#! /usr/bin/env python
## gibbs_sampling_normal_mixture.py

import numpy as np
import matplotlib.pyplot as plt
from scipy.special import gammaln

def full_con_p(x,sx,sxx,n,al,a,b,la):
    a_star = a+.5*n
    lp = np.log(al+n) + gammaln(a_star+.5) - gammaln(a_star)
    lp += .5*np.log(1-1/(la+n+1))
    lp += a_star*np.log(b+.5*(sxx - sx**2/(la+n)))
    lp -= (a_star+.5)*np.log(b+.5*(sxx+x**2 - (sx+x)**2/(la+n+1)))
    return(np.exp(lp))

def gibbs_sampling(x,m=2,mu=1,al=.1,a=.1,b=.1,la=1,M=100,seed=0):
    gen = np.random.default_rng(seed=seed)
    n = len(x)
    ns = np.array([int(n/m) for _ in range(m-1)]+[n-(m-1)*int(n/m)])
    z = np.repeat(range(m),ns)
    sx = [sum([x[i] for i in range(n) if z[i]==j]) for j in range(m)]
    sxx = [sum([x[i]**2 for i in range(n) if z[i]==j]) for j in range(m)]
    pz = np.empty(m)
    for _ in range(M):
        for i in range(n): #loop through sample allocations
            for j in range(m):
                if z[i]==j:
                    ns[j] -= 1
                    sx[j] -= x[i]
                    sxx[j] -= x[i]**2
                pz[j]=full_con_p(x[i],sx[j],sxx[j],ns[j],al,a,b,la)

            z[i] = gen.choice(m,p=pz/sum(pz))
            ns[z[i]] += 1
            sx[z[i]] += x[i]
            sxx[z[i]] += x[i]**2

    print(", ".join(map(str,ns/n)))
    print(", ".join(map(str,np.array(sx)/ns)))

def simulate_beta_mixture(n, beta_pars, probs):
    gen = np.random.default_rng(seed=0)
    z = gen.choice(len(probs), n, p=probs)
    return(np.array([gen.beta(*(beta_pars[z_i])) for z_i in z]))

x = np.sort(simulate_beta_mixture(10000,[[20,10],[2,3]],[0.3,0.7]))
gibbs_sampling(x,2)
