@@ -650,41 +650,39 @@ function PIPN(chain, strategy = GridTraining(0.1);
650
650
logger = nothing ,
651
651
log_options = LogOptions (),
652
652
iteration = nothing ,
653
+ shared_mlp1_sizes = [64 , 64 ],
654
+ shared_mlp2_sizes = [128 , 1024 ],
655
+ after_pool_mlp_sizes = [512 , 256 , 128 ],
653
656
kwargs... )
654
657
655
- input_dim = chain[1 ]. in_dims[1 ]
656
- output_dim = chain[end ]. out_dims[1 ]
658
+ input_dim = chain[1 ]. in_dims[1 ]
659
+ output_dim = chain[end ]. out_dims[1 ]
657
660
658
- println (" hi" );
661
+ # Create shared_mlp1
662
+ shared_mlp1_layers = [Lux. Dense (i == 1 ? input_dim : shared_mlp1_sizes[i- 1 ] => shared_mlp1_sizes[i], tanh) for i in 1 : length (shared_mlp1_sizes)]
663
+ shared_mlp1 = Lux. Chain (shared_mlp1_layers... )
659
664
660
- shared_mlp1 = Lux. Chain (
661
- Lux. Dense (input_dim => 64 , tanh),
662
- Lux. Dense (64 => 64 , tanh)
663
- )
665
+ # Create shared_mlp2
666
+ shared_mlp2_layers = [Lux. Dense (i == 1 ? shared_mlp1_sizes[end ] : shared_mlp2_sizes[i- 1 ] => shared_mlp2_sizes[i], tanh) for i in 1 : length (shared_mlp2_sizes)]
667
+ shared_mlp2 = Lux. Chain (shared_mlp2_layers... )
664
668
665
- shared_mlp2 = Lux . Chain (
666
- Lux . Dense ( 64 => 128 , tanh),
667
- Lux. Dense (128 => 1024 , tanh)
668
- )
669
+ # Create after_pool_mlp
670
+ after_pool_input_size = 2 * shared_mlp2_sizes[ end ] # Doubled due to concatenation
671
+ after_pool_mlp_layers = [ Lux. Dense (i == 1 ? after_pool_input_size : after_pool_mlp_sizes[i - 1 ] => after_pool_mlp_sizes[i] , tanh) for i in 1 : length (after_pool_mlp_sizes)]
672
+ after_pool_mlp = Lux . Chain (after_pool_mlp_layers ... )
669
673
670
- after_pool_mlp = Lux. Chain (
671
- Lux. Dense (2048 => 512 , tanh), # Changed from 1024 to 2048
672
- Lux. Dense (512 => 256 , tanh),
673
- Lux. Dense (256 => 128 , tanh)
674
- )
674
+ final_layer = Lux. Dense (after_pool_mlp_sizes[end ] => output_dim)
675
675
676
- final_layer = Lux. Dense (128 => output_dim)
677
-
678
- if iteration isa Vector{Int64}
679
- self_increment = false
680
- else
681
- iteration = [1 ]
682
- self_increment = true
683
- end
676
+ if iteration isa Vector{Int64}
677
+ self_increment = false
678
+ else
679
+ iteration = [1 ]
680
+ self_increment = true
681
+ end
684
682
685
- PIPN (shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
686
- strategy, init_params, param_estim, additional_loss, adaptive_loss,
687
- logger, log_options, iteration, self_increment, kwargs)
683
+ PIPN (shared_mlp1, shared_mlp2, after_pool_mlp, final_layer,
684
+ strategy, init_params, param_estim, additional_loss, adaptive_loss,
685
+ logger, log_options, iteration, self_increment, kwargs)
688
686
end
689
687
690
688
function (model:: PIPN )(x, ps, st:: NamedTuple )
0 commit comments