import numpy as np
import networkx as nx
from matplotlib import pyplot

def generate_problem(n,m):
    # Generate a random instance of a Z^2-synchronization problem

    # Generate xsol with random entries in {-1,1}
    xsol = np.random.rand(n,1)
    xsol = 2 * (xsol > 0.5) - 1

    # Choose m pairs
    Omega = []
    while len(Omega) < m:
        i = np.random.randint(n)
        j = np.random.randint(n)
        if (i<j) and not((i,j) in Omega):
            Omega.append((i,j))

    # Measurements
    y = np.zeros(m)
    for k in range(m):
        y[k] = xsol[Omega[k][0],0] * xsol[Omega[k][1],0]
    
    return xsol, Omega, y


def Omega_connected(n,Omega):
    # Check whether the set of pairs in Omega defines a connected
    # graph over {1,...,n}

    G = nx.Graph()
    G.add_nodes_from(range(n))
    G.add_edges_from(Omega)
    return nx.is_connected(G)


def cost_matrix(n,y,Omega):

    Y = np.zeros((n,n))

    for k in range(len(Omega)):
        Y[Omega[k][0],Omega[k][1]] = y[k]
        Y[Omega[k][1],Omega[k][0]] = y[k]

    return Y


def cost(Y,sig,U):
    # Evaluate cost function
    # -<Y,UU^T> + (sig/2) sum_i (||U[i,:]||^2 - 1)^2

    x = -np.sum((Y@U) * U)
    for k in range(np.shape(U)[0]):
        x = x + (sig/2) * (np.linalg.norm(U[k,:])**2 - 1)**2

    return x


def U_to_sol(U):
    # From a solution to the Burer-Monteiro problem, compute a
    # solution of the original synchronization problem

    svdvecs, _, _ = np.linalg.svd(U)
    x = 2 * (svdvecs[:,0] > 0) - 1
    x = x.reshape(-1,1)
    
    return x


def BM_GD(Y,sig,p,nb_its=1000,rank_increase=True):
    # Attempt to minimize U -> -<Y,UU^T> + (σ/2) ||diag(UU^T) - 1||^2
    # by gradient descent, over R^(n x p)

    n = np.shape(Y)[0]

    # Random initialization
    U = np.random.randn(n,p) / np.sqrt(p)
    costs = np.zeros(nb_its)

    for k_it in range(nb_its):

        costs[k_it] = cost(Y,sig,U)
        
        # Compute gradient
        # 2 (-YU + σ Diag(diag(UU^T) - 1)U)
        grad = - 2 * Y @ U
        diff_norms = np.sum(U**2,axis=1) - 1
        grad = grad + 2 * sig * diff_norms.reshape((n,1)) * U

        if (k_it == 0):
            step = 0.1 * np.linalg.norm(U) / np.linalg.norm(grad)

        # Backtrack
        while True:

            U_new = U - step * grad
            if (cost(Y,sig,U_new) < costs[k_it] - 0.2*step*np.linalg.norm(grad)**2 + 1e-8):
                break
            else:
                step = step / 2

        U = U_new
        step = step * 1.1

        # Increase the rank?
        if rank_increase and \
           (step * np.linalg.norm(grad) / np.linalg.norm(U) < 1e-3):
            # If the norm of step * grad is small, we consider that we
            # have reached a critical point for our current value of p.

            _, S, _ = np.linalg.svd(U)
            if np.min(S) > 0.05 * np.max(S):
                # The rank is not approximately deficient: the
                # critical point may be non-globally optimal => we
                # increase the rank.
                
                p += 1
                # Add a column to U
                U = np.hstack((U,0.1*np.random.randn(n,1)))
                print("Rank increases to ",p," at iteration ",k_it)

    return U, costs

# Generate problem
n = 20
m = 30
xsol, Omega, y = generate_problem(n,m)
if not(Omega_connected(n,Omega)):
    print("Omega not connected")
Y = cost_matrix(n,y,Omega)

# Solve Burer-Monteiro factorization
p = 1
sig = 10*m/n
nb_its = 1000
U, costs = BM_GD(Y,sig,p,nb_its,rank_increase=True)

# Reconstruct solution of the original problem
x = U_to_sol(U)
x = x * x[0] * xsol[0] # Flip sign to avoid x = -xsol

# Print results
print("Linear part of the cost: ",np.sum(-(Y@U)*U))
print("Average distance between row norm and 1: ", \
      np.sum(np.abs(np.sqrt(np.sum(U**2,axis=1)) - 1)) / n)
if (np.linalg.norm(x-xsol) < 1):
    print("xsol exactly recovered.")
else:
    print("xsol not recovered.")

# Plot the cost
pyplot.semilogy(costs[0:round(0.9*nb_its)] - np.min(costs) + 1e-8)
pyplot.xlabel("number of iterations")
pyplot.ylabel("distance to minimum")
pyplot.show()
