warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (134) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +47 -67
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +312 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1249 -784
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/fabric.py +1 -1
  18. warp/fem/cache.py +27 -19
  19. warp/fem/domain.py +2 -2
  20. warp/fem/field/nodal_field.py +2 -2
  21. warp/fem/field/virtual.py +264 -166
  22. warp/fem/geometry/geometry.py +5 -5
  23. warp/fem/integrate.py +129 -51
  24. warp/fem/space/restriction.py +4 -0
  25. warp/fem/space/shape/tet_shape_function.py +3 -10
  26. warp/jax_experimental/custom_call.py +1 -1
  27. warp/jax_experimental/ffi.py +2 -1
  28. warp/marching_cubes.py +708 -0
  29. warp/native/array.h +99 -4
  30. warp/native/builtin.h +82 -5
  31. warp/native/bvh.cpp +64 -28
  32. warp/native/bvh.cu +58 -58
  33. warp/native/bvh.h +2 -2
  34. warp/native/clang/clang.cpp +7 -7
  35. warp/native/coloring.cpp +8 -2
  36. warp/native/crt.cpp +2 -2
  37. warp/native/crt.h +3 -5
  38. warp/native/cuda_util.cpp +41 -10
  39. warp/native/cuda_util.h +10 -4
  40. warp/native/exports.h +1842 -1908
  41. warp/native/fabric.h +2 -1
  42. warp/native/hashgrid.cpp +37 -37
  43. warp/native/hashgrid.cu +2 -2
  44. warp/native/initializer_array.h +1 -1
  45. warp/native/intersect.h +2 -2
  46. warp/native/mat.h +1910 -116
  47. warp/native/mathdx.cpp +43 -43
  48. warp/native/mesh.cpp +24 -24
  49. warp/native/mesh.cu +26 -26
  50. warp/native/mesh.h +4 -2
  51. warp/native/nanovdb/GridHandle.h +179 -12
  52. warp/native/nanovdb/HostBuffer.h +8 -7
  53. warp/native/nanovdb/NanoVDB.h +517 -895
  54. warp/native/nanovdb/NodeManager.h +323 -0
  55. warp/native/nanovdb/PNanoVDB.h +2 -2
  56. warp/native/quat.h +331 -14
  57. warp/native/range.h +7 -1
  58. warp/native/reduce.cpp +10 -10
  59. warp/native/reduce.cu +13 -14
  60. warp/native/runlength_encode.cpp +2 -2
  61. warp/native/runlength_encode.cu +5 -5
  62. warp/native/scan.cpp +3 -3
  63. warp/native/scan.cu +4 -4
  64. warp/native/sort.cpp +10 -10
  65. warp/native/sort.cu +22 -22
  66. warp/native/sparse.cpp +8 -8
  67. warp/native/sparse.cu +13 -13
  68. warp/native/spatial.h +366 -17
  69. warp/native/temp_buffer.h +2 -2
  70. warp/native/tile.h +283 -69
  71. warp/native/vec.h +381 -14
  72. warp/native/volume.cpp +54 -54
  73. warp/native/volume.cu +1 -1
  74. warp/native/volume.h +2 -1
  75. warp/native/volume_builder.cu +30 -37
  76. warp/native/warp.cpp +150 -149
  77. warp/native/warp.cu +323 -192
  78. warp/native/warp.h +227 -226
  79. warp/optim/linear.py +736 -271
  80. warp/render/imgui_manager.py +289 -0
  81. warp/render/render_opengl.py +85 -6
  82. warp/sim/graph_coloring.py +2 -2
  83. warp/sparse.py +558 -175
  84. warp/tests/aux_test_module_aot.py +7 -0
  85. warp/tests/cuda/test_async.py +3 -3
  86. warp/tests/cuda/test_conditional_captures.py +101 -0
  87. warp/tests/geometry/test_marching_cubes.py +233 -12
  88. warp/tests/sim/test_coloring.py +6 -6
  89. warp/tests/test_array.py +56 -5
  90. warp/tests/test_codegen.py +3 -2
  91. warp/tests/test_context.py +8 -15
  92. warp/tests/test_enum.py +136 -0
  93. warp/tests/test_examples.py +2 -2
  94. warp/tests/test_fem.py +45 -2
  95. warp/tests/test_fixedarray.py +229 -0
  96. warp/tests/test_func.py +18 -15
  97. warp/tests/test_future_annotations.py +7 -5
  98. warp/tests/test_linear_solvers.py +30 -0
  99. warp/tests/test_map.py +1 -1
  100. warp/tests/test_mat.py +1518 -378
  101. warp/tests/test_mat_assign_copy.py +178 -0
  102. warp/tests/test_mat_constructors.py +574 -0
  103. warp/tests/test_module_aot.py +287 -0
  104. warp/tests/test_print.py +69 -0
  105. warp/tests/test_quat.py +140 -34
  106. warp/tests/test_quat_assign_copy.py +145 -0
  107. warp/tests/test_reload.py +2 -1
  108. warp/tests/test_sparse.py +71 -0
  109. warp/tests/test_spatial.py +140 -34
  110. warp/tests/test_spatial_assign_copy.py +160 -0
  111. warp/tests/test_struct.py +43 -3
  112. warp/tests/test_types.py +0 -20
  113. warp/tests/test_vec.py +179 -34
  114. warp/tests/test_vec_assign_copy.py +143 -0
  115. warp/tests/tile/test_tile.py +184 -18
  116. warp/tests/tile/test_tile_cholesky.py +605 -0
  117. warp/tests/tile/test_tile_load.py +169 -0
  118. warp/tests/tile/test_tile_mathdx.py +2 -558
  119. warp/tests/tile/test_tile_matmul.py +1 -1
  120. warp/tests/tile/test_tile_mlp.py +1 -1
  121. warp/tests/tile/test_tile_shared_memory.py +5 -5
  122. warp/tests/unittest_suites.py +6 -0
  123. warp/tests/walkthrough_debug.py +1 -1
  124. warp/thirdparty/unittest_parallel.py +108 -9
  125. warp/types.py +554 -264
  126. warp/utils.py +68 -86
  127. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  128. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/RECORD +131 -121
  129. warp/native/marching.cpp +0 -19
  130. warp/native/marching.cu +0 -514
  131. warp/native/marching.h +0 -19
  132. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  133. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/build.py CHANGED
@@ -51,7 +51,7 @@ def build_cuda(
51
51
  output_path = output_path.encode("utf-8")
52
52
 
53
53
  if warp.config.llvm_cuda:
54
- warp.context.runtime.llvm.compile_cuda(src, cu_path_bytes, inc_path, output_path, False)
54
+ warp.context.runtime.llvm.wp_compile_cuda(src, cu_path_bytes, inc_path, output_path, False)
55
55
 
56
56
  else:
57
57
  if ltoirs is None:
@@ -67,7 +67,7 @@ def build_cuda(
67
67
  fatbins
68
68
  )
69
69
  arr_link_input_types = (ctypes.c_int * num_link)(*link_input_types)
70
- err = warp.context.runtime.core.cuda_compile_program(
70
+ err = warp.context.runtime.core.wp_cuda_compile_program(
71
71
  src,
72
72
  program_name_bytes,
73
73
  arch,
@@ -96,7 +96,7 @@ def load_cuda(input_path, device):
96
96
  if not device.is_cuda:
97
97
  raise RuntimeError("Not a CUDA device")
98
98
 
99
- return warp.context.runtime.core.cuda_load_module(device.context, input_path.encode("utf-8"))
99
+ return warp.context.runtime.core.wp_cuda_load_module(device.context, input_path.encode("utf-8"))
100
100
 
101
101
 
102
102
  def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=False, fuse_fp=True):
@@ -106,7 +106,7 @@ def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=Fal
106
106
  inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
107
107
  obj_path = obj_path.encode("utf-8")
108
108
 
109
- err = warp.context.runtime.llvm.compile_cpp(
109
+ err = warp.context.runtime.llvm.wp_compile_cpp(
110
110
  src, cpp_path, inc_path, obj_path, mode == "debug", verify_fp, fuse_fp
111
111
  )
112
112
  if err != 0:
@@ -129,6 +129,15 @@ def init_kernel_cache(path=None):
129
129
  else:
130
130
  cache_root_dir = appdirs.user_cache_dir(appname="warp", appauthor="NVIDIA", version=warp.config.version)
131
131
 
132
+ if os.name == "nt" and os.path.isabs(cache_root_dir) and not cache_root_dir.startswith("\\\\?\\"):
133
+ # Add Windows long-path prefix, accounting for UNC shares.
134
+ if cache_root_dir.startswith("\\\\"):
135
+ # UNC path \\server\share\… → \\?\UNC\server\share\…
136
+ cache_root_dir = "\\\\?\\UNC\\" + cache_root_dir.lstrip("\\")
137
+ else:
138
+ # Drive-letter path C:\… → \\?\C:\…
139
+ cache_root_dir = "\\\\?\\" + cache_root_dir
140
+
132
141
  warp.config.kernel_cache_dir = cache_root_dir
133
142
 
134
143
  os.makedirs(warp.config.kernel_cache_dir, exist_ok=True)
@@ -246,7 +255,12 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
246
255
  the cached file data.
247
256
 
248
257
  Returns:
249
- Tuple containing lto_code_data followed by any extra data from extra_files
258
+ Tuple where the first element is a success flag (``bool``). The second
259
+ element is the LTO code as bytes (or ``None`` on failure).
260
+ If ``extra_files`` is provided, additional elements follow in the same
261
+ order as the keys in ``extra_files``:
262
+ - ``".meta"``: int (shared memory bytes).
263
+ - ``"_fatbin.lto"``: bytes (universal fatbin).
250
264
  """
251
265
  if extra_files is None:
252
266
  extra_files = {}
@@ -283,9 +297,9 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
283
297
 
284
298
  if all_files_cached:
285
299
  if not extra_files:
286
- return (lto_code_data,)
300
+ return (True, lto_code_data)
287
301
  else:
288
- return (lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
302
+ return (True, lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
289
303
 
290
304
  # Create process-dependent temporary build directory
291
305
  build_dir = f"{lto_dir}_p{os.getpid()}"
@@ -303,21 +317,24 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
303
317
  for path in temp_file_paths.values():
304
318
  if Path(path).exists():
305
319
  Path(path).unlink()
306
- raise RuntimeError(f"Failed to compile {lto_symbol}")
307
-
308
- # Move outputs to cache
309
- safe_rename(build_dir, lto_dir)
310
-
311
- # If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
312
- if os.path.exists(lto_dir):
313
- for ext, path in file_paths.items():
314
- if not os.path.exists(path):
315
- try:
316
- # copy output file to the destination lto dir
317
- os.rename(temp_file_paths[ext], path)
318
- except (OSError, FileExistsError):
319
- # another process likely updated the lto dir first
320
- pass
320
+
321
+ outputs[".lto"] = None
322
+ for ext in extra_files.keys():
323
+ outputs[ext] = None
324
+ else:
325
+ # Move outputs to cache
326
+ safe_rename(build_dir, lto_dir)
327
+
328
+ # If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
329
+ if os.path.exists(lto_dir):
330
+ for ext, path in file_paths.items():
331
+ if not os.path.exists(path):
332
+ try:
333
+ # copy output file to the destination lto dir
334
+ os.rename(temp_file_paths[ext], path)
335
+ except (OSError, FileExistsError):
336
+ # another process likely updated the lto dir first
337
+ pass
321
338
 
322
339
  # Clean up the temporary build directory
323
340
  if build_dir:
@@ -326,9 +343,9 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
326
343
  shutil.rmtree(build_dir, ignore_errors=True)
327
344
 
328
345
  if not extra_files:
329
- return (outputs[".lto"],)
346
+ return (result, outputs[".lto"])
330
347
  else:
331
- return (outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
348
+ return (result, outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
332
349
 
333
350
 
334
351
  def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
@@ -372,7 +389,7 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
372
389
  lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
373
390
 
374
391
  def compile_lto_dot(temp_paths):
375
- result = warp.context.runtime.core.cuda_compile_dot(
392
+ result = warp.context.runtime.core.wp_cuda_compile_dot(
376
393
  temp_paths[".lto"].encode("utf-8"),
377
394
  lto_symbol.encode("utf-8"),
378
395
  0,
@@ -402,7 +419,13 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
402
419
  if lto_symbol in builder.ltoirs:
403
420
  lto_code_data = builder.ltoirs[lto_symbol]
404
421
  else:
405
- (lto_code_data,) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
422
+ (result, lto_code_data) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
423
+
424
+ if not result:
425
+ raise RuntimeError(
426
+ f"Failed to compile LTO '{lto_symbol}'. "
427
+ "Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
428
+ )
406
429
 
407
430
  # Update builder
408
431
  builder.ltoirs[lto_symbol] = lto_code_data
@@ -429,6 +452,7 @@ def build_lto_solver(
429
452
  num_threads,
430
453
  parameter_list,
431
454
  builder,
455
+ smem_estimate_bytes=None,
432
456
  ):
433
457
  arch = 120 if arch > 121 else arch
434
458
 
@@ -446,7 +470,7 @@ def build_lto_solver(
446
470
 
447
471
  def compile_lto_solver(temp_paths):
448
472
  # compile LTO
449
- result = warp.context.runtime.core.cuda_compile_solver(
473
+ result = warp.context.runtime.core.wp_cuda_compile_solver(
450
474
  temp_paths["_fatbin.lto"].encode("utf-8"),
451
475
  temp_paths[".lto"].encode("utf-8"),
452
476
  lto_symbol.encode("utf-8"),
@@ -479,10 +503,43 @@ def build_lto_solver(
479
503
  if lto_symbol in builder.ltoirs:
480
504
  lto_code_data = builder.ltoirs[lto_symbol]
481
505
  else:
482
- lto_code_data, universal_fatbin_code_data = _build_lto_base(
506
+ (result, lto_code_data, universal_fatbin_code_data) = _build_lto_base(
483
507
  lto_symbol, compile_lto_solver, builder, {"_fatbin.lto": get_cached_lto}
484
508
  )
485
509
 
510
+ if not result:
511
+ hint = ""
512
+ if smem_estimate_bytes:
513
+ max_smem_bytes = 232448
514
+ max_smem_is_estimate = True
515
+ for d in warp.get_cuda_devices():
516
+ if d.arch == arch:
517
+ # We can directly query the max shared memory for this device
518
+ queried_bytes = warp.context.runtime.core.wp_cuda_get_max_shared_memory(d.context)
519
+ if queried_bytes > 0:
520
+ max_smem_bytes = queried_bytes
521
+ max_smem_is_estimate = False
522
+ break
523
+ if smem_estimate_bytes > max_smem_bytes:
524
+ source = "estimated limit" if max_smem_is_estimate else "device-reported limit"
525
+ hint = (
526
+ f"Estimated shared memory requirement is {smem_estimate_bytes}B, "
527
+ f"but the {source} is {max_smem_bytes}B. "
528
+ "The tile size(s) may be too large for this device."
529
+ )
530
+
531
+ if warp.context.runtime.toolkit_version < (12, 6):
532
+ raise RuntimeError(
533
+ "cuSolverDx requires CUDA Toolkit 12.6.3 or later. This version of Warp was built against CUDA Toolkit "
534
+ f"{warp.context.runtime.toolkit_version[0]}.{warp.context.runtime.toolkit_version[1]}. "
535
+ "Upgrade your CUDA Toolkit and rebuild Warp, or install a Warp wheel built with CUDA >= 12.6.3."
536
+ )
537
+ else:
538
+ raise RuntimeError(
539
+ f"Failed to compile LTO '{lto_symbol}'. {hint}"
540
+ " Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
541
+ )
542
+
486
543
  # Update builder
487
544
  builder.ltoirs[lto_symbol] = lto_code_data
488
545
  builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}{parameter_list};"
@@ -499,7 +556,7 @@ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
499
556
  def compile_lto_fft(temp_paths):
500
557
  shared_memory_size = ctypes.c_int(0)
501
558
 
502
- result = warp.context.runtime.core.cuda_compile_fft(
559
+ result = warp.context.runtime.core.wp_cuda_compile_fft(
503
560
  temp_paths[".lto"].encode("utf-8"),
504
561
  lto_symbol.encode("utf-8"),
505
562
  0,
@@ -535,10 +592,16 @@ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
535
592
  lto_code_data = builder.ltoirs[lto_symbol]
536
593
  shared_memory_bytes = builder.shared_memory_bytes[lto_symbol]
537
594
  else:
538
- lto_code_data, shared_memory_bytes = _build_lto_base(
595
+ (result, lto_code_data, shared_memory_bytes) = _build_lto_base(
539
596
  lto_symbol, compile_lto_fft, builder, {".meta": lambda path: get_cached_lto_meta(path, lto_symbol)}
540
597
  )
541
598
 
599
+ if not result:
600
+ raise RuntimeError(
601
+ f"Failed to compile LTO '{lto_symbol}'."
602
+ "Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
603
+ )
604
+
542
605
  # Update builder
543
606
  builder.ltoirs[lto_symbol] = lto_code_data
544
607
  builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes
warp/build_dll.py CHANGED
@@ -13,16 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import os
17
19
  import platform
18
20
  import subprocess
19
21
  import sys
20
- from typing import List, Optional
21
22
 
22
23
  from warp.utils import ScopedTimer
23
24
 
24
25
  verbose_cmd = True # print command lines before executing them
25
26
 
27
+ MIN_CTK_VERSION = (12, 0)
28
+
26
29
 
27
30
  def machine_architecture() -> str:
28
31
  """Return a canonical machine architecture string.
@@ -120,7 +123,7 @@ def find_host_compiler():
120
123
  return run_cmd("which g++").decode()
121
124
 
122
125
 
123
- def get_cuda_toolkit_version(cuda_home):
126
+ def get_cuda_toolkit_version(cuda_home) -> tuple[int, int]:
124
127
  try:
125
128
  # the toolkit version can be obtained by running "nvcc --version"
126
129
  nvcc_path = os.path.join(cuda_home, "bin", "nvcc")
@@ -128,14 +131,16 @@ def get_cuda_toolkit_version(cuda_home):
128
131
  # search for release substring (e.g., "release 11.5")
129
132
  import re
130
133
 
131
- m = re.search(r"(?<=release )\d+\.\d+", nvcc_version_output)
134
+ m = re.search(r"release (\d+)\.(\d+)", nvcc_version_output)
132
135
  if m is not None:
133
- return tuple(int(x) for x in m.group(0).split("."))
136
+ major, minor = map(int, m.groups())
137
+ return (major, minor)
134
138
  else:
135
139
  raise Exception("Failed to parse NVCC output")
136
140
 
137
141
  except Exception as e:
138
- print(f"Failed to determine CUDA Toolkit version: {e}")
142
+ print(f"Warning: Failed to determine CUDA Toolkit version: {e}")
143
+ return MIN_CTK_VERSION
139
144
 
140
145
 
141
146
  def quote(path):
@@ -169,7 +174,7 @@ def add_llvm_bin_to_path(args):
169
174
  return True
170
175
 
171
176
 
172
- def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[List[str]] = None, mode=None):
177
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: list[str] | None = None, mode=None):
173
178
  mode = args.mode if (mode is None) else mode
174
179
  cuda_home = args.cuda_path
175
180
  cuda_cmd = None
@@ -197,17 +202,12 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
197
202
 
198
203
  if cu_path:
199
204
  # check CUDA Toolkit version
200
- min_ctk_version = (11, 5)
201
- ctk_version = get_cuda_toolkit_version(cuda_home) or min_ctk_version
202
- if ctk_version < min_ctk_version:
205
+ ctk_version = get_cuda_toolkit_version(cuda_home)
206
+ if ctk_version < MIN_CTK_VERSION:
203
207
  raise Exception(
204
- f"CUDA Toolkit version {min_ctk_version[0]}.{min_ctk_version[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
208
+ f"CUDA Toolkit version {MIN_CTK_VERSION[0]}.{MIN_CTK_VERSION[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
205
209
  )
206
210
 
207
- if ctk_version[0] < 12 and args.libmathdx_path:
208
- print("MathDx support requires at least CUDA 12, skipping")
209
- args.libmathdx_path = None
210
-
211
211
  # NVCC gencode options
212
212
  gencode_opts = []
213
213
 
@@ -216,91 +216,71 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
216
216
 
217
217
  if args.quick:
218
218
  # minimum supported architectures (PTX)
219
- gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
220
- clang_arch_flags += ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
219
+ if ctk_version >= (13, 0):
220
+ gencode_opts += ["-gencode=arch=compute_75,code=compute_75"]
221
+ clang_arch_flags += ["--cuda-gpu-arch=sm_75"]
222
+ else:
223
+ gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
224
+ clang_arch_flags += ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
221
225
  else:
222
226
  # generate code for all supported architectures
223
227
  gencode_opts += [
224
228
  # SASS for supported desktop/datacenter architectures
225
- "-gencode=arch=compute_52,code=sm_52", # Maxwell
226
- "-gencode=arch=compute_60,code=sm_60", # Pascal
227
- "-gencode=arch=compute_61,code=sm_61",
228
- "-gencode=arch=compute_70,code=sm_70", # Volta
229
229
  "-gencode=arch=compute_75,code=sm_75", # Turing
230
230
  "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
231
231
  "-gencode=arch=compute_80,code=sm_80", # Ampere
232
232
  "-gencode=arch=compute_86,code=sm_86",
233
+ "-gencode=arch=compute_89,code=sm_89", # Ada
234
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
233
235
  ]
234
236
 
235
- # TODO: Get this working with sm_52, sm_60, sm_61
236
237
  clang_arch_flags += [
237
238
  # SASS for supported desktop/datacenter architectures
238
- "--cuda-gpu-arch=sm_52",
239
- "--cuda-gpu-arch=sm_60",
240
- "--cuda-gpu-arch=sm_61",
241
- "--cuda-gpu-arch=sm_70", # Volta
242
239
  "--cuda-gpu-arch=sm_75", # Turing
243
240
  "--cuda-gpu-arch=sm_80", # Ampere
244
241
  "--cuda-gpu-arch=sm_86",
242
+ "--cuda-gpu-arch=sm_89", # Ada
243
+ "--cuda-gpu-arch=sm_90", # Hopper
245
244
  ]
246
245
 
247
246
  if arch == "aarch64" and sys.platform == "linux":
248
- gencode_opts += [
249
- # SASS for supported mobile architectures (e.g. Tegra/Jetson)
250
- "-gencode=arch=compute_53,code=sm_53", # X1
251
- "-gencode=arch=compute_62,code=sm_62", # X2
252
- "-gencode=arch=compute_72,code=sm_72", # Xavier
253
- "-gencode=arch=compute_87,code=sm_87", # Orin
254
- ]
255
-
256
- clang_arch_flags += [
257
- # SASS for supported mobile architectures
258
- "--cuda-gpu-arch=sm_53", # X1
259
- "--cuda-gpu-arch=sm_62", # X2
260
- "--cuda-gpu-arch=sm_72", # Xavier
261
- "--cuda-gpu-arch=sm_87", # Orin
262
- ]
263
-
264
- if ctk_version >= (12, 8):
265
- gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
266
- clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
247
+ # SASS for supported mobile architectures (e.g. Tegra/Jetson)
248
+ gencode_opts += ["-gencode=arch=compute_87,code=sm_87"] # Orin
249
+ clang_arch_flags += ["--cuda-gpu-arch=sm_87"]
250
+
251
+ if ctk_version >= (13, 0):
252
+ gencode_opts += ["-gencode=arch=compute_110,code=sm_110"] # Thor
253
+ clang_arch_flags += ["--cuda-gpu-arch=sm_110"]
254
+ else:
255
+ gencode_opts += [
256
+ "-gencode=arch=compute_53,code=sm_53", # X1
257
+ "-gencode=arch=compute_62,code=sm_62", # X2
258
+ "-gencode=arch=compute_72,code=sm_72", # Xavier
259
+ ]
260
+ clang_arch_flags += [
261
+ "--cuda-gpu-arch=sm_53",
262
+ "--cuda-gpu-arch=sm_62",
263
+ "--cuda-gpu-arch=sm_72",
264
+ ]
265
+
266
+ if ctk_version >= (12, 8):
267
+ gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
268
+ clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
267
269
 
268
270
  if ctk_version >= (12, 8):
269
271
  # Support for Blackwell is available with CUDA Toolkit 12.8+
270
272
  gencode_opts += [
271
- "-gencode=arch=compute_89,code=sm_89", # Ada
272
- "-gencode=arch=compute_90,code=sm_90", # Hopper
273
273
  "-gencode=arch=compute_100,code=sm_100", # Blackwell
274
274
  "-gencode=arch=compute_120,code=sm_120", # Blackwell
275
275
  "-gencode=arch=compute_120,code=compute_120", # PTX for future hardware
276
276
  ]
277
277
 
278
278
  clang_arch_flags += [
279
- "--cuda-gpu-arch=sm_89", # Ada
280
- "--cuda-gpu-arch=sm_90", # Hopper
281
279
  "--cuda-gpu-arch=sm_100", # Blackwell
282
280
  "--cuda-gpu-arch=sm_120", # Blackwell
283
281
  ]
284
- elif ctk_version >= (11, 8):
285
- # Support for Ada and Hopper is available with CUDA Toolkit 11.8+
286
- gencode_opts += [
287
- "-gencode=arch=compute_89,code=sm_89", # Ada
288
- "-gencode=arch=compute_90,code=sm_90", # Hopper
289
- "-gencode=arch=compute_90,code=compute_90", # PTX for future hardware
290
- ]
291
-
292
- clang_arch_flags += [
293
- "--cuda-gpu-arch=sm_89", # Ada
294
- "--cuda-gpu-arch=sm_90", # Hopper
295
- ]
296
282
  else:
297
- gencode_opts += [
298
- "-gencode=arch=compute_86,code=compute_86", # PTX for future hardware
299
- ]
300
-
301
- clang_arch_flags += [
302
- "--cuda-gpu-arch=sm_86", # PTX for future hardware
303
- ]
283
+ gencode_opts += ["-gencode=arch=compute_90,code=compute_90"] # PTX for future hardware
304
284
 
305
285
  nvcc_opts = [
306
286
  *gencode_opts,