Skip to content
Snippets Groups Projects
Commit 038c9705 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix batch predictions for DEPS.

parent ee349a12
No related merge requests found
Pipeline #2955 passed with stage
in 4 minutes and 58 seconds
...@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor): ...@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
# TODO off-by-one hotfix, refactor # TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"]) sentence_length = len(tree_tokens)
h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
h = np.concatenate((h[-1:], h[:-1])) h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"]) r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
r = np.concatenate((r[-1:], r[:-1])) r = np.concatenate((r[-1:], r[:-1]))
graph.graph_and_tree_merge( graph.graph_and_tree_merge(
tree_arc_scores=predictions["head"], tree_arc_scores=predictions["head"][:sentence_length],
tree_rel_scores=predictions["deprel"], tree_rel_scores=predictions["deprel"][:sentence_length],
graph_arc_scores=h, graph_arc_scores=h,
graph_rel_scores=r, graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
......
...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1, ...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).") help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True, flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.") help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_): def run(_):
...@@ -68,6 +70,7 @@ def run(_): ...@@ -68,6 +70,7 @@ def run(_):
--input_file {test_file} --input_file {test_file}
--output_file {output_pred} --output_file {output_pred}
--cuda_device {FLAGS.cuda_device} --cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent --silent
""" """
utils.execute_command(command) utils.execute_command(command)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment