diff --git a/combo/main.py b/combo/main.py index e44feda58d6aacba48535bed040070dbc3d4ee92..49a0fc74e10e8693d70871e7a6ee6e2cb409803f 100755 --- a/combo/main.py +++ b/combo/main.py @@ -4,8 +4,10 @@ import os import pathlib import tempfile from itertools import chain +import random from typing import Dict, Optional, Any, Tuple +import numpy import torch from absl import app, flags import pytorch_lightning as pl @@ -47,6 +49,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1, help="Number of devices to train on (default -1 auto mode - train on as many as possible)") flags.DEFINE_string(name="output_file", default="output.log", help="Predictions result file.") +flags.DEFINE_integer(name="seed", default=None, help="Random seed.") # Training flags flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)") @@ -293,7 +296,17 @@ def read_model_from_config(logging_prefix: str) -> Optional[ return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary +def set_seed(seed: int) -> None: + random.seed(seed) + numpy.random.seed(seed) + torch.manual_seed(seed) + + def run(_): + + if FLAGS.seed: + set_seed(FLAGS.seed) + if FLAGS.mode == 'train': model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None diff --git a/combo/modules/seq2seq_encoders/transformer_encoder.py b/combo/modules/seq2seq_encoders/transformer_encoder.py index fc389b898c056eaecb20966f11f226b0c46d1358..3b126fac38009a82cfbe3efa259c7cd2c6a4a2af 100644 --- a/combo/modules/seq2seq_encoders/transformer_encoder.py +++ b/combo/modules/seq2seq_encoders/transformer_encoder.py @@ -1,20 +1,13 @@ from typing import Optional -from overrides import overrides import torch from torch import nn from combo.modules.encoder import _EncoderBase from combo.config.from_parameters import FromParameters, register_arguments - -# from modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder from combo.nn.utils import add_positional_features -# from allennlp.modules.seq2seq_encoders.seq2seq_encoder import Seq2SeqEncoder -# from allennlp.nn.util import add_positional_features - - class TransformerEncoder(_EncoderBase, FromParameters): """ Implements a stacked self-attention encoder similar to the Transformer diff --git a/setup.py b/setup.py index bd002c466e7694d1c23851bd3b3b6bd066beab70..196c10829f777535aee156d564ca0b3e57bcdee1 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ REQUIREMENTS = [ setup( name="combo", - version="3.1.4", + version="3.1.5", author="Maja Jablonska", author_email="maja.jablonska@ipipan.waw.pl", install_requires=REQUIREMENTS,