From 7f5fcee361d8d6a557d6caeb149f8ece40f5a8bd Mon Sep 17 00:00:00 2001 From: Steve Kim <86316075+sbSteveK@users.noreply.github.com> Date: Thu, 23 Feb 2023 15:51:14 -0800 Subject: [PATCH] Secure Tunnel V2 WebSocket Protocol support (#533) * Secure Tunnel API Expansion * Secure Tunnel V2 WebSocket Protocol support (Multiplexing) --- README.md | 1 + crt/aws-c-iot | 2 +- crt/aws-crt-cpp | 2 +- documents/Secure_Tunnel_Userguide.md | 165 ++++ samples/README.md | 10 +- .../secure_tunneling/secure_tunnel/main.cpp | 557 +++++++----- secure_tunneling/CMakeLists.txt | 3 - .../aws/iotsecuretunneling/SecureTunnel.h | 799 ++++++++++++++++-- secure_tunneling/source/SecureTunnel.cpp | 585 +++++++++++-- secure_tunneling/tests/CMakeLists.txt | 20 - secure_tunneling/tests/SecureTunnelTest.cpp | 263 ------ utils/run_secure_tunnel_ci.py | 39 +- 12 files changed, 1818 insertions(+), 628 deletions(-) create mode 100644 documents/Secure_Tunnel_Userguide.md delete mode 100644 secure_tunneling/tests/CMakeLists.txt delete mode 100644 secure_tunneling/tests/SecureTunnelTest.cpp diff --git a/README.md b/README.md index 2dfde6e2b..e1341dbca 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,7 @@ __Jump To:__ * [FAQ](./documents/FAQ.md) * [Giving Feedback and Contributions](#Giving-Feedback-and-Contributions) * [MQTT5 User Guide](./documents/MQTT5_Userguide.md) +* [Secure Tunnel User Guide](./documents/Secure_Tunnel_Userguide.md) ## Installation diff --git a/crt/aws-c-iot b/crt/aws-c-iot index fc21ca507..09ded2b5e 160000 --- a/crt/aws-c-iot +++ b/crt/aws-c-iot @@ -1 +1 @@ -Subproject commit fc21ca50727548626969a51bc18cf07eca35061b +Subproject commit 09ded2b5e5bd34bbcf0fd71b5482381cf7f08627 diff --git a/crt/aws-crt-cpp b/crt/aws-crt-cpp index f81587db3..7ff9e0343 160000 --- a/crt/aws-crt-cpp +++ b/crt/aws-crt-cpp @@ -1 +1 @@ -Subproject commit f81587db32a2f13764b0378fb069161f6742e376 +Subproject commit 7ff9e0343c978fc54f440b98147c2f72d304f6d8 diff --git a/documents/Secure_Tunnel_Userguide.md b/documents/Secure_Tunnel_Userguide.md new file mode 100644 index 000000000..d6f071e4a --- /dev/null +++ b/documents/Secure_Tunnel_Userguide.md @@ -0,0 +1,165 @@ +# Introduction +When devices are deployed behind restricted firewalls at remote sites, you need a way to gain access to those devices for troubleshooting, configuration updates, and other operational tasks. Use secure tunneling to establish bidirectional communication to remote devices over a secure connection that is managed by AWS IoT. Secure tunneling does not require updates to your existing inbound firewall rules, so you can keep the same security level provided by firewall rules at a remote site. + +More information on the service and how to open, close, and manage secure tunnels can be found here: https://docs.aws.amazon.com/iot/latest/developerguide/secure-tunneling.html + +A sample is also provided and can be found here: https://github.com/aws/aws-iot-device-sdk-cpp-v2/tree/main/samples#secure-tunnel + + + +# Getting started with Secure Tunnels +## How to Create a Secure Tunnel Client +Once a Secure Tunnel builder has been created, it is ready to make a Secure Tunnel client. Something important to note is that once a Secure Tunnel client is built and finalized, the configuration is immutable and cannot be changed. Further modifications to the Secure Tunnel builder will not change the settings of already created Secure Tunnel clients. + +```cpp +// Create Secure Tunnel Builder +SecureTunnelBuilder builder = SecureTunnelBuilder(...); + +// Build Secure Tunnel Client +std::shared_ptr secureTunnel = builder.Build(); + +if (secureTunnel == nullptr) +{ + fprintf(stdout, "Secure Tunnel creation failed.\n"); + return -1; +} + +// Start the secure tunnel connection +if (!secureTunnel->Start()) +{ + fprintf("Failed to start Secure Tunnel\n"); + return -1; +} +``` +## Callbacks + +### OnConnectionSuccess +When the Secure Tunnel Client successfully connects with the Secure Tunnel service, this callback will return the available (if any) service ids. + +### OnConnectionFailure +When a WebSocket upgrade request fails to connect, this callback will return an error code. + +### OnConnectionShutdown +When the WebSocket connection shuts down, this callback will be invoked. + +### OnSendDataComplete +When a message has been completely written to the socket, this callback will be invoked. + +### OnMessageReceived +When a message is received on an open Secure Tunnel stream, this callback will return the message. + +### OnStreamStarted +When a stream is started by a Source connected to the Destination, the Destination will invoke this callback and return the stream information. + +### OnStreamStopped +When an open stream is closed, this callback will be invoked and return the stopped stream's information. + +### OnSessionReset +When the Secure Tunnel service requests the Secure Tunnel client fully reset, this callback is invoked. + +### OnStopped +When the Secure Tunnel has reached a fully stopped state this callback is invoked. + +## Setting Secure Tunnel Callbacks +The Secure Tunnel client uses callbacks to keep the user updated on its status and pass along relavant information. These can be set up using the Secure Tunnel builder's With functions. + +```cpp +// Create Secure Tunnel Builder +SecureTunnelBuilder builder = SecureTunnelBuilder(...); + +// Setting the onMessageReceived callback using the builder +builder.WithOnMessageReceived([&](SecureTunnel *secureTunnel, const MessageReceivedEventData &eventData) { + { + std::shared_ptr message = eventData.message; + if (message->getServiceId().has_value()){ + fprintf( + stdout, + "Message received on service id:'" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(message->getServiceId().value())); + } + + if(message->getPayload().has_value()){ + fprintf( + stdout, + "Message has payload:'" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(message->getPayload().value())); + } + } + }); + +// Build Secure Tunnel Client +std::shared_ptr secureTunnel = builder.Build(); + +if (secureTunnel == nullptr) +{ + fprintf(stdout, "Secure Tunnel creation failed.\n"); + return -1; +} + +// Start the secure tunnel connection +if (!secureTunnel->Start()) +{ + fprintf("Failed to start Secure Tunnel\n"); + return -1; +} + +// Messages received on a stream will now be printed to stdout. +``` + +# How to Start and Stop + +## Start +Invoking `Start()` on the Secure Tunnel Client will put it into an active state where it recurrently establishes a connection to the configured Secure Tunnel endpoint using the provided [Client Access Token](https://docs.aws.amazon.com/iot/latest/developerguide/secure-tunneling-concepts.html). If a [Client Token](https://docs.aws.amazon.com/iot/latest/developerguide/secure-tunneling-concepts.html) is provided, the Secure Tunnel Client will use it. If a Client Token is not provided, the Secure Tunnel Client will automatically generate one for use on a reconnection attempts. The Client Token for any initial connection to the Secure Tunnel service **MUST** be unique. Reusing a Client Token from a previous connection will result in a failed connection to the Secure Tunnel Service. +```cpp +// Create Secure Tunnel Builder +SecureTunnelBuilder builder = SecureTunnelBuilder(...); + +// Adding a client token to the builder +String clientToken; +builder.WithClientToken(clientToken.c_str()); + +// Build Secure Tunnel Client +std::shared_ptr secureTunnel = builder.Build(); + +if (secureTunnel == nullptr) +{ + fprintf(stdout, "Secure Tunnel creation failed.\n"); + return -1; +} + +// Start the secure tunnel connection +if (!secureTunnel->Start()) +{ + fprintf("Failed to start Secure Tunnel\n"); + return -1; +} +``` + +## Stop +Invoking `Stop()` on the Secure Tunnel Client breaks the current connection (if any) and moves the client into an idle state. +```cpp +if(!secureTunnel->Stop()){ + fprintf(stdout, "Failed to stop the Secure Tunnel connection session. Exiting..\n"); +} +``` + +# Secure Tunnel Operations + +## Send Message +The SendMessage operation takes a description of the Message you wish to send and returns a success/failure in the synchronous logic that kicks off the Send Message operation. When the message is fully written to the socket, the OnSendDataComplete callback will be invoked. + +```cpp +Crt::String serviceId_string = "ssh"; +Crt::String message_string = "any payload"; + +ByteCursor serviceId = ByteCursorFromString(serviceId_string); +ByteCursor payload = ByteCursorFromString(message_string); + +// Create Message +std::shared_ptr message = std::make_shared(); +message->withServiceId(serviceId); +message->withPayload(payload); + +// Send Message +secureTunnel->SendMessage(message); +``` diff --git a/samples/README.md b/samples/README.md index 985278448..37d52d446 100644 --- a/samples/README.md +++ b/samples/README.md @@ -977,7 +977,7 @@ using a permanent certificate set, replace the paths specified in the `--cert` a ## Secure Tunnel -This sample uses AWS IoT [Secure Tunneling](https://docs.aws.amazon.com/iot/latest/developerguide/secure-tunneling.html) Service to connect a destination and a source with each other through the AWS Secure Tunnel endpoint using access tokens. +This sample uses AWS IoT [Secure Tunneling](https://docs.aws.amazon.com/iot/latest/developerguide/secure-tunneling.html) Service to connect a destination and a source with each other through the AWS Secure Tunnel endpoint using access tokens using the [V2WebSocketProtocol](https://github.com/aws-samples/aws-iot-securetunneling-localproxy/blob/main/V2WebSocketProtocolGuide.md). [Secure Tunnel Userguide](https://github.com/aws/aws-iot-device-sdk-cpp-v2/blob/main/documents/Secure_Tunnel_Userguide.md) Source: `samples/secure_tunneling/secure_tunnel` @@ -986,17 +986,13 @@ Create a new secure tunnel in the AWS IoT console (https://console.aws.amazon.co Provide the necessary arguments along with the destination access token and start the sample in destination mode (default). ``` sh -./secure_tunnel --endpoint --ca_file ---cert --key ---thing_name --region --access_token_file +./secure_tunnel --region --access_token_file ``` Provide the necessary arguments along with the source access token and start a second sample in source mode by using the flag --localProxyModeSource. ``` sh -./secure_tunnel --endpoint --ca_file ---cert --key ---thing_name --region --access_token_file +./secure_tunnel --region --access_token_file --localProxyModeSource ``` diff --git a/samples/secure_tunneling/secure_tunnel/main.cpp b/samples/secure_tunneling/secure_tunnel/main.cpp index 711130bfd..eabd2f944 100644 --- a/samples/secure_tunneling/secure_tunnel/main.cpp +++ b/samples/secure_tunneling/secure_tunnel/main.cpp @@ -2,6 +2,7 @@ * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. * SPDX-License-Identifier: Apache-2.0. */ +#include #include #include #include @@ -18,97 +19,160 @@ using namespace Aws::Iotsecuretunneling; using namespace Aws::Crt::Io; using namespace std::chrono_literals; -int main(int argc, char *argv[]) +void logMessage(std::shared_ptr message) { - ApiHandle apiHandle; - - String accessToken; - aws_secure_tunneling_local_proxy_mode localProxyMode; - - String proxyHost; - uint16_t proxyPort(8080); - String proxyUserName; - String proxyPassword; - - std::shared_ptr secureTunnel; + if (message->getServiceId().has_value()) + { + if (message->getPayload().has_value()) + { + fprintf( + stdout, + "Message received on service id:'" PRInSTR "' with payload:'" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(message->getServiceId().value()), + AWS_BYTE_CURSOR_PRI(message->getPayload().value())); + } + else + { + fprintf( + stdout, + "Message with service id:'" PRInSTR "' with no payload.\n", + AWS_BYTE_CURSOR_PRI(message->getServiceId().value())); + } + return; + } + if (message->getPayload().has_value()) + { + fprintf( + stdout, + "Message received with payload:'" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(message->getPayload().value())); + } +} - /*********************** Parse Arguments ***************************/ - Utils::CommandLineUtils cmdUtils = Utils::CommandLineUtils(); - cmdUtils.AddCommonProxyCommands(); - cmdUtils.RegisterProgramName("secure_tunnel"); - cmdUtils.RegisterCommand("region", "", "The region of your secure tunnel"); - cmdUtils.RegisterCommand( +void setupCommandLineUtils(Utils::CommandLineUtils *cmdUtils, int argc, char *argv[]) +{ + cmdUtils->AddCommonProxyCommands(); + cmdUtils->RegisterProgramName("secure_tunnel"); + cmdUtils->RegisterCommand("region", "", "The region of your secure tunnel"); + cmdUtils->RegisterCommand( "ca_file", "", "Path to AmazonRootCA1.pem (optional, system trust store used by default)."); - cmdUtils.RegisterCommand( + cmdUtils->RegisterCommand( "access_token_file", "", "Path to the tunneling access token file (optional if --access_token used)."); - cmdUtils.RegisterCommand("access_token", "", "Tunneling access token (optional if --access_token_file used)."); - cmdUtils.RegisterCommand( + cmdUtils->RegisterCommand( + "access_token", "", "Tunneling access token (optional if --access_token_file used)."); + cmdUtils->RegisterCommand( "local_proxy_mode_source", "", "Use to set local proxy mode to source (optional, default='destination')."); - cmdUtils.RegisterCommand("message", "", "Message to send (optional, default='Hello World!')."); - cmdUtils.RegisterCommand("test", "", "Used to trigger internal testing (optional, ignore unless testing)."); - cmdUtils.RegisterCommand( + cmdUtils->RegisterCommand( + "client_token", "", "Tunneling access token (optional if --client_token_file used)."); + cmdUtils->RegisterCommand("message", "", "Message to send (optional, default='Hello World!')."); + cmdUtils->RegisterCommand( "proxy_user_name", "", "User name passed if proxy server requires a user name (optional)"); - cmdUtils.RegisterCommand( + cmdUtils->RegisterCommand( "proxy_password", "", "Password passed if proxy server requires a password (optional)"); - cmdUtils.AddLoggingCommands(); + cmdUtils->RegisterCommand("count", "", "Number of messages to send before completing (optional, default='5')"); + cmdUtils->AddLoggingCommands(); const char **const_argv = (const char **)argv; - cmdUtils.SendArguments(const_argv, const_argv + argc); - cmdUtils.StartLoggingBasedOnCommand(&apiHandle); + cmdUtils->SendArguments(const_argv, const_argv + argc); +} - /* - * Generate secure tunneling endpoint using region - */ - String region = cmdUtils.GetCommandRequired("region"); - String endpoint = "data.tunneling.iot." + region + ".amazonaws.com"; +void setupCommandLineValues( + Utils::CommandLineUtils *cmdUtils, + String *endpoint, + String *accessToken, + String *clientToken, + String *caFile, + String *proxyHost, + String *proxyUserName, + String *proxyPassword, + uint16_t &proxyPort, + uint16_t &messageCount, + aws_secure_tunneling_local_proxy_mode &localProxyMode, + String *payloadMessage) +{ + /* Generate secure tunneling endpoint using region */ + String region = cmdUtils->GetCommandRequired("region"); + endpoint->assign("data.tunneling.iot." + region + ".amazonaws.com"); + + String tempAccessToken; + /* An access token is required to connect to the secure tunnel service */ + if (cmdUtils->HasCommand("access_token")) + { + tempAccessToken = cmdUtils->GetCommand("access_token"); + } + else if (cmdUtils->HasCommand("access_token_file")) + { + tempAccessToken = cmdUtils->GetCommand("access_token_file"); - if (!(cmdUtils.HasCommand("access_token_file") || cmdUtils.HasCommand("access_token"))) + std::ifstream accessTokenFile(tempAccessToken.c_str()); + if (accessTokenFile.is_open()) + { + getline(accessTokenFile, tempAccessToken); + accessTokenFile.close(); + } + else + { + fprintf(stderr, "Failed to open access token file"); + exit(-1); + } + } + else { - cmdUtils.PrintHelp(); + cmdUtils->PrintHelp(); fprintf(stderr, "--access_token_file or --access_token must be set to connect through a secure tunnel"); exit(-1); } + accessToken->assign(tempAccessToken); - if (cmdUtils.HasCommand("access_token")) + String tempClientToken; + /* + * A client token is optional as one will be automatically generated if it is absent but it is recommended the + * customer provides their own so they can reuse it with other secure tunnel clients after the secure tunnel client + * is terminated. + * */ + if (cmdUtils->HasCommand("client_token")) { - accessToken = cmdUtils.GetCommand("access_token"); + tempClientToken = cmdUtils->GetCommand("client_token"); } - else + + if (cmdUtils->HasCommand("client_token_file")) { - accessToken = cmdUtils.GetCommand("access_token_file"); + tempClientToken = cmdUtils->GetCommand("client_token_file"); - std::ifstream accessTokenFile(accessToken.c_str()); - if (accessTokenFile.is_open()) + std::ifstream clientTokenFile(tempClientToken.c_str()); + if (clientTokenFile.is_open()) { - getline(accessTokenFile, accessToken); - accessTokenFile.close(); + getline(clientTokenFile, tempClientToken); + clientTokenFile.close(); } else { - fprintf(stderr, "Failed to open access token file"); + fprintf(stderr, "Failed to open client token file\n"); exit(-1); } } - if (cmdUtils.HasCommand("proxy_host") || cmdUtils.HasCommand("proxy_port")) + clientToken->assign(tempClientToken); + + caFile->assign(cmdUtils->GetCommandOrDefault("ca_file", "")); + + if (cmdUtils->HasCommand("proxy_host") || cmdUtils->HasCommand("proxy_port")) { - proxyHost = - cmdUtils.GetCommandRequired("proxy_host", "--proxy_host must be set to connect through a proxy.").c_str(); + proxyHost->assign( + cmdUtils->GetCommandRequired("proxy_host", "--proxy_host must be set to connect through a proxy.").c_str()); int port = atoi( - cmdUtils.GetCommandRequired("proxy_port", "--proxy_port must be set to connect through a proxy.").c_str()); + cmdUtils->GetCommandRequired("proxy_port", "--proxy_port must be set to connect through a proxy.").c_str()); if (port > 0 && port <= UINT16_MAX) { proxyPort = static_cast(port); } - proxyUserName = cmdUtils.GetCommandOrDefault("proxy_user_name", ""); - proxyPassword = cmdUtils.GetCommandOrDefault("proxy_password", ""); + proxyUserName->assign(cmdUtils->GetCommandOrDefault("proxy_user_name", "")); + proxyPassword->assign(cmdUtils->GetCommandOrDefault("proxy_password", "")); } - String caFile = cmdUtils.GetCommandOrDefault("ca_file", ""); - /* * localProxyMode is set to destination by default unless flag is set to source */ - if (cmdUtils.HasCommand("local_proxy_mode_source")) + if (cmdUtils->HasCommand("local_proxy_mode_source")) { localProxyMode = AWS_SECURE_TUNNELING_SOURCE_MODE; } @@ -117,13 +181,64 @@ int main(int argc, char *argv[]) localProxyMode = AWS_SECURE_TUNNELING_DESTINATION_MODE; } - String message = cmdUtils.GetCommandOrDefault("message", "Hello World"); + payloadMessage->assign(cmdUtils->GetCommandOrDefault("message", "Hello World")); + int count = atoi(cmdUtils->GetCommandOrDefault("count", "5").c_str()); + messageCount = static_cast(count); +} + +int main(int argc, char *argv[]) +{ + struct aws_allocator *allocator = aws_default_allocator(); + /************************ Setup the Lib ****************************/ /* - * For internal testing + * Do the global initialization for the API and aws-c-iot. */ - bool isTest = cmdUtils.HasCommand("test"); - int expectedMessageCount = 5; + ApiHandle apiHandle; + aws_iotdevice_library_init(allocator); + + /* + * In a real world application you probably don't want to enforce synchronous behavior + * but this is a sample console application, so we'll just do that with a condition variable. + */ + std::promise connectionCompletedPromise; + std::promise connectionClosedPromise; + std::promise clientStoppedPromise; + + /* service id storage for use in sample */ + Aws::Crt::ByteBuf m_serviceIdStorage; + AWS_ZERO_STRUCT(m_serviceIdStorage); + Aws::Crt::Optional m_serviceId; + + String endpoint; + String accessToken; + String clientToken; + String caFile; + String proxyHost; + uint16_t proxyPort(8080); + String proxyUserName; + String proxyPassword; + aws_secure_tunneling_local_proxy_mode localProxyMode; + String payloadMessage; + uint16_t messageCount(5); + + /*********************** Parse Arguments ***************************/ + Utils::CommandLineUtils cmdUtils = Utils::CommandLineUtils(); + setupCommandLineUtils(&cmdUtils, argc, argv); + cmdUtils.StartLoggingBasedOnCommand(&apiHandle); + setupCommandLineValues( + &cmdUtils, + &endpoint, + &accessToken, + &clientToken, + &caFile, + &proxyHost, + &proxyUserName, + &proxyPassword, + proxyPort, + messageCount, + localProxyMode, + &payloadMessage); if (apiHandle.GetOrCreateStaticDefaultClientBootstrap()->LastError() != AWS_ERROR_SUCCESS) { @@ -134,111 +249,166 @@ int main(int argc, char *argv[]) exit(-1); } - /* - * In a real world application you probably don't want to enforce synchronous behavior - * but this is a sample console application, so we'll just do that with a condition variable. - */ - std::promise connectionCompletedPromise; - std::promise connectionClosedPromise; + /* Use a SecureTunnelBuilder to set up and build the secure tunnel client */ + SecureTunnelBuilder builder = SecureTunnelBuilder(allocator, accessToken.c_str(), localProxyMode, endpoint.c_str()); + + if (caFile.length() > 0) + { + builder.WithRootCa(caFile.c_str()); + } + + builder.WithClientToken(clientToken.c_str()); + + /* Add callbacks using the builder */ + + builder.WithOnConnectionSuccess([&](SecureTunnel *secureTunnel, const ConnectionSuccessEventData &eventData) { + if (eventData.connectionData->getServiceId1().has_value()) + { + /* If secure tunnel is using service ids, store one for future use */ + aws_byte_buf_clean_up(&m_serviceIdStorage); + AWS_ZERO_STRUCT(m_serviceIdStorage); + aws_byte_buf_init_copy_from_cursor( + &m_serviceIdStorage, allocator, eventData.connectionData->getServiceId1().value()); + m_serviceId = aws_byte_cursor_from_buf(&m_serviceIdStorage); + + fprintf( + stdout, + "Secure Tunnel connected with Service IDs '" PRInSTR "'", + AWS_BYTE_CURSOR_PRI(eventData.connectionData->getServiceId1().value())); + if (eventData.connectionData->getServiceId2().has_value()) + { + fprintf( + stdout, ", '" PRInSTR "'", AWS_BYTE_CURSOR_PRI(eventData.connectionData->getServiceId2().value())); + if (eventData.connectionData->getServiceId3().has_value()) + { + fprintf( + stdout, + ", '" PRInSTR "'", + AWS_BYTE_CURSOR_PRI(eventData.connectionData->getServiceId3().value())); + } + } + fprintf(stdout, "\n"); - /*********************** Callbacks ***************************/ - auto OnConnectionComplete = [&]() { - switch (localProxyMode) + /* Stream Start can only be called from Source Mode */ + if (localProxyMode == AWS_SECURE_TUNNELING_SOURCE_MODE) + { + fprintf( + stdout, + "Sending Stream Start request with service id:'" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(eventData.connectionData->getServiceId1().value())); + secureTunnel->SendStreamStart(eventData.connectionData->getServiceId1().value()); + } + } + else { - case AWS_SECURE_TUNNELING_DESTINATION_MODE: - fprintf(stdout, "Connection Complete in Destination Mode\n"); - break; - case AWS_SECURE_TUNNELING_SOURCE_MODE: - connectionCompletedPromise.set_value(true); - fprintf(stdout, "Connection Complete in Source Mode\n"); + fprintf(stdout, "Secure Tunnel is not using Service Ids.\n"); + + /* Stream Start can only be called from Source Mode */ + if (localProxyMode == AWS_SECURE_TUNNELING_SOURCE_MODE) + { fprintf(stdout, "Sending Stream Start request\n"); secureTunnel->SendStreamStart(); - break; + } } - }; - auto OnConnectionShutdown = [&]() { + connectionCompletedPromise.set_value(true); + }); + + builder.WithOnConnectionFailure([&](SecureTunnel *secureTunnel, int errorCode) { + (void)secureTunnel; + fprintf(stdout, "Connection attempt failed with error code %d(%s)\n", errorCode, ErrorDebugString(errorCode)); + }); + + builder.WithOnConnectionShutdown([&]() { fprintf(stdout, "Connection Shutdown\n"); connectionClosedPromise.set_value(true); - }; + }); - auto OnSendDataComplete = [&](int error_code) { - switch (localProxyMode) + builder.WithOnMessageReceived([&](SecureTunnel *secureTunnel, const MessageReceivedEventData &eventData) { { - case AWS_SECURE_TUNNELING_DESTINATION_MODE: - if (!error_code) - { - fprintf(stdout, "Send Data Complete in Destination Mode\n"); - } - else - { - fprintf(stderr, "Send Data Failed: %s\n", ErrorDebugString(error_code)); - } + std::shared_ptr message = eventData.message; - break; - case AWS_SECURE_TUNNELING_SOURCE_MODE: - if (!error_code) - { - fprintf(stdout, "Send Data Complete in Source Mode\n"); - } - else - { - fprintf(stderr, "Send Data Failed: %s\n", ErrorDebugString(error_code)); - } - break; - } - }; + logMessage(message); + std::shared_ptr echoMessage; - auto OnDataReceive = [&](const struct aws_byte_buf &data) { - String receivedData = String((char *)data.buffer, data.len); - String returnMessage = "Echo:" + receivedData; + switch (localProxyMode) + { + case AWS_SECURE_TUNNELING_DESTINATION_MODE: - fprintf(stdout, "Received: \"%s\"\n", receivedData.c_str()); + echoMessage = std::make_shared(message->getPayload().value()); - switch (localProxyMode) - { - case AWS_SECURE_TUNNELING_DESTINATION_MODE: - fprintf(stdout, "Data Receive Complete in Destination\n"); - fprintf(stdout, "Sending response message:\"%s\"\n", returnMessage.c_str()); - secureTunnel->SendData(ByteCursorFromCString(returnMessage.c_str())); - if (isTest) - { - expectedMessageCount--; - if (expectedMessageCount <= 0) + /* Echo message on same service id received message came on */ + if (message->getServiceId().has_value()) { - exit(0); + echoMessage->withServiceId(message->getServiceId().value()); } - } - break; - case AWS_SECURE_TUNNELING_SOURCE_MODE: - fprintf(stdout, "Data Receive Complete in Source\n"); - break; + + secureTunnel->SendMessage(echoMessage); + + fprintf(stdout, "Sending Echo Message\n"); + + break; + case AWS_SECURE_TUNNELING_SOURCE_MODE: + + break; + } } - }; + }); - /* - * This only fires in Destination Mode - */ - auto OnStreamStart = [&]() { fprintf(stdout, "Stream Started in Destination Mode\n"); }; + builder.WithOnStreamStarted( + [&](SecureTunnel *secureTunnel, int errorCode, const StreamStartedEventData &eventData) { + (void)secureTunnel; + if (!errorCode) + { + std::shared_ptr streamStartedData = eventData.streamStartedData; + + if (streamStartedData->getServiceId().has_value()) + { + fprintf( + stdout, + "Stream started on service id: '" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(streamStartedData->getServiceId().value())); + } + else + { + fprintf(stdout, "Stream started using V1 Protocol"); + } + } + }); - auto OnStreamReset = [&]() { fprintf(stdout, "Stream Reset\n"); }; + builder.WithOnStreamStopped([&](SecureTunnel *secureTunnel, const StreamStoppedEventData &eventData) { + (void)secureTunnel; + std::shared_ptr streamStoppedData = eventData.streamStoppedData; - auto OnSessionReset = [&]() { fprintf(stdout, "Session Reset\n"); }; + if (streamStoppedData->getServiceId().has_value()) + { + fprintf( + stdout, + "Stream stopped on service id: '" PRInSTR "'\n", + AWS_BYTE_CURSOR_PRI(streamStoppedData->getServiceId().value())); + } + else + { + fprintf(stdout, "Stream stopped using V1 Protocol"); + } + }); - /*********************** Proxy Connection Setup ***************************/ - /* - * Setup HttpClientCommectionProxyOptions for connecting through a proxy before the Secure Tunnel - */ + builder.WithOnStopped([&](SecureTunnel *secureTunnel) { + (void)secureTunnel; + fprintf(stdout, "Secure Tunnel has entered Stopped State\n"); + clientStoppedPromise.set_value(true); + }); + //*********************************************************************************************************************** + /* Proxy Options */ + //*********************************************************************************************************************** if (proxyHost.length() > 0) { auto proxyOptions = Aws::Crt::Http::HttpClientConnectionProxyOptions(); proxyOptions.HostName = proxyHost.c_str(); proxyOptions.Port = proxyPort; - /* - * Set up Proxy Strategy if a user name and password is provided - */ + /* Set up Proxy Strategy if a user name and password is provided */ if (proxyUserName.length() > 0 || proxyPassword.length() > 0) { fprintf(stdout, "Creating proxy strategy\n"); @@ -247,7 +417,7 @@ int main(int argc, char *argv[]) basicAuthConfig.Username = proxyUserName.c_str(); basicAuthConfig.Password = proxyPassword.c_str(); proxyOptions.ProxyStrategy = - Aws::Crt::Http::HttpProxyStrategy::CreateBasicHttpProxyStrategy(basicAuthConfig, Aws::Crt::g_allocator); + Aws::Crt::Http::HttpProxyStrategy::CreateBasicHttpProxyStrategy(basicAuthConfig, allocator); proxyOptions.AuthType = Aws::Crt::Http::AwsHttpProxyAuthenticationType::Basic; } else @@ -255,88 +425,93 @@ int main(int argc, char *argv[]) proxyOptions.AuthType = Aws::Crt::Http::AwsHttpProxyAuthenticationType::None; } - /*********************** Secure Tunnel Setup ***************************/ - /* - * Create a new SecureTunnel using the SecureTunnelBuilder - */ - secureTunnel = - SecureTunnelBuilder( - Aws::Crt::g_allocator, SocketOptions(), accessToken.c_str(), localProxyMode, endpoint.c_str()) - .WithRootCa(caFile.c_str()) - .WithHttpClientConnectionProxyOptions(proxyOptions) - .WithOnConnectionComplete(OnConnectionComplete) - .WithOnConnectionShutdown(OnConnectionShutdown) - .WithOnSendDataComplete(OnSendDataComplete) - .WithOnDataReceive(OnDataReceive) - .WithOnStreamStart(OnStreamStart) - .WithOnStreamReset(OnStreamReset) - .WithOnSessionReset(OnSessionReset) - .Build(); - } - else - { - /*********************** Secure Tunnel Setup ***************************/ - /* - * Create a new SecureTunnel using the SecureTunnelBuilder - */ - secureTunnel = - SecureTunnelBuilder( - Aws::Crt::g_allocator, SocketOptions(), accessToken.c_str(), localProxyMode, endpoint.c_str()) - .WithRootCa(caFile.c_str()) - .WithOnConnectionComplete(OnConnectionComplete) - .WithOnConnectionShutdown(OnConnectionShutdown) - .WithOnSendDataComplete(OnSendDataComplete) - .WithOnDataReceive(OnDataReceive) - .WithOnStreamStart(OnStreamStart) - .WithOnStreamReset(OnStreamReset) - .WithOnSessionReset(OnSessionReset) - .Build(); + /* Add proxy options to the builder */ + builder.WithHttpClientConnectionProxyOptions(proxyOptions); } + /* Create Secure Tunnel using the options set with the builder */ + std::shared_ptr secureTunnel = builder.Build(); + if (!secureTunnel) { fprintf(stderr, "Secure Tunnel Creation failed: %s\n", ErrorDebugString(LastError())); exit(-1); } - if (secureTunnel->Connect() == AWS_OP_ERR) + /* Set the Secure Tunnel Client to desire a connected state */ + if (secureTunnel->Start()) { fprintf(stderr, "Secure Tunnel Connect call failed: %s\n", ErrorDebugString(LastError())); exit(-1); } - int messageCount = 0; + + bool keepRunning = true; + uint16_t messagesSent = 0; if (connectionCompletedPromise.get_future().get()) { - while (true) - { - std::this_thread::sleep_for(3000ms); + std::this_thread::sleep_for(1000ms); + /* + * In Destination mode the Secure Tunnel Client will remain open and echo messages that come in. + * In Source mode the Secure Tunnel Client will send 4 messages and then disconnect and terminate. + */ + while (keepRunning) + { if (localProxyMode == AWS_SECURE_TUNNELING_SOURCE_MODE) { - messageCount++; - String toSend = (std::to_string(messageCount) + ": " + message.c_str()).c_str(); + messagesSent++; + String toSend = (std::to_string(messagesSent) + ": " + payloadMessage.c_str()).c_str(); - if (!secureTunnel->SendData(ByteCursorFromCString(toSend.c_str()))) + if (messagesSent <= messageCount) { - fprintf(stdout, "Sending Message:\"%s\"\n", toSend.c_str()); - if (messageCount >= 5) + std::shared_ptr message = std::make_shared(ByteCursorFromCString(toSend.c_str())); + + /* If the secure tunnel has service ids, we will use one for our messages. */ + if (m_serviceId.has_value()) { - fprintf(stdout, "Closing Connection\n"); - if (secureTunnel->Close() == AWS_OP_ERR) - { - fprintf(stderr, "Secure Tunnel Close call failed: %s\n", ErrorDebugString(LastError())); - exit(-1); - } + message->withServiceId(m_serviceId.value()); } + + secureTunnel->SendMessage(message); + + fprintf(stdout, "Sending Message:\"%s\"\n", toSend.c_str()); + + std::this_thread::sleep_for(2000ms); } - else if (connectionClosedPromise.get_future().get()) + else { - fprintf(stdout, "Sample Complete"); - - exit(0); + keepRunning = false; } } } } + + std::this_thread::sleep_for(3000ms); + + fprintf(stdout, "Closing Connection\n"); + /* Set the Secure Tunnel Client to desire a stopped state */ + if (secureTunnel->Stop() == AWS_OP_ERR) + { + fprintf(stderr, "Secure Tunnel Close call failed: %s\n", ErrorDebugString(LastError())); + exit(-1); + } + + if (connectionClosedPromise.get_future().get()) + { + fprintf(stdout, "Secure Tunnel Connection Closed\n"); + } + + /* The Secure Tunnel Client at this point will report they are stopped and can be safely removed. */ + if (clientStoppedPromise.get_future().get()) + { + secureTunnel = nullptr; + } + + fprintf(stdout, "Secure Tunnel Sample Completed\n"); + + /* Clean Up */ + aws_byte_buf_clean_up(&m_serviceIdStorage); + + return 0; } diff --git a/secure_tunneling/CMakeLists.txt b/secure_tunneling/CMakeLists.txt index eaa9d7dc7..e913c7316 100644 --- a/secure_tunneling/CMakeLists.txt +++ b/secure_tunneling/CMakeLists.txt @@ -137,6 +137,3 @@ install(FILES "${CMAKE_CURRENT_BINARY_DIR}/iotsecuretunneling-cpp-config.cmake" DESTINATION "${CMAKE_INSTALL_LIBDIR}/IotSecureTunneling-cpp/cmake/" COMPONENT Development) -if (BUILD_TESTING) - add_subdirectory(tests) -endif() diff --git a/secure_tunneling/include/aws/iotsecuretunneling/SecureTunnel.h b/secure_tunneling/include/aws/iotsecuretunneling/SecureTunnel.h index 56dde464e..873805d93 100644 --- a/secure_tunneling/include/aws/iotsecuretunneling/SecureTunnel.h +++ b/secure_tunneling/include/aws/iotsecuretunneling/SecureTunnel.h @@ -16,17 +16,344 @@ namespace Aws { namespace Iotsecuretunneling { + /** + * Data model for Secure Tunnel messages. + */ + class AWS_IOTSECURETUNNELING_API Message + { + public: + Message( + const aws_secure_tunnel_message_view &raw_options, + Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + Message(Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + Message(Crt::ByteCursor payload, Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + Message( + Crt::ByteCursor serviceId, + Crt::ByteCursor payload, + Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + + /** + * Sets the service id for the secure tunnel message. + * + * @param serviceId The service id for the secure tunnel message. + * @return The Message Object after setting the payload. + */ + Message &withServiceId(Crt::ByteCursor serviceId) noexcept; + + /** + * Sets the payload for the secure tunnel message. + * + * @param payload The payload for the secure tunnel message. + * @return The Message Object after setting the payload. + */ + Message &withPayload(Crt::ByteCursor payload) noexcept; + + bool initializeRawOptions(aws_secure_tunnel_message_view &raw_options) noexcept; + + /** + * The service id of the secure tunnel message. + * + * @return The service id of the secure tunnel message. + */ + const Crt::Optional &getServiceId() const noexcept; + + /** + * The payload of the secure tunnel message. + * + * @return The payload of the secure tunnel message. + */ + const Crt::Optional &getPayload() const noexcept; + + virtual ~Message(); + /* Do not allow direct copy or move */ + Message(const Message &) = delete; + Message(Message &&) noexcept = delete; + Message &operator=(const Message &) = delete; + Message &operator=(Message &&) noexcept = delete; + + private: + Crt::Allocator *m_allocator; + + /** + * The service id used for multiplexing. + * + * If left empty, a V1 protocol message is assumed. + */ + Crt::Optional m_serviceId; + + /** + * The payload of the secure tunnel message. + */ + Crt::Optional m_payload; + + /////////////////////////////////////////////////////////////////////////// + // Underlying data storage for internal use + /////////////////////////////////////////////////////////////////////////// + Crt::ByteBuf m_payloadStorage; + Crt::ByteBuf m_serviceIdStorage; + }; + + /** + * The data returned when a message is received on the secure tunnel. + */ + struct AWS_IOTSECURETUNNELING_API MessageReceivedEventData + { + MessageReceivedEventData() : message(nullptr) {} + std::shared_ptr message; + }; + + /** + * Data model for Secure Tunnel connection view. + */ + class AWS_IOTSECURETUNNELING_API ConnectionData + { + public: + ConnectionData( + const aws_secure_tunnel_connection_view &raw_options, + Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + + /** + * Service id 1 of the secure tunnel. + * + * @return Service id 1 of the secure tunnel. + */ + const Crt::Optional &getServiceId1() const noexcept; + + /** + * Service id 2 of the secure tunnel. + * + * @return Service id 2 of the secure tunnel. + */ + const Crt::Optional &getServiceId2() const noexcept; + + /** + * Service id 3 of the secure tunnel. + * + * @return Service id 3 of the secure tunnel. + */ + const Crt::Optional &getServiceId3() const noexcept; + + virtual ~ConnectionData(); + /* Do not allow direct copy or move */ + ConnectionData(const ConnectionData &) = delete; + ConnectionData(ConnectionData &&) noexcept = delete; + ConnectionData &operator=(const ConnectionData &) = delete; + ConnectionData &operator=(ConnectionData &&) noexcept = delete; + + private: + Crt::Allocator *m_allocator; + + /** + * Service id 1 used for multiplexing. + * + * If left empty, a V1 protocol message is assumed. + */ + Crt::Optional m_serviceId1; + + /** + * Service id 2 used for multiplexing. + */ + Crt::Optional m_serviceId2; + + /** + * Service id 2 used for multiplexing. + */ + Crt::Optional m_serviceId3; + + /////////////////////////////////////////////////////////////////////////// + // Underlying data storage for internal use + /////////////////////////////////////////////////////////////////////////// + Crt::ByteBuf m_serviceId1Storage; + Crt::ByteBuf m_serviceId2Storage; + Crt::ByteBuf m_serviceId3Storage; + }; + + /** + * The data returned when a connection with secure tunnel service is established. + */ + struct AWS_IOTSECURETUNNELING_API ConnectionSuccessEventData + { + ConnectionSuccessEventData() : connectionData(nullptr) {} + std::shared_ptr connectionData; + }; + + /** + * Data model for started Secure Tunnel streams. + */ + class AWS_IOTSECURETUNNELING_API StreamStartedData + { + public: + StreamStartedData( + const aws_secure_tunnel_message_view &raw_options, + Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + + /** + * Service id of the started stream. + * + * @return Service id of the started stream. + */ + const Crt::Optional &getServiceId() const noexcept; + + virtual ~StreamStartedData(); + /* Do not allow direct copy or move */ + StreamStartedData(const StreamStartedData &) = delete; + StreamStartedData(StreamStartedData &&) noexcept = delete; + StreamStartedData &operator=(const StreamStartedData &) = delete; + StreamStartedData &operator=(StreamStartedData &&) noexcept = delete; + + private: + Crt::Allocator *m_allocator; + + /** + * Service id of started stream. + * + * If left empty, a V1 protocolstream is assumed. + */ + Crt::Optional m_serviceId; + + /////////////////////////////////////////////////////////////////////////// + // Underlying data storage for internal use + /////////////////////////////////////////////////////////////////////////// + Crt::ByteBuf m_serviceIdStorage; + }; + + /** + * The data returned when a stream is started on the Secure Tunnel. + */ + struct AWS_IOTSECURETUNNELING_API StreamStartedEventData + { + StreamStartedEventData() : streamStartedData(nullptr) {} + std::shared_ptr streamStartedData; + }; + + /** + * Data model for started Secure Tunnel streams. + */ + class AWS_IOTSECURETUNNELING_API StreamStoppedData + { + public: + StreamStoppedData( + const aws_secure_tunnel_message_view &raw_options, + Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept; + + /** + * Service id of the stopped stream. + * + * @return Service id of the stopped stream. + */ + const Crt::Optional &getServiceId() const noexcept; + + /** + * Stream id of the stopped stream. + */ + + virtual ~StreamStoppedData(); + /* Do not allow direct copy or move */ + StreamStoppedData(const StreamStoppedData &) = delete; + StreamStoppedData(StreamStoppedData &&) noexcept = delete; + StreamStoppedData &operator=(const StreamStoppedData &) = delete; + StreamStoppedData &operator=(StreamStoppedData &&) noexcept = delete; + + private: + Crt::Allocator *m_allocator; + + /** + * Service id of started stream. + * + * If left empty, a V1 protocolstream is assumed. + */ + Crt::Optional m_serviceId; + + /////////////////////////////////////////////////////////////////////////// + // Underlying data storage for internal use + /////////////////////////////////////////////////////////////////////////// + Crt::ByteBuf m_serviceIdStorage; + }; + + /** + * The data returned when a stream is closed on the Secure Tunnel. + */ + struct AWS_IOTSECURETUNNELING_API StreamStoppedEventData + { + StreamStoppedEventData() : streamStoppedData(nullptr) {} + std::shared_ptr streamStoppedData; + }; + class SecureTunnel; // Client callback type definitions - using OnConnectionComplete = std::function; + + /** + * Type signature of the callback invoked when connection is established with the secure tunnel service and + * available service ids are returned. + */ + using OnConnectionSuccess = std::function; + + /** + * Type signature of the callback invoked when connection is established with the secure tunnel service and + * available service ids are returned. + */ + using OnConnectionFailure = std::function; + + /** + * Type signature of the callback invoked when connection is shutdown. + */ using OnConnectionShutdown = std::function; + + /** + * Type signature of the callback invoked when data has been sent through the secure tunnel connection. + */ using OnSendDataComplete = std::function; + + /** + * Type signature of the callback invoked when a message is received through the secure tunnel connection. + */ + using OnMessageReceived = std::function; + + /** + * Type signature of the callback invoked when a stream has been started with a source through the secure tunnel + * connection. + */ + using OnStreamStarted = + std::function; + + /** + * Type signature of the callback invoked when a stream has been closed + */ + + using OnStreamStopped = std::function; + + /** + * Type signature of the callback invoked when the secure tunnel receives a Session Reset. + */ + using OnSessionReset = std::function; + + /** + * Type signature of the callback invoked when the secure tunnel completes transitioning to a stopped state. + */ + using OnStopped = std::function; + + /** + * Deprecated - OnConnectionSuccess and OnConnectionFailure + */ + using OnConnectionComplete = std::function; + /** + * Deprecated - Use OnMessageReceived + */ using OnDataReceive = std::function; + /** + * Deprecated - Use OnStreamStarted + */ using OnStreamStart = std::function; + + /** + * Deprecated - Use OnStreamStopped + */ using OnStreamReset = std::function; - using OnSessionReset = std::function; + /** + * Represents a unique configuration for a secure tunnel + */ class AWS_IOTSECURETUNNELING_API SecureTunnelBuilder final { public: @@ -34,10 +361,20 @@ namespace Aws * Constructor arguments are the minimum required to create a secure tunnel */ SecureTunnelBuilder( - Crt::Allocator *allocator, // Should out live this object - Aws::Crt::Io::ClientBootstrap &clientBootstrap, // Should out live this object - const Aws::Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object - const std::string &accessToken, // Make a copy and save in this object + Crt::Allocator *allocator, // Should out live this object + Crt::Io::ClientBootstrap &clientBootstrap, // Should out live this object + const Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object + const std::string &accessToken, // Make a copy and save in this object + aws_secure_tunneling_local_proxy_mode localProxyMode, + const std::string &endpointHost); // Make a copy and save in this object + + /** + * Constructor arguments are the minimum required to create a secure tunnel + */ + SecureTunnelBuilder( + Crt::Allocator *allocator, // Should out live this object + const Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object + const std::string &accessToken, // Make a copy and save in this object aws_secure_tunneling_local_proxy_mode localProxyMode, const std::string &endpointHost); // Make a copy and save in this object @@ -45,61 +382,281 @@ namespace Aws * Constructor arguments are the minimum required to create a secure tunnel */ SecureTunnelBuilder( - Crt::Allocator *allocator, // Should out live this object - const Aws::Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object - const std::string &accessToken, // Make a copy and save in this object + Crt::Allocator *allocator, // Should out live this object + const std::string &accessToken, // Make a copy and save in this object aws_secure_tunneling_local_proxy_mode localProxyMode, const std::string &endpointHost); // Make a copy and save in this object + /* Optional members */ /** - * Optional members + * Sets rootCA to be used for this secure tunnel connection overriding the default trust store. + * + * @param rootCa string to use as rootCA for secure tunnel connection + * + * @return this builder object */ SecureTunnelBuilder &WithRootCa(const std::string &rootCa); + + /** + * Sets Client Token to a specified value rather than allowing the secure tunnel to auto-generate one. + * + * @param clientToken string to use as unique client token for secure tunnel connection + * + * @return this builder object + */ + SecureTunnelBuilder &WithClientToken(const std::string &clientToken); + + /** + * Sets http proxy options. + * + * @param httpClientConnectionProxyOptions http proxy configuration for connection establishment + * + * @return this builder object + */ SecureTunnelBuilder &WithHttpClientConnectionProxyOptions( - const Aws::Crt::Http::HttpClientConnectionProxyOptions &httpClientConnectionProxyOptions); - SecureTunnelBuilder &WithOnConnectionComplete(OnConnectionComplete onConnectionComplete); + const Crt::Http::HttpClientConnectionProxyOptions &httpClientConnectionProxyOptions); + + /** + * Setup callback handler trigged when an Secure Tunnel establishes a connection and receives available + * service ids. + * + * @param onConnectionSuccess + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnConnectionSuccess(OnConnectionSuccess onConnectionSuccess); + + /** + * Setup callback handler trigged when an Secure Tunnel fails a connection attempt. + * + * @param onConnectionFailure + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnConnectionFailure(OnConnectionFailure onConnectionFailure); + + /** + * Setup callback handler trigged when an Secure Tunnel shuts down connection to the secure tunnel service. + * + * @param onConnectionShutdown + * + * @return this builder object + */ SecureTunnelBuilder &WithOnConnectionShutdown(OnConnectionShutdown onConnectionShutdown); + + /** + * Setup callback handler trigged when an Secure Tunnel completes sending data to the secure tunnel service. + * + * @param onSendDataComplete + * + * @return this builder object + */ SecureTunnelBuilder &WithOnSendDataComplete(OnSendDataComplete onSendDataComplete); - SecureTunnelBuilder &WithOnDataReceive(OnDataReceive onDataReceive); - SecureTunnelBuilder &WithOnStreamStart(OnStreamStart onStreamStart); + + /** + * Setup callback handler trigged when an Secure Tunnel receives a Message through the secure tunnel + * service. + * + * @param onMessageReceived + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnMessageReceived(OnMessageReceived onMessageReceived); + + /** + * Setup callback handler trigged when an Secure Tunnel starts a stream with a source through the secure + * tunnel service. + * + * @param onStreamStarted + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnStreamStarted(OnStreamStarted onStreamStarted); + + /** + * Setup callback handler trigged when an Secure Tunnel stops a stream. + * + * @param onStreamStopped + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnStreamStopped(OnStreamStopped onStreamStopped); + + /** + * Setup callback handler trigged when an Secure Tunnel receives a stream reset. + * + * @param onStreamReset + * + * @return this builder object + */ SecureTunnelBuilder &WithOnStreamReset(OnStreamReset onStreamReset); + + /** + * Setup callback handler trigged when an Secure Tunnel receives a session reset from the secure tunnel + * service. + * + * @param onSessionReset + * + * @return this builder object + */ SecureTunnelBuilder &WithOnSessionReset(OnSessionReset onSessionReset); + /** + * Setup callback handler trigged when an Secure Tunnel completes entering a stopped state + * + * @param onStopped + * + * @return this builder object + */ + SecureTunnelBuilder &WithOnStopped(OnStopped onStopped); + + /** + * Deprecated - Use WithOnMessageReceived() + */ + SecureTunnelBuilder &WithOnDataReceive(OnDataReceive onDataReceive); + /** + * Deprecated - Use WithOnConnectionSuccess() and WithOnConnectionFailure() + */ + SecureTunnelBuilder &WithOnConnectionComplete(OnConnectionComplete onConnectionComplete); + /** + * Deprecated - Use WithOnStreamStarted + */ + SecureTunnelBuilder &WithOnStreamStart(OnStreamStart onStreamStart); + /** * Will return a shared pointer to a new SecureTunnel that countains a * new aws_secure_tunnel that is generated using the set members of SecureTunnelBuilder. - * On failure, the shared_ptr will contain a nullptr, call Aws::Crt::LastErrorOrUnknown(); to get the reason + * On failure, the shared_ptr will contain a nullptr, call Crt::LastErrorOrUnknown(); to get the reason * for failure. */ std::shared_ptr Build() noexcept; private: + /* Required Memebers */ + + Crt::Allocator *m_allocator; + /** - * Required members + * Client bootstrap to use. In almost all cases, this can be left undefined. + */ + Crt::Io::ClientBootstrap *m_clientBootstrap; + + /** + * Controls socket properties of the underlying connections made by the secure tunnel. Leave undefined to + * use defaults (no TCP keep alive, 10 second socket timeout). + */ + Crt::Io::SocketOptions m_socketOptions; + + /** + * Token used to establish a WebSocket connection with the secure tunnel service. This token is one time use + * and must be rotated to establish a new connection to the secure tunnel unless using a unique client + * token. */ - Crt::Allocator *m_allocator; - Aws::Crt::Io::ClientBootstrap *m_clientBootstrap; - Aws::Crt::Io::SocketOptions m_socketOptions; std::string m_accessToken; + + /** + * Proxy mode to use. + */ aws_secure_tunneling_local_proxy_mode m_localProxyMode; + + /** + * AWS Secure Tunnel endpoint to connect to. + */ std::string m_endpointHost; + /* Optional members */ /** - * Optional members + * Client token is used to reconnect to a secure tunnel after initial connection. If this is not set by the + * user, one will be automatically generated and used to maintain a connection as long as the secure tunnel + * has the desired state of CONNECTED. + */ + std::string m_clientToken; + + /** + * If set, this will be used to override the default trust store. */ std::string m_rootCa; + + /** + * If set, http proxy configuration will be used for connection establishment + */ Crt::Optional m_httpClientConnectionProxyOptions; + /* Callbacks */ /** - * Callbacks + * Callback handler trigged when secure tunnel establishes connection with secure tunnel service and + * receives available service ids. + */ + OnConnectionSuccess m_OnConnectionSuccess; + + /* Callbacks */ + /** + * Callback handler trigged when secure tunnel establishes fails a connection attempt with secure tunnel + * service. + */ + OnConnectionFailure m_OnConnectionFailure; + + /** + * Callback handler trigged when secure tunnel connection to secure tunnel service is closed. */ - OnConnectionComplete m_OnConnectionComplete; OnConnectionShutdown m_OnConnectionShutdown; + + /** + * Callback handler trigged when secure tunnel completes sending data to the secure tunnel service. + */ OnSendDataComplete m_OnSendDataComplete; + + /** + * Callback handler trigged when secure tunnel receives a message from the secure tunnel service. + * + * @param SecureTunnel: The shared secure tunnel + * @param MessageReceivedEventData: Data received + */ + OnMessageReceived m_OnMessageReceived; + + /** + * Callback handler trigged when secure tunnel receives a stream start from a source device. + * + * @param SecureTunnel: The shared secure tunnel + * @param int: error code + * @param StreamStartedEventData: Stream Started data + */ + OnStreamStarted m_OnStreamStarted; + + /** + * Callback handler trigged when secure tunnel receives a stream reset. + * + * @param SecureTunnel: The shared secure tunnel + * @param StreamStoppedEventData: Stream Started data + */ + OnStreamStopped m_OnStreamStopped; + + /** + * Callback handler trigged when secure tunnel receives a session reset from the secure tunnel service. + */ + OnSessionReset m_OnSessionReset; + + /** + * Callback handler trigged when secure tunnel completes transition to stopped state. + */ + OnStopped m_OnStopped; + + /** + * Deprecated - Use m_OnConnectionSuccess and m_OnConnectionFailure + */ + OnConnectionComplete m_OnConnectionComplete; + /** + * Deprecated - Use m_OnMessageReceived + */ OnDataReceive m_OnDataReceive; + /** + * Deprecated - Use m_OnStreamStarted + */ OnStreamStart m_OnStreamStart; + /** + * Deprecated - Use m_OnStreamStopped + */ OnStreamReset m_OnStreamReset; - OnSessionReset m_OnSessionReset; friend class SecureTunnel; }; @@ -108,9 +665,9 @@ namespace Aws { public: SecureTunnel( - Crt::Allocator *allocator, // Should out live this object - Aws::Crt::Io::ClientBootstrap *clientBootstrap, // Should out live this object - const Aws::Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object + Crt::Allocator *allocator, // Should out live this object + Crt::Io::ClientBootstrap *clientBootstrap, // Should out live this object + const Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object const std::string &accessToken, // Make a copy and save in this object aws_secure_tunneling_local_proxy_mode localProxyMode, @@ -126,8 +683,8 @@ namespace Aws OnSessionReset onSessionReset); SecureTunnel( - Crt::Allocator *allocator, // Should out live this object - const Aws::Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object + Crt::Allocator *allocator, // Should out live this object + const Crt::Io::SocketOptions &socketOptions, // Make a copy and save in this object const std::string &accessToken, // Make a copy and save in this object aws_secure_tunneling_local_proxy_mode localProxyMode, @@ -141,30 +698,95 @@ namespace Aws OnStreamStart onStreamStart, OnStreamReset onStreamReset, OnSessionReset onSessionReset); - SecureTunnel(const SecureTunnel &) = delete; - SecureTunnel(SecureTunnel &&) noexcept; virtual ~SecureTunnel(); - + SecureTunnel(const SecureTunnel &) = delete; SecureTunnel &operator=(const SecureTunnel &) = delete; + + SecureTunnel(SecureTunnel &&) noexcept; SecureTunnel &operator=(SecureTunnel &&) noexcept; bool IsValid(); + /** + * Notifies the secure tunnel that you want it to attempt to connect to the configured endpoint. + * The secure tunnel will attempt to stay connected and attempt to reconnect if disconnected. + * + * @return success/failure in the synchronous logic that kicks off the start process + */ + int Start(); + + /** + * Notifies the secure tunnel that you want it to transition to the stopped state, disconnecting any + * existing connection and stopping subsequent reconnect attempts. + * + * @return success/failure in the synchronous logic that kicks off the stop process + */ + int Stop(); + + /** + * Tells the secure tunnel to attempt to send a Message + * + * @param messageOptions: Message to send to the secure tunnel service. + * + * @return success/failure in the synchronous logic that kicks off the Send Message operation + */ + int SendMessage(std::shared_ptr messageOptions) noexcept; + + /* SOURCE MODE ONLY */ + /** + * Notifies the secure tunnel that you want to start a stream with the Destination device. This will result + * in a V1 stream. + * + * @return success/failure in the synchronous logic that kicks off the Stream Start operation + */ + int SendStreamStart(); + + /** + * Notifies the secure tunnel that you want to start a stream with the Destination device on a specific + * service id. This will result in a V2 stream. + * + * @param serviceId: The Service Id to start a stream on. + * + * @return success/failure in the synchronous logic that kicks off the Stream Start operation + */ + int SendStreamStart(std::string serviceId); + + /** + * Notifies the secure tunnel that you want to start a stream with the Destination device on a specific + * service id. This will result in a V2 stream. + * + * @param serviceId: The Service Id to start a stream on. + * + * @return success/failure in the synchronous logic that kicks off the Stream Start operation + */ + int SendStreamStart(Crt::ByteCursor serviceId); + + aws_secure_tunnel *GetUnderlyingHandle(); + + /** + * Deprecated - use Start() + */ int Connect(); + /** + * Deprecated - Use Stop() + */ int Close(); + /** + * Deprecated - Use Stop() + */ void Shutdown(); + /** + * Deprecated - Use SendMessage() + */ int SendData(const Crt::ByteCursor &data); - int SendStreamStart(); - + /* Should not be exposed. Under the hood only operation. */ int SendStreamReset(); - aws_secure_tunnel *GetUnderlyingHandle(); - private: /** * This constructor is used with SecureTunnelBuilder and should be modified when members are added or @@ -172,46 +794,121 @@ namespace Aws */ SecureTunnel( Crt::Allocator *allocator, - Aws::Crt::Io::ClientBootstrap *clientBootstrap, - const Aws::Crt::Io::SocketOptions &socketOptions, + Crt::Io::ClientBootstrap *clientBootstrap, + const Crt::Io::SocketOptions &socketOptions, const std::string &accessToken, + const std::string &clientToken, aws_secure_tunneling_local_proxy_mode localProxyMode, const std::string &endpointHost, const std::string &rootCa, - Aws::Crt::Http::HttpClientConnectionProxyOptions *httpClientConnectionProxyOptions, + Crt::Http::HttpClientConnectionProxyOptions *httpClientConnectionProxyOptions, + OnConnectionSuccess onConnectionSuccess, + OnConnectionFailure onConnectionFailure, OnConnectionComplete onConnectionComplete, OnConnectionShutdown onConnectionShutdown, OnSendDataComplete onSendDataComplete, + OnMessageReceived onMessageReceived, OnDataReceive onDataReceive, + OnStreamStarted onStreamStarted, OnStreamStart onStreamStart, + OnStreamStopped onStreamStopped, OnStreamReset onStreamReset, - OnSessionReset onSessionReset); - - // aws-c-iot callbacks - static void s_OnConnectionComplete(void *user_data); - static void s_OnConnectionShutdown(void *user_data); + OnSessionReset onSessionReset, + OnStopped onStopped); + + /* Static Callbacks */ + static void s_OnMessageReceived(const struct aws_secure_tunnel_message_view *message, void *user_data); + static void s_OnConnectionComplete( + const struct aws_secure_tunnel_connection_view *connection, + int error_code, + void *user_data); + static void s_OnConnectionFailure(int error_code, void *user_data); + static void s_OnConnectionShutdown(int error_code, void *user_data); static void s_OnSendDataComplete(int error_code, void *user_data); - static void s_OnDataReceive(const struct aws_byte_buf *data, void *user_data); - static void s_OnStreamStart(void *user_data); - static void s_OnStreamReset(void *user_data); + static void s_OnStreamStopped( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); static void s_OnSessionReset(void *user_data); + static void s_OnStopped(void *user_data); static void s_OnTerminationComplete(void *user_data); + static void s_OnStreamStarted( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data); void OnTerminationComplete(); - // Client callbacks - OnConnectionComplete m_OnConnectionComplete; + /** + * Callback handler trigged when secure tunnel receives a Message. + */ + OnMessageReceived m_OnMessageReceived; + + /** + * Callback handler trigged when secure tunnel establishes connection with the secure tunnel service and + * receives service ids. + */ + OnConnectionSuccess m_OnConnectionSuccess; + + /** + * Callback handler trigged when secure tunnel fails a connection attempt with the secure tunnel service. + */ + OnConnectionFailure m_OnConnectionFailure; + + /** + * Callback handler trigged when secure tunnel shuts down connection. + */ OnConnectionShutdown m_OnConnectionShutdown; + + /** + * Callback handler trigged when secure tunnel completes sending data to the secure tunnel service. + */ OnSendDataComplete m_OnSendDataComplete; + + /** + * Callback handler trigged when secure tunnel starts a stream with a source device through the secure + * tunnel service. + */ + OnStreamStarted m_OnStreamStarted; + + /** + * Callback handler trigged when secure tunnel closes a stream + */ + OnStreamStopped m_OnStreamStopped; + + /** + * Callback handler trigged when secure tunnel receives a session reset from the secure tunnel service. + */ + OnSessionReset m_OnSessionReset; + + /** + * Callback handler trigged when secure tunnel finishes entering a stopped state. + */ + OnStopped m_OnStopped; + + aws_secure_tunnel *m_secure_tunnel; + Crt::Allocator *m_allocator; + + /** + * Deprecated - m_OnMessageReceived + */ OnDataReceive m_OnDataReceive; + /** + * Deprecated - Use m_OnConnectionSuccess and m_OnConnectionFailure + */ + OnConnectionComplete m_OnConnectionComplete; + /** + * Deprecated - Use m_OnStreamStarted + */ OnStreamStart m_OnStreamStart; + /** + * Deprecated - Use m_OnStreamStopped + */ OnStreamReset m_OnStreamReset; - OnSessionReset m_OnSessionReset; - aws_secure_tunnel *m_secure_tunnel; - std::promise m_TerminationComplete; + std::shared_ptr m_selfRef; friend class SecureTunnelBuilder; }; diff --git a/secure_tunneling/source/SecureTunnel.cpp b/secure_tunneling/source/SecureTunnel.cpp index 578bc8c7b..b76d182e9 100644 --- a/secure_tunneling/source/SecureTunnel.cpp +++ b/secure_tunneling/source/SecureTunnel.cpp @@ -10,6 +10,195 @@ namespace Aws { namespace Iotsecuretunneling { + void setPacketByteBufOptional( + Crt::Optional &optional, + Crt::ByteBuf &optionalStorage, + Crt::Allocator *allocator, + const Crt::ByteCursor *value) + { + aws_byte_buf_clean_up(&optionalStorage); + AWS_ZERO_STRUCT(optionalStorage); + if (value != nullptr) + { + aws_byte_buf_init_copy_from_cursor(&optionalStorage, allocator, *value); + optional = aws_byte_cursor_from_buf(&optionalStorage); + } + else + { + optional.reset(); + } + } + + void setPacketStringOptional( + Crt::Optional &optional, + Crt::String &optionalStorage, + const aws_byte_cursor *value) + { + if (value != nullptr) + { + optionalStorage = Crt::String((const char *)value->ptr, value->len); + struct aws_byte_cursor optional_cursor; + optional_cursor.ptr = (uint8_t *)optionalStorage.c_str(); + optional_cursor.len = optionalStorage.size(); + optional = optional_cursor; + } + } + + //*********************************************************************************************************************** + /* Message */ + //*********************************************************************************************************************** + + Message::Message(const aws_secure_tunnel_message_view &message, Crt::Allocator *allocator) noexcept + : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_payloadStorage); + AWS_ZERO_STRUCT(m_serviceIdStorage); + + setPacketByteBufOptional(m_payload, m_payloadStorage, m_allocator, message.payload); + setPacketByteBufOptional(m_serviceId, m_serviceIdStorage, m_allocator, message.service_id); + } + + /* Default constructor */ + Message::Message(Crt::Allocator *allocator) noexcept : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_payloadStorage); + AWS_ZERO_STRUCT(m_serviceIdStorage); + } + + Message::Message(Crt::ByteCursor payload, Crt::Allocator *allocator) noexcept : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_payloadStorage); + AWS_ZERO_STRUCT(m_serviceIdStorage); + + aws_byte_buf_clean_up(&m_payloadStorage); + aws_byte_buf_init_copy_from_cursor(&m_payloadStorage, m_allocator, payload); + m_payload = aws_byte_cursor_from_buf(&m_payloadStorage); + } + + Message::Message(Crt::ByteCursor serviceId, Crt::ByteCursor payload, Crt::Allocator *allocator) noexcept + : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_payloadStorage); + AWS_ZERO_STRUCT(m_serviceIdStorage); + + aws_byte_buf_clean_up(&m_payloadStorage); + aws_byte_buf_init_copy_from_cursor(&m_payloadStorage, m_allocator, payload); + m_payload = aws_byte_cursor_from_buf(&m_payloadStorage); + + aws_byte_buf_clean_up(&m_serviceIdStorage); + aws_byte_buf_init_copy_from_cursor(&m_serviceIdStorage, m_allocator, serviceId); + m_serviceId = aws_byte_cursor_from_buf(&m_serviceIdStorage); + } + + Message &Message::withPayload(Crt::ByteCursor payload) noexcept + { + aws_byte_buf_clean_up(&m_payloadStorage); + aws_byte_buf_init_copy_from_cursor(&m_payloadStorage, m_allocator, payload); + m_payload = aws_byte_cursor_from_buf(&m_payloadStorage); + return *this; + } + + Message &Message::withServiceId(Crt::ByteCursor serviceId) noexcept + { + aws_byte_buf_clean_up(&m_serviceIdStorage); + aws_byte_buf_init_copy_from_cursor(&m_serviceIdStorage, m_allocator, serviceId); + m_serviceId = aws_byte_cursor_from_buf(&m_serviceIdStorage); + return *this; + } + + bool Message::initializeRawOptions(aws_secure_tunnel_message_view &raw_options) noexcept + { + AWS_ZERO_STRUCT(raw_options); + if (m_payload.has_value()) + { + raw_options.payload = &m_payload.value(); + } + if (m_serviceId.has_value()) + { + raw_options.service_id = &m_serviceId.value(); + } + + return true; + } + + const Crt::Optional &Message::getPayload() const noexcept { return m_payload; } + + const Crt::Optional &Message::getServiceId() const noexcept { return m_serviceId; } + + Message::~Message() + { + aws_byte_buf_clean_up(&m_payloadStorage); + aws_byte_buf_clean_up(&m_serviceIdStorage); + } + + //*********************************************************************************************************************** + /* ConnectionData */ + //*********************************************************************************************************************** + + ConnectionData::ConnectionData( + const aws_secure_tunnel_connection_view &connection, + Crt::Allocator *allocator) noexcept + : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_serviceId1Storage); + AWS_ZERO_STRUCT(m_serviceId2Storage); + AWS_ZERO_STRUCT(m_serviceId3Storage); + + setPacketByteBufOptional(m_serviceId1, m_serviceId1Storage, m_allocator, connection.service_id_1); + setPacketByteBufOptional(m_serviceId2, m_serviceId2Storage, m_allocator, connection.service_id_2); + setPacketByteBufOptional(m_serviceId3, m_serviceId3Storage, m_allocator, connection.service_id_3); + } + + const Crt::Optional &ConnectionData::getServiceId1() const noexcept { return m_serviceId1; } + const Crt::Optional &ConnectionData::getServiceId2() const noexcept { return m_serviceId2; } + const Crt::Optional &ConnectionData::getServiceId3() const noexcept { return m_serviceId3; } + + ConnectionData::~ConnectionData() + { + aws_byte_buf_clean_up(&m_serviceId1Storage); + aws_byte_buf_clean_up(&m_serviceId2Storage); + aws_byte_buf_clean_up(&m_serviceId3Storage); + } + + //*********************************************************************************************************************** + /* StreamStartedData */ + //*********************************************************************************************************************** + + StreamStartedData::StreamStartedData( + const aws_secure_tunnel_message_view &message, + Crt::Allocator *allocator) noexcept + : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_serviceIdStorage); + + setPacketByteBufOptional(m_serviceId, m_serviceIdStorage, m_allocator, message.service_id); + } + + const Crt::Optional &StreamStartedData::getServiceId() const noexcept { return m_serviceId; } + + StreamStartedData::~StreamStartedData() { aws_byte_buf_clean_up(&m_serviceIdStorage); } + + //*********************************************************************************************************************** + /* StreamStoppedData */ + //*********************************************************************************************************************** + + StreamStoppedData::StreamStoppedData( + const aws_secure_tunnel_message_view &message, + Crt::Allocator *allocator) noexcept + : m_allocator(allocator) + { + AWS_ZERO_STRUCT(m_serviceIdStorage); + + setPacketByteBufOptional(m_serviceId, m_serviceIdStorage, m_allocator, message.service_id); + } + + const Crt::Optional &StreamStoppedData::getServiceId() const noexcept { return m_serviceId; } + + StreamStoppedData::~StreamStoppedData() { aws_byte_buf_clean_up(&m_serviceIdStorage); } + + //*********************************************************************************************************************** + /* SecureTunnelBuilder */ + //*********************************************************************************************************************** SecureTunnelBuilder::SecureTunnelBuilder( Crt::Allocator *allocator, // Should out live this object Aws::Crt::Io::ClientBootstrap &clientBootstrap, // Should out live this object @@ -19,8 +208,8 @@ namespace Aws const std::string &endpointHost) // Make a copy and save in this object : m_allocator(allocator), m_clientBootstrap(&clientBootstrap), m_socketOptions(socketOptions), m_accessToken(accessToken), m_localProxyMode(localProxyMode), m_endpointHost(endpointHost), m_rootCa(""), - m_httpClientConnectionProxyOptions(), m_OnConnectionComplete(), m_OnConnectionShutdown(), - m_OnSendDataComplete(), m_OnDataReceive(), m_OnStreamStart(), m_OnStreamReset(), m_OnSessionReset() + m_httpClientConnectionProxyOptions(), m_OnConnectionShutdown(), m_OnSendDataComplete(), + m_OnSessionReset(), m_OnConnectionComplete(), m_OnDataReceive(), m_OnStreamStart(), m_OnStreamReset() { } @@ -33,8 +222,21 @@ namespace Aws : m_allocator(allocator), m_clientBootstrap(Crt::ApiHandle::GetOrCreateStaticDefaultClientBootstrap()), m_socketOptions(socketOptions), m_accessToken(accessToken), m_localProxyMode(localProxyMode), m_endpointHost(endpointHost), m_rootCa(""), m_httpClientConnectionProxyOptions(), - m_OnConnectionComplete(), m_OnConnectionShutdown(), m_OnSendDataComplete(), m_OnDataReceive(), - m_OnStreamStart(), m_OnStreamReset(), m_OnSessionReset() + m_OnConnectionShutdown(), m_OnSendDataComplete(), m_OnSessionReset(), m_OnConnectionComplete(), + m_OnDataReceive(), m_OnStreamStart(), m_OnStreamReset() + { + } + + SecureTunnelBuilder::SecureTunnelBuilder( + Crt::Allocator *allocator, // Should out live this object + const std::string &accessToken, // Make a copy and save in this object + aws_secure_tunneling_local_proxy_mode localProxyMode, + const std::string &endpointHost) // Make a copy and save in this object + : m_allocator(allocator), m_clientBootstrap(Crt::ApiHandle::GetOrCreateStaticDefaultClientBootstrap()), + m_socketOptions(Crt::Io::SocketOptions()), m_accessToken(accessToken), m_localProxyMode(localProxyMode), + m_endpointHost(endpointHost), m_rootCa(""), m_httpClientConnectionProxyOptions(), + m_OnConnectionShutdown(), m_OnSendDataComplete(), m_OnSessionReset(), m_OnConnectionComplete(), + m_OnDataReceive(), m_OnStreamStart(), m_OnStreamReset() { } @@ -51,9 +253,15 @@ namespace Aws return *this; } - SecureTunnelBuilder &SecureTunnelBuilder::WithOnConnectionComplete(OnConnectionComplete onConnectionComplete) + SecureTunnelBuilder &SecureTunnelBuilder::WithOnConnectionSuccess(OnConnectionSuccess onConnectionSuccess) { - m_OnConnectionComplete = std::move(onConnectionComplete); + m_OnConnectionSuccess = std::move(onConnectionSuccess); + return *this; + } + + SecureTunnelBuilder &SecureTunnelBuilder::WithOnConnectionFailure(OnConnectionFailure onConnectionFailure) + { + m_OnConnectionFailure = std::move(onConnectionFailure); return *this; } @@ -69,15 +277,21 @@ namespace Aws return *this; } - SecureTunnelBuilder &SecureTunnelBuilder::WithOnDataReceive(OnDataReceive onDataReceive) + SecureTunnelBuilder &SecureTunnelBuilder::WithOnMessageReceived(OnMessageReceived OnMessageReceived) { - m_OnDataReceive = std::move(onDataReceive); + m_OnMessageReceived = std::move(OnMessageReceived); return *this; } - SecureTunnelBuilder &SecureTunnelBuilder::WithOnStreamStart(OnStreamStart onStreamStart) + SecureTunnelBuilder &SecureTunnelBuilder::WithOnStreamStarted(OnStreamStarted onStreamStarted) { - m_OnStreamStart = std::move(onStreamStart); + m_OnStreamStarted = std::move(onStreamStarted); + return *this; + } + + SecureTunnelBuilder &SecureTunnelBuilder::WithOnStreamStopped(OnStreamStopped onStreamStopped) + { + m_OnStreamStopped = std::move(onStreamStopped); return *this; } @@ -93,6 +307,37 @@ namespace Aws return *this; } + SecureTunnelBuilder &SecureTunnelBuilder::WithOnStopped(OnStopped onStopped) + { + m_OnStopped = std::move(onStopped); + return *this; + } + + SecureTunnelBuilder &SecureTunnelBuilder::WithClientToken(const std::string &clientToken) + { + m_clientToken = clientToken; + return *this; + } + + /* Deprecated - Use WithOnConnectionSuccess and WithOnConnectionFailure */ + SecureTunnelBuilder &SecureTunnelBuilder::WithOnConnectionComplete(OnConnectionComplete onConnectionComplete) + { + m_OnConnectionComplete = std::move(onConnectionComplete); + return *this; + } + /* Deprecated - Use WithOnStreamStarted */ + SecureTunnelBuilder &SecureTunnelBuilder::WithOnStreamStart(OnStreamStart onStreamStart) + { + m_OnStreamStart = std::move(onStreamStart); + return *this; + } + /* Deprecated - Use WithOnMessageReceived */ + SecureTunnelBuilder &SecureTunnelBuilder::WithOnDataReceive(OnDataReceive onDataReceive) + { + m_OnDataReceive = std::move(onDataReceive); + return *this; + } + std::shared_ptr SecureTunnelBuilder::Build() noexcept { auto tunnel = std::shared_ptr(new SecureTunnel( @@ -100,17 +345,24 @@ namespace Aws m_clientBootstrap, m_socketOptions, m_accessToken, + m_clientToken, m_localProxyMode, m_endpointHost, m_rootCa, m_httpClientConnectionProxyOptions.has_value() ? &m_httpClientConnectionProxyOptions.value() : nullptr, + m_OnConnectionSuccess, + m_OnConnectionFailure, m_OnConnectionComplete, m_OnConnectionShutdown, m_OnSendDataComplete, + m_OnMessageReceived, m_OnDataReceive, + m_OnStreamStarted, m_OnStreamStart, + m_OnStreamStopped, m_OnStreamReset, - m_OnSessionReset)); + m_OnSessionReset, + m_OnStopped)); if (tunnel->m_secure_tunnel == nullptr) { @@ -120,6 +372,10 @@ namespace Aws return tunnel; } + //*********************************************************************************************************************** + /* SecureTunnel */ + //*********************************************************************************************************************** + /** * Private SecureTunnel constructor used by SecureTunnelBuilder on SecureTunnelBuilder::Build() and by old * SecureTunnel constructor which should be deprecated @@ -129,34 +385,45 @@ namespace Aws Aws::Crt::Io::ClientBootstrap *clientBootstrap, const Aws::Crt::Io::SocketOptions &socketOptions, const std::string &accessToken, + const std::string &clientToken, aws_secure_tunneling_local_proxy_mode localProxyMode, const std::string &endpointHost, const std::string &rootCa, Aws::Crt::Http::HttpClientConnectionProxyOptions *httpClientConnectionProxyOptions, + OnConnectionSuccess onConnectionSuccess, + OnConnectionFailure onConnectionFailure, OnConnectionComplete onConnectionComplete, OnConnectionShutdown onConnectionShutdown, OnSendDataComplete onSendDataComplete, + OnMessageReceived onMessageReceived, OnDataReceive onDataReceive, + OnStreamStarted onStreamStarted, OnStreamStart onStreamStart, + OnStreamStopped onStreamStopped, OnStreamReset onStreamReset, - OnSessionReset onSessionReset) + OnSessionReset onSessionReset, + OnStopped onStopped) { // Client callbacks + m_OnConnectionSuccess = std::move(onConnectionSuccess); + m_OnConnectionFailure = std::move(onConnectionFailure); m_OnConnectionComplete = std::move(onConnectionComplete); m_OnConnectionShutdown = std::move(onConnectionShutdown); m_OnSendDataComplete = std::move(onSendDataComplete); + m_OnMessageReceived = std::move(onMessageReceived); m_OnDataReceive = std::move(onDataReceive); + m_OnStreamStarted = std::move(onStreamStarted); m_OnStreamStart = std::move(onStreamStart); m_OnStreamReset = std::move(onStreamReset); m_OnSessionReset = std::move(onSessionReset); + m_OnStopped = std::move(onStopped); // Initialize aws_secure_tunnel_options aws_secure_tunnel_options config; AWS_ZERO_STRUCT(config); - config.allocator = allocator; config.bootstrap = clientBootstrap ? clientBootstrap->GetUnderlyingHandle() : nullptr; config.socket_options = &socketOptions.GetImpl(); @@ -169,14 +436,20 @@ namespace Aws config.root_ca = rootCa.c_str(); } + if (clientToken.length() > 0) + { + config.client_token = aws_byte_cursor_from_c_str(clientToken.c_str()); + } + + /* callbacks for native secure tunnel */ + config.on_message_received = s_OnMessageReceived; config.on_connection_complete = s_OnConnectionComplete; config.on_connection_shutdown = s_OnConnectionShutdown; config.on_send_data_complete = s_OnSendDataComplete; - config.on_data_receive = s_OnDataReceive; - config.on_stream_start = s_OnStreamStart; - config.on_stream_reset = s_OnStreamReset; + config.on_stream_start = s_OnStreamStarted; + config.on_stream_reset = s_OnStreamStopped; config.on_session_reset = s_OnSessionReset; - config.on_termination_complete = s_OnTerminationComplete; + config.on_stopped = s_OnStopped; config.user_data = this; @@ -189,12 +462,12 @@ namespace Aws } // Create the secure tunnel - m_secure_tunnel = aws_secure_tunnel_new(&config); + m_secure_tunnel = aws_secure_tunnel_new(allocator, &config); + m_allocator = allocator; } /** - * Should be deprecated when possible. - * SecureTunnelBuilder::Build() should be used to generate new SecureTunnels + * Deprecated - Use SecureTunnelBuilder */ SecureTunnel::SecureTunnel( Crt::Allocator *allocator, @@ -218,23 +491,29 @@ namespace Aws clientBootstrap, socketOptions, accessToken, + nullptr, localProxyMode, endpointHost, rootCa, nullptr, + nullptr, + nullptr, onConnectionComplete, onConnectionShutdown, onSendDataComplete, + nullptr, onDataReceive, + nullptr, onStreamStart, + nullptr, onStreamReset, - onSessionReset) + onSessionReset, + nullptr) { } /** - * Should be deprecated when possible. - * SecureTunnelBuilder::Build() should be used to generate new SecureTunnels + * Deprecated - Use SecureTunnelBuilder */ SecureTunnel::SecureTunnel( Crt::Allocator *allocator, @@ -257,60 +536,82 @@ namespace Aws Crt::ApiHandle::GetOrCreateStaticDefaultClientBootstrap(), socketOptions, accessToken, + nullptr, localProxyMode, endpointHost, rootCa, nullptr, + nullptr, + nullptr, onConnectionComplete, onConnectionShutdown, onSendDataComplete, + nullptr, onDataReceive, + nullptr, onStreamStart, + nullptr, onStreamReset, - onSessionReset) + onSessionReset, + nullptr) { } + SecureTunnel::~SecureTunnel() + { + if (m_secure_tunnel) + { + aws_secure_tunnel_release(m_secure_tunnel); + m_secure_tunnel = nullptr; + } + } + SecureTunnel::SecureTunnel(SecureTunnel &&other) noexcept { - m_OnConnectionComplete = std::move(other.m_OnConnectionComplete); + m_OnConnectionSuccess = std::move(other.m_OnConnectionSuccess); + m_OnConnectionFailure = std::move(other.m_OnConnectionFailure); m_OnConnectionShutdown = std::move(other.m_OnConnectionShutdown); m_OnSendDataComplete = std::move(other.m_OnSendDataComplete); - m_OnDataReceive = std::move(other.m_OnDataReceive); - m_OnStreamStart = std::move(other.m_OnStreamStart); + m_OnMessageReceived = std::move(other.m_OnMessageReceived); + m_OnStreamStarted = std::move(other.m_OnStreamStarted); m_OnStreamReset = std::move(other.m_OnStreamReset); m_OnSessionReset = std::move(other.m_OnSessionReset); + m_OnStopped = std::move(other.m_OnStopped); - m_TerminationComplete = std::move(other.m_TerminationComplete); + /* Deprecated - Use m_OnConnectionSuccess and m_OnConnectionFailure */ + m_OnConnectionComplete = std::move(other.m_OnConnectionComplete); + /* Deprecated - Use m_OnMessageReceived */ + m_OnDataReceive = std::move(other.m_OnDataReceive); + /* Deprecated - Use m_OnStreamStarted */ + m_OnStreamStart = std::move(other.m_OnStreamStart); m_secure_tunnel = other.m_secure_tunnel; other.m_secure_tunnel = nullptr; } - SecureTunnel::~SecureTunnel() - { - if (m_secure_tunnel) - { - aws_secure_tunnel_release(m_secure_tunnel); - } - } - SecureTunnel &SecureTunnel::operator=(SecureTunnel &&other) noexcept { if (this != &other) { this->~SecureTunnel(); - m_OnConnectionComplete = std::move(other.m_OnConnectionComplete); + m_OnConnectionSuccess = std::move(other.m_OnConnectionSuccess); + m_OnConnectionFailure = std::move(other.m_OnConnectionFailure); m_OnConnectionShutdown = std::move(other.m_OnConnectionShutdown); m_OnSendDataComplete = std::move(other.m_OnSendDataComplete); - m_OnDataReceive = std::move(other.m_OnDataReceive); - m_OnStreamStart = std::move(other.m_OnStreamStart); + m_OnMessageReceived = std::move(other.m_OnMessageReceived); + m_OnStreamStarted = std::move(other.m_OnStreamStarted); m_OnStreamReset = std::move(other.m_OnStreamReset); m_OnSessionReset = std::move(other.m_OnSessionReset); + m_OnStopped = std::move(other.m_OnStopped); - m_TerminationComplete = std::move(other.m_TerminationComplete); + /* Deprecated - Use m_OnConnectionSuccess and m_OnConnectionFailure */ + m_OnConnectionComplete = std::move(other.m_OnConnectionComplete); + /* Deprecated - Use m_OnMessageReceived */ + m_OnDataReceive = std::move(other.m_OnDataReceive); + /* Deprecated - Use m_OnStreamStarted */ + m_OnStreamStart = std::move(other.m_OnStreamStart); m_secure_tunnel = other.m_secure_tunnel; @@ -322,33 +623,105 @@ namespace Aws bool SecureTunnel::IsValid() { return m_secure_tunnel ? true : false; } - int SecureTunnel::Connect() { return aws_secure_tunnel_connect(m_secure_tunnel); } + int SecureTunnel::Start() + { + // if (m_selfRef == nullptr) + // { + // m_selfRef = this->getptr(); + // } + return aws_secure_tunnel_start(m_secure_tunnel); + } + + int SecureTunnel::Stop() { return aws_secure_tunnel_stop(m_secure_tunnel); } + + /* Deprecated - Use Start() */ + int SecureTunnel::Connect() { return Start(); } - int SecureTunnel::Close() { return aws_secure_tunnel_close(m_secure_tunnel); } + /* Deprecated - Use Stop() */ + int SecureTunnel::Close() { return Stop(); } + /* Deprecated - Use SendMessage() */ int SecureTunnel::SendData(const Crt::ByteCursor &data) { - return aws_secure_tunnel_send_data(m_secure_tunnel, &data); + // return SendData("", data); + std::shared_ptr message = std::make_shared(data); + return SendMessage(message); + } + + int SecureTunnel::SendMessage(std::shared_ptr messageOptions) noexcept + { + if (messageOptions == nullptr) + { + return AWS_OP_ERR; + } + + aws_secure_tunnel_message_view message; + messageOptions->initializeRawOptions(message); + return aws_secure_tunnel_send_message(m_secure_tunnel, &message); } - int SecureTunnel::SendStreamStart() { return aws_secure_tunnel_stream_start(m_secure_tunnel); } + int SecureTunnel::SendStreamStart() { return SendStreamStart(""); } + int SecureTunnel::SendStreamStart(std::string serviceId) + { + struct aws_byte_cursor service_id_cur; + AWS_ZERO_STRUCT(service_id_cur); + if (serviceId.length() > 0) + { + service_id_cur = aws_byte_cursor_from_c_str(serviceId.c_str()); + } + return SendStreamStart(service_id_cur); + } + int SecureTunnel::SendStreamStart(Crt::ByteCursor serviceId) + { + struct aws_secure_tunnel_message_view messageView; + AWS_ZERO_STRUCT(messageView); + messageView.service_id = &serviceId; + return aws_secure_tunnel_stream_start(m_secure_tunnel, &messageView); + } - int SecureTunnel::SendStreamReset() { return aws_secure_tunnel_stream_reset(m_secure_tunnel); } + int SecureTunnel::SendStreamReset() { return aws_secure_tunnel_stream_reset(m_secure_tunnel, NULL); } aws_secure_tunnel *SecureTunnel::GetUnderlyingHandle() { return m_secure_tunnel; } - void SecureTunnel::s_OnConnectionComplete(void *user_data) + void SecureTunnel::s_OnConnectionComplete( + const struct aws_secure_tunnel_connection_view *connection, + int error_code, + void *user_data) { auto *secureTunnel = static_cast(user_data); - if (secureTunnel->m_OnConnectionComplete) + + if (!error_code) { - secureTunnel->m_OnConnectionComplete(); + /* Check for full callback */ + if (secureTunnel->m_OnConnectionSuccess) + { + std::shared_ptr packet = + std::make_shared(*connection, secureTunnel->m_allocator); + ConnectionSuccessEventData eventData; + eventData.connectionData = packet; + secureTunnel->m_OnConnectionSuccess(secureTunnel, eventData); + return; + } + + /* Fall back on deprecated complete callback */ + if (secureTunnel->m_OnConnectionComplete) + { + secureTunnel->m_OnConnectionComplete(); + } + } + else + { + if (secureTunnel->m_OnConnectionFailure) + { + secureTunnel->m_OnConnectionFailure(secureTunnel, error_code); + } } } - void SecureTunnel::s_OnConnectionShutdown(void *user_data) + void SecureTunnel::s_OnConnectionShutdown(int error_code, void *user_data) { - auto *secureTunnel = static_cast(user_data); + (void)error_code; + SecureTunnel *secureTunnel = static_cast(user_data); if (secureTunnel->m_OnConnectionShutdown) { secureTunnel->m_OnConnectionShutdown(); @@ -357,34 +730,101 @@ namespace Aws void SecureTunnel::s_OnSendDataComplete(int error_code, void *user_data) { - auto *secureTunnel = static_cast(user_data); + SecureTunnel *secureTunnel = static_cast(user_data); if (secureTunnel->m_OnSendDataComplete) { secureTunnel->m_OnSendDataComplete(error_code); } } - void SecureTunnel::s_OnDataReceive(const struct aws_byte_buf *data, void *user_data) + void SecureTunnel::s_OnMessageReceived(const struct aws_secure_tunnel_message_view *message, void *user_data) { - auto *secureTunnel = static_cast(user_data); - if (secureTunnel->m_OnDataReceive) + SecureTunnel *secureTunnel = static_cast(user_data); + if (secureTunnel != nullptr) { - secureTunnel->m_OnDataReceive(*data); + if (message != NULL) + { + /* V2 Protocol API */ + if (secureTunnel->m_OnMessageReceived != nullptr) + { + std::shared_ptr packet = + std::make_shared(*message, secureTunnel->m_allocator); + MessageReceivedEventData eventData; + eventData.message = packet; + secureTunnel->m_OnMessageReceived(secureTunnel, eventData); + return; + } + + /* V1 Protocol API */ + if (secureTunnel->m_OnDataReceive != nullptr) + { + /* + * Old API (V1) expects an aws_byte_buf. Temporarily creating one from an aws_byte_cursor. The + * data will be managed syncronous here with the expectation that the user copies what they + * need as it is cleared as soon as this function completes + */ + struct aws_byte_buf payload_buf; + AWS_ZERO_STRUCT(payload_buf); + payload_buf.allocator = NULL; + payload_buf.buffer = message->payload->ptr; + payload_buf.len = message->payload->len; + secureTunnel->m_OnDataReceive(payload_buf); + return; + } + } + else + { + AWS_LOGF_ERROR(AWS_LS_IOTDEVICE_SECURE_TUNNELING, "Failed to access message view."); + } } } - void SecureTunnel::s_OnStreamStart(void *user_data) + void SecureTunnel::s_OnStreamStarted( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { - auto *secureTunnel = static_cast(user_data); - if (secureTunnel->m_OnStreamStart) + SecureTunnel *secureTunnel = static_cast(user_data); + if (!error_code) { - secureTunnel->m_OnStreamStart(); + if (secureTunnel->m_OnStreamStarted) + { + std::shared_ptr packet = + std::make_shared(*message, secureTunnel->m_allocator); + StreamStartedEventData eventData; + eventData.streamStartedData = packet; + secureTunnel->m_OnStreamStarted(secureTunnel, error_code, eventData); + return; + } + + /* Fall back on deprecated stream start callback */ + if (secureTunnel->m_OnStreamStart) + { + secureTunnel->m_OnStreamStart(); + } } } - void SecureTunnel::s_OnStreamReset(void *user_data) + void SecureTunnel::s_OnStreamStopped( + const struct aws_secure_tunnel_message_view *message, + int error_code, + void *user_data) { - auto *secureTunnel = static_cast(user_data); + (void)message; + (void)error_code; + SecureTunnel *secureTunnel = static_cast(user_data); + + if (secureTunnel->m_OnStreamStopped) + { + std::shared_ptr packet = + std::make_shared(*message, secureTunnel->m_allocator); + StreamStoppedEventData eventData; + eventData.streamStoppedData = packet; + secureTunnel->m_OnStreamStopped(secureTunnel, eventData); + return; + } + + /* Fall back on deprecated stream reset callback */ if (secureTunnel->m_OnStreamReset) { secureTunnel->m_OnStreamReset(); @@ -393,28 +833,23 @@ namespace Aws void SecureTunnel::s_OnSessionReset(void *user_data) { - auto *secureTunnel = static_cast(user_data); + SecureTunnel *secureTunnel = static_cast(user_data); if (secureTunnel->m_OnSessionReset) { secureTunnel->m_OnSessionReset(); } } - void SecureTunnel::s_OnTerminationComplete(void *user_data) + void SecureTunnel::s_OnStopped(void *user_data) { - auto *secureTunnel = static_cast(user_data); - secureTunnel->OnTerminationComplete(); + SecureTunnel *secureTunnel = static_cast(user_data); + secureTunnel->m_selfRef = nullptr; + if (secureTunnel->m_OnStopped) + { + secureTunnel->m_OnStopped(secureTunnel); + } } - void SecureTunnel::OnTerminationComplete() { m_TerminationComplete.set_value(); } - - void SecureTunnel::Shutdown() - { - Close(); - aws_secure_tunnel_release(m_secure_tunnel); - m_secure_tunnel = nullptr; - - m_TerminationComplete.get_future().wait(); - } + void SecureTunnel::Shutdown() { Stop(); } } // namespace Iotsecuretunneling } // namespace Aws diff --git a/secure_tunneling/tests/CMakeLists.txt b/secure_tunneling/tests/CMakeLists.txt deleted file mode 100644 index fd4050245..000000000 --- a/secure_tunneling/tests/CMakeLists.txt +++ /dev/null @@ -1,20 +0,0 @@ -include(AwsTestHarness) -enable_testing() -include(CTest) - -file(GLOB TEST_SRC "*.cpp") -file(GLOB TEST_HDRS "*.h") -file(GLOB TESTS ${TEST_HDRS} ${TEST_SRC}) - -set(TEST_BINARY_NAME ${PROJECT_NAME}-tests) - -aws_use_package(aws-crt-cpp) -aws_use_package(IotSecureTunneling-cpp) - -if (UNIX AND NOT APPLE) - add_test_case(SecureTunnelingHandleStreamStartTest) - add_test_case(SecureTunnelingHandleDataReceiveTest) - add_test_case(SecureTunnelingHandleStreamResetTest) - add_test_case(SecureTunnelingHandleSessionResetTest) - generate_cpp_test_driver(${TEST_BINARY_NAME}) -endif() diff --git a/secure_tunneling/tests/SecureTunnelTest.cpp b/secure_tunneling/tests/SecureTunnelTest.cpp deleted file mode 100644 index fd2d6cc9c..000000000 --- a/secure_tunneling/tests/SecureTunnelTest.cpp +++ /dev/null @@ -1,263 +0,0 @@ -/** - * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - * SPDX-License-Identifier: Apache-2.0. - */ - -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; -using namespace Aws::Crt::Io; -using namespace Aws::Iotdevicecommon; -using namespace Aws::Iotsecuretunneling; - -#define INVALID_STREAM_ID 0 -#define STREAM_ID 10 -#define PAYLOAD "secure tunneling data payload" - -extern "C" -{ - struct aws_websocket_incoming_frame; - extern bool on_websocket_incoming_frame_payload( - struct aws_websocket *websocket, - const struct aws_websocket_incoming_frame *frame, - struct aws_byte_cursor data, - void *user_data); -} - -struct SecureTunnelingTestContext -{ - unique_ptr deviceApiHandle; - unique_ptr elGroup; - unique_ptr resolver; - unique_ptr clientBootstrap; - shared_ptr secureTunnel; - - aws_secure_tunneling_local_proxy_mode localProxyMode; - - SecureTunnelingTestContext() { localProxyMode = AWS_SECURE_TUNNELING_DESTINATION_MODE; } -}; -static SecureTunnelingTestContext s_testContext; - -// Client callbacks implementation -static void s_OnConnectionComplete() {} -static void s_OnConnectionShutdown() {} - -static void s_OnSendDataComplete(int errorCode) {} - -static bool s_OnDataReceiveCorrectPayload = false; -static void s_OnDataReceive(const Aws::Crt::ByteBuf &data) -{ - s_OnDataReceiveCorrectPayload = aws_byte_buf_eq_c_str(&data, PAYLOAD); -} - -static bool s_OnStreamStartCalled = false; -static void s_OnStreamStart() -{ - s_OnStreamStartCalled = true; -} - -static bool s_OnStreamResetCalled = false; -static void s_OnStreamReset() -{ - s_OnStreamResetCalled = true; -} - -static bool s_OnSessionResetCalled = false; -static void s_OnSessionReset() -{ - s_OnSessionResetCalled = true; -} - -static int before(struct aws_allocator *allocator, void *ctx) -{ - auto *testContext = static_cast(ctx); - - aws_http_library_init(allocator); - - testContext->deviceApiHandle = unique_ptr(new DeviceApiHandle(allocator)); - testContext->elGroup = unique_ptr(new EventLoopGroup(1, allocator)); - testContext->resolver = unique_ptr(new DefaultHostResolver(*testContext->elGroup, 8, 30, allocator)); - testContext->clientBootstrap = - unique_ptr(new ClientBootstrap(*testContext->elGroup, *testContext->resolver, allocator)); - testContext->secureTunnel = SecureTunnelBuilder( - allocator, - *testContext->clientBootstrap, - SocketOptions(), - "access_token", - testContext->localProxyMode, - "endpoint") - .WithRootCa("") - .WithOnConnectionComplete(s_OnConnectionComplete) - .WithOnConnectionShutdown(s_OnConnectionShutdown) - .WithOnSendDataComplete(s_OnSendDataComplete) - .WithOnDataReceive(s_OnDataReceive) - .WithOnStreamStart(s_OnStreamStart) - .WithOnStreamReset(s_OnStreamReset) - .WithOnSessionReset(s_OnSessionReset) - .Build(); - return AWS_ERROR_SUCCESS; -} - -static int after(struct aws_allocator *allocator, int setup_result, void *ctx) -{ - auto *testContext = static_cast(ctx); - - testContext->secureTunnel->Shutdown(); - - testContext->secureTunnel.reset(); - testContext->clientBootstrap.reset(); - testContext->resolver.reset(); - testContext->elGroup.reset(); - testContext->deviceApiHandle.reset(); - - aws_http_library_clean_up(); - - return AWS_ERROR_SUCCESS; -} - -static void s_send_secure_tunneling_frame_to_websocket( - const struct aws_iot_st_msg *st_msg, - struct aws_allocator *allocator, - struct aws_secure_tunnel *secure_tunnel) -{ - struct aws_byte_buf serialized_st_msg; - aws_iot_st_msg_serialize_from_struct(&serialized_st_msg, allocator, *st_msg); - - /* Prepend 2 bytes length */ - struct aws_byte_buf websocket_frame; - aws_byte_buf_init(&websocket_frame, allocator, serialized_st_msg.len + 2); - aws_byte_buf_write_be16(&websocket_frame, (uint16_t)serialized_st_msg.len); - struct aws_byte_cursor c = aws_byte_cursor_from_buf(&serialized_st_msg); - aws_byte_buf_append(&websocket_frame, &c); - c = aws_byte_cursor_from_buf(&websocket_frame); - - on_websocket_incoming_frame_payload(NULL, NULL, c, secure_tunnel); - - aws_byte_buf_clean_up(&serialized_st_msg); - aws_byte_buf_clean_up(&websocket_frame); -} - -AWS_TEST_CASE_FIXTURE( - SecureTunnelingHandleStreamStartTest, - before, - s_SecureTunnelingHandleStreamStartTest, - after, - &s_testContext); -static int s_SecureTunnelingHandleStreamStartTest(Aws::Crt::Allocator *allocator, void *ctx) -{ - auto *testContext = static_cast(ctx); - - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_OnStreamStartCalled = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - ASSERT_TRUE(s_OnStreamStartCalled); - ASSERT_INT_EQUALS(STREAM_ID, testContext->secureTunnel->GetUnderlyingHandle()->stream_id); - ASSERT_UINT_EQUALS(0, testContext->secureTunnel->GetUnderlyingHandle()->received_data.len); - - return AWS_ERROR_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - SecureTunnelingHandleDataReceiveTest, - before, - s_SecureTunnelingHandleDataReceiveTest, - after, - &s_testContext); -static int s_SecureTunnelingHandleDataReceiveTest(Aws::Crt::Allocator *allocator, void *ctx) -{ - auto *testContext = static_cast(ctx); - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - /* Send data */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = DATA; - st_msg.stream_id = STREAM_ID; - st_msg.payload = aws_byte_buf_from_c_str(PAYLOAD); - s_OnDataReceiveCorrectPayload = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - ASSERT_TRUE(s_OnDataReceiveCorrectPayload); - ASSERT_INT_EQUALS(STREAM_ID, testContext->secureTunnel->GetUnderlyingHandle()->stream_id); - ASSERT_UINT_EQUALS(0, testContext->secureTunnel->GetUnderlyingHandle()->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - SecureTunnelingHandleStreamResetTest, - before, - SecureTunnelingHandleStreamResetTest, - after, - &s_testContext); -static int SecureTunnelingHandleStreamResetTest(Aws::Crt::Allocator *allocator, void *ctx) -{ - auto *testContext = static_cast(ctx); - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - /* Send StreamReset */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_RESET; - st_msg.stream_id = STREAM_ID; - s_OnStreamResetCalled = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - ASSERT_TRUE(s_OnStreamResetCalled); - ASSERT_INT_EQUALS(INVALID_STREAM_ID, testContext->secureTunnel->GetUnderlyingHandle()->stream_id); - ASSERT_UINT_EQUALS(0, testContext->secureTunnel->GetUnderlyingHandle()->received_data.len); - - return AWS_OP_SUCCESS; -} - -AWS_TEST_CASE_FIXTURE( - SecureTunnelingHandleSessionResetTest, - before, - s_SecureTunnelingHandleSessionResetTest, - after, - &s_testContext); -static int s_SecureTunnelingHandleSessionResetTest(struct aws_allocator *allocator, void *ctx) -{ - auto *testContext = static_cast(ctx); - - /* Send StreamStart first */ - struct aws_iot_st_msg st_msg; - AWS_ZERO_STRUCT(st_msg); - st_msg.type = STREAM_START; - st_msg.stream_id = STREAM_ID; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - /* Send StreamReset */ - AWS_ZERO_STRUCT(st_msg); - st_msg.type = SESSION_RESET; - st_msg.stream_id = STREAM_ID; - s_OnSessionResetCalled = false; - s_send_secure_tunneling_frame_to_websocket(&st_msg, allocator, testContext->secureTunnel->GetUnderlyingHandle()); - - ASSERT_TRUE(s_OnSessionResetCalled); - ASSERT_INT_EQUALS(INVALID_STREAM_ID, testContext->secureTunnel->GetUnderlyingHandle()->stream_id); - ASSERT_UINT_EQUALS(0, testContext->secureTunnel->GetUnderlyingHandle()->received_data.len); - - return AWS_OP_SUCCESS; -} diff --git a/utils/run_secure_tunnel_ci.py b/utils/run_secure_tunnel_ci.py index baa33997f..a9d7105fc 100644 --- a/utils/run_secure_tunnel_ci.py +++ b/utils/run_secure_tunnel_ci.py @@ -17,29 +17,32 @@ def getSecretsAndLaunch(parsed_commands): exit_code = 0 - print ("Creating secure tunnel client using Boto3") + print("Creating secure tunnel client using Boto3") tunnel_client = None try: - tunnel_client = boto3.client("iotsecuretunneling", region_name=parsed_commands.sample_region) + tunnel_client = boto3.client( + "iotsecuretunneling", region_name=parsed_commands.sample_region) except Exception: - print ("Could not create tunnel client!") + print("Could not create tunnel client!") exit(-1) tunnel_data = None try: - tunnel_data = tunnel_client.open_tunnel() + tunnel_data = tunnel_client.open_tunnel( + destinationConfig={'services': ['ssh', 'http', ]}) except Exception: - print ("Could not open tunnel!") + print("Could not open tunnel!") exit(-1) - print ("Launching Secure Tunnel samples...") + print("Launching Secure Tunnel samples...") exit_code = launch_samples(parsed_commands, tunnel_data) - print ("Closing tunnel...") + print("Closing tunnel...") try: - tunnel_client.close_tunnel(tunnelId=tunnel_data["tunnelId"], delete=True) + tunnel_client.close_tunnel( + tunnelId=tunnel_data["tunnelId"], delete=True) except Exception: - print ("Could not close tunnel!") + print("Could not close tunnel!") exit(-1) return exit_code @@ -49,13 +52,17 @@ def launch_samples(parsed_commands, tunnel_data): exit_code = 0 # Right now secure tunneling is only in C++, so we only support launching the sample in the C++ way - launch_arguments_destination = ["--test", "--region", parsed_commands.sample_region, "--access_token", tunnel_data["destinationAccessToken"]] - launch_arguments_source = ["--local_proxy_mode_source", "--region", parsed_commands.sample_region, "--access_token", tunnel_data["sourceAccessToken"]] - - destination_run = subprocess.Popen(args=launch_arguments_destination, executable=parsed_commands.sample_file) - print ("About to sleep before running source part of sample...") - sleep(10) # Sleep to give the destination some time to run - source_run = subprocess.Popen(args=launch_arguments_source, executable=parsed_commands.sample_file) + launch_arguments_destination = [ + "--test", "--region", parsed_commands.sample_region, "--access_token", tunnel_data["destinationAccessToken"]] + launch_arguments_source = ["--local_proxy_mode_source", "--region", + parsed_commands.sample_region, "--access_token", tunnel_data["sourceAccessToken"]] + + destination_run = subprocess.Popen( + args=launch_arguments_destination, executable=parsed_commands.sample_file) + print("About to sleep before running source part of sample...") + sleep(10) # Sleep to give the destination some time to run + source_run = subprocess.Popen( + args=launch_arguments_source, executable=parsed_commands.sample_file) # Wait for the source to finish source_run.wait()