Skip to content
Snippets Groups Projects
Commit 888e0f11 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix bug with double predictions.

parent 4386185d
Branches
Tags
2 merge requests!11Fix bug with double predictions - develop to master merge.,!10Double prediction fix
......@@ -151,8 +151,7 @@ def run(_):
if FLAGS.input_file == "-":
use_dataset_reader = False
predictor.without_sentence_embedding = True
if use_dataset_reader:
predictor.line_to_conllu = True
predictor.line_to_conllu = False
if FLAGS.silent:
logging.getLogger("allennlp.common.params").disabled = True
manager = allen_predict._PredictManager(
......
import logging
import os
from typing import List, Union, Tuple
from typing import List, Union, Dict, Any
import conllu
import numpy as np
from allennlp import data as allen_data, common, models
from allennlp.common import util
......@@ -26,7 +25,7 @@ class COMBO(predictor.Predictor):
dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
batch_size: int = 32,
line_to_conllu: bool = False) -> None:
line_to_conllu: bool = True) -> None:
super().__init__(model, dataset_reader)
self.batch_size = batch_size
self.vocab = model.vocab
......@@ -57,18 +56,21 @@ class COMBO(predictor.Predictor):
if len(sentence) == 0:
return []
example = sentence[0]
sentences = sentence
if isinstance(example, str) or isinstance(example, list):
sentences = []
for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
sentences_batch = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch])
sentences.extend(sentences_batch)
return sentences
result = []
sentences = [self._to_input_json(s) for s in sentences]
for sentences_batch in util.lazy_groups_of(sentences, self.batch_size):
sentences_batch = self.predict_batch_json(sentences_batch)
result.extend(sentences_batch)
return result
elif isinstance(example, data.Sentence):
sentences = []
for sentences_batch in util.lazy_groups_of(sentence, self.batch_size):
sentences_batch = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch])
sentences.extend(sentences_batch)
return sentences
result = []
sentences = [self._to_input_instance(s) for s in sentences]
for sentences_batch in util.lazy_groups_of(sentences, self.batch_size):
sentences_batch = self.predict_batch_instance(sentences_batch)
result.extend(sentences_batch)
return result
else:
raise ValueError("List must have either sentences as str, List[str] or Sentence object.")
else:
......@@ -79,36 +81,27 @@ class COMBO(predictor.Predictor):
sentences = []
predictions = super().predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances):
tree, sentence_embedding = self.predict_instance_as_tree(instance)
tree, sentence_embedding = self._predictions_as_tree(prediction, instance)
sentence = conllu2sentence(tree, sentence_embedding)
sentences.append(sentence)
return sentences
@overrides
def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[data.Sentence]:
sentences = []
instances = self._batch_json_to_instances(inputs)
predictions = self.predict_batch_instance(instances)
for prediction, instance in zip(predictions, instances):
tree, sentence_embedding = self.predict_instance_as_tree(instance)
sentence = conllu2sentence(tree, sentence_embedding)
sentences.append(sentence)
sentences = self.predict_batch_instance(instances)
return sentences
@overrides
def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence:
tree, sentence_embedding = self.predict_instance_as_tree(instance)
predictions = super().predict_instance(instance)
tree, sentence_embedding = self._predictions_as_tree(predictions, instance)
return conllu2sentence(tree, sentence_embedding)
@overrides
def predict_json(self, inputs: common.JsonDict) -> data.Sentence:
instance = self._json_to_instance(inputs)
tree, sentence_embedding = self.predict_instance_as_tree(instance)
return conllu2sentence(tree, sentence_embedding)
def predict_instance_as_tree(self, instance: allen_data.Instance) -> Tuple[conllu.TokenList, List[float]]:
predictions = super().predict_instance(instance)
return self._predictions_as_tree(predictions, instance)
return self.predict_instance(instance)
@overrides
def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance:
......@@ -143,7 +136,7 @@ class COMBO(predictor.Predictor):
def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance:
return self._dataset_reader.text_to_instance(sentence2conllu(sentence))
def _predictions_as_tree(self, predictions, instance):
def _predictions_as_tree(self, predictions: Dict[str, Any], instance: allen_data.Instance):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment