import concurrent.futures
import logging
import os
import json
import queue
import sqlite3
import subprocess
import threading
import uuid as uuid_mod

_log = logging.getLogger(__name__)

_DB_PATH = os.path.join(os.path.dirname(__file__), "storage", "sqlite.db")

_TABLE_ENSURED = False
_TABLE_LOCK = threading.Lock()

_PRIORITY_MAP = {'high': 0, 'medium': 1, 'low': 2}

# ── subprocess tracking ────────────────────────────────────────────────────────
# Maps job_id → list of subprocess.Popen objects spawned during that job.
_job_child_procs: dict = {}
_job_child_procs_lock = threading.Lock()

# Thread-local storage for the currently-running job_id so the patched
# subprocess.run() knows which job to register procs under.
_tl = threading.local()

# Heavy-tool process names that should be killed on timeout.
_HEAVY_PROC_NAMES = frozenset({
    'soffice', 'soffice.bin', 'ffmpeg', 'pdftoppm', 'gs',
    'ghostscript', 'convert', 'ebook-convert', 'calibre',
})

# ── monkey-patch subprocess.run ────────────────────────────────────────────────
_original_subprocess_run = subprocess.run


def _tracked_subprocess_run(cmd, *args, **kwargs):
    """
    Drop-in replacement for subprocess.run() that registers the spawned
    process with the current job so it can be killed on timeout.
    """
    job_id = getattr(_tl, 'current_job_id', None)
    if job_id is None:
        # Not running inside a tracked job — behave normally.
        return _original_subprocess_run(cmd, *args, **kwargs)

    # Pull out kwargs that Popen accepts but CompletedProcess needs separately.
    timeout  = kwargs.pop('timeout', None)
    check    = kwargs.pop('check', False)
    capture  = kwargs.pop('capture_output', False)
    if capture:
        kwargs.setdefault('stdout', subprocess.PIPE)
        kwargs.setdefault('stderr', subprocess.PIPE)

    proc = subprocess.Popen(cmd, *args, **kwargs)
    with _job_child_procs_lock:
        _job_child_procs.setdefault(job_id, []).append(proc)

    try:
        stdout, stderr = proc.communicate(timeout=timeout)
    except subprocess.TimeoutExpired:
        proc.kill()
        proc.communicate()
        raise

    completed = subprocess.CompletedProcess(
        args=cmd,
        returncode=proc.returncode,
        stdout=stdout,
        stderr=stderr,
    )
    if check and proc.returncode != 0:
        raise subprocess.CalledProcessError(proc.returncode, cmd,
                                            stdout, stderr)
    return completed


subprocess.run = _tracked_subprocess_run


# ── kill helpers ───────────────────────────────────────────────────────────────

def _kill_job_procs(job_id: str):
    """Kill all registered child processes for a job, then psutil fallback.

    The psutil fallback is scoped to descendants of the registered Popen PIDs
    so it cannot accidentally kill processes belonging to other concurrent jobs.
    """
    with _job_child_procs_lock:
        procs = _job_child_procs.pop(job_id, [])

    # ── Phase 1: snapshot descendants NOW, before killing parents ─────────────
    # We collect the psutil process objects for each registered pid and ALL of
    # their recursive children *before* we kill anything.  This is critical:
    # once the parent is killed its children can be re-parented to PID 1, so
    # a post-kill `parent.children()` call would find nothing.
    heavy_to_kill = []   # psutil.Process objects to kill in phase 3
    try:
        import psutil as _psutil
        for p in procs:
            try:
                parent = _psutil.Process(p.pid)
                # Collect descendants first
                for child in parent.children(recursive=True):
                    try:
                        name = child.name().lower().split('.')[0]
                        if name in _HEAVY_PROC_NAMES:
                            heavy_to_kill.append(child)
                    except Exception:
                        pass
                # Then collect the parent itself if it is a heavy tool
                try:
                    name = parent.name().lower().split('.')[0]
                    if name in _HEAVY_PROC_NAMES:
                        heavy_to_kill.append(parent)
                except Exception:
                    pass
            except Exception:
                pass  # NoSuchProcess or permission — already gone
    except ImportError:
        _psutil = None
    except Exception:
        pass

    # ── Phase 2: kill the registered Popen objects ─────────────────────────────
    for p in procs:
        try:
            if p.poll() is None:
                p.kill()
        except Exception:
            pass

    # ── Phase 3: kill any collected heavy-tool descendants ────────────────────
    # These were snapshotted before the parent kill so we have their handles
    # even if they were re-parented after the parent exited.
    for proc in heavy_to_kill:
        try:
            if proc.is_running():
                proc.kill()
        except Exception:
            pass


def _cleanup_job_procs(job_id: str):
    """Remove proc registry for a successfully finished job (no kill needed)."""
    with _job_child_procs_lock:
        _job_child_procs.pop(job_id, None)


def kill_all_heavy_procs() -> int:
    """
    Kill ALL running heavy-tool child processes of this Python process.
    Used by the admin "Kill Stuck Jobs" endpoint.
    Returns the count of processes killed.
    """
    killed = 0
    try:
        import psutil
        current = psutil.Process()
        for child in current.children(recursive=True):
            try:
                name = child.name().lower().split('.')[0]
                if name in _HEAVY_PROC_NAMES and child.is_running():
                    child.kill()
                    killed += 1
            except Exception:
                pass
    except ImportError:
        pass
    except Exception:
        pass
    return killed


# ── DB table helper ────────────────────────────────────────────────────────────

def _ensure_table():
    global _TABLE_ENSURED
    if _TABLE_ENSURED:
        return
    with _TABLE_LOCK:
        if _TABLE_ENSURED:
            return
        with sqlite3.connect(_DB_PATH) as conn:
            conn.execute("""
                CREATE TABLE IF NOT EXISTS local_jobs (
                    id TEXT PRIMARY KEY,
                    status TEXT NOT NULL DEFAULT 'queued',
                    result_json TEXT,
                    created_at TEXT NOT NULL DEFAULT (datetime('now')),
                    started_at TEXT,
                    completed_at TEXT
                )
            """)
            try:
                conn.execute("ALTER TABLE local_jobs ADD COLUMN started_at TEXT")
            except Exception:
                pass
            conn.commit()
        _TABLE_ENSURED = True


# ── LocalJob ───────────────────────────────────────────────────────────────────

class LocalJob:
    def __init__(self, job_id, status, result=None):
        self.id = job_id
        self._status = status
        self._result = result

    @property
    def is_finished(self):
        return self._status == 'finished'

    @property
    def is_failed(self):
        return self._status == 'failed'

    @property
    def is_queued(self):
        return self._status == 'queued'

    @property
    def is_started(self):
        return self._status == 'started'

    @property
    def result(self):
        return self._result

    @classmethod
    def fetch(cls, job_id, **kwargs):
        _ensure_table()
        with sqlite3.connect(_DB_PATH) as conn:
            conn.row_factory = sqlite3.Row
            cur = conn.execute(
                "SELECT id, status, result_json FROM local_jobs WHERE id=?",
                (job_id,)
            )
            row = cur.fetchone()
        if not row:
            raise Exception(f"LocalJob not found: {job_id}")
        result = json.loads(row['result_json']) if row['result_json'] else None
        return cls(row['id'], row['status'], result)


# ── LocalQueue ─────────────────────────────────────────────────────────────────

class LocalQueue:
    """
    Priority-aware local queue backed by Python's queue.PriorityQueue.
    Items placed with queue_name='high' are always processed before
    'medium', which are processed before 'low'.
    """

    _NUM_WORKERS = 2

    def __init__(self):
        self._pq = queue.PriorityQueue()
        self._counter = 0
        self._counter_lock = threading.Lock()
        self._workers = []
        for _ in range(self._NUM_WORKERS):
            t = threading.Thread(target=self._consume, daemon=True)
            t.start()
            self._workers.append(t)
        self._cleanup_orphaned_jobs()

    def _cleanup_orphaned_jobs(self):
        """
        At startup:
        1. Any local_jobs rows still 'queued'/'started' are orphaned (server
           restarted mid-job). Mark them 'failed' and sync conversions.
        2. Any conversions rows still 'pending'/'started' whose local_jobs
           row is already 'finished'/'failed' are stale. Sync them.
        """
        try:
            _ensure_table()
            _restart_msg = json.dumps({
                'error': True,
                'message': 'Server was restarted — please use the Try Again button.',
                'results': []
            })
            with sqlite3.connect(_DB_PATH) as conn:
                cur = conn.execute(
                    "SELECT id FROM local_jobs WHERE status IN ('queued', 'started')"
                )
                orphan_ids = [row[0] for row in cur.fetchall()]
                if orphan_ids:
                    _log.warning(
                        "LocalQueue startup: marking %d orphaned job(s) as failed: %s",
                        len(orphan_ids), orphan_ids
                    )
                    placeholders = ','.join('?' * len(orphan_ids))
                    conn.execute(
                        f"UPDATE local_jobs SET status='failed', result_json=?, "
                        f"completed_at=datetime('now') WHERE id IN ({placeholders})",
                        [_restart_msg] + orphan_ids
                    )
                    try:
                        conn.execute(
                            f"UPDATE conversions SET status='failed', "
                            f"error_message=CASE WHEN error_message='' THEN 'Server was restarted' ELSE error_message END, "
                            f"completed_at=COALESCE(completed_at, datetime('now')) "
                            f"WHERE job_id IN ({placeholders}) AND status IN ('pending','started')",
                            orphan_ids
                        )
                    except Exception as _e:
                        _log.warning("LocalQueue startup: could not sync conversions table: %s", _e)

                try:
                    stale = conn.execute("""
                        SELECT c.id, lj.status AS lj_status, lj.result_json
                        FROM conversions c
                        JOIN local_jobs lj ON lj.id = c.job_id
                        WHERE c.status IN ('pending', 'started')
                          AND lj.status IN ('finished', 'failed')
                    """).fetchall()
                    if stale:
                        _log.warning(
                            "LocalQueue startup: healing %d stale conversion(s) whose "
                            "jobs already completed", len(stale)
                        )
                        for row in stale:
                            try:
                                result = json.loads(row[2]) if row[2] else {}
                            except Exception:
                                result = {}
                            has_error = result.get('error', False)
                            new_status = 'failed' if (has_error or row[1] == 'failed') else 'finished'
                            err_msg = result.get('message', '') if has_error else ''
                            conn.execute(
                                "UPDATE conversions SET status=?, "
                                "error_message=CASE WHEN ? != '' THEN ? ELSE error_message END, "
                                "completed_at=COALESCE(completed_at, datetime('now')) "
                                "WHERE id=?",
                                (new_status, err_msg, err_msg, row[0])
                            )
                except Exception as _e:
                    _log.warning("LocalQueue startup: stale conversions heal failed: %s", _e)

                conn.commit()
        except Exception as exc:
            _log.error("LocalQueue startup orphan cleanup failed: %s", exc)

    def _consume(self):
        import queue as _queue
        while True:
            try:
                _, _, job_id, func, args, timeout_seconds = self._pq.get(timeout=5)
                self._run_job(job_id, func, args, timeout_seconds)
            except _queue.Empty:
                continue
            except Exception:
                _log.exception("LocalQueue._consume: unexpected error")

    def _run_job(self, job_id, func, args, timeout_seconds=0):
        # Set thread-local job_id in the *queue worker* thread.
        _tl.current_job_id = job_id
        try:
            with sqlite3.connect(_DB_PATH) as conn:
                conn.execute(
                    "UPDATE local_jobs SET status='started', started_at=datetime('now') WHERE id=?",
                    (job_id,)
                )
                try:
                    conn.execute(
                        "UPDATE conversions SET status='started' WHERE job_id=? AND status='pending'",
                        (job_id,)
                    )
                except Exception:
                    pass
                conn.commit()

            # ── run with optional timeout ──────────────────────────────────
            timed_out = False
            if timeout_seconds and timeout_seconds > 0:
                # Wrap func so the executor thread also carries the job_id in
                # its thread-local, enabling subprocess tracking.
                def _wrapped(*a):
                    _tl.current_job_id = job_id
                    return func(*a)

                _ex = concurrent.futures.ThreadPoolExecutor(max_workers=1)
                _future = _ex.submit(_wrapped, *args)
                # Shut down the executor without waiting — the thread may run
                # past the timeout and will eventually finish on its own.
                _ex.shutdown(wait=False)
                try:
                    raw = _future.result(timeout=timeout_seconds)
                except concurrent.futures.TimeoutError:
                    timed_out = True
                    raw = None
                except Exception as _fe:
                    raise _fe from None
            else:
                raw = func(*args)

            if timed_out:
                mins = timeout_seconds // 60
                secs = timeout_seconds % 60
                if mins:
                    time_str = f"{mins} minute{'s' if mins != 1 else ''}"
                    if secs:
                        time_str += f" {secs}s"
                else:
                    time_str = f"{timeout_seconds} seconds"
                _kill_job_procs(job_id)
                timeout_msg = (
                    f"Your task took more than {time_str} to process. "
                    "Get Online-Convert Premium to increase your processing time per task. "
                    "Processing tasks for a long time uses a lot of server resources "
                    "and is only available for paying users."
                )
                result_data = {
                    'error': True,
                    'timeout_exceeded': True,
                    'message': timeout_msg,
                    'results': []
                }
                with sqlite3.connect(_DB_PATH) as conn:
                    conn.execute(
                        "UPDATE local_jobs SET status='finished', result_json=?, "
                        "completed_at=datetime('now') WHERE id=?",
                        (json.dumps(result_data), job_id)
                    )
                    try:
                        conn.execute(
                            "UPDATE conversions SET status='failed', error_message=?, "
                            "completed_at=datetime('now') "
                            "WHERE job_id=? AND status IN ('pending', 'started')",
                            (timeout_msg, job_id)
                        )
                    except Exception:
                        pass
                    conn.commit()
                return

            # ── process result ─────────────────────────────────────────────
            if isinstance(raw, dict):
                if raw.get('error'):
                    result_data = {
                        'error': True,
                        'message': raw.get('message', 'Conversion failed'),
                        'results': []
                    }
                    _conv_status = 'failed'
                else:
                    paths = raw.get('results') or []
                    if isinstance(paths, str):
                        paths = [paths]
                    result_data = {
                        'error': False,
                        'message': 'Conversion complete',
                        'results': paths
                    }
                    if 'results2' in raw:
                        result_data['results2'] = raw['results2']
                    _conv_status = 'finished'
            elif isinstance(raw, list):
                result_data = {
                    'error': False,
                    'message': 'Conversion complete',
                    'results': raw
                }
                _conv_status = 'finished'
            elif isinstance(raw, str) and raw:
                result_data = {
                    'error': False,
                    'message': 'Conversion complete',
                    'results': [raw]
                }
                _conv_status = 'finished'
            else:
                result_data = {
                    'error': True,
                    'message': 'Conversion returned no output',
                    'results': []
                }
                _conv_status = 'failed'

            _conv_output_path = (
                result_data['results'][0]
                if _conv_status == 'finished' and result_data.get('results')
                else None
            )
            _err_msg = result_data.get('message', '') if _conv_status == 'failed' else ''

            with sqlite3.connect(_DB_PATH) as conn:
                conn.execute(
                    "UPDATE local_jobs SET status='finished', result_json=?, "
                    "completed_at=datetime('now') WHERE id=?",
                    (json.dumps(result_data), job_id)
                )
                try:
                    conn.execute(
                        "UPDATE conversions SET status=?, output_path=COALESCE(?, output_path), "
                        "error_message=CASE WHEN ? != '' THEN ? ELSE error_message END, "
                        "completed_at=datetime('now') "
                        "WHERE job_id=? AND status IN ('pending', 'started')",
                        (_conv_status, _conv_output_path, _err_msg, _err_msg, job_id)
                    )
                except Exception:
                    pass
                conn.commit()

            _cleanup_job_procs(job_id)

        except Exception as e:
            # Regular (non-timeout) failures: subprocesses have already completed
            # (communicate() blocks until they exit), so no kill is needed.
            # Per task spec, killing on non-timeout exceptions is out of scope.
            _cleanup_job_procs(job_id)
            err_text = f'Conversion error: {str(e)}'
            result_data = {
                'error': True,
                'message': err_text,
                'results': []
            }
            with sqlite3.connect(_DB_PATH) as conn:
                conn.execute(
                    "UPDATE local_jobs SET status='failed', result_json=?, "
                    "completed_at=datetime('now') WHERE id=?",
                    (json.dumps(result_data), job_id)
                )
                try:
                    conn.execute(
                        "UPDATE conversions SET status='failed', error_message=?, "
                        "completed_at=datetime('now') "
                        "WHERE job_id=? AND status IN ('pending', 'started')",
                        (err_text, job_id)
                    )
                except Exception:
                    pass
                conn.commit()
        finally:
            _tl.current_job_id = None

    def enqueue_call(self, func, args=(), timeout=0, result_ttl=5000, queue_name='low', **kwargs):
        _ensure_table()
        job_id = str(uuid_mod.uuid4())
        priority = _PRIORITY_MAP.get(queue_name, 2)

        with sqlite3.connect(_DB_PATH) as conn:
            conn.execute(
                "INSERT INTO local_jobs (id, status) VALUES (?, 'queued')",
                (job_id,)
            )
            conn.commit()

        with self._counter_lock:
            seq = self._counter
            self._counter += 1

        self._pq.put((priority, seq, job_id, func, args, timeout))
        return _LocalJobHandle(job_id)


class _LocalJobHandle:
    def __init__(self, job_id):
        self.id = job_id


_local_queue = LocalQueue()
