warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/__init__.py +3 -1
- warp/__init__.pyi +3489 -1
- warp/autograd.py +45 -122
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/build.py +241 -252
- warp/build_dll.py +130 -26
- warp/builtins.py +1907 -384
- warp/codegen.py +272 -104
- warp/config.py +12 -1
- warp/constants.py +1 -1
- warp/context.py +770 -238
- warp/dlpack.py +1 -1
- warp/examples/benchmarks/benchmark_cloth.py +2 -2
- warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
- warp/examples/core/example_sample_mesh.py +1 -1
- warp/examples/core/example_spin_lock.py +93 -0
- warp/examples/core/example_work_queue.py +118 -0
- warp/examples/fem/example_adaptive_grid.py +5 -5
- warp/examples/fem/example_apic_fluid.py +1 -1
- warp/examples/fem/example_burgers.py +1 -1
- warp/examples/fem/example_convection_diffusion.py +9 -6
- warp/examples/fem/example_darcy_ls_optimization.py +489 -0
- warp/examples/fem/example_deformed_geometry.py +1 -1
- warp/examples/fem/example_diffusion.py +2 -2
- warp/examples/fem/example_diffusion_3d.py +1 -1
- warp/examples/fem/example_distortion_energy.py +1 -1
- warp/examples/fem/example_elastic_shape_optimization.py +387 -0
- warp/examples/fem/example_magnetostatics.py +5 -3
- warp/examples/fem/example_mixed_elasticity.py +5 -3
- warp/examples/fem/example_navier_stokes.py +11 -9
- warp/examples/fem/example_nonconforming_contact.py +5 -3
- warp/examples/fem/example_streamlines.py +8 -3
- warp/examples/fem/utils.py +9 -8
- warp/examples/interop/example_jax_callable.py +34 -4
- warp/examples/interop/example_jax_ffi_callback.py +2 -2
- warp/examples/interop/example_jax_kernel.py +27 -1
- warp/examples/optim/example_drone.py +1 -1
- warp/examples/sim/example_cloth.py +1 -1
- warp/examples/sim/example_cloth_self_contact.py +48 -54
- warp/examples/tile/example_tile_block_cholesky.py +502 -0
- warp/examples/tile/example_tile_cholesky.py +2 -1
- warp/examples/tile/example_tile_convolution.py +1 -1
- warp/examples/tile/example_tile_filtering.py +1 -1
- warp/examples/tile/example_tile_matmul.py +1 -1
- warp/examples/tile/example_tile_mlp.py +2 -0
- warp/fabric.py +7 -7
- warp/fem/__init__.py +5 -0
- warp/fem/adaptivity.py +1 -1
- warp/fem/cache.py +152 -63
- warp/fem/dirichlet.py +2 -2
- warp/fem/domain.py +136 -6
- warp/fem/field/field.py +141 -99
- warp/fem/field/nodal_field.py +85 -39
- warp/fem/field/virtual.py +99 -52
- warp/fem/geometry/adaptive_nanogrid.py +91 -86
- warp/fem/geometry/closest_point.py +13 -0
- warp/fem/geometry/deformed_geometry.py +102 -40
- warp/fem/geometry/element.py +56 -2
- warp/fem/geometry/geometry.py +323 -22
- warp/fem/geometry/grid_2d.py +157 -62
- warp/fem/geometry/grid_3d.py +116 -20
- warp/fem/geometry/hexmesh.py +86 -20
- warp/fem/geometry/nanogrid.py +166 -86
- warp/fem/geometry/partition.py +59 -25
- warp/fem/geometry/quadmesh.py +86 -135
- warp/fem/geometry/tetmesh.py +47 -119
- warp/fem/geometry/trimesh.py +77 -270
- warp/fem/integrate.py +181 -95
- warp/fem/linalg.py +25 -58
- warp/fem/operator.py +124 -27
- warp/fem/quadrature/pic_quadrature.py +36 -14
- warp/fem/quadrature/quadrature.py +40 -16
- warp/fem/space/__init__.py +1 -1
- warp/fem/space/basis_function_space.py +66 -46
- warp/fem/space/basis_space.py +17 -4
- warp/fem/space/dof_mapper.py +1 -1
- warp/fem/space/function_space.py +2 -2
- warp/fem/space/grid_2d_function_space.py +4 -1
- warp/fem/space/hexmesh_function_space.py +4 -2
- warp/fem/space/nanogrid_function_space.py +3 -1
- warp/fem/space/partition.py +11 -2
- warp/fem/space/quadmesh_function_space.py +4 -1
- warp/fem/space/restriction.py +5 -2
- warp/fem/space/shape/__init__.py +10 -8
- warp/fem/space/tetmesh_function_space.py +4 -1
- warp/fem/space/topology.py +52 -21
- warp/fem/space/trimesh_function_space.py +4 -1
- warp/fem/utils.py +53 -8
- warp/jax.py +1 -2
- warp/jax_experimental/ffi.py +210 -67
- warp/jax_experimental/xla_ffi.py +37 -24
- warp/math.py +171 -1
- warp/native/array.h +103 -4
- warp/native/builtin.h +182 -35
- warp/native/coloring.cpp +6 -2
- warp/native/cuda_util.cpp +1 -1
- warp/native/exports.h +118 -63
- warp/native/intersect.h +5 -5
- warp/native/mat.h +8 -13
- warp/native/mathdx.cpp +11 -5
- warp/native/matnn.h +1 -123
- warp/native/mesh.h +1 -1
- warp/native/quat.h +34 -6
- warp/native/rand.h +7 -7
- warp/native/sparse.cpp +121 -258
- warp/native/sparse.cu +181 -274
- warp/native/spatial.h +305 -17
- warp/native/svd.h +23 -8
- warp/native/tile.h +603 -73
- warp/native/tile_radix_sort.h +1112 -0
- warp/native/tile_reduce.h +239 -13
- warp/native/tile_scan.h +240 -0
- warp/native/tuple.h +189 -0
- warp/native/vec.h +10 -20
- warp/native/warp.cpp +36 -4
- warp/native/warp.cu +588 -52
- warp/native/warp.h +47 -74
- warp/optim/linear.py +5 -1
- warp/paddle.py +7 -8
- warp/py.typed +0 -0
- warp/render/render_opengl.py +110 -80
- warp/render/render_usd.py +124 -62
- warp/sim/__init__.py +9 -0
- warp/sim/collide.py +253 -80
- warp/sim/graph_coloring.py +8 -1
- warp/sim/import_mjcf.py +4 -3
- warp/sim/import_usd.py +11 -7
- warp/sim/integrator.py +5 -2
- warp/sim/integrator_euler.py +1 -1
- warp/sim/integrator_featherstone.py +1 -1
- warp/sim/integrator_vbd.py +761 -322
- warp/sim/integrator_xpbd.py +1 -1
- warp/sim/model.py +265 -260
- warp/sim/utils.py +10 -7
- warp/sparse.py +303 -166
- warp/tape.py +54 -51
- warp/tests/cuda/test_conditional_captures.py +1046 -0
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/geometry/test_volume.py +2 -2
- warp/tests/interop/test_dlpack.py +9 -9
- warp/tests/interop/test_jax.py +0 -1
- warp/tests/run_coverage_serial.py +1 -1
- warp/tests/sim/disabled_kinematics.py +2 -2
- warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
- warp/tests/sim/test_collision.py +159 -51
- warp/tests/sim/test_coloring.py +91 -2
- warp/tests/test_array.py +254 -2
- warp/tests/test_array_reduce.py +2 -2
- warp/tests/test_assert.py +53 -0
- warp/tests/test_atomic_cas.py +312 -0
- warp/tests/test_codegen.py +142 -19
- warp/tests/test_conditional.py +47 -1
- warp/tests/test_ctypes.py +0 -20
- warp/tests/test_devices.py +8 -0
- warp/tests/test_fabricarray.py +4 -2
- warp/tests/test_fem.py +58 -25
- warp/tests/test_func.py +42 -1
- warp/tests/test_grad.py +1 -1
- warp/tests/test_lerp.py +1 -3
- warp/tests/test_map.py +481 -0
- warp/tests/test_mat.py +23 -24
- warp/tests/test_quat.py +28 -15
- warp/tests/test_rounding.py +10 -38
- warp/tests/test_runlength_encode.py +7 -7
- warp/tests/test_smoothstep.py +1 -1
- warp/tests/test_sparse.py +83 -2
- warp/tests/test_spatial.py +507 -1
- warp/tests/test_static.py +48 -0
- warp/tests/test_struct.py +2 -2
- warp/tests/test_tape.py +38 -0
- warp/tests/test_tuple.py +265 -0
- warp/tests/test_types.py +2 -2
- warp/tests/test_utils.py +24 -18
- warp/tests/test_vec.py +38 -408
- warp/tests/test_vec_constructors.py +325 -0
- warp/tests/tile/test_tile.py +438 -131
- warp/tests/tile/test_tile_mathdx.py +518 -14
- warp/tests/tile/test_tile_matmul.py +179 -0
- warp/tests/tile/test_tile_reduce.py +307 -5
- warp/tests/tile/test_tile_shared_memory.py +136 -7
- warp/tests/tile/test_tile_sort.py +121 -0
- warp/tests/unittest_suites.py +14 -6
- warp/types.py +462 -308
- warp/utils.py +647 -86
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
- warp/stubs.py +0 -3381
- warp/tests/sim/test_xpbd.py +0 -399
- warp/tests/test_mlp.py +0 -282
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/codegen.py
CHANGED
|
@@ -26,7 +26,7 @@ import re
|
|
|
26
26
|
import sys
|
|
27
27
|
import textwrap
|
|
28
28
|
import types
|
|
29
|
-
from typing import Any, Callable,
|
|
29
|
+
from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
|
|
30
30
|
|
|
31
31
|
import warp.config
|
|
32
32
|
from warp.types import *
|
|
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
|
|
|
57
57
|
|
|
58
58
|
|
|
59
59
|
# map operator to function name
|
|
60
|
-
builtin_operators:
|
|
60
|
+
builtin_operators: dict[type[ast.AST], str] = {}
|
|
61
61
|
|
|
62
62
|
# see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
|
|
63
63
|
# nice overview of python operators
|
|
@@ -321,7 +321,7 @@ class StructInstance:
|
|
|
321
321
|
# vector/matrix type, e.g. vec3
|
|
322
322
|
if value is None:
|
|
323
323
|
setattr(self._ctype, name, var.type())
|
|
324
|
-
elif
|
|
324
|
+
elif type(value) == var.type:
|
|
325
325
|
setattr(self._ctype, name, value)
|
|
326
326
|
else:
|
|
327
327
|
# conversion from list/tuple, ndarray, etc.
|
|
@@ -616,6 +616,8 @@ def compute_type_str(base_name, template_params):
|
|
|
616
616
|
def param2str(p):
|
|
617
617
|
if isinstance(p, int):
|
|
618
618
|
return str(p)
|
|
619
|
+
elif hasattr(p, "_wp_generic_type_str_"):
|
|
620
|
+
return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
|
|
619
621
|
elif hasattr(p, "_type_"):
|
|
620
622
|
if p.__name__ == "bool":
|
|
621
623
|
return "bool"
|
|
@@ -626,7 +628,7 @@ def compute_type_str(base_name, template_params):
|
|
|
626
628
|
|
|
627
629
|
return p.__name__
|
|
628
630
|
|
|
629
|
-
return f"{base_name}<{','.join(map(param2str, template_params))}>"
|
|
631
|
+
return f"{base_name}<{', '.join(map(param2str, template_params))}>"
|
|
630
632
|
|
|
631
633
|
|
|
632
634
|
class Var:
|
|
@@ -635,9 +637,9 @@ class Var:
|
|
|
635
637
|
label: str,
|
|
636
638
|
type: type,
|
|
637
639
|
requires_grad: builtins.bool = False,
|
|
638
|
-
constant:
|
|
640
|
+
constant: builtins.bool | None = None,
|
|
639
641
|
prefix: builtins.bool = True,
|
|
640
|
-
relative_lineno:
|
|
642
|
+
relative_lineno: int | None = None,
|
|
641
643
|
):
|
|
642
644
|
# convert built-in types to wp types
|
|
643
645
|
if type == float:
|
|
@@ -667,37 +669,44 @@ class Var:
|
|
|
667
669
|
def __str__(self):
|
|
668
670
|
return self.label
|
|
669
671
|
|
|
672
|
+
@staticmethod
|
|
673
|
+
def dtype_to_ctype(t: type) -> str:
|
|
674
|
+
if hasattr(t, "_wp_generic_type_str_"):
|
|
675
|
+
return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
|
|
676
|
+
elif isinstance(t, Struct):
|
|
677
|
+
return t.native_name
|
|
678
|
+
elif hasattr(t, "_wp_native_name_"):
|
|
679
|
+
return f"wp::{t._wp_native_name_}"
|
|
680
|
+
elif t.__name__ in ("bool", "int", "float"):
|
|
681
|
+
return t.__name__
|
|
682
|
+
|
|
683
|
+
return f"wp::{t.__name__}"
|
|
684
|
+
|
|
670
685
|
@staticmethod
|
|
671
686
|
def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
|
|
672
687
|
if is_array(t):
|
|
673
|
-
|
|
674
|
-
dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
|
|
675
|
-
elif isinstance(t.dtype, Struct):
|
|
676
|
-
dtypestr = t.dtype.native_name
|
|
677
|
-
elif t.dtype.__name__ in ("bool", "int", "float"):
|
|
678
|
-
dtypestr = t.dtype.__name__
|
|
679
|
-
else:
|
|
680
|
-
dtypestr = f"wp::{t.dtype.__name__}"
|
|
688
|
+
dtypestr = Var.dtype_to_ctype(t.dtype)
|
|
681
689
|
classstr = f"wp::{type(t).__name__}"
|
|
682
690
|
return f"{classstr}_t<{dtypestr}>"
|
|
691
|
+
elif get_origin(t) is tuple:
|
|
692
|
+
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
|
|
693
|
+
return f"wp::tuple_t<{dtypestr}>"
|
|
694
|
+
elif is_tuple(t):
|
|
695
|
+
dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
|
|
696
|
+
classstr = f"wp::{type(t).__name__}"
|
|
697
|
+
return f"{classstr}<{dtypestr}>"
|
|
683
698
|
elif is_tile(t):
|
|
684
699
|
return t.ctype()
|
|
685
|
-
elif isinstance(t, Struct):
|
|
686
|
-
return t.native_name
|
|
687
700
|
elif isinstance(t, type) and issubclass(t, StructInstance):
|
|
688
701
|
# ensure the actual Struct name is used instead of "NewStructInstance"
|
|
689
702
|
return t.native_name
|
|
690
703
|
elif is_reference(t):
|
|
691
704
|
if not value_type:
|
|
692
705
|
return Var.type_to_ctype(t.value_type) + "*"
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
elif t.__name__ in ("bool", "int", "float"):
|
|
698
|
-
return t.__name__
|
|
699
|
-
else:
|
|
700
|
-
return f"wp::{t.__name__}"
|
|
706
|
+
|
|
707
|
+
return Var.type_to_ctype(t.value_type)
|
|
708
|
+
|
|
709
|
+
return Var.dtype_to_ctype(t)
|
|
701
710
|
|
|
702
711
|
def ctype(self, value_type: builtins.bool = False) -> str:
|
|
703
712
|
return Var.type_to_ctype(self.type, value_type)
|
|
@@ -821,17 +830,26 @@ def func_match_args(func, arg_types, kwarg_types):
|
|
|
821
830
|
return True
|
|
822
831
|
|
|
823
832
|
|
|
824
|
-
def get_arg_type(arg:
|
|
833
|
+
def get_arg_type(arg: Var | Any) -> type:
|
|
825
834
|
if isinstance(arg, str):
|
|
826
835
|
return str
|
|
827
836
|
|
|
828
837
|
if isinstance(arg, Sequence):
|
|
829
838
|
return tuple(get_arg_type(x) for x in arg)
|
|
830
839
|
|
|
840
|
+
if get_origin(arg) is tuple:
|
|
841
|
+
return tuple(get_arg_type(x) for x in get_args(arg))
|
|
842
|
+
|
|
843
|
+
if is_tuple(arg):
|
|
844
|
+
return arg
|
|
845
|
+
|
|
831
846
|
if isinstance(arg, (type, warp.context.Function)):
|
|
832
847
|
return arg
|
|
833
848
|
|
|
834
849
|
if isinstance(arg, Var):
|
|
850
|
+
if get_origin(arg.type) is tuple:
|
|
851
|
+
return get_args(arg.type)
|
|
852
|
+
|
|
835
853
|
return arg.type
|
|
836
854
|
|
|
837
855
|
return type(arg)
|
|
@@ -845,7 +863,11 @@ def get_arg_value(arg: Any) -> Any:
|
|
|
845
863
|
return arg
|
|
846
864
|
|
|
847
865
|
if isinstance(arg, Var):
|
|
848
|
-
|
|
866
|
+
if is_tuple(arg.type):
|
|
867
|
+
return tuple(get_arg_value(x) for x in arg.type.values)
|
|
868
|
+
|
|
869
|
+
if arg.constant is not None:
|
|
870
|
+
return arg.constant
|
|
849
871
|
|
|
850
872
|
return arg
|
|
851
873
|
|
|
@@ -863,7 +885,8 @@ class Adjoint:
|
|
|
863
885
|
skip_reverse_codegen=False,
|
|
864
886
|
custom_reverse_mode=False,
|
|
865
887
|
custom_reverse_num_input_args=-1,
|
|
866
|
-
transformers:
|
|
888
|
+
transformers: list[ast.NodeTransformer] | None = None,
|
|
889
|
+
source: str | None = None,
|
|
867
890
|
):
|
|
868
891
|
adj.func = func
|
|
869
892
|
|
|
@@ -877,19 +900,17 @@ class Adjoint:
|
|
|
877
900
|
# extract name of source file
|
|
878
901
|
adj.filename = inspect.getsourcefile(func) or "unknown source file"
|
|
879
902
|
# get source file line number where function starts
|
|
880
|
-
|
|
881
|
-
|
|
882
|
-
|
|
883
|
-
|
|
884
|
-
|
|
885
|
-
|
|
886
|
-
) from e
|
|
903
|
+
adj.fun_lineno = 0
|
|
904
|
+
adj.source = source
|
|
905
|
+
if adj.source is None:
|
|
906
|
+
adj.source, adj.fun_lineno = adj.extract_function_source(func)
|
|
907
|
+
|
|
908
|
+
assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
|
|
887
909
|
|
|
888
910
|
# Indicates where the function definition starts (excludes decorators)
|
|
889
911
|
adj.fun_def_lineno = None
|
|
890
912
|
|
|
891
913
|
# get function source code
|
|
892
|
-
adj.source = inspect.getsource(func)
|
|
893
914
|
# ensures that indented class methods can be parsed as kernels
|
|
894
915
|
adj.source = textwrap.dedent(adj.source)
|
|
895
916
|
|
|
@@ -948,9 +969,14 @@ class Adjoint:
|
|
|
948
969
|
# this is to avoid registering false references to overshadowed modules
|
|
949
970
|
adj.symbols[name] = arg
|
|
950
971
|
|
|
972
|
+
# Indicates whether there are unresolved static expressions in the function.
|
|
973
|
+
# These stem from wp.static() expressions that could not be evaluated at declaration time.
|
|
974
|
+
# This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
|
|
975
|
+
adj.has_unresolved_static_expressions = False
|
|
976
|
+
|
|
951
977
|
# try to replace static expressions by their constant result if the
|
|
952
978
|
# expression can be evaluated at declaration time
|
|
953
|
-
adj.static_expressions:
|
|
979
|
+
adj.static_expressions: dict[str, Any] = {}
|
|
954
980
|
if "static" in adj.source:
|
|
955
981
|
adj.replace_static_expressions()
|
|
956
982
|
|
|
@@ -981,6 +1007,18 @@ class Adjoint:
|
|
|
981
1007
|
|
|
982
1008
|
return total_shared + adj.max_required_extra_shared_memory
|
|
983
1009
|
|
|
1010
|
+
@staticmethod
|
|
1011
|
+
def extract_function_source(func: Callable) -> tuple[str, int]:
|
|
1012
|
+
try:
|
|
1013
|
+
_, fun_lineno = inspect.getsourcelines(func)
|
|
1014
|
+
source = inspect.getsource(func)
|
|
1015
|
+
except OSError as e:
|
|
1016
|
+
raise RuntimeError(
|
|
1017
|
+
"Directly evaluating Warp code defined as a string using `exec()` is not supported, "
|
|
1018
|
+
"please save it to a file and use `importlib` if needed."
|
|
1019
|
+
) from e
|
|
1020
|
+
return source, fun_lineno
|
|
1021
|
+
|
|
984
1022
|
# generate function ssa form and adjoint
|
|
985
1023
|
def build(adj, builder, default_builder_options=None):
|
|
986
1024
|
# arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
|
|
@@ -1058,7 +1096,7 @@ class Adjoint:
|
|
|
1058
1096
|
# code generation methods
|
|
1059
1097
|
def format_template(adj, template, input_vars, output_var):
|
|
1060
1098
|
# output var is always the 0th index
|
|
1061
|
-
args = [output_var
|
|
1099
|
+
args = [output_var, *input_vars]
|
|
1062
1100
|
s = template.format(*args)
|
|
1063
1101
|
|
|
1064
1102
|
return s
|
|
@@ -1176,7 +1214,7 @@ class Adjoint:
|
|
|
1176
1214
|
|
|
1177
1215
|
return var
|
|
1178
1216
|
|
|
1179
|
-
def get_line_directive(adj, statement: str, relative_lineno:
|
|
1217
|
+
def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
|
|
1180
1218
|
"""Get a line directive for the given statement.
|
|
1181
1219
|
|
|
1182
1220
|
Args:
|
|
@@ -1202,7 +1240,7 @@ class Adjoint:
|
|
|
1202
1240
|
return f'#line {line} "{normalized_path}"'
|
|
1203
1241
|
return None
|
|
1204
1242
|
|
|
1205
|
-
def add_forward(adj, statement: str, replay:
|
|
1243
|
+
def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
|
|
1206
1244
|
"""Append a statement to the forward pass."""
|
|
1207
1245
|
|
|
1208
1246
|
if line_directive := adj.get_line_directive(statement, adj.lineno):
|
|
@@ -1300,7 +1338,8 @@ class Adjoint:
|
|
|
1300
1338
|
|
|
1301
1339
|
# check output dimensions match expectations
|
|
1302
1340
|
if min_outputs:
|
|
1303
|
-
|
|
1341
|
+
value_type = f.value_func(None, None)
|
|
1342
|
+
if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
|
|
1304
1343
|
continue
|
|
1305
1344
|
|
|
1306
1345
|
# found a match, use it
|
|
@@ -1396,6 +1435,17 @@ class Adjoint:
|
|
|
1396
1435
|
bound_arg_values,
|
|
1397
1436
|
)
|
|
1398
1437
|
|
|
1438
|
+
# Handle the special case where a Var instance is returned from the `value_func`
|
|
1439
|
+
# callback, in which case we replace the call with a reference to that variable.
|
|
1440
|
+
if isinstance(return_type, Var):
|
|
1441
|
+
return adj.register_var(return_type)
|
|
1442
|
+
elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
|
|
1443
|
+
return tuple(adj.register_var(x) for x in return_type)
|
|
1444
|
+
|
|
1445
|
+
if get_origin(return_type) is tuple:
|
|
1446
|
+
types = get_args(return_type)
|
|
1447
|
+
return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
|
|
1448
|
+
|
|
1399
1449
|
# immediately allocate output variables so we can pass them into the dispatch method
|
|
1400
1450
|
if return_type is None:
|
|
1401
1451
|
# void function
|
|
@@ -1775,6 +1825,22 @@ class Adjoint:
|
|
|
1775
1825
|
out = adj.add_builtin_call("where", [cond, var1, var2])
|
|
1776
1826
|
adj.symbols[sym] = out
|
|
1777
1827
|
|
|
1828
|
+
def emit_IfExp(adj, node):
|
|
1829
|
+
cond = adj.eval(node.test)
|
|
1830
|
+
|
|
1831
|
+
if cond.constant is not None:
|
|
1832
|
+
return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
|
|
1833
|
+
|
|
1834
|
+
adj.begin_if(cond)
|
|
1835
|
+
body = adj.eval(node.body)
|
|
1836
|
+
adj.end_if(cond)
|
|
1837
|
+
|
|
1838
|
+
adj.begin_else(cond)
|
|
1839
|
+
orelse = adj.eval(node.orelse)
|
|
1840
|
+
adj.end_else(cond)
|
|
1841
|
+
|
|
1842
|
+
return adj.add_builtin_call("where", [cond, body, orelse])
|
|
1843
|
+
|
|
1778
1844
|
def emit_Compare(adj, node):
|
|
1779
1845
|
# node.left, node.ops (list of ops), node.comparators (things to compare to)
|
|
1780
1846
|
# e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
|
|
@@ -1831,7 +1897,7 @@ class Adjoint:
|
|
|
1831
1897
|
if attr == "dtype":
|
|
1832
1898
|
return type_scalar_type(var_type)
|
|
1833
1899
|
elif attr == "length":
|
|
1834
|
-
return
|
|
1900
|
+
return type_size(var_type)
|
|
1835
1901
|
|
|
1836
1902
|
return getattr(var_type, attr, None)
|
|
1837
1903
|
|
|
@@ -1850,6 +1916,15 @@ class Adjoint:
|
|
|
1850
1916
|
index = adj.add_constant(index)
|
|
1851
1917
|
return index
|
|
1852
1918
|
|
|
1919
|
+
def transform_component(adj, component):
|
|
1920
|
+
if len(component) != 1:
|
|
1921
|
+
raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
|
|
1922
|
+
|
|
1923
|
+
if component not in ("p", "q"):
|
|
1924
|
+
raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
|
|
1925
|
+
|
|
1926
|
+
return component
|
|
1927
|
+
|
|
1853
1928
|
@staticmethod
|
|
1854
1929
|
def is_differentiable_value_type(var_type):
|
|
1855
1930
|
# checks that the argument type is a value type (i.e, not an array)
|
|
@@ -1880,12 +1955,20 @@ class Adjoint:
|
|
|
1880
1955
|
|
|
1881
1956
|
aggregate_type = strip_reference(aggregate.type)
|
|
1882
1957
|
|
|
1883
|
-
# reading a vector component
|
|
1884
|
-
if type_is_vector(aggregate_type):
|
|
1958
|
+
# reading a vector or quaternion component
|
|
1959
|
+
if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
|
|
1885
1960
|
index = adj.vector_component_index(node.attr, aggregate_type)
|
|
1886
1961
|
|
|
1887
1962
|
return adj.add_builtin_call("extract", [aggregate, index])
|
|
1888
1963
|
|
|
1964
|
+
elif type_is_transformation(aggregate_type):
|
|
1965
|
+
component = adj.transform_component(node.attr)
|
|
1966
|
+
|
|
1967
|
+
if component == "p":
|
|
1968
|
+
return adj.add_builtin_call("transform_get_translation", [aggregate])
|
|
1969
|
+
else:
|
|
1970
|
+
return adj.add_builtin_call("transform_get_rotation", [aggregate])
|
|
1971
|
+
|
|
1889
1972
|
else:
|
|
1890
1973
|
attr_type = Reference(aggregate_type.vars[node.attr].type)
|
|
1891
1974
|
attr = adj.add_var(attr_type)
|
|
@@ -2246,8 +2329,9 @@ class Adjoint:
|
|
|
2246
2329
|
|
|
2247
2330
|
if adj.is_static_expression(func):
|
|
2248
2331
|
# try to evaluate wp.static() expressions
|
|
2249
|
-
obj,
|
|
2332
|
+
obj, code = adj.evaluate_static_expression(node)
|
|
2250
2333
|
if obj is not None:
|
|
2334
|
+
adj.static_expressions[code] = obj
|
|
2251
2335
|
if isinstance(obj, warp.context.Function):
|
|
2252
2336
|
# special handling for wp.static() evaluating to a function
|
|
2253
2337
|
return obj
|
|
@@ -2282,6 +2366,10 @@ class Adjoint:
|
|
|
2282
2366
|
else:
|
|
2283
2367
|
func = caller.default_constructor
|
|
2284
2368
|
|
|
2369
|
+
# lambda function
|
|
2370
|
+
if func is None and getattr(caller, "__name__", None) == "<lambda>":
|
|
2371
|
+
raise NotImplementedError("Lambda expressions are not yet supported")
|
|
2372
|
+
|
|
2285
2373
|
if hasattr(caller, "_wp_type_args_"):
|
|
2286
2374
|
type_args = caller._wp_type_args_
|
|
2287
2375
|
|
|
@@ -2290,18 +2378,6 @@ class Adjoint:
|
|
|
2290
2378
|
f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
|
|
2291
2379
|
)
|
|
2292
2380
|
|
|
2293
|
-
# Check if any argument correspond to an unsupported construct.
|
|
2294
|
-
# Tuples are supported in the context of assigning multiple variables
|
|
2295
|
-
# at once, but not in place of vectors when calling built-ins like
|
|
2296
|
-
# `wp.length((1, 2, 3))`.
|
|
2297
|
-
# Therefore, we need to catch this specific case here instead of
|
|
2298
|
-
# more generally in `adj.eval()`.
|
|
2299
|
-
for arg in node.args:
|
|
2300
|
-
if isinstance(arg, ast.Tuple):
|
|
2301
|
-
raise WarpCodegenError(
|
|
2302
|
-
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
|
|
2303
|
-
)
|
|
2304
|
-
|
|
2305
2381
|
# get expected return count, e.g.: for multi-assignment
|
|
2306
2382
|
min_outputs = None
|
|
2307
2383
|
if hasattr(node, "expects"):
|
|
@@ -2311,7 +2387,6 @@ class Adjoint:
|
|
|
2311
2387
|
args = tuple(adj.resolve_arg(x) for x in node.args)
|
|
2312
2388
|
kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
|
|
2313
2389
|
|
|
2314
|
-
# add the call and build the callee adjoint if needed (func.adj)
|
|
2315
2390
|
out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
|
|
2316
2391
|
|
|
2317
2392
|
if warp.config.verify_autograd_array_access:
|
|
@@ -2461,10 +2536,6 @@ class Adjoint:
|
|
|
2461
2536
|
raise WarpCodegenError(
|
|
2462
2537
|
"List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2463
2538
|
)
|
|
2464
|
-
elif isinstance(node.value, ast.Tuple):
|
|
2465
|
-
raise WarpCodegenError(
|
|
2466
|
-
"Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
|
|
2467
|
-
)
|
|
2468
2539
|
|
|
2469
2540
|
# handle the case where we are assigning multiple output variables
|
|
2470
2541
|
if isinstance(lhs, ast.Tuple):
|
|
@@ -2480,6 +2551,17 @@ class Adjoint:
|
|
|
2480
2551
|
else:
|
|
2481
2552
|
out = adj.eval(node.value)
|
|
2482
2553
|
|
|
2554
|
+
subtype = getattr(out, "type", None)
|
|
2555
|
+
if isinstance(subtype, warp.types.tuple_t):
|
|
2556
|
+
if len(out.type.types) != len(lhs.elts):
|
|
2557
|
+
raise WarpCodegenError(
|
|
2558
|
+
f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
|
|
2559
|
+
)
|
|
2560
|
+
target = out
|
|
2561
|
+
out = tuple(
|
|
2562
|
+
adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
|
|
2563
|
+
)
|
|
2564
|
+
|
|
2483
2565
|
names = []
|
|
2484
2566
|
for v in lhs.elts:
|
|
2485
2567
|
if isinstance(v, ast.Name):
|
|
@@ -2532,7 +2614,12 @@ class Adjoint:
|
|
|
2532
2614
|
elif is_tile(target_type):
|
|
2533
2615
|
adj.add_builtin_call("assign", [target, *indices, rhs])
|
|
2534
2616
|
|
|
2535
|
-
elif
|
|
2617
|
+
elif (
|
|
2618
|
+
type_is_vector(target_type)
|
|
2619
|
+
or type_is_quaternion(target_type)
|
|
2620
|
+
or type_is_matrix(target_type)
|
|
2621
|
+
or type_is_transformation(target_type)
|
|
2622
|
+
):
|
|
2536
2623
|
# recursively unwind AST, stopping at penultimate node
|
|
2537
2624
|
node = lhs
|
|
2538
2625
|
while hasattr(node, "value"):
|
|
@@ -2572,7 +2659,7 @@ class Adjoint:
|
|
|
2572
2659
|
|
|
2573
2660
|
else:
|
|
2574
2661
|
raise WarpCodegenError(
|
|
2575
|
-
f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
|
|
2662
|
+
f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
|
|
2576
2663
|
)
|
|
2577
2664
|
|
|
2578
2665
|
elif isinstance(lhs, ast.Name):
|
|
@@ -2589,8 +2676,11 @@ class Adjoint:
|
|
|
2589
2676
|
f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
|
|
2590
2677
|
)
|
|
2591
2678
|
|
|
2592
|
-
|
|
2593
|
-
|
|
2679
|
+
if isinstance(node.value, ast.Tuple):
|
|
2680
|
+
out = rhs
|
|
2681
|
+
elif isinstance(rhs, Sequence):
|
|
2682
|
+
out = adj.add_builtin_call("tuple", rhs)
|
|
2683
|
+
elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
|
|
2594
2684
|
out = adj.add_builtin_call("copy", [rhs])
|
|
2595
2685
|
else:
|
|
2596
2686
|
out = rhs
|
|
@@ -2622,6 +2712,18 @@ class Adjoint:
|
|
|
2622
2712
|
else:
|
|
2623
2713
|
adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
|
|
2624
2714
|
|
|
2715
|
+
elif type_is_transformation(aggregate_type):
|
|
2716
|
+
component = adj.transform_component(lhs.attr)
|
|
2717
|
+
|
|
2718
|
+
# TODO: x[i,j].p = rhs case
|
|
2719
|
+
if is_reference(aggregate.type):
|
|
2720
|
+
raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
|
|
2721
|
+
|
|
2722
|
+
if component == "p":
|
|
2723
|
+
return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
|
|
2724
|
+
else:
|
|
2725
|
+
return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
|
|
2726
|
+
|
|
2625
2727
|
else:
|
|
2626
2728
|
attr = adj.emit_Attribute(lhs)
|
|
2627
2729
|
if is_reference(attr.type):
|
|
@@ -2644,7 +2746,9 @@ class Adjoint:
|
|
|
2644
2746
|
elif isinstance(node.value, ast.Tuple):
|
|
2645
2747
|
var = tuple(adj.eval(arg) for arg in node.value.elts)
|
|
2646
2748
|
else:
|
|
2647
|
-
var =
|
|
2749
|
+
var = adj.eval(node.value)
|
|
2750
|
+
if not isinstance(var, list) and not isinstance(var, tuple):
|
|
2751
|
+
var = (var,)
|
|
2648
2752
|
|
|
2649
2753
|
if adj.return_var is not None:
|
|
2650
2754
|
old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
|
|
@@ -2697,6 +2801,7 @@ class Adjoint:
|
|
|
2697
2801
|
type_is_vector(target_type.dtype)
|
|
2698
2802
|
or type_is_quaternion(target_type.dtype)
|
|
2699
2803
|
or type_is_matrix(target_type.dtype)
|
|
2804
|
+
or type_is_transformation(target_type.dtype)
|
|
2700
2805
|
):
|
|
2701
2806
|
dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
|
|
2702
2807
|
if dtype in warp.types.non_atomic_types:
|
|
@@ -2724,7 +2829,12 @@ class Adjoint:
|
|
|
2724
2829
|
make_new_assign_statement()
|
|
2725
2830
|
return
|
|
2726
2831
|
|
|
2727
|
-
elif
|
|
2832
|
+
elif (
|
|
2833
|
+
type_is_vector(target_type)
|
|
2834
|
+
or type_is_quaternion(target_type)
|
|
2835
|
+
or type_is_matrix(target_type)
|
|
2836
|
+
or type_is_transformation(target_type)
|
|
2837
|
+
):
|
|
2728
2838
|
if isinstance(node.op, ast.Add):
|
|
2729
2839
|
adj.add_builtin_call("add_inplace", [target, *indices, rhs])
|
|
2730
2840
|
elif isinstance(node.op, ast.Sub):
|
|
@@ -2735,9 +2845,36 @@ class Adjoint:
|
|
|
2735
2845
|
make_new_assign_statement()
|
|
2736
2846
|
return
|
|
2737
2847
|
|
|
2848
|
+
elif is_tile(target.type):
|
|
2849
|
+
if isinstance(node.op, ast.Add):
|
|
2850
|
+
adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
|
|
2851
|
+
elif isinstance(node.op, ast.Sub):
|
|
2852
|
+
adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
|
|
2853
|
+
else:
|
|
2854
|
+
if warp.config.verbose:
|
|
2855
|
+
print(f"Warning: in-place op {node.op} is not differentiable")
|
|
2856
|
+
make_new_assign_statement()
|
|
2857
|
+
return
|
|
2858
|
+
|
|
2738
2859
|
else:
|
|
2739
2860
|
raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
|
|
2740
2861
|
|
|
2862
|
+
elif isinstance(lhs, ast.Name):
|
|
2863
|
+
target = adj.eval(node.target)
|
|
2864
|
+
rhs = adj.eval(node.value)
|
|
2865
|
+
|
|
2866
|
+
if is_tile(target.type) and is_tile(rhs.type):
|
|
2867
|
+
if isinstance(node.op, ast.Add):
|
|
2868
|
+
adj.add_builtin_call("add_inplace", [target, rhs])
|
|
2869
|
+
elif isinstance(node.op, ast.Sub):
|
|
2870
|
+
adj.add_builtin_call("sub_inplace", [target, rhs])
|
|
2871
|
+
else:
|
|
2872
|
+
make_new_assign_statement()
|
|
2873
|
+
return
|
|
2874
|
+
else:
|
|
2875
|
+
make_new_assign_statement()
|
|
2876
|
+
return
|
|
2877
|
+
|
|
2741
2878
|
# TODO
|
|
2742
2879
|
elif isinstance(lhs, ast.Attribute):
|
|
2743
2880
|
make_new_assign_statement()
|
|
@@ -2748,15 +2885,16 @@ class Adjoint:
|
|
|
2748
2885
|
return
|
|
2749
2886
|
|
|
2750
2887
|
def emit_Tuple(adj, node):
|
|
2751
|
-
|
|
2752
|
-
return
|
|
2888
|
+
elements = tuple(adj.eval(x) for x in node.elts)
|
|
2889
|
+
return adj.add_builtin_call("tuple", elements)
|
|
2753
2890
|
|
|
2754
2891
|
def emit_Pass(adj, node):
|
|
2755
2892
|
pass
|
|
2756
2893
|
|
|
2757
|
-
node_visitors = {
|
|
2894
|
+
node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
|
|
2758
2895
|
ast.FunctionDef: emit_FunctionDef,
|
|
2759
2896
|
ast.If: emit_If,
|
|
2897
|
+
ast.IfExp: emit_IfExp,
|
|
2760
2898
|
ast.Compare: emit_Compare,
|
|
2761
2899
|
ast.BoolOp: emit_BoolOp,
|
|
2762
2900
|
ast.Name: emit_Name,
|
|
@@ -2860,11 +2998,11 @@ class Adjoint:
|
|
|
2860
2998
|
if isinstance(value, warp.context.Function):
|
|
2861
2999
|
return True
|
|
2862
3000
|
|
|
2863
|
-
def verify_struct(s: StructInstance, attr_path:
|
|
3001
|
+
def verify_struct(s: StructInstance, attr_path: list[str]):
|
|
2864
3002
|
for key in s._cls.vars.keys():
|
|
2865
3003
|
v = getattr(s, key)
|
|
2866
3004
|
if issubclass(type(v), StructInstance):
|
|
2867
|
-
verify_struct(v, attr_path
|
|
3005
|
+
verify_struct(v, [*attr_path, key])
|
|
2868
3006
|
else:
|
|
2869
3007
|
try:
|
|
2870
3008
|
adj.verify_static_return_value(v)
|
|
@@ -2879,7 +3017,8 @@ class Adjoint:
|
|
|
2879
3017
|
raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
|
|
2880
3018
|
|
|
2881
3019
|
# find the source code string of an AST node
|
|
2882
|
-
|
|
3020
|
+
@staticmethod
|
|
3021
|
+
def extract_node_source_from_lines(source_lines, node) -> str | None:
|
|
2883
3022
|
if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
|
|
2884
3023
|
return None
|
|
2885
3024
|
|
|
@@ -2895,12 +3034,12 @@ class Adjoint:
|
|
|
2895
3034
|
end_line = start_line
|
|
2896
3035
|
end_col = start_col
|
|
2897
3036
|
parenthesis_count = 1
|
|
2898
|
-
for lineno in range(start_line, len(
|
|
3037
|
+
for lineno in range(start_line, len(source_lines)):
|
|
2899
3038
|
if lineno == start_line:
|
|
2900
3039
|
c_start = start_col
|
|
2901
3040
|
else:
|
|
2902
3041
|
c_start = 0
|
|
2903
|
-
line =
|
|
3042
|
+
line = source_lines[lineno]
|
|
2904
3043
|
for i in range(c_start, len(line)):
|
|
2905
3044
|
c = line[i]
|
|
2906
3045
|
if c == "(":
|
|
@@ -2916,21 +3055,57 @@ class Adjoint:
|
|
|
2916
3055
|
|
|
2917
3056
|
if start_line == end_line:
|
|
2918
3057
|
# single-line expression
|
|
2919
|
-
return
|
|
3058
|
+
return source_lines[start_line][start_col:end_col]
|
|
2920
3059
|
else:
|
|
2921
3060
|
# multi-line expression
|
|
2922
3061
|
lines = []
|
|
2923
3062
|
# first line (from start_col to the end)
|
|
2924
|
-
lines.append(
|
|
3063
|
+
lines.append(source_lines[start_line][start_col:])
|
|
2925
3064
|
# middle lines (entire lines)
|
|
2926
|
-
lines.extend(
|
|
3065
|
+
lines.extend(source_lines[start_line + 1 : end_line])
|
|
2927
3066
|
# last line (from the start to end_col)
|
|
2928
|
-
lines.append(
|
|
3067
|
+
lines.append(source_lines[end_line][:end_col])
|
|
2929
3068
|
return "\n".join(lines).strip()
|
|
2930
3069
|
|
|
3070
|
+
@staticmethod
|
|
3071
|
+
def extract_lambda_source(func, only_body=False) -> str | None:
|
|
3072
|
+
try:
|
|
3073
|
+
source_lines = inspect.getsourcelines(func)[0]
|
|
3074
|
+
source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
|
|
3075
|
+
except OSError as e:
|
|
3076
|
+
raise WarpCodegenError(
|
|
3077
|
+
"Could not access lambda function source code. Please use a named function instead."
|
|
3078
|
+
) from e
|
|
3079
|
+
source = "".join(source_lines)
|
|
3080
|
+
source = source[source.index("lambda") :].rstrip()
|
|
3081
|
+
# Remove trailing unbalanced parentheses
|
|
3082
|
+
while source.count("(") < source.count(")"):
|
|
3083
|
+
source = source[:-1]
|
|
3084
|
+
# extract lambda expression up until a comma, e.g. in the case of
|
|
3085
|
+
# "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
|
|
3086
|
+
si = max(source.find(")"), source.find(":"))
|
|
3087
|
+
ci = source.find(",", si)
|
|
3088
|
+
if ci != -1:
|
|
3089
|
+
source = source[:ci]
|
|
3090
|
+
tree = ast.parse(source)
|
|
3091
|
+
lambda_source = None
|
|
3092
|
+
for node in ast.walk(tree):
|
|
3093
|
+
if isinstance(node, ast.Lambda):
|
|
3094
|
+
if only_body:
|
|
3095
|
+
# extract the body of the lambda function
|
|
3096
|
+
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
|
|
3097
|
+
else:
|
|
3098
|
+
# extract the entire lambda function
|
|
3099
|
+
lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
|
|
3100
|
+
break
|
|
3101
|
+
return lambda_source
|
|
3102
|
+
|
|
3103
|
+
def extract_node_source(adj, node) -> str | None:
|
|
3104
|
+
return adj.extract_node_source_from_lines(adj.source_lines, node)
|
|
3105
|
+
|
|
2931
3106
|
# handles a wp.static() expression and returns the resulting object and a string representing the code
|
|
2932
3107
|
# of the static expression
|
|
2933
|
-
def evaluate_static_expression(adj, node) ->
|
|
3108
|
+
def evaluate_static_expression(adj, node) -> tuple[Any, str]:
|
|
2934
3109
|
if len(node.args) == 1:
|
|
2935
3110
|
static_code = adj.extract_node_source(node.args[0])
|
|
2936
3111
|
elif len(node.keywords) == 1:
|
|
@@ -2942,6 +3117,7 @@ class Adjoint:
|
|
|
2942
3117
|
|
|
2943
3118
|
# Since this is an expression, we can enforce it to be defined on a single line.
|
|
2944
3119
|
static_code = static_code.replace("\n", "")
|
|
3120
|
+
code_to_eval = static_code # code to be evaluated
|
|
2945
3121
|
|
|
2946
3122
|
vars_dict = adj.get_static_evaluation_context()
|
|
2947
3123
|
# add constant variables to the static call context
|
|
@@ -2950,29 +3126,14 @@ class Adjoint:
|
|
|
2950
3126
|
|
|
2951
3127
|
# Replace all constant `len()` expressions with their value.
|
|
2952
3128
|
if "len" in static_code:
|
|
2953
|
-
|
|
2954
|
-
def eval_len(obj):
|
|
2955
|
-
if type_is_vector(obj):
|
|
2956
|
-
return obj._length_
|
|
2957
|
-
elif type_is_quaternion(obj):
|
|
2958
|
-
return obj._length_
|
|
2959
|
-
elif type_is_matrix(obj):
|
|
2960
|
-
return obj._shape_[0]
|
|
2961
|
-
elif type_is_transformation(obj):
|
|
2962
|
-
return obj._length_
|
|
2963
|
-
elif is_tile(obj):
|
|
2964
|
-
return obj.shape[0]
|
|
2965
|
-
|
|
2966
|
-
return len(obj)
|
|
2967
|
-
|
|
2968
3129
|
len_expr_ctx = vars_dict.copy()
|
|
2969
3130
|
constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
|
|
2970
3131
|
len_expr_ctx.update(constant_types)
|
|
2971
|
-
len_expr_ctx.update({"len":
|
|
3132
|
+
len_expr_ctx.update({"len": warp.types.type_length})
|
|
2972
3133
|
|
|
2973
3134
|
# We want to replace the expression code in-place,
|
|
2974
3135
|
# so reparse it to get the correct column info.
|
|
2975
|
-
len_value_locs:
|
|
3136
|
+
len_value_locs: list[tuple[int, int, int]] = []
|
|
2976
3137
|
expr_tree = ast.parse(static_code)
|
|
2977
3138
|
assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
|
|
2978
3139
|
expr_root = expr_tree.body[0].value
|
|
@@ -2998,10 +3159,10 @@ class Adjoint:
|
|
|
2998
3159
|
loc = end
|
|
2999
3160
|
|
|
3000
3161
|
new_static_code += static_code[len_value_locs[-1][2] :]
|
|
3001
|
-
|
|
3162
|
+
code_to_eval = new_static_code
|
|
3002
3163
|
|
|
3003
3164
|
try:
|
|
3004
|
-
value = eval(
|
|
3165
|
+
value = eval(code_to_eval, vars_dict)
|
|
3005
3166
|
if warp.config.verbose:
|
|
3006
3167
|
print(f"Evaluated static command: {static_code} = {value}")
|
|
3007
3168
|
except NameError as e:
|
|
@@ -3054,6 +3215,9 @@ class Adjoint:
|
|
|
3054
3215
|
# (and is therefore not executable and raises this exception), in which
|
|
3055
3216
|
# case changing the constant, or the code affecting this constant, would lead to
|
|
3056
3217
|
# a different module hash anyway.
|
|
3218
|
+
# In any case, we mark this Adjoint to have unresolvable static expressions.
|
|
3219
|
+
# This will trigger a code generation step even if the module hash is unchanged.
|
|
3220
|
+
adj.has_unresolved_static_expressions = True
|
|
3057
3221
|
pass
|
|
3058
3222
|
|
|
3059
3223
|
return self.generic_visit(node)
|
|
@@ -3134,14 +3298,14 @@ class Adjoint:
|
|
|
3134
3298
|
# return the Python code corresponding to the given AST node
|
|
3135
3299
|
return ast.get_source_segment(adj.source, node)
|
|
3136
3300
|
|
|
3137
|
-
def get_references(adj) ->
|
|
3301
|
+
def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
|
|
3138
3302
|
"""Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
|
|
3139
3303
|
|
|
3140
3304
|
local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
|
|
3141
3305
|
|
|
3142
|
-
constants:
|
|
3143
|
-
types:
|
|
3144
|
-
functions:
|
|
3306
|
+
constants: dict[str, Any] = {}
|
|
3307
|
+
types: dict[Struct | type, Any] = {}
|
|
3308
|
+
functions: dict[warp.context.Function, Any] = {}
|
|
3145
3309
|
|
|
3146
3310
|
for node in ast.walk(adj.tree):
|
|
3147
3311
|
if isinstance(node, ast.Name) and node.id not in local_variables:
|
|
@@ -3200,6 +3364,8 @@ cpu_module_header = """
|
|
|
3200
3364
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
|
|
3201
3365
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
|
|
3202
3366
|
|
|
3367
|
+
#define builtin_block_dim() wp::block_dim()
|
|
3368
|
+
|
|
3203
3369
|
"""
|
|
3204
3370
|
|
|
3205
3371
|
cuda_module_header = """
|
|
@@ -3219,6 +3385,8 @@ cuda_module_header = """
|
|
|
3219
3385
|
#define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
|
|
3220
3386
|
#define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
|
|
3221
3387
|
|
|
3388
|
+
#define builtin_block_dim() wp::block_dim()
|
|
3389
|
+
|
|
3222
3390
|
"""
|
|
3223
3391
|
|
|
3224
3392
|
struct_template = """
|
|
@@ -3663,7 +3831,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
|
|
|
3663
3831
|
f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
|
|
3664
3832
|
f"but the code returns {len(adj.return_var)} values."
|
|
3665
3833
|
)
|
|
3666
|
-
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
|
|
3834
|
+
elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
|
|
3667
3835
|
raise WarpCodegenError(
|
|
3668
3836
|
f"The function `{adj.fun_name}` has its return type "
|
|
3669
3837
|
f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
|