diff --git a/include/polarssl/asn1write.h b/include/polarssl/asn1write.h index 808bf3abe..d65ff79be 100644 --- a/include/polarssl/asn1write.h +++ b/include/polarssl/asn1write.h @@ -205,6 +205,28 @@ int asn1_write_bitstring( unsigned char **p, unsigned char *start, */ int asn1_write_octet_string( unsigned char **p, unsigned char *start, const unsigned char *buf, size_t size ); + +/** + * \brief Create or find a specific named_data entry for writing in a + * sequence or list based on the OID. If not already in there, + * a new entry is added to the head of the list. + * Warning: Destructive behaviour for the val data! + * + * \param list Pointer to the location of the head of the list to seek + * through (will be updated in case of a new entry) + * \param oid The OID to look for + * \param oid_len Size of the OID + * \param val Data to store (can be NULL if you want to fill it by hand) + * \param val_len Minimum length of the data buffer needed + * + * \return NULL if if there was a memory allocation error, or a pointer + * to the new / existing entry. + */ +asn1_named_data *asn1_store_named_data( asn1_named_data **list, + const char *oid, size_t oid_len, + const unsigned char *val, + size_t val_len ); + #ifdef __cplusplus } #endif diff --git a/library/asn1write.c b/library/asn1write.c index 893841f80..302acc33c 100644 --- a/library/asn1write.c +++ b/library/asn1write.c @@ -29,6 +29,14 @@ #include "polarssl/asn1write.h" +#if defined(POLARSSL_MEMORY_C) +#include "polarssl/memory.h" +#else +#include +#define polarssl_malloc malloc +#define polarssl_free free +#endif + int asn1_write_len( unsigned char **p, unsigned char *start, size_t len ) { if( len < 0x80 ) @@ -290,4 +298,65 @@ int asn1_write_octet_string( unsigned char **p, unsigned char *start, return( len ); } + +asn1_named_data *asn1_store_named_data( asn1_named_data **head, + const char *oid, size_t oid_len, + const unsigned char *val, + size_t val_len ) +{ + asn1_named_data *cur; + + if( ( cur = asn1_find_named_data( *head, oid, oid_len ) ) == NULL ) + { + // Add new entry if not present yet based on OID + // + if( ( cur = polarssl_malloc( sizeof(asn1_named_data) ) ) == NULL ) + return( NULL ); + + memset( cur, 0, sizeof(asn1_named_data) ); + + cur->oid.len = oid_len; + cur->oid.p = polarssl_malloc( oid_len ); + if( cur->oid.p == NULL ) + { + polarssl_free( cur ); + return( NULL ); + } + + cur->val.len = val_len; + cur->val.p = polarssl_malloc( val_len ); + if( cur->val.p == NULL ) + { + polarssl_free( cur->oid.p ); + polarssl_free( cur ); + return( NULL ); + } + + memcpy( cur->oid.p, oid, oid_len ); + + cur->next = *head; + *head = cur; + } + else if( cur->val.len < val_len ) + { + // Enlarge existing value buffer if needed + // + polarssl_free( cur->val.p ); + cur->val.p = NULL; + + cur->val.len = val_len; + cur->val.p = polarssl_malloc( val_len ); + if( cur->val.p == NULL ) + { + polarssl_free( cur->oid.p ); + polarssl_free( cur ); + return( NULL ); + } + } + + if( val != NULL ) + memcpy( cur->val.p, val, val_len ); + + return( cur ); +} #endif diff --git a/library/x509write.c b/library/x509write.c index d025abb05..8966ecb73 100644 --- a/library/x509write.c +++ b/library/x509write.c @@ -198,49 +198,10 @@ static int x509_set_extension( asn1_named_data **head, { asn1_named_data *cur; - if( ( cur = asn1_find_named_data( *head, oid, oid_len ) ) == NULL ) + if( ( cur = asn1_store_named_data( head, oid, oid_len, + NULL, val_len + 1 ) ) == NULL ) { - cur = polarssl_malloc( sizeof(asn1_named_data) ); - if( cur == NULL ) - return( POLARSSL_ERR_X509WRITE_MALLOC_FAILED ); - - memset( cur, 0, sizeof(asn1_named_data) ); - - cur->oid.len = oid_len; - cur->oid.p = polarssl_malloc( oid_len ); - if( cur->oid.p == NULL ) - { - polarssl_free( cur ); - return( POLARSSL_ERR_X509WRITE_MALLOC_FAILED ); - } - - cur->val.len = val_len + 1; - cur->val.p = polarssl_malloc( val_len + 1 ); - if( cur->val.p == NULL ) - { - polarssl_free( cur->oid.p ); - polarssl_free( cur ); - return( POLARSSL_ERR_X509WRITE_MALLOC_FAILED ); - } - - memcpy( cur->oid.p, oid, oid_len ); - - cur->next = *head; - *head = cur; - } - - if( cur->val.len != val_len + 1 ) - { - polarssl_free( cur->val.p ); - - cur->val.len = val_len + 1; - cur->val.p = polarssl_malloc( val_len + 1); - if( cur->val.p == NULL ) - { - polarssl_free( cur->oid.p ); - polarssl_free( cur ); - return( POLARSSL_ERR_X509WRITE_MALLOC_FAILED ); - } + return( POLARSSL_ERR_X509WRITE_MALLOC_FAILED ); } cur->val.p[0] = critical;