Files
speech-to-text/lib/server.py

359 lines
13 KiB
Python

import json
import os
import sys
import tempfile
import shutil
import urllib.request
import urllib.error
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from .config import (
HOST, PORT, BACKEND_URL, API_KEY, DEFAULT_LANG, MAX_MB, MAX_BYTES, MODEL_TYPE,
AVAILABLE_MODELS, WHISPER_MODEL, LLM_URL, LLM_API_KEY, LLM_MODEL
)
from .frontend import INDEX_HTML
from .transcribe import transcribe_file
from .summarize import summarize
def _parse_multipart(body, content_type):
boundary = None
for part in content_type.split(';'):
part = part.strip()
if part.lower().startswith('boundary='):
boundary = part[9:].strip('"\'')
break
if not boundary:
return None, {}
b = ('--' + boundary).encode()
parts = body.split(b)
file_bytes = None
filename = ''
fields = {}
for p in parts[1:-1]:
if p.startswith(b'\r\n'):
p = p[2:]
elif p.startswith(b'\n'):
p = p[1:]
idx = p.find(b'\r\n\r\n')
if idx >= 0:
headers = p[:idx].decode('utf-8', 'replace')
data = p[idx + 4:]
else:
idx = p.find(b'\n\n')
headers = p[:idx].decode('utf-8', 'replace')
data = p[idx + 2:]
if data.endswith(b'\r\n'):
data = data[:-2]
elif data.endswith(b'\n'):
data = data[:-1]
name = None
for line in headers.splitlines():
line = line.strip()
if line.lower().startswith('content-disposition:'):
for item in line.split(';'):
item = item.strip()
if item.lower().startswith('name='):
name = item[5:].strip('"\'')
elif item.lower().startswith('filename='):
filename = item[9:].strip('"\'')
break
if name == 'file':
file_bytes = data
elif name:
try:
fields[name] = data.decode('utf-8')
except UnicodeDecodeError:
fields[name] = data.decode('utf-8', 'replace')
fields['file'] = filename
return file_bytes, fields
class Handler(BaseHTTPRequestHandler):
server_version = "speech-to-text/1.0"
protocol_version = "HTTP/1.1"
def log_message(self, fmt, *args):
sys.stderr.write("[%s] %s\n" % (self.address_string(), fmt % args))
def _send(self, code, body, content_type="text/plain; charset=utf-8", extra=None):
if isinstance(body, str):
body = body.encode("utf-8")
self.send_response(code)
self.send_header("Content-Type", content_type)
self.send_header("Content-Length", str(len(body)))
if extra:
for k, v in extra.items():
self.send_header(k, v)
self.end_headers()
if self.command != "HEAD":
self.wfile.write(body)
def _json(self, code, obj):
self._send(code, json.dumps(obj), "application/json; charset=utf-8")
def _stream_event(self, obj):
line = json.dumps(obj, ensure_ascii=False) + '\n'
try:
self.wfile.write(line.encode('utf-8'))
self.wfile.flush()
except (BrokenPipeError, ConnectionResetError):
pass
def _stream_transcribe(self, input_path, language, response_format, model_name=None):
self.send_response(200)
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "close")
self.end_headers()
def on_progress(chunk, total, elapsed, duration):
self._stream_event({
"type": "progress",
"chunk": chunk,
"total": total,
"elapsed": round(elapsed, 1),
"duration": round(duration, 1),
})
def on_segment(seg):
self._stream_event({
"type": "segment",
"start": round(seg.get('start', 0), 2),
"end": round(seg.get('end', 0), 2),
"text": seg.get('text', ''),
})
try:
result = transcribe_file(input_path, language, response_format,
model_name=model_name,
progress_callback=on_progress,
segment_callback=on_segment)
self._stream_event({
"type": "done",
"format": response_format,
"content": result,
})
print(f"[response] format={response_format} len={len(result)}", file=sys.stderr)
except Exception as e:
print(f"[response] error: {e}", file=sys.stderr)
self._stream_event({
"type": "error",
"error": str(e),
})
def do_GET(self):
if self.path in ("/", "/index.html"):
model_tag = f"Whisper {os.environ.get('STT_WHISPER_MODEL', 'large-v3')}" if MODEL_TYPE == 'whisper' else "Parakeet TDT 0.6B"
page = (INDEX_HTML
.replace("__MAX_MB__", str(MAX_MB))
.replace("__MAX_BYTES__", str(MAX_BYTES))
.replace("__DEFAULT_LANG__", DEFAULT_LANG)
.replace("__MODEL_TAG__", model_tag))
self._send(200, page, "text/html; charset=utf-8")
elif self.path == "/api/health":
self._json(200, {
"backend_url": BACKEND_URL,
"backend_ok": self._backend_healthy(),
})
elif self.path == "/api/models":
self._json(200, {
"models": AVAILABLE_MODELS,
"default": WHISPER_MODEL,
})
elif self.path == "/api/config":
self._json(200, {
"llm_url": LLM_URL,
"llm_model": LLM_MODEL,
})
elif self.path == "/health":
self._json(200, {"status": "ok"})
else:
self._send(404, "not found")
def do_HEAD(self):
self.do_GET()
def do_POST(self):
if self.path == "/api/transcribe":
self._proxy_transcribe()
elif self.path == "/api/transcribe/path":
self._transcribe_path()
elif self.path == "/api/summarize":
self._handle_summarize()
else:
self._send(404, "not found")
def _backend_healthy(self):
if MODEL_TYPE != 'parakeet':
return True
try:
req = urllib.request.Request(BACKEND_URL + "/health", method="GET")
with urllib.request.urlopen(req, timeout=3) as resp:
return resp.status == 200
except Exception:
return False
def _handle_summarize(self):
length = int(self.headers.get("Content-Length", "0"))
if length <= 0 or length > (5 << 20):
self._json(400, {"error": "invalid request"})
return
body = self._read_exactly(length)
if body is None:
self._json(400, {"error": "client disconnected"})
return
try:
data = json.loads(body)
except json.JSONDecodeError:
self._json(400, {"error": "invalid JSON"})
return
text = data.get('text', '').strip()
if not text:
self._json(400, {"error": "missing text"})
return
llm_url = data.get('url', LLM_URL) or LLM_URL
api_key = data.get('api_key', LLM_API_KEY) or LLM_API_KEY
model = data.get('model', LLM_MODEL) or LLM_MODEL
if not llm_url:
self._json(400, {"error": "no LLM endpoint configured. Set OPENAI_COMPATIBLE_ENDPOINT or provide 'url'."})
return
self.send_response(200)
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
self.send_header("Cache-Control", "no-cache")
self.send_header("Connection", "close")
self.end_headers()
try:
from .summarize import summarize_stream
for text_part in summarize_stream(text, llm_url, api_key=api_key, model=model):
self._stream_event({"type": "token", "text": text_part})
self._stream_event({"type": "done"})
except urllib.error.HTTPError as e:
err_body = ""
try:
err_body = e.read().decode('utf-8', 'replace')
except Exception:
pass
self._stream_event({"type": "error", "error": f"LLM returned {e.code}: {err_body[:500]}"})
except Exception as e:
self._stream_event({"type": "error", "error": str(e)})
# NOTE: transcribes a file by absolute path on the server's filesystem.
# Intended for trusted local use only — no path sandboxing.
def _transcribe_path(self):
length = int(self.headers.get("Content-Length", "0"))
if length <= 0 or length > (1 << 20):
self._json(400, {"error": "invalid request"})
return
body = self._read_exactly(length)
if body is None:
self._json(400, {"error": "client disconnected"})
return
try:
data = json.loads(body)
except json.JSONDecodeError:
self._json(400, {"error": "invalid JSON"})
return
filepath = data.get('path', '').strip()
language = data.get('language', DEFAULT_LANG)
response_format = data.get('response_format', 'json')
model_name = data.get('model') or None
if not filepath:
self._json(400, {"error": "missing 'path'"})
return
filepath = os.path.abspath(filepath)
if not os.path.isfile(filepath):
self._json(400, {"error": f"file not found: {filepath}"})
return
self._stream_transcribe(filepath, language, response_format, model_name=model_name)
def _proxy_transcribe(self):
length = int(self.headers.get("Content-Length", "0"))
if length <= 0:
self._json(400, {"error": "empty request"})
return
if length > MAX_BYTES + (1 << 16):
self._drain(length)
self._json(413, {"error": f"upload exceeds {MAX_MB} MB limit"})
return
ctype = self.headers.get("Content-Type", "")
if "multipart/form-data" not in ctype:
self._drain(length)
self._json(400, {"error": "expected multipart/form-data"})
return
body = self._read_exactly(length)
if body is None:
self._json(400, {"error": "client disconnected during upload"})
return
file_bytes, fields = _parse_multipart(body, ctype)
if file_bytes is None:
self._json(400, {"error": "no file in request"})
return
language = fields.get('language', DEFAULT_LANG)
response_format = fields.get('response_format', 'json')
model_name = fields.get('model') or None
tmpdir = tempfile.mkdtemp(prefix='stt_upload_')
try:
filename = fields.get('file', '')
ext = filename.split('.')[-1] if '.' in filename else 'bin'
input_path = os.path.join(tmpdir, f"input.{ext}")
print(f"writing recording to temp file : {input_path}")
with open(input_path, 'wb') as f:
f.write(file_bytes)
self._stream_transcribe(input_path, language, response_format, model_name=model_name)
except Exception as e:
print(f"[proxy] error writing upload: {e}", file=sys.stderr)
try:
if not self.wfile.closed:
self._stream_event({"type": "error", "error": str(e)})
except Exception:
pass
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _read_exactly(self, n):
chunks = []
remaining = n
while remaining > 0:
chunk = self.rfile.read(min(remaining, 1 << 16))
if not chunk:
return None
chunks.append(chunk)
remaining -= len(chunk)
return b"".join(chunks)
def _drain(self, n):
remaining = n
while remaining > 0:
chunk = self.rfile.read(min(remaining, 1 << 16))
if not chunk:
break
remaining -= len(chunk)
def run_server():
httpd = ThreadingHTTPServer((HOST, PORT), Handler)
model_tag = "Whisper" if MODEL_TYPE == 'whisper' else "Parakeet"
print(f"speech-to-text listening on http://{HOST}:{PORT}", file=sys.stderr)
print(f" -> model: {model_tag}", file=sys.stderr)
print(f" -> default language: {DEFAULT_LANG}, max upload: {MAX_MB} MB", file=sys.stderr)
try:
httpd.serve_forever()
except KeyboardInterrupt:
print("\nshutting down", file=sys.stderr)
httpd.shutdown()