Skip to content
Snippets Groups Projects
Select Git revision
  • 5fd4d9598fe7a36a37e3464414b81a49963065f2
  • master default protected
  • develop protected
  • develop-0.7.x
  • develop-0.8.0
  • dev_czuk
  • loader
  • kgr10_roberta
  • 14-BiLSTM-CRF-RoBERTa
  • 12-handle-long-sequences
  • 13-flair-embeddings
  • BiLSTM
  • v0.7.0
  • v0.6.1
  • v0.5
  • v0.4.1
  • v0.3
17 results

process_tsv.py

Blame
  • process_tsv.py 2.74 KiB
    from __future__ import absolute_import, division, print_function
    
    import argparse
    import logging
    import os
    
    import torch
    
    import poldeepner2
    from poldeepner2.utils.data_utils import read_tsv, save_tsv
    from poldeepner2.utils.seed import setup_seed
    from poldeepner2.utils.sequences import FeatureGeneratorFactory
    
    
    def main(args):
        print("Loading the NER model ...")
    
        ner = poldeepner2.load(args.model, device=args.device)
    
        for param in ["device", "max_seq_length", "sequence_generator", "output_top_k"]:
            value = args.__dict__.get(param, None)
            if value is not None:
                value_default = ner.model.config.__dict__.get(param)
                if str(value) != str(value_default):
                    print(f"Forced change of the parameter: {param} '{value_default}' => '{value}'")
                    ner.model.config.__dict__[param] = value
    
        if args.seed is not None:
            setup_seed(args.seed)
    
        logging.info("Processing ...")
        sentences_labels = read_tsv(os.path.join(args.input))
        sentences = [sentence[0] for sentence in sentences_labels]
    
        logging.info(f"Number of sentences to process: {len(sentences)}")
        with torch.no_grad():
            predictions, stats = ner.process(sentences, args.max_seq_length)
        save_tsv(os.path.join(args.output), sentences, predictions)
    
        logging.info("done.")
    
    
    def parse_args():
        parser = argparse.ArgumentParser(
            description='Process a single TSV with a NER model')
        parser.add_argument('--input', required=True, metavar='PATH', help='path to a file with a list of files')
        parser.add_argument('--model', required=True, metavar='PATH', help='path or name of the model')
        parser.add_argument('--output', required=True, metavar='PATH',
                            help='path to a json output file')
        parser.add_argument('--max_seq_length', required=False, default=None, metavar='N', type=int,
                            help='override default values of the max_seq_length')
        parser.add_argument('--device', default=None, metavar='cpu|cuda',
                            help='override default value of the device')
        parser.add_argument('--sequence-generator', type=str, choices=FeatureGeneratorFactory.methods,
                            help="method of sequence generation", default=None, required=False)
        parser.add_argument('--seed', required=False, default=None, metavar='N', type=int,
                            help='a seed used to initialize a number generator')
        parser.add_argument('--output-top-k', required=False, default=None, metavar='N', type=int,
                            help='output top k labels for each token')
        return parser.parse_args()
    
    
    if __name__ == "__main__":
        cliargs = parse_args()
        try:
            main(cliargs)
        except ValueError as er:
            print("[ERROR] %s" % er)