Skip to content
Snippets Groups Projects
Select Git revision
  • 3630de8d41e18d27861fcc242b1ab77572726430
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • master protected
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
34 results

find_lr.py

Blame
  • Łukasz Pszenny's avatar
    Łukasz Pszenny authored
    8610de4c
    History
    find_lr.py 3.94 KiB
    """
    This script is designed for finding learning rate for specified architecture of a Named Entity Recognition (NER) model
    It loads configuration settings from a JSON file and performs a learning rate search using a PyTorch Lightning learning
    rate finder. At the end of the script, it prints the suggested learning rate obtained from the learning rate finder.
    
    Arguments:
    ----------
    --config_path : str
        Path to the configuration file in JSON format.
    --data_path : str, optional
        Path to the data directory. If provided, it overrides the data path in the configuration.
    --check_config : bool, optional
        Flag to check the validity of the configuration.
    """
    from pathlib import Path
    import pytorch_lightning as pl
    from pytorch_lightning.tuner import Tuner
    import json
    import argparse
    from combo.ner_modules.data.utils import create_tag2id, create_char2id, calculate_longest_word
    from combo.ner_modules.NerModel import NerModel
    from combo.ner_modules.utils.utils import check_config_constraints
    from combo.ner_modules.utils.constructors import construct_loss_from_config, construct_tokenizer_from_config, construct_data_module_from_config
    import torch
    from combo.ner_modules.utils.utils import fix_common_warnings
    
    torch.set_float32_matmul_precision("medium")  # to make lightning happy
    
    # Argument parsing
    parser = argparse.ArgumentParser()
    parser.add_argument("--config_path", action="store", default="", help="Path to config file")
    parser.add_argument("--data_path", action="store", default="", help="Path to data, if not provided taken from config")
    parser.add_argument("--check_config", action="store_true", help="Flag whether check on config should be done")
    args = parser.parse_args()
    
    PATH_CONFIG_FILE = Path(args.config_path)
    DATA_PATH = args.data_path
    
    if __name__ == "__main__":
        fix_common_warnings()
    
        # Loading config file
        default_config = open(PATH_CONFIG_FILE)
        config = json.load(default_config)
    
        # Parse data path to config if it is not give
        if DATA_PATH:
            config["data"]["path_data"] = DATA_PATH
    
        # Check config
        if args.check_config:
            config = check_config_constraints(config)
    
        # create vocabularies
        char_to_id = create_char2id(file_path=Path(config["data"]["path_data"]) / "train.txt",
                                    encoding=config["data"]["encoding"])
        label_to_id = create_tag2id(file_path=Path(config["data"]["path_data"]) / "train.txt",
                                    encoding=config["data"]["encoding"],
                                    include_special_tokens=config["data"]["use_start_end_token"])
    
        # Extract max word length
        config["data"]["max_word_len"] = calculate_longest_word(file_path=Path(config["data"]["path_data"]) / "train.txt",
                                                                encoding=config["data"]["encoding"])
    
        # construct tokenizer
        tokenizer = construct_tokenizer_from_config(config=config,
                                                    char_to_id_map=char_to_id,
                                                    label_to_id_map=label_to_id)
        # construct data module
        data_module = construct_data_module_from_config(config=config,
                                                        tokenizer=tokenizer)
    
        # construct loss
        loss = construct_loss_from_config(config=config,
                                          label_to_id=label_to_id)
    
        model = NerModel(loss_fn=loss,
                         char_to_id_map=char_to_id,
                         label_to_id_map=label_to_id,
                         config=config)
    
        # Create an instance of the PyTorch Lightning Trainer
        trainer = pl.Trainer(devices=config["trainer"]["devices"],
                             max_epochs=-1,
                             accelerator=config["trainer"]["accelerator"],
                             log_every_n_steps=10)
        tuner = Tuner(trainer)
    
        # Train the model
        lr_finder = tuner.lr_find(model, datamodule=data_module)
    
        # Print learning rate
        print(f"Found learning rate:{lr_finder.suggestion()}")