torch-memory-saver 0.0.6__tar.gz → 0.0.9rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/LICENSE +0 -0
  2. torch_memory_saver-0.0.9rc1/PKG-INFO +7 -0
  3. torch_memory_saver-0.0.9rc1/README.md +104 -0
  4. torch_memory_saver-0.0.9rc1/csrc/api_forwarder.cpp +52 -0
  5. torch_memory_saver-0.0.9rc1/csrc/core.cpp +332 -0
  6. torch_memory_saver-0.0.9rc1/csrc/entrypoint.cpp +122 -0
  7. torch_memory_saver-0.0.9rc1/setup.py +154 -0
  8. torch_memory_saver-0.0.9rc1/test/test_examples.py +73 -0
  9. torch_memory_saver-0.0.9rc1/torch_memory_saver/__init__.py +5 -0
  10. torch_memory_saver-0.0.9rc1/torch_memory_saver/binary_wrapper.py +31 -0
  11. torch_memory_saver-0.0.9rc1/torch_memory_saver/entrypoint.py +142 -0
  12. torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/__init__.py +0 -0
  13. torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/base.py +21 -0
  14. torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/mode_preload.py +26 -0
  15. torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/mode_torch.py +19 -0
  16. torch_memory_saver-0.0.9rc1/torch_memory_saver/testing_utils.py +10 -0
  17. torch_memory_saver-0.0.9rc1/torch_memory_saver/utils.py +27 -0
  18. torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/PKG-INFO +7 -0
  19. torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/SOURCES.txt +20 -0
  20. {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/torch_memory_saver.egg-info/dependency_links.txt +0 -0
  21. torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/top_level.txt +3 -0
  22. torch_memory_saver-0.0.6/PKG-INFO +0 -12
  23. torch_memory_saver-0.0.6/README.md +0 -29
  24. torch_memory_saver-0.0.6/csrc/torch_memory_saver.cpp +0 -288
  25. torch_memory_saver-0.0.6/setup.py +0 -49
  26. torch_memory_saver-0.0.6/torch_memory_saver/__init__.py +0 -115
  27. torch_memory_saver-0.0.6/torch_memory_saver.egg-info/PKG-INFO +0 -12
  28. torch_memory_saver-0.0.6/torch_memory_saver.egg-info/SOURCES.txt +0 -9
  29. torch_memory_saver-0.0.6/torch_memory_saver.egg-info/top_level.txt +0 -2
  30. {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/setup.cfg +0 -0
@@ -0,0 +1,7 @@
1
+ Metadata-Version: 2.4
2
+ Name: torch_memory_saver
3
+ Version: 0.0.9rc1
4
+ Requires-Python: >=3.9
5
+ License-File: LICENSE
6
+ Dynamic: license-file
7
+ Dynamic: requires-python
@@ -0,0 +1,104 @@
1
+ # Torch Memory Saver
2
+
3
+ A PyTorch library that allows tensor memory to be temporarily released and resumed later.
4
+
5
+ Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
6
+
7
+ ## Examples and Features
8
+
9
+ ### Basic Example
10
+
11
+ ```python
12
+ # 1. For tensors that wants to be paused, create them within `region`
13
+ with torch_memory_saver.region():
14
+ pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
15
+
16
+ # 2. After `pause`, CUDA memory is released for those tensors.
17
+ # For example, check `nvidia-smi`'s memory usage to verify.
18
+ torch_memory_saver.pause()
19
+
20
+ # 3. After `resume`, CUDA memory is re-occupied for those tensors.
21
+ torch_memory_saver.resume()
22
+ ```
23
+
24
+ During the pause, physical memory is released and virtual address is preserved. When resume, virtual address is kept unchanged, while physical memory is re-allocated
25
+
26
+ ### Multiple Tags
27
+
28
+ Please refer to https://github.com/sgl-project/sglang/issues/7009 for details.
29
+
30
+ ```python
31
+ # 1. Create tensors with different tags
32
+ with torch_memory_saver.region(tag="type1"):
33
+ tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
34
+
35
+ with torch_memory_saver.region(tag="type2"):
36
+ tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
37
+
38
+ # 2. Pause and resume with different tags selectively
39
+ torch_memory_saver.pause("type1")
40
+ torch_memory_saver.pause("type2")
41
+
42
+ torch_memory_saver.resume("type2")
43
+ torch_memory_saver.resume("type1")
44
+
45
+ torch_memory_saver.pause("type1")
46
+ torch_memory_saver.resume("type1")
47
+ ```
48
+
49
+ ### Release Memory in CUDA Graph
50
+
51
+ Not only does torch_memory_saver make tensors compatible with CUDA graph, but we can also release the memory held by CUDA graph (i.e. the intermediate tensors).
52
+
53
+ API: Change `torch.cuda.graph(...)` to `torch_memory_saver.cuda_graph(...)`
54
+
55
+ ### CPU Backup
56
+
57
+ By default, in order to save time, the content is thrown away. This is useful for, for example, KV cache that are to be staled, or model weights that are to be updated.
58
+
59
+ If you want the tensor content to be kept unchanged, use `enable_cpu_backup`.
60
+
61
+ ```python
62
+ with torch_memory_saver.region(enable_cpu_backup=True):
63
+ tensor1 = torch.full((5_000_000_000,), 42, dtype=torch.uint8, device='cuda')
64
+
65
+ torch_memory_saver.pause()
66
+ torch_memory_saver.resume()
67
+
68
+ assert tensor1[0] == 42, "content is kept unchanged"
69
+ ```
70
+
71
+ ### Hook Modes
72
+
73
+ There are two hook modes:
74
+
75
+ * **preload**: Use `LD_PRELOAD` to hook CUDA's malloc and free API to change allocation behavior.
76
+ * **torch**: Use torch's custom allocator API to change allocation behavior.
77
+
78
+ The mode can be chosen by:
79
+
80
+ ```python
81
+ torch_memory_saver.hook_mode = "torch"
82
+ ```
83
+
84
+ ### Example of RL with CUDA Graph
85
+
86
+ Please refer to `rl_example.py` for details.
87
+
88
+ ## Development
89
+
90
+ ```bash
91
+ make reinstall
92
+ ```
93
+
94
+ You can use this command for local testing:
95
+
96
+ ```bash
97
+ pytest /path/to/torch_memory_saver/test
98
+ ```
99
+
100
+ Or this one to test a single case (e.g. the `simple` one here):
101
+
102
+ ```bash
103
+ pytest /path/to/torch_memory_saver/test/test_examples.py::test_simple -s
104
+ ```
@@ -0,0 +1,52 @@
1
+ #include <iostream>
2
+ #include "api_forwarder.h"
3
+ #include "utils.h"
4
+ #include "macro.h"
5
+
6
+ namespace APIForwarder {
7
+ using CudaMallocFunc = cudaError_t (*)(void**, size_t);
8
+ using CudaFreeFunc = cudaError_t (*)(void*);
9
+
10
+ static void *check_dlsym(void *value) {
11
+ if (nullptr == value) {
12
+ std::cerr << "[torch_memory_saver.cpp] dlsym failed dlerror=" << dlerror() << std::endl;
13
+ exit(1);
14
+ }
15
+ return value;
16
+ }
17
+
18
+ static CudaMallocFunc real_cuda_malloc_ = NULL;
19
+ static CudaFreeFunc real_cuda_free_ = NULL;
20
+
21
+ cudaError_t call_real_cuda_malloc(void **ptr, size_t size) {
22
+ if (C10_UNLIKELY(nullptr == real_cuda_malloc_)) {
23
+ real_cuda_malloc_ = (CudaMallocFunc) check_dlsym(dlsym(RTLD_NEXT, "cudaMalloc"));
24
+ }
25
+
26
+ cudaError_t ret = real_cuda_malloc_(ptr, size);
27
+
28
+ #ifdef TMS_DEBUG_LOG
29
+ std::cout << "[torch_memory_saver.cpp] cudaMalloc [MODE NORMAL]"
30
+ << " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size << " ret=" << ret
31
+ << std::endl;
32
+ #endif
33
+
34
+ return ret;
35
+ }
36
+
37
+ cudaError_t call_real_cuda_free(void *ptr) {
38
+ if (C10_UNLIKELY(nullptr == real_cuda_free_)) {
39
+ real_cuda_free_ = (CudaFreeFunc) check_dlsym(dlsym(RTLD_NEXT, "cudaFree"));
40
+ }
41
+
42
+ cudaError_t ret = real_cuda_free_(ptr);
43
+
44
+ #ifdef TMS_DEBUG_LOG
45
+ std::cout << "[torch_memory_saver.cpp] cudaFree [MODE NORMAL]"
46
+ << " ptr=" << ptr << " ret=" << ret
47
+ << std::endl;
48
+ #endif
49
+
50
+ return ret;
51
+ }
52
+ }
@@ -0,0 +1,332 @@
1
+ #include "core.h"
2
+ #include "utils.h"
3
+ #include "macro.h"
4
+ #include "api_forwarder.h"
5
+
6
+ TorchMemorySaver::TorchMemorySaver() {}
7
+
8
+ TorchMemorySaver &TorchMemorySaver::instance() {
9
+ static TorchMemorySaver instance;
10
+ return instance;
11
+ }
12
+
13
+ cudaError_t TorchMemorySaver::malloc(void **ptr, CUdevice device, size_t size, const std::string& tag, const bool enable_cpu_backup) {
14
+ #if defined(USE_ROCM)
15
+ // hipDevice_t device;
16
+ CURESULT_CHECK(hipCtxGetDevice(&device));
17
+
18
+ // // Get granularity and calculate aligned size
19
+ // size_t granularity = CUDAUtils::cu_mem_get_granularity(device);
20
+ // size_t aligned_size = (size + granularity - 1) & ~(granularity - 1);
21
+
22
+ // //// Reserve aligned memory address, rocm will check granularity
23
+ // CURESULT_CHECK(hipMemAddressReserve((hipDeviceptr_t *)ptr, aligned_size, granularity, 0, 0));
24
+
25
+ hipMemAllocationProp prop = {};
26
+ prop.type = hipMemAllocationTypePinned;
27
+ prop.location.type = hipMemLocationTypeDevice;
28
+ prop.location.id = device;
29
+ prop.allocFlags.compressionType = 0x0;
30
+
31
+ size_t granularity;
32
+ CURESULT_CHECK(hipMemGetAllocationGranularity(&granularity, &prop,
33
+ hipMemAllocationGranularityMinimum));
34
+ size_t aligned_size = ((size + granularity - 1) / granularity) * granularity;
35
+ aligned_size = (aligned_size + MEMCREATE_CHUNK_SIZE - 1) / MEMCREATE_CHUNK_SIZE * MEMCREATE_CHUNK_SIZE;
36
+
37
+ assert(MEMCREATE_CHUNK_SIZE % granularity == 0);
38
+ assert(aligned_size % MEMCREATE_CHUNK_SIZE == 0);
39
+ assert(aligned_size % granularity == 0);
40
+
41
+
42
+ // Create allocation metadata
43
+ AllocationMetadata metadata;
44
+ metadata.size = size;
45
+ metadata.aligned_size = aligned_size;
46
+ metadata.device = device;
47
+ //// Not sure (Check these parameters)
48
+ metadata.tag = tag;
49
+ metadata.enable_cpu_backup = enable_cpu_backup;
50
+ metadata.cpu_backup = nullptr;
51
+ ////
52
+
53
+ // Get global device ID using our utility function
54
+ int global_device_id = DeviceUtils::get_global_device_id(device);
55
+
56
+ // rewrite numa node
57
+ uint64_t node_id = 0;
58
+ if (global_device_id > 3) {
59
+ node_id = 1;
60
+ }
61
+
62
+ #ifdef TMS_DEBUG_LOG
63
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
64
+ << " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
65
+ << " granularity=" << granularity
66
+ << " aligned_size=" << aligned_size
67
+ << " node_id=" << node_id
68
+ << " device=" << device
69
+ << " global_device_id=" << global_device_id
70
+ << std::endl;
71
+ #endif
72
+
73
+ hipDeviceptr_t d_mem;
74
+ // Reserve aligned memory address, rocm will check granularity
75
+ CURESULT_CHECK(hipMemAddressReserve(&d_mem, aligned_size, granularity, 0, node_id));
76
+ *ptr = (void*)d_mem;
77
+
78
+ // Create and map chunks
79
+ // CUDAUtils::cu_mem_create_and_map(device, size, (hipDeviceptr_t)*ptr,
80
+ CUDAUtils::cu_mem_create_and_map(device, aligned_size, (hipDeviceptr_t)*ptr,
81
+ metadata.allocHandles, metadata.chunk_sizes);
82
+ size_t num_chunks = metadata.allocHandles.size();
83
+ {
84
+ const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
85
+ allocation_metadata_.emplace(*ptr, std::move(metadata));
86
+ }
87
+ #ifdef TMS_DEBUG_LOG
88
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
89
+ << " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
90
+ << " metadata.aligned_size=" << metadata.aligned_size
91
+ << " num_chunks=" << num_chunks
92
+ << std::endl;
93
+ #endif
94
+
95
+ #elif defined(USE_CUDA)
96
+ CUmemGenericAllocationHandle allocHandle;
97
+ CUDAUtils::cu_mem_create(&allocHandle, size, device);
98
+ CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr *) ptr, size, 0, 0, 0));
99
+ CURESULT_CHECK(cuMemMap((CUdeviceptr) * ptr, size, 0, allocHandle, 0));
100
+ CUDAUtils::cu_mem_set_access(*ptr, size, device);
101
+
102
+ {
103
+ const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
104
+ allocation_metadata_.emplace(
105
+ *ptr,
106
+ AllocationMetadata{size, device, tag, AllocationState::ACTIVE, enable_cpu_backup, nullptr, allocHandle}
107
+ );
108
+ }
109
+
110
+ #ifdef TMS_DEBUG_LOG
111
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
112
+ << " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
113
+ << " allocHandle=" << allocHandle << " tag=" << tag
114
+ << std::endl;
115
+ #endif
116
+
117
+ #else
118
+ #error "USE_PLATFORM is not set"
119
+ #endif
120
+ return cudaSuccess;
121
+ }
122
+
123
+ cudaError_t TorchMemorySaver::free(void *ptr) {
124
+ #if defined(USE_ROCM)
125
+ AllocationMetadata metadata;
126
+ {
127
+ const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
128
+ SIMPLE_CHECK(allocation_metadata_.count(ptr), "Trying to free a pointer not allocated here");
129
+ metadata = std::move(allocation_metadata_[ptr]);
130
+ allocation_metadata_.erase(ptr);
131
+ }
132
+
133
+ // Unmap and release chunks
134
+ CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.size,
135
+ (hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
136
+
137
+ // Free the reserved address using stored aligned_size
138
+ CURESULT_CHECK(hipMemAddressFree((hipDeviceptr_t)ptr, metadata.aligned_size));
139
+
140
+ #ifdef TMS_DEBUG_LOG
141
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
142
+ << " ptr=" << ptr << " metadata.size=" << metadata.size
143
+ << " metadata.aligned_size=" << metadata.aligned_size
144
+ << " num_chunks=" << metadata.allocHandles.size()
145
+ << std::endl;
146
+ #endif
147
+ #elif defined(USE_CUDA)
148
+ AllocationMetadata metadata;
149
+ {
150
+ const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
151
+ if (allocation_metadata_.count(ptr) == 0) {
152
+ return APIForwarder::call_real_cuda_free(ptr);
153
+ }
154
+
155
+ metadata = allocation_metadata_[ptr];
156
+ allocation_metadata_.erase(ptr);
157
+ }
158
+
159
+ CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
160
+ CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
161
+ CURESULT_CHECK(cuMemAddressFree((CUdeviceptr) ptr, metadata.size));
162
+
163
+ if (nullptr != metadata.cpu_backup) {
164
+ CUDA_ERROR_CHECK(cudaFreeHost(metadata.cpu_backup));
165
+ metadata.cpu_backup = nullptr;
166
+ }
167
+
168
+ #ifdef TMS_DEBUG_LOG
169
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
170
+ << " ptr=" << ptr << " metadata.size=" << metadata.size
171
+ << " metadata.allocHandle=" << metadata.allocHandle << " tag=" << metadata.tag
172
+ << std::endl;
173
+ #endif
174
+
175
+ #else
176
+ #error "USE_PLATFORM is not set"
177
+ #endif
178
+ return cudaSuccess;
179
+ }
180
+
181
+ void TorchMemorySaver::pause(const std::string& tag) {
182
+ const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
183
+
184
+ #if defined(USE_ROCM)
185
+ for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
186
+ void *ptr = it->first;
187
+ AllocationMetadata &metadata = it->second;
188
+
189
+ if (!tag.empty() && metadata.tag != tag) {
190
+ continue;
191
+ }
192
+ // Copy CUDA's code supporting cpu_backup to here
193
+ if (metadata.enable_cpu_backup) {
194
+ if (nullptr == metadata.cpu_backup) {
195
+ CUDA_ERROR_CHECK(hipMallocHost(&metadata.cpu_backup, metadata.aligned_size));
196
+ }
197
+ SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
198
+ // TODO may use cudaMemcpyAsync if needed
199
+ CUDA_ERROR_CHECK(cudaMemcpy(metadata.cpu_backup, ptr, metadata.aligned_size, hipMemcpyDeviceToHost));
200
+ }
201
+ //
202
+
203
+ // Unmap and release chunks (but keep metadata for resume)
204
+ // CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.size,
205
+ CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.aligned_size,
206
+ (hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
207
+
208
+ #ifdef TMS_DEBUG_LOG
209
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
210
+ << " ptr=" << ptr << " metadata.size=" << metadata.size
211
+ << " metadata.aligned_size=" << metadata.aligned_size
212
+ << " num_chunks=" << metadata.allocHandles.size()
213
+ << std::endl;
214
+ #endif
215
+ }
216
+ #elif defined(USE_CUDA)
217
+ for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
218
+ void *ptr = it->first;
219
+ AllocationMetadata& metadata = it->second;
220
+
221
+ if (!tag.empty() && metadata.tag != tag) {
222
+ continue;
223
+ }
224
+
225
+ if (metadata.state != AllocationState::ACTIVE) {
226
+ std::cerr << "[torch_memory_saver.cpp] Cannot pause allocation that is not active."
227
+ << " tag=" << metadata.tag << " ptr=" << std::to_string((uintptr_t)ptr)
228
+ << " file=" << __FILE__ << " func=" << __func__ << " line=" << __LINE__
229
+ << std::endl;
230
+ exit(1);
231
+ }
232
+
233
+ if (metadata.enable_cpu_backup) {
234
+ if (nullptr == metadata.cpu_backup) {
235
+ CUDA_ERROR_CHECK(cudaMallocHost(&metadata.cpu_backup, metadata.size));
236
+ }
237
+ SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
238
+ // TODO may use cudaMemcpyAsync if needed
239
+ CUDA_ERROR_CHECK(cudaMemcpy(metadata.cpu_backup, ptr, metadata.size, cudaMemcpyDeviceToHost));
240
+ }
241
+
242
+ CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
243
+ CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
244
+
245
+ metadata.state = AllocationState::PAUSED;
246
+
247
+ #ifdef TMS_DEBUG_LOG
248
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
249
+ << " ptr=" << ptr << " metadata.size=" << metadata.size << " metadata.allocHandle="
250
+ << metadata.allocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
251
+ << " metadata.enable_cpu_backup=" << metadata.enable_cpu_backup
252
+ << std::endl;
253
+ #endif
254
+ }
255
+ #else
256
+ #error "USE_PLATFORM is not set"
257
+ #endif
258
+ }
259
+
260
+ void TorchMemorySaver::resume(const std::string& tag) {
261
+ const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
262
+
263
+ #if defined(USE_ROCM)
264
+ for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
265
+ void *ptr = it->first;
266
+ AllocationMetadata &metadata = it->second;
267
+
268
+ if (!tag.empty() && metadata.tag != tag) {
269
+ continue;
270
+ }
271
+
272
+ // Create new handles and map chunks
273
+ // CUDAUtils::cu_mem_create_and_map(metadata.device, metadata.size,
274
+ CUDAUtils::cu_mem_create_and_map(metadata.device, metadata.aligned_size,
275
+ (hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
276
+
277
+ #ifdef TMS_DEBUG_LOG
278
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
279
+ << " ptr=" << ptr << " metadata.size=" << metadata.size
280
+ << " metadata.aligned_size=" << metadata.aligned_size
281
+ << " num_chunks=" << metadata.allocHandles.size()
282
+ << std::endl;
283
+ #endif
284
+ }
285
+
286
+ #elif defined(USE_CUDA)
287
+ for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
288
+ void *ptr = it->first;
289
+ AllocationMetadata &metadata = it->second;
290
+
291
+ if (!tag.empty() && metadata.tag != tag) {
292
+ continue;
293
+ }
294
+
295
+ if (metadata.state != AllocationState::PAUSED) {
296
+ std::cerr << "[torch_memory_saver.cpp] Cannot resume allocation that is not paused. "
297
+ << " tag=" << metadata.tag << " ptr=" << std::to_string((uintptr_t)ptr)
298
+ << " file=" << __FILE__ << " func=" << __func__ << " line=" << __LINE__
299
+ << std::endl;
300
+ exit(1);
301
+ }
302
+
303
+ CUmemGenericAllocationHandle newAllocHandle;
304
+ CUDAUtils::cu_mem_create(&newAllocHandle, metadata.size, metadata.device);
305
+
306
+ CURESULT_CHECK(cuMemMap((CUdeviceptr) ptr, metadata.size, 0, newAllocHandle, 0));
307
+
308
+ CUDAUtils::cu_mem_set_access(ptr, metadata.size, metadata.device);
309
+
310
+ if (metadata.enable_cpu_backup) {
311
+ SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
312
+ // TODO may use cudaMemcpyAsync if needed
313
+ CUDA_ERROR_CHECK(cudaMemcpy(ptr, metadata.cpu_backup, metadata.size, cudaMemcpyHostToDevice));
314
+ // maybe we can free host memory if needed (currently keep it there to reduce re-alloc time)
315
+ }
316
+
317
+ #ifdef TMS_DEBUG_LOG
318
+ std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
319
+ << " ptr=" << ptr << " metadata.size=" << metadata.size << " (old)metadata.allocHandle="
320
+ << metadata.allocHandle
321
+ << " (new)newAllocHandle=" << newAllocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
322
+ << " metadata.enable_cpu_backup=" << metadata.enable_cpu_backup
323
+ << std::endl;
324
+ #endif
325
+
326
+ metadata.state = AllocationState::ACTIVE;
327
+ metadata.allocHandle = newAllocHandle;
328
+ }
329
+ #else
330
+ #error "USE_PLATFORM is not set"
331
+ #endif
332
+ }
@@ -0,0 +1,122 @@
1
+ #include "utils.h"
2
+ #include "core.h"
3
+ #include "api_forwarder.h"
4
+ #include <optional>
5
+ #include "macro.h"
6
+
7
+ // ----------------------------------------------- threadlocal configs --------------------------------------------------
8
+
9
+ class ThreadLocalConfig {
10
+ public:
11
+ std::string current_tag_ = "default";
12
+
13
+ bool is_interesting_region() {
14
+ if (!is_interesting_region_.has_value()) {
15
+ is_interesting_region_ = get_bool_env_var("TMS_INIT_ENABLE");
16
+ }
17
+ return is_interesting_region_.value();
18
+ }
19
+
20
+ void set_interesting_region(bool value) {
21
+ is_interesting_region_ = value;
22
+ }
23
+
24
+ bool enable_cpu_backup() {
25
+ if (!enable_cpu_backup_.has_value()) {
26
+ enable_cpu_backup_ = get_bool_env_var("TMS_INIT_ENABLE_CPU_BACKUP");
27
+ }
28
+ return enable_cpu_backup_.value();
29
+ }
30
+
31
+ void set_enable_cpu_backup(bool value) {
32
+ enable_cpu_backup_ = value;
33
+ }
34
+
35
+ private:
36
+ std::optional<bool> is_interesting_region_;
37
+ std::optional<bool> enable_cpu_backup_;
38
+ };
39
+ static thread_local ThreadLocalConfig thread_local_config;
40
+
41
+ // ------------------------------------------------- entrypoints :: hook ------------------------------------------------
42
+
43
+ #ifdef TMS_HOOK_MODE_PRELOAD
44
+ cudaError_t cudaMalloc(void **ptr, size_t size) {
45
+ if (thread_local_config.is_interesting_region()) {
46
+ return TorchMemorySaver::instance().malloc(
47
+ ptr, CUDAUtils::cu_ctx_get_device(), size, thread_local_config.current_tag_, thread_local_config.enable_cpu_backup());
48
+ } else {
49
+ return APIForwarder::call_real_cuda_malloc(ptr, size);
50
+ }
51
+ }
52
+
53
+ cudaError_t cudaFree(void *ptr) {
54
+ if (thread_local_config.is_interesting_region()) {
55
+ return TorchMemorySaver::instance().free(ptr);
56
+ } else {
57
+ return APIForwarder::call_real_cuda_free(ptr);
58
+ }
59
+ }
60
+ #endif
61
+
62
+ #ifdef TMS_HOOK_MODE_TORCH
63
+ extern "C" {
64
+ void *tms_torch_malloc(ssize_t size, int device, cudaStream_t stream) {
65
+ #ifdef TMS_DEBUG_LOG
66
+ std::cout << "[torch_memory_saver.cpp] tms_torch_malloc "
67
+ << " size=" << size << " device=" << device << " stream=" << stream
68
+ << std::endl;
69
+ #endif
70
+ SIMPLE_CHECK(thread_local_config.is_interesting_region(), "only support interesting region");
71
+ void *ptr;
72
+ TorchMemorySaver::instance().malloc(
73
+ &ptr, CUDAUtils::cu_device_get(device), size, thread_local_config.current_tag_, thread_local_config.enable_cpu_backup());
74
+ return ptr;
75
+ }
76
+
77
+ void tms_torch_free(void *ptr, ssize_t ssize, int device, cudaStream_t stream) {
78
+ #ifdef TMS_DEBUG_LOG
79
+ std::cout << "[torch_memory_saver.cpp] tms_torch_free "
80
+ << " ptr=" << ptr << " ssize=" << ssize << " device=" << device << " stream=" << stream
81
+ << std::endl;
82
+ #endif
83
+ SIMPLE_CHECK(thread_local_config.is_interesting_region(), "only support interesting region");
84
+ TorchMemorySaver::instance().free(ptr);
85
+ }
86
+ }
87
+ #endif
88
+
89
+ // ------------------------------------------------- entrypoints :: others ------------------------------------------------
90
+
91
+ extern "C" {
92
+ void tms_set_interesting_region(bool is_interesting_region) {
93
+ thread_local_config.set_interesting_region(is_interesting_region);
94
+ }
95
+
96
+ bool tms_get_interesting_region() {
97
+ return thread_local_config.is_interesting_region();
98
+ }
99
+
100
+ void tms_set_current_tag(const char* tag) {
101
+ SIMPLE_CHECK(tag != nullptr, "tag should not be null");
102
+ thread_local_config.current_tag_ = tag;
103
+ }
104
+
105
+ bool tms_get_enable_cpu_backup() {
106
+ return thread_local_config.enable_cpu_backup();
107
+ }
108
+
109
+ void tms_set_enable_cpu_backup(bool enable_cpu_backup) {
110
+ thread_local_config.set_enable_cpu_backup(enable_cpu_backup);
111
+ }
112
+
113
+ void tms_pause(const char* tag) {
114
+ std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
115
+ TorchMemorySaver::instance().pause(tag_str);
116
+ }
117
+
118
+ void tms_resume(const char* tag) {
119
+ std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
120
+ TorchMemorySaver::instance().resume(tag_str);
121
+ }
122
+ }