procfunc 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- procfunc/__init__.py +87 -0
- procfunc/color.py +57 -0
- procfunc/compute_graph/__init__.py +28 -0
- procfunc/compute_graph/compute_graph.py +115 -0
- procfunc/compute_graph/node.py +200 -0
- procfunc/compute_graph/operators_info.py +92 -0
- procfunc/compute_graph/proxy.py +173 -0
- procfunc/compute_graph/util.py +282 -0
- procfunc/context.py +115 -0
- procfunc/control.py +174 -0
- procfunc/nodes/__init__.py +66 -0
- procfunc/nodes/bindings_util.py +196 -0
- procfunc/nodes/bpy_node_info.py +280 -0
- procfunc/nodes/compositor.py +2242 -0
- procfunc/nodes/execute/construct_nodes.py +571 -0
- procfunc/nodes/execute/construct_special_cases.py +246 -0
- procfunc/nodes/execute/execute.py +548 -0
- procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
- procfunc/nodes/execute/util.py +247 -0
- procfunc/nodes/func.py +1417 -0
- procfunc/nodes/geo.py +4240 -0
- procfunc/nodes/manifest.json +8769 -0
- procfunc/nodes/math.py +644 -0
- procfunc/nodes/node_function.py +160 -0
- procfunc/nodes/shader.py +2359 -0
- procfunc/nodes/types.py +347 -0
- procfunc/ops/__init__.py +35 -0
- procfunc/ops/_util.py +275 -0
- procfunc/ops/addons.py +59 -0
- procfunc/ops/attr.py +426 -0
- procfunc/ops/collection.py +90 -0
- procfunc/ops/curve.py +18 -0
- procfunc/ops/file.py +126 -0
- procfunc/ops/manifest.json +39149 -0
- procfunc/ops/mesh.py +1510 -0
- procfunc/ops/modifier.py +603 -0
- procfunc/ops/object.py +258 -0
- procfunc/ops/primitives/__init__.py +31 -0
- procfunc/ops/primitives/camera.py +45 -0
- procfunc/ops/primitives/curve.py +71 -0
- procfunc/ops/primitives/light.py +114 -0
- procfunc/ops/primitives/mesh.py +358 -0
- procfunc/ops/uv.py +271 -0
- procfunc/random.py +247 -0
- procfunc/tracer/__init__.py +43 -0
- procfunc/tracer/decorator.py +121 -0
- procfunc/tracer/patch.py +494 -0
- procfunc/tracer/proxy.py +127 -0
- procfunc/tracer/trace.py +222 -0
- procfunc/transforms/__init__.py +49 -0
- procfunc/transforms/cleanup.py +214 -0
- procfunc/transforms/convert.py +20 -0
- procfunc/transforms/distribution.py +191 -0
- procfunc/transforms/extract_materials.py +116 -0
- procfunc/transforms/infer_distribution.py +326 -0
- procfunc/transforms/parameters.py +15 -0
- procfunc/transforms/util.py +35 -0
- procfunc/transpiler/__init__.py +24 -0
- procfunc/transpiler/bpy_to_computegraph.py +1348 -0
- procfunc/transpiler/codegen.py +919 -0
- procfunc/transpiler/identifiers.py +595 -0
- procfunc/transpiler/main.py +299 -0
- procfunc/types.py +380 -0
- procfunc/util/__init__.py +0 -0
- procfunc/util/bpy_info.py +145 -0
- procfunc/util/camera.py +0 -0
- procfunc/util/keyframe.py +70 -0
- procfunc/util/log.py +96 -0
- procfunc/util/manifest.py +121 -0
- procfunc/util/pytree.py +343 -0
- procfunc/util/teardown.py +37 -0
- procfunc-0.30.0.dist-info/METADATA +120 -0
- procfunc-0.30.0.dist-info/RECORD +76 -0
- procfunc-0.30.0.dist-info/WHEEL +5 -0
- procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
- procfunc-0.30.0.dist-info/top_level.txt +1 -0
procfunc/tracer/patch.py
ADDED
|
@@ -0,0 +1,494 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import functools
|
|
3
|
+
import logging
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from types import ModuleType
|
|
6
|
+
from typing import Any, Callable, TypeAlias, TypeVar
|
|
7
|
+
|
|
8
|
+
from procfunc import compute_graph as cg
|
|
9
|
+
from procfunc.tracer.proxy import RngProxy
|
|
10
|
+
from procfunc.util import pytree
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
PATCHING_FLAG_ATTR = "_gen_tracing_is_patched"
|
|
15
|
+
_MODULE_CALLFUNC = "__call__"
|
|
16
|
+
_MODULE_GENFUNC = "_generate"
|
|
17
|
+
|
|
18
|
+
Tfunc = TypeVar("Tfunc")
|
|
19
|
+
|
|
20
|
+
TWrapperCreate: TypeAlias = Callable[["PatchFunctionTarget", "Patcher"], Callable]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TraceLevel(enum.IntEnum):
|
|
24
|
+
"""
|
|
25
|
+
Higher = coarser. Lower = finer.
|
|
26
|
+
|
|
27
|
+
The resulting graph will contain nodes which are this level
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
GRAMMAR = 100 # _distribution functions
|
|
31
|
+
RANDOM_CONTROL = 60 # pf.control.choice
|
|
32
|
+
RANDOM_PARAMS = 50 # np.random calls, all other operations of numbers.
|
|
33
|
+
GENERATORS = 40
|
|
34
|
+
NODEGROUPS = 30 # @node_function
|
|
35
|
+
PRIMITIVES = 20
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class PatchFunctionTarget:
|
|
40
|
+
frame: dict
|
|
41
|
+
name: str
|
|
42
|
+
trace_level: TraceLevel # Compared to users requested trace level to decide whether this target will be a leaf. More
|
|
43
|
+
normalize: bool = True
|
|
44
|
+
allow_exec: bool = False
|
|
45
|
+
custom_trace_wrapper_create: TWrapperCreate | None = None
|
|
46
|
+
source_name: str | None = None # used for logging/debugging
|
|
47
|
+
mutates: list[str] | None = None # list of argument names that the function mutates
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
@dataclass
|
|
51
|
+
class Patch:
|
|
52
|
+
frame: dict[str, Any]
|
|
53
|
+
fn_name: str
|
|
54
|
+
orig_fn: Callable
|
|
55
|
+
patched_fn: Callable
|
|
56
|
+
|
|
57
|
+
def patch(self):
|
|
58
|
+
raise NotImplementedError()
|
|
59
|
+
|
|
60
|
+
def unpatch(self):
|
|
61
|
+
raise NotImplementedError()
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
@dataclass
|
|
65
|
+
class PatchSetItem(Patch):
|
|
66
|
+
def patch(self):
|
|
67
|
+
self.frame[self.fn_name] = self.patched_fn
|
|
68
|
+
|
|
69
|
+
def unpatch(self):
|
|
70
|
+
self.frame[self.fn_name] = self.orig_fn
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass
|
|
74
|
+
class PatchSetAttr(Patch):
|
|
75
|
+
def patch(self):
|
|
76
|
+
logger.debug(
|
|
77
|
+
f"Patching {self.fn_name} in {id(self.frame)} from {self.orig_fn} to {self.patched_fn}"
|
|
78
|
+
)
|
|
79
|
+
setattr(self.frame, self.fn_name, self.patched_fn)
|
|
80
|
+
|
|
81
|
+
def unpatch(self):
|
|
82
|
+
setattr(self.frame, self.fn_name, self.orig_fn)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def _targets_from_module(
|
|
86
|
+
module: ModuleType,
|
|
87
|
+
seen: set[int],
|
|
88
|
+
trace_level: TraceLevel,
|
|
89
|
+
normalize: bool = False,
|
|
90
|
+
allow_exec: bool = False,
|
|
91
|
+
) -> list[PatchFunctionTarget]:
|
|
92
|
+
"""
|
|
93
|
+
Gather functions from modules like `math`, `numpy`, `blendfunc` wherein we want to wrap everything they contain
|
|
94
|
+
|
|
95
|
+
We will use such functions as primitives - we will not trace their internals
|
|
96
|
+
|
|
97
|
+
We will allow them to be executed if they are called with non-dynamic args
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
module: the module to gather targets from
|
|
101
|
+
seen: a set of ids of functions that have already been processed and should be skipped
|
|
102
|
+
normalize: whether to try to convert args to kwargs where available. fails for some modules such as numpy
|
|
103
|
+
allow_exec: whether to allow the function to be executed if it is called with non-dynamic args
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
a list of PatchFunctionTargets, one for each valid callable in the module. We ignore any function starting with _ or that is not in __all__.
|
|
107
|
+
"""
|
|
108
|
+
|
|
109
|
+
results = []
|
|
110
|
+
exports = getattr(module, "__all__", None)
|
|
111
|
+
for name, value in module.__dict__.items():
|
|
112
|
+
if (
|
|
113
|
+
name.startswith("_")
|
|
114
|
+
or issubclass(type(value), type)
|
|
115
|
+
or not callable(value)
|
|
116
|
+
or id(value) in seen
|
|
117
|
+
or (exports is not None and name not in exports)
|
|
118
|
+
):
|
|
119
|
+
continue
|
|
120
|
+
|
|
121
|
+
seen.add(id(value))
|
|
122
|
+
|
|
123
|
+
target = PatchFunctionTarget(
|
|
124
|
+
frame=module.__dict__,
|
|
125
|
+
name=name,
|
|
126
|
+
source_name=module.__name__,
|
|
127
|
+
normalize=normalize,
|
|
128
|
+
allow_exec=allow_exec,
|
|
129
|
+
trace_level=trace_level,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
results.append(target)
|
|
133
|
+
|
|
134
|
+
# logger.debug(
|
|
135
|
+
# f"Gathered {len(results)} targets from {module.__name__} with {normalize=} {allow_exec=}"
|
|
136
|
+
# )
|
|
137
|
+
|
|
138
|
+
return results
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class Patcher:
|
|
142
|
+
"""
|
|
143
|
+
A 'patch' is an invasive modification of another module (e.g numpy, math) which makes their functions behave differently for tracing
|
|
144
|
+
|
|
145
|
+
This class creates patches and records them so that we can undo them later
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(
|
|
149
|
+
self,
|
|
150
|
+
trace_level: TraceLevel,
|
|
151
|
+
autopatch_wrap_modules: list[tuple[ModuleType, bool, TraceLevel]] | None = None,
|
|
152
|
+
autopatch_remove_modules: list[ModuleType] | None = None,
|
|
153
|
+
patch_functions: list[PatchFunctionTarget] | None = None,
|
|
154
|
+
search_scopes: list[dict] | None = None,
|
|
155
|
+
):
|
|
156
|
+
self.trace_level = trace_level
|
|
157
|
+
if autopatch_wrap_modules is None:
|
|
158
|
+
autopatch_wrap_modules = []
|
|
159
|
+
if autopatch_remove_modules is None:
|
|
160
|
+
autopatch_remove_modules = []
|
|
161
|
+
if patch_functions is None:
|
|
162
|
+
patch_functions = []
|
|
163
|
+
if search_scopes is None:
|
|
164
|
+
search_scopes = []
|
|
165
|
+
|
|
166
|
+
self.patched: list[Patch] = []
|
|
167
|
+
self.visited_frames: set[int] = set()
|
|
168
|
+
|
|
169
|
+
modules_seen = set()
|
|
170
|
+
|
|
171
|
+
autowrap_targets = []
|
|
172
|
+
for m, allow_exec, mod_trace_level in autopatch_wrap_modules:
|
|
173
|
+
autowrap_targets += _targets_from_module(
|
|
174
|
+
m,
|
|
175
|
+
modules_seen,
|
|
176
|
+
trace_level=mod_trace_level,
|
|
177
|
+
allow_exec=allow_exec,
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
ban_targets = []
|
|
181
|
+
for m in autopatch_remove_modules:
|
|
182
|
+
ban_targets += _targets_from_module(
|
|
183
|
+
m, modules_seen, trace_level=TraceLevel.PRIMITIVES
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
self.patch_function_ids = {
|
|
187
|
+
# keys ensure we can have separate unique values for same func in different frames
|
|
188
|
+
id(target.frame[target.name]): target
|
|
189
|
+
for target in autowrap_targets + patch_functions
|
|
190
|
+
}
|
|
191
|
+
self.ban_function_ids = {
|
|
192
|
+
id(target.frame[target.name]): target for target in ban_targets
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
self._autopatch_wrap_modules = autopatch_wrap_modules
|
|
196
|
+
self._autopatch_remove_modules = autopatch_remove_modules
|
|
197
|
+
|
|
198
|
+
for scope in search_scopes:
|
|
199
|
+
logger.debug(f"Adding search scope {id(scope)}")
|
|
200
|
+
self.search_autowrap_targets(scope)
|
|
201
|
+
|
|
202
|
+
logger.debug(
|
|
203
|
+
f"{Patcher.__name__} found {len(self.patch_function_ids)=}, {len(self.ban_function_ids)=}"
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
def apply_preexecute_patches(
|
|
207
|
+
self,
|
|
208
|
+
func: Callable,
|
|
209
|
+
trace_level: TraceLevel | None = None,
|
|
210
|
+
):
|
|
211
|
+
for target in self.patch_function_ids.values():
|
|
212
|
+
orig_fn = target.frame.get(target.name)
|
|
213
|
+
if orig_fn is None:
|
|
214
|
+
raise NotImplementedError(
|
|
215
|
+
f"{target.name} not found in frame (possibly a builtin)"
|
|
216
|
+
)
|
|
217
|
+
|
|
218
|
+
if getattr(orig_fn, PATCHING_FLAG_ATTR, False):
|
|
219
|
+
continue
|
|
220
|
+
|
|
221
|
+
func_wrapper = self.create_wrapper(orig_fn)
|
|
222
|
+
patch = PatchSetItem(
|
|
223
|
+
frame=target.frame,
|
|
224
|
+
fn_name=target.name,
|
|
225
|
+
orig_fn=orig_fn,
|
|
226
|
+
patched_fn=func_wrapper,
|
|
227
|
+
)
|
|
228
|
+
self.patch(patch)
|
|
229
|
+
|
|
230
|
+
for _id, target in self.ban_function_ids.items():
|
|
231
|
+
orig_fn = target.frame.get(target.name)
|
|
232
|
+
if orig_fn is None:
|
|
233
|
+
raise NotImplementedError(
|
|
234
|
+
f"{target.name} not found in frame (possibly a builtin)"
|
|
235
|
+
)
|
|
236
|
+
patch = PatchSetItem(
|
|
237
|
+
frame=target.frame,
|
|
238
|
+
fn_name=target.name,
|
|
239
|
+
orig_fn=orig_fn,
|
|
240
|
+
patched_fn=_create_banned_func_wrapper(orig_fn),
|
|
241
|
+
)
|
|
242
|
+
self.patch(patch)
|
|
243
|
+
|
|
244
|
+
for module, _allow_exec, _trace_level in self._autopatch_wrap_modules:
|
|
245
|
+
self.search_autowrap_targets(module.__dict__)
|
|
246
|
+
for module in self._autopatch_remove_modules:
|
|
247
|
+
self.search_autowrap_targets(module.__dict__)
|
|
248
|
+
|
|
249
|
+
self.search_autowrap_targets(func.__globals__)
|
|
250
|
+
|
|
251
|
+
# func may be itself be a target we need to wrap
|
|
252
|
+
func = self.create_wrapper(func)
|
|
253
|
+
|
|
254
|
+
return func
|
|
255
|
+
|
|
256
|
+
# call_wrapper = _create_module_call_wrapper(self)
|
|
257
|
+
# module_call_patch = PatchSetAttr(
|
|
258
|
+
# frame=Module,
|
|
259
|
+
# fn_name=_MODULE_CALLFUNC,
|
|
260
|
+
# orig_fn=_ORIG_MODULE_CALL,
|
|
261
|
+
# patched_fn=call_wrapper,
|
|
262
|
+
# )
|
|
263
|
+
# self.patch(module_call_patch)
|
|
264
|
+
|
|
265
|
+
# module_getattr_patch = PatchSetAttr(
|
|
266
|
+
# frame=Module,
|
|
267
|
+
# fn_name="__getattribute__",
|
|
268
|
+
# orig_fn=_ORIG_MODULE_GETATTR,
|
|
269
|
+
# patched_fn=_create_module_getattr_wrapper(),
|
|
270
|
+
# )
|
|
271
|
+
# patcher.patch(module_getattr_patch)
|
|
272
|
+
|
|
273
|
+
def create_wrapper(
|
|
274
|
+
self,
|
|
275
|
+
func: Callable,
|
|
276
|
+
):
|
|
277
|
+
if getattr(func, PATCHING_FLAG_ATTR, False):
|
|
278
|
+
raise ValueError(
|
|
279
|
+
f"Function {func.__name__} is already wrapped, should have already been skipped"
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
wrap_target = self.patch_function_ids.get(id(func))
|
|
283
|
+
ban_target = self.ban_function_ids.get(id(func))
|
|
284
|
+
|
|
285
|
+
is_wrap = wrap_target is not None
|
|
286
|
+
is_ban = ban_target is not None
|
|
287
|
+
|
|
288
|
+
# Can't wrap functions without __globals__ unless they're explicit targets
|
|
289
|
+
if not hasattr(func, "__globals__") and not is_wrap and not is_ban:
|
|
290
|
+
return func
|
|
291
|
+
|
|
292
|
+
match is_wrap, is_ban:
|
|
293
|
+
case (False, False):
|
|
294
|
+
wrapper = _create_nonleaf_wrap_discover_wrapper(func, self)
|
|
295
|
+
case (True, False):
|
|
296
|
+
if wrap_target.custom_trace_wrapper_create is not None:
|
|
297
|
+
wrapper = wrap_target.custom_trace_wrapper_create(wrap_target, self)
|
|
298
|
+
elif wrap_target.trace_level <= self.trace_level:
|
|
299
|
+
wrapper = _create_leaf_func_proxy_wrapper(wrap_target, self)
|
|
300
|
+
else:
|
|
301
|
+
wrapper = _create_nonleaf_wrap_discover_wrapper(func, self)
|
|
302
|
+
case False, True:
|
|
303
|
+
wrapper = _create_banned_func_wrapper(func)
|
|
304
|
+
case True, True:
|
|
305
|
+
raise ValueError(
|
|
306
|
+
f"Got {func.__name__} which was both wrapped and banned? {wrap_target=} {ban_target=}"
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
# if logger.isEnabledFor(logging.DEBUG):
|
|
310
|
+
# logger.debug(
|
|
311
|
+
# f"Created wrapper for {func.__name__} {id(func)} -> {id(wrapper)} {is_wrap=} {is_ban=}"
|
|
312
|
+
# )
|
|
313
|
+
|
|
314
|
+
return wrapper
|
|
315
|
+
|
|
316
|
+
def search_autowrap_targets(self, frame: dict):
|
|
317
|
+
# for efficiency - dont bother re-searching frames, since all the functions will be tagged/wrapped already
|
|
318
|
+
if id(frame) in self.visited_frames:
|
|
319
|
+
logger.debug(f"Already-visited frame {id(frame)}")
|
|
320
|
+
return
|
|
321
|
+
self.visited_frames.add(id(frame))
|
|
322
|
+
|
|
323
|
+
for name, value in frame.items():
|
|
324
|
+
skip = (
|
|
325
|
+
getattr(value, PATCHING_FLAG_ATTR, False)
|
|
326
|
+
or (name.startswith("__") and name.endswith("__"))
|
|
327
|
+
or not callable(value)
|
|
328
|
+
)
|
|
329
|
+
if skip:
|
|
330
|
+
continue
|
|
331
|
+
|
|
332
|
+
patch = PatchSetItem(
|
|
333
|
+
frame=frame,
|
|
334
|
+
fn_name=name,
|
|
335
|
+
orig_fn=value,
|
|
336
|
+
patched_fn=self.create_wrapper(value),
|
|
337
|
+
)
|
|
338
|
+
self.patch(patch)
|
|
339
|
+
|
|
340
|
+
def patch(
|
|
341
|
+
self,
|
|
342
|
+
patch: Patch,
|
|
343
|
+
):
|
|
344
|
+
if hasattr(patch.orig_fn, PATCHING_FLAG_ATTR):
|
|
345
|
+
logger.debug(f"skipping already patched {patch.orig_fn} {id(patch)}")
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
try:
|
|
349
|
+
setattr(patch.patched_fn, PATCHING_FLAG_ATTR, True)
|
|
350
|
+
except (TypeError, AttributeError) as _e:
|
|
351
|
+
# logger.debug(
|
|
352
|
+
# f"Failed to setattr {PATCHING_FLAG_ATTR} on {patch.patched_fn} {type(patch.patched_fn)=} {e}"
|
|
353
|
+
# )
|
|
354
|
+
pass # cant cache bools on some immutable types like Tuple
|
|
355
|
+
|
|
356
|
+
patch.patch()
|
|
357
|
+
self.patched.append(patch)
|
|
358
|
+
|
|
359
|
+
def unpatch_all(self):
|
|
360
|
+
for patch in self.patched:
|
|
361
|
+
patch.unpatch()
|
|
362
|
+
self.patched.clear()
|
|
363
|
+
|
|
364
|
+
def __enter__(self):
|
|
365
|
+
return self
|
|
366
|
+
|
|
367
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
368
|
+
self.unpatch_all()
|
|
369
|
+
return False
|
|
370
|
+
|
|
371
|
+
|
|
372
|
+
def _create_leaf_func_proxy_wrapper(
|
|
373
|
+
target: PatchFunctionTarget,
|
|
374
|
+
patcher: Patcher,
|
|
375
|
+
) -> cg.Node | Tfunc:
|
|
376
|
+
func = target.frame[target.name]
|
|
377
|
+
|
|
378
|
+
@functools.wraps(func)
|
|
379
|
+
def wrapper(*args, **kwargs):
|
|
380
|
+
logger.debug(f"Executing leaf_wrapper! func={func.__name__}")
|
|
381
|
+
|
|
382
|
+
# zero proxy args means the user is trying to call this function with real args during tracing
|
|
383
|
+
# e.g they might do: bias = np.zeros(3)
|
|
384
|
+
# some functions will allow this and just execute the function and return the result
|
|
385
|
+
def _unwrap_proxy(v):
|
|
386
|
+
return v.node if isinstance(v, cg.Proxy) else v
|
|
387
|
+
|
|
388
|
+
if target.allow_exec:
|
|
389
|
+
all_leaves, _ = pytree.flatten((args, kwargs))
|
|
390
|
+
proxy_leaves = [v for v in all_leaves if isinstance(v, cg.Proxy)]
|
|
391
|
+
if not proxy_leaves:
|
|
392
|
+
return func(*args, **kwargs)
|
|
393
|
+
rng_only = all(isinstance(v, RngProxy) for v in proxy_leaves)
|
|
394
|
+
if rng_only and patcher.trace_level < TraceLevel.RANDOM_PARAMS:
|
|
395
|
+
|
|
396
|
+
def _unbox_rng(v):
|
|
397
|
+
return v.rng if isinstance(v, RngProxy) else v
|
|
398
|
+
|
|
399
|
+
concrete_args = tuple(
|
|
400
|
+
pytree.PyTree(a).map(_unbox_rng).obj() for a in args
|
|
401
|
+
)
|
|
402
|
+
concrete_kwargs = {
|
|
403
|
+
k: pytree.PyTree(v).map(_unbox_rng).obj() for k, v in kwargs.items()
|
|
404
|
+
}
|
|
405
|
+
return func(*concrete_args, **concrete_kwargs)
|
|
406
|
+
|
|
407
|
+
if target.normalize:
|
|
408
|
+
args, kwargs = cg.normalize_args_to_kwargs(func, args, kwargs)
|
|
409
|
+
|
|
410
|
+
# Convert any Proxy args to their underlying nodes, including those nested in containers
|
|
411
|
+
node_args = tuple(pytree.PyTree(a).map(_unwrap_proxy).obj() for a in args)
|
|
412
|
+
node_kwargs = {
|
|
413
|
+
k: pytree.PyTree(v).map(_unwrap_proxy).obj() for k, v in kwargs.items()
|
|
414
|
+
}
|
|
415
|
+
|
|
416
|
+
node = cg.FunctionCallNode(func=func, args=node_args, kwargs=node_kwargs)
|
|
417
|
+
|
|
418
|
+
# Handle mutations by updating proxies for mutated arguments
|
|
419
|
+
for param_name in target.mutates or []:
|
|
420
|
+
if not (param_name in kwargs and isinstance(kwargs[param_name], cg.Proxy)):
|
|
421
|
+
continue
|
|
422
|
+
original_proxy = kwargs[param_name]
|
|
423
|
+
mutated_node = cg.MutatedArgumentNode(
|
|
424
|
+
mutator_call_node=node, original_node=original_proxy.node
|
|
425
|
+
)
|
|
426
|
+
logger.debug(
|
|
427
|
+
f"MutatedArgumentNode created: {func.__name__}({param_name}=...) -> {mutated_node}"
|
|
428
|
+
)
|
|
429
|
+
original_proxy.node = mutated_node
|
|
430
|
+
|
|
431
|
+
return cg.Proxy(node)
|
|
432
|
+
|
|
433
|
+
if hasattr(func, "reduce"):
|
|
434
|
+
# numpy.random.Generator breaks if these special reduce functions-inside-functions are not copied over?
|
|
435
|
+
# TODO: handle the general case of metadata attrs on wrapped functions, I believe torch.fx does this
|
|
436
|
+
wrapper.reduce = func.reduce
|
|
437
|
+
|
|
438
|
+
setattr(wrapper, PATCHING_FLAG_ATTR, True)
|
|
439
|
+
return wrapper
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def _create_nonleaf_wrap_discover_wrapper(
|
|
443
|
+
func: Callable,
|
|
444
|
+
patcher: Patcher,
|
|
445
|
+
):
|
|
446
|
+
# TODO: should record non-leaf functions in the graph while also executing their internals?
|
|
447
|
+
|
|
448
|
+
if getattr(func, PATCHING_FLAG_ATTR, False):
|
|
449
|
+
raise ValueError(f"Function {func.__name__} is already wrapped")
|
|
450
|
+
|
|
451
|
+
@functools.wraps(func)
|
|
452
|
+
def wrapper(*args, **kwargs):
|
|
453
|
+
frame = getattr(func, "__globals__", {})
|
|
454
|
+
logger.debug(f"Searching autowrap targets for {func.__name__} {id(frame)}")
|
|
455
|
+
patcher.search_autowrap_targets(frame)
|
|
456
|
+
logger.debug(f"Finished autowrap targets for {func.__name__} {id(frame)}")
|
|
457
|
+
|
|
458
|
+
# Wrap any callable args that are registered targets but weren't reachable
|
|
459
|
+
# via module-dict patching (e.g. passed as function-valued arguments)
|
|
460
|
+
def maybe_wrap(v):
|
|
461
|
+
if not callable(v):
|
|
462
|
+
return v
|
|
463
|
+
if getattr(v, PATCHING_FLAG_ATTR, False):
|
|
464
|
+
return v
|
|
465
|
+
if id(v) in patcher.patch_function_ids:
|
|
466
|
+
logger.debug(
|
|
467
|
+
f"maybe_wrap: wrapping callable arg {getattr(v, '__name__', v)!r} id={id(v)} found in patch_function_ids"
|
|
468
|
+
)
|
|
469
|
+
return patcher.create_wrapper(v)
|
|
470
|
+
logger.debug(
|
|
471
|
+
f"maybe_wrap: callable arg {getattr(v, '__name__', v)!r} id={id(v)} NOT in patch_function_ids (size={len(patcher.patch_function_ids)})"
|
|
472
|
+
)
|
|
473
|
+
return v
|
|
474
|
+
|
|
475
|
+
args = tuple(maybe_wrap(a) for a in args)
|
|
476
|
+
kwargs = {k: maybe_wrap(v) for k, v in kwargs.items()}
|
|
477
|
+
|
|
478
|
+
return func(*args, **kwargs)
|
|
479
|
+
|
|
480
|
+
setattr(wrapper, PATCHING_FLAG_ATTR, True)
|
|
481
|
+
|
|
482
|
+
return wrapper
|
|
483
|
+
|
|
484
|
+
|
|
485
|
+
def _create_banned_func_wrapper(func: Callable) -> Callable:
|
|
486
|
+
@functools.wraps(func)
|
|
487
|
+
def wrapped_func(*args, **kwargs):
|
|
488
|
+
raise ValueError(
|
|
489
|
+
f"Tracing failed - {func.__name__} is banned from use in traceable functions"
|
|
490
|
+
)
|
|
491
|
+
|
|
492
|
+
setattr(wrapped_func, PATCHING_FLAG_ATTR, True)
|
|
493
|
+
|
|
494
|
+
return wrapped_func
|
procfunc/tracer/proxy.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
import numpy as np
|
|
4
|
+
|
|
5
|
+
from procfunc import compute_graph as cg
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class RngProxy(cg.Proxy):
|
|
10
|
+
"""
|
|
11
|
+
We will do specialcase handling to trace rng nodes through the graph
|
|
12
|
+
|
|
13
|
+
This allows us to know what the results of choice() and other random control flow will be,
|
|
14
|
+
BUT only if the user has kept the rng non-dirty, i.e. the rng used for choice descends from the root via only spawn() calls
|
|
15
|
+
|
|
16
|
+
This is essential so that the result of choice() during tracing is the same as the result of choice() during generation
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
rng: np.random.Generator
|
|
20
|
+
|
|
21
|
+
# true if any randomness has been taken from this rng besides splitting it via spawn
|
|
22
|
+
dirty: bool = False
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
node: cg.Node,
|
|
27
|
+
rng: np.random.Generator,
|
|
28
|
+
dirty: bool = False,
|
|
29
|
+
):
|
|
30
|
+
super().__init__(node)
|
|
31
|
+
node.metadata["known_type"] = np.random.Generator
|
|
32
|
+
node.metadata["varname"] = "rng"
|
|
33
|
+
self.rng = rng
|
|
34
|
+
self.dirty = dirty
|
|
35
|
+
|
|
36
|
+
def spawn(self, n_children: int) -> "RngSpawnResultProxy":
|
|
37
|
+
"""
|
|
38
|
+
Returns a mock node which can only be unpacked into its constitutent mock rngs
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
assert isinstance(n_children, int)
|
|
42
|
+
|
|
43
|
+
spawn_node = cg.MethodCallNode(
|
|
44
|
+
callee=self.node,
|
|
45
|
+
method_name="spawn",
|
|
46
|
+
args=(n_children,),
|
|
47
|
+
kwargs={},
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
child_rngs = list(self.rng.spawn(n_children))
|
|
51
|
+
|
|
52
|
+
return RngSpawnResultProxy(
|
|
53
|
+
node=spawn_node,
|
|
54
|
+
from_rng_proxy=self,
|
|
55
|
+
child_rngs=child_rngs,
|
|
56
|
+
dirty=self.dirty,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
def __getattr__(self, name: str):
|
|
60
|
+
is_dirty_op = (
|
|
61
|
+
not name.startswith("_") and not name == "spawn" and hasattr(self.rng, name)
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
if is_dirty_op:
|
|
65
|
+
self.dirty = True
|
|
66
|
+
|
|
67
|
+
if not hasattr(self.rng, name):
|
|
68
|
+
raise AttributeError(
|
|
69
|
+
f"__getattr__ {name} is invalid because {type(self.rng)} has no attribute {name}"
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
return super().__getattr__(name)
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
@dataclass
|
|
76
|
+
class RngSpawnResultProxy(cg.Proxy):
|
|
77
|
+
from_rng_proxy: "RngProxy"
|
|
78
|
+
child_rngs: list[np.random.Generator]
|
|
79
|
+
dirty: bool = False
|
|
80
|
+
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
node: cg.Node,
|
|
84
|
+
from_rng_proxy: "RngProxy",
|
|
85
|
+
child_rngs: list[np.random.Generator],
|
|
86
|
+
dirty: bool = False,
|
|
87
|
+
):
|
|
88
|
+
super().__init__(node)
|
|
89
|
+
node.metadata["varname"] = "rngs"
|
|
90
|
+
self.from_rng_proxy = from_rng_proxy
|
|
91
|
+
self.child_rngs = child_rngs
|
|
92
|
+
self.dirty = dirty
|
|
93
|
+
if self.dirty:
|
|
94
|
+
import warnings
|
|
95
|
+
|
|
96
|
+
warnings.warn(
|
|
97
|
+
"RngSpawnResultProxy has dirty=True, tracing results may be incomplete"
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def __getitem__(self, idx: int) -> "RngProxy":
|
|
101
|
+
if idx < 0 or idx >= len(self.child_rngs):
|
|
102
|
+
raise IndexError(
|
|
103
|
+
f"Index {idx} out of range for {len(self.child_rngs)} children"
|
|
104
|
+
)
|
|
105
|
+
|
|
106
|
+
getitem_node = cg.MethodCallNode(
|
|
107
|
+
callee=self.node,
|
|
108
|
+
method_name="__getitem__",
|
|
109
|
+
args=(idx,),
|
|
110
|
+
kwargs={},
|
|
111
|
+
metadata={"known_value_type": np.random.Generator},
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
return RngProxy(
|
|
115
|
+
node=getitem_node,
|
|
116
|
+
rng=self.child_rngs[idx],
|
|
117
|
+
dirty=self.dirty,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def __iter__(self):
|
|
121
|
+
"""
|
|
122
|
+
specialcase to allow __iter__() since SpawnResult has known size
|
|
123
|
+
|
|
124
|
+
allows x,y,z = rng.spawn(3) syntax to work in functions that need to be traceable
|
|
125
|
+
"""
|
|
126
|
+
for i in range(len(self.child_rngs)):
|
|
127
|
+
yield self.__getitem__(i)
|