257 lines
11 KiB
Python
257 lines
11 KiB
Python
import os
|
|
import sys
|
|
import json
|
|
import shutil
|
|
import uuid
|
|
import urllib.request
|
|
import urllib.error
|
|
import concurrent.futures
|
|
|
|
from .config import CHUNK_SECONDS, CHUNK_WORKERS, BACKEND_URL, API_KEY, MODEL_TYPE
|
|
from .audio import (
|
|
prepare_audio, detect_silence, find_split_points, extract_chunk,
|
|
build_response, get_duration,
|
|
)
|
|
|
|
|
|
def _build_multipart_request(file_path, fields, url, api_key=None):
|
|
boundary = '----STTChunkBoundary' + uuid.uuid4().hex
|
|
lines = []
|
|
for name, value in fields.items():
|
|
lines.append(f'--{boundary}')
|
|
lines.append(f'Content-Disposition: form-data; name="{name}"')
|
|
lines.append('')
|
|
lines.append(value)
|
|
filename = os.path.basename(file_path)
|
|
lines.append(f'--{boundary}')
|
|
lines.append(f'Content-Disposition: form-data; name="file"; filename="{filename}"')
|
|
lines.append('Content-Type: audio/wav')
|
|
lines.append('')
|
|
header = '\r\n'.join(lines).encode('utf-8') + b'\r\n'
|
|
with open(file_path, 'rb') as f:
|
|
file_data = f.read()
|
|
footer = f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
|
body = header + file_data + footer
|
|
headers = {
|
|
'Content-Type': f'multipart/form-data; boundary={boundary}',
|
|
'Content-Length': str(len(body))
|
|
}
|
|
if api_key:
|
|
headers['Authorization'] = 'Bearer ' + api_key
|
|
return urllib.request.Request(url, data=body, headers=headers, method='POST')
|
|
|
|
|
|
def _transcribe_chunk_parakeet(path, language, url, api_key, timeout=600):
|
|
req = _build_multipart_request(
|
|
path,
|
|
{'language': language, 'response_format': 'verbose_json'},
|
|
url + '/v1/audio/transcriptions',
|
|
api_key
|
|
)
|
|
with urllib.request.urlopen(req, timeout=timeout) as resp:
|
|
return json.loads(resp.read())
|
|
|
|
|
|
def _retry_empty_chunk_parakeet(wav_path, start, end, language, tmpdir, chunk_idx):
|
|
sub_size = max(CHUNK_SECONDS / 2, 7)
|
|
sub_min = max(sub_size / 2, 5)
|
|
sub_starts = []
|
|
t = start
|
|
while t < end - sub_min:
|
|
sub_end = min(t + sub_size, end)
|
|
if sub_end - t < sub_min:
|
|
break
|
|
sub_starts.append((t, sub_end))
|
|
t = sub_end
|
|
|
|
sub_results = []
|
|
for i, (s, e) in enumerate(sub_starts):
|
|
sub_path = os.path.join(tmpdir, f"retry_{chunk_idx}_{i}.wav")
|
|
try:
|
|
extract_chunk(wav_path, sub_path, s, e)
|
|
resp = _transcribe_chunk_parakeet(sub_path, language, BACKEND_URL, API_KEY)
|
|
text = resp.get('text', '').strip()
|
|
if text:
|
|
sub_results.append({
|
|
'id': 0, 'start': s, 'end': e,
|
|
'text': text,
|
|
'language': resp.get('language', language),
|
|
})
|
|
preview = text[:60] + '...' if len(text) > 60 else text
|
|
print(f" [sub {s:.0f}-{e:.0f}s]: \"{preview}\"", file=sys.stderr)
|
|
except Exception as ex:
|
|
print(f" [sub {s:.0f}-{e:.0f}s] failed: {ex}", file=sys.stderr)
|
|
return sub_results
|
|
|
|
|
|
def transcribe_file_parakeet(input_path, language, response_format,
|
|
progress_callback=None, segment_callback=None):
|
|
tmpdir, wav_path, duration = prepare_audio(input_path)
|
|
|
|
try:
|
|
if CHUNK_SECONDS <= 0 or duration <= CHUNK_SECONDS:
|
|
if progress_callback:
|
|
progress_callback(1, 1, 0, duration)
|
|
resp = _transcribe_chunk_parakeet(wav_path, language, BACKEND_URL, API_KEY)
|
|
text = resp.get('text', '').strip()
|
|
seg = {
|
|
'id': 0, 'start': 0, 'end': duration,
|
|
'text': text,
|
|
'language': resp.get('language', language),
|
|
}
|
|
if segment_callback:
|
|
segment_callback(seg)
|
|
if response_format in ('json', 'text'):
|
|
return text
|
|
return build_response([seg], response_format)
|
|
|
|
silences = detect_silence(wav_path)
|
|
print(f"[chunking] detected {len(silences)} silence(s)", file=sys.stderr)
|
|
|
|
splits = find_split_points(silences, duration)
|
|
n_chunks = len(splits) - 1
|
|
print(f"[chunking] {duration:.1f}s -> {n_chunks} chunks", file=sys.stderr)
|
|
|
|
results = [None] * n_chunks # type: list[dict | None]
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=CHUNK_WORKERS) as executor:
|
|
future_to_idx = {}
|
|
for i in range(n_chunks):
|
|
chunk_path = os.path.join(tmpdir, f"chunk_{i:03d}.wav")
|
|
extract_chunk(wav_path, chunk_path, splits[i], splits[i + 1])
|
|
future = executor.submit(_transcribe_chunk_parakeet, chunk_path, language, BACKEND_URL, API_KEY)
|
|
future_to_idx[future] = i
|
|
|
|
completed = 0
|
|
for future in concurrent.futures.as_completed(future_to_idx):
|
|
idx = future_to_idx[future]
|
|
completed += 1
|
|
try:
|
|
resp = future.result()
|
|
except Exception as e:
|
|
print(f"[chunk {idx + 1}/{n_chunks}] failed: {e}", file=sys.stderr)
|
|
results[idx] = None
|
|
continue
|
|
text = resp.get('text', '').strip()
|
|
seg = {
|
|
'id': idx, 'start': splits[idx], 'end': splits[idx + 1],
|
|
'text': text,
|
|
'language': resp.get('language', language),
|
|
}
|
|
results[idx] = seg
|
|
if segment_callback:
|
|
segment_callback(seg)
|
|
preview = text[:80] + '...' if len(text) > 80 else text
|
|
print(f"[chunk {idx + 1}/{n_chunks}] {splits[idx]:.1f}s-{splits[idx + 1]:.1f}s: \"{preview}\"", file=sys.stderr)
|
|
if progress_callback:
|
|
progress_callback(completed, n_chunks, splits[idx], splits[idx + 1], duration)
|
|
|
|
empty_indices = [i for i, r in enumerate(results) if r is None or not r.get('text')]
|
|
if empty_indices:
|
|
print(f"[retry] {len(empty_indices)} empty chunk(s), sub-chunking...", file=sys.stderr)
|
|
for idx in empty_indices:
|
|
sub_segments = _retry_empty_chunk_parakeet(
|
|
wav_path, splits[idx], splits[idx + 1], language, tmpdir, idx
|
|
)
|
|
if sub_segments:
|
|
results[idx] = sub_segments[0]
|
|
for j, ss in enumerate(sub_segments[1:], 1):
|
|
results.insert(idx + j, ss)
|
|
print(f"[retry] chunk {idx + 1}/{n_chunks}: recovered {len(sub_segments)} sub-segment(s)", file=sys.stderr)
|
|
|
|
segments = [r for r in results if r and r.get('text')]
|
|
for i, seg in enumerate(segments):
|
|
seg['id'] = i
|
|
print(f"[chunking] {len(segments)} segments with text", file=sys.stderr)
|
|
return build_response(segments, response_format)
|
|
finally:
|
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
|
|
|
|
|
def transcribe_file_whisper(input_path, language, response_format, model_name=None,
|
|
progress_callback=None, segment_callback=None):
|
|
# Whisper uses a single GPU — parallel chunk inference would just contend
|
|
# for the same device, so chunks are processed sequentially.
|
|
from .whisper_model import get_whisper_model
|
|
|
|
model = get_whisper_model(model_name)
|
|
tmpdir, wav_path, duration = prepare_audio(input_path)
|
|
|
|
try:
|
|
if CHUNK_SECONDS <= 0 or duration <= CHUNK_SECONDS:
|
|
if progress_callback:
|
|
progress_callback(1, 1, 0, duration)
|
|
segments_iter, info = model.transcribe(wav_path, language=language, vad_filter=True)
|
|
all_segments = []
|
|
for seg in segments_iter:
|
|
s = {
|
|
'start': round(seg.start, 2),
|
|
'end': round(seg.end, 2),
|
|
'text': seg.text.strip(),
|
|
'language': info.language,
|
|
}
|
|
if s['text'] and segment_callback:
|
|
segment_callback(s)
|
|
all_segments.append(s)
|
|
return build_response(all_segments, response_format)
|
|
|
|
silences = detect_silence(wav_path)
|
|
print(f"[chunking] detected {len(silences)} silence(s)", file=sys.stderr)
|
|
|
|
splits = find_split_points(silences, duration)
|
|
n_chunks = len(splits) - 1
|
|
print(f"[chunking] {duration:.1f}s -> {n_chunks} chunks", file=sys.stderr)
|
|
|
|
all_segments = []
|
|
for i in range(n_chunks):
|
|
if progress_callback:
|
|
progress_callback(i + 1, n_chunks, splits[i], duration)
|
|
chunk_path = os.path.join(tmpdir, f"chunk_{i:03d}.wav")
|
|
extract_chunk(wav_path, chunk_path, splits[i], splits[i + 1])
|
|
segments_iter, info = model.transcribe(chunk_path, language=language, vad_filter=True)
|
|
offset = splits[i]
|
|
for seg in segments_iter:
|
|
s = {
|
|
'start': round(seg.start + offset, 2),
|
|
'end': round(seg.end + offset, 2),
|
|
'text': seg.text.strip(),
|
|
'language': info.language,
|
|
}
|
|
if s['text'] and segment_callback:
|
|
segment_callback(s)
|
|
all_segments.append(s)
|
|
preview = ' '.join(s['text'] for s in all_segments[-min(2, len(all_segments)):] if s.get('text'))
|
|
if len(preview) > 80:
|
|
preview = preview[:80] + '...'
|
|
print(f"[chunk {i + 1}/{n_chunks}] {splits[i]:.1f}s-{splits[i + 1]:.1f}s: \"{preview}\"", file=sys.stderr)
|
|
|
|
for i, seg in enumerate(all_segments):
|
|
seg['id'] = i
|
|
print(f"[chunking] {len(all_segments)} segments with text", file=sys.stderr)
|
|
return build_response(all_segments, response_format)
|
|
finally:
|
|
shutil.rmtree(tmpdir, ignore_errors=True)
|
|
|
|
|
|
def _whisper_to_response(segments_list, info, offset, fmt):
|
|
result_segments = []
|
|
for seg in segments_list:
|
|
result_segments.append({
|
|
'start': seg.start + offset,
|
|
'end': seg.end + offset,
|
|
'text': seg.text.strip(),
|
|
'language': info.language,
|
|
})
|
|
return build_response(result_segments, fmt)
|
|
|
|
|
|
def transcribe_file(input_path, language, response_format, model_name=None,
|
|
progress_callback=None, segment_callback=None):
|
|
if MODEL_TYPE == 'whisper' or model_name:
|
|
return transcribe_file_whisper(input_path, language, response_format,
|
|
model_name=model_name,
|
|
progress_callback=progress_callback,
|
|
segment_callback=segment_callback)
|
|
else:
|
|
return transcribe_file_parakeet(input_path, language, response_format,
|
|
progress_callback=progress_callback,
|
|
segment_callback=segment_callback) |