|
9 | 9 | import torch.nn as nn
|
10 | 10 | import torchvision.transforms.functional as torchvis_fun
|
11 | 11 |
|
12 |
| -from . import pretrained_layers |
| 12 | +from . import pretrained_layers, pretrained_models |
13 | 13 |
|
14 | 14 | __all__ = [
|
15 | 15 | 'generic_features_size',
|
16 |
| - 'check_input_size' |
| 16 | + 'check_input_size', |
| 17 | + 'register_model_hooks' |
17 | 18 | ]
|
18 | 19 |
|
19 | 20 |
|
20 | 21 | def out_hook(name: str, out_dict: Dict, sequence_first: Optional[bool] = False) -> Callable:
|
21 |
| - """Creating callable hook function""" |
| 22 | + """ |
| 23 | + Create a hook to capture the output of a specific layer in a PyTorch model. |
| 24 | +
|
| 25 | + Parameters: |
| 26 | + name (str): Name of the layer. |
| 27 | + out_dict (Dict): Dictionary to store the captured output. |
| 28 | + sequence_first (Optional[bool]): Whether to permute the output tensor if it has three |
| 29 | + dimensions with the sequence dimension first. |
| 30 | +
|
| 31 | + Returns: |
| 32 | + Callable: Hook function. |
| 33 | + """ |
22 | 34 |
|
23 | 35 | def hook(_model: nn.Module, _input_x: torch.Tensor, output_y: torch.Tensor):
|
| 36 | + """Detach the output tensor and store it in the dictionary""" |
24 | 37 | out_dict[name] = output_y.detach()
|
| 38 | + |
25 | 39 | if sequence_first and len(out_dict[name].shape) == 3:
|
26 |
| - # clip output (SequenceLength, Batch, HiddenDimension) |
| 40 | + # If sequence_first is True and the tensor has three dimensions, permute the tensor |
| 41 | + # (SequenceLength, Batch, HiddenDimension) -> (Batch, SequenceLength, HiddenDimension) |
27 | 42 | out_dict[name] = out_dict[name].permute(1, 0, 2)
|
28 | 43 |
|
29 | 44 | return hook
|
30 | 45 |
|
31 | 46 |
|
32 |
| -def resnet_hooks(model: nn.Module, layers: List[str], is_clip: Optional[bool] = False) -> ( |
33 |
| - Dict, Dict): |
34 |
| - """Creates hooks for the ResNet model.""" |
35 |
| - act_dict = dict() |
36 |
| - rf_hooks = dict() |
37 |
| - model_layers = list(model.children()) |
| 47 | +def _resnet_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict): |
| 48 | + """Setting up hooks for the ResNet architecture.""" |
| 49 | + is_clip = 'clip' in architecture |
| 50 | + acts, hooks = dict(), dict() |
| 51 | + if architecture in pretrained_models.available_models()['segmentation']: |
| 52 | + model_layers = list(model.parent_model.children()) |
| 53 | + else: |
| 54 | + model_layers = list(model.children()) |
38 | 55 | for layer in layers:
|
39 | 56 | l_ind = pretrained_layers.resnet_layer(layer, is_clip=is_clip)
|
40 |
| - rf_hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, act_dict)) |
41 |
| - return act_dict, rf_hooks |
| 57 | + hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts)) |
| 58 | + return acts, hooks |
42 | 59 |
|
43 | 60 |
|
44 |
| -def clip_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict): |
45 |
| - """Creates hooks for the Clip model.""" |
| 61 | +def _clip_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict): |
| 62 | + """Setting up hooks for the CLIP networks.""" |
46 | 63 | if architecture.replace('clip_', '') in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64']:
|
47 |
| - act_dict, rf_hooks = resnet_hooks(model, layers, is_clip=True) |
| 64 | + acts, hooks = _resnet_hooks(model, layers, architecture) |
48 | 65 | else:
|
49 |
| - act_dict = dict() |
50 |
| - rf_hooks = dict() |
| 66 | + acts, hooks = dict(), dict() |
51 | 67 | for layer in layers:
|
52 | 68 | if layer == 'encoder':
|
53 | 69 | layer_hook = model
|
54 |
| - elif layer == 'conv1': |
| 70 | + elif layer == 'conv_proj': |
55 | 71 | layer_hook = model.conv1
|
56 | 72 | else:
|
57 | 73 | block_ind = int(layer.replace('block', ''))
|
58 | 74 | layer_hook = model.transformer.resblocks[block_ind]
|
59 |
| - rf_hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, act_dict, True)) |
60 |
| - return act_dict, rf_hooks |
| 75 | + hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts, True)) |
| 76 | + return acts, hooks |
61 | 77 |
|
62 | 78 |
|
63 |
| -def vit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
64 |
| - act_dict = dict() |
65 |
| - rf_hooks = dict() |
| 79 | +def _vit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 80 | + """Setting up hooks for the ViT architecture.""" |
| 81 | + acts, hooks = dict(), dict() |
66 | 82 | for layer in layers:
|
67 | 83 | if layer == 'fc':
|
68 |
| - layer_hook = model.heads |
| 84 | + layer_hook = model |
69 | 85 | elif layer == 'conv_proj':
|
70 | 86 | layer_hook = model.conv_proj
|
71 | 87 | else:
|
72 | 88 | block_ind = int(layer.replace('block', ''))
|
73 | 89 | layer_hook = model.encoder.layers[block_ind]
|
74 |
| - rf_hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, act_dict)) |
75 |
| - return act_dict, rf_hooks |
| 90 | + hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts)) |
| 91 | + return acts, hooks |
| 92 | + |
| 93 | + |
| 94 | +def _maxvit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 95 | + """Setting up hooks for the ViT architecture.""" |
| 96 | + acts, hooks = dict(), dict() |
| 97 | + for layer in layers: |
| 98 | + if layer == 'fc': |
| 99 | + layer_hook = model |
| 100 | + elif layer == 'stem': |
| 101 | + layer_hook = model.stem |
| 102 | + elif 'block' in layer: |
| 103 | + l_ind = int(layer.replace('block', '')) - 1 |
| 104 | + layer_hook = list(model.blocks.children())[l_ind] |
| 105 | + elif 'classifier' in layer: |
| 106 | + l_ind = int(layer.replace('classifier', '')) |
| 107 | + layer_hook = list(model.classifier.children())[l_ind] |
| 108 | + else: |
| 109 | + raise RuntimeError('Unsupported MaxViT layer %s' % layer) |
| 110 | + hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts)) |
| 111 | + return acts, hooks |
| 112 | + |
| 113 | + |
| 114 | +def _swin_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 115 | + """Setting up hooks for the SwinTransformer architecture.""" |
| 116 | + return _attribute_hooks(model, layers, {'block': model.features}) |
| 117 | + |
| 118 | + |
| 119 | +def _regnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 120 | + """Setting up hooks for the RegNet architecture.""" |
| 121 | + acts, hooks = dict(), dict() |
| 122 | + for layer in layers: |
| 123 | + if layer == 'fc': |
| 124 | + layer_hook = model |
| 125 | + elif layer == 'stem': |
| 126 | + layer_hook = model.stem |
| 127 | + elif 'block' in layer: |
| 128 | + l_ind = int(layer.replace('block', '')) - 1 |
| 129 | + layer_hook = list(model.trunk_output.children())[l_ind] |
| 130 | + else: |
| 131 | + raise RuntimeError('Unsupported regnet layer %s' % layer) |
| 132 | + hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts)) |
| 133 | + return acts, hooks |
| 134 | + |
| 135 | + |
| 136 | +def _child_hook(children: List, layer: str, keyword: str): |
| 137 | + l_ind = int(layer.replace(keyword, '')) |
| 138 | + return children[l_ind] |
| 139 | + |
| 140 | + |
| 141 | +def _attribute_hooks(model: nn.Module, layers: List[str], |
| 142 | + attributes: Optional[Dict] = None) -> (Dict, Dict): |
| 143 | + """Setting up hooks for networks with children attributes.""" |
| 144 | + acts, hooks = dict(), dict() |
| 145 | + # A dynamic way to get model children with different names |
| 146 | + if attributes is None: |
| 147 | + attributes = { |
| 148 | + 'feature': model.features, |
| 149 | + 'classifier': model.classifier |
| 150 | + } |
| 151 | + # Looping through all the layers and making the hooks |
| 152 | + for layer in layers: |
| 153 | + if layer == 'fc': |
| 154 | + layer_hook = model |
| 155 | + else: |
| 156 | + layer_hook = None |
| 157 | + for attr, children in attributes.items(): |
| 158 | + if attr in layer: |
| 159 | + layer_hook = _child_hook(children, layer, attr) |
| 160 | + break |
| 161 | + if layer_hook is None: |
| 162 | + raise RuntimeError('Unsupported layer %s' % layer) |
| 163 | + hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts)) |
| 164 | + return acts, hooks |
| 165 | + |
| 166 | + |
| 167 | +def _alexnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 168 | + """Setting up hooks for the AlexNet architecture.""" |
| 169 | + return _attribute_hooks(model, layers) |
| 170 | + |
| 171 | + |
| 172 | +def _convnext_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 173 | + """Setting up hooks for the ConvNeXt architecture.""" |
| 174 | + return _attribute_hooks(model, layers) |
| 175 | + |
| 176 | + |
| 177 | +def _efficientnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 178 | + """Setting up hooks for the EfficientNet architecture.""" |
| 179 | + return _attribute_hooks(model, layers) |
| 180 | + |
| 181 | + |
| 182 | +def _densenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 183 | + """Setting up hooks for the DensNet architecture.""" |
| 184 | + return _attribute_hooks(model, layers) |
| 185 | + |
| 186 | + |
| 187 | +def _googlenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 188 | + """Setting up hooks for the GoogLeNet architecture.""" |
| 189 | + acts, hooks = dict(), dict() |
| 190 | + model_layers = list(model.parent_model.children()) |
| 191 | + for layer in layers: |
| 192 | + l_ind = pretrained_layers.googlenet_cutoff_slice(layer) |
| 193 | + l_ind = -1 if l_ind is None else l_ind - 1 |
| 194 | + hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts)) |
| 195 | + return acts, hooks |
| 196 | + |
| 197 | + |
| 198 | +def _inception_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 199 | + """Setting up hooks for the Inception architecture.""" |
| 200 | + acts, hooks = dict(), dict() |
| 201 | + model_layers = list(model.parent_model.children()) |
| 202 | + for layer in layers: |
| 203 | + l_ind = pretrained_layers.inception_cutoff_slice(layer) |
| 204 | + l_ind = -1 if l_ind is None else l_ind - 1 |
| 205 | + hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts)) |
| 206 | + return acts, hooks |
| 207 | + |
| 208 | + |
| 209 | +def _mnasnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 210 | + """Setting up hooks for the Mnasnet architecture.""" |
| 211 | + return _attribute_hooks(model, layers, {'layer': model.layers}) |
| 212 | + |
| 213 | + |
| 214 | +def _shufflenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 215 | + """Setting up hooks for the ShuffleNet architecture.""" |
| 216 | + return _attribute_hooks(model, layers, {'layer': list(model.children())}) |
| 217 | + |
| 218 | + |
| 219 | +def _mobilenet_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict): |
| 220 | + if architecture in ['lraspp_mobilenet_v3_large', 'deeplabv3_mobilenet_v3_large']: |
| 221 | + return _attribute_hooks(model, layers, {'feature': list(model.parent_model.children())}) |
| 222 | + return _attribute_hooks(model, layers) |
| 223 | + |
| 224 | + |
| 225 | +def _squeezenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 226 | + """Setting up hooks for the SqueezeNet architecture.""" |
| 227 | + return _attribute_hooks(model, layers) |
| 228 | + |
| 229 | + |
| 230 | +def _vgg_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict): |
| 231 | + """Setting up hooks for the VGG architecture.""" |
| 232 | + return _attribute_hooks(model, layers) |
76 | 233 |
|
77 | 234 |
|
78 | 235 | def register_model_hooks(model: nn.Module, architecture: str, layers: List[str]) -> (Dict, Dict):
|
79 |
| - """Registering forward hooks to the network.""" |
| 236 | + """ |
| 237 | + Register hooks for capturing activation for specific layers in the model. |
| 238 | +
|
| 239 | + Parameters: |
| 240 | + model (nn.Module): PyTorch model. |
| 241 | + architecture (str): Model architecture name. |
| 242 | + layers (List[str]): List of layer names for which to register hooks. |
| 243 | +
|
| 244 | + Raises: |
| 245 | + RuntimeError: If the specified layer is not supported for the given architecture. |
| 246 | +
|
| 247 | + Returns: |
| 248 | + (Dict, Dict): Dictionaries containing activation values and registered forward hooks. |
| 249 | + """ |
| 250 | + for layer in layers: |
| 251 | + if layer not in pretrained_layers.available_layers(architecture): |
| 252 | + raise RuntimeError( |
| 253 | + 'Layer %s is not supported for architecture %s. Call ' |
| 254 | + 'pretrained_layers.available_layers to see a list of supported layers for an ' |
| 255 | + 'architecture.' % (layer, architecture) |
| 256 | + ) |
| 257 | + |
80 | 258 | if is_resnet_backbone(architecture):
|
81 |
| - act_dict, rf_hooks = resnet_hooks(model, layers) |
| 259 | + return _resnet_hooks(model, layers, architecture) |
82 | 260 | elif 'clip' in architecture:
|
83 |
| - act_dict, rf_hooks = clip_hooks(model, layers, architecture) |
| 261 | + return _clip_hooks(model, layers, architecture) |
| 262 | + elif 'maxvit' in architecture: |
| 263 | + return _maxvit_hooks(model, layers) |
84 | 264 | elif 'vit_' in architecture:
|
85 |
| - act_dict, rf_hooks = vit_hooks(model, layers) |
| 265 | + return _vit_hooks(model, layers) |
| 266 | + elif 'regnet' in architecture: |
| 267 | + return _regnet_hooks(model, layers) |
| 268 | + elif 'vgg' in architecture: |
| 269 | + return _vgg_hooks(model, layers) |
| 270 | + elif architecture == 'alexnet': |
| 271 | + return _alexnet_hooks(model, layers) |
| 272 | + elif architecture == 'googlenet': |
| 273 | + return _googlenet_hooks(model, layers) |
| 274 | + elif architecture == 'inception_v3': |
| 275 | + return _inception_hooks(model, layers) |
| 276 | + elif 'convnext' in architecture: |
| 277 | + return _convnext_hooks(model, layers) |
| 278 | + elif 'densenet' in architecture: |
| 279 | + return _densenet_hooks(model, layers) |
| 280 | + elif 'mnasnet' in architecture: |
| 281 | + return _mnasnet_hooks(model, layers) |
| 282 | + elif 'shufflenet' in architecture: |
| 283 | + return _shufflenet_hooks(model, layers) |
| 284 | + elif 'squeezenet' in architecture: |
| 285 | + return _squeezenet_hooks(model, layers) |
| 286 | + elif 'efficientnet' in architecture: |
| 287 | + return _efficientnet_hooks(model, layers) |
| 288 | + elif 'mobilenet' in architecture: |
| 289 | + return _mobilenet_hooks(model, layers, architecture) |
| 290 | + elif 'swin_' in architecture: |
| 291 | + return _swin_hooks(model, layers) |
86 | 292 | else:
|
87 | 293 | raise RuntimeError('Model hooks does not support network %s' % architecture)
|
88 |
| - return act_dict, rf_hooks |
89 | 294 |
|
90 | 295 |
|
91 | 296 | def is_resnet_backbone(architecture: str) -> bool:
|
|
0 commit comments