@@ -374,9 +374,15 @@ def forward(self, src, pha, err, hid, tri):
374
374
375
375
x = paddle .concat ([hid , pha , tri ], axis = 1 )
376
376
x = F .interpolate (
377
- x , (h_half , w_half ), mode = 'bilinear' , align_corners = False )
377
+ x ,
378
+ paddle .concat ((h_half , w_half )),
379
+ mode = 'bilinear' ,
380
+ align_corners = False )
378
381
y = F .interpolate (
379
- src , (h_half , w_half ), mode = 'bilinear' , align_corners = False )
382
+ src ,
383
+ paddle .concat ((h_half , w_half )),
384
+ mode = 'bilinear' ,
385
+ align_corners = False )
380
386
381
387
if self .kernel_size == 3 :
382
388
x = F .pad (x , [3 , 3 , 3 , 3 ])
@@ -386,10 +392,11 @@ def forward(self, src, pha, err, hid, tri):
386
392
x = self .conv2 (x )
387
393
388
394
if self .kernel_size == 3 :
389
- x = F .interpolate (x , ( h_full + 4 , w_full + 4 ))
395
+ x = F .interpolate (x , paddle . concat (( h_full + 4 , w_full + 4 ) ))
390
396
y = F .pad (src , [2 , 2 , 2 , 2 ])
391
397
else :
392
- x = F .interpolate (x , (h_full , w_full ), mode = 'nearest' )
398
+ x = F .interpolate (
399
+ x , paddle .concat ((h_full , w_full )), mode = 'nearest' )
393
400
y = src
394
401
395
402
x = self .conv3 (paddle .concat ([x , y ], axis = 1 ))
0 commit comments