diff --git a/scripts/train_iwpt21.py b/scripts/train_iwpt21.py index e4705f7dc35f82496488ff7e2b7d1dba1838ec5e..b5838c491acb41390c1bd70329bced4b7f044df2 100644 --- a/scripts/train_iwpt21.py +++ b/scripts/train_iwpt21.py @@ -119,7 +119,6 @@ def run(_): --pretrained_transformer_name {utils.LANG2TRANSFORMER[lang]} --serialization_dir {serialization_dir} --cuda_device {FLAGS.cuda_device} - --word_batch_size 2500 --config_path {pathlib.Path.cwd() / 'combo' / 'config.graph.template.jsonnet'} --notensorboard """ @@ -128,6 +127,11 @@ def run(_): if lang in {"fr", "ru"}: command = command + " --targets deprel,head,upostag,lemma,feats" + if lang in {"ta"}: + command = command + " --word_batch_size 500" + else: + command = command + " --word_batch_size 2500" + utils.execute_command("".join(command.splitlines()))