Skip to content
Snippets Groups Projects
Commit b6440d55 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

Added transformer encoder instead of BiLSTM

parent 649ff6cc
No related branches found
No related tags found
1 merge request!48Transformer encoder
......@@ -123,7 +123,7 @@
}
},
"seq_encoder": {
"type": "combo_encoder",
"type": "combo_transformer_encoder",
"parameters": {
"layer_dropout_probability": 0.33,
"stacked_bilstm": {
......
import functools
import inspect
import typing
from types import NoneType
from typing import Any, Callable, Dict, List, Optional
from combo.common.params import Params
......@@ -58,6 +59,8 @@ def serialize_single_value(value: Any, pass_down_parameter_names: List[str] = No
return {k: serialize_single_value(v, pass_down_parameter_names) for k, v in value.items()}
elif isinstance(value, int) or isinstance(value, float) or isinstance(value, str):
return value
elif isinstance(value, NoneType):
return None
else:
return str(value)
......@@ -145,6 +148,8 @@ class FromParameters:
parameters_dict = {}
for pn, param_value in parameters_to_serialize.items():
if pn == 'seq_encoder':
a = 0
if pn in pass_down_parameter_names:
continue
parameters_dict[pn] = serialize_single_value(param_value,
......
......@@ -12,7 +12,7 @@ from combo.data.token_indexers import TokenConstPaddingCharactersIndexer, \
from combo.data.tokenizers import CharacterTokenizer
from combo.data.vocabulary import Vocabulary
from combo.combo_model import ComboModel
from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM
from combo.models.encoder import ComboEncoder, ComboStackedBidirectionalLSTM, ComboTransformerEncoder
from combo.models.dilated_cnn import DilatedCnnEncoder
from combo.modules.lemma import LemmatizerModel
from combo.modules.morpho import MorphologicalFeatures
......@@ -94,7 +94,130 @@ def default_vocabulary(data_loader: DataLoader) -> Vocabulary:
)
def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> ComboModel:
def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary, use_transformer_encoder=False) -> ComboModel:
if use_transformer_encoder:
return ComboModel(
vocabulary=vocabulary,
dependency_relation=DependencyRelationModel(
vocabulary=vocabulary,
dependency_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=164,
out_features=128
),
head_predictor=HeadPredictionModel(
cycle_loss_n=0,
dependency_projection_layer=Linear(
activation=TanhActivation(),
in_features=164,
out_features=512
),
head_projection_layer=Linear(
activation=TanhActivation(),
in_features=164,
out_features=512
)
),
head_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=164,
out_features=128
),
vocab_namespace="deprel_labels"
),
lemmatizer=LemmatizerModel(
vocabulary=vocabulary,
activations=[ReLUActivation(), ReLUActivation(), ReLUActivation(), LinearActivation()],
char_vocab_namespace="token_characters",
dilation=[1, 2, 4, 1],
embedding_dim=256,
filters=[256, 256, 256],
input_projection_layer=Linear(
activation=TanhActivation(),
dropout_rate=0.25,
in_features=164,
out_features=32
),
kernel_size=[3, 3, 3, 1],
lemma_vocab_namespace="lemma_characters",
padding=[1, 2, 4, 0],
stride=[1, 1, 1, 1]
),
loss_weights={
"deprel": 0.8,
"feats": 0.2,
"head": 0.2,
"lemma": 0.05,
"semrel": 0.05,
"upostag": 0.05,
"xpostag": 0.05
},
morphological_feat=MorphologicalFeatures(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[128],
input_dim=164,
num_layers=2,
vocab_namespace="feats_labels"
),
regularizer=RegularizerApplicator([
(".*conv1d.*", L2Regularizer(1e-6)),
(".*forward.*", L2Regularizer(1e-6)),
(".*backward.*", L2Regularizer(1e-6)),
(".*char_embed.*", L2Regularizer(1e-5))
]),
seq_encoder=ComboTransformerEncoder(
layer_dropout_probability=0.33,
input_dim=164,
num_layers=2,
feedforward_hidden_dim=2048,
num_attention_heads=4,
positional_encoding=None,
positional_embedding_size=512,
dropout_prob=0.1,
activation="relu"
),
text_field_embedder=BasicTextFieldEmbedder(
token_embedders={
"char": CharacterBasedWordEmbedder(
vocabulary=vocabulary,
dilated_cnn_encoder=DilatedCnnEncoder(
activations=[ReLUActivation(), ReLUActivation(), LinearActivation()],
dilation=[1, 2, 4],
filters=[512, 256, 64],
input_dim=64,
kernel_size=[3, 3, 3],
padding=[1, 2, 4],
stride=[1, 1, 1],
),
embedding_dim=64
),
"token": TransformersWordEmbedder(pretrained_transformer_name, projection_dim=100)
}
),
upos_tagger=FeedForwardPredictor.from_vocab(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[64],
input_dim=164,
num_layers=2,
vocab_namespace="upostag_labels"
),
xpos_tagger=FeedForwardPredictor.from_vocab(
vocabulary=vocabulary,
activations=[TanhActivation(), LinearActivation()],
dropout=[0.25, 0.],
hidden_dims=[64],
input_dim=164,
num_layers=2,
vocab_namespace="xpostag_labels"
)
)
else:
return ComboModel(
vocabulary=vocabulary,
dependency_relation=DependencyRelationModel(
......@@ -215,3 +338,6 @@ def default_model(pretrained_transformer_name: str, vocabulary: Vocabulary) -> C
vocab_namespace="xpostag_labels"
)
)
......@@ -47,6 +47,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1,
help="Number of devices to train on (default -1 auto mode - train on as many as possible)")
flags.DEFINE_string(name="output_file", default="output.log",
help="Predictions result file.")
flags.DEFINE_boolean(name="transformer_encoder", default=False, help="Use transformer encoder.")
# Training flags
flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)")
......@@ -312,6 +313,9 @@ def run(_):
FLAGS.validation_data_path,
prefix
)
if FLAGS.transformer_encoder:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary, True)
else:
model = default_model(FLAGS.pretrained_transformer_name, vocabulary)
if FLAGS.use_pure_config and model is None:
......
from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder
from .encoder import ComboStackedBidirectionalLSTM, ComboEncoder, ComboTransformerEncoder
......@@ -15,6 +15,7 @@ from combo.modules.augmented_lstm import AugmentedLstm
from combo.modules.input_variational_dropout import InputVariationalDropout
from combo.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.utils import ConfigurationError
from combo.modules.seq2seq_encoders.transformer_encoder import TransformerEncoder
TensorPair = Tuple[torch.Tensor, torch.Tensor]
......@@ -247,3 +248,46 @@ class ComboEncoder(Seq2SeqEncoder, FromParameters):
x = self.layer_dropout(inputs)
x = super().forward(x, mask)
return self.layer_dropout(x)
@Registry.register('combo_transformer_encoder')
class ComboTransformerEncoder(TransformerEncoder, FromParameters):
"""COMBO encoder (https://www.aclweb.org/anthology/K18-2004.pdf).
This implementation uses Variational Dropout on the input and then outputs of each BiLSTM layer
(instead of used Gaussian Dropout and Gaussian Noise).
"""
@register_arguments
def __init__(self,
layer_dropout_probability: float,
input_dim: int,
num_layers: int,
feedforward_hidden_dim: int = 2048,
num_attention_heads: int = 4,
positional_encoding: Optional[str] = None,
positional_embedding_size: int = 512,
dropout_prob: float = 0.1,
activation: str = "relu"
# stacked_transformer: ComboStackedBidirectionalLSTM,
):
super().__init__(
input_dim,
num_layers,
feedforward_hidden_dim,
num_attention_heads,
positional_encoding,
positional_embedding_size,
dropout_prob,
activation
)
self.layer_dropout = input_variational_dropout.InputVariationalDropout(p=layer_dropout_probability)
def forward(self,
inputs: torch.Tensor,
mask: torch.BoolTensor,
hidden_state: torch.Tensor = None) -> torch.Tensor:
x = self.layer_dropout(inputs)
x = super().forward(x, mask)
return self.layer_dropout(x)
......@@ -2,6 +2,7 @@
Adapted from AllenNLP
https://github.com/allenai/allennlp/blob/80fb6061e568cb9d6ab5d45b661e86eb61b92c82/allennlp/nn/util.py
"""
import math
from typing import Union, Dict, Optional, List, Any, NamedTuple
import torch
......@@ -479,3 +480,61 @@ def batched_span_select(target: torch.Tensor, spans: torch.LongTensor) -> torch.
span_embeddings = batched_index_select(target, span_indices)
return span_embeddings, span_mask
def add_positional_features(
tensor: torch.Tensor, min_timescale: float = 1.0, max_timescale: float = 1.0e4
):
"""
Implements the frequency-based positional encoding described
in [Attention is All you Need][0].
Adds sinusoids of different frequencies to a `Tensor`. A sinusoid of a
different frequency and phase is added to each dimension of the input `Tensor`.
This allows the attention heads to use absolute and relative positions.
The number of timescales is equal to hidden_dim / 2 within the range
(min_timescale, max_timescale). For each timescale, the two sinusoidal
signals sin(timestep / timescale) and cos(timestep / timescale) are
generated and concatenated along the hidden_dim dimension.
[0]: https://www.semanticscholar.org/paper/Attention-Is-All-You-Need-Vaswani-Shazeer/0737da0767d77606169cbf4187b83e1ab62f6077
# Parameters
tensor : `torch.Tensor`
a Tensor with shape (batch_size, timesteps, hidden_dim).
min_timescale : `float`, optional (default = `1.0`)
The smallest timescale to use.
max_timescale : `float`, optional (default = `1.0e4`)
The largest timescale to use.
# Returns
`torch.Tensor`
The input tensor augmented with the sinusoidal frequencies.
""" # noqa
_, timesteps, hidden_dim = tensor.size()
timestep_range = get_range_vector(timesteps, get_device_of(tensor)).data.float()
# We're generating both cos and sin frequencies,
# so half for each.
num_timescales = hidden_dim // 2
timescale_range = get_range_vector(num_timescales, get_device_of(tensor)).data.float()
log_timescale_increments = math.log(float(max_timescale) / float(min_timescale)) / float(
num_timescales - 1
)
inverse_timescales = min_timescale * torch.exp(timescale_range * -log_timescale_increments)
# Broadcasted multiplication - shape (timesteps, num_timescales)
scaled_time = timestep_range.unsqueeze(1) * inverse_timescales.unsqueeze(0)
# shape (timesteps, 2 * num_timescales)
sinusoids = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 1)
if hidden_dim % 2 != 0:
# if the number of dimensions is odd, the cos and sin
# timescales had size (hidden_dim - 1) / 2, so we need
# to add a row of zeros to make up the difference.
sinusoids = torch.cat([sinusoids, sinusoids.new_zeros(timesteps, 1)], 1)
return tensor + sinusoids.unsqueeze(0)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment