diff --git a/application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java b/application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java index 892eb2beda..a9692f4b9d 100644 --- a/application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java +++ b/application/src/main/java/org/thingsboard/server/actors/ActorSystemContext.java @@ -615,11 +615,21 @@ public class ActorSystemContext { @Value("${actors.rule.external.ssrf_additional_blocked_hosts:}") private List ssrfAdditionalBlockedHosts; + @Value("${actors.rule.external.ssrf_allowed_hosts:}") + private List ssrfAllowedHosts; + @PostConstruct public void init() { this.localCacheType = "caffeine".equals(cacheType); SsrfProtectionValidator.setEnabled(ssrfProtectionEnabled); SsrfProtectionValidator.setAdditionalBlockedHosts(ssrfAdditionalBlockedHosts); + SsrfProtectionValidator.setAllowedHosts(ssrfAllowedHosts); + if (!ssrfProtectionEnabled) { + log.warn("SSRF protection for external rule nodes is DISABLED. This allows rule chains to make HTTP requests to " + + "internal/private network addresses including cloud metadata endpoints. It is strongly recommended to " + + "enable SSRF protection by setting SSRF_PROTECTION_ENABLED=true. If your rule chains need to access " + + "devices on local networks, use SSRF_ALLOWED_HOSTS to whitelist specific addresses or ranges."); + } } @Value("${actors.tenant.create_components_on_init:true}") diff --git a/application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java b/application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java index df163283ce..d023ed9b53 100644 --- a/application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java +++ b/application/src/main/java/org/thingsboard/server/service/notification/channels/MicrosoftTeamsNotificationChannel.java @@ -29,6 +29,7 @@ import org.springframework.http.MediaType; import org.springframework.stereotype.Component; import org.springframework.web.client.RestTemplate; import org.thingsboard.common.util.JacksonUtil; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.common.data.notification.NotificationDeliveryMethod; import org.thingsboard.server.common.data.notification.info.NotificationInfo; @@ -109,10 +110,13 @@ public class MicrosoftTeamsNotificationChannel implements NotificationChannel request = new HttpEntity<>(JacksonUtil.toString(teamsAdaptiveCard), headers); - restTemplate.postForEntity(new URI(targetConfig.getWebhookUrl()), request, String.class); + restTemplate.postForEntity(webhookUri, request, String.class); } private void sendTeamsMessageCard(MicrosoftTeamsNotificationTargetConfig targetConfig, MicrosoftTeamsDeliveryMethodNotificationTemplate processedTemplate, NotificationProcessingContext ctx) throws JsonProcessingException, URISyntaxException { @@ -139,10 +143,13 @@ public class MicrosoftTeamsNotificationChannel implements NotificationChannel request = new HttpEntity<>(JacksonUtil.toString(teamsMessageCard), headers); - restTemplate.postForEntity(new URI(targetConfig.getWebhookUrl()), request, String.class); + restTemplate.postForEntity(webhookUri, request, String.class); } private String getButtonUri(MicrosoftTeamsDeliveryMethodNotificationTemplate processedTemplate, NotificationProcessingContext ctx) throws JsonProcessingException { diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java b/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java index 8477c69a99..5cd6ca09b2 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/CustomOAuth2ClientMapper.java @@ -23,11 +23,14 @@ import org.springframework.security.oauth2.client.authentication.OAuth2Authentic import org.springframework.stereotype.Service; import org.springframework.web.client.RestTemplate; import org.thingsboard.common.util.JacksonUtil; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.StringUtils; import org.thingsboard.server.common.data.oauth2.OAuth2CustomMapperConfig; import org.thingsboard.server.common.data.oauth2.OAuth2MapperConfig; import org.thingsboard.server.common.data.oauth2.OAuth2Client; import org.thingsboard.server.dao.oauth2.OAuth2User; + +import java.net.URI; import org.thingsboard.server.queue.util.TbCoreComponent; import org.thingsboard.server.service.security.model.SecurityUser; @@ -63,6 +66,12 @@ public class CustomOAuth2ClientMapper extends AbstractOAuth2ClientMapper impleme log.error("Can't convert principal to JSON string", e); throw new RuntimeException("Can't convert principal to JSON string", e); } + try { + SsrfProtectionValidator.validateUri(new URI(custom.getUrl())); + } catch (Exception e) { + log.error("SSRF validation failed for custom mapper URL '{}'", custom.getUrl(), e); + throw new RuntimeException("Unable to login. Please contact your Administrator!"); + } try { return restTemplate.postForEntity(custom.getUrl(), request, OAuth2User.class).getBody(); } catch (Exception e) { diff --git a/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java b/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java index c7aa893d59..2861036097 100644 --- a/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java +++ b/application/src/main/java/org/thingsboard/server/service/security/auth/oauth2/GithubOAuth2ClientMapper.java @@ -24,6 +24,7 @@ import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken; import org.springframework.stereotype.Service; import org.springframework.web.client.RestTemplate; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.oauth2.OAuth2MapperConfig; import org.thingsboard.server.common.data.oauth2.OAuth2Client; import org.thingsboard.server.dao.oauth2.OAuth2Configuration; @@ -31,6 +32,7 @@ import org.thingsboard.server.dao.oauth2.OAuth2User; import org.thingsboard.server.queue.util.TbCoreComponent; import org.thingsboard.server.service.security.model.SecurityUser; +import java.net.URI; import java.util.ArrayList; import java.util.Map; import java.util.Optional; @@ -62,6 +64,12 @@ public class GithubOAuth2ClientMapper extends AbstractOAuth2ClientMapper impleme restTemplateBuilder = restTemplateBuilder.defaultHeader(AUTHORIZATION, "token " + oauth2Token); RestTemplate restTemplate = restTemplateBuilder.build(); + try { + SsrfProtectionValidator.validateUri(new URI(emailUrl)); + } catch (Exception e) { + log.error("SSRF validation failed for GitHub email URL '{}'", emailUrl, e); + throw new RuntimeException("Unable to login. Please contact your Administrator!"); + } GithubEmailsResponse githubEmailsResponse; try { githubEmailsResponse = restTemplate.getForEntity(emailUrl, GithubEmailsResponse.class).getBody(); diff --git a/application/src/main/resources/thingsboard.yml b/application/src/main/resources/thingsboard.yml index 60a158febe..08913b2c9f 100644 --- a/application/src/main/resources/thingsboard.yml +++ b/application/src/main/resources/thingsboard.yml @@ -509,6 +509,10 @@ actors: # Comma-separated list of additional blocked destinations (IPs, CIDR subnets, or hostnames). # Example: "198.51.100.0/24,metadata.tencentyun.com,rancher-metadata" ssrf_additional_blocked_hosts: "${SSRF_ADDITIONAL_BLOCKED_HOSTS:}" + # Comma-separated list of allowed destinations that bypass SSRF blocking (IPs, CIDR subnets, or hostnames). + # Use this when your rule chains need to reach devices on private networks (e.g., 192.168.1.0/24). + # Example: "192.168.1.0/24,10.0.0.0/8,my-internal-service.corp" + ssrf_allowed_hosts: "${SSRF_ALLOWED_HOSTS:}" rpc: # Maximum number of persistent RPC call retries in case of failed request delivery. max_retries: "${ACTORS_RPC_MAX_RETRIES:5}" diff --git a/common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java b/common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java index 15da77f663..eb89b164a2 100644 --- a/common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java +++ b/common/util/src/main/java/org/thingsboard/common/util/SsrfProtectionValidator.java @@ -38,6 +38,7 @@ public class SsrfProtectionValidator { private static final Set BLOCKED_HOSTNAME_SUFFIXES = Set.of(".internal", ".local"); private static volatile AdditionalBlockedHosts additionalBlocked = AdditionalBlockedHosts.EMPTY; + private static volatile AllowedHosts allowedHosts = AllowedHosts.EMPTY; // Well-known cloud metadata endpoints not covered by the JDK checks (isLoopback, isSiteLocal, isLinkLocal) private static final List CLOUD_METADATA_RANGES = List.of( @@ -66,6 +67,13 @@ public class SsrfProtectionValidator { } String hostLower = host.toLowerCase(); + + // Allow-listed hostnames bypass all hostname and IP checks + AllowedHosts currentAllowed = allowedHosts; + if (currentAllowed.hostnames.contains(hostLower)) { + return; + } + if (BLOCKED_HOSTNAMES.contains(hostLower) || additionalBlocked.hostnames.contains(hostLower)) { throwBlockedHost(host); } @@ -98,7 +106,15 @@ public class SsrfProtectionValidator { } } - private static boolean isBlockedAddress(InetAddress address) { + public static boolean isBlockedAddress(InetAddress address) { + // Check allow-list first: allowed addresses bypass all block checks + AllowedHosts currentAllowed = allowedHosts; + for (CidrRange cidr : currentAllowed.cidrRanges) { + if (cidr.contains(address)) { + return false; + } + } + // Covers 127.0.0.0/8 and ::1 if (address.isLoopbackAddress()) { return true; @@ -142,6 +158,10 @@ public class SsrfProtectionValidator { throw new RuntimeException("URI is invalid: host '" + host + "' is not allowed"); } + public static boolean isEnabled() { + return enabled; + } + public static void setEnabled(boolean enabled) { SsrfProtectionValidator.enabled = enabled; } @@ -179,10 +199,42 @@ public class SsrfProtectionValidator { return !entry.isEmpty() && (Character.isDigit(entry.charAt(0)) || entry.contains(":")); } + public static void setAllowedHosts(List entries) { + if (entries == null || entries.isEmpty()) { + allowedHosts = AllowedHosts.EMPTY; + return; + } + List cidrRanges = new ArrayList<>(); + Set hostnames = new HashSet<>(); + for (String entry : entries) { + String trimmed = entry.trim(); + if (trimmed.isEmpty()) { + continue; + } + if (trimmed.contains("/") || isIpLiteral(trimmed)) { + try { + cidrRanges.add(CidrRange.parse(trimmed)); + } catch (Exception e) { + log.warn("Failed to parse allowed CIDR/IP entry '{}': {}", trimmed, e.getMessage()); + } + } else { + hostnames.add(trimmed.toLowerCase()); + } + } + allowedHosts = new AllowedHosts( + Collections.unmodifiableList(cidrRanges), + Collections.unmodifiableSet(hostnames)); + log.info("SSRF allowed hosts configured: {} CIDR range(s), {} hostname(s)", cidrRanges.size(), hostnames.size()); + } + record AdditionalBlockedHosts(List cidrRanges, Set hostnames) { static final AdditionalBlockedHosts EMPTY = new AdditionalBlockedHosts(Collections.emptyList(), Collections.emptySet()); } + record AllowedHosts(List cidrRanges, Set hostnames) { + static final AllowedHosts EMPTY = new AllowedHosts(Collections.emptyList(), Collections.emptySet()); + } + record CidrRange(byte[] network, int prefixLength) { static CidrRange of(String ip, int prefixLength) { diff --git a/common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java b/common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java index 6cb2d21a9a..52ab865f6c 100644 --- a/common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java +++ b/common/util/src/test/java/org/thingsboard/common/util/SsrfProtectionValidatorTest.java @@ -20,10 +20,12 @@ import org.junit.jupiter.api.parallel.ResourceLock; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; +import java.net.InetAddress; import java.net.URI; import java.util.Collections; import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatNoException; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -335,4 +337,92 @@ public class SsrfProtectionValidatorTest { } } + // --- Allow-list tests --- + + @Test + void testAllowListCidrAllowsPrivateAddress() { + try { + SsrfProtectionValidator.setAllowedHosts(List.of("192.168.1.0/24")); + // 192.168.1.1 is normally blocked (site-local), but allow-listed + assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://192.168.1.1"), true)); + // Other private ranges remain blocked + assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://10.0.0.1"), true)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } finally { + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + } + + @Test + void testAllowListHostnameBypassesSuffixCheck() { + try { + SsrfProtectionValidator.setAllowedHosts(List.of("my-device.local")); + // .local suffix is normally blocked, but allow-listed hostname passes + assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://my-device.local/api"), true)); + // Other .local hostnames remain blocked + assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("http://other-device.local/api"), true)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } finally { + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + } + + @Test + void testAllowListPrecedenceOverBlockList() { + try { + // Block 8.8.8.0/24 via additional-blocked, but allow 8.8.8.8 via allow-list + SsrfProtectionValidator.setAdditionalBlockedHosts(List.of("8.8.8.0/24")); + SsrfProtectionValidator.setAllowedHosts(List.of("8.8.8.8")); + // Allow-list should win + assertThatNoException().isThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("https://8.8.8.8"), true)); + // Adjacent IP still blocked + assertThatThrownBy(() -> SsrfProtectionValidator.validateUri(URI.create("https://8.8.8.9"), true)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } finally { + SsrfProtectionValidator.setAdditionalBlockedHosts(Collections.emptyList()); + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + } + + @Test + void testIsBlockedAddressPublicApi() throws Exception { + InetAddress loopback = InetAddress.getByName("127.0.0.1"); + assertThat(SsrfProtectionValidator.isBlockedAddress(loopback)).isTrue(); + + InetAddress publicIp = InetAddress.getByName("8.8.8.8"); + assertThat(SsrfProtectionValidator.isBlockedAddress(publicIp)).isFalse(); + + // Allow-listed private address + try { + SsrfProtectionValidator.setAllowedHosts(List.of("10.0.0.0/8")); + InetAddress privateIp = InetAddress.getByName("10.1.2.3"); + assertThat(SsrfProtectionValidator.isBlockedAddress(privateIp)).isFalse(); + } finally { + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + } + + @Test + void testIsEnabledAccessor() { + boolean original = SsrfProtectionValidator.isEnabled(); + try { + SsrfProtectionValidator.setEnabled(true); + assertThat(SsrfProtectionValidator.isEnabled()).isTrue(); + SsrfProtectionValidator.setEnabled(false); + assertThat(SsrfProtectionValidator.isEnabled()).isFalse(); + } finally { + SsrfProtectionValidator.setEnabled(original); + } + } + + @Test + void testSetAllowedHostsEmptyAndNull() { + // Should not throw + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + SsrfProtectionValidator.setAllowedHosts(null); + } + } diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroup.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroup.java new file mode 100644 index 0000000000..f3cd0c83bb --- /dev/null +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroup.java @@ -0,0 +1,135 @@ +/** + * Copyright © 2016-2026 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.rule.engine.rest; + +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultAddressResolverGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import org.thingsboard.common.util.SsrfProtectionValidator; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Custom Netty {@link AddressResolverGroup} that validates every resolved IP address + * against the SSRF block-list at connection time. This eliminates the DNS rebinding + * TOCTOU gap where a hostname resolves to a safe IP during validation but to a + * private/metadata IP when the actual connection is made. + */ +public final class SsrfSafeAddressResolverGroup extends AddressResolverGroup { + + public static final SsrfSafeAddressResolverGroup INSTANCE = new SsrfSafeAddressResolverGroup(); + + private SsrfSafeAddressResolverGroup() { + } + + @Override + protected AddressResolver newResolver(EventExecutor executor) throws Exception { + AddressResolver delegate = DefaultAddressResolverGroup.INSTANCE.getResolver(executor); + return new SsrfValidatingResolver(executor, delegate); + } + + private static final class SsrfValidatingResolver implements AddressResolver { + + private final EventExecutor executor; + private final AddressResolver delegate; + + SsrfValidatingResolver(EventExecutor executor, AddressResolver delegate) { + this.executor = executor; + this.delegate = delegate; + } + + @Override + public boolean isSupported(SocketAddress address) { + return delegate.isSupported(address); + } + + @Override + public boolean isResolved(SocketAddress address) { + return delegate.isResolved(address); + } + + @Override + public Future resolve(SocketAddress address) { + return resolve(address, executor.newPromise()); + } + + @Override + public Future resolve(SocketAddress address, Promise promise) { + delegate.resolve(address).addListener((Future future) -> { + if (!future.isSuccess()) { + promise.tryFailure(future.cause()); + return; + } + InetSocketAddress resolved = future.getNow(); + if (SsrfProtectionValidator.isEnabled() && isBlocked(resolved)) { + promise.tryFailure(new RuntimeException( + "SSRF protection: resolved address " + resolved.getAddress().getHostAddress() + " is blocked")); + } else { + promise.trySuccess(resolved); + } + }); + return promise; + } + + @Override + public Future> resolveAll(SocketAddress address) { + return resolveAll(address, executor.newPromise()); + } + + @Override + public Future> resolveAll(SocketAddress address, Promise> promise) { + delegate.resolveAll(address).addListener((Future> future) -> { + if (!future.isSuccess()) { + promise.tryFailure(future.cause()); + return; + } + List resolved = future.getNow(); + if (!SsrfProtectionValidator.isEnabled()) { + promise.trySuccess(resolved); + return; + } + List safe = resolved.stream() + .filter(addr -> !isBlocked(addr)) + .collect(Collectors.toList()); + if (safe.isEmpty()) { + String host = address instanceof InetSocketAddress isa ? isa.getHostString() : address.toString(); + promise.tryFailure(new RuntimeException( + "SSRF protection: all resolved addresses for " + host + " are blocked")); + } else { + promise.trySuccess(safe); + } + }); + return promise; + } + + @Override + public void close() { + delegate.close(); + } + + private static boolean isBlocked(InetSocketAddress socketAddress) { + InetAddress addr = socketAddress.getAddress(); + return addr != null && SsrfProtectionValidator.isBlockedAddress(addr); + } + } +} diff --git a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java index 24f88e88a3..c450622746 100644 --- a/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java +++ b/rule-engine/rule-engine-components/src/main/java/org/thingsboard/rule/engine/rest/TbHttpClient.java @@ -103,6 +103,8 @@ public class TbHttpClient { .build(); HttpClient httpClient = HttpClient.create(connectionProvider) + .resolver(SsrfSafeAddressResolverGroup.INSTANCE) + .followRedirect(false) .runOn(getSharedOrCreateEventLoopGroup(eventLoopGroupShared)) .doOnConnected(c -> c.addHandlerLast(new ReadTimeoutHandler(config.getReadTimeoutMs(), TimeUnit.MILLISECONDS))); diff --git a/rule-engine/rule-engine-components/src/test/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroupTest.java b/rule-engine/rule-engine-components/src/test/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroupTest.java new file mode 100644 index 0000000000..8d19239ea7 --- /dev/null +++ b/rule-engine/rule-engine-components/src/test/java/org/thingsboard/rule/engine/rest/SsrfSafeAddressResolverGroupTest.java @@ -0,0 +1,150 @@ +/** + * Copyright © 2016-2026 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.rule.engine.rest; + +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.resolver.AddressResolver; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.ResourceLock; +import org.thingsboard.common.util.SsrfProtectionValidator; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@ResourceLock("SsrfSafeAddressResolverGroupTest") +class SsrfSafeAddressResolverGroupTest { + + private static NioEventLoopGroup eventLoopGroup; + + @BeforeAll + static void setUp() { + eventLoopGroup = new NioEventLoopGroup(1); + } + + @AfterAll + static void tearDown() { + eventLoopGroup.shutdownGracefully(0, 5, TimeUnit.SECONDS); + SsrfProtectionValidator.setEnabled(false); + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + + @BeforeEach + void enableSsrf() { + SsrfProtectionValidator.setEnabled(true); + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + } + + @AfterEach + void resetState() { + SsrfProtectionValidator.setAllowedHosts(Collections.emptyList()); + SsrfProtectionValidator.setEnabled(false); + } + + @Test + void isBlockedAddressWorksForLoopback() throws Exception { + assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("127.0.0.1"))).isTrue(); + assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("192.168.1.1"))).isTrue(); + assertThat(SsrfProtectionValidator.isBlockedAddress(InetAddress.getByName("8.8.8.8"))).isFalse(); + } + + @Test + void resolvePublicIpSucceeds() throws Exception { + EventExecutor executor = eventLoopGroup.next(); + AddressResolver resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor); + Promise promise = executor.newPromise(); + + executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("example.com", 80), promise)); + InetSocketAddress result = promise.get(10, TimeUnit.SECONDS); + + assertThat(result.getAddress()).isNotNull(); + assertThat(result.getAddress().isLoopbackAddress()).isFalse(); + assertThat(result.getAddress().isSiteLocalAddress()).isFalse(); + } + + @Test + void resolveLoopbackFailsWhenSsrfEnabled() throws Exception { + assertThat(SsrfProtectionValidator.isEnabled()).isTrue(); + + EventExecutor executor = eventLoopGroup.next(); + AddressResolver resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor); + Promise promise = executor.newPromise(); + + executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("127.0.0.1", 80), promise)); + + assertThatThrownBy(() -> promise.get(10, TimeUnit.SECONDS)) + .isInstanceOf(ExecutionException.class) + .hasRootCauseInstanceOf(RuntimeException.class) + .rootCause().hasMessageContaining("SSRF protection"); + } + + @Test + void resolvePrivateIpFailsWhenSsrfEnabled() throws Exception { + assertThat(SsrfProtectionValidator.isEnabled()).isTrue(); + + EventExecutor executor = eventLoopGroup.next(); + AddressResolver resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor); + Promise promise = executor.newPromise(); + + executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("192.168.1.1", 80), promise)); + + assertThatThrownBy(() -> promise.get(10, TimeUnit.SECONDS)) + .isInstanceOf(ExecutionException.class) + .hasRootCauseInstanceOf(RuntimeException.class) + .rootCause().hasMessageContaining("SSRF protection"); + } + + @Test + void resolveAllowedPrivateIpSucceeds() throws Exception { + SsrfProtectionValidator.setAllowedHosts(List.of("192.168.1.0/24")); + + EventExecutor executor = eventLoopGroup.next(); + AddressResolver resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor); + Promise promise = executor.newPromise(); + + executor.submit(() -> resolver.resolve(InetSocketAddress.createUnresolved("192.168.1.1", 80), promise)); + InetSocketAddress result = promise.get(10, TimeUnit.SECONDS); + + assertThat(result.getAddress().getHostAddress()).isEqualTo("192.168.1.1"); + } + + @Test + void resolveAllReturnsAllWhenSsrfDisabled() throws Exception { + SsrfProtectionValidator.setEnabled(false); + + EventExecutor executor = eventLoopGroup.next(); + AddressResolver resolver = SsrfSafeAddressResolverGroup.INSTANCE.getResolver(executor); + Promise> promise = executor.newPromise(); + + executor.submit(() -> resolver.resolveAll(InetSocketAddress.createUnresolved("127.0.0.1", 80), promise)); + List results = promise.get(10, TimeUnit.SECONDS); + + assertThat(results).isNotEmpty(); + } +}