kiwi-array 0.2.47__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
kiwi_array/__init__.py ADDED
@@ -0,0 +1,9 @@
1
+ from __future__ import annotations
2
+
3
+ __version__ = "0.2.47"
4
+
5
+
6
+ def load_ipython_extension(ipython):
7
+ from .notebook import load_ipython_extension as load
8
+
9
+ load(ipython)
kiwi_array/bridge.py ADDED
@@ -0,0 +1,801 @@
1
+ from __future__ import annotations
2
+
3
+ import ctypes
4
+ import ctypes.util
5
+ import os
6
+ import platform
7
+ import shutil
8
+ import shlex
9
+ import subprocess
10
+ from dataclasses import dataclass
11
+ from importlib import metadata as importlib_metadata
12
+ from pathlib import Path
13
+ from typing import Any, Optional, Sequence, Union
14
+
15
+ try:
16
+ import numpy as _np
17
+ except ImportError: # pragma: no cover - optional dependency
18
+ _np = None
19
+
20
+ LIB_ENV_VAR = "KIWI_BRIDGE_LIB"
21
+ LEGACY_LIB_ENV_VAR = "KIWI_JUPYTER_BRIDGE_LIB"
22
+ RUNTIME_ENV_VAR = "KIWI_RUNTIME"
23
+ CUDA_DRIVER_LIB_ENV_VAR = "KIWI_CUDA_DRIVER_LIB"
24
+ CUDA_PRELOAD_DRIVER_ENV_VAR = "KIWI_CUDA_PRELOAD_DRIVER"
25
+ SOURCE_ROOT_ENV_VAR = "KIWI_SOURCE_ROOT"
26
+ ZIG_ENV_VAR = "KIWI_ZIG"
27
+ MLX_PREFIX_ENV_VAR = "KIWI_MLX_PREFIX"
28
+ MLX_C_INCLUDE_ENV_VAR = "KIWI_MLX_C_INCLUDE"
29
+ RUNTIME_ENTRY_POINT_GROUP = "kiwi_array.runtimes"
30
+
31
+ DEVICE_VALUES = {
32
+ "auto": 0,
33
+ "cpu": 1,
34
+ "gpu": 2,
35
+ }
36
+
37
+ STATUS_NAMES = {
38
+ 0: "ok",
39
+ 1: "parse",
40
+ 2: "type",
41
+ 3: "name",
42
+ 4: "domain",
43
+ 5: "rank",
44
+ 6: "nyi",
45
+ 7: "length",
46
+ 8: "index",
47
+ 9: "mlx",
48
+ 10: "device",
49
+ 11: "error",
50
+ 12: "oom",
51
+ }
52
+
53
+ AUTOGRAD_PATH_NAMES = {
54
+ 0: "none",
55
+ 1: "mlx",
56
+ 2: "finite_difference",
57
+ }
58
+
59
+ PathLikeArg = Union[os.PathLike[str], str]
60
+
61
+
62
+ class _CKiwiEvalResult(ctypes.Structure):
63
+ _fields_ = [
64
+ ("status", ctypes.c_int),
65
+ ("echoed", ctypes.c_bool),
66
+ ("autograd_path", ctypes.c_int),
67
+ ("text_ptr", ctypes.c_void_p),
68
+ ("text_len", ctypes.c_size_t),
69
+ ("display_mime_ptr", ctypes.c_void_p),
70
+ ("display_mime_len", ctypes.c_size_t),
71
+ ("display_data_ptr", ctypes.c_void_p),
72
+ ("display_data_len", ctypes.c_size_t),
73
+ ]
74
+
75
+
76
+ @dataclass(frozen=True)
77
+ class KiwiEvalResult:
78
+ status: str
79
+ echoed: bool
80
+ autograd_path: str
81
+ text: Optional[str]
82
+ display_mime: Optional[str] = None
83
+ display_data: Optional[str] = None
84
+
85
+
86
+ class KiwiBridgeError(RuntimeError):
87
+ pass
88
+
89
+
90
+ @dataclass(frozen=True)
91
+ class RuntimeDescriptor:
92
+ name: str
93
+ priority: int
94
+ backend: str
95
+ accelerator: Optional[str]
96
+ native_dir: Optional[Path]
97
+ lib_dir: Optional[Path]
98
+ bin_dir: Optional[Path] = None
99
+ source: str = "unknown"
100
+
101
+ def bridge_library_path(self) -> Optional[Path]:
102
+ if self.native_dir is None:
103
+ return None
104
+ return self.native_dir / platform_library_name("kiwi_bridge")
105
+
106
+ def cli_path(self) -> Optional[Path]:
107
+ if self.bin_dir is None:
108
+ return None
109
+ return self.bin_dir / platform_executable_name("kiwi")
110
+
111
+
112
+ def _find_source_root() -> Optional[Path]:
113
+ override = os.environ.get(SOURCE_ROOT_ENV_VAR)
114
+ if override:
115
+ return Path(override).expanduser()
116
+
117
+ for parent in Path(__file__).resolve().parents:
118
+ if (parent / "build.zig").is_file() and (parent / "src" / "kiwi_bridge.zig").is_file():
119
+ return parent
120
+ return None
121
+
122
+
123
+ def implementation_root() -> Path:
124
+ root = _find_source_root()
125
+ if root is None:
126
+ raise KiwiBridgeError(
127
+ "Kiwi source root not found. Set KIWI_SOURCE_ROOT when building the "
128
+ "bridge from source, or install a platform wheel that carries the "
129
+ "native Kiwi runtime."
130
+ )
131
+ return root
132
+
133
+
134
+ def _has_workspace_dependency_layout() -> bool:
135
+ root = _find_source_root()
136
+ if root is None:
137
+ return False
138
+ return (root.parents[1] / "vendor" / "mlx-c").exists()
139
+
140
+
141
+ def repo_root() -> Path:
142
+ root = implementation_root()
143
+ return root.parents[1] if _has_workspace_dependency_layout() else root
144
+
145
+
146
+ def _package_root() -> Optional[Path]:
147
+ return Path(__file__).resolve().parent
148
+
149
+
150
+ def _packaged_native_dir() -> Optional[Path]:
151
+ root = _package_root()
152
+ return None if root is None else root / "native"
153
+
154
+
155
+ def _packaged_library_dir() -> Optional[Path]:
156
+ root = _package_root()
157
+ return None if root is None else root / "lib"
158
+
159
+
160
+ def _is_platform_dynamic_library_name(name: str) -> bool:
161
+ system = platform.system()
162
+ if system == "Darwin":
163
+ return name.endswith(".dylib")
164
+ if system == "Windows":
165
+ return name.endswith(".dll")
166
+ return ".so" in name
167
+
168
+
169
+ def _dir_has_payload(path: Optional[Path]) -> bool:
170
+ if path is None or not path.is_dir():
171
+ return False
172
+ return any(
173
+ child.is_file()
174
+ and child.name not in {".gitignore", ".gitkeep"}
175
+ and _is_platform_dynamic_library_name(child.name)
176
+ for child in path.iterdir()
177
+ )
178
+
179
+
180
+ def _path_from_descriptor(value: object) -> Optional[Path]:
181
+ if value in (None, ""):
182
+ return None
183
+ return Path(os.fspath(value)).expanduser()
184
+
185
+
186
+ def _coerce_runtime_descriptor(raw: object, source: str) -> RuntimeDescriptor:
187
+ if isinstance(raw, RuntimeDescriptor):
188
+ return raw
189
+ if not isinstance(raw, dict):
190
+ raise TypeError(f"runtime descriptor from {source} must be a dict")
191
+ return RuntimeDescriptor(
192
+ name=str(raw["name"]),
193
+ priority=int(raw.get("priority", 0)),
194
+ backend=str(raw.get("backend", "unknown")),
195
+ accelerator=None if raw.get("accelerator") is None else str(raw.get("accelerator")),
196
+ native_dir=_path_from_descriptor(raw.get("native_dir")),
197
+ lib_dir=_path_from_descriptor(raw.get("lib_dir")),
198
+ bin_dir=_path_from_descriptor(raw.get("bin_dir")),
199
+ source=source,
200
+ )
201
+
202
+
203
+ def _embedded_runtime_descriptor() -> RuntimeDescriptor:
204
+ root = _package_root()
205
+ return RuntimeDescriptor(
206
+ name="host",
207
+ priority=10,
208
+ backend="host",
209
+ accelerator=None,
210
+ native_dir=root / "native",
211
+ lib_dir=root / "lib",
212
+ bin_dir=root / "bin",
213
+ source="kiwi_array",
214
+ )
215
+
216
+
217
+ def _entry_point_runtime_descriptors() -> list[RuntimeDescriptor]:
218
+ try:
219
+ entry_points = importlib_metadata.entry_points()
220
+ except Exception:
221
+ return []
222
+ if hasattr(entry_points, "select"):
223
+ selected = entry_points.select(group=RUNTIME_ENTRY_POINT_GROUP)
224
+ else: # pragma: no cover - Python <3.10 compatibility path
225
+ selected = entry_points.get(RUNTIME_ENTRY_POINT_GROUP, [])
226
+
227
+ descriptors: list[RuntimeDescriptor] = []
228
+ for entry_point in selected:
229
+ try:
230
+ loaded = entry_point.load()
231
+ raw = loaded() if callable(loaded) else loaded
232
+ descriptors.append(_coerce_runtime_descriptor(raw, entry_point.name))
233
+ except Exception:
234
+ continue
235
+ return descriptors
236
+
237
+
238
+ def discover_runtime_descriptors() -> list[RuntimeDescriptor]:
239
+ descriptors = [_embedded_runtime_descriptor(), *_entry_point_runtime_descriptors()]
240
+ by_name: dict[str, RuntimeDescriptor] = {}
241
+ for descriptor in descriptors:
242
+ current = by_name.get(descriptor.name)
243
+ if current is None or descriptor.priority > current.priority:
244
+ by_name[descriptor.name] = descriptor
245
+ return sorted(by_name.values(), key=lambda item: (-item.priority, item.name))
246
+
247
+
248
+ def _runtime_selection() -> str:
249
+ value = os.environ.get(RUNTIME_ENV_VAR, "auto").strip()
250
+ return value or "auto"
251
+
252
+
253
+ def _selected_runtime_descriptors() -> list[RuntimeDescriptor]:
254
+ selection = _runtime_selection()
255
+ descriptors = discover_runtime_descriptors()
256
+ if selection == "auto":
257
+ return descriptors
258
+ return [descriptor for descriptor in descriptors if descriptor.name == selection]
259
+
260
+
261
+ def active_runtime_descriptor() -> Optional[RuntimeDescriptor]:
262
+ for descriptor in _selected_runtime_descriptors():
263
+ bridge = descriptor.bridge_library_path()
264
+ if bridge is not None and bridge.exists():
265
+ return descriptor
266
+ return None
267
+
268
+
269
+ def zig_executable() -> str:
270
+ override = os.environ.get(ZIG_ENV_VAR) or os.environ.get("ZIG")
271
+ if override:
272
+ return str(Path(override).expanduser())
273
+ bundled = repo_root() / "tools" / "zig"
274
+ if bundled.exists():
275
+ return str(bundled)
276
+ path_zig = shutil.which("zig")
277
+ return path_zig if path_zig is not None else "zig"
278
+
279
+
280
+ def default_mlx_prefix() -> Path:
281
+ if _has_workspace_dependency_layout():
282
+ return repo_root() / ".artifacts" / "mlx" / "macos-default-install"
283
+ return implementation_root() / ".deps" / "mlx"
284
+
285
+
286
+ def default_mlx_c_include() -> Path:
287
+ if _has_workspace_dependency_layout():
288
+ return repo_root() / "vendor" / "mlx-c"
289
+ return implementation_root() / ".deps" / "mlx-c"
290
+
291
+
292
+ def mlx_prefix() -> Path:
293
+ return Path(
294
+ os.environ.get(MLX_PREFIX_ENV_VAR)
295
+ or default_mlx_prefix()
296
+ ).expanduser()
297
+
298
+
299
+ def mlx_c_include() -> Path:
300
+ return Path(
301
+ os.environ.get(MLX_C_INCLUDE_ENV_VAR)
302
+ or default_mlx_c_include()
303
+ ).expanduser()
304
+
305
+
306
+ def runtime_library_search_dir() -> Path:
307
+ source_root = _find_source_root()
308
+ if _runtime_selection() == "auto" and source_root is not None:
309
+ source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
310
+ if source_bridge.exists():
311
+ return mlx_prefix() / "lib"
312
+
313
+ descriptor = active_runtime_descriptor()
314
+ if descriptor is not None and descriptor.lib_dir is not None and _dir_has_payload(descriptor.lib_dir):
315
+ return descriptor.lib_dir
316
+ return mlx_prefix() / "lib"
317
+
318
+
319
+ def runtime_library_env_var() -> Optional[str]:
320
+ system = platform.system()
321
+ if system == "Darwin":
322
+ return "DYLD_LIBRARY_PATH"
323
+ if system == "Linux":
324
+ return "LD_LIBRARY_PATH"
325
+ return None
326
+
327
+
328
+ def build_args(optimize: str) -> list[str]:
329
+ bridge_mlx_prefix = mlx_prefix()
330
+ bridge_mlx_c_include = mlx_c_include()
331
+ return [
332
+ zig_executable(),
333
+ "build",
334
+ f"-Doptimize={optimize}",
335
+ "-Dpublic-cli=true",
336
+ "-Dcli-name=kiwi",
337
+ "-Dinstall-sdk=true",
338
+ f"-Dmlx-prefix={bridge_mlx_prefix}",
339
+ f"-Dmlx-c-include={bridge_mlx_c_include}",
340
+ ]
341
+
342
+
343
+ def platform_library_name(base_name: str) -> str:
344
+ system = platform.system()
345
+ if system == "Darwin":
346
+ return f"lib{base_name}.dylib"
347
+ if system == "Windows":
348
+ return f"{base_name}.dll"
349
+ return f"lib{base_name}.so"
350
+
351
+
352
+ def platform_executable_name(base_name: str) -> str:
353
+ return f"{base_name}.exe" if platform.system() == "Windows" else base_name
354
+
355
+
356
+ def default_cli_candidates() -> list[Path]:
357
+ candidates: list[Path] = []
358
+ source_root = _find_source_root()
359
+ source_cli = None
360
+ if source_root is not None:
361
+ source_cli = source_root / "zig-out" / "bin" / platform_executable_name("kiwi")
362
+
363
+ if _runtime_selection() == "auto" and source_cli is not None:
364
+ candidates.append(source_cli)
365
+
366
+ for descriptor in _selected_runtime_descriptors():
367
+ cli = descriptor.cli_path()
368
+ if cli is not None:
369
+ candidates.append(cli)
370
+
371
+ if _runtime_selection() != "auto" and source_cli is not None:
372
+ candidates.append(source_cli)
373
+ return candidates
374
+
375
+
376
+ def find_cli_path() -> Path:
377
+ for candidate in default_cli_candidates():
378
+ if candidate.exists():
379
+ return candidate
380
+ joined = ", ".join(str(candidate) for candidate in default_cli_candidates())
381
+ raise FileNotFoundError(
382
+ "Kiwi CLI not found. Install a runtime payload such as "
383
+ "`kiwi-array-host`, `kiwi-array-cpu`, `kiwi-array-metal`, or "
384
+ "`kiwi-array-cuda12`. Searched: "
385
+ f"{joined}"
386
+ )
387
+
388
+
389
+ def default_library_candidates() -> list[Path]:
390
+ candidates: list[Path] = []
391
+ for env_var in (LIB_ENV_VAR, LEGACY_LIB_ENV_VAR):
392
+ env_path = os.environ.get(env_var)
393
+ if env_path:
394
+ candidates.append(Path(env_path).expanduser())
395
+
396
+ source_root = _find_source_root()
397
+ source_bridge = None
398
+ if source_root is not None:
399
+ source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
400
+
401
+ if _runtime_selection() == "auto" and source_bridge is not None:
402
+ candidates.append(source_bridge)
403
+
404
+ for descriptor in _selected_runtime_descriptors():
405
+ bridge = descriptor.bridge_library_path()
406
+ if bridge is not None:
407
+ candidates.append(bridge)
408
+
409
+ if _runtime_selection() != "auto" and source_bridge is not None:
410
+ candidates.append(source_bridge)
411
+ return candidates
412
+
413
+
414
+ def find_library_path(path: Optional[PathLikeArg] = None) -> Path:
415
+ if path is not None:
416
+ candidate = Path(path).expanduser()
417
+ if candidate.exists():
418
+ return candidate
419
+ raise FileNotFoundError(f"Kiwi bridge library not found at {candidate}")
420
+
421
+ for candidate in default_library_candidates():
422
+ if candidate.exists():
423
+ return candidate
424
+ joined = ", ".join(str(candidate) for candidate in default_library_candidates())
425
+ source_root = _find_source_root()
426
+ if source_root is None:
427
+ hint = "Install a platform wheel that carries the native Kiwi runtime."
428
+ else:
429
+ hint = f"Build it with `{shlex.join(build_args('ReleaseFast'))}` from {source_root}."
430
+ raise FileNotFoundError(f"Kiwi bridge library not found. {hint} Searched: {joined}")
431
+
432
+
433
+ def build_bridge(optimize: str = "ReleaseFast") -> None:
434
+ subprocess.run(
435
+ build_args(optimize),
436
+ cwd=implementation_root(),
437
+ check=True,
438
+ )
439
+
440
+
441
+ def _configure_library(lib: ctypes.CDLL) -> ctypes.CDLL:
442
+ lib.kiwi_session_create.argtypes = [ctypes.c_int]
443
+ lib.kiwi_session_create.restype = ctypes.c_void_p
444
+
445
+ lib.kiwi_session_destroy.argtypes = [ctypes.c_void_p]
446
+ lib.kiwi_session_destroy.restype = None
447
+
448
+ lib.kiwi_session_eval.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ctypes.c_size_t]
449
+ lib.kiwi_session_eval.restype = _CKiwiEvalResult
450
+
451
+ lib.kiwi_session_set_global_float_array.argtypes = [
452
+ ctypes.c_void_p,
453
+ ctypes.c_char_p,
454
+ ctypes.c_size_t,
455
+ ctypes.POINTER(ctypes.c_float),
456
+ ctypes.POINTER(ctypes.c_int32),
457
+ ctypes.c_size_t,
458
+ ]
459
+ lib.kiwi_session_set_global_float_array.restype = ctypes.c_int
460
+
461
+ if hasattr(lib, "kiwi_session_set_global_mlx_float_array"):
462
+ lib.kiwi_session_set_global_mlx_float_array.argtypes = [
463
+ ctypes.c_void_p,
464
+ ctypes.c_char_p,
465
+ ctypes.c_size_t,
466
+ ctypes.POINTER(ctypes.c_float),
467
+ ctypes.POINTER(ctypes.c_int32),
468
+ ctypes.c_size_t,
469
+ ]
470
+ lib.kiwi_session_set_global_mlx_float_array.restype = ctypes.c_int
471
+
472
+ lib.kiwi_session_set_global_int_array.argtypes = [
473
+ ctypes.c_void_p,
474
+ ctypes.c_char_p,
475
+ ctypes.c_size_t,
476
+ ctypes.POINTER(ctypes.c_int32),
477
+ ctypes.POINTER(ctypes.c_int32),
478
+ ctypes.c_size_t,
479
+ ]
480
+ lib.kiwi_session_set_global_int_array.restype = ctypes.c_int
481
+
482
+ lib.kiwi_session_set_global_bool_array.argtypes = [
483
+ ctypes.c_void_p,
484
+ ctypes.c_char_p,
485
+ ctypes.c_size_t,
486
+ ctypes.POINTER(ctypes.c_bool),
487
+ ctypes.POINTER(ctypes.c_int32),
488
+ ctypes.c_size_t,
489
+ ]
490
+ lib.kiwi_session_set_global_bool_array.restype = ctypes.c_int
491
+
492
+ lib.kiwi_eval_result_free.argtypes = [_CKiwiEvalResult]
493
+ lib.kiwi_eval_result_free.restype = None
494
+
495
+ lib.kiwi_status_name.argtypes = [ctypes.c_int]
496
+ lib.kiwi_status_name.restype = ctypes.c_char_p
497
+ return lib
498
+
499
+
500
+ def _preload_runtime_dependencies(descriptor: Optional[RuntimeDescriptor] = None) -> None:
501
+ _preload_cuda_driver()
502
+
503
+ search_dirs = []
504
+ if descriptor is not None and descriptor.lib_dir is not None:
505
+ search_dirs.append(descriptor.lib_dir)
506
+ search_dirs.append(runtime_library_search_dir())
507
+ packaged_lib = _packaged_library_dir()
508
+ if packaged_lib is not None:
509
+ search_dirs.append(packaged_lib)
510
+ source_root = _find_source_root()
511
+ if source_root is not None:
512
+ search_dirs.append(source_root / "zig-out" / "lib")
513
+
514
+ seen: set[Path] = set()
515
+ for directory in search_dirs:
516
+ if directory in seen:
517
+ continue
518
+ seen.add(directory)
519
+ for base_name in ("mlx", "duckdb"):
520
+ library = directory / platform_library_name(base_name)
521
+ if library.exists():
522
+ ctypes.CDLL(str(library), mode=getattr(ctypes, "RTLD_GLOBAL", 0))
523
+
524
+
525
+ def _preload_cuda_driver() -> None:
526
+ if platform.system() != "Linux":
527
+ return
528
+
529
+ preload_mode = os.environ.get(CUDA_PRELOAD_DRIVER_ENV_VAR, "0").strip().lower()
530
+ if preload_mode in {"", "0", "false", "no", "off"}:
531
+ return
532
+
533
+ candidates = []
534
+ override = os.environ.get(CUDA_DRIVER_LIB_ENV_VAR)
535
+ if override:
536
+ candidates.append(Path(override).expanduser())
537
+
538
+ found = ctypes.util.find_library("cuda")
539
+ if found:
540
+ candidates.append(Path(found))
541
+
542
+ candidates.extend(
543
+ [
544
+ Path("/usr/lib64-nvidia/libcuda.so.1"),
545
+ Path("/usr/local/nvidia/lib64/libcuda.so.1"),
546
+ Path("/usr/lib/x86_64-linux-gnu/libcuda.so.1"),
547
+ ]
548
+ )
549
+ if preload_mode in {"compat", "cuda-compat"}:
550
+ candidates.append(Path("/usr/local/cuda/compat/libcuda.so.1"))
551
+
552
+ for candidate in candidates:
553
+ try:
554
+ if candidate.is_absolute() and not candidate.exists():
555
+ continue
556
+ ctypes.CDLL(str(candidate), mode=getattr(ctypes, "RTLD_GLOBAL", 0))
557
+ return
558
+ except OSError:
559
+ continue
560
+
561
+
562
+ def _load_library_at(path: Path, descriptor: Optional[RuntimeDescriptor] = None) -> ctypes.CDLL:
563
+ _preload_runtime_dependencies(descriptor)
564
+ return _configure_library(ctypes.CDLL(str(path)))
565
+
566
+
567
+ def load_library(path: Optional[PathLikeArg] = None) -> ctypes.CDLL:
568
+ if path is not None or os.environ.get(LIB_ENV_VAR) or os.environ.get(LEGACY_LIB_ENV_VAR):
569
+ return _load_library_at(find_library_path(path))
570
+
571
+ errors: list[str] = []
572
+ selection = _runtime_selection()
573
+ source_root = _find_source_root()
574
+ if selection == "auto" and source_root is not None:
575
+ source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
576
+ if source_bridge.exists():
577
+ try:
578
+ return _load_library_at(source_bridge)
579
+ except OSError as exc:
580
+ errors.append(f"source: {exc}")
581
+
582
+ for descriptor in _selected_runtime_descriptors():
583
+ bridge = descriptor.bridge_library_path()
584
+ if bridge is None or not bridge.exists():
585
+ continue
586
+ try:
587
+ return _load_library_at(bridge, descriptor)
588
+ except OSError as exc:
589
+ errors.append(f"{descriptor.name}: {exc}")
590
+ if selection != "auto":
591
+ raise KiwiBridgeError(f"failed to load Kiwi runtime {descriptor.name!r}: {exc}") from exc
592
+
593
+ if selection != "auto" and source_root is not None:
594
+ source_bridge = source_root / "zig-out" / "lib" / platform_library_name("kiwi_bridge")
595
+ if source_bridge.exists():
596
+ try:
597
+ return _load_library_at(source_bridge)
598
+ except OSError as exc:
599
+ errors.append(f"source: {exc}")
600
+
601
+ if errors:
602
+ raise KiwiBridgeError("failed to load any Kiwi runtime: " + "; ".join(errors))
603
+ return _load_library_at(find_library_path())
604
+
605
+
606
+ def _normalize_dims(dims: Optional[Sequence[int]], data) -> tuple[int, ...]:
607
+ if dims is not None:
608
+ return tuple(int(dim) for dim in dims)
609
+
610
+ shape = getattr(data, "shape", None)
611
+ if shape is None:
612
+ raise TypeError("dims are required unless data exposes a shape")
613
+ return tuple(int(dim) for dim in shape)
614
+
615
+
616
+ def _product(dims: Sequence[int]) -> int:
617
+ total = 1
618
+ for dim in dims:
619
+ if dim < 0:
620
+ raise ValueError(f"negative dimension {dim}")
621
+ total *= dim
622
+ return total
623
+
624
+
625
+ def _sequence_values(data, count: int):
626
+ if count == 1 and not isinstance(data, (list, tuple)):
627
+ return [data]
628
+ values = list(data)
629
+ if len(values) != count:
630
+ raise ValueError(f"expected {count} items, got {len(values)}")
631
+ return values
632
+
633
+
634
+ def _prepare_array_data(data, dims: tuple[int, ...], c_type, numpy_dtype):
635
+ count = _product(dims)
636
+ if _np is not None:
637
+ arr = _np.asarray(data, dtype=numpy_dtype)
638
+ if arr.size != count:
639
+ raise ValueError(f"expected {count} items, got {arr.size}")
640
+ flat = _np.ascontiguousarray(arr.reshape(-1))
641
+ ptr = None if count == 0 else flat.ctypes.data_as(ctypes.POINTER(c_type))
642
+ return flat, ptr
643
+
644
+ values = _sequence_values(data, count)
645
+ buf = (c_type * count)(*values)
646
+ ptr = None if count == 0 else buf
647
+ return buf, ptr
648
+
649
+
650
+ class KiwiSession:
651
+ def __init__(self, device: str = "auto", library_path: Optional[PathLikeArg] = None) -> None:
652
+ try:
653
+ device_value = DEVICE_VALUES[device]
654
+ except KeyError as exc:
655
+ raise ValueError(f"unknown Kiwi device {device!r}") from exc
656
+
657
+ self._lib = load_library(library_path)
658
+ self._handle = self._lib.kiwi_session_create(device_value)
659
+ if not self._handle:
660
+ raise KiwiBridgeError("failed to create Kiwi session")
661
+
662
+ def close(self) -> None:
663
+ if self._handle:
664
+ self._lib.kiwi_session_destroy(self._handle)
665
+ self._handle = None
666
+
667
+ def eval(self, source: str) -> KiwiEvalResult:
668
+ self._ensure_open()
669
+ encoded = source.encode("utf-8")
670
+ result = self._lib.kiwi_session_eval(self._handle, encoded, len(encoded))
671
+ try:
672
+ text = None
673
+ if result.text_ptr:
674
+ text = ctypes.string_at(result.text_ptr, result.text_len).decode("utf-8")
675
+ display_mime = None
676
+ if result.display_mime_ptr:
677
+ display_mime = ctypes.string_at(
678
+ result.display_mime_ptr,
679
+ result.display_mime_len,
680
+ ).decode("utf-8")
681
+ display_data = None
682
+ if result.display_data_ptr:
683
+ display_data = ctypes.string_at(
684
+ result.display_data_ptr,
685
+ result.display_data_len,
686
+ ).decode("utf-8")
687
+ return KiwiEvalResult(
688
+ status=STATUS_NAMES.get(result.status, "error"),
689
+ echoed=bool(result.echoed),
690
+ autograd_path=AUTOGRAD_PATH_NAMES.get(result.autograd_path, "none"),
691
+ text=text,
692
+ display_mime=display_mime,
693
+ display_data=display_data,
694
+ )
695
+ finally:
696
+ self._lib.kiwi_eval_result_free(result)
697
+
698
+ def set_global_float_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
699
+ self._set_global_array(
700
+ "kiwi_session_set_global_float_array",
701
+ name,
702
+ data,
703
+ dims,
704
+ ctypes.c_float,
705
+ _np.float32 if _np is not None else None,
706
+ )
707
+
708
+ def set_global_mlx_float_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
709
+ self._set_global_array(
710
+ "kiwi_session_set_global_mlx_float_array",
711
+ name,
712
+ data,
713
+ dims,
714
+ ctypes.c_float,
715
+ _np.float32 if _np is not None else None,
716
+ )
717
+
718
+ def set_global_int_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
719
+ self._set_global_array(
720
+ "kiwi_session_set_global_int_array",
721
+ name,
722
+ data,
723
+ dims,
724
+ ctypes.c_int32,
725
+ _np.int32 if _np is not None else None,
726
+ )
727
+
728
+ def set_global_bool_array(self, name: str, data, dims: Optional[Sequence[int]] = None) -> None:
729
+ self._set_global_array(
730
+ "kiwi_session_set_global_bool_array",
731
+ name,
732
+ data,
733
+ dims,
734
+ ctypes.c_bool,
735
+ _np.bool_ if _np is not None else None,
736
+ )
737
+
738
+ def __enter__(self) -> "KiwiSession":
739
+ return self
740
+
741
+ def __exit__(self, exc_type, exc, tb) -> None:
742
+ self.close()
743
+
744
+ def _ensure_open(self) -> None:
745
+ if not self._handle:
746
+ raise KiwiBridgeError("Kiwi session is closed")
747
+
748
+ def _set_global_array(
749
+ self,
750
+ fn_name: str,
751
+ name: str,
752
+ data,
753
+ dims: Optional[Sequence[int]],
754
+ c_type,
755
+ numpy_dtype,
756
+ ) -> None:
757
+ self._ensure_open()
758
+ if not hasattr(self._lib, fn_name):
759
+ raise KiwiBridgeError(f"native Kiwi bridge does not expose {fn_name}")
760
+ dims_tuple = _normalize_dims(dims, data)
761
+ owner, data_ptr = _prepare_array_data(data, dims_tuple, c_type, numpy_dtype)
762
+ dims_owner = (ctypes.c_int32 * len(dims_tuple))(*dims_tuple)
763
+ dims_ptr = None if len(dims_tuple) == 0 else dims_owner
764
+ encoded = name.encode("utf-8")
765
+ status = getattr(self._lib, fn_name)(
766
+ self._handle,
767
+ encoded,
768
+ len(encoded),
769
+ data_ptr,
770
+ dims_ptr,
771
+ len(dims_tuple),
772
+ )
773
+ _ = owner
774
+ if status != 0:
775
+ raise KiwiBridgeError(f"failed to set global {name!r}: {STATUS_NAMES.get(status, 'error')}")
776
+
777
+
778
+ __all__ = [
779
+ "AUTOGRAD_PATH_NAMES",
780
+ "KiwiBridgeError",
781
+ "KiwiEvalResult",
782
+ "KiwiSession",
783
+ "LEGACY_LIB_ENV_VAR",
784
+ "LIB_ENV_VAR",
785
+ "RUNTIME_ENV_VAR",
786
+ "RuntimeDescriptor",
787
+ "STATUS_NAMES",
788
+ "active_runtime_descriptor",
789
+ "build_bridge",
790
+ "default_library_candidates",
791
+ "default_cli_candidates",
792
+ "discover_runtime_descriptors",
793
+ "find_cli_path",
794
+ "find_library_path",
795
+ "implementation_root",
796
+ "load_library",
797
+ "platform_library_name",
798
+ "platform_executable_name",
799
+ "runtime_library_env_var",
800
+ "runtime_library_search_dir",
801
+ ]
kiwi_array/cli.py ADDED
@@ -0,0 +1,36 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import subprocess
5
+ import sys
6
+
7
+ from . import bridge
8
+
9
+
10
+ def _launcher_environment() -> dict[str, str]:
11
+ env = os.environ.copy()
12
+ library_var = bridge.runtime_library_env_var()
13
+ if library_var is None:
14
+ return env
15
+ library_dir = bridge.runtime_library_search_dir()
16
+ existing = env.get(library_var)
17
+ env[library_var] = str(library_dir) if not existing else f"{library_dir}{os.pathsep}{existing}"
18
+ return env
19
+
20
+
21
+ def main() -> int:
22
+ try:
23
+ cli = bridge.find_cli_path()
24
+ except FileNotFoundError as exc:
25
+ print(exc, file=sys.stderr)
26
+ return 1
27
+
28
+ argv = [str(cli), *sys.argv[1:]]
29
+ if hasattr(os, "execve"):
30
+ os.execve(str(cli), argv, _launcher_environment())
31
+ completed = subprocess.run(argv, env=_launcher_environment())
32
+ return int(completed.returncode)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ raise SystemExit(main())
kiwi_array/notebook.py ADDED
@@ -0,0 +1,268 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import os
5
+ import platform
6
+ import subprocess
7
+ import sys
8
+ import uuid
9
+ from dataclasses import dataclass
10
+ from typing import Iterator, Protocol
11
+
12
+ from kiwi_array.bridge import KiwiEvalResult, KiwiSession, find_library_path, runtime_library_search_dir
13
+
14
+ from . import __version__
15
+
16
+
17
+ class SessionProtocol(Protocol):
18
+ def eval(self, source: str) -> KiwiEvalResult:
19
+ ...
20
+
21
+ def close(self) -> None:
22
+ ...
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class CellOutput:
27
+ line_no: int
28
+ text: str
29
+ autograd_path: str
30
+ display_mime: str | None = None
31
+ display_data: str | None = None
32
+
33
+
34
+ @dataclass(frozen=True)
35
+ class SmokeCheck:
36
+ name: str
37
+ ok: bool
38
+ detail: str
39
+
40
+
41
+ class KiwiNotebookError(RuntimeError):
42
+ def __init__(self, line_no: int, status: str) -> None:
43
+ self.line_no = line_no
44
+ self.status = status
45
+ super().__init__(f"line {line_no}: !{status}")
46
+
47
+
48
+ def iter_executable_lines(code: str) -> Iterator[tuple[int, str]]:
49
+ for line_no, raw_line in enumerate(code.splitlines(), start=1):
50
+ line = raw_line.rstrip("\r")
51
+ trimmed = line.strip(" \t")
52
+ if not trimmed or trimmed.startswith("/"):
53
+ continue
54
+ yield line_no, line
55
+
56
+
57
+ def execute_cell(session: SessionProtocol, code: str) -> list[CellOutput]:
58
+ outputs: list[CellOutput] = []
59
+ for line_no, line in iter_executable_lines(code):
60
+ result = session.eval(line)
61
+ if result.status != "ok":
62
+ raise KiwiNotebookError(line_no, result.status)
63
+ if result.echoed and result.text is not None:
64
+ outputs.append(
65
+ CellOutput(
66
+ line_no=line_no,
67
+ text=result.text,
68
+ autograd_path=result.autograd_path,
69
+ display_mime=result.display_mime,
70
+ display_data=result.display_data,
71
+ )
72
+ )
73
+ return outputs
74
+
75
+
76
+ def display_bundle_for_output(output: CellOutput) -> dict[str, object]:
77
+ bundle: dict[str, object] = {"text/plain": output.text}
78
+ if output.display_mime is None or output.display_data is None:
79
+ return bundle
80
+ try:
81
+ payload = json.loads(output.display_data)
82
+ except json.JSONDecodeError:
83
+ return bundle
84
+ bundle[output.display_mime] = payload
85
+ if "vegalite" in output.display_mime:
86
+ bundle["text/html"] = vegalite_html(payload)
87
+ return bundle
88
+
89
+
90
+ def vegalite_embed_options() -> dict[str, object]:
91
+ options: dict[str, object] = {"actions": False, "renderer": "canvas"}
92
+ theme = os.environ.get("KIWI_VEGALITE_THEME", "").strip()
93
+ if theme and theme.lower() not in {"default", "none"}:
94
+ options["theme"] = theme
95
+ return options
96
+
97
+
98
+ def vegalite_html(spec: object) -> str:
99
+ element_id = f"kiwi-vegalite-{uuid.uuid4().hex}"
100
+ spec_json = json.dumps(spec, separators=(",", ":"))
101
+ options_json = json.dumps(vegalite_embed_options(), separators=(",", ":"))
102
+ return f"""
103
+ <div id="{element_id}"></div>
104
+ <script src="https://cdn.jsdelivr.net/npm/vega@5"></script>
105
+ <script src="https://cdn.jsdelivr.net/npm/vega-lite@5"></script>
106
+ <script src="https://cdn.jsdelivr.net/npm/vega-embed@6"></script>
107
+ <script>
108
+ (function() {{
109
+ const spec = {spec_json};
110
+ const render = function() {{
111
+ if (!window.vegaEmbed) {{
112
+ document.getElementById("{element_id}").textContent = {json.dumps("Vega-Embed failed to load.")};
113
+ return;
114
+ }}
115
+ window.vegaEmbed("#{element_id}", spec, {options_json});
116
+ }};
117
+ render();
118
+ }})();
119
+ </script>
120
+ """
121
+
122
+
123
+ def _safe_value(label: str, value_fn) -> tuple[str, str]:
124
+ try:
125
+ value = str(value_fn())
126
+ except Exception as exc: # pragma: no cover - defensive diagnostic path
127
+ value = f"!{type(exc).__name__}: {exc}"
128
+ return label, value
129
+
130
+
131
+ def _nvidia_smi_summary() -> str:
132
+ if platform.system() != "Linux":
133
+ return "not linux"
134
+ try:
135
+ result = subprocess.run(
136
+ [
137
+ "nvidia-smi",
138
+ "--query-gpu=name,driver_version,memory.total",
139
+ "--format=csv,noheader,nounits",
140
+ ],
141
+ check=False,
142
+ capture_output=True,
143
+ text=True,
144
+ timeout=3,
145
+ )
146
+ except (OSError, subprocess.TimeoutExpired):
147
+ return "not found"
148
+ if result.returncode != 0:
149
+ message = (result.stderr or result.stdout).strip()
150
+ return message or f"exit {result.returncode}"
151
+ return result.stdout.strip() or "no gpu"
152
+
153
+
154
+ def collect_info() -> list[tuple[str, str]]:
155
+ return [
156
+ ("kiwi-array", __version__),
157
+ ("python", sys.version.split()[0]),
158
+ ("platform", platform.platform()),
159
+ _safe_value("bridge", find_library_path),
160
+ _safe_value("runtime_libs", runtime_library_search_dir),
161
+ ("device_env", os.environ.get("KIWI_JUPYTER_DEVICE", "auto")),
162
+ ("vegalite_theme", os.environ.get("KIWI_VEGALITE_THEME", "default")),
163
+ ("nvidia_smi", _nvidia_smi_summary()),
164
+ ]
165
+
166
+
167
+ def format_info_text(info: list[tuple[str, str]]) -> str:
168
+ width = max((len(label) for label, _ in info), default=0)
169
+ return "\n".join(f"{label.rjust(width)}: {value}" for label, value in info)
170
+
171
+
172
+ def run_smoke(session: SessionProtocol, include_mlx: bool = False) -> list[SmokeCheck]:
173
+ checks: list[SmokeCheck] = []
174
+ for name, source, expected in [
175
+ ("scalar", "1+1", "2"),
176
+ ("vector", "+/1 2 3", "6"),
177
+ ("grad", "grad[{+/(x*x)}][1 2 3]", "2 4 6"),
178
+ ]:
179
+ result = session.eval(source)
180
+ ok = result.status == "ok" and result.text == expected
181
+ detail = result.text if result.status == "ok" and result.text is not None else f"!{result.status}"
182
+ checks.append(SmokeCheck(name=name, ok=ok, detail=detail))
183
+
184
+ if include_mlx:
185
+ setter = getattr(session, "set_global_mlx_float_array", None)
186
+ if setter is None:
187
+ checks.append(SmokeCheck(name="mlx_global", ok=False, detail="native bridge lacks MLX setter"))
188
+ return checks
189
+ try:
190
+ setter("kx", [1.0, 2.0, 3.0], (3,))
191
+ result = session.eval("grad[{+/(x*x)}][kx]")
192
+ ok = result.status == "ok" and result.text == "2 4 6"
193
+ detail = result.text if result.status == "ok" and result.text is not None else f"!{result.status}"
194
+ checks.append(SmokeCheck(name="mlx_global", ok=ok, detail=detail))
195
+ except Exception as exc:
196
+ checks.append(SmokeCheck(name="mlx_global", ok=False, detail=f"!{type(exc).__name__}: {exc}"))
197
+
198
+ return checks
199
+
200
+
201
+ def format_smoke_text(checks: list[SmokeCheck]) -> str:
202
+ width = max((len(check.name) for check in checks), default=0)
203
+ lines = []
204
+ for check in checks:
205
+ status = "ok" if check.ok else "error"
206
+ lines.append(f"{check.name.rjust(width)}: {status} {check.detail}")
207
+ return "\n".join(lines)
208
+
209
+
210
+ def load_ipython_extension(ipython) -> None:
211
+ from IPython.core.magic import Magics, line_cell_magic, line_magic, magics_class
212
+ from IPython.display import display
213
+
214
+ @magics_class
215
+ class KiwiMagics(Magics):
216
+ def __init__(self, shell) -> None:
217
+ super().__init__(shell)
218
+ self._session: KiwiSession | None = None
219
+
220
+ def _kiwi_session(self) -> KiwiSession:
221
+ if self._session is None:
222
+ device = os.environ.get("KIWI_JUPYTER_DEVICE", "auto")
223
+ library_path = os.environ.get("KIWI_BRIDGE_LIB")
224
+ self._session = KiwiSession(device=device, library_path=library_path)
225
+ return self._session
226
+
227
+ def _reset(self) -> None:
228
+ if self._session is not None:
229
+ self._session.close()
230
+ self._session = None
231
+
232
+ def _run(self, line: str, cell: str | None) -> None:
233
+ if cell is None and line.strip() in {"--reset", "-r"}:
234
+ self._reset()
235
+ return
236
+ code = line if cell is None else cell
237
+ outputs = execute_cell(self._kiwi_session(), code)
238
+ for output in outputs:
239
+ display(
240
+ display_bundle_for_output(output),
241
+ raw=True,
242
+ metadata={"kiwi_autograd_path": output.autograd_path},
243
+ )
244
+
245
+ @line_cell_magic
246
+ def kiwi(self, line: str, cell: str | None = None) -> None:
247
+ self._run(line, cell)
248
+
249
+ @line_cell_magic
250
+ def k(self, line: str, cell: str | None = None) -> None:
251
+ self._run(line, cell)
252
+
253
+ @line_magic
254
+ def kinfo(self, line: str) -> None:
255
+ del line
256
+ print(format_info_text(collect_info()))
257
+
258
+ @line_magic
259
+ def ksmoke(self, line: str) -> None:
260
+ flags = set(line.split())
261
+ include_mlx = bool(flags & {"--mlx", "--gpu"})
262
+ print(format_smoke_text(run_smoke(self._kiwi_session(), include_mlx=include_mlx)))
263
+
264
+ ipython.register_magics(KiwiMagics)
265
+
266
+
267
+ def unload_ipython_extension(ipython) -> None:
268
+ del ipython
@@ -0,0 +1,65 @@
1
+ Metadata-Version: 2.4
2
+ Name: kiwi-array
3
+ Version: 0.2.47
4
+ Summary: Python loader, IPython magics, backend discovery, and CLI launcher for Kiwi.
5
+ Requires-Python: >=3.11
6
+ Description-Content-Type: text/markdown
7
+ Requires-Dist: kiwi-array-host==0.2.47
8
+ Provides-Extra: host
9
+ Requires-Dist: kiwi-array-host==0.2.47; extra == "host"
10
+ Provides-Extra: cpu
11
+ Requires-Dist: kiwi-array-cpu==0.2.47; extra == "cpu"
12
+ Provides-Extra: metal
13
+ Requires-Dist: kiwi-array-metal==0.2.47; extra == "metal"
14
+ Provides-Extra: cuda12
15
+ Requires-Dist: kiwi-array-cuda12==0.2.47; extra == "cuda12"
16
+ Provides-Extra: notebook
17
+ Requires-Dist: ipython>=8; extra == "notebook"
18
+ Provides-Extra: jupyter
19
+ Requires-Dist: kiwi-array-jupyter==0.2.47; extra == "jupyter"
20
+
21
+ # kiwi-array
22
+
23
+ `kiwi-array` is the Python loader package for Kiwi. It contains the
24
+ `kiwi_array.bridge` ctypes wrapper, IPython magics, runtime backend discovery,
25
+ and a small `kiwi` CLI launcher that execs the real native CLI from an
26
+ installed backend payload.
27
+
28
+ By default, installing `kiwi-array` also installs `kiwi-array-host`, the
29
+ conservative host backend. Accelerated backends are explicit extras:
30
+ `kiwi-array[cpu]`, `kiwi-array[metal]`, or `kiwi-array[cuda12]`.
31
+
32
+ In IPython-compatible hosted notebooks, load the extension and use either the
33
+ descriptive or short magic:
34
+
35
+ ```python
36
+ %load_ext kiwi_array
37
+ ```
38
+
39
+ ```text
40
+ %%k
41
+ x:1 2 3
42
+ +/x
43
+ ```
44
+
45
+ `%%kiwi` is registered as the explicit alias for `%%k`.
46
+
47
+ Use the lightweight diagnostics when sharing hosted notebook setup cells:
48
+
49
+ ```python
50
+ %kinfo
51
+ %ksmoke
52
+ ```
53
+
54
+ On a GPU runtime, `%ksmoke --gpu` also pushes a small MLX-backed vector through
55
+ the native bridge and evaluates a gradient over it.
56
+
57
+ Runtime payload packages such as `kiwi-array-host`, `kiwi-array-metal`, and
58
+ `kiwi-array-cuda12` register themselves through the `kiwi_array.runtimes` entry
59
+ point group. Set `KIWI_RUNTIME=host`, `KIWI_RUNTIME=cpu`,
60
+ `KIWI_RUNTIME=metal`, `KIWI_RUNTIME=cuda12`, or leave `KIWI_RUNTIME=auto` to
61
+ control runtime selection.
62
+
63
+ Vega-Lite outputs use the notebook's default light rendering by default. Set
64
+ `KIWI_VEGALITE_THEME=dark` before rendering if you want the HTML fallback to use
65
+ Vega-Embed's dark theme.
@@ -0,0 +1,9 @@
1
+ kiwi_array/__init__.py,sha256=KS25NzHIwwg4tsOJ84QjhJM75tb_CxyN_yOigXRpNNg,174
2
+ kiwi_array/bridge.py,sha256=CvOce3e65BanIBk_CH1aFnlNL1Yo5YOpWjrYgav8F3k,25420
3
+ kiwi_array/cli.py,sha256=IsDcnvMOY1d_GTceFcrJwdx7vyR7yLMdEU3pOzQILJs,922
4
+ kiwi_array/notebook.py,sha256=DfYRf1LALgBJV0XEBNXpWyi82bG8yeeVLKqce-R7cNM,8979
5
+ kiwi_array-0.2.47.dist-info/METADATA,sha256=VQjuq1BfJ59x0RjC_b8_WJtz9wfP7i35ejrObXg_kz8,2154
6
+ kiwi_array-0.2.47.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
7
+ kiwi_array-0.2.47.dist-info/entry_points.txt,sha256=rHIbSUqGhESM76fFGiKgf6InC-bDRNnJdJQe8uAEBAo,45
8
+ kiwi_array-0.2.47.dist-info/top_level.txt,sha256=QZLBgjHZ20dRwlzhEjdi0HhTJaXovd4JoRSpsyu0dDk,11
9
+ kiwi_array-0.2.47.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ kiwi = kiwi_array.cli:main
@@ -0,0 +1 @@
1
+ kiwi_array