checkpoint-engine 0.3.0rc1__py3-none-any.whl → 0.3.1rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,390 @@
1
+ import concurrent.futures
2
+ import json
3
+ import os
4
+ import pickle
5
+ from typing import TYPE_CHECKING, Any, BinaryIO
6
+
7
+ import numpy as np
8
+ import torch
9
+ from loguru import logger
10
+ from pydantic import BaseModel
11
+ from safetensors.torch import _getdtype, safe_open
12
+
13
+ from checkpoint_engine.data_types import (
14
+ MemoryBuffer,
15
+ ParameterMeta,
16
+ )
17
+
18
+
19
+ if TYPE_CHECKING:
20
+ from checkpoint_engine.data_types import FileMeta
21
+
22
+ # 256 bytes alignment when flatten torch tensors to uint8 buffer
23
+ _ALIGN_SIZE = 256
24
+
25
+
26
+ def _align_size(dtype: torch.dtype, shape: torch.Size) -> int:
27
+ return (dtype.itemsize * shape.numel() + _ALIGN_SIZE - 1) // _ALIGN_SIZE * _ALIGN_SIZE
28
+
29
+
30
+ def _load_checkpoint_file(file_path: str) -> tuple[int, dict[str, tuple["FileMeta", torch.Tensor]]]:
31
+ def _safetensors_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
32
+ ret = {}
33
+ with safe_open(fn, framework="pt") as f:
34
+ for name in f.keys(): # noqa: SIM118
35
+ weight = f.get_tensor(name)
36
+ meta = {
37
+ "key": name,
38
+ "dtype": weight.dtype,
39
+ "shape": weight.shape,
40
+ "type": type(weight),
41
+ "tp_concat_dim": -1, # safetensors does not support tp_concat_dim
42
+ }
43
+ ret[name] = (meta, weight)
44
+ return ret
45
+
46
+ # deprecated, will be removed in the future
47
+ def _fast_np_load(fn: str) -> dict[str, tuple["FileMeta", torch.Tensor]]:
48
+ """load *.np file and return memmap and related tensor meta"""
49
+
50
+ def parse_npy_header(fin: BinaryIO) -> dict[str, Any]:
51
+ start = fin.tell()
52
+ major, minor = np.lib.format.read_magic(fin)
53
+ if major == 1 and minor == 0:
54
+ read_header_fn = np.lib.format.read_array_header_1_0
55
+ elif major == 2 and minor == 0:
56
+ read_header_fn = np.lib.format.read_array_header_2_0
57
+ else:
58
+ raise ValueError(
59
+ f"unknown version {major}.{minor} when parsing npy header from {fn}"
60
+ )
61
+ shape, is_fortran, dtype = read_header_fn(fin)
62
+ return {
63
+ "shape": shape,
64
+ "is_fortran": is_fortran,
65
+ "dtype": dtype,
66
+ "header_length": fin.tell() - start,
67
+ }
68
+
69
+ meta_fn = fn + ".meta"
70
+ with open(meta_fn, "rb") as fin:
71
+ meta_lst = pickle.load(fin)
72
+
73
+ tensors = []
74
+ offset = 0
75
+ with open(fn, "rb") as fin:
76
+ fin.seek(0, os.SEEK_END)
77
+ filesize = fin.tell()
78
+ fin.seek(0)
79
+ while fin.tell() < filesize:
80
+ tensor_meta = parse_npy_header(fin)
81
+ tensor = np.memmap(
82
+ fn,
83
+ dtype=tensor_meta["dtype"],
84
+ mode="c",
85
+ offset=offset + tensor_meta["header_length"],
86
+ shape=tensor_meta["shape"],
87
+ )
88
+ offset += tensor_meta["header_length"] + tensor.nbytes
89
+ fin.seek(offset)
90
+ tensors.append(tensor)
91
+
92
+ assert len(meta_lst) == len(tensors)
93
+ ret = {}
94
+ for meta, tensor in zip(meta_lst, tensors):
95
+ if meta["type"] == torch.Tensor:
96
+ tensor = torch.from_numpy(tensor)
97
+ tensor = tensor.view(dtype=meta["dtype"]).view(*meta["shape"])
98
+ ret[meta["key"]] = (meta, tensor)
99
+ return ret
100
+
101
+ tp_rank = 0
102
+ if file_path.endswith(".npy"):
103
+ logger.warning("numpy model file is deprecated, will be removed in the future")
104
+ filename_split = os.path.basename(file_path).split(".")
105
+ # if using numpy and want to specify tp rank
106
+ # file should be in model.{layer}.{tp}[.{ep}].npy format
107
+ tp_rank = int(filename_split[2]) if len(filename_split) > 3 else 0
108
+ ret = _fast_np_load(file_path)
109
+ elif file_path.endswith(".safetensors"):
110
+ ret = _safetensors_load(file_path)
111
+ else:
112
+ raise ValueError(f"unsupported file format: {file_path}")
113
+ return tp_rank, ret
114
+
115
+
116
+ def _concat_tp_weights(
117
+ tp_weights: list[torch.Tensor], tp_concat_dim: int, tp_size: int
118
+ ) -> torch.Tensor:
119
+ """Concat tp weights with meta info.
120
+ If meta.concat_dim is -1, means this is shared tp weights, just use the first weights.
121
+ Else we will cat weights in concat_dim.
122
+ """
123
+ if tp_concat_dim == -1:
124
+ return tp_weights[0]
125
+ assert tp_size == len(tp_weights)
126
+ if len(tp_weights) == 1:
127
+ return tp_weights[0]
128
+ return torch.cat([w for w in tp_weights], dim=tp_concat_dim)
129
+
130
+
131
+ def _load_checkpoint(files: list[str]) -> dict[str, torch.Tensor]:
132
+ class TPMeta(BaseModel):
133
+ concat_dim: int
134
+ size: int
135
+
136
+ parameters: dict[str, torch.Tensor] = {}
137
+ parameter_metas: dict[str, ParameterMeta] = {}
138
+ tp_metas: dict[str, TPMeta] = {}
139
+ parameters_with_tp: dict[str, dict[int, torch.Tensor]] = {}
140
+ for file in files:
141
+ tp_rank, ret = _load_checkpoint_file(file)
142
+ for parameter_name, (meta, weight) in ret.items():
143
+ if parameter_name not in parameters_with_tp:
144
+ parameters_with_tp[parameter_name] = {}
145
+ parameters_with_tp[parameter_name][tp_rank] = weight
146
+ if parameter_name not in tp_metas:
147
+ tp_metas[parameter_name] = TPMeta(
148
+ concat_dim=meta["tp_concat_dim"],
149
+ size=1,
150
+ )
151
+ if parameter_name not in parameter_metas:
152
+ assert isinstance(meta["dtype"], torch.dtype), (
153
+ f"meta {meta} dtype should be torch.dtype"
154
+ )
155
+ assert isinstance(meta["shape"], torch.Size), (
156
+ f"meta {meta} shape should be torch.Size"
157
+ )
158
+ parameter_metas[parameter_name] = ParameterMeta(
159
+ name=parameter_name,
160
+ shape=meta["shape"],
161
+ dtype=meta["dtype"],
162
+ aligned_size=_align_size(meta["dtype"], meta["shape"]),
163
+ )
164
+ tp_meta = tp_metas[parameter_name]
165
+ if tp_meta.concat_dim != -1:
166
+ tp_meta.size = max(tp_meta.size, tp_rank + 1)
167
+ for name, tp_meta in tp_metas.items():
168
+ if tp_meta.concat_dim != -1:
169
+ shape = list(parameter_metas[name].shape)
170
+ shape[tp_meta.concat_dim] = shape[tp_meta.concat_dim] * tp_meta.size
171
+ parameter_metas[name] = ParameterMeta(
172
+ name=name,
173
+ shape=torch.Size(shape),
174
+ dtype=parameter_metas[name].dtype,
175
+ aligned_size=_align_size(parameter_metas[name].dtype, torch.Size(shape)),
176
+ )
177
+ weights_in_cpu = [parameters_with_tp[name][key] for key in sorted(parameters_with_tp[name])]
178
+ # TODO: here concat is serial, which may be slow
179
+ # but since tp storage is not used in the future
180
+ # we ignore this performance issue for now
181
+ parameters[name] = _concat_tp_weights(weights_in_cpu, tp_meta.concat_dim, tp_meta.size)
182
+ for name, parameter in parameters.items():
183
+ assert name in parameter_metas, f"parameter {name} not found in parameter_metas"
184
+ assert parameter_metas[name].shape == parameter.shape, (
185
+ f"parameter {name} shape mismatch, {parameter_metas[name].shape} != {parameter.shape}"
186
+ )
187
+ assert parameter_metas[name].dtype == parameter.dtype, (
188
+ f"parameter {name} dtype mismatch, {parameter_metas[name].dtype} != {parameter.dtype}"
189
+ )
190
+ return parameters
191
+
192
+
193
+ def _inplace_pin_memory(files: list[str], rank: int | None = None) -> list[MemoryBuffer]:
194
+ def _parse_and_pin_from_safetensors(file_path: str) -> MemoryBuffer:
195
+ """
196
+ safetensors format see https://huggingface.co/docs/safetensors/en/index#format.
197
+ We load the safetensors file as bytes, then parse the header manually to get parameter metas.
198
+ The actual tensor data is in the remaining bytes and is naturally aligned.
199
+ We pin the remaining bytes as the buffer, making pinning faster.
200
+ """
201
+
202
+ def _pin(t: torch.Tensor):
203
+ """
204
+ Pin the memory of tensor in-place.
205
+ See: https://github.com/pytorch/pytorch/issues/32167
206
+ """
207
+ cudart = torch.cuda.cudart()
208
+ r = cudart.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0)
209
+ assert r == 0, f"pin memory error, error code: {r}"
210
+
211
+ # TODO: should only support /dev/shm? but we found files in disk also work?
212
+ size = os.stat(file_path).st_size
213
+ flag_size = 8
214
+ t = torch.from_file(file_path, True, size, dtype=torch.uint8)
215
+ assert t.nbytes > flag_size, (
216
+ f"tensor nbytes {t.nbytes} should be greater than flag_size {flag_size}"
217
+ )
218
+ start_pos = (
219
+ int.from_bytes(t[0:flag_size].numpy().tobytes(), byteorder="little", signed=False)
220
+ + flag_size
221
+ )
222
+ header_tensor = t[flag_size:start_pos]
223
+ header = json.loads(header_tensor.numpy().tobytes())
224
+ if "__metadata__" in header:
225
+ header.pop("__metadata__")
226
+
227
+ metas: list[ParameterMeta] = []
228
+ offset = 0
229
+ try:
230
+ for name, meta in sorted(header.items(), key=lambda x: x[1]["data_offsets"]):
231
+ start, end = meta["data_offsets"]
232
+ # safetensors format ensures offsets are aligned
233
+ assert offset == start, f"offset {offset} should be equal to start {start}"
234
+ metas.append(
235
+ ParameterMeta(
236
+ name=name,
237
+ dtype=_getdtype(meta["dtype"]),
238
+ shape=torch.Size(meta["shape"]),
239
+ aligned_size=end - start,
240
+ )
241
+ )
242
+ offset = end
243
+ except Exception as e:
244
+ logger.error(f"fail to parse safetensors header from {file_path}: {e}")
245
+ raise
246
+
247
+ buffer = t[start_pos:]
248
+ assert offset == buffer.nbytes, (
249
+ f"offset {offset} should be equal to buffer.nbytes {buffer.nbytes}"
250
+ )
251
+ # Remove the file after successfully loading. This will avoid doubling the memory usage.
252
+ # We assume files in /dev/shm/ are temporary files. So it's safe to remove them after loading.
253
+ os.remove(file_path)
254
+ _pin(buffer)
255
+ logger.info(
256
+ f"[rank{rank}] inplace pin memory for file {file_path} finished, size {buffer.nbytes / 1024 / 1024:.2f}MiB"
257
+ )
258
+ return MemoryBuffer(buffer=buffer, size=buffer.nbytes, metas=metas, manually_pinned=True)
259
+
260
+ memory_buffers: list[MemoryBuffer] = []
261
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
262
+ memory_buffers = list(executor.map(_parse_and_pin_from_safetensors, files))
263
+ return memory_buffers
264
+
265
+
266
+ def _normal_pin_memory(
267
+ files: list[str],
268
+ named_tensors: dict[str, torch.Tensor],
269
+ rank: int | None = None,
270
+ shared_pin_memory: list[MemoryBuffer] | None = None,
271
+ ) -> list[MemoryBuffer]:
272
+ parameters = _load_checkpoint(files)
273
+ if named_tensors:
274
+ parameters.update(named_tensors)
275
+ bucket_size = max(4 << 30, max(_align_size(x.dtype, x.shape) for x in parameters.values()))
276
+
277
+ class MemoryBucket(BaseModel):
278
+ size: int
279
+ metas: list[ParameterMeta]
280
+
281
+ buckets: list[MemoryBucket] = []
282
+ buckets.append(MemoryBucket(size=0, metas=[]))
283
+ for name, tensor in sorted(parameters.items()):
284
+ size = _align_size(tensor.dtype, tensor.shape)
285
+ if buckets[-1].size + size > bucket_size:
286
+ assert buckets[-1], f"buckets[{len(buckets) - 1}] should not be empty"
287
+ buckets.append(MemoryBucket(size=0, metas=[]))
288
+ buckets[-1].metas.append(
289
+ ParameterMeta(name=name, shape=tensor.shape, dtype=tensor.dtype, aligned_size=size)
290
+ )
291
+ buckets[-1].size += size
292
+
293
+ memory_buffers = [
294
+ MemoryBuffer(buffer=torch.empty(0), size=bucket.size, metas=bucket.metas)
295
+ for bucket in buckets
296
+ ]
297
+
298
+ def register_pin_memory(
299
+ idx: int, size: int, shared_pin_memory: list[MemoryBuffer] | None = None
300
+ ) -> tuple[int, torch.Tensor]:
301
+ if shared_pin_memory:
302
+ # If shared_pin_memory is provided, reuse the pin memory buffer, do not allocate new one
303
+ # Reusing pin memory only support fixed shape of checkpoints, which is registered the first time
304
+ assert idx < len(shared_pin_memory), (
305
+ f"idx {idx} should be less than shared_pin_memory length {len(shared_pin_memory)}"
306
+ )
307
+ assert shared_pin_memory[idx].size == size, (
308
+ f"shared_pin_memory[{idx}].size {shared_pin_memory[idx].size} should be equal to {size}"
309
+ )
310
+ return idx, shared_pin_memory[idx].buffer
311
+ else:
312
+ buffer = torch.empty(size, dtype=torch.uint8, pin_memory=True)
313
+ return idx, buffer
314
+
315
+ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
316
+ buffer[offset : offset + tensor.nbytes] = tensor.view(-1).view(dtype=torch.uint8)
317
+
318
+ with concurrent.futures.ThreadPoolExecutor(max_workers=32) as executor:
319
+ futures = [
320
+ executor.submit(
321
+ register_pin_memory,
322
+ idx,
323
+ bucket.size,
324
+ shared_pin_memory,
325
+ )
326
+ for idx, bucket in enumerate(buckets)
327
+ ]
328
+ new_futures = []
329
+ for future in concurrent.futures.as_completed(futures):
330
+ idx, buffer = future.result()
331
+ assert buffer.numel() == buckets[idx].size, (
332
+ f"buffer numel {buffer.numel()} should be equal to bucket size {buckets[idx].size}"
333
+ )
334
+ memory_buffers[idx].buffer = buffer
335
+ logger.info(
336
+ f"[rank{rank}] register pin_memory for bucket {idx + 1}/{len(buckets)} finished, "
337
+ f"size {buffer.numel() / 1024 / 1024:.2f}MiB, start to copy tensors to buffer"
338
+ )
339
+ offset = 0
340
+ for meta in buckets[idx].metas:
341
+ name = meta.name
342
+ tensor = parameters[name]
343
+ size = _align_size(tensor.dtype, tensor.shape)
344
+ assert size == _align_size(meta.dtype, meta.shape), (
345
+ f"tensor {name} size {size} should be equal to meta size {_align_size(meta.dtype, meta.shape)}"
346
+ )
347
+ new_futures.append(executor.submit(register_tensor, buffer, offset, tensor))
348
+ offset += size
349
+ for future in concurrent.futures.as_completed(new_futures):
350
+ future.result()
351
+ return memory_buffers
352
+
353
+
354
+ def _register_checkpoint(
355
+ *,
356
+ files: list[str],
357
+ named_tensors: dict[str, torch.Tensor],
358
+ rank: int | None = None,
359
+ shared_pin_memory: list[MemoryBuffer] | None = None,
360
+ inplace_pin: bool = False,
361
+ ) -> list[MemoryBuffer]:
362
+ logger.info(
363
+ f"[rank{rank}] start to register checkpoint with {len(files)} files and {len(named_tensors)} named_tensors"
364
+ )
365
+ if not files and not named_tensors:
366
+ return []
367
+ memory_buffers: list[MemoryBuffer] = []
368
+ if inplace_pin:
369
+ logger.info(f"[rank{rank}] allow inplace pin memory for /dev/shm/ safetensors files")
370
+ files_to_inplace_pin = [
371
+ file
372
+ for file in files
373
+ if file.startswith("/dev/shm/") and file.endswith(".safetensors") # noqa: S108
374
+ ]
375
+ files_to_normal_pin = [file for file in files if file not in files_to_inplace_pin]
376
+ else:
377
+ files_to_normal_pin = files
378
+ files_to_inplace_pin = []
379
+ if files_to_normal_pin or named_tensors:
380
+ memory_buffers.extend(
381
+ _normal_pin_memory(
382
+ files=files_to_normal_pin,
383
+ named_tensors=named_tensors,
384
+ rank=rank,
385
+ shared_pin_memory=shared_pin_memory,
386
+ )
387
+ )
388
+ if files_to_inplace_pin:
389
+ memory_buffers.extend(_inplace_pin_memory(files_to_inplace_pin, rank=rank))
390
+ return memory_buffers