PageRenderTime 35ms CodeModel.GetById 22ms RepoModel.GetById 0ms app.codeStats 0ms

/projects/weka-3-6-9/weka-src/src/main/java/weka/classifiers/trees/m5/Rule.java

https://gitlab.com/essere.lab.public/qualitas.class-corpus
Java | 647 lines | 309 code | 102 blank | 236 comment | 53 complexity | 9953593635ff9f738e59810a764f1c6e MD5 | raw file
  1. /*
  2. * This program is free software; you can redistribute it and/or modify
  3. * it under the terms of the GNU General Public License as published by
  4. * the Free Software Foundation; either version 2 of the License, or
  5. * (at your option) any later version.
  6. *
  7. * This program is distributed in the hope that it will be useful,
  8. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  9. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  10. * GNU General Public License for more details.
  11. *
  12. * You should have received a copy of the GNU General Public License
  13. * along with this program; if not, write to the Free Software
  14. * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
  15. */
  16. /*
  17. * Rule.java
  18. * Copyright (C) 2000 University of Waikato, Hamilton, New Zealand
  19. *
  20. */
  21. package weka.classifiers.trees.m5;
  22. import weka.core.Instance;
  23. import weka.core.Instances;
  24. import weka.core.RevisionHandler;
  25. import weka.core.RevisionUtils;
  26. import weka.core.Utils;
  27. import java.io.Serializable;
  28. /**
  29. * Generates a single m5 tree or rule
  30. *
  31. * @author Mark Hall
  32. * @version $Revision: 6260 $
  33. */
  34. public class Rule
  35. implements Serializable, RevisionHandler {
  36. /** for serialization */
  37. private static final long serialVersionUID = -4458627451682483204L;
  38. protected static int LEFT = 0;
  39. protected static int RIGHT = 1;
  40. /**
  41. * the instances covered by this rule
  42. */
  43. private Instances m_instances;
  44. /**
  45. * the class index
  46. */
  47. private int m_classIndex;
  48. /**
  49. * the number of attributes
  50. */
  51. private int m_numAttributes;
  52. /**
  53. * the number of instances in the dataset
  54. */
  55. private int m_numInstances;
  56. /**
  57. * the indexes of the attributes used to split on for this rule
  58. */
  59. private int[] m_splitAtts;
  60. /**
  61. * the corresponding values of the split points
  62. */
  63. private double[] m_splitVals;
  64. /**
  65. * the corresponding internal nodes. Used for smoothing rules.
  66. */
  67. private RuleNode[] m_internalNodes;
  68. /**
  69. * the corresponding relational operators (0 = "<=", 1 = ">")
  70. */
  71. private int[] m_relOps;
  72. /**
  73. * the leaf encapsulating the linear model for this rule
  74. */
  75. private RuleNode m_ruleModel;
  76. /**
  77. * the top of the m5 tree for this rule
  78. */
  79. protected RuleNode m_topOfTree;
  80. /**
  81. * the standard deviation of the class for all the instances
  82. */
  83. private double m_globalStdDev;
  84. /**
  85. * the absolute deviation of the class for all the instances
  86. */
  87. private double m_globalAbsDev;
  88. /**
  89. * the instances covered by this rule
  90. */
  91. private Instances m_covered;
  92. /**
  93. * the number of instances covered by this rule
  94. */
  95. private int m_numCovered;
  96. /**
  97. * the instances not covered by this rule
  98. */
  99. private Instances m_notCovered;
  100. /**
  101. * use a pruned m5 tree rather than make a rule
  102. */
  103. private boolean m_useTree;
  104. /**
  105. * use the original m5 smoothing procedure
  106. */
  107. private boolean m_smoothPredictions;
  108. /**
  109. * Save instances at each node in an M5 tree for visualization purposes.
  110. */
  111. private boolean m_saveInstances;
  112. /**
  113. * Make a regression tree instead of a model tree
  114. */
  115. private boolean m_regressionTree;
  116. /**
  117. * Build unpruned tree/rule
  118. */
  119. private boolean m_useUnpruned;
  120. /**
  121. * The minimum number of instances to allow at a leaf node
  122. */
  123. private double m_minNumInstances;
  124. /**
  125. * Constructor declaration
  126. *
  127. */
  128. public Rule() {
  129. m_useTree = false;
  130. m_smoothPredictions = false;
  131. m_useUnpruned = false;
  132. m_minNumInstances = 4;
  133. }
  134. /**
  135. * Generates a single rule or m5 model tree.
  136. *
  137. * @param data set of instances serving as training data
  138. * @exception Exception if the rule has not been generated
  139. * successfully
  140. */
  141. public void buildClassifier(Instances data) throws Exception {
  142. m_instances = null;
  143. m_topOfTree = null;
  144. m_covered = null;
  145. m_notCovered = null;
  146. m_ruleModel = null;
  147. m_splitAtts = null;
  148. m_splitVals = null;
  149. m_relOps = null;
  150. m_internalNodes = null;
  151. m_instances = data;
  152. m_classIndex = m_instances.classIndex();
  153. m_numAttributes = m_instances.numAttributes();
  154. m_numInstances = m_instances.numInstances();
  155. // first calculate global deviation of class attribute
  156. m_globalStdDev = Rule.stdDev(m_classIndex, m_instances);
  157. m_globalAbsDev = Rule.absDev(m_classIndex, m_instances);
  158. m_topOfTree = new RuleNode(m_globalStdDev, m_globalAbsDev, null);
  159. m_topOfTree.setSaveInstances(m_saveInstances);
  160. m_topOfTree.setRegressionTree(m_regressionTree);
  161. m_topOfTree.setMinNumInstances(m_minNumInstances);
  162. m_topOfTree.buildClassifier(m_instances);
  163. if (!m_useUnpruned) {
  164. m_topOfTree.prune();
  165. } else {
  166. m_topOfTree.installLinearModels();
  167. }
  168. if (m_smoothPredictions) {
  169. m_topOfTree.installSmoothedModels();
  170. }
  171. //m_topOfTree.printAllModels();
  172. m_topOfTree.numLeaves(0);
  173. if (!m_useTree) {
  174. makeRule();
  175. // save space
  176. // m_topOfTree = null;
  177. }
  178. // save space
  179. m_instances = new Instances(m_instances, 0);
  180. }
  181. /**
  182. * Calculates a prediction for an instance using this rule
  183. * or M5 model tree
  184. *
  185. * @param instance the instance whos class value is to be predicted
  186. * @return the prediction
  187. * @exception Exception if a prediction can't be made.
  188. */
  189. public double classifyInstance(Instance instance) throws Exception {
  190. if (m_useTree) {
  191. return m_topOfTree.classifyInstance(instance);
  192. }
  193. // does the instance pass the rule's conditions?
  194. if (m_splitAtts.length > 0) {
  195. for (int i = 0; i < m_relOps.length; i++) {
  196. if (m_relOps[i] == LEFT) // left
  197. {
  198. if (instance.value(m_splitAtts[i]) > m_splitVals[i]) {
  199. throw new Exception("Rule does not classify instance");
  200. }
  201. } else {
  202. if (instance.value(m_splitAtts[i]) <= m_splitVals[i]) {
  203. throw new Exception("Rule does not classify instance");
  204. }
  205. }
  206. }
  207. }
  208. // the linear model's prediction for this rule
  209. return m_ruleModel.classifyInstance(instance);
  210. }
  211. /**
  212. * Returns the top of the tree.
  213. */
  214. public RuleNode topOfTree() {
  215. return m_topOfTree;
  216. }
  217. /**
  218. * Make the single best rule from a pruned m5 model tree
  219. *
  220. * @exception Exception if something goes wrong.
  221. */
  222. private void makeRule() throws Exception {
  223. RuleNode[] best_leaf = new RuleNode[1];
  224. double[] best_cov = new double[1];
  225. RuleNode temp;
  226. m_notCovered = new Instances(m_instances, 0);
  227. m_covered = new Instances(m_instances, 0);
  228. best_cov[0] = -1;
  229. best_leaf[0] = null;
  230. m_topOfTree.findBestLeaf(best_cov, best_leaf);
  231. temp = best_leaf[0];
  232. if (temp == null) {
  233. throw new Exception("Unable to generate rule!");
  234. }
  235. // save the linear model for this rule
  236. m_ruleModel = temp;
  237. int count = 0;
  238. while (temp.parentNode() != null) {
  239. count++;
  240. temp = temp.parentNode();
  241. }
  242. temp = best_leaf[0];
  243. m_relOps = new int[count];
  244. m_splitAtts = new int[count];
  245. m_splitVals = new double[count];
  246. if (m_smoothPredictions) {
  247. m_internalNodes = new RuleNode[count];
  248. }
  249. // trace back to the root
  250. int i = 0;
  251. while (temp.parentNode() != null) {
  252. m_splitAtts[i] = temp.parentNode().splitAtt();
  253. m_splitVals[i] = temp.parentNode().splitVal();
  254. if (temp.parentNode().leftNode() == temp) {
  255. m_relOps[i] = LEFT;
  256. temp.parentNode().m_right = null;
  257. } else {
  258. m_relOps[i] = RIGHT;
  259. temp.parentNode().m_left = null;
  260. }
  261. if (m_smoothPredictions) {
  262. m_internalNodes[i] = temp.parentNode();
  263. }
  264. temp = temp.parentNode();
  265. i++;
  266. }
  267. // now assemble the covered and uncovered instances
  268. boolean ok;
  269. for (i = 0; i < m_numInstances; i++) {
  270. ok = true;
  271. for (int j = 0; j < m_relOps.length; j++) {
  272. if (m_relOps[j] == LEFT)
  273. {
  274. if (m_instances.instance(i).value(m_splitAtts[j])
  275. > m_splitVals[j]) {
  276. m_notCovered.add(m_instances.instance(i));
  277. ok = false;
  278. break;
  279. }
  280. } else {
  281. if (m_instances.instance(i).value(m_splitAtts[j])
  282. <= m_splitVals[j]) {
  283. m_notCovered.add(m_instances.instance(i));
  284. ok = false;
  285. break;
  286. }
  287. }
  288. }
  289. if (ok) {
  290. m_numCovered++;
  291. // m_covered.add(m_instances.instance(i));
  292. }
  293. }
  294. }
  295. /**
  296. * Return a description of the m5 tree or rule
  297. *
  298. * @return a description of the m5 tree or rule as a String
  299. */
  300. public String toString() {
  301. if (m_useTree) {
  302. return treeToString();
  303. } else {
  304. return ruleToString();
  305. }
  306. }
  307. /**
  308. * Return a description of the m5 tree
  309. *
  310. * @return a description of the m5 tree as a String
  311. */
  312. private String treeToString() {
  313. StringBuffer text = new StringBuffer();
  314. if (m_topOfTree == null) {
  315. return "Tree/Rule has not been built yet!";
  316. }
  317. text.append("M5 "
  318. + ((m_useUnpruned)
  319. ? "unpruned "
  320. : "pruned ")
  321. + ((m_regressionTree)
  322. ? "regression "
  323. : "model ")
  324. +"tree:\n");
  325. if (m_smoothPredictions == true) {
  326. text.append("(using smoothed linear models)\n");
  327. }
  328. text.append(m_topOfTree.treeToString(0));
  329. text.append(m_topOfTree.printLeafModels());
  330. text.append("\nNumber of Rules : " + m_topOfTree.numberOfLinearModels());
  331. return text.toString();
  332. }
  333. /**
  334. * Return a description of the rule
  335. *
  336. * @return a description of the rule as a String
  337. */
  338. private String ruleToString() {
  339. StringBuffer text = new StringBuffer();
  340. if (m_splitAtts.length > 0) {
  341. text.append("IF\n");
  342. for (int i = m_splitAtts.length - 1; i >= 0; i--) {
  343. text.append("\t" + m_covered.attribute(m_splitAtts[i]).name() + " ");
  344. if (m_relOps[i] == 0) {
  345. text.append("<= ");
  346. } else {
  347. text.append("> ");
  348. }
  349. text.append(Utils.doubleToString(m_splitVals[i], 1, 3) + "\n");
  350. }
  351. text.append("THEN\n");
  352. }
  353. if (m_ruleModel != null) {
  354. try {
  355. text.append(m_ruleModel.printNodeLinearModel());
  356. text.append(" [" + m_numCovered/*m_covered.numInstances()*/);
  357. if (m_globalAbsDev > 0.0) {
  358. text.append("/"+Utils.doubleToString((100 *
  359. m_ruleModel.
  360. rootMeanSquaredError() /
  361. m_globalStdDev), 1, 3)
  362. + "%]\n\n");
  363. } else {
  364. text.append("]\n\n");
  365. }
  366. } catch (Exception e) {
  367. return "Can't print rule";
  368. }
  369. }
  370. // System.out.println(m_instances);
  371. return text.toString();
  372. }
  373. /**
  374. * Use unpruned tree/rules
  375. *
  376. * @param unpruned true if unpruned tree/rules are to be generated
  377. */
  378. public void setUnpruned(boolean unpruned) {
  379. m_useUnpruned = unpruned;
  380. }
  381. /**
  382. * Get whether unpruned tree/rules are being generated
  383. *
  384. * @return true if unpruned tree/rules are to be generated
  385. */
  386. public boolean getUnpruned() {
  387. return m_useUnpruned;
  388. }
  389. /**
  390. * Use an m5 tree rather than generate rules
  391. *
  392. * @param u true if m5 tree is to be used
  393. */
  394. public void setUseTree(boolean u) {
  395. m_useTree = u;
  396. }
  397. /**
  398. * get whether an m5 tree is being used rather than rules
  399. *
  400. * @return true if an m5 tree is being used.
  401. */
  402. public boolean getUseTree() {
  403. return m_useTree;
  404. }
  405. /**
  406. * Smooth predictions
  407. *
  408. * @param s true if smoothing is to be used
  409. */
  410. public void setSmoothing(boolean s) {
  411. m_smoothPredictions = s;
  412. }
  413. /**
  414. * Get whether or not smoothing has been turned on
  415. *
  416. * @return true if smoothing is being used
  417. */
  418. public boolean getSmoothing() {
  419. return m_smoothPredictions;
  420. }
  421. /**
  422. * Get the instances not covered by this rule
  423. *
  424. * @return the instances not covered
  425. */
  426. public Instances notCoveredInstances() {
  427. return m_notCovered;
  428. }
  429. /**
  430. * Free up memory consumed by the set of instances
  431. * not covered by this rule.
  432. */
  433. public void freeNotCoveredInstances() {
  434. m_notCovered = null;
  435. }
  436. // /**
  437. // * Get the instances covered by this rule
  438. // *
  439. // * @return the instances covered by this rule
  440. // */
  441. // public Instances coveredInstances() {
  442. // return m_covered;
  443. // }
  444. /**
  445. * Returns the standard deviation value of the supplied attribute index.
  446. *
  447. * @param attr an attribute index
  448. * @param inst the instances
  449. * @return the standard deviation value
  450. */
  451. protected static final double stdDev(int attr, Instances inst) {
  452. int i,count=0;
  453. double sd,va,sum=0.0,sqrSum=0.0,value;
  454. for(i = 0; i <= inst.numInstances() - 1; i++) {
  455. count++;
  456. value = inst.instance(i).value(attr);
  457. sum += value;
  458. sqrSum += value * value;
  459. }
  460. if(count > 1) {
  461. va = (sqrSum - sum * sum / count) / count;
  462. va = Math.abs(va);
  463. sd = Math.sqrt(va);
  464. } else {
  465. sd = 0.0;
  466. }
  467. return sd;
  468. }
  469. /**
  470. * Returns the absolute deviation value of the supplied attribute index.
  471. *
  472. * @param attr an attribute index
  473. * @param inst the instances
  474. * @return the absolute deviation value
  475. */
  476. protected static final double absDev(int attr, Instances inst) {
  477. int i;
  478. double average=0.0,absdiff=0.0,absDev;
  479. for(i = 0; i <= inst.numInstances()-1; i++) {
  480. average += inst.instance(i).value(attr);
  481. }
  482. if(inst.numInstances() > 1) {
  483. average /= (double)inst.numInstances();
  484. for(i=0; i <= inst.numInstances()-1; i++) {
  485. absdiff += Math.abs(inst.instance(i).value(attr) - average);
  486. }
  487. absDev = absdiff / (double)inst.numInstances();
  488. } else {
  489. absDev = 0.0;
  490. }
  491. return absDev;
  492. }
  493. /**
  494. * Sets whether instances at each node in an M5 tree should be saved
  495. * for visualization purposes. Default is to save memory.
  496. *
  497. * @param save a <code>boolean</code> value
  498. */
  499. protected void setSaveInstances(boolean save) {
  500. m_saveInstances = save;
  501. }
  502. /**
  503. * Get the value of regressionTree.
  504. *
  505. * @return Value of regressionTree.
  506. */
  507. public boolean getRegressionTree() {
  508. return m_regressionTree;
  509. }
  510. /**
  511. * Set the value of regressionTree.
  512. *
  513. * @param newregressionTree Value to assign to regressionTree.
  514. */
  515. public void setRegressionTree(boolean newregressionTree) {
  516. m_regressionTree = newregressionTree;
  517. }
  518. /**
  519. * Set the minumum number of instances to allow at a leaf node
  520. *
  521. * @param minNum the minimum number of instances
  522. */
  523. public void setMinNumInstances(double minNum) {
  524. m_minNumInstances = minNum;
  525. }
  526. /**
  527. * Get the minimum number of instances to allow at a leaf node
  528. *
  529. * @return a <code>double</code> value
  530. */
  531. public double getMinNumInstances() {
  532. return m_minNumInstances;
  533. }
  534. public RuleNode getM5RootNode() {
  535. return m_topOfTree;
  536. }
  537. /**
  538. * Returns the revision string.
  539. *
  540. * @return the revision
  541. */
  542. public String getRevision() {
  543. return RevisionUtils.extract("$Revision: 6260 $");
  544. }
  545. }