warp-lang 1.8.1__py3-none-manylinux_2_34_aarch64.whl → 1.9.1__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/warp-clang.so +0 -0
  4. warp/bin/warp.so +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/bin/warp-clang.so CHANGED
Binary file
warp/bin/warp.so CHANGED
Binary file
warp/build.py CHANGED
@@ -51,7 +51,7 @@ def build_cuda(
51
51
  output_path = output_path.encode("utf-8")
52
52
 
53
53
  if warp.config.llvm_cuda:
54
- warp.context.runtime.llvm.compile_cuda(src, cu_path_bytes, inc_path, output_path, False)
54
+ warp.context.runtime.llvm.wp_compile_cuda(src, cu_path_bytes, inc_path, output_path, False)
55
55
 
56
56
  else:
57
57
  if ltoirs is None:
@@ -67,7 +67,7 @@ def build_cuda(
67
67
  fatbins
68
68
  )
69
69
  arr_link_input_types = (ctypes.c_int * num_link)(*link_input_types)
70
- err = warp.context.runtime.core.cuda_compile_program(
70
+ err = warp.context.runtime.core.wp_cuda_compile_program(
71
71
  src,
72
72
  program_name_bytes,
73
73
  arch,
@@ -96,7 +96,7 @@ def load_cuda(input_path, device):
96
96
  if not device.is_cuda:
97
97
  raise RuntimeError("Not a CUDA device")
98
98
 
99
- return warp.context.runtime.core.cuda_load_module(device.context, input_path.encode("utf-8"))
99
+ return warp.context.runtime.core.wp_cuda_load_module(device.context, input_path.encode("utf-8"))
100
100
 
101
101
 
102
102
  def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=False, fuse_fp=True):
@@ -106,7 +106,7 @@ def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=Fal
106
106
  inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
107
107
  obj_path = obj_path.encode("utf-8")
108
108
 
109
- err = warp.context.runtime.llvm.compile_cpp(
109
+ err = warp.context.runtime.llvm.wp_compile_cpp(
110
110
  src, cpp_path, inc_path, obj_path, mode == "debug", verify_fp, fuse_fp
111
111
  )
112
112
  if err != 0:
@@ -129,6 +129,15 @@ def init_kernel_cache(path=None):
129
129
  else:
130
130
  cache_root_dir = appdirs.user_cache_dir(appname="warp", appauthor="NVIDIA", version=warp.config.version)
131
131
 
132
+ if os.name == "nt" and os.path.isabs(cache_root_dir) and not cache_root_dir.startswith("\\\\?\\"):
133
+ # Add Windows long-path prefix, accounting for UNC shares.
134
+ if cache_root_dir.startswith("\\\\"):
135
+ # UNC path \\server\share\… → \\?\UNC\server\share\…
136
+ cache_root_dir = "\\\\?\\UNC\\" + cache_root_dir.lstrip("\\")
137
+ else:
138
+ # Drive-letter path C:\… → \\?\C:\…
139
+ cache_root_dir = "\\\\?\\" + cache_root_dir
140
+
132
141
  warp.config.kernel_cache_dir = cache_root_dir
133
142
 
134
143
  os.makedirs(warp.config.kernel_cache_dir, exist_ok=True)
@@ -246,7 +255,12 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
246
255
  the cached file data.
247
256
 
248
257
  Returns:
249
- Tuple containing lto_code_data followed by any extra data from extra_files
258
+ Tuple where the first element is a success flag (``bool``). The second
259
+ element is the LTO code as bytes (or ``None`` on failure).
260
+ If ``extra_files`` is provided, additional elements follow in the same
261
+ order as the keys in ``extra_files``:
262
+ - ``".meta"``: int (shared memory bytes).
263
+ - ``"_fatbin.lto"``: bytes (universal fatbin).
250
264
  """
251
265
  if extra_files is None:
252
266
  extra_files = {}
@@ -283,9 +297,9 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
283
297
 
284
298
  if all_files_cached:
285
299
  if not extra_files:
286
- return (lto_code_data,)
300
+ return (True, lto_code_data)
287
301
  else:
288
- return (lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
302
+ return (True, lto_code_data, *[extra_files[ext] for ext in extra_files.keys()])
289
303
 
290
304
  # Create process-dependent temporary build directory
291
305
  build_dir = f"{lto_dir}_p{os.getpid()}"
@@ -303,21 +317,24 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
303
317
  for path in temp_file_paths.values():
304
318
  if Path(path).exists():
305
319
  Path(path).unlink()
306
- raise RuntimeError(f"Failed to compile {lto_symbol}")
307
-
308
- # Move outputs to cache
309
- safe_rename(build_dir, lto_dir)
310
-
311
- # If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
312
- if os.path.exists(lto_dir):
313
- for ext, path in file_paths.items():
314
- if not os.path.exists(path):
315
- try:
316
- # copy output file to the destination lto dir
317
- os.rename(temp_file_paths[ext], path)
318
- except (OSError, FileExistsError):
319
- # another process likely updated the lto dir first
320
- pass
320
+
321
+ outputs[".lto"] = None
322
+ for ext in extra_files.keys():
323
+ outputs[ext] = None
324
+ else:
325
+ # Move outputs to cache
326
+ safe_rename(build_dir, lto_dir)
327
+
328
+ # If build_dir couldn't be moved by a rename, move the outputs one-by-one to lto_dir
329
+ if os.path.exists(lto_dir):
330
+ for ext, path in file_paths.items():
331
+ if not os.path.exists(path):
332
+ try:
333
+ # copy output file to the destination lto dir
334
+ os.rename(temp_file_paths[ext], path)
335
+ except (OSError, FileExistsError):
336
+ # another process likely updated the lto dir first
337
+ pass
321
338
 
322
339
  # Clean up the temporary build directory
323
340
  if build_dir:
@@ -326,9 +343,9 @@ def _build_lto_base(lto_symbol, compile_func, builder, extra_files=None):
326
343
  shutil.rmtree(build_dir, ignore_errors=True)
327
344
 
328
345
  if not extra_files:
329
- return (outputs[".lto"],)
346
+ return (result, outputs[".lto"])
330
347
  else:
331
- return (outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
348
+ return (result, outputs[".lto"], *[outputs[ext] for ext in extra_files.keys()])
332
349
 
333
350
 
334
351
  def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, arch, num_threads, builder):
@@ -372,7 +389,7 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
372
389
  lto_symbol = f"dot_{M}_{N}_{K}_{arch}_{num_threads}_{a_arrangement}_{b_arrangement}_{c_arrangement}_{a_prec}_{b_prec}_{c_prec}_{element_type}"
373
390
 
374
391
  def compile_lto_dot(temp_paths):
375
- result = warp.context.runtime.core.cuda_compile_dot(
392
+ result = warp.context.runtime.core.wp_cuda_compile_dot(
376
393
  temp_paths[".lto"].encode("utf-8"),
377
394
  lto_symbol.encode("utf-8"),
378
395
  0,
@@ -402,7 +419,13 @@ def build_lto_dot(M, N, K, adtype, bdtype, cdtype, alayout, blayout, clayout, ar
402
419
  if lto_symbol in builder.ltoirs:
403
420
  lto_code_data = builder.ltoirs[lto_symbol]
404
421
  else:
405
- (lto_code_data,) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
422
+ (result, lto_code_data) = _build_lto_base(lto_symbol, compile_lto_dot, builder, {})
423
+
424
+ if not result:
425
+ raise RuntimeError(
426
+ f"Failed to compile LTO '{lto_symbol}'. "
427
+ "Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
428
+ )
406
429
 
407
430
  # Update builder
408
431
  builder.ltoirs[lto_symbol] = lto_code_data
@@ -429,6 +452,7 @@ def build_lto_solver(
429
452
  num_threads,
430
453
  parameter_list,
431
454
  builder,
455
+ smem_estimate_bytes=None,
432
456
  ):
433
457
  arch = 120 if arch > 121 else arch
434
458
 
@@ -446,7 +470,7 @@ def build_lto_solver(
446
470
 
447
471
  def compile_lto_solver(temp_paths):
448
472
  # compile LTO
449
- result = warp.context.runtime.core.cuda_compile_solver(
473
+ result = warp.context.runtime.core.wp_cuda_compile_solver(
450
474
  temp_paths["_fatbin.lto"].encode("utf-8"),
451
475
  temp_paths[".lto"].encode("utf-8"),
452
476
  lto_symbol.encode("utf-8"),
@@ -479,10 +503,43 @@ def build_lto_solver(
479
503
  if lto_symbol in builder.ltoirs:
480
504
  lto_code_data = builder.ltoirs[lto_symbol]
481
505
  else:
482
- lto_code_data, universal_fatbin_code_data = _build_lto_base(
506
+ (result, lto_code_data, universal_fatbin_code_data) = _build_lto_base(
483
507
  lto_symbol, compile_lto_solver, builder, {"_fatbin.lto": get_cached_lto}
484
508
  )
485
509
 
510
+ if not result:
511
+ hint = ""
512
+ if smem_estimate_bytes:
513
+ max_smem_bytes = 232448
514
+ max_smem_is_estimate = True
515
+ for d in warp.get_cuda_devices():
516
+ if d.arch == arch:
517
+ # We can directly query the max shared memory for this device
518
+ queried_bytes = warp.context.runtime.core.wp_cuda_get_max_shared_memory(d.context)
519
+ if queried_bytes > 0:
520
+ max_smem_bytes = queried_bytes
521
+ max_smem_is_estimate = False
522
+ break
523
+ if smem_estimate_bytes > max_smem_bytes:
524
+ source = "estimated limit" if max_smem_is_estimate else "device-reported limit"
525
+ hint = (
526
+ f"Estimated shared memory requirement is {smem_estimate_bytes}B, "
527
+ f"but the {source} is {max_smem_bytes}B. "
528
+ "The tile size(s) may be too large for this device."
529
+ )
530
+
531
+ if warp.context.runtime.toolkit_version < (12, 6):
532
+ raise RuntimeError(
533
+ "cuSolverDx requires CUDA Toolkit 12.6.3 or later. This version of Warp was built against CUDA Toolkit "
534
+ f"{warp.context.runtime.toolkit_version[0]}.{warp.context.runtime.toolkit_version[1]}. "
535
+ "Upgrade your CUDA Toolkit and rebuild Warp, or install a Warp wheel built with CUDA >= 12.6.3."
536
+ )
537
+ else:
538
+ raise RuntimeError(
539
+ f"Failed to compile LTO '{lto_symbol}'. {hint}"
540
+ " Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
541
+ )
542
+
486
543
  # Update builder
487
544
  builder.ltoirs[lto_symbol] = lto_code_data
488
545
  builder.ltoirs_decl[lto_symbol] = f"void {lto_symbol}{parameter_list};"
@@ -499,7 +556,7 @@ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
499
556
  def compile_lto_fft(temp_paths):
500
557
  shared_memory_size = ctypes.c_int(0)
501
558
 
502
- result = warp.context.runtime.core.cuda_compile_fft(
559
+ result = warp.context.runtime.core.wp_cuda_compile_fft(
503
560
  temp_paths[".lto"].encode("utf-8"),
504
561
  lto_symbol.encode("utf-8"),
505
562
  0,
@@ -535,10 +592,16 @@ def build_lto_fft(arch, size, ept, direction, dir, precision, builder):
535
592
  lto_code_data = builder.ltoirs[lto_symbol]
536
593
  shared_memory_bytes = builder.shared_memory_bytes[lto_symbol]
537
594
  else:
538
- lto_code_data, shared_memory_bytes = _build_lto_base(
595
+ (result, lto_code_data, shared_memory_bytes) = _build_lto_base(
539
596
  lto_symbol, compile_lto_fft, builder, {".meta": lambda path: get_cached_lto_meta(path, lto_symbol)}
540
597
  )
541
598
 
599
+ if not result:
600
+ raise RuntimeError(
601
+ f"Failed to compile LTO '{lto_symbol}'."
602
+ "Set the environment variable LIBMATHDX_LOG_LEVEL=5 and rerun for more details."
603
+ )
604
+
542
605
  # Update builder
543
606
  builder.ltoirs[lto_symbol] = lto_code_data
544
607
  builder.shared_memory_bytes[lto_symbol] = shared_memory_bytes