/Neuroph/org/neuroph/core/NeuralNetwork.java

https://bitbucket.org/dusankrivosija/som · Java · 650 lines · 281 code · 76 blank · 293 comment · 23 complexity · 53eef372f5992d07ad53f503fa83e5f4 MD5 · raw file

  1. /**
  2. * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. package org.neuroph.core;
  17. import java.io.BufferedInputStream;
  18. import java.io.BufferedOutputStream;
  19. import java.io.File;
  20. import java.io.FileInputStream;
  21. import java.io.FileNotFoundException;
  22. import java.io.FileOutputStream;
  23. import java.io.IOException;
  24. import java.io.InputStream;
  25. import java.io.ObjectInputStream;
  26. import java.io.ObjectOutputStream;
  27. import java.io.Serializable;
  28. import java.util.HashMap;
  29. import java.util.Iterator;
  30. import java.util.Map;
  31. import java.util.Observable;
  32. import java.util.Random;
  33. import java.util.Vector;
  34. import org.neuroph.core.exceptions.VectorSizeMismatchException;
  35. import org.neuroph.core.learning.IterativeLearning;
  36. import org.neuroph.core.learning.LearningRule;
  37. import org.neuroph.core.learning.TrainingSet;
  38. import org.neuroph.util.NeuralNetworkType;
  39. import org.neuroph.util.VectorParser;
  40. import org.neuroph.util.plugins.LabelsPlugin;
  41. import org.neuroph.util.plugins.PluginBase;
  42. /**
  43. *<pre>
  44. * Base class for artificial neural networks. It provides generic structure and functionality
  45. * for the neural networks. Neural network contains a collection of neuron layers and learning rule.
  46. * Custom neural networks are created by deriving from this class, creating layers of interconnected network specific neurons,
  47. * and setting network specific learning rule.
  48. *</pre>
  49. *
  50. * @see Layer
  51. * @see LearningRule
  52. * @author Zoran Sevarac <sevarac@gmail.com>
  53. */
  54. public class NeuralNetwork extends Observable implements Runnable, Serializable {
  55. /**
  56. * The class fingerprint that is set to indicate serialization
  57. * compatibility with a previous version of the class.
  58. */
  59. private static final long serialVersionUID = 3L;
  60. /**
  61. * Network type id (see neuroph.util.NeuralNetworkType)
  62. */
  63. private NeuralNetworkType type;
  64. /**
  65. * Neural network
  66. */
  67. private Vector<Layer> layers;
  68. /**
  69. * Reference to network input neurons
  70. */
  71. private Vector<Neuron> inputNeurons;
  72. /**
  73. * Reference to newtwork output neurons
  74. */
  75. private Vector<Neuron> outputNeurons;
  76. /**
  77. * Learning rule for this network
  78. */
  79. private LearningRule learningRule; // learning algorithme
  80. /**
  81. * Separate thread for learning rule
  82. */
  83. private transient Thread learningThread; // thread for learning rule
  84. /**
  85. * Plugins collection
  86. */
  87. private Map<String, PluginBase> plugins;
  88. /**
  89. * Creates an instance of empty neural network.
  90. */
  91. public NeuralNetwork() {
  92. this.layers = new Vector<Layer>();
  93. this.plugins = new HashMap<String, PluginBase>();
  94. this.addPlugin(new LabelsPlugin());
  95. }
  96. /**
  97. * Adds layer to neural network
  98. *
  99. * @param layer
  100. * layer to add
  101. */
  102. public void addLayer(Layer layer) {
  103. layer.setParentNetwork(this);
  104. this.layers.add(layer);
  105. }
  106. /**
  107. * Adds layer to specified index position in network
  108. *
  109. * @param idx
  110. * index position to add layer
  111. * @param layer
  112. * layer to add
  113. */
  114. public void addLayer(int idx, Layer layer) {
  115. layer.setParentNetwork(this);
  116. this.layers.add(idx, layer);
  117. }
  118. /**
  119. * Removes specified layer from network
  120. *
  121. * @param layer
  122. * layer to remove
  123. */
  124. public void removeLayer(Layer layer) {
  125. this.layers.removeElement(layer);
  126. }
  127. /**
  128. * Removes layer at specified index position from net
  129. *
  130. * @param idx
  131. * int value represents index postion of layer which should be
  132. * removed
  133. */
  134. public void removeLayerAt(int idx) {
  135. this.layers.removeElementAt(idx);
  136. }
  137. /**
  138. * Returns interface for iterating layers
  139. *
  140. * @return iterator interface for network getLayersIterator
  141. */
  142. public Iterator<Layer> getLayersIterator() {
  143. return this.layers.iterator();
  144. }
  145. /**
  146. * Returns getLayersIterator Vector collection
  147. *
  148. * @return getLayersIterator Vector collection
  149. */
  150. public Vector<Layer> getLayers() {
  151. return this.layers;
  152. }
  153. /**
  154. * Returns layer at specified index
  155. *
  156. * @param idx
  157. * layer index position
  158. * @return layer at specified index position
  159. */
  160. public Layer getLayerAt(int idx) {
  161. return this.layers.elementAt(idx);
  162. }
  163. /**
  164. * Returns index position of the specified layer
  165. *
  166. * @param layer
  167. * requested Layer object
  168. * @return layer position index
  169. */
  170. public int indexOf(Layer layer) {
  171. return this.layers.indexOf(layer);
  172. }
  173. /**
  174. * Returns number of layers in network
  175. *
  176. * @return number of layes in net
  177. */
  178. public int getLayersCount() {
  179. return this.layers.size();
  180. }
  181. /**
  182. * Sets network input. Input Vector is collection of Double values.
  183. *
  184. * @param inputVector
  185. * network input vector
  186. */
  187. public void setInput(Vector<Double> inputVector) throws VectorSizeMismatchException {
  188. if (inputVector.size() != this.inputNeurons.size())
  189. throw new VectorSizeMismatchException("Input vector size does not match network input dimension!");
  190. Iterator<Double> inputIterator = inputVector.iterator();
  191. for(Neuron neuron : this.inputNeurons) {
  192. Double input = inputIterator.next(); // get input value
  193. neuron.setInput(input); // set input to the coresponding neuron
  194. }
  195. }
  196. /**
  197. * Sets network input. Input is array of double values.
  198. *
  199. * @param inputArray
  200. * network input as double array
  201. */
  202. public void setInput(double ... inputArray) throws VectorSizeMismatchException {
  203. if (inputArray.length != inputNeurons.size())
  204. throw new VectorSizeMismatchException("Input vector size does not match network input dimension!");
  205. setInput(VectorParser.convertToVector(inputArray));
  206. }
  207. /**
  208. * Returns network output Vector. Output Vector is a collection of Double
  209. * values.
  210. *
  211. * @return network output Vector
  212. */
  213. public Vector<Double> getOutput() {
  214. Vector<Double> outputVector = new Vector<Double>();
  215. for(Neuron neuron : this.outputNeurons) {
  216. double output = neuron.getOutput();
  217. outputVector.add(new Double(output));
  218. }
  219. return outputVector;
  220. }
  221. /**
  222. * Returns network output vector as double array
  223. *
  224. * @return network output vector as double array
  225. */
  226. public double[] getOutputAsArray() {
  227. return VectorParser.convertToArray(getOutput());
  228. }
  229. /**
  230. * Performs calculation on whole network
  231. */
  232. public void calculate() {
  233. for(Layer layer : this.layers) {
  234. layer.calculate();
  235. }
  236. }
  237. /**
  238. * Resets the activation levels for whole network
  239. */
  240. public void reset() {
  241. for(Layer layer : this.layers) {
  242. layer.reset();
  243. }
  244. }
  245. /**
  246. * Implementation of Runnable interface for calculating network in the
  247. * separate thread.
  248. */
  249. @Override
  250. public void run() {
  251. this.calculate();
  252. }
  253. /**
  254. * Trains the network to learn the specified training set.
  255. * This method is deprecated use learnInNewThread or learnInSameThread instead.
  256. * @param trainingSetToLearn
  257. * set of training elements to learn
  258. * @deprecated
  259. */
  260. public void learn(TrainingSet trainingSetToLearn) {
  261. learnInNewThread(trainingSetToLearn);
  262. }
  263. /**
  264. * Starts learning in a new thread to learn the specified training set,
  265. * and immediately returns from method to the current thread execution
  266. * @param trainingSetToLearn
  267. * set of training elements to learn
  268. */
  269. public void learnInNewThread(TrainingSet trainingSetToLearn) {
  270. learningRule.setTrainingSet(trainingSetToLearn);
  271. learningThread = new Thread(learningRule);
  272. learningThread.start();
  273. }
  274. /**
  275. * Starts learning with specified learning rule in new thread to learn the
  276. * specified training set, and immediately returns from method to the current thread execution
  277. * @param trainingSetToLearn
  278. * set of training elements to learn
  279. * @param learningRule
  280. * learning algorithm
  281. */
  282. public void learnInNewThread(TrainingSet trainingSetToLearn, LearningRule learningRule) {
  283. setLearningRule(learningRule);
  284. learningRule.setTrainingSet(trainingSetToLearn);
  285. learningThread = new Thread(learningRule);
  286. learningThread.start();
  287. }
  288. /**
  289. * Starts the learning in the current running thread to learn the specified
  290. * training set, and returns from method when network is done learning
  291. * @param trainingSetToLearn
  292. * set of training elements to learn
  293. */
  294. public void learnInSameThread(TrainingSet trainingSetToLearn) {
  295. learningRule.setTrainingSet(trainingSetToLearn);
  296. learningRule.run();
  297. }
  298. /**
  299. * Starts the learning with specified learning rule in the current running
  300. * thread to learn the specified training set, and returns from method when network is done learning
  301. * @param trainingSetToLearn
  302. * set of training elements to learn
  303. * @param learningRule
  304. * learning algorithm
  305. * *
  306. */
  307. public void learnInSameThread(TrainingSet trainingSetToLearn, LearningRule learningRule) {
  308. setLearningRule(learningRule);
  309. learningRule.setTrainingSet(trainingSetToLearn);
  310. learningRule.run();
  311. }
  312. /**
  313. * Stops learning
  314. */
  315. public void stopLearning() {
  316. learningRule.stopLearning();
  317. }
  318. /**
  319. * Pause the learning - puts learning thread in wait state.
  320. * Makes sense only wen learning is done in new thread with learnInNewThread() method
  321. */
  322. public void pauseLearning() {
  323. if ( learningRule instanceof IterativeLearning)
  324. ((IterativeLearning)learningRule).pause();
  325. }
  326. /**
  327. * Resumes paused learning - notifies the learning thread to continue
  328. */
  329. public void resumeLearning() {
  330. if ( learningRule instanceof IterativeLearning)
  331. ((IterativeLearning)learningRule).resume();
  332. }
  333. /**
  334. * Randomizes connection weights for the whole network
  335. */
  336. public void randomizeWeights() {
  337. for(Layer layer : this.layers) {
  338. layer.randomizeWeights();
  339. }
  340. }
  341. /**
  342. * Initialize connection weights for the whole network to a value
  343. *
  344. * @param value the weight value
  345. */
  346. public void initializeWeights(double value) {
  347. for(Layer layer : this.layers) {
  348. layer.initializeWeights(value);
  349. }
  350. }
  351. /**
  352. * Initialize connection weights for the whole network using a
  353. * random number generator
  354. *
  355. * @param generator the random number generator
  356. */
  357. public void initializeWeights(Random generator) {
  358. for(Layer layer : this.layers) {
  359. layer.initializeWeights(generator);
  360. }
  361. }
  362. public void initializeWeights(double min, double max) {
  363. for(Layer layer : this.layers) {
  364. layer.initializeWeights(min, max);
  365. }
  366. }
  367. /**
  368. * Returns type of this network
  369. *
  370. * @return network type
  371. */
  372. public NeuralNetworkType getNetworkType() {
  373. return type;
  374. }
  375. /**
  376. * Sets type for this network
  377. *
  378. * @param type network type
  379. */
  380. public void setNetworkType(NeuralNetworkType type) {
  381. this.type = type;
  382. }
  383. /**
  384. * Gets reference to input neurons Vector.
  385. *
  386. * @return input neurons Vector
  387. */
  388. public Vector<Neuron> getInputNeurons() {
  389. return this.inputNeurons;
  390. }
  391. /**
  392. * Sets reference to input neurons Vector
  393. *
  394. * @param inputNeurons
  395. * input neurons collection
  396. */
  397. public void setInputNeurons(Vector<Neuron> inputNeurons) {
  398. this.inputNeurons = inputNeurons;
  399. }
  400. /**
  401. * Returns reference to output neurons Vector.
  402. *
  403. * @return output neurons Vector
  404. */
  405. public Vector<Neuron> getOutputNeurons() {
  406. return this.outputNeurons;
  407. }
  408. /**
  409. * Sets reference to output neurons Vector.
  410. *
  411. * @param outputNeurons
  412. * output neurons collection
  413. */
  414. public void setOutputNeurons(Vector<Neuron> outputNeurons) {
  415. this.outputNeurons = outputNeurons;
  416. }
  417. /**
  418. * Returns the learning algorithm of this network
  419. *
  420. * @return algorithm for network training
  421. */
  422. public LearningRule getLearningRule() {
  423. return this.learningRule;
  424. }
  425. /**
  426. * Sets learning algorithm for this network
  427. *
  428. * @param learningRule learning algorithm for this network
  429. */
  430. public void setLearningRule(LearningRule learningRule) {
  431. learningRule.setNeuralNetwork(this);
  432. this.learningRule = learningRule;
  433. }
  434. /**
  435. * Returns the current learning thread (if it is learning in the new thread
  436. * Check what happens if it learns in the same thread)
  437. */
  438. public Thread getLearningThread() {
  439. return learningThread;
  440. }
  441. /**
  442. * Notifies observers about some change
  443. */
  444. public void notifyChange() {
  445. setChanged();
  446. notifyObservers();
  447. clearChanged();
  448. }
  449. /**
  450. * Creates connection with specified weight value between specified neurons
  451. *
  452. * @param fromNeuron neuron to connect
  453. * @param toNeuron neuron to connect to
  454. * @param weightVal connection weight value
  455. */
  456. public void createConnection(Neuron fromNeuron, Neuron toNeuron, double weightVal) {
  457. Connection connection = new Connection(fromNeuron, weightVal);
  458. toNeuron.addInputConnection(connection);
  459. }
  460. @Override
  461. public String toString() {
  462. if (plugins.containsKey("LabelsPlugin")) {
  463. LabelsPlugin labelsPlugin = ((LabelsPlugin)this.getPlugin("LabelsPlugin"));
  464. String label = labelsPlugin.getLabel(this);
  465. if (label!=null) return label;
  466. }
  467. return super.toString();
  468. }
  469. /**
  470. * Saves neural network into the specified file.
  471. *
  472. * @param filePath
  473. * file path to save network into
  474. */
  475. public void save(String filePath) {
  476. ObjectOutputStream out = null;
  477. try {
  478. File file = new File(filePath);
  479. out = new ObjectOutputStream( new BufferedOutputStream( new FileOutputStream(file)));
  480. out.writeObject(this);
  481. out.flush();
  482. } catch(IOException ioe) {
  483. ioe.printStackTrace();
  484. } finally {
  485. if(out != null) {
  486. try {
  487. out.close();
  488. } catch (IOException e) {
  489. }
  490. }
  491. }
  492. }
  493. /**
  494. * Loads neural network from the specified file.
  495. *
  496. * @param filePath
  497. * file path to load network from
  498. * @return loaded neural network as NeuralNetwork object
  499. */
  500. public static NeuralNetwork load(String filePath) {
  501. ObjectInputStream oistream = null;
  502. try {
  503. File file = new File(filePath);
  504. if (!file.exists()) {
  505. throw new FileNotFoundException("Cannot find file: " + filePath);
  506. }
  507. oistream = new ObjectInputStream( new BufferedInputStream(new FileInputStream(filePath)));
  508. NeuralNetwork nnet = (NeuralNetwork) oistream.readObject();
  509. return nnet;
  510. } catch(IOException ioe) {
  511. ioe.printStackTrace();
  512. } catch(ClassNotFoundException cnfe) {
  513. cnfe.printStackTrace();
  514. } finally {
  515. if(oistream != null) {
  516. try {
  517. oistream.close();
  518. } catch (IOException ioe) {
  519. }
  520. }
  521. }
  522. return null;
  523. }
  524. /**
  525. * Loads neural network from the specified InputStream.
  526. *
  527. * @param inputStream
  528. * input stream to load network from
  529. * @return loaded neural network as NeuralNetwork object
  530. */
  531. public static NeuralNetwork load(InputStream inputStream) {
  532. ObjectInputStream oistream = null;
  533. try {
  534. oistream = new ObjectInputStream(new BufferedInputStream(inputStream));
  535. NeuralNetwork nnet = (NeuralNetwork) oistream.readObject();
  536. return nnet;
  537. } catch(IOException ioe) {
  538. ioe.printStackTrace();
  539. } catch(ClassNotFoundException cnfe) {
  540. cnfe.printStackTrace();
  541. } finally {
  542. if(oistream != null) {
  543. try {
  544. oistream.close();
  545. } catch (IOException ioe) {
  546. }
  547. }
  548. }
  549. return null;
  550. }
  551. /**
  552. * Adds plugin to neural network
  553. * @param plugin neural network plugin to add
  554. */
  555. public void addPlugin(PluginBase plugin) {
  556. plugin.setParentNetwork(this);
  557. this.plugins.put(plugin.getName(), plugin);
  558. }
  559. /**
  560. * Returns the requested plugin
  561. * @param pluginName name of the plugin to get
  562. * @return plugin with specified name
  563. */
  564. public PluginBase getPlugin(String pluginName) {
  565. return this.plugins.get(pluginName);
  566. }
  567. /**
  568. * Removes the plugin with specified name
  569. * @param pluginName name of the plugin to remove
  570. */
  571. public void removePlugin(String pluginName) {
  572. this.plugins.remove(pluginName);
  573. }
  574. }