cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
custom_metric.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <functional>
4#include <memory>
5
7
8namespace toolbox::metrics
9{
10
11// Function-based custom metric
12template<typename T>
13class CustomMetric : public base_metric_t<CustomMetric<T>, T>
14{
15public:
16 using element_type = T;
17 using distance_function = std::function<T(const T*, const T*, std::size_t)>;
18
19 explicit CustomMetric(distance_function dist_func)
20 : dist_func_(std::move(dist_func))
21 {
22 if (!dist_func_) {
23 throw std::invalid_argument("Distance function cannot be null");
24 }
25 }
26
27 constexpr T distance_impl(const T* a, const T* b, std::size_t size) const
28 {
29 return dist_func_(a, b, size);
30 }
31
32 constexpr T squared_distance_impl(const T* a, const T* b, std::size_t size) const
33 {
34 T dist = distance_impl(a, b, size);
35 return dist * dist;
36 }
37
38private:
39 distance_function dist_func_;
40};
41
42// Lambda-based custom metric with state
43template<typename T, typename Lambda>
44class LambdaMetric : public base_metric_t<LambdaMetric<T, Lambda>, T>
45{
46public:
47 using element_type = T;
48
49 explicit LambdaMetric(Lambda lambda) : lambda_(std::move(lambda)) {}
50
51 constexpr T distance_impl(const T* a, const T* b, std::size_t size) const
52 {
53 return lambda_(a, b, size);
54 }
55
56 constexpr T squared_distance_impl(const T* a, const T* b, std::size_t size) const
57 {
58 T dist = distance_impl(a, b, size);
59 return dist * dist;
60 }
61
62private:
63 Lambda lambda_;
64};
65
66// Factory function for creating lambda metrics
67template<typename T, typename Lambda>
68auto make_lambda_metric(Lambda&& lambda)
69{
70 return LambdaMetric<T, std::decay_t<Lambda>>(std::forward<Lambda>(lambda));
71}
72
73// Weighted metric wrapper
74template<typename BaseMetric>
75class WeightedMetric : public base_metric_t<WeightedMetric<BaseMetric>,
76 typename BaseMetric::element_type>
77{
78public:
79 using base_metric_type = BaseMetric;
80 using element_type = typename BaseMetric::element_type;
81
82 WeightedMetric(BaseMetric base_metric, const std::vector<element_type>& weights)
83 : base_metric_(std::move(base_metric)), weights_(weights) {}
84
86 const element_type* b,
87 std::size_t size) const
88 {
89 if (weights_.size() != size) {
90 throw std::invalid_argument("Weight vector size mismatch");
91 }
92
93 // Create weighted versions of the input vectors
94 std::vector<element_type> weighted_a(size);
95 std::vector<element_type> weighted_b(size);
96
97 for (std::size_t i = 0; i < size; ++i) {
98 element_type w = std::sqrt(weights_[i]);
99 weighted_a[i] = a[i] * w;
100 weighted_b[i] = b[i] * w;
101 }
102
103 return base_metric_.distance(weighted_a.data(), weighted_b.data(), size);
104 }
105
107 const element_type* b,
108 std::size_t size) const
109 {
110 element_type dist = distance_impl(a, b, size);
111 return dist * dist;
112 }
113
114private:
115 BaseMetric base_metric_;
116 std::vector<element_type> weights_;
117};
118
119// Factory function for creating weighted metrics
120template<typename BaseMetric>
121auto make_weighted_metric(BaseMetric&& metric,
122 const std::vector<typename std::decay_t<BaseMetric>::element_type>& weights)
123{
125 std::forward<BaseMetric>(metric), weights);
126}
127
128// Minkowski-like custom metric with parameter
129template<typename T>
130class ParameterizedMetric : public base_metric_t<ParameterizedMetric<T>, T>
131{
132public:
133 using element_type = T;
134
136
137 void set_parameter(T param) { parameter_ = param; }
138 T get_parameter() const { return parameter_; }
139
140 // Note: This is an abstract class, derived classes must implement distance_impl
141 // We provide squared_distance_impl here for convenience
142 constexpr T squared_distance_impl(const T* a, const T* b, std::size_t size) const
143 {
144 T dist = static_cast<const ParameterizedMetric*>(this)->distance_impl(a, b, size);
145 return dist * dist;
146 }
147
148protected:
149 T parameter_ = T(2); // Default to L2-like behavior
150};
151
152// Composite metric that combines multiple metrics
153template<typename T>
154class CompositeMetric : public base_metric_t<CompositeMetric<T>, T>
155{
156public:
157 using element_type = T;
158 using metric_ptr = std::shared_ptr<base_metric_t<void, T>>;
159
160 void add_metric(metric_ptr metric, T weight = T(1))
161 {
162 metrics_.emplace_back(std::move(metric), weight);
163 }
164
165 constexpr T distance_impl(const T* a, const T* b, std::size_t size) const
166 {
167 T total_distance {};
168 T total_weight {};
169
170 for (const auto& [metric, weight] : metrics_) {
171 total_distance += weight * metric->distance(a, b, size);
172 total_weight += weight;
173 }
174
175 if (total_weight > T(0)) {
176 return total_distance / total_weight;
177 }
178 return T(0);
179 }
180
181 constexpr T squared_distance_impl(const T* a, const T* b, std::size_t size) const
182 {
183 T dist = distance_impl(a, b, size);
184 return dist * dist;
185 }
186
187private:
188 std::vector<std::pair<metric_ptr, T>> metrics_;
189};
190
191// Mahalanobis-like metric (simplified version without full covariance matrix)
192template<typename T>
193class ScaledMetric : public base_metric_t<ScaledMetric<T>, T>
194{
195public:
196 using element_type = T;
197
198 explicit ScaledMetric(const std::vector<T>& scales)
199 : scales_(scales) {}
200
201 constexpr T distance_impl(const T* a, const T* b, std::size_t size) const
202 {
203 if (scales_.size() != size) {
204 throw std::invalid_argument("Scale vector size mismatch");
205 }
206
207 T sum {};
208 for (std::size_t i = 0; i < size; ++i) {
209 T diff = (a[i] - b[i]) / scales_[i];
210 sum += diff * diff;
211 }
212 return std::sqrt(sum);
213 }
214
215 constexpr T squared_distance_impl(const T* a, const T* b, std::size_t size) const
216 {
217 if (scales_.size() != size) {
218 throw std::invalid_argument("Scale vector size mismatch");
219 }
220
221 T sum {};
222 for (std::size_t i = 0; i < size; ++i) {
223 T diff = (a[i] - b[i]) / scales_[i];
224 sum += diff * diff;
225 }
226 return sum; // Return squared distance directly (no sqrt)
227 }
228
229private:
230 std::vector<T> scales_;
231};
232
233} // namespace toolbox::metrics
Definition custom_metric.hpp:155
std::shared_ptr< base_metric_t< void, T > > metric_ptr
Definition custom_metric.hpp:158
constexpr T distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:165
constexpr T squared_distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:181
T element_type
Definition custom_metric.hpp:157
void add_metric(metric_ptr metric, T weight=T(1))
Definition custom_metric.hpp:160
Definition custom_metric.hpp:14
T element_type
Definition custom_metric.hpp:16
CustomMetric(distance_function dist_func)
Definition custom_metric.hpp:19
constexpr T distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:27
constexpr T squared_distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:32
std::function< T(const T *, const T *, std::size_t)> distance_function
Definition custom_metric.hpp:17
Definition custom_metric.hpp:45
constexpr T squared_distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:56
LambdaMetric(Lambda lambda)
Definition custom_metric.hpp:49
constexpr T distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:51
T element_type
Definition custom_metric.hpp:47
Definition custom_metric.hpp:131
constexpr T squared_distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:142
T get_parameter() const
Definition custom_metric.hpp:138
T element_type
Definition custom_metric.hpp:133
T parameter_
Definition custom_metric.hpp:149
void set_parameter(T param)
Definition custom_metric.hpp:137
Definition custom_metric.hpp:194
T element_type
Definition custom_metric.hpp:196
ScaledMetric(const std::vector< T > &scales)
Definition custom_metric.hpp:198
constexpr T distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:201
constexpr T squared_distance_impl(const T *a, const T *b, std::size_t size) const
Definition custom_metric.hpp:215
Definition custom_metric.hpp:77
BaseMetric base_metric_type
Definition custom_metric.hpp:79
typename BaseMetric::element_type element_type
Definition custom_metric.hpp:80
constexpr element_type squared_distance_impl(const element_type *a, const element_type *b, std::size_t size) const
Definition custom_metric.hpp:106
constexpr element_type distance_impl(const element_type *a, const element_type *b, std::size_t size) const
Definition custom_metric.hpp:85
WeightedMetric(BaseMetric base_metric, const std::vector< element_type > &weights)
Definition custom_metric.hpp:82
Definition base_metric.hpp:13
Definition angular_metrics.hpp:11
auto make_weighted_metric(BaseMetric &&metric, const std::vector< typename std::decay_t< BaseMetric >::element_type > &weights)
Definition custom_metric.hpp:121
auto make_lambda_metric(Lambda &&lambda)
Definition custom_metric.hpp:68