cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
ransac_registration_impl.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <future>
4#include <numeric>
5#include <thread>
6
11
12namespace toolbox::pcl
13{
14
15template<typename DataType>
17{
18 // 初始化结果 / Initialize result
19 result.transformation.setIdentity();
20 result.fitness_score = std::numeric_limits<DataType>::max();
21 result.inliers.clear();
22 result.num_iterations = 0;
23 result.converged = false;
24
25 // 获取对应关系 / Get correspondences
26 auto correspondences = this->get_correspondences();
27 if (!correspondences || correspondences->empty()) {
28 LOG_ERROR_S << "RANSAC: 没有提供对应关系 / No correspondences provided";
29 return false;
30 }
31
32 // 如果对应关系太少 / If too few correspondences
33 if (correspondences->size() < m_sample_size) {
34 LOG_ERROR_S << "RANSAC: 对应关系数量不足 / Insufficient correspondences: "
35 << correspondences->size() << " < " << m_sample_size;
36 return false;
37 }
38
39 LOG_DEBUG_S << "RANSAC: 开始配准,对应关系数量 / Starting registration with "
40 "correspondences: "
41 << correspondences->size();
42
43 // 初始化随机数生成器 / Initialize random number generator
44 std::mt19937 generator(this->get_random_seed());
45
46 // 计算所需迭代次数 / Calculate required iterations
47 DataType outlier_ratio =
48 static_cast<DataType>(0.5); // 初始估计 / Initial estimate
49 std::size_t max_iterations = this->get_max_iterations();
50 std::size_t adaptive_iterations = calculate_iterations(outlier_ratio);
51 std::size_t iterations = std::min(max_iterations, adaptive_iterations);
52
53 // 最佳结果 / Best result
54 transformation_t best_transform;
55 best_transform.setIdentity();
56 std::vector<std::size_t> best_inliers;
57 std::size_t best_inlier_count = 0;
58
59 // 收敛检查变量 / Convergence check variables
60 const std::size_t convergence_window_size = 20; // 滑动窗口大小 / Sliding window size
61 std::vector<std::size_t> inlier_history; // 记录每次迭代的最佳内点数 / Record best inlier count for each iteration
62 inlier_history.reserve(convergence_window_size);
63
64 // 计时器 / Timer
66 timer.start();
67
68 // 主RANSAC循环 / Main RANSAC loop
69 for (std::size_t iter = 0; iter < iterations; ++iter) {
70 result.num_iterations = iter + 1;
71
72 // 随机采样 / Random sampling
73 std::vector<correspondence_t> sample;
74 sample_correspondences(sample, generator);
75
76 // 估计变换 / Estimate transformation
77 transformation_t transform = estimate_rigid_transform_svd(sample);
78
79 // 计算内点 / Count inliers
80 std::vector<std::size_t> inliers;
81 std::size_t inlier_count = count_inliers(transform, inliers);
82
83 // 更新最佳结果 / Update best result
84 if (inlier_count > best_inlier_count) {
85 best_transform = transform;
86 best_inliers = std::move(inliers);
87 best_inlier_count = inlier_count;
88
89 // 更新自适应迭代次数 / Update adaptive iterations
90 outlier_ratio =
91 static_cast<DataType>(correspondences->size() - best_inlier_count)
92 / correspondences->size();
93 adaptive_iterations = calculate_iterations(outlier_ratio);
94 iterations = std::min(max_iterations, adaptive_iterations);
95
96 // 早停检查 / Early stopping check
97 DataType inlier_ratio =
98 static_cast<DataType>(best_inlier_count) / correspondences->size();
99 if (inlier_ratio >= m_early_stop_ratio) {
100 LOG_DEBUG_S << "RANSAC: 早停,内点比例 / Early stopping, inlier ratio: "
101 << inlier_ratio;
102 break;
103 }
104 }
105
106 // 收敛检查 - 每次迭代都检查 / Convergence check - check every iteration
107 // 记录当前最佳内点数 / Record current best inlier count
108 inlier_history.push_back(best_inlier_count);
109
110 // 如果历史记录已满,使用滑动窗口 / Use sliding window if history is full
111 if (inlier_history.size() > convergence_window_size) {
112 inlier_history.erase(inlier_history.begin());
113 }
114
115 // 需要足够的历史数据才能判断收敛 / Need enough history to check convergence
116 if (inlier_history.size() >= convergence_window_size) {
117 // 计算窗口内的总改进 / Calculate total improvement in window
118 std::size_t window_improvement = inlier_history.back() - inlier_history.front();
119
120 // 计算平均每次迭代的改进 / Calculate average improvement per iteration
121 DataType avg_improvement_per_iter = static_cast<DataType>(window_improvement)
122 / static_cast<DataType>(convergence_window_size - 1);
123
124 // 计算相对改进率 / Calculate relative improvement rate
125 DataType relative_improvement = 0.0;
126 if (inlier_history.front() > 0) {
127 relative_improvement = static_cast<DataType>(window_improvement)
128 / static_cast<DataType>(inlier_history.front());
129 }
130
131 // 收敛条件:绝对改进和相对改进都很小 / Convergence: both absolute and relative improvements are small
132 const DataType min_avg_improvement = 0.5; // 平均每次迭代至少增加0.5个内点 / At least 0.5 inlier per iteration on average
133 const DataType min_relative_improvement = 0.01; // 或者1%的相对改进 / Or 1% relative improvement
134
135 if (avg_improvement_per_iter < min_avg_improvement &&
136 relative_improvement < min_relative_improvement) {
137 LOG_DEBUG_S << "RANSAC: 收敛,最近 " << convergence_window_size
138 << " 次迭代内点数增加 / Converged, inlier count improved by "
139 << window_improvement << " 个 / in last " << convergence_window_size
140 << " iterations (相对改进 / relative improvement: "
141 << relative_improvement * 100 << "%)";
142 result.converged = true;
143 break;
144 }
145 }
146 }
147
148 timer.stop();
149 double elapsed_time = timer.elapsed_time();
150 LOG_DEBUG_S << "RANSAC: 完成 " << result.num_iterations
151 << " 次迭代,耗时 / iterations in: " << elapsed_time << " 秒/s";
152
153 // 检查是否找到足够的内点 / Check if enough inliers found
154 if (best_inlier_count < this->get_min_inliers()) {
155 LOG_ERROR_S << "RANSAC: 内点数量不足 / Insufficient inliers: "
156 << best_inlier_count << " < " << this->get_min_inliers();
157 return false;
158 }
159
160 // 精炼结果(可选) / Refine result (optional)
161 if (m_refine_result && best_inlier_count >= m_sample_size) {
162 LOG_DEBUG_S << "RANSAC: 使用 " << best_inlier_count
163 << " 个内点精炼结果 / Refining with inliers";
164 best_transform = refine_transformation(best_inliers);
165
166 // 重新计算内点 / Recompute inliers
167 best_inlier_count = count_inliers(best_transform, best_inliers);
168 }
169
170 // 设置结果 / Set result
171 result.transformation = best_transform;
172 result.inliers = std::move(best_inliers);
173 result.fitness_score = compute_fitness_score(best_transform, result.inliers);
174 result.converged =
175 result.converged || (best_inlier_count >= this->get_min_inliers());
176
177 LOG_DEBUG_S << "RANSAC: 配准完成,内点 / Registration complete, inliers: "
178 << result.inliers.size() << "/" << correspondences->size()
179 << ", 质量评分 / fitness score: " << result.fitness_score;
180
181 return true;
182}
183
184template<typename DataType>
186{
187 // 基类已经验证了点云,这里验证RANSAC特定的输入 / Base class validated clouds,
188 // validate RANSAC-specific input
189 auto correspondences = this->get_correspondences();
190 if (!correspondences || correspondences->empty()) {
191 LOG_ERROR_S << "RANSAC: 需要对应关系 / Correspondences required";
192 return false;
193 }
194
195 if (correspondences->size() < m_sample_size) {
196 LOG_ERROR_S << "RANSAC: 对应关系数量不足以进行采样 / Not enough "
197 "correspondences for sampling";
198 return false;
199 }
200
201 return true;
202}
203
204template<typename DataType>
206 DataType outlier_ratio) const
207{
208 // N = log(1 - p) / log(1 - (1 - e)^s)
209 // p = confidence, e = outlier_ratio, s = sample_size
210
211 if (outlier_ratio <= 0 || outlier_ratio >= 1) {
212 return this->get_max_iterations();
213 }
214
215 DataType inlier_ratio = 1 - outlier_ratio;
216 DataType sample_success_prob =
217 std::pow(inlier_ratio, static_cast<DataType>(m_sample_size));
218
219 if (sample_success_prob <= 0 || sample_success_prob >= 1) {
220 return this->get_max_iterations();
221 }
222
223 DataType num_iterations =
224 std::log(1 - m_confidence) / std::log(1 - sample_success_prob);
225
226 return static_cast<std::size_t>(std::ceil(num_iterations));
227}
228
229template<typename DataType>
230void ransac_registration_t<DataType>::sample_correspondences(
231 std::vector<correspondence_t>& sample, std::mt19937& generator) const
232{
233 auto correspondences = this->get_correspondences();
234 const std::size_t num_correspondences = correspondences->size();
235
236 // 清空并预留空间 / Clear and reserve space
237 sample.clear();
238 sample.reserve(m_sample_size);
239
240 // 使用Fisher-Yates洗牌算法的变体进行无重复采样 / Use variant of Fisher-Yates
241 // shuffle for sampling without replacement
242 std::vector<std::size_t> indices(num_correspondences);
243 std::iota(indices.begin(), indices.end(), 0);
244
245 for (std::size_t i = 0; i < m_sample_size; ++i) {
246 std::uniform_int_distribution<std::size_t> dist(i, num_correspondences - 1);
247 std::size_t j = dist(generator);
248 std::swap(indices[i], indices[j]);
249 sample.push_back((*correspondences)[indices[i]]);
250 }
251}
252
253template<typename DataType>
255ransac_registration_t<DataType>::estimate_rigid_transform_svd(
256 const std::vector<correspondence_t>& sample) const
257{
258 const std::size_t n = sample.size();
259 transformation_t transform;
260 transform.setIdentity();
261
262 if (n < 3) {
263 LOG_WARN_S << "RANSAC: 样本数量不足以估计变换 / Insufficient samples for "
264 "transformation estimation";
265 return transform;
266 }
267
268 auto source_cloud = this->get_source_cloud();
269 auto target_cloud = this->get_target_cloud();
270
271 // 计算质心 / Compute centroids
272 vector3_t source_centroid = vector3_t::Zero();
273 vector3_t target_centroid = vector3_t::Zero();
274
275 for (const auto& corr : sample) {
276 const auto& src_pt = source_cloud->points[corr.src_idx];
277 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
278
279 source_centroid += vector3_t(src_pt.x, src_pt.y, src_pt.z);
280 target_centroid += vector3_t(tgt_pt.x, tgt_pt.y, tgt_pt.z);
281 }
282
283 source_centroid /= static_cast<DataType>(n);
284 target_centroid /= static_cast<DataType>(n);
285
286 // 构建协方差矩阵 / Build covariance matrix
287 matrix3_t H = matrix3_t::Zero();
288
289 for (const auto& corr : sample) {
290 const auto& src_pt = source_cloud->points[corr.src_idx];
291 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
292
293 vector3_t src_centered(src_pt.x - source_centroid[0],
294 src_pt.y - source_centroid[1],
295 src_pt.z - source_centroid[2]);
296 vector3_t tgt_centered(tgt_pt.x - target_centroid[0],
297 tgt_pt.y - target_centroid[1],
298 tgt_pt.z - target_centroid[2]);
299
300 H += src_centered * tgt_centered.transpose();
301 }
302
303 // SVD分解 / SVD decomposition
304 Eigen::JacobiSVD<matrix3_t> svd(H, Eigen::ComputeFullU | Eigen::ComputeFullV);
305 matrix3_t U = svd.matrixU();
306 matrix3_t V = svd.matrixV();
307
308 // 计算旋转矩阵 / Compute rotation matrix
309 matrix3_t R = V * U.transpose();
310
311 // 处理反射情况 / Handle reflection case
312 if (R.determinant() < 0) {
313 V.col(2) *= -1;
314 R = V * U.transpose();
315 }
316
317 // 计算平移向量 / Compute translation vector
318 vector3_t t = target_centroid - R * source_centroid;
319
320 // 构建4x4变换矩阵 / Build 4x4 transformation matrix
321 transform.template block<3, 3>(0, 0) = R;
322 transform.template block<3, 1>(0, 3) = t;
323
324 return transform;
325}
326
327template<typename DataType>
328std::size_t ransac_registration_t<DataType>::count_inliers(
329 const transformation_t& transform, std::vector<std::size_t>& inliers) const
330{
331 inliers.clear();
332
333 auto correspondences = this->get_correspondences();
334 auto source_cloud = this->get_source_cloud();
335 auto target_cloud = this->get_target_cloud();
336
337 const DataType threshold_squared =
338 this->get_inlier_threshold() * this->get_inlier_threshold();
339
340 // 根据是否启用并行选择实现 / Choose implementation based on parallel enabled
341 if (this->is_parallel_enabled()) {
342 // 并行版本 / Parallel version
343 // 获取线程数 / Get number of threads
344 const std::size_t num_threads = std::thread::hardware_concurrency();
345 std::vector<std::vector<std::size_t>> local_inliers(num_threads);
346
347 // 将对应关系分成多个块 / Divide correspondences into chunks
348 const std::size_t chunk_size =
349 (correspondences->size() + num_threads - 1) / num_threads;
350 std::vector<std::future<void>> futures;
351
352 for (std::size_t thread_id = 0; thread_id < num_threads; ++thread_id) {
353 std::size_t start = thread_id * chunk_size;
354 std::size_t end = std::min(start + chunk_size, correspondences->size());
355
356 if (start >= end)
357 break;
358
359 futures.push_back(std::async(
360 std::launch::async,
361 [&, thread_id, start, end]()
362 {
363 auto& thread_inliers = local_inliers[thread_id];
364
365 for (std::size_t i = start; i < end; ++i) {
366 const auto& corr = (*correspondences)[i];
367 const auto& src_pt = source_cloud->points[corr.src_idx];
368
369 // 变换源点 / Transform source point
370 vector3_t src_vec(src_pt.x, src_pt.y, src_pt.z);
371 vector3_t transformed =
372 transform.template block<3, 3>(0, 0) * src_vec
373 + transform.template block<3, 1>(0, 3);
374
375 // 计算到目标点的距离 / Compute distance to target point
376 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
377 DataType dx = transformed[0] - tgt_pt.x;
378 DataType dy = transformed[1] - tgt_pt.y;
379 DataType dz = transformed[2] - tgt_pt.z;
380 DataType dist_squared = dx * dx + dy * dy + dz * dz;
381
382 if (dist_squared <= threshold_squared) {
383 thread_inliers.push_back(i);
384 }
385 }
386 }));
387 }
388
389 // 等待所有线程完成 / Wait for all threads to complete
390 for (auto& future : futures) {
391 future.wait();
392 }
393
394 // 合并结果 / Merge results
395 for (const auto& thread_inliers : local_inliers) {
396 inliers.insert(
397 inliers.end(), thread_inliers.begin(), thread_inliers.end());
398 }
399 } else {
400 // 串行版本 / Sequential version
401 for (std::size_t i = 0; i < correspondences->size(); ++i) {
402 const auto& corr = (*correspondences)[i];
403 const auto& src_pt = source_cloud->points[corr.src_idx];
404
405 // 变换源点 / Transform source point
406 vector3_t src_vec(src_pt.x, src_pt.y, src_pt.z);
407 vector3_t transformed = transform.template block<3, 3>(0, 0) * src_vec
408 + transform.template block<3, 1>(0, 3);
409
410 // 计算到目标点的距离 / Compute distance to target point
411 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
412 DataType dx = transformed[0] - tgt_pt.x;
413 DataType dy = transformed[1] - tgt_pt.y;
414 DataType dz = transformed[2] - tgt_pt.z;
415 DataType dist_squared = dx * dx + dy * dy + dz * dz;
416
417 if (dist_squared <= threshold_squared) {
418 inliers.push_back(i);
419 }
420 }
421 }
422
423 return inliers.size();
424}
425
426template<typename DataType>
428ransac_registration_t<DataType>::refine_transformation(
429 const std::vector<std::size_t>& inlier_indices) const
430{
431 // 使用所有内点重新估计变换 / Re-estimate transformation using all inliers
432 auto correspondences = this->get_correspondences();
433 std::vector<correspondence_t> inlier_correspondences;
434 inlier_correspondences.reserve(inlier_indices.size());
435
436 for (std::size_t idx : inlier_indices) {
437 inlier_correspondences.push_back((*correspondences)[idx]);
438 }
439
440 return estimate_rigid_transform_svd(inlier_correspondences);
441}
442
443template<typename DataType>
444DataType ransac_registration_t<DataType>::compute_fitness_score(
445 const transformation_t& transform,
446 const std::vector<std::size_t>& inliers) const
447{
448 if (inliers.empty()) {
449 return std::numeric_limits<DataType>::max();
450 }
451
452 // 使用 LCPMetric 计算适应度分数 / Use LCPMetric to compute fitness score
453 toolbox::metrics::LCPMetric<DataType> lcp_metric(this->get_inlier_threshold());
454
455 auto correspondences = this->get_correspondences();
456 auto source_cloud = this->get_source_cloud();
457 auto target_cloud = this->get_target_cloud();
458
459 // 创建内点对应的源点云和目标点云 / Create source and target clouds for inliers
460 toolbox::types::point_cloud_t<DataType> inlier_source, inlier_target;
461 inlier_source.points.reserve(inliers.size());
462 inlier_target.points.reserve(inliers.size());
463
464 for (std::size_t idx : inliers) {
465 const auto& corr = (*correspondences)[idx];
466 inlier_source.points.push_back(source_cloud->points[corr.src_idx]);
467 inlier_target.points.push_back(target_cloud->points[corr.dst_idx]);
468 }
469
470 // 计算LCP分数 / Compute LCP score
471 return lcp_metric.compute_lcp_score(inlier_source, inlier_target, transform);
472}
473
474// 显式实例化 / Explicit instantiation
475template class ransac_registration_t<float>;
476template class ransac_registration_t<double>;
477
478} // namespace toolbox::pcl
LCP (Largest Common Pointset) metric for evaluating point cloud registration.
Definition point_cloud_metrics.hpp:465
RANSAC粗配准算法 / RANSAC coarse registration algorithm.
Definition ransac_registration.hpp:46
bool validate_input_impl() const
额外的输入验证 / Additional input validation
Definition ransac_registration_impl.hpp:185
Eigen::Matrix< DataType, 4, 4 > transformation_t
Definition ransac_registration.hpp:55
bool align_impl(result_type &result)
派生类实现的配准算法 / Registration algorithm implementation by derived class
Definition ransac_registration_impl.hpp:16
包含点和相关数据的点云类 / A point cloud class containing points and associated data
Definition point.hpp:268
std::vector< point_t< T > > points
点坐标 / Point coordinates
Definition point.hpp:270
A high-resolution stopwatch timer for measuring elapsed time.
Definition timer.hpp:41
void stop()
Stop (pause) the timer and accumulate the duration.
auto elapsed_time() const -> double
Get the total elapsed time in seconds.
void start()
Start or resume the timer.
#define LOG_DEBUG_S
DEBUG级别流式日志的宏 / Macro for DEBUG level stream logging.
Definition thread_logger.hpp:1329
#define LOG_ERROR_S
ERROR级别流式日志的宏 / Macro for ERROR level stream logging.
Definition thread_logger.hpp:1332
#define LOG_WARN_S
WARN级别流式日志的宏 / Macro for WARN level stream logging.
Definition thread_logger.hpp:1331
Definition base_correspondence_generator.hpp:18
std::vector< T > sample(const std::vector< T > &population, size_t k)
从vector中随机采样k个元素/Randomly sample k elements from vector
Definition random.hpp:518