warp-lang 1.7.2rc1__py3-none-win_amd64.whl → 1.8.1__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +130 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +272 -104
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +770 -238
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_callable.py +34 -4
  36. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  37. warp/examples/interop/example_jax_kernel.py +27 -1
  38. warp/examples/optim/example_drone.py +1 -1
  39. warp/examples/sim/example_cloth.py +1 -1
  40. warp/examples/sim/example_cloth_self_contact.py +48 -54
  41. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  42. warp/examples/tile/example_tile_cholesky.py +2 -1
  43. warp/examples/tile/example_tile_convolution.py +1 -1
  44. warp/examples/tile/example_tile_filtering.py +1 -1
  45. warp/examples/tile/example_tile_matmul.py +1 -1
  46. warp/examples/tile/example_tile_mlp.py +2 -0
  47. warp/fabric.py +7 -7
  48. warp/fem/__init__.py +5 -0
  49. warp/fem/adaptivity.py +1 -1
  50. warp/fem/cache.py +152 -63
  51. warp/fem/dirichlet.py +2 -2
  52. warp/fem/domain.py +136 -6
  53. warp/fem/field/field.py +141 -99
  54. warp/fem/field/nodal_field.py +85 -39
  55. warp/fem/field/virtual.py +99 -52
  56. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  57. warp/fem/geometry/closest_point.py +13 -0
  58. warp/fem/geometry/deformed_geometry.py +102 -40
  59. warp/fem/geometry/element.py +56 -2
  60. warp/fem/geometry/geometry.py +323 -22
  61. warp/fem/geometry/grid_2d.py +157 -62
  62. warp/fem/geometry/grid_3d.py +116 -20
  63. warp/fem/geometry/hexmesh.py +86 -20
  64. warp/fem/geometry/nanogrid.py +166 -86
  65. warp/fem/geometry/partition.py +59 -25
  66. warp/fem/geometry/quadmesh.py +86 -135
  67. warp/fem/geometry/tetmesh.py +47 -119
  68. warp/fem/geometry/trimesh.py +77 -270
  69. warp/fem/integrate.py +181 -95
  70. warp/fem/linalg.py +25 -58
  71. warp/fem/operator.py +124 -27
  72. warp/fem/quadrature/pic_quadrature.py +36 -14
  73. warp/fem/quadrature/quadrature.py +40 -16
  74. warp/fem/space/__init__.py +1 -1
  75. warp/fem/space/basis_function_space.py +66 -46
  76. warp/fem/space/basis_space.py +17 -4
  77. warp/fem/space/dof_mapper.py +1 -1
  78. warp/fem/space/function_space.py +2 -2
  79. warp/fem/space/grid_2d_function_space.py +4 -1
  80. warp/fem/space/hexmesh_function_space.py +4 -2
  81. warp/fem/space/nanogrid_function_space.py +3 -1
  82. warp/fem/space/partition.py +11 -2
  83. warp/fem/space/quadmesh_function_space.py +4 -1
  84. warp/fem/space/restriction.py +5 -2
  85. warp/fem/space/shape/__init__.py +10 -8
  86. warp/fem/space/tetmesh_function_space.py +4 -1
  87. warp/fem/space/topology.py +52 -21
  88. warp/fem/space/trimesh_function_space.py +4 -1
  89. warp/fem/utils.py +53 -8
  90. warp/jax.py +1 -2
  91. warp/jax_experimental/ffi.py +210 -67
  92. warp/jax_experimental/xla_ffi.py +37 -24
  93. warp/math.py +171 -1
  94. warp/native/array.h +103 -4
  95. warp/native/builtin.h +182 -35
  96. warp/native/coloring.cpp +6 -2
  97. warp/native/cuda_util.cpp +1 -1
  98. warp/native/exports.h +118 -63
  99. warp/native/intersect.h +5 -5
  100. warp/native/mat.h +8 -13
  101. warp/native/mathdx.cpp +11 -5
  102. warp/native/matnn.h +1 -123
  103. warp/native/mesh.h +1 -1
  104. warp/native/quat.h +34 -6
  105. warp/native/rand.h +7 -7
  106. warp/native/sparse.cpp +121 -258
  107. warp/native/sparse.cu +181 -274
  108. warp/native/spatial.h +305 -17
  109. warp/native/svd.h +23 -8
  110. warp/native/tile.h +603 -73
  111. warp/native/tile_radix_sort.h +1112 -0
  112. warp/native/tile_reduce.h +239 -13
  113. warp/native/tile_scan.h +240 -0
  114. warp/native/tuple.h +189 -0
  115. warp/native/vec.h +10 -20
  116. warp/native/warp.cpp +36 -4
  117. warp/native/warp.cu +588 -52
  118. warp/native/warp.h +47 -74
  119. warp/optim/linear.py +5 -1
  120. warp/paddle.py +7 -8
  121. warp/py.typed +0 -0
  122. warp/render/render_opengl.py +110 -80
  123. warp/render/render_usd.py +124 -62
  124. warp/sim/__init__.py +9 -0
  125. warp/sim/collide.py +253 -80
  126. warp/sim/graph_coloring.py +8 -1
  127. warp/sim/import_mjcf.py +4 -3
  128. warp/sim/import_usd.py +11 -7
  129. warp/sim/integrator.py +5 -2
  130. warp/sim/integrator_euler.py +1 -1
  131. warp/sim/integrator_featherstone.py +1 -1
  132. warp/sim/integrator_vbd.py +761 -322
  133. warp/sim/integrator_xpbd.py +1 -1
  134. warp/sim/model.py +265 -260
  135. warp/sim/utils.py +10 -7
  136. warp/sparse.py +303 -166
  137. warp/tape.py +54 -51
  138. warp/tests/cuda/test_conditional_captures.py +1046 -0
  139. warp/tests/cuda/test_streams.py +1 -1
  140. warp/tests/geometry/test_volume.py +2 -2
  141. warp/tests/interop/test_dlpack.py +9 -9
  142. warp/tests/interop/test_jax.py +0 -1
  143. warp/tests/run_coverage_serial.py +1 -1
  144. warp/tests/sim/disabled_kinematics.py +2 -2
  145. warp/tests/sim/{test_vbd.py → test_cloth.py} +378 -112
  146. warp/tests/sim/test_collision.py +159 -51
  147. warp/tests/sim/test_coloring.py +91 -2
  148. warp/tests/test_array.py +254 -2
  149. warp/tests/test_array_reduce.py +2 -2
  150. warp/tests/test_assert.py +53 -0
  151. warp/tests/test_atomic_cas.py +312 -0
  152. warp/tests/test_codegen.py +142 -19
  153. warp/tests/test_conditional.py +47 -1
  154. warp/tests/test_ctypes.py +0 -20
  155. warp/tests/test_devices.py +8 -0
  156. warp/tests/test_fabricarray.py +4 -2
  157. warp/tests/test_fem.py +58 -25
  158. warp/tests/test_func.py +42 -1
  159. warp/tests/test_grad.py +1 -1
  160. warp/tests/test_lerp.py +1 -3
  161. warp/tests/test_map.py +481 -0
  162. warp/tests/test_mat.py +23 -24
  163. warp/tests/test_quat.py +28 -15
  164. warp/tests/test_rounding.py +10 -38
  165. warp/tests/test_runlength_encode.py +7 -7
  166. warp/tests/test_smoothstep.py +1 -1
  167. warp/tests/test_sparse.py +83 -2
  168. warp/tests/test_spatial.py +507 -1
  169. warp/tests/test_static.py +48 -0
  170. warp/tests/test_struct.py +2 -2
  171. warp/tests/test_tape.py +38 -0
  172. warp/tests/test_tuple.py +265 -0
  173. warp/tests/test_types.py +2 -2
  174. warp/tests/test_utils.py +24 -18
  175. warp/tests/test_vec.py +38 -408
  176. warp/tests/test_vec_constructors.py +325 -0
  177. warp/tests/tile/test_tile.py +438 -131
  178. warp/tests/tile/test_tile_mathdx.py +518 -14
  179. warp/tests/tile/test_tile_matmul.py +179 -0
  180. warp/tests/tile/test_tile_reduce.py +307 -5
  181. warp/tests/tile/test_tile_shared_memory.py +136 -7
  182. warp/tests/tile/test_tile_sort.py +121 -0
  183. warp/tests/unittest_suites.py +14 -6
  184. warp/types.py +462 -308
  185. warp/utils.py +647 -86
  186. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/METADATA +20 -6
  187. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/RECORD +190 -176
  188. warp/stubs.py +0 -3381
  189. warp/tests/sim/test_xpbd.py +0 -399
  190. warp/tests/test_mlp.py +0 -282
  191. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  192. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  193. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/build_dll.py CHANGED
@@ -17,16 +17,18 @@ import os
17
17
  import platform
18
18
  import subprocess
19
19
  import sys
20
+ from typing import List, Optional
20
21
 
21
22
  from warp.utils import ScopedTimer
22
23
 
23
24
  verbose_cmd = True # print command lines before executing them
24
25
 
25
26
 
26
- # returns a canonical machine architecture string
27
- # - "x86_64" for x86-64, aka. AMD64, aka. x64
28
- # - "aarch64" for AArch64, aka. ARM64
29
27
  def machine_architecture() -> str:
28
+ """Return a canonical machine architecture string.
29
+ - "x86_64" for x86-64, aka. AMD64, aka. x64
30
+ - "aarch64" for AArch64, aka. ARM64
31
+ """
30
32
  machine = platform.machine()
31
33
  if machine == "x86_64" or machine == "AMD64":
32
34
  return "x86_64"
@@ -103,10 +105,8 @@ def find_host_compiler():
103
105
  cl_required_major = 14
104
106
  cl_required_minor = 29
105
107
 
106
- if (
107
- (int(cl_version[0]) < cl_required_major)
108
- or (int(cl_version[0]) == cl_required_major)
109
- and int(cl_version[1]) < cl_required_minor
108
+ if int(cl_version[0]) < cl_required_major or (
109
+ (int(cl_version[0]) == cl_required_major) and (int(cl_version[1]) < cl_required_minor)
110
110
  ):
111
111
  print(
112
112
  f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
@@ -142,22 +142,54 @@ def quote(path):
142
142
  return '"' + path + '"'
143
143
 
144
144
 
145
- def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None):
145
+ def add_llvm_bin_to_path(args):
146
+ """Add the LLVM bin directory to the PATH environment variable if it's set.
147
+
148
+ Args:
149
+ args: The argument namespace containing llvm_path.
150
+
151
+ Returns:
152
+ ``True`` if the PATH was updated, ``False`` otherwise.
153
+ """
154
+ if not hasattr(args, "llvm_path") or not args.llvm_path:
155
+ return False
156
+
157
+ # Construct the bin directory path
158
+ llvm_bin_path = os.path.join(args.llvm_path, "bin")
159
+
160
+ # Check if the directory exists
161
+ if not os.path.isdir(llvm_bin_path):
162
+ print(f"Warning: LLVM bin directory not found at {llvm_bin_path}")
163
+ return False
164
+
165
+ # Add to PATH environment variable
166
+ os.environ["PATH"] = llvm_bin_path + os.pathsep + os.environ.get("PATH", "")
167
+
168
+ print(f"Added {llvm_bin_path} to PATH")
169
+ return True
170
+
171
+
172
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[List[str]] = None, mode=None):
146
173
  mode = args.mode if (mode is None) else mode
147
174
  cuda_home = args.cuda_path
148
175
  cuda_cmd = None
149
176
 
177
+ # Add LLVM bin directory to PATH
178
+ add_llvm_bin_to_path(args)
179
+
150
180
  if args.quick or cu_path is None:
151
181
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
152
182
  else:
153
183
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
154
184
 
185
+ if libs is None:
186
+ libs = []
187
+
155
188
  import pathlib
156
189
 
157
190
  warp_home_path = pathlib.Path(__file__).parent
158
191
  warp_home = warp_home_path.resolve()
159
192
 
160
- # output stale, rebuild
161
193
  if args.verbose:
162
194
  print(f"Building {dll_path}")
163
195
 
@@ -176,11 +208,16 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
176
208
  print("MathDx support requires at least CUDA 12, skipping")
177
209
  args.libmathdx_path = None
178
210
 
211
+ # NVCC gencode options
179
212
  gencode_opts = []
180
213
 
214
+ # Clang architecture flags
215
+ clang_arch_flags = []
216
+
181
217
  if args.quick:
182
218
  # minimum supported architectures (PTX)
183
219
  gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
220
+ clang_arch_flags += ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
184
221
  else:
185
222
  # generate code for all supported architectures
186
223
  gencode_opts += [
@@ -190,9 +227,23 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
190
227
  "-gencode=arch=compute_61,code=sm_61",
191
228
  "-gencode=arch=compute_70,code=sm_70", # Volta
192
229
  "-gencode=arch=compute_75,code=sm_75", # Turing
230
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
193
231
  "-gencode=arch=compute_80,code=sm_80", # Ampere
194
232
  "-gencode=arch=compute_86,code=sm_86",
195
233
  ]
234
+
235
+ # TODO: Get this working with sm_52, sm_60, sm_61
236
+ clang_arch_flags += [
237
+ # SASS for supported desktop/datacenter architectures
238
+ "--cuda-gpu-arch=sm_52",
239
+ "--cuda-gpu-arch=sm_60",
240
+ "--cuda-gpu-arch=sm_61",
241
+ "--cuda-gpu-arch=sm_70", # Volta
242
+ "--cuda-gpu-arch=sm_75", # Turing
243
+ "--cuda-gpu-arch=sm_80", # Ampere
244
+ "--cuda-gpu-arch=sm_86",
245
+ ]
246
+
196
247
  if arch == "aarch64" and sys.platform == "linux":
197
248
  gencode_opts += [
198
249
  # SASS for supported mobile architectures (e.g. Tegra/Jetson)
@@ -202,6 +253,18 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
202
253
  "-gencode=arch=compute_87,code=sm_87", # Orin
203
254
  ]
204
255
 
256
+ clang_arch_flags += [
257
+ # SASS for supported mobile architectures
258
+ "--cuda-gpu-arch=sm_53", # X1
259
+ "--cuda-gpu-arch=sm_62", # X2
260
+ "--cuda-gpu-arch=sm_72", # Xavier
261
+ "--cuda-gpu-arch=sm_87", # Orin
262
+ ]
263
+
264
+ if ctk_version >= (12, 8):
265
+ gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
266
+ clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
267
+
205
268
  if ctk_version >= (12, 8):
206
269
  # Support for Blackwell is available with CUDA Toolkit 12.8+
207
270
  gencode_opts += [
@@ -211,6 +274,13 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
211
274
  "-gencode=arch=compute_120,code=sm_120", # Blackwell
212
275
  "-gencode=arch=compute_120,code=compute_120", # PTX for future hardware
213
276
  ]
277
+
278
+ clang_arch_flags += [
279
+ "--cuda-gpu-arch=sm_89", # Ada
280
+ "--cuda-gpu-arch=sm_90", # Hopper
281
+ "--cuda-gpu-arch=sm_100", # Blackwell
282
+ "--cuda-gpu-arch=sm_120", # Blackwell
283
+ ]
214
284
  elif ctk_version >= (11, 8):
215
285
  # Support for Ada and Hopper is available with CUDA Toolkit 11.8+
216
286
  gencode_opts += [
@@ -218,16 +288,40 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
218
288
  "-gencode=arch=compute_90,code=sm_90", # Hopper
219
289
  "-gencode=arch=compute_90,code=compute_90", # PTX for future hardware
220
290
  ]
291
+
292
+ clang_arch_flags += [
293
+ "--cuda-gpu-arch=sm_89", # Ada
294
+ "--cuda-gpu-arch=sm_90", # Hopper
295
+ ]
221
296
  else:
222
297
  gencode_opts += [
223
298
  "-gencode=arch=compute_86,code=compute_86", # PTX for future hardware
224
299
  ]
225
300
 
226
- nvcc_opts = gencode_opts + [
301
+ clang_arch_flags += [
302
+ "--cuda-gpu-arch=sm_86", # PTX for future hardware
303
+ ]
304
+
305
+ nvcc_opts = [
306
+ *gencode_opts,
227
307
  "-t0", # multithreaded compilation
228
308
  "--extended-lambda",
229
309
  ]
230
310
 
311
+ # Clang options
312
+ clang_opts = [
313
+ *clang_arch_flags,
314
+ "-std=c++17",
315
+ "-xcuda",
316
+ f'--cuda-path="{cuda_home}"',
317
+ ]
318
+
319
+ if args.compile_time_trace:
320
+ if ctk_version >= (12, 8):
321
+ nvcc_opts.append("--fdevice-time-trace=build_lib_compile-time-trace")
322
+ else:
323
+ print("Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported")
324
+
231
325
  if args.fast_math:
232
326
  nvcc_opts.append("--use_fast_math")
233
327
 
@@ -304,13 +398,17 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
304
398
  )
305
399
 
306
400
  if args.libmathdx_path:
307
- linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib" mathdx_static.lib')
401
+ linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib/x64" mathdx_static.lib')
308
402
 
309
403
  with ScopedTimer("link", active=args.verbose):
310
404
  link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
311
405
  run_cmd(link_cmd)
312
406
 
313
407
  else:
408
+ # Unix compilation
409
+ cuda_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "nvcc"
410
+ cpp_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "g++"
411
+
314
412
  cpp_includes = f' -I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
315
413
  cpp_includes += f' -I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
316
414
  cuda_includes = f' -I"{cuda_home}/include"' if cu_path else ""
@@ -319,9 +417,12 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
319
417
  if sys.platform == "darwin":
320
418
  version = f"--target={arch}-apple-macos11"
321
419
  else:
322
- version = "-fabi-version=13" # GCC 8.2+
420
+ if cpp_compiler == "g++":
421
+ version = "-fabi-version=13" # GCC 8.2+
422
+ else:
423
+ version = ""
323
424
 
324
- cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
425
+ cpp_flags = f'-Werror -Wuninitialized {version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
325
426
 
326
427
  if mode == "debug":
327
428
  cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
@@ -342,17 +443,23 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
342
443
  cpp_out = cpp_path + ".o"
343
444
  ld_inputs.append(quote(cpp_out))
344
445
 
345
- build_cmd = f'g++ {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
446
+ build_cmd = f'{cpp_compiler} {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
346
447
  run_cmd(build_cmd)
347
448
 
348
449
  if cu_path:
349
450
  cu_out = cu_path + ".o"
350
451
 
351
- if mode == "debug":
352
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
353
-
354
- elif mode == "release":
355
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
452
+ if cuda_compiler == "nvcc":
453
+ if mode == "debug":
454
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
455
+ elif mode == "release":
456
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
457
+ else:
458
+ # Use Clang compiler
459
+ if mode == "debug":
460
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -g -O0 -fPIC -fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
461
+ elif mode == "release":
462
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -O3 -fPIC -fvisibility=hidden -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
356
463
 
357
464
  with ScopedTimer("build_cuda", active=args.verbose):
358
465
  run_cmd(cuda_cmd)
@@ -374,7 +481,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
374
481
 
375
482
  with ScopedTimer("link", active=args.verbose):
376
483
  origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
377
- link_cmd = f"g++ {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
484
+ link_cmd = f"{cpp_compiler} {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
378
485
  run_cmd(link_cmd)
379
486
 
380
487
  # Strip symbols to reduce the binary size
@@ -389,17 +496,14 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
389
496
 
390
497
 
391
498
  def build_dll(args, dll_path, cpp_paths, cu_path, libs=None):
392
- if libs is None:
393
- libs = []
394
-
395
499
  if sys.platform == "darwin":
396
500
  # create a universal binary by combining x86-64 and AArch64 builds
397
- build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, libs, "x86_64")
398
- build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, libs, "aarch64")
501
+ build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, "x86_64", libs)
502
+ build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, "aarch64", libs)
399
503
 
400
504
  run_cmd(f"lipo -create -output {dll_path} {dll_path}-x86_64 {dll_path}-aarch64")
401
505
  os.remove(f"{dll_path}-x86_64")
402
506
  os.remove(f"{dll_path}-aarch64")
403
507
 
404
508
  else:
405
- build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, machine_architecture())
509
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_path, machine_architecture(), libs)