@@ -342,6 +342,16 @@ def expected_components(self) -> List[ComponentSpec]:
342
342
def expected_configs (self ) -> List [ConfigSpec ]:
343
343
return []
344
344
345
+ @property
346
+ def intermediate_inputs (self ) -> List [OutputParam ]:
347
+ """List of intermediate output parameters. Must be implemented by subclasses."""
348
+ return []
349
+
350
+ @property
351
+ def intermediate_outputs (self ) -> List [OutputParam ]:
352
+ """List of intermediate output parameters. Must be implemented by subclasses."""
353
+ return []
354
+
345
355
@classmethod
346
356
def from_pretrained (
347
357
cls ,
@@ -423,6 +433,60 @@ def init_pipeline(
423
433
)
424
434
return modular_pipeline
425
435
436
+ def get_block_state (self , state : PipelineState ) -> dict :
437
+ """Get all inputs and intermediates in one dictionary"""
438
+ data = {}
439
+ state_inputs = self .inputs + self .intermediate_inputs
440
+
441
+ # Check inputs
442
+ for input_param in state_inputs :
443
+ if input_param .name :
444
+ value = state .get_input (input_param .name ) or state .get_intermediate (input_param .name )
445
+ if input_param .required and value is None :
446
+ raise ValueError (f"Required input '{ input_param .name } ' is missing" )
447
+ elif value is not None or (value is None and input_param .name not in data ):
448
+ data [input_param .name ] = value
449
+
450
+ elif input_param .kwargs_type :
451
+ # if kwargs_type is provided, get all inputs with matching kwargs_type
452
+ if input_param .kwargs_type not in data :
453
+ data [input_param .kwargs_type ] = {}
454
+ inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type ) or state .get_intermediate_kwargs (
455
+ input_param .kwargs_type
456
+ )
457
+ if inputs_kwargs :
458
+ for k , v in inputs_kwargs .items ():
459
+ if v is not None :
460
+ data [k ] = v
461
+ data [input_param .kwargs_type ][k ] = v
462
+
463
+ return BlockState (** data )
464
+
465
+ def set_block_state (self , state : PipelineState , block_state : BlockState ):
466
+ for output_param in self .intermediate_outputs :
467
+ if not hasattr (block_state , output_param .name ):
468
+ raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
469
+ param = getattr (block_state , output_param .name )
470
+ state .set_intermediate (output_param .name , param , output_param .kwargs_type )
471
+
472
+ for input_param in self .intermediate_inputs :
473
+ if input_param .name and hasattr (block_state , input_param .name ):
474
+ param = getattr (block_state , input_param .name )
475
+ # Only add if the value is different from what's in the state
476
+ current_value = state .get_intermediate (input_param .name )
477
+ if current_value is not param : # Using identity comparison to check if object was modified
478
+ state .set_intermediate (input_param .name , param , input_param .kwargs_type )
479
+ elif input_param .kwargs_type :
480
+ # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
481
+ # we need to first find out which inputs are and loop through them.
482
+ intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
483
+ for param_name , current_value in intermediate_kwargs .items ():
484
+ if not hasattr (block_state , param_name ):
485
+ continue
486
+ param = getattr (block_state , param_name )
487
+ if current_value is not param : # Using identity comparison to check if object was modified
488
+ state .set_intermediate (param_name , param , input_param .kwargs_type )
489
+
426
490
@staticmethod
427
491
def combine_inputs (* named_input_lists : List [Tuple [str , List [InputParam ]]]) -> List [InputParam ]:
428
492
"""
@@ -652,51 +716,6 @@ def doc(self):
652
716
expected_configs = self .expected_configs ,
653
717
)
654
718
655
- # YiYi TODO: input and inteermediate inputs with same name? should warn?
656
- def get_block_state (self , state : PipelineState ) -> dict :
657
- """Get all inputs and intermediates in one dictionary"""
658
- data = {}
659
-
660
- # Check inputs
661
- for input_param in self .inputs :
662
- if input_param .name :
663
- value = state .get_input (input_param .name )
664
- if input_param .required and value is None :
665
- raise ValueError (f"Required input '{ input_param .name } ' is missing" )
666
- elif value is not None or (value is None and input_param .name not in data ):
667
- data [input_param .name ] = value
668
- elif input_param .kwargs_type :
669
- # if kwargs_type is provided, get all inputs with matching kwargs_type
670
- if input_param .kwargs_type not in data :
671
- data [input_param .kwargs_type ] = {}
672
- inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type )
673
- if inputs_kwargs :
674
- for k , v in inputs_kwargs .items ():
675
- if v is not None :
676
- data [k ] = v
677
- data [input_param .kwargs_type ][k ] = v
678
-
679
- # Check intermediates
680
- for input_param in self .intermediate_inputs :
681
- if input_param .name :
682
- value = state .get_intermediate (input_param .name )
683
- if input_param .required and value is None :
684
- raise ValueError (f"Required intermediate input '{ input_param .name } ' is missing" )
685
- elif value is not None or (value is None and input_param .name not in data ):
686
- data [input_param .name ] = value
687
- elif input_param .kwargs_type :
688
- # if kwargs_type is provided, get all intermediates with matching kwargs_type
689
- if input_param .kwargs_type not in data :
690
- data [input_param .kwargs_type ] = {}
691
- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
692
- if intermediate_kwargs :
693
- for k , v in intermediate_kwargs .items ():
694
- if v is not None :
695
- if k not in data :
696
- data [k ] = v
697
- data [input_param .kwargs_type ][k ] = v
698
- return BlockState (** data )
699
-
700
719
def set_block_state (self , state : PipelineState , block_state : BlockState ):
701
720
for output_param in self .intermediate_outputs :
702
721
if not hasattr (block_state , output_param .name ):
@@ -1633,75 +1652,6 @@ def loop_step(self, components, state: PipelineState, **kwargs):
1633
1652
def __call__ (self , components , state : PipelineState ) -> PipelineState :
1634
1653
raise NotImplementedError ("`__call__` method needs to be implemented by the subclass" )
1635
1654
1636
- def get_block_state (self , state : PipelineState ) -> dict :
1637
- """Get all inputs and intermediates in one dictionary"""
1638
- data = {}
1639
-
1640
- # Check inputs
1641
- for input_param in self .inputs :
1642
- if input_param .name :
1643
- value = state .get_input (input_param .name )
1644
- if input_param .required and value is None :
1645
- raise ValueError (f"Required input '{ input_param .name } ' is missing" )
1646
- elif value is not None or (value is None and input_param .name not in data ):
1647
- data [input_param .name ] = value
1648
- elif input_param .kwargs_type :
1649
- # if kwargs_type is provided, get all inputs with matching kwargs_type
1650
- if input_param .kwargs_type not in data :
1651
- data [input_param .kwargs_type ] = {}
1652
- inputs_kwargs = state .get_inputs_kwargs (input_param .kwargs_type )
1653
- if inputs_kwargs :
1654
- for k , v in inputs_kwargs .items ():
1655
- if v is not None :
1656
- data [k ] = v
1657
- data [input_param .kwargs_type ][k ] = v
1658
-
1659
- # Check intermediates
1660
- for input_param in self .intermediate_inputs :
1661
- if input_param .name :
1662
- value = state .get_intermediate (input_param .name )
1663
- if input_param .required and value is None :
1664
- raise ValueError (f"Required intermediate input '{ input_param .name } ' is missing" )
1665
- elif value is not None or (value is None and input_param .name not in data ):
1666
- data [input_param .name ] = value
1667
- elif input_param .kwargs_type :
1668
- # if kwargs_type is provided, get all intermediates with matching kwargs_type
1669
- if input_param .kwargs_type not in data :
1670
- data [input_param .kwargs_type ] = {}
1671
- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
1672
- if intermediate_kwargs :
1673
- for k , v in intermediate_kwargs .items ():
1674
- if v is not None :
1675
- if k not in data :
1676
- data [k ] = v
1677
- data [input_param .kwargs_type ][k ] = v
1678
- return BlockState (** data )
1679
-
1680
- def set_block_state (self , state : PipelineState , block_state : BlockState ):
1681
- for output_param in self .intermediate_outputs :
1682
- if not hasattr (block_state , output_param .name ):
1683
- raise ValueError (f"Intermediate output '{ output_param .name } ' is missing in block state" )
1684
- param = getattr (block_state , output_param .name )
1685
- state .set_intermediate (output_param .name , param , output_param .kwargs_type )
1686
-
1687
- for input_param in self .intermediate_inputs :
1688
- if input_param .name and hasattr (block_state , input_param .name ):
1689
- param = getattr (block_state , input_param .name )
1690
- # Only add if the value is different from what's in the state
1691
- current_value = state .get_intermediate (input_param .name )
1692
- if current_value is not param : # Using identity comparison to check if object was modified
1693
- state .set_intermediate (input_param .name , param , input_param .kwargs_type )
1694
- elif input_param .kwargs_type :
1695
- # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters
1696
- # we need to first find out which inputs are and loop through them.
1697
- intermediate_kwargs = state .get_intermediate_kwargs (input_param .kwargs_type )
1698
- for param_name , current_value in intermediate_kwargs .items ():
1699
- if not hasattr (block_state , param_name ):
1700
- continue
1701
- param = getattr (block_state , param_name )
1702
- if current_value is not param : # Using identity comparison to check if object was modified
1703
- state .set_intermediate (param_name , param , input_param .kwargs_type )
1704
-
1705
1655
@property
1706
1656
def doc (self ):
1707
1657
return make_doc_string (
@@ -1974,7 +1924,6 @@ def __call__(self, state: PipelineState = None, output: Union[str, List[str]] =
1974
1924
1975
1925
# Add inputs to state, using defaults if not provided in the kwargs or the state
1976
1926
# if same input already in the state, will override it if provided in the kwargs
1977
-
1978
1927
intermediate_inputs = [inp .name for inp in self .blocks .intermediate_inputs ]
1979
1928
for expected_input_param in self .blocks .inputs :
1980
1929
name = expected_input_param .name
0 commit comments