From cffe54c397a9be9d1927e3433e03773f9b7f8d5b Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Martyna=20Wi=C4=85cek?= <martyna.wiacek@ipipan.waw.pl>
Date: Thu, 23 Nov 2023 22:03:36 +0100
Subject: [PATCH] fixed passing relation_distribution and
 relation_label_distribution

---
 combo/predict.py | 9 +++++++--
 1 file changed, 7 insertions(+), 2 deletions(-)

diff --git a/combo/predict.py b/combo/predict.py
index 8363e50..74c1ca5 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -109,7 +109,8 @@ class COMBO(PredictorModule):
              relation_distribution, relation_label_distribution) = self._predictions_as_tree(prediction, instance)
             sentence = conllu2sentence(
                 tree, sentence_embedding, embeddings,
-                relation_distribution, relation_label_distribution
+                relation_distribution=relation_distribution,
+                relation_label_distribution=relation_label_distribution
             )
             sentences.append(sentence)
         return sentences
@@ -125,7 +126,11 @@ class COMBO(PredictorModule):
         predictions = super().predict_instance(instance)
         (tree, sentence_embedding, embeddings,
          relation_distribution, relation_label_distribution) = self._predictions_as_tree(predictions, instance)
-        return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution)
+        return conllu2sentence(
+            tree, sentence_embedding, embeddings,
+            relation_distribution=relation_distribution,
+            relation_label_distribution=relation_label_distribution
+        )
 
     @overrides
     def predict_json(self, inputs: JsonDict) -> data.Sentence:
-- 
GitLab