from . import atFunctions as at
from math import atan2,sin,cos,sqrt
import numpy as np

ASTE_MJD_BASE = 51544.0 # 2000/01/01 00:00:00 0.0
MAXI_MJD_BASE = 51544.0 # 2000/01/01 00:00:00 0.0
MAXI_MJDREFI = 51544                # 2000.0 UT
MAXI_MJDREFF = 0.00074287037037037  # (64.184/86400)
MAXI_GPS_BASE = -630720013          # 1980/1/6 00:00:00 in maxitime

def _mat2eulerZYX(RM, V):
    """
    rm - (in) 3x3 matrix
    v - (out) 3 dimentional vector
    """
    cy_thresh= 8.8817841970012523e-16;
    cy = sqrt(RM[2][2]*RM[2][2] + RM[1][2]*RM[1][2]);
    if cy > cy_thresh:
        # cos(y) not close to zero, standard form
        V[0]= atan2(-RM[0][1],  RM[0][0])   #atan2(cos(y)*sin(z), cos(y)*cos(z))
        V[1]= atan2(RM[0][2],  cy)          #atan2(sin(y), cy)
        V[2]= atan2(-RM[1][2], RM[2][2])    #atan2(cos(y)*sin(x), cos(x)*cos(y))
    else:
        # cos(y) (close to) zero, so x -> 0.0 (see above)
        # so RM[1][0] -> sin(z), RM[1][1] -> cos(z) and
        # atan2(sin(y), cy)
        V[0]= atan2(RM[1][0],  RM[1][1])
        V[1]= atan2(RM[0][2],  cy)
        V[2]= 0.0

def _euler2matZYX(V, RM):
    """
    V - (in) 3 dimentional vector
    RM - (OUT) 3x3 matrix
    """
    x=V[0];
    y=V[1];
    z=V[2];

    RM[0][0]=cos(y)*cos(z);
    RM[1][0]=sin(x)*sin(y)*cos(z)-cos(x)*sin(z);
    RM[2][0]=cos(x)*sin(y)*cos(z)+sin(x)*sin(z);

    RM[0][1]=cos(y)*sin(z);
    RM[1][1]=sin(x)*sin(y)*sin(z)+cos(x)*cos(z);
    RM[2][1]=cos(x)*sin(y)*sin(z)-sin(x)*cos(z);

    RM[0][2]=-sin(y);
    RM[1][2]=sin(x)*cos(y);
    RM[2][2]=cos(x)*cos(y);

def mjd2maxi(mjd):
    """
    Convert MJD to MAXI time

    double mjd2maxi(double)
    double[] mjd2maxi(double[])
    """
    return at.atMJDToMission(MAXI_MJD_BASE, mjd)

def maxi2mjd(mission):
    """
    Convert MAXI time to MJD

    double maxi2mjd(double)
    double[] maxi2mjd(double[])
    """
    return at.atMissionToMJD(MAXI_MJD_BASE, mission)

def maxi2gps(mission):
    """
    Convert MAXI time to GPS time

    double maxi2gps(double)
    double[] maxi2gps(double[])
    """
    return mission - MAXI_GPS_BASE

def gps2maxi(gps):
    """\
    Convert GPS time to MAXI time

    double gps2maxi(double)
    double[] gps2maxi(double[])
    """
    return gps + MAXI_GPS_BASE

def quat2EulerZYX(q):
    """\
    AtVect atQuat2EulerZYX(AtVect)
    :
    AtVect[] atQuat2EulerZYX(AtVect[])
    """
    expand = False
    if q.ndim == 1:
        expand = True
        q = q.reshape((1, 4))
    if q.ndim != 2:
        raise ValueError("invalid dimention: q")
    n = q.shape[0]
    ea = np.zeros((n, 3), dtype=np.double)
    v = np.zeros((3, ), dtype=np.double)
    rm = at.atQuatToRM(q)
    for i in range(n):
        _mat2eulerZYX(rm[i], v)
        for j in range(3):
            ea[i][j] = -v[2-j]
    if expand:
        return ea[0]
    return ea

def euler2QuatZYX(ea):
    """\
    AtVect atEuler2QuatZYX(AtVect)
    AtVect[] atEuler2QuatZYX(AtVect[])
    """
    expand = False
    if ea.ndim == 1:
        expand = True
        ea = ea.reshape((1, 3))
    if ea.ndim != 2:
        raise ValueError("invalid dimention: ea")
    n = ea.shape[0]
    v = np.zeros((3, ), dtype=np.double)
    rm = np.zeros((n, 3, 3), dtype=np.double)
    q = np.zeros((n, 4), dtype=np.double)
    for i in range(n):
        for j in range(3):
            v[j] = ea[i][j]
        _euler2matZYX(v, rm[i])
    q = at.atRMToQuat(rm)
    if expand:
        return q[0]
    return q

def sunBeta(mjd, pos, vel):
    """
    double atSunBeta(double, AtVect, AtVect)
    :
    double[] atSunBeta(double[], AtVect[], AtVect[])
    """
    expand = False
    if at.is_number(mjd) and pos.ndim == 1 and vel.ndim == 1:
        expand = True
    mjd = at.n2np(mjd)
    if pos.ndim == 1:
        pos = pos.reshape((1, 3))
    if vel.ndim == 1:
        vel = vel.reshape((1, 3))
    if mjd.ndim != 1:
        raise ValueError("Invalid dimention: mjd")
    if pos.ndim != 2:
        raise ValueError("Invalid dimention pos")
    if vel.ndim != 2:
        raise ValueError("Invalid dimention vel")
    if not at.valid_array_len(mjd, pos, vel):
        raise ValueError("Invalid array length")
    n = max(mjd.shape[0], pos.shape[0], vel.shape[0])
    beta = np.zeros((n, ), dtype=np.double)
    vect_sun = np.zeros((3, ), dtype=np.double)
    for i in range(n):
        vect_sun = at.atSun(mjd[min(i, mjd.shape[0] - 1)])
        nvect_sun = at.atNormVect(vect_sun)
        axis_x = at.atNormVect(pos[min(i, pos.shape[0] - 1)])
        axis_y = at.atNormVect(vel[min(i, vel.shape[0] - 1)])
        axis_z = at.atVectProd(axis_x, axis_y)
        dist = at.atAngDistance(nvect_sun, axis_z)
        beta[i] = at.PI / 2.0 - dist
    if expand:
        return beta[0]
    return beta

def quatDiff(q1, q2):
    """
    AtQuat atQuatDiff(AtQuat, AtQuat)
    AtQuat[] atQuatDiff(AtQuat[], AtQuat[])
    """
    expand = False
    if q1.ndim == 1 and q2.ndim == 1:
        expand = True
    if q1.ndim == 1:
        q1 = q1.reshape((1, 4))
    if q2.ndim == 1:
        q2 = q2.reshape((1, 4))
    if q1.ndim != 2:
        raise ValueError("Invalid dimention q1")
    if q2.ndim != 2:
        raise ValueError("Invalid dimention q2")
    if not at.valid_array_len(q1, q2):
        raise ValueError("invalid array length")
    n = max(q1.shape[0], q2.shape[0])
    q = np.zeros((n, 4), dtype=np.double)
    rm2 = at.atQuatToRM(q2)
    irm2 = at.atInvRotMat(rm2)
    iq2 = at.atRMToQuat(irm2)
    q3 = at.atQuatProd(iq2, q1)
    rm = at.atQuatToRM(q3)
    q = at.atRMToQuat(rm)
    if expand:
        return q[0]
    return q


## matlab-compatible
# E2Q(input:YPR, output:Q)
def angle2quatZYX(ang):
    """\
    AtVect atEuler2QuatZYX(AtVect)
    AtVect[] atEuler2QuatZYX(AtVect[])
    """
    expand = False
    if ang.ndim == 1:
        expand = True
        ang = ang.reshape((1, 3))
    if ang.ndim != 2:
        raise ValueError("invalid dimention: ang")
    n = ang.shape[0]
    q = np.zeros((n, 4), dtype=np.double)
    c1 = np.cos(ang[:,0]*0.5);
    s1 = np.sin(ang[:,0]*0.5);
    c2 = np.cos(ang[:,1]*0.5);
    s2 = np.sin(ang[:,1]*0.5);
    c3 = np.cos(ang[:,2]*0.5);
    s3 = np.sin(ang[:,2]*0.5);
    q[:,3] = c1*c2*c3 + s1*s2*s3;
    q[:,0] = c1*c2*s3 - s1*s2*c3;
    q[:,1] = c1*s2*c3 + s1*c2*s3;
    q[:,2] = s1*c2*c3 - c1*s2*s3;
    if expand:
        return q[0]
    return q

# matlab-compatible
# Q2E(input:Q, output:YPR)
def quat2angleZYX(q):
    q = np.array(q, dtype=np.double)
    expand = False
    if q.ndim == 1:
        expand = True
        q = q.reshape((1, 4))
    if q.ndim != 2:
        raise ValueError("invalid dimention: q")
    n = q.shape[0]
    ang = np.zeros((n, 3), dtype=np.double)
    norm = np.tile(np.sqrt(q[:,0]**2 + q[:,1]**2 + q[:,2]**2 + q[:,3]**2),(4,1)).T
    n_check = (norm[:,0] <= 0)
    q[n_check] = np.array([0,0,0,1])
    norm[n_check] = 1
    q[:] = q[:]/norm
    C11 = q[:,3]**2 + q[:,0]**2 - q[:,1]**2 - q[:,2]**2
    C12 = 2.0 * (q[:,0]*q[:,1] + q[:,3]*q[:,2])
    C13 = 2.0 * (q[:,0]*q[:,2] - q[:,3]*q[:,1])
#   C21 = 2.0 * (q[:,0]*q[:,1] - q[:,3]*q[:,2])
#   C22 = q[:,3]**2 - q[:,0]**2 + q[:,1]**2 - q[:,2]**2
    C23 = 2.0 * (q[:,1]*q[:,2] + q[:,3]*q[:,0])
#   C31 = 2.0 * (q[:,0]*q[:,2] + q[:,3]*q[:,1])
#   C32 = 2.0 * (q[:,1]*q[:,2] - q[:,3]*q[:,0])
    C33 = q[:,3]**2 - q[:,0]**2 - q[:,1]**2 + q[:,2]**2
    ang[:,0] = np.arctan2(C12,C11)
    ang[:,1] = -np.arcsin(C13)
    ang[:,2] = np.arctan2(C23,C33)
    ang[n_check] = np.array([np.nan,np.nan,np.nan])
    if expand:
        return ang[0]
    return ang

# matlab-compatible
# Q2E(input:Q, output:YRP)
def quat2angleZXY(q):
    q = np.array(q, dtype=np.double)
    expand = False
    if q.ndim == 1:
        expand = True
        q = q.reshape((1, 4))
    if q.ndim != 2:
        raise ValueError("invalid dimention: q")
    n = q.shape[0]
    ang = np.zeros((n, 3), dtype=np.double)
    norm = np.tile(np.sqrt(q[:,0]**2 + q[:,1]**2 + q[:,2]**2 + q[:,3]**2),(4,1)).T
    n_check = (norm[:,0] <= 0)
    q[n_check] = np.array([0,0,0,1])
    norm[n_check] = 1
    q[:] = q[:]/norm
#   C11 = q[:,3]**2 + q[:,0]**2 - q[:,1]**2 - q[:,2]**2
#   C12 = 2.0 * (q[:,0]*q[:,1] + q[:,3]*q[:,2])
    C13 = 2.0 * (q[:,0]*q[:,2] - q[:,3]*q[:,1])
    C21 = 2.0 * (q[:,0]*q[:,1] - q[:,3]*q[:,2])
    C22 = q[:,3]**2 - q[:,0]**2 + q[:,1]**2 - q[:,2]**2
    C23 = 2.0 * (q[:,1]*q[:,2] + q[:,3]*q[:,0])
#   C31 = 2.0 * (q[:,0]*q[:,2] + q[:,3]*q[:,1])
#   C32 = 2.0 * (q[:,1]*q[:,2] - q[:,3]*q[:,0])
    C33 = q[:,3]**2 - q[:,0]**2 - q[:,1]**2 + q[:,2]**2
    ang[:,0] = np.arctan2(-C21,C22)
    ang[:,1] = np.arcsin(C23)
    ang[:,2] = np.arctan2(-C13,C33)
    ang[n_check] = np.array([np.nan,np.nan,np.nan])
    if expand:
        return ang[0]
    return ang

# matlab-compatible
# Q2E(input:Q, output:PRY)
def quat2angleYXZ(q):
    q = np.array(q, dtype=np.double)
    expand = False
    if q.ndim == 1:
        expand = True
        q = q.reshape((1, 4))
    if q.ndim != 2:
        raise ValueError("invalid dimention: q")
    n = q.shape[0]
    ang = np.zeros((n, 3), dtype=np.double)
    norm = np.tile(np.sqrt(q[:,0]**2 + q[:,1]**2 + q[:,2]**2 + q[:,3]**2),(4,1)).T
    n_check = (norm[:,0] <= 0)
    q[n_check] = np.array([0,0,0,1])
    norm[n_check] = 1
    q[:] = q[:]/norm
 #  C11 = q[:,3]**2 + q[:,0]**2 - q[:,1]**2 - q[:,2]**2
    C12 = 2.0 * (q[:,0]*q[:,1] + q[:,3]*q[:,2])
 #  C13 = 2.0 * (q[:,0]*q[:,2] - q[:,3]*q[:,1])
 #  C21 = 2.0 * (q[:,0]*q[:,1] - q[:,3]*q[:,2])
    C22 = q[:,3]**2 - q[:,0]**2 + q[:,1]**2 - q[:,2]**2
 #  C23 = 2.0 * (q[:,1]*q[:,2] + q[:,3]*q[:,0])
    C31 = 2.0 * (q[:,0]*q[:,2] + q[:,3]*q[:,1])
    C32 = 2.0 * (q[:,1]*q[:,2] - q[:,3]*q[:,0])
    C33 = q[:,3]**2 - q[:,0]**2 - q[:,1]**2 + q[:,2]**2
    ang[:,0] = np.arctan2(C31,C33)
    ang[:,1] = -np.arcsin(C32)
    ang[:,2] = np.arctan2(C12,C22)
    ang[n_check] = np.array([np.nan,np.nan,np.nan])
    if expand:
        return ang[0]
    return ang

