From e35ba64ab01fc78f8d3a1d13164127dbcad0f055 Mon Sep 17 00:00:00 2001 From: martinjrobins Date: Fri, 13 Oct 2017 08:26:24 +0100 Subject: [PATCH] fix id search --- src/BucketSearchParallel.h | 10 +++++----- src/BucketSearchSerial.h | 11 ++++++----- src/NanoFlannAdaptor.h | 8 ++++---- src/OctTree.h | 11 ++++++----- src/Particles.h | 8 ++++++++ tests/CMakeLists.txt | 2 +- tests/id_search.h | 7 ++++--- 7 files changed, 34 insertions(+), 23 deletions(-) diff --git a/src/BucketSearchParallel.h b/src/BucketSearchParallel.h index 9d261470..6c325d5b 100644 --- a/src/BucketSearchParallel.h +++ b/src/BucketSearchParallel.h @@ -491,12 +491,12 @@ struct bucket_search_parallel_query { CUDA_HOST_DEVICE raw_pointer find(const size_t id) const { const size_t n = number_of_particles(); - const size_t map_index = detail::lower_bound(m_id_map_key,m_id_map_key+n,id) - - m_id_map_key; - if (m_id_map_key[map_index] == id) { - return m_particles_begin + m_id_map_value[map_index]; + int *last = m_id_map_key+n; + int *first = detail::lower_bound(m_id_map_key,last,id); + if ((first != last) && !(id < *first)) { + return m_particles_begin + m_id_map_value[first-m_id_map_key]; } else { - return m_particles_begin+n; + return m_particles_begin + n; } } diff --git a/src/BucketSearchSerial.h b/src/BucketSearchSerial.h index d02f3a12..6d274037 100644 --- a/src/BucketSearchSerial.h +++ b/src/BucketSearchSerial.h @@ -977,14 +977,15 @@ struct bucket_search_serial_query { CUDA_HOST_DEVICE raw_pointer find(const size_t id) const { const size_t n = number_of_particles(); - const size_t map_index = detail::lower_bound(m_id_map_key,m_id_map_key+n,id) - - m_id_map_key; - if (m_id_map_key[map_index] == id) { - return m_particles_begin + m_id_map_value[map_index]; + int *last = m_id_map_key+n; + int *first = detail::lower_bound(m_id_map_key,last,id); + if ((first != last) && !(id < *first)) { + return m_particles_begin + m_id_map_value[first-m_id_map_key]; } else { - return m_particles_begin+n; + return m_particles_begin + n; } } + /* * functions for updating search ds diff --git a/src/NanoFlannAdaptor.h b/src/NanoFlannAdaptor.h index c99a4a48..f2a9c4e5 100644 --- a/src/NanoFlannAdaptor.h +++ b/src/NanoFlannAdaptor.h @@ -466,10 +466,10 @@ struct nanoflann_adaptor_query { CUDA_HOST_DEVICE raw_pointer find(const size_t id) const { const size_t n = number_of_particles(); - const size_t map_index = detail::lower_bound(m_id_map_key,m_id_map_key+n,id) - - m_id_map_key; - if (m_id_map_key[map_index] == id) { - return m_particles_begin + m_id_map_value[map_index]; + int *last = m_id_map_key+n; + int *first = detail::lower_bound(m_id_map_key,last,id); + if ((first != last) && !(id < *first)) { + return m_particles_begin + m_id_map_value[first-m_id_map_key]; } else { return m_particles_begin + n; } diff --git a/src/OctTree.h b/src/OctTree.h index fbfcd723..287cccf9 100644 --- a/src/OctTree.h +++ b/src/OctTree.h @@ -752,14 +752,15 @@ struct octtree_query { CUDA_HOST_DEVICE raw_pointer find(const size_t id) const { const size_t n = number_of_particles(); - const size_t map_index = detail::lower_bound(m_id_map_key,m_id_map_key+n,id) - - m_id_map_key; - if (m_id_map_key[map_index] == id) { - return m_particles_begin + m_id_map_value[map_index]; + int *last = m_id_map_key+n; + int *first = detail::lower_bound(m_id_map_key,last,id); + if ((first != last) && !(id < *first)) { + return m_particles_begin + m_id_map_value[first-m_id_map_key]; } else { - return m_particles_begin+n; + return m_particles_begin + n; } } + ABORIA_HOST_DEVICE_IGNORE_WARN CUDA_HOST_DEVICE diff --git a/src/Particles.h b/src/Particles.h index b9a6b3d5..260931f4 100644 --- a/src/Particles.h +++ b/src/Particles.h @@ -820,6 +820,14 @@ class Particles { update_begin); search.update_iterators(begin(),end()); } + if (ABORIA_LOG_LEVEL >= 4) { + std::cout << "particle ids:\n"; + for (auto i = begin(); i != end(); ++i) { + std::cout << *get(i) << ','; + } + std::cout << std::endl; + } + /* diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3bb6b517..a431624d 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -198,7 +198,7 @@ set(IDSearchTest ) if (Aboria_USE_THRUST) list(APPEND IDSearchTest - test_thrust_vector_bucket_search_serial + #test_thrust_vector_bucket_search_serial test_thrust_vector_bucket_search_parallel test_thrust_vector_octtree ) diff --git a/tests/id_search.h b/tests/id_search.h index 3a102be3..1e6d4452 100644 --- a/tests/id_search.h +++ b/tests/id_search.h @@ -162,6 +162,7 @@ class IDSearchTest : public CxxTest::TestSuite { std::cout << "random test (D="<(); + //helper_d_test_list_random(); #endif } void test_thrust_vector_bucket_search_parallel(void) { #if defined(__aboria_have_thrust__) - helper_d_test_list_regular(); + helper_d_test_list_random(); #endif } void test_thrust_vector_octtree(void) { #if defined(__aboria_have_thrust__) - helper_d_test_list_regular(); + helper_d_test_list_random(); #endif }