过多同名函数导致类型推断失败?

在做一个大型项目时发现的东西, 下面是一个可以重现这个问题的简单例子

abstract type CustomType end

struct AType <: CustomType end
struct BType <: CustomType end
struct CType <: CustomType end
struct DType <: CustomType end

func(a::AType, x) = 1 + x
func(b::BType, x) = 2 + x
func(c::CType, x) = 3 + x
func(d::DType, x) = 4 + x

struct EType
    cust::CustomType
end

foo(e::EType, x) = func(e.cust, x)

运行 code_warntype(foo, (EType, Int)) 输出

MethodInstance for foo(::EType, ::Int64)
  from foo(e::EType, x) in Main at e:\Codes\Jl\nyasRT.jl\aaa.jl:17
Arguments
  #self#::Core.Const(foo)
  e::EType
  x::Int64
Body::Any
1 ─ %1 = Base.getproperty(e, :cust)::CustomType
│   %2 = Main.func(%1, x)::Any
└──      return %2

但是把第4个 func 注释掉1个后再运行 code_warntype 得到

MethodInstance for foo(::EType, ::Int64)
  from foo(e::EType, x) in Main at e:\Codes\Jl\nyasRT.jl\aaa.jl:17
Arguments
  #self#::Core.Const(foo)
  e::EType
  x::Int64
Body::Int64
1 ─ %1 = Base.getproperty(e, :cust)::CustomType
│   %2 = Main.func(%1, x)::Int64
└──      return %2

也就是说 func 数量超过3个的话编译器就不能推断返回类型了.

有什么办法使得编译器推断 foo 的返回类型呢.
在实际项目里是每一个 func 都返回相同类型的, 但是编译器就是不懂, 给 func 添加参数类型声明和直接声明 func 的返回类型都不行.

给一些我的观察。在你的例子中,

struct EType
    cust::CustomType
end

本身是一个类型不稳定的struct定义。即便在只有三个实例的情况下,虽然它能推断Body返回的是Int64,但是1 ─ %1 = Base.getproperty(e, :cust)::CustomType依旧是不稳定的。

如果你把struct本身定义为稳定的,

julia> struct GType
           cust::DType
       end

julia> foo(g::GType, x) = func(g.cust, x)
foo (generic function with 3 methods)

julia> code_warntype(foo, (GType, Int))
MethodInstance for foo(::GType, ::Int64)
  from foo(g::GType, x) in Main at REPL[28]:1
Arguments
  #self#::Core.Const(foo)
  g::Core.Const(GType(DType()))
  x::Int64
Body::Int64
1 ─ %1 = Base.getproperty(g, :cust)::Core.Const(DType())
│   %2 = Main.func(%1, x)::Int64
└──      return %2

或者使用parametric struct

julia> struct HType{T<:CustomType}
           cust::T
       end

julia> h = HType(DType())
HType{DType}(DType())

julia> foo(h::HType, x) = func(h.cust, x)
foo (generic function with 4 methods)

julia> @code_warntype foo(h, 1)
MethodInstance for foo(::HType{DType}, ::Int64)
  from foo(h::HType, x) in Main at REPL[37]:1
Arguments
  #self#::Core.Const(foo)
  h::Core.Const(HType{DType}(DType()))
  x::Int64
Body::Int64
1 ─ %1 = Base.getproperty(h, :cust)::Core.Const(DType())
│   %2 = Main.func(%1, x)::Int64
└──      return %2

那么衍生出来的问题就是,原先的

struct EType
    cust::CustomType
end

这个写法其实完全可以被

struct HType{T<:CustomType} end

替换掉,而类型实例则可以写为

h = HType{DType}()

确实有考虑过第二种方法,但是在我的项目里另外的地方会组成一个元素类型为 EType 的数组, 并且 EType 里有3个类似 CustomType 的抽象类型对象 (3个对象之间是没关联的), 我看性能建议说应该避免过多的多重派发.
所以使用 HType 那样方法会把类型不稳定带到其他地方吗?

其实如果可以告诉编译器 func(::CustomType, ::Int) 必定会返回 Int 的话就兼顾代码简洁性并且使性能损失减到最少 (大概?). 但是没找到这种方法.

其实我发现把语句

foo(e::EType, x) = func(e.cust, x)

替换为

foo(e::EType, x::Int) = func(e.cust, x)::Int

的话就可以给编译器声明返回的类型, 但是因为 func 在项目里很多地方都在用, 全部加上这个类型声明的话就不太好看了.
而且这样子实际运行还是从 func(e.cust, x) 拿到 Any 再判断是否为 Int, 感觉这样的性能损失就不太应该发生 (C艹就没有这样的性能损失).

如果你想通过类型稳定来保证代码效率,那么使用含参类(parametric struct)是通常的选择。我例子中的HType保证了类型稳定;一般需要避免将抽象类型Abstract type作为struct内部变量的类型定义,可以参考官方的performance tips。

如果你怀疑定义三个类HType{P,Q,R}会滥用多重派发,最好实践对比一下。也可以参考GitHub上一些别人写的Julia库是如何构建自定义类型的。

一般来说,如果想要拿到高性能地代码的话,不应该使用抽象类型来声明结构体.

这是一个非常好的观察… 之所以如此是因为 Julia 提供了一个 Union-Splitting 地机制来优化一些非常小的 “抽象” 集合体的性能。https://julialang.org/blog/2018/08/union-splitting/ 这个技术允许 Julia 在这些非常小的类型集合上依然优化出比较不错的性能,这最主要的一个应用就在比如说 Union{Nothing,T} 这种场景,比如 iter(X) 的返回值就是这样。

其实如果可以告诉编译器 func(::CustomType, ::Int) 必定会返回 Int 的话就兼顾代码简洁性并且使性能损失减到最少 (大概?). 但是没找到这种方法.

在 Julia 里面,你几乎很少需要通过手动标注类型来指定某个变量或函数的返回值类型,换句话说,在绝大部分时候,你都不需要做下面这种操作

foo(e::EType, x::Int) = func(e.cust, x)::Int

这种操作本质上是将类型不稳定在 foo 的层面阻断了,但是在 func 内部依然是不稳定的。

更好的策略就是 @henry2004y 所说的用参数化类型来构建类型稳定代码,从而让 Julia 和编译期去把这些事情给推导出来并进行优化。很多时候编译器比你想象的要聪明,只是需要你告诉它足够多的(类型)信息。

1 个赞

确实使用含参类可以使 foo(::HType) 的返回类型确定, 但是使用 Hype 组成的数组就变成了类型不确定了.
下面是项目里调用 func 的部分内容:

hlist = HType[]

push!(hlist, HType(AType()))
push!(hlist, HType(DType()))

function usefulfunc(hs, x)
    res = 0
    for h ∈ hs
        res += foo(h, x)
    end
    return res
end

使用 @code_warntype usefulfunc(hlist, 1) 发现类型不确定的问题就被抛到了 h ∈ hs 上, usefulfunc 依然返回的是 Any. 因为在项目里 HType 是有多个参数类型的, 所以也不能把 usefulfunc 限制到 usefulfunc(hs::Vector{HType{T}}, x) where {T} 上.

对于怎么优化这部分自己已经完全没有头绪了, 还请大佬多多指教.

对于无法消除的类型不稳定,一个常用的技巧是 function barrier https://docs.julialang.org/en/v1/manual/performance-tips/#kernel-functions

在你的这个例子里,foo 就是 kernel function,至少在 foo 的内部,类型是稳定的。

如果还要进一步阻断类型不稳定的话,那么就不要用一个 hlist,而是用一组 hlist 来分别存放不同类型的元素: hlist_atype, hlist_btype, ... 然后构建一个更大的 kernel 函数。

在我的项目里 HType 是有3个类型参数的, 总共产生超过100种具体类型. 但是在实际运行里只有几个或者十几个 HType 对象, 并且对象之间的具体类型很可能都不相同的.
并且 usefulfunc 在一次程序运行里会被调用几千万甚至上亿次, 实际上大部分时间都花在运行时的类型推断上了.

回到我一开始的程序 (使用 EType), 并且 func 数量不超过3个, 下面是使用 @btime 测量的运行情况

5.458 s (45 allocations: 59.46 MiB)

但是当 func 数量超过3个, 就算用上 HType, 并且在 usefulfunc 里面声明返回类型 res += foo(h, x)::Int, 这时的运行情况也十分不理想

62.008 s (1562596710 allocations: 68.37 GiB)

并且使用 julia --track-allocation=user 确定了在 res += foo(h, x)::Int 这一行产生了巨量的内存占用.

所以也许比起让编译器确定函数的输入类型, 我这里的情况可能更需要让编译器确定函数的输出类型?

这里的讨论似乎有点抓不住重点,我和 @henry2004y 说的和你所设想的似乎并不是一件事情。如果你可以给一个 MWE 的话可能更好讨论一些,比如说如果你能够把两个版本的完整测试代码都拿出来的话我们会更容易判断究竟问题在哪里。

这个vector无论怎么处理都会是类型不稳定的,本质上和[Int16, Int32, Int64, Float32]这样的没啥区别。我感觉还是核心函数的结构设计上有问题,可以改进。我之前也遇到过类似的问题,是通过把不稳定的类型放到更外层的函数,然后让kernel稳定的方式来提高效率的。@johnnychen94 如果想让这里的res尽可能稳定你会怎么做?

1 个赞

@johnnychen94 如果想让这里的res 尽可能稳定你会怎么做?

这个例子的问题在于 foo 太简单了,因此性能差距特别明显,办法是拆分出更大的 kernel function 来让底层计算依然在类型稳定的基础上进行。

v1:

hlist = HType[]

for i in 1:10000
    v = rand()
    if v > 0.7
        push!(hlist, HType(AType()))
    elseif v > 0.4
        push!(hlist, HType(BType()))
    elseif v > 0.2
        push!(hlist, HType(CType()))
    else
        push!(hlist, HType(DType()))
    end
end

function usefulfunc(hs::Vector{<:HType}, x)
    res = 0
    for h ∈ hs
        res += foo(h, x)
    end
    return res
end

@btime usefulfunc($hlist, 2) # 248.732 μs (9883 allocations: 154.42 KiB)

v2:

hss = [HType{AType}[], HType{BType}[], HType{CType}[], HType{DType}[]]
for i in 1:10000
    v = rand()
    if v > 0.7
        push!(hss[1], HType(AType()))
    elseif v > 0.4
        push!(hss[2], HType(BType()))
    elseif v > 0.2
        push!(hss[3], HType(CType()))
    else
        push!(hss[4], HType(DType()))
    end
end
usefulfunc(hs::Vector{Vector}, x) = mapreduce(hs->usefulfunc(hs, x), +, hss)

@btime usefulfunc($hss, 2) # 168.340 ns (12 allocations: 192 bytes)
1 个赞

补充一点,多重派发性能快的前提是类型在编译期间能够进行推断,上面的这段代码其实已经违背了这一前提,因为 rand() 的随机性导致你不可能在编译期间推断所有类型。能做的仅仅是通过构建 function barrier 来让 Julia 编译器尽可能地优化底层函数,从而让类型不稳定对性能的影响不那么大。

不是所有场景都能够追求绝对的类型可推断的,比如说 FileIO 里面读取一个文件 load(::String) 就没办法从 String 类型去推断 load 的返回值类型是什么:它有可能是 Array{RGB}, 也可能是 DataFrame 或者各种各样的类型,完全取决于 runtime 的具体数值。但是为了方便性,依然还是构建了这么一个东西出来。

我感觉DataFrames.jl里面也是这样的。

我已经把整块项目提交到GitHub上, 这是链接, 因为这个是个人项目, 阅读起来可能会有点迷惑.

运行这个代码需要 StaticArrays 包. 而 PNGFilesColorTypes 仅用于输出图片, 把 ./test.jl 里相应的地方注释掉就不需要这两个包.

在项目里与使用 EType 相同的代码放在 ./src.backup 里, 而使用与 HType 相同的参数类型的代码放在 ./src 里. test 第1, 第2 行可以切换两个不同实现的模块.

项目里与 ETypeHType 相似的类型在放在 ./src/object.jl 里的 Object, 而 CustomType 则对应 ./src/surfaces.jl 里面的 AbstractSurface, func 对应着 surface 下 record.

现在产生问题的地方在 ./src/world.jl 里, ObjectWorld 以数组储存, 并且 hitobject 函数反复调用 record(obj.surface, ...). 具体构建 World 对象的函数是在 ./test.jl 里的 arealight_scenes.

我进行测试时是直接 include("test.jl"), 然后运行 @time main((1080, 1920), 16, nothing) (或者 @btime).
因为在 ./src.backup/surfaces.jl 下的 Circle 被注释了, 所以编译器可以获得 record(obj.surface, ...) 的返回类型从而在短时间内运行完成; 但如果没有注释掉 Circle 的话就会运行超长时间.

struct World{Cam, Amb, LT<:AbstractLight}
    camera::Cam
    ambient::Amb
    objects::Vector{Object}
    lights::Dict{LT,Vector{T}}
    objlights::Vector{Int}
end

然后

- push!(world.lights, nyasRT.Directional(Vec3(0, -1, -.5), RGB(.2, .8, 1)))
+ push!(world.lights[Directional], nyasRT.Directional(Vec3(0, -1, -.5), RGB(.2, .8, 1)))

后面渲染的时候也分别对每种类型单独做循环,比如说

- for light in world.lights
+ for LT in subtypes(AbstractLight)
+     haskey(world.lights, LT) || continue
+     for light in world.lights[LT]
      ...

这样的话就是把更小的 kernel 函数 (对某一个光的渲染)转换成了更大的 kernel 函数 (对某一类光的所有光源进行渲染)

没有做测试,但是应该可以降低类型不稳定带来的开销。(实际上应该已经把类型稳定化了)

直接给 World 加上 LT 类参的话, 不是会直接限制了下面 lights 的具体类型, 从而导致不能 push 入其他关照类型?

并且目前造成性能问题的是 world.jl 里面检查光线与 objects 碰撞的 hitobject 函数.

大概看明白了不过代码有点太长了所以不太好直接改出一个可用的版本。

这里的话因为 ObjectAbstractLight 的实际类型还是相对有限的,所以可以将每一类存放在对应的列表里,然后再将每一类的列表构造成一个大的 Tuple

struct ObjectList{T}
    objects::T
end

struct World{Cam, Amb, OT, LT}
    camara::Cam
    ambient::Amb
    objects::OT
    lights::LT
    objlights::Vector{Tuple{Int,Int}}
end

然后对应 world 的构造稍微调整一下,比如说:

# 这里可以枚举所有可能的类型,如果某个类型没有内容那么就返回一个空的列表,这样可以保证
# 构造出来的 objects 类型可以在编译期确定
obj_balls = [nyasRT.Object(ballsur, balltex, ballbrdf)]
obj_lights = [...]
obj_floors = [...]
objects = (obj_balls, obj_lights, obj_floors)

# lights 同理
world(cam, amb, objects, lights)

这样下来,world 里面每个值的类型都是稳定的。

因为这个调整涉及到其他代码的逻辑,所以我这里就没有去测试了。

确实这样子可以使得 objects 的类型全部确定下来, 下面代码是我对这个方法的实现, 因为这个函数只会在一开始调用一次, 所以这个函数内的类型不稳定就不在意了.

function setupobjects(objs::Vector{Object})
    objdict = Dict{DataType, Any}()
    for obj ∈ objs
        T = typeof(obj)
        T ∉ keys(objdict) && (objdict[T] = T[])
        push!(objdict[T], obj)
    end
    return (values(objdict)..., )
end

下面是对这个函数的测试

julia> objs isa Vector{nyasRT.Object}
true

julia> objlists = nyasRT.setupobjects(objs);

julia> typeof(objlists)
Tuple{Vector{Main.nyasRT.Object{Main.nyasRT.Sphere, Main.nyasRT.PureColor, Main.nyasRT.Phong}}, 
      Vector{Main.nyasRT.Object{Main.nyasRT.InftyPlane, Main.nyasRT.PureColor, Main.nyasRT.Lambertian}}}

但是问题是如何对 objlists 里面的元素进行历遍呢, 下面是我设想的一种方法

function hitobjects(objs, ray)
    hitsetidx = hiteleidx = 0
    rec = nyasRT.HittingRecord()
    for (setidx, objlists) ∈ enumerate(objs)
        for (eleidx, obj) ∈ enumerate(objlists)
            tmprec = nyasRT.hit(obj.surface, ray, rec.t)
            if tmprec.hitted
                rec = tmprec
                hitsetidx = setidx
                hiteleidx = eleidx
            end
        end
    end
    return (hitsetidx, hiteleidx), rec
end

但是通过 code_warntypes 查看这个函数依然是类型不稳定的

julia> @code_warntype hitobjects(objs, nyasRT.Ray(Vec3(0), Vec3(0)))
MethodInstance for hitobjects(::Tuple{Vector{Object{Sphere, PureColor, Phong}}, Vector{Object{InftyPlane, PureColor, Lambertian}}}, ::Ray)
  from hitobjects(objs, ray) in Main at 
Arguments
  #self#::Core.Const(test3)
  objs::Tuple{Vector{Object{Sphere, PureColor, Phong}}, Vector{Object{InftyPlane, PureColor, Lambertian}}}
  ray::Ray
Locals
  @_4::Union{Nothing, Tuple{Union{Vector{Object{InftyPlane, PureColor, Lambertian}}, Vector{Object{Sphere, PureColor, Phong}}}, Int64}}
  rec::HittingRecord
  @_6::Union{Nothing, Tuple{Object{InftyPlane, PureColor, Lambertian}, Int64}, Tuple{Object{Sphere, PureColor, Phong}, Int64}}
  objlists::Union{Vector{Object{InftyPlane, PureColor, Lambertian}}, Vector{Object{Sphere, PureColor, Phong}}}
  obj::Union{Object{InftyPlane, PureColor, Lambertian}, Object{Sphere, PureColor, Phong}}
  tmprec::Tuple{Bool, Float64}
Body::HittingRecord
1 ─ %1  = HittingRecord::Core.Const(HittingRecord)
│         (rec = (%1)())
│   %3  = objs::Tuple{Vector{Object{Sphere, PureColor, Phong}}, Vector{Object{InftyPlane, PureColor, Lambertian}}}
│         (@_4 = Base.iterate(%3))
│   %5  = (@_4::Core.PartialStruct(Tuple{Vector{Object{Sphere, PureColor, Phong}}, Int64}, Any[Vector{Object{Sphere, PureColor, Phong}}, Core.Const(2)]) === nothing)::Core.Const(false)
│   %6  = Base.not_int(%5)::Core.Const(true)
└──       goto #6 if not %6
2 ┄ %8  = @_4::Union{Tuple{Vector{Object{InftyPlane, PureColor, Lambertian}}, Int64}, Tuple{Vector{Object{Sphere, PureColor, Phong}}, Int64}}
│         (objlists = Core.getfield(%8, 1))
│   %10 = Core.getfield(%8, 2)::Int64
│   %11 = objlists::Union{Vector{Object{InftyPlane, PureColor, Lambertian}}, Vector{Object{Sphere, PureColor, Phong}}}
│         (@_6 = Base.iterate(%11))
│   %13 = (@_6 === nothing)::Bool
│   %14 = Base.not_int(%13)::Bool
└──       goto #4 if not %14
3 ─ %16 = @_6::Union{Tuple{Object{InftyPlane, PureColor, Lambertian}, Int64}, Tuple{Object{Sphere, PureColor, Phong}, Int64}}
│         (obj = Core.getfield(%16, 1))
│         Core.getfield(%16, 2)
│   %19 = hit::Core.Const(hit)
│   %20 = Base.getproperty(obj, :surface)::Union{InftyPlane, Sphere}
│   %21 = Base.getproperty(rec::Core.Const(HittingRecord(false, Inf, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0])), :t)::Core.Const(Inf)
│         (tmprec = (%19)(%20, ray, %21))
│         Base.getproperty(tmprec, :hitted)
│         Core.Const(:(Core.typeassert(%23, Core.Bool)))
│         Core.Const(:(rec = tmprec))
│         Core.Const(:(goto %27))
│         Core.Const(:(@_6 = Base.iterate(%11, %18)))
│         Core.Const(:(@_6 === nothing))
│         Core.Const(:(Base.not_int(%28)))
│         Core.Const(:(goto %32 if not %29))
└──       Core.Const(:(goto %16))
4 ┄       (@_4 = Base.iterate(%3, %10))
│   %33 = (@_4 === nothing)::Bool
│   %34 = Base.not_int(%33)::Bool
└──       goto #6 if not %34
5 ─       goto #2
6 ┄       return rec::Core.Const(HittingRecord(false, Inf, [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]))

我无意中翻到一篇旧文Analyzing sources of compiler latency in Julia: method invalidations,其中给出的第一组例子刚好能够解释本帖最开始的问题。我的理解是这是目前版本为了提高编译效率作出的妥协。

备案号:京ICP备17009874号-2