diff --git a/combo/predict.py b/combo/predict.py index a5c99fd883b4ee40c5e9c76af44a7c7dbad85bdc..bd9f5d4637bf410bace029604afb915ed05311c2 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -24,7 +24,7 @@ class COMBO(predictor.Predictor): model: models.Model, dataset_reader: allen_data.DatasetReader, tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(), - batch_size: int = 32, + batch_size: int = 1024, line_to_conllu: bool = True) -> None: super().__init__(model, dataset_reader) self.batch_size = batch_size @@ -140,54 +140,56 @@ class COMBO(predictor.Predictor): tree = instance.fields["metadata"]["input"] field_names = instance.fields["metadata"]["field_names"] tree_tokens = [t for t in tree if isinstance(t["id"], int)] - for idx, token in enumerate(tree_tokens): - for field_name in field_names: - if field_name in predictions: - if field_name in ["xpostag", "upostag", "semrel", "deprel"]: - value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + "_labels") - token[field_name] = value - elif field_name in ["head"]: - token[field_name] = int(predictions[field_name][idx]) - elif field_name == "deps": - # Handled after every other decoding - continue - - elif field_name in ["feats"]: - slices = self._model.morphological_feat.slices - features = [] - prediction = predictions[field_name][idx] - for (cat, cat_indices), pred_idx in zip(slices.items(), prediction): - if cat not in ["__PAD__", "_"]: - value = self.vocab.get_token_from_index(cat_indices[pred_idx], - field_name + "_labels") - # Exclude auxiliary values - if "=None" not in value: - features.append(value) - if len(features) == 0: - field_value = "_" - else: - lowercase_features = [f.lower() for f in features] - arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__) - field_value = "|".join(np.array(features)[arg_indices].tolist()) - - token[field_name] = field_value - elif field_name == "lemma": - prediction = predictions[field_name][idx] - word_chars = [] - for char_idx in prediction[1:-1]: - pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters") - - if pred_char == "__END__": - break - elif pred_char == "__PAD__": - continue - elif "_" in pred_char: - pred_char = "?" - - word_chars.append(pred_char) - token[field_name] = "".join(word_chars) + for field_name in field_names: + if field_name not in predictions: + continue + field_predictions = predictions[field_name] + for idx, token in enumerate(tree_tokens): + if field_name in {"xpostag", "upostag", "semrel", "deprel"}: + value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels") + token[field_name] = value + elif field_name == "head": + token[field_name] = int(field_predictions[idx]) + elif field_name == "deps": + # Handled after every other decoding + continue + + elif field_name == "feats": + slices = self._model.morphological_feat.slices + features = [] + prediction = field_predictions[idx] + for (cat, cat_indices), pred_idx in zip(slices.items(), prediction): + if cat not in ["__PAD__", "_"]: + value = self.vocab.get_token_from_index(cat_indices[pred_idx], + field_name + "_labels") + # Exclude auxiliary values + if "=None" not in value: + features.append(value) + if len(features) == 0: + field_value = "_" else: - raise NotImplementedError(f"Unknown field name {field_name}!") + lowercase_features = [f.lower() for f in features] + arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__) + field_value = "|".join(np.array(features)[arg_indices].tolist()) + + token[field_name] = field_value + elif field_name == "lemma": + prediction = field_predictions[idx] + word_chars = [] + for char_idx in prediction[1:-1]: + pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters") + + if pred_char == "__END__": + break + elif pred_char == "__PAD__": + continue + elif "_" in pred_char: + pred_char = "?" + + word_chars.append(pred_char) + token[field_name] = "".join(word_chars) + else: + raise NotImplementedError(f"Unknown field name {field_name}!") if "enhanced_head" in predictions and predictions["enhanced_head"]: # TODO off-by-one hotfix, refactor @@ -212,7 +214,7 @@ class COMBO(predictor.Predictor): @classmethod def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(), - batch_size: int = 32, + batch_size: int = 1024, cuda_device: int = -1): util.import_module_and_submodules("combo.commands") util.import_module_and_submodules("combo.models")