关于 @inbounds 的问题

最近在优化一段代码的时候,遇到一个有关 @inbounds 的问题。

using CuArrays
CuArrays.allowscalar(false)  # 这里disable了

x = cu(rand(2,3,4))
y = cu([CartesianIndex(i, i) for i in 1:4])
x[:, y]

这里会触发错误:

scalar getindex is disallowed

Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/tj/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
 [3] getindex at /home/tj/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54 [inlined]
 [4] iterate at ./abstractarray.jl:914 [inlined]
 [5] iterate at ./abstractarray.jl:912 [inlined]
 [6] checkindex(::Type{Bool}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}, ::CuArray{CartesianIndex{2},1,Nothing}) at ./multidimensional.jl:514
 [7] checkbounds_indices at ./multidimensional.jl:509 [inlined]
 [8] checkbounds_indices at ./abstractarray.jl:529 [inlined]
 [9] checkbounds at ./abstractarray.jl:482 [inlined]
 [10] checkbounds at ./abstractarray.jl:503 [inlined]
 [11] _getindex at ./multidimensional.jl:669 [inlined]
 [12] getindex(::CuArray{Float32,3,Nothing}, ::Function, ::CuArray{CartesianIndex{2},1,Nothing}) at ./abstractarray.jl:981

其主要原因是,_getindex里有 checkbounds

文档里有介绍checkbounds的实现,可以看到这里会触发cuarray 的 scalar 访问

checkbounds_indices(Bool, (IA1, IA...), (I1, I...)) = checkindex(Bool, IA1, I1) &                                                      checkbounds_indices(Bool, IA, I)

由于我的代码逻辑里可以保证这里不会出现异常访问,于是我希望直接 ignore boundscheck,于是定义了下面的函数:

@inline f(a, b) = @inbounds a[:, b]
f(x, y)

2×4 CuArray{Float32,2,Nothing}:
 0.94804   0.683741  0.203128  0.582778
 0.461565  0.36298   0.537245  0.469615

so far so good,但是,我需要在Zygote里用到这个函数:

gradient(f, x, y)

然后又触发了跟前面一样的error,这应该怎么解决?

scalar getindex is disallowed

Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] assertscalar(::String) at /home/tj/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:14
 [3] getindex at /home/tj/.julia/packages/GPUArrays/1wgPO/src/indexing.jl:54 [inlined]
 [4] iterate at ./abstractarray.jl:914 [inlined]
 [5] iterate at ./abstractarray.jl:912 [inlined]
 [6] checkindex(::Type{Bool}, ::Tuple{Base.OneTo{Int64},Base.OneTo{Int64}}, ::CuArray{CartesianIndex{2},1,Nothing}) at ./multidimensional.jl:514
 [7] checkbounds_indices at ./multidimensional.jl:509 [inlined]
 [8] checkbounds_indices at ./abstractarray.jl:529 [inlined]
 [9] checkbounds at ./abstractarray.jl:482 [inlined]
 [10] checkbounds at ./abstractarray.jl:503 [inlined]
 [11] _getindex at ./multidimensional.jl:669 [inlined]
 [12] getindex at ./abstractarray.jl:981 [inlined]
 [13] adjoint at /home/tj/.julia/packages/Zygote/8dVxG/src/lib/array.jl:30 [inlined]
 [14] _pullback at /home/tj/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [15] f at ./In[36]:1 [inlined]
 [16] _pullback(::Zygote.Context, ::typeof(f), ::CuArray{Float32,3,Nothing}, ::CuArray{CartesianIndex{2},1,Nothing}) at /home/tj/.julia/packages/Zygote/8dVxG/src/compiler/interface2.jl:0
 [17] _pullback(::Function, ::CuArray{Float32,3,Nothing}, ::CuArray{CartesianIndex{2},1,Nothing}) at /home/tj/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:31
 [18] pullback(::Function, ::CuArray{Float32,3,Nothing}, ::CuArray{CartesianIndex{2},1,Nothing}) at /home/tj/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:37
 [19] gradient(::Function, ::CuArray{Float32,3,Nothing}, ::CuArray{CartesianIndex{2},1,Nothing}) at /home/tj/.julia/packages/Zygote/8dVxG/src/compiler/interface.jl:46
 [20] top-level scope at In[40]:
···

能定义CuArray的checkbounds么?这样就不会派发到genric上了

嗯,目前先自定义了

Base.checkindex(::Type{Bool}, inds::Tuple, I::CuArray{<:CartesianIndex})