diff --git a/combo/predict.py b/combo/predict.py index ff4d72b6785454e326f54baae6190f75190b7f3c..50babf4a63f798241acda7d3285e6ad62711da83 100644 --- a/combo/predict.py +++ b/combo/predict.py @@ -194,14 +194,15 @@ class COMBO(predictor.Predictor): if "enhanced_head" in predictions and predictions["enhanced_head"]: # TODO off-by-one hotfix, refactor - h = np.array(predictions["enhanced_head"]) + sentence_length = len(tree_tokens) + h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length] h = np.concatenate((h[-1:], h[:-1])) - r = np.array(predictions["enhanced_deprel_prob"]) + r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :] r = np.concatenate((r[-1:], r[:-1])) graph.graph_and_tree_merge( - tree_arc_scores=predictions["head"], - tree_rel_scores=predictions["deprel"], + tree_arc_scores=predictions["head"][:sentence_length], + tree_rel_scores=predictions["deprel"][:sentence_length], graph_arc_scores=h, graph_rel_scores=r, idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"), diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py index 513ce3b7c1c97265c2a1ebebb2beb40a7a931905..61b9cf72a183c728a612507f21e459589b0005b6 100644 --- a/scripts/predict_iwpt21.py +++ b/scripts/predict_iwpt21.py @@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1, help="Cuda device id (-1 for cpu).") flags.DEFINE_boolean(name="expect_prefix", default=True, help="Whether to expect allennlp prefix.") +flags.DEFINE_integer(name="batch_size", default=32, + help="Batch size.") def run(_): @@ -68,6 +70,7 @@ def run(_): --input_file {test_file} --output_file {output_pred} --cuda_device {FLAGS.cuda_device} + --batch_size {FLAGS.batch_size} --silent """ utils.execute_command(command)