自定义广播多个拼在一起出错

最近使用了 * 法广播来扩展一个结构体的算子,正常两个放在一起没问题,但是与其他广播混合使用却出现了问题,如下是最小问题复现:

mutable struct NamedArray
    data::Array
    name::String
    function NamedArray(x, name)
        new(x, name)
    end
end

function Base.Broadcast.broadcasted(::typeof(*), x::Array, y::NamedArray)
    return NamedArray(x .* y.data, y.name)
end

然后 [4 3 2 1] .* NamedArray([1 2 3 4], "julia") 可以出正确的结果,但是

([5 4 3 2] .- 1) .*  NamedArray([1 2 3 4],"julia")

就会报错

ERROR: MethodError: no method matching length(::NamedArray)
Closest candidates are:
  length(::Union{Base.KeySet, Base.ValueIterator}) at D:\julia-1.7.2-win64\share\julia\base\abstractdict.jl:58
  length(::Union{LinearAlgebra.Adjoint{T, S}, LinearAlgebra.Transpose{T, S}} where {T, S}) at D:\julia-1.7.2-win64\share\julia\stdlib\v1.7\LinearAlgebra\src\adjtrans.jl:171
  length(::Base.Unicode.GraphemeIterator{S}) where S at D:\julia-1.7.2-win64\share\julia\base\strings\unicode.jl:664

然后加了

Base.length(x::NamedArray) = length(x.data)

求长度方法后又报错

julia> ([5 4 3 2] .- 1) .*  NamedArray([1 2 3 4], "julia")
ERROR: MethodError: no method matching iterate(::NamedArray)
Closest candidates are:
  iterate(::Union{LinRange, StepRangeLen}) at D:\julia-1.7.2-win64\share\julia\base\range.jl:826
  iterate(::Union{LinRange, StepRangeLen}, ::Integer) at D:\julia-1.7.2-win64\share\julia\base\range.jl:826
  iterate(::T) where T<:Union{Base.KeySet{<:Any, <:Dict}, Base.ValueIterator{<:Dict}} at D:\julia-1.7.2-win64\share\julia\base\dict.jl:695

然后再创建个迭代方法还是不对…

julia> Base.iterate(s::NamedArray, i::Integer=1) = i>length(s.data) ? nothing : (s.data[i],i+1)
julia> ([5 4 3 2] .- 1) .*  NamedArray([1 2 3 4], "julia")
4×4 Matrix{Int64}:
  4   3  2  1
  8   6  4  2
 12   9  6  3
 16  12  8  4

按照我的理解,应该是先求出 ([5 4 3 2] .- 1) 再计算其与 NamedArray([1 2 3 4], "julia") 的点乘,就像下边的就是对的

julia> begin
           a = [5 4 3 2] .- 1;
           b = NamedArray([1 2 3 4], "julia");
           a .* b
       end
NamedArray([4 6 6 4], "julia")

同时,原生的 Julia 数组也是可以的

julia> (1 .- randn(1,3)) .* randn(2,3)
2×3 Matrix{Float64}:
 -0.566199   1.13881   1.47938
  0.243334  -1.59071  -0.38588

有懂的同学麻烦帮个忙,谢谢啦 ::

1 个赞

你这里没必要自己去实现所有的细节,可以简单利用 Julia 的矩阵协议来复用一些共同的实现:

mutable struct NamedArray{T,N,AT<:AbstractArray{A,T}} <: AbstractArray{T,N}
    data::AT
    name::String
    function NamedArray(x::AbstractArray{T,N}, name) where {T,N}
        new{T,N,typeof(x)}(x, name)
    end
end

Base.size(x::NamedArray) = size(x.data)
Base.@propagate_inbounds Base.getindex(x::NamedArray, inds::Int...) = x.data[inds...]

然后就都可以用了.

Ref: Interfaces · The Julia Language

如果确实想要自己实现一个广播的话,可以参考 Interfaces · The Julia Language 给了一个非常相似的例子

julia的广播机制参考18337 03的37min.

我复制你的代码来玩了一下:

(3-2) .* NamedArray([1 2 3 4], "julia")  # 复现相同的报错

NamedArray([1 2 3 4], "julia")[1] # 报错:没有定义getIndex

定义一个Array不定义getindex感觉有点说不过去哈.
我愿意押五块钱你这个问题应该在定义完getindex就能解决…

加了getindex 是不行的,跟我加了 iterate 后的结果是一样的 :sweat_smile:

这个例子确实跟我的例子几乎一样,解决了一部分问题,但是这样应该就是直接继承了数组的点操作,包括加减乘除。其实我想实现自定义加减乘除来实现一些额外操作,例如点加、点减、点乘、点除都分别对应一种对 .name 属性的操作,比如

function Base.Broadcast.broadcasted(::typeof(+), x::Array, y::NamedArray)
    return NamedArray(x .+ y.data, y.name * " add")
end

function Base.Broadcast.broadcasted(::typeof(-), x::Array, y::NamedArray)
    return NamedArray(x .- y.data, y.name * " minus")
end

function Base.Broadcast.broadcasted(::typeof(*), x::Array, y::NamedArray)
    return NamedArray(x .* y.data, y.name * " mul")
end

function Base.Broadcast.broadcasted(::typeof(/), x::Array, y::NamedArray)
    return NamedArray(x ./ y.data, y.name * " div")
end

boradcasted 应该返回一个 Broadcasted 或者其他相同接口的 lazy container 对象。你这里面的代码实际上会真正把每一个中间结果 (例如 x ./ y.data)算出来,再存储为 NamedArray 对象,先不谈如何做到你想要的功能,首先它在性能上就和广播的 “多个计算仅需要一次内存分配” 的特性相违背了。

如果你要实现自己的广播,最好复用标准的矩阵接口,然后再根据广播的接口来进行扩展,不是很清楚你想算出个什么东西,所以我这里随便造了一个函数 build_name 函数

mutable struct NamedArray{T,N,AT<:AbstractArray{T,N}} <: AbstractArray{T,N}
    data::AT
    name::String
    function NamedArray(x::AbstractArray{T,N}, name) where {T,N}
        new{T,N,typeof(x)}(x, name)
    end
end

Base.size(x::NamedArray) = size(x.data)
Base.@propagate_inbounds Base.getindex(x::NamedArray, inds::Int...) = x.data[inds...]
Base.@propagate_inbounds Base.setindex!(x::NamedArray, val, inds::Int...) = x.data[inds...] = val

Base.BroadcastStyle(::Type{<:NamedArray}) = Broadcast.ArrayStyle{NamedArray}()
function Base.similar(bc::Broadcast.Broadcasted{Broadcast.ArrayStyle{NamedArray}}, ::Type{T}) where {T}
    name = build_name(bc)
    NamedArray(similar(Array{T}, axes(bc)), name)
end

build_name(bc::Base.Broadcast.Broadcasted) = build_name(bc.f, bc.args)
function build_name(f, args::Tuple)
    if length(args) == 1
        return fname(f) * " " * argname(args[1])
    elseif length(args) == 2
        return argname(args[1]) * " " * fname(f) * " " * argname(args[2])
    else
        error("only support unary and binary operations.")
    end
    return name
end

fname(::typeof(-)) = "minus"
fname(::typeof(+)) = "add"
fname(::typeof(*)) = "mul"
fname(::typeof(/)) = "div"

argname(x) = ""
argname(x::NamedArray) = x.name
argname(bc::Base.Broadcast.Broadcasted) = build_name(bc)

julia> x1 = [5 4 3 2]
1×4 Matrix{Int64}:
 5  4  3  2

julia> x2 = 1
1

julia> x3 = NamedArray([1 2 3 4], " julia")
1×4 NamedArray{Int64, 2, Matrix{Int64}}:
 1  2  3  4

julia> y = (x1 .- x2) .* x3
1×4 NamedArray{Int64, 2, Matrix{Int64}}:
 4  6  6  4

julia> y.name
" minus  mul  julia"

换句话说,name 的构建是发生在为结果创建内存的过程 (similar) 而不是在实际数值计算的环节,因为前面说了 broadcasted 需要返回一个 lazy container 对象。


y = (x1 .- x2) .* x3

背后大概做的是类似于下面的操作:

bc = broadcasted(*, broadcasted(-, x1, x2), x3) # 这一步并不会真正计算数据
y = collect(bc) # 真正的计算和内存分配发生在这一步
2 个赞

多个计算仅需要一次内存分配, 在大部分场景中都是极其高效的,但是也有些特殊需求是想保留中间计算结果的,例如 (1 .- x) .* y ./ c, 中不妨将其拆解为三个算式:
a1 = (1 .- x)
a2 = a1 .* y
a3 = a2 ./ c
a1 a2 a3的计算结果都需要保存在某个结构中以供后续分析,这样就避免不了多次分配,这时候我们希望在三个表达式中仍然使用 broadcast 操作,但是不希望将多个 broadcast 一次性合并再统一计算,也就是按照加括号最优先,然后再按照加减乘除优先级依次计算。

我对Julia内核不了解,但是猜测目前这个例子里应该是没法这样搞的。不过还是非常感谢!

一般是两种方案,取决于你要在哪个抽象级别存储信息:

  • 如果可以逐元素存储的信息,最好采用 array of struct 的抽象,比如说 Array{T} 其中 T 是一个存储了额外信息的结构体。一个典型的例子就是自动微分框架的 dual number Automatic Differentiation with Dual Numbers | juliabloggers.com
  • 如果是矩阵本身的 metadata,用上面这种方式将 metadata 的信息计算拆开到一个单独的环节进行处理。与广播的逐元素计算混在一起会因为不太兼容其他的 Julia 矩阵类型导致很多不必要的胶水代码。