ckptkit 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ckptkit/__init__.py +112 -0
- ckptkit/_types.py +164 -0
- ckptkit/cli.py +163 -0
- ckptkit/convert.py +366 -0
- ckptkit/diff.py +233 -0
- ckptkit/estimator.py +193 -0
- ckptkit/gguf.py +348 -0
- ckptkit/inspect.py +192 -0
- ckptkit/merge.py +116 -0
- ckptkit/metadata.py +265 -0
- ckptkit/py.typed +0 -0
- ckptkit/stats.py +148 -0
- ckptkit/validate.py +141 -0
- ckptkit-0.3.0.dist-info/METADATA +223 -0
- ckptkit-0.3.0.dist-info/RECORD +18 -0
- ckptkit-0.3.0.dist-info/WHEEL +4 -0
- ckptkit-0.3.0.dist-info/entry_points.txt +2 -0
- ckptkit-0.3.0.dist-info/licenses/LICENSE +177 -0
ckptkit/__init__.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
"""ckptkit — inspect, convert, diff, and merge model checkpoints."""
|
|
2
|
+
|
|
3
|
+
from ckptkit._types import (
|
|
4
|
+
DTYPE_SIZES,
|
|
5
|
+
CheckpointFormat,
|
|
6
|
+
CheckpointInfo,
|
|
7
|
+
CkptError,
|
|
8
|
+
DiffEntry,
|
|
9
|
+
DiffResult,
|
|
10
|
+
DType,
|
|
11
|
+
FormatError,
|
|
12
|
+
MergeConfig,
|
|
13
|
+
MergeError,
|
|
14
|
+
TensorInfo,
|
|
15
|
+
)
|
|
16
|
+
from ckptkit.convert import (
|
|
17
|
+
ConversionConfig,
|
|
18
|
+
ConversionFormat,
|
|
19
|
+
ConversionResult,
|
|
20
|
+
FormatConverter,
|
|
21
|
+
convert_checkpoint,
|
|
22
|
+
)
|
|
23
|
+
from ckptkit.diff import diff, diff_infos, format_diff, format_diff_rich, format_diff_table
|
|
24
|
+
from ckptkit.estimator import (
|
|
25
|
+
EstimationResult,
|
|
26
|
+
QuantEstimationResult,
|
|
27
|
+
TensorEstimate,
|
|
28
|
+
estimate_quantized_size,
|
|
29
|
+
estimate_reduction,
|
|
30
|
+
format_estimation,
|
|
31
|
+
)
|
|
32
|
+
from ckptkit.gguf import GGUFInfo, format_gguf_info, inspect_gguf, parse_gguf
|
|
33
|
+
from ckptkit.inspect import (
|
|
34
|
+
detect_format,
|
|
35
|
+
format_params,
|
|
36
|
+
format_size,
|
|
37
|
+
inspect,
|
|
38
|
+
inspect_safetensors,
|
|
39
|
+
)
|
|
40
|
+
from ckptkit.merge import find_lora_pairs, merge_lora_state_dicts
|
|
41
|
+
from ckptkit.metadata import (
|
|
42
|
+
CheckpointMetadata,
|
|
43
|
+
MetadataEditor,
|
|
44
|
+
extract_metadata_from_path,
|
|
45
|
+
format_metadata_report,
|
|
46
|
+
)
|
|
47
|
+
from ckptkit.stats import CheckpointStats, TensorStats, stats_from_info
|
|
48
|
+
from ckptkit.validate import ValidationIssue, ValidationResult, validate
|
|
49
|
+
|
|
50
|
+
__version__ = "0.3.0"
|
|
51
|
+
|
|
52
|
+
__all__ = [
|
|
53
|
+
"__version__",
|
|
54
|
+
# Types
|
|
55
|
+
"CheckpointFormat",
|
|
56
|
+
"CheckpointInfo",
|
|
57
|
+
"CkptError",
|
|
58
|
+
"DType",
|
|
59
|
+
"DTYPE_SIZES",
|
|
60
|
+
"DiffEntry",
|
|
61
|
+
"DiffResult",
|
|
62
|
+
"FormatError",
|
|
63
|
+
"MergeConfig",
|
|
64
|
+
"MergeError",
|
|
65
|
+
"TensorInfo",
|
|
66
|
+
# Inspect
|
|
67
|
+
"detect_format",
|
|
68
|
+
"inspect",
|
|
69
|
+
"inspect_safetensors",
|
|
70
|
+
"format_size",
|
|
71
|
+
"format_params",
|
|
72
|
+
# Diff
|
|
73
|
+
"diff",
|
|
74
|
+
"diff_infos",
|
|
75
|
+
"format_diff",
|
|
76
|
+
"format_diff_rich",
|
|
77
|
+
"format_diff_table",
|
|
78
|
+
# Merge
|
|
79
|
+
"merge_lora_state_dicts",
|
|
80
|
+
"find_lora_pairs",
|
|
81
|
+
# Stats
|
|
82
|
+
"CheckpointStats",
|
|
83
|
+
"TensorStats",
|
|
84
|
+
"stats_from_info",
|
|
85
|
+
# Validate
|
|
86
|
+
"validate",
|
|
87
|
+
"ValidationResult",
|
|
88
|
+
"ValidationIssue",
|
|
89
|
+
# Estimator
|
|
90
|
+
"EstimationResult",
|
|
91
|
+
"QuantEstimationResult",
|
|
92
|
+
"TensorEstimate",
|
|
93
|
+
"estimate_reduction",
|
|
94
|
+
"estimate_quantized_size",
|
|
95
|
+
"format_estimation",
|
|
96
|
+
# Convert
|
|
97
|
+
"ConversionConfig",
|
|
98
|
+
"ConversionFormat",
|
|
99
|
+
"ConversionResult",
|
|
100
|
+
"FormatConverter",
|
|
101
|
+
"convert_checkpoint",
|
|
102
|
+
# GGUF
|
|
103
|
+
"GGUFInfo",
|
|
104
|
+
"parse_gguf",
|
|
105
|
+
"inspect_gguf",
|
|
106
|
+
"format_gguf_info",
|
|
107
|
+
# Metadata
|
|
108
|
+
"CheckpointMetadata",
|
|
109
|
+
"MetadataEditor",
|
|
110
|
+
"extract_metadata_from_path",
|
|
111
|
+
"format_metadata_report",
|
|
112
|
+
]
|
ckptkit/_types.py
ADDED
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
"""Core types for ckptkit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from typing import Any
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CheckpointFormat(str, Enum):
|
|
11
|
+
"""Supported checkpoint formats."""
|
|
12
|
+
|
|
13
|
+
SAFETENSORS = "safetensors"
|
|
14
|
+
PYTORCH = "pytorch"
|
|
15
|
+
NUMPY = "numpy"
|
|
16
|
+
UNKNOWN = "unknown"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class DType(str, Enum):
|
|
20
|
+
"""Common tensor data types."""
|
|
21
|
+
|
|
22
|
+
F32 = "F32"
|
|
23
|
+
F16 = "F16"
|
|
24
|
+
BF16 = "BF16"
|
|
25
|
+
F64 = "F64"
|
|
26
|
+
I64 = "I64"
|
|
27
|
+
I32 = "I32"
|
|
28
|
+
I16 = "I16"
|
|
29
|
+
I8 = "I8"
|
|
30
|
+
U8 = "U8"
|
|
31
|
+
BOOL = "BOOL"
|
|
32
|
+
UNKNOWN = "UNKNOWN"
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
# Bytes per element for each dtype
|
|
36
|
+
DTYPE_SIZES: dict[DType, int] = {
|
|
37
|
+
DType.F32: 4, DType.F16: 2, DType.BF16: 2, DType.F64: 8,
|
|
38
|
+
DType.I64: 8, DType.I32: 4, DType.I16: 2, DType.I8: 1,
|
|
39
|
+
DType.U8: 1, DType.BOOL: 1,
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
@dataclass
|
|
44
|
+
class TensorInfo:
|
|
45
|
+
"""Metadata for a single tensor in a checkpoint."""
|
|
46
|
+
|
|
47
|
+
name: str
|
|
48
|
+
shape: list[int]
|
|
49
|
+
dtype: DType
|
|
50
|
+
offset_start: int = 0
|
|
51
|
+
offset_end: int = 0
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def numel(self) -> int:
|
|
55
|
+
"""Number of elements."""
|
|
56
|
+
result = 1
|
|
57
|
+
for s in self.shape:
|
|
58
|
+
result *= s
|
|
59
|
+
return result
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def size_bytes(self) -> int:
|
|
63
|
+
"""Size in bytes."""
|
|
64
|
+
return self.numel * DTYPE_SIZES.get(self.dtype, 0)
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def shape_str(self) -> str:
|
|
68
|
+
return "×".join(str(s) for s in self.shape)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
@dataclass
|
|
72
|
+
class CheckpointInfo:
|
|
73
|
+
"""Metadata for an entire checkpoint."""
|
|
74
|
+
|
|
75
|
+
path: str
|
|
76
|
+
format: CheckpointFormat
|
|
77
|
+
file_size: int
|
|
78
|
+
tensors: list[TensorInfo]
|
|
79
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def n_tensors(self) -> int:
|
|
83
|
+
return len(self.tensors)
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def n_parameters(self) -> int:
|
|
87
|
+
return sum(t.numel for t in self.tensors)
|
|
88
|
+
|
|
89
|
+
@property
|
|
90
|
+
def total_bytes(self) -> int:
|
|
91
|
+
return sum(t.size_bytes for t in self.tensors)
|
|
92
|
+
|
|
93
|
+
def dtype_summary(self) -> dict[str, int]:
|
|
94
|
+
"""Count parameters per dtype."""
|
|
95
|
+
counts: dict[str, int] = {}
|
|
96
|
+
for t in self.tensors:
|
|
97
|
+
key = t.dtype.value
|
|
98
|
+
counts[key] = counts.get(key, 0) + t.numel
|
|
99
|
+
return counts
|
|
100
|
+
|
|
101
|
+
def layer_groups(self) -> dict[str, list[TensorInfo]]:
|
|
102
|
+
"""Group tensors by layer prefix (e.g., 'model.layers.0')."""
|
|
103
|
+
groups: dict[str, list[TensorInfo]] = {}
|
|
104
|
+
for t in self.tensors:
|
|
105
|
+
parts = t.name.split(".")
|
|
106
|
+
# Take first 3 parts as group key, or full name if shorter
|
|
107
|
+
key = ".".join(parts[:3]) if len(parts) > 3 else t.name
|
|
108
|
+
groups.setdefault(key, []).append(t)
|
|
109
|
+
return groups
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
@dataclass
|
|
113
|
+
class DiffEntry:
|
|
114
|
+
"""A difference between two checkpoints."""
|
|
115
|
+
|
|
116
|
+
tensor_name: str
|
|
117
|
+
change_type: str # "added", "removed", "shape_changed", "dtype_changed", "values_changed"
|
|
118
|
+
details: str = ""
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
@dataclass
|
|
122
|
+
class DiffResult:
|
|
123
|
+
"""Result of comparing two checkpoints."""
|
|
124
|
+
|
|
125
|
+
path_a: str
|
|
126
|
+
path_b: str
|
|
127
|
+
entries: list[DiffEntry]
|
|
128
|
+
n_shared: int = 0
|
|
129
|
+
n_identical: int = 0
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def n_changes(self) -> int:
|
|
133
|
+
return len(self.entries)
|
|
134
|
+
|
|
135
|
+
@property
|
|
136
|
+
def has_changes(self) -> bool:
|
|
137
|
+
return len(self.entries) > 0
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
@dataclass
|
|
141
|
+
class MergeConfig:
|
|
142
|
+
"""Configuration for LoRA adapter merging."""
|
|
143
|
+
|
|
144
|
+
base_path: str
|
|
145
|
+
adapter_path: str
|
|
146
|
+
output_path: str
|
|
147
|
+
alpha: float = 1.0
|
|
148
|
+
device: str = "cpu"
|
|
149
|
+
|
|
150
|
+
def __post_init__(self) -> None:
|
|
151
|
+
if not (0.0 <= self.alpha <= 2.0):
|
|
152
|
+
raise ValueError(f"alpha must be between 0.0 and 2.0, got {self.alpha}")
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
class CkptError(Exception):
|
|
156
|
+
"""Base exception for ckptkit."""
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
class FormatError(CkptError):
|
|
160
|
+
"""Unsupported or corrupt file format."""
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
class MergeError(CkptError):
|
|
164
|
+
"""Error during checkpoint merging."""
|
ckptkit/cli.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
"""CLI for ckptkit."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
try:
|
|
9
|
+
import click
|
|
10
|
+
_HAS_CLICK = True
|
|
11
|
+
except ImportError:
|
|
12
|
+
_HAS_CLICK = False
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from rich.console import Console
|
|
16
|
+
from rich.table import Table
|
|
17
|
+
_console = Console()
|
|
18
|
+
_HAS_RICH = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
_HAS_RICH = False
|
|
21
|
+
_console = None # type: ignore[assignment]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _build_cli() -> Any:
|
|
25
|
+
if not _HAS_CLICK:
|
|
26
|
+
return None
|
|
27
|
+
|
|
28
|
+
from ckptkit.diff import diff, format_diff
|
|
29
|
+
from ckptkit.inspect import format_params, format_size, inspect
|
|
30
|
+
from ckptkit.stats import format_stats, stats_from_info
|
|
31
|
+
from ckptkit.validate import validate
|
|
32
|
+
|
|
33
|
+
@click.group()
|
|
34
|
+
@click.version_option(package_name="ckptkit")
|
|
35
|
+
def cli() -> None:
|
|
36
|
+
"""ckptkit — inspect, convert, diff, and merge model checkpoints."""
|
|
37
|
+
|
|
38
|
+
@cli.command()
|
|
39
|
+
@click.argument("path", type=click.Path(exists=True))
|
|
40
|
+
@click.option("--json", "as_json", is_flag=True, help="Output as JSON.")
|
|
41
|
+
def info(path: str, as_json: bool) -> None:
|
|
42
|
+
"""Inspect a checkpoint file."""
|
|
43
|
+
import json as json_mod
|
|
44
|
+
ckpt_info = inspect(path)
|
|
45
|
+
|
|
46
|
+
if as_json:
|
|
47
|
+
data = {
|
|
48
|
+
"path": ckpt_info.path,
|
|
49
|
+
"format": ckpt_info.format.value,
|
|
50
|
+
"file_size": ckpt_info.file_size,
|
|
51
|
+
"n_tensors": ckpt_info.n_tensors,
|
|
52
|
+
"n_parameters": ckpt_info.n_parameters,
|
|
53
|
+
"tensors": [
|
|
54
|
+
{"name": t.name, "shape": t.shape, "dtype": t.dtype.value}
|
|
55
|
+
for t in ckpt_info.tensors
|
|
56
|
+
],
|
|
57
|
+
}
|
|
58
|
+
click.echo(json_mod.dumps(data, indent=2))
|
|
59
|
+
return
|
|
60
|
+
|
|
61
|
+
if _HAS_RICH and _console is not None:
|
|
62
|
+
_console.print(f"\n[bold]File:[/bold] {ckpt_info.path}")
|
|
63
|
+
_console.print(f"[bold]Format:[/bold] {ckpt_info.format.value}")
|
|
64
|
+
_console.print(f"[bold]Size:[/bold] {format_size(ckpt_info.file_size)}")
|
|
65
|
+
_console.print(f"[bold]Parameters:[/bold] {ckpt_info.n_parameters:,} ({format_params(ckpt_info.n_parameters)})")
|
|
66
|
+
_console.print(f"[bold]Tensors:[/bold] {ckpt_info.n_tensors}")
|
|
67
|
+
_console.print()
|
|
68
|
+
|
|
69
|
+
table = Table(title="Tensors", show_lines=False)
|
|
70
|
+
table.add_column("Name", style="cyan")
|
|
71
|
+
table.add_column("Shape", justify="right")
|
|
72
|
+
table.add_column("DType", justify="center")
|
|
73
|
+
table.add_column("Params", justify="right")
|
|
74
|
+
table.add_column("Size", justify="right")
|
|
75
|
+
|
|
76
|
+
for t in ckpt_info.tensors[:50]:
|
|
77
|
+
table.add_row(
|
|
78
|
+
t.name, t.shape_str, t.dtype.value,
|
|
79
|
+
f"{t.numel:,}", format_size(t.size_bytes),
|
|
80
|
+
)
|
|
81
|
+
if len(ckpt_info.tensors) > 50:
|
|
82
|
+
table.add_row("...", f"({len(ckpt_info.tensors) - 50} more)", "", "", "")
|
|
83
|
+
|
|
84
|
+
_console.print(table)
|
|
85
|
+
else:
|
|
86
|
+
click.echo(f"File: {ckpt_info.path}")
|
|
87
|
+
click.echo(f"Format: {ckpt_info.format.value}")
|
|
88
|
+
click.echo(f"Parameters: {ckpt_info.n_parameters:,}")
|
|
89
|
+
for t in ckpt_info.tensors[:50]:
|
|
90
|
+
click.echo(f" {t.name}: {t.shape} {t.dtype.value}")
|
|
91
|
+
|
|
92
|
+
@cli.command(name="diff")
|
|
93
|
+
@click.argument("path_a", type=click.Path(exists=True))
|
|
94
|
+
@click.argument("path_b", type=click.Path(exists=True))
|
|
95
|
+
def diff_cmd(path_a: str, path_b: str) -> None:
|
|
96
|
+
"""Compare two checkpoint files."""
|
|
97
|
+
result = diff(path_a, path_b)
|
|
98
|
+
|
|
99
|
+
if _HAS_RICH and _console is not None:
|
|
100
|
+
_console.print(f"\n[bold]A:[/bold] {result.path_a}")
|
|
101
|
+
_console.print(f"[bold]B:[/bold] {result.path_b}")
|
|
102
|
+
_console.print(f"Shared: {result.n_shared} Identical: {result.n_identical} Changes: {result.n_changes}")
|
|
103
|
+
_console.print()
|
|
104
|
+
|
|
105
|
+
if result.entries:
|
|
106
|
+
table = Table(show_lines=False)
|
|
107
|
+
table.add_column("", style="bold", width=3)
|
|
108
|
+
table.add_column("Tensor", style="cyan")
|
|
109
|
+
table.add_column("Change")
|
|
110
|
+
table.add_column("Details")
|
|
111
|
+
|
|
112
|
+
for e in result.entries:
|
|
113
|
+
symbols = {"added": "[green]+[/green]", "removed": "[red]-[/red]",
|
|
114
|
+
"shape_changed": "[yellow]~[/yellow]", "dtype_changed": "[yellow]~[/yellow]"}
|
|
115
|
+
table.add_row(symbols.get(e.change_type, "?"), e.tensor_name, e.change_type, e.details)
|
|
116
|
+
_console.print(table)
|
|
117
|
+
else:
|
|
118
|
+
_console.print("[green]Checkpoints are structurally identical.[/green]")
|
|
119
|
+
else:
|
|
120
|
+
click.echo(format_diff(result))
|
|
121
|
+
|
|
122
|
+
@cli.command()
|
|
123
|
+
@click.argument("path", type=click.Path(exists=True))
|
|
124
|
+
def stats(path: str) -> None:
|
|
125
|
+
"""Show checkpoint statistics."""
|
|
126
|
+
ckpt_info = inspect(path)
|
|
127
|
+
ckpt_stats = stats_from_info(ckpt_info)
|
|
128
|
+
click.echo(format_stats(ckpt_stats))
|
|
129
|
+
|
|
130
|
+
@cli.command(name="validate")
|
|
131
|
+
@click.argument("path", type=click.Path(exists=True))
|
|
132
|
+
def validate_cmd(path: str) -> None:
|
|
133
|
+
"""Validate checkpoint integrity."""
|
|
134
|
+
result = validate(path)
|
|
135
|
+
if result.valid:
|
|
136
|
+
click.echo(f"✓ {path}: valid ({result.format.value})")
|
|
137
|
+
else:
|
|
138
|
+
click.echo(f"✗ {path}: invalid")
|
|
139
|
+
for issue in result.issues:
|
|
140
|
+
prefix = " ERROR:" if issue.severity == "error" else " WARN:"
|
|
141
|
+
click.echo(f"{prefix} {issue.message}")
|
|
142
|
+
if not result.valid:
|
|
143
|
+
raise SystemExit(1)
|
|
144
|
+
|
|
145
|
+
return cli
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
cli = _build_cli()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def main() -> None:
|
|
152
|
+
if cli is None:
|
|
153
|
+
print(
|
|
154
|
+
"The CLI requires extra dependencies. Install with:\n"
|
|
155
|
+
" pip install ckptkit[cli]",
|
|
156
|
+
file=sys.stderr,
|
|
157
|
+
)
|
|
158
|
+
sys.exit(1)
|
|
159
|
+
cli()
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
if __name__ == "__main__":
|
|
163
|
+
main()
|