Select Git revision
Łukasz Pszenny authored
find_lr.py 3.94 KiB
"""
This script is designed for finding learning rate for specified architecture of a Named Entity Recognition (NER) model
It loads configuration settings from a JSON file and performs a learning rate search using a PyTorch Lightning learning
rate finder. At the end of the script, it prints the suggested learning rate obtained from the learning rate finder.
Arguments:
----------
--config_path : str
Path to the configuration file in JSON format.
--data_path : str, optional
Path to the data directory. If provided, it overrides the data path in the configuration.
--check_config : bool, optional
Flag to check the validity of the configuration.
"""
from pathlib import Path
import pytorch_lightning as pl
from pytorch_lightning.tuner import Tuner
import json
import argparse
from combo.ner_modules.data.utils import create_tag2id, create_char2id, calculate_longest_word
from combo.ner_modules.NerModel import NerModel
from combo.ner_modules.utils.utils import check_config_constraints
from combo.ner_modules.utils.constructors import construct_loss_from_config, construct_tokenizer_from_config, construct_data_module_from_config
import torch
from combo.ner_modules.utils.utils import fix_common_warnings
torch.set_float32_matmul_precision("medium") # to make lightning happy
# Argument parsing
parser = argparse.ArgumentParser()
parser.add_argument("--config_path", action="store", default="", help="Path to config file")
parser.add_argument("--data_path", action="store", default="", help="Path to data, if not provided taken from config")
parser.add_argument("--check_config", action="store_true", help="Flag whether check on config should be done")
args = parser.parse_args()
PATH_CONFIG_FILE = Path(args.config_path)
DATA_PATH = args.data_path
if __name__ == "__main__":
fix_common_warnings()
# Loading config file
default_config = open(PATH_CONFIG_FILE)
config = json.load(default_config)
# Parse data path to config if it is not give
if DATA_PATH:
config["data"]["path_data"] = DATA_PATH
# Check config
if args.check_config:
config = check_config_constraints(config)
# create vocabularies
char_to_id = create_char2id(file_path=Path(config["data"]["path_data"]) / "train.txt",
encoding=config["data"]["encoding"])
label_to_id = create_tag2id(file_path=Path(config["data"]["path_data"]) / "train.txt",
encoding=config["data"]["encoding"],
include_special_tokens=config["data"]["use_start_end_token"])
# Extract max word length
config["data"]["max_word_len"] = calculate_longest_word(file_path=Path(config["data"]["path_data"]) / "train.txt",
encoding=config["data"]["encoding"])
# construct tokenizer
tokenizer = construct_tokenizer_from_config(config=config,
char_to_id_map=char_to_id,
label_to_id_map=label_to_id)
# construct data module
data_module = construct_data_module_from_config(config=config,
tokenizer=tokenizer)
# construct loss
loss = construct_loss_from_config(config=config,
label_to_id=label_to_id)
model = NerModel(loss_fn=loss,
char_to_id_map=char_to_id,
label_to_id_map=label_to_id,
config=config)
# Create an instance of the PyTorch Lightning Trainer
trainer = pl.Trainer(devices=config["trainer"]["devices"],
max_epochs=-1,
accelerator=config["trainer"]["accelerator"],
log_every_n_steps=10)
tuner = Tuner(trainer)
# Train the model
lr_finder = tuner.lr_find(model, datamodule=data_module)
# Print learning rate
print(f"Found learning rate:{lr_finder.suggestion()}")