Skip to content

Commit 3464cfd

Browse files
committed
Add BatchNorm, Mul, Add, Sub, Mean, and more ops
1 parent 3fc3293 commit 3464cfd

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

utensor_cgen/backend/utensor/code_generator/rearch/_operators/_impls.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
511511
tensor_var_map=tensor_var_map,
512512
)
513513

514+
@OperatorFactory.register
515+
class _ConvOperator(_CommonParams):
516+
op_type = "ConvOperator"
517+
518+
@classmethod
519+
@must_return_type(Hashable)
520+
def get_constructor_parameters(cls, op_info):
521+
522+
strides = [
523+
1,
524+
op_info.op_attr['StrideW'],
525+
op_info.op_attr['StrideH'],
526+
1,
527+
]
528+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
529+
strides_str = ','.join(map(str, strides))
530+
return ("{{ {} }}".format(strides_str), padding)
531+
532+
def get_declare_snippet(self, op_var_name, tensor_var_map):
533+
return DeclareOpSnippet(
534+
op=self,
535+
templ_dtypes=[self.out_dtypes[0]],
536+
op_var_name=op_var_name,
537+
)
538+
539+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
540+
return ConvOpEvalSnippet(
541+
op_info=op_info,
542+
templ_dtypes=[self.out_dtypes[0]],
543+
op_name=op_var_name,
544+
tensor_var_map=tensor_var_map,
545+
)
546+
514547

515548
@OperatorFactory.register
516549
class _QuantizedFullyConnectedOperator(_CommonParams):
@@ -689,3 +722,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
689722
tensor_var_map=tensor_var_map,
690723
nested_namespaces=type(self).namespaces,
691724
)
725+
726+
@OperatorFactory.register
727+
class _BatchNormOperator(_CommonParams):
728+
op_type = "BatchNormOperator"
729+
730+
@classmethod
731+
@must_return_type(Hashable)
732+
def get_constructor_parameters(cls, op_info):
733+
strides = [
734+
1,
735+
op_info.op_attr['StrideW'],
736+
op_info.op_attr['StrideH'],
737+
1,
738+
]
739+
padding = cls._PADDING_MAP[op_info.op_attr['Padding']]
740+
strides_str = ','.join(map(str, strides))
741+
return ("{{ {} }}".format(strides_str), padding)
742+
743+
def get_declare_snippet(self, op_var_name, tensor_var_map):
744+
return DeclareOpSnippet(
745+
op=self,
746+
templ_dtypes=[self.out_dtypes[0]],
747+
op_var_name=op_var_name,
748+
)
749+
750+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
751+
return BatchNormSnippet(
752+
op_info=op_info,
753+
templ_dtypes=[self.out_dtypes[0]],
754+
op_name=op_var_name,
755+
tensor_var_map=tensor_var_map,
756+
)
757+
758+
@OperatorFactory.register
759+
class _MeanOperator(_CommonParams):
760+
op_type = "MeanOperator"
761+
762+
@classmethod
763+
@must_return_type(Hashable)
764+
def get_constructor_parameters(cls, op_info):
765+
keep_dims = str(op_info.op_attr["keep_dims"])
766+
return (" {} ".format(keep_dims), )
767+
768+
def get_declare_snippet(self, op_var_name, tensor_var_map):
769+
return DeclareOpSnippet(
770+
op=self,
771+
templ_dtypes=[self.out_dtypes[0]],
772+
op_var_name=op_var_name,
773+
)
774+
775+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
776+
return BatchNormSnippet(
777+
op_info=op_info,
778+
templ_dtypes=[self.out_dtypes[0]],
779+
op_name=op_var_name,
780+
tensor_var_map=tensor_var_map,
781+
)
782+
783+
@OperatorFactory.register
784+
class _SoftmaxOperator(_CommonParams):
785+
op_type = "SoftmaxOperator"
786+
787+
@classmethod
788+
@must_return_type(Hashable)
789+
def get_constructor_parameters(cls, op_info):
790+
Beta = op_info.op_attr["Beta"]
791+
return (" %f " % Beta,)
792+
793+
def get_declare_snippet(self, op_var_name, tensor_var_map):
794+
return DeclareOpSnippet(
795+
op=self,
796+
templ_dtypes=[self.out_dtypes[0]],
797+
op_var_name=op_var_name,
798+
)
799+
800+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
801+
return BatchNormSnippet(
802+
op_info=op_info,
803+
templ_dtypes=[self.out_dtypes[0]],
804+
op_name=op_var_name,
805+
tensor_var_map=tensor_var_map,
806+
)
807+
808+
@OperatorFactory.register
809+
class _MulOperator(_Operator):
810+
op_type = 'MulOperator'
811+
812+
def get_declare_snippet(self, op_var_name, tensor_var_map):
813+
return DeclareOpSnippet(
814+
op=self,
815+
templ_dtypes=[self.in_dtypes[0]],
816+
op_var_name=op_var_name,
817+
)
818+
819+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
820+
return MulOpEvalSnippet(
821+
op_info=op_info,
822+
templ_dtypes=[self.in_dtypes[0]],
823+
op_name=op_var_name,
824+
tensor_var_map=tensor_var_map,
825+
)
826+
827+
@OperatorFactory.register
828+
class _SubOperator(_Operator):
829+
op_type = 'SubOperator'
830+
831+
def get_declare_snippet(self, op_var_name, tensor_var_map):
832+
return DeclareOpSnippet(
833+
op=self,
834+
templ_dtypes=[self.in_dtypes[0]],
835+
op_var_name=op_var_name,
836+
)
837+
838+
def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
839+
return SubOpEvalSnippet(
840+
op_info=op_info,
841+
templ_dtypes=[self.in_dtypes[0]],
842+
op_name=op_var_name,
843+
tensor_var_map=tensor_var_map,
844+
)

utensor_cgen/backend/utensor/snippets/rearch/_snippets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
"QuantizedFullyConnectedSnippet",
3131
"MissingOpEvalSnippet",
3232
"TimeSlotContainer",
33-
"BatchNormSnippet",
3433
"MulOpEvalSnippet",
3534
"SubOpEvalSnippet",
3635
"ConvOpEvalSnippet",

0 commit comments

Comments
 (0)