procfunc 0.30.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (76) hide show
  1. procfunc/__init__.py +87 -0
  2. procfunc/color.py +57 -0
  3. procfunc/compute_graph/__init__.py +28 -0
  4. procfunc/compute_graph/compute_graph.py +115 -0
  5. procfunc/compute_graph/node.py +200 -0
  6. procfunc/compute_graph/operators_info.py +92 -0
  7. procfunc/compute_graph/proxy.py +173 -0
  8. procfunc/compute_graph/util.py +282 -0
  9. procfunc/context.py +115 -0
  10. procfunc/control.py +174 -0
  11. procfunc/nodes/__init__.py +66 -0
  12. procfunc/nodes/bindings_util.py +196 -0
  13. procfunc/nodes/bpy_node_info.py +280 -0
  14. procfunc/nodes/compositor.py +2242 -0
  15. procfunc/nodes/execute/construct_nodes.py +571 -0
  16. procfunc/nodes/execute/construct_special_cases.py +246 -0
  17. procfunc/nodes/execute/execute.py +548 -0
  18. procfunc/nodes/execute/infer_runtime_data_type.py +195 -0
  19. procfunc/nodes/execute/util.py +247 -0
  20. procfunc/nodes/func.py +1417 -0
  21. procfunc/nodes/geo.py +4240 -0
  22. procfunc/nodes/manifest.json +8769 -0
  23. procfunc/nodes/math.py +644 -0
  24. procfunc/nodes/node_function.py +160 -0
  25. procfunc/nodes/shader.py +2359 -0
  26. procfunc/nodes/types.py +347 -0
  27. procfunc/ops/__init__.py +35 -0
  28. procfunc/ops/_util.py +275 -0
  29. procfunc/ops/addons.py +59 -0
  30. procfunc/ops/attr.py +426 -0
  31. procfunc/ops/collection.py +90 -0
  32. procfunc/ops/curve.py +18 -0
  33. procfunc/ops/file.py +126 -0
  34. procfunc/ops/manifest.json +39149 -0
  35. procfunc/ops/mesh.py +1510 -0
  36. procfunc/ops/modifier.py +603 -0
  37. procfunc/ops/object.py +258 -0
  38. procfunc/ops/primitives/__init__.py +31 -0
  39. procfunc/ops/primitives/camera.py +45 -0
  40. procfunc/ops/primitives/curve.py +71 -0
  41. procfunc/ops/primitives/light.py +114 -0
  42. procfunc/ops/primitives/mesh.py +358 -0
  43. procfunc/ops/uv.py +271 -0
  44. procfunc/random.py +247 -0
  45. procfunc/tracer/__init__.py +43 -0
  46. procfunc/tracer/decorator.py +121 -0
  47. procfunc/tracer/patch.py +494 -0
  48. procfunc/tracer/proxy.py +127 -0
  49. procfunc/tracer/trace.py +222 -0
  50. procfunc/transforms/__init__.py +49 -0
  51. procfunc/transforms/cleanup.py +214 -0
  52. procfunc/transforms/convert.py +20 -0
  53. procfunc/transforms/distribution.py +191 -0
  54. procfunc/transforms/extract_materials.py +116 -0
  55. procfunc/transforms/infer_distribution.py +326 -0
  56. procfunc/transforms/parameters.py +15 -0
  57. procfunc/transforms/util.py +35 -0
  58. procfunc/transpiler/__init__.py +24 -0
  59. procfunc/transpiler/bpy_to_computegraph.py +1348 -0
  60. procfunc/transpiler/codegen.py +919 -0
  61. procfunc/transpiler/identifiers.py +595 -0
  62. procfunc/transpiler/main.py +299 -0
  63. procfunc/types.py +380 -0
  64. procfunc/util/__init__.py +0 -0
  65. procfunc/util/bpy_info.py +145 -0
  66. procfunc/util/camera.py +0 -0
  67. procfunc/util/keyframe.py +70 -0
  68. procfunc/util/log.py +96 -0
  69. procfunc/util/manifest.py +121 -0
  70. procfunc/util/pytree.py +343 -0
  71. procfunc/util/teardown.py +37 -0
  72. procfunc-0.30.0.dist-info/METADATA +120 -0
  73. procfunc-0.30.0.dist-info/RECORD +76 -0
  74. procfunc-0.30.0.dist-info/WHEEL +5 -0
  75. procfunc-0.30.0.dist-info/licenses/LICENSE.md +11 -0
  76. procfunc-0.30.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,222 @@
1
+ """
2
+ torch.fx / jax-like function-compute-graph tracing tool, but specially designed for procedural generation functions
3
+
4
+
5
+ This tool was heavily inspired by torch.fx.symbolic_trace https://docs.pytorch.org/docs/2.6/fx.html
6
+ """
7
+
8
+ import builtins
9
+ import inspect
10
+ import logging
11
+ import math
12
+ import random
13
+ from types import ModuleType
14
+ from typing import Any, Callable
15
+
16
+ import numpy as np
17
+ import numpy.linalg
18
+
19
+ import procfunc as pf
20
+ from procfunc import compute_graph as cg
21
+ from procfunc.util import pytree
22
+
23
+ from .patch import (
24
+ Patcher,
25
+ PatchFunctionTarget,
26
+ TraceLevel,
27
+ )
28
+ from .proxy import RngProxy
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ _autowrap_modules: list[tuple[ModuleType, bool, TraceLevel]] = []
33
+
34
+
35
+ def autowrap_module(
36
+ module: ModuleType,
37
+ allow_exec: bool = False,
38
+ trace_level: TraceLevel = TraceLevel.PRIMITIVES,
39
+ ):
40
+ _autowrap_modules.append((module, allow_exec, trace_level))
41
+
42
+
43
+ autowrap_module(math, allow_exec=True)
44
+ autowrap_module(np, allow_exec=True)
45
+ autowrap_module(numpy.linalg, allow_exec=True)
46
+
47
+
48
+ _banned_modules: list[ModuleType] = [
49
+ np.random,
50
+ random,
51
+ ]
52
+
53
+
54
+ def add_banned_module(module: ModuleType):
55
+ _banned_modules.append(module)
56
+
57
+
58
+ _patch_function_targets: list[PatchFunctionTarget] = []
59
+
60
+
61
+ def add_wrap_target(target: PatchFunctionTarget):
62
+ _patch_function_targets.append(target)
63
+
64
+
65
+ WRAP_BUILTINS = ["min", "max", "abs", "round", "sum"]
66
+
67
+ for _builtin_name in WRAP_BUILTINS:
68
+ add_wrap_target(
69
+ PatchFunctionTarget(
70
+ frame=builtins.__dict__,
71
+ name=_builtin_name,
72
+ trace_level=TraceLevel.PRIMITIVES,
73
+ normalize=False,
74
+ allow_exec=True,
75
+ source_name="builtins",
76
+ )
77
+ )
78
+
79
+ WRAP_CONSTRUCTORS = ["Vector", "Color", "Euler"]
80
+
81
+ for _name in WRAP_CONSTRUCTORS:
82
+ add_wrap_target(
83
+ PatchFunctionTarget(
84
+ frame=pf.__dict__,
85
+ name=_name,
86
+ trace_level=TraceLevel.PRIMITIVES,
87
+ normalize=False,
88
+ allow_exec=True,
89
+ source_name="mathutils",
90
+ )
91
+ )
92
+
93
+
94
+ _search_scopes: list[dict] = []
95
+
96
+
97
+ def add_search_scope(module: ModuleType):
98
+ """
99
+ Causes an intermediate module to be searched to discover any existing targets that need to be patched
100
+
101
+ e.g. for procfunc.nodes.to_mesh_object, the original to_mesh_object is already a target, but we need to
102
+ add_search_scope on the `nodes` module so that that module's references to to_mesh_object get wrapped.
103
+ """
104
+
105
+ _search_scopes.append(module.__dict__)
106
+
107
+
108
+ def _map_args(
109
+ func: Callable,
110
+ **inputs: Any,
111
+ ) -> dict[str, cg.Proxy | Any]:
112
+ signature = inspect.signature(func)
113
+
114
+ res = {}
115
+
116
+ for name, param in signature.parameters.items():
117
+ if name in inputs:
118
+ val = inputs[name]
119
+ if isinstance(val, np.random.Generator):
120
+ node = cg.ConstantNode(value=val)
121
+ res[name] = RngProxy(node, val, dirty=False)
122
+ elif isinstance(val, cg.ConstantNode) and isinstance(
123
+ val.value, np.random.Generator
124
+ ):
125
+ res[name] = RngProxy(val, val.value, dirty=False)
126
+ else:
127
+ res[name] = val
128
+ continue
129
+
130
+ if param.default is not param.empty:
131
+ res[name] = param.default
132
+ continue
133
+
134
+ node = cg.InputPlaceholderNode(
135
+ name=name, default_value=None, metadata={"varname": name}
136
+ )
137
+ res[name] = cg.Proxy(node)
138
+
139
+ return res
140
+
141
+
142
+ def trace(
143
+ func: Callable,
144
+ trace_level: TraceLevel = TraceLevel.GENERATORS,
145
+ name: str | None = None,
146
+ **inputs: Any,
147
+ ):
148
+ """
149
+ Turn a python function into a graph datastructure.
150
+
151
+ Using this datastructure is (usually) equivelent to executing the function.
152
+
153
+ Args:
154
+ func: The function to trace
155
+ trace_level: Granularity of the graph. Functions at this level become leaves;
156
+ finer functions are traced through. choice() peeks through all options when
157
+ trace_level >= RANDOM_CONTROL, or resolves to the chosen branch when finer.
158
+ """
159
+
160
+ logger.debug(f"Tracing {func} {id(func)=} with {inputs=} {trace_level=}")
161
+
162
+ if name is None:
163
+ assert hasattr(func, "__name__")
164
+ assert isinstance(func.__name__, str)
165
+ name = func.__name__
166
+
167
+ proxy_args = _map_args(func, **inputs)
168
+
169
+ if pf.context.globals.current_trace_level is not None:
170
+ # TODO we can lift this restriction fairly(?) easily by having a global patcher & saving/restoring this state
171
+ # rather than setting to false at end of function. see fx.trace for an example
172
+ raise RuntimeError(
173
+ f"Can't trace {name}, tracing is already in progress for another function. "
174
+ "Nested tracing is not yet supported - contact the developers to request"
175
+ )
176
+
177
+ pf.context.globals.current_trace_level = trace_level.value
178
+
179
+ patcher = Patcher(
180
+ trace_level=trace_level,
181
+ autopatch_wrap_modules=_autowrap_modules,
182
+ autopatch_remove_modules=_banned_modules,
183
+ search_scopes=_search_scopes,
184
+ patch_functions=_patch_function_targets,
185
+ )
186
+
187
+ try:
188
+ func = patcher.apply_preexecute_patches(func, trace_level)
189
+ logger.debug(f"Executing {func.__name__} {id(func)=}")
190
+ func_result = func(**proxy_args)
191
+ finally:
192
+ pf.context.globals.current_trace_level = None
193
+ patcher.unpatch_all()
194
+
195
+ metadata = {
196
+ "operations": [
197
+ (trace, {"func": func, "trace_level": trace_level}),
198
+ ],
199
+ }
200
+
201
+ def extract_node(v):
202
+ if isinstance(v, cg.Proxy):
203
+ return v.node
204
+ return cg.ConstantNode(value=v)
205
+
206
+ outputs = pytree.PyTree(func_result)
207
+ outputs = outputs.map(extract_node)
208
+
209
+ input_nodes = {
210
+ k: v.node
211
+ for k, v in proxy_args.items()
212
+ if isinstance(v, cg.Proxy) and isinstance(v.node, cg.InputPlaceholderNode)
213
+ }
214
+
215
+ compgraph = cg.ComputeGraph(
216
+ inputs=pytree.PyTree(input_nodes),
217
+ outputs=outputs,
218
+ name=name,
219
+ metadata=metadata,
220
+ )
221
+
222
+ return compgraph
@@ -0,0 +1,49 @@
1
+ from .cleanup import (
2
+ coerce_shaders_to_materialresult,
3
+ eliminate_duplicate_result_types,
4
+ eliminate_duplicate_subgraphs,
5
+ extract_shader_vectors_as_inputs,
6
+ fill_graph_defaults_with_call_node,
7
+ remove_v1_name_from_graph,
8
+ replace_ids,
9
+ )
10
+ from .convert import (
11
+ colors_to_hsv_definition,
12
+ )
13
+ from .distribution import (
14
+ distribution_to_mode,
15
+ outlier_distribution,
16
+ )
17
+ from .extract_materials import (
18
+ extract_materials_from_graph,
19
+ extract_materials_from_graphs,
20
+ )
21
+ from .infer_distribution import (
22
+ infer_distribution_hypercube,
23
+ infer_nodegroup_distributions,
24
+ )
25
+ from .parameters import (
26
+ extract_parameter_distributions,
27
+ )
28
+ from .util import (
29
+ map_graph_list,
30
+ map_subgraphs,
31
+ )
32
+
33
+ __all__ = [
34
+ "coerce_shaders_to_materialresult",
35
+ "eliminate_duplicate_result_types",
36
+ "eliminate_duplicate_subgraphs",
37
+ "extract_shader_vectors_as_inputs",
38
+ "fill_graph_defaults_with_call_node",
39
+ "remove_v1_name_from_graph",
40
+ "replace_ids",
41
+ "colors_to_hsv_definition",
42
+ "distribution_to_mode",
43
+ "outlier_distribution",
44
+ "infer_distribution_hypercube",
45
+ "infer_nodegroup_distributions",
46
+ "extract_parameter_distributions",
47
+ "map_graph_list",
48
+ "map_subgraphs",
49
+ ]
@@ -0,0 +1,214 @@
1
+ import logging
2
+ from collections import defaultdict
3
+ from typing import Any, Callable
4
+
5
+ from procfunc import compute_graph as cg
6
+ from procfunc import types as t
7
+ from procfunc.nodes import types as nt
8
+ from procfunc.nodes.shader import coord, geometry
9
+ from procfunc.util import pytree
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ def remove_v1_name_from_graph(
15
+ _call_node: cg.Node, graph: cg.ComputeGraph
16
+ ) -> cg.ComputeGraph:
17
+ if graph.name.startswith("nodegroup_"):
18
+ graph.name = graph.name.replace("nodegroup_", "")
19
+ if graph.name.startswith("shader_"):
20
+ graph.name = graph.name.replace("shader_", "")
21
+ return graph
22
+
23
+
24
+ def eliminate_duplicate_subgraphs(
25
+ graphs: list[cg.ComputeGraph],
26
+ ) -> list[cg.ComputeGraph]:
27
+ unique: list[cg.ComputeGraph] = []
28
+ removed: list[cg.ComputeGraph] = []
29
+ # maps id(duplicate_subgraph) -> canonical subgraph to replace it with
30
+ replacements: dict[int, cg.ComputeGraph] = {}
31
+
32
+ for topgraph in graphs:
33
+ subgraphs = reversed(
34
+ list(cg.traverse_nested_graphs(topgraph, yield_call_nodes=True))
35
+ )
36
+ for _call_node, subgraph in subgraphs:
37
+ match = next((g for g in unique if cg.graph_nodes_equal(subgraph, g)), None)
38
+ if match is not None:
39
+ if len(subgraph.name) < len(match.name):
40
+ match.name = subgraph.name
41
+ replacements[id(subgraph)] = match
42
+ removed.append(subgraph)
43
+ else:
44
+ unique.append(subgraph)
45
+
46
+ # second pass: update ALL call nodes (in all nested subgraphs) that reference a replaced subgraph
47
+ for topgraph in graphs:
48
+ for subgraph in cg.traverse_nested_graphs(topgraph):
49
+ for node in cg.traverse_depth_first(subgraph):
50
+ if (
51
+ isinstance(node, cg.SubgraphCallNode)
52
+ and id(node.subgraph) in replacements
53
+ ):
54
+ node.subgraph = replacements[id(node.subgraph)]
55
+
56
+ logger.debug(f"Eliminated duplicated subgraphs {[g.name for g in removed]}")
57
+
58
+ return graphs
59
+
60
+
61
+ def eliminate_duplicate_result_types(
62
+ graphs: list[cg.ComputeGraph],
63
+ uses_threshold: int = 1,
64
+ ) -> list[cg.ComputeGraph]:
65
+ rettype_uses: dict[type, list[cg.ComputeGraph]] = defaultdict(list)
66
+
67
+ for graph in graphs:
68
+ for subgraph in cg.traverse_nested_graphs(graph):
69
+ result_type = subgraph.outputs.toplevel_type()
70
+ if result_type is None or not pytree.is_type_namedtuple(result_type):
71
+ continue
72
+
73
+ for rt in rettype_uses.keys():
74
+ if list(rt._fields) == list(result_type._fields):
75
+ result_type = rt
76
+ break
77
+
78
+ rettype_uses[result_type].append(subgraph)
79
+
80
+ for rettype, uses in rettype_uses.items():
81
+ if len(uses) <= uses_threshold:
82
+ continue
83
+ first_rettype = uses[0].outputs.toplevel_type()
84
+ for subgraph in uses[1:]:
85
+ subgraph.outputs.spec.container = first_rettype
86
+
87
+ return graphs
88
+
89
+
90
+ def fill_graph_defaults_with_call_node(
91
+ call_node: cg.SubgraphCallNode,
92
+ graph: cg.ComputeGraph,
93
+ ) -> cg.ComputeGraph:
94
+ if call_node is None:
95
+ return graph
96
+
97
+ if any(
98
+ isinstance(arg.default_value, float) and arg.default_value != 0.0
99
+ for arg in graph.inputs.values()
100
+ ):
101
+ logger.debug(
102
+ f"Skipping {graph.name} because it has nondefault existing default args"
103
+ )
104
+ return graph
105
+
106
+ for name, inpnode in graph.inputs.items():
107
+ fillval = call_node.kwargs.get(name, None)
108
+ if fillval is not None and not isinstance(fillval, cg.Node):
109
+ inpnode.kwargs["default_value"] = fillval
110
+
111
+ return graph
112
+
113
+
114
+ def coerce_shaders_to_materialresult(
115
+ _call_node: cg.Node, subgraph: cg.ComputeGraph
116
+ ) -> cg.ComputeGraph:
117
+ if subgraph.outputs.toplevel_type() is t.Material:
118
+ return subgraph
119
+ outputs = subgraph.outputs.dict()
120
+ surface = outputs.get("surface") or outputs.get("bsdf")
121
+ if surface is None:
122
+ return subgraph
123
+ shader_outputs = {
124
+ "surface": surface,
125
+ "displacement": outputs.get("displacement"),
126
+ "volume": outputs.get("volume"),
127
+ }
128
+ if len(outputs) > len(shader_outputs):
129
+ logger.warning(
130
+ f"{coerce_shaders_to_materialresult.__name__} skipping due to extra outputs: {outputs.keys()}"
131
+ )
132
+ return subgraph
133
+ logger.debug(
134
+ f"{coerce_shaders_to_materialresult.__name__} converted {subgraph.name} output"
135
+ )
136
+ subgraph.outputs = pytree.PyTree(t.Material(**shader_outputs))
137
+
138
+ return subgraph
139
+
140
+
141
+ def replace_ids(
142
+ graph: cg.ComputeGraph,
143
+ ids: set[int],
144
+ val: Any,
145
+ ):
146
+ """
147
+ Pull out hardcoded arguments to be inputs to the graph instead
148
+
149
+ Args:
150
+ graph: The graph to extract constants from
151
+ extract_mask: A mask of which args to extract. The key is a tuple of the parent node id and the arg name.
152
+ """
153
+
154
+ assert isinstance(graph, cg.ComputeGraph)
155
+
156
+ for name, parent, child in cg.traverse_depth_first(
157
+ graph, yield_consts=True, yield_name=True, yield_parent=True
158
+ ):
159
+ if id(child) not in ids:
160
+ continue
161
+ if isinstance(name, int):
162
+ args = list(parent.args)
163
+ args[name] = val
164
+ parent.args = tuple(args)
165
+ else:
166
+ parent.kwargs[name] = val
167
+
168
+ return graph
169
+
170
+
171
+ def extract_as_input(
172
+ graph: cg.ComputeGraph,
173
+ nodes: set[int],
174
+ name: str,
175
+ arg_type: type,
176
+ ):
177
+ inp = cg.InputPlaceholderNode(
178
+ default_value=None, metadata={"known_value_type": arg_type, "varname": name}
179
+ )
180
+
181
+ inputs = graph.inputs.obj()
182
+ assert isinstance(inputs, dict), inputs
183
+ inputs[name] = inp
184
+ graph.inputs = pytree.PyTree(inputs)
185
+
186
+ return replace_ids(graph, nodes, inp)
187
+
188
+
189
+ def extract_shader_vectors_as_inputs(
190
+ graph: cg.ComputeGraph,
191
+ extract_funcs: list[Callable[..., Any]] | None = None,
192
+ ):
193
+ """
194
+ Pull out shader vectors as inputs to the graph instead
195
+ """
196
+
197
+ if extract_funcs is None:
198
+ extract_funcs = [coord, geometry]
199
+
200
+ def _is_vector_target(node: cg.FunctionCallNode) -> bool:
201
+ return isinstance(node, cg.FunctionCallNode) and node.func in extract_funcs
202
+
203
+ vector_nodes = set(
204
+ id(node)
205
+ for node in cg.traverse_depth_first(graph)
206
+ if _is_vector_target(node)
207
+ or (isinstance(node, cg.GetAttributeNode) and _is_vector_target(node.args[0]))
208
+ )
209
+
210
+ if len(vector_nodes) == 0:
211
+ return graph
212
+
213
+ extract_as_input(graph, vector_nodes, "vector", nt.ProcNode[t.Vector])
214
+ return graph
@@ -0,0 +1,20 @@
1
+ import procfunc as pf
2
+ from procfunc import compute_graph as cg
3
+
4
+
5
+ def colors_to_hsv_definition(graph: cg.ComputeGraph) -> cg.ComputeGraph:
6
+ for node in cg.traverse_depth_first(graph):
7
+ for i, arg in enumerate(node.args):
8
+ if isinstance(arg, pf.Color):
9
+ hsv = tuple(round(x, 4) for x in arg.hsv)
10
+ node.args[i] = cg.FunctionCallNode(
11
+ pf.color.hsv_to_rgba, args=(), kwargs={"hsv": hsv}
12
+ )
13
+ for key, arg in node.kwargs.items():
14
+ if isinstance(arg, pf.Color):
15
+ hsv = tuple(round(x, 4) for x in arg.hsv)
16
+ node.kwargs[key] = cg.FunctionCallNode(
17
+ pf.color.hsv_to_rgba, args=(), kwargs={"hsv": hsv}
18
+ )
19
+
20
+ return graph
@@ -0,0 +1,191 @@
1
+ import enum
2
+ import logging
3
+ from typing import Any, Callable
4
+
5
+ import numpy as np
6
+
7
+ import procfunc as pf
8
+ from procfunc import compute_graph as cg
9
+ from procfunc.compute_graph import transform_compute_graph
10
+ from procfunc.random import random_distrib_funcs
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class NumpyRandomDistrib(enum.Enum):
16
+ # from https://numpy.org/doc/2.2/reference/random/generator.html#distributions
17
+ BETA = "beta"
18
+ BINOMIAL = "binomial"
19
+ CHISQUARE = "chisquare"
20
+ EXPONENTIAL = "exponential"
21
+ F = "f"
22
+ GAMMA = "gamma"
23
+ GAUSSIAN = "gaussian"
24
+ GEOMETRIC = "geometric"
25
+ GUMBEL = "gumbel"
26
+ HYPERGEOMETRIC = "hypergeometric"
27
+ LAPLACE = "laplace"
28
+ LOGISTIC = "logistic"
29
+ LOGNORMAL = "lognormal"
30
+ LOGSERIES = "logseries"
31
+ MULTINOMIAL = "multinomial"
32
+ MULTIVARIATE_NORMAL = "multivariate_normal"
33
+ NEGATIVE_BINOMIAL = "negative_binomial"
34
+ NORMAL = "normal"
35
+ NONCENTRAL_CHISQUARE = "noncentral_chisquare"
36
+ NONCENTRAL_F = "noncentral_f"
37
+ PARETO = "pareto"
38
+ POISSON = "poisson"
39
+ POWER = "power"
40
+ RAYLEIGH = "rayleigh"
41
+ SHUFFLE = "shuffle"
42
+ STANDARD_CAUCHY = "standard_cauchy"
43
+ STANDARD_EXPONENTIAL = "standard_exponential"
44
+ STANDARD_GAMMA = "standard_gamma"
45
+ STANDARD_NORMAL = "standard_normal"
46
+ STANDARD_LOGISTIC = "standard_logistic"
47
+ STANDARD_LAPLACE = "standard_laplace"
48
+ STANDARD_PARETO = "standard_pareto"
49
+ STANDARD_T = "standard_t"
50
+ TRIANGULAR = "triangular"
51
+ UNIFORM = "uniform"
52
+ VONMISES = "vonmises"
53
+ WALD = "wald"
54
+ WEIBULL = "weibull"
55
+ ZIPF = "zipf"
56
+ INTEGER = "integers"
57
+
58
+
59
+ FUNCNAME_TO_DISTRIB = {v.value: v for v in NumpyRandomDistrib.__members__.values()}
60
+
61
+
62
+ def as_distribution(
63
+ node: cg.Node,
64
+ ) -> NumpyRandomDistrib | Callable[[np.random.Generator, ...], Any] | None:
65
+ match node:
66
+ case cg.FunctionCallNode(func=x) if x in random_distrib_funcs:
67
+ return x
68
+ case cg.MethodCallNode(method_name=method, args=(arg_0,)) if (
69
+ method in FUNCNAME_TO_DISTRIB
70
+ and isinstance(arg_0, cg.Node)
71
+ and arg_0.metadata.get("known_value_type") is np.random.Generator
72
+ ):
73
+ return FUNCNAME_TO_DISTRIB[method]
74
+ case _:
75
+ return None
76
+
77
+
78
+ def distribution_to_mode(
79
+ compute_graph: cg.ComputeGraph,
80
+ graph_name: str | None = None,
81
+ ) -> cg.ComputeGraph:
82
+ """
83
+ Transform a generator to use the mode of the distribution instead of random samples
84
+ """
85
+
86
+ def map_to_mode(node: cg.Node) -> cg.Node | float:
87
+ match as_distribution(node):
88
+ case NumpyRandomDistrib.UNIFORM:
89
+ assert len(node.args) == 1 or len(node.kwargs) == 0, (
90
+ "implementation may be errored for mix of args and kwargs"
91
+ )
92
+ low = node.kwargs.get("low", node.args[1])
93
+ high = node.kwargs.get("high", node.args[2])
94
+ if isinstance(low, cg.Node) or isinstance(high, cg.Node):
95
+ logger.warning(
96
+ f"Uniform mode not implemented for {node=} with non-constant {low=} {high=}"
97
+ )
98
+ return node
99
+ return (low + high) / 2
100
+ case NumpyRandomDistrib.NORMAL:
101
+ assert len(node.args) == 1 or len(node.kwargs) == 0, (
102
+ "implementation may be errored for mix of args and kwargs"
103
+ )
104
+ mean = node.kwargs.get("mean", node.args[1])
105
+ _std = node.kwargs.get("std", node.args[2])
106
+ if isinstance(mean, cg.Node):
107
+ logger.warning(
108
+ f"Normal mode not implemented for {node=} with non-constant {mean=}"
109
+ )
110
+ return node
111
+ return mean
112
+ case _:
113
+ return node
114
+
115
+ return transform_compute_graph(
116
+ compute_graph,
117
+ map_to_mode,
118
+ graph_name=graph_name or compute_graph.name + "_mode",
119
+ )
120
+
121
+
122
+ def map_to_outlier(
123
+ node: cg.Node, pct: float = 0.05, normal_clip_std: float = 3.0
124
+ ) -> cg.Node:
125
+ varname = node.metadata.get("varname", None)
126
+ varname = (varname + "_outlier") if varname else None
127
+
128
+ match as_distribution(node):
129
+ case NumpyRandomDistrib.UNIFORM:
130
+ assert len(node.args) == 1 or len(node.kwargs) == 0, (
131
+ "cant handle arg/kwarg mix"
132
+ )
133
+ low = node.kwargs.get("low", node.args[1])
134
+ high = node.kwargs.get("high", node.args[2])
135
+ if isinstance(low, cg.Node) or isinstance(high, cg.Node):
136
+ logger.warning(
137
+ f"outlier not implemented for {node=} with {min=} {max=}"
138
+ )
139
+ return node
140
+ rng = node.args[0]
141
+ assert isinstance(rng, cg.Node), f"got {node.args[0]=}"
142
+ return cg.FunctionCallNode(
143
+ func=pf.random.uniform_tails,
144
+ args=(rng,),
145
+ kwargs=dict(low=low, high=high, tail_pct=pct),
146
+ varname=varname,
147
+ )
148
+ case NumpyRandomDistrib.NORMAL:
149
+ assert len(node.args) == 1 or len(node.kwargs) == 0, (
150
+ "cant handle arg/kwarg mix"
151
+ )
152
+ mean = node.kwargs.get("mean", node.args[1])
153
+ std = node.kwargs.get("std", node.args[2])
154
+ if isinstance(mean, cg.Node):
155
+ logger.warning(f"outlier not implemented for {node=} with {mean=}")
156
+ return node
157
+ rng = node.args[0]
158
+ return cg.FunctionCallNode(
159
+ func=pf.random.uniform_tails,
160
+ args=(rng,),
161
+ kwargs=dict(
162
+ tail_pct=pct,
163
+ low=mean - normal_clip_std * std,
164
+ high=mean + normal_clip_std * std,
165
+ ),
166
+ varname=varname,
167
+ )
168
+ case func if func in pf.random.random_distrib_funcs:
169
+ logger.warning(
170
+ f"{outlier_distribution.__name__} not implemented for {func}"
171
+ )
172
+ return node
173
+ case _:
174
+ return node
175
+
176
+
177
+ def outlier_distribution(
178
+ compute_graph: cg.ComputeGraph,
179
+ pct: float = 0.05,
180
+ graph_name: str | None = None,
181
+ normal_clip_std: float = 3.0,
182
+ ) -> cg.ComputeGraph:
183
+ """
184
+ Transform a generator to generate outliers with a given probability
185
+ """
186
+
187
+ return transform_compute_graph(
188
+ compute_graph,
189
+ lambda node: map_to_outlier(node, pct, normal_clip_std),
190
+ graph_name=graph_name or compute_graph.name + "_outlier",
191
+ )