@ -21,7 +21,6 @@ import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext ;
import io.netty.channel.ChannelInboundHandlerAdapter ;
import io.netty.handler.codec.mqtt.MqttConnAckMessage ;
import io.netty.handler.codec.mqtt.MqttConnAckVariableHeader ;
import io.netty.handler.codec.mqtt.MqttConnectMessage ;
import io.netty.handler.codec.mqtt.MqttConnectReturnCode ;
import io.netty.handler.codec.mqtt.MqttFixedHeader ;
@ -63,12 +62,12 @@ import org.thingsboard.server.common.transport.adaptor.AdaptorException;
import org.thingsboard.server.common.transport.auth.SessionInfoCreator ;
import org.thingsboard.server.common.transport.auth.TransportDeviceInfo ;
import org.thingsboard.server.common.transport.auth.ValidateDeviceCredentialsResponse ;
import org.thingsboard.server.common.transport.auth.ValidateDeviceProfileCredentialsResponse ;
import org.thingsboard.server.common.transport.service.DefaultTransportService ;
import org.thingsboard.server.common.transport.service.SessionMetaData ;
import org.thingsboard.server.common.transport.util.SslUtil ;
import org.thingsboard.server.gen.transport.TransportProtos ;
import org.thingsboard.server.gen.transport.TransportProtos.ProvisionDeviceResponseMsg ;
import org.thingsboard.server.gen.transport.TransportProtos.ValidateDeviceX509CertRequestMsg ;
import org.thingsboard.server.queue.scheduler.SchedulerComponent ;
import org.thingsboard.server.transport.mqtt.adaptors.MqttTransportAdaptor ;
import org.thingsboard.server.transport.mqtt.session.DeviceSessionCtx ;
@ -80,7 +79,7 @@ import org.thingsboard.server.transport.mqtt.util.ReturnCodeResolver;
import javax.net.ssl.SSLPeerUnverifiedException ;
import java.io.IOException ;
import java.net.InetSocketAddress ;
import java.security.cert.Certificate ;
import java.security.cert.CertificateEncodingException ;
import java.security.cert.X509Certificate ;
import java.util.ArrayList ;
import java.util.Collections ;
@ -90,16 +89,15 @@ import java.util.UUID;
import java.util.concurrent.Callable ;
import java.util.concurrent.ConcurrentHashMap ;
import java.util.concurrent.ConcurrentMap ;
import java.util.concurrent.CountDownLatch ;
import java.util.concurrent.TimeUnit ;
import java.util.regex.Matcher ;
import java.util.regex.Pattern ;
import static com.amazonaws.util.StringUtils.UTF8 ;
import static io.netty.handler.codec.mqtt.MqttMessageType.CONNACK ;
import static io.netty.handler.codec.mqtt.MqttMessageType.CONNECT ;
import static io.netty.handler.codec.mqtt.MqttMessageType.PINGRESP ;
import static io.netty.handler.codec.mqtt.MqttMessageType.SUBACK ;
import static io.netty.handler.codec.mqtt.MqttMessageType.UNSUBACK ;
import static io.netty.handler.codec.mqtt.MqttQoS.AT_LEAST_ONCE ;
import static io.netty.handler.codec.mqtt.MqttQoS.AT_MOST_ONCE ;
import static org.thingsboard.server.common.transport.service.DefaultTransportService.SESSION_EVENT_MSG_CLOSED ;
@ -810,9 +808,9 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
deviceSessionCtx . setProvisionOnly ( true ) ;
ctx . writeAndFlush ( createMqttConnAckMsg ( ReturnCode . SUCCESS , msg ) ) ;
} else {
X509Certificate cert ;
if ( sslHandler ! = null & & ( cert = getX509Certificate ( ) ) ! = null ) {
processX509CertConnect ( ctx , cert , msg ) ;
X509Certificate [ ] chain ;
if ( sslHandler ! = null & & ( chain = getX509Certificate ( ) ) ! = null ) {
processX509CertConnect ( ctx , chain , msg ) ;
} else {
processAuthTokenConnect ( ctx , msg ) ;
}
@ -848,27 +846,82 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
} ) ;
}
private void processX509CertConnect ( ChannelHandlerContext ctx , X509Certificate cert , MqttConnectMessage connectMessage ) {
private void processX509CertConnect ( ChannelHandlerContext ctx , X509Certificate [ ] chain , MqttConnectMessage connectMessage ) {
try {
if ( ! context . isSkipValidityCheckForClientCert ( ) ) {
cert . checkValidity ( ) ;
String deviceCN = SslUtil . parseCommonName ( chain [ 0 ] ) ;
String deviceCertHash = EncryptionUtil . getSha3Hash ( SslUtil . getCertificateString ( chain [ 0 ] ) ) ;
for ( X509Certificate cert : chain ) {
try {
String strCert = SslUtil . getCertificateString ( cert ) ;
String sha3Hash = EncryptionUtil . getSha3Hash ( strCert ) ;
final ValidateDeviceCredentialsResponse [ ] validateDeviceCredentialsResponses = new ValidateDeviceCredentialsResponse [ 1 ] ;
CountDownLatch latch = new CountDownLatch ( 1 ) ;
transportService . process ( DeviceTransportType . MQTT , TransportProtos . ValidateDeviceX509CertRequestMsg . newBuilder ( ) . setHash ( sha3Hash ) . build ( ) ,
new TransportServiceCallback < > ( ) {
@Override
public void onSuccess ( ValidateDeviceCredentialsResponse msg ) {
if ( ! StringUtils . isEmpty ( msg . getCredentials ( ) ) ) {
validateDeviceCredentialsResponses [ 0 ] = msg ;
latch . countDown ( ) ;
} else {
transportService . process ( DeviceTransportType . MQTT ,
TransportProtos . ValidateDeviceProfileX509CertRequestMsg . newBuilder ( ) . setHash ( sha3Hash ) . build ( ) ,
new TransportServiceCallback < > ( ) {
@Override
public void onSuccess ( ValidateDeviceProfileCredentialsResponse msg ) {
if ( msg . isDeviceProfileFound ( ) ) {
transportService . process ( DeviceTransportType . MQTT ,
TransportProtos . UpdateOrCreateDeviceX509CertRequestMsg . newBuilder ( )
. setHash ( deviceCertHash )
. setCommonName ( deviceCN )
. setDeviceProfileIdMSB ( msg . getDeviceProfileId ( ) . getId ( ) . getMostSignificantBits ( ) )
. setDeviceProfileIdLSB ( msg . getDeviceProfileId ( ) . getId ( ) . getLeastSignificantBits ( ) )
. build ( ) ,
new TransportServiceCallback < > ( ) {
@Override
public void onSuccess ( ValidateDeviceCredentialsResponse msg ) {
if ( ! StringUtils . isEmpty ( msg . getCredentials ( ) ) ) {
validateDeviceCredentialsResponses [ 0 ] = msg ;
latch . countDown ( ) ;
}
}
@Override
public void onError ( Throwable e ) {
log . error ( e . getMessage ( ) , e ) ;
latch . countDown ( ) ;
}
}
) ;
} else {
latch . countDown ( ) ;
}
}
@Override
public void onError ( Throwable e ) {
log . error ( e . getMessage ( ) , e ) ;
latch . countDown ( ) ;
}
} ) ;
}
}
@Override
public void onError ( Throwable e ) {
log . error ( e . getMessage ( ) , e ) ;
latch . countDown ( ) ;
}
} ) ;
latch . await ( 10 , TimeUnit . SECONDS ) ;
if ( validateDeviceCredentialsResponses [ 0 ] ! = null & & validateDeviceCredentialsResponses [ 0 ] . hasDeviceInfo ( ) ) {
onValidateDeviceResponse ( validateDeviceCredentialsResponses [ 0 ] , ctx , connectMessage ) ;
break ;
}
} catch ( InterruptedException | CertificateEncodingException e ) {
log . error ( e . getMessage ( ) , e ) ;
}
}
String strCert = SslUtil . getCertificateString ( cert ) ;
String sha3Hash = EncryptionUtil . getSha3Hash ( strCert ) ;
transportService . process ( DeviceTransportType . MQTT , ValidateDeviceX509CertRequestMsg . newBuilder ( ) . setHash ( sha3Hash ) . build ( ) ,
new TransportServiceCallback < > ( ) {
@Override
public void onSuccess ( ValidateDeviceCredentialsResponse msg ) {
onValidateDeviceResponse ( msg , ctx , connectMessage ) ;
}
@Override
public void onError ( Throwable e ) {
log . trace ( "[{}] Failed to process credentials: {}" , address , sha3Hash , e ) ;
ctx . writeAndFlush ( createMqttConnAckMsg ( ReturnCode . SERVER_UNAVAILABLE_5 , connectMessage ) ) ;
ctx . close ( ) ;
}
} ) ;
} catch ( Exception e ) {
context . onAuthFailure ( address ) ;
ctx . writeAndFlush ( createMqttConnAckMsg ( ReturnCode . NOT_AUTHORIZED_5 , connectMessage ) ) ;
@ -877,17 +930,13 @@ public class MqttTransportHandler extends ChannelInboundHandlerAdapter implement
}
}
private X509Certificate getX509Certificate ( ) {
private X509Certificate [ ] getX509Certificate ( ) {
try {
Certificate [ ] certChain = sslHandler . engine ( ) . getSession ( ) . getPeerCertificates ( ) ;
if ( certChain . length > 0 ) {
return ( X509Certificate ) certChain [ 0 ] ;
}
return ( X509Certificate [ ] ) sslHandler . engine ( ) . getSession ( ) . getPeerCertificates ( ) ;
} catch ( SSLPeerUnverifiedException e ) {
log . warn ( e . getMessage ( ) ) ;
return null ;
}
return null ;
}
private MqttConnAckMessage createMqttConnAckMsg ( ReturnCode returnCode , MqttConnectMessage msg ) {