diff --git a/combo/combo_model.py b/combo/combo_model.py
index b87eb1c0719fdd1e8fc4e1e4695559deb81644cb..d87d514adeddb223bec672e43d664d347f97cf82 100644
--- a/combo/combo_model.py
+++ b/combo/combo_model.py
@@ -1,5 +1,5 @@
 """Main COMBO model."""
-from typing import Optional, Dict, Any, List
+from typing import Optional, Dict, Any, List, Union
 
 import numpy
 import torch
@@ -24,6 +24,7 @@ from combo.nn.utils import get_text_field_mask
 from combo.predictors import Predictor
 from combo.utils import metrics
 from combo.utils import ConfigurationError
+from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
 
 
 @Registry.register("semantic_multitask")
@@ -39,7 +40,7 @@ class ComboModel(Model, FromParameters):
                  vocabulary: data.Vocabulary,
                  loss_weights: Dict[str, float],
                  text_field_embedder: TextFieldEmbedder,
-                 seq_encoder: Seq2SeqEncoder,
+                 seq_encoder: Union[Seq2SeqEncoder, TransformerEncoder],
                  use_sample_weight: bool = True,
                  lemmatizer: LemmatizerModel = None,
                  upos_tagger: MorphologicalFeatures = None,
diff --git a/combo/config/from_parameters.py b/combo/config/from_parameters.py
index 743c80ac18587ed741651388e5b0dff6815d7be6..8d624073f83449b68afdc078a65f1e9934b27b7c 100644
--- a/combo/config/from_parameters.py
+++ b/combo/config/from_parameters.py
@@ -58,6 +58,8 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No
         return {k: serialize_single_value(v, pass_down_parameter_names) for k, v in value.items()}
     elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str):
         return value
+    elif value is None:
+        return None
     else:
         return str(value)
 
diff --git a/combo/default_model.py b/combo/default_model.py
index 4ec705d0f0e28b68fe28c89f9c216b67c3561197..61be958707c01d1e79fe39c77b2bbc64af99310f 100644
--- a/combo/default_model.py
+++ b/combo/default_model.py
@@ -12,7 +12,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
 from combo.data.tokenizers import CharacterTokenizer
 from combo.data.vocabulary import Vocabulary
 from combo.combo_model import ComboModel
-from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM
+from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM, ComboTransformerEncoder
 from combo.models.dilated_cnn import DilatedCnnEncoder
 from combo.modules.lemma import LemmatizerModel
 from combo.modules.morpho import MorphologicalFeatures
@@ -122,124 +122,247 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
     )
 
 
-def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> ComboModel:
-    return ComboModel(
-        vocabulary=vocabulary,
-        dependency_relation=DependencyRelationModel(
+def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary, use_transformer_encoder=False) -> ComboModel:
+    if use_transformer_encoder:
+        return ComboModel(
             vocabulary=vocabulary,
-            dependency_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=128
-            ),
-            head_predictor=HeadPredictionModel(
-                cycle_loss_n=0,
+            dependency_relation=DependencyRelationModel(
+                vocabulary=vocabulary,
                 dependency_projection_layer=Linear(
                     activation=TanhActivation(),
+                    dropout_rate=0.25,
                     in_features=1024,
-                    out_features=512
+                    out_features=128
+                ),
+                head_predictor=HeadPredictionModel(
+                    cycle_loss_n=0,
+                    dependency_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
+                    ),
+                    head_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
+                    )
                 ),
                 head_projection_layer=Linear(
                     activation=TanhActivation(),
+                    dropout_rate=0.25,
                     in_features=1024,
-                    out_features=512
-                )
+                    out_features=128
+                ),
+                vocab_namespace="deprel_labels"
             ),
-            head_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=128
+            lemmatizer=LemmatizerModel(
+                vocabulary=vocabulary,
+                activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
+                char_vocab_namespace="token_characters",
+                dilation=[1, 2, 4, 1],
+                embedding_dim=300,
+                filters=[256, 256, 256],
+                input_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=1024,
+                    out_features=32
+                ),
+                kernel_size=[3, 3, 3, 1],
+                lemma_vocab_namespace="lemma_characters",
+                padding=[1, 2, 4, 0],
+                stride=[1, 1, 1, 1]
             ),
-            vocab_namespace="deprel_labels"
-        ),
-        lemmatizer=LemmatizerModel(
-            vocabulary=vocabulary,
-            activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
-            char_vocab_namespace="token_characters",
-            dilation=[1, 2, 4, 1],
-            embedding_dim=300,
-            filters=[256, 256, 256],
-            input_projection_layer=Linear(
-                activation=TanhActivation(),
-                dropout_rate=0.25,
-                in_features=1024,
-                out_features=32
+            loss_weights={
+                "deprel": 0.8,
+                "feats": 0.2,
+                "head": 0.2,
+                "lemma": 0.05,
+                "semrel": 0.05,
+                "upostag": 0.05,
+                "xpostag": 0.05
+            },
+            morphological_feat=MorphologicalFeatures(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[128],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="feats_labels"
             ),
-            kernel_size=[3, 3, 3, 1],
-            lemma_vocab_namespace="lemma_characters",
-            padding=[1, 2, 4, 0],
-            stride=[1, 1, 1, 1]
-        ),
-        loss_weights={
-            "deprel": 0.8,
-            "feats": 0.2,
-            "head": 0.2,
-            "lemma": 0.05,
-            "semrel": 0.05,
-            "upostag": 0.05,
-            "xpostag": 0.05
-        },
-        morphological_feat=MorphologicalFeatures(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[128],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="feats_labels"
-        ),
-        regularizer=RegularizerApplicator([
-            (".*conv1d.*", L2Regularizer(1e-6)),
-            (".*forward.*", L2Regularizer(1e-6)),
-            (".*backward.*", L2Regularizer(1e-6)),
-            (".*char_embed.*", L2Regularizer(1e-5))
-        ]),
-        seq_encoder=ComboEncoder(
-            layer_dropout_probability=0.33,
-            stacked_bilstm=ComboStackedBidirectionalLSTM(
-                hidden_size=512,
-                input_size=164,
+            regularizer=RegularizerApplicator([
+                (".*conv1d.*", L2Regularizer(1e-6)),
+                (".*forward.*", L2Regularizer(1e-6)),
+                (".*backward.*", L2Regularizer(1e-6)),
+                (".*char_embed.*", L2Regularizer(1e-5))
+            ]),
+            seq_encoder=ComboTransformerEncoder(
                 layer_dropout_probability=0.33,
+                input_dim=164,
                 num_layers=2,
-                recurrent_dropout_probability=0.33
+                feedforward_hidden_dim=2048,
+                num_attention_heads=4,
+                positional_encoding=None,
+                positional_embedding_size=512,
+                dropout_prob=0.1,
+                activation="relu"
+            ),
+            text_field_embedder=BasicTextFieldEmbedder(
+                token_embedders={
+                    "char": CharacterBasedWordEmbedder(
+                        vocabulary=vocabulary,
+                        dilated_cnn_encoder=DilatedCnnEncoder(
+                            activations=[GELUActivation(), GELUActivation(), LinearActivation()],
+                            dilation=[1, 2, 4],
+                            filters=[512, 256, 64],
+                            input_dim=64,
+                            kernel_size=[3, 3, 3],
+                            padding=[1, 2, 4],
+                            stride=[1, 1, 1],
+                        ),
+                        embedding_dim=64
+                    ),
+                    "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
+                }
+            ),
+            upos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="upostag_labels"
+            ),
+            xpos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="xpostag_labels"
             )
-        ),
-        text_field_embedder=BasicTextFieldEmbedder(
-            token_embedders={
-                "char": CharacterBasedWordEmbedder(
-                    vocabulary=vocabulary,
-                    dilated_cnn_encoder=DilatedCnnEncoder(
-                        activations=[GELUActivation(), GELUActivation(), LinearActivation()],
-                        dilation=[1, 2, 4],
-                        filters=[512, 256, 64],
-                        input_dim=64,
-                        kernel_size=[3, 3, 3],
-                        padding=[1, 2, 4],
-                        stride=[1, 1, 1],
+        )
+    else:
+        return ComboModel(
+            vocabulary=vocabulary,
+            dependency_relation=DependencyRelationModel(
+                vocabulary=vocabulary,
+                dependency_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=1024,
+                    out_features=128
+                ),
+                head_predictor=HeadPredictionModel(
+                    cycle_loss_n=0,
+                    dependency_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
                     ),
-                    embedding_dim=64
+                    head_projection_layer=Linear(
+                        activation=TanhActivation(),
+                        in_features=1024,
+                        out_features=512
+                    )
                 ),
-                "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
-            }
-        ),
-        upos_tagger=FeedForwardPredictor.from_vocab(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[64],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="upostag_labels"
-        ),
-        xpos_tagger=FeedForwardPredictor.from_vocab(
-            vocabulary=vocabulary,
-            activations=[TanhActivation(), LinearActivation()],
-            dropout=[0.25, 0.],
-            hidden_dims=[64],
-            input_dim=1024,
-            num_layers=2,
-            vocab_namespace="xpostag_labels"
+                head_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=1024,
+                    out_features=128
+                ),
+                vocab_namespace="deprel_labels"
+            ),
+            lemmatizer=LemmatizerModel(
+                vocabulary=vocabulary,
+                activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
+                char_vocab_namespace="token_characters",
+                dilation=[1, 2, 4, 1],
+                embedding_dim=300,
+                filters=[256, 256, 256],
+                input_projection_layer=Linear(
+                    activation=TanhActivation(),
+                    dropout_rate=0.25,
+                    in_features=1024,
+                    out_features=32
+                ),
+                kernel_size=[3, 3, 3, 1],
+                lemma_vocab_namespace="lemma_characters",
+                padding=[1, 2, 4, 0],
+                stride=[1, 1, 1, 1]
+            ),
+            loss_weights={
+                "deprel": 0.8,
+                "feats": 0.2,
+                "head": 0.2,
+                "lemma": 0.05,
+                "semrel": 0.05,
+                "upostag": 0.05,
+                "xpostag": 0.05
+            },
+            morphological_feat=MorphologicalFeatures(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[128],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="feats_labels"
+            ),
+            regularizer=RegularizerApplicator([
+                (".*conv1d.*", L2Regularizer(1e-6)),
+                (".*forward.*", L2Regularizer(1e-6)),
+                (".*backward.*", L2Regularizer(1e-6)),
+                (".*char_embed.*", L2Regularizer(1e-5))
+            ]),
+            seq_encoder=ComboEncoder(
+                layer_dropout_probability=0.33,
+                stacked_bilstm=ComboStackedBidirectionalLSTM(
+                    hidden_size=512,
+                    input_size=164,
+                    layer_dropout_probability=0.33,
+                    num_layers=2,
+                    recurrent_dropout_probability=0.33
+                )
+            ),
+            text_field_embedder=BasicTextFieldEmbedder(
+                token_embedders={
+                    "char": CharacterBasedWordEmbedder(
+                        vocabulary=vocabulary,
+                        dilated_cnn_encoder=DilatedCnnEncoder(
+                            activations=[GELUActivation(), GELUActivation(), LinearActivation()],
+                            dilation=[1, 2, 4],
+                            filters=[512, 256, 64],
+                            input_dim=64,
+                            kernel_size=[3, 3, 3],
+                            padding=[1, 2, 4],
+                            stride=[1, 1, 1],
+                        ),
+                        embedding_dim=64
+                    ),
+                    "token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
+                }
+            ),
+            upos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="upostag_labels"
+            ),
+            xpos_tagger=FeedForwardPredictor.from_vocab(
+                vocabulary=vocabulary,
+                activations=[TanhActivation(), LinearActivation()],
+                dropout=[0.25, 0.],
+                hidden_dims=[64],
+                input_dim=1024,
+                num_layers=2,
+                vocab_namespace="xpostag_labels"
+            )
         )
-    )
diff --git a/combo/main.py b/combo/main.py
index 256e75aed0641d64b6b049d92563d5605badfda9..ba135d46a8a4333e2d8d5ba56e52636859560a64 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -84,6 +84,7 @@ flags.DEFINE_boolean(name="turns", default=False,
                      help="Segment into sentences on sentence break or on turn break.")
 flags.DEFINE_boolean(name="split_subwords", default=False,
                      help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
+flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
 
 # Finetune after training flags
 flags.DEFINE_string(name="finetuning_training_data_path", default="",
@@ -312,7 +313,11 @@ def run(_):
                     FLAGS.validation_data_path,
                     prefix
                 )
-                model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
+
+                if FLAGS.transformer_encoder:
+                    model = default_model(FLAGS.pretrained_transformer_name, vocabulary, True)
+                else:
+                    model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
 
             if FLAGS.use_pure_config and model is None:
                 logger.error('Error in configuration - model could not be read from parameters. ' +
diff --git a/combo/models/__init__.py b/combo/models/__init__.py
index c2d36ae531e502632c5b33e4802fcc3419baa9dd..68d9046e3366ab8952e2c355f1d9f8721808b175 100644
--- a/combo/models/__init__.py
+++ b/combo/models/__init__.py
@@ -1 +1 @@
-from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder
+from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder, ComboTransformerEncoder
diff --git a/combo/models/encoder.py b/combo/models/encoder.py
index 8ed5aff127de187ff51a929ec5e42b345e62e498..c751836235bbfa09a8547d85fdc3e6862844d4d0 100644
--- a/combo/models/encoder.py
+++ b/combo/models/encoder.py
@@ -15,6 +15,7 @@ from combo.modules.augmented_lstm import AugmentedLstm
 from combo.modules.input_variational_dropout import InputVariationalDropout
 from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
 from combo.utils import ConfigurationError
+from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
 
 TensorPair = Tuple[torch.Tensor, torch.Tensor]
 
@@ -247,3 +248,46 @@ class ComboEncoder(Seq2SeqEncoder, FromParameters):
         x = self.layer_dropout(inputs)
         x = super().forward(x, mask)
         return self.layer_dropout(x)
+
+
+@Registry.register('combo_transformer_encoder')
+class ComboTransformerEncoder(TransformerEncoder, FromParameters):
+    """COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
+
+    This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
+    (instead of used Gaussian Dropout and Gaussian Noise).
+    """
+
+    @register_arguments
+    def __init__(self,
+                 layer_dropout_probability: float,
+                 input_dim: int,
+                 num_layers: int,
+                 feedforward_hidden_dim: int = 2048,
+                 num_attention_heads: int = 4,
+                 positional_encoding: Optional[str] = None,
+                 positional_embedding_size: int = 512,
+                 dropout_prob: float = 0.1,
+                 activation: str = "relu"
+                 # stacked_transformer: ComboStackedBidirectionalLSTM,
+                ):
+        super().__init__(
+            input_dim,
+            num_layers,
+            feedforward_hidden_dim,
+            num_attention_heads,
+            positional_encoding,
+            positional_embedding_size,
+            dropout_prob,
+            activation
+        )
+
+        self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
+
+    def forward(self,
+                inputs: torch.Tensor,
+                mask: torch.BoolTensor,
+                hidden_state: torch.Tensor = None) -> torch.Tensor:
+        x = self.layer_dropout(inputs)
+        x = super().forward(x, mask)
+        return self.layer_dropout(x)
\ No newline at end of file
diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..a49100def51df028bd8427ea65e0ab8bd818905a
--- /dev/null
+++ b/combo/modules/seq2seq_encoders/transformer_encoder.py
@@ -0,0 +1,117 @@
+from typing import Optional
+
+from overrides import overrides
+import torch
+from torch import nn
+
+from combo.modules.encoder import _EncoderBase
+from combo.config.from_parameters import FromParameters, register_arguments
+
+# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
+from nn.utils import add_positional_features
+
+
+# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
+# from allennlp.nn.util import add_positional_features
+
+
+class TransformerEncoder(_EncoderBase, FromParameters):
+    """
+    Implements a stacked self-attention encoder similar to the Transformer
+    architecture in [Attention is all you Need]
+    (https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077).
+
+    This class adapts the Transformer from torch.nn for use in AllenNLP. Optionally, it adds positional encodings.
+
+    Registered as a `Seq2SeqEncoder` with name "pytorch_transformer".
+
+    # Parameters
+
+    input_dim : `int`, required.
+        The input dimension of the encoder.
+    feedforward_hidden_dim : `int`, required.
+        The middle dimension of the FeedForward network. The input and output
+        dimensions are fixed to ensure sizes match up for the self attention layers.
+    num_layers : `int`, required.
+        The number of stacked self attention -> feedforward -> layer normalisation blocks.
+    num_attention_heads : `int`, required.
+        The number of attention heads to use per layer.
+    use_positional_encoding : `bool`, optional, (default = `True`)
+        Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended,
+        as without this feature, the self attention layers have no idea of absolute or relative
+        position (as they are just computing pairwise similarity between vectors of elements),
+        which can be important features for many tasks.
+    dropout_prob : `float`, optional, (default = `0.1`)
+        The dropout probability for the feedforward network.
+    """  # noqa
+
+    def __init__(
+        self,
+        input_dim: int,
+        num_layers: int,
+        feedforward_hidden_dim: int = 2048,
+        num_attention_heads: int = 4,
+        positional_encoding: Optional[str] = None,
+        positional_embedding_size: int = 512,
+        dropout_prob: float = 0.1,
+        activation: str = "relu",
+    ) -> None:
+        super().__init__()
+
+        layer = nn.TransformerEncoderLayer(
+            d_model=input_dim,
+            nhead=num_attention_heads,
+            dim_feedforward=feedforward_hidden_dim,
+            dropout=dropout_prob,
+            activation=activation,
+        )
+        self._transformer = nn.TransformerEncoder(layer, num_layers)
+        self._input_dim = input_dim
+
+        # initialize parameters
+        # We do this before the embeddings are initialized so we get the default initialization for the embeddings.
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+        if positional_encoding is None:
+            self._sinusoidal_positional_encoding = False
+            self._positional_embedding = None
+        elif positional_encoding == "sinusoidal":
+            self._sinusoidal_positional_encoding = True
+            self._positional_embedding = None
+        elif positional_encoding == "embedding":
+            self._sinusoidal_positional_encoding = False
+            self._positional_embedding = nn.Embedding(positional_embedding_size, input_dim)
+        else:
+            raise ValueError(
+                "positional_encoding must be one of None, 'sinusoidal', or 'embedding'"
+            )
+
+    def get_input_dim(self) -> int:
+        return self._input_dim
+
+    def get_output_dim(self) -> int:
+        return self._input_dim
+
+    def is_bidirectional(self):
+        return False
+
+    def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor):
+        output = inputs
+        if self._sinusoidal_positional_encoding:
+            output = add_positional_features(output)
+        if self._positional_embedding is not None:
+            position_ids = torch.arange(inputs.size(1), dtype=torch.long, device=output.device)
+            position_ids = position_ids.unsqueeze(0).expand(inputs.shape[:-1])
+            output = output + self._positional_embedding(position_ids)
+
+        # For some reason the torch transformer expects the shape (sequence, batch, features), not the more
+        # familiar (batch, sequence, features), so we have to fix it.
+        output = output.permute(1, 0, 2)
+        # For some other reason, the torch transformer takes the mask backwards.
+        mask = ~mask
+        output = self._transformer(output, src_key_padding_mask=mask)
+        output = output.permute(1, 0, 2)
+
+        return output
\ No newline at end of file
diff --git a/combo/nn/utils.py b/combo/nn/utils.py
index 4333be0a1fc5736c87a27efb18426eaee1c5d2e7..bdf4dc3c196d47975eed97c5bc57302441345bb8 100644
--- a/combo/nn/utils.py
+++ b/combo/nn/utils.py
@@ -2,6 +2,7 @@
 Adapted from AllenNLP
 https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
 """
+import math
 from typing import Union, Dict, Optional, List, Any, NamedTuple
 
 import torch
@@ -479,3 +480,61 @@ def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch.
     span_embeddings = batched_index_select(target, span_indices)
 
     return span_embeddings, span_mask
+
+
+def add_positional_features(
+    tensor: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4
+):
+
+    """
+    Implements the frequency-based positional encoding described
+    in [Attention is All you Need][0].
+
+    Adds sinusoids of different frequencies to a `Tensor`. A sinusoid of a
+    different frequency and phase is added to each dimension of the input `Tensor`.
+    This allows the attention heads to use absolute and relative positions.
+
+    The number of timescales is equal to hidden_dim / 2 within the range
+    (min_timescale, max_timescale). For each timescale, the two sinusoidal
+    signals sin(timestep / timescale) and cos(timestep / timescale) are
+    generated and concatenated along the hidden_dim dimension.
+
+    [0]: https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077
+
+    # Parameters
+
+    tensor : `torch.Tensor`
+        a Tensor with shape (batch_size, timesteps, hidden_dim).
+    min_timescale : `float`, optional (default = `1.0`)
+        The smallest timescale to use.
+    max_timescale : `float`, optional (default = `1.0e4`)
+        The largest timescale to use.
+
+    # Returns
+
+    `torch.Tensor`
+        The input tensor augmented with the sinusoidal frequencies.
+    """  # noqa
+    _, timesteps, hidden_dim = tensor.size()
+
+    timestep_range = get_range_vector(timesteps, get_device_of(tensor)).data.float()
+    # We're generating both cos and sin frequencies,
+    # so half for each.
+    num_timescales = hidden_dim // 2
+    timescale_range = get_range_vector(num_timescales, get_device_of(tensor)).data.float()
+
+    log_timescale_increments = math.log(float(max_timescale) / float(min_timescale)) / float(
+        num_timescales - 1
+    )
+    inverse_timescales = min_timescale * torch.exp(timescale_range * -log_timescale_increments)
+
+    # Broadcasted multiplication - shape (timesteps, num_timescales)
+    scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
+    # shape (timesteps, 2 * num_timescales)
+    sinusoids = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
+    if hidden_dim % 2 != 0:
+        # if the number of dimensions is odd, the cos and sin
+        # timescales had size (hidden_dim - 1) / 2, so we need
+        # to add a row of zeros to make up the difference.
+        sinusoids = torch.cat([sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
+    return tensor + sinusoids.unsqueeze(0)
\ No newline at end of file