cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
prosac_registration_impl.hpp
Go to the documentation of this file.
1#pragma once
2
4
5namespace toolbox::pcl
6{
7
8template<typename DataType>
10{
11 // 获取基类成员访问 / Get base class member access
12 const auto& source_cloud = this->get_source_cloud();
13 const auto& target_cloud = this->get_target_cloud();
14 const auto& correspondences = this->get_correspondences();
15 const auto max_iterations = this->get_max_iterations();
16 const auto inlier_threshold = this->get_inlier_threshold();
17 const auto min_inliers = this->get_min_inliers();
18 const auto random_seed = this->get_random_seed();
19
20 // 初始化结果 / Initialize result
21 result.transformation = transformation_t::Identity();
22 result.fitness_score = std::numeric_limits<DataType>::max();
23 result.inliers.clear();
24 result.converged = false;
25 result.num_iterations = 0;
26
27 // 检查是否有足够的对应关系 / Check if there are enough correspondences
28 const std::size_t num_correspondences = correspondences->size();
29 if (num_correspondences < m_sample_size) {
30 LOG_ERROR_S << "错误:对应关系数量不足 / Error: Insufficient correspondences: "
31 << num_correspondences << " < " << m_sample_size;
32 return false;
33 }
34
35 // 预计算采样调度 / Precompute sampling schedule
36 precompute_sampling_schedule(num_correspondences);
37
38 // 初始化随机数生成器 / Initialize random number generator
39 std::mt19937 generator(random_seed);
40
41 // PROSAC主循环变量 / PROSAC main loop variables
42 std::size_t n = m_sample_size; // 当前采样池大小 / Current sampling pool size
43 std::size_t t = 0; // 迭代计数器 / Iteration counter
44 std::size_t best_inlier_count = 0;
45 transformation_t best_transformation = transformation_t::Identity();
46 std::vector<std::size_t> best_inliers;
47
48 // 用于早停的变量 / Variables for early stopping
49 auto start_time = std::chrono::steady_clock::now();
50 const auto max_time = std::chrono::seconds(300); // 5分钟超时 / 5 minutes timeout
51
52 LOG_INFO_S << "开始PROSAC配准 / Starting PROSAC registration with "
53 << num_correspondences << " correspondences, sample size "
54 << m_sample_size;
55
56 // PROSAC主循环 / PROSAC main loop
57 while (t < max_iterations) {
58 // 检查超时 / Check timeout
59 auto current_time = std::chrono::steady_clock::now();
60 if (current_time - start_time > max_time) {
61 LOG_WARN_S << "警告:PROSAC达到时间限制 / Warning: PROSAC reached time limit";
62 break;
63 }
64
65 // 更新采样池大小 / Update sampling pool size
66 if (t == m_T_n[n - 1] && n < num_correspondences) {
67 n++;
68 }
69
70 // 渐进式采样 / Progressive sampling
71 std::vector<correspondence_t> sample;
72 sample.reserve(m_sample_size);
73 progressive_sample(sample, n, t, generator);
74
75 // 检查样本有效性 / Check sample validity
76 if (!is_sample_valid(sample)) {
77 t++;
78 continue;
79 }
80
81 // 估计变换 / Estimate transformation
82 transformation_t transform = estimate_transformation(sample);
83
84 // 计算内点 / Count inliers
85 std::vector<std::size_t> inliers;
86 std::size_t inlier_count = count_inliers(transform, inliers);
87
88 // 更新最佳模型 / Update best model
89 if (inlier_count > best_inlier_count) {
90 best_inlier_count = inlier_count;
91 best_transformation = transform;
92 best_inliers = inliers;
93
94 LOG_INFO_S << "迭代 / Iteration " << t << ": 找到更好的模型 / found better model with "
95 << inlier_count << " inliers (n=" << n << ")";
96
97 // 检查早停条件 / Check early stopping condition
98 DataType inlier_ratio = static_cast<DataType>(inlier_count) /
99 static_cast<DataType>(num_correspondences);
100 if (inlier_ratio >= m_early_stop_ratio) {
101 LOG_INFO_S << "达到早停条件 / Reached early stop condition: inlier ratio = "
102 << inlier_ratio;
103 break;
104 }
105
106 // 检查非随机性准则 / Check non-randomness criterion
107 if (check_non_randomness(inlier_count, n)) {
108 LOG_INFO_S << "满足非随机性准则 / Non-randomness criterion satisfied";
109 break;
110 }
111 }
112
113 // 检查最大性准则 / Check maximality criterion
114 if (best_inlier_count >= min_inliers &&
115 check_maximality(best_inlier_count, n, t)) {
116 LOG_INFO_S << "满足最大性准则 / Maximality criterion satisfied";
117 break;
118 }
119
120 t++;
121 m_total_samples++;
122 }
123
124 // 设置结果 / Set results
125 result.num_iterations = t;
126 m_best_inlier_count = best_inlier_count;
127
128 if (best_inlier_count >= min_inliers) {
129 // 如果需要,使用所有内点精炼结果 / Refine result using all inliers if needed
130 if (m_refine_result && best_inlier_count > m_sample_size) {
131 LOG_INFO_S << "使用 / Using " << best_inlier_count
132 << " 个内点精炼变换 / inliers to refine transformation";
133 best_transformation = refine_transformation(best_inliers);
134
135 // 重新计算内点 / Recompute inliers
136 best_inlier_count = count_inliers(best_transformation, best_inliers);
137 }
138
139 result.transformation = best_transformation;
140 result.inliers = best_inliers;
141 result.fitness_score = compute_fitness_score(best_transformation, best_inliers);
142 result.converged = true;
143
144 LOG_INFO_S << "PROSAC配准成功 / PROSAC registration successful: "
145 << best_inlier_count << " inliers in " << t << " iterations";
146 } else {
147 LOG_WARN_S << "警告:PROSAC未找到足够的内点 / Warning: PROSAC did not find enough inliers: "
148 << best_inlier_count << " < " << min_inliers;
149 }
150
151 return result.converged;
152}
153
154template<typename DataType>
156{
157 const auto& correspondences = this->get_correspondences();
158
159 if (!correspondences || correspondences->empty()) {
160 LOG_ERROR_S << "错误:对应关系为空 / Error: Correspondences are empty";
161 return false;
162 }
163
164 if (m_sorted_indices.empty()) {
165 LOG_WARN_S << "警告:未提供排序索引,假设对应关系已排序 / "
166 "Warning: No sorted indices provided, assuming correspondences are sorted";
167 } else if (m_sorted_indices.size() != correspondences->size()) {
168 LOG_ERROR_S << "错误:排序索引大小与对应关系不匹配 / "
169 "Error: Sorted indices size doesn't match correspondences";
170 return false;
171 }
172
173 return true;
174}
175
176template<typename DataType>
178 std::size_t n_correspondences)
179{
180 m_T_n.clear();
181 m_T_n.reserve(n_correspondences);
182
183 // T_m初始值 / Initial value of T_m
184 DataType T_m = static_cast<DataType>(n_correspondences) *
185 std::pow(1.0 - m_initial_inlier_ratio,
186 static_cast<DataType>(m_sample_size));
187
188 // 添加前m个T值(都是1) / Add first m T values (all are 1)
189 for (std::size_t i = 0; i < m_sample_size; ++i) {
190 m_T_n.push_back(1);
191 }
192
193 // 计算T_n for n = m+1 to N / Compute T_n for n = m+1 to N
194 for (std::size_t n = m_sample_size + 1; n <= n_correspondences; ++n) {
195 // T_n = T_{n-1} + ceil(T_m * (n - m) / (m * C(n, m)))
196 std::size_t T_n_minus_1 = m_T_n.back();
197
198 // 避免整数溢出,使用对数计算 / Avoid integer overflow, use logarithmic computation
199 DataType log_numerator = std::log(static_cast<DataType>(n - m_sample_size)) +
200 std::log(T_m);
201 DataType log_denominator = std::log(static_cast<DataType>(m_sample_size));
202
203 // 计算组合数的对数 / Compute logarithm of binomial coefficient
204 for (std::size_t i = 0; i < m_sample_size; ++i) {
205 log_denominator += std::log(static_cast<DataType>(n - i)) -
206 std::log(static_cast<DataType>(m_sample_size - i));
207 }
208
209 DataType increment = std::exp(log_numerator - log_denominator);
210 std::size_t T_n = T_n_minus_1 + static_cast<std::size_t>(std::ceil(increment));
211
212 m_T_n.push_back(T_n);
213 }
214}
215
216template<typename DataType>
217void prosac_registration_t<DataType>::progressive_sample(
218 std::vector<correspondence_t>& sample, std::size_t n, std::size_t t,
219 std::mt19937& generator) const
220{
221 const auto& correspondences = this->get_correspondences();
222 sample.clear();
223
224 if (t >= m_T_n[n - 1]) {
225 // PROSAC采样:选择第n个对应关系和前n-1个中的m-1个 /
226 // PROSAC sampling: select nth correspondence and m-1 from first n-1
227
228 // 添加第n个对应关系 / Add nth correspondence
229 if (!m_sorted_indices.empty()) {
230 sample.push_back((*correspondences)[m_sorted_indices[n - 1]]);
231 } else {
232 sample.push_back((*correspondences)[n - 1]);
233 }
234
235 // 从前n-1个中随机选择m-1个 / Randomly select m-1 from first n-1
236 std::vector<std::size_t> indices;
237 indices.reserve(n - 1);
238 for (std::size_t i = 0; i < n - 1; ++i) {
239 indices.push_back(i);
240 }
241
242 std::shuffle(indices.begin(), indices.end(), generator);
243
244 for (std::size_t i = 0; i < m_sample_size - 1; ++i) {
245 if (!m_sorted_indices.empty()) {
246 sample.push_back((*correspondences)[m_sorted_indices[indices[i]]]);
247 } else {
248 sample.push_back((*correspondences)[indices[i]]);
249 }
250 }
251 } else {
252 // 标准RANSAC采样:从前n个中随机选择m个 /
253 // Standard RANSAC sampling: randomly select m from first n
254 std::vector<std::size_t> indices;
255 indices.reserve(n);
256 for (std::size_t i = 0; i < n; ++i) {
257 indices.push_back(i);
258 }
259
260 std::shuffle(indices.begin(), indices.end(), generator);
261
262 for (std::size_t i = 0; i < m_sample_size; ++i) {
263 if (!m_sorted_indices.empty()) {
264 sample.push_back((*correspondences)[m_sorted_indices[indices[i]]]);
265 } else {
266 sample.push_back((*correspondences)[indices[i]]);
267 }
268 }
269 }
270}
271
272template<typename DataType>
274prosac_registration_t<DataType>::estimate_transformation(
275 const std::vector<correspondence_t>& sample) const
276{
277 const auto& source_cloud = this->get_source_cloud();
278 const auto& target_cloud = this->get_target_cloud();
279
280 // 提取样本点 / Extract sample points
281 Eigen::Matrix<DataType, 3, Eigen::Dynamic> src_points(3, sample.size());
282 Eigen::Matrix<DataType, 3, Eigen::Dynamic> tgt_points(3, sample.size());
283
284 for (std::size_t i = 0; i < sample.size(); ++i) {
285 const auto& src_pt = source_cloud->points[sample[i].src_idx];
286 const auto& tgt_pt = target_cloud->points[sample[i].dst_idx];
287
288 src_points.col(i) = vector3_t(src_pt.x, src_pt.y, src_pt.z);
289 tgt_points.col(i) = vector3_t(tgt_pt.x, tgt_pt.y, tgt_pt.z);
290 }
291
292 // 计算质心 / Compute centroids
293 vector3_t src_centroid = src_points.rowwise().mean();
294 vector3_t tgt_centroid = tgt_points.rowwise().mean();
295
296 // 中心化点云 / Center point clouds
297 Eigen::Matrix<DataType, 3, Eigen::Dynamic> src_centered =
298 src_points.colwise() - src_centroid;
299 Eigen::Matrix<DataType, 3, Eigen::Dynamic> tgt_centered =
300 tgt_points.colwise() - tgt_centroid;
301
302 // 计算协方差矩阵 / Compute covariance matrix
303 matrix3_t H = src_centered * tgt_centered.transpose();
304
305 // SVD分解 / SVD decomposition
306 Eigen::JacobiSVD<matrix3_t> svd(H, Eigen::ComputeFullU | Eigen::ComputeFullV);
307 matrix3_t U = svd.matrixU();
308 matrix3_t V = svd.matrixV();
309
310 // 计算旋转矩阵 / Compute rotation matrix
311 matrix3_t R = V * U.transpose();
312
313 // 处理反射情况 / Handle reflection case
314 if (R.determinant() < 0) {
315 V.col(2) *= -1;
316 R = V * U.transpose();
317 }
318
319 // 计算平移向量 / Compute translation vector
320 vector3_t t = tgt_centroid - R * src_centroid;
321
322 // 构建变换矩阵 / Build transformation matrix
323 transformation_t transform = transformation_t::Identity();
324 transform.template block<3, 3>(0, 0) = R;
325 transform.template block<3, 1>(0, 3) = t;
326
327 return transform;
328}
329
330template<typename DataType>
331std::size_t prosac_registration_t<DataType>::count_inliers(
332 const transformation_t& transform, std::vector<std::size_t>& inliers) const
333{
334 const auto& source_cloud = this->get_source_cloud();
335 const auto& target_cloud = this->get_target_cloud();
336 const auto& correspondences = this->get_correspondences();
337 const auto inlier_threshold = this->get_inlier_threshold();
338
339 inliers.clear();
340 inliers.reserve(correspondences->size());
341
342 // 对每个对应关系检查是否为内点 / Check each correspondence for inlier status
343 for (std::size_t i = 0; i < correspondences->size(); ++i) {
344 const auto& corr = (*correspondences)[i];
345 const auto& src_pt = source_cloud->points[corr.src_idx];
346 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
347
348 // 变换源点 / Transform source point
349 vector3_t src_vec(src_pt.x, src_pt.y, src_pt.z);
350 vector3_t transformed = transform.template block<3, 3>(0, 0) * src_vec +
351 transform.template block<3, 1>(0, 3);
352
353 // 计算距离 / Compute distance
354 DataType dist = std::sqrt((transformed[0] - tgt_pt.x) * (transformed[0] - tgt_pt.x) +
355 (transformed[1] - tgt_pt.y) * (transformed[1] - tgt_pt.y) +
356 (transformed[2] - tgt_pt.z) * (transformed[2] - tgt_pt.z));
357
358 if (dist <= inlier_threshold) {
359 inliers.push_back(i);
360 }
361 }
362
363 return inliers.size();
364}
365
366template<typename DataType>
367bool prosac_registration_t<DataType>::check_non_randomness(
368 std::size_t inlier_count, std::size_t n) const
369{
370 // 计算观察到这么多内点的概率 / Compute probability of observing this many inliers
371 DataType p_good = 1.0;
372
373 for (std::size_t j = m_sample_size; j <= inlier_count; ++j) {
374 DataType beta = compute_beta(j, m_sample_size, n);
375 p_good *= (1.0 - beta);
376 }
377
378 p_good = 1.0 - p_good;
379
380 return p_good < m_non_randomness_threshold;
381}
382
383template<typename DataType>
384bool prosac_registration_t<DataType>::check_maximality(
385 std::size_t inlier_count, std::size_t n, std::size_t t) const
386{
387 // 计算找到更好模型所需的期望迭代次数 /
388 // Compute expected iterations to find better model
389 DataType inlier_ratio = static_cast<DataType>(inlier_count) /
390 static_cast<DataType>(n);
391
392 if (inlier_ratio <= 0) {
393 return false;
394 }
395
396 DataType p_better = std::pow(inlier_ratio, static_cast<DataType>(m_sample_size));
397
398 if (p_better <= 0) {
399 return true; // 不可能找到更好的模型 / Impossible to find better model
400 }
401
402 DataType k_max = std::log(1.0 - m_confidence) / std::log(1.0 - p_better);
403
404 return static_cast<DataType>(t) >= k_max;
405}
406
407template<typename DataType>
409prosac_registration_t<DataType>::refine_transformation(
410 const std::vector<std::size_t>& inlier_indices) const
411{
412 const auto& source_cloud = this->get_source_cloud();
413 const auto& target_cloud = this->get_target_cloud();
414 const auto& correspondences = this->get_correspondences();
415
416 // 提取所有内点 / Extract all inlier points
417 Eigen::Matrix<DataType, 3, Eigen::Dynamic> src_points(3, inlier_indices.size());
418 Eigen::Matrix<DataType, 3, Eigen::Dynamic> tgt_points(3, inlier_indices.size());
419
420 for (std::size_t i = 0; i < inlier_indices.size(); ++i) {
421 const auto& corr = (*correspondences)[inlier_indices[i]];
422 const auto& src_pt = source_cloud->points[corr.src_idx];
423 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
424
425 src_points.col(i) = vector3_t(src_pt.x, src_pt.y, src_pt.z);
426 tgt_points.col(i) = vector3_t(tgt_pt.x, tgt_pt.y, tgt_pt.z);
427 }
428
429 // 使用SVD计算最优变换(与estimate_transformation相同的方法) /
430 // Compute optimal transformation using SVD (same method as estimate_transformation)
431 vector3_t src_centroid = src_points.rowwise().mean();
432 vector3_t tgt_centroid = tgt_points.rowwise().mean();
433
434 Eigen::Matrix<DataType, 3, Eigen::Dynamic> src_centered =
435 src_points.colwise() - src_centroid;
436 Eigen::Matrix<DataType, 3, Eigen::Dynamic> tgt_centered =
437 tgt_points.colwise() - tgt_centroid;
438
439 matrix3_t H = src_centered * tgt_centered.transpose();
440
441 Eigen::JacobiSVD<matrix3_t> svd(H, Eigen::ComputeFullU | Eigen::ComputeFullV);
442 matrix3_t U = svd.matrixU();
443 matrix3_t V = svd.matrixV();
444
445 matrix3_t R = V * U.transpose();
446
447 if (R.determinant() < 0) {
448 V.col(2) *= -1;
449 R = V * U.transpose();
450 }
451
452 vector3_t t = tgt_centroid - R * src_centroid;
453
454 transformation_t transform = transformation_t::Identity();
455 transform.template block<3, 3>(0, 0) = R;
456 transform.template block<3, 1>(0, 3) = t;
457
458 return transform;
459}
460
461template<typename DataType>
462std::size_t prosac_registration_t<DataType>::compute_binomial_coefficient(
463 std::size_t n, std::size_t k) const
464{
465 if (k > n) return 0;
466 if (k == 0 || k == n) return 1;
467
468 // 使用Pascal三角形的性质优化计算 / Optimize using Pascal's triangle property
469 k = std::min(k, n - k);
470
471 std::size_t result = 1;
472 for (std::size_t i = 0; i < k; ++i) {
473 result = result * (n - i) / (i + 1);
474 }
475
476 return result;
477}
478
479template<typename DataType>
480DataType prosac_registration_t<DataType>::compute_beta(
481 std::size_t i, std::size_t m, std::size_t n) const
482{
483 if (i < m) return 0;
484 if (i > n) return 0;
485
486 // beta(i, m, n) = C(i-1, m-1) * C(n-i, 1) / C(n, m)
487 // = i * C(i-1, m-1) / C(n, m)
488
489 // 使用对数避免溢出 / Use logarithm to avoid overflow
490 DataType log_beta = std::log(static_cast<DataType>(i));
491
492 // log(C(i-1, m-1))
493 for (std::size_t j = 0; j < m - 1; ++j) {
494 log_beta += std::log(static_cast<DataType>(i - 1 - j)) -
495 std::log(static_cast<DataType>(j + 1));
496 }
497
498 // log(C(n, m))
499 for (std::size_t j = 0; j < m; ++j) {
500 log_beta -= std::log(static_cast<DataType>(n - j)) -
501 std::log(static_cast<DataType>(j + 1));
502 }
503
504 return std::exp(log_beta);
505}
506
507template<typename DataType>
508bool prosac_registration_t<DataType>::is_sample_valid(
509 const std::vector<correspondence_t>& sample) const
510{
511 if (sample.size() < 3) {
512 return false;
513 }
514
515 const auto& source_cloud = this->get_source_cloud();
516
517 // 检查是否有重复的对应关系 / Check for duplicate correspondences
518 std::set<std::size_t> src_indices, dst_indices;
519 for (const auto& corr : sample) {
520 if (!src_indices.insert(corr.src_idx).second ||
521 !dst_indices.insert(corr.dst_idx).second) {
522 return false;
523 }
524 }
525
526 // 检查源点是否共线 / Check if source points are collinear
527 const auto& p1 = source_cloud->points[sample[0].src_idx];
528 const auto& p2 = source_cloud->points[sample[1].src_idx];
529 const auto& p3 = source_cloud->points[sample[2].src_idx];
530
531 vector3_t v1(p2.x - p1.x, p2.y - p1.y, p2.z - p1.z);
532 vector3_t v2(p3.x - p1.x, p3.y - p1.y, p3.z - p1.z);
533
534 vector3_t cross = v1.cross(v2);
535 DataType cross_norm = cross.norm();
536
537 const DataType collinear_threshold = static_cast<DataType>(1e-6);
538 return cross_norm > collinear_threshold;
539}
540
541template<typename DataType>
542DataType prosac_registration_t<DataType>::compute_fitness_score(
543 const transformation_t& transform,
544 const std::vector<std::size_t>& inliers) const
545{
546 if (inliers.empty()) {
547 return std::numeric_limits<DataType>::max();
548 }
549
550 const auto& source_cloud = this->get_source_cloud();
551 const auto& target_cloud = this->get_target_cloud();
552 const auto& correspondences = this->get_correspondences();
553
554 DataType total_distance = 0;
555
556 // 计算所有内点的平均距离 / Compute average distance of all inliers
557 for (std::size_t idx : inliers) {
558 const auto& corr = (*correspondences)[idx];
559 const auto& src_pt = source_cloud->points[corr.src_idx];
560 const auto& tgt_pt = target_cloud->points[corr.dst_idx];
561
562 // 变换源点 / Transform source point
563 vector3_t src_vec(src_pt.x, src_pt.y, src_pt.z);
564 vector3_t transformed = transform.template block<3, 3>(0, 0) * src_vec +
565 transform.template block<3, 1>(0, 3);
566
567 // 计算距离 / Compute distance
568 DataType dist = std::sqrt((transformed[0] - tgt_pt.x) * (transformed[0] - tgt_pt.x) +
569 (transformed[1] - tgt_pt.y) * (transformed[1] - tgt_pt.y) +
570 (transformed[2] - tgt_pt.z) * (transformed[2] - tgt_pt.z));
571
572 total_distance += dist;
573 }
574
575 return total_distance / static_cast<DataType>(inliers.size());
576}
577
578} // namespace toolbox::pcl
PROSAC (渐进式采样一致性) 粗配准算法 / PROSAC (Progressive Sample Consensus) coarse registration algorithm.
Definition prosac_registration.hpp:60
bool validate_input_impl() const
额外的输入验证 / Additional input validation
Definition prosac_registration_impl.hpp:155
bool align_impl(result_type &result)
派生类实现的配准算法 / Registration algorithm implementation by derived class
Definition prosac_registration_impl.hpp:9
Eigen::Matrix< DataType, 4, 4 > transformation_t
Definition prosac_registration.hpp:72
#define LOG_INFO_S
INFO级别流式日志的宏 / Macro for INFO level stream logging.
Definition thread_logger.hpp:1330
#define LOG_ERROR_S
ERROR级别流式日志的宏 / Macro for ERROR level stream logging.
Definition thread_logger.hpp:1332
#define LOG_WARN_S
WARN级别流式日志的宏 / Macro for WARN level stream logging.
Definition thread_logger.hpp:1331
Definition base_correspondence_generator.hpp:18
std::vector< T > sample(const std::vector< T > &population, size_t k)
从vector中随机采样k个元素/Randomly sample k elements from vector
Definition random.hpp:518