Skip to content
Snippets Groups Projects
Unverified Commit 6736343c authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Fix repo checks

parent e800454d
Branches
1 merge request!13Change data model
Showing
with 82 additions and 60 deletions
......@@ -10,7 +10,7 @@ from sziszapangma.model.model import Word
class AlignmentUtil:
@staticmethod
def _optional_str_to_str(word: Optional[Word]) -> str:
return word['text'] if word is not None else ""
return word["text"] if word is not None else ""
@staticmethod
def _wer_step_to_pandas_row_lit(step: AlignmentStep) -> Tuple[str, str, str, float]:
......
......@@ -22,7 +22,7 @@ class DistanceCalculator(ABC):
class BinaryDistanceCalculator(DistanceCalculator):
def calculate_distance_for_words(self, word1: Word, word2: Word) -> float:
return 0 if word1['text'] == word2['text'] else 1
return 0 if word1["text"] == word2["text"] else 1
def calculate_distance_matrix(
self, reference: List[Word], hypothesis: List[Word]
......@@ -46,8 +46,8 @@ class CosineDistanceCalculator(DistanceCalculator):
def calculate_distance_for_words(self, word1: Word, word2: Word) -> float:
return self.cosine_distance_between_words_embeddings(
self._embedding_transformer.get_embedding(word1['text']),
self._embedding_transformer.get_embedding(word2['text']),
self._embedding_transformer.get_embedding(word1["text"]),
self._embedding_transformer.get_embedding(word2["text"]),
)
@staticmethod
......@@ -74,14 +74,14 @@ class CosineDistanceCalculator(DistanceCalculator):
self, reference: List[Word], hypothesis: List[Word]
) -> npt.NDArray[np.float64]:
embeddings_dict = self._embedding_transformer.get_embeddings(
list(set(it['text'] for it in (reference + hypothesis)))
list(set(it["text"] for it in (reference + hypothesis)))
)
return np.array(
[
[
self.cosine_distance_between_words_embeddings(
embeddings_dict[reference_word['text']],
embeddings_dict[hypothesis_word['text']],
embeddings_dict[reference_word["text"]],
embeddings_dict[hypothesis_word["text"]],
)
for hypothesis_word in hypothesis
]
......
......@@ -26,5 +26,8 @@ class ExperimentManager:
def process(self):
self._experiment_repository.initialise()
for processing_task in self._processing_tasks:
processing_task.process(self._record_id_iterator, self._experiment_repository,
self._relation_manager_provider)
processing_task.process(
self._record_id_iterator,
self._experiment_repository,
self._relation_manager_provider,
)
......@@ -15,8 +15,10 @@ class StepWordsMapper:
@staticmethod
def from_json_dict(input_json_dict: Dict[str, Any]) -> StepWords:
reference_word = None if "reference_word" not in input_json_dict \
else input_json_dict["reference_word"]
hypothesis_word = None if "hypothesis_word" not in input_json_dict \
else input_json_dict["hypothesis_word"]
reference_word = (
None if "reference_word" not in input_json_dict else input_json_dict["reference_word"]
)
hypothesis_word = (
None if "hypothesis_word" not in input_json_dict else input_json_dict["hypothesis_word"]
)
return StepWords(reference_word, hypothesis_word)
......@@ -4,7 +4,6 @@ from sziszapangma.model.relation_manager import RelationManager
class RelationManagerProvider(ABC):
@abstractmethod
def get_relation_manager(self, record_id: str) -> RelationManager:
pass
......@@ -30,14 +30,14 @@ class AsrTask(ProcessingTask):
return asr_value is not None and "transcription" in asr_value
def run_single_process(
self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
) -> None:
file_record_path = self._record_path_provider.get_path(record_id)
asr_result = self._asr_processor.call_recognise(file_record_path)
asr_result["transcription"] = [
create_new_word(it) for it in asr_result["transcription"]
]
asr_result["transcription"] = [create_new_word(it) for it in asr_result["transcription"]]
experiment_repository.update_property_for_key(
record_id, self._asr_property_name, asr_result
)
......@@ -43,8 +43,12 @@ class ClassicWerMetricTask(ProcessingTask):
is not None
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
def run_single_process(
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
......@@ -59,9 +63,7 @@ class ClassicWerMetricTask(ProcessingTask):
)
def _get_alignment(
self,
gold_transcript: List[Word],
asr_transcript: List[Word]
self, gold_transcript: List[Word], asr_transcript: List[Word]
) -> List[AlignmentStep]:
return self._alignment_classic_calculator.calculate_alignment(
reference=gold_transcript, hypothesis=asr_transcript
......
......@@ -51,8 +51,12 @@ class EmbeddingWerMetricsTask(ProcessingTask):
is not None
)
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
def run_single_process(
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
):
gold_transcript = TaskUtil.get_words_from_record(relation_manager)
asr_result = experiment_repository.get_property_for_key(record_id, self._asr_property_name)
if gold_transcript is not None and asr_result is not None and "transcription" in asr_result:
......
......@@ -16,8 +16,12 @@ class ProcessingTask(ABC):
self._task_name = task_name
@abstractmethod
def run_single_process(self, record_id: str, experiment_repository: ExperimentRepository,
relation_manager: RelationManager):
def run_single_process(
self,
record_id: str,
experiment_repository: ExperimentRepository,
relation_manager: RelationManager,
):
pass
@abstractmethod
......@@ -25,8 +29,10 @@ class ProcessingTask(ABC):
pass
def process(
self, record_id_iterator: RecordIdIterator, experiment_repository: ExperimentRepository,
relation_manager_provider: RelationManagerProvider
self,
record_id_iterator: RecordIdIterator,
experiment_repository: ExperimentRepository,
relation_manager_provider: RelationManagerProvider,
):
records_ids = list(record_id_iterator.get_all_records())
for record_index in range(len(records_ids)):
......
from typing import List
from typing import List, cast
from sziszapangma.model.model import Word
from sziszapangma.model.model import Document, Word
from sziszapangma.model.relation_manager import RelationManager
class TaskUtil:
@staticmethod
def get_words_from_record(relation_manager: RelationManager) -> List[Word]:
document = [itt for itt in relation_manager.get_all_items() if itt['type'] == 'Document'][0]
return [relation_manager.get_item_by_id(item_id) for item_id in document['word_ids']]
document = cast(
Document,
[itt for itt in relation_manager.get_all_items() if itt["type"] == "Document"][0],
)
return [
cast(Word, relation_manager.get_item_by_id(item_id)) for item_id in document["word_ids"]
]
@staticmethod
def _word_to_lower(word: Word) -> Word:
return Word(id=word['id'], type='Word', text=word['text'].lower())
return Word(id=word["id"], type="Word", text=word["text"].lower())
@staticmethod
def words_to_lower(words: List[Word]) -> List[Word]:
......
from typing import List, TypedDict, Literal, Any, Union
from typing import Any, List, Literal, TypedDict, Union
AnnotationType = Literal["lemma", "pos", "morph", "concept", "chunk", "turn"]
ReferenceType = Literal["Token", "Word", "Document", "SpanAnnotation"]
......
import uuid
from typing import List
from typing import Any, List
from sziszapangma.model.model import Word, SingleAnnotation, AnnotationType, SpanAnnotation, \
Document
from sziszapangma.model.model import (
AnnotationType,
Document,
SingleAnnotation,
SpanAnnotation,
Word,
)
def _get_uuid() -> str:
......@@ -10,21 +15,20 @@ def _get_uuid() -> str:
def create_new_word(text: str) -> Word:
return Word(id=_get_uuid(), type='Word', text=text)
return Word(id=_get_uuid(), type="Word", text=text)
def create_new_single_annotation(
annotation_type: AnnotationType,
value: any,
reference_id: str
annotation_type: AnnotationType, value: Any, reference_id: str
) -> SingleAnnotation:
return SingleAnnotation(id=_get_uuid(), type=annotation_type, value=value,
reference_id=reference_id)
return SingleAnnotation(
id=_get_uuid(), type=annotation_type, value=value, reference_id=reference_id
)
def create_new_span_annotation(name: str, elements: List[str]) -> SpanAnnotation:
return SpanAnnotation(id=_get_uuid(), type='SpanAnnotation', name=name, elements=elements)
return SpanAnnotation(id=_get_uuid(), type="SpanAnnotation", name=name, elements=elements)
def create_new_document(word_ids: List[str]) -> Document:
return Document(id=_get_uuid(), type='Document', word_ids=word_ids)
return Document(id=_get_uuid(), type="Document", word_ids=word_ids)
import json
import os
from abc import abstractmethod, ABC
from typing import List, TypedDict, Dict
from abc import ABC, abstractmethod
from typing import Dict, List, TypedDict
import pandas as pd
......@@ -73,15 +73,17 @@ class FileRelationManager(RelationManager):
return self.relations_dataframe.query(query).to_dict("records")
def get_all_relations_for_item(self, item_id: str) -> List[RelationItem]:
return self.relations_dataframe[
self.relations_dataframe["first_id"] == item_id
].to_dict("records")
return self.relations_dataframe[self.relations_dataframe["first_id"] == item_id].to_dict(
"records"
)
def get_item_by_id(self, item_id) -> UUIDable:
return self.items_dict[item_id]
def save_item(self, item: UUIDable):
self.items_dict[item.get("id")] = item
item_id = item.get("id")
if item_id is not None:
self.items_dict[item_id] = item
def save_relation(self, item1: UUIDable, item2: UUIDable):
relation_1_2 = {
......@@ -96,12 +98,8 @@ class FileRelationManager(RelationManager):
"second_id": item1["id"],
"second_type": item1["type"],
}
self.relations_dataframe = self.relations_dataframe.append(
relation_1_2, ignore_index=True
)
self.relations_dataframe = self.relations_dataframe.append(
relation_2_1, ignore_index=True
)
self.relations_dataframe = self.relations_dataframe.append(relation_1_2, ignore_index=True)
self.relations_dataframe = self.relations_dataframe.append(relation_2_1, ignore_index=True)
def commit(self):
self.relations_dataframe.to_csv(self.relations_csv_path, index=False)
......
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment