warp-lang 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl → 1.8.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.so +0 -0
  5. warp/bin/warp.so +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/build.py CHANGED
@@ -39,6 +39,7 @@ def build_cuda(
39
39
  fast_math=False,
40
40
  fuse_fp=True,
41
41
  lineinfo=False,
42
+ compile_time_trace=False,
42
43
  ltoirs=None,
43
44
  fatbins=None,
44
45
  ) -> None:
@@ -79,6 +80,7 @@ def build_cuda(
79
80
  fast_math,
80
81
  fuse_fp,
81
82
  lineinfo,
83
+ compile_time_trace,
82
84
  output_path,
83
85
  num_link,
84
86
  arr_link,
@@ -223,7 +225,7 @@ def get_cached_lto(path):
223
225
 
224
226
  def get_cached_lto_meta(path, symbol):
225
227
  if os.path.exists(path):
226
- with open(path, "r") as f:
228
+ with open(path) as f:
227
229
  keys = json.load(f)
228
230
  value = keys[symbol]
229
231
  return value
@@ -231,9 +233,106 @@ def get_cached_lto_meta(path, symbol):
231
233
  return None
232
234
 
233
235
 
236
+ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
237
+ """Generic LTO build function that handles caching, file operations and process management.
238
+
239
+ Args:
240
+ lto_symbol: Unique identifier for the LTO operation
241
+ compile_func: Function to compile the specific LTO
242
+ (receives a dictionary of build paths)
243
+ builder: Builder object to store results
244
+ extra_files: Dictionary of additional file types to handle (e.g.,
245
+ {".meta": None, ".fatbin": None}). Values are the functions to get
246
+ the cached file data.
247
+
248
+ Returns:
249
+ Tuple containing lto_code_data followed by any extra data from extra_files
250
+ """
251
+ if extra_files is None:
252
+ extra_files = {}
253
+
254
+ # Hash symbol and set up paths
255
+ h = hash_symbol(lto_symbol)
256
+ lto_dir = get_lto_cache_dir()
257
+ lto_name = f"{h[:7]}.lto"
258
+ lto_path = os.path.join(lto_dir, lto_name)
259
+
260
+ # Set up paths for extra files
261
+ file_paths = {".lto": lto_path}
262
+ temp_file_paths = {}
263
+
264
+ for ext, _ in extra_files.items():
265
+ name = f"{h[:7]}{ext}"
266
+ file_paths[ext] = os.path.join(lto_dir, name)
267
+
268
+ # Check if already built but not cached
269
+ lto_code_data = get_cached_lto(lto_path)
270
+ if lto_code_data is not None:
271
+ # Get the cached data for the extra files and early return
272
+ all_files_cached = True
273
+ for ext, getter in extra_files.items():
274
+ if getter and os.path.exists(file_paths[ext]):
275
+ cached_data = getter(file_paths[ext])
276
+ if cached_data is None:
277
+ all_files_cached = False
278
+ break
279
+ extra_files[ext] = cached_data
280
+ elif getter: # If there's a getter but file doesn't exist
281
+ all_files_cached = False
282
+ break
283
+
284
+ if all_files_cached:
285
+ if not extra_files:
286
+ return (lto_code_data,)
287
+ else:
288
+ return (lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
289
+
290
+ # Create process-dependent temporary build directory
291
+ build_dir = f"{lto_dir}_p{os.getpid()}"
292
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
293
+
294
+ # Set up temporary paths for the build outputs
295
+ for ext, path in file_paths.items():
296
+ temp_file_paths[ext] = os.path.join(build_dir, os.path.basename(path))
297
+
298
+ # Compile LTO with the specialized function
299
+ result, outputs = compile_func(temp_file_paths)
300
+
301
+ if not result:
302
+ # Clean up and fail
303
+ for path in temp_file_paths.values():
304
+ if Path(path).exists():
305
+ Path(path).unlink()
306
+ raise RuntimeError(f"Failed to compile {lto_symbol}")
307
+
308
+ # Move outputs to cache
309
+ safe_rename(build_dir, lto_dir)
310
+
311
+ # If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
312
+ if os.path.exists(lto_dir):
313
+ for ext, path in file_paths.items():
314
+ if not os.path.exists(path):
315
+ try:
316
+ # copy output file to the destination lto dir
317
+ os.rename(temp_file_paths[ext], path)
318
+ except (OSError, FileExistsError):
319
+ # another process likely updated the lto dir first
320
+ pass
321
+
322
+ # Clean up the temporary build directory
323
+ if build_dir:
324
+ import shutil
325
+
326
+ shutil.rmtree(build_dir, ignore_errors=True)
327
+
328
+ if not extra_files:
329
+ return (outputs[".lto"],)
330
+ else:
331
+ return (outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
332
+
333
+
234
334
  def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
235
- # TODO: MathDx doesn't yet have heuristics for Blackwell
236
- arch = min(arch, 90)
335
+ arch = 120 if arch > 121 else arch
237
336
 
238
337
  # Maps Python/Warp types to C++ types and enums
239
338
  def cublasdx_type_map(dtype):
@@ -266,292 +365,182 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
266
365
  c_arrangement = cublasdx_arrangement_map(clayout)
267
366
 
268
367
  if a_type != b_type or a_type != c_type:
269
- raise TypeError("time_matmul(A, B, C) requires all inputs to be real or complex")
368
+ raise TypeError("tile_matmul(A, B, C) requires all inputs to be real or complex")
270
369
 
271
370
  element_type = a_type
272
371
 
273
372
  lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
274
373
 
275
- # early out if LTO for this symbol is already cached in current module
276
- if lto_symbol in builder.ltoirs:
277
- return lto_symbol, builder.ltoirs[lto_symbol]
278
-
279
- # hash symbol and determine output path
280
- h = hash_symbol(lto_symbol)
374
+ def compile_lto_dot(temp_paths):
375
+ result = warp.context.runtime.core.cuda_compile_dot(
376
+ temp_paths[".lto"].encode("utf-8"),
377
+ lto_symbol.encode("utf-8"),
378
+ 0,
379
+ None,
380
+ None,
381
+ arch,
382
+ M,
383
+ N,
384
+ K,
385
+ a_prec,
386
+ b_prec,
387
+ c_prec,
388
+ element_type,
389
+ a_arrangement,
390
+ b_arrangement,
391
+ c_arrangement,
392
+ num_threads,
393
+ )
281
394
 
282
- lto_dir = get_lto_cache_dir()
283
- lto_name = f"{h[:7]}.lto"
284
- lto_path = os.path.join(lto_dir, lto_name)
395
+ if result:
396
+ with open(temp_paths[".lto"], "rb") as f:
397
+ lto_code_data = f.read()
398
+ return True, {".lto": lto_code_data}
399
+ return False, {}
285
400
 
286
- # early out if LTO for this symbol is already built but not cached in current module
287
- lto_code_data = get_cached_lto(lto_path)
401
+ # Early out if already cached in module
402
+ if lto_symbol in builder.ltoirs:
403
+ lto_code_data = builder.ltoirs[lto_symbol]
404
+ else:
405
+ (lto_code_data,) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
288
406
 
289
- if lto_code_data is not None:
407
+ # Update builder
290
408
  builder.ltoirs[lto_symbol] = lto_code_data
291
409
  builder.ltoirs_decl[lto_symbol] = (
292
- f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
410
+ f"void {lto_symbol}({c_dtype}*, {a_dtype}*, {b_dtype}*, {c_dtype}*, {c_dtype}*);"
293
411
  )
294
412
 
295
- return lto_symbol, lto_code_data
296
-
297
- # create a temporary (process unique) dir for build outputs before moving to the binary dir
298
- build_dir = f"{lto_dir}_p{os.getpid()}"
299
-
300
- # dir may exist from previous attempts / runs / archs
301
- Path(build_dir).mkdir(parents=True, exist_ok=True)
302
-
303
- # temporary path to compile to in build_dir
304
- temp_lto_path = os.path.join(build_dir, lto_name)
305
-
306
- # compile LTO
307
- result = warp.context.runtime.core.cuda_compile_dot(
308
- temp_lto_path.encode("utf-8"),
309
- lto_symbol.encode("utf-8"),
310
- 0,
311
- None,
312
- None,
313
- arch,
314
- M,
315
- N,
316
- K,
317
- a_prec,
318
- b_prec,
319
- c_prec,
320
- element_type,
321
- a_arrangement,
322
- b_arrangement,
323
- c_arrangement,
324
- num_threads,
325
- )
326
-
327
- if not result:
328
- if Path(temp_lto_path).exists():
329
- Path(temp_lto_path).unlink()
330
- raise RuntimeError("Failed to compile tile_matmul")
331
- else:
332
- with open(temp_lto_path, "rb") as f:
333
- lto_code_data = f.read()
334
-
335
- builder.ltoirs[lto_symbol] = lto_code_data
336
- builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}({c_dtype}, {a_dtype}*, {b_dtype}*, {c_dtype}, {c_dtype}*);"
337
-
338
- # try to move process outputs to cache
339
- safe_rename(build_dir, lto_dir)
340
-
341
- if os.path.exists(lto_dir):
342
- if not os.path.exists(lto_path):
343
- # copy output file to the destination lto dir
344
- try:
345
- os.rename(temp_lto_path, lto_path)
346
- except (OSError, FileExistsError):
347
- # another process likely updated the lto dir first
348
- pass
349
-
350
- if build_dir:
351
- import shutil
352
-
353
- # clean up build_dir used for this process
354
- shutil.rmtree(build_dir, ignore_errors=True)
355
-
356
413
  return lto_symbol, lto_code_data
357
414
 
358
415
 
359
- def build_lto_solver(M, N, solver, solver_enum, fill_mode, arch, precision_enum, num_threads, parameter_list, builder):
360
- # TODO: MathDx doesn't yet have heuristics for Blackwell
361
- arch = min(arch, 90)
362
-
363
- lto_symbol = f"{solver}_{M}_{N}_{arch}_{num_threads}_{precision_enum}_{fill_mode}"
364
- ltoir_decl = f"void {lto_symbol}{parameter_list};"
365
-
366
- # early out if LTO for this symbol is already cached in current module
367
- if lto_symbol in builder.ltoirs:
368
- return lto_symbol, builder.ltoirs[lto_symbol]
369
-
370
- # hash symbol and determine output path
371
- h = hash_symbol(lto_symbol)
416
+ def build_lto_solver(
417
+ M,
418
+ N,
419
+ NRHS,
420
+ solver,
421
+ solver_enum,
422
+ side_enum,
423
+ diag_enum,
424
+ alayout,
425
+ blayout,
426
+ fill_mode,
427
+ arch,
428
+ precision_enum,
429
+ num_threads,
430
+ parameter_list,
431
+ builder,
432
+ ):
433
+ arch = 120 if arch > 121 else arch
434
+
435
+ def cusolverdx_arrangement_map(layout):
436
+ if layout == "colmajor":
437
+ return 0 # CUSOLVERDX_ARRANGEMENT_COL_MAJOR
438
+ if layout == "rowmajor":
439
+ return 1 # CUSOLVERDX_ARRANGEMENT_ROW_MAJOR
440
+ raise ValueError("Unsupported layout in tile_matmul")
372
441
 
373
- lto_dir = get_lto_cache_dir()
374
- lto_name = f"{h[:7]}.lto"
375
- lto_path = os.path.join(lto_dir, lto_name)
442
+ a_arrangement = cusolverdx_arrangement_map(alayout)
443
+ b_arrangement = cusolverdx_arrangement_map(blayout)
444
+
445
+ lto_symbol = f"{solver}_{M}_{N}_{NRHS}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{precision_enum}_{side_enum if side_enum >= 0 else 'x'}_{diag_enum if diag_enum >= 0 else 'x'}_{fill_mode}"
446
+
447
+ def compile_lto_solver(temp_paths):
448
+ # compile LTO
449
+ result = warp.context.runtime.core.cuda_compile_solver(
450
+ temp_paths["_fatbin.lto"].encode("utf-8"),
451
+ temp_paths[".lto"].encode("utf-8"),
452
+ lto_symbol.encode("utf-8"),
453
+ 0,
454
+ None,
455
+ None,
456
+ arch,
457
+ M,
458
+ N,
459
+ NRHS,
460
+ solver_enum,
461
+ side_enum,
462
+ diag_enum,
463
+ precision_enum,
464
+ a_arrangement,
465
+ b_arrangement,
466
+ fill_mode,
467
+ num_threads,
468
+ )
376
469
 
377
- # we also cache a universal fatbin binary for this symbol
378
- universal_fatbin_name = f"{h[:7]}_fatbin.lto"
379
- universal_fatbin_path = os.path.join(lto_dir, universal_fatbin_name)
470
+ if result:
471
+ with open(temp_paths[".lto"], "rb") as f:
472
+ lto_code_data = f.read()
473
+ with open(temp_paths["_fatbin.lto"], "rb") as f:
474
+ universal_fatbin_code_data = f.read()
475
+ return True, {".lto": lto_code_data, "_fatbin.lto": universal_fatbin_code_data}
476
+ return False, {}
380
477
 
381
- lto_code_data = get_cached_lto(lto_path)
382
- universal_fatbin_code_data = get_cached_lto(universal_fatbin_path)
478
+ # Early out if already cached in module
479
+ if lto_symbol in builder.ltoirs:
480
+ lto_code_data = builder.ltoirs[lto_symbol]
481
+ else:
482
+ lto_code_data, universal_fatbin_code_data = _build_lto_base(
483
+ lto_symbol, compile_lto_solver, builder, {"_fatbin.lto": get_cached_lto}
484
+ )
383
485
 
384
- # early out if LTO for this symbol is already built but not cached in current module
385
- if lto_code_data is not None and universal_fatbin_code_data is not None:
486
+ # Update builder
386
487
  builder.ltoirs[lto_symbol] = lto_code_data
387
- builder.ltoirs_decl[lto_symbol] = ltoir_decl
488
+ builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}{parameter_list};"
388
489
  builder.fatbins[lto_symbol] = universal_fatbin_code_data
389
490
 
390
- return lto_symbol, lto_code_data
391
-
392
- # create a temporary (process unique) dir for build outputs before moving to the binary dir
393
- build_dir = f"{lto_dir}_p{os.getpid()}"
394
-
395
- # dir may exist from previous attempts / runs / archs
396
- Path(build_dir).mkdir(parents=True, exist_ok=True)
397
-
398
- # temporary paths to compile to in build_dir
399
- temp_lto_path = os.path.join(build_dir, lto_name)
400
- temp_universal_fatbin_path = os.path.join(build_dir, universal_fatbin_name)
401
-
402
- # compile LTO
403
- result = warp.context.runtime.core.cuda_compile_solver(
404
- temp_universal_fatbin_path.encode("utf-8"),
405
- temp_lto_path.encode("utf-8"),
406
- lto_symbol.encode("utf-8"),
407
- 0,
408
- None,
409
- None,
410
- arch,
411
- M,
412
- N,
413
- solver_enum,
414
- precision_enum,
415
- fill_mode,
416
- num_threads,
417
- )
418
-
419
- if not result:
420
- for path in [temp_universal_fatbin_path, temp_lto_path]:
421
- if Path(path).exists():
422
- Path(path).unlink()
423
- raise RuntimeError("Failed to compile tile_cholesky")
424
-
425
- else:
426
- with open(temp_lto_path, "rb") as f:
427
- lto_code_data = f.read()
428
- with open(temp_universal_fatbin_path, "rb") as f:
429
- universal_fatbin_code_data = f.read()
430
-
431
- builder.ltoirs[lto_symbol] = lto_code_data
432
- builder.ltoirs_decl[lto_symbol] = ltoir_decl
433
- builder.fatbins[lto_symbol] = universal_fatbin_code_data
434
-
435
- # try to move process outputs to lto cache
436
- safe_rename(build_dir, lto_dir)
437
-
438
- if os.path.exists(lto_dir):
439
- for p in [(lto_path, temp_lto_path), (universal_fatbin_path, temp_universal_fatbin_path)]:
440
- path, temp_path = p
441
- if not os.path.exists(path):
442
- # copy output file to the destination lto dir
443
- try:
444
- os.rename(temp_path, path)
445
- except (OSError, FileExistsError):
446
- # another process likely updated the lto dir first
447
- pass
448
-
449
- if build_dir:
450
- import shutil
451
-
452
- # clean up build_dir used for this process
453
- shutil.rmtree(build_dir, ignore_errors=True)
454
-
455
491
  return lto_symbol, lto_code_data
456
492
 
457
493
 
458
494
  def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
459
- # TODO: MathDx doesn't yet have heuristics for Blackwell
460
- arch = min(arch, 90)
495
+ arch = 120 if arch > 121 else arch
461
496
 
462
497
  lto_symbol = f"fft_{size}_{ept}_{arch}_{direction}_{precision}"
463
498
 
464
- # early out if LTO for this symbol is already cached in current module
465
- if lto_symbol in builder.ltoirs:
466
- return lto_symbol, builder.ltoirs[lto_symbol], builder.shared_memory_bytes[lto_symbol]
467
-
468
- # hash symbol and determine output path
469
- h = hash_symbol(lto_symbol)
470
-
471
- lto_dir = get_lto_cache_dir()
472
- lto_name = f"{h[:7]}.lto"
473
- lto_path = os.path.join(lto_dir, lto_name)
474
-
475
- # we also cache shared memory requirements for this kernel in a .meta file
476
- meta_name = f"{h[:7]}.meta"
477
- meta_path = os.path.join(lto_dir, meta_name)
499
+ def compile_lto_fft(temp_paths):
500
+ shared_memory_size = ctypes.c_int(0)
501
+
502
+ result = warp.context.runtime.core.cuda_compile_fft(
503
+ temp_paths[".lto"].encode("utf-8"),
504
+ lto_symbol.encode("utf-8"),
505
+ 0,
506
+ None,
507
+ None,
508
+ arch,
509
+ size,
510
+ ept,
511
+ dir,
512
+ precision,
513
+ ctypes.byref(shared_memory_size),
514
+ )
478
515
 
479
- # early out if LTO for this symbol is already built but not cached in current module
480
- lto_code_data = get_cached_lto(lto_path)
481
- shared_memory_bytes = get_cached_lto_meta(meta_path, lto_symbol)
516
+ if result:
517
+ with open(temp_paths[".lto"], "rb") as f:
518
+ lto_code_data = f.read()
482
519
 
483
- if lto_code_data is not None and shared_memory_bytes is not None:
484
- builder.ltoirs[lto_symbol] = lto_code_data
485
- builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
520
+ shared_memory_bytes = tile.round_up(shared_memory_size.value)
486
521
 
487
- return lto_symbol, lto_code_data, shared_memory_bytes
522
+ # output meta file with shared memory requirements for this lto_symbol
523
+ meta = {}
524
+ meta[lto_symbol] = shared_memory_bytes
488
525
 
489
- # create a temporary (process unique) dir for build outputs before moving to the binary dir
490
- build_dir = f"{lto_dir}_p{os.getpid()}"
526
+ with open(temp_paths[".meta"], "w") as meta_file:
527
+ json.dump(meta, meta_file)
491
528
 
492
- # dir may exist from previous attempts / runs / archs
493
- Path(build_dir).mkdir(parents=True, exist_ok=True)
494
-
495
- # temporary paths to compile to in build_dir
496
- temp_lto_path = os.path.join(build_dir, lto_name)
497
- temp_meta_path = os.path.join(build_dir, meta_name)
498
-
499
- # compile LTO
500
- shared_memory_size = ctypes.c_int(0)
501
-
502
- result = warp.context.runtime.core.cuda_compile_fft(
503
- temp_lto_path.encode("utf-8"),
504
- lto_symbol.encode("utf-8"),
505
- 0,
506
- None,
507
- None,
508
- arch,
509
- size,
510
- ept,
511
- dir,
512
- precision,
513
- ctypes.byref(shared_memory_size),
514
- )
515
-
516
- shared_memory_bytes = Tile.round_up(shared_memory_size.value)
529
+ return True, {".lto": lto_code_data, ".meta": shared_memory_bytes}
517
530
 
518
- if not result:
519
- if Path(temp_lto_path).exists():
520
- Path(temp_lto_path).unlink()
521
- raise RuntimeError("Failed to compile tile_fft")
531
+ return False, {}
522
532
 
533
+ # Early out if already cached in module
534
+ if lto_symbol in builder.ltoirs and lto_symbol in builder.shared_memory_bytes:
535
+ lto_code_data = builder.ltoirs[lto_symbol]
536
+ shared_memory_bytes = builder.shared_memory_bytes[lto_symbol]
523
537
  else:
524
- with open(temp_lto_path, "rb") as f:
525
- lto_code_data = f.read()
526
-
527
- # output meta file with shared memory requirements for this lto_symbol
528
- meta = {}
529
- meta[lto_symbol] = shared_memory_bytes
530
-
531
- with open(temp_meta_path, "w") as meta_file:
532
- json.dump(meta, meta_file)
533
-
534
- builder.ltoirs[lto_symbol] = lto_code_data
535
- builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
536
-
537
- # try to move process outputs to cache
538
- safe_rename(build_dir, lto_dir)
539
-
540
- if os.path.exists(lto_dir):
541
- for p in [(lto_path, temp_lto_path), (meta_path, temp_meta_path)]:
542
- path, temp_path = p
543
- if not os.path.exists(path):
544
- # copy output file to the destination lto dir
545
- try:
546
- os.rename(temp_path, path)
547
- except (OSError, FileExistsError):
548
- # another process likely updated the lto dir first
549
- pass
550
-
551
- if build_dir:
552
- import shutil
538
+ lto_code_data, shared_memory_bytes = _build_lto_base(
539
+ lto_symbol, compile_lto_fft, builder, {".meta": lambda path: get_cached_lto_meta(path, lto_symbol)}
540
+ )
553
541
 
554
- # clean up build_dir used for this process
555
- shutil.rmtree(build_dir, ignore_errors=True)
542
+ # Update builder
543
+ builder.ltoirs[lto_symbol] = lto_code_data
544
+ builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
556
545
 
557
546
  return lto_symbol, lto_code_data, shared_memory_bytes