Skip to content
Snippets Groups Projects
Commit 5118ae5b authored by Michał Marcińczuk's avatar Michał Marcińczuk
Browse files

Write training arguments to the file.

parent b8d24205
Branches
1 merge request!41Dev v07
Pipeline #6242 failed with stage
in 2 minutes and 25 seconds
......@@ -4,6 +4,7 @@ from __future__ import absolute_import, division, print_function
import argparse
import glob
import json
import logging
import os
import sys
......@@ -82,6 +83,7 @@ def train_model(args: Namespace):
"empty." % args.output_dir)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
json.dump(args.__dict__, open(str(Path(args.output_dir) / "train_args.json"), "w", encoding="utf-8"), indent=4)
logger = logging.getLogger(__name__)
for item in sorted(config.items()):
......@@ -169,7 +171,7 @@ def train_model(args: Namespace):
if args.freeze_model:
logger.info("Freezing XLM-R model...")
for n, p in model.named_parameters():
if 'xlmr' in n and p.requires_grad:
if 'encoder' in n and p.requires_grad:
logging.info("Parameter %s - freezed" % n)
p.requires_grad = False
else:
......@@ -253,7 +255,7 @@ def train_model(args: Namespace):
epoch_stats["epoch_training_time"] = time.time() - time_start
if args.data_tune:
logger.info("\nTesting on validation set...")
logger.info("Testing on validation set...")
time_start = time.time()
f1, precision, recall, report = evaluate_model(model, val_data, label_list, args.eval_batch_size, device)
time_end = time.time()
......@@ -264,8 +266,8 @@ def train_model(args: Namespace):
if f1 > best_val_f1:
best_val_f1 = f1
logger.info("\nFound better f1=%.4f on validation set. "
"Saving model\n" % f1)
logger.info("Found better f1=%.4f on validation set. "
"Saving model" % f1)
logger.info("%s\n" % report)
model.save(args.output_dir)
else:
......@@ -283,8 +285,7 @@ def train_model(args: Namespace):
logger.info("%s\n" % report)
if args.epoch_save_model:
epoch_output_dir = os.path.join(args.output_dir,
"e%03d" % epoch_no)
epoch_output_dir = os.path.join(args.output_dir, "e%03d" % epoch_no)
os.makedirs(epoch_output_dir)
model.save(epoch_output_dir)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment