warp-lang 1.5.0__py3-none-macosx_10_13_universal2.whl → 1.6.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (132) hide show
  1. warp/__init__.py +5 -0
  2. warp/autograd.py +414 -191
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +40 -12
  6. warp/build_dll.py +13 -6
  7. warp/builtins.py +1124 -497
  8. warp/codegen.py +261 -136
  9. warp/config.py +1 -1
  10. warp/context.py +357 -119
  11. warp/examples/assets/square_cloth.usd +0 -0
  12. warp/examples/benchmarks/benchmark_gemm.py +27 -18
  13. warp/examples/benchmarks/benchmark_interop_paddle.py +3 -3
  14. warp/examples/benchmarks/benchmark_interop_torch.py +3 -3
  15. warp/examples/core/example_torch.py +18 -34
  16. warp/examples/fem/example_apic_fluid.py +1 -0
  17. warp/examples/fem/example_mixed_elasticity.py +1 -1
  18. warp/examples/optim/example_bounce.py +1 -1
  19. warp/examples/optim/example_cloth_throw.py +1 -1
  20. warp/examples/optim/example_diffray.py +4 -15
  21. warp/examples/optim/example_drone.py +1 -1
  22. warp/examples/optim/example_softbody_properties.py +392 -0
  23. warp/examples/optim/example_trajectory.py +1 -3
  24. warp/examples/optim/example_walker.py +5 -0
  25. warp/examples/sim/example_cartpole.py +0 -2
  26. warp/examples/sim/example_cloth.py +3 -1
  27. warp/examples/sim/example_cloth_self_contact.py +260 -0
  28. warp/examples/sim/example_granular_collision_sdf.py +4 -5
  29. warp/examples/sim/example_jacobian_ik.py +0 -2
  30. warp/examples/sim/example_quadruped.py +5 -2
  31. warp/examples/tile/example_tile_cholesky.py +79 -0
  32. warp/examples/tile/example_tile_convolution.py +2 -2
  33. warp/examples/tile/example_tile_fft.py +2 -2
  34. warp/examples/tile/example_tile_filtering.py +3 -3
  35. warp/examples/tile/example_tile_matmul.py +4 -4
  36. warp/examples/tile/example_tile_mlp.py +12 -12
  37. warp/examples/tile/example_tile_nbody.py +180 -0
  38. warp/examples/tile/example_tile_walker.py +319 -0
  39. warp/fem/geometry/geometry.py +0 -2
  40. warp/math.py +147 -0
  41. warp/native/array.h +12 -0
  42. warp/native/builtin.h +0 -1
  43. warp/native/bvh.cpp +149 -70
  44. warp/native/bvh.cu +287 -68
  45. warp/native/bvh.h +195 -85
  46. warp/native/clang/clang.cpp +5 -1
  47. warp/native/coloring.cpp +5 -1
  48. warp/native/cuda_util.cpp +91 -53
  49. warp/native/cuda_util.h +5 -0
  50. warp/native/exports.h +40 -40
  51. warp/native/intersect.h +17 -0
  52. warp/native/mat.h +41 -0
  53. warp/native/mathdx.cpp +19 -0
  54. warp/native/mesh.cpp +25 -8
  55. warp/native/mesh.cu +153 -101
  56. warp/native/mesh.h +482 -403
  57. warp/native/quat.h +40 -0
  58. warp/native/solid_angle.h +7 -0
  59. warp/native/sort.cpp +85 -0
  60. warp/native/sort.cu +34 -0
  61. warp/native/sort.h +3 -1
  62. warp/native/spatial.h +11 -0
  63. warp/native/tile.h +1187 -669
  64. warp/native/tile_reduce.h +8 -6
  65. warp/native/vec.h +41 -0
  66. warp/native/warp.cpp +8 -1
  67. warp/native/warp.cu +263 -40
  68. warp/native/warp.h +19 -5
  69. warp/optim/linear.py +22 -4
  70. warp/render/render_opengl.py +130 -64
  71. warp/sim/__init__.py +6 -1
  72. warp/sim/collide.py +270 -26
  73. warp/sim/import_urdf.py +8 -8
  74. warp/sim/integrator_euler.py +25 -7
  75. warp/sim/integrator_featherstone.py +154 -35
  76. warp/sim/integrator_vbd.py +842 -40
  77. warp/sim/model.py +134 -72
  78. warp/sparse.py +1 -1
  79. warp/stubs.py +265 -132
  80. warp/tape.py +28 -30
  81. warp/tests/aux_test_module_unload.py +15 -0
  82. warp/tests/{test_sim_grad.py → flaky_test_sim_grad.py} +104 -63
  83. warp/tests/test_array.py +74 -0
  84. warp/tests/test_assert.py +242 -0
  85. warp/tests/test_codegen.py +14 -61
  86. warp/tests/test_collision.py +2 -2
  87. warp/tests/test_coloring.py +12 -2
  88. warp/tests/test_examples.py +12 -1
  89. warp/tests/test_func.py +21 -4
  90. warp/tests/test_grad_debug.py +87 -2
  91. warp/tests/test_hash_grid.py +1 -1
  92. warp/tests/test_ipc.py +116 -0
  93. warp/tests/test_lerp.py +13 -87
  94. warp/tests/test_mat.py +138 -167
  95. warp/tests/test_math.py +47 -1
  96. warp/tests/test_matmul.py +17 -16
  97. warp/tests/test_matmul_lite.py +10 -15
  98. warp/tests/test_mesh.py +84 -60
  99. warp/tests/test_mesh_query_aabb.py +165 -0
  100. warp/tests/test_mesh_query_point.py +328 -286
  101. warp/tests/test_mesh_query_ray.py +134 -121
  102. warp/tests/test_mlp.py +2 -2
  103. warp/tests/test_operators.py +43 -0
  104. warp/tests/test_overwrite.py +47 -2
  105. warp/tests/test_quat.py +77 -0
  106. warp/tests/test_reload.py +29 -0
  107. warp/tests/test_sim_grad_bounce_linear.py +204 -0
  108. warp/tests/test_smoothstep.py +17 -83
  109. warp/tests/test_static.py +19 -3
  110. warp/tests/test_tape.py +25 -0
  111. warp/tests/test_tile.py +178 -191
  112. warp/tests/test_tile_load.py +356 -0
  113. warp/tests/test_tile_mathdx.py +61 -8
  114. warp/tests/test_tile_mlp.py +17 -17
  115. warp/tests/test_tile_reduce.py +24 -18
  116. warp/tests/test_tile_shared_memory.py +66 -17
  117. warp/tests/test_tile_view.py +165 -0
  118. warp/tests/test_torch.py +35 -0
  119. warp/tests/test_utils.py +36 -24
  120. warp/tests/test_vec.py +110 -0
  121. warp/tests/unittest_suites.py +29 -4
  122. warp/tests/unittest_utils.py +30 -13
  123. warp/thirdparty/unittest_parallel.py +2 -2
  124. warp/types.py +411 -101
  125. warp/utils.py +10 -7
  126. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/METADATA +92 -69
  127. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/RECORD +130 -119
  128. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/WHEEL +1 -1
  129. warp/examples/benchmarks/benchmark_tile.py +0 -179
  130. warp/native/tile_gemm.h +0 -341
  131. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/LICENSE.md +0 -0
  132. {warp_lang-1.5.0.dist-info → warp_lang-1.6.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -282,9 +282,9 @@ class StructInstance:
282
282
  else:
283
283
  # wp.array
284
284
  assert isinstance(value, array)
285
- assert types_equal(
286
- value.dtype, var.type.dtype
287
- ), f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
285
+ assert types_equal(value.dtype, var.type.dtype), (
286
+ f"assign to struct member variable {name} failed, expected type {type_repr(var.type.dtype)}, got type {type_repr(value.dtype)}"
287
+ )
288
288
  setattr(self._ctype, name, value.__ctype__())
289
289
 
290
290
  elif isinstance(var.type, Struct):
@@ -606,6 +606,9 @@ def compute_type_str(base_name, template_params):
606
606
  return "bool"
607
607
  else:
608
608
  return f"wp::{p.__name__}"
609
+ elif is_tile(p):
610
+ return p.ctype()
611
+
609
612
  return p.__name__
610
613
 
611
614
  return f"{base_name}<{','.join(map(param2str, template_params))}>"
@@ -947,7 +950,7 @@ class Adjoint:
947
950
  total_shared = 0
948
951
 
949
952
  for var in adj.variables:
950
- if is_tile(var.type) and var.type.storage == "shared":
953
+ if is_tile(var.type) and var.type.storage == "shared" and var.type.owner:
951
954
  total_shared += var.type.size_in_bytes()
952
955
 
953
956
  return total_shared + adj.max_required_extra_shared_memory
@@ -1139,6 +1142,9 @@ class Adjoint:
1139
1142
  if isinstance(var, (Reference, warp.context.Function)):
1140
1143
  return var
1141
1144
 
1145
+ if isinstance(var, int):
1146
+ return adj.add_constant(var)
1147
+
1142
1148
  if var.label is None:
1143
1149
  return adj.add_var(var.type, var.constant)
1144
1150
 
@@ -1175,25 +1181,25 @@ class Adjoint:
1175
1181
  left = adj.load(left)
1176
1182
  s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1177
1183
 
1178
- prev_comp = None
1184
+ prev_comp_var = None
1179
1185
 
1180
1186
  for op, comp in zip(op_strings, comps):
1181
1187
  comp_chainable = op_str_is_chainable(op)
1182
- if comp_chainable and prev_comp:
1183
- # We restrict chaining to operands of the same type
1184
- if prev_comp.type is comp.type:
1185
- prev_comp = adj.load(prev_comp)
1186
- comp = adj.load(comp)
1187
- s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
1188
+ if comp_chainable and prev_comp_var:
1189
+ # We restrict chaining to operands of the same type
1190
+ if prev_comp_var.type is comp.type:
1191
+ prev_comp_var = adj.load(prev_comp_var)
1192
+ comp_var = adj.load(comp)
1193
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1188
1194
  else:
1189
1195
  raise WarpCodegenTypeError(
1190
- f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
1196
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1191
1197
  )
1192
1198
  else:
1193
- comp = adj.load(comp)
1194
- s += op + " " + comp.emit() + ") "
1199
+ comp_var = adj.load(comp)
1200
+ s += op + " " + comp_var.emit() + ") "
1195
1201
 
1196
- prev_comp = comp
1202
+ prev_comp_var = comp_var
1197
1203
 
1198
1204
  s = s.rstrip() + ";"
1199
1205
 
@@ -1349,8 +1355,9 @@ class Adjoint:
1349
1355
  # which allows for some more advanced resolution to be performed,
1350
1356
  # for example by checking whether an argument corresponds to
1351
1357
  # a literal value or references a variable.
1358
+ extra_shared_memory = 0
1352
1359
  if func.lto_dispatch_func is not None:
1353
- func_args, template_args, ltoirs = func.lto_dispatch_func(
1360
+ func_args, template_args, ltoirs, extra_shared_memory = func.lto_dispatch_func(
1354
1361
  func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
1362
  )
1356
1363
  elif func.dispatch_func is not None:
@@ -1366,13 +1373,15 @@ class Adjoint:
1366
1373
  fwd_args = []
1367
1374
  for func_arg in func_args:
1368
1375
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1369
- func_arg = adj.load(func_arg)
1376
+ func_arg_var = adj.load(func_arg)
1377
+ else:
1378
+ func_arg_var = func_arg
1370
1379
 
1371
1380
  # if the argument is a function (and not a builtin), then build it recursively
1372
- if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
1373
- adj.builder.build_function(func_arg)
1381
+ if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1382
+ adj.builder.build_function(func_arg_var)
1374
1383
 
1375
- fwd_args.append(strip_reference(func_arg))
1384
+ fwd_args.append(strip_reference(func_arg_var))
1376
1385
 
1377
1386
  if return_type is None:
1378
1387
  # handles expression (zero output) functions, e.g.: void do_something();
@@ -1422,7 +1431,9 @@ class Adjoint:
1422
1431
  # update our smem roofline requirements based on any
1423
1432
  # shared memory required by the dependent function call
1424
1433
  if not func.is_builtin():
1425
- adj.alloc_shared_extra(func.adj.get_total_required_shared())
1434
+ adj.alloc_shared_extra(func.adj.get_total_required_shared() + extra_shared_memory)
1435
+ else:
1436
+ adj.alloc_shared_extra(extra_shared_memory)
1426
1437
 
1427
1438
  return output
1428
1439
 
@@ -1525,7 +1536,8 @@ class Adjoint:
1525
1536
  # zero adjoints
1526
1537
  for i in body_block.vars:
1527
1538
  if is_tile(i.type):
1528
- reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1539
+ if i.type.owner:
1540
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1529
1541
  else:
1530
1542
  reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1531
1543
 
@@ -1855,6 +1867,17 @@ class Adjoint:
1855
1867
  # stubbed @wp.native_func
1856
1868
  return
1857
1869
 
1870
+ def emit_Assert(adj, node):
1871
+ # eval condition
1872
+ cond = adj.eval(node.test)
1873
+ cond = adj.load(cond)
1874
+
1875
+ source_segment = ast.get_source_segment(adj.source, node)
1876
+ # If a message was provided with the assert, " marks can interfere with the generated code
1877
+ escaped_segment = source_segment.replace('"', '\\"')
1878
+
1879
+ adj.add_forward(f'assert(("{escaped_segment}",{cond.emit()}));')
1880
+
1858
1881
  def emit_NameConstant(adj, node):
1859
1882
  if node.value:
1860
1883
  return adj.add_constant(node.value)
@@ -1898,12 +1921,25 @@ class Adjoint:
1898
1921
 
1899
1922
  name = builtin_operators[type(node.op)]
1900
1923
 
1924
+ try:
1925
+ # Check if there is any user-defined overload for this operator
1926
+ user_func = adj.resolve_external_reference(name)
1927
+ if isinstance(user_func, warp.context.Function):
1928
+ return adj.add_call(user_func, (left, right), {}, {})
1929
+ except WarpCodegenError:
1930
+ pass
1931
+
1901
1932
  return adj.add_builtin_call(name, [left, right])
1902
1933
 
1903
1934
  def emit_UnaryOp(adj, node):
1904
1935
  # evaluate unary op arguments
1905
1936
  arg = adj.eval(node.operand)
1906
1937
 
1938
+ # evaluate expression to a compile-time constant if arg is a constant
1939
+ if arg.constant is not None and math.isfinite(arg.constant):
1940
+ if isinstance(node.op, ast.USub):
1941
+ return adj.add_constant(-arg.constant)
1942
+
1907
1943
  name = builtin_operators[type(node.op)]
1908
1944
 
1909
1945
  return adj.add_builtin_call(name, [arg])
@@ -2348,12 +2384,16 @@ class Adjoint:
2348
2384
  out.is_write = target.is_write
2349
2385
 
2350
2386
  elif is_tile(target_type):
2351
- if len(indices) == 2:
2387
+ if len(indices) == len(target_type.shape):
2352
2388
  # handles extracting a single element from a tile
2353
2389
  out = adj.add_builtin_call("tile_extract", [target, *indices])
2354
- else:
2390
+ elif len(indices) < len(target_type.shape):
2355
2391
  # handles tile views
2356
- out = adj.add_builtin_call("tile_view", [target, *indices])
2392
+ out = adj.add_builtin_call("tile_view", [target, indices])
2393
+ else:
2394
+ raise RuntimeError(
2395
+ f"Incorrect number of indices specified for a tile view/extract, got {len(indices)} indices for a {len(target_type.shape)} dimensional tile."
2396
+ )
2357
2397
 
2358
2398
  else:
2359
2399
  # handles non-array type indexing, e.g: vec3, mat33, etc
@@ -2445,6 +2485,9 @@ class Adjoint:
2445
2485
 
2446
2486
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2447
2487
 
2488
+ elif is_tile(target_type):
2489
+ adj.add_builtin_call("assign", [target, *indices, rhs])
2490
+
2448
2491
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2449
2492
  # recursively unwind AST, stopping at penultimate node
2450
2493
  node = lhs
@@ -2471,15 +2514,18 @@ class Adjoint:
2471
2514
  print(
2472
2515
  f"Warning: mutating {node_source} in function {adj.fun_name} at {adj.filename}:{lineno}: this is a non-differentiable operation.\n{line}\n"
2473
2516
  )
2474
-
2475
2517
  else:
2476
- out = adj.add_builtin_call("assign", [target, *indices, rhs])
2477
-
2478
- # re-point target symbol to out var
2479
- for id in adj.symbols:
2480
- if adj.symbols[id] == target:
2481
- adj.symbols[id] = out
2482
- break
2518
+ if adj.builder_options.get("enable_backward", True):
2519
+ out = adj.add_builtin_call("assign", [target, *indices, rhs])
2520
+
2521
+ # re-point target symbol to out var
2522
+ for id in adj.symbols:
2523
+ if adj.symbols[id] == target:
2524
+ adj.symbols[id] = out
2525
+ break
2526
+ else:
2527
+ attr = adj.add_builtin_call("index", [target, *indices])
2528
+ adj.add_builtin_call("store", [attr, rhs])
2483
2529
 
2484
2530
  else:
2485
2531
  raise WarpCodegenError(
@@ -2516,22 +2562,23 @@ class Adjoint:
2516
2562
 
2517
2563
  # assigning to a vector or quaternion component
2518
2564
  if type_is_vector(aggregate_type) or type_is_quaternion(aggregate_type):
2519
- # TODO: handle wp.adjoint case
2520
-
2521
2565
  index = adj.vector_component_index(lhs.attr, aggregate_type)
2522
2566
 
2523
- # TODO: array vec component case
2524
2567
  if is_reference(aggregate.type):
2525
2568
  attr = adj.add_builtin_call("indexref", [aggregate, index])
2526
2569
  adj.add_builtin_call("store", [attr, rhs])
2527
2570
  else:
2528
- out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2529
-
2530
- # re-point target symbol to out var
2531
- for id in adj.symbols:
2532
- if adj.symbols[id] == aggregate:
2533
- adj.symbols[id] = out
2534
- break
2571
+ if adj.builder_options.get("enable_backward", True):
2572
+ out = adj.add_builtin_call("assign", [aggregate, index, rhs])
2573
+
2574
+ # re-point target symbol to out var
2575
+ for id in adj.symbols:
2576
+ if adj.symbols[id] == aggregate:
2577
+ adj.symbols[id] = out
2578
+ break
2579
+ else:
2580
+ attr = adj.add_builtin_call("index", [aggregate, index])
2581
+ adj.add_builtin_call("store", [attr, rhs])
2535
2582
 
2536
2583
  else:
2537
2584
  attr = adj.emit_Attribute(lhs)
@@ -2569,8 +2616,10 @@ class Adjoint:
2569
2616
  adj.return_var = ()
2570
2617
  for ret in var:
2571
2618
  if is_reference(ret.type):
2572
- ret = adj.add_builtin_call("copy", [ret])
2573
- adj.return_var += (ret,)
2619
+ ret_var = adj.add_builtin_call("copy", [ret])
2620
+ else:
2621
+ ret_var = ret
2622
+ adj.return_var += (ret_var,)
2574
2623
 
2575
2624
  adj.add_return(adj.return_var)
2576
2625
 
@@ -2633,10 +2682,14 @@ class Adjoint:
2633
2682
  make_new_assign_statement()
2634
2683
  return
2635
2684
 
2636
- # TODO
2637
2685
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
2638
- make_new_assign_statement()
2639
- return
2686
+ if isinstance(node.op, ast.Add):
2687
+ adj.add_builtin_call("augassign_add", [target, *indices, rhs])
2688
+ elif isinstance(node.op, ast.Sub):
2689
+ adj.add_builtin_call("augassign_sub", [target, *indices, rhs])
2690
+ else:
2691
+ make_new_assign_statement()
2692
+ return
2640
2693
 
2641
2694
  else:
2642
2695
  raise WarpCodegenError("Can only subscript in-place assign array, vector, quaternion, and matrix types")
@@ -2684,6 +2737,7 @@ class Adjoint:
2684
2737
  ast.Tuple: emit_Tuple,
2685
2738
  ast.Pass: emit_Pass,
2686
2739
  ast.Ellipsis: emit_Ellipsis,
2740
+ ast.Assert: emit_Assert,
2687
2741
  }
2688
2742
 
2689
2743
  def eval(adj, node):
@@ -2846,11 +2900,62 @@ class Adjoint:
2846
2900
  if static_code is None:
2847
2901
  raise WarpCodegenError("Error extracting source code from wp.static() expression")
2848
2902
 
2903
+ # Since this is an expression, we can enforce it to be defined on a single line.
2904
+ static_code = static_code.replace("\n", "")
2905
+
2849
2906
  vars_dict = adj.get_static_evaluation_context()
2850
2907
  # add constant variables to the static call context
2851
2908
  constant_vars = {k: v.constant for k, v in adj.symbols.items() if isinstance(v, Var) and v.constant is not None}
2852
2909
  vars_dict.update(constant_vars)
2853
2910
 
2911
+ # Replace all constant `len()` expressions with their value.
2912
+ if "len" in static_code:
2913
+
2914
+ def eval_len(obj):
2915
+ if type_is_vector(obj):
2916
+ return obj._length_
2917
+ elif type_is_quaternion(obj):
2918
+ return obj._length_
2919
+ elif type_is_matrix(obj):
2920
+ return obj._shape_[0]
2921
+ elif type_is_transformation(obj):
2922
+ return obj._length_
2923
+ elif is_tile(obj):
2924
+ return obj.shape[0]
2925
+
2926
+ return len(obj)
2927
+
2928
+ len_expr_ctx = vars_dict.copy()
2929
+ constant_types = {k: v.type for k, v in adj.symbols.items() if isinstance(v, Var) and v.type is not None}
2930
+ len_expr_ctx.update(constant_types)
2931
+ len_expr_ctx.update({"len": eval_len})
2932
+
2933
+ # We want to replace the expression code in-place,
2934
+ # so reparse it to get the correct column info.
2935
+ len_value_locs = []
2936
+ expr_tree = ast.parse(static_code)
2937
+ assert len(expr_tree.body) == 1 and isinstance(expr_tree.body[0], ast.Expr)
2938
+ expr_root = expr_tree.body[0].value
2939
+ for expr_node in ast.walk(expr_root):
2940
+ if isinstance(expr_node, ast.Call) and expr_node.func.id == "len" and len(expr_node.args) == 1:
2941
+ len_expr = static_code[expr_node.col_offset : expr_node.end_col_offset]
2942
+ try:
2943
+ len_value = eval(len_expr, len_expr_ctx)
2944
+ except Exception:
2945
+ pass
2946
+ else:
2947
+ len_value_locs.append((len_value, expr_node.col_offset, expr_node.end_col_offset))
2948
+
2949
+ if len_value_locs:
2950
+ new_static_code = ""
2951
+ loc = 0
2952
+ for value, start, end in len_value_locs:
2953
+ new_static_code += f"{static_code[loc:start]}{value}"
2954
+ loc = end
2955
+
2956
+ new_static_code += static_code[len_value_locs[-1][2] :]
2957
+ static_code = new_static_code
2958
+
2854
2959
  try:
2855
2960
  value = eval(static_code, vars_dict)
2856
2961
  if warp.config.verbose:
@@ -3135,7 +3240,7 @@ static CUDA_CALLABLE void adj_{name}(
3135
3240
 
3136
3241
  """
3137
3242
 
3138
- cuda_kernel_template = """
3243
+ cuda_kernel_template_forward = """
3139
3244
 
3140
3245
  extern "C" __global__ void {name}_cuda_kernel_forward(
3141
3246
  {forward_args})
@@ -3150,6 +3255,10 @@ extern "C" __global__ void {name}_cuda_kernel_forward(
3150
3255
  {forward_body} }}
3151
3256
  }}
3152
3257
 
3258
+ """
3259
+
3260
+ cuda_kernel_template_backward = """
3261
+
3153
3262
  extern "C" __global__ void {name}_cuda_kernel_backward(
3154
3263
  {reverse_args})
3155
3264
  {{
@@ -3165,13 +3274,17 @@ extern "C" __global__ void {name}_cuda_kernel_backward(
3165
3274
 
3166
3275
  """
3167
3276
 
3168
- cpu_kernel_template = """
3277
+ cpu_kernel_template_forward = """
3169
3278
 
3170
3279
  void {name}_cpu_kernel_forward(
3171
3280
  {forward_args})
3172
3281
  {{
3173
3282
  {forward_body}}}
3174
3283
 
3284
+ """
3285
+
3286
+ cpu_kernel_template_backward = """
3287
+
3175
3288
  void {name}_cpu_kernel_backward(
3176
3289
  {reverse_args})
3177
3290
  {{
@@ -3179,7 +3292,7 @@ void {name}_cpu_kernel_backward(
3179
3292
 
3180
3293
  """
3181
3294
 
3182
- cpu_module_template = """
3295
+ cpu_module_template_forward = """
3183
3296
 
3184
3297
  extern "C" {{
3185
3298
 
@@ -3194,6 +3307,14 @@ WP_API void {name}_cpu_forward(
3194
3307
  }}
3195
3308
  }}
3196
3309
 
3310
+ }} // extern C
3311
+
3312
+ """
3313
+
3314
+ cpu_module_template_backward = """
3315
+
3316
+ extern "C" {{
3317
+
3197
3318
  WP_API void {name}_cpu_backward(
3198
3319
  {reverse_args})
3199
3320
  {{
@@ -3208,36 +3329,6 @@ WP_API void {name}_cpu_backward(
3208
3329
 
3209
3330
  """
3210
3331
 
3211
- cuda_module_header_template = """
3212
-
3213
- extern "C" {{
3214
-
3215
- // Python CUDA entry points
3216
- WP_API void {name}_cuda_forward(
3217
- void* stream,
3218
- {forward_args});
3219
-
3220
- WP_API void {name}_cuda_backward(
3221
- void* stream,
3222
- {reverse_args});
3223
-
3224
- }} // extern C
3225
- """
3226
-
3227
- cpu_module_header_template = """
3228
-
3229
- extern "C" {{
3230
-
3231
- // Python CPU entry points
3232
- WP_API void {name}_cpu_forward(
3233
- {forward_args});
3234
-
3235
- WP_API void {name}_cpu_backward(
3236
- {reverse_args});
3237
-
3238
- }} // extern C
3239
- """
3240
-
3241
3332
 
3242
3333
  # converts a constant Python value to equivalent C-repr
3243
3334
  def constant_str(value):
@@ -3675,59 +3766,82 @@ def codegen_kernel(kernel, device, options):
3675
3766
 
3676
3767
  adj = kernel.adj
3677
3768
 
3678
- forward_args = ["wp::launch_bounds_t dim"]
3679
- reverse_args = ["wp::launch_bounds_t dim"]
3769
+ if device == "cpu":
3770
+ template_forward = cpu_kernel_template_forward
3771
+ template_backward = cpu_kernel_template_backward
3772
+ elif device == "cuda":
3773
+ template_forward = cuda_kernel_template_forward
3774
+ template_backward = cuda_kernel_template_backward
3775
+ else:
3776
+ raise ValueError(f"Device {device} is not supported")
3777
+
3778
+ template = ""
3779
+ template_fmt_args = {
3780
+ "name": kernel.get_mangled_name(),
3781
+ }
3680
3782
 
3783
+ # build forward signature
3784
+ forward_args = ["wp::launch_bounds_t dim"]
3681
3785
  if device == "cpu":
3682
3786
  forward_args.append("size_t task_index")
3683
- reverse_args.append("size_t task_index")
3684
3787
 
3685
- # forward args
3686
3788
  for arg in adj.args:
3687
3789
  forward_args.append(arg.ctype() + " var_" + arg.label)
3688
- reverse_args.append(arg.ctype() + " var_" + arg.label)
3689
3790
 
3690
- # reverse args
3691
- for arg in adj.args:
3692
- # indexed array gradients are regular arrays
3693
- if isinstance(arg.type, indexedarray):
3694
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3695
- reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3696
- else:
3697
- reverse_args.append(arg.ctype() + " adj_" + arg.label)
3698
-
3699
- # codegen body
3700
3791
  forward_body = codegen_func_forward(adj, func_type="kernel", device=device)
3792
+ template_fmt_args.update(
3793
+ {
3794
+ "forward_args": indent(forward_args),
3795
+ "forward_body": forward_body,
3796
+ }
3797
+ )
3798
+ template += template_forward
3701
3799
 
3702
3800
  if options["enable_backward"]:
3703
- reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3704
- else:
3705
- reverse_body = ""
3801
+ # build reverse signature
3802
+ reverse_args = ["wp::launch_bounds_t dim"]
3803
+ if device == "cpu":
3804
+ reverse_args.append("size_t task_index")
3706
3805
 
3707
- if device == "cpu":
3708
- template = cpu_kernel_template
3709
- elif device == "cuda":
3710
- template = cuda_kernel_template
3711
- else:
3712
- raise ValueError(f"Device {device} is not supported")
3806
+ for arg in adj.args:
3807
+ reverse_args.append(arg.ctype() + " var_" + arg.label)
3713
3808
 
3714
- s = template.format(
3715
- name=kernel.get_mangled_name(),
3716
- forward_args=indent(forward_args),
3717
- reverse_args=indent(reverse_args),
3718
- forward_body=forward_body,
3719
- reverse_body=reverse_body,
3720
- )
3809
+ for arg in adj.args:
3810
+ # indexed array gradients are regular arrays
3811
+ if isinstance(arg.type, indexedarray):
3812
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3813
+ reverse_args.append(_arg.ctype() + " adj_" + arg.label)
3814
+ else:
3815
+ reverse_args.append(arg.ctype() + " adj_" + arg.label)
3721
3816
 
3817
+ reverse_body = codegen_func_reverse(adj, func_type="kernel", device=device)
3818
+ template_fmt_args.update(
3819
+ {
3820
+ "reverse_args": indent(reverse_args),
3821
+ "reverse_body": reverse_body,
3822
+ }
3823
+ )
3824
+ template += template_backward
3825
+
3826
+ s = template.format(**template_fmt_args)
3722
3827
  return s
3723
3828
 
3724
3829
 
3725
- def codegen_module(kernel, device="cpu"):
3830
+ def codegen_module(kernel, device, options):
3726
3831
  if device != "cpu":
3727
3832
  return ""
3728
3833
 
3834
+ # Update the module's options with the ones defined on the kernel, if any.
3835
+ options = dict(options)
3836
+ options.update(kernel.options)
3837
+
3729
3838
  adj = kernel.adj
3730
3839
 
3840
+ template = ""
3841
+ template_fmt_args = {
3842
+ "name": kernel.get_mangled_name(),
3843
+ }
3844
+
3731
3845
  # build forward signature
3732
3846
  forward_args = ["wp::launch_bounds_t dim"]
3733
3847
  forward_params = ["dim", "task_index"]
@@ -3741,29 +3855,40 @@ def codegen_module(kernel, device="cpu"):
3741
3855
  forward_args.append(f"{arg.ctype()} var_{arg.label}")
3742
3856
  forward_params.append("var_" + arg.label)
3743
3857
 
3744
- # build reverse signature
3745
- reverse_args = [*forward_args]
3746
- reverse_params = [*forward_params]
3858
+ template_fmt_args.update(
3859
+ {
3860
+ "forward_args": indent(forward_args),
3861
+ "forward_params": indent(forward_params, 3),
3862
+ }
3863
+ )
3864
+ template += cpu_module_template_forward
3747
3865
 
3748
- for arg in adj.args:
3749
- if isinstance(arg.type, indexedarray):
3750
- # indexed array gradients are regular arrays
3751
- _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3752
- reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3753
- reverse_params.append(f"adj_{_arg.label}")
3754
- elif hasattr(arg.type, "_wp_generic_type_str_"):
3755
- # vectors and matrices are passed from Python by pointer
3756
- reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3757
- reverse_params.append(f"*adj_{arg.label}")
3758
- else:
3759
- reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3760
- reverse_params.append(f"adj_{arg.label}")
3866
+ if options["enable_backward"]:
3867
+ # build reverse signature
3868
+ reverse_args = [*forward_args]
3869
+ reverse_params = [*forward_params]
3761
3870
 
3762
- s = cpu_module_template.format(
3763
- name=kernel.get_mangled_name(),
3764
- forward_args=indent(forward_args),
3765
- reverse_args=indent(reverse_args),
3766
- forward_params=indent(forward_params, 3),
3767
- reverse_params=indent(reverse_params, 3),
3768
- )
3871
+ for arg in adj.args:
3872
+ if isinstance(arg.type, indexedarray):
3873
+ # indexed array gradients are regular arrays
3874
+ _arg = Var(arg.label, array(dtype=arg.type.dtype, ndim=arg.type.ndim))
3875
+ reverse_args.append(f"const {_arg.ctype()} adj_{arg.label}")
3876
+ reverse_params.append(f"adj_{_arg.label}")
3877
+ elif hasattr(arg.type, "_wp_generic_type_str_"):
3878
+ # vectors and matrices are passed from Python by pointer
3879
+ reverse_args.append(f"const {arg.ctype()}* adj_{arg.label}")
3880
+ reverse_params.append(f"*adj_{arg.label}")
3881
+ else:
3882
+ reverse_args.append(f"{arg.ctype()} adj_{arg.label}")
3883
+ reverse_params.append(f"adj_{arg.label}")
3884
+
3885
+ template_fmt_args.update(
3886
+ {
3887
+ "reverse_args": indent(reverse_args),
3888
+ "reverse_params": indent(reverse_params, 3),
3889
+ }
3890
+ )
3891
+ template += cpu_module_template_backward
3892
+
3893
+ s = template.format(**template_fmt_args)
3769
3894
  return s
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.5.0"
10
+ version: str = "1.6.0"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False