/MST Parser/Parameters.cs

http://github.com/rasoolims/MSTParserCSharp · C# · 284 lines · 235 code · 49 blank · 0 comment · 33 complexity · 55783464554800fdd894a16219688be2 MD5 · raw file

  1. using System;
  2. using MSTParser.Extensions;
  3. namespace MSTParser
  4. {
  5. public class Parameters
  6. {
  7. public LossTypes LossType = LossTypes.Punc;
  8. public double[] parameters;
  9. public double[] Total;
  10. public Parameters(int size)
  11. {
  12. parameters = new double[size];
  13. Total = new double[size];
  14. for (int i = 0; i < parameters.Length; i++)
  15. {
  16. parameters[i] = 0.0;
  17. Total[i] = 0.0;
  18. }
  19. LossType = LossTypes.Punc;
  20. }
  21. public void SetLoss(LossTypes lt)
  22. {
  23. LossType = lt;
  24. }
  25. public void AverageParams(double avVal)
  26. {
  27. for (int j = 0; j < Total.Length; j++)
  28. Total[j] *= 1.0/(avVal);
  29. parameters = Total;
  30. }
  31. public void UpdateParamsMIRA(DependencyInstance inst, object[,] d, double upd)
  32. {
  33. string actParseTree = inst.ActParseTree;
  34. FeatureVector actFV = inst.Fv;
  35. int K = 0;
  36. for (int i = 0; i < d.GetLength(0) && d[i, 0] != null; i++)
  37. {
  38. K = i + 1;
  39. }
  40. var b = new double[K];
  41. var lamDist = new double[K];
  42. var dist = new FeatureVector[K];
  43. for (int k = 0; k < K; k++)
  44. {
  45. lamDist[k] = GetScore(actFV)
  46. - GetScore((FeatureVector) d[k, 0]);
  47. b[k] = NumErrors(inst, (string) d[k, 1], actParseTree);
  48. b[k] -= lamDist[k];
  49. dist[k] = FeatureVector.GetDistVector(actFV, (FeatureVector) d[k, 0]);
  50. }
  51. double[] alpha = hildreth(dist, b);
  52. FeatureVector fv = null;
  53. int res = 0;
  54. for (int k = 0; k < K; k++)
  55. {
  56. fv = dist[k];
  57. foreach (Feature feature in fv.FVector)
  58. {
  59. if (feature.Index < 0)
  60. continue;
  61. parameters[feature.Index] += alpha[k]*feature.Value;
  62. Total[feature.Index] += upd*alpha[k]*feature.Value;
  63. }
  64. }
  65. }
  66. public double GetScore(FeatureVector fv)
  67. {
  68. double score = 0.0;
  69. foreach (Feature feature in fv.FVector)
  70. {
  71. if (feature.Index >= 0)
  72. score += parameters[feature.Index]*feature.Value;
  73. }
  74. return score;
  75. }
  76. private double[] hildreth(FeatureVector[] a, double[] b)
  77. {
  78. int i;
  79. const int maxIter = 10000;
  80. const double eps = 0.00000001;
  81. const double zero = 0.000000000001;
  82. var alpha = new double[b.Length];
  83. var F = new double[b.Length];
  84. var kkt = new double[b.Length];
  85. double maxKkt = double.NegativeInfinity;
  86. int K = a.Length;
  87. var A = new double[K][];
  88. for (int j = 0; j < A.Length; j++)
  89. {
  90. A[j] = new double[K];
  91. }
  92. var isComputed = new bool[K];
  93. for (i = 0; i < K; i++)
  94. {
  95. A[i][i] = FeatureVector.DotProduct(a[i], a[i]);
  96. isComputed[i] = false;
  97. }
  98. int maxKktI = -1;
  99. for (i = 0; i < F.Length; i++)
  100. {
  101. F[i] = b[i];
  102. kkt[i] = F[i];
  103. if (kkt[i] > maxKkt)
  104. {
  105. maxKkt = kkt[i];
  106. maxKktI = i;
  107. }
  108. }
  109. int iter = 0;
  110. double diff_alpha;
  111. double try_alpha;
  112. double add_alpha;
  113. while (maxKkt >= eps && iter < maxIter)
  114. {
  115. diff_alpha = A[maxKktI][maxKktI] <= zero ? 0.0 : F[maxKktI]/A[maxKktI][maxKktI];
  116. try_alpha = alpha[maxKktI] + diff_alpha;
  117. add_alpha = 0.0;
  118. if (try_alpha < 0.0)
  119. add_alpha = -1.0*alpha[maxKktI];
  120. else
  121. add_alpha = diff_alpha;
  122. alpha[maxKktI] = alpha[maxKktI] + add_alpha;
  123. if (!isComputed[maxKktI])
  124. {
  125. for (i = 0; i < K; i++)
  126. {
  127. A[i][maxKktI] = FeatureVector.DotProduct(a[i], a[maxKktI]); // for version 1
  128. isComputed[maxKktI] = true;
  129. }
  130. }
  131. for (i = 0; i < F.Length; i++)
  132. {
  133. F[i] -= add_alpha*A[i][maxKktI];
  134. kkt[i] = F[i];
  135. if (alpha[i] > zero)
  136. kkt[i] = Math.Abs(F[i]);
  137. }
  138. maxKkt = double.NegativeInfinity;
  139. maxKktI = -1;
  140. for (i = 0; i < F.Length; i++)
  141. if (kkt[i] > maxKkt)
  142. {
  143. maxKkt = kkt[i];
  144. maxKktI = i;
  145. }
  146. iter++;
  147. }
  148. return alpha;
  149. }
  150. public double NumErrors(DependencyInstance inst, string pred, string act)
  151. {
  152. if (LossType==LossTypes.NoPunc)
  153. return NumErrorsDepNoPunc(inst, pred, act) + NumErrorsLabelNoPunc(inst, pred, act);
  154. return NumErrorsDep(inst, pred, act) + NumErrorsLabel(inst, pred, act);
  155. }
  156. public double NumErrorsDep(DependencyInstance inst, string pred, string act)
  157. {
  158. string[] actSpans = act.Split(' ');
  159. string[] predSpans = pred.Split(' ');
  160. int correct = 0;
  161. for (int i = 0; i < predSpans.Length; i++)
  162. {
  163. string p = predSpans[i].Split(':')[0];
  164. string a = actSpans[i].Split(':')[0];
  165. if (p.Equals(a))
  166. {
  167. correct++;
  168. }
  169. }
  170. return ((double) actSpans.Length - correct);
  171. }
  172. public double NumErrorsLabel(DependencyInstance inst, string pred, string act)
  173. {
  174. string[] actSpans = act.Split(' ');
  175. string[] predSpans = pred.Split(' ');
  176. int correct = 0;
  177. for (int i = 0; i < predSpans.Length; i++)
  178. {
  179. string p = predSpans[i].Split(':')[1];
  180. string a = actSpans[i].Split(':')[1];
  181. if (p.Equals(a))
  182. {
  183. correct++;
  184. }
  185. }
  186. return ((double) actSpans.Length - correct);
  187. }
  188. public double NumErrorsDepNoPunc(DependencyInstance inst, string pred, string act)
  189. {
  190. string[] actSpans = act.Split(' ');
  191. string[] predSpans = pred.Split(' ');
  192. string[] pos = inst.POS;
  193. int correct = 0;
  194. int numPunc = 0;
  195. for (int i = 0; i < predSpans.Length; i++)
  196. {
  197. string p = predSpans[i].Split(':')[0];
  198. string a = actSpans[i].Split(':')[0];
  199. if (pos[i + 1].Matches(@"[,:\.'`]+"))
  200. {
  201. numPunc++;
  202. continue;
  203. }
  204. if (p.Equals(a))
  205. {
  206. correct++;
  207. }
  208. }
  209. return ((double) actSpans.Length - numPunc - correct);
  210. }
  211. public double NumErrorsLabelNoPunc(DependencyInstance inst, string pred, string act)
  212. {
  213. string[] actSpans = act.Split(' ');
  214. string[] predSpans = pred.Split(' ');
  215. string[] pos = inst.POS;
  216. int correct = 0;
  217. int numPunc = 0;
  218. for (int i = 0; i < predSpans.Length; i++)
  219. {
  220. string p = predSpans[i].Split(':')[1];
  221. string a = actSpans[i].Split(':')[1];
  222. if (pos[i + 1].Matches("[,:.'`]+"))
  223. {
  224. numPunc++;
  225. continue;
  226. }
  227. if (p.Equals(a))
  228. {
  229. correct++;
  230. }
  231. }
  232. return ((double) actSpans.Length - numPunc - correct);
  233. }
  234. }
  235. }