Skip to content
Snippets Groups Projects
Commit 25860d3b authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add Lightning training classes

parent ba511cf5
Branches
Tags
1 merge request!46Merge COMBO 3.0 into master
...@@ -141,6 +141,17 @@ class Vocabulary(FromParameters): ...@@ -141,6 +141,17 @@ class Vocabulary(FromParameters):
self._oov_token) self._oov_token)
self._retained_counter: Optional[Dict[str, Dict[str, int]]] = None self._retained_counter: Optional[Dict[str, Dict[str, int]]] = None
self._extend(
counter,
min_count,
max_vocab_size,
non_padded_namespaces,
pretrained_files,
only_include_pretrained_words,
tokens_to_add,
min_pretrained_embeddings
)
def _extend(self, def _extend(self,
counter: Dict[str, Dict[str, int]] = None, counter: Dict[str, Dict[str, int]] = None,
min_count: Dict[str, int] = None, min_count: Dict[str, int] = None,
......
%% Cell type:code id:initial_id tags:
``` python
from combo.data.dataset_readers import UniversalDependenciesDatasetReader
from combo.data.tokenizers import CharacterTokenizer
from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, TokenFeatsIndexer, PretrainedTransformerFixedMismatchedIndexer, SingleIdTokenIndexer
from combo.data.dataset_loaders import SimpleDataLoader
from combo.data.vocabulary import FromInstancesVocabulary
```
%% Cell type:code id:abb6ce33c2e461e6 tags:
``` python
def default_const_character_indexer():
return TokenConstPaddingCharactersIndexer(
tokenizer=CharacterTokenizer(end_tokens=["__END__"],
start_tokens=["__START__"]),
min_padding_length=32,
namespace="lemma_characters"
)
dataset_reader = UniversalDependenciesDatasetReader(
features=["token", "char"],
lemma_indexers={
"char": default_const_character_indexer()
},
targets=["deprel", "head", "upostag", "lemma", "feats", "xpostag"],
token_indexers={
"char": default_const_character_indexer(),
"feats": TokenFeatsIndexer(),
"lemma": default_const_character_indexer(),
"token": PretrainedTransformerFixedMismatchedIndexer("bert-base-cased"),
"upostag": SingleIdTokenIndexer(
feature_name="pos_",
namespace="upostag"
),
"xpostag": SingleIdTokenIndexer(
feature_name="tag_",
namespace="xpostag"
)
},
use_sem=False
)
```
%% Cell type:code id:3519b6753622def0 tags:
``` python
FILE_PATH = '/Users/majajablonska/Documents/train.conllu'
data_loader = SimpleDataLoader.from_dataset_reader(dataset_reader,
data_path=FILE_PATH,
batch_size=4)
```
%% Output
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
- Avoid using `tokenizers` before the fork if possible
- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
%% Cell type:code id:eb23ae8415cb52c2 tags:
``` python
for i in data_loader.iter_instances():
break
```
%% Cell type:code id:834f448f90453d03 tags:
``` python
vocabulary = FromInstancesVocabulary.from_instances_extended(
data_loader.iter_instances(),
non_padded_namespaces=['head_labels'],
only_include_pretrained_words=True,
oov_token='_',
padding_token='__PAD__'
)
```
%% Output
%% Cell type:code id:82d4c789c15866ab tags:
``` python
```
%% Cell type:markdown id:9a4de0a90632538 tags:
This diff is collapsed.
This diff is collapsed.
class Scheduler: import torch
pass from typing import Callable, List, Union
from overrides import overrides
class Scheduler(torch.optim.lr_scheduler.LambdaLR):
def __init__(self,
optimizer: torch.optim.Optimizer,
patience: int = 6,
decreases: int = 2,
threshold: float = 1e-3,
last_epoch: int = -1,
verbose: bool = False):
super().__init__(optimizer, [self._lr_lambda], last_epoch, verbose)
self.patience = patience
self.decreases = decreases
self.threshold = threshold
self.start_patience = patience
self.best_score = 0.0
@staticmethod
def _lr_lambda(idx: int) -> float:
return 1.0 / (1.0 + idx * 1e-4)
def step(self, metric: float = None) -> None:
super().step()
if metric is not None:
if metric - self.best_score > self.threshold:
self.best_score = metric if metric > self.best_score else self.best_score
self.patience = self.start_patience
else:
if self.patience <= 1:
if self.decreases == 0:
# The Trainer should trigger early stopping
self.patience = 0
else:
self.patience = self.start_patience
self.decreases -= 1
self.threshold /= 2
self.base_lrs = [x / 2 for x in self.base_lrs]
else:
self.patience -= 1
from typing import Optional, Type from typing import Any, Dict, List, Optional, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch
from torch import Tensor from torch import Tensor
from combo.config import FromParameters from combo.config import FromParameters
from combo.data.dataset_loaders.dataset_loader import TensorDict from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.modules.model import Model from combo.modules.model import Model
from combo.training import Scheduler
class TrainableCombo(pl.LightningModule, FromParameters): class TrainableCombo(pl.LightningModule, FromParameters):
def __init__(self, def __init__(self,
model: Model, model: Model,
optimizer_type: Type, optimizer_type: Type = torch.optim.Adam,
learning_rate: float = 0.1): optimizer_kwargs: Optional[Dict[str, Any]] = None,
scheduler_type: Type = Scheduler,
scheduler_kwargs: Optional[Dict[str, Any]] = None,
validation_metrics: List[str] = None):
super().__init__() super().__init__()
self.model = model self.model = model
self._optimizer_type = optimizer_type self._optimizer_type = optimizer_type
self._lr = learning_rate self._optimizer_kwargs = optimizer_kwargs if optimizer_kwargs else {}
self._scheduler_type = scheduler_type
self._scheduler_kwargs = scheduler_kwargs if scheduler_kwargs else {}
self._validation_metrics = validation_metrics if validation_metrics else []
def forward(self, batch: TensorDict) -> TensorDict: def forward(self, batch: TensorDict) -> TensorDict:
return self.model.batch_outputs(batch, self.model.training) return self.model.batch_outputs(batch, self.model.training)
...@@ -28,8 +39,17 @@ class TrainableCombo(pl.LightningModule, FromParameters): ...@@ -28,8 +39,17 @@ class TrainableCombo(pl.LightningModule, FromParameters):
def validation_step(self, batch: TensorDict, batch_idx: int) -> Tensor: def validation_step(self, batch: TensorDict, batch_idx: int) -> Tensor:
output = self.forward(batch) output = self.forward(batch)
self.log("validation_loss", output['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True) metrics = self.model.get_metrics()
for k in metrics.keys():
if k in self._validation_metrics:
self.log(k, metrics[k], on_epoch=True, prog_bar=True, logger=True)
return output["loss"] return output["loss"]
def lr_scheduler_step(self, scheduler: torch.optim.lr_scheduler, metric: Optional[Any]) -> None:
scheduler.step(metric=metric)
def configure_optimizers(self): def configure_optimizers(self):
return self._optimizer_type(self.model.parameters(), lr=self._lr) optimizer = self._optimizer_type(self.model.parameters(), **self._optimizer_kwargs)
return ([optimizer],
[{'scheduler': self._scheduler_type(optimizer, **self._scheduler_kwargs),
'interval': 'epoch'}])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment