迷路ソルバー

ふと、A*-アルゴリズムって実装したことなかったな、と思ったので実装してみただけ。A*-アルゴリズムが正しいことは自明でないと思うのだけど、意外とどこにも証明とかなくて、結局、原論文を読めばいいという結論に至った。

A Formal Basis for the Heuristic Determination of Minimum Cost Paths"
http://fai.cs.uni-saarland.de/teaching/winter12-13/heuristic-search-material/Astar.pdf

#!/usr/bin/env python

#--breadth first
def route_bfs(M,start,end):
   def next((x,y)):
       if x>0 and M[y][x-1]!='*':yield (x-1,y)
       if x+1<len(M[y]) and M[y][x+1]!='*':yield (x+1,y)
       if y>0 and M[y-1][x]!='*':yield (x,y-1)
       if y+1<len(M) and M[y+1][x]!='*':yield (x,y+1)
   Q = [(start , [])]
   while Q:
       cp,rr = Q[0]
       del Q[0]
       if cp==end:return rr+[end]
       for p in next(cp):
           if p in rr:continue
           Q.append( (p,rr+[cp]) )
   return []



#-- depth first
def route_dfs(M,start,end):
   def next((x,y)):
       if x>0 and M[y][x-1]!='*':yield (x-1,y)
       if x+1<len(M[y]) and M[y][x+1]!='*':yield (x+1,y)
       if y>0 and M[y-1][x]!='*':yield (x,y-1)
       if y+1<len(M) and M[y+1][x]!='*':yield (x,y+1)
   def aux(M,s,e,rr):
       if s==e:return [s]
       routes = [aux(M , p , e , rr+[s]) for p in next(s) if  not p in rr]
       routes = [x for x in routes if len(x)>0]
       if len(routes)>0:return [s]+min(routes,key=lambda x:len(x))
       else:return []
   return aux(M,start,end,[])


#-- Dijkstra
def route_dijk(M,start,end):
   def next((x,y)):
       if x>0 and M[y][x-1]!='*':yield (x-1,y)
       if x+1<len(M[y]) and M[y][x+1]!='*':yield (x+1,y)
       if y>0 and M[y-1][x]!='*':yield (x,y-1)
       if y+1<len(M) and M[y+1][x]!='*':yield (x,y+1)
   Q = [start]
   paths = {start:[start]}
   while Q:
      cp = Q.pop()
      cr = paths[cp]
      for np in next(cp):
         nr = paths.get(np , None)
         if nr is None or len(nr)>=len(cr)+1:
            paths[np] = cr+[np]
            Q.append(np)
   return paths.get(end , [])


#-- A*
def route_astar(M,start,end):
   def h((x,y)):    #--heuristic
       return abs(x-end[0])+abs(y-end[1])
   def next((x,y)):
       if x>0 and M[y][x-1]!='*':yield (x-1,y)
       if x+1<len(M[y]) and M[y][x+1]!='*':yield (x+1,y)
       if y>0 and M[y-1][x]!='*':yield (x,y-1)
       if y+1<len(M) and M[y+1][x]!='*':yield (x,y+1)
   openSet = {start:h(start)}
   closeSet = {}
   parent = {}
   while True:
      if len(openSet)==0:return []
      node,hval = min(openSet.items() , key=lambda x:x[1])
      if node==end:
          break
      else:
          closeSet[node] = openSet[node]
          del openSet[node]
      for np in next(node):
          rval = (hval-h(node)) + (h(np) + 1)
          if rval < openSet.get(np,rval):
              openSet[np] = rval
              parent[np] = node
          elif rval < closeSet.get(np,rval):
              openSet[np] = rval
              parent[np] = node
              del closeSet[np]
          elif not closeSet.has_key(np) and not openSet.has_key(np):
              openSet[np] = rval
              parent[np] = node
          else:
              pass
   paths = []
   current = end
   while True:
      paths.append(current)
      if current==start:return list(reversed(paths))
      current = parent[current]


import time
if __name__=="__main__":
   M = """*S*****************************
*   *                   *     *
*** ******* ***** * *** *** * *
*         * *   * *   *     * *
***** *** * *** * *** *** *****
*   * *   *     *       *     *
*** * *** * * *** * * ******* *
*     *   * *   * * *   *     *
*** ***** *** * * * *** * *****
*     *         * * *   *   * *
* *********** ********* ***** *
* * *   *   * *   *     *     *
* *** * * *** * *** ***** *** *
*     *     *   *     *     * *
* ********* ********* * ***** *
* *   *     *         * *     *
* * ******* *** ******* *** ***
* * * *   * *   * *     * *   *
*** * * *** * *** * * *** * * *
*           *       *     * * *
*****************************G*"""
   M=M.split('\n')
   algorithms = [("bfs",route_bfs),("dfs",route_dfs),("Dijkstra",route_dijk),("A*",route_astar)]
   for name,route in algorithms:
       st = time.clock()
       route(M,(1,0),(29,20))
       et = time.clock()
       print ("%s: %1.4f(sec)" % (name,et-st))
   print "------------------"
   M = [" "*100]
   for name,route in algorithms:
       st = time.clock()
       route(M,(1,0),(80,0))
       et = time.clock()
       print ("%s: %1.4f(sec)" % (name,et-st))
   print "-------------------"
   M = [" "*6 for _ in xrange(6)]
   for name,route in algorithms:
       st = time.clock()
       route(M,(0,0),(4,4))
       et = time.clock()
       print ("%s: %1.4f(sec)" % (name,et-st))