Skip to content
Snippets Groups Projects
Commit bb4757a5 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Refine "finetuning" option in CLI

parent c2e015c5
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
......@@ -65,8 +65,8 @@ def default_ud_dataset_reader() -> UniversalDependenciesDatasetReader:
def default_data_loader(dataset_reader: DatasetReader,
file_path: str,
batch_size: int,
batches_per_epoch: int) -> SimpleDataLoader:
batch_size: int = 16,
batches_per_epoch: int = 4) -> SimpleDataLoader:
return SimpleDataLoader.from_dataset_reader(dataset_reader,
data_path=file_path,
batch_size=batch_size,
......
import json
import logging
import os
import pathlib
from typing import Dict
......@@ -14,8 +13,7 @@ from combo.training.trainable_combo import TrainableCombo
from combo.utils import checks
from config import resolve
from data.dataset import UniversalDependenciesDatasetReader
from default_model import default_ud_dataset_reader
from default_model import default_ud_dataset_reader, default_data_loader
from models import ComboModel
from predict import COMBO
......@@ -24,7 +22,7 @@ _FEATURES = ["token", "char", "upostag", "xpostag", "lemma", "feats"]
_TARGETS = ["deprel", "feats", "head", "lemma", "upostag", "xpostag", "semrel", "sent", "deps"]
FLAGS = flags.FLAGS
flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict"],
flags.DEFINE_enum(name="mode", default=None, enum_values=["train", "predict", "finetune"],
help="Specify COMBO mode: train or predict")
# Common flags
......@@ -59,11 +57,11 @@ flags.DEFINE_boolean(name="tensorboard", default=False,
help="When provided model will log tensorboard metrics.")
# Finetune after training flags
flags.DEFINE_list(name="finetuning_training_data_path", default="",
flags.DEFINE_string(name="finetuning_training_data_path", default="",
help="Training data path(s)")
flags.DEFINE_list(name="finetuning_validation_data_path", default="",
flags.DEFINE_string(name="finetuning_validation_data_path", default="",
help="Validation data path(s)")
flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "config.template.json"),
flags.DEFINE_string(name="config_path", default=str(pathlib.Path(__file__).parent / "params.json"),
help="Config file path.")
# Test after training flags
......@@ -90,8 +88,8 @@ flags.DEFINE_enum(name="predictor_name", default="combo-lambo",
help="Use predictor with whitespace, spacy or lambo (recommended) tokenizer.")
def get_model_for_training() -> TrainableCombo:
with open(FLAGS.config_path, 'rb') as f:
def get_model_for_training(config_path) -> TrainableCombo:
with open(config_path, 'rb') as f:
config_json = json.load(f)
train_data_loader = resolve(config_json['train_data_loader'])
val_data_loader = resolve(config_json['val_data_loader'])
......@@ -109,16 +107,20 @@ def get_model_for_training() -> TrainableCombo:
return model
def get_saved_model(parameters) -> ComboModel:
return ComboModel.load(os.path.join(FLAGS.model_path),
config=parameters,
weights_file=os.path.join(FLAGS.model_path, 'best.th'),
cuda_device=FLAGS.cuda_device)
def get_predictor() -> COMBO:
# Check for GPU
# allen_checks.check_for_gpu(FLAGS.cuda_device)
checks.file_exists(FLAGS.model_path)
with open(os.path.join(FLAGS.model_path, 'params.json'), 'r') as f:
serialized = json.load(f)
model = ComboModel.load(serialized,
os.path.join(FLAGS.model_path),
os.path.join(FLAGS.model_path, 'best.th'),
cuda_device=FLAGS.cuda_device)
model = get_saved_model(serialized)
if 'dataset_reader' in serialized:
dataset_reader = resolve(serialized['dataset_reader'])
else:
......@@ -128,7 +130,7 @@ def get_predictor() -> COMBO:
def run(_):
if FLAGS.mode == 'train':
trained_nlp = get_model_for_training()
trained_nlp = get_model_for_training(FLAGS.config_path)
trained_nlp.save(FLAGS.serialization_dir)
elif FLAGS.mode == 'predict':
predictor = get_predictor()
......@@ -138,6 +140,29 @@ def run(_):
for token in prediction.tokens:
print("{:15} {:15} {:10} {:10} {:10}".format(token.text, token.lemma, token.upostag, token.head,
token.deprel))
elif FLAGS.mode == 'finetune':
checks.file_exists(FLAGS.model_path)
with open(os.path.join(FLAGS.model_path, 'params.json'), 'r') as f:
serialized = json.load(f)
if 'dataset_reader' in serialized:
dataset_reader = resolve(serialized['dataset_reader'])
else:
dataset_reader = default_ud_dataset_reader()
model = get_saved_model(serialized)
nlp = TrainableCombo(model, torch.optim.Adam,
optimizer_kwargs={'betas': [0.9, 0.9], 'lr': 0.002},
validation_metrics=['EM'])
trainer = pl.Trainer(max_epochs=FLAGS.num_epochs,
default_root_dir=FLAGS.serialization_dir,
gradient_clip_val=5)
train_data_loader = default_data_loader(dataset_reader,
FLAGS.finetuning_training_data_path)
val_data_loader = default_data_loader(dataset_reader,
FLAGS.finetuning_validation_data_path)
train_data_loader.index_with(model.vocab)
val_data_loader.index_with(model.vocab)
trainer.fit(model=nlp, train_dataloaders=train_data_loader, val_dataloaders=val_data_loader)
model.save(FLAGS.serialization_dir)
def _get_ext_vars(finetuning: bool = False) -> Dict:
......
......@@ -7,7 +7,7 @@ import logging
import os
import re
from os import PathLike
from typing import Dict, List, Set, Type, Optional, Union
from typing import Dict, List, Set, Type, Optional, Union, Any
import numpy
import torch
......@@ -401,8 +401,8 @@ class Model(Module, FromParameters):
@classmethod
def load(
cls,
config: Params,
serialization_dir: Union[str, PathLike],
config: Optional[Union[Union[str, PathLike], Dict[str, Any]]] = None,
weights_file: Optional[Union[str, PathLike]] = None,
cuda_device: int = -1,
) -> "Model":
......@@ -432,6 +432,12 @@ class Model(Module, FromParameters):
The model specified in the configuration, loaded with the serialized
vocabulary and the trained weights.
"""
if config is None:
with open(os.path.join(serialization_dir, 'params.json'), 'r') as f:
config = json.load(f)
elif isinstance(config, str) or isinstance(config, PathLike):
with open(config, 'r') as f:
config = json.load(f)
# Peak at the class of the model.
model_type = (
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment