diff --git a/source/common/singleton/threadsafe_singleton.h b/source/common/singleton/threadsafe_singleton.h index 3af33deb2548..cb5de79254bf 100644 --- a/source/common/singleton/threadsafe_singleton.h +++ b/source/common/singleton/threadsafe_singleton.h @@ -66,6 +66,9 @@ template class InjectableSingleton { } static void clear() { loader_ = nullptr; } + // Atomically replace the value, returning the old value. + static T* replaceForTest(T* new_value) { return loader_.exchange(new_value); } + protected: static std::atomic loader_; }; @@ -86,20 +89,16 @@ template class ScopedInjectableLoader { std::unique_ptr instance_; }; -// This class saves the singleton object and restore the original singleton at destroy. This class -// is not thread safe. It can be used in single thread test. -template -class StackedScopedInjectableLoader : - // To access the protected loader_. - protected InjectableSingleton { +// This class saves the singleton object and restore the original singleton at destroy. +template class StackedScopedInjectableLoaderForTest { public: - explicit StackedScopedInjectableLoader(std::unique_ptr&& instance) { - original_loader_ = InjectableSingleton::getExisting(); - InjectableSingleton::clear(); + explicit StackedScopedInjectableLoaderForTest(std::unique_ptr&& instance) { instance_ = std::move(instance); - InjectableSingleton::initialize(instance_.get()); + original_loader_ = InjectableSingleton::replaceForTest(instance_.get()); + } + ~StackedScopedInjectableLoaderForTest() { + InjectableSingleton::replaceForTest(original_loader_); } - ~StackedScopedInjectableLoader() { InjectableSingleton::loader_ = original_loader_; } private: std::unique_ptr instance_; diff --git a/test/common/listener_manager/listener_manager_impl_test.cc b/test/common/listener_manager/listener_manager_impl_test.cc index bb81bb701d7a..d296966813b4 100644 --- a/test/common/listener_manager/listener_manager_impl_test.cc +++ b/test/common/listener_manager/listener_manager_impl_test.cc @@ -2604,7 +2604,8 @@ TEST_P(ListenerManagerImplTest, BindToPortEqualToFalse) { InSequence s; auto mock_interface = std::make_unique( std::vector{Network::Address::IpVersion::v4}); - StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + StackedScopedInjectableLoaderForTest new_interface( + std::move(mock_interface)); ProdListenerComponentFactory real_listener_factory(server_); EXPECT_CALL(*worker_, start(_, _)); @@ -2643,7 +2644,8 @@ TEST_P(ListenerManagerImplTest, UpdateBindToPortEqualToFalse) { InSequence s; auto mock_interface = std::make_unique( std::vector{Network::Address::IpVersion::v4}); - StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + StackedScopedInjectableLoaderForTest new_interface( + std::move(mock_interface)); ProdListenerComponentFactory real_listener_factory(server_); EXPECT_CALL(*worker_, start(_, _)); diff --git a/test/common/network/listen_socket_impl_test.cc b/test/common/network/listen_socket_impl_test.cc index 0bc6bd067bdc..769a62739177 100644 --- a/test/common/network/listen_socket_impl_test.cc +++ b/test/common/network/listen_socket_impl_test.cc @@ -229,7 +229,7 @@ TEST_P(ListenSocketImplTestTcp, SupportedIpFamilyVirtualSocketIsCreatedWithNoBsd auto any_address = version_ == Address::IpVersion::v4 ? Utility::getIpv4AnyAddress() : Utility::getIpv6AnyAddress(); - StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + StackedScopedInjectableLoaderForTest new_interface(std::move(mock_interface)); { EXPECT_CALL(*mock_interface_ptr, socket(_, _, _)).Times(0); @@ -245,7 +245,7 @@ TEST_P(ListenSocketImplTestTcp, DeathAtUnSupportedIpFamilyListenSocket) { auto* mock_interface_ptr = mock_interface.get(); auto the_other_address = version_ == Address::IpVersion::v4 ? Utility::getIpv6AnyAddress() : Utility::getIpv4AnyAddress(); - StackedScopedInjectableLoader new_interface(std::move(mock_interface)); + StackedScopedInjectableLoaderForTest new_interface(std::move(mock_interface)); { EXPECT_CALL(*mock_interface_ptr, socket(_, _, _)).Times(0); EXPECT_CALL(*mock_interface_ptr, socket(_, _, _, _, _)).Times(0); diff --git a/test/integration/socket_interface_swap.cc b/test/integration/socket_interface_swap.cc index 9ea29fa7ad3b..7d29f0691a47 100644 --- a/test/integration/socket_interface_swap.cc +++ b/test/integration/socket_interface_swap.cc @@ -3,10 +3,8 @@ namespace Envoy { SocketInterfaceSwap::SocketInterfaceSwap(Network::Socket::Type socket_type) - : write_matcher_(std::make_shared(socket_type)) { - Envoy::Network::SocketInterfaceSingleton::clear(); - test_socket_interface_loader_ = std::make_unique( - std::make_unique( + : write_matcher_(std::make_shared(socket_type)), + test_socket_interface_loader_(std::make_unique( [write_matcher = write_matcher_](Envoy::Network::TestIoSocketHandle* io_handle) -> absl::optional { Api::IoErrorPtr error_override = write_matcher->returnConnectOverride(io_handle); @@ -28,8 +26,7 @@ SocketInterfaceSwap::SocketInterfaceSwap(Network::Socket::Type socket_type) }, [write_matcher = write_matcher_](Network::IoHandle::RecvMsgOutput& output) { write_matcher->readOverride(output); - })); -} + })) {} void SocketInterfaceSwap::IoHandleMatcher::setResumeWrites() { absl::MutexLock lock(&mutex_); diff --git a/test/integration/socket_interface_swap.h b/test/integration/socket_interface_swap.h index 57e82087de8f..1e16aaaf6156 100644 --- a/test/integration/socket_interface_swap.h +++ b/test/integration/socket_interface_swap.h @@ -116,15 +116,8 @@ class SocketInterfaceSwap { explicit SocketInterfaceSwap(Network::Socket::Type socket_type); - ~SocketInterfaceSwap() { - test_socket_interface_loader_.reset(); - Envoy::Network::SocketInterfaceSingleton::initialize(previous_socket_interface_); - } - - Envoy::Network::SocketInterface* const previous_socket_interface_{ - Envoy::Network::SocketInterfaceSingleton::getExisting()}; std::shared_ptr write_matcher_; - std::unique_ptr test_socket_interface_loader_; + StackedScopedInjectableLoaderForTest test_socket_interface_loader_; }; } // namespace Envoy