From dc482287a1ca3295b3e714246384286e09c5e29e Mon Sep 17 00:00:00 2001 From: Mateusz Klimaszewski <mk.klimaszewski@gmail.com> Date: Thu, 30 Jul 2020 09:32:04 +0200 Subject: [PATCH] Fix handling sentence embeddings. --- combo/predict.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/combo/predict.py b/combo/predict.py index 35dbbee..2e5693f 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -79,7 +79,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): for prediction, instance in zip(predictions, instances): tree = self._predictions_as_tree(prediction, instance) if serialize: - tree = tree.serialize() + tree = self._serialize(tree) tree_json = util.sanitize(tree) trees.append(collections.OrderedDict([ ("tree", tree_json), @@ -89,11 +89,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict: tree = self.predict_instance_as_tree(instance) + sentence_embedding = tree.metadata.get("sentence_embedding", []) if serialize: - tree = tree.serialize() + tree = self._serialize(tree) tree_json = util.sanitize(tree) result = collections.OrderedDict([ ("tree", tree_json), + ("sentence_embedding", sentence_embedding) ]) return result @@ -118,6 +120,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): result = collections.OrderedDict([ ("tree", tree_json), ]) + result["sentence_embedding"] = tree.metadata.get("sentence_embedding", []) return result def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: @@ -217,9 +220,15 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise NotImplementedError(f"Unknown field name {field_name}!") if self._dataset_reader and "sent" in self._dataset_reader._targets: - tree.metadata["sentence_embedding"] = str(predictions["sentence_embedding"]) + tree.metadata["sentence_embedding"] = predictions["sentence_embedding"] return tree + @staticmethod + def _serialize(tree: conllu.TokenList): + if "sentence_embedding" in tree.metadata: + tree.metadata["sentence_embedding"] = str(tree.metadata["sentence_embedding"]) + return tree.serialize() + @classmethod def with_spacy_tokenizer(cls, model: models.Model, dataset_reader: allen_data.DatasetReader): -- GitLab