warp-lang 1.4.1__py3-none-manylinux2014_aarch64.whl → 1.5.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (164) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1920 -111
  8. warp/codegen.py +186 -62
  9. warp/config.py +2 -2
  10. warp/context.py +322 -73
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/core/example_dem.py +2 -1
  17. warp/examples/core/example_mesh_intersect.py +3 -3
  18. warp/examples/fem/example_adaptive_grid.py +37 -10
  19. warp/examples/fem/example_apic_fluid.py +3 -2
  20. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  21. warp/examples/fem/example_deformed_geometry.py +1 -1
  22. warp/examples/fem/example_diffusion_3d.py +47 -4
  23. warp/examples/fem/example_distortion_energy.py +220 -0
  24. warp/examples/fem/example_magnetostatics.py +127 -85
  25. warp/examples/fem/example_nonconforming_contact.py +5 -5
  26. warp/examples/fem/example_stokes.py +3 -1
  27. warp/examples/fem/example_streamlines.py +12 -19
  28. warp/examples/fem/utils.py +38 -15
  29. warp/examples/optim/example_walker.py +2 -2
  30. warp/examples/sim/example_cloth.py +2 -25
  31. warp/examples/sim/example_jacobian_ik.py +6 -2
  32. warp/examples/sim/example_quadruped.py +2 -1
  33. warp/examples/tile/example_tile_convolution.py +58 -0
  34. warp/examples/tile/example_tile_fft.py +47 -0
  35. warp/examples/tile/example_tile_filtering.py +105 -0
  36. warp/examples/tile/example_tile_matmul.py +79 -0
  37. warp/examples/tile/example_tile_mlp.py +375 -0
  38. warp/fem/__init__.py +8 -0
  39. warp/fem/cache.py +16 -12
  40. warp/fem/dirichlet.py +1 -1
  41. warp/fem/domain.py +44 -1
  42. warp/fem/field/__init__.py +1 -2
  43. warp/fem/field/field.py +31 -19
  44. warp/fem/field/nodal_field.py +101 -49
  45. warp/fem/field/virtual.py +794 -0
  46. warp/fem/geometry/__init__.py +2 -2
  47. warp/fem/geometry/deformed_geometry.py +3 -105
  48. warp/fem/geometry/element.py +13 -0
  49. warp/fem/geometry/geometry.py +165 -5
  50. warp/fem/geometry/grid_2d.py +3 -6
  51. warp/fem/geometry/grid_3d.py +31 -28
  52. warp/fem/geometry/hexmesh.py +3 -46
  53. warp/fem/geometry/nanogrid.py +3 -2
  54. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  55. warp/fem/geometry/tetmesh.py +2 -43
  56. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  57. warp/fem/integrate.py +683 -261
  58. warp/fem/linalg.py +404 -0
  59. warp/fem/operator.py +101 -18
  60. warp/fem/polynomial.py +5 -5
  61. warp/fem/quadrature/quadrature.py +45 -21
  62. warp/fem/space/__init__.py +45 -11
  63. warp/fem/space/basis_function_space.py +451 -0
  64. warp/fem/space/basis_space.py +58 -11
  65. warp/fem/space/function_space.py +146 -5
  66. warp/fem/space/grid_2d_function_space.py +80 -66
  67. warp/fem/space/grid_3d_function_space.py +113 -68
  68. warp/fem/space/hexmesh_function_space.py +96 -108
  69. warp/fem/space/nanogrid_function_space.py +62 -110
  70. warp/fem/space/quadmesh_function_space.py +208 -0
  71. warp/fem/space/shape/__init__.py +45 -7
  72. warp/fem/space/shape/cube_shape_function.py +328 -54
  73. warp/fem/space/shape/shape_function.py +10 -1
  74. warp/fem/space/shape/square_shape_function.py +328 -60
  75. warp/fem/space/shape/tet_shape_function.py +269 -19
  76. warp/fem/space/shape/triangle_shape_function.py +238 -19
  77. warp/fem/space/tetmesh_function_space.py +69 -37
  78. warp/fem/space/topology.py +38 -0
  79. warp/fem/space/trimesh_function_space.py +179 -0
  80. warp/fem/utils.py +6 -331
  81. warp/jax_experimental.py +3 -1
  82. warp/native/array.h +55 -40
  83. warp/native/builtin.h +124 -43
  84. warp/native/bvh.h +4 -0
  85. warp/native/coloring.cpp +600 -0
  86. warp/native/cuda_util.cpp +14 -0
  87. warp/native/cuda_util.h +2 -1
  88. warp/native/fabric.h +8 -0
  89. warp/native/hashgrid.h +4 -0
  90. warp/native/marching.cu +8 -0
  91. warp/native/mat.h +14 -3
  92. warp/native/mathdx.cpp +59 -0
  93. warp/native/mesh.h +4 -0
  94. warp/native/range.h +13 -1
  95. warp/native/reduce.cpp +9 -1
  96. warp/native/reduce.cu +7 -0
  97. warp/native/runlength_encode.cpp +9 -1
  98. warp/native/runlength_encode.cu +7 -1
  99. warp/native/scan.cpp +8 -0
  100. warp/native/scan.cu +8 -0
  101. warp/native/scan.h +8 -1
  102. warp/native/sparse.cpp +8 -0
  103. warp/native/sparse.cu +8 -0
  104. warp/native/temp_buffer.h +7 -0
  105. warp/native/tile.h +1857 -0
  106. warp/native/tile_gemm.h +341 -0
  107. warp/native/tile_reduce.h +210 -0
  108. warp/native/volume_builder.cu +8 -0
  109. warp/native/volume_builder.h +8 -0
  110. warp/native/warp.cpp +10 -2
  111. warp/native/warp.cu +369 -15
  112. warp/native/warp.h +12 -2
  113. warp/optim/adam.py +39 -4
  114. warp/paddle.py +29 -12
  115. warp/render/render_opengl.py +137 -65
  116. warp/sim/graph_coloring.py +292 -0
  117. warp/sim/integrator_euler.py +4 -2
  118. warp/sim/integrator_featherstone.py +115 -44
  119. warp/sim/integrator_vbd.py +6 -0
  120. warp/sim/model.py +90 -17
  121. warp/stubs.py +651 -85
  122. warp/tape.py +12 -7
  123. warp/tests/assets/pixel.npy +0 -0
  124. warp/tests/aux_test_instancing_gc.py +18 -0
  125. warp/tests/test_array.py +207 -48
  126. warp/tests/test_closest_point_edge_edge.py +8 -8
  127. warp/tests/test_codegen.py +120 -1
  128. warp/tests/test_codegen_instancing.py +30 -0
  129. warp/tests/test_collision.py +110 -0
  130. warp/tests/test_coloring.py +241 -0
  131. warp/tests/test_context.py +34 -0
  132. warp/tests/test_examples.py +18 -4
  133. warp/tests/test_fabricarray.py +33 -0
  134. warp/tests/test_fem.py +453 -113
  135. warp/tests/test_func.py +48 -1
  136. warp/tests/test_generics.py +52 -0
  137. warp/tests/test_iter.py +68 -0
  138. warp/tests/test_mat_scalar_ops.py +1 -1
  139. warp/tests/test_mesh_query_point.py +5 -4
  140. warp/tests/test_module_hashing.py +23 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +191 -1
  143. warp/tests/test_spatial.py +1 -1
  144. warp/tests/test_tile.py +700 -0
  145. warp/tests/test_tile_mathdx.py +144 -0
  146. warp/tests/test_tile_mlp.py +383 -0
  147. warp/tests/test_tile_reduce.py +374 -0
  148. warp/tests/test_tile_shared_memory.py +190 -0
  149. warp/tests/test_vbd.py +12 -20
  150. warp/tests/test_volume.py +43 -0
  151. warp/tests/unittest_suites.py +23 -2
  152. warp/tests/unittest_utils.py +4 -0
  153. warp/types.py +339 -73
  154. warp/utils.py +22 -1
  155. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/METADATA +33 -7
  156. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/RECORD +159 -132
  157. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/WHEEL +1 -1
  158. warp/fem/field/test.py +0 -180
  159. warp/fem/field/trial.py +0 -183
  160. warp/fem/space/collocated_function_space.py +0 -102
  161. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  162. warp/fem/space/trimesh_2d_function_space.py +0 -153
  163. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/LICENSE.md +0 -0
  164. {warp_lang-1.4.1.dist-info → warp_lang-1.5.0.dist-info}/top_level.txt +0 -0
warp/codegen.py CHANGED
@@ -23,6 +23,10 @@ from typing import Any, Callable, Dict, Mapping, Optional, Sequence
23
23
  import warp.config
24
24
  from warp.types import *
25
25
 
26
+ # used as a globally accessible copy
27
+ # of current compile options (block_dim) etc
28
+ options = {}
29
+
26
30
 
27
31
  class WarpCodegenError(RuntimeError):
28
32
  def __init__(self, message):
@@ -110,6 +114,16 @@ def get_closure_cell_contents(obj):
110
114
  return None
111
115
 
112
116
 
117
+ def get_type_origin(tp):
118
+ # Compatible version of `typing.get_origin()` for Python 3.7 and older.
119
+ return getattr(tp, "__origin__", None)
120
+
121
+
122
+ def get_type_args(tp):
123
+ # Compatible version of `typing.get_args()` for Python 3.7 and older.
124
+ return getattr(tp, "__args__", ())
125
+
126
+
113
127
  def eval_annotations(annotations: Mapping[str, Any], obj: Any) -> Mapping[str, Any]:
114
128
  """Un-stringize annotations caused by `from __future__ import annotations` of PEP 563."""
115
129
  # Implementation backported from `inspect.get_annotations()` for Python 3.9 and older.
@@ -637,6 +651,8 @@ class Var:
637
651
  dtypestr = f"wp::{t.dtype.__name__}"
638
652
  classstr = f"wp::{type(t).__name__}"
639
653
  return f"{classstr}_t<{dtypestr}>"
654
+ elif is_tile(t):
655
+ return t.ctype()
640
656
  elif isinstance(t, Struct):
641
657
  return t.native_name
642
658
  elif isinstance(t, type) and issubclass(t, StructInstance):
@@ -876,7 +892,7 @@ class Adjoint:
876
892
  # use source-level argument annotations
877
893
  if len(argspec.annotations) < len(argspec.args):
878
894
  raise WarpCodegenError(f"Incomplete argument annotations on function {adj.fun_name}")
879
- adj.arg_types = argspec.annotations
895
+ adj.arg_types = {k: v for k, v in argspec.annotations.items() if not (k == "return" and v is None)}
880
896
  else:
881
897
  # use overload argument annotations
882
898
  for arg_name in argspec.args:
@@ -914,6 +930,28 @@ class Adjoint:
914
930
  # for unit testing errors being spit out from kernels.
915
931
  adj.skip_build = False
916
932
 
933
+ # Collect the LTOIR required at link-time
934
+ adj.ltoirs = []
935
+
936
+ # allocate extra space for a function call that requires its
937
+ # own shared memory space, we treat shared memory as a stack
938
+ # where each function pushes and pops space off, the extra
939
+ # quantity is the 'roofline' amount required for the entire kernel
940
+ def alloc_shared_extra(adj, num_bytes):
941
+ adj.max_required_extra_shared_memory = max(adj.max_required_extra_shared_memory, num_bytes)
942
+
943
+ # returns the total number of bytes for a function
944
+ # based on it's own requirements + worst case
945
+ # requirements of any dependent functions
946
+ def get_total_required_shared(adj):
947
+ total_shared = 0
948
+
949
+ for var in adj.variables:
950
+ if is_tile(var.type) and var.type.storage == "shared":
951
+ total_shared += var.type.size_in_bytes()
952
+
953
+ return total_shared + adj.max_required_extra_shared_memory
954
+
917
955
  # generate function ssa form and adjoint
918
956
  def build(adj, builder, default_builder_options=None):
919
957
  # arg Var read/write flags are held during module rebuilds, so we reset here even when skipping a build
@@ -934,12 +972,17 @@ class Adjoint:
934
972
  else:
935
973
  adj.builder_options = default_builder_options
936
974
 
975
+ global options
976
+ options = adj.builder_options
977
+
937
978
  adj.symbols = {} # map from symbols to adjoint variables
938
979
  adj.variables = [] # list of local variables (in order)
939
980
 
940
981
  adj.return_var = None # return type for function or kernel
941
982
  adj.loop_symbols = [] # symbols at the start of each loop
942
- adj.loop_const_iter_symbols = [] # iteration variables (constant) for static loops
983
+ adj.loop_const_iter_symbols = (
984
+ set()
985
+ ) # constant iteration variables for static loops (mutating them does not raise an error)
943
986
 
944
987
  # blocks
945
988
  adj.blocks = [Block()]
@@ -951,6 +994,9 @@ class Adjoint:
951
994
  # used to generate new label indices
952
995
  adj.label_count = 0
953
996
 
997
+ # tracks how much additional shared memory is required by any dependent function calls
998
+ adj.max_required_extra_shared_memory = 0
999
+
954
1000
  # update symbol map for each argument
955
1001
  for a in adj.args:
956
1002
  adj.symbols[a.label] = a
@@ -967,6 +1013,7 @@ class Adjoint:
967
1013
  e = ex(";".join([msg] + [str(a) for a in data.args])).with_traceback(traceback)
968
1014
  finally:
969
1015
  adj.skip_build = True
1016
+ adj.builder = None
970
1017
  raise e
971
1018
 
972
1019
  if builder is not None:
@@ -976,6 +1023,9 @@ class Adjoint:
976
1023
  elif isinstance(a.type, warp.types.array) and isinstance(a.type.dtype, Struct):
977
1024
  builder.build_struct_recursive(a.type.dtype)
978
1025
 
1026
+ # release builder reference for GC
1027
+ adj.builder = None
1028
+
979
1029
  # code generation methods
980
1030
  def format_template(adj, template, input_vars, output_var):
981
1031
  # output var is always the 0th index
@@ -992,9 +1042,9 @@ class Adjoint:
992
1042
  if isinstance(a, warp.context.Function):
993
1043
  # functions don't have a var_ prefix so strip it off here
994
1044
  if prefix == "var":
995
- arg_strs.append(a.native_func)
1045
+ arg_strs.append(f"{a.namespace}{a.native_func}")
996
1046
  else:
997
- arg_strs.append(f"{prefix}_{a.native_func}")
1047
+ arg_strs.append(f"{a.namespace}{prefix}_{a.native_func}")
998
1048
  elif is_reference(a.type):
999
1049
  arg_strs.append(f"{prefix}_{a}")
1000
1050
  elif isinstance(a, Var):
@@ -1276,15 +1326,34 @@ class Adjoint:
1276
1326
  bound_arg_values,
1277
1327
  )
1278
1328
 
1279
- if func.dispatch_func is not None:
1280
- # If we have a built-in that requires special handling to dispatch
1281
- # the arguments to the underlying C++ function, then we can resolve
1282
- # these using the `dispatch_func`. Since this is only called from
1283
- # within codegen, we pass it directly `codegen.Var` objects,
1284
- # which allows for some more advanced resolution to be performed,
1285
- # for example by checking whether an argument corresponds to
1286
- # a literal value or references a variable.
1329
+ # immediately allocate output variables so we can pass them into the dispatch method
1330
+ if return_type is None:
1331
+ # void function
1332
+ output = None
1333
+ output_list = []
1334
+ elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1335
+ # single return value function
1336
+ if isinstance(return_type, Sequence):
1337
+ return_type = return_type[0]
1338
+ output = adj.add_var(return_type)
1339
+ output_list = [output]
1340
+ else:
1341
+ # multiple return value function
1342
+ output = [adj.add_var(v) for v in return_type]
1343
+ output_list = output
1287
1344
 
1345
+ # If we have a built-in that requires special handling to dispatch
1346
+ # the arguments to the underlying C++ function, then we can resolve
1347
+ # these using the `dispatch_func`. Since this is only called from
1348
+ # within codegen, we pass it directly `codegen.Var` objects,
1349
+ # which allows for some more advanced resolution to be performed,
1350
+ # for example by checking whether an argument corresponds to
1351
+ # a literal value or references a variable.
1352
+ if func.lto_dispatch_func is not None:
1353
+ func_args, template_args, ltoirs = func.lto_dispatch_func(
1354
+ func.input_types, return_type, output_list, bound_args, options=adj.builder_options, builder=adj.builder
1355
+ )
1356
+ elif func.dispatch_func is not None:
1288
1357
  func_args, template_args = func.dispatch_func(func.input_types, return_type, bound_args)
1289
1358
  else:
1290
1359
  func_args = tuple(bound_args.values())
@@ -1299,18 +1368,14 @@ class Adjoint:
1299
1368
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1300
1369
  func_arg = adj.load(func_arg)
1301
1370
 
1302
- # if the argument is a function, build it recursively
1303
- if isinstance(func_arg, warp.context.Function):
1371
+ # if the argument is a function (and not a builtin), then build it recursively
1372
+ if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
1304
1373
  adj.builder.build_function(func_arg)
1305
1374
 
1306
1375
  fwd_args.append(strip_reference(func_arg))
1307
1376
 
1308
1377
  if return_type is None:
1309
1378
  # handles expression (zero output) functions, e.g.: void do_something();
1310
-
1311
- output = None
1312
- output_list = []
1313
-
1314
1379
  forward_call = (
1315
1380
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1316
1381
  )
@@ -1320,12 +1385,6 @@ class Adjoint:
1320
1385
 
1321
1386
  elif not isinstance(return_type, Sequence) or len(return_type) == 1:
1322
1387
  # handle simple function (one output)
1323
-
1324
- if isinstance(return_type, Sequence):
1325
- return_type = return_type[0]
1326
- output = adj.add_var(return_type)
1327
- output_list = [output]
1328
-
1329
1388
  forward_call = f"var_{output} = {func.namespace}{func_name}({adj.format_forward_call_args(fwd_args, use_initializer_list)});"
1330
1389
  replay_call = forward_call
1331
1390
  if func.custom_replay_func is not None:
@@ -1333,10 +1392,6 @@ class Adjoint:
1333
1392
 
1334
1393
  else:
1335
1394
  # handle multiple value functions
1336
-
1337
- output = [adj.add_var(v) for v in return_type]
1338
- output_list = output
1339
-
1340
1395
  forward_call = (
1341
1396
  f"{func.namespace}{func_name}({adj.format_forward_call_args(fwd_args + output, use_initializer_list)});"
1342
1397
  )
@@ -1364,6 +1419,11 @@ class Adjoint:
1364
1419
  reverse_call = f"{func.namespace}adj_{func.native_func}({arg_str});"
1365
1420
  adj.add_reverse(reverse_call)
1366
1421
 
1422
+ # update our smem roofline requirements based on any
1423
+ # shared memory required by the dependent function call
1424
+ if not func.is_builtin():
1425
+ adj.alloc_shared_extra(func.adj.get_total_required_shared())
1426
+
1367
1427
  return output
1368
1428
 
1369
1429
  def add_builtin_call(adj, func_name, args, min_outputs=None):
@@ -1464,7 +1524,10 @@ class Adjoint:
1464
1524
 
1465
1525
  # zero adjoints
1466
1526
  for i in body_block.vars:
1467
- reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1527
+ if is_tile(i.type):
1528
+ reverse.append(adj.indentation + f"\t{i.emit_adj()}.grad_zero();")
1529
+ else:
1530
+ reverse.append(adj.indentation + f"\t{i.emit_adj()} = {{}};")
1468
1531
 
1469
1532
  # replay
1470
1533
  for i in body_block.body_replay:
@@ -2000,22 +2063,11 @@ class Adjoint:
2000
2063
  )
2001
2064
  return range_call
2002
2065
 
2003
- def begin_record_constant_iter_symbols(adj):
2004
- if len(adj.loop_const_iter_symbols) > 0:
2005
- adj.loop_const_iter_symbols.append(adj.loop_const_iter_symbols[-1])
2006
- else:
2007
- adj.loop_const_iter_symbols.append(set())
2008
-
2009
- def end_record_constant_iter_symbols(adj):
2010
- if len(adj.loop_const_iter_symbols) > 0:
2011
- adj.loop_const_iter_symbols.pop()
2012
-
2013
2066
  def record_constant_iter_symbol(adj, sym):
2014
- if len(adj.loop_const_iter_symbols) > 0:
2015
- adj.loop_const_iter_symbols[-1].add(sym)
2067
+ adj.loop_const_iter_symbols.add(sym)
2016
2068
 
2017
2069
  def is_constant_iter_symbol(adj, sym):
2018
- return len(adj.loop_const_iter_symbols) > 0 and sym in adj.loop_const_iter_symbols[-1]
2070
+ return sym in adj.loop_const_iter_symbols
2019
2071
 
2020
2072
  def emit_For(adj, node):
2021
2073
  # try and unroll simple range() statements that use constant args
@@ -2045,7 +2097,6 @@ class Adjoint:
2045
2097
  iter = adj.eval(node.iter)
2046
2098
 
2047
2099
  adj.symbols[node.target.id] = adj.begin_for(iter)
2048
- adj.begin_record_constant_iter_symbols()
2049
2100
 
2050
2101
  # for loops should be side-effect free, here we store a copy
2051
2102
  adj.loop_symbols.append(adj.symbols.copy())
@@ -2056,7 +2107,6 @@ class Adjoint:
2056
2107
 
2057
2108
  adj.materialize_redefinitions(adj.loop_symbols[-1])
2058
2109
  adj.loop_symbols.pop()
2059
- adj.end_record_constant_iter_symbols()
2060
2110
 
2061
2111
  adj.end_for(iter)
2062
2112
 
@@ -2217,7 +2267,7 @@ class Adjoint:
2217
2267
 
2218
2268
  # returns the object being indexed, and the list of indices
2219
2269
  def eval_subscript(adj, node):
2220
- # We want to coalesce multi-dimentional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2270
+ # We want to coalesce multi-dimensional array indexing into a single operation. This needs to deal with expressions like `a[i][j][x][y]` where `a` is a 2D array of matrices,
2221
2271
  # and essentially rewrite it into `a[i, j][x][y]`. Since the AST observes the indexing right-to-left, and we don't want to evaluate the index expressions prematurely,
2222
2272
  # this requires a first loop to check if this `node` only performs indexing on the array, and a second loop to evaluate and collect index variables.
2223
2273
  root = node
@@ -2297,6 +2347,14 @@ class Adjoint:
2297
2347
  out.is_read = target.is_read
2298
2348
  out.is_write = target.is_write
2299
2349
 
2350
+ elif is_tile(target_type):
2351
+ if len(indices) == 2:
2352
+ # handles extracting a single element from a tile
2353
+ out = adj.add_builtin_call("tile_extract", [target, *indices])
2354
+ else:
2355
+ # handles tile views
2356
+ out = adj.add_builtin_call("tile_view", [target, *indices])
2357
+
2300
2358
  else:
2301
2359
  # handles non-array type indexing, e.g: vec3, mat33, etc
2302
2360
  out = adj.add_builtin_call("extract", [target, *indices])
@@ -2538,11 +2596,22 @@ class Adjoint:
2538
2596
  target_type = strip_reference(target.type)
2539
2597
 
2540
2598
  if is_array(target_type):
2541
- # target_type is not suitable for atomic array accumulation
2542
- if target_type.dtype not in warp.types.atomic_types:
2599
+ # target_types int8, uint8, int16, uint16 are not suitable for atomic array accumulation
2600
+ if target_type.dtype in warp.types.non_atomic_types:
2543
2601
  make_new_assign_statement()
2544
2602
  return
2545
2603
 
2604
+ # the same holds true for vecs/mats/quats that are composed of these types
2605
+ if (
2606
+ type_is_vector(target_type.dtype)
2607
+ or type_is_quaternion(target_type.dtype)
2608
+ or type_is_matrix(target_type.dtype)
2609
+ ):
2610
+ dtype = getattr(target_type.dtype, "_wp_scalar_type_", None)
2611
+ if dtype in warp.types.non_atomic_types:
2612
+ make_new_assign_statement()
2613
+ return
2614
+
2546
2615
  kernel_name = adj.fun_name
2547
2616
  filename = adj.filename
2548
2617
  lineno = adj.lineno + adj.fun_lineno
@@ -2559,7 +2628,10 @@ class Adjoint:
2559
2628
  if warp.config.verify_autograd_array_access:
2560
2629
  target.mark_write(kernel_name=kernel_name, filename=filename, lineno=lineno)
2561
2630
  else:
2562
- print(f"Warning: in-place op {node.op} is not differentiable")
2631
+ if warp.config.verbose:
2632
+ print(f"Warning: in-place op {node.op} is not differentiable")
2633
+ make_new_assign_statement()
2634
+ return
2563
2635
 
2564
2636
  # TODO
2565
2637
  elif type_is_vector(target_type) or type_is_quaternion(target_type) or type_is_matrix(target_type):
@@ -2963,6 +3035,7 @@ class Adjoint:
2963
3035
  # code generation
2964
3036
 
2965
3037
  cpu_module_header = """
3038
+ #define WP_TILE_BLOCK_DIM {tile_size}
2966
3039
  #define WP_NO_CRT
2967
3040
  #include "builtin.h"
2968
3041
 
@@ -2973,7 +3046,7 @@ cpu_module_header = """
2973
3046
  #define int(x) cast_int(x)
2974
3047
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2975
3048
 
2976
- #define builtin_tid1d() wp::tid(task_index)
3049
+ #define builtin_tid1d() wp::tid(task_index, dim)
2977
3050
  #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2978
3051
  #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2979
3052
  #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
@@ -2981,6 +3054,7 @@ cpu_module_header = """
2981
3054
  """
2982
3055
 
2983
3056
  cuda_module_header = """
3057
+ #define WP_TILE_BLOCK_DIM {tile_size}
2984
3058
  #define WP_NO_CRT
2985
3059
  #include "builtin.h"
2986
3060
 
@@ -2991,10 +3065,10 @@ cuda_module_header = """
2991
3065
  #define int(x) cast_int(x)
2992
3066
  #define adj_int(x, adj_x, adj_ret) adj_cast_int(x, adj_x, adj_ret)
2993
3067
 
2994
- #define builtin_tid1d() wp::tid(task_index)
2995
- #define builtin_tid2d(x, y) wp::tid(x, y, task_index, dim)
2996
- #define builtin_tid3d(x, y, z) wp::tid(x, y, z, task_index, dim)
2997
- #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, task_index, dim)
3068
+ #define builtin_tid1d() wp::tid(_idx, dim)
3069
+ #define builtin_tid2d(x, y) wp::tid(x, y, _idx, dim)
3070
+ #define builtin_tid3d(x, y, z) wp::tid(x, y, z, _idx, dim)
3071
+ #define builtin_tid4d(x, y, z, w) wp::tid(x, y, z, w, _idx, dim)
2998
3072
 
2999
3073
  """
3000
3074
 
@@ -3066,20 +3140,26 @@ cuda_kernel_template = """
3066
3140
  extern "C" __global__ void {name}_cuda_kernel_forward(
3067
3141
  {forward_args})
3068
3142
  {{
3069
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3070
- task_index < dim.size;
3071
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3143
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3144
+ _idx < dim.size;
3145
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3072
3146
  {{
3147
+ // reset shared memory allocator
3148
+ wp::tile_alloc_shared(0, true);
3149
+
3073
3150
  {forward_body} }}
3074
3151
  }}
3075
3152
 
3076
3153
  extern "C" __global__ void {name}_cuda_kernel_backward(
3077
3154
  {reverse_args})
3078
3155
  {{
3079
- for (size_t task_index = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3080
- task_index < dim.size;
3081
- task_index += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3156
+ for (size_t _idx = static_cast<size_t>(blockDim.x) * static_cast<size_t>(blockIdx.x) + static_cast<size_t>(threadIdx.x);
3157
+ _idx < dim.size;
3158
+ _idx += static_cast<size_t>(blockDim.x) * static_cast<size_t>(gridDim.x))
3082
3159
  {{
3160
+ // reset shared memory allocator
3161
+ wp::tile_alloc_shared(0, true);
3162
+
3083
3163
  {reverse_body} }}
3084
3164
  }}
3085
3165
 
@@ -3317,7 +3397,9 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3317
3397
  lines += ["// primal vars\n"]
3318
3398
 
3319
3399
  for var in adj.variables:
3320
- if var.constant is None:
3400
+ if is_tile(var.type):
3401
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=False)};\n"]
3402
+ elif var.constant is None:
3321
3403
  lines += [f"{var.ctype()} {var.emit()};\n"]
3322
3404
  else:
3323
3405
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3352,7 +3434,9 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3352
3434
  lines += ["// primal vars\n"]
3353
3435
 
3354
3436
  for var in adj.variables:
3355
- if var.constant is None:
3437
+ if is_tile(var.type):
3438
+ lines += [f"{var.ctype()} {var.emit()} = {var.type.cinit(requires_grad=True)};\n"]
3439
+ elif var.constant is None:
3356
3440
  lines += [f"{var.ctype()} {var.emit()};\n"]
3357
3441
  else:
3358
3442
  lines += [f"const {var.ctype()} {var.emit()} = {constant_str(var.constant)};\n"]
@@ -3362,7 +3446,20 @@ def codegen_func_reverse(adj, func_type="kernel", device="cpu"):
3362
3446
  lines += ["// dual vars\n"]
3363
3447
 
3364
3448
  for var in adj.variables:
3365
- lines += [f"{var.ctype(value_type=True)} {var.emit_adj()} = {{}};\n"]
3449
+ name = var.emit_adj()
3450
+ ctype = var.ctype(value_type=True)
3451
+
3452
+ if is_tile(var.type):
3453
+ if var.type.storage == "register":
3454
+ lines += [
3455
+ f"{var.type.ctype()} {name}(0.0);\n"
3456
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3457
+ elif var.type.storage == "shared":
3458
+ lines += [
3459
+ f"{var.type.ctype()}& {name} = {var.emit()};\n"
3460
+ ] # reverse mode tiles alias the forward vars since shared tiles store both primal/dual vars together
3461
+ else:
3462
+ lines += [f"{ctype} {name} = {{}};\n"]
3366
3463
 
3367
3464
  # forward pass
3368
3465
  lines += ["//---------\n"]
@@ -3391,6 +3488,33 @@ def codegen_func(adj, c_func_name: str, device="cpu", options=None):
3391
3488
  if options is None:
3392
3489
  options = {}
3393
3490
 
3491
+ if adj.return_var is not None and "return" in adj.arg_types:
3492
+ if get_type_origin(adj.arg_types["return"]) is tuple:
3493
+ if len(get_type_args(adj.arg_types["return"])) != len(adj.return_var):
3494
+ raise WarpCodegenError(
3495
+ f"The function `{adj.fun_name}` has its return type "
3496
+ f"annotated as a tuple of {len(get_type_args(adj.arg_types['return']))} elements "
3497
+ f"but the code returns {len(adj.return_var)} values."
3498
+ )
3499
+ elif not types_equal(adj.arg_types["return"], tuple(x.type for x in adj.return_var)):
3500
+ raise WarpCodegenError(
3501
+ f"The function `{adj.fun_name}` has its return type "
3502
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3503
+ f"but the code returns a tuple with types `({', '.join(warp.context.type_str(x.type) for x in adj.return_var)})`."
3504
+ )
3505
+ elif len(adj.return_var) > 1 and get_type_origin(adj.arg_types["return"]) is not tuple:
3506
+ raise WarpCodegenError(
3507
+ f"The function `{adj.fun_name}` has its return type "
3508
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3509
+ f"but the code returns {len(adj.return_var)} values."
3510
+ )
3511
+ elif not types_equal(adj.arg_types["return"], adj.return_var[0].type):
3512
+ raise WarpCodegenError(
3513
+ f"The function `{adj.fun_name}` has its return type "
3514
+ f"annotated as `{warp.context.type_str(adj.arg_types['return'])}` "
3515
+ f"but the code returns a value of type `{warp.context.type_str(adj.return_var[0].type)}`."
3516
+ )
3517
+
3394
3518
  # forward header
3395
3519
  if adj.return_var is not None and len(adj.return_var) == 1:
3396
3520
  return_type = adj.return_var[0].ctype()
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.4.1"
10
+ version: str = "1.5.0"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False
@@ -16,7 +16,7 @@ Has performance implications.
16
16
  """
17
17
 
18
18
  verify_cuda: bool = False
19
- """If `True`, Warp will check for CUDA errors after every launch and memory operation.
19
+ """If `True`, Warp will check for CUDA errors after every launch operation.
20
20
  CUDA error verification cannot be used during graph capture. Has performance implications.
21
21
  """
22
22