warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
23
23
  import warp.config
24
24
  from warp.types import *
25
25
 
26
+ # used as a globally accessible copy
27
+ # of current compile options (block_dim) etc
28
+ options = {}
29
+
26
30
 
27
31
  class WarpCodegenError(RuntimeError):
28
32
  def __init__(self, message):
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
110
114
  return None
111
115
 
112
116
 
117
+ def get_type_origin(tp):
118
+ # Compatible version of `typing.get_origin()` for Python 3.7 and older.
119
+ return getattr(tp, "__origin__", None)
120
+
121
+
122
+ def get_type_args(tp):
123
+ # Compatible version of `typing.get_args()` for Python 3.7 and older.
124
+ return getattr(tp, "__args__", ())
125
+
126
+
113
127
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
114
128
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
115
129
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -637,6 +651,8 @@ class Var:
637
651
  dtypestr = f"wp::{t.dtype.__name__}"
638
652
  classstr = f"wp::{type(t).__name__}"
639
653
  return f"{classstr}_t<{dtypestr}>"
654
+ elif is_tile(t):
655
+ return t.ctype()
640
656
  elif isinstance(t, Struct):
641
657
  return t.native_name
642
658
  elif isinstance(t, type) and issubclass(t, StructInstance):
@@ -876,7 +892,7 @@ class Adjoint:
876
892
  # use source-level argument annotations
877
893
  if len(argspec.annotations) < len(argspec.args):
878
894
  raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
879
- adj.arg_types = argspec.annotations
895
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
880
896
  else:
881
897
  # use overload argument annotations
882
898
  for arg_name in argspec.args:
@@ -914,6 +930,28 @@ class Adjoint:
914
930
  # for unit testing errors being spit out from kernels.
915
931
  adj.skip_build = False
916
932
 
933
+ # Collect the LTOIR required at link-time
934
+ adj.ltoirs = []
935
+
936
+ # allocate extra space for a function call that requires its
937
+ # own shared memory space, we treat shared memory as a stack
938
+ # where each function pushes and pops space off, the extra
939
+ # quantity is the 'roofline' amount required for the entire kernel
940
+ def alloc_shared_extra(adj, num_bytes):
941
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
942
+
943
+ # returns the total number of bytes for a function
944
+ # based on it's own requirements + worst case
945
+ # requirements of any dependent functions
946
+ def get_total_required_shared(adj):
947
+ total_shared = 0
948
+
949
+ for var in adj.variables:
950
+ if is_tile(var.type) and var.type.storage == "shared":
951
+ total_shared += var.type.size_in_bytes()
952
+
953
+ return total_shared + adj.max_required_extra_shared_memory
954
+
917
955
  # generate function ssa form and adjoint
918
956
  def build(adj, builder, default_builder_options=None):
919
957
  # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
@@ -934,6 +972,9 @@ class Adjoint:
934
972
  else:
935
973
  adj.builder_options = default_builder_options
936
974
 
975
+ global options
976
+ options = adj.builder_options
977
+
937
978
  adj.symbols = {} # map from symbols to adjoint variables
938
979
  adj.variables = [] # list of local variables (in order)
939
980
 
@@ -953,6 +994,9 @@ class Adjoint:
953
994
  # used to generate new label indices
954
995
  adj.label_count = 0
955
996
 
997
+ # tracks how much additional shared memory is required by any dependent function calls
998
+ adj.max_required_extra_shared_memory = 0
999
+
956
1000
  # update symbol map for each argument
957
1001
  for a in adj.args:
958
1002
  adj.symbols[a.label] = a
@@ -969,6 +1013,7 @@ class Adjoint:
969
1013
  e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
970
1014
  finally:
971
1015
  adj.skip_build = True
1016
+ adj.builder = None
972
1017
  raise e
973
1018
 
974
1019
  if builder is not None:
@@ -978,6 +1023,9 @@ class Adjoint:
978
1023
  elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
979
1024
  builder.build_struct_recursive(a.type.dtype)
980
1025
 
1026
+ # release builder reference for GC
1027
+ adj.builder = None
1028
+
981
1029
  # code generation methods
982
1030
  def format_template(adj, template, input_vars, output_var):
983
1031
  # output var is always the 0th index
@@ -994,9 +1042,9 @@ class Adjoint:
994
1042
  if isinstance(a, warp.context.Function):
995
1043
  # functions don't have a var_ prefix so strip it off here
996
1044
  if prefix == "var":
997
- arg_strs.append(a.native_func)
1045
+ arg_strs.append(f"{a.namespace}{a.native_func}")
998
1046
  else:
999
- arg_strs.append(f"{prefix}_{a.native_func}")
1047
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
1000
1048
  elif is_reference(a.type):
1001
1049
  arg_strs.append(f"{prefix}_{a}")
1002
1050
  elif isinstance(a, Var):
@@ -1127,25 +1175,25 @@ class Adjoint:
1127
1175
  left = adj.load(left)
1128
1176
  s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1129
1177
 
1130
- prev_comp = None
1178
+ prev_comp_var = None
1131
1179
 
1132
1180
  for op, comp in zip(op_strings, comps):
1133
1181
  comp_chainable = op_str_is_chainable(op)
1134
- if comp_chainable and prev_comp:
1135
- # We restrict chaining to operands of the same type
1136
- if prev_comp.type is comp.type:
1137
- prev_comp = adj.load(prev_comp)
1138
- comp = adj.load(comp)
1139
- s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
1182
+ if comp_chainable and prev_comp_var:
1183
+ # We restrict chaining to operands of the same type
1184
+ if prev_comp_var.type is comp.type:
1185
+ prev_comp_var = adj.load(prev_comp_var)
1186
+ comp_var = adj.load(comp)
1187
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1140
1188
  else:
1141
1189
  raise WarpCodegenTypeError(
1142
- f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
1190
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1143
1191
  )
1144
1192
  else:
1145
- comp = adj.load(comp)
1146
- s += op + " " + comp.emit() + ") "
1193
+ comp_var = adj.load(comp)
1194
+ s += op + " " + comp_var.emit() + ") "
1147
1195
 
1148
- prev_comp = comp
1196
+ prev_comp_var = comp_var
1149
1197
 
1150
1198
  s = s.rstrip() + ";"
1151
1199
 
@@ -1278,15 +1326,34 @@ class Adjoint:
1278
1326
  bound_arg_values,
1279
1327
  )
1280
1328
 
1281
- if func.dispatch_func is not None:
1282
- # If we have a built-in that requires special handling to dispatch
1283
- # the arguments to the underlying C++ function, then we can resolve
1284
- # these using the `dispatch_func`. Since this is only called from
1285
- # within codegen, we pass it directly `codegen.Var` objects,
1286
- # which allows for some more advanced resolution to be performed,
1287
- # for example by checking whether an argument corresponds to
1288
- # a literal value or references a variable.
1329
+ # immediately allocate output variables so we can pass them into the dispatch method
1330
+ if return_type is None:
1331
+ # void function
1332
+ output = None
1333
+ output_list = []
1334
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1335
+ # single return value function
1336
+ if isinstance(return_type, Sequence):
1337
+ return_type = return_type[0]
1338
+ output = adj.add_var(return_type)
1339
+ output_list = [output]
1340
+ else:
1341
+ # multiple return value function
1342
+ output = [adj.add_var(v) for v in return_type]
1343
+ output_list = output
1289
1344
 
1345
+ # If we have a built-in that requires special handling to dispatch
1346
+ # the arguments to the underlying C++ function, then we can resolve
1347
+ # these using the `dispatch_func`. Since this is only called from
1348
+ # within codegen, we pass it directly `codegen.Var` objects,
1349
+ # which allows for some more advanced resolution to be performed,
1350
+ # for example by checking whether an argument corresponds to
1351
+ # a literal value or references a variable.
1352
+ if func.lto_dispatch_func is not None:
1353
+ func_args, template_args, ltoirs = func.lto_dispatch_func(
1354
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
+ )
1356
+ elif func.dispatch_func is not None:
1290
1357
  func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1291
1358
  else:
1292
1359
  func_args = tuple(bound_args.values())
@@ -1299,20 +1366,18 @@ class Adjoint:
1299
1366
  fwd_args = []
1300
1367
  for func_arg in func_args:
1301
1368
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1302
- func_arg = adj.load(func_arg)
1369
+ func_arg_var = adj.load(func_arg)
1370
+ else:
1371
+ func_arg_var = func_arg
1303
1372
 
1304
- # if the argument is a function, build it recursively
1305
- if isinstance(func_arg, warp.context.Function):
1306
- adj.builder.build_function(func_arg)
1373
+ # if the argument is a function (and not a builtin), then build it recursively
1374
+ if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1375
+ adj.builder.build_function(func_arg_var)
1307
1376
 
1308
- fwd_args.append(strip_reference(func_arg))
1377
+ fwd_args.append(strip_reference(func_arg_var))
1309
1378
 
1310
1379
  if return_type is None:
1311
1380
  # handles expression (zero output) functions, e.g.: void do_something();
1312
-
1313
- output = None
1314
- output_list = []
1315
-
1316
1381
  forward_call = (
1317
1382
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1318
1383
  )
@@ -1322,12 +1387,6 @@ class Adjoint:
1322
1387
 
1323
1388
  elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1324
1389
  # handle simple function (one output)
1325
-
1326
- if isinstance(return_type, Sequence):
1327
- return_type = return_type[0]
1328
- output = adj.add_var(return_type)
1329
- output_list = [output]
1330
-
1331
1390
  forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1332
1391
  replay_call = forward_call
1333
1392
  if func.custom_replay_func is not None:
@@ -1335,10 +1394,6 @@ class Adjoint:
1335
1394
 
1336
1395
  else:
1337
1396
  # handle multiple value functions
1338
-
1339
- output = [adj.add_var(v) for v in return_type]
1340
- output_list = output
1341
-
1342
1397
  forward_call = (
1343
1398
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1344
1399
  )
@@ -1366,6 +1421,11 @@ class Adjoint:
1366
1421
  reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1367
1422
  adj.add_reverse(reverse_call)
1368
1423
 
1424
+ # update our smem roofline requirements based on any
1425
+ # shared memory required by the dependent function call
1426
+ if not func.is_builtin():
1427
+ adj.alloc_shared_extra(func.adj.get_total_required_shared())
1428
+
1369
1429
  return output
1370
1430
 
1371
1431
  def add_builtin_call(adj, func_name, args, min_outputs=None):
@@ -1466,7 +1526,10 @@ class Adjoint:
1466
1526
 
1467
1527
  # zero adjoints
1468
1528
  for i in body_block.vars:
1469
- reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1529
+ if is_tile(i.type):
1530
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1531
+ else:
1532
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1470
1533
 
1471
1534
  # replay
1472
1535
  for i in body_block.body_replay:
@@ -2206,7 +2269,7 @@ class Adjoint:
2206
2269
 
2207
2270
  # returns the object being indexed, and the list of indices
2208
2271
  def eval_subscript(adj, node):
2209
- # We want to coalesce multi-dimentional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2272
+ # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2210
2273
  # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2211
2274
  # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2212
2275
  root = node
@@ -2286,6 +2349,14 @@ class Adjoint:
2286
2349
  out.is_read = target.is_read
2287
2350
  out.is_write = target.is_write
2288
2351
 
2352
+ elif is_tile(target_type):
2353
+ if len(indices) == 2:
2354
+ # handles extracting a single element from a tile
2355
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2356
+ else:
2357
+ # handles tile views
2358
+ out = adj.add_builtin_call("tile_view", [target, *indices])
2359
+
2289
2360
  else:
2290
2361
  # handles non-array type indexing, e.g: vec3, mat33, etc
2291
2362
  out = adj.add_builtin_call("extract", [target, *indices])
@@ -2500,8 +2571,10 @@ class Adjoint:
2500
2571
  adj.return_var = ()
2501
2572
  for ret in var:
2502
2573
  if is_reference(ret.type):
2503
- ret = adj.add_builtin_call("copy", [ret])
2504
- adj.return_var += (ret,)
2574
+ ret_var = adj.add_builtin_call("copy", [ret])
2575
+ else:
2576
+ ret_var = ret
2577
+ adj.return_var += (ret_var,)
2505
2578
 
2506
2579
  adj.add_return(adj.return_var)
2507
2580
 
@@ -2527,11 +2600,22 @@ class Adjoint:
2527
2600
  target_type = strip_reference(target.type)
2528
2601
 
2529
2602
  if is_array(target_type):
2530
- # target_type is not suitable for atomic array accumulation
2531
- if target_type.dtype not in warp.types.atomic_types:
2603
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2604
+ if target_type.dtype in warp.types.non_atomic_types:
2532
2605
  make_new_assign_statement()
2533
2606
  return
2534
2607
 
2608
+ # the same holds true for vecs/mats/quats that are composed of these types
2609
+ if (
2610
+ type_is_vector(target_type.dtype)
2611
+ or type_is_quaternion(target_type.dtype)
2612
+ or type_is_matrix(target_type.dtype)
2613
+ ):
2614
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2615
+ if dtype in warp.types.non_atomic_types:
2616
+ make_new_assign_statement()
2617
+ return
2618
+
2535
2619
  kernel_name = adj.fun_name
2536
2620
  filename = adj.filename
2537
2621
  lineno = adj.lineno + adj.fun_lineno
@@ -2955,6 +3039,7 @@ class Adjoint:
2955
3039
  # code generation
2956
3040
 
2957
3041
  cpu_module_header = """
3042
+ #define WP_TILE_BLOCK_DIM {tile_size}
2958
3043
  #define WP_NO_CRT
2959
3044
  #include "builtin.h"
2960
3045
 
@@ -2965,7 +3050,7 @@ cpu_module_header = """
2965
3050
  #define int(x) cast_int(x)
2966
3051
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2967
3052
 
2968
- #define builtin_tid1d() wp::tid(task_index)
3053
+ #define builtin_tid1d() wp::tid(task_index, dim)
2969
3054
  #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2970
3055
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2971
3056
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
@@ -2973,6 +3058,7 @@ cpu_module_header = """
2973
3058
  """
2974
3059
 
2975
3060
  cuda_module_header = """
3061
+ #define WP_TILE_BLOCK_DIM {tile_size}
2976
3062
  #define WP_NO_CRT
2977
3063
  #include "builtin.h"
2978
3064
 
@@ -2983,10 +3069,10 @@ cuda_module_header = """
2983
3069
  #define int(x) cast_int(x)
2984
3070
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2985
3071
 
2986
- #define builtin_tid1d() wp::tid(task_index)
2987
- #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2988
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2989
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3072
+ #define builtin_tid1d() wp::tid(_idx, dim)
3073
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3074
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3075
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
2990
3076
 
2991
3077
  """
2992
3078
 
@@ -3058,20 +3144,26 @@ cuda_kernel_template = """
3058
3144
  extern "C" __global__ void {name}_cuda_kernel_forward(
3059
3145
  {forward_args})
3060
3146
  {{
3061
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3062
- task_index < dim.size;
3063
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3147
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3148
+ _idx < dim.size;
3149
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3064
3150
  {{
3151
+ // reset shared memory allocator
3152
+ wp::tile_alloc_shared(0, true);
3153
+
3065
3154
  {forward_body} }}
3066
3155
  }}
3067
3156
 
3068
3157
  extern "C" __global__ void {name}_cuda_kernel_backward(
3069
3158
  {reverse_args})
3070
3159
  {{
3071
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3072
- task_index < dim.size;
3073
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3160
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3161
+ _idx < dim.size;
3162
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3074
3163
  {{
3164
+ // reset shared memory allocator
3165
+ wp::tile_alloc_shared(0, true);
3166
+
3075
3167
  {reverse_body} }}
3076
3168
  }}
3077
3169
 
@@ -3309,7 +3401,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3309
3401
  lines += ["// primal vars\n"]
3310
3402
 
3311
3403
  for var in adj.variables:
3312
- if var.constant is None:
3404
+ if is_tile(var.type):
3405
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3406
+ elif var.constant is None:
3313
3407
  lines += [f"{var.ctype()} {var.emit()};\n"]
3314
3408
  else:
3315
3409
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3344,7 +3438,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3344
3438
  lines += ["// primal vars\n"]
3345
3439
 
3346
3440
  for var in adj.variables:
3347
- if var.constant is None:
3441
+ if is_tile(var.type):
3442
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3443
+ elif var.constant is None:
3348
3444
  lines += [f"{var.ctype()} {var.emit()};\n"]
3349
3445
  else:
3350
3446
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3354,7 +3450,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3354
3450
  lines += ["// dual vars\n"]
3355
3451
 
3356
3452
  for var in adj.variables:
3357
- lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
3453
+ name = var.emit_adj()
3454
+ ctype = var.ctype(value_type=True)
3455
+
3456
+ if is_tile(var.type):
3457
+ if var.type.storage == "register":
3458
+ lines += [
3459
+ f"{var.type.ctype()} {name}(0.0);\n"
3460
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3461
+ elif var.type.storage == "shared":
3462
+ lines += [
3463
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
3464
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3465
+ else:
3466
+ lines += [f"{ctype} {name} = {{}};\n"]
3358
3467
 
3359
3468
  # forward pass
3360
3469
  lines += ["//---------\n"]
@@ -3383,6 +3492,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3383
3492
  if options is None:
3384
3493
  options = {}
3385
3494
 
3495
+ if adj.return_var is not None and "return" in adj.arg_types:
3496
+ if get_type_origin(adj.arg_types["return"]) is tuple:
3497
+ if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3498
+ raise WarpCodegenError(
3499
+ f"The function `{adj.fun_name}` has its return type "
3500
+ f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3501
+ f"but the code returns {len(adj.return_var)} values."
3502
+ )
3503
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3504
+ raise WarpCodegenError(
3505
+ f"The function `{adj.fun_name}` has its return type "
3506
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3507
+ f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3508
+ )
3509
+ elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3510
+ raise WarpCodegenError(
3511
+ f"The function `{adj.fun_name}` has its return type "
3512
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3513
+ f"but the code returns {len(adj.return_var)} values."
3514
+ )
3515
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
3516
+ raise WarpCodegenError(
3517
+ f"The function `{adj.fun_name}` has its return type "
3518
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3519
+ f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3520
+ )
3521
+
3386
3522
  # forward header
3387
3523
  if adj.return_var is not None and len(adj.return_var) == 1:
3388
3524
  return_type = adj.return_var[0].ctype()
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.4.2"
10
+ version: str = "1.5.1"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False
@@ -16,7 +16,7 @@ Has performance implications.
16
16
  """
17
17
 
18
18
  verify_cuda: bool = False
19
- """If `True`, Warp will check for CUDA errors after every launch and memory operation.
19
+ """If `True`, Warp will check for CUDA errors after every launch operation.
20
20
  CUDA error verification cannot be used during graph capture. Has performance implications.
21
21
  """
22
22