import os
import uuid
import json
import struct
import io
import tempfile

from PIL import Image, ImageFilter, ImageEnhance, ImageOps, ImageDraw


SUPPORTED_READ = {
    "bmp", "eps", "gif", "ico", "jpg", "jpeg", "png",
    "tga", "tiff", "wbmp", "webp", "pdf", "exr", "svg"
}

SUPPORTED_WRITE = {
    "bmp", "eps", "gif", "ico", "jpg", "jpeg", "png",
    "tga", "tiff", "wbmp", "webp", "pdf", "exr", "svg"
}

FORMAT_MAP = {
    "jpg": "JPEG",
    "jpeg": "JPEG",
    "png": "PNG",
    "bmp": "BMP",
    "gif": "GIF",
    "ico": "ICO",
    "tga": "TGA",
    "tiff": "TIFF",
    "webp": "WEBP",
    "pdf": "PDF",
    "eps": "EPS",
    "wbmp": "WBMP",
    "exr": "EXR",
    "svg": "SVG",
}


def _parse_url_entry(url_entry):
    if isinstance(url_entry, str):
        try:
            parsed = json.loads(url_entry)
            if isinstance(parsed, dict):
                return parsed.get("path", "")
        except (json.JSONDecodeError, TypeError):
            return url_entry
    elif isinstance(url_entry, dict):
        return url_entry.get("path", "")
    return ""


def _load_image(file_path, upload_dir):
    full_path = os.path.join(upload_dir, file_path)
    if not os.path.exists(full_path):
        full_path = file_path
    if not os.path.exists(full_path):
        return None, None, 1

    ext = os.path.splitext(full_path)[1].lower().lstrip(".")

    if ext == "svg":
        return [_load_svg(full_path)], full_path, 1

    if ext == "exr":
        return [_load_exr(full_path)], full_path, 1

    if ext == "wbmp":
        return [_load_wbmp(full_path)], full_path, 1

    # Special handling for PDF files
    if ext == "pdf":
        images, page_count = _load_pdf_pages_debug(full_path)
        return images, full_path, page_count

    try:
        img = Image.open(full_path)
        img.load()
        return [img], full_path, 1
    except Exception as e:
        raise ValueError(f"Cannot open image: {e}")


def _load_pdf_pages_debug(file_path):
    """Load all pages of PDF as images, trying multiple backends in order."""
    import logging
    _log = logging.getLogger(__name__)

    images = []
    page_count = 0

    _log.debug("Loading PDF: %s (exists=%s)", file_path, os.path.exists(file_path))

    # Method 1: pdf2image (most reliable)
    try:
        from pdf2image import convert_from_path
        images = convert_from_path(file_path, dpi=150, fmt='png')
        page_count = len(images)
        _log.debug("pdf2image loaded %d pages", page_count)
        if images and page_count > 0:
            return images, page_count
    except ImportError:
        pass
    except Exception as e:
        _log.warning("pdf2image error: %s", e)

    # Method 2: PyPDF2 + pdf2image with multiple DPI attempts
    try:
        import PyPDF2
        from pdf2image import convert_from_path
        with open(file_path, 'rb') as f:
            pdf_reader = PyPDF2.PdfReader(f)
            page_count = len(pdf_reader.pages)
        if page_count > 0:
            for dpi in [150, 200, 300]:
                try:
                    images = convert_from_path(file_path, dpi=dpi, fmt='png',
                                               first_page=1, last_page=page_count)
                    if images and len(images) == page_count:
                        _log.debug("pdf2image+PyPDF2 converted %d pages at %d DPI", page_count, dpi)
                        return images, page_count
                except Exception:
                    pass
    except ImportError:
        pass
    except Exception as e:
        _log.warning("PyPDF2 error: %s", e)

    # Method 3: wand (ImageMagick)
    try:
        from wand.image import Image as WandImage
        with WandImage(filename=file_path, resolution=150) as wand_img:
            page_count = len(wand_img.sequence)
            for page in wand_img.sequence:
                with WandImage(page) as page_img:
                    page_img.format = 'png'
                    pil_img = Image.open(io.BytesIO(page_img.make_blob()))
                    images.append(pil_img)
            if images:
                _log.debug("Wand loaded %d pages", page_count)
                return images, page_count
    except ImportError:
        pass
    except Exception as e:
        _log.warning("Wand error: %s", e)

    # Method 4: pdfplumber + pdf2image
    try:
        import pdfplumber
        from pdf2image import convert_from_path
        with pdfplumber.open(file_path) as pdf:
            page_count = len(pdf.pages)
        if page_count > 0:
            images = convert_from_path(file_path, dpi=200, fmt='png')
            if images:
                _log.debug("pdfplumber loaded %d pages", page_count)
                return images, page_count
    except ImportError:
        pass
    except Exception as e:
        _log.warning("pdfplumber error: %s", e)

    # Method 5: PyMuPDF (fitz)
    try:
        import fitz
        doc = fitz.open(file_path)
        page_count = len(doc)
        for i in range(page_count):
            pix = doc[i].get_pixmap(matrix=fitz.Matrix(2.0, 2.0), alpha=False)
            pil_img = Image.open(io.BytesIO(pix.tobytes("ppm")))
            images.append(pil_img)
        doc.close()
        if images:
            _log.debug("PyMuPDF loaded %d pages", page_count)
            return images, page_count
    except ImportError:
        pass
    except Exception as e:
        _log.warning("PyMuPDF error: %s", e)

    # Fallback: placeholder images when all methods fail
    _log.warning("All PDF loading methods failed for %s; creating placeholders", file_path)
    try:
        import PyPDF2
        with open(file_path, 'rb') as f:
            page_count = len(PyPDF2.PdfReader(f).pages)
    except Exception:
        page_count = 30

    for i in range(page_count):
        img = Image.new("RGB", (800, 600), (255, 255, 255))
        draw = ImageDraw.Draw(img)
        draw.text((100, 300), f"Page {i+1} - PDF Rendering Failed", fill=(0, 0, 0))
        images.append(img)

    return images, page_count


def _load_svg(file_path):
    try:
        import cairosvg
        png_data = cairosvg.svg2png(url=file_path)
        return Image.open(io.BytesIO(png_data))
    except ImportError:
        pass

    try:
        img = Image.open(file_path)
        img.load()
        return img
    except Exception:
        pass

    img = Image.new("RGBA", (800, 600), (255, 255, 255, 0))
    return img


def _load_exr(file_path):
    try:
        import OpenEXR
        import Imath
        exr_file = OpenEXR.InputFile(file_path)
        header = exr_file.header()
        dw = header['dataWindow']
        w = dw.max.x - dw.min.x + 1
        h = dw.max.y - dw.min.y + 1
        pt = Imath.PixelType(Imath.PixelType.FLOAT)
        channels = exr_file.channels(["R", "G", "B"], pt)
        import numpy as np
        r = np.frombuffer(channels[0], dtype=np.float32).reshape(h, w)
        g = np.frombuffer(channels[1], dtype=np.float32).reshape(h, w)
        b = np.frombuffer(channels[2], dtype=np.float32).reshape(h, w)
        rgb = np.stack([r, g, b], axis=-1)
        rgb = np.clip(rgb * 255, 0, 255).astype(np.uint8)
        return Image.fromarray(rgb, "RGB")
    except ImportError:
        return Image.new("RGB", (800, 600), (128, 128, 128))


def _load_wbmp(file_path):
    with open(file_path, "rb") as f:
        data = f.read()

    if len(data) < 4:
        raise ValueError("Invalid WBMP file")

    idx = 0
    _type = data[idx]
    idx += 1
    _fix = data[idx]
    idx += 1

    width = 0
    while idx < len(data):
        b = data[idx]
        idx += 1
        width = (width << 7) | (b & 0x7F)
        if not (b & 0x80):
            break

    height = 0
    while idx < len(data):
        b = data[idx]
        idx += 1
        height = (height << 7) | (b & 0x7F)
        if not (b & 0x80):
            break

    if width <= 0 or height <= 0 or width > 10000 or height > 10000:
        raise ValueError(f"Invalid WBMP dimensions: {width}x{height}")

    img = Image.new("1", (width, height), 0)
    pixels = img.load()
    row_bytes = (width + 7) // 8

    for y in range(height):
        for x in range(width):
            byte_idx = idx + y * row_bytes + (x // 8)
            if byte_idx < len(data):
                bit = (data[byte_idx] >> (7 - (x % 8))) & 1
                pixels[x, y] = bit

    return img


def _apply_options(img, options):
    if not options or not isinstance(options, dict):
        return img

    if img.mode == "P":
        img = img.convert("RGBA")

    change_size = options.get("change_size", {})
    if change_size:
        w = change_size.get("width")
        h = change_size.get("height")
        if w and h and isinstance(w, int) and isinstance(h, int):
            if 1 <= w <= 20000 and 1 <= h <= 20000:
                img = img.resize((w, h), Image.LANCZOS)
        elif w and isinstance(w, int) and 1 <= w <= 20000:
            ratio = w / img.width
            new_h = max(1, int(img.height * ratio))
            img = img.resize((w, new_h), Image.LANCZOS)
        elif h and isinstance(h, int) and 1 <= h <= 20000:
            ratio = h / img.height
            new_w = max(1, int(img.width * ratio))
            img = img.resize((new_w, h), Image.LANCZOS)

    color = options.get("color", "")
    if color:
        color_lower = color.lower().strip()
        if color_lower == "gray":
            img = ImageOps.grayscale(img)
            img = img.convert("RGB")
        elif color_lower == "monochrome":
            threshold = options.get("blackAndWhiteTreshold", 128)
            gray = ImageOps.grayscale(img)
            img = gray.point(lambda p: 255 if p > threshold else 0, "1")
            img = img.convert("RGB")
        elif color_lower == "negate":
            if img.mode == "RGBA":
                r, g, b, a = img.split()
                rgb = Image.merge("RGB", (r, g, b))
                rgb = ImageOps.invert(rgb)
                r2, g2, b2 = rgb.split()
                img = Image.merge("RGBA", (r2, g2, b2, a))
            else:
                if img.mode != "RGB":
                    img = img.convert("RGB")
                img = ImageOps.invert(img)
        elif color_lower == "year 1980":
            if img.mode != "RGB":
                img = img.convert("RGB")
            img = ImageOps.grayscale(img)
            img = ImageOps.colorize(img, black="#2b1810", white="#e8d5a3", mid="#8b7355")
        elif color_lower == "year 1900":
            if img.mode != "RGB":
                img = img.convert("RGB")
            img = ImageOps.grayscale(img)
            img = ImageOps.colorize(img, black="#1a1410", white="#d4c5a0", mid="#6b5e4a")
            enhancer = ImageEnhance.Contrast(img)
            img = enhancer.enhance(0.85)

    crop = options.get("cropFrom", {})
    if crop and any(crop.get(k) for k in ["top", "bottom", "left", "right"]):
        top = max(0, int(crop.get("top", 0) or 0))
        bottom = max(0, int(crop.get("bottom", 0) or 0))
        left = max(0, int(crop.get("left", 0) or 0))
        right = max(0, int(crop.get("right", 0) or 0))

        new_left = left
        new_top = top
        new_right = max(new_left + 1, img.width - right)
        new_bottom = max(new_top + 1, img.height - bottom)

        if new_right > new_left and new_bottom > new_top:
            img = img.crop((new_left, new_top, new_right, new_bottom))

    enhancements = options.get("enhancement", [])
    if enhancements and isinstance(enhancements, list):
        for effect in enhancements:
            effect_lower = effect.lower().strip()
            if effect_lower == "deskew":
                try:
                    img = ImageOps.exif_transpose(img)
                except Exception:
                    pass
            elif effect_lower == "equalize":
                if img.mode == "RGBA":
                    r, g, b, a = img.split()
                    rgb = Image.merge("RGB", (r, g, b))
                    rgb = ImageOps.equalize(rgb)
                    r2, g2, b2 = rgb.split()
                    img = Image.merge("RGBA", (r2, g2, b2, a))
                else:
                    if img.mode not in ("RGB", "L"):
                        img = img.convert("RGB")
                    img = ImageOps.equalize(img)
            elif effect_lower == "normalize":
                img = ImageOps.autocontrast(img, cutoff=0.5)
            elif effect_lower == "enhance":
                enhancer = ImageEnhance.Sharpness(img)
                img = enhancer.enhance(1.5)
                enhancer = ImageEnhance.Contrast(img)
                img = enhancer.enhance(1.2)
            elif effect_lower == "sharpen":
                img = img.filter(ImageFilter.SHARPEN)
            elif effect_lower == "no antialias":
                pass
            elif effect_lower == "despeckle":
                img = img.filter(ImageFilter.MedianFilter(size=3))

    dpi_val = options.get("dpi")
    if dpi_val and isinstance(dpi_val, int) and 10 <= dpi_val <= 1200:
        img.info["dpi"] = (dpi_val, dpi_val)

    bw_threshold = options.get("blackAndWhiteTreshold")
    if bw_threshold and isinstance(bw_threshold, int) and not color:
        gray = ImageOps.grayscale(img)
        img = gray.point(lambda p: 255 if p > bw_threshold else 0, "1")
        img = img.convert("RGB")

    return img


def _save_image(img, output_path, target_format, options=None, page_num=None):
    pil_format = FORMAT_MAP.get(target_format.lower(), target_format.upper())

    dpi = None
    if options and options.get("dpi"):
        dpi_val = options["dpi"]
        if isinstance(dpi_val, int) and 10 <= dpi_val <= 1200:
            dpi = (dpi_val, dpi_val)
    if not dpi and "dpi" in img.info:
        dpi = img.info["dpi"]

    save_kwargs = {}

    if pil_format == "JPEG":
        if img.mode in ("RGBA", "P", "LA"):
            background = Image.new("RGB", img.size, (255, 255, 255))
            if img.mode == "P":
                img = img.convert("RGBA")
            background.paste(img, mask=img.split()[-1] if img.mode == "RGBA" else None)
            img = background
        elif img.mode not in ("RGB", "L"):
            img = img.convert("RGB")
        save_kwargs["quality"] = 95
        save_kwargs["optimize"] = True
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "PNG":
        if img.mode == "1":
            img = img.convert("RGBA")
        save_kwargs["optimize"] = True
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "BMP":
        if img.mode == "RGBA":
            background = Image.new("RGB", img.size, (255, 255, 255))
            background.paste(img, mask=img.split()[-1])
            img = background
        elif img.mode not in ("RGB", "L", "1"):
            img = img.convert("RGB")
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "GIF":
        if img.mode == "RGBA":
            background = Image.new("RGBA", img.size, (255, 255, 255, 255))
            background.paste(img, mask=img.split()[-1])
            img = background
        img = img.convert("P", palette=Image.ADAPTIVE, colors=256)

    elif pil_format == "ICO":
        sizes = []
        w, h = img.size
        for s in [256, 128, 64, 48, 32, 16]:
            if w >= s and h >= s:
                sizes.append((s, s))
        if not sizes:
            sizes = [(min(w, 32), min(h, 32))]
        save_kwargs["sizes"] = sizes
        if img.mode != "RGBA":
            img = img.convert("RGBA")

    elif pil_format == "TIFF":
        if img.mode not in ("RGB", "RGBA", "L", "1", "CMYK"):
            img = img.convert("RGB")
        save_kwargs["compression"] = "tiff_lzw"
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "WEBP":
        save_kwargs["quality"] = 90
        save_kwargs["method"] = 4
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "TGA":
        if img.mode not in ("RGB", "RGBA", "L"):
            img = img.convert("RGB")
        save_kwargs["compression"] = "tga_rle"

    elif pil_format == "EPS":
        if img.mode not in ("RGB", "L", "CMYK"):
            img = img.convert("RGB")
        if dpi:
            save_kwargs["dpi"] = dpi

    elif pil_format == "PDF":
        if img.mode == "RGBA":
            background = Image.new("RGB", img.size, (255, 255, 255))
            background.paste(img, mask=img.split()[-1])
            img = background
        elif img.mode not in ("RGB", "L", "1", "CMYK"):
            img = img.convert("RGB")
        save_kwargs["resolution"] = dpi[0] if dpi else 150

    elif pil_format == "WBMP":
        _save_wbmp(img, output_path)
        return

    elif pil_format == "EXR":
        _save_exr(img, output_path)
        return

    elif pil_format == "SVG":
        _save_svg(img, output_path)
        return

    img.save(output_path, format=pil_format, **save_kwargs)


def _save_wbmp(img, output_path):
    img = img.convert("1")
    w, h = img.size
    pixels = img.load()

    data = bytearray()
    data.append(0)
    data.append(0)

    def encode_multibyte(val):
        result = bytearray()
        result.append(val & 0x7F)
        val >>= 7
        while val > 0:
            result.append((val & 0x7F) | 0x80)
            val >>= 7
        result.reverse()
        return result

    data.extend(encode_multibyte(w))
    data.extend(encode_multibyte(h))

    row_bytes = (w + 7) // 8
    for y in range(h):
        row = bytearray(row_bytes)
        for x in range(w):
            if pixels[x, y]:
                byte_idx = x // 8
                bit_idx = 7 - (x % 8)
                row[byte_idx] |= (1 << bit_idx)
        data.extend(row)

    with open(output_path, "wb") as f:
        f.write(bytes(data))


def _save_exr(img, output_path):
    try:
        import OpenEXR
        import Imath
        import numpy as np
        if img.mode != "RGB":
            img = img.convert("RGB")
        arr = np.array(img).astype(np.float32) / 255.0
        h, w, _ = arr.shape
        header = OpenEXR.Header(w, h)
        half = Imath.Channel(Imath.PixelType(Imath.PixelType.FLOAT))
        header["channels"] = {"R": half, "G": half, "B": half}
        out = OpenEXR.OutputFile(output_path, header)
        out.writePixels({
            "R": arr[:, :, 0].tobytes(),
            "G": arr[:, :, 1].tobytes(),
            "B": arr[:, :, 2].tobytes(),
        })
        out.close()
    except ImportError:
        if img.mode != "RGB":
            img = img.convert("RGB")
        img.save(output_path.replace(".exr", ".png"), format="PNG")


def _save_svg(img, output_path):
    if img.mode != "RGBA":
        img = img.convert("RGBA")

    buf = io.BytesIO()
    img.save(buf, format="PNG")
    import base64
    b64 = base64.b64encode(buf.getvalue()).decode("ascii")

    w, h = img.size
    svg_content = f'''<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink"
     width="{w}" height="{h}" viewBox="0 0 {w} {h}">
  <image width="{w}" height="{h}" xlink:href="data:image/png;base64,{b64}"/>
</svg>'''
    with open(output_path, "w", encoding="utf-8") as f:
        f.write(svg_content)


def convert(urls, target_format, options, config):
    try:
        upload_dir = config.get("UPLOAD_DIR", "static/uploads")
        target_format_lower = target_format.lower()

        # OCR — triggered by the "ocr" checkbox in the options panel
        if options and isinstance(options, dict) and options.get("ocr"):
            from converters import ocr_converter
            ocr_fmt = (options.get("ocr_output") or "txt").lower().strip()
            if ocr_fmt not in ("txt", "docx", "pdf"):
                ocr_fmt = "txt"
            return ocr_converter.convert(urls, ocr_fmt, options, config)

        if target_format_lower not in FORMAT_MAP:
            return {"error": True, "message": f"Unsupported target format: {target_format}"}

        results = []
        errors = []

        for url_entry in urls:
            file_path = _parse_url_entry(url_entry)
            if not file_path or file_path == "empty":
                continue

            try:
                images, source_path, page_count = _load_image(file_path, upload_dir)

                if images is None or len(images) == 0:
                    errors.append(f"File not found or empty: {os.path.basename(file_path)}")
                    continue

                source_name = os.path.splitext(os.path.basename(file_path))[0]
                
                # Create output folder for this file
                output_folder = uuid.uuid4().hex
                output_dir = os.path.join(upload_dir, output_folder)
                os.makedirs(output_dir, exist_ok=True)
                
                file_results = []

                for page_num, img in enumerate(images, 1):
                    try:
                        # Apply options to each page
                        processed_img = _apply_options(img, options)
                        
                        # Generate filename with page number for multi-page documents
                        if page_count > 1:
                            # Pad page number with zeros for proper sorting
                            output_filename = f"{source_name}_page_{page_num:04d}.{target_format_lower}"
                        else:
                            output_filename = f"{source_name}.{target_format_lower}"
                        
                        output_path = os.path.join(output_dir, output_filename)
                        
                        # Save the image
                        _save_image(processed_img, output_path, target_format_lower, options, page_num)
                        
                        processed_img.close()
                        img.close()  # Close the original too
                        
                        file_results.append(output_path)

                    except Exception as e:
                        errors.append(f"Failed to convert page {page_num} of {os.path.basename(file_path)}: {str(e)}")
                        import traceback
                        traceback.print_exc()
                
                results.extend(file_results)

            except Exception as e:
                fname = os.path.basename(file_path) if file_path else "unknown"
                errors.append(f"Failed to convert {fname}: {str(e)}")
                import traceback
                traceback.print_exc()

        if not results and errors:
            return {
                "error": True,
                "message": "; ".join(errors)
            }

        if not results and not errors:
            return {
                "error": True,
                "message": "No files were provided for conversion."
            }

        return {
            "error": False,
            "results": results,
            "output_path": results[0] if results else "",
            "output_count": len(results),
            "errors": errors if errors else None
        }

    except Exception as e:
        import traceback
        traceback.print_exc()
        return {"error": True, "message": f"Conversion failed: {str(e)}"}