warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/build_dll.py CHANGED
@@ -13,16 +13,19 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
+ from __future__ import annotations
17
+
16
18
  import os
17
19
  import platform
18
20
  import subprocess
19
21
  import sys
20
- from typing import List, Optional
21
22
 
22
23
  from warp.utils import ScopedTimer
23
24
 
24
25
  verbose_cmd = True # print command lines before executing them
25
26
 
27
+ MIN_CTK_VERSION = (12, 0)
28
+
26
29
 
27
30
  def machine_architecture() -> str:
28
31
  """Return a canonical machine architecture string.
@@ -120,7 +123,7 @@ def find_host_compiler():
120
123
  return run_cmd("which g++").decode()
121
124
 
122
125
 
123
- def get_cuda_toolkit_version(cuda_home):
126
+ def get_cuda_toolkit_version(cuda_home) -> tuple[int, int]:
124
127
  try:
125
128
  # the toolkit version can be obtained by running "nvcc --version"
126
129
  nvcc_path = os.path.join(cuda_home, "bin", "nvcc")
@@ -128,14 +131,16 @@ def get_cuda_toolkit_version(cuda_home):
128
131
  # search for release substring (e.g., "release 11.5")
129
132
  import re
130
133
 
131
- m = re.search(r"(?<=release )\d+\.\d+", nvcc_version_output)
134
+ m = re.search(r"release (\d+)\.(\d+)", nvcc_version_output)
132
135
  if m is not None:
133
- return tuple(int(x) for x in m.group(0).split("."))
136
+ major, minor = map(int, m.groups())
137
+ return (major, minor)
134
138
  else:
135
139
  raise Exception("Failed to parse NVCC output")
136
140
 
137
141
  except Exception as e:
138
- print(f"Failed to determine CUDA Toolkit version: {e}")
142
+ print(f"Warning: Failed to determine CUDA Toolkit version: {e}")
143
+ return MIN_CTK_VERSION
139
144
 
140
145
 
141
146
  def quote(path):
@@ -169,138 +174,363 @@ def add_llvm_bin_to_path(args):
169
174
  return True
170
175
 
171
176
 
172
- def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[List[str]] = None, mode=None):
173
- mode = args.mode if (mode is None) else mode
174
- cuda_home = args.cuda_path
175
- cuda_cmd = None
177
+ def _get_architectures_cu12(
178
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
179
+ ) -> tuple[list[str], list[str]]:
180
+ """Get architecture flags for CUDA 12.x."""
181
+ gencode_opts = []
182
+ clang_arch_flags = []
176
183
 
177
- # Add LLVM bin directory to PATH
178
- add_llvm_bin_to_path(args)
179
-
180
- if args.quick or cu_path is None:
181
- cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
184
+ if quick_build:
185
+ gencode_opts = ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
186
+ clang_arch_flags = ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
182
187
  else:
183
- cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
188
+ if arch == "aarch64" and target_platform == "linux" and ctk_version == (12, 9):
189
+ # Skip certain architectures for aarch64 with CUDA 12.9 due to CCCL bug
190
+ print(
191
+ "[INFO] Skipping sm_52, sm_60, sm_61, and sm_70 targets for ARM due to a CUDA Toolkit bug. "
192
+ "See https://nvidia.github.io/warp/installation.html#cuda-12-9-limitation-on-linux-arm-platforms "
193
+ "for details."
194
+ )
195
+ else:
196
+ gencode_opts.extend(
197
+ [
198
+ "-gencode=arch=compute_52,code=sm_52", # Maxwell
199
+ "-gencode=arch=compute_60,code=sm_60", # Pascal
200
+ "-gencode=arch=compute_61,code=sm_61",
201
+ "-gencode=arch=compute_70,code=sm_70", # Volta
202
+ ]
203
+ )
204
+ clang_arch_flags.extend(
205
+ [
206
+ "--cuda-gpu-arch=sm_52",
207
+ "--cuda-gpu-arch=sm_60",
208
+ "--cuda-gpu-arch=sm_61",
209
+ "--cuda-gpu-arch=sm_70",
210
+ ]
211
+ )
184
212
 
185
- if libs is None:
186
- libs = []
213
+ # Desktop architectures
214
+ gencode_opts.extend(
215
+ [
216
+ "-gencode=arch=compute_75,code=sm_75", # Turing
217
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
218
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
219
+ "-gencode=arch=compute_86,code=sm_86",
220
+ "-gencode=arch=compute_89,code=sm_89", # Ada
221
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
222
+ ]
223
+ )
224
+ clang_arch_flags.extend(
225
+ [
226
+ "--cuda-gpu-arch=sm_75", # Turing
227
+ "--cuda-gpu-arch=sm_80", # Ampere
228
+ "--cuda-gpu-arch=sm_86",
229
+ "--cuda-gpu-arch=sm_89", # Ada
230
+ "--cuda-gpu-arch=sm_90", # Hopper
231
+ ]
232
+ )
187
233
 
188
- import pathlib
234
+ if ctk_version >= (12, 8):
235
+ gencode_opts.extend(["-gencode=arch=compute_100,code=sm_100", "-gencode=arch=compute_120,code=sm_120"])
236
+ clang_arch_flags.extend(["--cuda-gpu-arch=sm_100", "--cuda-gpu-arch=sm_120"])
189
237
 
190
- warp_home_path = pathlib.Path(__file__).parent
191
- warp_home = warp_home_path.resolve()
238
+ # Mobile architectures for aarch64 Linux
239
+ if arch == "aarch64" and target_platform == "linux":
240
+ gencode_opts.extend(
241
+ [
242
+ "-gencode=arch=compute_87,code=sm_87", # Orin
243
+ "-gencode=arch=compute_53,code=sm_53", # X1
244
+ "-gencode=arch=compute_62,code=sm_62", # X2
245
+ "-gencode=arch=compute_72,code=sm_72", # Xavier
246
+ ]
247
+ )
248
+ clang_arch_flags.extend(
249
+ [
250
+ "--cuda-gpu-arch=sm_87",
251
+ "--cuda-gpu-arch=sm_53",
252
+ "--cuda-gpu-arch=sm_62",
253
+ "--cuda-gpu-arch=sm_72",
254
+ ]
255
+ )
192
256
 
193
- if args.verbose:
194
- print(f"Building {dll_path}")
257
+ # Thor support in CUDA 12.8+
258
+ if ctk_version >= (12, 8):
259
+ gencode_opts.append("-gencode=arch=compute_101,code=sm_101") # Thor (CUDA 12 numbering)
260
+ clang_arch_flags.append("--cuda-gpu-arch=sm_101")
261
+
262
+ if ctk_version >= (12, 9):
263
+ gencode_opts.append("-gencode=arch=compute_121,code=sm_121")
264
+ clang_arch_flags.append("--cuda-gpu-arch=sm_121")
265
+
266
+ # PTX for future hardware (use highest available compute capability)
267
+ if ctk_version >= (12, 9):
268
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
269
+ elif ctk_version >= (12, 8):
270
+ gencode_opts.extend(["-gencode=arch=compute_120,code=compute_120"])
271
+ else:
272
+ gencode_opts.append("-gencode=arch=compute_90,code=compute_90")
195
273
 
196
- native_dir = os.path.join(warp_home, "native")
274
+ return gencode_opts, clang_arch_flags
197
275
 
198
- if cu_path:
199
- # check CUDA Toolkit version
200
- min_ctk_version = (11, 5)
201
- ctk_version = get_cuda_toolkit_version(cuda_home) or min_ctk_version
202
- if ctk_version < min_ctk_version:
203
- raise Exception(
204
- f"CUDA Toolkit version {min_ctk_version[0]}.{min_ctk_version[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
276
+
277
+ def _get_architectures_cu13(
278
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
279
+ ) -> tuple[list[str], list[str]]:
280
+ """Get architecture flags for CUDA 13.x."""
281
+ gencode_opts = []
282
+ clang_arch_flags = []
283
+
284
+ if quick_build:
285
+ gencode_opts = ["-gencode=arch=compute_75,code=compute_75"]
286
+ clang_arch_flags = ["--cuda-gpu-arch=sm_75"]
287
+ else:
288
+ # Desktop architectures
289
+ gencode_opts.extend(
290
+ [
291
+ "-gencode=arch=compute_75,code=sm_75", # Turing
292
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
293
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
294
+ "-gencode=arch=compute_86,code=sm_86",
295
+ "-gencode=arch=compute_89,code=sm_89", # Ada
296
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
297
+ "-gencode=arch=compute_100,code=sm_100", # Blackwell
298
+ "-gencode=arch=compute_120,code=sm_120", # Blackwell
299
+ ]
300
+ )
301
+ clang_arch_flags.extend(
302
+ [
303
+ "--cuda-gpu-arch=sm_75", # Turing
304
+ "--cuda-gpu-arch=sm_80", # Ampere
305
+ "--cuda-gpu-arch=sm_86",
306
+ "--cuda-gpu-arch=sm_89", # Ada
307
+ "--cuda-gpu-arch=sm_90", # Hopper
308
+ "--cuda-gpu-arch=sm_100", # Blackwell
309
+ "--cuda-gpu-arch=sm_120", # Blackwell
310
+ ]
311
+ )
312
+
313
+ # Mobile architectures for aarch64 Linux
314
+ if arch == "aarch64" and target_platform == "linux":
315
+ gencode_opts.extend(
316
+ [
317
+ "-gencode=arch=compute_87,code=sm_87", # Orin
318
+ "-gencode=arch=compute_110,code=sm_110", # Thor
319
+ "-gencode=arch=compute_121,code=sm_121", # Spark
320
+ ]
321
+ )
322
+ clang_arch_flags.extend(
323
+ [
324
+ "--cuda-gpu-arch=sm_87",
325
+ "--cuda-gpu-arch=sm_110",
326
+ "--cuda-gpu-arch=sm_121",
327
+ ]
205
328
  )
206
329
 
207
- if ctk_version[0] < 12 and args.libmathdx_path:
208
- print("MathDx support requires at least CUDA 12, skipping")
209
- args.libmathdx_path = None
330
+ # PTX for future hardware (use highest available compute capability)
331
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
210
332
 
211
- # NVCC gencode options
212
- gencode_opts = []
333
+ return gencode_opts, clang_arch_flags
213
334
 
214
- # Clang architecture flags
215
- clang_arch_flags = []
216
335
 
217
- if args.quick:
218
- # minimum supported architectures (PTX)
219
- gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
220
- clang_arch_flags += ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
336
+ def _get_architectures_cu12(
337
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
338
+ ) -> tuple[list[str], list[str]]:
339
+ """Get architecture flags for CUDA 12.x."""
340
+ gencode_opts = []
341
+ clang_arch_flags = []
342
+
343
+ if quick_build:
344
+ gencode_opts = ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
345
+ clang_arch_flags = ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
346
+ else:
347
+ if arch == "aarch64" and target_platform == "linux" and ctk_version == (12, 9):
348
+ # Skip certain architectures for aarch64 with CUDA 12.9 due to CCCL bug
349
+ print(
350
+ "[INFO] Skipping sm_52, sm_60, sm_61, and sm_70 targets for ARM due to a CUDA Toolkit bug. "
351
+ "See https://nvidia.github.io/warp/installation.html#cuda-12-9-limitation-on-linux-arm-platforms "
352
+ "for details."
353
+ )
221
354
  else:
222
- # generate code for all supported architectures
223
- gencode_opts += [
224
- # SASS for supported desktop/datacenter architectures
225
- "-gencode=arch=compute_52,code=sm_52", # Maxwell
226
- "-gencode=arch=compute_60,code=sm_60", # Pascal
227
- "-gencode=arch=compute_61,code=sm_61",
228
- "-gencode=arch=compute_70,code=sm_70", # Volta
355
+ gencode_opts.extend(
356
+ [
357
+ "-gencode=arch=compute_52,code=sm_52", # Maxwell
358
+ "-gencode=arch=compute_60,code=sm_60", # Pascal
359
+ "-gencode=arch=compute_61,code=sm_61",
360
+ "-gencode=arch=compute_70,code=sm_70", # Volta
361
+ ]
362
+ )
363
+ clang_arch_flags.extend(
364
+ [
365
+ "--cuda-gpu-arch=sm_52",
366
+ "--cuda-gpu-arch=sm_60",
367
+ "--cuda-gpu-arch=sm_61",
368
+ "--cuda-gpu-arch=sm_70",
369
+ ]
370
+ )
371
+
372
+ # Desktop architectures
373
+ gencode_opts.extend(
374
+ [
229
375
  "-gencode=arch=compute_75,code=sm_75", # Turing
230
376
  "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
231
377
  "-gencode=arch=compute_80,code=sm_80", # Ampere
232
378
  "-gencode=arch=compute_86,code=sm_86",
379
+ "-gencode=arch=compute_89,code=sm_89", # Ada
380
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
233
381
  ]
234
-
235
- # TODO: Get this working with sm_52, sm_60, sm_61
236
- clang_arch_flags += [
237
- # SASS for supported desktop/datacenter architectures
238
- "--cuda-gpu-arch=sm_52",
239
- "--cuda-gpu-arch=sm_60",
240
- "--cuda-gpu-arch=sm_61",
241
- "--cuda-gpu-arch=sm_70", # Volta
382
+ )
383
+ clang_arch_flags.extend(
384
+ [
242
385
  "--cuda-gpu-arch=sm_75", # Turing
243
386
  "--cuda-gpu-arch=sm_80", # Ampere
244
387
  "--cuda-gpu-arch=sm_86",
388
+ "--cuda-gpu-arch=sm_89", # Ada
389
+ "--cuda-gpu-arch=sm_90", # Hopper
245
390
  ]
391
+ )
392
+
393
+ if ctk_version >= (12, 8):
394
+ gencode_opts.extend(["-gencode=arch=compute_100,code=sm_100", "-gencode=arch=compute_120,code=sm_120"])
395
+ clang_arch_flags.extend(["--cuda-gpu-arch=sm_100", "--cuda-gpu-arch=sm_120"])
246
396
 
247
- if arch == "aarch64" and sys.platform == "linux":
248
- gencode_opts += [
249
- # SASS for supported mobile architectures (e.g. Tegra/Jetson)
397
+ # Mobile architectures for aarch64 Linux
398
+ if arch == "aarch64" and target_platform == "linux":
399
+ gencode_opts.extend(
400
+ [
401
+ "-gencode=arch=compute_87,code=sm_87", # Orin
250
402
  "-gencode=arch=compute_53,code=sm_53", # X1
251
403
  "-gencode=arch=compute_62,code=sm_62", # X2
252
404
  "-gencode=arch=compute_72,code=sm_72", # Xavier
253
- "-gencode=arch=compute_87,code=sm_87", # Orin
254
405
  ]
255
-
256
- clang_arch_flags += [
257
- # SASS for supported mobile architectures
258
- "--cuda-gpu-arch=sm_53", # X1
259
- "--cuda-gpu-arch=sm_62", # X2
260
- "--cuda-gpu-arch=sm_72", # Xavier
261
- "--cuda-gpu-arch=sm_87", # Orin
406
+ )
407
+ clang_arch_flags.extend(
408
+ [
409
+ "--cuda-gpu-arch=sm_87",
410
+ "--cuda-gpu-arch=sm_53",
411
+ "--cuda-gpu-arch=sm_62",
412
+ "--cuda-gpu-arch=sm_72",
262
413
  ]
414
+ )
263
415
 
264
- if ctk_version >= (12, 8):
265
- gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
266
- clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
267
-
416
+ # Thor support in CUDA 12.8+
268
417
  if ctk_version >= (12, 8):
269
- # Support for Blackwell is available with CUDA Toolkit 12.8+
270
- gencode_opts += [
271
- "-gencode=arch=compute_89,code=sm_89", # Ada
272
- "-gencode=arch=compute_90,code=sm_90", # Hopper
273
- "-gencode=arch=compute_100,code=sm_100", # Blackwell
274
- "-gencode=arch=compute_120,code=sm_120", # Blackwell
275
- "-gencode=arch=compute_120,code=compute_120", # PTX for future hardware
276
- ]
418
+ gencode_opts.append("-gencode=arch=compute_101,code=sm_101") # Thor (CUDA 12 numbering)
419
+ clang_arch_flags.append("--cuda-gpu-arch=sm_101")
420
+
421
+ if ctk_version >= (12, 9):
422
+ gencode_opts.append("-gencode=arch=compute_121,code=sm_121")
423
+ clang_arch_flags.append("--cuda-gpu-arch=sm_121")
424
+
425
+ # PTX for future hardware (use highest available compute capability)
426
+ if ctk_version >= (12, 9):
427
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
428
+ elif ctk_version >= (12, 8):
429
+ gencode_opts.extend(["-gencode=arch=compute_120,code=compute_120"])
430
+ else:
431
+ gencode_opts.append("-gencode=arch=compute_90,code=compute_90")
432
+
433
+ return gencode_opts, clang_arch_flags
277
434
 
278
- clang_arch_flags += [
279
- "--cuda-gpu-arch=sm_89", # Ada
280
- "--cuda-gpu-arch=sm_90", # Hopper
281
- "--cuda-gpu-arch=sm_100", # Blackwell
282
- "--cuda-gpu-arch=sm_120", # Blackwell
283
- ]
284
- elif ctk_version >= (11, 8):
285
- # Support for Ada and Hopper is available with CUDA Toolkit 11.8+
286
- gencode_opts += [
287
- "-gencode=arch=compute_89,code=sm_89", # Ada
288
- "-gencode=arch=compute_90,code=sm_90", # Hopper
289
- "-gencode=arch=compute_90,code=compute_90", # PTX for future hardware
290
- ]
291
435
 
292
- clang_arch_flags += [
293
- "--cuda-gpu-arch=sm_89", # Ada
294
- "--cuda-gpu-arch=sm_90", # Hopper
436
+ def _get_architectures_cu13(
437
+ ctk_version: tuple[int, int], arch: str, target_platform: str, quick_build: bool = False
438
+ ) -> tuple[list[str], list[str]]:
439
+ """Get architecture flags for CUDA 13.x."""
440
+ gencode_opts = []
441
+ clang_arch_flags = []
442
+
443
+ if quick_build:
444
+ gencode_opts = ["-gencode=arch=compute_75,code=compute_75"]
445
+ clang_arch_flags = ["--cuda-gpu-arch=sm_75"]
446
+ else:
447
+ # Desktop architectures
448
+ gencode_opts.extend(
449
+ [
450
+ "-gencode=arch=compute_75,code=sm_75", # Turing
451
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
452
+ "-gencode=arch=compute_80,code=sm_80", # Ampere
453
+ "-gencode=arch=compute_86,code=sm_86",
454
+ "-gencode=arch=compute_89,code=sm_89", # Ada
455
+ "-gencode=arch=compute_90,code=sm_90", # Hopper
456
+ "-gencode=arch=compute_100,code=sm_100", # Blackwell
457
+ "-gencode=arch=compute_120,code=sm_120", # Blackwell
458
+ ]
459
+ )
460
+ clang_arch_flags.extend(
461
+ [
462
+ "--cuda-gpu-arch=sm_75", # Turing
463
+ "--cuda-gpu-arch=sm_80", # Ampere
464
+ "--cuda-gpu-arch=sm_86",
465
+ "--cuda-gpu-arch=sm_89", # Ada
466
+ "--cuda-gpu-arch=sm_90", # Hopper
467
+ "--cuda-gpu-arch=sm_100", # Blackwell
468
+ "--cuda-gpu-arch=sm_120", # Blackwell
469
+ ]
470
+ )
471
+
472
+ # Mobile architectures for aarch64 Linux
473
+ if arch == "aarch64" and target_platform == "linux":
474
+ gencode_opts.extend(
475
+ [
476
+ "-gencode=arch=compute_87,code=sm_87", # Orin
477
+ "-gencode=arch=compute_110,code=sm_110", # Thor
478
+ "-gencode=arch=compute_121,code=sm_121", # Spark
295
479
  ]
296
- else:
297
- gencode_opts += [
298
- "-gencode=arch=compute_86,code=compute_86", # PTX for future hardware
480
+ )
481
+ clang_arch_flags.extend(
482
+ [
483
+ "--cuda-gpu-arch=sm_87",
484
+ "--cuda-gpu-arch=sm_110",
485
+ "--cuda-gpu-arch=sm_121",
299
486
  ]
487
+ )
300
488
 
301
- clang_arch_flags += [
302
- "--cuda-gpu-arch=sm_86", # PTX for future hardware
303
- ]
489
+ # PTX for future hardware (use highest available compute capability)
490
+ gencode_opts.extend(["-gencode=arch=compute_121,code=compute_121"])
491
+
492
+ return gencode_opts, clang_arch_flags
493
+
494
+
495
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: list[str] | None = None, mode=None):
496
+ mode = args.mode if (mode is None) else mode
497
+ cuda_home = args.cuda_path
498
+ cuda_cmd = None
499
+
500
+ # Add LLVM bin directory to PATH
501
+ add_llvm_bin_to_path(args)
502
+
503
+ if args.quick or cu_path is None:
504
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
505
+ else:
506
+ cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
507
+
508
+ if libs is None:
509
+ libs = []
510
+
511
+ import pathlib
512
+
513
+ warp_home_path = pathlib.Path(__file__).parent
514
+ warp_home = warp_home_path.resolve()
515
+
516
+ if args.verbose:
517
+ print(f"Building {dll_path}")
518
+
519
+ native_dir = os.path.join(warp_home, "native")
520
+
521
+ if cu_path:
522
+ # check CUDA Toolkit version
523
+ ctk_version = get_cuda_toolkit_version(cuda_home)
524
+ if ctk_version < MIN_CTK_VERSION:
525
+ raise Exception(
526
+ f"CUDA Toolkit version {MIN_CTK_VERSION[0]}.{MIN_CTK_VERSION[1]}+ is required (found {ctk_version[0]}.{ctk_version[1]} in {cuda_home})"
527
+ )
528
+
529
+ # Get architecture flags based on CUDA version
530
+ if ctk_version >= (13, 0):
531
+ gencode_opts, clang_arch_flags = _get_architectures_cu13(ctk_version, arch, sys.platform, args.quick)
532
+ else:
533
+ gencode_opts, clang_arch_flags = _get_architectures_cu12(ctk_version, arch, sys.platform, args.quick)
304
534
 
305
535
  nvcc_opts = [
306
536
  *gencode_opts,