diff --git a/combo/main.py b/combo/main.py
index 261e79759b97b4b692380f7a094c61f63ae0bc9a..3884f0511d8a0903a19ab0fb451da0827600b3d0 100755
--- a/combo/main.py
+++ b/combo/main.py
@@ -161,7 +161,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
                                                    )
 
     if not training_data_loader:
-        training_data_loader = default_data_loader(dataset_reader, training_data_path)
+        training_data_loader = default_data_loader(dataset_reader, training_data_path, FLAGS.batch_size)
     else:
         if training_data_path:
             training_data_loader.data_path = training_data_path
@@ -170,7 +170,7 @@ def get_defaults(dataset_reader: Optional[DatasetReader],
                            str(training_data_loader.data_path), prefix=prefix)
 
     if FLAGS.validation_data_path and not validation_data_loader:
-        validation_data_loader = default_data_loader(dataset_reader, validation_data_path)
+        validation_data_loader = default_data_loader(dataset_reader, validation_data_path, FLAGS.batch_size)
     elif FLAGS.validation_data_path and validation_data_loader:
         if validation_data_path:
             validation_data_loader.data_path = validation_data_path