Skip to content

Commit

Permalink
Merge pull request #513 from robguima/robguima/transfer_tls13_sessions
Browse files Browse the repository at this point in the history
adds TLS13 support for ptls_import() and ptls_export()
  • Loading branch information
kazuho committed Mar 5, 2024
2 parents 628e876 + 483973c commit 703553c
Show file tree
Hide file tree
Showing 3 changed files with 199 additions and 78 deletions.
1 change: 1 addition & 0 deletions include/picotls.h
Original file line number Diff line number Diff line change
Expand Up @@ -1588,6 +1588,7 @@ int ptls_handshake_is_complete(ptls_t *tls);
int ptls_is_psk_handshake(ptls_t *tls);
/**
* return if a ECH handshake was performed, as well as optionally the kem and cipher-suite being used
* FIXME: this function always return false when the TLS session is exported and imported
*/
int ptls_is_ech_handshake(ptls_t *tls, uint8_t *config_id, ptls_hpke_kem_t **kem, ptls_hpke_cipher_suite_t **cipher);
/**
Expand Down
173 changes: 117 additions & 56 deletions lib/picotls.c
Original file line number Diff line number Diff line change
Expand Up @@ -1615,7 +1615,7 @@ static int get_traffic_keys(ptls_aead_algorithm_t *aead, ptls_hash_algorithm_t *
return ret;
}

static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_label, size_t epoch, int skip_notify)
static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_label, size_t epoch, uint64_t seq, int skip_notify)
{
static const char *log_labels[2][4] = {
{NULL, "CLIENT_EARLY_TRAFFIC_SECRET", "CLIENT_HANDSHAKE_TRAFFIC_SECRET", "CLIENT_TRAFFIC_SECRET_0"},
Expand Down Expand Up @@ -1645,7 +1645,7 @@ static int setup_traffic_protection(ptls_t *tls, int is_enc, const char *secret_
if ((ctx->aead = ptls_aead_new(tls->cipher_suite->aead, tls->cipher_suite->hash, is_enc, ctx->secret,
tls->ctx->hkdf_label_prefix__obsolete)) == NULL)
return PTLS_ERROR_NO_MEMORY; /* TODO obtain error from ptls_aead_new */
ctx->seq = 0;
ctx->seq = seq;

PTLS_DEBUGF("[%s] %02x%02x,%02x%02x\n", log_labels[ptls_is_server(tls)][epoch], (unsigned)ctx->secret[0],
(unsigned)ctx->secret[1], (unsigned)ctx->aead->static_iv[0], (unsigned)ctx->aead->static_iv[1]);
Expand All @@ -1664,7 +1664,7 @@ static int commission_handshake_secret(ptls_t *tls)
free(tls->pending_handshake_secret);
tls->pending_handshake_secret = NULL;

return setup_traffic_protection(tls, is_enc, NULL, 2, 1);
return setup_traffic_protection(tls, is_enc, NULL, 2, 0, 1);
}

static void log_client_random(ptls_t *tls)
Expand Down Expand Up @@ -2479,7 +2479,7 @@ static int send_client_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptls_

if (tls->client.using_early_data) {
assert(!is_second_flight);
if ((ret = setup_traffic_protection(tls, 1, "c e traffic", 1, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 1, "c e traffic", 1, 0, 0)) != 0)
goto Exit;
if ((ret = push_change_cipher_spec(tls, emitter)) != 0)
goto Exit;
Expand Down Expand Up @@ -2795,7 +2795,7 @@ static int client_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl

if ((ret = key_schedule_extract(tls->key_schedule, ecdh_secret)) != 0)
goto Exit;
if ((ret = setup_traffic_protection(tls, 0, "s hs traffic", 2, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 0, "s hs traffic", 2, 0, 0)) != 0)
goto Exit;
if (tls->client.using_early_data) {
if ((tls->pending_handshake_secret = malloc(PTLS_MAX_DIGEST_SIZE)) == NULL) {
Expand All @@ -2808,7 +2808,7 @@ static int client_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl
(ret = tls->ctx->update_traffic_key->cb(tls->ctx->update_traffic_key, tls, 1, 2, tls->pending_handshake_secret)) != 0)
goto Exit;
} else {
if ((ret = setup_traffic_protection(tls, 1, "c hs traffic", 2, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 1, "c hs traffic", 2, 0, 0)) != 0)
goto Exit;
}

Expand Down Expand Up @@ -3373,7 +3373,7 @@ static int client_handle_finished(ptls_t *tls, ptls_message_emitter_t *emitter,
/* update traffic keys by using messages upto ServerFinished, but commission them after sending ClientFinished */
if ((ret = key_schedule_extract(tls->key_schedule, ptls_iovec_init(NULL, 0))) != 0)
goto Exit;
if ((ret = setup_traffic_protection(tls, 0, "s ap traffic", 3, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 0, "s ap traffic", 3, 0, 0)) != 0)
goto Exit;
if ((ret = derive_secret(tls->key_schedule, send_secret, "c ap traffic")) != 0)
goto Exit;
Expand Down Expand Up @@ -3407,7 +3407,7 @@ static int client_handle_finished(ptls_t *tls, ptls_message_emitter_t *emitter,
ret = send_finished(tls, emitter);

memcpy(tls->traffic_protection.enc.secret, send_secret, sizeof(send_secret));
if ((ret = setup_traffic_protection(tls, 1, NULL, 3, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 1, NULL, 3, 0, 0)) != 0)
goto Exit;

tls->state = PTLS_STATE_CLIENT_POST_HANDSHAKE;
Expand Down Expand Up @@ -4572,7 +4572,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl
}
if ((ret = derive_exporter_secret(tls, 1)) != 0)
goto Exit;
if ((ret = setup_traffic_protection(tls, 0, "c e traffic", 1, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 0, "c e traffic", 1, 0, 0)) != 0)
goto Exit;
}

Expand Down Expand Up @@ -4631,7 +4631,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl
/* create protection contexts for the handshake */
assert(tls->key_schedule->generation == 1);
key_schedule_extract(tls->key_schedule, ecdh_secret);
if ((ret = setup_traffic_protection(tls, 1, "s hs traffic", 2, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 1, "s hs traffic", 2, 0, 0)) != 0)
goto Exit;
if (tls->pending_handshake_secret != NULL) {
if ((ret = derive_secret(tls->key_schedule, tls->pending_handshake_secret, "c hs traffic")) != 0)
Expand All @@ -4640,7 +4640,7 @@ static int server_handle_hello(ptls_t *tls, ptls_message_emitter_t *emitter, ptl
(ret = tls->ctx->update_traffic_key->cb(tls->ctx->update_traffic_key, tls, 0, 2, tls->pending_handshake_secret)) != 0)
goto Exit;
} else {
if ((ret = setup_traffic_protection(tls, 0, "c hs traffic", 2, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 0, "c hs traffic", 2, 0, 0)) != 0)
goto Exit;
if (ch->psk.early_data_indication)
tls->server.early_data_skipped_bytes = 0;
Expand Down Expand Up @@ -4766,7 +4766,7 @@ static int server_finish_handshake(ptls_t *tls, ptls_message_emitter_t *emitter,
assert(tls->key_schedule->generation == 2);
if ((ret = key_schedule_extract(tls->key_schedule, ptls_iovec_init(NULL, 0))) != 0)
goto Exit;
if ((ret = setup_traffic_protection(tls, 1, "s ap traffic", 3, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 1, "s ap traffic", 3, 0, 0)) != 0)
goto Exit;
if ((ret = derive_secret(tls->key_schedule, tls->server.pending_traffic_secret, "c ap traffic")) != 0)
goto Exit;
Expand Down Expand Up @@ -4827,7 +4827,7 @@ static int server_handle_finished(ptls_t *tls, ptls_iovec_t message)

memcpy(tls->traffic_protection.dec.secret, tls->server.pending_traffic_secret, sizeof(tls->server.pending_traffic_secret));
ptls_clear_memory(tls->server.pending_traffic_secret, sizeof(tls->server.pending_traffic_secret));
if ((ret = setup_traffic_protection(tls, 0, NULL, 3, 0)) != 0)
if ((ret = setup_traffic_protection(tls, 0, NULL, 3, 0, 0)) != 0)
return ret;

ptls__key_schedule_update_hash(tls->key_schedule, message.base, message.len, 0);
Expand All @@ -4847,7 +4847,7 @@ static int update_traffic_key(ptls_t *tls, int is_enc)
"traffic upd", ptls_iovec_init(NULL, 0), NULL)) != 0)
goto Exit;
memcpy(tp->secret, secret, sizeof(secret));
ret = setup_traffic_protection(tls, is_enc, NULL, 3, 1);
ret = setup_traffic_protection(tls, is_enc, NULL, 3, 0, 1);

Exit:
ptls_clear_memory(secret, sizeof(secret));
Expand Down Expand Up @@ -5017,36 +5017,47 @@ ptls_t *ptls_server_new(ptls_context_t *ctx)
return tls;
}

#define export_tls_params(output, is_server, session_reused, protocol_version, cipher, client_random, server_name, \
negotiated_protocol, ver_block) \
do { \
const char *_server_name = (server_name); \
ptls_iovec_t _negotiated_protocol = (negotiated_protocol); \
ptls_buffer_push_block((output), 2, { \
ptls_buffer_push((output), (is_server)); \
ptls_buffer_push((output), (session_reused)); \
ptls_buffer_push16((output), (protocol_version)); \
ptls_buffer_push16((output), (cipher)->id); \
ptls_buffer_pushv((output), (client_random), PTLS_HELLO_RANDOM_SIZE); \
ptls_buffer_push_block((output), 2, { \
size_t len = _server_name != NULL ? strlen(_server_name) : 0; \
ptls_buffer_pushv((output), _server_name, len); \
}); \
ptls_buffer_push_block((output), 2, \
{ ptls_buffer_pushv((output), _negotiated_protocol.base, _negotiated_protocol.len); }); \
ptls_buffer_push_block((output), 2, {ver_block}); /* version-specific block */ \
ptls_buffer_push_block((output), 2, {}); /* for future extensions */ \
}); \
} while (0)

static int export_tls12_params(ptls_buffer_t *output, int is_server, int session_reused, ptls_cipher_suite_t *cipher,
const void *client_random, const char *server_name, ptls_iovec_t negotiated_protocol,
const void *enc_key, const void *enc_iv, uint64_t enc_seq, uint64_t enc_record_iv,
const void *dec_key, const void *dec_iv, uint64_t dec_seq)
{
int ret;

ptls_buffer_push_block(output, 2, {
ptls_buffer_push(output, is_server);
ptls_buffer_push(output, session_reused);
ptls_buffer_push16(output, PTLS_PROTOCOL_VERSION_TLS12);
ptls_buffer_push16(output, cipher->id);
ptls_buffer_pushv(output, client_random, PTLS_HELLO_RANDOM_SIZE);
ptls_buffer_push_block(output, 2, {
size_t len = server_name != NULL ? strlen(server_name) : 0;
ptls_buffer_pushv(output, server_name, len);
});
ptls_buffer_push_block(output, 2, { ptls_buffer_pushv(output, negotiated_protocol.base, negotiated_protocol.len); });
ptls_buffer_push_block(output, 2, {
ptls_buffer_pushv(output, enc_key, cipher->aead->key_size);
ptls_buffer_pushv(output, enc_iv, cipher->aead->tls12.fixed_iv_size);
ptls_buffer_push64(output, enc_seq);
if (cipher->aead->tls12.record_iv_size != 0)
ptls_buffer_push64(output, enc_record_iv);
ptls_buffer_pushv(output, dec_key, cipher->aead->key_size);
ptls_buffer_pushv(output, dec_iv, cipher->aead->tls12.fixed_iv_size);
ptls_buffer_push64(output, dec_seq);
});
ptls_buffer_push_block(output, 2, {}); /* for future extensions */
});
export_tls_params(output, is_server, session_reused, PTLS_PROTOCOL_VERSION_TLS12, cipher, client_random, server_name,
negotiated_protocol, {
ptls_buffer_pushv(output, enc_key, cipher->aead->key_size);
ptls_buffer_pushv(output, enc_iv, cipher->aead->tls12.fixed_iv_size);
ptls_buffer_push64(output, enc_seq);
if (cipher->aead->tls12.record_iv_size != 0)
ptls_buffer_push64(output, enc_record_iv);
ptls_buffer_pushv(output, dec_key, cipher->aead->key_size);
ptls_buffer_pushv(output, dec_iv, cipher->aead->tls12.fixed_iv_size);
ptls_buffer_push64(output, dec_seq);
});
ret = 0;

Exit:
return ret;
Expand Down Expand Up @@ -5094,20 +5105,39 @@ int ptls_build_tls12_export_params(ptls_context_t *ctx, ptls_buffer_t *output, i

int ptls_export(ptls_t *tls, ptls_buffer_t *output)
{
/* TODO add tls13 support */
if (!tls->traffic_protection.enc.tls12)
return PTLS_ERROR_LIBRARY;

ptls_iovec_t negotiated_protocol =
ptls_iovec_init(tls->negotiated_protocol, tls->negotiated_protocol != NULL ? strlen(tls->negotiated_protocol) : 0);
return export_tls12_params(output, tls->is_server, tls->is_psk_handshake, tls->cipher_suite, tls->client_random,
tls->server_name, negotiated_protocol, tls->traffic_protection.enc.secret,
tls->traffic_protection.enc.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.enc.seq,
tls->traffic_protection.enc.tls12_enc_record_iv, tls->traffic_protection.dec.secret,
tls->traffic_protection.dec.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.dec.seq);
int ret;

if (tls->state != PTLS_STATE_SERVER_POST_HANDSHAKE) {
ret = PTLS_ERROR_LIBRARY;
goto Exit;
}

if (ptls_get_protocol_version(tls) == PTLS_PROTOCOL_VERSION_TLS13) {
export_tls_params(output, tls->is_server, tls->is_psk_handshake, PTLS_PROTOCOL_VERSION_TLS13, tls->cipher_suite,
tls->client_random, tls->server_name, negotiated_protocol, {
ptls_buffer_pushv(output, tls->traffic_protection.enc.secret, tls->cipher_suite->hash->digest_size);
ptls_buffer_push64(output, tls->traffic_protection.enc.seq);
ptls_buffer_pushv(output, tls->traffic_protection.dec.secret, tls->cipher_suite->hash->digest_size);
ptls_buffer_push64(output, tls->traffic_protection.dec.seq);
});
ret = 0;
} else {
if ((ret = export_tls12_params(output, tls->is_server, tls->is_psk_handshake, tls->cipher_suite, tls->client_random,
tls->server_name, negotiated_protocol, tls->traffic_protection.enc.secret,
tls->traffic_protection.enc.secret + PTLS_MAX_SECRET_SIZE, tls->traffic_protection.enc.seq,
tls->traffic_protection.enc.tls12_enc_record_iv, tls->traffic_protection.dec.secret,
tls->traffic_protection.dec.secret + PTLS_MAX_SECRET_SIZE,
tls->traffic_protection.dec.seq)) != 0)
goto Exit;
}

Exit:
return ret;
}

static int build_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end)
static int import_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end)
{
struct st_ptls_traffic_protection_t *tp = is_enc ? &tls->traffic_protection.enc : &tls->traffic_protection.dec;

Expand All @@ -5134,6 +5164,22 @@ static int build_tls12_traffic_protection(ptls_t *tls, int is_enc, const uint8_t
return 0;
}

static int import_tls13_traffic_protection(ptls_t *tls, int is_enc, const uint8_t **src, const uint8_t *const end)
{
struct st_ptls_traffic_protection_t *tp = is_enc ? &tls->traffic_protection.enc : &tls->traffic_protection.dec;

/* set properties */
memcpy(tp->secret, *src, tls->cipher_suite->hash->digest_size);
*src += tls->cipher_suite->hash->digest_size;
if (ptls_decode64(&tp->seq, src, end) != 0)
return PTLS_ALERT_DECODE_ERROR;

if (setup_traffic_protection(tls, is_enc, NULL, 3, tp->seq, 0) != 0)
return PTLS_ERROR_INCOMPATIBLE_KEY;

return 0;
}

int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params)
{
const uint8_t *src = params.base, *const end = src + params.len;
Expand All @@ -5159,11 +5205,6 @@ int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params)
goto Exit;
if ((ret = ptls_decode16(&csid, &src, end)) != 0)
goto Exit;
(*tls)->cipher_suite = ptls_find_cipher_suite(ctx->tls12_cipher_suites, csid);
if ((*tls)->cipher_suite == NULL) {
ret = PTLS_ALERT_HANDSHAKE_FAILURE;
goto Exit;
}
/* other version-independent stuff */
if (end - src < PTLS_HELLO_RANDOM_SIZE) {
ret = PTLS_ALERT_DECODE_ERROR;
Expand All @@ -5189,15 +5230,36 @@ int ptls_import(ptls_context_t *ctx, ptls_t **tls, ptls_iovec_t params)
ptls_decode_open_block(src, end, 2, {
switch (protocol_version) {
case PTLS_PROTOCOL_VERSION_TLS12:
(*tls)->cipher_suite = ptls_find_cipher_suite(ctx->tls12_cipher_suites, csid);
if ((*tls)->cipher_suite == NULL) {
ret = PTLS_ALERT_HANDSHAKE_FAILURE;
goto Exit;
}
/* setup AEAD keys */
if ((ret = build_tls12_traffic_protection(*tls, 1, &src, end)) != 0)
if ((ret = import_tls12_traffic_protection(*tls, 1, &src, end)) != 0)
goto Exit;
if ((ret = build_tls12_traffic_protection(*tls, 0, &src, end)) != 0)
if ((ret = import_tls12_traffic_protection(*tls, 0, &src, end)) != 0)
goto Exit;
break;
case PTLS_PROTOCOL_VERSION_TLS13:
(*tls)->cipher_suite = ptls_find_cipher_suite(ctx->cipher_suites, csid);
if ((*tls)->cipher_suite == NULL) {
ret = PTLS_ALERT_HANDSHAKE_FAILURE;
goto Exit;
}
/* setup AEAD keys */
if (((*tls)->key_schedule = key_schedule_new((*tls)->cipher_suite, NULL, (*tls)->ech.aead != NULL)) == NULL) {
ret = PTLS_ERROR_NO_MEMORY;
goto Exit;
}
if ((ret = import_tls13_traffic_protection(*tls, 1, &src, end)) != 0)
goto Exit;
if ((ret = import_tls13_traffic_protection(*tls, 0, &src, end)) != 0)
goto Exit;
break;
default:
ret = PTLS_ALERT_ILLEGAL_PARAMETER;
break;
goto Exit;
}
});
/* extensions */
Expand Down Expand Up @@ -6232,7 +6294,6 @@ ptls_aead_context_t *new_aead(ptls_aead_algorithm_t *aead, ptls_hash_algorithm_t
if ((ret = get_traffic_keys(aead, hash, key_iv.key, key_iv.iv, secret, hash_value, label_prefix)) != 0)
goto Exit;
ctx = ptls_aead_new_direct(aead, is_enc, key_iv.key, key_iv.iv);

Exit:
ptls_clear_memory(&key_iv, sizeof(key_iv));
return ctx;
Expand Down
Loading

0 comments on commit 703553c

Please sign in to comment.