cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
bfknn_parallel_impl.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <cmath>
5#include <limits>
6#include <mutex>
7#include <numeric>
8#include <vector>
9
11
12namespace toolbox::pcl
13{
14
15// Generic parallel brute-force KNN implementation
16template<typename Element, typename Metric>
18{
19 m_data = std::make_shared<container_type>(data);
20 return m_data->size();
21}
22
23template<typename Element, typename Metric>
25{
26 m_data = data;
27 return m_data ? m_data->size() : 0;
28}
29
30template<typename Element, typename Metric>
32{
33 m_compile_time_metric = metric;
34 m_use_runtime_metric = false;
35}
36
37template<typename Element, typename Metric>
40{
41 m_runtime_metric = metric;
42 m_use_runtime_metric = true;
43}
44
45template<typename Element, typename Metric>
47 const element_type& query,
48 std::size_t num_neighbors,
49 std::vector<std::size_t>& indices,
50 std::vector<distance_type>& distances)
51{
52 if (!m_data || m_data->empty())
53 {
54 return false;
55 }
56
57 const std::size_t data_size = m_data->size();
58 num_neighbors = std::min(num_neighbors, data_size);
59
60 // For small datasets or when parallel is disabled, use sequential version
61 if (!m_parallel_enabled || data_size < k_parallel_threshold)
62 {
63 // Sequential implementation (same as non-parallel version)
64 std::vector<std::pair<distance_type, std::size_t>> distance_index_pairs;
65 distance_index_pairs.reserve(data_size);
66
67 for (std::size_t i = 0; i < data_size; ++i)
68 {
69 distance_type dist;
70 if (m_use_runtime_metric && m_runtime_metric)
71 {
72 // For point types, convert to arrays and use distance method
73 value_type arr_query[3] = {query.x, query.y, query.z};
74 value_type arr_data[3] = {(*m_data)[i].x, (*m_data)[i].y, (*m_data)[i].z};
75 dist = m_runtime_metric->distance(arr_query, arr_data, 3);
76 }
77 else
78 {
79 dist = m_compile_time_metric(query, (*m_data)[i]);
80 }
81 distance_index_pairs.emplace_back(dist, i);
82 }
83
84 std::partial_sort(distance_index_pairs.begin(),
85 distance_index_pairs.begin() + num_neighbors,
86 distance_index_pairs.end(),
87 [](const auto& a, const auto& b) { return a.first < b.first; });
88
89 indices.resize(num_neighbors);
90 distances.resize(num_neighbors);
91 for (std::size_t i = 0; i < num_neighbors; ++i)
92 {
93 distances[i] = distance_index_pairs[i].first;
94 indices[i] = distance_index_pairs[i].second;
95 }
96 }
97 else
98 {
99 // Parallel implementation
101 const std::size_t num_threads = thread_pool.get_thread_count();
102 const std::size_t chunk_size = (data_size + num_threads - 1) / num_threads;
103
104 // Thread-local storage for distance-index pairs
105 std::vector<std::vector<std::pair<distance_type, std::size_t>>> thread_results(num_threads);
106 std::vector<std::future<void>> futures;
107
108 // Launch parallel tasks
109 for (std::size_t t = 0; t < num_threads; ++t)
110 {
111 const std::size_t start = t * chunk_size;
112 const std::size_t end = std::min(start + chunk_size, data_size);
113
114 if (start >= data_size) break;
115
116 futures.emplace_back(thread_pool.submit([this, &query, start, end, t, &thread_results]() {
117 auto& local_results = thread_results[t];
118 local_results.reserve(end - start);
119
120 for (std::size_t i = start; i < end; ++i)
121 {
122 distance_type dist;
123 if (m_use_runtime_metric && m_runtime_metric)
124 {
125 // For point types, convert to arrays and use distance method
126 value_type arr_query[3] = {query.x, query.y, query.z};
127 value_type arr_data[3] = {(*m_data)[i].x, (*m_data)[i].y, (*m_data)[i].z};
128 dist = m_runtime_metric->distance(arr_query, arr_data, 3);
129 }
130 else
131 {
132 dist = m_compile_time_metric(query, (*m_data)[i]);
133 }
134 local_results.emplace_back(dist, i);
135 }
136 }));
137 }
138
139 // Wait for all tasks to complete
140 for (auto& future : futures)
141 {
142 future.wait();
143 }
144
145 // Merge results from all threads
146 std::vector<std::pair<distance_type, std::size_t>> all_results;
147 all_results.reserve(data_size);
148 for (const auto& thread_result : thread_results)
149 {
150 all_results.insert(all_results.end(), thread_result.begin(), thread_result.end());
151 }
152
153 // Find k nearest neighbors
154 std::partial_sort(all_results.begin(),
155 all_results.begin() + num_neighbors,
156 all_results.end(),
157 [](const auto& a, const auto& b) { return a.first < b.first; });
158
159 indices.resize(num_neighbors);
160 distances.resize(num_neighbors);
161 for (std::size_t i = 0; i < num_neighbors; ++i)
162 {
163 distances[i] = all_results[i].first;
164 indices[i] = all_results[i].second;
165 }
166 }
167
168 return true;
169}
170
171template<typename Element, typename Metric>
173 const element_type& query,
174 distance_type radius,
175 std::vector<std::size_t>& indices,
176 std::vector<distance_type>& distances)
177{
178 if (!m_data || m_data->empty() || radius <= 0)
179 {
180 return false;
181 }
182
183 indices.clear();
184 distances.clear();
185
186 const std::size_t data_size = m_data->size();
187
188 // For small datasets or when parallel is disabled, use sequential version
189 if (!m_parallel_enabled || data_size < k_parallel_threshold)
190 {
191 std::vector<std::pair<distance_type, std::size_t>> distance_index_pairs;
192
193 for (std::size_t i = 0; i < data_size; ++i)
194 {
195 distance_type dist;
196 if (m_use_runtime_metric && m_runtime_metric)
197 {
198 // For point types, convert to arrays and use distance method
199 value_type arr_query[3] = {query.x, query.y, query.z};
200 value_type arr_data[3] = {(*m_data)[i].x, (*m_data)[i].y, (*m_data)[i].z};
201 dist = m_runtime_metric->distance(arr_query, arr_data, 3);
202 }
203 else
204 {
205 dist = m_compile_time_metric(query, (*m_data)[i]);
206 }
207
208 if (dist <= radius)
209 {
210 distance_index_pairs.emplace_back(dist, i);
211 }
212 }
213
214 std::sort(distance_index_pairs.begin(), distance_index_pairs.end(),
215 [](const auto& a, const auto& b) { return a.first < b.first; });
216
217 indices.reserve(distance_index_pairs.size());
218 distances.reserve(distance_index_pairs.size());
219 for (const auto& [dist, idx] : distance_index_pairs)
220 {
221 distances.push_back(dist);
222 indices.push_back(idx);
223 }
224 }
225 else
226 {
227 // Parallel implementation
229 const std::size_t num_threads = thread_pool.get_thread_count();
230 const std::size_t chunk_size = (data_size + num_threads - 1) / num_threads;
231
232 // Thread-local storage for results
233 std::vector<std::vector<std::pair<distance_type, std::size_t>>> thread_results(num_threads);
234 std::vector<std::future<void>> futures;
235
236 // Launch parallel tasks
237 for (std::size_t t = 0; t < num_threads; ++t)
238 {
239 const std::size_t start = t * chunk_size;
240 const std::size_t end = std::min(start + chunk_size, data_size);
241
242 if (start >= data_size) break;
243
244 futures.emplace_back(thread_pool.submit([this, &query, radius, start, end, t, &thread_results]() {
245 auto& local_results = thread_results[t];
246
247 for (std::size_t i = start; i < end; ++i)
248 {
249 distance_type dist;
250 if (m_use_runtime_metric && m_runtime_metric)
251 {
252 // For point types, convert to arrays and use distance method
253 value_type arr_query[3] = {query.x, query.y, query.z};
254 value_type arr_data[3] = {(*m_data)[i].x, (*m_data)[i].y, (*m_data)[i].z};
255 dist = m_runtime_metric->distance(arr_query, arr_data, 3);
256 }
257 else
258 {
259 dist = m_compile_time_metric(query, (*m_data)[i]);
260 }
261
262 if (dist <= radius)
263 {
264 local_results.emplace_back(dist, i);
265 }
266 }
267 }));
268 }
269
270 // Wait for all tasks to complete
271 for (auto& future : futures)
272 {
273 future.wait();
274 }
275
276 // Merge and sort results
277 std::vector<std::pair<distance_type, std::size_t>> all_results;
278 for (const auto& thread_result : thread_results)
279 {
280 all_results.insert(all_results.end(), thread_result.begin(), thread_result.end());
281 }
282
283 std::sort(all_results.begin(), all_results.end(),
284 [](const auto& a, const auto& b) { return a.first < b.first; });
285
286 indices.reserve(all_results.size());
287 distances.reserve(all_results.size());
288 for (const auto& [dist, idx] : all_results)
289 {
290 distances.push_back(dist);
291 indices.push_back(idx);
292 }
293 }
294
295 return true;
296}
297
298} // namespace toolbox::pcl
static thread_pool_singleton_t & instance()
获取单例实例/Get the singleton instance
Definition thread_pool_singleton.hpp:23
Definition metric_factory.hpp:23
Definition bfknn_parallel.hpp:14
bool kneighbors_impl(const element_type &query, std::size_t num_neighbors, std::vector< std::size_t > &indices, std::vector< distance_type > &distances)
Definition bfknn_parallel_impl.hpp:46
typename base_type::container_ptr container_ptr
Definition bfknn_parallel.hpp:22
typename traits_type::distance_type distance_type
Definition bfknn_parallel.hpp:20
typename traits_type::metric_type metric_type
Definition bfknn_parallel.hpp:19
std::size_t set_input_impl(const container_type &data)
Definition bfknn_parallel_impl.hpp:17
typename base_type::container_type container_type
Definition bfknn_parallel.hpp:21
void set_metric_impl(const metric_type &metric)
Definition bfknn_parallel_impl.hpp:31
typename Element::value_type value_type
Definition bfknn_parallel.hpp:23
typename traits_type::element_type element_type
Definition bfknn_parallel.hpp:18
Definition base_correspondence_generator.hpp:18