Skip to content
Snippets Groups Projects
Commit 5118ae5b authored by Michał Marcińczuk's avatar Michał Marcińczuk
Browse files

Write training arguments to the file.

parent b8d24205
No related branches found
No related tags found
1 merge request!41Dev v07
Pipeline #6242 failed
......@@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function
import argparse
import glob
import json
import logging
import os
import sys
......@@ -82,6 +83,7 @@ def train_model(args: Namespace):
"empty." % args.output_dir)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
json.dump(args.__dict__, open(str(Path(args.output_dir) / "train_args.json"), "w", encoding="utf-8"), indent=4)
logger = logging.getLogger(__name__)
for item in sorted(config.items()):
......@@ -169,7 +171,7 @@ def train_model(args: Namespace):
if args.freeze_model:
logger.info("Freezing XLM-R model...")
for n, p in model.named_parameters():
if 'xlmr' in n and p.requires_grad:
if 'encoder' in n and p.requires_grad:
logging.info("Parameter %s - freezed" % n)
p.requires_grad = False
else:
......@@ -253,7 +255,7 @@ def train_model(args: Namespace):
epoch_stats["epoch_training_time"] = time.time() - time_start
if args.data_tune:
logger.info("\nTesting on validation set...")
logger.info("Testing on validation set...")
time_start = time.time()
f1, precision, recall, report = evaluate_model(model, val_data, label_list, args.eval_batch_size, device)
time_end = time.time()
......@@ -264,8 +266,8 @@ def train_model(args: Namespace):
if f1 > best_val_f1:
best_val_f1 = f1
logger.info("\nFound better f1=%.4f on validation set. "
"Saving model\n" % f1)
logger.info("Found better f1=%.4f on validation set. "
"Saving model" % f1)
logger.info("%s\n" % report)
model.save(args.output_dir)
else:
......@@ -283,8 +285,7 @@ def train_model(args: Namespace):
logger.info("%s\n" % report)
if args.epoch_save_model:
epoch_output_dir = os.path.join(args.output_dir,
"e%03d" % epoch_no)
epoch_output_dir = os.path.join(args.output_dir, "e%03d" % epoch_no)
os.makedirs(epoch_output_dir)
model.save(epoch_output_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment