warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/libwarp.dylib +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +125 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +257 -101
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +657 -223
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +97 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +107 -52
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +12 -17
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +99 -0
- warp/native/builtin.h +174 -31
- warp/native/coloring.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +3 -3
- warp/native/mat.h +5 -10
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/quat.h +28 -4
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/tile.h +583 -72
- warp/native/tile_radix_sort.h +1108 -0
- warp/native/tile_reduce.h +237 -2
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +6 -16
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +574 -51
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +58 -29
- warp/render/render_usd.py +124 -61
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +252 -78
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +751 -320
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +52 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +15 -1
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_atomic_cas.py +299 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +1 -24
- warp/tests/test_quat.py +6 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +51 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/tile/test_tile.py +420 -1
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_reduce.py +213 -0
- warp/tests/tile/test_tile_shared_memory.py +130 -1
- warp/tests/tile/test_tile_sort.py +117 -0
- warp/tests/unittest_suites.py +4 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -26,7 +26,7 @@ import re
|
|
|
26
26
|
import sys
|
|
27
27
|
import textwrap
|
|
28
28
|
import types
|
|
29
|
-
from typing import Any, Callable,
|
|
29
|
+
from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
|
|
30
30
|
|
|
31
31
|
import warp.config
|
|
32
32
|
from warp.types import *
|
|
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
# map operator to function name
|
|
60
|
-
builtin_operators:
|
|
60
|
+
builtin_operators: dict[type[ast.AST], str] = {}
|
|
61
61
|
|
|
62
62
|
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
63
63
|
# nice overview of python operators
|
|
@@ -321,7 +321,7 @@ class StructInstance:
|
|
|
321
321
|
# vector/matrix type, e.g. vec3
|
|
322
322
|
if value is None:
|
|
323
323
|
setattr(self._ctype, name, var.type())
|
|
324
|
-
elif
|
|
324
|
+
elif type(value) == var.type:
|
|
325
325
|
setattr(self._ctype, name, value)
|
|
326
326
|
else:
|
|
327
327
|
# conversion from list/tuple, ndarray, etc.
|
|
@@ -626,7 +626,7 @@ def compute_type_str(base_name, template_params):
|
|
|
626
626
|
|
|
627
627
|
return p.__name__
|
|
628
628
|
|
|
629
|
-
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
629
|
+
return f"{base_name}<{', '.join(map(param2str, template_params))}>"
|
|
630
630
|
|
|
631
631
|
|
|
632
632
|
class Var:
|
|
@@ -635,9 +635,9 @@ class Var:
|
|
|
635
635
|
label: str,
|
|
636
636
|
type: type,
|
|
637
637
|
requires_grad: builtins.bool = False,
|
|
638
|
-
constant:
|
|
638
|
+
constant: builtins.bool | None = None,
|
|
639
639
|
prefix: builtins.bool = True,
|
|
640
|
-
relative_lineno:
|
|
640
|
+
relative_lineno: int | None = None,
|
|
641
641
|
):
|
|
642
642
|
# convert built-in types to wp types
|
|
643
643
|
if type == float:
|
|
@@ -667,37 +667,44 @@ class Var:
|
|
|
667
667
|
def __str__(self):
|
|
668
668
|
return self.label
|
|
669
669
|
|
|
670
|
+
@staticmethod
|
|
671
|
+
def dtype_to_ctype(t: type) -> str:
|
|
672
|
+
if hasattr(t, "_wp_generic_type_str_"):
|
|
673
|
+
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
674
|
+
elif isinstance(t, Struct):
|
|
675
|
+
return t.native_name
|
|
676
|
+
elif hasattr(t, "_wp_native_name_"):
|
|
677
|
+
return f"wp::{t._wp_native_name_}"
|
|
678
|
+
elif t.__name__ in ("bool", "int", "float"):
|
|
679
|
+
return t.__name__
|
|
680
|
+
|
|
681
|
+
return f"wp::{t.__name__}"
|
|
682
|
+
|
|
670
683
|
@staticmethod
|
|
671
684
|
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
672
685
|
if is_array(t):
|
|
673
|
-
|
|
674
|
-
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
675
|
-
elif isinstance(t.dtype, Struct):
|
|
676
|
-
dtypestr = t.dtype.native_name
|
|
677
|
-
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
678
|
-
dtypestr = t.dtype.__name__
|
|
679
|
-
else:
|
|
680
|
-
dtypestr = f"wp::{t.dtype.__name__}"
|
|
686
|
+
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
681
687
|
classstr = f"wp::{type(t).__name__}"
|
|
682
688
|
return f"{classstr}_t<{dtypestr}>"
|
|
689
|
+
elif get_origin(t) is tuple:
|
|
690
|
+
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
|
|
691
|
+
return f"wp::tuple_t<{dtypestr}>"
|
|
692
|
+
elif is_tuple(t):
|
|
693
|
+
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
|
|
694
|
+
classstr = f"wp::{type(t).__name__}"
|
|
695
|
+
return f"{classstr}<{dtypestr}>"
|
|
683
696
|
elif is_tile(t):
|
|
684
697
|
return t.ctype()
|
|
685
|
-
elif isinstance(t, Struct):
|
|
686
|
-
return t.native_name
|
|
687
698
|
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
688
699
|
# ensure the actual Struct name is used instead of "NewStructInstance"
|
|
689
700
|
return t.native_name
|
|
690
701
|
elif is_reference(t):
|
|
691
702
|
if not value_type:
|
|
692
703
|
return Var.type_to_ctype(t.value_type) + "*"
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
elif t.__name__ in ("bool", "int", "float"):
|
|
698
|
-
return t.__name__
|
|
699
|
-
else:
|
|
700
|
-
return f"wp::{t.__name__}"
|
|
704
|
+
|
|
705
|
+
return Var.type_to_ctype(t.value_type)
|
|
706
|
+
|
|
707
|
+
return Var.dtype_to_ctype(t)
|
|
701
708
|
|
|
702
709
|
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
703
710
|
return Var.type_to_ctype(self.type, value_type)
|
|
@@ -821,17 +828,26 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
821
828
|
return True
|
|
822
829
|
|
|
823
830
|
|
|
824
|
-
def get_arg_type(arg:
|
|
831
|
+
def get_arg_type(arg: Var | Any) -> type:
|
|
825
832
|
if isinstance(arg, str):
|
|
826
833
|
return str
|
|
827
834
|
|
|
828
835
|
if isinstance(arg, Sequence):
|
|
829
836
|
return tuple(get_arg_type(x) for x in arg)
|
|
830
837
|
|
|
838
|
+
if get_origin(arg) is tuple:
|
|
839
|
+
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
840
|
+
|
|
841
|
+
if is_tuple(arg):
|
|
842
|
+
return arg
|
|
843
|
+
|
|
831
844
|
if isinstance(arg, (type, warp.context.Function)):
|
|
832
845
|
return arg
|
|
833
846
|
|
|
834
847
|
if isinstance(arg, Var):
|
|
848
|
+
if get_origin(arg.type) is tuple:
|
|
849
|
+
return get_args(arg.type)
|
|
850
|
+
|
|
835
851
|
return arg.type
|
|
836
852
|
|
|
837
853
|
return type(arg)
|
|
@@ -845,7 +861,11 @@ def get_arg_value(arg: Any) -> Any:
|
|
|
845
861
|
return arg
|
|
846
862
|
|
|
847
863
|
if isinstance(arg, Var):
|
|
848
|
-
|
|
864
|
+
if is_tuple(arg.type):
|
|
865
|
+
return tuple(get_arg_value(x) for x in arg.type.values)
|
|
866
|
+
|
|
867
|
+
if arg.constant is not None:
|
|
868
|
+
return arg.constant
|
|
849
869
|
|
|
850
870
|
return arg
|
|
851
871
|
|
|
@@ -863,7 +883,8 @@ class Adjoint:
|
|
|
863
883
|
skip_reverse_codegen=False,
|
|
864
884
|
custom_reverse_mode=False,
|
|
865
885
|
custom_reverse_num_input_args=-1,
|
|
866
|
-
transformers:
|
|
886
|
+
transformers: list[ast.NodeTransformer] | None = None,
|
|
887
|
+
source: str | None = None,
|
|
867
888
|
):
|
|
868
889
|
adj.func = func
|
|
869
890
|
|
|
@@ -877,19 +898,17 @@ class Adjoint:
|
|
|
877
898
|
# extract name of source file
|
|
878
899
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
879
900
|
# get source file line number where function starts
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
) from e
|
|
901
|
+
adj.fun_lineno = 0
|
|
902
|
+
adj.source = source
|
|
903
|
+
if adj.source is None:
|
|
904
|
+
adj.source, adj.fun_lineno = adj.extract_function_source(func)
|
|
905
|
+
|
|
906
|
+
assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
|
|
887
907
|
|
|
888
908
|
# Indicates where the function definition starts (excludes decorators)
|
|
889
909
|
adj.fun_def_lineno = None
|
|
890
910
|
|
|
891
911
|
# get function source code
|
|
892
|
-
adj.source = inspect.getsource(func)
|
|
893
912
|
# ensures that indented class methods can be parsed as kernels
|
|
894
913
|
adj.source = textwrap.dedent(adj.source)
|
|
895
914
|
|
|
@@ -950,7 +969,7 @@ class Adjoint:
|
|
|
950
969
|
|
|
951
970
|
# try to replace static expressions by their constant result if the
|
|
952
971
|
# expression can be evaluated at declaration time
|
|
953
|
-
adj.static_expressions:
|
|
972
|
+
adj.static_expressions: dict[str, Any] = {}
|
|
954
973
|
if "static" in adj.source:
|
|
955
974
|
adj.replace_static_expressions()
|
|
956
975
|
|
|
@@ -981,6 +1000,18 @@ class Adjoint:
|
|
|
981
1000
|
|
|
982
1001
|
return total_shared + adj.max_required_extra_shared_memory
|
|
983
1002
|
|
|
1003
|
+
@staticmethod
|
|
1004
|
+
def extract_function_source(func: Callable) -> tuple[str, int]:
|
|
1005
|
+
try:
|
|
1006
|
+
_, fun_lineno = inspect.getsourcelines(func)
|
|
1007
|
+
source = inspect.getsource(func)
|
|
1008
|
+
except OSError as e:
|
|
1009
|
+
raise RuntimeError(
|
|
1010
|
+
"Directly evaluating Warp code defined as a string using `exec()` is not supported, "
|
|
1011
|
+
"please save it to a file and use `importlib` if needed."
|
|
1012
|
+
) from e
|
|
1013
|
+
return source, fun_lineno
|
|
1014
|
+
|
|
984
1015
|
# generate function ssa form and adjoint
|
|
985
1016
|
def build(adj, builder, default_builder_options=None):
|
|
986
1017
|
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
@@ -1058,7 +1089,7 @@ class Adjoint:
|
|
|
1058
1089
|
# code generation methods
|
|
1059
1090
|
def format_template(adj, template, input_vars, output_var):
|
|
1060
1091
|
# output var is always the 0th index
|
|
1061
|
-
args = [output_var
|
|
1092
|
+
args = [output_var, *input_vars]
|
|
1062
1093
|
s = template.format(*args)
|
|
1063
1094
|
|
|
1064
1095
|
return s
|
|
@@ -1176,7 +1207,7 @@ class Adjoint:
|
|
|
1176
1207
|
|
|
1177
1208
|
return var
|
|
1178
1209
|
|
|
1179
|
-
def get_line_directive(adj, statement: str, relative_lineno:
|
|
1210
|
+
def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
|
|
1180
1211
|
"""Get a line directive for the given statement.
|
|
1181
1212
|
|
|
1182
1213
|
Args:
|
|
@@ -1202,7 +1233,7 @@ class Adjoint:
|
|
|
1202
1233
|
return f'#line {line} "{normalized_path}"'
|
|
1203
1234
|
return None
|
|
1204
1235
|
|
|
1205
|
-
def add_forward(adj, statement: str, replay:
|
|
1236
|
+
def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
|
|
1206
1237
|
"""Append a statement to the forward pass."""
|
|
1207
1238
|
|
|
1208
1239
|
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
@@ -1300,7 +1331,8 @@ class Adjoint:
|
|
|
1300
1331
|
|
|
1301
1332
|
# check output dimensions match expectations
|
|
1302
1333
|
if min_outputs:
|
|
1303
|
-
|
|
1334
|
+
value_type = f.value_func(None, None)
|
|
1335
|
+
if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
|
|
1304
1336
|
continue
|
|
1305
1337
|
|
|
1306
1338
|
# found a match, use it
|
|
@@ -1396,6 +1428,17 @@ class Adjoint:
|
|
|
1396
1428
|
bound_arg_values,
|
|
1397
1429
|
)
|
|
1398
1430
|
|
|
1431
|
+
# Handle the special case where a Var instance is returned from the `value_func`
|
|
1432
|
+
# callback, in which case we replace the call with a reference to that variable.
|
|
1433
|
+
if isinstance(return_type, Var):
|
|
1434
|
+
return adj.register_var(return_type)
|
|
1435
|
+
elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
|
|
1436
|
+
return tuple(adj.register_var(x) for x in return_type)
|
|
1437
|
+
|
|
1438
|
+
if get_origin(return_type) is tuple:
|
|
1439
|
+
types = get_args(return_type)
|
|
1440
|
+
return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
|
|
1441
|
+
|
|
1399
1442
|
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1400
1443
|
if return_type is None:
|
|
1401
1444
|
# void function
|
|
@@ -1775,6 +1818,22 @@ class Adjoint:
|
|
|
1775
1818
|
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1776
1819
|
adj.symbols[sym] = out
|
|
1777
1820
|
|
|
1821
|
+
def emit_IfExp(adj, node):
|
|
1822
|
+
cond = adj.eval(node.test)
|
|
1823
|
+
|
|
1824
|
+
if cond.constant is not None:
|
|
1825
|
+
return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
|
|
1826
|
+
|
|
1827
|
+
adj.begin_if(cond)
|
|
1828
|
+
body = adj.eval(node.body)
|
|
1829
|
+
adj.end_if(cond)
|
|
1830
|
+
|
|
1831
|
+
adj.begin_else(cond)
|
|
1832
|
+
orelse = adj.eval(node.orelse)
|
|
1833
|
+
adj.end_else(cond)
|
|
1834
|
+
|
|
1835
|
+
return adj.add_builtin_call("where", [cond, body, orelse])
|
|
1836
|
+
|
|
1778
1837
|
def emit_Compare(adj, node):
|
|
1779
1838
|
# node.left, node.ops (list of ops), node.comparators (things to compare to)
|
|
1780
1839
|
# e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
|
|
@@ -1831,7 +1890,7 @@ class Adjoint:
|
|
|
1831
1890
|
if attr == "dtype":
|
|
1832
1891
|
return type_scalar_type(var_type)
|
|
1833
1892
|
elif attr == "length":
|
|
1834
|
-
return
|
|
1893
|
+
return type_size(var_type)
|
|
1835
1894
|
|
|
1836
1895
|
return getattr(var_type, attr, None)
|
|
1837
1896
|
|
|
@@ -1850,6 +1909,15 @@ class Adjoint:
|
|
|
1850
1909
|
index = adj.add_constant(index)
|
|
1851
1910
|
return index
|
|
1852
1911
|
|
|
1912
|
+
def transform_component(adj, component):
|
|
1913
|
+
if len(component) != 1:
|
|
1914
|
+
raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
|
|
1915
|
+
|
|
1916
|
+
if component not in ("p", "q"):
|
|
1917
|
+
raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
|
|
1918
|
+
|
|
1919
|
+
return component
|
|
1920
|
+
|
|
1853
1921
|
@staticmethod
|
|
1854
1922
|
def is_differentiable_value_type(var_type):
|
|
1855
1923
|
# checks that the argument type is a value type (i.e, not an array)
|
|
@@ -1880,12 +1948,20 @@ class Adjoint:
|
|
|
1880
1948
|
|
|
1881
1949
|
aggregate_type = strip_reference(aggregate.type)
|
|
1882
1950
|
|
|
1883
|
-
# reading a vector component
|
|
1884
|
-
if type_is_vector(aggregate_type):
|
|
1951
|
+
# reading a vector or quaternion component
|
|
1952
|
+
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
1885
1953
|
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
1886
1954
|
|
|
1887
1955
|
return adj.add_builtin_call("extract", [aggregate, index])
|
|
1888
1956
|
|
|
1957
|
+
elif type_is_transformation(aggregate_type):
|
|
1958
|
+
component = adj.transform_component(node.attr)
|
|
1959
|
+
|
|
1960
|
+
if component == "p":
|
|
1961
|
+
return adj.add_builtin_call("transform_get_translation", [aggregate])
|
|
1962
|
+
else:
|
|
1963
|
+
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
1964
|
+
|
|
1889
1965
|
else:
|
|
1890
1966
|
attr_type = Reference(aggregate_type.vars[node.attr].type)
|
|
1891
1967
|
attr = adj.add_var(attr_type)
|
|
@@ -2282,6 +2358,10 @@ class Adjoint:
|
|
|
2282
2358
|
else:
|
|
2283
2359
|
func = caller.default_constructor
|
|
2284
2360
|
|
|
2361
|
+
# lambda function
|
|
2362
|
+
if func is None and getattr(caller, "__name__", None) == "<lambda>":
|
|
2363
|
+
raise NotImplementedError("Lambda expressions are not yet supported")
|
|
2364
|
+
|
|
2285
2365
|
if hasattr(caller, "_wp_type_args_"):
|
|
2286
2366
|
type_args = caller._wp_type_args_
|
|
2287
2367
|
|
|
@@ -2290,18 +2370,6 @@ class Adjoint:
|
|
|
2290
2370
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
2291
2371
|
)
|
|
2292
2372
|
|
|
2293
|
-
# Check if any argument correspond to an unsupported construct.
|
|
2294
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2295
|
-
# at once, but not in place of vectors when calling built-ins like
|
|
2296
|
-
# `wp.length((1, 2, 3))`.
|
|
2297
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2298
|
-
# more generally in `adj.eval()`.
|
|
2299
|
-
for arg in node.args:
|
|
2300
|
-
if isinstance(arg, ast.Tuple):
|
|
2301
|
-
raise WarpCodegenError(
|
|
2302
|
-
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
|
|
2303
|
-
)
|
|
2304
|
-
|
|
2305
2373
|
# get expected return count, e.g.: for multi-assignment
|
|
2306
2374
|
min_outputs = None
|
|
2307
2375
|
if hasattr(node, "expects"):
|
|
@@ -2311,7 +2379,6 @@ class Adjoint:
|
|
|
2311
2379
|
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2312
2380
|
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2313
2381
|
|
|
2314
|
-
# add the call and build the callee adjoint if needed (func.adj)
|
|
2315
2382
|
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2316
2383
|
|
|
2317
2384
|
if warp.config.verify_autograd_array_access:
|
|
@@ -2461,10 +2528,6 @@ class Adjoint:
|
|
|
2461
2528
|
raise WarpCodegenError(
|
|
2462
2529
|
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2463
2530
|
)
|
|
2464
|
-
elif isinstance(node.value, ast.Tuple):
|
|
2465
|
-
raise WarpCodegenError(
|
|
2466
|
-
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2467
|
-
)
|
|
2468
2531
|
|
|
2469
2532
|
# handle the case where we are assigning multiple output variables
|
|
2470
2533
|
if isinstance(lhs, ast.Tuple):
|
|
@@ -2480,6 +2543,17 @@ class Adjoint:
|
|
|
2480
2543
|
else:
|
|
2481
2544
|
out = adj.eval(node.value)
|
|
2482
2545
|
|
|
2546
|
+
subtype = getattr(out, "type", None)
|
|
2547
|
+
if isinstance(subtype, warp.types.tuple_t):
|
|
2548
|
+
if len(out.type.types) != len(lhs.elts):
|
|
2549
|
+
raise WarpCodegenError(
|
|
2550
|
+
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
|
|
2551
|
+
)
|
|
2552
|
+
target = out
|
|
2553
|
+
out = tuple(
|
|
2554
|
+
adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
|
|
2555
|
+
)
|
|
2556
|
+
|
|
2483
2557
|
names = []
|
|
2484
2558
|
for v in lhs.elts:
|
|
2485
2559
|
if isinstance(v, ast.Name):
|
|
@@ -2532,7 +2606,12 @@ class Adjoint:
|
|
|
2532
2606
|
elif is_tile(target_type):
|
|
2533
2607
|
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2534
2608
|
|
|
2535
|
-
elif
|
|
2609
|
+
elif (
|
|
2610
|
+
type_is_vector(target_type)
|
|
2611
|
+
or type_is_quaternion(target_type)
|
|
2612
|
+
or type_is_matrix(target_type)
|
|
2613
|
+
or type_is_transformation(target_type)
|
|
2614
|
+
):
|
|
2536
2615
|
# recursively unwind AST, stopping at penultimate node
|
|
2537
2616
|
node = lhs
|
|
2538
2617
|
while hasattr(node, "value"):
|
|
@@ -2572,7 +2651,7 @@ class Adjoint:
|
|
|
2572
2651
|
|
|
2573
2652
|
else:
|
|
2574
2653
|
raise WarpCodegenError(
|
|
2575
|
-
f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
|
|
2654
|
+
f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
|
|
2576
2655
|
)
|
|
2577
2656
|
|
|
2578
2657
|
elif isinstance(lhs, ast.Name):
|
|
@@ -2589,8 +2668,11 @@ class Adjoint:
|
|
|
2589
2668
|
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
2590
2669
|
)
|
|
2591
2670
|
|
|
2592
|
-
|
|
2593
|
-
|
|
2671
|
+
if isinstance(node.value, ast.Tuple):
|
|
2672
|
+
out = rhs
|
|
2673
|
+
elif isinstance(rhs, Sequence):
|
|
2674
|
+
out = adj.add_builtin_call("tuple", rhs)
|
|
2675
|
+
elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
2594
2676
|
out = adj.add_builtin_call("copy", [rhs])
|
|
2595
2677
|
else:
|
|
2596
2678
|
out = rhs
|
|
@@ -2622,6 +2704,18 @@ class Adjoint:
|
|
|
2622
2704
|
else:
|
|
2623
2705
|
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2624
2706
|
|
|
2707
|
+
elif type_is_transformation(aggregate_type):
|
|
2708
|
+
component = adj.transform_component(lhs.attr)
|
|
2709
|
+
|
|
2710
|
+
# TODO: x[i,j].p = rhs case
|
|
2711
|
+
if is_reference(aggregate.type):
|
|
2712
|
+
raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
|
|
2713
|
+
|
|
2714
|
+
if component == "p":
|
|
2715
|
+
return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
|
|
2716
|
+
else:
|
|
2717
|
+
return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
|
|
2718
|
+
|
|
2625
2719
|
else:
|
|
2626
2720
|
attr = adj.emit_Attribute(lhs)
|
|
2627
2721
|
if is_reference(attr.type):
|
|
@@ -2644,7 +2738,9 @@ class Adjoint:
|
|
|
2644
2738
|
elif isinstance(node.value, ast.Tuple):
|
|
2645
2739
|
var = tuple(adj.eval(arg) for arg in node.value.elts)
|
|
2646
2740
|
else:
|
|
2647
|
-
var =
|
|
2741
|
+
var = adj.eval(node.value)
|
|
2742
|
+
if not isinstance(var, list) and not isinstance(var, tuple):
|
|
2743
|
+
var = (var,)
|
|
2648
2744
|
|
|
2649
2745
|
if adj.return_var is not None:
|
|
2650
2746
|
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
@@ -2697,6 +2793,7 @@ class Adjoint:
|
|
|
2697
2793
|
type_is_vector(target_type.dtype)
|
|
2698
2794
|
or type_is_quaternion(target_type.dtype)
|
|
2699
2795
|
or type_is_matrix(target_type.dtype)
|
|
2796
|
+
or type_is_transformation(target_type.dtype)
|
|
2700
2797
|
):
|
|
2701
2798
|
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2702
2799
|
if dtype in warp.types.non_atomic_types:
|
|
@@ -2724,7 +2821,12 @@ class Adjoint:
|
|
|
2724
2821
|
make_new_assign_statement()
|
|
2725
2822
|
return
|
|
2726
2823
|
|
|
2727
|
-
elif
|
|
2824
|
+
elif (
|
|
2825
|
+
type_is_vector(target_type)
|
|
2826
|
+
or type_is_quaternion(target_type)
|
|
2827
|
+
or type_is_matrix(target_type)
|
|
2828
|
+
or type_is_transformation(target_type)
|
|
2829
|
+
):
|
|
2728
2830
|
if isinstance(node.op, ast.Add):
|
|
2729
2831
|
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2730
2832
|
elif isinstance(node.op, ast.Sub):
|
|
@@ -2735,9 +2837,36 @@ class Adjoint:
|
|
|
2735
2837
|
make_new_assign_statement()
|
|
2736
2838
|
return
|
|
2737
2839
|
|
|
2840
|
+
elif is_tile(target.type):
|
|
2841
|
+
if isinstance(node.op, ast.Add):
|
|
2842
|
+
adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
|
|
2843
|
+
elif isinstance(node.op, ast.Sub):
|
|
2844
|
+
adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
|
|
2845
|
+
else:
|
|
2846
|
+
if warp.config.verbose:
|
|
2847
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2848
|
+
make_new_assign_statement()
|
|
2849
|
+
return
|
|
2850
|
+
|
|
2738
2851
|
else:
|
|
2739
2852
|
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
2740
2853
|
|
|
2854
|
+
elif isinstance(lhs, ast.Name):
|
|
2855
|
+
target = adj.eval(node.target)
|
|
2856
|
+
rhs = adj.eval(node.value)
|
|
2857
|
+
|
|
2858
|
+
if is_tile(target.type) and is_tile(rhs.type):
|
|
2859
|
+
if isinstance(node.op, ast.Add):
|
|
2860
|
+
adj.add_builtin_call("add_inplace", [target, rhs])
|
|
2861
|
+
elif isinstance(node.op, ast.Sub):
|
|
2862
|
+
adj.add_builtin_call("sub_inplace", [target, rhs])
|
|
2863
|
+
else:
|
|
2864
|
+
make_new_assign_statement()
|
|
2865
|
+
return
|
|
2866
|
+
else:
|
|
2867
|
+
make_new_assign_statement()
|
|
2868
|
+
return
|
|
2869
|
+
|
|
2741
2870
|
# TODO
|
|
2742
2871
|
elif isinstance(lhs, ast.Attribute):
|
|
2743
2872
|
make_new_assign_statement()
|
|
@@ -2748,15 +2877,16 @@ class Adjoint:
|
|
|
2748
2877
|
return
|
|
2749
2878
|
|
|
2750
2879
|
def emit_Tuple(adj, node):
|
|
2751
|
-
|
|
2752
|
-
return
|
|
2880
|
+
elements = tuple(adj.eval(x) for x in node.elts)
|
|
2881
|
+
return adj.add_builtin_call("tuple", elements)
|
|
2753
2882
|
|
|
2754
2883
|
def emit_Pass(adj, node):
|
|
2755
2884
|
pass
|
|
2756
2885
|
|
|
2757
|
-
node_visitors = {
|
|
2886
|
+
node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
|
|
2758
2887
|
ast.FunctionDef: emit_FunctionDef,
|
|
2759
2888
|
ast.If: emit_If,
|
|
2889
|
+
ast.IfExp: emit_IfExp,
|
|
2760
2890
|
ast.Compare: emit_Compare,
|
|
2761
2891
|
ast.BoolOp: emit_BoolOp,
|
|
2762
2892
|
ast.Name: emit_Name,
|
|
@@ -2860,11 +2990,11 @@ class Adjoint:
|
|
|
2860
2990
|
if isinstance(value, warp.context.Function):
|
|
2861
2991
|
return True
|
|
2862
2992
|
|
|
2863
|
-
def verify_struct(s: StructInstance, attr_path:
|
|
2993
|
+
def verify_struct(s: StructInstance, attr_path: list[str]):
|
|
2864
2994
|
for key in s._cls.vars.keys():
|
|
2865
2995
|
v = getattr(s, key)
|
|
2866
2996
|
if issubclass(type(v), StructInstance):
|
|
2867
|
-
verify_struct(v, attr_path
|
|
2997
|
+
verify_struct(v, [*attr_path, key])
|
|
2868
2998
|
else:
|
|
2869
2999
|
try:
|
|
2870
3000
|
adj.verify_static_return_value(v)
|
|
@@ -2879,7 +3009,8 @@ class Adjoint:
|
|
|
2879
3009
|
raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
|
|
2880
3010
|
|
|
2881
3011
|
# find the source code string of an AST node
|
|
2882
|
-
|
|
3012
|
+
@staticmethod
|
|
3013
|
+
def extract_node_source_from_lines(source_lines, node) -> str | None:
|
|
2883
3014
|
if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
|
|
2884
3015
|
return None
|
|
2885
3016
|
|
|
@@ -2895,12 +3026,12 @@ class Adjoint:
|
|
|
2895
3026
|
end_line = start_line
|
|
2896
3027
|
end_col = start_col
|
|
2897
3028
|
parenthesis_count = 1
|
|
2898
|
-
for lineno in range(start_line, len(
|
|
3029
|
+
for lineno in range(start_line, len(source_lines)):
|
|
2899
3030
|
if lineno == start_line:
|
|
2900
3031
|
c_start = start_col
|
|
2901
3032
|
else:
|
|
2902
3033
|
c_start = 0
|
|
2903
|
-
line =
|
|
3034
|
+
line = source_lines[lineno]
|
|
2904
3035
|
for i in range(c_start, len(line)):
|
|
2905
3036
|
c = line[i]
|
|
2906
3037
|
if c == "(":
|
|
@@ -2916,21 +3047,57 @@ class Adjoint:
|
|
|
2916
3047
|
|
|
2917
3048
|
if start_line == end_line:
|
|
2918
3049
|
# single-line expression
|
|
2919
|
-
return
|
|
3050
|
+
return source_lines[start_line][start_col:end_col]
|
|
2920
3051
|
else:
|
|
2921
3052
|
# multi-line expression
|
|
2922
3053
|
lines = []
|
|
2923
3054
|
# first line (from start_col to the end)
|
|
2924
|
-
lines.append(
|
|
3055
|
+
lines.append(source_lines[start_line][start_col:])
|
|
2925
3056
|
# middle lines (entire lines)
|
|
2926
|
-
lines.extend(
|
|
3057
|
+
lines.extend(source_lines[start_line + 1 : end_line])
|
|
2927
3058
|
# last line (from the start to end_col)
|
|
2928
|
-
lines.append(
|
|
3059
|
+
lines.append(source_lines[end_line][:end_col])
|
|
2929
3060
|
return "\n".join(lines).strip()
|
|
2930
3061
|
|
|
3062
|
+
@staticmethod
|
|
3063
|
+
def extract_lambda_source(func, only_body=False) -> str | None:
|
|
3064
|
+
try:
|
|
3065
|
+
source_lines = inspect.getsourcelines(func)[0]
|
|
3066
|
+
source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
|
|
3067
|
+
except OSError as e:
|
|
3068
|
+
raise WarpCodegenError(
|
|
3069
|
+
"Could not access lambda function source code. Please use a named function instead."
|
|
3070
|
+
) from e
|
|
3071
|
+
source = "".join(source_lines)
|
|
3072
|
+
source = source[source.index("lambda") :].rstrip()
|
|
3073
|
+
# Remove trailing unbalanced parentheses
|
|
3074
|
+
while source.count("(") < source.count(")"):
|
|
3075
|
+
source = source[:-1]
|
|
3076
|
+
# extract lambda expression up until a comma, e.g. in the case of
|
|
3077
|
+
# "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
|
|
3078
|
+
si = max(source.find(")"), source.find(":"))
|
|
3079
|
+
ci = source.find(",", si)
|
|
3080
|
+
if ci != -1:
|
|
3081
|
+
source = source[:ci]
|
|
3082
|
+
tree = ast.parse(source)
|
|
3083
|
+
lambda_source = None
|
|
3084
|
+
for node in ast.walk(tree):
|
|
3085
|
+
if isinstance(node, ast.Lambda):
|
|
3086
|
+
if only_body:
|
|
3087
|
+
# extract the body of the lambda function
|
|
3088
|
+
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
|
|
3089
|
+
else:
|
|
3090
|
+
# extract the entire lambda function
|
|
3091
|
+
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
|
|
3092
|
+
break
|
|
3093
|
+
return lambda_source
|
|
3094
|
+
|
|
3095
|
+
def extract_node_source(adj, node) -> str | None:
|
|
3096
|
+
return adj.extract_node_source_from_lines(adj.source_lines, node)
|
|
3097
|
+
|
|
2931
3098
|
# handles a wp.static() expression and returns the resulting object and a string representing the code
|
|
2932
3099
|
# of the static expression
|
|
2933
|
-
def evaluate_static_expression(adj, node) ->
|
|
3100
|
+
def evaluate_static_expression(adj, node) -> tuple[Any, str]:
|
|
2934
3101
|
if len(node.args) == 1:
|
|
2935
3102
|
static_code = adj.extract_node_source(node.args[0])
|
|
2936
3103
|
elif len(node.keywords) == 1:
|
|
@@ -2950,29 +3117,14 @@ class Adjoint:
|
|
|
2950
3117
|
|
|
2951
3118
|
# Replace all constant `len()` expressions with their value.
|
|
2952
3119
|
if "len" in static_code:
|
|
2953
|
-
|
|
2954
|
-
def eval_len(obj):
|
|
2955
|
-
if type_is_vector(obj):
|
|
2956
|
-
return obj._length_
|
|
2957
|
-
elif type_is_quaternion(obj):
|
|
2958
|
-
return obj._length_
|
|
2959
|
-
elif type_is_matrix(obj):
|
|
2960
|
-
return obj._shape_[0]
|
|
2961
|
-
elif type_is_transformation(obj):
|
|
2962
|
-
return obj._length_
|
|
2963
|
-
elif is_tile(obj):
|
|
2964
|
-
return obj.shape[0]
|
|
2965
|
-
|
|
2966
|
-
return len(obj)
|
|
2967
|
-
|
|
2968
3120
|
len_expr_ctx = vars_dict.copy()
|
|
2969
3121
|
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
2970
3122
|
len_expr_ctx.update(constant_types)
|
|
2971
|
-
len_expr_ctx.update({"len":
|
|
3123
|
+
len_expr_ctx.update({"len": warp.types.type_length})
|
|
2972
3124
|
|
|
2973
3125
|
# We want to replace the expression code in-place,
|
|
2974
3126
|
# so reparse it to get the correct column info.
|
|
2975
|
-
len_value_locs:
|
|
3127
|
+
len_value_locs: list[tuple[int, int, int]] = []
|
|
2976
3128
|
expr_tree = ast.parse(static_code)
|
|
2977
3129
|
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2978
3130
|
expr_root = expr_tree.body[0].value
|
|
@@ -3134,14 +3286,14 @@ class Adjoint:
|
|
|
3134
3286
|
# return the Python code corresponding to the given AST node
|
|
3135
3287
|
return ast.get_source_segment(adj.source, node)
|
|
3136
3288
|
|
|
3137
|
-
def get_references(adj) ->
|
|
3289
|
+
def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
|
|
3138
3290
|
"""Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
|
|
3139
3291
|
|
|
3140
3292
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3141
3293
|
|
|
3142
|
-
constants:
|
|
3143
|
-
types:
|
|
3144
|
-
functions:
|
|
3294
|
+
constants: dict[str, Any] = {}
|
|
3295
|
+
types: dict[Struct | type, Any] = {}
|
|
3296
|
+
functions: dict[warp.context.Function, Any] = {}
|
|
3145
3297
|
|
|
3146
3298
|
for node in ast.walk(adj.tree):
|
|
3147
3299
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
@@ -3200,6 +3352,8 @@ cpu_module_header = """
|
|
|
3200
3352
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
3201
3353
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
3202
3354
|
|
|
3355
|
+
#define builtin_block_dim() wp::block_dim()
|
|
3356
|
+
|
|
3203
3357
|
"""
|
|
3204
3358
|
|
|
3205
3359
|
cuda_module_header = """
|
|
@@ -3219,6 +3373,8 @@ cuda_module_header = """
|
|
|
3219
3373
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3220
3374
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
3221
3375
|
|
|
3376
|
+
#define builtin_block_dim() wp::block_dim()
|
|
3377
|
+
|
|
3222
3378
|
"""
|
|
3223
3379
|
|
|
3224
3380
|
struct_template = """
|
|
@@ -3663,7 +3819,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3663
3819
|
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3664
3820
|
f"but the code returns {len(adj.return_var)} values."
|
|
3665
3821
|
)
|
|
3666
|
-
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
3822
|
+
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
|
|
3667
3823
|
raise WarpCodegenError(
|
|
3668
3824
|
f"The function `{adj.fun_name}` has its return type "
|
|
3669
3825
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|