diff --git a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java index 7a396ff795..481d412d59 100644 --- a/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java +++ b/application/src/main/java/org/thingsboard/server/controller/plugin/TbWebSocketHandler.java @@ -219,12 +219,12 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke .build(); } - private class SessionMetaData implements SendHandler { + class SessionMetaData implements SendHandler { private final WebSocketSession session; private final RemoteEndpoint.Async asyncRemote; private final WebSocketSessionRef sessionRef; - private final AtomicBoolean isSending = new AtomicBoolean(false); + final AtomicBoolean isSending = new AtomicBoolean(false); private final Queue> msgQueue; private volatile long lastActivityTime; @@ -239,7 +239,7 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke this.lastActivityTime = System.currentTimeMillis(); } - synchronized void sendPing(long currentTime) { + void sendPing(long currentTime) { try { long timeSinceLastActivity = currentTime - lastActivityTime; if (timeSinceLastActivity >= pingTimeout) { @@ -254,37 +254,38 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke } } - private void closeSession(CloseStatus reason) { + void closeSession(CloseStatus reason) { try { close(this.sessionRef, reason); } catch (IOException ioe) { log.trace("[{}] Session transport error", session.getId(), ioe); + } finally { + msgQueue.clear(); } } - synchronized void processPongMessage(long currentTime) { + void processPongMessage(long currentTime) { lastActivityTime = currentTime; } - synchronized void sendMsg(String msg) { + void sendMsg(String msg) { sendMsg(new TbWebSocketTextMsg(msg)); } - synchronized void sendMsg(TbWebSocketMsg msg) { - if (isSending.compareAndSet(false, true)) { - sendMsgInternal(msg); - } else { - try { - msgQueue.add(msg); - } catch (RuntimeException e) { - if (log.isTraceEnabled()) { - log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e); - } else { - log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId()); - } - closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); + void sendMsg(TbWebSocketMsg msg) { + try { + msgQueue.add(msg); + } catch (RuntimeException e) { + if (log.isTraceEnabled()) { + log.trace("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId(), e); + } else { + log.info("[{}][{}] Session closed due to queue error", sessionRef.getSecurityCtx().getTenantId(), session.getId()); } + closeSession(CloseStatus.POLICY_VIOLATION.withReason("Max pending updates limit reached!")); + return; } + + processNextMsg(); } private void sendMsgInternal(TbWebSocketMsg msg) { @@ -292,9 +293,11 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (TbWebSocketMsgType.TEXT.equals(msg.getType())) { TbWebSocketTextMsg textMsg = (TbWebSocketTextMsg) msg; this.asyncRemote.sendText(textMsg.getMsg(), this); + // isSending status will be reset in the onResult method by call back } else { TbWebSocketPingMsg pingMsg = (TbWebSocketPingMsg) msg; - this.asyncRemote.sendPing(pingMsg.getMsg()); + this.asyncRemote.sendPing(pingMsg.getMsg()); // blocking call + isSending.set(false); processNextMsg(); } } catch (Exception e) { @@ -308,12 +311,17 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (!result.isOK()) { log.trace("[{}] Failed to send msg", session.getId(), result.getException()); closeSession(CloseStatus.SESSION_NOT_RELIABLE); - } else { - processNextMsg(); + return; } + + isSending.set(false); + processNextMsg(); } private void processNextMsg() { + if (msgQueue.isEmpty() || !isSending.compareAndSet(false, true)) { + return; + } TbWebSocketMsg msg = msgQueue.poll(); if (msg != null) { sendMsgInternal(msg); @@ -397,19 +405,21 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (tenantProfileConfiguration == null) { return true; } - + boolean limitAllowed; String sessionId = session.getId(); if (tenantProfileConfiguration.getMaxWsSessionsPerTenant() > 0) { Set tenantSessions = tenantSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getTenantId(), id -> ConcurrentHashMap.newKeySet()); synchronized (tenantSessions) { - if (tenantSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerTenant()) { + limitAllowed = tenantSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerTenant(); + if (limitAllowed) { tenantSessions.add(sessionId); - } else { + } + } + if (!limitAllowed) { log.info("[{}][{}][{}] Failed to start session. Max tenant sessions limit reached" , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max tenant sessions limit reached!")); return false; - } } } @@ -417,42 +427,48 @@ public class TbWebSocketHandler extends TextWebSocketHandler implements WebSocke if (tenantProfileConfiguration.getMaxWsSessionsPerCustomer() > 0) { Set customerSessions = customerSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getCustomerId(), id -> ConcurrentHashMap.newKeySet()); synchronized (customerSessions) { - if (customerSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerCustomer()) { + limitAllowed = customerSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerCustomer(); + if (limitAllowed) { customerSessions.add(sessionId); - } else { + } + } + if (!limitAllowed) { log.info("[{}][{}][{}] Failed to start session. Max customer sessions limit reached" , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max customer sessions limit reached")); return false; - } } } if (tenantProfileConfiguration.getMaxWsSessionsPerRegularUser() > 0 && UserPrincipal.Type.USER_NAME.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { Set regularUserSessions = regularUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); synchronized (regularUserSessions) { - if (regularUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerRegularUser()) { + limitAllowed = regularUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerRegularUser(); + if (limitAllowed) { regularUserSessions.add(sessionId); - } else { + } + } + if (!limitAllowed) { log.info("[{}][{}][{}] Failed to start session. Max regular user sessions limit reached" , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max regular user sessions limit reached")); return false; - } } } if (tenantProfileConfiguration.getMaxWsSessionsPerPublicUser() > 0 && UserPrincipal.Type.PUBLIC_ID.equals(sessionRef.getSecurityCtx().getUserPrincipal().getType())) { Set publicUserSessions = publicUserSessionsMap.computeIfAbsent(sessionRef.getSecurityCtx().getId(), id -> ConcurrentHashMap.newKeySet()); synchronized (publicUserSessions) { - if (publicUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerPublicUser()) { + limitAllowed = publicUserSessions.size() < tenantProfileConfiguration.getMaxWsSessionsPerPublicUser(); + if (limitAllowed) { publicUserSessions.add(sessionId); - } else { + } + } + if (!limitAllowed) { log.info("[{}][{}][{}] Failed to start session. Max public user sessions limit reached" , sessionRef.getSecurityCtx().getTenantId(), sessionRef.getSecurityCtx().getId(), sessionId); session.close(CloseStatus.POLICY_VIOLATION.withReason("Max public user sessions limit reached")); return false; - } } } } diff --git a/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java new file mode 100644 index 0000000000..0394e8a505 --- /dev/null +++ b/application/src/test/java/org/thingsboard/server/controller/plugin/TbWebSocketHandlerTest.java @@ -0,0 +1,160 @@ +/** + * Copyright © 2016-2023 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.controller.plugin; + +import lombok.extern.slf4j.Slf4j; +import org.awaitility.Awaitility; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.adapter.NativeWebSocketSession; +import org.thingsboard.common.util.ThingsBoardThreadFactory; +import org.thingsboard.server.service.ws.WebSocketSessionRef; + +import javax.websocket.RemoteEndpoint; +import javax.websocket.SendHandler; +import javax.websocket.SendResult; +import javax.websocket.Session; +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.BDDMockito.willAnswer; +import static org.mockito.BDDMockito.willDoNothing; +import static org.mockito.BDDMockito.willReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +@Slf4j +class TbWebSocketHandlerTest { + + TbWebSocketHandler wsHandler; + NativeWebSocketSession session; + Session nativeSession; + RemoteEndpoint.Async asyncRemote; + WebSocketSessionRef sessionRef; + int maxMsgQueuePerSession; + TbWebSocketHandler.SessionMetaData sendHandler; + ExecutorService executor; + + @BeforeEach + void setUp() throws IOException { + maxMsgQueuePerSession = 100; + executor = Executors.newCachedThreadPool(ThingsBoardThreadFactory.forName(getClass().getSimpleName())); + wsHandler = spy(new TbWebSocketHandler()); + willDoNothing().given(wsHandler).close(any(), any()); + session = mock(NativeWebSocketSession.class); + nativeSession = mock(Session.class); + willReturn(nativeSession).given(session).getNativeSession(Session.class); + asyncRemote = mock(RemoteEndpoint.Async.class); + willReturn(asyncRemote).given(nativeSession).getAsyncRemote(); + sessionRef = mock(WebSocketSessionRef.class, Mockito.RETURNS_DEEP_STUBS); //prevent NPE on logs + sendHandler = spy(wsHandler.new SessionMetaData(session, sessionRef, maxMsgQueuePerSession)); + } + + @AfterEach + void tearDown() { + if (executor != null) { + executor.shutdownNow(); + } + } + + @Test + void sendHandler_sendMsg_parallel_no_race() throws InterruptedException { + CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession * 2); + AtomicInteger sendersCount = new AtomicInteger(); + willAnswer(invocation -> { + assertThat(sendersCount.incrementAndGet()).as("no race").isEqualTo(1); + String text = invocation.getArgument(0); + SendHandler onResultHandler = invocation.getArgument(1); + SendResult sendResult = new SendResult(); + executor.submit(() -> { + sendersCount.decrementAndGet(); + onResultHandler.onResult(sendResult); + finishLatch.countDown(); + }); + return null; + }).given(asyncRemote).sendText(anyString(), any()); + + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + //first batch + IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + Awaitility.await("first batch processed").atMost(30, TimeUnit.SECONDS).until(() -> finishLatch.getCount() == maxMsgQueuePerSession); + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + //second batch - to test pause between big msg batches + IntStream.range(100, 100 + maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue(); + + verify(sendHandler, never()).closeSession(any()); + verify(sendHandler, times(maxMsgQueuePerSession * 2)).onResult(any()); + assertThat(sendHandler.isSending.get()).as("sendHandler not is in sending state").isFalse(); + } + + @Test + void sendHandler_sendMsg_message_order() throws InterruptedException { + CountDownLatch finishLatch = new CountDownLatch(maxMsgQueuePerSession); + Collection outputs = new ConcurrentLinkedQueue<>(); + willAnswer(invocation -> { + String text = invocation.getArgument(0); + outputs.add(text); + SendHandler onResultHandler = invocation.getArgument(1); + SendResult sendResult = new SendResult(); + executor.submit(() -> { + onResultHandler.onResult(sendResult); + finishLatch.countDown(); + }); + return null; + }).given(asyncRemote).sendText(anyString(), any()); + + List inputs = IntStream.range(0, maxMsgQueuePerSession).mapToObj(i -> "msg " + i).collect(Collectors.toList()); + inputs.forEach(s -> sendHandler.sendMsg(s)); + + assertThat(finishLatch.await(30, TimeUnit.SECONDS)).as("all callbacks fired").isTrue(); + assertThat(outputs).as("inputs exactly the same as outputs").containsExactlyElementsOf(inputs); + + verify(sendHandler, never()).closeSession(any()); + verify(sendHandler, times(maxMsgQueuePerSession)).onResult(any()); + } + + @Test + void sendHandler_sendMsg_queue_size_exceed() { + willDoNothing().given(asyncRemote).sendText(anyString(), any()); // send text will never call back, so queue will grow each sendMsg + sendHandler.sendMsg("first message to stay in-flight all the time during this test"); + IntStream.range(0, maxMsgQueuePerSession).parallel().forEach(i -> sendHandler.sendMsg("hello " + i)); + verify(sendHandler, never()).closeSession(any()); + sendHandler.sendMsg("excessive message"); + verify(sendHandler, times(1)).closeSession(eq(new CloseStatus(1008, "Max pending updates limit reached!"))); + verify(asyncRemote, times(1)).sendText(anyString(), any()); + } + +}