Files
speech-to-text/lib/transcribe.py

257 lines
11 KiB
Python

import os
import sys
import json
import shutil
import uuid
import urllib.request
import urllib.error
import concurrent.futures
from .config import CHUNK_SECONDS, CHUNK_WORKERS, BACKEND_URL, API_KEY, MODEL_TYPE
from .audio import (
prepare_audio, detect_silence, find_split_points, extract_chunk,
build_response, get_duration,
)
def _build_multipart_request(file_path, fields, url, api_key=None):
boundary = '----STTChunkBoundary' + uuid.uuid4().hex
lines = []
for name, value in fields.items():
lines.append(f'--{boundary}')
lines.append(f'Content-Disposition: form-data; name="{name}"')
lines.append('')
lines.append(value)
filename = os.path.basename(file_path)
lines.append(f'--{boundary}')
lines.append(f'Content-Disposition: form-data; name="file"; filename="{filename}"')
lines.append('Content-Type: audio/wav')
lines.append('')
header = '\r\n'.join(lines).encode('utf-8') + b'\r\n'
with open(file_path, 'rb') as f:
file_data = f.read()
footer = f'\r\n--{boundary}--\r\n'.encode('utf-8')
body = header + file_data + footer
headers = {
'Content-Type': f'multipart/form-data; boundary={boundary}',
'Content-Length': str(len(body))
}
if api_key:
headers['Authorization'] = 'Bearer ' + api_key
return urllib.request.Request(url, data=body, headers=headers, method='POST')
def _transcribe_chunk_parakeet(path, language, url, api_key, timeout=600):
req = _build_multipart_request(
path,
{'language': language, 'response_format': 'verbose_json'},
url + '/v1/audio/transcriptions',
api_key
)
with urllib.request.urlopen(req, timeout=timeout) as resp:
return json.loads(resp.read())
def _retry_empty_chunk_parakeet(wav_path, start, end, language, tmpdir, chunk_idx):
sub_size = max(CHUNK_SECONDS / 2, 7)
sub_min = max(sub_size / 2, 5)
sub_starts = []
t = start
while t < end - sub_min:
sub_end = min(t + sub_size, end)
if sub_end - t < sub_min:
break
sub_starts.append((t, sub_end))
t = sub_end
sub_results = []
for i, (s, e) in enumerate(sub_starts):
sub_path = os.path.join(tmpdir, f"retry_{chunk_idx}_{i}.wav")
try:
extract_chunk(wav_path, sub_path, s, e)
resp = _transcribe_chunk_parakeet(sub_path, language, BACKEND_URL, API_KEY)
text = resp.get('text', '').strip()
if text:
sub_results.append({
'id': 0, 'start': s, 'end': e,
'text': text,
'language': resp.get('language', language),
})
preview = text[:60] + '...' if len(text) > 60 else text
print(f" [sub {s:.0f}-{e:.0f}s]: \"{preview}\"", file=sys.stderr)
except Exception as ex:
print(f" [sub {s:.0f}-{e:.0f}s] failed: {ex}", file=sys.stderr)
return sub_results
def transcribe_file_parakeet(input_path, language, response_format,
progress_callback=None, segment_callback=None):
tmpdir, wav_path, duration = prepare_audio(input_path)
try:
if CHUNK_SECONDS <= 0 or duration <= CHUNK_SECONDS:
if progress_callback:
progress_callback(1, 1, 0, duration)
resp = _transcribe_chunk_parakeet(wav_path, language, BACKEND_URL, API_KEY)
text = resp.get('text', '').strip()
seg = {
'id': 0, 'start': 0, 'end': duration,
'text': text,
'language': resp.get('language', language),
}
if segment_callback:
segment_callback(seg)
if response_format in ('json', 'text'):
return text
return build_response([seg], response_format)
silences = detect_silence(wav_path)
print(f"[chunking] detected {len(silences)} silence(s)", file=sys.stderr)
splits = find_split_points(silences, duration)
n_chunks = len(splits) - 1
print(f"[chunking] {duration:.1f}s -> {n_chunks} chunks", file=sys.stderr)
results = [None] * n_chunks # type: list[dict | None]
with concurrent.futures.ThreadPoolExecutor(max_workers=CHUNK_WORKERS) as executor:
future_to_idx = {}
for i in range(n_chunks):
chunk_path = os.path.join(tmpdir, f"chunk_{i:03d}.wav")
extract_chunk(wav_path, chunk_path, splits[i], splits[i + 1])
future = executor.submit(_transcribe_chunk_parakeet, chunk_path, language, BACKEND_URL, API_KEY)
future_to_idx[future] = i
completed = 0
for future in concurrent.futures.as_completed(future_to_idx):
idx = future_to_idx[future]
completed += 1
try:
resp = future.result()
except Exception as e:
print(f"[chunk {idx + 1}/{n_chunks}] failed: {e}", file=sys.stderr)
results[idx] = None
continue
text = resp.get('text', '').strip()
seg = {
'id': idx, 'start': splits[idx], 'end': splits[idx + 1],
'text': text,
'language': resp.get('language', language),
}
results[idx] = seg
if segment_callback:
segment_callback(seg)
preview = text[:80] + '...' if len(text) > 80 else text
print(f"[chunk {idx + 1}/{n_chunks}] {splits[idx]:.1f}s-{splits[idx + 1]:.1f}s: \"{preview}\"", file=sys.stderr)
if progress_callback:
progress_callback(completed, n_chunks, splits[idx], splits[idx + 1], duration)
empty_indices = [i for i, r in enumerate(results) if r is None or not r.get('text')]
if empty_indices:
print(f"[retry] {len(empty_indices)} empty chunk(s), sub-chunking...", file=sys.stderr)
for idx in empty_indices:
sub_segments = _retry_empty_chunk_parakeet(
wav_path, splits[idx], splits[idx + 1], language, tmpdir, idx
)
if sub_segments:
results[idx] = sub_segments[0]
for j, ss in enumerate(sub_segments[1:], 1):
results.insert(idx + j, ss)
print(f"[retry] chunk {idx + 1}/{n_chunks}: recovered {len(sub_segments)} sub-segment(s)", file=sys.stderr)
segments = [r for r in results if r and r.get('text')]
for i, seg in enumerate(segments):
seg['id'] = i
print(f"[chunking] {len(segments)} segments with text", file=sys.stderr)
return build_response(segments, response_format)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def transcribe_file_whisper(input_path, language, response_format, model_name=None,
progress_callback=None, segment_callback=None):
# Whisper uses a single GPU — parallel chunk inference would just contend
# for the same device, so chunks are processed sequentially.
from .whisper_model import get_whisper_model
model = get_whisper_model(model_name)
tmpdir, wav_path, duration = prepare_audio(input_path)
try:
if CHUNK_SECONDS <= 0 or duration <= CHUNK_SECONDS:
if progress_callback:
progress_callback(1, 1, 0, duration)
segments_iter, info = model.transcribe(wav_path, language=language, vad_filter=True)
all_segments = []
for seg in segments_iter:
s = {
'start': round(seg.start, 2),
'end': round(seg.end, 2),
'text': seg.text.strip(),
'language': info.language,
}
if s['text'] and segment_callback:
segment_callback(s)
all_segments.append(s)
return build_response(all_segments, response_format)
silences = detect_silence(wav_path)
print(f"[chunking] detected {len(silences)} silence(s)", file=sys.stderr)
splits = find_split_points(silences, duration)
n_chunks = len(splits) - 1
print(f"[chunking] {duration:.1f}s -> {n_chunks} chunks", file=sys.stderr)
all_segments = []
for i in range(n_chunks):
if progress_callback:
progress_callback(i + 1, n_chunks, splits[i], duration)
chunk_path = os.path.join(tmpdir, f"chunk_{i:03d}.wav")
extract_chunk(wav_path, chunk_path, splits[i], splits[i + 1])
segments_iter, info = model.transcribe(chunk_path, language=language, vad_filter=True)
offset = splits[i]
for seg in segments_iter:
s = {
'start': round(seg.start + offset, 2),
'end': round(seg.end + offset, 2),
'text': seg.text.strip(),
'language': info.language,
}
if s['text'] and segment_callback:
segment_callback(s)
all_segments.append(s)
preview = ' '.join(s['text'] for s in all_segments[-min(2, len(all_segments)):] if s.get('text'))
if len(preview) > 80:
preview = preview[:80] + '...'
print(f"[chunk {i + 1}/{n_chunks}] {splits[i]:.1f}s-{splits[i + 1]:.1f}s: \"{preview}\"", file=sys.stderr)
for i, seg in enumerate(all_segments):
seg['id'] = i
print(f"[chunking] {len(all_segments)} segments with text", file=sys.stderr)
return build_response(all_segments, response_format)
finally:
shutil.rmtree(tmpdir, ignore_errors=True)
def _whisper_to_response(segments_list, info, offset, fmt):
result_segments = []
for seg in segments_list:
result_segments.append({
'start': seg.start + offset,
'end': seg.end + offset,
'text': seg.text.strip(),
'language': info.language,
})
return build_response(result_segments, fmt)
def transcribe_file(input_path, language, response_format, model_name=None,
progress_callback=None, segment_callback=None):
if MODEL_TYPE == 'whisper' or model_name:
return transcribe_file_whisper(input_path, language, response_format,
model_name=model_name,
progress_callback=progress_callback,
segment_callback=segment_callback)
else:
return transcribe_file_parakeet(input_path, language, response_format,
progress_callback=progress_callback,
segment_callback=segment_callback)