#! /usr/bin/env python
## spline_regression_simulate.py
from spline_regression import model_average
import numpy as np
import matplotlib.pyplot as plt
import sys

def reg_fn(t): return(10+5*np.sin(t)+t**2/5.0)

def spline_fn(t,tau,beta,d=1):
    val = beta[0] + np.dot(beta[1:(d+1)],[t**j for j in range(1,d+1)])
    for j in range(len(tau)):
        val += beta[j+d+1]*(t-tau[j])**d if t > tau[j] else 0
    return(val)

gen = np.random.default_rng(seed=0)
n = 10 # number of observations
d = 3 if len(sys.argv)<2 else int(sys.argv[1]) # degree of splines
T = 10 # size of function domain
x = np.linspace(start=0, stop=T, num=n)
y = [gen.normal(loc=reg_fn(x_i)) for x_i in x]
grid = np.linspace(start=0, stop=T, num=50)

ave_f,max_m,max_mn,max_tau,pm=model_average(x,y,40,len(grid),T,d)
fig,axs=plt.subplots(1,3,figsize=(12,4),constrained_layout=True)
fig.canvas.manager.set_window_title('Spline regression posterior')
true_f = [reg_fn(x_i) for x_i in grid]
for plt_ind in [0,2]:
    axs[plt_ind].plot(grid,true_f, color='c', lw=2, linestyle='--')
    axs[plt_ind].scatter(x,y, color='black')
    axs[plt_ind].set_xlabel(r'$x$')
axs[0].plot(grid,ave_f)
axs[0].set_title('Posterior mean regression function')
axs[1].bar(range(len(pm)),pm)
axs[1].set_xlabel(r'$m$')
axs[1].set_title('Posterior '+r'$p(m\vert x,y)$')
axs[2].plot(grid,[spline_fn(t_i,max_tau,max_mn,d) for t_i in grid])
axs[2].set_title('Mean regression function for '+r'$m=$'+str(max_m))
plt.show()
