WebSocketServerHandler.java 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361
  1. package com.telerobot.fs.wshandle.nettyserver;
  2. import com.alibaba.fastjson.JSON;
  3. import com.alibaba.fastjson.JSONObject;
  4. import com.telerobot.fs.config.AppContextProvider;
  5. import com.telerobot.fs.entity.dao.BizGroup;
  6. import com.telerobot.fs.service.SysService;
  7. import com.telerobot.fs.utils.CommonUtils;
  8. import com.telerobot.fs.utils.StringUtils;
  9. import com.telerobot.fs.utils.ThreadUtil;
  10. import com.telerobot.fs.wshandle.*;
  11. import com.telerobot.fs.wshandle.SecurityManager;
  12. import io.netty.buffer.ByteBuf;
  13. import io.netty.buffer.Unpooled;
  14. import io.netty.channel.ChannelFuture;
  15. import io.netty.channel.ChannelFutureListener;
  16. import io.netty.channel.ChannelHandler.Sharable;
  17. import io.netty.channel.ChannelHandlerContext;
  18. import io.netty.handler.codec.http.DefaultFullHttpResponse;
  19. import io.netty.handler.codec.http.FullHttpRequest;
  20. import io.netty.handler.codec.http.HttpResponseStatus;
  21. import io.netty.handler.codec.http.HttpVersion;
  22. import io.netty.handler.codec.http.websocketx.*;
  23. import io.netty.util.CharsetUtil;
  24. import org.slf4j.Logger;
  25. import org.slf4j.LoggerFactory;
  26. import org.springframework.stereotype.Component;
  27. import java.util.List;
  28. import java.util.Map;
  29. /**
  30. * websocket 具体业务处理方法
  31. *
  32. * @author DELL
  33. */
  34. @Component
  35. @Sharable
  36. public class WebSocketServerHandler extends BaseWebSocketServerHandler {
  37. private static final Logger logger = LoggerFactory.getLogger(WebSocketServerHandler.class);
  38. /**
  39. * 当客户端连接成功,返回个成功信息
  40. */
  41. @Override
  42. public void channelActive(final ChannelHandlerContext ctx) throws Exception {}
  43. /**
  44. * 当客户端断开连接
  45. */
  46. @Override
  47. public void channelInactive(ChannelHandlerContext ctx) throws Exception {
  48. // 从连接池内剔除
  49. String clientId = ctx.channel().id().asLongText();
  50. logger.debug("client disconnected:{}", clientId);
  51. WebsocketThreadPool.addTask(new Runnable() {
  52. @Override
  53. public void run() {
  54. MessageHandlerEngine engine = MessageHandlerEngineList.getInstance().getMsgHandlerEngine(clientId);
  55. if(engine != null){
  56. InactiveNotice.onDisconnected(engine.getSessionInfo());
  57. }
  58. }
  59. });
  60. MessageHandlerEngineList.getInstance().delete(clientId, true);
  61. ctx.close();
  62. }
  63. @Override
  64. public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
  65. ctx.flush();
  66. }
  67. @Override
  68. protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception {
  69. // http://xxxx
  70. if (msg instanceof FullHttpRequest) {
  71. handleHttpRequest(ctx, (FullHttpRequest) msg);
  72. } else if (msg instanceof WebSocketFrame) {
  73. // ws://xxxx
  74. handlerWebSocketFrame(ctx, (WebSocketFrame) msg);
  75. }
  76. // channelRead0 不需要显式释放msg; (丢弃已接收的消息)
  77. // 如果放入到线程池去处理msg, 需要显式释放这个msg;
  78. }
  79. // @Override
  80. // public void channelRead(ChannelHandlerContext ctx, Object msg) {
  81. // ReferenceCountUtil.release(msg);
  82. // }
  83. public void handlerWebSocketFrame(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
  84. // 关闭请求
  85. if (frame instanceof CloseWebSocketFrame) {
  86. WebSocketServerHandshaker handshaker = Constant.handShakerMap.get(ctx.channel().id().asLongText());
  87. if (handshaker != null) {
  88. handshaker.close(ctx.channel(), (CloseWebSocketFrame) frame.retain());
  89. }
  90. return;
  91. }
  92. // ping请求
  93. if (frame instanceof PingWebSocketFrame) {
  94. ctx.channel().write(new PongWebSocketFrame(frame.content().retain()));
  95. return;
  96. }
  97. // 只支持文本格式,不支持二进制消息
  98. if (!(frame instanceof TextWebSocketFrame)) {
  99. throw new Exception("only plain-text format supported.");
  100. }
  101. // 客服端发送过来的消息
  102. final String msg = ((TextWebSocketFrame) frame).text();
  103. logger.debug("current connections:{}", MessageHandlerEngineList.getInstance().size());
  104. WebsocketThreadPool.addTask(new Runnable() {
  105. @Override
  106. public void run() {
  107. String clientID = ctx.channel().id().asLongText();
  108. logger.debug(String.format("receive message: %s, from: %s , clientId: %s .", msg,
  109. ctx.channel().remoteAddress().toString(), clientID));
  110. MessageHandlerEngine msgEngine = MessageHandlerEngineList.getInstance()
  111. .getMsgHandlerEngine(clientID);
  112. int trycount = 0;
  113. int maxtry = 1000;
  114. long startTime = System.currentTimeMillis();
  115. while (msgEngine == null) {
  116. ThreadUtil.sleep(5);
  117. trycount += 1;
  118. msgEngine = MessageHandlerEngineList.getInstance().getMsgHandlerEngine(clientID);
  119. if (msgEngine != null) {
  120. int spendSeconds = (int) ((System.currentTimeMillis() - startTime) / 1000);
  121. logger.debug("successfully get messageHandlerEngine object, spend seconds : {}" , spendSeconds);
  122. break;
  123. }
  124. if (trycount > maxtry) {
  125. break;
  126. }
  127. }
  128. if (msgEngine == null) {
  129. MessageResponse replyMsg = new MessageResponse();
  130. replyMsg.setMsg("server too busy, can't get msgEngine");
  131. replyMsg.setStatus(500);
  132. ctx.writeAndFlush(new TextWebSocketFrame(replyMsg.toString()));
  133. ctx.close();
  134. logger.error("{} server too busy, can't get msgEngine.", clientID);
  135. return;
  136. }
  137. MsgStruct msgObj = null;
  138. try {
  139. msgObj = JSON.parseObject(msg, MsgStruct.class);
  140. } catch (Exception e) {
  141. sendReplyToAgent(400, "invalid json format.", msgEngine);
  142. return;
  143. }
  144. if (msgObj == null) {
  145. sendReplyToAgent(400, "operation not supported", msgEngine);
  146. return;
  147. }
  148. boolean notHasHeader = StringUtils.isNullOrEmpty(msgObj.getAction());
  149. boolean notHasBody = StringUtils.isNullOrEmpty(msgObj.getBody());
  150. if (notHasHeader || notHasBody) {
  151. sendReplyToAgent(400, "except both 'action' and 'body' in request msg.", msgEngine);
  152. return;
  153. }
  154. if(!msgEngine.checkAuth()){
  155. return;
  156. }
  157. if (msgEngine.getSessionInfo() == null || !msgEngine.getSessionInfo().IsValid()) {
  158. String tips = "can not process your request, phone-bar login timeout.";
  159. logger.warn(tips);
  160. sendReplyToAgent(RespStatus.UNAUTHORIZED, tips, msgEngine);
  161. return;
  162. }
  163. if(!"setHearBeat".equals(msgObj.getAction())){
  164. msgEngine.processMsg(msgObj);
  165. }else{
  166. // 心跳
  167. logger.info("{} recv websocket client heartBeat: {}", msgEngine.getTraceId(), msgEngine.getClientSessionID() );
  168. //每次消息接收,都更新下用户活动时间;
  169. SessionEntity sessionInfo = msgEngine.getSessionInfo();
  170. if (sessionInfo != null) {
  171. sessionInfo.setLastActiveTime(System.currentTimeMillis());
  172. }
  173. }
  174. }
  175. });
  176. }
  177. private void sendReplyToAgent(int statusCode, String msg, MessageHandlerEngine messageHandlerEngine){
  178. MessageResponse replyMsg = new MessageResponse();
  179. replyMsg.setMsg(msg);
  180. replyMsg.setStatus(statusCode);
  181. messageHandlerEngine.sendReplyToAgent(replyMsg);
  182. }
  183. /**
  184. * 第一次请求是http请求,请求头包括ws的信息
  185. * @param ctx
  186. * @param req
  187. */
  188. public void handleHttpRequest(ChannelHandlerContext ctx, FullHttpRequest req) {
  189. if (!req.decoderResult().isSuccess()) {
  190. sendHttpResponse(ctx, req,
  191. new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST));
  192. return;
  193. }
  194. WebSocketServerHandshaker handshaker = Constant.handShakerMap.get(ctx.channel().id().asLongText());
  195. if (handshaker == null) {
  196. String wsuri = "ws://127.0.0.1:1081" + req.uri();
  197. WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory(wsuri, null, true);
  198. handshaker = wsFactory.newHandshaker(req);
  199. Constant.handShakerMap.put(ctx.channel().id().asLongText(), handshaker);
  200. // 在这里处理用户登录;
  201. handleWsLogin(ctx, req.uri(), req);
  202. }
  203. if (handshaker == null) {
  204. // 不支持
  205. WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel());
  206. } else {
  207. handshaker.handshake(ctx.channel(), req);
  208. }
  209. }
  210. private void handleWsLogin(ChannelHandlerContext ctx, final String requestURI, FullHttpRequest req) {
  211. WebsocketThreadPool.addTask(new Runnable() {
  212. @Override
  213. public void run() {
  214. logger.debug(String.format("websocket OnOpen, getId:%s , remoteAddress: %s.",
  215. ctx.channel().id().asLongText(), ctx.channel().remoteAddress().toString()));
  216. MessageResponse replyMsg = new MessageResponse();
  217. String queryString = requestURI;
  218. if (!queryString.contains("?")) {
  219. queryString = "";
  220. } else {
  221. String[] tmpArray = queryString.split("\\?");
  222. if (tmpArray.length == 2) {
  223. queryString = tmpArray[1];
  224. }
  225. }
  226. Map<String, String> params = CommonUtils.processRequestParameter(queryString);
  227. String token = params.get("loginToken");
  228. logger.info("{} recv login request.", token);
  229. Map<String, String> loginMap = CommonUtils.validateToken(token, ""); // Boolean.parseBoolean("true");
  230. if (null == loginMap) {
  231. replyMsg.setStatus(400);
  232. replyMsg.setMsg("token verify failed.");
  233. ctx.writeAndFlush(new TextWebSocketFrame(replyMsg.toString()));
  234. ctx.close();
  235. } else {
  236. String extnum = loginMap.get("extnum");
  237. String opnum = loginMap.get("opnum");
  238. String skillLevel = loginMap.get("skillLevel");
  239. String groupId = loginMap.get("groupId");
  240. String tips = String.format("successfully decode loginToken, extnum=%s, opnum=%s, skillLevel=%s, groupId=%s",
  241. extnum, opnum, skillLevel, groupId
  242. );
  243. logger.info(tips);
  244. if (StringUtils.isNullOrEmpty(extnum) ||
  245. StringUtils.isNullOrEmpty(opnum) ||
  246. StringUtils.isNullOrEmpty(skillLevel) ||
  247. StringUtils.isNullOrEmpty(groupId) ) {
  248. replyMsg.setStatus(400);
  249. replyMsg.setMsg(tips + " ; parameter missing, (extnum、 opnum、 skillLevel、groupId) 至少有一个为空... ");
  250. ctx.writeAndFlush(new TextWebSocketFrame(replyMsg.toString()));
  251. ctx.close();
  252. return;
  253. }
  254. String traceId = String.format("%s-%s:", opnum, extnum);
  255. JSONObject jsonObject = new JSONObject();
  256. jsonObject.put("extnum", extnum);
  257. jsonObject.put("opnum", opnum);
  258. jsonObject.put("groupId", groupId);
  259. List<BizGroup> groups = AppContextProvider.getBean(SysService.class).getAllGroupList();
  260. jsonObject.put("groups", groups);
  261. replyMsg.setStatus(200);
  262. replyMsg.setObject(jsonObject);
  263. // 优先从 Nginx 传递的 header 获取真实 IP
  264. String clientIP = req.headers().get("X-Real-IP");
  265. if (clientIP == null || clientIP.isEmpty()) {
  266. clientIP = req.headers().get("X-Forwarded-For");
  267. // X-Forwarded-For 可能是多个 IP,取第一个
  268. if (clientIP != null && clientIP.contains(",")) {
  269. clientIP = clientIP.split(",")[0].trim();
  270. }
  271. }
  272. // 兜底:如果 header 为空,再用 remoteAddress
  273. if (clientIP == null || clientIP.isEmpty()) {
  274. String remoteAddr = ctx.channel().remoteAddress().toString();
  275. clientIP = CommonUtils.getIpFromFullAddress(remoteAddr);
  276. }
  277. logger.info("real client IP: {}", clientIP);
  278. // String remoteAddr = ctx.channel().remoteAddress().toString();
  279. // String clientIP = CommonUtils.getIpFromFullAddress(remoteAddr);
  280. SessionEntity sessionEntity = new SessionEntity();
  281. sessionEntity.setClientIp(clientIP);
  282. sessionEntity.setExtNum(extnum);
  283. sessionEntity.setOpNum(opnum);
  284. sessionEntity.setSessionId(ctx.channel().id().asLongText());
  285. sessionEntity.setLastActiveTime(System.currentTimeMillis());
  286. sessionEntity.setSkillLevel(Integer.parseInt(skillLevel));
  287. sessionEntity.setLoginTime(System.currentTimeMillis());
  288. sessionEntity.setGroupId(groupId);
  289. boolean addSessionOk = SessionManager.getInstance().add(sessionEntity); // 添加到会话管理
  290. if (!addSessionOk) {
  291. logger.error("{} failed to add current session.", traceId);
  292. return;
  293. }
  294. SecurityManager.getInstance().addClientIpToFirewallWhiteList(clientIP);
  295. logger.info("{} successfully add current session.", traceId);
  296. MessageHandlerEngine myEngine = new MessageHandlerEngine(ctx);
  297. logger.info("{} successfully create MsgEngine for current user.", traceId);
  298. if (!MessageHandlerEngineList.getInstance().add(myEngine)) {
  299. logger.error("{} failed to add MsgEngine to SysList.", traceId);
  300. }
  301. myEngine.initSession(sessionEntity);// 初始化session信息
  302. myEngine.sendReplyToAgent(replyMsg);
  303. }
  304. }
  305. });
  306. }
  307. public static void sendHttpResponse(ChannelHandlerContext ctx, FullHttpRequest req, DefaultFullHttpResponse res) {
  308. // 返回应答给客户端
  309. if (res.status().code() != 200) {
  310. ByteBuf buf = Unpooled.copiedBuffer(res.status().toString(), CharsetUtil.UTF_8);
  311. res.content().writeBytes(buf);
  312. buf.release();
  313. }
  314. // 如果是非Keep-Alive,关闭连接
  315. ChannelFuture f = ctx.channel().writeAndFlush(res);
  316. if (!isKeepAlive(req) || res.status().code() != 200) {
  317. f.addListener(ChannelFutureListener.CLOSE);
  318. }
  319. }
  320. private static boolean isKeepAlive(FullHttpRequest req) {
  321. return false;
  322. }
  323. // 异常处理,netty默认是关闭channel
  324. @Override
  325. public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
  326. if(cause instanceof java.io.IOException) {
  327. logger.info("netty IOException: {}", cause.toString());
  328. }else{
  329. logger.error("netty exceptionCaught: {}", cause.toString());
  330. }
  331. MessageHandlerEngineList.getInstance().delete(ctx.channel().id().asLongText(), true);
  332. ctx.close();
  333. }
  334. }