warp-lang 1.4.2__py3-none-win_amd64.whl → 1.5.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (158) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1783 -2
  8. warp/codegen.py +177 -45
  9. warp/config.py +2 -2
  10. warp/context.py +321 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +2 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -5
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +600 -0
  82. warp/native/cuda_util.cpp +14 -0
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1857 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +137 -65
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/integrator_euler.py +4 -2
  114. warp/sim/integrator_featherstone.py +115 -44
  115. warp/sim/integrator_vbd.py +6 -0
  116. warp/sim/model.py +88 -15
  117. warp/stubs.py +569 -4
  118. warp/tape.py +12 -7
  119. warp/tests/assets/pixel.npy +0 -0
  120. warp/tests/aux_test_instancing_gc.py +18 -0
  121. warp/tests/test_array.py +39 -0
  122. warp/tests/test_codegen.py +81 -1
  123. warp/tests/test_codegen_instancing.py +30 -0
  124. warp/tests/test_collision.py +110 -0
  125. warp/tests/test_coloring.py +241 -0
  126. warp/tests/test_context.py +34 -0
  127. warp/tests/test_examples.py +18 -4
  128. warp/tests/test_fem.py +453 -113
  129. warp/tests/test_func.py +13 -0
  130. warp/tests/test_generics.py +52 -0
  131. warp/tests/test_iter.py +68 -0
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_mesh_query_point.py +1 -1
  134. warp/tests/test_module_hashing.py +23 -0
  135. warp/tests/test_paddle.py +27 -87
  136. warp/tests/test_print.py +56 -1
  137. warp/tests/test_spatial.py +1 -1
  138. warp/tests/test_tile.py +700 -0
  139. warp/tests/test_tile_mathdx.py +144 -0
  140. warp/tests/test_tile_mlp.py +383 -0
  141. warp/tests/test_tile_reduce.py +374 -0
  142. warp/tests/test_tile_shared_memory.py +190 -0
  143. warp/tests/test_vbd.py +12 -20
  144. warp/tests/test_volume.py +43 -0
  145. warp/tests/unittest_suites.py +19 -2
  146. warp/tests/unittest_utils.py +4 -0
  147. warp/types.py +338 -72
  148. warp/utils.py +22 -1
  149. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  150. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/RECORD +153 -126
  151. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  152. warp/fem/field/test.py +0 -180
  153. warp/fem/field/trial.py +0 -183
  154. warp/fem/space/collocated_function_space.py +0 -102
  155. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  156. warp/fem/space/trimesh_2d_function_space.py +0 -153
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
23
23
  import warp.config
24
24
  from warp.types import *
25
25
 
26
+ # used as a globally accessible copy
27
+ # of current compile options (block_dim) etc
28
+ options = {}
29
+
26
30
 
27
31
  class WarpCodegenError(RuntimeError):
28
32
  def __init__(self, message):
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
110
114
  return None
111
115
 
112
116
 
117
+ def get_type_origin(tp):
118
+ # Compatible version of `typing.get_origin()` for Python 3.7 and older.
119
+ return getattr(tp, "__origin__", None)
120
+
121
+
122
+ def get_type_args(tp):
123
+ # Compatible version of `typing.get_args()` for Python 3.7 and older.
124
+ return getattr(tp, "__args__", ())
125
+
126
+
113
127
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
114
128
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
115
129
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -637,6 +651,8 @@ class Var:
637
651
  dtypestr = f"wp::{t.dtype.__name__}"
638
652
  classstr = f"wp::{type(t).__name__}"
639
653
  return f"{classstr}_t<{dtypestr}>"
654
+ elif is_tile(t):
655
+ return t.ctype()
640
656
  elif isinstance(t, Struct):
641
657
  return t.native_name
642
658
  elif isinstance(t, type) and issubclass(t, StructInstance):
@@ -876,7 +892,7 @@ class Adjoint:
876
892
  # use source-level argument annotations
877
893
  if len(argspec.annotations) < len(argspec.args):
878
894
  raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
879
- adj.arg_types = argspec.annotations
895
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
880
896
  else:
881
897
  # use overload argument annotations
882
898
  for arg_name in argspec.args:
@@ -914,6 +930,28 @@ class Adjoint:
914
930
  # for unit testing errors being spit out from kernels.
915
931
  adj.skip_build = False
916
932
 
933
+ # Collect the LTOIR required at link-time
934
+ adj.ltoirs = []
935
+
936
+ # allocate extra space for a function call that requires its
937
+ # own shared memory space, we treat shared memory as a stack
938
+ # where each function pushes and pops space off, the extra
939
+ # quantity is the 'roofline' amount required for the entire kernel
940
+ def alloc_shared_extra(adj, num_bytes):
941
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
942
+
943
+ # returns the total number of bytes for a function
944
+ # based on it's own requirements + worst case
945
+ # requirements of any dependent functions
946
+ def get_total_required_shared(adj):
947
+ total_shared = 0
948
+
949
+ for var in adj.variables:
950
+ if is_tile(var.type) and var.type.storage == "shared":
951
+ total_shared += var.type.size_in_bytes()
952
+
953
+ return total_shared + adj.max_required_extra_shared_memory
954
+
917
955
  # generate function ssa form and adjoint
918
956
  def build(adj, builder, default_builder_options=None):
919
957
  # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
@@ -934,6 +972,9 @@ class Adjoint:
934
972
  else:
935
973
  adj.builder_options = default_builder_options
936
974
 
975
+ global options
976
+ options = adj.builder_options
977
+
937
978
  adj.symbols = {} # map from symbols to adjoint variables
938
979
  adj.variables = [] # list of local variables (in order)
939
980
 
@@ -953,6 +994,9 @@ class Adjoint:
953
994
  # used to generate new label indices
954
995
  adj.label_count = 0
955
996
 
997
+ # tracks how much additional shared memory is required by any dependent function calls
998
+ adj.max_required_extra_shared_memory = 0
999
+
956
1000
  # update symbol map for each argument
957
1001
  for a in adj.args:
958
1002
  adj.symbols[a.label] = a
@@ -969,6 +1013,7 @@ class Adjoint:
969
1013
  e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
970
1014
  finally:
971
1015
  adj.skip_build = True
1016
+ adj.builder = None
972
1017
  raise e
973
1018
 
974
1019
  if builder is not None:
@@ -978,6 +1023,9 @@ class Adjoint:
978
1023
  elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
979
1024
  builder.build_struct_recursive(a.type.dtype)
980
1025
 
1026
+ # release builder reference for GC
1027
+ adj.builder = None
1028
+
981
1029
  # code generation methods
982
1030
  def format_template(adj, template, input_vars, output_var):
983
1031
  # output var is always the 0th index
@@ -994,9 +1042,9 @@ class Adjoint:
994
1042
  if isinstance(a, warp.context.Function):
995
1043
  # functions don't have a var_ prefix so strip it off here
996
1044
  if prefix == "var":
997
- arg_strs.append(a.native_func)
1045
+ arg_strs.append(f"{a.namespace}{a.native_func}")
998
1046
  else:
999
- arg_strs.append(f"{prefix}_{a.native_func}")
1047
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1000
1048
  elif is_reference(a.type):
1001
1049
  arg_strs.append(f"{prefix}_{a}")
1002
1050
  elif isinstance(a, Var):
@@ -1278,15 +1326,34 @@ class Adjoint:
1278
1326
  bound_arg_values,
1279
1327
  )
1280
1328
 
1281
- if func.dispatch_func is not None:
1282
- # If we have a built-in that requires special handling to dispatch
1283
- # the arguments to the underlying C++ function, then we can resolve
1284
- # these using the `dispatch_func`. Since this is only called from
1285
- # within codegen, we pass it directly `codegen.Var` objects,
1286
- # which allows for some more advanced resolution to be performed,
1287
- # for example by checking whether an argument corresponds to
1288
- # a literal value or references a variable.
1329
+ # immediately allocate output variables so we can pass them into the dispatch method
1330
+ if return_type is None:
1331
+ # void function
1332
+ output = None
1333
+ output_list = []
1334
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1335
+ # single return value function
1336
+ if isinstance(return_type, Sequence):
1337
+ return_type = return_type[0]
1338
+ output = adj.add_var(return_type)
1339
+ output_list = [output]
1340
+ else:
1341
+ # multiple return value function
1342
+ output = [adj.add_var(v) for v in return_type]
1343
+ output_list = output
1289
1344
 
1345
+ # If we have a built-in that requires special handling to dispatch
1346
+ # the arguments to the underlying C++ function, then we can resolve
1347
+ # these using the `dispatch_func`. Since this is only called from
1348
+ # within codegen, we pass it directly `codegen.Var` objects,
1349
+ # which allows for some more advanced resolution to be performed,
1350
+ # for example by checking whether an argument corresponds to
1351
+ # a literal value or references a variable.
1352
+ if func.lto_dispatch_func is not None:
1353
+ func_args, template_args, ltoirs = func.lto_dispatch_func(
1354
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
+ )
1356
+ elif func.dispatch_func is not None:
1290
1357
  func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1291
1358
  else:
1292
1359
  func_args = tuple(bound_args.values())
@@ -1301,18 +1368,14 @@ class Adjoint:
1301
1368
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1302
1369
  func_arg = adj.load(func_arg)
1303
1370
 
1304
- # if the argument is a function, build it recursively
1305
- if isinstance(func_arg, warp.context.Function):
1371
+ # if the argument is a function (and not a builtin), then build it recursively
1372
+ if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
1306
1373
  adj.builder.build_function(func_arg)
1307
1374
 
1308
1375
  fwd_args.append(strip_reference(func_arg))
1309
1376
 
1310
1377
  if return_type is None:
1311
1378
  # handles expression (zero output) functions, e.g.: void do_something();
1312
-
1313
- output = None
1314
- output_list = []
1315
-
1316
1379
  forward_call = (
1317
1380
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1318
1381
  )
@@ -1322,12 +1385,6 @@ class Adjoint:
1322
1385
 
1323
1386
  elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1324
1387
  # handle simple function (one output)
1325
-
1326
- if isinstance(return_type, Sequence):
1327
- return_type = return_type[0]
1328
- output = adj.add_var(return_type)
1329
- output_list = [output]
1330
-
1331
1388
  forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1332
1389
  replay_call = forward_call
1333
1390
  if func.custom_replay_func is not None:
@@ -1335,10 +1392,6 @@ class Adjoint:
1335
1392
 
1336
1393
  else:
1337
1394
  # handle multiple value functions
1338
-
1339
- output = [adj.add_var(v) for v in return_type]
1340
- output_list = output
1341
-
1342
1395
  forward_call = (
1343
1396
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1344
1397
  )
@@ -1366,6 +1419,11 @@ class Adjoint:
1366
1419
  reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1367
1420
  adj.add_reverse(reverse_call)
1368
1421
 
1422
+ # update our smem roofline requirements based on any
1423
+ # shared memory required by the dependent function call
1424
+ if not func.is_builtin():
1425
+ adj.alloc_shared_extra(func.adj.get_total_required_shared())
1426
+
1369
1427
  return output
1370
1428
 
1371
1429
  def add_builtin_call(adj, func_name, args, min_outputs=None):
@@ -1466,7 +1524,10 @@ class Adjoint:
1466
1524
 
1467
1525
  # zero adjoints
1468
1526
  for i in body_block.vars:
1469
- reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1527
+ if is_tile(i.type):
1528
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1529
+ else:
1530
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1470
1531
 
1471
1532
  # replay
1472
1533
  for i in body_block.body_replay:
@@ -2206,7 +2267,7 @@ class Adjoint:
2206
2267
 
2207
2268
  # returns the object being indexed, and the list of indices
2208
2269
  def eval_subscript(adj, node):
2209
- # We want to coalesce multi-dimentional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2270
+ # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2210
2271
  # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2211
2272
  # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2212
2273
  root = node
@@ -2286,6 +2347,14 @@ class Adjoint:
2286
2347
  out.is_read = target.is_read
2287
2348
  out.is_write = target.is_write
2288
2349
 
2350
+ elif is_tile(target_type):
2351
+ if len(indices) == 2:
2352
+ # handles extracting a single element from a tile
2353
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2354
+ else:
2355
+ # handles tile views
2356
+ out = adj.add_builtin_call("tile_view", [target, *indices])
2357
+
2289
2358
  else:
2290
2359
  # handles non-array type indexing, e.g: vec3, mat33, etc
2291
2360
  out = adj.add_builtin_call("extract", [target, *indices])
@@ -2527,11 +2596,22 @@ class Adjoint:
2527
2596
  target_type = strip_reference(target.type)
2528
2597
 
2529
2598
  if is_array(target_type):
2530
- # target_type is not suitable for atomic array accumulation
2531
- if target_type.dtype not in warp.types.atomic_types:
2599
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2600
+ if target_type.dtype in warp.types.non_atomic_types:
2532
2601
  make_new_assign_statement()
2533
2602
  return
2534
2603
 
2604
+ # the same holds true for vecs/mats/quats that are composed of these types
2605
+ if (
2606
+ type_is_vector(target_type.dtype)
2607
+ or type_is_quaternion(target_type.dtype)
2608
+ or type_is_matrix(target_type.dtype)
2609
+ ):
2610
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2611
+ if dtype in warp.types.non_atomic_types:
2612
+ make_new_assign_statement()
2613
+ return
2614
+
2535
2615
  kernel_name = adj.fun_name
2536
2616
  filename = adj.filename
2537
2617
  lineno = adj.lineno + adj.fun_lineno
@@ -2955,6 +3035,7 @@ class Adjoint:
2955
3035
  # code generation
2956
3036
 
2957
3037
  cpu_module_header = """
3038
+ #define WP_TILE_BLOCK_DIM {tile_size}
2958
3039
  #define WP_NO_CRT
2959
3040
  #include "builtin.h"
2960
3041
 
@@ -2965,7 +3046,7 @@ cpu_module_header = """
2965
3046
  #define int(x) cast_int(x)
2966
3047
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2967
3048
 
2968
- #define builtin_tid1d() wp::tid(task_index)
3049
+ #define builtin_tid1d() wp::tid(task_index, dim)
2969
3050
  #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2970
3051
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2971
3052
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
@@ -2973,6 +3054,7 @@ cpu_module_header = """
2973
3054
  """
2974
3055
 
2975
3056
  cuda_module_header = """
3057
+ #define WP_TILE_BLOCK_DIM {tile_size}
2976
3058
  #define WP_NO_CRT
2977
3059
  #include "builtin.h"
2978
3060
 
@@ -2983,10 +3065,10 @@ cuda_module_header = """
2983
3065
  #define int(x) cast_int(x)
2984
3066
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2985
3067
 
2986
- #define builtin_tid1d() wp::tid(task_index)
2987
- #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2988
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2989
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3068
+ #define builtin_tid1d() wp::tid(_idx, dim)
3069
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3070
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3071
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
2990
3072
 
2991
3073
  """
2992
3074
 
@@ -3058,20 +3140,26 @@ cuda_kernel_template = """
3058
3140
  extern "C" __global__ void {name}_cuda_kernel_forward(
3059
3141
  {forward_args})
3060
3142
  {{
3061
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3062
- task_index < dim.size;
3063
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3143
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3144
+ _idx < dim.size;
3145
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3064
3146
  {{
3147
+ // reset shared memory allocator
3148
+ wp::tile_alloc_shared(0, true);
3149
+
3065
3150
  {forward_body} }}
3066
3151
  }}
3067
3152
 
3068
3153
  extern "C" __global__ void {name}_cuda_kernel_backward(
3069
3154
  {reverse_args})
3070
3155
  {{
3071
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3072
- task_index < dim.size;
3073
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3156
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3157
+ _idx < dim.size;
3158
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3074
3159
  {{
3160
+ // reset shared memory allocator
3161
+ wp::tile_alloc_shared(0, true);
3162
+
3075
3163
  {reverse_body} }}
3076
3164
  }}
3077
3165
 
@@ -3309,7 +3397,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3309
3397
  lines += ["// primal vars\n"]
3310
3398
 
3311
3399
  for var in adj.variables:
3312
- if var.constant is None:
3400
+ if is_tile(var.type):
3401
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3402
+ elif var.constant is None:
3313
3403
  lines += [f"{var.ctype()} {var.emit()};\n"]
3314
3404
  else:
3315
3405
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3344,7 +3434,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3344
3434
  lines += ["// primal vars\n"]
3345
3435
 
3346
3436
  for var in adj.variables:
3347
- if var.constant is None:
3437
+ if is_tile(var.type):
3438
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3439
+ elif var.constant is None:
3348
3440
  lines += [f"{var.ctype()} {var.emit()};\n"]
3349
3441
  else:
3350
3442
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3354,7 +3446,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3354
3446
  lines += ["// dual vars\n"]
3355
3447
 
3356
3448
  for var in adj.variables:
3357
- lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
3449
+ name = var.emit_adj()
3450
+ ctype = var.ctype(value_type=True)
3451
+
3452
+ if is_tile(var.type):
3453
+ if var.type.storage == "register":
3454
+ lines += [
3455
+ f"{var.type.ctype()} {name}(0.0);\n"
3456
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3457
+ elif var.type.storage == "shared":
3458
+ lines += [
3459
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
3460
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3461
+ else:
3462
+ lines += [f"{ctype} {name} = {{}};\n"]
3358
3463
 
3359
3464
  # forward pass
3360
3465
  lines += ["//---------\n"]
@@ -3383,6 +3488,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3383
3488
  if options is None:
3384
3489
  options = {}
3385
3490
 
3491
+ if adj.return_var is not None and "return" in adj.arg_types:
3492
+ if get_type_origin(adj.arg_types["return"]) is tuple:
3493
+ if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3494
+ raise WarpCodegenError(
3495
+ f"The function `{adj.fun_name}` has its return type "
3496
+ f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3497
+ f"but the code returns {len(adj.return_var)} values."
3498
+ )
3499
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3500
+ raise WarpCodegenError(
3501
+ f"The function `{adj.fun_name}` has its return type "
3502
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3503
+ f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3504
+ )
3505
+ elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3506
+ raise WarpCodegenError(
3507
+ f"The function `{adj.fun_name}` has its return type "
3508
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3509
+ f"but the code returns {len(adj.return_var)} values."
3510
+ )
3511
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
3512
+ raise WarpCodegenError(
3513
+ f"The function `{adj.fun_name}` has its return type "
3514
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3515
+ f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3516
+ )
3517
+
3386
3518
  # forward header
3387
3519
  if adj.return_var is not None and len(adj.return_var) == 1:
3388
3520
  return_type = adj.return_var[0].ctype()
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.4.2"
10
+ version: str = "1.5.0"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False
@@ -16,7 +16,7 @@ Has performance implications.
16
16
  """
17
17
 
18
18
  verify_cuda: bool = False
19
- """If `True`, Warp will check for CUDA errors after every launch and memory operation.
19
+ """If `True`, Warp will check for CUDA errors after every launch operation.
20
20
  CUDA error verification cannot be used during graph capture. Has performance implications.
21
21
  """
22
22