bafe-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bafe/__init__.py ADDED
@@ -0,0 +1,974 @@
1
+ """BAFE - Basic Algebra Fusion Engine.
2
+
3
+ Public API:
4
+ import bafe
5
+
6
+ @bafe.jit
7
+ def f(A, B, C):
8
+ return bafe.relu(bafe.matmul(A, B) + C)
9
+
10
+ The decorator traces the function (which builds an IR graph via the op
11
+ functions), runs the BAFE optimization pipeline, JIT-compiles the result,
12
+ and returns a callable that invokes the compiled kernel directly.
13
+ """
14
+ from __future__ import annotations
15
+
16
+ import ctypes
17
+ import functools
18
+ import os
19
+ from typing import Callable, List, Tuple, Any
20
+
21
+ import numpy as np
22
+
23
+ from bafe._binding import (
24
+ _lib, BafeGraph, BafeShape, BafeOpAttrs, BafeNode,
25
+ make_shape, make_attrs, graph_summary,
26
+ BAFE_MAX_NODES, BAFE_MAX_CHILDREN,
27
+ )
28
+
29
+ __version__ = "0.1.0"
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Dtype mapping
33
+ # ---------------------------------------------------------------------------
34
+
35
+ _NP_TO_BAFE = {
36
+ np.dtype("float32"): 0, # BAFE_DTYPE_F32
37
+ np.dtype("float64"): 1, # BAFE_DTYPE_F64
38
+ np.dtype("int32"): 2, # BAFE_DTYPE_I32
39
+ np.dtype("int64"): 3, # BAFE_DTYPE_I64
40
+ }
41
+
42
+ _BAFE_TO_NP = {v: k for k, v in _NP_TO_BAFE.items()}
43
+
44
+
45
+ # ---------------------------------------------------------------------------
46
+ # Tracing context
47
+ # ---------------------------------------------------------------------------
48
+
49
+ class _TraceContext:
50
+ """Holds the in-progress graph and the name->node_id mapping."""
51
+ def __init__(self):
52
+ self.graph = BafeGraph()
53
+ _lib.bafe_graph_init(ctypes.byref(self.graph))
54
+ self.inputs: List[Tuple[str, "Tensor"]] = []
55
+ self.input_names: List[str] = []
56
+
57
+ def add_input(self, name: str, shape: Tuple[int, ...], dtype: int,
58
+ layout: int = 0) -> int:
59
+ """Add an input. layout is a bafe_layout enum value (0=row, 1=col)."""
60
+ sh = make_shape(shape)
61
+ if layout == 0:
62
+ # default row-major: use the plain add_input for backward compat
63
+ nid = _lib.bafe_graph_add_input(
64
+ ctypes.byref(self.graph),
65
+ name.encode("utf-8"),
66
+ ctypes.byref(sh),
67
+ ctypes.c_int(dtype),
68
+ )
69
+ else:
70
+ nid = _lib.bafe_graph_add_input_with_layout(
71
+ ctypes.byref(self.graph),
72
+ name.encode("utf-8"),
73
+ ctypes.byref(sh),
74
+ ctypes.c_int(dtype),
75
+ ctypes.c_int(layout),
76
+ )
77
+ if nid < 0:
78
+ raise RuntimeError(f"failed to add input {name}")
79
+ self.input_names.append(name)
80
+ return nid
81
+
82
+ def add_op(self, op_name: str, children: List[int], **attrs) -> int:
83
+ n = len(children)
84
+ arr = (ctypes.c_int32 * max(n, 1))(*children) if n else None
85
+ a = make_attrs(**attrs) if attrs else None
86
+ nid = _lib.bafe_graph_add(
87
+ ctypes.byref(self.graph),
88
+ op_name.encode("utf-8"),
89
+ arr,
90
+ ctypes.c_int(n),
91
+ ctypes.byref(a) if a else None,
92
+ )
93
+ if nid < 0:
94
+ raise RuntimeError(f"failed to add op {op_name} (children={children})")
95
+ return nid
96
+
97
+ def set_output(self, nid: int) -> None:
98
+ _lib.bafe_graph_set_output(ctypes.byref(self.graph), ctypes.c_int32(nid))
99
+
100
+
101
+ _TRACE: _TraceContext | None = None
102
+
103
+
104
+ def _ensure_trace() -> _TraceContext:
105
+ global _TRACE
106
+ if _TRACE is None:
107
+ raise RuntimeError(
108
+ "no active trace - bafe ops can only be called inside a "
109
+ "function decorated with @bafe.jit"
110
+ )
111
+ return _TRACE
112
+
113
+
114
+ # ---------------------------------------------------------------------------
115
+ # Tensor handle (symbolic during tracing)
116
+ # ---------------------------------------------------------------------------
117
+
118
+ class Tensor:
119
+ """A symbolic tensor handle.
120
+
121
+ During tracing (inside an @bafe.jit function), Tensor instances refer
122
+ to IR graph nodes. They support Python operator overloads so users
123
+ can write natural math expressions.
124
+ """
125
+ __slots__ = ("node_id", "shape", "dtype", "_name")
126
+
127
+ def __init__(self, node_id: int, shape: Tuple[int, ...], dtype: int, name: str | None = None):
128
+ self.node_id = node_id
129
+ self.shape = shape
130
+ self.dtype = dtype
131
+ self._name = name
132
+
133
+ def __add__(self, other: "Tensor") -> "Tensor":
134
+ return _binop("add", self, other)
135
+
136
+ def __sub__(self, other: "Tensor") -> "Tensor":
137
+ return _binop("sub", self, other)
138
+
139
+ def __mul__(self, other: "Tensor") -> "Tensor":
140
+ return _binop("mul", self, other)
141
+
142
+ def __matmul__(self, other: "Tensor") -> "Tensor":
143
+ tc = _ensure_trace()
144
+ if len(self.shape) != 2 or len(other.shape) != 2:
145
+ raise ValueError("matmul requires rank-2 tensors")
146
+ if self.shape[1] != other.shape[0]:
147
+ raise ValueError(
148
+ f"matmul shape mismatch: {self.shape} @ {other.shape}"
149
+ )
150
+ out_shape = (self.shape[0], other.shape[1])
151
+ nid = tc.add_op("matmul", [self.node_id, other.node_id])
152
+ return Tensor(nid, out_shape, self.dtype)
153
+
154
+ @property
155
+ def name(self) -> str | None:
156
+ return self._name
157
+
158
+ def __repr__(self) -> str:
159
+ return f"Tensor(shape={self.shape}, dtype={_BAFE_TO_NP[self.dtype]})"
160
+
161
+
162
+ def _binop(op: str, a: Tensor, b: Tensor) -> Tensor:
163
+ tc = _ensure_trace()
164
+ # broadcasting: simple version, just take the larger shape
165
+ out_shape = _broadcast_shapes(a.shape, b.shape)
166
+ nid = tc.add_op(op, [a.node_id, b.node_id])
167
+ return Tensor(nid, out_shape, a.dtype)
168
+
169
+
170
+ def _broadcast_shapes(a, b):
171
+ if len(a) == len(b):
172
+ return tuple(max(x, y) for x, y in zip(a, b))
173
+ if len(a) < len(b):
174
+ a, b = b, a
175
+ # pad b
176
+ bpad = (1,) * (len(a) - len(b)) + tuple(b)
177
+ return tuple(max(x, y) for x, y in zip(a, bpad))
178
+
179
+
180
+ # ---------------------------------------------------------------------------
181
+ # Op functions (used inside @bafe.jit functions)
182
+ # ---------------------------------------------------------------------------
183
+
184
+ def input(shape: Tuple[int, ...], dtype: str | np.dtype = "float32",
185
+ name: str = "x", layout: str = "row") -> Tensor:
186
+ """Declare an input tensor.
187
+
188
+ Usually called automatically by @bafe.jit based on the function's
189
+ arguments, but can also be called explicitly.
190
+
191
+ layout: "row" (default, C order) or "col" (Fortran order).
192
+ The layout tag tells BAFE how the input data is stored in memory,
193
+ which affects codegen (access patterns) and the cost model
194
+ (conversion penalties, fusion bonuses).
195
+ """
196
+ tc = _ensure_trace()
197
+ np_dt = np.dtype(dtype)
198
+ if np_dt not in _NP_TO_BAFE:
199
+ raise ValueError(f"unsupported dtype {dtype}; supported: {list(_NP_TO_BAFE)}")
200
+ bafe_dt = _NP_TO_BAFE[np_dt]
201
+ layout_map = {"row": 0, "col": 1, "blocked": 2, "tc": 3}
202
+ if layout not in layout_map:
203
+ raise ValueError(f"unsupported layout {layout!r}; supported: {list(layout_map)}")
204
+ bafe_layout = layout_map[layout]
205
+ nid = tc.add_input(name, tuple(int(s) for s in shape), bafe_dt, layout=bafe_layout)
206
+ return Tensor(nid, tuple(int(s) for s in shape), bafe_dt, name=name)
207
+
208
+
209
+ def matmul(a: Tensor, b: Tensor) -> Tensor:
210
+ return a @ b
211
+
212
+
213
+ def add(a: Tensor, b: Tensor) -> Tensor:
214
+ return a + b
215
+
216
+
217
+ def sub(a: Tensor, b: Tensor) -> Tensor:
218
+ return a - b
219
+
220
+
221
+ def mul(a: Tensor, b: Tensor) -> Tensor:
222
+ return a * b
223
+
224
+
225
+ def relu(x: Tensor) -> Tensor:
226
+ tc = _ensure_trace()
227
+ nid = tc.add_op("relu", [x.node_id])
228
+ return Tensor(nid, x.shape, x.dtype)
229
+
230
+
231
+ def sigmoid(x: Tensor) -> Tensor:
232
+ tc = _ensure_trace()
233
+ nid = tc.add_op("sigmoid", [x.node_id])
234
+ return Tensor(nid, x.shape, x.dtype)
235
+
236
+
237
+ def tanh(x: Tensor) -> Tensor:
238
+ tc = _ensure_trace()
239
+ nid = tc.add_op("tanh", [x.node_id])
240
+ return Tensor(nid, x.shape, x.dtype)
241
+
242
+
243
+ def bias_add(x: Tensor, bias: Tensor) -> Tensor:
244
+ tc = _ensure_trace()
245
+ nid = tc.add_op("bias_add", [x.node_id, bias.node_id])
246
+ return Tensor(nid, x.shape, x.dtype)
247
+
248
+
249
+ def transpose(x: Tensor, perm: Tuple[int, ...]) -> Tensor:
250
+ tc = _ensure_trace()
251
+ nid = tc.add_op("transpose", [x.node_id], perm=list(perm))
252
+ out_shape = tuple(x.shape[p] for p in perm)
253
+ return Tensor(nid, out_shape, x.dtype)
254
+
255
+
256
+ def reduce_sum(x: Tensor, axes: Tuple[int, ...], keepdims: bool = False) -> Tensor:
257
+ tc = _ensure_trace()
258
+ nid = tc.add_op("reduce_sum", [x.node_id], axes=list(axes), keepdims=keepdims)
259
+ out_shape = _reduce_shape(x.shape, axes, keepdims)
260
+ return Tensor(nid, out_shape, x.dtype)
261
+
262
+
263
+ def reduce_max(x: Tensor, axes: Tuple[int, ...], keepdims: bool = False) -> Tensor:
264
+ tc = _ensure_trace()
265
+ nid = tc.add_op("reduce_max", [x.node_id], axes=list(axes), keepdims=keepdims)
266
+ out_shape = _reduce_shape(x.shape, axes, keepdims)
267
+ return Tensor(nid, out_shape, x.dtype)
268
+
269
+
270
+ def reshape(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
271
+ tc = _ensure_trace()
272
+ nid = tc.add_op("reshape", [x.node_id], shape=list(shape))
273
+ return Tensor(nid, tuple(shape), x.dtype)
274
+
275
+
276
+ def broadcast_to(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
277
+ tc = _ensure_trace()
278
+ nid = tc.add_op("broadcast_to", [x.node_id], shape=list(shape))
279
+ return Tensor(nid, tuple(shape), x.dtype)
280
+
281
+
282
+ def _reduce_shape(in_shape, axes, keepdims):
283
+ axes_set = {a % len(in_shape) for a in axes}
284
+ out = []
285
+ for i, d in enumerate(in_shape):
286
+ if i in axes_set:
287
+ if keepdims:
288
+ out.append(1)
289
+ else:
290
+ out.append(d)
291
+ return tuple(out)
292
+
293
+
294
+ # ---------------------------------------------------------------------------
295
+ # @bafe.jit decorator
296
+ # ---------------------------------------------------------------------------
297
+
298
+ class JittedFunction:
299
+ """A JIT-compiled BAFE function.
300
+
301
+ On first call with concrete numpy arrays, it:
302
+ 1. Inspects the input shapes/dtypes
303
+ 2. Runs the user function under a trace
304
+ 3. Calls bafe_optimize + bafe_jit_get_or_compile
305
+ 4. Builds a ctypes function pointer with the correct signature
306
+ 5. Invokes the kernel
307
+
308
+ Subsequent calls with the same shapes/dtypes hit the JIT cache.
309
+
310
+ Phase 2 (issue #1): if a search budget is set, uses stochastic
311
+ multi-pass search to discover deeper rewrites.
312
+ """
313
+
314
+ def __init__(self, fn: Callable, budget=None, autotune=False, time_budget_ms=None):
315
+ self._fn = fn
316
+ self._budget = budget # BafeSearchBudget or None (deterministic)
317
+ self._autotune = autotune # Phase 3: enable autotune loop
318
+ self._time_budget_ms = time_budget_ms # Phase 3 (issue #4): pruning time budget
319
+ self._compiled = {} # key: (shapes, dtypes, layouts) -> compiled tuple
320
+ self._call_count = {} # key -> call number (for warmup tracking)
321
+ functools.update_wrapper(self, fn)
322
+
323
+ def __call__(self, *args: np.ndarray) -> np.ndarray:
324
+ # validate
325
+ if not args:
326
+ raise TypeError("jitted function requires at least one input")
327
+ for a in args:
328
+ if not isinstance(a, np.ndarray):
329
+ raise TypeError(f"expected numpy array, got {type(a)}")
330
+
331
+ # build cache key (Phase 2: include layout so row-major vs col-major
332
+ # inputs produce different compiled kernels)
333
+ def _layout_of(a):
334
+ if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
335
+ return "col"
336
+ return "row"
337
+ key = tuple((a.shape, str(a.dtype), _layout_of(a)) for a in args)
338
+
339
+ if key not in self._compiled:
340
+ self._compiled[key] = self._compile(args)
341
+ # Phase 3 (issue #6): if autotune is enabled and this is a new
342
+ # compile, increment the compile counter
343
+ if self._autotune_enabled():
344
+ stats = _lib.bafe_autotune_get_stats()
345
+ # we can't easily mutate the C stats from here, but the C
346
+ # side tracks compiles via bafe_jit_get_or_compile
347
+ pass
348
+
349
+ fn_ptr, sig, in_dtypes, out_shape, out_dtype, opt_graph, graph_hash, predicted_cost = self._compiled[key]
350
+
351
+ # allocate output
352
+ out = np.zeros(out_shape, dtype=out_dtype)
353
+
354
+ # build arg list: inputs + output pointer
355
+ # Phase 2: preserve the input's memory layout — if we compiled for
356
+ # col-major, the array must stay col-major (don't ascontiguousarray it).
357
+ c_args = []
358
+ for a, dt in zip(args, in_dtypes):
359
+ if a.dtype == dt:
360
+ arr = a
361
+ else:
362
+ # dtype conversion needed; preserve layout
363
+ if a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
364
+ arr = np.asfortranarray(a, dtype=dt)
365
+ else:
366
+ arr = np.ascontiguousarray(a, dtype=dt)
367
+ c_args.append(arr.ctypes.data_as(ctypes.c_void_p))
368
+ c_args.append(out.ctypes.data_as(ctypes.c_void_p))
369
+
370
+ # call: sig(fn_ptr) creates the callable; then pass the c_args
371
+ kernel = sig(fn_ptr)
372
+
373
+ # Phase 3 (issue #6): if autotune is enabled, time the kernel + log
374
+ if self._autotune_enabled():
375
+ import time
376
+ config = _lib.bafe_autotune_get_config()
377
+ stats = _lib.bafe_autotune_get_stats()
378
+ # warmup: skip timing for the first N calls
379
+ # (stats.total_calls is incremented on the C side, but since
380
+ # we're driving the loop from Python, we track it here)
381
+ call_num = self._call_count.get(key, 0) + 1
382
+ self._call_count[key] = call_num
383
+
384
+ if call_num <= config.warmup_calls:
385
+ kernel(*c_args)
386
+ else:
387
+ # time over multiple iterations
388
+ iters = config.timing_iters if config.timing_iters > 0 else 1
389
+ t0 = time.perf_counter()
390
+ for _ in range(iters):
391
+ kernel(*c_args)
392
+ t1 = time.perf_counter()
393
+ observed_ms = (t1 - t0) * 1000.0 / iters
394
+
395
+ # extract features + log
396
+ features = (ctypes.c_double * 8)()
397
+ _lib.bafe_profiling_extract_features(
398
+ ctypes.byref(opt_graph), features
399
+ )
400
+ _lib.bafe_profiling_add(
401
+ graph_hash.encode("utf-8") if isinstance(graph_hash, str) else graph_hash,
402
+ features,
403
+ ctypes.c_double(predicted_cost),
404
+ ctypes.c_double(observed_ms),
405
+ ctypes.c_int(0),
406
+ )
407
+
408
+ # check if we should refit
409
+ log = _lib.bafe_profiling_get_log().contents
410
+ if log.n > 0 and log.n % config.refit_threshold == 0:
411
+ _lib.bafe_profiling_refit()
412
+ # Phase 3 (issue #5): after refit, invalidate the
413
+ # Python-level compiled cache so the next call
414
+ # re-optimizes with the calibrated cost model.
415
+ # The C-level JIT cache is also invalidated so
416
+ # dlopen'd handles are released.
417
+ self._compiled.clear()
418
+ _lib.bafe_jit_invalidate_memory_cache()
419
+ else:
420
+ kernel(*c_args)
421
+
422
+ return out
423
+
424
+ def _autotune_enabled(self) -> bool:
425
+ """Check if autotune is enabled for this function."""
426
+ return getattr(self, "_autotune", False)
427
+
428
+ def _compile(self, args: Tuple[np.ndarray, ...]):
429
+ global _TRACE
430
+ # run the trace
431
+ _TRACE = _TraceContext()
432
+ try:
433
+ # build input tensors
434
+ in_tensors = []
435
+ for i, a in enumerate(args):
436
+ np_dt = np.dtype(a.dtype)
437
+ if np_dt not in _NP_TO_BAFE:
438
+ raise TypeError(f"unsupported dtype {np_dt}")
439
+ bafe_dt = _NP_TO_BAFE[np_dt]
440
+ name = self._fn.__code__.co_varnames[i] if i < len(self._fn.__code__.co_varnames) else f"in{i}"
441
+ # Phase 2: auto-detect input layout from numpy array flags
442
+ # If the array is Fortran-contiguous (col-major), tag it as "col".
443
+ # Otherwise default to "row" (C order).
444
+ if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
445
+ layout = "col"
446
+ else:
447
+ layout = "row"
448
+ t = input(a.shape, np_dt, name=name, layout=layout)
449
+ in_tensors.append(t)
450
+
451
+ # call the user function
452
+ result = self._fn(*in_tensors)
453
+ if not isinstance(result, Tensor):
454
+ raise TypeError(
455
+ f"jitted function must return a Tensor, got {type(result)}"
456
+ )
457
+ _TRACE.set_output(result.node_id)
458
+
459
+ # snapshot the input graph (the optimize call may add nodes via rewrites)
460
+ in_graph = _TRACE.graph
461
+ finally:
462
+ _TRACE = None
463
+
464
+ # optimize + compile
465
+ # Phase 2 (issue #1): if a budget is set, use stochastic multi-pass search
466
+ # Phase 3 (issue #4): if time_budget_ms is set, use the pruning controller
467
+ optimized = BafeGraph()
468
+ err_buf = ctypes.create_string_buffer(256)
469
+ if self._budget is not None or self._time_budget_ms is not None:
470
+ # build a budget struct
471
+ if self._budget is not None:
472
+ budget = self._budget
473
+ else:
474
+ budget = _lib.bafe_search_budget_default()
475
+ if self._time_budget_ms is not None:
476
+ budget.time_budget_ms = int(self._time_budget_ms)
477
+ rc = _lib.bafe_optimize_with_budget(
478
+ ctypes.byref(in_graph),
479
+ ctypes.byref(optimized),
480
+ ctypes.byref(budget),
481
+ err_buf,
482
+ ctypes.c_size_t(len(err_buf)),
483
+ )
484
+ else:
485
+ rc = _lib.bafe_optimize(
486
+ ctypes.byref(in_graph),
487
+ ctypes.byref(optimized),
488
+ err_buf,
489
+ ctypes.c_size_t(len(err_buf)),
490
+ )
491
+ if rc != 0:
492
+ raise RuntimeError(
493
+ f"bafe_optimize failed (code {rc}): {err_buf.value.decode()}"
494
+ )
495
+
496
+ # JIT compile
497
+ fn_ptr = _lib.bafe_jit_get_or_compile(
498
+ ctypes.byref(optimized),
499
+ err_buf,
500
+ ctypes.c_size_t(len(err_buf)),
501
+ )
502
+ if not fn_ptr:
503
+ raise RuntimeError(
504
+ f"bafe_jit_get_or_compile failed: {err_buf.value.decode()}"
505
+ )
506
+
507
+ # build ctypes signature: void name(const T1* in1, const T2* in2, ..., T* out)
508
+ in_dtypes = [a.dtype for a in args]
509
+ # all args are pointers (void*) for simplicity, ctypes will cast
510
+ sig = ctypes.CFUNCTYPE(None, *([ctypes.c_void_p] * (len(args) + 1)))
511
+
512
+ # output shape from the optimized graph's output node
513
+ out_node = optimized.nodes[optimized.outputs[0]]
514
+ out_shape = tuple(out_node.shape.dims[i] for i in range(out_node.shape.rank))
515
+ out_np_dt = _BAFE_TO_NP[out_node.dtype]
516
+
517
+ # Phase 3 (issue #6): compute graph hash + predicted cost for autotune logging
518
+ graph_hash_buf = ctypes.create_string_buffer(65)
519
+ _lib.bafe_jit_hash_graph(ctypes.byref(optimized), graph_hash_buf, ctypes.c_size_t(65))
520
+ graph_hash = graph_hash_buf.value.decode("utf-8")
521
+
522
+ # predicted cost = total graph cost from the cost model
523
+ cm = _lib.bafe_cost_model_default()
524
+ predicted_cost = _lib.bafe_cost_graph(ctypes.byref(cm), ctypes.byref(optimized))
525
+
526
+ # keep a copy of the optimized graph for feature extraction during autotune
527
+ # (we store it as a ctypes object so it doesn't get GC'd)
528
+ opt_graph_copy = BafeGraph()
529
+ ctypes.memmove(ctypes.byref(opt_graph_copy), ctypes.byref(optimized), ctypes.sizeof(BafeGraph))
530
+
531
+ return (fn_ptr, sig, in_dtypes, out_shape, out_np_dt,
532
+ opt_graph_copy, graph_hash, predicted_cost)
533
+
534
+
535
+ def jit(fn: Callable = None, *, budget=None, iters: int = None,
536
+ temperature: float = None, seed: int = None, autotune: bool = False,
537
+ time_budget_ms: int = None):
538
+ """Decorator: trace + optimize + JIT-compile a tensor function.
539
+
540
+ Phase 2 (issue #1): optional stochastic search parameters.
541
+ Phase 3 (issue #4): optional time-budget pruning.
542
+ Phase 3 (issue #6): optional autotune loop.
543
+
544
+ Args:
545
+ budget: a BafeSearchBudget for full control.
546
+ iters: number of stochastic passes (default 4 if budget mode on).
547
+ temperature: 0.0 = greedy, high = explore randomly (default 1.0).
548
+ seed: PRNG seed for reproducibility (default 0xBAFE5EED).
549
+ autotune: if True, enable the auto-tuning loop.
550
+ time_budget_ms: wall-clock limit for the optimization search.
551
+ Controls the pruning regime:
552
+ <= 1 ms: greedy (Level A+B only, beam=1)
553
+ <= 10 ms: light (A+B+C, beam=4)
554
+ <= 100 ms: beam (A+B+C+D, beam=16)
555
+ > 100 ms: deep (all tiers, beam=64)
556
+ 0 or None means no limit (uses stochastic search).
557
+
558
+ Examples:
559
+ @bafe.jit
560
+ def f(A, B): ... # deterministic (default)
561
+
562
+ @bafe.jit(time_budget_ms=100)
563
+ def f(A, B): ... # 100ms pruning budget
564
+
565
+ @bafe.jit(time_budget_ms=1000, autotune=True)
566
+ def f(A, B): ... # 1s budget + autotune
567
+ """
568
+ if fn is not None:
569
+ return JittedFunction(fn, autotune=autotune, time_budget_ms=time_budget_ms)
570
+
571
+ def deco(fn):
572
+ b = budget
573
+ if b is None and (iters is not None or temperature is not None or seed is not None):
574
+ b = make_search_budget(
575
+ max_iters=iters if iters is not None else 4,
576
+ temperature=temperature if temperature is not None else 1.0,
577
+ seed=seed if seed is not None else 0xBAFE5EED,
578
+ )
579
+ if b is not None and time_budget_ms is not None:
580
+ b.time_budget_ms = int(time_budget_ms)
581
+ return JittedFunction(fn, budget=b, autotune=autotune, time_budget_ms=time_budget_ms)
582
+ return deco
583
+
584
+
585
+ def make_search_budget(max_iters: int = 4, max_nodes: int = 256,
586
+ max_rewrites: int = 64, time_budget_ms: int = 0,
587
+ temperature: float = 1.0, seed: int = 0xBAFE5EED,
588
+ enable_multi_pass: bool = True):
589
+ """Build a BafeSearchBudget for use with @bafe.jit(budget=...).
590
+
591
+ The budget controls the stochastic search layer:
592
+ - max_iters: how many stochastic passes (each pass re-applies rules
593
+ to newly-created nodes, discovering deeper rewrites)
594
+ - max_nodes: hard cap on graph size during search
595
+ - max_rewrites: cap on total rewrites materialized
596
+ - time_budget_ms: wall-clock limit (0 = no limit)
597
+ - temperature: 0.0 = greedy (only cost-reducing rewrites),
598
+ high = explore randomly
599
+ - seed: PRNG seed for reproducibility
600
+ - enable_multi_pass: if False, degrades to deterministic single-pass
601
+ """
602
+ from bafe._binding import BafeSearchBudget
603
+ b = BafeSearchBudget()
604
+ b.max_iters = int(max_iters)
605
+ b.max_nodes = int(max_nodes)
606
+ b.max_rewrites = int(max_rewrites)
607
+ b.time_budget_ms = int(time_budget_ms)
608
+ b.temperature = float(temperature)
609
+ b.seed = int(seed) & 0xFFFFFFFF
610
+ b.enable_multi_pass = bool(enable_multi_pass)
611
+ return b
612
+
613
+
614
+ # expose the optimize function for low-level use
615
+ def optimize(graph: BafeGraph) -> BafeGraph:
616
+ """Run the BAFE optimization pipeline on a graph."""
617
+ out = BafeGraph()
618
+ err = ctypes.create_string_buffer(256)
619
+ rc = _lib.bafe_optimize(byref(graph), byref(out), err, c_size_t(len(err)))
620
+ if rc != 0:
621
+ raise RuntimeError(f"optimize failed: {err.value.decode()}")
622
+ return out
623
+
624
+
625
+ # ---------------------------------------------------------------------------
626
+ # Phase 3 (issue #6): autotune API
627
+ # ---------------------------------------------------------------------------
628
+
629
+ def configure_autotune(refit_threshold: int = 20,
630
+ invalidation_drift: float = 0.25,
631
+ warmup_calls: int = 2,
632
+ timing_iters: int = 5):
633
+ """Configure the global autotune settings.
634
+
635
+ Args:
636
+ refit_threshold: refit the cost model after this many new samples.
637
+ invalidation_drift: invalidate cached kernels when predictions
638
+ drift by more than this ratio (0.25 = 25%).
639
+ warmup_calls: skip timing for the first N calls (cache effects).
640
+ timing_iters: average the kernel runtime over this many invocations.
641
+ """
642
+ cfg = _lib.bafe_autotune_config_default()
643
+ cfg.enabled = True
644
+ cfg.refit_threshold = int(refit_threshold)
645
+ cfg.invalidation_drift = float(invalidation_drift)
646
+ cfg.warmup_calls = int(warmup_calls)
647
+ cfg.timing_iters = int(timing_iters)
648
+ _lib.bafe_autotune_configure(ctypes.byref(cfg))
649
+
650
+
651
+ def autotune_stats() -> dict:
652
+ """Get current autotune statistics.
653
+
654
+ Returns a dict with:
655
+ total_calls, total_compiles, total_refits, total_invalidations,
656
+ last_refit_r_squared, log_size
657
+ """
658
+ s = _lib.bafe_autotune_get_stats()
659
+ return {
660
+ "total_calls": s.total_calls,
661
+ "total_compiles": s.total_compiles,
662
+ "total_refits": s.total_refits,
663
+ "total_invalidations": s.total_invalidations,
664
+ "last_refit_r_squared": s.last_refit_r_squared,
665
+ "log_size": s.log_size,
666
+ }
667
+
668
+
669
+ def autotune_refit() -> int:
670
+ """Manually trigger a cost model refit.
671
+
672
+ Returns 0 on success, non-zero if not enough samples.
673
+ """
674
+ return _lib.bafe_profiling_refit()
675
+
676
+
677
+ def autotune_model() -> dict:
678
+ """Get the current learned cost model.
679
+
680
+ Returns a dict with:
681
+ weights (list of 8 floats), bias, r_squared, n_samples, valid
682
+ """
683
+ m = _lib.bafe_profiling_get_model().contents
684
+ return {
685
+ "weights": [m.weights[i] for i in range(8)],
686
+ "bias": m.bias,
687
+ "r_squared": m.r_squared,
688
+ "n_samples": m.n_samples,
689
+ "valid": bool(m.valid),
690
+ }
691
+
692
+
693
+ def autotune_dump_log(path: str) -> int:
694
+ """Dump the profiling log to a JSONL file. Returns number of records."""
695
+ return _lib.bafe_profiling_dump_jsonl(path.encode("utf-8"))
696
+
697
+
698
+ def autotune_reset():
699
+ """Reset all profiling state (log, learned model, stats)."""
700
+ _lib.bafe_profiling_reset()
701
+
702
+
703
+ def calibrate():
704
+ """Build a calibrated cost model from the current learned model.
705
+
706
+ Returns a BafeCostModel that has its per-node weights adjusted based
707
+ on what the learned model discovered about actual runtime correlations.
708
+
709
+ If no learned model is available (no refit has happened yet), returns
710
+ the static default cost model.
711
+
712
+ The calibrated model is automatically used by bafe_optimize for all
713
+ subsequent extractions.
714
+ """
715
+ return _lib.bafe_cost_model_calibrated_default()
716
+
717
+
718
+ def calibrated_cost_model() -> dict:
719
+ """Inspect the calibrated cost model (for debugging).
720
+
721
+ Returns a dict with the calibrated weights:
722
+ alpha_flops, beta_bytes, gamma_intermediate, delta_fuse,
723
+ epsilon_layout_conv, zeta_layout_fuse, eta_contiguous
724
+ """
725
+ cm = _lib.bafe_cost_model_calibrated_default()
726
+ return {
727
+ "alpha_flops": cm.alpha_flops,
728
+ "beta_bytes": cm.beta_bytes,
729
+ "gamma_intermediate": cm.gamma_intermediate,
730
+ "delta_fuse": cm.delta_fuse,
731
+ "epsilon_layout_conv": cm.epsilon_layout_conv,
732
+ "zeta_layout_fuse": cm.zeta_layout_fuse,
733
+ "eta_contiguous": cm.eta_contiguous,
734
+ }
735
+
736
+
737
+ # ---------------------------------------------------------------------------
738
+ # Phase 3 (issue #7): Cross-kernel fusion
739
+ # ---------------------------------------------------------------------------
740
+
741
+ class FusedFunction:
742
+ """A fused kernel combining two jitted functions.
743
+
744
+ When you call `h = bafe.fuse(f, g)`, BAFE concatenates the two
745
+ optimized graphs (f's output feeds g's first input) and compiles
746
+ a single kernel. This avoids materializing the intermediate tensor.
747
+
748
+ The fused kernel takes f's inputs followed by g's inputs[1:].
749
+ """
750
+
751
+ def __init__(self, func_a: "JittedFunction", func_b: "JittedFunction"):
752
+ self._func_a = func_a
753
+ self._func_b = func_b
754
+ self._compiled = {} # key: (shapes, dtypes, layouts) -> compiled tuple
755
+ # for fuse chaining: expose a _fn-like object with the combined arg count
756
+ # n_total = n_a_inputs + n_b_inputs - 1
757
+ class _FnShim:
758
+ def __init__(self, a_fn, b_fn):
759
+ self.__code__ = type("_Code", (), {
760
+ "co_argcount": a_fn.__code__.co_argcount + b_fn.__code__.co_argcount - 1,
761
+ "co_varnames": a_fn.__code__.co_varnames[:a_fn.__code__.co_argcount] + \
762
+ tuple(b_fn.__code__.co_varnames[1:b_fn.__code__.co_argcount]),
763
+ })()
764
+ self._fn = _FnShim(func_a._fn, func_b._fn)
765
+ functools.update_wrapper(self, func_a)
766
+
767
+ def __call__(self, *args: np.ndarray) -> np.ndarray:
768
+ if not args:
769
+ raise TypeError("fused function requires at least one input")
770
+
771
+ # split args: first len(a_inputs) go to f, rest go to g[1:]
772
+ # we need to know how many inputs f takes — infer from the first compile
773
+ # Actually, f takes some inputs, g takes some inputs, and f's output
774
+ # feeds g's first input. So total args = n_f_inputs + n_g_inputs - 1.
775
+ # We need to figure out the split.
776
+
777
+ # build cache key
778
+ def _layout_of(a):
779
+ if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
780
+ return "col"
781
+ return "row"
782
+ key = tuple((a.shape, str(a.dtype), _layout_of(a)) for a in args)
783
+
784
+ if key not in self._compiled:
785
+ self._compiled[key] = self._compile(args)
786
+
787
+ fn_ptr, sig, in_dtypes, out_shape, out_dtype, _opt_graph, _hash, _pred = self._compiled[key]
788
+
789
+ # allocate output
790
+ out = np.zeros(out_shape, dtype=out_dtype)
791
+
792
+ # build arg list
793
+ c_args = []
794
+ for a, dt in zip(args, in_dtypes):
795
+ if a.dtype == dt:
796
+ arr = a
797
+ else:
798
+ if a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
799
+ arr = np.asfortranarray(a, dtype=dt)
800
+ else:
801
+ arr = np.ascontiguousarray(a, dtype=dt)
802
+ c_args.append(arr.ctypes.data_as(ctypes.c_void_p))
803
+ c_args.append(out.ctypes.data_as(ctypes.c_void_p))
804
+
805
+ kernel = sig(fn_ptr)
806
+ kernel(*c_args)
807
+ return out
808
+
809
+ def _compile(self, args):
810
+ """Compile the fused function and return the 8-tuple matching
811
+ JittedFunction._compile (for fuse chaining)."""
812
+ # Figure out n_f_inputs by looking at f's signature
813
+ f_code = self._func_a._fn.__code__
814
+ n_f_inputs = f_code.co_argcount
815
+ g_code = self._func_b._fn.__code__
816
+ n_g_inputs = g_code.co_argcount
817
+
818
+ if len(args) != n_f_inputs + n_g_inputs - 1:
819
+ raise TypeError(
820
+ f"fused function expects {n_f_inputs + n_g_inputs - 1} args "
821
+ f"({n_f_inputs} for f + {n_g_inputs - 1} for g, since g's "
822
+ f"first input is f's output), got {len(args)}"
823
+ )
824
+
825
+ f_args = args[:n_f_inputs]
826
+ g_args = args[n_f_inputs:]
827
+
828
+ # compile f to get its optimized graph
829
+ f_result = self._func_a._compile(f_args)
830
+ f_fn_ptr, f_sig, f_in_dtypes, f_out_shape, f_out_dt, f_opt_graph, f_hash, f_pred = f_result
831
+
832
+ # For g, compile with a dummy first input matching f's output
833
+ dummy_f_out = np.zeros(f_out_shape, dtype=f_out_dt)
834
+ g_full_args = (dummy_f_out,) + g_args
835
+ g_result = self._func_b._compile(g_full_args)
836
+ g_opt_graph = g_result[5]
837
+
838
+ # Concatenate the two optimized graphs via the C API
839
+ fused_graph = BafeGraph()
840
+ err_buf = ctypes.create_string_buffer(256)
841
+ rc = _lib.bafe_fuse_concat(
842
+ ctypes.byref(f_opt_graph),
843
+ ctypes.byref(g_opt_graph),
844
+ ctypes.byref(fused_graph),
845
+ err_buf,
846
+ ctypes.c_size_t(len(err_buf)),
847
+ )
848
+ if rc != 0:
849
+ raise RuntimeError(
850
+ f"bafe_fuse_concat failed (code {rc}): {err_buf.value.decode()}"
851
+ )
852
+
853
+ # Optimize + JIT compile the fused graph
854
+ optimized = BafeGraph()
855
+ rc = _lib.bafe_optimize(
856
+ ctypes.byref(fused_graph),
857
+ ctypes.byref(optimized),
858
+ err_buf,
859
+ ctypes.c_size_t(len(err_buf)),
860
+ )
861
+ if rc != 0:
862
+ raise RuntimeError(
863
+ f"bafe_optimize failed for fused graph (code {rc}): {err_buf.value.decode()}"
864
+ )
865
+
866
+ fn_ptr = _lib.bafe_jit_get_or_compile(
867
+ ctypes.byref(optimized),
868
+ err_buf,
869
+ ctypes.c_size_t(len(err_buf)),
870
+ )
871
+ if not fn_ptr:
872
+ raise RuntimeError(
873
+ f"JIT compile failed for fused graph: {err_buf.value.decode()}"
874
+ )
875
+
876
+ # build the ctypes signature for the fused kernel
877
+ n_total_inputs = n_f_inputs + n_g_inputs - 1
878
+ sig = ctypes.CFUNCTYPE(None, *([ctypes.c_void_p] * (n_total_inputs + 1)))
879
+
880
+ in_dtypes = [a.dtype for a in args]
881
+
882
+ out_node = optimized.nodes[optimized.outputs[0]]
883
+ out_shape = tuple(out_node.shape.dims[i] for i in range(out_node.shape.rank))
884
+ out_dt = _BAFE_TO_NP[out_node.dtype]
885
+
886
+ # keep a copy of the optimized graph for autotune feature extraction
887
+ opt_graph_copy = BafeGraph()
888
+ ctypes.memmove(ctypes.byref(opt_graph_copy), ctypes.byref(optimized), ctypes.sizeof(BafeGraph))
889
+
890
+ graph_hash_buf = ctypes.create_string_buffer(65)
891
+ _lib.bafe_jit_hash_graph(ctypes.byref(optimized), graph_hash_buf, ctypes.c_size_t(65))
892
+ graph_hash = graph_hash_buf.value.decode("utf-8")
893
+
894
+ cm = _lib.bafe_cost_model_calibrated_default()
895
+ predicted = _lib.bafe_cost_graph(ctypes.byref(cm), ctypes.byref(optimized))
896
+
897
+ return (fn_ptr, sig, in_dtypes, out_shape, out_dt,
898
+ opt_graph_copy, graph_hash, predicted)
899
+
900
+
901
+ def fuse(func_a, func_b):
902
+ """Fuse two jitted functions into a single kernel.
903
+
904
+ When `h = bafe.fuse(f, g)`, calling `h(a, b, c)` is equivalent to
905
+ `g(f(a, b), c)` but compiled as a single kernel — the intermediate
906
+ tensor (f's output) is never materialized.
907
+
908
+ Args:
909
+ func_a: a @bafe.jit-decorated function or FusedFunction (producer)
910
+ func_b: a @bafe.jit-decorated function or FusedFunction (consumer;
911
+ its first argument receives f's output)
912
+
913
+ Returns:
914
+ A FusedFunction that takes f's inputs + g's inputs[1:].
915
+ """
916
+ # Accept both JittedFunction and FusedFunction (for chaining)
917
+ if not isinstance(func_a, (JittedFunction, FusedFunction)):
918
+ raise TypeError("fuse() requires @bafe.jit-decorated or fused functions")
919
+ if not isinstance(func_b, (JittedFunction, FusedFunction)):
920
+ raise TypeError("fuse() requires @bafe.jit-decorated or fused functions")
921
+ return FusedFunction(func_a, func_b)
922
+
923
+
924
+ __all__ = [
925
+ "Tensor", "jit", "optimize", "make_search_budget", "fuse",
926
+ "input", "matmul", "add", "sub", "mul", "relu", "sigmoid", "tanh",
927
+ "bias_add", "transpose", "reduce_sum", "reduce_max", "reshape", "broadcast_to",
928
+ "graph_summary",
929
+ "configure_autotune", "autotune_stats", "autotune_refit",
930
+ "autotune_model", "autotune_dump_log", "autotune_reset",
931
+ "calibrate", "calibrated_cost_model",
932
+ "pruning_regime_name", "pruning_beam_width", "pruning_iters",
933
+ "__version__",
934
+ ]
935
+
936
+
937
+ # ---------------------------------------------------------------------------
938
+ # Phase 3 (issue #4): Pruning controller helpers
939
+ # ---------------------------------------------------------------------------
940
+
941
+ _PRUNING_REGIMES = {
942
+ 0: "greedy",
943
+ 1: "light",
944
+ 2: "beam",
945
+ 3: "deep",
946
+ }
947
+
948
+
949
+ def pruning_regime_name(time_budget_ms: int) -> str:
950
+ """Get the regime name for a time budget.
951
+
952
+ Returns one of: "greedy" (<=1ms), "light" (<=10ms), "beam" (<=100ms),
953
+ "deep" (>100ms). 0 or None returns "deep" (no limit).
954
+ """
955
+ if time_budget_ms is None or time_budget_ms <= 0:
956
+ return "deep"
957
+ regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
958
+ return _PRUNING_REGIMES.get(regime, "unknown")
959
+
960
+
961
+ def pruning_beam_width(time_budget_ms: int) -> int:
962
+ """Get the beam width for a time budget's regime."""
963
+ if time_budget_ms is None or time_budget_ms <= 0:
964
+ time_budget_ms = 0
965
+ regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
966
+ return _lib.bafe_pruning_beam_width_for_regime(regime)
967
+
968
+
969
+ def pruning_iters(time_budget_ms: int) -> int:
970
+ """Get the number of stochastic iterations for a time budget's regime."""
971
+ if time_budget_ms is None or time_budget_ms <= 0:
972
+ time_budget_ms = 0
973
+ regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
974
+ return _lib.bafe_pruning_iters_for_regime(regime)