"""
Recurrent Neural Network implementation for use with structured linear maps.
"""
# machine learning/data science imports
import torch
import torch.nn as nn
import torch.nn.functional as F
# ecosystem imports
import slim
[docs]class RNNCell(nn.Module):
def __init__(self, input_size, hidden_size, bias=False, nonlin=F.gelu,
hidden_map=slim.Linear, input_map=slim.Linear, input_args=dict(),
hidden_args=dict()):
"""
:param input_size: (int) Dimension of input to rnn cell.
:param hidden_size: (int) Dimension of output of rnn cell.
:param bias: (bool) Whether to use bias.
:param nonlinearity: (callable) Activation function
:param linear_map: (nn.Module) A module compatible with torch.nn.Linear
:param linargs: (dict) Arguments to instantiate linear layers
"""
super().__init__()
self.input_size, self.hidden_size = input_size, hidden_size
self.in_features, self.out_features = input_size, hidden_size
self.nonlin = nonlin
self.lin_in = input_map(input_size, hidden_size, bias=bias, **input_args)
self.lin_hidden = hidden_map(hidden_size, hidden_size, bias=bias, **hidden_args)
[docs] def reg_error(self):
"""
:return: (torch.float) Regularization error associated with linear maps.
"""
return (self.lin_in.reg_error() + self.lin_hidden.reg_error())/2.0
[docs] def forward(self, input, hidden):
"""
:param input: (torch.Tensor, shape=[batch_size, input_size]) Input to cell
:param hidden: (torch.Tensor, shape=[batch_size, hidden_size]) Hidden state (typically previous output of cell)
:return: (torch.Tensor, shape=[batchsize, hidden_size]) Cell output
.. doctest::
>>> import slim, torch
>>> cell = slim.RNNCell(5, 8, input_map=slim.Linear, hidden_map=slim.PerronFrobeniusLinear)
>>> x, h = torch.rand(20, 5), torch.rand(20, 8)
>>> output = cell(x, h)
>>> output.shape
torch.Size([20, 8])
"""
return self.nonlin(self.lin_hidden(hidden) + self.lin_in(input))
[docs]class RNN(nn.Module):
def __init__(self, input_size, hidden_size=16, num_layers=1, cell_args=dict()):
"""
Has input and output corresponding to basic usage of torch.nn.RNN module.
No bidirectional, bias, nonlinearity, batch_first, and dropout args.
Cells can incorporate custom linear maps.
Bias and nonlinearity are included in cell args.
:param input_size: (int) Dimension of inputs
:param hidden_size: (int) Dimension of hidden states
:param num_layers: (int) Number of recurrent layers. E.g., setting num_layers=2 would mean stacking two RNNs together
to form a stacked RNN, with the second RNN taking in outputs of the first RNN and
computing the final results. Default: 1
:param cell_args: (dict) Arguments to instantiate RNN cells (see :class:`rnn.RNNCell` for args).
"""
super().__init__()
rnn_cells = [RNNCell(input_size, hidden_size, **cell_args)]
rnn_cells += [RNNCell(hidden_size, hidden_size, **cell_args)
for k in range(num_layers-1)]
self.rnn_cells = nn.ModuleList(rnn_cells)
self.init_states = nn.ParameterList([nn.Parameter(torch.zeros(1, cell.hidden_size))
for cell in self.rnn_cells])
[docs] def reg_error(self):
"""
:return: (torch.float) Regularization error associated with linear maps.
"""
return torch.mean(torch.stack([cell.reg_error() for cell in self.rnn_cells]))
[docs] def forward(self, sequence, init_states=None):
"""
:param sequence: (torch.Tensor, shape=[seq_len, batch, input_size]) Input sequence to RNN
:param init_state: (torch.Tensor, shape=[num_layers, batch, hidden_size]) :math:`h_0`, initial hidden states for stacked RNNCells
:returns:
- output: (seq_len, batch, hidden_size) Sequence of outputs
- :math:`h_n`: (num_layers, batch, hidden_size) Final hidden states for stack of RNN cells.
.. doctest::
>>> import slim, torch
>>> rnn = slim.RNN(5, hidden_size=8, num_layers=3, cell_args={'hidden_map': slim.PerronFrobeniusLinear})
>>> x = torch.rand(20, 10, 5)
>>> output, h_n = rnn(x)
>>> output.shape, h_n.shape
(torch.Size([20, 10, 8]), torch.Size([3, 10, 8]))
"""
assert len(sequence.shape) == 3, f'RNN takes order 3 tensor with shape=(seq_len, nsamples, {self.insize})'
if init_states is None:
init_states = self.init_states
final_hiddens = []
for h, cell in zip(init_states, self.rnn_cells):
# loop over stack of cells
states = []
for seq_idx, cell_input in enumerate(sequence):
# loop over sequence
h = cell(cell_input, h)
states.append(h)
sequence = torch.stack(states)
final_hiddens.append(h) # Save final hidden state for each cell in case need to do truncated back prop
assert torch.equal(sequence[-1, :, :], final_hiddens[-1])
return sequence, torch.stack(final_hiddens)
if __name__ == '__main__':
x = torch.rand(20, 5, 8)
for bias in [True, False]:
for num_layers in [1, 2]:
for name, map in slim.maps.items():
print(name)
print(map)
rnn = RNN(8, hidden_size=8, num_layers=num_layers, cell_args={'bias': bias,
'nonlin': F.gelu,
'hidden_map': map,
'input_map': slim.Linear})
out = rnn(x)
print(out[0].shape, out[1].shape)
for map in set(slim.maps.values()) - slim.square_maps:
print(name)
rnn = RNN(8, hidden_size=16, num_layers=num_layers, cell_args={'bias': bias,
'nonlin': F.gelu,
'hidden_map': map,
'input_map': slim.Linear})
out = rnn(x)
print(out[0].shape, out[1].shape)