Skip to content

Commit

Permalink
Add support for should_sign_header
Browse files Browse the repository at this point in the history
  • Loading branch information
lauzadis committed Apr 26, 2024
1 parent 738217a commit f4c6550
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@ package aws.sdk.kotlin.crt.auth.signing

import aws.sdk.kotlin.crt.*
import aws.sdk.kotlin.crt.auth.credentials.Credentials
import aws.sdk.kotlin.crt.http.Headers
import aws.sdk.kotlin.crt.http.HttpRequest
import aws.sdk.kotlin.crt.http.HttpRequestBodyStream
import aws.sdk.kotlin.crt.http.headers
import aws.sdk.kotlin.crt.http.*
import aws.sdk.kotlin.crt.io.Uri
import kotlinx.coroutines.test.runTest
import kotlin.test.*
Expand Down Expand Up @@ -192,4 +189,49 @@ class SigningTest : CrtTest() {
val expectedSignature = "8b578658fa1705d62bf26aa73e764ac4b705e6d9efd223a2d9e156580f085de4" // validated using DefaultAwsSigner
assertEquals(expectedSignature, signature)
}

@Test
fun testShouldSignHeader() = runTest {
val request = HttpRequestBuilder().apply {
method = "POST"
encodedPath = "https://www.example.com"
headers {
append("bad-header", "should not be signed")
append("Host", "https://www.example.com")
}
}.build()

val baseSigningConfig = AwsSigningConfig.Builder().apply {
algorithm = AwsSigningAlgorithm.SIGV4
signatureType = AwsSignatureType.HTTP_REQUEST_VIA_HEADERS
region = "us-east-1"
service = "service"
date = Platform.epochMilliNow()
credentials = Credentials(TEST_ACCESS_KEY_ID, TEST_SECRET_ACCESS_KEY, null)
useDoubleUriEncode = true
normalizeUriPath = true
}

val skipHeaderConfig = baseSigningConfig.apply {
shouldSignHeader = { it != "bad-header" }
}.build()
val implicitSignAllHeadersConfig = baseSigningConfig.apply {
shouldSignHeader = null
}.build()
val explicitSignAllHeadersConfig = baseSigningConfig.apply {
shouldSignHeader = { true }
}.build()

val skipHeaderSignedRequest = AwsSigner.signRequest(request, skipHeaderConfig)
assertTrue(skipHeaderSignedRequest.headers.contains("Authorization"))
assertFalse(skipHeaderSignedRequest.headers["Authorization"]!!.contains("bad-header"))

val implicitSignAllHeadersRequest = AwsSigner.signRequest(request, implicitSignAllHeadersConfig)
assertTrue(implicitSignAllHeadersRequest.headers.contains("Authorization"))
assertTrue(implicitSignAllHeadersRequest.headers["Authorization"]!!.contains("bad-header"))

val explicitSignAllHeadersRequest = AwsSigner.signRequest(request, explicitSignAllHeadersConfig)
assertTrue(explicitSignAllHeadersRequest.headers.contains("Authorization"))
assertTrue(explicitSignAllHeadersRequest.headers["Authorization"]!!.contains("bad-header"))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,11 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer<aws_signing_confi
service.initFromCursor(this@toNativeSigningConfig.service.toAwsString().asAwsByteCursor())
aws_date_time_init_epoch_millis(date.ptr, this@toNativeSigningConfig.date.toULong())

// FIXME Can't convert Kotlin config's shouldSignHeader to a C function without capturing the config variable, and staticCFunction cannot capture variables.
// should_sign_header = [email protected]?.toNativeShouldSignHeaderFn()
// should_sign_header_ud =
this@toNativeSigningConfig.shouldSignHeader?.let {
val shouldSignHeaderStableRef = StableRef.create(it)
should_sign_header = staticCFunction(::nativeShouldSignHeaderFn)
should_sign_header_ud = shouldSignHeaderStableRef.asCPointer()
}

flags.use_double_uri_encode = if (this@toNativeSigningConfig.useDoubleUriEncode) 1u else 0u
flags.should_normalize_uri_path = if (this@toNativeSigningConfig.normalizeUriPath) 1u else 0u
Expand All @@ -215,6 +217,16 @@ private fun AwsSigningConfig.toNativeSigningConfig(): CPointer<aws_signing_confi
return config.ptr
}

private typealias ShouldSignHeaderFunction = (String) -> Boolean
private fun nativeShouldSignHeaderFn(headerName: CPointer<aws_byte_cursor>?, userData: COpaquePointer?): Boolean {
checkNotNull(headerName) { "aws_should_sign_header_fn expected non-null header name" }
if (userData == null) { return true }

val kShouldSignHeaderFn = userData.asStableRef<ShouldSignHeaderFunction>().get()
val kHeaderName = headerName.pointed.toKString()
return kShouldSignHeaderFn(kHeaderName)
}

/**
* Callback for standard request signing. Applies the given signing result to the HTTP message and then returns the
* signature via callback channel.
Expand Down

0 comments on commit f4c6550

Please sign in to comment.