求帮我看看这个代码有哪些问题,一跑就挂了。
using Flux
using Flux.Data
using Flux.CUDA
using Images
using Formatting
using Base.Iterators
function preprocess(batch_images, batch_labels, device)
x = Float32.(Flux.batch(map(img -> reshape(channelview(img), (28, 28, 1)), batch_images))) |> device
y = Float32.(Flux.onehotbatch(batch_labels, 0:9)) |> device
return x, y
end
function build_dataloader()
images = MNIST.images()
labels = MNIST.labels()
trainloader = DataLoader((images, labels), batchsize=32, shuffle=true)
return trainloader
end
function build_model(device)
return Chain(
Conv((3, 3), 1 => 64, relu, stride=1, pad=1), # 64 * 28 * 28
Conv((3, 3), 64 => 128, relu, stride=1, pad=1), # 128 * 28 * 28
MaxPool((2, 2), stride=2),
Flux.flatten,
Dense(14 * 14 * 128, 1024, relu),
Dropout(0.5),
Dense(1024, 10),
softmax
) |> device
end
mutable struct Loss{M} <: Function
m::M
device::Function
end
function (loss::Loss)(images, labels)
x, y = preprocess(images, labels, loss.device)
loss = Flux.crossentropy(loss.m(x), y)
return loss
end
function train_model!(model, dataloader::DataLoader, opt, device, epochs::Integer=10)
ps = Flux.params(model)
loss_fn = Loss(model, device)
for epoch ∈ 1:epochs
printfmtln("epoch={:03d}", epoch)
Flux.train!(loss_fn, ps, dataloader, opt)
break
end
end
function run_train()
device = gpu
model = build_model(device)
dataloader = build_dataloader()
loss_fn = Loss(model, device)
opt = ADAM(0.001, (0.9, 0.999))
train_model!(model, dataloader, opt, device)
end
run_train()