"""Search Analytics data fetching with pagination support."""

from __future__ import annotations

import logging
from datetime import date, timedelta

from src.api.rate_limiter import RateLimiter, retry_on_api_error
from src.models.data_models import PageMetrics, QueryMetrics, SearchMetrics

logger = logging.getLogger("seo_optimizer")

ROW_LIMIT = 25000  # Max rows per API request


@retry_on_api_error()
def _execute_query(service, site_url: str, request_body: dict) -> list[dict]:
    """Execute a single searchAnalytics.query request."""
    response = (
        service.searchanalytics()
        .query(siteUrl=site_url, body=request_body)
        .execute()
    )
    return response.get("rows", [])


def fetch_search_analytics(
    service,
    site_url: str,
    start_date: date,
    end_date: date,
    dimensions: list[str] | None = None,
    rate_limiter: RateLimiter | None = None,
) -> list[dict]:
    """Fetch search analytics data with pagination.

    Args:
        service: Webmasters v3 service.
        site_url: GSC property URL.
        start_date: Start date for the query.
        end_date: End date for the query.
        dimensions: List of dimensions (e.g., ['query', 'page']).
        rate_limiter: Optional rate limiter instance.

    Returns:
        List of row dicts from the API response.
    """
    if dimensions is None:
        dimensions = []

    all_rows = []
    start_row = 0

    while True:
        if rate_limiter:
            rate_limiter.acquire(site_url)

        request_body = {
            "startDate": start_date.isoformat(),
            "endDate": end_date.isoformat(),
            "dimensions": dimensions,
            "rowLimit": ROW_LIMIT,
            "startRow": start_row,
        }

        rows = _execute_query(service, site_url, request_body)
        if not rows:
            break

        all_rows.extend(rows)
        logger.debug(
            "Fetched %d rows (total: %d) for %s [%s-%s] dims=%s",
            len(rows), len(all_rows), site_url,
            start_date, end_date, dimensions,
        )

        if len(rows) < ROW_LIMIT:
            break
        start_row += ROW_LIMIT

    return all_rows


def _aggregate_totals(rows: list[dict]) -> SearchMetrics:
    """Aggregate rows into total SearchMetrics."""
    if not rows:
        return SearchMetrics()

    total_clicks = sum(r.get("clicks", 0) for r in rows)
    total_impressions = sum(r.get("impressions", 0) for r in rows)
    ctr = total_clicks / total_impressions if total_impressions > 0 else 0.0

    # Weighted average position
    weighted_pos = sum(
        r.get("position", 0) * r.get("impressions", 0) for r in rows
    )
    avg_position = weighted_pos / total_impressions if total_impressions > 0 else 0.0

    return SearchMetrics(
        clicks=total_clicks,
        impressions=total_impressions,
        ctr=round(ctr, 4),
        position=round(avg_position, 2),
    )


def _rows_to_query_metrics(rows: list[dict], top_n: int = 20) -> list[QueryMetrics]:
    """Convert API rows (with 'query' dimension) to sorted QueryMetrics."""
    metrics = []
    for row in rows:
        keys = row.get("keys", [])
        if not keys:
            continue
        metrics.append(
            QueryMetrics(
                query=keys[0],
                clicks=row.get("clicks", 0),
                impressions=row.get("impressions", 0),
                ctr=round(row.get("ctr", 0), 4),
                position=round(row.get("position", 0), 2),
            )
        )
    metrics.sort(key=lambda m: m.clicks, reverse=True)
    return metrics[:top_n]


def _rows_to_page_metrics(rows: list[dict], top_n: int = 20) -> list[PageMetrics]:
    """Convert API rows (with 'page' dimension) to sorted PageMetrics."""
    metrics = []
    for row in rows:
        keys = row.get("keys", [])
        if not keys:
            continue
        metrics.append(
            PageMetrics(
                page=keys[0],
                clicks=row.get("clicks", 0),
                impressions=row.get("impressions", 0),
                ctr=round(row.get("ctr", 0), 4),
                position=round(row.get("position", 0), 2),
            )
        )
    metrics.sort(key=lambda m: m.clicks, reverse=True)
    return metrics[:top_n]


def fetch_for_period(
    service,
    site_url: str,
    start_date: date,
    end_date: date,
    rate_limiter: RateLimiter | None = None,
) -> dict:
    """Fetch all analytics dimensions for a period.

    Returns:
        Dict with keys: total, queries, pages, rows_by_query, rows_by_page
    """
    # Total (no dimensions)
    total_rows = fetch_search_analytics(
        service, site_url, start_date, end_date,
        dimensions=[], rate_limiter=rate_limiter,
    )
    total = _aggregate_totals(total_rows)

    # By query
    query_rows = fetch_search_analytics(
        service, site_url, start_date, end_date,
        dimensions=["query"], rate_limiter=rate_limiter,
    )
    top_queries = _rows_to_query_metrics(query_rows)

    # By page
    page_rows = fetch_search_analytics(
        service, site_url, start_date, end_date,
        dimensions=["page"], rate_limiter=rate_limiter,
    )
    top_pages = _rows_to_page_metrics(page_rows)

    logger.info(
        "Period %s~%s for %s: clicks=%d, impressions=%d, ctr=%.2f%%, pos=%.1f",
        start_date, end_date, site_url,
        total.clicks, total.impressions, total.ctr * 100, total.position,
    )

    return {
        "total": total,
        "top_queries": top_queries,
        "top_pages": top_pages,
        "rows_by_query": query_rows,
        "rows_by_page": page_rows,
    }
