warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (192) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +130 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +272 -104
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +770 -238
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_callable.py +34 -4
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/interop/example_jax_kernel.py +27 -1
  37. warp/examples/optim/example_drone.py +1 -1
  38. warp/examples/sim/example_cloth.py +1 -1
  39. warp/examples/sim/example_cloth_self_contact.py +48 -54
  40. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  41. warp/examples/tile/example_tile_cholesky.py +2 -1
  42. warp/examples/tile/example_tile_convolution.py +1 -1
  43. warp/examples/tile/example_tile_filtering.py +1 -1
  44. warp/examples/tile/example_tile_matmul.py +1 -1
  45. warp/examples/tile/example_tile_mlp.py +2 -0
  46. warp/fabric.py +7 -7
  47. warp/fem/__init__.py +5 -0
  48. warp/fem/adaptivity.py +1 -1
  49. warp/fem/cache.py +152 -63
  50. warp/fem/dirichlet.py +2 -2
  51. warp/fem/domain.py +136 -6
  52. warp/fem/field/field.py +141 -99
  53. warp/fem/field/nodal_field.py +85 -39
  54. warp/fem/field/virtual.py +99 -52
  55. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  56. warp/fem/geometry/closest_point.py +13 -0
  57. warp/fem/geometry/deformed_geometry.py +102 -40
  58. warp/fem/geometry/element.py +56 -2
  59. warp/fem/geometry/geometry.py +323 -22
  60. warp/fem/geometry/grid_2d.py +157 -62
  61. warp/fem/geometry/grid_3d.py +116 -20
  62. warp/fem/geometry/hexmesh.py +86 -20
  63. warp/fem/geometry/nanogrid.py +166 -86
  64. warp/fem/geometry/partition.py +59 -25
  65. warp/fem/geometry/quadmesh.py +86 -135
  66. warp/fem/geometry/tetmesh.py +47 -119
  67. warp/fem/geometry/trimesh.py +77 -270
  68. warp/fem/integrate.py +181 -95
  69. warp/fem/linalg.py +25 -58
  70. warp/fem/operator.py +124 -27
  71. warp/fem/quadrature/pic_quadrature.py +36 -14
  72. warp/fem/quadrature/quadrature.py +40 -16
  73. warp/fem/space/__init__.py +1 -1
  74. warp/fem/space/basis_function_space.py +66 -46
  75. warp/fem/space/basis_space.py +17 -4
  76. warp/fem/space/dof_mapper.py +1 -1
  77. warp/fem/space/function_space.py +2 -2
  78. warp/fem/space/grid_2d_function_space.py +4 -1
  79. warp/fem/space/hexmesh_function_space.py +4 -2
  80. warp/fem/space/nanogrid_function_space.py +3 -1
  81. warp/fem/space/partition.py +11 -2
  82. warp/fem/space/quadmesh_function_space.py +4 -1
  83. warp/fem/space/restriction.py +5 -2
  84. warp/fem/space/shape/__init__.py +10 -8
  85. warp/fem/space/tetmesh_function_space.py +4 -1
  86. warp/fem/space/topology.py +52 -21
  87. warp/fem/space/trimesh_function_space.py +4 -1
  88. warp/fem/utils.py +53 -8
  89. warp/jax.py +1 -2
  90. warp/jax_experimental/ffi.py +210 -67
  91. warp/jax_experimental/xla_ffi.py +37 -24
  92. warp/math.py +171 -1
  93. warp/native/array.h +103 -4
  94. warp/native/builtin.h +182 -35
  95. warp/native/coloring.cpp +6 -2
  96. warp/native/cuda_util.cpp +1 -1
  97. warp/native/exports.h +118 -63
  98. warp/native/intersect.h +5 -5
  99. warp/native/mat.h +8 -13
  100. warp/native/mathdx.cpp +11 -5
  101. warp/native/matnn.h +1 -123
  102. warp/native/mesh.h +1 -1
  103. warp/native/quat.h +34 -6
  104. warp/native/rand.h +7 -7
  105. warp/native/sparse.cpp +121 -258
  106. warp/native/sparse.cu +181 -274
  107. warp/native/spatial.h +305 -17
  108. warp/native/svd.h +23 -8
  109. warp/native/tile.h +603 -73
  110. warp/native/tile_radix_sort.h +1112 -0
  111. warp/native/tile_reduce.h +239 -13
  112. warp/native/tile_scan.h +240 -0
  113. warp/native/tuple.h +189 -0
  114. warp/native/vec.h +10 -20
  115. warp/native/warp.cpp +36 -4
  116. warp/native/warp.cu +588 -52
  117. warp/native/warp.h +47 -74
  118. warp/optim/linear.py +5 -1
  119. warp/paddle.py +7 -8
  120. warp/py.typed +0 -0
  121. warp/render/render_opengl.py +110 -80
  122. warp/render/render_usd.py +124 -62
  123. warp/sim/__init__.py +9 -0
  124. warp/sim/collide.py +253 -80
  125. warp/sim/graph_coloring.py +8 -1
  126. warp/sim/import_mjcf.py +4 -3
  127. warp/sim/import_usd.py +11 -7
  128. warp/sim/integrator.py +5 -2
  129. warp/sim/integrator_euler.py +1 -1
  130. warp/sim/integrator_featherstone.py +1 -1
  131. warp/sim/integrator_vbd.py +761 -322
  132. warp/sim/integrator_xpbd.py +1 -1
  133. warp/sim/model.py +265 -260
  134. warp/sim/utils.py +10 -7
  135. warp/sparse.py +303 -166
  136. warp/tape.py +54 -51
  137. warp/tests/cuda/test_conditional_captures.py +1046 -0
  138. warp/tests/cuda/test_streams.py +1 -1
  139. warp/tests/geometry/test_volume.py +2 -2
  140. warp/tests/interop/test_dlpack.py +9 -9
  141. warp/tests/interop/test_jax.py +0 -1
  142. warp/tests/run_coverage_serial.py +1 -1
  143. warp/tests/sim/disabled_kinematics.py +2 -2
  144. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  145. warp/tests/sim/test_collision.py +159 -51
  146. warp/tests/sim/test_coloring.py +91 -2
  147. warp/tests/test_array.py +254 -2
  148. warp/tests/test_array_reduce.py +2 -2
  149. warp/tests/test_assert.py +53 -0
  150. warp/tests/test_atomic_cas.py +312 -0
  151. warp/tests/test_codegen.py +142 -19
  152. warp/tests/test_conditional.py +47 -1
  153. warp/tests/test_ctypes.py +0 -20
  154. warp/tests/test_devices.py +8 -0
  155. warp/tests/test_fabricarray.py +4 -2
  156. warp/tests/test_fem.py +58 -25
  157. warp/tests/test_func.py +42 -1
  158. warp/tests/test_grad.py +1 -1
  159. warp/tests/test_lerp.py +1 -3
  160. warp/tests/test_map.py +481 -0
  161. warp/tests/test_mat.py +23 -24
  162. warp/tests/test_quat.py +28 -15
  163. warp/tests/test_rounding.py +10 -38
  164. warp/tests/test_runlength_encode.py +7 -7
  165. warp/tests/test_smoothstep.py +1 -1
  166. warp/tests/test_sparse.py +83 -2
  167. warp/tests/test_spatial.py +507 -1
  168. warp/tests/test_static.py +48 -0
  169. warp/tests/test_struct.py +2 -2
  170. warp/tests/test_tape.py +38 -0
  171. warp/tests/test_tuple.py +265 -0
  172. warp/tests/test_types.py +2 -2
  173. warp/tests/test_utils.py +24 -18
  174. warp/tests/test_vec.py +38 -408
  175. warp/tests/test_vec_constructors.py +325 -0
  176. warp/tests/tile/test_tile.py +438 -131
  177. warp/tests/tile/test_tile_mathdx.py +518 -14
  178. warp/tests/tile/test_tile_matmul.py +179 -0
  179. warp/tests/tile/test_tile_reduce.py +307 -5
  180. warp/tests/tile/test_tile_shared_memory.py +136 -7
  181. warp/tests/tile/test_tile_sort.py +121 -0
  182. warp/tests/unittest_suites.py +14 -6
  183. warp/types.py +462 -308
  184. warp/utils.py +647 -86
  185. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +189 -175
  187. warp/stubs.py +0 -3381
  188. warp/tests/sim/test_xpbd.py +0 -399
  189. warp/tests/test_mlp.py +0 -282
  190. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -26,7 +26,7 @@ import re
26
26
  import sys
27
27
  import textwrap
28
28
  import types
29
- from typing import Any, Callable, Dict, Mapping, Optional, Sequence, get_args, get_origin
29
+ from typing import Any, Callable, ClassVar, Mapping, Sequence, get_args, get_origin
30
30
 
31
31
  import warp.config
32
32
  from warp.types import *
@@ -57,7 +57,7 @@ class WarpCodegenKeyError(KeyError):
57
57
 
58
58
 
59
59
  # map operator to function name
60
- builtin_operators: Dict[type[ast.AST], str] = {}
60
+ builtin_operators: dict[type[ast.AST], str] = {}
61
61
 
62
62
  # see https://www.ics.uci.edu/~pattis/ICS-31/lectures/opexp.pdf for a
63
63
  # nice overview of python operators
@@ -321,7 +321,7 @@ class StructInstance:
321
321
  # vector/matrix type, e.g. vec3
322
322
  if value is None:
323
323
  setattr(self._ctype, name, var.type())
324
- elif types_equal(type(value), var.type):
324
+ elif type(value) == var.type:
325
325
  setattr(self._ctype, name, value)
326
326
  else:
327
327
  # conversion from list/tuple, ndarray, etc.
@@ -616,6 +616,8 @@ def compute_type_str(base_name, template_params):
616
616
  def param2str(p):
617
617
  if isinstance(p, int):
618
618
  return str(p)
619
+ elif hasattr(p, "_wp_generic_type_str_"):
620
+ return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
619
621
  elif hasattr(p, "_type_"):
620
622
  if p.__name__ == "bool":
621
623
  return "bool"
@@ -626,7 +628,7 @@ def compute_type_str(base_name, template_params):
626
628
 
627
629
  return p.__name__
628
630
 
629
- return f"{base_name}<{','.join(map(param2str, template_params))}>"
631
+ return f"{base_name}<{', '.join(map(param2str, template_params))}>"
630
632
 
631
633
 
632
634
  class Var:
@@ -635,9 +637,9 @@ class Var:
635
637
  label: str,
636
638
  type: type,
637
639
  requires_grad: builtins.bool = False,
638
- constant: Optional[builtins.bool] = None,
640
+ constant: builtins.bool | None = None,
639
641
  prefix: builtins.bool = True,
640
- relative_lineno: Optional[int] = None,
642
+ relative_lineno: int | None = None,
641
643
  ):
642
644
  # convert built-in types to wp types
643
645
  if type == float:
@@ -667,37 +669,44 @@ class Var:
667
669
  def __str__(self):
668
670
  return self.label
669
671
 
672
+ @staticmethod
673
+ def dtype_to_ctype(t: type) -> str:
674
+ if hasattr(t, "_wp_generic_type_str_"):
675
+ return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
676
+ elif isinstance(t, Struct):
677
+ return t.native_name
678
+ elif hasattr(t, "_wp_native_name_"):
679
+ return f"wp::{t._wp_native_name_}"
680
+ elif t.__name__ in ("bool", "int", "float"):
681
+ return t.__name__
682
+
683
+ return f"wp::{t.__name__}"
684
+
670
685
  @staticmethod
671
686
  def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
672
687
  if is_array(t):
673
- if hasattr(t.dtype, "_wp_generic_type_str_"):
674
- dtypestr = compute_type_str(f"wp::{t.dtype._wp_generic_type_str_}", t.dtype._wp_type_params_)
675
- elif isinstance(t.dtype, Struct):
676
- dtypestr = t.dtype.native_name
677
- elif t.dtype.__name__ in ("bool", "int", "float"):
678
- dtypestr = t.dtype.__name__
679
- else:
680
- dtypestr = f"wp::{t.dtype.__name__}"
688
+ dtypestr = Var.dtype_to_ctype(t.dtype)
681
689
  classstr = f"wp::{type(t).__name__}"
682
690
  return f"{classstr}_t<{dtypestr}>"
691
+ elif get_origin(t) is tuple:
692
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in get_args(t))
693
+ return f"wp::tuple_t<{dtypestr}>"
694
+ elif is_tuple(t):
695
+ dtypestr = ", ".join(Var.dtype_to_ctype(x) for x in t.types)
696
+ classstr = f"wp::{type(t).__name__}"
697
+ return f"{classstr}<{dtypestr}>"
683
698
  elif is_tile(t):
684
699
  return t.ctype()
685
- elif isinstance(t, Struct):
686
- return t.native_name
687
700
  elif isinstance(t, type) and issubclass(t, StructInstance):
688
701
  # ensure the actual Struct name is used instead of "NewStructInstance"
689
702
  return t.native_name
690
703
  elif is_reference(t):
691
704
  if not value_type:
692
705
  return Var.type_to_ctype(t.value_type) + "*"
693
- else:
694
- return Var.type_to_ctype(t.value_type)
695
- elif hasattr(t, "_wp_generic_type_str_"):
696
- return compute_type_str(f"wp::{t._wp_generic_type_str_}", t._wp_type_params_)
697
- elif t.__name__ in ("bool", "int", "float"):
698
- return t.__name__
699
- else:
700
- return f"wp::{t.__name__}"
706
+
707
+ return Var.type_to_ctype(t.value_type)
708
+
709
+ return Var.dtype_to_ctype(t)
701
710
 
702
711
  def ctype(self, value_type: builtins.bool = False) -> str:
703
712
  return Var.type_to_ctype(self.type, value_type)
@@ -821,17 +830,26 @@ def func_match_args(func, arg_types, kwarg_types):
821
830
  return True
822
831
 
823
832
 
824
- def get_arg_type(arg: Union[Var, Any]) -> type:
833
+ def get_arg_type(arg: Var | Any) -> type:
825
834
  if isinstance(arg, str):
826
835
  return str
827
836
 
828
837
  if isinstance(arg, Sequence):
829
838
  return tuple(get_arg_type(x) for x in arg)
830
839
 
840
+ if get_origin(arg) is tuple:
841
+ return tuple(get_arg_type(x) for x in get_args(arg))
842
+
843
+ if is_tuple(arg):
844
+ return arg
845
+
831
846
  if isinstance(arg, (type, warp.context.Function)):
832
847
  return arg
833
848
 
834
849
  if isinstance(arg, Var):
850
+ if get_origin(arg.type) is tuple:
851
+ return get_args(arg.type)
852
+
835
853
  return arg.type
836
854
 
837
855
  return type(arg)
@@ -845,7 +863,11 @@ def get_arg_value(arg: Any) -> Any:
845
863
  return arg
846
864
 
847
865
  if isinstance(arg, Var):
848
- return arg.constant
866
+ if is_tuple(arg.type):
867
+ return tuple(get_arg_value(x) for x in arg.type.values)
868
+
869
+ if arg.constant is not None:
870
+ return arg.constant
849
871
 
850
872
  return arg
851
873
 
@@ -863,7 +885,8 @@ class Adjoint:
863
885
  skip_reverse_codegen=False,
864
886
  custom_reverse_mode=False,
865
887
  custom_reverse_num_input_args=-1,
866
- transformers: Optional[List[ast.NodeTransformer]] = None,
888
+ transformers: list[ast.NodeTransformer] | None = None,
889
+ source: str | None = None,
867
890
  ):
868
891
  adj.func = func
869
892
 
@@ -877,19 +900,17 @@ class Adjoint:
877
900
  # extract name of source file
878
901
  adj.filename = inspect.getsourcefile(func) or "unknown source file"
879
902
  # get source file line number where function starts
880
- try:
881
- _, adj.fun_lineno = inspect.getsourcelines(func)
882
- except OSError as e:
883
- raise RuntimeError(
884
- "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
885
- "please save it on a file and use `importlib` if needed."
886
- ) from e
903
+ adj.fun_lineno = 0
904
+ adj.source = source
905
+ if adj.source is None:
906
+ adj.source, adj.fun_lineno = adj.extract_function_source(func)
907
+
908
+ assert adj.source is not None, f"Failed to extract source code for function {func.__name__}"
887
909
 
888
910
  # Indicates where the function definition starts (excludes decorators)
889
911
  adj.fun_def_lineno = None
890
912
 
891
913
  # get function source code
892
- adj.source = inspect.getsource(func)
893
914
  # ensures that indented class methods can be parsed as kernels
894
915
  adj.source = textwrap.dedent(adj.source)
895
916
 
@@ -948,9 +969,14 @@ class Adjoint:
948
969
  # this is to avoid registering false references to overshadowed modules
949
970
  adj.symbols[name] = arg
950
971
 
972
+ # Indicates whether there are unresolved static expressions in the function.
973
+ # These stem from wp.static() expressions that could not be evaluated at declaration time.
974
+ # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
975
+ adj.has_unresolved_static_expressions = False
976
+
951
977
  # try to replace static expressions by their constant result if the
952
978
  # expression can be evaluated at declaration time
953
- adj.static_expressions: Dict[str, Any] = {}
979
+ adj.static_expressions: dict[str, Any] = {}
954
980
  if "static" in adj.source:
955
981
  adj.replace_static_expressions()
956
982
 
@@ -981,6 +1007,18 @@ class Adjoint:
981
1007
 
982
1008
  return total_shared + adj.max_required_extra_shared_memory
983
1009
 
1010
+ @staticmethod
1011
+ def extract_function_source(func: Callable) -> tuple[str, int]:
1012
+ try:
1013
+ _, fun_lineno = inspect.getsourcelines(func)
1014
+ source = inspect.getsource(func)
1015
+ except OSError as e:
1016
+ raise RuntimeError(
1017
+ "Directly evaluating Warp code defined as a string using `exec()` is not supported, "
1018
+ "please save it to a file and use `importlib` if needed."
1019
+ ) from e
1020
+ return source, fun_lineno
1021
+
984
1022
  # generate function ssa form and adjoint
985
1023
  def build(adj, builder, default_builder_options=None):
986
1024
  # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
@@ -1058,7 +1096,7 @@ class Adjoint:
1058
1096
  # code generation methods
1059
1097
  def format_template(adj, template, input_vars, output_var):
1060
1098
  # output var is always the 0th index
1061
- args = [output_var] + input_vars
1099
+ args = [output_var, *input_vars]
1062
1100
  s = template.format(*args)
1063
1101
 
1064
1102
  return s
@@ -1176,7 +1214,7 @@ class Adjoint:
1176
1214
 
1177
1215
  return var
1178
1216
 
1179
- def get_line_directive(adj, statement: str, relative_lineno: Optional[int] = None) -> Optional[str]:
1217
+ def get_line_directive(adj, statement: str, relative_lineno: int | None = None) -> str | None:
1180
1218
  """Get a line directive for the given statement.
1181
1219
 
1182
1220
  Args:
@@ -1202,7 +1240,7 @@ class Adjoint:
1202
1240
  return f'#line {line} "{normalized_path}"'
1203
1241
  return None
1204
1242
 
1205
- def add_forward(adj, statement: str, replay: Optional[str] = None, skip_replay: builtins.bool = False) -> None:
1243
+ def add_forward(adj, statement: str, replay: str | None = None, skip_replay: builtins.bool = False) -> None:
1206
1244
  """Append a statement to the forward pass."""
1207
1245
 
1208
1246
  if line_directive := adj.get_line_directive(statement, adj.lineno):
@@ -1300,7 +1338,8 @@ class Adjoint:
1300
1338
 
1301
1339
  # check output dimensions match expectations
1302
1340
  if min_outputs:
1303
- if not isinstance(f.value_type, Sequence) or len(f.value_type) != min_outputs:
1341
+ value_type = f.value_func(None, None)
1342
+ if not isinstance(value_type, Sequence) or len(value_type) != min_outputs:
1304
1343
  continue
1305
1344
 
1306
1345
  # found a match, use it
@@ -1396,6 +1435,17 @@ class Adjoint:
1396
1435
  bound_arg_values,
1397
1436
  )
1398
1437
 
1438
+ # Handle the special case where a Var instance is returned from the `value_func`
1439
+ # callback, in which case we replace the call with a reference to that variable.
1440
+ if isinstance(return_type, Var):
1441
+ return adj.register_var(return_type)
1442
+ elif isinstance(return_type, Sequence) and all(isinstance(x, Var) for x in return_type):
1443
+ return tuple(adj.register_var(x) for x in return_type)
1444
+
1445
+ if get_origin(return_type) is tuple:
1446
+ types = get_args(return_type)
1447
+ return_type = warp.types.tuple_t(types=types, values=(None,) * len(types))
1448
+
1399
1449
  # immediately allocate output variables so we can pass them into the dispatch method
1400
1450
  if return_type is None:
1401
1451
  # void function
@@ -1775,6 +1825,22 @@ class Adjoint:
1775
1825
  out = adj.add_builtin_call("where", [cond, var1, var2])
1776
1826
  adj.symbols[sym] = out
1777
1827
 
1828
+ def emit_IfExp(adj, node):
1829
+ cond = adj.eval(node.test)
1830
+
1831
+ if cond.constant is not None:
1832
+ return adj.eval(node.body) if cond.constant else adj.eval(node.orelse)
1833
+
1834
+ adj.begin_if(cond)
1835
+ body = adj.eval(node.body)
1836
+ adj.end_if(cond)
1837
+
1838
+ adj.begin_else(cond)
1839
+ orelse = adj.eval(node.orelse)
1840
+ adj.end_else(cond)
1841
+
1842
+ return adj.add_builtin_call("where", [cond, body, orelse])
1843
+
1778
1844
  def emit_Compare(adj, node):
1779
1845
  # node.left, node.ops (list of ops), node.comparators (things to compare to)
1780
1846
  # e.g. (left ops[0] node.comparators[0]) ops[1] node.comparators[1]
@@ -1831,7 +1897,7 @@ class Adjoint:
1831
1897
  if attr == "dtype":
1832
1898
  return type_scalar_type(var_type)
1833
1899
  elif attr == "length":
1834
- return type_length(var_type)
1900
+ return type_size(var_type)
1835
1901
 
1836
1902
  return getattr(var_type, attr, None)
1837
1903
 
@@ -1850,6 +1916,15 @@ class Adjoint:
1850
1916
  index = adj.add_constant(index)
1851
1917
  return index
1852
1918
 
1919
+ def transform_component(adj, component):
1920
+ if len(component) != 1:
1921
+ raise WarpCodegenAttributeError(f"Transform attribute must be single character, got .{component}")
1922
+
1923
+ if component not in ("p", "q"):
1924
+ raise WarpCodegenAttributeError(f"Attribute for transformation must be either 'p' or 'q', got {component}")
1925
+
1926
+ return component
1927
+
1853
1928
  @staticmethod
1854
1929
  def is_differentiable_value_type(var_type):
1855
1930
  # checks that the argument type is a value type (i.e, not an array)
@@ -1880,12 +1955,20 @@ class Adjoint:
1880
1955
 
1881
1956
  aggregate_type = strip_reference(aggregate.type)
1882
1957
 
1883
- # reading a vector component
1884
- if type_is_vector(aggregate_type):
1958
+ # reading a vector or quaternion component
1959
+ if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
1885
1960
  index = adj.vector_component_index(node.attr, aggregate_type)
1886
1961
 
1887
1962
  return adj.add_builtin_call("extract", [aggregate, index])
1888
1963
 
1964
+ elif type_is_transformation(aggregate_type):
1965
+ component = adj.transform_component(node.attr)
1966
+
1967
+ if component == "p":
1968
+ return adj.add_builtin_call("transform_get_translation", [aggregate])
1969
+ else:
1970
+ return adj.add_builtin_call("transform_get_rotation", [aggregate])
1971
+
1889
1972
  else:
1890
1973
  attr_type = Reference(aggregate_type.vars[node.attr].type)
1891
1974
  attr = adj.add_var(attr_type)
@@ -2246,8 +2329,9 @@ class Adjoint:
2246
2329
 
2247
2330
  if adj.is_static_expression(func):
2248
2331
  # try to evaluate wp.static() expressions
2249
- obj, _ = adj.evaluate_static_expression(node)
2332
+ obj, code = adj.evaluate_static_expression(node)
2250
2333
  if obj is not None:
2334
+ adj.static_expressions[code] = obj
2251
2335
  if isinstance(obj, warp.context.Function):
2252
2336
  # special handling for wp.static() evaluating to a function
2253
2337
  return obj
@@ -2282,6 +2366,10 @@ class Adjoint:
2282
2366
  else:
2283
2367
  func = caller.default_constructor
2284
2368
 
2369
+ # lambda function
2370
+ if func is None and getattr(caller, "__name__", None) == "<lambda>":
2371
+ raise NotImplementedError("Lambda expressions are not yet supported")
2372
+
2285
2373
  if hasattr(caller, "_wp_type_args_"):
2286
2374
  type_args = caller._wp_type_args_
2287
2375
 
@@ -2290,18 +2378,6 @@ class Adjoint:
2290
2378
  f"Could not find function {'.'.join(path)} as a built-in or user-defined function. Note that user functions must be annotated with a @wp.func decorator to be called from a kernel."
2291
2379
  )
2292
2380
 
2293
- # Check if any argument correspond to an unsupported construct.
2294
- # Tuples are supported in the context of assigning multiple variables
2295
- # at once, but not in place of vectors when calling built-ins like
2296
- # `wp.length((1, 2, 3))`.
2297
- # Therefore, we need to catch this specific case here instead of
2298
- # more generally in `adj.eval()`.
2299
- for arg in node.args:
2300
- if isinstance(arg, ast.Tuple):
2301
- raise WarpCodegenError(
2302
- "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` instead."
2303
- )
2304
-
2305
2381
  # get expected return count, e.g.: for multi-assignment
2306
2382
  min_outputs = None
2307
2383
  if hasattr(node, "expects"):
@@ -2311,7 +2387,6 @@ class Adjoint:
2311
2387
  args = tuple(adj.resolve_arg(x) for x in node.args)
2312
2388
  kwargs = {x.arg: adj.resolve_arg(x.value) for x in node.keywords}
2313
2389
 
2314
- # add the call and build the callee adjoint if needed (func.adj)
2315
2390
  out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2316
2391
 
2317
2392
  if warp.config.verify_autograd_array_access:
@@ -2461,10 +2536,6 @@ class Adjoint:
2461
2536
  raise WarpCodegenError(
2462
2537
  "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2463
2538
  )
2464
- elif isinstance(node.value, ast.Tuple):
2465
- raise WarpCodegenError(
2466
- "Tuple constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2467
- )
2468
2539
 
2469
2540
  # handle the case where we are assigning multiple output variables
2470
2541
  if isinstance(lhs, ast.Tuple):
@@ -2480,6 +2551,17 @@ class Adjoint:
2480
2551
  else:
2481
2552
  out = adj.eval(node.value)
2482
2553
 
2554
+ subtype = getattr(out, "type", None)
2555
+ if isinstance(subtype, warp.types.tuple_t):
2556
+ if len(out.type.types) != len(lhs.elts):
2557
+ raise WarpCodegenError(
2558
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
2559
+ )
2560
+ target = out
2561
+ out = tuple(
2562
+ adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
2563
+ )
2564
+
2483
2565
  names = []
2484
2566
  for v in lhs.elts:
2485
2567
  if isinstance(v, ast.Name):
@@ -2532,7 +2614,12 @@ class Adjoint:
2532
2614
  elif is_tile(target_type):
2533
2615
  adj.add_builtin_call("assign", [target, *indices, rhs])
2534
2616
 
2535
- elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2617
+ elif (
2618
+ type_is_vector(target_type)
2619
+ or type_is_quaternion(target_type)
2620
+ or type_is_matrix(target_type)
2621
+ or type_is_transformation(target_type)
2622
+ ):
2536
2623
  # recursively unwind AST, stopping at penultimate node
2537
2624
  node = lhs
2538
2625
  while hasattr(node, "value"):
@@ -2572,7 +2659,7 @@ class Adjoint:
2572
2659
 
2573
2660
  else:
2574
2661
  raise WarpCodegenError(
2575
- f"Can only subscript assign array, vector, quaternion, and matrix types, got {target_type}"
2662
+ f"Can only subscript assign array, vector, quaternion, transformation, and matrix types, got {target_type}"
2576
2663
  )
2577
2664
 
2578
2665
  elif isinstance(lhs, ast.Name):
@@ -2589,8 +2676,11 @@ class Adjoint:
2589
2676
  f"Error, assigning to existing symbol {name} ({adj.symbols[name].type}) with different type ({rhs.type})"
2590
2677
  )
2591
2678
 
2592
- # handle simple assignment case (a = b), where we generate a value copy rather than reference
2593
- if isinstance(node.value, ast.Name) or is_reference(rhs.type):
2679
+ if isinstance(node.value, ast.Tuple):
2680
+ out = rhs
2681
+ elif isinstance(rhs, Sequence):
2682
+ out = adj.add_builtin_call("tuple", rhs)
2683
+ elif isinstance(node.value, ast.Name) or is_reference(rhs.type):
2594
2684
  out = adj.add_builtin_call("copy", [rhs])
2595
2685
  else:
2596
2686
  out = rhs
@@ -2622,6 +2712,18 @@ class Adjoint:
2622
2712
  else:
2623
2713
  adj.add_builtin_call("assign_inplace", [aggregate, index, rhs])
2624
2714
 
2715
+ elif type_is_transformation(aggregate_type):
2716
+ component = adj.transform_component(lhs.attr)
2717
+
2718
+ # TODO: x[i,j].p = rhs case
2719
+ if is_reference(aggregate.type):
2720
+ raise WarpCodegenError(f"Error, assigning transform attribute {component} to an array element")
2721
+
2722
+ if component == "p":
2723
+ return adj.add_builtin_call("transform_set_translation", [aggregate, rhs])
2724
+ else:
2725
+ return adj.add_builtin_call("transform_set_rotation", [aggregate, rhs])
2726
+
2625
2727
  else:
2626
2728
  attr = adj.emit_Attribute(lhs)
2627
2729
  if is_reference(attr.type):
@@ -2644,7 +2746,9 @@ class Adjoint:
2644
2746
  elif isinstance(node.value, ast.Tuple):
2645
2747
  var = tuple(adj.eval(arg) for arg in node.value.elts)
2646
2748
  else:
2647
- var = (adj.eval(node.value),)
2749
+ var = adj.eval(node.value)
2750
+ if not isinstance(var, list) and not isinstance(var, tuple):
2751
+ var = (var,)
2648
2752
 
2649
2753
  if adj.return_var is not None:
2650
2754
  old_ctypes = tuple(v.ctype(value_type=True) for v in adj.return_var)
@@ -2697,6 +2801,7 @@ class Adjoint:
2697
2801
  type_is_vector(target_type.dtype)
2698
2802
  or type_is_quaternion(target_type.dtype)
2699
2803
  or type_is_matrix(target_type.dtype)
2804
+ or type_is_transformation(target_type.dtype)
2700
2805
  ):
2701
2806
  dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2702
2807
  if dtype in warp.types.non_atomic_types:
@@ -2724,7 +2829,12 @@ class Adjoint:
2724
2829
  make_new_assign_statement()
2725
2830
  return
2726
2831
 
2727
- elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2832
+ elif (
2833
+ type_is_vector(target_type)
2834
+ or type_is_quaternion(target_type)
2835
+ or type_is_matrix(target_type)
2836
+ or type_is_transformation(target_type)
2837
+ ):
2728
2838
  if isinstance(node.op, ast.Add):
2729
2839
  adj.add_builtin_call("add_inplace", [target, *indices, rhs])
2730
2840
  elif isinstance(node.op, ast.Sub):
@@ -2735,9 +2845,36 @@ class Adjoint:
2735
2845
  make_new_assign_statement()
2736
2846
  return
2737
2847
 
2848
+ elif is_tile(target.type):
2849
+ if isinstance(node.op, ast.Add):
2850
+ adj.add_builtin_call("tile_add_inplace", [target, *indices, rhs])
2851
+ elif isinstance(node.op, ast.Sub):
2852
+ adj.add_builtin_call("tile_sub_inplace", [target, *indices, rhs])
2853
+ else:
2854
+ if warp.config.verbose:
2855
+ print(f"Warning: in-place op {node.op} is not differentiable")
2856
+ make_new_assign_statement()
2857
+ return
2858
+
2738
2859
  else:
2739
2860
  raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
2740
2861
 
2862
+ elif isinstance(lhs, ast.Name):
2863
+ target = adj.eval(node.target)
2864
+ rhs = adj.eval(node.value)
2865
+
2866
+ if is_tile(target.type) and is_tile(rhs.type):
2867
+ if isinstance(node.op, ast.Add):
2868
+ adj.add_builtin_call("add_inplace", [target, rhs])
2869
+ elif isinstance(node.op, ast.Sub):
2870
+ adj.add_builtin_call("sub_inplace", [target, rhs])
2871
+ else:
2872
+ make_new_assign_statement()
2873
+ return
2874
+ else:
2875
+ make_new_assign_statement()
2876
+ return
2877
+
2741
2878
  # TODO
2742
2879
  elif isinstance(lhs, ast.Attribute):
2743
2880
  make_new_assign_statement()
@@ -2748,15 +2885,16 @@ class Adjoint:
2748
2885
  return
2749
2886
 
2750
2887
  def emit_Tuple(adj, node):
2751
- # LHS for expressions, such as i, j, k = 1, 2, 3
2752
- return tuple(adj.eval(x) for x in node.elts)
2888
+ elements = tuple(adj.eval(x) for x in node.elts)
2889
+ return adj.add_builtin_call("tuple", elements)
2753
2890
 
2754
2891
  def emit_Pass(adj, node):
2755
2892
  pass
2756
2893
 
2757
- node_visitors = {
2894
+ node_visitors: ClassVar[dict[type[ast.AST], Callable]] = {
2758
2895
  ast.FunctionDef: emit_FunctionDef,
2759
2896
  ast.If: emit_If,
2897
+ ast.IfExp: emit_IfExp,
2760
2898
  ast.Compare: emit_Compare,
2761
2899
  ast.BoolOp: emit_BoolOp,
2762
2900
  ast.Name: emit_Name,
@@ -2860,11 +2998,11 @@ class Adjoint:
2860
2998
  if isinstance(value, warp.context.Function):
2861
2999
  return True
2862
3000
 
2863
- def verify_struct(s: StructInstance, attr_path: List[str]):
3001
+ def verify_struct(s: StructInstance, attr_path: list[str]):
2864
3002
  for key in s._cls.vars.keys():
2865
3003
  v = getattr(s, key)
2866
3004
  if issubclass(type(v), StructInstance):
2867
- verify_struct(v, attr_path + [key])
3005
+ verify_struct(v, [*attr_path, key])
2868
3006
  else:
2869
3007
  try:
2870
3008
  adj.verify_static_return_value(v)
@@ -2879,7 +3017,8 @@ class Adjoint:
2879
3017
  raise ValueError(f"value of type {type(value)} cannot be constructed inside Warp kernels")
2880
3018
 
2881
3019
  # find the source code string of an AST node
2882
- def extract_node_source(adj, node) -> Optional[str]:
3020
+ @staticmethod
3021
+ def extract_node_source_from_lines(source_lines, node) -> str | None:
2883
3022
  if not hasattr(node, "lineno") or not hasattr(node, "col_offset"):
2884
3023
  return None
2885
3024
 
@@ -2895,12 +3034,12 @@ class Adjoint:
2895
3034
  end_line = start_line
2896
3035
  end_col = start_col
2897
3036
  parenthesis_count = 1
2898
- for lineno in range(start_line, len(adj.source_lines)):
3037
+ for lineno in range(start_line, len(source_lines)):
2899
3038
  if lineno == start_line:
2900
3039
  c_start = start_col
2901
3040
  else:
2902
3041
  c_start = 0
2903
- line = adj.source_lines[lineno]
3042
+ line = source_lines[lineno]
2904
3043
  for i in range(c_start, len(line)):
2905
3044
  c = line[i]
2906
3045
  if c == "(":
@@ -2916,21 +3055,57 @@ class Adjoint:
2916
3055
 
2917
3056
  if start_line == end_line:
2918
3057
  # single-line expression
2919
- return adj.source_lines[start_line][start_col:end_col]
3058
+ return source_lines[start_line][start_col:end_col]
2920
3059
  else:
2921
3060
  # multi-line expression
2922
3061
  lines = []
2923
3062
  # first line (from start_col to the end)
2924
- lines.append(adj.source_lines[start_line][start_col:])
3063
+ lines.append(source_lines[start_line][start_col:])
2925
3064
  # middle lines (entire lines)
2926
- lines.extend(adj.source_lines[start_line + 1 : end_line])
3065
+ lines.extend(source_lines[start_line + 1 : end_line])
2927
3066
  # last line (from the start to end_col)
2928
- lines.append(adj.source_lines[end_line][:end_col])
3067
+ lines.append(source_lines[end_line][:end_col])
2929
3068
  return "\n".join(lines).strip()
2930
3069
 
3070
+ @staticmethod
3071
+ def extract_lambda_source(func, only_body=False) -> str | None:
3072
+ try:
3073
+ source_lines = inspect.getsourcelines(func)[0]
3074
+ source_lines[0] = source_lines[0][source_lines[0].index("lambda") :]
3075
+ except OSError as e:
3076
+ raise WarpCodegenError(
3077
+ "Could not access lambda function source code. Please use a named function instead."
3078
+ ) from e
3079
+ source = "".join(source_lines)
3080
+ source = source[source.index("lambda") :].rstrip()
3081
+ # Remove trailing unbalanced parentheses
3082
+ while source.count("(") < source.count(")"):
3083
+ source = source[:-1]
3084
+ # extract lambda expression up until a comma, e.g. in the case of
3085
+ # "map(lambda a: (a + 2.0, a + 3.0), a, return_kernel=True)"
3086
+ si = max(source.find(")"), source.find(":"))
3087
+ ci = source.find(",", si)
3088
+ if ci != -1:
3089
+ source = source[:ci]
3090
+ tree = ast.parse(source)
3091
+ lambda_source = None
3092
+ for node in ast.walk(tree):
3093
+ if isinstance(node, ast.Lambda):
3094
+ if only_body:
3095
+ # extract the body of the lambda function
3096
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node.body)
3097
+ else:
3098
+ # extract the entire lambda function
3099
+ lambda_source = Adjoint.extract_node_source_from_lines(source_lines, node)
3100
+ break
3101
+ return lambda_source
3102
+
3103
+ def extract_node_source(adj, node) -> str | None:
3104
+ return adj.extract_node_source_from_lines(adj.source_lines, node)
3105
+
2931
3106
  # handles a wp.static() expression and returns the resulting object and a string representing the code
2932
3107
  # of the static expression
2933
- def evaluate_static_expression(adj, node) -> Tuple[Any, str]:
3108
+ def evaluate_static_expression(adj, node) -> tuple[Any, str]:
2934
3109
  if len(node.args) == 1:
2935
3110
  static_code = adj.extract_node_source(node.args[0])
2936
3111
  elif len(node.keywords) == 1:
@@ -2942,6 +3117,7 @@ class Adjoint:
2942
3117
 
2943
3118
  # Since this is an expression, we can enforce it to be defined on a single line.
2944
3119
  static_code = static_code.replace("\n", "")
3120
+ code_to_eval = static_code # code to be evaluated
2945
3121
 
2946
3122
  vars_dict = adj.get_static_evaluation_context()
2947
3123
  # add constant variables to the static call context
@@ -2950,29 +3126,14 @@ class Adjoint:
2950
3126
 
2951
3127
  # Replace all constant `len()` expressions with their value.
2952
3128
  if "len" in static_code:
2953
-
2954
- def eval_len(obj):
2955
- if type_is_vector(obj):
2956
- return obj._length_
2957
- elif type_is_quaternion(obj):
2958
- return obj._length_
2959
- elif type_is_matrix(obj):
2960
- return obj._shape_[0]
2961
- elif type_is_transformation(obj):
2962
- return obj._length_
2963
- elif is_tile(obj):
2964
- return obj.shape[0]
2965
-
2966
- return len(obj)
2967
-
2968
3129
  len_expr_ctx = vars_dict.copy()
2969
3130
  constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2970
3131
  len_expr_ctx.update(constant_types)
2971
- len_expr_ctx.update({"len": eval_len})
3132
+ len_expr_ctx.update({"len": warp.types.type_length})
2972
3133
 
2973
3134
  # We want to replace the expression code in-place,
2974
3135
  # so reparse it to get the correct column info.
2975
- len_value_locs: List[Tuple[int, int, int]] = []
3136
+ len_value_locs: list[tuple[int, int, int]] = []
2976
3137
  expr_tree = ast.parse(static_code)
2977
3138
  assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2978
3139
  expr_root = expr_tree.body[0].value
@@ -2998,10 +3159,10 @@ class Adjoint:
2998
3159
  loc = end
2999
3160
 
3000
3161
  new_static_code += static_code[len_value_locs[-1][2] :]
3001
- static_code = new_static_code
3162
+ code_to_eval = new_static_code
3002
3163
 
3003
3164
  try:
3004
- value = eval(static_code, vars_dict)
3165
+ value = eval(code_to_eval, vars_dict)
3005
3166
  if warp.config.verbose:
3006
3167
  print(f"Evaluated static command: {static_code} = {value}")
3007
3168
  except NameError as e:
@@ -3054,6 +3215,9 @@ class Adjoint:
3054
3215
  # (and is therefore not executable and raises this exception), in which
3055
3216
  # case changing the constant, or the code affecting this constant, would lead to
3056
3217
  # a different module hash anyway.
3218
+ # In any case, we mark this Adjoint to have unresolvable static expressions.
3219
+ # This will trigger a code generation step even if the module hash is unchanged.
3220
+ adj.has_unresolved_static_expressions = True
3057
3221
  pass
3058
3222
 
3059
3223
  return self.generic_visit(node)
@@ -3134,14 +3298,14 @@ class Adjoint:
3134
3298
  # return the Python code corresponding to the given AST node
3135
3299
  return ast.get_source_segment(adj.source, node)
3136
3300
 
3137
- def get_references(adj) -> Tuple[Dict[str, Any], Dict[Any, Any], Dict[warp.context.Function, Any]]:
3301
+ def get_references(adj) -> tuple[dict[str, Any], dict[Any, Any], dict[warp.context.Function, Any]]:
3138
3302
  """Traverses ``adj.tree`` and returns referenced constants, types, and user-defined functions."""
3139
3303
 
3140
3304
  local_variables = set() # Track local variables appearing on the LHS so we know when variables are shadowed
3141
3305
 
3142
- constants: Dict[str, Any] = {}
3143
- types: Dict[Union[Struct, type], Any] = {}
3144
- functions: Dict[warp.context.Function, Any] = {}
3306
+ constants: dict[str, Any] = {}
3307
+ types: dict[Struct | type, Any] = {}
3308
+ functions: dict[warp.context.Function, Any] = {}
3145
3309
 
3146
3310
  for node in ast.walk(adj.tree):
3147
3311
  if isinstance(node, ast.Name) and node.id not in local_variables:
@@ -3200,6 +3364,8 @@ cpu_module_header = """
3200
3364
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
3201
3365
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3202
3366
 
3367
+ #define builtin_block_dim() wp::block_dim()
3368
+
3203
3369
  """
3204
3370
 
3205
3371
  cuda_module_header = """
@@ -3219,6 +3385,8 @@ cuda_module_header = """
3219
3385
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3220
3386
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
3221
3387
 
3388
+ #define builtin_block_dim() wp::block_dim()
3389
+
3222
3390
  """
3223
3391
 
3224
3392
  struct_template = """
@@ -3663,7 +3831,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3663
3831
  f"annotated as a tuple of {len(get_args(adj.arg_types['return']))} elements "
3664
3832
  f"but the code returns {len(adj.return_var)} values."
3665
3833
  )
3666
- elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3834
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var), match_generic=True):
3667
3835
  raise WarpCodegenError(
3668
3836
  f"The function `{adj.fun_name}` has its return type "
3669
3837
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "