pulse-framework 0.1.40__py3-none-any.whl → 0.1.42__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pulse/hooks/init.py ADDED
@@ -0,0 +1,460 @@
1
+ from __future__ import annotations
2
+
3
+ import ast
4
+ import ctypes
5
+ import functools
6
+ import inspect
7
+ import textwrap
8
+ import types
9
+ from collections.abc import Callable, Sequence
10
+ from typing import Any, Literal, cast, override
11
+
12
+ from pulse.hooks.core import HookState, hooks
13
+
14
+ # Storage keyed by (code object, lineno) of the `with ps.init()` call site.
15
+ _init_hook = hooks.create("init_storage", lambda: InitState())
16
+
17
+ _CAN_USE_CPYTHON = hasattr(ctypes.pythonapi, "PyFrame_LocalsToFast")
18
+ if _CAN_USE_CPYTHON:
19
+ PyFrame_LocalsToFast = ctypes.pythonapi.PyFrame_LocalsToFast
20
+ PyFrame_LocalsToFast.argtypes = [ctypes.py_object, ctypes.c_int]
21
+ PyFrame_LocalsToFast.restype = None
22
+
23
+
24
+ def previous_frame() -> types.FrameType:
25
+ """Get the previous frame (caller's frame) with assertions.
26
+
27
+ This skips the frame of this helper function and its immediate caller
28
+ to return the actual previous frame.
29
+ """
30
+ current = inspect.currentframe()
31
+ assert current is not None, "currentframe() returned None"
32
+ # Skip this helper function's frame
33
+ caller = current.f_back
34
+ assert caller is not None, "f_back is None"
35
+ # Skip the caller's frame (e.g., __enter__) to get the actual previous frame
36
+ frame = caller.f_back
37
+ assert frame is not None, "f_back.f_back is None"
38
+ return frame
39
+
40
+
41
+ class InitContext:
42
+ """Context that captures locals on first render and restores thereafter."""
43
+
44
+ callsite: tuple[Any, int] | None
45
+ frame: types.FrameType | None
46
+ first_render: bool
47
+ pre_keys: set[str]
48
+ saved: dict[str, Any]
49
+
50
+ def __init__(self):
51
+ self.callsite = None
52
+ self.frame = None
53
+ self.first_render = False
54
+ self.pre_keys = set()
55
+ self.saved = {}
56
+
57
+ def __enter__(self):
58
+ self.frame = previous_frame()
59
+ self.pre_keys = set(self.frame.f_locals.keys())
60
+ # Use code object to disambiguate identical line numbers in different fns.
61
+ self.callsite = (self.frame.f_code, self.frame.f_lineno)
62
+
63
+ storage = _init_hook().storage
64
+ entry = storage.get(self.callsite)
65
+ if entry is None:
66
+ self.first_render = True
67
+ self.saved = {}
68
+ else:
69
+ self.first_render = False
70
+ self.saved = entry["vars"]
71
+ return self
72
+
73
+ def restore_variables(self):
74
+ if self.first_render:
75
+ return
76
+ frame = self.frame if self.frame is not None else previous_frame()
77
+ frame.f_locals.update(self.saved)
78
+ PyFrame_LocalsToFast(frame, 1)
79
+
80
+ def save(self, values: dict[str, Any]):
81
+ self.saved = values
82
+ assert self.callsite is not None, "callsite is None"
83
+ storage = _init_hook().storage
84
+ storage[self.callsite] = {"vars": values}
85
+
86
+ def _capture_new_locals(self) -> dict[str, Any]:
87
+ frame = self.frame
88
+ assert frame is not None, "frame is None"
89
+ captured = {}
90
+ for name, value in frame.f_locals.items():
91
+ if name in self.pre_keys:
92
+ continue
93
+ if value is self:
94
+ continue
95
+ captured[name] = value
96
+ return captured
97
+
98
+ def __exit__(
99
+ self,
100
+ exc_type: type[BaseException] | None,
101
+ exc_value: BaseException | None,
102
+ exc_tb: Any,
103
+ ) -> Literal[False]:
104
+ if exc_type is None:
105
+ captured = self._capture_new_locals()
106
+ assert self.callsite is not None, "callsite None"
107
+ storage = _init_hook().storage
108
+ storage[self.callsite] = {"vars": captured}
109
+ self.frame = None
110
+ return False
111
+
112
+
113
+ def init() -> InitContext:
114
+ return InitContext()
115
+
116
+
117
+ # ---------------------------- AST rewriting -------------------------------
118
+
119
+
120
+ class InitCPythonRewriter(ast.NodeTransformer):
121
+ counter: int
122
+ _init_names: set[str]
123
+ _init_modules: set[str]
124
+
125
+ def __init__(self, init_names: set[str], init_modules: set[str]):
126
+ super().__init__()
127
+ self.counter = 0
128
+ self._init_names = init_names
129
+ self._init_modules = init_modules
130
+
131
+ @override
132
+ def visit_With(self, node: ast.With):
133
+ node = cast(ast.With, self.generic_visit(node))
134
+ if not node.items:
135
+ return node
136
+
137
+ item = node.items[0]
138
+ if self.is_init_call(item.context_expr):
139
+ ctx_name = f"_init_ctx_{self.counter}"
140
+ self.counter += 1
141
+ new_item = ast.withitem(
142
+ context_expr=item.context_expr,
143
+ optional_vars=ast.Name(id=ctx_name, ctx=ast.Store()),
144
+ )
145
+
146
+ restore_call = ast.Expr(
147
+ value=ast.Call(
148
+ func=ast.Attribute(
149
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
150
+ attr="restore_variables",
151
+ ctx=ast.Load(),
152
+ ),
153
+ args=[],
154
+ keywords=[],
155
+ )
156
+ )
157
+
158
+ new_if = ast.If(
159
+ test=ast.Attribute(
160
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
161
+ attr="first_render",
162
+ ctx=ast.Load(),
163
+ ),
164
+ body=node.body,
165
+ orelse=[restore_call],
166
+ )
167
+
168
+ return ast.With(
169
+ items=[new_item],
170
+ body=[new_if],
171
+ type_comment=getattr(node, "type_comment", None),
172
+ )
173
+
174
+ return node
175
+
176
+ def is_init_call(self, expr: ast.AST) -> bool:
177
+ if not isinstance(expr, ast.Call):
178
+ return False
179
+ func = expr.func
180
+ if isinstance(func, ast.Name) and func.id in self._init_names:
181
+ return True
182
+ if (
183
+ isinstance(func, ast.Attribute)
184
+ and isinstance(func.value, ast.Name)
185
+ and func.value.id in self._init_modules
186
+ and func.attr == "init"
187
+ ):
188
+ return True
189
+ return False
190
+
191
+
192
+ class InitFallbackRewriter(ast.NodeTransformer):
193
+ """Rewrite using explicit rebinding (portable, no LocalsToFast)."""
194
+
195
+ counter: int
196
+ _init_names: set[str]
197
+ _init_modules: set[str]
198
+
199
+ def __init__(self, init_names: set[str], init_modules: set[str]):
200
+ super().__init__()
201
+ self.counter = 0
202
+ self._init_names = init_names
203
+ self._init_modules = init_modules
204
+
205
+ @override
206
+ def visit_With(self, node: ast.With):
207
+ node = cast(ast.With, self.generic_visit(node))
208
+ if not node.items:
209
+ return node
210
+
211
+ item = node.items[0]
212
+ if not self.is_init_call(item.context_expr):
213
+ return node
214
+
215
+ ctx_name = f"_init_ctx_{self.counter}"
216
+ self.counter += 1
217
+ new_item = ast.withitem(
218
+ context_expr=item.context_expr,
219
+ optional_vars=ast.Name(id=ctx_name, ctx=ast.Store()),
220
+ )
221
+
222
+ assigned = _collect_assigned_names(node.body)
223
+
224
+ save_call = ast.Expr(
225
+ value=ast.Call(
226
+ func=ast.Attribute(
227
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
228
+ attr="save",
229
+ ctx=ast.Load(),
230
+ ),
231
+ args=[
232
+ ast.Dict(
233
+ keys=[ast.Constant(n) for n in assigned],
234
+ values=[ast.Name(id=n, ctx=ast.Load()) for n in assigned],
235
+ )
236
+ ],
237
+ keywords=[],
238
+ )
239
+ )
240
+
241
+ restore_assigns: Sequence[ast.stmt] = [
242
+ ast.Assign(
243
+ targets=[ast.Name(id=name, ctx=ast.Store())],
244
+ value=ast.Subscript(
245
+ value=ast.Attribute(
246
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
247
+ attr="saved",
248
+ ctx=ast.Load(),
249
+ ),
250
+ slice=ast.Constant(name),
251
+ ctx=ast.Load(),
252
+ ),
253
+ )
254
+ for name in assigned
255
+ ]
256
+
257
+ new_if = ast.If(
258
+ test=ast.Attribute(
259
+ value=ast.Name(id=ctx_name, ctx=ast.Load()),
260
+ attr="first_render",
261
+ ctx=ast.Load(),
262
+ ),
263
+ body=node.body + [save_call],
264
+ orelse=list(restore_assigns),
265
+ )
266
+
267
+ return ast.With(
268
+ items=[new_item],
269
+ body=[new_if],
270
+ type_comment=getattr(node, "type_comment", None),
271
+ )
272
+
273
+ def is_init_call(self, expr: ast.AST) -> bool:
274
+ if not isinstance(expr, ast.Call):
275
+ return False
276
+ func = expr.func
277
+ if isinstance(func, ast.Name) and func.id in self._init_names:
278
+ return True
279
+ if (
280
+ isinstance(func, ast.Attribute)
281
+ and isinstance(func.value, ast.Name)
282
+ and func.value.id in self._init_modules
283
+ and func.attr == "init"
284
+ ):
285
+ return True
286
+ return False
287
+
288
+
289
+ def _collect_assigned_names(body: list[ast.stmt]) -> list[str]:
290
+ names: set[str] = set()
291
+
292
+ def add_target(target: ast.AST):
293
+ if isinstance(target, ast.Name):
294
+ names.add(target.id)
295
+ elif isinstance(target, (ast.Tuple, ast.List)):
296
+ for elt in target.elts:
297
+ add_target(elt)
298
+
299
+ for stmt in body:
300
+ if isinstance(stmt, ast.Assign):
301
+ for target in stmt.targets:
302
+ add_target(target)
303
+ elif isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name):
304
+ names.add(stmt.target.id)
305
+ elif isinstance(stmt, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
306
+ names.add(stmt.name)
307
+ return list(names)
308
+
309
+
310
+ def rewrite_init_blocks(func: Callable[..., Any]) -> Callable[..., Any]:
311
+ """Rewrite `with ps.init()` blocks in the provided function, if present."""
312
+
313
+ source = _get_source(func) # raises immediately if missing
314
+
315
+ if "init" not in source: # quick prefilter, allow alias detection later
316
+ return func
317
+
318
+ tree = ast.parse(source)
319
+
320
+ init_names, init_modules = _resolve_init_bindings(func)
321
+
322
+ # Remove decorators so the re-exec'd function isn't double-wrapped.
323
+ for node in ast.walk(tree):
324
+ if (
325
+ isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
326
+ and node.name == func.__name__
327
+ ):
328
+ node.decorator_list = []
329
+
330
+ if not _contains_ps_init(tree, init_names, init_modules):
331
+ return func
332
+
333
+ if _has_disallowed_control_flow(tree, init_names, init_modules):
334
+ raise RuntimeError(
335
+ "ps.init blocks cannot contain control flow (if/for/while/try/with/match)"
336
+ )
337
+
338
+ rewriter: ast.NodeTransformer
339
+ if _CAN_USE_CPYTHON:
340
+ rewriter = InitCPythonRewriter(init_names, init_modules)
341
+ else:
342
+ rewriter = InitFallbackRewriter(init_names, init_modules)
343
+
344
+ tree = rewriter.visit(tree)
345
+ ast.fix_missing_locations(tree)
346
+
347
+ filename = inspect.getsourcefile(func) or "<rewrite>"
348
+ compiled = compile(tree, filename=filename, mode="exec")
349
+
350
+ global_ns = dict(func.__globals__)
351
+ closure_vars = inspect.getclosurevars(func)
352
+ global_ns.update(closure_vars.nonlocals)
353
+ # Ensure `ps` resolves during exec.
354
+ if "ps" not in global_ns:
355
+ try:
356
+ import pulse as ps
357
+
358
+ global_ns["ps"] = ps
359
+ except Exception:
360
+ pass
361
+ local_ns: dict[str, Any] = {}
362
+ exec(compiled, global_ns, local_ns)
363
+ rewritten = local_ns.get(func.__name__) or global_ns[func.__name__]
364
+ functools.update_wrapper(rewritten, func)
365
+ return rewritten
366
+
367
+
368
+ def _contains_ps_init(
369
+ tree: ast.AST, init_names: set[str], init_modules: set[str]
370
+ ) -> bool:
371
+ checker = _InitCallChecker(init_names, init_modules)
372
+ return checker.contains_init(tree)
373
+
374
+
375
+ def _has_disallowed_control_flow(
376
+ tree: ast.AST, init_names: set[str], init_modules: set[str]
377
+ ) -> bool:
378
+ disallowed = (ast.If, ast.For, ast.While, ast.Try, ast.With, ast.Match)
379
+ checker = _InitCallChecker(init_names, init_modules)
380
+ for node in ast.walk(tree):
381
+ if isinstance(node, ast.With):
382
+ first = node.items[0] if node.items else None
383
+ if first and checker.is_init_call(first.context_expr):
384
+ continue
385
+ if isinstance(node, disallowed):
386
+ return True
387
+ return False
388
+
389
+
390
+ class _InitCallChecker:
391
+ init_names: set[str]
392
+ init_modules: set[str]
393
+
394
+ def __init__(self, init_names: set[str], init_modules: set[str]):
395
+ self.init_names = init_names
396
+ self.init_modules = init_modules
397
+
398
+ def is_init_call(self, expr: ast.AST) -> bool:
399
+ if not isinstance(expr, ast.Call):
400
+ return False
401
+ func = expr.func
402
+ if isinstance(func, ast.Name) and func.id in self.init_names:
403
+ return True
404
+ if (
405
+ isinstance(func, ast.Attribute)
406
+ and isinstance(func.value, ast.Name)
407
+ and func.value.id in self.init_modules
408
+ and func.attr == "init"
409
+ ):
410
+ return True
411
+ return False
412
+
413
+ def contains_init(self, tree: ast.AST) -> bool:
414
+ for node in ast.walk(tree):
415
+ if self.is_init_call(node):
416
+ return True
417
+ return False
418
+
419
+
420
+ def _get_source(func: Callable[..., Any]) -> str:
421
+ try:
422
+ return textwrap.dedent(inspect.getsource(func))
423
+ except OSError as exc:
424
+ src = getattr(func, "__source__", None)
425
+ if src is None:
426
+ raise RuntimeError(
427
+ f"ps.init rewrite failed: unable to read source ({exc})"
428
+ ) from exc
429
+ return textwrap.dedent(src)
430
+
431
+
432
+ def _resolve_init_bindings(func: Callable[..., Any]) -> tuple[set[str], set[str]]:
433
+ """Find names/modules that resolve to pulse.init in the function scope."""
434
+
435
+ init_names: set[str] = set()
436
+ init_modules: set[str] = set()
437
+
438
+ closure = inspect.getclosurevars(func)
439
+ scopes = [func.__globals__, closure.nonlocals, closure.globals]
440
+
441
+ for scope in scopes:
442
+ for name, val in scope.items():
443
+ if val is init:
444
+ init_names.add(name)
445
+ try:
446
+ if getattr(val, "init", None) is init:
447
+ init_modules.add(name)
448
+ except Exception:
449
+ continue
450
+
451
+ return init_names, init_modules
452
+
453
+
454
+ class InitState(HookState):
455
+ def __init__(self) -> None:
456
+ self.storage: dict[tuple[Any, int], dict[str, Any]] = {}
457
+
458
+ @override
459
+ def dispose(self) -> None:
460
+ self.storage.clear()
pulse/hooks/states.py CHANGED
@@ -17,48 +17,106 @@ S9 = TypeVar("S9", bound=State)
17
17
  S10 = TypeVar("S10", bound=State)
18
18
 
19
19
 
20
+ class StateNamespace:
21
+ __slots__: tuple[str, ...] = ("states", "key", "called")
22
+ states: tuple[State, ...]
23
+ key: str | None
24
+ called: bool
25
+
26
+ def __init__(self, key: str | None) -> None:
27
+ self.states = ()
28
+ self.key = key
29
+ self.called = False
30
+
31
+ def ensure_not_called(self) -> None:
32
+ if self.called:
33
+ key_msg = (
34
+ f" with key='{self.key}'" if self.key is not None else " without a key"
35
+ )
36
+ raise RuntimeError(
37
+ f"`pulse.states` can only be called once per component render{key_msg}"
38
+ )
39
+
40
+ def get_or_create_states(
41
+ self, args: tuple[State | Callable[[], State], ...]
42
+ ) -> tuple[State, ...]:
43
+ if len(self.states) > 0:
44
+ # Reuse existing states
45
+ existing_states = self.states
46
+ # Validate that the number of arguments matches
47
+ if len(args) != len(existing_states):
48
+ key_msg = (
49
+ f" with key='{self.key}'"
50
+ if self.key is not None
51
+ else " without a key"
52
+ )
53
+ raise RuntimeError(
54
+ f"`pulse.states` called with {len(args)} argument(s) but was previously "
55
+ + f"called with {len(existing_states)} argument(s){key_msg}. "
56
+ + "The number of arguments must remain consistent across renders."
57
+ )
58
+ # Dispose any State instances passed directly as args that aren't being used
59
+ existing_set = set(existing_states)
60
+ for arg in args:
61
+ if isinstance(arg, State) and arg not in existing_set:
62
+ try:
63
+ if not arg.__disposed__:
64
+ arg.dispose()
65
+ except RuntimeError:
66
+ # Already disposed, ignore
67
+ pass
68
+ return existing_states
69
+
70
+ # Create new states
71
+ instances = tuple(_instantiate_state(arg) for arg in args)
72
+ self.states = instances
73
+ return instances
74
+
75
+ def dispose(self) -> None:
76
+ for state in self.states:
77
+ try:
78
+ if not state.__disposed__:
79
+ state.dispose()
80
+ except RuntimeError:
81
+ # Already disposed, ignore
82
+ pass
83
+ self.states = ()
84
+
85
+
20
86
  class StatesHookState(HookState):
21
- __slots__: tuple[str, ...] = ("initialized", "states", "key", "_called")
22
- initialized: bool
23
- _called: bool
87
+ __slots__: tuple[str, ...] = ("namespaces",)
88
+ namespaces: dict[str | None, StateNamespace]
24
89
 
25
90
  def __init__(self) -> None:
26
91
  super().__init__()
27
- self.initialized = False
28
- self.states: tuple[State, ...] = ()
29
- self.key: str | None = None
30
- self._called = False
92
+ self.namespaces = {}
31
93
 
32
94
  @override
33
95
  def on_render_start(self, render_cycle: int) -> None:
34
96
  super().on_render_start(render_cycle)
35
- self._called = False
36
-
37
- def replace(self, states: list[State], key: str | None) -> None:
38
- self.dispose_states()
39
- self.states = tuple(states)
40
- self.key = key
41
- self.initialized = True
42
-
43
- def dispose_states(self) -> None:
44
- for state in self.states:
45
- state.dispose()
46
- self.states = ()
47
- self.initialized = False
48
- self.key = None
97
+ if self.namespaces:
98
+ for namespace in self.namespaces.values():
99
+ namespace.called = False
100
+
101
+ def get_namespace(self, key: str | None) -> StateNamespace:
102
+ if key not in self.namespaces:
103
+ self.namespaces[key] = StateNamespace(key)
104
+ return self.namespaces[key]
105
+
106
+ def get_or_create_states(
107
+ self, args: tuple[State | Callable[[], State], ...], key: str | None
108
+ ) -> tuple[State, ...]:
109
+ namespace = self.get_namespace(key)
110
+ namespace.ensure_not_called()
111
+ result = namespace.get_or_create_states(args)
112
+ namespace.called = True
113
+ return result
49
114
 
50
115
  @override
51
116
  def dispose(self) -> None:
52
- self.dispose_states()
53
-
54
- def ensure_not_called(self) -> None:
55
- if self._called:
56
- raise RuntimeError(
57
- "`pulse.states` can only be called once per component render"
58
- )
59
-
60
- def mark_called(self) -> None:
61
- self._called = True
117
+ for namespace in self.namespaces.values():
118
+ namespace.dispose()
119
+ self.namespaces.clear()
62
120
 
63
121
 
64
122
  def _instantiate_state(arg: State | Callable[[], State]) -> State:
@@ -219,29 +277,8 @@ def states(*args: S | Callable[[], S], key: str | None = ...) -> tuple[S, ...]:
219
277
 
220
278
 
221
279
  def states(*args: State | Callable[[], State], key: str | None = None):
222
- state = _states_hook()
223
- state.ensure_not_called()
224
-
225
- if not state.initialized:
226
- instances = [_instantiate_state(arg) for arg in args]
227
- state.replace(instances, key)
228
- state.mark_called()
229
- result = state.states
230
- return result[0] if len(result) == 1 else result
231
-
232
- if key is not None and key != state.key:
233
- instances = [_instantiate_state(arg) for arg in args]
234
- state.replace(instances, key)
235
- state.mark_called()
236
- result = state.states
237
- return result[0] if len(result) == 1 else result
238
-
239
- for arg in args:
240
- if isinstance(arg, State):
241
- arg.dispose()
242
-
243
- state.mark_called()
244
- result = state.states
280
+ hook_state = _states_hook()
281
+ result = hook_state.get_or_create_states(args, key)
245
282
  return result[0] if len(result) == 1 else result
246
283
 
247
284
 
pulse/messages.py CHANGED
@@ -175,6 +175,6 @@ class Directives(TypedDict):
175
175
  socketio: SocketIODirectives
176
176
 
177
177
 
178
- class PrerenderResult(TypedDict):
178
+ class Prerender(TypedDict):
179
179
  views: dict[str, ServerInitMessage | None]
180
180
  directives: Directives