diff --git a/combo/main.py b/combo/main.py index e6702cb0d2446db2e94f8189bba09f953fb17ef9..dc9a0d702984febe10f93eae1548b5d19046797b 100644 --- a/combo/main.py +++ b/combo/main.py @@ -151,8 +151,7 @@ def run(_): if FLAGS.input_file == "-": use_dataset_reader = False predictor.without_sentence_embedding = True - if use_dataset_reader: - predictor.line_to_conllu = True + predictor.line_to_conllu = False if FLAGS.silent: logging.getLogger("allennlp.common.params").disabled = True manager = allen_predict._PredictManager( diff --git a/combo/predict.py b/combo/predict.py index 9ab1f44c77d1466b1b2178a83890e0de69799a75..a5c99fd883b4ee40c5e9c76af44a7c7dbad85bdc 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,8 +1,7 @@ import logging import os -from typing import List, Union, Tuple +from typing import List, Union, Dict, Any -import conllu import numpy as np from allennlp import data as allen_data, common, models from allennlp.common import util @@ -26,7 +25,7 @@ class COMBO(predictor.Predictor): dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), batch_size: int = 32, - line_to_conllu: bool = False) -> None: + line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size self.vocab = model.vocab @@ -57,18 +56,21 @@ class COMBO(predictor.Predictor): if len(sentence) == 0: return [] example = sentence[0] + sentences = sentence if isinstance(example, str) or isinstance(example, list): - sentences = [] - for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - sentences_batch = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) - sentences.extend(sentences_batch) - return sentences + result = [] + sentences = [self._to_input_json(s) for s in sentences] + for sentences_batch in util.lazy_groups_of(sentences, self.batch_size): + sentences_batch = self.predict_batch_json(sentences_batch) + result.extend(sentences_batch) + return result elif isinstance(example, data.Sentence): - sentences = [] - for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - sentences_batch = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch]) - sentences.extend(sentences_batch) - return sentences + result = [] + sentences = [self._to_input_instance(s) for s in sentences] + for sentences_batch in util.lazy_groups_of(sentences, self.batch_size): + sentences_batch = self.predict_batch_instance(sentences_batch) + result.extend(sentences_batch) + return result else: raise ValueError("List must have either sentences as str, List[str] or Sentence object.") else: @@ -79,36 +81,27 @@ class COMBO(predictor.Predictor): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding = self.predict_instance_as_tree(instance) + tree, sentence_embedding = self._predictions_as_tree(prediction, instance) sentence = conllu2sentence(tree, sentence_embedding) sentences.append(sentence) return sentences @overrides def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[data.Sentence]: - sentences = [] instances = self._batch_json_to_instances(inputs) - predictions = self.predict_batch_instance(instances) - for prediction, instance in zip(predictions, instances): - tree, sentence_embedding = self.predict_instance_as_tree(instance) - sentence = conllu2sentence(tree, sentence_embedding) - sentences.append(sentence) + sentences = self.predict_batch_instance(instances) return sentences @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: - tree, sentence_embedding = self.predict_instance_as_tree(instance) + predictions = super().predict_instance(instance) + tree, sentence_embedding = self._predictions_as_tree(predictions, instance) return conllu2sentence(tree, sentence_embedding) @overrides def predict_json(self, inputs: common.JsonDict) -> data.Sentence: instance = self._json_to_instance(inputs) - tree, sentence_embedding = self.predict_instance_as_tree(instance) - return conllu2sentence(tree, sentence_embedding) - - def predict_instance_as_tree(self, instance: allen_data.Instance) -> Tuple[conllu.TokenList, List[float]]: - predictions = super().predict_instance(instance) - return self._predictions_as_tree(predictions, instance) + return self.predict_instance(instance) @overrides def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: @@ -143,7 +136,7 @@ class COMBO(predictor.Predictor): def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance: return self._dataset_reader.text_to_instance(sentence2conllu(sentence)) - def _predictions_as_tree(self, predictions, instance): + def _predictions_as_tree(self, predictions: Dict[str, Any], instance: allen_data.Instance): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)]