diff --git a/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java b/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java index 28d696468c..8e14eff009 100644 --- a/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java +++ b/application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java @@ -37,6 +37,7 @@ import dev.langchain4j.model.openai.OpenAiChatModel; import dev.langchain4j.model.vertexai.gemini.VertexAiGeminiChatModel; import org.springframework.http.HttpHeaders; import org.springframework.stereotype.Component; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig; import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig; import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig; @@ -58,6 +59,7 @@ import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient; import java.io.ByteArrayInputStream; import java.io.IOException; +import java.net.URI; import java.nio.charset.StandardCharsets; import java.time.Duration; import java.util.Base64; @@ -69,6 +71,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur @Override public ChatModel configureChatModel(OpenAiChatModelConfig chatModelConfig) { + validateBaseUrl(chatModelConfig.providerConfig().baseUrl()); return OpenAiChatModel.builder() .baseUrl(chatModelConfig.providerConfig().baseUrl()) .apiKey(chatModelConfig.providerConfig().apiKey()) @@ -86,6 +89,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur @Override public ChatModel configureChatModel(AzureOpenAiChatModelConfig chatModelConfig) { AzureOpenAiProviderConfig providerConfig = chatModelConfig.providerConfig(); + validateBaseUrl(providerConfig.endpoint()); return AzureOpenAiChatModel.builder() .endpoint(providerConfig.endpoint()) .serviceVersion(providerConfig.serviceVersion()) @@ -273,6 +277,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur @Override public ChatModel configureChatModel(OllamaChatModelConfig chatModelConfig) { + validateBaseUrl(chatModelConfig.providerConfig().baseUrl()); var builder = OllamaChatModel.builder() .baseUrl(chatModelConfig.providerConfig().baseUrl()) .modelName(chatModelConfig.modelId()) @@ -300,6 +305,10 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur return builder.build(); } + private static void validateBaseUrl(String url) { + SsrfProtectionValidator.validateUri(URI.create(url)); + } + private static Duration toDuration(Integer timeoutSeconds) { return timeoutSeconds != null ? Duration.ofSeconds(timeoutSeconds) : null; } diff --git a/application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java b/application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java index 099ef05752..84c5b31d00 100644 --- a/application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java +++ b/application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java @@ -19,6 +19,7 @@ import com.datastax.oss.driver.api.core.uuid.Uuids; import com.fasterxml.jackson.core.type.TypeReference; import org.junit.Test; import org.springframework.test.web.servlet.ResultActions; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.EntityType; import org.thingsboard.server.common.data.ai.AiModel; import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig; @@ -136,6 +137,37 @@ public class AiModelControllerTest extends AbstractControllerTest { assertThat(updatedModel.getExternalId()).isNull(); } + @Test + public void saveAiModel_whenBaseUrlIsPrivateIp_shouldReturnBadRequest() throws Exception { + // GIVEN + loginTenantAdmin(); + SsrfProtectionValidator.setEnabled(true); + + try { + var modelConfig = OpenAiChatModelConfig.builder() + .providerConfig(OpenAiProviderConfig.builder() + .baseUrl("http://172.17.0.1:22/") + .apiKey("test-api-key") + .build()) + .modelId("gpt-4o") + .build(); + + AiModel model = AiModel.builder() + .tenantId(tenantId) + .name("SSRF test model") + .configuration(modelConfig) + .build(); + + // WHEN + ResultActions result = doPost("/api/ai/model", model); + + // THEN + result.andExpect(status().isBadRequest()); + } finally { + SsrfProtectionValidator.setEnabled(false); + } + } + /* --- Get by ID API tests --- */ @Test diff --git a/application/src/test/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImplTest.java b/application/src/test/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImplTest.java new file mode 100644 index 0000000000..61509cf1a3 --- /dev/null +++ b/application/src/test/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImplTest.java @@ -0,0 +1,103 @@ +/** + * Copyright © 2016-2026 The Thingsboard Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.thingsboard.server.service.ai; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.ResourceLock; +import org.thingsboard.common.util.SsrfProtectionValidator; +import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig; +import org.thingsboard.server.common.data.ai.model.chat.OllamaChatModelConfig; +import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig; +import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig; +import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig; +import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +@ResourceLock("SsrfProtectionValidator") +class Langchain4jChatModelConfigurerImplTest { + + private final Langchain4jChatModelConfigurerImpl configurer = new Langchain4jChatModelConfigurerImpl(); + + @BeforeEach + void enableSsrfProtection() { + SsrfProtectionValidator.setEnabled(true); + } + + @AfterEach + void disableSsrfProtection() { + SsrfProtectionValidator.setEnabled(false); + } + + @Test + void configureChatModel_openAi_withPrivateIp_shouldThrow() { + var config = OpenAiChatModelConfig.builder() + .providerConfig(OpenAiProviderConfig.builder() + .baseUrl("http://172.17.0.1:8080/") + .apiKey("test") + .build()) + .modelId("gpt-4o") + .build(); + + assertThatThrownBy(() -> configurer.configureChatModel(config)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } + + @Test + void configureChatModel_openAi_withLocalhostUrl_shouldThrow() { + var config = OpenAiChatModelConfig.builder() + .providerConfig(OpenAiProviderConfig.builder() + .baseUrl("http://localhost:22/") + .apiKey("test") + .build()) + .modelId("gpt-4o") + .build(); + + assertThatThrownBy(() -> configurer.configureChatModel(config)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } + + @Test + void configureChatModel_azureOpenAi_withPrivateIp_shouldThrow() { + var config = AzureOpenAiChatModelConfig.builder() + .providerConfig(new AzureOpenAiProviderConfig( + "http://10.0.0.1:8080/", null, "test-key")) + .modelId("gpt-4o") + .build(); + + assertThatThrownBy(() -> configurer.configureChatModel(config)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } + + @Test + void configureChatModel_ollama_withPrivateIp_shouldThrow() { + var config = OllamaChatModelConfig.builder() + .providerConfig(new OllamaProviderConfig( + "http://192.168.1.100:11434/", new OllamaProviderConfig.OllamaAuth.None())) + .modelId("llama3") + .build(); + + assertThatThrownBy(() -> configurer.configureChatModel(config)) + .isInstanceOf(RuntimeException.class) + .hasMessageContaining("URI is invalid"); + } + +} diff --git a/dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java b/dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java index 9b138fe279..3d6c420e4c 100644 --- a/dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java +++ b/dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java @@ -17,13 +17,19 @@ package org.thingsboard.server.dao.service.validator; import lombok.RequiredArgsConstructor; import org.springframework.stereotype.Component; +import org.thingsboard.common.util.SsrfProtectionValidator; import org.thingsboard.server.common.data.ai.AiModel; +import org.thingsboard.server.common.data.ai.provider.AiProviderConfig; +import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig; +import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig; +import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig; import org.thingsboard.server.common.data.id.TenantId; import org.thingsboard.server.dao.ai.AiModelDao; import org.thingsboard.server.dao.exception.DataValidationException; import org.thingsboard.server.dao.service.DataValidator; import org.thingsboard.server.dao.tenant.TenantService; +import java.net.URI; import java.util.Optional; @Component @@ -64,6 +70,26 @@ class AiModelDataValidator extends DataValidator { if (!tenantService.tenantExists(tenantId)) { throw new DataValidationException("AI model reference a non-existent tenant!"); } + + // provider URL SSRF validation + if (model.getConfiguration() != null) { + AiProviderConfig providerConfig = model.getConfiguration().providerConfig(); + String url = null; + if (providerConfig instanceof OpenAiProviderConfig c) { + url = c.baseUrl(); + } else if (providerConfig instanceof AzureOpenAiProviderConfig c) { + url = c.endpoint(); + } else if (providerConfig instanceof OllamaProviderConfig c) { + url = c.baseUrl(); + } + if (url != null) { + try { + SsrfProtectionValidator.validateUri(URI.create(url)); + } catch (Exception e) { + throw new DataValidationException("AI model provider URL is not allowed: " + e.getMessage()); + } + } + } } }