Skip to content
Snippets Groups Projects
Unverified Commit 3829ad0b authored by Marcin Wątroba's avatar Marcin Wątroba
Browse files

Add web embeddings client

parent 27776f8e
No related branches found
No related tags found
1 merge request!10Feature/add auth asr service
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="MypyConfigService">
<option name="customMypyPath" value="$USER_HOME$/miniconda3/envs/asr-benchmarks/bin/mypy" />
</component>
</project>
\ No newline at end of file
import json
from typing import Dict, List, Optional
import numpy as np
import requests
from sziszapangma.core.transformer.embedding_transformer import EmbeddingTransformer
class WebEmbeddingTransformer(EmbeddingTransformer):
_lang_id: str
_host: str
_auth_token: Optional[str]
def __init__(self, lang_id: str, host: str, auth_token: Optional[str]):
self._lang_id = lang_id
self._host = host
self._auth_token = auth_token
def get_embedding(self, word: str) -> np.ndarray:
return self._make_request([word])[word]
def get_embeddings(self, words: List[str]) -> Dict[str, np.ndarray]:
return self._make_request(words)
def _make_request(self, words: List[str]) -> Dict[str, np.ndarray]:
print(f"call for words: {json.dumps(words)}")
headers = {"Content-Type": "application/json"}
if self._auth_token is not None:
headers["Authorization"] = f"Bearer {self._auth_token}"
result = requests.post(
f"{self._host}/embeddings/{self._lang_id}", headers=headers, json=words
)
if not result.ok:
print(f"response_code {result.status_code}")
result_text = result.text
parsed_result = json.loads(result_text)
return {it: np.array(parsed_result[it]) for it in parsed_result.keys()}
import os
import traceback
import uuid
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from flask import Flask, Response, jsonify, request
from flask_httpauth import HTTPTokenAuth
from sziszapangma.integration.service_core.asr.asr_result import AsrResult
_TEMP_DIRECTORY = "asr_processing"
_AUTH_TOKEN = "AUTH_TOKEN"
_SERVICE_PORT = "SERVICE_PORT"
class AsrBaseProcessor(ABC):
user_token: str
def __init__(self):
self.user_token = os.environ[_AUTH_TOKEN]
@abstractmethod
def process_asr(self, audio_file_path: str) -> AsrResult:
"""Method to call for ASR results."""
def process_request(self) -> Response:
file_tag = str(uuid.uuid4())
f = request.files["file"]
if f is not None and f.filename is not None:
file_extension = f.filename.split(".")[-1]
file_name = f"{file_tag}.{file_extension}"
file_path = f"{_TEMP_DIRECTORY}/{file_name}"
f.save(file_path)
try:
transcription = self.process_asr(file_path)
os.remove(file_path)
result_object = jsonify(
{"transcription": transcription.words, "full_text": transcription.full_text}
)
except Exception as exception:
print(exception)
traceback.print_exc()
result_object = jsonify({"error": "Error on asr processing"})
else:
result_object = jsonify({"error": "Error on asr processing"})
return result_object
def is_token_correct(self, token: str) -> Optional[str]:
if token == self.user_token:
return "asr_client"
else:
return None
def health_check(self) -> Response:
return jsonify({"status": "running"})
def start_processor(self):
app = Flask(__name__)
auth = HTTPTokenAuth(scheme="Bearer")
auth.verify_token(self.is_token_correct)
Path(_TEMP_DIRECTORY).mkdir(parents=True, exist_ok=True)
app.route("/process_asr", methods=["POST"])(auth.login_required(self.process_request))
app.route("/health_check", methods=["GET"])(self.health_check)
port = int(os.environ[_SERVICE_PORT]) if _SERVICE_PORT in os.environ else 5000
app.run(debug=True, host="0.0.0.0", port=port)
from dataclasses import dataclass
from typing import List
@dataclass(frozen=True)
class AsrResult:
words: List[str]
full_text: str
import os
from abc import ABC, abstractmethod
from typing import List, Optional
import numpy as np
import numpy.typing as npt
from flask import Flask, Response, jsonify, request
from flask_httpauth import HTTPTokenAuth
_TEMP_DIRECTORY = "asr_processing"
_AUTH_TOKEN = "AUTH_TOKEN"
_SERVICE_PORT = "SERVICE_PORT"
class EmbeddingBaseProcessor(ABC):
user_token: str
def __init__(self):
self.user_token = os.environ[_AUTH_TOKEN]
@abstractmethod
def get_embedding(self, phrase: str, language: str) -> npt.NDArray[np.float64]:
"""Method to call for embedding of phrase."""
@staticmethod
def numpy_to_list(arr) -> List[float]:
return [float(it) for it in arr]
def process_embeddings(self, language: str):
words = request.json
if words is not List[str]:
raise Exception("Incorrect body")
return jsonify(
{word: self.numpy_to_list(self.get_embedding(word, language)) for word in words}
)
def is_token_correct(self, token: str) -> Optional[str]:
if token == self.user_token:
return "asr_client"
else:
return None
def health_check(self) -> Response:
return jsonify({"status": "running"})
def start_processor(self):
app = Flask(__name__)
auth = HTTPTokenAuth(scheme="Bearer")
auth.verify_token(self.is_token_correct)
app.route("/process_embedding", methods=["POST"])(
auth.login_required(self.process_embeddings)
)
app.route("/health_check", methods=["GET"])(self.health_check)
port = int(os.environ[_SERVICE_PORT]) if _SERVICE_PORT in os.environ else 5000
app.run(debug=True, host="0.0.0.0", port=port)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment