warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -18,9 +18,11 @@ from __future__ import annotations
18
18
  import ast
19
19
  import builtins
20
20
  import ctypes
21
+ import enum
21
22
  import functools
22
23
  import hashlib
23
24
  import inspect
25
+ import itertools
24
26
  import math
25
27
  import re
26
28
  import sys
@@ -614,8 +616,12 @@ def compute_type_str(base_name, template_params):
614
616
  return base_name
615
617
 
616
618
  def param2str(p):
619
+ if isinstance(p, builtins.bool):
620
+ return "true" if p else "false"
617
621
  if isinstance(p, int):
618
622
  return str(p)
623
+ elif hasattr(p, "_wp_generic_type_str_"):
624
+ return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
619
625
  elif hasattr(p, "_type_"):
620
626
  if p.__name__ == "bool":
621
627
  return "bool"
@@ -623,6 +629,8 @@ def compute_type_str(base_name, template_params):
623
629
  return f"wp::{p.__name__}"
624
630
  elif is_tile(p):
625
631
  return p.ctype()
632
+ elif isinstance(p, Struct):
633
+ return p.native_name
626
634
 
627
635
  return p.__name__
628
636
 
@@ -682,7 +690,12 @@ class Var:
682
690
 
683
691
  @staticmethod
684
692
  def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
685
- if is_array(t):
693
+ if isinstance(t, fixedarray):
694
+ template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
695
+ dtypestr = ", ".join(template_args)
696
+ classstr = f"wp::{type(t).__name__}"
697
+ return f"{classstr}_t<{dtypestr}>"
698
+ elif is_array(t):
686
699
  dtypestr = Var.dtype_to_ctype(t.dtype)
687
700
  classstr = f"wp::{type(t).__name__}"
688
701
  return f"{classstr}_t<{dtypestr}>"
@@ -778,11 +791,10 @@ def apply_defaults(
778
791
  arguments = bound_args.arguments
779
792
  new_arguments = []
780
793
  for name in bound_args._signature.parameters.keys():
781
- try:
794
+ if name in arguments:
782
795
  new_arguments.append((name, arguments[name]))
783
- except KeyError:
784
- if name in values:
785
- new_arguments.append((name, values[name]))
796
+ elif name in values:
797
+ new_arguments.append((name, values[name]))
786
798
 
787
799
  bound_args.arguments = dict(new_arguments)
788
800
 
@@ -835,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
835
847
  if isinstance(arg, Sequence):
836
848
  return tuple(get_arg_type(x) for x in arg)
837
849
 
850
+ if is_array(arg):
851
+ return arg
852
+
838
853
  if get_origin(arg) is tuple:
839
854
  return tuple(get_arg_type(x) for x in get_args(arg))
840
855
 
@@ -894,6 +909,8 @@ class Adjoint:
894
909
  adj.skip_forward_codegen = skip_forward_codegen
895
910
  # whether the generation of the adjoint code is skipped for this function
896
911
  adj.skip_reverse_codegen = skip_reverse_codegen
912
+ # Whether this function is used by a kernel that has has the backward pass enabled.
913
+ adj.used_by_backward_kernel = False
897
914
 
898
915
  # extract name of source file
899
916
  adj.filename = inspect.getsourcefile(func) or "unknown source file"
@@ -960,13 +977,18 @@ class Adjoint:
960
977
  continue
961
978
 
962
979
  # add variable for argument
963
- arg = Var(name, type, False)
980
+ arg = Var(name, type, requires_grad=False)
964
981
  adj.args.append(arg)
965
982
 
966
983
  # pre-populate symbol dictionary with function argument names
967
984
  # this is to avoid registering false references to overshadowed modules
968
985
  adj.symbols[name] = arg
969
986
 
987
+ # Indicates whether there are unresolved static expressions in the function.
988
+ # These stem from wp.static() expressions that could not be evaluated at declaration time.
989
+ # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
990
+ adj.has_unresolved_static_expressions = False
991
+
970
992
  # try to replace static expressions by their constant result if the
971
993
  # expression can be evaluated at declaration time
972
994
  adj.static_expressions: dict[str, Any] = {}
@@ -1064,17 +1086,21 @@ class Adjoint:
1064
1086
  # recursively evaluate function body
1065
1087
  try:
1066
1088
  adj.eval(adj.tree.body[0])
1067
- except Exception:
1089
+ except Exception as original_exc:
1068
1090
  try:
1069
1091
  lineno = adj.lineno + adj.fun_lineno
1070
1092
  line = adj.source_lines[adj.lineno]
1071
1093
  msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1072
- ex, data, traceback = sys.exc_info()
1073
- e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
1094
+
1095
+ # Combine the new message with the original exception's arguments
1096
+ new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1097
+
1098
+ # Enhance the original exception with parser context before re-raising.
1099
+ # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1100
+ raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1074
1101
  finally:
1075
1102
  adj.skip_build = True
1076
1103
  adj.builder = None
1077
- raise e
1078
1104
 
1079
1105
  if builder is not None:
1080
1106
  for a in adj.args:
@@ -1220,9 +1246,9 @@ class Adjoint:
1220
1246
 
1221
1247
  # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1222
1248
  # emit line directives in generated code if it's not being compiled with line information
1223
- lineinfo_enabled = (
1224
- adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1225
- )
1249
+ build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
1250
+
1251
+ lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1226
1252
 
1227
1253
  if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1228
1254
  is_comment = statement.strip().startswith("//")
@@ -1341,7 +1367,7 @@ class Adjoint:
1341
1367
  # unresolved function, report error
1342
1368
  arg_type_reprs = []
1343
1369
 
1344
- for x in arg_types:
1370
+ for x in itertools.chain(arg_types, kwarg_types.values()):
1345
1371
  if isinstance(x, warp.context.Function):
1346
1372
  arg_type_reprs.append("function")
1347
1373
  else:
@@ -1371,7 +1397,7 @@ class Adjoint:
1371
1397
  # in order to process them as Python does it.
1372
1398
  bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1373
1399
 
1374
- # Type args are the compile time argument values we get from codegen.
1400
+ # Type args are the "compile time" argument values we get from codegen.
1375
1401
  # For example, when calling `wp.vec3f(...)` from within a kernel,
1376
1402
  # this translates in fact to calling the `vector()` built-in augmented
1377
1403
  # with the type args `length=3, dtype=float`.
@@ -1409,20 +1435,30 @@ class Adjoint:
1409
1435
  bound_args = bound_args.arguments
1410
1436
 
1411
1437
  # if it is a user-function then build it recursively
1412
- if not func.is_builtin() and func not in adj.builder.functions:
1413
- adj.builder.build_function(func)
1414
- # add custom grad, replay functions to the list of functions
1415
- # to be built later (invalid code could be generated if we built them now)
1416
- # so that they are not missed when only the forward function is imported
1417
- # from another module
1418
- if func.custom_grad_func:
1419
- adj.builder.deferred_functions.append(func.custom_grad_func)
1420
- if func.custom_replay_func:
1421
- adj.builder.deferred_functions.append(func.custom_replay_func)
1438
+ if not func.is_builtin():
1439
+ # If the function called is a user function,
1440
+ # we need to ensure its adjoint is also being generated.
1441
+ if adj.used_by_backward_kernel:
1442
+ func.adj.used_by_backward_kernel = True
1443
+
1444
+ if adj.builder is None:
1445
+ func.build(None)
1446
+
1447
+ elif func not in adj.builder.functions:
1448
+ adj.builder.build_function(func)
1449
+ # add custom grad, replay functions to the list of functions
1450
+ # to be built later (invalid code could be generated if we built them now)
1451
+ # so that they are not missed when only the forward function is imported
1452
+ # from another module
1453
+ if func.custom_grad_func:
1454
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1455
+ if func.custom_replay_func:
1456
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1422
1457
 
1423
1458
  # Resolve the return value based on the types and values of the given arguments.
1424
1459
  bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1425
1460
  bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1461
+
1426
1462
  return_type = func.value_func(
1427
1463
  {k: strip_reference(v) for k, v in bound_arg_types.items()},
1428
1464
  bound_arg_values,
@@ -1486,6 +1522,9 @@ class Adjoint:
1486
1522
 
1487
1523
  # if the argument is a function (and not a builtin), then build it recursively
1488
1524
  if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1525
+ if adj.used_by_backward_kernel:
1526
+ func_arg_var.adj.used_by_backward_kernel = True
1527
+
1489
1528
  adj.builder.build_function(func_arg_var)
1490
1529
 
1491
1530
  fwd_args.append(strip_reference(func_arg_var))
@@ -1879,6 +1918,9 @@ class Adjoint:
1879
1918
  return obj
1880
1919
  if isinstance(obj, type):
1881
1920
  return obj
1921
+ if isinstance(obj, Struct):
1922
+ adj.builder.build_struct_recursive(obj)
1923
+ return obj
1882
1924
  if isinstance(obj, types.ModuleType):
1883
1925
  return obj
1884
1926
 
@@ -1931,11 +1973,17 @@ class Adjoint:
1931
1973
  aggregate = adj.eval(node.value)
1932
1974
 
1933
1975
  try:
1976
+ if isinstance(aggregate, Var) and aggregate.constant is not None:
1977
+ # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
1978
+ return aggregate
1979
+
1934
1980
  if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1935
1981
  out = getattr(aggregate, node.attr)
1936
1982
 
1937
1983
  if warp.types.is_value(out):
1938
1984
  return adj.add_constant(out)
1985
+ if isinstance(out, (enum.IntEnum, enum.IntFlag)):
1986
+ return adj.add_constant(int(out))
1939
1987
 
1940
1988
  return out
1941
1989
 
@@ -1963,18 +2011,29 @@ class Adjoint:
1963
2011
  return adj.add_builtin_call("transform_get_rotation", [aggregate])
1964
2012
 
1965
2013
  else:
1966
- attr_type = Reference(aggregate_type.vars[node.attr].type)
2014
+ attr_var = aggregate_type.vars[node.attr]
2015
+
2016
+ # represent pointer types as uint64
2017
+ if isinstance(attr_var.type, pointer_t):
2018
+ cast = f"({Var.dtype_to_ctype(uint64)}*)"
2019
+ adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2020
+ attr_type = Reference(uint64)
2021
+ else:
2022
+ cast = ""
2023
+ adj_cast = ""
2024
+ attr_type = Reference(attr_var.type)
2025
+
1967
2026
  attr = adj.add_var(attr_type)
1968
2027
 
1969
2028
  if is_reference(aggregate.type):
1970
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
2029
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
1971
2030
  else:
1972
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
2031
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
1973
2032
 
1974
2033
  if adj.is_differentiable_value_type(strip_reference(attr_type)):
1975
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
2034
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
1976
2035
  else:
1977
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
2036
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
1978
2037
 
1979
2038
  return attr
1980
2039
 
@@ -2302,9 +2361,12 @@ class Adjoint:
2302
2361
 
2303
2362
  return var
2304
2363
 
2305
- if isinstance(expr, (type, Var, warp.context.Function)):
2364
+ if isinstance(expr, (type, Struct, Var, warp.context.Function)):
2306
2365
  return expr
2307
2366
 
2367
+ if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2368
+ return adj.add_constant(int(expr))
2369
+
2308
2370
  return adj.add_constant(expr)
2309
2371
 
2310
2372
  def emit_Call(adj, node):
@@ -2322,8 +2384,9 @@ class Adjoint:
2322
2384
 
2323
2385
  if adj.is_static_expression(func):
2324
2386
  # try to evaluate wp.static() expressions
2325
- obj, _ = adj.evaluate_static_expression(node)
2387
+ obj, code = adj.evaluate_static_expression(node)
2326
2388
  if obj is not None:
2389
+ adj.static_expressions[code] = obj
2327
2390
  if isinstance(obj, warp.context.Function):
2328
2391
  # special handling for wp.static() evaluating to a function
2329
2392
  return obj
@@ -2352,7 +2415,8 @@ class Adjoint:
2352
2415
 
2353
2416
  # struct constructor
2354
2417
  if func is None and isinstance(caller, Struct):
2355
- adj.builder.build_struct_recursive(caller)
2418
+ if adj.builder is not None:
2419
+ adj.builder.build_struct_recursive(caller)
2356
2420
  if node.args or node.keywords:
2357
2421
  func = caller.value_constructor
2358
2422
  else:
@@ -2412,68 +2476,45 @@ class Adjoint:
2412
2476
 
2413
2477
  return adj.eval(node.value)
2414
2478
 
2415
- # returns the object being indexed, and the list of indices
2416
- def eval_subscript(adj, node):
2417
- # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2418
- # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2419
- # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2420
- root = node
2421
- count = 0
2422
- array = None
2423
- while isinstance(root, ast.Subscript):
2424
- if isinstance(root.slice, ast.Tuple):
2425
- # handles the x[i, j] case (Python 3.8.x upward)
2426
- count += len(root.slice.elts)
2427
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2428
- # handles the x[i, j] case (Python 3.7.x)
2429
- count += len(root.slice.value.elts)
2430
- else:
2431
- # simple expression, e.g.: x[i]
2432
- count += 1
2433
-
2434
- if isinstance(root.value, ast.Name):
2435
- symbol = adj.emit_Name(root.value)
2436
- symbol_type = strip_reference(symbol.type)
2437
- if is_array(symbol_type):
2438
- array = symbol
2439
- break
2440
-
2441
- root = root.value
2442
-
2443
- # If not all indices index into the array, just evaluate the right-most indexing operation.
2444
- if not array or (count > array.type.ndim):
2445
- count = 1
2446
-
2447
- indices = []
2448
- root = node
2449
- while len(indices) < count:
2450
- if isinstance(root.slice, ast.Tuple):
2451
- ij = [adj.eval(arg) for arg in root.slice.elts]
2452
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2453
- ij = [adj.eval(arg) for arg in root.slice.value.elts]
2454
- else:
2455
- ij = [adj.eval(root.slice)]
2456
-
2457
- indices = ij + indices # prepend
2458
-
2459
- root = root.value
2460
-
2461
- target = adj.eval(root)
2479
+ def eval_indices(adj, target_type, indices):
2480
+ nodes = indices
2481
+ if hasattr(target_type, "_wp_generic_type_hint_"):
2482
+ indices = []
2483
+ for dim, node in enumerate(nodes):
2484
+ if isinstance(node, ast.Slice):
2485
+ # In the context of slicing a vec/mat type, indices are expected
2486
+ # to be compile-time constants, hence we can infer the actual slice
2487
+ # bounds also at compile-time.
2488
+ length = target_type._shape_[dim]
2489
+ step = 1 if node.step is None else adj.eval(node.step).constant
2490
+
2491
+ if node.lower is None:
2492
+ start = length - 1 if step < 0 else 0
2493
+ else:
2494
+ start = adj.eval(node.lower).constant
2495
+ start = min(max(start, -length), length)
2496
+ start = start + length if start < 0 else start
2462
2497
 
2463
- return target, indices
2498
+ if node.upper is None:
2499
+ stop = -1 if step < 0 else length
2500
+ else:
2501
+ stop = adj.eval(node.upper).constant
2502
+ stop = min(max(stop, -length), length)
2503
+ stop = stop + length if stop < 0 else stop
2464
2504
 
2465
- def emit_Subscript(adj, node):
2466
- if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2467
- # handle adjoint of a variable, i.e. wp.adjoint[var]
2468
- node.slice.is_adjoint = True
2469
- var = adj.eval(node.slice)
2470
- var_name = var.label
2471
- var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2472
- return var
2505
+ slice = adj.add_builtin_call("slice", (start, stop, step))
2506
+ indices.append(slice)
2507
+ else:
2508
+ indices.append(adj.eval(node))
2473
2509
 
2474
- target, indices = adj.eval_subscript(node)
2510
+ return tuple(indices)
2511
+ else:
2512
+ return tuple(adj.eval(x) for x in nodes)
2475
2513
 
2514
+ def emit_indexing(adj, target, indices):
2476
2515
  target_type = strip_reference(target.type)
2516
+ indices = adj.eval_indices(target_type, indices)
2517
+
2477
2518
  if is_array(target_type):
2478
2519
  if len(indices) == target_type.ndim:
2479
2520
  # handles array loads (where each dimension has an index specified)
@@ -2512,47 +2553,116 @@ class Adjoint:
2512
2553
 
2513
2554
  return out
2514
2555
 
2556
+ # from a list of lists of indices, strip the first `count` indices
2557
+ @staticmethod
2558
+ def strip_indices(indices, count):
2559
+ dim = count
2560
+ while count > 0:
2561
+ ij = indices[0]
2562
+ indices = indices[1:]
2563
+ count -= len(ij)
2564
+
2565
+ # report straddling like in `arr2d[0][1,2]` as a syntax error
2566
+ if count < 0:
2567
+ raise WarpCodegenError(
2568
+ f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2569
+ )
2570
+
2571
+ return indices
2572
+
2573
+ def recurse_subscript(adj, node, indices):
2574
+ if isinstance(node, ast.Name):
2575
+ target = adj.eval(node)
2576
+ return target, indices
2577
+
2578
+ if isinstance(node, ast.Subscript):
2579
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2580
+ return adj.eval(node), indices
2581
+
2582
+ if isinstance(node.slice, ast.Tuple):
2583
+ ij = node.slice.elts
2584
+ elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2585
+ # The node `ast.Index` is deprecated in Python 3.9.
2586
+ ij = node.slice.value.elts
2587
+ elif isinstance(node.slice, ast.ExtSlice):
2588
+ # The node `ast.ExtSlice` is deprecated in Python 3.9.
2589
+ ij = node.slice.dims
2590
+ else:
2591
+ ij = [node.slice]
2592
+
2593
+ indices = [ij, *indices] # prepend
2594
+
2595
+ target, indices = adj.recurse_subscript(node.value, indices)
2596
+
2597
+ target_type = strip_reference(target.type)
2598
+ if is_array(target_type):
2599
+ flat_indices = [i for ij in indices for i in ij]
2600
+ if len(flat_indices) > target_type.ndim:
2601
+ target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2602
+ indices = adj.strip_indices(indices, target_type.ndim)
2603
+
2604
+ return target, indices
2605
+
2606
+ target = adj.eval(node)
2607
+ return target, indices
2608
+
2609
+ # returns the object being indexed, and the list of indices
2610
+ def eval_subscript(adj, node):
2611
+ target, indices = adj.recurse_subscript(node, [])
2612
+ flat_indices = [i for ij in indices for i in ij]
2613
+ return target, flat_indices
2614
+
2615
+ def emit_Subscript(adj, node):
2616
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2617
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2618
+ node.slice.is_adjoint = True
2619
+ var = adj.eval(node.slice)
2620
+ var_name = var.label
2621
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2622
+ return var
2623
+
2624
+ target, indices = adj.eval_subscript(node)
2625
+
2626
+ return adj.emit_indexing(target, indices)
2627
+
2515
2628
  def emit_Assign(adj, node):
2516
2629
  if len(node.targets) != 1:
2517
2630
  raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2518
2631
 
2519
- lhs = node.targets[0]
2632
+ # Check if the rhs corresponds to an unsupported construct.
2633
+ # Tuples are supported in the context of assigning multiple variables
2634
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2635
+ # Therefore, we need to catch this specific case here instead of
2636
+ # more generally in `adj.eval()`.
2637
+ if isinstance(node.value, ast.List):
2638
+ raise WarpCodegenError(
2639
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2640
+ )
2520
2641
 
2521
- if not isinstance(lhs, ast.Tuple):
2522
- # Check if the rhs corresponds to an unsupported construct.
2523
- # Tuples are supported in the context of assigning multiple variables
2524
- # at once, but not for simple assignments like `x = (1, 2, 3)`.
2525
- # Therefore, we need to catch this specific case here instead of
2526
- # more generally in `adj.eval()`.
2527
- if isinstance(node.value, ast.List):
2528
- raise WarpCodegenError(
2529
- "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2530
- )
2642
+ lhs = node.targets[0]
2531
2643
 
2532
- # handle the case where we are assigning multiple output variables
2533
- if isinstance(lhs, ast.Tuple):
2644
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2534
2645
  # record the expected number of outputs on the node
2535
2646
  # we do this so we can decide which function to
2536
2647
  # call based on the number of expected outputs
2537
- if isinstance(node.value, ast.Call):
2538
- node.value.expects = len(lhs.elts)
2648
+ node.value.expects = len(lhs.elts)
2539
2649
 
2540
- # evaluate values
2541
- if isinstance(node.value, ast.Tuple):
2542
- out = [adj.eval(v) for v in node.value.elts]
2543
- else:
2544
- out = adj.eval(node.value)
2650
+ # evaluate rhs
2651
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2652
+ rhs = [adj.eval(v) for v in node.value.elts]
2653
+ else:
2654
+ rhs = adj.eval(node.value)
2655
+
2656
+ # handle the case where we are assigning multiple output variables
2657
+ if isinstance(lhs, ast.Tuple):
2658
+ subtype = getattr(rhs, "type", None)
2545
2659
 
2546
- subtype = getattr(out, "type", None)
2547
2660
  if isinstance(subtype, warp.types.tuple_t):
2548
- if len(out.type.types) != len(lhs.elts):
2661
+ if len(rhs.type.types) != len(lhs.elts):
2549
2662
  raise WarpCodegenError(
2550
- f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
2663
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2551
2664
  )
2552
- target = out
2553
- out = tuple(
2554
- adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
2555
- )
2665
+ rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2556
2666
 
2557
2667
  names = []
2558
2668
  for v in lhs.elts:
@@ -2563,11 +2673,12 @@ class Adjoint:
2563
2673
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2564
2674
  )
2565
2675
 
2566
- if len(names) != len(out):
2676
+ if len(names) != len(rhs):
2567
2677
  raise WarpCodegenError(
2568
- f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
2678
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2569
2679
  )
2570
2680
 
2681
+ out = rhs
2571
2682
  for name, rhs in zip(names, out):
2572
2683
  if name in adj.symbols:
2573
2684
  if not types_equal(rhs.type, adj.symbols[name].type):
@@ -2579,8 +2690,6 @@ class Adjoint:
2579
2690
 
2580
2691
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2581
2692
  elif isinstance(lhs, ast.Subscript):
2582
- rhs = adj.eval(node.value)
2583
-
2584
2693
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2585
2694
  # handle adjoint of a variable, i.e. wp.adjoint[var]
2586
2695
  lhs.slice.is_adjoint = True
@@ -2592,6 +2701,7 @@ class Adjoint:
2592
2701
  target, indices = adj.eval_subscript(lhs)
2593
2702
 
2594
2703
  target_type = strip_reference(target.type)
2704
+ indices = adj.eval_indices(target_type, indices)
2595
2705
 
2596
2706
  if is_array(target_type):
2597
2707
  adj.add_builtin_call("array_store", [target, *indices, rhs])
@@ -2613,14 +2723,11 @@ class Adjoint:
2613
2723
  or type_is_transformation(target_type)
2614
2724
  ):
2615
2725
  # recursively unwind AST, stopping at penultimate node
2616
- node = lhs
2617
- while hasattr(node, "value"):
2618
- if hasattr(node.value, "value"):
2619
- node = node.value
2620
- else:
2621
- break
2726
+ root = lhs
2727
+ while hasattr(root.value, "value"):
2728
+ root = root.value
2622
2729
  # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2623
- if hasattr(node, "attr") and node.attr == "adjoint":
2730
+ if hasattr(root, "attr") and root.attr == "adjoint":
2624
2731
  attr = adj.add_builtin_call("index", [target, *indices])
2625
2732
  adj.add_builtin_call("store", [attr, rhs])
2626
2733
  return
@@ -2658,9 +2765,6 @@ class Adjoint:
2658
2765
  # symbol name
2659
2766
  name = lhs.id
2660
2767
 
2661
- # evaluate rhs
2662
- rhs = adj.eval(node.value)
2663
-
2664
2768
  # check type matches if symbol already defined
2665
2769
  if name in adj.symbols:
2666
2770
  if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
@@ -2681,7 +2785,6 @@ class Adjoint:
2681
2785
  adj.symbols[name] = out
2682
2786
 
2683
2787
  elif isinstance(lhs, ast.Attribute):
2684
- rhs = adj.eval(node.value)
2685
2788
  aggregate = adj.eval(lhs.value)
2686
2789
  aggregate_type = strip_reference(aggregate.type)
2687
2790
 
@@ -2769,9 +2872,9 @@ class Adjoint:
2769
2872
  new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2770
2873
  adj.eval(new_node)
2771
2874
 
2772
- if isinstance(lhs, ast.Subscript):
2773
- rhs = adj.eval(node.value)
2875
+ rhs = adj.eval(node.value)
2774
2876
 
2877
+ if isinstance(lhs, ast.Subscript):
2775
2878
  # wp.adjoint[var] appears in custom grad functions, and does not require
2776
2879
  # special consideration in the AugAssign case
2777
2880
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
@@ -2781,6 +2884,7 @@ class Adjoint:
2781
2884
  target, indices = adj.eval_subscript(lhs)
2782
2885
 
2783
2886
  target_type = strip_reference(target.type)
2887
+ indices = adj.eval_indices(target_type, indices)
2784
2888
 
2785
2889
  if is_array(target_type):
2786
2890
  # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
@@ -2853,7 +2957,6 @@ class Adjoint:
2853
2957
 
2854
2958
  elif isinstance(lhs, ast.Name):
2855
2959
  target = adj.eval(node.target)
2856
- rhs = adj.eval(node.value)
2857
2960
 
2858
2961
  if is_tile(target.type) and is_tile(rhs.type):
2859
2962
  if isinstance(node.op, ast.Add):
@@ -3109,6 +3212,7 @@ class Adjoint:
3109
3212
 
3110
3213
  # Since this is an expression, we can enforce it to be defined on a single line.
3111
3214
  static_code = static_code.replace("\n", "")
3215
+ code_to_eval = static_code # code to be evaluated
3112
3216
 
3113
3217
  vars_dict = adj.get_static_evaluation_context()
3114
3218
  # add constant variables to the static call context
@@ -3150,10 +3254,12 @@ class Adjoint:
3150
3254
  loc = end
3151
3255
 
3152
3256
  new_static_code += static_code[len_value_locs[-1][2] :]
3153
- static_code = new_static_code
3257
+ code_to_eval = new_static_code
3154
3258
 
3155
3259
  try:
3156
- value = eval(static_code, vars_dict)
3260
+ value = eval(code_to_eval, vars_dict)
3261
+ if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3262
+ value = int(value)
3157
3263
  if warp.config.verbose:
3158
3264
  print(f"Evaluated static command: {static_code} = {value}")
3159
3265
  except NameError as e:
@@ -3206,6 +3312,9 @@ class Adjoint:
3206
3312
  # (and is therefore not executable and raises this exception), in which
3207
3313
  # case changing the constant, or the code affecting this constant, would lead to
3208
3314
  # a different module hash anyway.
3315
+ # In any case, we mark this Adjoint to have unresolvable static expressions.
3316
+ # This will trigger a code generation step even if the module hash is unchanged.
3317
+ adj.has_unresolved_static_expressions = True
3209
3318
  pass
3210
3319
 
3211
3320
  return self.generic_visit(node)
@@ -3361,6 +3470,11 @@ cuda_module_header = """
3361
3470
  #define WP_NO_CRT
3362
3471
  #include "builtin.h"
3363
3472
 
3473
+ // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3474
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
3475
+ #define __debugbreak() __brkpt()
3476
+ #endif
3477
+
3364
3478
  // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3365
3479
  #define float(x) cast_float(x)
3366
3480
  #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
@@ -3398,6 +3512,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
3398
3512
  {{
3399
3513
  {reverse_body}}}
3400
3514
 
3515
+ // Required when compiling adjoints.
3516
+ CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3517
+ {{
3518
+ return {name}();
3519
+ }}
3520
+
3401
3521
  CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3402
3522
  {{
3403
3523
  {atomic_add_body}}}
@@ -3478,7 +3598,8 @@ cuda_kernel_template_backward = """
3478
3598
  cpu_kernel_template_forward = """
3479
3599
 
3480
3600
  void {name}_cpu_kernel_forward(
3481
- {forward_args})
3601
+ {forward_args},
3602
+ wp_args_{name} *_wp_args)
3482
3603
  {{
3483
3604
  {forward_body}}}
3484
3605
 
@@ -3487,7 +3608,9 @@ void {name}_cpu_kernel_forward(
3487
3608
  cpu_kernel_template_backward = """
3488
3609
 
3489
3610
  void {name}_cpu_kernel_backward(
3490
- {reverse_args})
3611
+ {reverse_args},
3612
+ wp_args_{name} *_wp_args,
3613
+ wp_args_{name} *_wp_adj_args)
3491
3614
  {{
3492
3615
  {reverse_body}}}
3493
3616
 
@@ -3499,15 +3622,15 @@ extern "C" {{
3499
3622
 
3500
3623
  // Python CPU entry points
3501
3624
  WP_API void {name}_cpu_forward(
3502
- {forward_args})
3625
+ wp::launch_bounds_t dim,
3626
+ wp_args_{name} *_wp_args)
3503
3627
  {{
3504
3628
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3505
3629
  {{
3506
3630
  // init shared memory allocator
3507
3631
  wp::tile_alloc_shared(0, true);
3508
3632
 
3509
- {name}_cpu_kernel_forward(
3510
- {forward_params});
3633
+ {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3511
3634
 
3512
3635
  // check shared memory allocator
3513
3636
  wp::tile_alloc_shared(0, false, true);
@@ -3524,15 +3647,16 @@ cpu_module_template_backward = """
3524
3647
  extern "C" {{
3525
3648
 
3526
3649
  WP_API void {name}_cpu_backward(
3527
- {reverse_args})
3650
+ wp::launch_bounds_t dim,
3651
+ wp_args_{name} *_wp_args,
3652
+ wp_args_{name} *_wp_adj_args)
3528
3653
  {{
3529
3654
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3530
3655
  {{
3531
3656
  // initialize shared memory allocator
3532
3657
  wp::tile_alloc_shared(0, true);
3533
3658
 
3534
- {name}_cpu_kernel_backward(
3535
- {reverse_params});
3659
+ {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3536
3660
 
3537
3661
  // check shared memory allocator
3538
3662
  wp::tile_alloc_shared(0, false, true);
@@ -3563,7 +3687,7 @@ def constant_str(value):
3563
3687
  # special case for float16, which is stored as uint16 in the ctypes.Array
3564
3688
  from warp.context import runtime
3565
3689
 
3566
- scalar_value = runtime.core.half_bits_to_float
3690
+ scalar_value = runtime.core.wp_half_bits_to_float
3567
3691
  else:
3568
3692
 
3569
3693
  def scalar_value(x):
@@ -3701,8 +3825,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3701
3825
 
3702
3826
  indent_block = " " * indent
3703
3827
 
3704
- # primal vars
3705
3828
  lines = []
3829
+
3830
+ # argument vars
3831
+ if device == "cpu" and func_type == "kernel":
3832
+ lines += ["//---------\n"]
3833
+ lines += ["// argument vars\n"]
3834
+
3835
+ for var in adj.args:
3836
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3837
+
3838
+ # primal vars
3706
3839
  lines += ["//---------\n"]
3707
3840
  lines += ["// primal vars\n"]
3708
3841
 
@@ -3746,6 +3879,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3746
3879
 
3747
3880
  lines = []
3748
3881
 
3882
+ # argument vars
3883
+ if device == "cpu" and func_type == "kernel":
3884
+ lines += ["//---------\n"]
3885
+ lines += ["// argument vars\n"]
3886
+
3887
+ for var in adj.args:
3888
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3889
+
3890
+ for var in adj.args:
3891
+ lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3892
+
3749
3893
  # primal vars
3750
3894
  lines += ["//---------\n"]
3751
3895
  lines += ["// primal vars\n"]
@@ -3837,6 +3981,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3837
3981
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3838
3982
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3839
3983
  )
3984
+ elif (
3985
+ isinstance(adj.return_var[0].type, warp.types.fixedarray)
3986
+ and type(adj.arg_types["return"]) is warp.types.array
3987
+ ):
3988
+ # If the return statement yields a `fixedarray` while the function is annotated
3989
+ # to return a standard `array`, then raise an error since the `fixedarray` storage
3990
+ # allocated on the stack will be freed once the function exits, meaning that the
3991
+ # resulting `array` instance will point to an invalid data.
3992
+ raise WarpCodegenError(
3993
+ f"The function `{adj.fun_name}` returns a fixed-size array "
3994
+ f"whereas it has its return type annotated as "
3995
+ f"`{warp.context.type_str(adj.arg_types['return'])}`."
3996
+ )
3840
3997
 
3841
3998
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3842
3999
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
@@ -3915,10 +4072,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3915
4072
  if adj.custom_reverse_mode:
3916
4073
  reverse_body = "\t// user-defined adjoint code\n" + forward_body
3917
4074
  else:
3918
- if options.get("enable_backward", True):
4075
+ if options.get("enable_backward", True) and adj.used_by_backward_kernel:
3919
4076
  reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
3920
4077
  else:
3921
- reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
4078
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
3922
4079
  s += reverse_template.format(
3923
4080
  name=c_func_name,
3924
4081
  return_type=return_type,
@@ -4010,6 +4167,13 @@ def codegen_kernel(kernel, device, options):
4010
4167
 
4011
4168
  adj = kernel.adj
4012
4169
 
4170
+ args_struct = ""
4171
+ if device == "cpu":
4172
+ args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4173
+ for i in adj.args:
4174
+ args_struct += f" {i.ctype()} {i.label};\n"
4175
+ args_struct += "};\n"
4176
+
4013
4177
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4014
4178
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4015
4179
  # a direct mapping to a Python source line.
@@ -4035,9 +4199,9 @@ def codegen_kernel(kernel, device, options):
4035
4199
  forward_args = ["wp::launch_bounds_t dim"]
4036
4200
  if device == "cpu":
4037
4201
  forward_args.append("size_t task_index")
4038
-
4039
- for arg in adj.args:
4040
- forward_args.append(arg.ctype() + " var_" + arg.label)
4202
+ else:
4203
+ for arg in adj.args:
4204
+ forward_args.append(arg.ctype() + " var_" + arg.label)
4041
4205
 
4042
4206
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4043
4207
  template_fmt_args.update(
@@ -4054,17 +4218,16 @@ def codegen_kernel(kernel, device, options):
4054
4218
  reverse_args = ["wp::launch_bounds_t dim"]
4055
4219
  if device == "cpu":
4056
4220
  reverse_args.append("size_t task_index")
4057
-
4058
- for arg in adj.args:
4059
- reverse_args.append(arg.ctype() + " var_" + arg.label)
4060
-
4061
- for arg in adj.args:
4062
- # indexed array gradients are regular arrays
4063
- if isinstance(arg.type, indexedarray):
4064
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4065
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4066
- else:
4067
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
4221
+ else:
4222
+ for arg in adj.args:
4223
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
4224
+ for arg in adj.args:
4225
+ # indexed array gradients are regular arrays
4226
+ if isinstance(arg.type, indexedarray):
4227
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4228
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4229
+ else:
4230
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
4068
4231
 
4069
4232
  reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4070
4233
  template_fmt_args.update(
@@ -4076,7 +4239,7 @@ def codegen_kernel(kernel, device, options):
4076
4239
  template += template_backward
4077
4240
 
4078
4241
  s = template.format(**template_fmt_args)
4079
- return s
4242
+ return args_struct + s
4080
4243
 
4081
4244
 
4082
4245
  def codegen_module(kernel, device, options):
@@ -4087,59 +4250,14 @@ def codegen_module(kernel, device, options):
4087
4250
  options = dict(options)
4088
4251
  options.update(kernel.options)
4089
4252
 
4090
- adj = kernel.adj
4091
-
4092
4253
  template = ""
4093
4254
  template_fmt_args = {
4094
4255
  "name": kernel.get_mangled_name(),
4095
4256
  }
4096
4257
 
4097
- # build forward signature
4098
- forward_args = ["wp::launch_bounds_t dim"]
4099
- forward_params = ["dim", "task_index"]
4100
-
4101
- for arg in adj.args:
4102
- if hasattr(arg.type, "_wp_generic_type_str_"):
4103
- # vectors and matrices are passed from Python by pointer
4104
- forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
4105
- forward_params.append(f"*var_{arg.label}")
4106
- else:
4107
- forward_args.append(f"{arg.ctype()} var_{arg.label}")
4108
- forward_params.append("var_" + arg.label)
4109
-
4110
- template_fmt_args.update(
4111
- {
4112
- "forward_args": indent(forward_args),
4113
- "forward_params": indent(forward_params, 3),
4114
- }
4115
- )
4116
4258
  template += cpu_module_template_forward
4117
4259
 
4118
4260
  if options["enable_backward"]:
4119
- # build reverse signature
4120
- reverse_args = [*forward_args]
4121
- reverse_params = [*forward_params]
4122
-
4123
- for arg in adj.args:
4124
- if isinstance(arg.type, indexedarray):
4125
- # indexed array gradients are regular arrays
4126
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4127
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
4128
- reverse_params.append(f"adj_{_arg.label}")
4129
- elif hasattr(arg.type, "_wp_generic_type_str_"):
4130
- # vectors and matrices are passed from Python by pointer
4131
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
4132
- reverse_params.append(f"*adj_{arg.label}")
4133
- else:
4134
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
4135
- reverse_params.append(f"adj_{arg.label}")
4136
-
4137
- template_fmt_args.update(
4138
- {
4139
- "reverse_args": indent(reverse_args),
4140
- "reverse_params": indent(reverse_params, 3),
4141
- }
4142
- )
4143
4261
  template += cpu_module_template_backward
4144
4262
 
4145
4263
  s = template.format(**template_fmt_args)