AwsClaudeClient.java 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. //package com.ruoyi.aicall.utils;
  2. //
  3. //import software.amazon.awssdk.auth.credentials.AwsBasicCredentials;
  4. //import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider;
  5. //import software.amazon.awssdk.core.SdkBytes;
  6. //import software.amazon.awssdk.regions.Region;
  7. //import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
  8. //import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeAsyncClient;
  9. //import software.amazon.awssdk.services.bedrockruntime.model.*;
  10. //import software.amazon.awssdk.core.async.SdkPublisher;
  11. //
  12. //import java.util.ArrayList;
  13. //import java.util.Arrays;
  14. //import java.util.List;
  15. //import java.util.concurrent.CompletableFuture;
  16. //import java.util.concurrent.ExecutionException;
  17. //
  18. ///**
  19. // * AWS Bedrock Claude 模型调用工具类
  20. // * 支持 Claude 3 (Haiku, Sonnet, Opus) 和 Claude 2 系列模型
  21. // */
  22. //public class AwsClaudeClient implements AutoCloseable {
  23. //
  24. // private final BedrockRuntimeClient bedrockClient;
  25. // private final BedrockRuntimeAsyncClient asyncClient;
  26. // private final String modelId;
  27. //
  28. // // 常用模型ID常量
  29. // public static final String CLAUDE_3_5_SONNET = "anthropic.claude-3-5-sonnet-20240620-v1:0";
  30. // public static final String CLAUDE_3_OPUS = "anthropic.claude-3-opus-20240229-v1:0";
  31. // public static final String CLAUDE_3_SONNET = "anthropic.claude-3-sonnet-20240229-v1:0";
  32. // public static final String CLAUDE_3_HAIKU = "anthropic.claude-3-haiku-20240307-v1:0";
  33. // public static final String CLAUDE_2_1 = "anthropic.claude-v2:1";
  34. // public static final String CLAUDE_2 = "anthropic.claude-v2";
  35. // public static final String CLAUDE_INSTANT = "anthropic.claude-instant-v1";
  36. //
  37. // /**
  38. // * 构造函数
  39. // *
  40. // * @param accessKey AWS Access Key
  41. // * @param secretKey AWS Secret Key
  42. // * @param region AWS 区域 (如 Region.US_EAST_1)
  43. // * @param modelId 模型ID,可使用本类提供的常量
  44. // */
  45. // public AwsClaudeClient(String accessKey, String secretKey, Region region, String modelId) {
  46. // AwsBasicCredentials credentials = AwsBasicCredentials.create(accessKey, secretKey);
  47. // StaticCredentialsProvider credentialsProvider = StaticCredentialsProvider.create(credentials);
  48. //
  49. // this.bedrockClient = BedrockRuntimeClient.builder()
  50. // .region(region)
  51. // .credentialsProvider(credentialsProvider)
  52. // .build();
  53. //
  54. // this.asyncClient = BedrockRuntimeAsyncClient.builder()
  55. // .region(region)
  56. // .credentialsProvider(credentialsProvider)
  57. // .build();
  58. //
  59. // this.modelId = modelId;
  60. // }
  61. //
  62. // /**
  63. // * 简单文本对话(Claude 3 版本)
  64. // *
  65. // * @param userMessage 用户输入的消息
  66. // * @return 模型的回复文本
  67. // */
  68. // public String chat(String userMessage) {
  69. // return chat(userMessage, 1024, 0.7);
  70. // }
  71. //
  72. // /**
  73. // * 带参数控制的文本对话(Claude 3 版本)
  74. // *
  75. // * @param userMessage 用户输入的消息
  76. // * @param maxTokens 最大生成token数
  77. // * @param temperature 温度参数 (0.0 - 1.0)
  78. // * @return 模型的回复文本
  79. // */
  80. // public String chat(String userMessage, int maxTokens, double temperature) {
  81. // // 构建 Claude 3 消息格式
  82. // String payload = String.format(
  83. // "{" +
  84. // "\"anthropic_version\": \"bedrock-2023-05-31\"," +
  85. // "\"max_tokens\": %d," +
  86. // "\"temperature\": %.2f," +
  87. // "\"messages\": [" +
  88. // " {\"role\": \"user\", \"content\": \"%s\"}" +
  89. // "]" +
  90. // "}",
  91. // maxTokens, temperature, escapeJson(userMessage)
  92. // );
  93. //
  94. // return invokeModel(payload);
  95. // }
  96. //
  97. // /**
  98. // * 多轮对话(Claude 3 版本)
  99. // *
  100. // * @param messages 消息列表,包含 role 和 content
  101. // * @param maxTokens 最大生成token数
  102. // * @param temperature 温度参数
  103. // * @return 模型的回复文本
  104. // */
  105. // public String chat(List<Message> messages, int maxTokens, double temperature) {
  106. // StringBuilder messagesJson = new StringBuilder();
  107. // for (int i = 0; i < messages.size(); i++) {
  108. // Message msg = messages.get(i);
  109. // messagesJson.append(String.format(
  110. // "{\"role\": \"%s\", \"content\": \"%s\"}",
  111. // msg.getRole(), escapeJson(msg.getContent())
  112. // ));
  113. // if (i < messages.size() - 1) {
  114. // messagesJson.append(",");
  115. // }
  116. // }
  117. //
  118. // String payload = String.format(
  119. // "{" +
  120. // "\"anthropic_version\": \"bedrock-2023-05-31\"," +
  121. // "\"max_tokens\": %d," +
  122. // "\"temperature\": %.2f," +
  123. // "\"messages\": [%s]" +
  124. // "}",
  125. // maxTokens, temperature, messagesJson.toString()
  126. // );
  127. //
  128. // return invokeModel(payload);
  129. // }
  130. //
  131. // /**
  132. // * 流式调用(使用异步客户端)
  133. // *
  134. // * @param userMessage 用户消息
  135. // * @param callback 流式响应回调
  136. // */
  137. // public void chatStream(String userMessage, StreamCallback callback) {
  138. // String payload = String.format(
  139. // "{" +
  140. // "\"anthropic_version\": \"bedrock-2023-05-31\"," +
  141. // "\"max_tokens\": 4096," +
  142. // "\"messages\": [" +
  143. // " {\"role\": \"user\", \"content\": \"%s\"}" +
  144. // "]" +
  145. // "}",
  146. // escapeJson(userMessage)
  147. // );
  148. //
  149. // InvokeModelWithResponseStreamRequest request = InvokeModelWithResponseStreamRequest.builder()
  150. // .modelId(modelId)
  151. // .body(SdkBytes.fromUtf8String(payload))
  152. // .build();
  153. //
  154. // // 使用异步客户端进行流式调用
  155. // CompletableFuture<Void> future = asyncClient.invokeModelWithResponseStream(request,
  156. // new InvokeModelWithResponseStreamResponseHandler() {
  157. //
  158. // @Override
  159. // public void responseReceived(InvokeModelWithResponseStreamResponse invokeModelWithResponseStreamResponse) {
  160. //
  161. // }
  162. //
  163. // @Override
  164. // public void onEventStream(SdkPublisher<ResponseStream> publisher) {
  165. // publisher.subscribe(event -> {
  166. // if (event instanceof PayloadPart) {
  167. // PayloadPart payloadPart = (PayloadPart) event;
  168. // String chunk = payloadPart.bytes().asUtf8String();
  169. // String content = extractContentFromChunk(chunk);
  170. // if (content != null && !content.isEmpty()) {
  171. // callback.onChunkReceived(content);
  172. // }
  173. // }
  174. // });
  175. // }
  176. //
  177. // @Override
  178. // public void exceptionOccurred(Throwable throwable) {
  179. // callback.onError(throwable);
  180. // }
  181. //
  182. // @Override
  183. // public void complete() {
  184. // callback.onComplete();
  185. // }
  186. // });
  187. //
  188. // // 等待流式调用完成
  189. // try {
  190. // future.get();
  191. // } catch (InterruptedException | ExecutionException e) {
  192. // throw new RuntimeException("流式调用失败", e);
  193. // }
  194. // }
  195. //
  196. // /**
  197. // * 调用模型并解析响应
  198. // */
  199. // private String invokeModel(String payload) {
  200. // InvokeModelRequest request = InvokeModelRequest.builder()
  201. // .modelId(modelId)
  202. // .body(SdkBytes.fromUtf8String(payload))
  203. // .build();
  204. //
  205. // InvokeModelResponse response = bedrockClient.invokeModel(request);
  206. // String responseBody = response.body().asUtf8String();
  207. //
  208. // return extractContent(responseBody);
  209. // }
  210. //
  211. // /**
  212. // * 从响应 JSON 中提取 content 文本
  213. // */
  214. // private String extractContent(String jsonResponse) {
  215. // // Claude 3 格式处理
  216. // int contentStart = jsonResponse.indexOf("\"content\":");
  217. // if (contentStart != -1) {
  218. // int textStart = jsonResponse.indexOf("\"text\":", contentStart);
  219. // if (textStart != -1) {
  220. // int quoteStart = jsonResponse.indexOf("\"", textStart + 7) + 1;
  221. // int quoteEnd = jsonResponse.indexOf("\"", quoteStart);
  222. // if (quoteStart > 0 && quoteEnd > quoteStart) {
  223. // return unescapeJson(jsonResponse.substring(quoteStart, quoteEnd));
  224. // }
  225. // }
  226. // }
  227. //
  228. // // Claude 2 格式处理
  229. // int completionStart = jsonResponse.indexOf("\"completion\":");
  230. // if (completionStart != -1) {
  231. // int quoteStart = jsonResponse.indexOf("\"", completionStart + 13) + 1;
  232. // int quoteEnd = jsonResponse.lastIndexOf("\"");
  233. // if (quoteStart > 0 && quoteEnd > quoteStart) {
  234. // return unescapeJson(jsonResponse.substring(quoteStart, quoteEnd));
  235. // }
  236. // }
  237. //
  238. // return jsonResponse;
  239. // }
  240. //
  241. // /**
  242. // * 从流式响应块中提取内容
  243. // */
  244. // private String extractContentFromChunk(String chunk) {
  245. // // 流式响应通常是 JSON 行格式
  246. // if (chunk.contains("\"type\":\"content_block_delta\"") || chunk.contains("\"delta\"")) {
  247. // int textStart = chunk.indexOf("\"text\":");
  248. // if (textStart != -1) {
  249. // int quoteStart = chunk.indexOf("\"", textStart + 7) + 1;
  250. // int quoteEnd = chunk.indexOf("\"", quoteStart);
  251. // if (quoteStart > 0 && quoteEnd > quoteStart) {
  252. // return unescapeJson(chunk.substring(quoteStart, quoteEnd));
  253. // }
  254. // }
  255. // }
  256. // return chunk;
  257. // }
  258. //
  259. // /**
  260. // * JSON 字符串转义
  261. // */
  262. // private String escapeJson(String input) {
  263. // if (input == null) return "";
  264. // return input.replace("\\", "\\\\")
  265. // .replace("\"", "\\\"")
  266. // .replace("\n", "\\n")
  267. // .replace("\r", "\\r")
  268. // .replace("\t", "\\t");
  269. // }
  270. //
  271. // /**
  272. // * JSON 字符串反转义
  273. // */
  274. // private String unescapeJson(String input) {
  275. // if (input == null) return "";
  276. // return input.replace("\\\"", "\"")
  277. // .replace("\\n", "\n")
  278. // .replace("\\r", "\r")
  279. // .replace("\\t", "\t")
  280. // .replace("\\\\", "\\");
  281. // }
  282. //
  283. // @Override
  284. // public void close() {
  285. // bedrockClient.close();
  286. // asyncClient.close();
  287. // }
  288. //
  289. // // ==================== 内部类和接口 ====================
  290. //
  291. // /**
  292. // * 消息对象,用于多轮对话
  293. // */
  294. // public static class Message {
  295. // private final String role; // "user" 或 "assistant"
  296. // private final String content;
  297. //
  298. // public Message(String role, String content) {
  299. // this.role = role;
  300. // this.content = content;
  301. // }
  302. //
  303. // public String getRole() { return role; }
  304. // public String getContent() { return content; }
  305. //
  306. // public static Message user(String content) {
  307. // return new Message("user", content);
  308. // }
  309. //
  310. // public static Message assistant(String content) {
  311. // return new Message("assistant", content);
  312. // }
  313. // }
  314. //
  315. // /**
  316. // * 流式响应回调接口
  317. // */
  318. // public interface StreamCallback {
  319. // void onChunkReceived(String chunk);
  320. // void onError(Throwable throwable);
  321. // void onComplete();
  322. // }
  323. //
  324. // // ==================== 使用示例 ====================
  325. //
  326. // public static void main(String[] args) {
  327. // // 配置参数(建议从配置文件或环境变量读取)
  328. // String accessKey = "AKIAZSHZ4G4NU37EOYFT";
  329. // String secretKey = "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";
  330. // Region region = Region.US_EAST_1; // Bedrock 主要在此区域
  331. //
  332. // try (AwsClaudeClient client = new AwsClaudeClient(
  333. // accessKey, secretKey, region, CLAUDE_3_5_SONNET)) {
  334. //
  335. // // 1. 简单调用
  336. // System.out.println("=== 简单调用 ===");
  337. // String response = client.chat("你好,请介绍一下你自己");
  338. // System.out.println("响应: " + response);
  339. //
  340. // // 2. 带参数调用
  341. // System.out.println("\n=== 带参数调用 ===");
  342. // String response2 = client.chat(
  343. // "写一首关于春天的诗",
  344. // 2048, // 最大token
  345. // 0.9 // 更高的创造性
  346. // );
  347. // System.out.println("诗歌: " + response2);
  348. //
  349. // // 3. 多轮对话
  350. // System.out.println("\n=== 多轮对话 ===");
  351. // List<Message> messages = Arrays.asList(
  352. // Message.user("你好"),
  353. // Message.assistant("你好!有什么我可以帮助你的吗?"),
  354. // Message.user("请用 Java 写一个单例模式")
  355. // );
  356. // String multiTurnResponse = client.chat(messages, 2048, 0.7);
  357. // System.out.println("多轮响应: " + multiTurnResponse);
  358. //
  359. // // 4. 流式调用
  360. // System.out.println("\n=== 流式调用 ===");
  361. // client.chatStream("请详细解释 Java 的内存模型,包括堆、栈、方法区等", new StreamCallback() {
  362. // @Override
  363. // public void onChunkReceived(String chunk) {
  364. // System.out.print(chunk);
  365. // }
  366. //
  367. // @Override
  368. // public void onError(Throwable throwable) {
  369. // System.err.println("流式调用出错: " + throwable.getMessage());
  370. // }
  371. //
  372. // @Override
  373. // public void onComplete() {
  374. // System.out.println("\n[流式调用完成]");
  375. // }
  376. // });
  377. //
  378. // } catch (Exception e) {
  379. // e.printStackTrace();
  380. // }
  381. // }
  382. //}