Skip to content
Snippets Groups Projects
Select Git revision
  • 60c45d3e8f4dd67c7c1dae290745ff3653f7e329
  • master default protected
  • stormrider_attacks
  • llm
  • personalized
  • attacks_poleval
6 results

poleval.py

Blame
  • train.py 12.52 KiB
    from __future__ import absolute_import, division, print_function
    
    import argparse
    import glob
    import json
    import logging
    import os
    import sys
    import time
    import re
    from argparse import Namespace
    from pathlib import Path
    
    import torch
    from tqdm import tqdm
    from pytorch_transformers import WarmupLinearSchedule, AdamW
    from torch.utils.data import DataLoader, RandomSampler
    
    from poldeepner2.model.hf_for_token_calssification import Pdn2ModelConfiguration
    from poldeepner2.utils.seed import setup_seed
    from poldeepner2.utils.sequences import FeatureGeneratorFactory
    from poldeepner2.utils.train_utils import add_train_args, evaluate_model
    from poldeepner2.utils.data_utils import NerProcessor, create_dataset
    from poldeepner2.model.hf_for_token_calssification import Pdn2TokenClassification
    
    
    def generate_unique_suffix(output_dir: str) -> str:
        if not Path(output_dir).exists():
            return ""
        max_suffix = 0
        for path in glob.glob(output_dir + "*"):
            suffix = path.split("_")[-1]
            suffix = int(suffix) if suffix.isdigit() else 0
            max_suffix = max(max_suffix, suffix)
        return f"_{max_suffix + 1:03d}"
    
    
    def get_output_dir(args: Namespace) -> str:
        parts = re.split(r"({[^}]+})", args.output_dir.lstrip())
        path = []
        for p in parts:
            if p.startswith("{"):
                name = p[1:-1]
                if name not in args.__dict__:
                    raise ValueError(f"Invalid parameter name in output_dir: {name}")
                path.append(re.sub(f"[^a-zA-Z0-9_/-]", "_", str(args.__dict__[name])))
            else:
                path.append(p)
        return "".join(path)
    
    
    def train_model(args: Namespace):
        args.output_dir = get_output_dir(args)
        suffix = generate_unique_suffix(args.output_dir)
        args.output_dir += suffix
    
        config = {
            "epochs": args.num_train_epochs,
            "language_model": args.pretrained_path,
            "batch_size": args.train_batch_size,
            "data_train": args.data_train,
            "data_tune": args.data_tune,
            "data_test": args.data_test,
            "max_seq_length": args.max_seq_length,
            "warmup_proportion": args.warmup_proportion,
            "learning_rate": args.learning_rate,
            "gradient_accumulation_steps": args.gradient_accumulation_steps,
            "sequence_generator": args.sequence_generator,
            "dropout": args.dropout,
            "output_dir": args.output_dir
        }
    
        if args.wandb:
            import wandb
            wandb.init(project=args.wandb, config=config)
            wandb.run.name = f"seq_{args.max_seq_length}_seed_{args.seed}_epoch_{args.num_train_epochs}_{suffix}"
            wandb.run.save()
    
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError("Output directory (%s) already exists and is not "
                             "empty." % args.output_dir)
    
        Path(args.output_dir).mkdir(parents=True, exist_ok=True)
        json.dump(args.__dict__, open(str(Path(args.output_dir) / "train_args.json"), "w", encoding="utf-8"), indent=4)
    
        logger = logging.getLogger(__name__)
        for item in sorted(config.items()):
            logger.info(item)
    
        if args.gradient_accumulation_steps < 1:
            raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                             "should be >= 1 "
                             % args.gradient_accumulation_steps)
    
        args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
        setup_seed(args.seed)
    
        # Determine set of labels
        processor = NerProcessor()
        datasets = [args.data_train]
        if args.data_tune:
            datasets.append(args.data_tune)
        if args.data_test:
            datasets.append(args.data_test)
        label_list = processor.get_labels(datasets)
        logger.info(f"Labels: {label_list}")
        num_labels = len(label_list) + 1  # add one for IGNORE label
        logger.info(f"Number of labels: {num_labels}")
    
        # Load training data
        logger.info("Loading training data...")
        t0 = time.time()
        train_examples = processor.get_examples(args.data_train, "train")
        logger.info(f"Training data was loaded in {time.time() - t0} second(s)")
    
        # preparing model configs
        hidden_size = 4096 if 'xxlarge' in args.pretrained_path else \
            2048 if 'xlarge' in args.pretrained_path else \
            1024 if 'large' in args.pretrained_path else \
            768 if 'base' in args.pretrained_path else args.hidden_size
        device = args.device
    
        logger.info("Loading pretrained model...")
        t0 = time.time()
        config = Pdn2ModelConfiguration(
            labels=label_list,
            hidden_size=hidden_size,
            dropout_p=args.dropout,
            device=device,
            max_seq_length=args.max_seq_length,
            sequence_generator=args.sequence_generator,
            seed=args.seed
        )
        model = Pdn2TokenClassification(path=args.pretrained_path, config=config)
        model.to(device)
        logger.info(f"Pretrained model was loaded in {time.time()-t0} second(s)")
    
        gen = FeatureGeneratorFactory.create(args.sequence_generator,
                                             label_list=label_list,
                                             max_seq_length=args.max_seq_length,
                                             encode_method=model.encode_word)
    
        if args.sequence_generator_for_eval is not None:
            gen_eval = FeatureGeneratorFactory.create(args.sequence_generator_for_eval,
                                                      label_list=label_list,
                                                      max_seq_length=args.max_seq_length,
                                                      encode_method=model.encode_word)
        else:
            gen_eval = gen
    
        train_features = gen.generate(train_examples)
    
        num_train_optimization_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
    
        no_decay = ['bias', 'final_layer_norm.weight']
        params = list(model.named_parameters())
        optimizer_grouped_parameters = [
            {'params': [p for n, p in params if not any(
                nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in params if any(
                nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    
        warmup_steps = int(args.warmup_proportion * num_train_optimization_steps)
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps, t_total=num_train_optimization_steps)
    
        # freeze model if necessary
        if args.freeze_model:
            logger.info("Freezing XLM-R model...")
            for n, p in model.named_parameters():
                if 'encoder' in n and p.requires_grad:
                    logging.info("Parameter %s - freezed" % n)
                    p.requires_grad = False
                else:
                    logging.info("Parameter %s - unchanged" % n)
    
        if args.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex "
                    "to use fp16 training.")
            model, optimizer = amp.initialize(
                model, optimizer, opt_level=args.fp16_opt_level)
    
        # Train the model
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
    
        train_data = create_dataset(train_features)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
    
        # getting validation samples
        best_val_f1 = 0.0
        if args.data_tune:
            val_examples = processor.get_examples(args.data_tune, "tune")
            val_features = gen_eval.generate(val_examples)
            val_data = create_dataset(val_features)
    
        if args.data_test:
            eval_examples = processor.get_examples(args.data_test, "test")
            eval_features = gen_eval.generate(eval_examples)
            eval_data = create_dataset(eval_features)
    
        for epoch_no in range(1, args.num_train_epochs + 1):
            epoch_stats = {"epoch": epoch_no}
            logger.info("Epoch %d" % epoch_no)
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
    
            model.train()
            steps = len(train_dataloader)
    
            time_start = time.time()
    
            for step, batch in tqdm(enumerate(train_dataloader), total=steps):
                batch = tuple(t.to(device) for t in batch)
                input_ids, label_ids, valid_ids = batch
                loss = model(input_ids, label_ids, valid_ids)
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
    
                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
    
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
    
                epoch_stats["loss"] = loss
                epoch_stats["learning_rate"] = scheduler.get_last_lr()[0]
    
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()
    
                epoch_stats["step"] = step
                if args.wandb:
                    wandb.log(epoch_stats)
    
            if args.wandb:
                epoch_stats["epoch_training_time"] = time.time() - time_start
    
            if args.data_tune:
                logger.info("Testing on validation set...")
                time_start = time.time()
                f1, precision, recall, report = evaluate_model(model, val_data, label_list, args.eval_batch_size, device)
                time_end = time.time()
                epoch_stats["validation_F1"] = f1
                epoch_stats["validation_P"] = precision
                epoch_stats["validation_R"] = recall
                epoch_stats["epoch_validation_time"] = time_end - time_start
    
                if f1 > best_val_f1:
                    best_val_f1 = f1
                    logger.info("Found better f1=%.4f on validation set. "
                                "Saving model" % f1)
                    logger.info("%s\n" % report)
                    model.save(args.output_dir)
                else:
                    logger.info("%s\n" % report)
    
            if args.data_test:
                logger.info("\nTesting on test set...")
                time_start = time.time()
                f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
                time_end = time.time()
                epoch_stats["test_F1"] = f1
                epoch_stats["test_P"] = precision
                epoch_stats["test_R"] = recall
                epoch_stats["epoch_testing_time"] = time_end - time_start
                logger.info("%s\n" % report)
    
            if args.epoch_save_model:
                epoch_output_dir = os.path.join(args.output_dir, "e%03d" % epoch_no)
                os.makedirs(epoch_output_dir)
                model.save(epoch_output_dir)
    
            if args.wandb:
                wandb.log(epoch_stats)
    
        if args.data_tune:
            eval_data = create_dataset(val_features)
            f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
            logger.info("\n%s", report)
            output_eval_file = os.path.join(args.output_dir, "valid_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Writing results to file *****")
                writer.write(report)
                logger.info("Done.")
    
        if args.data_test:
            eval_data = create_dataset(eval_features)
            f1, precision, recall, report = evaluate_model(model, eval_data, label_list, args.eval_batch_size, device)
            logger.info("\n%s", report)
            output_eval_file = os.path.join(args.output_dir, "test_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Writing results to file *****")
                writer.write(report)
                logger.info("Done.")
    
        if args.wandb:
            wandb.finish()
    
        del model
        torch.cuda.empty_cache()
    
    
    if __name__ == "__main__":
        logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                            datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO)
        logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
    
        parser = argparse.ArgumentParser()
        parser = add_train_args(parser)
        args = parser.parse_args()
        train_model(args)