cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
metric_factory.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <memory>
4#include <string>
5#include <unordered_map>
6#include <functional>
7#include <stdexcept>
8#include <vector>
9
16
17namespace toolbox::metrics
18{
19
20// Abstract metric interface for runtime polymorphism
21template<typename T>
23{
24public:
25 using element_type = T;
26 using result_type = T;
27
28 virtual ~IMetric() = default;
29
30 virtual T distance(const T* a, const T* b, std::size_t size) const = 0;
31 virtual T squared_distance(const T* a, const T* b, std::size_t size) const = 0;
32
33 // Container interface
34 template<typename Container>
35 T distance(const Container& a, const Container& b) const
36 {
37 return distance(a.data(), b.data(), a.size());
38 }
39
40 template<typename Container>
41 T squared_distance(const Container& a, const Container& b) const
42 {
43 return squared_distance(a.data(), b.data(), a.size());
44 }
45};
46
47// Concrete wrapper for CRTP metrics
48template<typename MetricType>
49class MetricWrapper : public IMetric<typename MetricType::element_type>
50{
51public:
52 using T = typename MetricType::element_type;
53
54 explicit MetricWrapper(MetricType metric) : metric_(std::move(metric)) {}
55
56 T distance(const T* a, const T* b, std::size_t size) const override
57 {
58 return metric_.distance(a, b, size);
59 }
60
61 T squared_distance(const T* a, const T* b, std::size_t size) const override
62 {
63 return metric_.squared_distance(a, b, size);
64 }
65
66private:
67 MetricType metric_;
68};
69
70// Metric factory for runtime metric creation
71template<typename T>
73{
74public:
75 using metric_ptr = std::unique_ptr<IMetric<T>>;
76 using creator_func = std::function<metric_ptr()>;
77
79 {
80 static MetricFactory factory;
81 return factory;
82 }
83
84 // Register a metric type
85 template<typename MetricType>
86 void register_metric(const std::string& name)
87 {
88 creators_[name] = []() -> metric_ptr {
89 return std::make_unique<MetricWrapper<MetricType>>(MetricType{});
90 };
91 }
92
93 // Register a metric with parameters
94 void register_metric(const std::string& name, creator_func creator)
95 {
96 creators_[name] = std::move(creator);
97 }
98
99 // Create a metric by name
100 metric_ptr create(const std::string& name) const
101 {
102 auto it = creators_.find(name);
103 if (it != creators_.end()) {
104 return it->second();
105 }
106 throw std::invalid_argument("Unknown metric: " + name);
107 }
108
109 // Get list of registered metrics
110 std::vector<std::string> available_metrics() const
111 {
112 std::vector<std::string> names;
113 names.reserve(creators_.size());
114 for (const auto& [name, _] : creators_) {
115 names.push_back(name);
116 }
117 return names;
118 }
119
120 // Clear all registrations
121 void clear()
122 {
123 creators_.clear();
124 }
125
126private:
128 {
129 // Register default metrics
130 register_default_metrics();
131 }
132
133 void register_default_metrics()
134 {
135 // Vector metrics
136 register_metric<L1Metric<T>>("l1");
137 register_metric<L1Metric<T>>("manhattan");
138 register_metric<L2Metric<T>>("l2");
139 register_metric<L2Metric<T>>("euclidean");
140 register_metric<LinfMetric<T>>("linf");
141 register_metric<LinfMetric<T>>("chebyshev");
142
143 // Histogram metrics
144 register_metric<ChiSquaredMetric<T>>("chi_squared");
145 register_metric<HistogramIntersectionMetric<T>>("histogram_intersection");
146 register_metric<BhattacharyyaMetric<T>>("bhattacharyya");
147 register_metric<HellingerMetric<T>>("hellinger");
148 register_metric<EMDMetric<T>>("emd");
149 register_metric<EMDMetric<T>>("wasserstein");
150 register_metric<KLDivergenceMetric<T>>("kl_divergence");
151 register_metric<JensenShannonMetric<T>>("jensen_shannon");
152
153 // Angular metrics
154 register_metric<CosineMetric<T>>("cosine");
155 register_metric<AngularMetric<T>>("angular");
156 register_metric<NormalizedAngularMetric<T>>("normalized_angular");
157 register_metric<CorrelationMetric<T>>("correlation");
158 register_metric<InnerProductMetric<T>>("inner_product");
159
160 // Lp metrics with specific p values
161 register_metric("l3", []() -> metric_ptr {
162 return std::make_unique<MetricWrapper<LpMetric<T, 3>>>(LpMetric<T, 3>{});
163 });
164
165 register_metric("l4", []() -> metric_ptr {
166 return std::make_unique<MetricWrapper<LpMetric<T, 4>>>(LpMetric<T, 4>{});
167 });
168 }
169
170 std::unordered_map<std::string, creator_func> creators_;
171};
172
173// Helper function to create metrics
174template<typename T>
175std::unique_ptr<IMetric<T>> create_metric(const std::string& name)
176{
177 return MetricFactory<T>::instance().create(name);
178}
179
180// Metric registry for plugin architecture
181template<typename T>
183{
184public:
186 {
187 std::string name;
188 std::string description;
192 };
193
195 {
196 static MetricRegistry registry;
197 return registry;
198 }
199
200 template<typename MetricType>
201 void register_metric(const std::string& name, const std::string& description)
202 {
203 MetricInfo info;
204 info.name = name;
205 info.description = description;
209
210 metrics_[name] = info;
211
212 // Also register in factory
213 MetricFactory<T>::instance().template register_metric<MetricType>(name);
214 }
215
216 const MetricInfo* get_info(const std::string& name) const
217 {
218 auto it = metrics_.find(name);
219 return it != metrics_.end() ? &it->second : nullptr;
220 }
221
222 std::vector<MetricInfo> list_metrics() const
223 {
224 std::vector<MetricInfo> result;
225 result.reserve(metrics_.size());
226 for (const auto& [_, info] : metrics_) {
227 result.push_back(info);
228 }
229 return result;
230 }
231
232private:
233 std::unordered_map<std::string, MetricInfo> metrics_;
234};
235
236// Convenience function for creating metrics with automatic type deduction
237template<typename Container>
238auto create_metric_for(const Container& example, const std::string& name)
239{
240 using T = typename Container::value_type;
241 return create_metric<T>(name);
242}
243
244} // namespace toolbox::metrics
Definition metric_factory.hpp:23
T element_type
Definition metric_factory.hpp:25
virtual ~IMetric()=default
T squared_distance(const Container &a, const Container &b) const
Definition metric_factory.hpp:41
T result_type
Definition metric_factory.hpp:26
virtual T distance(const T *a, const T *b, std::size_t size) const =0
virtual T squared_distance(const T *a, const T *b, std::size_t size) const =0
T distance(const Container &a, const Container &b) const
Definition metric_factory.hpp:35
Definition metric_factory.hpp:73
void register_metric(const std::string &name)
Definition metric_factory.hpp:86
metric_ptr create(const std::string &name) const
Definition metric_factory.hpp:100
std::unique_ptr< IMetric< T > > metric_ptr
Definition metric_factory.hpp:75
std::vector< std::string > available_metrics() const
Definition metric_factory.hpp:110
std::function< metric_ptr()> creator_func
Definition metric_factory.hpp:76
static MetricFactory & instance()
Definition metric_factory.hpp:78
void clear()
Definition metric_factory.hpp:121
void register_metric(const std::string &name, creator_func creator)
Definition metric_factory.hpp:94
Definition metric_factory.hpp:183
std::vector< MetricInfo > list_metrics() const
Definition metric_factory.hpp:222
const MetricInfo * get_info(const std::string &name) const
Definition metric_factory.hpp:216
void register_metric(const std::string &name, const std::string &description)
Definition metric_factory.hpp:201
static MetricRegistry & instance()
Definition metric_factory.hpp:194
Definition metric_factory.hpp:50
T distance(const T *a, const T *b, std::size_t size) const override
Definition metric_factory.hpp:56
MetricWrapper(MetricType metric)
Definition metric_factory.hpp:54
typename MetricType::element_type T
Definition metric_factory.hpp:52
T squared_distance(const T *a, const T *b, std::size_t size) const override
Definition metric_factory.hpp:61
Definition angular_metrics.hpp:11
auto create_metric_for(const Container &example, const std::string &name)
Definition metric_factory.hpp:238
std::unique_ptr< IMetric< T > > create_metric(const std::string &name)
Definition metric_factory.hpp:175
Definition metric_factory.hpp:186
bool is_symmetric
Definition metric_factory.hpp:189
bool has_squared_form
Definition metric_factory.hpp:190
std::string name
Definition metric_factory.hpp:187
std::string description
Definition metric_factory.hpp:188
bool requires_positive_values
Definition metric_factory.hpp:191
Definition metric_traits.hpp:25