warp-lang 1.5.1__py3-none-win_amd64.whl → 1.6.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (131) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1077 -481
  8. warp/codegen.py +250 -122
  9. warp/config.py +65 -21
  10. warp/context.py +500 -149
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_marching_cubes.py +1 -1
  16. warp/examples/core/example_mesh.py +1 -1
  17. warp/examples/core/example_torch.py +18 -34
  18. warp/examples/core/example_wave.py +1 -1
  19. warp/examples/fem/example_apic_fluid.py +1 -0
  20. warp/examples/fem/example_mixed_elasticity.py +1 -1
  21. warp/examples/optim/example_bounce.py +1 -1
  22. warp/examples/optim/example_cloth_throw.py +1 -1
  23. warp/examples/optim/example_diffray.py +4 -15
  24. warp/examples/optim/example_drone.py +1 -1
  25. warp/examples/optim/example_softbody_properties.py +392 -0
  26. warp/examples/optim/example_trajectory.py +1 -3
  27. warp/examples/optim/example_walker.py +5 -0
  28. warp/examples/sim/example_cartpole.py +0 -2
  29. warp/examples/sim/example_cloth_self_contact.py +314 -0
  30. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  31. warp/examples/sim/example_jacobian_ik.py +0 -2
  32. warp/examples/sim/example_quadruped.py +5 -2
  33. warp/examples/tile/example_tile_cholesky.py +79 -0
  34. warp/examples/tile/example_tile_convolution.py +2 -2
  35. warp/examples/tile/example_tile_fft.py +2 -2
  36. warp/examples/tile/example_tile_filtering.py +3 -3
  37. warp/examples/tile/example_tile_matmul.py +4 -4
  38. warp/examples/tile/example_tile_mlp.py +12 -12
  39. warp/examples/tile/example_tile_nbody.py +191 -0
  40. warp/examples/tile/example_tile_walker.py +319 -0
  41. warp/math.py +147 -0
  42. warp/native/array.h +12 -0
  43. warp/native/builtin.h +0 -1
  44. warp/native/bvh.cpp +149 -70
  45. warp/native/bvh.cu +287 -68
  46. warp/native/bvh.h +195 -85
  47. warp/native/clang/clang.cpp +6 -2
  48. warp/native/crt.h +1 -0
  49. warp/native/cuda_util.cpp +35 -0
  50. warp/native/cuda_util.h +5 -0
  51. warp/native/exports.h +40 -40
  52. warp/native/intersect.h +17 -0
  53. warp/native/mat.h +57 -3
  54. warp/native/mathdx.cpp +19 -0
  55. warp/native/mesh.cpp +25 -8
  56. warp/native/mesh.cu +153 -101
  57. warp/native/mesh.h +482 -403
  58. warp/native/quat.h +40 -0
  59. warp/native/solid_angle.h +7 -0
  60. warp/native/sort.cpp +85 -0
  61. warp/native/sort.cu +34 -0
  62. warp/native/sort.h +3 -1
  63. warp/native/spatial.h +11 -0
  64. warp/native/tile.h +1189 -664
  65. warp/native/tile_reduce.h +8 -6
  66. warp/native/vec.h +41 -0
  67. warp/native/warp.cpp +8 -1
  68. warp/native/warp.cu +263 -40
  69. warp/native/warp.h +19 -5
  70. warp/optim/linear.py +22 -4
  71. warp/render/render_opengl.py +132 -59
  72. warp/render/render_usd.py +10 -2
  73. warp/sim/__init__.py +6 -1
  74. warp/sim/collide.py +289 -32
  75. warp/sim/import_urdf.py +20 -5
  76. warp/sim/integrator_euler.py +25 -7
  77. warp/sim/integrator_featherstone.py +147 -35
  78. warp/sim/integrator_vbd.py +842 -40
  79. warp/sim/model.py +173 -112
  80. warp/sim/render.py +2 -2
  81. warp/stubs.py +249 -116
  82. warp/tape.py +28 -30
  83. warp/tests/aux_test_module_unload.py +15 -0
  84. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  85. warp/tests/test_array.py +100 -0
  86. warp/tests/test_assert.py +242 -0
  87. warp/tests/test_codegen.py +14 -61
  88. warp/tests/test_collision.py +8 -8
  89. warp/tests/test_examples.py +16 -1
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_launch.py +77 -26
  94. warp/tests/test_mat.py +213 -168
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +11 -7
  97. warp/tests/test_matmul_lite.py +4 -4
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +6 -5
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_static.py +16 -0
  109. warp/tests/test_tape.py +25 -0
  110. warp/tests/test_tile.py +134 -191
  111. warp/tests/test_tile_load.py +399 -0
  112. warp/tests/test_tile_mathdx.py +61 -8
  113. warp/tests/test_tile_mlp.py +17 -17
  114. warp/tests/test_tile_reduce.py +24 -18
  115. warp/tests/test_tile_shared_memory.py +66 -17
  116. warp/tests/test_tile_view.py +165 -0
  117. warp/tests/test_torch.py +35 -0
  118. warp/tests/test_utils.py +36 -24
  119. warp/tests/test_vec.py +110 -0
  120. warp/tests/unittest_suites.py +29 -4
  121. warp/tests/unittest_utils.py +30 -11
  122. warp/thirdparty/unittest_parallel.py +5 -2
  123. warp/types.py +419 -111
  124. warp/utils.py +9 -5
  125. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/METADATA +86 -45
  126. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/RECORD +129 -118
  127. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/WHEEL +1 -1
  128. warp/examples/benchmarks/benchmark_tile.py +0 -179
  129. warp/native/tile_gemm.h +0 -341
  130. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/LICENSE.md +0 -0
  131. {warp_lang-1.5.1.dist-info → warp_lang-1.6.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -282,9 +282,9 @@ class StructInstance:
282
282
  else:
283
283
  # wp.array
284
284
  assert isinstance(value, array)
285
- assert types_equal(
286
- value.dtype, var.type.dtype
287
- ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
285
+ assert types_equal(value.dtype, var.type.dtype), (
286
+ f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
287
+ )
288
288
  setattr(self._ctype, name, value.__ctype__())
289
289
 
290
290
  elif isinstance(var.type, Struct):
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
606
606
  return "bool"
607
607
  else:
608
608
  return f"wp::{p.__name__}"
609
+ elif is_tile(p):
610
+ return p.ctype()
611
+
609
612
  return p.__name__
610
613
 
611
614
  return f"{base_name}<{','.join(map(param2str, template_params))}>"
@@ -947,7 +950,7 @@ class Adjoint:
947
950
  total_shared = 0
948
951
 
949
952
  for var in adj.variables:
950
- if is_tile(var.type) and var.type.storage == "shared":
953
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
951
954
  total_shared += var.type.size_in_bytes()
952
955
 
953
956
  return total_shared + adj.max_required_extra_shared_memory
@@ -1139,6 +1142,9 @@ class Adjoint:
1139
1142
  if isinstance(var, (Reference, warp.context.Function)):
1140
1143
  return var
1141
1144
 
1145
+ if isinstance(var, int):
1146
+ return adj.add_constant(var)
1147
+
1142
1148
  if var.label is None:
1143
1149
  return adj.add_var(var.type, var.constant)
1144
1150
 
@@ -1349,8 +1355,9 @@ class Adjoint:
1349
1355
  # which allows for some more advanced resolution to be performed,
1350
1356
  # for example by checking whether an argument corresponds to
1351
1357
  # a literal value or references a variable.
1358
+ extra_shared_memory = 0
1352
1359
  if func.lto_dispatch_func is not None:
1353
- func_args, template_args, ltoirs = func.lto_dispatch_func(
1360
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1354
1361
  func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
1362
  )
1356
1363
  elif func.dispatch_func is not None:
@@ -1424,7 +1431,9 @@ class Adjoint:
1424
1431
  # update our smem roofline requirements based on any
1425
1432
  # shared memory required by the dependent function call
1426
1433
  if not func.is_builtin():
1427
- adj.alloc_shared_extra(func.adj.get_total_required_shared())
1434
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1435
+ else:
1436
+ adj.alloc_shared_extra(extra_shared_memory)
1428
1437
 
1429
1438
  return output
1430
1439
 
@@ -1527,7 +1536,8 @@ class Adjoint:
1527
1536
  # zero adjoints
1528
1537
  for i in body_block.vars:
1529
1538
  if is_tile(i.type):
1530
- reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1539
+ if i.type.owner:
1540
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1531
1541
  else:
1532
1542
  reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1533
1543
 
@@ -1857,6 +1867,17 @@ class Adjoint:
1857
1867
  # stubbed @wp.native_func
1858
1868
  return
1859
1869
 
1870
+ def emit_Assert(adj, node):
1871
+ # eval condition
1872
+ cond = adj.eval(node.test)
1873
+ cond = adj.load(cond)
1874
+
1875
+ source_segment = ast.get_source_segment(adj.source, node)
1876
+ # If a message was provided with the assert, " marks can interfere with the generated code
1877
+ escaped_segment = source_segment.replace('"', '\\"')
1878
+
1879
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1880
+
1860
1881
  def emit_NameConstant(adj, node):
1861
1882
  if node.value:
1862
1883
  return adj.add_constant(node.value)
@@ -1900,12 +1921,25 @@ class Adjoint:
1900
1921
 
1901
1922
  name = builtin_operators[type(node.op)]
1902
1923
 
1924
+ try:
1925
+ # Check if there is any user-defined overload for this operator
1926
+ user_func = adj.resolve_external_reference(name)
1927
+ if isinstance(user_func, warp.context.Function):
1928
+ return adj.add_call(user_func, (left, right), {}, {})
1929
+ except WarpCodegenError:
1930
+ pass
1931
+
1903
1932
  return adj.add_builtin_call(name, [left, right])
1904
1933
 
1905
1934
  def emit_UnaryOp(adj, node):
1906
1935
  # evaluate unary op arguments
1907
1936
  arg = adj.eval(node.operand)
1908
1937
 
1938
+ # evaluate expression to a compile-time constant if arg is a constant
1939
+ if arg.constant is not None and math.isfinite(arg.constant):
1940
+ if isinstance(node.op, ast.USub):
1941
+ return adj.add_constant(-arg.constant)
1942
+
1909
1943
  name = builtin_operators[type(node.op)]
1910
1944
 
1911
1945
  return adj.add_builtin_call(name, [arg])
@@ -2244,15 +2278,22 @@ class Adjoint:
2244
2278
  out = adj.add_call(func, args, kwargs, type_args, min_outputs=min_outputs)
2245
2279
 
2246
2280
  if warp.config.verify_autograd_array_access:
2281
+ # Extract the types and values passed as arguments to the function call.
2282
+ arg_types = tuple(strip_reference(get_arg_type(x)) for x in args)
2283
+ kwarg_types = {k: strip_reference(get_arg_type(v)) for k, v in kwargs.items()}
2284
+
2285
+ # Resolve the exact function signature among any existing overload.
2286
+ resolved_func = adj.resolve_func(func, arg_types, kwarg_types, min_outputs)
2287
+
2247
2288
  # update arg read/write states according to what happens to that arg in the called function
2248
- if hasattr(func, "adj"):
2289
+ if hasattr(resolved_func, "adj"):
2249
2290
  for i, arg in enumerate(args):
2250
- if func.adj.args[i].is_write:
2291
+ if resolved_func.adj.args[i].is_write:
2251
2292
  kernel_name = adj.fun_name
2252
2293
  filename = adj.filename
2253
2294
  lineno = adj.lineno + adj.fun_lineno
2254
2295
  arg.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2255
- if func.adj.args[i].is_read:
2296
+ if resolved_func.adj.args[i].is_read:
2256
2297
  arg.mark_read()
2257
2298
 
2258
2299
  return out
@@ -2350,12 +2391,16 @@ class Adjoint:
2350
2391
  out.is_write = target.is_write
2351
2392
 
2352
2393
  elif is_tile(target_type):
2353
- if len(indices) == 2:
2394
+ if len(indices) == len(target_type.shape):
2354
2395
  # handles extracting a single element from a tile
2355
2396
  out = adj.add_builtin_call("tile_extract", [target, *indices])
2356
- else:
2397
+ elif len(indices) < len(target_type.shape):
2357
2398
  # handles tile views
2358
- out = adj.add_builtin_call("tile_view", [target, *indices])
2399
+ out = adj.add_builtin_call("tile_view", [target, indices])
2400
+ else:
2401
+ raise RuntimeError(
2402
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2403
+ )
2359
2404
 
2360
2405
  else:
2361
2406
  # handles non-array type indexing, e.g: vec3, mat33, etc
@@ -2447,6 +2492,9 @@ class Adjoint:
2447
2492
 
2448
2493
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2449
2494
 
2495
+ elif is_tile(target_type):
2496
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2497
+
2450
2498
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2451
2499
  # recursively unwind AST, stopping at penultimate node
2452
2500
  node = lhs
@@ -2473,15 +2521,18 @@ class Adjoint:
2473
2521
  print(
2474
2522
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2475
2523
  )
2476
-
2477
2524
  else:
2478
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2479
-
2480
- # re-point target symbol to out var
2481
- for id in adj.symbols:
2482
- if adj.symbols[id] == target:
2483
- adj.symbols[id] = out
2484
- break
2525
+ if adj.builder_options.get("enable_backward", True):
2526
+ out = adj.add_builtin_call("assign", [target, *indices, rhs])
2527
+
2528
+ # re-point target symbol to out var
2529
+ for id in adj.symbols:
2530
+ if adj.symbols[id] == target:
2531
+ adj.symbols[id] = out
2532
+ break
2533
+ else:
2534
+ attr = adj.add_builtin_call("index", [target, *indices])
2535
+ adj.add_builtin_call("store", [attr, rhs])
2485
2536
 
2486
2537
  else:
2487
2538
  raise WarpCodegenError(
@@ -2518,22 +2569,23 @@ class Adjoint:
2518
2569
 
2519
2570
  # assigning to a vector or quaternion component
2520
2571
  if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2521
- # TODO: handle wp.adjoint case
2522
-
2523
2572
  index = adj.vector_component_index(lhs.attr, aggregate_type)
2524
2573
 
2525
- # TODO: array vec component case
2526
2574
  if is_reference(aggregate.type):
2527
2575
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2528
2576
  adj.add_builtin_call("store", [attr, rhs])
2529
2577
  else:
2530
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2531
-
2532
- # re-point target symbol to out var
2533
- for id in adj.symbols:
2534
- if adj.symbols[id] == aggregate:
2535
- adj.symbols[id] = out
2536
- break
2578
+ if adj.builder_options.get("enable_backward", True):
2579
+ out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2580
+
2581
+ # re-point target symbol to out var
2582
+ for id in adj.symbols:
2583
+ if adj.symbols[id] == aggregate:
2584
+ adj.symbols[id] = out
2585
+ break
2586
+ else:
2587
+ attr = adj.add_builtin_call("index", [aggregate, index])
2588
+ adj.add_builtin_call("store", [attr, rhs])
2537
2589
 
2538
2590
  else:
2539
2591
  attr = adj.emit_Attribute(lhs)
@@ -2637,10 +2689,14 @@ class Adjoint:
2637
2689
  make_new_assign_statement()
2638
2690
  return
2639
2691
 
2640
- # TODO
2641
2692
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2642
- make_new_assign_statement()
2643
- return
2693
+ if isinstance(node.op, ast.Add):
2694
+ adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2695
+ elif isinstance(node.op, ast.Sub):
2696
+ adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2697
+ else:
2698
+ make_new_assign_statement()
2699
+ return
2644
2700
 
2645
2701
  else:
2646
2702
  raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
@@ -2688,6 +2744,7 @@ class Adjoint:
2688
2744
  ast.Tuple: emit_Tuple,
2689
2745
  ast.Pass: emit_Pass,
2690
2746
  ast.Ellipsis: emit_Ellipsis,
2747
+ ast.Assert: emit_Assert,
2691
2748
  }
2692
2749
 
2693
2750
  def eval(adj, node):
@@ -2850,11 +2907,62 @@ class Adjoint:
2850
2907
  if static_code is None:
2851
2908
  raise WarpCodegenError("Error extracting source code from wp.static() expression")
2852
2909
 
2910
+ # Since this is an expression, we can enforce it to be defined on a single line.
2911
+ static_code = static_code.replace("\n", "")
2912
+
2853
2913
  vars_dict = adj.get_static_evaluation_context()
2854
2914
  # add constant variables to the static call context
2855
2915
  constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2856
2916
  vars_dict.update(constant_vars)
2857
2917
 
2918
+ # Replace all constant `len()` expressions with their value.
2919
+ if "len" in static_code:
2920
+
2921
+ def eval_len(obj):
2922
+ if type_is_vector(obj):
2923
+ return obj._length_
2924
+ elif type_is_quaternion(obj):
2925
+ return obj._length_
2926
+ elif type_is_matrix(obj):
2927
+ return obj._shape_[0]
2928
+ elif type_is_transformation(obj):
2929
+ return obj._length_
2930
+ elif is_tile(obj):
2931
+ return obj.shape[0]
2932
+
2933
+ return len(obj)
2934
+
2935
+ len_expr_ctx = vars_dict.copy()
2936
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2937
+ len_expr_ctx.update(constant_types)
2938
+ len_expr_ctx.update({"len": eval_len})
2939
+
2940
+ # We want to replace the expression code in-place,
2941
+ # so reparse it to get the correct column info.
2942
+ len_value_locs = []
2943
+ expr_tree = ast.parse(static_code)
2944
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2945
+ expr_root = expr_tree.body[0].value
2946
+ for expr_node in ast.walk(expr_root):
2947
+ if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2948
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2949
+ try:
2950
+ len_value = eval(len_expr, len_expr_ctx)
2951
+ except Exception:
2952
+ pass
2953
+ else:
2954
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
2955
+
2956
+ if len_value_locs:
2957
+ new_static_code = ""
2958
+ loc = 0
2959
+ for value, start, end in len_value_locs:
2960
+ new_static_code += f"{static_code[loc:start]}{value}"
2961
+ loc = end
2962
+
2963
+ new_static_code += static_code[len_value_locs[-1][2] :]
2964
+ static_code = new_static_code
2965
+
2858
2966
  try:
2859
2967
  value = eval(static_code, vars_dict)
2860
2968
  if warp.config.verbose:
@@ -3139,7 +3247,7 @@ static CUDA_CALLABLE void adj_{name}(
3139
3247
 
3140
3248
  """
3141
3249
 
3142
- cuda_kernel_template = """
3250
+ cuda_kernel_template_forward = """
3143
3251
 
3144
3252
  extern "C" __global__ void {name}_cuda_kernel_forward(
3145
3253
  {forward_args})
@@ -3154,6 +3262,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
3154
3262
  {forward_body} }}
3155
3263
  }}
3156
3264
 
3265
+ """
3266
+
3267
+ cuda_kernel_template_backward = """
3268
+
3157
3269
  extern "C" __global__ void {name}_cuda_kernel_backward(
3158
3270
  {reverse_args})
3159
3271
  {{
@@ -3169,13 +3281,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
3169
3281
 
3170
3282
  """
3171
3283
 
3172
- cpu_kernel_template = """
3284
+ cpu_kernel_template_forward = """
3173
3285
 
3174
3286
  void {name}_cpu_kernel_forward(
3175
3287
  {forward_args})
3176
3288
  {{
3177
3289
  {forward_body}}}
3178
3290
 
3291
+ """
3292
+
3293
+ cpu_kernel_template_backward = """
3294
+
3179
3295
  void {name}_cpu_kernel_backward(
3180
3296
  {reverse_args})
3181
3297
  {{
@@ -3183,7 +3299,7 @@ void {name}_cpu_kernel_backward(
3183
3299
 
3184
3300
  """
3185
3301
 
3186
- cpu_module_template = """
3302
+ cpu_module_template_forward = """
3187
3303
 
3188
3304
  extern "C" {{
3189
3305
 
@@ -3198,6 +3314,14 @@ WP_API void {name}_cpu_forward(
3198
3314
  }}
3199
3315
  }}
3200
3316
 
3317
+ }} // extern C
3318
+
3319
+ """
3320
+
3321
+ cpu_module_template_backward = """
3322
+
3323
+ extern "C" {{
3324
+
3201
3325
  WP_API void {name}_cpu_backward(
3202
3326
  {reverse_args})
3203
3327
  {{
@@ -3212,36 +3336,6 @@ WP_API void {name}_cpu_backward(
3212
3336
 
3213
3337
  """
3214
3338
 
3215
- cuda_module_header_template = """
3216
-
3217
- extern "C" {{
3218
-
3219
- // Python CUDA entry points
3220
- WP_API void {name}_cuda_forward(
3221
- void* stream,
3222
- {forward_args});
3223
-
3224
- WP_API void {name}_cuda_backward(
3225
- void* stream,
3226
- {reverse_args});
3227
-
3228
- }} // extern C
3229
- """
3230
-
3231
- cpu_module_header_template = """
3232
-
3233
- extern "C" {{
3234
-
3235
- // Python CPU entry points
3236
- WP_API void {name}_cpu_forward(
3237
- {forward_args});
3238
-
3239
- WP_API void {name}_cpu_backward(
3240
- {reverse_args});
3241
-
3242
- }} // extern C
3243
- """
3244
-
3245
3339
 
3246
3340
  # converts a constant Python value to equivalent C-repr
3247
3341
  def constant_str(value):
@@ -3679,59 +3773,82 @@ def codegen_kernel(kernel, device, options):
3679
3773
 
3680
3774
  adj = kernel.adj
3681
3775
 
3682
- forward_args = ["wp::launch_bounds_t dim"]
3683
- reverse_args = ["wp::launch_bounds_t dim"]
3776
+ if device == "cpu":
3777
+ template_forward = cpu_kernel_template_forward
3778
+ template_backward = cpu_kernel_template_backward
3779
+ elif device == "cuda":
3780
+ template_forward = cuda_kernel_template_forward
3781
+ template_backward = cuda_kernel_template_backward
3782
+ else:
3783
+ raise ValueError(f"Device {device} is not supported")
3784
+
3785
+ template = ""
3786
+ template_fmt_args = {
3787
+ "name": kernel.get_mangled_name(),
3788
+ }
3684
3789
 
3790
+ # build forward signature
3791
+ forward_args = ["wp::launch_bounds_t dim"]
3685
3792
  if device == "cpu":
3686
3793
  forward_args.append("size_t task_index")
3687
- reverse_args.append("size_t task_index")
3688
3794
 
3689
- # forward args
3690
3795
  for arg in adj.args:
3691
3796
  forward_args.append(arg.ctype() + " var_" + arg.label)
3692
- reverse_args.append(arg.ctype() + " var_" + arg.label)
3693
3797
 
3694
- # reverse args
3695
- for arg in adj.args:
3696
- # indexed array gradients are regular arrays
3697
- if isinstance(arg.type, indexedarray):
3698
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3699
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3700
- else:
3701
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
3702
-
3703
- # codegen body
3704
3798
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
3799
+ template_fmt_args.update(
3800
+ {
3801
+ "forward_args": indent(forward_args),
3802
+ "forward_body": forward_body,
3803
+ }
3804
+ )
3805
+ template += template_forward
3705
3806
 
3706
3807
  if options["enable_backward"]:
3707
- reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3708
- else:
3709
- reverse_body = ""
3808
+ # build reverse signature
3809
+ reverse_args = ["wp::launch_bounds_t dim"]
3810
+ if device == "cpu":
3811
+ reverse_args.append("size_t task_index")
3710
3812
 
3711
- if device == "cpu":
3712
- template = cpu_kernel_template
3713
- elif device == "cuda":
3714
- template = cuda_kernel_template
3715
- else:
3716
- raise ValueError(f"Device {device} is not supported")
3813
+ for arg in adj.args:
3814
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
3717
3815
 
3718
- s = template.format(
3719
- name=kernel.get_mangled_name(),
3720
- forward_args=indent(forward_args),
3721
- reverse_args=indent(reverse_args),
3722
- forward_body=forward_body,
3723
- reverse_body=reverse_body,
3724
- )
3816
+ for arg in adj.args:
3817
+ # indexed array gradients are regular arrays
3818
+ if isinstance(arg.type, indexedarray):
3819
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3820
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3821
+ else:
3822
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
3725
3823
 
3824
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3825
+ template_fmt_args.update(
3826
+ {
3827
+ "reverse_args": indent(reverse_args),
3828
+ "reverse_body": reverse_body,
3829
+ }
3830
+ )
3831
+ template += template_backward
3832
+
3833
+ s = template.format(**template_fmt_args)
3726
3834
  return s
3727
3835
 
3728
3836
 
3729
- def codegen_module(kernel, device="cpu"):
3837
+ def codegen_module(kernel, device, options):
3730
3838
  if device != "cpu":
3731
3839
  return ""
3732
3840
 
3841
+ # Update the module's options with the ones defined on the kernel, if any.
3842
+ options = dict(options)
3843
+ options.update(kernel.options)
3844
+
3733
3845
  adj = kernel.adj
3734
3846
 
3847
+ template = ""
3848
+ template_fmt_args = {
3849
+ "name": kernel.get_mangled_name(),
3850
+ }
3851
+
3735
3852
  # build forward signature
3736
3853
  forward_args = ["wp::launch_bounds_t dim"]
3737
3854
  forward_params = ["dim", "task_index"]
@@ -3745,29 +3862,40 @@ def codegen_module(kernel, device="cpu"):
3745
3862
  forward_args.append(f"{arg.ctype()} var_{arg.label}")
3746
3863
  forward_params.append("var_" + arg.label)
3747
3864
 
3748
- # build reverse signature
3749
- reverse_args = [*forward_args]
3750
- reverse_params = [*forward_params]
3865
+ template_fmt_args.update(
3866
+ {
3867
+ "forward_args": indent(forward_args),
3868
+ "forward_params": indent(forward_params, 3),
3869
+ }
3870
+ )
3871
+ template += cpu_module_template_forward
3751
3872
 
3752
- for arg in adj.args:
3753
- if isinstance(arg.type, indexedarray):
3754
- # indexed array gradients are regular arrays
3755
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3756
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3757
- reverse_params.append(f"adj_{_arg.label}")
3758
- elif hasattr(arg.type, "_wp_generic_type_str_"):
3759
- # vectors and matrices are passed from Python by pointer
3760
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3761
- reverse_params.append(f"*adj_{arg.label}")
3762
- else:
3763
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3764
- reverse_params.append(f"adj_{arg.label}")
3873
+ if options["enable_backward"]:
3874
+ # build reverse signature
3875
+ reverse_args = [*forward_args]
3876
+ reverse_params = [*forward_params]
3765
3877
 
3766
- s = cpu_module_template.format(
3767
- name=kernel.get_mangled_name(),
3768
- forward_args=indent(forward_args),
3769
- reverse_args=indent(reverse_args),
3770
- forward_params=indent(forward_params, 3),
3771
- reverse_params=indent(reverse_params, 3),
3772
- )
3878
+ for arg in adj.args:
3879
+ if isinstance(arg.type, indexedarray):
3880
+ # indexed array gradients are regular arrays
3881
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3882
+ reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3883
+ reverse_params.append(f"adj_{_arg.label}")
3884
+ elif hasattr(arg.type, "_wp_generic_type_str_"):
3885
+ # vectors and matrices are passed from Python by pointer
3886
+ reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3887
+ reverse_params.append(f"*adj_{arg.label}")
3888
+ else:
3889
+ reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3890
+ reverse_params.append(f"adj_{arg.label}")
3891
+
3892
+ template_fmt_args.update(
3893
+ {
3894
+ "reverse_args": indent(reverse_args),
3895
+ "reverse_params": indent(reverse_params, 3),
3896
+ }
3897
+ )
3898
+ template += cpu_module_template_backward
3899
+
3900
+ s = template.format(**template_fmt_args)
3773
3901
  return s