Skip to content

Commit 12c4c49

Browse files
Add method to set model network for nnUNet predictor and update predict method
1 parent 44803a2 commit 12c4c49

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

monai/deploy/operators/monet_bundle_inference_operator.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,17 @@ def _init_config(self, config_names):
6060

6161
self._nnunet_predictor = parser.get_parsed_content("network_def")
6262

63+
def _set_model_network(self, model_network):
64+
"""Sets the model network for the nnUNet predictor."""
65+
if not (isinstance(model_network, torch.nn.Module) or isinstance(model_network, torch.jit.ScriptModule)):
66+
raise TypeError(f"Expected model_network to be a torch.nn.Module or a torch.jit.ScriptModule, got {type(model_network)}")
67+
self._nnunet_predictor.network = model_network
68+
6369
def predict(self, data: Any, *args, **kwargs) -> Union[Image, Any, Tuple[Any, ...], Dict[Any, Any]]:
6470
"""Predicts output using the inferer. If multimodal data is provided as keyword arguments,
6571
it concatenates the data with the main input data."""
6672

67-
self._nnunet_predictor.predictor.network = self._model_network
68-
73+
self._set_model_network(self._model_network)
6974
if len(kwargs) > 0:
7075
multimodal_data = {"image": data}
7176
for key in kwargs.keys():

0 commit comments

Comments
 (0)