import base64
import json
import logging
import os
from typing import Sequence

import requests
from fastapi import Request
from opentelemetry import trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, SpanExporter, SpanExportResult
from opentelemetry.trace import SpanKind, StatusCode
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import Response

logger = logging.getLogger(__name__)


def appsentinels_init_tracing() -> None:
    provider = TracerProvider(resource=Resource.create())
    provider.add_span_processor(
        BatchSpanProcessor(
            OTLPHTTPJsonExporter(endpoint=os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"])
        )
    )
    trace.set_tracer_provider(provider)

# --- OTLP/HTTP JSON exporter -------------------------------------------------

_KIND_MAP = {
    SpanKind.INTERNAL: 1,
    SpanKind.SERVER: 2,
    SpanKind.CLIENT: 3,
    SpanKind.PRODUCER: 4,
    SpanKind.CONSUMER: 5,
}


def _encode_id(value: int, length: int) -> str:
    return base64.b64encode(value.to_bytes(length, "big")).decode()


def _attr_value(v):
    if isinstance(v, bool):
        return {"boolValue": v}
    if isinstance(v, int):
        return {"intValue": str(v)}
    if isinstance(v, float):
        return {"doubleValue": v}
    if isinstance(v, (list, tuple)):
        return {"arrayValue": {"values": [_attr_value(i) for i in v]}}
    return {"stringValue": str(v)}


def _fmt_attrs(attrs):
    if not attrs:
        return []
    return [{"key": k, "value": _attr_value(v)} for k, v in attrs.items()]


def _serialize(spans: Sequence[ReadableSpan]) -> dict:
    resource_map: dict = {}
    for span in spans:
        res_key = id(span.resource)
        if res_key not in resource_map:
            resource_map[res_key] = {"resource": span.resource, "scopes": {}}
        scope_key = (span.instrumentation_scope.name, span.instrumentation_scope.version)
        scopes = resource_map[res_key]["scopes"]
        if scope_key not in scopes:
            scopes[scope_key] = {"scope": span.instrumentation_scope, "spans": []}

        s = {
            "traceId": _encode_id(span.context.trace_id, 16),
            "spanId": _encode_id(span.context.span_id, 8),
            "name": span.name,
            "kind": _KIND_MAP.get(span.kind, 0),
            "startTimeUnixNano": str(span.start_time),
            "endTimeUnixNano": str(span.end_time),
            "attributes": _fmt_attrs(span.attributes),
            "events": [
                {
                    "name": e.name,
                    "timeUnixNano": str(e.timestamp),
                    "attributes": _fmt_attrs(e.attributes),
                }
                for e in span.events
            ],
            "links": [
                {
                    "traceId": _encode_id(lk.context.trace_id, 16),
                    "spanId": _encode_id(lk.context.span_id, 8),
                    "attributes": _fmt_attrs(lk.attributes),
                }
                for lk in span.links
            ],
            "status": {
                "code": 2 if span.status.status_code == StatusCode.ERROR else 0
            },
        }
        if span.parent:
            s["parentSpanId"] = _encode_id(span.parent.span_id, 8)
        scopes[scope_key]["spans"].append(s)

    resource_spans = []
    for res_data in resource_map.values():
        scope_spans = []
        for (name, version), scope_data in res_data["scopes"].items():
            scope_spans.append({
                "scope": {"name": name or "", "version": version or ""},
                "spans": scope_data["spans"],
            })
        resource_spans.append({
            "resource": {"attributes": _fmt_attrs(dict(res_data["resource"].attributes))},
            "scopeSpans": scope_spans,
        })

    return {"resourceSpans": resource_spans}


class OTLPHTTPJsonExporter(SpanExporter):
    def __init__(self, endpoint: str, headers: dict | None = None, timeout: int = 10):
        self.endpoint = endpoint
        self.headers = {"Content-Type": "application/json"}
        if headers:
            self.headers.update(headers)
        self.timeout = timeout

    def export(self, spans: Sequence[ReadableSpan]) -> SpanExportResult:
        payload = _serialize(spans)
        logger.debug("Exporting %d span(s) to %s: %s", len(spans), self.endpoint, json.dumps(payload))
        try:
            resp = requests.post(
                self.endpoint,
                data=json.dumps(payload),
                headers=self.headers,
                timeout=self.timeout,
            )
            resp.raise_for_status()
            logger.info("Successfully exported %d span(s) — HTTP %s", len(spans), resp.status_code)
            return SpanExportResult.SUCCESS
        except requests.exceptions.ConnectionError as e:
            logger.error("Connection error exporting spans to %s: %s", self.endpoint, e)
        except requests.exceptions.HTTPError as e:
            logger.error("HTTP error exporting spans: %s — response: %s", e, e.response.text if e.response else "")
        except Exception as e:
            logger.error("Unexpected error exporting spans: %s", e)
        return SpanExportResult.FAILURE

    def shutdown(self) -> None:
        pass


# --- Body/header capture middleware ------------------------------------------

_MAX_BODY_BYTES = 131_072

_TEXT_CONTENT_TYPES = (
    "application/json",
    "application/xml",
    "application/x-www-form-urlencoded",
    "application/graphql",
    "text/",
)


def _should_capture_body(content_type: str) -> bool:
    ct = content_type.split(";")[0].strip().lower()
    return any(ct.startswith(t) for t in _TEXT_CONTENT_TYPES)


def _capture_headers(span, headers, prefix: str) -> None:
    seen: dict = {}
    for key, value in headers.items():
        seen.setdefault(key.lower(), []).append(value)
    for key, values in seen.items():
        span.set_attribute(f"{prefix}.{key}", values if len(values) > 1 else values[0])


class BodyCaptureMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        span = trace.get_current_span()

        # No active span (e.g. Starlette-internal frames) — nothing to annotate
        if not span.is_recording():
            return await call_next(request)

        try:
            raw_qs = request.scope.get("query_string", b"").decode("utf-8", errors="replace")
            full_target = str(request.url.path)
            if raw_qs:
                full_target += "?" + raw_qs
            span.set_attribute("http.target", full_target)
            _capture_headers(span, request.headers, "http.request.header")
            if _should_capture_body(request.headers.get("content-type", "")):
                req_body = await request.body()
                if req_body:
                    span.set_attribute(
                        "http.request.body",
                        req_body[:_MAX_BODY_BYTES].decode("utf-8", errors="replace"),
                    )
        except Exception:
            logger.exception("AppSentinels: failed to annotate request span — continuing")

        response = await call_next(request)

        # Always consume body_iterator so we can reconstruct the Response.
        # This must happen outside the try block to guarantee the client always
        # receives a complete response even if annotation fails.
        chunks = []
        try:
            async for chunk in response.body_iterator:
                chunks.append(chunk)
        except Exception:
            logger.exception("AppSentinels: failed to read response body — returning partial")
        resp_body = b"".join(chunks)

        try:
            _capture_headers(span, response.headers, "http.response.header")
            if resp_body and _should_capture_body(response.headers.get("content-type", "")):
                span.set_attribute(
                    "http.response.body",
                    resp_body[:_MAX_BODY_BYTES].decode("utf-8", errors="replace"),
                )
        except Exception:
            logger.exception("AppSentinels: failed to annotate response span — continuing")

        return Response(
            content=resp_body,
            status_code=response.status_code,
            headers=dict(response.headers),
            media_type=response.media_type,
        )
