Skip to content

Commit

Permalink
fixes data race in CRTClientEngine (#424)
Browse files Browse the repository at this point in the history
* fixes data race in CRTClientEngine
* creates nested actor named SerialExecutor in CRTClientEngine
* public functions were directed to the actor to manage connection pools
  • Loading branch information
brennanMKE committed Jul 21, 2022
1 parent 902cbcd commit 44e4a16
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,67 @@ import Darwin
#endif

public class CRTClientEngine: HttpClientEngine {
actor SerialExecutor {
private var logger: LogAgent

private let windowSize: Int
private let maxConnectionsPerEndpoint: Int
private var connectionPools: [Endpoint: HttpClientConnectionManager] = [:]

init(config: CRTClientEngineConfig) {
self.windowSize = config.windowSize
self.maxConnectionsPerEndpoint = config.maxConnectionsPerEndpoint
self.logger = SwiftLogger(label: "SerialExecutor")
}

func getOrCreateConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
guard let connectionPool = connectionPools[endpoint] else {
let newConnectionPool = createConnectionPool(endpoint: endpoint)
connectionPools[endpoint] = newConnectionPool // save in dictionary
return newConnectionPool
}

return connectionPool
}

func closeAllPendingConnections() {
for (endpoint, value) in connectionPools {
logger.debug("Connection to endpoint: \(String(describing: endpoint.url?.absoluteString)) is closing")
value.closePendingConnections()
}
}

private func createConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
let tlsConnectionOptions = SDKDefaultIO.shared.tlsContext.newConnectionOptions()
do {
try tlsConnectionOptions.setServerName(endpoint.host)
} catch let err {
logger.error("Server name was not able to be set in TLS Connection Options. TLS Negotiation will fail.")
logger.error("Error: \(err.localizedDescription)")
}
let socketOptions = SocketOptions(socketType: .stream)
#if os(iOS) || os(watchOS)
socketOptions.connectTimeoutMs = 30_000
#endif
let options = HttpClientConnectionOptions(clientBootstrap: SDKDefaultIO.shared.clientBootstrap,
hostName: endpoint.host,
initialWindowSize: windowSize,
port: UInt16(endpoint.port),
proxyOptions: nil,
socketOptions: socketOptions,
tlsOptions: tlsConnectionOptions,
monitoringOptions: nil,
maxConnections: maxConnectionsPerEndpoint,
enableManualWindowManagement: false) // not using backpressure yet
logger.debug("Creating connection pool for \(String(describing: endpoint.url?.absoluteString))" +
"with max connections: \(maxConnectionsPerEndpoint)")
return HttpClientConnectionManager(options: options)
}
}

public typealias StreamContinuation = CheckedContinuation<HttpResponse, Error>
private var logger: LogAgent
private var connectionPools: [Endpoint: HttpClientConnectionManager] = [:]
private let serialExecutor: SerialExecutor
private let CONTENT_LENGTH_HEADER = "Content-Length"
private let AWS_COMMON_RUNTIME = "AwsCommonRuntime"
private let DEFAULT_STREAM_WINDOW_SIZE = 16 * 1024 * 1024 // 16 MB
Expand All @@ -26,70 +84,33 @@ public class CRTClientEngine: HttpClientEngine {
self.maxConnectionsPerEndpoint = config.maxConnectionsPerEndpoint
self.windowSize = config.windowSize
self.logger = SwiftLogger(label: "CRTClientEngine")
}

private func createConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {
let tlsConnectionOptions = SDKDefaultIO.shared.tlsContext.newConnectionOptions()
do {
try tlsConnectionOptions.setServerName(endpoint.host)
} catch let err {
logger.error("Server name was not able to be set in TLS Connection Options. TLS Negotiation will fail.")
logger.error("Error: \(err.localizedDescription)")
}
let socketOptions = SocketOptions(socketType: .stream)
#if os(iOS) || os(watchOS)
socketOptions.connectTimeoutMs = 30_000
#endif
let options = HttpClientConnectionOptions(clientBootstrap: SDKDefaultIO.shared.clientBootstrap,
hostName: endpoint.host,
initialWindowSize: windowSize,
port: UInt16(endpoint.port),
proxyOptions: nil,
socketOptions: socketOptions,
tlsOptions: tlsConnectionOptions,
monitoringOptions: nil,
maxConnections: maxConnectionsPerEndpoint,
enableManualWindowManagement: false) // not using backpressure yet
logger.debug("Creating connection pool for \(String(describing: endpoint.url?.absoluteString))" +
"with max connections: \(maxConnectionsPerEndpoint)")
return HttpClientConnectionManager(options: options)
}

private func getOrCreateConnectionPool(endpoint: Endpoint) -> HttpClientConnectionManager {

guard let connectionPool = connectionPools[endpoint] else {
let newConnectionPool = createConnectionPool(endpoint: endpoint)
connectionPools[endpoint] = newConnectionPool // save in dictionary
return newConnectionPool
}

return connectionPool
self.serialExecutor = SerialExecutor(config: config)
}

public func execute(request: SdkHttpRequest) async throws -> HttpResponse {
let connectionMgr = getOrCreateConnectionPool(endpoint: request.endpoint)
let connectionMgr = await serialExecutor.getOrCreateConnectionPool(endpoint: request.endpoint)
let connection = try await connectionMgr.acquireConnection()
self.logger.debug("Connection was acquired to: \(String(describing: request.endpoint.url?.absoluteString))")
return try await withCheckedThrowingContinuation({ (continuation: StreamContinuation) in
let requestOptions = makeHttpRequestStreamOptions(request, continuation)
let stream = connection.makeRequest(requestOptions: requestOptions)
stream.activate()
do {
let requestOptions = makeHttpRequestStreamOptions(request, continuation)
let stream = try connection.makeRequest(requestOptions: requestOptions)
try stream.activate()
} catch {
continuation.resume(throwing: error)
}
})

}

public func close() {
for (endpoint, value) in connectionPools {
logger.debug("Connection to endpoint: \(String(describing: endpoint.url?.absoluteString)) is closing")
value.closePendingConnections()
}
public func close() async {
await serialExecutor.closeAllPendingConnections()
}

public func makeHttpRequestStreamOptions(_ request: SdkHttpRequest, _ continuation: StreamContinuation) -> HttpRequestOptions {
let response = HttpResponse()
let crtRequest = request.toHttpRequest(bufferSize: windowSize)
let streamReader: StreamReader = DataStreamReader()

let requestOptions = HttpRequestOptions(request: crtRequest) { [self] (stream, _, httpHeaders) in
logger.debug("headers were received")
response.statusCode = HttpStatusCode(rawValue: Int(stream.statusCode)) ?? HttpStatusCode.notFound
Expand All @@ -113,11 +134,11 @@ public class CRTClientEngine: HttpClientEngine {
return
}
}

response.body = .stream(.reader(streamReader))

response.statusCode = HttpStatusCode(rawValue: Int(stream.statusCode)) ?? HttpStatusCode.notFound

continuation.resume(returning: response)
}
return requestOptions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ import AwsCommonRuntimeKit

public protocol HttpClientEngine {
func execute(request: SdkHttpRequest) async throws -> HttpResponse
func close()
func close() async
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ public class SdkHttpClient {
}

public func close() {
engine.close()
Task {
await self.engine.close()
}
}

}
Expand Down

0 comments on commit 44e4a16

Please sign in to comment.