@ -15,13 +15,16 @@
* /
package org.thingsboard.server.controller.plugin ;
import com.github.benmanes.caffeine.cache.Cache ;
import com.github.benmanes.caffeine.cache.Caffeine ;
import com.github.benmanes.caffeine.cache.RemovalCause ;
import lombok.RequiredArgsConstructor ;
import lombok.extern.slf4j.Slf4j ;
import org.apache.commons.lang3.StringUtils ;
import org.springframework.beans.factory.BeanCreationNotAllowedException ;
import org.springframework.beans.factory.annotation.Autowired ;
import org.springframework.beans.factory.annotation.Value ;
import org.springframework.context.annotation.Lazy ;
import org.springframework.security.core.Authentication ;
import org.springframework.stereotype.Service ;
import org.springframework.web.socket.CloseStatus ;
import org.springframework.web.socket.PongMessage ;
@ -29,6 +32,7 @@ import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession ;
import org.springframework.web.socket.adapter.NativeWebSocketSession ;
import org.springframework.web.socket.handler.TextWebSocketHandler ;
import org.thingsboard.common.util.JacksonUtil ;
import org.thingsboard.server.common.data.TenantProfile ;
import org.thingsboard.server.common.data.exception.ThingsboardErrorCode ;
import org.thingsboard.server.common.data.id.CustomerId ;
@ -40,13 +44,20 @@ import org.thingsboard.server.config.WebSocketConfiguration;
import org.thingsboard.server.dao.tenant.TbTenantProfileCache ;
import org.thingsboard.server.dao.util.limits.RateLimitService ;
import org.thingsboard.server.queue.util.TbCoreComponent ;
import org.thingsboard.server.service.security.auth.jwt.JwtAuthenticationProvider ;
import org.thingsboard.server.service.security.model.SecurityUser ;
import org.thingsboard.server.service.security.model.UserPrincipal ;
import org.thingsboard.server.service.subscription.SubscriptionErrorCode ;
import org.thingsboard.server.service.ws.AuthCmd ;
import org.thingsboard.server.service.ws.SessionEvent ;
import org.thingsboard.server.service.ws.WebSocketMsgEndpoint ;
import org.thingsboard.server.service.ws.WebSocketService ;
import org.thingsboard.server.service.ws.WebSocketSessionRef ;
import org.thingsboard.server.service.ws.WebSocketSessionType ;
import org.thingsboard.server.service.ws.WsCommandsWrapper ;
import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper ;
import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper ;
import org.thingsboard.server.service.ws.telemetry.cmd.v2.AuthCmdUpdate ;
import javax.websocket.RemoteEndpoint ;
import javax.websocket.SendHandler ;
@ -61,6 +72,7 @@ import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.ConcurrentMap ;
import java.util.concurrent.LinkedBlockingQueue ;
import java.util.concurrent.TimeUnit ;
import java.util.concurrent.atomic.AtomicBoolean ;
import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_OF_PING_ATTEMPTS ;
@ -68,20 +80,20 @@ import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_O
@Service
@TbCoreComponent
@Slf4j
@RequiredArgsConstructor
public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocketMsgEndpoint {
private final ConcurrentMap < String , SessionMetaData > internalSessionMap = new ConcurrentHashMap < > ( ) ;
private final ConcurrentMap < String , String > externalSessionMap = new ConcurrentHashMap < > ( ) ;
@Autowired @Lazy
private WebSocketService webSocketService ;
@Autowired
private TbTenantProfileCache tenantProfileCache ;
@Autowired
private RateLimitService rateLimitService ;
@Autowired
private JwtAuthenticationProvider authenticationProvider ;
@Value ( "${server.ws.send_timeout:5000}" )
private long sendTimeout ;
@ -97,16 +109,77 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
private final ConcurrentMap < UserId , Set < String > > regularUserSessionsMap = new ConcurrentHashMap < > ( ) ;
private final ConcurrentMap < UserId , Set < String > > publicUserSessionsMap = new ConcurrentHashMap < > ( ) ;
private final Cache < String , SessionMetaData > pendingSessions = Caffeine . newBuilder ( )
. expireAfterWrite ( 10 , TimeUnit . SECONDS )
. < String , SessionMetaData > removalListener ( ( sessionId , sessionMd , removalCause ) - > {
if ( removalCause = = RemovalCause . EXPIRED & & sessionMd ! = null ) {
try {
close ( sessionMd . sessionRef , CloseStatus . POLICY_VIOLATION ) ;
} catch ( IOException e ) {
log . warn ( "IO error" , e ) ;
}
}
} )
. build ( ) ;
@Override
public void handleTextMessage ( WebSocketSession session , TextMessage message ) {
try {
SessionMetaData sessionMd = internalSessionMap . get ( session . getId ( ) ) ;
if ( sessionMd ! = null ) {
log . trace ( "[{}][{}] Processing {}" , sessionMd . sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) , message . getPayload ( ) ) ;
webSocketService . handleWebSocketMsg ( sessionMd . sessionRef , message . getPayload ( ) ) ;
} else {
SessionMetaData sessionMd = getSessionMd ( session . getId ( ) ) ;
if ( sessionMd = = null ) {
log . trace ( "[{}] Failed to find session" , session . getId ( ) ) ;
session . close ( CloseStatus . SERVER_ERROR . withReason ( "Session not found!" ) ) ;
return ;
}
WebSocketSessionRef sessionRef = sessionMd . sessionRef ;
String msg = message . getPayload ( ) ;
WsCommandsWrapper cmdsWrapper ;
try {
switch ( sessionRef . getSessionType ( ) ) {
case GENERAL :
cmdsWrapper = JacksonUtil . fromString ( msg , WsCommandsWrapper . class ) ;
break ;
case TELEMETRY :
cmdsWrapper = JacksonUtil . fromString ( msg , TelemetryCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
case NOTIFICATIONS :
cmdsWrapper = JacksonUtil . fromString ( msg , NotificationCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
default :
return ;
}
} catch ( Exception e ) {
log . warn ( "Failed to decode subscription cmd: {}" , e . getMessage ( ) , e ) ;
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
webSocketService . sendError ( sessionRef , 1 , SubscriptionErrorCode . BAD_REQUEST , "Failed to parse the payload" ) ;
} else {
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
}
return ;
}
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
log . trace ( "[{}][{}] Processing {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) , msg ) ;
webSocketService . handleCommands ( sessionRef , cmdsWrapper ) ;
} else {
AuthCmd authCmd = cmdsWrapper . getAuthCmd ( ) ;
if ( authCmd = = null ) {
close ( sessionRef , CloseStatus . POLICY_VIOLATION . withReason ( "Auth cmd is missing" ) ) ;
return ;
}
log . trace ( "[{}] Authenticating session" , session . getId ( ) ) ;
SecurityUser securityCtx ;
try {
securityCtx = authenticationProvider . authenticate ( authCmd . getToken ( ) ) ;
} catch ( Exception e ) {
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
return ;
}
sessionRef . setSecurityCtx ( securityCtx ) ;
pendingSessions . invalidate ( session . getId ( ) ) ;
establishSession ( session , sessionRef ) ;
webSocketService . sendUpdate ( sessionRef . getSessionId ( ) , new AuthCmdUpdate ( 1 ) ) ;
}
} catch ( IOException e ) {
log . warn ( "IO error" , e ) ;
@ -116,7 +189,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
@Override
protected void handlePongMessage ( WebSocketSession session , PongMessage message ) throws Exception {
try {
SessionMetaData sessionMd = internalSessionMap . get ( session . getId ( ) ) ;
SessionMetaData sessionMd = getSessionMd ( session . getId ( ) ) ;
if ( sessionMd ! = null ) {
log . trace ( "[{}][{}] Processing pong response {}" , sessionMd . sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) , message . getPayload ( ) ) ;
sessionMd . processPongMessage ( System . currentTimeMillis ( ) ) ;
@ -139,36 +212,45 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
nativeSession . getAsyncRemote ( ) . setSendTimeout ( sendTimeout ) ;
}
}
String internalSessionId = session . getId ( ) ;
WebSocketSessionRef sessionRef = toRef ( session ) ;
String externalSessionId = sessionRef . getSessionId ( ) ;
log . debug ( "[{}][{}] Session opened from address: {}" , sessionRef . getSessionId ( ) , session . getId ( ) , session . getRemoteAddress ( ) ) ;
establishSession ( session , sessionRef ) ;
} catch ( InvalidParameterException e ) {
log . warn ( "[{}] Failed to start session" , session . getId ( ) , e ) ;
session . close ( CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
} catch ( Exception e ) {
log . warn ( "[{}] Failed to start session" , session . getId ( ) , e ) ;
session . close ( CloseStatus . SERVER_ERROR . withReason ( e . getMessage ( ) ) ) ;
}
}
private void establishSession ( WebSocketSession session , WebSocketSessionRef sessionRef ) throws IOException {
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
if ( ! checkLimits ( session , sessionRef ) ) {
return ;
}
var tenantProfileConfiguration = getTenantProfileConfiguration ( sessionRef ) ;
int wsTenantProfileQueueLimit = tenantProfileConfiguration ! = null ?
tenantProfileConfiguration . getWsMsgQueueLimitPerSession ( ) : wsMaxQueueMessagesPerSession ;
internalSessionMap . put ( internalSessionId , new SessionMetaData ( session , sessionRef ,
SessionMetaData sessionMd = new SessionMetaData ( session , sessionRef ,
( wsTenantProfileQueueLimit > 0 & & wsTenantProfileQueueLimit < wsMaxQueueMessagesPerSession ) ?
wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession ) ) ;
wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession ) ;
externalSessionMap . put ( externalSessionId , internalSessionId ) ;
internalSessionMap . put ( session . getId ( ) , sessionMd ) ;
externalSessionMap . put ( sessionRef . getSessionId ( ) , session . getId ( ) ) ;
processInWebSocketService ( sessionRef , SessionEvent . onEstablished ( ) ) ;
log . info ( "[{}][{}][{}] Session is opened from address: {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , externalSessionId , session . getId ( ) , session . getRemoteAddress ( ) ) ;
} catch ( InvalidParameterException e ) {
log . warn ( "[{}] Failed to start session" , session . getId ( ) , e ) ;
session . close ( CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
} catch ( Exception e ) {
log . warn ( "[{}] Failed to start session" , session . getId ( ) , e ) ;
session . close ( CloseStatus . SERVER_ERROR . withReason ( e . getMessage ( ) ) ) ;
log . info ( "[{}][{}][{}] Session established from address: {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , sessionRef . getSessionId ( ) , session . getId ( ) , session . getRemoteAddress ( ) ) ;
} else {
SessionMetaData sessionMd = new SessionMetaData ( session , sessionRef , wsMaxQueueMessagesPerSession ) ;
pendingSessions . put ( session . getId ( ) , sessionMd ) ;
externalSessionMap . put ( sessionRef . getSessionId ( ) , session . getId ( ) ) ;
}
}
@Override
public void handleTransportError ( WebSocketSession session , Throwable tError ) throws Exception {
super . handleTransportError ( session , tError ) ;
SessionMetaData sessionMd = internalSessionMap . get ( session . getId ( ) ) ;
SessionMetaData sessionMd = getSessionMd ( session . getId ( ) ) ;
if ( sessionMd ! = null ) {
processInWebSocketService ( sessionMd . sessionRef , SessionEvent . onError ( tError ) ) ;
} else {
@ -181,10 +263,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
public void afterConnectionClosed ( WebSocketSession session , CloseStatus closeStatus ) throws Exception {
super . afterConnectionClosed ( session , closeStatus ) ;
SessionMetaData sessionMd = internalSessionMap . remove ( session . getId ( ) ) ;
if ( sessionMd = = null ) {
sessionMd = pendingSessions . asMap ( ) . remove ( session . getId ( ) ) ;
}
if ( sessionMd ! = null ) {
cleanupLimits ( session , sessionMd . sessionRef ) ;
externalSessionMap . remove ( sessionMd . sessionRef . getSessionId ( ) ) ;
processInWebSocketService ( sessionMd . sessionRef , SessionEvent . onClosed ( ) ) ;
if ( sessionMd . sessionRef . getSecurityCtx ( ) ! = null ) {
cleanupLimits ( session , sessionMd . sessionRef ) ;
processInWebSocketService ( sessionMd . sessionRef , SessionEvent . onClosed ( ) ) ;
}
log . info ( "[{}][{}][{}] Session is closed" , sessionMd . sessionRef . getSecurityCtx ( ) . getTenantId ( ) , sessionMd . sessionRef . getSessionId ( ) , session . getId ( ) ) ;
} else {
log . info ( "[{}] Session is closed" , session . getId ( ) ) ;
@ -192,8 +279,11 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
private void processInWebSocketService ( WebSocketSessionRef sessionRef , SessionEvent event ) {
if ( sessionRef . getSecurityCtx ( ) = = null ) {
return ;
}
try {
webSocketService . handleWebSocketSessionEvent ( sessionRef , event ) ;
webSocketService . handleSessionEvent ( sessionRef , event ) ;
} catch ( BeanCreationNotAllowedException e ) {
log . warn ( "[{}] Failed to close session due to possible shutdown state" , sessionRef . getSessionId ( ) ) ;
}
@ -210,16 +300,28 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
. orElseThrow ( ( ) - > new InvalidParameterException ( "Unknown session type" ) ) ;
}
SecurityUser currentUser = ( SecurityUser ) ( ( Authentication ) session . getPrincipal ( ) ) . getPrincipal ( ) ;
SecurityUser securityCtx = null ;
String token = StringUtils . substringAfter ( session . getUri ( ) . getQuery ( ) , "token=" ) ;
if ( StringUtils . isNotEmpty ( token ) ) {
securityCtx = authenticationProvider . authenticate ( token ) ;
}
return WebSocketSessionRef . builder ( )
. sessionId ( UUID . randomUUID ( ) . toString ( ) )
. securityCtx ( currentUser )
. securityCtx ( securityCtx )
. localAddress ( session . getLocalAddress ( ) )
. remoteAddress ( session . getRemoteAddress ( ) )
. sessionType ( sessionType )
. build ( ) ;
}
private SessionMetaData getSessionMd ( String internalSessionId ) {
SessionMetaData sessionMd = internalSessionMap . get ( internalSessionId ) ;
if ( sessionMd = = null ) {
sessionMd = pendingSessions . getIfPresent ( internalSessionId ) ;
}
return sessionMd ;
}
class SessionMetaData implements SendHandler {
private final WebSocketSession session ;
private final RemoteEndpoint . Async asyncRemote ;
@ -228,6 +330,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
final AtomicBoolean isSending = new AtomicBoolean ( false ) ;
private final Queue < TbWebSocketMsg < ? > > msgQueue ;
// TODO: msg queue as in org.thingsboard.server.transport.mqtt.session.DeviceSessionCtx
private volatile long lastActivityTime ;
SessionMetaData ( WebSocketSession session , WebSocketSessionRef sessionRef , int maxMsgQueuePerSession ) {
@ -335,7 +439,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
@Override
public void send ( WebSocketSessionRef sessionRef , int subscriptionId , String msg ) throws IOException {
String externalId = sessionRef . getSessionId ( ) ;
log . debug ( "[{}] Process ing {}" , externalId , msg ) ;
log . debug ( "[{}] Send ing {}" , externalId , msg ) ;
String internalId = externalSessionMap . get ( externalId ) ;
if ( internalId ! = null ) {
SessionMetaData sessionMd = internalSessionMap . get ( internalId ) ;
@ -383,7 +487,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
log . debug ( "[{}] Processing close request" , externalId ) ;
String internalId = externalSessionMap . get ( externalId ) ;
if ( internalId ! = null ) {
SessionMetaData sessionMd = internalSessionMap . get ( internalId ) ;
SessionMetaData sessionMd = getSessionMd ( internalId ) ;
if ( sessionMd ! = null ) {
sessionMd . session . close ( reason ) ;
} else {
@ -394,7 +498,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
}
private boolean checkLimits ( WebSocketSession session , WebSocketSessionRef sessionRef ) throws Exception {
private boolean checkLimits ( WebSocketSession session , WebSocketSessionRef sessionRef ) throws IO Exception {
var tenantProfileConfiguration = getTenantProfileConfiguration ( sessionRef ) ;
if ( tenantProfileConfiguration = = null ) {
return true ;