"""Functional Modified Differential Multiplier Method (MDMM) for PyTorch.

This module provides functional constraints that don't rely on nn.Parameter,
making them compatible with torch.func.vmap, functional optimizers, and TensorDicts.

Modified from https://github.com/crowsonkb/mdmm
"""

import abc
from dataclasses import dataclass

import torch


@dataclass
class ConstraintReturn:
    """The return type for individual constraints."""

    value: torch.Tensor
    fn_value: torch.Tensor
    inf: torch.Tensor


class FunctionalConstraint(metaclass=abc.ABCMeta):
    """The base class for all functional constraint types."""

    def __init__(self, scale: float = 1.0, damping: float = 1.0):
        self.scale = scale
        self.damping = damping

    @abc.abstractmethod
    def infeasibility(self, fn_value: torch.Tensor) -> torch.Tensor: ...

    def compute_penalty(self, fn_value: torch.Tensor, lmbda: torch.Tensor) -> ConstraintReturn:
        inf = self.infeasibility(fn_value)
        l_term = lmbda * inf
        damp_term = self.damping * (inf**2) / 2

        # Ensures damping is scaled accordingly.
        penalty = self.scale * (l_term + damp_term)
        return ConstraintReturn(penalty, fn_value, inf)


class EqConstraint(FunctionalConstraint):
    """Represents an equality constraint."""

    def __init__(self, value: float, scale: float = 1.0, damping: float = 1.0):
        super().__init__(scale, damping)
        self.value = value

    def infeasibility(self, fn_value: torch.Tensor) -> torch.Tensor:
        return self.value - fn_value


class MaxConstraintHard(FunctionalConstraint):
    """Represents a maximum inequality constraint without a slack variable."""

    def __init__(self, max_val: float, scale: float = 1.0, damping: float = 1.0):
        super().__init__(scale, damping)
        self.max_val = max_val

    def infeasibility(self, fn_value: torch.Tensor) -> torch.Tensor:
        return torch.clamp(fn_value, max=self.max_val) - fn_value


class MinConstraintHard(FunctionalConstraint):
    """Represents a minimum inequality constraint without a slack variable."""

    def __init__(self, min_val: float, scale: float = 1.0, damping: float = 1.0):
        super().__init__(scale, damping)
        self.min_val = min_val

    def infeasibility(self, fn_value: torch.Tensor) -> torch.Tensor:
        return torch.clamp(fn_value, min=self.min_val) - fn_value


class BoundConstraintHard(FunctionalConstraint):
    """Represents a bound constraint."""

    def __init__(self, min_val: float, max_val: float, scale: float = 1.0, damping: float = 1.0):
        super().__init__(scale, damping)
        self.min_val = min_val
        self.max_val = max_val

    def infeasibility(self, fn_value: torch.Tensor) -> torch.Tensor:
        return torch.clamp(fn_value, self.min_val, self.max_val) - fn_value


@dataclass
class MDMMReturn:
    """The return type for FunctionalMDMM."""

    value: torch.Tensor
    fn_values: list[torch.Tensor]
    infs: list[torch.Tensor]


class FunctionalMDMM:
    """The main Functional MDMM class, which combines multiple constraints."""

    def __init__(self, constraints: list[FunctionalConstraint]):
        self.constraints = constraints

    def init_state(self, batch_size: torch.Size, device: torch.device) -> dict[str, torch.Tensor]:
        """Returns the initial state dict containing the dual variables (lambdas).

        Args:
            batch_size: The batch dimensions. e.g., (num_models,)
            device: The device to place the tensors on.
        """
        # Shape: [num_constraints, *batch_size]
        return {"lambdas": torch.zeros((len(self.constraints), *batch_size), device=device)}

    def __call__(
        self, state: dict[str, torch.Tensor], fn_values: list[torch.Tensor], lam_key: str = "lambdas"
    ) -> MDMMReturn:
        if len(fn_values) != len(self.constraints):
            raise ValueError(f"Expected {len(self.constraints)} function values, got {len(fn_values)}")

        total_value = torch.tensor(0.0, device=fn_values[0].device, dtype=fn_values[0].dtype)
        infs = []

        for i, (c, fn_val) in enumerate(zip(self.constraints, fn_values, strict=False)):
            lmbda = state[lam_key][i]
            c_ret = c.compute_penalty(fn_val, lmbda)
            total_value = total_value + c_ret.value
            infs.append(c_ret.inf)

        return MDMMReturn(total_value, fn_values, infs)


def test_functional_mdmm():
    from torch.func import grad_and_value

    # A constraint that ensures fn_value <= 0.0
    constraint = MaxConstraintHard(max_val=0.0, damping=5.0)
    mdmm = FunctionalMDMM([constraint])

    # Initial parameters for optimization block
    params = {
        "w": torch.tensor([1.0, -1.0], requires_grad=True),
        "mdmm": mdmm.init_state(torch.Size([]), torch.device("cpu")),
    }

    def loss_fn(p):
        w = p["w"]
        mdmm_state = p["mdmm"]

        # A dummy base objective to minimize
        main_loss = (w**2).sum()

        # The constraint: fn_value expected to be <= 0
        # E.g., w[0] + w[1] + 1 <= 0. Currently 1 - 1 + 1 = 1 (Violates constraint)
        fn_val = w[0] + w[1] + 1.0

        # Process MDMM
        mdmm_ret = mdmm(mdmm_state, [fn_val])

        return main_loss + mdmm_ret.value

    grads, loss = grad_and_value(loss_fn)(params)

    # Dual variables need gradient ascent. We invert them natively.
    grads["mdmm"]["lambdas"] = -grads["mdmm"]["lambdas"]

    print("Loss evaluated successfully:", loss.item())
    print("w grads:", grads["w"])
    print("lambdas grad (reversed for ascent):", grads["mdmm"]["lambdas"])

    # Because fn_val is 1.0, max_val is 0.0, infeasibility = -1.0
    # In gradient ascent, lambda grad should be negative, pushing lmbda lower
    # (effectively penalizing the main objective more aggressively).
    assert grads["mdmm"]["lambdas"][0].item() > 0.0


if __name__ == "__main__":
    test_functional_mdmm()
    import matplotlib.pyplot as plt
    import numpy as np
    import torch
    from matplotlib.animation import FuncAnimation

    # ---------------------------------------------------------------------------
    # Two competing losses whose Pareto front is a quarter-circle arc
    # ---------------------------------------------------------------------------
    def loss_1(theta):
        return (theta**2).sum()

    def loss_2(theta):
        return ((theta - 1.0) ** 2).sum()

    # ---------------------------------------------------------------------------
    # Modified Differential Method of Multipliers (MDMM)
    # ---------------------------------------------------------------------------
    damping = 10.0
    eps = 0.7
    num_steps = 400
    lr_theta = 0.01
    lr_lambda = 1.0

    def theta_from_losses(l1, l2):
        """Return a theta whose loss_1=l1 and loss_2=l2 (picks the branch with x>=y)."""
        s = (l1 - l2 + 2) / 2  # x + y = s
        disc = 2 * l1 - s**2
        if disc < 0:
            raise ValueError(f"No real theta for (l1={l1}, l2={l2})")
        x = (s + disc**0.5) / 2
        y = s - x
        return torch.tensor([x, y], dtype=torch.float32)

    # Starting points specified in loss space (l1, l2)
    starting_losses = [
        (1.6, 1.8),
        (1.8, 1.7),
        (1.9, 1.5),
        (2.0, 1.8),
        (2.1, 1.3),
        (2.0, 0.3),  # below epsilon, far right
    ]

    starting_points = [theta_from_losses(l1, l2) for l1, l2 in starting_losses]

    colours = [
        "#808080",  # grey
        "#e91e8c",  # pink
        "#8B4513",  # brown
        "#7B2FBE",  # purple
        "#DC143C",  # crimson
        "#2E8B57",  # green
        "#FF8C00",  # orange
        "#1E90FF",  # dodger blue
    ]

    # ---------------------------------------------------------------------------
    # Run optimisation from several starting points and record trajectories
    # ---------------------------------------------------------------------------
    print("Running MDMM from", len(starting_points), "starting points …")
    trajectories = []

    for start in starting_points:
        theta = start.clone().requires_grad_(True)
        lam = torch.tensor(0.0, requires_grad=True)
        traj = []

        for _ in range(num_steps):
            l1 = loss_1(theta)
            l2 = loss_2(theta)
            traj.append((l1.item(), l2.item()))

            # Compute Lagrangian
            constraint = eps - l2
            damp = damping * constraint.detach()  # stop_gradient equivalent
            L = l1 - (lam - damp) * constraint

            # Gradients for theta (minimise) and lambda (maximise)
            L.backward()

            with torch.no_grad():
                theta -= lr_theta * theta.grad
                lam += lr_lambda * lam.grad  # ascent on lambda
                lam.clamp_(min=0.0)

            theta.grad = None
            lam.grad = None

        trajectories.append(np.array(traj))

    print("Done.")

    # ---------------------------------------------------------------------------
    # Pareto front for plotting
    # ---------------------------------------------------------------------------
    t = np.linspace(0, 1, 300)
    pareto_l1 = 2 * t**2
    pareto_l2 = 2 * (1 - t) ** 2

    # ---------------------------------------------------------------------------
    # Animate
    # ---------------------------------------------------------------------------
    fig, ax = plt.subplots(figsize=(8, 7))
    fig.patch.set_facecolor("#D0D0D0")
    ax.set_facecolor("#D0D0D0")

    ax.set_xlim(-0.05, 2.2)
    ax.set_ylim(-0.05, 2.2)
    ax.set_xlabel("Loss #1", fontsize=13)
    ax.set_ylabel("Loss #2", fontsize=13)
    ax.set_title("The Modified Differential Method of Multipliers", fontsize=14, fontweight="bold", pad=12)

    # Pareto front (blue arc + hatched region below)
    ax.plot(pareto_l1, pareto_l2, color="#1E90FF", linewidth=2.5, zorder=2)
    ax.fill_between(
        pareto_l1, pareto_l2, 0, alpha=0.08, color="#1E90FF", hatch="//", edgecolor="#1E90FF", linewidth=0, zorder=1
    )

    # Epsilon constraint line (dashed + hatched band)
    ax.axhline(y=eps, color="black", linewidth=1.2, linestyle="--", zorder=1)
    ax.fill_between([0, 2.5], eps - 0.015, eps + 0.015, color="black", alpha=0.12, hatch="\\\\", zorder=1)

    # Epsilon annotation arrow
    ax.annotate("", xy=(0.08, 0.0), xytext=(0.08, eps), arrowprops=dict(arrowstyle="<->", color="black", lw=1.5))
    ax.text(0.15, eps / 2, r"$\varepsilon$", fontsize=18, va="center")

    # Prepare artists for each trajectory
    lines = []
    dots = []
    for colour in colours:
        (line,) = ax.plot([], [], color=colour, linewidth=1.6, zorder=3)
        (dot,) = ax.plot([], [], "o", color=colour, markersize=5, zorder=4)
        lines.append(line)
        dots.append(dot)

    # Black dots that appear once a trajectory reaches the constraint
    (converge_dots,) = ax.plot([], [], "o", color="black", markersize=5, zorder=5)

    def init():
        for line, dot in zip(lines, dots, strict=False):
            line.set_data([], [])
            dot.set_data([], [])
        converge_dots.set_data([], [])
        return lines + dots + [converge_dots]

    def update(frame):
        cx, cy = [], []
        for i, traj in enumerate(trajectories):
            seg = traj[: frame + 1]
            if len(seg) > 0:
                lines[i].set_data(seg[:, 0], seg[:, 1])
                dots[i].set_data([seg[-1, 0]], [seg[-1, 1]])
                if abs(seg[-1, 1] - eps) < 0.05 and frame > 40:
                    cx.append(seg[-1, 0])
                    cy.append(seg[-1, 1])
        converge_dots.set_data(cx, cy)
        return lines + dots + [converge_dots]

    anim = FuncAnimation(fig, update, init_func=init, frames=num_steps, interval=30, blit=True)

    plt.tight_layout()

    anim.save("./mdmm.gif", writer="pillow", fps=30)
