From 337b0b6a3b95064c3aab36b66264e29d5ab9a58d Mon Sep 17 00:00:00 2001 From: Rahul Mohan Date: Tue, 8 Dec 2020 17:38:45 -0500 Subject: [PATCH 1/5] Added attention consensus model --- variantworks/attention.py | 76 +++++++++++++++++++++++++++++++++++++++ variantworks/networks.py | 66 ++++++++++++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 variantworks/attention.py diff --git a/variantworks/attention.py b/variantworks/attention.py new file mode 100644 index 0000000..5e70ce1 --- /dev/null +++ b/variantworks/attention.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn + + +class Attention(nn.Module): + """ Applies attention mechanism on the `context` using the `query`. + + Implementation from: https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html + + Args: + dimensions (int): Dimensionality of the query and context. + attention_type (str, optional): How to compute the attention score: + + * dot: :math:`score(H_j,q) = H_j^T q` + * general: :math:`score(H_j, q) = H_j^T W_a q` + """ + def __init__(self, dimensions, attention_type='general'): + super(Attention, self).__init__() + + if attention_type not in ['dot', 'general']: + raise ValueError('Invalid attention type selected.') + + self.attention_type = attention_type + if self.attention_type == 'general': + self.linear_in = nn.Linear(dimensions, dimensions, bias=False) + + self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False) + self.softmax = nn.Softmax(dim=-1) + self.tanh = nn.Tanh() + + def forward(self, query, context): + """ + Args: + query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of + queries to query the context. + context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data + overwhich to apply the attention mechanism. + + Returns: + :class:`tuple` with `output` and `weights`: + * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]): + Tensor containing the attended features. + * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]): + Tensor containing attention weights. + """ + batch_size, output_len, dimensions = query.size() + query_len = context.size(1) + + if self.attention_type == "general": + query = query.reshape(batch_size * output_len, dimensions) + query = self.linear_in(query) + query = query.reshape(batch_size, output_len, dimensions) + + # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) -> + # (batch_size, output_len, query_len) + attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous()) + + # Compute weights across every context sequence + attention_scores = attention_scores.view(batch_size * output_len, query_len) + attention_weights = self.softmax(attention_scores) + attention_weights = attention_weights.view(batch_size, output_len, query_len) + + # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) -> + # (batch_size, output_len, dimensions) + mix = torch.bmm(attention_weights, context) + + # concat -> (batch_size * output_len, 2*dimensions) + combined = torch.cat((mix, query), dim=2) + combined = combined.view(batch_size * output_len, 2 * dimensions) + + # Apply linear_out on every 2nd dimension of concat + # output -> (batch_size, output_len, dimensions) + output = self.linear_out(combined).view(batch_size, output_len, dimensions) + output = self.tanh(output) + + return output, attention_weights diff --git a/variantworks/networks.py b/variantworks/networks.py index f8af0e7..fa340aa 100644 --- a/variantworks/networks.py +++ b/variantworks/networks.py @@ -22,6 +22,7 @@ from nemo.utils.decorators import add_port_docs from nemo.core.neural_types import NeuralType, ChannelType, LogitsType from nemo.core.neural_factory import DeviceType +from variantworks.attention import Attention class AlexNet(TrainableNM): @@ -175,3 +176,68 @@ def forward(self, encoding): encoding, h_n = self.gru(encoding) encoding = self.classifier(encoding) return encoding + + +class ConsensusAttention(TrainableNM): + """A Neural Module for training a Consensus Attention Model.""" + + @property + @add_port_docs() + def input_ports(self): + """Return definitions of module input ports. + + Returns: + Module input ports. + """ + return { + "encoding": NeuralType(('B', 'W', 'C'), ChannelType()), + } + + @property + @add_port_docs() + def output_ports(self): + """Return definitions of module output ports. + + Returns: + Module output ports. + """ + return { + # Variant type + 'output_logit': NeuralType(('B', 'W', 'D'), LogitsType()), + } + + def __init__(self, sequence_length, input_feature_size, num_output_logits): + """Construct an Consensus RNN NeMo instance. + + Args: + sequence_length : Length of sequence to feed into RNN. + input_feature_size : Length of input feature set. + num_output_logits : Number of output classes of classifier. + + Returns: + Instance of class. + """ + super().__init__() + self.num_output_logits = num_output_logits + + self.attn = Attention(input_feature_size) + self.gru = nn.GRU(input_feature_size, 16, 1, batch_first=True, bidirectional=True) + self.classifier = nn.Linear(32, self.num_output_logits) + + self._device = torch.device( + "cuda" if self.placement == DeviceType.GPU else "cpu") + self.to(self._device) + + def forward(self, encoding): + """Abstract function to run the network. + + Args: + encoding : Input sequence to run network on. + + Returns: + Output of forward pass. + """ + encoding, weights = self.attn(encoding, encoding) + encoding, h_n = self.gru(encoding) + encoding = self.classifier(encoding) + return encoding From c16afb02c4fa2468f4252c8cbd60decf7eb49afc Mon Sep 17 00:00:00 2001 From: Joyjit Daw Date: Thu, 10 Dec 2020 11:22:28 -0500 Subject: [PATCH 2/5] [layers] move attention module to layers --- variantworks/{ => layers}/attention.py | 16 ++++++++++++++++ variantworks/networks.py | 3 ++- 2 files changed, 18 insertions(+), 1 deletion(-) rename variantworks/{ => layers}/attention.py (84%) diff --git a/variantworks/attention.py b/variantworks/layers/attention.py similarity index 84% rename from variantworks/attention.py rename to variantworks/layers/attention.py index 5e70ce1..ef2e1da 100644 --- a/variantworks/attention.py +++ b/variantworks/layers/attention.py @@ -1,3 +1,19 @@ +# +# Copyright 2020 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + import torch import torch.nn as nn diff --git a/variantworks/networks.py b/variantworks/networks.py index fa340aa..242c72f 100644 --- a/variantworks/networks.py +++ b/variantworks/networks.py @@ -22,7 +22,8 @@ from nemo.utils.decorators import add_port_docs from nemo.core.neural_types import NeuralType, ChannelType, LogitsType from nemo.core.neural_factory import DeviceType -from variantworks.attention import Attention + +from variantworks.layers.attention import Attention class AlexNet(TrainableNM): From 384b2e8188acfa9c2be399999dca8693bacfdca9 Mon Sep 17 00:00:00 2001 From: Joyjit Daw Date: Thu, 10 Dec 2020 16:47:47 -0500 Subject: [PATCH 3/5] [networks] refactor and clean up comments --- variantworks/layers/attention.py | 39 +++++++++++++++++--------------- variantworks/networks.py | 4 ++-- 2 files changed, 23 insertions(+), 20 deletions(-) diff --git a/variantworks/layers/attention.py b/variantworks/layers/attention.py index ef2e1da..f372fe4 100644 --- a/variantworks/layers/attention.py +++ b/variantworks/layers/attention.py @@ -13,24 +13,28 @@ # See the License for the specific language governing permissions and # limitations under the License. # +"""Attention related layers.""" import torch import torch.nn as nn class Attention(nn.Module): - """ Applies attention mechanism on the `context` using the `query`. + """Applies attention mechanism on the `context` using the `query`. Implementation from: https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html - - Args: - dimensions (int): Dimensionality of the query and context. - attention_type (str, optional): How to compute the attention score: - - * dot: :math:`score(H_j,q) = H_j^T q` - * general: :math:`score(H_j, q) = H_j^T W_a q` """ + def __init__(self, dimensions, attention_type='general'): + """Construct an Attention layer. + + Args: + dimensions (int): Dimensionality of the query and context. + attention_type (str, optional): How to compute the attention score: + + * dot: :math:`score(H_j,q) = H_j^T q` + * general: :math:`score(H_j, q) = H_j^T W_a q` + """ super(Attention, self).__init__() if attention_type not in ['dot', 'general']: @@ -45,19 +49,18 @@ def __init__(self, dimensions, attention_type='general'): self.tanh = nn.Tanh() def forward(self, query, context): - """ + """Forward method. + Args: - query (:class:`torch.FloatTensor` [batch size, output length, dimensions]): Sequence of - queries to query the context. - context (:class:`torch.FloatTensor` [batch size, query length, dimensions]): Data - overwhich to apply the attention mechanism. + query : Sequence of queries to query the \ + context [batch size, output length, dimensions]. + context : Data over which to apply the attention \ + mechanism [batch size, query length, dimensions]. Returns: - :class:`tuple` with `output` and `weights`: - * **output** (:class:`torch.LongTensor` [batch size, output length, dimensions]): - Tensor containing the attended features. - * **weights** (:class:`torch.FloatTensor` [batch size, output length, query length]): - Tensor containing attention weights. + Tuple with output and weights: + * output : Tensor containing the attended features [batch size, output length, dimensions]. + * weights : Tensor containing attention weights [batch size, output length, query length]. """ batch_size, output_len, dimensions = query.size() query_len = context.size(1) diff --git a/variantworks/networks.py b/variantworks/networks.py index 242c72f..deb7ee0 100644 --- a/variantworks/networks.py +++ b/variantworks/networks.py @@ -238,7 +238,7 @@ def forward(self, encoding): Returns: Output of forward pass. """ - encoding, weights = self.attn(encoding, encoding) - encoding, h_n = self.gru(encoding) + encoding, _ = self.attn(encoding, encoding) + encoding, _ = self.gru(encoding) encoding = self.classifier(encoding) return encoding From a05ac722504eb390d118d05deb3596c064193a9c Mon Sep 17 00:00:00 2001 From: Joyjit Daw Date: Mon, 14 Dec 2020 13:43:49 -0500 Subject: [PATCH 4/5] [layers] add bsd 3-clause license from original source code repo --- variantworks/layers/attention.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/variantworks/layers/attention.py b/variantworks/layers/attention.py index f372fe4..3efc523 100644 --- a/variantworks/layers/attention.py +++ b/variantworks/layers/attention.py @@ -13,6 +13,37 @@ # See the License for the specific language governing permissions and # limitations under the License. # +# +# The implementation in this file is adopted from a 3rd party repository with BSD 3-Clause License. +# BSD 3-Clause License +# +# Copyright (c) James Bradbury and Soumith Chintala 2016, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Attention related layers.""" import torch From 6f72bc08924ba59c5480d1bda3e73d789cc50de7 Mon Sep 17 00:00:00 2001 From: Joyjit Daw Date: Mon, 14 Dec 2020 13:52:16 -0500 Subject: [PATCH 5/5] [tests] add test for attention layer --- tests/test_nn_layers.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 tests/test_nn_layers.py diff --git a/tests/test_nn_layers.py b/tests/test_nn_layers.py new file mode 100644 index 0000000..e7f3620 --- /dev/null +++ b/tests/test_nn_layers.py @@ -0,0 +1,26 @@ +# +# Copyright 2020 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch + +from variantworks.layers.attention import Attention + + +def test_attention_layer(): + input_tensor = torch.zeros((10, 10, 5), dtype=torch.float32) + attn_layer = Attention(5) + out, _ = attn_layer(input_tensor, input_tensor) + assert(torch.all(input_tensor.eq(out)))