cpp-toolbox  0.0.1
A toolbox library for C++
Loading...
Searching...
No Matches
dataloader.hpp
Go to the documentation of this file.
1#pragma once
2
3#include <cstddef>
4#include <deque>
5#include <future>
6#include <utility>
7#include <vector>
8
10#include <cpp-toolbox/cpp-toolbox_export.hpp>
11
13
14namespace toolbox::io
15{
16
28template<typename T>
29static std::future<T> make_ready_future(T value)
30{
31 std::promise<T> p;
32 auto f = p.get_future();
33 p.set_value(std::move(value));
34 return f;
35}
36
53template<typename DatasetT, typename SamplerT>
54class CPP_TOOLBOX_EXPORT dataloader_t
55{
56public:
60 using data_type = typename DatasetT::data_type;
64 using batch_type = std::vector<data_type>;
68 using index_type = typename SamplerT::index_type;
69
86 dataloader_t(DatasetT& dataset,
87 SamplerT sampler,
88 std::size_t batch_size,
89 std::size_t prefetch_batches = 0,
90 toolbox::base::thread_pool_t* pool = nullptr,
91 bool drop_last = false)
92 : m_dataset(dataset)
93 , m_sampler(std::move(sampler))
94 , m_batch_size(batch_size)
95 , m_prefetch_batches(prefetch_batches)
96 , m_pool(pool)
97 , m_drop_last(drop_last)
98 {
99 }
100
107 {
108 public:
113 : m_dataset_ptr(nullptr)
114 , m_sampler(0)
115 , m_batch_size(0)
116 , m_prefetch_batches(0)
117 , m_pool(nullptr)
118 , m_drop_last(false)
119 {
120 }
121
132 iterator(DatasetT* dataset,
133 SamplerT sampler,
134 std::size_t batch_size,
135 std::size_t prefetch_batches,
137 bool drop_last)
138 : m_dataset_ptr(dataset)
139 , m_sampler(std::move(sampler))
140 , m_batch_size(batch_size)
141 , m_prefetch_batches(prefetch_batches)
142 , m_pool(pool)
143 , m_drop_last(drop_last)
144 {
145 for (std::size_t i = 0; i < m_prefetch_batches && m_sampler.has_next();
146 ++i)
147 {
148 enqueue_fetch();
149 }
150 if (m_prefetch_batches == 0 && m_sampler.has_next()) {
151 enqueue_fetch();
152 }
153 ++(*this);
154 }
155
161 {
162 if (!m_dataset_ptr) {
163 return *this;
164 }
165
166 if (m_queue.empty()) {
167 m_dataset_ptr = nullptr;
168 return *this;
169 }
170
171 auto fut = std::move(m_queue.front());
172 m_queue.pop_front();
173 m_current_batch = fut.get();
174
175 if (m_sampler.has_next()) {
176 enqueue_fetch();
177 }
178
179 if (m_current_batch.size() < m_batch_size && m_drop_last) {
180 m_dataset_ptr = nullptr;
181 }
182
183 if (m_queue.empty() && !m_sampler.has_next()) {
184 if (m_current_batch.empty()
185 || (m_drop_last && m_current_batch.size() < m_batch_size))
186 {
187 m_dataset_ptr = nullptr;
188 }
189 }
190 return *this;
191 }
192
197 const batch_type& operator*() const { return m_current_batch; }
198
203 const batch_type* operator->() const { return &m_current_batch; }
204
210 bool operator!=(const iterator& other) const
211 {
212 return m_dataset_ptr != other.m_dataset_ptr;
213 }
214
215 private:
219 void enqueue_fetch()
220 {
221 auto indices = m_sampler.next_batch(m_batch_size);
222 if (indices.empty()) {
223 return;
224 }
225
226 DatasetT* dataset_ptr = m_dataset_ptr;
227 auto task = [dataset_ptr, indices = std::move(indices)]()
228 {
229 batch_type batch;
230 batch.reserve(indices.size());
231 for (auto idx : indices) {
232 auto item = dataset_ptr->get_item(idx);
233 if (item) {
234 batch.push_back(std::move(*item));
235 }
236 }
237 return batch;
238 };
239
240 if (m_pool) {
241 m_queue.push_back(m_pool->submit(task));
242 } else {
243 m_queue.push_back(make_ready_future(task()));
244 }
245 }
246
250 DatasetT* m_dataset_ptr;
254 SamplerT m_sampler;
258 std::size_t m_batch_size;
262 std::size_t m_prefetch_batches;
270 bool m_drop_last;
274 std::deque<std::future<batch_type>> m_queue;
278 batch_type m_current_batch;
279 };
280
292 {
293 m_sampler.reset();
294 return iterator(&m_dataset,
295 m_sampler,
296 m_batch_size,
297 m_prefetch_batches,
298 m_pool,
299 m_drop_last);
300 }
301
306 iterator end() { return iterator(); }
307
308private:
312 DatasetT& m_dataset;
316 SamplerT m_sampler;
320 std::size_t m_batch_size;
324 std::size_t m_prefetch_batches;
332 bool m_drop_last;
333};
334
335} // namespace toolbox::io
Definition thread_pool.hpp:60
数据加载器迭代器/Iterator for dataloader
Definition dataloader.hpp:107
iterator()
默认构造函数/Default constructor
Definition dataloader.hpp:112
iterator(DatasetT *dataset, SamplerT sampler, std::size_t batch_size, std::size_t prefetch_batches, toolbox::base::thread_pool_t *pool, bool drop_last)
构造函数/Constructor
Definition dataloader.hpp:132
iterator & operator++()
前置自增运算符/Prefix increment operator
Definition dataloader.hpp:160
const batch_type * operator->() const
指针运算符/Pointer operator
Definition dataloader.hpp:203
const batch_type & operator*() const
解引用运算符/Dereference operator
Definition dataloader.hpp:197
bool operator!=(const iterator &other) const
不等于运算符/Not-equal operator
Definition dataloader.hpp:210
通用数据加载器/Generic data loader
Definition dataloader.hpp:55
typename DatasetT::data_type data_type
数据类型/Data type
Definition dataloader.hpp:60
typename SamplerT::index_type index_type
索引类型/Index type
Definition dataloader.hpp:68
dataloader_t(DatasetT &dataset, SamplerT sampler, std::size_t batch_size, std::size_t prefetch_batches=0, toolbox::base::thread_pool_t *pool=nullptr, bool drop_last=false)
构造函数/Constructor
Definition dataloader.hpp:86
iterator end()
获取迭代终止点/Get end iterator
Definition dataloader.hpp:306
std::vector< data_type > batch_type
批次类型/Batch type
Definition dataloader.hpp:64
iterator begin()
获取迭代起始点/Get begin iterator
Definition dataloader.hpp:291
< 用于列出目录下的文件/For listing files in a directory
Definition dataloader.hpp:15