dyr
1
第一张图是我想构建的网络结构,第二张图是我写的程序,但是这个函数是无法调用出参数的
# Struct to define model
function m(x)
pqsize = 72
vsize = 36
layer0 = Dense(pqsize, 52, relu)
layer1 = Dense(52, 32, relu)
z1 =
passthrough1 = Dense(pqsize, 32, relu)
layer2 = Dense(32, 21, relu)
passthrough2 = Dense(pqsize, 21, relu)
layer3 = Dense(21, 30, relu)
passthrough3 = Dense(pqsize, 30, relu)
layer4 = Dense(30, vsize, relu)
return layer4(passthrough3(x) + layer3(passthrough2(x) +
layer2(passthrough1(x) + layer1(layer0(x)))))
end
这是我写的程序,希望好心的小可爱能给我指个方向呜呜呜
简单来说你这里基本结构是对的,但是参数没写对
struct PassThroughBlock
forward
passthrough
end
# 告诉Flux这是一个Flux兼容的网络层,这个也许可以不写,但不太确定
Flux.@functor PassThroughBlock
# 先定义一个方便的构造函数
function PassThroughBlock(Ns::Tuple; activation = relu)
Ls = [Dense(Ns[1], Ns[2], activation)]
Ps = []
for (n_in, n_out) in zip(Ns[2:end-1], Ns[3:end])
push!(Ls, Dense(n_in, n_out, activation))
push!(Ps, Dense(Ns[1], n_out, activation))
end
return PassThroughBlock(Chain(Ls...), Chain(Ps...))
end
# 接下来定义怎么样进行前向传播
function (block::PassThroughBlock)(x)
Ls = block.forward
Ps = block.passthrough
z = Ls[1](x)
for (l, p) in zip(Ls[2:end], Ps)
z = l(z) + p(x)
end
return z
end
# 测试
x = ones(72, 10)
Ns = (72, 52, 32, 21, 30, 36)
block = PassThroughBlock(Ns)
block(x) # array of size (36, 10)
2 个赞
dyr
5
呜呜呜呜,谢谢大神,Julia社区真好,我会好好努力的呜呜呜
1 个赞