cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
sampler.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <algorithm>
4#include <cstddef>
5#include <numeric>
6#include <random>
7#include <vector>
8
9#include <cpp-toolbox/cpp-toolbox_export.hpp>
10
11namespace toolbox::io
12{
13
25{
30 void operator()(std::vector<std::size_t>& /*indices*/) const noexcept {}
31};
32
44{
45public:
50 explicit shuffle_policy_t(unsigned seed = std::random_device {}())
51 : m_rng(seed)
52 {
53 }
54
59 void operator()(std::vector<std::size_t>& indices)
60 {
61 std::shuffle(indices.begin(), indices.end(), m_rng);
62 }
63
68 void set_seed(unsigned seed) { m_rng.seed(seed); }
69
70private:
74 std::mt19937 m_rng;
75};
76
94template<typename PolicyT = sequential_policy_t>
95class CPP_TOOLBOX_EXPORT sampler_t
96{
97public:
101 using index_type = std::size_t;
102
107 sampler_t(const sampler_t& other)
108 : m_dataset_size(other.m_dataset_size)
109 , m_policy(other.m_policy)
110 , m_indices(other.m_indices)
111 {
112 // 计算迭代器偏移/Calculate iterator offset
113 auto offset = other.m_iter - other.m_indices.begin();
114 m_iter = m_indices.begin() + offset;
115 }
116
123 {
124 if (this != &other) {
125 m_dataset_size = other.m_dataset_size;
126 m_policy = other.m_policy;
127 m_indices = other.m_indices;
128 // 计算迭代器偏移/Calculate iterator offset
129 auto offset = other.m_iter - other.m_indices.begin();
130 m_iter = m_indices.begin() + offset;
131 }
132 return *this;
133 }
134
140 explicit sampler_t(std::size_t dataset_size, PolicyT policy = PolicyT {})
141 : m_dataset_size(dataset_size)
142 , m_policy(std::move(policy))
143 {
144 m_indices.resize(m_dataset_size);
145 std::iota(m_indices.begin(), m_indices.end(), index_type {0});
146 reset();
147 }
148
158 void reset()
159 {
160 std::iota(m_indices.begin(), m_indices.end(), index_type {0});
161 m_policy(m_indices);
162 m_iter = m_indices.begin();
163 }
164
174 [[nodiscard]] bool has_next() const noexcept
175 {
176 return m_iter != m_indices.end();
177 }
178
188 index_type next() { return *m_iter++; }
189
198 std::vector<index_type> next_batch(std::size_t batch_size)
199 {
200 std::vector<index_type> batch;
201 batch.reserve(batch_size);
202 for (std::size_t i = 0; i < batch_size && has_next(); ++i) {
203 batch.push_back(next());
204 }
205 return batch;
206 }
207
208private:
212 std::size_t m_dataset_size;
216 PolicyT m_policy;
220 std::vector<index_type> m_indices;
224 typename std::vector<index_type>::iterator m_iter;
225};
226
227} // namespace toolbox::io
通用采样器/Generic sampler
Definition sampler.hpp:96
bool has_next() const noexcept
是否还有下一个索引/Whether there are more indices
Definition sampler.hpp:174
std::vector< index_type > next_batch(std::size_t batch_size)
获取批量索引/Get a batch of indices
Definition sampler.hpp:198
void reset()
重置索引并重新应用策略/Reset indices and reapply policy
Definition sampler.hpp:158
sampler_t(const sampler_t &other)
拷贝构造函数/Copy constructor
Definition sampler.hpp:107
index_type next()
获取下一个索引/Get the next index
Definition sampler.hpp:188
sampler_t & operator=(const sampler_t &other)
拷贝赋值操作符/Copy assignment operator
Definition sampler.hpp:122
std::size_t index_type
索引类型/Index type
Definition sampler.hpp:101
sampler_t(std::size_t dataset_size, PolicyT policy=PolicyT {})
构造函数/Constructor
Definition sampler.hpp:140
随机打乱采样策略/Shuffle sampling policy
Definition sampler.hpp:44
void operator()(std::vector< std::size_t > &indices)
打乱索引/Shuffle indices
Definition sampler.hpp:59
void set_seed(unsigned seed)
设置随机种子/Set random seed
Definition sampler.hpp:68
shuffle_policy_t(unsigned seed=std::random_device {}())
构造函数/Constructor
Definition sampler.hpp:50
< 用于列出目录下的文件/For listing files in a directory
Definition dataloader.hpp:15
顺序采样策略/Sequential sampling policy
Definition sampler.hpp:25
void operator()(std::vector< std::size_t > &) const noexcept
顺序采样操作符/Sequential sampling operator
Definition sampler.hpp:30