Skip to content
Snippets Groups Projects
Commit 76fb7c4a authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Add training step

parent 4fb40473
1 merge request!46Merge COMBO 3.0 into master
......@@ -5,8 +5,7 @@ https://github.com/allenai/allennlp/blob/main/allennlp/data/data_loaders/simple_
import math
import random
from typing import Optional, List, Iterator
from typing import Optional, List, Iterator, Callable
import torch
......@@ -36,6 +35,7 @@ class SimpleDataLoader(DataLoader):
shuffle: bool = False,
batches_per_epoch: Optional[int] = None,
vocab: Optional[Vocabulary] = None,
collate_fn: Optional[Callable[[List[Instance]], TensorDict]] = DefaultDataCollator()
) -> None:
self.instances = instances
self.batch_size = batch_size
......@@ -44,7 +44,7 @@ class SimpleDataLoader(DataLoader):
self.vocab = vocab
self.cuda_device: Optional[torch.device] = None
self._batch_generator: Optional[Iterator[TensorDict]] = None
self.collate_fn = DefaultDataCollator()
self.collate_fn = collate_fn
def __len__(self) -> int:
if self.batches_per_epoch is not None:
......@@ -96,9 +96,10 @@ class SimpleDataLoader(DataLoader):
shuffle: bool = False,
batches_per_epoch: Optional[int] = None,
quiet: bool = False,
collate_fn: Optional[Callable[[List[Instance]], TensorDict]] = DefaultDataCollator()
) -> "SimpleDataLoader":
instance_iter = reader.read(data_path)
if not quiet:
instance_iter = Tqdm.tqdm(instance_iter, desc="loading instances")
instances = list(instance_iter)
return cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch)
return cls(instances, batch_size, shuffle=shuffle, batches_per_epoch=batches_per_epoch, collate_fn=collate_fn)
This diff is collapsed.
......@@ -164,6 +164,20 @@ def device_mapping(cuda_device: int):
return inner_device_mapping
def find_text_field_embedder(model: torch.nn.Module) -> torch.nn.Module:
"""
Takes a `Model` and returns the `Module` that is a `TextFieldEmbedder`. We return just the
first one, as it's very rare to have more than one. If there isn't a `TextFieldEmbedder` in the
given `Model`, we raise a `ValueError`.
"""
from combo.modules.text_field_embedders.text_field_embedder import TextFieldEmbedder
for module in model.modules():
if isinstance(module, TextFieldEmbedder):
return module
raise ValueError("Couldn't find TextFieldEmbedder!")
def get_lengths_from_binary_sequence_mask(mask: torch.BoolTensor) -> torch.LongTensor:
"""
Compute sequence lengths for each batch element in a tensor using a
......
......@@ -4,12 +4,14 @@ import sys
from typing import List, Union, Dict, Any
import numpy as np
import torch
from overrides import overrides
from combo import data, models, common
from combo.common import util
from combo.config import Registry
from combo.data import tokenizers, Instance, conllu2sentence, tokens2conllu, sentence2conllu
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict
......@@ -56,6 +58,15 @@ class COMBO(PredictorModule):
logger.error('Exiting.')
sys.exit(1)
def forward(self, inputs: TensorDict, training: bool = False) -> Dict[str, torch.Tensor]:
return self.batch_outputs(inputs, training)
def training_step(self, batch: TensorDict) -> Dict[str, torch.Tensor]:
self._model.train()
output = self.forward(batch)
output['loss'].backward()
return output
def predict(self, sentence: Union[str, List[str], List[List[str]], List[data.Sentence]]):
if isinstance(sentence, str):
return self.predict_json({"sentence": sentence})
......
......@@ -3,7 +3,7 @@ Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/main/allennlp/predictors/predictor.py
"""
from typing import List, Iterator, Dict, Tuple, Any
from typing import List, Iterator, Dict, Tuple, Any, Union
import logging
import json
import re
......@@ -11,6 +11,7 @@ from contextlib import contextmanager
import numpy
import torch
from overrides import overrides
from torch.utils.hooks import RemovableHandle
from torch import Tensor
from torch import backends
......@@ -20,6 +21,7 @@ import pytorch_lightning as pl
from combo.common.util import sanitize
from combo.config import FromParameters
from combo.data.batch import Batch
from combo.data.dataset_loaders.dataset_loader import TensorDict
from combo.data.dataset_readers.dataset_reader import DatasetReader
from combo.data.instance import JsonDict, Instance
from combo.modules.model import Model
......@@ -43,6 +45,28 @@ class PredictorModule(pl.LightningModule, FromParameters):
self.cuda_device = next(self._model.named_parameters())[1].get_device()
self._token_offsets: List[Tensor] = []
def forward(self, inputs: Instance) -> Dict[str, torch.Tensor]:
r"""
Same as :meth:`torch.nn.Module.forward`.
Args:
*args: Whatever you decide to pass into the forward method.
**kwargs: Keyword arguments are also possible.
Return:
Your model's output
"""
return self._model.forward(inputs)
@overrides
def training_step(self, *args: Any, **kwargs: Any) -> Dict[str, torch.Tensor]:
raise NotImplementedError()
@overrides
def configure_optimizers(self) -> Any:
raise NotImplementedError()
def load_line(self, line: str) -> JsonDict:
"""
If your inputs are not in JSON-lines format (e.g. you have a CSV)
......@@ -78,6 +102,31 @@ class PredictorModule(pl.LightningModule, FromParameters):
new_instances = self.predictions_to_labeled_instances(instance, outputs)
return new_instances
def batch_outputs(self, batch: TensorDict, for_training: bool) -> Dict[str, torch.Tensor]:
"""
Does a forward pass on the given batch and returns the output dictionary that the model
returns, after adding any specified regularization penalty to the loss (if training).
"""
output_dict = self._model(**batch)
if for_training:
try:
assert "loss" in output_dict
regularization_penalty = self._model.get_regularization_penalty()
if regularization_penalty is not None:
output_dict["reg_loss"] = regularization_penalty
output_dict["loss"] += regularization_penalty
except AssertionError:
if for_training:
raise RuntimeError(
"The model you are trying to optimize does not contain a"
" 'loss' key in the output of model.forward(inputs)."
)
return output_dict
def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Gets the gradients of the loss with respect to the model inputs.
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment