From 6736343cec1cc35bfba7fdd13206a5000c95b623 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Marcin=20W=C4=85troba?= <markowanga@gmail.com>
Date: Tue, 28 Dec 2021 12:54:48 +0100
Subject: [PATCH] Fix repo checks

---
 sziszapangma/core/alignment/alignment_util.py |  2 +-
 .../alignment/distance_matrix_calculator.py   | 12 ++++-----
 .../integration/experiment_manager.py         |  7 +++--
 .../integration/mapper/step_words_mapper.py   | 10 ++++---
 .../integration/relation_manager_provider.py  |  1 -
 sziszapangma/integration/task/asr_task.py     | 10 +++----
 .../task/classic_wer_metric_task.py           | 12 +++++----
 .../task/embedding_wer_metrics_task.py        |  8 ++++--
 .../integration/task/processing_task.py       | 14 +++++++---
 sziszapangma/integration/task/task_util.py    | 16 +++++++-----
 sziszapangma/model/model.py                   |  2 +-
 sziszapangma/model/model_creators.py          | 26 +++++++++++--------
 sziszapangma/model/relation_manager.py        | 22 +++++++---------
 13 files changed, 82 insertions(+), 60 deletions(-)

diff --git a/sziszapangma/core/alignment/alignment_util.py b/sziszapangma/core/alignment/alignment_util.py
index 5c2c850..0b7c6c4 100644
--- a/sziszapangma/core/alignment/alignment_util.py
+++ b/sziszapangma/core/alignment/alignment_util.py
@@ -10,7 +10,7 @@ from sziszapangma.model.model import Word
 class AlignmentUtil:
     @staticmethod
     def _optional_str_to_str(word: Optional[Word]) -> str:
-        return word['text'] if word is not None else ""
+        return word["text"] if word is not None else ""
 
     @staticmethod
     def _wer_step_to_pandas_row_lit(step: AlignmentStep) -> Tuple[str, str, str, float]:
diff --git a/sziszapangma/core/alignment/distance_matrix_calculator.py b/sziszapangma/core/alignment/distance_matrix_calculator.py
index 7c1f5be..7fabcb3 100644
--- a/sziszapangma/core/alignment/distance_matrix_calculator.py
+++ b/sziszapangma/core/alignment/distance_matrix_calculator.py
@@ -22,7 +22,7 @@ class DistanceCalculator(ABC):
 
 class BinaryDistanceCalculator(DistanceCalculator):
     def calculate_distance_for_words(self, word1: Word, word2: Word) -> float:
-        return 0 if word1['text'] == word2['text'] else 1
+        return 0 if word1["text"] == word2["text"] else 1
 
     def calculate_distance_matrix(
         self, reference: List[Word], hypothesis: List[Word]
@@ -46,8 +46,8 @@ class CosineDistanceCalculator(DistanceCalculator):
 
     def calculate_distance_for_words(self, word1: Word, word2: Word) -> float:
         return self.cosine_distance_between_words_embeddings(
-            self._embedding_transformer.get_embedding(word1['text']),
-            self._embedding_transformer.get_embedding(word2['text']),
+            self._embedding_transformer.get_embedding(word1["text"]),
+            self._embedding_transformer.get_embedding(word2["text"]),
         )
 
     @staticmethod
@@ -74,14 +74,14 @@ class CosineDistanceCalculator(DistanceCalculator):
         self, reference: List[Word], hypothesis: List[Word]
     ) -> npt.NDArray[np.float64]:
         embeddings_dict = self._embedding_transformer.get_embeddings(
-            list(set(it['text'] for it in (reference + hypothesis)))
+            list(set(it["text"] for it in (reference + hypothesis)))
         )
         return np.array(
             [
                 [
                     self.cosine_distance_between_words_embeddings(
-                        embeddings_dict[reference_word['text']],
-                        embeddings_dict[hypothesis_word['text']],
+                        embeddings_dict[reference_word["text"]],
+                        embeddings_dict[hypothesis_word["text"]],
                     )
                     for hypothesis_word in hypothesis
                 ]
diff --git a/sziszapangma/integration/experiment_manager.py b/sziszapangma/integration/experiment_manager.py
index 5190b12..85566e0 100644
--- a/sziszapangma/integration/experiment_manager.py
+++ b/sziszapangma/integration/experiment_manager.py
@@ -26,5 +26,8 @@ class ExperimentManager:
     def process(self):
         self._experiment_repository.initialise()
         for processing_task in self._processing_tasks:
-            processing_task.process(self._record_id_iterator, self._experiment_repository,
-                                    self._relation_manager_provider)
+            processing_task.process(
+                self._record_id_iterator,
+                self._experiment_repository,
+                self._relation_manager_provider,
+            )
diff --git a/sziszapangma/integration/mapper/step_words_mapper.py b/sziszapangma/integration/mapper/step_words_mapper.py
index 6cb1511..5b86049 100644
--- a/sziszapangma/integration/mapper/step_words_mapper.py
+++ b/sziszapangma/integration/mapper/step_words_mapper.py
@@ -15,8 +15,10 @@ class StepWordsMapper:
 
     @staticmethod
     def from_json_dict(input_json_dict: Dict[str, Any]) -> StepWords:
-        reference_word = None if "reference_word" not in input_json_dict \
-            else input_json_dict["reference_word"]
-        hypothesis_word = None if "hypothesis_word" not in input_json_dict \
-            else input_json_dict["hypothesis_word"]
+        reference_word = (
+            None if "reference_word" not in input_json_dict else input_json_dict["reference_word"]
+        )
+        hypothesis_word = (
+            None if "hypothesis_word" not in input_json_dict else input_json_dict["hypothesis_word"]
+        )
         return StepWords(reference_word, hypothesis_word)
diff --git a/sziszapangma/integration/relation_manager_provider.py b/sziszapangma/integration/relation_manager_provider.py
index a92fb11..8d54c43 100644
--- a/sziszapangma/integration/relation_manager_provider.py
+++ b/sziszapangma/integration/relation_manager_provider.py
@@ -4,7 +4,6 @@ from sziszapangma.model.relation_manager import RelationManager
 
 
 class RelationManagerProvider(ABC):
-
     @abstractmethod
     def get_relation_manager(self, record_id: str) -> RelationManager:
         pass
diff --git a/sziszapangma/integration/task/asr_task.py b/sziszapangma/integration/task/asr_task.py
index f93045c..bbb60b2 100644
--- a/sziszapangma/integration/task/asr_task.py
+++ b/sziszapangma/integration/task/asr_task.py
@@ -30,14 +30,14 @@ class AsrTask(ProcessingTask):
         return asr_value is not None and "transcription" in asr_value
 
     def run_single_process(
-        self, record_id: str, experiment_repository: ExperimentRepository,
-        relation_manager: RelationManager
+        self,
+        record_id: str,
+        experiment_repository: ExperimentRepository,
+        relation_manager: RelationManager,
     ) -> None:
         file_record_path = self._record_path_provider.get_path(record_id)
         asr_result = self._asr_processor.call_recognise(file_record_path)
-        asr_result["transcription"] = [
-            create_new_word(it) for it in asr_result["transcription"]
-        ]
+        asr_result["transcription"] = [create_new_word(it) for it in asr_result["transcription"]]
         experiment_repository.update_property_for_key(
             record_id, self._asr_property_name, asr_result
         )
diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py
index e3cbc2e..27ee08a 100644
--- a/sziszapangma/integration/task/classic_wer_metric_task.py
+++ b/sziszapangma/integration/task/classic_wer_metric_task.py
@@ -43,8 +43,12 @@ class ClassicWerMetricTask(ProcessingTask):
             is not None
         )
 
-    def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
-                           relation_manager: RelationManager):
+    def run_single_process(
+        self,
+        record_id: str,
+        experiment_repository: ExperimentRepository,
+        relation_manager: RelationManager,
+    ):
         gold_transcript = TaskUtil.get_words_from_record(relation_manager)
         asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
         if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
@@ -59,9 +63,7 @@ class ClassicWerMetricTask(ProcessingTask):
             )
 
     def _get_alignment(
-        self,
-        gold_transcript: List[Word],
-        asr_transcript: List[Word]
+        self, gold_transcript: List[Word], asr_transcript: List[Word]
     ) -> List[AlignmentStep]:
         return self._alignment_classic_calculator.calculate_alignment(
             reference=gold_transcript, hypothesis=asr_transcript
diff --git a/sziszapangma/integration/task/embedding_wer_metrics_task.py b/sziszapangma/integration/task/embedding_wer_metrics_task.py
index 426cb8d..c0f54ad 100644
--- a/sziszapangma/integration/task/embedding_wer_metrics_task.py
+++ b/sziszapangma/integration/task/embedding_wer_metrics_task.py
@@ -51,8 +51,12 @@ class EmbeddingWerMetricsTask(ProcessingTask):
             is not None
         )
 
-    def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
-                           relation_manager: RelationManager):
+    def run_single_process(
+        self,
+        record_id: str,
+        experiment_repository: ExperimentRepository,
+        relation_manager: RelationManager,
+    ):
         gold_transcript = TaskUtil.get_words_from_record(relation_manager)
         asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
         if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
diff --git a/sziszapangma/integration/task/processing_task.py b/sziszapangma/integration/task/processing_task.py
index 8d30526..6959ae1 100644
--- a/sziszapangma/integration/task/processing_task.py
+++ b/sziszapangma/integration/task/processing_task.py
@@ -16,8 +16,12 @@ class ProcessingTask(ABC):
         self._task_name = task_name
 
     @abstractmethod
-    def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
-                           relation_manager: RelationManager):
+    def run_single_process(
+        self,
+        record_id: str,
+        experiment_repository: ExperimentRepository,
+        relation_manager: RelationManager,
+    ):
         pass
 
     @abstractmethod
@@ -25,8 +29,10 @@ class ProcessingTask(ABC):
         pass
 
     def process(
-        self, record_id_iterator: RecordIdIterator, experiment_repository: ExperimentRepository,
-        relation_manager_provider: RelationManagerProvider
+        self,
+        record_id_iterator: RecordIdIterator,
+        experiment_repository: ExperimentRepository,
+        relation_manager_provider: RelationManagerProvider,
     ):
         records_ids = list(record_id_iterator.get_all_records())
         for record_index in range(len(records_ids)):
diff --git a/sziszapangma/integration/task/task_util.py b/sziszapangma/integration/task/task_util.py
index 71efe44..228d2c5 100644
--- a/sziszapangma/integration/task/task_util.py
+++ b/sziszapangma/integration/task/task_util.py
@@ -1,19 +1,23 @@
-from typing import List
+from typing import List, cast
 
-from sziszapangma.model.model import Word
+from sziszapangma.model.model import Document, Word
 from sziszapangma.model.relation_manager import RelationManager
 
 
 class TaskUtil:
-
     @staticmethod
     def get_words_from_record(relation_manager: RelationManager) -> List[Word]:
-        document = [itt for itt in relation_manager.get_all_items() if itt['type'] == 'Document'][0]
-        return [relation_manager.get_item_by_id(item_id) for item_id in document['word_ids']]
+        document = cast(
+            Document,
+            [itt for itt in relation_manager.get_all_items() if itt["type"] == "Document"][0],
+        )
+        return [
+            cast(Word, relation_manager.get_item_by_id(item_id)) for item_id in document["word_ids"]
+        ]
 
     @staticmethod
     def _word_to_lower(word: Word) -> Word:
-        return Word(id=word['id'], type='Word', text=word['text'].lower())
+        return Word(id=word["id"], type="Word", text=word["text"].lower())
 
     @staticmethod
     def words_to_lower(words: List[Word]) -> List[Word]:
diff --git a/sziszapangma/model/model.py b/sziszapangma/model/model.py
index 2bcba17..0236650 100644
--- a/sziszapangma/model/model.py
+++ b/sziszapangma/model/model.py
@@ -1,4 +1,4 @@
-from typing import List, TypedDict, Literal, Any, Union
+from typing import Any, List, Literal, TypedDict, Union
 
 AnnotationType = Literal["lemma", "pos", "morph", "concept", "chunk", "turn"]
 ReferenceType = Literal["Token", "Word", "Document", "SpanAnnotation"]
diff --git a/sziszapangma/model/model_creators.py b/sziszapangma/model/model_creators.py
index 702e713..dbd27f3 100644
--- a/sziszapangma/model/model_creators.py
+++ b/sziszapangma/model/model_creators.py
@@ -1,8 +1,13 @@
 import uuid
-from typing import List
+from typing import Any, List
 
-from sziszapangma.model.model import Word, SingleAnnotation, AnnotationType, SpanAnnotation, \
-    Document
+from sziszapangma.model.model import (
+    AnnotationType,
+    Document,
+    SingleAnnotation,
+    SpanAnnotation,
+    Word,
+)
 
 
 def _get_uuid() -> str:
@@ -10,21 +15,20 @@ def _get_uuid() -> str:
 
 
 def create_new_word(text: str) -> Word:
-    return Word(id=_get_uuid(), type='Word', text=text)
+    return Word(id=_get_uuid(), type="Word", text=text)
 
 
 def create_new_single_annotation(
-    annotation_type: AnnotationType,
-    value: any,
-    reference_id: str
+    annotation_type: AnnotationType, value: Any, reference_id: str
 ) -> SingleAnnotation:
-    return SingleAnnotation(id=_get_uuid(), type=annotation_type, value=value,
-                            reference_id=reference_id)
+    return SingleAnnotation(
+        id=_get_uuid(), type=annotation_type, value=value, reference_id=reference_id
+    )
 
 
 def create_new_span_annotation(name: str, elements: List[str]) -> SpanAnnotation:
-    return SpanAnnotation(id=_get_uuid(), type='SpanAnnotation', name=name, elements=elements)
+    return SpanAnnotation(id=_get_uuid(), type="SpanAnnotation", name=name, elements=elements)
 
 
 def create_new_document(word_ids: List[str]) -> Document:
-    return Document(id=_get_uuid(), type='Document', word_ids=word_ids)
+    return Document(id=_get_uuid(), type="Document", word_ids=word_ids)
diff --git a/sziszapangma/model/relation_manager.py b/sziszapangma/model/relation_manager.py
index d99b20e..74ec581 100644
--- a/sziszapangma/model/relation_manager.py
+++ b/sziszapangma/model/relation_manager.py
@@ -1,7 +1,7 @@
 import json
 import os
-from abc import abstractmethod, ABC
-from typing import List, TypedDict, Dict
+from abc import ABC, abstractmethod
+from typing import Dict, List, TypedDict
 
 import pandas as pd
 
@@ -73,15 +73,17 @@ class FileRelationManager(RelationManager):
         return self.relations_dataframe.query(query).to_dict("records")
 
     def get_all_relations_for_item(self, item_id: str) -> List[RelationItem]:
-        return self.relations_dataframe[
-            self.relations_dataframe["first_id"] == item_id
-            ].to_dict("records")
+        return self.relations_dataframe[self.relations_dataframe["first_id"] == item_id].to_dict(
+            "records"
+        )
 
     def get_item_by_id(self, item_id) -> UUIDable:
         return self.items_dict[item_id]
 
     def save_item(self, item: UUIDable):
-        self.items_dict[item.get("id")] = item
+        item_id = item.get("id")
+        if item_id is not None:
+            self.items_dict[item_id] = item
 
     def save_relation(self, item1: UUIDable, item2: UUIDable):
         relation_1_2 = {
@@ -96,12 +98,8 @@ class FileRelationManager(RelationManager):
             "second_id": item1["id"],
             "second_type": item1["type"],
         }
-        self.relations_dataframe = self.relations_dataframe.append(
-            relation_1_2, ignore_index=True
-        )
-        self.relations_dataframe = self.relations_dataframe.append(
-            relation_2_1, ignore_index=True
-        )
+        self.relations_dataframe = self.relations_dataframe.append(relation_1_2, ignore_index=True)
+        self.relations_dataframe = self.relations_dataframe.append(relation_2_1, ignore_index=True)
 
     def commit(self):
         self.relations_dataframe.to_csv(self.relations_csv_path, index=False)
-- 
GitLab