Browse Source

fix: add SSRF protection to AI model provider URLs

Apply SsrfProtectionValidator.validateUri() to AI providers with
user-supplied URLs (OpenAI baseUrl, Azure OpenAI endpoint, Ollama
baseUrl). Validation at two layers:
- execution time in Langchain4jChatModelConfigurerImpl
- save time in AiModelDataValidator

Controlled by the existing SSRF_PROTECTION_ENABLED flag.
pull/15412/head
Oleksandra Matviienko 2 months ago
parent
commit
c908d04afb
  1. 9
      application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java
  2. 32
      application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java
  3. 103
      application/src/test/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImplTest.java
  4. 26
      dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java

9
application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java

@ -37,6 +37,7 @@ import dev.langchain4j.model.openai.OpenAiChatModel;
import dev.langchain4j.model.vertexai.gemini.VertexAiGeminiChatModel;
import org.springframework.http.HttpHeaders;
import org.springframework.stereotype.Component;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig;
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig;
import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig;
@ -58,6 +59,7 @@ import software.amazon.awssdk.services.bedrockruntime.BedrockRuntimeClient;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Base64;
@ -69,6 +71,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
@Override
public ChatModel configureChatModel(OpenAiChatModelConfig chatModelConfig) {
validateBaseUrl(chatModelConfig.providerConfig().baseUrl());
return OpenAiChatModel.builder()
.baseUrl(chatModelConfig.providerConfig().baseUrl())
.apiKey(chatModelConfig.providerConfig().apiKey())
@ -86,6 +89,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
@Override
public ChatModel configureChatModel(AzureOpenAiChatModelConfig chatModelConfig) {
AzureOpenAiProviderConfig providerConfig = chatModelConfig.providerConfig();
validateBaseUrl(providerConfig.endpoint());
return AzureOpenAiChatModel.builder()
.endpoint(providerConfig.endpoint())
.serviceVersion(providerConfig.serviceVersion())
@ -273,6 +277,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
@Override
public ChatModel configureChatModel(OllamaChatModelConfig chatModelConfig) {
validateBaseUrl(chatModelConfig.providerConfig().baseUrl());
var builder = OllamaChatModel.builder()
.baseUrl(chatModelConfig.providerConfig().baseUrl())
.modelName(chatModelConfig.modelId())
@ -300,6 +305,10 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
return builder.build();
}
private static void validateBaseUrl(String url) {
SsrfProtectionValidator.validateUri(URI.create(url));
}
private static Duration toDuration(Integer timeoutSeconds) {
return timeoutSeconds != null ? Duration.ofSeconds(timeoutSeconds) : null;
}

32
application/src/test/java/org/thingsboard/server/controller/AiModelControllerTest.java

@ -19,6 +19,7 @@ import com.datastax.oss.driver.api.core.uuid.Uuids;
import com.fasterxml.jackson.core.type.TypeReference;
import org.junit.Test;
import org.springframework.test.web.servlet.ResultActions;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.EntityType;
import org.thingsboard.server.common.data.ai.AiModel;
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig;
@ -136,6 +137,37 @@ public class AiModelControllerTest extends AbstractControllerTest {
assertThat(updatedModel.getExternalId()).isNull();
}
@Test
public void saveAiModel_whenBaseUrlIsPrivateIp_shouldReturnBadRequest() throws Exception {
// GIVEN
loginTenantAdmin();
SsrfProtectionValidator.setEnabled(true);
try {
var modelConfig = OpenAiChatModelConfig.builder()
.providerConfig(OpenAiProviderConfig.builder()
.baseUrl("http://172.17.0.1:22/")
.apiKey("test-api-key")
.build())
.modelId("gpt-4o")
.build();
AiModel model = AiModel.builder()
.tenantId(tenantId)
.name("SSRF test model")
.configuration(modelConfig)
.build();
// WHEN
ResultActions result = doPost("/api/ai/model", model);
// THEN
result.andExpect(status().isBadRequest());
} finally {
SsrfProtectionValidator.setEnabled(false);
}
}
/* --- Get by ID API tests --- */
@Test

103
application/src/test/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImplTest.java

@ -0,0 +1,103 @@
/**
* Copyright © 2016-2026 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.service.ai;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.parallel.ResourceLock;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig;
import org.thingsboard.server.common.data.ai.model.chat.OllamaChatModelConfig;
import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig;
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig;
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig;
import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
@ResourceLock("SsrfProtectionValidator")
class Langchain4jChatModelConfigurerImplTest {
private final Langchain4jChatModelConfigurerImpl configurer = new Langchain4jChatModelConfigurerImpl();
@BeforeEach
void enableSsrfProtection() {
SsrfProtectionValidator.setEnabled(true);
}
@AfterEach
void disableSsrfProtection() {
SsrfProtectionValidator.setEnabled(false);
}
@Test
void configureChatModel_openAi_withPrivateIp_shouldThrow() {
var config = OpenAiChatModelConfig.builder()
.providerConfig(OpenAiProviderConfig.builder()
.baseUrl("http://172.17.0.1:8080/")
.apiKey("test")
.build())
.modelId("gpt-4o")
.build();
assertThatThrownBy(() -> configurer.configureChatModel(config))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
}
@Test
void configureChatModel_openAi_withLocalhostUrl_shouldThrow() {
var config = OpenAiChatModelConfig.builder()
.providerConfig(OpenAiProviderConfig.builder()
.baseUrl("http://localhost:22/")
.apiKey("test")
.build())
.modelId("gpt-4o")
.build();
assertThatThrownBy(() -> configurer.configureChatModel(config))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
}
@Test
void configureChatModel_azureOpenAi_withPrivateIp_shouldThrow() {
var config = AzureOpenAiChatModelConfig.builder()
.providerConfig(new AzureOpenAiProviderConfig(
"http://10.0.0.1:8080/", null, "test-key"))
.modelId("gpt-4o")
.build();
assertThatThrownBy(() -> configurer.configureChatModel(config))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
}
@Test
void configureChatModel_ollama_withPrivateIp_shouldThrow() {
var config = OllamaChatModelConfig.builder()
.providerConfig(new OllamaProviderConfig(
"http://192.168.1.100:11434/", new OllamaProviderConfig.OllamaAuth.None()))
.modelId("llama3")
.build();
assertThatThrownBy(() -> configurer.configureChatModel(config))
.isInstanceOf(RuntimeException.class)
.hasMessageContaining("URI is invalid");
}
}

26
dao/src/main/java/org/thingsboard/server/dao/service/validator/AiModelDataValidator.java

@ -17,13 +17,19 @@ package org.thingsboard.server.dao.service.validator;
import lombok.RequiredArgsConstructor;
import org.springframework.stereotype.Component;
import org.thingsboard.common.util.SsrfProtectionValidator;
import org.thingsboard.server.common.data.ai.AiModel;
import org.thingsboard.server.common.data.ai.provider.AiProviderConfig;
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig;
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig;
import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig;
import org.thingsboard.server.common.data.id.TenantId;
import org.thingsboard.server.dao.ai.AiModelDao;
import org.thingsboard.server.dao.exception.DataValidationException;
import org.thingsboard.server.dao.service.DataValidator;
import org.thingsboard.server.dao.tenant.TenantService;
import java.net.URI;
import java.util.Optional;
@Component
@ -64,6 +70,26 @@ class AiModelDataValidator extends DataValidator<AiModel> {
if (!tenantService.tenantExists(tenantId)) {
throw new DataValidationException("AI model reference a non-existent tenant!");
}
// provider URL SSRF validation
if (model.getConfiguration() != null) {
AiProviderConfig providerConfig = model.getConfiguration().providerConfig();
String url = null;
if (providerConfig instanceof OpenAiProviderConfig c) {
url = c.baseUrl();
} else if (providerConfig instanceof AzureOpenAiProviderConfig c) {
url = c.endpoint();
} else if (providerConfig instanceof OllamaProviderConfig c) {
url = c.baseUrl();
}
if (url != null) {
try {
SsrfProtectionValidator.validateUri(URI.create(url));
} catch (Exception e) {
throw new DataValidationException("AI model provider URL is not allowed: " + e.getMessage());
}
}
}
}
}

Loading…
Cancel
Save