sft-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sft/__init__.py +3 -0
- sft/browser.py +947 -0
- sft/cli.py +63 -0
- sft/index.py +197 -0
- sft_cli-0.1.0.dist-info/METADATA +115 -0
- sft_cli-0.1.0.dist-info/RECORD +8 -0
- sft_cli-0.1.0.dist-info/WHEEL +4 -0
- sft_cli-0.1.0.dist-info/entry_points.txt +2 -0
sft/__init__.py
ADDED
sft/browser.py
ADDED
|
@@ -0,0 +1,947 @@
|
|
|
1
|
+
"""Textual TUI application for browsing safetensors files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
from enum import Enum
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
|
|
9
|
+
from rich.text import Text
|
|
10
|
+
from textual.app import App, ComposeResult
|
|
11
|
+
from textual.binding import Binding
|
|
12
|
+
from textual.containers import Container
|
|
13
|
+
from textual.message import Message
|
|
14
|
+
from textual.screen import ModalScreen
|
|
15
|
+
from textual.widgets import DataTable, Footer, Input, Label, Static, Tree
|
|
16
|
+
from textual.widgets.tree import TreeNode
|
|
17
|
+
|
|
18
|
+
from sft.index import (
|
|
19
|
+
PrefixTree,
|
|
20
|
+
PrefixTreeNode,
|
|
21
|
+
TensorIndex,
|
|
22
|
+
TensorInfo,
|
|
23
|
+
natural_sort_key,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def format_bytes(nbytes: int) -> str:
|
|
28
|
+
"""Format bytes as human-readable string."""
|
|
29
|
+
if nbytes < 1024:
|
|
30
|
+
return f"{nbytes} B"
|
|
31
|
+
elif nbytes < 1024 * 1024:
|
|
32
|
+
return f"{nbytes / 1024:.1f} KB"
|
|
33
|
+
elif nbytes < 1024 * 1024 * 1024:
|
|
34
|
+
return f"{nbytes / 1024 / 1024:.1f} MB"
|
|
35
|
+
else:
|
|
36
|
+
return f"{nbytes / 1024 / 1024 / 1024:.2f} GB"
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def format_shape(shape: tuple[int, ...]) -> str:
|
|
40
|
+
"""Format tensor shape as string."""
|
|
41
|
+
if len(shape) == 0:
|
|
42
|
+
return "()"
|
|
43
|
+
return f"({', '.join(str(d) for d in shape)})"
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class SortMode(Enum):
|
|
47
|
+
"""Sort modes for tensor table."""
|
|
48
|
+
|
|
49
|
+
NAME_ASC = "name ↑"
|
|
50
|
+
NAME_DESC = "name ↓"
|
|
51
|
+
SIZE_ASC = "size ↑"
|
|
52
|
+
SIZE_DESC = "size ↓"
|
|
53
|
+
RANK_ASC = "rank ↑"
|
|
54
|
+
RANK_DESC = "rank ↓"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
SORT_ORDER = [
|
|
58
|
+
SortMode.NAME_ASC,
|
|
59
|
+
SortMode.NAME_DESC,
|
|
60
|
+
SortMode.SIZE_DESC,
|
|
61
|
+
SortMode.SIZE_ASC,
|
|
62
|
+
SortMode.RANK_DESC,
|
|
63
|
+
SortMode.RANK_ASC,
|
|
64
|
+
]
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
class TensorDetailScreen(ModalScreen):
|
|
68
|
+
"""Modal screen showing tensor details."""
|
|
69
|
+
|
|
70
|
+
CSS = """
|
|
71
|
+
TensorDetailScreen {
|
|
72
|
+
align: center middle;
|
|
73
|
+
}
|
|
74
|
+
|
|
75
|
+
#detail-container {
|
|
76
|
+
width: 60;
|
|
77
|
+
height: auto;
|
|
78
|
+
max-height: 80%;
|
|
79
|
+
background: $surface;
|
|
80
|
+
border: thick $primary;
|
|
81
|
+
padding: 1 2;
|
|
82
|
+
}
|
|
83
|
+
|
|
84
|
+
#detail-title {
|
|
85
|
+
text-align: center;
|
|
86
|
+
text-style: bold;
|
|
87
|
+
margin-bottom: 1;
|
|
88
|
+
}
|
|
89
|
+
|
|
90
|
+
.detail-row {
|
|
91
|
+
margin: 0;
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
.detail-label {
|
|
95
|
+
color: $text-muted;
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
.detail-value {
|
|
99
|
+
color: $text;
|
|
100
|
+
}
|
|
101
|
+
"""
|
|
102
|
+
|
|
103
|
+
BINDINGS = [
|
|
104
|
+
Binding("escape", "dismiss", "Close"),
|
|
105
|
+
Binding("space", "dismiss", "Close"),
|
|
106
|
+
]
|
|
107
|
+
|
|
108
|
+
def __init__(self, tensor: TensorInfo) -> None:
|
|
109
|
+
super().__init__()
|
|
110
|
+
self.tensor = tensor
|
|
111
|
+
|
|
112
|
+
def compose(self) -> ComposeResult:
|
|
113
|
+
t = self.tensor
|
|
114
|
+
with Container(id="detail-container"):
|
|
115
|
+
yield Label("Tensor Details", id="detail-title")
|
|
116
|
+
yield Static(f"[dim]Name:[/dim] {t.full_name}", classes="detail-row")
|
|
117
|
+
yield Static(
|
|
118
|
+
f"[dim]Shape:[/dim] {format_shape(t.shape)}", classes="detail-row"
|
|
119
|
+
)
|
|
120
|
+
yield Static(f"[dim]Rank:[/dim] {t.rank}", classes="detail-row")
|
|
121
|
+
yield Static(f"[dim]Dtype:[/dim] {t.dtype}", classes="detail-row")
|
|
122
|
+
yield Static(
|
|
123
|
+
f"[dim]Size:[/dim] {format_bytes(t.nbytes)} ({t.nbytes:,} bytes)",
|
|
124
|
+
classes="detail-row",
|
|
125
|
+
)
|
|
126
|
+
yield Static(f"[dim]Numel:[/dim] {t.numel:,}", classes="detail-row")
|
|
127
|
+
yield Static(
|
|
128
|
+
"\n[dim]Press ESC or SPACE to close[/dim]", classes="detail-row"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
class MetadataScreen(ModalScreen):
|
|
133
|
+
"""Modal screen showing file metadata."""
|
|
134
|
+
|
|
135
|
+
CSS = """
|
|
136
|
+
MetadataScreen {
|
|
137
|
+
align: center middle;
|
|
138
|
+
}
|
|
139
|
+
|
|
140
|
+
#metadata-container {
|
|
141
|
+
width: 70;
|
|
142
|
+
height: auto;
|
|
143
|
+
max-height: 80%;
|
|
144
|
+
background: $surface;
|
|
145
|
+
border: thick $secondary;
|
|
146
|
+
padding: 1 2;
|
|
147
|
+
}
|
|
148
|
+
|
|
149
|
+
#metadata-title {
|
|
150
|
+
text-align: center;
|
|
151
|
+
text-style: bold;
|
|
152
|
+
margin-bottom: 1;
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
#metadata-content {
|
|
156
|
+
height: auto;
|
|
157
|
+
max-height: 20;
|
|
158
|
+
}
|
|
159
|
+
"""
|
|
160
|
+
|
|
161
|
+
BINDINGS = [
|
|
162
|
+
Binding("escape", "dismiss", "Close"),
|
|
163
|
+
Binding("m", "dismiss", "Close"),
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
def __init__(self, metadata: dict, file_path: Path) -> None:
|
|
167
|
+
super().__init__()
|
|
168
|
+
self.metadata = metadata
|
|
169
|
+
self.file_path = file_path
|
|
170
|
+
|
|
171
|
+
def compose(self) -> ComposeResult:
|
|
172
|
+
with Container(id="metadata-container"):
|
|
173
|
+
yield Label("File Metadata", id="metadata-title")
|
|
174
|
+
yield Static(f"[dim]File:[/dim] {self.file_path.name}")
|
|
175
|
+
|
|
176
|
+
if self.metadata:
|
|
177
|
+
formatted = json.dumps(self.metadata, indent=2)
|
|
178
|
+
yield Static(f"\n{formatted}", id="metadata-content")
|
|
179
|
+
else:
|
|
180
|
+
yield Static("\n[dim]No metadata found in file[/dim]")
|
|
181
|
+
|
|
182
|
+
yield Static("\n[dim]Press ESC or M to close[/dim]")
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
class FilterScreen(ModalScreen):
|
|
186
|
+
"""Modal screen for filtering tensors."""
|
|
187
|
+
|
|
188
|
+
CSS = """
|
|
189
|
+
FilterScreen {
|
|
190
|
+
align: center middle;
|
|
191
|
+
}
|
|
192
|
+
|
|
193
|
+
#filter-container {
|
|
194
|
+
width: 50;
|
|
195
|
+
height: auto;
|
|
196
|
+
max-height: 80%;
|
|
197
|
+
background: $surface;
|
|
198
|
+
border: thick $accent;
|
|
199
|
+
padding: 1 2;
|
|
200
|
+
}
|
|
201
|
+
|
|
202
|
+
#filter-title {
|
|
203
|
+
text-align: center;
|
|
204
|
+
text-style: bold;
|
|
205
|
+
margin-bottom: 1;
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
.filter-section {
|
|
209
|
+
margin: 1 0;
|
|
210
|
+
}
|
|
211
|
+
|
|
212
|
+
.filter-label {
|
|
213
|
+
color: $text-muted;
|
|
214
|
+
margin-bottom: 0;
|
|
215
|
+
}
|
|
216
|
+
|
|
217
|
+
.filter-options {
|
|
218
|
+
margin-left: 2;
|
|
219
|
+
}
|
|
220
|
+
|
|
221
|
+
.dtype-option {
|
|
222
|
+
margin: 0;
|
|
223
|
+
}
|
|
224
|
+
|
|
225
|
+
.dtype-option.selected {
|
|
226
|
+
color: $success;
|
|
227
|
+
}
|
|
228
|
+
"""
|
|
229
|
+
|
|
230
|
+
BINDINGS = [
|
|
231
|
+
Binding("escape", "dismiss", "Close"),
|
|
232
|
+
Binding("f", "dismiss", "Close"),
|
|
233
|
+
Binding("c", "clear_filters", "Clear All"),
|
|
234
|
+
Binding("1", "toggle_dtype_0", "Toggle", show=False),
|
|
235
|
+
Binding("2", "toggle_dtype_1", "Toggle", show=False),
|
|
236
|
+
Binding("3", "toggle_dtype_2", "Toggle", show=False),
|
|
237
|
+
Binding("4", "toggle_dtype_3", "Toggle", show=False),
|
|
238
|
+
Binding("5", "toggle_dtype_4", "Toggle", show=False),
|
|
239
|
+
]
|
|
240
|
+
|
|
241
|
+
COMMON_DTYPES = ["F16", "F32", "BF16", "I8", "I32"]
|
|
242
|
+
|
|
243
|
+
def __init__(self, current_filters: dict, available_dtypes: set[str]) -> None:
|
|
244
|
+
super().__init__()
|
|
245
|
+
self.current_filters = current_filters.copy()
|
|
246
|
+
self.available_dtypes = sorted(available_dtypes)
|
|
247
|
+
self.selected_dtypes: set[str] = set(current_filters.get("dtypes", []))
|
|
248
|
+
|
|
249
|
+
def compose(self) -> ComposeResult:
|
|
250
|
+
with Container(id="filter-container"):
|
|
251
|
+
yield Label("Filter Tensors", id="filter-title")
|
|
252
|
+
|
|
253
|
+
# Dtype filter
|
|
254
|
+
yield Static("[bold]Dtype Filter[/bold]", classes="filter-section")
|
|
255
|
+
for i, dtype in enumerate(self.available_dtypes[:5]):
|
|
256
|
+
selected = "✓" if dtype in self.selected_dtypes else " "
|
|
257
|
+
css_class = (
|
|
258
|
+
"dtype-option selected"
|
|
259
|
+
if dtype in self.selected_dtypes
|
|
260
|
+
else "dtype-option"
|
|
261
|
+
)
|
|
262
|
+
yield Static(
|
|
263
|
+
f" [{i + 1}] {selected} {dtype}",
|
|
264
|
+
classes=css_class,
|
|
265
|
+
id=f"dtype-{i}",
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
yield Static("\n[dim]Keys:[/dim]", classes="filter-section")
|
|
269
|
+
yield Static(" [1-5] Toggle dtype")
|
|
270
|
+
yield Static(" [c] Clear all filters")
|
|
271
|
+
yield Static(" [ESC/f] Close")
|
|
272
|
+
|
|
273
|
+
def _toggle_dtype(self, index: int) -> None:
|
|
274
|
+
"""Toggle a dtype filter."""
|
|
275
|
+
if index >= len(self.available_dtypes):
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
dtype = self.available_dtypes[index]
|
|
279
|
+
if dtype in self.selected_dtypes:
|
|
280
|
+
self.selected_dtypes.discard(dtype)
|
|
281
|
+
else:
|
|
282
|
+
self.selected_dtypes.add(dtype)
|
|
283
|
+
|
|
284
|
+
# Update display
|
|
285
|
+
selected = "✓" if dtype in self.selected_dtypes else " "
|
|
286
|
+
widget = self.query_one(f"#dtype-{index}", Static)
|
|
287
|
+
widget.update(f" [{index + 1}] {selected} {dtype}")
|
|
288
|
+
if dtype in self.selected_dtypes:
|
|
289
|
+
widget.add_class("selected")
|
|
290
|
+
else:
|
|
291
|
+
widget.remove_class("selected")
|
|
292
|
+
|
|
293
|
+
def action_toggle_dtype_0(self) -> None:
|
|
294
|
+
self._toggle_dtype(0)
|
|
295
|
+
|
|
296
|
+
def action_toggle_dtype_1(self) -> None:
|
|
297
|
+
self._toggle_dtype(1)
|
|
298
|
+
|
|
299
|
+
def action_toggle_dtype_2(self) -> None:
|
|
300
|
+
self._toggle_dtype(2)
|
|
301
|
+
|
|
302
|
+
def action_toggle_dtype_3(self) -> None:
|
|
303
|
+
self._toggle_dtype(3)
|
|
304
|
+
|
|
305
|
+
def action_toggle_dtype_4(self) -> None:
|
|
306
|
+
self._toggle_dtype(4)
|
|
307
|
+
|
|
308
|
+
def action_clear_filters(self) -> None:
|
|
309
|
+
"""Clear all filters."""
|
|
310
|
+
self.selected_dtypes.clear()
|
|
311
|
+
for i in range(min(5, len(self.available_dtypes))):
|
|
312
|
+
widget = self.query_one(f"#dtype-{i}", Static)
|
|
313
|
+
dtype = self.available_dtypes[i]
|
|
314
|
+
widget.update(f" [{i + 1}] {dtype}")
|
|
315
|
+
widget.remove_class("selected")
|
|
316
|
+
|
|
317
|
+
def action_dismiss(self) -> None:
|
|
318
|
+
"""Dismiss and return filters."""
|
|
319
|
+
filters = {}
|
|
320
|
+
if self.selected_dtypes:
|
|
321
|
+
filters["dtypes"] = list(self.selected_dtypes)
|
|
322
|
+
self.dismiss(filters)
|
|
323
|
+
|
|
324
|
+
|
|
325
|
+
class FilteredPrefixTree:
|
|
326
|
+
"""A filtered view of a PrefixTree containing only matching tensors."""
|
|
327
|
+
|
|
328
|
+
def __init__(
|
|
329
|
+
self, original_tree: PrefixTree, matching_tensors: list[TensorInfo]
|
|
330
|
+
) -> None:
|
|
331
|
+
"""Build a filtered tree from matching tensors."""
|
|
332
|
+
self.original_tree = original_tree
|
|
333
|
+
self.index = original_tree.index
|
|
334
|
+
self.delimiter = original_tree.delimiter
|
|
335
|
+
self.matching_tensor_names = {t.full_name for t in matching_tensors}
|
|
336
|
+
|
|
337
|
+
# Build filtered tree structure
|
|
338
|
+
self.root = self._build_filtered_node(original_tree.root, "")
|
|
339
|
+
|
|
340
|
+
def _build_filtered_node(
|
|
341
|
+
self, original_node: PrefixTreeNode, prefix: str
|
|
342
|
+
) -> PrefixTreeNode | None:
|
|
343
|
+
"""Recursively build a filtered node, returning None if no matches."""
|
|
344
|
+
# Check direct tensors
|
|
345
|
+
matching_direct = [
|
|
346
|
+
tid
|
|
347
|
+
for tid in original_node.tensor_ids
|
|
348
|
+
if self.index.tensors[tid].full_name in self.matching_tensor_names
|
|
349
|
+
]
|
|
350
|
+
|
|
351
|
+
# Recursively filter children
|
|
352
|
+
filtered_children: dict[str, PrefixTreeNode] = {}
|
|
353
|
+
for child_name, child_node in original_node.children.items():
|
|
354
|
+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
|
|
355
|
+
filtered_child = self._build_filtered_node(child_node, child_prefix)
|
|
356
|
+
if filtered_child is not None:
|
|
357
|
+
filtered_children[child_name] = filtered_child
|
|
358
|
+
|
|
359
|
+
# If no matches in this subtree, return None
|
|
360
|
+
if not matching_direct and not filtered_children:
|
|
361
|
+
return None
|
|
362
|
+
|
|
363
|
+
# Create filtered node
|
|
364
|
+
node = PrefixTreeNode(name=original_node.name)
|
|
365
|
+
node.tensor_ids = matching_direct
|
|
366
|
+
node.children = filtered_children
|
|
367
|
+
|
|
368
|
+
# Compute aggregates
|
|
369
|
+
direct_count = len(matching_direct)
|
|
370
|
+
direct_bytes = sum(self.index.tensors[tid].nbytes for tid in matching_direct)
|
|
371
|
+
child_count = sum(c.aggregate_count for c in filtered_children.values())
|
|
372
|
+
child_bytes = sum(c.aggregate_bytes for c in filtered_children.values())
|
|
373
|
+
|
|
374
|
+
node.aggregate_count = direct_count + child_count
|
|
375
|
+
node.aggregate_bytes = direct_bytes + child_bytes
|
|
376
|
+
|
|
377
|
+
return node
|
|
378
|
+
|
|
379
|
+
def get_tensors_under(self, prefix: str) -> list[TensorInfo]:
|
|
380
|
+
"""Get all matching tensors under a given prefix."""
|
|
381
|
+
if self.root is None:
|
|
382
|
+
return []
|
|
383
|
+
|
|
384
|
+
if not prefix:
|
|
385
|
+
return [
|
|
386
|
+
t
|
|
387
|
+
for t in self.index.tensors
|
|
388
|
+
if t.full_name in self.matching_tensor_names
|
|
389
|
+
]
|
|
390
|
+
|
|
391
|
+
# Navigate to the prefix node
|
|
392
|
+
parts = prefix.split(self.delimiter)
|
|
393
|
+
node = self.root
|
|
394
|
+
|
|
395
|
+
for part in parts:
|
|
396
|
+
if part in node.children:
|
|
397
|
+
node = node.children[part]
|
|
398
|
+
else:
|
|
399
|
+
return []
|
|
400
|
+
|
|
401
|
+
# Collect all tensor IDs under this node
|
|
402
|
+
tensor_ids = self._collect_tensor_ids(node)
|
|
403
|
+
return [self.index.tensors[tid] for tid in tensor_ids]
|
|
404
|
+
|
|
405
|
+
def _collect_tensor_ids(self, node: PrefixTreeNode) -> list[int]:
|
|
406
|
+
"""Recursively collect all tensor IDs under a node."""
|
|
407
|
+
ids = list(node.tensor_ids)
|
|
408
|
+
for child in node.children.values():
|
|
409
|
+
ids.extend(self._collect_tensor_ids(child))
|
|
410
|
+
return ids
|
|
411
|
+
|
|
412
|
+
|
|
413
|
+
class HierarchyTree(Tree):
|
|
414
|
+
"""Tree widget for navigating tensor namespaces."""
|
|
415
|
+
|
|
416
|
+
BINDINGS = [
|
|
417
|
+
Binding("left", "collapse_node", "Collapse", show=False),
|
|
418
|
+
Binding("right", "expand_node", "Expand", show=False),
|
|
419
|
+
Binding("enter", "toggle_node", "Toggle", show=False),
|
|
420
|
+
]
|
|
421
|
+
|
|
422
|
+
class NodeSelected(Message):
|
|
423
|
+
"""Message sent when a tree node is selected or highlighted."""
|
|
424
|
+
|
|
425
|
+
def __init__(self, prefix: str, node: PrefixTreeNode) -> None:
|
|
426
|
+
self.prefix = prefix
|
|
427
|
+
self.node = node
|
|
428
|
+
super().__init__()
|
|
429
|
+
|
|
430
|
+
def __init__(self, prefix_tree: PrefixTree) -> None:
|
|
431
|
+
super().__init__("root")
|
|
432
|
+
self.prefix_tree = prefix_tree
|
|
433
|
+
self.filtered_tree: FilteredPrefixTree | None = None
|
|
434
|
+
self._node_prefixes: dict[TreeNode, str] = {}
|
|
435
|
+
|
|
436
|
+
@property
|
|
437
|
+
def active_tree(self) -> PrefixTree | FilteredPrefixTree:
|
|
438
|
+
"""Return the currently active tree (filtered or original)."""
|
|
439
|
+
return self.filtered_tree if self.filtered_tree else self.prefix_tree
|
|
440
|
+
|
|
441
|
+
def on_mount(self) -> None:
|
|
442
|
+
"""Build the tree when mounted."""
|
|
443
|
+
self._rebuild_tree_view()
|
|
444
|
+
|
|
445
|
+
def _rebuild_tree_view(self) -> None:
|
|
446
|
+
"""Rebuild the tree view from the active tree."""
|
|
447
|
+
# Clear existing tree
|
|
448
|
+
self.root.remove_children()
|
|
449
|
+
self._node_prefixes.clear()
|
|
450
|
+
|
|
451
|
+
active = self.active_tree
|
|
452
|
+
if active.root is None:
|
|
453
|
+
# No matches - show empty state
|
|
454
|
+
self.root.set_label(
|
|
455
|
+
self._make_label(
|
|
456
|
+
self.prefix_tree.index.file_path.name + " (no matches)",
|
|
457
|
+
0,
|
|
458
|
+
0,
|
|
459
|
+
)
|
|
460
|
+
)
|
|
461
|
+
self._node_prefixes[self.root] = ""
|
|
462
|
+
return
|
|
463
|
+
|
|
464
|
+
self.root.expand()
|
|
465
|
+
self._build_tree(self.root, active.root, "")
|
|
466
|
+
|
|
467
|
+
# Update root label
|
|
468
|
+
self.root.set_label(
|
|
469
|
+
self._make_label(
|
|
470
|
+
self.prefix_tree.index.file_path.name,
|
|
471
|
+
active.root.aggregate_count,
|
|
472
|
+
active.root.aggregate_bytes,
|
|
473
|
+
)
|
|
474
|
+
)
|
|
475
|
+
self._node_prefixes[self.root] = ""
|
|
476
|
+
|
|
477
|
+
def apply_filter(
|
|
478
|
+
self, matching_tensors: list[TensorInfo] | None, query: str = ""
|
|
479
|
+
) -> None:
|
|
480
|
+
"""Apply a filter to the tree, showing only matching tensors."""
|
|
481
|
+
if matching_tensors is None:
|
|
482
|
+
# Clear filter
|
|
483
|
+
self.filtered_tree = None
|
|
484
|
+
self.border_subtitle = ""
|
|
485
|
+
else:
|
|
486
|
+
# Create filtered tree
|
|
487
|
+
self.filtered_tree = FilteredPrefixTree(self.prefix_tree, matching_tensors)
|
|
488
|
+
# Show search query in border subtitle
|
|
489
|
+
if query:
|
|
490
|
+
self.border_subtitle = f"search: {query}"
|
|
491
|
+
|
|
492
|
+
self._rebuild_tree_view()
|
|
493
|
+
|
|
494
|
+
def _make_label(self, name: str, count: int, nbytes: int) -> Text:
|
|
495
|
+
"""Create a formatted label for a tree node."""
|
|
496
|
+
label = Text()
|
|
497
|
+
label.append(name, style="bold")
|
|
498
|
+
label.append(f" ({count}, {format_bytes(nbytes)})", style="dim")
|
|
499
|
+
return label
|
|
500
|
+
|
|
501
|
+
def _build_tree(self, parent: TreeNode, node: PrefixTreeNode, prefix: str) -> None:
|
|
502
|
+
"""Recursively build tree nodes."""
|
|
503
|
+
for child_name, child_node in sorted(
|
|
504
|
+
node.children.items(), key=lambda x: natural_sort_key(x[0])
|
|
505
|
+
):
|
|
506
|
+
child_prefix = f"{prefix}.{child_name}" if prefix else child_name
|
|
507
|
+
|
|
508
|
+
# Use add_leaf for nodes without children (no expand icon)
|
|
509
|
+
# Use add for nodes with children (expandable)
|
|
510
|
+
if child_node.children:
|
|
511
|
+
tree_node = parent.add(
|
|
512
|
+
self._make_label(
|
|
513
|
+
child_name,
|
|
514
|
+
child_node.aggregate_count,
|
|
515
|
+
child_node.aggregate_bytes,
|
|
516
|
+
),
|
|
517
|
+
expand=False,
|
|
518
|
+
)
|
|
519
|
+
self._node_prefixes[tree_node] = child_prefix
|
|
520
|
+
self._build_tree(tree_node, child_node, child_prefix)
|
|
521
|
+
else:
|
|
522
|
+
tree_node = parent.add_leaf(
|
|
523
|
+
self._make_label(
|
|
524
|
+
child_name,
|
|
525
|
+
child_node.aggregate_count,
|
|
526
|
+
child_node.aggregate_bytes,
|
|
527
|
+
),
|
|
528
|
+
)
|
|
529
|
+
self._node_prefixes[tree_node] = child_prefix
|
|
530
|
+
|
|
531
|
+
def _get_prefix_tree_node(self, tree_node: TreeNode) -> tuple[str, PrefixTreeNode]:
|
|
532
|
+
"""Get the prefix and PrefixTreeNode for a given tree node."""
|
|
533
|
+
prefix = self._node_prefixes.get(tree_node, "")
|
|
534
|
+
|
|
535
|
+
# Navigate to find the actual PrefixTreeNode in the active tree
|
|
536
|
+
active = self.active_tree
|
|
537
|
+
node = active.root
|
|
538
|
+
if node is None:
|
|
539
|
+
# Return a dummy empty node
|
|
540
|
+
return prefix, PrefixTreeNode(name="")
|
|
541
|
+
|
|
542
|
+
if prefix:
|
|
543
|
+
for part in prefix.split(active.delimiter):
|
|
544
|
+
if part in node.children:
|
|
545
|
+
node = node.children[part]
|
|
546
|
+
else:
|
|
547
|
+
break
|
|
548
|
+
|
|
549
|
+
return prefix, node
|
|
550
|
+
|
|
551
|
+
def on_tree_node_highlighted(self, event: Tree.NodeHighlighted) -> None:
|
|
552
|
+
"""Handle node highlight (cursor movement) - update right panel."""
|
|
553
|
+
prefix, node = self._get_prefix_tree_node(event.node)
|
|
554
|
+
self.post_message(self.NodeSelected(prefix, node))
|
|
555
|
+
|
|
556
|
+
def action_toggle_node(self) -> None:
|
|
557
|
+
"""Toggle expand/collapse of the currently highlighted node."""
|
|
558
|
+
if self.cursor_node and self.cursor_node.children:
|
|
559
|
+
self.cursor_node.toggle()
|
|
560
|
+
|
|
561
|
+
def action_collapse_node(self) -> None:
|
|
562
|
+
"""Collapse the currently highlighted node."""
|
|
563
|
+
if self.cursor_node and self.cursor_node.is_expanded:
|
|
564
|
+
self.cursor_node.collapse()
|
|
565
|
+
elif self.cursor_node and self.cursor_node.parent:
|
|
566
|
+
# If already collapsed, go to parent
|
|
567
|
+
self.select_node(self.cursor_node.parent)
|
|
568
|
+
|
|
569
|
+
def action_expand_node(self) -> None:
|
|
570
|
+
"""Expand the currently highlighted node."""
|
|
571
|
+
if self.cursor_node and not self.cursor_node.is_expanded:
|
|
572
|
+
self.cursor_node.expand()
|
|
573
|
+
|
|
574
|
+
|
|
575
|
+
class TensorTable(DataTable):
|
|
576
|
+
"""Table widget for displaying tensor information."""
|
|
577
|
+
|
|
578
|
+
def __init__(self) -> None:
|
|
579
|
+
super().__init__()
|
|
580
|
+
self.cursor_type = "row"
|
|
581
|
+
self.zebra_stripes = True
|
|
582
|
+
self._tensors: list[TensorInfo] = []
|
|
583
|
+
self._current_prefix: str = ""
|
|
584
|
+
self._sort_mode: SortMode | None = None
|
|
585
|
+
self._columns_initialized: bool = False
|
|
586
|
+
|
|
587
|
+
def on_mount(self) -> None:
|
|
588
|
+
"""Set up table columns."""
|
|
589
|
+
self._setup_columns()
|
|
590
|
+
|
|
591
|
+
def _setup_columns(self) -> None:
|
|
592
|
+
"""Set up table columns (only once)."""
|
|
593
|
+
if self._columns_initialized:
|
|
594
|
+
return
|
|
595
|
+
self.add_column("Name", key="name")
|
|
596
|
+
self.add_column("Shape", key="shape")
|
|
597
|
+
self.add_column("Dtype", key="dtype")
|
|
598
|
+
self.add_column("Size", key="size")
|
|
599
|
+
self._columns_initialized = True
|
|
600
|
+
|
|
601
|
+
def _get_sort_indicator(self) -> str:
|
|
602
|
+
"""Get a string indicating the current sort mode."""
|
|
603
|
+
if self._sort_mode is None:
|
|
604
|
+
return ""
|
|
605
|
+
return f" [{self._sort_mode.value}]"
|
|
606
|
+
|
|
607
|
+
def update_tensors(self, tensors: list[TensorInfo], prefix: str = "") -> None:
|
|
608
|
+
"""Update the table with a list of tensors."""
|
|
609
|
+
self._tensors = tensors
|
|
610
|
+
self._current_prefix = prefix
|
|
611
|
+
self._refresh_table()
|
|
612
|
+
|
|
613
|
+
def _refresh_table(self) -> None:
|
|
614
|
+
"""Refresh the table contents."""
|
|
615
|
+
self.clear()
|
|
616
|
+
|
|
617
|
+
for tensor in self._tensors:
|
|
618
|
+
# Add sort indicator to the first row's name if sorting is active
|
|
619
|
+
self.add_row(
|
|
620
|
+
tensor.full_name,
|
|
621
|
+
format_shape(tensor.shape),
|
|
622
|
+
tensor.dtype,
|
|
623
|
+
format_bytes(tensor.nbytes),
|
|
624
|
+
key=tensor.full_name,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
# Update border subtitle to show sort mode
|
|
628
|
+
if self._sort_mode:
|
|
629
|
+
self.border_subtitle = f"sort: {self._sort_mode.value}"
|
|
630
|
+
else:
|
|
631
|
+
self.border_subtitle = ""
|
|
632
|
+
|
|
633
|
+
def get_selected_tensor(self) -> TensorInfo | None:
|
|
634
|
+
"""Get the currently selected tensor."""
|
|
635
|
+
if self.cursor_row is None or self.cursor_row >= len(self._tensors):
|
|
636
|
+
return None
|
|
637
|
+
return self._tensors[self.cursor_row]
|
|
638
|
+
|
|
639
|
+
def sort_by(self, mode: SortMode) -> None:
|
|
640
|
+
"""Sort tensors by the given mode."""
|
|
641
|
+
self._sort_mode = mode
|
|
642
|
+
|
|
643
|
+
if mode == SortMode.NAME_ASC:
|
|
644
|
+
self._tensors.sort(key=lambda t: natural_sort_key(t.full_name))
|
|
645
|
+
elif mode == SortMode.NAME_DESC:
|
|
646
|
+
self._tensors.sort(
|
|
647
|
+
key=lambda t: natural_sort_key(t.full_name), reverse=True
|
|
648
|
+
)
|
|
649
|
+
elif mode == SortMode.SIZE_ASC:
|
|
650
|
+
self._tensors.sort(key=lambda t: t.nbytes)
|
|
651
|
+
elif mode == SortMode.SIZE_DESC:
|
|
652
|
+
self._tensors.sort(key=lambda t: t.nbytes, reverse=True)
|
|
653
|
+
elif mode == SortMode.RANK_ASC:
|
|
654
|
+
self._tensors.sort(key=lambda t: (t.rank, natural_sort_key(t.full_name)))
|
|
655
|
+
elif mode == SortMode.RANK_DESC:
|
|
656
|
+
self._tensors.sort(key=lambda t: (-t.rank, natural_sort_key(t.full_name)))
|
|
657
|
+
|
|
658
|
+
self._refresh_table()
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
class SearchInput(Input):
|
|
662
|
+
"""Search input widget."""
|
|
663
|
+
|
|
664
|
+
DEFAULT_CSS = """
|
|
665
|
+
SearchInput {
|
|
666
|
+
display: none;
|
|
667
|
+
height: 3;
|
|
668
|
+
border: solid $accent;
|
|
669
|
+
background: $surface;
|
|
670
|
+
}
|
|
671
|
+
|
|
672
|
+
SearchInput.visible {
|
|
673
|
+
display: block;
|
|
674
|
+
}
|
|
675
|
+
"""
|
|
676
|
+
|
|
677
|
+
def __init__(self) -> None:
|
|
678
|
+
super().__init__(placeholder="Type to search...")
|
|
679
|
+
self.border_title = "Search (ESC to cancel)"
|
|
680
|
+
|
|
681
|
+
|
|
682
|
+
class SftApp(App):
|
|
683
|
+
"""Interactive browser for .safetensors files."""
|
|
684
|
+
|
|
685
|
+
TITLE = "sft"
|
|
686
|
+
|
|
687
|
+
CSS = """
|
|
688
|
+
Screen {
|
|
689
|
+
layout: grid;
|
|
690
|
+
grid-size: 2 1;
|
|
691
|
+
grid-columns: 1fr 2fr;
|
|
692
|
+
}
|
|
693
|
+
|
|
694
|
+
HierarchyTree {
|
|
695
|
+
height: 100%;
|
|
696
|
+
border: solid $primary;
|
|
697
|
+
scrollbar-gutter: stable;
|
|
698
|
+
}
|
|
699
|
+
|
|
700
|
+
TensorTable {
|
|
701
|
+
height: 100%;
|
|
702
|
+
border: solid $secondary;
|
|
703
|
+
}
|
|
704
|
+
|
|
705
|
+
SearchInput {
|
|
706
|
+
column-span: 2;
|
|
707
|
+
dock: bottom;
|
|
708
|
+
}
|
|
709
|
+
"""
|
|
710
|
+
|
|
711
|
+
BINDINGS = [
|
|
712
|
+
Binding("q", "quit", "Quit"),
|
|
713
|
+
Binding("tab", "toggle_panel", "Switch Panel", show=True),
|
|
714
|
+
Binding("slash", "start_search", "Search", show=True),
|
|
715
|
+
Binding("escape", "cancel_search", "Cancel", show=False),
|
|
716
|
+
Binding("s", "cycle_sort", "Sort", show=True),
|
|
717
|
+
Binding("f", "show_filters", "Filter", show=True),
|
|
718
|
+
Binding("space", "show_details", "Details", show=True),
|
|
719
|
+
Binding("m", "show_metadata", "Metadata", show=True),
|
|
720
|
+
Binding("g", "goto_top", "Top", show=False),
|
|
721
|
+
Binding("G", "goto_bottom", "Bottom", show=False),
|
|
722
|
+
]
|
|
723
|
+
|
|
724
|
+
def __init__(self, file_path: Path) -> None:
|
|
725
|
+
"""Initialize the app with a safetensors file path."""
|
|
726
|
+
super().__init__()
|
|
727
|
+
self.file_path = file_path
|
|
728
|
+
self.index: TensorIndex | None = None
|
|
729
|
+
self.prefix_tree: PrefixTree | None = None
|
|
730
|
+
self._current_prefix: str = ""
|
|
731
|
+
self._all_tensors: list[TensorInfo] = []
|
|
732
|
+
self._base_tensors: list[TensorInfo] = [] # Before any filtering
|
|
733
|
+
self._sort_mode_index: int = 0
|
|
734
|
+
self._search_active: bool = False
|
|
735
|
+
self._current_filters: dict = {}
|
|
736
|
+
|
|
737
|
+
def compose(self) -> ComposeResult:
|
|
738
|
+
"""Compose the UI layout."""
|
|
739
|
+
yield Footer()
|
|
740
|
+
|
|
741
|
+
# Parse the file
|
|
742
|
+
try:
|
|
743
|
+
self.index = TensorIndex.from_file(self.file_path)
|
|
744
|
+
self.prefix_tree = PrefixTree(self.index)
|
|
745
|
+
self._all_tensors = self.index.tensors.copy()
|
|
746
|
+
self._base_tensors = self.index.tensors.copy()
|
|
747
|
+
except Exception as e:
|
|
748
|
+
yield Static(f"Error loading file: {e}", id="error")
|
|
749
|
+
return
|
|
750
|
+
|
|
751
|
+
yield HierarchyTree(self.prefix_tree)
|
|
752
|
+
yield TensorTable()
|
|
753
|
+
yield SearchInput()
|
|
754
|
+
|
|
755
|
+
def on_mount(self) -> None:
|
|
756
|
+
"""Initialize the view after mounting."""
|
|
757
|
+
if self.index is None:
|
|
758
|
+
return
|
|
759
|
+
|
|
760
|
+
# Show all tensors initially
|
|
761
|
+
table = self.query_one(TensorTable)
|
|
762
|
+
table.update_tensors(self.index.tensors)
|
|
763
|
+
|
|
764
|
+
# Focus the tree
|
|
765
|
+
tree = self.query_one(HierarchyTree)
|
|
766
|
+
tree.focus()
|
|
767
|
+
|
|
768
|
+
def on_hierarchy_tree_node_selected(
|
|
769
|
+
self, event: HierarchyTree.NodeSelected
|
|
770
|
+
) -> None:
|
|
771
|
+
"""Handle tree node selection."""
|
|
772
|
+
self._current_prefix = event.prefix
|
|
773
|
+
|
|
774
|
+
# Get tensors under this prefix from the active tree
|
|
775
|
+
tree = self.query_one(HierarchyTree)
|
|
776
|
+
tensors = tree.active_tree.get_tensors_under(event.prefix)
|
|
777
|
+
self._base_tensors = tensors.copy()
|
|
778
|
+
|
|
779
|
+
# Apply any active dtype filters
|
|
780
|
+
if self._current_filters:
|
|
781
|
+
self._apply_filters()
|
|
782
|
+
else:
|
|
783
|
+
self._all_tensors = tensors.copy()
|
|
784
|
+
|
|
785
|
+
# Update tensor table
|
|
786
|
+
table = self.query_one(TensorTable)
|
|
787
|
+
table.update_tensors(tensors, event.prefix)
|
|
788
|
+
|
|
789
|
+
# Apply current sort
|
|
790
|
+
if self._sort_mode_index > 0:
|
|
791
|
+
table.sort_by(SORT_ORDER[self._sort_mode_index])
|
|
792
|
+
|
|
793
|
+
def action_toggle_panel(self) -> None:
|
|
794
|
+
"""Toggle focus between tree and table panels."""
|
|
795
|
+
tree = self.query_one(HierarchyTree)
|
|
796
|
+
table = self.query_one(TensorTable)
|
|
797
|
+
|
|
798
|
+
if tree.has_focus:
|
|
799
|
+
table.focus()
|
|
800
|
+
else:
|
|
801
|
+
tree.focus()
|
|
802
|
+
|
|
803
|
+
def action_start_search(self) -> None:
|
|
804
|
+
"""Start search mode."""
|
|
805
|
+
search_input = self.query_one(SearchInput)
|
|
806
|
+
search_input.add_class("visible")
|
|
807
|
+
search_input.focus()
|
|
808
|
+
self._search_active = True
|
|
809
|
+
|
|
810
|
+
def action_cancel_search(self) -> None:
|
|
811
|
+
"""Cancel search and restore full list."""
|
|
812
|
+
search_input = self.query_one(SearchInput)
|
|
813
|
+
search_input.remove_class("visible")
|
|
814
|
+
search_input.value = ""
|
|
815
|
+
self._search_active = False
|
|
816
|
+
|
|
817
|
+
# Clear tree filter
|
|
818
|
+
tree = self.query_one(HierarchyTree)
|
|
819
|
+
tree.apply_filter(None)
|
|
820
|
+
|
|
821
|
+
# Reset to show all tensors
|
|
822
|
+
self._current_prefix = ""
|
|
823
|
+
self._base_tensors = self.index.tensors.copy()
|
|
824
|
+
self._all_tensors = self.index.tensors.copy()
|
|
825
|
+
|
|
826
|
+
# Restore full tensor list
|
|
827
|
+
table = self.query_one(TensorTable)
|
|
828
|
+
table.update_tensors(self._all_tensors, self._current_prefix)
|
|
829
|
+
|
|
830
|
+
# Apply current sort
|
|
831
|
+
if self._sort_mode_index > 0:
|
|
832
|
+
table.sort_by(SORT_ORDER[self._sort_mode_index])
|
|
833
|
+
|
|
834
|
+
# Focus tree
|
|
835
|
+
tree.focus()
|
|
836
|
+
|
|
837
|
+
def on_input_changed(self, event: Input.Changed) -> None:
|
|
838
|
+
"""Handle search input changes."""
|
|
839
|
+
if not self._search_active:
|
|
840
|
+
return
|
|
841
|
+
|
|
842
|
+
query = event.value.lower()
|
|
843
|
+
tree = self.query_one(HierarchyTree)
|
|
844
|
+
table = self.query_one(TensorTable)
|
|
845
|
+
|
|
846
|
+
if query:
|
|
847
|
+
# Filter tensors from the full index (not current selection)
|
|
848
|
+
filtered = [t for t in self.index.tensors if query in t.full_name.lower()]
|
|
849
|
+
|
|
850
|
+
# Update tree with filter (pass query for display)
|
|
851
|
+
tree.apply_filter(filtered, query)
|
|
852
|
+
|
|
853
|
+
# Update table with filtered tensors
|
|
854
|
+
self._current_prefix = ""
|
|
855
|
+
self._base_tensors = filtered
|
|
856
|
+
self._all_tensors = filtered
|
|
857
|
+
table.update_tensors(filtered, "")
|
|
858
|
+
|
|
859
|
+
# Apply current sort
|
|
860
|
+
if self._sort_mode_index > 0:
|
|
861
|
+
table.sort_by(SORT_ORDER[self._sort_mode_index])
|
|
862
|
+
else:
|
|
863
|
+
# Clear filter
|
|
864
|
+
tree.apply_filter(None)
|
|
865
|
+
self._current_prefix = ""
|
|
866
|
+
self._base_tensors = self.index.tensors.copy()
|
|
867
|
+
self._all_tensors = self.index.tensors.copy()
|
|
868
|
+
table.update_tensors(self.index.tensors, "")
|
|
869
|
+
|
|
870
|
+
def on_input_submitted(self, _event: Input.Submitted) -> None:
|
|
871
|
+
"""Handle search input submission."""
|
|
872
|
+
# Keep the search active, just focus the table
|
|
873
|
+
table = self.query_one(TensorTable)
|
|
874
|
+
table.focus()
|
|
875
|
+
|
|
876
|
+
def action_cycle_sort(self) -> None:
|
|
877
|
+
"""Cycle through sort modes."""
|
|
878
|
+
self._sort_mode_index = (self._sort_mode_index + 1) % len(SORT_ORDER)
|
|
879
|
+
mode = SORT_ORDER[self._sort_mode_index]
|
|
880
|
+
|
|
881
|
+
table = self.query_one(TensorTable)
|
|
882
|
+
table.sort_by(mode)
|
|
883
|
+
|
|
884
|
+
def action_show_details(self) -> None:
|
|
885
|
+
"""Show tensor details popup."""
|
|
886
|
+
table = self.query_one(TensorTable)
|
|
887
|
+
tensor = table.get_selected_tensor()
|
|
888
|
+
|
|
889
|
+
if tensor:
|
|
890
|
+
self.push_screen(TensorDetailScreen(tensor))
|
|
891
|
+
|
|
892
|
+
def action_show_metadata(self) -> None:
|
|
893
|
+
"""Show file metadata popup."""
|
|
894
|
+
if self.index:
|
|
895
|
+
self.push_screen(MetadataScreen(self.index.metadata, self.file_path))
|
|
896
|
+
|
|
897
|
+
def action_show_filters(self) -> None:
|
|
898
|
+
"""Show filter palette."""
|
|
899
|
+
if self.index is None:
|
|
900
|
+
return
|
|
901
|
+
|
|
902
|
+
# Get available dtypes
|
|
903
|
+
available_dtypes = {t.dtype for t in self.index.tensors}
|
|
904
|
+
|
|
905
|
+
def on_filter_result(filters: dict) -> None:
|
|
906
|
+
"""Handle filter result."""
|
|
907
|
+
self._current_filters = filters
|
|
908
|
+
self._apply_filters()
|
|
909
|
+
|
|
910
|
+
self.push_screen(
|
|
911
|
+
FilterScreen(self._current_filters, available_dtypes),
|
|
912
|
+
on_filter_result,
|
|
913
|
+
)
|
|
914
|
+
|
|
915
|
+
def _apply_filters(self) -> None:
|
|
916
|
+
"""Apply current filters to the tensor list."""
|
|
917
|
+
# Start from base tensors (all tensors under current prefix)
|
|
918
|
+
tensors = self._base_tensors.copy()
|
|
919
|
+
|
|
920
|
+
# Apply dtype filter
|
|
921
|
+
if "dtypes" in self._current_filters and self._current_filters["dtypes"]:
|
|
922
|
+
allowed = set(self._current_filters["dtypes"])
|
|
923
|
+
tensors = [t for t in tensors if t.dtype in allowed]
|
|
924
|
+
|
|
925
|
+
self._all_tensors = tensors
|
|
926
|
+
|
|
927
|
+
# Update table
|
|
928
|
+
table = self.query_one(TensorTable)
|
|
929
|
+
table.update_tensors(tensors, self._current_prefix)
|
|
930
|
+
|
|
931
|
+
# Apply current sort
|
|
932
|
+
if self._sort_mode_index > 0:
|
|
933
|
+
table.sort_by(SORT_ORDER[self._sort_mode_index])
|
|
934
|
+
|
|
935
|
+
def action_goto_top(self) -> None:
|
|
936
|
+
"""Go to top of current focused widget."""
|
|
937
|
+
focused = self.focused
|
|
938
|
+
if isinstance(focused, DataTable):
|
|
939
|
+
focused.move_cursor(row=0)
|
|
940
|
+
elif isinstance(focused, Tree):
|
|
941
|
+
focused.select_node(focused.root)
|
|
942
|
+
|
|
943
|
+
def action_goto_bottom(self) -> None:
|
|
944
|
+
"""Go to bottom of current focused widget."""
|
|
945
|
+
focused = self.focused
|
|
946
|
+
if isinstance(focused, DataTable):
|
|
947
|
+
focused.move_cursor(row=focused.row_count - 1)
|
sft/cli.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""CLI entry point for sft."""
|
|
2
|
+
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
import typer
|
|
6
|
+
|
|
7
|
+
from sft import __version__
|
|
8
|
+
|
|
9
|
+
app = typer.Typer(
|
|
10
|
+
name="sft",
|
|
11
|
+
help="An interactive terminal browser for .safetensors files.",
|
|
12
|
+
no_args_is_help=True,
|
|
13
|
+
add_completion=False,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def version_callback(value: bool) -> None:
|
|
18
|
+
"""Print version and exit."""
|
|
19
|
+
if value:
|
|
20
|
+
typer.echo(f"sft {__version__}")
|
|
21
|
+
raise typer.Exit()
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@app.command()
|
|
25
|
+
def main(
|
|
26
|
+
file: Path = typer.Argument(
|
|
27
|
+
...,
|
|
28
|
+
help="Path to a .safetensors file to browse.",
|
|
29
|
+
exists=True,
|
|
30
|
+
file_okay=True,
|
|
31
|
+
dir_okay=False,
|
|
32
|
+
readable=True,
|
|
33
|
+
resolve_path=True,
|
|
34
|
+
),
|
|
35
|
+
_version: bool | None = typer.Option(
|
|
36
|
+
None,
|
|
37
|
+
"--version",
|
|
38
|
+
"-v",
|
|
39
|
+
help="Show version and exit.",
|
|
40
|
+
callback=version_callback,
|
|
41
|
+
is_eager=True,
|
|
42
|
+
),
|
|
43
|
+
) -> None:
|
|
44
|
+
"""Open an interactive browser for a .safetensors file."""
|
|
45
|
+
# Validate file extension
|
|
46
|
+
if file.suffix.lower() != ".safetensors":
|
|
47
|
+
typer.secho(
|
|
48
|
+
f"Error: Expected a .safetensors file, got '{file.suffix}'",
|
|
49
|
+
fg=typer.colors.RED,
|
|
50
|
+
err=True,
|
|
51
|
+
)
|
|
52
|
+
raise typer.Exit(code=1)
|
|
53
|
+
|
|
54
|
+
# Import here to avoid slow startup for --help/--version
|
|
55
|
+
from sft.browser import SftApp
|
|
56
|
+
|
|
57
|
+
# Launch the TUI
|
|
58
|
+
app_instance = SftApp(file)
|
|
59
|
+
app_instance.run()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
if __name__ == "__main__":
|
|
63
|
+
app()
|
sft/index.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
1
|
+
"""Data model and parsing for safetensors files."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import re
|
|
6
|
+
from dataclasses import dataclass, field
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def natural_sort_key(s: str) -> list:
|
|
12
|
+
"""Generate a sort key for natural (human) sorting.
|
|
13
|
+
|
|
14
|
+
Splits string into text and numeric parts so that numbers are
|
|
15
|
+
compared numerically: 'layer.2' < 'layer.10' instead of 'layer.10' < 'layer.2'.
|
|
16
|
+
"""
|
|
17
|
+
parts = re.split(r"(\d+)", s)
|
|
18
|
+
return [int(part) if part.isdigit() else part.lower() for part in parts]
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class TensorInfo:
|
|
23
|
+
"""Information about a single tensor."""
|
|
24
|
+
|
|
25
|
+
full_name: str
|
|
26
|
+
shape: tuple[int, ...]
|
|
27
|
+
dtype: str
|
|
28
|
+
nbytes: int
|
|
29
|
+
|
|
30
|
+
@property
|
|
31
|
+
def rank(self) -> int:
|
|
32
|
+
"""Return the number of dimensions (rank) of the tensor."""
|
|
33
|
+
return len(self.shape)
|
|
34
|
+
|
|
35
|
+
@property
|
|
36
|
+
def numel(self) -> int:
|
|
37
|
+
"""Return the number of elements in the tensor."""
|
|
38
|
+
result = 1
|
|
39
|
+
for dim in self.shape:
|
|
40
|
+
result *= dim
|
|
41
|
+
return result
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
@dataclass
|
|
45
|
+
class TensorIndex:
|
|
46
|
+
"""Index of all tensors in a safetensors file."""
|
|
47
|
+
|
|
48
|
+
tensors: list[TensorInfo]
|
|
49
|
+
metadata: dict[str, Any]
|
|
50
|
+
file_path: Path
|
|
51
|
+
|
|
52
|
+
@classmethod
|
|
53
|
+
def from_file(cls, path: Path) -> TensorIndex:
|
|
54
|
+
"""Parse a safetensors file and extract tensor metadata (header only).
|
|
55
|
+
|
|
56
|
+
This uses direct header parsing to avoid loading any tensor data.
|
|
57
|
+
"""
|
|
58
|
+
import json
|
|
59
|
+
import struct
|
|
60
|
+
|
|
61
|
+
tensors: list[TensorInfo] = []
|
|
62
|
+
metadata: dict[str, Any] = {}
|
|
63
|
+
|
|
64
|
+
with open(path, "rb") as f:
|
|
65
|
+
# Read header size (first 8 bytes, little-endian u64)
|
|
66
|
+
header_size_bytes = f.read(8)
|
|
67
|
+
if len(header_size_bytes) < 8:
|
|
68
|
+
raise ValueError("Invalid safetensors file: too short")
|
|
69
|
+
|
|
70
|
+
header_size = struct.unpack("<Q", header_size_bytes)[0]
|
|
71
|
+
|
|
72
|
+
# Read and parse header JSON
|
|
73
|
+
header_bytes = f.read(header_size)
|
|
74
|
+
header = json.loads(header_bytes.decode("utf-8"))
|
|
75
|
+
|
|
76
|
+
# Extract metadata if present
|
|
77
|
+
if "__metadata__" in header:
|
|
78
|
+
metadata = header.pop("__metadata__")
|
|
79
|
+
|
|
80
|
+
# Extract tensor info
|
|
81
|
+
for name, info in header.items():
|
|
82
|
+
dtype_str = info["dtype"]
|
|
83
|
+
shape = tuple(info["shape"])
|
|
84
|
+
|
|
85
|
+
# Calculate byte offsets
|
|
86
|
+
data_offsets = info["data_offsets"]
|
|
87
|
+
nbytes = data_offsets[1] - data_offsets[0]
|
|
88
|
+
|
|
89
|
+
tensors.append(
|
|
90
|
+
TensorInfo(
|
|
91
|
+
full_name=name,
|
|
92
|
+
shape=shape,
|
|
93
|
+
dtype=dtype_str,
|
|
94
|
+
nbytes=nbytes,
|
|
95
|
+
)
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
# Sort tensors by name for consistent ordering (natural sort)
|
|
99
|
+
tensors.sort(key=lambda t: natural_sort_key(t.full_name))
|
|
100
|
+
|
|
101
|
+
return cls(tensors=tensors, metadata=metadata, file_path=path)
|
|
102
|
+
|
|
103
|
+
@property
|
|
104
|
+
def total_tensors(self) -> int:
|
|
105
|
+
"""Return the total number of tensors."""
|
|
106
|
+
return len(self.tensors)
|
|
107
|
+
|
|
108
|
+
@property
|
|
109
|
+
def total_bytes(self) -> int:
|
|
110
|
+
"""Return the total size of all tensors in bytes."""
|
|
111
|
+
return sum(t.nbytes for t in self.tensors)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
@dataclass
|
|
115
|
+
class PrefixTreeNode:
|
|
116
|
+
"""A node in the prefix tree representing a namespace."""
|
|
117
|
+
|
|
118
|
+
name: str
|
|
119
|
+
children: dict[str, PrefixTreeNode] = field(default_factory=dict)
|
|
120
|
+
tensor_ids: list[int] = field(default_factory=list)
|
|
121
|
+
aggregate_count: int = 0
|
|
122
|
+
aggregate_bytes: int = 0
|
|
123
|
+
|
|
124
|
+
def add_tensor(self, parts: list[str], tensor_id: int, nbytes: int) -> None:
|
|
125
|
+
"""Add a tensor to this node or a descendant."""
|
|
126
|
+
if not parts:
|
|
127
|
+
# This tensor belongs directly to this node
|
|
128
|
+
self.tensor_ids.append(tensor_id)
|
|
129
|
+
return
|
|
130
|
+
|
|
131
|
+
# Navigate/create child node
|
|
132
|
+
child_name = parts[0]
|
|
133
|
+
if child_name not in self.children:
|
|
134
|
+
self.children[child_name] = PrefixTreeNode(name=child_name)
|
|
135
|
+
|
|
136
|
+
self.children[child_name].add_tensor(parts[1:], tensor_id, nbytes)
|
|
137
|
+
|
|
138
|
+
def compute_aggregates(self, tensors: list[TensorInfo]) -> None:
|
|
139
|
+
"""Compute aggregate counts and bytes for this node and descendants."""
|
|
140
|
+
# First, compute for all children
|
|
141
|
+
for child in self.children.values():
|
|
142
|
+
child.compute_aggregates(tensors)
|
|
143
|
+
|
|
144
|
+
# Aggregate from direct tensors
|
|
145
|
+
direct_count = len(self.tensor_ids)
|
|
146
|
+
direct_bytes = sum(tensors[tid].nbytes for tid in self.tensor_ids)
|
|
147
|
+
|
|
148
|
+
# Aggregate from children
|
|
149
|
+
child_count = sum(c.aggregate_count for c in self.children.values())
|
|
150
|
+
child_bytes = sum(c.aggregate_bytes for c in self.children.values())
|
|
151
|
+
|
|
152
|
+
self.aggregate_count = direct_count + child_count
|
|
153
|
+
self.aggregate_bytes = direct_bytes + child_bytes
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class PrefixTree:
|
|
157
|
+
"""A tree structure built from tensor names using a delimiter."""
|
|
158
|
+
|
|
159
|
+
def __init__(self, index: TensorIndex, delimiter: str = ".") -> None:
|
|
160
|
+
"""Build a prefix tree from a tensor index."""
|
|
161
|
+
self.index = index
|
|
162
|
+
self.delimiter = delimiter
|
|
163
|
+
self.root = PrefixTreeNode(name="")
|
|
164
|
+
|
|
165
|
+
# Build the tree
|
|
166
|
+
for i, tensor in enumerate(index.tensors):
|
|
167
|
+
parts = tensor.full_name.split(delimiter)
|
|
168
|
+
self.root.add_tensor(parts, i, tensor.nbytes)
|
|
169
|
+
|
|
170
|
+
# Compute aggregates
|
|
171
|
+
self.root.compute_aggregates(index.tensors)
|
|
172
|
+
|
|
173
|
+
def get_tensors_under(self, prefix: str) -> list[TensorInfo]:
|
|
174
|
+
"""Get all tensors under a given prefix."""
|
|
175
|
+
if not prefix:
|
|
176
|
+
return self.index.tensors
|
|
177
|
+
|
|
178
|
+
# Navigate to the prefix node
|
|
179
|
+
parts = prefix.split(self.delimiter)
|
|
180
|
+
node = self.root
|
|
181
|
+
|
|
182
|
+
for part in parts:
|
|
183
|
+
if part in node.children:
|
|
184
|
+
node = node.children[part]
|
|
185
|
+
else:
|
|
186
|
+
return []
|
|
187
|
+
|
|
188
|
+
# Collect all tensor IDs under this node
|
|
189
|
+
tensor_ids = self._collect_tensor_ids(node)
|
|
190
|
+
return [self.index.tensors[tid] for tid in tensor_ids]
|
|
191
|
+
|
|
192
|
+
def _collect_tensor_ids(self, node: PrefixTreeNode) -> list[int]:
|
|
193
|
+
"""Recursively collect all tensor IDs under a node."""
|
|
194
|
+
ids = list(node.tensor_ids)
|
|
195
|
+
for child in node.children.values():
|
|
196
|
+
ids.extend(self._collect_tensor_ids(child))
|
|
197
|
+
return ids
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: sft-cli
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: An interactive terminal browser for .safetensors files
|
|
5
|
+
Project-URL: Homepage, https://github.com/matanby/sft-cli
|
|
6
|
+
Project-URL: Repository, https://github.com/matanby/sft-cli
|
|
7
|
+
Author: Matan Ben-Yosef
|
|
8
|
+
License: MIT
|
|
9
|
+
Keywords: browser,cli,machine-learning,safetensors,tui
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: Environment :: Console
|
|
12
|
+
Classifier: Intended Audience :: Developers
|
|
13
|
+
Classifier: Intended Audience :: Science/Research
|
|
14
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
15
|
+
Classifier: Programming Language :: Python :: 3
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
18
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
19
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
21
|
+
Classifier: Topic :: Utilities
|
|
22
|
+
Requires-Python: >=3.9
|
|
23
|
+
Requires-Dist: textual>=0.40
|
|
24
|
+
Requires-Dist: typer>=0.9
|
|
25
|
+
Description-Content-Type: text/markdown
|
|
26
|
+
|
|
27
|
+
# sft
|
|
28
|
+
|
|
29
|
+
An interactive terminal browser for `.safetensors` files.
|
|
30
|
+
|
|
31
|
+
## Installation
|
|
32
|
+
|
|
33
|
+
The recommended way to install `sft` is via [uv](https://docs.astral.sh/uv/):
|
|
34
|
+
|
|
35
|
+
```bash
|
|
36
|
+
uv tool install sft-cli
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
This makes `sft` available globally as a command-line tool.
|
|
40
|
+
|
|
41
|
+
Alternatively, install via pip:
|
|
42
|
+
|
|
43
|
+
```bash
|
|
44
|
+
pip install sft-cli
|
|
45
|
+
```
|
|
46
|
+
|
|
47
|
+
Or install from source:
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
git clone https://github.com/matanby/sft-cli
|
|
51
|
+
cd sft-cli
|
|
52
|
+
pip install -e .
|
|
53
|
+
```
|
|
54
|
+
|
|
55
|
+
## Usage
|
|
56
|
+
|
|
57
|
+
```bash
|
|
58
|
+
sft model.safetensors
|
|
59
|
+
```
|
|
60
|
+
|
|
61
|
+
## Features
|
|
62
|
+
|
|
63
|
+
- **Interactive TUI** — Browse tensors with keyboard navigation
|
|
64
|
+
- **Hierarchy View** — Tensors organized by namespace (e.g., `unet.down_blocks.0`)
|
|
65
|
+
- **Fast** — Header-only parsing, instant startup even for multi-GB files
|
|
66
|
+
- **Safe** — Read-only, never loads tensor data
|
|
67
|
+
- **Search** — Find tensors by name with `/`
|
|
68
|
+
- **Filter** — Filter by dtype with `f`
|
|
69
|
+
- **Sort** — Sort by name, size, or rank with `s`
|
|
70
|
+
- **Details** — View tensor details with `Space`
|
|
71
|
+
- **Metadata** — View file metadata with `m`
|
|
72
|
+
|
|
73
|
+
## Keybindings
|
|
74
|
+
|
|
75
|
+
### Navigation
|
|
76
|
+
| Key | Action |
|
|
77
|
+
|-----|--------|
|
|
78
|
+
| `↑`/`↓` | Move selection |
|
|
79
|
+
| `←`/`→` | Collapse/Expand tree node |
|
|
80
|
+
| `Enter` | Select/focus node |
|
|
81
|
+
| `Tab` | Switch between tree and table |
|
|
82
|
+
| `g`/`G` | Go to top/bottom |
|
|
83
|
+
|
|
84
|
+
### Search & Filter
|
|
85
|
+
| Key | Action |
|
|
86
|
+
|-----|--------|
|
|
87
|
+
| `/` | Start search |
|
|
88
|
+
| `f` | Open filter palette |
|
|
89
|
+
| `Esc` | Cancel search/close dialogs |
|
|
90
|
+
|
|
91
|
+
### Sorting
|
|
92
|
+
| Key | Action |
|
|
93
|
+
|-----|--------|
|
|
94
|
+
| `s` | Cycle sort mode (name ↑↓, size ↑↓, rank ↑↓) |
|
|
95
|
+
|
|
96
|
+
### Inspection
|
|
97
|
+
| Key | Action |
|
|
98
|
+
|-----|--------|
|
|
99
|
+
| `Space` | Show tensor details |
|
|
100
|
+
| `m` | Show file metadata |
|
|
101
|
+
|
|
102
|
+
### Application
|
|
103
|
+
| Key | Action |
|
|
104
|
+
|-----|--------|
|
|
105
|
+
| `q` | Quit |
|
|
106
|
+
|
|
107
|
+
## Technical Details
|
|
108
|
+
|
|
109
|
+
- **Header-only parsing** — sft reads only the safetensors header, never loading tensor data
|
|
110
|
+
- **Instant startup** — Even multi-GB model files open instantly
|
|
111
|
+
- **Memory efficient** — Uses minimal memory regardless of file size
|
|
112
|
+
|
|
113
|
+
## License
|
|
114
|
+
|
|
115
|
+
MIT
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
sft/__init__.py,sha256=vshzLW58M_PYp8ssp6k1iwo3Fhb263a2ACgC-LrSpEk,93
|
|
2
|
+
sft/browser.py,sha256=qIpHHI2vDB2N-a_Dm8-NJ_C_qmNgpPTT2j9IHGwKJPQ,30564
|
|
3
|
+
sft/cli.py,sha256=gFbaJHC5cJ2kIk3NJ5y3HV9YXnq3YRl5wcwfBpNDrPQ,1423
|
|
4
|
+
sft/index.py,sha256=PKhraBxQ9Vayz1Wsni59hjyo661-QStc3nQ0z-bPPvA,6284
|
|
5
|
+
sft_cli-0.1.0.dist-info/METADATA,sha256=oi_5tfW5UWbNXXV3MPUhxoEs0DJL_0fUGt_xvlfTtE8,2993
|
|
6
|
+
sft_cli-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
|
|
7
|
+
sft_cli-0.1.0.dist-info/entry_points.txt,sha256=t4jny1Xgeheb-L0GiTC71-BZMSMzkLT5mGfSwWJB1Vk,36
|
|
8
|
+
sft_cli-0.1.0.dist-info/RECORD,,
|