/train_pyceptron.py

https://github.com/ricardokrieg/pyceptron
Python | 130 lines | 84 code | 24 blank | 22 comment | 16 complexity | 880fa92f4a3bf1e3fe4948388b104f12 MD5 | raw file
  1. #!/usr/bin/python
  2. #-*- coding: utf-8 -*-
  3. import sys
  4. import numpy
  5. from random import random
  6. import Tkinter as tkinter
  7. import threading
  8. import csv
  9. class TrainPyceptron(threading.Thread):
  10. def __init__(self):
  11. threading.Thread.__init__(self)
  12. self.w = None
  13. self.x = None
  14. self.n = 0.01
  15. self.epoch = 0
  16. self.max_epochs = 1000
  17. self.n_lines = 0
  18. self.n_ins = 0
  19. self.load_x('in.csv')
  20. self.init_w()
  21. print 'w:', self.w
  22. print 'x:', self.x
  23. print 'd:', self.d
  24. print 'n:', self.n
  25. # __init__
  26. def run(self):
  27. error = True
  28. while error:
  29. error = False
  30. print '============== epoca %02d ==============' % (self.epoch+1)
  31. for i, xi in enumerate(self.x):
  32. u = numpy.dot(xi, self.w)
  33. y = self.sinal(u)
  34. print self.w, u, y, self.d[i],
  35. if y != self.d[i]:
  36. error = True
  37. self.hebb(i, y)
  38. # if
  39. print error
  40. # for
  41. self.epoch += 1
  42. if self.epoch >= self.max_epochs:
  43. break
  44. # while
  45. f = open('pesos', 'w')
  46. w_str = str(self.w)[1:]
  47. w_str = w_str[:-1]
  48. w_str = w_str.replace(' ', ' ')
  49. w_str = w_str.replace(' ', ' ')
  50. f.write(w_str)
  51. f.close()
  52. if error:
  53. print 'ABORTED'
  54. print '+++++++++++++++'
  55. print 'Resultado:'
  56. print 'Epocas = %02d' % self.epoch
  57. print 'w = ', self.w
  58. print '+++++++++++++++'
  59. # run
  60. def degrau(self, u):
  61. return 1 if u >= 0 else 0
  62. # degrau
  63. def sinal(self, u):
  64. return 1.0 if u >= 0 else -1.0
  65. # sinal
  66. def hebb(self, i, y):
  67. self.w = self.w + self.n*(self.d[i] - y)*self.x[i]
  68. # hebb
  69. def init_w(self):
  70. self.w = numpy.zeros([self.n_ins-1])
  71. for i in range(0,self.n_ins-1):
  72. self.w.put([i], [round(random(),4)])
  73. # init_w
  74. def load_x(self, file_path):
  75. reader = csv.reader(open(file_path, 'r'), delimiter=';')
  76. for row in reader:
  77. self.n_lines += 1
  78. self.n_ins = 0
  79. for param in row:
  80. self.n_ins += 1
  81. # for
  82. reader = csv.reader(open(file_path, 'r'), delimiter=';')
  83. self.x = numpy.zeros([self.n_lines, self.n_ins-1])
  84. self.d = numpy.zeros([self.n_lines])
  85. current_line = 0
  86. for row in reader:
  87. for i, param in enumerate(row):
  88. if i == self.n_ins-1:
  89. self.d.put([current_line], [float(param)])
  90. else:
  91. self.x.put([current_line*self.n_ins+i-current_line], [float(param)])
  92. # for
  93. current_line += 1
  94. # for
  95. # load_x
  96. # TrainPyceptron
  97. def main():
  98. #window = tkinter.Tk()
  99. #window.title('Pyceptron - @ricardokrieg')
  100. #window.geometry('800x600')
  101. #window.configure(bg='black')
  102. train_pyceptron = TrainPyceptron()
  103. train_pyceptron.start()
  104. #window.mainloop();
  105. # main
  106. main()