-
Notifications
You must be signed in to change notification settings - Fork 380
FastCell Example Fixes, Generalized trainer for both batch_first args #174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 6 commits
c0bafc3
7247d39
e195805
9427308
87c1740
fe86230
9b050b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,8 +144,8 @@ def getVars(self): | |
|
||
def get_model_size(self): | ||
''' | ||
Function to get aimed model size | ||
''' | ||
Function to get aimed model size | ||
''' | ||
mats = self.getVars() | ||
endW = self._num_W_matrices | ||
endU = endW + self._num_U_matrices | ||
|
@@ -261,7 +261,7 @@ def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | |
self.zeta = nn.Parameter(self._zetaInit * torch.ones([1, 1])) | ||
self.nu = nn.Parameter(self._nuInit * torch.ones([1, 1])) | ||
|
||
self.copy_previous_UW() | ||
# self.copy_previous_UW() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this commented out? This supports sparsity for KWS codes. The fastcell_Example.py doesn't depend on it because this is a bit broken given the new pytorch updates. Check with Harsha. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line is causing segmentation fault in the rnnpool codes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah! I know the code is broken, it needs to be fixed. @harsha-simhadri used this in one of the codes so, we might have to fix it. |
||
|
||
@property | ||
def name(self): | ||
|
@@ -330,7 +330,7 @@ class FastGRNNCUDACell(RNNCell): | |
''' | ||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, zetaInit=1.0, nuInit=-4.0, wSparsity=1.0, uSparsity=1.0, name="FastGRNNCUDACell"): | ||
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_non_linearity, update_nonlinearity, | ||
super(FastGRNNCUDACell, self).__init__(input_size, hidden_size, gate_nonlinearity, update_nonlinearity, | ||
1, 1, 2, wRank, uRank, wSparsity, uSparsity) | ||
if utils.findCUDA() is None: | ||
raise Exception('FastGRNNCUDA is supported only on GPU devices.') | ||
|
@@ -967,78 +967,142 @@ class BaseRNN(nn.Module): | |
[batchSize, timeSteps, inputDims] | ||
''' | ||
|
||
def __init__(self, cell: RNNCell, batch_first=False): | ||
def __init__(self, cell: RNNCell, batch_first=False, cell_reverse: RNNCell=None, bidirectional=False): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is this RNNCell=None? I don't know what it does. Why should there be an RNNCell for this argument? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the bidirectional pass doesn't share weights this passes an extra initialized RNNCell for doing the backward pass. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should I resolve this conversation ? @oindrilasaha |
||
super(BaseRNN, self).__init__() | ||
self._RNNCell = cell | ||
self.RNNCell = cell | ||
SachinG007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._batch_first = batch_first | ||
self._bidirectional = bidirectional | ||
if cell_reverse is not None: | ||
self.RNNCell_reverse = cell_reverse | ||
elif self._bidirectional: | ||
self.RNNCell_reverse = cell | ||
SachinG007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def getVars(self): | ||
return self._RNNCell.getVars() | ||
return self.RNNCell.getVars() | ||
|
||
def forward(self, input, hiddenState=None, | ||
cellState=None): | ||
self.device = input.device | ||
self.num_directions = 2 if self._bidirectional else 1 | ||
SachinG007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# hidden | ||
# for i in range(num_directions): | ||
SachinG007 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hiddenStates = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self._RNNCell.output_size]).to(self.device) | ||
self.RNNCell.output_size]).to(self.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All the lines after this are about bi-directional stuff. I would defer to @oindrilasaha to check all this stuff. She has to sign off on it as she uses them the most. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @oindrilasaha any comments for the later section of the code ? |
||
|
||
if self._bidirectional: | ||
hiddenStates_reverse = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self.RNNCell_reverse.output_size]).to(self.device) | ||
|
||
if hiddenState is None: | ||
hiddenState = torch.zeros( | ||
[input.shape[0] if self._batch_first else input.shape[1], | ||
self._RNNCell.output_size]).to(self.device) | ||
[self.num_directions, input.shape[0] if self._batch_first else input.shape[1], | ||
self.RNNCell.output_size]).to(self.device) | ||
|
||
if self._batch_first is True: | ||
if self._RNNCell.cellType == "LSTMLR": | ||
if self.RNNCell.cellType == "LSTMLR": | ||
cellStates = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self._RNNCell.output_size]).to(self.device) | ||
self.RNNCell.output_size]).to(self.device) | ||
if self._bidirectional: | ||
cellStates_reverse = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self.RNNCell_reverse.output_size]).to(self.device) | ||
if cellState is None: | ||
cellState = torch.zeros( | ||
[input.shape[0], self._RNNCell.output_size]).to(self.device) | ||
[self.num_directions, input.shape[0], self.RNNCell.output_size]).to(self.device) | ||
for i in range(0, input.shape[1]): | ||
hiddenState, cellState = self._RNNCell( | ||
input[:, i, :], (hiddenState, cellState)) | ||
hiddenStates[:, i, :] = hiddenState | ||
cellStates[:, i, :] = cellState | ||
return hiddenStates, cellStates | ||
hiddenState[0], cellState[0] = self.RNNCell( | ||
input[:, i, :], (hiddenState[0].clone(), cellState[0].clone())) | ||
hiddenStates[:, i, :] = hiddenState[0] | ||
cellStates[:, i, :] = cellState[0] | ||
if self._bidirectional: | ||
hiddenState[1], cellState[1] = self.RNNCell_reverse( | ||
input[:, input.shape[1]-i-1, :], (hiddenState[1].clone(), cellState[1].clone())) | ||
hiddenStates_reverse[:, i, :] = hiddenState[1] | ||
cellStates_reverse[:, i, :] = cellState[1] | ||
if not self._bidirectional: | ||
return hiddenStates, cellStates | ||
else: | ||
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1) | ||
else: | ||
for i in range(0, input.shape[1]): | ||
hiddenState = self._RNNCell(input[:, i, :], hiddenState) | ||
hiddenStates[:, i, :] = hiddenState | ||
return hiddenStates | ||
hiddenState[0] = self.RNNCell(input[:, i, :], hiddenState[0].clone()) | ||
hiddenStates[:, i, :] = hiddenState[0] | ||
if self._bidirectional: | ||
hiddenState[1] = self.RNNCell_reverse( | ||
input[:, input.shape[1]-i-1, :], hiddenState[1].clone()) | ||
hiddenStates_reverse[:, i, :] = hiddenState[1] | ||
if not self._bidirectional: | ||
return hiddenStates | ||
else: | ||
return torch.cat([hiddenStates,hiddenStates_reverse],-1) | ||
else: | ||
if self._RNNCell.cellType == "LSTMLR": | ||
if self.RNNCell.cellType == "LSTMLR": | ||
cellStates = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self._RNNCell.output_size]).to(self.device) | ||
self.RNNCell.output_size]).to(self.device) | ||
if self._bidirectional: | ||
cellStates_reverse = torch.zeros( | ||
[input.shape[0], input.shape[1], | ||
self.RNNCell_reverse.output_size]).to(self.device) | ||
if cellState is None: | ||
cellState = torch.zeros( | ||
[input.shape[1], self._RNNCell.output_size]).to(self.device) | ||
[self.num_directions, input.shape[1], self.RNNCell.output_size]).to(self.device) | ||
for i in range(0, input.shape[0]): | ||
hiddenState, cellState = self._RNNCell( | ||
input[i, :, :], (hiddenState, cellState)) | ||
hiddenStates[i, :, :] = hiddenState | ||
cellStates[i, :, :] = cellState | ||
return hiddenStates, cellStates | ||
hiddenState[0], cellState[0] = self.RNNCell( | ||
input[i, :, :], (hiddenState[0].clone(), cellState[0].clone())) | ||
hiddenStates[i, :, :] = hiddenState[0] | ||
cellStates[i, :, :] = cellState[0] | ||
if self._bidirectional: | ||
hiddenState[1], cellState[1] = self.RNNCell_reverse( | ||
input[input.shape[0]-i-1, :, :], (hiddenState[1].clone(), cellState[1].clone())) | ||
hiddenStates_reverse[i, :, :] = hiddenState[1] | ||
cellStates_reverse[i, :, :] = cellState[1] | ||
if not self._bidirectional: | ||
return hiddenStates, cellStates | ||
else: | ||
return torch.cat([hiddenStates,hiddenStates_reverse],-1), torch.cat([cellStates,cellStates_reverse],-1) | ||
else: | ||
for i in range(0, input.shape[0]): | ||
hiddenState = self._RNNCell(input[i, :, :], hiddenState) | ||
hiddenStates[i, :, :] = hiddenState | ||
return hiddenStates | ||
hiddenState[0] = self.RNNCell(input[i, :, :], hiddenState[0].clone()) | ||
hiddenStates[i, :, :] = hiddenState[0] | ||
if self._bidirectional: | ||
hiddenState[1] = self.RNNCell_reverse( | ||
input[input.shape[0]-i-1, :, :], hiddenState[1].clone()) | ||
hiddenStates_reverse[i, :, :] = hiddenState[1] | ||
if not self._bidirectional: | ||
return hiddenStates | ||
else: | ||
return torch.cat([hiddenStates,hiddenStates_reverse],-1) | ||
|
||
|
||
class LSTM(nn.Module): | ||
"""Equivalent to nn.LSTM using LSTMLRCell""" | ||
|
||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False): | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False, | ||
bidirectional=False, is_shared_bidirectional=True): | ||
super(LSTM, self).__init__() | ||
self._bidirectional = bidirectional | ||
self._batch_first = batch_first | ||
self._is_shared_bidirectional = is_shared_bidirectional | ||
self.cell = LSTMLRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
if self._bidirectional is True and self._is_shared_bidirectional is False: | ||
self.cell_reverse = LSTMLRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
def forward(self, input, hiddenState=None, cellState=None): | ||
return self.unrollRNN(input, hiddenState, cellState) | ||
|
@@ -1049,14 +1113,26 @@ class GRU(nn.Module): | |
|
||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False): | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False, | ||
bidirectional=False, is_shared_bidirectional=True): | ||
super(GRU, self).__init__() | ||
self._bidirectional = bidirectional | ||
self._batch_first = batch_first | ||
self._is_shared_bidirectional = is_shared_bidirectional | ||
self.cell = GRULRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
if self._bidirectional is True and self._is_shared_bidirectional is False: | ||
self.cell_reverse = GRULRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
def forward(self, input, hiddenState=None, cellState=None): | ||
return self.unrollRNN(input, hiddenState, cellState) | ||
|
@@ -1067,14 +1143,26 @@ class UGRNN(nn.Module): | |
|
||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False): | ||
wSparsity=1.0, uSparsity=1.0, batch_first=False, | ||
bidirectional=False, is_shared_bidirectional=True): | ||
super(UGRNN, self).__init__() | ||
self._bidirectional = bidirectional | ||
self._batch_first = batch_first | ||
self._is_shared_bidirectional = is_shared_bidirectional | ||
self.cell = UGRNNLRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
if self._bidirectional is True and self._is_shared_bidirectional is False: | ||
self.cell_reverse = UGRNNLRCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity) | ||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
def forward(self, input, hiddenState=None, cellState=None): | ||
return self.unrollRNN(input, hiddenState, cellState) | ||
|
@@ -1085,15 +1173,28 @@ class FastRNN(nn.Module): | |
|
||
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, | ||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, batch_first=False): | ||
wSparsity=1.0, uSparsity=1.0, alphaInit=-3.0, betaInit=3.0, | ||
batch_first=False, bidirectional=False, is_shared_bidirectional=True): | ||
super(FastRNN, self).__init__() | ||
self._bidirectional = bidirectional | ||
self._batch_first = batch_first | ||
self._is_shared_bidirectional = is_shared_bidirectional | ||
self.cell = FastRNNCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity, | ||
alphaInit=alphaInit, betaInit=betaInit) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
if self._bidirectional is True and self._is_shared_bidirectional is False: | ||
self.cell_reverse = FastRNNCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity, | ||
alphaInit=alphaInit, betaInit=betaInit) | ||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
def forward(self, input, hiddenState=None, cellState=None): | ||
return self.unrollRNN(input, hiddenState, cellState) | ||
|
@@ -1105,15 +1206,27 @@ class FastGRNN(nn.Module): | |
def __init__(self, input_size, hidden_size, gate_nonlinearity="sigmoid", | ||
update_nonlinearity="tanh", wRank=None, uRank=None, | ||
wSparsity=1.0, uSparsity=1.0, zetaInit=1.0, nuInit=-4.0, | ||
batch_first=False): | ||
batch_first=False, bidirectional=False, is_shared_bidirectional=True): | ||
super(FastGRNN, self).__init__() | ||
self._bidirectional = bidirectional | ||
self._batch_first = batch_first | ||
self._is_shared_bidirectional = is_shared_bidirectional | ||
self.cell = FastGRNNCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity, | ||
zetaInit=zetaInit, nuInit=nuInit) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=batch_first) | ||
self.unrollRNN = BaseRNN(self.cell, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
if self._bidirectional is True and self._is_shared_bidirectional is False: | ||
self.cell_reverse = FastGRNNCell(input_size, hidden_size, | ||
gate_nonlinearity=gate_nonlinearity, | ||
update_nonlinearity=update_nonlinearity, | ||
wRank=wRank, uRank=uRank, | ||
wSparsity=wSparsity, uSparsity=uSparsity, | ||
zetaInit=zetaInit, nuInit=nuInit) | ||
self.unrollRNN = BaseRNN(self.cell, self.cell_reverse, batch_first=self._batch_first, bidirectional=self._bidirectional) | ||
|
||
def getVars(self): | ||
return self.unrollRNN.getVars() | ||
|
@@ -1222,8 +1335,8 @@ def getVars(self): | |
|
||
def get_model_size(self): | ||
''' | ||
Function to get aimed model size | ||
''' | ||
Function to get aimed model size | ||
''' | ||
mats = self.getVars() | ||
endW = self._num_W_matrices | ||
endU = endW + self._num_U_matrices | ||
|
Uh oh!
There was an error while loading. Please reload this page.