diff --git a/include/mbedtls/ssl.h b/include/mbedtls/ssl.h index 51a3a0638..eed441319 100644 --- a/include/mbedtls/ssl.h +++ b/include/mbedtls/ssl.h @@ -349,6 +349,8 @@ #define MBEDTLS_TLS_EXT_SESSION_TICKET 35 +#define MBEDTLS_TLS_EXT_ECJPAKE_KKPP 256 /* experimental */ + #define MBEDTLS_TLS_EXT_RENEGOTIATION_INFO 0xFF01 /* diff --git a/library/ssl_cli.c b/library/ssl_cli.c index ba6a6166d..bab63a711 100644 --- a/library/ssl_cli.c +++ b/library/ssl_cli.c @@ -280,6 +280,48 @@ static void ssl_write_supported_point_formats_ext( mbedtls_ssl_context *ssl, } #endif /* MBEDTLS_ECDH_C || MBEDTLS_ECDSA_C || MBEDTLS_ECJPAKE_C */ +#if defined(MBEDTLS_ECJPAKE_C) +static void ssl_write_ecjpake_kkpp_ext( mbedtls_ssl_context *ssl, + unsigned char *buf, + size_t *olen ) +{ + int ret; + unsigned char *p = buf; + const unsigned char *end = ssl->out_buf + MBEDTLS_SSL_MAX_CONTENT_LEN; + size_t kkpp_len; + + *olen = 0; + + /* Skip costly extension if we can't use EC J-PAKE anyway */ + if( mbedtls_ecjpake_check( &ssl->handshake->ecjpake_ctx ) != 0 ) + return; + + MBEDTLS_SSL_DEBUG_MSG( 3, ( "client hello, adding ecjpake_kkpp extension" ) ); + + if( end - p < 4 ) + { + MBEDTLS_SSL_DEBUG_MSG( 1, ( "buffer too small" ) ); + return; + } + + *p++ = (unsigned char)( ( MBEDTLS_TLS_EXT_ECJPAKE_KKPP >> 8 ) & 0xFF ); + *p++ = (unsigned char)( ( MBEDTLS_TLS_EXT_ECJPAKE_KKPP ) & 0xFF ); + + if( ( ret = mbedtls_ecjpake_write_round_one( &ssl->handshake->ecjpake_ctx, + p + 2, end - p - 2, &kkpp_len, + ssl->conf->f_rng, ssl->conf->p_rng ) ) != 0 ) + { + MBEDTLS_SSL_DEBUG_RET( 1 , "mbedtls_ecjpake_write_round_one", ret ); + return; + } + + *p++ = (unsigned char)( ( kkpp_len >> 8 ) & 0xFF ); + *p++ = (unsigned char)( ( kkpp_len ) & 0xFF ); + + *olen = kkpp_len + 4; +} +#endif /* MBEDTLS_ECJPAKE_C */ + #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) static void ssl_write_max_fragment_length_ext( mbedtls_ssl_context *ssl, unsigned char *buf, @@ -781,6 +823,11 @@ static int ssl_write_client_hello( mbedtls_ssl_context *ssl ) ext_len += olen; #endif +#if defined(MBEDTLS_ECJPAKE_C) + ssl_write_ecjpake_kkpp_ext( ssl, p + 2 + ext_len, &olen ); + ext_len += olen; +#endif + #if defined(MBEDTLS_SSL_MAX_FRAGMENT_LENGTH) ssl_write_max_fragment_length_ext( ssl, p + 2 + ext_len, &olen ); ext_len += olen;