/test/classifiers/id3_test.rb

https://github.com/MarcelorjOliveira/Sistemas-Especialistas · Ruby · 208 lines · 165 code · 23 blank · 20 comment · 0 complexity · d39e24878eba7c33d9f1901efa118002 MD5 · raw file

  1. # id3_test.rb
  2. #
  3. # This is a unit test file for the ID3 algorithm (Quinlan) implemented
  4. # in ai4r
  5. #
  6. # Author:: Sergio Fierens
  7. # License:: MPL 1.1
  8. # Project:: ai4r
  9. # Url:: http://ai4r.rubyforge.org/
  10. #
  11. # You can redistribute it and/or modify it under the terms of
  12. # the Mozilla Public License version 1.1 as published by the
  13. # Mozilla Foundation at http://www.mozilla.org/MPL/MPL-1.1.txt
  14. require File.dirname(__FILE__) + '/../../lib/ai4r/classifiers/id3'
  15. require 'test/unit'
  16. DATA_LABELS = [ 'city', 'age_range', 'gender', 'marketing_target' ]
  17. DATA_ITEMS = [ ['New York', '<30', 'M', 'Y'],
  18. ['Chicago', '<30', 'M', 'Y'],
  19. ['Chicago', '<30', 'F', 'Y'],
  20. ['New York', '<30', 'M', 'Y'],
  21. ['New York', '<30', 'M', 'Y'],
  22. ['Chicago', '[30-50)', 'M', 'Y'],
  23. ['New York', '[30-50)', 'F', 'N'],
  24. ['Chicago', '[30-50)', 'F', 'Y'],
  25. ['New York', '[30-50)', 'F', 'N'],
  26. ['Chicago', '[50-80]', 'M', 'N'],
  27. ['New York', '[50-80]', 'F', 'N'],
  28. ['New York', '[50-80]', 'M', 'N'],
  29. ['Chicago', '[50-80]', 'M', 'N'],
  30. ['New York', '[50-80]', 'F', 'N'],
  31. ['Chicago', '>80', 'F', 'Y']
  32. ]
  33. SPLIT_DATA_ITEMS_BY_CITY = [ [
  34. ["New York", "<30", "M", "Y"],
  35. ["New York", "<30", "M", "Y"],
  36. ["New York", "<30", "M", "Y"],
  37. ["New York", "[30-50)", "F", "N"],
  38. ["New York", "[30-50)", "F", "N"],
  39. ["New York", "[50-80]", "F", "N"],
  40. ["New York", "[50-80]", "M", "N"],
  41. ["New York", "[50-80]", "F", "N"]],
  42. [
  43. ["Chicago", "<30", "M", "Y"],
  44. ["Chicago", "<30", "F", "Y"],
  45. ["Chicago", "[30-50)", "M", "Y"],
  46. ["Chicago", "[30-50)", "F", "Y"],
  47. ["Chicago", "[50-80]", "M", "N"],
  48. ["Chicago", "[50-80]", "M", "N"],
  49. ["Chicago", ">80", "F", "Y"]]
  50. ]
  51. SPLIT_DATA_ITEMS_BY_AGE = [ [
  52. ["New York", "<30", "M", "Y"],
  53. ["Chicago", "<30", "M", "Y"],
  54. ["Chicago", "<30", "F", "Y"],
  55. ["New York", "<30", "M", "Y"],
  56. ["New York", "<30", "M", "Y"]],
  57. [
  58. ["Chicago", "[30-50)", "M", "Y"],
  59. ["New York", "[30-50)", "F", "N"],
  60. ["Chicago", "[30-50)", "F", "Y"],
  61. ["New York", "[30-50)", "F", "N"]],
  62. [
  63. ["Chicago", "[50-80]", "M", "N"],
  64. ["New York", "[50-80]", "F", "N"],
  65. ["New York", "[50-80]", "M", "N"],
  66. ["Chicago", "[50-80]", "M", "N"],
  67. ["New York", "[50-80]", "F", "N"]],
  68. [
  69. ["Chicago", ">80", "F", "Y"]]
  70. ]
  71. EXPECTED_RULES_STRING =
  72. "if age_range=='<30' then marketing_target='Y'\n"+
  73. "elsif age_range=='[30-50)' and city=='Chicago' then marketing_target='Y'\n"+
  74. "elsif age_range=='[30-50)' and city=='New York' then marketing_target='N'\n"+
  75. "elsif age_range=='[50-80]' then marketing_target='N'\n"+
  76. "elsif age_range=='>80' then marketing_target='Y'\n"+
  77. "else raise 'There was not enough information during training to do a proper induction for this data element' end"
  78. include Ai4r::Classifiers
  79. include Ai4r::Data
  80. class ID3Test < Test::Unit::TestCase
  81. def test_build
  82. Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.protected_instance_methods)
  83. Ai4r::Classifiers::ID3.send(:public, *Ai4r::Classifiers::ID3.private_instance_methods)
  84. end
  85. def test_log2
  86. assert_equal 1.0, ID3.log2(2)
  87. assert_equal 0.0, ID3.log2(0)
  88. assert 1.585 - ID3.log2(3) < 0.001
  89. end
  90. def test_sum
  91. assert_equal 28, ID3.sum([5, 0, 22, 1])
  92. assert_equal 0, ID3.sum([])
  93. end
  94. def test_data_labels
  95. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS))
  96. expected_default = [ 'attribute_1', 'attribute_2', 'attribute_3', 'class_value' ]
  97. assert_equal(expected_default, id3.data_set.data_labels)
  98. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  99. assert_equal(DATA_LABELS, id3.data_set.data_labels)
  100. end
  101. def test_domain
  102. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  103. expected_domain = [["New York", "Chicago"], ["<30", "[30-50)", "[50-80]", ">80"], ["M", "F"], ["Y", "N"]]
  104. assert_equal expected_domain, id3.domain(DATA_ITEMS)
  105. end
  106. def test_grid
  107. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  108. expected_grid = [[3, 5], [5, 2]]
  109. domain = id3.domain(DATA_ITEMS)
  110. assert_equal expected_grid, id3.freq_grid(0, DATA_ITEMS, domain)
  111. expected_grid = [[5, 0], [2, 2], [0, 5], [1, 0]]
  112. assert_equal expected_grid, id3.freq_grid(1, DATA_ITEMS, domain)
  113. end
  114. def test_entropy
  115. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  116. expected_entropy = 0.9118
  117. domain = id3.domain(DATA_ITEMS)
  118. freq_grid = id3.freq_grid(0, DATA_ITEMS, domain)
  119. assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
  120. expected_entropy = 0.2667
  121. freq_grid = id3.freq_grid(1, DATA_ITEMS, domain)
  122. assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
  123. expected_entropy = 0.9688
  124. freq_grid = id3.freq_grid(2, DATA_ITEMS, domain)
  125. assert expected_entropy - id3.entropy(freq_grid, DATA_ITEMS.length) < 0.0001
  126. end
  127. def test_min_entropy_index
  128. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  129. domain = id3.domain(DATA_ITEMS)
  130. assert_equal 1, id3.min_entropy_index(DATA_ITEMS, domain)
  131. assert_equal 0, id3.min_entropy_index(DATA_ITEMS, domain, [1])
  132. assert_equal 2, id3.min_entropy_index(DATA_ITEMS, domain, [1, 0])
  133. end
  134. def test_split_data_examples
  135. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  136. domain = id3.domain(DATA_ITEMS)
  137. res = id3.split_data_examples(DATA_ITEMS, domain, 0)
  138. assert_equal(SPLIT_DATA_ITEMS_BY_CITY, res)
  139. res = id3.split_data_examples(DATA_ITEMS, domain, 1)
  140. assert_equal(SPLIT_DATA_ITEMS_BY_AGE, res)
  141. end
  142. def test_most_freq
  143. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  144. domain = id3.domain(DATA_ITEMS)
  145. assert_equal 'Y', id3.most_freq(DATA_ITEMS, domain)
  146. assert_equal 'Y', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[3], domain)
  147. assert_equal 'N', id3.most_freq(SPLIT_DATA_ITEMS_BY_AGE[2], domain)
  148. end
  149. def test_get_rules
  150. assert_equal [["marketing_target='N'"]], CategoryNode.new('marketing_target', 'N').get_rules
  151. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  152. assert_equal EXPECTED_RULES_STRING, id3.get_rules
  153. end
  154. def test_eval
  155. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  156. #if age_range='<30' then marketing_target='Y'
  157. assert_equal 'Y', id3.eval(['New York', '<30', 'F'])
  158. assert_equal 'Y', id3.eval(['Chicago', '<30', 'M'])
  159. #if age_range='[30-50)' and city='Chicago' then marketing_target='Y'
  160. assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'F'])
  161. assert_equal 'Y', id3.eval(['Chicago', '[30-50)', 'M'])
  162. #if age_range='[30-50)' and city='New York' then marketing_target='N'
  163. assert_equal 'N', id3.eval(['New York', '[30-50)', 'F'])
  164. assert_equal 'N', id3.eval(['New York', '[30-50)', 'M'])
  165. #if age_range='[50-80]' then marketing_target='N'
  166. assert_equal 'N', id3.eval(['New York', '[50-80]', 'F'])
  167. assert_equal 'N', id3.eval(['Chicago', '[50-80]', 'M'])
  168. #if age_range='>80' then marketing_target='Y'
  169. assert_equal 'Y', id3.eval(['New York', '>80', 'M'])
  170. assert_equal 'Y', id3.eval(['Chicago', '>80', 'F'])
  171. end
  172. def test_rules_eval
  173. id3 = ID3.new.build(DataSet.new(:data_items =>DATA_ITEMS, :data_labels => DATA_LABELS))
  174. #if age_range='<30' then marketing_target='Y'
  175. age_range = '<30'
  176. marketing_target = nil
  177. eval id3.get_rules
  178. assert_equal 'Y', marketing_target
  179. #if age_range='[30-50)' and city='New York' then marketing_target='N'
  180. age_range='[30-50)'
  181. city='New York'
  182. eval id3.get_rules
  183. assert_equal 'N', marketing_target
  184. end
  185. end