"""
Shared Whisper speech-to-text helper.

The openai-whisper package (pip install openai-whisper) runs the model
entirely locally — no API key, no network calls.

Model sizes and approximate VRAM/RAM requirements:
  tiny   ~75 MB   fastest, lowest accuracy
  base   ~145 MB
  small  ~465 MB
  medium ~1.5 GB
  large  ~2.9 GB

Set the WHISPER_MODEL environment variable to override the default (tiny).
"""

import os
import math

_MODEL_CACHE = {}


def _get_model(model_name: str = None):
    """Load and cache a Whisper model. Thread-safe for read after first load."""
    if model_name is None:
        env_model = os.environ.get("WHISPER_MODEL", "").strip().lower()
        if env_model:
            model_name = env_model
        else:
            try:
                from helpers import get_site_settings
                model_name = (get_site_settings().get("whisper_model") or "tiny").strip().lower()
            except Exception:
                model_name = "tiny"
    model_name = model_name.strip().lower()
    if model_name not in ("tiny", "base", "small", "medium", "large"):
        model_name = "tiny"
    if model_name not in _MODEL_CACHE:
        try:
            import whisper
        except ImportError:
            raise RuntimeError(
                "openai-whisper is not installed. "
                "Run: pip install openai-whisper"
            )
        _MODEL_CACHE[model_name] = whisper.load_model(model_name)
    return _MODEL_CACHE[model_name]


def transcribe(audio_path: str, model_name: str = None) -> dict:
    """
    Transcribe an audio/video file using Whisper.

    Returns a dict:
      {
        "text":     str,          # full transcript
        "language": str,          # detected language code, e.g. "en"
        "segments": [             # list of timed segments
          {"start": float, "end": float, "text": str},
          ...
        ]
      }
    """
    model = _get_model(model_name)
    result = model.transcribe(audio_path, fp16=False)
    return {
        "text": result.get("text", ""),
        "language": result.get("language", ""),
        "segments": [
            {
                "start": seg["start"],
                "end": seg["end"],
                "text": seg["text"].strip(),
            }
            for seg in result.get("segments", [])
        ],
    }


def _fmt_srt_time(seconds: float) -> str:
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int(round((seconds - math.floor(seconds)) * 1000))
    return f"{hours:02d}:{minutes:02d}:{secs:02d},{millis:03d}"


def _fmt_vtt_time(seconds: float) -> str:
    hours = int(seconds // 3600)
    minutes = int((seconds % 3600) // 60)
    secs = int(seconds % 60)
    millis = int(round((seconds - math.floor(seconds)) * 1000))
    return f"{hours:02d}:{minutes:02d}:{secs:02d}.{millis:03d}"


def to_srt(segments: list) -> str:
    """Convert Whisper segments to SRT subtitle format."""
    lines = []
    for i, seg in enumerate(segments, 1):
        start = _fmt_srt_time(seg["start"])
        end = _fmt_srt_time(seg["end"])
        lines.append(f"{i}\n{start} --> {end}\n{seg['text']}\n")
    return "\n".join(lines)


def to_vtt(segments: list) -> str:
    """Convert Whisper segments to WebVTT subtitle format."""
    lines = ["WEBVTT\n"]
    for i, seg in enumerate(segments, 1):
        start = _fmt_vtt_time(seg["start"])
        end = _fmt_vtt_time(seg["end"])
        lines.append(f"{i}\n{start} --> {end}\n{seg['text']}\n")
    return "\n".join(lines)
