Skip to content
Snippets Groups Projects
Commit 038c9705 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix batch predictions for DEPS.

parent ee349a12
No related merge requests found
Pipeline #2955 passed with stage
in 4 minutes and 58 seconds
......@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]:
# TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"])
sentence_length = len(tree_tokens)
h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"])
r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
r = np.concatenate((r[-1:], r[:-1]))
graph.graph_and_tree_merge(
tree_arc_scores=predictions["head"],
tree_rel_scores=predictions["deprel"],
tree_arc_scores=predictions["head"][:sentence_length],
tree_rel_scores=predictions["deprel"][:sentence_length],
graph_arc_scores=h,
graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
......
......@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_):
......@@ -68,6 +70,7 @@ def run(_):
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent
"""
utils.execute_command(command)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment