Skip to content

Commit

Permalink
fix: Don't change HTTP request components after they have been signed (
Browse files Browse the repository at this point in the history
  • Loading branch information
jbelkins committed May 16, 2023
1 parent 0feeb94 commit a9b3758
Show file tree
Hide file tree
Showing 16 changed files with 176 additions and 161 deletions.
56 changes: 14 additions & 42 deletions Sources/ClientRuntime/Networking/Endpoint.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import Foundation

public struct Endpoint {
public struct Endpoint: Hashable {
public let path: String
public let queryItems: [URLQueryItem]?
public let protocolType: ProtocolType?
Expand Down Expand Up @@ -58,55 +58,27 @@ public struct Endpoint {
}
}

public extension Endpoint {
extension Endpoint {
// We still have to keep 'url' as an optional, since we're
// dealing with dynamic components that could be invalid.
var url: URL? {
public var url: URL? {
var components = URLComponents()
components.scheme = protocolType?.rawValue
components.host = host
components.path = path
components.percentEncodedQueryItems = queryItems
components.percentEncodedPath = path
components.percentEncodedQuery = queryItemString

return components.url
}

var queryItemString: String {
guard let queryItems = queryItems, !queryItems.isEmpty else {
return ""
}
let queryString = queryItems.map { "\($0.name)=\($0.value ?? "")" }.joined(separator: "&")
return "?\(queryString)"
}
}

// It was discovered that in Swift 5.8 and earlier versions, the URLQueryItem type does not correctly implement
// Hashable: namely, multiple URLQueryItems with the same name & value and that are equal by the == operator will have
// different hash values.
//
// Github issue filed against open-source Foundation:
// https://github.com/apple/swift-corelibs-foundation/issues/4737
//
// This extension is intended to correct this problem for the Endpoint type by substituting a
// different structure with the same properties as URLQueryItem when the Endpoint is hashed.
//
// This extension may be removed, and the compiler-generated Hashable compliance may be used instead, once the
// URLQueryItem's Hashable implementation is fixed in open-source Foundation.
extension Endpoint: Hashable {

private struct QueryItem: Hashable {
let name: String
let value: String?
}

public func hash(into hasher: inout Hasher) {
hasher.combine(path)
let queryItemElements = queryItems?.map { QueryItem(name: $0.name, value: $0.value) }
hasher.combine(queryItemElements)
hasher.combine(protocolType)
hasher.combine(host)
hasher.combine(port)
hasher.combine(headers)
hasher.combine(properties)
var queryItemString: String? {
guard let queryItems = queryItems else { return nil }
return queryItems.map { queryItem in
if let value = queryItem.value {
return "\(queryItem.name)=\(value)"
} else {
return queryItem.name
}
}.joined(separator: "&")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public class CRTClientEngine: HttpClientEngine {

return connectionPool
}

private func createConnectionPool(endpoint: Endpoint) throws -> HTTPClientConnectionManager {
let tlsConnectionOptions = TLSConnectionOptions(
context: sharedDefaultIO.tlsContext,
Expand All @@ -69,7 +69,7 @@ public class CRTClientEngine: HttpClientEngine {
enableManualWindowManagement: false
) // not using backpressure yet
logger.debug("""
Creating connection pool for \(String(describing: endpoint.url?.absoluteString)) \
Creating connection pool for \(String(describing: endpoint.host)) \
with max connections: \(maxConnectionsPerEndpoint)
""")
return try HTTPClientConnectionManager(options: options)
Expand All @@ -96,7 +96,7 @@ public class CRTClientEngine: HttpClientEngine {
enableStreamManualWindowManagement: false
)
logger.debug("""
Creating connection pool for \(String(describing: endpoint.url?.absoluteString)) \
Creating connection pool for \(String(describing: endpoint.host)) \
with max connections: \(maxConnectionsPerEndpoint)
""")

Expand Down Expand Up @@ -274,7 +274,7 @@ public class CRTClientEngine: HttpClientEngine {
}

requestOptions.http2ManualDataWrites = http2ManualDataWrites

response.body = .stream(stream)
return requestOptions
}
Expand Down
65 changes: 26 additions & 39 deletions Sources/ClientRuntime/Networking/Http/SdkHttpRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,56 +10,44 @@ import AwsCommonRuntimeKit
// we need to maintain a reference to this same request while we add headers
// in the CRT engine so that is why it's a class
public class SdkHttpRequest {
public var body: HttpBody
public var headers: Headers
public let queryItems: [URLQueryItem]?
public let body: HttpBody
public let endpoint: Endpoint
public let method: HttpMethodType
public var headers: Headers { endpoint.headers ?? Headers() }
public var path: String { endpoint.path }
public var host: String { endpoint.host }
public var queryItems: [URLQueryItem]? { endpoint.queryItems }

public init(method: HttpMethodType,
endpoint: Endpoint,
headers: Headers,
queryItems: [URLQueryItem]? = nil,
body: HttpBody = HttpBody.none) {
self.method = method
self.endpoint = endpoint
self.headers = headers
self.body = body
self.queryItems = queryItems
}
}

// Create a `CharacterSet` of the characters that need not be percent encoded in the
// resulting URL. This set consists of alphanumerics plus underscore, dash, tilde, and
// period. Any other character should be percent-encoded when used in a path segment.
// Forward-slash is added as well because the segments have already been joined into a path.
//
// See, for URL-allowed characters:
// https://www.rfc-editor.org/rfc/rfc3986#section-2.3
private let allowed = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "/_-.~"))

extension SdkHttpRequest {
public func toHttpRequest() throws -> HTTPRequest {
let httpHeaders = headers.toHttpHeaders()

public func toHttpRequest(escaping: Bool = false) throws -> HTTPRequest {
let httpRequest = try HTTPRequest()
httpRequest.method = method.rawValue
let encodedPath = endpoint.path.addingPercentEncoding(withAllowedCharacters: allowed) ?? endpoint.path
httpRequest.path = "\(encodedPath)\(endpoint.queryItemString)"
httpRequest.addHeaders(headers: httpHeaders)
let encodedPath = escaping ? endpoint.path.urlPercentEncodedForPath : endpoint.path
httpRequest.path = [encodedPath, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
httpRequest.addHeaders(headers: headers.toHttpHeaders())
httpRequest.body = StreamableHttpBody(body: body)
return httpRequest
}

/// Convert the SDK request to a CRT HTTPRequestBase
/// CRT converts the HTTPRequestBase to HTTP2Request internally if the protocol is HTTP/2
/// - Returns: the CRT request
public func toHttp2Request() throws -> HTTPRequestBase {
let httpHeaders = headers.toHttpHeaders()
public func toHttp2Request(escaping: Bool = false) throws -> HTTPRequestBase {
let httpRequest = try HTTPRequest()
httpRequest.method = method.rawValue
let encodedPath = endpoint.path.addingPercentEncoding(withAllowedCharacters: allowed) ?? endpoint.path
httpRequest.path = "\(encodedPath)\(endpoint.queryItemString)"
httpRequest.addHeaders(headers: httpHeaders)
let encodedPath = escaping ? endpoint.path.urlPercentEncodedForPath : endpoint.path
httpRequest.path = [encodedPath, endpoint.queryItemString].compactMap { $0 }.joined(separator: "?")
httpRequest.addHeaders(headers: headers.toHttpHeaders())

// HTTP2Request used with manual writes hence we need to set the body to nil
// so that CRT does not write the body for us (we will write it manually)
Expand Down Expand Up @@ -96,11 +84,11 @@ extension SdkHttpRequestBuilder {
public func update(from crtRequest: HTTPRequestBase, originalRequest: SdkHttpRequest) -> SdkHttpRequestBuilder {
headers = convertSignedHeadersToHeaders(crtRequest: crtRequest)
methodType = originalRequest.method
host = originalRequest.endpoint.host
if let crtRequest = crtRequest as? HTTPRequest {
let pathAndQueryItems = URLComponents(string: crtRequest.path)
path = pathAndQueryItems?.path ?? "/"
queryItems = pathAndQueryItems?.percentEncodedQueryItems ?? [URLQueryItem]()
host = originalRequest.host
if let crtRequest = crtRequest as? HTTPRequest, let components = URLComponents(string: crtRequest.path) {
path = components.percentEncodedPath
queryItems = components.percentEncodedQueryItems?.map { URLQueryItem(name: $0.name, value: $0.value) }
?? [URLQueryItem]()
} else if crtRequest as? HTTP2Request != nil {
assertionFailure("HTTP2Request not supported")
} else {
Expand All @@ -123,11 +111,11 @@ public class SdkHttpRequestBuilder {
var host: String = ""
var path: String = "/"
var body: HttpBody = .none
var queryItems = [URLQueryItem]()
var queryItems: [URLQueryItem]? = nil
var port: Int16 = 443
var protocolType: ProtocolType = .https

public var currentQueryItems: [URLQueryItem] {
public var currentQueryItems: [URLQueryItem]? {
return queryItems
}

Expand Down Expand Up @@ -179,14 +167,14 @@ public class SdkHttpRequestBuilder {

@discardableResult
public func withQueryItems(_ value: [URLQueryItem]) -> SdkHttpRequestBuilder {
self.queryItems = value
self.queryItems = self.queryItems ?? []
self.queryItems?.append(contentsOf: value)
return self
}

@discardableResult
public func withQueryItem(_ value: URLQueryItem) -> SdkHttpRequestBuilder {
self.queryItems.append(value)
return self
withQueryItems([value])
}

@discardableResult
Expand All @@ -206,11 +194,10 @@ public class SdkHttpRequestBuilder {
path: path,
port: port,
queryItems: queryItems,
protocolType: protocolType)
protocolType: protocolType,
headers: headers)
return SdkHttpRequest(method: methodType,
endpoint: endpoint,
headers: headers,
queryItems: queryItems,
body: body)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
//
// Copyright Amazon.com Inc. or its affiliates.
// All Rights Reserved.
//
// SPDX-License-Identifier: Apache-2.0
//

import Foundation

// Creates a `CharacterSet` of the characters that need not be percent encoded in the
// resulting URL. This set consists of alphanumerics plus underscore, dash, tilde, and
// period. Any other character should be percent-encoded when used in a path segment.
// Forward-slash is added as well because the segments have already been joined into a path.
//
// See, for URL-allowed characters:
// https://www.rfc-editor.org/rfc/rfc3986#section-2.3
private let allowedForPath = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "/_-.~"))
private let allowedForQuery = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "_-.~"))

extension String {

/// Encodes a URL component for inclusion in the path or query items, using percent-escaping.
///
/// All characters except alphanumerics plus forward slash, underscore, dash, tilde, and period will be escaped.
var urlPercentEncodedForPath: String {
addingPercentEncoding(withAllowedCharacters: allowedForPath) ?? self
}

/// Encodes a URL component for inclusion in query item name or value, using percent-escaping.
///
/// All characters except alphanumerics plus forward slash, underscore, dash, tilde, and period will be escaped.
var urlPercentEncodedForQuery: String {
addingPercentEncoding(withAllowedCharacters: allowedForQuery) ?? self
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,7 @@ extension String {

extension String {
public func urlPercentEncoding() -> String {
if let encodedString = self.addingPercentEncoding(withAllowedCharacters: .singleUrlQueryAllowed) {
return encodedString
}
return self
self.urlPercentEncodedForQuery
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,14 @@
*/

import Foundation

public typealias URL = Foundation.URL

extension URL {
func toQueryItems() -> [URLQueryItem]? { return URLComponents(url: self,
resolvingAgainstBaseURL: false)?.queryItems }

func toQueryItems() -> [URLQueryItem]? {
URLComponents(url: self, resolvingAgainstBaseURL: false)?
.queryItems?
.map { URLQueryItem(name: $0.name, value: $0.value) }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,14 @@
* SPDX-License-Identifier: Apache-2.0.
*/

import struct Foundation.URLQueryItem
public typealias URLQueryItem = Foundation.URLQueryItem
public typealias URLQueryItem = MyURLQueryItem

extension URLQueryItem: Comparable {
/// Compares two `URLQueryItem` instances by their `name` property.
/// - Parameters:
/// - lhs: The first `URLQueryItem` to compare.
/// - rhs: The second `URLQueryItem` to compare.
/// - Returns: `true` if the `name` property of `lhs` is less than the `name` property of `rhs`.
public static func < (lhs: URLQueryItem, rhs: URLQueryItem) -> Bool {
lhs.name < rhs.name
public struct MyURLQueryItem: Hashable {
public var name: String
public var value: String?

public init(name: String, value: String?) {
self.name = name
self.value = value
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
//
// SPDX-License-Identifier: Apache-2.0
//

import XCTest
import ClientRuntime

Expand All @@ -19,12 +20,12 @@ extension HttpRequestTestBase {
assertQueryItems(expectedQueryItems, actualQueryItems, file: file, line: line)
}

private func convertToQueryItems(data: Data) -> [URLQueryItem] {
private func convertToQueryItems(data: Data) -> [ClientRuntime.URLQueryItem] {
guard let queryString = String(data: data, encoding: .utf8) else {
XCTFail("Failed to decode data")
return []
}
var queryItems: [URLQueryItem] = []
var queryItems: [ClientRuntime.URLQueryItem] = []
let sanitizedQueryString = queryString.replacingOccurrences(of: "\n", with: "")
let keyValuePairs = sanitizedQueryString.components(separatedBy: "&")
for keyValue in keyValuePairs {
Expand Down
Loading

0 comments on commit a9b3758

Please sign in to comment.