文法推論(1)Regular Positive Negative Inference

まじめに文法推論を勉強していこうと思う。
Grammatical Inference: Learning Automata and Grammars
http://www.amazon.co.jp/Grammatical-Inference-Learning-Automata-Grammars/dp/0521763169
という本が出ていたりする


文法推論は、正規表現とか文脈自由文法とか、様々な確率文法とかHMMとかを学習する問題だと思う。少なくとも、1960年代には、そういうことを調べる問題意識はあったらしい。どーでもいいけど、HMMについて書いた入門的な記事とか読むと、Baum-WelchとかViterbiとかは説明してあるけど、そもそも、どうやってモデルを作るのか説明してあるのを見たことがないのだけど、世間の人は、それで一体どうしてるのだろう。重要な問題だと思うのだけど、結構マイナーなテーマの気もする


まずは、とりあえず一番簡単であるはずの正規表現の推論。libalfとかいう、C/C++で書かれたDFA/NFA学習アルゴリズムを実装したフレームワークがあるらしい。そこには、アルゴリズムが9個くらい載っている。多い
About the libalf library
http://libalf.informatik.rwth-aachen.de/index.php?page=about


あと、Matlab/Octave向けに実装されたgitoolboxというのもある。
gitoolbox
http://code.google.com/p/gitoolbox/wiki/Documentation?tm=6

というわけで、自分で書く必要は全くないけど、勉強用ということでpythonで書く。


正規表現は、正例だけでは"学習できない"ことが知られているらしい(極限同定可能とかいう基準によるのだと思う)。証明とか知らないけど、適当な正規表現にマッチする例だけ全部与えても、それを受理するDFAは一般に無数にあるから?一応、正例だけで推論しようという試みとして、k-testable languageという正規表現のサブセットがあるらしいけど、実用性は全然ない気がする。libalfにあるのは、正例と負例の両方から学習するもので、RPNI(Regular Positive Negative Inference)というのが、よく見るので、これを実装して見る。論文は1992年なので、意外と新しい
IDENTIFYING REGULAR LANGUAGES IN POLYNOMIAL TIME
http://grfia.dlsi.ua.es/repositori/grfia/pubs/76/asspr1992.pdf


アルゴリズムは単純で、最初に正例を丁度受理集合とするDFAをPTA(Prefix-Tree Automata)として作って、隣接する状態をマージ(マージするとNFAになるので、適当にDFA化)して、それが負例を受理しなければ、マージしたDFAを元に、同じ処理を続けていく。負例を受理してしまった場合は、一つ前に戻って別のマージを試すetc.という具合で処理は進んでいく。初期状態として、PTAを選ぶ必然性はなくて、正例を受理し、負例を受理しない任意のDFAから始められるのだと思う。PTAの状態集合は、正例のprefixの集合と1:1に対応するので、以下の実装では、それを利用している

#!/usr/bin/env python
# -*- coding:utf-8 -*-

def RPNI(positives , negatives):
   assert(len(set(positives) & set(negatives))==0),"正例と負例に重複がある"
   def merge(nodes , symbols):
       nodes.sort(key=lambda xs:len(min(xs)))
       for xs in nodes:
           for ys in nodes:
               if xs==ys:continue
               if len( set(xs) & set([s[:-1] for s in ys]) )==0:
                   continue   #-- xs nodeからys nodeへの遷移がない
               nfa_nodes = [xs+ys]
               #print ("trial:merge %s and %s" % (str(xs),str(ys)))
               #-- mergeしなかったnodeを追加
               for zs in nodes:
                  if xs!=zs and ys!=zs:nfa_nodes.append(zs)
               yield nfa2dfa(nfa_nodes , symbols)
   def nfa2dfa(nodes , symbols):
       current_nodes = nodes
       while True:
           for sym in symbols:
               for xs in current_nodes:
                   testset = set([s+sym for s in xs])
                   adj_nodes = []
                   for ys in current_nodes:
                       if xs!=ys and len(testset & set(ys)):
                          adj_nodes.append(ys)
                   if len(testset & set(xs))>0 and len(adj_nodes)>0:
                      #print("nfa2dfa:merged %s and %s" % (str(xs) , str(adj_nodes)))
                      next_nodes = [xs+sum(adj_nodes, [])]
                      for zs in current_nodes:
                          if zs==xs or (zs in adj_nodes):continue
                          next_nodes.append(zs)
                      break
                   elif len(testset & set(xs))==0 and len(adj_nodes)>1:
                      #print("nfa2dfa:merged %s" % str(adj_nodes))
                      next_nodes = [sum(adj_nodes, [])]
                      for zs in current_nodes:
                          if zs in adj_nodes:continue
                          next_nodes.append( zs )
                      break
               else:
                   continue
               break
           else:
               return current_nodes
           current_nodes = next_nodes
   def gendfa(nodes , symbols):
       acc_st = []
       init_st = 0
       tbl = {}   #-- 遷移表
       for n,xs in enumerate(nodes):
          if '' in xs:init_st = n
          if len(set(xs) & set(positives))>0:acc_st.append(n)
          for sym in symbols:
              testset = set([s+sym for s in xs])
              for m,ys in enumerate(nodes):
                  if len(testset & set(ys))>0:
                     tbl[(n,sym)] = m
                     break
       def __test__(s):
           cur_st = init_st
           for c in s:
               cur_st = tbl.get((cur_st,c) , None)
               if cur_st is None:return False
           return (cur_st in acc_st)
       return __test__
   prefixes = list(set(sum([[s[:c] for c in xrange(len(s)+1)] for s in positives] , [])))
   chars = set([])
   for p in prefixes:
       for q in prefixes:
            if len(q)==0:continue
            if p==q[:-1]:chars.add(q[-1])
   current_nodes = [[n] for n in prefixes]
   while True:
       for next_nodes in merge(current_nodes , chars):
           #-- next_nodesが負例を受理しないかどうかのチェック
           for node in next_nodes:
               if len(set(node) & set(positives))==0:continue #--nodeは受理状態でない
               if len(set(node) & set(negatives))>0:
                  #print("trial failed:%s accepted\n" % (str(set(node) & set(negatives))))
                  break
           else:
               current_nodes = next_nodes
               break
           continue
       else:   #-- 適切なnext_nodesが見つからなかったらここに来る
           #print("DFA found:%s" % str(current_nodes))
           return gendfa(current_nodes , chars)


def testRPNI(positives,negatives):
   f = RPNI(positives , negatives)
   for s in positives:
       if not f(s):return False
   for s in negatives:
       if f(s):return False
   return True



if __name__=="__main__":
   testRPNI(["00","000" , "011" , "101","1010"] , ["1","01","10","1000"])
   testRPNI(["","aaa","aaba","ababa","bb","bbaaa"] , ["aa","ab","aaaa","ba"])
   testRPNI(["xxx@hoge.com" , "xyx@fuga.com" , "xxy@hoge.com"] , ["xxy" , "xxx@","xxy@hoge" , "xxx@fuga"])