在使用Flux.Tracker.gradient(),发现表达式形式似乎不能直接做代数运算?
具体代码如下:
nMesh = 101
omega = 1.0
X = Vector(range(0, stop=1, length=101))
X = hcat(X...)
Y = zeros(101)
Y = hcat(Y...)
m = Chain(Dense(1, 20, tanh), Dense(20, 1), softmax)
hyper(x) = 1.0 .+ m(x) .* x
resi(x) = omega .* hyper(x) .+ Tracker.gradient(hyper, x; nest = true)[1]
loss(x, y) = sum(resi(x).^2)
这样会报以下错误:
Function output is not scalar
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] losscheck(::TrackedArray{…,Array{Float64,2}}) at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:153
[3] gradient_nested(::Function, ::Array{Float64,2}) at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:160
[4] #gradient#24 at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:164 [inlined]
[5] #gradient at ./none:0 [inlined]
[6] resi(::Array{Float64,2}) at ./In[13]:36
[7] loss(::Array{Float64,2}, ::Array{Float64,2}) at ./In[13]:37
[8] #15 at /home/vavrines/.julia/packages/Flux/zNlBL/src/optimise/train.jl:72 [inlined]
[9] gradient_(::getfield(Flux.Optimise, Symbol(“##15#21”)){typeof(loss),Tuple{Array{Float64,2},Array{Float64,2}}}, ::Tracker.Params) at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:97
[10] #gradient#24(::Bool, ::Function, ::Function, ::Tracker.Params) at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:164
[11] gradient at /home/vavrines/.julia/packages/Tracker/6wcYJ/src/back.jl:164 [inlined]
[12] macro expansion at /home/vavrines/.julia/packages/Flux/zNlBL/src/optimise/train.jl:71 [inlined]
[13] macro expansion at /home/vavrines/.julia/packages/Juno/TfNYn/src/progress.jl:133 [inlined]
[14] #train!#12(::getfield(Flux, Symbol(“#throttled#18”)){getfield(Flux, Symbol(“##throttled#10#14”)){Bool,Bool,getfield(Main, Symbol(“##17#18”)),Int64}}, ::Function, ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float64,2},Array{Float64,2}}}}, ::ADAM) at /home/vavrines/.julia/packages/Flux/zNlBL/src/optimise/train.jl:69
[15] (::getfield(Flux.Optimise, Symbol(“#kw##train!”)))(::NamedTuple{(:cb,),Tuple{getfield(Flux, Symbol(“#throttled#18”)){getfield(Flux, Symbol(“##throttled#10#14”)){Bool,Bool,getfield(Main, Symbol(“##17#18”)),Int64}}}}, ::typeof(Flux.Optimise.train!), ::Function, ::Tracker.Params, ::Base.Iterators.Take{Base.Iterators.Repeated{Tuple{Array{Float64,2},Array{Float64,2}}}}, ::ADAM) at ./none:0
[16] top-level scope at In[13]:42
同样的问题,在使用tensorflow的placeholder占位时就不会出现,例如代码是:
varIn = tf.placeholder(name="input", dtype=tf.float64, shape=[None, 1],)
core = varIn
hypothesis = 1.0 + varIn * core
residual = omega * hypothesis + tf.split(tf.gradients(hypothesis, varIn)[0], num_or_size_splits=1, axis=1)
Flux的文档不太完善,请问是否有人遇到过同样问题或有具体的解决办法?
感恩。