From 3f7e1a00f4ac1db9d54ff35df010de83c9ddd8d6 Mon Sep 17 00:00:00 2001
From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com>
Date: Tue, 2 Jun 2020 11:44:00 +0200
Subject: [PATCH] Make tensorboard optional and add morphological features as
 possible input.

---
 README.md                                     |  4 +-
 combo/data/dataset.py                         | 15 +++-
 combo/data/token_indexers/__init__.py         |  1 +
 .../token_indexers/token_features_indexer.py  | 71 ++++++++++++++++++
 combo/main.py                                 |  9 ++-
 combo/models/embeddings.py                    | 74 +++++++++++++++----
 combo/models/lemma.py                         | 19 ++---
 combo/predict.py                              |  5 +-
 combo/training/tensorboard_writer.py          | 68 +++++++++++++++++
 combo/training/trainer.py                     | 68 +++++++++++++++--
 config.template.jsonnet                       | 28 +++++--
 setup.py                                      |  1 +
 tests/test_main.py                            |  2 +-
 13 files changed, 312 insertions(+), 53 deletions(-)
 create mode 100644 combo/data/token_indexers/token_features_indexer.py
 create mode 100644 combo/training/tensorboard_writer.py

diff --git a/README.md b/README.md
index 5da02eb..9b88650 100644
--- a/README.md
+++ b/README.md
@@ -90,8 +90,8 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name
 import combo.predict as predict
 
 model_path = "your_model.tar.gz"
-predictor = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
-parsed_tree = predictor.predict_string("Sentence to parse.")["tree"]
+nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path)
+parsed_tree = nlp("Sentence to parse.")["tree"]
 ```
 
 ## Configuration
diff --git a/combo/data/dataset.py b/combo/data/dataset.py
index 56061db..ab5ce2a 100644
--- a/combo/data/dataset.py
+++ b/combo/data/dataset.py
@@ -6,6 +6,7 @@ from allennlp import data as allen_data
 from allennlp.common import checks
 from allennlp.data import fields as allen_fields, vocabulary
 from conllu import parser
+from dataclasses import dataclass
 from overrides import overrides
 
 from combo.data import fields
@@ -85,10 +86,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader):
     @overrides
     def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance:
         fields_: Dict[str, allen_data.Field] = {}
-        tokens = [allen_data.Token(t['token'],
-                                   pos_=t.get('upostag'),
-                                   tag_=t.get('xpostag'),
-                                   lemma_=t.get('lemma'))
+        tokens = [Token(t['token'],
+                        pos_=t.get('upostag'),
+                        tag_=t.get('xpostag'),
+                        lemma_=t.get('lemma'),
+                        feats_=t.get('feats'))
                   for t in tree]
 
         # features
@@ -228,3 +230,8 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary):
                 slices[name] = [idx]
         vocab.slices = slices
         return vocab.slices
+
+
+@dataclass
+class Token(allen_data.Token):
+    feats_: Optional[str] = None
diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py
index 6b2a7b7..1b918b3 100644
--- a/combo/data/token_indexers/__init__.py
+++ b/combo/data/token_indexers/__init__.py
@@ -1 +1,2 @@
 from .token_characters_indexer import TokenCharactersIndexer
+from .token_features_indexer import TokenFeatsIndexer
diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py
new file mode 100644
index 0000000..eac755b
--- /dev/null
+++ b/combo/data/token_indexers/token_features_indexer.py
@@ -0,0 +1,71 @@
+"""Features indexer."""
+import collections
+from typing import List, Dict
+
+import torch
+from allennlp import data
+from allennlp.common import util
+from overrides import overrides
+
+
+@data.TokenIndexer.register('feats_indexer')
+class TokenFeatsIndexer(data.TokenIndexer):
+
+    def __init__(
+            self,
+            namespace: str = "feats",
+            feature_name: str = "feats_",
+            token_min_padding_length: int = 0,
+    ) -> None:
+        super().__init__(token_min_padding_length)
+        self.namespace = namespace
+        self._feature_name = feature_name
+
+    @overrides
+    def count_vocab_items(self, token: data.Token, counter: Dict[str, Dict[str, int]]):
+        feats = self._feat_values(token)
+        for feat in feats:
+            counter[self.namespace][feat] += 1
+
+    @overrides
+    def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList:
+        indices: List[List[int]] = []
+        vocab_size = vocabulary.get_vocab_size(self.namespace)
+        for token in tokens:
+            token_indices = []
+            feats = self._feat_values(token)
+            for feat in feats:
+                token_indices.append(vocabulary.get_token_index(feat, self.namespace))
+            indices.append(util.pad_sequence_to_length(token_indices, vocab_size))
+        return {"tokens": indices}
+
+    @overrides
+    def get_empty_token_list(self) -> data.IndexedTokenList:
+        return {"tokens": [[]]}
+
+    def _feat_values(self, token):
+        feats = getattr(token, self._feature_name)
+        if feats is None:
+            feats = collections.OrderedDict()
+        features = []
+        for feat, value in feats.items():
+            if feat in ['_', '__ROOT__']:
+                pass
+            else:
+                features.append(feat + '=' + value)
+        return features
+
+    @overrides
+    def as_padded_tensor_dict(
+            self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int]
+    ) -> Dict[str, torch.Tensor]:
+        tensor_dict = {}
+        for key, val in tokens.items():
+            vocab_size = len(val[0])
+            tensor = torch.tensor(util.pad_sequence_to_length(val,
+                                                              padding_lengths[key],
+                                                              default_value=lambda: [0] * vocab_size,
+                                                              )
+                                  )
+            tensor_dict[key] = tensor
+        return tensor_dict
diff --git a/combo/main.py b/combo/main.py
index 4bd9d65..d995136 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -45,13 +45,15 @@ flags.DEFINE_string(name='pretrained_transformer_name', default='',
                     help='Pretrained transformer model name (see transformers from HuggingFace library for list of'
                          'available models) for transformers based embeddings.')
 flags.DEFINE_multi_enum(name='features', default=['token', 'char'],
-                        enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma'],
+                        enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma', 'feats'],
                         help='Features used to train model (required `token` and `char`)')
 flags.DEFINE_multi_enum(name='targets', default=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag'],
                         enum_values=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag', 'semrel', 'sent'],
                         help='Targets of the model (required `deprel` and `head`)')
 flags.DEFINE_string(name='serialization_dir', default=None,
                     help='Model serialization directory (default - system temp dir).')
+flags.DEFINE_boolean(name='tensorboard', default=False,
+                     help='When provided model will log tensorboard metrics.')
 
 # Finetune after training flags
 flags.DEFINE_string(name='finetuning_training_data_path', default='',
@@ -145,7 +147,7 @@ def run(_):
             FLAGS.input_file,
             FLAGS.output_file,
             FLAGS.batch_size,
-            FLAGS.silent,
+            not FLAGS.silent,
             use_dataset_reader,
         )
         manager.run()
@@ -181,7 +183,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
             'embedding_dim': str(FLAGS.embedding_dim),
             'cuda_device': str(FLAGS.cuda_device),
             'num_epochs': str(FLAGS.num_epochs),
-            'word_batch_size': str(FLAGS.word_batch_size)
+            'word_batch_size': str(FLAGS.word_batch_size),
+            'use_tensorboard': str(FLAGS.tensorboard),
         }
 
 
diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py
index 15cd9ec..fe40775 100644
--- a/combo/models/embeddings.py
+++ b/combo/models/embeddings.py
@@ -3,10 +3,10 @@ from typing import Optional
 
 import torch
 import torch.nn as nn
-from allennlp import nn as allen_nn, data
+from allennlp import nn as allen_nn, data, modules
 from allennlp.modules import token_embedders
+from allennlp.nn import util
 from overrides import overrides
-from transformers import modeling_auto
 
 from combo.models import base, dilated_cnn
 
@@ -25,7 +25,7 @@ class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder):
             num_embeddings=num_embeddings,
             embedding_dim=embedding_dim,
         )
-        self.dilated_cnn_encoder = dilated_cnn_encoder
+        self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder)
         self.output_dim = embedding_dim
 
     def forward(self,
@@ -36,16 +36,8 @@ class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder):
 
         x = self.char_embed(x)
         x = x * char_mask.unsqueeze(-1).float()
-
-        BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_EMB = x.size()
-
-        words = []
-        for i in range(SENTENCE_LENGTH):
-            word = x[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB).transpose(1, 2)
-            word = self.dilated_cnn_encoder(word)
-            word, _ = torch.max(word, dim=2)
-            words.append(word)
-        return torch.stack(words, dim=1)
+        x = self.dilated_cnn_encoder(x.transpose(2, 3))
+        return torch.max(x, dim=-1)[0]
 
     @overrides
     def get_output_dim(self) -> int:
@@ -120,7 +112,9 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
                  projection_dropout_rate: Optional[float] = 0.0,
                  freeze_transformer: bool = True):
         super().__init__(model_name)
-        if freeze_transformer:
+        self.freeze_transformer = freeze_transformer
+        if self.freeze_transformer:
+            self._matched_embedder.eval()
             for param in self._matched_embedder.parameters():
                 param.requires_grad = False
         if projection_dim:
@@ -154,8 +148,56 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm
 
     @overrides
     def train(self, mode: bool):
-        self.projection_layer.train(mode)
+        if self.freeze_transformer:
+            self.projection_layer.train(mode)
+        else:
+            super().train(mode)
 
     @overrides
     def eval(self):
-        self.projection_layer.eval()
+        if self.freeze_transformer:
+            self.projection_layer.eval()
+        else:
+            super().eval()
+
+
+@token_embedders.TokenEmbedder.register("feats_embedding")
+class FeatsTokenEmbedder(token_embedders.Embedding):
+
+    def __init__(self,
+                 embedding_dim: int,
+                 num_embeddings: int = None,
+                 weight: torch.FloatTensor = None,
+                 padding_index: int = None,
+                 trainable: bool = True,
+                 max_norm: float = None,
+                 norm_type: float = 2.0,
+                 scale_grad_by_freq: bool = False,
+                 sparse: bool = False,
+                 vocab_namespace: str = "feats",
+                 pretrained_file: str = None,
+                 vocab: data.Vocabulary = None):
+        super().__init__(
+            embedding_dim=embedding_dim,
+            num_embeddings=num_embeddings,
+            weight=weight,
+            padding_index=padding_index,
+            trainable=trainable,
+            max_norm=max_norm,
+            norm_type=norm_type,
+            scale_grad_by_freq=scale_grad_by_freq,
+            sparse=sparse,
+            vocab_namespace=vocab_namespace,
+            pretrained_file=pretrained_file,
+            vocab=vocab
+        )
+
+    def forward(self, tokens: torch.Tensor) -> torch.Tensor:
+        # (batch_size, sentence_length, features_vocab_length)
+        mask = (tokens > 0).float()
+        # (batch_size, sentence_length, features_vocab_length, embedding_dim)
+        x = super().forward(tokens)
+        # (batch_size, sentence_length, embedding_dim)
+        return x.sum(dim=-2) / (
+            (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1)
+        )
diff --git a/combo/models/lemma.py b/combo/models/lemma.py
index fecab0b..0df4bf7 100644
--- a/combo/models/lemma.py
+++ b/combo/models/lemma.py
@@ -3,7 +3,7 @@ from typing import Optional, Dict, List, Union
 
 import torch
 import torch.nn as nn
-from allennlp import data, nn as allen_nn
+from allennlp import data, nn as allen_nn, modules
 from allennlp.common import checks
 
 from combo.models import base, dilated_cnn, utils
@@ -23,7 +23,7 @@ class LemmatizerModel(base.Predictor):
             num_embeddings=num_embeddings,
             embedding_dim=embedding_dim,
         )
-        self.dilated_cnn_encoder = dilated_cnn_encoder
+        self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder)
         self.input_projection_layer = input_projection_layer
 
     def forward(self,
@@ -36,20 +36,11 @@ class LemmatizerModel(base.Predictor):
         encoder_emb = self.input_projection_layer(encoder_emb)
         char_embeddings = self.char_embed(chars)
 
-        BATCH_SIZE, SENTENCE_LENGTH, WORD_EMB = encoder_emb.size()
-        _, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size()
-
-
+        BATCH_SIZE, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size()
         encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1)
 
-        pred = []
-        for i in range(SENTENCE_LENGTH):
-            word_emb = (encoder_emb[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, -1))
-            char_sent_emb = char_embeddings[:, i, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB)
-            x = torch.cat((char_sent_emb, word_emb), -1).transpose(2, 1)
-            x = self.dilated_cnn_encoder(x)
-            pred.append(x)
-        x = torch.stack(pred, dim=1).transpose(2, 3)
+        x = torch.cat((char_embeddings, encoder_emb), dim=-1).transpose(2, 3)
+        x = self.dilated_cnn_encoder(x).transpose(2, 3)
         output = {
             'prediction': x.argmax(-1),
             'probability': x
diff --git a/combo/predict.py b/combo/predict.py
index 7c7bb8f..b30da43 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -57,9 +57,12 @@ class SemanticMultitaskPredictor(predictor.Predictor):
         logger.info('Took {} ms'.format((end_time - start_time) * 1000.0))
         return result
 
-    def predict_string(self, sentence: str):
+    def predict(self, sentence: str):
         return self.predict_json({'sentence': sentence})
 
+    def __call__(self, sentence: str):
+        return self.predict(sentence)
+
     @overrides
     def predict_json(self, inputs: common.JsonDict) -> common.JsonDict:
         start_time = time.time()
diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py
new file mode 100644
index 0000000..83a2f80
--- /dev/null
+++ b/combo/training/tensorboard_writer.py
@@ -0,0 +1,68 @@
+from typing import Dict, Optional, List
+
+import torch
+from allennlp import models, common
+from allennlp.data import dataloader
+from allennlp.training import optimizers
+
+
+class NullTensorboardWriter(common.FromParams):
+
+    def log_batch(
+        self,
+        model: models.Model,
+        optimizer: optimizers.Optimizer,
+        batch_grad_norm: Optional[float],
+        metrics: Dict[str, float],
+        batch_group: List[List[dataloader.TensorDict]],
+        param_updates: Optional[Dict[str, torch.Tensor]],
+    ) -> None:
+        pass
+
+    def reset_epoch(self) -> None:
+        pass
+
+    def should_log_this_batch(self) -> bool:
+        return False
+
+    def should_log_histograms_this_batch(self) -> bool:
+        return False
+
+    def add_train_scalar(self, name: str, value: float, timestep: int = None) -> None:
+        pass
+
+    def add_train_histogram(self, name: str, values: torch.Tensor) -> None:
+        pass
+
+    def add_validation_scalar(self, name: str, value: float, timestep: int = None) -> None:
+        pass
+
+    def log_parameter_and_gradient_statistics(self, model: models.Model, batch_grad_norm: float) -> None:
+        pass
+
+    def log_learning_rates(self, model: models.Model, optimizer: torch.optim.Optimizer):
+        pass
+
+    def log_histograms(self, model: models.Model) -> None:
+        pass
+
+    def log_gradient_updates(self, model: models.Model, param_updates: Dict[str, torch.Tensor]) -> None:
+        pass
+
+    def log_metrics(
+        self,
+        train_metrics: dict,
+        val_metrics: dict = None,
+        epoch: int = None,
+        log_to_console: bool = False,
+    ) -> None:
+        pass
+
+    def enable_activation_logging(self, model: models.Model) -> None:
+        pass
+
+    def log_activation_histogram(self, outputs, log_prefix: str) -> None:
+        pass
+
+    def close(self) -> None:
+        pass
diff --git a/combo/training/trainer.py b/combo/training/trainer.py
index 01773c3..330cf53 100644
--- a/combo/training/trainer.py
+++ b/combo/training/trainer.py
@@ -10,18 +10,20 @@ import torch.distributed as dist
 import torch.optim as optim
 import torch.optim.lr_scheduler
 import torch.utils.data as data
-from allennlp import training
+from allennlp import training, common
 from allennlp.common import checks
 from allennlp.common import util as common_util
 from allennlp.models import model
-from allennlp.training import checkpointer
+from allennlp.training import checkpointer, optimizers
 from allennlp.training import learning_rate_schedulers
 from allennlp.training import momentum_schedulers
 from allennlp.training import moving_average
-from allennlp.training import tensorboard_writer
+from allennlp.training import tensorboard_writer as allen_tensorboard_writer
 from allennlp.training import util as training_util
 from overrides import overrides
 
+from combo.training import tensorboard_writer as combo_tensorboard_writer
+
 logger = logging.getLogger(__name__)
 
 
@@ -47,13 +49,12 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
                  grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None,
                  learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None,
                  momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None,
-                 tensorboard_writer: tensorboard_writer.TensorboardWriter = None,
+                 tensorboard_writer: allen_tensorboard_writer.TensorboardWriter = None,
                  moving_average: Optional[moving_average.MovingAverage] = None,
                  batch_callbacks: List[training.BatchCallback] = None,
                  epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0,
                  world_size: int = 1, num_gradient_accumulation_steps: int = 1,
                  opt_level: Optional[str] = None) -> None:
-
         super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs,
                          serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping,
                          learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average,
@@ -211,3 +212,60 @@ class GradientDescentTrainer(training.GradientDescentTrainer):
             self.model.load_state_dict(best_model_state)
 
         return metrics
+
+    @classmethod
+    def from_partial_objects(
+        cls,
+        model: model.Model,
+        serialization_dir: str,
+        data_loader: data.DataLoader,
+        validation_data_loader: data.DataLoader = None,
+        local_rank: int = 0,
+        patience: int = None,
+        validation_metric: str = "-loss",
+        num_epochs: int = 20,
+        cuda_device: int = -1,
+        grad_norm: float = None,
+        grad_clipping: float = None,
+        distributed: bool = None,
+        world_size: int = 1,
+        num_gradient_accumulation_steps: int = 1,
+        opt_level: Optional[str] = None,
+        no_grad: List[str] = None,
+        optimizer: common.Lazy[optimizers.Optimizer] = None,
+        learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None,
+        momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None,
+        tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None,
+        moving_average: common.Lazy[moving_average.MovingAverage] = None,
+        checkpointer: common.Lazy[training.Checkpointer] = None,
+        batch_callbacks: List[training.BatchCallback] = None,
+        epoch_callbacks: List[training.EpochCallback] = None,
+    ) -> "training.Trainer":
+        if tensorboard_writer.construct() is None:
+            tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter)
+        return super().from_partial_objects(
+            model=model,
+            serialization_dir=serialization_dir,
+            data_loader=data_loader,
+            validation_data_loader=validation_data_loader,
+            local_rank=local_rank,
+            patience=patience,
+            validation_metric=validation_metric,
+            num_epochs=num_epochs,
+            cuda_device=cuda_device,
+            grad_norm=grad_norm,
+            grad_clipping=grad_clipping,
+            distributed=distributed,
+            world_size=world_size,
+            num_gradient_accumulation_steps=num_gradient_accumulation_steps,
+            opt_level=opt_level,
+            no_grad=no_grad,
+            optimizer=optimizer,
+            learning_rate_scheduler=learning_rate_scheduler,
+            momentum_scheduler=momentum_scheduler,
+            tensorboard_writer=tensorboard_writer,
+            moving_average=moving_average,
+            checkpointer=checkpointer,
+            batch_callbacks=batch_callbacks,
+            epoch_callbacks=epoch_callbacks,
+        )
diff --git a/config.template.jsonnet b/config.template.jsonnet
index ae99eb3..0f12d8f 100644
--- a/config.template.jsonnet
+++ b/config.template.jsonnet
@@ -28,8 +28,6 @@ local features = std.split(std.extVar("features"), " ");
 # Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent"
 # Required "deprel", "head"
 local targets = std.split(std.extVar("targets"), " ");
-# Path for tensorboard metrics, str
-local metrics_dir = "./runs";
 # Word embedding dimension, int
 # If pretrained_tokens is not null must much provided dimensionality
 local embedding_dim = std.parseInt(std.extVar("embedding_dim"));
@@ -42,6 +40,9 @@ local xpostag_dim = 100;
 # Upostag embedding dimension, int
 # (discarded if upostag not in features)
 local upostag_dim = 100;
+# Feats embedding dimension, int
+# (discarded if feats not in featres)
+local feats_dim = 100;
 # Lemma embedding dimension, int
 # (discarded if lemma not in features)
 local lemma_char_dim = 64;
@@ -67,7 +68,10 @@ local cycle_loss_n = 0;
 # Maximum length of the word, int
 # Shorter words are padded, longer - truncated
 local word_length = 30;
-
+# Whether to use tensorboard, bool
+local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false;
+# Path for tensorboard metrics, str
+local metrics_dir = "./runs";
 
 # Helper functions
 local in_features(name) = !(std.length(std.find(name, features)) == 0);
@@ -141,6 +145,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 # +2 for start and end token
                 min_padding_length: word_length + 2,
             },
+            feats: {
+                type: "feats_indexer",
+            },
         },
         lemma_indexers: {
             char: {
@@ -233,6 +240,12 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                         activations: ["relu", "relu", "linear"],
                     },
                 },
+                feats: if in_features("feats") then {
+                    type: "feats_embedding",
+                    padding_index: 0,
+                    embedding_dim: feats_dim,
+                    vocab_namespace: "feats",
+                },
             },
         },
         loss_weights: loss_weights,
@@ -244,7 +257,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
                 char_dim + projected_embedding_dim +
                 if in_features('xpostag') then xpostag_dim else 0 +
                 if in_features('lemma') then lemma_char_dim else 0 +
-                if in_features('upostag') then upostag_dim else 0,
+                if in_features('upostag') then upostag_dim else 0 +
+                if in_features('feats') then feats_dim else 0,
                 hidden_size: hidden_size,
                 num_layers: num_layers,
                 recurrent_dropout_probability: 0.33,
@@ -342,7 +356,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
             ],
         },
     }),
-    trainer: {
+    trainer: std.prune({
         checkpointer: {
             type: "finishing_only_checkpointer",
         },
@@ -362,12 +376,12 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
         learning_rate_scheduler: {
             type: "combo_scheduler",
         },
-        tensorboard_writer: {
+        tensorboard_writer: if use_tensorboard then {
             serialization_dir: metrics_dir,
             should_log_learning_rate: false,
             should_log_parameter_statistics: false,
             summary_interval: 100,
         },
         validation_metric: "+EM",
-    },
+    }),
 }
diff --git a/setup.py b/setup.py
index c7caeb3..bb402c8 100644
--- a/setup.py
+++ b/setup.py
@@ -21,5 +21,6 @@ setup(
     packages=find_packages(exclude=['tests']),
     setup_requires=['pytest-runner', 'pytest-pylint'],
     tests_require=['pytest', 'pylint'],
+    python_requires='>=3.6',
     entry_points={'console_scripts': ['combo = combo.main:main']},
 )
diff --git a/tests/test_main.py b/tests/test_main.py
index 325453c..8c42245 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -37,10 +37,10 @@ class TrainingEndToEndTest(unittest.TestCase):
             'cuda_device': '-1',
             'num_epochs': '1',
             'word_batch_size': '1',
+            'use_tensorboard': 'False'
         }
         params = Params.from_file(os.path.join(self.PROJECT_ROOT, 'config.template.jsonnet'),
                                   ext_vars=ext_vars)
-        params['trainer']['tensorboard_writer']['serialization_dir'] = os.path.join(self.TEST_DIR, 'metrics')
 
         # when
         model = train.train_model(params, serialization_dir=self.TEST_DIR)
-- 
GitLab