Skip to content

Commit 60d1b81

Browse files
committed
update
1 parent 9db9be6 commit 60d1b81

File tree

1 file changed

+64
-115
lines changed

1 file changed

+64
-115
lines changed

src/diffusers/modular_pipelines/modular_pipeline.py

Lines changed: 64 additions & 115 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,16 @@ def expected_components(self) -> List[ComponentSpec]:
342342
def expected_configs(self) -> List[ConfigSpec]:
343343
return []
344344

345+
@property
346+
def intermediate_inputs(self) -> List[OutputParam]:
347+
"""List of intermediate output parameters. Must be implemented by subclasses."""
348+
return []
349+
350+
@property
351+
def intermediate_outputs(self) -> List[OutputParam]:
352+
"""List of intermediate output parameters. Must be implemented by subclasses."""
353+
return []
354+
345355
@classmethod
346356
def from_pretrained(
347357
cls,
@@ -423,6 +433,60 @@ def init_pipeline(
423433
)
424434
return modular_pipeline
425435

436+
def get_block_state(self, state: PipelineState) -> dict:
437+
"""Get all inputs and intermediates in one dictionary"""
438+
data = {}
439+
state_inputs = self.inputs + self.intermediate_inputs
440+
441+
# Check inputs
442+
for input_param in state_inputs:
443+
if input_param.name:
444+
value = state.get_input(input_param.name) or state.get_intermediate(input_param.name)
445+
if input_param.required and value is None:
446+
raise ValueError(f"Required input '{input_param.name}' is missing")
447+
elif value is not None or (value is None and input_param.name not in data):
448+
data[input_param.name] = value
449+
450+
elif input_param.kwargs_type:
451+
# if kwargs_type is provided, get all inputs with matching kwargs_type
452+
if input_param.kwargs_type not in data:
453+
data[input_param.kwargs_type] = {}
454+
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type) or state.get_intermediate_kwargs(
455+
input_param.kwargs_type
456+
)
457+
if inputs_kwargs:
458+
for k, v in inputs_kwargs.items():
459+
if v is not None:
460+
data[k] = v
461+
data[input_param.kwargs_type][k] = v
462+
463+
return BlockState(**data)
464+
465+
def set_block_state(self, state: PipelineState, block_state: BlockState):
466+
for output_param in self.intermediate_outputs:
467+
if not hasattr(block_state, output_param.name):
468+
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
469+
param = getattr(block_state, output_param.name)
470+
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
471+
472+
for input_param in self.intermediate_inputs:
473+
if input_param.name and hasattr(block_state, input_param.name):
474+
param = getattr(block_state, input_param.name)
475+
# Only add if the value is different from what's in the state
476+
current_value = state.get_intermediate(input_param.name)
477+
if current_value is not param: # Using identity comparison to check if object was modified
478+
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
479+
elif input_param.kwargs_type:
480+
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
481+
# we need to first find out which inputs are and loop through them.
482+
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
483+
for param_name, current_value in intermediate_kwargs.items():
484+
if not hasattr(block_state, param_name):
485+
continue
486+
param = getattr(block_state, param_name)
487+
if current_value is not param: # Using identity comparison to check if object was modified
488+
state.set_intermediate(param_name, param, input_param.kwargs_type)
489+
426490
@staticmethod
427491
def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]:
428492
"""
@@ -652,51 +716,6 @@ def doc(self):
652716
expected_configs=self.expected_configs,
653717
)
654718

655-
# YiYi TODO: input and inteermediate inputs with same name? should warn?
656-
def get_block_state(self, state: PipelineState) -> dict:
657-
"""Get all inputs and intermediates in one dictionary"""
658-
data = {}
659-
660-
# Check inputs
661-
for input_param in self.inputs:
662-
if input_param.name:
663-
value = state.get_input(input_param.name)
664-
if input_param.required and value is None:
665-
raise ValueError(f"Required input '{input_param.name}' is missing")
666-
elif value is not None or (value is None and input_param.name not in data):
667-
data[input_param.name] = value
668-
elif input_param.kwargs_type:
669-
# if kwargs_type is provided, get all inputs with matching kwargs_type
670-
if input_param.kwargs_type not in data:
671-
data[input_param.kwargs_type] = {}
672-
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
673-
if inputs_kwargs:
674-
for k, v in inputs_kwargs.items():
675-
if v is not None:
676-
data[k] = v
677-
data[input_param.kwargs_type][k] = v
678-
679-
# Check intermediates
680-
for input_param in self.intermediate_inputs:
681-
if input_param.name:
682-
value = state.get_intermediate(input_param.name)
683-
if input_param.required and value is None:
684-
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
685-
elif value is not None or (value is None and input_param.name not in data):
686-
data[input_param.name] = value
687-
elif input_param.kwargs_type:
688-
# if kwargs_type is provided, get all intermediates with matching kwargs_type
689-
if input_param.kwargs_type not in data:
690-
data[input_param.kwargs_type] = {}
691-
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
692-
if intermediate_kwargs:
693-
for k, v in intermediate_kwargs.items():
694-
if v is not None:
695-
if k not in data:
696-
data[k] = v
697-
data[input_param.kwargs_type][k] = v
698-
return BlockState(**data)
699-
700719
def set_block_state(self, state: PipelineState, block_state: BlockState):
701720
for output_param in self.intermediate_outputs:
702721
if not hasattr(block_state, output_param.name):
@@ -1633,75 +1652,6 @@ def loop_step(self, components, state: PipelineState, **kwargs):
16331652
def __call__(self, components, state: PipelineState) -> PipelineState:
16341653
raise NotImplementedError("`__call__` method needs to be implemented by the subclass")
16351654

1636-
def get_block_state(self, state: PipelineState) -> dict:
1637-
"""Get all inputs and intermediates in one dictionary"""
1638-
data = {}
1639-
1640-
# Check inputs
1641-
for input_param in self.inputs:
1642-
if input_param.name:
1643-
value = state.get_input(input_param.name)
1644-
if input_param.required and value is None:
1645-
raise ValueError(f"Required input '{input_param.name}' is missing")
1646-
elif value is not None or (value is None and input_param.name not in data):
1647-
data[input_param.name] = value
1648-
elif input_param.kwargs_type:
1649-
# if kwargs_type is provided, get all inputs with matching kwargs_type
1650-
if input_param.kwargs_type not in data:
1651-
data[input_param.kwargs_type] = {}
1652-
inputs_kwargs = state.get_inputs_kwargs(input_param.kwargs_type)
1653-
if inputs_kwargs:
1654-
for k, v in inputs_kwargs.items():
1655-
if v is not None:
1656-
data[k] = v
1657-
data[input_param.kwargs_type][k] = v
1658-
1659-
# Check intermediates
1660-
for input_param in self.intermediate_inputs:
1661-
if input_param.name:
1662-
value = state.get_intermediate(input_param.name)
1663-
if input_param.required and value is None:
1664-
raise ValueError(f"Required intermediate input '{input_param.name}' is missing")
1665-
elif value is not None or (value is None and input_param.name not in data):
1666-
data[input_param.name] = value
1667-
elif input_param.kwargs_type:
1668-
# if kwargs_type is provided, get all intermediates with matching kwargs_type
1669-
if input_param.kwargs_type not in data:
1670-
data[input_param.kwargs_type] = {}
1671-
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
1672-
if intermediate_kwargs:
1673-
for k, v in intermediate_kwargs.items():
1674-
if v is not None:
1675-
if k not in data:
1676-
data[k] = v
1677-
data[input_param.kwargs_type][k] = v
1678-
return BlockState(**data)
1679-
1680-
def set_block_state(self, state: PipelineState, block_state: BlockState):
1681-
for output_param in self.intermediate_outputs:
1682-
if not hasattr(block_state, output_param.name):
1683-
raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state")
1684-
param = getattr(block_state, output_param.name)
1685-
state.set_intermediate(output_param.name, param, output_param.kwargs_type)
1686-
1687-
for input_param in self.intermediate_inputs:
1688-
if input_param.name and hasattr(block_state, input_param.name):
1689-
param = getattr(block_state, input_param.name)
1690-
# Only add if the value is different from what's in the state
1691-
current_value = state.get_intermediate(input_param.name)
1692-
if current_value is not param: # Using identity comparison to check if object was modified
1693-
state.set_intermediate(input_param.name, param, input_param.kwargs_type)
1694-
elif input_param.kwargs_type:
1695-
# if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
1696-
# we need to first find out which inputs are and loop through them.
1697-
intermediate_kwargs = state.get_intermediate_kwargs(input_param.kwargs_type)
1698-
for param_name, current_value in intermediate_kwargs.items():
1699-
if not hasattr(block_state, param_name):
1700-
continue
1701-
param = getattr(block_state, param_name)
1702-
if current_value is not param: # Using identity comparison to check if object was modified
1703-
state.set_intermediate(param_name, param, input_param.kwargs_type)
1704-
17051655
@property
17061656
def doc(self):
17071657
return make_doc_string(
@@ -1974,7 +1924,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
19741924

19751925
# Add inputs to state, using defaults if not provided in the kwargs or the state
19761926
# if same input already in the state, will override it if provided in the kwargs
1977-
19781927
intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs]
19791928
for expected_input_param in self.blocks.inputs:
19801929
name = expected_input_param.name

0 commit comments

Comments
 (0)