/tests/milvus-java-test/src/main/java/com/Utils.java

https://github.com/milvus-io/milvus · Java · 242 lines · 212 code · 24 blank · 6 comment · 19 complexity · 2198207250a9d9c2811c0c6969b36794 MD5 · raw file

  1. package com;
  2. import com.alibaba.fastjson.JSON;
  3. import com.alibaba.fastjson.JSONArray;
  4. import io.milvus.client.*;
  5. import com.alibaba.fastjson.JSONObject;
  6. import org.apache.commons.lang3.RandomStringUtils;
  7. import java.nio.ByteBuffer;
  8. import java.util.*;
  9. import java.util.stream.Collectors;
  10. import java.util.stream.Stream;
  11. public class Utils {
  12. public static List<Float> normalize(List<Float> w2v){
  13. float squareSum = w2v.stream().map(x -> x * x).reduce((float) 0, Float::sum);
  14. final float norm = (float) Math.sqrt(squareSum);
  15. w2v = w2v.stream().map(x -> x / norm).collect(Collectors.toList());
  16. return w2v;
  17. }
  18. public static String genUniqueStr(String str_value){
  19. String prefix = "_"+RandomStringUtils.randomAlphabetic(10);
  20. String str = str_value == null || str_value.trim().isEmpty() ? "test" : str_value;
  21. return str.trim()+prefix;
  22. }
  23. public static List<List<Float>> genVectors(int vectorCount, int dimension, boolean norm) {
  24. List<List<Float>> vectors = new ArrayList<>();
  25. Random random = new Random();
  26. for (int i = 0; i < vectorCount; ++i) {
  27. List<Float> vector = new ArrayList<>();
  28. for (int j = 0; j < dimension; ++j) {
  29. vector.add(random.nextFloat());
  30. }
  31. if (norm == true) {
  32. vector = normalize(vector);
  33. }
  34. vectors.add(vector);
  35. }
  36. return vectors;
  37. }
  38. static List<ByteBuffer> genBinaryVectors(long vectorCount, long dimension) {
  39. Random random = new Random();
  40. List<ByteBuffer> vectors = new ArrayList<>();
  41. final long dimensionInByte = dimension / 8;
  42. for (long i = 0; i < vectorCount; ++i) {
  43. ByteBuffer byteBuffer = ByteBuffer.allocate((int) dimensionInByte);
  44. random.nextBytes(byteBuffer.array());
  45. vectors.add(byteBuffer);
  46. }
  47. return vectors;
  48. }
  49. private static List<Map<String, Object>> genBaseFieldsWithoutVector(){
  50. List<Map<String,Object>> fieldsList = new ArrayList<>();
  51. Map<String, Object> intFields = new HashMap<>();
  52. intFields.put("field","int64");
  53. intFields.put("type",DataType.INT64);
  54. Map<String, Object> floatField = new HashMap<>();
  55. floatField.put("field","float");
  56. floatField.put("type",DataType.FLOAT);
  57. fieldsList.add(intFields);
  58. fieldsList.add(floatField);
  59. return fieldsList;
  60. }
  61. public static List<Map<String, Object>> genDefaultFields(int dimension, boolean isBinary){
  62. List<Map<String, Object>> defaultFieldList = genBaseFieldsWithoutVector();
  63. Map<String, Object> vectorField = new HashMap<>();
  64. if (isBinary){
  65. vectorField.put("field","binary_vector");
  66. vectorField.put("type",DataType.VECTOR_BINARY);
  67. }else {
  68. vectorField.put("field","float_vector");
  69. vectorField.put("type",DataType.VECTOR_FLOAT);
  70. }
  71. JSONObject jsonObject = new JSONObject();
  72. jsonObject.put("dim", dimension);
  73. vectorField.put("params", jsonObject.toString());
  74. defaultFieldList.add(vectorField);
  75. return defaultFieldList;
  76. }
  77. public static List<Map<String,Object>> genDefaultEntities(int dimension, int vectorCount, List<List<Float>> vectors){
  78. List<Map<String,Object>> fieldsMap = genDefaultFields(dimension, false);
  79. List<Long> intValues = new ArrayList<>(vectorCount);
  80. List<Float> floatValues = new ArrayList<>(vectorCount);
  81. for (int i = 0; i < vectorCount; ++i) {
  82. intValues.add((long) i);
  83. floatValues.add((float) i);
  84. }
  85. for(Map<String,Object> field: fieldsMap){
  86. String fieldType = field.get("field").toString();
  87. switch (fieldType){
  88. case "int64":
  89. field.put("values",intValues);
  90. break;
  91. case "float":
  92. field.put("values",floatValues);
  93. break;
  94. case "float_vector":
  95. field.put("values",vectors);
  96. break;
  97. }
  98. }
  99. return fieldsMap;
  100. }
  101. public static List<Map<String,Object>> genDefaultBinaryEntities(int dimension, int vectorCount, List<ByteBuffer> vectorsBinary){
  102. List<Map<String,Object>> binaryFieldsMap = genDefaultFields(dimension, true);
  103. List<Long> intValues = new ArrayList<>(vectorCount);
  104. List<Float> floatValues = new ArrayList<>(vectorCount);
  105. // List<List<Float>> vectors = genVectors(vectorCount,dimension,false);
  106. // List<ByteBuffer> binaryVectors = genBinaryVectors(vectorCount,dimension);
  107. for (int i = 0; i < vectorCount; ++i) {
  108. intValues.add((long) i);
  109. floatValues.add((float) i);
  110. }
  111. for(Map<String,Object> field: binaryFieldsMap){
  112. String fieldType = field.get("field").toString();
  113. switch (fieldType){
  114. case "int64":
  115. field.put("values",intValues);
  116. break;
  117. case "float":
  118. field.put("values",floatValues);
  119. break;
  120. case "binary_vector":
  121. field.put("values",vectorsBinary);
  122. break;
  123. }
  124. }
  125. return binaryFieldsMap;
  126. }
  127. public static String setIndexParam(String indexType, String metricType, int nlist) {
  128. // ("{\"index_type\": \"IVF_SQ8\", \"metric_type\": \"L2\", \"\"params\": {\"nlist\": 2048}}")
  129. // JSONObject indexParam = new JSONObject();
  130. // indexParam.put("nlist", nlist);
  131. // return JSONObject.toJSONString(indexParam);
  132. String indexParams = String.format("{\"index_type\": %s, \"metric_type\": %s, \"params\": {\"nlist\": %s}}", indexType, metricType, nlist);
  133. return indexParams;
  134. }
  135. public static String setSearchParam(String metricType, List<List<Float>> queryVectors, int topk, int nprobe) {
  136. JSONObject searchParam = new JSONObject();
  137. JSONObject fieldParam = new JSONObject();
  138. fieldParam.put("topk", topk);
  139. fieldParam.put("metric_type", metricType);
  140. fieldParam.put("query", queryVectors);
  141. fieldParam.put("type", Constants.vectorType);
  142. JSONObject tmpSearchParam = new JSONObject();
  143. tmpSearchParam.put("nprobe", nprobe);
  144. fieldParam.put("params", tmpSearchParam);
  145. JSONObject vectorParams = new JSONObject();
  146. vectorParams.put(Constants.floatFieldName, fieldParam);
  147. searchParam.put("vector", vectorParams);
  148. JSONObject param = new JSONObject();
  149. JSONObject mustParam = new JSONObject();
  150. JSONArray tmp = new JSONArray();
  151. tmp.add(searchParam);
  152. mustParam.put("must", tmp);
  153. param.put("bool", mustParam);
  154. return JSONObject.toJSONString(param);
  155. }
  156. public static String setBinarySearchParam(String metricType, List<ByteBuffer> queryVectors, int topk, int nprobe) {
  157. JSONObject searchParam = new JSONObject();
  158. JSONObject fieldParam = new JSONObject();
  159. fieldParam.put("topk", topk);
  160. fieldParam.put("metricType", metricType);
  161. fieldParam.put("queryVectors", queryVectors);
  162. JSONObject tmpSearchParam = new JSONObject();
  163. tmpSearchParam.put("nprobe", nprobe);
  164. fieldParam.put("params", tmpSearchParam);
  165. JSONObject vectorParams = new JSONObject();
  166. vectorParams.put(Constants.floatFieldName, fieldParam);
  167. searchParam.put("vector", vectorParams);
  168. JSONObject boolParam = new JSONObject();
  169. JSONObject mustParam = new JSONObject();
  170. mustParam.put("must", new JSONArray().add(searchParam));
  171. boolParam.put("bool", mustParam);
  172. return JSONObject.toJSONString(searchParam);
  173. }
  174. public static int getIndexParamValue(String indexParam, String key) {
  175. return JSONObject.parseObject(indexParam).getIntValue(key);
  176. }
  177. public static JSONObject getCollectionInfo(String collectionInfo) {
  178. return JSONObject.parseObject(collectionInfo);
  179. }
  180. public static List<Long> toListIds(int id) {
  181. List<Long> ids = new ArrayList<>();
  182. ids.add((long)id);
  183. return ids;
  184. }
  185. public static List<Long> toListIds(long id) {
  186. List<Long> ids = new ArrayList<>();
  187. ids.add(id);
  188. return ids;
  189. }
  190. public static int getParam(String params, String key){
  191. JSONObject jsonObject = JSONObject.parseObject(params);
  192. System.out.println(jsonObject.toString());
  193. Integer value = jsonObject.getInteger(key);
  194. return value;
  195. }
  196. public static List<Float> getVector(List<Map<String,Object>> entities, int i){
  197. List<Float> vector = new ArrayList<>();
  198. entities.forEach(entity -> {
  199. if("float_vector".equals(entity.get("field")) && Objects.nonNull(entity.get("values"))){
  200. vector.add(((List<Float>)entity.get("values")).get(i));
  201. }
  202. });
  203. return vector;
  204. }
  205. public static JSONArray parseJsonArray(String message, String type) {
  206. JSONObject jsonObject = JSONObject.parseObject(message);
  207. JSONArray partitionsJsonArray = jsonObject.getJSONArray("partitions");
  208. if ("partitions".equals(type))
  209. return partitionsJsonArray;
  210. JSONArray segmentsJsonArray = ((JSONObject)partitionsJsonArray.get(0)).getJSONArray("segments");
  211. if ("segments".equals(type))
  212. return segmentsJsonArray;
  213. JSONArray filesJsonArray = ((JSONObject)segmentsJsonArray.get(0)).getJSONArray("files");
  214. if ("files".equals(type))
  215. return filesJsonArray;
  216. throw new RuntimeException("unsupported type");
  217. }
  218. }