pulse-framework 0.1.41__py3-none-any.whl → 0.1.43__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pulse/hooks/init.py ADDED
@@ -0,0 +1,460 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import ctypes
5
+ import functools
6
+ import inspect
7
+ import textwrap
8
+ import types
9
+ from collections.abc import Callable, Sequence
10
+ from typing import Any, Literal, cast, override
11
+
12
+ from pulse.hooks.core import HookState, hooks
13
+
14
+ # Storage keyed by (code object, lineno) of the `with ps.init()` call site.
15
+ _init_hook = hooks.create("init_storage", lambda: InitState())
16
+
17
+ _CAN_USE_CPYTHON = hasattr(ctypes.pythonapi, "PyFrame_LocalsToFast")
18
+ if _CAN_USE_CPYTHON:
19
+ PyFrame_LocalsToFast = ctypes.pythonapi.PyFrame_LocalsToFast
20
+ PyFrame_LocalsToFast.argtypes = [ctypes.py_object, ctypes.c_int]
21
+ PyFrame_LocalsToFast.restype = None
22
+
23
+
24
+ def previous_frame() -> types.FrameType:
25
+ """Get the previous frame (caller's frame) with assertions.
26
+
27
+ This skips the frame of this helper function and its immediate caller
28
+ to return the actual previous frame.
29
+ """
30
+ current = inspect.currentframe()
31
+ assert current is not None, "currentframe() returned None"
32
+ # Skip this helper function's frame
33
+ caller = current.f_back
34
+ assert caller is not None, "f_back is None"
35
+ # Skip the caller's frame (e.g., __enter__) to get the actual previous frame
36
+ frame = caller.f_back
37
+ assert frame is not None, "f_back.f_back is None"
38
+ return frame
39
+
40
+
41
+ class InitContext:
42
+ """Context that captures locals on first render and restores thereafter."""
43
+
44
+ callsite: tuple[Any, int] | None
45
+ frame: types.FrameType | None
46
+ first_render: bool
47
+ pre_keys: set[str]
48
+ saved: dict[str, Any]
49
+
50
+ def __init__(self):
51
+ self.callsite = None
52
+ self.frame = None
53
+ self.first_render = False
54
+ self.pre_keys = set()
55
+ self.saved = {}
56
+
57
+ def __enter__(self):
58
+ self.frame = previous_frame()
59
+ self.pre_keys = set(self.frame.f_locals.keys())
60
+ # Use code object to disambiguate identical line numbers in different fns.
61
+ self.callsite = (self.frame.f_code, self.frame.f_lineno)
62
+
63
+ storage = _init_hook().storage
64
+ entry = storage.get(self.callsite)
65
+ if entry is None:
66
+ self.first_render = True
67
+ self.saved = {}
68
+ else:
69
+ self.first_render = False
70
+ self.saved = entry["vars"]
71
+ return self
72
+
73
+ def restore_variables(self):
74
+ if self.first_render:
75
+ return
76
+ frame = self.frame if self.frame is not None else previous_frame()
77
+ frame.f_locals.update(self.saved)
78
+ PyFrame_LocalsToFast(frame, 1)
79
+
80
+ def save(self, values: dict[str, Any]):
81
+ self.saved = values
82
+ assert self.callsite is not None, "callsite is None"
83
+ storage = _init_hook().storage
84
+ storage[self.callsite] = {"vars": values}
85
+
86
+ def _capture_new_locals(self) -> dict[str, Any]:
87
+ frame = self.frame
88
+ assert frame is not None, "frame is None"
89
+ captured = {}
90
+ for name, value in frame.f_locals.items():
91
+ if name in self.pre_keys:
92
+ continue
93
+ if value is self:
94
+ continue
95
+ captured[name] = value
96
+ return captured
97
+
98
+ def __exit__(
99
+ self,
100
+ exc_type: type[BaseException] | None,
101
+ exc_value: BaseException | None,
102
+ exc_tb: Any,
103
+ ) -> Literal[False]:
104
+ if exc_type is None:
105
+ captured = self._capture_new_locals()
106
+ assert self.callsite is not None, "callsite None"
107
+ storage = _init_hook().storage
108
+ storage[self.callsite] = {"vars": captured}
109
+ self.frame = None
110
+ return False
111
+
112
+
113
+ def init() -> InitContext:
114
+ return InitContext()
115
+
116
+
117
+ # ---------------------------- AST rewriting -------------------------------
118
+
119
+
120
+ class InitCPythonRewriter(ast.NodeTransformer):
121
+ counter: int
122
+ _init_names: set[str]
123
+ _init_modules: set[str]
124
+
125
+ def __init__(self, init_names: set[str], init_modules: set[str]):
126
+ super().__init__()
127
+ self.counter = 0
128
+ self._init_names = init_names
129
+ self._init_modules = init_modules
130
+
131
+ @override
132
+ def visit_With(self, node: ast.With):
133
+ node = cast(ast.With, self.generic_visit(node))
134
+ if not node.items:
135
+ return node
136
+
137
+ item = node.items[0]
138
+ if self.is_init_call(item.context_expr):
139
+ ctx_name = f"_init_ctx_{self.counter}"
140
+ self.counter += 1
141
+ new_item = ast.withitem(
142
+ context_expr=item.context_expr,
143
+ optional_vars=ast.Name(id=ctx_name, ctx=ast.Store()),
144
+ )
145
+
146
+ restore_call = ast.Expr(
147
+ value=ast.Call(
148
+ func=ast.Attribute(
149
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
150
+ attr="restore_variables",
151
+ ctx=ast.Load(),
152
+ ),
153
+ args=[],
154
+ keywords=[],
155
+ )
156
+ )
157
+
158
+ new_if = ast.If(
159
+ test=ast.Attribute(
160
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
161
+ attr="first_render",
162
+ ctx=ast.Load(),
163
+ ),
164
+ body=node.body,
165
+ orelse=[restore_call],
166
+ )
167
+
168
+ return ast.With(
169
+ items=[new_item],
170
+ body=[new_if],
171
+ type_comment=getattr(node, "type_comment", None),
172
+ )
173
+
174
+ return node
175
+
176
+ def is_init_call(self, expr: ast.AST) -> bool:
177
+ if not isinstance(expr, ast.Call):
178
+ return False
179
+ func = expr.func
180
+ if isinstance(func, ast.Name) and func.id in self._init_names:
181
+ return True
182
+ if (
183
+ isinstance(func, ast.Attribute)
184
+ and isinstance(func.value, ast.Name)
185
+ and func.value.id in self._init_modules
186
+ and func.attr == "init"
187
+ ):
188
+ return True
189
+ return False
190
+
191
+
192
+ class InitFallbackRewriter(ast.NodeTransformer):
193
+ """Rewrite using explicit rebinding (portable, no LocalsToFast)."""
194
+
195
+ counter: int
196
+ _init_names: set[str]
197
+ _init_modules: set[str]
198
+
199
+ def __init__(self, init_names: set[str], init_modules: set[str]):
200
+ super().__init__()
201
+ self.counter = 0
202
+ self._init_names = init_names
203
+ self._init_modules = init_modules
204
+
205
+ @override
206
+ def visit_With(self, node: ast.With):
207
+ node = cast(ast.With, self.generic_visit(node))
208
+ if not node.items:
209
+ return node
210
+
211
+ item = node.items[0]
212
+ if not self.is_init_call(item.context_expr):
213
+ return node
214
+
215
+ ctx_name = f"_init_ctx_{self.counter}"
216
+ self.counter += 1
217
+ new_item = ast.withitem(
218
+ context_expr=item.context_expr,
219
+ optional_vars=ast.Name(id=ctx_name, ctx=ast.Store()),
220
+ )
221
+
222
+ assigned = _collect_assigned_names(node.body)
223
+
224
+ save_call = ast.Expr(
225
+ value=ast.Call(
226
+ func=ast.Attribute(
227
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
228
+ attr="save",
229
+ ctx=ast.Load(),
230
+ ),
231
+ args=[
232
+ ast.Dict(
233
+ keys=[ast.Constant(n) for n in assigned],
234
+ values=[ast.Name(id=n, ctx=ast.Load()) for n in assigned],
235
+ )
236
+ ],
237
+ keywords=[],
238
+ )
239
+ )
240
+
241
+ restore_assigns: Sequence[ast.stmt] = [
242
+ ast.Assign(
243
+ targets=[ast.Name(id=name, ctx=ast.Store())],
244
+ value=ast.Subscript(
245
+ value=ast.Attribute(
246
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
247
+ attr="saved",
248
+ ctx=ast.Load(),
249
+ ),
250
+ slice=ast.Constant(name),
251
+ ctx=ast.Load(),
252
+ ),
253
+ )
254
+ for name in assigned
255
+ ]
256
+
257
+ new_if = ast.If(
258
+ test=ast.Attribute(
259
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
260
+ attr="first_render",
261
+ ctx=ast.Load(),
262
+ ),
263
+ body=node.body + [save_call],
264
+ orelse=list(restore_assigns),
265
+ )
266
+
267
+ return ast.With(
268
+ items=[new_item],
269
+ body=[new_if],
270
+ type_comment=getattr(node, "type_comment", None),
271
+ )
272
+
273
+ def is_init_call(self, expr: ast.AST) -> bool:
274
+ if not isinstance(expr, ast.Call):
275
+ return False
276
+ func = expr.func
277
+ if isinstance(func, ast.Name) and func.id in self._init_names:
278
+ return True
279
+ if (
280
+ isinstance(func, ast.Attribute)
281
+ and isinstance(func.value, ast.Name)
282
+ and func.value.id in self._init_modules
283
+ and func.attr == "init"
284
+ ):
285
+ return True
286
+ return False
287
+
288
+
289
+ def _collect_assigned_names(body: list[ast.stmt]) -> list[str]:
290
+ names: set[str] = set()
291
+
292
+ def add_target(target: ast.AST):
293
+ if isinstance(target, ast.Name):
294
+ names.add(target.id)
295
+ elif isinstance(target, (ast.Tuple, ast.List)):
296
+ for elt in target.elts:
297
+ add_target(elt)
298
+
299
+ for stmt in body:
300
+ if isinstance(stmt, ast.Assign):
301
+ for target in stmt.targets:
302
+ add_target(target)
303
+ elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
304
+ names.add(stmt.target.id)
305
+ elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
306
+ names.add(stmt.name)
307
+ return list(names)
308
+
309
+
310
+ def rewrite_init_blocks(func: Callable[..., Any]) -> Callable[..., Any]:
311
+ """Rewrite `with ps.init()` blocks in the provided function, if present."""
312
+
313
+ source = _get_source(func) # raises immediately if missing
314
+
315
+ if "init" not in source: # quick prefilter, allow alias detection later
316
+ return func
317
+
318
+ tree = ast.parse(source)
319
+
320
+ init_names, init_modules = _resolve_init_bindings(func)
321
+
322
+ # Remove decorators so the re-exec'd function isn't double-wrapped.
323
+ for node in ast.walk(tree):
324
+ if (
325
+ isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
326
+ and node.name == func.__name__
327
+ ):
328
+ node.decorator_list = []
329
+
330
+ if not _contains_ps_init(tree, init_names, init_modules):
331
+ return func
332
+
333
+ if _has_disallowed_control_flow(tree, init_names, init_modules):
334
+ raise RuntimeError(
335
+ "ps.init blocks cannot contain control flow (if/for/while/try/with/match)"
336
+ )
337
+
338
+ rewriter: ast.NodeTransformer
339
+ if _CAN_USE_CPYTHON:
340
+ rewriter = InitCPythonRewriter(init_names, init_modules)
341
+ else:
342
+ rewriter = InitFallbackRewriter(init_names, init_modules)
343
+
344
+ tree = rewriter.visit(tree)
345
+ ast.fix_missing_locations(tree)
346
+
347
+ filename = inspect.getsourcefile(func) or "<rewrite>"
348
+ compiled = compile(tree, filename=filename, mode="exec")
349
+
350
+ global_ns = dict(func.__globals__)
351
+ closure_vars = inspect.getclosurevars(func)
352
+ global_ns.update(closure_vars.nonlocals)
353
+ # Ensure `ps` resolves during exec.
354
+ if "ps" not in global_ns:
355
+ try:
356
+ import pulse as ps
357
+
358
+ global_ns["ps"] = ps
359
+ except Exception:
360
+ pass
361
+ local_ns: dict[str, Any] = {}
362
+ exec(compiled, global_ns, local_ns)
363
+ rewritten = local_ns.get(func.__name__) or global_ns[func.__name__]
364
+ functools.update_wrapper(rewritten, func)
365
+ return rewritten
366
+
367
+
368
+ def _contains_ps_init(
369
+ tree: ast.AST, init_names: set[str], init_modules: set[str]
370
+ ) -> bool:
371
+ checker = _InitCallChecker(init_names, init_modules)
372
+ return checker.contains_init(tree)
373
+
374
+
375
+ def _has_disallowed_control_flow(
376
+ tree: ast.AST, init_names: set[str], init_modules: set[str]
377
+ ) -> bool:
378
+ disallowed = (ast.If, ast.For, ast.While, ast.Try, ast.With, ast.Match)
379
+ checker = _InitCallChecker(init_names, init_modules)
380
+ for node in ast.walk(tree):
381
+ if isinstance(node, ast.With):
382
+ first = node.items[0] if node.items else None
383
+ if first and checker.is_init_call(first.context_expr):
384
+ continue
385
+ if isinstance(node, disallowed):
386
+ return True
387
+ return False
388
+
389
+
390
+ class _InitCallChecker:
391
+ init_names: set[str]
392
+ init_modules: set[str]
393
+
394
+ def __init__(self, init_names: set[str], init_modules: set[str]):
395
+ self.init_names = init_names
396
+ self.init_modules = init_modules
397
+
398
+ def is_init_call(self, expr: ast.AST) -> bool:
399
+ if not isinstance(expr, ast.Call):
400
+ return False
401
+ func = expr.func
402
+ if isinstance(func, ast.Name) and func.id in self.init_names:
403
+ return True
404
+ if (
405
+ isinstance(func, ast.Attribute)
406
+ and isinstance(func.value, ast.Name)
407
+ and func.value.id in self.init_modules
408
+ and func.attr == "init"
409
+ ):
410
+ return True
411
+ return False
412
+
413
+ def contains_init(self, tree: ast.AST) -> bool:
414
+ for node in ast.walk(tree):
415
+ if self.is_init_call(node):
416
+ return True
417
+ return False
418
+
419
+
420
+ def _get_source(func: Callable[..., Any]) -> str:
421
+ try:
422
+ return textwrap.dedent(inspect.getsource(func))
423
+ except OSError as exc:
424
+ src = getattr(func, "__source__", None)
425
+ if src is None:
426
+ raise RuntimeError(
427
+ f"ps.init rewrite failed: unable to read source ({exc})"
428
+ ) from exc
429
+ return textwrap.dedent(src)
430
+
431
+
432
+ def _resolve_init_bindings(func: Callable[..., Any]) -> tuple[set[str], set[str]]:
433
+ """Find names/modules that resolve to pulse.init in the function scope."""
434
+
435
+ init_names: set[str] = set()
436
+ init_modules: set[str] = set()
437
+
438
+ closure = inspect.getclosurevars(func)
439
+ scopes = [func.__globals__, closure.nonlocals, closure.globals]
440
+
441
+ for scope in scopes:
442
+ for name, val in scope.items():
443
+ if val is init:
444
+ init_names.add(name)
445
+ try:
446
+ if getattr(val, "init", None) is init:
447
+ init_modules.add(name)
448
+ except Exception:
449
+ continue
450
+
451
+ return init_names, init_modules
452
+
453
+
454
+ class InitState(HookState):
455
+ def __init__(self) -> None:
456
+ self.storage: dict[tuple[Any, int], dict[str, Any]] = {}
457
+
458
+ @override
459
+ def dispose(self) -> None:
460
+ self.storage.clear()