warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -26,7 +26,7 @@ import re
26
26
  import sys
27
27
  import textwrap
28
28
  import types
29
- from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
29
+ from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
30
30
 
31
31
  import warp.config
32
32
  from warp.types import *
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
57
57
 
58
58
 
59
59
  # map operator to function name
60
- builtin_operators: Dict[type[ast.AST], str] = {}
60
+ builtin_operators: dict[type[ast.AST], str] = {}
61
61
 
62
62
  # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
63
63
  # nice overview of python operators
@@ -321,7 +321,7 @@ class StructInstance:
321
321
  # vector/matrix type, e.g. vec3
322
322
  if value is None:
323
323
  setattr(self._ctype, name, var.type())
324
- elif types_equal(type(value), var.type):
324
+ elif type(value) == var.type:
325
325
  setattr(self._ctype, name, value)
326
326
  else:
327
327
  # conversion from list/tuple, ndarray, etc.
@@ -626,7 +626,7 @@ def compute_type_str(base_name, template_params):
626
626
 
627
627
  return p.__name__
628
628
 
629
- return f"{base_name}<{','.join(map(param2str, template_params))}>"
629
+ return f"{base_name}<{', '.join(map(param2str, template_params))}>"
630
630
 
631
631
 
632
632
  class Var:
@@ -635,9 +635,9 @@ class Var:
635
635
  label: str,
636
636
  type: type,
637
637
  requires_grad: builtins.bool = False,
638
- constant: Optional[builtins.bool] = None,
638
+ constant: builtins.bool | None = None,
639
639
  prefix: builtins.bool = True,
640
- relative_lineno: Optional[int] = None,
640
+ relative_lineno: int | None = None,
641
641
  ):
642
642
  # convert built-in types to wp types
643
643
  if type == float:
@@ -667,37 +667,44 @@ class Var:
667
667
  def __str__(self):
668
668
  return self.label
669
669
 
670
+ @staticmethod
671
+ def dtype_to_ctype(t: type) -> str:
672
+ if hasattr(t, "_wp_generic_type_str_"):
673
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
674
+ elif isinstance(t, Struct):
675
+ return t.native_name
676
+ elif hasattr(t, "_wp_native_name_"):
677
+ return f"wp::{t._wp_native_name_}"
678
+ elif t.__name__ in ("bool", "int", "float"):
679
+ return t.__name__
680
+
681
+ return f"wp::{t.__name__}"
682
+
670
683
  @staticmethod
671
684
  def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
672
685
  if is_array(t):
673
- if hasattr(t.dtype, "_wp_generic_type_str_"):
674
- dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
675
- elif isinstance(t.dtype, Struct):
676
- dtypestr = t.dtype.native_name
677
- elif t.dtype.__name__ in ("bool", "int", "float"):
678
- dtypestr = t.dtype.__name__
679
- else:
680
- dtypestr = f"wp::{t.dtype.__name__}"
686
+ dtypestr = Var.dtype_to_ctype(t.dtype)
681
687
  classstr = f"wp::{type(t).__name__}"
682
688
  return f"{classstr}_t<{dtypestr}>"
689
+ elif get_origin(t) is tuple:
690
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
691
+ return f"wp::tuple_t<{dtypestr}>"
692
+ elif is_tuple(t):
693
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
694
+ classstr = f"wp::{type(t).__name__}"
695
+ return f"{classstr}<{dtypestr}>"
683
696
  elif is_tile(t):
684
697
  return t.ctype()
685
- elif isinstance(t, Struct):
686
- return t.native_name
687
698
  elif isinstance(t, type) and issubclass(t, StructInstance):
688
699
  # ensure the actual Struct name is used instead of "NewStructInstance"
689
700
  return t.native_name
690
701
  elif is_reference(t):
691
702
  if not value_type:
692
703
  return Var.type_to_ctype(t.value_type) + "*"
693
- else:
694
- return Var.type_to_ctype(t.value_type)
695
- elif hasattr(t, "_wp_generic_type_str_"):
696
- return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
697
- elif t.__name__ in ("bool", "int", "float"):
698
- return t.__name__
699
- else:
700
- return f"wp::{t.__name__}"
704
+
705
+ return Var.type_to_ctype(t.value_type)
706
+
707
+ return Var.dtype_to_ctype(t)
701
708
 
702
709
  def ctype(self, value_type: builtins.bool = False) -> str:
703
710
  return Var.type_to_ctype(self.type, value_type)
@@ -821,17 +828,26 @@ def func_match_args(func, arg_types, kwarg_types):
821
828
  return True
822
829
 
823
830
 
824
- def get_arg_type(arg: Union[Var, Any]) -> type:
831
+ def get_arg_type(arg: Var | Any) -> type:
825
832
  if isinstance(arg, str):
826
833
  return str
827
834
 
828
835
  if isinstance(arg, Sequence):
829
836
  return tuple(get_arg_type(x) for x in arg)
830
837
 
838
+ if get_origin(arg) is tuple:
839
+ return tuple(get_arg_type(x) for x in get_args(arg))
840
+
841
+ if is_tuple(arg):
842
+ return arg
843
+
831
844
  if isinstance(arg, (type, warp.context.Function)):
832
845
  return arg
833
846
 
834
847
  if isinstance(arg, Var):
848
+ if get_origin(arg.type) is tuple:
849
+ return get_args(arg.type)
850
+
835
851
  return arg.type
836
852
 
837
853
  return type(arg)
@@ -845,7 +861,11 @@ def get_arg_value(arg: Any) -> Any:
845
861
  return arg
846
862
 
847
863
  if isinstance(arg, Var):
848
- return arg.constant
864
+ if is_tuple(arg.type):
865
+ return tuple(get_arg_value(x) for x in arg.type.values)
866
+
867
+ if arg.constant is not None:
868
+ return arg.constant
849
869
 
850
870
  return arg
851
871
 
@@ -863,7 +883,8 @@ class Adjoint:
863
883
  skip_reverse_codegen=False,
864
884
  custom_reverse_mode=False,
865
885
  custom_reverse_num_input_args=-1,
866
- transformers: Optional[List[ast.NodeTransformer]] = None,
886
+ transformers: list[ast.NodeTransformer] | None = None,
887
+ source: str | None = None,
867
888
  ):
868
889
  adj.func = func
869
890
 
@@ -877,19 +898,17 @@ class Adjoint:
877
898
  # extract name of source file
878
899
  adj.filename = inspect.getsourcefile(func) or "unknown source file"
879
900
  # get source file line number where function starts
880
- try:
881
- _, adj.fun_lineno = inspect.getsourcelines(func)
882
- except OSError as e:
883
- raise RuntimeError(
884
- "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
885
- "please save it on a file and use `importlib` if needed."
886
- ) from e
901
+ adj.fun_lineno = 0
902
+ adj.source = source
903
+ if adj.source is None:
904
+ adj.source, adj.fun_lineno = adj.extract_function_source(func)
905
+
906
+ assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
887
907
 
888
908
  # Indicates where the function definition starts (excludes decorators)
889
909
  adj.fun_def_lineno = None
890
910
 
891
911
  # get function source code
892
- adj.source = inspect.getsource(func)
893
912
  # ensures that indented class methods can be parsed as kernels
894
913
  adj.source = textwrap.dedent(adj.source)
895
914
 
@@ -950,7 +969,7 @@ class Adjoint:
950
969
 
951
970
  # try to replace static expressions by their constant result if the
952
971
  # expression can be evaluated at declaration time
953
- adj.static_expressions: Dict[str, Any] = {}
972
+ adj.static_expressions: dict[str, Any] = {}
954
973
  if "static" in adj.source:
955
974
  adj.replace_static_expressions()
956
975
 
@@ -981,6 +1000,18 @@ class Adjoint:
981
1000
 
982
1001
  return total_shared + adj.max_required_extra_shared_memory
983
1002
 
1003
+ @staticmethod
1004
+ def extract_function_source(func: Callable) -> tuple[str, int]:
1005
+ try:
1006
+ _, fun_lineno = inspect.getsourcelines(func)
1007
+ source = inspect.getsource(func)
1008
+ except OSError as e:
1009
+ raise RuntimeError(
1010
+ "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
1011
+ "please save it to a file and use `importlib` if needed."
1012
+ ) from e
1013
+ return source, fun_lineno
1014
+
984
1015
  # generate function ssa form and adjoint
985
1016
  def build(adj, builder, default_builder_options=None):
986
1017
  # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
@@ -1058,7 +1089,7 @@ class Adjoint:
1058
1089
  # code generation methods
1059
1090
  def format_template(adj, template, input_vars, output_var):
1060
1091
  # output var is always the 0th index
1061
- args = [output_var] + input_vars
1092
+ args = [output_var, *input_vars]
1062
1093
  s = template.format(*args)
1063
1094
 
1064
1095
  return s
@@ -1176,7 +1207,7 @@ class Adjoint:
1176
1207
 
1177
1208
  return var
1178
1209
 
1179
- def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1210
+ def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
1180
1211
  """Get a line directive for the given statement.
1181
1212
 
1182
1213
  Args:
@@ -1202,7 +1233,7 @@ class Adjoint:
1202
1233
  return f'#line {line} "{normalized_path}"'
1203
1234
  return None
1204
1235
 
1205
- def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1236
+ def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
1206
1237
  """Append a statement to the forward pass."""
1207
1238
 
1208
1239
  if line_directive := adj.get_line_directive(statement, adj.lineno):
@@ -1300,7 +1331,8 @@ class Adjoint:
1300
1331
 
1301
1332
  # check output dimensions match expectations
1302
1333
  if min_outputs:
1303
- if not isinstance(f.value_type, Sequence) or len(f.value_type) != min_outputs:
1334
+ value_type = f.value_func(None, None)
1335
+ if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
1304
1336
  continue
1305
1337
 
1306
1338
  # found a match, use it
@@ -1396,6 +1428,17 @@ class Adjoint:
1396
1428
  bound_arg_values,
1397
1429
  )
1398
1430
 
1431
+ # Handle the special case where a Var instance is returned from the `value_func`
1432
+ # callback, in which case we replace the call with a reference to that variable.
1433
+ if isinstance(return_type, Var):
1434
+ return adj.register_var(return_type)
1435
+ elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1436
+ return tuple(adj.register_var(x) for x in return_type)
1437
+
1438
+ if get_origin(return_type) is tuple:
1439
+ types = get_args(return_type)
1440
+ return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
1441
+
1399
1442
  # immediately allocate output variables so we can pass them into the dispatch method
1400
1443
  if return_type is None:
1401
1444
  # void function
@@ -1775,6 +1818,22 @@ class Adjoint:
1775
1818
  out = adj.add_builtin_call("where", [cond, var1, var2])
1776
1819
  adj.symbols[sym] = out
1777
1820
 
1821
+ def emit_IfExp(adj, node):
1822
+ cond = adj.eval(node.test)
1823
+
1824
+ if cond.constant is not None:
1825
+ return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
1826
+
1827
+ adj.begin_if(cond)
1828
+ body = adj.eval(node.body)
1829
+ adj.end_if(cond)
1830
+
1831
+ adj.begin_else(cond)
1832
+ orelse = adj.eval(node.orelse)
1833
+ adj.end_else(cond)
1834
+
1835
+ return adj.add_builtin_call("where", [cond, body, orelse])
1836
+
1778
1837
  def emit_Compare(adj, node):
1779
1838
  # node.left, node.ops (list of ops), node.comparators (things to compare to)
1780
1839
  # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
@@ -1831,7 +1890,7 @@ class Adjoint:
1831
1890
  if attr == "dtype":
1832
1891
  return type_scalar_type(var_type)
1833
1892
  elif attr == "length":
1834
- return type_length(var_type)
1893
+ return type_size(var_type)
1835
1894
 
1836
1895
  return getattr(var_type, attr, None)
1837
1896
 
@@ -1850,6 +1909,15 @@ class Adjoint:
1850
1909
  index = adj.add_constant(index)
1851
1910
  return index
1852
1911
 
1912
+ def transform_component(adj, component):
1913
+ if len(component) != 1:
1914
+ raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
1915
+
1916
+ if component not in ("p", "q"):
1917
+ raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
1918
+
1919
+ return component
1920
+
1853
1921
  @staticmethod
1854
1922
  def is_differentiable_value_type(var_type):
1855
1923
  # checks that the argument type is a value type (i.e, not an array)
@@ -1880,12 +1948,20 @@ class Adjoint:
1880
1948
 
1881
1949
  aggregate_type = strip_reference(aggregate.type)
1882
1950
 
1883
- # reading a vector component
1884
- if type_is_vector(aggregate_type):
1951
+ # reading a vector or quaternion component
1952
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
1885
1953
  index = adj.vector_component_index(node.attr, aggregate_type)
1886
1954
 
1887
1955
  return adj.add_builtin_call("extract", [aggregate, index])
1888
1956
 
1957
+ elif type_is_transformation(aggregate_type):
1958
+ component = adj.transform_component(node.attr)
1959
+
1960
+ if component == "p":
1961
+ return adj.add_builtin_call("transform_get_translation", [aggregate])
1962
+ else:
1963
+ return adj.add_builtin_call("transform_get_rotation", [aggregate])
1964
+
1889
1965
  else:
1890
1966
  attr_type = Reference(aggregate_type.vars[node.attr].type)
1891
1967
  attr = adj.add_var(attr_type)
@@ -2282,6 +2358,10 @@ class Adjoint:
2282
2358
  else:
2283
2359
  func = caller.default_constructor
2284
2360
 
2361
+ # lambda function
2362
+ if func is None and getattr(caller, "__name__", None) == "<lambda>":
2363
+ raise NotImplementedError("Lambda expressions are not yet supported")
2364
+
2285
2365
  if hasattr(caller, "_wp_type_args_"):
2286
2366
  type_args = caller._wp_type_args_
2287
2367
 
@@ -2290,18 +2370,6 @@ class Adjoint:
2290
2370
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2291
2371
  )
2292
2372
 
2293
- # Check if any argument correspond to an unsupported construct.
2294
- # Tuples are supported in the context of assigning multiple variables
2295
- # at once, but not in place of vectors when calling built-ins like
2296
- # `wp.length((1, 2, 3))`.
2297
- # Therefore, we need to catch this specific case here instead of
2298
- # more generally in `adj.eval()`.
2299
- for arg in node.args:
2300
- if isinstance(arg, ast.Tuple):
2301
- raise WarpCodegenError(
2302
- "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
2303
- )
2304
-
2305
2373
  # get expected return count, e.g.: for multi-assignment
2306
2374
  min_outputs = None
2307
2375
  if hasattr(node, "expects"):
@@ -2311,7 +2379,6 @@ class Adjoint:
2311
2379
  args = tuple(adj.resolve_arg(x) for x in node.args)
2312
2380
  kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2313
2381
 
2314
- # add the call and build the callee adjoint if needed (func.adj)
2315
2382
  out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2316
2383
 
2317
2384
  if warp.config.verify_autograd_array_access:
@@ -2461,10 +2528,6 @@ class Adjoint:
2461
2528
  raise WarpCodegenError(
2462
2529
  "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2463
2530
  )
2464
- elif isinstance(node.value, ast.Tuple):
2465
- raise WarpCodegenError(
2466
- "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2467
- )
2468
2531
 
2469
2532
  # handle the case where we are assigning multiple output variables
2470
2533
  if isinstance(lhs, ast.Tuple):
@@ -2480,6 +2543,17 @@ class Adjoint:
2480
2543
  else:
2481
2544
  out = adj.eval(node.value)
2482
2545
 
2546
+ subtype = getattr(out, "type", None)
2547
+ if isinstance(subtype, warp.types.tuple_t):
2548
+ if len(out.type.types) != len(lhs.elts):
2549
+ raise WarpCodegenError(
2550
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
2551
+ )
2552
+ target = out
2553
+ out = tuple(
2554
+ adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
2555
+ )
2556
+
2483
2557
  names = []
2484
2558
  for v in lhs.elts:
2485
2559
  if isinstance(v, ast.Name):
@@ -2532,7 +2606,12 @@ class Adjoint:
2532
2606
  elif is_tile(target_type):
2533
2607
  adj.add_builtin_call("assign", [target, *indices, rhs])
2534
2608
 
2535
- elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2609
+ elif (
2610
+ type_is_vector(target_type)
2611
+ or type_is_quaternion(target_type)
2612
+ or type_is_matrix(target_type)
2613
+ or type_is_transformation(target_type)
2614
+ ):
2536
2615
  # recursively unwind AST, stopping at penultimate node
2537
2616
  node = lhs
2538
2617
  while hasattr(node, "value"):
@@ -2572,7 +2651,7 @@ class Adjoint:
2572
2651
 
2573
2652
  else:
2574
2653
  raise WarpCodegenError(
2575
- f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
2654
+ f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
2576
2655
  )
2577
2656
 
2578
2657
  elif isinstance(lhs, ast.Name):
@@ -2589,8 +2668,11 @@ class Adjoint:
2589
2668
  f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2590
2669
  )
2591
2670
 
2592
- # handle simple assignment case (a = b), where we generate a value copy rather than reference
2593
- if isinstance(node.value, ast.Name) or is_reference(rhs.type):
2671
+ if isinstance(node.value, ast.Tuple):
2672
+ out = rhs
2673
+ elif isinstance(rhs, Sequence):
2674
+ out = adj.add_builtin_call("tuple", rhs)
2675
+ elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
2594
2676
  out = adj.add_builtin_call("copy", [rhs])
2595
2677
  else:
2596
2678
  out = rhs
@@ -2622,6 +2704,18 @@ class Adjoint:
2622
2704
  else:
2623
2705
  adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2624
2706
 
2707
+ elif type_is_transformation(aggregate_type):
2708
+ component = adj.transform_component(lhs.attr)
2709
+
2710
+ # TODO: x[i,j].p = rhs case
2711
+ if is_reference(aggregate.type):
2712
+ raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
2713
+
2714
+ if component == "p":
2715
+ return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
2716
+ else:
2717
+ return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
2718
+
2625
2719
  else:
2626
2720
  attr = adj.emit_Attribute(lhs)
2627
2721
  if is_reference(attr.type):
@@ -2644,7 +2738,9 @@ class Adjoint:
2644
2738
  elif isinstance(node.value, ast.Tuple):
2645
2739
  var = tuple(adj.eval(arg) for arg in node.value.elts)
2646
2740
  else:
2647
- var = (adj.eval(node.value),)
2741
+ var = adj.eval(node.value)
2742
+ if not isinstance(var, list) and not isinstance(var, tuple):
2743
+ var = (var,)
2648
2744
 
2649
2745
  if adj.return_var is not None:
2650
2746
  old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
@@ -2697,6 +2793,7 @@ class Adjoint:
2697
2793
  type_is_vector(target_type.dtype)
2698
2794
  or type_is_quaternion(target_type.dtype)
2699
2795
  or type_is_matrix(target_type.dtype)
2796
+ or type_is_transformation(target_type.dtype)
2700
2797
  ):
2701
2798
  dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2702
2799
  if dtype in warp.types.non_atomic_types:
@@ -2724,7 +2821,12 @@ class Adjoint:
2724
2821
  make_new_assign_statement()
2725
2822
  return
2726
2823
 
2727
- elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2824
+ elif (
2825
+ type_is_vector(target_type)
2826
+ or type_is_quaternion(target_type)
2827
+ or type_is_matrix(target_type)
2828
+ or type_is_transformation(target_type)
2829
+ ):
2728
2830
  if isinstance(node.op, ast.Add):
2729
2831
  adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2730
2832
  elif isinstance(node.op, ast.Sub):
@@ -2735,9 +2837,36 @@ class Adjoint:
2735
2837
  make_new_assign_statement()
2736
2838
  return
2737
2839
 
2840
+ elif is_tile(target.type):
2841
+ if isinstance(node.op, ast.Add):
2842
+ adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
2843
+ elif isinstance(node.op, ast.Sub):
2844
+ adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
2845
+ else:
2846
+ if warp.config.verbose:
2847
+ print(f"Warning: in-place op {node.op} is not differentiable")
2848
+ make_new_assign_statement()
2849
+ return
2850
+
2738
2851
  else:
2739
2852
  raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2740
2853
 
2854
+ elif isinstance(lhs, ast.Name):
2855
+ target = adj.eval(node.target)
2856
+ rhs = adj.eval(node.value)
2857
+
2858
+ if is_tile(target.type) and is_tile(rhs.type):
2859
+ if isinstance(node.op, ast.Add):
2860
+ adj.add_builtin_call("add_inplace", [target, rhs])
2861
+ elif isinstance(node.op, ast.Sub):
2862
+ adj.add_builtin_call("sub_inplace", [target, rhs])
2863
+ else:
2864
+ make_new_assign_statement()
2865
+ return
2866
+ else:
2867
+ make_new_assign_statement()
2868
+ return
2869
+
2741
2870
  # TODO
2742
2871
  elif isinstance(lhs, ast.Attribute):
2743
2872
  make_new_assign_statement()
@@ -2748,15 +2877,16 @@ class Adjoint:
2748
2877
  return
2749
2878
 
2750
2879
  def emit_Tuple(adj, node):
2751
- # LHS for expressions, such as i, j, k = 1, 2, 3
2752
- return tuple(adj.eval(x) for x in node.elts)
2880
+ elements = tuple(adj.eval(x) for x in node.elts)
2881
+ return adj.add_builtin_call("tuple", elements)
2753
2882
 
2754
2883
  def emit_Pass(adj, node):
2755
2884
  pass
2756
2885
 
2757
- node_visitors = {
2886
+ node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
2758
2887
  ast.FunctionDef: emit_FunctionDef,
2759
2888
  ast.If: emit_If,
2889
+ ast.IfExp: emit_IfExp,
2760
2890
  ast.Compare: emit_Compare,
2761
2891
  ast.BoolOp: emit_BoolOp,
2762
2892
  ast.Name: emit_Name,
@@ -2860,11 +2990,11 @@ class Adjoint:
2860
2990
  if isinstance(value, warp.context.Function):
2861
2991
  return True
2862
2992
 
2863
- def verify_struct(s: StructInstance, attr_path: List[str]):
2993
+ def verify_struct(s: StructInstance, attr_path: list[str]):
2864
2994
  for key in s._cls.vars.keys():
2865
2995
  v = getattr(s, key)
2866
2996
  if issubclass(type(v), StructInstance):
2867
- verify_struct(v, attr_path + [key])
2997
+ verify_struct(v, [*attr_path, key])
2868
2998
  else:
2869
2999
  try:
2870
3000
  adj.verify_static_return_value(v)
@@ -2879,7 +3009,8 @@ class Adjoint:
2879
3009
  raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
2880
3010
 
2881
3011
  # find the source code string of an AST node
2882
- def extract_node_source(adj, node) -> Optional[str]:
3012
+ @staticmethod
3013
+ def extract_node_source_from_lines(source_lines, node) -> str | None:
2883
3014
  if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
2884
3015
  return None
2885
3016
 
@@ -2895,12 +3026,12 @@ class Adjoint:
2895
3026
  end_line = start_line
2896
3027
  end_col = start_col
2897
3028
  parenthesis_count = 1
2898
- for lineno in range(start_line, len(adj.source_lines)):
3029
+ for lineno in range(start_line, len(source_lines)):
2899
3030
  if lineno == start_line:
2900
3031
  c_start = start_col
2901
3032
  else:
2902
3033
  c_start = 0
2903
- line = adj.source_lines[lineno]
3034
+ line = source_lines[lineno]
2904
3035
  for i in range(c_start, len(line)):
2905
3036
  c = line[i]
2906
3037
  if c == "(":
@@ -2916,21 +3047,57 @@ class Adjoint:
2916
3047
 
2917
3048
  if start_line == end_line:
2918
3049
  # single-line expression
2919
- return adj.source_lines[start_line][start_col:end_col]
3050
+ return source_lines[start_line][start_col:end_col]
2920
3051
  else:
2921
3052
  # multi-line expression
2922
3053
  lines = []
2923
3054
  # first line (from start_col to the end)
2924
- lines.append(adj.source_lines[start_line][start_col:])
3055
+ lines.append(source_lines[start_line][start_col:])
2925
3056
  # middle lines (entire lines)
2926
- lines.extend(adj.source_lines[start_line + 1 : end_line])
3057
+ lines.extend(source_lines[start_line + 1 : end_line])
2927
3058
  # last line (from the start to end_col)
2928
- lines.append(adj.source_lines[end_line][:end_col])
3059
+ lines.append(source_lines[end_line][:end_col])
2929
3060
  return "\n".join(lines).strip()
2930
3061
 
3062
+ @staticmethod
3063
+ def extract_lambda_source(func, only_body=False) -> str | None:
3064
+ try:
3065
+ source_lines = inspect.getsourcelines(func)[0]
3066
+ source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
3067
+ except OSError as e:
3068
+ raise WarpCodegenError(
3069
+ "Could not access lambda function source code. Please use a named function instead."
3070
+ ) from e
3071
+ source = "".join(source_lines)
3072
+ source = source[source.index("lambda") :].rstrip()
3073
+ # Remove trailing unbalanced parentheses
3074
+ while source.count("(") < source.count(")"):
3075
+ source = source[:-1]
3076
+ # extract lambda expression up until a comma, e.g. in the case of
3077
+ # "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
3078
+ si = max(source.find(")"), source.find(":"))
3079
+ ci = source.find(",", si)
3080
+ if ci != -1:
3081
+ source = source[:ci]
3082
+ tree = ast.parse(source)
3083
+ lambda_source = None
3084
+ for node in ast.walk(tree):
3085
+ if isinstance(node, ast.Lambda):
3086
+ if only_body:
3087
+ # extract the body of the lambda function
3088
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
3089
+ else:
3090
+ # extract the entire lambda function
3091
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
3092
+ break
3093
+ return lambda_source
3094
+
3095
+ def extract_node_source(adj, node) -> str | None:
3096
+ return adj.extract_node_source_from_lines(adj.source_lines, node)
3097
+
2931
3098
  # handles a wp.static() expression and returns the resulting object and a string representing the code
2932
3099
  # of the static expression
2933
- def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
3100
+ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
2934
3101
  if len(node.args) == 1:
2935
3102
  static_code = adj.extract_node_source(node.args[0])
2936
3103
  elif len(node.keywords) == 1:
@@ -2950,29 +3117,14 @@ class Adjoint:
2950
3117
 
2951
3118
  # Replace all constant `len()` expressions with their value.
2952
3119
  if "len" in static_code:
2953
-
2954
- def eval_len(obj):
2955
- if type_is_vector(obj):
2956
- return obj._length_
2957
- elif type_is_quaternion(obj):
2958
- return obj._length_
2959
- elif type_is_matrix(obj):
2960
- return obj._shape_[0]
2961
- elif type_is_transformation(obj):
2962
- return obj._length_
2963
- elif is_tile(obj):
2964
- return obj.shape[0]
2965
-
2966
- return len(obj)
2967
-
2968
3120
  len_expr_ctx = vars_dict.copy()
2969
3121
  constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2970
3122
  len_expr_ctx.update(constant_types)
2971
- len_expr_ctx.update({"len": eval_len})
3123
+ len_expr_ctx.update({"len": warp.types.type_length})
2972
3124
 
2973
3125
  # We want to replace the expression code in-place,
2974
3126
  # so reparse it to get the correct column info.
2975
- len_value_locs: List[Tuple[int, int, int]] = []
3127
+ len_value_locs: list[tuple[int, int, int]] = []
2976
3128
  expr_tree = ast.parse(static_code)
2977
3129
  assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2978
3130
  expr_root = expr_tree.body[0].value
@@ -3134,14 +3286,14 @@ class Adjoint:
3134
3286
  # return the Python code corresponding to the given AST node
3135
3287
  return ast.get_source_segment(adj.source, node)
3136
3288
 
3137
- def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
3289
+ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
3138
3290
  """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3139
3291
 
3140
3292
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3141
3293
 
3142
- constants: Dict[str, Any] = {}
3143
- types: Dict[Union[Struct, type], Any] = {}
3144
- functions: Dict[warp.context.Function, Any] = {}
3294
+ constants: dict[str, Any] = {}
3295
+ types: dict[Struct | type, Any] = {}
3296
+ functions: dict[warp.context.Function, Any] = {}
3145
3297
 
3146
3298
  for node in ast.walk(adj.tree):
3147
3299
  if isinstance(node, ast.Name) and node.id not in local_variables:
@@ -3200,6 +3352,8 @@ cpu_module_header = """
3200
3352
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3201
3353
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3202
3354
 
3355
+ #define builtin_block_dim() wp::block_dim()
3356
+
3203
3357
  """
3204
3358
 
3205
3359
  cuda_module_header = """
@@ -3219,6 +3373,8 @@ cuda_module_header = """
3219
3373
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3220
3374
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3221
3375
 
3376
+ #define builtin_block_dim() wp::block_dim()
3377
+
3222
3378
  """
3223
3379
 
3224
3380
  struct_template = """
@@ -3663,7 +3819,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3663
3819
  f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3664
3820
  f"but the code returns {len(adj.return_var)} values."
3665
3821
  )
3666
- elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3822
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
3667
3823
  raise WarpCodegenError(
3668
3824
  f"The function `{adj.fun_name}` has its return type "
3669
3825
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "