Bladeren bron

chroma 模块

ct 1 dag geleden
bovenliggende
commit
0df0fa0030

+ 18 - 0
fs-ai-api/src/main/java/com/fs/ai/rag/KnowledgeVectorService.java

@@ -3,8 +3,14 @@ package com.fs.ai.rag;
 import com.fs.ai.rag.dto.DeleteReq;
 import com.fs.ai.rag.dto.CreateDatabaseReq;
 import com.fs.ai.rag.dto.CreateTenantReq;
+import com.fs.ai.rag.dto.CollectionCreateReq;
+import com.fs.ai.rag.dto.CollectionDeleteReq;
+import com.fs.ai.rag.dto.CollectionListReq;
 import com.fs.ai.rag.dto.IndexReq;
 import com.fs.ai.rag.dto.QueryReq;
+import com.fs.ai.rag.dto.RecordDeleteReq;
+import com.fs.ai.rag.dto.RecordQueryReq;
+import com.fs.ai.rag.dto.RecordUpsertReq;
 
 import java.util.List;
 import java.util.Map;
@@ -19,4 +25,16 @@ public interface KnowledgeVectorService {
     Map<String, Object> createTenant(CreateTenantReq req);
 
     Map<String, Object> createDatabase(CreateDatabaseReq req);
+
+    Map<String, Object> createCollection(CollectionCreateReq req);
+
+    Object listCollections(CollectionListReq req);
+
+    void deleteCollection(CollectionDeleteReq req);
+
+    Map<String, Object> upsertRecords(RecordUpsertReq req);
+
+    Map<String, Object> queryRecords(RecordQueryReq req);
+
+    Map<String, Object> deleteRecords(RecordDeleteReq req);
 }

+ 41 - 0
fs-ai-api/src/main/java/com/fs/ai/rag/controller/AiRagController.java

@@ -58,6 +58,43 @@ public class AiRagController {
         return R.ok().put("data", knowledgeVectorService.createDatabase(req));
     }
 
+    @ApiOperation("创建集合(可指定 tenant_id / database)")
+    @PostMapping("/collection/create")
+    public R createCollection(@RequestBody CollectionCreateReq req) {
+        return R.ok().put("data", knowledgeVectorService.createCollection(req));
+    }
+
+    @ApiOperation("集合列表(可指定 tenant_id / database)")
+    @PostMapping("/collection/list")
+    public R listCollections(@RequestBody(required = false) CollectionListReq req) {
+        return R.ok().put("data", knowledgeVectorService.listCollections(req == null ? new CollectionListReq() : req));
+    }
+
+    @ApiOperation("删除集合(collection_id 或 collection_name)")
+    @PostMapping("/collection/delete")
+    public R deleteCollection(@RequestBody CollectionDeleteReq req) {
+        knowledgeVectorService.deleteCollection(req);
+        return R.ok();
+    }
+
+    @ApiOperation("记录 upsert(collection_id 或 collection_name)")
+    @PostMapping("/record/upsert")
+    public R upsertRecords(@RequestBody RecordUpsertReq req) {
+        return R.ok().put("data", knowledgeVectorService.upsertRecords(req));
+    }
+
+    @ApiOperation("记录 query(collection_id 或 collection_name)")
+    @PostMapping("/record/query")
+    public R queryRecords(@RequestBody RecordQueryReq req) {
+        return R.ok().put("data", knowledgeVectorService.queryRecords(req));
+    }
+
+    @ApiOperation("记录 delete(collection_id 或 collection_name)")
+    @PostMapping("/record/delete")
+    public R deleteRecords(@RequestBody RecordDeleteReq req) {
+        return R.ok().put("data", knowledgeVectorService.deleteRecords(req));
+    }
+
     @SuppressWarnings("unchecked")
     private List<IndexReq> toIndexReqList(Object req) {
         if (req instanceof Map) {
@@ -85,6 +122,10 @@ public class AiRagController {
     private IndexReq convertDocument(Map<String, Object> doc) {
         IndexReq req = new IndexReq();
         req.setDocId(toStr(doc.get("doc_id")));
+        req.setTenantCode(toStr(doc.get("tenantCode")));
+        if (req.getTenantCode() == null) {
+            req.setTenantCode(toStr(doc.get("tenant_code")));
+        }
         req.setTenantId(toStr(doc.get("tenant_id")));
         req.setText(toStr(doc.get("content")));
 

+ 2 - 0
fs-ai-api/src/main/java/com/fs/ai/rag/dto/IndexReq.java

@@ -7,6 +7,8 @@ import lombok.Data;
 @Data
 @ApiModel("知识入库")
 public class IndexReq {
+    @ApiModelProperty("租户编码(将作为 Chroma tenant)")
+    private String tenantCode;
     @ApiModelProperty("公司/租户维度,用作 Chroma metadata.tenant_id")
     private String tenantId;
     @ApiModelProperty("业务主键,全文更新时先按此删旧向量")

+ 275 - 6
fs-ai-api/src/main/java/com/fs/ai/rag/impl/KnowledgeVectorServiceImpl.java

@@ -2,11 +2,17 @@ package com.fs.ai.rag.impl;
 
 import com.fs.ai.rag.AiRagProperties;
 import com.fs.ai.rag.KnowledgeVectorService;
+import com.fs.ai.rag.dto.CollectionCreateReq;
+import com.fs.ai.rag.dto.CollectionDeleteReq;
+import com.fs.ai.rag.dto.CollectionListReq;
 import com.fs.ai.rag.dto.CreateDatabaseReq;
 import com.fs.ai.rag.dto.CreateTenantReq;
 import com.fs.ai.rag.dto.DeleteReq;
 import com.fs.ai.rag.dto.IndexReq;
 import com.fs.ai.rag.dto.QueryReq;
+import com.fs.ai.rag.dto.RecordDeleteReq;
+import com.fs.ai.rag.dto.RecordQueryReq;
+import com.fs.ai.rag.dto.RecordUpsertReq;
 import lombok.RequiredArgsConstructor;
 import lombok.extern.slf4j.Slf4j;
 import org.apache.commons.lang3.StringUtils;
@@ -24,6 +30,8 @@ import java.util.*;
 @Slf4j
 @RequiredArgsConstructor
 public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
+    private static final String WORKFLOW_DATABASE = "ai_workflow";
+    private static final String WORKFLOW_COLLECTION = "workflow_knowledge_base";
 
     private final AiRagProperties props;
     private final RestTemplate restTemplate;
@@ -57,9 +65,19 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
             return;
         }
         validateIndex(req);
-        EmbeddingProfile embeddingProfile = resolveEmbeddingProfile(req.getCollectionName());
-        ChromaScope scope = resolveScope();
-        String collId = getOrCreateCollectionId(scope, req.getCollectionName());
+        String tenantCode = StringUtils.trimToNull(req.getTenantCode());
+        String targetCollection = WORKFLOW_COLLECTION;
+        ChromaScope scope = StringUtils.isNotBlank(tenantCode)
+                ? new ChromaScope(tenantCode, WORKFLOW_DATABASE)
+                : resolveScope();
+        if (StringUtils.isBlank(tenantCode) && StringUtils.isNotBlank(req.getCollectionName())) {
+            targetCollection = req.getCollectionName();
+        }
+
+        EmbeddingProfile embeddingProfile = resolveEmbeddingProfile(targetCollection);
+        String collId = StringUtils.isNotBlank(tenantCode)
+                ? getOrCreateWorkflowCollectionId(scope)
+                : getOrCreateCollectionId(scope, targetCollection);
 
         deleteByDocIdInternal(scope, collId, req.getDocId());
 
@@ -79,7 +97,7 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
             ids.add(req.getDocId() + "_" + i);
             Map<String, Object> meta = new LinkedHashMap<>();
             meta.put("doc_id", req.getDocId());
-            meta.put("tenant_id", req.getTenantId());
+            meta.put("tenant_id", StringUtils.defaultIfBlank(req.getTenantCode(), req.getTenantId()));
             meta.put("chunk_index", i);
             meta.put("total_chunks", n);
             metadatas.add(meta);
@@ -161,6 +179,75 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
         return postJsonReturnMap(url, body);
     }
 
+    @Override
+    public Map<String, Object> createCollection(CollectionCreateReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        String name = trim(req.getName());
+        if (StringUtils.isBlank(name)) {
+            throw new IllegalArgumentException("name 不能为空");
+        }
+        String url = tenantsDatabasesPath(scope) + "/collections";
+        Map<String, Object> body = new LinkedHashMap<>();
+        body.put("name", name);
+        body.put("get_or_create", req.getGetOrCreate() == null ? Boolean.TRUE : req.getGetOrCreate());
+        if (req.getMetadata() != null) {
+            body.put("metadata", req.getMetadata());
+        }
+        return postJsonReturnMap(url, body);
+    }
+
+    @Override
+    public Object listCollections(CollectionListReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        int limit = req.getLimit() == null ? 100 : req.getLimit();
+        int offset = req.getOffset() == null ? 0 : req.getOffset();
+        String url = tenantsDatabasesPath(scope) + "/collections?limit=" + limit + "&offset=" + offset;
+        return getJsonObject(url);
+    }
+
+    @Override
+    public void deleteCollection(CollectionDeleteReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        String collectionId = trim(req.getCollectionId());
+        if (StringUtils.isBlank(collectionId)) {
+            String collectionName = trim(req.getCollectionName());
+            if (StringUtils.isBlank(collectionName)) {
+                throw new IllegalArgumentException("collection_id 或 collection_name 必填其一");
+            }
+            collectionId = findCollectionIdByName(scope, collectionName);
+            if (StringUtils.isBlank(collectionId)) {
+                throw new IllegalArgumentException("未找到集合: " + collectionName);
+            }
+        }
+        String url = collectionsPath(scope, collectionId);
+        HttpHeaders h = chromaHeaders();
+        restTemplate.exchange(url, HttpMethod.DELETE, new HttpEntity<>(h), String.class);
+    }
+
+    @Override
+    public Map<String, Object> upsertRecords(RecordUpsertReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        String collectionId = resolveCollectionId(scope, req);
+        String url = collectionsPath(scope, collectionId) + "/upsert";
+        return postJsonReturnMap(url, stripControlFields(req));
+    }
+
+    @Override
+    public Map<String, Object> queryRecords(RecordQueryReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        String collectionId = resolveCollectionId(scope, req);
+        String url = collectionsPath(scope, collectionId) + "/query";
+        return postJsonReturnMap(url, stripControlFields(req));
+    }
+
+    @Override
+    public Map<String, Object> deleteRecords(RecordDeleteReq req) {
+        ChromaScope scope = resolveScopeOverride(req);
+        String collectionId = resolveCollectionId(scope, req);
+        String url = collectionsPath(scope, collectionId) + "/delete";
+        return postJsonReturnMap(url, stripControlFields(req));
+    }
+
     private void deleteByDocIdInternal(ChromaScope scope, String collectionId, String docId) {
         String url = collectionsPath(scope, collectionId) + "/delete";
         Map<String, Object> where = new LinkedHashMap<>();
@@ -221,6 +308,120 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
         }
     }
 
+    private ChromaScope resolveScopeOverride(CollectionCreateReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private ChromaScope resolveScopeOverride(CollectionListReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private ChromaScope resolveScopeOverride(CollectionDeleteReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private ChromaScope resolveScopeOverride(RecordUpsertReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private ChromaScope resolveScopeOverride(RecordQueryReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private ChromaScope resolveScopeOverride(RecordDeleteReq req) {
+        String tid = req == null ? null : trim(req.getTenantId());
+        String db = req == null ? null : trim(req.getDatabase());
+        if (StringUtils.isNotBlank(tid) && StringUtils.isNotBlank(db)) {
+            return new ChromaScope(tid, db);
+        }
+        ChromaScope fallback = resolveScope();
+        return new ChromaScope(StringUtils.defaultIfBlank(tid, fallback.tenantId),
+                StringUtils.defaultIfBlank(db, fallback.databaseName));
+    }
+
+    private String resolveCollectionId(ChromaScope scope, RecordUpsertReq req) {
+        String collectionId = trim(req.getCollectionId());
+        if (StringUtils.isNotBlank(collectionId)) {
+            return collectionId;
+        }
+        String collectionName = trim(req.getCollectionName());
+        if (StringUtils.isBlank(collectionName)) {
+            throw new IllegalArgumentException("collection_id 或 collection_name 必填其一");
+        }
+        String resolved = findCollectionIdByName(scope, collectionName);
+        if (StringUtils.isBlank(resolved)) {
+            throw new IllegalArgumentException("未找到集合: " + collectionName);
+        }
+        return resolved;
+    }
+
+    private String resolveCollectionId(ChromaScope scope, RecordQueryReq req) {
+        String collectionId = trim(req.getCollectionId());
+        if (StringUtils.isNotBlank(collectionId)) {
+            return collectionId;
+        }
+        String collectionName = trim(req.getCollectionName());
+        if (StringUtils.isBlank(collectionName)) {
+            throw new IllegalArgumentException("collection_id 或 collection_name 必填其一");
+        }
+        String resolved = findCollectionIdByName(scope, collectionName);
+        if (StringUtils.isBlank(resolved)) {
+            throw new IllegalArgumentException("未找到集合: " + collectionName);
+        }
+        return resolved;
+    }
+
+    private String resolveCollectionId(ChromaScope scope, RecordDeleteReq req) {
+        String collectionId = trim(req.getCollectionId());
+        if (StringUtils.isNotBlank(collectionId)) {
+            return collectionId;
+        }
+        String collectionName = trim(req.getCollectionName());
+        if (StringUtils.isBlank(collectionName)) {
+            throw new IllegalArgumentException("collection_id 或 collection_name 必填其一");
+        }
+        String resolved = findCollectionIdByName(scope, collectionName);
+        if (StringUtils.isBlank(resolved)) {
+            throw new IllegalArgumentException("未找到集合: " + collectionName);
+        }
+        return resolved;
+    }
+
     private String getOrCreateCollectionId(ChromaScope scope, String name) {
         String existing = findCollectionIdByName(scope, name);
         if (existing != null) {
@@ -238,6 +439,34 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
         return Objects.requireNonNull(findCollectionIdByName(scope, name), "创建集合失败: " + name);
     }
 
+    private String getOrCreateWorkflowCollectionId(ChromaScope scope) {
+        String existing = findCollectionIdByName(scope, WORKFLOW_COLLECTION);
+        if (existing != null) {
+            return existing;
+        }
+        String url = tenantsDatabasesPath(scope) + "/collections";
+        Map<String, Object> body = new LinkedHashMap<>();
+        Map<String, Object> configuration = new LinkedHashMap<>();
+        Map<String, Object> hnsw = new LinkedHashMap<>();
+        hnsw.put("space", "cosine");
+        configuration.put("hnsw", hnsw);
+        body.put("configuration", configuration);
+        body.put("get_or_create", Boolean.TRUE);
+        Map<String, Object> metadata = new LinkedHashMap<>();
+        metadata.put("description", "用户创建的工作流知识库");
+        metadata.put("owner", "team-ai");
+        body.put("metadata", metadata);
+        body.put("name", WORKFLOW_COLLECTION);
+        body.put("schema", null);
+        Map<String, Object> resp = postJsonReturnMap(url, body);
+        Object id = resp.get("id");
+        if (id != null) {
+            return id.toString();
+        }
+        return Objects.requireNonNull(findCollectionIdByName(scope, WORKFLOW_COLLECTION),
+                "创建默认工作流集合失败: " + WORKFLOW_COLLECTION);
+    }
+
     @SuppressWarnings("unchecked")
     private String findCollectionIdByName(ChromaScope scope, String name) {
         String url = tenantsDatabasesPath(scope) + "/collections?limit=500&offset=0";
@@ -270,8 +499,11 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
     }
 
     private void validateIndex(IndexReq req) {
-        if (StringUtils.isAnyBlank(req.getTenantId(), req.getDocId(), req.getCollectionName())) {
-            throw new IllegalArgumentException("tenantId/docId/collectionName 不能为空");
+        if (StringUtils.isBlank(req.getTenantCode()) && StringUtils.isBlank(req.getTenantId())) {
+            throw new IllegalArgumentException("tenantCode 或 tenantId 至少填一个");
+        }
+        if (StringUtils.isBlank(req.getDocId())) {
+            throw new IllegalArgumentException("docId 不能为空");
         }
         if (StringUtils.isBlank(req.getText())) {
             throw new IllegalArgumentException("text 不能为空");
@@ -410,6 +642,43 @@ public class KnowledgeVectorServiceImpl implements KnowledgeVectorService {
         return url + "/v1/embeddings";
     }
 
+    private String trim(Object val) {
+        return val == null ? null : StringUtils.trimToNull(String.valueOf(val));
+    }
+
+    private Map<String, Object> stripControlFields(RecordUpsertReq req) {
+        Map<String, Object> body = new LinkedHashMap<>();
+        body.put("ids", req.getIds());
+        body.put("documents", req.getDocuments());
+        body.put("embeddings", req.getEmbeddings());
+        body.put("metadatas", req.getMetadatas());
+        return body;
+    }
+
+    private Map<String, Object> stripControlFields(RecordQueryReq req) {
+        Map<String, Object> body = new LinkedHashMap<>();
+        body.put("query_embeddings", req.getQueryEmbeddings());
+        body.put("n_results", req.getNResults());
+        if (req.getWhere() != null) {
+            body.put("where", req.getWhere());
+        }
+        if (req.getInclude() != null) {
+            body.put("include", req.getInclude());
+        }
+        return body;
+    }
+
+    private Map<String, Object> stripControlFields(RecordDeleteReq req) {
+        Map<String, Object> body = new LinkedHashMap<>();
+        if (req.getIds() != null) {
+            body.put("ids", req.getIds());
+        }
+        if (req.getWhere() != null) {
+            body.put("where", req.getWhere());
+        }
+        return body;
+    }
+
     private void postJson(String url, Map<String, Object> body) {
         HttpHeaders h = chromaHeaders();
         h.setContentType(MediaType.APPLICATION_JSON);

+ 0 - 2
fs-ai-api/src/main/resources/application.yml

@@ -23,8 +23,6 @@ ai:
     # 若 Chroma 开启鉴权,填写 x-chroma-token(OpenAPI 名称 ApiKeyAuth)
     chroma-token: ""
     # 可选:手动指定租户 UUID、库名;留空则尝试 GET /auth/identity 自动解析
-    chroma-tenant-id: "saasai"
-    chroma-database: ylrz_saas_ai
     embedding-url: https://api.openai.com/v1/embeddings
     embedding-api-key: ""
     embedding-model: text-embedding-3-small