checkpoint-engine 0.3.0rc1__py3-none-any.whl → 0.3.1rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checkpoint_engine/__init__.py +36 -0
- checkpoint_engine/__main__.py +28 -0
- checkpoint_engine/_version.py +2 -2
- checkpoint_engine/api.py +95 -0
- checkpoint_engine/data_types.py +111 -0
- checkpoint_engine/p2p_store.py +210 -0
- checkpoint_engine/pin_memory.py +390 -0
- checkpoint_engine/ps.py +23 -797
- checkpoint_engine/worker.py +18 -9
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/METADATA +1 -1
- checkpoint_engine-0.3.1rc0.dist-info/RECORD +15 -0
- checkpoint_engine-0.3.0rc1.dist-info/RECORD +0 -10
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/WHEEL +0 -0
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/licenses/LICENCE +0 -0
- {checkpoint_engine-0.3.0rc1.dist-info → checkpoint_engine-0.3.1rc0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,390 @@
|
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import pickle
|
|
5
|
+
from typing import TYPE_CHECKING, Any, BinaryIO
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import torch
|
|
9
|
+
from loguru import logger
|
|
10
|
+
from pydantic import BaseModel
|
|
11
|
+
from safetensors.torch import _getdtype, safe_open
|
|
12
|
+
|
|
13
|
+
from checkpoint_engine.data_types import (
|
|
14
|
+
MemoryBuffer,
|
|
15
|
+
ParameterMeta,
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from checkpoint_engine.data_types import FileMeta
|
|
21
|
+
|
|
22
|
+
# 256 bytes alignment when flatten torch tensors to uint8 buffer
|
|
23
|
+
_ALIGN_SIZE = 256
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
|
|
27
|
+
return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
|
|
31
|
+
def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
32
|
+
ret = {}
|
|
33
|
+
with safe_open(fn, framework="pt") as f:
|
|
34
|
+
for name in f.keys(): # noqa: SIM118
|
|
35
|
+
weight = f.get_tensor(name)
|
|
36
|
+
meta = {
|
|
37
|
+
"key": name,
|
|
38
|
+
"dtype": weight.dtype,
|
|
39
|
+
"shape": weight.shape,
|
|
40
|
+
"type": type(weight),
|
|
41
|
+
"tp_concat_dim": -1, # safetensors does not support tp_concat_dim
|
|
42
|
+
}
|
|
43
|
+
ret[name] = (meta, weight)
|
|
44
|
+
return ret
|
|
45
|
+
|
|
46
|
+
# deprecated, will be removed in the future
|
|
47
|
+
def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
|
|
48
|
+
"""load *.np file and return memmap and related tensor meta"""
|
|
49
|
+
|
|
50
|
+
def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
|
|
51
|
+
start = fin.tell()
|
|
52
|
+
major, minor = np.lib.format.read_magic(fin)
|
|
53
|
+
if major == 1 and minor == 0:
|
|
54
|
+
read_header_fn = np.lib.format.read_array_header_1_0
|
|
55
|
+
elif major == 2 and minor == 0:
|
|
56
|
+
read_header_fn = np.lib.format.read_array_header_2_0
|
|
57
|
+
else:
|
|
58
|
+
raise ValueError(
|
|
59
|
+
f"unknown version {major}.{minor} when parsing npy header from {fn}"
|
|
60
|
+
)
|
|
61
|
+
shape, is_fortran, dtype = read_header_fn(fin)
|
|
62
|
+
return {
|
|
63
|
+
"shape": shape,
|
|
64
|
+
"is_fortran": is_fortran,
|
|
65
|
+
"dtype": dtype,
|
|
66
|
+
"header_length": fin.tell() - start,
|
|
67
|
+
}
|
|
68
|
+
|
|
69
|
+
meta_fn = fn + ".meta"
|
|
70
|
+
with open(meta_fn, "rb") as fin:
|
|
71
|
+
meta_lst = pickle.load(fin)
|
|
72
|
+
|
|
73
|
+
tensors = []
|
|
74
|
+
offset = 0
|
|
75
|
+
with open(fn, "rb") as fin:
|
|
76
|
+
fin.seek(0, os.SEEK_END)
|
|
77
|
+
filesize = fin.tell()
|
|
78
|
+
fin.seek(0)
|
|
79
|
+
while fin.tell() < filesize:
|
|
80
|
+
tensor_meta = parse_npy_header(fin)
|
|
81
|
+
tensor = np.memmap(
|
|
82
|
+
fn,
|
|
83
|
+
dtype=tensor_meta["dtype"],
|
|
84
|
+
mode="c",
|
|
85
|
+
offset=offset + tensor_meta["header_length"],
|
|
86
|
+
shape=tensor_meta["shape"],
|
|
87
|
+
)
|
|
88
|
+
offset += tensor_meta["header_length"] + tensor.nbytes
|
|
89
|
+
fin.seek(offset)
|
|
90
|
+
tensors.append(tensor)
|
|
91
|
+
|
|
92
|
+
assert len(meta_lst) == len(tensors)
|
|
93
|
+
ret = {}
|
|
94
|
+
for meta, tensor in zip(meta_lst, tensors):
|
|
95
|
+
if meta["type"] == torch.Tensor:
|
|
96
|
+
tensor = torch.from_numpy(tensor)
|
|
97
|
+
tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
|
|
98
|
+
ret[meta["key"]] = (meta, tensor)
|
|
99
|
+
return ret
|
|
100
|
+
|
|
101
|
+
tp_rank = 0
|
|
102
|
+
if file_path.endswith(".npy"):
|
|
103
|
+
logger.warning("numpy model file is deprecated, will be removed in the future")
|
|
104
|
+
filename_split = os.path.basename(file_path).split(".")
|
|
105
|
+
# if using numpy and want to specify tp rank
|
|
106
|
+
# file should be in model.{layer}.{tp}[.{ep}].npy format
|
|
107
|
+
tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
|
|
108
|
+
ret = _fast_np_load(file_path)
|
|
109
|
+
elif file_path.endswith(".safetensors"):
|
|
110
|
+
ret = _safetensors_load(file_path)
|
|
111
|
+
else:
|
|
112
|
+
raise ValueError(f"unsupported file format: {file_path}")
|
|
113
|
+
return tp_rank, ret
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _concat_tp_weights(
|
|
117
|
+
tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Concat tp weights with meta info.
|
|
120
|
+
If meta.concat_dim is -1, means this is shared tp weights, just use the first weights.
|
|
121
|
+
Else we will cat weights in concat_dim.
|
|
122
|
+
"""
|
|
123
|
+
if tp_concat_dim == -1:
|
|
124
|
+
return tp_weights[0]
|
|
125
|
+
assert tp_size == len(tp_weights)
|
|
126
|
+
if len(tp_weights) == 1:
|
|
127
|
+
return tp_weights[0]
|
|
128
|
+
return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
|
|
132
|
+
class TPMeta(BaseModel):
|
|
133
|
+
concat_dim: int
|
|
134
|
+
size: int
|
|
135
|
+
|
|
136
|
+
parameters: dict[str, torch.Tensor] = {}
|
|
137
|
+
parameter_metas: dict[str, ParameterMeta] = {}
|
|
138
|
+
tp_metas: dict[str, TPMeta] = {}
|
|
139
|
+
parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
|
|
140
|
+
for file in files:
|
|
141
|
+
tp_rank, ret = _load_checkpoint_file(file)
|
|
142
|
+
for parameter_name, (meta, weight) in ret.items():
|
|
143
|
+
if parameter_name not in parameters_with_tp:
|
|
144
|
+
parameters_with_tp[parameter_name] = {}
|
|
145
|
+
parameters_with_tp[parameter_name][tp_rank] = weight
|
|
146
|
+
if parameter_name not in tp_metas:
|
|
147
|
+
tp_metas[parameter_name] = TPMeta(
|
|
148
|
+
concat_dim=meta["tp_concat_dim"],
|
|
149
|
+
size=1,
|
|
150
|
+
)
|
|
151
|
+
if parameter_name not in parameter_metas:
|
|
152
|
+
assert isinstance(meta["dtype"], torch.dtype), (
|
|
153
|
+
f"meta {meta} dtype should be torch.dtype"
|
|
154
|
+
)
|
|
155
|
+
assert isinstance(meta["shape"], torch.Size), (
|
|
156
|
+
f"meta {meta} shape should be torch.Size"
|
|
157
|
+
)
|
|
158
|
+
parameter_metas[parameter_name] = ParameterMeta(
|
|
159
|
+
name=parameter_name,
|
|
160
|
+
shape=meta["shape"],
|
|
161
|
+
dtype=meta["dtype"],
|
|
162
|
+
aligned_size=_align_size(meta["dtype"], meta["shape"]),
|
|
163
|
+
)
|
|
164
|
+
tp_meta = tp_metas[parameter_name]
|
|
165
|
+
if tp_meta.concat_dim != -1:
|
|
166
|
+
tp_meta.size = max(tp_meta.size, tp_rank + 1)
|
|
167
|
+
for name, tp_meta in tp_metas.items():
|
|
168
|
+
if tp_meta.concat_dim != -1:
|
|
169
|
+
shape = list(parameter_metas[name].shape)
|
|
170
|
+
shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
|
|
171
|
+
parameter_metas[name] = ParameterMeta(
|
|
172
|
+
name=name,
|
|
173
|
+
shape=torch.Size(shape),
|
|
174
|
+
dtype=parameter_metas[name].dtype,
|
|
175
|
+
aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
|
|
176
|
+
)
|
|
177
|
+
weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
|
|
178
|
+
# TODO: here concat is serial, which may be slow
|
|
179
|
+
# but since tp storage is not used in the future
|
|
180
|
+
# we ignore this performance issue for now
|
|
181
|
+
parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
|
|
182
|
+
for name, parameter in parameters.items():
|
|
183
|
+
assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
|
|
184
|
+
assert parameter_metas[name].shape == parameter.shape, (
|
|
185
|
+
f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
|
|
186
|
+
)
|
|
187
|
+
assert parameter_metas[name].dtype == parameter.dtype, (
|
|
188
|
+
f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
|
|
189
|
+
)
|
|
190
|
+
return parameters
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
|
|
194
|
+
def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
|
|
195
|
+
"""
|
|
196
|
+
safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
|
|
197
|
+
We load the safetensors file as bytes, then parse the header manually to get parameter metas.
|
|
198
|
+
The actual tensor data is in the remaining bytes and is naturally aligned.
|
|
199
|
+
We pin the remaining bytes as the buffer, making pinning faster.
|
|
200
|
+
"""
|
|
201
|
+
|
|
202
|
+
def _pin(t: torch.Tensor):
|
|
203
|
+
"""
|
|
204
|
+
Pin the memory of tensor in-place.
|
|
205
|
+
See: https://github.com/pytorch/pytorch/issues/32167
|
|
206
|
+
"""
|
|
207
|
+
cudart = torch.cuda.cudart()
|
|
208
|
+
r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
|
|
209
|
+
assert r == 0, f"pin memory error, error code: {r}"
|
|
210
|
+
|
|
211
|
+
# TODO: should only support /dev/shm? but we found files in disk also work?
|
|
212
|
+
size = os.stat(file_path).st_size
|
|
213
|
+
flag_size = 8
|
|
214
|
+
t = torch.from_file(file_path, True, size, dtype=torch.uint8)
|
|
215
|
+
assert t.nbytes > flag_size, (
|
|
216
|
+
f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
|
|
217
|
+
)
|
|
218
|
+
start_pos = (
|
|
219
|
+
int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
|
|
220
|
+
+ flag_size
|
|
221
|
+
)
|
|
222
|
+
header_tensor = t[flag_size:start_pos]
|
|
223
|
+
header = json.loads(header_tensor.numpy().tobytes())
|
|
224
|
+
if "__metadata__" in header:
|
|
225
|
+
header.pop("__metadata__")
|
|
226
|
+
|
|
227
|
+
metas: list[ParameterMeta] = []
|
|
228
|
+
offset = 0
|
|
229
|
+
try:
|
|
230
|
+
for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
|
|
231
|
+
start, end = meta["data_offsets"]
|
|
232
|
+
# safetensors format ensures offsets are aligned
|
|
233
|
+
assert offset == start, f"offset {offset} should be equal to start {start}"
|
|
234
|
+
metas.append(
|
|
235
|
+
ParameterMeta(
|
|
236
|
+
name=name,
|
|
237
|
+
dtype=_getdtype(meta["dtype"]),
|
|
238
|
+
shape=torch.Size(meta["shape"]),
|
|
239
|
+
aligned_size=end - start,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
offset = end
|
|
243
|
+
except Exception as e:
|
|
244
|
+
logger.error(f"fail to parse safetensors header from {file_path}: {e}")
|
|
245
|
+
raise
|
|
246
|
+
|
|
247
|
+
buffer = t[start_pos:]
|
|
248
|
+
assert offset == buffer.nbytes, (
|
|
249
|
+
f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
|
|
250
|
+
)
|
|
251
|
+
# Remove the file after successfully loading. This will avoid doubling the memory usage.
|
|
252
|
+
# We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
|
|
253
|
+
os.remove(file_path)
|
|
254
|
+
_pin(buffer)
|
|
255
|
+
logger.info(
|
|
256
|
+
f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
|
|
257
|
+
)
|
|
258
|
+
return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
|
|
259
|
+
|
|
260
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
261
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
262
|
+
memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
|
|
263
|
+
return memory_buffers
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def _normal_pin_memory(
|
|
267
|
+
files: list[str],
|
|
268
|
+
named_tensors: dict[str, torch.Tensor],
|
|
269
|
+
rank: int | None = None,
|
|
270
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
271
|
+
) -> list[MemoryBuffer]:
|
|
272
|
+
parameters = _load_checkpoint(files)
|
|
273
|
+
if named_tensors:
|
|
274
|
+
parameters.update(named_tensors)
|
|
275
|
+
bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
|
|
276
|
+
|
|
277
|
+
class MemoryBucket(BaseModel):
|
|
278
|
+
size: int
|
|
279
|
+
metas: list[ParameterMeta]
|
|
280
|
+
|
|
281
|
+
buckets: list[MemoryBucket] = []
|
|
282
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
283
|
+
for name, tensor in sorted(parameters.items()):
|
|
284
|
+
size = _align_size(tensor.dtype, tensor.shape)
|
|
285
|
+
if buckets[-1].size + size > bucket_size:
|
|
286
|
+
assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
|
|
287
|
+
buckets.append(MemoryBucket(size=0, metas=[]))
|
|
288
|
+
buckets[-1].metas.append(
|
|
289
|
+
ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
|
|
290
|
+
)
|
|
291
|
+
buckets[-1].size += size
|
|
292
|
+
|
|
293
|
+
memory_buffers = [
|
|
294
|
+
MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
|
|
295
|
+
for bucket in buckets
|
|
296
|
+
]
|
|
297
|
+
|
|
298
|
+
def register_pin_memory(
|
|
299
|
+
idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
|
|
300
|
+
) -> tuple[int, torch.Tensor]:
|
|
301
|
+
if shared_pin_memory:
|
|
302
|
+
# If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
|
|
303
|
+
# Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
|
|
304
|
+
assert idx < len(shared_pin_memory), (
|
|
305
|
+
f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
|
|
306
|
+
)
|
|
307
|
+
assert shared_pin_memory[idx].size == size, (
|
|
308
|
+
f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
|
|
309
|
+
)
|
|
310
|
+
return idx, shared_pin_memory[idx].buffer
|
|
311
|
+
else:
|
|
312
|
+
buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
|
|
313
|
+
return idx, buffer
|
|
314
|
+
|
|
315
|
+
def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
|
|
316
|
+
buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
|
|
317
|
+
|
|
318
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
|
|
319
|
+
futures = [
|
|
320
|
+
executor.submit(
|
|
321
|
+
register_pin_memory,
|
|
322
|
+
idx,
|
|
323
|
+
bucket.size,
|
|
324
|
+
shared_pin_memory,
|
|
325
|
+
)
|
|
326
|
+
for idx, bucket in enumerate(buckets)
|
|
327
|
+
]
|
|
328
|
+
new_futures = []
|
|
329
|
+
for future in concurrent.futures.as_completed(futures):
|
|
330
|
+
idx, buffer = future.result()
|
|
331
|
+
assert buffer.numel() == buckets[idx].size, (
|
|
332
|
+
f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
|
|
333
|
+
)
|
|
334
|
+
memory_buffers[idx].buffer = buffer
|
|
335
|
+
logger.info(
|
|
336
|
+
f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
|
|
337
|
+
f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
|
|
338
|
+
)
|
|
339
|
+
offset = 0
|
|
340
|
+
for meta in buckets[idx].metas:
|
|
341
|
+
name = meta.name
|
|
342
|
+
tensor = parameters[name]
|
|
343
|
+
size = _align_size(tensor.dtype, tensor.shape)
|
|
344
|
+
assert size == _align_size(meta.dtype, meta.shape), (
|
|
345
|
+
f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
|
|
346
|
+
)
|
|
347
|
+
new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
|
|
348
|
+
offset += size
|
|
349
|
+
for future in concurrent.futures.as_completed(new_futures):
|
|
350
|
+
future.result()
|
|
351
|
+
return memory_buffers
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
def _register_checkpoint(
|
|
355
|
+
*,
|
|
356
|
+
files: list[str],
|
|
357
|
+
named_tensors: dict[str, torch.Tensor],
|
|
358
|
+
rank: int | None = None,
|
|
359
|
+
shared_pin_memory: list[MemoryBuffer] | None = None,
|
|
360
|
+
inplace_pin: bool = False,
|
|
361
|
+
) -> list[MemoryBuffer]:
|
|
362
|
+
logger.info(
|
|
363
|
+
f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
|
|
364
|
+
)
|
|
365
|
+
if not files and not named_tensors:
|
|
366
|
+
return []
|
|
367
|
+
memory_buffers: list[MemoryBuffer] = []
|
|
368
|
+
if inplace_pin:
|
|
369
|
+
logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
|
|
370
|
+
files_to_inplace_pin = [
|
|
371
|
+
file
|
|
372
|
+
for file in files
|
|
373
|
+
if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
|
|
374
|
+
]
|
|
375
|
+
files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
|
|
376
|
+
else:
|
|
377
|
+
files_to_normal_pin = files
|
|
378
|
+
files_to_inplace_pin = []
|
|
379
|
+
if files_to_normal_pin or named_tensors:
|
|
380
|
+
memory_buffers.extend(
|
|
381
|
+
_normal_pin_memory(
|
|
382
|
+
files=files_to_normal_pin,
|
|
383
|
+
named_tensors=named_tensors,
|
|
384
|
+
rank=rank,
|
|
385
|
+
shared_pin_memory=shared_pin_memory,
|
|
386
|
+
)
|
|
387
|
+
)
|
|
388
|
+
if files_to_inplace_pin:
|
|
389
|
+
memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
|
|
390
|
+
return memory_buffers
|