Skip to content
Snippets Groups Projects
Commit ed29c7a0 authored by Martyna Wiącek's avatar Martyna Wiącek
Browse files

add setting random seed

parent f0811ea2
Branches
Tags
1 merge request!50Fix imports and random seed
......@@ -4,8 +4,10 @@ import os
import pathlib
import tempfile
from itertools import chain
import random
from typing import Dict, Optional, Any, Tuple
import numpy.random.bit_generator
import torch
from absl import app, flags
import pytorch_lightning as pl
......@@ -47,6 +49,7 @@ flags.DEFINE_integer(name="n_cuda_devices", default=-1,
help="Number of devices to train on (default -1 auto mode - train on as many as possible)")
flags.DEFINE_string(name="output_file", default="output.log",
help="Predictions result file.")
flags.DEFINE_integer(name="seed", default=None, help="Random seed.")
# Training flags
flags.DEFINE_string(name="training_data_path", default="", help="Training data path(s)")
......@@ -293,7 +296,17 @@ def read_model_from_config(logging_prefix: str) -> Optional[
return model, dataset_reader, training_data_loader, validation_data_loader, vocabulary
def set_seed(seed: int) -> None:
random.seed(seed)
numpy.random.seed(seed)
torch.manual_seed(seed)
def run(_):
if FLAGS.seed:
set_seed(FLAGS.seed)
if FLAGS.mode == 'train':
model, dataset_reader, training_data_loader, validation_data_loader, vocabulary = None, None, None, None, None
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment