@@ -913,38 +913,21 @@ def patchify(x, patch_size):
913
913
if patch_size == 1 :
914
914
return x
915
915
916
- if x .dim () == 4 :
917
- # x shape: [batch_size, channels, height, width]
918
- batch_size , channels , height , width = x .shape
919
-
920
- # Ensure height and width are divisible by patch_size
921
- if height % patch_size != 0 or width % patch_size != 0 :
922
- raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
923
-
924
- # Reshape to [batch_size, channels, height//patch_size, patch_size, width//patch_size, patch_size]
925
- x = x .view (batch_size , channels , height // patch_size , patch_size , width // patch_size , patch_size )
926
-
927
- # Rearrange to [batch_size, channels * patch_size * patch_size, height//patch_size, width//patch_size]
928
- x = x .permute (0 , 1 , 3 , 5 , 2 , 4 ).contiguous ()
929
- x = x .view (batch_size , channels * patch_size * patch_size , height // patch_size , width // patch_size )
930
-
931
- elif x .dim () == 5 :
932
- # x shape: [batch_size, channels, frames, height, width]
933
- batch_size , channels , frames , height , width = x .shape
934
-
935
- # Ensure height and width are divisible by patch_size
936
- if height % patch_size != 0 or width % patch_size != 0 :
937
- raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
916
+ if x .dim () != 5 :
917
+ raise ValueError (f"Invalid input shape: { x .shape } " )
918
+ # x shape: [batch_size, channels, frames, height, width]
919
+ batch_size , channels , frames , height , width = x .shape
938
920
939
- # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
940
- x = x .view (batch_size , channels , frames , height // patch_size , patch_size , width // patch_size , patch_size )
921
+ # Ensure height and width are divisible by patch_size
922
+ if height % patch_size != 0 or width % patch_size != 0 :
923
+ raise ValueError (f"Height ({ height } ) and width ({ width } ) must be divisible by patch_size ({ patch_size } )" )
941
924
942
- # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
943
- x = x .permute (0 , 1 , 4 , 6 , 2 , 3 , 5 ).contiguous ()
944
- x = x .view (batch_size , channels * patch_size * patch_size , frames , height // patch_size , width // patch_size )
925
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926
+ x = x .view (batch_size , channels , frames , height // patch_size , patch_size , width // patch_size , patch_size )
945
927
946
- else :
947
- raise ValueError (f"Invalid input shape: { x .shape } " )
928
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929
+ x = x .permute (0 , 1 , 6 , 4 , 2 , 3 , 5 ).contiguous ()
930
+ x = x .view (batch_size , channels * patch_size * patch_size , frames , height // patch_size , width // patch_size )
948
931
949
932
return x
950
933
@@ -953,29 +936,18 @@ def unpatchify(x, patch_size):
953
936
if patch_size == 1 :
954
937
return x
955
938
956
- if x .dim () == 4 :
957
- # x shape: [b, (c * patch_size * patch_size), h, w]
958
- batch_size , c_patches , height , width = x .shape
959
- channels = c_patches // (patch_size * patch_size )
960
-
961
- # Reshape to [b, c, patch_size, patch_size, h, w]
962
- x = x .view (batch_size , channels , patch_size , patch_size , height , width )
963
-
964
- # Rearrange to [b, c, h * patch_size, w * patch_size]
965
- x = x .permute (0 , 1 , 4 , 2 , 5 , 3 ).contiguous ()
966
- x = x .view (batch_size , channels , height * patch_size , width * patch_size )
967
-
968
- elif x .dim () == 5 :
969
- # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
970
- batch_size , c_patches , frames , height , width = x .shape
971
- channels = c_patches // (patch_size * patch_size )
939
+ if x .dim () != 5 :
940
+ raise ValueError (f"Invalid input shape: { x .shape } " )
941
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942
+ batch_size , c_patches , frames , height , width = x .shape
943
+ channels = c_patches // (patch_size * patch_size )
972
944
973
- # Reshape to [b, c, patch_size, patch_size, f, h, w]
974
- x = x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
945
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
946
+ x = x .view (batch_size , channels , patch_size , patch_size , frames , height , width )
975
947
976
- # Rearrange to [b, c, f, h * patch_size, w * patch_size]
977
- x = x .permute (0 , 1 , 4 , 5 , 2 , 6 , 3 ).contiguous ()
978
- x = x .view (batch_size , channels , frames , height * patch_size , width * patch_size )
948
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
949
+ x = x .permute (0 , 1 , 4 , 5 , 3 , 6 , 2 ).contiguous ()
950
+ x = x .view (batch_size , channels , frames , height * patch_size , width * patch_size )
979
951
980
952
return x
981
953
@@ -1044,7 +1016,6 @@ def __init__(
1044
1016
patch_size : Optional [int ] = None ,
1045
1017
scale_factor_temporal : Optional [int ] = 4 ,
1046
1018
scale_factor_spatial : Optional [int ] = 8 ,
1047
- clip_output : bool = True ,
1048
1019
) -> None :
1049
1020
super ().__init__ ()
1050
1021
@@ -1244,10 +1215,11 @@ def _decode(self, z: torch.Tensor, return_dict: bool = True):
1244
1215
out_ = self .decoder (x [:, :, i : i + 1 , :, :], feat_cache = self ._feat_map , feat_idx = self ._conv_idx )
1245
1216
out = torch .cat ([out , out_ ], 2 )
1246
1217
1247
- if self .config .clip_output :
1248
- out = torch .clamp (out , min = - 1.0 , max = 1.0 )
1249
1218
if self .config .patch_size is not None :
1250
1219
out = unpatchify (out , patch_size = self .config .patch_size )
1220
+
1221
+ out = torch .clamp (out , min = - 1.0 , max = 1.0 )
1222
+
1251
1223
self .clear_cache ()
1252
1224
if not return_dict :
1253
1225
return (out ,)
0 commit comments