@ -17,8 +17,11 @@ package org.thingsboard.rule.engine.ai;
import com.fasterxml.jackson.databind.node.ObjectNode ;
import com.google.common.util.concurrent.FluentFuture ;
import com.google.common.util.concurrent.Futures ;
import dev.langchain4j.data.message.AiMessage ;
import dev.langchain4j.data.message.ImageContent ;
import dev.langchain4j.data.message.SystemMessage ;
import dev.langchain4j.data.message.TextContent ;
import dev.langchain4j.data.message.UserMessage ;
import dev.langchain4j.model.chat.request.ResponseFormat ;
import dev.langchain4j.model.chat.request.ResponseFormatType ;
@ -32,6 +35,7 @@ import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments ;
import org.junit.jupiter.params.provider.MethodSource ;
import org.junit.jupiter.params.provider.ValueSource ;
import org.mockito.ArgumentCaptor ;
import org.mockito.Mock ;
import org.mockito.junit.jupiter.MockitoExtension ;
import org.thingsboard.common.util.JacksonUtil ;
@ -43,6 +47,10 @@ import org.thingsboard.rule.engine.api.RuleEngineAiChatModelService;
import org.thingsboard.rule.engine.api.TbContext ;
import org.thingsboard.rule.engine.api.TbNodeConfiguration ;
import org.thingsboard.rule.engine.api.TbNodeException ;
import org.thingsboard.server.common.data.GeneralFileDescriptor ;
import org.thingsboard.server.common.data.ResourceType ;
import org.thingsboard.server.common.data.TbResource ;
import org.thingsboard.server.common.data.TbResourceDataInfo ;
import org.thingsboard.server.common.data.ai.AiModel ;
import org.thingsboard.server.common.data.ai.model.AiModelConfig ;
import org.thingsboard.server.common.data.ai.model.chat.AnthropicChatModelConfig ;
@ -52,6 +60,7 @@ import org.thingsboard.server.common.data.ai.provider.OpenAiProviderConfig;
import org.thingsboard.server.common.data.id.AiModelId ;
import org.thingsboard.server.common.data.id.DeviceId ;
import org.thingsboard.server.common.data.id.RuleNodeId ;
import org.thingsboard.server.common.data.id.TbResourceId ;
import org.thingsboard.server.common.data.id.TenantId ;
import org.thingsboard.server.common.data.msg.TbNodeConnectionType ;
import org.thingsboard.server.common.data.rule.RuleNode ;
@ -59,9 +68,14 @@ import org.thingsboard.server.common.msg.TbMsg;
import org.thingsboard.server.common.msg.TbMsgMetaData ;
import org.thingsboard.server.dao.ai.AiModelService ;
import org.thingsboard.server.dao.exception.DataValidationException ;
import org.thingsboard.server.dao.resource.ResourceService ;
import org.thingsboard.server.dao.resource.TbResourceDataCache ;
import java.util.Base64 ;
import java.util.List ;
import java.util.Map ;
import java.util.Optional ;
import java.util.Set ;
import java.util.UUID ;
import java.util.stream.Stream ;
@ -76,16 +90,23 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.then ;
import static org.mockito.Mockito.lenient ;
import static org.mockito.Mockito.never ;
import static org.thingsboard.server.common.data.ResourceType.GENERAL ;
@ExtendWith ( MockitoExtension . class )
class TbAiNodeTest {
private static final byte [ ] PNG_IMAGE = Base64 . getDecoder ( ) . decode ( "iVBORw0KGgoAAAANSUhEUgAAAMgAAACgCAMAAAB+IdObAAAC9FBMVEUAAAABAQEBAgICAgICAwMCBAQDAwMDBQUDBgYEBAQEBwcECAgFCQkFCgoGBgYGCwsGDAwHBwcHDQ0HDg4ICAgIDw8IEBAJCQkJEREKEhIKExMLFBQLFRUMFhYMFxcNDQ0NGBgNGRkODg4OGhoOGxsPDw8PHBwPHR0QEBAQHh4QHx8RERERICARISESEhISIiITExMTIyMTJCQUJSUUJiYVKCgWFhYWKSkXFxcXGhwYGBgYLC0ZGRkaMDEaMTIbGxsbMjMcMzQdNTYfOTogICAgOzwiP0AiQEEjIyMjQkMkQ0QnJycnSEkoS0wpKSkrUFErUVIsLCwvV1gvWFkwWlszMzMzYGE1NTU2NjY3Zmc4aWo5OTk5ams5a2w6Ojo6bG07bm88cXI9cnM9c3Q/dndAQEBAeHlBeXpCQkJCe3xCfH1DQ0NEREREf4FFRUVFgIJGg4VHhYdISEhIhohJh4lLi41LjI5MTExMjpBNj5FNkJJOkpRQUFBQlZdRUVFSUlJTU1NTmpxUVFRUnZ9VVVVVnqBWVlZYpadZWVlZp6laqKpbW1tbqatbqqxcXFxcrK5dra9drrBeXl5er7FfsbNfsrRgs7VhYWFiYmJiuLpjubtku71lvL5lvb9mvsBnwcNowsRpxMZpxcdra2tryctsysxubm5uzc9vb29vz9Fw0dNx0tVy1Ndy1dhz1tlz19p0dHR02Nt02dx12t1229523N93d3d33eB33uF5eXl54eR6enp64+Z65Od75eh75ul8fHx85+p86Ot96ex96u2AgICA7vGA7/KB8fSC8/aD9PeD9fiEhISE9/qF+PuF+fyGhoaG+v2G+/6Hh4eH/P+IiIiMjIyNjY2Ojo6QkJCRkZGSkpKTk5Obm5ucnJyfn5+lpaWnp6eoqKipqamqqqqwsLCzs7O1tbW4uLi5ubm6urq7u7u8vLy/v7/BwcHCwsLFxcXGxsbPz8/Y2Nji4uLj4+Pv7+/4+Pj5+fn+/v7/75T///+GLm1tAAAAAWJLR0T7omo23AAABJtJREFUeNrt3Wd8E3UYB/CH0oqm1dJaS5N0IKu0qQSVinXG4gKlKFi3uMC9FVwoVQnQqCBgBVxFnKCoFFFExFGhliWt/zoYLuIMKEpB7b3xuf9dQu+MvAjXcsTf7/PJk/ul1/S+TS53r3KkNFfk0V6evDHbFGruQ3EQTzNVUFxkHOXFB6QbIQiCIAiC/GeSs/QkR6vkCPeUaNUeSUjkkdR1npCp6a7VV7U6P1dbKfNFrS89rJNas/T6rlZtkUS/i2evhw99Q92y9/r7nVzzw7VfeDX3y2qv893plTVb1uW+uw6xiyNpspAQ8bjLy8l5REiImOlUq3Pniunyxw8Ib+vqF7aB5AgdItLVmit0iOgc9W0owhDt1RSAABL3EGeDDqmXhwRXgw6pj3qESFhtgHC1DYSGrJCQjweFq4SEqzkD67zGah8Inay+p1yl4XqKWt2lF69UDxQrzzevXZprrDn2gfTIUs85Iv/oHpny8HKHdugeVZhpXNudu6u6J1P8lmpIX1ys10X6myVfPeLl919UZFi74JXjWtfCecfa5sj+odx908XSg9Taqdaw+3I1QuYLA6RG2AbiEDpE9JJnvcYP1BRhgiw3QuoAASTuIQnP6JCF8hQlcbYBwrWIKgPDIg9UGSGP2QdCnZ+QkDneKQs4swqe1CDJ09RaXfBUETWKm3a+gFMMEMc0+0AoJVX9nM1+VDsCznLurz64b5VWq7nWLLi81QfygYZfNlU7nAUP0nOwrLnGiiAIgiAIgiAIgiDI/zstLS3tMEtKSiycgAACCCCAAAIIIIAAAggggAACCCCAAAIIIIAAAggggAACCCCAAAIIIIAAAggggAACCCCAAAIIIIBYAkEQBEEQBEEQBEEQBGmrdLwuyLmhg703km8Z63k7N2Tw0jnqFt/f0bROn69WBYOfbuxiyR+8MXC9vB8QCBTQkEAgMOG2gVyvDmTzdAWuifFp077m8f503vwZr/PSd28Hg+uaTjVDlOFEIxVrINVijfwi4glCHE1XioXPz6kX9xHNFIUkvyM/xqeduIPHup95bGni8edYotOUqJCrrII0iMv4LnNFg4Sczd/9/Zw4abchD0Ygv0pIBVFZG0Nq587lu/PE02EIXSQuaSfI92l88bfNFkHqLxUnEM1+bXQEMloMY8hgn893esyQIzbzWHtveXn51GW89AtfTeyATWZIWm919s6wBtLYdfXdVCyuuEdCHhoxwr/mAzdDtMQKoaP4duQmRVG+kUtyu83X3OuylX09f+9r0c6eOvkjx82fdPdLiHrdjsrD1Z39LP5W06ExQ475g8eqSR6PZ+oXvLSVNWk/nmmGKNcSXaBYBXEPFkMXV1GlhFyYlSof3t19ZOxfPJp+4/HTeh47JhGdqLQxJDtpyRJxBgUi+0g7QkYSlVsHoVtFrcNiyO0SsoXHDxIykej4v/8F+XxDKLRxmXWQfo2jyGJIh894PDs9FArNeIGXvlwbCn37Upl5rXObOMPtf1K4z5u8ne/sx0tl6hbfgtNkBEGQPZs4uUBwTxoTH5DxtM0TD46+20lpHrfXX7e52/jtyj9kFKbIT2L3FQAAAABJRU5ErkJggg==" ) ;
@Mock
TbContext ctxMock ;
@Mock
AiModelService aiModelServiceMock ;
@Mock
RuleEngineAiChatModelService aiChatModelServiceMock ;
@Mock
TbResourceDataCache tbResourceDataCacheMock ;
@Mock
ResourceService resourceServiceMock ;
TbAiNode aiNode ;
TbAiNodeConfiguration config ;
@ -141,6 +162,8 @@ class TbAiNodeTest {
lenient ( ) . when ( ctxMock . getAiModelService ( ) ) . thenReturn ( aiModelServiceMock ) ;
lenient ( ) . when ( ctxMock . getAiChatModelService ( ) ) . thenReturn ( aiChatModelServiceMock ) ;
lenient ( ) . when ( ctxMock . getDbCallbackExecutor ( ) ) . thenReturn ( new TestDbCallbackExecutor ( ) ) ;
lenient ( ) . when ( ctxMock . getTbResourceDataCache ( ) ) . thenReturn ( tbResourceDataCacheMock ) ;
lenient ( ) . when ( ctxMock . getResourceService ( ) ) . thenReturn ( resourceServiceMock ) ;
}
@Test
@ -158,6 +181,7 @@ class TbAiNodeTest {
assertThat ( config . getResponseFormat ( ) ) . isEqualTo ( new TbJsonResponseFormat ( ) ) ;
assertThat ( config . getTimeoutSeconds ( ) ) . isEqualTo ( 60 ) ;
assertThat ( config . isForceAck ( ) ) . isTrue ( ) ;
assertThat ( config . getResourceIds ( ) ) . isNull ( ) ;
}
/* -- Node initialization tests -- */
@ -373,6 +397,36 @@ class TbAiNodeTest {
. matches ( e - > ( ( TbNodeException ) e ) . isUnrecoverable ( ) ) ;
}
@Test
void givenNotExistingResources_whenInit_thenThrowsException ( ) {
// GIVEN
config = constructValidConfig ( ) ;
UUID resourceId = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId ) ) ;
// WHEN-THEN
assertThatThrownBy ( ( ) - > aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) )
. isInstanceOf ( TbNodeException . class )
. hasMessageContaining ( "[" + tenantId + "] Resource with ID: [" + resourceId + "] was not found" ) ;
}
@Test
void givenResourceOfWrongType_whenInit_thenThrowsException ( ) {
// GIVEN
config = constructValidConfig ( ) ;
UUID resourceId = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId ) ) ;
// WHEN-THEN
TbResource tbResource = new TbResource ( ) ;
tbResource . setResourceType ( ResourceType . DASHBOARD ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , any ( ) ) ) . willReturn ( tbResource ) ;
assertThatThrownBy ( ( ) - > aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) )
. isInstanceOf ( TbNodeException . class )
. hasMessageContaining ( "[" + tenantId + "] Resource with ID: [" + resourceId + "] has unsupported resource type: " + ResourceType . DASHBOARD ) ;
}
/* -- Message processing tests -- */
@Test
@ -560,6 +614,166 @@ class TbAiNodeTest {
) ;
}
@Test
void givenSystemPromptAndUserPromptAndResourcesConfigured_whenOnMsg_thenRequestContainsSystemAndUserAndResourceContent ( ) throws TbNodeException {
String systemPrompt = "Respond with valid JSON" ;
String userPrompt = "Tell me a joke" ;
String textData = "Text resource content for AI request." ;
String xmlData = "<?xml version=\"1.0\" encoding=\"UTF-8\"?><test></test>" ;
// GIVEN
config = constructValidConfig ( ) ;
config . setSystemPrompt ( systemPrompt ) ;
config . setUserPrompt ( userPrompt ) ;
UUID resourceId = UUID . randomUUID ( ) ;
UUID resourceId2 = UUID . randomUUID ( ) ;
UUID resourceId3 = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId , resourceId2 , resourceId3 ) ) ;
// WHEN-THEN
TbResource textResource = buildGeneralResource ( textData . getBytes ( ) , "text/plain" ) ;
TbResource xmlResource = buildGeneralResource ( xmlData . getBytes ( ) , "application/xml" ) ;
TbResource imageResource = buildGeneralResource ( PNG_IMAGE , "image/png" ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( textResource ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId2 ) ) ) ) . willReturn ( xmlResource ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId3 ) ) ) ) . willReturn ( imageResource ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( textResource . toResourceDataInfo ( ) ) ) ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId2 ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( xmlResource . toResourceDataInfo ( ) ) ) ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId3 ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( imageResource . toResourceDataInfo ( ) ) ) ) ;
aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) ;
var msg = TbMsg . newMsg ( )
. originator ( deviceId )
. data ( TbMsg . EMPTY_JSON_OBJECT )
. metaData ( TbMsgMetaData . EMPTY )
. build ( ) ;
var chatResponse = ChatResponse . builder ( )
. aiMessage ( AiMessage . from ( "{\"type\":\"joke\",\"setup\":\"Why did the scarecrow win an award?\",\"punchline\":\"Because he was outstanding in his field.\"}" ) )
. build ( ) ;
given ( aiChatModelServiceMock . sendChatRequestAsync ( any ( ) , any ( ) ) ) . willReturn ( FluentFuture . from ( immediateFuture ( chatResponse ) ) ) ;
// WHEN
aiNode . onMsg ( ctxMock , msg ) ;
// THEN
then ( aiChatModelServiceMock ) . should ( ) . sendChatRequestAsync ( any ( ) ,
argThat ( actualChatRequest - > {
assertThat ( actualChatRequest . messages ( ) ) . hasSize ( 2 ) ;
assertThat ( actualChatRequest . messages ( ) . get ( 0 ) ) . isEqualTo ( SystemMessage . from ( systemPrompt ) ) ;
assertThat ( ( ( UserMessage ) actualChatRequest . messages ( ) . get ( 1 ) ) . contents ( ) )
. containsAll ( List . of ( new TextContent ( userPrompt ) , new TextContent ( textData ) ,
new TextContent ( xmlData ) , new ImageContent ( Base64 . getEncoder ( ) . encodeToString ( PNG_IMAGE ) , "image/png" ) ) ) ;
return true ;
} )
) ;
}
@Test
void givenNullResource_whenOnMsg_thenRequestContainsSystemAndUserPrompt ( ) throws TbNodeException {
// GIVEN
config = constructValidConfig ( ) ;
UUID resourceId = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId ) ) ;
// WHEN-THEN
TbResource tbResource = buildGeneralResource ( "Text resource content for AI request." . getBytes ( ) , "text/plain" ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( tbResource ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( null ) ) ) ;
aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) ;
var msg = TbMsg . newMsg ( )
. originator ( deviceId )
. data ( TbMsg . EMPTY_JSON_OBJECT )
. metaData ( TbMsgMetaData . EMPTY )
. build ( ) ;
// WHEN
aiNode . onMsg ( ctxMock , msg ) ;
// THEN
then ( aiChatModelServiceMock ) . should ( ) . sendChatRequestAsync ( any ( ) ,
argThat ( actualChatRequest - > {
assertThat ( actualChatRequest . messages ( ) ) . hasSize ( 2 ) ;
assertThat ( actualChatRequest . messages ( ) . get ( 0 ) ) . isEqualTo ( SystemMessage . from ( config . getSystemPrompt ( ) ) ) ;
assertThat ( ( ( UserMessage ) actualChatRequest . messages ( ) . get ( 1 ) ) . contents ( ) )
. containsAll ( List . of ( new TextContent ( config . getUserPrompt ( ) ) ) ) ;
return true ;
} )
) ;
}
@Test
void givenResourceWithNoDescriptor_whenOnMsg_thenEnqueueForTellFailure ( ) throws TbNodeException {
// GIVEN
config = constructValidConfig ( ) ;
UUID resourceId = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId ) ) ;
// WHEN-THEN
TbResource tbResource = buildGeneralResource ( "Text resource content for AI request." . getBytes ( ) , "text/plain" ) ;
TbResourceDataInfo resourceDataInfo = new TbResourceDataInfo ( tbResource . getData ( ) , null ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( tbResource ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( resourceDataInfo ) ) ) ;
aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) ;
var msg = TbMsg . newMsg ( )
. originator ( deviceId )
. data ( TbMsg . EMPTY_JSON_OBJECT )
. metaData ( TbMsgMetaData . EMPTY )
. build ( ) ;
// WHEN
aiNode . onMsg ( ctxMock , msg ) ;
// THEN
var exceptionCaptor = ArgumentCaptor . forClass ( Throwable . class ) ;
then ( ctxMock ) . should ( ) . enqueueForTellFailure ( any ( ) , exceptionCaptor . capture ( ) ) ;
Throwable actualException = exceptionCaptor . getValue ( ) ;
assertThat ( actualException . getMessage ( ) ) . isEqualTo ( "Missing descriptor for resource" ) ;
}
@Test
void givenResourceWithNoMediaType_whenOnMsg_thenEnqueueForTellFailure ( ) throws TbNodeException {
// GIVEN
config = constructValidConfig ( ) ;
UUID resourceId = UUID . randomUUID ( ) ;
config . setResourceIds ( Set . of ( resourceId ) ) ;
// WHEN-THEN
TbResource tbResource = buildGeneralResource ( "Text resource content for AI request." . getBytes ( ) , "text/plain" ) ;
TbResourceDataInfo resourceDataInfo = new TbResourceDataInfo ( tbResource . getData ( ) , JacksonUtil . newObjectNode ( ) ) ;
given ( resourceServiceMock . findResourceInfoById ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( tbResource ) ;
given ( tbResourceDataCacheMock . getResourceDataInfoAsync ( any ( ) , eq ( new TbResourceId ( resourceId ) ) ) ) . willReturn ( FluentFuture . from ( Futures . immediateFuture ( resourceDataInfo ) ) ) ;
aiNode . init ( ctxMock , new TbNodeConfiguration ( JacksonUtil . valueToTree ( config ) ) ) ;
var msg = TbMsg . newMsg ( )
. originator ( deviceId )
. data ( TbMsg . EMPTY_JSON_OBJECT )
. metaData ( TbMsgMetaData . EMPTY )
. build ( ) ;
// WHEN
aiNode . onMsg ( ctxMock , msg ) ;
// THEN
var exceptionCaptor = ArgumentCaptor . forClass ( Throwable . class ) ;
then ( ctxMock ) . should ( ) . enqueueForTellFailure ( any ( ) , exceptionCaptor . capture ( ) ) ;
Throwable actualException = exceptionCaptor . getValue ( ) ;
assertThat ( actualException . getMessage ( ) ) . isEqualTo ( "Missing mediaType in resource descriptor {}" ) ;
}
@Test
void givenTemplatedPrompts_whenOnMsg_thenRequestContainsSubstitutedMessages ( ) throws TbNodeException {
// GIVEN
@ -950,4 +1164,13 @@ class TbAiNodeTest {
then ( ctxMock ) . should ( never ( ) ) . tellFailure ( any ( ) , any ( ) ) ;
}
private TbResource buildGeneralResource ( byte [ ] data , String mediaType ) {
TbResource tbResource = new TbResource ( ) ;
tbResource . setResourceType ( GENERAL ) ;
GeneralFileDescriptor descriptor = new GeneralFileDescriptor ( mediaType ) ;
tbResource . setDescriptorValue ( descriptor ) ;
tbResource . setData ( data ) ;
return tbResource ;
}
}