/python_modules/plearn/learners/modulelearners/sampler/inputweights.py

https://github.com/lisa-lab/PLearn · Python · 132 lines · 86 code · 24 blank · 22 comment · 23 complexity · 029e9c90ceb315f29095bbb1fd8fd73d MD5 · raw file

  1. from plearn.learners.modulelearners import *
  2. zoom_factor = 5
  3. from plearn.learners.modulelearners.sampler import *
  4. import sys, os.path
  5. def view_inputweights(learner, Nim, save_dir):
  6. save_image=False
  7. if save_dir<>None and len(save_dir)<>0:
  8. print "\nDo you want to save learner in the directory "+save_dir+"?"
  9. print "1.[default] No"
  10. print "2. Yes"
  11. c = pause()
  12. while c not in [0,1,2,EXITCODE]:
  13. c = pause()
  14. if c==2:
  15. save_image=True
  16. elif c==EXITCODE:
  17. return
  18. if save_image:
  19. print "\nChecking/creating directory "+save_dir+"\n"
  20. os.system('mkdir -p '+save_dir)
  21. inputweights_man()
  22. #
  23. # Getting the RBMmodule which sees the image (looking at size of the down layer)
  24. #
  25. modules=getModules(learner)
  26. for i in range(len(modules)):
  27. module = modules[i]
  28. if isModule(module,'RBM') and module.connection.down_size == Nim:
  29. image_RBM=learner.module.modules[i]
  30. break
  31. image_RBM_name=image_RBM.name
  32. zoom_factor = globals()['zoom_factor']
  33. if 'RBMMatrixConnection' in str(type(image_RBM.connection)):
  34. screen=init_screen(Nim,zoom_factor)
  35. for i in range(len(image_RBM.connection.weights)):
  36. weights=image_RBM.connection.weights[i]
  37. print str(i+1)+"/"+str(len(image_RBM.connection.weights))
  38. c = draw_normalized_image( weights, screen, zoom_factor )
  39. if save_image:
  40. fname = save_dir+'/filters-%05d.jpg' % i
  41. os.system('import -window "pygame window" ' + fname )
  42. if c==EXITCODE:
  43. return
  44. elif 'RBMMixedConnection' in str(type(image_RBM.connection)):
  45. N_filter = len(image_RBM.connection.sub_connections)
  46. N_inputim = len(image_RBM.connection.sub_connections[0])
  47. size_filter = image_RBM.connection.sub_connections[0][0].kernel.shape
  48. zoom_factor **= 2
  49. #
  50. # to complete.... see all the weights at the same time
  51. #
  52. # N=math.ceil(math.sqrt(N_filter))
  53. # print N
  54. # print (size_filter[0]*N_inputim+(N_inputim-1))*N_filter
  55. # print size_filter[1]*N_filter
  56. # print zoom_factor
  57. # return
  58. # screen=init_screen( (size_filter[0]*N_inputim*N , size_filter[1]*N) , zoom_factor)
  59. # for i in range(N_filter):
  60. # X = math.fmod(N_filter,i)
  61. # weights = image_RBM.connection.sub_connections[i][0].kernel
  62. # print str(i+1)+"/"+str(N_filter)
  63. # for j in range(1,N_inputim):
  64. # weights.resize( size_filter[0]*(j+1)*(i+1)+j*(i+1), size_filter[1]*(i+1) )
  65. # weights[size_filter[0]*j*i+1:]=image_RBM.connection.sub_connections[i][j].kernel
  66. # draw_normalized_image( weights, screen, zoom_factor )
  67. # return
  68. screen=init_screen( (size_filter[0]*N_inputim+(N_inputim-1) , size_filter[1]) , zoom_factor)
  69. for i in range(N_filter):
  70. weights = image_RBM.connection.sub_connections[i][0].kernel
  71. print str(i+1)+"/"+str(N_filter)
  72. for j in range(1,N_inputim):
  73. weights.resize( size_filter[0]*(j+1)+j, size_filter[1] )
  74. weights[size_filter[0]*j]=[0]*size_filter[1]
  75. weights[size_filter[0]*j+1:]=image_RBM.connection.sub_connections[i][j].kernel
  76. c = draw_normalized_image( weights, screen, zoom_factor )
  77. if save_image:
  78. fname = save_dir+'/filters-%05d.jpg' % i
  79. os.system('import -window "pygame window" ' + fname )
  80. if c==EXITCODE:
  81. return
  82. else:
  83. raise TypeError, "sampler::view_inputweights() not yet implemented for RBM connection of type "+str(type(image_RBM.connection))
  84. def inputweights_man():
  85. print "\nPlease type:"
  86. print ": <ENTER> : to continue Gibbs Sampling (same gibbs step)"
  87. print ": q : (quit) when you are fed up\n"
  88. print "Meaning of gray levels (g):"
  89. print "\tg = 127 <-> w = 0"
  90. print "\tg > 127 <-> w > 0"
  91. print "\tg < 127 <-> w < 0"
  92. print "\tg = 255 <-> w = +max{ -min(w), max(w) }"
  93. print "\tg = 0 <-> w = -max{ -min(w), max(w) }\n"
  94. if __name__ == "__main__":
  95. if len(sys.argv) < 2:
  96. print "Usage:\n\t" + sys.argv[0] + " <ModuleLearner_filename> <Image_size>\n"
  97. print "Purpose:\n\tSee weights of the RBM which sees image\n\t(i.e. with visible_size=Image_size)"
  98. inputweights_man()
  99. sys.exit()
  100. learner_filename = sys.argv[1]
  101. Nim = int(sys.argv[2])
  102. if os.path.isfile(learner_filename) == False:
  103. raise TypeError, "Cannot find file "+learner_filename
  104. print " loading... "+learner_filename
  105. learner = loadObject(learner_filename)
  106. if 'HyperLearner' in str(type(learner)):
  107. learner=learner.learner
  108. view_inputweights(learner, Nim)