From 91f05282ea51b355cb2c04948919600cffb73f87 Mon Sep 17 00:00:00 2001 From: vzikratyi Date: Wed, 7 Oct 2020 13:06:22 +0300 Subject: [PATCH] Validate domains for SchemeTypes combinations --- .../server/dao/oauth2/OAuth2ServiceImpl.java | 9 +- .../dao/service/BaseOAuth2ServiceTest.java | 99 ++++++++++++++++++- 2 files changed, 102 insertions(+), 6 deletions(-) diff --git a/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java b/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java index dd287f0755..2ad0ad4f20 100644 --- a/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java +++ b/dao/src/main/java/org/thingsboard/server/dao/oauth2/OAuth2ServiceImpl.java @@ -119,10 +119,17 @@ public class OAuth2ServiceImpl extends AbstractEntityService implements OAuth2Se if (StringUtils.isEmpty(domainInfo.getName())) { throw new DataValidationException("Domain name should be specified!"); } - if (StringUtils.isEmpty(domainInfo.getScheme())) { + if (domainInfo.getScheme() == null) { throw new DataValidationException("Domain scheme should be specified!"); } } + domainParams.getDomainInfos().stream() + .collect(Collectors.groupingBy(DomainInfo::getName)) + .forEach((domainName, domainInfos) -> { + if (domainInfos.size() > 1 && domainInfos.stream().anyMatch(domainInfo -> domainInfo.getScheme() == SchemeType.MIXED)) { + throw new DataValidationException("MIXED scheme type shouldn't be combined with another scheme type!"); + } + }); if (domainParams.getClientRegistrations() == null || domainParams.getClientRegistrations().isEmpty()) { throw new DataValidationException("Client registrations should be specified!"); } diff --git a/dao/src/test/java/org/thingsboard/server/dao/service/BaseOAuth2ServiceTest.java b/dao/src/test/java/org/thingsboard/server/dao/service/BaseOAuth2ServiceTest.java index 2daa0cb337..b54505a3d7 100644 --- a/dao/src/test/java/org/thingsboard/server/dao/service/BaseOAuth2ServiceTest.java +++ b/dao/src/test/java/org/thingsboard/server/dao/service/BaseOAuth2ServiceTest.java @@ -22,6 +22,7 @@ import org.junit.Before; import org.junit.Test; import org.springframework.beans.factory.annotation.Autowired; import org.thingsboard.server.common.data.oauth2.*; +import org.thingsboard.server.dao.exception.DataValidationException; import org.thingsboard.server.dao.oauth2.OAuth2Service; import java.util.*; @@ -45,6 +46,44 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { Assert.assertTrue(oAuth2Service.findOAuth2Params().getDomainsParams().isEmpty()); } + @Test(expected = DataValidationException.class) + public void testSaveHttpAndMixedDomainsTogether() { + OAuth2ClientsParams clientsParams = new OAuth2ClientsParams(true, Sets.newHashSet( + OAuth2ClientsDomainParams.builder() + .domainInfos(Sets.newHashSet( + DomainInfo.builder().name("first-domain").scheme(SchemeType.HTTP).build(), + DomainInfo.builder().name("first-domain").scheme(SchemeType.MIXED).build(), + DomainInfo.builder().name("third-domain").scheme(SchemeType.HTTPS).build() + )) + .clientRegistrations(Sets.newHashSet( + validClientRegistrationDto(), + validClientRegistrationDto(), + validClientRegistrationDto() + )) + .build() + )); + oAuth2Service.saveOAuth2Params(clientsParams); + } + + @Test(expected = DataValidationException.class) + public void testSaveHttpsAndMixedDomainsTogether() { + OAuth2ClientsParams clientsParams = new OAuth2ClientsParams(true, Sets.newHashSet( + OAuth2ClientsDomainParams.builder() + .domainInfos(Sets.newHashSet( + DomainInfo.builder().name("first-domain").scheme(SchemeType.HTTPS).build(), + DomainInfo.builder().name("first-domain").scheme(SchemeType.MIXED).build(), + DomainInfo.builder().name("third-domain").scheme(SchemeType.HTTPS).build() + )) + .clientRegistrations(Sets.newHashSet( + validClientRegistrationDto(), + validClientRegistrationDto(), + validClientRegistrationDto() + )) + .build() + )); + oAuth2Service.saveOAuth2Params(clientsParams); + } + @Test public void testCreateAndFindParams() { OAuth2ClientsParams clientsParams = createDefaultClientsParams(); @@ -178,7 +217,7 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { Assert.assertTrue(nonExistentDomainClients.isEmpty()); List firstDomainHttpClients = oAuth2Service.getOAuth2Clients("http", "first-domain"); - Assert.assertEquals(firstDomainHttpClients.size(), firstDomainHttpClients.size()); + Assert.assertEquals(firstGroupClientInfos.size(), firstDomainHttpClients.size()); firstGroupClientInfos.forEach(firstGroupClientInfo -> { Assert.assertTrue( firstDomainHttpClients.stream().anyMatch(clientInfo -> @@ -191,7 +230,7 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { Assert.assertTrue(firstDomainHttpsClients.isEmpty()); List fourthDomainHttpClients = oAuth2Service.getOAuth2Clients("http", "fourth-domain"); - Assert.assertEquals(fourthDomainHttpClients.size(), secondGroupClientInfos.size()); + Assert.assertEquals(secondGroupClientInfos.size(), fourthDomainHttpClients.size()); secondGroupClientInfos.forEach(secondGroupClientInfo -> { Assert.assertTrue( fourthDomainHttpClients.stream().anyMatch(clientInfo -> @@ -200,7 +239,7 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { ); }); List fourthDomainHttpsClients = oAuth2Service.getOAuth2Clients("https", "fourth-domain"); - Assert.assertEquals(fourthDomainHttpsClients.size(), secondGroupClientInfos.size()); + Assert.assertEquals(secondGroupClientInfos.size(), fourthDomainHttpsClients.size()); secondGroupClientInfos.forEach(secondGroupClientInfo -> { Assert.assertTrue( fourthDomainHttpsClients.stream().anyMatch(clientInfo -> @@ -210,7 +249,7 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { }); List secondDomainHttpClients = oAuth2Service.getOAuth2Clients("http", "second-domain"); - Assert.assertEquals(secondDomainHttpClients.size(), firstGroupClientInfos.size() + secondGroupClientInfos.size()); + Assert.assertEquals(firstGroupClientInfos.size() + secondGroupClientInfos.size(), secondDomainHttpClients.size()); firstGroupClientInfos.forEach(firstGroupClientInfo -> { Assert.assertTrue( secondDomainHttpClients.stream().anyMatch(clientInfo -> @@ -227,7 +266,7 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { }); List secondDomainHttpsClients = oAuth2Service.getOAuth2Clients("https", "second-domain"); - Assert.assertEquals(secondDomainHttpsClients.size(), firstGroupClientInfos.size() + thirdGroupClientInfos.size()); + Assert.assertEquals(firstGroupClientInfos.size() + thirdGroupClientInfos.size(), secondDomainHttpsClients.size()); firstGroupClientInfos.forEach(firstGroupClientInfo -> { Assert.assertTrue( secondDomainHttpsClients.stream().anyMatch(clientInfo -> @@ -244,6 +283,56 @@ public class BaseOAuth2ServiceTest extends AbstractServiceTest { }); } + @Test + public void testGetOAuth2ClientsForHttpAndHttps() { + Set firstGroup = Sets.newHashSet( + validClientRegistrationDto(), + validClientRegistrationDto(), + validClientRegistrationDto(), + validClientRegistrationDto() + ); + OAuth2ClientsParams clientsParams = new OAuth2ClientsParams(true, Sets.newHashSet( + OAuth2ClientsDomainParams.builder() + .domainInfos(Sets.newHashSet( + DomainInfo.builder().name("first-domain").scheme(SchemeType.HTTP).build(), + DomainInfo.builder().name("second-domain").scheme(SchemeType.MIXED).build(), + DomainInfo.builder().name("first-domain").scheme(SchemeType.HTTPS).build() + )) + .clientRegistrations(firstGroup) + .build() + )); + + oAuth2Service.saveOAuth2Params(clientsParams); + OAuth2ClientsParams foundClientsParams = oAuth2Service.findOAuth2Params(); + Assert.assertNotNull(foundClientsParams); + Assert.assertEquals(clientsParams, foundClientsParams); + + List firstGroupClientInfos = firstGroup.stream() + .map(clientRegistrationDto -> new OAuth2ClientInfo( + clientRegistrationDto.getLoginButtonLabel(), clientRegistrationDto.getLoginButtonIcon(), null)) + .collect(Collectors.toList()); + + List firstDomainHttpClients = oAuth2Service.getOAuth2Clients("http", "first-domain"); + Assert.assertEquals(firstGroupClientInfos.size(), firstDomainHttpClients.size()); + firstGroupClientInfos.forEach(firstGroupClientInfo -> { + Assert.assertTrue( + firstDomainHttpClients.stream().anyMatch(clientInfo -> + clientInfo.getIcon().equals(firstGroupClientInfo.getIcon()) + && clientInfo.getName().equals(firstGroupClientInfo.getName())) + ); + }); + + List firstDomainHttpsClients = oAuth2Service.getOAuth2Clients("https", "first-domain"); + Assert.assertEquals(firstGroupClientInfos.size(), firstDomainHttpsClients.size()); + firstGroupClientInfos.forEach(firstGroupClientInfo -> { + Assert.assertTrue( + firstDomainHttpsClients.stream().anyMatch(clientInfo -> + clientInfo.getIcon().equals(firstGroupClientInfo.getIcon()) + && clientInfo.getName().equals(firstGroupClientInfo.getName())) + ); + }); + } + @Test public void testGetDisabledOAuth2Clients() { OAuth2ClientsParams clientsParams = new OAuth2ClientsParams(true, Sets.newHashSet(