diff --git a/combo/predict.py b/combo/predict.py
index ff4d72b6785454e326f54baae6190f75190b7f3c..50babf4a63f798241acda7d3285e6ad62711da83 100644
--- a/combo/predict.py
+++ b/combo/predict.py
@@ -194,14 +194,15 @@ class COMBO(predictor.Predictor):
 
         if "enhanced_head" in predictions and predictions["enhanced_head"]:
             # TODO off-by-one hotfix, refactor
-            h = np.array(predictions["enhanced_head"])
+            sentence_length = len(tree_tokens)
+            h = np.array(predictions["enhanced_head"])[:sentence_length, :sentence_length]
             h = np.concatenate((h[-1:], h[:-1]))
-            r = np.array(predictions["enhanced_deprel_prob"])
+            r = np.array(predictions["enhanced_deprel_prob"])[:sentence_length, :sentence_length, :]
             r = np.concatenate((r[-1:], r[:-1]))
 
             graph.graph_and_tree_merge(
-                tree_arc_scores=predictions["head"],
-                tree_rel_scores=predictions["deprel"],
+                tree_arc_scores=predictions["head"][:sentence_length],
+                tree_rel_scores=predictions["deprel"][:sentence_length],
                 graph_arc_scores=h,
                 graph_rel_scores=r,
                 idx2label=self.vocab.get_index_to_token_vocabulary("deprel_labels"),
diff --git a/scripts/predict_iwpt21.py b/scripts/predict_iwpt21.py
index 513ce3b7c1c97265c2a1ebebb2beb40a7a931905..61b9cf72a183c728a612507f21e459589b0005b6 100644
--- a/scripts/predict_iwpt21.py
+++ b/scripts/predict_iwpt21.py
@@ -36,6 +36,8 @@ flags.DEFINE_integer(name="cuda_device", default=-1,
                      help="Cuda device id (-1 for cpu).")
 flags.DEFINE_boolean(name="expect_prefix", default=True,
                      help="Whether to expect allennlp prefix.")
+flags.DEFINE_integer(name="batch_size", default=32,
+                     help="Batch size.")
 
 
 def run(_):
@@ -68,6 +70,7 @@ def run(_):
         --input_file {test_file}
         --output_file {output_pred}
         --cuda_device {FLAGS.cuda_device}
+        --batch_size {FLAGS.batch_size}
         --silent
         """
         utils.execute_command(command)