Select Git revision
train.py 12.52 KiB
from __future__ import absolute_import, division, print_function
import argparse
import glob
import json
import logging
import os
import sys
import time
import re
from argparse import Namespace
from pathlib import Path
import torch
from tqdm import tqdm
from pytorch_transformers import WarmupLinearSchedule, AdamW
from torch.utils.data import DataLoader, RandomSampler
from poldeepner2.model.hf_for_token_calssification import Pdn2ModelConfiguration
from poldeepner2.utils.seed import setup_seed
from poldeepner2.utils.sequences import FeatureGeneratorFactory
from poldeepner2.utils.train_utils import add_train_args, evaluate_model
from poldeepner2.utils.data_utils import NerProcessor, create_dataset
from poldeepner2.model.hf_for_token_calssification import Pdn2TokenClassification
def generate_unique_suffix(output_dir: str) -> str:
if not Path(output_dir).exists():
return ""
max_suffix = 0
for path in glob.glob(output_dir + "*"):
suffix = path.split("_")[-1]
suffix = int(suffix) if suffix.isdigit() else 0
max_suffix = max(max_suffix, suffix)
return f"_{max_suffix + 1:03d}"
def get_output_dir(args: Namespace) -> str:
parts = re.split(r"({[^}]+})", args.output_dir.lstrip())
path = []
for p in parts:
if p.startswith("{"):
name = p[1:-1]
if name not in args.__dict__:
raise ValueError(f"Invalid parameter name in output_dir: {name}")
path.append(re.sub(f"[^a-zA-Z0-9_/-]", "_", str(args.__dict__[name])))
else:
path.append(p)
return "".join(path)
def train_model(args: Namespace):
args.output_dir = get_output_dir(args)
suffix = generate_unique_suffix(args.output_dir)
args.output_dir += suffix
config = {
"epochs": args.num_train_epochs,
"language_model": args.pretrained_path,
"batch_size": args.train_batch_size,
"data_train": args.data_train,
"data_tune": args.data_tune,
"data_test": args.data_test,
"max_seq_length": args.max_seq_length,
"warmup_proportion": args.warmup_proportion,
"learning_rate": args.learning_rate,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"sequence_generator": args.sequence_generator,
"dropout": args.dropout,
"output_dir": args.output_dir
}
if args.wandb:
import wandb
wandb.init(project=args.wandb, config=config)
wandb.run.name = f"seq_{args.max_seq_length}_seed_{args.seed}_epoch_{args.num_train_epochs}_{suffix}"
wandb.run.save()
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory (%s) already exists and is not "
"empty." % args.output_dir)
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
json.dump(args.__dict__, open(str(Path(args.output_dir) / "train_args.json"), "w", encoding="utf-8"), indent=4)
logger = logging.getLogger(__name__)
for item in sorted(config.items()):
logger.info(item)
if args.gradient_accumulation_steps < 1:
raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
"should be >= 1 "
% args.gradient_accumulation_steps)
args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
setup_seed(args.seed)
# Determine set of labels
processor = NerProcessor()
datasets = [args.data_train]
if args.data_tune:
datasets.append(args.data_tune)
if args.data_test:
datasets.append(args.data_test)
label_list = processor.get_labels(datasets)
logger.info(f"Labels: {label_list}")
num_labels = len(label_list) + 1 # add one for IGNORE label
logger.info(f"Number of labels: {num_labels}")
# Load training data
logger.info("Loading training data...")
t0 = time.time()
train_examples = processor.get_examples(args.data_train, "train")
logger.info(f"Training data was loaded in {time.time() - t0} second(s)")
# preparing model configs
hidden_size = 4096 if 'xxlarge' in args.pretrained_path else \
2048 if 'xlarge' in args.pretrained_path else \
1024 if 'large' in args.pretrained_path else \
768 if 'base' in args.pretrained_path else args.hidden_size
device = args.device
logger.info("Loading pretrained model...")
t0 = time.time()
config = Pdn2ModelConfiguration(
labels=label_list,
hidden_size=hidden_size,
dropout_p=args.dropout,
device=device,
max_seq_length=args.max_seq_length,
sequence_generator=args.sequence_generator,
seed=args.seed
)
model = Pdn2TokenClassification(path=args.pretrained_path, config=config)
model.to(device)
logger.info(f"Pretrained model was loaded in {time.time()-t0} second(s)")
gen = FeatureGeneratorFactory.create(args.sequence_generator,
label_list=label_list,
max_seq_length=args.max_seq_length,
encode_method=model.encode_word)
if args.sequence_generator_for_eval is not None:
gen_eval = FeatureGeneratorFactory.create(args.sequence_generator_for_eval,
label_list=label_list,
max_seq_length=args.max_seq_length,
encode_method=model.encode_word)
else:
gen_eval = gen
train_features = gen.generate(train_examples)
num_train_optimization_steps = int(
len(train_features) / args.train_batch_size /
args.gradient_accumulation_steps) * args.num_train_epochs
no_decay = ['bias', 'final_layer_norm.weight']
params = list(model.named_parameters())
optimizer_grouped_parameters = [
{'params': [p for n, p in params if not any(
nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
{'params': [p for n, p in params if any(
nd in n for nd in no_decay)], 'weight_decay': 0.0}
]
warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
# freeze model if necessary
if args.freeze_model:
logger.info("Freezing XLM-R model...")
for n, p in model.named_parameters():
if 'encoder' in n and p.requires_grad:
logging.info("Parameter %s - freezed" % n)
p.requires_grad = False
else:
logging.info("Parameter %s - unchanged" % n)
if args.fp16:
try:
from apex import amp
except ImportError:
raise ImportError(
"Please install apex from https://www.github.com/nvidia/apex "
"to use fp16 training.")
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.fp16_opt_level)
# Train the model
logger.info("***** Running training *****")
logger.info(" Num examples = %d", len(train_examples))
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_optimization_steps)
train_data = create_dataset(train_features)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
# getting validation samples
best_val_f1 = 0.0
if args.data_tune:
val_examples = processor.get_examples(args.data_tune, "tune")
val_features = gen_eval.generate(val_examples)
val_data = create_dataset(val_features)
if args.data_test:
eval_examples = processor.get_examples(args.data_test, "test")
eval_features = gen_eval.generate(eval_examples)
eval_data = create_dataset(eval_features)
for epoch_no in range(1, args.num_train_epochs + 1):
epoch_stats = {"epoch": epoch_no}
logger.info("Epoch %d" % epoch_no)
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
model.train()
steps = len(train_dataloader)
time_start = time.time()
for step, batch in tqdm(enumerate(train_dataloader), total=steps):
batch = tuple(t.to(device) for t in batch)
input_ids, label_ids, valid_ids = batch
loss = model(input_ids, label_ids, valid_ids)
if args.gradient_accumulation_steps > 1:
loss = loss / args.gradient_accumulation_steps
if args.fp16:
with amp.scale_loss(loss, optimizer) as scaled_loss:
scaled_loss.backward()
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
else:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
tr_loss += loss.item()
nb_tr_examples += input_ids.size(0)
nb_tr_steps += 1
epoch_stats["loss"] = loss
epoch_stats["learning_rate"] = scheduler.get_last_lr()[0]
if (step + 1) % args.gradient_accumulation_steps == 0:
optimizer.step()
scheduler.step()
model.zero_grad()
epoch_stats["step"] = step
if args.wandb:
wandb.log(epoch_stats)
if args.wandb:
epoch_stats["epoch_training_time"] = time.time() - time_start
if args.data_tune:
logger.info("Testing on validation set...")
time_start = time.time()
f1, precision, recall, report = evaluate_model(model, val_data, label_list, args.eval_batch_size, device)
time_end = time.time()
epoch_stats["validation_F1"] = f1
epoch_stats["validation_P"] = precision
epoch_stats["validation_R"] = recall
epoch_stats["epoch_validation_time"] = time_end - time_start
if f1 > best_val_f1:
best_val_f1 = f1
logger.info("Found better f1=%.4f on validation set. "
"Saving model" % f1)
logger.info("%s\n" % report)
model.save(args.output_dir)
else:
logger.info("%s\n" % report)
if args.data_test:
logger.info("\nTesting on test set...")
time_start = time.time()
f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
time_end = time.time()
epoch_stats["test_F1"] = f1
epoch_stats["test_P"] = precision
epoch_stats["test_R"] = recall
epoch_stats["epoch_testing_time"] = time_end - time_start
logger.info("%s\n" % report)
if args.epoch_save_model:
epoch_output_dir = os.path.join(args.output_dir, "e%03d" % epoch_no)
os.makedirs(epoch_output_dir)
model.save(epoch_output_dir)
if args.wandb:
wandb.log(epoch_stats)
if args.data_tune:
eval_data = create_dataset(val_features)
f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
logger.info("\n%s", report)
output_eval_file = os.path.join(args.output_dir, "valid_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Writing results to file *****")
writer.write(report)
logger.info("Done.")
if args.data_test:
eval_data = create_dataset(eval_features)
f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
logger.info("\n%s", report)
output_eval_file = os.path.join(args.output_dir, "test_results.txt")
with open(output_eval_file, "w") as writer:
logger.info("***** Writing results to file *****")
writer.write(report)
logger.info("Done.")
if args.wandb:
wandb.finish()
del model
torch.cuda.empty_cache()
if __name__ == "__main__":
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
parser = argparse.ArgumentParser()
parser = add_train_args(parser)
args = parser.parse_args()
train_model(args)