diff --git a/application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java b/application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java index 38d6f4bccb..609c087a51 100644 --- a/application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java +++ b/application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java @@ -167,6 +167,11 @@ class DefaultTbContext implements TbContext { } private void enqueue(TopicPartitionInfo tpi, TbMsg tbMsg, Consumer onFailure, Runnable onSuccess) { + if (!tbMsg.isValid()) { + log.trace("[{}] Skip invalid message: {}", getTenantId(), tbMsg); + onFailure.accept(new IllegalArgumentException("Source message is no longer valid!")); + return; + } TransportProtos.ToRuleEngineMsg msg = TransportProtos.ToRuleEngineMsg.newBuilder() .setTenantIdMSB(getTenantId().getId().getMostSignificantBits()) .setTenantIdLSB(getTenantId().getId().getLeastSignificantBits()) @@ -235,6 +240,11 @@ class DefaultTbContext implements TbContext { } private void enqueueForTellNext(TopicPartitionInfo tpi, String queueName, TbMsg source, Set relationTypes, String failureMessage, Runnable onSuccess, Consumer onFailure) { + if (!source.isValid()) { + log.trace("[{}] Skip invalid message: {}", getTenantId(), source); + onFailure.accept(new IllegalArgumentException("Source message is no longer valid!")); + return; + } RuleChainId ruleChainId = nodeCtx.getSelf().getRuleChainId(); RuleNodeId ruleNodeId = nodeCtx.getSelf().getId(); TbMsg tbMsg = TbMsg.newMsg(source, queueName, ruleChainId, ruleNodeId); diff --git a/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java b/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java index 88a4f6a202..6f6ea8fcec 100644 --- a/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java +++ b/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java @@ -200,17 +200,20 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor relationTypes, String failureMessage) { try { - checkActive(msg); + checkComponentStateActive(msg); EntityId entityId = msg.getOriginator(); TopicPartitionInfo tpi = systemContext.resolve(ServiceType.TB_RULE_ENGINE, msg.getQueueName(), tenantId, entityId); diff --git a/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java b/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java index 8fef00c49c..60f462f4d4 100644 --- a/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java +++ b/application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java @@ -27,6 +27,7 @@ import org.thingsboard.server.common.data.id.RuleChainId; import org.thingsboard.server.common.data.id.RuleNodeId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.msg.TbActorMsg; +import org.thingsboard.server.common.msg.TbMsg; import org.thingsboard.server.common.msg.plugin.ComponentLifecycleMsg; import org.thingsboard.server.common.msg.queue.PartitionChangeMsg; @@ -86,12 +87,19 @@ public class RuleNodeActor extends ComponentActor extends Abstract schedulePeriodicMsgWithDelay(context, new StatsPersistTick(), statsPersistFrequency, statsPersistFrequency); } - protected void checkActive(TbMsg tbMsg) throws RuleNodeException { + protected boolean checkMsgValid(TbMsg tbMsg) { + var valid = tbMsg.isValid(); + if (!valid) { + if (log.isTraceEnabled()) { + log.trace("Skip processing of message: {} because it is no longer valid!", tbMsg); + } + } + return valid; + } + + protected void checkComponentStateActive(TbMsg tbMsg) throws RuleNodeException { if (state != ComponentLifecycleState.ACTIVE) { log.debug("Component is not active. Current state [{}] for processor [{}][{}] tenant [{}]", state, entityId.getEntityType(), entityId, tenantId); RuleNodeException ruleNodeException = getInactiveException(); diff --git a/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java b/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java index eff1ecaf86..e3ce0927ef 100644 --- a/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java +++ b/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java @@ -66,6 +66,11 @@ public class TbMsgPackCallback implements TbMsgCallback { ctx.onFailure(tenantId, id, e); } + @Override + public boolean isMsgValid() { + return !ctx.isComplete(); + } + @Override public void onProcessingStart(RuleNodeInfo ruleNodeInfo) { log.trace("[{}] ON PROCESSING START: {}", id, ruleNodeInfo); diff --git a/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java b/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java index d7a064c4ca..62285a411c 100644 --- a/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java +++ b/application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java @@ -53,6 +53,8 @@ public class TbMsgPackProcessingContext { private final ConcurrentMap exceptionsMap = new ConcurrentHashMap<>(); private final ConcurrentMap lastRuleNodeMap = new ConcurrentHashMap<>(); + @Getter + private volatile boolean complete = false; public TbMsgPackProcessingContext(String queueName, TbRuleEngineSubmitStrategy submitStrategy) { this.queueName = queueName; @@ -149,6 +151,7 @@ public class TbMsgPackProcessingContext { } public void cleanup() { + complete = true; pendingMap.clear(); successMap.clear(); failedMap.clear(); diff --git a/application/src/main/resources/thingsboard.yml b/application/src/main/resources/thingsboard.yml index 6865aebf9a..5fc13c6952 100644 --- a/application/src/main/resources/thingsboard.yml +++ b/application/src/main/resources/thingsboard.yml @@ -619,6 +619,13 @@ transport: log: enabled: "${TB_TRANSPORT_LOG_ENABLED:true}" max_length: "${TB_TRANSPORT_LOG_MAX_LENGTH:1024}" + rate_limits: + # Enable or disable generic rate limits. Device and Tenant specific rate limits are controlled in Tenant Profile. + ip_limits_enabled: "${TB_TRANSPORT_IP_RATE_LIMITS_ENABLED:false}" + # Maximum number of connect attempts with invalid credentials + max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}" + # Timeout to expire block IP addresses + ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:60000}" # Local HTTP transport parameters http: enabled: "${HTTP_ENABLED:true}" @@ -630,6 +637,9 @@ transport: enabled: "${MQTT_ENABLED:true}" bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}" bind_port: "${MQTT_BIND_PORT:1883}" + # Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2. + # Useful to get the real IP address of the client in the logs and for rate limits. + proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}" timeout: "${MQTT_TIMEOUT:10000}" msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism netty: diff --git a/application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java b/application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java index 90a0404751..eb79bc9d35 100644 --- a/application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java +++ b/application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java @@ -144,6 +144,7 @@ public abstract class AbstractRuleEngineFlowIntegrationTest extends AbstractRule Thread.sleep(1000); TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class); + Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true); TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback); QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null); // Pushing Message to the system @@ -256,6 +257,7 @@ public abstract class AbstractRuleEngineFlowIntegrationTest extends AbstractRule Thread.sleep(1000); TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class); + Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true); TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback); QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null); // Pushing Message to the system diff --git a/application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java b/application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java index ccdf788634..4a92321e91 100644 --- a/application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java +++ b/application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java @@ -140,6 +140,7 @@ public abstract class AbstractRuleEngineLifecycleIntegrationTest extends Abstrac Thread.sleep(1000); TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class); + Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true); TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback); QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null); // Pushing Message to the system diff --git a/common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java b/common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java index 63c148b951..7c77d8a06b 100644 --- a/common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java +++ b/common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java @@ -269,7 +269,7 @@ public final class TbMsg implements Serializable { } public TbMsgCallback getCallback() { - //May be null in case of deserialization; + // May be null in case of deserialization; if (callback != null) { return callback; } else { @@ -288,4 +288,12 @@ public final class TbMsg implements Serializable { public TbMsgProcessingStackItem popFormStack() { return ctx.pop(); } + + /** + * Checks if the message is still valid for processing. May be invalid if the message pack is timed-out or canceled. + * @return 'true' if message is valid for processing, 'false' otherwise. + */ + public boolean isValid() { + return getCallback().isMsgValid(); + } } diff --git a/common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java b/common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java index fd62df83c9..2bab1e8340 100644 --- a/common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java +++ b/common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java @@ -17,6 +17,9 @@ package org.thingsboard.server.common.msg.queue; import org.thingsboard.server.common.data.id.RuleNodeId; +/** + * Should be renamed to TbMsgPackContext, but this can't be changed due to backward-compatibility. + */ public interface TbMsgCallback { TbMsgCallback EMPTY = new TbMsgCallback() { @@ -36,11 +39,20 @@ public interface TbMsgCallback { void onFailure(RuleEngineException e); + /** + * Returns 'true' if rule engine is expecting the message to be processed, 'false' otherwise. + * message may no longer be valid, if the message pack is already expired/canceled/failed. + * + * @return 'true' if rule engine is expecting the message to be processed, 'false' otherwise. + */ + default boolean isMsgValid() { + return true; + } + default void onProcessingStart(RuleNodeInfo ruleNodeInfo) { } default void onProcessingEnd(RuleNodeId ruleNodeId) { } - } diff --git a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java index ef3db6cf84..efb3fd8606 100644 --- a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java +++ b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java @@ -23,14 +23,12 @@ import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression; import org.springframework.stereotype.Component; -import org.thingsboard.common.util.ThingsBoardExecutors; import org.thingsboard.server.common.transport.TransportContext; import org.thingsboard.server.transport.mqtt.adaptors.JsonMqttAdaptor; import org.thingsboard.server.transport.mqtt.adaptors.ProtoMqttAdaptor; import javax.annotation.PostConstruct; -import javax.annotation.PreDestroy; -import java.util.concurrent.ExecutorService; +import java.net.InetSocketAddress; import java.util.concurrent.atomic.AtomicInteger; /** @@ -73,6 +71,10 @@ public class MqttTransportContext extends TransportContext { @Value("${transport.mqtt.timeout:10000}") private long timeout; + @Getter + @Value("${transport.mqtt.proxy_enabled:false}") + private boolean proxyEnabled; + private final AtomicInteger connectionsCounter = new AtomicInteger(); @PostConstruct @@ -88,4 +90,17 @@ public class MqttTransportContext extends TransportContext { public void channelUnregistered() { connectionsCounter.decrementAndGet(); } + + public boolean checkAddress(InetSocketAddress address) { + return rateLimitService.checkAddress(address); + } + + public void onAuthSuccess(InetSocketAddress address) { + rateLimitService.onAuthSuccess(address); + } + + public void onAuthFailure(InetSocketAddress address) { + rateLimitService.onAuthFailure(address); + } + } diff --git a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java index bf474cc266..50e0f1b190 100644 --- a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java +++ b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java @@ -164,6 +164,9 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { log.trace("[{}] Processing msg: {}", sessionId, msg); + if (address == null) { + address = getAddress(ctx); + } try { if (msg instanceof MqttMessage) { MqttMessage message = (MqttMessage) msg; @@ -182,8 +185,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement } } + InetSocketAddress getAddress(ChannelHandlerContext ctx) { + return ctx.channel().attr(MqttTransportService.ADDRESS).get(); + } + void processMqttMsg(ChannelHandlerContext ctx, MqttMessage msg) { - address = getAddress(ctx); if (msg.fixedHeader() == null) { log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort()); ctx.close(); @@ -199,10 +205,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement } } - InetSocketAddress getAddress(ChannelHandlerContext ctx) { - return (InetSocketAddress) ctx.channel().remoteAddress(); - } - private void processProvisionSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) { switch (msg.fixedHeader().messageType()) { case PUBLISH: @@ -771,7 +773,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement private void processAuthTokenConnect(ChannelHandlerContext ctx, MqttConnectMessage connectMessage) { String userName = connectMessage.payload().userName(); - log.debug("[{}] Processing connect msg for client with user name: {}!", sessionId, userName); + log.debug("[{}][{}] Processing connect msg for client with user name: {}!", address, sessionId, userName); TransportProtos.ValidateBasicMqttCredRequestMsg.Builder request = TransportProtos.ValidateBasicMqttCredRequestMsg.newBuilder() .setClientId(connectMessage.payload().clientIdentifier()); if (userName != null) { @@ -820,6 +822,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement } }); } catch (Exception e) { + context.onAuthFailure(address); ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage)); log.trace("[{}] X509 auth failure: {}", sessionId, address, e); ctx.close(); @@ -931,9 +934,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement private void onValidateDeviceResponse(ValidateDeviceCredentialsResponse msg, ChannelHandlerContext ctx, MqttConnectMessage connectMessage) { if (!msg.hasDeviceInfo()) { + context.onAuthFailure(address); ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage)); ctx.close(); } else { + context.onAuthSuccess(address); deviceSessionCtx.setDeviceInfo(msg.getDeviceInfo()); deviceSessionCtx.setDeviceProfile(msg.getDeviceProfile()); deviceSessionCtx.setSessionInfo(SessionInfoCreator.create(msg, context, sessionId)); diff --git a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java index edc26cb4a9..07d87e9ce1 100644 --- a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java +++ b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java @@ -18,9 +18,12 @@ package org.thingsboard.server.transport.mqtt; import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.codec.haproxy.HAProxyMessageDecoder; import io.netty.handler.codec.mqtt.MqttDecoder; import io.netty.handler.codec.mqtt.MqttEncoder; import io.netty.handler.ssl.SslHandler; +import org.thingsboard.server.transport.mqtt.limits.IpFilter; +import org.thingsboard.server.transport.mqtt.limits.ProxyIpFilter; /** * @author Andrew Shvayka @@ -39,6 +42,12 @@ public class MqttTransportServerInitializer extends ChannelInitializer ADDRESS = AttributeKey.newInstance("SRC_ADDRESS"); + @Value("${transport.mqtt.bind_address}") private String host; @Value("${transport.mqtt.bind_port}") diff --git a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/IpFilter.java b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/IpFilter.java new file mode 100644 index 0000000000..de8b8dacae --- /dev/null +++ b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/IpFilter.java @@ -0,0 +1,46 @@ +/** + * Copyright © 2016-2021 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.transport.mqtt.limits; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.handler.ipfilter.AbstractRemoteAddressFilter; +import lombok.extern.slf4j.Slf4j; +import org.thingsboard.server.transport.mqtt.MqttTransportContext; +import org.thingsboard.server.transport.mqtt.MqttTransportService; + +import java.net.InetSocketAddress; + +@Slf4j +public class IpFilter extends AbstractRemoteAddressFilter { + + private MqttTransportContext context; + + public IpFilter(MqttTransportContext context) { + this.context = context; + } + + @Override + protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception { + if(context.checkAddress(remoteAddress)){ + ctx.channel().attr(MqttTransportService.ADDRESS).set(remoteAddress); + return true; + } else { + return false; + } + } +} diff --git a/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/ProxyIpFilter.java b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/ProxyIpFilter.java new file mode 100644 index 0000000000..7cd42e36da --- /dev/null +++ b/common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/ProxyIpFilter.java @@ -0,0 +1,60 @@ +/** + * Copyright © 2016-2021 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.transport.mqtt.limits; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.haproxy.HAProxyMessage; +import io.netty.util.AttributeKey; +import lombok.extern.slf4j.Slf4j; +import org.thingsboard.server.transport.mqtt.MqttTransportContext; +import org.thingsboard.server.transport.mqtt.MqttTransportService; + +import java.net.InetAddress; +import java.net.InetSocketAddress; + +@Slf4j +public class ProxyIpFilter extends ChannelInboundHandlerAdapter { + + + private MqttTransportContext context; + + public ProxyIpFilter(MqttTransportContext context) { + this.context = context; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if(msg instanceof HAProxyMessage){ + HAProxyMessage proxyMsg = (HAProxyMessage) msg; + if(proxyMsg.sourceAddress() != null && proxyMsg.sourcePort() > 0) { + InetSocketAddress address = new InetSocketAddress(proxyMsg.sourceAddress(), proxyMsg.sourcePort()); + if(!context.checkAddress(address)){ + ctx.close(); + } else { + ctx.channel().attr(MqttTransportService.ADDRESS).set(address); + // We no longer need this channel in the pipeline. Similar to HAProxyMessageDecoder + ctx.pipeline().remove(this); + } + } else { + log.debug("Received local health-check connection message: {}", proxyMsg); + ctx.close(); + } + } else { + super.channelRead(ctx, msg); + } + } +} diff --git a/common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java b/common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java index a9903fe867..cc26ff6fa8 100644 --- a/common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java +++ b/common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java @@ -116,7 +116,7 @@ public class MqttTransportHandlerTest { MqttConnectMessage msg = getMqttConnectMessage(); willDoNothing().given(handler).processConnect(ctx, msg); - handler.processMqttMsg(ctx, msg); + handler.channelRead(ctx, msg); assertThat(handler.address, is(IP_ADDR)); assertThat(handler.deviceSessionCtx.getChannel(), is(ctx)); @@ -152,7 +152,7 @@ public class MqttTransportHandlerTest { List messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList()); - messages.forEach((msg) -> handler.processMqttMsg(ctx, msg)); + messages.forEach((msg) -> handler.channelRead(ctx, msg)); assertThat(handler.address, is(IP_ADDR)); assertThat(handler.deviceSessionCtx.getChannel(), is(ctx)); diff --git a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java index ec21851b84..03760566c4 100644 --- a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java +++ b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java @@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j; import org.springframework.beans.factory.annotation.Autowired; import org.thingsboard.common.util.ThingsBoardExecutors; import org.thingsboard.server.cache.ota.OtaPackageDataCache; +import org.thingsboard.server.common.transport.limits.TransportRateLimitService; import org.thingsboard.server.queue.discovery.TbServiceInfoProvider; import org.thingsboard.server.queue.scheduler.SchedulerComponent; @@ -57,6 +58,9 @@ public abstract class TransportContext { @Autowired private TransportResourceCache transportResourceCache; + @Autowired + protected TransportRateLimitService rateLimitService; + @PostConstruct public void init() { executor = ThingsBoardExecutors.newWorkStealingPool(50, getClass()); @@ -73,4 +77,6 @@ public abstract class TransportContext { return serviceInfoProvider.getServiceId(); } + + } diff --git a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java index 5ee4064466..1d7fe4162b 100644 --- a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java +++ b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java @@ -16,6 +16,7 @@ package org.thingsboard.server.common.transport.limits; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Service; import org.springframework.util.StringUtils; import org.thingsboard.server.common.data.EntityType; @@ -25,12 +26,13 @@ import org.thingsboard.server.common.data.id.EntityId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration; import org.thingsboard.server.common.data.tenant.profile.TenantProfileData; -import org.thingsboard.server.common.msg.tools.TbRateLimits; import org.thingsboard.server.common.transport.TransportTenantProfileCache; import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult; import org.thingsboard.server.queue.util.TbTransportComponent; -import java.util.HashSet; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Map; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -47,9 +49,17 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi private final ConcurrentMap> tenantDevices = new ConcurrentHashMap<>(); private final ConcurrentMap perTenantLimits = new ConcurrentHashMap<>(); private final ConcurrentMap perDeviceLimits = new ConcurrentHashMap<>(); + private final Map ipMap = new ConcurrentHashMap<>(); private final TransportTenantProfileCache tenantProfileCache; + @Value("${transport.rate_limits.ip_limits_enabled:false}") + private boolean ipRateLimitsEnabled; + @Value("${transport.rate_limits.max_wrong_credentials_per_ip:10}") + private int maxWrongCredentialsPerIp; + @Value("${transport.rate_limits.ip_block_timeout:60000}") + private long ipBlockTimeout; + public DefaultTransportRateLimitService(TransportTenantProfileCache tenantProfileCache) { this.tenantProfileCache = tenantProfileCache; } @@ -116,6 +126,75 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi tenantAllowed.put(tenantId, allowed); } + @Override + public boolean checkAddress(InetSocketAddress address) { + if (!ipRateLimitsEnabled) { + return true; + } + var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats()); + return !stats.isBlocked() || (stats.getLastActivityTs() + ipBlockTimeout < System.currentTimeMillis()); + } + + @Override + public void onAuthSuccess(InetSocketAddress address) { + if (!ipRateLimitsEnabled) { + return; + } + + var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats()); + stats.getLock().lock(); + try { + stats.setLastActivityTs(System.currentTimeMillis()); + stats.setFailureCount(0); + if (stats.isBlocked()) { + stats.setBlocked(false); + log.info("[{}] IP address un-blocked due to correct credentials.", address.getAddress()); + } + } finally { + stats.getLock().unlock(); + } + } + + @Override + public void onAuthFailure(InetSocketAddress address) { + if (!ipRateLimitsEnabled) { + return; + } + + var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats()); + stats.getLock().lock(); + try { + stats.setLastActivityTs(System.currentTimeMillis()); + int failureCount = stats.getFailureCount() + 1; + stats.setFailureCount(failureCount); + if (failureCount >= maxWrongCredentialsPerIp) { + log.info("[{}] IP address blocked due to constantly wrong credentials.", address.getAddress()); + stats.setBlocked(true); + } + } finally { + stats.getLock().unlock(); + } + } + + @Override + public void invalidateRateLimitsIpTable(long sessionInactivityTimeout) { + if (!ipRateLimitsEnabled) { + return; + } + long currentTime = System.currentTimeMillis(); + long expTime = currentTime - Math.max(sessionInactivityTimeout, ipBlockTimeout); + for (var entry : ipMap.entrySet()) { + var stats = entry.getValue(); + if (stats.getLastActivityTs() < expTime) { + log.debug("[{}] IP address removed due to session inactivity timeout.", entry.getKey()); + ipMap.remove(entry.getKey()); + } else if (stats.isBlocked() && (stats.getLastActivityTs() + ipBlockTimeout < currentTime)) { + log.info("[{}] IP address unblocked due ip block timeout.", entry.getKey()); + stats.setBlocked(false); + } + } + } + private void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits, Function getFunction, BiConsumer putFunction) { diff --git a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/InetAddressRateLimitStats.java b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/InetAddressRateLimitStats.java new file mode 100644 index 0000000000..ba857df418 --- /dev/null +++ b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/InetAddressRateLimitStats.java @@ -0,0 +1,33 @@ +/** + * Copyright © 2016-2021 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.common.transport.limits; + +import lombok.Data; + +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; + +@Data +public class InetAddressRateLimitStats { + + private final Lock lock = new ReentrantLock(); + + private boolean blocked; + private long lastActivityTs; + private int failureCount; + private int connectionsCount; + +} diff --git a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java index b7c6e27e0d..9a0880022f 100644 --- a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java +++ b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java @@ -20,6 +20,8 @@ import org.thingsboard.server.common.data.id.DeviceId; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult; +import java.net.InetSocketAddress; + public interface TransportRateLimitService { EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints); @@ -33,4 +35,13 @@ public interface TransportRateLimitService { void remove(DeviceId deviceId); void update(TenantId tenantId, boolean transportEnabled); + + boolean checkAddress(InetSocketAddress address); + + void onAuthSuccess(InetSocketAddress address); + + void onAuthFailure(InetSocketAddress address); + + void invalidateRateLimitsIpTable(long sessionInactivityTimeout); + } diff --git a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java index 1e9cccceb2..4dfe42885c 100644 --- a/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java +++ b/common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java @@ -209,6 +209,7 @@ public class DefaultTransportService implements TransportService { this.transportApiStats = statsFactory.createMessagesStats(StatsType.TRANSPORT.getName() + ".producer"); this.transportCallbackExecutor = ThingsBoardExecutors.newWorkStealingPool(20, getClass()); this.scheduler.scheduleAtFixedRate(this::checkInactivityAndReportActivity, new Random().nextInt((int) sessionReportTimeout), sessionReportTimeout, TimeUnit.MILLISECONDS); + this.scheduler.scheduleAtFixedRate(this::invalidateRateLimits, new Random().nextInt((int) sessionReportTimeout), sessionReportTimeout, TimeUnit.MILLISECONDS); transportApiRequestTemplate = queueProvider.createTransportApiRequestTemplate(); transportApiRequestTemplate.setMessagesStats(transportApiStats); ruleEngineMsgProducer = producerProvider.getRuleEngineMsgProducer(); @@ -247,6 +248,10 @@ public class DefaultTransportService implements TransportService { }); } + private void invalidateRateLimits() { + rateLimitService.invalidateRateLimitsIpTable(sessionInactivityTimeout); + } + @PreDestroy public void destroy() { stopped = true; diff --git a/transport/mqtt/src/main/resources/tb-mqtt-transport.yml b/transport/mqtt/src/main/resources/tb-mqtt-transport.yml index 435b780818..e75825d474 100644 --- a/transport/mqtt/src/main/resources/tb-mqtt-transport.yml +++ b/transport/mqtt/src/main/resources/tb-mqtt-transport.yml @@ -88,6 +88,9 @@ transport: mqtt: bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}" bind_port: "${MQTT_BIND_PORT:1883}" + # Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2. + # Useful to get the real IP address of the client in the logs and for rate limits. + proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}" timeout: "${MQTT_TIMEOUT:10000}" msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism netty: @@ -146,6 +149,15 @@ transport: stats: enabled: "${TB_TRANSPORT_STATS_ENABLED:true}" print-interval-ms: "${TB_TRANSPORT_STATS_PRINT_INTERVAL_MS:60000}" + client_side_rpc: + timeout: "${CLIENT_SIDE_RPC_TIMEOUT:60000}" + rate_limits: + # Enable or disable generic rate limits. Device and Tenant specific rate limits are controlled in Tenant Profile. + ip_limits_enabled: "${TB_TRANSPORT_IP_RATE_LIMITS_ENABLED:false}" + # Maximum number of connect attempts with invalid credentials + max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}" + # Timeout to expire block IP addresses + ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:60000}" queue: