Browse Source

AI rule node: add frequency penalty for the models that support it

pull/13371/head
Dmytro Skarzhynets 11 months ago
parent
commit
dfe4dea436
No known key found for this signature in database GPG Key ID: 2B51652F224037DF
  1. 8
      application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java
  2. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/AzureOpenAiChatModel.java
  3. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GitHubModelsChatModel.java
  4. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleAiGeminiChatModel.java
  5. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java
  6. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/MistralAiChatModel.java
  7. 1
      common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/OpenAiChatModel.java

8
application/src/main/java/org/thingsboard/server/service/ai/Langchain4jChatModelConfigurerImpl.java

@ -61,6 +61,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
.modelName(modelConfig.modelId())
.temperature(modelConfig.temperature())
.topP(modelConfig.topP())
.frequencyPenalty(modelConfig.frequencyPenalty())
.timeout(toDuration(modelConfig.timeoutSeconds()))
.maxRetries(modelConfig.maxRetries())
.build();
@ -74,6 +75,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
.deploymentName(modelConfig.modelId())
.temperature(modelConfig.temperature())
.topP(modelConfig.topP())
.frequencyPenalty(modelConfig.frequencyPenalty())
.timeout(toDuration(modelConfig.timeoutSeconds()))
.maxRetries(modelConfig.maxRetries())
.build();
@ -88,6 +90,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
.temperature(modelConfig.temperature())
.topP(modelConfig.topP())
.topK(modelConfig.topK())
.frequencyPenalty(modelConfig.frequencyPenalty())
.timeout(toDuration(modelConfig.timeoutSeconds()))
.maxRetries(modelConfig.maxRetries())
.build();
@ -153,6 +156,9 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
if (modelConfig.topK() != null) {
generationConfigBuilder.setTopK(modelConfig.topK());
}
if (modelConfig.frequencyPenalty() != null) {
generationConfigBuilder.setFrequencyPenalty(modelConfig.frequencyPenalty().floatValue());
}
var generationConfig = generationConfigBuilder.build();
// construct generative model instance
@ -177,6 +183,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
.modelName(modelConfig.modelId())
.temperature(modelConfig.temperature())
.topP(modelConfig.topP())
.frequencyPenalty(modelConfig.frequencyPenalty())
.timeout(toDuration(modelConfig.timeoutSeconds()))
.maxRetries(modelConfig.maxRetries())
.build();
@ -232,6 +239,7 @@ class Langchain4jChatModelConfigurerImpl implements Langchain4jChatModelConfigur
.modelName(modelConfig.modelId())
.temperature(modelConfig.temperature())
.topP(modelConfig.topP())
.frequencyPenalty(modelConfig.frequencyPenalty())
.timeout(toDuration(modelConfig.timeoutSeconds()))
.maxRetries(modelConfig.maxRetries())
.build();

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/AzureOpenAiChatModel.java

@ -31,6 +31,7 @@ public record AzureOpenAiChatModel(
String modelId,
Double temperature,
Double topP,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<AzureOpenAiChatModel.Config> {}

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GitHubModelsChatModel.java

@ -31,6 +31,7 @@ public record GitHubModelsChatModel(
String modelId,
Double temperature,
Double topP,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<GitHubModelsChatModel.Config> {}

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleAiGeminiChatModel.java

@ -32,6 +32,7 @@ public record GoogleAiGeminiChatModel(
Double temperature,
Double topP,
Integer topK,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<GoogleAiGeminiChatModel.Config> {}

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/GoogleVertexAiGeminiChatModel.java

@ -32,6 +32,7 @@ public record GoogleVertexAiGeminiChatModel(
Double temperature,
Double topP,
Integer topK,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<GoogleVertexAiGeminiChatModel.Config> {}

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/MistralAiChatModel.java

@ -31,6 +31,7 @@ public record MistralAiChatModel(
String modelId,
Double temperature,
Double topP,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<MistralAiChatModel.Config> {}

1
common/data/src/main/java/org/thingsboard/server/common/data/ai/model/chat/OpenAiChatModel.java

@ -31,6 +31,7 @@ public record OpenAiChatModel(
String modelId,
Double temperature,
Double topP,
Double frequencyPenalty,
Integer timeoutSeconds,
Integer maxRetries
) implements AiChatModelConfig<OpenAiChatModel.Config> {}

Loading…
Cancel
Save