Flux的用法问题

求帮我看看这个代码有哪些问题,一跑就挂了。

using Flux
using Flux.Data
using Flux.CUDA
using Images
using Formatting
using Base.Iterators

function preprocess(batch_images, batch_labels, device)
    x = Float32.(Flux.batch(map(img -> reshape(channelview(img), (28, 28, 1)), batch_images))) |> device
    y = Float32.(Flux.onehotbatch(batch_labels, 0:9)) |> device
    return x, y
end


function build_dataloader()
    images = MNIST.images()
    labels = MNIST.labels()
    trainloader = DataLoader((images, labels), batchsize=32, shuffle=true)
    return trainloader
end

function build_model(device)
    return Chain(
        Conv((3, 3), 1 => 64, relu, stride=1, pad=1),  # 64 * 28 * 28
        Conv((3, 3), 64 => 128, relu, stride=1, pad=1),  # 128 * 28 * 28
        MaxPool((2, 2), stride=2),
        Flux.flatten,
        Dense(14 * 14 * 128, 1024, relu),
        Dropout(0.5),
        Dense(1024, 10),
        softmax
    ) |> device
end

mutable struct Loss{M} <: Function
    m::M
    device::Function
end


function (loss::Loss)(images, labels)
    x, y = preprocess(images, labels, loss.device)
    loss = Flux.crossentropy(loss.m(x), y)
    return loss
end


function train_model!(model, dataloader::DataLoader, opt, device, epochs::Integer=10)
    ps = Flux.params(model)
    loss_fn = Loss(model, device)
    for epoch ∈ 1:epochs
        printfmtln("epoch={:03d}", epoch)
        Flux.train!(loss_fn, ps, dataloader, opt)
        break
    end
end


function run_train()
    device = gpu
    model = build_model(device)
    dataloader = build_dataloader()
    loss_fn = Loss(model, device)
    opt = ADAM(0.001, (0.9, 0.999))

    train_model!(model, dataloader, opt, device)
end


run_train()

这么长一大串,咋帮…
报错也不贴,咋帮…
也不花时间整理问题,咋帮…

不是我不贴报错,直接程序挂了,我说的挂不是exception,而是直接进程没了

sigh…

你这里用了 GPU

这段代码在CPU上运行有试过么?
model简化,去掉dataloader试过么?
去掉customize的Loss有试过么?
去掉train_model!单独运行一个epoch有试过么?

1赞

京ICP备17009874号-2