Skip to content
Snippets Groups Projects
Commit aff8d3c7 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

transformer on top of multiword_fix

parent 11b2e8a3
Branches
Tags
2 merge requests!49Multiword fix transformer encoder,!47Fixed multiword prediction + bug that made the code write empty predictions
This commit is part of merge request !47. Comments created here will be created in the context of that merge request.
"""Main COMBO model.""" """Main COMBO model."""
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List, Union
import numpy import numpy
import torch import torch
...@@ -24,6 +24,7 @@ from combo.nn.utils import get_text_field_mask ...@@ -24,6 +24,7 @@ from combo.nn.utils import get_text_field_mask
from combo.predictors import Predictor from combo.predictors import Predictor
from combo.utils import metrics from combo.utils import metrics
from combo.utils import ConfigurationError from combo.utils import ConfigurationError
from modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
@Registry.register("semantic_multitask") @Registry.register("semantic_multitask")
...@@ -39,7 +40,7 @@ class ComboModel(Model, FromParameters): ...@@ -39,7 +40,7 @@ class ComboModel(Model, FromParameters):
vocabulary: data.Vocabulary, vocabulary: data.Vocabulary,
loss_weights: Dict[str, float], loss_weights: Dict[str, float],
text_field_embedder: TextFieldEmbedder, text_field_embedder: TextFieldEmbedder,
seq_encoder: Seq2SeqEncoder, seq_encoder: Union[Seq2SeqEncoder, TransformerEncoder],
use_sample_weight: bool = True, use_sample_weight: bool = True,
lemmatizer: LemmatizerModel = None, lemmatizer: LemmatizerModel = None,
upos_tagger: MorphologicalFeatures = None, upos_tagger: MorphologicalFeatures = None,
......
...@@ -58,6 +58,8 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No ...@@ -58,6 +58,8 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No
return {k: serialize_single_value(v, pass_down_parameter_names) for k, v in value.items()} return {k: serialize_single_value(v, pass_down_parameter_names) for k, v in value.items()}
elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str): elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str):
return value return value
elif value is None:
return None
else: else:
return str(value) return str(value)
......
...@@ -12,7 +12,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \ ...@@ -12,7 +12,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
from combo.data.tokenizers import CharacterTokenizer from combo.data.tokenizers import CharacterTokenizer
from combo.data.vocabulary import Vocabulary from combo.data.vocabulary import Vocabulary
from combo.combo_model import ComboModel from combo.combo_model import ComboModel
from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM, ComboTransformerEncoder
from combo.models.dilated_cnn import DilatedCnnEncoder from combo.models.dilated_cnn import DilatedCnnEncoder
from combo.modules.lemma import LemmatizerModel from combo.modules.lemma import LemmatizerModel
from combo.modules.morpho import MorphologicalFeatures from combo.modules.morpho import MorphologicalFeatures
...@@ -122,7 +122,130 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary: ...@@ -122,7 +122,130 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
) )
def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> ComboModel: def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary, use_transformer_encoder=False) -> ComboModel:
if use_transformer_encoder:
return ComboModel(
vocabulary=vocabulary,
dependency_relation=DependencyRelationModel(
vocabulary=vocabulary,
dependency_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=1024,
out_features=128
),
head_predictor=HeadPredictionModel(
cycle_loss_n=0,
dependency_projection_layer=Linear(
activation=TanhActivation(),
in_features=1024,
out_features=512
),
head_projection_layer=Linear(
activation=TanhActivation(),
in_features=1024,
out_features=512
)
),
head_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=1024,
out_features=128
),
vocab_namespace="deprel_labels"
),
lemmatizer=LemmatizerModel(
vocabulary=vocabulary,
activations=[GELUActivation(), GELUActivation(), GELUActivation(), LinearActivation()],
char_vocab_namespace="token_characters",
dilation=[1, 2, 4, 1],
embedding_dim=300,
filters=[256, 256, 256],
input_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=1024,
out_features=32
),
kernel_size=[3, 3, 3, 1],
lemma_vocab_namespace="lemma_characters",
padding=[1, 2, 4, 0],
stride=[1, 1, 1, 1]
),
loss_weights={
"deprel": 0.8,
"feats": 0.2,
"head": 0.2,
"lemma": 0.05,
"semrel": 0.05,
"upostag": 0.05,
"xpostag": 0.05
},
morphological_feat=MorphologicalFeatures(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[128],
input_dim=1024,
num_layers=2,
vocab_namespace="feats_labels"
),
regularizer=RegularizerApplicator([
(".*conv1d.*", L2Regularizer(1e-6)),
(".*forward.*", L2Regularizer(1e-6)),
(".*backward.*", L2Regularizer(1e-6)),
(".*char_embed.*", L2Regularizer(1e-5))
]),
seq_encoder=ComboTransformerEncoder(
layer_dropout_probability=0.33,
input_dim=164,
num_layers=2,
feedforward_hidden_dim=2048,
num_attention_heads=4,
positional_encoding=None,
positional_embedding_size=512,
dropout_prob=0.1,
activation="relu"
),
text_field_embedder=BasicTextFieldEmbedder(
token_embedders={
"char": CharacterBasedWordEmbedder(
vocabulary=vocabulary,
dilated_cnn_encoder=DilatedCnnEncoder(
activations=[GELUActivation(), GELUActivation(), LinearActivation()],
dilation=[1, 2, 4],
filters=[512, 256, 64],
input_dim=64,
kernel_size=[3, 3, 3],
padding=[1, 2, 4],
stride=[1, 1, 1],
),
embedding_dim=64
),
"token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
}
),
upos_tagger=FeedForwardPredictor.from_vocab(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[64],
input_dim=1024,
num_layers=2,
vocab_namespace="upostag_labels"
),
xpos_tagger=FeedForwardPredictor.from_vocab(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[64],
input_dim=1024,
num_layers=2,
vocab_namespace="xpostag_labels"
)
)
else:
return ComboModel( return ComboModel(
vocabulary=vocabulary, vocabulary=vocabulary,
dependency_relation=DependencyRelationModel( dependency_relation=DependencyRelationModel(
......
...@@ -84,6 +84,7 @@ flags.DEFINE_boolean(name="turns", default=False, ...@@ -84,6 +84,7 @@ flags.DEFINE_boolean(name="turns", default=False,
help="Segment into sentences on sentence break or on turn break.") help="Segment into sentences on sentence break or on turn break.")
flags.DEFINE_boolean(name="split_subwords", default=False, flags.DEFINE_boolean(name="split_subwords", default=False,
help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.") help="Split subwords (e.g. don\'t = do, n\'t) into separate tokens.")
flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
# Finetune after training flags # Finetune after training flags
flags.DEFINE_string(name="finetuning_training_data_path", default="", flags.DEFINE_string(name="finetuning_training_data_path", default="",
...@@ -312,6 +313,10 @@ def run(_): ...@@ -312,6 +313,10 @@ def run(_):
FLAGS.validation_data_path, FLAGS.validation_data_path,
prefix prefix
) )
if FLAGS.transformer_encoder:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary, True)
else:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary) model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
if FLAGS.use_pure_config and model is None: if FLAGS.use_pure_config and model is None:
......
from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder, ComboTransformerEncoder
...@@ -15,6 +15,7 @@ from combo.modules.augmented_lstm import AugmentedLstm ...@@ -15,6 +15,7 @@ from combo.modules.augmented_lstm import AugmentedLstm
from combo.modules.input_variational_dropout import InputVariationalDropout from combo.modules.input_variational_dropout import InputVariationalDropout
from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.utils import ConfigurationError from combo.utils import ConfigurationError
from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
TensorPair = Tuple[torch.Tensor, torch.Tensor] TensorPair = Tuple[torch.Tensor, torch.Tensor]
...@@ -247,3 +248,46 @@ class ComboEncoder(Seq2SeqEncoder, FromParameters): ...@@ -247,3 +248,46 @@ class ComboEncoder(Seq2SeqEncoder, FromParameters):
x = self.layer_dropout(inputs) x = self.layer_dropout(inputs)
x = super().forward(x, mask) x = super().forward(x, mask)
return self.layer_dropout(x) return self.layer_dropout(x)
@Registry.register('combo_transformer_encoder')
class ComboTransformerEncoder(TransformerEncoder, FromParameters):
"""COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
(instead of used Gaussian Dropout and Gaussian Noise).
"""
@register_arguments
def __init__(self,
layer_dropout_probability: float,
input_dim: int,
num_layers: int,
feedforward_hidden_dim: int = 2048,
num_attention_heads: int = 4,
positional_encoding: Optional[str] = None,
positional_embedding_size: int = 512,
dropout_prob: float = 0.1,
activation: str = "relu"
# stacked_transformer: ComboStackedBidirectionalLSTM,
):
super().__init__(
input_dim,
num_layers,
feedforward_hidden_dim,
num_attention_heads,
positional_encoding,
positional_embedding_size,
dropout_prob,
activation
)
self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
def forward(self,
inputs: torch.Tensor,
mask: torch.BoolTensor,
hidden_state: torch.Tensor = None) -> torch.Tensor:
x = self.layer_dropout(inputs)
x = super().forward(x, mask)
return self.layer_dropout(x)
\ No newline at end of file
from typing import Optional
from overrides import overrides
import torch
from torch import nn
from combo.modules.encoder import _EncoderBase
from combo.config.from_parameters import FromParameters, register_arguments
# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from nn.utils import add_positional_features
# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
# from allennlp.nn.util import add_positional_features
class TransformerEncoder(_EncoderBase, FromParameters):
"""
Implements a stacked self-attention encoder similar to the Transformer
architecture in [Attention is all you Need]
(https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077).
This class adapts the Transformer from torch.nn for use in AllenNLP. Optionally, it adds positional encodings.
Registered as a `Seq2SeqEncoder` with name "pytorch_transformer".
# Parameters
input_dim : `int`, required.
The input dimension of the encoder.
feedforward_hidden_dim : `int`, required.
The middle dimension of the FeedForward network. The input and output
dimensions are fixed to ensure sizes match up for the self attention layers.
num_layers : `int`, required.
The number of stacked self attention -> feedforward -> layer normalisation blocks.
num_attention_heads : `int`, required.
The number of attention heads to use per layer.
use_positional_encoding : `bool`, optional, (default = `True`)
Whether to add sinusoidal frequencies to the input tensor. This is strongly recommended,
as without this feature, the self attention layers have no idea of absolute or relative
position (as they are just computing pairwise similarity between vectors of elements),
which can be important features for many tasks.
dropout_prob : `float`, optional, (default = `0.1`)
The dropout probability for the feedforward network.
""" # noqa
def __init__(
self,
input_dim: int,
num_layers: int,
feedforward_hidden_dim: int = 2048,
num_attention_heads: int = 4,
positional_encoding: Optional[str] = None,
positional_embedding_size: int = 512,
dropout_prob: float = 0.1,
activation: str = "relu",
) -> None:
super().__init__()
layer = nn.TransformerEncoderLayer(
d_model=input_dim,
nhead=num_attention_heads,
dim_feedforward=feedforward_hidden_dim,
dropout=dropout_prob,
activation=activation,
)
self._transformer = nn.TransformerEncoder(layer, num_layers)
self._input_dim = input_dim
# initialize parameters
# We do this before the embeddings are initialized so we get the default initialization for the embeddings.
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
if positional_encoding is None:
self._sinusoidal_positional_encoding = False
self._positional_embedding = None
elif positional_encoding == "sinusoidal":
self._sinusoidal_positional_encoding = True
self._positional_embedding = None
elif positional_encoding == "embedding":
self._sinusoidal_positional_encoding = False
self._positional_embedding = nn.Embedding(positional_embedding_size, input_dim)
else:
raise ValueError(
"positional_encoding must be one of None, 'sinusoidal', or 'embedding'"
)
def get_input_dim(self) -> int:
return self._input_dim
def get_output_dim(self) -> int:
return self._input_dim
def is_bidirectional(self):
return False
def forward(self, inputs: torch.Tensor, mask: torch.BoolTensor):
output = inputs
if self._sinusoidal_positional_encoding:
output = add_positional_features(output)
if self._positional_embedding is not None:
position_ids = torch.arange(inputs.size(1), dtype=torch.long, device=output.device)
position_ids = position_ids.unsqueeze(0).expand(inputs.shape[:-1])
output = output + self._positional_embedding(position_ids)
# For some reason the torch transformer expects the shape (sequence, batch, features), not the more
# familiar (batch, sequence, features), so we have to fix it.
output = output.permute(1, 0, 2)
# For some other reason, the torch transformer takes the mask backwards.
mask = ~mask
output = self._transformer(output, src_key_padding_mask=mask)
output = output.permute(1, 0, 2)
return output
\ No newline at end of file
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
Adapted from AllenNLP Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
""" """
import math
from typing import Union, Dict, Optional, List, Any, NamedTuple from typing import Union, Dict, Optional, List, Any, NamedTuple
import torch import torch
...@@ -479,3 +480,61 @@ def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch. ...@@ -479,3 +480,61 @@ def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch.
span_embeddings = batched_index_select(target, span_indices) span_embeddings = batched_index_select(target, span_indices)
return span_embeddings, span_mask return span_embeddings, span_mask
def add_positional_features(
tensor: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4
):
"""
Implements the frequency-based positional encoding described
in [Attention is All you Need][0].
Adds sinusoids of different frequencies to a `Tensor`. A sinusoid of a
different frequency and phase is added to each dimension of the input `Tensor`.
This allows the attention heads to use absolute and relative positions.
The number of timescales is equal to hidden_dim / 2 within the range
(min_timescale, max_timescale). For each timescale, the two sinusoidal
signals sin(timestep / timescale) and cos(timestep / timescale) are
generated and concatenated along the hidden_dim dimension.
[0]: https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077
# Parameters
tensor : `torch.Tensor`
a Tensor with shape (batch_size, timesteps, hidden_dim).
min_timescale : `float`, optional (default = `1.0`)
The smallest timescale to use.
max_timescale : `float`, optional (default = `1.0e4`)
The largest timescale to use.
# Returns
`torch.Tensor`
The input tensor augmented with the sinusoidal frequencies.
""" # noqa
_, timesteps, hidden_dim = tensor.size()
timestep_range = get_range_vector(timesteps, get_device_of(tensor)).data.float()
# We're generating both cos and sin frequencies,
# so half for each.
num_timescales = hidden_dim // 2
timescale_range = get_range_vector(num_timescales, get_device_of(tensor)).data.float()
log_timescale_increments = math.log(float(max_timescale) / float(min_timescale)) / float(
num_timescales - 1
)
inverse_timescales = min_timescale * torch.exp(timescale_range * -log_timescale_increments)
# Broadcasted multiplication - shape (timesteps, num_timescales)
scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
# shape (timesteps, 2 * num_timescales)
sinusoids = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
if hidden_dim % 2 != 0:
# if the number of dimensions is odd, the cos and sin
# timescales had size (hidden_dim - 1) / 2, so we need
# to add a row of zeros to make up the difference.
sinusoids = torch.cat([sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
return tensor + sinusoids.unsqueeze(0)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment