ckptkit 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ckptkit/__init__.py ADDED
@@ -0,0 +1,112 @@
1
+ """ckptkit — inspect, convert, diff, and merge model checkpoints."""
2
+
3
+ from ckptkit._types import (
4
+ DTYPE_SIZES,
5
+ CheckpointFormat,
6
+ CheckpointInfo,
7
+ CkptError,
8
+ DiffEntry,
9
+ DiffResult,
10
+ DType,
11
+ FormatError,
12
+ MergeConfig,
13
+ MergeError,
14
+ TensorInfo,
15
+ )
16
+ from ckptkit.convert import (
17
+ ConversionConfig,
18
+ ConversionFormat,
19
+ ConversionResult,
20
+ FormatConverter,
21
+ convert_checkpoint,
22
+ )
23
+ from ckptkit.diff import diff, diff_infos, format_diff, format_diff_rich, format_diff_table
24
+ from ckptkit.estimator import (
25
+ EstimationResult,
26
+ QuantEstimationResult,
27
+ TensorEstimate,
28
+ estimate_quantized_size,
29
+ estimate_reduction,
30
+ format_estimation,
31
+ )
32
+ from ckptkit.gguf import GGUFInfo, format_gguf_info, inspect_gguf, parse_gguf
33
+ from ckptkit.inspect import (
34
+ detect_format,
35
+ format_params,
36
+ format_size,
37
+ inspect,
38
+ inspect_safetensors,
39
+ )
40
+ from ckptkit.merge import find_lora_pairs, merge_lora_state_dicts
41
+ from ckptkit.metadata import (
42
+ CheckpointMetadata,
43
+ MetadataEditor,
44
+ extract_metadata_from_path,
45
+ format_metadata_report,
46
+ )
47
+ from ckptkit.stats import CheckpointStats, TensorStats, stats_from_info
48
+ from ckptkit.validate import ValidationIssue, ValidationResult, validate
49
+
50
+ __version__ = "0.3.0"
51
+
52
+ __all__ = [
53
+ "__version__",
54
+ # Types
55
+ "CheckpointFormat",
56
+ "CheckpointInfo",
57
+ "CkptError",
58
+ "DType",
59
+ "DTYPE_SIZES",
60
+ "DiffEntry",
61
+ "DiffResult",
62
+ "FormatError",
63
+ "MergeConfig",
64
+ "MergeError",
65
+ "TensorInfo",
66
+ # Inspect
67
+ "detect_format",
68
+ "inspect",
69
+ "inspect_safetensors",
70
+ "format_size",
71
+ "format_params",
72
+ # Diff
73
+ "diff",
74
+ "diff_infos",
75
+ "format_diff",
76
+ "format_diff_rich",
77
+ "format_diff_table",
78
+ # Merge
79
+ "merge_lora_state_dicts",
80
+ "find_lora_pairs",
81
+ # Stats
82
+ "CheckpointStats",
83
+ "TensorStats",
84
+ "stats_from_info",
85
+ # Validate
86
+ "validate",
87
+ "ValidationResult",
88
+ "ValidationIssue",
89
+ # Estimator
90
+ "EstimationResult",
91
+ "QuantEstimationResult",
92
+ "TensorEstimate",
93
+ "estimate_reduction",
94
+ "estimate_quantized_size",
95
+ "format_estimation",
96
+ # Convert
97
+ "ConversionConfig",
98
+ "ConversionFormat",
99
+ "ConversionResult",
100
+ "FormatConverter",
101
+ "convert_checkpoint",
102
+ # GGUF
103
+ "GGUFInfo",
104
+ "parse_gguf",
105
+ "inspect_gguf",
106
+ "format_gguf_info",
107
+ # Metadata
108
+ "CheckpointMetadata",
109
+ "MetadataEditor",
110
+ "extract_metadata_from_path",
111
+ "format_metadata_report",
112
+ ]
ckptkit/_types.py ADDED
@@ -0,0 +1,164 @@
1
+ """Core types for ckptkit."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from enum import Enum
7
+ from typing import Any
8
+
9
+
10
+ class CheckpointFormat(str, Enum):
11
+ """Supported checkpoint formats."""
12
+
13
+ SAFETENSORS = "safetensors"
14
+ PYTORCH = "pytorch"
15
+ NUMPY = "numpy"
16
+ UNKNOWN = "unknown"
17
+
18
+
19
+ class DType(str, Enum):
20
+ """Common tensor data types."""
21
+
22
+ F32 = "F32"
23
+ F16 = "F16"
24
+ BF16 = "BF16"
25
+ F64 = "F64"
26
+ I64 = "I64"
27
+ I32 = "I32"
28
+ I16 = "I16"
29
+ I8 = "I8"
30
+ U8 = "U8"
31
+ BOOL = "BOOL"
32
+ UNKNOWN = "UNKNOWN"
33
+
34
+
35
+ # Bytes per element for each dtype
36
+ DTYPE_SIZES: dict[DType, int] = {
37
+ DType.F32: 4, DType.F16: 2, DType.BF16: 2, DType.F64: 8,
38
+ DType.I64: 8, DType.I32: 4, DType.I16: 2, DType.I8: 1,
39
+ DType.U8: 1, DType.BOOL: 1,
40
+ }
41
+
42
+
43
+ @dataclass
44
+ class TensorInfo:
45
+ """Metadata for a single tensor in a checkpoint."""
46
+
47
+ name: str
48
+ shape: list[int]
49
+ dtype: DType
50
+ offset_start: int = 0
51
+ offset_end: int = 0
52
+
53
+ @property
54
+ def numel(self) -> int:
55
+ """Number of elements."""
56
+ result = 1
57
+ for s in self.shape:
58
+ result *= s
59
+ return result
60
+
61
+ @property
62
+ def size_bytes(self) -> int:
63
+ """Size in bytes."""
64
+ return self.numel * DTYPE_SIZES.get(self.dtype, 0)
65
+
66
+ @property
67
+ def shape_str(self) -> str:
68
+ return "×".join(str(s) for s in self.shape)
69
+
70
+
71
+ @dataclass
72
+ class CheckpointInfo:
73
+ """Metadata for an entire checkpoint."""
74
+
75
+ path: str
76
+ format: CheckpointFormat
77
+ file_size: int
78
+ tensors: list[TensorInfo]
79
+ metadata: dict[str, Any] = field(default_factory=dict)
80
+
81
+ @property
82
+ def n_tensors(self) -> int:
83
+ return len(self.tensors)
84
+
85
+ @property
86
+ def n_parameters(self) -> int:
87
+ return sum(t.numel for t in self.tensors)
88
+
89
+ @property
90
+ def total_bytes(self) -> int:
91
+ return sum(t.size_bytes for t in self.tensors)
92
+
93
+ def dtype_summary(self) -> dict[str, int]:
94
+ """Count parameters per dtype."""
95
+ counts: dict[str, int] = {}
96
+ for t in self.tensors:
97
+ key = t.dtype.value
98
+ counts[key] = counts.get(key, 0) + t.numel
99
+ return counts
100
+
101
+ def layer_groups(self) -> dict[str, list[TensorInfo]]:
102
+ """Group tensors by layer prefix (e.g., 'model.layers.0')."""
103
+ groups: dict[str, list[TensorInfo]] = {}
104
+ for t in self.tensors:
105
+ parts = t.name.split(".")
106
+ # Take first 3 parts as group key, or full name if shorter
107
+ key = ".".join(parts[:3]) if len(parts) > 3 else t.name
108
+ groups.setdefault(key, []).append(t)
109
+ return groups
110
+
111
+
112
+ @dataclass
113
+ class DiffEntry:
114
+ """A difference between two checkpoints."""
115
+
116
+ tensor_name: str
117
+ change_type: str # "added", "removed", "shape_changed", "dtype_changed", "values_changed"
118
+ details: str = ""
119
+
120
+
121
+ @dataclass
122
+ class DiffResult:
123
+ """Result of comparing two checkpoints."""
124
+
125
+ path_a: str
126
+ path_b: str
127
+ entries: list[DiffEntry]
128
+ n_shared: int = 0
129
+ n_identical: int = 0
130
+
131
+ @property
132
+ def n_changes(self) -> int:
133
+ return len(self.entries)
134
+
135
+ @property
136
+ def has_changes(self) -> bool:
137
+ return len(self.entries) > 0
138
+
139
+
140
+ @dataclass
141
+ class MergeConfig:
142
+ """Configuration for LoRA adapter merging."""
143
+
144
+ base_path: str
145
+ adapter_path: str
146
+ output_path: str
147
+ alpha: float = 1.0
148
+ device: str = "cpu"
149
+
150
+ def __post_init__(self) -> None:
151
+ if not (0.0 <= self.alpha <= 2.0):
152
+ raise ValueError(f"alpha must be between 0.0 and 2.0, got {self.alpha}")
153
+
154
+
155
+ class CkptError(Exception):
156
+ """Base exception for ckptkit."""
157
+
158
+
159
+ class FormatError(CkptError):
160
+ """Unsupported or corrupt file format."""
161
+
162
+
163
+ class MergeError(CkptError):
164
+ """Error during checkpoint merging."""
ckptkit/cli.py ADDED
@@ -0,0 +1,163 @@
1
+ """CLI for ckptkit."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import sys
6
+ from typing import Any
7
+
8
+ try:
9
+ import click
10
+ _HAS_CLICK = True
11
+ except ImportError:
12
+ _HAS_CLICK = False
13
+
14
+ try:
15
+ from rich.console import Console
16
+ from rich.table import Table
17
+ _console = Console()
18
+ _HAS_RICH = True
19
+ except ImportError:
20
+ _HAS_RICH = False
21
+ _console = None # type: ignore[assignment]
22
+
23
+
24
+ def _build_cli() -> Any:
25
+ if not _HAS_CLICK:
26
+ return None
27
+
28
+ from ckptkit.diff import diff, format_diff
29
+ from ckptkit.inspect import format_params, format_size, inspect
30
+ from ckptkit.stats import format_stats, stats_from_info
31
+ from ckptkit.validate import validate
32
+
33
+ @click.group()
34
+ @click.version_option(package_name="ckptkit")
35
+ def cli() -> None:
36
+ """ckptkit — inspect, convert, diff, and merge model checkpoints."""
37
+
38
+ @cli.command()
39
+ @click.argument("path", type=click.Path(exists=True))
40
+ @click.option("--json", "as_json", is_flag=True, help="Output as JSON.")
41
+ def info(path: str, as_json: bool) -> None:
42
+ """Inspect a checkpoint file."""
43
+ import json as json_mod
44
+ ckpt_info = inspect(path)
45
+
46
+ if as_json:
47
+ data = {
48
+ "path": ckpt_info.path,
49
+ "format": ckpt_info.format.value,
50
+ "file_size": ckpt_info.file_size,
51
+ "n_tensors": ckpt_info.n_tensors,
52
+ "n_parameters": ckpt_info.n_parameters,
53
+ "tensors": [
54
+ {"name": t.name, "shape": t.shape, "dtype": t.dtype.value}
55
+ for t in ckpt_info.tensors
56
+ ],
57
+ }
58
+ click.echo(json_mod.dumps(data, indent=2))
59
+ return
60
+
61
+ if _HAS_RICH and _console is not None:
62
+ _console.print(f"\n[bold]File:[/bold] {ckpt_info.path}")
63
+ _console.print(f"[bold]Format:[/bold] {ckpt_info.format.value}")
64
+ _console.print(f"[bold]Size:[/bold] {format_size(ckpt_info.file_size)}")
65
+ _console.print(f"[bold]Parameters:[/bold] {ckpt_info.n_parameters:,} ({format_params(ckpt_info.n_parameters)})")
66
+ _console.print(f"[bold]Tensors:[/bold] {ckpt_info.n_tensors}")
67
+ _console.print()
68
+
69
+ table = Table(title="Tensors", show_lines=False)
70
+ table.add_column("Name", style="cyan")
71
+ table.add_column("Shape", justify="right")
72
+ table.add_column("DType", justify="center")
73
+ table.add_column("Params", justify="right")
74
+ table.add_column("Size", justify="right")
75
+
76
+ for t in ckpt_info.tensors[:50]:
77
+ table.add_row(
78
+ t.name, t.shape_str, t.dtype.value,
79
+ f"{t.numel:,}", format_size(t.size_bytes),
80
+ )
81
+ if len(ckpt_info.tensors) > 50:
82
+ table.add_row("...", f"({len(ckpt_info.tensors) - 50} more)", "", "", "")
83
+
84
+ _console.print(table)
85
+ else:
86
+ click.echo(f"File: {ckpt_info.path}")
87
+ click.echo(f"Format: {ckpt_info.format.value}")
88
+ click.echo(f"Parameters: {ckpt_info.n_parameters:,}")
89
+ for t in ckpt_info.tensors[:50]:
90
+ click.echo(f" {t.name}: {t.shape} {t.dtype.value}")
91
+
92
+ @cli.command(name="diff")
93
+ @click.argument("path_a", type=click.Path(exists=True))
94
+ @click.argument("path_b", type=click.Path(exists=True))
95
+ def diff_cmd(path_a: str, path_b: str) -> None:
96
+ """Compare two checkpoint files."""
97
+ result = diff(path_a, path_b)
98
+
99
+ if _HAS_RICH and _console is not None:
100
+ _console.print(f"\n[bold]A:[/bold] {result.path_a}")
101
+ _console.print(f"[bold]B:[/bold] {result.path_b}")
102
+ _console.print(f"Shared: {result.n_shared} Identical: {result.n_identical} Changes: {result.n_changes}")
103
+ _console.print()
104
+
105
+ if result.entries:
106
+ table = Table(show_lines=False)
107
+ table.add_column("", style="bold", width=3)
108
+ table.add_column("Tensor", style="cyan")
109
+ table.add_column("Change")
110
+ table.add_column("Details")
111
+
112
+ for e in result.entries:
113
+ symbols = {"added": "[green]+[/green]", "removed": "[red]-[/red]",
114
+ "shape_changed": "[yellow]~[/yellow]", "dtype_changed": "[yellow]~[/yellow]"}
115
+ table.add_row(symbols.get(e.change_type, "?"), e.tensor_name, e.change_type, e.details)
116
+ _console.print(table)
117
+ else:
118
+ _console.print("[green]Checkpoints are structurally identical.[/green]")
119
+ else:
120
+ click.echo(format_diff(result))
121
+
122
+ @cli.command()
123
+ @click.argument("path", type=click.Path(exists=True))
124
+ def stats(path: str) -> None:
125
+ """Show checkpoint statistics."""
126
+ ckpt_info = inspect(path)
127
+ ckpt_stats = stats_from_info(ckpt_info)
128
+ click.echo(format_stats(ckpt_stats))
129
+
130
+ @cli.command(name="validate")
131
+ @click.argument("path", type=click.Path(exists=True))
132
+ def validate_cmd(path: str) -> None:
133
+ """Validate checkpoint integrity."""
134
+ result = validate(path)
135
+ if result.valid:
136
+ click.echo(f"✓ {path}: valid ({result.format.value})")
137
+ else:
138
+ click.echo(f"✗ {path}: invalid")
139
+ for issue in result.issues:
140
+ prefix = " ERROR:" if issue.severity == "error" else " WARN:"
141
+ click.echo(f"{prefix} {issue.message}")
142
+ if not result.valid:
143
+ raise SystemExit(1)
144
+
145
+ return cli
146
+
147
+
148
+ cli = _build_cli()
149
+
150
+
151
+ def main() -> None:
152
+ if cli is None:
153
+ print(
154
+ "The CLI requires extra dependencies. Install with:\n"
155
+ " pip install ckptkit[cli]",
156
+ file=sys.stderr,
157
+ )
158
+ sys.exit(1)
159
+ cli()
160
+
161
+
162
+ if __name__ == "__main__":
163
+ main()