|
@@ -0,0 +1,313 @@
|
|
|
|
|
+package com.fs.ai.rag.service.impl;
|
|
|
|
|
+
|
|
|
|
|
+import cn.hutool.http.HttpRequest;
|
|
|
|
|
+import cn.hutool.http.HttpResponse;
|
|
|
|
|
+import cn.hutool.http.Method;
|
|
|
|
|
+import cn.hutool.json.JSONUtil;
|
|
|
|
|
+import com.fs.ai.rag.dto.*;
|
|
|
|
|
+import com.fs.ai.rag.service.QdrantService;
|
|
|
|
|
+import com.fs.framework.config.properties.QdrantProperties;
|
|
|
|
|
+import lombok.extern.slf4j.Slf4j;
|
|
|
|
|
+import org.apache.commons.lang3.StringUtils;
|
|
|
|
|
+import org.springframework.stereotype.Service;
|
|
|
|
|
+
|
|
|
|
|
+import java.util.ArrayList;
|
|
|
|
|
+import java.util.LinkedHashMap;
|
|
|
|
|
+import java.util.List;
|
|
|
|
|
+import java.util.Map;
|
|
|
|
|
+
|
|
|
|
|
+@Slf4j
|
|
|
|
|
+@Service
|
|
|
|
|
+public class QdrantServiceImpl implements QdrantService {
|
|
|
|
|
+
|
|
|
|
|
+ private final QdrantProperties qdrantProperties;
|
|
|
|
|
+
|
|
|
|
|
+ public QdrantServiceImpl(QdrantProperties qdrantProperties) {
|
|
|
|
|
+ this.qdrantProperties = qdrantProperties;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void createCollection(QdrantCollectionReq req) {
|
|
|
|
|
+ String collectionName = requireCollectionName(req.getCollectionName());
|
|
|
|
|
+ Integer vectorSize = req.getVectorSize() != null ? req.getVectorSize() : qdrantProperties.getVectorSize();
|
|
|
|
|
+ try {
|
|
|
|
|
+ if (collectionExists(collectionName)) {
|
|
|
|
|
+ log.info("Qdrant collection 已存在, 跳过创建, collectionName={}", collectionName);
|
|
|
|
|
+ return;
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> body = new LinkedHashMap<>();
|
|
|
|
|
+ Map<String, Object> vectors = new LinkedHashMap<>();
|
|
|
|
|
+ vectors.put("size", vectorSize);
|
|
|
|
|
+ vectors.put("distance", "Cosine");
|
|
|
|
|
+ body.put("vectors", vectors);
|
|
|
|
|
+ exchange(collectionUrl(collectionName), Method.PUT, body);
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("创建 Qdrant collection 失败: " + collectionName, e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void deleteCollection(String collectionName) {
|
|
|
|
|
+ String requiredName = requireCollectionName(collectionName);
|
|
|
|
|
+ try {
|
|
|
|
|
+ exchange(collectionUrl(requiredName), Method.DELETE, null);
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("删除 Qdrant collection 失败: " + requiredName, e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void upsertPoints(QdrantPointUpsertReq req) {
|
|
|
|
|
+ validateUpsertReq(req);
|
|
|
|
|
+ createCollectionIfAbsent(req.getCollectionName(), resolveVectorSize(req.getVectors()));
|
|
|
|
|
+ List<Map<String, Object>> points = new ArrayList<>();
|
|
|
|
|
+ for (int i = 0; i < req.getIds().size(); i++) {
|
|
|
|
|
+ Map<String, Object> point = new LinkedHashMap<>();
|
|
|
|
|
+ point.put("id", req.getIds().get(i));
|
|
|
|
|
+ point.put("vector", req.getVectors().get(i));
|
|
|
|
|
+ Map<String, Object> payload = buildPayload(req, i);
|
|
|
|
|
+ if (!payload.isEmpty()) {
|
|
|
|
|
+ point.put("payload", payload);
|
|
|
|
|
+ }
|
|
|
|
|
+ points.add(point);
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> body = new LinkedHashMap<>();
|
|
|
|
|
+ body.put("points", points);
|
|
|
|
|
+ body.put("wait", true);
|
|
|
|
|
+ try {
|
|
|
|
|
+ exchange(pointsUrl(req.getCollectionName()), Method.PUT, body);
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("Qdrant upsert 失败: " + req.getCollectionName(), e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public void deletePoints(QdrantPointDeleteReq req) {
|
|
|
|
|
+ if (req == null || StringUtils.isBlank(req.getCollectionName()) || req.getIds() == null || req.getIds().isEmpty()) {
|
|
|
|
|
+ throw new IllegalArgumentException("collectionName 和 ids 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> body = new LinkedHashMap<>();
|
|
|
|
|
+ body.put("points", req.getIds());
|
|
|
|
|
+ body.put("wait", true);
|
|
|
|
|
+ try {
|
|
|
|
|
+ exchange(pointsDeleteUrl(req.getCollectionName()), Method.POST, body);
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("Qdrant delete 失败: " + req.getCollectionName(), e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public Map<String, Object> getPoint(QdrantPointGetReq req) {
|
|
|
|
|
+ if (req == null || StringUtils.isBlank(req.getCollectionName()) || req.getId() == null) {
|
|
|
|
|
+ throw new IllegalArgumentException("collectionName 和 id 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<String, Object> resp = exchange(pointUrl(req.getCollectionName(), req.getId()), Method.GET, null);
|
|
|
|
|
+ Object result = resp.get("result");
|
|
|
|
|
+ if (!(result instanceof Map)) {
|
|
|
|
|
+ return new LinkedHashMap<>();
|
|
|
|
|
+ }
|
|
|
|
|
+ return toPointMap((Map<String, Object>) result);
|
|
|
|
|
+ } catch (IllegalStateException e) {
|
|
|
|
|
+ if (is404(e)) {
|
|
|
|
|
+ return new LinkedHashMap<>();
|
|
|
|
|
+ }
|
|
|
|
|
+ throw e;
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("Qdrant get point 失败: " + req.getCollectionName(), e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @Override
|
|
|
|
|
+ public List<Map<String, Object>> search(QdrantPointSearchReq req) {
|
|
|
|
|
+ if (req == null || StringUtils.isBlank(req.getCollectionName()) || req.getVector() == null || req.getVector().isEmpty()) {
|
|
|
|
|
+ throw new IllegalArgumentException("collectionName 和 vector 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> body = new LinkedHashMap<>();
|
|
|
|
|
+ body.put("vector", req.getVector());
|
|
|
|
|
+ body.put("limit", req.getTopK() == null ? 5 : req.getTopK());
|
|
|
|
|
+ body.put("with_payload", true);
|
|
|
|
|
+ body.put("with_vector", true);
|
|
|
|
|
+ if (req.getFilter() != null && !req.getFilter().isEmpty()) {
|
|
|
|
|
+ body.put("filter", buildFilter(req.getFilter()));
|
|
|
|
|
+ }
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<String, Object> resp = exchange(searchUrl(req.getCollectionName()), Method.POST, body);
|
|
|
|
|
+ Object result = resp.get("result");
|
|
|
|
|
+ if (!(result instanceof List)) {
|
|
|
|
|
+ return new ArrayList<>();
|
|
|
|
|
+ }
|
|
|
|
|
+ List<Map<String, Object>> list = new ArrayList<>();
|
|
|
|
|
+ for (Object item : (List<?>) result) {
|
|
|
|
|
+ if (!(item instanceof Map)) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> source = (Map<String, Object>) item;
|
|
|
|
|
+ Map<String, Object> map = new LinkedHashMap<>();
|
|
|
|
|
+ map.put("id", source.get("id"));
|
|
|
|
|
+ map.put("score", source.get("score"));
|
|
|
|
|
+ Object payload = source.get("payload");
|
|
|
|
|
+ map.put("payload", payload instanceof Map ? payload : new LinkedHashMap<>());
|
|
|
|
|
+ list.add(map);
|
|
|
|
|
+ }
|
|
|
|
|
+ return list;
|
|
|
|
|
+ } catch (Exception e) {
|
|
|
|
|
+ throw new IllegalStateException("Qdrant search 失败: " + req.getCollectionName(), e);
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void createCollectionIfAbsent(String collectionName, int vectorSize) {
|
|
|
|
|
+ QdrantCollectionReq req = new QdrantCollectionReq();
|
|
|
|
|
+ req.setCollectionName(collectionName);
|
|
|
|
|
+ req.setVectorSize(vectorSize);
|
|
|
|
|
+ createCollection(req);
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private boolean collectionExists(String collectionName) {
|
|
|
|
|
+ try {
|
|
|
|
|
+ Map<String, Object> resp = exchange(collectionUrl(collectionName), Method.GET, null);
|
|
|
|
|
+ return resp.get("result") != null || resp.get("status") != null;
|
|
|
|
|
+ } catch (IllegalStateException e) {
|
|
|
|
|
+ if (is404(e)) {
|
|
|
|
|
+ return false;
|
|
|
|
|
+ }
|
|
|
|
|
+ throw e;
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private boolean is404(IllegalStateException e) {
|
|
|
|
|
+ return e.getMessage() != null && e.getMessage().contains("HTTP 404");
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String requireCollectionName(String collectionName) {
|
|
|
|
|
+ if (StringUtils.isBlank(collectionName)) {
|
|
|
|
|
+ throw new IllegalArgumentException("collectionName 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ return collectionName.trim();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private void validateUpsertReq(QdrantPointUpsertReq req) {
|
|
|
|
|
+ if (req == null || StringUtils.isBlank(req.getCollectionName())) {
|
|
|
|
|
+ throw new IllegalArgumentException("collectionName 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getIds() == null || req.getIds().isEmpty()) {
|
|
|
|
|
+ throw new IllegalArgumentException("ids 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getVectors() == null || req.getVectors().isEmpty()) {
|
|
|
|
|
+ throw new IllegalArgumentException("vectors 不能为空");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getIds().size() != req.getVectors().size()) {
|
|
|
|
|
+ throw new IllegalArgumentException("ids 和 vectors 数量不一致");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getDocuments() != null && !req.getDocuments().isEmpty() && req.getDocuments().size() != req.getIds().size()) {
|
|
|
|
|
+ throw new IllegalArgumentException("documents 和 ids 数量不一致");
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getPayloads() != null && !req.getPayloads().isEmpty() && req.getPayloads().size() != req.getIds().size()) {
|
|
|
|
|
+ throw new IllegalArgumentException("payloads 和 ids 数量不一致");
|
|
|
|
|
+ }
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private int resolveVectorSize(List<List<Float>> vectors) {
|
|
|
|
|
+ if (vectors == null || vectors.isEmpty() || vectors.get(0) == null) {
|
|
|
|
|
+ return qdrantProperties.getVectorSize();
|
|
|
|
|
+ }
|
|
|
|
|
+ return vectors.get(0).size();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private Map<String, Object> buildPayload(QdrantPointUpsertReq req, int index) {
|
|
|
|
|
+ Map<String, Object> payload = new LinkedHashMap<>();
|
|
|
|
|
+ if (req.getDocuments() != null && req.getDocuments().size() > index) {
|
|
|
|
|
+ payload.put("document", req.getDocuments().get(index));
|
|
|
|
|
+ }
|
|
|
|
|
+ if (req.getPayloads() != null && req.getPayloads().size() > index && req.getPayloads().get(index) != null) {
|
|
|
|
|
+ payload.putAll(req.getPayloads().get(index));
|
|
|
|
|
+ }
|
|
|
|
|
+ return payload;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private Map<String, Object> buildFilter(Map<String, Object> filterMap) {
|
|
|
|
|
+ List<Map<String, Object>> must = new ArrayList<>();
|
|
|
|
|
+ for (Map.Entry<String, Object> entry : filterMap.entrySet()) {
|
|
|
|
|
+ if (entry.getValue() == null) {
|
|
|
|
|
+ continue;
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> match = new LinkedHashMap<>();
|
|
|
|
|
+ match.put("value", entry.getValue());
|
|
|
|
|
+ Map<String, Object> item = new LinkedHashMap<>();
|
|
|
|
|
+ item.put("key", entry.getKey());
|
|
|
|
|
+ item.put("match", match);
|
|
|
|
|
+ must.add(item);
|
|
|
|
|
+ }
|
|
|
|
|
+ Map<String, Object> filter = new LinkedHashMap<>();
|
|
|
|
|
+ filter.put("must", must);
|
|
|
|
|
+ return filter;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private Map<String, Object> toPointMap(Map<String, Object> result) {
|
|
|
|
|
+ Map<String, Object> map = new LinkedHashMap<>();
|
|
|
|
|
+ map.put("id", result.get("id"));
|
|
|
|
|
+ Object payload = result.get("payload");
|
|
|
|
|
+ map.put("payload", payload instanceof Map ? payload : new LinkedHashMap<>());
|
|
|
|
|
+ map.put("vectors", result.get("vector"));
|
|
|
|
|
+ return map;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ @SuppressWarnings("unchecked")
|
|
|
|
|
+ private Map<String, Object> exchange(String url, Method method, Object body) {
|
|
|
|
|
+ String reqBody = body == null ? null : JSONUtil.toJsonStr(body);
|
|
|
|
|
+ HttpRequest request = new HttpRequest(url).method(method).header("Content-Type", "application/json");
|
|
|
|
|
+ if (StringUtils.isNotBlank(qdrantProperties.getApiKey())) {
|
|
|
|
|
+ request.header("api-key", qdrantProperties.getApiKey());
|
|
|
|
|
+ }
|
|
|
|
|
+ if (reqBody != null) {
|
|
|
|
|
+ request.body(reqBody);
|
|
|
|
|
+ }
|
|
|
|
|
+ HttpResponse response = request.execute();
|
|
|
|
|
+ String respBody = response.body();
|
|
|
|
|
+ log.info("Qdrant HTTP 调用完成, method={}, url={}, status={}, req={}, resp={}",
|
|
|
|
|
+ method, url, response.getStatus(), truncate(reqBody), truncate(respBody));
|
|
|
|
|
+ if (response.getStatus() < 200 || response.getStatus() >= 300) {
|
|
|
|
|
+ throw new IllegalStateException("Qdrant HTTP " + response.getStatus() + " 调用失败: " + url + ", body=" + respBody);
|
|
|
|
|
+ }
|
|
|
|
|
+ if (StringUtils.isBlank(respBody)) {
|
|
|
|
|
+ return new LinkedHashMap<>();
|
|
|
|
|
+ }
|
|
|
|
|
+ Object parsed = JSONUtil.toBean(respBody, Map.class);
|
|
|
|
|
+ return parsed instanceof Map ? (Map<String, Object>) parsed : new LinkedHashMap<>();
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String truncate(String text) {
|
|
|
|
|
+ if (text == null) {
|
|
|
|
|
+ return null;
|
|
|
|
|
+ }
|
|
|
|
|
+ return text.length() > 2000 ? text.substring(0, 2000) + "..." : text;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String baseUrl() {
|
|
|
|
|
+ String baseUrl = StringUtils.trimToEmpty(qdrantProperties.getBaseUrl());
|
|
|
|
|
+ while (baseUrl.endsWith("/")) {
|
|
|
|
|
+ baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
|
|
|
|
|
+ }
|
|
|
|
|
+ return baseUrl;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String collectionUrl(String collectionName) {
|
|
|
|
|
+ return baseUrl() + "/collections/" + collectionName;
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String pointsUrl(String collectionName) {
|
|
|
|
|
+ return collectionUrl(collectionName) + "/points";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String pointsDeleteUrl(String collectionName) {
|
|
|
|
|
+ return pointsUrl(collectionName) + "/delete";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String pointUrl(String collectionName, Long id) {
|
|
|
|
|
+ return pointsUrl(collectionName) + "/" + id + "?with_payload=true&with_vector=true";
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
|
|
+ private String searchUrl(String collectionName) {
|
|
|
|
|
+ return pointsUrl(collectionName) + "/search";
|
|
|
|
|
+ }
|
|
|
|
|
+}
|