#! /usr/bin/env python
## logistic_regression_stan.py

import stan
import numpy as np
import matplotlib.pyplot as plt

# Simulate data
gen = np.random.default_rng(seed=0)
n = 25
m = 50
T = 5
x = np.linspace(start=-T, stop=T, num=n)
grid = np.linspace(start=-T, stop=T, num=m)
beta = .5#gen.normal()
y = [gen.binomial(1,1/(1+np.exp(-x_i*beta))) for x_i in x]
sm_data = {'n':n, 'p':1, 'a':1, 'b':0.5, 'X':x.reshape((n,1)), 'y':y, 'm':m, 'grid':grid.reshape((m,1))}

# Initialise stan object
with open('logistic_regression.stan','r',newline='') as f:
    sm = stan.build(f.read(),sm_data,random_seed=1)

# Select the number of MCMC chains and iterations, then sample
chains, samples, burn = 2, 10000, 1000
fit=sm.sample(num_chains=chains, num_samples=samples, num_warmup=burn, save_warmup=False)

# Plot regression function and posterior for beta
fig,axs=plt.subplots(1,2,figsize=(10,4),constrained_layout=True)
fig.canvas.manager.set_window_title('Logistic regression posterior')
f = np.mean(fit['fn_vals'],axis=1)
true_f = [1.0/(1+np.exp(-beta*x_i)) for x_i in grid]
b = fit['beta'][0]
axs[0].plot(grid,f)
axs[0].plot(grid,true_f, color='c', lw=2, linestyle='--')
axs[0].scatter(x,y, color='black')
axs[0].set_title('Posterior mean regression function')
axs[0].set_xlabel(r'$x$')
h = axs[1].hist(b,200, density=True);
axs[1].axvline(beta, color='c', lw=2, linestyle='--')
axs[1].set_title('Approximate posterior density of '+r'$\beta$')
axs[1].set_xlabel(r'$\beta$')
plt.show()
