Select Git revision
process_tsv.py
Michał Marcińczuk authored
process_tsv.py 2.74 KiB
from __future__ import absolute_import, division, print_function
import argparse
import logging
import os
import torch
import poldeepner2
from poldeepner2.utils.data_utils import read_tsv, save_tsv
from poldeepner2.utils.seed import setup_seed
from poldeepner2.utils.sequences import FeatureGeneratorFactory
def main(args):
print("Loading the NER model ...")
ner = poldeepner2.load(args.model, device=args.device)
for param in ["device", "max_seq_length", "sequence_generator", "output_top_k"]:
value = args.__dict__.get(param, None)
if value is not None:
value_default = ner.model.config.__dict__.get(param)
if str(value) != str(value_default):
print(f"Forced change of the parameter: {param} '{value_default}' => '{value}'")
ner.model.config.__dict__[param] = value
if args.seed is not None:
setup_seed(args.seed)
logging.info("Processing ...")
sentences_labels = read_tsv(os.path.join(args.input))
sentences = [sentence[0] for sentence in sentences_labels]
logging.info(f"Number of sentences to process: {len(sentences)}")
with torch.no_grad():
predictions, stats = ner.process(sentences, args.max_seq_length)
save_tsv(os.path.join(args.output), sentences, predictions)
logging.info("done.")
def parse_args():
parser = argparse.ArgumentParser(
description='Process a single TSV with a NER model')
parser.add_argument('--input', required=True, metavar='PATH', help='path to a file with a list of files')
parser.add_argument('--model', required=True, metavar='PATH', help='path or name of the model')
parser.add_argument('--output', required=True, metavar='PATH',
help='path to a json output file')
parser.add_argument('--max_seq_length', required=False, default=None, metavar='N', type=int,
help='override default values of the max_seq_length')
parser.add_argument('--device', default=None, metavar='cpu|cuda',
help='override default value of the device')
parser.add_argument('--sequence-generator', type=str, choices=FeatureGeneratorFactory.methods,
help="method of sequence generation", default=None, required=False)
parser.add_argument('--seed', required=False, default=None, metavar='N', type=int,
help='a seed used to initialize a number generator')
parser.add_argument('--output-top-k', required=False, default=None, metavar='N', type=int,
help='output top k labels for each token')
return parser.parse_args()
if __name__ == "__main__":
cliargs = parse_args()
try:
main(cliargs)
except ValueError as er:
print("[ERROR] %s" % er)