ilang-python 0.1.0__py3-none-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ilang/__init__.py ADDED
@@ -0,0 +1,20 @@
1
+ """Python front-end for 𝚒."""
2
+
3
+ from .component import Bench, Component
4
+ from .tensor import Device, Tensor
5
+
6
+ class _i:
7
+ Component = Component
8
+ Tensor = Tensor
9
+ Device = Device
10
+
11
+ @property
12
+ def I(self) -> Component:
13
+ return Component.I
14
+
15
+ def __call__(self, expr: str) -> Component:
16
+ return Component(expr)
17
+
18
+ i = _i()
19
+
20
+ __all__ = ["Bench", "Component", "Device", "Tensor", "i"]
ilang/component.py ADDED
@@ -0,0 +1,529 @@
1
+ from __future__ import annotations
2
+
3
+ import ctypes
4
+ from dataclasses import dataclass
5
+ from datetime import datetime, timedelta
6
+ from math import floor, log10, sqrt
7
+ from typing import Any, ClassVar
8
+
9
+ from . import ffi
10
+ from .inputs import _inputs
11
+ from .tensor import Device, Tensor, _OwnedOutputs
12
+
13
+ __all__ = ["Bench", "Component"]
14
+
15
+
16
+ @dataclass
17
+ class Bench:
18
+ mean: timedelta
19
+ std: timedelta
20
+ n_warmups: int
21
+ n_runs: int
22
+ runs: list[timedelta]
23
+
24
+ def _human_time(self) -> str:
25
+ if self.mean.total_seconds() == 0:
26
+ mean_order = -6
27
+ else:
28
+ mean_order = floor(log10(self.mean.total_seconds()))
29
+ if mean_order <= -6:
30
+ scale = 9
31
+ unit = "ns"
32
+ elif mean_order <= -3:
33
+ scale = 6
34
+ unit = "μs"
35
+ elif mean_order <= 0:
36
+ scale = 3
37
+ unit = "ms"
38
+ else:
39
+ scale = 0
40
+ unit = "s"
41
+ mean_str = f"{round(self.mean.total_seconds() * 10**scale)}"
42
+ std_str = f"{round(self.std.total_seconds() * 10**scale)}"
43
+ return f"{mean_str}±{std_str} {unit}"
44
+
45
+ def __repr__(self) -> str:
46
+ return f"{self._human_time()}, warmups = {self.n_warmups}, runs = {self.n_runs}"
47
+
48
+
49
+ class Component:
50
+ I: ClassVar[Component] # noqa: E741
51
+
52
+ def __init__(
53
+ self,
54
+ expr: str | None = None,
55
+ _ptr: ctypes.c_void_p | None = None,
56
+ _bindings: tuple[Any | None, ...] | None = None,
57
+ ) -> None:
58
+ if _ptr is None:
59
+ if expr is None:
60
+ raise TypeError("Component needs expression")
61
+ _ptr = ffi._core.i_parse(expr.encode())
62
+ self._ptr: ctypes.c_void_p | None = ffi._check_ptr(_ptr)
63
+ input_count = self._input_count()
64
+ self._bindings: tuple[Any | None, ...] = (
65
+ tuple(None for _ in range(input_count)) if _bindings is None else _bindings
66
+ )
67
+ if len(self._bindings) != input_count:
68
+ raise RuntimeError(
69
+ f"binding metadata has {len(self._bindings)} input(s), component has {input_count}"
70
+ )
71
+ states = self._input_states(input_count)
72
+ for index, (binding, state) in enumerate(zip(self._bindings, states)):
73
+ if (binding is None) != (state == 0):
74
+ raise RuntimeError(
75
+ f"binding metadata for input {index} disagrees with component state"
76
+ )
77
+ self._program: ctypes.c_void_p | None = None
78
+ self._cuda_program: ctypes.c_void_p | None = None
79
+
80
+ def __del__(self) -> None:
81
+ program = getattr(self, "_program", None)
82
+ cuda_program = getattr(self, "_cuda_program", None)
83
+ ptr = getattr(self, "_ptr", None)
84
+ if program:
85
+ ffi._core.i_program_free(program)
86
+ self._program = None
87
+ if cuda_program:
88
+ ffi._core.i_program_free(cuda_program)
89
+ self._cuda_program = None
90
+ if ptr:
91
+ ffi._core.i_component_free(ptr)
92
+ self._ptr = None
93
+
94
+ def _input_count(self) -> int:
95
+ out = ctypes.c_size_t()
96
+ ffi._check(ffi._core.i_component_input_count(self._ptr, ctypes.byref(out))) # type: ignore[arg-type]
97
+ return int(out.value)
98
+
99
+ def _output_count(self) -> int:
100
+ out = ctypes.c_size_t()
101
+ ffi._check(ffi._core.i_component_output_count(self._ptr, ctypes.byref(out))) # type: ignore[arg-type]
102
+ return int(out.value)
103
+
104
+ def _input_states(self, count: int | None = None) -> tuple[int, ...]:
105
+ if count is None:
106
+ count = self._input_count()
107
+ states = (ctypes.c_int * count)()
108
+ ffi._check(ffi._core.i_component_input_states(self._ptr, states)) # type: ignore[arg-type]
109
+ return tuple(int(states[i]) for i in range(count))
110
+
111
+ def _bin(self, other: Component | str, fn: Any, bindings_fn: Any) -> Component:
112
+ if not isinstance(other, Component):
113
+ other = Component(other)
114
+ bindings = bindings_fn(other)
115
+ return Component(
116
+ _ptr=ffi._check_ptr(fn(self._ptr, other._ptr)), # type: ignore[arg-type]
117
+ _bindings=bindings,
118
+ )
119
+
120
+ def chain(self, other: Component | str) -> Component:
121
+ return self._bin(other, ffi._core.i_chain, self._chain_bindings)
122
+
123
+ def compose(self, other: Component | str) -> Component:
124
+ return self._bin(other, ffi._core.i_compose, self._compose_bindings)
125
+
126
+ def fanout(self, other: Component | str) -> Component:
127
+ return self._bin(other, ffi._core.i_fanout, self._fanout_bindings)
128
+
129
+ def pair(self, other: Component | str) -> Component:
130
+ return self._bin(other, ffi._core.i_pair, self._pair_bindings)
131
+
132
+ def swap(self) -> Component:
133
+ return Component(
134
+ _ptr=ffi._check_ptr(ffi._core.i_swap(self._ptr)), # type: ignore[arg-type]
135
+ _bindings=self._bindings,
136
+ )
137
+
138
+ def bind(self, *args: Any) -> Component:
139
+ free = _free_indices(self._bindings)
140
+ if len(args) > len(free):
141
+ raise TypeError(
142
+ f"too many bindings: got {len(args)}, component has {len(free)} free input(s)"
143
+ )
144
+
145
+ bindings = list(self._bindings)
146
+ ptr = self._ptr
147
+ owned_temp = False
148
+ for physical_index, value in zip(free, args):
149
+ if value is None:
150
+ continue
151
+ new_ptr = ffi._check_ptr(ffi._core.i_bind_input(ptr, physical_index)) # type: ignore[arg-type]
152
+ if owned_temp:
153
+ ffi._core.i_component_free(ptr)
154
+ ptr = new_ptr
155
+ owned_temp = True
156
+ bindings[physical_index] = value
157
+
158
+ if not owned_temp:
159
+ return self
160
+ return Component(_ptr=ptr, _bindings=tuple(bindings))
161
+
162
+ def __call__(self, *args: Any, into: Any = None) -> Any:
163
+ if not args:
164
+ return self.exec(into=into)
165
+
166
+ if into is not None:
167
+ raise TypeError("into= is only valid when executing with an empty call")
168
+
169
+ result = self
170
+ pending_bindings = []
171
+ for arg in args:
172
+ if isinstance(arg, Component):
173
+ if pending_bindings:
174
+ result = result.bind(*pending_bindings)
175
+ pending_bindings = []
176
+ result = result.compose(arg)
177
+ else:
178
+ pending_bindings.append(arg)
179
+
180
+ if pending_bindings:
181
+ result = result.bind(*pending_bindings)
182
+ return result
183
+
184
+ def _chain_bindings(self, other: Component) -> tuple[Any | None, ...]:
185
+ paired = min(self._output_count(), len(_free_indices(other._bindings)))
186
+ consumed = set(_free_indices(other._bindings)[:paired])
187
+ return self._bindings + tuple(
188
+ binding for index, binding in enumerate(other._bindings) if index not in consumed
189
+ )
190
+
191
+ def _compose_bindings(self, other: Component) -> tuple[Any | None, ...]:
192
+ paired = min(len(_free_indices(self._bindings)), other._output_count())
193
+ consumed = set(_free_indices(self._bindings)[:paired])
194
+ return other._bindings + tuple(
195
+ binding for index, binding in enumerate(self._bindings) if index not in consumed
196
+ )
197
+
198
+ def _fanout_bindings(self, other: Component) -> tuple[Any | None, ...]:
199
+ paired = min(len(_free_indices(self._bindings)), len(_free_indices(other._bindings)))
200
+ consumed = set(_free_indices(other._bindings)[:paired])
201
+ return self._bindings + tuple(
202
+ binding for index, binding in enumerate(other._bindings) if index not in consumed
203
+ )
204
+
205
+ def _pair_bindings(self, other: Component) -> tuple[Any | None, ...]:
206
+ return self._bindings + other._bindings
207
+
208
+ def __rshift__(self, other: Component | str) -> Component:
209
+ return self.chain(other)
210
+
211
+ def __lshift__(self, other: Component | str) -> Component:
212
+ return self.compose(other)
213
+
214
+ def __and__(self, other: Component | str) -> Component:
215
+ return self.fanout(other)
216
+
217
+ def __or__(self, other: Component | str) -> Component:
218
+ return self.pair(other)
219
+
220
+ def __invert__(self) -> Component:
221
+ return self.swap()
222
+
223
+ def _compile(self, device: Device = Device.CPU) -> ctypes.c_void_p:
224
+ if device is Device.CPU:
225
+ if self._program is None:
226
+ self._program = ffi._check_ptr(
227
+ ffi._core.i_compile(self._ptr, device._as_ffi()) # type: ignore[arg-type]
228
+ )
229
+ return self._program
230
+
231
+ if self._cuda_program is None:
232
+ self._cuda_program = ffi._check_ptr(
233
+ ffi._core.i_compile(self._ptr, device._as_ffi()) # type: ignore[arg-type]
234
+ )
235
+ return self._cuda_program
236
+
237
+ def _code(self, device: Device | str = Device.CPU) -> str:
238
+ device = Device.coerce(device)
239
+ s = ffi._check_ptr(ffi._core.i_code(self._ptr, device._as_ffi())) # type: ignore[arg-type]
240
+ try:
241
+ ptr = ctypes.cast(s, ctypes.c_char_p).value
242
+ if ptr is None:
243
+ raise RuntimeError("Failed to decode None pointer")
244
+ return ptr.decode()
245
+ finally:
246
+ ffi._core.i_string_free(s)
247
+
248
+ def _cuda_code(self) -> str:
249
+ return self._code(Device.CUDA)
250
+
251
+ def _physical_inputs(self, inputs: tuple[Any, ...]) -> tuple[Any, ...]:
252
+ free = _free_indices(self._bindings)
253
+ if len(inputs) > len(free):
254
+ raise TypeError(
255
+ f"too many inputs: got {len(inputs)}, component has {len(free)} free input(s)"
256
+ )
257
+
258
+ merged = list(self._bindings)
259
+ for physical_index, value in zip(free, inputs):
260
+ merged[physical_index] = value
261
+
262
+ missing = [index for index, value in enumerate(merged) if value is None]
263
+ if missing:
264
+ raise TypeError(f"component is not fully bound; missing input(s) {missing}")
265
+ return tuple(merged)
266
+
267
+ def output_shapes(
268
+ self, *inputs: Any, _program: ctypes.c_void_p | None = None, _physical: bool = False
269
+ ) -> list[tuple[int, ...]]:
270
+ if not _physical:
271
+ inputs = self._physical_inputs(inputs)
272
+ program = _program if _program is not None else self._compile()
273
+ input_arr, _keepalive = _inputs(inputs)
274
+ count = ffi._core.i_output_count(program)
275
+ ranks = (ctypes.c_size_t * count)()
276
+ ffi._check(ffi._core.i_output_ranks(program, ranks))
277
+ shape_bufs: list[Any] = [(ctypes.c_size_t * ranks[i])() for i in range(count)]
278
+ shape_ptrs = (ctypes.POINTER(ctypes.c_size_t) * count)(
279
+ *(ctypes.cast(buf, ctypes.POINTER(ctypes.c_size_t)) for buf in shape_bufs)
280
+ )
281
+ ffi._check(
282
+ ffi._core.i_output_shapes(program, input_arr, len(inputs), shape_ptrs)
283
+ )
284
+ return [
285
+ tuple(buf[j] for j in range(ranks[i])) for i, buf in enumerate(shape_bufs)
286
+ ]
287
+
288
+ def exec(self, *inputs: Any, into: Any = None) -> Any:
289
+ inputs = self._physical_inputs(inputs)
290
+ target, device = _resolve_target(inputs, into)
291
+ program = self._compile(device)
292
+ if target == "tensor" and device is Device.CPU:
293
+ return self._exec_owned(program, *inputs)
294
+ return self._exec_allocated(program, target, device, *inputs)
295
+
296
+ def _exec_owned(self, program: ctypes.c_void_p, *inputs: Any) -> Any:
297
+ input_arr, _keepalive = _inputs(inputs)
298
+ outputs = ffi._core.i_exec(program, input_arr, len(inputs))
299
+ if outputs.count == 0:
300
+ ffi._check(-1)
301
+ owner = _OwnedOutputs(outputs)
302
+ tensors: list[Tensor] = [
303
+ Tensor._from_owned(owner, i) for i in range(outputs.count)
304
+ ]
305
+ if len(tensors) == 1:
306
+ return tensors[0]
307
+ return tuple(tensors)
308
+
309
+ def _exec_allocated(
310
+ self, program: ctypes.c_void_p, target: str, device: Device, *inputs: Any
311
+ ) -> Any:
312
+ shapes: list[tuple[int, ...]] = self.output_shapes(*inputs, _program=program, _physical=True)
313
+ if target == "numpy":
314
+ import numpy as np
315
+
316
+ if device is not Device.CPU:
317
+ raise TypeError("NumPy outputs only support CPU execution")
318
+ outs: list[Any] = [np.empty(shape, dtype=np.float32) for shape in shapes]
319
+ elif target == "torch":
320
+ import torch
321
+
322
+ torch_device = "cuda" if device is Device.CUDA else "cpu"
323
+ outs = [
324
+ torch.empty(shape, dtype=torch.float32, device=torch_device)
325
+ for shape in shapes
326
+ ]
327
+ elif target == "tensor":
328
+ outs = [Tensor._empty(shape, device) for shape in shapes]
329
+ else:
330
+ raise TypeError(f"unknown execution target {target!r}")
331
+
332
+ outputs = outs if len(outs) != 1 else outs[0]
333
+ return self._exec_into(program, outputs, *inputs)
334
+
335
+ def _exec_into(self, program: ctypes.c_void_p, outputs: Any, *inputs: Any) -> Any:
336
+ if not isinstance(outputs, (tuple, list)):
337
+ outputs = (outputs,)
338
+
339
+ input_arr, _keepalive = _inputs(inputs)
340
+ out_views: list[ffi._CTensorMut] = []
341
+ out_keepalive: list[tuple[Any, ...]] = []
342
+ for out in outputs:
343
+ view, keep = _output(out)
344
+ out_views.append(view)
345
+ out_keepalive.append(keep)
346
+ output_arr: ctypes.Array[ffi._CTensorMut] = (ffi._CTensorMut * len(out_views))(
347
+ *out_views
348
+ )
349
+ ffi._check(
350
+ ffi._core.i_exec_into(
351
+ program, input_arr, len(inputs), output_arr, len(out_views)
352
+ )
353
+ )
354
+ return outputs[0] if len(outputs) == 1 else tuple(outputs)
355
+
356
+ def bench(
357
+ self,
358
+ inputs: list[Tensor],
359
+ n_warmups: int = 10,
360
+ n_runs: int = 100,
361
+ ) -> Bench:
362
+ for _ in range(n_warmups):
363
+ self.exec(*inputs)
364
+
365
+ runs = []
366
+ for _ in range(n_runs):
367
+ start = datetime.now()
368
+ self.exec(*inputs)
369
+ end = datetime.now()
370
+ runs.append(end - start)
371
+
372
+ mean = timedelta(seconds=sum([run.total_seconds() for run in runs]) / len(runs))
373
+ std = timedelta(
374
+ seconds=sqrt(
375
+ 1
376
+ / (len(runs) - 1)
377
+ * sum([(r - mean).total_seconds() ** 2 for r in runs])
378
+ )
379
+ )
380
+
381
+ return Bench(
382
+ mean=mean,
383
+ std=std,
384
+ n_warmups=n_warmups,
385
+ n_runs=n_runs,
386
+ runs=runs,
387
+ )
388
+
389
+
390
+ def _free_indices(bindings: tuple[Any | None, ...]) -> list[int]:
391
+ return [index for index, binding in enumerate(bindings) if binding is None]
392
+
393
+
394
+ def _resolve_target(inputs: tuple[Any, ...], into: Any) -> tuple[str, Device]:
395
+ infos = [_input_info(x) for x in inputs]
396
+ devices = {device for _kind, device in infos}
397
+ if len(devices) != 1:
398
+ if not devices:
399
+ raise TypeError("cannot infer execution device without inputs")
400
+ raise TypeError("all inputs must be on the same device")
401
+ device = devices.pop()
402
+
403
+ if into is not None:
404
+ return _target_from_marker(into), device
405
+
406
+ kinds = {kind for kind, _device in infos}
407
+ if len(kinds) == 1:
408
+ return kinds.pop(), device
409
+ if not kinds:
410
+ raise TypeError("cannot infer execution target without inputs; pass into=...")
411
+ raise TypeError("cannot infer execution target from mixed input tensor types")
412
+
413
+
414
+ def _target_from_marker(marker: Any) -> str:
415
+ if marker is Tensor:
416
+ return "tensor"
417
+
418
+ if isinstance(marker, str):
419
+ name = marker.lower()
420
+ if name in {"tensor", "ilang", "i"}:
421
+ return "tensor"
422
+ if name in {"numpy", "np"}:
423
+ return "numpy"
424
+ if name == "torch":
425
+ return "torch"
426
+
427
+ try:
428
+ import numpy as np
429
+
430
+ if marker is np.ndarray:
431
+ return "numpy"
432
+ except ImportError:
433
+ pass
434
+
435
+ try:
436
+ import torch
437
+
438
+ if marker is torch.Tensor:
439
+ return "torch"
440
+ except ImportError:
441
+ pass
442
+
443
+ raise TypeError(
444
+ "into must be i.Tensor, 'numpy', 'torch', np.ndarray, or torch.Tensor"
445
+ )
446
+
447
+
448
+ def _input_info(x: Any) -> tuple[str, Device]:
449
+ if isinstance(x, Tensor):
450
+ return "tensor", x.device
451
+
452
+ try:
453
+ import numpy as np
454
+
455
+ if isinstance(x, np.ndarray):
456
+ if x.dtype != np.float32 or not x.flags.c_contiguous:
457
+ raise TypeError("NumPy inputs must be float32 and C-contiguous")
458
+ return "numpy", Device.CPU
459
+ except ImportError:
460
+ pass
461
+
462
+ try:
463
+ import torch
464
+
465
+ if isinstance(x, torch.Tensor):
466
+ if str(x.dtype) != "torch.float32":
467
+ raise TypeError("Torch tensors must be float32")
468
+ if not x.is_contiguous():
469
+ raise TypeError("Torch tensors must be contiguous")
470
+ return "torch", Device.CUDA if x.is_cuda else Device.CPU
471
+ except ImportError:
472
+ pass
473
+
474
+ try:
475
+ Tensor(x)
476
+ except (TypeError, ValueError):
477
+ pass
478
+ else:
479
+ return "tensor", Device.CPU
480
+
481
+ raise TypeError(
482
+ "inputs must be ilang.Tensor, NumPy arrays, Torch tensors, "
483
+ "or Python scalars/lists"
484
+ )
485
+
486
+
487
+ def _output(x: Any) -> tuple[ffi._CTensorMut, tuple[Any, ...]]:
488
+ try:
489
+ import numpy as np
490
+
491
+ if isinstance(x, np.ndarray):
492
+ if x.dtype != np.float32 or not x.flags.c_contiguous:
493
+ raise TypeError("NumPy outputs must be float32 and C-contiguous")
494
+ shape, shape_buf = _shape_array(x.shape)
495
+ data = x.ctypes.data_as(ctypes.POINTER(ctypes.c_float))
496
+ return ffi._CTensorMut(data, shape_buf, len(shape)), (x, shape_buf)
497
+ except ImportError:
498
+ pass
499
+
500
+ try:
501
+ import torch
502
+
503
+ if isinstance(x, torch.Tensor):
504
+ if str(x.dtype) != "torch.float32":
505
+ raise TypeError("Torch outputs must be float32")
506
+ if not x.is_contiguous():
507
+ raise TypeError("Torch outputs must be contiguous")
508
+ shape, shape_buf = _shape_array(tuple(x.shape))
509
+ data = ctypes.cast(x.data_ptr(), ctypes.POINTER(ctypes.c_float))
510
+ return ffi._CTensorMut(data, shape_buf, len(shape)), (x, shape_buf)
511
+ except ImportError:
512
+ pass
513
+
514
+ if isinstance(x, Tensor):
515
+ return ffi._CTensorMut(x._data, x._shape_buf, len(x.shape)), (x,)
516
+
517
+ raise TypeError("outputs must be ilang.Tensor, NumPy arrays, or Torch tensors")
518
+
519
+
520
+ def _shape_array(
521
+ shape: tuple[int, ...],
522
+ ) -> tuple[tuple[int, ...], Any]:
523
+ shape = tuple(int(d) for d in shape)
524
+ arr: Any = (ctypes.c_size_t * len(shape))(*shape)
525
+ return shape, arr
526
+
527
+
528
+ Component.I = Component(_ptr=ffi._core.i_identity()) # noqa: E741
529
+
ilang/ffi.py ADDED
@@ -0,0 +1,151 @@
1
+ from __future__ import annotations
2
+
3
+ import ctypes
4
+ import os
5
+ import sys
6
+ from pathlib import Path
7
+
8
+
9
+ def _load_core() -> ctypes.CDLL:
10
+ override = os.environ.get("I_CORE_LIB")
11
+ if override:
12
+ return ctypes.CDLL(override)
13
+
14
+ here = Path(__file__).resolve()
15
+ names: dict[str, list[str]] = {
16
+ "darwin": ["libi_core.dylib"],
17
+ "win32": ["i_core.dll"],
18
+ }
19
+ so_name: list[str] = names.get(sys.platform, ["libi_core.so"])
20
+
21
+ roots: list[Path] = [
22
+ here.parent,
23
+ here.parent.parent / "target" / "release",
24
+ here.parent.parent / "target" / "debug",
25
+ ]
26
+ for root in roots:
27
+ for name in so_name:
28
+ path = root / name
29
+ if path.exists():
30
+ return ctypes.CDLL(str(path))
31
+
32
+ raise RuntimeError("could not find i-core library; run `cargo build -p i-core`")
33
+
34
+
35
+ _core: ctypes.CDLL = _load_core()
36
+
37
+
38
+ def _check_ptr(ptr: ctypes.c_void_p | None) -> ctypes.c_void_p:
39
+ if not ptr:
40
+ err = _core.i_error()
41
+ raise RuntimeError(err.decode() if err else "i-core error")
42
+ return ptr
43
+
44
+
45
+ def _check(code: int) -> None:
46
+ if code != 0:
47
+ err = _core.i_error()
48
+ raise RuntimeError(err.decode() if err else "i-core error")
49
+
50
+
51
+ class _CTensor(ctypes.Structure):
52
+ _fields_ = [
53
+ ("data", ctypes.POINTER(ctypes.c_float)),
54
+ ("shape", ctypes.POINTER(ctypes.c_size_t)),
55
+ ("rank", ctypes.c_size_t),
56
+ ]
57
+
58
+
59
+ class _CTensorMut(ctypes.Structure):
60
+ _fields_ = [
61
+ ("data", ctypes.POINTER(ctypes.c_float)),
62
+ ("shape", ctypes.POINTER(ctypes.c_size_t)),
63
+ ("rank", ctypes.c_size_t),
64
+ ]
65
+
66
+
67
+ class _COwnedTensor(ctypes.Structure):
68
+ _fields_ = [
69
+ ("data", ctypes.POINTER(ctypes.c_float)),
70
+ ("shape", ctypes.POINTER(ctypes.c_size_t)),
71
+ ("rank", ctypes.c_size_t),
72
+ ("len", ctypes.c_size_t),
73
+ ]
74
+
75
+
76
+ class _COutputs(ctypes.Structure):
77
+ _fields_ = [
78
+ ("tensors", ctypes.POINTER(_COwnedTensor)),
79
+ ("count", ctypes.c_size_t),
80
+ ]
81
+
82
+
83
+ def _bind_functions(core: ctypes.CDLL) -> None:
84
+ core.i_parse.argtypes = [ctypes.c_char_p]
85
+ core.i_parse.restype = ctypes.c_void_p
86
+ core.i_identity.argtypes = []
87
+ core.i_identity.restype = ctypes.c_void_p
88
+ for _name in ("i_chain", "i_compose", "i_fanout", "i_pair"):
89
+ _fn = getattr(core, _name)
90
+ _fn.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
91
+ _fn.restype = ctypes.c_void_p
92
+ core.i_swap.argtypes = [ctypes.c_void_p]
93
+ core.i_swap.restype = ctypes.c_void_p
94
+ core.i_bind_input.argtypes = [ctypes.c_void_p, ctypes.c_size_t]
95
+ core.i_bind_input.restype = ctypes.c_void_p
96
+ core.i_component_input_count.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
97
+ core.i_component_input_count.restype = ctypes.c_int
98
+ core.i_component_output_count.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
99
+ core.i_component_output_count.restype = ctypes.c_int
100
+ core.i_component_input_states.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_int)]
101
+ core.i_component_input_states.restype = ctypes.c_int
102
+ core.i_code.argtypes = [ctypes.c_void_p, ctypes.c_int]
103
+ core.i_code.restype = ctypes.c_void_p
104
+ core.i_compile.argtypes = [ctypes.c_void_p, ctypes.c_int]
105
+ core.i_compile.restype = ctypes.c_void_p
106
+ core.i_program_device.argtypes = [ctypes.c_void_p]
107
+ core.i_program_device.restype = ctypes.c_int
108
+ core.i_alloc.argtypes = [ctypes.c_int, ctypes.c_size_t]
109
+ core.i_alloc.restype = ctypes.c_void_p
110
+ core.i_free.argtypes = [ctypes.c_int, ctypes.POINTER(ctypes.c_float)]
111
+ core.i_copy.argtypes = [
112
+ ctypes.c_int,
113
+ ctypes.POINTER(ctypes.c_float),
114
+ ctypes.c_int,
115
+ ctypes.POINTER(ctypes.c_float),
116
+ ctypes.c_size_t,
117
+ ]
118
+ core.i_copy.restype = ctypes.c_int
119
+ core.i_output_count.argtypes = [ctypes.c_void_p]
120
+ core.i_output_count.restype = ctypes.c_size_t
121
+ core.i_output_ranks.argtypes = [ctypes.c_void_p, ctypes.POINTER(ctypes.c_size_t)]
122
+ core.i_output_ranks.restype = ctypes.c_int
123
+ core.i_output_shapes.argtypes = [
124
+ ctypes.c_void_p,
125
+ ctypes.POINTER(_CTensor),
126
+ ctypes.c_size_t,
127
+ ctypes.POINTER(ctypes.POINTER(ctypes.c_size_t)),
128
+ ]
129
+ core.i_output_shapes.restype = ctypes.c_int
130
+ core.i_exec_into.argtypes = [
131
+ ctypes.c_void_p,
132
+ ctypes.POINTER(_CTensor),
133
+ ctypes.c_size_t,
134
+ ctypes.POINTER(_CTensorMut),
135
+ ctypes.c_size_t,
136
+ ]
137
+ core.i_exec_into.restype = ctypes.c_int
138
+ core.i_exec.argtypes = [ctypes.c_void_p, ctypes.POINTER(_CTensor), ctypes.c_size_t]
139
+ core.i_exec.restype = _COutputs
140
+ core.i_component_free.argtypes = [ctypes.c_void_p]
141
+ core.i_program_free.argtypes = [ctypes.c_void_p]
142
+ core.i_outputs_free.argtypes = [_COutputs]
143
+ core.i_string_free.argtypes = [ctypes.c_void_p]
144
+ core.i_error.argtypes = []
145
+ core.i_error.restype = ctypes.c_char_p
146
+
147
+
148
+ _core: ctypes.CDLL = _load_core()
149
+ _bind_functions(_core)
150
+
151
+ __all__ = ["_core", "_check_ptr", "_check", "_CTensor", "_CTensorMut"]