PageRenderTime 46ms CodeModel.GetById 16ms RepoModel.GetById 1ms app.codeStats 0ms

/encog-core-silverlight/encog-core-silverlight/Persist/Persistors/SVMNetworkPersistor.cs

http://encog-cs.googlecode.com/
C# | 432 lines | 280 code | 54 blank | 98 comment | 69 complexity | cc7cb3c50c8f8bdc53e6f08042ff156e MD5 | raw file
  1. using System;
  2. using System.Collections.Generic;
  3. using System.Linq;
  4. using System.Text;
  5. using Encog.Parse.Tags.Read;
  6. using Encog.Neural.Networks.SVM;
  7. using Encog.MathUtil.LIBSVM;
  8. using Encog.Util;
  9. using Encog.Parse.Tags.Write;
  10. using Encog.Engine.Util;
  11. namespace Encog.Persist.Persistors
  12. {
  13. /// <summary>
  14. /// Persist a SVM network.
  15. /// </summary>
  16. public class SVMNetworkPersistor : IPersistor
  17. {
  18. /// <summary>
  19. /// Constants for the SVM types.
  20. /// </summary>
  21. public static readonly String[] svm_type_table = { "c_svc", "nu_svc",
  22. "one_class", "epsilon_svr", "nu_svr", };
  23. /// <summary>
  24. /// Constants for the kernel types.
  25. /// </summary>
  26. public static readonly String[] kernel_type_table = { "linear", "polynomial",
  27. "rbf", "sigmoid", "precomputed" };
  28. /// <summary>
  29. /// The input tag.
  30. /// </summary>
  31. public const String TAG_INPUT = "input";
  32. /// <summary>
  33. /// The output tag.
  34. /// </summary>
  35. public const String TAG_OUTPUT = "output";
  36. /// <summary>
  37. /// The models tag.
  38. /// </summary>
  39. public const String TAG_MODELS = "models";
  40. /// <summary>
  41. /// The data tag.
  42. /// </summary>
  43. public const String TAG_DATA = "Data";
  44. /// <summary>
  45. /// The row tag.
  46. /// </summary>
  47. public const String TAG_ROW = "Row";
  48. /// <summary>
  49. /// The model tag.
  50. /// </summary>
  51. public const String TAG_MODEL = "Model";
  52. /// <summary>
  53. /// The type of SVM this is.
  54. /// </summary>
  55. public const String TAG_TYPE_SVM = "typeSVM";
  56. /// <summary>
  57. /// The type of kernel to use.
  58. /// </summary>
  59. public const String TAG_TYPE_KERNEL = "typeKernel";
  60. /// <summary>
  61. /// The degree to use.
  62. /// </summary>
  63. public const String TAG_DEGREE = "degree";
  64. /// <summary>
  65. /// The gamma to use.
  66. /// </summary>
  67. public const String TAG_GAMMA = "gamma";
  68. /// <summary>
  69. /// The coefficient.
  70. /// </summary>
  71. public const String TAG_COEF0 = "coef0";
  72. /// <summary>
  73. /// The number of classes.
  74. /// </summary>
  75. public const String TAG_NUMCLASS = "numClass";
  76. /// <summary>
  77. /// The total number of cases.
  78. /// </summary>
  79. public const String TAG_TOTALSV = "totalSV";
  80. /// <summary>
  81. /// The rho to use.
  82. /// </summary>
  83. public const String TAG_RHO = "rho";
  84. /// <summary>
  85. /// The labels.
  86. /// </summary>
  87. public const String TAG_LABEL = "label";
  88. /// <summary>
  89. /// The A-probability.
  90. /// </summary>
  91. public const String TAG_PROB_A = "probA";
  92. /// <summary>
  93. /// The B-probability.
  94. /// </summary>
  95. public const String TAG_PROB_B = "probB";
  96. /// <summary>
  97. /// The number of support vectors.
  98. /// </summary>
  99. public const String TAG_NSV = "nSV";
  100. /// <summary>
  101. /// Load the SVM network.
  102. /// </summary>
  103. /// <param name="xmlin">Where to read it from.</param>
  104. /// <returns>The loaded object.</returns>
  105. public IEncogPersistedObject Load(ReadXML xmlin)
  106. {
  107. SVMNetwork result = null;
  108. int input = -1, output = -1;
  109. String name = xmlin.LastTag.Attributes[
  110. EncogPersistedCollection.ATTRIBUTE_NAME];
  111. String description = xmlin.LastTag.Attributes[
  112. EncogPersistedCollection.ATTRIBUTE_DESCRIPTION];
  113. while (xmlin.ReadToTag())
  114. {
  115. if (xmlin.IsIt(SVMNetworkPersistor.TAG_INPUT, true))
  116. {
  117. input = int.Parse(xmlin.ReadTextToTag());
  118. }
  119. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_OUTPUT, true))
  120. {
  121. output = int.Parse(xmlin.ReadTextToTag());
  122. }
  123. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_MODELS, true))
  124. {
  125. result = new SVMNetwork(input, output, false);
  126. HandleModels(xmlin, result);
  127. }
  128. else if (xmlin.IsIt(EncogPersistedCollection.TYPE_SVM, false))
  129. {
  130. break;
  131. }
  132. }
  133. result.Name = name;
  134. result.Description = description;
  135. return result;
  136. }
  137. /// <summary>
  138. /// Load the models.
  139. /// </summary>
  140. /// <param name="xmlin">Where to read the models from.</param>
  141. /// <param name="network">Where the models are read into.</param>
  142. private void HandleModels(ReadXML xmlin, SVMNetwork network)
  143. {
  144. int index = 0;
  145. while (xmlin.ReadToTag())
  146. {
  147. if (xmlin.IsIt(SVMNetworkPersistor.TAG_MODEL, true))
  148. {
  149. svm_parameter param = new svm_parameter();
  150. svm_model model = new svm_model();
  151. model.param = param;
  152. network.Models[index] = model;
  153. HandleModel(xmlin, network.Models[index]);
  154. index++;
  155. }
  156. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_MODELS, false))
  157. {
  158. break;
  159. }
  160. }
  161. }
  162. /// <summary>
  163. /// Handle a model.
  164. /// </summary>
  165. /// <param name="xmlin">Where to read the model from.</param>
  166. /// <param name="model">Where to load the model into.</param>
  167. private void HandleModel(ReadXML xmlin, svm_model model)
  168. {
  169. while (xmlin.ReadToTag())
  170. {
  171. if (xmlin.IsIt(SVMNetworkPersistor.TAG_TYPE_SVM, true))
  172. {
  173. int i = EngineArray.FindStringInArray(
  174. SVMNetworkPersistor.svm_type_table, xmlin.ReadTextToTag());
  175. model.param.svm_type = i;
  176. }
  177. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_DEGREE, true))
  178. {
  179. model.param.degree = int.Parse(xmlin.ReadTextToTag());
  180. }
  181. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_GAMMA, true))
  182. {
  183. model.param.gamma = double.Parse(xmlin.ReadTextToTag());
  184. }
  185. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_COEF0, true))
  186. {
  187. model.param.coef0 = double.Parse(xmlin.ReadTextToTag());
  188. }
  189. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_NUMCLASS, true))
  190. {
  191. model.nr_class = int.Parse(xmlin.ReadTextToTag());
  192. }
  193. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_TOTALSV, true))
  194. {
  195. model.l = int.Parse(xmlin.ReadTextToTag());
  196. }
  197. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_RHO, true))
  198. {
  199. int n = model.nr_class * (model.nr_class - 1) / 2;
  200. model.rho = new double[n];
  201. String[] st = xmlin.ReadTextToTag().Split(',');
  202. for (int i = 0; i < n; i++)
  203. model.rho[i] = double.Parse(st[i]);
  204. }
  205. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_LABEL, true))
  206. {
  207. int n = model.nr_class;
  208. model.label = new int[n];
  209. String[] st = xmlin.ReadTextToTag().Split(',');
  210. for (int i = 0; i < n; i++)
  211. model.label[i] = int.Parse(st[i]);
  212. }
  213. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_PROB_A, true))
  214. {
  215. int n = model.nr_class * (model.nr_class - 1) / 2;
  216. model.probA = new double[n];
  217. String[] st = xmlin.ReadTextToTag().Split(',');
  218. for (int i = 0; i < n; i++)
  219. model.probA[i] = Double.Parse(st[i]);
  220. }
  221. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_PROB_B, true))
  222. {
  223. int n = model.nr_class * (model.nr_class - 1) / 2;
  224. model.probB = new double[n];
  225. String[] st = xmlin.ReadTextToTag().Split(',');
  226. for (int i = 0; i < n; i++)
  227. model.probB[i] = Double.Parse(st[i]);
  228. }
  229. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_NSV, true))
  230. {
  231. int n = model.nr_class;
  232. model.nSV = new int[n];
  233. String[] st = xmlin.ReadTextToTag().Split(',');
  234. for (int i = 0; i < n; i++)
  235. model.nSV[i] = int.Parse(st[i]);
  236. }
  237. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_TYPE_KERNEL, true))
  238. {
  239. int i = EngineArray.FindStringInArray(
  240. SVMNetworkPersistor.kernel_type_table, xmlin
  241. .ReadTextToTag());
  242. model.param.kernel_type = i;
  243. }
  244. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_DATA, true))
  245. {
  246. HandleData(xmlin, model);
  247. }
  248. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_MODEL, false))
  249. {
  250. break;
  251. }
  252. }
  253. }
  254. /// <summary>
  255. /// Load the data from a model.
  256. /// </summary>
  257. /// <param name="xmlin">Where to read the data from.</param>
  258. /// <param name="model">The model to load data into.</param>
  259. private void HandleData(ReadXML xmlin, svm_model model)
  260. {
  261. int i = 0;
  262. int m = model.nr_class - 1;
  263. int l = model.l;
  264. model.sv_coef = EngineArray.AllocateDouble2D(m, l);
  265. model.SV = new svm_node[l][];
  266. while (xmlin.ReadToTag())
  267. {
  268. if (xmlin.IsIt(SVMNetworkPersistor.TAG_ROW, true))
  269. {
  270. String line = xmlin.ReadTextToTag();
  271. String[] st = xmlin.ReadTextToTag().Split(',');
  272. for (int k = 0; k < m; k++)
  273. model.sv_coef[k][i] = Double.Parse(st[i]);
  274. int n = st.Length / 2;
  275. model.SV[i] = new svm_node[n];
  276. int idx = 0;
  277. for (int j = 0; j < n; j++)
  278. {
  279. model.SV[i][j] = new svm_node();
  280. model.SV[i][j].index = int.Parse(st[idx++]);
  281. model.SV[i][j].value_Renamed = Double.Parse(st[idx++]);
  282. }
  283. i++;
  284. }
  285. else if (xmlin.IsIt(SVMNetworkPersistor.TAG_DATA, false))
  286. {
  287. break;
  288. }
  289. }
  290. }
  291. /// <summary>
  292. /// Save a model.
  293. /// </summary>
  294. /// <param name="xmlout">Where to save a model to.</param>
  295. /// <param name="model">The model to save to.</param>
  296. public static void SaveModel(WriteXML xmlout, svm_model model)
  297. {
  298. if (model != null)
  299. {
  300. xmlout.BeginTag(SVMNetworkPersistor.TAG_MODEL);
  301. svm_parameter param = model.param;
  302. xmlout.AddProperty(SVMNetworkPersistor.TAG_TYPE_SVM,
  303. svm_type_table[param.svm_type]);
  304. xmlout.AddProperty(SVMNetworkPersistor.TAG_TYPE_KERNEL,
  305. kernel_type_table[param.kernel_type]);
  306. if (param.kernel_type == svm_parameter.POLY)
  307. {
  308. xmlout.AddProperty(SVMNetworkPersistor.TAG_DEGREE, param.degree);
  309. }
  310. if (param.kernel_type == svm_parameter.POLY
  311. || param.kernel_type == svm_parameter.RBF
  312. || param.kernel_type == svm_parameter.SIGMOID)
  313. {
  314. xmlout.AddProperty(SVMNetworkPersistor.TAG_GAMMA, param.gamma);
  315. }
  316. if (param.kernel_type == svm_parameter.POLY
  317. || param.kernel_type == svm_parameter.SIGMOID)
  318. {
  319. xmlout.AddProperty(SVMNetworkPersistor.TAG_COEF0, param.coef0);
  320. }
  321. int nr_class = model.nr_class;
  322. int l = model.l;
  323. xmlout.AddProperty(SVMNetworkPersistor.TAG_NUMCLASS, nr_class);
  324. xmlout.AddProperty(SVMNetworkPersistor.TAG_TOTALSV, l);
  325. xmlout.AddProperty(SVMNetworkPersistor.TAG_RHO, model.rho, nr_class
  326. * (nr_class - 1) / 2);
  327. xmlout.AddProperty(SVMNetworkPersistor.TAG_LABEL, model.label,
  328. nr_class);
  329. xmlout.AddProperty(SVMNetworkPersistor.TAG_PROB_A, model.probA,
  330. nr_class * (nr_class - 1) / 2);
  331. xmlout.AddProperty(SVMNetworkPersistor.TAG_PROB_B, model.probB,
  332. nr_class * (nr_class - 1) / 2);
  333. xmlout.AddProperty(SVMNetworkPersistor.TAG_NSV, model.nSV, nr_class);
  334. xmlout.BeginTag(SVMNetworkPersistor.TAG_DATA);
  335. double[][] sv_coef = model.sv_coef;
  336. svm_node[][] SV = model.SV;
  337. StringBuilder line = new StringBuilder();
  338. for (int i = 0; i < l; i++)
  339. {
  340. line.Length = 0;
  341. for (int j = 0; j < nr_class - 1; j++)
  342. line.Append(sv_coef[j][i] + " ");
  343. svm_node[] p = SV[i];
  344. //if (param.kernel_type == svm_parameter.PRECOMPUTED)
  345. //{
  346. // line.Append("0:" + (int) (p[0].value));
  347. //}
  348. //else
  349. for (int j = 0; j < p.Length; j++)
  350. line.Append(p[j].index + ":" + p[j].value_Renamed + " ");
  351. xmlout.AddProperty(SVMNetworkPersistor.TAG_ROW, line.ToString());
  352. }
  353. xmlout.EndTag();
  354. xmlout.EndTag();
  355. }
  356. }
  357. /// <summary>
  358. /// Save a SVMNetwork.
  359. /// </summary>
  360. /// <param name="obj">The object to save.</param>
  361. /// <param name="xmlout">Where to save it to.</param>
  362. public void Save(IEncogPersistedObject obj, WriteXML xmlout)
  363. {
  364. PersistorUtil.BeginEncogObject(EncogPersistedCollection.TYPE_SVM, xmlout,
  365. obj, true);
  366. SVMNetwork net = (SVMNetwork)obj;
  367. xmlout.AddProperty(SVMNetworkPersistor.TAG_INPUT, net.InputCount);
  368. xmlout.AddProperty(SVMNetworkPersistor.TAG_OUTPUT, net.OutputCount);
  369. xmlout.BeginTag(SVMNetworkPersistor.TAG_MODELS);
  370. for (int i = 0; i < net.Models.Length; i++)
  371. {
  372. SaveModel(xmlout, net.Models[i]);
  373. }
  374. xmlout.EndTag();
  375. xmlout.EndTag();
  376. }
  377. }
  378. }