/core/src/main/java/com/alibaba/alink/operator/common/clustering/kmeans/KMeansUtil.java

https://github.com/alibaba/Alink · Java · 311 lines · 200 code · 19 blank · 92 comment · 26 complexity · 4d3544fcd2c3273839d5f5f9dc828f38 MD5 · raw file

  1. package com.alibaba.alink.operator.common.clustering.kmeans;
  2. import com.alibaba.alink.common.linalg.*;
  3. import com.alibaba.alink.common.utils.JsonConverter;
  4. import com.alibaba.alink.common.utils.TableUtil;
  5. import com.alibaba.alink.operator.common.clustering.DistanceType;
  6. import com.alibaba.alink.operator.common.distance.ContinuousDistance;
  7. import com.alibaba.alink.operator.common.distance.FastDistance;
  8. import com.alibaba.alink.operator.common.distance.FastDistanceMatrixData;
  9. import com.alibaba.alink.operator.common.distance.FastDistanceVectorData;
  10. import com.alibaba.alink.params.shared.clustering.HasKMeansWithHaversineDistanceType;
  11. import org.apache.commons.math3.stat.StatUtils;
  12. import org.apache.flink.api.java.tuple.Tuple2;
  13. import org.apache.flink.ml.api.misc.param.Params;
  14. import org.apache.flink.types.Row;
  15. import org.apache.flink.util.Preconditions;
  16. import java.io.Serializable;
  17. import java.util.ArrayList;
  18. import java.util.Arrays;
  19. import java.util.List;
  20. /**
  21. * Common functions for KMeans.
  22. */
  23. public class KMeansUtil implements Serializable {
  24. /**
  25. * Build the FastDistanceMatrixData from a list of FastDistanceVectorData.
  26. *
  27. * @param vectors list of FastDistanceVectorData.
  28. * @param distance FastDistance.
  29. * @param vectorSize vectorSize.
  30. * @return FastDistanceMatrixData
  31. */
  32. public static FastDistanceMatrixData buildCentroidsMatrix(List<FastDistanceVectorData> vectors,
  33. FastDistance distance,
  34. int vectorSize) {
  35. DenseMatrix matrix = new DenseMatrix(vectorSize, vectors.size());
  36. for (int i = 0; i < vectors.size(); i++) {
  37. MatVecOp.appendVectorToMatrix(matrix, false, i, vectors.get(i).getVector());
  38. }
  39. FastDistanceMatrixData centroid = new FastDistanceMatrixData(matrix);
  40. distance.updateLabel(centroid);
  41. return centroid;
  42. }
  43. /**
  44. * Find the closest centroid from centroids for sample, and add the sample to sumMatrix.
  45. *
  46. * @param sample query sample.
  47. * @param sampleWeight sample weight.
  48. * @param centroids centroids.
  49. * @param vectorSize vectorsize.
  50. * @param sumMatrix the sumMatrix to be update.
  51. * @param k centroid number.
  52. * @param fastDistance distance.
  53. * @param distanceMatrix preallocated distance result matrix.
  54. * @return the closest cluster index.
  55. */
  56. public static int updateSumMatrix(FastDistanceVectorData sample,
  57. long sampleWeight,
  58. FastDistanceMatrixData centroids,
  59. int vectorSize,
  60. double[] sumMatrix,
  61. int k,
  62. FastDistance fastDistance,
  63. DenseMatrix distanceMatrix) {
  64. Preconditions.checkNotNull(sumMatrix);
  65. Preconditions.checkNotNull(distanceMatrix);
  66. Preconditions.checkArgument(distanceMatrix.numRows() == centroids.getVectors().numCols() &&
  67. distanceMatrix.numCols() == 1, "Memory not preallocated!");
  68. fastDistance.calc(sample, centroids, distanceMatrix);
  69. int clusterIndex = getClosestClusterIndex(sample, centroids, k, fastDistance, distanceMatrix).f0;
  70. int startIndex = clusterIndex * (vectorSize + 1);
  71. Vector vec = sample.getVector();
  72. if (vec instanceof DenseVector) {
  73. BLAS.axpy(vectorSize, sampleWeight, ((DenseVector)vec).getData(), 0, sumMatrix, startIndex);
  74. } else {
  75. SparseVector sparseVector = (SparseVector)vec;
  76. sparseVector.forEach((index, value) -> sumMatrix[startIndex + index] += sampleWeight * value);
  77. }
  78. sumMatrix[startIndex + vectorSize] += sampleWeight;
  79. return clusterIndex;
  80. }
  81. /**
  82. * Find the closest cluster index.
  83. *
  84. * @param sample query sample.
  85. * @param centroids centroids.
  86. * @param k cluster number.
  87. * @param distance FastDistance.
  88. * @param distanceMatrix Preallocated distance matrix.
  89. * @return the closest cluster index and distance.
  90. */
  91. public static Tuple2<Integer, Double> getClosestClusterIndex(FastDistanceVectorData sample,
  92. FastDistanceMatrixData centroids,
  93. int k,
  94. FastDistance distance,
  95. DenseMatrix distanceMatrix) {
  96. getClusterDistances(sample, centroids, distance, distanceMatrix);
  97. double[] data = distanceMatrix.getData();
  98. int index = getMinPointIndex(data, k);
  99. return Tuple2.of(index, data[index]);
  100. }
  101. /**
  102. * Find the distances from the centroids.
  103. * @param sample query sample.
  104. * @param centroids centroids.
  105. * @param distance FastDistance.
  106. * @param distanceMatrix Preallocated distance matrix.
  107. * @return the distance array.
  108. */
  109. public static double[] getClusterDistances(FastDistanceVectorData sample,
  110. FastDistanceMatrixData centroids,
  111. FastDistance distance,
  112. DenseMatrix distanceMatrix) {
  113. Preconditions.checkNotNull(distanceMatrix);
  114. Preconditions.checkArgument(distanceMatrix.numRows() == centroids.getVectors().numCols() &&
  115. distanceMatrix.numCols() == 1, "Memory not preallocated!");
  116. distance.calc(sample, centroids, distanceMatrix);
  117. return distanceMatrix.getData();
  118. }
  119. /**
  120. * Find the closest cluster index.
  121. *
  122. * @param trainModelData trainModel
  123. * @param sample query sample
  124. * @param distance ContinuousDistance
  125. * @return the index and distance.
  126. */
  127. public static Tuple2<Integer, Double> getClosestClusterIndex(KMeansTrainModelData trainModelData,
  128. Vector sample,
  129. ContinuousDistance distance) {
  130. double[] distances = getClusterDistances(trainModelData, sample, distance);
  131. int index = getMinPointIndex(distances, trainModelData.params.k);
  132. return Tuple2.of(index, distances[index]);
  133. }
  134. /**
  135. * Find the distances from the centroids.
  136. *
  137. * @param trainModelData trainModel
  138. * @param sample query sample
  139. * @param distance ContinuousDistance
  140. * @return the distance array.
  141. */
  142. public static double[] getClusterDistances(KMeansTrainModelData trainModelData,
  143. Vector sample,
  144. ContinuousDistance distance) {
  145. double[] res = new double[trainModelData.params.k];
  146. for(int i = 0; i < res.length; i++){
  147. res[i] = distance.calc(trainModelData.getClusterVector(i), sample);
  148. }
  149. return res;
  150. }
  151. public static int getMinPointIndex(double[] data, int endIndex){
  152. Preconditions.checkArgument(endIndex <= data.length, "End index must be less than data length!");
  153. int index = -1;
  154. double min = Double.MAX_VALUE;
  155. for (int i = 0; i < endIndex; i++) {
  156. if (data[i] < min) {
  157. index = i;
  158. min = data[i];
  159. }
  160. }
  161. return index;
  162. }
  163. /**
  164. * Get the selected columns indexes from the input columns. Support vector input or latitudeCol and longtitude
  165. * inputs.
  166. *
  167. * @param params ParamSummary.
  168. * @param dataCols input columns.
  169. * @return selected columns indexes.
  170. */
  171. public static int[] getKmeansPredictColIdxs(KMeansTrainModelData.ParamSummary params, String[] dataCols) {
  172. Preconditions.checkArgument((null == params.longtitudeColName) == (null == params.latitudeColName),
  173. "Model Format error!");
  174. Preconditions.checkArgument(params.distanceType.equals(HasKMeansWithHaversineDistanceType.DistanceType.HAVERSINE) == (null == params.vectorColName
  175. && null != params.longtitudeColName),
  176. "Model Format error!");
  177. int[] colIdxs;
  178. if (null != params.vectorColName) {
  179. colIdxs = new int[1];
  180. colIdxs[0] = TableUtil.findColIndexWithAssert(dataCols, params.vectorColName);
  181. } else {
  182. colIdxs = new int[2];
  183. colIdxs[0] = TableUtil.findColIndexWithAssert(dataCols, params.latitudeColName);
  184. colIdxs[1] = TableUtil.findColIndexWithAssert(dataCols, params.longtitudeColName);
  185. }
  186. return colIdxs;
  187. }
  188. /**
  189. * Extract the vector from Row.
  190. *
  191. * @param colIdxs selected column indices.
  192. * @param row Row.
  193. * @return the vector.
  194. */
  195. public static Vector getKMeansPredictVector(int[] colIdxs, Row row) {
  196. Vector vec;
  197. if (colIdxs.length > 1) {
  198. vec = new DenseVector(2);
  199. vec.set(0, ((Number)row.getField(colIdxs[0])).doubleValue());
  200. vec.set(1, ((Number)row.getField(colIdxs[1])).doubleValue());
  201. } else {
  202. vec = VectorUtil.getVector(row.getField(colIdxs[0]));
  203. }
  204. return vec;
  205. }
  206. /**
  207. * Transform KMeansPredictModelData to KMeansTrainModelData.
  208. *
  209. * @param predictModelData KMeansPredictModelData.
  210. * @return KMeansTrainModelData.
  211. */
  212. public static KMeansTrainModelData transformPredictDataToTrainData(KMeansPredictModelData predictModelData) {
  213. KMeansTrainModelData modelData = new KMeansTrainModelData();
  214. modelData.params = predictModelData.params;
  215. modelData.centroids = new ArrayList<>();
  216. for (int i = 0; i < predictModelData.params.k; i++) {
  217. KMeansTrainModelData.ClusterSummary clusterSummary = new KMeansTrainModelData.ClusterSummary(
  218. predictModelData.getClusterVector(i),
  219. predictModelData.getClusterId(i),
  220. predictModelData.getClusterWeight(i));
  221. modelData.centroids.add(clusterSummary);
  222. }
  223. return modelData;
  224. }
  225. /**
  226. * Transform KMeansTrainModelData to KMeansPredictModelData.
  227. *
  228. * @param trainModelData KMeansTrainModelData.
  229. * @return KMeansPredictModelData.
  230. */
  231. public static KMeansPredictModelData transformTrainDataToPredictData(KMeansTrainModelData trainModelData) {
  232. KMeansPredictModelData modelData = new KMeansPredictModelData();
  233. modelData.params = trainModelData.params;
  234. DenseMatrix denseMatrix = new DenseMatrix(trainModelData.params.vectorSize, trainModelData.params.k);
  235. Row[] rows = new Row[trainModelData.params.k];
  236. int index = 0;
  237. for (int i = 0; i < trainModelData.centroids.size(); i++) {
  238. MatVecOp.appendVectorToMatrix(denseMatrix, false, index, trainModelData.getClusterVector(i));
  239. rows[index] = Row.of(trainModelData.getClusterId(i), trainModelData.getClusterWeight(i));
  240. index++;
  241. }
  242. modelData.centroids = new FastDistanceMatrixData(denseMatrix, rows);
  243. (modelData.params.distanceType.getFastDistance()).updateLabel(modelData.centroids);
  244. return modelData;
  245. }
  246. public static double[] getProbArrayFromDistanceArray(double[] distances){
  247. double sum = StatUtils.sum(distances);
  248. double ratio = 1.0 / sum / (distances.length - 1);
  249. double[] probs = new double[distances.length];
  250. Arrays.fill(probs, 1.0 / (distances.length - 1));
  251. BLAS.axpy(-ratio, distances, probs);
  252. return probs;
  253. }
  254. /**
  255. * Load KMeansTrainModelData from saved model.
  256. *
  257. * @param params saved params.
  258. * @param data saved data.
  259. * @return KMeansTrainModelData.
  260. */
  261. public static KMeansTrainModelData loadModelForTrain(Params params, Iterable<String> data) {
  262. KMeansTrainModelData trainModelData = new KMeansTrainModelData();
  263. trainModelData.params = new KMeansTrainModelData.ParamSummary(params);
  264. trainModelData.centroids = new ArrayList<>(trainModelData.params.k);
  265. data.forEach(s -> {
  266. try {
  267. trainModelData.centroids.add(JsonConverter.fromJson(s, KMeansTrainModelData.ClusterSummary.class));
  268. } catch (Exception e) {
  269. OldClusterSummary oldClusterSummary = JsonConverter.fromJson(s, OldClusterSummary.class);
  270. DenseVector vec;
  271. if (oldClusterSummary.center.contains("data")) {
  272. vec = JsonConverter.fromJson(oldClusterSummary.center, DenseVector.class);
  273. } else {
  274. vec = new DenseVector(JsonConverter.fromJson(oldClusterSummary.center, double[].class));
  275. }
  276. KMeansTrainModelData.ClusterSummary clusterSummary = new KMeansTrainModelData.ClusterSummary(
  277. vec,
  278. oldClusterSummary.clusterId,
  279. oldClusterSummary.weight
  280. );
  281. trainModelData.centroids.add(clusterSummary);
  282. }
  283. });
  284. return trainModelData;
  285. }
  286. static class OldClusterSummary implements Serializable {
  287. public long clusterId;
  288. public double weight;
  289. public String center;
  290. public DenseVector vec;
  291. }
  292. }