Skip to content
Snippets Groups Projects
Commit 038c9705 authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix batch predictions for DEPS.

parent ee349a12
No related branches found
No related tags found
No related merge requests found
Pipeline #2955 passed
...@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor): ...@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
# TODO off-by-one hotfix, refactor # TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"]) sentence_length = len(tree_tokens)
h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
h = np.concatenate((h[-1:], h[:-1])) h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"]) r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
r = np.concatenate((r[-1:], r[:-1])) r = np.concatenate((r[-1:], r[:-1]))
graph.graph_and_tree_merge( graph.graph_and_tree_merge(
tree_arc_scores=predictions["head"], tree_arc_scores=predictions["head"][:sentence_length],
tree_rel_scores=predictions["deprel"], tree_rel_scores=predictions["deprel"][:sentence_length],
graph_arc_scores=h, graph_arc_scores=h,
graph_rel_scores=r, graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
......
...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1, ...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).") help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True, flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.") help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_): def run(_):
...@@ -68,6 +70,7 @@ def run(_): ...@@ -68,6 +70,7 @@ def run(_):
--input_file {test_file} --input_file {test_file}
--output_file {output_pred} --output_file {output_pred}
--cuda_device {FLAGS.cuda_device} --cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent --silent
""" """
utils.execute_command(command) utils.execute_command(command)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment