diff --git a/combo/main.py b/combo/main.py index 17c960ac7caa84513692841abb989955c7925721..dc9a0d702984febe10f93eae1548b5d19046797b 100644 --- a/combo/main.py +++ b/combo/main.py @@ -84,8 +84,8 @@ flags.DEFINE_integer(name="batch_size", default=1, help="Prediction batch size.") flags.DEFINE_boolean(name="silent", default=True, help="Silent prediction to file (without printing to console).") -flags.DEFINE_enum(name="predictor_name", default="semantic-multitask-predictor-spacy", - enum_values=["semantic-multitask-predictor", "semantic-multitask-predictor-spacy"], +flags.DEFINE_enum(name="predictor_name", default="combo-spacy", + enum_values=["combo", "combo-spacy"], help="Use predictor with whitespace or spacy tokenizer.") @@ -151,8 +151,7 @@ def run(_): if FLAGS.input_file == "-": use_dataset_reader = False predictor.without_sentence_embedding = True - if use_dataset_reader: - predictor.line_to_conllu = True + predictor.line_to_conllu = False if FLAGS.silent: logging.getLogger("allennlp.common.params").disabled = True manager = allen_predict._PredictManager( diff --git a/combo/predict.py b/combo/predict.py index 21941d91d56170e7c552af2a3ac1af229816f76d..a5c99fd883b4ee40c5e9c76af44a7c7dbad85bdc 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -1,8 +1,7 @@ import logging import os -from typing import List, Union, Tuple +from typing import List, Union, Dict, Any -import conllu import numpy as np from allennlp import data as allen_data, common, models from allennlp.common import util @@ -17,8 +16,8 @@ from combo.utils import download, graph logger = logging.getLogger(__name__) -@predictor.Predictor.register("semantic-multitask-predictor") -@predictor.Predictor.register("semantic-multitask-predictor-spacy", constructor="with_spacy_tokenizer") +@predictor.Predictor.register("combo") +@predictor.Predictor.register("combo-spacy", constructor="with_spacy_tokenizer") class COMBO(predictor.Predictor): def __init__(self, @@ -26,7 +25,7 @@ class COMBO(predictor.Predictor): dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), batch_size: int = 32, - line_to_conllu: bool = False) -> None: + line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size self.vocab = model.vocab @@ -57,18 +56,21 @@ class COMBO(predictor.Predictor): if len(sentence) == 0: return [] example = sentence[0] + sentences = sentence if isinstance(example, str) or isinstance(example, list): - sentences = [] - for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - sentences_batch = self.predict_batch_json([self._to_input_json(s) for s in sentences_batch]) - sentences.extend(sentences_batch) - return sentences + result = [] + sentences = [self._to_input_json(s) for s in sentences] + for sentences_batch in util.lazy_groups_of(sentences, self.batch_size): + sentences_batch = self.predict_batch_json(sentences_batch) + result.extend(sentences_batch) + return result elif isinstance(example, data.Sentence): - sentences = [] - for sentences_batch in util.lazy_groups_of(sentence, self.batch_size): - sentences_batch = self.predict_batch_instance([self._to_input_instance(s) for s in sentences_batch]) - sentences.extend(sentences_batch) - return sentences + result = [] + sentences = [self._to_input_instance(s) for s in sentences] + for sentences_batch in util.lazy_groups_of(sentences, self.batch_size): + sentences_batch = self.predict_batch_instance(sentences_batch) + result.extend(sentences_batch) + return result else: raise ValueError("List must have either sentences as str, List[str] or Sentence object.") else: @@ -79,36 +81,27 @@ class COMBO(predictor.Predictor): sentences = [] predictions = super().predict_batch_instance(instances) for prediction, instance in zip(predictions, instances): - tree, sentence_embedding = self.predict_instance_as_tree(instance) + tree, sentence_embedding = self._predictions_as_tree(prediction, instance) sentence = conllu2sentence(tree, sentence_embedding) sentences.append(sentence) return sentences @overrides def predict_batch_json(self, inputs: List[common.JsonDict]) -> List[data.Sentence]: - sentences = [] instances = self._batch_json_to_instances(inputs) - predictions = self.predict_batch_instance(instances) - for prediction, instance in zip(predictions, instances): - tree, sentence_embedding = self.predict_instance_as_tree(instance) - sentence = conllu2sentence(tree, sentence_embedding) - sentences.append(sentence) + sentences = self.predict_batch_instance(instances) return sentences @overrides def predict_instance(self, instance: allen_data.Instance, serialize: bool = True) -> data.Sentence: - tree, sentence_embedding = self.predict_instance_as_tree(instance) + predictions = super().predict_instance(instance) + tree, sentence_embedding = self._predictions_as_tree(predictions, instance) return conllu2sentence(tree, sentence_embedding) @overrides def predict_json(self, inputs: common.JsonDict) -> data.Sentence: instance = self._json_to_instance(inputs) - tree, sentence_embedding = self.predict_instance_as_tree(instance) - return conllu2sentence(tree, sentence_embedding) - - def predict_instance_as_tree(self, instance: allen_data.Instance) -> Tuple[conllu.TokenList, List[float]]: - predictions = super().predict_instance(instance) - return self._predictions_as_tree(predictions, instance) + return self.predict_instance(instance) @overrides def _json_to_instance(self, json_dict: common.JsonDict) -> allen_data.Instance: @@ -143,7 +136,7 @@ class COMBO(predictor.Predictor): def _to_input_instance(self, sentence: data.Sentence) -> allen_data.Instance: return self._dataset_reader.text_to_instance(sentence2conllu(sentence)) - def _predictions_as_tree(self, predictions, instance): + def _predictions_as_tree(self, predictions: Dict[str, Any], instance: allen_data.Instance): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] diff --git a/docs/prediction.md b/docs/prediction.md index 6de5d0e1892389ba5cd18c25b88947db3f717074..125c1298c0df6b4a8a9b0ce0544e90422f4ae155 100644 --- a/docs/prediction.md +++ b/docs/prediction.md @@ -28,7 +28,7 @@ combo --mode predict --model_path your_model_tar_gz --input_file your_text_file There are 2 tokenizers: whitespace and spacy-based (`en_core_web_sm` model). -Use either `--predictor_name semantic-multitask-predictor` or `--predictor_name semantic-multitask-predictor-spacy`. +Use either `--predictor_name combo` or `--predictor_name combo-spacy`. ## Python ```python