Browse Source

Fix key_dictionary race causing cached keyId 0

pull/14536/head
dshvaika 6 months ago
parent
commit
09eb511599
  1. 2
      dao/src/main/java/org/thingsboard/server/dao/model/sqlts/dictionary/KeyDictionaryCompositeKey.java
  2. 5
      dao/src/main/java/org/thingsboard/server/dao/model/sqlts/dictionary/KeyDictionaryEntry.java
  3. 59
      dao/src/main/java/org/thingsboard/server/dao/sqlts/dictionary/JpaKeyDictionaryDao.java
  4. 4
      dao/src/main/java/org/thingsboard/server/dao/sqlts/dictionary/KeyDictionaryRepository.java
  5. 111
      dao/src/test/java/org/thingsboard/server/dao/sqlts/dictionary/KeyDictionaryDaoTest.java

2
dao/src/main/java/org/thingsboard/server/dao/model/sqlts/dictionary/KeyDictionaryCompositeKey.java

@ -25,7 +25,7 @@ import java.io.Serializable;
@Data
@NoArgsConstructor
@AllArgsConstructor
public class KeyDictionaryCompositeKey implements Serializable{
public class KeyDictionaryCompositeKey implements Serializable {
@Transient
private static final long serialVersionUID = -4089175869616037523L;

5
dao/src/main/java/org/thingsboard/server/dao/model/sqlts/dictionary/KeyDictionaryEntry.java

@ -36,8 +36,7 @@ public final class KeyDictionaryEntry {
@Column(name = KEY_COLUMN)
private String key;
@Column(name = KEY_ID_COLUMN, unique = true, columnDefinition = "int")
@Generated
private int keyId;
@Column(name = KEY_ID_COLUMN, unique = true, columnDefinition = "int", insertable = false, updatable = false)
private Integer keyId;
}

59
dao/src/main/java/org/thingsboard/server/dao/sqlts/dictionary/JpaKeyDictionaryDao.java

@ -17,8 +17,6 @@ package org.thingsboard.server.dao.sqlts.dictionary;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.hibernate.exception.ConstraintViolationException;
import org.springframework.dao.DataIntegrityViolationException;
import org.springframework.stereotype.Component;
import org.springframework.transaction.annotation.Propagation;
import org.springframework.transaction.annotation.Transactional;
@ -48,43 +46,32 @@ public class JpaKeyDictionaryDao extends JpaAbstractDaoListeningExecutorService
@Transactional(propagation = Propagation.NOT_SUPPORTED)
@Override
public Integer getOrSaveKeyId(String strKey) {
Integer keyId = keyDictionaryMap.get(strKey);
if (keyId == null) {
Optional<KeyDictionaryEntry> tsKvDictionaryOptional;
tsKvDictionaryOptional = keyDictionaryRepository.findById(new KeyDictionaryCompositeKey(strKey));
if (tsKvDictionaryOptional.isEmpty()) {
creationLock.lock();
try {
keyId = keyDictionaryMap.get(strKey);
if (keyId != null) {
return keyId;
}
tsKvDictionaryOptional = keyDictionaryRepository.findById(new KeyDictionaryCompositeKey(strKey));
if (tsKvDictionaryOptional.isEmpty()) {
KeyDictionaryEntry keyDictionaryEntry = new KeyDictionaryEntry();
keyDictionaryEntry.setKey(strKey);
try {
KeyDictionaryEntry saved = keyDictionaryRepository.save(keyDictionaryEntry);
keyDictionaryMap.put(saved.getKey(), saved.getKeyId());
keyId = saved.getKeyId();
} catch (DataIntegrityViolationException | ConstraintViolationException e) {
tsKvDictionaryOptional = keyDictionaryRepository.findById(new KeyDictionaryCompositeKey(strKey));
KeyDictionaryEntry dictionary = tsKvDictionaryOptional.orElseThrow(() -> new RuntimeException("Failed to get KeyDictionaryEntry entity from DB!"));
keyDictionaryMap.put(dictionary.getKey(), dictionary.getKeyId());
keyId = dictionary.getKeyId();
}
} else {
keyId = tsKvDictionaryOptional.get().getKeyId();
}
} finally {
creationLock.unlock();
Integer cached = keyDictionaryMap.get(strKey);
if (cached != null) {
return cached;
}
creationLock.lock();
try {
Integer keyId = keyDictionaryMap.get(strKey);
if (keyId != null) {
return keyId;
}
keyId = keyDictionaryRepository.upsertAndGetKeyId(strKey);
if (keyId == null || keyId == 0) {
log.warn("upsertAndGetKeyId returned: [{}] for key: [{}], falling back to findById", keyId, strKey);
KeyDictionaryCompositeKey id = new KeyDictionaryCompositeKey(strKey);
Optional<KeyDictionaryEntry> entryOpt = keyDictionaryRepository.findById(id);
if (entryOpt.isEmpty() ||
entryOpt.get().getKeyId() == null ||
entryOpt.get().getKeyId() == 0) {
throw new IllegalStateException("Failed to resolve keyId for string key: " + strKey + " after fallback. keyId: " + keyId);
}
} else {
keyId = tsKvDictionaryOptional.get().getKeyId();
keyDictionaryMap.put(strKey, keyId);
}
keyDictionaryMap.put(strKey, keyId);
return keyId;
} finally {
creationLock.unlock();
}
return keyId;
}
@Override

4
dao/src/main/java/org/thingsboard/server/dao/sqlts/dictionary/KeyDictionaryRepository.java

@ -19,6 +19,7 @@ import org.springframework.data.domain.Page;
import org.springframework.data.domain.Pageable;
import org.springframework.data.jpa.repository.JpaRepository;
import org.springframework.data.jpa.repository.Query;
import org.springframework.data.repository.query.Param;
import org.thingsboard.server.dao.model.sqlts.dictionary.KeyDictionaryCompositeKey;
import org.thingsboard.server.dao.model.sqlts.dictionary.KeyDictionaryEntry;
@ -31,4 +32,7 @@ public interface KeyDictionaryRepository extends JpaRepository<KeyDictionaryEntr
@Query("SELECT e FROM KeyDictionaryEntry e ORDER BY e.keyId ASC")
Page<KeyDictionaryEntry> findAll(Pageable pageable);
@Query(value = "INSERT INTO key_dictionary (key) VALUES (:key) ON CONFLICT (key) DO UPDATE SET key = EXCLUDED.key RETURNING key_id", nativeQuery = true)
Integer upsertAndGetKeyId(@Param("key") String key);
}

111
dao/src/test/java/org/thingsboard/server/dao/sqlts/dictionary/KeyDictionaryDaoTest.java

@ -0,0 +1,111 @@
/**
* Copyright © 2016-2025 The Thingsboard Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.thingsboard.server.dao.sqlts.dictionary;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.thingsboard.server.dao.dictionary.KeyDictionaryDao;
import org.thingsboard.server.dao.model.sqlts.dictionary.KeyDictionaryCompositeKey;
import org.thingsboard.server.dao.model.sqlts.dictionary.KeyDictionaryEntry;
import org.thingsboard.server.dao.service.AbstractServiceTest;
import org.thingsboard.server.dao.service.DaoSqlTest;
import java.util.Arrays;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import static org.assertj.core.api.Assertions.assertThat;
@DaoSqlTest
public class KeyDictionaryDaoTest extends AbstractServiceTest {
@Autowired
private KeyDictionaryDao keyDictionaryDao;
@Autowired
private KeyDictionaryRepository keyDictionaryRepository;
private static final String KEY = "testKeyDictionaryDaoTestKey";
@Test
public void testGetOrSaveKeyId_concurrent() throws Exception {
int threads = 8;
ExecutorService executor = Executors.newFixedThreadPool(threads);
CountDownLatch allReady = new CountDownLatch(threads);
CountDownLatch start = new CountDownLatch(1);
CountDownLatch allDone = new CountDownLatch(threads);
Integer[] keyIds = new Integer[threads];
try {
for (int i = 0; i < threads; i++) {
final int idx = i;
executor.submit(() -> {
allReady.countDown();
try {
// wait until all threads are ready
start.await();
// concurrent call
Integer id = keyDictionaryDao.getOrSaveKeyId(KEY);
keyIds[idx] = id;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
} finally {
allDone.countDown();
}
});
}
// ensure all threads are queued
allReady.await(5, TimeUnit.SECONDS);
// fire the start gun
start.countDown();
// wait for all to finish
allDone.await(10, TimeUnit.SECONDS);
} finally {
executor.shutdownNow();
}
// basic sanity
for (int i = 0; i < threads; i++) {
assertThat(keyIds[i])
.as("keyId[%s]", i)
.isNotNull()
.isGreaterThan(0);
}
// all threads must see the same keyId
int first = keyIds[0];
assertThat(first).isGreaterThan(0);
assertThat(Arrays.stream(keyIds).distinct().count())
.as("all threads should get the same keyId")
.isEqualTo(1);
// DB must have exactly one row for this key and the same id
KeyDictionaryCompositeKey id = new KeyDictionaryCompositeKey(KEY);
Optional<KeyDictionaryEntry> entry = keyDictionaryRepository.findById(id);
assertThat(entry.isPresent()).isTrue();
assertThat(entry.get().getKeyId()).isEqualTo(first);
keyDictionaryRepository.deleteById(id);
}
}
Loading…
Cancel
Save