#! /usr/bin/env python
## linear_model_latent_factors_stan.py

import stan
import numpy as np
import matplotlib.pyplot as plt

# Simulate data
gen = np.random.default_rng(seed=0)
n = 50
p = 8
q = 3
u = v = 1
Z = 10 * np.reshape(gen.normal(size=n*q),[n,q])
X = 10 * np.reshape(gen.normal(size=n*p),[n,p])
Sigma = X.dot(X.T) + Z.dot(Z.T) + np.identity(n)
y = gen.multivariate_normal(np.zeros(n),Sigma)

#sigma = 1.0/gen.gamma(1,1)
sm_data = {'n':n, 'p':p, 'X':X, 'y':y, 'q':q, 'u':u, 'v':v, 'a':1, 'b':1}

# Initialise stan object
with open('linear_model_latent_factors.stan','r',newline='') as f:
    sm = stan.build(f.read(),sm_data,random_seed=1)

# Select the number of MCMC chains and iterations, then sample
chains, samples, burn = 1, 1000, 1000
fit=sm.sample(num_chains=chains, num_samples=samples, num_warmup=burn, save_warmup=False)

print(fit)
