warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -18,9 +18,11 @@ from __future__ import annotations
18
18
  import ast
19
19
  import builtins
20
20
  import ctypes
21
+ import enum
21
22
  import functools
22
23
  import hashlib
23
24
  import inspect
25
+ import itertools
24
26
  import math
25
27
  import re
26
28
  import sys
@@ -614,6 +616,8 @@ def compute_type_str(base_name, template_params):
614
616
  return base_name
615
617
 
616
618
  def param2str(p):
619
+ if isinstance(p, builtins.bool):
620
+ return "true" if p else "false"
617
621
  if isinstance(p, int):
618
622
  return str(p)
619
623
  elif hasattr(p, "_wp_generic_type_str_"):
@@ -625,6 +629,8 @@ def compute_type_str(base_name, template_params):
625
629
  return f"wp::{p.__name__}"
626
630
  elif is_tile(p):
627
631
  return p.ctype()
632
+ elif isinstance(p, Struct):
633
+ return p.native_name
628
634
 
629
635
  return p.__name__
630
636
 
@@ -684,7 +690,12 @@ class Var:
684
690
 
685
691
  @staticmethod
686
692
  def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
687
- if is_array(t):
693
+ if isinstance(t, fixedarray):
694
+ template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
695
+ dtypestr = ", ".join(template_args)
696
+ classstr = f"wp::{type(t).__name__}"
697
+ return f"{classstr}_t<{dtypestr}>"
698
+ elif is_array(t):
688
699
  dtypestr = Var.dtype_to_ctype(t.dtype)
689
700
  classstr = f"wp::{type(t).__name__}"
690
701
  return f"{classstr}_t<{dtypestr}>"
@@ -780,11 +791,10 @@ def apply_defaults(
780
791
  arguments = bound_args.arguments
781
792
  new_arguments = []
782
793
  for name in bound_args._signature.parameters.keys():
783
- try:
794
+ if name in arguments:
784
795
  new_arguments.append((name, arguments[name]))
785
- except KeyError:
786
- if name in values:
787
- new_arguments.append((name, values[name]))
796
+ elif name in values:
797
+ new_arguments.append((name, values[name]))
788
798
 
789
799
  bound_args.arguments = dict(new_arguments)
790
800
 
@@ -837,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
837
847
  if isinstance(arg, Sequence):
838
848
  return tuple(get_arg_type(x) for x in arg)
839
849
 
850
+ if is_array(arg):
851
+ return arg
852
+
840
853
  if get_origin(arg) is tuple:
841
854
  return tuple(get_arg_type(x) for x in get_args(arg))
842
855
 
@@ -896,6 +909,8 @@ class Adjoint:
896
909
  adj.skip_forward_codegen = skip_forward_codegen
897
910
  # whether the generation of the adjoint code is skipped for this function
898
911
  adj.skip_reverse_codegen = skip_reverse_codegen
912
+ # Whether this function is used by a kernel that has has the backward pass enabled.
913
+ adj.used_by_backward_kernel = False
899
914
 
900
915
  # extract name of source file
901
916
  adj.filename = inspect.getsourcefile(func) or "unknown source file"
@@ -962,7 +977,7 @@ class Adjoint:
962
977
  continue
963
978
 
964
979
  # add variable for argument
965
- arg = Var(name, type, False)
980
+ arg = Var(name, type, requires_grad=False)
966
981
  adj.args.append(arg)
967
982
 
968
983
  # pre-populate symbol dictionary with function argument names
@@ -1071,17 +1086,21 @@ class Adjoint:
1071
1086
  # recursively evaluate function body
1072
1087
  try:
1073
1088
  adj.eval(adj.tree.body[0])
1074
- except Exception:
1089
+ except Exception as original_exc:
1075
1090
  try:
1076
1091
  lineno = adj.lineno + adj.fun_lineno
1077
1092
  line = adj.source_lines[adj.lineno]
1078
1093
  msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1079
- ex, data, traceback = sys.exc_info()
1080
- e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
1094
+
1095
+ # Combine the new message with the original exception's arguments
1096
+ new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1097
+
1098
+ # Enhance the original exception with parser context before re-raising.
1099
+ # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1100
+ raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1081
1101
  finally:
1082
1102
  adj.skip_build = True
1083
1103
  adj.builder = None
1084
- raise e
1085
1104
 
1086
1105
  if builder is not None:
1087
1106
  for a in adj.args:
@@ -1225,11 +1244,16 @@ class Adjoint:
1225
1244
  A line directive for the given statement, or None if no line directive is needed.
1226
1245
  """
1227
1246
 
1247
+ if adj.filename == "unknown source file" or adj.fun_lineno == 0:
1248
+ # Early return if function is not associated with a source file or is otherwise invalid
1249
+ # TODO: Get line directives working with wp.map() functions
1250
+ return None
1251
+
1228
1252
  # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1229
1253
  # emit line directives in generated code if it's not being compiled with line information
1230
- lineinfo_enabled = (
1231
- adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1232
- )
1254
+ build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
1255
+
1256
+ lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1233
1257
 
1234
1258
  if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1235
1259
  is_comment = statement.strip().startswith("//")
@@ -1348,7 +1372,7 @@ class Adjoint:
1348
1372
  # unresolved function, report error
1349
1373
  arg_type_reprs = []
1350
1374
 
1351
- for x in arg_types:
1375
+ for x in itertools.chain(arg_types, kwarg_types.values()):
1352
1376
  if isinstance(x, warp.context.Function):
1353
1377
  arg_type_reprs.append("function")
1354
1378
  else:
@@ -1378,7 +1402,7 @@ class Adjoint:
1378
1402
  # in order to process them as Python does it.
1379
1403
  bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1380
1404
 
1381
- # Type args are the compile time argument values we get from codegen.
1405
+ # Type args are the "compile time" argument values we get from codegen.
1382
1406
  # For example, when calling `wp.vec3f(...)` from within a kernel,
1383
1407
  # this translates in fact to calling the `vector()` built-in augmented
1384
1408
  # with the type args `length=3, dtype=float`.
@@ -1416,20 +1440,30 @@ class Adjoint:
1416
1440
  bound_args = bound_args.arguments
1417
1441
 
1418
1442
  # if it is a user-function then build it recursively
1419
- if not func.is_builtin() and func not in adj.builder.functions:
1420
- adj.builder.build_function(func)
1421
- # add custom grad, replay functions to the list of functions
1422
- # to be built later (invalid code could be generated if we built them now)
1423
- # so that they are not missed when only the forward function is imported
1424
- # from another module
1425
- if func.custom_grad_func:
1426
- adj.builder.deferred_functions.append(func.custom_grad_func)
1427
- if func.custom_replay_func:
1428
- adj.builder.deferred_functions.append(func.custom_replay_func)
1443
+ if not func.is_builtin():
1444
+ # If the function called is a user function,
1445
+ # we need to ensure its adjoint is also being generated.
1446
+ if adj.used_by_backward_kernel:
1447
+ func.adj.used_by_backward_kernel = True
1448
+
1449
+ if adj.builder is None:
1450
+ func.build(None)
1451
+
1452
+ elif func not in adj.builder.functions:
1453
+ adj.builder.build_function(func)
1454
+ # add custom grad, replay functions to the list of functions
1455
+ # to be built later (invalid code could be generated if we built them now)
1456
+ # so that they are not missed when only the forward function is imported
1457
+ # from another module
1458
+ if func.custom_grad_func:
1459
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1460
+ if func.custom_replay_func:
1461
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1429
1462
 
1430
1463
  # Resolve the return value based on the types and values of the given arguments.
1431
1464
  bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1432
1465
  bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1466
+
1433
1467
  return_type = func.value_func(
1434
1468
  {k: strip_reference(v) for k, v in bound_arg_types.items()},
1435
1469
  bound_arg_values,
@@ -1493,6 +1527,9 @@ class Adjoint:
1493
1527
 
1494
1528
  # if the argument is a function (and not a builtin), then build it recursively
1495
1529
  if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1530
+ if adj.used_by_backward_kernel:
1531
+ func_arg_var.adj.used_by_backward_kernel = True
1532
+
1496
1533
  adj.builder.build_function(func_arg_var)
1497
1534
 
1498
1535
  fwd_args.append(strip_reference(func_arg_var))
@@ -1886,6 +1923,9 @@ class Adjoint:
1886
1923
  return obj
1887
1924
  if isinstance(obj, type):
1888
1925
  return obj
1926
+ if isinstance(obj, Struct):
1927
+ adj.builder.build_struct_recursive(obj)
1928
+ return obj
1889
1929
  if isinstance(obj, types.ModuleType):
1890
1930
  return obj
1891
1931
 
@@ -1938,11 +1978,17 @@ class Adjoint:
1938
1978
  aggregate = adj.eval(node.value)
1939
1979
 
1940
1980
  try:
1981
+ if isinstance(aggregate, Var) and aggregate.constant is not None:
1982
+ # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
1983
+ return aggregate
1984
+
1941
1985
  if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1942
1986
  out = getattr(aggregate, node.attr)
1943
1987
 
1944
1988
  if warp.types.is_value(out):
1945
1989
  return adj.add_constant(out)
1990
+ if isinstance(out, (enum.IntEnum, enum.IntFlag)):
1991
+ return adj.add_constant(int(out))
1946
1992
 
1947
1993
  return out
1948
1994
 
@@ -1970,18 +2016,29 @@ class Adjoint:
1970
2016
  return adj.add_builtin_call("transform_get_rotation", [aggregate])
1971
2017
 
1972
2018
  else:
1973
- attr_type = Reference(aggregate_type.vars[node.attr].type)
2019
+ attr_var = aggregate_type.vars[node.attr]
2020
+
2021
+ # represent pointer types as uint64
2022
+ if isinstance(attr_var.type, pointer_t):
2023
+ cast = f"({Var.dtype_to_ctype(uint64)}*)"
2024
+ adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2025
+ attr_type = Reference(uint64)
2026
+ else:
2027
+ cast = ""
2028
+ adj_cast = ""
2029
+ attr_type = Reference(attr_var.type)
2030
+
1974
2031
  attr = adj.add_var(attr_type)
1975
2032
 
1976
2033
  if is_reference(aggregate.type):
1977
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
2034
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
1978
2035
  else:
1979
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
2036
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
1980
2037
 
1981
2038
  if adj.is_differentiable_value_type(strip_reference(attr_type)):
1982
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
2039
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
1983
2040
  else:
1984
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
2041
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
1985
2042
 
1986
2043
  return attr
1987
2044
 
@@ -2309,9 +2366,12 @@ class Adjoint:
2309
2366
 
2310
2367
  return var
2311
2368
 
2312
- if isinstance(expr, (type, Var, warp.context.Function)):
2369
+ if isinstance(expr, (type, Struct, Var, warp.context.Function)):
2313
2370
  return expr
2314
2371
 
2372
+ if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2373
+ return adj.add_constant(int(expr))
2374
+
2315
2375
  return adj.add_constant(expr)
2316
2376
 
2317
2377
  def emit_Call(adj, node):
@@ -2360,7 +2420,8 @@ class Adjoint:
2360
2420
 
2361
2421
  # struct constructor
2362
2422
  if func is None and isinstance(caller, Struct):
2363
- adj.builder.build_struct_recursive(caller)
2423
+ if adj.builder is not None:
2424
+ adj.builder.build_struct_recursive(caller)
2364
2425
  if node.args or node.keywords:
2365
2426
  func = caller.value_constructor
2366
2427
  else:
@@ -2420,68 +2481,45 @@ class Adjoint:
2420
2481
 
2421
2482
  return adj.eval(node.value)
2422
2483
 
2423
- # returns the object being indexed, and the list of indices
2424
- def eval_subscript(adj, node):
2425
- # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2426
- # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2427
- # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2428
- root = node
2429
- count = 0
2430
- array = None
2431
- while isinstance(root, ast.Subscript):
2432
- if isinstance(root.slice, ast.Tuple):
2433
- # handles the x[i, j] case (Python 3.8.x upward)
2434
- count += len(root.slice.elts)
2435
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2436
- # handles the x[i, j] case (Python 3.7.x)
2437
- count += len(root.slice.value.elts)
2438
- else:
2439
- # simple expression, e.g.: x[i]
2440
- count += 1
2441
-
2442
- if isinstance(root.value, ast.Name):
2443
- symbol = adj.emit_Name(root.value)
2444
- symbol_type = strip_reference(symbol.type)
2445
- if is_array(symbol_type):
2446
- array = symbol
2447
- break
2448
-
2449
- root = root.value
2450
-
2451
- # If not all indices index into the array, just evaluate the right-most indexing operation.
2452
- if not array or (count > array.type.ndim):
2453
- count = 1
2454
-
2455
- indices = []
2456
- root = node
2457
- while len(indices) < count:
2458
- if isinstance(root.slice, ast.Tuple):
2459
- ij = [adj.eval(arg) for arg in root.slice.elts]
2460
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2461
- ij = [adj.eval(arg) for arg in root.slice.value.elts]
2462
- else:
2463
- ij = [adj.eval(root.slice)]
2464
-
2465
- indices = ij + indices # prepend
2466
-
2467
- root = root.value
2468
-
2469
- target = adj.eval(root)
2484
+ def eval_indices(adj, target_type, indices):
2485
+ nodes = indices
2486
+ if hasattr(target_type, "_wp_generic_type_hint_"):
2487
+ indices = []
2488
+ for dim, node in enumerate(nodes):
2489
+ if isinstance(node, ast.Slice):
2490
+ # In the context of slicing a vec/mat type, indices are expected
2491
+ # to be compile-time constants, hence we can infer the actual slice
2492
+ # bounds also at compile-time.
2493
+ length = target_type._shape_[dim]
2494
+ step = 1 if node.step is None else adj.eval(node.step).constant
2495
+
2496
+ if node.lower is None:
2497
+ start = length - 1 if step < 0 else 0
2498
+ else:
2499
+ start = adj.eval(node.lower).constant
2500
+ start = min(max(start, -length), length)
2501
+ start = start + length if start < 0 else start
2470
2502
 
2471
- return target, indices
2503
+ if node.upper is None:
2504
+ stop = -1 if step < 0 else length
2505
+ else:
2506
+ stop = adj.eval(node.upper).constant
2507
+ stop = min(max(stop, -length), length)
2508
+ stop = stop + length if stop < 0 else stop
2472
2509
 
2473
- def emit_Subscript(adj, node):
2474
- if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2475
- # handle adjoint of a variable, i.e. wp.adjoint[var]
2476
- node.slice.is_adjoint = True
2477
- var = adj.eval(node.slice)
2478
- var_name = var.label
2479
- var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2480
- return var
2510
+ slice = adj.add_builtin_call("slice", (start, stop, step))
2511
+ indices.append(slice)
2512
+ else:
2513
+ indices.append(adj.eval(node))
2481
2514
 
2482
- target, indices = adj.eval_subscript(node)
2515
+ return tuple(indices)
2516
+ else:
2517
+ return tuple(adj.eval(x) for x in nodes)
2483
2518
 
2519
+ def emit_indexing(adj, target, indices):
2484
2520
  target_type = strip_reference(target.type)
2521
+ indices = adj.eval_indices(target_type, indices)
2522
+
2485
2523
  if is_array(target_type):
2486
2524
  if len(indices) == target_type.ndim:
2487
2525
  # handles array loads (where each dimension has an index specified)
@@ -2520,47 +2558,116 @@ class Adjoint:
2520
2558
 
2521
2559
  return out
2522
2560
 
2561
+ # from a list of lists of indices, strip the first `count` indices
2562
+ @staticmethod
2563
+ def strip_indices(indices, count):
2564
+ dim = count
2565
+ while count > 0:
2566
+ ij = indices[0]
2567
+ indices = indices[1:]
2568
+ count -= len(ij)
2569
+
2570
+ # report straddling like in `arr2d[0][1,2]` as a syntax error
2571
+ if count < 0:
2572
+ raise WarpCodegenError(
2573
+ f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2574
+ )
2575
+
2576
+ return indices
2577
+
2578
+ def recurse_subscript(adj, node, indices):
2579
+ if isinstance(node, ast.Name):
2580
+ target = adj.eval(node)
2581
+ return target, indices
2582
+
2583
+ if isinstance(node, ast.Subscript):
2584
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2585
+ return adj.eval(node), indices
2586
+
2587
+ if isinstance(node.slice, ast.Tuple):
2588
+ ij = node.slice.elts
2589
+ elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2590
+ # The node `ast.Index` is deprecated in Python 3.9.
2591
+ ij = node.slice.value.elts
2592
+ elif isinstance(node.slice, ast.ExtSlice):
2593
+ # The node `ast.ExtSlice` is deprecated in Python 3.9.
2594
+ ij = node.slice.dims
2595
+ else:
2596
+ ij = [node.slice]
2597
+
2598
+ indices = [ij, *indices] # prepend
2599
+
2600
+ target, indices = adj.recurse_subscript(node.value, indices)
2601
+
2602
+ target_type = strip_reference(target.type)
2603
+ if is_array(target_type):
2604
+ flat_indices = [i for ij in indices for i in ij]
2605
+ if len(flat_indices) > target_type.ndim:
2606
+ target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2607
+ indices = adj.strip_indices(indices, target_type.ndim)
2608
+
2609
+ return target, indices
2610
+
2611
+ target = adj.eval(node)
2612
+ return target, indices
2613
+
2614
+ # returns the object being indexed, and the list of indices
2615
+ def eval_subscript(adj, node):
2616
+ target, indices = adj.recurse_subscript(node, [])
2617
+ flat_indices = [i for ij in indices for i in ij]
2618
+ return target, flat_indices
2619
+
2620
+ def emit_Subscript(adj, node):
2621
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2622
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2623
+ node.slice.is_adjoint = True
2624
+ var = adj.eval(node.slice)
2625
+ var_name = var.label
2626
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2627
+ return var
2628
+
2629
+ target, indices = adj.eval_subscript(node)
2630
+
2631
+ return adj.emit_indexing(target, indices)
2632
+
2523
2633
  def emit_Assign(adj, node):
2524
2634
  if len(node.targets) != 1:
2525
2635
  raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2526
2636
 
2527
- lhs = node.targets[0]
2637
+ # Check if the rhs corresponds to an unsupported construct.
2638
+ # Tuples are supported in the context of assigning multiple variables
2639
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2640
+ # Therefore, we need to catch this specific case here instead of
2641
+ # more generally in `adj.eval()`.
2642
+ if isinstance(node.value, ast.List):
2643
+ raise WarpCodegenError(
2644
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2645
+ )
2528
2646
 
2529
- if not isinstance(lhs, ast.Tuple):
2530
- # Check if the rhs corresponds to an unsupported construct.
2531
- # Tuples are supported in the context of assigning multiple variables
2532
- # at once, but not for simple assignments like `x = (1, 2, 3)`.
2533
- # Therefore, we need to catch this specific case here instead of
2534
- # more generally in `adj.eval()`.
2535
- if isinstance(node.value, ast.List):
2536
- raise WarpCodegenError(
2537
- "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2538
- )
2647
+ lhs = node.targets[0]
2539
2648
 
2540
- # handle the case where we are assigning multiple output variables
2541
- if isinstance(lhs, ast.Tuple):
2649
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2542
2650
  # record the expected number of outputs on the node
2543
2651
  # we do this so we can decide which function to
2544
2652
  # call based on the number of expected outputs
2545
- if isinstance(node.value, ast.Call):
2546
- node.value.expects = len(lhs.elts)
2653
+ node.value.expects = len(lhs.elts)
2547
2654
 
2548
- # evaluate values
2549
- if isinstance(node.value, ast.Tuple):
2550
- out = [adj.eval(v) for v in node.value.elts]
2551
- else:
2552
- out = adj.eval(node.value)
2655
+ # evaluate rhs
2656
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2657
+ rhs = [adj.eval(v) for v in node.value.elts]
2658
+ else:
2659
+ rhs = adj.eval(node.value)
2660
+
2661
+ # handle the case where we are assigning multiple output variables
2662
+ if isinstance(lhs, ast.Tuple):
2663
+ subtype = getattr(rhs, "type", None)
2553
2664
 
2554
- subtype = getattr(out, "type", None)
2555
2665
  if isinstance(subtype, warp.types.tuple_t):
2556
- if len(out.type.types) != len(lhs.elts):
2666
+ if len(rhs.type.types) != len(lhs.elts):
2557
2667
  raise WarpCodegenError(
2558
- f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
2668
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2559
2669
  )
2560
- target = out
2561
- out = tuple(
2562
- adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
2563
- )
2670
+ rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2564
2671
 
2565
2672
  names = []
2566
2673
  for v in lhs.elts:
@@ -2571,11 +2678,12 @@ class Adjoint:
2571
2678
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2572
2679
  )
2573
2680
 
2574
- if len(names) != len(out):
2681
+ if len(names) != len(rhs):
2575
2682
  raise WarpCodegenError(
2576
- f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
2683
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2577
2684
  )
2578
2685
 
2686
+ out = rhs
2579
2687
  for name, rhs in zip(names, out):
2580
2688
  if name in adj.symbols:
2581
2689
  if not types_equal(rhs.type, adj.symbols[name].type):
@@ -2587,8 +2695,6 @@ class Adjoint:
2587
2695
 
2588
2696
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2589
2697
  elif isinstance(lhs, ast.Subscript):
2590
- rhs = adj.eval(node.value)
2591
-
2592
2698
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2593
2699
  # handle adjoint of a variable, i.e. wp.adjoint[var]
2594
2700
  lhs.slice.is_adjoint = True
@@ -2600,6 +2706,7 @@ class Adjoint:
2600
2706
  target, indices = adj.eval_subscript(lhs)
2601
2707
 
2602
2708
  target_type = strip_reference(target.type)
2709
+ indices = adj.eval_indices(target_type, indices)
2603
2710
 
2604
2711
  if is_array(target_type):
2605
2712
  adj.add_builtin_call("array_store", [target, *indices, rhs])
@@ -2621,14 +2728,11 @@ class Adjoint:
2621
2728
  or type_is_transformation(target_type)
2622
2729
  ):
2623
2730
  # recursively unwind AST, stopping at penultimate node
2624
- node = lhs
2625
- while hasattr(node, "value"):
2626
- if hasattr(node.value, "value"):
2627
- node = node.value
2628
- else:
2629
- break
2731
+ root = lhs
2732
+ while hasattr(root.value, "value"):
2733
+ root = root.value
2630
2734
  # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2631
- if hasattr(node, "attr") and node.attr == "adjoint":
2735
+ if hasattr(root, "attr") and root.attr == "adjoint":
2632
2736
  attr = adj.add_builtin_call("index", [target, *indices])
2633
2737
  adj.add_builtin_call("store", [attr, rhs])
2634
2738
  return
@@ -2666,9 +2770,6 @@ class Adjoint:
2666
2770
  # symbol name
2667
2771
  name = lhs.id
2668
2772
 
2669
- # evaluate rhs
2670
- rhs = adj.eval(node.value)
2671
-
2672
2773
  # check type matches if symbol already defined
2673
2774
  if name in adj.symbols:
2674
2775
  if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
@@ -2689,7 +2790,6 @@ class Adjoint:
2689
2790
  adj.symbols[name] = out
2690
2791
 
2691
2792
  elif isinstance(lhs, ast.Attribute):
2692
- rhs = adj.eval(node.value)
2693
2793
  aggregate = adj.eval(lhs.value)
2694
2794
  aggregate_type = strip_reference(aggregate.type)
2695
2795
 
@@ -2777,9 +2877,9 @@ class Adjoint:
2777
2877
  new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2778
2878
  adj.eval(new_node)
2779
2879
 
2780
- if isinstance(lhs, ast.Subscript):
2781
- rhs = adj.eval(node.value)
2880
+ rhs = adj.eval(node.value)
2782
2881
 
2882
+ if isinstance(lhs, ast.Subscript):
2783
2883
  # wp.adjoint[var] appears in custom grad functions, and does not require
2784
2884
  # special consideration in the AugAssign case
2785
2885
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
@@ -2789,6 +2889,7 @@ class Adjoint:
2789
2889
  target, indices = adj.eval_subscript(lhs)
2790
2890
 
2791
2891
  target_type = strip_reference(target.type)
2892
+ indices = adj.eval_indices(target_type, indices)
2792
2893
 
2793
2894
  if is_array(target_type):
2794
2895
  # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
@@ -2861,7 +2962,6 @@ class Adjoint:
2861
2962
 
2862
2963
  elif isinstance(lhs, ast.Name):
2863
2964
  target = adj.eval(node.target)
2864
- rhs = adj.eval(node.value)
2865
2965
 
2866
2966
  if is_tile(target.type) and is_tile(rhs.type):
2867
2967
  if isinstance(node.op, ast.Add):
@@ -3163,6 +3263,8 @@ class Adjoint:
3163
3263
 
3164
3264
  try:
3165
3265
  value = eval(code_to_eval, vars_dict)
3266
+ if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3267
+ value = int(value)
3166
3268
  if warp.config.verbose:
3167
3269
  print(f"Evaluated static command: {static_code} = {value}")
3168
3270
  except NameError as e:
@@ -3373,6 +3475,11 @@ cuda_module_header = """
3373
3475
  #define WP_NO_CRT
3374
3476
  #include "builtin.h"
3375
3477
 
3478
+ // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3479
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
3480
+ #define __debugbreak() __brkpt()
3481
+ #endif
3482
+
3376
3483
  // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3377
3484
  #define float(x) cast_float(x)
3378
3485
  #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
@@ -3410,6 +3517,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
3410
3517
  {{
3411
3518
  {reverse_body}}}
3412
3519
 
3520
+ // Required when compiling adjoints.
3521
+ CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3522
+ {{
3523
+ return {name}();
3524
+ }}
3525
+
3413
3526
  CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3414
3527
  {{
3415
3528
  {atomic_add_body}}}
@@ -3490,7 +3603,8 @@ cuda_kernel_template_backward = """
3490
3603
  cpu_kernel_template_forward = """
3491
3604
 
3492
3605
  void {name}_cpu_kernel_forward(
3493
- {forward_args})
3606
+ {forward_args},
3607
+ wp_args_{name} *_wp_args)
3494
3608
  {{
3495
3609
  {forward_body}}}
3496
3610
 
@@ -3499,7 +3613,9 @@ void {name}_cpu_kernel_forward(
3499
3613
  cpu_kernel_template_backward = """
3500
3614
 
3501
3615
  void {name}_cpu_kernel_backward(
3502
- {reverse_args})
3616
+ {reverse_args},
3617
+ wp_args_{name} *_wp_args,
3618
+ wp_args_{name} *_wp_adj_args)
3503
3619
  {{
3504
3620
  {reverse_body}}}
3505
3621
 
@@ -3511,15 +3627,15 @@ extern "C" {{
3511
3627
 
3512
3628
  // Python CPU entry points
3513
3629
  WP_API void {name}_cpu_forward(
3514
- {forward_args})
3630
+ wp::launch_bounds_t dim,
3631
+ wp_args_{name} *_wp_args)
3515
3632
  {{
3516
3633
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3517
3634
  {{
3518
3635
  // init shared memory allocator
3519
3636
  wp::tile_alloc_shared(0, true);
3520
3637
 
3521
- {name}_cpu_kernel_forward(
3522
- {forward_params});
3638
+ {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3523
3639
 
3524
3640
  // check shared memory allocator
3525
3641
  wp::tile_alloc_shared(0, false, true);
@@ -3536,15 +3652,16 @@ cpu_module_template_backward = """
3536
3652
  extern "C" {{
3537
3653
 
3538
3654
  WP_API void {name}_cpu_backward(
3539
- {reverse_args})
3655
+ wp::launch_bounds_t dim,
3656
+ wp_args_{name} *_wp_args,
3657
+ wp_args_{name} *_wp_adj_args)
3540
3658
  {{
3541
3659
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3542
3660
  {{
3543
3661
  // initialize shared memory allocator
3544
3662
  wp::tile_alloc_shared(0, true);
3545
3663
 
3546
- {name}_cpu_kernel_backward(
3547
- {reverse_params});
3664
+ {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3548
3665
 
3549
3666
  // check shared memory allocator
3550
3667
  wp::tile_alloc_shared(0, false, true);
@@ -3575,7 +3692,7 @@ def constant_str(value):
3575
3692
  # special case for float16, which is stored as uint16 in the ctypes.Array
3576
3693
  from warp.context import runtime
3577
3694
 
3578
- scalar_value = runtime.core.half_bits_to_float
3695
+ scalar_value = runtime.core.wp_half_bits_to_float
3579
3696
  else:
3580
3697
 
3581
3698
  def scalar_value(x):
@@ -3713,8 +3830,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3713
3830
 
3714
3831
  indent_block = " " * indent
3715
3832
 
3716
- # primal vars
3717
3833
  lines = []
3834
+
3835
+ # argument vars
3836
+ if device == "cpu" and func_type == "kernel":
3837
+ lines += ["//---------\n"]
3838
+ lines += ["// argument vars\n"]
3839
+
3840
+ for var in adj.args:
3841
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3842
+
3843
+ # primal vars
3718
3844
  lines += ["//---------\n"]
3719
3845
  lines += ["// primal vars\n"]
3720
3846
 
@@ -3758,6 +3884,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3758
3884
 
3759
3885
  lines = []
3760
3886
 
3887
+ # argument vars
3888
+ if device == "cpu" and func_type == "kernel":
3889
+ lines += ["//---------\n"]
3890
+ lines += ["// argument vars\n"]
3891
+
3892
+ for var in adj.args:
3893
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3894
+
3895
+ for var in adj.args:
3896
+ lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3897
+
3761
3898
  # primal vars
3762
3899
  lines += ["//---------\n"]
3763
3900
  lines += ["// primal vars\n"]
@@ -3849,6 +3986,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3849
3986
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3850
3987
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3851
3988
  )
3989
+ elif (
3990
+ isinstance(adj.return_var[0].type, warp.types.fixedarray)
3991
+ and type(adj.arg_types["return"]) is warp.types.array
3992
+ ):
3993
+ # If the return statement yields a `fixedarray` while the function is annotated
3994
+ # to return a standard `array`, then raise an error since the `fixedarray` storage
3995
+ # allocated on the stack will be freed once the function exits, meaning that the
3996
+ # resulting `array` instance will point to an invalid data.
3997
+ raise WarpCodegenError(
3998
+ f"The function `{adj.fun_name}` returns a fixed-size array "
3999
+ f"whereas it has its return type annotated as "
4000
+ f"`{warp.context.type_str(adj.arg_types['return'])}`."
4001
+ )
3852
4002
 
3853
4003
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3854
4004
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
@@ -3927,10 +4077,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3927
4077
  if adj.custom_reverse_mode:
3928
4078
  reverse_body = "\t// user-defined adjoint code\n" + forward_body
3929
4079
  else:
3930
- if options.get("enable_backward", True):
4080
+ if options.get("enable_backward", True) and adj.used_by_backward_kernel:
3931
4081
  reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
3932
4082
  else:
3933
- reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
4083
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
3934
4084
  s += reverse_template.format(
3935
4085
  name=c_func_name,
3936
4086
  return_type=return_type,
@@ -4022,6 +4172,13 @@ def codegen_kernel(kernel, device, options):
4022
4172
 
4023
4173
  adj = kernel.adj
4024
4174
 
4175
+ args_struct = ""
4176
+ if device == "cpu":
4177
+ args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4178
+ for i in adj.args:
4179
+ args_struct += f" {i.ctype()} {i.label};\n"
4180
+ args_struct += "};\n"
4181
+
4025
4182
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4026
4183
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4027
4184
  # a direct mapping to a Python source line.
@@ -4047,9 +4204,9 @@ def codegen_kernel(kernel, device, options):
4047
4204
  forward_args = ["wp::launch_bounds_t dim"]
4048
4205
  if device == "cpu":
4049
4206
  forward_args.append("size_t task_index")
4050
-
4051
- for arg in adj.args:
4052
- forward_args.append(arg.ctype() + " var_" + arg.label)
4207
+ else:
4208
+ for arg in adj.args:
4209
+ forward_args.append(arg.ctype() + " var_" + arg.label)
4053
4210
 
4054
4211
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4055
4212
  template_fmt_args.update(
@@ -4066,17 +4223,16 @@ def codegen_kernel(kernel, device, options):
4066
4223
  reverse_args = ["wp::launch_bounds_t dim"]
4067
4224
  if device == "cpu":
4068
4225
  reverse_args.append("size_t task_index")
4069
-
4070
- for arg in adj.args:
4071
- reverse_args.append(arg.ctype() + " var_" + arg.label)
4072
-
4073
- for arg in adj.args:
4074
- # indexed array gradients are regular arrays
4075
- if isinstance(arg.type, indexedarray):
4076
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4077
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4078
- else:
4079
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
4226
+ else:
4227
+ for arg in adj.args:
4228
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
4229
+ for arg in adj.args:
4230
+ # indexed array gradients are regular arrays
4231
+ if isinstance(arg.type, indexedarray):
4232
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4233
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4234
+ else:
4235
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
4080
4236
 
4081
4237
  reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4082
4238
  template_fmt_args.update(
@@ -4088,7 +4244,7 @@ def codegen_kernel(kernel, device, options):
4088
4244
  template += template_backward
4089
4245
 
4090
4246
  s = template.format(**template_fmt_args)
4091
- return s
4247
+ return args_struct + s
4092
4248
 
4093
4249
 
4094
4250
  def codegen_module(kernel, device, options):
@@ -4099,59 +4255,14 @@ def codegen_module(kernel, device, options):
4099
4255
  options = dict(options)
4100
4256
  options.update(kernel.options)
4101
4257
 
4102
- adj = kernel.adj
4103
-
4104
4258
  template = ""
4105
4259
  template_fmt_args = {
4106
4260
  "name": kernel.get_mangled_name(),
4107
4261
  }
4108
4262
 
4109
- # build forward signature
4110
- forward_args = ["wp::launch_bounds_t dim"]
4111
- forward_params = ["dim", "task_index"]
4112
-
4113
- for arg in adj.args:
4114
- if hasattr(arg.type, "_wp_generic_type_str_"):
4115
- # vectors and matrices are passed from Python by pointer
4116
- forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
4117
- forward_params.append(f"*var_{arg.label}")
4118
- else:
4119
- forward_args.append(f"{arg.ctype()} var_{arg.label}")
4120
- forward_params.append("var_" + arg.label)
4121
-
4122
- template_fmt_args.update(
4123
- {
4124
- "forward_args": indent(forward_args),
4125
- "forward_params": indent(forward_params, 3),
4126
- }
4127
- )
4128
4263
  template += cpu_module_template_forward
4129
4264
 
4130
4265
  if options["enable_backward"]:
4131
- # build reverse signature
4132
- reverse_args = [*forward_args]
4133
- reverse_params = [*forward_params]
4134
-
4135
- for arg in adj.args:
4136
- if isinstance(arg.type, indexedarray):
4137
- # indexed array gradients are regular arrays
4138
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4139
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
4140
- reverse_params.append(f"adj_{_arg.label}")
4141
- elif hasattr(arg.type, "_wp_generic_type_str_"):
4142
- # vectors and matrices are passed from Python by pointer
4143
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
4144
- reverse_params.append(f"*adj_{arg.label}")
4145
- else:
4146
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
4147
- reverse_params.append(f"adj_{arg.label}")
4148
-
4149
- template_fmt_args.update(
4150
- {
4151
- "reverse_args": indent(reverse_args),
4152
- "reverse_params": indent(reverse_params, 3),
4153
- }
4154
- )
4155
4266
  template += cpu_module_template_backward
4156
4267
 
4157
4268
  s = template.format(**template_fmt_args)