warp-lang 1.7.2rc1__py3-none-macosx_10_13_universal2.whl → 1.8.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (180) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +241 -252
  6. warp/build_dll.py +125 -26
  7. warp/builtins.py +1907 -384
  8. warp/codegen.py +257 -101
  9. warp/config.py +12 -1
  10. warp/constants.py +1 -1
  11. warp/context.py +657 -223
  12. warp/dlpack.py +1 -1
  13. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  14. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  15. warp/examples/core/example_sample_mesh.py +1 -1
  16. warp/examples/core/example_spin_lock.py +93 -0
  17. warp/examples/core/example_work_queue.py +118 -0
  18. warp/examples/fem/example_adaptive_grid.py +5 -5
  19. warp/examples/fem/example_apic_fluid.py +1 -1
  20. warp/examples/fem/example_burgers.py +1 -1
  21. warp/examples/fem/example_convection_diffusion.py +9 -6
  22. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  23. warp/examples/fem/example_deformed_geometry.py +1 -1
  24. warp/examples/fem/example_diffusion.py +2 -2
  25. warp/examples/fem/example_diffusion_3d.py +1 -1
  26. warp/examples/fem/example_distortion_energy.py +1 -1
  27. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  28. warp/examples/fem/example_magnetostatics.py +5 -3
  29. warp/examples/fem/example_mixed_elasticity.py +5 -3
  30. warp/examples/fem/example_navier_stokes.py +11 -9
  31. warp/examples/fem/example_nonconforming_contact.py +5 -3
  32. warp/examples/fem/example_streamlines.py +8 -3
  33. warp/examples/fem/utils.py +9 -8
  34. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  35. warp/examples/optim/example_drone.py +1 -1
  36. warp/examples/sim/example_cloth.py +1 -1
  37. warp/examples/sim/example_cloth_self_contact.py +48 -54
  38. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  39. warp/examples/tile/example_tile_cholesky.py +2 -1
  40. warp/examples/tile/example_tile_convolution.py +1 -1
  41. warp/examples/tile/example_tile_filtering.py +1 -1
  42. warp/examples/tile/example_tile_matmul.py +1 -1
  43. warp/examples/tile/example_tile_mlp.py +2 -0
  44. warp/fabric.py +7 -7
  45. warp/fem/__init__.py +5 -0
  46. warp/fem/adaptivity.py +1 -1
  47. warp/fem/cache.py +152 -63
  48. warp/fem/dirichlet.py +2 -2
  49. warp/fem/domain.py +136 -6
  50. warp/fem/field/field.py +141 -99
  51. warp/fem/field/nodal_field.py +85 -39
  52. warp/fem/field/virtual.py +97 -52
  53. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  54. warp/fem/geometry/closest_point.py +13 -0
  55. warp/fem/geometry/deformed_geometry.py +102 -40
  56. warp/fem/geometry/element.py +56 -2
  57. warp/fem/geometry/geometry.py +323 -22
  58. warp/fem/geometry/grid_2d.py +157 -62
  59. warp/fem/geometry/grid_3d.py +116 -20
  60. warp/fem/geometry/hexmesh.py +86 -20
  61. warp/fem/geometry/nanogrid.py +166 -86
  62. warp/fem/geometry/partition.py +59 -25
  63. warp/fem/geometry/quadmesh.py +86 -135
  64. warp/fem/geometry/tetmesh.py +47 -119
  65. warp/fem/geometry/trimesh.py +77 -270
  66. warp/fem/integrate.py +107 -52
  67. warp/fem/linalg.py +25 -58
  68. warp/fem/operator.py +124 -27
  69. warp/fem/quadrature/pic_quadrature.py +36 -14
  70. warp/fem/quadrature/quadrature.py +40 -16
  71. warp/fem/space/__init__.py +1 -1
  72. warp/fem/space/basis_function_space.py +66 -46
  73. warp/fem/space/basis_space.py +17 -4
  74. warp/fem/space/dof_mapper.py +1 -1
  75. warp/fem/space/function_space.py +2 -2
  76. warp/fem/space/grid_2d_function_space.py +4 -1
  77. warp/fem/space/hexmesh_function_space.py +4 -2
  78. warp/fem/space/nanogrid_function_space.py +3 -1
  79. warp/fem/space/partition.py +11 -2
  80. warp/fem/space/quadmesh_function_space.py +4 -1
  81. warp/fem/space/restriction.py +5 -2
  82. warp/fem/space/shape/__init__.py +10 -8
  83. warp/fem/space/tetmesh_function_space.py +4 -1
  84. warp/fem/space/topology.py +52 -21
  85. warp/fem/space/trimesh_function_space.py +4 -1
  86. warp/fem/utils.py +53 -8
  87. warp/jax.py +1 -2
  88. warp/jax_experimental/ffi.py +12 -17
  89. warp/jax_experimental/xla_ffi.py +37 -24
  90. warp/math.py +171 -1
  91. warp/native/array.h +99 -0
  92. warp/native/builtin.h +174 -31
  93. warp/native/coloring.cpp +1 -1
  94. warp/native/exports.h +118 -63
  95. warp/native/intersect.h +3 -3
  96. warp/native/mat.h +5 -10
  97. warp/native/mathdx.cpp +11 -5
  98. warp/native/matnn.h +1 -123
  99. warp/native/quat.h +28 -4
  100. warp/native/sparse.cpp +121 -258
  101. warp/native/sparse.cu +181 -274
  102. warp/native/spatial.h +305 -17
  103. warp/native/tile.h +583 -72
  104. warp/native/tile_radix_sort.h +1108 -0
  105. warp/native/tile_reduce.h +237 -2
  106. warp/native/tile_scan.h +240 -0
  107. warp/native/tuple.h +189 -0
  108. warp/native/vec.h +6 -16
  109. warp/native/warp.cpp +36 -4
  110. warp/native/warp.cu +574 -51
  111. warp/native/warp.h +47 -74
  112. warp/optim/linear.py +5 -1
  113. warp/paddle.py +7 -8
  114. warp/py.typed +0 -0
  115. warp/render/render_opengl.py +58 -29
  116. warp/render/render_usd.py +124 -61
  117. warp/sim/__init__.py +9 -0
  118. warp/sim/collide.py +252 -78
  119. warp/sim/graph_coloring.py +8 -1
  120. warp/sim/import_mjcf.py +4 -3
  121. warp/sim/import_usd.py +11 -7
  122. warp/sim/integrator.py +5 -2
  123. warp/sim/integrator_euler.py +1 -1
  124. warp/sim/integrator_featherstone.py +1 -1
  125. warp/sim/integrator_vbd.py +751 -320
  126. warp/sim/integrator_xpbd.py +1 -1
  127. warp/sim/model.py +265 -260
  128. warp/sim/utils.py +10 -7
  129. warp/sparse.py +303 -166
  130. warp/tape.py +52 -51
  131. warp/tests/cuda/test_conditional_captures.py +1046 -0
  132. warp/tests/cuda/test_streams.py +1 -1
  133. warp/tests/geometry/test_volume.py +2 -2
  134. warp/tests/interop/test_dlpack.py +9 -9
  135. warp/tests/interop/test_jax.py +0 -1
  136. warp/tests/run_coverage_serial.py +1 -1
  137. warp/tests/sim/disabled_kinematics.py +2 -2
  138. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  139. warp/tests/sim/test_collision.py +159 -51
  140. warp/tests/sim/test_coloring.py +15 -1
  141. warp/tests/test_array.py +254 -2
  142. warp/tests/test_array_reduce.py +2 -2
  143. warp/tests/test_atomic_cas.py +299 -0
  144. warp/tests/test_codegen.py +142 -19
  145. warp/tests/test_conditional.py +47 -1
  146. warp/tests/test_ctypes.py +0 -20
  147. warp/tests/test_devices.py +8 -0
  148. warp/tests/test_fabricarray.py +4 -2
  149. warp/tests/test_fem.py +58 -25
  150. warp/tests/test_func.py +42 -1
  151. warp/tests/test_grad.py +1 -1
  152. warp/tests/test_lerp.py +1 -3
  153. warp/tests/test_map.py +481 -0
  154. warp/tests/test_mat.py +1 -24
  155. warp/tests/test_quat.py +6 -15
  156. warp/tests/test_rounding.py +10 -38
  157. warp/tests/test_runlength_encode.py +7 -7
  158. warp/tests/test_smoothstep.py +1 -1
  159. warp/tests/test_sparse.py +51 -2
  160. warp/tests/test_spatial.py +507 -1
  161. warp/tests/test_struct.py +2 -2
  162. warp/tests/test_tuple.py +265 -0
  163. warp/tests/test_types.py +2 -2
  164. warp/tests/test_utils.py +24 -18
  165. warp/tests/tile/test_tile.py +420 -1
  166. warp/tests/tile/test_tile_mathdx.py +518 -14
  167. warp/tests/tile/test_tile_reduce.py +213 -0
  168. warp/tests/tile/test_tile_shared_memory.py +130 -1
  169. warp/tests/tile/test_tile_sort.py +117 -0
  170. warp/tests/unittest_suites.py +4 -6
  171. warp/types.py +462 -308
  172. warp/utils.py +647 -86
  173. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  174. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/RECORD +177 -165
  175. warp/stubs.py +0 -3381
  176. warp/tests/sim/test_xpbd.py +0 -399
  177. warp/tests/test_mlp.py +0 -282
  178. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  179. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  180. {warp_lang-1.7.2rc1.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
warp/build_dll.py CHANGED
@@ -17,16 +17,18 @@ import os
17
17
  import platform
18
18
  import subprocess
19
19
  import sys
20
+ from typing import List, Optional
20
21
 
21
22
  from warp.utils import ScopedTimer
22
23
 
23
24
  verbose_cmd = True # print command lines before executing them
24
25
 
25
26
 
26
- # returns a canonical machine architecture string
27
- # - "x86_64" for x86-64, aka. AMD64, aka. x64
28
- # - "aarch64" for AArch64, aka. ARM64
29
27
  def machine_architecture() -> str:
28
+ """Return a canonical machine architecture string.
29
+ - "x86_64" for x86-64, aka. AMD64, aka. x64
30
+ - "aarch64" for AArch64, aka. ARM64
31
+ """
30
32
  machine = platform.machine()
31
33
  if machine == "x86_64" or machine == "AMD64":
32
34
  return "x86_64"
@@ -103,10 +105,8 @@ def find_host_compiler():
103
105
  cl_required_major = 14
104
106
  cl_required_minor = 29
105
107
 
106
- if (
107
- (int(cl_version[0]) < cl_required_major)
108
- or (int(cl_version[0]) == cl_required_major)
109
- and int(cl_version[1]) < cl_required_minor
108
+ if int(cl_version[0]) < cl_required_major or (
109
+ (int(cl_version[0]) == cl_required_major) and (int(cl_version[1]) < cl_required_minor)
110
110
  ):
111
111
  print(
112
112
  f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
@@ -142,22 +142,54 @@ def quote(path):
142
142
  return '"' + path + '"'
143
143
 
144
144
 
145
- def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None):
145
+ def add_llvm_bin_to_path(args):
146
+ """Add the LLVM bin directory to the PATH environment variable if it's set.
147
+
148
+ Args:
149
+ args: The argument namespace containing llvm_path.
150
+
151
+ Returns:
152
+ ``True`` if the PATH was updated, ``False`` otherwise.
153
+ """
154
+ if not hasattr(args, "llvm_path") or not args.llvm_path:
155
+ return False
156
+
157
+ # Construct the bin directory path
158
+ llvm_bin_path = os.path.join(args.llvm_path, "bin")
159
+
160
+ # Check if the directory exists
161
+ if not os.path.isdir(llvm_bin_path):
162
+ print(f"Warning: LLVM bin directory not found at {llvm_bin_path}")
163
+ return False
164
+
165
+ # Add to PATH environment variable
166
+ os.environ["PATH"] = llvm_bin_path + os.pathsep + os.environ.get("PATH", "")
167
+
168
+ print(f"Added {llvm_bin_path} to PATH")
169
+ return True
170
+
171
+
172
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[List[str]] = None, mode=None):
146
173
  mode = args.mode if (mode is None) else mode
147
174
  cuda_home = args.cuda_path
148
175
  cuda_cmd = None
149
176
 
177
+ # Add LLVM bin directory to PATH
178
+ add_llvm_bin_to_path(args)
179
+
150
180
  if args.quick or cu_path is None:
151
181
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
152
182
  else:
153
183
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
154
184
 
185
+ if libs is None:
186
+ libs = []
187
+
155
188
  import pathlib
156
189
 
157
190
  warp_home_path = pathlib.Path(__file__).parent
158
191
  warp_home = warp_home_path.resolve()
159
192
 
160
- # output stale, rebuild
161
193
  if args.verbose:
162
194
  print(f"Building {dll_path}")
163
195
 
@@ -176,11 +208,16 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
176
208
  print("MathDx support requires at least CUDA 12, skipping")
177
209
  args.libmathdx_path = None
178
210
 
211
+ # NVCC gencode options
179
212
  gencode_opts = []
180
213
 
214
+ # Clang architecture flags
215
+ clang_arch_flags = []
216
+
181
217
  if args.quick:
182
218
  # minimum supported architectures (PTX)
183
219
  gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
220
+ clang_arch_flags += ["--cuda-gpu-arch=sm_52", "--cuda-gpu-arch=sm_75"]
184
221
  else:
185
222
  # generate code for all supported architectures
186
223
  gencode_opts += [
@@ -193,6 +230,19 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
193
230
  "-gencode=arch=compute_80,code=sm_80", # Ampere
194
231
  "-gencode=arch=compute_86,code=sm_86",
195
232
  ]
233
+
234
+ # TODO: Get this working with sm_52, sm_60, sm_61
235
+ clang_arch_flags += [
236
+ # SASS for supported desktop/datacenter architectures
237
+ "--cuda-gpu-arch=sm_52",
238
+ "--cuda-gpu-arch=sm_60",
239
+ "--cuda-gpu-arch=sm_61",
240
+ "--cuda-gpu-arch=sm_70", # Volta
241
+ "--cuda-gpu-arch=sm_75", # Turing
242
+ "--cuda-gpu-arch=sm_80", # Ampere
243
+ "--cuda-gpu-arch=sm_86",
244
+ ]
245
+
196
246
  if arch == "aarch64" and sys.platform == "linux":
197
247
  gencode_opts += [
198
248
  # SASS for supported mobile architectures (e.g. Tegra/Jetson)
@@ -202,6 +252,14 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
202
252
  "-gencode=arch=compute_87,code=sm_87", # Orin
203
253
  ]
204
254
 
255
+ clang_arch_flags += [
256
+ # SASS for supported mobile architectures
257
+ "--cuda-gpu-arch=sm_53", # X1
258
+ "--cuda-gpu-arch=sm_62", # X2
259
+ "--cuda-gpu-arch=sm_72", # Xavier
260
+ "--cuda-gpu-arch=sm_87", # Orin
261
+ ]
262
+
205
263
  if ctk_version >= (12, 8):
206
264
  # Support for Blackwell is available with CUDA Toolkit 12.8+
207
265
  gencode_opts += [
@@ -211,6 +269,13 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
211
269
  "-gencode=arch=compute_120,code=sm_120", # Blackwell
212
270
  "-gencode=arch=compute_120,code=compute_120", # PTX for future hardware
213
271
  ]
272
+
273
+ clang_arch_flags += [
274
+ "--cuda-gpu-arch=sm_89", # Ada
275
+ "--cuda-gpu-arch=sm_90", # Hopper
276
+ "--cuda-gpu-arch=sm_100", # Blackwell
277
+ "--cuda-gpu-arch=sm_120", # Blackwell
278
+ ]
214
279
  elif ctk_version >= (11, 8):
215
280
  # Support for Ada and Hopper is available with CUDA Toolkit 11.8+
216
281
  gencode_opts += [
@@ -218,16 +283,40 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
218
283
  "-gencode=arch=compute_90,code=sm_90", # Hopper
219
284
  "-gencode=arch=compute_90,code=compute_90", # PTX for future hardware
220
285
  ]
286
+
287
+ clang_arch_flags += [
288
+ "--cuda-gpu-arch=sm_89", # Ada
289
+ "--cuda-gpu-arch=sm_90", # Hopper
290
+ ]
221
291
  else:
222
292
  gencode_opts += [
223
293
  "-gencode=arch=compute_86,code=compute_86", # PTX for future hardware
224
294
  ]
225
295
 
226
- nvcc_opts = gencode_opts + [
296
+ clang_arch_flags += [
297
+ "--cuda-gpu-arch=sm_86", # PTX for future hardware
298
+ ]
299
+
300
+ nvcc_opts = [
301
+ *gencode_opts,
227
302
  "-t0", # multithreaded compilation
228
303
  "--extended-lambda",
229
304
  ]
230
305
 
306
+ # Clang options
307
+ clang_opts = [
308
+ *clang_arch_flags,
309
+ "-std=c++17",
310
+ "-xcuda",
311
+ f'--cuda-path="{cuda_home}"',
312
+ ]
313
+
314
+ if args.compile_time_trace:
315
+ if ctk_version >= (12, 8):
316
+ nvcc_opts.append("--fdevice-time-trace=build_lib_compile-time-trace")
317
+ else:
318
+ print("Warp warning: CUDA version is less than 12.8, compile_time_trace is not supported")
319
+
231
320
  if args.fast_math:
232
321
  nvcc_opts.append("--use_fast_math")
233
322
 
@@ -304,13 +393,17 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
304
393
  )
305
394
 
306
395
  if args.libmathdx_path:
307
- linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib" mathdx_static.lib')
396
+ linkopts.append(f'nvJitLink_static.lib /LIBPATH:"{args.libmathdx_path}/lib/x64" mathdx_static.lib')
308
397
 
309
398
  with ScopedTimer("link", active=args.verbose):
310
399
  link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
311
400
  run_cmd(link_cmd)
312
401
 
313
402
  else:
403
+ # Unix compilation
404
+ cuda_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "nvcc"
405
+ cpp_compiler = "clang++" if getattr(args, "clang_build_toolchain", False) else "g++"
406
+
314
407
  cpp_includes = f' -I"{warp_home_path.parent}/external/llvm-project/out/install/{mode}-{arch}/include"'
315
408
  cpp_includes += f' -I"{warp_home_path.parent}/_build/host-deps/llvm-project/release-{arch}/include"'
316
409
  cuda_includes = f' -I"{cuda_home}/include"' if cu_path else ""
@@ -319,9 +412,12 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
319
412
  if sys.platform == "darwin":
320
413
  version = f"--target={arch}-apple-macos11"
321
414
  else:
322
- version = "-fabi-version=13" # GCC 8.2+
415
+ if cpp_compiler == "g++":
416
+ version = "-fabi-version=13" # GCC 8.2+
417
+ else:
418
+ version = ""
323
419
 
324
- cpp_flags = f'{version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
420
+ cpp_flags = f'-Werror -Wuninitialized {version} --std=c++17 -fno-rtti -D{cuda_enabled} -D{mathdx_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes} '
325
421
 
326
422
  if mode == "debug":
327
423
  cpp_flags += "-O0 -g -D_DEBUG -DWP_ENABLE_DEBUG=1 -fkeep-inline-functions"
@@ -342,17 +438,23 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
342
438
  cpp_out = cpp_path + ".o"
343
439
  ld_inputs.append(quote(cpp_out))
344
440
 
345
- build_cmd = f'g++ {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
441
+ build_cmd = f'{cpp_compiler} {cpp_flags} -c "{cpp_path}" -o "{cpp_out}"'
346
442
  run_cmd(build_cmd)
347
443
 
348
444
  if cu_path:
349
445
  cu_out = cu_path + ".o"
350
446
 
351
- if mode == "debug":
352
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
353
-
354
- elif mode == "release":
355
- cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
447
+ if cuda_compiler == "nvcc":
448
+ if mode == "debug":
449
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -g -G -O0 --compiler-options -fPIC,-fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -line-info {" ".join(nvcc_opts)} -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
450
+ elif mode == "release":
451
+ cuda_cmd = f'"{cuda_home}/bin/nvcc" --std=c++17 -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
452
+ else:
453
+ # Use Clang compiler
454
+ if mode == "debug":
455
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -g -O0 -fPIC -fvisibility=hidden -D_DEBUG -D_ITERATOR_DEBUG_LEVEL=0 -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
456
+ elif mode == "release":
457
+ cuda_cmd = f'clang++ -Werror -Wuninitialized -Wno-unknown-cuda-version {" ".join(clang_opts)} -O3 -fPIC -fvisibility=hidden -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{mathdx_enabled} {libmathdx_includes} -o "{cu_out}" -c "{cu_path}"'
356
458
 
357
459
  with ScopedTimer("build_cuda", active=args.verbose):
358
460
  run_cmd(cuda_cmd)
@@ -374,7 +476,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
374
476
 
375
477
  with ScopedTimer("link", active=args.verbose):
376
478
  origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
377
- link_cmd = f"g++ {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
479
+ link_cmd = f"{cpp_compiler} {version} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
378
480
  run_cmd(link_cmd)
379
481
 
380
482
  # Strip symbols to reduce the binary size
@@ -389,17 +491,14 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None
389
491
 
390
492
 
391
493
  def build_dll(args, dll_path, cpp_paths, cu_path, libs=None):
392
- if libs is None:
393
- libs = []
394
-
395
494
  if sys.platform == "darwin":
396
495
  # create a universal binary by combining x86-64 and AArch64 builds
397
- build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, libs, "x86_64")
398
- build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, libs, "aarch64")
496
+ build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, "x86_64", libs)
497
+ build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, "aarch64", libs)
399
498
 
400
499
  run_cmd(f"lipo -create -output {dll_path} {dll_path}-x86_64 {dll_path}-aarch64")
401
500
  os.remove(f"{dll_path}-x86_64")
402
501
  os.remove(f"{dll_path}-aarch64")
403
502
 
404
503
  else:
405
- build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, machine_architecture())
504
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_path, machine_architecture(), libs)