ノンパラメトリックベイズ(0)適応的棄却サンプリング
もうブームは去った気がするけど、ノンパラメトリックベイズの勉強をすることにした。統計学とかで、パラメトリックモデル/ノンパラメトリックモデルという概念があるけど、それとは関係ないように見える。infinite GMMを実装しようとしたら、適応的棄却サンプリング(ARS,Adaptive rejection sampling)を使うとか書いてあって、よく分からなかったので、そこから。
論文
adaptive rejection sampling for gibbs sampling(PDF)
http://www.math.chalmers.se/Stat/Grundutb/CTH/mve186/1415/adaptive.sampling.pdf
によると、元々はDevroyeという人の本に書いてある方法を改良したものらしい(?)。adaptive rejection samplingという名前は、この論文で命名されたのじゃないかと思うけど。
確率密度関数p(x)が解析的な式で与えられている状況で、普通の棄却サンプリングだと、
Rejection sampling
https://en.wikipedia.org/wiki/Rejection_sampling#Theory
提案分布(proposal distribution)は人間がいい感じに設定することになっているが、そんないい分布わかんねーよとかいうケースで、p(x)がlog-concave(文字通り、p(x)の対数が凹関数であるという条件)を満たす時に、p(x)の定義域上で、複数のxを選んで、各xにおいて、log p(x)の接線を引く。それらをつなぐと、log p(x)を上から近似する区分的線形関数が得られる(log-concaveである必要がここで生じる)ので、これのexponentialを取って正規化したものを提案分布とする、ということらしい(論文では、正規化してない関数を$u_k(x)$、正規化した確率密度関数を$s_k(x)$と書いている)。また、接線を計算するため、p(x)は一回以上微分可能である必要がある。逆関数法では、p(x)を積分する必要があるが、積分はできないけど、微分はできるというケースはよくある
区分的線形関数のexponentialを確率密度関数とするような乱数は、(少し面倒ではあるが)普通に逆関数法で生成できる。
逆関数法
https://ja.wikipedia.org/wiki/%E9%80%86%E9%96%A2%E6%95%B0%E6%B3%95
あと、棄却サンプリングで、提案分布に従う乱数xを発生させて、最終的に棄却された時、提案分布を作る時に選んだ"複数のx"に、これも追加して、提案分布を改善する(一番最初の点は人間が選ぶ必要がある。最低二点必要)。提案分布とp(x)の乖離が大きいxでは、棄却率が高く、これによって乖離の大きい部分が改善される可能性が高い。繰り返し行うと、提案分布の近似は、どんどんよくなっていくので"適応的"という名前が付いている。infinite GMMでは、毎回(Gibbsサンプリングの一ループ毎に)パラメータが変わって、提案分布も作りなおしになるけど
論文では、上から抑える関数とは別に下から抑える関数(論文でsqueezing functionと呼んでているもの)も作っていて、多分、元の確率密度関数を評価するよりも、区分線形関数のexponentialを評価する方が、一般的に計算量が少ないから、ということで、そうしているのだと思う。色々な改良やバリエーションがあるようだけど、区分線系関数のexponentialを使って、上(や下)から確率密度関数を抑える関数を作るという点は共通しているっぽい
※)論文2.2.2において、
が成立するので、squeezing testはスキップしても結果は変わらない
最近、頭が悪くなった気がするので、これくらいだと理解できて良い。
以下に、単純な実装を置いておく。棄却されるたびに、h(x)とh'(x)を計算し直すのは、無駄の極みだけど、わかりやすさ重視。例として、ガンマ分布(np.random.gammaで生成できるが)
""" This is an implementation of the following paper written in python 3.5 Adaptive Rejection Sampling for Gibbs Sampling https://www.jstor.org/stable/2347565?seq=1#page_scan_tab_contents """ import numpy as np def rnd_from_piecewise_exponential(coeffs , intercepts , points): """ density function h(x) in [points[i] , points[i+1]] h(x) = exp(coeffs[i]*x + intercepts[i]) """ assert(len(coeffs)==len(intercepts)),(coeffs,intercepts) assert(len(coeffs)+1==len(points)),(coeffs,points) u = np.random.random() CDF = [0.0] #--cumulative density value at each points for i in range(len(points)-1): z_cur = points[i] z_next = points[i+1] a = coeffs[i] b = intercepts[i] if a!=0.0: r = (np.exp(a*z_next+b) - np.exp(a*z_cur+b))/a else: r = (znext-z)*np.exp(b) r0 = CDF[-1] CDF.append( r + r0 ) u = CDF[-1]*u #--CDFが正規化されてないので for i in range(len(points)-1): if CDF[i] < u and u <=CDF[i+1]: a = coeffs[i] b = intercepts[i] z = points[i] """ compute t such that ( exp(a*t+b) - exp(a*z+b) )/a == u - CDF[i] """ if a!=0.0: t0 = a*(u - CDF[i]) + np.exp(a*z+b) assert(t0 > 0.0),t0 t = ( np.log(t0) - b )/a else: t = (u - CDF[i])/np.exp(b) + z assert(t >= z and t <= points[i+1]),(t,z,points[i+1]) return t def ars(h , hprime , points , support=(0 , np.inf)): xs = [x for x in points] while 1: xs.sort() x0 = xs[0] xN = xs[-1] zs = [support[0]] coeffs = [] intercepts = [] for i in range(len(xs)-1): x = xs[i] xnext = xs[i+1] zi = (h(xnext) - h(x) - xnext*hprime(xnext)+x*hprime(x))/(hprime(x) - hprime(xnext)) zs.append( zi ) coeffs.append( hprime(x) ) intercepts.append( h(x) - hprime(x)*x ) zs.append( support[1] ) coeffs.append( hprime(xN) ) intercepts.append( h(xN) - hprime(xN)*xN ) y = rnd_from_piecewise_exponential(coeffs , intercepts , zs) w = np.random.random() #-- omit squeezing test # lkval = 0.0 # for i in range(len(xs)-1): # if xs[i]<=y and y<=xs[i+1]: # xcur = xs[i] # xnext = xs[i+1] # lkval = ( (xnext - y)*h(xcur) + (y-xcur)*h(xnext) )/(xnext-xcur) # break # if w <= np.exp(lkval - ukval): # return y ukval = 0.0 for i in range(len(zs)-1): if zs[i]<=y and y<=zs[i+1]: ukval = h(xs[i]) + (y-xs[i])*hprime(xs[i]) break if w<=np.exp(h(y) - ukval): return y else: xs.append( y ) if __name__=="__main__": #-- gamma distribution alpha = 2.0 #-- shape parameter beta = 2.0 #-- rate parameter C = np.log(4.0) #-- log( beta^alpha/Gamma(alpha) ) def h(x): return ((alpha-1)*np.log(x) - beta*x + C) def hprime(x): return (alpha-1)/x - beta print( ars(h , hprime , [0.1 , 1.0 , 5.0]) )