359 lines
13 KiB
Python
359 lines
13 KiB
Python
import json
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
import shutil
|
|
import urllib.request
|
|
import urllib.error
|
|
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
|
|
|
from .config import (
|
|
HOST, PORT, BACKEND_URL, API_KEY, DEFAULT_LANG, MAX_MB, MAX_BYTES, MODEL_TYPE,
|
|
AVAILABLE_MODELS, WHISPER_MODEL, LLM_URL, LLM_API_KEY, LLM_MODEL
|
|
)
|
|
from .frontend import INDEX_HTML
|
|
from .transcribe import transcribe_file
|
|
from .summarize import summarize
|
|
|
|
|
|
def _parse_multipart(body, content_type):
|
|
boundary = None
|
|
for part in content_type.split(';'):
|
|
part = part.strip()
|
|
if part.lower().startswith('boundary='):
|
|
boundary = part[9:].strip('"\'')
|
|
break
|
|
if not boundary:
|
|
return None, {}
|
|
b = ('--' + boundary).encode()
|
|
parts = body.split(b)
|
|
file_bytes = None
|
|
filename = ''
|
|
fields = {}
|
|
for p in parts[1:-1]:
|
|
if p.startswith(b'\r\n'):
|
|
p = p[2:]
|
|
elif p.startswith(b'\n'):
|
|
p = p[1:]
|
|
idx = p.find(b'\r\n\r\n')
|
|
if idx >= 0:
|
|
headers = p[:idx].decode('utf-8', 'replace')
|
|
data = p[idx + 4:]
|
|
else:
|
|
idx = p.find(b'\n\n')
|
|
headers = p[:idx].decode('utf-8', 'replace')
|
|
data = p[idx + 2:]
|
|
if data.endswith(b'\r\n'):
|
|
data = data[:-2]
|
|
elif data.endswith(b'\n'):
|
|
data = data[:-1]
|
|
name = None
|
|
for line in headers.splitlines():
|
|
line = line.strip()
|
|
if line.lower().startswith('content-disposition:'):
|
|
for item in line.split(';'):
|
|
item = item.strip()
|
|
if item.lower().startswith('name='):
|
|
name = item[5:].strip('"\'')
|
|
elif item.lower().startswith('filename='):
|
|
filename = item[9:].strip('"\'')
|
|
break
|
|
if name == 'file':
|
|
file_bytes = data
|
|
elif name:
|
|
try:
|
|
fields[name] = data.decode('utf-8')
|
|
except UnicodeDecodeError:
|
|
fields[name] = data.decode('utf-8', 'replace')
|
|
fields['file'] = filename
|
|
return file_bytes, fields
|
|
|
|
|
|
class Handler(BaseHTTPRequestHandler):
|
|
server_version = "speech-to-text/1.0"
|
|
protocol_version = "HTTP/1.1"
|
|
|
|
def log_message(self, fmt, *args):
|
|
sys.stderr.write("[%s] %s\n" % (self.address_string(), fmt % args))
|
|
|
|
def _send(self, code, body, content_type="text/plain; charset=utf-8", extra=None):
|
|
if isinstance(body, str):
|
|
body = body.encode("utf-8")
|
|
self.send_response(code)
|
|
self.send_header("Content-Type", content_type)
|
|
self.send_header("Content-Length", str(len(body)))
|
|
if extra:
|
|
for k, v in extra.items():
|
|
self.send_header(k, v)
|
|
self.end_headers()
|
|
if self.command != "HEAD":
|
|
self.wfile.write(body)
|
|
|
|
def _json(self, code, obj):
|
|
self._send(code, json.dumps(obj), "application/json; charset=utf-8")
|
|
|
|
def _stream_event(self, obj):
|
|
line = json.dumps(obj, ensure_ascii=False) + '\n'
|
|
try:
|
|
self.wfile.write(line.encode('utf-8'))
|
|
self.wfile.flush()
|
|
except (BrokenPipeError, ConnectionResetError):
|
|
pass
|
|
|
|
def _stream_transcribe(self, input_path, language, response_format, model_name=None):
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
|
|
self.send_header("Cache-Control", "no-cache")
|
|
self.send_header("Connection", "close")
|
|
self.end_headers()
|
|
|
|
def on_progress(chunk, total, elapsed, duration):
|
|
self._stream_event({
|
|
"type": "progress",
|
|
"chunk": chunk,
|
|
"total": total,
|
|
"elapsed": round(elapsed, 1),
|
|
"duration": round(duration, 1),
|
|
})
|
|
|
|
def on_segment(seg):
|
|
self._stream_event({
|
|
"type": "segment",
|
|
"start": round(seg.get('start', 0), 2),
|
|
"end": round(seg.get('end', 0), 2),
|
|
"text": seg.get('text', ''),
|
|
})
|
|
|
|
try:
|
|
result = transcribe_file(input_path, language, response_format,
|
|
model_name=model_name,
|
|
progress_callback=on_progress,
|
|
segment_callback=on_segment)
|
|
self._stream_event({
|
|
"type": "done",
|
|
"format": response_format,
|
|
"content": result,
|
|
})
|
|
print(f"[response] format={response_format} len={len(result)}", file=sys.stderr)
|
|
except Exception as e:
|
|
print(f"[response] error: {e}", file=sys.stderr)
|
|
self._stream_event({
|
|
"type": "error",
|
|
"error": str(e),
|
|
})
|
|
|
|
def do_GET(self):
|
|
if self.path in ("/", "/index.html"):
|
|
model_tag = f"Whisper {os.environ.get('STT_WHISPER_MODEL', 'large-v3')}" if MODEL_TYPE == 'whisper' else "Parakeet TDT 0.6B"
|
|
page = (INDEX_HTML
|
|
.replace("__MAX_MB__", str(MAX_MB))
|
|
.replace("__MAX_BYTES__", str(MAX_BYTES))
|
|
.replace("__DEFAULT_LANG__", DEFAULT_LANG)
|
|
.replace("__MODEL_TAG__", model_tag))
|
|
self._send(200, page, "text/html; charset=utf-8")
|
|
elif self.path == "/api/health":
|
|
self._json(200, {
|
|
"backend_url": BACKEND_URL,
|
|
"backend_ok": self._backend_healthy(),
|
|
})
|
|
elif self.path == "/api/models":
|
|
self._json(200, {
|
|
"models": AVAILABLE_MODELS,
|
|
"default": WHISPER_MODEL,
|
|
})
|
|
elif self.path == "/api/config":
|
|
self._json(200, {
|
|
"llm_url": LLM_URL,
|
|
"llm_model": LLM_MODEL,
|
|
})
|
|
elif self.path == "/health":
|
|
self._json(200, {"status": "ok"})
|
|
else:
|
|
self._send(404, "not found")
|
|
|
|
def do_HEAD(self):
|
|
self.do_GET()
|
|
|
|
def do_POST(self):
|
|
if self.path == "/api/transcribe":
|
|
self._proxy_transcribe()
|
|
elif self.path == "/api/transcribe/path":
|
|
self._transcribe_path()
|
|
elif self.path == "/api/summarize":
|
|
self._handle_summarize()
|
|
else:
|
|
self._send(404, "not found")
|
|
|
|
def _backend_healthy(self):
|
|
if MODEL_TYPE != 'parakeet':
|
|
return True
|
|
try:
|
|
req = urllib.request.Request(BACKEND_URL + "/health", method="GET")
|
|
with urllib.request.urlopen(req, timeout=3) as resp:
|
|
return resp.status == 200
|
|
except Exception:
|
|
return False
|
|
|
|
def _handle_summarize(self):
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0 or length > (5 << 20):
|
|
self._json(400, {"error": "invalid request"})
|
|
return
|
|
body = self._read_exactly(length)
|
|
if body is None:
|
|
self._json(400, {"error": "client disconnected"})
|
|
return
|
|
try:
|
|
data = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
self._json(400, {"error": "invalid JSON"})
|
|
return
|
|
|
|
text = data.get('text', '').strip()
|
|
if not text:
|
|
self._json(400, {"error": "missing text"})
|
|
return
|
|
|
|
llm_url = data.get('url', LLM_URL) or LLM_URL
|
|
api_key = data.get('api_key', LLM_API_KEY) or LLM_API_KEY
|
|
model = data.get('model', LLM_MODEL) or LLM_MODEL
|
|
|
|
if not llm_url:
|
|
self._json(400, {"error": "no LLM endpoint configured. Set OPENAI_COMPATIBLE_ENDPOINT or provide 'url'."})
|
|
return
|
|
|
|
self.send_response(200)
|
|
self.send_header("Content-Type", "application/x-ndjson; charset=utf-8")
|
|
self.send_header("Cache-Control", "no-cache")
|
|
self.send_header("Connection", "close")
|
|
self.end_headers()
|
|
|
|
try:
|
|
from .summarize import summarize_stream
|
|
for text_part in summarize_stream(text, llm_url, api_key=api_key, model=model):
|
|
self._stream_event({"type": "token", "text": text_part})
|
|
self._stream_event({"type": "done"})
|
|
except urllib.error.HTTPError as e:
|
|
err_body = ""
|
|
try:
|
|
err_body = e.read().decode('utf-8', 'replace')
|
|
except Exception:
|
|
pass
|
|
self._stream_event({"type": "error", "error": f"LLM returned {e.code}: {err_body[:500]}"})
|
|
except Exception as e:
|
|
self._stream_event({"type": "error", "error": str(e)})
|
|
|
|
# NOTE: transcribes a file by absolute path on the server's filesystem.
|
|
# Intended for trusted local use only — no path sandboxing.
|
|
def _transcribe_path(self):
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0 or length > (1 << 20):
|
|
self._json(400, {"error": "invalid request"})
|
|
return
|
|
body = self._read_exactly(length)
|
|
if body is None:
|
|
self._json(400, {"error": "client disconnected"})
|
|
return
|
|
try:
|
|
data = json.loads(body)
|
|
except json.JSONDecodeError:
|
|
self._json(400, {"error": "invalid JSON"})
|
|
return
|
|
|
|
filepath = data.get('path', '').strip()
|
|
language = data.get('language', DEFAULT_LANG)
|
|
response_format = data.get('response_format', 'json')
|
|
model_name = data.get('model') or None
|
|
|
|
if not filepath:
|
|
self._json(400, {"error": "missing 'path'"})
|
|
return
|
|
filepath = os.path.abspath(filepath)
|
|
if not os.path.isfile(filepath):
|
|
self._json(400, {"error": f"file not found: {filepath}"})
|
|
return
|
|
|
|
self._stream_transcribe(filepath, language, response_format, model_name=model_name)
|
|
|
|
def _proxy_transcribe(self):
|
|
length = int(self.headers.get("Content-Length", "0"))
|
|
if length <= 0:
|
|
self._json(400, {"error": "empty request"})
|
|
return
|
|
if length > MAX_BYTES + (1 << 16):
|
|
self._drain(length)
|
|
self._json(413, {"error": f"upload exceeds {MAX_MB} MB limit"})
|
|
return
|
|
|
|
ctype = self.headers.get("Content-Type", "")
|
|
if "multipart/form-data" not in ctype:
|
|
self._drain(length)
|
|
self._json(400, {"error": "expected multipart/form-data"})
|
|
return
|
|
|
|
body = self._read_exactly(length)
|
|
if body is None:
|
|
self._json(400, {"error": "client disconnected during upload"})
|
|
return
|
|
|
|
file_bytes, fields = _parse_multipart(body, ctype)
|
|
if file_bytes is None:
|
|
self._json(400, {"error": "no file in request"})
|
|
return
|
|
|
|
language = fields.get('language', DEFAULT_LANG)
|
|
response_format = fields.get('response_format', 'json')
|
|
model_name = fields.get('model') or None
|
|
|
|
tmpdir = tempfile.mkdtemp(prefix='stt_upload_')
|
|
try:
|
|
filename = fields.get('file', '')
|
|
ext = filename.split('.')[-1] if '.' in filename else 'bin'
|
|
input_path = os.path.join(tmpdir, f"input.{ext}")
|
|
print(f"writing recording to temp file : {input_path}")
|
|
with open(input_path, 'wb') as f:
|
|
f.write(file_bytes)
|
|
|
|
self._stream_transcribe(input_path, language, response_format, model_name=model_name)
|
|
except Exception as e:
|
|
print(f"[proxy] error writing upload: {e}", file=sys.stderr)
|
|
try:
|
|
if not self.wfile.closed:
|
|
self._stream_event({"type": "error", "error": str(e)})
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
|
|
|
def _read_exactly(self, n):
|
|
chunks = []
|
|
remaining = n
|
|
while remaining > 0:
|
|
chunk = self.rfile.read(min(remaining, 1 << 16))
|
|
if not chunk:
|
|
return None
|
|
chunks.append(chunk)
|
|
remaining -= len(chunk)
|
|
return b"".join(chunks)
|
|
|
|
def _drain(self, n):
|
|
remaining = n
|
|
while remaining > 0:
|
|
chunk = self.rfile.read(min(remaining, 1 << 16))
|
|
if not chunk:
|
|
break
|
|
remaining -= len(chunk)
|
|
|
|
|
|
def run_server():
|
|
httpd = ThreadingHTTPServer((HOST, PORT), Handler)
|
|
model_tag = "Whisper" if MODEL_TYPE == 'whisper' else "Parakeet"
|
|
print(f"speech-to-text listening on http://{HOST}:{PORT}", file=sys.stderr)
|
|
print(f" -> model: {model_tag}", file=sys.stderr)
|
|
print(f" -> default language: {DEFAULT_LANG}, max upload: {MAX_MB} MB", file=sys.stderr)
|
|
try:
|
|
httpd.serve_forever()
|
|
except KeyboardInterrupt:
|
|
print("\nshutting down", file=sys.stderr)
|
|
httpd.shutdown()
|