warp-lang 1.2.2__py3-none-macosx_10_13_universal2.whl → 1.3.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +8 -6
- warp/autograd.py +823 -0
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +6 -2
- warp/builtins.py +1410 -886
- warp/codegen.py +503 -166
- warp/config.py +48 -18
- warp/context.py +400 -198
- warp/dlpack.py +8 -0
- warp/examples/assets/bunny.usd +0 -0
- warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
- warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
- warp/examples/benchmarks/benchmark_launches.py +1 -1
- warp/examples/core/example_cupy.py +78 -0
- warp/examples/fem/example_apic_fluid.py +17 -36
- warp/examples/fem/example_burgers.py +9 -18
- warp/examples/fem/example_convection_diffusion.py +7 -17
- warp/examples/fem/example_convection_diffusion_dg.py +27 -47
- warp/examples/fem/example_deformed_geometry.py +11 -22
- warp/examples/fem/example_diffusion.py +7 -18
- warp/examples/fem/example_diffusion_3d.py +24 -28
- warp/examples/fem/example_diffusion_mgpu.py +7 -14
- warp/examples/fem/example_magnetostatics.py +190 -0
- warp/examples/fem/example_mixed_elasticity.py +111 -80
- warp/examples/fem/example_navier_stokes.py +30 -34
- warp/examples/fem/example_nonconforming_contact.py +290 -0
- warp/examples/fem/example_stokes.py +17 -32
- warp/examples/fem/example_stokes_transfer.py +12 -21
- warp/examples/fem/example_streamlines.py +350 -0
- warp/examples/fem/utils.py +936 -0
- warp/fabric.py +5 -2
- warp/fem/__init__.py +13 -3
- warp/fem/cache.py +161 -11
- warp/fem/dirichlet.py +37 -28
- warp/fem/domain.py +105 -14
- warp/fem/field/__init__.py +14 -3
- warp/fem/field/field.py +454 -11
- warp/fem/field/nodal_field.py +33 -18
- warp/fem/geometry/deformed_geometry.py +50 -15
- warp/fem/geometry/hexmesh.py +12 -24
- warp/fem/geometry/nanogrid.py +106 -31
- warp/fem/geometry/quadmesh_2d.py +6 -11
- warp/fem/geometry/tetmesh.py +103 -61
- warp/fem/geometry/trimesh_2d.py +98 -47
- warp/fem/integrate.py +231 -186
- warp/fem/operator.py +14 -9
- warp/fem/quadrature/pic_quadrature.py +35 -9
- warp/fem/quadrature/quadrature.py +119 -32
- warp/fem/space/basis_space.py +98 -22
- warp/fem/space/collocated_function_space.py +3 -1
- warp/fem/space/function_space.py +7 -2
- warp/fem/space/grid_2d_function_space.py +3 -3
- warp/fem/space/grid_3d_function_space.py +4 -4
- warp/fem/space/hexmesh_function_space.py +3 -2
- warp/fem/space/nanogrid_function_space.py +12 -14
- warp/fem/space/partition.py +45 -47
- warp/fem/space/restriction.py +19 -16
- warp/fem/space/shape/cube_shape_function.py +91 -3
- warp/fem/space/shape/shape_function.py +7 -0
- warp/fem/space/shape/square_shape_function.py +32 -0
- warp/fem/space/shape/tet_shape_function.py +11 -7
- warp/fem/space/shape/triangle_shape_function.py +10 -1
- warp/fem/space/topology.py +116 -42
- warp/fem/types.py +8 -1
- warp/fem/utils.py +301 -83
- warp/native/array.h +16 -0
- warp/native/builtin.h +0 -15
- warp/native/cuda_util.cpp +14 -6
- warp/native/exports.h +1348 -1308
- warp/native/quat.h +79 -0
- warp/native/rand.h +27 -4
- warp/native/sparse.cpp +83 -81
- warp/native/sparse.cu +381 -453
- warp/native/vec.h +64 -0
- warp/native/volume.cpp +40 -49
- warp/native/volume_builder.cu +2 -3
- warp/native/volume_builder.h +12 -17
- warp/native/warp.cu +3 -3
- warp/native/warp.h +69 -59
- warp/render/render_opengl.py +17 -9
- warp/sim/articulation.py +117 -17
- warp/sim/collide.py +35 -29
- warp/sim/model.py +123 -18
- warp/sim/render.py +3 -1
- warp/sparse.py +867 -203
- warp/stubs.py +312 -541
- warp/tape.py +29 -1
- warp/tests/disabled_kinematics.py +1 -1
- warp/tests/test_adam.py +1 -1
- warp/tests/test_arithmetic.py +1 -1
- warp/tests/test_array.py +58 -1
- warp/tests/test_array_reduce.py +1 -1
- warp/tests/test_async.py +1 -1
- warp/tests/test_atomic.py +1 -1
- warp/tests/test_bool.py +1 -1
- warp/tests/test_builtins_resolution.py +1 -1
- warp/tests/test_bvh.py +6 -1
- warp/tests/test_closest_point_edge_edge.py +1 -1
- warp/tests/test_codegen.py +66 -1
- warp/tests/test_compile_consts.py +1 -1
- warp/tests/test_conditional.py +1 -1
- warp/tests/test_copy.py +1 -1
- warp/tests/test_ctypes.py +1 -1
- warp/tests/test_dense.py +1 -1
- warp/tests/test_devices.py +1 -1
- warp/tests/test_dlpack.py +1 -1
- warp/tests/test_examples.py +33 -4
- warp/tests/test_fabricarray.py +5 -2
- warp/tests/test_fast_math.py +1 -1
- warp/tests/test_fem.py +213 -6
- warp/tests/test_fp16.py +1 -1
- warp/tests/test_func.py +1 -1
- warp/tests/test_future_annotations.py +90 -0
- warp/tests/test_generics.py +1 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_grad_customs.py +1 -1
- warp/tests/test_grad_debug.py +247 -0
- warp/tests/test_hash_grid.py +6 -1
- warp/tests/test_implicit_init.py +354 -0
- warp/tests/test_import.py +1 -1
- warp/tests/test_indexedarray.py +1 -1
- warp/tests/test_intersect.py +1 -1
- warp/tests/test_jax.py +1 -1
- warp/tests/test_large.py +1 -1
- warp/tests/test_launch.py +1 -1
- warp/tests/test_lerp.py +1 -1
- warp/tests/test_linear_solvers.py +1 -1
- warp/tests/test_lvalue.py +1 -1
- warp/tests/test_marching_cubes.py +5 -2
- warp/tests/test_mat.py +34 -35
- warp/tests/test_mat_lite.py +2 -1
- warp/tests/test_mat_scalar_ops.py +1 -1
- warp/tests/test_math.py +1 -1
- warp/tests/test_matmul.py +20 -16
- warp/tests/test_matmul_lite.py +1 -1
- warp/tests/test_mempool.py +1 -1
- warp/tests/test_mesh.py +5 -2
- warp/tests/test_mesh_query_aabb.py +1 -1
- warp/tests/test_mesh_query_point.py +1 -1
- warp/tests/test_mesh_query_ray.py +1 -1
- warp/tests/test_mlp.py +1 -1
- warp/tests/test_model.py +1 -1
- warp/tests/test_module_hashing.py +77 -1
- warp/tests/test_modules_lite.py +1 -1
- warp/tests/test_multigpu.py +1 -1
- warp/tests/test_noise.py +1 -1
- warp/tests/test_operators.py +1 -1
- warp/tests/test_options.py +1 -1
- warp/tests/test_overwrite.py +542 -0
- warp/tests/test_peer.py +1 -1
- warp/tests/test_pinned.py +1 -1
- warp/tests/test_print.py +1 -1
- warp/tests/test_quat.py +15 -1
- warp/tests/test_rand.py +1 -1
- warp/tests/test_reload.py +1 -1
- warp/tests/test_rounding.py +1 -1
- warp/tests/test_runlength_encode.py +1 -1
- warp/tests/test_scalar_ops.py +95 -0
- warp/tests/test_sim_grad.py +1 -1
- warp/tests/test_sim_kinematics.py +1 -1
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +82 -15
- warp/tests/test_spatial.py +1 -1
- warp/tests/test_special_values.py +2 -11
- warp/tests/test_streams.py +11 -1
- warp/tests/test_struct.py +1 -1
- warp/tests/test_tape.py +1 -1
- warp/tests/test_torch.py +194 -1
- warp/tests/test_transient_module.py +1 -1
- warp/tests/test_types.py +1 -1
- warp/tests/test_utils.py +1 -1
- warp/tests/test_vec.py +15 -63
- warp/tests/test_vec_lite.py +2 -1
- warp/tests/test_vec_scalar_ops.py +65 -1
- warp/tests/test_verify_fp.py +1 -1
- warp/tests/test_volume.py +28 -2
- warp/tests/test_volume_write.py +1 -1
- warp/tests/unittest_serial.py +1 -1
- warp/tests/unittest_suites.py +9 -1
- warp/tests/walkthrough_debug.py +1 -1
- warp/thirdparty/unittest_parallel.py +2 -5
- warp/torch.py +103 -41
- warp/types.py +341 -224
- warp/utils.py +11 -2
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
- warp_lang-1.3.0.dist-info/RECORD +368 -0
- warp/examples/fem/bsr_utils.py +0 -378
- warp/examples/fem/mesh_utils.py +0 -133
- warp/examples/fem/plot_utils.py +0 -292
- warp_lang-1.2.2.dist-info/RECORD +0 -359
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -10,13 +10,14 @@ from __future__ import annotations
|
|
|
10
10
|
import ast
|
|
11
11
|
import builtins
|
|
12
12
|
import ctypes
|
|
13
|
+
import functools
|
|
13
14
|
import inspect
|
|
14
15
|
import math
|
|
15
16
|
import re
|
|
16
17
|
import sys
|
|
17
18
|
import textwrap
|
|
18
19
|
import types
|
|
19
|
-
from typing import Any, Callable, Dict, Mapping
|
|
20
|
+
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
|
|
20
21
|
|
|
21
22
|
import warp.config
|
|
22
23
|
from warp.types import *
|
|
@@ -84,17 +85,108 @@ comparison_chain_strings = [
|
|
|
84
85
|
]
|
|
85
86
|
|
|
86
87
|
|
|
88
|
+
def values_check_equal(a, b):
|
|
89
|
+
if isinstance(a, Sequence) and isinstance(b, Sequence):
|
|
90
|
+
if len(a) != len(b):
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
return all(x == y for x, y in zip(a, b))
|
|
94
|
+
|
|
95
|
+
return a == b
|
|
96
|
+
|
|
97
|
+
|
|
87
98
|
def op_str_is_chainable(op: str) -> builtins.bool:
|
|
88
99
|
return op in comparison_chain_strings
|
|
89
100
|
|
|
90
101
|
|
|
102
|
+
def get_closure_cell_contents(obj):
|
|
103
|
+
"""Retrieve a closure's cell contents or `None` if it's empty."""
|
|
104
|
+
try:
|
|
105
|
+
return obj.cell_contents
|
|
106
|
+
except ValueError:
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
return None
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
|
|
113
|
+
"""Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
|
|
114
|
+
# Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
|
|
115
|
+
if not annotations:
|
|
116
|
+
return {}
|
|
117
|
+
|
|
118
|
+
if not any(isinstance(x, str) for x in annotations.values()):
|
|
119
|
+
# No annotation to un-stringize.
|
|
120
|
+
return annotations
|
|
121
|
+
|
|
122
|
+
if isinstance(obj, type):
|
|
123
|
+
# class
|
|
124
|
+
globals = {}
|
|
125
|
+
module_name = getattr(obj, "__module__", None)
|
|
126
|
+
if module_name:
|
|
127
|
+
module = sys.modules.get(module_name, None)
|
|
128
|
+
if module:
|
|
129
|
+
globals = getattr(module, "__dict__", {})
|
|
130
|
+
locals = dict(vars(obj))
|
|
131
|
+
unwrap = obj
|
|
132
|
+
elif isinstance(obj, types.ModuleType):
|
|
133
|
+
# module
|
|
134
|
+
globals = obj.__dict__
|
|
135
|
+
locals = {}
|
|
136
|
+
unwrap = None
|
|
137
|
+
elif callable(obj):
|
|
138
|
+
# function
|
|
139
|
+
globals = getattr(obj, "__globals__", {})
|
|
140
|
+
# Capture the variables from the surrounding scope.
|
|
141
|
+
closure_vars = zip(
|
|
142
|
+
obj.__code__.co_freevars, tuple(get_closure_cell_contents(x) for x in (obj.__closure__ or ()))
|
|
143
|
+
)
|
|
144
|
+
locals = {k: v for k, v in closure_vars if v is not None}
|
|
145
|
+
unwrap = obj
|
|
146
|
+
else:
|
|
147
|
+
raise TypeError(f"{obj!r} is not a module, class, or callable.")
|
|
148
|
+
|
|
149
|
+
if unwrap is not None:
|
|
150
|
+
while True:
|
|
151
|
+
if hasattr(unwrap, "__wrapped__"):
|
|
152
|
+
unwrap = unwrap.__wrapped__
|
|
153
|
+
continue
|
|
154
|
+
if isinstance(unwrap, functools.partial):
|
|
155
|
+
unwrap = unwrap.func
|
|
156
|
+
continue
|
|
157
|
+
break
|
|
158
|
+
if hasattr(unwrap, "__globals__"):
|
|
159
|
+
globals = unwrap.__globals__
|
|
160
|
+
|
|
161
|
+
# "Inject" type parameters into the local namespace
|
|
162
|
+
# (unless they are shadowed by assignments *in* the local namespace),
|
|
163
|
+
# as a way of emulating annotation scopes when calling `eval()`
|
|
164
|
+
type_params = getattr(obj, "__type_params__", ())
|
|
165
|
+
if type_params:
|
|
166
|
+
locals = {param.__name__: param for param in type_params} | locals
|
|
167
|
+
|
|
168
|
+
return {k: v if not isinstance(v, str) else eval(v, globals, locals) for k, v in annotations.items()}
|
|
169
|
+
|
|
170
|
+
|
|
91
171
|
def get_annotations(obj: Any) -> Mapping[str, Any]:
|
|
92
|
-
"""
|
|
172
|
+
"""Same as `inspect.get_annotations()` but always returning un-stringized annotations."""
|
|
173
|
+
# This backports `inspect.get_annotations()` for Python 3.9 and older.
|
|
93
174
|
# See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
|
|
94
175
|
if isinstance(obj, type):
|
|
95
|
-
|
|
176
|
+
annotations = obj.__dict__.get("__annotations__", {})
|
|
177
|
+
else:
|
|
178
|
+
annotations = getattr(obj, "__annotations__", {})
|
|
179
|
+
|
|
180
|
+
# Evaluating annotations can be done using the `eval_str` parameter with
|
|
181
|
+
# the official function from the `inspect` module.
|
|
182
|
+
return eval_annotations(annotations, obj)
|
|
96
183
|
|
|
97
|
-
|
|
184
|
+
|
|
185
|
+
def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
|
|
186
|
+
"""Same as `inspect.getfullargspec()` but always returning un-stringized annotations."""
|
|
187
|
+
# See https://docs.python.org/3/howto/annotations.html#manually-un-stringizing-stringized-annotations
|
|
188
|
+
spec = inspect.getfullargspec(func)
|
|
189
|
+
return spec._replace(annotations=eval_annotations(spec.annotations, func))
|
|
98
190
|
|
|
99
191
|
|
|
100
192
|
def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
|
|
@@ -490,6 +582,14 @@ class Var:
|
|
|
490
582
|
self.constant = constant
|
|
491
583
|
self.prefix = prefix
|
|
492
584
|
|
|
585
|
+
# records whether this Var has been read from in a kernel function (array only)
|
|
586
|
+
self.is_read = False
|
|
587
|
+
# records whether this Var has been written to in a kernel function (array only)
|
|
588
|
+
self.is_write = False
|
|
589
|
+
|
|
590
|
+
# used to associate a view array Var with its parent array Var
|
|
591
|
+
self.parent = None
|
|
592
|
+
|
|
493
593
|
def __str__(self):
|
|
494
594
|
return self.label
|
|
495
595
|
|
|
@@ -532,6 +632,42 @@ class Var:
|
|
|
532
632
|
def emit_adj(self):
|
|
533
633
|
return self.emit("adj")
|
|
534
634
|
|
|
635
|
+
def mark_read(self):
|
|
636
|
+
"""Marks this Var as having been read from in a kernel (array only)."""
|
|
637
|
+
if not is_array(self.type):
|
|
638
|
+
return
|
|
639
|
+
|
|
640
|
+
self.is_read = True
|
|
641
|
+
|
|
642
|
+
# recursively update all parent states
|
|
643
|
+
parent = self.parent
|
|
644
|
+
while parent is not None:
|
|
645
|
+
parent.is_read = True
|
|
646
|
+
parent = parent.parent
|
|
647
|
+
|
|
648
|
+
def mark_write(self, **kwargs):
|
|
649
|
+
"""Marks this Var has having been written to in a kernel (array only)."""
|
|
650
|
+
if not is_array(self.type):
|
|
651
|
+
return
|
|
652
|
+
|
|
653
|
+
# detect if we are writing to an array after reading from it within the same kernel
|
|
654
|
+
if self.is_read and warp.config.verify_autograd_array_access:
|
|
655
|
+
if "kernel_name" and "filename" and "lineno" in kwargs:
|
|
656
|
+
print(
|
|
657
|
+
f"Warning: Array passed to argument {self.label} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
|
|
658
|
+
)
|
|
659
|
+
else:
|
|
660
|
+
print(
|
|
661
|
+
f"Warning: Array {self} is being written to after it has been read from within the same kernel. This may corrupt gradient computation in the backward pass."
|
|
662
|
+
)
|
|
663
|
+
self.is_write = True
|
|
664
|
+
|
|
665
|
+
# recursively update all parent states
|
|
666
|
+
parent = self.parent
|
|
667
|
+
while parent is not None:
|
|
668
|
+
parent.is_write = True
|
|
669
|
+
parent = parent.parent
|
|
670
|
+
|
|
535
671
|
|
|
536
672
|
class Block:
|
|
537
673
|
# Represents a basic block of instructions, e.g.: list
|
|
@@ -547,6 +683,91 @@ class Block:
|
|
|
547
683
|
self.vars = []
|
|
548
684
|
|
|
549
685
|
|
|
686
|
+
def apply_defaults(
|
|
687
|
+
bound_args: inspect.BoundArguments,
|
|
688
|
+
values: Mapping[str, Any],
|
|
689
|
+
):
|
|
690
|
+
# Similar to Python's `inspect.BoundArguments.apply_defaults()`
|
|
691
|
+
# but with the possibility to pass an augmented set of default values.
|
|
692
|
+
arguments = bound_args.arguments
|
|
693
|
+
new_arguments = []
|
|
694
|
+
for name in bound_args._signature.parameters.keys():
|
|
695
|
+
try:
|
|
696
|
+
new_arguments.append((name, arguments[name]))
|
|
697
|
+
except KeyError:
|
|
698
|
+
if name in values:
|
|
699
|
+
new_arguments.append((name, values[name]))
|
|
700
|
+
|
|
701
|
+
bound_args.arguments = dict(new_arguments)
|
|
702
|
+
|
|
703
|
+
|
|
704
|
+
def func_match_args(func, arg_types, kwarg_types):
|
|
705
|
+
try:
|
|
706
|
+
# Try to bind the given arguments to the function's signature.
|
|
707
|
+
# This is not checking whether the argument types are matching,
|
|
708
|
+
# rather it's just assigning each argument to the corresponding
|
|
709
|
+
# function parameter.
|
|
710
|
+
bound_arg_types = func.signature.bind(*arg_types, **kwarg_types)
|
|
711
|
+
except TypeError:
|
|
712
|
+
return False
|
|
713
|
+
|
|
714
|
+
# Populate the bound arguments with any default values.
|
|
715
|
+
default_arg_types = {
|
|
716
|
+
k: None if v is None else get_arg_type(v)
|
|
717
|
+
for k, v in func.defaults.items()
|
|
718
|
+
if k not in bound_arg_types.arguments
|
|
719
|
+
}
|
|
720
|
+
apply_defaults(bound_arg_types, default_arg_types)
|
|
721
|
+
bound_arg_types = tuple(bound_arg_types.arguments.values())
|
|
722
|
+
|
|
723
|
+
# Check the given argument types against the ones defined on the function.
|
|
724
|
+
for bound_arg_type, func_arg_type in zip(bound_arg_types, func.input_types.values()):
|
|
725
|
+
# Let the `value_func` callback infer the type.
|
|
726
|
+
if bound_arg_type is None:
|
|
727
|
+
continue
|
|
728
|
+
|
|
729
|
+
# if arg type registered as Any, treat as
|
|
730
|
+
# template allowing any type to match
|
|
731
|
+
if func_arg_type == Any:
|
|
732
|
+
continue
|
|
733
|
+
|
|
734
|
+
# handle function refs as a special case
|
|
735
|
+
if func_arg_type == Callable and isinstance(bound_arg_type, warp.context.Function):
|
|
736
|
+
continue
|
|
737
|
+
|
|
738
|
+
# check arg type matches input variable type
|
|
739
|
+
if not types_equal(func_arg_type, strip_reference(bound_arg_type), match_generic=True):
|
|
740
|
+
return False
|
|
741
|
+
|
|
742
|
+
return True
|
|
743
|
+
|
|
744
|
+
|
|
745
|
+
def get_arg_type(arg: Union[Var, Any]):
|
|
746
|
+
if isinstance(arg, Sequence):
|
|
747
|
+
return tuple(get_arg_type(x) for x in arg)
|
|
748
|
+
|
|
749
|
+
if isinstance(arg, (type, warp.context.Function)):
|
|
750
|
+
return arg
|
|
751
|
+
|
|
752
|
+
if isinstance(arg, Var):
|
|
753
|
+
return arg.type
|
|
754
|
+
|
|
755
|
+
return type(arg)
|
|
756
|
+
|
|
757
|
+
|
|
758
|
+
def get_arg_value(arg: Union[Var, Any]):
|
|
759
|
+
if isinstance(arg, Sequence):
|
|
760
|
+
return tuple(get_arg_value(x) for x in arg)
|
|
761
|
+
|
|
762
|
+
if isinstance(arg, (type, warp.context.Function)):
|
|
763
|
+
return arg
|
|
764
|
+
|
|
765
|
+
if isinstance(arg, Var):
|
|
766
|
+
return arg.constant
|
|
767
|
+
|
|
768
|
+
return arg
|
|
769
|
+
|
|
770
|
+
|
|
550
771
|
class Adjoint:
|
|
551
772
|
# Source code transformer, this class takes a Python function and
|
|
552
773
|
# generates forward and backward SSA forms of the function instructions
|
|
@@ -605,7 +826,7 @@ class Adjoint:
|
|
|
605
826
|
adj.custom_reverse_num_input_args = custom_reverse_num_input_args
|
|
606
827
|
|
|
607
828
|
# parse argument types
|
|
608
|
-
argspec =
|
|
829
|
+
argspec = get_full_arg_spec(func)
|
|
609
830
|
|
|
610
831
|
# ensure all arguments are annotated
|
|
611
832
|
if overload_annotations is None:
|
|
@@ -646,6 +867,11 @@ class Adjoint:
|
|
|
646
867
|
|
|
647
868
|
# generate function ssa form and adjoint
|
|
648
869
|
def build(adj, builder, default_builder_options=None):
|
|
870
|
+
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
871
|
+
for arg in adj.args:
|
|
872
|
+
arg.is_read = False
|
|
873
|
+
arg.is_write = False
|
|
874
|
+
|
|
649
875
|
if adj.skip_build:
|
|
650
876
|
return
|
|
651
877
|
|
|
@@ -682,15 +908,11 @@ class Adjoint:
|
|
|
682
908
|
# recursively evaluate function body
|
|
683
909
|
try:
|
|
684
910
|
adj.eval(adj.tree.body[0])
|
|
685
|
-
except Exception
|
|
911
|
+
except Exception:
|
|
686
912
|
try:
|
|
687
|
-
if isinstance(e, KeyError) and getattr(e.args[0], "__module__", None) == "ast":
|
|
688
|
-
msg = f'Syntax error: unsupported construct "ast.{e.args[0].__name__}"'
|
|
689
|
-
else:
|
|
690
|
-
msg = "Error"
|
|
691
913
|
lineno = adj.lineno + adj.fun_lineno
|
|
692
914
|
line = adj.source_lines[adj.lineno]
|
|
693
|
-
msg
|
|
915
|
+
msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
|
|
694
916
|
ex, data, traceback = sys.exc_info()
|
|
695
917
|
e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
|
|
696
918
|
finally:
|
|
@@ -808,6 +1030,20 @@ class Adjoint:
|
|
|
808
1030
|
|
|
809
1031
|
return v
|
|
810
1032
|
|
|
1033
|
+
def register_var(adj, var):
|
|
1034
|
+
# We sometimes initialize `Var` instances that might be thrown away
|
|
1035
|
+
# afterwards, so this method allows to defer their registration among
|
|
1036
|
+
# the list of primal vars until later on, instead of registering them
|
|
1037
|
+
# immediately if we were to use `adj.add_var()` or `adj.add_constant()`.
|
|
1038
|
+
|
|
1039
|
+
if isinstance(var, (Reference, warp.context.Function)):
|
|
1040
|
+
return var
|
|
1041
|
+
|
|
1042
|
+
if var.label is None:
|
|
1043
|
+
return adj.add_var(var.type, var.constant)
|
|
1044
|
+
|
|
1045
|
+
return var
|
|
1046
|
+
|
|
811
1047
|
# append a statement to the forward pass
|
|
812
1048
|
def add_forward(adj, statement, replay=None, skip_replay=False):
|
|
813
1049
|
adj.blocks[-1].body_forward.append(adj.indentation + statement)
|
|
@@ -873,12 +1109,10 @@ class Adjoint:
|
|
|
873
1109
|
|
|
874
1110
|
return output
|
|
875
1111
|
|
|
876
|
-
def resolve_func(adj, func,
|
|
877
|
-
arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
|
|
878
|
-
|
|
1112
|
+
def resolve_func(adj, func, arg_types, kwarg_types, min_outputs):
|
|
879
1113
|
if not func.is_builtin():
|
|
880
1114
|
# user-defined function
|
|
881
|
-
overload = func.get_overload(arg_types)
|
|
1115
|
+
overload = func.get_overload(arg_types, kwarg_types)
|
|
882
1116
|
if overload is not None:
|
|
883
1117
|
return overload
|
|
884
1118
|
else:
|
|
@@ -888,88 +1122,89 @@ class Adjoint:
|
|
|
888
1122
|
# skip type checking for variadic functions
|
|
889
1123
|
if not f.variadic:
|
|
890
1124
|
# check argument counts match are compatible (may be some default args)
|
|
891
|
-
if len(f.input_types) < len(
|
|
1125
|
+
if len(f.input_types) < len(arg_types) + len(kwarg_types):
|
|
892
1126
|
continue
|
|
893
1127
|
|
|
894
|
-
|
|
895
|
-
# check argument types equal
|
|
896
|
-
for i, (arg_name, arg_type) in enumerate(f.input_types.items()):
|
|
897
|
-
# if arg type registered as Any, treat as
|
|
898
|
-
# template allowing any type to match
|
|
899
|
-
if arg_type == Any:
|
|
900
|
-
continue
|
|
901
|
-
|
|
902
|
-
# handle function refs as a special case
|
|
903
|
-
if arg_type == Callable and type(args[i]) is warp.context.Function:
|
|
904
|
-
continue
|
|
905
|
-
|
|
906
|
-
if arg_type == Reference and is_reference(args[i].type):
|
|
907
|
-
continue
|
|
908
|
-
|
|
909
|
-
# look for default values for missing args
|
|
910
|
-
if i >= len(args):
|
|
911
|
-
if arg_name not in f.defaults:
|
|
912
|
-
return False
|
|
913
|
-
else:
|
|
914
|
-
# otherwise check arg type matches input variable type
|
|
915
|
-
if not types_equal(arg_type, strip_reference(args[i].type), match_generic=True):
|
|
916
|
-
return False
|
|
917
|
-
|
|
918
|
-
return True
|
|
919
|
-
|
|
920
|
-
if not match_args(args, f):
|
|
1128
|
+
if not func_match_args(f, arg_types, kwarg_types):
|
|
921
1129
|
continue
|
|
922
1130
|
|
|
923
1131
|
# check output dimensions match expectations
|
|
924
1132
|
if min_outputs:
|
|
925
|
-
|
|
926
|
-
value_type = f.value_func(args, kwds, templates)
|
|
927
|
-
if not hasattr(value_type, "__len__") or len(value_type) != min_outputs:
|
|
928
|
-
continue
|
|
929
|
-
except Exception:
|
|
930
|
-
# value func may fail if the user has given
|
|
931
|
-
# incorrect args, so we need to catch this
|
|
1133
|
+
if not isinstance(f.value_type, Sequence) or len(f.value_type) != min_outputs:
|
|
932
1134
|
continue
|
|
933
1135
|
|
|
934
1136
|
# found a match, use it
|
|
935
1137
|
return f
|
|
936
1138
|
|
|
937
1139
|
# unresolved function, report error
|
|
938
|
-
|
|
1140
|
+
arg_type_reprs = []
|
|
939
1141
|
|
|
940
|
-
for x in
|
|
941
|
-
if isinstance(x,
|
|
1142
|
+
for x in arg_types:
|
|
1143
|
+
if isinstance(x, warp.context.Function):
|
|
1144
|
+
arg_type_reprs.append("function")
|
|
1145
|
+
else:
|
|
942
1146
|
# shorten Warp primitive type names
|
|
943
|
-
if isinstance(x
|
|
944
|
-
if len(x
|
|
1147
|
+
if isinstance(x, Sequence):
|
|
1148
|
+
if len(x) != 1:
|
|
945
1149
|
raise WarpCodegenError("Argument must not be the result from a multi-valued function")
|
|
946
|
-
arg_type = x
|
|
1150
|
+
arg_type = x[0]
|
|
947
1151
|
else:
|
|
948
|
-
arg_type = x
|
|
1152
|
+
arg_type = x
|
|
949
1153
|
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
if isinstance(x, warp.context.Function):
|
|
953
|
-
arg_types.append("function")
|
|
1154
|
+
arg_type_reprs.append(type_repr(arg_type))
|
|
954
1155
|
|
|
955
1156
|
raise WarpCodegenError(
|
|
956
|
-
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(
|
|
1157
|
+
f"Couldn't find function overload for '{func.key}' that matched inputs with types: [{', '.join(arg_type_reprs)}]"
|
|
957
1158
|
)
|
|
958
1159
|
|
|
959
|
-
def add_call(adj, func, args,
|
|
960
|
-
|
|
961
|
-
|
|
1160
|
+
def add_call(adj, func, args, kwargs, type_args, min_outputs=None):
|
|
1161
|
+
# Extract the types and values passed as arguments to the function call.
|
|
1162
|
+
arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
|
|
1163
|
+
kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
|
|
1164
|
+
|
|
1165
|
+
# Resolve the exact function signature among any existing overload.
|
|
1166
|
+
func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
|
|
1167
|
+
|
|
1168
|
+
# Bind the positional and keyword arguments to the function's signature
|
|
1169
|
+
# in order to process them as Python does it.
|
|
1170
|
+
bound_args = func.signature.bind(*args, **kwargs)
|
|
1171
|
+
|
|
1172
|
+
# Type args are the “compile time” argument values we get from codegen.
|
|
1173
|
+
# For example, when calling `wp.vec3f(...)` from within a kernel,
|
|
1174
|
+
# this translates in fact to calling the `vector()` built-in augmented
|
|
1175
|
+
# with the type args `length=3, dtype=float`.
|
|
1176
|
+
# Eventually, these need to be passed to the underlying C++ function,
|
|
1177
|
+
# so we update the arguments with the type args here.
|
|
1178
|
+
if type_args:
|
|
1179
|
+
for arg in type_args:
|
|
1180
|
+
if arg in bound_args.arguments:
|
|
1181
|
+
# In case of conflict, ideally we'd throw an error since
|
|
1182
|
+
# what comes from codegen should be the source of truth
|
|
1183
|
+
# and users also passing the same value as an argument
|
|
1184
|
+
# is redundant (e.g.: `wp.mat22(shape=(2, 2))`).
|
|
1185
|
+
# However, for backward compatibility, we allow that form
|
|
1186
|
+
# as long as the values are equal.
|
|
1187
|
+
if values_check_equal(get_arg_value(bound_args.arguments[arg]), type_args[arg]):
|
|
1188
|
+
continue
|
|
962
1189
|
|
|
963
|
-
|
|
1190
|
+
raise RuntimeError(
|
|
1191
|
+
f"Remove the extraneous `{arg}` parameter "
|
|
1192
|
+
f"when calling the templated version of "
|
|
1193
|
+
f"`wp.{func.native_func}()`"
|
|
1194
|
+
)
|
|
964
1195
|
|
|
965
|
-
|
|
966
|
-
|
|
967
|
-
|
|
968
|
-
|
|
969
|
-
|
|
970
|
-
|
|
971
|
-
|
|
972
|
-
|
|
1196
|
+
type_vars = {k: Var(None, type=type(v), constant=v) for k, v in type_args.items()}
|
|
1197
|
+
apply_defaults(bound_args, type_vars)
|
|
1198
|
+
|
|
1199
|
+
if func.defaults:
|
|
1200
|
+
default_vars = {
|
|
1201
|
+
k: Var(None, type=type(v), constant=v)
|
|
1202
|
+
for k, v in func.defaults.items()
|
|
1203
|
+
if k not in bound_args.arguments and v is not None
|
|
1204
|
+
}
|
|
1205
|
+
apply_defaults(bound_args, default_vars)
|
|
1206
|
+
|
|
1207
|
+
bound_args = bound_args.arguments
|
|
973
1208
|
|
|
974
1209
|
# if it is a user-function then build it recursively
|
|
975
1210
|
if not func.is_builtin() and func not in adj.builder.functions:
|
|
@@ -983,23 +1218,38 @@ class Adjoint:
|
|
|
983
1218
|
if func.custom_replay_func:
|
|
984
1219
|
adj.builder.deferred_functions.append(func.custom_replay_func)
|
|
985
1220
|
|
|
986
|
-
#
|
|
987
|
-
|
|
988
|
-
|
|
1221
|
+
# Resolve the return value based on the types and values of the given arguments.
|
|
1222
|
+
bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
|
|
1223
|
+
bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
|
|
1224
|
+
return_type = func.value_func(
|
|
1225
|
+
{k: strip_reference(v) for k, v in bound_arg_types.items()},
|
|
1226
|
+
bound_arg_values,
|
|
1227
|
+
)
|
|
1228
|
+
|
|
1229
|
+
if func.dispatch_func is not None:
|
|
1230
|
+
# If we have a built-in that requires special handling to dispatch
|
|
1231
|
+
# the arguments to the underlying C++ function, then we can resolve
|
|
1232
|
+
# these using the `dispatch_func`. Since this is only called from
|
|
1233
|
+
# within codegen, we pass it directly `codegen.Var` objects,
|
|
1234
|
+
# which allows for some more advanced resolution to be performed,
|
|
1235
|
+
# for example by checking whether an argument corresponds to
|
|
1236
|
+
# a literal value or references a variable.
|
|
989
1237
|
|
|
990
|
-
|
|
991
|
-
|
|
1238
|
+
func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
|
|
1239
|
+
else:
|
|
1240
|
+
func_args = tuple(bound_args.values())
|
|
1241
|
+
template_args = ()
|
|
992
1242
|
|
|
993
|
-
|
|
1243
|
+
func_args = tuple(adj.register_var(x) for x in func_args)
|
|
1244
|
+
func_name = compute_type_str(func.native_func, template_args)
|
|
1245
|
+
use_initializer_list = func.initializer_list_func(bound_args, return_type)
|
|
994
1246
|
|
|
995
|
-
|
|
996
|
-
|
|
997
|
-
|
|
998
|
-
|
|
999
|
-
|
|
1000
|
-
)
|
|
1001
|
-
for i, a in enumerate(args)
|
|
1002
|
-
]
|
|
1247
|
+
fwd_args = []
|
|
1248
|
+
for func_arg in func_args:
|
|
1249
|
+
if not isinstance(func_arg, (Reference, warp.context.Function)):
|
|
1250
|
+
func_arg = adj.load(func_arg)
|
|
1251
|
+
|
|
1252
|
+
fwd_args.append(strip_reference(func_arg))
|
|
1003
1253
|
|
|
1004
1254
|
if return_type is None:
|
|
1005
1255
|
# handles expression (zero output) functions, e.g.: void do_something();
|
|
@@ -1008,24 +1258,24 @@ class Adjoint:
|
|
|
1008
1258
|
output_list = []
|
|
1009
1259
|
|
|
1010
1260
|
forward_call = (
|
|
1011
|
-
f"{func.namespace}{func_name}({adj.format_forward_call_args(
|
|
1261
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1012
1262
|
)
|
|
1013
1263
|
replay_call = forward_call
|
|
1014
1264
|
if func.custom_replay_func is not None or func.replay_snippet is not None:
|
|
1015
|
-
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(
|
|
1265
|
+
replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1016
1266
|
|
|
1017
|
-
elif not isinstance(return_type,
|
|
1267
|
+
elif not isinstance(return_type, Sequence) or len(return_type) == 1:
|
|
1018
1268
|
# handle simple function (one output)
|
|
1019
1269
|
|
|
1020
|
-
if isinstance(return_type,
|
|
1270
|
+
if isinstance(return_type, Sequence):
|
|
1021
1271
|
return_type = return_type[0]
|
|
1022
1272
|
output = adj.add_var(return_type)
|
|
1023
1273
|
output_list = [output]
|
|
1024
1274
|
|
|
1025
|
-
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(
|
|
1275
|
+
forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1026
1276
|
replay_call = forward_call
|
|
1027
1277
|
if func.custom_replay_func is not None:
|
|
1028
|
-
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(
|
|
1278
|
+
replay_call = f"var_{output} = {func.namespace}replay_{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
|
|
1029
1279
|
|
|
1030
1280
|
else:
|
|
1031
1281
|
# handle multiple value functions
|
|
@@ -1034,7 +1284,7 @@ class Adjoint:
|
|
|
1034
1284
|
output_list = output
|
|
1035
1285
|
|
|
1036
1286
|
forward_call = (
|
|
1037
|
-
f"{func.namespace}{func_name}({adj.format_forward_call_args(
|
|
1287
|
+
f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
|
|
1038
1288
|
)
|
|
1039
1289
|
replay_call = forward_call
|
|
1040
1290
|
|
|
@@ -1043,13 +1293,14 @@ class Adjoint:
|
|
|
1043
1293
|
else:
|
|
1044
1294
|
adj.add_forward(forward_call, replay=replay_call)
|
|
1045
1295
|
|
|
1046
|
-
if not func.missing_grad and len(
|
|
1296
|
+
if not func.missing_grad and len(func_args):
|
|
1297
|
+
adj_args = tuple(strip_reference(x) for x in func_args)
|
|
1047
1298
|
reverse_has_output_args = (
|
|
1048
1299
|
func.require_original_output_arg or len(output_list) > 1
|
|
1049
1300
|
) and func.custom_grad_func is None
|
|
1050
1301
|
arg_str = adj.format_reverse_call_args(
|
|
1051
|
-
|
|
1052
|
-
|
|
1302
|
+
fwd_args,
|
|
1303
|
+
adj_args,
|
|
1053
1304
|
output_list,
|
|
1054
1305
|
use_initializer_list,
|
|
1055
1306
|
has_output_args=reverse_has_output_args,
|
|
@@ -1061,12 +1312,9 @@ class Adjoint:
|
|
|
1061
1312
|
|
|
1062
1313
|
return output
|
|
1063
1314
|
|
|
1064
|
-
def add_builtin_call(adj, func_name, args, min_outputs=None
|
|
1065
|
-
if templates is None:
|
|
1066
|
-
templates = []
|
|
1067
|
-
|
|
1315
|
+
def add_builtin_call(adj, func_name, args, min_outputs=None):
|
|
1068
1316
|
func = warp.context.builtin_functions[func_name]
|
|
1069
|
-
return adj.add_call(func, args,
|
|
1317
|
+
return adj.add_call(func, args, {}, {}, min_outputs=min_outputs)
|
|
1070
1318
|
|
|
1071
1319
|
def add_return(adj, var):
|
|
1072
1320
|
if var is None or len(var) == 0:
|
|
@@ -1505,7 +1753,24 @@ class Adjoint:
|
|
|
1505
1753
|
|
|
1506
1754
|
def emit_BinOp(adj, node):
|
|
1507
1755
|
# evaluate binary operator arguments
|
|
1756
|
+
|
|
1757
|
+
if warp.config.verify_autograd_array_access:
|
|
1758
|
+
# array overwrite tracking: in-place operators are a special case
|
|
1759
|
+
# x[tid] = x[tid] + 1 is a read followed by a write, but we only want to record the write
|
|
1760
|
+
# so we save the current arg read flags and restore them after lhs eval
|
|
1761
|
+
is_read_states = []
|
|
1762
|
+
for arg in adj.args:
|
|
1763
|
+
is_read_states.append(arg.is_read)
|
|
1764
|
+
|
|
1765
|
+
# evaluate lhs binary operator argument
|
|
1508
1766
|
left = adj.eval(node.left)
|
|
1767
|
+
|
|
1768
|
+
if warp.config.verify_autograd_array_access:
|
|
1769
|
+
# restore arg read flags
|
|
1770
|
+
for i, arg in enumerate(adj.args):
|
|
1771
|
+
arg.is_read = is_read_states[i]
|
|
1772
|
+
|
|
1773
|
+
# evaluate rhs binary operator argument
|
|
1509
1774
|
right = adj.eval(node.right)
|
|
1510
1775
|
|
|
1511
1776
|
name = builtin_operators[type(node.op)]
|
|
@@ -1569,6 +1834,9 @@ class Adjoint:
|
|
|
1569
1834
|
# e.g.: wp.constant in the globals scope
|
|
1570
1835
|
obj, _ = adj.resolve_static_expression(a)
|
|
1571
1836
|
|
|
1837
|
+
if obj is None:
|
|
1838
|
+
obj = adj.eval(a)
|
|
1839
|
+
|
|
1572
1840
|
if isinstance(obj, Var) and obj.constant is not None:
|
|
1573
1841
|
obj = obj.constant
|
|
1574
1842
|
|
|
@@ -1728,13 +1996,40 @@ class Adjoint:
|
|
|
1728
1996
|
f"arguments to the function {adj.fun_name}, {adj.filename}:{lineno}:\n{line}\n"
|
|
1729
1997
|
)
|
|
1730
1998
|
|
|
1999
|
+
def resolve_arg(adj, arg):
|
|
2000
|
+
# Always try to start with evaluating the argument since it can help
|
|
2001
|
+
# detecting some issues such as global variables being accessed.
|
|
2002
|
+
try:
|
|
2003
|
+
var = adj.eval(arg)
|
|
2004
|
+
except (WarpCodegenError, WarpCodegenKeyError) as e:
|
|
2005
|
+
error = e
|
|
2006
|
+
else:
|
|
2007
|
+
error = None
|
|
2008
|
+
|
|
2009
|
+
# Check if we can resolve the argument as a static expression.
|
|
2010
|
+
# If not, return the variable resulting from evaluating the argument.
|
|
2011
|
+
expr, _ = adj.resolve_static_expression(arg)
|
|
2012
|
+
if expr is None:
|
|
2013
|
+
if error is not None:
|
|
2014
|
+
raise error
|
|
2015
|
+
|
|
2016
|
+
return var
|
|
2017
|
+
|
|
2018
|
+
if isinstance(expr, (type, Var, warp.context.Function)):
|
|
2019
|
+
return expr
|
|
2020
|
+
|
|
2021
|
+
return adj.add_constant(expr)
|
|
2022
|
+
|
|
1731
2023
|
def emit_Call(adj, node):
|
|
1732
2024
|
adj.check_tid_in_func_error(node)
|
|
1733
2025
|
|
|
1734
2026
|
# try and lookup function in globals by
|
|
1735
2027
|
# resolving path (e.g.: module.submodule.attr)
|
|
1736
2028
|
func, path = adj.resolve_static_expression(node.func)
|
|
1737
|
-
|
|
2029
|
+
if func is None:
|
|
2030
|
+
func = adj.eval(node.func)
|
|
2031
|
+
|
|
2032
|
+
type_args = {}
|
|
1738
2033
|
|
|
1739
2034
|
if not isinstance(func, warp.context.Function):
|
|
1740
2035
|
attr = path[-1]
|
|
@@ -1747,7 +2042,6 @@ class Adjoint:
|
|
|
1747
2042
|
|
|
1748
2043
|
# vector class type e.g.: wp.vec3f constructor
|
|
1749
2044
|
if func is None and hasattr(caller, "_wp_generic_type_str_"):
|
|
1750
|
-
templates = caller._wp_type_params_
|
|
1751
2045
|
func = warp.context.builtin_functions.get(caller._wp_constructor_)
|
|
1752
2046
|
|
|
1753
2047
|
# scalar class type e.g.: wp.int8 constructor
|
|
@@ -1757,43 +2051,53 @@ class Adjoint:
|
|
|
1757
2051
|
# struct constructor
|
|
1758
2052
|
if func is None and isinstance(caller, Struct):
|
|
1759
2053
|
adj.builder.build_struct_recursive(caller)
|
|
1760
|
-
|
|
2054
|
+
if node.args or node.keywords:
|
|
2055
|
+
func = caller.value_constructor
|
|
2056
|
+
else:
|
|
2057
|
+
func = caller.default_constructor
|
|
2058
|
+
|
|
2059
|
+
if hasattr(caller, "_wp_type_args_"):
|
|
2060
|
+
type_args = caller._wp_type_args_
|
|
1761
2061
|
|
|
1762
2062
|
if func is None:
|
|
1763
2063
|
raise WarpCodegenError(
|
|
1764
2064
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
1765
2065
|
)
|
|
1766
2066
|
|
|
1767
|
-
|
|
1768
|
-
|
|
1769
|
-
#
|
|
2067
|
+
# Check if any argument correspond to an unsupported construct.
|
|
2068
|
+
# Tuples are supported in the context of assigning multiple variables
|
|
2069
|
+
# at once, but not in place of vectors when calling built-ins like
|
|
2070
|
+
# `wp.length((1, 2, 3))`.
|
|
2071
|
+
# Therefore, we need to catch this specific case here instead of
|
|
2072
|
+
# more generally in `adj.eval()`.
|
|
1770
2073
|
for arg in node.args:
|
|
1771
|
-
|
|
1772
|
-
|
|
1773
|
-
|
|
1774
|
-
|
|
1775
|
-
def kwval(kw):
|
|
1776
|
-
if isinstance(kw.value, ast.Num):
|
|
1777
|
-
return kw.value.n
|
|
1778
|
-
elif isinstance(kw.value, ast.Tuple):
|
|
1779
|
-
arg_is_numeric, arg_values = zip(*(adj.eval_num(e) for e in kw.value.elts))
|
|
1780
|
-
if not all(arg_is_numeric):
|
|
1781
|
-
raise WarpCodegenError(
|
|
1782
|
-
f"All elements of the tuple keyword argument '{kw.name}' must be numeric constants, got '{arg_values}'"
|
|
1783
|
-
)
|
|
1784
|
-
return arg_values
|
|
1785
|
-
else:
|
|
1786
|
-
return adj.resolve_static_expression(kw.value)[0]
|
|
1787
|
-
|
|
1788
|
-
kwds = {kw.arg: kwval(kw) for kw in node.keywords}
|
|
2074
|
+
if isinstance(arg, ast.Tuple):
|
|
2075
|
+
raise WarpCodegenError(
|
|
2076
|
+
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
|
|
2077
|
+
)
|
|
1789
2078
|
|
|
1790
2079
|
# get expected return count, e.g.: for multi-assignment
|
|
1791
2080
|
min_outputs = None
|
|
1792
2081
|
if hasattr(node, "expects"):
|
|
1793
2082
|
min_outputs = node.expects
|
|
1794
2083
|
|
|
1795
|
-
#
|
|
1796
|
-
|
|
2084
|
+
# Evaluate all positional and keywords arguments.
|
|
2085
|
+
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2086
|
+
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2087
|
+
|
|
2088
|
+
if warp.config.verify_autograd_array_access:
|
|
2089
|
+
# update arg read/write states according to what happens to that arg in the called function
|
|
2090
|
+
if hasattr(func, "adj"):
|
|
2091
|
+
for i, arg in enumerate(args):
|
|
2092
|
+
if func.adj.args[i].is_write:
|
|
2093
|
+
kernel_name = adj.fun_name
|
|
2094
|
+
filename = adj.filename
|
|
2095
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2096
|
+
arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2097
|
+
if func.adj.args[i].is_read:
|
|
2098
|
+
arg.mark_read()
|
|
2099
|
+
|
|
2100
|
+
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
1797
2101
|
return out
|
|
1798
2102
|
|
|
1799
2103
|
def emit_Index(adj, node):
|
|
@@ -1872,10 +2176,22 @@ class Adjoint:
|
|
|
1872
2176
|
if len(indices) == target_type.ndim:
|
|
1873
2177
|
# handles array loads (where each dimension has an index specified)
|
|
1874
2178
|
out = adj.add_builtin_call("address", [target, *indices])
|
|
2179
|
+
|
|
2180
|
+
if warp.config.verify_autograd_array_access:
|
|
2181
|
+
target.mark_read()
|
|
2182
|
+
|
|
1875
2183
|
else:
|
|
1876
2184
|
# handles array views (fewer indices than dimensions)
|
|
1877
2185
|
out = adj.add_builtin_call("view", [target, *indices])
|
|
1878
2186
|
|
|
2187
|
+
if warp.config.verify_autograd_array_access:
|
|
2188
|
+
# store reference to target Var to propagate downstream read/write state back to root arg Var
|
|
2189
|
+
out.parent = target
|
|
2190
|
+
|
|
2191
|
+
# view arg inherits target Var's read/write states
|
|
2192
|
+
out.is_read = target.is_read
|
|
2193
|
+
out.is_write = target.is_write
|
|
2194
|
+
|
|
1879
2195
|
else:
|
|
1880
2196
|
# handles non-array type indexing, e.g: vec3, mat33, etc
|
|
1881
2197
|
out = adj.add_builtin_call("extract", [target, *indices])
|
|
@@ -1888,6 +2204,21 @@ class Adjoint:
|
|
|
1888
2204
|
|
|
1889
2205
|
lhs = node.targets[0]
|
|
1890
2206
|
|
|
2207
|
+
if not isinstance(lhs, ast.Tuple):
|
|
2208
|
+
# Check if the rhs corresponds to an unsupported construct.
|
|
2209
|
+
# Tuples are supported in the context of assigning multiple variables
|
|
2210
|
+
# at once, but not for simple assignments like `x = (1, 2, 3)`.
|
|
2211
|
+
# Therefore, we need to catch this specific case here instead of
|
|
2212
|
+
# more generally in `adj.eval()`.
|
|
2213
|
+
if isinstance(node.value, ast.List):
|
|
2214
|
+
raise WarpCodegenError(
|
|
2215
|
+
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2216
|
+
)
|
|
2217
|
+
elif isinstance(node.value, ast.Tuple):
|
|
2218
|
+
raise WarpCodegenError(
|
|
2219
|
+
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2220
|
+
)
|
|
2221
|
+
|
|
1891
2222
|
# handle the case where we are assigning multiple output variables
|
|
1892
2223
|
if isinstance(lhs, ast.Tuple):
|
|
1893
2224
|
# record the expected number of outputs on the node
|
|
@@ -1944,7 +2275,14 @@ class Adjoint:
|
|
|
1944
2275
|
if is_array(target_type):
|
|
1945
2276
|
adj.add_builtin_call("array_store", [target, *indices, rhs])
|
|
1946
2277
|
|
|
1947
|
-
|
|
2278
|
+
if warp.config.verify_autograd_array_access:
|
|
2279
|
+
kernel_name = adj.fun_name
|
|
2280
|
+
filename = adj.filename
|
|
2281
|
+
lineno = adj.lineno + adj.fun_lineno
|
|
2282
|
+
|
|
2283
|
+
target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
|
|
2284
|
+
|
|
2285
|
+
elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
|
|
1948
2286
|
if is_reference(target.type):
|
|
1949
2287
|
attr = adj.add_builtin_call("indexref", [target, *indices])
|
|
1950
2288
|
else:
|
|
@@ -1961,7 +2299,7 @@ class Adjoint:
|
|
|
1961
2299
|
)
|
|
1962
2300
|
|
|
1963
2301
|
else:
|
|
1964
|
-
raise WarpCodegenError("Can only subscript assign array, vector, and matrix types")
|
|
2302
|
+
raise WarpCodegenError("Can only subscript assign array, vector, quaternion, and matrix types")
|
|
1965
2303
|
|
|
1966
2304
|
elif isinstance(lhs, ast.Name):
|
|
1967
2305
|
# symbol name
|
|
@@ -2050,8 +2388,7 @@ class Adjoint:
|
|
|
2050
2388
|
|
|
2051
2389
|
def emit_Tuple(adj, node):
|
|
2052
2390
|
# LHS for expressions, such as i, j, k = 1, 2, 3
|
|
2053
|
-
for
|
|
2054
|
-
adj.eval(elem)
|
|
2391
|
+
return tuple(adj.eval(x) for x in node.elts)
|
|
2055
2392
|
|
|
2056
2393
|
def emit_Pass(adj, node):
|
|
2057
2394
|
pass
|
|
@@ -2089,7 +2426,12 @@ class Adjoint:
|
|
|
2089
2426
|
if hasattr(node, "lineno"):
|
|
2090
2427
|
adj.set_lineno(node.lineno - 1)
|
|
2091
2428
|
|
|
2092
|
-
|
|
2429
|
+
try:
|
|
2430
|
+
emit_node = adj.node_visitors[type(node)]
|
|
2431
|
+
except KeyError as e:
|
|
2432
|
+
type_name = type(node).__name__
|
|
2433
|
+
namespace = "ast." if isinstance(node, ast.AST) else ""
|
|
2434
|
+
raise WarpCodegenError(f"Construct `{namespace}{type_name}` not supported in kernels.") from e
|
|
2093
2435
|
|
|
2094
2436
|
return emit_node(adj, node)
|
|
2095
2437
|
|
|
@@ -2120,18 +2462,18 @@ class Adjoint:
|
|
|
2120
2462
|
vars_dict = {**adj.func.__globals__, **capturedvars}
|
|
2121
2463
|
|
|
2122
2464
|
if path[0] in vars_dict:
|
|
2123
|
-
|
|
2465
|
+
expr = vars_dict[path[0]]
|
|
2124
2466
|
|
|
2125
2467
|
# Support Warp types in kernels without the module suffix (e.g. v = vec3(0.0,0.2,0.4)):
|
|
2126
2468
|
else:
|
|
2127
|
-
|
|
2469
|
+
expr = getattr(warp, path[0], None)
|
|
2128
2470
|
|
|
2129
|
-
if
|
|
2471
|
+
if expr:
|
|
2130
2472
|
for i in range(1, len(path)):
|
|
2131
|
-
if hasattr(
|
|
2132
|
-
|
|
2473
|
+
if hasattr(expr, path[i]):
|
|
2474
|
+
expr = getattr(expr, path[i])
|
|
2133
2475
|
|
|
2134
|
-
return
|
|
2476
|
+
return expr
|
|
2135
2477
|
|
|
2136
2478
|
# Evaluates a static expression that does not depend on runtime values
|
|
2137
2479
|
# if eval_types is True, try resolving the path using evaluated type information as well
|
|
@@ -2182,11 +2524,6 @@ class Adjoint:
|
|
|
2182
2524
|
if captured_obj is not None:
|
|
2183
2525
|
return captured_obj, path
|
|
2184
2526
|
|
|
2185
|
-
# Still nothing found, maybe this is a predefined type attribute like `dtype`
|
|
2186
|
-
if eval_types:
|
|
2187
|
-
val = adj.eval(root_node)
|
|
2188
|
-
return [val, path]
|
|
2189
|
-
|
|
2190
2527
|
return None, path
|
|
2191
2528
|
|
|
2192
2529
|
# annotate generated code with the original source code line
|
|
@@ -2262,10 +2599,10 @@ cpu_module_header = """
|
|
|
2262
2599
|
#define int(x) cast_int(x)
|
|
2263
2600
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2264
2601
|
|
|
2265
|
-
#define builtin_tid1d() wp::tid(
|
|
2266
|
-
#define builtin_tid2d(x, y) wp::tid(x, y,
|
|
2267
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z,
|
|
2268
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w,
|
|
2602
|
+
#define builtin_tid1d() wp::tid(task_index)
|
|
2603
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
2604
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
2605
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
2269
2606
|
|
|
2270
2607
|
"""
|
|
2271
2608
|
|
|
@@ -2280,10 +2617,10 @@ cuda_module_header = """
|
|
|
2280
2617
|
#define int(x) cast_int(x)
|
|
2281
2618
|
#define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
|
|
2282
2619
|
|
|
2283
|
-
#define builtin_tid1d() wp::tid(
|
|
2284
|
-
#define builtin_tid2d(x, y) wp::tid(x, y,
|
|
2285
|
-
#define builtin_tid3d(x, y, z) wp::tid(x, y, z,
|
|
2286
|
-
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w,
|
|
2620
|
+
#define builtin_tid1d() wp::tid(task_index)
|
|
2621
|
+
#define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
|
|
2622
|
+
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
2623
|
+
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
2287
2624
|
|
|
2288
2625
|
"""
|
|
2289
2626
|
|
|
@@ -2355,9 +2692,9 @@ cuda_kernel_template = """
|
|
|
2355
2692
|
extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
2356
2693
|
{forward_args})
|
|
2357
2694
|
{{
|
|
2358
|
-
for (size_t
|
|
2359
|
-
|
|
2360
|
-
|
|
2695
|
+
for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2696
|
+
task_index < dim.size;
|
|
2697
|
+
task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2361
2698
|
{{
|
|
2362
2699
|
{forward_body} }}
|
|
2363
2700
|
}}
|
|
@@ -2365,9 +2702,9 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
|
|
|
2365
2702
|
extern "C" __global__ void {name}_cuda_kernel_backward(
|
|
2366
2703
|
{reverse_args})
|
|
2367
2704
|
{{
|
|
2368
|
-
for (size_t
|
|
2369
|
-
|
|
2370
|
-
|
|
2705
|
+
for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
|
|
2706
|
+
task_index < dim.size;
|
|
2707
|
+
task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
|
|
2371
2708
|
{{
|
|
2372
2709
|
{reverse_body} }}
|
|
2373
2710
|
}}
|
|
@@ -2396,10 +2733,8 @@ extern "C" {{
|
|
|
2396
2733
|
WP_API void {name}_cpu_forward(
|
|
2397
2734
|
{forward_args})
|
|
2398
2735
|
{{
|
|
2399
|
-
for (size_t
|
|
2736
|
+
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
2400
2737
|
{{
|
|
2401
|
-
wp::s_threadIdx = i;
|
|
2402
|
-
|
|
2403
2738
|
{name}_cpu_kernel_forward(
|
|
2404
2739
|
{forward_params});
|
|
2405
2740
|
}}
|
|
@@ -2408,10 +2743,8 @@ WP_API void {name}_cpu_forward(
|
|
|
2408
2743
|
WP_API void {name}_cpu_backward(
|
|
2409
2744
|
{reverse_args})
|
|
2410
2745
|
{{
|
|
2411
|
-
for (size_t
|
|
2746
|
+
for (size_t task_index = 0; task_index < dim.size; ++task_index)
|
|
2412
2747
|
{{
|
|
2413
|
-
wp::s_threadIdx = i;
|
|
2414
|
-
|
|
2415
2748
|
{name}_cpu_kernel_backward(
|
|
2416
2749
|
{reverse_params});
|
|
2417
2750
|
}}
|
|
@@ -2838,6 +3171,10 @@ def codegen_kernel(kernel, device, options):
|
|
|
2838
3171
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
2839
3172
|
reverse_args = ["wp::launch_bounds_t dim"]
|
|
2840
3173
|
|
|
3174
|
+
if device == "cpu":
|
|
3175
|
+
forward_args.append("size_t task_index")
|
|
3176
|
+
reverse_args.append("size_t task_index")
|
|
3177
|
+
|
|
2841
3178
|
# forward args
|
|
2842
3179
|
for arg in adj.args:
|
|
2843
3180
|
forward_args.append(arg.ctype() + " var_" + arg.label)
|
|
@@ -2886,7 +3223,7 @@ def codegen_module(kernel, device="cpu"):
|
|
|
2886
3223
|
|
|
2887
3224
|
# build forward signature
|
|
2888
3225
|
forward_args = ["wp::launch_bounds_t dim"]
|
|
2889
|
-
forward_params = ["dim"]
|
|
3226
|
+
forward_params = ["dim", "task_index"]
|
|
2890
3227
|
|
|
2891
3228
|
for arg in adj.args:
|
|
2892
3229
|
if hasattr(arg.type, "_wp_generic_type_str_"):
|