torch-memory-saver 0.0.6__tar.gz → 0.0.9rc1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/LICENSE +0 -0
- torch_memory_saver-0.0.9rc1/PKG-INFO +7 -0
- torch_memory_saver-0.0.9rc1/README.md +104 -0
- torch_memory_saver-0.0.9rc1/csrc/api_forwarder.cpp +52 -0
- torch_memory_saver-0.0.9rc1/csrc/core.cpp +332 -0
- torch_memory_saver-0.0.9rc1/csrc/entrypoint.cpp +122 -0
- torch_memory_saver-0.0.9rc1/setup.py +154 -0
- torch_memory_saver-0.0.9rc1/test/test_examples.py +73 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/__init__.py +5 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/binary_wrapper.py +31 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/entrypoint.py +142 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/__init__.py +0 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/base.py +21 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/mode_preload.py +26 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/hooks/mode_torch.py +19 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/testing_utils.py +10 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver/utils.py +27 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/PKG-INFO +7 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/SOURCES.txt +20 -0
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/torch_memory_saver.egg-info/dependency_links.txt +0 -0
- torch_memory_saver-0.0.9rc1/torch_memory_saver.egg-info/top_level.txt +3 -0
- torch_memory_saver-0.0.6/PKG-INFO +0 -12
- torch_memory_saver-0.0.6/README.md +0 -29
- torch_memory_saver-0.0.6/csrc/torch_memory_saver.cpp +0 -288
- torch_memory_saver-0.0.6/setup.py +0 -49
- torch_memory_saver-0.0.6/torch_memory_saver/__init__.py +0 -115
- torch_memory_saver-0.0.6/torch_memory_saver.egg-info/PKG-INFO +0 -12
- torch_memory_saver-0.0.6/torch_memory_saver.egg-info/SOURCES.txt +0 -9
- torch_memory_saver-0.0.6/torch_memory_saver.egg-info/top_level.txt +0 -2
- {torch_memory_saver-0.0.6 → torch_memory_saver-0.0.9rc1}/setup.cfg +0 -0
File without changes
|
@@ -0,0 +1,104 @@
|
|
1
|
+
# Torch Memory Saver
|
2
|
+
|
3
|
+
A PyTorch library that allows tensor memory to be temporarily released and resumed later.
|
4
|
+
|
5
|
+
Please refer to https://github.com/sgl-project/sglang/issues/2542#issuecomment-2563641647 for details.
|
6
|
+
|
7
|
+
## Examples and Features
|
8
|
+
|
9
|
+
### Basic Example
|
10
|
+
|
11
|
+
```python
|
12
|
+
# 1. For tensors that wants to be paused, create them within `region`
|
13
|
+
with torch_memory_saver.region():
|
14
|
+
pauseable_tensor = torch.full((1_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
15
|
+
|
16
|
+
# 2. After `pause`, CUDA memory is released for those tensors.
|
17
|
+
# For example, check `nvidia-smi`'s memory usage to verify.
|
18
|
+
torch_memory_saver.pause()
|
19
|
+
|
20
|
+
# 3. After `resume`, CUDA memory is re-occupied for those tensors.
|
21
|
+
torch_memory_saver.resume()
|
22
|
+
```
|
23
|
+
|
24
|
+
During the pause, physical memory is released and virtual address is preserved. When resume, virtual address is kept unchanged, while physical memory is re-allocated
|
25
|
+
|
26
|
+
### Multiple Tags
|
27
|
+
|
28
|
+
Please refer to https://github.com/sgl-project/sglang/issues/7009 for details.
|
29
|
+
|
30
|
+
```python
|
31
|
+
# 1. Create tensors with different tags
|
32
|
+
with torch_memory_saver.region(tag="type1"):
|
33
|
+
tensor1 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
34
|
+
|
35
|
+
with torch_memory_saver.region(tag="type2"):
|
36
|
+
tensor2 = torch.full((5_000_000_000,), 100, dtype=torch.uint8, device='cuda')
|
37
|
+
|
38
|
+
# 2. Pause and resume with different tags selectively
|
39
|
+
torch_memory_saver.pause("type1")
|
40
|
+
torch_memory_saver.pause("type2")
|
41
|
+
|
42
|
+
torch_memory_saver.resume("type2")
|
43
|
+
torch_memory_saver.resume("type1")
|
44
|
+
|
45
|
+
torch_memory_saver.pause("type1")
|
46
|
+
torch_memory_saver.resume("type1")
|
47
|
+
```
|
48
|
+
|
49
|
+
### Release Memory in CUDA Graph
|
50
|
+
|
51
|
+
Not only does torch_memory_saver make tensors compatible with CUDA graph, but we can also release the memory held by CUDA graph (i.e. the intermediate tensors).
|
52
|
+
|
53
|
+
API: Change `torch.cuda.graph(...)` to `torch_memory_saver.cuda_graph(...)`
|
54
|
+
|
55
|
+
### CPU Backup
|
56
|
+
|
57
|
+
By default, in order to save time, the content is thrown away. This is useful for, for example, KV cache that are to be staled, or model weights that are to be updated.
|
58
|
+
|
59
|
+
If you want the tensor content to be kept unchanged, use `enable_cpu_backup`.
|
60
|
+
|
61
|
+
```python
|
62
|
+
with torch_memory_saver.region(enable_cpu_backup=True):
|
63
|
+
tensor1 = torch.full((5_000_000_000,), 42, dtype=torch.uint8, device='cuda')
|
64
|
+
|
65
|
+
torch_memory_saver.pause()
|
66
|
+
torch_memory_saver.resume()
|
67
|
+
|
68
|
+
assert tensor1[0] == 42, "content is kept unchanged"
|
69
|
+
```
|
70
|
+
|
71
|
+
### Hook Modes
|
72
|
+
|
73
|
+
There are two hook modes:
|
74
|
+
|
75
|
+
* **preload**: Use `LD_PRELOAD` to hook CUDA's malloc and free API to change allocation behavior.
|
76
|
+
* **torch**: Use torch's custom allocator API to change allocation behavior.
|
77
|
+
|
78
|
+
The mode can be chosen by:
|
79
|
+
|
80
|
+
```python
|
81
|
+
torch_memory_saver.hook_mode = "torch"
|
82
|
+
```
|
83
|
+
|
84
|
+
### Example of RL with CUDA Graph
|
85
|
+
|
86
|
+
Please refer to `rl_example.py` for details.
|
87
|
+
|
88
|
+
## Development
|
89
|
+
|
90
|
+
```bash
|
91
|
+
make reinstall
|
92
|
+
```
|
93
|
+
|
94
|
+
You can use this command for local testing:
|
95
|
+
|
96
|
+
```bash
|
97
|
+
pytest /path/to/torch_memory_saver/test
|
98
|
+
```
|
99
|
+
|
100
|
+
Or this one to test a single case (e.g. the `simple` one here):
|
101
|
+
|
102
|
+
```bash
|
103
|
+
pytest /path/to/torch_memory_saver/test/test_examples.py::test_simple -s
|
104
|
+
```
|
@@ -0,0 +1,52 @@
|
|
1
|
+
#include <iostream>
|
2
|
+
#include "api_forwarder.h"
|
3
|
+
#include "utils.h"
|
4
|
+
#include "macro.h"
|
5
|
+
|
6
|
+
namespace APIForwarder {
|
7
|
+
using CudaMallocFunc = cudaError_t (*)(void**, size_t);
|
8
|
+
using CudaFreeFunc = cudaError_t (*)(void*);
|
9
|
+
|
10
|
+
static void *check_dlsym(void *value) {
|
11
|
+
if (nullptr == value) {
|
12
|
+
std::cerr << "[torch_memory_saver.cpp] dlsym failed dlerror=" << dlerror() << std::endl;
|
13
|
+
exit(1);
|
14
|
+
}
|
15
|
+
return value;
|
16
|
+
}
|
17
|
+
|
18
|
+
static CudaMallocFunc real_cuda_malloc_ = NULL;
|
19
|
+
static CudaFreeFunc real_cuda_free_ = NULL;
|
20
|
+
|
21
|
+
cudaError_t call_real_cuda_malloc(void **ptr, size_t size) {
|
22
|
+
if (C10_UNLIKELY(nullptr == real_cuda_malloc_)) {
|
23
|
+
real_cuda_malloc_ = (CudaMallocFunc) check_dlsym(dlsym(RTLD_NEXT, "cudaMalloc"));
|
24
|
+
}
|
25
|
+
|
26
|
+
cudaError_t ret = real_cuda_malloc_(ptr, size);
|
27
|
+
|
28
|
+
#ifdef TMS_DEBUG_LOG
|
29
|
+
std::cout << "[torch_memory_saver.cpp] cudaMalloc [MODE NORMAL]"
|
30
|
+
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size << " ret=" << ret
|
31
|
+
<< std::endl;
|
32
|
+
#endif
|
33
|
+
|
34
|
+
return ret;
|
35
|
+
}
|
36
|
+
|
37
|
+
cudaError_t call_real_cuda_free(void *ptr) {
|
38
|
+
if (C10_UNLIKELY(nullptr == real_cuda_free_)) {
|
39
|
+
real_cuda_free_ = (CudaFreeFunc) check_dlsym(dlsym(RTLD_NEXT, "cudaFree"));
|
40
|
+
}
|
41
|
+
|
42
|
+
cudaError_t ret = real_cuda_free_(ptr);
|
43
|
+
|
44
|
+
#ifdef TMS_DEBUG_LOG
|
45
|
+
std::cout << "[torch_memory_saver.cpp] cudaFree [MODE NORMAL]"
|
46
|
+
<< " ptr=" << ptr << " ret=" << ret
|
47
|
+
<< std::endl;
|
48
|
+
#endif
|
49
|
+
|
50
|
+
return ret;
|
51
|
+
}
|
52
|
+
}
|
@@ -0,0 +1,332 @@
|
|
1
|
+
#include "core.h"
|
2
|
+
#include "utils.h"
|
3
|
+
#include "macro.h"
|
4
|
+
#include "api_forwarder.h"
|
5
|
+
|
6
|
+
TorchMemorySaver::TorchMemorySaver() {}
|
7
|
+
|
8
|
+
TorchMemorySaver &TorchMemorySaver::instance() {
|
9
|
+
static TorchMemorySaver instance;
|
10
|
+
return instance;
|
11
|
+
}
|
12
|
+
|
13
|
+
cudaError_t TorchMemorySaver::malloc(void **ptr, CUdevice device, size_t size, const std::string& tag, const bool enable_cpu_backup) {
|
14
|
+
#if defined(USE_ROCM)
|
15
|
+
// hipDevice_t device;
|
16
|
+
CURESULT_CHECK(hipCtxGetDevice(&device));
|
17
|
+
|
18
|
+
// // Get granularity and calculate aligned size
|
19
|
+
// size_t granularity = CUDAUtils::cu_mem_get_granularity(device);
|
20
|
+
// size_t aligned_size = (size + granularity - 1) & ~(granularity - 1);
|
21
|
+
|
22
|
+
// //// Reserve aligned memory address, rocm will check granularity
|
23
|
+
// CURESULT_CHECK(hipMemAddressReserve((hipDeviceptr_t *)ptr, aligned_size, granularity, 0, 0));
|
24
|
+
|
25
|
+
hipMemAllocationProp prop = {};
|
26
|
+
prop.type = hipMemAllocationTypePinned;
|
27
|
+
prop.location.type = hipMemLocationTypeDevice;
|
28
|
+
prop.location.id = device;
|
29
|
+
prop.allocFlags.compressionType = 0x0;
|
30
|
+
|
31
|
+
size_t granularity;
|
32
|
+
CURESULT_CHECK(hipMemGetAllocationGranularity(&granularity, &prop,
|
33
|
+
hipMemAllocationGranularityMinimum));
|
34
|
+
size_t aligned_size = ((size + granularity - 1) / granularity) * granularity;
|
35
|
+
aligned_size = (aligned_size + MEMCREATE_CHUNK_SIZE - 1) / MEMCREATE_CHUNK_SIZE * MEMCREATE_CHUNK_SIZE;
|
36
|
+
|
37
|
+
assert(MEMCREATE_CHUNK_SIZE % granularity == 0);
|
38
|
+
assert(aligned_size % MEMCREATE_CHUNK_SIZE == 0);
|
39
|
+
assert(aligned_size % granularity == 0);
|
40
|
+
|
41
|
+
|
42
|
+
// Create allocation metadata
|
43
|
+
AllocationMetadata metadata;
|
44
|
+
metadata.size = size;
|
45
|
+
metadata.aligned_size = aligned_size;
|
46
|
+
metadata.device = device;
|
47
|
+
//// Not sure (Check these parameters)
|
48
|
+
metadata.tag = tag;
|
49
|
+
metadata.enable_cpu_backup = enable_cpu_backup;
|
50
|
+
metadata.cpu_backup = nullptr;
|
51
|
+
////
|
52
|
+
|
53
|
+
// Get global device ID using our utility function
|
54
|
+
int global_device_id = DeviceUtils::get_global_device_id(device);
|
55
|
+
|
56
|
+
// rewrite numa node
|
57
|
+
uint64_t node_id = 0;
|
58
|
+
if (global_device_id > 3) {
|
59
|
+
node_id = 1;
|
60
|
+
}
|
61
|
+
|
62
|
+
#ifdef TMS_DEBUG_LOG
|
63
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
|
64
|
+
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
|
65
|
+
<< " granularity=" << granularity
|
66
|
+
<< " aligned_size=" << aligned_size
|
67
|
+
<< " node_id=" << node_id
|
68
|
+
<< " device=" << device
|
69
|
+
<< " global_device_id=" << global_device_id
|
70
|
+
<< std::endl;
|
71
|
+
#endif
|
72
|
+
|
73
|
+
hipDeviceptr_t d_mem;
|
74
|
+
// Reserve aligned memory address, rocm will check granularity
|
75
|
+
CURESULT_CHECK(hipMemAddressReserve(&d_mem, aligned_size, granularity, 0, node_id));
|
76
|
+
*ptr = (void*)d_mem;
|
77
|
+
|
78
|
+
// Create and map chunks
|
79
|
+
// CUDAUtils::cu_mem_create_and_map(device, size, (hipDeviceptr_t)*ptr,
|
80
|
+
CUDAUtils::cu_mem_create_and_map(device, aligned_size, (hipDeviceptr_t)*ptr,
|
81
|
+
metadata.allocHandles, metadata.chunk_sizes);
|
82
|
+
size_t num_chunks = metadata.allocHandles.size();
|
83
|
+
{
|
84
|
+
const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
|
85
|
+
allocation_metadata_.emplace(*ptr, std::move(metadata));
|
86
|
+
}
|
87
|
+
#ifdef TMS_DEBUG_LOG
|
88
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
|
89
|
+
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
|
90
|
+
<< " metadata.aligned_size=" << metadata.aligned_size
|
91
|
+
<< " num_chunks=" << num_chunks
|
92
|
+
<< std::endl;
|
93
|
+
#endif
|
94
|
+
|
95
|
+
#elif defined(USE_CUDA)
|
96
|
+
CUmemGenericAllocationHandle allocHandle;
|
97
|
+
CUDAUtils::cu_mem_create(&allocHandle, size, device);
|
98
|
+
CURESULT_CHECK(cuMemAddressReserve((CUdeviceptr *) ptr, size, 0, 0, 0));
|
99
|
+
CURESULT_CHECK(cuMemMap((CUdeviceptr) * ptr, size, 0, allocHandle, 0));
|
100
|
+
CUDAUtils::cu_mem_set_access(*ptr, size, device);
|
101
|
+
|
102
|
+
{
|
103
|
+
const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
|
104
|
+
allocation_metadata_.emplace(
|
105
|
+
*ptr,
|
106
|
+
AllocationMetadata{size, device, tag, AllocationState::ACTIVE, enable_cpu_backup, nullptr, allocHandle}
|
107
|
+
);
|
108
|
+
}
|
109
|
+
|
110
|
+
#ifdef TMS_DEBUG_LOG
|
111
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_malloc "
|
112
|
+
<< " ptr=" << ptr << " *ptr=" << *ptr << " size=" << size
|
113
|
+
<< " allocHandle=" << allocHandle << " tag=" << tag
|
114
|
+
<< std::endl;
|
115
|
+
#endif
|
116
|
+
|
117
|
+
#else
|
118
|
+
#error "USE_PLATFORM is not set"
|
119
|
+
#endif
|
120
|
+
return cudaSuccess;
|
121
|
+
}
|
122
|
+
|
123
|
+
cudaError_t TorchMemorySaver::free(void *ptr) {
|
124
|
+
#if defined(USE_ROCM)
|
125
|
+
AllocationMetadata metadata;
|
126
|
+
{
|
127
|
+
const std::lock_guard<std::mutex> lock(allocator_metadata_mutex_);
|
128
|
+
SIMPLE_CHECK(allocation_metadata_.count(ptr), "Trying to free a pointer not allocated here");
|
129
|
+
metadata = std::move(allocation_metadata_[ptr]);
|
130
|
+
allocation_metadata_.erase(ptr);
|
131
|
+
}
|
132
|
+
|
133
|
+
// Unmap and release chunks
|
134
|
+
CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.size,
|
135
|
+
(hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
|
136
|
+
|
137
|
+
// Free the reserved address using stored aligned_size
|
138
|
+
CURESULT_CHECK(hipMemAddressFree((hipDeviceptr_t)ptr, metadata.aligned_size));
|
139
|
+
|
140
|
+
#ifdef TMS_DEBUG_LOG
|
141
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
|
142
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size
|
143
|
+
<< " metadata.aligned_size=" << metadata.aligned_size
|
144
|
+
<< " num_chunks=" << metadata.allocHandles.size()
|
145
|
+
<< std::endl;
|
146
|
+
#endif
|
147
|
+
#elif defined(USE_CUDA)
|
148
|
+
AllocationMetadata metadata;
|
149
|
+
{
|
150
|
+
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
|
151
|
+
if (allocation_metadata_.count(ptr) == 0) {
|
152
|
+
return APIForwarder::call_real_cuda_free(ptr);
|
153
|
+
}
|
154
|
+
|
155
|
+
metadata = allocation_metadata_[ptr];
|
156
|
+
allocation_metadata_.erase(ptr);
|
157
|
+
}
|
158
|
+
|
159
|
+
CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
|
160
|
+
CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
|
161
|
+
CURESULT_CHECK(cuMemAddressFree((CUdeviceptr) ptr, metadata.size));
|
162
|
+
|
163
|
+
if (nullptr != metadata.cpu_backup) {
|
164
|
+
CUDA_ERROR_CHECK(cudaFreeHost(metadata.cpu_backup));
|
165
|
+
metadata.cpu_backup = nullptr;
|
166
|
+
}
|
167
|
+
|
168
|
+
#ifdef TMS_DEBUG_LOG
|
169
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.cuda_free "
|
170
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size
|
171
|
+
<< " metadata.allocHandle=" << metadata.allocHandle << " tag=" << metadata.tag
|
172
|
+
<< std::endl;
|
173
|
+
#endif
|
174
|
+
|
175
|
+
#else
|
176
|
+
#error "USE_PLATFORM is not set"
|
177
|
+
#endif
|
178
|
+
return cudaSuccess;
|
179
|
+
}
|
180
|
+
|
181
|
+
void TorchMemorySaver::pause(const std::string& tag) {
|
182
|
+
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
|
183
|
+
|
184
|
+
#if defined(USE_ROCM)
|
185
|
+
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
186
|
+
void *ptr = it->first;
|
187
|
+
AllocationMetadata &metadata = it->second;
|
188
|
+
|
189
|
+
if (!tag.empty() && metadata.tag != tag) {
|
190
|
+
continue;
|
191
|
+
}
|
192
|
+
// Copy CUDA's code supporting cpu_backup to here
|
193
|
+
if (metadata.enable_cpu_backup) {
|
194
|
+
if (nullptr == metadata.cpu_backup) {
|
195
|
+
CUDA_ERROR_CHECK(hipMallocHost(&metadata.cpu_backup, metadata.aligned_size));
|
196
|
+
}
|
197
|
+
SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
|
198
|
+
// TODO may use cudaMemcpyAsync if needed
|
199
|
+
CUDA_ERROR_CHECK(cudaMemcpy(metadata.cpu_backup, ptr, metadata.aligned_size, hipMemcpyDeviceToHost));
|
200
|
+
}
|
201
|
+
//
|
202
|
+
|
203
|
+
// Unmap and release chunks (but keep metadata for resume)
|
204
|
+
// CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.size,
|
205
|
+
CUDAUtils::cu_mem_unmap_and_release(metadata.device, metadata.aligned_size,
|
206
|
+
(hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
|
207
|
+
|
208
|
+
#ifdef TMS_DEBUG_LOG
|
209
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
|
210
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size
|
211
|
+
<< " metadata.aligned_size=" << metadata.aligned_size
|
212
|
+
<< " num_chunks=" << metadata.allocHandles.size()
|
213
|
+
<< std::endl;
|
214
|
+
#endif
|
215
|
+
}
|
216
|
+
#elif defined(USE_CUDA)
|
217
|
+
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
218
|
+
void *ptr = it->first;
|
219
|
+
AllocationMetadata& metadata = it->second;
|
220
|
+
|
221
|
+
if (!tag.empty() && metadata.tag != tag) {
|
222
|
+
continue;
|
223
|
+
}
|
224
|
+
|
225
|
+
if (metadata.state != AllocationState::ACTIVE) {
|
226
|
+
std::cerr << "[torch_memory_saver.cpp] Cannot pause allocation that is not active."
|
227
|
+
<< " tag=" << metadata.tag << " ptr=" << std::to_string((uintptr_t)ptr)
|
228
|
+
<< " file=" << __FILE__ << " func=" << __func__ << " line=" << __LINE__
|
229
|
+
<< std::endl;
|
230
|
+
exit(1);
|
231
|
+
}
|
232
|
+
|
233
|
+
if (metadata.enable_cpu_backup) {
|
234
|
+
if (nullptr == metadata.cpu_backup) {
|
235
|
+
CUDA_ERROR_CHECK(cudaMallocHost(&metadata.cpu_backup, metadata.size));
|
236
|
+
}
|
237
|
+
SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
|
238
|
+
// TODO may use cudaMemcpyAsync if needed
|
239
|
+
CUDA_ERROR_CHECK(cudaMemcpy(metadata.cpu_backup, ptr, metadata.size, cudaMemcpyDeviceToHost));
|
240
|
+
}
|
241
|
+
|
242
|
+
CURESULT_CHECK(cuMemUnmap((CUdeviceptr) ptr, metadata.size));
|
243
|
+
CURESULT_CHECK(cuMemRelease(metadata.allocHandle));
|
244
|
+
|
245
|
+
metadata.state = AllocationState::PAUSED;
|
246
|
+
|
247
|
+
#ifdef TMS_DEBUG_LOG
|
248
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.pause"
|
249
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " metadata.allocHandle="
|
250
|
+
<< metadata.allocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
|
251
|
+
<< " metadata.enable_cpu_backup=" << metadata.enable_cpu_backup
|
252
|
+
<< std::endl;
|
253
|
+
#endif
|
254
|
+
}
|
255
|
+
#else
|
256
|
+
#error "USE_PLATFORM is not set"
|
257
|
+
#endif
|
258
|
+
}
|
259
|
+
|
260
|
+
void TorchMemorySaver::resume(const std::string& tag) {
|
261
|
+
const std::lock_guard <std::mutex> lock(allocator_metadata_mutex_);
|
262
|
+
|
263
|
+
#if defined(USE_ROCM)
|
264
|
+
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
265
|
+
void *ptr = it->first;
|
266
|
+
AllocationMetadata &metadata = it->second;
|
267
|
+
|
268
|
+
if (!tag.empty() && metadata.tag != tag) {
|
269
|
+
continue;
|
270
|
+
}
|
271
|
+
|
272
|
+
// Create new handles and map chunks
|
273
|
+
// CUDAUtils::cu_mem_create_and_map(metadata.device, metadata.size,
|
274
|
+
CUDAUtils::cu_mem_create_and_map(metadata.device, metadata.aligned_size,
|
275
|
+
(hipDeviceptr_t)ptr, metadata.allocHandles, metadata.chunk_sizes);
|
276
|
+
|
277
|
+
#ifdef TMS_DEBUG_LOG
|
278
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
|
279
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size
|
280
|
+
<< " metadata.aligned_size=" << metadata.aligned_size
|
281
|
+
<< " num_chunks=" << metadata.allocHandles.size()
|
282
|
+
<< std::endl;
|
283
|
+
#endif
|
284
|
+
}
|
285
|
+
|
286
|
+
#elif defined(USE_CUDA)
|
287
|
+
for (auto it = allocation_metadata_.begin(); it != allocation_metadata_.end(); ++it) {
|
288
|
+
void *ptr = it->first;
|
289
|
+
AllocationMetadata &metadata = it->second;
|
290
|
+
|
291
|
+
if (!tag.empty() && metadata.tag != tag) {
|
292
|
+
continue;
|
293
|
+
}
|
294
|
+
|
295
|
+
if (metadata.state != AllocationState::PAUSED) {
|
296
|
+
std::cerr << "[torch_memory_saver.cpp] Cannot resume allocation that is not paused. "
|
297
|
+
<< " tag=" << metadata.tag << " ptr=" << std::to_string((uintptr_t)ptr)
|
298
|
+
<< " file=" << __FILE__ << " func=" << __func__ << " line=" << __LINE__
|
299
|
+
<< std::endl;
|
300
|
+
exit(1);
|
301
|
+
}
|
302
|
+
|
303
|
+
CUmemGenericAllocationHandle newAllocHandle;
|
304
|
+
CUDAUtils::cu_mem_create(&newAllocHandle, metadata.size, metadata.device);
|
305
|
+
|
306
|
+
CURESULT_CHECK(cuMemMap((CUdeviceptr) ptr, metadata.size, 0, newAllocHandle, 0));
|
307
|
+
|
308
|
+
CUDAUtils::cu_mem_set_access(ptr, metadata.size, metadata.device);
|
309
|
+
|
310
|
+
if (metadata.enable_cpu_backup) {
|
311
|
+
SIMPLE_CHECK(metadata.cpu_backup != nullptr, "cpu_backup should not be nullptr");
|
312
|
+
// TODO may use cudaMemcpyAsync if needed
|
313
|
+
CUDA_ERROR_CHECK(cudaMemcpy(ptr, metadata.cpu_backup, metadata.size, cudaMemcpyHostToDevice));
|
314
|
+
// maybe we can free host memory if needed (currently keep it there to reduce re-alloc time)
|
315
|
+
}
|
316
|
+
|
317
|
+
#ifdef TMS_DEBUG_LOG
|
318
|
+
std::cout << "[torch_memory_saver.cpp] TorchMemorySaver.resume"
|
319
|
+
<< " ptr=" << ptr << " metadata.size=" << metadata.size << " (old)metadata.allocHandle="
|
320
|
+
<< metadata.allocHandle
|
321
|
+
<< " (new)newAllocHandle=" << newAllocHandle << " tag=" << metadata.tag << " filter_tag=" << tag
|
322
|
+
<< " metadata.enable_cpu_backup=" << metadata.enable_cpu_backup
|
323
|
+
<< std::endl;
|
324
|
+
#endif
|
325
|
+
|
326
|
+
metadata.state = AllocationState::ACTIVE;
|
327
|
+
metadata.allocHandle = newAllocHandle;
|
328
|
+
}
|
329
|
+
#else
|
330
|
+
#error "USE_PLATFORM is not set"
|
331
|
+
#endif
|
332
|
+
}
|
@@ -0,0 +1,122 @@
|
|
1
|
+
#include "utils.h"
|
2
|
+
#include "core.h"
|
3
|
+
#include "api_forwarder.h"
|
4
|
+
#include <optional>
|
5
|
+
#include "macro.h"
|
6
|
+
|
7
|
+
// ----------------------------------------------- threadlocal configs --------------------------------------------------
|
8
|
+
|
9
|
+
class ThreadLocalConfig {
|
10
|
+
public:
|
11
|
+
std::string current_tag_ = "default";
|
12
|
+
|
13
|
+
bool is_interesting_region() {
|
14
|
+
if (!is_interesting_region_.has_value()) {
|
15
|
+
is_interesting_region_ = get_bool_env_var("TMS_INIT_ENABLE");
|
16
|
+
}
|
17
|
+
return is_interesting_region_.value();
|
18
|
+
}
|
19
|
+
|
20
|
+
void set_interesting_region(bool value) {
|
21
|
+
is_interesting_region_ = value;
|
22
|
+
}
|
23
|
+
|
24
|
+
bool enable_cpu_backup() {
|
25
|
+
if (!enable_cpu_backup_.has_value()) {
|
26
|
+
enable_cpu_backup_ = get_bool_env_var("TMS_INIT_ENABLE_CPU_BACKUP");
|
27
|
+
}
|
28
|
+
return enable_cpu_backup_.value();
|
29
|
+
}
|
30
|
+
|
31
|
+
void set_enable_cpu_backup(bool value) {
|
32
|
+
enable_cpu_backup_ = value;
|
33
|
+
}
|
34
|
+
|
35
|
+
private:
|
36
|
+
std::optional<bool> is_interesting_region_;
|
37
|
+
std::optional<bool> enable_cpu_backup_;
|
38
|
+
};
|
39
|
+
static thread_local ThreadLocalConfig thread_local_config;
|
40
|
+
|
41
|
+
// ------------------------------------------------- entrypoints :: hook ------------------------------------------------
|
42
|
+
|
43
|
+
#ifdef TMS_HOOK_MODE_PRELOAD
|
44
|
+
cudaError_t cudaMalloc(void **ptr, size_t size) {
|
45
|
+
if (thread_local_config.is_interesting_region()) {
|
46
|
+
return TorchMemorySaver::instance().malloc(
|
47
|
+
ptr, CUDAUtils::cu_ctx_get_device(), size, thread_local_config.current_tag_, thread_local_config.enable_cpu_backup());
|
48
|
+
} else {
|
49
|
+
return APIForwarder::call_real_cuda_malloc(ptr, size);
|
50
|
+
}
|
51
|
+
}
|
52
|
+
|
53
|
+
cudaError_t cudaFree(void *ptr) {
|
54
|
+
if (thread_local_config.is_interesting_region()) {
|
55
|
+
return TorchMemorySaver::instance().free(ptr);
|
56
|
+
} else {
|
57
|
+
return APIForwarder::call_real_cuda_free(ptr);
|
58
|
+
}
|
59
|
+
}
|
60
|
+
#endif
|
61
|
+
|
62
|
+
#ifdef TMS_HOOK_MODE_TORCH
|
63
|
+
extern "C" {
|
64
|
+
void *tms_torch_malloc(ssize_t size, int device, cudaStream_t stream) {
|
65
|
+
#ifdef TMS_DEBUG_LOG
|
66
|
+
std::cout << "[torch_memory_saver.cpp] tms_torch_malloc "
|
67
|
+
<< " size=" << size << " device=" << device << " stream=" << stream
|
68
|
+
<< std::endl;
|
69
|
+
#endif
|
70
|
+
SIMPLE_CHECK(thread_local_config.is_interesting_region(), "only support interesting region");
|
71
|
+
void *ptr;
|
72
|
+
TorchMemorySaver::instance().malloc(
|
73
|
+
&ptr, CUDAUtils::cu_device_get(device), size, thread_local_config.current_tag_, thread_local_config.enable_cpu_backup());
|
74
|
+
return ptr;
|
75
|
+
}
|
76
|
+
|
77
|
+
void tms_torch_free(void *ptr, ssize_t ssize, int device, cudaStream_t stream) {
|
78
|
+
#ifdef TMS_DEBUG_LOG
|
79
|
+
std::cout << "[torch_memory_saver.cpp] tms_torch_free "
|
80
|
+
<< " ptr=" << ptr << " ssize=" << ssize << " device=" << device << " stream=" << stream
|
81
|
+
<< std::endl;
|
82
|
+
#endif
|
83
|
+
SIMPLE_CHECK(thread_local_config.is_interesting_region(), "only support interesting region");
|
84
|
+
TorchMemorySaver::instance().free(ptr);
|
85
|
+
}
|
86
|
+
}
|
87
|
+
#endif
|
88
|
+
|
89
|
+
// ------------------------------------------------- entrypoints :: others ------------------------------------------------
|
90
|
+
|
91
|
+
extern "C" {
|
92
|
+
void tms_set_interesting_region(bool is_interesting_region) {
|
93
|
+
thread_local_config.set_interesting_region(is_interesting_region);
|
94
|
+
}
|
95
|
+
|
96
|
+
bool tms_get_interesting_region() {
|
97
|
+
return thread_local_config.is_interesting_region();
|
98
|
+
}
|
99
|
+
|
100
|
+
void tms_set_current_tag(const char* tag) {
|
101
|
+
SIMPLE_CHECK(tag != nullptr, "tag should not be null");
|
102
|
+
thread_local_config.current_tag_ = tag;
|
103
|
+
}
|
104
|
+
|
105
|
+
bool tms_get_enable_cpu_backup() {
|
106
|
+
return thread_local_config.enable_cpu_backup();
|
107
|
+
}
|
108
|
+
|
109
|
+
void tms_set_enable_cpu_backup(bool enable_cpu_backup) {
|
110
|
+
thread_local_config.set_enable_cpu_backup(enable_cpu_backup);
|
111
|
+
}
|
112
|
+
|
113
|
+
void tms_pause(const char* tag) {
|
114
|
+
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
|
115
|
+
TorchMemorySaver::instance().pause(tag_str);
|
116
|
+
}
|
117
|
+
|
118
|
+
void tms_resume(const char* tag) {
|
119
|
+
std::string tag_str = (tag != nullptr) ? std::string(tag) : "";
|
120
|
+
TorchMemorySaver::instance().resume(tag_str);
|
121
|
+
}
|
122
|
+
}
|