//package com.ruoyi.aicall.utils; // //import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; //import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; //import software.amazon.awssdk.core.SdkBytes; //import software.amazon.awssdk.regions.Region; //import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; //import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient; //import software.amazon.awssdk.services.bedrockruntime.model.*; //import software.amazon.awssdk.core.async.SdkPublisher; // //import java.util.ArrayList; //import java.util.Arrays; //import java.util.List; //import java.util.concurrent.CompletableFuture; //import java.util.concurrent.ExecutionException; // ///** // * AWS Bedrock Claude 模型调用工具类 // * 支持 Claude 3 (Haiku, Sonnet, Opus) 和 Claude 2 系列模型 // */ //public class AwsClaudeClient implements AutoCloseable { // // private final BedrockRuntimeClient bedrockClient; // private final BedrockRuntimeAsyncClient asyncClient; // private final String modelId; // // // 常用模型ID常量 // public static final String CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0"; // public static final String CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0"; // public static final String CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0"; // public static final String CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0"; // public static final String CLAUDE_2_1 = "anthropic.claude-v2:1"; // public static final String CLAUDE_2 = "anthropic.claude-v2"; // public static final String CLAUDE_INSTANT = "anthropic.claude-instant-v1"; // // /** // * 构造函数 // * // * @param accessKey AWS Access Key // * @param secretKey AWS Secret Key // * @param region AWS 区域 (如 Region.US_EAST_1) // * @param modelId 模型ID,可使用本类提供的常量 // */ // public AwsClaudeClient(String accessKey, String secretKey, Region region, String modelId) { // AwsBasicCredentials credentials = AwsBasicCredentials.create(accessKey, secretKey); // StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(credentials); // // this.bedrockClient = BedrockRuntimeClient.builder() // .region(region) // .credentialsProvider(credentialsProvider) // .build(); // // this.asyncClient = BedrockRuntimeAsyncClient.builder() // .region(region) // .credentialsProvider(credentialsProvider) // .build(); // // this.modelId = modelId; // } // // /** // * 简单文本对话(Claude 3 版本) // * // * @param userMessage 用户输入的消息 // * @return 模型的回复文本 // */ // public String chat(String userMessage) { // return chat(userMessage, 1024, 0.7); // } // // /** // * 带参数控制的文本对话(Claude 3 版本) // * // * @param userMessage 用户输入的消息 // * @param maxTokens 最大生成token数 // * @param temperature 温度参数 (0.0 - 1.0) // * @return 模型的回复文本 // */ // public String chat(String userMessage, int maxTokens, double temperature) { // // 构建 Claude 3 消息格式 // String payload = String.format( // "{" + // "\"anthropic_version\": \"bedrock-2023-05-31\"," + // "\"max_tokens\": %d," + // "\"temperature\": %.2f," + // "\"messages\": [" + // " {\"role\": \"user\", \"content\": \"%s\"}" + // "]" + // "}", // maxTokens, temperature, escapeJson(userMessage) // ); // // return invokeModel(payload); // } // // /** // * 多轮对话(Claude 3 版本) // * // * @param messages 消息列表,包含 role 和 content // * @param maxTokens 最大生成token数 // * @param temperature 温度参数 // * @return 模型的回复文本 // */ // public String chat(List messages, int maxTokens, double temperature) { // StringBuilder messagesJson = new StringBuilder(); // for (int i = 0; i < messages.size(); i++) { // Message msg = messages.get(i); // messagesJson.append(String.format( // "{\"role\": \"%s\", \"content\": \"%s\"}", // msg.getRole(), escapeJson(msg.getContent()) // )); // if (i < messages.size() - 1) { // messagesJson.append(","); // } // } // // String payload = String.format( // "{" + // "\"anthropic_version\": \"bedrock-2023-05-31\"," + // "\"max_tokens\": %d," + // "\"temperature\": %.2f," + // "\"messages\": [%s]" + // "}", // maxTokens, temperature, messagesJson.toString() // ); // // return invokeModel(payload); // } // // /** // * 流式调用(使用异步客户端) // * // * @param userMessage 用户消息 // * @param callback 流式响应回调 // */ // public void chatStream(String userMessage, StreamCallback callback) { // String payload = String.format( // "{" + // "\"anthropic_version\": \"bedrock-2023-05-31\"," + // "\"max_tokens\": 4096," + // "\"messages\": [" + // " {\"role\": \"user\", \"content\": \"%s\"}" + // "]" + // "}", // escapeJson(userMessage) // ); // // InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder() // .modelId(modelId) // .body(SdkBytes.fromUtf8String(payload)) // .build(); // // // 使用异步客户端进行流式调用 // CompletableFuture future = asyncClient.invokeModelWithResponseStream(request, // new InvokeModelWithResponseStreamResponseHandler() { // // @Override // public void responseReceived(InvokeModelWithResponseStreamResponse invokeModelWithResponseStreamResponse) { // // } // // @Override // public void onEventStream(SdkPublisher publisher) { // publisher.subscribe(event -> { // if (event instanceof PayloadPart) { // PayloadPart payloadPart = (PayloadPart) event; // String chunk = payloadPart.bytes().asUtf8String(); // String content = extractContentFromChunk(chunk); // if (content != null && !content.isEmpty()) { // callback.onChunkReceived(content); // } // } // }); // } // // @Override // public void exceptionOccurred(Throwable throwable) { // callback.onError(throwable); // } // // @Override // public void complete() { // callback.onComplete(); // } // }); // // // 等待流式调用完成 // try { // future.get(); // } catch (InterruptedException | ExecutionException e) { // throw new RuntimeException("流式调用失败", e); // } // } // // /** // * 调用模型并解析响应 // */ // private String invokeModel(String payload) { // InvokeModelRequest request = InvokeModelRequest.builder() // .modelId(modelId) // .body(SdkBytes.fromUtf8String(payload)) // .build(); // // InvokeModelResponse response = bedrockClient.invokeModel(request); // String responseBody = response.body().asUtf8String(); // // return extractContent(responseBody); // } // // /** // * 从响应 JSON 中提取 content 文本 // */ // private String extractContent(String jsonResponse) { // // Claude 3 格式处理 // int contentStart = jsonResponse.indexOf("\"content\":"); // if (contentStart != -1) { // int textStart = jsonResponse.indexOf("\"text\":", contentStart); // if (textStart != -1) { // int quoteStart = jsonResponse.indexOf("\"", textStart + 7) + 1; // int quoteEnd = jsonResponse.indexOf("\"", quoteStart); // if (quoteStart > 0 && quoteEnd > quoteStart) { // return unescapeJson(jsonResponse.substring(quoteStart, quoteEnd)); // } // } // } // // // Claude 2 格式处理 // int completionStart = jsonResponse.indexOf("\"completion\":"); // if (completionStart != -1) { // int quoteStart = jsonResponse.indexOf("\"", completionStart + 13) + 1; // int quoteEnd = jsonResponse.lastIndexOf("\""); // if (quoteStart > 0 && quoteEnd > quoteStart) { // return unescapeJson(jsonResponse.substring(quoteStart, quoteEnd)); // } // } // // return jsonResponse; // } // // /** // * 从流式响应块中提取内容 // */ // private String extractContentFromChunk(String chunk) { // // 流式响应通常是 JSON 行格式 // if (chunk.contains("\"type\":\"content_block_delta\"") || chunk.contains("\"delta\"")) { // int textStart = chunk.indexOf("\"text\":"); // if (textStart != -1) { // int quoteStart = chunk.indexOf("\"", textStart + 7) + 1; // int quoteEnd = chunk.indexOf("\"", quoteStart); // if (quoteStart > 0 && quoteEnd > quoteStart) { // return unescapeJson(chunk.substring(quoteStart, quoteEnd)); // } // } // } // return chunk; // } // // /** // * JSON 字符串转义 // */ // private String escapeJson(String input) { // if (input == null) return ""; // return input.replace("\\", "\\\\") // .replace("\"", "\\\"") // .replace("\n", "\\n") // .replace("\r", "\\r") // .replace("\t", "\\t"); // } // // /** // * JSON 字符串反转义 // */ // private String unescapeJson(String input) { // if (input == null) return ""; // return input.replace("\\\"", "\"") // .replace("\\n", "\n") // .replace("\\r", "\r") // .replace("\\t", "\t") // .replace("\\\\", "\\"); // } // // @Override // public void close() { // bedrockClient.close(); // asyncClient.close(); // } // // // ==================== 内部类和接口 ==================== // // /** // * 消息对象,用于多轮对话 // */ // public static class Message { // private final String role; // "user" 或 "assistant" // private final String content; // // public Message(String role, String content) { // this.role = role; // this.content = content; // } // // public String getRole() { return role; } // public String getContent() { return content; } // // public static Message user(String content) { // return new Message("user", content); // } // // public static Message assistant(String content) { // return new Message("assistant", content); // } // } // // /** // * 流式响应回调接口 // */ // public interface StreamCallback { // void onChunkReceived(String chunk); // void onError(Throwable throwable); // void onComplete(); // } // // // ==================== 使用示例 ==================== // // public static void main(String[] args) { // // 配置参数(建议从配置文件或环境变量读取) // String accessKey = "AKIAZSHZ4G4NU37EOYFT"; // String secretKey = "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"; // Region region = Region.US_EAST_1; // Bedrock 主要在此区域 // // try (AwsClaudeClient client = new AwsClaudeClient( // accessKey, secretKey, region, CLAUDE_3_5_SONNET)) { // // // 1. 简单调用 // System.out.println("=== 简单调用 ==="); // String response = client.chat("你好,请介绍一下你自己"); // System.out.println("响应: " + response); // // // 2. 带参数调用 // System.out.println("\n=== 带参数调用 ==="); // String response2 = client.chat( // "写一首关于春天的诗", // 2048, // 最大token // 0.9 // 更高的创造性 // ); // System.out.println("诗歌: " + response2); // // // 3. 多轮对话 // System.out.println("\n=== 多轮对话 ==="); // List messages = Arrays.asList( // Message.user("你好"), // Message.assistant("你好!有什么我可以帮助你的吗?"), // Message.user("请用 Java 写一个单例模式") // ); // String multiTurnResponse = client.chat(messages, 2048, 0.7); // System.out.println("多轮响应: " + multiTurnResponse); // // // 4. 流式调用 // System.out.println("\n=== 流式调用 ==="); // client.chatStream("请详细解释 Java 的内存模型,包括堆、栈、方法区等", new StreamCallback() { // @Override // public void onChunkReceived(String chunk) { // System.out.print(chunk); // } // // @Override // public void onError(Throwable throwable) { // System.err.println("流式调用出错: " + throwable.getMessage()); // } // // @Override // public void onComplete() { // System.out.println("\n[流式调用完成]"); // } // }); // // } catch (Exception e) { // e.printStackTrace(); // } // } //}