Skip to content

Commit 06f41b5

Browse files
ka-bearChrisRackauckas
authored andcommitted
Adaptible PIPNs
1 parent ce182c9 commit 06f41b5

File tree

1 file changed

+25
-27
lines changed

1 file changed

+25
-27
lines changed

src/pinn_types.jl

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -650,41 +650,39 @@ function PIPN(chain, strategy = GridTraining(0.1);
650650
logger = nothing,
651651
log_options = LogOptions(),
652652
iteration = nothing,
653+
shared_mlp1_sizes = [64, 64],
654+
shared_mlp2_sizes = [128, 1024],
655+
after_pool_mlp_sizes = [512, 256, 128],
653656
kwargs...)
654657

655-
input_dim = chain[1].in_dims[1]
656-
output_dim = chain[end].out_dims[1]
658+
input_dim = chain[1].in_dims[1]
659+
output_dim = chain[end].out_dims[1]
657660

658-
println("hi");
661+
# Create shared_mlp1
662+
shared_mlp1_layers = [Lux.Dense(i == 1 ? input_dim : shared_mlp1_sizes[i-1] => shared_mlp1_sizes[i], tanh) for i in 1:length(shared_mlp1_sizes)]
663+
shared_mlp1 = Lux.Chain(shared_mlp1_layers...)
659664

660-
shared_mlp1 = Lux.Chain(
661-
Lux.Dense(input_dim => 64, tanh),
662-
Lux.Dense(64 => 64, tanh)
663-
)
665+
# Create shared_mlp2
666+
shared_mlp2_layers = [Lux.Dense(i == 1 ? shared_mlp1_sizes[end] : shared_mlp2_sizes[i-1] => shared_mlp2_sizes[i], tanh) for i in 1:length(shared_mlp2_sizes)]
667+
shared_mlp2 = Lux.Chain(shared_mlp2_layers...)
664668

665-
shared_mlp2 = Lux.Chain(
666-
Lux.Dense(64 => 128, tanh),
667-
Lux.Dense(128 => 1024, tanh)
668-
)
669+
# Create after_pool_mlp
670+
after_pool_input_size = 2 * shared_mlp2_sizes[end] # Doubled due to concatenation
671+
after_pool_mlp_layers = [Lux.Dense(i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i-1] => after_pool_mlp_sizes[i], tanh) for i in 1:length(after_pool_mlp_sizes)]
672+
after_pool_mlp = Lux.Chain(after_pool_mlp_layers...)
669673

670-
after_pool_mlp = Lux.Chain(
671-
Lux.Dense(2048 => 512, tanh), # Changed from 1024 to 2048
672-
Lux.Dense(512 => 256, tanh),
673-
Lux.Dense(256 => 128, tanh)
674-
)
674+
final_layer = Lux.Dense(after_pool_mlp_sizes[end] => output_dim)
675675

676-
final_layer = Lux.Dense(128 => output_dim)
677-
678-
if iteration isa Vector{Int64}
679-
self_increment = false
680-
else
681-
iteration = [1]
682-
self_increment = true
683-
end
676+
if iteration isa Vector{Int64}
677+
self_increment = false
678+
else
679+
iteration = [1]
680+
self_increment = true
681+
end
684682

685-
PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
686-
strategy, init_params, param_estim, additional_loss, adaptive_loss,
687-
logger, log_options, iteration, self_increment, kwargs)
683+
PIPN(shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
684+
strategy, init_params, param_estim, additional_loss, adaptive_loss,
685+
logger, log_options, iteration, self_increment, kwargs)
688686
end
689687

690688
function (model::PIPN)(x, ps, st::NamedTuple)

0 commit comments

Comments
 (0)