49 lines
1.4 KiB
Python
49 lines
1.4 KiB
Python
import sys
|
|
from .config import WHISPER_MODEL, WHISPER_DEVICE, WHISPER_COMPUTE
|
|
|
|
_models = {}
|
|
|
|
|
|
def _detect_device():
|
|
if WHISPER_DEVICE != "cuda":
|
|
return "cpu"
|
|
try:
|
|
import ctranslate2
|
|
cuda_types = ctranslate2.get_supported_compute_types("cuda")
|
|
if not cuda_types:
|
|
print("[whisper] CUDA not available in ctranslate2, using CPU", file=sys.stderr)
|
|
return "cpu"
|
|
except Exception:
|
|
return "cpu"
|
|
return "cuda"
|
|
|
|
|
|
def get_whisper_model(model_name=None):
|
|
global _models
|
|
if model_name is None:
|
|
model_name = WHISPER_MODEL
|
|
if model_name in _models:
|
|
return _models[model_name]
|
|
|
|
device = WHISPER_DEVICE
|
|
if device == "cuda":
|
|
device = _detect_device()
|
|
|
|
compute = WHISPER_COMPUTE
|
|
if device == "cpu" and compute in ("int8_float16", "float16"):
|
|
compute = "int8"
|
|
|
|
from faster_whisper import WhisperModel
|
|
print(f"[whisper] loading {model_name} device={device} compute={compute}...", file=sys.stderr)
|
|
model = WhisperModel(model_name, device=device, compute_type=compute)
|
|
print(f"[whisper] {model_name} loaded on {device}", file=sys.stderr)
|
|
_models[model_name] = model
|
|
return model
|
|
|
|
|
|
def unload_model(model_name):
|
|
global _models
|
|
if model_name in _models:
|
|
del _models[model_name]
|
|
print(f"[whisper] unloaded {model_name}", file=sys.stderr)
|