最近想使用Flux包计算一个简单的拉普拉斯方程(二阶导数问题)。
发现不太清楚如何使用Tracker计算。
代码如下:
using Flux
X = rand(2, 465)
Y = zeros(2, 465)
m = Chain(Dense(2, 20, tanh), Dense(20, 20, tanh), Dense(20, 1))
mat_x(varIn) = [1.0 0.0] * varIn
mat_y(varIn) = [0.0 1.0] * varIn
u(x) = sin.(π .* mat_x(x) ./ 2.) .* m(x)
ux(x) = [1.0 0.0] * Tracker.forward(u, x)[2](1)[1]
uy(x) = [0.0 1.0] * Tracker.forward(u, x)[2](1)[1]
uxx(x) = [1.0 0.0] * Tracker.forward(ux, x)[2](1)[1]
uyy(x) = [0.0 1.0] * Tracker.forward(uy, x)[2](1)[1]
resi(x) = uxx(x) + uyy(x)
loss(x, y) = sum(resi(x).^2)
using Base.Iterators: repeated
using Flux: throttle
dataset = repeated((X, Y), 2)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()
Flux.train!(loss, params(m), dataset, opt, cb = throttle(evalcb, 10))
发现报错如下:
MethodError: Cannot `convert` an object of type Base.ReshapedArray{Float64,2,LinearAlgebra.Transpose{Float64,Array{Float64,2}},Tuple{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}}} to an object of type LinearAlgebra.Transpose{Float32,Array{Float32,2}}
Closest candidates are:
convert(::Type{LinearAlgebra.Transpose{T,S}}, !Matched::LinearAlgebra.Transpose) where {T, S} at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/adjtrans.jl:139
convert(::Type{T<:AbstractArray}, !Matched::T<:AbstractArray) where T<:AbstractArray at abstractarray.jl:14
convert(::Type{T<:AbstractArray}, !Matched::LinearAlgebra.Factorization) where T<:AbstractArray at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.1/LinearAlgebra/src/factorization.jl:46
...
新手入门,还不太会使用Flux和Julia。
请教一下是在哪里出现了错误?
感谢!