diff --git a/combo/predict.py b/combo/predict.py index 35dbbee3af5720b110c37cee4c4e12f72243c8c4..2e5693f75a635915d4e7ed4f7c063bcda21fa8e2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -79,7 +79,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): for prediction, instance in zip(predictions, instances): tree = self._predictions_as_tree(prediction, instance) if serialize: - tree = tree.serialize() + tree = self._serialize(tree) tree_json = util.sanitize(tree) trees.append(collections.OrderedDict([ ("tree", tree_json), @@ -89,11 +89,13 @@ class SemanticMultitaskPredictor(predictor.Predictor): @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> common.JsonDict: tree = self.predict_instance_as_tree(instance) + sentence_embedding = tree.metadata.get("sentence_embedding", []) if serialize: - tree = tree.serialize() + tree = self._serialize(tree) tree_json = util.sanitize(tree) result = collections.OrderedDict([ ("tree", tree_json), + ("sentence_embedding", sentence_embedding) ]) return result @@ -118,6 +120,7 @@ class SemanticMultitaskPredictor(predictor.Predictor): result = collections.OrderedDict([ ("tree", tree_json), ]) + result["sentence_embedding"] = tree.metadata.get("sentence_embedding", []) return result def predict_instance_as_tree(self, instance: allen_data.Instance) -> conllu.TokenList: @@ -217,9 +220,15 @@ class SemanticMultitaskPredictor(predictor.Predictor): raise NotImplementedError(f"Unknown field name {field_name}!") if self._dataset_reader and "sent" in self._dataset_reader._targets: - tree.metadata["sentence_embedding"] = str(predictions["sentence_embedding"]) + tree.metadata["sentence_embedding"] = predictions["sentence_embedding"] return tree + @staticmethod + def _serialize(tree: conllu.TokenList): + if "sentence_embedding" in tree.metadata: + tree.metadata["sentence_embedding"] = str(tree.metadata["sentence_embedding"]) + return tree.serialize() + @classmethod def with_spacy_tokenizer(cls, model: models.Model, dataset_reader: allen_data.DatasetReader):