采用Flux训练的模型求解出来的结果无法绘制图形,疑问

求高手赐教
目的是拟合曲线y=sin(x), 0.0<x<3.0;
代码如下:

using Flux;
using Flux: @epochs, throttle;

# 产生数据集
xs = reshape(collect(0:0.05:3.0), 1, :);
ys = sin.(xs) .+ 0.15*rand(1, length(xs));
dataset = [(xs, ys)];  # or dataset = Iterators.repeated((xs, ys), 1000)

# 定义模型
m = Chain(Dense(1, 50, tanh),
    Dense(50, 25),
    Dense(25, 1));

loss(x, y) = Flux.mse(m(x), y); # 定义误差函数
evalcb() = @show(loss(xs, ys))
opt = ADAM();
ps = params(m);

# 训练模型
@epochs 1000 Flux.train!(loss, ps, dataset, opt, cb = throttle(evalcb, 10));

# 绘制数据图像
using Gadfly;
l1 = Gadfly.layer(x=xs, y=ys, Geom.point, Theme(default_color="black"));
Gadfly.plot(l1, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin"], ["black"]))
x0 = [1.5 2.5];
println("真实值:$(sin.(x0));", "预测值:$(m(x0))")
l2 = Gadfly.layer(x=x0, y=m(x0), Geom.point, Theme(default_color="red"))
Gadfly.plot(l1, l2, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin", "predict"], ["black", "red"]))
Gadfly.plot(x=[1.5 2.5], y=m([1.5 2.5]), Geom.point)

出现的问题是,利用训练的模型求解得到的值无法绘图(提示类型错误) ;
报错截图如下:

报错问题是类型不匹配,所以要想办法把 Track Array 里的数据取出来;

>y=m(x)
>?y

查看Track Array 用法,
然后发现画图的时候应该用 y.data

十分感谢你的解答,:+1:
修正后代码如下:

# ----------(x) -> f(x)--------------
using Flux;
using Flux: @epochs, throttle;

# 产生数据集
xs = reshape(collect(0:0.05:6.0), 1, :);
ys = 2.0.*sin.(xs) .+ 0.15*rand(1, length(xs));
dataset = [(xs, ys)];  # or dataset = Iterators.repeated((xs, ys), 1000)

# 定义模型
m = Chain(Dense(1, 50, tanh),
    Dense(50, 25),
    Dense(25, 1));

loss(x, y) = Flux.mse(m(x), y); # 定义误差函数
evalcb() = @show(loss(xs, ys))
opt = ADAM();
ps = Flux.params(m);

# 训练模型
# @epochs 10000 Flux.train!(loss, ps, dataset, opt, cb = throttle(evalcb, 10));
for i in 1:3000
    Flux.train!(loss, ps, dataset, opt);
    if i%100 == 0
        println("$(i) steps loss is $(loss(xs, ys))")        
    end
end

# 绘制数据图像
using Gadfly;
l1 = Gadfly.layer(x=xs, y=ys, Geom.point, Theme(default_color="black"));
Gadfly.plot(l1, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin"], ["black"]))
x0 = xs;
println("真实值:$(2.0.*sin.(x0));", "预测值:$(Float64.(m(x0).data))")
l2 = Gadfly.layer(x=x0, y=m(x0).data, Geom.line, Theme(default_color="red"))
p = Gadfly.plot(l1, l2, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin", "predict"], ["black", "red"]))
p |> SVG("pic.SVG")

输出图像如下:

京ICP备17009874号-2