import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D


# ============================================================
# 1. Equilibrium manifold surface in (a, b, x)
#    Equilibria: x^3 + a x + b = 0  ->  b = -x^3 - a x
# ============================================================
x_vals = np.linspace(-3.0, 3.0, 200)
a_vals = np.linspace(-2.0, 2.0, 200)

X, A = np.meshgrid(x_vals, a_vals)
B = -X**3 - A * X

fig = plt.figure(figsize=(12, 5))

ax1 = fig.add_subplot(1, 2, 1, projection='3d')
ax1.plot_surface(A, B, X, rstride=4, cstride=4, alpha=0.7)
ax1.set_xlabel('Em (a)')
ax1.set_ylabel('Ra (b)')
ax1.set_zlabel('B (x)')
ax1.set_title('Equilibrium manifold: x³ + a x + b = 0')

# ============================================================
# 2. Cusp set in control space
#    Cusp: 4 a^3 + 27 b^2 = 0
# ============================================================
s = np.linspace(-2.0, 2.0, 400)
a_cusp = -3.0 * s**2
b_cusp =  2.0 * s**3

ax2 = fig.add_subplot(1, 2, 2)
ax2.plot(a_cusp, b_cusp, linewidth=2)
ax2.set_xlabel('a')
ax2.set_ylabel('b')
ax2.set_title('Cusp set: 4a³ + 27b² = 0')
ax2.axhline(0, linewidth=0.5)
ax2.axvline(0, linewidth=0.5)
ax2.set_aspect('equal', 'box')

plt.tight_layout()
#plt.show()


# ============================================================
# 3. Potential V(x; a, b) for representative parameter values
# ============================================================
def V(x, a, b):
    return x**4/4 + a*x**2/2 + b*x

x = np.linspace(-3, 3, 400)

cases = [
    ( 1.0, 0.0, "outside cusp (a=1, b=0)"),
    (-1.0, 0.0, "inside cusp (a=-1, b=0)"),
    (-3.0, 2.0, "on cusp (a=-3, b=2)")
]

plt.figure(figsize=(7, 5))
for a_val, b_val, label in cases:
    plt.plot(x, V(x, a_val, b_val), label=label)

plt.xlabel("B (x)")
plt.ylabel("V(x; a, b)")
plt.title("Potential for typical (a, b)")
plt.grid(True)
plt.legend()
plt.tight_layout()
#plt.show()


# ============================================================
# 4. Slices of the equilibrium manifold at fixed a
#    For fixed a = a0: b = -x³ - a0 x
# ============================================================
a_slice_values = [1.0, 0.0, -1.0]  # outside, transitional, inside cusp
x_slice = np.linspace(-3, 3, 400)

plt.figure(figsize=(8, 6))
for a0 in a_slice_values:
    b_slice = -x_slice**3 - a0 * x_slice
    # plt.plot(x_slice, b_slice, label=f"a = {a0}")
    plt.plot(b_slice, x_slice, label=f"a = {a0}")

plt.xlabel("Ra (b)")
plt.ylabel("B (x)")
plt.title("Equilibrium manifold slices for fixed a")
plt.grid(True)
plt.legend()
plt.tight_layout()
#plt.show()


# ============================================================
# 5. NEW: Slices of the equilibrium manifold at fixed b
#    For fixed b = b0: x³ + a x + b0 = 0  ->  a = -(x³ + b0)/x
#    (for x ≠ 0; we mask a small neighbourhood of x=0)
# ============================================================
#b_slice_values = [0.1, 0.0, -0.2]  # three representative values of 
b_slice_values = [-0.1]  # three representative values of b
x_slice = np.linspace(-3, 3, 400)

plt.figure(figsize=(8, 6))

eps = 1e-6  # exclude |x| < eps to avoid division by zero
mask = np.abs(x_slice) > eps

for b0 in b_slice_values:
    a_slice = -(x_slice**3 + b0) / x_slice
    # plt.plot(x_slice[mask], a_slice[mask], label=f"b = {b0}")
    plt.plot(a_slice[mask], x_slice[mask], label=f"b = {b0}",  linestyle='None', marker='.')
    
plt.xlabel("Em (a)")
plt.ylabel("B (x)")
plt.title("Equilibrium manifold slices for fixed b")
plt.grid(True)
plt.legend()
plt.tight_layout()
plt.show()
