5
5
from .helper_functions import conv1d_transpose_via_conv2d
6
6
from . import helper_functions as hf
7
7
import tensorflow as tf
8
+ from deeplift .util import to_tf_variable
8
9
9
10
PoolMode = deeplift .util .enum (max = 'max' , avg = 'avg' )
10
11
PaddingMode = deeplift .util .enum (same = 'SAME' , valid = 'VALID' )
@@ -34,8 +35,8 @@ def __init__(self, kernel, bias, stride, padding, **kwargs):
34
35
super (Conv1D , self ).__init__ (** kwargs )
35
36
#kernel has dimensions:
36
37
#length x inp_channels x num output channels
37
- self .kernel = kernel
38
- self .bias = bias
38
+ self .kernel = to_tf_variable ( kernel , name = self . get_name () + "_kernel" )
39
+ self .bias = to_tf_variable ( bias , name = self . get_name () + "_bias" )
39
40
if (hasattr (stride , '__iter__' )):
40
41
assert len (stride )== 1
41
42
stride = stride [0 ]
@@ -54,7 +55,7 @@ def _compute_shape(self, input_shape):
54
55
1 + int ((input_shape [1 ]- self .kernel .shape [0 ])/ self .stride ))
55
56
elif (self .padding == PaddingMode .same ):
56
57
shape_to_return .append (
57
- int ((input_shape [1 ]+ self .stride - 1 )/ self .stride ))
58
+ int ((input_shape [1 ]+ self .stride - 1 )/ self .stride ))
58
59
else :
59
60
raise RuntimeError ("Please implement shape inference for"
60
61
" padding mode: " + str (self .padding ))
@@ -69,7 +70,7 @@ def _build_activation_vars(self, input_act_vars):
69
70
70
71
def _build_pos_and_neg_contribs (self ):
71
72
if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
72
- inp_diff_ref = self ._get_input_diff_from_reference_vars ()
73
+ inp_diff_ref = self ._get_input_diff_from_reference_vars ()
73
74
pos_contribs = (self ._compute_conv_without_bias (
74
75
x = inp_diff_ref * hf .gt_mask (inp_diff_ref ,0.0 ),
75
76
kernel = self .kernel * hf .gt_mask (self .kernel ,0.0 ))
@@ -95,12 +96,12 @@ def _compute_conv_without_bias(self, x, kernel):
95
96
padding = self .padding )
96
97
return conv_without_bias
97
98
98
- def _get_mxts_increments_for_inputs (self ):
99
+ def _get_mxts_increments_for_inputs (self ):
99
100
pos_mxts = self .get_pos_mxts ()
100
101
neg_mxts = self .get_neg_mxts ()
101
- inp_diff_ref = self ._get_input_diff_from_reference_vars ()
102
+ inp_diff_ref = self ._get_input_diff_from_reference_vars ()
102
103
output_shape = self ._get_input_shape ()
103
- if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
104
+ if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
104
105
pos_inp_mask = hf .gt_mask (inp_diff_ref ,0.0 )
105
106
neg_inp_mask = hf .lt_mask (inp_diff_ref ,0.0 )
106
107
zero_inp_mask = hf .eq_mask (inp_diff_ref ,0.0 )
@@ -159,8 +160,8 @@ def __init__(self, kernel, bias, strides, padding, data_format, **kwargs):
159
160
super (Conv2D , self ).__init__ (** kwargs )
160
161
#kernel has dimensions:
161
162
#rows_kern_width x cols_kern_width x inp_channels x num output channels
162
- self .kernel = kernel
163
- self .bias = bias
163
+ self .kernel = to_tf_variable ( kernel , name = self . get_name () + "_kernel" )
164
+ self .bias = to_tf_variable ( bias , name = self . get_name () + "_bias" )
164
165
self .strides = strides
165
166
self .padding = padding
166
167
self .data_format = data_format
@@ -184,12 +185,12 @@ def _compute_shape(self, input_shape):
184
185
zip (input_shape [1 :3 ], self .kernel .shape [:2 ], self .strides ):
185
186
#overhangs are excluded
186
187
shape_to_return .append (
187
- 1 + int ((dim_inp_len - dim_kern_width )/ dim_stride ))
188
+ 1 + int ((dim_inp_len - dim_kern_width )/ dim_stride ))
188
189
elif (self .padding == PaddingMode .same ):
189
190
for (dim_inp_len , dim_kern_width , dim_stride ) in \
190
191
zip (input_shape [1 :3 ], self .kernel .shape [:2 ], self .strides ):
191
192
shape_to_return .append (
192
- int ((dim_inp_len + dim_stride - 1 )/ dim_stride ))
193
+ int ((dim_inp_len + dim_stride - 1 )/ dim_stride ))
193
194
else :
194
195
raise RuntimeError ("Please implement shape inference for"
195
196
" border mode: " + str (self .padding ))
@@ -216,11 +217,11 @@ def _build_activation_vars(self, input_act_vars):
216
217
if (self .data_format == DataFormat .channels_first ):
217
218
to_return = tf .transpose (a = to_return ,
218
219
perm = [0 ,3 ,1 ,2 ])
219
- return to_return
220
+ return to_return
220
221
221
222
def _build_pos_and_neg_contribs (self ):
222
223
if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
223
- inp_diff_ref = self ._get_input_diff_from_reference_vars ()
224
+ inp_diff_ref = self ._get_input_diff_from_reference_vars ()
224
225
if (self .data_format == DataFormat .channels_first ):
225
226
inp_diff_ref = tf .transpose (a = inp_diff_ref ,
226
227
perm = [0 ,2 ,3 ,1 ])
@@ -255,10 +256,10 @@ def _compute_conv_without_bias(self, x, kernel):
255
256
padding = self .padding )
256
257
return conv_without_bias
257
258
258
- def _get_mxts_increments_for_inputs (self ):
259
+ def _get_mxts_increments_for_inputs (self ):
259
260
pos_mxts = self .get_pos_mxts ()
260
261
neg_mxts = self .get_neg_mxts ()
261
- inp_diff_ref = self ._get_input_diff_from_reference_vars ()
262
+ inp_diff_ref = self ._get_input_diff_from_reference_vars ()
262
263
inp_act_vars = self .inputs .get_activation_vars ()
263
264
strides_to_supply = [1 ]+ list (self .strides )+ [1 ]
264
265
@@ -270,11 +271,11 @@ def _get_mxts_increments_for_inputs(self):
270
271
271
272
output_shape = tf .shape (inp_act_vars )
272
273
273
- if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
274
+ if (self .conv_mxts_mode == ConvMxtsMode .Linear ):
274
275
pos_inp_mask = hf .gt_mask (inp_diff_ref ,0.0 )
275
276
neg_inp_mask = hf .lt_mask (inp_diff_ref ,0.0 )
276
277
zero_inp_mask = hf .eq_mask (inp_diff_ref , 0.0 )
277
-
278
+
278
279
inp_mxts_increments = pos_inp_mask * (
279
280
tf .nn .conv2d_transpose (
280
281
value = pos_mxts ,
@@ -319,7 +320,7 @@ def _get_mxts_increments_for_inputs(self):
319
320
320
321
if (self .data_format == DataFormat .channels_first ):
321
322
pos_mxts_increments = tf .transpose (a = pos_mxts_increments ,
322
- perm = (0 ,3 ,1 ,2 ))
323
+ perm = (0 ,3 ,1 ,2 ))
323
324
neg_mxts_increments = tf .transpose (a = neg_mxts_increments ,
324
325
perm = (0 ,3 ,1 ,2 ))
325
326
0 commit comments