@ -19,6 +19,7 @@ import com.github.benmanes.caffeine.cache.Cache;
import com.github.benmanes.caffeine.cache.Caffeine ;
import com.github.benmanes.caffeine.cache.RemovalCause ;
import lombok.RequiredArgsConstructor ;
import lombok.Setter ;
import lombok.extern.slf4j.Slf4j ;
import org.apache.commons.lang3.StringUtils ;
import org.springframework.beans.factory.BeanCreationNotAllowedException ;
@ -57,7 +58,6 @@ import org.thingsboard.server.service.ws.WebSocketSessionType;
import org.thingsboard.server.service.ws.WsCommandsWrapper ;
import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper ;
import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper ;
import org.thingsboard.server.service.ws.telemetry.cmd.v2.AuthCmdUpdate ;
import javax.websocket.RemoteEndpoint ;
import javax.websocket.SendHandler ;
@ -70,10 +70,13 @@ import java.util.Queue;
import java.util.Set ;
import java.util.UUID ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.ConcurrentLinkedQueue ;
import java.util.concurrent.ConcurrentMap ;
import java.util.concurrent.LinkedBlockingQueue ;
import java.util.concurrent.TimeUnit ;
import java.util.concurrent.atomic.AtomicBoolean ;
import java.util.concurrent.atomic.AtomicInteger ;
import java.util.concurrent.locks.Lock ;
import java.util.concurrent.locks.ReentrantLock ;
import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_OF_PING_ATTEMPTS ;
@ -131,58 +134,60 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
session . close ( CloseStatus . SERVER_ERROR . withReason ( "Session not found!" ) ) ;
return ;
}
WebSocketSessionRef sessionRef = sessionMd . sessionRef ;
String msg = message . getPayload ( ) ;
sessionMd . onMsg ( msg ) ;
} catch ( IOException e ) {
log . warn ( "IO error" , e ) ;
}
}
WsCommandsWrapper cmdsWrapper ;
try {
switch ( sessionRef . getSessionType ( ) ) {
case GENERAL :
cmdsWrapper = JacksonUtil . fromString ( msg , WsCommandsWrapper . class ) ;
break ;
case TELEMETRY :
cmdsWrapper = JacksonUtil . fromString ( msg , TelemetryCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
case NOTIFICATIONS :
cmdsWrapper = JacksonUtil . fromString ( msg , NotificationCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
default :
return ;
}
} catch ( Exception e ) {
log . warn ( "Failed to decode subscription cmd: {}" , e . getMessage ( ) , e ) ;
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
webSocketService . sendError ( sessionRef , 1 , SubscriptionErrorCode . BAD_REQUEST , "Failed to parse the payload" ) ;
} else {
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
}
return ;
void processMsg ( SessionMetaData sessionMd , String msg ) throws IOException {
WebSocketSessionRef sessionRef = sessionMd . sessionRef ;
WsCommandsWrapper cmdsWrapper ;
try {
switch ( sessionRef . getSessionType ( ) ) {
case GENERAL :
cmdsWrapper = JacksonUtil . fromString ( msg , WsCommandsWrapper . class ) ;
break ;
case TELEMETRY :
cmdsWrapper = JacksonUtil . fromString ( msg , TelemetryCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
case NOTIFICATIONS :
cmdsWrapper = JacksonUtil . fromString ( msg , NotificationCmdsWrapper . class ) . toCommonCmdsWrapper ( ) ;
break ;
default :
return ;
}
} catch ( Exception e ) {
log . warn ( "Failed to decode subscription cmd: {}" , e . getMessage ( ) , e ) ;
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
log . trace ( "[{}][{}] Processing {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) , msg ) ;
webSocketService . handleCommands ( sessionRef , cmdsWrapper ) ;
webSocketService . sendError ( sessionRef , 1 , SubscriptionErrorCode . BAD_REQUEST , "Failed to parse the payload" ) ;
} else {
AuthCmd authCmd = cmdsWrapper . getAuthCmd ( ) ;
if ( authCmd = = null ) {
close ( sessionRef , CloseStatus . POLICY_VIOLATION . withReason ( "Auth cmd is missing" ) ) ;
return ;
}
log . trace ( "[{}] Authenticating session" , session . getId ( ) ) ;
SecurityUser securityCtx ;
try {
securityCtx = authenticationProvider . authenticate ( authCmd . getToken ( ) ) ;
} catch ( Exception e ) {
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
return ;
}
sessionRef . setSecurityCtx ( securityCtx ) ;
pendingSessions . invalidate ( session . getId ( ) ) ;
establishSession ( session , sessionRef ) ;
webSocketService . sendUpdate ( sessionRef . getSessionId ( ) , new AuthCmdUpdate ( 1 ) ) ;
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
}
} catch ( IOException e ) {
log . warn ( "IO error" , e ) ;
return ;
}
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
log . trace ( "[{}][{}] Processing {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , sessionMd . session . getId ( ) , msg ) ;
webSocketService . handleCommands ( sessionRef , cmdsWrapper ) ;
} else {
AuthCmd authCmd = cmdsWrapper . getAuthCmd ( ) ;
if ( authCmd = = null ) {
close ( sessionRef , CloseStatus . POLICY_VIOLATION . withReason ( "Auth cmd is missing" ) ) ;
return ;
}
log . trace ( "[{}] Authenticating session" , sessionMd . session . getId ( ) ) ;
SecurityUser securityCtx ;
try {
securityCtx = authenticationProvider . authenticate ( authCmd . getToken ( ) ) ;
} catch ( Exception e ) {
close ( sessionRef , CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
return ;
}
sessionRef . setSecurityCtx ( securityCtx ) ;
pendingSessions . invalidate ( sessionMd . session . getId ( ) ) ;
establishSession ( sessionMd . session , sessionRef , sessionMd ) ;
}
}
@ -214,7 +219,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
WebSocketSessionRef sessionRef = toRef ( session ) ;
log . debug ( "[{}][{}] Session opened from address: {}" , sessionRef . getSessionId ( ) , session . getId ( ) , session . getRemoteAddress ( ) ) ;
establishSession ( session , sessionRef ) ;
establishSession ( session , sessionRef , null ) ;
} catch ( InvalidParameterException e ) {
log . warn ( "[{}] Failed to start session" , session . getId ( ) , e ) ;
session . close ( CloseStatus . BAD_DATA . withReason ( e . getMessage ( ) ) ) ;
@ -224,24 +229,26 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
}
private void establishSession ( WebSocketSession session , WebSocketSessionRef sessionRef ) throws IOException {
private void establishSession ( WebSocketSession session , WebSocketSessionRef sessionRef , SessionMetaData sessionMd ) throws IOException {
if ( sessionRef . getSecurityCtx ( ) ! = null ) {
if ( ! checkLimits ( session , sessionRef ) ) {
return ;
}
var tenantProfileConfiguration = getTenantProfileConfiguration ( sessionRef ) ;
int wsTenantProfileQueueLimit = tenantProfileConfiguration ! = null ?
tenantProfileConfiguration . getWsMsgQueueLimitPerSession ( ) : wsMaxQueueMessagesPerSession ;
SessionMetaData sessionMd = new SessionMetaData ( session , sessionRef ,
( wsTenantProfileQueueLimit > 0 & & wsTenantProfileQueueLimit < wsMaxQueueMessagesPerSession ) ?
wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession ) ;
int maxMsgQueueSize = Optional . ofNullable ( getTenantProfileConfiguration ( sessionRef ) )
. map ( DefaultTenantProfileConfiguration : : getWsMsgQueueLimitPerSession )
. filter ( profileLimit - > profileLimit > 0 & & profileLimit < wsMaxQueueMessagesPerSession )
. orElse ( wsMaxQueueMessagesPerSession ) ;
if ( sessionMd = = null ) {
sessionMd = new SessionMetaData ( session , sessionRef ) ;
}
sessionMd . setMaxMsgQueueSize ( maxMsgQueueSize ) ;
internalSessionMap . put ( session . getId ( ) , sessionMd ) ;
externalSessionMap . put ( sessionRef . getSessionId ( ) , session . getId ( ) ) ;
processInWebSocketService ( sessionRef , SessionEvent . onEstablished ( ) ) ;
log . info ( "[{}][{}][{}] Session established from address: {}" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , sessionRef . getSessionId ( ) , session . getId ( ) , session . getRemoteAddress ( ) ) ;
} else {
SessionMetaData sessionMd = new SessionMetaData ( session , sessionRef , wsMaxQueueMessagesPerSession ) ;
sessionMd = new SessionMetaData ( session , sessionRef ) ;
pendingSessions . put ( session . getId ( ) , sessionMd ) ;
externalSessionMap . put ( sessionRef . getSessionId ( ) , session . getId ( ) ) ;
}
@ -328,19 +335,22 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
private final WebSocketSessionRef sessionRef ;
final AtomicBoolean isSending = new AtomicBoolean ( false ) ;
private final Queue < TbWebSocketMsg < ? > > msgQueue ;
private final Queue < TbWebSocketMsg < ? > > outboundMsgQueue = new ConcurrentLinkedQueue < > ( ) ;
private final AtomicInteger outboundMsgQueueSize = new AtomicInteger ( ) ;
@Setter
private int maxMsgQueueSize = wsMaxQueueMessagesPerSession ;
// TODO: msg queue as in org.thingsboard.server.transport.mqtt.session.DeviceSessionCtx
private final Queue < String > inboundMsgQueue = new ConcurrentLinkedQueue < > ( ) ;
private final Lock inboundMsgQueueProcessorLock = new ReentrantLock ( ) ;
private volatile long lastActivityTime ;
SessionMetaData ( WebSocketSession session , WebSocketSessionRef sessionRef , int maxMsgQueuePerSession ) {
SessionMetaData ( WebSocketSession session , WebSocketSessionRef sessionRef ) {
super ( ) ;
this . session = session ;
Session nativeSession = ( ( NativeWebSocketSession ) session ) . getNativeSession ( Session . class ) ;
this . asyncRemote = nativeSession . getAsyncRemote ( ) ;
this . sessionRef = sessionRef ;
this . msgQueue = new LinkedBlockingQueue < > ( maxMsgQueuePerSession ) ;
this . lastActivityTime = System . currentTimeMillis ( ) ;
}
@ -365,7 +375,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
} catch ( IOException ioe ) {
log . trace ( "[{}] Session transport error" , session . getId ( ) , ioe ) ;
} finally {
m sgQueue. clear ( ) ;
outboundM sgQueue. clear ( ) ;
}
}
@ -378,19 +388,14 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
void sendMsg ( TbWebSocketMsg < ? > msg ) {
try {
msgQueue . add ( msg ) ;
} catch ( RuntimeException e ) {
if ( log . isTraceEnabled ( ) ) {
log . trace ( "[{}][{}] Session closed due to queue error" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) , e ) ;
} else {
log . info ( "[{}][{}] Session closed due to queue error" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) ) ;
}
if ( outboundMsgQueueSize . get ( ) < maxMsgQueueSize ) {
outboundMsgQueue . add ( msg ) ;
outboundMsgQueueSize . incrementAndGet ( ) ;
processNextMsg ( ) ;
} else {
log . info ( "[{}][{}] Session closed due to updates queue size exceeded" , sessionRef . getSecurityCtx ( ) . getTenantId ( ) , session . getId ( ) ) ;
closeSession ( CloseStatus . POLICY_VIOLATION . withReason ( "Max pending updates limit reached!" ) ) ;
return ;
}
processNextMsg ( ) ;
}
private void sendMsgInternal ( TbWebSocketMsg < ? > msg ) {
@ -424,16 +429,39 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke
}
private void processNextMsg ( ) {
if ( m sgQueue. isEmpty ( ) | | ! isSending . compareAndSet ( false , true ) ) {
if ( outboundM sgQueue. isEmpty ( ) | | ! isSending . compareAndSet ( false , true ) ) {
return ;
}
TbWebSocketMsg < ? > msg = m sgQueue. poll ( ) ;
TbWebSocketMsg < ? > msg = outboundM sgQueue. poll ( ) ;
if ( msg ! = null ) {
outboundMsgQueueSize . decrementAndGet ( ) ;
sendMsgInternal ( msg ) ;
} else {
isSending . set ( false ) ;
}
}
public void onMsg ( String msg ) throws IOException {
inboundMsgQueue . add ( msg ) ;
tryProcessInboundMsgs ( ) ;
}
void tryProcessInboundMsgs ( ) throws IOException {
while ( ! inboundMsgQueue . isEmpty ( ) ) {
if ( inboundMsgQueueProcessorLock . tryLock ( ) ) {
try {
String msg ;
while ( ( msg = inboundMsgQueue . poll ( ) ) ! = null ) {
processMsg ( this , msg ) ;
}
} finally {
inboundMsgQueueProcessorLock . unlock ( ) ;
}
} else {
return ;
}
}
}
}
@Override