求高手赐教
目的是拟合曲线y=sin(x), 0.0<x<3.0;
代码如下:
using Flux;
using Flux: @epochs, throttle;
# 产生数据集
xs = reshape(collect(0:0.05:3.0), 1, :);
ys = sin.(xs) .+ 0.15*rand(1, length(xs));
dataset = [(xs, ys)]; # or dataset = Iterators.repeated((xs, ys), 1000)
# 定义模型
m = Chain(Dense(1, 50, tanh),
Dense(50, 25),
Dense(25, 1));
loss(x, y) = Flux.mse(m(x), y); # 定义误差函数
evalcb() = @show(loss(xs, ys))
opt = ADAM();
ps = params(m);
# 训练模型
@epochs 1000 Flux.train!(loss, ps, dataset, opt, cb = throttle(evalcb, 10));
# 绘制数据图像
using Gadfly;
l1 = Gadfly.layer(x=xs, y=ys, Geom.point, Theme(default_color="black"));
Gadfly.plot(l1, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin"], ["black"]))
x0 = [1.5 2.5];
println("真实值:$(sin.(x0));", "预测值:$(m(x0))")
l2 = Gadfly.layer(x=x0, y=m(x0), Geom.point, Theme(default_color="red"))
Gadfly.plot(l1, l2, Guide.xlabel("x"), Guide.ylabel("y"), Guide.title("sin(x)"), Guide.manual_color_key("Legend", ["origin", "predict"], ["black", "red"]))
Gadfly.plot(x=[1.5 2.5], y=m([1.5 2.5]), Geom.point)
出现的问题是,利用训练的模型求解得到的值无法绘图(提示类型错误) ;
报错截图如下: