diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h index b9084b437..a34d38521 100644 --- a/include/mbedtls/ssl_internal.h +++ b/include/mbedtls/ssl_internal.h @@ -155,6 +155,9 @@ #define MBEDTLS_SSL_OUT_PAYLOAD_LEN ( MBEDTLS_SSL_PAYLOAD_OVERHEAD + \ ( MBEDTLS_SSL_OUT_CONTENT_LEN ) ) +/* The maximum number of buffered handshake messages. */ +#define MBEDTLS_SSL_MAX_BUFFERED_HS 2 + /* Maximum length we can advertise as our max content length for RFC 6066 max_fragment_length extension negotiation purposes (the lesser of both sizes, if they are unequal.) @@ -313,6 +316,14 @@ struct mbedtls_ssl_handshake_params uint8_t seen_ccs; /*!< Indicates if a CCS message has * been seen in the current flight. */ + struct mbedtls_ssl_hs_buffer + { + uint8_t is_valid : 1; + uint8_t is_fragmented : 1; + uint8_t is_complete : 1; + unsigned char *data; + } hs[MBEDTLS_SSL_MAX_BUFFERED_HS]; + } buffering; #endif /* MBEDTLS_SSL_PROTO_DTLS */ @@ -372,6 +383,8 @@ struct mbedtls_ssl_handshake_params #endif /* MBEDTLS_SSL_ASYNC_PRIVATE */ }; +typedef struct mbedtls_ssl_hs_buffer mbedtls_ssl_hs_buffer; + /* * This structure contains a full set of runtime transform parameters * either in negotiation or active. diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 5e573422e..7e01aa35a 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -167,6 +167,8 @@ static int ssl_get_remaining_payload_in_datagram( mbedtls_ssl_context const *ssl return( (int) remaining ); } +static void ssl_buffering_free( mbedtls_ssl_context *ssl ); + /* * Double the retransmit timeout value, within the allowed range, * returning -1 if the maximum value has already been reached. @@ -3072,6 +3074,9 @@ void mbedtls_ssl_recv_flight_completed( mbedtls_ssl_context *ssl ) /* We don't want to remember CCS's across flight boundaries. */ ssl->handshake->buffering.seen_ccs = 0; + /* Clear future message buffering structure. */ + ssl_buffering_free( ssl ); + /* Cancel timer */ ssl_set_timer( ssl, 0 ); @@ -3747,9 +3752,9 @@ int mbedtls_ssl_prepare_handshake_record( mbedtls_ssl_context *ssl ) void mbedtls_ssl_update_handshake_status( mbedtls_ssl_context *ssl ) { + mbedtls_ssl_handshake_params * const hs = ssl->handshake; - if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER && - ssl->handshake != NULL ) + if( ssl->state != MBEDTLS_SSL_HANDSHAKE_OVER && hs != NULL ) { ssl->handshake->update_checksum( ssl, ssl->in_msg, ssl->in_hslen ); } @@ -3759,7 +3764,8 @@ void mbedtls_ssl_update_handshake_status( mbedtls_ssl_context *ssl ) if( ssl->conf->transport == MBEDTLS_SSL_TRANSPORT_DATAGRAM && ssl->handshake != NULL ) { - ssl->handshake->in_msg_seq++; + unsigned offset; + mbedtls_ssl_hs_buffer *hs_buf; /* Clear up handshake reassembly structure, if any. */ if( ssl->handshake->hs_msg != NULL ) @@ -3767,6 +3773,28 @@ void mbedtls_ssl_update_handshake_status( mbedtls_ssl_context *ssl ) mbedtls_free( ssl->handshake->hs_msg ); ssl->handshake->hs_msg = NULL; } + + /* Increment handshake sequence number */ + hs->in_msg_seq++; + + /* + * Clear up handshake buffering and reassembly structure. + */ + + /* Free first entry */ + hs_buf = &hs->buffering.hs[0]; + if( hs_buf->is_valid ) + mbedtls_free( hs_buf->data ); + + /* Shift all other entries */ + for( offset = 0; offset + 1 < MBEDTLS_SSL_MAX_BUFFERED_HS; + offset++, hs_buf++ ) + { + *hs_buf = *(hs_buf + 1); + } + + /* Create a fresh last entry */ + memset( hs_buf, 0, sizeof( mbedtls_ssl_hs_buffer ) ); } #endif } @@ -8286,6 +8314,29 @@ static void ssl_key_cert_free( mbedtls_ssl_key_cert *key_cert ) } #endif /* MBEDTLS_X509_CRT_PARSE_C */ +#if defined(MBEDTLS_SSL_PROTO_DTLS) + +static void ssl_buffering_free( mbedtls_ssl_context *ssl ) +{ + unsigned offset; + mbedtls_ssl_handshake_params * const hs = ssl->handshake; + + if( hs == NULL ) + return; + + for( offset = 0; offset < MBEDTLS_SSL_MAX_BUFFERED_HS; offset++ ) + { + mbedtls_ssl_hs_buffer *hs_buf = &hs->buffering.hs[offset]; + if( hs_buf->is_valid == 1 ) + { + mbedtls_free( hs_buf->data ); + memset( hs_buf, 0, sizeof( mbedtls_ssl_hs_buffer ) ); + } + } +} + +#endif /* MBEDTLS_SSL_PROTO_DTLS */ + void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl ) { mbedtls_ssl_handshake_params *handshake = ssl->handshake; @@ -8367,6 +8418,7 @@ void mbedtls_ssl_handshake_free( mbedtls_ssl_context *ssl ) mbedtls_free( handshake->verify_cookie ); mbedtls_free( handshake->hs_msg ); ssl_flight_free( handshake->flight ); + ssl_buffering_free( ssl ); #endif mbedtls_platform_zeroize( handshake,