Skip to content
Snippets Groups Projects
Commit d8329fba authored by Maja Jablonska's avatar Maja Jablonska
Browse files

Merge branch 'main' of https://gitlab.clarin-pl.eu/syntactic-tools/combo

# Conflicts:
#	combo/modules/seq2seq_encoders/transformer_encoder.py
parents 447006ce 99fe2a13
Branches
Tags
No related merge requests found
Pipeline #16617 passed with stage
in 49 seconds
...@@ -4,8 +4,10 @@ import os ...@@ -4,8 +4,10 @@ import os
import pathlib import pathlib
import tempfile import tempfile
from itertools import chain from itertools import chain
import random
from typing import Dict, Optional, Any, Tuple from typing import Dict, Optional, Any, Tuple
import numpy
import torch import torch
from absl import app, flags from absl import app, flags
import pytorch_lightning as pl import pytorch_lightning as pl
...@@ -47,6 +49,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1, ...@@ -47,6 +49,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1,
help="Number of devices to train on (default -1 auto mode - train on as many as possible)") help="Number of devices to train on (default -1 auto mode - train on as many as possible)")
flags.DEFINE_string(name="output_file", default="output.log", flags.DEFINE_string(name="output_file", default="output.log",
help="Predictions result file.") help="Predictions result file.")
flags.DEFINE_integer(name="seed", default=None, help="Random seed.")
# Training flags # Training flags
flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)") flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)")
...@@ -293,7 +296,17 @@ def read_model_from_config(logging_prefix: str) -> Optional[ ...@@ -293,7 +296,17 @@ def read_model_from_config(logging_prefix: str) -> Optional[
return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary
def set_seed(seed: int) -> None:
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def run(_): def run(_):
if FLAGS.seed:
set_seed(FLAGS.seed)
if FLAGS.mode == 'train': if FLAGS.mode == 'train':
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None
......
from typing import Optional from typing import Optional
from overrides import overrides
import torch import torch
from torch import nn from torch import nn
from combo.modules.encoder import _EncoderBase from combo.modules.encoder import _EncoderBase
from combo.config.from_parameters import FromParameters, register_arguments from combo.config.from_parameters import FromParameters, register_arguments
# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
from combo.nn.utils import add_positional_features from combo.nn.utils import add_positional_features
# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder
# from allennlp.nn.util import add_positional_features
class TransformerEncoder(_EncoderBase, FromParameters): class TransformerEncoder(_EncoderBase, FromParameters):
""" """
Implements a stacked self-attention encoder similar to the Transformer Implements a stacked self-attention encoder similar to the Transformer
......
...@@ -30,7 +30,7 @@ REQUIREMENTS = [ ...@@ -30,7 +30,7 @@ REQUIREMENTS = [
setup( setup(
name="combo", name="combo",
version="3.1.4", version="3.1.5",
author="Maja Jablonska", author="Maja Jablonska",
author_email="maja.jablonska@ipipan.waw.pl", author_email="maja.jablonska@ipipan.waw.pl",
install_requires=REQUIREMENTS, install_requires=REQUIREMENTS,
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment