diff --git a/README.md b/README.md index 5da02eb9a1bbffea048252e167a8f2df3cde2f03..9b88650adcda7a8d7635e5b058e4320286e42109 100644 --- a/README.md +++ b/README.md @@ -90,8 +90,8 @@ Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name import combo.predict as predict model_path = "your_model.tar.gz" -predictor = predict.SemanticMultitaskPredictor.from_pretrained(model_path) -parsed_tree = predictor.predict_string("Sentence to parse.")["tree"] +nlp = predict.SemanticMultitaskPredictor.from_pretrained(model_path) +parsed_tree = nlp("Sentence to parse.")["tree"] ``` ## Configuration diff --git a/combo/data/dataset.py b/combo/data/dataset.py index 56061db098b3452914ecf0b0173dac5ebb8482d4..ab5ce2aecb974a37a62248304281dadf4da8c9f2 100644 --- a/combo/data/dataset.py +++ b/combo/data/dataset.py @@ -6,6 +6,7 @@ from allennlp import data as allen_data from allennlp.common import checks from allennlp.data import fields as allen_fields, vocabulary from conllu import parser +from dataclasses import dataclass from overrides import overrides from combo.data import fields @@ -85,10 +86,11 @@ class UniversalDependenciesDatasetReader(allen_data.DatasetReader): @overrides def text_to_instance(self, tree: conllu.TokenList) -> allen_data.Instance: fields_: Dict[str, allen_data.Field] = {} - tokens = [allen_data.Token(t['token'], - pos_=t.get('upostag'), - tag_=t.get('xpostag'), - lemma_=t.get('lemma')) + tokens = [Token(t['token'], + pos_=t.get('upostag'), + tag_=t.get('xpostag'), + lemma_=t.get('lemma'), + feats_=t.get('feats')) for t in tree] # features @@ -228,3 +230,8 @@ def get_slices_if_not_provided(vocab: allen_data.Vocabulary): slices[name] = [idx] vocab.slices = slices return vocab.slices + + +@dataclass +class Token(allen_data.Token): + feats_: Optional[str] = None diff --git a/combo/data/token_indexers/__init__.py b/combo/data/token_indexers/__init__.py index 6b2a7b7510477a68af50c60d40cc21eef3d863bc..1b918b3ad66692a761b564c08d0c270745a263cb 100644 --- a/combo/data/token_indexers/__init__.py +++ b/combo/data/token_indexers/__init__.py @@ -1 +1,2 @@ from .token_characters_indexer import TokenCharactersIndexer +from .token_features_indexer import TokenFeatsIndexer diff --git a/combo/data/token_indexers/token_features_indexer.py b/combo/data/token_indexers/token_features_indexer.py new file mode 100644 index 0000000000000000000000000000000000000000..eac755bd4b9503c79ec0656d885c734a79541c18 --- /dev/null +++ b/combo/data/token_indexers/token_features_indexer.py @@ -0,0 +1,71 @@ +"""Features indexer.""" +import collections +from typing import List, Dict + +import torch +from allennlp import data +from allennlp.common import util +from overrides import overrides + + +@data.TokenIndexer.register('feats_indexer') +class TokenFeatsIndexer(data.TokenIndexer): + + def __init__( + self, + namespace: str = "feats", + feature_name: str = "feats_", + token_min_padding_length: int = 0, + ) -> None: + super().__init__(token_min_padding_length) + self.namespace = namespace + self._feature_name = feature_name + + @overrides + def count_vocab_items(self, token: data.Token, counter: Dict[str, Dict[str, int]]): + feats = self._feat_values(token) + for feat in feats: + counter[self.namespace][feat] += 1 + + @overrides + def tokens_to_indices(self, tokens: List[data.Token], vocabulary: data.Vocabulary) -> data.IndexedTokenList: + indices: List[List[int]] = [] + vocab_size = vocabulary.get_vocab_size(self.namespace) + for token in tokens: + token_indices = [] + feats = self._feat_values(token) + for feat in feats: + token_indices.append(vocabulary.get_token_index(feat, self.namespace)) + indices.append(util.pad_sequence_to_length(token_indices, vocab_size)) + return {"tokens": indices} + + @overrides + def get_empty_token_list(self) -> data.IndexedTokenList: + return {"tokens": [[]]} + + def _feat_values(self, token): + feats = getattr(token, self._feature_name) + if feats is None: + feats = collections.OrderedDict() + features = [] + for feat, value in feats.items(): + if feat in ['_', '__ROOT__']: + pass + else: + features.append(feat + '=' + value) + return features + + @overrides + def as_padded_tensor_dict( + self, tokens: data.IndexedTokenList, padding_lengths: Dict[str, int] + ) -> Dict[str, torch.Tensor]: + tensor_dict = {} + for key, val in tokens.items(): + vocab_size = len(val[0]) + tensor = torch.tensor(util.pad_sequence_to_length(val, + padding_lengths[key], + default_value=lambda: [0] * vocab_size, + ) + ) + tensor_dict[key] = tensor + return tensor_dict diff --git a/combo/main.py b/combo/main.py index 4bd9d65afa9762cc9ceb9e05df4cac466bd16a5e..d9951362346b1d08ef2fa2bc11aff9781a82d12c 100644 --- a/combo/main.py +++ b/combo/main.py @@ -45,13 +45,15 @@ flags.DEFINE_string(name='pretrained_transformer_name', default='', help='Pretrained transformer model name (see transformers from HuggingFace library for list of' 'available models) for transformers based embeddings.') flags.DEFINE_multi_enum(name='features', default=['token', 'char'], - enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma'], + enum_values=['token', 'char', 'upostag', 'xpostag', 'lemma', 'feats'], help='Features used to train model (required `token` and `char`)') flags.DEFINE_multi_enum(name='targets', default=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag'], enum_values=['deprel', 'feats', 'head', 'lemma', 'upostag', 'xpostag', 'semrel', 'sent'], help='Targets of the model (required `deprel` and `head`)') flags.DEFINE_string(name='serialization_dir', default=None, help='Model serialization directory (default - system temp dir).') +flags.DEFINE_boolean(name='tensorboard', default=False, + help='When provided model will log tensorboard metrics.') # Finetune after training flags flags.DEFINE_string(name='finetuning_training_data_path', default='', @@ -145,7 +147,7 @@ def run(_): FLAGS.input_file, FLAGS.output_file, FLAGS.batch_size, - FLAGS.silent, + not FLAGS.silent, use_dataset_reader, ) manager.run() @@ -181,7 +183,8 @@ def _get_ext_vars(finetuning: bool = False) -> Dict: 'embedding_dim': str(FLAGS.embedding_dim), 'cuda_device': str(FLAGS.cuda_device), 'num_epochs': str(FLAGS.num_epochs), - 'word_batch_size': str(FLAGS.word_batch_size) + 'word_batch_size': str(FLAGS.word_batch_size), + 'use_tensorboard': str(FLAGS.tensorboard), } diff --git a/combo/models/embeddings.py b/combo/models/embeddings.py index 15cd9ec14b8ab9ab43e900a45963dc2af7f75ab2..fe407751d970201147f8f8f8bb011c35e6b4530c 100644 --- a/combo/models/embeddings.py +++ b/combo/models/embeddings.py @@ -3,10 +3,10 @@ from typing import Optional import torch import torch.nn as nn -from allennlp import nn as allen_nn, data +from allennlp import nn as allen_nn, data, modules from allennlp.modules import token_embedders +from allennlp.nn import util from overrides import overrides -from transformers import modeling_auto from combo.models import base, dilated_cnn @@ -25,7 +25,7 @@ class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder): num_embeddings=num_embeddings, embedding_dim=embedding_dim, ) - self.dilated_cnn_encoder = dilated_cnn_encoder + self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder) self.output_dim = embedding_dim def forward(self, @@ -36,16 +36,8 @@ class CharacterBasedWordEmbeddings(token_embedders.TokenEmbedder): x = self.char_embed(x) x = x * char_mask.unsqueeze(-1).float() - - BATCH_SIZE, SENTENCE_LENGTH, MAX_WORD_LENGTH, CHAR_EMB = x.size() - - words = [] - for i in range(SENTENCE_LENGTH): - word = x[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB).transpose(1, 2) - word = self.dilated_cnn_encoder(word) - word, _ = torch.max(word, dim=2) - words.append(word) - return torch.stack(words, dim=1) + x = self.dilated_cnn_encoder(x.transpose(2, 3)) + return torch.max(x, dim=-1)[0] @overrides def get_output_dim(self) -> int: @@ -120,7 +112,9 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm projection_dropout_rate: Optional[float] = 0.0, freeze_transformer: bool = True): super().__init__(model_name) - if freeze_transformer: + self.freeze_transformer = freeze_transformer + if self.freeze_transformer: + self._matched_embedder.eval() for param in self._matched_embedder.parameters(): param.requires_grad = False if projection_dim: @@ -154,8 +148,56 @@ class TransformersWordEmbedder(token_embedders.PretrainedTransformerMismatchedEm @overrides def train(self, mode: bool): - self.projection_layer.train(mode) + if self.freeze_transformer: + self.projection_layer.train(mode) + else: + super().train(mode) @overrides def eval(self): - self.projection_layer.eval() + if self.freeze_transformer: + self.projection_layer.eval() + else: + super().eval() + + +@token_embedders.TokenEmbedder.register("feats_embedding") +class FeatsTokenEmbedder(token_embedders.Embedding): + + def __init__(self, + embedding_dim: int, + num_embeddings: int = None, + weight: torch.FloatTensor = None, + padding_index: int = None, + trainable: bool = True, + max_norm: float = None, + norm_type: float = 2.0, + scale_grad_by_freq: bool = False, + sparse: bool = False, + vocab_namespace: str = "feats", + pretrained_file: str = None, + vocab: data.Vocabulary = None): + super().__init__( + embedding_dim=embedding_dim, + num_embeddings=num_embeddings, + weight=weight, + padding_index=padding_index, + trainable=trainable, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + sparse=sparse, + vocab_namespace=vocab_namespace, + pretrained_file=pretrained_file, + vocab=vocab + ) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + # (batch_size, sentence_length, features_vocab_length) + mask = (tokens > 0).float() + # (batch_size, sentence_length, features_vocab_length, embedding_dim) + x = super().forward(tokens) + # (batch_size, sentence_length, embedding_dim) + return x.sum(dim=-2) / ( + (mask.sum(dim=-1) + util.tiny_value_of_dtype(mask.dtype)).unsqueeze(dim=-1) + ) diff --git a/combo/models/lemma.py b/combo/models/lemma.py index fecab0b369294ce644316f1b36ae570dfbaaba64..0df4bf7e5a1789129b4f060a2922f1f931cfc46b 100644 --- a/combo/models/lemma.py +++ b/combo/models/lemma.py @@ -3,7 +3,7 @@ from typing import Optional, Dict, List, Union import torch import torch.nn as nn -from allennlp import data, nn as allen_nn +from allennlp import data, nn as allen_nn, modules from allennlp.common import checks from combo.models import base, dilated_cnn, utils @@ -23,7 +23,7 @@ class LemmatizerModel(base.Predictor): num_embeddings=num_embeddings, embedding_dim=embedding_dim, ) - self.dilated_cnn_encoder = dilated_cnn_encoder + self.dilated_cnn_encoder = modules.TimeDistributed(dilated_cnn_encoder) self.input_projection_layer = input_projection_layer def forward(self, @@ -36,20 +36,11 @@ class LemmatizerModel(base.Predictor): encoder_emb = self.input_projection_layer(encoder_emb) char_embeddings = self.char_embed(chars) - BATCH_SIZE, SENTENCE_LENGTH, WORD_EMB = encoder_emb.size() - _, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size() - - + BATCH_SIZE, _, MAX_WORD_LENGTH, CHAR_EMB = char_embeddings.size() encoder_emb = encoder_emb.unsqueeze(2).repeat(1, 1, MAX_WORD_LENGTH, 1) - pred = [] - for i in range(SENTENCE_LENGTH): - word_emb = (encoder_emb[:, i, :, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, -1)) - char_sent_emb = char_embeddings[:, i, :].reshape(BATCH_SIZE, MAX_WORD_LENGTH, CHAR_EMB) - x = torch.cat((char_sent_emb, word_emb), -1).transpose(2, 1) - x = self.dilated_cnn_encoder(x) - pred.append(x) - x = torch.stack(pred, dim=1).transpose(2, 3) + x = torch.cat((char_embeddings, encoder_emb), dim=-1).transpose(2, 3) + x = self.dilated_cnn_encoder(x).transpose(2, 3) output = { 'prediction': x.argmax(-1), 'probability': x diff --git a/combo/predict.py b/combo/predict.py index 7c7bb8f36c4ad5c66790c30611f8420456c11a50..b30da436572f93edde9548dd40f5acff946c0fd5 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -57,9 +57,12 @@ class SemanticMultitaskPredictor(predictor.Predictor): logger.info('Took {} ms'.format((end_time - start_time) * 1000.0)) return result - def predict_string(self, sentence: str): + def predict(self, sentence: str): return self.predict_json({'sentence': sentence}) + def __call__(self, sentence: str): + return self.predict(sentence) + @overrides def predict_json(self, inputs: common.JsonDict) -> common.JsonDict: start_time = time.time() diff --git a/combo/training/tensorboard_writer.py b/combo/training/tensorboard_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..83a2f801adfcc637fad432ee79888bcae66bb282 --- /dev/null +++ b/combo/training/tensorboard_writer.py @@ -0,0 +1,68 @@ +from typing import Dict, Optional, List + +import torch +from allennlp import models, common +from allennlp.data import dataloader +from allennlp.training import optimizers + + +class NullTensorboardWriter(common.FromParams): + + def log_batch( + self, + model: models.Model, + optimizer: optimizers.Optimizer, + batch_grad_norm: Optional[float], + metrics: Dict[str, float], + batch_group: List[List[dataloader.TensorDict]], + param_updates: Optional[Dict[str, torch.Tensor]], + ) -> None: + pass + + def reset_epoch(self) -> None: + pass + + def should_log_this_batch(self) -> bool: + return False + + def should_log_histograms_this_batch(self) -> bool: + return False + + def add_train_scalar(self, name: str, value: float, timestep: int = None) -> None: + pass + + def add_train_histogram(self, name: str, values: torch.Tensor) -> None: + pass + + def add_validation_scalar(self, name: str, value: float, timestep: int = None) -> None: + pass + + def log_parameter_and_gradient_statistics(self, model: models.Model, batch_grad_norm: float) -> None: + pass + + def log_learning_rates(self, model: models.Model, optimizer: torch.optim.Optimizer): + pass + + def log_histograms(self, model: models.Model) -> None: + pass + + def log_gradient_updates(self, model: models.Model, param_updates: Dict[str, torch.Tensor]) -> None: + pass + + def log_metrics( + self, + train_metrics: dict, + val_metrics: dict = None, + epoch: int = None, + log_to_console: bool = False, + ) -> None: + pass + + def enable_activation_logging(self, model: models.Model) -> None: + pass + + def log_activation_histogram(self, outputs, log_prefix: str) -> None: + pass + + def close(self) -> None: + pass diff --git a/combo/training/trainer.py b/combo/training/trainer.py index 01773c394199c097214e9887e872e77580cfad54..330cf53e1d7921bf1550846abc832a67776b7795 100644 --- a/combo/training/trainer.py +++ b/combo/training/trainer.py @@ -10,18 +10,20 @@ import torch.distributed as dist import torch.optim as optim import torch.optim.lr_scheduler import torch.utils.data as data -from allennlp import training +from allennlp import training, common from allennlp.common import checks from allennlp.common import util as common_util from allennlp.models import model -from allennlp.training import checkpointer +from allennlp.training import checkpointer, optimizers from allennlp.training import learning_rate_schedulers from allennlp.training import momentum_schedulers from allennlp.training import moving_average -from allennlp.training import tensorboard_writer +from allennlp.training import tensorboard_writer as allen_tensorboard_writer from allennlp.training import util as training_util from overrides import overrides +from combo.training import tensorboard_writer as combo_tensorboard_writer + logger = logging.getLogger(__name__) @@ -47,13 +49,12 @@ class GradientDescentTrainer(training.GradientDescentTrainer): grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, learning_rate_scheduler: Optional[learning_rate_schedulers.LearningRateScheduler] = None, momentum_scheduler: Optional[momentum_schedulers.MomentumScheduler] = None, - tensorboard_writer: tensorboard_writer.TensorboardWriter = None, + tensorboard_writer: allen_tensorboard_writer.TensorboardWriter = None, moving_average: Optional[moving_average.MovingAverage] = None, batch_callbacks: List[training.BatchCallback] = None, epoch_callbacks: List[training.EpochCallback] = None, distributed: bool = False, local_rank: int = 0, world_size: int = 1, num_gradient_accumulation_steps: int = 1, opt_level: Optional[str] = None) -> None: - super().__init__(model, optimizer, data_loader, patience, validation_metric, validation_data_loader, num_epochs, serialization_dir, checkpointer, cuda_device, grad_norm, grad_clipping, learning_rate_scheduler, momentum_scheduler, tensorboard_writer, moving_average, @@ -211,3 +212,60 @@ class GradientDescentTrainer(training.GradientDescentTrainer): self.model.load_state_dict(best_model_state) return metrics + + @classmethod + def from_partial_objects( + cls, + model: model.Model, + serialization_dir: str, + data_loader: data.DataLoader, + validation_data_loader: data.DataLoader = None, + local_rank: int = 0, + patience: int = None, + validation_metric: str = "-loss", + num_epochs: int = 20, + cuda_device: int = -1, + grad_norm: float = None, + grad_clipping: float = None, + distributed: bool = None, + world_size: int = 1, + num_gradient_accumulation_steps: int = 1, + opt_level: Optional[str] = None, + no_grad: List[str] = None, + optimizer: common.Lazy[optimizers.Optimizer] = None, + learning_rate_scheduler: common.Lazy[learning_rate_schedulers.LearningRateScheduler] = None, + momentum_scheduler: common.Lazy[momentum_schedulers.MomentumScheduler] = None, + tensorboard_writer: common.Lazy[allen_tensorboard_writer.TensorboardWriter] = None, + moving_average: common.Lazy[moving_average.MovingAverage] = None, + checkpointer: common.Lazy[training.Checkpointer] = None, + batch_callbacks: List[training.BatchCallback] = None, + epoch_callbacks: List[training.EpochCallback] = None, + ) -> "training.Trainer": + if tensorboard_writer.construct() is None: + tensorboard_writer = common.Lazy(combo_tensorboard_writer.NullTensorboardWriter) + return super().from_partial_objects( + model=model, + serialization_dir=serialization_dir, + data_loader=data_loader, + validation_data_loader=validation_data_loader, + local_rank=local_rank, + patience=patience, + validation_metric=validation_metric, + num_epochs=num_epochs, + cuda_device=cuda_device, + grad_norm=grad_norm, + grad_clipping=grad_clipping, + distributed=distributed, + world_size=world_size, + num_gradient_accumulation_steps=num_gradient_accumulation_steps, + opt_level=opt_level, + no_grad=no_grad, + optimizer=optimizer, + learning_rate_scheduler=learning_rate_scheduler, + momentum_scheduler=momentum_scheduler, + tensorboard_writer=tensorboard_writer, + moving_average=moving_average, + checkpointer=checkpointer, + batch_callbacks=batch_callbacks, + epoch_callbacks=epoch_callbacks, + ) diff --git a/config.template.jsonnet b/config.template.jsonnet index ae99eb37256ffd9fce5848e4fe0fbf9a401e4bbf..0f12d8ff5f4a70526bba3356db698a52723aa161 100644 --- a/config.template.jsonnet +++ b/config.template.jsonnet @@ -28,8 +28,6 @@ local features = std.split(std.extVar("features"), " "); # Choice "feats", "lemma", "upostag", "xpostag", "semrel". "sent" # Required "deprel", "head" local targets = std.split(std.extVar("targets"), " "); -# Path for tensorboard metrics, str -local metrics_dir = "./runs"; # Word embedding dimension, int # If pretrained_tokens is not null must much provided dimensionality local embedding_dim = std.parseInt(std.extVar("embedding_dim")); @@ -42,6 +40,9 @@ local xpostag_dim = 100; # Upostag embedding dimension, int # (discarded if upostag not in features) local upostag_dim = 100; +# Feats embedding dimension, int +# (discarded if feats not in featres) +local feats_dim = 100; # Lemma embedding dimension, int # (discarded if lemma not in features) local lemma_char_dim = 64; @@ -67,7 +68,10 @@ local cycle_loss_n = 0; # Maximum length of the word, int # Shorter words are padded, longer - truncated local word_length = 30; - +# Whether to use tensorboard, bool +local use_tensorboard = if std.extVar("use_tensorboard") == "True" then true else false; +# Path for tensorboard metrics, str +local metrics_dir = "./runs"; # Helper functions local in_features(name) = !(std.length(std.find(name, features)) == 0); @@ -141,6 +145,9 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't # +2 for start and end token min_padding_length: word_length + 2, }, + feats: { + type: "feats_indexer", + }, }, lemma_indexers: { char: { @@ -233,6 +240,12 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't activations: ["relu", "relu", "linear"], }, }, + feats: if in_features("feats") then { + type: "feats_embedding", + padding_index: 0, + embedding_dim: feats_dim, + vocab_namespace: "feats", + }, }, }, loss_weights: loss_weights, @@ -244,7 +257,8 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't char_dim + projected_embedding_dim + if in_features('xpostag') then xpostag_dim else 0 + if in_features('lemma') then lemma_char_dim else 0 + - if in_features('upostag') then upostag_dim else 0, + if in_features('upostag') then upostag_dim else 0 + + if in_features('feats') then feats_dim else 0, hidden_size: hidden_size, num_layers: num_layers, recurrent_dropout_probability: 0.33, @@ -342,7 +356,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't ], }, }), - trainer: { + trainer: std.prune({ checkpointer: { type: "finishing_only_checkpointer", }, @@ -362,12 +376,12 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't learning_rate_scheduler: { type: "combo_scheduler", }, - tensorboard_writer: { + tensorboard_writer: if use_tensorboard then { serialization_dir: metrics_dir, should_log_learning_rate: false, should_log_parameter_statistics: false, summary_interval: 100, }, validation_metric: "+EM", - }, + }), } diff --git a/setup.py b/setup.py index c7caeb3b98e89a09d0737a809e2c8af865f65a94..bb402c8f9dae2c06df51b9b94aa571ad0c0be2dc 100644 --- a/setup.py +++ b/setup.py @@ -21,5 +21,6 @@ setup( packages=find_packages(exclude=['tests']), setup_requires=['pytest-runner', 'pytest-pylint'], tests_require=['pytest', 'pylint'], + python_requires='>=3.6', entry_points={'console_scripts': ['combo = combo.main:main']}, ) diff --git a/tests/test_main.py b/tests/test_main.py index 325453c736c016c8dca322e3e8cb9bfbc404f5be..8c422451111583c4201e784a592a4847b62aa6ff 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -37,10 +37,10 @@ class TrainingEndToEndTest(unittest.TestCase): 'cuda_device': '-1', 'num_epochs': '1', 'word_batch_size': '1', + 'use_tensorboard': 'False' } params = Params.from_file(os.path.join(self.PROJECT_ROOT, 'config.template.jsonnet'), ext_vars=ext_vars) - params['trainer']['tensorboard_writer']['serialization_dir'] = os.path.join(self.TEST_DIR, 'metrics') # when model = train.train_model(params, serialization_dir=self.TEST_DIR)