关于使用Flux.Tracker.forward计算函数二阶导数


#1

最近想使用Flux包计算一个简单的拉普拉斯方程(二阶导数问题)。

发现不太清楚如何使用Tracker计算。

代码如下:

using Flux

X = rand(2, 465)
Y = zeros(2, 465)

m = Chain(Dense(2, 20, tanh), Dense(20, 20, tanh), Dense(20, 1))

mat_x(varIn) = [1.0 0.0] * varIn
mat_y(varIn) = [0.0 1.0] * varIn

u(x) = sin.(π .* mat_x(x) ./ 2.) .* m(x)

ux(x) = [1.0 0.0] * Tracker.forward(u, x)[2](1)[1]
uy(x) = [0.0 1.0] * Tracker.forward(u, x)[2](1)[1]
uxx(x) = [1.0 0.0] * Tracker.forward(ux, x)[2](1)[1]
uyy(x) = [0.0 1.0] * Tracker.forward(uy, x)[2](1)[1]

resi(x) = uxx(x) + uyy(x)
loss(x, y) = sum(resi(x).^2)

using Base.Iterators: repeated
using Flux: throttle

dataset = repeated((X, Y), 2)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()

Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))

发现报错如下:

MethodError: Cannot `convert` an object of type Base.ReshapedArray{Float64,2,LinearAlgebra.Transpose{Float64,Array{Float64,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}} to an object of type LinearAlgebra.Transpose{Float32,Array{Float32,2}}
Closest candidates are:
  convert(::Type{LinearAlgebra.Transpose{T,S}}, !Matched::LinearAlgebra.Transpose) where {T, S} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/adjtrans.jl:139
  convert(::Type{T<:AbstractArray}, !Matched::T<:AbstractArray) where T<:AbstractArray at abstractarray.jl:14
  convert(::Type{T<:AbstractArray}, !Matched::LinearAlgebra.Factorization) where T<:AbstractArray at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/factorization.jl:46
  ...

新手入门,还不太会使用Flux和Julia。
请教一下是在哪里出现了错误?
感谢!


#2

你这个牵涉到梯度的嵌套,然而默认的Flux.train!函数里面求梯度的时候没使用nest=true参数,所以才会出错,如果想顺利运行的话,可以把最后Flux.train!(...)那一行改为:

ps = params(m)
for (X,Y) in dataset
    gs = Tracker.gradient(()->@show(loss(X,Y)), ps,nest=true)
    Tracker.update!(opt, ps, gs)
end

#3

谢谢!:kissing_heart: