warp-lang 1.6.2__py3-none-win_amd64.whl → 1.7.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (191) hide show
  1. warp/__init__.py +7 -1
  2. warp/autograd.py +12 -2
  3. warp/bin/warp-clang.dll +0 -0
  4. warp/bin/warp.dll +0 -0
  5. warp/build.py +410 -0
  6. warp/build_dll.py +6 -14
  7. warp/builtins.py +463 -372
  8. warp/codegen.py +196 -124
  9. warp/config.py +42 -6
  10. warp/context.py +496 -271
  11. warp/dlpack.py +8 -6
  12. warp/examples/assets/nonuniform.usd +0 -0
  13. warp/examples/assets/nvidia_logo.png +0 -0
  14. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  15. warp/examples/benchmarks/benchmark_tile_load_store.py +103 -0
  16. warp/examples/core/example_sample_mesh.py +300 -0
  17. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  18. warp/examples/fem/example_apic_fluid.py +1 -1
  19. warp/examples/fem/example_burgers.py +2 -2
  20. warp/examples/fem/example_deformed_geometry.py +1 -1
  21. warp/examples/fem/example_distortion_energy.py +1 -1
  22. warp/examples/fem/example_magnetostatics.py +6 -6
  23. warp/examples/fem/utils.py +9 -3
  24. warp/examples/interop/example_jax_callable.py +116 -0
  25. warp/examples/interop/example_jax_ffi_callback.py +132 -0
  26. warp/examples/interop/example_jax_kernel.py +205 -0
  27. warp/examples/optim/example_fluid_checkpoint.py +497 -0
  28. warp/examples/tile/example_tile_matmul.py +2 -4
  29. warp/fem/__init__.py +11 -1
  30. warp/fem/adaptivity.py +4 -4
  31. warp/fem/field/field.py +11 -1
  32. warp/fem/field/nodal_field.py +56 -88
  33. warp/fem/field/virtual.py +62 -23
  34. warp/fem/geometry/adaptive_nanogrid.py +16 -13
  35. warp/fem/geometry/closest_point.py +1 -1
  36. warp/fem/geometry/deformed_geometry.py +5 -2
  37. warp/fem/geometry/geometry.py +5 -0
  38. warp/fem/geometry/grid_2d.py +12 -12
  39. warp/fem/geometry/grid_3d.py +12 -15
  40. warp/fem/geometry/hexmesh.py +5 -7
  41. warp/fem/geometry/nanogrid.py +9 -11
  42. warp/fem/geometry/quadmesh.py +13 -13
  43. warp/fem/geometry/tetmesh.py +3 -4
  44. warp/fem/geometry/trimesh.py +7 -20
  45. warp/fem/integrate.py +262 -93
  46. warp/fem/linalg.py +5 -5
  47. warp/fem/quadrature/pic_quadrature.py +37 -22
  48. warp/fem/quadrature/quadrature.py +194 -25
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_function_space.py +4 -2
  51. warp/fem/space/basis_space.py +25 -18
  52. warp/fem/space/hexmesh_function_space.py +2 -2
  53. warp/fem/space/partition.py +6 -2
  54. warp/fem/space/quadmesh_function_space.py +8 -8
  55. warp/fem/space/shape/cube_shape_function.py +23 -23
  56. warp/fem/space/shape/square_shape_function.py +12 -12
  57. warp/fem/space/shape/triangle_shape_function.py +1 -1
  58. warp/fem/space/tetmesh_function_space.py +3 -3
  59. warp/fem/space/trimesh_function_space.py +2 -2
  60. warp/fem/utils.py +12 -6
  61. warp/jax.py +14 -1
  62. warp/jax_experimental/__init__.py +16 -0
  63. warp/{jax_experimental.py → jax_experimental/custom_call.py} +28 -29
  64. warp/jax_experimental/ffi.py +702 -0
  65. warp/jax_experimental/xla_ffi.py +602 -0
  66. warp/math.py +89 -0
  67. warp/native/array.h +13 -0
  68. warp/native/builtin.h +29 -3
  69. warp/native/bvh.cpp +3 -1
  70. warp/native/bvh.cu +42 -14
  71. warp/native/bvh.h +2 -1
  72. warp/native/clang/clang.cpp +30 -3
  73. warp/native/cuda_util.cpp +14 -0
  74. warp/native/cuda_util.h +2 -0
  75. warp/native/exports.h +68 -63
  76. warp/native/intersect.h +26 -26
  77. warp/native/intersect_adj.h +33 -33
  78. warp/native/marching.cu +1 -1
  79. warp/native/mat.h +513 -9
  80. warp/native/mesh.h +10 -10
  81. warp/native/quat.h +99 -11
  82. warp/native/rand.h +6 -0
  83. warp/native/sort.cpp +122 -59
  84. warp/native/sort.cu +152 -15
  85. warp/native/sort.h +8 -1
  86. warp/native/sparse.cpp +43 -22
  87. warp/native/sparse.cu +52 -17
  88. warp/native/svd.h +116 -0
  89. warp/native/tile.h +312 -116
  90. warp/native/tile_reduce.h +46 -3
  91. warp/native/vec.h +68 -7
  92. warp/native/volume.cpp +85 -113
  93. warp/native/volume_builder.cu +25 -10
  94. warp/native/volume_builder.h +6 -0
  95. warp/native/warp.cpp +5 -6
  96. warp/native/warp.cu +100 -11
  97. warp/native/warp.h +19 -10
  98. warp/optim/linear.py +10 -10
  99. warp/render/render_opengl.py +19 -17
  100. warp/render/render_usd.py +93 -3
  101. warp/sim/articulation.py +4 -4
  102. warp/sim/collide.py +32 -19
  103. warp/sim/import_mjcf.py +449 -155
  104. warp/sim/import_urdf.py +32 -12
  105. warp/sim/inertia.py +189 -156
  106. warp/sim/integrator_euler.py +8 -5
  107. warp/sim/integrator_featherstone.py +3 -10
  108. warp/sim/integrator_vbd.py +207 -2
  109. warp/sim/integrator_xpbd.py +8 -5
  110. warp/sim/model.py +71 -25
  111. warp/sim/render.py +4 -0
  112. warp/sim/utils.py +2 -2
  113. warp/sparse.py +642 -555
  114. warp/stubs.py +217 -20
  115. warp/tests/__main__.py +0 -15
  116. warp/tests/assets/torus.usda +1 -1
  117. warp/tests/cuda/__init__.py +0 -0
  118. warp/tests/{test_mempool.py → cuda/test_mempool.py} +39 -0
  119. warp/tests/{test_streams.py → cuda/test_streams.py} +71 -0
  120. warp/tests/geometry/__init__.py +0 -0
  121. warp/tests/{test_mesh_query_point.py → geometry/test_mesh_query_point.py} +66 -63
  122. warp/tests/{test_mesh_query_ray.py → geometry/test_mesh_query_ray.py} +1 -1
  123. warp/tests/{test_volume.py → geometry/test_volume.py} +41 -6
  124. warp/tests/interop/__init__.py +0 -0
  125. warp/tests/{test_dlpack.py → interop/test_dlpack.py} +28 -5
  126. warp/tests/sim/__init__.py +0 -0
  127. warp/tests/{disabled_kinematics.py → sim/disabled_kinematics.py} +9 -10
  128. warp/tests/{test_collision.py → sim/test_collision.py} +236 -205
  129. warp/tests/sim/test_inertia.py +161 -0
  130. warp/tests/{test_model.py → sim/test_model.py} +40 -0
  131. warp/tests/{flaky_test_sim_grad.py → sim/test_sim_grad.py} +4 -0
  132. warp/tests/{test_sim_kinematics.py → sim/test_sim_kinematics.py} +2 -1
  133. warp/tests/sim/test_vbd.py +597 -0
  134. warp/tests/sim/test_xpbd.py +399 -0
  135. warp/tests/test_bool.py +1 -1
  136. warp/tests/test_codegen.py +24 -3
  137. warp/tests/test_examples.py +40 -38
  138. warp/tests/test_fem.py +98 -14
  139. warp/tests/test_linear_solvers.py +0 -11
  140. warp/tests/test_mat.py +577 -156
  141. warp/tests/test_mat_scalar_ops.py +4 -4
  142. warp/tests/test_overwrite.py +0 -60
  143. warp/tests/test_quat.py +356 -151
  144. warp/tests/test_rand.py +44 -37
  145. warp/tests/test_sparse.py +47 -6
  146. warp/tests/test_spatial.py +75 -0
  147. warp/tests/test_static.py +1 -1
  148. warp/tests/test_utils.py +84 -4
  149. warp/tests/test_vec.py +336 -178
  150. warp/tests/tile/__init__.py +0 -0
  151. warp/tests/{test_tile.py → tile/test_tile.py} +136 -51
  152. warp/tests/{test_tile_load.py → tile/test_tile_load.py} +98 -1
  153. warp/tests/{test_tile_mathdx.py → tile/test_tile_mathdx.py} +9 -6
  154. warp/tests/{test_tile_mlp.py → tile/test_tile_mlp.py} +25 -14
  155. warp/tests/{test_tile_reduce.py → tile/test_tile_reduce.py} +60 -1
  156. warp/tests/{test_tile_view.py → tile/test_tile_view.py} +1 -1
  157. warp/tests/unittest_serial.py +1 -0
  158. warp/tests/unittest_suites.py +45 -62
  159. warp/tests/unittest_utils.py +2 -1
  160. warp/thirdparty/unittest_parallel.py +3 -1
  161. warp/types.py +175 -666
  162. warp/utils.py +137 -72
  163. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/METADATA +46 -12
  164. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/RECORD +184 -171
  165. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/WHEEL +1 -1
  166. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info/licenses}/LICENSE.md +0 -26
  167. warp/examples/optim/example_walker.py +0 -317
  168. warp/native/cutlass_gemm.cpp +0 -43
  169. warp/native/cutlass_gemm.cu +0 -382
  170. warp/tests/test_matmul.py +0 -511
  171. warp/tests/test_matmul_lite.py +0 -411
  172. warp/tests/test_vbd.py +0 -386
  173. warp/tests/unused_test_misc.py +0 -77
  174. /warp/tests/{test_async.py → cuda/test_async.py} +0 -0
  175. /warp/tests/{test_ipc.py → cuda/test_ipc.py} +0 -0
  176. /warp/tests/{test_multigpu.py → cuda/test_multigpu.py} +0 -0
  177. /warp/tests/{test_peer.py → cuda/test_peer.py} +0 -0
  178. /warp/tests/{test_pinned.py → cuda/test_pinned.py} +0 -0
  179. /warp/tests/{test_bvh.py → geometry/test_bvh.py} +0 -0
  180. /warp/tests/{test_hash_grid.py → geometry/test_hash_grid.py} +0 -0
  181. /warp/tests/{test_marching_cubes.py → geometry/test_marching_cubes.py} +0 -0
  182. /warp/tests/{test_mesh.py → geometry/test_mesh.py} +0 -0
  183. /warp/tests/{test_mesh_query_aabb.py → geometry/test_mesh_query_aabb.py} +0 -0
  184. /warp/tests/{test_volume_write.py → geometry/test_volume_write.py} +0 -0
  185. /warp/tests/{test_jax.py → interop/test_jax.py} +0 -0
  186. /warp/tests/{test_paddle.py → interop/test_paddle.py} +0 -0
  187. /warp/tests/{test_torch.py → interop/test_torch.py} +0 -0
  188. /warp/tests/{test_coloring.py → sim/test_coloring.py} +0 -0
  189. /warp/tests/{test_sim_grad_bounce_linear.py → sim/test_sim_grad_bounce_linear.py} +0 -0
  190. /warp/tests/{test_tile_shared_memory.py → tile/test_tile_shared_memory.py} +0 -0
  191. {warp_lang-1.6.2.dist-info → warp_lang-1.7.1.dist-info}/top_level.txt +0 -0
warp/__init__.py CHANGED
@@ -84,7 +84,12 @@ from warp.context import Stream, get_stream, set_stream, wait_stream, synchroniz
84
84
  from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time
85
85
  from warp.context import RegisteredGLBuffer
86
86
  from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
87
- from warp.context import set_mempool_release_threshold, get_mempool_release_threshold
87
+ from warp.context import (
88
+ set_mempool_release_threshold,
89
+ get_mempool_release_threshold,
90
+ get_mempool_used_mem_current,
91
+ get_mempool_used_mem_high,
92
+ )
88
93
  from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
89
94
  from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
90
95
 
@@ -120,6 +125,7 @@ from warp.paddle import device_from_paddle, device_to_paddle
120
125
  from warp.paddle import stream_from_paddle
121
126
 
122
127
  from warp.build import clear_kernel_cache
128
+ from warp.build import clear_lto_cache
123
129
 
124
130
  from warp.constants import *
125
131
 
warp/autograd.py CHANGED
@@ -52,7 +52,12 @@ def gradcheck(
52
52
  ) -> bool:
53
53
  """
54
54
  Checks whether the autodiff gradient of a Warp kernel matches finite differences.
55
- Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
55
+
56
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
57
+
58
+ .. math::
59
+
60
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
56
61
 
57
62
  The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
58
63
  ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
@@ -250,7 +255,12 @@ def gradcheck_tape(
250
255
  ) -> bool:
251
256
  """
252
257
  Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
253
- Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
258
+
259
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
260
+
261
+ .. math::
262
+
263
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
254
264
 
255
265
  Note:
256
266
  Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
warp/bin/warp-clang.dll CHANGED
Binary file
warp/bin/warp.dll CHANGED
Binary file
warp/build.py CHANGED
@@ -14,10 +14,16 @@
14
14
  # limitations under the License.
15
15
 
16
16
  import ctypes
17
+ import errno
18
+ import hashlib
19
+ import json
17
20
  import os
21
+ import time
22
+ from pathlib import Path
18
23
 
19
24
  import warp.config
20
25
  from warp.thirdparty import appdirs
26
+ from warp.types import *
21
27
 
22
28
  # From nvJitLink.h
23
29
  nvJitLink_input_type = {"cubin": 1, "ptx": 2, "ltoir": 3, "fatbin": 4, "object": 5, "library": 6}
@@ -131,6 +137,7 @@ def clear_kernel_cache() -> None:
131
137
 
132
138
  Only directories beginning with ``wp_`` will be deleted.
133
139
  This function only clears the cache for the current Warp version.
140
+ LTO artifacts are not affected.
134
141
  """
135
142
 
136
143
  warp.context.init()
@@ -145,3 +152,406 @@ def clear_kernel_cache() -> None:
145
152
  if os.path.isdir(item_path) and item.startswith("wp_"):
146
153
  # Remove the directory and its contents
147
154
  shutil.rmtree(item_path, ignore_errors=True)
155
+
156
+
157
+ def clear_lto_cache() -> None:
158
+ """Clear the LTO cache directory of previously generated LTO code.
159
+
160
+ The LTO cache is stored within a subdirectory of the kernel cache directory.
161
+ This function only clears the cache for the current Warp version.
162
+ """
163
+
164
+ warp.context.init()
165
+
166
+ import shutil
167
+
168
+ is_intialized = warp.context.runtime is not None
169
+ assert is_intialized, "The kernel cache directory is not configured; wp.init() has not been called yet or failed."
170
+
171
+ lto_path = os.path.join(warp.config.kernel_cache_dir, "lto")
172
+ if os.path.isdir(lto_path):
173
+ # Remove the lto directory and its contents
174
+ shutil.rmtree(lto_path, ignore_errors=True)
175
+
176
+
177
+ def safe_rename(src, dst, attempts=5, delay=0.1):
178
+ for i in range(attempts):
179
+ try:
180
+ os.rename(src, dst)
181
+ return
182
+ except FileExistsError:
183
+ return
184
+ except OSError as e:
185
+ if e.errno == errno.ENOTEMPTY:
186
+ # if directory exists we assume another process
187
+ # got there first, in which case we will copy
188
+ # our output to the directory manually in second step
189
+ return
190
+ else:
191
+ # otherwise assume directory creation failed e.g.: access denied
192
+ # on Windows we see occasional failures to rename directories due to
193
+ # some process holding a lock on a file to be moved to workaround
194
+ # this we make multiple attempts to rename with some delay
195
+ if i < attempts - 1:
196
+ time.sleep(delay)
197
+ else:
198
+ print(
199
+ f"Could not update Warp cache with compiled binaries, trying to rename {src} to {dst}, error {e}"
200
+ )
201
+ raise e
202
+
203
+
204
+ def hash_symbol(symbol):
205
+ ch = hashlib.sha256()
206
+ ch.update(symbol.encode("utf-8"))
207
+ return ch.hexdigest()
208
+
209
+
210
+ def get_lto_cache_dir():
211
+ lto_dir = os.path.join(warp.config.kernel_cache_dir, "lto")
212
+ return lto_dir
213
+
214
+
215
+ def get_cached_lto(path):
216
+ if os.path.exists(path):
217
+ with open(path, "rb") as f:
218
+ lto_code_data = f.read()
219
+ return lto_code_data
220
+ else:
221
+ return None
222
+
223
+
224
+ def get_cached_lto_meta(path, symbol):
225
+ if os.path.exists(path):
226
+ with open(path, "r") as f:
227
+ keys = json.load(f)
228
+ value = keys[symbol]
229
+ return value
230
+ else:
231
+ return None
232
+
233
+
234
+ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
235
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
236
+ arch = min(arch, 90)
237
+
238
+ # Maps Python/Warp types to C++ types and enums
239
+ def cublasdx_type_map(dtype):
240
+ if dtype == float16:
241
+ return ("wp::float16", 3, 0)
242
+ if dtype == float32:
243
+ return ("wp::float32", 5, 0)
244
+ if dtype == float64:
245
+ return ("wp::float64", 6, 0)
246
+ if dtype == vec2h:
247
+ return ("wp::vec2h", 3, 1)
248
+ if dtype == vec2f:
249
+ return ("wp::vec2f", 5, 1)
250
+ if dtype == vec2d:
251
+ return ("wp::vec2d", 6, 1)
252
+ raise TypeError("Unsupported input type in tile_matmul")
253
+
254
+ def cublasdx_arrangement_map(layout):
255
+ if layout == "colmajor":
256
+ return 0 # CUBLASDX_ARRANGEMENT_COL_MAJOR
257
+ if layout == "rowmajor":
258
+ return 1 # CUBLASDX_ARRANGEMENT_ROW_MAJOR
259
+ raise ValueError("Unsupported layout in tile_matmul")
260
+
261
+ (a_dtype, a_prec, a_type) = cublasdx_type_map(adtype)
262
+ (b_dtype, b_prec, b_type) = cublasdx_type_map(bdtype)
263
+ (c_dtype, c_prec, c_type) = cublasdx_type_map(cdtype)
264
+ a_arrangement = cublasdx_arrangement_map(alayout)
265
+ b_arrangement = cublasdx_arrangement_map(blayout)
266
+ c_arrangement = cublasdx_arrangement_map(clayout)
267
+
268
+ if a_type != b_type or a_type != c_type:
269
+ raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
270
+
271
+ element_type = a_type
272
+
273
+ lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
274
+
275
+ # early out if LTO for this symbol is already cached in current module
276
+ if lto_symbol in builder.ltoirs:
277
+ return lto_symbol, builder.ltoirs[lto_symbol]
278
+
279
+ # hash symbol and determine output path
280
+ h = hash_symbol(lto_symbol)
281
+
282
+ lto_dir = get_lto_cache_dir()
283
+ lto_name = f"{h[:7]}.lto"
284
+ lto_path = os.path.join(lto_dir, lto_name)
285
+
286
+ # early out if LTO for this symbol is already built but not cached in current module
287
+ lto_code_data = get_cached_lto(lto_path)
288
+
289
+ if lto_code_data is not None:
290
+ builder.ltoirs[lto_symbol] = lto_code_data
291
+ builder.ltoirs_decl[lto_symbol] = (
292
+ f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
293
+ )
294
+
295
+ return lto_symbol, lto_code_data
296
+
297
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
298
+ build_dir = f"{lto_dir}_p{os.getpid()}"
299
+
300
+ # dir may exist from previous attempts / runs / archs
301
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
302
+
303
+ # temporary path to compile to in build_dir
304
+ temp_lto_path = os.path.join(build_dir, lto_name)
305
+
306
+ # compile LTO
307
+ result = warp.context.runtime.core.cuda_compile_dot(
308
+ temp_lto_path.encode("utf-8"),
309
+ lto_symbol.encode("utf-8"),
310
+ 0,
311
+ None,
312
+ None,
313
+ arch,
314
+ M,
315
+ N,
316
+ K,
317
+ a_prec,
318
+ b_prec,
319
+ c_prec,
320
+ element_type,
321
+ a_arrangement,
322
+ b_arrangement,
323
+ c_arrangement,
324
+ num_threads,
325
+ )
326
+
327
+ if not result:
328
+ if Path(temp_lto_path).exists():
329
+ Path(temp_lto_path).unlink()
330
+ raise RuntimeError("Failed to compile tile_matmul")
331
+ else:
332
+ with open(temp_lto_path, "rb") as f:
333
+ lto_code_data = f.read()
334
+
335
+ builder.ltoirs[lto_symbol] = lto_code_data
336
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
337
+
338
+ # try to move process outputs to cache
339
+ safe_rename(build_dir, lto_dir)
340
+
341
+ if os.path.exists(lto_dir):
342
+ if not os.path.exists(lto_path):
343
+ # copy output file to the destination lto dir
344
+ try:
345
+ os.rename(temp_lto_path, lto_path)
346
+ except (OSError, FileExistsError):
347
+ # another process likely updated the lto dir first
348
+ pass
349
+
350
+ if build_dir:
351
+ import shutil
352
+
353
+ # clean up build_dir used for this process
354
+ shutil.rmtree(build_dir, ignore_errors=True)
355
+
356
+ return lto_symbol, lto_code_data
357
+
358
+
359
+ def build_lto_solver(M, N, solver, solver_enum, fill_mode, arch, precision_enum, num_threads, parameter_list, builder):
360
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
361
+ arch = min(arch, 90)
362
+
363
+ lto_symbol = f"{solver}_{M}_{N}_{arch}_{num_threads}_{precision_enum}_{fill_mode}"
364
+ ltoir_decl = f"void {lto_symbol}{parameter_list};"
365
+
366
+ # early out if LTO for this symbol is already cached in current module
367
+ if lto_symbol in builder.ltoirs:
368
+ return lto_symbol, builder.ltoirs[lto_symbol]
369
+
370
+ # hash symbol and determine output path
371
+ h = hash_symbol(lto_symbol)
372
+
373
+ lto_dir = get_lto_cache_dir()
374
+ lto_name = f"{h[:7]}.lto"
375
+ lto_path = os.path.join(lto_dir, lto_name)
376
+
377
+ # we also cache a universal fatbin binary for this symbol
378
+ universal_fatbin_name = f"{h[:7]}_fatbin.lto"
379
+ universal_fatbin_path = os.path.join(lto_dir, universal_fatbin_name)
380
+
381
+ lto_code_data = get_cached_lto(lto_path)
382
+ universal_fatbin_code_data = get_cached_lto(universal_fatbin_path)
383
+
384
+ # early out if LTO for this symbol is already built but not cached in current module
385
+ if lto_code_data is not None and universal_fatbin_code_data is not None:
386
+ builder.ltoirs[lto_symbol] = lto_code_data
387
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
388
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
389
+
390
+ return lto_symbol, lto_code_data
391
+
392
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
393
+ build_dir = f"{lto_dir}_p{os.getpid()}"
394
+
395
+ # dir may exist from previous attempts / runs / archs
396
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
397
+
398
+ # temporary paths to compile to in build_dir
399
+ temp_lto_path = os.path.join(build_dir, lto_name)
400
+ temp_universal_fatbin_path = os.path.join(build_dir, universal_fatbin_name)
401
+
402
+ # compile LTO
403
+ result = warp.context.runtime.core.cuda_compile_solver(
404
+ temp_universal_fatbin_path.encode("utf-8"),
405
+ temp_lto_path.encode("utf-8"),
406
+ lto_symbol.encode("utf-8"),
407
+ 0,
408
+ None,
409
+ None,
410
+ arch,
411
+ M,
412
+ N,
413
+ solver_enum,
414
+ precision_enum,
415
+ fill_mode,
416
+ num_threads,
417
+ )
418
+
419
+ if not result:
420
+ for path in [temp_universal_fatbin_path, temp_lto_path]:
421
+ if Path(path).exists():
422
+ Path(path).unlink()
423
+ raise RuntimeError("Failed to compile tile_cholesky")
424
+
425
+ else:
426
+ with open(temp_lto_path, "rb") as f:
427
+ lto_code_data = f.read()
428
+ with open(temp_universal_fatbin_path, "rb") as f:
429
+ universal_fatbin_code_data = f.read()
430
+
431
+ builder.ltoirs[lto_symbol] = lto_code_data
432
+ builder.ltoirs_decl[lto_symbol] = ltoir_decl
433
+ builder.fatbins[lto_symbol] = universal_fatbin_code_data
434
+
435
+ # try to move process outputs to lto cache
436
+ safe_rename(build_dir, lto_dir)
437
+
438
+ if os.path.exists(lto_dir):
439
+ for p in [(lto_path, temp_lto_path), (universal_fatbin_path, temp_universal_fatbin_path)]:
440
+ path, temp_path = p
441
+ if not os.path.exists(path):
442
+ # copy output file to the destination lto dir
443
+ try:
444
+ os.rename(temp_path, path)
445
+ except (OSError, FileExistsError):
446
+ # another process likely updated the lto dir first
447
+ pass
448
+
449
+ if build_dir:
450
+ import shutil
451
+
452
+ # clean up build_dir used for this process
453
+ shutil.rmtree(build_dir, ignore_errors=True)
454
+
455
+ return lto_symbol, lto_code_data
456
+
457
+
458
+ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
459
+ # TODO: MathDx doesn't yet have heuristics for Blackwell
460
+ arch = min(arch, 90)
461
+
462
+ lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
463
+
464
+ # early out if LTO for this symbol is already cached in current module
465
+ if lto_symbol in builder.ltoirs:
466
+ return lto_symbol, builder.ltoirs[lto_symbol], builder.shared_memory_bytes[lto_symbol]
467
+
468
+ # hash symbol and determine output path
469
+ h = hash_symbol(lto_symbol)
470
+
471
+ lto_dir = get_lto_cache_dir()
472
+ lto_name = f"{h[:7]}.lto"
473
+ lto_path = os.path.join(lto_dir, lto_name)
474
+
475
+ # we also cache shared memory requirements for this kernel in a .meta file
476
+ meta_name = f"{h[:7]}.meta"
477
+ meta_path = os.path.join(lto_dir, meta_name)
478
+
479
+ # early out if LTO for this symbol is already built but not cached in current module
480
+ lto_code_data = get_cached_lto(lto_path)
481
+ shared_memory_bytes = get_cached_lto_meta(meta_path, lto_symbol)
482
+
483
+ if lto_code_data is not None and shared_memory_bytes is not None:
484
+ builder.ltoirs[lto_symbol] = lto_code_data
485
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
486
+
487
+ return lto_symbol, lto_code_data, shared_memory_bytes
488
+
489
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
490
+ build_dir = f"{lto_dir}_p{os.getpid()}"
491
+
492
+ # dir may exist from previous attempts / runs / archs
493
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
494
+
495
+ # temporary paths to compile to in build_dir
496
+ temp_lto_path = os.path.join(build_dir, lto_name)
497
+ temp_meta_path = os.path.join(build_dir, meta_name)
498
+
499
+ # compile LTO
500
+ shared_memory_size = ctypes.c_int(0)
501
+
502
+ result = warp.context.runtime.core.cuda_compile_fft(
503
+ temp_lto_path.encode("utf-8"),
504
+ lto_symbol.encode("utf-8"),
505
+ 0,
506
+ None,
507
+ None,
508
+ arch,
509
+ size,
510
+ ept,
511
+ dir,
512
+ precision,
513
+ ctypes.byref(shared_memory_size),
514
+ )
515
+
516
+ shared_memory_bytes = Tile.round_up(shared_memory_size.value)
517
+
518
+ if not result:
519
+ if Path(temp_lto_path).exists():
520
+ Path(temp_lto_path).unlink()
521
+ raise RuntimeError("Failed to compile tile_fft")
522
+
523
+ else:
524
+ with open(temp_lto_path, "rb") as f:
525
+ lto_code_data = f.read()
526
+
527
+ # output meta file with shared memory requirements for this lto_symbol
528
+ meta = {}
529
+ meta[lto_symbol] = shared_memory_bytes
530
+
531
+ with open(temp_meta_path, "w") as meta_file:
532
+ json.dump(meta, meta_file)
533
+
534
+ builder.ltoirs[lto_symbol] = lto_code_data
535
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
536
+
537
+ # try to move process outputs to cache
538
+ safe_rename(build_dir, lto_dir)
539
+
540
+ if os.path.exists(lto_dir):
541
+ for p in [(lto_path, temp_lto_path), (meta_path, temp_meta_path)]:
542
+ path, temp_path = p
543
+ if not os.path.exists(path):
544
+ # copy output file to the destination lto dir
545
+ try:
546
+ os.rename(temp_path, path)
547
+ except (OSError, FileExistsError):
548
+ # another process likely updated the lto dir first
549
+ pass
550
+
551
+ if build_dir:
552
+ import shutil
553
+
554
+ # clean up build_dir used for this process
555
+ shutil.rmtree(build_dir, ignore_errors=True)
556
+
557
+ return lto_symbol, lto_code_data, shared_memory_bytes
warp/build_dll.py CHANGED
@@ -147,14 +147,6 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
147
147
  cuda_home = args.cuda_path
148
148
  cuda_cmd = None
149
149
 
150
- if args.quick:
151
- cutlass_includes = ""
152
- cutlass_enabled = "WP_ENABLE_CUTLASS=0"
153
- else:
154
- cutlass_home = "warp/native/cutlass"
155
- cutlass_includes = f'-I"{cutlass_home}/include" -I"{cutlass_home}/tools/util/include"'
156
- cutlass_enabled = "WP_ENABLE_CUTLASS=1"
157
-
158
150
  if args.quick or cu_path is None:
159
151
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
160
152
  else:
@@ -270,7 +262,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
270
262
  iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
271
263
  debug = "_DEBUG"
272
264
 
273
- cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
265
+ cpp_flags = f'/nologo /std:c++17 /GR- {runtime} /D "{debug}" /D "{cuda_enabled}" /D "{mathdx_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" {includes} '
274
266
 
275
267
  if args.mode == "debug":
276
268
  cpp_flags += "/Zi /Od /D WP_ENABLE_DEBUG=1"
@@ -299,10 +291,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
299
291
  cu_out = cu_path + ".o"
300
292
 
301
293
  if mode == "debug":
302
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
294
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 --compiler-options=/MT,/Zi,/Od -g -G -O0 -DNDEBUG -D_ITERATOR_DEBUG_LEVEL=0 -I"{native_dir}" -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
303
295
 
304
296
  elif mode == "release":
305
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
297
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
306
298
 
307
299
  with ScopedTimer("build_cuda", active=args.verbose):
308
300
  run_cmd(cuda_cmd)
@@ -329,7 +321,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
329
321
  else:
330
322
  version = "-fabi-version=13" # GCC 8.2+
331
323
 
332
- cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{cutlass_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
324
+ cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
333
325
 
334
326
  if mode == "debug":
335
327
  cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
@@ -357,10 +349,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
357
349
  cu_out = cu_path + ".o"
358
350
 
359
351
  if mode == "debug":
360
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
352
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
361
353
 
362
354
  elif mode == "release":
363
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
355
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
364
356
 
365
357
  with ScopedTimer("build_cuda", active=args.verbose):
366
358
  run_cmd(cuda_cmd)