Skip to content
Snippets Groups Projects
Select Git revision
  • 160ee6ea3f05316abb2d91f2103a6d91009e819c
  • main default protected
  • ud_training_script
  • fix_seed
  • merged-with-ner
  • multiword_fix_transformer
  • transformer_encoder
  • combo3
  • save_deprel_matrix_to_npz
  • master protected
  • combo-lambo
  • lambo-sent-attributes
  • adding_lambo
  • develop
  • update_allenlp2
  • develop_tmp
  • tokens_truncation
  • LR_test
  • eud_iwpt
  • iob
  • eud_iwpt_shared_task_bert_finetuning
  • 3.3.1
  • list
  • 3.2.1
  • 3.0.3
  • 3.0.1
  • 3.0.0
  • v1.0.6
  • v1.0.5
  • v1.0.4
  • v1.0.3
  • v1.0.2
  • v1.0.1
  • v1.0.0
34 results

prediction.md

Blame
  • winer_worker.py 5.98 KiB
    """Winer worker implementation."""
    from winer import Winer, allign_into_entities, get_sentences_from_document
    import clarin_json
    from clarin_json.containers import Span
    import gc
    import torch
    
    import nlp_ws
    
    import logging
    _log = logging.getLogger(__name__)
    
    
    class WinerWorker(nlp_ws.NLPWorker):
        """Class implementing Winer worker."""
    
        @classmethod
        def static_init(cls, config):
            """Initialize process.
    
            :param config: configuration from ini file
            :type config: dict
            """
            cls.models_location = config["tool"].get("models_cache_dir",
                                                     "/home/worker/models/base")
            cls.batch_size = int(config["tool"].get("batch_size", "128"))
    
        def __init__(self):
            """Constructor."""
            _log.info("Loading models...")
            self._model = Winer(self.models_location,
                                batch_size=self.batch_size)
            self._model._tokenizer.model_max_length = 512
    
        @staticmethod
        def pack_sentences_to_max_tokens(
                plain_inputs,
                tokenized_inputs,
                sent_starts,
                max_tokens=512
            ):
            """
            Pack sentences into chunks, ensuring that the token count 
            per chunk does not exceed a given maximum.
    
            This method takes in plain text sentences, their tokenized 
            versions, and sentence start indices,
            and it creates batches of sentences such that the total 
            token count per batch does not exceed the
            specified max_tokens limit.
    
            Args:
                plain_inputs (list): List of plain text sentences.
                tokenized_inputs (list): List of tokenized versions 
                of the sentences.
                sent_starts (list): List of sentence start indices.
                max_tokens (int, optional): The maximum number of 
                tokens allowed per chunk. Defaults to 512.
    
            Returns:
                tuple: 
                    - packed_plain_inputs (list): List of packed 
                    plain text sentences, where each element 
                    is a chunk of sentences.
                    - packed_tokenized_inputs (list): List of packed 
                    tokenized inputs corresponding to the plain 
                    text chunks.
                    - packed_sent_starts (list): List of sentence 
                    start indices for each chunk.
            """
            packed_plain_inputs = []
            packed_tokenized_inputs = []
            packed_sent_starts = []
    
            current_plain_inputs = []
            current_tokenized_inputs = []
            current_sent_start = []
            current_token_count = 0
    
            for sentence, sentence_tokens, sent_start in zip(
                    plain_inputs, tokenized_inputs, sent_starts):
    
                if current_token_count + len(sentence_tokens) <= max_tokens:
                    current_plain_inputs.append(sentence)
                    current_tokenized_inputs.extend(sentence_tokens)
                    current_sent_start.append(sent_start)
                    current_token_count += len(sentence_tokens)
                else:
                    packed_plain_inputs.append(' '.join(current_plain_inputs))
                    packed_tokenized_inputs.append(current_tokenized_inputs)
                    packed_sent_starts.append(current_sent_start[0])
    
                    # Reset for a new batch
                    current_plain_inputs = []
                    current_tokenized_inputs = []
                    current_sent_start = []
                    current_token_count = 0
    
            # the last batch
            if current_plain_inputs:
                packed_plain_inputs.append(' '.join(current_plain_inputs))
                packed_tokenized_inputs.append(current_tokenized_inputs)
                packed_sent_starts.append(current_sent_start[0])
    
            return packed_plain_inputs, packed_tokenized_inputs, packed_sent_starts
    
        def process(
            self,
            input_path: str,
            task_options: dict,
            output_path: str
        ) -> None:
            """Called for each request made to the worker.
    
            :param input_path: Path to a file with text documents from which
                the worker should read text.
            :type input_path: str
            :param task_options: no task options
            :type task_options: dict
            :param output_path: Path to directory where the
                worker will store result.
            :type output_path: str
            """
            # Read inputs and open output
            F_ASCII = task_options.get("ensure_ascii", False)
            with clarin_json.open(output_path, "w", ensure_ascii=F_ASCII) as fout:
                with clarin_json.open(input_path, "r") as fin:
                    for document in fin:
                        # Pre-process document
                        plain_inputs, tokenized_inputs, sent_starts = \
                            get_sentences_from_document(document)
    
                        (
                            packed_plain_texts,
                            packed_tokenized_inputs,
                            packed_sent_starts
                        ) = (
                            self.pack_sentences_to_max_tokens(
                                plain_inputs,
                                tokenized_inputs,
                                sent_starts,
                                max_tokens=510
                            )
                        )
    
                        # Process data
                        aggregations = self._model.process(packed_tokenized_inputs)
    
                        # Aggregate results to entities
                        entities = allign_into_entities(
                            packed_tokenized_inputs,
                            aggregations,
                            inputs_sentences=packed_plain_texts,
                            sentences_range=packed_sent_starts
                        )
                        document.set_spans(
                            [Span(**entity)
                             for sent_entities in entities
                             for entity in sent_entities],
                            "ner")
    
                        # Write processed document
                        fout.write(document)
    
            # Clean the memory
            gc.collect()
            torch.cuda.empty_cache()