Skip to content
Snippets Groups Projects
Commit bd6a44c3 authored by Łukasz Pszenny's avatar Łukasz Pszenny
Browse files

Truncation in token indexer and script to train COMBO on PDB with transformer_encoder

parent 4a278559
No related branches found
No related tags found
No related merge requests found
......@@ -77,6 +77,11 @@ local in_features(name) = !(std.length(std.find(name, features)) == 0);
local in_targets(name) = !(std.length(std.find(name, targets)) == 0);
local use_transformer = pretrained_transformer_name != null;
# Transformer encoder options
local use_transformer_encoder = if std.length(std.extVar("use_transformer_encoder")) == "True" then true else false;
local num_layers_transformer_encoder = 6;
local num_attention_heads = 8;
# Verify some configuration requirements
assert in_features("token"): "Key 'token' must be in features!";
assert in_features("char"): "Key 'char' must be in features!";
......@@ -252,7 +257,17 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
},
},
loss_weights: loss_weights,
seq_encoder: {
seq_encoder: if use_transformer_encoder then {
type: "pytorch_transformer",
input_dim: (char_dim + projected_embedding_dim +
(if in_features('xpostag') then xpostag_dim else 0) +
(if in_features('lemma') then lemma_char_dim else 0) +
(if in_features('upostag') then upostag_dim else 0) +
(if in_features('feats') then feats_dim else 0)),
num_layers: num_layers_transformer_encoder,
feedforward_hidden_dim: hidden_size,
num_attention_heads: num_attention_heads,
positional_encoding: "sinusoidal"} else {
type: "combo_encoder",
layer_dropout_probability: 0.33,
stacked_bilstm: {
......@@ -266,7 +281,7 @@ assert pretrained_tokens == null || pretrained_transformer_name == null: "Can't
num_layers: num_layers,
recurrent_dropout_probability: 0.33,
layer_dropout_probability: 0.33
},
}
},
dependency_relation: {
type: "combo_dependency_parsing_from_vocab",
......
from typing import Optional, Dict, Any, List, Tuple
from allennlp import data
from allennlp.data import token_indexers, tokenizers
from allennlp.data import token_indexers, tokenizers, IndexedTokenList, vocabulary
from overrides import overrides
from typing import List
@data.TokenIndexer.register("pretrained_transformer_mismatched_fixed")
class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransformerMismatchedIndexer):
"""TODO(mklimasz) Remove during next allennlp update, fixed on allennlp master."""
def __init__(self, model_name: str, namespace: str = "tags", max_length: int = None,
tokenizer_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> None:
......@@ -24,6 +26,37 @@ class PretrainedTransformerMismatchedIndexer(token_indexers.PretrainedTransforme
self._num_added_start_tokens = self._matched_indexer._num_added_start_tokens
self._num_added_end_tokens = self._matched_indexer._num_added_end_tokens
@overrides
def tokens_to_indices(self,
tokens,
vocabulary: vocabulary ) -> IndexedTokenList:
"""
Method is overridden in order to raise an error while the number of tokens needed to embed a sentence exceeds the
maximal input of a model.
"""
self._matched_indexer._add_encoding_to_vocabulary_if_needed(vocabulary)
wordpieces, offsets = self._allennlp_tokenizer.intra_word_tokenize(
[t.ensure_text() for t in tokens])
if len(wordpieces) > self._tokenizer.max_len_single_sentence:
raise ValueError("Following sentence consists of more wordpiece tokens that the model can process:\n" +\
" ".join([str(x) for x in tokens[:10]]) + " ... \n" + \
f"Maximal input: {self._tokenizer.max_len_single_sentence}\n"+ \
f"Current input: {len(wordpieces)}")
offsets = [x if x is not None else (-1, -1) for x in offsets]
output: IndexedTokenList = {
"token_ids": [t.text_id for t in wordpieces],
"mask": [True] * len(tokens), # for original tokens (i.e. word-level)
"type_ids": [t.type_id for t in wordpieces],
"offsets": offsets,
"wordpiece_mask": [True] * len(wordpieces), # for wordpieces (i.e. subword-level)
}
return self._matched_indexer._postprocess_output(output)
class PretrainedTransformerIndexer(token_indexers.PretrainedTransformerIndexer):
......
......@@ -57,6 +57,8 @@ flags.DEFINE_string(name="serialization_dir", default=None,
help="Model serialization directory (default - system temp dir).")
flags.DEFINE_boolean(name="tensorboard", default=False,
help="When provided model will log tensorboard metrics.")
flags.DEFINE_boolean(name="use_transformer_encoder", default=False,
help="Indicator whether to use transformer encoder or BiLSTM (default)")
# Finetune after training flags
flags.DEFINE_list(name="finetuning_training_data_path", default="",
......@@ -197,6 +199,7 @@ def _get_ext_vars(finetuning: bool = False) -> Dict:
"num_epochs": str(FLAGS.num_epochs),
"word_batch_size": str(FLAGS.word_batch_size),
"use_tensorboard": str(FLAGS.tensorboard),
"use_transformer_encoder": str(FLAGS.use_transformer_encoder)
}
......
import pathlib
from absl import app
from absl import flags
from scripts import utils
LANG = ["Polish"]
TREEBANKS = {"Polish" : "UD_Polish-PDB"}
FLAGS = flags.FLAGS
flags.DEFINE_string(name="data_dir", default="/home/pszenny/combo",
help="Path to data directory.")
flags.DEFINE_string(name="models_dir", default="/home/pszenny/combo/tmp/",
help="Model serialization dir.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_boolean(name="expect_prefix", default=True,
help="Whether to expect allennlp prefix.")
flags.DEFINE_integer(name="batch_size", default=32,
help="Batch size.")
def run(_):
for encoder in ["BiLSTM","transformer"]:
models_dir = pathlib.Path(FLAGS.models_dir)
for model_dir in models_dir.iterdir():
lang = model_dir.name
if lang not in LANG:
print("Skipping unknown directory: ", lang)
continue
if FLAGS.expect_prefix:
model_dir = pathlib.Path(models_dir) / (lang + "/" + encoder + "/")
model_dir = list(model_dir.iterdir())
assert len(model_dir) == 1, f"There is incorrect count of models {model_dir}"
model_dir = model_dir[0]
data_dir = pathlib.Path(FLAGS.data_dir)
data_dir = data_dir / TREEBANKS[lang]
files = list(data_dir.iterdir())
test_file = [f for f in files if "test" in f.name and ".conllu" in f.name]
assert len(test_file) == 1, f"Couldn't find training file."
test_file = test_file[0]
output_pred = data_dir / f'{lang}_pred.conllu'
command = f"""combo --mode predict --model_path {model_dir}
--input_file {test_file}
--output_file {output_pred}
--cuda_device {FLAGS.cuda_device}
--batch_size {FLAGS.batch_size}
--silent
"""
utils.execute_command(command)
return 1
def main():
app.run(run)
if __name__ == "__main__":
main()
\ No newline at end of file
"""Script to train Dependency Parsing models based on UD 2.x data."""
import pathlib
from absl import app
from absl import flags
from scripts import utils
# # ls -1 | xargs -i echo "\"{}\","
# UD 2.7
TREEBANKS = ["UD_Polish-PDB"]
embedding_model = "allegro/herbert-base-cased"
FLAGS = flags.FLAGS
flags.DEFINE_list(name="treebanks", default=TREEBANKS,
help=f"Treebanks to train. Possible values: {TREEBANKS}.")
flags.DEFINE_string(name="data_dir", default="/home/pszenny/combo/",
help="Path to UD data directory.")
flags.DEFINE_string(name="serialization_dir", default="/home/pszenny/combo/tmp/",
help="Model serialization directory.")
flags.DEFINE_integer(name="cuda_device", default=-1,
help="Cuda device id (-1 for cpu).")
flags.DEFINE_string(name="train_config_path", default="/home/pszenny/combo/combo/config.template.jsonnet",
help="Directory of jsonnet config file")
flags.DEFINE_boolean(name="use_transformer_encoder", default=False,
help="Indicator whether to use transformer encoder or BiLSTM (default)")
def run(_):
treebanks_dir = pathlib.Path(FLAGS.data_dir)
for treebank in FLAGS.treebanks:
assert treebank in TREEBANKS, f"Unknown treebank {treebank}."
treebank_dir = treebanks_dir / treebank
treebank_parts = treebank[3:].split("-")
language = treebank_parts[0]
files = list(treebank_dir.iterdir())
training_file = [f for f in files if "train" in f.name and ".conllu" in f.name]
assert len(training_file) == 1, f"Couldn't find training file."
training_file_path = training_file[0]
valid_file = [f for f in files if "dev" in f.name and ".conllu" in f.name]
assert len(valid_file) == 1, f"Couldn't find validation file."
valid_file_path = valid_file[0]
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / (language+"/BiLSTM/")
serialization_dir.mkdir(exist_ok=True, parents=True)
word_batch_size = 2500
command = f"""time combo --mode train
--cuda_device {FLAGS.cuda_device}
--training_data_path {training_file_path}
--validation_data_path {valid_file_path}
--pretrained_transformer_name {embedding_model}
--serialization_dir {serialization_dir}
--use_transformer_encoder {FLAGS.use_transformer_encoder}
--config_path {FLAGS.train_config_path}
--notensorboard
--word_batch_size {word_batch_size}
--targets deprel,head,upostag,lemma,feats,xpostag
"""
utils.execute_command(command)
FLAGS.use_transformer_encoder = True
serialization_dir = pathlib.Path(FLAGS.serialization_dir) / (language+"/transformer/")
serialization_dir.mkdir(exist_ok=True, parents=True)
command = f"""time combo --mode train
--cuda_device {FLAGS.cuda_device}
--training_data_path {training_file_path}
--validation_data_path {valid_file_path}
--pretrained_transformer_name {embedding_model}
--serialization_dir {serialization_dir}
--use_transformer_encoder {FLAGS.use_transformer_encoder}
--config_path {FLAGS.train_config_path}
--notensorboard
--word_batch_size {word_batch_size}
--targets deprel,head,upostag,lemma,feats,xpostag
"""
utils.execute_command(command)
def main():
app.run(run)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment