bafe-engine 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
bafe/_binding.py ADDED
@@ -0,0 +1,528 @@
1
+ """BAFE - ctypes binding to libbafe.so.
2
+
3
+ This module loads libbafe.so and exposes the C API as Python functions.
4
+ The public API (matmul, add, relu, @jit, ...) is in __init__.py; this
5
+ file is the low-level FFI.
6
+
7
+ The library is searched in this order:
8
+ 1. $BAFE_LIB environment variable (full path to .so)
9
+ 2. ./bafe/build/libbafe.so (development)
10
+ 3. ./build/libbafe.so
11
+ 4. system library paths (via ctypes.util.find_library)
12
+ """
13
+ from __future__ import annotations
14
+
15
+ import ctypes
16
+ import ctypes.util
17
+ import os
18
+ import sys
19
+ from pathlib import Path
20
+ from ctypes import (
21
+ c_int, c_int32, c_uint32, c_size_t, c_double, c_bool, c_char, c_char_p, c_void_p,
22
+ POINTER, Structure, byref, cast, string_at,
23
+ )
24
+
25
+ # ---------------------------------------------------------------------------
26
+ # Path resolution for libbafe.so
27
+ # ---------------------------------------------------------------------------
28
+
29
+ def _find_library() -> str:
30
+ # 1. BAFE_LIB env var (for development/testing)
31
+ env = os.environ.get("BAFE_LIB")
32
+ if env and Path(env).exists():
33
+ return env
34
+ # 2. Bundled .so (next to _binding.py, for pip-installed packages)
35
+ bundled = Path(__file__).resolve().parent / "libbafe.so"
36
+ if bundled.exists():
37
+ return str(bundled)
38
+ # 3. Development paths
39
+ candidates = [
40
+ Path(__file__).resolve().parent.parent.parent / "bafe" / "build" / "libbafe.so",
41
+ Path(__file__).resolve().parent.parent / "build" / "libbafe.so",
42
+ Path.cwd() / "bafe" / "build" / "libbafe.so",
43
+ Path.cwd() / "build" / "libbafe.so",
44
+ ]
45
+ for c in candidates:
46
+ if c.exists():
47
+ return str(c)
48
+ # 4. System library paths
49
+ found = ctypes.util.find_library("bafe") # type: ignore[attr-defined]
50
+ if found:
51
+ return found
52
+ raise RuntimeError(
53
+ "libbafe.so not found. Set BAFE_LIB to its path, or run `make` "
54
+ "in the project root to build it."
55
+ )
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # C struct definitions (must match bafe/types.h, bafe/ir.h, bafe/ops.h, etc.)
60
+ # ---------------------------------------------------------------------------
61
+
62
+ class BafeShape(Structure):
63
+ _fields_ = [
64
+ ("dims", c_int32 * 8), # BAFE_MAX_RANK = 8
65
+ ("rank", c_int32),
66
+ ]
67
+
68
+
69
+ BAFE_MAX_CHILDREN = 4
70
+ BAFE_MAX_ATTR_LEN = 32
71
+
72
+
73
+ class BafeOpAttrs(Structure):
74
+ _fields_ = [
75
+ ("n_axes", c_int32),
76
+ ("axes", c_int32 * BAFE_MAX_ATTR_LEN),
77
+ ("n_perm", c_int32),
78
+ ("perm", c_int32 * BAFE_MAX_ATTR_LEN),
79
+ ("n_shape", c_int32),
80
+ ("shape", c_int32 * BAFE_MAX_ATTR_LEN),
81
+ ("keepdims", c_bool),
82
+ ("scalar_value", c_double),
83
+ ("has_scalar", c_bool),
84
+ ("name", c_char * BAFE_MAX_ATTR_LEN),
85
+ ]
86
+
87
+
88
+ BAFE_MAX_NODES = 4096
89
+
90
+
91
+ class BafeNode(Structure):
92
+ _fields_ = [
93
+ ("id", c_int32),
94
+ ("op_name", c_char_p),
95
+ ("attrs", BafeOpAttrs),
96
+ ("n_children", c_int),
97
+ ("children", c_int32 * BAFE_MAX_CHILDREN),
98
+ ("shape", BafeShape),
99
+ ("dtype", c_int),
100
+ ("layout", c_int), # Phase 2: bafe_layout enum
101
+ ("input_name", c_char * BAFE_MAX_ATTR_LEN),
102
+ ("is_input", c_bool),
103
+ ("is_constant", c_bool),
104
+ ("const_value", c_double),
105
+ ]
106
+
107
+
108
+ class BafeGraph(Structure):
109
+ _fields_ = [
110
+ ("nodes", BafeNode * BAFE_MAX_NODES),
111
+ ("n_nodes", c_int),
112
+ ("inputs", c_int32 * BAFE_MAX_NODES),
113
+ ("n_inputs", c_int),
114
+ ("outputs", c_int32 * BAFE_MAX_NODES),
115
+ ("n_outputs", c_int),
116
+ ]
117
+
118
+
119
+ # Phase 2: rewrite alternatives (used by tests)
120
+ BAFE_MAX_ALTERNATIVES = 1024
121
+
122
+
123
+ class BafeAlternative(Structure):
124
+ _fields_ = [
125
+ ("original_node_id", c_int32),
126
+ ("op_name", c_char_p),
127
+ ("attrs", BafeOpAttrs),
128
+ ("n_children", c_int),
129
+ ("children", c_int32 * BAFE_MAX_CHILDREN),
130
+ ]
131
+
132
+
133
+ class BafeAltList(Structure):
134
+ _fields_ = [
135
+ ("items", BafeAlternative * BAFE_MAX_ALTERNATIVES),
136
+ ("n", c_int),
137
+ ]
138
+
139
+
140
+ # Phase 2: cost model struct (used by tests)
141
+ class BafeCostModel(Structure):
142
+ _fields_ = [
143
+ ("alpha_flops", c_double),
144
+ ("beta_bytes", c_double),
145
+ ("gamma_intermediate", c_double),
146
+ ("delta_fuse", c_double),
147
+ ("epsilon_layout_conv", c_double),
148
+ ("zeta_layout_fuse", c_double),
149
+ ("eta_contiguous", c_double),
150
+ ]
151
+
152
+
153
+ # ---------------------------------------------------------------------------
154
+ # Load library and set up function prototypes
155
+ # ---------------------------------------------------------------------------
156
+
157
+ _lib_path = _find_library()
158
+ _lib = ctypes.CDLL(_lib_path)
159
+
160
+ # Phase 2: rewrite + cost bindings (used by tests)
161
+ _lib.bafe_cost_model_default.argtypes = []
162
+ _lib.bafe_cost_model_default.restype = BafeCostModel
163
+ _lib.bafe_cost_graph.argtypes = [POINTER(BafeCostModel), POINTER(BafeGraph)]
164
+ _lib.bafe_cost_graph.restype = c_double
165
+ _lib.bafe_cost_model_calibrate.argtypes = [POINTER(BafeCostModel), POINTER(c_double), c_int, c_double]
166
+ _lib.bafe_cost_model_calibrate.restype = BafeCostModel
167
+ _lib.bafe_cost_model_calibrated_default.argtypes = []
168
+ _lib.bafe_cost_model_calibrated_default.restype = BafeCostModel
169
+ _lib.bafe_rewrite_find.argtypes = [POINTER(BafeGraph), POINTER(BafeAltList)]
170
+ _lib.bafe_rewrite_find.restype = c_int
171
+ _lib.bafe_rewrite_default_count.argtypes = []
172
+ _lib.bafe_rewrite_default_count.restype = c_int
173
+
174
+
175
+ # Phase 2 (issue #1): stochastic search budget + stats
176
+ class BafeSearchBudget(Structure):
177
+ _fields_ = [
178
+ ("max_iters", c_int),
179
+ ("max_nodes", c_int),
180
+ ("max_rewrites", c_int),
181
+ ("time_budget_ms", c_int),
182
+ ("temperature", c_double),
183
+ ("seed", c_uint32),
184
+ ("enable_multi_pass", c_bool),
185
+ ]
186
+
187
+
188
+ class BafeSearchStats(Structure):
189
+ _fields_ = [
190
+ ("iters_done", c_int),
191
+ ("alts_found", c_int),
192
+ ("alts_materialized", c_int),
193
+ ("nodes_added", c_int),
194
+ ("elapsed_ms", c_double),
195
+ ]
196
+
197
+
198
+ _lib.bafe_search_budget_default.argtypes = []
199
+ _lib.bafe_search_budget_default.restype = BafeSearchBudget
200
+ _lib.bafe_rewrite_stochastic.argtypes = [POINTER(BafeGraph), POINTER(BafeAltList),
201
+ POINTER(BafeSearchBudget)]
202
+ _lib.bafe_rewrite_stochastic.restype = c_int
203
+ _lib.bafe_rewrite_stochastic_stats.argtypes = [POINTER(BafeGraph), POINTER(BafeAltList),
204
+ POINTER(BafeSearchBudget),
205
+ POINTER(BafeSearchStats)]
206
+ _lib.bafe_rewrite_stochastic_stats.restype = c_int
207
+ _lib.bafe_optimize_with_budget.argtypes = [POINTER(BafeGraph), POINTER(BafeGraph),
208
+ POINTER(BafeSearchBudget),
209
+ c_char_p, c_size_t]
210
+ _lib.bafe_optimize_with_budget.restype = c_int
211
+ _lib.bafe_optimize_and_compile_with_budget.argtypes = [POINTER(BafeGraph),
212
+ POINTER(BafeSearchBudget),
213
+ c_char_p, c_size_t]
214
+ _lib.bafe_optimize_and_compile_with_budget.restype = c_void_p
215
+
216
+
217
+ # Phase 3 (issue #6): auto-tuning + profiling
218
+ BAFE_NUM_FEATURES = 8
219
+ BAFE_PROFILING_LOG_SIZE = 4096
220
+
221
+
222
+ class BafeProfilingRecord(Structure):
223
+ _fields_ = [
224
+ ("graph_hash", c_char * 65),
225
+ ("features", c_double * BAFE_NUM_FEATURES),
226
+ ("predicted_cost", c_double),
227
+ ("observed_ms", c_double),
228
+ ("kernel_id", c_int),
229
+ ]
230
+
231
+
232
+ class BafeProfilingLog(Structure):
233
+ _fields_ = [
234
+ ("records", BafeProfilingRecord * BAFE_PROFILING_LOG_SIZE),
235
+ ("n", c_int),
236
+ ("head", c_int),
237
+ ("wrapped", c_bool),
238
+ ]
239
+
240
+
241
+ class BafeLearnedCostModel(Structure):
242
+ _fields_ = [
243
+ ("weights", c_double * BAFE_NUM_FEATURES),
244
+ ("bias", c_double),
245
+ ("r_squared", c_double),
246
+ ("n_samples", c_int),
247
+ ("valid", c_bool),
248
+ ]
249
+
250
+
251
+ class BafeAutotuneConfig(Structure):
252
+ _fields_ = [
253
+ ("enabled", c_bool),
254
+ ("refit_threshold", c_int),
255
+ ("invalidation_drift", c_double),
256
+ ("warmup_calls", c_int),
257
+ ("timing_iters", c_int),
258
+ ]
259
+
260
+
261
+ class BafeAutotuneStats(Structure):
262
+ _fields_ = [
263
+ ("total_calls", c_int),
264
+ ("total_compiles", c_int),
265
+ ("total_refits", c_int),
266
+ ("total_invalidations", c_int),
267
+ ("last_refit_r_squared", c_double),
268
+ ("log_size", c_int),
269
+ ]
270
+
271
+
272
+ _lib.bafe_profiling_init.argtypes = []
273
+ _lib.bafe_profiling_init.restype = None
274
+ _lib.bafe_profiling_reset.argtypes = []
275
+ _lib.bafe_profiling_reset.restype = None
276
+ _lib.bafe_profiling_extract_features.argtypes = [POINTER(BafeGraph), POINTER(c_double)]
277
+ _lib.bafe_profiling_extract_features.restype = None
278
+ _lib.bafe_profiling_add.argtypes = [c_char_p, POINTER(c_double), c_double, c_double, c_int]
279
+ _lib.bafe_profiling_add.restype = None
280
+ _lib.bafe_profiling_get_log.argtypes = []
281
+ _lib.bafe_profiling_get_log.restype = POINTER(BafeProfilingLog)
282
+ _lib.bafe_profiling_dump_jsonl.argtypes = [c_char_p]
283
+ _lib.bafe_profiling_dump_jsonl.restype = c_int
284
+ _lib.bafe_profiling_refit.argtypes = []
285
+ _lib.bafe_profiling_refit.restype = c_int
286
+ _lib.bafe_profiling_get_model.argtypes = []
287
+ _lib.bafe_profiling_get_model.restype = POINTER(BafeLearnedCostModel)
288
+ _lib.bafe_profiling_predict_ms.argtypes = [POINTER(c_double)]
289
+ _lib.bafe_profiling_predict_ms.restype = c_double
290
+ _lib.bafe_autotune_config_default.argtypes = []
291
+ _lib.bafe_autotune_config_default.restype = BafeAutotuneConfig
292
+ _lib.bafe_autotune_configure.argtypes = [POINTER(BafeAutotuneConfig)]
293
+ _lib.bafe_autotune_configure.restype = None
294
+ _lib.bafe_autotune_get_config.argtypes = []
295
+ _lib.bafe_autotune_get_config.restype = BafeAutotuneConfig
296
+ _lib.bafe_autotune_get_stats.argtypes = []
297
+ _lib.bafe_autotune_get_stats.restype = BafeAutotuneStats
298
+ _lib.bafe_jit_invalidate_memory_cache.argtypes = []
299
+ _lib.bafe_jit_invalidate_memory_cache.restype = None
300
+
301
+
302
+ # Phase 3 (issue #7): cross-kernel fusion
303
+ _lib.bafe_fuse_concat.argtypes = [POINTER(BafeGraph), POINTER(BafeGraph),
304
+ POINTER(BafeGraph), c_char_p, c_size_t]
305
+ _lib.bafe_fuse_concat.restype = c_int
306
+ _lib.bafe_fuse_compile.argtypes = [POINTER(BafeGraph), POINTER(BafeGraph),
307
+ c_char_p, c_size_t]
308
+ _lib.bafe_fuse_compile.restype = c_void_p
309
+
310
+
311
+ # Phase 3 (issue #4): multi-tier pruning with time budget
312
+ class BafePruningConfig(Structure):
313
+ _fields_ = [
314
+ ("time_budget_ms", c_int),
315
+ ("max_nodes", c_int),
316
+ ("max_rewrites", c_int),
317
+ ("max_egraph_size", c_int),
318
+ ("beam_width", c_int),
319
+ ("heuristic_threshold", c_double),
320
+ ("temperature", c_double),
321
+ ("seed", c_uint32),
322
+ ("enable_anytime", c_bool),
323
+ ]
324
+
325
+
326
+ class BafePruningStats(Structure):
327
+ _fields_ = [
328
+ ("regime", c_int),
329
+ ("tier_a_passed", c_int),
330
+ ("tier_b_passed", c_int),
331
+ ("tier_c_kept", c_int),
332
+ ("tier_d_materialized", c_int),
333
+ ("total_alts_found", c_int),
334
+ ("best_cost", c_int),
335
+ ("elapsed_ms", c_double),
336
+ ("was_interrupted", c_bool),
337
+ ]
338
+
339
+
340
+ _lib.bafe_pruning_config_default.argtypes = []
341
+ _lib.bafe_pruning_config_default.restype = BafePruningConfig
342
+ _lib.bafe_pruning_regime_from_budget.argtypes = [c_int]
343
+ _lib.bafe_pruning_regime_from_budget.restype = c_int
344
+ _lib.bafe_pruning_beam_width_for_regime.argtypes = [c_int]
345
+ _lib.bafe_pruning_beam_width_for_regime.restype = c_int
346
+ _lib.bafe_pruning_iters_for_regime.argtypes = [c_int]
347
+ _lib.bafe_pruning_iters_for_regime.restype = c_int
348
+ _lib.bafe_pruning_run.argtypes = [POINTER(BafeGraph), POINTER(BafeAltList),
349
+ POINTER(BafePruningConfig), POINTER(BafePruningStats)]
350
+ _lib.bafe_pruning_run.restype = c_int
351
+ _lib.bafe_pruning_run_with_budget.argtypes = [POINTER(BafeGraph), POINTER(BafeAltList),
352
+ c_int, POINTER(BafePruningStats)]
353
+ _lib.bafe_pruning_run_with_budget.restype = c_int
354
+
355
+
356
+ # types
357
+ _lib.bafe_dtype_c_name.argtypes = [c_int]
358
+ _lib.bafe_dtype_c_name.restype = c_char_p
359
+ _lib.bafe_dtype_numpy_name.argtypes = [c_int]
360
+ _lib.bafe_dtype_numpy_name.restype = c_char_p
361
+ _lib.bafe_dtype_byte_size.argtypes = [c_int]
362
+ _lib.bafe_dtype_byte_size.restype = c_size_t
363
+ _lib.bafe_dtype_from_str.argtypes = [c_char_p]
364
+ _lib.bafe_dtype_from_str.restype = c_int
365
+
366
+ _lib.bafe_shape_make.argtypes = [c_int32, POINTER(c_int32)]
367
+ _lib.bafe_shape_make.restype = BafeShape
368
+ _lib.bafe_shape_numel.argtypes = [POINTER(BafeShape)]
369
+ _lib.bafe_shape_numel.restype = c_size_t
370
+ _lib.bafe_shape_broadcast.argtypes = [POINTER(BafeShape), POINTER(BafeShape)]
371
+ _lib.bafe_shape_broadcast.restype = BafeShape
372
+ _lib.bafe_shape_reduce.argtypes = [POINTER(BafeShape), POINTER(c_int32), c_int32, c_bool]
373
+ _lib.bafe_shape_reduce.restype = BafeShape
374
+ _lib.bafe_shape_transpose.argtypes = [POINTER(BafeShape), POINTER(c_int32)]
375
+ _lib.bafe_shape_transpose.restype = BafeShape
376
+ _lib.bafe_shape_eq.argtypes = [POINTER(BafeShape), POINTER(BafeShape)]
377
+ _lib.bafe_shape_eq.restype = c_bool
378
+ _lib.bafe_shape_is_scalar.argtypes = [POINTER(BafeShape)]
379
+ _lib.bafe_shape_is_scalar.restype = c_bool
380
+ _lib.bafe_shape_is_empty.argtypes = [POINTER(BafeShape)]
381
+ _lib.bafe_shape_is_empty.restype = c_bool
382
+ _lib.bafe_shape_rank.argtypes = [POINTER(BafeShape)]
383
+ _lib.bafe_shape_rank.restype = c_int32
384
+ _lib.bafe_shape_nbytes.argtypes = [POINTER(BafeShape), c_int]
385
+ _lib.bafe_shape_nbytes.restype = c_size_t
386
+ _lib.bafe_shape_dim.argtypes = [POINTER(BafeShape), c_int32]
387
+ _lib.bafe_shape_dim.restype = c_int32
388
+ _lib.bafe_layout_name.argtypes = [c_int]
389
+ _lib.bafe_layout_name.restype = c_char_p
390
+
391
+ # ops
392
+ _lib.bafe_op_get.argtypes = [c_char_p]
393
+ _lib.bafe_op_get.restype = c_void_p
394
+
395
+ # ir
396
+ _lib.bafe_graph_init.argtypes = [POINTER(BafeGraph)]
397
+ _lib.bafe_graph_init.restype = None
398
+ _lib.bafe_graph_add_input.argtypes = [POINTER(BafeGraph), c_char_p, POINTER(BafeShape), c_int]
399
+ _lib.bafe_graph_add_input.restype = c_int32
400
+ _lib.bafe_graph_add_input_with_layout.argtypes = [POINTER(BafeGraph), c_char_p, POINTER(BafeShape), c_int, c_int]
401
+ _lib.bafe_graph_add_input_with_layout.restype = c_int32
402
+ _lib.bafe_graph_set_node_layout.argtypes = [POINTER(BafeGraph), c_int32, c_int]
403
+ _lib.bafe_graph_set_node_layout.restype = c_int
404
+ _lib.bafe_graph_get_node_layout.argtypes = [POINTER(BafeGraph), c_int32]
405
+ _lib.bafe_graph_get_node_layout.restype = c_int
406
+ _lib.bafe_graph_add_constant.argtypes = [POINTER(BafeGraph), c_double, POINTER(BafeShape), c_int]
407
+ _lib.bafe_graph_add_constant.restype = c_int32
408
+ _lib.bafe_graph_add.argtypes = [
409
+ POINTER(BafeGraph), c_char_p, POINTER(c_int32), c_int, POINTER(BafeOpAttrs)
410
+ ]
411
+ _lib.bafe_graph_add.restype = c_int32
412
+
413
+ _lib.bafe_graph_matmul.argtypes = [POINTER(BafeGraph), c_int32, c_int32]
414
+ _lib.bafe_graph_matmul.restype = c_int32
415
+ _lib.bafe_graph_add_op.argtypes = [POINTER(BafeGraph), c_int32, c_int32]
416
+ _lib.bafe_graph_add_op.restype = c_int32
417
+ _lib.bafe_graph_mul.argtypes = [POINTER(BafeGraph), c_int32, c_int32]
418
+ _lib.bafe_graph_mul.restype = c_int32
419
+ _lib.bafe_graph_sub.argtypes = [POINTER(BafeGraph), c_int32, c_int32]
420
+ _lib.bafe_graph_sub.restype = c_int32
421
+ _lib.bafe_graph_bias_add.argtypes = [POINTER(BafeGraph), c_int32, c_int32]
422
+ _lib.bafe_graph_bias_add.restype = c_int32
423
+ _lib.bafe_graph_relu.argtypes = [POINTER(BafeGraph), c_int32]
424
+ _lib.bafe_graph_relu.restype = c_int32
425
+ _lib.bafe_graph_sigmoid.argtypes = [POINTER(BafeGraph), c_int32]
426
+ _lib.bafe_graph_sigmoid.restype = c_int32
427
+ _lib.bafe_graph_tanh.argtypes = [POINTER(BafeGraph), c_int32]
428
+ _lib.bafe_graph_tanh.restype = c_int32
429
+ _lib.bafe_graph_neg.argtypes = [POINTER(BafeGraph), c_int32]
430
+ _lib.bafe_graph_neg.restype = c_int32
431
+ _lib.bafe_graph_transpose.argtypes = [POINTER(BafeGraph), c_int32, POINTER(c_int32), c_int32]
432
+ _lib.bafe_graph_transpose.restype = c_int32
433
+ _lib.bafe_graph_reduce_sum.argtypes = [POINTER(BafeGraph), c_int32, POINTER(c_int32), c_int32, c_int]
434
+ _lib.bafe_graph_reduce_sum.restype = c_int32
435
+ _lib.bafe_graph_reduce_max.argtypes = [POINTER(BafeGraph), c_int32, POINTER(c_int32), c_int32, c_int]
436
+ _lib.bafe_graph_reduce_max.restype = c_int32
437
+ _lib.bafe_graph_reshape.argtypes = [POINTER(BafeGraph), c_int32, POINTER(c_int32), c_int32]
438
+ _lib.bafe_graph_reshape.restype = c_int32
439
+ _lib.bafe_graph_broadcast_to.argtypes = [POINTER(BafeGraph), c_int32, POINTER(c_int32), c_int32]
440
+ _lib.bafe_graph_broadcast_to.restype = c_int32
441
+
442
+ _lib.bafe_graph_set_output.argtypes = [POINTER(BafeGraph), c_int32]
443
+ _lib.bafe_graph_set_output.restype = None
444
+
445
+ # bafe (top-level)
446
+ _lib.bafe_optimize.argtypes = [POINTER(BafeGraph), POINTER(BafeGraph), c_char_p, c_size_t]
447
+ _lib.bafe_optimize.restype = c_int
448
+ _lib.bafe_optimize_and_compile.argtypes = [POINTER(BafeGraph), c_char_p, c_size_t]
449
+ _lib.bafe_optimize_and_compile.restype = c_void_p
450
+
451
+ # jit
452
+ _lib.bafe_jit_get_or_compile.argtypes = [POINTER(BafeGraph), c_char_p, c_size_t]
453
+ _lib.bafe_jit_get_or_compile.restype = c_void_p
454
+ _lib.bafe_jit_set_cache_dir.argtypes = [c_char_p]
455
+ _lib.bafe_jit_set_cache_dir.restype = None
456
+ _lib.bafe_jit_cache_dir.argtypes = []
457
+ _lib.bafe_jit_cache_dir.restype = c_char_p
458
+
459
+ class BafeJitStats(Structure):
460
+ _fields_ = [
461
+ ("hits", c_int),
462
+ ("misses", c_int),
463
+ ("compiles", c_int),
464
+ ("compile_failures", c_int),
465
+ ]
466
+
467
+ _lib.bafe_jit_get_stats.argtypes = []
468
+ _lib.bafe_jit_get_stats.restype = BafeJitStats
469
+
470
+
471
+ # ---------------------------------------------------------------------------
472
+ # Convenience helpers for the public API
473
+ # ---------------------------------------------------------------------------
474
+
475
+ def make_shape(dims):
476
+ """Build a BafeShape from a Python tuple/list of ints."""
477
+ n = len(dims)
478
+ arr = (c_int32 * max(n, 1))(*dims)
479
+ return _lib.bafe_shape_make(c_int32(n), arr)
480
+
481
+
482
+ def make_attrs(**kw):
483
+ """Build a BafeOpAttrs from keyword args."""
484
+ a = BafeOpAttrs()
485
+ # zero it
486
+ ctypes.memset(byref(a), 0, ctypes.sizeof(a))
487
+ if "axes" in kw:
488
+ ax = list(kw["axes"])
489
+ a.n_axes = len(ax)
490
+ for i, v in enumerate(ax):
491
+ a.axes[i] = v
492
+ if "perm" in kw:
493
+ p = list(kw["perm"])
494
+ a.n_perm = len(p)
495
+ for i, v in enumerate(p):
496
+ a.perm[i] = v
497
+ if "shape" in kw:
498
+ s = list(kw["shape"])
499
+ a.n_shape = len(s)
500
+ for i, v in enumerate(s):
501
+ a.shape[i] = v
502
+ if "keepdims" in kw:
503
+ a.keepdims = 1 if kw["keepdims"] else 0
504
+ if "scalar" in kw:
505
+ a.scalar_value = float(kw["scalar"])
506
+ a.has_scalar = 1
507
+ if "name" in kw:
508
+ nm = kw["name"].encode("utf-8")[:BAFE_MAX_ATTR_LEN-1]
509
+ a.name = nm
510
+ return a
511
+
512
+
513
+ def graph_summary(g: BafeGraph) -> str:
514
+ """Read summary string from a BafeGraph."""
515
+ buf = ctypes.create_string_buffer(8192)
516
+ # bafe_graph_summary is in ir.h
517
+ _lib.bafe_graph_summary.argtypes = [POINTER(BafeGraph), c_char_p, c_size_t]
518
+ _lib.bafe_graph_summary.restype = c_int
519
+ _lib.bafe_graph_summary(byref(g), buf, c_size_t(len(buf)))
520
+ return buf.value.decode("utf-8")
521
+
522
+
523
+ __all__ = [
524
+ "_lib", "_lib_path",
525
+ "BafeShape", "BafeOpAttrs", "BafeNode", "BafeGraph", "BafeJitStats",
526
+ "BAFE_MAX_NODES", "BAFE_MAX_CHILDREN", "BAFE_MAX_ATTR_LEN",
527
+ "make_shape", "make_attrs", "graph_summary",
528
+ ]
bafe/libbafe.so ADDED
Binary file