Skip to content
Snippets Groups Projects
Commit 847f3a90 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Refactor prediction loop.

parent 888e0f11
Branches
Tags
2 merge requests!13Refactor merge develop to master,!12Refactor
......@@ -24,7 +24,7 @@ class COMBO(predictor.Predictor):
model: models.Model,
dataset_reader: allen_data.DatasetReader,
tokenizer: allen_data.Tokenizer = tokenizers.WhitespaceTokenizer(),
batch_size: int = 32,
batch_size: int = 1024,
line_to_conllu: bool = True) -> None:
super().__init__(model, dataset_reader)
self.batch_size = batch_size
......@@ -140,54 +140,56 @@ class COMBO(predictor.Predictor):
tree = instance.fields["metadata"]["input"]
field_names = instance.fields["metadata"]["field_names"]
tree_tokens = [t for t in tree if isinstance(t["id"], int)]
for idx, token in enumerate(tree_tokens):
for field_name in field_names:
if field_name in predictions:
if field_name in ["xpostag", "upostag", "semrel", "deprel"]:
value = self.vocab.get_token_from_index(predictions[field_name][idx], field_name + "_labels")
token[field_name] = value
elif field_name in ["head"]:
token[field_name] = int(predictions[field_name][idx])
elif field_name == "deps":
# Handled after every other decoding
continue
elif field_name in ["feats"]:
slices = self._model.morphological_feat.slices
features = []
prediction = predictions[field_name][idx]
for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
if cat not in ["__PAD__", "_"]:
value = self.vocab.get_token_from_index(cat_indices[pred_idx],
field_name + "_labels")
# Exclude auxiliary values
if "=None" not in value:
features.append(value)
if len(features) == 0:
field_value = "_"
else:
lowercase_features = [f.lower() for f in features]
arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
field_value = "|".join(np.array(features)[arg_indices].tolist())
token[field_name] = field_value
elif field_name == "lemma":
prediction = predictions[field_name][idx]
word_chars = []
for char_idx in prediction[1:-1]:
pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
if pred_char == "__END__":
break
elif pred_char == "__PAD__":
continue
elif "_" in pred_char:
pred_char = "?"
word_chars.append(pred_char)
token[field_name] = "".join(word_chars)
for field_name in field_names:
if field_name not in predictions:
continue
field_predictions = predictions[field_name]
for idx, token in enumerate(tree_tokens):
if field_name in {"xpostag", "upostag", "semrel", "deprel"}:
value = self.vocab.get_token_from_index(field_predictions[idx], field_name + "_labels")
token[field_name] = value
elif field_name == "head":
token[field_name] = int(field_predictions[idx])
elif field_name == "deps":
# Handled after every other decoding
continue
elif field_name == "feats":
slices = self._model.morphological_feat.slices
features = []
prediction = field_predictions[idx]
for (cat, cat_indices), pred_idx in zip(slices.items(), prediction):
if cat not in ["__PAD__", "_"]:
value = self.vocab.get_token_from_index(cat_indices[pred_idx],
field_name + "_labels")
# Exclude auxiliary values
if "=None" not in value:
features.append(value)
if len(features) == 0:
field_value = "_"
else:
raise NotImplementedError(f"Unknown field name {field_name}!")
lowercase_features = [f.lower() for f in features]
arg_indices = sorted(range(len(lowercase_features)), key=lowercase_features.__getitem__)
field_value = "|".join(np.array(features)[arg_indices].tolist())
token[field_name] = field_value
elif field_name == "lemma":
prediction = field_predictions[idx]
word_chars = []
for char_idx in prediction[1:-1]:
pred_char = self.vocab.get_token_from_index(char_idx, "lemma_characters")
if pred_char == "__END__":
break
elif pred_char == "__PAD__":
continue
elif "_" in pred_char:
pred_char = "?"
word_chars.append(pred_char)
token[field_name] = "".join(word_chars)
else:
raise NotImplementedError(f"Unknown field name {field_name}!")
if "enhanced_head" in predictions and predictions["enhanced_head"]:
# TODO off-by-one hotfix, refactor
......@@ -212,7 +214,7 @@ class COMBO(predictor.Predictor):
@classmethod
def from_pretrained(cls, path: str, tokenizer=tokenizers.SpacyTokenizer(),
batch_size: int = 32,
batch_size: int = 1024,
cuda_device: int = -1):
util.import_module_and_submodules("combo.commands")
util.import_module_and_submodules("combo.models")
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment