@@ -423,12 +423,12 @@ def forward(self, src, pha, err, hid, tri):
423
423
x = paddle .concat ([hid , pha , tri ], axis = 1 )
424
424
x = F .interpolate (
425
425
x ,
426
- paddle .concat ((h_half , w_half )),
426
+ paddle .stack ((h_half , w_half )). squeeze ( ),
427
427
mode = 'bilinear' ,
428
428
align_corners = False )
429
429
y = F .interpolate (
430
430
src ,
431
- paddle .concat ((h_half , w_half )),
431
+ paddle .stack ((h_half , w_half )). squeeze ( ),
432
432
mode = 'bilinear' ,
433
433
align_corners = False )
434
434
@@ -440,11 +440,11 @@ def forward(self, src, pha, err, hid, tri):
440
440
x = self .conv2 (x )
441
441
442
442
if self .kernel_size == 3 :
443
- x = F .interpolate (x , paddle .concat ((h_full + 4 , w_full + 4 )))
443
+ x = F .interpolate (x , paddle .stack ((h_full + 4 , w_full + 4 )). squeeze ( ))
444
444
y = F .pad (src , [2 , 2 , 2 , 2 ])
445
445
else :
446
446
x = F .interpolate (
447
- x , paddle .concat ((h_full , w_full )), mode = 'nearest' )
447
+ x , paddle .stack ((h_full , w_full )). squeeze ( ), mode = 'nearest' )
448
448
y = src
449
449
450
450
x = self .conv3 (paddle .concat ([x , y ], axis = 1 ))
0 commit comments