@@ -511,6 +511,39 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
511
511
tensor_var_map = tensor_var_map ,
512
512
)
513
513
514
+ @OperatorFactory .register
515
+ class _ConvOperator (_CommonParams ):
516
+ op_type = "ConvOperator"
517
+
518
+ @classmethod
519
+ @must_return_type (Hashable )
520
+ def get_constructor_parameters (cls , op_info ):
521
+
522
+ strides = [
523
+ 1 ,
524
+ op_info .op_attr ['StrideW' ],
525
+ op_info .op_attr ['StrideH' ],
526
+ 1 ,
527
+ ]
528
+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
529
+ strides_str = ',' .join (map (str , strides ))
530
+ return ("{{ {} }}" .format (strides_str ), padding )
531
+
532
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
533
+ return DeclareOpSnippet (
534
+ op = self ,
535
+ templ_dtypes = [self .out_dtypes [0 ]],
536
+ op_var_name = op_var_name ,
537
+ )
538
+
539
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
540
+ return ConvOpEvalSnippet (
541
+ op_info = op_info ,
542
+ templ_dtypes = [self .out_dtypes [0 ]],
543
+ op_name = op_var_name ,
544
+ tensor_var_map = tensor_var_map ,
545
+ )
546
+
514
547
515
548
@OperatorFactory .register
516
549
class _QuantizedFullyConnectedOperator (_CommonParams ):
@@ -689,3 +722,123 @@ def get_eval_snippet(self, op_var_name, op_info, tensor_var_map):
689
722
tensor_var_map = tensor_var_map ,
690
723
nested_namespaces = type (self ).namespaces ,
691
724
)
725
+
726
+ @OperatorFactory .register
727
+ class _BatchNormOperator (_CommonParams ):
728
+ op_type = "BatchNormOperator"
729
+
730
+ @classmethod
731
+ @must_return_type (Hashable )
732
+ def get_constructor_parameters (cls , op_info ):
733
+ strides = [
734
+ 1 ,
735
+ op_info .op_attr ['StrideW' ],
736
+ op_info .op_attr ['StrideH' ],
737
+ 1 ,
738
+ ]
739
+ padding = cls ._PADDING_MAP [op_info .op_attr ['Padding' ]]
740
+ strides_str = ',' .join (map (str , strides ))
741
+ return ("{{ {} }}" .format (strides_str ), padding )
742
+
743
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
744
+ return DeclareOpSnippet (
745
+ op = self ,
746
+ templ_dtypes = [self .out_dtypes [0 ]],
747
+ op_var_name = op_var_name ,
748
+ )
749
+
750
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
751
+ return BatchNormSnippet (
752
+ op_info = op_info ,
753
+ templ_dtypes = [self .out_dtypes [0 ]],
754
+ op_name = op_var_name ,
755
+ tensor_var_map = tensor_var_map ,
756
+ )
757
+
758
+ @OperatorFactory .register
759
+ class _MeanOperator (_CommonParams ):
760
+ op_type = "MeanOperator"
761
+
762
+ @classmethod
763
+ @must_return_type (Hashable )
764
+ def get_constructor_parameters (cls , op_info ):
765
+ keep_dims = str (op_info .op_attr ["keep_dims" ])
766
+ return (" {} " .format (keep_dims ), )
767
+
768
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
769
+ return DeclareOpSnippet (
770
+ op = self ,
771
+ templ_dtypes = [self .out_dtypes [0 ]],
772
+ op_var_name = op_var_name ,
773
+ )
774
+
775
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
776
+ return BatchNormSnippet (
777
+ op_info = op_info ,
778
+ templ_dtypes = [self .out_dtypes [0 ]],
779
+ op_name = op_var_name ,
780
+ tensor_var_map = tensor_var_map ,
781
+ )
782
+
783
+ @OperatorFactory .register
784
+ class _SoftmaxOperator (_CommonParams ):
785
+ op_type = "SoftmaxOperator"
786
+
787
+ @classmethod
788
+ @must_return_type (Hashable )
789
+ def get_constructor_parameters (cls , op_info ):
790
+ Beta = op_info .op_attr ["Beta" ]
791
+ return (" %f " % Beta ,)
792
+
793
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
794
+ return DeclareOpSnippet (
795
+ op = self ,
796
+ templ_dtypes = [self .out_dtypes [0 ]],
797
+ op_var_name = op_var_name ,
798
+ )
799
+
800
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
801
+ return BatchNormSnippet (
802
+ op_info = op_info ,
803
+ templ_dtypes = [self .out_dtypes [0 ]],
804
+ op_name = op_var_name ,
805
+ tensor_var_map = tensor_var_map ,
806
+ )
807
+
808
+ @OperatorFactory .register
809
+ class _MulOperator (_Operator ):
810
+ op_type = 'MulOperator'
811
+
812
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
813
+ return DeclareOpSnippet (
814
+ op = self ,
815
+ templ_dtypes = [self .in_dtypes [0 ]],
816
+ op_var_name = op_var_name ,
817
+ )
818
+
819
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
820
+ return MulOpEvalSnippet (
821
+ op_info = op_info ,
822
+ templ_dtypes = [self .in_dtypes [0 ]],
823
+ op_name = op_var_name ,
824
+ tensor_var_map = tensor_var_map ,
825
+ )
826
+
827
+ @OperatorFactory .register
828
+ class _SubOperator (_Operator ):
829
+ op_type = 'SubOperator'
830
+
831
+ def get_declare_snippet (self , op_var_name , tensor_var_map ):
832
+ return DeclareOpSnippet (
833
+ op = self ,
834
+ templ_dtypes = [self .in_dtypes [0 ]],
835
+ op_var_name = op_var_name ,
836
+ )
837
+
838
+ def get_eval_snippet (self , op_var_name , op_info , tensor_var_map ):
839
+ return SubOpEvalSnippet (
840
+ op_info = op_info ,
841
+ templ_dtypes = [self .in_dtypes [0 ]],
842
+ op_name = op_var_name ,
843
+ tensor_var_map = tensor_var_map ,
844
+ )
0 commit comments