diff --git a/sziszapangma/core/alignment/alignment_util.py b/sziszapangma/core/alignment/alignment_util.py index 5c2c850462b7d42795f5af06af9e09ad821b50c6..0b7c6c42136e94602b1792d987359c6d92bfb605 100644 --- a/sziszapangma/core/alignment/alignment_util.py +++ b/sziszapangma/core/alignment/alignment_util.py @@ -10,7 +10,7 @@ from sziszapangma.model.model import Word class AlignmentUtil: @staticmethod def _optional_str_to_str(word: Optional[Word]) -> str: - return word['text'] if word is not None else "" + return word["text"] if word is not None else "" @staticmethod def _wer_step_to_pandas_row_lit(step: AlignmentStep) -> Tuple[str, str, str, float]: diff --git a/sziszapangma/core/alignment/distance_matrix_calculator.py b/sziszapangma/core/alignment/distance_matrix_calculator.py index 7c1f5bebec2c664e66ce989422dc68ef9b4674bf..7fabcb3518d87a4345508c8b97d00067df4747c9 100644 --- a/sziszapangma/core/alignment/distance_matrix_calculator.py +++ b/sziszapangma/core/alignment/distance_matrix_calculator.py @@ -22,7 +22,7 @@ class DistanceCalculator(ABC): class BinaryDistanceCalculator(DistanceCalculator): def calculate_distance_for_words(self, word1: Word, word2: Word) -> float: - return 0 if word1['text'] == word2['text'] else 1 + return 0 if word1["text"] == word2["text"] else 1 def calculate_distance_matrix( self, reference: List[Word], hypothesis: List[Word] @@ -46,8 +46,8 @@ class CosineDistanceCalculator(DistanceCalculator): def calculate_distance_for_words(self, word1: Word, word2: Word) -> float: return self.cosine_distance_between_words_embeddings( - self._embedding_transformer.get_embedding(word1['text']), - self._embedding_transformer.get_embedding(word2['text']), + self._embedding_transformer.get_embedding(word1["text"]), + self._embedding_transformer.get_embedding(word2["text"]), ) @staticmethod @@ -74,14 +74,14 @@ class CosineDistanceCalculator(DistanceCalculator): self, reference: List[Word], hypothesis: List[Word] ) -> npt.NDArray[np.float64]: embeddings_dict = self._embedding_transformer.get_embeddings( - list(set(it['text'] for it in (reference + hypothesis))) + list(set(it["text"] for it in (reference + hypothesis))) ) return np.array( [ [ self.cosine_distance_between_words_embeddings( - embeddings_dict[reference_word['text']], - embeddings_dict[hypothesis_word['text']], + embeddings_dict[reference_word["text"]], + embeddings_dict[hypothesis_word["text"]], ) for hypothesis_word in hypothesis ] diff --git a/sziszapangma/integration/experiment_manager.py b/sziszapangma/integration/experiment_manager.py index 5190b121156435b0803e9fbc86ee60dff4ff3d7c..85566e0d8bac3c624df697c80355f9cd6fa506a2 100644 --- a/sziszapangma/integration/experiment_manager.py +++ b/sziszapangma/integration/experiment_manager.py @@ -26,5 +26,8 @@ class ExperimentManager: def process(self): self._experiment_repository.initialise() for processing_task in self._processing_tasks: - processing_task.process(self._record_id_iterator, self._experiment_repository, - self._relation_manager_provider) + processing_task.process( + self._record_id_iterator, + self._experiment_repository, + self._relation_manager_provider, + ) diff --git a/sziszapangma/integration/mapper/step_words_mapper.py b/sziszapangma/integration/mapper/step_words_mapper.py index 6cb1511dda793089863dffb426bee5b58a96664e..5b860491c70b4bbd4d3ae1fc4857cbbb43b77598 100644 --- a/sziszapangma/integration/mapper/step_words_mapper.py +++ b/sziszapangma/integration/mapper/step_words_mapper.py @@ -15,8 +15,10 @@ class StepWordsMapper: @staticmethod def from_json_dict(input_json_dict: Dict[str, Any]) -> StepWords: - reference_word = None if "reference_word" not in input_json_dict \ - else input_json_dict["reference_word"] - hypothesis_word = None if "hypothesis_word" not in input_json_dict \ - else input_json_dict["hypothesis_word"] + reference_word = ( + None if "reference_word" not in input_json_dict else input_json_dict["reference_word"] + ) + hypothesis_word = ( + None if "hypothesis_word" not in input_json_dict else input_json_dict["hypothesis_word"] + ) return StepWords(reference_word, hypothesis_word) diff --git a/sziszapangma/integration/relation_manager_provider.py b/sziszapangma/integration/relation_manager_provider.py index a92fb1132e986c11b940923fcacf37e83d0f749b..8d54c4320366379d4bdcaac3f499c93571738719 100644 --- a/sziszapangma/integration/relation_manager_provider.py +++ b/sziszapangma/integration/relation_manager_provider.py @@ -4,7 +4,6 @@ from sziszapangma.model.relation_manager import RelationManager class RelationManagerProvider(ABC): - @abstractmethod def get_relation_manager(self, record_id: str) -> RelationManager: pass diff --git a/sziszapangma/integration/task/asr_task.py b/sziszapangma/integration/task/asr_task.py index f93045c269a183d64f43f1546c9d887f8a3c055b..bbb60b2e5aa98a4156d7ca61100ff531624e5613 100644 --- a/sziszapangma/integration/task/asr_task.py +++ b/sziszapangma/integration/task/asr_task.py @@ -30,14 +30,14 @@ class AsrTask(ProcessingTask): return asr_value is not None and "transcription" in asr_value def run_single_process( - self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager + self, + record_id: str, + experiment_repository: ExperimentRepository, + relation_manager: RelationManager, ) -> None: file_record_path = self._record_path_provider.get_path(record_id) asr_result = self._asr_processor.call_recognise(file_record_path) - asr_result["transcription"] = [ - create_new_word(it) for it in asr_result["transcription"] - ] + asr_result["transcription"] = [create_new_word(it) for it in asr_result["transcription"]] experiment_repository.update_property_for_key( record_id, self._asr_property_name, asr_result ) diff --git a/sziszapangma/integration/task/classic_wer_metric_task.py b/sziszapangma/integration/task/classic_wer_metric_task.py index e3cbc2e8321012b283a1309c70becf724931cd66..27ee08aeee9b6eeb7179d3c05947da47f49b1d91 100644 --- a/sziszapangma/integration/task/classic_wer_metric_task.py +++ b/sziszapangma/integration/task/classic_wer_metric_task.py @@ -43,8 +43,12 @@ class ClassicWerMetricTask(ProcessingTask): is not None ) - def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager): + def run_single_process( + self, + record_id: str, + experiment_repository: ExperimentRepository, + relation_manager: RelationManager, + ): gold_transcript = TaskUtil.get_words_from_record(relation_manager) asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name) if gold_transcript is not None and asr_result is not None and "transcription" in asr_result: @@ -59,9 +63,7 @@ class ClassicWerMetricTask(ProcessingTask): ) def _get_alignment( - self, - gold_transcript: List[Word], - asr_transcript: List[Word] + self, gold_transcript: List[Word], asr_transcript: List[Word] ) -> List[AlignmentStep]: return self._alignment_classic_calculator.calculate_alignment( reference=gold_transcript, hypothesis=asr_transcript diff --git a/sziszapangma/integration/task/embedding_wer_metrics_task.py b/sziszapangma/integration/task/embedding_wer_metrics_task.py index 426cb8d8df62594b8147dfa75de7ef5c6857f8e5..c0f54ad88a57914f115b9589c49b73623a6af2f4 100644 --- a/sziszapangma/integration/task/embedding_wer_metrics_task.py +++ b/sziszapangma/integration/task/embedding_wer_metrics_task.py @@ -51,8 +51,12 @@ class EmbeddingWerMetricsTask(ProcessingTask): is not None ) - def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager): + def run_single_process( + self, + record_id: str, + experiment_repository: ExperimentRepository, + relation_manager: RelationManager, + ): gold_transcript = TaskUtil.get_words_from_record(relation_manager) asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name) if gold_transcript is not None and asr_result is not None and "transcription" in asr_result: diff --git a/sziszapangma/integration/task/processing_task.py b/sziszapangma/integration/task/processing_task.py index 8d30526097db531c60e0902942199c2ab742c54b..6959ae15075a81c1acc3b7050651572bf6829646 100644 --- a/sziszapangma/integration/task/processing_task.py +++ b/sziszapangma/integration/task/processing_task.py @@ -16,8 +16,12 @@ class ProcessingTask(ABC): self._task_name = task_name @abstractmethod - def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository, - relation_manager: RelationManager): + def run_single_process( + self, + record_id: str, + experiment_repository: ExperimentRepository, + relation_manager: RelationManager, + ): pass @abstractmethod @@ -25,8 +29,10 @@ class ProcessingTask(ABC): pass def process( - self, record_id_iterator: RecordIdIterator, experiment_repository: ExperimentRepository, - relation_manager_provider: RelationManagerProvider + self, + record_id_iterator: RecordIdIterator, + experiment_repository: ExperimentRepository, + relation_manager_provider: RelationManagerProvider, ): records_ids = list(record_id_iterator.get_all_records()) for record_index in range(len(records_ids)): diff --git a/sziszapangma/integration/task/task_util.py b/sziszapangma/integration/task/task_util.py index 71efe4479f0e0f1f9743c676c6d2a3ed25f232f0..228d2c59287ce20b7dd909bcab3af9ec63a74288 100644 --- a/sziszapangma/integration/task/task_util.py +++ b/sziszapangma/integration/task/task_util.py @@ -1,19 +1,23 @@ -from typing import List +from typing import List, cast -from sziszapangma.model.model import Word +from sziszapangma.model.model import Document, Word from sziszapangma.model.relation_manager import RelationManager class TaskUtil: - @staticmethod def get_words_from_record(relation_manager: RelationManager) -> List[Word]: - document = [itt for itt in relation_manager.get_all_items() if itt['type'] == 'Document'][0] - return [relation_manager.get_item_by_id(item_id) for item_id in document['word_ids']] + document = cast( + Document, + [itt for itt in relation_manager.get_all_items() if itt["type"] == "Document"][0], + ) + return [ + cast(Word, relation_manager.get_item_by_id(item_id)) for item_id in document["word_ids"] + ] @staticmethod def _word_to_lower(word: Word) -> Word: - return Word(id=word['id'], type='Word', text=word['text'].lower()) + return Word(id=word["id"], type="Word", text=word["text"].lower()) @staticmethod def words_to_lower(words: List[Word]) -> List[Word]: diff --git a/sziszapangma/model/model.py b/sziszapangma/model/model.py index 2bcba172644f7706541a457b594ff94f4d96763d..023665078d16adffbeab94ca6f06c593a374c476 100644 --- a/sziszapangma/model/model.py +++ b/sziszapangma/model/model.py @@ -1,4 +1,4 @@ -from typing import List, TypedDict, Literal, Any, Union +from typing import Any, List, Literal, TypedDict, Union AnnotationType = Literal["lemma", "pos", "morph", "concept", "chunk", "turn"] ReferenceType = Literal["Token", "Word", "Document", "SpanAnnotation"] diff --git a/sziszapangma/model/model_creators.py b/sziszapangma/model/model_creators.py index 702e713741ee6588cc954d2af0bbd2727be4ebc1..dbd27f39394e42ea611936c434c25d9d56a3edee 100644 --- a/sziszapangma/model/model_creators.py +++ b/sziszapangma/model/model_creators.py @@ -1,8 +1,13 @@ import uuid -from typing import List +from typing import Any, List -from sziszapangma.model.model import Word, SingleAnnotation, AnnotationType, SpanAnnotation, \ - Document +from sziszapangma.model.model import ( + AnnotationType, + Document, + SingleAnnotation, + SpanAnnotation, + Word, +) def _get_uuid() -> str: @@ -10,21 +15,20 @@ def _get_uuid() -> str: def create_new_word(text: str) -> Word: - return Word(id=_get_uuid(), type='Word', text=text) + return Word(id=_get_uuid(), type="Word", text=text) def create_new_single_annotation( - annotation_type: AnnotationType, - value: any, - reference_id: str + annotation_type: AnnotationType, value: Any, reference_id: str ) -> SingleAnnotation: - return SingleAnnotation(id=_get_uuid(), type=annotation_type, value=value, - reference_id=reference_id) + return SingleAnnotation( + id=_get_uuid(), type=annotation_type, value=value, reference_id=reference_id + ) def create_new_span_annotation(name: str, elements: List[str]) -> SpanAnnotation: - return SpanAnnotation(id=_get_uuid(), type='SpanAnnotation', name=name, elements=elements) + return SpanAnnotation(id=_get_uuid(), type="SpanAnnotation", name=name, elements=elements) def create_new_document(word_ids: List[str]) -> Document: - return Document(id=_get_uuid(), type='Document', word_ids=word_ids) + return Document(id=_get_uuid(), type="Document", word_ids=word_ids) diff --git a/sziszapangma/model/relation_manager.py b/sziszapangma/model/relation_manager.py index d99b20e487614dee1d1be0b28502ce34a6aa5606..74ec58181bc2dc659018a476528a096c5231a490 100644 --- a/sziszapangma/model/relation_manager.py +++ b/sziszapangma/model/relation_manager.py @@ -1,7 +1,7 @@ import json import os -from abc import abstractmethod, ABC -from typing import List, TypedDict, Dict +from abc import ABC, abstractmethod +from typing import Dict, List, TypedDict import pandas as pd @@ -73,15 +73,17 @@ class FileRelationManager(RelationManager): return self.relations_dataframe.query(query).to_dict("records") def get_all_relations_for_item(self, item_id: str) -> List[RelationItem]: - return self.relations_dataframe[ - self.relations_dataframe["first_id"] == item_id - ].to_dict("records") + return self.relations_dataframe[self.relations_dataframe["first_id"] == item_id].to_dict( + "records" + ) def get_item_by_id(self, item_id) -> UUIDable: return self.items_dict[item_id] def save_item(self, item: UUIDable): - self.items_dict[item.get("id")] = item + item_id = item.get("id") + if item_id is not None: + self.items_dict[item_id] = item def save_relation(self, item1: UUIDable, item2: UUIDable): relation_1_2 = { @@ -96,12 +98,8 @@ class FileRelationManager(RelationManager): "second_id": item1["id"], "second_type": item1["type"], } - self.relations_dataframe = self.relations_dataframe.append( - relation_1_2, ignore_index=True - ) - self.relations_dataframe = self.relations_dataframe.append( - relation_2_1, ignore_index=True - ) + self.relations_dataframe = self.relations_dataframe.append(relation_1_2, ignore_index=True) + self.relations_dataframe = self.relations_dataframe.append(relation_2_1, ignore_index=True) def commit(self): self.relations_dataframe.to_csv(self.relations_csv_path, index=False)