Skip to content
Snippets Groups Projects
Commit 4eaad34d authored by Mateusz Klimaszewski's avatar Mateusz Klimaszewski
Browse files

Fix batch predictions for DEPS.

parent ce2d7be6
Branches
Tags
2 merge requests!37Release 1.0.4.,!36Release 1.0.4
This commit is part of merge request !36. Comments created here will be created in the context of that merge request.
...@@ -197,14 +197,15 @@ class COMBO(predictor.Predictor): ...@@ -197,14 +197,15 @@ class COMBO(predictor.Predictor):
if "enhanced_head" in predictions and predictions["enhanced_head"]: if "enhanced_head" in predictions and predictions["enhanced_head"]:
# TODO off-by-one hotfix, refactor # TODO off-by-one hotfix, refactor
h = np.array(predictions["enhanced_head"]) sentence_length = len(tree_tokens)
h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
h = np.concatenate((h[-1:], h[:-1])) h = np.concatenate((h[-1:], h[:-1]))
r = np.array(predictions["enhanced_deprel_prob"]) r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
r = np.concatenate((r[-1:], r[:-1])) r = np.concatenate((r[-1:], r[:-1]))
graph.graph_and_tree_merge( graph.graph_and_tree_merge(
tree_arc_scores=predictions["head"], tree_arc_scores=predictions["head"][:sentence_length],
tree_rel_scores=predictions["deprel"], tree_rel_scores=predictions["deprel"][:sentence_length],
graph_arc_scores=h, graph_arc_scores=h,
graph_rel_scores=r, graph_rel_scores=r,
idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
......
...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1, ...@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).") help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True, flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.") help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_): def run(_):
...@@ -68,6 +70,7 @@ def run(_): ...@@ -68,6 +70,7 @@ def run(_):
--input_file {test_file} --input_file {test_file}
--output_file {output_pred} --output_file {output_pred}
--cuda_device {FLAGS.cuda_device} --cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent --silent
""" """
utils.execute_command(command) utils.execute_command(command)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment