diff --git a/combo/predict.py b/combo/predict.py index 8363e50e74638cad4e3751a49baebc03cfb91a9c..74c1ca57ddd5110d63e00130ebd05c6a492c9be5 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -109,7 +109,8 @@ class COMBO(PredictorModule): relation_distribution, relation_label_distribution) = self._predictions_as_tree(prediction, instance) sentence = conllu2sentence( tree, sentence_embedding, embeddings, - relation_distribution, relation_label_distribution + relation_distribution=relation_distribution, + relation_label_distribution=relation_label_distribution ) sentences.append(sentence) return sentences @@ -125,7 +126,11 @@ class COMBO(PredictorModule): predictions = super().predict_instance(instance) (tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) = self._predictions_as_tree(predictions, instance) - return conllu2sentence(tree, sentence_embedding, embeddings, relation_distribution, relation_label_distribution) + return conllu2sentence( + tree, sentence_embedding, embeddings, + relation_distribution=relation_distribution, + relation_label_distribution=relation_label_distribution + ) @overrides def predict_json(self, inputs: JsonDict) -> data.Sentence: