diff --git a/src/utils.py b/src/utils.py index 906de65e72dbf1561ec8107d2a76cc2b77f0c922..aef4911fc322026a5b071478d5275c642c281fee 100644 --- a/src/utils.py +++ b/src/utils.py @@ -6,7 +6,8 @@ from typing import Optional import yaml -PROJECT_ROOT = os.path.dirname(os.path.realpath("/".join(__file__.split("/")) + "/..")) +PROJECT_ROOT = os.path.dirname(os.path.realpath( + "/".join(__file__.split("/")) + "/..")) def get_config() -> dict: @@ -52,6 +53,26 @@ def remove_punctuation(text: str) -> str: return "".join(filter(lambda x: x.isalnum() or x.isspace(), text)) +def unify_whitespaces(text: str) -> str: + """Maps all whitespace characters into a simple ' ' + + Args: + text (str): Text containing multiple forms of whitespace + + Returns: + str: Text with a single form of whitespace + """ + result = "" + + for c in text: + if c.isspace(): + result += " " + else: + result += c + + return result + + def preprocess(text: str) -> str: """Makes sure that input is in the same format as training data (no non-alphanum chars, no double spaces, all lowercase etc.) @@ -63,6 +84,7 @@ def preprocess(text: str) -> str: str: Text in training-data format """ text = remove_punctuation(text) + text = unify_whitespaces(text) text = remove_multiple_spaces(text) text = text.lower() text = text.strip() diff --git a/worker.py b/worker.py index c7132f543ca808d54075a0193837ae70a458b54a..54a25c115b7cc46d4160e46be5ba105aa0951430 100755 --- a/worker.py +++ b/worker.py @@ -23,11 +23,14 @@ class Worker(nlp_ws.NLPWorker): self.config["deployment"]["device"], ) + self.model.train(False) + def process(self, input_file: str, task_options: dict, output_file: str) -> None: """Implementation of example tasks that copies files.""" with open(input_file, "r") as f: - text = preprocess(f.read()) + text = str(f.read()) + text = preprocess(text) text_processed = apply_actions_punctuation( text, self.chunk_size, self.tokenizer, self.model, self.threshold )