module Muse3GLike

using ComponentArrays, Distributions, ForwardDiff, HDF5, LinearAlgebra, 
    PDMats, Random, Requires, SparseArrays, Statistics


struct spt3g_2yr_delensed_ee_optimal_pp_muse{D1,Σ1,D2,Σ2,S,B,P,T,C,SD}
    d :: D1
    Σ :: Σ1
    d_transformed :: D2
    Σ_transformed :: Σ2
    s :: S
    BPWF :: B
    preprocess :: P
    transform :: T
    components :: C
    systematics_dist :: SD
    c0 :: Float64
end
Base.show(io::IO, s::spt3g_2yr_delensed_ee_optimal_pp_muse) = print(io, "spt3g_2yr_delensed_ee_optimal_pp_muse($(s.components))")


function spt3g_2yr_delensed_ee_optimal_pp_muse(
    filename = joinpath(@__DIR__, "muse3glike/dat/90_150_220.h5");
    components = (:ϕϕ, :EE),
    ℓ = 1:5000
)

    h5open(filename) do f

        𝟘 = zero(ComponentVector(read(f,"data"), ComponentArrays.Axis(eval(Meta.parse(read(f,"axis"))))))

        # preprocess, here drop last EE bin
        preprocess = x -> ComponentVector(k => (k == :EE) ? x.EE[1:end-1] : x[k] for k in components)
        preprocess_jac = ForwardDiff.jacobian(preprocess, select(𝟘,components)) + preprocess(select(𝟘,components)) * select(𝟘,components)'
        
        # bandpower likelihood is taken as Gaussian in this transformed space. volume factor keeps uniform prior on bandpowers.
        transform = x -> ComponentVector(k => isbandpowers(k) ? bandpower_transform.(x[k]) : x[k] for k in keys(x))
        transform_volume = Diagonal(ComponentVector(k => (isbandpowers(k) ? 1/0.8^2 : 0) .+ preprocess(𝟘)[k] for k in components))

        # preprocess and transform the data and covariance matrix
        s = select(read(f, "scale") + 𝟘, components)
        d = select(read(f, "data") + 𝟘, components)
        d_transformed = transform(preprocess(d ./ s))
        Σ = select(Symmetric(read(f, "covariance")) + 𝟘 * 𝟘', components)
        Σ_transformed = PDMat(Symmetric(inv(inv(preprocess_jac * (Σ ./ s ./ s') * preprocess_jac') + transform_volume)))
        c0 = Distributions.mvnormal_c0(MvNormal(Σ_transformed)) # constant to match Python's loglike

        # bandpower window functions
        BPWF = Dict(Symbol(k) => (k == "ℓ") ? v[ℓ] : sparse(v[:,ℓ]) for (k,v) in read(f, "BPWF"))

        # prior for systematics which are explicilty sampled
        components_syst = filter(!isbandpowers, components)
        systematics_dist = isempty(components_syst) ? nothing : select(Uniform.(eachcol(d .+ [-5 5] .* sqrt.(diag(Σ)))...), components_syst)

        spt3g_2yr_delensed_ee_optimal_pp_muse(d, Σ, d_transformed, Σ_transformed, s, BPWF, preprocess, transform, components, systematics_dist, c0)
    
    end

end

function d_transformed_dist(spt::spt3g_2yr_delensed_ee_optimal_pp_muse, x)
    μ = ComponentVector(k => isbandpowers(k) ? (spt.BPWF[k] * x[k]) ./ spt.s[k] : x[k] for k in keys(x))
    MvNormal(spt.transform(spt.preprocess(μ)), spt.Σ_transformed)
end

function loglike(spt::spt3g_2yr_delensed_ee_optimal_pp_muse, x)
    logpdf(d_transformed_dist(spt, x), spt.d_transformed) - spt.c0
end

@init @require Turing="fce5fe82-541a-59a6-adf8-730c64b5f9a0" @eval begin
    Turing.@model function turing_model_for_bandpowers(spt::spt3g_2yr_delensed_ee_optimal_pp_muse, spectra)
        systematics = isnothing(spt.systematics_dist) ? () : ComponentVector(systematics ~ Turing.arraydist(spt.systematics_dist), getaxes(spt.systematics_dist.v))
        dist = d_transformed_dist(spt, (;spectra..., systematics...))
        d_transformed ~ dist
        any(isnan, d_transformed) && Turing.@addlogprob!(-Inf)
        (;d_transformed, μ_transformed=mean(dist))
    end
end

function example_inputs(spt::spt3g_2yr_delensed_ee_optimal_pp_muse)
    h5open(joinpath(@__DIR__, "muse3glike/dat/example_inputs.h5")) do f
        ComponentVector(k => isbandpowers(k) ? read(f,string(k)) : (k==:Acals) ? one.(spt.d.Acals) : zero.(spt.d[k]) for k in spt.components)
    end
end


bandpower_transform(x::T) where {T} = 0.01 < x < 2.01 ? atanh.(x .- 1.01) : T(NaN)
bandpower_untransform(x::T) where {T} = tanh.(x) .+ 1.01
isbandpowers(k) = k in (:ϕϕ, :EE)



# these dont work right in ComponentArrays so do custom version of
# them that work here. see:
# https://github.com/jonniedie/ComponentArrays.jl/issues/256 and
# https://github.com/jonniedie/ComponentArrays.jl/issues/257
function select(c::ComponentArray, keys)
    ComponentArray(;(k => c[k] for k in keys)...)
end
function select(c::ComponentMatrix, keys)
    i = select(ComponentVector(1:size(c,1), getaxes(c)[1]), keys)
    ComponentMatrix(c[i,i], (getaxes(i)[1], getaxes(i)[1]))
end
# the default constructor recomputes cholesky of Σ everytime when eltype(μ) <: Dual, don't do that
Distributions.MvNormal(μ::V, Σ::M) where {R<:Real,T,D<:ForwardDiff.Dual{T,R},V<:AbstractVector{D},M<:AbstractPDMat{R}} = Distributions.MvNormal{D,M,V}(μ,Σ)
# make arraydist(::ComponentVector) preserve labels
Distributions.rand(rng::Random.AbstractRNG, dist::Product) = rand.(Ref(rng), dist.v)
# avoids a few parentheses above
ComponentArrays.ComponentVector(g::Base.Generator) = ComponentVector(NamedTuple(g))

end

