Skip to content
Snippets Groups Projects
Commit ba511cf5 authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Pytorch Lightning training module

parent 76fb7c4a
1 merge request!46Merge COMBO 3.0 into master
......@@ -16,6 +16,7 @@ from combo.common.params import remove_keys_from_params, Params
from combo.config import FromParameters, Registry
from combo.data import Vocabulary, Instance
from combo.data.batch import Batch
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.modules.module import Module
from combo.nn import util, RegularizerApplicator
from combo.utils import ConfigurationError
......@@ -148,6 +149,31 @@ class Model(Module, FromParameters):
"""
raise NotImplementedError
def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]:
"""
Does a forward pass on the given batch and returns the output dictionary that the model
returns, after adding any specified regularization penalty to the loss (if training).
"""
output_dict = self.__call__(**batch)
if for_training:
try:
assert "loss" in output_dict
regularization_penalty = self.get_regularization_penalty()
if regularization_penalty is not None:
output_dict["reg_loss"] = regularization_penalty
output_dict["loss"] += regularization_penalty
except AssertionError:
if for_training:
raise RuntimeError(
"The model you are trying to optimize does not contain a"
" 'loss' key in the output of model.forward(inputs)."
)
return output_dict
def forward_on_instance(self, instance: Instance) -> Dict[str, numpy.ndarray]:
"""
Takes an [`Instance`](../data/instance.md), which typically has raw text in it, converts
......
......@@ -5,6 +5,7 @@ https://github.com/allenai/allennlp/blob/main/allennlp/nn/module.py#L14
from typing import List, Optional, Tuple
import pytorch_lightning as pl
import torch
from combo.nn.util import (
_check_incompatible_keys,
......@@ -13,7 +14,7 @@ from combo.nn.util import (
)
class Module(pl.LightningModule):
class Module(torch.nn.Module):
"""
This is just `torch.nn.Module` with some extra functionality.
"""
......
import logging
import os
import sys
from typing import List, Union, Dict, Any
from typing import List, Union, Dict, Any, Optional, Type
import numpy as np
import torch
......@@ -67,6 +67,9 @@ class COMBO(PredictorModule):
output['loss'].backward()
return output
def configure_optimizers(self) -> Any:
pass
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str):
return self.predict_json({"sentence": sentence})
......
......@@ -102,31 +102,6 @@ class PredictorModule(pl.LightningModule, FromParameters):
new_instances = self.predictions_to_labeled_instances(instance, outputs)
return new_instances
def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]:
"""
Does a forward pass on the given batch and returns the output dictionary that the model
returns, after adding any specified regularization penalty to the loss (if training).
"""
output_dict = self._model(**batch)
if for_training:
try:
assert "loss" in output_dict
regularization_penalty = self._model.get_regularization_penalty()
if regularization_penalty is not None:
output_dict["reg_loss"] = regularization_penalty
output_dict["loss"] += regularization_penalty
except AssertionError:
if for_training:
raise RuntimeError(
"The model you are trying to optimize does not contain a"
" 'loss' key in the output of model.forward(inputs)."
)
return output_dict
def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Gets the gradients of the loss with respect to the model inputs.
......
This diff is collapsed.
from .checkpointer import FinishingTrainingCheckpointer
from .scheduler import Scheduler
from .trainer import GradientDescentTrainer
from typing import Optional, Type
import pytorch_lightning as pl
from torch import Tensor
from combo.config import FromParameters
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.modules.model import Model
class TrainableCombo(pl.LightningModule, FromParameters):
def __init__(self,
model: Model,
optimizer_type: Type,
learning_rate: float = 0.1):
super().__init__()
self.model = model
self._optimizer_type = optimizer_type
self._lr = learning_rate
def forward(self, batch: TensorDict) -> TensorDict:
return self.model.batch_outputs(batch, self.model.training)
def training_step(self, batch: TensorDict, batch_idx: int) -> Tensor:
output = self.forward(batch)
self.log("train_loss", output['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
return output["loss"]
def validation_step(self, batch: TensorDict, batch_idx: int) -> Tensor:
output = self.forward(batch)
self.log("validation_loss", output['loss'], on_step=True, on_epoch=True, prog_bar=True, logger=True)
return output["loss"]
def configure_optimizers(self):
return self._optimizer_type(self.model.parameters(), lr=self._lr)
from pytorch_lightning import Trainer
class Callback:
pass
class TransferPatienceEpochCallback:
pass
class GradientDescentTrainer(Trainer):
pass
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment