warp-lang 1.5.1__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (123) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1076 -480
  8. warp/codegen.py +240 -119
  9. warp/config.py +1 -1
  10. warp/context.py +298 -84
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth_self_contact.py +260 -0
  27. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  28. warp/examples/sim/example_jacobian_ik.py +0 -2
  29. warp/examples/sim/example_quadruped.py +5 -2
  30. warp/examples/tile/example_tile_cholesky.py +79 -0
  31. warp/examples/tile/example_tile_convolution.py +2 -2
  32. warp/examples/tile/example_tile_fft.py +2 -2
  33. warp/examples/tile/example_tile_filtering.py +3 -3
  34. warp/examples/tile/example_tile_matmul.py +4 -4
  35. warp/examples/tile/example_tile_mlp.py +12 -12
  36. warp/examples/tile/example_tile_nbody.py +180 -0
  37. warp/examples/tile/example_tile_walker.py +319 -0
  38. warp/math.py +147 -0
  39. warp/native/array.h +12 -0
  40. warp/native/builtin.h +0 -1
  41. warp/native/bvh.cpp +149 -70
  42. warp/native/bvh.cu +287 -68
  43. warp/native/bvh.h +195 -85
  44. warp/native/clang/clang.cpp +5 -1
  45. warp/native/cuda_util.cpp +35 -0
  46. warp/native/cuda_util.h +5 -0
  47. warp/native/exports.h +40 -40
  48. warp/native/intersect.h +17 -0
  49. warp/native/mat.h +41 -0
  50. warp/native/mathdx.cpp +19 -0
  51. warp/native/mesh.cpp +25 -8
  52. warp/native/mesh.cu +153 -101
  53. warp/native/mesh.h +482 -403
  54. warp/native/quat.h +40 -0
  55. warp/native/solid_angle.h +7 -0
  56. warp/native/sort.cpp +85 -0
  57. warp/native/sort.cu +34 -0
  58. warp/native/sort.h +3 -1
  59. warp/native/spatial.h +11 -0
  60. warp/native/tile.h +1185 -664
  61. warp/native/tile_reduce.h +8 -6
  62. warp/native/vec.h +41 -0
  63. warp/native/warp.cpp +8 -1
  64. warp/native/warp.cu +263 -40
  65. warp/native/warp.h +19 -5
  66. warp/optim/linear.py +22 -4
  67. warp/render/render_opengl.py +124 -59
  68. warp/sim/__init__.py +6 -1
  69. warp/sim/collide.py +270 -26
  70. warp/sim/integrator_euler.py +25 -7
  71. warp/sim/integrator_featherstone.py +154 -35
  72. warp/sim/integrator_vbd.py +842 -40
  73. warp/sim/model.py +111 -53
  74. warp/stubs.py +248 -115
  75. warp/tape.py +28 -30
  76. warp/tests/aux_test_module_unload.py +15 -0
  77. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  78. warp/tests/test_array.py +74 -0
  79. warp/tests/test_assert.py +242 -0
  80. warp/tests/test_codegen.py +14 -61
  81. warp/tests/test_collision.py +2 -2
  82. warp/tests/test_examples.py +9 -0
  83. warp/tests/test_grad_debug.py +87 -2
  84. warp/tests/test_hash_grid.py +1 -1
  85. warp/tests/test_ipc.py +116 -0
  86. warp/tests/test_mat.py +138 -167
  87. warp/tests/test_math.py +47 -1
  88. warp/tests/test_matmul.py +11 -7
  89. warp/tests/test_matmul_lite.py +4 -4
  90. warp/tests/test_mesh.py +84 -60
  91. warp/tests/test_mesh_query_aabb.py +165 -0
  92. warp/tests/test_mesh_query_point.py +328 -286
  93. warp/tests/test_mesh_query_ray.py +134 -121
  94. warp/tests/test_mlp.py +2 -2
  95. warp/tests/test_operators.py +43 -0
  96. warp/tests/test_overwrite.py +2 -2
  97. warp/tests/test_quat.py +77 -0
  98. warp/tests/test_reload.py +29 -0
  99. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  100. warp/tests/test_static.py +16 -0
  101. warp/tests/test_tape.py +25 -0
  102. warp/tests/test_tile.py +134 -191
  103. warp/tests/test_tile_load.py +356 -0
  104. warp/tests/test_tile_mathdx.py +61 -8
  105. warp/tests/test_tile_mlp.py +17 -17
  106. warp/tests/test_tile_reduce.py +24 -18
  107. warp/tests/test_tile_shared_memory.py +66 -17
  108. warp/tests/test_tile_view.py +165 -0
  109. warp/tests/test_torch.py +35 -0
  110. warp/tests/test_utils.py +36 -24
  111. warp/tests/test_vec.py +110 -0
  112. warp/tests/unittest_suites.py +29 -4
  113. warp/tests/unittest_utils.py +30 -11
  114. warp/thirdparty/unittest_parallel.py +2 -2
  115. warp/types.py +409 -99
  116. warp/utils.py +9 -5
  117. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/METADATA +68 -44
  118. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/RECORD +121 -110
  119. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  120. warp/examples/benchmarks/benchmark_tile.py +0 -179
  121. warp/native/tile_gemm.h +0 -341
  122. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  123. {warp_lang-1.5.1.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -282,9 +282,9 @@ class StructInstance:
282
282
  else:
283
283
  # wp.array
284
284
  assert isinstance(value, array)
285
- assert types_equal(
286
- value.dtype, var.type.dtype
287
- ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
285
+ assert types_equal(value.dtype, var.type.dtype), (
286
+ f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
287
+ )
288
288
  setattr(self._ctype, name, value.__ctype__())
289
289
 
290
290
  elif isinstance(var.type, Struct):
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
606
606
  return "bool"
607
607
  else:
608
608
  return f"wp::{p.__name__}"
609
+ elif is_tile(p):
610
+ return p.ctype()
611
+
609
612
  return p.__name__
610
613
 
611
614
  return f"{base_name}<{','.join(map(param2str, template_params))}>"
@@ -947,7 +950,7 @@ class Adjoint:
947
950
  total_shared = 0
948
951
 
949
952
  for var in adj.variables:
950
- if is_tile(var.type) and var.type.storage == "shared":
953
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
951
954
  total_shared += var.type.size_in_bytes()
952
955
 
953
956
  return total_shared + adj.max_required_extra_shared_memory
@@ -1139,6 +1142,9 @@ class Adjoint:
1139
1142
  if isinstance(var, (Reference, warp.context.Function)):
1140
1143
  return var
1141
1144
 
1145
+ if isinstance(var, int):
1146
+ return adj.add_constant(var)
1147
+
1142
1148
  if var.label is None:
1143
1149
  return adj.add_var(var.type, var.constant)
1144
1150
 
@@ -1349,8 +1355,9 @@ class Adjoint:
1349
1355
  # which allows for some more advanced resolution to be performed,
1350
1356
  # for example by checking whether an argument corresponds to
1351
1357
  # a literal value or references a variable.
1358
+ extra_shared_memory = 0
1352
1359
  if func.lto_dispatch_func is not None:
1353
- func_args, template_args, ltoirs = func.lto_dispatch_func(
1360
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1354
1361
  func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
1362
  )
1356
1363
  elif func.dispatch_func is not None:
@@ -1424,7 +1431,9 @@ class Adjoint:
1424
1431
  # update our smem roofline requirements based on any
1425
1432
  # shared memory required by the dependent function call
1426
1433
  if not func.is_builtin():
1427
- adj.alloc_shared_extra(func.adj.get_total_required_shared())
1434
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1435
+ else:
1436
+ adj.alloc_shared_extra(extra_shared_memory)
1428
1437
 
1429
1438
  return output
1430
1439
 
@@ -1527,7 +1536,8 @@ class Adjoint:
1527
1536
  # zero adjoints
1528
1537
  for i in body_block.vars:
1529
1538
  if is_tile(i.type):
1530
- reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1539
+ if i.type.owner:
1540
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1531
1541
  else:
1532
1542
  reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1533
1543
 
@@ -1857,6 +1867,17 @@ class Adjoint:
1857
1867
  # stubbed @wp.native_func
1858
1868
  return
1859
1869
 
1870
+ def emit_Assert(adj, node):
1871
+ # eval condition
1872
+ cond = adj.eval(node.test)
1873
+ cond = adj.load(cond)
1874
+
1875
+ source_segment = ast.get_source_segment(adj.source, node)
1876
+ # If a message was provided with the assert, " marks can interfere with the generated code
1877
+ escaped_segment = source_segment.replace('"', '\\"')
1878
+
1879
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1880
+
1860
1881
  def emit_NameConstant(adj, node):
1861
1882
  if node.value:
1862
1883
  return adj.add_constant(node.value)
@@ -1900,12 +1921,25 @@ class Adjoint:
1900
1921
 
1901
1922
  name = builtin_operators[type(node.op)]
1902
1923
 
1924
+ try:
1925
+ # Check if there is any user-defined overload for this operator
1926
+ user_func = adj.resolve_external_reference(name)
1927
+ if isinstance(user_func, warp.context.Function):
1928
+ return adj.add_call(user_func, (left, right), {}, {})
1929
+ except WarpCodegenError:
1930
+ pass
1931
+
1903
1932
  return adj.add_builtin_call(name, [left, right])
1904
1933
 
1905
1934
  def emit_UnaryOp(adj, node):
1906
1935
  # evaluate unary op arguments
1907
1936
  arg = adj.eval(node.operand)
1908
1937
 
1938
+ # evaluate expression to a compile-time constant if arg is a constant
1939
+ if arg.constant is not None and math.isfinite(arg.constant):
1940
+ if isinstance(node.op, ast.USub):
1941
+ return adj.add_constant(-arg.constant)
1942
+
1909
1943
  name = builtin_operators[type(node.op)]
1910
1944
 
1911
1945
  return adj.add_builtin_call(name, [arg])
@@ -2350,12 +2384,16 @@ class Adjoint:
2350
2384
  out.is_write = target.is_write
2351
2385
 
2352
2386
  elif is_tile(target_type):
2353
- if len(indices) == 2:
2387
+ if len(indices) == len(target_type.shape):
2354
2388
  # handles extracting a single element from a tile
2355
2389
  out = adj.add_builtin_call("tile_extract", [target, *indices])
2356
- else:
2390
+ elif len(indices) < len(target_type.shape):
2357
2391
  # handles tile views
2358
- out = adj.add_builtin_call("tile_view", [target, *indices])
2392
+ out = adj.add_builtin_call("tile_view", [target, indices])
2393
+ else:
2394
+ raise RuntimeError(
2395
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2396
+ )
2359
2397
 
2360
2398
  else:
2361
2399
  # handles non-array type indexing, e.g: vec3, mat33, etc
@@ -2447,6 +2485,9 @@ class Adjoint:
2447
2485
 
2448
2486
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2449
2487
 
2488
+ elif is_tile(target_type):
2489
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2490
+
2450
2491
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2451
2492
  # recursively unwind AST, stopping at penultimate node
2452
2493
  node = lhs
@@ -2473,15 +2514,18 @@ class Adjoint:
2473
2514
  print(
2474
2515
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2475
2516
  )
2476
-
2477
2517
  else:
2478
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2479
-
2480
- # re-point target symbol to out var
2481
- for id in adj.symbols:
2482
- if adj.symbols[id] == target:
2483
- adj.symbols[id] = out
2484
- break
2518
+ if adj.builder_options.get("enable_backward", True):
2519
+ out = adj.add_builtin_call("assign", [target, *indices, rhs])
2520
+
2521
+ # re-point target symbol to out var
2522
+ for id in adj.symbols:
2523
+ if adj.symbols[id] == target:
2524
+ adj.symbols[id] = out
2525
+ break
2526
+ else:
2527
+ attr = adj.add_builtin_call("index", [target, *indices])
2528
+ adj.add_builtin_call("store", [attr, rhs])
2485
2529
 
2486
2530
  else:
2487
2531
  raise WarpCodegenError(
@@ -2518,22 +2562,23 @@ class Adjoint:
2518
2562
 
2519
2563
  # assigning to a vector or quaternion component
2520
2564
  if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2521
- # TODO: handle wp.adjoint case
2522
-
2523
2565
  index = adj.vector_component_index(lhs.attr, aggregate_type)
2524
2566
 
2525
- # TODO: array vec component case
2526
2567
  if is_reference(aggregate.type):
2527
2568
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2528
2569
  adj.add_builtin_call("store", [attr, rhs])
2529
2570
  else:
2530
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2531
-
2532
- # re-point target symbol to out var
2533
- for id in adj.symbols:
2534
- if adj.symbols[id] == aggregate:
2535
- adj.symbols[id] = out
2536
- break
2571
+ if adj.builder_options.get("enable_backward", True):
2572
+ out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2573
+
2574
+ # re-point target symbol to out var
2575
+ for id in adj.symbols:
2576
+ if adj.symbols[id] == aggregate:
2577
+ adj.symbols[id] = out
2578
+ break
2579
+ else:
2580
+ attr = adj.add_builtin_call("index", [aggregate, index])
2581
+ adj.add_builtin_call("store", [attr, rhs])
2537
2582
 
2538
2583
  else:
2539
2584
  attr = adj.emit_Attribute(lhs)
@@ -2637,10 +2682,14 @@ class Adjoint:
2637
2682
  make_new_assign_statement()
2638
2683
  return
2639
2684
 
2640
- # TODO
2641
2685
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2642
- make_new_assign_statement()
2643
- return
2686
+ if isinstance(node.op, ast.Add):
2687
+ adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2688
+ elif isinstance(node.op, ast.Sub):
2689
+ adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2690
+ else:
2691
+ make_new_assign_statement()
2692
+ return
2644
2693
 
2645
2694
  else:
2646
2695
  raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
@@ -2688,6 +2737,7 @@ class Adjoint:
2688
2737
  ast.Tuple: emit_Tuple,
2689
2738
  ast.Pass: emit_Pass,
2690
2739
  ast.Ellipsis: emit_Ellipsis,
2740
+ ast.Assert: emit_Assert,
2691
2741
  }
2692
2742
 
2693
2743
  def eval(adj, node):
@@ -2850,11 +2900,62 @@ class Adjoint:
2850
2900
  if static_code is None:
2851
2901
  raise WarpCodegenError("Error extracting source code from wp.static() expression")
2852
2902
 
2903
+ # Since this is an expression, we can enforce it to be defined on a single line.
2904
+ static_code = static_code.replace("\n", "")
2905
+
2853
2906
  vars_dict = adj.get_static_evaluation_context()
2854
2907
  # add constant variables to the static call context
2855
2908
  constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2856
2909
  vars_dict.update(constant_vars)
2857
2910
 
2911
+ # Replace all constant `len()` expressions with their value.
2912
+ if "len" in static_code:
2913
+
2914
+ def eval_len(obj):
2915
+ if type_is_vector(obj):
2916
+ return obj._length_
2917
+ elif type_is_quaternion(obj):
2918
+ return obj._length_
2919
+ elif type_is_matrix(obj):
2920
+ return obj._shape_[0]
2921
+ elif type_is_transformation(obj):
2922
+ return obj._length_
2923
+ elif is_tile(obj):
2924
+ return obj.shape[0]
2925
+
2926
+ return len(obj)
2927
+
2928
+ len_expr_ctx = vars_dict.copy()
2929
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2930
+ len_expr_ctx.update(constant_types)
2931
+ len_expr_ctx.update({"len": eval_len})
2932
+
2933
+ # We want to replace the expression code in-place,
2934
+ # so reparse it to get the correct column info.
2935
+ len_value_locs = []
2936
+ expr_tree = ast.parse(static_code)
2937
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2938
+ expr_root = expr_tree.body[0].value
2939
+ for expr_node in ast.walk(expr_root):
2940
+ if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2941
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2942
+ try:
2943
+ len_value = eval(len_expr, len_expr_ctx)
2944
+ except Exception:
2945
+ pass
2946
+ else:
2947
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
2948
+
2949
+ if len_value_locs:
2950
+ new_static_code = ""
2951
+ loc = 0
2952
+ for value, start, end in len_value_locs:
2953
+ new_static_code += f"{static_code[loc:start]}{value}"
2954
+ loc = end
2955
+
2956
+ new_static_code += static_code[len_value_locs[-1][2] :]
2957
+ static_code = new_static_code
2958
+
2858
2959
  try:
2859
2960
  value = eval(static_code, vars_dict)
2860
2961
  if warp.config.verbose:
@@ -3139,7 +3240,7 @@ static CUDA_CALLABLE void adj_{name}(
3139
3240
 
3140
3241
  """
3141
3242
 
3142
- cuda_kernel_template = """
3243
+ cuda_kernel_template_forward = """
3143
3244
 
3144
3245
  extern "C" __global__ void {name}_cuda_kernel_forward(
3145
3246
  {forward_args})
@@ -3154,6 +3255,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
3154
3255
  {forward_body} }}
3155
3256
  }}
3156
3257
 
3258
+ """
3259
+
3260
+ cuda_kernel_template_backward = """
3261
+
3157
3262
  extern "C" __global__ void {name}_cuda_kernel_backward(
3158
3263
  {reverse_args})
3159
3264
  {{
@@ -3169,13 +3274,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
3169
3274
 
3170
3275
  """
3171
3276
 
3172
- cpu_kernel_template = """
3277
+ cpu_kernel_template_forward = """
3173
3278
 
3174
3279
  void {name}_cpu_kernel_forward(
3175
3280
  {forward_args})
3176
3281
  {{
3177
3282
  {forward_body}}}
3178
3283
 
3284
+ """
3285
+
3286
+ cpu_kernel_template_backward = """
3287
+
3179
3288
  void {name}_cpu_kernel_backward(
3180
3289
  {reverse_args})
3181
3290
  {{
@@ -3183,7 +3292,7 @@ void {name}_cpu_kernel_backward(
3183
3292
 
3184
3293
  """
3185
3294
 
3186
- cpu_module_template = """
3295
+ cpu_module_template_forward = """
3187
3296
 
3188
3297
  extern "C" {{
3189
3298
 
@@ -3198,6 +3307,14 @@ WP_API void {name}_cpu_forward(
3198
3307
  }}
3199
3308
  }}
3200
3309
 
3310
+ }} // extern C
3311
+
3312
+ """
3313
+
3314
+ cpu_module_template_backward = """
3315
+
3316
+ extern "C" {{
3317
+
3201
3318
  WP_API void {name}_cpu_backward(
3202
3319
  {reverse_args})
3203
3320
  {{
@@ -3212,36 +3329,6 @@ WP_API void {name}_cpu_backward(
3212
3329
 
3213
3330
  """
3214
3331
 
3215
- cuda_module_header_template = """
3216
-
3217
- extern "C" {{
3218
-
3219
- // Python CUDA entry points
3220
- WP_API void {name}_cuda_forward(
3221
- void* stream,
3222
- {forward_args});
3223
-
3224
- WP_API void {name}_cuda_backward(
3225
- void* stream,
3226
- {reverse_args});
3227
-
3228
- }} // extern C
3229
- """
3230
-
3231
- cpu_module_header_template = """
3232
-
3233
- extern "C" {{
3234
-
3235
- // Python CPU entry points
3236
- WP_API void {name}_cpu_forward(
3237
- {forward_args});
3238
-
3239
- WP_API void {name}_cpu_backward(
3240
- {reverse_args});
3241
-
3242
- }} // extern C
3243
- """
3244
-
3245
3332
 
3246
3333
  # converts a constant Python value to equivalent C-repr
3247
3334
  def constant_str(value):
@@ -3679,59 +3766,82 @@ def codegen_kernel(kernel, device, options):
3679
3766
 
3680
3767
  adj = kernel.adj
3681
3768
 
3682
- forward_args = ["wp::launch_bounds_t dim"]
3683
- reverse_args = ["wp::launch_bounds_t dim"]
3769
+ if device == "cpu":
3770
+ template_forward = cpu_kernel_template_forward
3771
+ template_backward = cpu_kernel_template_backward
3772
+ elif device == "cuda":
3773
+ template_forward = cuda_kernel_template_forward
3774
+ template_backward = cuda_kernel_template_backward
3775
+ else:
3776
+ raise ValueError(f"Device {device} is not supported")
3777
+
3778
+ template = ""
3779
+ template_fmt_args = {
3780
+ "name": kernel.get_mangled_name(),
3781
+ }
3684
3782
 
3783
+ # build forward signature
3784
+ forward_args = ["wp::launch_bounds_t dim"]
3685
3785
  if device == "cpu":
3686
3786
  forward_args.append("size_t task_index")
3687
- reverse_args.append("size_t task_index")
3688
3787
 
3689
- # forward args
3690
3788
  for arg in adj.args:
3691
3789
  forward_args.append(arg.ctype() + " var_" + arg.label)
3692
- reverse_args.append(arg.ctype() + " var_" + arg.label)
3693
-
3694
- # reverse args
3695
- for arg in adj.args:
3696
- # indexed array gradients are regular arrays
3697
- if isinstance(arg.type, indexedarray):
3698
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3699
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3700
- else:
3701
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
3702
3790
 
3703
- # codegen body
3704
3791
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
3792
+ template_fmt_args.update(
3793
+ {
3794
+ "forward_args": indent(forward_args),
3795
+ "forward_body": forward_body,
3796
+ }
3797
+ )
3798
+ template += template_forward
3705
3799
 
3706
3800
  if options["enable_backward"]:
3707
- reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3708
- else:
3709
- reverse_body = ""
3801
+ # build reverse signature
3802
+ reverse_args = ["wp::launch_bounds_t dim"]
3803
+ if device == "cpu":
3804
+ reverse_args.append("size_t task_index")
3710
3805
 
3711
- if device == "cpu":
3712
- template = cpu_kernel_template
3713
- elif device == "cuda":
3714
- template = cuda_kernel_template
3715
- else:
3716
- raise ValueError(f"Device {device} is not supported")
3806
+ for arg in adj.args:
3807
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
3717
3808
 
3718
- s = template.format(
3719
- name=kernel.get_mangled_name(),
3720
- forward_args=indent(forward_args),
3721
- reverse_args=indent(reverse_args),
3722
- forward_body=forward_body,
3723
- reverse_body=reverse_body,
3724
- )
3809
+ for arg in adj.args:
3810
+ # indexed array gradients are regular arrays
3811
+ if isinstance(arg.type, indexedarray):
3812
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3813
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3814
+ else:
3815
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
3725
3816
 
3817
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3818
+ template_fmt_args.update(
3819
+ {
3820
+ "reverse_args": indent(reverse_args),
3821
+ "reverse_body": reverse_body,
3822
+ }
3823
+ )
3824
+ template += template_backward
3825
+
3826
+ s = template.format(**template_fmt_args)
3726
3827
  return s
3727
3828
 
3728
3829
 
3729
- def codegen_module(kernel, device="cpu"):
3830
+ def codegen_module(kernel, device, options):
3730
3831
  if device != "cpu":
3731
3832
  return ""
3732
3833
 
3834
+ # Update the module's options with the ones defined on the kernel, if any.
3835
+ options = dict(options)
3836
+ options.update(kernel.options)
3837
+
3733
3838
  adj = kernel.adj
3734
3839
 
3840
+ template = ""
3841
+ template_fmt_args = {
3842
+ "name": kernel.get_mangled_name(),
3843
+ }
3844
+
3735
3845
  # build forward signature
3736
3846
  forward_args = ["wp::launch_bounds_t dim"]
3737
3847
  forward_params = ["dim", "task_index"]
@@ -3745,29 +3855,40 @@ def codegen_module(kernel, device="cpu"):
3745
3855
  forward_args.append(f"{arg.ctype()} var_{arg.label}")
3746
3856
  forward_params.append("var_" + arg.label)
3747
3857
 
3748
- # build reverse signature
3749
- reverse_args = [*forward_args]
3750
- reverse_params = [*forward_params]
3858
+ template_fmt_args.update(
3859
+ {
3860
+ "forward_args": indent(forward_args),
3861
+ "forward_params": indent(forward_params, 3),
3862
+ }
3863
+ )
3864
+ template += cpu_module_template_forward
3751
3865
 
3752
- for arg in adj.args:
3753
- if isinstance(arg.type, indexedarray):
3754
- # indexed array gradients are regular arrays
3755
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3756
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3757
- reverse_params.append(f"adj_{_arg.label}")
3758
- elif hasattr(arg.type, "_wp_generic_type_str_"):
3759
- # vectors and matrices are passed from Python by pointer
3760
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3761
- reverse_params.append(f"*adj_{arg.label}")
3762
- else:
3763
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3764
- reverse_params.append(f"adj_{arg.label}")
3866
+ if options["enable_backward"]:
3867
+ # build reverse signature
3868
+ reverse_args = [*forward_args]
3869
+ reverse_params = [*forward_params]
3765
3870
 
3766
- s = cpu_module_template.format(
3767
- name=kernel.get_mangled_name(),
3768
- forward_args=indent(forward_args),
3769
- reverse_args=indent(reverse_args),
3770
- forward_params=indent(forward_params, 3),
3771
- reverse_params=indent(reverse_params, 3),
3772
- )
3871
+ for arg in adj.args:
3872
+ if isinstance(arg.type, indexedarray):
3873
+ # indexed array gradients are regular arrays
3874
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3875
+ reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3876
+ reverse_params.append(f"adj_{_arg.label}")
3877
+ elif hasattr(arg.type, "_wp_generic_type_str_"):
3878
+ # vectors and matrices are passed from Python by pointer
3879
+ reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3880
+ reverse_params.append(f"*adj_{arg.label}")
3881
+ else:
3882
+ reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3883
+ reverse_params.append(f"adj_{arg.label}")
3884
+
3885
+ template_fmt_args.update(
3886
+ {
3887
+ "reverse_args": indent(reverse_args),
3888
+ "reverse_params": indent(reverse_params, 3),
3889
+ }
3890
+ )
3891
+ template += cpu_module_template_backward
3892
+
3893
+ s = template.format(**template_fmt_args)
3773
3894
  return s
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.5.1"
10
+ version: str = "1.6.0"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False