Browse Source

Fix SSRF DNS rebinding bypass, add allow-list, protect additional HTTP vectors

Add SsrfSafeAddressResolverGroup that validates resolved IPs at Netty
connection time, eliminating the TOCTOU gap where DNS rebinding domains
resolve to safe IPs during validation but to private/metadata IPs at
connection time. Disable HTTP redirects in TbHttpClient to prevent
redirect-based SSRF bypass.

Add allow-list support (SSRF_ALLOWED_HOSTS) to SsrfProtectionValidator
so customers with IoT devices on private networks can whitelist specific
addresses or CIDR ranges while keeping SSRF protection enabled.

Add SSRF validation to MS Teams webhook, custom OAuth2 mapper, and
GitHub OAuth2 mapper endpoints. Log a warning when SSRF protection is
disabled.
pull/15253/head
Viacheslav Klimov 3 months ago
parent
commit
ae8246fc60
Failed to extract signature
  1. 10
      application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java
  2. 11
      application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java
  3. 9
      application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java
  4. 8
      application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java
  5. 4
      application/src/main/resources/thingsboard.yml
  6. 54
      common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java
  7. 90
      common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java
  8. 135
      rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroup.java
  9. 2
      rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java
  10. 150
      rule-engine/rule-engine-components/src/test/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroupTest.java

10
application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java

@ -615,11 +615,21 @@ public class ActorSystemContext {
@Value("${actors.rule.external.ssrf_additional_blocked_hosts:}")
private List<String> ssrfAdditionalBlockedHosts;
@Value("${actors.rule.external.ssrf_allowed_hosts:}")
private List<String> ssrfAllowedHosts;
@PostConstruct
public void init() {
this.localCacheType = "caffeine".equals(cacheType);
SsrfProtectionValidator.setEnabled(ssrfProtectionEnabled);
SsrfProtectionValidator.setAdditionalBlockedHosts(ssrfAdditionalBlockedHosts);
SsrfProtectionValidator.setAllowedHosts(ssrfAllowedHosts);
if (!ssrfProtectionEnabled) {
log.warn("SSRF protection for external rule nodes is DISABLED. This allows rule chains to make HTTP requests to " +
"internal/private network addresses including cloud metadata endpoints. It is strongly recommended to " +
"enable SSRF protection by setting SSRF_PROTECTION_ENABLED=true. If your rule chains need to access " +
"devices on local networks, use SSRF_ALLOWED_HOSTS to whitelist specific addresses or ranges.");
}
}
@Value("${actors.tenant.create_components_on_init:true}")

11
application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java

@ -29,6 +29,7 @@ import org.springframework.http.MediaType;
import org.springframework.stereotype.Component;
import org.springframework.web.client.RestTemplate;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.common.data.notification.NotificationDeliveryMethod;
import org.thingsboard.server.common.data.notification.info.NotificationInfo;
@ -109,10 +110,13 @@ public class MicrosoftTeamsNotificationChannel implements NotificationChannel<Mi
adaptiveCard.getActions().add(actionOpenUrl);
}
URI webhookUri = new URI(targetConfig.getWebhookUrl());
SsrfProtectionValidator.validateUri(webhookUri);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> request = new HttpEntity<>(JacksonUtil.toString(teamsAdaptiveCard), headers);
restTemplate.postForEntity(new URI(targetConfig.getWebhookUrl()), request, String.class);
restTemplate.postForEntity(webhookUri, request, String.class);
}
private void sendTeamsMessageCard(MicrosoftTeamsNotificationTargetConfig targetConfig, MicrosoftTeamsDeliveryMethodNotificationTemplate processedTemplate, NotificationProcessingContext ctx) throws JsonProcessingException, URISyntaxException {
@ -139,10 +143,13 @@ public class MicrosoftTeamsNotificationChannel implements NotificationChannel<Mi
teamsMessageCard.setPotentialAction(List.of(actionCard));
}
URI webhookUri = new URI(targetConfig.getWebhookUrl());
SsrfProtectionValidator.validateUri(webhookUri);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
HttpEntity<String> request = new HttpEntity<>(JacksonUtil.toString(teamsMessageCard), headers);
restTemplate.postForEntity(new URI(targetConfig.getWebhookUrl()), request, String.class);
restTemplate.postForEntity(webhookUri, request, String.class);
}
private String getButtonUri(MicrosoftTeamsDeliveryMethodNotificationTemplate processedTemplate, NotificationProcessingContext ctx) throws JsonProcessingException {

9
application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java

@ -23,11 +23,14 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.thingsboard.common.util.JacksonUtil;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.StringUtils;
import org.thingsboard.server.common.data.oauth2.OAuth2CustomMapperConfig;
import org.thingsboard.server.common.data.oauth2.OAuth2MapperConfig;
import org.thingsboard.server.common.data.oauth2.OAuth2Client;
import org.thingsboard.server.dao.oauth2.OAuth2User;
import java.net.URI;
import org.thingsboard.server.queue.util.TbCoreComponent;
import org.thingsboard.server.service.security.model.SecurityUser;
@ -63,6 +66,12 @@ public class CustomOAuth2ClientMapper extends AbstractOAuth2ClientMapper impleme
log.error("Can't convert principal to JSON string", e);
throw new RuntimeException("Can't convert principal to JSON string", e);
}
try {
SsrfProtectionValidator.validateUri(new URI(custom.getUrl()));
} catch (Exception e) {
log.error("SSRF validation failed for custom mapper URL '{}'", custom.getUrl(), e);
throw new RuntimeException("Unable to login. Please contact your Administrator!");
}
try {
return restTemplate.postForEntity(custom.getUrl(), request, OAuth2User.class).getBody();
} catch (Exception e) {

8
application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java

@ -24,6 +24,7 @@ import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.stereotype.Service;
import org.springframework.web.client.RestTemplate;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.oauth2.OAuth2MapperConfig;
import org.thingsboard.server.common.data.oauth2.OAuth2Client;
import org.thingsboard.server.dao.oauth2.OAuth2Configuration;
@ -31,6 +32,7 @@ import org.thingsboard.server.dao.oauth2.OAuth2User;
import org.thingsboard.server.queue.util.TbCoreComponent;
import org.thingsboard.server.service.security.model.SecurityUser;
import java.net.URI;
import java.util.ArrayList;
import java.util.Map;
import java.util.Optional;
@ -62,6 +64,12 @@ public class GithubOAuth2ClientMapper extends AbstractOAuth2ClientMapper impleme
restTemplateBuilder = restTemplateBuilder.defaultHeader(AUTHORIZATION, "token " + oauth2Token);
RestTemplate restTemplate = restTemplateBuilder.build();
try {
SsrfProtectionValidator.validateUri(new URI(emailUrl));
} catch (Exception e) {
log.error("SSRF validation failed for GitHub email URL '{}'", emailUrl, e);
throw new RuntimeException("Unable to login. Please contact your Administrator!");
}
GithubEmailsResponse githubEmailsResponse;
try {
githubEmailsResponse = restTemplate.getForEntity(emailUrl, GithubEmailsResponse.class).getBody();

4
application/src/main/resources/thingsboard.yml

@ -509,6 +509,10 @@ actors:
# Comma-separated list of additional blocked destinations (IPs, CIDR subnets, or hostnames).
# Example: "198.51.100.0/24,metadata.tencentyun.com,rancher-metadata"
ssrf_additional_blocked_hosts: "${SSRF_ADDITIONAL_BLOCKED_HOSTS:}"
# Comma-separated list of allowed destinations that bypass SSRF blocking (IPs, CIDR subnets, or hostnames).
# Use this when your rule chains need to reach devices on private networks (e.g., 192.168.1.0/24).
# Example: "192.168.1.0/24,10.0.0.0/8,my-internal-service.corp"
ssrf_allowed_hosts: "${SSRF_ALLOWED_HOSTS:}"
rpc:
# Maximum number of persistent RPC call retries in case of failed request delivery.
max_retries: "${ACTORS_RPC_MAX_RETRIES:5}"

54
common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java

@ -38,6 +38,7 @@ public class SsrfProtectionValidator {
private static final Set<String> BLOCKED_HOSTNAME_SUFFIXES = Set.of(".internal", ".local");
private static volatile AdditionalBlockedHosts additionalBlocked = AdditionalBlockedHosts.EMPTY;
private static volatile AllowedHosts allowedHosts = AllowedHosts.EMPTY;
// Well-known cloud metadata endpoints not covered by the JDK checks (isLoopback, isSiteLocal, isLinkLocal)
private static final List<CidrRange> CLOUD_METADATA_RANGES = List.of(
@ -66,6 +67,13 @@ public class SsrfProtectionValidator {
}
String hostLower = host.toLowerCase();
// Allow-listed hostnames bypass all hostname and IP checks
AllowedHosts currentAllowed = allowedHosts;
if (currentAllowed.hostnames.contains(hostLower)) {
return;
}
if (BLOCKED_HOSTNAMES.contains(hostLower) || additionalBlocked.hostnames.contains(hostLower)) {
throwBlockedHost(host);
}
@ -98,7 +106,15 @@ public class SsrfProtectionValidator {
}
}
private static boolean isBlockedAddress(InetAddress address) {
public static boolean isBlockedAddress(InetAddress address) {
// Check allow-list first: allowed addresses bypass all block checks
AllowedHosts currentAllowed = allowedHosts;
for (CidrRange cidr : currentAllowed.cidrRanges) {
if (cidr.contains(address)) {
return false;
}
}
// Covers 127.0.0.0/8 and ::1
if (address.isLoopbackAddress()) {
return true;
@ -142,6 +158,10 @@ public class SsrfProtectionValidator {
throw new RuntimeException("URI is invalid: host '" + host + "' is not allowed");
}
public static boolean isEnabled() {
return enabled;
}
public static void setEnabled(boolean enabled) {
SsrfProtectionValidator.enabled = enabled;
}
@ -179,10 +199,42 @@ public class SsrfProtectionValidator {
return !entry.isEmpty() && (Character.isDigit(entry.charAt(0)) || entry.contains(":"));
}
public static void setAllowedHosts(List<String> entries) {
if (entries == null || entries.isEmpty()) {
allowedHosts = AllowedHosts.EMPTY;
return;
}
List<CidrRange> cidrRanges = new ArrayList<>();
Set<String> hostnames = new HashSet<>();
for (String entry : entries) {
String trimmed = entry.trim();
if (trimmed.isEmpty()) {
continue;
}
if (trimmed.contains("/") || isIpLiteral(trimmed)) {
try {
cidrRanges.add(CidrRange.parse(trimmed));
} catch (Exception e) {
log.warn("Failed to parse allowed CIDR/IP entry '{}': {}", trimmed, e.getMessage());
}
} else {
hostnames.add(trimmed.toLowerCase());
}
}
allowedHosts = new AllowedHosts(
Collections.unmodifiableList(cidrRanges),
Collections.unmodifiableSet(hostnames));
log.info("SSRF allowed hosts configured: {} CIDR range(s), {} hostname(s)", cidrRanges.size(), hostnames.size());
}
record AdditionalBlockedHosts(List<CidrRange> cidrRanges, Set<String> hostnames) {
static final AdditionalBlockedHosts EMPTY = new AdditionalBlockedHosts(Collections.emptyList(), Collections.emptySet());
}
record AllowedHosts(List<CidrRange> cidrRanges, Set<String> hostnames) {
static final AllowedHosts EMPTY = new AllowedHosts(Collections.emptyList(), Collections.emptySet());
}
record CidrRange(byte[] network, int prefixLength) {
static CidrRange of(String ip, int prefixLength) {

90
common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java

@ -20,10 +20,12 @@ import org.junit.jupiter.api.parallel.ResourceLock;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;
import java.net.InetAddress;
import java.net.URI;
import java.util.Collections;
import java.util.List;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatNoException;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ -335,4 +337,92 @@ public class SsrfProtectionValidatorTest {
}
}
// --- Allow-list tests ---
@Test
void testAllowListCidrAllowsPrivateAddress() {
try {
SsrfProtectionValidator.setAllowedHosts(List.of("192.168.1.0/24"));
// 192.168.1.1 is normally blocked (site-local), but allow-listed
assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://192.168.1.1"), true));
// Other private ranges remain blocked
assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://10.0.0.1"), true))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
} finally {
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
}
@Test
void testAllowListHostnameBypassesSuffixCheck() {
try {
SsrfProtectionValidator.setAllowedHosts(List.of("my-device.local"));
// .local suffix is normally blocked, but allow-listed hostname passes
assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://my-device.local/api"), true));
// Other .local hostnames remain blocked
assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://other-device.local/api"), true))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
} finally {
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
}
@Test
void testAllowListPrecedenceOverBlockList() {
try {
// Block 8.8.8.0/24 via additional-blocked, but allow 8.8.8.8 via allow-list
SsrfProtectionValidator.setAdditionalBlockedHosts(List.of("8.8.8.0/24"));
SsrfProtectionValidator.setAllowedHosts(List.of("8.8.8.8"));
// Allow-list should win
assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("https://8.8.8.8"), true));
// Adjacent IP still blocked
assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("https://8.8.8.9"), true))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
} finally {
SsrfProtectionValidator.setAdditionalBlockedHosts(Collections.emptyList());
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
}
@Test
void testIsBlockedAddressPublicApi() throws Exception {
InetAddress loopback = InetAddress.getByName("127.0.0.1");
assertThat(SsrfProtectionValidator.isBlockedAddress(loopback)).isTrue();
InetAddress publicIp = InetAddress.getByName("8.8.8.8");
assertThat(SsrfProtectionValidator.isBlockedAddress(publicIp)).isFalse();
// Allow-listed private address
try {
SsrfProtectionValidator.setAllowedHosts(List.of("10.0.0.0/8"));
InetAddress privateIp = InetAddress.getByName("10.1.2.3");
assertThat(SsrfProtectionValidator.isBlockedAddress(privateIp)).isFalse();
} finally {
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
}
@Test
void testIsEnabledAccessor() {
boolean original = SsrfProtectionValidator.isEnabled();
try {
SsrfProtectionValidator.setEnabled(true);
assertThat(SsrfProtectionValidator.isEnabled()).isTrue();
SsrfProtectionValidator.setEnabled(false);
assertThat(SsrfProtectionValidator.isEnabled()).isFalse();
} finally {
SsrfProtectionValidator.setEnabled(original);
}
}
@Test
void testSetAllowedHostsEmptyAndNull() {
// Should not throw
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
SsrfProtectionValidator.setAllowedHosts(null);
}
}

135
rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroup.java

@ -0,0 +1,135 @@
/**
* Copyright © 2016-2026 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.rule.engine.rest;
import io.netty.resolver.AddressResolver;
import io.netty.resolver.AddressResolverGroup;
import io.netty.resolver.DefaultAddressResolverGroup;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import org.thingsboard.common.util.SsrfProtectionValidator;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.stream.Collectors;
/**
* Custom Netty {@link AddressResolverGroup} that validates every resolved IP address
* against the SSRF block-list at connection time. This eliminates the DNS rebinding
* TOCTOU gap where a hostname resolves to a safe IP during validation but to a
* private/metadata IP when the actual connection is made.
*/
public final class SsrfSafeAddressResolverGroup extends AddressResolverGroup<InetSocketAddress> {
public static final SsrfSafeAddressResolverGroup INSTANCE = new SsrfSafeAddressResolverGroup();
private SsrfSafeAddressResolverGroup() {
}
@Override
protected AddressResolver<InetSocketAddress> newResolver(EventExecutor executor) throws Exception {
AddressResolver<InetSocketAddress> delegate = DefaultAddressResolverGroup.INSTANCE.getResolver(executor);
return new SsrfValidatingResolver(executor, delegate);
}
private static final class SsrfValidatingResolver implements AddressResolver<InetSocketAddress> {
private final EventExecutor executor;
private final AddressResolver<InetSocketAddress> delegate;
SsrfValidatingResolver(EventExecutor executor, AddressResolver<InetSocketAddress> delegate) {
this.executor = executor;
this.delegate = delegate;
}
@Override
public boolean isSupported(SocketAddress address) {
return delegate.isSupported(address);
}
@Override
public boolean isResolved(SocketAddress address) {
return delegate.isResolved(address);
}
@Override
public Future<InetSocketAddress> resolve(SocketAddress address) {
return resolve(address, executor.newPromise());
}
@Override
public Future<InetSocketAddress> resolve(SocketAddress address, Promise<InetSocketAddress> promise) {
delegate.resolve(address).addListener((Future<InetSocketAddress> future) -> {
if (!future.isSuccess()) {
promise.tryFailure(future.cause());
return;
}
InetSocketAddress resolved = future.getNow();
if (SsrfProtectionValidator.isEnabled() && isBlocked(resolved)) {
promise.tryFailure(new RuntimeException(
"SSRF protection: resolved address " + resolved.getAddress().getHostAddress() + " is blocked"));
} else {
promise.trySuccess(resolved);
}
});
return promise;
}
@Override
public Future<List<InetSocketAddress>> resolveAll(SocketAddress address) {
return resolveAll(address, executor.newPromise());
}
@Override
public Future<List<InetSocketAddress>> resolveAll(SocketAddress address, Promise<List<InetSocketAddress>> promise) {
delegate.resolveAll(address).addListener((Future<List<InetSocketAddress>> future) -> {
if (!future.isSuccess()) {
promise.tryFailure(future.cause());
return;
}
List<InetSocketAddress> resolved = future.getNow();
if (!SsrfProtectionValidator.isEnabled()) {
promise.trySuccess(resolved);
return;
}
List<InetSocketAddress> safe = resolved.stream()
.filter(addr -> !isBlocked(addr))
.collect(Collectors.toList());
if (safe.isEmpty()) {
String host = address instanceof InetSocketAddress isa ? isa.getHostString() : address.toString();
promise.tryFailure(new RuntimeException(
"SSRF protection: all resolved addresses for " + host + " are blocked"));
} else {
promise.trySuccess(safe);
}
});
return promise;
}
@Override
public void close() {
delegate.close();
}
private static boolean isBlocked(InetSocketAddress socketAddress) {
InetAddress addr = socketAddress.getAddress();
return addr != null && SsrfProtectionValidator.isBlockedAddress(addr);
}
}
}

2
rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java

@ -103,6 +103,8 @@ public class TbHttpClient {
.build();
HttpClient httpClient = HttpClient.create(connectionProvider)
.resolver(SsrfSafeAddressResolverGroup.INSTANCE)
.followRedirect(false)
.runOn(getSharedOrCreateEventLoopGroup(eventLoopGroupShared))
.doOnConnected(c ->
c.addHandlerLast(new ReadTimeoutHandler(config.getReadTimeoutMs(), TimeUnit.MILLISECONDS)));

150
rule-engine/rule-engine-components/src/test/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroupTest.java

@ -0,0 +1,150 @@
/**
* Copyright © 2016-2026 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.rule.engine.rest;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.resolver.AddressResolver;
import io.netty.util.concurrent.EventExecutor;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.ResourceLock;
import org.thingsboard.common.util.SsrfProtectionValidator;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ResourceLock("SsrfSafeAddressResolverGroupTest")
class SsrfSafeAddressResolverGroupTest {
private static NioEventLoopGroup eventLoopGroup;
@BeforeAll
static void setUp() {
eventLoopGroup = new NioEventLoopGroup(1);
}
@AfterAll
static void tearDown() {
eventLoopGroup.shutdownGracefully(0, 5, TimeUnit.SECONDS);
SsrfProtectionValidator.setEnabled(false);
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
@BeforeEach
void enableSsrf() {
SsrfProtectionValidator.setEnabled(true);
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
}
@AfterEach
void resetState() {
SsrfProtectionValidator.setAllowedHosts(Collections.emptyList());
SsrfProtectionValidator.setEnabled(false);
}
@Test
void isBlockedAddressWorksForLoopback() throws Exception {
assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("127.0.0.1"))).isTrue();
assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("192.168.1.1"))).isTrue();
assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("8.8.8.8"))).isFalse();
}
@Test
void resolvePublicIpSucceeds() throws Exception {
EventExecutor executor = eventLoopGroup.next();
AddressResolver<InetSocketAddress> resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor);
Promise<InetSocketAddress> promise = executor.newPromise();
executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("example.com", 80), promise));
InetSocketAddress result = promise.get(10, TimeUnit.SECONDS);
assertThat(result.getAddress()).isNotNull();
assertThat(result.getAddress().isLoopbackAddress()).isFalse();
assertThat(result.getAddress().isSiteLocalAddress()).isFalse();
}
@Test
void resolveLoopbackFailsWhenSsrfEnabled() throws Exception {
assertThat(SsrfProtectionValidator.isEnabled()).isTrue();
EventExecutor executor = eventLoopGroup.next();
AddressResolver<InetSocketAddress> resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor);
Promise<InetSocketAddress> promise = executor.newPromise();
executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("127.0.0.1", 80), promise));
assertThatThrownBy(() -> promise.get(10, TimeUnit.SECONDS))
.isInstanceOf(ExecutionException.class)
.hasRootCauseInstanceOf(RuntimeException.class)
.rootCause().hasMessageContaining("SSRF protection");
}
@Test
void resolvePrivateIpFailsWhenSsrfEnabled() throws Exception {
assertThat(SsrfProtectionValidator.isEnabled()).isTrue();
EventExecutor executor = eventLoopGroup.next();
AddressResolver<InetSocketAddress> resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor);
Promise<InetSocketAddress> promise = executor.newPromise();
executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("192.168.1.1", 80), promise));
assertThatThrownBy(() -> promise.get(10, TimeUnit.SECONDS))
.isInstanceOf(ExecutionException.class)
.hasRootCauseInstanceOf(RuntimeException.class)
.rootCause().hasMessageContaining("SSRF protection");
}
@Test
void resolveAllowedPrivateIpSucceeds() throws Exception {
SsrfProtectionValidator.setAllowedHosts(List.of("192.168.1.0/24"));
EventExecutor executor = eventLoopGroup.next();
AddressResolver<InetSocketAddress> resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor);
Promise<InetSocketAddress> promise = executor.newPromise();
executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("192.168.1.1", 80), promise));
InetSocketAddress result = promise.get(10, TimeUnit.SECONDS);
assertThat(result.getAddress().getHostAddress()).isEqualTo("192.168.1.1");
}
@Test
void resolveAllReturnsAllWhenSsrfDisabled() throws Exception {
SsrfProtectionValidator.setEnabled(false);
EventExecutor executor = eventLoopGroup.next();
AddressResolver<InetSocketAddress> resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor);
Promise<List<InetSocketAddress>> promise = executor.newPromise();
executor.submit(() -> resolver.resolveAll(InetSocketAddress.createUnresolved("127.0.0.1", 80), promise));
List<InetSocketAddress> results = promise.get(10, TimeUnit.SECONDS);
assertThat(results).isNotEmpty();
}
}
Loading…
Cancel
Save