Skip to content
Snippets Groups Projects
sequential_jsonl.py 2.18 KiB
Newer Older
from src.pipeline.interface import Pipeline
from typing import Dict
from src.suppressors.interface import Suppressor
from src.detectors.interface import Detector
from src.replacers.interface import ReplacerInterface
from src.input_parsers.interface import InputParser
import json

class SequentialJSONLPipeline(Pipeline):
    def __init__(
        self,
        input_parser: InputParser,
        detectors: Dict[str, Detector],
        suppressor: Suppressor,
        replacers: Dict[str, ReplacerInterface],
        concat_to_txt: bool = False,
    ):
        self._input_parser = input_parser
        self._detectors = detectors
        self._suppressor = suppressor
        self._replacers = replacers
        self._concat_to_txt = concat_to_txt
    def run(self, input_path) -> str:
        result = []
        with open(input_path, "r") as f:
            for line in f.readlines():
                if line.strip() == "":
                    continue
                parsed_input = self._input_parser.parse(line)

                detected_entities = []
                for detector_name, detector in self._detectors.items():
Michał Pogoda's avatar
Michał Pogoda committed
                    detected_entities += detector.detect(
                        parsed_input[0], parsed_input[1]
                    )

                annotaitons_cleaned = self._suppressor.suppress(detected_entities)

                replaced_input = parsed_input[0]
                annotations_left = annotaitons_cleaned
                for replacer_name, replacer in self._replacers.items():
                    replaced_input, annotations_left = replacer.replace(
                        replaced_input, annotations_left
                    )

                result.append({"text": replaced_input})
        if self._concat_to_txt:
            result_text = ""
            for item in result:
                text = item["text"]
                if result_text != "" and result_text.rstrip() == result_text and text.lstrip() == text:
                    result_text += " " + text
                else:
                    result_text += text
                    
            return result_text
        else:
            return "\n".join([json.dumps(item, ensure_ascii=False) for item in result])