#! /usr/bin/env python
## gibbs_sampling_2d.py
import numpy as np
import matplotlib.pyplot as plt

def full_conditional(y,mu,g):
    z = g.binomial(1,1/(1+np.exp(-2*y*mu)))
    return(g.normal() + (mu if z else -mu))

def gibbs_sampling(M=100,seed=9,initial=[0,0],mu=1):
    gen = np.random.default_rng(seed=seed)
    xs = np.empty(shape=[M+1,2])
    xs[0,] = x = list(initial)
    for i in range(M):
        for j in range(2):
            x[j] = full_conditional(x[1-j],mu,gen)
        xs[i+1,] = x
    return(xs)

def trace_plots(z):
    fig,axs=plt.subplots(1,len(z),figsize=(12,4), constrained_layout=True)
    for ind in range(len(z)):
        x,y = z[ind][:,0],z[ind][:,1]
        axs[ind].plot(x,y,'bx-',linewidth=.2,markersize=4)
        axs[ind].set_xlabel(r'$\theta_1$', fontsize=16)
        axs[ind].set_ylabel(r'$\theta_2$', fontsize=16)
    plt.show()

trace_plots([gibbs_sampling(mu=mu) for mu in (1,3)])
