Loading...
Searching...
No Matches
cuda_stream.hpp
1#pragma once
2
3#include "cuda_error.hpp"
4
9
10namespace tf {
11
12
13// ----------------------------------------------------------------------------
14// cudaEventBase
15// ----------------------------------------------------------------------------
16
23
24 public:
25
29 cudaEvent_t operator () () const {
30 cudaEvent_t event;
31 TF_CHECK_CUDA(cudaEventCreate(&event), "failed to create a CUDA event");
32 return event;
33 }
34
38 cudaEvent_t operator () (unsigned int flag) const {
39 cudaEvent_t event;
40 TF_CHECK_CUDA(
41 cudaEventCreateWithFlags(&event, flag),
42 "failed to create a CUDA event with flag=", flag
43 );
44 return event;
45 }
46
50 cudaEvent_t operator () (cudaEvent_t event) const {
51 return event;
52 }
53};
54
61 public:
65 void operator () (cudaEvent_t event) const {
66 cudaEventDestroy(event);
67 }
68};
69
81template <typename Creator, typename Deleter>
82class cudaEventBase : public std::unique_ptr<std::remove_pointer_t<cudaEvent_t>, Deleter> {
83
84 static_assert(std::is_pointer_v<cudaEvent_t>, "cudaEvent_t is not a pointer type");
85
86 public:
87
94 using base_type = std::unique_ptr<std::remove_pointer_t<cudaEvent_t>, Deleter>;
95
103 template <typename... ArgsT>
104 explicit cudaEventBase(ArgsT&& ... args) : base_type(
105 Creator{}(std::forward<ArgsT>(args)...), Deleter()
106 ) {
107 }
108
113
118
119 private:
120
121 cudaEventBase(const cudaEventBase&) = delete;
122 cudaEventBase& operator = (const cudaEventBase&) = delete;
123};
124
129
130// ----------------------------------------------------------------------------
131// cudaStream
132// ----------------------------------------------------------------------------
133
140
141 public:
142
146 cudaStream_t operator () () const {
147 cudaStream_t stream;
148 TF_CHECK_CUDA(cudaStreamCreate(&stream), "failed to create a CUDA stream");
149 return stream;
150 }
151
155 cudaStream_t operator () (cudaStream_t stream) const {
156 return stream;
157 }
158};
159
166
167 public:
168
172 void operator () (cudaStream_t stream) const {
173 cudaStreamDestroy(stream);
174 }
175};
176
188template <typename Creator, typename Deleter>
189class cudaStreamBase : public std::unique_ptr<std::remove_pointer_t<cudaStream_t>, Deleter> {
190
191 static_assert(std::is_pointer_v<cudaStream_t>, "cudaStream_t is not a pointer type");
192
193 public:
194
201 using base_type = std::unique_ptr<std::remove_pointer_t<cudaStream_t>, Deleter>;
202
210 template <typename... ArgsT>
211 explicit cudaStreamBase(ArgsT&& ... args) : base_type(
212 Creator{}(std::forward<ArgsT>(args)...), Deleter()
213 ) {
214 }
215
220
225
233 TF_CHECK_CUDA(
234 cudaStreamSynchronize(this->get()), "failed to synchronize a CUDA stream"
235 );
236 return *this;
237 }
238
265 void begin_capture(cudaStreamCaptureMode m = cudaStreamCaptureModeGlobal) const {
266 TF_CHECK_CUDA(
267 cudaStreamBeginCapture(this->get(), m),
268 "failed to begin capture on stream ", this->get(), " with thread mode ", m
269 );
270 }
271
281 cudaGraph_t end_capture() const {
282 cudaGraph_t native_g;
283 TF_CHECK_CUDA(
284 cudaStreamEndCapture(this->get(), &native_g),
285 "failed to end capture on stream ", this->get()
286 );
287 return native_g;
288 }
289
296 void record(cudaEvent_t event) const {
297 TF_CHECK_CUDA(
298 cudaEventRecord(event, this->get()),
299 "failed to record event ", event, " on stream ", this->get()
300 );
301 }
302
309 void wait(cudaEvent_t event) const {
310 TF_CHECK_CUDA(
311 cudaStreamWaitEvent(this->get(), event, 0),
312 "failed to wait for event ", event, " on stream ", this->get()
313 );
314 }
315
321 template <typename C, typename D>
323
329 cudaStreamBase& run(cudaGraphExec_t exec);
330
331 private:
332
333 cudaStreamBase(const cudaStreamBase&) = delete;
334 cudaStreamBase& operator = (const cudaStreamBase&) = delete;
335};
336
341
342} // end of namespace tf -----------------------------------------------------
343
344
345
class to create a CUDA event with unique ownership
Definition cuda_stream.hpp:82
std::unique_ptr< std::remove_pointer_t< cudaEvent_t >, Deleter > base_type
base type for the underlying unique pointer
Definition cuda_stream.hpp:94
cudaEventBase(cudaEventBase &&)=default
constructs a cudaEvent from the given rhs using move semantics
cudaEventBase & operator=(cudaEventBase &&)=default
assign the rhs to *this using move semantics
cudaEventBase(ArgsT &&... args)
constructs a cudaEvent object by passing the given arguments to the event creator
Definition cuda_stream.hpp:104
class to create functors that construct CUDA events
Definition cuda_stream.hpp:22
cudaEvent_t operator()() const
creates a new cudaEvent_t object using cudaEventCreate
Definition cuda_stream.hpp:29
class to create a functor that deletes a CUDA event
Definition cuda_stream.hpp:60
void operator()(cudaEvent_t event) const
deletes the given cudaEvent_t object using cudaEventDestroy
Definition cuda_stream.hpp:65
class to create an executable CUDA graph with unique ownership
Definition cuda_graph_exec.hpp:93
class to create a CUDA stream with unique ownership
Definition cuda_stream.hpp:189
cudaStreamBase(cudaStreamBase &&)=default
constructs a cudaStream from the given rhs using move semantics
cudaStreamBase & synchronize()
synchronizes the associated stream
Definition cuda_stream.hpp:232
void begin_capture(cudaStreamCaptureMode m=cudaStreamCaptureModeGlobal) const
begins graph capturing on the stream
Definition cuda_stream.hpp:265
cudaGraph_t end_capture() const
ends graph capturing on the stream
Definition cuda_stream.hpp:281
cudaStreamBase(ArgsT &&... args)
constructs a cudaStream object by passing the given arguments to the stream creator
Definition cuda_stream.hpp:211
void record(cudaEvent_t event) const
records an event on the stream
Definition cuda_stream.hpp:296
cudaStreamBase & run(const cudaGraphExecBase< C, D > &exec)
runs the given executable CUDA graph
void wait(cudaEvent_t event) const
waits on an event
Definition cuda_stream.hpp:309
cudaStreamBase & operator=(cudaStreamBase &&)=default
assign the rhs to *this using move semantics
cudaStreamBase & run(cudaGraphExec_t exec)
runs the given executable CUDA graph
Definition cuda_graph_exec.hpp:366
std::unique_ptr< std::remove_pointer_t< cudaStream_t >, Deleter > base_type
base type for the underlying unique pointer
Definition cuda_stream.hpp:201
class to create functors that construct CUDA streams
Definition cuda_stream.hpp:139
cudaStream_t operator()() const
constructs a new cudaStream_t object using cudaStreamCreate
Definition cuda_stream.hpp:146
class to create a functor that deletes a CUDA stream
Definition cuda_stream.hpp:165
void operator()(cudaStream_t stream) const
deletes the given cudaStream_t object
Definition cuda_stream.hpp:172
taskflow namespace
Definition small_vector.hpp:20
cudaEventBase< cudaEventCreator, cudaEventDeleter > cudaEvent
default smart pointer type to manage a cudaEvent_t object with unique ownership
Definition cuda_stream.hpp:128
cudaStreamBase< cudaStreamCreator, cudaStreamDeleter > cudaStream
default smart pointer type to manage a cudaStream_t object with unique ownership
Definition cuda_stream.hpp:340