import argparse from experiment.const_pipeline_names import GOLD_TRANSCRIPT from experiment.experiment_dependency_provider import get_record_provider, get_repository from sziszapangma.core.transformer.web_embedding_transformer import WebEmbeddingTransformer from sziszapangma.integration.experiment_manager import ExperimentManager from sziszapangma.integration.task.classic_wer_metric_task import ClassicWerMetricTask from sziszapangma.integration.task.embedding_wer_metrics_task import EmbeddingWerMetricsTask def run_word_wer_pipeline(dataset_name: str, asr_name: str): record_provider = get_record_provider(dataset_name) experiment_processor = ExperimentManager( record_id_iterator=record_provider, processing_tasks=[ ClassicWerMetricTask( task_name=f'ClassicWerMetricTask___{dataset_name}___{asr_name}', asr_property_name=f'{asr_name}__result', gold_transcript_property_name=GOLD_TRANSCRIPT, metrics_property_name=f'{asr_name}__word_wer_classic_metrics', require_update=False, alignment_property_name=f'{asr_name}__word_wer_classic_alignment' ), EmbeddingWerMetricsTask( task_name='EmbeddingWerMetricsTask', asr_property_name=f'{asr_name}__result', gold_transcript_property_name=GOLD_TRANSCRIPT, metrics_property_name=f'{asr_name}__word_wer_embeddings_metrics', require_update=False, embedding_transformer=WebEmbeddingTransformer('pl', 'http://localhost:5003', 'fjsd-mkwe-oius-m9h2'), alignment_property_name=f'{asr_name}__word_wer_embeddings_alignment' ) ], experiment_repository=get_repository(dataset_name), relation_manager_provider=record_provider ) experiment_processor.process() if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument("--dataset") parser.add_argument("--asr") args = parser.parse_args() run_word_wer_pipeline(args.dataset, args.asr)