warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.0__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -18,9 +18,11 @@ from __future__ import annotations
18
18
  import ast
19
19
  import builtins
20
20
  import ctypes
21
+ import enum
21
22
  import functools
22
23
  import hashlib
23
24
  import inspect
25
+ import itertools
24
26
  import math
25
27
  import re
26
28
  import sys
@@ -614,6 +616,8 @@ def compute_type_str(base_name, template_params):
614
616
  return base_name
615
617
 
616
618
  def param2str(p):
619
+ if isinstance(p, builtins.bool):
620
+ return "true" if p else "false"
617
621
  if isinstance(p, int):
618
622
  return str(p)
619
623
  elif hasattr(p, "_wp_generic_type_str_"):
@@ -625,6 +629,8 @@ def compute_type_str(base_name, template_params):
625
629
  return f"wp::{p.__name__}"
626
630
  elif is_tile(p):
627
631
  return p.ctype()
632
+ elif isinstance(p, Struct):
633
+ return p.native_name
628
634
 
629
635
  return p.__name__
630
636
 
@@ -684,7 +690,12 @@ class Var:
684
690
 
685
691
  @staticmethod
686
692
  def type_to_ctype(t: type, value_type: builtins.bool = False) -> str:
687
- if is_array(t):
693
+ if isinstance(t, fixedarray):
694
+ template_args = (str(t.size), Var.dtype_to_ctype(t.dtype))
695
+ dtypestr = ", ".join(template_args)
696
+ classstr = f"wp::{type(t).__name__}"
697
+ return f"{classstr}_t<{dtypestr}>"
698
+ elif is_array(t):
688
699
  dtypestr = Var.dtype_to_ctype(t.dtype)
689
700
  classstr = f"wp::{type(t).__name__}"
690
701
  return f"{classstr}_t<{dtypestr}>"
@@ -780,11 +791,10 @@ def apply_defaults(
780
791
  arguments = bound_args.arguments
781
792
  new_arguments = []
782
793
  for name in bound_args._signature.parameters.keys():
783
- try:
794
+ if name in arguments:
784
795
  new_arguments.append((name, arguments[name]))
785
- except KeyError:
786
- if name in values:
787
- new_arguments.append((name, values[name]))
796
+ elif name in values:
797
+ new_arguments.append((name, values[name]))
788
798
 
789
799
  bound_args.arguments = dict(new_arguments)
790
800
 
@@ -837,6 +847,9 @@ def get_arg_type(arg: Var | Any) -> type:
837
847
  if isinstance(arg, Sequence):
838
848
  return tuple(get_arg_type(x) for x in arg)
839
849
 
850
+ if is_array(arg):
851
+ return arg
852
+
840
853
  if get_origin(arg) is tuple:
841
854
  return tuple(get_arg_type(x) for x in get_args(arg))
842
855
 
@@ -896,6 +909,8 @@ class Adjoint:
896
909
  adj.skip_forward_codegen = skip_forward_codegen
897
910
  # whether the generation of the adjoint code is skipped for this function
898
911
  adj.skip_reverse_codegen = skip_reverse_codegen
912
+ # Whether this function is used by a kernel that has has the backward pass enabled.
913
+ adj.used_by_backward_kernel = False
899
914
 
900
915
  # extract name of source file
901
916
  adj.filename = inspect.getsourcefile(func) or "unknown source file"
@@ -962,7 +977,7 @@ class Adjoint:
962
977
  continue
963
978
 
964
979
  # add variable for argument
965
- arg = Var(name, type, False)
980
+ arg = Var(name, type, requires_grad=False)
966
981
  adj.args.append(arg)
967
982
 
968
983
  # pre-populate symbol dictionary with function argument names
@@ -1071,17 +1086,21 @@ class Adjoint:
1071
1086
  # recursively evaluate function body
1072
1087
  try:
1073
1088
  adj.eval(adj.tree.body[0])
1074
- except Exception:
1089
+ except Exception as original_exc:
1075
1090
  try:
1076
1091
  lineno = adj.lineno + adj.fun_lineno
1077
1092
  line = adj.source_lines[adj.lineno]
1078
1093
  msg = f'Error while parsing function "{adj.fun_name}" at {adj.filename}:{lineno}:\n{line}\n'
1079
- ex, data, traceback = sys.exc_info()
1080
- e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
1094
+
1095
+ # Combine the new message with the original exception's arguments
1096
+ new_args = (";".join([msg] + [str(a) for a in original_exc.args]),)
1097
+
1098
+ # Enhance the original exception with parser context before re-raising.
1099
+ # 'from None' is used to suppress Python's chained exceptions for a cleaner error output.
1100
+ raise type(original_exc)(*new_args).with_traceback(original_exc.__traceback__) from None
1081
1101
  finally:
1082
1102
  adj.skip_build = True
1083
1103
  adj.builder = None
1084
- raise e
1085
1104
 
1086
1105
  if builder is not None:
1087
1106
  for a in adj.args:
@@ -1227,9 +1246,9 @@ class Adjoint:
1227
1246
 
1228
1247
  # lineinfo is enabled by default in debug mode regardless of the builder option, don't want to unnecessarily
1229
1248
  # emit line directives in generated code if it's not being compiled with line information
1230
- lineinfo_enabled = (
1231
- adj.builder_options.get("lineinfo", False) or adj.builder_options.get("mode", "release") == "debug"
1232
- )
1249
+ build_mode = val if (val := adj.builder_options.get("mode")) is not None else warp.config.mode
1250
+
1251
+ lineinfo_enabled = adj.builder_options.get("lineinfo", False) or build_mode == "debug"
1233
1252
 
1234
1253
  if relative_lineno is not None and lineinfo_enabled and warp.config.line_directives:
1235
1254
  is_comment = statement.strip().startswith("//")
@@ -1348,7 +1367,7 @@ class Adjoint:
1348
1367
  # unresolved function, report error
1349
1368
  arg_type_reprs = []
1350
1369
 
1351
- for x in arg_types:
1370
+ for x in itertools.chain(arg_types, kwarg_types.values()):
1352
1371
  if isinstance(x, warp.context.Function):
1353
1372
  arg_type_reprs.append("function")
1354
1373
  else:
@@ -1378,7 +1397,7 @@ class Adjoint:
1378
1397
  # in order to process them as Python does it.
1379
1398
  bound_args: inspect.BoundArguments = func.signature.bind(*args, **kwargs)
1380
1399
 
1381
- # Type args are the compile time argument values we get from codegen.
1400
+ # Type args are the "compile time" argument values we get from codegen.
1382
1401
  # For example, when calling `wp.vec3f(...)` from within a kernel,
1383
1402
  # this translates in fact to calling the `vector()` built-in augmented
1384
1403
  # with the type args `length=3, dtype=float`.
@@ -1416,20 +1435,30 @@ class Adjoint:
1416
1435
  bound_args = bound_args.arguments
1417
1436
 
1418
1437
  # if it is a user-function then build it recursively
1419
- if not func.is_builtin() and func not in adj.builder.functions:
1420
- adj.builder.build_function(func)
1421
- # add custom grad, replay functions to the list of functions
1422
- # to be built later (invalid code could be generated if we built them now)
1423
- # so that they are not missed when only the forward function is imported
1424
- # from another module
1425
- if func.custom_grad_func:
1426
- adj.builder.deferred_functions.append(func.custom_grad_func)
1427
- if func.custom_replay_func:
1428
- adj.builder.deferred_functions.append(func.custom_replay_func)
1438
+ if not func.is_builtin():
1439
+ # If the function called is a user function,
1440
+ # we need to ensure its adjoint is also being generated.
1441
+ if adj.used_by_backward_kernel:
1442
+ func.adj.used_by_backward_kernel = True
1443
+
1444
+ if adj.builder is None:
1445
+ func.build(None)
1446
+
1447
+ elif func not in adj.builder.functions:
1448
+ adj.builder.build_function(func)
1449
+ # add custom grad, replay functions to the list of functions
1450
+ # to be built later (invalid code could be generated if we built them now)
1451
+ # so that they are not missed when only the forward function is imported
1452
+ # from another module
1453
+ if func.custom_grad_func:
1454
+ adj.builder.deferred_functions.append(func.custom_grad_func)
1455
+ if func.custom_replay_func:
1456
+ adj.builder.deferred_functions.append(func.custom_replay_func)
1429
1457
 
1430
1458
  # Resolve the return value based on the types and values of the given arguments.
1431
1459
  bound_arg_types = {k: get_arg_type(v) for k, v in bound_args.items()}
1432
1460
  bound_arg_values = {k: get_arg_value(v) for k, v in bound_args.items()}
1461
+
1433
1462
  return_type = func.value_func(
1434
1463
  {k: strip_reference(v) for k, v in bound_arg_types.items()},
1435
1464
  bound_arg_values,
@@ -1493,6 +1522,9 @@ class Adjoint:
1493
1522
 
1494
1523
  # if the argument is a function (and not a builtin), then build it recursively
1495
1524
  if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1525
+ if adj.used_by_backward_kernel:
1526
+ func_arg_var.adj.used_by_backward_kernel = True
1527
+
1496
1528
  adj.builder.build_function(func_arg_var)
1497
1529
 
1498
1530
  fwd_args.append(strip_reference(func_arg_var))
@@ -1886,6 +1918,9 @@ class Adjoint:
1886
1918
  return obj
1887
1919
  if isinstance(obj, type):
1888
1920
  return obj
1921
+ if isinstance(obj, Struct):
1922
+ adj.builder.build_struct_recursive(obj)
1923
+ return obj
1889
1924
  if isinstance(obj, types.ModuleType):
1890
1925
  return obj
1891
1926
 
@@ -1938,11 +1973,17 @@ class Adjoint:
1938
1973
  aggregate = adj.eval(node.value)
1939
1974
 
1940
1975
  try:
1976
+ if isinstance(aggregate, Var) and aggregate.constant is not None:
1977
+ # this case may occur when the attribute is a constant, e.g.: `IntEnum.A.value`
1978
+ return aggregate
1979
+
1941
1980
  if isinstance(aggregate, types.ModuleType) or isinstance(aggregate, type):
1942
1981
  out = getattr(aggregate, node.attr)
1943
1982
 
1944
1983
  if warp.types.is_value(out):
1945
1984
  return adj.add_constant(out)
1985
+ if isinstance(out, (enum.IntEnum, enum.IntFlag)):
1986
+ return adj.add_constant(int(out))
1946
1987
 
1947
1988
  return out
1948
1989
 
@@ -1970,18 +2011,29 @@ class Adjoint:
1970
2011
  return adj.add_builtin_call("transform_get_rotation", [aggregate])
1971
2012
 
1972
2013
  else:
1973
- attr_type = Reference(aggregate_type.vars[node.attr].type)
2014
+ attr_var = aggregate_type.vars[node.attr]
2015
+
2016
+ # represent pointer types as uint64
2017
+ if isinstance(attr_var.type, pointer_t):
2018
+ cast = f"({Var.dtype_to_ctype(uint64)}*)"
2019
+ adj_cast = f"({Var.dtype_to_ctype(attr_var.type.dtype)}*)"
2020
+ attr_type = Reference(uint64)
2021
+ else:
2022
+ cast = ""
2023
+ adj_cast = ""
2024
+ attr_type = Reference(attr_var.type)
2025
+
1974
2026
  attr = adj.add_var(attr_type)
1975
2027
 
1976
2028
  if is_reference(aggregate.type):
1977
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}->{node.attr});")
2029
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}->{attr_var.label});")
1978
2030
  else:
1979
- adj.add_forward(f"{attr.emit()} = &({aggregate.emit()}.{node.attr});")
2031
+ adj.add_forward(f"{attr.emit()} = {cast}&({aggregate.emit()}.{attr_var.label});")
1980
2032
 
1981
2033
  if adj.is_differentiable_value_type(strip_reference(attr_type)):
1982
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} += {attr.emit_adj()};")
2034
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} += {adj_cast}{attr.emit_adj()};")
1983
2035
  else:
1984
- adj.add_reverse(f"{aggregate.emit_adj()}.{node.attr} = {attr.emit_adj()};")
2036
+ adj.add_reverse(f"{aggregate.emit_adj()}.{attr_var.label} = {adj_cast}{attr.emit_adj()};")
1985
2037
 
1986
2038
  return attr
1987
2039
 
@@ -2309,9 +2361,12 @@ class Adjoint:
2309
2361
 
2310
2362
  return var
2311
2363
 
2312
- if isinstance(expr, (type, Var, warp.context.Function)):
2364
+ if isinstance(expr, (type, Struct, Var, warp.context.Function)):
2313
2365
  return expr
2314
2366
 
2367
+ if isinstance(expr, (enum.IntEnum, enum.IntFlag)):
2368
+ return adj.add_constant(int(expr))
2369
+
2315
2370
  return adj.add_constant(expr)
2316
2371
 
2317
2372
  def emit_Call(adj, node):
@@ -2360,7 +2415,8 @@ class Adjoint:
2360
2415
 
2361
2416
  # struct constructor
2362
2417
  if func is None and isinstance(caller, Struct):
2363
- adj.builder.build_struct_recursive(caller)
2418
+ if adj.builder is not None:
2419
+ adj.builder.build_struct_recursive(caller)
2364
2420
  if node.args or node.keywords:
2365
2421
  func = caller.value_constructor
2366
2422
  else:
@@ -2420,68 +2476,45 @@ class Adjoint:
2420
2476
 
2421
2477
  return adj.eval(node.value)
2422
2478
 
2423
- # returns the object being indexed, and the list of indices
2424
- def eval_subscript(adj, node):
2425
- # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2426
- # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2427
- # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2428
- root = node
2429
- count = 0
2430
- array = None
2431
- while isinstance(root, ast.Subscript):
2432
- if isinstance(root.slice, ast.Tuple):
2433
- # handles the x[i, j] case (Python 3.8.x upward)
2434
- count += len(root.slice.elts)
2435
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2436
- # handles the x[i, j] case (Python 3.7.x)
2437
- count += len(root.slice.value.elts)
2438
- else:
2439
- # simple expression, e.g.: x[i]
2440
- count += 1
2441
-
2442
- if isinstance(root.value, ast.Name):
2443
- symbol = adj.emit_Name(root.value)
2444
- symbol_type = strip_reference(symbol.type)
2445
- if is_array(symbol_type):
2446
- array = symbol
2447
- break
2448
-
2449
- root = root.value
2450
-
2451
- # If not all indices index into the array, just evaluate the right-most indexing operation.
2452
- if not array or (count > array.type.ndim):
2453
- count = 1
2454
-
2455
- indices = []
2456
- root = node
2457
- while len(indices) < count:
2458
- if isinstance(root.slice, ast.Tuple):
2459
- ij = [adj.eval(arg) for arg in root.slice.elts]
2460
- elif isinstance(root.slice, ast.Index) and isinstance(root.slice.value, ast.Tuple):
2461
- ij = [adj.eval(arg) for arg in root.slice.value.elts]
2462
- else:
2463
- ij = [adj.eval(root.slice)]
2464
-
2465
- indices = ij + indices # prepend
2466
-
2467
- root = root.value
2468
-
2469
- target = adj.eval(root)
2479
+ def eval_indices(adj, target_type, indices):
2480
+ nodes = indices
2481
+ if hasattr(target_type, "_wp_generic_type_hint_"):
2482
+ indices = []
2483
+ for dim, node in enumerate(nodes):
2484
+ if isinstance(node, ast.Slice):
2485
+ # In the context of slicing a vec/mat type, indices are expected
2486
+ # to be compile-time constants, hence we can infer the actual slice
2487
+ # bounds also at compile-time.
2488
+ length = target_type._shape_[dim]
2489
+ step = 1 if node.step is None else adj.eval(node.step).constant
2490
+
2491
+ if node.lower is None:
2492
+ start = length - 1 if step < 0 else 0
2493
+ else:
2494
+ start = adj.eval(node.lower).constant
2495
+ start = min(max(start, -length), length)
2496
+ start = start + length if start < 0 else start
2470
2497
 
2471
- return target, indices
2498
+ if node.upper is None:
2499
+ stop = -1 if step < 0 else length
2500
+ else:
2501
+ stop = adj.eval(node.upper).constant
2502
+ stop = min(max(stop, -length), length)
2503
+ stop = stop + length if stop < 0 else stop
2472
2504
 
2473
- def emit_Subscript(adj, node):
2474
- if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2475
- # handle adjoint of a variable, i.e. wp.adjoint[var]
2476
- node.slice.is_adjoint = True
2477
- var = adj.eval(node.slice)
2478
- var_name = var.label
2479
- var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2480
- return var
2505
+ slice = adj.add_builtin_call("slice", (start, stop, step))
2506
+ indices.append(slice)
2507
+ else:
2508
+ indices.append(adj.eval(node))
2481
2509
 
2482
- target, indices = adj.eval_subscript(node)
2510
+ return tuple(indices)
2511
+ else:
2512
+ return tuple(adj.eval(x) for x in nodes)
2483
2513
 
2514
+ def emit_indexing(adj, target, indices):
2484
2515
  target_type = strip_reference(target.type)
2516
+ indices = adj.eval_indices(target_type, indices)
2517
+
2485
2518
  if is_array(target_type):
2486
2519
  if len(indices) == target_type.ndim:
2487
2520
  # handles array loads (where each dimension has an index specified)
@@ -2520,47 +2553,116 @@ class Adjoint:
2520
2553
 
2521
2554
  return out
2522
2555
 
2556
+ # from a list of lists of indices, strip the first `count` indices
2557
+ @staticmethod
2558
+ def strip_indices(indices, count):
2559
+ dim = count
2560
+ while count > 0:
2561
+ ij = indices[0]
2562
+ indices = indices[1:]
2563
+ count -= len(ij)
2564
+
2565
+ # report straddling like in `arr2d[0][1,2]` as a syntax error
2566
+ if count < 0:
2567
+ raise WarpCodegenError(
2568
+ f"Incorrect number of indices specified for array indexing, got {dim - count} indices for a {dim} dimensional array."
2569
+ )
2570
+
2571
+ return indices
2572
+
2573
+ def recurse_subscript(adj, node, indices):
2574
+ if isinstance(node, ast.Name):
2575
+ target = adj.eval(node)
2576
+ return target, indices
2577
+
2578
+ if isinstance(node, ast.Subscript):
2579
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2580
+ return adj.eval(node), indices
2581
+
2582
+ if isinstance(node.slice, ast.Tuple):
2583
+ ij = node.slice.elts
2584
+ elif isinstance(node.slice, ast.Index) and isinstance(node.slice.value, ast.Tuple):
2585
+ # The node `ast.Index` is deprecated in Python 3.9.
2586
+ ij = node.slice.value.elts
2587
+ elif isinstance(node.slice, ast.ExtSlice):
2588
+ # The node `ast.ExtSlice` is deprecated in Python 3.9.
2589
+ ij = node.slice.dims
2590
+ else:
2591
+ ij = [node.slice]
2592
+
2593
+ indices = [ij, *indices] # prepend
2594
+
2595
+ target, indices = adj.recurse_subscript(node.value, indices)
2596
+
2597
+ target_type = strip_reference(target.type)
2598
+ if is_array(target_type):
2599
+ flat_indices = [i for ij in indices for i in ij]
2600
+ if len(flat_indices) > target_type.ndim:
2601
+ target = adj.emit_indexing(target, flat_indices[: target_type.ndim])
2602
+ indices = adj.strip_indices(indices, target_type.ndim)
2603
+
2604
+ return target, indices
2605
+
2606
+ target = adj.eval(node)
2607
+ return target, indices
2608
+
2609
+ # returns the object being indexed, and the list of indices
2610
+ def eval_subscript(adj, node):
2611
+ target, indices = adj.recurse_subscript(node, [])
2612
+ flat_indices = [i for ij in indices for i in ij]
2613
+ return target, flat_indices
2614
+
2615
+ def emit_Subscript(adj, node):
2616
+ if hasattr(node.value, "attr") and node.value.attr == "adjoint":
2617
+ # handle adjoint of a variable, i.e. wp.adjoint[var]
2618
+ node.slice.is_adjoint = True
2619
+ var = adj.eval(node.slice)
2620
+ var_name = var.label
2621
+ var = Var(f"adj_{var_name}", type=var.type, constant=None, prefix=False)
2622
+ return var
2623
+
2624
+ target, indices = adj.eval_subscript(node)
2625
+
2626
+ return adj.emit_indexing(target, indices)
2627
+
2523
2628
  def emit_Assign(adj, node):
2524
2629
  if len(node.targets) != 1:
2525
2630
  raise WarpCodegenError("Assigning the same value to multiple variables is not supported")
2526
2631
 
2527
- lhs = node.targets[0]
2632
+ # Check if the rhs corresponds to an unsupported construct.
2633
+ # Tuples are supported in the context of assigning multiple variables
2634
+ # at once, but not for simple assignments like `x = (1, 2, 3)`.
2635
+ # Therefore, we need to catch this specific case here instead of
2636
+ # more generally in `adj.eval()`.
2637
+ if isinstance(node.value, ast.List):
2638
+ raise WarpCodegenError(
2639
+ "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2640
+ )
2528
2641
 
2529
- if not isinstance(lhs, ast.Tuple):
2530
- # Check if the rhs corresponds to an unsupported construct.
2531
- # Tuples are supported in the context of assigning multiple variables
2532
- # at once, but not for simple assignments like `x = (1, 2, 3)`.
2533
- # Therefore, we need to catch this specific case here instead of
2534
- # more generally in `adj.eval()`.
2535
- if isinstance(node.value, ast.List):
2536
- raise WarpCodegenError(
2537
- "List constructs are not supported in kernels. Use vectors like `wp.vec3()` for small collections instead."
2538
- )
2642
+ lhs = node.targets[0]
2539
2643
 
2540
- # handle the case where we are assigning multiple output variables
2541
- if isinstance(lhs, ast.Tuple):
2644
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Call):
2542
2645
  # record the expected number of outputs on the node
2543
2646
  # we do this so we can decide which function to
2544
2647
  # call based on the number of expected outputs
2545
- if isinstance(node.value, ast.Call):
2546
- node.value.expects = len(lhs.elts)
2648
+ node.value.expects = len(lhs.elts)
2547
2649
 
2548
- # evaluate values
2549
- if isinstance(node.value, ast.Tuple):
2550
- out = [adj.eval(v) for v in node.value.elts]
2551
- else:
2552
- out = adj.eval(node.value)
2650
+ # evaluate rhs
2651
+ if isinstance(lhs, ast.Tuple) and isinstance(node.value, ast.Tuple):
2652
+ rhs = [adj.eval(v) for v in node.value.elts]
2653
+ else:
2654
+ rhs = adj.eval(node.value)
2655
+
2656
+ # handle the case where we are assigning multiple output variables
2657
+ if isinstance(lhs, ast.Tuple):
2658
+ subtype = getattr(rhs, "type", None)
2553
2659
 
2554
- subtype = getattr(out, "type", None)
2555
2660
  if isinstance(subtype, warp.types.tuple_t):
2556
- if len(out.type.types) != len(lhs.elts):
2661
+ if len(rhs.type.types) != len(lhs.elts):
2557
2662
  raise WarpCodegenError(
2558
- f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(out.type.types)})."
2663
+ f"Invalid number of values to unpack (expected {len(lhs.elts)}, got {len(rhs.type.types)})."
2559
2664
  )
2560
- target = out
2561
- out = tuple(
2562
- adj.add_builtin_call("extract", (target, adj.add_constant(i))) for i in range(len(lhs.elts))
2563
- )
2665
+ rhs = tuple(adj.add_builtin_call("extract", (rhs, adj.add_constant(i))) for i in range(len(lhs.elts)))
2564
2666
 
2565
2667
  names = []
2566
2668
  for v in lhs.elts:
@@ -2571,11 +2673,12 @@ class Adjoint:
2571
2673
  "Multiple return functions can only assign to simple variables, e.g.: x, y = func()"
2572
2674
  )
2573
2675
 
2574
- if len(names) != len(out):
2676
+ if len(names) != len(rhs):
2575
2677
  raise WarpCodegenError(
2576
- f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(out)}, got {len(names)})"
2678
+ f"Multiple return functions need to receive all their output values, incorrect number of values to unpack (expected {len(rhs)}, got {len(names)})"
2577
2679
  )
2578
2680
 
2681
+ out = rhs
2579
2682
  for name, rhs in zip(names, out):
2580
2683
  if name in adj.symbols:
2581
2684
  if not types_equal(rhs.type, adj.symbols[name].type):
@@ -2587,8 +2690,6 @@ class Adjoint:
2587
2690
 
2588
2691
  # handles the case where we are assigning to an array index (e.g.: arr[i] = 2.0)
2589
2692
  elif isinstance(lhs, ast.Subscript):
2590
- rhs = adj.eval(node.value)
2591
-
2592
2693
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
2593
2694
  # handle adjoint of a variable, i.e. wp.adjoint[var]
2594
2695
  lhs.slice.is_adjoint = True
@@ -2600,6 +2701,7 @@ class Adjoint:
2600
2701
  target, indices = adj.eval_subscript(lhs)
2601
2702
 
2602
2703
  target_type = strip_reference(target.type)
2704
+ indices = adj.eval_indices(target_type, indices)
2603
2705
 
2604
2706
  if is_array(target_type):
2605
2707
  adj.add_builtin_call("array_store", [target, *indices, rhs])
@@ -2621,14 +2723,11 @@ class Adjoint:
2621
2723
  or type_is_transformation(target_type)
2622
2724
  ):
2623
2725
  # recursively unwind AST, stopping at penultimate node
2624
- node = lhs
2625
- while hasattr(node, "value"):
2626
- if hasattr(node.value, "value"):
2627
- node = node.value
2628
- else:
2629
- break
2726
+ root = lhs
2727
+ while hasattr(root.value, "value"):
2728
+ root = root.value
2630
2729
  # lhs is updating a variable adjoint (i.e. wp.adjoint[var])
2631
- if hasattr(node, "attr") and node.attr == "adjoint":
2730
+ if hasattr(root, "attr") and root.attr == "adjoint":
2632
2731
  attr = adj.add_builtin_call("index", [target, *indices])
2633
2732
  adj.add_builtin_call("store", [attr, rhs])
2634
2733
  return
@@ -2666,9 +2765,6 @@ class Adjoint:
2666
2765
  # symbol name
2667
2766
  name = lhs.id
2668
2767
 
2669
- # evaluate rhs
2670
- rhs = adj.eval(node.value)
2671
-
2672
2768
  # check type matches if symbol already defined
2673
2769
  if name in adj.symbols:
2674
2770
  if not types_equal(strip_reference(rhs.type), adj.symbols[name].type):
@@ -2689,7 +2785,6 @@ class Adjoint:
2689
2785
  adj.symbols[name] = out
2690
2786
 
2691
2787
  elif isinstance(lhs, ast.Attribute):
2692
- rhs = adj.eval(node.value)
2693
2788
  aggregate = adj.eval(lhs.value)
2694
2789
  aggregate_type = strip_reference(aggregate.type)
2695
2790
 
@@ -2777,9 +2872,9 @@ class Adjoint:
2777
2872
  new_node = ast.Assign(targets=[lhs], value=ast.BinOp(lhs, node.op, node.value))
2778
2873
  adj.eval(new_node)
2779
2874
 
2780
- if isinstance(lhs, ast.Subscript):
2781
- rhs = adj.eval(node.value)
2875
+ rhs = adj.eval(node.value)
2782
2876
 
2877
+ if isinstance(lhs, ast.Subscript):
2783
2878
  # wp.adjoint[var] appears in custom grad functions, and does not require
2784
2879
  # special consideration in the AugAssign case
2785
2880
  if hasattr(lhs.value, "attr") and lhs.value.attr == "adjoint":
@@ -2789,6 +2884,7 @@ class Adjoint:
2789
2884
  target, indices = adj.eval_subscript(lhs)
2790
2885
 
2791
2886
  target_type = strip_reference(target.type)
2887
+ indices = adj.eval_indices(target_type, indices)
2792
2888
 
2793
2889
  if is_array(target_type):
2794
2890
  # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
@@ -2861,7 +2957,6 @@ class Adjoint:
2861
2957
 
2862
2958
  elif isinstance(lhs, ast.Name):
2863
2959
  target = adj.eval(node.target)
2864
- rhs = adj.eval(node.value)
2865
2960
 
2866
2961
  if is_tile(target.type) and is_tile(rhs.type):
2867
2962
  if isinstance(node.op, ast.Add):
@@ -3163,6 +3258,8 @@ class Adjoint:
3163
3258
 
3164
3259
  try:
3165
3260
  value = eval(code_to_eval, vars_dict)
3261
+ if isinstance(value, (enum.IntEnum, enum.IntFlag)):
3262
+ value = int(value)
3166
3263
  if warp.config.verbose:
3167
3264
  print(f"Evaluated static command: {static_code} = {value}")
3168
3265
  except NameError as e:
@@ -3373,6 +3470,11 @@ cuda_module_header = """
3373
3470
  #define WP_NO_CRT
3374
3471
  #include "builtin.h"
3375
3472
 
3473
+ // Map wp.breakpoint() to a device brkpt at the call site so cuda-gdb attributes the stop to the generated .cu line
3474
+ #if defined(__CUDACC__) && !defined(_MSC_VER)
3475
+ #define __debugbreak() __brkpt()
3476
+ #endif
3477
+
3376
3478
  // avoid namespacing of float type for casting to float type, this is to avoid wp::float(x), which is not valid in C++
3377
3479
  #define float(x) cast_float(x)
3378
3480
  #define adj_float(x, adj_x, adj_ret) adj_cast_float(x, adj_x, adj_ret)
@@ -3410,6 +3512,12 @@ static CUDA_CALLABLE void adj_{name}({reverse_args})
3410
3512
  {{
3411
3513
  {reverse_body}}}
3412
3514
 
3515
+ // Required when compiling adjoints.
3516
+ CUDA_CALLABLE {name} add(const {name}& a, const {name}& b)
3517
+ {{
3518
+ return {name}();
3519
+ }}
3520
+
3413
3521
  CUDA_CALLABLE void adj_atomic_add({name}* p, {name} t)
3414
3522
  {{
3415
3523
  {atomic_add_body}}}
@@ -3490,7 +3598,8 @@ cuda_kernel_template_backward = """
3490
3598
  cpu_kernel_template_forward = """
3491
3599
 
3492
3600
  void {name}_cpu_kernel_forward(
3493
- {forward_args})
3601
+ {forward_args},
3602
+ wp_args_{name} *_wp_args)
3494
3603
  {{
3495
3604
  {forward_body}}}
3496
3605
 
@@ -3499,7 +3608,9 @@ void {name}_cpu_kernel_forward(
3499
3608
  cpu_kernel_template_backward = """
3500
3609
 
3501
3610
  void {name}_cpu_kernel_backward(
3502
- {reverse_args})
3611
+ {reverse_args},
3612
+ wp_args_{name} *_wp_args,
3613
+ wp_args_{name} *_wp_adj_args)
3503
3614
  {{
3504
3615
  {reverse_body}}}
3505
3616
 
@@ -3511,15 +3622,15 @@ extern "C" {{
3511
3622
 
3512
3623
  // Python CPU entry points
3513
3624
  WP_API void {name}_cpu_forward(
3514
- {forward_args})
3625
+ wp::launch_bounds_t dim,
3626
+ wp_args_{name} *_wp_args)
3515
3627
  {{
3516
3628
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3517
3629
  {{
3518
3630
  // init shared memory allocator
3519
3631
  wp::tile_alloc_shared(0, true);
3520
3632
 
3521
- {name}_cpu_kernel_forward(
3522
- {forward_params});
3633
+ {name}_cpu_kernel_forward(dim, task_index, _wp_args);
3523
3634
 
3524
3635
  // check shared memory allocator
3525
3636
  wp::tile_alloc_shared(0, false, true);
@@ -3536,15 +3647,16 @@ cpu_module_template_backward = """
3536
3647
  extern "C" {{
3537
3648
 
3538
3649
  WP_API void {name}_cpu_backward(
3539
- {reverse_args})
3650
+ wp::launch_bounds_t dim,
3651
+ wp_args_{name} *_wp_args,
3652
+ wp_args_{name} *_wp_adj_args)
3540
3653
  {{
3541
3654
  for (size_t task_index = 0; task_index < dim.size; ++task_index)
3542
3655
  {{
3543
3656
  // initialize shared memory allocator
3544
3657
  wp::tile_alloc_shared(0, true);
3545
3658
 
3546
- {name}_cpu_kernel_backward(
3547
- {reverse_params});
3659
+ {name}_cpu_kernel_backward(dim, task_index, _wp_args, _wp_adj_args);
3548
3660
 
3549
3661
  // check shared memory allocator
3550
3662
  wp::tile_alloc_shared(0, false, true);
@@ -3575,7 +3687,7 @@ def constant_str(value):
3575
3687
  # special case for float16, which is stored as uint16 in the ctypes.Array
3576
3688
  from warp.context import runtime
3577
3689
 
3578
- scalar_value = runtime.core.half_bits_to_float
3690
+ scalar_value = runtime.core.wp_half_bits_to_float
3579
3691
  else:
3580
3692
 
3581
3693
  def scalar_value(x):
@@ -3713,8 +3825,17 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3713
3825
 
3714
3826
  indent_block = " " * indent
3715
3827
 
3716
- # primal vars
3717
3828
  lines = []
3829
+
3830
+ # argument vars
3831
+ if device == "cpu" and func_type == "kernel":
3832
+ lines += ["//---------\n"]
3833
+ lines += ["// argument vars\n"]
3834
+
3835
+ for var in adj.args:
3836
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3837
+
3838
+ # primal vars
3718
3839
  lines += ["//---------\n"]
3719
3840
  lines += ["// primal vars\n"]
3720
3841
 
@@ -3758,6 +3879,17 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3758
3879
 
3759
3880
  lines = []
3760
3881
 
3882
+ # argument vars
3883
+ if device == "cpu" and func_type == "kernel":
3884
+ lines += ["//---------\n"]
3885
+ lines += ["// argument vars\n"]
3886
+
3887
+ for var in adj.args:
3888
+ lines += [f"{var.ctype()} {var.emit()} = _wp_args->{var.label};\n"]
3889
+
3890
+ for var in adj.args:
3891
+ lines += [f"{var.ctype()} {var.emit_adj()} = _wp_adj_args->{var.label};\n"]
3892
+
3761
3893
  # primal vars
3762
3894
  lines += ["//---------\n"]
3763
3895
  lines += ["// primal vars\n"]
@@ -3849,6 +3981,19 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3849
3981
  f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3850
3982
  f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3851
3983
  )
3984
+ elif (
3985
+ isinstance(adj.return_var[0].type, warp.types.fixedarray)
3986
+ and type(adj.arg_types["return"]) is warp.types.array
3987
+ ):
3988
+ # If the return statement yields a `fixedarray` while the function is annotated
3989
+ # to return a standard `array`, then raise an error since the `fixedarray` storage
3990
+ # allocated on the stack will be freed once the function exits, meaning that the
3991
+ # resulting `array` instance will point to an invalid data.
3992
+ raise WarpCodegenError(
3993
+ f"The function `{adj.fun_name}` returns a fixed-size array "
3994
+ f"whereas it has its return type annotated as "
3995
+ f"`{warp.context.type_str(adj.arg_types['return'])}`."
3996
+ )
3852
3997
 
3853
3998
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
3854
3999
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
@@ -3927,10 +4072,10 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3927
4072
  if adj.custom_reverse_mode:
3928
4073
  reverse_body = "\t// user-defined adjoint code\n" + forward_body
3929
4074
  else:
3930
- if options.get("enable_backward", True):
4075
+ if options.get("enable_backward", True) and adj.used_by_backward_kernel:
3931
4076
  reverse_body = codegen_func_reverse(adj, func_type="function", device=device)
3932
4077
  else:
3933
- reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False)\n'
4078
+ reverse_body = '\t// reverse mode disabled (module option "enable_backward" is False or no dependent kernel found with "enable_backward")\n'
3934
4079
  s += reverse_template.format(
3935
4080
  name=c_func_name,
3936
4081
  return_type=return_type,
@@ -4022,6 +4167,13 @@ def codegen_kernel(kernel, device, options):
4022
4167
 
4023
4168
  adj = kernel.adj
4024
4169
 
4170
+ args_struct = ""
4171
+ if device == "cpu":
4172
+ args_struct = f"struct wp_args_{kernel.get_mangled_name()} {{\n"
4173
+ for i in adj.args:
4174
+ args_struct += f" {i.ctype()} {i.label};\n"
4175
+ args_struct += "};\n"
4176
+
4025
4177
  # Build line directive for function definition (subtract 1 to account for 1-indexing of AST line numbers)
4026
4178
  # This is used as a catch-all C-to-Python source line mapping for any code that does not have
4027
4179
  # a direct mapping to a Python source line.
@@ -4047,9 +4199,9 @@ def codegen_kernel(kernel, device, options):
4047
4199
  forward_args = ["wp::launch_bounds_t dim"]
4048
4200
  if device == "cpu":
4049
4201
  forward_args.append("size_t task_index")
4050
-
4051
- for arg in adj.args:
4052
- forward_args.append(arg.ctype() + " var_" + arg.label)
4202
+ else:
4203
+ for arg in adj.args:
4204
+ forward_args.append(arg.ctype() + " var_" + arg.label)
4053
4205
 
4054
4206
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
4055
4207
  template_fmt_args.update(
@@ -4066,17 +4218,16 @@ def codegen_kernel(kernel, device, options):
4066
4218
  reverse_args = ["wp::launch_bounds_t dim"]
4067
4219
  if device == "cpu":
4068
4220
  reverse_args.append("size_t task_index")
4069
-
4070
- for arg in adj.args:
4071
- reverse_args.append(arg.ctype() + " var_" + arg.label)
4072
-
4073
- for arg in adj.args:
4074
- # indexed array gradients are regular arrays
4075
- if isinstance(arg.type, indexedarray):
4076
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4077
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4078
- else:
4079
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
4221
+ else:
4222
+ for arg in adj.args:
4223
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
4224
+ for arg in adj.args:
4225
+ # indexed array gradients are regular arrays
4226
+ if isinstance(arg.type, indexedarray):
4227
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4228
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
4229
+ else:
4230
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
4080
4231
 
4081
4232
  reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
4082
4233
  template_fmt_args.update(
@@ -4088,7 +4239,7 @@ def codegen_kernel(kernel, device, options):
4088
4239
  template += template_backward
4089
4240
 
4090
4241
  s = template.format(**template_fmt_args)
4091
- return s
4242
+ return args_struct + s
4092
4243
 
4093
4244
 
4094
4245
  def codegen_module(kernel, device, options):
@@ -4099,59 +4250,14 @@ def codegen_module(kernel, device, options):
4099
4250
  options = dict(options)
4100
4251
  options.update(kernel.options)
4101
4252
 
4102
- adj = kernel.adj
4103
-
4104
4253
  template = ""
4105
4254
  template_fmt_args = {
4106
4255
  "name": kernel.get_mangled_name(),
4107
4256
  }
4108
4257
 
4109
- # build forward signature
4110
- forward_args = ["wp::launch_bounds_t dim"]
4111
- forward_params = ["dim", "task_index"]
4112
-
4113
- for arg in adj.args:
4114
- if hasattr(arg.type, "_wp_generic_type_str_"):
4115
- # vectors and matrices are passed from Python by pointer
4116
- forward_args.append(f"const {arg.ctype()}* var_" + arg.label)
4117
- forward_params.append(f"*var_{arg.label}")
4118
- else:
4119
- forward_args.append(f"{arg.ctype()} var_{arg.label}")
4120
- forward_params.append("var_" + arg.label)
4121
-
4122
- template_fmt_args.update(
4123
- {
4124
- "forward_args": indent(forward_args),
4125
- "forward_params": indent(forward_params, 3),
4126
- }
4127
- )
4128
4258
  template += cpu_module_template_forward
4129
4259
 
4130
4260
  if options["enable_backward"]:
4131
- # build reverse signature
4132
- reverse_args = [*forward_args]
4133
- reverse_params = [*forward_params]
4134
-
4135
- for arg in adj.args:
4136
- if isinstance(arg.type, indexedarray):
4137
- # indexed array gradients are regular arrays
4138
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
4139
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
4140
- reverse_params.append(f"adj_{_arg.label}")
4141
- elif hasattr(arg.type, "_wp_generic_type_str_"):
4142
- # vectors and matrices are passed from Python by pointer
4143
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
4144
- reverse_params.append(f"*adj_{arg.label}")
4145
- else:
4146
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
4147
- reverse_params.append(f"adj_{arg.label}")
4148
-
4149
- template_fmt_args.update(
4150
- {
4151
- "reverse_args": indent(reverse_args),
4152
- "reverse_params": indent(reverse_params, 3),
4153
- }
4154
- )
4155
4261
  template += cpu_module_template_backward
4156
4262
 
4157
4263
  s = template.format(**template_fmt_args)