Browse Source

Merge pull request #5875 from thingsboard/feature/mqtt-rate-limits

[3.3.3] IP Rate Limits for MQTT
pull/5891/head
Andrew Shvayka 4 years ago
committed by GitHub
parent
commit
8d12abb846
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
  1. 10
      application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java
  2. 28
      application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java
  3. 14
      application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java
  4. 4
      application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActorMessageProcessor.java
  5. 12
      application/src/main/java/org/thingsboard/server/actors/shared/ComponentMsgProcessor.java
  6. 5
      application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java
  7. 3
      application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java
  8. 10
      application/src/main/resources/thingsboard.yml
  9. 2
      application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java
  10. 1
      application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java
  11. 10
      common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java
  12. 14
      common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java
  13. 21
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java
  14. 17
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java
  15. 9
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java
  16. 5
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportService.java
  17. 46
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/IpFilter.java
  18. 60
      common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/ProxyIpFilter.java
  19. 4
      common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java
  20. 6
      common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java
  21. 83
      common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java
  22. 33
      common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/InetAddressRateLimitStats.java
  23. 11
      common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java
  24. 5
      common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java
  25. 12
      transport/mqtt/src/main/resources/tb-mqtt-transport.yml

10
application/src/main/java/org/thingsboard/server/actors/ruleChain/DefaultTbContext.java

@ -167,6 +167,11 @@ class DefaultTbContext implements TbContext {
}
private void enqueue(TopicPartitionInfo tpi, TbMsg tbMsg, Consumer<Throwable> onFailure, Runnable onSuccess) {
if (!tbMsg.isValid()) {
log.trace("[{}] Skip invalid message: {}", getTenantId(), tbMsg);
onFailure.accept(new IllegalArgumentException("Source message is no longer valid!"));
return;
}
TransportProtos.ToRuleEngineMsg msg = TransportProtos.ToRuleEngineMsg.newBuilder()
.setTenantIdMSB(getTenantId().getId().getMostSignificantBits())
.setTenantIdLSB(getTenantId().getId().getLeastSignificantBits())
@ -235,6 +240,11 @@ class DefaultTbContext implements TbContext {
}
private void enqueueForTellNext(TopicPartitionInfo tpi, String queueName, TbMsg source, Set<String> relationTypes, String failureMessage, Runnable onSuccess, Consumer<Throwable> onFailure) {
if (!source.isValid()) {
log.trace("[{}] Skip invalid message: {}", getTenantId(), source);
onFailure.accept(new IllegalArgumentException("Source message is no longer valid!"));
return;
}
RuleChainId ruleChainId = nodeCtx.getSelf().getRuleChainId();
RuleNodeId ruleNodeId = nodeCtx.getSelf().getId();
TbMsg tbMsg = TbMsg.newMsg(source, queueName, ruleChainId, ruleNodeId);

28
application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleChainActorMessageProcessor.java

@ -200,17 +200,20 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor<RuleCh
void onQueueToRuleEngineMsg(QueueToRuleEngineMsg envelope) {
TbMsg msg = envelope.getMsg();
if (!checkMsgValid(msg)) {
return;
}
log.trace("[{}][{}] Processing message [{}]: {}", entityId, firstId, msg.getId(), msg);
if (envelope.getRelationTypes() == null || envelope.getRelationTypes().isEmpty()) {
onTellNext(msg, true);
} else {
onTellNext(envelope.getMsg(), envelope.getMsg().getRuleNodeId(), envelope.getRelationTypes(), envelope.getFailureMessage());
onTellNext(msg, envelope.getMsg().getRuleNodeId(), envelope.getRelationTypes(), envelope.getFailureMessage());
}
}
private void onTellNext(TbMsg msg, boolean useRuleNodeIdFromMsg) {
try {
checkActive(msg);
checkComponentStateActive(msg);
RuleNodeId targetId = useRuleNodeIdFromMsg ? msg.getRuleNodeId() : null;
RuleNodeCtx targetCtx;
if (targetId == null) {
@ -234,6 +237,10 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor<RuleCh
}
public void onRuleChainInputMsg(RuleChainInputMsg envelope) {
var msg = envelope.getMsg();
if (!checkMsgValid(msg)) {
return;
}
if (entityId.equals(envelope.getRuleChainId())) {
onTellNext(envelope.getMsg(), false);
} else {
@ -242,6 +249,10 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor<RuleCh
}
public void onRuleChainOutputMsg(RuleChainOutputMsg envelope) {
var msg = envelope.getMsg();
if (!checkMsgValid(msg)) {
return;
}
if (entityId.equals(envelope.getRuleChainId())) {
var originatorNodeId = envelope.getTargetRuleNodeId();
RuleNodeCtx ruleNodeCtx = nodeActors.get(originatorNodeId);
@ -255,8 +266,12 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor<RuleCh
}
void onRuleChainToRuleChainMsg(RuleChainToRuleChainMsg envelope) {
var msg = envelope.getMsg();
if (!checkMsgValid(msg)) {
return;
}
try {
checkActive(envelope.getMsg());
checkComponentStateActive(envelope.getMsg());
if (firstNode != null) {
pushMsgToNode(firstNode, envelope.getMsg(), envelope.getFromRelationType());
} else {
@ -268,12 +283,15 @@ public class RuleChainActorMessageProcessor extends ComponentMsgProcessor<RuleCh
}
void onTellNext(RuleNodeToRuleChainTellNextMsg envelope) {
onTellNext(envelope.getMsg(), envelope.getOriginator(), envelope.getRelationTypes(), envelope.getFailureMessage());
var msg = envelope.getMsg();
if (checkMsgValid(msg)) {
onTellNext(msg, envelope.getOriginator(), envelope.getRelationTypes(), envelope.getFailureMessage());
}
}
private void onTellNext(TbMsg msg, RuleNodeId originatorNodeId, Set<String> relationTypes, String failureMessage) {
try {
checkActive(msg);
checkComponentStateActive(msg);
EntityId entityId = msg.getOriginator();
TopicPartitionInfo tpi = systemContext.resolve(ServiceType.TB_RULE_ENGINE, msg.getQueueName(), tenantId, entityId);

14
application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActor.java

@ -27,6 +27,7 @@ import org.thingsboard.server.common.data.id.RuleChainId;
import org.thingsboard.server.common.data.id.RuleNodeId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.msg.TbActorMsg;
import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.common.msg.plugin.ComponentLifecycleMsg;
import org.thingsboard.server.common.msg.queue.PartitionChangeMsg;
@ -86,12 +87,19 @@ public class RuleNodeActor extends ComponentActor<RuleNodeId, RuleNodeActorMessa
}
}
private void onRuleChainToRuleNodeMsg(RuleChainToRuleNodeMsg msg) {
private void onRuleChainToRuleNodeMsg(RuleChainToRuleNodeMsg envelope) {
TbMsg msg = envelope.getMsg();
if (!msg.isValid()) {
if (log.isTraceEnabled()) {
log.trace("Skip processing of message: {} because it is no longer valid!", msg);
}
return;
}
if (log.isDebugEnabled()) {
log.debug("[{}][{}][{}] Going to process rule msg: {}", ruleChainId, id, processor.getComponentName(), msg.getMsg());
log.debug("[{}][{}][{}] Going to process rule engine msg: {}", ruleChainId, id, processor.getComponentName(), msg);
}
try {
processor.onRuleChainToRuleNodeMsg(msg);
processor.onRuleChainToRuleNodeMsg(envelope);
increaseMessagesProcessedCount();
} catch (Exception e) {
logAndPersist("onRuleMsg", e);

4
application/src/main/java/org/thingsboard/server/actors/ruleChain/RuleNodeActorMessageProcessor.java

@ -101,7 +101,7 @@ public class RuleNodeActorMessageProcessor extends ComponentMsgProcessor<RuleNod
}
public void onRuleToSelfMsg(RuleNodeToSelfMsg msg) throws Exception {
checkActive(msg.getMsg());
checkComponentStateActive(msg.getMsg());
TbMsg tbMsg = msg.getMsg();
int ruleNodeCount = tbMsg.getAndIncrementRuleNodeCounter();
int maxRuleNodeExecutionsPerMessage = getTenantProfileConfiguration().getMaxRuleNodeExecsPerMessage();
@ -122,7 +122,7 @@ public class RuleNodeActorMessageProcessor extends ComponentMsgProcessor<RuleNod
void onRuleChainToRuleNodeMsg(RuleChainToRuleNodeMsg msg) throws Exception {
msg.getMsg().getCallback().onProcessingStart(info);
checkActive(msg.getMsg());
checkComponentStateActive(msg.getMsg());
TbMsg tbMsg = msg.getMsg();
int ruleNodeCount = tbMsg.getAndIncrementRuleNodeCounter();
int maxRuleNodeExecutionsPerMessage = getTenantProfileConfiguration().getMaxRuleNodeExecsPerMessage();

12
application/src/main/java/org/thingsboard/server/actors/shared/ComponentMsgProcessor.java

@ -82,7 +82,17 @@ public abstract class ComponentMsgProcessor<T extends EntityId> extends Abstract
schedulePeriodicMsgWithDelay(context, new StatsPersistTick(), statsPersistFrequency, statsPersistFrequency);
}
protected void checkActive(TbMsg tbMsg) throws RuleNodeException {
protected boolean checkMsgValid(TbMsg tbMsg) {
var valid = tbMsg.isValid();
if (!valid) {
if (log.isTraceEnabled()) {
log.trace("Skip processing of message: {} because it is no longer valid!", tbMsg);
}
}
return valid;
}
protected void checkComponentStateActive(TbMsg tbMsg) throws RuleNodeException {
if (state != ComponentLifecycleState.ACTIVE) {
log.debug("Component is not active. Current state [{}] for processor [{}][{}] tenant [{}]", state, entityId.getEntityType(), entityId, tenantId);
RuleNodeException ruleNodeException = getInactiveException();

5
application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackCallback.java

@ -66,6 +66,11 @@ public class TbMsgPackCallback implements TbMsgCallback {
ctx.onFailure(tenantId, id, e);
}
@Override
public boolean isMsgValid() {
return !ctx.isComplete();
}
@Override
public void onProcessingStart(RuleNodeInfo ruleNodeInfo) {
log.trace("[{}] ON PROCESSING START: {}", id, ruleNodeInfo);

3
application/src/main/java/org/thingsboard/server/service/queue/TbMsgPackProcessingContext.java

@ -53,6 +53,8 @@ public class TbMsgPackProcessingContext {
private final ConcurrentMap<TenantId, RuleEngineException> exceptionsMap = new ConcurrentHashMap<>();
private final ConcurrentMap<UUID, RuleNodeInfo> lastRuleNodeMap = new ConcurrentHashMap<>();
@Getter
private volatile boolean complete = false;
public TbMsgPackProcessingContext(String queueName, TbRuleEngineSubmitStrategy submitStrategy) {
this.queueName = queueName;
@ -149,6 +151,7 @@ public class TbMsgPackProcessingContext {
}
public void cleanup() {
complete = true;
pendingMap.clear();
successMap.clear();
failedMap.clear();

10
application/src/main/resources/thingsboard.yml

@ -619,6 +619,13 @@ transport:
log:
enabled: "${TB_TRANSPORT_LOG_ENABLED:true}"
max_length: "${TB_TRANSPORT_LOG_MAX_LENGTH:1024}"
rate_limits:
# Enable or disable generic rate limits. Device and Tenant specific rate limits are controlled in Tenant Profile.
ip_limits_enabled: "${TB_TRANSPORT_IP_RATE_LIMITS_ENABLED:false}"
# Maximum number of connect attempts with invalid credentials
max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}"
# Timeout to expire block IP addresses
ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:60000}"
# Local HTTP transport parameters
http:
enabled: "${HTTP_ENABLED:true}"
@ -630,6 +637,9 @@ transport:
enabled: "${MQTT_ENABLED:true}"
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty:

2
application/src/test/java/org/thingsboard/server/rules/flow/AbstractRuleEngineFlowIntegrationTest.java

@ -144,6 +144,7 @@ public abstract class AbstractRuleEngineFlowIntegrationTest extends AbstractRule
Thread.sleep(1000);
TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class);
Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true);
TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback);
QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null);
// Pushing Message to the system
@ -256,6 +257,7 @@ public abstract class AbstractRuleEngineFlowIntegrationTest extends AbstractRule
Thread.sleep(1000);
TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class);
Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true);
TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback);
QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null);
// Pushing Message to the system

1
application/src/test/java/org/thingsboard/server/rules/lifecycle/AbstractRuleEngineLifecycleIntegrationTest.java

@ -140,6 +140,7 @@ public abstract class AbstractRuleEngineLifecycleIntegrationTest extends Abstrac
Thread.sleep(1000);
TbMsgCallback tbMsgCallback = Mockito.mock(TbMsgCallback.class);
Mockito.when(tbMsgCallback.isMsgValid()).thenReturn(true);
TbMsg tbMsg = TbMsg.newMsg("CUSTOM", device.getId(), new TbMsgMetaData(), "{}", tbMsgCallback);
QueueToRuleEngineMsg qMsg = new QueueToRuleEngineMsg(savedTenant.getId(), tbMsg, null, null);
// Pushing Message to the system

10
common/message/src/main/java/org/thingsboard/server/common/msg/TbMsg.java

@ -269,7 +269,7 @@ public final class TbMsg implements Serializable {
}
public TbMsgCallback getCallback() {
//May be null in case of deserialization;
// May be null in case of deserialization;
if (callback != null) {
return callback;
} else {
@ -288,4 +288,12 @@ public final class TbMsg implements Serializable {
public TbMsgProcessingStackItem popFormStack() {
return ctx.pop();
}
/**
* Checks if the message is still valid for processing. May be invalid if the message pack is timed-out or canceled.
* @return 'true' if message is valid for processing, 'false' otherwise.
*/
public boolean isValid() {
return getCallback().isMsgValid();
}
}

14
common/message/src/main/java/org/thingsboard/server/common/msg/queue/TbMsgCallback.java

@ -17,6 +17,9 @@ package org.thingsboard.server.common.msg.queue;
import org.thingsboard.server.common.data.id.RuleNodeId;
/**
* Should be renamed to TbMsgPackContext, but this can't be changed due to backward-compatibility.
*/
public interface TbMsgCallback {
TbMsgCallback EMPTY = new TbMsgCallback() {
@ -36,11 +39,20 @@ public interface TbMsgCallback {
void onFailure(RuleEngineException e);
/**
* Returns 'true' if rule engine is expecting the message to be processed, 'false' otherwise.
* message may no longer be valid, if the message pack is already expired/canceled/failed.
*
* @return 'true' if rule engine is expecting the message to be processed, 'false' otherwise.
*/
default boolean isMsgValid() {
return true;
}
default void onProcessingStart(RuleNodeInfo ruleNodeInfo) {
}
default void onProcessingEnd(RuleNodeId ruleNodeId) {
}
}

21
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportContext.java

@ -23,14 +23,12 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.autoconfigure.condition.ConditionalOnExpression;
import org.springframework.stereotype.Component;
import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.server.common.transport.TransportContext;
import org.thingsboard.server.transport.mqtt.adaptors.JsonMqttAdaptor;
import org.thingsboard.server.transport.mqtt.adaptors.ProtoMqttAdaptor;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.util.concurrent.ExecutorService;
import java.net.InetSocketAddress;
import java.util.concurrent.atomic.AtomicInteger;
/**
@ -73,6 +71,10 @@ public class MqttTransportContext extends TransportContext {
@Value("${transport.mqtt.timeout:10000}")
private long timeout;
@Getter
@Value("${transport.mqtt.proxy_enabled:false}")
private boolean proxyEnabled;
private final AtomicInteger connectionsCounter = new AtomicInteger();
@PostConstruct
@ -88,4 +90,17 @@ public class MqttTransportContext extends TransportContext {
public void channelUnregistered() {
connectionsCounter.decrementAndGet();
}
public boolean checkAddress(InetSocketAddress address) {
return rateLimitService.checkAddress(address);
}
public void onAuthSuccess(InetSocketAddress address) {
rateLimitService.onAuthSuccess(address);
}
public void onAuthFailure(InetSocketAddress address) {
rateLimitService.onAuthFailure(address);
}
}

17
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportHandler.java

@ -164,6 +164,9 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
log.trace("[{}] Processing msg: {}", sessionId, msg);
if (address == null) {
address = getAddress(ctx);
}
try {
if (msg instanceof MqttMessage) {
MqttMessage message = (MqttMessage) msg;
@ -182,8 +185,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return ctx.channel().attr(MqttTransportService.ADDRESS).get();
}
void processMqttMsg(ChannelHandlerContext ctx, MqttMessage msg) {
address = getAddress(ctx);
if (msg.fixedHeader() == null) {
log.info("[{}:{}] Invalid message received", address.getHostName(), address.getPort());
ctx.close();
@ -199,10 +205,6 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
InetSocketAddress getAddress(ChannelHandlerContext ctx) {
return (InetSocketAddress) ctx.channel().remoteAddress();
}
private void processProvisionSessionMsg(ChannelHandlerContext ctx, MqttMessage msg) {
switch (msg.fixedHeader().messageType()) {
case PUBLISH:
@ -771,7 +773,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void processAuthTokenConnect(ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
String userName = connectMessage.payload().userName();
log.debug("[{}] Processing connect msg for client with user name: {}!", sessionId, userName);
log.debug("[{}][{}] Processing connect msg for client with user name: {}!", address, sessionId, userName);
TransportProtos.ValidateBasicMqttCredRequestMsg.Builder request = TransportProtos.ValidateBasicMqttCredRequestMsg.newBuilder()
.setClientId(connectMessage.payload().clientIdentifier());
if (userName != null) {
@ -820,6 +822,7 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
});
} catch (Exception e) {
context.onAuthFailure(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
log.trace("[{}] X509 auth failure: {}", sessionId, address, e);
ctx.close();
@ -931,9 +934,11 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
private void onValidateDeviceResponse(ValidateDeviceCredentialsResponse msg, ChannelHandlerContext ctx, MqttConnectMessage connectMessage) {
if (!msg.hasDeviceInfo()) {
context.onAuthFailure(address);
ctx.writeAndFlush(createMqttConnAckMsg(CONNECTION_REFUSED_NOT_AUTHORIZED, connectMessage));
ctx.close();
} else {
context.onAuthSuccess(address);
deviceSessionCtx.setDeviceInfo(msg.getDeviceInfo());
deviceSessionCtx.setDeviceProfile(msg.getDeviceProfile());
deviceSessionCtx.setSessionInfo(SessionInfoCreator.create(msg, context, sessionId));

9
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportServerInitializer.java

@ -18,9 +18,12 @@ package org.thingsboard.server.transport.mqtt;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.codec.mqtt.MqttDecoder;
import io.netty.handler.codec.mqtt.MqttEncoder;
import io.netty.handler.ssl.SslHandler;
import org.thingsboard.server.transport.mqtt.limits.IpFilter;
import org.thingsboard.server.transport.mqtt.limits.ProxyIpFilter;
/**
* @author Andrew Shvayka
@ -39,6 +42,12 @@ public class MqttTransportServerInitializer extends ChannelInitializer<SocketCha
public void initChannel(SocketChannel ch) {
ChannelPipeline pipeline = ch.pipeline();
SslHandler sslHandler = null;
if (context.isProxyEnabled()) {
pipeline.addLast("proxy", new HAProxyMessageDecoder());
pipeline.addLast("ipFilter", new ProxyIpFilter(context));
} else {
pipeline.addLast("ipFilter", new IpFilter(context));
}
if (sslEnabled && context.getSslHandlerProvider() != null) {
sslHandler = context.getSslHandlerProvider().getSslHandler();
pipeline.addLast(sslHandler);

5
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/MqttTransportService.java

@ -21,6 +21,7 @@ import io.netty.channel.ChannelOption;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.nio.NioServerSocketChannel;
import io.netty.util.AttributeKey;
import io.netty.util.ResourceLeakDetector;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
@ -33,6 +34,8 @@ import org.thingsboard.server.common.data.TbTransportService;
import javax.annotation.PostConstruct;
import javax.annotation.PreDestroy;
import java.net.InetAddress;
import java.net.InetSocketAddress;
/**
* @author Andrew Shvayka
@ -42,6 +45,8 @@ import javax.annotation.PreDestroy;
@Slf4j
public class MqttTransportService implements TbTransportService {
public static AttributeKey<InetSocketAddress> ADDRESS = AttributeKey.newInstance("SRC_ADDRESS");
@Value("${transport.mqtt.bind_address}")
private String host;
@Value("${transport.mqtt.bind_port}")

46
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/IpFilter.java

@ -0,0 +1,46 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.ipfilter.AbstractRemoteAddressFilter;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetSocketAddress;
@Slf4j
public class IpFilter extends AbstractRemoteAddressFilter<InetSocketAddress> {
private MqttTransportContext context;
public IpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception {
if(context.checkAddress(remoteAddress)){
ctx.channel().attr(MqttTransportService.ADDRESS).set(remoteAddress);
return true;
} else {
return false;
}
}
}

60
common/transport/mqtt/src/main/java/org/thingsboard/server/transport/mqtt/limits/ProxyIpFilter.java

@ -0,0 +1,60 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.transport.mqtt.limits;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.util.AttributeKey;
import lombok.extern.slf4j.Slf4j;
import org.thingsboard.server.transport.mqtt.MqttTransportContext;
import org.thingsboard.server.transport.mqtt.MqttTransportService;
import java.net.InetAddress;
import java.net.InetSocketAddress;
@Slf4j
public class ProxyIpFilter extends ChannelInboundHandlerAdapter {
private MqttTransportContext context;
public ProxyIpFilter(MqttTransportContext context) {
this.context = context;
}
@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
if(msg instanceof HAProxyMessage){
HAProxyMessage proxyMsg = (HAProxyMessage) msg;
if(proxyMsg.sourceAddress() != null && proxyMsg.sourcePort() > 0) {
InetSocketAddress address = new InetSocketAddress(proxyMsg.sourceAddress(), proxyMsg.sourcePort());
if(!context.checkAddress(address)){
ctx.close();
} else {
ctx.channel().attr(MqttTransportService.ADDRESS).set(address);
// We no longer need this channel in the pipeline. Similar to HAProxyMessageDecoder
ctx.pipeline().remove(this);
}
} else {
log.debug("Received local health-check connection message: {}", proxyMsg);
ctx.close();
}
} else {
super.channelRead(ctx, msg);
}
}
}

4
common/transport/mqtt/src/test/java/org/thingsboard/server/transport/mqtt/MqttTransportHandlerTest.java

@ -116,7 +116,7 @@ public class MqttTransportHandlerTest {
MqttConnectMessage msg = getMqttConnectMessage();
willDoNothing().given(handler).processConnect(ctx, msg);
handler.processMqttMsg(ctx, msg);
handler.channelRead(ctx, msg);
assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));
@ -152,7 +152,7 @@ public class MqttTransportHandlerTest {
List<MqttPublishMessage> messages = Stream.generate(this::getMqttPublishMessage).limit(MSG_QUEUE_LIMIT).collect(Collectors.toList());
messages.forEach((msg) -> handler.processMqttMsg(ctx, msg));
messages.forEach((msg) -> handler.channelRead(ctx, msg));
assertThat(handler.address, is(IP_ADDR));
assertThat(handler.deviceSessionCtx.getChannel(), is(ctx));

6
common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/TransportContext.java

@ -22,6 +22,7 @@ import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.common.util.ThingsBoardExecutors;
import org.thingsboard.server.cache.ota.OtaPackageDataCache;
import org.thingsboard.server.common.transport.limits.TransportRateLimitService;
import org.thingsboard.server.queue.discovery.TbServiceInfoProvider;
import org.thingsboard.server.queue.scheduler.SchedulerComponent;
@ -57,6 +58,9 @@ public abstract class TransportContext {
@Autowired
private TransportResourceCache transportResourceCache;
@Autowired
protected TransportRateLimitService rateLimitService;
@PostConstruct
public void init() {
executor = ThingsBoardExecutors.newWorkStealingPool(50, getClass());
@ -73,4 +77,6 @@ public abstract class TransportContext {
return serviceInfoProvider.getServiceId();
}
}

83
common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/DefaultTransportRateLimitService.java

@ -16,6 +16,7 @@
package org.thingsboard.server.common.transport.limits;
import lombok.extern.slf4j.Slf4j;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;
import org.springframework.util.StringUtils;
import org.thingsboard.server.common.data.EntityType;
@ -25,12 +26,13 @@ import org.thingsboard.server.common.data.id.EntityId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.tenant.profile.DefaultTenantProfileConfiguration;
import org.thingsboard.server.common.data.tenant.profile.TenantProfileData;
import org.thingsboard.server.common.msg.tools.TbRateLimits;
import org.thingsboard.server.common.transport.TransportTenantProfileCache;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import org.thingsboard.server.queue.util.TbTransportComponent;
import java.util.HashSet;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
@ -47,9 +49,17 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi
private final ConcurrentMap<TenantId, Set<DeviceId>> tenantDevices = new ConcurrentHashMap<>();
private final ConcurrentMap<TenantId, EntityTransportRateLimits> perTenantLimits = new ConcurrentHashMap<>();
private final ConcurrentMap<DeviceId, EntityTransportRateLimits> perDeviceLimits = new ConcurrentHashMap<>();
private final Map<InetAddress, InetAddressRateLimitStats> ipMap = new ConcurrentHashMap<>();
private final TransportTenantProfileCache tenantProfileCache;
@Value("${transport.rate_limits.ip_limits_enabled:false}")
private boolean ipRateLimitsEnabled;
@Value("${transport.rate_limits.max_wrong_credentials_per_ip:10}")
private int maxWrongCredentialsPerIp;
@Value("${transport.rate_limits.ip_block_timeout:60000}")
private long ipBlockTimeout;
public DefaultTransportRateLimitService(TransportTenantProfileCache tenantProfileCache) {
this.tenantProfileCache = tenantProfileCache;
}
@ -116,6 +126,75 @@ public class DefaultTransportRateLimitService implements TransportRateLimitServi
tenantAllowed.put(tenantId, allowed);
}
@Override
public boolean checkAddress(InetSocketAddress address) {
if (!ipRateLimitsEnabled) {
return true;
}
var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
return !stats.isBlocked() || (stats.getLastActivityTs() + ipBlockTimeout < System.currentTimeMillis());
}
@Override
public void onAuthSuccess(InetSocketAddress address) {
if (!ipRateLimitsEnabled) {
return;
}
var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
stats.getLock().lock();
try {
stats.setLastActivityTs(System.currentTimeMillis());
stats.setFailureCount(0);
if (stats.isBlocked()) {
stats.setBlocked(false);
log.info("[{}] IP address un-blocked due to correct credentials.", address.getAddress());
}
} finally {
stats.getLock().unlock();
}
}
@Override
public void onAuthFailure(InetSocketAddress address) {
if (!ipRateLimitsEnabled) {
return;
}
var stats = ipMap.computeIfAbsent(address.getAddress(), a -> new InetAddressRateLimitStats());
stats.getLock().lock();
try {
stats.setLastActivityTs(System.currentTimeMillis());
int failureCount = stats.getFailureCount() + 1;
stats.setFailureCount(failureCount);
if (failureCount >= maxWrongCredentialsPerIp) {
log.info("[{}] IP address blocked due to constantly wrong credentials.", address.getAddress());
stats.setBlocked(true);
}
} finally {
stats.getLock().unlock();
}
}
@Override
public void invalidateRateLimitsIpTable(long sessionInactivityTimeout) {
if (!ipRateLimitsEnabled) {
return;
}
long currentTime = System.currentTimeMillis();
long expTime = currentTime - Math.max(sessionInactivityTimeout, ipBlockTimeout);
for (var entry : ipMap.entrySet()) {
var stats = entry.getValue();
if (stats.getLastActivityTs() < expTime) {
log.debug("[{}] IP address removed due to session inactivity timeout.", entry.getKey());
ipMap.remove(entry.getKey());
} else if (stats.isBlocked() && (stats.getLastActivityTs() + ipBlockTimeout < currentTime)) {
log.info("[{}] IP address unblocked due ip block timeout.", entry.getKey());
stats.setBlocked(false);
}
}
}
private <T extends EntityId> void mergeLimits(T entityId, EntityTransportRateLimits newRateLimits,
Function<T, EntityTransportRateLimits> getFunction,
BiConsumer<T, EntityTransportRateLimits> putFunction) {

33
common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/InetAddressRateLimitStats.java

@ -0,0 +1,33 @@
/**
* Copyright © 2016-2021 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.common.transport.limits;
import lombok.Data;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
@Data
public class InetAddressRateLimitStats {
private final Lock lock = new ReentrantLock();
private boolean blocked;
private long lastActivityTs;
private int failureCount;
private int connectionsCount;
}

11
common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/limits/TransportRateLimitService.java

@ -20,6 +20,8 @@ import org.thingsboard.server.common.data.id.DeviceId;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.transport.profile.TenantProfileUpdateResult;
import java.net.InetSocketAddress;
public interface TransportRateLimitService {
EntityType checkLimits(TenantId tenantId, DeviceId deviceId, int dataPoints);
@ -33,4 +35,13 @@ public interface TransportRateLimitService {
void remove(DeviceId deviceId);
void update(TenantId tenantId, boolean transportEnabled);
boolean checkAddress(InetSocketAddress address);
void onAuthSuccess(InetSocketAddress address);
void onAuthFailure(InetSocketAddress address);
void invalidateRateLimitsIpTable(long sessionInactivityTimeout);
}

5
common/transport/transport-api/src/main/java/org/thingsboard/server/common/transport/service/DefaultTransportService.java

@ -209,6 +209,7 @@ public class DefaultTransportService implements TransportService {
this.transportApiStats = statsFactory.createMessagesStats(StatsType.TRANSPORT.getName() + ".producer");
this.transportCallbackExecutor = ThingsBoardExecutors.newWorkStealingPool(20, getClass());
this.scheduler.scheduleAtFixedRate(this::checkInactivityAndReportActivity, new Random().nextInt((int) sessionReportTimeout), sessionReportTimeout, TimeUnit.MILLISECONDS);
this.scheduler.scheduleAtFixedRate(this::invalidateRateLimits, new Random().nextInt((int) sessionReportTimeout), sessionReportTimeout, TimeUnit.MILLISECONDS);
transportApiRequestTemplate = queueProvider.createTransportApiRequestTemplate();
transportApiRequestTemplate.setMessagesStats(transportApiStats);
ruleEngineMsgProducer = producerProvider.getRuleEngineMsgProducer();
@ -247,6 +248,10 @@ public class DefaultTransportService implements TransportService {
});
}
private void invalidateRateLimits() {
rateLimitService.invalidateRateLimitsIpTable(sessionInactivityTimeout);
}
@PreDestroy
public void destroy() {
stopped = true;

12
transport/mqtt/src/main/resources/tb-mqtt-transport.yml

@ -88,6 +88,9 @@ transport:
mqtt:
bind_address: "${MQTT_BIND_ADDRESS:0.0.0.0}"
bind_port: "${MQTT_BIND_PORT:1883}"
# Enable proxy protocol support. Disabled by default. If enabled, supports both v1 and v2.
# Useful to get the real IP address of the client in the logs and for rate limits.
proxy_enabled: "${MQTT_PROXY_PROTOCOL_ENABLED:false}"
timeout: "${MQTT_TIMEOUT:10000}"
msg_queue_size_per_device_limit: "${MQTT_MSG_QUEUE_SIZE_PER_DEVICE_LIMIT:100}" # messages await in the queue before device connected state. This limit works on low level before TenantProfileLimits mechanism
netty:
@ -146,6 +149,15 @@ transport:
stats:
enabled: "${TB_TRANSPORT_STATS_ENABLED:true}"
print-interval-ms: "${TB_TRANSPORT_STATS_PRINT_INTERVAL_MS:60000}"
client_side_rpc:
timeout: "${CLIENT_SIDE_RPC_TIMEOUT:60000}"
rate_limits:
# Enable or disable generic rate limits. Device and Tenant specific rate limits are controlled in Tenant Profile.
ip_limits_enabled: "${TB_TRANSPORT_IP_RATE_LIMITS_ENABLED:false}"
# Maximum number of connect attempts with invalid credentials
max_wrong_credentials_per_ip: "${TB_TRANSPORT_MAX_WRONG_CREDENTIALS_PER_IP:10}"
# Timeout to expire block IP addresses
ip_block_timeout: "${TB_TRANSPORT_IP_BLOCK_TIMEOUT:60000}"
queue:

Loading…
Cancel
Save