warp-lang 1.4.2__py3-none-manylinux2014_aarch64.whl → 1.5.1__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (166) hide show
  1. warp/__init__.py +4 -0
  2. warp/autograd.py +43 -8
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +21 -2
  6. warp/build_dll.py +23 -6
  7. warp/builtins.py +1819 -7
  8. warp/codegen.py +197 -61
  9. warp/config.py +2 -2
  10. warp/context.py +379 -107
  11. warp/examples/assets/pixel.jpg +0 -0
  12. warp/examples/benchmarks/benchmark_cloth_paddle.py +86 -0
  13. warp/examples/benchmarks/benchmark_gemm.py +121 -0
  14. warp/examples/benchmarks/benchmark_interop_paddle.py +158 -0
  15. warp/examples/benchmarks/benchmark_tile.py +179 -0
  16. warp/examples/fem/example_adaptive_grid.py +37 -10
  17. warp/examples/fem/example_apic_fluid.py +3 -2
  18. warp/examples/fem/example_convection_diffusion_dg.py +4 -5
  19. warp/examples/fem/example_deformed_geometry.py +1 -1
  20. warp/examples/fem/example_diffusion_3d.py +47 -4
  21. warp/examples/fem/example_distortion_energy.py +220 -0
  22. warp/examples/fem/example_magnetostatics.py +127 -85
  23. warp/examples/fem/example_nonconforming_contact.py +5 -5
  24. warp/examples/fem/example_stokes.py +3 -1
  25. warp/examples/fem/example_streamlines.py +12 -19
  26. warp/examples/fem/utils.py +38 -15
  27. warp/examples/sim/example_cloth.py +4 -25
  28. warp/examples/sim/example_quadruped.py +2 -1
  29. warp/examples/tile/example_tile_convolution.py +58 -0
  30. warp/examples/tile/example_tile_fft.py +47 -0
  31. warp/examples/tile/example_tile_filtering.py +105 -0
  32. warp/examples/tile/example_tile_matmul.py +79 -0
  33. warp/examples/tile/example_tile_mlp.py +375 -0
  34. warp/fem/__init__.py +8 -0
  35. warp/fem/cache.py +16 -12
  36. warp/fem/dirichlet.py +1 -1
  37. warp/fem/domain.py +44 -1
  38. warp/fem/field/__init__.py +1 -2
  39. warp/fem/field/field.py +31 -19
  40. warp/fem/field/nodal_field.py +101 -49
  41. warp/fem/field/virtual.py +794 -0
  42. warp/fem/geometry/__init__.py +2 -2
  43. warp/fem/geometry/deformed_geometry.py +3 -105
  44. warp/fem/geometry/element.py +13 -0
  45. warp/fem/geometry/geometry.py +165 -7
  46. warp/fem/geometry/grid_2d.py +3 -6
  47. warp/fem/geometry/grid_3d.py +31 -28
  48. warp/fem/geometry/hexmesh.py +3 -46
  49. warp/fem/geometry/nanogrid.py +3 -2
  50. warp/fem/geometry/{quadmesh_2d.py → quadmesh.py} +280 -159
  51. warp/fem/geometry/tetmesh.py +2 -43
  52. warp/fem/geometry/{trimesh_2d.py → trimesh.py} +354 -186
  53. warp/fem/integrate.py +683 -261
  54. warp/fem/linalg.py +404 -0
  55. warp/fem/operator.py +101 -18
  56. warp/fem/polynomial.py +5 -5
  57. warp/fem/quadrature/quadrature.py +45 -21
  58. warp/fem/space/__init__.py +45 -11
  59. warp/fem/space/basis_function_space.py +451 -0
  60. warp/fem/space/basis_space.py +58 -11
  61. warp/fem/space/function_space.py +146 -5
  62. warp/fem/space/grid_2d_function_space.py +80 -66
  63. warp/fem/space/grid_3d_function_space.py +113 -68
  64. warp/fem/space/hexmesh_function_space.py +96 -108
  65. warp/fem/space/nanogrid_function_space.py +62 -110
  66. warp/fem/space/quadmesh_function_space.py +208 -0
  67. warp/fem/space/shape/__init__.py +45 -7
  68. warp/fem/space/shape/cube_shape_function.py +328 -54
  69. warp/fem/space/shape/shape_function.py +10 -1
  70. warp/fem/space/shape/square_shape_function.py +328 -60
  71. warp/fem/space/shape/tet_shape_function.py +269 -19
  72. warp/fem/space/shape/triangle_shape_function.py +238 -19
  73. warp/fem/space/tetmesh_function_space.py +69 -37
  74. warp/fem/space/topology.py +38 -0
  75. warp/fem/space/trimesh_function_space.py +179 -0
  76. warp/fem/utils.py +6 -331
  77. warp/jax_experimental.py +3 -1
  78. warp/native/array.h +15 -0
  79. warp/native/builtin.h +66 -26
  80. warp/native/bvh.h +4 -0
  81. warp/native/coloring.cpp +604 -0
  82. warp/native/cuda_util.cpp +68 -51
  83. warp/native/cuda_util.h +2 -1
  84. warp/native/fabric.h +8 -0
  85. warp/native/hashgrid.h +4 -0
  86. warp/native/marching.cu +8 -0
  87. warp/native/mat.h +14 -3
  88. warp/native/mathdx.cpp +59 -0
  89. warp/native/mesh.h +4 -0
  90. warp/native/range.h +13 -1
  91. warp/native/reduce.cpp +9 -1
  92. warp/native/reduce.cu +7 -0
  93. warp/native/runlength_encode.cpp +9 -1
  94. warp/native/runlength_encode.cu +7 -1
  95. warp/native/scan.cpp +8 -0
  96. warp/native/scan.cu +8 -0
  97. warp/native/scan.h +8 -1
  98. warp/native/sparse.cpp +8 -0
  99. warp/native/sparse.cu +8 -0
  100. warp/native/temp_buffer.h +7 -0
  101. warp/native/tile.h +1854 -0
  102. warp/native/tile_gemm.h +341 -0
  103. warp/native/tile_reduce.h +210 -0
  104. warp/native/volume_builder.cu +8 -0
  105. warp/native/volume_builder.h +8 -0
  106. warp/native/warp.cpp +10 -2
  107. warp/native/warp.cu +369 -15
  108. warp/native/warp.h +12 -2
  109. warp/optim/adam.py +39 -4
  110. warp/paddle.py +29 -12
  111. warp/render/render_opengl.py +140 -67
  112. warp/sim/graph_coloring.py +292 -0
  113. warp/sim/import_urdf.py +8 -8
  114. warp/sim/integrator_euler.py +4 -2
  115. warp/sim/integrator_featherstone.py +115 -44
  116. warp/sim/integrator_vbd.py +6 -0
  117. warp/sim/model.py +109 -32
  118. warp/sparse.py +1 -1
  119. warp/stubs.py +569 -4
  120. warp/tape.py +12 -7
  121. warp/tests/assets/pixel.npy +0 -0
  122. warp/tests/aux_test_instancing_gc.py +18 -0
  123. warp/tests/test_array.py +39 -0
  124. warp/tests/test_codegen.py +81 -1
  125. warp/tests/test_codegen_instancing.py +30 -0
  126. warp/tests/test_collision.py +110 -0
  127. warp/tests/test_coloring.py +251 -0
  128. warp/tests/test_context.py +34 -0
  129. warp/tests/test_examples.py +21 -5
  130. warp/tests/test_fem.py +453 -113
  131. warp/tests/test_func.py +34 -4
  132. warp/tests/test_generics.py +52 -0
  133. warp/tests/test_iter.py +68 -0
  134. warp/tests/test_lerp.py +13 -87
  135. warp/tests/test_mat_scalar_ops.py +1 -1
  136. warp/tests/test_matmul.py +6 -9
  137. warp/tests/test_matmul_lite.py +6 -11
  138. warp/tests/test_mesh_query_point.py +1 -1
  139. warp/tests/test_module_hashing.py +23 -0
  140. warp/tests/test_overwrite.py +45 -0
  141. warp/tests/test_paddle.py +27 -87
  142. warp/tests/test_print.py +56 -1
  143. warp/tests/test_smoothstep.py +17 -83
  144. warp/tests/test_spatial.py +1 -1
  145. warp/tests/test_static.py +3 -3
  146. warp/tests/test_tile.py +744 -0
  147. warp/tests/test_tile_mathdx.py +144 -0
  148. warp/tests/test_tile_mlp.py +383 -0
  149. warp/tests/test_tile_reduce.py +374 -0
  150. warp/tests/test_tile_shared_memory.py +190 -0
  151. warp/tests/test_vbd.py +12 -20
  152. warp/tests/test_volume.py +43 -0
  153. warp/tests/unittest_suites.py +19 -2
  154. warp/tests/unittest_utils.py +4 -2
  155. warp/types.py +340 -74
  156. warp/utils.py +23 -3
  157. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/METADATA +32 -7
  158. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/RECORD +161 -134
  159. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/WHEEL +1 -1
  160. warp/fem/field/test.py +0 -180
  161. warp/fem/field/trial.py +0 -183
  162. warp/fem/space/collocated_function_space.py +0 -102
  163. warp/fem/space/quadmesh_2d_function_space.py +0 -261
  164. warp/fem/space/trimesh_2d_function_space.py +0 -153
  165. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/LICENSE.md +0 -0
  166. {warp_lang-1.4.2.dist-info → warp_lang-1.5.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -7,21 +7,24 @@
7
7
 
8
8
  import ast
9
9
  import ctypes
10
+ import errno
10
11
  import functools
11
12
  import hashlib
12
13
  import inspect
13
14
  import io
14
15
  import itertools
16
+ import json
15
17
  import operator
16
18
  import os
17
19
  import platform
18
20
  import sys
21
+ import time
19
22
  import types
20
23
  import typing
21
24
  import weakref
22
25
  from copy import copy as shallowcopy
23
26
  from pathlib import Path
24
- from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
27
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
25
28
 
26
29
  import numpy as np
27
30
 
@@ -101,6 +104,7 @@ class Function:
101
104
  value_func=None,
102
105
  export_func=None,
103
106
  dispatch_func=None,
107
+ lto_dispatch_func=None,
104
108
  module=None,
105
109
  variadic=False,
106
110
  initializer_list_func=None,
@@ -137,6 +141,7 @@ class Function:
137
141
  self.value_func = value_func # a function that takes a list of args and a list of templates and returns the value type, e.g.: load(array, index) returns the type of value being loaded
138
142
  self.export_func = export_func
139
143
  self.dispatch_func = dispatch_func
144
+ self.lto_dispatch_func = lto_dispatch_func
140
145
  self.input_types = {}
141
146
  self.export = export
142
147
  self.doc = doc
@@ -235,24 +240,23 @@ class Function:
235
240
  # in a way that is compatible with Python's semantics.
236
241
  signature_params = []
237
242
  signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
238
- for param_name in self.input_types.keys():
239
- if param_name.startswith("**"):
240
- param_name = param_name[2:]
243
+ for raw_param_name in self.input_types.keys():
244
+ if raw_param_name.startswith("**"):
245
+ param_name = raw_param_name[2:]
241
246
  param_kind = inspect.Parameter.VAR_KEYWORD
242
- elif param_name.startswith("*"):
243
- param_name = param_name[1:]
247
+ elif raw_param_name.startswith("*"):
248
+ param_name = raw_param_name[1:]
244
249
  param_kind = inspect.Parameter.VAR_POSITIONAL
245
250
 
246
251
  # Once a variadic argument like `*args` is found, any following
247
252
  # arguments need to be passed using keywords.
248
253
  signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
249
254
  else:
255
+ param_name = raw_param_name
250
256
  param_kind = signature_default_param_kind
251
257
 
252
- param = param = inspect.Parameter(
253
- param_name,
254
- param_kind,
255
- default=self.defaults.get(param_name, inspect.Parameter.empty),
258
+ param = inspect.Parameter(
259
+ param_name, param_kind, default=self.defaults.get(param_name, inspect.Parameter.empty)
256
260
  )
257
261
  signature_params.append(param)
258
262
  self.signature = inspect.Signature(signature_params)
@@ -291,22 +295,22 @@ class Function:
291
295
 
292
296
  if hasattr(self, "user_overloads") and len(self.user_overloads):
293
297
  # user-defined function with overloads
298
+ bound_args = self.signature.bind(*args, **kwargs)
299
+ if self.defaults:
300
+ warp.codegen.apply_defaults(bound_args, self.defaults)
294
301
 
295
- if len(kwargs):
296
- raise RuntimeError(
297
- f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
298
- )
302
+ arguments = tuple(bound_args.arguments.values())
299
303
 
300
304
  # try and find a matching overload
301
305
  for overload in self.user_overloads.values():
302
- if len(overload.input_types) != len(args):
306
+ if len(overload.input_types) != len(arguments):
303
307
  continue
304
308
  template_types = list(overload.input_types.values())
305
309
  arg_names = list(overload.input_types.keys())
306
310
  try:
307
311
  # attempt to unify argument types with function template types
308
- warp.types.infer_argument_types(args, template_types, arg_names)
309
- return overload.func(*args)
312
+ warp.types.infer_argument_types(arguments, template_types, arg_names)
313
+ return overload.func(*arguments)
310
314
  except Exception:
311
315
  continue
312
316
 
@@ -506,11 +510,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
506
510
  if elem_count != arg_type._length_:
507
511
  return (False, None)
508
512
 
509
- # Retrieve the element type of the sequence while ensuring
510
- # that it's homogeneous.
513
+ # Retrieve the element type of the sequence while ensuring that it's homogeneous.
511
514
  elem_type = type(arr[0])
512
- for i in range(1, elem_count):
513
- if type(arr[i]) is not elem_type:
515
+ for array_index in range(1, elem_count):
516
+ if type(arr[array_index]) is not elem_type:
514
517
  raise ValueError("All array elements must share the same type.")
515
518
 
516
519
  expected_elem_type = arg_type._wp_scalar_type_
@@ -540,10 +543,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
540
543
  c_param = arg_type()
541
544
  if warp.types.type_is_matrix(arg_type):
542
545
  rows, cols = arg_type._shape_
543
- for i in range(rows):
544
- idx_start = i * cols
546
+ for row_index in range(rows):
547
+ idx_start = row_index * cols
545
548
  idx_end = idx_start + cols
546
- c_param[i] = arr[idx_start:idx_end]
549
+ c_param[row_index] = arr[idx_start:idx_end]
547
550
  else:
548
551
  c_param[:] = arr
549
552
 
@@ -619,10 +622,13 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
619
622
 
620
623
 
621
624
  class KernelHooks:
622
- def __init__(self, forward, backward):
625
+ def __init__(self, forward, backward, forward_smem_bytes=0, backward_smem_bytes=0):
623
626
  self.forward = forward
624
627
  self.backward = backward
625
628
 
629
+ self.forward_smem_bytes = forward_smem_bytes
630
+ self.backward_smem_bytes = backward_smem_bytes
631
+
626
632
 
627
633
  # caches source and compiled entry points for a kernel (will be populated after module loads)
628
634
  class Kernel:
@@ -970,8 +976,17 @@ def struct(c):
970
976
  return s
971
977
 
972
978
 
973
- # overload a kernel with the given argument types
974
- def overload(kernel, arg_types=None):
979
+ def overload(kernel, arg_types=Union[None, Dict[str, Any], List[Any]]):
980
+ """Overload a generic kernel with the given argument types.
981
+
982
+ Can be called directly or used as a function decorator.
983
+
984
+ Args:
985
+ kernel: The generic kernel to be instantiated with concrete types.
986
+ arg_types: A list of concrete argument types for the kernel or a
987
+ dictionary specifying generic argument names as keys and concrete
988
+ types as variables.
989
+ """
975
990
  if isinstance(kernel, Kernel):
976
991
  # handle cases where user calls us directly, e.g. wp.overload(kernel, [args...])
977
992
 
@@ -1073,6 +1088,7 @@ def add_builtin(
1073
1088
  value_func=None,
1074
1089
  export_func=None,
1075
1090
  dispatch_func=None,
1091
+ lto_dispatch_func=None,
1076
1092
  doc="",
1077
1093
  namespace="wp::",
1078
1094
  variadic=False,
@@ -1113,6 +1129,9 @@ def add_builtin(
1113
1129
  The arguments returned must be of type `codegen.Var`.
1114
1130
  If not provided, all arguments passed by the users when calling
1115
1131
  the built-in are passed as-is as runtime arguments to the C++ function.
1132
+ lto_dispatch_func (Callable): Same as dispatch_func, but takes an 'option' dict
1133
+ as extra argument (indicating tile_size and target architecture) and returns
1134
+ an LTO-IR buffer as extra return value
1116
1135
  doc (str): Used to generate the Python's docstring and the HTML documentation.
1117
1136
  namespace: Namespace for the underlying C++ function.
1118
1137
  variadic (bool): Whether the function declares variadic arguments.
@@ -1220,16 +1239,16 @@ def add_builtin(
1220
1239
  typelists.append(l)
1221
1240
 
1222
1241
  for arg_types in itertools.product(*typelists):
1223
- arg_types = dict(zip(input_types.keys(), arg_types))
1242
+ concrete_arg_types = dict(zip(input_types.keys(), arg_types))
1224
1243
 
1225
1244
  # Some of these argument lists won't work, eg if the function is mul(), we won't be
1226
1245
  # able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
1227
1246
  # function determines which combinations are valid:
1228
1247
  if constraint:
1229
- if constraint(arg_types) is False:
1248
+ if constraint(concrete_arg_types) is False:
1230
1249
  continue
1231
1250
 
1232
- return_type = value_func(arg_types, None)
1251
+ return_type = value_func(concrete_arg_types, None)
1233
1252
 
1234
1253
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1235
1254
  # in the list of hard coded types so it knows it's returning one of them:
@@ -1247,11 +1266,12 @@ def add_builtin(
1247
1266
  # finally we can generate a function call for these concrete types:
1248
1267
  add_builtin(
1249
1268
  key,
1250
- input_types=arg_types,
1269
+ input_types=concrete_arg_types,
1251
1270
  value_type=return_type,
1252
1271
  value_func=value_func if return_type is Any else None,
1253
1272
  export_func=export_func,
1254
1273
  dispatch_func=dispatch_func,
1274
+ lto_dispatch_func=lto_dispatch_func,
1255
1275
  doc=doc,
1256
1276
  namespace=namespace,
1257
1277
  variadic=variadic,
@@ -1274,6 +1294,7 @@ def add_builtin(
1274
1294
  value_func=value_func,
1275
1295
  export_func=export_func,
1276
1296
  dispatch_func=dispatch_func,
1297
+ lto_dispatch_func=lto_dispatch_func,
1277
1298
  variadic=variadic,
1278
1299
  initializer_list_func=initializer_list_func,
1279
1300
  export=export,
@@ -1540,6 +1561,8 @@ class ModuleBuilder:
1540
1561
  self.options = options
1541
1562
  self.module = module
1542
1563
  self.deferred_functions = []
1564
+ self.ltoirs = {} # map from lto symbol to lto binary
1565
+ self.ltoirs_decl = {} # map from lto symbol to lto forward declaration
1543
1566
 
1544
1567
  if hasher is None:
1545
1568
  hasher = ModuleHasher(module)
@@ -1607,9 +1630,26 @@ class ModuleBuilder:
1607
1630
  # use dict to preserve import order
1608
1631
  self.functions[func] = None
1609
1632
 
1633
+ def build_meta(self):
1634
+ meta = {}
1635
+
1636
+ for kernel in self.kernels:
1637
+ name = kernel.get_mangled_name()
1638
+
1639
+ meta[name + "_cuda_kernel_forward_smem_bytes"] = kernel.adj.get_total_required_shared()
1640
+ meta[name + "_cuda_kernel_backward_smem_bytes"] = kernel.adj.get_total_required_shared() * 2
1641
+
1642
+ return meta
1643
+
1610
1644
  def codegen(self, device):
1611
1645
  source = ""
1612
1646
 
1647
+ # code-gen LTO forward declarations
1648
+ source += 'extern "C" {\n'
1649
+ for fwd in self.ltoirs_decl.values():
1650
+ source += fwd + "\n"
1651
+ source += "}\n"
1652
+
1613
1653
  # code-gen structs
1614
1654
  visited_structs = set()
1615
1655
  for struct in self.structs.keys():
@@ -1639,9 +1679,9 @@ class ModuleBuilder:
1639
1679
 
1640
1680
  # add headers
1641
1681
  if device == "cpu":
1642
- source = warp.codegen.cpu_module_header + source
1682
+ source = warp.codegen.cpu_module_header.format(tile_size=self.options["block_dim"]) + source
1643
1683
  else:
1644
- source = warp.codegen.cuda_module_header + source
1684
+ source = warp.codegen.cuda_module_header.format(tile_size=self.options["block_dim"]) + source
1645
1685
 
1646
1686
  return source
1647
1687
 
@@ -1660,11 +1700,12 @@ class ModuleExec:
1660
1700
  instance.handle = None
1661
1701
  return instance
1662
1702
 
1663
- def __init__(self, handle, module_hash, device):
1703
+ def __init__(self, handle, module_hash, device, meta):
1664
1704
  self.handle = handle
1665
1705
  self.module_hash = module_hash
1666
1706
  self.device = device
1667
1707
  self.kernel_hooks = {}
1708
+ self.meta = meta
1668
1709
 
1669
1710
  # release the loaded module
1670
1711
  def __del__(self):
@@ -1678,19 +1719,50 @@ class ModuleExec:
1678
1719
 
1679
1720
  # lookup and cache kernel entry points
1680
1721
  def get_kernel_hooks(self, kernel):
1681
- hooks = self.kernel_hooks.get(kernel)
1722
+ # Use kernel.adj as a unique key for cache lookups instead of the kernel itself.
1723
+ # This avoids holding a reference to the kernel and is faster than using
1724
+ # a WeakKeyDictionary with kernels as keys.
1725
+ hooks = self.kernel_hooks.get(kernel.adj)
1682
1726
  if hooks is not None:
1683
1727
  return hooks
1684
1728
 
1685
1729
  name = kernel.get_mangled_name()
1686
1730
 
1687
1731
  if self.device.is_cuda:
1688
- forward = runtime.core.cuda_get_kernel(
1689
- self.device.context, self.handle, (name + "_cuda_kernel_forward").encode("utf-8")
1732
+ forward_name = name + "_cuda_kernel_forward"
1733
+ forward_kernel = runtime.core.cuda_get_kernel(
1734
+ self.device.context, self.handle, forward_name.encode("utf-8")
1690
1735
  )
1691
- backward = runtime.core.cuda_get_kernel(
1692
- self.device.context, self.handle, (name + "_cuda_kernel_backward").encode("utf-8")
1736
+
1737
+ backward_name = name + "_cuda_kernel_backward"
1738
+ backward_kernel = runtime.core.cuda_get_kernel(
1739
+ self.device.context, self.handle, backward_name.encode("utf-8")
1693
1740
  )
1741
+
1742
+ # look up the required shared memory size for each kernel from module metadata
1743
+ forward_smem_bytes = self.meta[forward_name + "_smem_bytes"]
1744
+ backward_smem_bytes = self.meta[backward_name + "_smem_bytes"]
1745
+
1746
+ # configure kernels maximum shared memory size
1747
+ max_smem_bytes = runtime.core.cuda_get_max_shared_memory(self.device.context)
1748
+
1749
+ if not runtime.core.cuda_configure_kernel_shared_memory(forward_kernel, forward_smem_bytes):
1750
+ print(
1751
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {forward_name} kernel for {forward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1752
+ )
1753
+
1754
+ options = dict(kernel.module.options)
1755
+ options.update(kernel.options)
1756
+
1757
+ if options["enable_backward"] and not runtime.core.cuda_configure_kernel_shared_memory(
1758
+ backward_kernel, backward_smem_bytes
1759
+ ):
1760
+ print(
1761
+ f"Warning: Failed to configure kernel dynamic shared memory for this device, tried to configure {backward_name} kernel for {backward_smem_bytes} bytes, but maximum available is {max_smem_bytes}"
1762
+ )
1763
+
1764
+ hooks = KernelHooks(forward_kernel, backward_kernel, forward_smem_bytes, backward_smem_bytes)
1765
+
1694
1766
  else:
1695
1767
  func = ctypes.CFUNCTYPE(None)
1696
1768
  forward = (
@@ -1700,9 +1772,9 @@ class ModuleExec:
1700
1772
  func(runtime.llvm.lookup(self.handle.encode("utf-8"), (name + "_cpu_backward").encode("utf-8"))) or None
1701
1773
  )
1702
1774
 
1703
- hooks = KernelHooks(forward, backward)
1704
- self.kernel_hooks[kernel] = hooks
1775
+ hooks = KernelHooks(forward, backward)
1705
1776
 
1777
+ self.kernel_hooks[kernel.adj] = hooks
1706
1778
  return hooks
1707
1779
 
1708
1780
 
@@ -1712,7 +1784,8 @@ class ModuleExec:
1712
1784
  # build cache
1713
1785
  class Module:
1714
1786
  def __init__(self, name, loader):
1715
- self.name = name
1787
+ self.name = name if name is not None else "None"
1788
+
1716
1789
  self.loader = loader
1717
1790
 
1718
1791
  # lookup the latest versions of kernels, functions, and structs by key
@@ -1720,12 +1793,14 @@ class Module:
1720
1793
  self.functions = {} # (key: function)
1721
1794
  self.structs = {} # (key: struct)
1722
1795
 
1723
- # Set of all "live" kernels in this module.
1796
+ # Set of all "live" kernels in this module, i.e., kernels that still have references.
1797
+ # We keep a weak reference to every kernel ever created in this module and rely on Python GC
1798
+ # to release kernels that no longer have any references (in user code or internal bookkeeping).
1724
1799
  # The difference between `live_kernels` and `kernels` is that `live_kernels` may contain
1725
1800
  # multiple kernels with the same key (which is essential to support closures), while `kernels`
1726
1801
  # only holds the latest kernel for each key. When the module is built, we compute the hash
1727
1802
  # of each kernel in `live_kernels` and filter out duplicates for codegen.
1728
- self.live_kernels = weakref.WeakSet()
1803
+ self._live_kernels = weakref.WeakSet()
1729
1804
 
1730
1805
  # executable modules currently loaded
1731
1806
  self.execs = {} # (device.context: ModuleExec)
@@ -1749,6 +1824,7 @@ class Module:
1749
1824
  "fast_math": False,
1750
1825
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
1751
1826
  "mode": warp.config.mode,
1827
+ "block_dim": 256,
1752
1828
  }
1753
1829
 
1754
1830
  # Module dependencies are determined by scanning each function
@@ -1773,7 +1849,7 @@ class Module:
1773
1849
  self.kernels[kernel.key] = kernel
1774
1850
 
1775
1851
  # track all kernel objects, even if they are duplicates
1776
- self.live_kernels.add(kernel)
1852
+ self._live_kernels.add(kernel)
1777
1853
 
1778
1854
  self.find_references(kernel.adj)
1779
1855
 
@@ -1839,6 +1915,19 @@ class Module:
1839
1915
  # for a reload of module on next launch
1840
1916
  self.mark_modified()
1841
1917
 
1918
+ @property
1919
+ def live_kernels(self):
1920
+ # Return a list of kernels that still have references.
1921
+ # We return a regular list instead of the WeakSet to avoid undesirable issues
1922
+ # if kernels are garbage collected before the caller is done using this list.
1923
+ # Note that we should avoid retaining strong references to kernels unnecessarily
1924
+ # so that Python GC can release kernels that no longer have user references.
1925
+ # It is tempting to call gc.collect() here to force garbage collection,
1926
+ # but this can have undesirable consequences (e.g., GC during graph capture),
1927
+ # so we should avoid it as a general rule. Instead, we rely on Python's
1928
+ # reference counting GC to collect kernels that have gone out of scope.
1929
+ return list(self._live_kernels)
1930
+
1842
1931
  # find kernel corresponding to a Python function
1843
1932
  def find_kernel(self, func):
1844
1933
  qualname = warp.codegen.make_full_qualified_name(func)
@@ -1879,9 +1968,17 @@ class Module:
1879
1968
  self.hasher = ModuleHasher(self)
1880
1969
  return self.hasher.get_module_hash()
1881
1970
 
1882
- def load(self, device) -> ModuleExec:
1971
+ def load(self, device, block_dim=None) -> ModuleExec:
1883
1972
  device = runtime.get_device(device)
1884
1973
 
1974
+ # re-compile module if tile size (blockdim) changes
1975
+ # todo: it would be better to have a method such as `module.get_kernel(block_dim=N)`
1976
+ # that can return a single kernel instance with a given block size
1977
+ if block_dim is not None:
1978
+ if self.options["block_dim"] != block_dim:
1979
+ self.unload()
1980
+ self.options["block_dim"] = block_dim
1981
+
1885
1982
  # compute the hash if needed
1886
1983
  if self.hasher is None:
1887
1984
  self.hasher = ModuleHasher(self)
@@ -1909,6 +2006,7 @@ class Module:
1909
2006
  # determine output paths
1910
2007
  if device.is_cpu:
1911
2008
  output_name = "module_codegen.o"
2009
+ output_arch = None
1912
2010
 
1913
2011
  elif device.is_cuda:
1914
2012
  # determine whether to use PTX or CUBIN
@@ -1947,7 +2045,12 @@ class Module:
1947
2045
  or not warp.config.cache_kernels
1948
2046
  or warp.config.verify_autograd_array_access
1949
2047
  ):
1950
- builder = ModuleBuilder(self, self.options, hasher=self.hasher)
2048
+ builder_options = {
2049
+ **self.options,
2050
+ # Some of the Tile codegen, such as cuFFTDx and cuBLASDx, requires knowledge of the target arch
2051
+ "output_arch": output_arch,
2052
+ }
2053
+ builder = ModuleBuilder(self, builder_options, hasher=self.hasher)
1951
2054
 
1952
2055
  # create a temporary (process unique) dir for build outputs before moving to the binary dir
1953
2056
  build_dir = os.path.join(
@@ -2010,6 +2113,7 @@ class Module:
2010
2113
  config=self.options["mode"],
2011
2114
  fast_math=self.options["fast_math"],
2012
2115
  verify_fp=warp.config.verify_fp,
2116
+ ltoirs=builder.ltoirs.values(),
2013
2117
  )
2014
2118
 
2015
2119
  except Exception as e:
@@ -2017,15 +2121,46 @@ class Module:
2017
2121
  module_load_timer.extra_msg = " (error)"
2018
2122
  raise (e)
2019
2123
 
2124
+ # ------------------------------------------------------------
2125
+ # build meta data
2126
+
2127
+ meta = builder.build_meta()
2128
+ meta_path = os.path.join(build_dir, "module_codegen.meta")
2129
+
2130
+ with open(meta_path, "w") as meta_file:
2131
+ json.dump(meta, meta_file)
2132
+
2020
2133
  # -----------------------------------------------------------
2021
2134
  # update cache
2022
2135
 
2023
- try:
2024
- # Copy process-specific build directory to a process-independent location
2025
- os.rename(build_dir, module_dir)
2026
- except (OSError, FileExistsError):
2027
- # another process likely updated the module dir first
2028
- pass
2136
+ def safe_rename(src, dst, attempts=5, delay=0.1):
2137
+ for i in range(attempts):
2138
+ try:
2139
+ os.rename(src, dst)
2140
+ return
2141
+ except FileExistsError:
2142
+ return
2143
+ except OSError as e:
2144
+ if e.errno == errno.ENOTEMPTY:
2145
+ # if directory exists we assume another process
2146
+ # got there first, in which case we will copy
2147
+ # our output to the directory manually in second step
2148
+ return
2149
+ else:
2150
+ # otherwise assume directory creation failed e.g.: access denied
2151
+ # on Windows we see occasional failures to rename directories due to
2152
+ # some process holding a lock on a file to be moved to workaround
2153
+ # this we make multiple attempts to rename with some delay
2154
+ if i < attempts - 1:
2155
+ time.sleep(delay)
2156
+ else:
2157
+ print(
2158
+ f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
2159
+ )
2160
+ raise e
2161
+
2162
+ # try to move process outputs to cache
2163
+ safe_rename(build_dir, module_dir)
2029
2164
 
2030
2165
  if os.path.exists(module_dir):
2031
2166
  if not os.path.exists(binary_path):
@@ -2053,18 +2188,23 @@ class Module:
2053
2188
 
2054
2189
  # -----------------------------------------------------------
2055
2190
  # Load CPU or CUDA binary
2191
+
2192
+ meta_path = os.path.join(module_dir, "module_codegen.meta")
2193
+ with open(meta_path, "r") as meta_file:
2194
+ meta = json.load(meta_file)
2195
+
2056
2196
  if device.is_cpu:
2057
2197
  # LLVM modules are identified using strings, so we need to ensure uniqueness
2058
2198
  module_handle = f"{module_name}_{self.cpu_exec_id}"
2059
2199
  self.cpu_exec_id += 1
2060
2200
  runtime.llvm.load_obj(binary_path.encode("utf-8"), module_handle.encode("utf-8"))
2061
- module_exec = ModuleExec(module_handle, module_hash, device)
2201
+ module_exec = ModuleExec(module_handle, module_hash, device, meta)
2062
2202
  self.execs[None] = module_exec
2063
2203
 
2064
2204
  elif device.is_cuda:
2065
2205
  cuda_module = warp.build.load_cuda(binary_path, device)
2066
2206
  if cuda_module is not None:
2067
- module_exec = ModuleExec(cuda_module, module_hash, device)
2207
+ module_exec = ModuleExec(cuda_module, module_hash, device, meta)
2068
2208
  self.execs[device.context] = module_exec
2069
2209
  else:
2070
2210
  module_load_timer.extra_msg = " (error)"
@@ -2719,21 +2859,16 @@ class Graph:
2719
2859
 
2720
2860
  class Runtime:
2721
2861
  def __init__(self):
2722
- if sys.version_info < (3, 7):
2723
- raise RuntimeError("Warp requires Python 3.7 as a minimum")
2862
+ if sys.version_info < (3, 8):
2863
+ raise RuntimeError("Warp requires Python 3.8 as a minimum")
2724
2864
  if sys.version_info < (3, 9):
2725
2865
  warp.utils.warn(f"Python 3.9 or newer is recommended for running Warp, detected {sys.version_info}")
2726
2866
 
2727
2867
  bin_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "bin")
2728
2868
 
2729
2869
  if os.name == "nt":
2730
- if sys.version_info >= (3, 8):
2731
- # Python >= 3.8 this method to add dll search paths
2732
- os.add_dll_directory(bin_path)
2733
-
2734
- else:
2735
- # Python < 3.8 we add dll directory to path
2736
- os.environ["PATH"] = bin_path + os.pathsep + os.environ["PATH"]
2870
+ # Python >= 3.8 this method to add dll search paths
2871
+ os.add_dll_directory(bin_path)
2737
2872
 
2738
2873
  warp_lib = os.path.join(bin_path, "warp.dll")
2739
2874
  llvm_lib = os.path.join(bin_path, "warp-clang.dll")
@@ -3205,6 +3340,8 @@ class Runtime:
3205
3340
  self.core.is_cuda_compatibility_enabled.restype = ctypes.c_int
3206
3341
  self.core.is_cutlass_enabled.argtypes = None
3207
3342
  self.core.is_cutlass_enabled.restype = ctypes.c_int
3343
+ self.core.is_mathdx_enabled.argtypes = None
3344
+ self.core.is_mathdx_enabled.restype = ctypes.c_int
3208
3345
 
3209
3346
  self.core.cuda_driver_version.argtypes = None
3210
3347
  self.core.cuda_driver_version.restype = ctypes.c_int
@@ -3329,17 +3466,58 @@ class Runtime:
3329
3466
  self.core.cuda_graph_destroy.restype = ctypes.c_bool
3330
3467
 
3331
3468
  self.core.cuda_compile_program.argtypes = [
3332
- ctypes.c_char_p,
3333
- ctypes.c_int,
3334
- ctypes.c_char_p,
3335
- ctypes.c_bool,
3336
- ctypes.c_bool,
3337
- ctypes.c_bool,
3338
- ctypes.c_bool,
3339
- ctypes.c_char_p,
3469
+ ctypes.c_char_p, # cuda_src
3470
+ ctypes.c_int, # arch
3471
+ ctypes.c_char_p, # include_dir
3472
+ ctypes.c_int, # num_cuda_include_dirs
3473
+ ctypes.POINTER(ctypes.c_char_p), # cuda include dirs
3474
+ ctypes.c_bool, # debug
3475
+ ctypes.c_bool, # verbose
3476
+ ctypes.c_bool, # verify_fp
3477
+ ctypes.c_bool, # fast_math
3478
+ ctypes.c_char_p, # output_path
3479
+ ctypes.c_size_t, # num_ltoirs
3480
+ ctypes.POINTER(ctypes.c_char_p), # ltoirs
3481
+ ctypes.POINTER(ctypes.c_size_t), # ltoir_sizes
3340
3482
  ]
3341
3483
  self.core.cuda_compile_program.restype = ctypes.c_size_t
3342
3484
 
3485
+ self.core.cuda_compile_fft.argtypes = [
3486
+ ctypes.c_char_p, # lto
3487
+ ctypes.c_char_p, # function name
3488
+ ctypes.c_int, # num include dirs
3489
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3490
+ ctypes.c_char_p, # mathdx include dir
3491
+ ctypes.c_int, # arch
3492
+ ctypes.c_int, # size
3493
+ ctypes.c_int, # ept
3494
+ ctypes.c_int, # direction
3495
+ ctypes.c_int, # precision
3496
+ ctypes.POINTER(ctypes.c_int), # smem (out)
3497
+ ]
3498
+ self.core.cuda_compile_fft.restype = ctypes.c_bool
3499
+
3500
+ self.core.cuda_compile_dot.argtypes = [
3501
+ ctypes.c_char_p, # lto
3502
+ ctypes.c_char_p, # function name
3503
+ ctypes.c_int, # num include dirs
3504
+ ctypes.POINTER(ctypes.c_char_p), # include dirs
3505
+ ctypes.c_char_p, # mathdx include dir
3506
+ ctypes.c_int, # arch
3507
+ ctypes.c_int, # M
3508
+ ctypes.c_int, # N
3509
+ ctypes.c_int, # K
3510
+ ctypes.c_int, # a_precision
3511
+ ctypes.c_int, # b_precision
3512
+ ctypes.c_int, # c_precision
3513
+ ctypes.c_int, # type
3514
+ ctypes.c_int, # a_arrangement
3515
+ ctypes.c_int, # b_arrangement
3516
+ ctypes.c_int, # c_arrangement
3517
+ ctypes.c_int, # num threads
3518
+ ]
3519
+ self.core.cuda_compile_dot.restype = ctypes.c_bool
3520
+
3343
3521
  self.core.cuda_load_module.argtypes = [ctypes.c_void_p, ctypes.c_char_p]
3344
3522
  self.core.cuda_load_module.restype = ctypes.c_void_p
3345
3523
 
@@ -3349,11 +3527,19 @@ class Runtime:
3349
3527
  self.core.cuda_get_kernel.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_char_p]
3350
3528
  self.core.cuda_get_kernel.restype = ctypes.c_void_p
3351
3529
 
3530
+ self.core.cuda_get_max_shared_memory.argtypes = [ctypes.c_void_p]
3531
+ self.core.cuda_get_max_shared_memory.restype = ctypes.c_int
3532
+
3533
+ self.core.cuda_configure_kernel_shared_memory.argtypes = [ctypes.c_void_p, ctypes.c_int]
3534
+ self.core.cuda_configure_kernel_shared_memory.restype = ctypes.c_bool
3535
+
3352
3536
  self.core.cuda_launch_kernel.argtypes = [
3353
3537
  ctypes.c_void_p,
3354
3538
  ctypes.c_void_p,
3355
3539
  ctypes.c_size_t,
3356
3540
  ctypes.c_int,
3541
+ ctypes.c_int,
3542
+ ctypes.c_int,
3357
3543
  ctypes.POINTER(ctypes.c_void_p),
3358
3544
  ctypes.c_void_p,
3359
3545
  ]
@@ -3382,6 +3568,23 @@ class Runtime:
3382
3568
  self.core.cuda_timing_end.argtypes = []
3383
3569
  self.core.cuda_timing_end.restype = None
3384
3570
 
3571
+ self.core.graph_coloring.argtypes = [
3572
+ ctypes.c_int,
3573
+ warp.types.array_t,
3574
+ ctypes.c_int,
3575
+ warp.types.array_t,
3576
+ ]
3577
+ self.core.graph_coloring.restype = ctypes.c_int
3578
+
3579
+ self.core.balance_coloring.argtypes = [
3580
+ ctypes.c_int,
3581
+ warp.types.array_t,
3582
+ ctypes.c_int,
3583
+ ctypes.c_float,
3584
+ warp.types.array_t,
3585
+ ]
3586
+ self.core.balance_coloring.restype = ctypes.c_float
3587
+
3385
3588
  self.core.init.restype = ctypes.c_int
3386
3589
 
3387
3590
  except AttributeError as e:
@@ -3607,10 +3810,7 @@ class Runtime:
3607
3810
 
3608
3811
  def load_dll(self, dll_path):
3609
3812
  try:
3610
- if sys.version_info >= (3, 8):
3611
- dll = ctypes.CDLL(dll_path, winmode=0)
3612
- else:
3613
- dll = ctypes.CDLL(dll_path)
3813
+ dll = ctypes.CDLL(dll_path, winmode=0)
3614
3814
  except OSError as e:
3615
3815
  if "GLIBCXX" in str(e):
3616
3816
  raise RuntimeError(
@@ -3751,7 +3951,7 @@ def is_cuda_available() -> bool:
3751
3951
  return get_cuda_device_count() > 0
3752
3952
 
3753
3953
 
3754
- def is_device_available(device):
3954
+ def is_device_available(device: Device) -> bool:
3755
3955
  return device in get_devices()
3756
3956
 
3757
3957
 
@@ -3811,7 +4011,7 @@ def get_cuda_devices() -> List[Device]:
3811
4011
 
3812
4012
 
3813
4013
  def get_preferred_device() -> Device:
3814
- """Returns the preferred compute device, CUDA if available and CPU otherwise."""
4014
+ """Returns the preferred compute device, ``cuda:0`` if available and ``cpu`` otherwise."""
3815
4015
 
3816
4016
  init()
3817
4017
 
@@ -3896,7 +4096,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
3896
4096
  They should generally be enabled, but there is a rare caveat. Copying data between different GPUs
3897
4097
  may fail during graph capture if the memory was allocated using pooled allocators and memory pool
3898
4098
  access is not enabled between the two GPUs. This is an internal CUDA limitation that is not related
3899
- to Warp. The preferred solution is to enable memory pool access using `warp.set_mempool_access_enabled()`.
4099
+ to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
3900
4100
  If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
3901
4101
  prior to graph capture.
3902
4102
  """
@@ -3951,7 +4151,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3951
4151
 
3952
4152
 
3953
4153
  def get_mempool_release_threshold(device: Devicelike) -> int:
3954
- """Get the CUDA memory pool release threshold on the device."""
4154
+ """Get the CUDA memory pool release threshold on the device in bytes."""
3955
4155
 
3956
4156
  init()
3957
4157
 
@@ -3970,7 +4170,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3970
4170
  """Check if `peer_device` can directly access the memory of `target_device` on this system.
3971
4171
 
3972
4172
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3973
- CUDA pooled allocators, use `is_mempool_access_supported()`.
4173
+ CUDA pooled allocators, use :func:`is_mempool_access_supported()`.
3974
4174
 
3975
4175
  Returns:
3976
4176
  A Boolean value indicating if this peer access is supported by the system.
@@ -3991,7 +4191,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike) -
3991
4191
  """Check if `peer_device` can currently access the memory of `target_device`.
3992
4192
 
3993
4193
  This applies to memory allocated using default CUDA allocators. For memory allocated using
3994
- CUDA pooled allocators, use `is_mempool_access_enabled()`.
4194
+ CUDA pooled allocators, use :func:`is_mempool_access_enabled()`.
3995
4195
 
3996
4196
  Returns:
3997
4197
  A Boolean value indicating if this peer access is currently enabled.
@@ -4015,7 +4215,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4015
4215
  a negative impact on memory consumption and allocation performance.
4016
4216
 
4017
4217
  This applies to memory allocated using default CUDA allocators. For memory allocated using
4018
- CUDA pooled allocators, use `set_mempool_access_enabled()`.
4218
+ CUDA pooled allocators, use :func:`set_mempool_access_enabled()`.
4019
4219
  """
4020
4220
 
4021
4221
  init()
@@ -4043,7 +4243,8 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
4043
4243
  def is_mempool_access_supported(target_device: Devicelike, peer_device: Devicelike) -> bool:
4044
4244
  """Check if `peer_device` can directly access the memory pool of `target_device`.
4045
4245
 
4046
- If mempool access is possible, it can be managed using `set_mempool_access_enabled()` and `is_mempool_access_enabled()`.
4246
+ If mempool access is possible, it can be managed using :func:`set_mempool_access_enabled()`
4247
+ and :func:`is_mempool_access_enabled()`.
4047
4248
 
4048
4249
  Returns:
4049
4250
  A Boolean value indicating if this memory pool access is supported by the system.
@@ -4061,7 +4262,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
4061
4262
  """Check if `peer_device` can currently access the memory pool of `target_device`.
4062
4263
 
4063
4264
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4064
- default CUDA allocators, use `is_peer_access_enabled()`.
4265
+ default CUDA allocators, use :func:`is_peer_access_enabled()`.
4065
4266
 
4066
4267
  Returns:
4067
4268
  A Boolean value indicating if this peer access is currently enabled.
@@ -4082,7 +4283,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
4082
4283
  """Enable or disable access from `peer_device` to the memory pool of `target_device`.
4083
4284
 
4084
4285
  This applies to memory allocated using CUDA pooled allocators. For memory allocated using
4085
- default CUDA allocators, use `set_peer_access_enabled()`.
4286
+ default CUDA allocators, use :func:`set_peer_access_enabled()`.
4086
4287
  """
4087
4288
 
4088
4289
  init()
@@ -4791,7 +4992,9 @@ def pack_arg(kernel, arg_type, arg_name, value, device, adjoint=False):
4791
4992
  # represents all data required for a kernel launch
4792
4993
  # so that launches can be replayed quickly, use `wp.launch(..., record_cmd=True)`
4793
4994
  class Launch:
4794
- def __init__(self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0):
4995
+ def __init__(
4996
+ self, kernel, device, hooks=None, params=None, params_addr=None, bounds=None, max_blocks=0, block_dim=256
4997
+ ):
4795
4998
  # retain the module executable so it doesn't get unloaded
4796
4999
  self.module_exec = kernel.module.load(device)
4797
5000
  if not self.module_exec:
@@ -4830,6 +5033,7 @@ class Launch:
4830
5033
  self.device = device
4831
5034
  self.bounds = bounds
4832
5035
  self.max_blocks = max_blocks
5036
+ self.block_dim = block_dim
4833
5037
 
4834
5038
  def set_dim(self, dim):
4835
5039
  self.bounds = warp.types.launch_bounds_t(dim)
@@ -4911,6 +5115,8 @@ class Launch:
4911
5115
  self.hooks.forward,
4912
5116
  self.bounds.size,
4913
5117
  self.max_blocks,
5118
+ self.block_dim,
5119
+ self.hooks.forward_smem_bytes,
4914
5120
  self.params_addr,
4915
5121
  stream.cuda_stream,
4916
5122
  )
@@ -4929,6 +5135,7 @@ def launch(
4929
5135
  record_tape=True,
4930
5136
  record_cmd=False,
4931
5137
  max_blocks=0,
5138
+ block_dim=256,
4932
5139
  ):
4933
5140
  """Launch a Warp kernel on the target device
4934
5141
 
@@ -4948,6 +5155,7 @@ def launch(
4948
5155
  record_cmd: When True the launch will be returned as a ``Launch`` command object, the launch will not occur until the user calls ``cmd.launch()``
4949
5156
  max_blocks: The maximum number of CUDA thread blocks to use. Only has an effect for CUDA kernel launches.
4950
5157
  If negative or zero, the maximum hardware value will be used.
5158
+ block_dim: The number of threads per block.
4951
5159
  """
4952
5160
 
4953
5161
  init()
@@ -5001,7 +5209,12 @@ def launch(
5001
5209
  kernel = kernel.add_overload(fwd_types)
5002
5210
 
5003
5211
  # delay load modules, including new overload if needed
5004
- module_exec = kernel.module.load(device)
5212
+ try:
5213
+ module_exec = kernel.module.load(device, block_dim)
5214
+ except Exception:
5215
+ kernel.adj.skip_build = True
5216
+ raise
5217
+
5005
5218
  if not module_exec:
5006
5219
  return
5007
5220
 
@@ -5057,7 +5270,14 @@ def launch(
5057
5270
  )
5058
5271
 
5059
5272
  runtime.core.cuda_launch_kernel(
5060
- device.context, hooks.backward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5273
+ device.context,
5274
+ hooks.backward,
5275
+ bounds.size,
5276
+ max_blocks,
5277
+ block_dim,
5278
+ hooks.backward_smem_bytes,
5279
+ kernel_params,
5280
+ stream.cuda_stream,
5061
5281
  )
5062
5282
 
5063
5283
  else:
@@ -5074,13 +5294,22 @@ def launch(
5074
5294
  params_addr=kernel_params,
5075
5295
  bounds=bounds,
5076
5296
  device=device,
5297
+ max_blocks=max_blocks,
5298
+ block_dim=block_dim,
5077
5299
  )
5078
5300
  return launch
5079
5301
 
5080
5302
  else:
5081
5303
  # launch
5082
5304
  runtime.core.cuda_launch_kernel(
5083
- device.context, hooks.forward, bounds.size, max_blocks, kernel_params, stream.cuda_stream
5305
+ device.context,
5306
+ hooks.forward,
5307
+ bounds.size,
5308
+ max_blocks,
5309
+ block_dim,
5310
+ hooks.forward_smem_bytes,
5311
+ kernel_params,
5312
+ stream.cuda_stream,
5084
5313
  )
5085
5314
 
5086
5315
  try:
@@ -5094,13 +5323,65 @@ def launch(
5094
5323
  # record file, lineno, func as metadata
5095
5324
  frame = inspect.currentframe().f_back
5096
5325
  caller = {"file": frame.f_code.co_filename, "lineno": frame.f_lineno, "func": frame.f_code.co_name}
5097
- runtime.tape.record_launch(kernel, dim, max_blocks, inputs, outputs, device, metadata={"caller": caller})
5326
+ runtime.tape.record_launch(
5327
+ kernel, dim, max_blocks, inputs, outputs, device, block_dim, metadata={"caller": caller}
5328
+ )
5098
5329
 
5099
5330
  # detect illegal inter-kernel read/write access patterns if verification flag is set
5100
5331
  if warp.config.verify_autograd_array_access:
5101
5332
  runtime.tape._check_kernel_array_access(kernel, fwd_args)
5102
5333
 
5103
5334
 
5335
+ def launch_tiled(*args, **kwargs):
5336
+ """A helper method for launching a grid with an extra trailing dimension equal to the block size.
5337
+
5338
+ For example, to launch a 2D grid, where each element has 64 threads assigned you would use the following:
5339
+
5340
+ .. code-block:: python
5341
+
5342
+ wp.launch_tiled(kernel, [M, N], inputs=[...], block_dim=64)
5343
+
5344
+ Which is equivalent to the following:
5345
+
5346
+ .. code-block:: python
5347
+
5348
+ wp.launch(kernel, [M, N, 64], inputs=[...], block_dim=64)
5349
+
5350
+ Inside your kernel code you can retrieve the first two indices of the thread as usual, ignoring the implicit third dimension if desired:
5351
+
5352
+ .. code-block:: python
5353
+
5354
+ @wp.kernel
5355
+ def compute()
5356
+
5357
+ i, j = wp.tid()
5358
+
5359
+ ...
5360
+ """
5361
+
5362
+ # promote dim to a list in case it was passed as a scalar or tuple
5363
+ if "dim" not in kwargs:
5364
+ raise RuntimeError("Launch dimensions 'dim' argument should be passed via. keyword args for wp.launch_tiled()")
5365
+
5366
+ if "block_dim" not in kwargs:
5367
+ raise RuntimeError(
5368
+ "Launch block dimension 'block_dim' argument should be passed via. keyword args for wp.launch_tiled()"
5369
+ )
5370
+
5371
+ dim = kwargs["dim"]
5372
+ if not isinstance(dim, list):
5373
+ dim = list(dim) if isinstance(dim, tuple) else [dim]
5374
+
5375
+ if len(dim) > 3:
5376
+ raise RuntimeError("wp.launch_tiled() requires a grid with fewer than 4 dimensions")
5377
+
5378
+ # add trailing dimension
5379
+ kwargs["dim"] = dim + [kwargs["block_dim"]]
5380
+
5381
+ # forward to original launch method
5382
+ return launch(*args, **kwargs)
5383
+
5384
+
5104
5385
  def synchronize():
5105
5386
  """Manually synchronize the calling CPU thread with any outstanding CUDA work on all devices
5106
5387
 
@@ -5619,16 +5900,6 @@ def type_str(t):
5619
5900
  return "Any"
5620
5901
  elif t == Callable:
5621
5902
  return "Callable"
5622
- elif t == Tuple[int]:
5623
- return "Tuple[int]"
5624
- elif t == Tuple[int, int]:
5625
- return "Tuple[int, int]"
5626
- elif t == Tuple[int, int, int]:
5627
- return "Tuple[int, int, int]"
5628
- elif t == Tuple[int, int, int, int]:
5629
- return "Tuple[int, int, int, int]"
5630
- elif t == Tuple[int, ...]:
5631
- return "Tuple[int, ...]"
5632
5903
  elif isinstance(t, int):
5633
5904
  return str(t)
5634
5905
  elif isinstance(t, List):
@@ -5663,9 +5934,13 @@ def type_str(t):
5663
5934
  return f"Transformation[{type_str(t._wp_scalar_type_)}]"
5664
5935
 
5665
5936
  raise TypeError("Invalid vector or matrix dimensions")
5666
- elif typing.get_origin(t) in (List, Mapping, Sequence, Union, Tuple):
5667
- args_repr = ", ".join(type_str(x) for x in typing.get_args(t))
5668
- return f"{t.__name__}[{args_repr}]"
5937
+ elif warp.codegen.get_type_origin(t) in (list, tuple):
5938
+ args_repr = ", ".join(type_str(x) for x in warp.codegen.get_type_args(t))
5939
+ return f"{t._name}[{args_repr}]"
5940
+ elif t is Ellipsis:
5941
+ return "..."
5942
+ elif warp.types.is_tile(t):
5943
+ return "Tile"
5669
5944
 
5670
5945
  return t.__name__
5671
5946
 
@@ -5826,9 +6101,6 @@ def export_stubs(file): # pragma: no cover
5826
6101
  print('Cols = TypeVar("Cols", bound=int)', file=file)
5827
6102
  print('DType = TypeVar("DType")', file=file)
5828
6103
 
5829
- print('Int = TypeVar("Int")', file=file)
5830
- print('Float = TypeVar("Float")', file=file)
5831
- print('Scalar = TypeVar("Scalar")', file=file)
5832
6104
  print("Vector = Generic[Length, Scalar]", file=file)
5833
6105
  print("Matrix = Generic[Rows, Cols, Scalar]", file=file)
5834
6106
  print("Quaternion = Generic[Float]", file=file)