procfunc 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. procfunc/__init__.py +87 -0
  2. procfunc/color.py +57 -0
  3. procfunc/compute_graph/__init__.py +28 -0
  4. procfunc/compute_graph/compute_graph.py +115 -0
  5. procfunc/compute_graph/node.py +200 -0
  6. procfunc/compute_graph/operators_info.py +92 -0
  7. procfunc/compute_graph/proxy.py +173 -0
  8. procfunc/compute_graph/util.py +282 -0
  9. procfunc/context.py +115 -0
  10. procfunc/control.py +174 -0
  11. procfunc/nodes/__init__.py +66 -0
  12. procfunc/nodes/bindings_util.py +196 -0
  13. procfunc/nodes/bpy_node_info.py +280 -0
  14. procfunc/nodes/compositor.py +2242 -0
  15. procfunc/nodes/execute/construct_nodes.py +571 -0
  16. procfunc/nodes/execute/construct_special_cases.py +246 -0
  17. procfunc/nodes/execute/execute.py +548 -0
  18. procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
  19. procfunc/nodes/execute/util.py +247 -0
  20. procfunc/nodes/func.py +1417 -0
  21. procfunc/nodes/geo.py +4240 -0
  22. procfunc/nodes/manifest.json +8769 -0
  23. procfunc/nodes/math.py +644 -0
  24. procfunc/nodes/node_function.py +160 -0
  25. procfunc/nodes/shader.py +2359 -0
  26. procfunc/nodes/types.py +347 -0
  27. procfunc/ops/__init__.py +35 -0
  28. procfunc/ops/_util.py +275 -0
  29. procfunc/ops/addons.py +59 -0
  30. procfunc/ops/attr.py +426 -0
  31. procfunc/ops/collection.py +90 -0
  32. procfunc/ops/curve.py +18 -0
  33. procfunc/ops/file.py +126 -0
  34. procfunc/ops/manifest.json +39149 -0
  35. procfunc/ops/mesh.py +1510 -0
  36. procfunc/ops/modifier.py +603 -0
  37. procfunc/ops/object.py +258 -0
  38. procfunc/ops/primitives/__init__.py +31 -0
  39. procfunc/ops/primitives/camera.py +45 -0
  40. procfunc/ops/primitives/curve.py +71 -0
  41. procfunc/ops/primitives/light.py +114 -0
  42. procfunc/ops/primitives/mesh.py +358 -0
  43. procfunc/ops/uv.py +271 -0
  44. procfunc/random.py +247 -0
  45. procfunc/tracer/__init__.py +43 -0
  46. procfunc/tracer/decorator.py +121 -0
  47. procfunc/tracer/patch.py +494 -0
  48. procfunc/tracer/proxy.py +127 -0
  49. procfunc/tracer/trace.py +222 -0
  50. procfunc/transforms/__init__.py +49 -0
  51. procfunc/transforms/cleanup.py +214 -0
  52. procfunc/transforms/convert.py +20 -0
  53. procfunc/transforms/distribution.py +191 -0
  54. procfunc/transforms/extract_materials.py +116 -0
  55. procfunc/transforms/infer_distribution.py +326 -0
  56. procfunc/transforms/parameters.py +15 -0
  57. procfunc/transforms/util.py +35 -0
  58. procfunc/transpiler/__init__.py +24 -0
  59. procfunc/transpiler/bpy_to_computegraph.py +1348 -0
  60. procfunc/transpiler/codegen.py +919 -0
  61. procfunc/transpiler/identifiers.py +595 -0
  62. procfunc/transpiler/main.py +299 -0
  63. procfunc/types.py +380 -0
  64. procfunc/util/__init__.py +0 -0
  65. procfunc/util/bpy_info.py +145 -0
  66. procfunc/util/camera.py +0 -0
  67. procfunc/util/keyframe.py +70 -0
  68. procfunc/util/log.py +96 -0
  69. procfunc/util/manifest.py +121 -0
  70. procfunc/util/pytree.py +343 -0
  71. procfunc/util/teardown.py +37 -0
  72. procfunc-0.30.0.dist-info/METADATA +120 -0
  73. procfunc-0.30.0.dist-info/RECORD +76 -0
  74. procfunc-0.30.0.dist-info/WHEEL +5 -0
  75. procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
  76. procfunc-0.30.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,494 @@
1
+ import enum
2
+ import functools
3
+ import logging
4
+ from dataclasses import dataclass
5
+ from types import ModuleType
6
+ from typing import Any, Callable, TypeAlias, TypeVar
7
+
8
+ from procfunc import compute_graph as cg
9
+ from procfunc.tracer.proxy import RngProxy
10
+ from procfunc.util import pytree
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+ PATCHING_FLAG_ATTR = "_gen_tracing_is_patched"
15
+ _MODULE_CALLFUNC = "__call__"
16
+ _MODULE_GENFUNC = "_generate"
17
+
18
+ Tfunc = TypeVar("Tfunc")
19
+
20
+ TWrapperCreate: TypeAlias = Callable[["PatchFunctionTarget", "Patcher"], Callable]
21
+
22
+
23
+ class TraceLevel(enum.IntEnum):
24
+ """
25
+ Higher = coarser. Lower = finer.
26
+
27
+ The resulting graph will contain nodes which are this level
28
+ """
29
+
30
+ GRAMMAR = 100 # _distribution functions
31
+ RANDOM_CONTROL = 60 # pf.control.choice
32
+ RANDOM_PARAMS = 50 # np.random calls, all other operations of numbers.
33
+ GENERATORS = 40
34
+ NODEGROUPS = 30 # @node_function
35
+ PRIMITIVES = 20
36
+
37
+
38
+ @dataclass
39
+ class PatchFunctionTarget:
40
+ frame: dict
41
+ name: str
42
+ trace_level: TraceLevel # Compared to users requested trace level to decide whether this target will be a leaf. More
43
+ normalize: bool = True
44
+ allow_exec: bool = False
45
+ custom_trace_wrapper_create: TWrapperCreate | None = None
46
+ source_name: str | None = None # used for logging/debugging
47
+ mutates: list[str] | None = None # list of argument names that the function mutates
48
+
49
+
50
+ @dataclass
51
+ class Patch:
52
+ frame: dict[str, Any]
53
+ fn_name: str
54
+ orig_fn: Callable
55
+ patched_fn: Callable
56
+
57
+ def patch(self):
58
+ raise NotImplementedError()
59
+
60
+ def unpatch(self):
61
+ raise NotImplementedError()
62
+
63
+
64
+ @dataclass
65
+ class PatchSetItem(Patch):
66
+ def patch(self):
67
+ self.frame[self.fn_name] = self.patched_fn
68
+
69
+ def unpatch(self):
70
+ self.frame[self.fn_name] = self.orig_fn
71
+
72
+
73
+ @dataclass
74
+ class PatchSetAttr(Patch):
75
+ def patch(self):
76
+ logger.debug(
77
+ f"Patching {self.fn_name} in {id(self.frame)} from {self.orig_fn} to {self.patched_fn}"
78
+ )
79
+ setattr(self.frame, self.fn_name, self.patched_fn)
80
+
81
+ def unpatch(self):
82
+ setattr(self.frame, self.fn_name, self.orig_fn)
83
+
84
+
85
+ def _targets_from_module(
86
+ module: ModuleType,
87
+ seen: set[int],
88
+ trace_level: TraceLevel,
89
+ normalize: bool = False,
90
+ allow_exec: bool = False,
91
+ ) -> list[PatchFunctionTarget]:
92
+ """
93
+ Gather functions from modules like `math`, `numpy`, `blendfunc` wherein we want to wrap everything they contain
94
+
95
+ We will use such functions as primitives - we will not trace their internals
96
+
97
+ We will allow them to be executed if they are called with non-dynamic args
98
+
99
+ Args:
100
+ module: the module to gather targets from
101
+ seen: a set of ids of functions that have already been processed and should be skipped
102
+ normalize: whether to try to convert args to kwargs where available. fails for some modules such as numpy
103
+ allow_exec: whether to allow the function to be executed if it is called with non-dynamic args
104
+
105
+ Returns:
106
+ a list of PatchFunctionTargets, one for each valid callable in the module. We ignore any function starting with _ or that is not in __all__.
107
+ """
108
+
109
+ results = []
110
+ exports = getattr(module, "__all__", None)
111
+ for name, value in module.__dict__.items():
112
+ if (
113
+ name.startswith("_")
114
+ or issubclass(type(value), type)
115
+ or not callable(value)
116
+ or id(value) in seen
117
+ or (exports is not None and name not in exports)
118
+ ):
119
+ continue
120
+
121
+ seen.add(id(value))
122
+
123
+ target = PatchFunctionTarget(
124
+ frame=module.__dict__,
125
+ name=name,
126
+ source_name=module.__name__,
127
+ normalize=normalize,
128
+ allow_exec=allow_exec,
129
+ trace_level=trace_level,
130
+ )
131
+
132
+ results.append(target)
133
+
134
+ # logger.debug(
135
+ # f"Gathered {len(results)} targets from {module.__name__} with {normalize=} {allow_exec=}"
136
+ # )
137
+
138
+ return results
139
+
140
+
141
+ class Patcher:
142
+ """
143
+ A 'patch' is an invasive modification of another module (e.g numpy, math) which makes their functions behave differently for tracing
144
+
145
+ This class creates patches and records them so that we can undo them later
146
+ """
147
+
148
+ def __init__(
149
+ self,
150
+ trace_level: TraceLevel,
151
+ autopatch_wrap_modules: list[tuple[ModuleType, bool, TraceLevel]] | None = None,
152
+ autopatch_remove_modules: list[ModuleType] | None = None,
153
+ patch_functions: list[PatchFunctionTarget] | None = None,
154
+ search_scopes: list[dict] | None = None,
155
+ ):
156
+ self.trace_level = trace_level
157
+ if autopatch_wrap_modules is None:
158
+ autopatch_wrap_modules = []
159
+ if autopatch_remove_modules is None:
160
+ autopatch_remove_modules = []
161
+ if patch_functions is None:
162
+ patch_functions = []
163
+ if search_scopes is None:
164
+ search_scopes = []
165
+
166
+ self.patched: list[Patch] = []
167
+ self.visited_frames: set[int] = set()
168
+
169
+ modules_seen = set()
170
+
171
+ autowrap_targets = []
172
+ for m, allow_exec, mod_trace_level in autopatch_wrap_modules:
173
+ autowrap_targets += _targets_from_module(
174
+ m,
175
+ modules_seen,
176
+ trace_level=mod_trace_level,
177
+ allow_exec=allow_exec,
178
+ )
179
+
180
+ ban_targets = []
181
+ for m in autopatch_remove_modules:
182
+ ban_targets += _targets_from_module(
183
+ m, modules_seen, trace_level=TraceLevel.PRIMITIVES
184
+ )
185
+
186
+ self.patch_function_ids = {
187
+ # keys ensure we can have separate unique values for same func in different frames
188
+ id(target.frame[target.name]): target
189
+ for target in autowrap_targets + patch_functions
190
+ }
191
+ self.ban_function_ids = {
192
+ id(target.frame[target.name]): target for target in ban_targets
193
+ }
194
+
195
+ self._autopatch_wrap_modules = autopatch_wrap_modules
196
+ self._autopatch_remove_modules = autopatch_remove_modules
197
+
198
+ for scope in search_scopes:
199
+ logger.debug(f"Adding search scope {id(scope)}")
200
+ self.search_autowrap_targets(scope)
201
+
202
+ logger.debug(
203
+ f"{Patcher.__name__} found {len(self.patch_function_ids)=}, {len(self.ban_function_ids)=}"
204
+ )
205
+
206
+ def apply_preexecute_patches(
207
+ self,
208
+ func: Callable,
209
+ trace_level: TraceLevel | None = None,
210
+ ):
211
+ for target in self.patch_function_ids.values():
212
+ orig_fn = target.frame.get(target.name)
213
+ if orig_fn is None:
214
+ raise NotImplementedError(
215
+ f"{target.name} not found in frame (possibly a builtin)"
216
+ )
217
+
218
+ if getattr(orig_fn, PATCHING_FLAG_ATTR, False):
219
+ continue
220
+
221
+ func_wrapper = self.create_wrapper(orig_fn)
222
+ patch = PatchSetItem(
223
+ frame=target.frame,
224
+ fn_name=target.name,
225
+ orig_fn=orig_fn,
226
+ patched_fn=func_wrapper,
227
+ )
228
+ self.patch(patch)
229
+
230
+ for _id, target in self.ban_function_ids.items():
231
+ orig_fn = target.frame.get(target.name)
232
+ if orig_fn is None:
233
+ raise NotImplementedError(
234
+ f"{target.name} not found in frame (possibly a builtin)"
235
+ )
236
+ patch = PatchSetItem(
237
+ frame=target.frame,
238
+ fn_name=target.name,
239
+ orig_fn=orig_fn,
240
+ patched_fn=_create_banned_func_wrapper(orig_fn),
241
+ )
242
+ self.patch(patch)
243
+
244
+ for module, _allow_exec, _trace_level in self._autopatch_wrap_modules:
245
+ self.search_autowrap_targets(module.__dict__)
246
+ for module in self._autopatch_remove_modules:
247
+ self.search_autowrap_targets(module.__dict__)
248
+
249
+ self.search_autowrap_targets(func.__globals__)
250
+
251
+ # func may be itself be a target we need to wrap
252
+ func = self.create_wrapper(func)
253
+
254
+ return func
255
+
256
+ # call_wrapper = _create_module_call_wrapper(self)
257
+ # module_call_patch = PatchSetAttr(
258
+ # frame=Module,
259
+ # fn_name=_MODULE_CALLFUNC,
260
+ # orig_fn=_ORIG_MODULE_CALL,
261
+ # patched_fn=call_wrapper,
262
+ # )
263
+ # self.patch(module_call_patch)
264
+
265
+ # module_getattr_patch = PatchSetAttr(
266
+ # frame=Module,
267
+ # fn_name="__getattribute__",
268
+ # orig_fn=_ORIG_MODULE_GETATTR,
269
+ # patched_fn=_create_module_getattr_wrapper(),
270
+ # )
271
+ # patcher.patch(module_getattr_patch)
272
+
273
+ def create_wrapper(
274
+ self,
275
+ func: Callable,
276
+ ):
277
+ if getattr(func, PATCHING_FLAG_ATTR, False):
278
+ raise ValueError(
279
+ f"Function {func.__name__} is already wrapped, should have already been skipped"
280
+ )
281
+
282
+ wrap_target = self.patch_function_ids.get(id(func))
283
+ ban_target = self.ban_function_ids.get(id(func))
284
+
285
+ is_wrap = wrap_target is not None
286
+ is_ban = ban_target is not None
287
+
288
+ # Can't wrap functions without __globals__ unless they're explicit targets
289
+ if not hasattr(func, "__globals__") and not is_wrap and not is_ban:
290
+ return func
291
+
292
+ match is_wrap, is_ban:
293
+ case (False, False):
294
+ wrapper = _create_nonleaf_wrap_discover_wrapper(func, self)
295
+ case (True, False):
296
+ if wrap_target.custom_trace_wrapper_create is not None:
297
+ wrapper = wrap_target.custom_trace_wrapper_create(wrap_target, self)
298
+ elif wrap_target.trace_level <= self.trace_level:
299
+ wrapper = _create_leaf_func_proxy_wrapper(wrap_target, self)
300
+ else:
301
+ wrapper = _create_nonleaf_wrap_discover_wrapper(func, self)
302
+ case False, True:
303
+ wrapper = _create_banned_func_wrapper(func)
304
+ case True, True:
305
+ raise ValueError(
306
+ f"Got {func.__name__} which was both wrapped and banned? {wrap_target=} {ban_target=}"
307
+ )
308
+
309
+ # if logger.isEnabledFor(logging.DEBUG):
310
+ # logger.debug(
311
+ # f"Created wrapper for {func.__name__} {id(func)} -> {id(wrapper)} {is_wrap=} {is_ban=}"
312
+ # )
313
+
314
+ return wrapper
315
+
316
+ def search_autowrap_targets(self, frame: dict):
317
+ # for efficiency - dont bother re-searching frames, since all the functions will be tagged/wrapped already
318
+ if id(frame) in self.visited_frames:
319
+ logger.debug(f"Already-visited frame {id(frame)}")
320
+ return
321
+ self.visited_frames.add(id(frame))
322
+
323
+ for name, value in frame.items():
324
+ skip = (
325
+ getattr(value, PATCHING_FLAG_ATTR, False)
326
+ or (name.startswith("__") and name.endswith("__"))
327
+ or not callable(value)
328
+ )
329
+ if skip:
330
+ continue
331
+
332
+ patch = PatchSetItem(
333
+ frame=frame,
334
+ fn_name=name,
335
+ orig_fn=value,
336
+ patched_fn=self.create_wrapper(value),
337
+ )
338
+ self.patch(patch)
339
+
340
+ def patch(
341
+ self,
342
+ patch: Patch,
343
+ ):
344
+ if hasattr(patch.orig_fn, PATCHING_FLAG_ATTR):
345
+ logger.debug(f"skipping already patched {patch.orig_fn} {id(patch)}")
346
+ return
347
+
348
+ try:
349
+ setattr(patch.patched_fn, PATCHING_FLAG_ATTR, True)
350
+ except (TypeError, AttributeError) as _e:
351
+ # logger.debug(
352
+ # f"Failed to setattr {PATCHING_FLAG_ATTR} on {patch.patched_fn} {type(patch.patched_fn)=} {e}"
353
+ # )
354
+ pass # cant cache bools on some immutable types like Tuple
355
+
356
+ patch.patch()
357
+ self.patched.append(patch)
358
+
359
+ def unpatch_all(self):
360
+ for patch in self.patched:
361
+ patch.unpatch()
362
+ self.patched.clear()
363
+
364
+ def __enter__(self):
365
+ return self
366
+
367
+ def __exit__(self, exc_type, exc_val, exc_tb):
368
+ self.unpatch_all()
369
+ return False
370
+
371
+
372
+ def _create_leaf_func_proxy_wrapper(
373
+ target: PatchFunctionTarget,
374
+ patcher: Patcher,
375
+ ) -> cg.Node | Tfunc:
376
+ func = target.frame[target.name]
377
+
378
+ @functools.wraps(func)
379
+ def wrapper(*args, **kwargs):
380
+ logger.debug(f"Executing leaf_wrapper! func={func.__name__}")
381
+
382
+ # zero proxy args means the user is trying to call this function with real args during tracing
383
+ # e.g they might do: bias = np.zeros(3)
384
+ # some functions will allow this and just execute the function and return the result
385
+ def _unwrap_proxy(v):
386
+ return v.node if isinstance(v, cg.Proxy) else v
387
+
388
+ if target.allow_exec:
389
+ all_leaves, _ = pytree.flatten((args, kwargs))
390
+ proxy_leaves = [v for v in all_leaves if isinstance(v, cg.Proxy)]
391
+ if not proxy_leaves:
392
+ return func(*args, **kwargs)
393
+ rng_only = all(isinstance(v, RngProxy) for v in proxy_leaves)
394
+ if rng_only and patcher.trace_level < TraceLevel.RANDOM_PARAMS:
395
+
396
+ def _unbox_rng(v):
397
+ return v.rng if isinstance(v, RngProxy) else v
398
+
399
+ concrete_args = tuple(
400
+ pytree.PyTree(a).map(_unbox_rng).obj() for a in args
401
+ )
402
+ concrete_kwargs = {
403
+ k: pytree.PyTree(v).map(_unbox_rng).obj() for k, v in kwargs.items()
404
+ }
405
+ return func(*concrete_args, **concrete_kwargs)
406
+
407
+ if target.normalize:
408
+ args, kwargs = cg.normalize_args_to_kwargs(func, args, kwargs)
409
+
410
+ # Convert any Proxy args to their underlying nodes, including those nested in containers
411
+ node_args = tuple(pytree.PyTree(a).map(_unwrap_proxy).obj() for a in args)
412
+ node_kwargs = {
413
+ k: pytree.PyTree(v).map(_unwrap_proxy).obj() for k, v in kwargs.items()
414
+ }
415
+
416
+ node = cg.FunctionCallNode(func=func, args=node_args, kwargs=node_kwargs)
417
+
418
+ # Handle mutations by updating proxies for mutated arguments
419
+ for param_name in target.mutates or []:
420
+ if not (param_name in kwargs and isinstance(kwargs[param_name], cg.Proxy)):
421
+ continue
422
+ original_proxy = kwargs[param_name]
423
+ mutated_node = cg.MutatedArgumentNode(
424
+ mutator_call_node=node, original_node=original_proxy.node
425
+ )
426
+ logger.debug(
427
+ f"MutatedArgumentNode created: {func.__name__}({param_name}=...) -> {mutated_node}"
428
+ )
429
+ original_proxy.node = mutated_node
430
+
431
+ return cg.Proxy(node)
432
+
433
+ if hasattr(func, "reduce"):
434
+ # numpy.random.Generator breaks if these special reduce functions-inside-functions are not copied over?
435
+ # TODO: handle the general case of metadata attrs on wrapped functions, I believe torch.fx does this
436
+ wrapper.reduce = func.reduce
437
+
438
+ setattr(wrapper, PATCHING_FLAG_ATTR, True)
439
+ return wrapper
440
+
441
+
442
+ def _create_nonleaf_wrap_discover_wrapper(
443
+ func: Callable,
444
+ patcher: Patcher,
445
+ ):
446
+ # TODO: should record non-leaf functions in the graph while also executing their internals?
447
+
448
+ if getattr(func, PATCHING_FLAG_ATTR, False):
449
+ raise ValueError(f"Function {func.__name__} is already wrapped")
450
+
451
+ @functools.wraps(func)
452
+ def wrapper(*args, **kwargs):
453
+ frame = getattr(func, "__globals__", {})
454
+ logger.debug(f"Searching autowrap targets for {func.__name__} {id(frame)}")
455
+ patcher.search_autowrap_targets(frame)
456
+ logger.debug(f"Finished autowrap targets for {func.__name__} {id(frame)}")
457
+
458
+ # Wrap any callable args that are registered targets but weren't reachable
459
+ # via module-dict patching (e.g. passed as function-valued arguments)
460
+ def maybe_wrap(v):
461
+ if not callable(v):
462
+ return v
463
+ if getattr(v, PATCHING_FLAG_ATTR, False):
464
+ return v
465
+ if id(v) in patcher.patch_function_ids:
466
+ logger.debug(
467
+ f"maybe_wrap: wrapping callable arg {getattr(v, '__name__', v)!r} id={id(v)} found in patch_function_ids"
468
+ )
469
+ return patcher.create_wrapper(v)
470
+ logger.debug(
471
+ f"maybe_wrap: callable arg {getattr(v, '__name__', v)!r} id={id(v)} NOT in patch_function_ids (size={len(patcher.patch_function_ids)})"
472
+ )
473
+ return v
474
+
475
+ args = tuple(maybe_wrap(a) for a in args)
476
+ kwargs = {k: maybe_wrap(v) for k, v in kwargs.items()}
477
+
478
+ return func(*args, **kwargs)
479
+
480
+ setattr(wrapper, PATCHING_FLAG_ATTR, True)
481
+
482
+ return wrapper
483
+
484
+
485
+ def _create_banned_func_wrapper(func: Callable) -> Callable:
486
+ @functools.wraps(func)
487
+ def wrapped_func(*args, **kwargs):
488
+ raise ValueError(
489
+ f"Tracing failed - {func.__name__} is banned from use in traceable functions"
490
+ )
491
+
492
+ setattr(wrapped_func, PATCHING_FLAG_ATTR, True)
493
+
494
+ return wrapped_func
@@ -0,0 +1,127 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+ from procfunc import compute_graph as cg
6
+
7
+
8
+ @dataclass
9
+ class RngProxy(cg.Proxy):
10
+ """
11
+ We will do specialcase handling to trace rng nodes through the graph
12
+
13
+ This allows us to know what the results of choice() and other random control flow will be,
14
+ BUT only if the user has kept the rng non-dirty, i.e. the rng used for choice descends from the root via only spawn() calls
15
+
16
+ This is essential so that the result of choice() during tracing is the same as the result of choice() during generation
17
+ """
18
+
19
+ rng: np.random.Generator
20
+
21
+ # true if any randomness has been taken from this rng besides splitting it via spawn
22
+ dirty: bool = False
23
+
24
+ def __init__(
25
+ self,
26
+ node: cg.Node,
27
+ rng: np.random.Generator,
28
+ dirty: bool = False,
29
+ ):
30
+ super().__init__(node)
31
+ node.metadata["known_type"] = np.random.Generator
32
+ node.metadata["varname"] = "rng"
33
+ self.rng = rng
34
+ self.dirty = dirty
35
+
36
+ def spawn(self, n_children: int) -> "RngSpawnResultProxy":
37
+ """
38
+ Returns a mock node which can only be unpacked into its constitutent mock rngs
39
+ """
40
+
41
+ assert isinstance(n_children, int)
42
+
43
+ spawn_node = cg.MethodCallNode(
44
+ callee=self.node,
45
+ method_name="spawn",
46
+ args=(n_children,),
47
+ kwargs={},
48
+ )
49
+
50
+ child_rngs = list(self.rng.spawn(n_children))
51
+
52
+ return RngSpawnResultProxy(
53
+ node=spawn_node,
54
+ from_rng_proxy=self,
55
+ child_rngs=child_rngs,
56
+ dirty=self.dirty,
57
+ )
58
+
59
+ def __getattr__(self, name: str):
60
+ is_dirty_op = (
61
+ not name.startswith("_") and not name == "spawn" and hasattr(self.rng, name)
62
+ )
63
+
64
+ if is_dirty_op:
65
+ self.dirty = True
66
+
67
+ if not hasattr(self.rng, name):
68
+ raise AttributeError(
69
+ f"__getattr__ {name} is invalid because {type(self.rng)} has no attribute {name}"
70
+ )
71
+
72
+ return super().__getattr__(name)
73
+
74
+
75
+ @dataclass
76
+ class RngSpawnResultProxy(cg.Proxy):
77
+ from_rng_proxy: "RngProxy"
78
+ child_rngs: list[np.random.Generator]
79
+ dirty: bool = False
80
+
81
+ def __init__(
82
+ self,
83
+ node: cg.Node,
84
+ from_rng_proxy: "RngProxy",
85
+ child_rngs: list[np.random.Generator],
86
+ dirty: bool = False,
87
+ ):
88
+ super().__init__(node)
89
+ node.metadata["varname"] = "rngs"
90
+ self.from_rng_proxy = from_rng_proxy
91
+ self.child_rngs = child_rngs
92
+ self.dirty = dirty
93
+ if self.dirty:
94
+ import warnings
95
+
96
+ warnings.warn(
97
+ "RngSpawnResultProxy has dirty=True, tracing results may be incomplete"
98
+ )
99
+
100
+ def __getitem__(self, idx: int) -> "RngProxy":
101
+ if idx < 0 or idx >= len(self.child_rngs):
102
+ raise IndexError(
103
+ f"Index {idx} out of range for {len(self.child_rngs)} children"
104
+ )
105
+
106
+ getitem_node = cg.MethodCallNode(
107
+ callee=self.node,
108
+ method_name="__getitem__",
109
+ args=(idx,),
110
+ kwargs={},
111
+ metadata={"known_value_type": np.random.Generator},
112
+ )
113
+
114
+ return RngProxy(
115
+ node=getitem_node,
116
+ rng=self.child_rngs[idx],
117
+ dirty=self.dirty,
118
+ )
119
+
120
+ def __iter__(self):
121
+ """
122
+ specialcase to allow __iter__() since SpawnResult has known size
123
+
124
+ allows x,y,z = rng.spawn(3) syntax to work in functions that need to be traceable
125
+ """
126
+ for i in range(len(self.child_rngs)):
127
+ yield self.__getitem__(i)