Skip to content

Commit 847245e

Browse files
Merge pull request #18 from ArashAkbarinia/development
Development
2 parents 62b3a1b + bf565dd commit 847245e

File tree

11 files changed

+375
-101
lines changed

11 files changed

+375
-101
lines changed

CITATION.cff

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,9 @@
33

44
cff-version: 1.2.0
55
title: >-
6-
Osculari: a Python package to explore and interpret deep
7-
neural networks
6+
Osculari: a Python package to explore artificial neural networks with psychophysical experiments
87
message: >-
9-
If you use this software, please cite it using the
10-
metadata from this file.
8+
If you use this software, please cite it using the metadata from this file.
119
type: software
1210
authors:
1311
- given-names: Arash
@@ -18,5 +16,5 @@ identifiers:
1816
value: 10.5281/zenodo.10214006
1917
repository-code: 'https://github.com/ArashAkbarinia/osculari'
2018
license: MIT
21-
version: v0.0.2
22-
date-released: '2023-11-28'
19+
version: v0.0.4
20+
date-released: '2023-12-21'

README.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
# osculari
22

3-
[![Python version](https://img.shields.io/pypi/pyversions/osculari)](https://pypi.org/project/osculari/)
43
[![Project Status](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
5-
[![Documentation Status](https://readthedocs.org/projects/osculari/badge/?version=latest)](https://osculari.readthedocs.io/en/latest/?badge=latest)
4+
[![Build Status](https://github.com/ArashAkbarinia/osculari/actions/workflows/python-package.yml/badge.svg)](https://github.com/ArashAkbarinia/osculari)
65
[![PyPi Status](https://img.shields.io/pypi/v/osculari.svg)](https://pypi.org/project/osculari/)
6+
[![Python version](https://img.shields.io/pypi/pyversions/osculari)](https://pypi.org/project/osculari/)
7+
[![Documentation Status](https://readthedocs.org/projects/osculari/badge/?version=latest)](https://osculari.readthedocs.io/en/latest/?badge=latest)
8+
[![Documentation Status](https://static.pepy.tech/badge/osculari)](https://pypi.org/project/osculari/)
9+
[![Documentation Status](https://codecov.io/gh/ArashAkbarinia/osculari/branch/main/graph/badge.svg)](https://app.codecov.io/gh/ArashAkbarinia/osculari)
10+
[![Pytorch version](https://img.shields.io/badge/PyTorch_1.9.1+-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/get-started/locally/)
711
[![Licence](https://img.shields.io/pypi/l/osculari.svg)](LICENSE)
812
[![DOI](https://zenodo.org/badge/717052640.svg)](https://zenodo.org/doi/10.5281/zenodo.10214005)
913

docs/source/index.rst

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,16 @@ Osculari
88

99
.. image:: https://www.repostatus.org/badges/latest/active.svg
1010
:target: https://www.repostatus.org/#active
11-
.. image:: https://img.shields.io/github/v/release/ArashAkbarinia/osculari?logo=github
11+
.. image:: https://github.com/ArashAkbarinia/osculari/actions/workflows/python-package.yml/badge.svg
1212
:target: https://github.com/ArashAkbarinia/osculari
13+
.. image:: https://img.shields.io/pypi/v/osculari.svg
14+
:target: https://pypi.org/project/osculari/
1315
.. image:: https://img.shields.io/pypi/pyversions/osculari.svg
1416
:target: https://pypi.org/project/osculari/
1517
.. image:: https://static.pepy.tech/badge/osculari
1618
:target: https://pypi.org/project/osculari/
19+
.. image:: https://codecov.io/gh/ArashAkbarinia/osculari/branch/main/graph/badge.svg
20+
:target: https://app.codecov.io/gh/ArashAkbarinia/osculari
1721
.. image:: https://img.shields.io/badge/PyTorch_1.9.1+-ee4c2c?logo=pytorch&logoColor=white
1822
:target: https://pytorch.org/get-started/locally/
1923
.. image:: https://img.shields.io/pypi/l/osculari.svg

osculari/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@
99
from .models.pretrained_models import available_models
1010

1111
# Version variable
12-
__version__ = "0.0.3"
12+
__version__ = "0.0.4"

osculari/models/model_utils.py

Lines changed: 236 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,83 +9,288 @@
99
import torch.nn as nn
1010
import torchvision.transforms.functional as torchvis_fun
1111

12-
from . import pretrained_layers
12+
from . import pretrained_layers, pretrained_models
1313

1414
__all__ = [
1515
'generic_features_size',
16-
'check_input_size'
16+
'check_input_size',
17+
'register_model_hooks'
1718
]
1819

1920

2021
def out_hook(name: str, out_dict: Dict, sequence_first: Optional[bool] = False) -> Callable:
21-
"""Creating callable hook function"""
22+
"""
23+
Create a hook to capture the output of a specific layer in a PyTorch model.
24+
25+
Parameters:
26+
name (str): Name of the layer.
27+
out_dict (Dict): Dictionary to store the captured output.
28+
sequence_first (Optional[bool]): Whether to permute the output tensor if it has three
29+
dimensions with the sequence dimension first.
30+
31+
Returns:
32+
Callable: Hook function.
33+
"""
2234

2335
def hook(_model: nn.Module, _input_x: torch.Tensor, output_y: torch.Tensor):
36+
"""Detach the output tensor and store it in the dictionary"""
2437
out_dict[name] = output_y.detach()
38+
2539
if sequence_first and len(out_dict[name].shape) == 3:
26-
# clip output (SequenceLength, Batch, HiddenDimension)
40+
# If sequence_first is True and the tensor has three dimensions, permute the tensor
41+
# (SequenceLength, Batch, HiddenDimension) -> (Batch, SequenceLength, HiddenDimension)
2742
out_dict[name] = out_dict[name].permute(1, 0, 2)
2843

2944
return hook
3045

3146

32-
def resnet_hooks(model: nn.Module, layers: List[str], is_clip: Optional[bool] = False) -> (
33-
Dict, Dict):
34-
"""Creates hooks for the ResNet model."""
35-
act_dict = dict()
36-
rf_hooks = dict()
37-
model_layers = list(model.children())
47+
def _resnet_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict):
48+
"""Setting up hooks for the ResNet architecture."""
49+
is_clip = 'clip' in architecture
50+
acts, hooks = dict(), dict()
51+
if architecture in pretrained_models.available_models()['segmentation']:
52+
model_layers = list(model.parent_model.children())
53+
else:
54+
model_layers = list(model.children())
3855
for layer in layers:
3956
l_ind = pretrained_layers.resnet_layer(layer, is_clip=is_clip)
40-
rf_hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, act_dict))
41-
return act_dict, rf_hooks
57+
hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts))
58+
return acts, hooks
4259

4360

44-
def clip_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict):
45-
"""Creates hooks for the Clip model."""
61+
def _clip_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict):
62+
"""Setting up hooks for the CLIP networks."""
4663
if architecture.replace('clip_', '') in ['RN50', 'RN101', 'RN50x4', 'RN50x16', 'RN50x64']:
47-
act_dict, rf_hooks = resnet_hooks(model, layers, is_clip=True)
64+
acts, hooks = _resnet_hooks(model, layers, architecture)
4865
else:
49-
act_dict = dict()
50-
rf_hooks = dict()
66+
acts, hooks = dict(), dict()
5167
for layer in layers:
5268
if layer == 'encoder':
5369
layer_hook = model
54-
elif layer == 'conv1':
70+
elif layer == 'conv_proj':
5571
layer_hook = model.conv1
5672
else:
5773
block_ind = int(layer.replace('block', ''))
5874
layer_hook = model.transformer.resblocks[block_ind]
59-
rf_hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, act_dict, True))
60-
return act_dict, rf_hooks
75+
hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts, True))
76+
return acts, hooks
6177

6278

63-
def vit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
64-
act_dict = dict()
65-
rf_hooks = dict()
79+
def _vit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
80+
"""Setting up hooks for the ViT architecture."""
81+
acts, hooks = dict(), dict()
6682
for layer in layers:
6783
if layer == 'fc':
68-
layer_hook = model.heads
84+
layer_hook = model
6985
elif layer == 'conv_proj':
7086
layer_hook = model.conv_proj
7187
else:
7288
block_ind = int(layer.replace('block', ''))
7389
layer_hook = model.encoder.layers[block_ind]
74-
rf_hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, act_dict))
75-
return act_dict, rf_hooks
90+
hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts))
91+
return acts, hooks
92+
93+
94+
def _maxvit_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
95+
"""Setting up hooks for the ViT architecture."""
96+
acts, hooks = dict(), dict()
97+
for layer in layers:
98+
if layer == 'fc':
99+
layer_hook = model
100+
elif layer == 'stem':
101+
layer_hook = model.stem
102+
elif 'block' in layer:
103+
l_ind = int(layer.replace('block', '')) - 1
104+
layer_hook = list(model.blocks.children())[l_ind]
105+
elif 'classifier' in layer:
106+
l_ind = int(layer.replace('classifier', ''))
107+
layer_hook = list(model.classifier.children())[l_ind]
108+
else:
109+
raise RuntimeError('Unsupported MaxViT layer %s' % layer)
110+
hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts))
111+
return acts, hooks
112+
113+
114+
def _swin_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
115+
"""Setting up hooks for the SwinTransformer architecture."""
116+
return _attribute_hooks(model, layers, {'block': model.features})
117+
118+
119+
def _regnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
120+
"""Setting up hooks for the RegNet architecture."""
121+
acts, hooks = dict(), dict()
122+
for layer in layers:
123+
if layer == 'fc':
124+
layer_hook = model
125+
elif layer == 'stem':
126+
layer_hook = model.stem
127+
elif 'block' in layer:
128+
l_ind = int(layer.replace('block', '')) - 1
129+
layer_hook = list(model.trunk_output.children())[l_ind]
130+
else:
131+
raise RuntimeError('Unsupported regnet layer %s' % layer)
132+
hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts))
133+
return acts, hooks
134+
135+
136+
def _child_hook(children: List, layer: str, keyword: str):
137+
l_ind = int(layer.replace(keyword, ''))
138+
return children[l_ind]
139+
140+
141+
def _attribute_hooks(model: nn.Module, layers: List[str],
142+
attributes: Optional[Dict] = None) -> (Dict, Dict):
143+
"""Setting up hooks for networks with children attributes."""
144+
acts, hooks = dict(), dict()
145+
# A dynamic way to get model children with different names
146+
if attributes is None:
147+
attributes = {
148+
'feature': model.features,
149+
'classifier': model.classifier
150+
}
151+
# Looping through all the layers and making the hooks
152+
for layer in layers:
153+
if layer == 'fc':
154+
layer_hook = model
155+
else:
156+
layer_hook = None
157+
for attr, children in attributes.items():
158+
if attr in layer:
159+
layer_hook = _child_hook(children, layer, attr)
160+
break
161+
if layer_hook is None:
162+
raise RuntimeError('Unsupported layer %s' % layer)
163+
hooks[layer] = layer_hook.register_forward_hook(out_hook(layer, acts))
164+
return acts, hooks
165+
166+
167+
def _alexnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
168+
"""Setting up hooks for the AlexNet architecture."""
169+
return _attribute_hooks(model, layers)
170+
171+
172+
def _convnext_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
173+
"""Setting up hooks for the ConvNeXt architecture."""
174+
return _attribute_hooks(model, layers)
175+
176+
177+
def _efficientnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
178+
"""Setting up hooks for the EfficientNet architecture."""
179+
return _attribute_hooks(model, layers)
180+
181+
182+
def _densenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
183+
"""Setting up hooks for the DensNet architecture."""
184+
return _attribute_hooks(model, layers)
185+
186+
187+
def _googlenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
188+
"""Setting up hooks for the GoogLeNet architecture."""
189+
acts, hooks = dict(), dict()
190+
model_layers = list(model.parent_model.children())
191+
for layer in layers:
192+
l_ind = pretrained_layers.googlenet_cutoff_slice(layer)
193+
l_ind = -1 if l_ind is None else l_ind - 1
194+
hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts))
195+
return acts, hooks
196+
197+
198+
def _inception_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
199+
"""Setting up hooks for the Inception architecture."""
200+
acts, hooks = dict(), dict()
201+
model_layers = list(model.parent_model.children())
202+
for layer in layers:
203+
l_ind = pretrained_layers.inception_cutoff_slice(layer)
204+
l_ind = -1 if l_ind is None else l_ind - 1
205+
hooks[layer] = model_layers[l_ind].register_forward_hook(out_hook(layer, acts))
206+
return acts, hooks
207+
208+
209+
def _mnasnet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
210+
"""Setting up hooks for the Mnasnet architecture."""
211+
return _attribute_hooks(model, layers, {'layer': model.layers})
212+
213+
214+
def _shufflenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
215+
"""Setting up hooks for the ShuffleNet architecture."""
216+
return _attribute_hooks(model, layers, {'layer': list(model.children())})
217+
218+
219+
def _mobilenet_hooks(model: nn.Module, layers: List[str], architecture: str) -> (Dict, Dict):
220+
if architecture in ['lraspp_mobilenet_v3_large', 'deeplabv3_mobilenet_v3_large']:
221+
return _attribute_hooks(model, layers, {'feature': list(model.parent_model.children())})
222+
return _attribute_hooks(model, layers)
223+
224+
225+
def _squeezenet_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
226+
"""Setting up hooks for the SqueezeNet architecture."""
227+
return _attribute_hooks(model, layers)
228+
229+
230+
def _vgg_hooks(model: nn.Module, layers: List[str]) -> (Dict, Dict):
231+
"""Setting up hooks for the VGG architecture."""
232+
return _attribute_hooks(model, layers)
76233

77234

78235
def register_model_hooks(model: nn.Module, architecture: str, layers: List[str]) -> (Dict, Dict):
79-
"""Registering forward hooks to the network."""
236+
"""
237+
Register hooks for capturing activation for specific layers in the model.
238+
239+
Parameters:
240+
model (nn.Module): PyTorch model.
241+
architecture (str): Model architecture name.
242+
layers (List[str]): List of layer names for which to register hooks.
243+
244+
Raises:
245+
RuntimeError: If the specified layer is not supported for the given architecture.
246+
247+
Returns:
248+
(Dict, Dict): Dictionaries containing activation values and registered forward hooks.
249+
"""
250+
for layer in layers:
251+
if layer not in pretrained_layers.available_layers(architecture):
252+
raise RuntimeError(
253+
'Layer %s is not supported for architecture %s. Call '
254+
'pretrained_layers.available_layers to see a list of supported layers for an '
255+
'architecture.' % (layer, architecture)
256+
)
257+
80258
if is_resnet_backbone(architecture):
81-
act_dict, rf_hooks = resnet_hooks(model, layers)
259+
return _resnet_hooks(model, layers, architecture)
82260
elif 'clip' in architecture:
83-
act_dict, rf_hooks = clip_hooks(model, layers, architecture)
261+
return _clip_hooks(model, layers, architecture)
262+
elif 'maxvit' in architecture:
263+
return _maxvit_hooks(model, layers)
84264
elif 'vit_' in architecture:
85-
act_dict, rf_hooks = vit_hooks(model, layers)
265+
return _vit_hooks(model, layers)
266+
elif 'regnet' in architecture:
267+
return _regnet_hooks(model, layers)
268+
elif 'vgg' in architecture:
269+
return _vgg_hooks(model, layers)
270+
elif architecture == 'alexnet':
271+
return _alexnet_hooks(model, layers)
272+
elif architecture == 'googlenet':
273+
return _googlenet_hooks(model, layers)
274+
elif architecture == 'inception_v3':
275+
return _inception_hooks(model, layers)
276+
elif 'convnext' in architecture:
277+
return _convnext_hooks(model, layers)
278+
elif 'densenet' in architecture:
279+
return _densenet_hooks(model, layers)
280+
elif 'mnasnet' in architecture:
281+
return _mnasnet_hooks(model, layers)
282+
elif 'shufflenet' in architecture:
283+
return _shufflenet_hooks(model, layers)
284+
elif 'squeezenet' in architecture:
285+
return _squeezenet_hooks(model, layers)
286+
elif 'efficientnet' in architecture:
287+
return _efficientnet_hooks(model, layers)
288+
elif 'mobilenet' in architecture:
289+
return _mobilenet_hooks(model, layers, architecture)
290+
elif 'swin_' in architecture:
291+
return _swin_hooks(model, layers)
86292
else:
87293
raise RuntimeError('Model hooks does not support network %s' % architecture)
88-
return act_dict, rf_hooks
89294

90295

91296
def is_resnet_backbone(architecture: str) -> bool:

0 commit comments

Comments
 (0)