From de9301ada5a70cac4da84282c7a2b8b44acbb147 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl>
Date: Tue, 11 Apr 2023 22:49:10 +0200
Subject: [PATCH] Predict and save deprel matrix

---
 combo/data/api.py | 34 +++++++++++++++++++++++++++-------
 combo/main.py     |  4 ++++
 combo/predict.py  |  4 +++-
 3 files changed, 34 insertions(+), 8 deletions(-)

diff --git a/combo/data/api.py b/combo/data/api.py
index adb1b27..deec11f 100644
--- a/combo/data/api.py
+++ b/combo/data/api.py
@@ -1,7 +1,13 @@
 import collections
 import dataclasses
 import json
+import os
+import string
 from dataclasses import dataclass, field
+from json import JSONEncoder
+import random
+
+import numpy
 from typing import Optional, List, Dict, Any, Union, Tuple
 
 import conllu
@@ -32,14 +38,28 @@ class Sentence:
     relation_label_distribution: List[float] = field(default_factory=list, repr=False)
     metadata: Dict[str, Any] = field(default_factory=collections.OrderedDict)
 
-    def to_json(self):
+    def to_json(self, save_relation_distribution_path=None):
+        class NumpyArrayEncoder(JSONEncoder):
+            def default(self, obj):
+                if isinstance(obj, numpy.ndarray):
+                    return obj.tolist()
+                return JSONEncoder.default(self, obj)
+
+        hash = ''.join(random.sample(string.ascii_letters + string.digits, 32))
+        numpy.savez(
+            os.path.join(save_relation_distribution_path, hash + '.npz'),
+            relation_distribution=self.relation_distribution,
+            relation_distribution_trimmed=self.relation_distribution[1:, 1:],
+            relation_label_distribution=self.relation_label_distribution
+        )
+        # numpy.savez(hash + '_trimmed.npz', self.relation_distribution[1:, 1:])
+
         return json.dumps({
-            "tokens": [dataclasses.asdict(t) for t in self.tokens],
-            "sentence_embedding": self.sentence_embedding,
-            "relation_distribution": self.relation_distribution,
-            "relation_label_distribution": self.relation_label_distribution,
-            "metadata": self.metadata,
-        })
+            "tokens": [(t.token, t.lemma, t.upostag, t.xpostag, t. feats, t.head, t.deprel) for t in self.tokens],
+            # "sentence_embedding": self.sentence_embedding,
+            "relation_distribution_hash": hash,
+            "path_file": str(os.path.join(save_relation_distribution_path, hash + '.npz'))
+        }, cls=NumpyArrayEncoder)
 
     def __len__(self):
         return len(self.tokens)
diff --git a/combo/main.py b/combo/main.py
index d1e0292..f319ae1 100644
--- a/combo/main.py
+++ b/combo/main.py
@@ -88,6 +88,8 @@ flags.DEFINE_boolean(name="silent", default=True,
 flags.DEFINE_enum(name="predictor_name", default="combo-spacy",
                   enum_values=["combo", "combo-spacy"],
                   help="Use predictor with whitespace or spacy tokenizer.")
+flags.DEFINE_string(name="save_relation_distribution_path", default=None,
+                     help="Save relation distribution to file.")
 
 
 def run(_):
@@ -153,6 +155,8 @@ def run(_):
             use_dataset_reader = False
             predictor.without_sentence_embedding = True
             predictor.line_to_conllu = False
+        predictor.line_to_conllu = False
+        predictor.save_relation_distribution_path = FLAGS.save_relation_distribution_path
         if FLAGS.silent:
             logging.getLogger("allennlp.common.params").disabled = True
         manager = allen_predict._PredictManager(
diff --git a/combo/predict.py b/combo/predict.py
index cdcd694..ef44be2 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -134,7 +134,9 @@ class COMBO(predictor.Predictor):
         if self.line_to_conllu:
             return sentence2conllu(outputs, keep_semrel=self.dataset_reader.use_sem).serialize()
         else:
-            return outputs.to_json()
+            if not os.path.exists(self.save_relation_distribution_path):
+                os.makedirs(self.save_relation_distribution_path)
+            return outputs.to_json(self.save_relation_distribution_path)
 
     @staticmethod
     def _to_input_json(sentence: str):
-- 
GitLab