diff --git a/application/src/main/java/org/thingsboard/server/config/ThingsboardSecurityConfiguration.java b/application/src/main/java/org/thingsboard/server/config/ThingsboardSecurityConfiguration.java index 55370b04b8..6b29ba0b28 100644 --- a/application/src/main/java/org/thingsboard/server/config/ThingsboardSecurityConfiguration.java +++ b/application/src/main/java/org/thingsboard/server/config/ThingsboardSecurityConfiguration.java @@ -32,14 +32,12 @@ import org.springframework.security.config.annotation.method.configuration.Enabl import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.http.SessionCreationPolicy; -import org.springframework.security.crypto.bcrypt.BCryptPasswordEncoder; import org.springframework.security.oauth2.client.web.OAuth2AuthorizationRequestResolver; import org.springframework.security.web.SecurityFilterChain; import org.springframework.security.web.authentication.AuthenticationFailureHandler; import org.springframework.security.web.authentication.AuthenticationSuccessHandler; import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; import org.springframework.security.web.header.writers.StaticHeadersWriter; -import org.springframework.security.web.util.matcher.AntPathRequestMatcher; import org.springframework.web.cors.UrlBasedCorsConfigurationSource; import org.springframework.web.filter.CorsFilter; import org.springframework.web.filter.ShallowEtagHeaderFilter; @@ -63,7 +61,7 @@ import java.util.List; @Configuration @EnableWebSecurity -@EnableGlobalMethodSecurity(prePostEnabled=true) +@EnableGlobalMethodSecurity(prePostEnabled = true) @Order(SecurityProperties.BASIC_AUTH_ORDER) @TbCoreComponent public class ThingsboardSecurityConfiguration { @@ -79,7 +77,7 @@ public class ThingsboardSecurityConfiguration { public static final String TOKEN_REFRESH_ENTRY_POINT = "/api/auth/token"; protected static final String[] NON_TOKEN_BASED_AUTH_ENTRY_POINTS = new String[] {"/index.html", "/assets/**", "/static/**", "/api/noauth/**", "/webjars/**", "/api/license/**"}; public static final String TOKEN_BASED_AUTH_ENTRY_POINT = "/api/**"; - public static final String WS_TOKEN_BASED_AUTH_ENTRY_POINT = "/api/ws/**"; + public static final String WS_ENTRY_POINT = "/api/ws/**"; public static final String MAIL_OAUTH2_PROCESSING_ENTRY_POINT = "/api/admin/mail/oauth2/code"; public static final String DEVICE_CONNECTIVITY_CERTIFICATE_DOWNLOAD_ENTRY_POINT = "/api/device-connectivity/mqtts/certificate/download"; @@ -115,10 +113,6 @@ public class ThingsboardSecurityConfiguration { @Qualifier("jwtHeaderTokenExtractor") private TokenExtractor jwtHeaderTokenExtractor; - @Autowired - @Qualifier("jwtQueryTokenExtractor") - private TokenExtractor jwtQueryTokenExtractor; - @Autowired private AuthenticationManager authenticationManager; @Autowired private RateLimitProcessingFilter rateLimitProcessingFilter; @@ -150,7 +144,7 @@ public class ThingsboardSecurityConfiguration { protected JwtTokenAuthenticationProcessingFilter buildJwtTokenAuthenticationProcessingFilter() throws Exception { List pathsToSkip = new ArrayList<>(Arrays.asList(NON_TOKEN_BASED_AUTH_ENTRY_POINTS)); - pathsToSkip.addAll(Arrays.asList(WS_TOKEN_BASED_AUTH_ENTRY_POINT, TOKEN_REFRESH_ENTRY_POINT, FORM_BASED_LOGIN_ENTRY_POINT, + pathsToSkip.addAll(Arrays.asList(WS_ENTRY_POINT, TOKEN_REFRESH_ENTRY_POINT, FORM_BASED_LOGIN_ENTRY_POINT, PUBLIC_LOGIN_ENTRY_POINT, DEVICE_API_ENTRY_POINT, WEBJARS_ENTRY_POINT, MAIL_OAUTH2_PROCESSING_ENTRY_POINT, DEVICE_CONNECTIVITY_CERTIFICATE_DOWNLOAD_ENTRY_POINT)); SkipPathRequestMatcher matcher = new SkipPathRequestMatcher(pathsToSkip, TOKEN_BASED_AUTH_ENTRY_POINT); @@ -167,15 +161,6 @@ public class ThingsboardSecurityConfiguration { return filter; } - @Bean - protected JwtTokenAuthenticationProcessingFilter buildWsJwtTokenAuthenticationProcessingFilter() throws Exception { - AntPathRequestMatcher matcher = new AntPathRequestMatcher(WS_TOKEN_BASED_AUTH_ENTRY_POINT); - JwtTokenAuthenticationProcessingFilter filter - = new JwtTokenAuthenticationProcessingFilter(failureHandler, jwtQueryTokenExtractor, matcher); - filter.setAuthenticationManager(this.authenticationManager); - return filter; - } - @Bean public AuthenticationManager authenticationManager(ObjectPostProcessor objectPostProcessor) throws Exception { DefaultAuthenticationEventPublisher eventPublisher = objectPostProcessor @@ -229,7 +214,7 @@ public class ThingsboardSecurityConfiguration { .antMatchers(NON_TOKEN_BASED_AUTH_ENTRY_POINTS).permitAll() // static resources, user activation and password reset end-points .and() .authorizeRequests() - .antMatchers(WS_TOKEN_BASED_AUTH_ENTRY_POINT).authenticated() // Protected WebSocket API End-points + .antMatchers(WS_ENTRY_POINT).permitAll() // WebSocket API End-points .antMatchers(TOKEN_BASED_AUTH_ENTRY_POINT).authenticated() // Protected API End-points .and() .exceptionHandling().accessDeniedHandler(restAccessDeniedHandler) @@ -238,7 +223,6 @@ public class ThingsboardSecurityConfiguration { .addFilterBefore(buildRestPublicLoginProcessingFilter(), UsernamePasswordAuthenticationFilter.class) .addFilterBefore(buildJwtTokenAuthenticationProcessingFilter(), UsernamePasswordAuthenticationFilter.class) .addFilterBefore(buildRefreshTokenProcessingFilter(), UsernamePasswordAuthenticationFilter.class) - .addFilterBefore(buildWsJwtTokenAuthenticationProcessingFilter(), UsernamePasswordAuthenticationFilter.class) .addFilterAfter(rateLimitProcessingFilter, UsernamePasswordAuthenticationFilter.class); if (oauth2Configuration != null) { http.oauth2Login() diff --git a/application/src/main/java/org/thingsboard/server/config/WebSocketConfiguration.java b/application/src/main/java/org/thingsboard/server/config/WebSocketConfiguration.java index e6097ec5d2..467e20adb8 100644 --- a/application/src/main/java/org/thingsboard/server/config/WebSocketConfiguration.java +++ b/application/src/main/java/org/thingsboard/server/config/WebSocketConfiguration.java @@ -19,25 +19,13 @@ import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; -import org.springframework.http.HttpStatus; -import org.springframework.http.server.ServerHttpRequest; -import org.springframework.http.server.ServerHttpResponse; -import org.springframework.security.core.Authentication; -import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.config.annotation.EnableWebSocket; import org.springframework.web.socket.config.annotation.WebSocketConfigurer; import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry; -import org.springframework.web.socket.server.HandshakeInterceptor; import org.springframework.web.socket.server.standard.ServletServerContainerFactoryBean; -import org.springframework.web.socket.server.support.HttpSessionHandshakeInterceptor; -import org.thingsboard.server.common.data.exception.ThingsboardErrorCode; -import org.thingsboard.server.common.data.exception.ThingsboardException; import org.thingsboard.server.controller.plugin.TbWebSocketHandler; import org.thingsboard.server.queue.util.TbCoreComponent; -import org.thingsboard.server.service.security.model.SecurityUser; - -import java.util.Map; @Configuration @TbCoreComponent @@ -66,39 +54,7 @@ public class WebSocketConfiguration implements WebSocketConfigurer { log.error("TbWebSocketHandler expected but [{}] provided", wsHandler); throw new RuntimeException("TbWebSocketHandler expected but " + wsHandler + " provided"); } - registry.addHandler(wsHandler, WS_API_MAPPING).setAllowedOriginPatterns("*") - .addInterceptors(new HttpSessionHandshakeInterceptor(), new HandshakeInterceptor() { - - @Override - public boolean beforeHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, - Map attributes) throws Exception { - SecurityUser user = null; - try { - user = getCurrentUser(); - } catch (ThingsboardException ex) { - } - if (user == null) { - response.setStatusCode(HttpStatus.UNAUTHORIZED); - return false; - } else { - return true; - } - } - - @Override - public void afterHandshake(ServerHttpRequest request, ServerHttpResponse response, WebSocketHandler wsHandler, - Exception exception) { - //Do nothing - } - }); + registry.addHandler(wsHandler, WS_API_MAPPING).setAllowedOriginPatterns("*"); } - protected SecurityUser getCurrentUser() throws ThingsboardException { - Authentication authentication = SecurityContextHolder.getContext().getAuthentication(); - if (authentication != null && authentication.getPrincipal() instanceof SecurityUser) { - return (SecurityUser) authentication.getPrincipal(); - } else { - throw new ThingsboardException("You aren't authorized to perform this operation!", ThingsboardErrorCode.AUTHENTICATION); - } - } } diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index de073aca9a..855cbbc1e5 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -15,13 +15,16 @@ */ package org.thingsboard.server.controller.plugin; +import com.github.benmanes.caffeine.cache.Cache; +import com.github.benmanes.caffeine.cache.Caffeine; +import com.github.benmanes.caffeine.cache.RemovalCause; +import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.lang3.StringUtils; import org.springframework.beans.factory.BeanCreationNotAllowedException; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Lazy; -import org.springframework.security.core.Authentication; import org.springframework.stereotype.Service; import org.springframework.web.socket.CloseStatus; import org.springframework.web.socket.PongMessage; @@ -29,6 +32,7 @@ import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.adapter.NativeWebSocketSession; import org.springframework.web.socket.handler.TextWebSocketHandler; +import org.thingsboard.common.util.JacksonUtil; import org.thingsboard.server.common.data.TenantProfile; import org.thingsboard.server.common.data.exception.ThingsboardErrorCode; import org.thingsboard.server.common.data.id.CustomerId; @@ -40,13 +44,20 @@ import org.thingsboard.server.config.WebSocketConfiguration; import org.thingsboard.server.dao.tenant.TbTenantProfileCache; import org.thingsboard.server.dao.util.limits.RateLimitService; import org.thingsboard.server.queue.util.TbCoreComponent; +import org.thingsboard.server.service.security.auth.jwt.JwtAuthenticationProvider; import org.thingsboard.server.service.security.model.SecurityUser; import org.thingsboard.server.service.security.model.UserPrincipal; +import org.thingsboard.server.service.subscription.SubscriptionErrorCode; +import org.thingsboard.server.service.ws.AuthCmd; import org.thingsboard.server.service.ws.SessionEvent; import org.thingsboard.server.service.ws.WebSocketMsgEndpoint; import org.thingsboard.server.service.ws.WebSocketService; import org.thingsboard.server.service.ws.WebSocketSessionRef; import org.thingsboard.server.service.ws.WebSocketSessionType; +import org.thingsboard.server.service.ws.WsCommandsWrapper; +import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper; +import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper; +import org.thingsboard.server.service.ws.telemetry.cmd.v2.AuthCmdUpdate; import javax.websocket.RemoteEndpoint; import javax.websocket.SendHandler; @@ -61,6 +72,7 @@ import java.util.UUID; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_OF_PING_ATTEMPTS; @@ -68,20 +80,20 @@ import static org.thingsboard.server.service.ws.DefaultWebSocketService.NUMBER_O @Service @TbCoreComponent @Slf4j +@RequiredArgsConstructor public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocketMsgEndpoint { private final ConcurrentMap internalSessionMap = new ConcurrentHashMap<>(); private final ConcurrentMap externalSessionMap = new ConcurrentHashMap<>(); - @Autowired @Lazy private WebSocketService webSocketService; - @Autowired private TbTenantProfileCache tenantProfileCache; - @Autowired private RateLimitService rateLimitService; + @Autowired + private JwtAuthenticationProvider authenticationProvider; @Value("${server.ws.send_timeout:5000}") private long sendTimeout; @@ -97,16 +109,77 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke private final ConcurrentMap> regularUserSessionsMap = new ConcurrentHashMap<>(); private final ConcurrentMap> publicUserSessionsMap = new ConcurrentHashMap<>(); + private final Cache pendingSessions = Caffeine.newBuilder() + .expireAfterWrite(10, TimeUnit.SECONDS) + .removalListener((sessionId, sessionMd, removalCause) -> { + if (removalCause == RemovalCause.EXPIRED && sessionMd != null) { + try { + close(sessionMd.sessionRef, CloseStatus.POLICY_VIOLATION); + } catch (IOException e) { + log.warn("IO error", e); + } + } + }) + .build(); + @Override public void handleTextMessage(WebSocketSession session, TextMessage message) { try { - SessionMetaData sessionMd = internalSessionMap.get(session.getId()); - if (sessionMd != null) { - log.trace("[{}][{}] Processing {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload()); - webSocketService.handleWebSocketMsg(sessionMd.sessionRef, message.getPayload()); - } else { + SessionMetaData sessionMd = getSessionMd(session.getId()); + if (sessionMd == null) { log.trace("[{}] Failed to find session", session.getId()); session.close(CloseStatus.SERVER_ERROR.withReason("Session not found!")); + return; + } + WebSocketSessionRef sessionRef = sessionMd.sessionRef; + String msg = message.getPayload(); + + WsCommandsWrapper cmdsWrapper; + try { + switch (sessionRef.getSessionType()) { + case GENERAL: + cmdsWrapper = JacksonUtil.fromString(msg, WsCommandsWrapper.class); + break; + case TELEMETRY: + cmdsWrapper = JacksonUtil.fromString(msg, TelemetryCmdsWrapper.class).toCommonCmdsWrapper(); + break; + case NOTIFICATIONS: + cmdsWrapper = JacksonUtil.fromString(msg, NotificationCmdsWrapper.class).toCommonCmdsWrapper(); + break; + default: + return; + } + } catch (Exception e) { + log.warn("Failed to decode subscription cmd: {}", e.getMessage(), e); + if (sessionRef.getSecurityCtx() != null) { + webSocketService.sendError(sessionRef, 1, SubscriptionErrorCode.BAD_REQUEST, "Failed to parse the payload"); + } else { + close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); + } + return; + } + + if (sessionRef.getSecurityCtx() != null) { + log.trace("[{}][{}] Processing {}", sessionRef.getSecurityCtx().getTenantId(), session.getId(), msg); + webSocketService.handleCommands(sessionRef, cmdsWrapper); + } else { + AuthCmd authCmd = cmdsWrapper.getAuthCmd(); + if (authCmd == null) { + close(sessionRef, CloseStatus.POLICY_VIOLATION.withReason("Auth cmd is missing")); + return; + } + log.trace("[{}] Authenticating session", session.getId()); + SecurityUser securityCtx; + try { + securityCtx = authenticationProvider.authenticate(authCmd.getToken()); + } catch (Exception e) { + close(sessionRef, CloseStatus.BAD_DATA.withReason(e.getMessage())); + return; + } + sessionRef.setSecurityCtx(securityCtx); + pendingSessions.invalidate(session.getId()); + establishSession(session, sessionRef); + webSocketService.sendUpdate(sessionRef.getSessionId(), new AuthCmdUpdate(1)); } } catch (IOException e) { log.warn("IO error", e); @@ -116,7 +189,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke @Override protected void handlePongMessage(WebSocketSession session, PongMessage message) throws Exception { try { - SessionMetaData sessionMd = internalSessionMap.get(session.getId()); + SessionMetaData sessionMd = getSessionMd(session.getId()); if (sessionMd != null) { log.trace("[{}][{}] Processing pong response {}", sessionMd.sessionRef.getSecurityCtx().getTenantId(), session.getId(), message.getPayload()); sessionMd.processPongMessage(System.currentTimeMillis()); @@ -139,36 +212,45 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke nativeSession.getAsyncRemote().setSendTimeout(sendTimeout); } } - String internalSessionId = session.getId(); WebSocketSessionRef sessionRef = toRef(session); - String externalSessionId = sessionRef.getSessionId(); + log.debug("[{}][{}] Session opened from address: {}", sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); + establishSession(session, sessionRef); + } catch (InvalidParameterException e) { + log.warn("[{}] Failed to start session", session.getId(), e); + session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); + } catch (Exception e) { + log.warn("[{}] Failed to start session", session.getId(), e); + session.close(CloseStatus.SERVER_ERROR.withReason(e.getMessage())); + } + } + private void establishSession(WebSocketSession session, WebSocketSessionRef sessionRef) throws IOException { + if (sessionRef.getSecurityCtx() != null) { if (!checkLimits(session, sessionRef)) { return; } var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); int wsTenantProfileQueueLimit = tenantProfileConfiguration != null ? tenantProfileConfiguration.getWsMsgQueueLimitPerSession() : wsMaxQueueMessagesPerSession; - internalSessionMap.put(internalSessionId, new SessionMetaData(session, sessionRef, + SessionMetaData sessionMd = new SessionMetaData(session, sessionRef, (wsTenantProfileQueueLimit > 0 && wsTenantProfileQueueLimit < wsMaxQueueMessagesPerSession) ? - wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession)); + wsTenantProfileQueueLimit : wsMaxQueueMessagesPerSession); - externalSessionMap.put(externalSessionId, internalSessionId); + internalSessionMap.put(session.getId(), sessionMd); + externalSessionMap.put(sessionRef.getSessionId(), session.getId()); processInWebSocketService(sessionRef, SessionEvent.onEstablished()); - log.info("[{}][{}][{}] Session is opened from address: {}", sessionRef.getSecurityCtx().getTenantId(), externalSessionId, session.getId(), session.getRemoteAddress()); - } catch (InvalidParameterException e) { - log.warn("[{}] Failed to start session", session.getId(), e); - session.close(CloseStatus.BAD_DATA.withReason(e.getMessage())); - } catch (Exception e) { - log.warn("[{}] Failed to start session", session.getId(), e); - session.close(CloseStatus.SERVER_ERROR.withReason(e.getMessage())); + log.info("[{}][{}][{}] Session established from address: {}", sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSessionId(), session.getId(), session.getRemoteAddress()); + } else { + SessionMetaData sessionMd = new SessionMetaData(session, sessionRef, wsMaxQueueMessagesPerSession); + pendingSessions.put(session.getId(), sessionMd); + externalSessionMap.put(sessionRef.getSessionId(), session.getId()); } } @Override public void handleTransportError(WebSocketSession session, Throwable tError) throws Exception { super.handleTransportError(session, tError); - SessionMetaData sessionMd = internalSessionMap.get(session.getId()); + SessionMetaData sessionMd = getSessionMd(session.getId()); if (sessionMd != null) { processInWebSocketService(sessionMd.sessionRef, SessionEvent.onError(tError)); } else { @@ -181,10 +263,15 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke public void afterConnectionClosed(WebSocketSession session, CloseStatus closeStatus) throws Exception { super.afterConnectionClosed(session, closeStatus); SessionMetaData sessionMd = internalSessionMap.remove(session.getId()); + if (sessionMd == null) { + sessionMd = pendingSessions.asMap().remove(session.getId()); + } if (sessionMd != null) { - cleanupLimits(session, sessionMd.sessionRef); externalSessionMap.remove(sessionMd.sessionRef.getSessionId()); - processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); + if (sessionMd.sessionRef.getSecurityCtx() != null) { + cleanupLimits(session, sessionMd.sessionRef); + processInWebSocketService(sessionMd.sessionRef, SessionEvent.onClosed()); + } log.info("[{}][{}][{}] Session is closed", sessionMd.sessionRef.getSecurityCtx().getTenantId(), sessionMd.sessionRef.getSessionId(), session.getId()); } else { log.info("[{}] Session is closed", session.getId()); @@ -192,8 +279,11 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } private void processInWebSocketService(WebSocketSessionRef sessionRef, SessionEvent event) { + if (sessionRef.getSecurityCtx() == null) { + return; + } try { - webSocketService.handleWebSocketSessionEvent(sessionRef, event); + webSocketService.handleSessionEvent(sessionRef, event); } catch (BeanCreationNotAllowedException e) { log.warn("[{}] Failed to close session due to possible shutdown state", sessionRef.getSessionId()); } @@ -210,16 +300,28 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke .orElseThrow(() -> new InvalidParameterException("Unknown session type")); } - SecurityUser currentUser = (SecurityUser) ((Authentication) session.getPrincipal()).getPrincipal(); + SecurityUser securityCtx = null; + String token = StringUtils.substringAfter(session.getUri().getQuery(), "token="); + if (StringUtils.isNotEmpty(token)) { + securityCtx = authenticationProvider.authenticate(token); + } return WebSocketSessionRef.builder() .sessionId(UUID.randomUUID().toString()) - .securityCtx(currentUser) + .securityCtx(securityCtx) .localAddress(session.getLocalAddress()) .remoteAddress(session.getRemoteAddress()) .sessionType(sessionType) .build(); } + private SessionMetaData getSessionMd(String internalSessionId) { + SessionMetaData sessionMd = internalSessionMap.get(internalSessionId); + if (sessionMd == null) { + sessionMd = pendingSessions.getIfPresent(internalSessionId); + } + return sessionMd; + } + class SessionMetaData implements SendHandler { private final WebSocketSession session; private final RemoteEndpoint.Async asyncRemote; @@ -228,6 +330,8 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke final AtomicBoolean isSending = new AtomicBoolean(false); private final Queue> msgQueue; + // TODO: msg queue as in org.thingsboard.server.transport.mqtt.session.DeviceSessionCtx + private volatile long lastActivityTime; SessionMetaData(WebSocketSession session, WebSocketSessionRef sessionRef, int maxMsgQueuePerSession) { @@ -335,7 +439,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke @Override public void send(WebSocketSessionRef sessionRef, int subscriptionId, String msg) throws IOException { String externalId = sessionRef.getSessionId(); - log.debug("[{}] Processing {}", externalId, msg); + log.debug("[{}] Sending {}", externalId, msg); String internalId = externalSessionMap.get(externalId); if (internalId != null) { SessionMetaData sessionMd = internalSessionMap.get(internalId); @@ -383,7 +487,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke log.debug("[{}] Processing close request", externalId); String internalId = externalSessionMap.get(externalId); if (internalId != null) { - SessionMetaData sessionMd = internalSessionMap.get(internalId); + SessionMetaData sessionMd = getSessionMd(internalId); if (sessionMd != null) { sessionMd.session.close(reason); } else { @@ -394,7 +498,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } - private boolean checkLimits(WebSocketSession session, WebSocketSessionRef sessionRef) throws Exception { + private boolean checkLimits(WebSocketSession session, WebSocketSessionRef sessionRef) throws IOException { var tenantProfileConfiguration = getTenantProfileConfiguration(sessionRef); if (tenantProfileConfiguration == null) { return true; diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/DefaultTokenOutdatingService.java b/application/src/main/java/org/thingsboard/server/service/security/auth/DefaultTokenOutdatingService.java index 9f33466793..02c712a640 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/DefaultTokenOutdatingService.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/DefaultTokenOutdatingService.java @@ -23,7 +23,6 @@ import org.thingsboard.server.cache.TbTransactionalCache; import org.thingsboard.server.common.data.StringUtils; import org.thingsboard.server.common.data.id.UserId; import org.thingsboard.server.common.data.security.event.UserAuthDataChangedEvent; -import org.thingsboard.server.common.data.security.model.JwtToken; import org.thingsboard.server.service.security.model.token.JwtTokenFactory; import java.util.Optional; @@ -49,7 +48,7 @@ public class DefaultTokenOutdatingService implements TokenOutdatingService { } @Override - public boolean isOutdated(JwtToken token, UserId userId) { + public boolean isOutdated(String token, UserId userId) { Claims claims = tokenFactory.parseTokenClaims(token).getBody(); long issueTime = claims.getIssuedAt().getTime(); String sessionId = claims.get("sessionId", String.class); diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/TokenOutdatingService.java b/application/src/main/java/org/thingsboard/server/service/security/auth/TokenOutdatingService.java index 20df639619..f1d0c499ff 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/TokenOutdatingService.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/TokenOutdatingService.java @@ -16,10 +16,9 @@ package org.thingsboard.server.service.security.auth; import org.thingsboard.server.common.data.id.UserId; -import org.thingsboard.server.common.data.security.model.JwtToken; public interface TokenOutdatingService { - boolean isOutdated(JwtToken token, UserId userId); + boolean isOutdated(String token, UserId userId); } diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/JwtAuthenticationProvider.java b/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/JwtAuthenticationProvider.java index e0242126fd..43389e2a09 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/JwtAuthenticationProvider.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/JwtAuthenticationProvider.java @@ -16,7 +16,9 @@ package org.thingsboard.server.service.security.auth.jwt; import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; import org.springframework.security.authentication.AuthenticationProvider; +import org.springframework.security.authentication.BadCredentialsException; import org.springframework.security.core.Authentication; import org.springframework.security.core.AuthenticationException; import org.springframework.stereotype.Component; @@ -37,13 +39,19 @@ public class JwtAuthenticationProvider implements AuthenticationProvider { @Override public Authentication authenticate(Authentication authentication) throws AuthenticationException { RawAccessJwtToken rawAccessToken = (RawAccessJwtToken) authentication.getCredentials(); - SecurityUser securityUser = tokenFactory.parseAccessJwtToken(rawAccessToken); + SecurityUser securityUser = authenticate(rawAccessToken.getToken()); + return new JwtAuthenticationToken(securityUser); + } - if (tokenOutdatingService.isOutdated(rawAccessToken, securityUser.getId())) { + public SecurityUser authenticate(String accessToken) throws AuthenticationException { + if (StringUtils.isEmpty(accessToken)) { + throw new BadCredentialsException("Token is invalid"); + } + SecurityUser securityUser = tokenFactory.parseAccessJwtToken(accessToken); + if (tokenOutdatingService.isOutdated(accessToken, securityUser.getId())) { throw new JwtExpiredTokenException("Token is outdated"); } - - return new JwtAuthenticationToken(securityUser); + return securityUser; } @Override diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/RefreshTokenAuthenticationProvider.java b/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/RefreshTokenAuthenticationProvider.java index e75d3aa4d1..527ff97ba5 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/RefreshTokenAuthenticationProvider.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/jwt/RefreshTokenAuthenticationProvider.java @@ -57,7 +57,7 @@ public class RefreshTokenAuthenticationProvider implements AuthenticationProvide public Authentication authenticate(Authentication authentication) throws AuthenticationException { Assert.notNull(authentication, "No authentication data provided"); RawAccessJwtToken rawAccessToken = (RawAccessJwtToken) authentication.getCredentials(); - SecurityUser unsafeUser = tokenFactory.parseRefreshToken(rawAccessToken); + SecurityUser unsafeUser = tokenFactory.parseRefreshToken(rawAccessToken.getToken()); UserPrincipal principal = unsafeUser.getUserPrincipal(); SecurityUser securityUser; @@ -67,7 +67,7 @@ public class RefreshTokenAuthenticationProvider implements AuthenticationProvide securityUser = authenticateByPublicId(principal.getValue()); } securityUser.setSessionId(unsafeUser.getSessionId()); - if (tokenOutdatingService.isOutdated(rawAccessToken, securityUser.getId())) { + if (tokenOutdatingService.isOutdated(rawAccessToken.getToken(), securityUser.getId())) { throw new CredentialsExpiredException("Token is outdated"); } diff --git a/application/src/main/java/org/thingsboard/server/service/security/exception/JwtExpiredTokenException.java b/application/src/main/java/org/thingsboard/server/service/security/exception/JwtExpiredTokenException.java index 87b7403a85..4e4430ceaa 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/exception/JwtExpiredTokenException.java +++ b/application/src/main/java/org/thingsboard/server/service/security/exception/JwtExpiredTokenException.java @@ -16,23 +16,22 @@ package org.thingsboard.server.service.security.exception; import org.springframework.security.core.AuthenticationException; -import org.thingsboard.server.common.data.security.model.JwtToken; public class JwtExpiredTokenException extends AuthenticationException { private static final long serialVersionUID = -5959543783324224864L; - private JwtToken token; + private String token; public JwtExpiredTokenException(String msg) { super(msg); } - public JwtExpiredTokenException(JwtToken token, String msg, Throwable t) { + public JwtExpiredTokenException(String token, String msg, Throwable t) { super(msg, t); this.token = token; } public String token() { - return this.token.getToken(); + return this.token; } } diff --git a/application/src/main/java/org/thingsboard/server/service/security/model/token/JwtTokenFactory.java b/application/src/main/java/org/thingsboard/server/service/security/model/token/JwtTokenFactory.java index a222a7e83a..a9d32c4d95 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/model/token/JwtTokenFactory.java +++ b/application/src/main/java/org/thingsboard/server/service/security/model/token/JwtTokenFactory.java @@ -93,8 +93,8 @@ public class JwtTokenFactory { return new AccessJwtToken(token); } - public SecurityUser parseAccessJwtToken(RawAccessJwtToken rawAccessToken) { - Jws jwsClaims = parseTokenClaims(rawAccessToken); + public SecurityUser parseAccessJwtToken(String token) { + Jws jwsClaims = parseTokenClaims(token); Claims claims = jwsClaims.getBody(); String subject = claims.getSubject(); @SuppressWarnings("unchecked") @@ -145,8 +145,8 @@ public class JwtTokenFactory { return new AccessJwtToken(token); } - public SecurityUser parseRefreshToken(RawAccessJwtToken rawAccessToken) { - Jws jwsClaims = parseTokenClaims(rawAccessToken); + public SecurityUser parseRefreshToken(String token) { + Jws jwsClaims = parseTokenClaims(token); Claims claims = jwsClaims.getBody(); String subject = claims.getSubject(); @SuppressWarnings("unchecked") @@ -200,11 +200,11 @@ public class JwtTokenFactory { .signWith(SignatureAlgorithm.HS512, jwtSettingsService.getJwtSettings().getTokenSigningKey()); } - public Jws parseTokenClaims(JwtToken token) { + public Jws parseTokenClaims(String token) { try { return Jwts.parser() .setSigningKey(jwtSettingsService.getJwtSettings().getTokenSigningKey()) - .parseClaimsJws(token.getToken()); + .parseClaimsJws(token); } catch (UnsupportedJwtException | MalformedJwtException | IllegalArgumentException ex) { log.debug("Invalid JWT Token", ex); throw new BadCredentialsException("Invalid JWT token: ", ex); diff --git a/application/src/main/java/org/thingsboard/server/service/subscription/TbAbstractSubCtx.java b/application/src/main/java/org/thingsboard/server/service/subscription/TbAbstractSubCtx.java index 102431f63b..b4468c0bd0 100644 --- a/application/src/main/java/org/thingsboard/server/service/subscription/TbAbstractSubCtx.java +++ b/application/src/main/java/org/thingsboard/server/service/subscription/TbAbstractSubCtx.java @@ -336,7 +336,7 @@ public abstract class TbAbstractSubCtx { public void sendWsMsg(CmdUpdate update) { wsLock.lock(); try { - wsService.sendWsMsg(sessionRef.getSessionId(), update); + wsService.sendUpdate(sessionRef.getSessionId(), update); } finally { wsLock.unlock(); } diff --git a/application/src/main/java/org/thingsboard/server/service/ws/AuthCmd.java b/application/src/main/java/org/thingsboard/server/service/ws/AuthCmd.java new file mode 100644 index 0000000000..776cdd40d2 --- /dev/null +++ b/application/src/main/java/org/thingsboard/server/service/ws/AuthCmd.java @@ -0,0 +1,33 @@ +/** + * Copyright © 2016-2023 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.service.ws; + +import lombok.AllArgsConstructor; +import lombok.Data; +import lombok.NoArgsConstructor; + +@Data +@NoArgsConstructor +@AllArgsConstructor +public class AuthCmd implements WsCmd { + private int cmdId; + private String token; + + @Override + public WsCmdType getType() { + return WsCmdType.AUTH; + } +} diff --git a/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java b/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java index 0d7dc7fbfe..b1118224ef 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/DefaultWebSocketService.java @@ -66,8 +66,6 @@ import org.thingsboard.server.service.subscription.TbEntityDataSubscriptionServi import org.thingsboard.server.service.subscription.TbLocalSubscriptionService; import org.thingsboard.server.service.subscription.TbTimeSeriesSubscription; import org.thingsboard.server.service.ws.notification.NotificationCommandsHandler; -import org.thingsboard.server.service.ws.notification.cmd.NotificationCmdsWrapper; -import org.thingsboard.server.service.ws.telemetry.cmd.TelemetryCmdsWrapper; import org.thingsboard.server.service.ws.telemetry.cmd.v1.AttributesSubscriptionCmd; import org.thingsboard.server.service.ws.telemetry.cmd.v1.GetHistoryCmd; import org.thingsboard.server.service.ws.telemetry.cmd.v1.SubscriptionCmd; @@ -124,7 +122,6 @@ public class DefaultWebSocketService implements WebSocketService { private static final String FAILED_TO_FETCH_DATA = "Failed to fetch data!"; private static final String FAILED_TO_FETCH_ATTRIBUTES = "Failed to fetch attributes!"; private static final String SESSION_META_DATA_NOT_FOUND = "Session meta-data not found!"; - private static final String FAILED_TO_PARSE_WS_COMMAND = "Failed to parse websocket command!"; private final ConcurrentMap wsSessionsMap = new ConcurrentHashMap<>(); @@ -192,7 +189,7 @@ public class DefaultWebSocketService implements WebSocketService { } @Override - public void handleWebSocketSessionEvent(WebSocketSessionRef sessionRef, SessionEvent event) { + public void handleSessionEvent(WebSocketSessionRef sessionRef, SessionEvent event) { String sessionId = sessionRef.getSessionId(); log.debug(PROCESSING_MSG, sessionId, event); switch (event.getEventType()) { @@ -212,46 +209,20 @@ public class DefaultWebSocketService implements WebSocketService { } @Override - public void handleWebSocketMsg(WebSocketSessionRef sessionRef, String msg) { - if (log.isTraceEnabled()) { - log.trace("[{}] Processing: {}", sessionRef.getSessionId(), msg); - } - - try { - WsCommandsWrapper cmdsWrapper; - switch (sessionRef.getSessionType()) { - case GENERAL: - cmdsWrapper = JacksonUtil.fromString(msg, WsCommandsWrapper.class); - break; - case TELEMETRY: - cmdsWrapper = JacksonUtil.fromString(msg, TelemetryCmdsWrapper.class).toCommonCmdsWrapper(); - break; - case NOTIFICATIONS: - cmdsWrapper = JacksonUtil.fromString(msg, NotificationCmdsWrapper.class).toCommonCmdsWrapper(); - break; - default: - throw new IllegalArgumentException("Unknown session type"); - } - processCmds(sessionRef, cmdsWrapper); - } catch (Exception e) { - log.warn("Failed to decode subscription cmd: {}", e.getMessage(), e); - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(UNKNOWN_SUBSCRIPTION_ID, SubscriptionErrorCode.BAD_REQUEST, FAILED_TO_PARSE_WS_COMMAND)); - } - } - - private void processCmds(WebSocketSessionRef sessionRef, WsCommandsWrapper cmdsWrapper) { - if (cmdsWrapper == null || CollectionUtils.isEmpty(cmdsWrapper.getCmds())) { + public void handleCommands(WebSocketSessionRef sessionRef, WsCommandsWrapper commandsWrapper) { + if (commandsWrapper == null || CollectionUtils.isEmpty(commandsWrapper.getCmds())) { return; } String sessionId = sessionRef.getSessionId(); - if (!validateSessionMetadata(sessionRef, cmdsWrapper.getCmds().get(0).getCmdId(), sessionId)) { + if (!validateSessionMetadata(sessionRef, UNKNOWN_SUBSCRIPTION_ID, sessionId)) { return; } - for (WsCmd cmd : cmdsWrapper.getCmds()) { + for (WsCmd cmd : commandsWrapper.getCmds()) { log.debug("[{}][{}][{}] Processing cmd: {}", sessionId, cmd.getType(), cmd.getCmdId(), cmd); try { - getCmdHandler(cmd.getType()).handle(sessionRef, cmd); + Optional.ofNullable(getCmdHandler(cmd.getType())) + .ifPresent(cmdHandler -> cmdHandler.handle(sessionRef, cmd)); } catch (Exception e) { log.error("[sessionId: {}, tenantId: {}, userId: {}] Failed to handle WS cmd: {}", sessionId, sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), cmd, e); @@ -288,19 +259,25 @@ public class DefaultWebSocketService implements WebSocketService { } @Override - public void sendWsMsg(String sessionId, TelemetrySubscriptionUpdate update) { - sendWsMsg(sessionId, update.getSubscriptionId(), update); + public void sendUpdate(String sessionId, TelemetrySubscriptionUpdate update) { + sendUpdate(sessionId, update.getSubscriptionId(), update); } @Override - public void sendWsMsg(String sessionId, CmdUpdate update) { - sendWsMsg(sessionId, update.getCmdId(), update); + public void sendUpdate(String sessionId, CmdUpdate update) { + sendUpdate(sessionId, update.getCmdId(), update); } - private void sendWsMsg(String sessionId, int cmdId, T update) { + @Override + public void sendError(WebSocketSessionRef sessionRef, int subId, SubscriptionErrorCode errorCode, String errorMsg) { + TelemetrySubscriptionUpdate update = new TelemetrySubscriptionUpdate(subId, errorCode, errorMsg); + sendUpdate(sessionRef, update); + } + + private void sendUpdate(String sessionId, int cmdId, T update) { WsSessionMetaData md = wsSessionsMap.get(sessionId); if (md != null) { - sendWsMsg(md.getSessionRef(), cmdId, update); + sendUpdate(md.getSessionRef(), cmdId, update); } } @@ -472,7 +449,7 @@ public class DefaultWebSocketService implements WebSocketService { .updateProcessor((subscription, update) -> { subLock.lock(); try { - sendWsMsg(subscription.getSessionId(), update); + sendUpdate(subscription.getSessionId(), update); } finally { subLock.unlock(); } @@ -482,7 +459,7 @@ public class DefaultWebSocketService implements WebSocketService { subLock.lock(); try { oldSubService.addSubscription(sub); - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), attributesData)); + sendUpdate(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), attributesData)); } finally { subLock.unlock(); } @@ -500,7 +477,7 @@ public class DefaultWebSocketService implements WebSocketService { update = new TelemetrySubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.INTERNAL_ERROR, FAILED_TO_FETCH_ATTRIBUTES); } - sendWsMsg(sessionRef, update); + sendUpdate(sessionRef, update); } }; @@ -529,7 +506,7 @@ public class DefaultWebSocketService implements WebSocketService { FutureCallback> callback = new FutureCallback>() { @Override public void onSuccess(List data) { - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); + sendUpdate(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); } @Override @@ -542,7 +519,7 @@ public class DefaultWebSocketService implements WebSocketService { update = new TelemetrySubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.INTERNAL_ERROR, FAILED_TO_FETCH_DATA); } - sendWsMsg(sessionRef, update); + sendUpdate(sessionRef, update); } }; accessValidator.validate(sessionRef.getSecurityCtx(), Operation.READ_TELEMETRY, entityId, @@ -577,7 +554,7 @@ public class DefaultWebSocketService implements WebSocketService { .updateProcessor((subscription, update) -> { subLock.lock(); try { - sendWsMsg(subscription.getSessionId(), update); + sendUpdate(subscription.getSessionId(), update); } finally { subLock.unlock(); } @@ -588,7 +565,7 @@ public class DefaultWebSocketService implements WebSocketService { subLock.lock(); try { oldSubService.addSubscription(sub); - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), attributesData)); + sendUpdate(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), attributesData)); } finally { subLock.unlock(); } @@ -672,7 +649,7 @@ public class DefaultWebSocketService implements WebSocketService { .updateProcessor((subscription, update) -> { subLock.lock(); try { - sendWsMsg(subscription.getSessionId(), update); + sendUpdate(subscription.getSessionId(), update); } finally { subLock.unlock(); } @@ -685,7 +662,7 @@ public class DefaultWebSocketService implements WebSocketService { subLock.lock(); try { oldSubService.addSubscription(sub); - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); + sendUpdate(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); } finally { subLock.unlock(); } @@ -701,7 +678,7 @@ public class DefaultWebSocketService implements WebSocketService { update = new TelemetrySubscriptionUpdate(cmd.getCmdId(), SubscriptionErrorCode.INTERNAL_ERROR, FAILED_TO_FETCH_DATA); } - sendWsMsg(sessionRef, update); + sendUpdate(sessionRef, update); } }; accessValidator.validate(sessionRef.getSecurityCtx(), Operation.READ_TELEMETRY, entityId, @@ -727,7 +704,7 @@ public class DefaultWebSocketService implements WebSocketService { .updateProcessor((subscription, update) -> { subLock.lock(); try { - sendWsMsg(subscription.getSessionId(), update); + sendUpdate(subscription.getSessionId(), update); } finally { subLock.unlock(); } @@ -740,7 +717,7 @@ public class DefaultWebSocketService implements WebSocketService { subLock.lock(); try { oldSubService.addSubscription(sub); - sendWsMsg(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); + sendUpdate(sessionRef, new TelemetrySubscriptionUpdate(cmd.getCmdId(), data)); } finally { subLock.unlock(); } @@ -829,20 +806,15 @@ public class DefaultWebSocketService implements WebSocketService { return true; } - private void sendError(WebSocketSessionRef sessionRef, int subId, SubscriptionErrorCode errorCode, String errorMsg) { - TelemetrySubscriptionUpdate update = new TelemetrySubscriptionUpdate(subId, errorCode, errorMsg); - sendWsMsg(sessionRef, update); - } - - private void sendWsMsg(WebSocketSessionRef sessionRef, EntityDataUpdate update) { - sendWsMsg(sessionRef, update.getCmdId(), update); + private void sendUpdate(WebSocketSessionRef sessionRef, EntityDataUpdate update) { + sendUpdate(sessionRef, update.getCmdId(), update); } - private void sendWsMsg(WebSocketSessionRef sessionRef, TelemetrySubscriptionUpdate update) { - sendWsMsg(sessionRef, update.getSubscriptionId(), update); + private void sendUpdate(WebSocketSessionRef sessionRef, TelemetrySubscriptionUpdate update) { + sendUpdate(sessionRef, update.getSubscriptionId(), update); } - private void sendWsMsg(WebSocketSessionRef sessionRef, int cmdId, Object update) { + private void sendUpdate(WebSocketSessionRef sessionRef, int cmdId, Object update) { try { String msg = JacksonUtil.OBJECT_MAPPER.writeValueAsString(update); executor.submit(() -> { @@ -997,7 +969,7 @@ public class DefaultWebSocketService implements WebSocketService { return cmdHandler; } } - throw new IllegalArgumentException("Unknown command type " + cmdType); + return null; } public static WsCmdHandler newCmdHandler(WsCmdType cmdType, BiConsumer handler) { diff --git a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketService.java b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketService.java index 1ebd6101f0..19753b91ca 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketService.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketService.java @@ -16,6 +16,7 @@ package org.thingsboard.server.service.ws; import org.springframework.web.socket.CloseStatus; +import org.thingsboard.server.service.subscription.SubscriptionErrorCode; import org.thingsboard.server.service.ws.telemetry.cmd.v2.CmdUpdate; import org.thingsboard.server.service.ws.telemetry.sub.TelemetrySubscriptionUpdate; @@ -24,13 +25,15 @@ import org.thingsboard.server.service.ws.telemetry.sub.TelemetrySubscriptionUpda */ public interface WebSocketService { - void handleWebSocketSessionEvent(WebSocketSessionRef sessionRef, SessionEvent sessionEvent); + void handleSessionEvent(WebSocketSessionRef sessionRef, SessionEvent sessionEvent); - void handleWebSocketMsg(WebSocketSessionRef sessionRef, String msg); + void handleCommands(WebSocketSessionRef sessionRef, WsCommandsWrapper commandsWrapper); - void sendWsMsg(String sessionId, TelemetrySubscriptionUpdate update); + void sendUpdate(String sessionId, TelemetrySubscriptionUpdate update); - void sendWsMsg(String sessionId, CmdUpdate update); + void sendUpdate(String sessionId, CmdUpdate update); + + void sendError(WebSocketSessionRef sessionRef, int subId, SubscriptionErrorCode errorCode, String errorMsg); void close(String sessionId, CloseStatus status); } diff --git a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java index e799c8faa0..1dc75b0149 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/WebSocketSessionRef.java @@ -16,8 +16,7 @@ package org.thingsboard.server.service.ws; import lombok.Builder; -import lombok.Getter; -import lombok.RequiredArgsConstructor; +import lombok.Data; import org.thingsboard.server.service.security.model.SecurityUser; import java.net.InetSocketAddress; @@ -27,15 +26,14 @@ import java.util.concurrent.atomic.AtomicInteger; /** * Created by ashvayka on 27.03.18. */ -@RequiredArgsConstructor @Builder -@Getter +@Data public class WebSocketSessionRef { private static final long serialVersionUID = 1L; private final String sessionId; - private final SecurityUser securityCtx; + private SecurityUser securityCtx; private final InetSocketAddress localAddress; private final InetSocketAddress remoteAddress; private final WebSocketSessionType sessionType; diff --git a/application/src/main/java/org/thingsboard/server/service/ws/WsCmdType.java b/application/src/main/java/org/thingsboard/server/service/ws/WsCmdType.java index ef1d8520a2..285dd88efe 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/WsCmdType.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/WsCmdType.java @@ -16,6 +16,8 @@ package org.thingsboard.server.service.ws; public enum WsCmdType { + AUTH, + ATTRIBUTES, TIMESERIES, TIMESERIES_HISTORY, diff --git a/application/src/main/java/org/thingsboard/server/service/ws/WsCommandsWrapper.java b/application/src/main/java/org/thingsboard/server/service/ws/WsCommandsWrapper.java index cbdbdf8c05..5704f3b2dc 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/WsCommandsWrapper.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/WsCommandsWrapper.java @@ -45,6 +45,8 @@ import java.util.List; @NoArgsConstructor public class WsCommandsWrapper { + private AuthCmd authCmd; + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, property = "type") @JsonSubTypes({ @Type(name = "ATTRIBUTES", value = AttributesSubscriptionCmd.class), diff --git a/application/src/main/java/org/thingsboard/server/service/ws/notification/DefaultNotificationCommandsHandler.java b/application/src/main/java/org/thingsboard/server/service/ws/notification/DefaultNotificationCommandsHandler.java index 0937275a0d..5a02814870 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/notification/DefaultNotificationCommandsHandler.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/notification/DefaultNotificationCommandsHandler.java @@ -245,7 +245,7 @@ public class DefaultNotificationCommandsHandler implements NotificationCommandsH private void sendUpdate(String sessionId, CmdUpdate update) { log.trace("[{}, cmdId: {}] Sending WS update: {}", sessionId, update.getCmdId(), update); - wsService.sendWsMsg(sessionId, update); + wsService.sendUpdate(sessionId, update); } } diff --git a/application/src/main/java/org/thingsboard/server/service/ws/notification/cmd/NotificationCmdsWrapper.java b/application/src/main/java/org/thingsboard/server/service/ws/notification/cmd/NotificationCmdsWrapper.java index 3e0eca1fab..8ba52a9b3c 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/notification/cmd/NotificationCmdsWrapper.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/notification/cmd/NotificationCmdsWrapper.java @@ -42,7 +42,7 @@ public class NotificationCmdsWrapper { @JsonIgnore public WsCommandsWrapper toCommonCmdsWrapper() { - return new WsCommandsWrapper(Stream.of( + return new WsCommandsWrapper(null, Stream.of( unreadCountSubCmd, unreadSubCmd, markAsReadCmd, markAllAsReadCmd, unsubCmd ) .filter(Objects::nonNull) diff --git a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/TelemetryCmdsWrapper.java b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/TelemetryCmdsWrapper.java index e2029cbe36..1019124e6d 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/TelemetryCmdsWrapper.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/TelemetryCmdsWrapper.java @@ -67,7 +67,7 @@ public class TelemetryCmdsWrapper { @JsonIgnore public WsCommandsWrapper toCommonCmdsWrapper() { - return new WsCommandsWrapper(Stream.of( + return new WsCommandsWrapper(null, Stream.of( attrSubCmds, tsSubCmds, historyCmds, entityDataCmds, entityDataUnsubscribeCmds, alarmDataCmds, alarmDataUnsubscribeCmds, entityCountCmds, entityCountUnsubscribeCmds, diff --git a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java new file mode 100644 index 0000000000..61ed4fc9ce --- /dev/null +++ b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/AuthCmdUpdate.java @@ -0,0 +1,34 @@ +/** + * Copyright © 2016-2023 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.service.ws.telemetry.cmd.v2; + +import org.thingsboard.server.service.subscription.SubscriptionErrorCode; + +public class AuthCmdUpdate extends CmdUpdate { + + public AuthCmdUpdate(int cmdId) { + this(cmdId, SubscriptionErrorCode.NO_ERROR.getCode(), null); + } + + public AuthCmdUpdate(int cmdId, int errorCode, String errorMsg) { + super(cmdId, errorCode, errorMsg); + } + + @Override + public CmdUpdateType getCmdUpdateType() { + return CmdUpdateType.AUTH; + } +} diff --git a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java index f5b3809ce2..04b3cbd06e 100644 --- a/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java +++ b/application/src/main/java/org/thingsboard/server/service/ws/telemetry/cmd/v2/CmdUpdateType.java @@ -21,5 +21,6 @@ public enum CmdUpdateType { ALARM_COUNT_DATA, COUNT_DATA, NOTIFICATIONS, - NOTIFICATIONS_COUNT + NOTIFICATIONS_COUNT, + AUTH } diff --git a/application/src/test/java/org/thingsboard/server/controller/AbstractControllerTest.java b/application/src/test/java/org/thingsboard/server/controller/AbstractControllerTest.java index fe09a187e1..1729285034 100644 --- a/application/src/test/java/org/thingsboard/server/controller/AbstractControllerTest.java +++ b/application/src/test/java/org/thingsboard/server/controller/AbstractControllerTest.java @@ -105,8 +105,11 @@ public abstract class AbstractControllerTest extends AbstractNotifyEntityTest { } protected TbTestWebSocketClient buildAndConnectWebSocketClient(String path) throws URISyntaxException, InterruptedException { - TbTestWebSocketClient wsClient = new TbTestWebSocketClient(new URI(WS_URL + wsPort + path + "?token=" + token)); + TbTestWebSocketClient wsClient = new TbTestWebSocketClient(new URI(WS_URL + wsPort + path)); assertThat(wsClient.connectBlocking(TIMEOUT, TimeUnit.SECONDS)).isTrue(); + if (!path.contains("token=")) { + wsClient.authenticate(token); + } return wsClient; } diff --git a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java index 0194911e5b..69eefa78bd 100644 --- a/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java +++ b/application/src/test/java/org/thingsboard/server/controller/TbTestWebSocketClient.java @@ -27,6 +27,7 @@ import org.thingsboard.server.common.data.query.EntityDataPageLink; import org.thingsboard.server.common.data.query.EntityDataQuery; import org.thingsboard.server.common.data.query.EntityFilter; import org.thingsboard.server.common.data.query.EntityKey; +import org.thingsboard.server.service.ws.AuthCmd; import org.thingsboard.server.service.ws.WsCmd; import org.thingsboard.server.service.ws.WsCommandsWrapper; import org.thingsboard.server.service.ws.telemetry.cmd.v1.AttributesSubscriptionCmd; @@ -64,6 +65,13 @@ public class TbTestWebSocketClient extends WebSocketClient { } + public void authenticate(String token) { + WsCommandsWrapper cmdsWrapper = new WsCommandsWrapper(); + cmdsWrapper.setAuthCmd(new AuthCmd(1, token)); + send(JacksonUtil.toString(cmdsWrapper)); + waitForReply(); + } + @Override public void onMessage(String s) { log.info("RECEIVED: {}", s); diff --git a/application/src/test/java/org/thingsboard/server/controller/WebsocketApiTest.java b/application/src/test/java/org/thingsboard/server/controller/WebsocketApiTest.java index 070fda9474..47ce021dae 100644 --- a/application/src/test/java/org/thingsboard/server/controller/WebsocketApiTest.java +++ b/application/src/test/java/org/thingsboard/server/controller/WebsocketApiTest.java @@ -667,7 +667,7 @@ public class WebsocketApiTest extends AbstractControllerTest { ObjectNode wrapperNode = JacksonUtil.newObjectNode(); wrapperNode.set("entityCountCmds", entityCountCmds); - wsClient = buildAndConnectWebSocketClient("/api/ws/plugins/telemetry"); + wsClient = buildAndConnectWebSocketClient("/api/ws/plugins/telemetry?token=" + token); wsClient.send(JacksonUtil.toString(wrapperNode)); EntityCountUpdate update = wsClient.parseCountReply(wsClient.waitForReply()); diff --git a/application/src/test/java/org/thingsboard/server/service/notification/AbstractNotificationApiTest.java b/application/src/test/java/org/thingsboard/server/service/notification/AbstractNotificationApiTest.java index 27a0eed1a4..f7fc195146 100644 --- a/application/src/test/java/org/thingsboard/server/service/notification/AbstractNotificationApiTest.java +++ b/application/src/test/java/org/thingsboard/server/service/notification/AbstractNotificationApiTest.java @@ -259,8 +259,9 @@ public abstract class AbstractNotificationApiTest extends AbstractControllerTest @Override protected NotificationApiWsClient buildAndConnectWebSocketClient() throws URISyntaxException, InterruptedException { - NotificationApiWsClient wsClient = new NotificationApiWsClient(WS_URL + wsPort, token); + NotificationApiWsClient wsClient = new NotificationApiWsClient(WS_URL + wsPort); assertThat(wsClient.connectBlocking(TIMEOUT, TimeUnit.SECONDS)).isTrue(); + wsClient.authenticate(token); return wsClient; } diff --git a/application/src/test/java/org/thingsboard/server/service/notification/NotificationApiWsClient.java b/application/src/test/java/org/thingsboard/server/service/notification/NotificationApiWsClient.java index 2cff2939c2..d96d04b31a 100644 --- a/application/src/test/java/org/thingsboard/server/service/notification/NotificationApiWsClient.java +++ b/application/src/test/java/org/thingsboard/server/service/notification/NotificationApiWsClient.java @@ -48,8 +48,8 @@ public class NotificationApiWsClient extends TbTestWebSocketClient { private int unreadCount; private List notifications; - public NotificationApiWsClient(String wsUrl, String token) throws URISyntaxException { - super(new URI(wsUrl + "/api/ws?token=" + token)); + public NotificationApiWsClient(String wsUrl) throws URISyntaxException { + super(new URI(wsUrl + "/api/ws")); } public NotificationApiWsClient subscribeForUnreadNotifications(int limit) { diff --git a/application/src/test/java/org/thingsboard/server/service/security/auth/JwtTokenFactoryTest.java b/application/src/test/java/org/thingsboard/server/service/security/auth/JwtTokenFactoryTest.java index 7bb4f1e559..48070678f6 100644 --- a/application/src/test/java/org/thingsboard/server/service/security/auth/JwtTokenFactoryTest.java +++ b/application/src/test/java/org/thingsboard/server/service/security/auth/JwtTokenFactoryTest.java @@ -106,7 +106,7 @@ public class JwtTokenFactoryTest { AccessJwtToken accessToken = tokenFactory.createAccessJwtToken(securityUser); checkExpirationTime(accessToken, jwtSettings.getTokenExpirationTime()); - SecurityUser parsedSecurityUser = tokenFactory.parseAccessJwtToken(new RawAccessJwtToken(accessToken.getToken())); + SecurityUser parsedSecurityUser = tokenFactory.parseAccessJwtToken(accessToken.getToken()); assertThat(parsedSecurityUser.getId()).isEqualTo(securityUser.getId()); assertThat(parsedSecurityUser.getEmail()).isEqualTo(securityUser.getEmail()); assertThat(parsedSecurityUser.getUserPrincipal()).matches(userPrincipal -> { @@ -135,7 +135,7 @@ public class JwtTokenFactoryTest { JwtToken refreshToken = tokenFactory.createRefreshToken(securityUser); checkExpirationTime(refreshToken, jwtSettings.getRefreshTokenExpTime()); - SecurityUser parsedSecurityUser = tokenFactory.parseRefreshToken(new RawAccessJwtToken(refreshToken.getToken())); + SecurityUser parsedSecurityUser = tokenFactory.parseRefreshToken(refreshToken.getToken()); assertThat(parsedSecurityUser.getId()).isEqualTo(securityUser.getId()); assertThat(parsedSecurityUser.getUserPrincipal()).matches(userPrincipal -> { return userPrincipal.getType().equals(securityUser.getUserPrincipal().getType()) @@ -159,7 +159,7 @@ public class JwtTokenFactoryTest { JwtToken preVerificationToken = tokenFactory.createPreVerificationToken(securityUser, tokenLifetime); checkExpirationTime(preVerificationToken, tokenLifetime); - SecurityUser parsedSecurityUser = tokenFactory.parseAccessJwtToken(new RawAccessJwtToken(preVerificationToken.getToken())); + SecurityUser parsedSecurityUser = tokenFactory.parseAccessJwtToken(preVerificationToken.getToken()); assertThat(parsedSecurityUser.getId()).isEqualTo(securityUser.getId()); assertThat(parsedSecurityUser.getAuthority()).isEqualTo(Authority.PRE_VERIFICATION_TOKEN); assertThat(parsedSecurityUser.getTenantId()).isEqualTo(securityUser.getTenantId()); @@ -198,7 +198,7 @@ public class JwtTokenFactoryTest { } private void checkExpirationTime(JwtToken jwtToken, int tokenLifetime) { - Claims claims = tokenFactory.parseTokenClaims(jwtToken).getBody(); + Claims claims = tokenFactory.parseTokenClaims(jwtToken.getToken()).getBody(); assertThat(claims.getExpiration()).matches(actualExpirationTime -> { Calendar expirationTime = Calendar.getInstance(); expirationTime.setTime(new Date()); diff --git a/application/src/test/java/org/thingsboard/server/service/security/auth/TokenOutdatingTest.java b/application/src/test/java/org/thingsboard/server/service/security/auth/TokenOutdatingTest.java index eeb9fe64a2..90d911833c 100644 --- a/application/src/test/java/org/thingsboard/server/service/security/auth/TokenOutdatingTest.java +++ b/application/src/test/java/org/thingsboard/server/service/security/auth/TokenOutdatingTest.java @@ -114,12 +114,12 @@ public class TokenOutdatingTest { // Token outdatage time is rounded to 1 sec. Need to wait before outdating so that outdatage time is strictly after token issue time SECONDS.sleep(1); eventPublisher.publishEvent(new UserCredentialsInvalidationEvent(securityUser.getId())); - assertTrue(tokenOutdatingService.isOutdated(jwtToken, securityUser.getId())); + assertTrue(tokenOutdatingService.isOutdated(jwtToken.getToken(), securityUser.getId())); SECONDS.sleep(1); JwtToken newJwtToken = tokenFactory.createAccessJwtToken(securityUser); - assertFalse(tokenOutdatingService.isOutdated(newJwtToken, securityUser.getId())); + assertFalse(tokenOutdatingService.isOutdated(newJwtToken.getToken(), securityUser.getId())); } @Test