#! /usr/bin/env python
## latent_factors_stan.py

import stan
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orthogonal_procrustes

# Simulate data
gen = np.random.default_rng(seed=0)
n = 50
p = 8
k = 3
Lambda = 10 * np.reshape(gen.normal(size=p*k),[p,k])
Sigma = 1.0/gen.gamma(1,1,p)
Omega =  Lambda.dot(Lambda.T) + np.diag(Sigma)
X = gen.multivariate_normal(np.zeros(p),Omega,size=n)
sm_data = {'n':n, 'p':p, 'X':X, 'k':k, 'a':1, 'b':1}

# Initialise stan object
with open('latent_factors.stan','r',newline='') as f:
    sm = stan.build(f.read(),sm_data,random_seed=1)

# Select the number of MCMC chains and iterations, then sample
chains, samples, burn = 1, 10000, 1000
fit=sm.sample(num_chains=chains, num_samples=samples, num_warmup=burn, save_warmup=False)

# Perform Procrustes alignment of sampled Lambdas
lam_hat = fit['Lambda'][:,:,-1]
for i in range(samples-1):
    l = fit['Lambda'][:,:,i]
    R = orthogonal_procrustes(l,fit['Lambda'][:,:,-1])[0]
    lam_hat += l.dot(R)

lam_hat /= samples
lam_hat = lam_hat.dot(orthogonal_procrustes(lam_hat,Lambda)[0])

lam_bar = np.mean(fit['Lambda'],axis=2)
lam_bar = lam_bar.dot(orthogonal_procrustes(lam_bar,Lambda)[0])

# Plot estimate and true values for Lambda
fig,axs=plt.subplots(1,3,figsize=(7,4),constrained_layout=True)
fig.canvas.manager.set_window_title('Latent factor lambdas')
axs[0].imshow(lam_hat, cmap='Blues')
axs[0].set_title(r'$\hat{\Lambda}$')
axs[1].imshow(Lambda, cmap='Blues')
axs[1].set_title(r'$\Lambda$')
axs[2].imshow(lam_bar, cmap='Blues')
axs[2].set_title('Crude estimate '+r'$\bar{\Lambda}$')
out=plt.setp(plt.gcf().get_axes(), xticks=[], yticks=[]);
plt.show()
