#! /usr/bin/env python
## linear_model_latent_factors_stan.py

import cmdstanpy
import numpy as np
import matplotlib.pyplot as plt

# Simulate data
gen = np.random.default_rng(seed=0)
n = 50
p = 8
q = 3
u = v = 1
Z = 10 * np.reshape(gen.normal(size=n*q),[n,q])
X = 10 * np.reshape(gen.normal(size=n*p),[n,p])
Sigma = X.dot(X.T) + Z.dot(Z.T) + np.identity(n)
y = gen.multivariate_normal(np.zeros(n),Sigma)

#sigma = 1.0/gen.gamma(1,1)
sm_data = {'n':n, 'p':p, 'X':X, 'y':y, 'q':q, 'u':u, 'v':v, 'a':1, 'b':1}

# Initialise stan object
sm = cmdstanpy.CmdStanModel(stan_file='linear_model_latent_factors.stan')


# Select the number of MCMC chains and iterations, then sample
chains, samples, burn = 1, 1000, 1000
fit=sm.sample(data=sm_data, chains=chains, iter_sampling=samples, iter_warmup=burn, save_warmup=False, seed=1)

print(fit)
