From de9301ada5a70cac4da84282c7a2b8b44acbb147 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl> Date: Tue, 11 Apr 2023 22:49:10 +0200 Subject: [PATCH] Predict and save deprel matrix --- combo/data/api.py | 34 +++++++++++++++++++++++++++------- combo/main.py | 4 ++++ combo/predict.py | 4 +++- 3 files changed, 34 insertions(+), 8 deletions(-) diff --git a/combo/data/api.py b/combo/data/api.py index adb1b27..deec11f 100644 --- a/combo/data/api.py +++ b/combo/data/api.py @@ -1,7 +1,13 @@ import collections import dataclasses import json +import os +import string from dataclasses import dataclass, field +from json import JSONEncoder +import random + +import numpy from typing import Optional, List, Dict, Any, Union, Tuple import conllu @@ -32,14 +38,28 @@ class Sentence: relation_label_distribution: List[float] = field(default_factory=list, repr=False) metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict) - def to_json(self): + def to_json(self, save_relation_distribution_path=None): + class NumpyArrayEncoder(JSONEncoder): + def default(self, obj): + if isinstance(obj, numpy.ndarray): + return obj.tolist() + return JSONEncoder.default(self, obj) + + hash = ''.join(random.sample(string.ascii_letters + string.digits, 32)) + numpy.savez( + os.path.join(save_relation_distribution_path, hash + '.npz'), + relation_distribution=self.relation_distribution, + relation_distribution_trimmed=self.relation_distribution[1:, 1:], + relation_label_distribution=self.relation_label_distribution + ) + # numpy.savez(hash + '_trimmed.npz', self.relation_distribution[1:, 1:]) + return json.dumps({ - "tokens": [dataclasses.asdict(t) for t in self.tokens], - "sentence_embedding": self.sentence_embedding, - "relation_distribution": self.relation_distribution, - "relation_label_distribution": self.relation_label_distribution, - "metadata": self.metadata, - }) + "tokens": [(t.token, t.lemma, t.upostag, t.xpostag, t. feats, t.head, t.deprel) for t in self.tokens], + # "sentence_embedding": self.sentence_embedding, + "relation_distribution_hash": hash, + "path_file": str(os.path.join(save_relation_distribution_path, hash + '.npz')) + }, cls=NumpyArrayEncoder) def __len__(self): return len(self.tokens) diff --git a/combo/main.py b/combo/main.py index d1e0292..f319ae1 100644 --- a/combo/main.py +++ b/combo/main.py @@ -88,6 +88,8 @@ flags.DEFINE_boolean(name="silent", default=True, flags.DEFINE_enum(name="predictor_name", default="combo-spacy", enum_values=["combo", "combo-spacy"], help="Use predictor with whitespace or spacy tokenizer.") +flags.DEFINE_string(name="save_relation_distribution_path", default=None, + help="Save relation distribution to file.") def run(_): @@ -153,6 +155,8 @@ def run(_): use_dataset_reader = False predictor.without_sentence_embedding = True predictor.line_to_conllu = False + predictor.line_to_conllu = False + predictor.save_relation_distribution_path = FLAGS.save_relation_distribution_path if FLAGS.silent: logging.getLogger("allennlp.common.params").disabled = True manager = allen_predict._PredictManager( diff --git a/combo/predict.py b/combo/predict.py index cdcd694..ef44be2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -134,7 +134,9 @@ class COMBO(predictor.Predictor): if self.line_to_conllu: return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize() else: - return outputs.to_json() + if not os.path.exists(self.save_relation_distribution_path): + os.makedirs(self.save_relation_distribution_path) + return outputs.to_json(self.save_relation_distribution_path) @staticmethod def _to_input_json(sentence: str): -- GitLab