import qiskit 
import qiskit.quantum_info as qiskit_quantum_info
import numpy as np

nbqubits = 2

def make_circuit(x,t):
    qc = qiskit.QuantumCircuit(nbqubits)
    qc.ry(x[0],0,'x0'), qc.ry(x[1],1,'x1'), qc.cz(0,1)
    qc.ry(t[0],0,'t0'), qc.ry(t[1],1,'t1'), qc.cz(0,1)
    qc.ry(x[2],0,'x2'), qc.ry(x[3],1,'x3'), qc.cz(0,1)
    qc.ry(t[2],0,'t2'), qc.ry(t[3],1,'t3')
    return qc

def _predict_proba(x,t):
    qc = make_circuit(x, t)
    stateVec = qiskit_quantum_info.Statevector.from_instruction(qc)
    probavec = stateVec.probabilities()    
    return probavec[:-1] 

def predict_proba(X,t):
    return np.array([_predict_proba(x,t) for x in X])


def rescale_to_angle(X, X_min, X_max):
    return np.pi * (X - X_min) / (X_max - X_min) + np.pi/2

# qmlmodel.predict_proba(qmlmodel.rescale_to_angle(iris.X_test,X_min,X_max),t).argmax(axis=1)

def predict(X,t):
    return predict_proba(X,t).argmax(axis=1)

from scipy.special import softmax


def loss(t, X_train_rescaled, y_true):
    """ X_train is a list of angles, rescaled"""
    y_pred_proba = predict_proba(X_train_rescaled, t)
    y_pred_proba = softmax(y_pred_proba, axis=1)
    v = y_pred_proba[np.arange(len(y_true)), y_true]
    v = -np.log(v)
    return np.mean(v)
    