bafe-engine 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- bafe/__init__.py +974 -0
- bafe/_binding.py +528 -0
- bafe/libbafe.so +0 -0
- bafe_engine-0.1.0.dist-info/METADATA +343 -0
- bafe_engine-0.1.0.dist-info/RECORD +8 -0
- bafe_engine-0.1.0.dist-info/WHEEL +5 -0
- bafe_engine-0.1.0.dist-info/licenses/LICENSE +201 -0
- bafe_engine-0.1.0.dist-info/top_level.txt +1 -0
bafe/__init__.py
ADDED
|
@@ -0,0 +1,974 @@
|
|
|
1
|
+
"""BAFE - Basic Algebra Fusion Engine.
|
|
2
|
+
|
|
3
|
+
Public API:
|
|
4
|
+
import bafe
|
|
5
|
+
|
|
6
|
+
@bafe.jit
|
|
7
|
+
def f(A, B, C):
|
|
8
|
+
return bafe.relu(bafe.matmul(A, B) + C)
|
|
9
|
+
|
|
10
|
+
The decorator traces the function (which builds an IR graph via the op
|
|
11
|
+
functions), runs the BAFE optimization pipeline, JIT-compiles the result,
|
|
12
|
+
and returns a callable that invokes the compiled kernel directly.
|
|
13
|
+
"""
|
|
14
|
+
from __future__ import annotations
|
|
15
|
+
|
|
16
|
+
import ctypes
|
|
17
|
+
import functools
|
|
18
|
+
import os
|
|
19
|
+
from typing import Callable, List, Tuple, Any
|
|
20
|
+
|
|
21
|
+
import numpy as np
|
|
22
|
+
|
|
23
|
+
from bafe._binding import (
|
|
24
|
+
_lib, BafeGraph, BafeShape, BafeOpAttrs, BafeNode,
|
|
25
|
+
make_shape, make_attrs, graph_summary,
|
|
26
|
+
BAFE_MAX_NODES, BAFE_MAX_CHILDREN,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
__version__ = "0.1.0"
|
|
30
|
+
|
|
31
|
+
# ---------------------------------------------------------------------------
|
|
32
|
+
# Dtype mapping
|
|
33
|
+
# ---------------------------------------------------------------------------
|
|
34
|
+
|
|
35
|
+
_NP_TO_BAFE = {
|
|
36
|
+
np.dtype("float32"): 0, # BAFE_DTYPE_F32
|
|
37
|
+
np.dtype("float64"): 1, # BAFE_DTYPE_F64
|
|
38
|
+
np.dtype("int32"): 2, # BAFE_DTYPE_I32
|
|
39
|
+
np.dtype("int64"): 3, # BAFE_DTYPE_I64
|
|
40
|
+
}
|
|
41
|
+
|
|
42
|
+
_BAFE_TO_NP = {v: k for k, v in _NP_TO_BAFE.items()}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
# ---------------------------------------------------------------------------
|
|
46
|
+
# Tracing context
|
|
47
|
+
# ---------------------------------------------------------------------------
|
|
48
|
+
|
|
49
|
+
class _TraceContext:
|
|
50
|
+
"""Holds the in-progress graph and the name->node_id mapping."""
|
|
51
|
+
def __init__(self):
|
|
52
|
+
self.graph = BafeGraph()
|
|
53
|
+
_lib.bafe_graph_init(ctypes.byref(self.graph))
|
|
54
|
+
self.inputs: List[Tuple[str, "Tensor"]] = []
|
|
55
|
+
self.input_names: List[str] = []
|
|
56
|
+
|
|
57
|
+
def add_input(self, name: str, shape: Tuple[int, ...], dtype: int,
|
|
58
|
+
layout: int = 0) -> int:
|
|
59
|
+
"""Add an input. layout is a bafe_layout enum value (0=row, 1=col)."""
|
|
60
|
+
sh = make_shape(shape)
|
|
61
|
+
if layout == 0:
|
|
62
|
+
# default row-major: use the plain add_input for backward compat
|
|
63
|
+
nid = _lib.bafe_graph_add_input(
|
|
64
|
+
ctypes.byref(self.graph),
|
|
65
|
+
name.encode("utf-8"),
|
|
66
|
+
ctypes.byref(sh),
|
|
67
|
+
ctypes.c_int(dtype),
|
|
68
|
+
)
|
|
69
|
+
else:
|
|
70
|
+
nid = _lib.bafe_graph_add_input_with_layout(
|
|
71
|
+
ctypes.byref(self.graph),
|
|
72
|
+
name.encode("utf-8"),
|
|
73
|
+
ctypes.byref(sh),
|
|
74
|
+
ctypes.c_int(dtype),
|
|
75
|
+
ctypes.c_int(layout),
|
|
76
|
+
)
|
|
77
|
+
if nid < 0:
|
|
78
|
+
raise RuntimeError(f"failed to add input {name}")
|
|
79
|
+
self.input_names.append(name)
|
|
80
|
+
return nid
|
|
81
|
+
|
|
82
|
+
def add_op(self, op_name: str, children: List[int], **attrs) -> int:
|
|
83
|
+
n = len(children)
|
|
84
|
+
arr = (ctypes.c_int32 * max(n, 1))(*children) if n else None
|
|
85
|
+
a = make_attrs(**attrs) if attrs else None
|
|
86
|
+
nid = _lib.bafe_graph_add(
|
|
87
|
+
ctypes.byref(self.graph),
|
|
88
|
+
op_name.encode("utf-8"),
|
|
89
|
+
arr,
|
|
90
|
+
ctypes.c_int(n),
|
|
91
|
+
ctypes.byref(a) if a else None,
|
|
92
|
+
)
|
|
93
|
+
if nid < 0:
|
|
94
|
+
raise RuntimeError(f"failed to add op {op_name} (children={children})")
|
|
95
|
+
return nid
|
|
96
|
+
|
|
97
|
+
def set_output(self, nid: int) -> None:
|
|
98
|
+
_lib.bafe_graph_set_output(ctypes.byref(self.graph), ctypes.c_int32(nid))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
_TRACE: _TraceContext | None = None
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def _ensure_trace() -> _TraceContext:
|
|
105
|
+
global _TRACE
|
|
106
|
+
if _TRACE is None:
|
|
107
|
+
raise RuntimeError(
|
|
108
|
+
"no active trace - bafe ops can only be called inside a "
|
|
109
|
+
"function decorated with @bafe.jit"
|
|
110
|
+
)
|
|
111
|
+
return _TRACE
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
# ---------------------------------------------------------------------------
|
|
115
|
+
# Tensor handle (symbolic during tracing)
|
|
116
|
+
# ---------------------------------------------------------------------------
|
|
117
|
+
|
|
118
|
+
class Tensor:
|
|
119
|
+
"""A symbolic tensor handle.
|
|
120
|
+
|
|
121
|
+
During tracing (inside an @bafe.jit function), Tensor instances refer
|
|
122
|
+
to IR graph nodes. They support Python operator overloads so users
|
|
123
|
+
can write natural math expressions.
|
|
124
|
+
"""
|
|
125
|
+
__slots__ = ("node_id", "shape", "dtype", "_name")
|
|
126
|
+
|
|
127
|
+
def __init__(self, node_id: int, shape: Tuple[int, ...], dtype: int, name: str | None = None):
|
|
128
|
+
self.node_id = node_id
|
|
129
|
+
self.shape = shape
|
|
130
|
+
self.dtype = dtype
|
|
131
|
+
self._name = name
|
|
132
|
+
|
|
133
|
+
def __add__(self, other: "Tensor") -> "Tensor":
|
|
134
|
+
return _binop("add", self, other)
|
|
135
|
+
|
|
136
|
+
def __sub__(self, other: "Tensor") -> "Tensor":
|
|
137
|
+
return _binop("sub", self, other)
|
|
138
|
+
|
|
139
|
+
def __mul__(self, other: "Tensor") -> "Tensor":
|
|
140
|
+
return _binop("mul", self, other)
|
|
141
|
+
|
|
142
|
+
def __matmul__(self, other: "Tensor") -> "Tensor":
|
|
143
|
+
tc = _ensure_trace()
|
|
144
|
+
if len(self.shape) != 2 or len(other.shape) != 2:
|
|
145
|
+
raise ValueError("matmul requires rank-2 tensors")
|
|
146
|
+
if self.shape[1] != other.shape[0]:
|
|
147
|
+
raise ValueError(
|
|
148
|
+
f"matmul shape mismatch: {self.shape} @ {other.shape}"
|
|
149
|
+
)
|
|
150
|
+
out_shape = (self.shape[0], other.shape[1])
|
|
151
|
+
nid = tc.add_op("matmul", [self.node_id, other.node_id])
|
|
152
|
+
return Tensor(nid, out_shape, self.dtype)
|
|
153
|
+
|
|
154
|
+
@property
|
|
155
|
+
def name(self) -> str | None:
|
|
156
|
+
return self._name
|
|
157
|
+
|
|
158
|
+
def __repr__(self) -> str:
|
|
159
|
+
return f"Tensor(shape={self.shape}, dtype={_BAFE_TO_NP[self.dtype]})"
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def _binop(op: str, a: Tensor, b: Tensor) -> Tensor:
|
|
163
|
+
tc = _ensure_trace()
|
|
164
|
+
# broadcasting: simple version, just take the larger shape
|
|
165
|
+
out_shape = _broadcast_shapes(a.shape, b.shape)
|
|
166
|
+
nid = tc.add_op(op, [a.node_id, b.node_id])
|
|
167
|
+
return Tensor(nid, out_shape, a.dtype)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def _broadcast_shapes(a, b):
|
|
171
|
+
if len(a) == len(b):
|
|
172
|
+
return tuple(max(x, y) for x, y in zip(a, b))
|
|
173
|
+
if len(a) < len(b):
|
|
174
|
+
a, b = b, a
|
|
175
|
+
# pad b
|
|
176
|
+
bpad = (1,) * (len(a) - len(b)) + tuple(b)
|
|
177
|
+
return tuple(max(x, y) for x, y in zip(a, bpad))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# ---------------------------------------------------------------------------
|
|
181
|
+
# Op functions (used inside @bafe.jit functions)
|
|
182
|
+
# ---------------------------------------------------------------------------
|
|
183
|
+
|
|
184
|
+
def input(shape: Tuple[int, ...], dtype: str | np.dtype = "float32",
|
|
185
|
+
name: str = "x", layout: str = "row") -> Tensor:
|
|
186
|
+
"""Declare an input tensor.
|
|
187
|
+
|
|
188
|
+
Usually called automatically by @bafe.jit based on the function's
|
|
189
|
+
arguments, but can also be called explicitly.
|
|
190
|
+
|
|
191
|
+
layout: "row" (default, C order) or "col" (Fortran order).
|
|
192
|
+
The layout tag tells BAFE how the input data is stored in memory,
|
|
193
|
+
which affects codegen (access patterns) and the cost model
|
|
194
|
+
(conversion penalties, fusion bonuses).
|
|
195
|
+
"""
|
|
196
|
+
tc = _ensure_trace()
|
|
197
|
+
np_dt = np.dtype(dtype)
|
|
198
|
+
if np_dt not in _NP_TO_BAFE:
|
|
199
|
+
raise ValueError(f"unsupported dtype {dtype}; supported: {list(_NP_TO_BAFE)}")
|
|
200
|
+
bafe_dt = _NP_TO_BAFE[np_dt]
|
|
201
|
+
layout_map = {"row": 0, "col": 1, "blocked": 2, "tc": 3}
|
|
202
|
+
if layout not in layout_map:
|
|
203
|
+
raise ValueError(f"unsupported layout {layout!r}; supported: {list(layout_map)}")
|
|
204
|
+
bafe_layout = layout_map[layout]
|
|
205
|
+
nid = tc.add_input(name, tuple(int(s) for s in shape), bafe_dt, layout=bafe_layout)
|
|
206
|
+
return Tensor(nid, tuple(int(s) for s in shape), bafe_dt, name=name)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def matmul(a: Tensor, b: Tensor) -> Tensor:
|
|
210
|
+
return a @ b
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def add(a: Tensor, b: Tensor) -> Tensor:
|
|
214
|
+
return a + b
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def sub(a: Tensor, b: Tensor) -> Tensor:
|
|
218
|
+
return a - b
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def mul(a: Tensor, b: Tensor) -> Tensor:
|
|
222
|
+
return a * b
|
|
223
|
+
|
|
224
|
+
|
|
225
|
+
def relu(x: Tensor) -> Tensor:
|
|
226
|
+
tc = _ensure_trace()
|
|
227
|
+
nid = tc.add_op("relu", [x.node_id])
|
|
228
|
+
return Tensor(nid, x.shape, x.dtype)
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def sigmoid(x: Tensor) -> Tensor:
|
|
232
|
+
tc = _ensure_trace()
|
|
233
|
+
nid = tc.add_op("sigmoid", [x.node_id])
|
|
234
|
+
return Tensor(nid, x.shape, x.dtype)
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def tanh(x: Tensor) -> Tensor:
|
|
238
|
+
tc = _ensure_trace()
|
|
239
|
+
nid = tc.add_op("tanh", [x.node_id])
|
|
240
|
+
return Tensor(nid, x.shape, x.dtype)
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def bias_add(x: Tensor, bias: Tensor) -> Tensor:
|
|
244
|
+
tc = _ensure_trace()
|
|
245
|
+
nid = tc.add_op("bias_add", [x.node_id, bias.node_id])
|
|
246
|
+
return Tensor(nid, x.shape, x.dtype)
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
def transpose(x: Tensor, perm: Tuple[int, ...]) -> Tensor:
|
|
250
|
+
tc = _ensure_trace()
|
|
251
|
+
nid = tc.add_op("transpose", [x.node_id], perm=list(perm))
|
|
252
|
+
out_shape = tuple(x.shape[p] for p in perm)
|
|
253
|
+
return Tensor(nid, out_shape, x.dtype)
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
def reduce_sum(x: Tensor, axes: Tuple[int, ...], keepdims: bool = False) -> Tensor:
|
|
257
|
+
tc = _ensure_trace()
|
|
258
|
+
nid = tc.add_op("reduce_sum", [x.node_id], axes=list(axes), keepdims=keepdims)
|
|
259
|
+
out_shape = _reduce_shape(x.shape, axes, keepdims)
|
|
260
|
+
return Tensor(nid, out_shape, x.dtype)
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
def reduce_max(x: Tensor, axes: Tuple[int, ...], keepdims: bool = False) -> Tensor:
|
|
264
|
+
tc = _ensure_trace()
|
|
265
|
+
nid = tc.add_op("reduce_max", [x.node_id], axes=list(axes), keepdims=keepdims)
|
|
266
|
+
out_shape = _reduce_shape(x.shape, axes, keepdims)
|
|
267
|
+
return Tensor(nid, out_shape, x.dtype)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def reshape(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
|
|
271
|
+
tc = _ensure_trace()
|
|
272
|
+
nid = tc.add_op("reshape", [x.node_id], shape=list(shape))
|
|
273
|
+
return Tensor(nid, tuple(shape), x.dtype)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def broadcast_to(x: Tensor, shape: Tuple[int, ...]) -> Tensor:
|
|
277
|
+
tc = _ensure_trace()
|
|
278
|
+
nid = tc.add_op("broadcast_to", [x.node_id], shape=list(shape))
|
|
279
|
+
return Tensor(nid, tuple(shape), x.dtype)
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def _reduce_shape(in_shape, axes, keepdims):
|
|
283
|
+
axes_set = {a % len(in_shape) for a in axes}
|
|
284
|
+
out = []
|
|
285
|
+
for i, d in enumerate(in_shape):
|
|
286
|
+
if i in axes_set:
|
|
287
|
+
if keepdims:
|
|
288
|
+
out.append(1)
|
|
289
|
+
else:
|
|
290
|
+
out.append(d)
|
|
291
|
+
return tuple(out)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
# ---------------------------------------------------------------------------
|
|
295
|
+
# @bafe.jit decorator
|
|
296
|
+
# ---------------------------------------------------------------------------
|
|
297
|
+
|
|
298
|
+
class JittedFunction:
|
|
299
|
+
"""A JIT-compiled BAFE function.
|
|
300
|
+
|
|
301
|
+
On first call with concrete numpy arrays, it:
|
|
302
|
+
1. Inspects the input shapes/dtypes
|
|
303
|
+
2. Runs the user function under a trace
|
|
304
|
+
3. Calls bafe_optimize + bafe_jit_get_or_compile
|
|
305
|
+
4. Builds a ctypes function pointer with the correct signature
|
|
306
|
+
5. Invokes the kernel
|
|
307
|
+
|
|
308
|
+
Subsequent calls with the same shapes/dtypes hit the JIT cache.
|
|
309
|
+
|
|
310
|
+
Phase 2 (issue #1): if a search budget is set, uses stochastic
|
|
311
|
+
multi-pass search to discover deeper rewrites.
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
def __init__(self, fn: Callable, budget=None, autotune=False, time_budget_ms=None):
|
|
315
|
+
self._fn = fn
|
|
316
|
+
self._budget = budget # BafeSearchBudget or None (deterministic)
|
|
317
|
+
self._autotune = autotune # Phase 3: enable autotune loop
|
|
318
|
+
self._time_budget_ms = time_budget_ms # Phase 3 (issue #4): pruning time budget
|
|
319
|
+
self._compiled = {} # key: (shapes, dtypes, layouts) -> compiled tuple
|
|
320
|
+
self._call_count = {} # key -> call number (for warmup tracking)
|
|
321
|
+
functools.update_wrapper(self, fn)
|
|
322
|
+
|
|
323
|
+
def __call__(self, *args: np.ndarray) -> np.ndarray:
|
|
324
|
+
# validate
|
|
325
|
+
if not args:
|
|
326
|
+
raise TypeError("jitted function requires at least one input")
|
|
327
|
+
for a in args:
|
|
328
|
+
if not isinstance(a, np.ndarray):
|
|
329
|
+
raise TypeError(f"expected numpy array, got {type(a)}")
|
|
330
|
+
|
|
331
|
+
# build cache key (Phase 2: include layout so row-major vs col-major
|
|
332
|
+
# inputs produce different compiled kernels)
|
|
333
|
+
def _layout_of(a):
|
|
334
|
+
if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
|
|
335
|
+
return "col"
|
|
336
|
+
return "row"
|
|
337
|
+
key = tuple((a.shape, str(a.dtype), _layout_of(a)) for a in args)
|
|
338
|
+
|
|
339
|
+
if key not in self._compiled:
|
|
340
|
+
self._compiled[key] = self._compile(args)
|
|
341
|
+
# Phase 3 (issue #6): if autotune is enabled and this is a new
|
|
342
|
+
# compile, increment the compile counter
|
|
343
|
+
if self._autotune_enabled():
|
|
344
|
+
stats = _lib.bafe_autotune_get_stats()
|
|
345
|
+
# we can't easily mutate the C stats from here, but the C
|
|
346
|
+
# side tracks compiles via bafe_jit_get_or_compile
|
|
347
|
+
pass
|
|
348
|
+
|
|
349
|
+
fn_ptr, sig, in_dtypes, out_shape, out_dtype, opt_graph, graph_hash, predicted_cost = self._compiled[key]
|
|
350
|
+
|
|
351
|
+
# allocate output
|
|
352
|
+
out = np.zeros(out_shape, dtype=out_dtype)
|
|
353
|
+
|
|
354
|
+
# build arg list: inputs + output pointer
|
|
355
|
+
# Phase 2: preserve the input's memory layout — if we compiled for
|
|
356
|
+
# col-major, the array must stay col-major (don't ascontiguousarray it).
|
|
357
|
+
c_args = []
|
|
358
|
+
for a, dt in zip(args, in_dtypes):
|
|
359
|
+
if a.dtype == dt:
|
|
360
|
+
arr = a
|
|
361
|
+
else:
|
|
362
|
+
# dtype conversion needed; preserve layout
|
|
363
|
+
if a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
|
|
364
|
+
arr = np.asfortranarray(a, dtype=dt)
|
|
365
|
+
else:
|
|
366
|
+
arr = np.ascontiguousarray(a, dtype=dt)
|
|
367
|
+
c_args.append(arr.ctypes.data_as(ctypes.c_void_p))
|
|
368
|
+
c_args.append(out.ctypes.data_as(ctypes.c_void_p))
|
|
369
|
+
|
|
370
|
+
# call: sig(fn_ptr) creates the callable; then pass the c_args
|
|
371
|
+
kernel = sig(fn_ptr)
|
|
372
|
+
|
|
373
|
+
# Phase 3 (issue #6): if autotune is enabled, time the kernel + log
|
|
374
|
+
if self._autotune_enabled():
|
|
375
|
+
import time
|
|
376
|
+
config = _lib.bafe_autotune_get_config()
|
|
377
|
+
stats = _lib.bafe_autotune_get_stats()
|
|
378
|
+
# warmup: skip timing for the first N calls
|
|
379
|
+
# (stats.total_calls is incremented on the C side, but since
|
|
380
|
+
# we're driving the loop from Python, we track it here)
|
|
381
|
+
call_num = self._call_count.get(key, 0) + 1
|
|
382
|
+
self._call_count[key] = call_num
|
|
383
|
+
|
|
384
|
+
if call_num <= config.warmup_calls:
|
|
385
|
+
kernel(*c_args)
|
|
386
|
+
else:
|
|
387
|
+
# time over multiple iterations
|
|
388
|
+
iters = config.timing_iters if config.timing_iters > 0 else 1
|
|
389
|
+
t0 = time.perf_counter()
|
|
390
|
+
for _ in range(iters):
|
|
391
|
+
kernel(*c_args)
|
|
392
|
+
t1 = time.perf_counter()
|
|
393
|
+
observed_ms = (t1 - t0) * 1000.0 / iters
|
|
394
|
+
|
|
395
|
+
# extract features + log
|
|
396
|
+
features = (ctypes.c_double * 8)()
|
|
397
|
+
_lib.bafe_profiling_extract_features(
|
|
398
|
+
ctypes.byref(opt_graph), features
|
|
399
|
+
)
|
|
400
|
+
_lib.bafe_profiling_add(
|
|
401
|
+
graph_hash.encode("utf-8") if isinstance(graph_hash, str) else graph_hash,
|
|
402
|
+
features,
|
|
403
|
+
ctypes.c_double(predicted_cost),
|
|
404
|
+
ctypes.c_double(observed_ms),
|
|
405
|
+
ctypes.c_int(0),
|
|
406
|
+
)
|
|
407
|
+
|
|
408
|
+
# check if we should refit
|
|
409
|
+
log = _lib.bafe_profiling_get_log().contents
|
|
410
|
+
if log.n > 0 and log.n % config.refit_threshold == 0:
|
|
411
|
+
_lib.bafe_profiling_refit()
|
|
412
|
+
# Phase 3 (issue #5): after refit, invalidate the
|
|
413
|
+
# Python-level compiled cache so the next call
|
|
414
|
+
# re-optimizes with the calibrated cost model.
|
|
415
|
+
# The C-level JIT cache is also invalidated so
|
|
416
|
+
# dlopen'd handles are released.
|
|
417
|
+
self._compiled.clear()
|
|
418
|
+
_lib.bafe_jit_invalidate_memory_cache()
|
|
419
|
+
else:
|
|
420
|
+
kernel(*c_args)
|
|
421
|
+
|
|
422
|
+
return out
|
|
423
|
+
|
|
424
|
+
def _autotune_enabled(self) -> bool:
|
|
425
|
+
"""Check if autotune is enabled for this function."""
|
|
426
|
+
return getattr(self, "_autotune", False)
|
|
427
|
+
|
|
428
|
+
def _compile(self, args: Tuple[np.ndarray, ...]):
|
|
429
|
+
global _TRACE
|
|
430
|
+
# run the trace
|
|
431
|
+
_TRACE = _TraceContext()
|
|
432
|
+
try:
|
|
433
|
+
# build input tensors
|
|
434
|
+
in_tensors = []
|
|
435
|
+
for i, a in enumerate(args):
|
|
436
|
+
np_dt = np.dtype(a.dtype)
|
|
437
|
+
if np_dt not in _NP_TO_BAFE:
|
|
438
|
+
raise TypeError(f"unsupported dtype {np_dt}")
|
|
439
|
+
bafe_dt = _NP_TO_BAFE[np_dt]
|
|
440
|
+
name = self._fn.__code__.co_varnames[i] if i < len(self._fn.__code__.co_varnames) else f"in{i}"
|
|
441
|
+
# Phase 2: auto-detect input layout from numpy array flags
|
|
442
|
+
# If the array is Fortran-contiguous (col-major), tag it as "col".
|
|
443
|
+
# Otherwise default to "row" (C order).
|
|
444
|
+
if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
|
|
445
|
+
layout = "col"
|
|
446
|
+
else:
|
|
447
|
+
layout = "row"
|
|
448
|
+
t = input(a.shape, np_dt, name=name, layout=layout)
|
|
449
|
+
in_tensors.append(t)
|
|
450
|
+
|
|
451
|
+
# call the user function
|
|
452
|
+
result = self._fn(*in_tensors)
|
|
453
|
+
if not isinstance(result, Tensor):
|
|
454
|
+
raise TypeError(
|
|
455
|
+
f"jitted function must return a Tensor, got {type(result)}"
|
|
456
|
+
)
|
|
457
|
+
_TRACE.set_output(result.node_id)
|
|
458
|
+
|
|
459
|
+
# snapshot the input graph (the optimize call may add nodes via rewrites)
|
|
460
|
+
in_graph = _TRACE.graph
|
|
461
|
+
finally:
|
|
462
|
+
_TRACE = None
|
|
463
|
+
|
|
464
|
+
# optimize + compile
|
|
465
|
+
# Phase 2 (issue #1): if a budget is set, use stochastic multi-pass search
|
|
466
|
+
# Phase 3 (issue #4): if time_budget_ms is set, use the pruning controller
|
|
467
|
+
optimized = BafeGraph()
|
|
468
|
+
err_buf = ctypes.create_string_buffer(256)
|
|
469
|
+
if self._budget is not None or self._time_budget_ms is not None:
|
|
470
|
+
# build a budget struct
|
|
471
|
+
if self._budget is not None:
|
|
472
|
+
budget = self._budget
|
|
473
|
+
else:
|
|
474
|
+
budget = _lib.bafe_search_budget_default()
|
|
475
|
+
if self._time_budget_ms is not None:
|
|
476
|
+
budget.time_budget_ms = int(self._time_budget_ms)
|
|
477
|
+
rc = _lib.bafe_optimize_with_budget(
|
|
478
|
+
ctypes.byref(in_graph),
|
|
479
|
+
ctypes.byref(optimized),
|
|
480
|
+
ctypes.byref(budget),
|
|
481
|
+
err_buf,
|
|
482
|
+
ctypes.c_size_t(len(err_buf)),
|
|
483
|
+
)
|
|
484
|
+
else:
|
|
485
|
+
rc = _lib.bafe_optimize(
|
|
486
|
+
ctypes.byref(in_graph),
|
|
487
|
+
ctypes.byref(optimized),
|
|
488
|
+
err_buf,
|
|
489
|
+
ctypes.c_size_t(len(err_buf)),
|
|
490
|
+
)
|
|
491
|
+
if rc != 0:
|
|
492
|
+
raise RuntimeError(
|
|
493
|
+
f"bafe_optimize failed (code {rc}): {err_buf.value.decode()}"
|
|
494
|
+
)
|
|
495
|
+
|
|
496
|
+
# JIT compile
|
|
497
|
+
fn_ptr = _lib.bafe_jit_get_or_compile(
|
|
498
|
+
ctypes.byref(optimized),
|
|
499
|
+
err_buf,
|
|
500
|
+
ctypes.c_size_t(len(err_buf)),
|
|
501
|
+
)
|
|
502
|
+
if not fn_ptr:
|
|
503
|
+
raise RuntimeError(
|
|
504
|
+
f"bafe_jit_get_or_compile failed: {err_buf.value.decode()}"
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
# build ctypes signature: void name(const T1* in1, const T2* in2, ..., T* out)
|
|
508
|
+
in_dtypes = [a.dtype for a in args]
|
|
509
|
+
# all args are pointers (void*) for simplicity, ctypes will cast
|
|
510
|
+
sig = ctypes.CFUNCTYPE(None, *([ctypes.c_void_p] * (len(args) + 1)))
|
|
511
|
+
|
|
512
|
+
# output shape from the optimized graph's output node
|
|
513
|
+
out_node = optimized.nodes[optimized.outputs[0]]
|
|
514
|
+
out_shape = tuple(out_node.shape.dims[i] for i in range(out_node.shape.rank))
|
|
515
|
+
out_np_dt = _BAFE_TO_NP[out_node.dtype]
|
|
516
|
+
|
|
517
|
+
# Phase 3 (issue #6): compute graph hash + predicted cost for autotune logging
|
|
518
|
+
graph_hash_buf = ctypes.create_string_buffer(65)
|
|
519
|
+
_lib.bafe_jit_hash_graph(ctypes.byref(optimized), graph_hash_buf, ctypes.c_size_t(65))
|
|
520
|
+
graph_hash = graph_hash_buf.value.decode("utf-8")
|
|
521
|
+
|
|
522
|
+
# predicted cost = total graph cost from the cost model
|
|
523
|
+
cm = _lib.bafe_cost_model_default()
|
|
524
|
+
predicted_cost = _lib.bafe_cost_graph(ctypes.byref(cm), ctypes.byref(optimized))
|
|
525
|
+
|
|
526
|
+
# keep a copy of the optimized graph for feature extraction during autotune
|
|
527
|
+
# (we store it as a ctypes object so it doesn't get GC'd)
|
|
528
|
+
opt_graph_copy = BafeGraph()
|
|
529
|
+
ctypes.memmove(ctypes.byref(opt_graph_copy), ctypes.byref(optimized), ctypes.sizeof(BafeGraph))
|
|
530
|
+
|
|
531
|
+
return (fn_ptr, sig, in_dtypes, out_shape, out_np_dt,
|
|
532
|
+
opt_graph_copy, graph_hash, predicted_cost)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def jit(fn: Callable = None, *, budget=None, iters: int = None,
|
|
536
|
+
temperature: float = None, seed: int = None, autotune: bool = False,
|
|
537
|
+
time_budget_ms: int = None):
|
|
538
|
+
"""Decorator: trace + optimize + JIT-compile a tensor function.
|
|
539
|
+
|
|
540
|
+
Phase 2 (issue #1): optional stochastic search parameters.
|
|
541
|
+
Phase 3 (issue #4): optional time-budget pruning.
|
|
542
|
+
Phase 3 (issue #6): optional autotune loop.
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
budget: a BafeSearchBudget for full control.
|
|
546
|
+
iters: number of stochastic passes (default 4 if budget mode on).
|
|
547
|
+
temperature: 0.0 = greedy, high = explore randomly (default 1.0).
|
|
548
|
+
seed: PRNG seed for reproducibility (default 0xBAFE5EED).
|
|
549
|
+
autotune: if True, enable the auto-tuning loop.
|
|
550
|
+
time_budget_ms: wall-clock limit for the optimization search.
|
|
551
|
+
Controls the pruning regime:
|
|
552
|
+
<= 1 ms: greedy (Level A+B only, beam=1)
|
|
553
|
+
<= 10 ms: light (A+B+C, beam=4)
|
|
554
|
+
<= 100 ms: beam (A+B+C+D, beam=16)
|
|
555
|
+
> 100 ms: deep (all tiers, beam=64)
|
|
556
|
+
0 or None means no limit (uses stochastic search).
|
|
557
|
+
|
|
558
|
+
Examples:
|
|
559
|
+
@bafe.jit
|
|
560
|
+
def f(A, B): ... # deterministic (default)
|
|
561
|
+
|
|
562
|
+
@bafe.jit(time_budget_ms=100)
|
|
563
|
+
def f(A, B): ... # 100ms pruning budget
|
|
564
|
+
|
|
565
|
+
@bafe.jit(time_budget_ms=1000, autotune=True)
|
|
566
|
+
def f(A, B): ... # 1s budget + autotune
|
|
567
|
+
"""
|
|
568
|
+
if fn is not None:
|
|
569
|
+
return JittedFunction(fn, autotune=autotune, time_budget_ms=time_budget_ms)
|
|
570
|
+
|
|
571
|
+
def deco(fn):
|
|
572
|
+
b = budget
|
|
573
|
+
if b is None and (iters is not None or temperature is not None or seed is not None):
|
|
574
|
+
b = make_search_budget(
|
|
575
|
+
max_iters=iters if iters is not None else 4,
|
|
576
|
+
temperature=temperature if temperature is not None else 1.0,
|
|
577
|
+
seed=seed if seed is not None else 0xBAFE5EED,
|
|
578
|
+
)
|
|
579
|
+
if b is not None and time_budget_ms is not None:
|
|
580
|
+
b.time_budget_ms = int(time_budget_ms)
|
|
581
|
+
return JittedFunction(fn, budget=b, autotune=autotune, time_budget_ms=time_budget_ms)
|
|
582
|
+
return deco
|
|
583
|
+
|
|
584
|
+
|
|
585
|
+
def make_search_budget(max_iters: int = 4, max_nodes: int = 256,
|
|
586
|
+
max_rewrites: int = 64, time_budget_ms: int = 0,
|
|
587
|
+
temperature: float = 1.0, seed: int = 0xBAFE5EED,
|
|
588
|
+
enable_multi_pass: bool = True):
|
|
589
|
+
"""Build a BafeSearchBudget for use with @bafe.jit(budget=...).
|
|
590
|
+
|
|
591
|
+
The budget controls the stochastic search layer:
|
|
592
|
+
- max_iters: how many stochastic passes (each pass re-applies rules
|
|
593
|
+
to newly-created nodes, discovering deeper rewrites)
|
|
594
|
+
- max_nodes: hard cap on graph size during search
|
|
595
|
+
- max_rewrites: cap on total rewrites materialized
|
|
596
|
+
- time_budget_ms: wall-clock limit (0 = no limit)
|
|
597
|
+
- temperature: 0.0 = greedy (only cost-reducing rewrites),
|
|
598
|
+
high = explore randomly
|
|
599
|
+
- seed: PRNG seed for reproducibility
|
|
600
|
+
- enable_multi_pass: if False, degrades to deterministic single-pass
|
|
601
|
+
"""
|
|
602
|
+
from bafe._binding import BafeSearchBudget
|
|
603
|
+
b = BafeSearchBudget()
|
|
604
|
+
b.max_iters = int(max_iters)
|
|
605
|
+
b.max_nodes = int(max_nodes)
|
|
606
|
+
b.max_rewrites = int(max_rewrites)
|
|
607
|
+
b.time_budget_ms = int(time_budget_ms)
|
|
608
|
+
b.temperature = float(temperature)
|
|
609
|
+
b.seed = int(seed) & 0xFFFFFFFF
|
|
610
|
+
b.enable_multi_pass = bool(enable_multi_pass)
|
|
611
|
+
return b
|
|
612
|
+
|
|
613
|
+
|
|
614
|
+
# expose the optimize function for low-level use
|
|
615
|
+
def optimize(graph: BafeGraph) -> BafeGraph:
|
|
616
|
+
"""Run the BAFE optimization pipeline on a graph."""
|
|
617
|
+
out = BafeGraph()
|
|
618
|
+
err = ctypes.create_string_buffer(256)
|
|
619
|
+
rc = _lib.bafe_optimize(byref(graph), byref(out), err, c_size_t(len(err)))
|
|
620
|
+
if rc != 0:
|
|
621
|
+
raise RuntimeError(f"optimize failed: {err.value.decode()}")
|
|
622
|
+
return out
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
# ---------------------------------------------------------------------------
|
|
626
|
+
# Phase 3 (issue #6): autotune API
|
|
627
|
+
# ---------------------------------------------------------------------------
|
|
628
|
+
|
|
629
|
+
def configure_autotune(refit_threshold: int = 20,
|
|
630
|
+
invalidation_drift: float = 0.25,
|
|
631
|
+
warmup_calls: int = 2,
|
|
632
|
+
timing_iters: int = 5):
|
|
633
|
+
"""Configure the global autotune settings.
|
|
634
|
+
|
|
635
|
+
Args:
|
|
636
|
+
refit_threshold: refit the cost model after this many new samples.
|
|
637
|
+
invalidation_drift: invalidate cached kernels when predictions
|
|
638
|
+
drift by more than this ratio (0.25 = 25%).
|
|
639
|
+
warmup_calls: skip timing for the first N calls (cache effects).
|
|
640
|
+
timing_iters: average the kernel runtime over this many invocations.
|
|
641
|
+
"""
|
|
642
|
+
cfg = _lib.bafe_autotune_config_default()
|
|
643
|
+
cfg.enabled = True
|
|
644
|
+
cfg.refit_threshold = int(refit_threshold)
|
|
645
|
+
cfg.invalidation_drift = float(invalidation_drift)
|
|
646
|
+
cfg.warmup_calls = int(warmup_calls)
|
|
647
|
+
cfg.timing_iters = int(timing_iters)
|
|
648
|
+
_lib.bafe_autotune_configure(ctypes.byref(cfg))
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def autotune_stats() -> dict:
|
|
652
|
+
"""Get current autotune statistics.
|
|
653
|
+
|
|
654
|
+
Returns a dict with:
|
|
655
|
+
total_calls, total_compiles, total_refits, total_invalidations,
|
|
656
|
+
last_refit_r_squared, log_size
|
|
657
|
+
"""
|
|
658
|
+
s = _lib.bafe_autotune_get_stats()
|
|
659
|
+
return {
|
|
660
|
+
"total_calls": s.total_calls,
|
|
661
|
+
"total_compiles": s.total_compiles,
|
|
662
|
+
"total_refits": s.total_refits,
|
|
663
|
+
"total_invalidations": s.total_invalidations,
|
|
664
|
+
"last_refit_r_squared": s.last_refit_r_squared,
|
|
665
|
+
"log_size": s.log_size,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
|
|
669
|
+
def autotune_refit() -> int:
|
|
670
|
+
"""Manually trigger a cost model refit.
|
|
671
|
+
|
|
672
|
+
Returns 0 on success, non-zero if not enough samples.
|
|
673
|
+
"""
|
|
674
|
+
return _lib.bafe_profiling_refit()
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def autotune_model() -> dict:
|
|
678
|
+
"""Get the current learned cost model.
|
|
679
|
+
|
|
680
|
+
Returns a dict with:
|
|
681
|
+
weights (list of 8 floats), bias, r_squared, n_samples, valid
|
|
682
|
+
"""
|
|
683
|
+
m = _lib.bafe_profiling_get_model().contents
|
|
684
|
+
return {
|
|
685
|
+
"weights": [m.weights[i] for i in range(8)],
|
|
686
|
+
"bias": m.bias,
|
|
687
|
+
"r_squared": m.r_squared,
|
|
688
|
+
"n_samples": m.n_samples,
|
|
689
|
+
"valid": bool(m.valid),
|
|
690
|
+
}
|
|
691
|
+
|
|
692
|
+
|
|
693
|
+
def autotune_dump_log(path: str) -> int:
|
|
694
|
+
"""Dump the profiling log to a JSONL file. Returns number of records."""
|
|
695
|
+
return _lib.bafe_profiling_dump_jsonl(path.encode("utf-8"))
|
|
696
|
+
|
|
697
|
+
|
|
698
|
+
def autotune_reset():
|
|
699
|
+
"""Reset all profiling state (log, learned model, stats)."""
|
|
700
|
+
_lib.bafe_profiling_reset()
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def calibrate():
|
|
704
|
+
"""Build a calibrated cost model from the current learned model.
|
|
705
|
+
|
|
706
|
+
Returns a BafeCostModel that has its per-node weights adjusted based
|
|
707
|
+
on what the learned model discovered about actual runtime correlations.
|
|
708
|
+
|
|
709
|
+
If no learned model is available (no refit has happened yet), returns
|
|
710
|
+
the static default cost model.
|
|
711
|
+
|
|
712
|
+
The calibrated model is automatically used by bafe_optimize for all
|
|
713
|
+
subsequent extractions.
|
|
714
|
+
"""
|
|
715
|
+
return _lib.bafe_cost_model_calibrated_default()
|
|
716
|
+
|
|
717
|
+
|
|
718
|
+
def calibrated_cost_model() -> dict:
|
|
719
|
+
"""Inspect the calibrated cost model (for debugging).
|
|
720
|
+
|
|
721
|
+
Returns a dict with the calibrated weights:
|
|
722
|
+
alpha_flops, beta_bytes, gamma_intermediate, delta_fuse,
|
|
723
|
+
epsilon_layout_conv, zeta_layout_fuse, eta_contiguous
|
|
724
|
+
"""
|
|
725
|
+
cm = _lib.bafe_cost_model_calibrated_default()
|
|
726
|
+
return {
|
|
727
|
+
"alpha_flops": cm.alpha_flops,
|
|
728
|
+
"beta_bytes": cm.beta_bytes,
|
|
729
|
+
"gamma_intermediate": cm.gamma_intermediate,
|
|
730
|
+
"delta_fuse": cm.delta_fuse,
|
|
731
|
+
"epsilon_layout_conv": cm.epsilon_layout_conv,
|
|
732
|
+
"zeta_layout_fuse": cm.zeta_layout_fuse,
|
|
733
|
+
"eta_contiguous": cm.eta_contiguous,
|
|
734
|
+
}
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
# ---------------------------------------------------------------------------
|
|
738
|
+
# Phase 3 (issue #7): Cross-kernel fusion
|
|
739
|
+
# ---------------------------------------------------------------------------
|
|
740
|
+
|
|
741
|
+
class FusedFunction:
|
|
742
|
+
"""A fused kernel combining two jitted functions.
|
|
743
|
+
|
|
744
|
+
When you call `h = bafe.fuse(f, g)`, BAFE concatenates the two
|
|
745
|
+
optimized graphs (f's output feeds g's first input) and compiles
|
|
746
|
+
a single kernel. This avoids materializing the intermediate tensor.
|
|
747
|
+
|
|
748
|
+
The fused kernel takes f's inputs followed by g's inputs[1:].
|
|
749
|
+
"""
|
|
750
|
+
|
|
751
|
+
def __init__(self, func_a: "JittedFunction", func_b: "JittedFunction"):
|
|
752
|
+
self._func_a = func_a
|
|
753
|
+
self._func_b = func_b
|
|
754
|
+
self._compiled = {} # key: (shapes, dtypes, layouts) -> compiled tuple
|
|
755
|
+
# for fuse chaining: expose a _fn-like object with the combined arg count
|
|
756
|
+
# n_total = n_a_inputs + n_b_inputs - 1
|
|
757
|
+
class _FnShim:
|
|
758
|
+
def __init__(self, a_fn, b_fn):
|
|
759
|
+
self.__code__ = type("_Code", (), {
|
|
760
|
+
"co_argcount": a_fn.__code__.co_argcount + b_fn.__code__.co_argcount - 1,
|
|
761
|
+
"co_varnames": a_fn.__code__.co_varnames[:a_fn.__code__.co_argcount] + \
|
|
762
|
+
tuple(b_fn.__code__.co_varnames[1:b_fn.__code__.co_argcount]),
|
|
763
|
+
})()
|
|
764
|
+
self._fn = _FnShim(func_a._fn, func_b._fn)
|
|
765
|
+
functools.update_wrapper(self, func_a)
|
|
766
|
+
|
|
767
|
+
def __call__(self, *args: np.ndarray) -> np.ndarray:
|
|
768
|
+
if not args:
|
|
769
|
+
raise TypeError("fused function requires at least one input")
|
|
770
|
+
|
|
771
|
+
# split args: first len(a_inputs) go to f, rest go to g[1:]
|
|
772
|
+
# we need to know how many inputs f takes — infer from the first compile
|
|
773
|
+
# Actually, f takes some inputs, g takes some inputs, and f's output
|
|
774
|
+
# feeds g's first input. So total args = n_f_inputs + n_g_inputs - 1.
|
|
775
|
+
# We need to figure out the split.
|
|
776
|
+
|
|
777
|
+
# build cache key
|
|
778
|
+
def _layout_of(a):
|
|
779
|
+
if a.ndim >= 2 and a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
|
|
780
|
+
return "col"
|
|
781
|
+
return "row"
|
|
782
|
+
key = tuple((a.shape, str(a.dtype), _layout_of(a)) for a in args)
|
|
783
|
+
|
|
784
|
+
if key not in self._compiled:
|
|
785
|
+
self._compiled[key] = self._compile(args)
|
|
786
|
+
|
|
787
|
+
fn_ptr, sig, in_dtypes, out_shape, out_dtype, _opt_graph, _hash, _pred = self._compiled[key]
|
|
788
|
+
|
|
789
|
+
# allocate output
|
|
790
|
+
out = np.zeros(out_shape, dtype=out_dtype)
|
|
791
|
+
|
|
792
|
+
# build arg list
|
|
793
|
+
c_args = []
|
|
794
|
+
for a, dt in zip(args, in_dtypes):
|
|
795
|
+
if a.dtype == dt:
|
|
796
|
+
arr = a
|
|
797
|
+
else:
|
|
798
|
+
if a.flags["F_CONTIGUOUS"] and not a.flags["C_CONTIGUOUS"]:
|
|
799
|
+
arr = np.asfortranarray(a, dtype=dt)
|
|
800
|
+
else:
|
|
801
|
+
arr = np.ascontiguousarray(a, dtype=dt)
|
|
802
|
+
c_args.append(arr.ctypes.data_as(ctypes.c_void_p))
|
|
803
|
+
c_args.append(out.ctypes.data_as(ctypes.c_void_p))
|
|
804
|
+
|
|
805
|
+
kernel = sig(fn_ptr)
|
|
806
|
+
kernel(*c_args)
|
|
807
|
+
return out
|
|
808
|
+
|
|
809
|
+
def _compile(self, args):
|
|
810
|
+
"""Compile the fused function and return the 8-tuple matching
|
|
811
|
+
JittedFunction._compile (for fuse chaining)."""
|
|
812
|
+
# Figure out n_f_inputs by looking at f's signature
|
|
813
|
+
f_code = self._func_a._fn.__code__
|
|
814
|
+
n_f_inputs = f_code.co_argcount
|
|
815
|
+
g_code = self._func_b._fn.__code__
|
|
816
|
+
n_g_inputs = g_code.co_argcount
|
|
817
|
+
|
|
818
|
+
if len(args) != n_f_inputs + n_g_inputs - 1:
|
|
819
|
+
raise TypeError(
|
|
820
|
+
f"fused function expects {n_f_inputs + n_g_inputs - 1} args "
|
|
821
|
+
f"({n_f_inputs} for f + {n_g_inputs - 1} for g, since g's "
|
|
822
|
+
f"first input is f's output), got {len(args)}"
|
|
823
|
+
)
|
|
824
|
+
|
|
825
|
+
f_args = args[:n_f_inputs]
|
|
826
|
+
g_args = args[n_f_inputs:]
|
|
827
|
+
|
|
828
|
+
# compile f to get its optimized graph
|
|
829
|
+
f_result = self._func_a._compile(f_args)
|
|
830
|
+
f_fn_ptr, f_sig, f_in_dtypes, f_out_shape, f_out_dt, f_opt_graph, f_hash, f_pred = f_result
|
|
831
|
+
|
|
832
|
+
# For g, compile with a dummy first input matching f's output
|
|
833
|
+
dummy_f_out = np.zeros(f_out_shape, dtype=f_out_dt)
|
|
834
|
+
g_full_args = (dummy_f_out,) + g_args
|
|
835
|
+
g_result = self._func_b._compile(g_full_args)
|
|
836
|
+
g_opt_graph = g_result[5]
|
|
837
|
+
|
|
838
|
+
# Concatenate the two optimized graphs via the C API
|
|
839
|
+
fused_graph = BafeGraph()
|
|
840
|
+
err_buf = ctypes.create_string_buffer(256)
|
|
841
|
+
rc = _lib.bafe_fuse_concat(
|
|
842
|
+
ctypes.byref(f_opt_graph),
|
|
843
|
+
ctypes.byref(g_opt_graph),
|
|
844
|
+
ctypes.byref(fused_graph),
|
|
845
|
+
err_buf,
|
|
846
|
+
ctypes.c_size_t(len(err_buf)),
|
|
847
|
+
)
|
|
848
|
+
if rc != 0:
|
|
849
|
+
raise RuntimeError(
|
|
850
|
+
f"bafe_fuse_concat failed (code {rc}): {err_buf.value.decode()}"
|
|
851
|
+
)
|
|
852
|
+
|
|
853
|
+
# Optimize + JIT compile the fused graph
|
|
854
|
+
optimized = BafeGraph()
|
|
855
|
+
rc = _lib.bafe_optimize(
|
|
856
|
+
ctypes.byref(fused_graph),
|
|
857
|
+
ctypes.byref(optimized),
|
|
858
|
+
err_buf,
|
|
859
|
+
ctypes.c_size_t(len(err_buf)),
|
|
860
|
+
)
|
|
861
|
+
if rc != 0:
|
|
862
|
+
raise RuntimeError(
|
|
863
|
+
f"bafe_optimize failed for fused graph (code {rc}): {err_buf.value.decode()}"
|
|
864
|
+
)
|
|
865
|
+
|
|
866
|
+
fn_ptr = _lib.bafe_jit_get_or_compile(
|
|
867
|
+
ctypes.byref(optimized),
|
|
868
|
+
err_buf,
|
|
869
|
+
ctypes.c_size_t(len(err_buf)),
|
|
870
|
+
)
|
|
871
|
+
if not fn_ptr:
|
|
872
|
+
raise RuntimeError(
|
|
873
|
+
f"JIT compile failed for fused graph: {err_buf.value.decode()}"
|
|
874
|
+
)
|
|
875
|
+
|
|
876
|
+
# build the ctypes signature for the fused kernel
|
|
877
|
+
n_total_inputs = n_f_inputs + n_g_inputs - 1
|
|
878
|
+
sig = ctypes.CFUNCTYPE(None, *([ctypes.c_void_p] * (n_total_inputs + 1)))
|
|
879
|
+
|
|
880
|
+
in_dtypes = [a.dtype for a in args]
|
|
881
|
+
|
|
882
|
+
out_node = optimized.nodes[optimized.outputs[0]]
|
|
883
|
+
out_shape = tuple(out_node.shape.dims[i] for i in range(out_node.shape.rank))
|
|
884
|
+
out_dt = _BAFE_TO_NP[out_node.dtype]
|
|
885
|
+
|
|
886
|
+
# keep a copy of the optimized graph for autotune feature extraction
|
|
887
|
+
opt_graph_copy = BafeGraph()
|
|
888
|
+
ctypes.memmove(ctypes.byref(opt_graph_copy), ctypes.byref(optimized), ctypes.sizeof(BafeGraph))
|
|
889
|
+
|
|
890
|
+
graph_hash_buf = ctypes.create_string_buffer(65)
|
|
891
|
+
_lib.bafe_jit_hash_graph(ctypes.byref(optimized), graph_hash_buf, ctypes.c_size_t(65))
|
|
892
|
+
graph_hash = graph_hash_buf.value.decode("utf-8")
|
|
893
|
+
|
|
894
|
+
cm = _lib.bafe_cost_model_calibrated_default()
|
|
895
|
+
predicted = _lib.bafe_cost_graph(ctypes.byref(cm), ctypes.byref(optimized))
|
|
896
|
+
|
|
897
|
+
return (fn_ptr, sig, in_dtypes, out_shape, out_dt,
|
|
898
|
+
opt_graph_copy, graph_hash, predicted)
|
|
899
|
+
|
|
900
|
+
|
|
901
|
+
def fuse(func_a, func_b):
|
|
902
|
+
"""Fuse two jitted functions into a single kernel.
|
|
903
|
+
|
|
904
|
+
When `h = bafe.fuse(f, g)`, calling `h(a, b, c)` is equivalent to
|
|
905
|
+
`g(f(a, b), c)` but compiled as a single kernel — the intermediate
|
|
906
|
+
tensor (f's output) is never materialized.
|
|
907
|
+
|
|
908
|
+
Args:
|
|
909
|
+
func_a: a @bafe.jit-decorated function or FusedFunction (producer)
|
|
910
|
+
func_b: a @bafe.jit-decorated function or FusedFunction (consumer;
|
|
911
|
+
its first argument receives f's output)
|
|
912
|
+
|
|
913
|
+
Returns:
|
|
914
|
+
A FusedFunction that takes f's inputs + g's inputs[1:].
|
|
915
|
+
"""
|
|
916
|
+
# Accept both JittedFunction and FusedFunction (for chaining)
|
|
917
|
+
if not isinstance(func_a, (JittedFunction, FusedFunction)):
|
|
918
|
+
raise TypeError("fuse() requires @bafe.jit-decorated or fused functions")
|
|
919
|
+
if not isinstance(func_b, (JittedFunction, FusedFunction)):
|
|
920
|
+
raise TypeError("fuse() requires @bafe.jit-decorated or fused functions")
|
|
921
|
+
return FusedFunction(func_a, func_b)
|
|
922
|
+
|
|
923
|
+
|
|
924
|
+
__all__ = [
|
|
925
|
+
"Tensor", "jit", "optimize", "make_search_budget", "fuse",
|
|
926
|
+
"input", "matmul", "add", "sub", "mul", "relu", "sigmoid", "tanh",
|
|
927
|
+
"bias_add", "transpose", "reduce_sum", "reduce_max", "reshape", "broadcast_to",
|
|
928
|
+
"graph_summary",
|
|
929
|
+
"configure_autotune", "autotune_stats", "autotune_refit",
|
|
930
|
+
"autotune_model", "autotune_dump_log", "autotune_reset",
|
|
931
|
+
"calibrate", "calibrated_cost_model",
|
|
932
|
+
"pruning_regime_name", "pruning_beam_width", "pruning_iters",
|
|
933
|
+
"__version__",
|
|
934
|
+
]
|
|
935
|
+
|
|
936
|
+
|
|
937
|
+
# ---------------------------------------------------------------------------
|
|
938
|
+
# Phase 3 (issue #4): Pruning controller helpers
|
|
939
|
+
# ---------------------------------------------------------------------------
|
|
940
|
+
|
|
941
|
+
_PRUNING_REGIMES = {
|
|
942
|
+
0: "greedy",
|
|
943
|
+
1: "light",
|
|
944
|
+
2: "beam",
|
|
945
|
+
3: "deep",
|
|
946
|
+
}
|
|
947
|
+
|
|
948
|
+
|
|
949
|
+
def pruning_regime_name(time_budget_ms: int) -> str:
|
|
950
|
+
"""Get the regime name for a time budget.
|
|
951
|
+
|
|
952
|
+
Returns one of: "greedy" (<=1ms), "light" (<=10ms), "beam" (<=100ms),
|
|
953
|
+
"deep" (>100ms). 0 or None returns "deep" (no limit).
|
|
954
|
+
"""
|
|
955
|
+
if time_budget_ms is None or time_budget_ms <= 0:
|
|
956
|
+
return "deep"
|
|
957
|
+
regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
|
|
958
|
+
return _PRUNING_REGIMES.get(regime, "unknown")
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
def pruning_beam_width(time_budget_ms: int) -> int:
|
|
962
|
+
"""Get the beam width for a time budget's regime."""
|
|
963
|
+
if time_budget_ms is None or time_budget_ms <= 0:
|
|
964
|
+
time_budget_ms = 0
|
|
965
|
+
regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
|
|
966
|
+
return _lib.bafe_pruning_beam_width_for_regime(regime)
|
|
967
|
+
|
|
968
|
+
|
|
969
|
+
def pruning_iters(time_budget_ms: int) -> int:
|
|
970
|
+
"""Get the number of stochastic iterations for a time budget's regime."""
|
|
971
|
+
if time_budget_ms is None or time_budget_ms <= 0:
|
|
972
|
+
time_budget_ms = 0
|
|
973
|
+
regime = _lib.bafe_pruning_regime_from_budget(int(time_budget_ms))
|
|
974
|
+
return _lib.bafe_pruning_iters_for_regime(regime)
|