Skip to content

Commit 4aee999

Browse files
committed
Improve handling of model parameters
Before this commit the model parameters were stored as numpy array. Now the model parameters are converted to a tensorflow variable.
1 parent 667f00b commit 4aee999

File tree

5 files changed

+84
-65
lines changed

5 files changed

+84
-65
lines changed

deeplift/conversion/kerasapi_conversion.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -446,23 +446,25 @@ def convert_sequential_model(
446446
+str(nonlinear_mxts_mode))
447447
sys.stdout.flush()
448448

449-
converted_layers = []
450-
batch_input_shape = model_config[0]['config'][KerasKeys.batch_input_shape]
451-
converted_layers.append(
452-
layers.core.Input(batch_shape=batch_input_shape, name="input"))
453-
#converted_layers is actually mutated to be extended with the
454-
#additional layers so the assignment is not strictly necessary,
455-
#but whatever
456-
converted_layers = sequential_container_conversion(
457-
config=model_config, name="", verbose=verbose,
458-
nonlinear_mxts_mode=nonlinear_mxts_mode,
459-
dense_mxts_mode=dense_mxts_mode,
460-
conv_mxts_mode=conv_mxts_mode,
461-
maxpool_deeplift_mode=maxpool_deeplift_mode,
462-
converted_layers=converted_layers,
463-
layer_overrides=layer_overrides)
464-
converted_layers[-1].build_fwd_pass_vars()
465-
return models.SequentialModel(converted_layers)
449+
# use variable scope if multiple deeplift models are constructed in a session
450+
with tf.variable_scope(None, default_name='deeplift'):
451+
converted_layers = []
452+
batch_input_shape = model_config[0]['config'][KerasKeys.batch_input_shape]
453+
converted_layers.append(
454+
layers.core.Input(batch_shape=batch_input_shape, name="input"))
455+
#converted_layers is actually mutated to be extended with the
456+
#additional layers so the assignment is not strictly necessary,
457+
#but whatever
458+
converted_layers = sequential_container_conversion(
459+
config=model_config, name="", verbose=verbose,
460+
nonlinear_mxts_mode=nonlinear_mxts_mode,
461+
dense_mxts_mode=dense_mxts_mode,
462+
conv_mxts_mode=conv_mxts_mode,
463+
maxpool_deeplift_mode=maxpool_deeplift_mode,
464+
converted_layers=converted_layers,
465+
layer_overrides=layer_overrides)
466+
converted_layers[-1].build_fwd_pass_vars()
467+
return models.SequentialModel(converted_layers)
466468

467469

468470
def sequential_container_conversion(config,
@@ -819,20 +821,22 @@ def convert_functional_model(
819821
if (verbose):
820822
print("nonlinear_mxts_mode is set to: "+str(nonlinear_mxts_mode))
821823

822-
converted_model_container = functional_container_conversion(
823-
config=model_config,
824-
name="", verbose=verbose,
825-
nonlinear_mxts_mode=nonlinear_mxts_mode,
826-
dense_mxts_mode=dense_mxts_mode,
827-
conv_mxts_mode=conv_mxts_mode,
828-
maxpool_deeplift_mode=maxpool_deeplift_mode,
829-
layer_overrides=layer_overrides,
830-
custom_conversion_funcs=custom_conversion_funcs)
831-
832-
for output_layer in converted_model_container.output_layers:
833-
output_layer.build_fwd_pass_vars()
834-
835-
return models.GraphModel(
836-
name_to_layer=converted_model_container.name_to_deeplift_layer,
837-
input_layer_names=converted_model_container.input_layer_names)
824+
# use variable scope if multiple deeplift models are constructed in a session
825+
with tf.variable_scope(None, default_name='deeplift'):
826+
converted_model_container = functional_container_conversion(
827+
config=model_config,
828+
name="", verbose=verbose,
829+
nonlinear_mxts_mode=nonlinear_mxts_mode,
830+
dense_mxts_mode=dense_mxts_mode,
831+
conv_mxts_mode=conv_mxts_mode,
832+
maxpool_deeplift_mode=maxpool_deeplift_mode,
833+
layer_overrides=layer_overrides,
834+
custom_conversion_funcs=custom_conversion_funcs)
835+
836+
for output_layer in converted_model_container.output_layers:
837+
output_layer.build_fwd_pass_vars()
838+
839+
return models.GraphModel(
840+
name_to_layer=converted_model_container.name_to_deeplift_layer,
841+
input_layer_names=converted_model_container.input_layer_names)
838842

deeplift/layers/convolutional.py

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .helper_functions import conv1d_transpose_via_conv2d
66
from . import helper_functions as hf
77
import tensorflow as tf
8+
from deeplift.util import to_tf_variable
89

910
PoolMode = deeplift.util.enum(max='max', avg='avg')
1011
PaddingMode = deeplift.util.enum(same='SAME', valid='VALID')
@@ -34,8 +35,8 @@ def __init__(self, kernel, bias, stride, padding, **kwargs):
3435
super(Conv1D, self).__init__(**kwargs)
3536
#kernel has dimensions:
3637
#length x inp_channels x num output channels
37-
self.kernel = kernel
38-
self.bias = bias
38+
self.kernel = to_tf_variable(kernel, name=self.get_name() + "_kernel")
39+
self.bias = to_tf_variable(bias, name=self.get_name() + "_bias")
3940
if (hasattr(stride, '__iter__')):
4041
assert len(stride)==1
4142
stride=stride[0]
@@ -54,7 +55,7 @@ def _compute_shape(self, input_shape):
5455
1+int((input_shape[1]-self.kernel.shape[0])/self.stride))
5556
elif (self.padding == PaddingMode.same):
5657
shape_to_return.append(
57-
int((input_shape[1]+self.stride-1)/self.stride))
58+
int((input_shape[1]+self.stride-1)/self.stride))
5859
else:
5960
raise RuntimeError("Please implement shape inference for"
6061
" padding mode: "+str(self.padding))
@@ -69,7 +70,7 @@ def _build_activation_vars(self, input_act_vars):
6970

7071
def _build_pos_and_neg_contribs(self):
7172
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
72-
inp_diff_ref = self._get_input_diff_from_reference_vars()
73+
inp_diff_ref = self._get_input_diff_from_reference_vars()
7374
pos_contribs = (self._compute_conv_without_bias(
7475
x=inp_diff_ref*hf.gt_mask(inp_diff_ref,0.0),
7576
kernel=self.kernel*hf.gt_mask(self.kernel,0.0))
@@ -95,12 +96,12 @@ def _compute_conv_without_bias(self, x, kernel):
9596
padding=self.padding)
9697
return conv_without_bias
9798

98-
def _get_mxts_increments_for_inputs(self):
99+
def _get_mxts_increments_for_inputs(self):
99100
pos_mxts = self.get_pos_mxts()
100101
neg_mxts = self.get_neg_mxts()
101-
inp_diff_ref = self._get_input_diff_from_reference_vars()
102+
inp_diff_ref = self._get_input_diff_from_reference_vars()
102103
output_shape = self._get_input_shape()
103-
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
104+
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
104105
pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0)
105106
neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0)
106107
zero_inp_mask = hf.eq_mask(inp_diff_ref,0.0)
@@ -159,8 +160,8 @@ def __init__(self, kernel, bias, strides, padding, data_format, **kwargs):
159160
super(Conv2D, self).__init__(**kwargs)
160161
#kernel has dimensions:
161162
#rows_kern_width x cols_kern_width x inp_channels x num output channels
162-
self.kernel = kernel
163-
self.bias = bias
163+
self.kernel = to_tf_variable(kernel, name=self.get_name() + "_kernel")
164+
self.bias = to_tf_variable(bias, name=self.get_name() + "_bias")
164165
self.strides = strides
165166
self.padding = padding
166167
self.data_format = data_format
@@ -184,12 +185,12 @@ def _compute_shape(self, input_shape):
184185
zip(input_shape[1:3], self.kernel.shape[:2], self.strides):
185186
#overhangs are excluded
186187
shape_to_return.append(
187-
1+int((dim_inp_len-dim_kern_width)/dim_stride))
188+
1+int((dim_inp_len-dim_kern_width)/dim_stride))
188189
elif (self.padding == PaddingMode.same):
189190
for (dim_inp_len, dim_kern_width, dim_stride) in\
190191
zip(input_shape[1:3], self.kernel.shape[:2], self.strides):
191192
shape_to_return.append(
192-
int((dim_inp_len+dim_stride-1)/dim_stride))
193+
int((dim_inp_len+dim_stride-1)/dim_stride))
193194
else:
194195
raise RuntimeError("Please implement shape inference for"
195196
" border mode: "+str(self.padding))
@@ -216,11 +217,11 @@ def _build_activation_vars(self, input_act_vars):
216217
if (self.data_format == DataFormat.channels_first):
217218
to_return = tf.transpose(a=to_return,
218219
perm=[0,3,1,2])
219-
return to_return
220+
return to_return
220221

221222
def _build_pos_and_neg_contribs(self):
222223
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
223-
inp_diff_ref = self._get_input_diff_from_reference_vars()
224+
inp_diff_ref = self._get_input_diff_from_reference_vars()
224225
if (self.data_format == DataFormat.channels_first):
225226
inp_diff_ref = tf.transpose(a=inp_diff_ref,
226227
perm=[0,2,3,1])
@@ -255,10 +256,10 @@ def _compute_conv_without_bias(self, x, kernel):
255256
padding=self.padding)
256257
return conv_without_bias
257258

258-
def _get_mxts_increments_for_inputs(self):
259+
def _get_mxts_increments_for_inputs(self):
259260
pos_mxts = self.get_pos_mxts()
260261
neg_mxts = self.get_neg_mxts()
261-
inp_diff_ref = self._get_input_diff_from_reference_vars()
262+
inp_diff_ref = self._get_input_diff_from_reference_vars()
262263
inp_act_vars = self.inputs.get_activation_vars()
263264
strides_to_supply = [1]+list(self.strides)+[1]
264265

@@ -270,11 +271,11 @@ def _get_mxts_increments_for_inputs(self):
270271

271272
output_shape = tf.shape(inp_act_vars)
272273

273-
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
274+
if (self.conv_mxts_mode == ConvMxtsMode.Linear):
274275
pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0)
275276
neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0)
276277
zero_inp_mask = hf.eq_mask(inp_diff_ref, 0.0)
277-
278+
278279
inp_mxts_increments = pos_inp_mask*(
279280
tf.nn.conv2d_transpose(
280281
value=pos_mxts,
@@ -319,7 +320,7 @@ def _get_mxts_increments_for_inputs(self):
319320

320321
if (self.data_format == DataFormat.channels_first):
321322
pos_mxts_increments = tf.transpose(a=pos_mxts_increments,
322-
perm=(0,3,1,2))
323+
perm=(0,3,1,2))
323324
neg_mxts_increments = tf.transpose(a=neg_mxts_increments,
324325
perm=(0,3,1,2))
325326

deeplift/layers/core.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from collections import namedtuple
88
from collections import OrderedDict
99
from collections import defaultdict
10-
import deeplift.util
10+
import deeplift.util
11+
from deeplift.util import to_tf_variable
1112
from .helper_functions import (
1213
pseudocount_near_zero, add_val_to_col)
1314
from . import helper_functions as hf
@@ -520,8 +521,11 @@ class Dense(SingleInputMixin, OneDimOutputMixin, Node):
520521

521522
def __init__(self, kernel, bias, dense_mxts_mode, **kwargs):
522523
super(Dense, self).__init__(**kwargs)
523-
self.kernel = np.array(kernel).astype("float32")
524-
self.bias = np.array(bias).astype("float32")
524+
525+
self.kernel = to_tf_variable(np.array(kernel).astype("float32"),
526+
name=self.get_name() + "_kernel")
527+
self.bias = to_tf_variable(np.array(bias).astype("float32"),
528+
name=self.get_name() + "_bias")
525529
self.dense_mxts_mode = dense_mxts_mode
526530

527531
def _compute_shape(self, input_shape):
@@ -560,19 +564,22 @@ def _get_mxts_increments_for_inputs(self):
560564
pos_inp_mask = hf.gt_mask(inp_diff_ref,0.0)
561565
neg_inp_mask = hf.lt_mask(inp_diff_ref,0.0)
562566
zero_inp_mask = hf.eq_mask(inp_diff_ref,0.0)
567+
568+
kernel_T = tf.transpose(self.kernel)
569+
563570
inp_mxts_increments = pos_inp_mask*(
564571
tf.matmul(self.get_pos_mxts(),
565-
self.kernel.T*(hf.gt_mask(self.kernel.T, 0.0)))
572+
kernel_T*(hf.gt_mask(kernel_T, 0.0)))
566573
+ tf.matmul(self.get_neg_mxts(),
567-
self.kernel.T*(hf.lt_mask(self.kernel.T, 0.0))))
574+
kernel_T*(hf.lt_mask(kernel_T, 0.0))))
568575
inp_mxts_increments += neg_inp_mask*(
569576
tf.matmul(self.get_pos_mxts(),
570-
self.kernel.T*(hf.lt_mask(self.kernel.T, 0.0)))
577+
kernel_T*(hf.lt_mask(kernel_T, 0.0)))
571578
+ tf.matmul(self.get_neg_mxts(),
572-
self.kernel.T*(hf.gt_mask(self.kernel.T, 0.0))))
579+
kernel_T*(hf.gt_mask(kernel_T, 0.0))))
573580
inp_mxts_increments += zero_inp_mask*(
574581
tf.matmul(0.5*(self.get_pos_mxts()
575-
+self.get_neg_mxts()), self.kernel.T))
582+
+self.get_neg_mxts()), kernel_T))
576583
#pos_mxts and neg_mxts in the input get the same multiplier
577584
#because the breakdown between pos and neg wasn't used to
578585
#compute pos_contribs and neg_contribs in the forward pass

deeplift/layers/normalization.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections import namedtuple
88
from collections import OrderedDict
99
from collections import defaultdict
10-
import deeplift.util
10+
from deeplift.util import to_tf_variable
1111
from .helper_functions import (
1212
pseudocount_near_zero, add_val_to_col)
1313
from . import helper_functions as hf
@@ -33,12 +33,12 @@ def __init__(self, gamma, beta, axis,
3333
#implementation, seems to support these only being one dimensional
3434
assert len(mean.shape)==1
3535
assert len(var.shape)==1
36-
self.gamma = gamma
37-
self.beta = beta
36+
self.gamma = to_tf_variable(gamma, self.get_name() + '_gamma')
37+
self.beta = to_tf_variable(beta, self.get_name() + '_beta')
3838
self.axis = axis
39-
self.mean = mean
40-
self.var = var
41-
self.epsilon = epsilon
39+
self.mean = to_tf_variable(mean, self.get_name() + '_mean')
40+
self.var = to_tf_variable(var, self.get_name() + '_var')
41+
self.epsilon = tf.constant(epsilon)
4242

4343
def _compute_shape(self, input_shape):
4444
return input_shape

deeplift/util.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,10 @@ def in_place_shuffle(arr):
426426
arr[chosen_index] = arr[i]
427427
arr[i] = val_at_index
428428
return arr
429+
430+
431+
def to_tf_variable(np_array, name):
432+
if type(np_array) == list:
433+
np_array = np.array(np_array)
434+
return tf.get_variable(name, dtype=np_array.dtype,
435+
initializer=np_array, trainable=False)

0 commit comments

Comments
 (0)