haoline 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- haoline/.streamlit/config.toml +10 -0
- haoline/__init__.py +248 -0
- haoline/analyzer.py +935 -0
- haoline/cli.py +2712 -0
- haoline/compare.py +811 -0
- haoline/compare_visualizations.py +1564 -0
- haoline/edge_analysis.py +525 -0
- haoline/eval/__init__.py +131 -0
- haoline/eval/adapters.py +844 -0
- haoline/eval/cli.py +390 -0
- haoline/eval/comparison.py +542 -0
- haoline/eval/deployment.py +633 -0
- haoline/eval/schemas.py +833 -0
- haoline/examples/__init__.py +15 -0
- haoline/examples/basic_inspection.py +74 -0
- haoline/examples/compare_models.py +117 -0
- haoline/examples/hardware_estimation.py +78 -0
- haoline/format_adapters.py +1001 -0
- haoline/formats/__init__.py +123 -0
- haoline/formats/coreml.py +250 -0
- haoline/formats/gguf.py +483 -0
- haoline/formats/openvino.py +255 -0
- haoline/formats/safetensors.py +273 -0
- haoline/formats/tflite.py +369 -0
- haoline/hardware.py +2307 -0
- haoline/hierarchical_graph.py +462 -0
- haoline/html_export.py +1573 -0
- haoline/layer_summary.py +769 -0
- haoline/llm_summarizer.py +465 -0
- haoline/op_icons.py +618 -0
- haoline/operational_profiling.py +1492 -0
- haoline/patterns.py +1116 -0
- haoline/pdf_generator.py +265 -0
- haoline/privacy.py +250 -0
- haoline/pydantic_models.py +241 -0
- haoline/report.py +1923 -0
- haoline/report_sections.py +539 -0
- haoline/risks.py +521 -0
- haoline/schema.py +523 -0
- haoline/streamlit_app.py +2024 -0
- haoline/tests/__init__.py +4 -0
- haoline/tests/conftest.py +123 -0
- haoline/tests/test_analyzer.py +868 -0
- haoline/tests/test_compare_visualizations.py +293 -0
- haoline/tests/test_edge_analysis.py +243 -0
- haoline/tests/test_eval.py +604 -0
- haoline/tests/test_format_adapters.py +460 -0
- haoline/tests/test_hardware.py +237 -0
- haoline/tests/test_hardware_recommender.py +90 -0
- haoline/tests/test_hierarchical_graph.py +326 -0
- haoline/tests/test_html_export.py +180 -0
- haoline/tests/test_layer_summary.py +428 -0
- haoline/tests/test_llm_patterns.py +540 -0
- haoline/tests/test_llm_summarizer.py +339 -0
- haoline/tests/test_patterns.py +774 -0
- haoline/tests/test_pytorch.py +327 -0
- haoline/tests/test_report.py +383 -0
- haoline/tests/test_risks.py +398 -0
- haoline/tests/test_schema.py +417 -0
- haoline/tests/test_tensorflow.py +380 -0
- haoline/tests/test_visualizations.py +316 -0
- haoline/universal_ir.py +856 -0
- haoline/visualizations.py +1086 -0
- haoline/visualize_yolo.py +44 -0
- haoline/web.py +110 -0
- haoline-0.3.0.dist-info/METADATA +471 -0
- haoline-0.3.0.dist-info/RECORD +70 -0
- haoline-0.3.0.dist-info/WHEEL +4 -0
- haoline-0.3.0.dist-info/entry_points.txt +5 -0
- haoline-0.3.0.dist-info/licenses/LICENSE +22 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
OpenVINO format reader.
|
|
6
|
+
|
|
7
|
+
OpenVINO is Intel's toolkit for optimizing and deploying AI inference.
|
|
8
|
+
This reader extracts model metadata from OpenVINO IR format (.xml/.bin).
|
|
9
|
+
|
|
10
|
+
Requires: openvino (pip install openvino)
|
|
11
|
+
|
|
12
|
+
Reference: https://docs.openvino.ai/
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from __future__ import annotations
|
|
16
|
+
|
|
17
|
+
from dataclasses import dataclass, field
|
|
18
|
+
from pathlib import Path
|
|
19
|
+
from typing import Any
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass
|
|
23
|
+
class OpenVINOLayerInfo:
|
|
24
|
+
"""Information about an OpenVINO layer."""
|
|
25
|
+
|
|
26
|
+
name: str
|
|
27
|
+
type: str
|
|
28
|
+
input_shapes: list[tuple[int, ...]] = field(default_factory=list)
|
|
29
|
+
output_shapes: list[tuple[int, ...]] = field(default_factory=list)
|
|
30
|
+
precision: str = "FP32"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@dataclass
|
|
34
|
+
class OpenVINOInfo:
|
|
35
|
+
"""Parsed OpenVINO model information."""
|
|
36
|
+
|
|
37
|
+
path: Path
|
|
38
|
+
name: str
|
|
39
|
+
framework: str
|
|
40
|
+
layers: list[OpenVINOLayerInfo] = field(default_factory=list)
|
|
41
|
+
inputs: list[dict[str, Any]] = field(default_factory=list)
|
|
42
|
+
outputs: list[dict[str, Any]] = field(default_factory=list)
|
|
43
|
+
metadata: dict[str, Any] = field(default_factory=dict)
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def layer_count(self) -> int:
|
|
47
|
+
"""Number of layers."""
|
|
48
|
+
return len(self.layers)
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def layer_type_counts(self) -> dict[str, int]:
|
|
52
|
+
"""Count of layers by type."""
|
|
53
|
+
counts: dict[str, int] = {}
|
|
54
|
+
for layer in self.layers:
|
|
55
|
+
counts[layer.type] = counts.get(layer.type, 0) + 1
|
|
56
|
+
return counts
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def precision_breakdown(self) -> dict[str, int]:
|
|
60
|
+
"""Count of layers by precision."""
|
|
61
|
+
breakdown: dict[str, int] = {}
|
|
62
|
+
for layer in self.layers:
|
|
63
|
+
breakdown[layer.precision] = breakdown.get(layer.precision, 0) + 1
|
|
64
|
+
return breakdown
|
|
65
|
+
|
|
66
|
+
def to_dict(self) -> dict[str, Any]:
|
|
67
|
+
"""Convert to dictionary for JSON serialization."""
|
|
68
|
+
return {
|
|
69
|
+
"path": str(self.path),
|
|
70
|
+
"name": self.name,
|
|
71
|
+
"framework": self.framework,
|
|
72
|
+
"layer_count": self.layer_count,
|
|
73
|
+
"layer_type_counts": self.layer_type_counts,
|
|
74
|
+
"precision_breakdown": self.precision_breakdown,
|
|
75
|
+
"inputs": self.inputs,
|
|
76
|
+
"outputs": self.outputs,
|
|
77
|
+
"metadata": self.metadata,
|
|
78
|
+
}
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class OpenVINOReader:
|
|
82
|
+
"""Reader for OpenVINO IR format (.xml/.bin)."""
|
|
83
|
+
|
|
84
|
+
def __init__(self, path: str | Path):
|
|
85
|
+
"""
|
|
86
|
+
Initialize reader with file path.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
path: Path to the OpenVINO model (.xml or directory).
|
|
90
|
+
|
|
91
|
+
Raises:
|
|
92
|
+
ImportError: If openvino is not installed.
|
|
93
|
+
"""
|
|
94
|
+
self.path = Path(path)
|
|
95
|
+
|
|
96
|
+
# Handle both .xml path and directory containing model
|
|
97
|
+
if self.path.is_dir():
|
|
98
|
+
xml_files = list(self.path.glob("*.xml"))
|
|
99
|
+
if xml_files:
|
|
100
|
+
self.path = xml_files[0]
|
|
101
|
+
else:
|
|
102
|
+
raise FileNotFoundError(f"No .xml file found in: {self.path}")
|
|
103
|
+
elif self.path.suffix.lower() != ".xml":
|
|
104
|
+
# Try adding .xml
|
|
105
|
+
xml_path = self.path.with_suffix(".xml")
|
|
106
|
+
if xml_path.exists():
|
|
107
|
+
self.path = xml_path
|
|
108
|
+
|
|
109
|
+
if not self.path.exists():
|
|
110
|
+
raise FileNotFoundError(f"OpenVINO model not found: {self.path}")
|
|
111
|
+
|
|
112
|
+
try:
|
|
113
|
+
import openvino # noqa: F401
|
|
114
|
+
except ImportError as e:
|
|
115
|
+
raise ImportError("openvino required. Install with: pip install openvino") from e
|
|
116
|
+
|
|
117
|
+
def read(self) -> OpenVINOInfo:
|
|
118
|
+
"""
|
|
119
|
+
Read and parse the OpenVINO model.
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
OpenVINOInfo with model metadata.
|
|
123
|
+
"""
|
|
124
|
+
from openvino.runtime import Core
|
|
125
|
+
|
|
126
|
+
core = Core()
|
|
127
|
+
model = core.read_model(str(self.path))
|
|
128
|
+
|
|
129
|
+
# Extract basic info
|
|
130
|
+
name = model.get_friendly_name()
|
|
131
|
+
|
|
132
|
+
# Get RT info for framework
|
|
133
|
+
rt_info = model.get_rt_info()
|
|
134
|
+
framework = "unknown"
|
|
135
|
+
if "framework" in rt_info:
|
|
136
|
+
framework = str(rt_info["framework"].value)
|
|
137
|
+
|
|
138
|
+
# Extract inputs
|
|
139
|
+
inputs = []
|
|
140
|
+
for inp in model.inputs:
|
|
141
|
+
input_info = {
|
|
142
|
+
"name": inp.get_any_name(),
|
|
143
|
+
"shape": list(inp.get_partial_shape()),
|
|
144
|
+
"element_type": str(inp.get_element_type()),
|
|
145
|
+
}
|
|
146
|
+
inputs.append(input_info)
|
|
147
|
+
|
|
148
|
+
# Extract outputs
|
|
149
|
+
outputs = []
|
|
150
|
+
for out in model.outputs:
|
|
151
|
+
output_info = {
|
|
152
|
+
"name": out.get_any_name(),
|
|
153
|
+
"shape": list(out.get_partial_shape()),
|
|
154
|
+
"element_type": str(out.get_element_type()),
|
|
155
|
+
}
|
|
156
|
+
outputs.append(output_info)
|
|
157
|
+
|
|
158
|
+
# Extract layers (operations)
|
|
159
|
+
layers = []
|
|
160
|
+
for op in model.get_ordered_ops():
|
|
161
|
+
# Get input shapes
|
|
162
|
+
input_shapes = []
|
|
163
|
+
for inp in op.inputs():
|
|
164
|
+
shape = inp.get_partial_shape()
|
|
165
|
+
if shape.is_static:
|
|
166
|
+
input_shapes.append(tuple(shape.to_shape()))
|
|
167
|
+
else:
|
|
168
|
+
input_shapes.append(tuple(d.get_length() if d.is_static else -1 for d in shape))
|
|
169
|
+
|
|
170
|
+
# Get output shapes
|
|
171
|
+
output_shapes = []
|
|
172
|
+
for out in op.outputs():
|
|
173
|
+
shape = out.get_partial_shape()
|
|
174
|
+
if shape.is_static:
|
|
175
|
+
output_shapes.append(tuple(shape.to_shape()))
|
|
176
|
+
else:
|
|
177
|
+
output_shapes.append(
|
|
178
|
+
tuple(d.get_length() if d.is_static else -1 for d in shape)
|
|
179
|
+
)
|
|
180
|
+
|
|
181
|
+
# Get precision from output
|
|
182
|
+
precision = "FP32"
|
|
183
|
+
if op.outputs():
|
|
184
|
+
element_type = str(op.outputs()[0].get_element_type())
|
|
185
|
+
precision = element_type.upper().replace("F", "FP").replace("I", "INT")
|
|
186
|
+
|
|
187
|
+
layers.append(
|
|
188
|
+
OpenVINOLayerInfo(
|
|
189
|
+
name=op.get_friendly_name(),
|
|
190
|
+
type=op.get_type_name(),
|
|
191
|
+
input_shapes=input_shapes,
|
|
192
|
+
output_shapes=output_shapes,
|
|
193
|
+
precision=precision,
|
|
194
|
+
)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Additional metadata
|
|
198
|
+
metadata = {}
|
|
199
|
+
for key in rt_info:
|
|
200
|
+
try:
|
|
201
|
+
metadata[key] = str(rt_info[key].value)
|
|
202
|
+
except Exception:
|
|
203
|
+
pass
|
|
204
|
+
|
|
205
|
+
return OpenVINOInfo(
|
|
206
|
+
path=self.path,
|
|
207
|
+
name=name,
|
|
208
|
+
framework=framework,
|
|
209
|
+
layers=layers,
|
|
210
|
+
inputs=inputs,
|
|
211
|
+
outputs=outputs,
|
|
212
|
+
metadata=metadata,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
|
|
216
|
+
def is_openvino_file(path: str | Path) -> bool:
|
|
217
|
+
"""
|
|
218
|
+
Check if a file is an OpenVINO model.
|
|
219
|
+
|
|
220
|
+
Args:
|
|
221
|
+
path: Path to check.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
True if the file is an OpenVINO IR model.
|
|
225
|
+
"""
|
|
226
|
+
path = Path(path)
|
|
227
|
+
if not path.exists():
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
# Check for .xml file
|
|
231
|
+
if path.suffix.lower() == ".xml":
|
|
232
|
+
# Quick check for OpenVINO XML structure
|
|
233
|
+
try:
|
|
234
|
+
with open(path, encoding="utf-8") as f:
|
|
235
|
+
header = f.read(500)
|
|
236
|
+
return "<net" in header and ("name=" in header or "version=" in header)
|
|
237
|
+
except Exception:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
# Check if .bin file has corresponding .xml
|
|
241
|
+
if path.suffix.lower() == ".bin":
|
|
242
|
+
xml_path = path.with_suffix(".xml")
|
|
243
|
+
return xml_path.exists() and is_openvino_file(xml_path)
|
|
244
|
+
|
|
245
|
+
return False
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def is_available() -> bool:
|
|
249
|
+
"""Check if openvino is available."""
|
|
250
|
+
try:
|
|
251
|
+
import openvino # noqa: F401
|
|
252
|
+
|
|
253
|
+
return True
|
|
254
|
+
except ImportError:
|
|
255
|
+
return False
|
|
@@ -0,0 +1,273 @@
|
|
|
1
|
+
# Copyright (c) 2025 HaoLine Contributors
|
|
2
|
+
# SPDX-License-Identifier: MIT
|
|
3
|
+
|
|
4
|
+
"""
|
|
5
|
+
SafeTensors format reader.
|
|
6
|
+
|
|
7
|
+
SafeTensors is a simple, safe format for storing tensors, widely used
|
|
8
|
+
in the HuggingFace ecosystem. Unlike pickle-based formats, SafeTensors
|
|
9
|
+
cannot execute arbitrary code.
|
|
10
|
+
|
|
11
|
+
This reader extracts:
|
|
12
|
+
- Tensor names, shapes, and dtypes
|
|
13
|
+
- Parameter counts and memory estimates
|
|
14
|
+
- Metadata (if present)
|
|
15
|
+
|
|
16
|
+
Reference: https://github.com/huggingface/safetensors
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
from __future__ import annotations
|
|
20
|
+
|
|
21
|
+
from dataclasses import dataclass, field
|
|
22
|
+
from pathlib import Path
|
|
23
|
+
from typing import Any
|
|
24
|
+
|
|
25
|
+
# Bytes per dtype
|
|
26
|
+
DTYPE_SIZES: dict[str, int] = {
|
|
27
|
+
"F64": 8,
|
|
28
|
+
"F32": 4,
|
|
29
|
+
"F16": 2,
|
|
30
|
+
"BF16": 2,
|
|
31
|
+
"I64": 8,
|
|
32
|
+
"I32": 4,
|
|
33
|
+
"I16": 2,
|
|
34
|
+
"I8": 1,
|
|
35
|
+
"U8": 1,
|
|
36
|
+
"BOOL": 1,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class SafeTensorInfo:
|
|
42
|
+
"""Information about a single tensor."""
|
|
43
|
+
|
|
44
|
+
name: str
|
|
45
|
+
dtype: str
|
|
46
|
+
shape: tuple[int, ...]
|
|
47
|
+
|
|
48
|
+
@property
|
|
49
|
+
def n_elements(self) -> int:
|
|
50
|
+
"""Total number of elements."""
|
|
51
|
+
result = 1
|
|
52
|
+
for d in self.shape:
|
|
53
|
+
result *= d
|
|
54
|
+
return result
|
|
55
|
+
|
|
56
|
+
@property
|
|
57
|
+
def size_bytes(self) -> int:
|
|
58
|
+
"""Size in bytes."""
|
|
59
|
+
return self.n_elements * DTYPE_SIZES.get(self.dtype, 4)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class SafeTensorsInfo:
|
|
64
|
+
"""Parsed SafeTensors file information."""
|
|
65
|
+
|
|
66
|
+
path: Path
|
|
67
|
+
tensors: list[SafeTensorInfo] = field(default_factory=list)
|
|
68
|
+
metadata: dict[str, str] = field(default_factory=dict)
|
|
69
|
+
|
|
70
|
+
@property
|
|
71
|
+
def total_params(self) -> int:
|
|
72
|
+
"""Total parameter count."""
|
|
73
|
+
return sum(t.n_elements for t in self.tensors)
|
|
74
|
+
|
|
75
|
+
@property
|
|
76
|
+
def total_size_bytes(self) -> int:
|
|
77
|
+
"""Total size in bytes."""
|
|
78
|
+
return sum(t.size_bytes for t in self.tensors)
|
|
79
|
+
|
|
80
|
+
@property
|
|
81
|
+
def dtype_breakdown(self) -> dict[str, int]:
|
|
82
|
+
"""Count of tensors by dtype."""
|
|
83
|
+
breakdown: dict[str, int] = {}
|
|
84
|
+
for t in self.tensors:
|
|
85
|
+
breakdown[t.dtype] = breakdown.get(t.dtype, 0) + 1
|
|
86
|
+
return breakdown
|
|
87
|
+
|
|
88
|
+
@property
|
|
89
|
+
def size_breakdown(self) -> dict[str, int]:
|
|
90
|
+
"""Size in bytes by dtype."""
|
|
91
|
+
breakdown: dict[str, int] = {}
|
|
92
|
+
for t in self.tensors:
|
|
93
|
+
breakdown[t.dtype] = breakdown.get(t.dtype, 0) + t.size_bytes
|
|
94
|
+
return breakdown
|
|
95
|
+
|
|
96
|
+
def to_dict(self) -> dict[str, Any]:
|
|
97
|
+
"""Convert to dictionary for JSON serialization."""
|
|
98
|
+
return {
|
|
99
|
+
"path": str(self.path),
|
|
100
|
+
"tensor_count": len(self.tensors),
|
|
101
|
+
"total_params": self.total_params,
|
|
102
|
+
"total_size_bytes": self.total_size_bytes,
|
|
103
|
+
"dtype_breakdown": self.dtype_breakdown,
|
|
104
|
+
"size_breakdown": self.size_breakdown,
|
|
105
|
+
"metadata": self.metadata,
|
|
106
|
+
"tensors": [{"name": t.name, "dtype": t.dtype, "shape": t.shape} for t in self.tensors],
|
|
107
|
+
}
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
class SafeTensorsReader:
|
|
111
|
+
"""Reader for SafeTensors format files."""
|
|
112
|
+
|
|
113
|
+
def __init__(self, path: str | Path):
|
|
114
|
+
"""
|
|
115
|
+
Initialize reader with file path.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
path: Path to the SafeTensors file.
|
|
119
|
+
|
|
120
|
+
Raises:
|
|
121
|
+
ImportError: If safetensors library is not installed.
|
|
122
|
+
"""
|
|
123
|
+
self.path = Path(path)
|
|
124
|
+
if not self.path.exists():
|
|
125
|
+
raise FileNotFoundError(f"SafeTensors file not found: {self.path}")
|
|
126
|
+
|
|
127
|
+
try:
|
|
128
|
+
import safetensors # noqa: F401
|
|
129
|
+
except ImportError as e:
|
|
130
|
+
raise ImportError(
|
|
131
|
+
"safetensors library required. Install with: pip install safetensors"
|
|
132
|
+
) from e
|
|
133
|
+
|
|
134
|
+
def read(self) -> SafeTensorsInfo:
|
|
135
|
+
"""
|
|
136
|
+
Read and parse the SafeTensors file.
|
|
137
|
+
|
|
138
|
+
Returns:
|
|
139
|
+
SafeTensorsInfo with tensor information.
|
|
140
|
+
"""
|
|
141
|
+
from safetensors import safe_open
|
|
142
|
+
|
|
143
|
+
tensors = []
|
|
144
|
+
metadata = {}
|
|
145
|
+
|
|
146
|
+
with safe_open(self.path, framework="np") as f:
|
|
147
|
+
# Get metadata if present
|
|
148
|
+
meta = f.metadata()
|
|
149
|
+
if meta:
|
|
150
|
+
metadata = dict(meta)
|
|
151
|
+
|
|
152
|
+
# Get tensor info
|
|
153
|
+
for name in f.keys():
|
|
154
|
+
tensor = f.get_tensor(name)
|
|
155
|
+
info = SafeTensorInfo(
|
|
156
|
+
name=name,
|
|
157
|
+
dtype=self._numpy_dtype_to_safetensors(tensor.dtype),
|
|
158
|
+
shape=tuple(tensor.shape),
|
|
159
|
+
)
|
|
160
|
+
tensors.append(info)
|
|
161
|
+
|
|
162
|
+
return SafeTensorsInfo(
|
|
163
|
+
path=self.path,
|
|
164
|
+
tensors=tensors,
|
|
165
|
+
metadata=metadata,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
def read_header_only(self) -> SafeTensorsInfo:
|
|
169
|
+
"""
|
|
170
|
+
Read only the header without loading tensor data.
|
|
171
|
+
|
|
172
|
+
This is faster for large files when you only need metadata.
|
|
173
|
+
|
|
174
|
+
Returns:
|
|
175
|
+
SafeTensorsInfo with tensor information.
|
|
176
|
+
"""
|
|
177
|
+
import json
|
|
178
|
+
|
|
179
|
+
tensors = []
|
|
180
|
+
metadata = {}
|
|
181
|
+
|
|
182
|
+
# SafeTensors header is at the start of the file
|
|
183
|
+
# Format: 8 bytes (header size as u64) + header JSON
|
|
184
|
+
with open(self.path, "rb") as f:
|
|
185
|
+
# Read header size
|
|
186
|
+
header_size = int.from_bytes(f.read(8), "little")
|
|
187
|
+
|
|
188
|
+
# Read header JSON
|
|
189
|
+
header_bytes = f.read(header_size)
|
|
190
|
+
header = json.loads(header_bytes)
|
|
191
|
+
|
|
192
|
+
# Extract metadata
|
|
193
|
+
if "__metadata__" in header:
|
|
194
|
+
metadata = header.pop("__metadata__")
|
|
195
|
+
|
|
196
|
+
# Extract tensor info
|
|
197
|
+
for name, info in header.items():
|
|
198
|
+
dtype = info["dtype"]
|
|
199
|
+
shape = tuple(info["shape"])
|
|
200
|
+
tensors.append(SafeTensorInfo(name=name, dtype=dtype, shape=shape))
|
|
201
|
+
|
|
202
|
+
return SafeTensorsInfo(
|
|
203
|
+
path=self.path,
|
|
204
|
+
tensors=tensors,
|
|
205
|
+
metadata=metadata,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def _numpy_dtype_to_safetensors(self, dtype) -> str:
|
|
209
|
+
"""Convert numpy dtype to safetensors dtype string."""
|
|
210
|
+
import numpy as np
|
|
211
|
+
|
|
212
|
+
dtype_map = {
|
|
213
|
+
np.float64: "F64",
|
|
214
|
+
np.float32: "F32",
|
|
215
|
+
np.float16: "F16",
|
|
216
|
+
np.int64: "I64",
|
|
217
|
+
np.int32: "I32",
|
|
218
|
+
np.int16: "I16",
|
|
219
|
+
np.int8: "I8",
|
|
220
|
+
np.uint8: "U8",
|
|
221
|
+
np.bool_: "BOOL",
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
# Handle bfloat16 specially (might be stored as uint16)
|
|
225
|
+
dtype_name = str(dtype)
|
|
226
|
+
if "bfloat16" in dtype_name:
|
|
227
|
+
return "BF16"
|
|
228
|
+
|
|
229
|
+
return dtype_map.get(dtype.type, "F32")
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def is_safetensors_file(path: str | Path) -> bool:
|
|
233
|
+
"""
|
|
234
|
+
Check if a file is a valid SafeTensors file.
|
|
235
|
+
|
|
236
|
+
Args:
|
|
237
|
+
path: Path to check.
|
|
238
|
+
|
|
239
|
+
Returns:
|
|
240
|
+
True if the file appears to be a valid SafeTensors file.
|
|
241
|
+
"""
|
|
242
|
+
path = Path(path)
|
|
243
|
+
if not path.exists() or not path.is_file():
|
|
244
|
+
return False
|
|
245
|
+
|
|
246
|
+
# Check extension
|
|
247
|
+
if path.suffix.lower() != ".safetensors":
|
|
248
|
+
return False
|
|
249
|
+
|
|
250
|
+
# Try to read header
|
|
251
|
+
try:
|
|
252
|
+
with open(path, "rb") as f:
|
|
253
|
+
header_size = int.from_bytes(f.read(8), "little")
|
|
254
|
+
# Sanity check: header shouldn't be larger than 100MB
|
|
255
|
+
if header_size > 100 * 1024 * 1024:
|
|
256
|
+
return False
|
|
257
|
+
# Try to parse as JSON
|
|
258
|
+
|
|
259
|
+
header_bytes = f.read(min(header_size, 1024)) # Just read start
|
|
260
|
+
# Should start with '{'
|
|
261
|
+
return header_bytes.startswith(b"{")
|
|
262
|
+
except Exception:
|
|
263
|
+
return False
|
|
264
|
+
|
|
265
|
+
|
|
266
|
+
def is_available() -> bool:
|
|
267
|
+
"""Check if safetensors library is available."""
|
|
268
|
+
try:
|
|
269
|
+
import safetensors # noqa: F401
|
|
270
|
+
|
|
271
|
+
return True
|
|
272
|
+
except ImportError:
|
|
273
|
+
return False
|