#! /usr/bin/env python
# polya_tree_sample.py

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

def polya_tree_sample(a,F_0_inv=norm.ppf,m=9,seed=0):
    gen = np.random.default_rng(seed=seed)
    p = np.ones(2**m)
    for j in range(m,0,-1):
        a_j = a*j**2
        i = 0
        for k in range(2**(j-1)):
            b = gen.beta(a_j,a_j)
            for l in range(2**(m-j)):
                p[i+l] *= b
                p[i+l+2**(m-j)] *= 1-b
            i += 2**(m-j+1)

    x = np.array([F_0_inv((z+1.0)/2**m) for z in range(2**m-1)])
    return((x[1:]+x[:-1])/2, p[1:-1]/np.diff(x))

for a in [10,10**3,10**5]:
    plt.plot(*polya_tree_sample(a),label=r'$\alpha=$'+str(a))
plt.legend(loc="upper right")
plt.autoscale(enable=True, axis='x', tight=True)
plt.gcf().canvas.manager.set_window_title('Polya tree samples')
plt.show()
