@ -15,20 +15,29 @@
* /
* /
package org.thingsboard.server.service.ai ;
package org.thingsboard.server.service.ai ;
import com.google.cloud.vertexai.api.GenerationConfig ;
import dev.langchain4j.model.ModelProvider ;
import dev.langchain4j.model.chat.ChatModel ;
import dev.langchain4j.model.chat.ChatModel ;
import dev.langchain4j.model.chat.request.ChatRequestParameters ;
import org.junit.jupiter.api.AfterEach ;
import org.junit.jupiter.api.AfterEach ;
import org.junit.jupiter.api.BeforeEach ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.api.Test ;
import org.junit.jupiter.api.parallel.ResourceLock ;
import org.junit.jupiter.api.parallel.ResourceLock ;
import org.springframework.test.util.ReflectionTestUtils ;
import org.thingsboard.common.util.SsrfProtectionValidator ;
import org.thingsboard.common.util.SsrfProtectionValidator ;
import org.thingsboard.server.common.data.ai.model.chat.AmazonBedrockChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.AzureOpenAiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.GitHubModelsChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.GoogleAiGeminiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.GoogleVertexAiGeminiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.GoogleVertexAiGeminiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.MistralAiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.OllamaChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.OllamaChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.OpenAiChatModelConfig ;
import org.thingsboard.server.common.data.ai.provider.AmazonBedrockProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.AnthropicProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.AzureOpenAiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.GitHubModelsProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.GoogleAiGeminiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.GoogleVertexAiGeminiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.GoogleVertexAiGeminiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.MistralAiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.OllamaProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig ;
import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig ;
@ -53,18 +62,280 @@ class Langchain4jChatModelConfigurerImplTest {
private final Langchain4jChatModelConfigurerImpl configurer = new Langchain4jChatModelConfigurerImpl ( ) ;
private final Langchain4jChatModelConfigurerImpl configurer = new Langchain4jChatModelConfigurerImpl ( ) ;
@BeforeEach
void enableSsrfProtection ( ) {
SsrfProtectionValidator . setEnabled ( true ) ;
}
@AfterEach
@AfterEach
void disable SsrfProtection( ) {
void resetSsrfProtection ( ) {
SsrfProtectionValidator . setEnabled ( false ) ;
SsrfProtectionValidator . setEnabled ( false ) ;
}
}
// ============================== Configuration correctness (one per provider) ==============================
// For each provider we feed a fully populated config and assert that the returned ChatModel carries the same
// values, using only the public ChatModel surface (provider() and defaultRequestParameters()) — no reflection.
@Test
void shouldConfigureOpenAiModel_whenGivenOpenAiConfig ( ) {
// GIVEN
var config = OpenAiChatModelConfig . builder ( )
. providerConfig ( OpenAiProviderConfig . builder ( )
. baseUrl ( "https://api.openai.com/v1" )
. apiKey ( "test-key" )
. build ( ) )
. modelId ( "gpt-4o" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. frequencyPenalty ( 0 . 5 )
. presencePenalty ( 0 . 25 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . OPEN_AI ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "gpt-4o" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . frequencyPenalty ( ) ) . isEqualTo ( 0 . 5 ) ;
assertThat ( params . presencePenalty ( ) ) . isEqualTo ( 0 . 25 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureAzureOpenAiModel_whenGivenAzureOpenAiConfig ( ) {
// GIVEN
var config = AzureOpenAiChatModelConfig . builder ( )
. providerConfig ( new AzureOpenAiProviderConfig (
"https://my-resource.openai.azure.com/" , "2024-05-01-preview" , "test-key" ) )
. modelId ( "gpt-4o" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. frequencyPenalty ( 0 . 5 )
. presencePenalty ( 0 . 25 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . AZURE_OPEN_AI ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "gpt-4o" ) ; // deployment name maps to modelName
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . frequencyPenalty ( ) ) . isEqualTo ( 0 . 5 ) ;
assertThat ( params . presencePenalty ( ) ) . isEqualTo ( 0 . 25 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureGoogleAiGeminiModel_whenGivenGoogleAiGeminiConfig ( ) {
// GIVEN
var config = GoogleAiGeminiChatModelConfig . builder ( )
. providerConfig ( new GoogleAiGeminiProviderConfig ( "test-key" ) )
. modelId ( "gemini-2.5-flash" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. topK ( 40 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . GOOGLE_GENAI ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "gemini-2.5-flash" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . topK ( ) ) . isEqualTo ( 40 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureGoogleVertexAiGeminiModel_whenGivenGoogleVertexAiGeminiConfig ( ) {
// GIVEN
var config = GoogleVertexAiGeminiChatModelConfig . builder ( )
. providerConfig ( new GoogleVertexAiGeminiProviderConfig (
"key.json" , "test-project" , "us-central1" , TEST_SERVICE_ACCOUNT_KEY ) )
. modelId ( "gemini-2.5-flash" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. topK ( 40 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . GOOGLE_GENAI ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "gemini-2.5-flash" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . topK ( ) ) . isEqualTo ( 40 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureMistralAiModel_whenGivenMistralAiConfig ( ) {
// GIVEN
var config = MistralAiChatModelConfig . builder ( )
. providerConfig ( new MistralAiProviderConfig ( "test-key" ) )
. modelId ( "mistral-large-latest" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. frequencyPenalty ( 0 . 5 )
. presencePenalty ( 0 . 25 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . MISTRAL_AI ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "mistral-large-latest" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . frequencyPenalty ( ) ) . isEqualTo ( 0 . 5 ) ;
assertThat ( params . presencePenalty ( ) ) . isEqualTo ( 0 . 25 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureAnthropicModel_whenGivenAnthropicConfig ( ) {
// GIVEN
var config = AnthropicChatModelConfig . builder ( )
. providerConfig ( new AnthropicProviderConfig ( "test-key" ) )
. modelId ( "claude-opus-4-8" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. topK ( 40 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . ANTHROPIC ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "claude-opus-4-8" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . topK ( ) ) . isEqualTo ( 40 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
@Test
void configureChatModel_openAi_withPrivateIp_shouldThrow ( ) {
void shouldConfigureAmazonBedrockModel_whenGivenAmazonBedrockConfig ( ) {
// GIVEN
var config = AmazonBedrockChatModelConfig . builder ( )
. providerConfig ( new AmazonBedrockProviderConfig (
"us-east-1" , "test-access-key-id" , "test-secret-access-key" ) )
. modelId ( "anthropic.claude-3-5-sonnet-20240620-v1:0" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . AMAZON_BEDROCK ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "anthropic.claude-3-5-sonnet-20240620-v1:0" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ;
}
@Test
void shouldConfigureGitHubModelsModel_whenGivenGitHubModelsConfig ( ) {
// GIVEN
var config = GitHubModelsChatModelConfig . builder ( )
. providerConfig ( new GitHubModelsProviderConfig ( "ghp-test-token" ) )
. modelId ( "gpt-4o" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. frequencyPenalty ( 0 . 5 )
. presencePenalty ( 0 . 25 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . GITHUB_MODELS ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "gpt-4o" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . frequencyPenalty ( ) ) . isEqualTo ( 0 . 5 ) ;
assertThat ( params . presencePenalty ( ) ) . isEqualTo ( 0 . 25 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ; // maxCompletionTokens maps to maxOutputTokens
}
@Test
void shouldConfigureOllamaModel_whenGivenOllamaConfig ( ) {
// GIVEN
var config = OllamaChatModelConfig . builder ( )
. providerConfig ( new OllamaProviderConfig (
"http://localhost:11434" , new OllamaProviderConfig . OllamaAuth . None ( ) ) )
. modelId ( "llama3" )
. temperature ( 0 . 7 )
. topP ( 0 . 9 )
. topK ( 40 )
. contextLength ( 4096 )
. maxOutputTokens ( 500 )
. timeoutSeconds ( 60 )
. maxRetries ( 3 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( config ) ;
// THEN
assertThat ( chatModel . provider ( ) ) . isEqualTo ( ModelProvider . OLLAMA ) ;
ChatRequestParameters params = chatModel . defaultRequestParameters ( ) ;
assertThat ( params . modelName ( ) ) . isEqualTo ( "llama3" ) ;
assertThat ( params . temperature ( ) ) . isEqualTo ( 0 . 7 ) ;
assertThat ( params . topP ( ) ) . isEqualTo ( 0 . 9 ) ;
assertThat ( params . topK ( ) ) . isEqualTo ( 40 ) ;
assertThat ( params . maxOutputTokens ( ) ) . isEqualTo ( 500 ) ; // numPredict maps to maxOutputTokens
}
// ============================== Base URL SSRF validation ==============================
// Providers that accept a user-supplied base URL must reject hosts that resolve to private/loopback addresses
// when SSRF protection is enabled.
@Test
void shouldThrow_whenOpenAiBaseUrlIsPrivateIp ( ) {
// GIVEN
SsrfProtectionValidator . setEnabled ( true ) ;
var config = OpenAiChatModelConfig . builder ( )
var config = OpenAiChatModelConfig . builder ( )
. providerConfig ( OpenAiProviderConfig . builder ( )
. providerConfig ( OpenAiProviderConfig . builder ( )
. baseUrl ( "http://172.17.0.1:8080/" )
. baseUrl ( "http://172.17.0.1:8080/" )
@ -73,13 +344,16 @@ class Langchain4jChatModelConfigurerImplTest {
. modelId ( "gpt-4o" )
. modelId ( "gpt-4o" )
. build ( ) ;
. build ( ) ;
// WHEN / THEN
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
. isInstanceOf ( RuntimeException . class )
. isInstanceOf ( RuntimeException . class )
. hasMessageContaining ( "URI is invalid" ) ;
. hasMessageContaining ( "URI is invalid" ) ;
}
}
@Test
@Test
void configureChatModel_openAi_withLocalhostUrl_shouldThrow ( ) {
void shouldThrow_whenOpenAiBaseUrlIsLocalhost ( ) {
// GIVEN
SsrfProtectionValidator . setEnabled ( true ) ;
var config = OpenAiChatModelConfig . builder ( )
var config = OpenAiChatModelConfig . builder ( )
. providerConfig ( OpenAiProviderConfig . builder ( )
. providerConfig ( OpenAiProviderConfig . builder ( )
. baseUrl ( "http://localhost:22/" )
. baseUrl ( "http://localhost:22/" )
@ -88,57 +362,42 @@ class Langchain4jChatModelConfigurerImplTest {
. modelId ( "gpt-4o" )
. modelId ( "gpt-4o" )
. build ( ) ;
. build ( ) ;
// WHEN / THEN
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
. isInstanceOf ( RuntimeException . class )
. isInstanceOf ( RuntimeException . class )
. hasMessageContaining ( "URI is invalid" ) ;
. hasMessageContaining ( "URI is invalid" ) ;
}
}
@Test
@Test
void configureChatModel_azureOpenAi_withPrivateIp_shouldThrow ( ) {
void shouldThrow_whenAzureOpenAiEndpointIsPrivateIp ( ) {
// GIVEN
SsrfProtectionValidator . setEnabled ( true ) ;
var config = AzureOpenAiChatModelConfig . builder ( )
var config = AzureOpenAiChatModelConfig . builder ( )
. providerConfig ( new AzureOpenAiProviderConfig (
. providerConfig ( new AzureOpenAiProviderConfig (
"http://10.0.0.1:8080/" , null , "test-key" ) )
"http://10.0.0.1:8080/" , null , "test-key" ) )
. modelId ( "gpt-4o" )
. modelId ( "gpt-4o" )
. build ( ) ;
. build ( ) ;
// WHEN / THEN
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
. isInstanceOf ( RuntimeException . class )
. isInstanceOf ( RuntimeException . class )
. hasMessageContaining ( "URI is invalid" ) ;
. hasMessageContaining ( "URI is invalid" ) ;
}
}
@Test
@Test
void configureChatModel_ollama_withPrivateIp_shouldThrow ( ) {
void shouldThrow_whenOllamaBaseUrlIsPrivateIp ( ) {
// GIVEN
SsrfProtectionValidator . setEnabled ( true ) ;
var config = OllamaChatModelConfig . builder ( )
var config = OllamaChatModelConfig . builder ( )
. providerConfig ( new OllamaProviderConfig (
. providerConfig ( new OllamaProviderConfig (
"http://192.168.1.100:11434/" , new OllamaProviderConfig . OllamaAuth . None ( ) ) )
"http://192.168.1.100:11434/" , new OllamaProviderConfig . OllamaAuth . None ( ) ) )
. modelId ( "llama3" )
. modelId ( "llama3" )
. build ( ) ;
. build ( ) ;
// WHEN / THEN
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
assertThatThrownBy ( ( ) - > configurer . configureChatModel ( config ) )
. isInstanceOf ( RuntimeException . class )
. isInstanceOf ( RuntimeException . class )
. hasMessageContaining ( "URI is invalid" ) ;
. hasMessageContaining ( "URI is invalid" ) ;
}
}
@Test
void configureChatModel_vertexAi_setsFrequencyAndPresencePenaltyFromCorrectConfigFields ( ) {
// GIVEN
var providerConfig = new GoogleVertexAiGeminiProviderConfig (
"test.json" , "test-project" , "us-central1" , TEST_SERVICE_ACCOUNT_KEY
) ;
var chatModelConfig = GoogleVertexAiGeminiChatModelConfig . builder ( )
. providerConfig ( providerConfig )
. modelId ( "gemini-2.0-flash" )
. frequencyPenalty ( 0 . 3 )
. presencePenalty ( 0 . 7 )
. build ( ) ;
// WHEN
ChatModel chatModel = configurer . configureChatModel ( chatModelConfig ) ;
// THEN
var generationConfig = ( GenerationConfig ) ReflectionTestUtils . getField ( chatModel , "generationConfig" ) ;
assertThat ( generationConfig . getFrequencyPenalty ( ) ) . isEqualTo ( 0 . 3f ) ;
assertThat ( generationConfig . getPresencePenalty ( ) ) . isEqualTo ( 0 . 7f ) ;
}
}
}