netty簡(jiǎn)介
Netty是由JBOSS提供的一個(gè)java開(kāi)源框架。Netty提供異步的、事件驅(qū)動(dòng)的網(wǎng)絡(luò)應(yīng)用程序框架和工具,用以快速開(kāi)發(fā)高性能、高可靠性的網(wǎng)絡(luò)服務(wù)器和客戶(hù)端程序。Netty是基于Java NIO實(shí)現(xiàn)的異步通信框架,其主要特點(diǎn)是簡(jiǎn)單,要比原生的JavaNIO開(kāi)發(fā)方便很多,同時(shí)Netty封裝了大量好用的組件,方便開(kāi)發(fā)。源碼地址:https://github.com/netty/netty,下面就用netty官方給出的websocket服務(wù)demo改動(dòng)而來(lái),嵌入在spring-boot工程里面,直接開(kāi)搞
引入netty依賴(lài)
在pom.xml 里面添加netty依賴(lài)
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-all</artifactId>
<version>4.1.77.Final</version>
<type>pom</type>
</dependency>
jwt依賴(lài)
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt</artifactId>
<version>0.9.1</version>
</dependency>
socket server主類(lèi)
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.Channel;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.ApplicationArguments;
import org.springframework.boot.ApplicationRunner;
import org.springframework.stereotype.Component;
@Slf4j
public class WebSocketServer implements ApplicationRunner {
@Value("${lk.socket.port:7011}")
private Integer socketPort;
@Override
public void run(ApplicationArguments args) throws Exception {
// 獲取Reactor線(xiàn)程池
// 主線(xiàn)程
EventLoopGroup bossGroup = new NioEventLoopGroup(1);
// 工作線(xiàn)程
EventLoopGroup workerGroup = new NioEventLoopGroup();
try {
ServerBootstrap b = new ServerBootstrap();
b.group(bossGroup, workerGroup)
.channel(NioServerSocketChannel.class) // 用NIO selector 接受新的連接
.handler(new LoggingHandler(LogLevel.INFO)) // 日志級(jí)別
.childHandler(new WebSocketServerInitializer()); // 自定義業(yè)務(wù)hander
// bootstrap 還可以設(shè)置TCP參數(shù),根據(jù)需要可以分別設(shè)置主線(xiàn)程池和從線(xiàn)程池參數(shù),來(lái)優(yōu)化服務(wù)端性能。
// 其中主線(xiàn)程池使用option方法來(lái)設(shè)置,從線(xiàn)程池使用childOption方法設(shè)置。
// backlog表示主線(xiàn)程池中在套接口排隊(duì)的最大數(shù)量,隊(duì)列由未連接隊(duì)列(三次握手未完成的)和已連接隊(duì)列
//.option(ChannelOption.SO_BACKLOG, 5)
// 表示連接?;?,相當(dāng)于心跳機(jī)制,默認(rèn)為7200s
//.childOption(ChannelOption.SO_KEEPALIVE, true);
Channel ch = b.bind(socketPort).sync().channel();
log.info("websocket server has been started at port {}",socketPort);
ch.closeFuture().sync();
} finally {
bossGroup.shutdownGracefully();
workerGroup.shutdownGracefully();
}
}
}
這里大部分都是固定寫(xiě)法,只有WebSocketServerInitializer類(lèi)是需要我們自定義自己業(yè)務(wù)的
WebSocketServerInitializer類(lèi)
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.HttpObjectAggregator;
import io.netty.handler.codec.http.HttpServerCodec;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.timeout.IdleStateHandler;
/**
*/
public class WebSocketServerInitializer extends ChannelInitializer<SocketChannel> {
public static final String WEBSOCKET_PATH = "/websocket";
@Override
public void initChannel(SocketChannel ch) throws Exception {
// 職責(zé)鏈模式,添加需要處理的Handler
ChannelPipeline pipeline = ch.pipeline();
pipeline.addLast(new HttpServerCodec());// http編解碼器,websocket 本身是基于http協(xié)議的
pipeline.addLast(new HttpObjectAggregator(65536)); // http的 chunked 的消息聚合為完成的請(qǐng)求FullHttpRequest,內(nèi)容最大長(zhǎng)度65535
pipeline.addLast(new WebSocketServerCompressionHandler()); // WebSocket 數(shù)據(jù)壓縮擴(kuò)展
pipeline.addLast(new WebSocketSecurityHandler()); // 權(quán)限校驗(yàn)hander 根據(jù)業(yè)務(wù)需求可選
// WebSocket 握手、控制幀處理
pipeline.addLast(new WebSocketServerProtocolHandler(WEBSOCKET_PATH, null, true));
// 通道業(yè)務(wù)處理hander
pipeline.addLast(new WebSocketFrameHandler());
// 心跳空閑檢測(cè)設(shè)置
pipeline.addLast(new IdleStateHandler(600,600,3600));
// 心跳空閑事件處理
pipeline.addLast(new HeartBeatHandler());
}
}
HttpServerCodec、HttpObjectAggregator、WebSocketServerCompressionHandler、WebSocketServerProtocolHandler、IdleStateHandler 這些hander都是netty默認(rèn)自帶的處理器,WebSocketSecurityHandler、WebSocketFrameHandler、HeartBeatHandler 這些都是需要根據(jù)我們自己的業(yè)務(wù)是實(shí)現(xiàn)的hander,一個(gè)個(gè)來(lái),show me the code
WebSocketSecurityHandler 權(quán)限校驗(yàn)
為什么需要權(quán)限校驗(yàn)?zāi)兀?/p>
- 對(duì)客戶(hù)端的合法性進(jìn)行校驗(yàn),不合法的客戶(hù)端在握手階段就可以?huà)仐?/li>
- 客戶(hù)端帶著token過(guò)來(lái),我們可以識(shí)別是哪個(gè)業(yè)務(wù)client,這個(gè)client需要跟我們的業(yè)務(wù)綁定
http協(xié)議升級(jí)為websocket協(xié)議的過(guò)程中會(huì)通過(guò)WebSocketSecurityHandler 校驗(yàn)
import cn.hutool.core.map.MapUtil;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler;
import lombok.extern.slf4j.Slf4j;
import java.util.Map;
import java.util.Objects;
@Slf4j
public class WebSocketSecurityHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
@Override
protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest request) {
// 解析客戶(hù)端的帶上來(lái)的queryString websocket連接字符串:ws://127.0.0.1:7011/websocket?token=eyJhbGciOiJIUzUxMiJ9.eyJleHAiOjE2NTQ2NjA1MTEsInVzZXJJZCI6Ind1Z2FuZ2xpIn0.e17O09aD8WrzfwA7UGkwVIByQGuElKyFsAZlrYueH55FiCUjgLcDmXPz7nAuyPfUpswQKPVCC9lx5q0hXdJAeQ
// 客戶(hù)端連接字符串由服務(wù)端業(yè)務(wù)接口下發(fā)
Map<String, String> paramMap = getUrlParams(request.uri());
String token = paramMap.get("token");
// 這里我們用的jwt生成的token
String userId = JwtUtil.validateToken(token);
if (userId != null) {
log.info("token校驗(yàn)通過(guò),user id:" + userId);
request.setUri(WebSocketServerInitializer.WEBSOCKET_PATH);
// 客戶(hù)端業(yè)務(wù)id放入header,以便后續(xù)業(yè)務(wù)綁定,也可以直接放到channel,ctx.channel().attr(AttributeKey).set(userId);
request.headers().set("userId", userId);
// 校驗(yàn)通過(guò)之后,傳遞到下一個(gè)hander處理
ctx.fireChannelRead(request.retain());
// 只是首次校驗(yàn),之后消息傳遞不需要權(quán)限校驗(yàn)
ctx.pipeline().remove(WebSocketSecurityHandler.class);
} else {
log.error("user id " + userId + "token校驗(yàn)不通過(guò)");
ctx.close();
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
log.error("error:\r\n" + cause.toString());
}
private static Map<String, String> getUrlParams(String url) {
Map<String, String> map = MapUtil.newHashMap(10);
url = url.replace(SeparatorEnum.QUESTION.getCode(), SeparatorEnum.SEMICOLON.getCode());
if (!url.contains(SeparatorEnum.SEMICOLON.getCode())) {
return map;
}
if (url.split(SeparatorEnum.SEMICOLON.getCode()).length > 0) {
String[] arr = url.split(SeparatorEnum.SEMICOLON.getCode())[1].split(SeparatorEnum.AND.getCode());
for (String s : arr) {
String key = s.split(SeparatorEnum.EQUALS.getCode())[0];
String value = s.split(SeparatorEnum.EQUALS.getCode())[1];
map.put(key, value);
}
return map;
} else {
return map;
}
}
}
WebSocketFrameHandler WebSocket通道處理器
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.SimpleChannelInboundHandler;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.*;
import io.netty.handler.codec.http.websocketx.*;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class WebSocketFrameHandler extends SimpleChannelInboundHandler<WebSocketFrame> {
@Override
public void handlerAdded(ChannelHandlerContext ctx) throws Exception {
log.info("new connection");
}
@Override
public void handlerRemoved(ChannelHandlerContext ctx) throws Exception {
String userId = SocketUtil.getChannelClientId(ctx.channel());
log.info("disconnect: {}", userId);
ChannelHolder.removeChannel(userId);
}
@Override
protected void channelRead0(ChannelHandlerContext ctx, WebSocketFrame frame) throws Exception {
if (frame instanceof CloseWebSocketFrame) {
CloseWebSocketFrame closeFrame = (CloseWebSocketFrame) frame;
log.info("client id: {},close reason:{} ", ctx.channel().id().asShortText(), closeFrame.reasonText());
return;
}
if (frame instanceof PingWebSocketFrame) {
ctx.write(new PongWebSocketFrame(frame.content().retain()));
return;
}
if (frame instanceof TextWebSocketFrame) {
String request = ((TextWebSocketFrame) frame).text();
log.info("client_id:{} channel id:{} input:{}", SocketUtil.getChannelClientId(ctx.channel()),ctx.channel().id().asShortText() ,request);
// 心跳包固定回復(fù)OK
if ("Heartbeat Packet".equals(request)) {
ctx.channel().writeAndFlush(new TextWebSocketFrame("ok"));
} else {
// 非心跳包的處理,根據(jù)自己業(yè)務(wù)擴(kuò)展
ctx.channel().writeAndFlush(new TextWebSocketFrame("hello js!!!"));
}
return;
}
if (frame instanceof BinaryWebSocketFrame) {
ctx.write(frame.retain());
}
}
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
SocketChannel channel = (SocketChannel) ctx.channel();
// 握手完成事件,表示http協(xié)議成功升級(jí)為websockt協(xié)議
if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) {
WebSocketServerProtocolHandler.HandshakeComplete handshakeCompletedEvent = (WebSocketServerProtocolHandler.HandshakeComplete) evt;
// 握手請(qǐng)求頭
HttpHeaders headers = handshakeCompletedEvent.requestHeaders();
String userId = headers.get("userId");
channel.attr(SocketUtil.userIdKey).set(userId);
log.info("client id: {},HandshakeComplete", userId);
log.info("request headers:{}", headers);
ChannelHolder.addChannel(userId, channel);
}
}
@Override
public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
ctx.close();
log.error("error:\r\n" + cause.toString());
}
}
HeartBeatHandler心跳處理器
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.timeout.IdleState;
import io.netty.handler.timeout.IdleStateEvent;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class HeartBeatHandler extends ChannelInboundHandlerAdapter {
@Override
public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception {
if (evt instanceof IdleStateEvent) {
IdleStateEvent idleStateEvent = (IdleStateEvent) evt;
Channel channel = ctx.channel();
String clientId = SocketUtil.getChannelClientId(channel);
// 讀空閑
if (idleStateEvent.state() == IdleState.READER_IDLE) {
log.info("client:{} READER_IDLE...", clientId);
// 寫(xiě)空閑
} else if (idleStateEvent.state() == IdleState.WRITER_IDLE) {
log.info("client:{} WRITER_IDLE...", clientId);
// 讀寫(xiě)空閑
} else if (idleStateEvent.state() == IdleState.ALL_IDLE) {
log.info("client:{} ALL_IDLE...", clientId);
SocketUtil.closeChannel(channel);
}
}
}
觸發(fā)讀空閑、寫(xiě)空閑、讀寫(xiě)空閑事件是根據(jù)
// 心跳空閑檢測(cè)設(shè)置
pipeline.addLast(new IdleStateHandler(600,600,3600));
設(shè)置的空閑時(shí)間來(lái)觸發(fā)??梢栽谑录锩孀鱿鄳?yīng)的業(yè)務(wù)邏輯
ChannelHolder
import io.netty.channel.socket.SocketChannel;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
// 管理客戶(hù)端通道連接類(lèi)
public class ChannelHolder {
public static Map<String, SocketChannel> channelMap = new ConcurrentHashMap<>();
public static void addChannel(String clientId, SocketChannel channel){
channelMap.put(clientId, channel);
}
public static Map<String, SocketChannel> getChannels(){
return channelMap;
}
public static SocketChannel getChannel(String clientId){
return channelMap.get(clientId);
}
public static void removeChannel(String clientId){
channelMap.remove(clientId);
}
public static int getSize(){
return channelMap.size();
}
}
JsonUtil json工具類(lèi)
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.SerializationFeature;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.text.SimpleDateFormat;
import java.util.TimeZone;
public class JsonUtil {
private static final Logger logger = LoggerFactory.getLogger(JsonUtil.class);
public static String toString(Object o) {
try {
return (getMapper().writeValueAsString(o));
} catch (Exception e) {
logger.error("Error writing json object: {}", e.getMessage());
}
return "";
}
public static <T> T fromString(String s, Class<T> cls) {
try {
return getMapper().readValue(s, cls);
} catch (Exception e) {
logger.error("Error parse string to json object: {}", e.getMessage());
}
return null;
}
public static <T> T fromString(String s, TypeReference<T> typeReference) {
try {
return getMapper().readValue(s, typeReference);
} catch (Exception e) {
logger.error("Error parse string to json object: {}", e.getMessage());
}
return null;
}
private static ObjectMapper mapper;
public static ObjectMapper getMapper() {
if (mapper == null) {
mapper = new ObjectMapper();
mapper.setTimeZone(TimeZone.getTimeZone("GMT+8"));
}
return mapper;
}
public static void configureDateFormatString() {
getMapper().configure(SerializationFeature.WRITE_DATES_AS_TIMESTAMPS, false);
getMapper().setDateFormat(new SimpleDateFormat("yyyy-MM-dd HH:mm:ss"));
}
public static void configureTimeZone(TimeZone timeZone) {
getMapper().setTimeZone(timeZone);
}
}
JwtUtil
package com.mdkw.likang.websocket;
import java.util.Date;
import java.util.HashMap;
import java.util.Map;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Configuration;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.SignatureAlgorithm;
public class JwtUtil {
private static long EXPIRATION_TIME = 3600000L;
private static String SECRET = "#V8o8cpr&xql&@uP";
/**
* 生成jwtToken
*
* @param userId
* @return
*/
public static String generateToken(String userId) {
HashMap<String, Object> map = new HashMap<>();
// you can put any data in the map
map.put("userId", userId);
String jwt = Jwts.builder().setClaims(map).setExpiration(new Date(System.currentTimeMillis() + EXPIRATION_TIME))
.signWith(SignatureAlgorithm.HS512, SECRET).compact();
return jwt;
}
/**
* 校驗(yàn)jwtToken
*
* @param token
* @return
*/
public static String validateToken(String token) {
if (token != null) {
Map<String, Object> body = Jwts.parser().setSigningKey(SECRET).parseClaimsJws(token).getBody();
String username = (String) (body.get("userId"));
if (username == null || username.isEmpty()) {
return null;
} else {
return username;
}
}
return null;
}
public static long getEXPIRATION_TIME() {
return JwtUtil.EXPIRATION_TIME;
}
static class TokenValidationException extends RuntimeException {
/**
*
*/
private static final long serialVersionUID = -7946690694369283250L;
public TokenValidationException(String msg) {
super(msg);
}
}
}
SeparatorEnum
import lombok.AllArgsConstructor;
import lombok.Getter;
@Getter
@AllArgsConstructor
public enum SeparatorEnum {
/**
* 逗號(hào)
*/
COMMA(","),
SLASH("/"),
LINE("\\|"),
QUESTION("?"),
SEMICOLON(";"),
EQUALS("="),
POUND("#"),
MINUS("-"),
AND("&"),
UNDERLINE("_"),
SPOT("."),
SPOT_E("\\."),
APOSTROPHE("'"),
PERCENTAGE("%"),
GT(">")
;
/**
* 值
*/
public final String code;
}
SocketUtil 封裝的工具類(lèi),供業(yè)務(wù)端掉用
package com.mdkw.likang.websocket;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.http.websocketx.TextWebSocketFrame;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
@Slf4j
public class SocketUtil {
public static AttributeKey<String> userIdKey = AttributeKey.valueOf("userId");
public static <T> void pushMsgToOne(String id, T content) {
SocketChannel channel = ChannelHolder.getChannel(id);
if (channel == null) {
log.error("channel {} does not exist",id);
return;
}
ChannelMsg<T> channelMsg = ChannelMsg.newInstance(content);
send(channel, channelMsg);
}
private static <T> void send(SocketChannel channel, ChannelMsg<T> channelMsg) {
String userId = getChannelClientId(channel);
if (channel.isActive()) {
channel.writeAndFlush(new TextWebSocketFrame(JsonUtil.toString(channelMsg))).addListener((ChannelFutureListener) channelFuture -> {
if (channelFuture.isSuccess()) {
log.info("send msg to user:{} successful,content:{}",userId,JsonUtil.toString(channelMsg));
} else {
log.info("send msg to user:{} failed,content:{}",userId,JsonUtil.toString(channelMsg));
closeChannel(channel);
}
});
} else {
closeChannel(channel);
}
}
public static void closeChannel(Channel channel) {
channel.close();
String userId = getChannelClientId(channel);
ChannelHolder.removeChannel(userId);
}
public static String getChannelClientId(Channel channel) {
return channel.attr(userIdKey).get();
}
public static <T> void pushMsgToAll(T msg) {
ChannelHolder.getChannels().values().forEach((channel) -> {
ChannelMsg<T> channelMsg = ChannelMsg.newInstance(msg);
send(channel, channelMsg);
});
}
}
js客戶(hù)端
<html><head><title>Web Socket Test</title></head>
<body>
<script type="text/javascript">
var socket;
if (!window.WebSocket) {
window.WebSocket = window.MozWebSocket;
}
if (window.WebSocket) {
// 連接地址修改為自己的真實(shí)地址,一般為業(yè)務(wù)接口提供
socket = new WebSocket("ws://127.0.0.1:7011/websocket?token=eyJhbGciOiJIUzUxMiJ9.eyJleHAiOjE2NTQ2NjA1MTEsInVzZXJJZCI6Ind1Z2FuZ2xpIn0.e17O09aD8WrzfwA7UGkwVIByQGuElKyFsAZlrYueH55FiCUjgLcDmXPz7nAuyPfUpswQKPVCC9lx5q0hXdJAeQ");
socket.onmessage = function(event) {
var ta = document.getElementById('responseText');
ta.value = ta.value + '\n' + event.data
};
socket.onopen = function(event) {
var ta = document.getElementById('responseText');
ta.value = "Web Socket opened!";
setInterval("keepalive()", 5000);
};
socket.onclose = function(event) {
var ta = document.getElementById('responseText');
ta.value = ta.value + "Web Socket closed";
};
} else {
alert("Your browser does not support Web Socket.");
}
function send(message) {
if (!window.WebSocket) { return; }
if (socket.readyState == WebSocket.OPEN) {
socket.send(message);
} else {
alert("The socket is not open.");
}
}
function keepalive(){
var dataContent = "Heartbeat Packet";
// 發(fā)送心跳
socket.send(dataContent);
}
</script>
<form onsubmit="return false;">
<input type="text" name="message" value="Hello, World!"/><input type="button" value="Send Web Socket Data"
onclick="send(this.form.message.value)" />
<h3>Output</h3>
<textarea id="responseText" style="width:500px;height:300px;"></textarea>
</form>
</body>
</html>