kiwi-array 0.2.47__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kiwi_array/__init__.py +9 -0
- kiwi_array/bridge.py +801 -0
- kiwi_array/cli.py +36 -0
- kiwi_array/notebook.py +268 -0
- kiwi_array-0.2.47.dist-info/METADATA +65 -0
- kiwi_array-0.2.47.dist-info/RECORD +9 -0
- kiwi_array-0.2.47.dist-info/WHEEL +5 -0
- kiwi_array-0.2.47.dist-info/entry_points.txt +2 -0
- kiwi_array-0.2.47.dist-info/top_level.txt +1 -0
kiwi_array/__init__.py
ADDED
kiwi_array/bridge.py
ADDED
|
@@ -0,0 +1,801 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import ctypes
|
|
4
|
+
import ctypes.util
|
|
5
|
+
import os
|
|
6
|
+
import platform
|
|
7
|
+
import shutil
|
|
8
|
+
import shlex
|
|
9
|
+
import subprocess
|
|
10
|
+
from dataclasses import dataclass
|
|
11
|
+
from importlib import metadata as importlib_metadata
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from typing import Any, Optional, Sequence, Union
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import numpy as _np
|
|
17
|
+
except ImportError: # pragma: no cover - optional dependency
|
|
18
|
+
_np = None
|
|
19
|
+
|
|
20
|
+
LIB_ENV_VAR = "KIWI_BRIDGE_LIB"
|
|
21
|
+
LEGACY_LIB_ENV_VAR = "KIWI_JUPYTER_BRIDGE_LIB"
|
|
22
|
+
RUNTIME_ENV_VAR = "KIWI_RUNTIME"
|
|
23
|
+
CUDA_DRIVER_LIB_ENV_VAR = "KIWI_CUDA_DRIVER_LIB"
|
|
24
|
+
CUDA_PRELOAD_DRIVER_ENV_VAR = "KIWI_CUDA_PRELOAD_DRIVER"
|
|
25
|
+
SOURCE_ROOT_ENV_VAR = "KIWI_SOURCE_ROOT"
|
|
26
|
+
ZIG_ENV_VAR = "KIWI_ZIG"
|
|
27
|
+
MLX_PREFIX_ENV_VAR = "KIWI_MLX_PREFIX"
|
|
28
|
+
MLX_C_INCLUDE_ENV_VAR = "KIWI_MLX_C_INCLUDE"
|
|
29
|
+
RUNTIME_ENTRY_POINT_GROUP = "kiwi_array.runtimes"
|
|
30
|
+
|
|
31
|
+
DEVICE_VALUES = {
|
|
32
|
+
"auto": 0,
|
|
33
|
+
"cpu": 1,
|
|
34
|
+
"gpu": 2,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
STATUS_NAMES = {
|
|
38
|
+
0: "ok",
|
|
39
|
+
1: "parse",
|
|
40
|
+
2: "type",
|
|
41
|
+
3: "name",
|
|
42
|
+
4: "domain",
|
|
43
|
+
5: "rank",
|
|
44
|
+
6: "nyi",
|
|
45
|
+
7: "length",
|
|
46
|
+
8: "index",
|
|
47
|
+
9: "mlx",
|
|
48
|
+
10: "device",
|
|
49
|
+
11: "error",
|
|
50
|
+
12: "oom",
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
AUTOGRAD_PATH_NAMES = {
|
|
54
|
+
0: "none",
|
|
55
|
+
1: "mlx",
|
|
56
|
+
2: "finite_difference",
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
PathLikeArg = Union[os.PathLike[str], str]
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class _CKiwiEvalResult(ctypes.Structure):
|
|
63
|
+
_fields_ = [
|
|
64
|
+
("status", ctypes.c_int),
|
|
65
|
+
("echoed", ctypes.c_bool),
|
|
66
|
+
("autograd_path", ctypes.c_int),
|
|
67
|
+
("text_ptr", ctypes.c_void_p),
|
|
68
|
+
("text_len", ctypes.c_size_t),
|
|
69
|
+
("display_mime_ptr", ctypes.c_void_p),
|
|
70
|
+
("display_mime_len", ctypes.c_size_t),
|
|
71
|
+
("display_data_ptr", ctypes.c_void_p),
|
|
72
|
+
("display_data_len", ctypes.c_size_t),
|
|
73
|
+
]
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@dataclass(frozen=True)
|
|
77
|
+
class KiwiEvalResult:
|
|
78
|
+
status: str
|
|
79
|
+
echoed: bool
|
|
80
|
+
autograd_path: str
|
|
81
|
+
text: Optional[str]
|
|
82
|
+
display_mime: Optional[str] = None
|
|
83
|
+
display_data: Optional[str] = None
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class KiwiBridgeError(RuntimeError):
|
|
87
|
+
pass
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass(frozen=True)
|
|
91
|
+
class RuntimeDescriptor:
|
|
92
|
+
name: str
|
|
93
|
+
priority: int
|
|
94
|
+
backend: str
|
|
95
|
+
accelerator: Optional[str]
|
|
96
|
+
native_dir: Optional[Path]
|
|
97
|
+
lib_dir: Optional[Path]
|
|
98
|
+
bin_dir: Optional[Path] = None
|
|
99
|
+
source: str = "unknown"
|
|
100
|
+
|
|
101
|
+
def bridge_library_path(self) -> Optional[Path]:
|
|
102
|
+
if self.native_dir is None:
|
|
103
|
+
return None
|
|
104
|
+
return self.native_dir / platform_library_name("kiwi_bridge")
|
|
105
|
+
|
|
106
|
+
def cli_path(self) -> Optional[Path]:
|
|
107
|
+
if self.bin_dir is None:
|
|
108
|
+
return None
|
|
109
|
+
return self.bin_dir / platform_executable_name("kiwi")
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _find_source_root() -> Optional[Path]:
|
|
113
|
+
override = os.environ.get(SOURCE_ROOT_ENV_VAR)
|
|
114
|
+
if override:
|
|
115
|
+
return Path(override).expanduser()
|
|
116
|
+
|
|
117
|
+
for parent in Path(__file__).resolve().parents:
|
|
118
|
+
if (parent / "build.zig").is_file() and (parent / "src" / "kiwi_bridge.zig").is_file():
|
|
119
|
+
return parent
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def implementation_root() -> Path:
|
|
124
|
+
root = _find_source_root()
|
|
125
|
+
if root is None:
|
|
126
|
+
raise KiwiBridgeError(
|
|
127
|
+
"Kiwi source root not found. Set KIWI_SOURCE_ROOT when building the "
|
|
128
|
+
"bridge from source, or install a platform wheel that carries the "
|
|
129
|
+
"native Kiwi runtime."
|
|
130
|
+
)
|
|
131
|
+
return root
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def _has_workspace_dependency_layout() -> bool:
|
|
135
|
+
root = _find_source_root()
|
|
136
|
+
if root is None:
|
|
137
|
+
return False
|
|
138
|
+
return (root.parents[1] / "vendor" / "mlx-c").exists()
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def repo_root() -> Path:
|
|
142
|
+
root = implementation_root()
|
|
143
|
+
return root.parents[1] if _has_workspace_dependency_layout() else root
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _package_root() -> Optional[Path]:
|
|
147
|
+
return Path(__file__).resolve().parent
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def _packaged_native_dir() -> Optional[Path]:
|
|
151
|
+
root = _package_root()
|
|
152
|
+
return None if root is None else root / "native"
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def _packaged_library_dir() -> Optional[Path]:
|
|
156
|
+
root = _package_root()
|
|
157
|
+
return None if root is None else root / "lib"
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def _is_platform_dynamic_library_name(name: str) -> bool:
|
|
161
|
+
system = platform.system()
|
|
162
|
+
if system == "Darwin":
|
|
163
|
+
return name.endswith(".dylib")
|
|
164
|
+
if system == "Windows":
|
|
165
|
+
return name.endswith(".dll")
|
|
166
|
+
return ".so" in name
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _dir_has_payload(path: Optional[Path]) -> bool:
|
|
170
|
+
if path is None or not path.is_dir():
|
|
171
|
+
return False
|
|
172
|
+
return any(
|
|
173
|
+
child.is_file()
|
|
174
|
+
and child.name not in {".gitignore", ".gitkeep"}
|
|
175
|
+
and _is_platform_dynamic_library_name(child.name)
|
|
176
|
+
for child in path.iterdir()
|
|
177
|
+
)
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def _path_from_descriptor(value: object) -> Optional[Path]:
|
|
181
|
+
if value in (None, ""):
|
|
182
|
+
return None
|
|
183
|
+
return Path(os.fspath(value)).expanduser()
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def _coerce_runtime_descriptor(raw: object, source: str) -> RuntimeDescriptor:
|
|
187
|
+
if isinstance(raw, RuntimeDescriptor):
|
|
188
|
+
return raw
|
|
189
|
+
if not isinstance(raw, dict):
|
|
190
|
+
raise TypeError(f"runtime descriptor from {source} must be a dict")
|
|
191
|
+
return RuntimeDescriptor(
|
|
192
|
+
name=str(raw["name"]),
|
|
193
|
+
priority=int(raw.get("priority", 0)),
|
|
194
|
+
backend=str(raw.get("backend", "unknown")),
|
|
195
|
+
accelerator=None if raw.get("accelerator") is None else str(raw.get("accelerator")),
|
|
196
|
+
native_dir=_path_from_descriptor(raw.get("native_dir")),
|
|
197
|
+
lib_dir=_path_from_descriptor(raw.get("lib_dir")),
|
|
198
|
+
bin_dir=_path_from_descriptor(raw.get("bin_dir")),
|
|
199
|
+
source=source,
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
def _embedded_runtime_descriptor() -> RuntimeDescriptor:
|
|
204
|
+
root = _package_root()
|
|
205
|
+
return RuntimeDescriptor(
|
|
206
|
+
name="host",
|
|
207
|
+
priority=10,
|
|
208
|
+
backend="host",
|
|
209
|
+
accelerator=None,
|
|
210
|
+
native_dir=root / "native",
|
|
211
|
+
lib_dir=root / "lib",
|
|
212
|
+
bin_dir=root / "bin",
|
|
213
|
+
source="kiwi_array",
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _entry_point_runtime_descriptors() -> list[RuntimeDescriptor]:
|
|
218
|
+
try:
|
|
219
|
+
entry_points = importlib_metadata.entry_points()
|
|
220
|
+
except Exception:
|
|
221
|
+
return []
|
|
222
|
+
if hasattr(entry_points, "select"):
|
|
223
|
+
selected = entry_points.select(group=RUNTIME_ENTRY_POINT_GROUP)
|
|
224
|
+
else: # pragma: no cover - Python <3.10 compatibility path
|
|
225
|
+
selected = entry_points.get(RUNTIME_ENTRY_POINT_GROUP, [])
|
|
226
|
+
|
|
227
|
+
descriptors: list[RuntimeDescriptor] = []
|
|
228
|
+
for entry_point in selected:
|
|
229
|
+
try:
|
|
230
|
+
loaded = entry_point.load()
|
|
231
|
+
raw = loaded() if callable(loaded) else loaded
|
|
232
|
+
descriptors.append(_coerce_runtime_descriptor(raw, entry_point.name))
|
|
233
|
+
except Exception:
|
|
234
|
+
continue
|
|
235
|
+
return descriptors
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
def discover_runtime_descriptors() -> list[RuntimeDescriptor]:
|
|
239
|
+
descriptors = [_embedded_runtime_descriptor(), *_entry_point_runtime_descriptors()]
|
|
240
|
+
by_name: dict[str, RuntimeDescriptor] = {}
|
|
241
|
+
for descriptor in descriptors:
|
|
242
|
+
current = by_name.get(descriptor.name)
|
|
243
|
+
if current is None or descriptor.priority > current.priority:
|
|
244
|
+
by_name[descriptor.name] = descriptor
|
|
245
|
+
return sorted(by_name.values(), key=lambda item: (-item.priority, item.name))
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def _runtime_selection() -> str:
|
|
249
|
+
value = os.environ.get(RUNTIME_ENV_VAR, "auto").strip()
|
|
250
|
+
return value or "auto"
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _selected_runtime_descriptors() -> list[RuntimeDescriptor]:
|
|
254
|
+
selection = _runtime_selection()
|
|
255
|
+
descriptors = discover_runtime_descriptors()
|
|
256
|
+
if selection == "auto":
|
|
257
|
+
return descriptors
|
|
258
|
+
return [descriptor for descriptor in descriptors if descriptor.name == selection]
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def active_runtime_descriptor() -> Optional[RuntimeDescriptor]:
|
|
262
|
+
for descriptor in _selected_runtime_descriptors():
|
|
263
|
+
bridge = descriptor.bridge_library_path()
|
|
264
|
+
if bridge is not None and bridge.exists():
|
|
265
|
+
return descriptor
|
|
266
|
+
return None
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def zig_executable() -> str:
|
|
270
|
+
override = os.environ.get(ZIG_ENV_VAR) or os.environ.get("ZIG")
|
|
271
|
+
if override:
|
|
272
|
+
return str(Path(override).expanduser())
|
|
273
|
+
bundled = repo_root() / "tools" / "zig"
|
|
274
|
+
if bundled.exists():
|
|
275
|
+
return str(bundled)
|
|
276
|
+
path_zig = shutil.which("zig")
|
|
277
|
+
return path_zig if path_zig is not None else "zig"
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def default_mlx_prefix() -> Path:
|
|
281
|
+
if _has_workspace_dependency_layout():
|
|
282
|
+
return repo_root() / ".artifacts" / "mlx" / "macos-default-install"
|
|
283
|
+
return implementation_root() / ".deps" / "mlx"
|
|
284
|
+
|
|
285
|
+
|
|
286
|
+
def default_mlx_c_include() -> Path:
|
|
287
|
+
if _has_workspace_dependency_layout():
|
|
288
|
+
return repo_root() / "vendor" / "mlx-c"
|
|
289
|
+
return implementation_root() / ".deps" / "mlx-c"
|
|
290
|
+
|
|
291
|
+
|
|
292
|
+
def mlx_prefix() -> Path:
|
|
293
|
+
return Path(
|
|
294
|
+
os.environ.get(MLX_PREFIX_ENV_VAR)
|
|
295
|
+
or default_mlx_prefix()
|
|
296
|
+
).expanduser()
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def mlx_c_include() -> Path:
|
|
300
|
+
return Path(
|
|
301
|
+
os.environ.get(MLX_C_INCLUDE_ENV_VAR)
|
|
302
|
+
or default_mlx_c_include()
|
|
303
|
+
).expanduser()
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
def runtime_library_search_dir() -> Path:
|
|
307
|
+
source_root = _find_source_root()
|
|
308
|
+
if _runtime_selection() == "auto" and source_root is not None:
|
|
309
|
+
source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
|
|
310
|
+
if source_bridge.exists():
|
|
311
|
+
return mlx_prefix() / "lib"
|
|
312
|
+
|
|
313
|
+
descriptor = active_runtime_descriptor()
|
|
314
|
+
if descriptor is not None and descriptor.lib_dir is not None and _dir_has_payload(descriptor.lib_dir):
|
|
315
|
+
return descriptor.lib_dir
|
|
316
|
+
return mlx_prefix() / "lib"
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def runtime_library_env_var() -> Optional[str]:
|
|
320
|
+
system = platform.system()
|
|
321
|
+
if system == "Darwin":
|
|
322
|
+
return "DYLD_LIBRARY_PATH"
|
|
323
|
+
if system == "Linux":
|
|
324
|
+
return "LD_LIBRARY_PATH"
|
|
325
|
+
return None
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
def build_args(optimize: str) -> list[str]:
|
|
329
|
+
bridge_mlx_prefix = mlx_prefix()
|
|
330
|
+
bridge_mlx_c_include = mlx_c_include()
|
|
331
|
+
return [
|
|
332
|
+
zig_executable(),
|
|
333
|
+
"build",
|
|
334
|
+
f"-Doptimize={optimize}",
|
|
335
|
+
"-Dpublic-cli=true",
|
|
336
|
+
"-Dcli-name=kiwi",
|
|
337
|
+
"-Dinstall-sdk=true",
|
|
338
|
+
f"-Dmlx-prefix={bridge_mlx_prefix}",
|
|
339
|
+
f"-Dmlx-c-include={bridge_mlx_c_include}",
|
|
340
|
+
]
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def platform_library_name(base_name: str) -> str:
|
|
344
|
+
system = platform.system()
|
|
345
|
+
if system == "Darwin":
|
|
346
|
+
return f"lib{base_name}.dylib"
|
|
347
|
+
if system == "Windows":
|
|
348
|
+
return f"{base_name}.dll"
|
|
349
|
+
return f"lib{base_name}.so"
|
|
350
|
+
|
|
351
|
+
|
|
352
|
+
def platform_executable_name(base_name: str) -> str:
|
|
353
|
+
return f"{base_name}.exe" if platform.system() == "Windows" else base_name
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def default_cli_candidates() -> list[Path]:
|
|
357
|
+
candidates: list[Path] = []
|
|
358
|
+
source_root = _find_source_root()
|
|
359
|
+
source_cli = None
|
|
360
|
+
if source_root is not None:
|
|
361
|
+
source_cli = source_root / "zig-out" / "bin" / platform_executable_name("kiwi")
|
|
362
|
+
|
|
363
|
+
if _runtime_selection() == "auto" and source_cli is not None:
|
|
364
|
+
candidates.append(source_cli)
|
|
365
|
+
|
|
366
|
+
for descriptor in _selected_runtime_descriptors():
|
|
367
|
+
cli = descriptor.cli_path()
|
|
368
|
+
if cli is not None:
|
|
369
|
+
candidates.append(cli)
|
|
370
|
+
|
|
371
|
+
if _runtime_selection() != "auto" and source_cli is not None:
|
|
372
|
+
candidates.append(source_cli)
|
|
373
|
+
return candidates
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
def find_cli_path() -> Path:
|
|
377
|
+
for candidate in default_cli_candidates():
|
|
378
|
+
if candidate.exists():
|
|
379
|
+
return candidate
|
|
380
|
+
joined = ", ".join(str(candidate) for candidate in default_cli_candidates())
|
|
381
|
+
raise FileNotFoundError(
|
|
382
|
+
"Kiwi CLI not found. Install a runtime payload such as "
|
|
383
|
+
"`kiwi-array-host`, `kiwi-array-cpu`, `kiwi-array-metal`, or "
|
|
384
|
+
"`kiwi-array-cuda12`. Searched: "
|
|
385
|
+
f"{joined}"
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
|
|
389
|
+
def default_library_candidates() -> list[Path]:
|
|
390
|
+
candidates: list[Path] = []
|
|
391
|
+
for env_var in (LIB_ENV_VAR, LEGACY_LIB_ENV_VAR):
|
|
392
|
+
env_path = os.environ.get(env_var)
|
|
393
|
+
if env_path:
|
|
394
|
+
candidates.append(Path(env_path).expanduser())
|
|
395
|
+
|
|
396
|
+
source_root = _find_source_root()
|
|
397
|
+
source_bridge = None
|
|
398
|
+
if source_root is not None:
|
|
399
|
+
source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
|
|
400
|
+
|
|
401
|
+
if _runtime_selection() == "auto" and source_bridge is not None:
|
|
402
|
+
candidates.append(source_bridge)
|
|
403
|
+
|
|
404
|
+
for descriptor in _selected_runtime_descriptors():
|
|
405
|
+
bridge = descriptor.bridge_library_path()
|
|
406
|
+
if bridge is not None:
|
|
407
|
+
candidates.append(bridge)
|
|
408
|
+
|
|
409
|
+
if _runtime_selection() != "auto" and source_bridge is not None:
|
|
410
|
+
candidates.append(source_bridge)
|
|
411
|
+
return candidates
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
def find_library_path(path: Optional[PathLikeArg] = None) -> Path:
|
|
415
|
+
if path is not None:
|
|
416
|
+
candidate = Path(path).expanduser()
|
|
417
|
+
if candidate.exists():
|
|
418
|
+
return candidate
|
|
419
|
+
raise FileNotFoundError(f"Kiwi bridge library not found at {candidate}")
|
|
420
|
+
|
|
421
|
+
for candidate in default_library_candidates():
|
|
422
|
+
if candidate.exists():
|
|
423
|
+
return candidate
|
|
424
|
+
joined = ", ".join(str(candidate) for candidate in default_library_candidates())
|
|
425
|
+
source_root = _find_source_root()
|
|
426
|
+
if source_root is None:
|
|
427
|
+
hint = "Install a platform wheel that carries the native Kiwi runtime."
|
|
428
|
+
else:
|
|
429
|
+
hint = f"Build it with `{shlex.join(build_args('ReleaseFast'))}` from {source_root}."
|
|
430
|
+
raise FileNotFoundError(f"Kiwi bridge library not found. {hint} Searched: {joined}")
|
|
431
|
+
|
|
432
|
+
|
|
433
|
+
def build_bridge(optimize: str = "ReleaseFast") -> None:
|
|
434
|
+
subprocess.run(
|
|
435
|
+
build_args(optimize),
|
|
436
|
+
cwd=implementation_root(),
|
|
437
|
+
check=True,
|
|
438
|
+
)
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
def _configure_library(lib: ctypes.CDLL) -> ctypes.CDLL:
|
|
442
|
+
lib.kiwi_session_create.argtypes = [ctypes.c_int]
|
|
443
|
+
lib.kiwi_session_create.restype = ctypes.c_void_p
|
|
444
|
+
|
|
445
|
+
lib.kiwi_session_destroy.argtypes = [ctypes.c_void_p]
|
|
446
|
+
lib.kiwi_session_destroy.restype = None
|
|
447
|
+
|
|
448
|
+
lib.kiwi_session_eval.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_size_t]
|
|
449
|
+
lib.kiwi_session_eval.restype = _CKiwiEvalResult
|
|
450
|
+
|
|
451
|
+
lib.kiwi_session_set_global_float_array.argtypes = [
|
|
452
|
+
ctypes.c_void_p,
|
|
453
|
+
ctypes.c_char_p,
|
|
454
|
+
ctypes.c_size_t,
|
|
455
|
+
ctypes.POINTER(ctypes.c_float),
|
|
456
|
+
ctypes.POINTER(ctypes.c_int32),
|
|
457
|
+
ctypes.c_size_t,
|
|
458
|
+
]
|
|
459
|
+
lib.kiwi_session_set_global_float_array.restype = ctypes.c_int
|
|
460
|
+
|
|
461
|
+
if hasattr(lib, "kiwi_session_set_global_mlx_float_array"):
|
|
462
|
+
lib.kiwi_session_set_global_mlx_float_array.argtypes = [
|
|
463
|
+
ctypes.c_void_p,
|
|
464
|
+
ctypes.c_char_p,
|
|
465
|
+
ctypes.c_size_t,
|
|
466
|
+
ctypes.POINTER(ctypes.c_float),
|
|
467
|
+
ctypes.POINTER(ctypes.c_int32),
|
|
468
|
+
ctypes.c_size_t,
|
|
469
|
+
]
|
|
470
|
+
lib.kiwi_session_set_global_mlx_float_array.restype = ctypes.c_int
|
|
471
|
+
|
|
472
|
+
lib.kiwi_session_set_global_int_array.argtypes = [
|
|
473
|
+
ctypes.c_void_p,
|
|
474
|
+
ctypes.c_char_p,
|
|
475
|
+
ctypes.c_size_t,
|
|
476
|
+
ctypes.POINTER(ctypes.c_int32),
|
|
477
|
+
ctypes.POINTER(ctypes.c_int32),
|
|
478
|
+
ctypes.c_size_t,
|
|
479
|
+
]
|
|
480
|
+
lib.kiwi_session_set_global_int_array.restype = ctypes.c_int
|
|
481
|
+
|
|
482
|
+
lib.kiwi_session_set_global_bool_array.argtypes = [
|
|
483
|
+
ctypes.c_void_p,
|
|
484
|
+
ctypes.c_char_p,
|
|
485
|
+
ctypes.c_size_t,
|
|
486
|
+
ctypes.POINTER(ctypes.c_bool),
|
|
487
|
+
ctypes.POINTER(ctypes.c_int32),
|
|
488
|
+
ctypes.c_size_t,
|
|
489
|
+
]
|
|
490
|
+
lib.kiwi_session_set_global_bool_array.restype = ctypes.c_int
|
|
491
|
+
|
|
492
|
+
lib.kiwi_eval_result_free.argtypes = [_CKiwiEvalResult]
|
|
493
|
+
lib.kiwi_eval_result_free.restype = None
|
|
494
|
+
|
|
495
|
+
lib.kiwi_status_name.argtypes = [ctypes.c_int]
|
|
496
|
+
lib.kiwi_status_name.restype = ctypes.c_char_p
|
|
497
|
+
return lib
|
|
498
|
+
|
|
499
|
+
|
|
500
|
+
def _preload_runtime_dependencies(descriptor: Optional[RuntimeDescriptor] = None) -> None:
|
|
501
|
+
_preload_cuda_driver()
|
|
502
|
+
|
|
503
|
+
search_dirs = []
|
|
504
|
+
if descriptor is not None and descriptor.lib_dir is not None:
|
|
505
|
+
search_dirs.append(descriptor.lib_dir)
|
|
506
|
+
search_dirs.append(runtime_library_search_dir())
|
|
507
|
+
packaged_lib = _packaged_library_dir()
|
|
508
|
+
if packaged_lib is not None:
|
|
509
|
+
search_dirs.append(packaged_lib)
|
|
510
|
+
source_root = _find_source_root()
|
|
511
|
+
if source_root is not None:
|
|
512
|
+
search_dirs.append(source_root / "zig-out" / "lib")
|
|
513
|
+
|
|
514
|
+
seen: set[Path] = set()
|
|
515
|
+
for directory in search_dirs:
|
|
516
|
+
if directory in seen:
|
|
517
|
+
continue
|
|
518
|
+
seen.add(directory)
|
|
519
|
+
for base_name in ("mlx", "duckdb"):
|
|
520
|
+
library = directory / platform_library_name(base_name)
|
|
521
|
+
if library.exists():
|
|
522
|
+
ctypes.CDLL(str(library), mode=getattr(ctypes, "RTLD_GLOBAL", 0))
|
|
523
|
+
|
|
524
|
+
|
|
525
|
+
def _preload_cuda_driver() -> None:
|
|
526
|
+
if platform.system() != "Linux":
|
|
527
|
+
return
|
|
528
|
+
|
|
529
|
+
preload_mode = os.environ.get(CUDA_PRELOAD_DRIVER_ENV_VAR, "0").strip().lower()
|
|
530
|
+
if preload_mode in {"", "0", "false", "no", "off"}:
|
|
531
|
+
return
|
|
532
|
+
|
|
533
|
+
candidates = []
|
|
534
|
+
override = os.environ.get(CUDA_DRIVER_LIB_ENV_VAR)
|
|
535
|
+
if override:
|
|
536
|
+
candidates.append(Path(override).expanduser())
|
|
537
|
+
|
|
538
|
+
found = ctypes.util.find_library("cuda")
|
|
539
|
+
if found:
|
|
540
|
+
candidates.append(Path(found))
|
|
541
|
+
|
|
542
|
+
candidates.extend(
|
|
543
|
+
[
|
|
544
|
+
Path("/usr/lib64-nvidia/libcuda.so.1"),
|
|
545
|
+
Path("/usr/local/nvidia/lib64/libcuda.so.1"),
|
|
546
|
+
Path("/usr/lib/x86_64-linux-gnu/libcuda.so.1"),
|
|
547
|
+
]
|
|
548
|
+
)
|
|
549
|
+
if preload_mode in {"compat", "cuda-compat"}:
|
|
550
|
+
candidates.append(Path("/usr/local/cuda/compat/libcuda.so.1"))
|
|
551
|
+
|
|
552
|
+
for candidate in candidates:
|
|
553
|
+
try:
|
|
554
|
+
if candidate.is_absolute() and not candidate.exists():
|
|
555
|
+
continue
|
|
556
|
+
ctypes.CDLL(str(candidate), mode=getattr(ctypes, "RTLD_GLOBAL", 0))
|
|
557
|
+
return
|
|
558
|
+
except OSError:
|
|
559
|
+
continue
|
|
560
|
+
|
|
561
|
+
|
|
562
|
+
def _load_library_at(path: Path, descriptor: Optional[RuntimeDescriptor] = None) -> ctypes.CDLL:
|
|
563
|
+
_preload_runtime_dependencies(descriptor)
|
|
564
|
+
return _configure_library(ctypes.CDLL(str(path)))
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def load_library(path: Optional[PathLikeArg] = None) -> ctypes.CDLL:
|
|
568
|
+
if path is not None or os.environ.get(LIB_ENV_VAR) or os.environ.get(LEGACY_LIB_ENV_VAR):
|
|
569
|
+
return _load_library_at(find_library_path(path))
|
|
570
|
+
|
|
571
|
+
errors: list[str] = []
|
|
572
|
+
selection = _runtime_selection()
|
|
573
|
+
source_root = _find_source_root()
|
|
574
|
+
if selection == "auto" and source_root is not None:
|
|
575
|
+
source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
|
|
576
|
+
if source_bridge.exists():
|
|
577
|
+
try:
|
|
578
|
+
return _load_library_at(source_bridge)
|
|
579
|
+
except OSError as exc:
|
|
580
|
+
errors.append(f"source: {exc}")
|
|
581
|
+
|
|
582
|
+
for descriptor in _selected_runtime_descriptors():
|
|
583
|
+
bridge = descriptor.bridge_library_path()
|
|
584
|
+
if bridge is None or not bridge.exists():
|
|
585
|
+
continue
|
|
586
|
+
try:
|
|
587
|
+
return _load_library_at(bridge, descriptor)
|
|
588
|
+
except OSError as exc:
|
|
589
|
+
errors.append(f"{descriptor.name}: {exc}")
|
|
590
|
+
if selection != "auto":
|
|
591
|
+
raise KiwiBridgeError(f"failed to load Kiwi runtime {descriptor.name!r}: {exc}") from exc
|
|
592
|
+
|
|
593
|
+
if selection != "auto" and source_root is not None:
|
|
594
|
+
source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
|
|
595
|
+
if source_bridge.exists():
|
|
596
|
+
try:
|
|
597
|
+
return _load_library_at(source_bridge)
|
|
598
|
+
except OSError as exc:
|
|
599
|
+
errors.append(f"source: {exc}")
|
|
600
|
+
|
|
601
|
+
if errors:
|
|
602
|
+
raise KiwiBridgeError("failed to load any Kiwi runtime: " + "; ".join(errors))
|
|
603
|
+
return _load_library_at(find_library_path())
|
|
604
|
+
|
|
605
|
+
|
|
606
|
+
def _normalize_dims(dims: Optional[Sequence[int]], data) -> tuple[int, ...]:
|
|
607
|
+
if dims is not None:
|
|
608
|
+
return tuple(int(dim) for dim in dims)
|
|
609
|
+
|
|
610
|
+
shape = getattr(data, "shape", None)
|
|
611
|
+
if shape is None:
|
|
612
|
+
raise TypeError("dims are required unless data exposes a shape")
|
|
613
|
+
return tuple(int(dim) for dim in shape)
|
|
614
|
+
|
|
615
|
+
|
|
616
|
+
def _product(dims: Sequence[int]) -> int:
|
|
617
|
+
total = 1
|
|
618
|
+
for dim in dims:
|
|
619
|
+
if dim < 0:
|
|
620
|
+
raise ValueError(f"negative dimension {dim}")
|
|
621
|
+
total *= dim
|
|
622
|
+
return total
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
def _sequence_values(data, count: int):
|
|
626
|
+
if count == 1 and not isinstance(data, (list, tuple)):
|
|
627
|
+
return [data]
|
|
628
|
+
values = list(data)
|
|
629
|
+
if len(values) != count:
|
|
630
|
+
raise ValueError(f"expected {count} items, got {len(values)}")
|
|
631
|
+
return values
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _prepare_array_data(data, dims: tuple[int, ...], c_type, numpy_dtype):
|
|
635
|
+
count = _product(dims)
|
|
636
|
+
if _np is not None:
|
|
637
|
+
arr = _np.asarray(data, dtype=numpy_dtype)
|
|
638
|
+
if arr.size != count:
|
|
639
|
+
raise ValueError(f"expected {count} items, got {arr.size}")
|
|
640
|
+
flat = _np.ascontiguousarray(arr.reshape(-1))
|
|
641
|
+
ptr = None if count == 0 else flat.ctypes.data_as(ctypes.POINTER(c_type))
|
|
642
|
+
return flat, ptr
|
|
643
|
+
|
|
644
|
+
values = _sequence_values(data, count)
|
|
645
|
+
buf = (c_type * count)(*values)
|
|
646
|
+
ptr = None if count == 0 else buf
|
|
647
|
+
return buf, ptr
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
class KiwiSession:
|
|
651
|
+
def __init__(self, device: str = "auto", library_path: Optional[PathLikeArg] = None) -> None:
|
|
652
|
+
try:
|
|
653
|
+
device_value = DEVICE_VALUES[device]
|
|
654
|
+
except KeyError as exc:
|
|
655
|
+
raise ValueError(f"unknown Kiwi device {device!r}") from exc
|
|
656
|
+
|
|
657
|
+
self._lib = load_library(library_path)
|
|
658
|
+
self._handle = self._lib.kiwi_session_create(device_value)
|
|
659
|
+
if not self._handle:
|
|
660
|
+
raise KiwiBridgeError("failed to create Kiwi session")
|
|
661
|
+
|
|
662
|
+
def close(self) -> None:
|
|
663
|
+
if self._handle:
|
|
664
|
+
self._lib.kiwi_session_destroy(self._handle)
|
|
665
|
+
self._handle = None
|
|
666
|
+
|
|
667
|
+
def eval(self, source: str) -> KiwiEvalResult:
|
|
668
|
+
self._ensure_open()
|
|
669
|
+
encoded = source.encode("utf-8")
|
|
670
|
+
result = self._lib.kiwi_session_eval(self._handle, encoded, len(encoded))
|
|
671
|
+
try:
|
|
672
|
+
text = None
|
|
673
|
+
if result.text_ptr:
|
|
674
|
+
text = ctypes.string_at(result.text_ptr, result.text_len).decode("utf-8")
|
|
675
|
+
display_mime = None
|
|
676
|
+
if result.display_mime_ptr:
|
|
677
|
+
display_mime = ctypes.string_at(
|
|
678
|
+
result.display_mime_ptr,
|
|
679
|
+
result.display_mime_len,
|
|
680
|
+
).decode("utf-8")
|
|
681
|
+
display_data = None
|
|
682
|
+
if result.display_data_ptr:
|
|
683
|
+
display_data = ctypes.string_at(
|
|
684
|
+
result.display_data_ptr,
|
|
685
|
+
result.display_data_len,
|
|
686
|
+
).decode("utf-8")
|
|
687
|
+
return KiwiEvalResult(
|
|
688
|
+
status=STATUS_NAMES.get(result.status, "error"),
|
|
689
|
+
echoed=bool(result.echoed),
|
|
690
|
+
autograd_path=AUTOGRAD_PATH_NAMES.get(result.autograd_path, "none"),
|
|
691
|
+
text=text,
|
|
692
|
+
display_mime=display_mime,
|
|
693
|
+
display_data=display_data,
|
|
694
|
+
)
|
|
695
|
+
finally:
|
|
696
|
+
self._lib.kiwi_eval_result_free(result)
|
|
697
|
+
|
|
698
|
+
def set_global_float_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
|
|
699
|
+
self._set_global_array(
|
|
700
|
+
"kiwi_session_set_global_float_array",
|
|
701
|
+
name,
|
|
702
|
+
data,
|
|
703
|
+
dims,
|
|
704
|
+
ctypes.c_float,
|
|
705
|
+
_np.float32 if _np is not None else None,
|
|
706
|
+
)
|
|
707
|
+
|
|
708
|
+
def set_global_mlx_float_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
|
|
709
|
+
self._set_global_array(
|
|
710
|
+
"kiwi_session_set_global_mlx_float_array",
|
|
711
|
+
name,
|
|
712
|
+
data,
|
|
713
|
+
dims,
|
|
714
|
+
ctypes.c_float,
|
|
715
|
+
_np.float32 if _np is not None else None,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
def set_global_int_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
|
|
719
|
+
self._set_global_array(
|
|
720
|
+
"kiwi_session_set_global_int_array",
|
|
721
|
+
name,
|
|
722
|
+
data,
|
|
723
|
+
dims,
|
|
724
|
+
ctypes.c_int32,
|
|
725
|
+
_np.int32 if _np is not None else None,
|
|
726
|
+
)
|
|
727
|
+
|
|
728
|
+
def set_global_bool_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
|
|
729
|
+
self._set_global_array(
|
|
730
|
+
"kiwi_session_set_global_bool_array",
|
|
731
|
+
name,
|
|
732
|
+
data,
|
|
733
|
+
dims,
|
|
734
|
+
ctypes.c_bool,
|
|
735
|
+
_np.bool_ if _np is not None else None,
|
|
736
|
+
)
|
|
737
|
+
|
|
738
|
+
def __enter__(self) -> "KiwiSession":
|
|
739
|
+
return self
|
|
740
|
+
|
|
741
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
742
|
+
self.close()
|
|
743
|
+
|
|
744
|
+
def _ensure_open(self) -> None:
|
|
745
|
+
if not self._handle:
|
|
746
|
+
raise KiwiBridgeError("Kiwi session is closed")
|
|
747
|
+
|
|
748
|
+
def _set_global_array(
|
|
749
|
+
self,
|
|
750
|
+
fn_name: str,
|
|
751
|
+
name: str,
|
|
752
|
+
data,
|
|
753
|
+
dims: Optional[Sequence[int]],
|
|
754
|
+
c_type,
|
|
755
|
+
numpy_dtype,
|
|
756
|
+
) -> None:
|
|
757
|
+
self._ensure_open()
|
|
758
|
+
if not hasattr(self._lib, fn_name):
|
|
759
|
+
raise KiwiBridgeError(f"native Kiwi bridge does not expose {fn_name}")
|
|
760
|
+
dims_tuple = _normalize_dims(dims, data)
|
|
761
|
+
owner, data_ptr = _prepare_array_data(data, dims_tuple, c_type, numpy_dtype)
|
|
762
|
+
dims_owner = (ctypes.c_int32 * len(dims_tuple))(*dims_tuple)
|
|
763
|
+
dims_ptr = None if len(dims_tuple) == 0 else dims_owner
|
|
764
|
+
encoded = name.encode("utf-8")
|
|
765
|
+
status = getattr(self._lib, fn_name)(
|
|
766
|
+
self._handle,
|
|
767
|
+
encoded,
|
|
768
|
+
len(encoded),
|
|
769
|
+
data_ptr,
|
|
770
|
+
dims_ptr,
|
|
771
|
+
len(dims_tuple),
|
|
772
|
+
)
|
|
773
|
+
_ = owner
|
|
774
|
+
if status != 0:
|
|
775
|
+
raise KiwiBridgeError(f"failed to set global {name!r}: {STATUS_NAMES.get(status, 'error')}")
|
|
776
|
+
|
|
777
|
+
|
|
778
|
+
__all__ = [
|
|
779
|
+
"AUTOGRAD_PATH_NAMES",
|
|
780
|
+
"KiwiBridgeError",
|
|
781
|
+
"KiwiEvalResult",
|
|
782
|
+
"KiwiSession",
|
|
783
|
+
"LEGACY_LIB_ENV_VAR",
|
|
784
|
+
"LIB_ENV_VAR",
|
|
785
|
+
"RUNTIME_ENV_VAR",
|
|
786
|
+
"RuntimeDescriptor",
|
|
787
|
+
"STATUS_NAMES",
|
|
788
|
+
"active_runtime_descriptor",
|
|
789
|
+
"build_bridge",
|
|
790
|
+
"default_library_candidates",
|
|
791
|
+
"default_cli_candidates",
|
|
792
|
+
"discover_runtime_descriptors",
|
|
793
|
+
"find_cli_path",
|
|
794
|
+
"find_library_path",
|
|
795
|
+
"implementation_root",
|
|
796
|
+
"load_library",
|
|
797
|
+
"platform_library_name",
|
|
798
|
+
"platform_executable_name",
|
|
799
|
+
"runtime_library_env_var",
|
|
800
|
+
"runtime_library_search_dir",
|
|
801
|
+
]
|
kiwi_array/cli.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import subprocess
|
|
5
|
+
import sys
|
|
6
|
+
|
|
7
|
+
from . import bridge
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _launcher_environment() -> dict[str, str]:
|
|
11
|
+
env = os.environ.copy()
|
|
12
|
+
library_var = bridge.runtime_library_env_var()
|
|
13
|
+
if library_var is None:
|
|
14
|
+
return env
|
|
15
|
+
library_dir = bridge.runtime_library_search_dir()
|
|
16
|
+
existing = env.get(library_var)
|
|
17
|
+
env[library_var] = str(library_dir) if not existing else f"{library_dir}{os.pathsep}{existing}"
|
|
18
|
+
return env
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def main() -> int:
|
|
22
|
+
try:
|
|
23
|
+
cli = bridge.find_cli_path()
|
|
24
|
+
except FileNotFoundError as exc:
|
|
25
|
+
print(exc, file=sys.stderr)
|
|
26
|
+
return 1
|
|
27
|
+
|
|
28
|
+
argv = [str(cli), *sys.argv[1:]]
|
|
29
|
+
if hasattr(os, "execve"):
|
|
30
|
+
os.execve(str(cli), argv, _launcher_environment())
|
|
31
|
+
completed = subprocess.run(argv, env=_launcher_environment())
|
|
32
|
+
return int(completed.returncode)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
if __name__ == "__main__":
|
|
36
|
+
raise SystemExit(main())
|
kiwi_array/notebook.py
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import platform
|
|
6
|
+
import subprocess
|
|
7
|
+
import sys
|
|
8
|
+
import uuid
|
|
9
|
+
from dataclasses import dataclass
|
|
10
|
+
from typing import Iterator, Protocol
|
|
11
|
+
|
|
12
|
+
from kiwi_array.bridge import KiwiEvalResult, KiwiSession, find_library_path, runtime_library_search_dir
|
|
13
|
+
|
|
14
|
+
from . import __version__
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class SessionProtocol(Protocol):
|
|
18
|
+
def eval(self, source: str) -> KiwiEvalResult:
|
|
19
|
+
...
|
|
20
|
+
|
|
21
|
+
def close(self) -> None:
|
|
22
|
+
...
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass(frozen=True)
|
|
26
|
+
class CellOutput:
|
|
27
|
+
line_no: int
|
|
28
|
+
text: str
|
|
29
|
+
autograd_path: str
|
|
30
|
+
display_mime: str | None = None
|
|
31
|
+
display_data: str | None = None
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
@dataclass(frozen=True)
|
|
35
|
+
class SmokeCheck:
|
|
36
|
+
name: str
|
|
37
|
+
ok: bool
|
|
38
|
+
detail: str
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class KiwiNotebookError(RuntimeError):
|
|
42
|
+
def __init__(self, line_no: int, status: str) -> None:
|
|
43
|
+
self.line_no = line_no
|
|
44
|
+
self.status = status
|
|
45
|
+
super().__init__(f"line {line_no}: !{status}")
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def iter_executable_lines(code: str) -> Iterator[tuple[int, str]]:
|
|
49
|
+
for line_no, raw_line in enumerate(code.splitlines(), start=1):
|
|
50
|
+
line = raw_line.rstrip("\r")
|
|
51
|
+
trimmed = line.strip(" \t")
|
|
52
|
+
if not trimmed or trimmed.startswith("/"):
|
|
53
|
+
continue
|
|
54
|
+
yield line_no, line
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def execute_cell(session: SessionProtocol, code: str) -> list[CellOutput]:
|
|
58
|
+
outputs: list[CellOutput] = []
|
|
59
|
+
for line_no, line in iter_executable_lines(code):
|
|
60
|
+
result = session.eval(line)
|
|
61
|
+
if result.status != "ok":
|
|
62
|
+
raise KiwiNotebookError(line_no, result.status)
|
|
63
|
+
if result.echoed and result.text is not None:
|
|
64
|
+
outputs.append(
|
|
65
|
+
CellOutput(
|
|
66
|
+
line_no=line_no,
|
|
67
|
+
text=result.text,
|
|
68
|
+
autograd_path=result.autograd_path,
|
|
69
|
+
display_mime=result.display_mime,
|
|
70
|
+
display_data=result.display_data,
|
|
71
|
+
)
|
|
72
|
+
)
|
|
73
|
+
return outputs
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def display_bundle_for_output(output: CellOutput) -> dict[str, object]:
|
|
77
|
+
bundle: dict[str, object] = {"text/plain": output.text}
|
|
78
|
+
if output.display_mime is None or output.display_data is None:
|
|
79
|
+
return bundle
|
|
80
|
+
try:
|
|
81
|
+
payload = json.loads(output.display_data)
|
|
82
|
+
except json.JSONDecodeError:
|
|
83
|
+
return bundle
|
|
84
|
+
bundle[output.display_mime] = payload
|
|
85
|
+
if "vegalite" in output.display_mime:
|
|
86
|
+
bundle["text/html"] = vegalite_html(payload)
|
|
87
|
+
return bundle
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def vegalite_embed_options() -> dict[str, object]:
|
|
91
|
+
options: dict[str, object] = {"actions": False, "renderer": "canvas"}
|
|
92
|
+
theme = os.environ.get("KIWI_VEGALITE_THEME", "").strip()
|
|
93
|
+
if theme and theme.lower() not in {"default", "none"}:
|
|
94
|
+
options["theme"] = theme
|
|
95
|
+
return options
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def vegalite_html(spec: object) -> str:
|
|
99
|
+
element_id = f"kiwi-vegalite-{uuid.uuid4().hex}"
|
|
100
|
+
spec_json = json.dumps(spec, separators=(",", ":"))
|
|
101
|
+
options_json = json.dumps(vegalite_embed_options(), separators=(",", ":"))
|
|
102
|
+
return f"""
|
|
103
|
+
<div id="{element_id}"></div>
|
|
104
|
+
<script src="https://cdn.jsdelivr.net/npm/vega@5"></script>
|
|
105
|
+
<script src="https://cdn.jsdelivr.net/npm/vega-lite@5"></script>
|
|
106
|
+
<script src="https://cdn.jsdelivr.net/npm/vega-embed@6"></script>
|
|
107
|
+
<script>
|
|
108
|
+
(function() {{
|
|
109
|
+
const spec = {spec_json};
|
|
110
|
+
const render = function() {{
|
|
111
|
+
if (!window.vegaEmbed) {{
|
|
112
|
+
document.getElementById("{element_id}").textContent = {json.dumps("Vega-Embed failed to load.")};
|
|
113
|
+
return;
|
|
114
|
+
}}
|
|
115
|
+
window.vegaEmbed("#{element_id}", spec, {options_json});
|
|
116
|
+
}};
|
|
117
|
+
render();
|
|
118
|
+
}})();
|
|
119
|
+
</script>
|
|
120
|
+
"""
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def _safe_value(label: str, value_fn) -> tuple[str, str]:
|
|
124
|
+
try:
|
|
125
|
+
value = str(value_fn())
|
|
126
|
+
except Exception as exc: # pragma: no cover - defensive diagnostic path
|
|
127
|
+
value = f"!{type(exc).__name__}: {exc}"
|
|
128
|
+
return label, value
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _nvidia_smi_summary() -> str:
|
|
132
|
+
if platform.system() != "Linux":
|
|
133
|
+
return "not linux"
|
|
134
|
+
try:
|
|
135
|
+
result = subprocess.run(
|
|
136
|
+
[
|
|
137
|
+
"nvidia-smi",
|
|
138
|
+
"--query-gpu=name,driver_version,memory.total",
|
|
139
|
+
"--format=csv,noheader,nounits",
|
|
140
|
+
],
|
|
141
|
+
check=False,
|
|
142
|
+
capture_output=True,
|
|
143
|
+
text=True,
|
|
144
|
+
timeout=3,
|
|
145
|
+
)
|
|
146
|
+
except (OSError, subprocess.TimeoutExpired):
|
|
147
|
+
return "not found"
|
|
148
|
+
if result.returncode != 0:
|
|
149
|
+
message = (result.stderr or result.stdout).strip()
|
|
150
|
+
return message or f"exit {result.returncode}"
|
|
151
|
+
return result.stdout.strip() or "no gpu"
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def collect_info() -> list[tuple[str, str]]:
|
|
155
|
+
return [
|
|
156
|
+
("kiwi-array", __version__),
|
|
157
|
+
("python", sys.version.split()[0]),
|
|
158
|
+
("platform", platform.platform()),
|
|
159
|
+
_safe_value("bridge", find_library_path),
|
|
160
|
+
_safe_value("runtime_libs", runtime_library_search_dir),
|
|
161
|
+
("device_env", os.environ.get("KIWI_JUPYTER_DEVICE", "auto")),
|
|
162
|
+
("vegalite_theme", os.environ.get("KIWI_VEGALITE_THEME", "default")),
|
|
163
|
+
("nvidia_smi", _nvidia_smi_summary()),
|
|
164
|
+
]
|
|
165
|
+
|
|
166
|
+
|
|
167
|
+
def format_info_text(info: list[tuple[str, str]]) -> str:
|
|
168
|
+
width = max((len(label) for label, _ in info), default=0)
|
|
169
|
+
return "\n".join(f"{label.rjust(width)}: {value}" for label, value in info)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def run_smoke(session: SessionProtocol, include_mlx: bool = False) -> list[SmokeCheck]:
|
|
173
|
+
checks: list[SmokeCheck] = []
|
|
174
|
+
for name, source, expected in [
|
|
175
|
+
("scalar", "1+1", "2"),
|
|
176
|
+
("vector", "+/1 2 3", "6"),
|
|
177
|
+
("grad", "grad[{+/(x*x)}][1 2 3]", "2 4 6"),
|
|
178
|
+
]:
|
|
179
|
+
result = session.eval(source)
|
|
180
|
+
ok = result.status == "ok" and result.text == expected
|
|
181
|
+
detail = result.text if result.status == "ok" and result.text is not None else f"!{result.status}"
|
|
182
|
+
checks.append(SmokeCheck(name=name, ok=ok, detail=detail))
|
|
183
|
+
|
|
184
|
+
if include_mlx:
|
|
185
|
+
setter = getattr(session, "set_global_mlx_float_array", None)
|
|
186
|
+
if setter is None:
|
|
187
|
+
checks.append(SmokeCheck(name="mlx_global", ok=False, detail="native bridge lacks MLX setter"))
|
|
188
|
+
return checks
|
|
189
|
+
try:
|
|
190
|
+
setter("kx", [1.0, 2.0, 3.0], (3,))
|
|
191
|
+
result = session.eval("grad[{+/(x*x)}][kx]")
|
|
192
|
+
ok = result.status == "ok" and result.text == "2 4 6"
|
|
193
|
+
detail = result.text if result.status == "ok" and result.text is not None else f"!{result.status}"
|
|
194
|
+
checks.append(SmokeCheck(name="mlx_global", ok=ok, detail=detail))
|
|
195
|
+
except Exception as exc:
|
|
196
|
+
checks.append(SmokeCheck(name="mlx_global", ok=False, detail=f"!{type(exc).__name__}: {exc}"))
|
|
197
|
+
|
|
198
|
+
return checks
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
def format_smoke_text(checks: list[SmokeCheck]) -> str:
|
|
202
|
+
width = max((len(check.name) for check in checks), default=0)
|
|
203
|
+
lines = []
|
|
204
|
+
for check in checks:
|
|
205
|
+
status = "ok" if check.ok else "error"
|
|
206
|
+
lines.append(f"{check.name.rjust(width)}: {status} {check.detail}")
|
|
207
|
+
return "\n".join(lines)
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def load_ipython_extension(ipython) -> None:
|
|
211
|
+
from IPython.core.magic import Magics, line_cell_magic, line_magic, magics_class
|
|
212
|
+
from IPython.display import display
|
|
213
|
+
|
|
214
|
+
@magics_class
|
|
215
|
+
class KiwiMagics(Magics):
|
|
216
|
+
def __init__(self, shell) -> None:
|
|
217
|
+
super().__init__(shell)
|
|
218
|
+
self._session: KiwiSession | None = None
|
|
219
|
+
|
|
220
|
+
def _kiwi_session(self) -> KiwiSession:
|
|
221
|
+
if self._session is None:
|
|
222
|
+
device = os.environ.get("KIWI_JUPYTER_DEVICE", "auto")
|
|
223
|
+
library_path = os.environ.get("KIWI_BRIDGE_LIB")
|
|
224
|
+
self._session = KiwiSession(device=device, library_path=library_path)
|
|
225
|
+
return self._session
|
|
226
|
+
|
|
227
|
+
def _reset(self) -> None:
|
|
228
|
+
if self._session is not None:
|
|
229
|
+
self._session.close()
|
|
230
|
+
self._session = None
|
|
231
|
+
|
|
232
|
+
def _run(self, line: str, cell: str | None) -> None:
|
|
233
|
+
if cell is None and line.strip() in {"--reset", "-r"}:
|
|
234
|
+
self._reset()
|
|
235
|
+
return
|
|
236
|
+
code = line if cell is None else cell
|
|
237
|
+
outputs = execute_cell(self._kiwi_session(), code)
|
|
238
|
+
for output in outputs:
|
|
239
|
+
display(
|
|
240
|
+
display_bundle_for_output(output),
|
|
241
|
+
raw=True,
|
|
242
|
+
metadata={"kiwi_autograd_path": output.autograd_path},
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
@line_cell_magic
|
|
246
|
+
def kiwi(self, line: str, cell: str | None = None) -> None:
|
|
247
|
+
self._run(line, cell)
|
|
248
|
+
|
|
249
|
+
@line_cell_magic
|
|
250
|
+
def k(self, line: str, cell: str | None = None) -> None:
|
|
251
|
+
self._run(line, cell)
|
|
252
|
+
|
|
253
|
+
@line_magic
|
|
254
|
+
def kinfo(self, line: str) -> None:
|
|
255
|
+
del line
|
|
256
|
+
print(format_info_text(collect_info()))
|
|
257
|
+
|
|
258
|
+
@line_magic
|
|
259
|
+
def ksmoke(self, line: str) -> None:
|
|
260
|
+
flags = set(line.split())
|
|
261
|
+
include_mlx = bool(flags & {"--mlx", "--gpu"})
|
|
262
|
+
print(format_smoke_text(run_smoke(self._kiwi_session(), include_mlx=include_mlx)))
|
|
263
|
+
|
|
264
|
+
ipython.register_magics(KiwiMagics)
|
|
265
|
+
|
|
266
|
+
|
|
267
|
+
def unload_ipython_extension(ipython) -> None:
|
|
268
|
+
del ipython
|
|
@@ -0,0 +1,65 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: kiwi-array
|
|
3
|
+
Version: 0.2.47
|
|
4
|
+
Summary: Python loader, IPython magics, backend discovery, and CLI launcher for Kiwi.
|
|
5
|
+
Requires-Python: >=3.11
|
|
6
|
+
Description-Content-Type: text/markdown
|
|
7
|
+
Requires-Dist: kiwi-array-host==0.2.47
|
|
8
|
+
Provides-Extra: host
|
|
9
|
+
Requires-Dist: kiwi-array-host==0.2.47; extra == "host"
|
|
10
|
+
Provides-Extra: cpu
|
|
11
|
+
Requires-Dist: kiwi-array-cpu==0.2.47; extra == "cpu"
|
|
12
|
+
Provides-Extra: metal
|
|
13
|
+
Requires-Dist: kiwi-array-metal==0.2.47; extra == "metal"
|
|
14
|
+
Provides-Extra: cuda12
|
|
15
|
+
Requires-Dist: kiwi-array-cuda12==0.2.47; extra == "cuda12"
|
|
16
|
+
Provides-Extra: notebook
|
|
17
|
+
Requires-Dist: ipython>=8; extra == "notebook"
|
|
18
|
+
Provides-Extra: jupyter
|
|
19
|
+
Requires-Dist: kiwi-array-jupyter==0.2.47; extra == "jupyter"
|
|
20
|
+
|
|
21
|
+
# kiwi-array
|
|
22
|
+
|
|
23
|
+
`kiwi-array` is the Python loader package for Kiwi. It contains the
|
|
24
|
+
`kiwi_array.bridge` ctypes wrapper, IPython magics, runtime backend discovery,
|
|
25
|
+
and a small `kiwi` CLI launcher that execs the real native CLI from an
|
|
26
|
+
installed backend payload.
|
|
27
|
+
|
|
28
|
+
By default, installing `kiwi-array` also installs `kiwi-array-host`, the
|
|
29
|
+
conservative host backend. Accelerated backends are explicit extras:
|
|
30
|
+
`kiwi-array[cpu]`, `kiwi-array[metal]`, or `kiwi-array[cuda12]`.
|
|
31
|
+
|
|
32
|
+
In IPython-compatible hosted notebooks, load the extension and use either the
|
|
33
|
+
descriptive or short magic:
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
%load_ext kiwi_array
|
|
37
|
+
```
|
|
38
|
+
|
|
39
|
+
```text
|
|
40
|
+
%%k
|
|
41
|
+
x:1 2 3
|
|
42
|
+
+/x
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
`%%kiwi` is registered as the explicit alias for `%%k`.
|
|
46
|
+
|
|
47
|
+
Use the lightweight diagnostics when sharing hosted notebook setup cells:
|
|
48
|
+
|
|
49
|
+
```python
|
|
50
|
+
%kinfo
|
|
51
|
+
%ksmoke
|
|
52
|
+
```
|
|
53
|
+
|
|
54
|
+
On a GPU runtime, `%ksmoke --gpu` also pushes a small MLX-backed vector through
|
|
55
|
+
the native bridge and evaluates a gradient over it.
|
|
56
|
+
|
|
57
|
+
Runtime payload packages such as `kiwi-array-host`, `kiwi-array-metal`, and
|
|
58
|
+
`kiwi-array-cuda12` register themselves through the `kiwi_array.runtimes` entry
|
|
59
|
+
point group. Set `KIWI_RUNTIME=host`, `KIWI_RUNTIME=cpu`,
|
|
60
|
+
`KIWI_RUNTIME=metal`, `KIWI_RUNTIME=cuda12`, or leave `KIWI_RUNTIME=auto` to
|
|
61
|
+
control runtime selection.
|
|
62
|
+
|
|
63
|
+
Vega-Lite outputs use the notebook's default light rendering by default. Set
|
|
64
|
+
`KIWI_VEGALITE_THEME=dark` before rendering if you want the HTML fallback to use
|
|
65
|
+
Vega-Embed's dark theme.
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
kiwi_array/__init__.py,sha256=KS25NzHIwwg4tsOJ84QjhJM75tb_CxyN_yOigXRpNNg,174
|
|
2
|
+
kiwi_array/bridge.py,sha256=CvOce3e65BanIBk_CH1aFnlNL1Yo5YOpWjrYgav8F3k,25420
|
|
3
|
+
kiwi_array/cli.py,sha256=IsDcnvMOY1d_GTceFcrJwdx7vyR7yLMdEU3pOzQILJs,922
|
|
4
|
+
kiwi_array/notebook.py,sha256=DfYRf1LALgBJV0XEBNXpWyi82bG8yeeVLKqce-R7cNM,8979
|
|
5
|
+
kiwi_array-0.2.47.dist-info/METADATA,sha256=VQjuq1BfJ59x0RjC_b8_WJtz9wfP7i35ejrObXg_kz8,2154
|
|
6
|
+
kiwi_array-0.2.47.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
|
|
7
|
+
kiwi_array-0.2.47.dist-info/entry_points.txt,sha256=rHIbSUqGhESM76fFGiKgf6InC-bDRNnJdJQe8uAEBAo,45
|
|
8
|
+
kiwi_array-0.2.47.dist-info/top_level.txt,sha256=QZLBgjHZ20dRwlzhEjdi0HhTJaXovd4JoRSpsyu0dDk,11
|
|
9
|
+
kiwi_array-0.2.47.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
kiwi_array
|