我照着
改了一个判断大小的,就可以用。
using Flux: Chain, Dense, params, crossentropy, onehotbatch, mse,
ADAMW, train!, softmax
using Test
# Data preparation
function func(x, y)
if x > y
return "gt"
elseif x < y
return "lt"
else
return "eq"
end
end
func(a::Array) = func.(a[1, :], a[2, :])
const LABELS = ["gt", "lt", "eq"];
@test func([1 0 1; 0 1 1]) == LABELS
raw_x = rand(2, 100);
raw_y = func(raw_x);
X = raw_x;
Y = onehotbatch(raw_y, LABELS);
# Model
m = Chain(Dense(2, 10), Dense(10, 3), softmax)
loss(x, y) = mse(m(x), y)
opt = ADAMW()
# Helpers
deepbuzz(x, y) = LABELS[argmax(m([x; y]))]
deepbuzz(a::Array) = deepbuzz.(a[1, :], a[2, :])
function monitor(e)
print("epoch $(lpad(e, 4)): loss = $(round(loss(X,Y).data; digits=4)) | ")
println(deepbuzz([1 0 1; 0 1 1]))
end
# Training
for e in 0:3000
train!(loss, params(m), [(X, Y)], opt)
if e % 100 == 0; monitor(e) end
end
epoch 2700: loss = 0.0088 | ["gt", "lt", "gt"]
epoch 2800: loss = 0.0085 | ["gt", "lt", "lt"]
epoch 2900: loss = 0.0082 | ["gt", "lt", "lt"]
epoch 3000: loss = 0.0079 | ["gt", "lt", "lt"]
julia>
julia> deepbuzz([0,1])
1-element Array{String,1}:
"lt"
julia> deepbuzz([0,-1])
1-element Array{String,1}:
"gt"
julia> deepbuzz([0,0])
1-element Array{String,1}:
"lt"
julia> deepbuzz([999,0])
1-element Array{String,1}:
"gt"
julia> deepbuzz([999,5000])
1-element Array{String,1}:
"lt"
julia> deepbuzz([3.15; 3.14])
1-element Array{String,1}: