import numpy as np
import cosmopower as cp

packages_path = '/absolute/path/to/this/folder/'



beameig1 = np.load(packages_path+"data/sptpol2023/beameig_part1_ProjectedOut.npy")
beameig2 = np.load(packages_path+"data/sptpol2023/beameig_part2.npy")
beameig3  = np.load(packages_path+"data/sptpol2023/beameig_part3.npy")
invC  = np.load(packages_path+"data/sptpol2023/invC_np_NoBeam_E57.npy") /1e24
clhat = np.load(packages_path+"data/sptpol2023/clhat.npy",allow_pickle=True,encoding="latin1").item()
bpwf  = np.load(packages_path+"data/sptpol2023/bpwf_databeam.npy" ,allow_pickle=True,encoding="latin1").item()
#bpwf  = np.load(packages_path+"data/sptpol2023/bpwf_simbeam.npy" ,allow_pickle=True,encoding="latin1").item()
cp_ee = cp.cosmopower_NN(        restore=True, restore_filename='/absolute_path_to/EE_v1' )
cp_te = cp.cosmopower_PCAplusNN( restore=True, restore_filename='/absolute_path_to/TE_v1' )
cp_der= cp.cosmopower_NN(        restore=True, restore_filename='/absolute_path_to/DER_v1')



def do_logp(
        # Parameters that we may or may not sample over
        tau_reio,logA,n_s,H0,omega_b,omega_cdm,
        Tcal9to5=1.,
        Pcal090=1.,
        Tcal150=1.,
        Pcal150=1.,
        Abeam1=0,
        Abeam2=0,
        Abeam3=0,
        Abeam4=0,
        Abeam5=0,
        Abeam6=0,
        Abeam7=0,
        kappa= 0,
        beta = 0.0012309,
        czero_psEE_090 = 0,
        czero_psEE_150 = 0,
        ADust_TE_090   = 0,
        ADust_TE_150   = 0,
        ADust_EE_090   = 0,
        ADust_EE_150   = 0,
        alphaDust_TE = -2.42,
        alphaDust_EE = -2.42,
        # Keyword through which the cobaya likelihood instance will be passed.
        _self=None):
    Tcal090 = Tcal9to5 * Tcal150
    params = { 'omega_b':[omega_b], 'omega_cdm':[omega_cdm], 'H0':[H0], 'tau_reio':[tau_reio], 'n_s':[n_s], 'ln10^{10}A_s':[logA] }
    dlee = 2.726e6*2.726e6*cp_ee.ten_to_predictions_np(params)[0]
    clee = dlee *2*np.pi/np.arange(2,11001)/np.arange(3,11002)
    clee = np.hstack(( 0., clee, np.zeros(500) ))  # starts with ell=1
    dlte = 2.726e6*2.726e6*cp_te.predictions_np(params)[0]
    clte = dlte *2*np.pi/np.arange(2,11001)/np.arange(3,11002)
    clte = np.hstack(( 0., clte, np.zeros(500) ))  # starts with ell=1
    deriv_te = np.hstack((0,clte[2:]-clte[:-2],0))/2
    deriv_ee = np.hstack((0,clee[2:]-clee[:-2],0))/2
    dl2cl = 2.*np.pi /np.arange(1,11501)/np.arange(2,11502)
    d3000 = 3000.*3001./2/np.pi
    poisson_level_EE_090 = czero_psEE_090 /d3000
    poisson_level_EE_9x5 = np.sqrt(np.abs(czero_psEE_090*czero_psEE_150)) /d3000
    poisson_level_EE_150 = czero_psEE_150 /d3000
    dipole_cosine = -0.4033
    fgnd090_TE   = dl2cl*ADust_TE_090*                               (np.arange(1,11501)/80.)**(alphaDust_TE+2)                        - kappa*(2*clte+deriv_te*np.arange(1,11501)) - beta*dipole_cosine*deriv_te*np.arange(1,11501)
    fgnd9x5_TE   = dl2cl*np.sqrt(np.abs(ADust_TE_090*ADust_TE_150))* (np.arange(1,11501)/80.)**(alphaDust_TE+2)                        - kappa*(2*clte+deriv_te*np.arange(1,11501)) - beta*dipole_cosine*deriv_te*np.arange(1,11501)
    fgnd9x5_ET   = dl2cl*np.sqrt(np.abs(ADust_TE_090*ADust_TE_150))* (np.arange(1,11501)/80.)**(alphaDust_TE+2)                        - kappa*(2*clte+deriv_te*np.arange(1,11501)) - beta*dipole_cosine*deriv_te*np.arange(1,11501)
    fgnd150_TE   = dl2cl*ADust_TE_150*                               (np.arange(1,11501)/80.)**(alphaDust_TE+2)                        - kappa*(2*clte+deriv_te*np.arange(1,11501)) - beta*dipole_cosine*deriv_te*np.arange(1,11501)
    fgnd090_EE   = dl2cl*ADust_EE_090*                               (np.arange(1,11501)/80.)**(alphaDust_EE+2) + poisson_level_EE_090 - kappa*(2*clee+deriv_ee*np.arange(1,11501)) - beta*dipole_cosine*deriv_ee*np.arange(1,11501)
    fgnd9x5_EE   = dl2cl*np.sqrt(np.abs(ADust_EE_090*ADust_EE_150))* (np.arange(1,11501)/80.)**(alphaDust_EE+2) + poisson_level_EE_9x5 - kappa*(2*clee+deriv_ee*np.arange(1,11501)) - beta*dipole_cosine*deriv_ee*np.arange(1,11501)
    fgnd150_EE   = dl2cl*ADust_EE_150*                               (np.arange(1,11501)/80.)**(alphaDust_EE+2) + poisson_level_EE_150 - kappa*(2*clee+deriv_ee*np.arange(1,11501)) - beta*dipole_cosine*deriv_ee*np.arange(1,11501)
    theory090_TE = bpwf['POL_090_TxPOL_090_E'] @ (clte + fgnd090_TE)
    theory9x5_TE = bpwf['POL_090_TxPOL_150_E'] @ (clte + fgnd9x5_TE)
    theory9x5_ET = bpwf['POL_150_TxPOL_090_E'] @ (clte + fgnd9x5_ET)
    theory150_TE = bpwf['POL_150_TxPOL_150_E'] @ (clte + fgnd150_TE)
    theory090_EE = bpwf['POL_090_ExPOL_090_E'] @ (clee + fgnd090_EE)
    theory9x5_EE = bpwf['POL_090_ExPOL_150_E'] @ (clee + fgnd9x5_EE)
    theory150_EE = bpwf['POL_150_ExPOL_150_E'] @ (clee + fgnd150_EE)
    theory090_TE*= (1.+Abeam1*beameig1[:56,0]) * (1.+Abeam2*beameig1[:56,1]) * (1.+Abeam3*beameig1[:56,2]) * (1.+Abeam4*beameig1[:56,3]) * (1.+Abeam5*beameig2[:56]) * (1.+Abeam6*beameig3[:56,0]) * (1.+Abeam7*beameig3[:56,1])                      / ( Tcal090**2 * Pcal090      )
    theory9x5_TE*= (1.+Abeam1*beameig1[56:112,0]) * (1.+Abeam2*beameig1[56:112,1]) * (1.+Abeam3*beameig1[56:112,2]) * (1.+Abeam4*beameig1[56:112,3]) * (1.+Abeam5*beameig2[56:112]) * (1.+Abeam6*beameig3[56:112,0]) * (1.+Abeam7*beameig3[56:112,1]) / ( Tcal090*Tcal150 * Pcal150 )
    theory9x5_ET*= (1.+Abeam1*beameig1[56:112,0]) * (1.+Abeam2*beameig1[56:112,1]) * (1.+Abeam3*beameig1[56:112,2]) * (1.+Abeam4*beameig1[56:112,3]) * (1.+Abeam5*beameig2[56:112]) * (1.+Abeam6*beameig3[56:112,0]) * (1.+Abeam7*beameig3[56:112,1]) / ( Tcal090*Tcal150 * Pcal090 )
    theory150_TE*= (1.+Abeam1*beameig1[112:,0]) * (1.+Abeam2*beameig1[112:,1]) * (1.+Abeam3*beameig1[112:,2]) * (1.+Abeam4*beameig1[112:,3]) * (1.+Abeam5*beameig2[112:]) * (1.+Abeam6*beameig3[112:,0]) * (1.+Abeam7*beameig3[112:,1])               / ( Tcal150**2 * Pcal150      )
    theory090_EE*= (1.+Abeam1*beameig1[:56,0]) * (1.+Abeam2*beameig1[:56,1]) * (1.+Abeam3*beameig1[:56,2]) * (1.+Abeam4*beameig1[:56,3]) * (1.+Abeam5*beameig2[:56]) * (1.+Abeam6*beameig3[:56,0]) * (1.+Abeam7*beameig3[:56,1])                      / ( Tcal090**2 * Pcal090**2   )
    theory9x5_EE*= (1.+Abeam1*beameig1[56:112,0]) * (1.+Abeam2*beameig1[56:112,1]) * (1.+Abeam3*beameig1[56:112,2]) * (1.+Abeam4*beameig1[56:112,3]) * (1.+Abeam5*beameig2[56:112]) * (1.+Abeam6*beameig3[56:112,0]) * (1.+Abeam7*beameig3[56:112,1]) / ( Tcal090*Tcal150 * Pcal090*Pcal150 )
    theory150_EE*= (1.+Abeam1*beameig1[112:,0]) * (1.+Abeam2*beameig1[112:,1]) * (1.+Abeam3*beameig1[112:,2]) * (1.+Abeam4*beameig1[112:,3]) * (1.+Abeam5*beameig2[112:]) * (1.+Abeam6*beameig3[112:,0]) * (1.+Abeam7*beameig3[112:,1])               / ( Tcal150**2 * Pcal150**2   )
    residual = np.hstack(( theory090_TE - clhat['bandpowers']['POL_090_TxPOL_090_E'] ,
                           theory9x5_TE - clhat['bandpowers']['POL_090_TxPOL_150_E'] ,
                           theory9x5_ET - clhat['bandpowers']['POL_150_TxPOL_090_E'] ,
                           theory150_TE - clhat['bandpowers']['POL_150_TxPOL_150_E'] ,
                           theory090_EE - clhat['bandpowers']['POL_090_ExPOL_090_E'] ,
                           theory9x5_EE - clhat['bandpowers']['POL_090_ExPOL_150_E'] ,
                           theory150_EE - clhat['bandpowers']['POL_150_ExPOL_150_E'] ))
    chi2 = residual.T @ invC @ residual
    logp = -0.5 * chi2
    derived = {'theta_s_100': cp_der.ten_to_predictions_np(params)[0][0],
               'sigma_8':     cp_der.ten_to_predictions_np(params)[0][1],
               'Y_P':         cp_der.ten_to_predictions_np(params)[0][2]}
    return logp, derived


