warp-lang 1.8.1__py3-none-macosx_10_13_universal2.whl → 1.9.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (141) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +1904 -114
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +331 -101
  7. warp/builtins.py +1244 -160
  8. warp/codegen.py +317 -206
  9. warp/config.py +1 -1
  10. warp/context.py +1465 -789
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_kernel.py +2 -1
  18. warp/fabric.py +1 -1
  19. warp/fem/cache.py +27 -19
  20. warp/fem/domain.py +2 -2
  21. warp/fem/field/nodal_field.py +2 -2
  22. warp/fem/field/virtual.py +264 -166
  23. warp/fem/geometry/geometry.py +5 -5
  24. warp/fem/integrate.py +129 -51
  25. warp/fem/space/restriction.py +4 -0
  26. warp/fem/space/shape/tet_shape_function.py +3 -10
  27. warp/jax_experimental/custom_call.py +25 -2
  28. warp/jax_experimental/ffi.py +22 -1
  29. warp/jax_experimental/xla_ffi.py +16 -7
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +99 -4
  32. warp/native/builtin.h +86 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +8 -2
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +41 -10
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +2 -2
  48. warp/native/mat.h +1910 -116
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +4 -2
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +331 -14
  59. warp/native/range.h +7 -1
  60. warp/native/reduce.cpp +10 -10
  61. warp/native/reduce.cu +13 -14
  62. warp/native/runlength_encode.cpp +2 -2
  63. warp/native/runlength_encode.cu +5 -5
  64. warp/native/scan.cpp +3 -3
  65. warp/native/scan.cu +4 -4
  66. warp/native/sort.cpp +10 -10
  67. warp/native/sort.cu +40 -31
  68. warp/native/sort.h +2 -0
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +13 -13
  71. warp/native/spatial.h +366 -17
  72. warp/native/temp_buffer.h +2 -2
  73. warp/native/tile.h +471 -82
  74. warp/native/vec.h +328 -14
  75. warp/native/volume.cpp +54 -54
  76. warp/native/volume.cu +1 -1
  77. warp/native/volume.h +2 -1
  78. warp/native/volume_builder.cu +30 -37
  79. warp/native/warp.cpp +150 -149
  80. warp/native/warp.cu +377 -216
  81. warp/native/warp.h +227 -226
  82. warp/optim/linear.py +736 -271
  83. warp/render/imgui_manager.py +289 -0
  84. warp/render/render_opengl.py +99 -18
  85. warp/render/render_usd.py +1 -0
  86. warp/sim/graph_coloring.py +2 -2
  87. warp/sparse.py +558 -175
  88. warp/tests/aux_test_module_aot.py +7 -0
  89. warp/tests/cuda/test_async.py +3 -3
  90. warp/tests/cuda/test_conditional_captures.py +101 -0
  91. warp/tests/geometry/test_hash_grid.py +38 -0
  92. warp/tests/geometry/test_marching_cubes.py +233 -12
  93. warp/tests/interop/test_jax.py +608 -28
  94. warp/tests/sim/test_coloring.py +6 -6
  95. warp/tests/test_array.py +58 -5
  96. warp/tests/test_codegen.py +4 -3
  97. warp/tests/test_context.py +8 -15
  98. warp/tests/test_enum.py +136 -0
  99. warp/tests/test_examples.py +2 -2
  100. warp/tests/test_fem.py +49 -6
  101. warp/tests/test_fixedarray.py +229 -0
  102. warp/tests/test_func.py +18 -15
  103. warp/tests/test_future_annotations.py +7 -5
  104. warp/tests/test_linear_solvers.py +30 -0
  105. warp/tests/test_map.py +15 -1
  106. warp/tests/test_mat.py +1518 -378
  107. warp/tests/test_mat_assign_copy.py +178 -0
  108. warp/tests/test_mat_constructors.py +574 -0
  109. warp/tests/test_module_aot.py +287 -0
  110. warp/tests/test_print.py +69 -0
  111. warp/tests/test_quat.py +140 -34
  112. warp/tests/test_quat_assign_copy.py +145 -0
  113. warp/tests/test_reload.py +2 -1
  114. warp/tests/test_sparse.py +71 -0
  115. warp/tests/test_spatial.py +140 -34
  116. warp/tests/test_spatial_assign_copy.py +160 -0
  117. warp/tests/test_struct.py +43 -3
  118. warp/tests/test_tuple.py +96 -0
  119. warp/tests/test_types.py +61 -20
  120. warp/tests/test_vec.py +179 -34
  121. warp/tests/test_vec_assign_copy.py +143 -0
  122. warp/tests/tile/test_tile.py +245 -18
  123. warp/tests/tile/test_tile_cholesky.py +605 -0
  124. warp/tests/tile/test_tile_load.py +169 -0
  125. warp/tests/tile/test_tile_mathdx.py +2 -558
  126. warp/tests/tile/test_tile_matmul.py +1 -1
  127. warp/tests/tile/test_tile_mlp.py +1 -1
  128. warp/tests/tile/test_tile_shared_memory.py +5 -5
  129. warp/tests/unittest_suites.py +6 -0
  130. warp/tests/walkthrough_debug.py +1 -1
  131. warp/thirdparty/unittest_parallel.py +108 -9
  132. warp/types.py +571 -267
  133. warp/utils.py +68 -86
  134. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/METADATA +29 -69
  135. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/RECORD +138 -128
  136. warp/native/marching.cpp +0 -19
  137. warp/native/marching.cu +0 -514
  138. warp/native/marching.h +0 -19
  139. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/WHEEL +0 -0
  140. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/licenses/LICENSE.md +0 -0
  141. {warp_lang-1.8.1.dist-info → warp_lang-1.9.1.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -20,6 +20,8 @@ import ctypes
20
20
  import inspect
21
21
  import math
22
22
  import struct
23
+ import sys
24
+ import types
23
25
  import zlib
24
26
  from typing import (
25
27
  Any,
@@ -176,11 +178,11 @@ def constant(x):
176
178
 
177
179
 
178
180
  def float_to_half_bits(value):
179
- return warp.context.runtime.core.float_to_half_bits(value)
181
+ return warp.context.runtime.core.wp_float_to_half_bits(value)
180
182
 
181
183
 
182
184
  def half_bits_to_float(value):
183
- return warp.context.runtime.core.half_bits_to_float(value)
185
+ return warp.context.runtime.core.wp_half_bits_to_float(value)
184
186
 
185
187
 
186
188
  def safe_len(obj):
@@ -190,10 +192,80 @@ def safe_len(obj):
190
192
  return -1
191
193
 
192
194
 
195
+ def flatten(value: Sequence) -> tuple[list, tuple[int, ...]]:
196
+ """Flatten an arbitrarily-nested, rectangular iterable."""
197
+ arr = []
198
+ shape = []
199
+
200
+ depth = 0
201
+ stack = [(depth, value)]
202
+
203
+ while stack:
204
+ depth, elem = stack.pop(0)
205
+
206
+ if isinstance(elem, (str, bytes, bytearray, memoryview)):
207
+ raise TypeError(f"Got an invalid element of type `{type(elem).__name__}`")
208
+
209
+ try:
210
+ # If `elem` is a sequence, then it should be possible
211
+ # to add its elements to the stack for later processing.
212
+ stack.extend((depth + 1, x) for x in elem)
213
+ except TypeError:
214
+ # Since `elem` doesn't seem to be a sequence, we must have
215
+ # a leaf value that we need to add to our resulting array.
216
+ if depth != len(shape):
217
+ raise ValueError("Ragged array: scalar found before deepest level.") from None
218
+
219
+ arr.append(elem)
220
+ else:
221
+ dim = len(elem)
222
+ if depth == len(shape):
223
+ # First sequence seen at this depth, record its length.
224
+ shape.append(dim)
225
+ elif shape[depth] != dim:
226
+ # Later sequences must have the same length.
227
+ raise ValueError(f"Ragged array: expected length {shape[depth]} at depth {depth}, got {dim}.") from None
228
+
229
+ return (arr, tuple(shape))
230
+
231
+
193
232
  # ----------------------
194
233
  # built-in types
195
234
 
196
235
 
236
+ def _unary_op(self, op, t):
237
+ try:
238
+ return op(self)
239
+ except RuntimeError:
240
+ return t(*(op(a) for a in self))
241
+
242
+
243
+ def _binary_op(self, op, x, t, cw=True):
244
+ try:
245
+ return op(self, x)
246
+ except RuntimeError:
247
+ if is_scalar(x):
248
+ return t(*(op(a, x) for a in self))
249
+
250
+ if cw and types_equal(x, t):
251
+ return t(*(op(a, b) for a, b in zip(self, x)))
252
+
253
+ raise
254
+
255
+
256
+ def _rbinary_op(self, op, x, t, cw=True):
257
+ try:
258
+ return op(x, self)
259
+ except RuntimeError:
260
+ if is_scalar(x):
261
+ return t(*(op(x, a) for a in self))
262
+
263
+ if cw and types_equal(x, t):
264
+ return t(*(op(b, a) for a, b in zip(self, x)))
265
+
266
+ raise
267
+
268
+
197
269
  def vector(length, dtype):
198
270
  # canonicalize dtype
199
271
  if dtype == int:
@@ -260,9 +332,10 @@ def vector(length, dtype):
260
332
  return vec_t.scalar_export(super().__getitem__(key))
261
333
  elif isinstance(key, slice):
262
334
  if self._wp_scalar_type_ == float16:
263
- return [vec_t.scalar_export(x) for x in super().__getitem__(key)]
335
+ values = tuple(vec_t.scalar_export(x) for x in super().__getitem__(key))
264
336
  else:
265
- return super().__getitem__(key)
337
+ values = super().__getitem__(key)
338
+ return vector(len(values), self._wp_scalar_type_)(*values)
266
339
  else:
267
340
  raise KeyError(f"Invalid key {key}, expected int or slice")
268
341
 
@@ -276,6 +349,13 @@ def vector(length, dtype):
276
349
  f"but got `{type(value).__name__}` instead"
277
350
  ) from None
278
351
  elif isinstance(key, slice):
352
+ if is_scalar(value):
353
+ indices = range(*key.indices(self._length_))
354
+ for idx in indices:
355
+ super().__setitem__(idx, vec_t.scalar_import(value))
356
+
357
+ return
358
+
279
359
  try:
280
360
  iter(value)
281
361
  except TypeError:
@@ -325,40 +405,40 @@ def vector(length, dtype):
325
405
  return super().__setattr__(name, value)
326
406
 
327
407
  def __add__(self, y):
328
- return warp.add(self, y)
408
+ return _binary_op(self, warp.add, y, vec_t)
329
409
 
330
410
  def __radd__(self, y):
331
- return warp.add(y, self)
411
+ return _rbinary_op(self, warp.add, y, vec_t)
332
412
 
333
413
  def __sub__(self, y):
334
- return warp.sub(self, y)
414
+ return _binary_op(self, warp.sub, y, vec_t)
335
415
 
336
416
  def __rsub__(self, y):
337
- return warp.sub(y, self)
417
+ return _rbinary_op(self, warp.sub, y, vec_t)
338
418
 
339
419
  def __mul__(self, y):
340
- return warp.mul(self, y)
420
+ return _binary_op(self, warp.mul, y, vec_t, cw=False)
341
421
 
342
422
  def __rmul__(self, x):
343
- return warp.mul(x, self)
423
+ return _rbinary_op(self, warp.mul, x, vec_t, cw=False)
344
424
 
345
425
  def __truediv__(self, y):
346
- return warp.div(self, y)
426
+ return _binary_op(self, warp.div, y, vec_t, cw=False)
347
427
 
348
428
  def __rtruediv__(self, x):
349
- return warp.div(x, self)
429
+ return _rbinary_op(self, warp.div, x, vec_t, cw=False)
350
430
 
351
431
  def __mod__(self, x):
352
- return warp.mod(self, x)
432
+ return _binary_op(self, warp.mod, x, vec_t)
353
433
 
354
434
  def __rmod__(self, x):
355
- return warp.mod(x, self)
435
+ return _rbinary_op(self, warp.mod, x, vec_t)
356
436
 
357
437
  def __pos__(self):
358
- return warp.pos(self)
438
+ return _unary_op(self, warp.pos, vec_t)
359
439
 
360
440
  def __neg__(self):
361
- return warp.neg(self)
441
+ return _unary_op(self, warp.neg, vec_t)
362
442
 
363
443
  def __str__(self):
364
444
  return f"[{', '.join(map(str, self))}]"
@@ -422,6 +502,7 @@ def matrix(shape, dtype):
422
502
  _wp_constructor_ = "matrix"
423
503
 
424
504
  _wp_row_type_ = vector(0 if shape[1] == Any else shape[1], dtype)
505
+ _wp_col_type_ = vector(0 if shape[0] == Any else shape[0], dtype)
425
506
 
426
507
  # special handling for float16 type: in this case, data is stored
427
508
  # as uint16 but it's actually half precision floating point
@@ -467,22 +548,22 @@ def matrix(shape, dtype):
467
548
  return self._shape_[0]
468
549
 
469
550
  def __add__(self, y):
470
- return warp.add(self, y)
551
+ return _binary_op(self, warp.add, y, mat_t)
471
552
 
472
553
  def __radd__(self, y):
473
- return warp.add(y, self)
554
+ return _rbinary_op(self, warp.add, y, mat_t)
474
555
 
475
556
  def __sub__(self, y):
476
- return warp.sub(self, y)
557
+ return _binary_op(self, warp.sub, y, mat_t)
477
558
 
478
559
  def __rsub__(self, y):
479
- return warp.sub(y, self)
560
+ return _rbinary_op(self, warp.sub, y, mat_t)
480
561
 
481
562
  def __mul__(self, y):
482
- return warp.mul(self, y)
563
+ return _binary_op(self, warp.mul, y, mat_t, cw=False)
483
564
 
484
565
  def __rmul__(self, x):
485
- return warp.mul(x, self)
566
+ return _rbinary_op(self, warp.mul, x, mat_t, cw=False)
486
567
 
487
568
  def __matmul__(self, y):
488
569
  return warp.mul(self, y)
@@ -491,16 +572,22 @@ def matrix(shape, dtype):
491
572
  return warp.mul(x, self)
492
573
 
493
574
  def __truediv__(self, y):
494
- return warp.div(self, y)
575
+ return _binary_op(self, warp.div, y, mat_t, cw=False)
495
576
 
496
577
  def __rtruediv__(self, x):
497
- return warp.div(x, self)
578
+ return _rbinary_op(self, warp.div, x, mat_t, cw=False)
579
+
580
+ def __mod__(self, x):
581
+ return _binary_op(self, warp.mod, x, mat_t)
582
+
583
+ def __rmod__(self, x):
584
+ return _rbinary_op(self, warp.mod, x, mat_t)
498
585
 
499
586
  def __pos__(self):
500
- return warp.pos(self)
587
+ return _unary_op(self, warp.pos, mat_t)
501
588
 
502
589
  def __neg__(self):
503
- return warp.neg(self)
590
+ return _unary_op(self, warp.neg, mat_t)
504
591
 
505
592
  def __str__(self):
506
593
  row_str = []
@@ -522,8 +609,10 @@ def matrix(shape, dtype):
522
609
  return True
523
610
 
524
611
  def get_row(self, r):
525
- if r < 0 or r >= self._shape_[0]:
612
+ if r < -self._shape_[0] or r >= self._shape_[0]:
526
613
  raise IndexError("Invalid row index")
614
+ if r < 0:
615
+ r += self._shape_[0]
527
616
  row_start = r * self._shape_[1]
528
617
  row_end = row_start + self._shape_[1]
529
618
  row_data = super().__getitem__(slice(row_start, row_end))
@@ -532,9 +621,35 @@ def matrix(shape, dtype):
532
621
  else:
533
622
  return self._wp_row_type_(row_data)
534
623
 
624
+ def get_col(self, c):
625
+ if c < -self._shape_[1] or c >= self._shape_[1]:
626
+ raise IndexError("Invalid column index")
627
+ if c < 0:
628
+ c += self._shape_[1]
629
+ col_start = c
630
+ col_end = col_start + self._shape_[0] * self._shape_[1]
631
+ col_step = self._shape_[1]
632
+ col_data = super().__getitem__(slice(col_start, col_end, col_step))
633
+ if self._wp_scalar_type_ == float16:
634
+ return self._wp_col_type_(*[mat_t.scalar_export(x) for x in col_data])
635
+ else:
636
+ return self._wp_col_type_(col_data)
637
+
535
638
  def set_row(self, r, v):
536
- if r < 0 or r >= self._shape_[0]:
639
+ if r < -self._shape_[0] or r >= self._shape_[0]:
537
640
  raise IndexError("Invalid row index")
641
+ if r < 0:
642
+ r += self._shape_[0]
643
+
644
+ row_start = r * self._shape_[1]
645
+ row_end = row_start + self._shape_[1]
646
+
647
+ if is_scalar(v):
648
+ for i in range(row_start, row_end):
649
+ super().__setitem__(i, mat_t.scalar_import(v))
650
+
651
+ return
652
+
538
653
  try:
539
654
  iter(v)
540
655
  except TypeError:
@@ -542,8 +657,6 @@ def matrix(shape, dtype):
542
657
  f"Expected to assign a slice from a sequence of values but got `{type(v).__name__}` instead"
543
658
  ) from None
544
659
 
545
- row_start = r * self._shape_[1]
546
- row_end = row_start + self._shape_[1]
547
660
  if self._wp_scalar_type_ == float16:
548
661
  converted = []
549
662
  try:
@@ -558,17 +671,86 @@ def matrix(shape, dtype):
558
671
  v = converted
559
672
  super().__setitem__(slice(row_start, row_end), v)
560
673
 
674
+ def set_col(self, c, v):
675
+ if c < -self._shape_[1] or c >= self._shape_[1]:
676
+ raise IndexError("Invalid col index")
677
+ if c < 0:
678
+ c += self._shape_[1]
679
+
680
+ col_start = c
681
+ col_end = col_start + self._shape_[0] * self._shape_[1]
682
+ col_step = self._shape_[1]
683
+
684
+ if is_scalar(v):
685
+ for i in range(col_start, col_end, col_step):
686
+ super().__setitem__(i, mat_t.scalar_import(v))
687
+
688
+ return
689
+
690
+ try:
691
+ iter(v)
692
+ except TypeError:
693
+ raise TypeError(
694
+ f"Expected to assign a slice from a sequence of values but got `{type(v).__name__}` instead"
695
+ ) from None
696
+
697
+ if self._wp_scalar_type_ == float16:
698
+ converted = []
699
+ try:
700
+ for x in v:
701
+ converted.append(mat_t.scalar_import(x))
702
+ except ctypes.ArgumentError:
703
+ raise TypeError(
704
+ f"Expected to assign a slice from a sequence of `float16` values "
705
+ f"but got `{type(x).__name__}` instead"
706
+ ) from None
707
+
708
+ v = converted
709
+ super().__setitem__(slice(col_start, col_end, col_step), v)
710
+
561
711
  def __getitem__(self, key):
562
712
  if isinstance(key, Tuple):
563
713
  # element indexing m[i,j]
564
714
  if len(key) != 2:
565
715
  raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
566
- if any(isinstance(x, slice) for x in key):
567
- raise KeyError("Slices are not supported when indexing matrices using the `m[i, j]` notation")
568
- return mat_t.scalar_export(super().__getitem__(key[0] * self._shape_[1] + key[1]))
716
+
717
+ # Count how many dimensions the output value will have.
718
+ ndim = sum(1 for x in key if isinstance(x, slice))
719
+
720
+ if ndim == 0:
721
+ row = key[0] + self._shape_[0] if key[0] < 0 else key[0]
722
+ col = key[1] + self._shape_[1] if key[1] < 0 else key[1]
723
+ return mat_t.scalar_export(super().__getitem__(row * self._shape_[1] + col))
724
+
725
+ if ndim == 1:
726
+ if isinstance(key[1], slice):
727
+ # Row vector.
728
+ cols = range(*key[1].indices(self._shape_[0]))
729
+ row_vec = self.get_row(key[0])
730
+ values = tuple(row_vec[x] for x in cols)
731
+ return vector(len(values), self._wp_scalar_type_)(*values)
732
+ else:
733
+ # Column vector.
734
+ rows = range(*key[0].indices(self._shape_[1]))
735
+ col_vec = self.get_col(key[1])
736
+ values = tuple(col_vec[x] for x in rows)
737
+ return vector(len(values), self._wp_scalar_type_)(*values)
738
+
739
+ assert ndim == 2
740
+ rows = range(*key[0].indices(self._shape_[1]))
741
+ cols = range(*key[1].indices(self._shape_[0]))
742
+ row_vecs = tuple(self.get_row(i) for i in rows)
743
+ values = tuple(x[j] for x in row_vecs for j in cols)
744
+ shape = (len(rows), len(cols))
745
+ return matrix(shape, self._wp_scalar_type_)(*values)
569
746
  elif isinstance(key, int):
570
747
  # row vector indexing m[r]
571
748
  return self.get_row(key)
749
+ elif isinstance(key, slice):
750
+ indices = range(*key.indices(self._shape_[0]))
751
+ row_vecs = tuple(self.get_row(x) for x in indices)
752
+ shape = (len(row_vecs), self._shape_[1])
753
+ return matrix(shape, self._wp_scalar_type_)(*row_vecs)
572
754
  else:
573
755
  raise KeyError(f"Invalid key {key}, expected int or pair of ints")
574
756
 
@@ -577,20 +759,104 @@ def matrix(shape, dtype):
577
759
  # element indexing m[i,j] = x
578
760
  if len(key) != 2:
579
761
  raise KeyError(f"Invalid key, expected one or two indices, got {len(key)}")
580
- if any(isinstance(x, slice) for x in key):
581
- raise KeyError("Slices are not supported when indexing matrices using the `m[i, j]` notation")
582
- try:
583
- return super().__setitem__(key[0] * self._shape_[1] + key[1], mat_t.scalar_import(value))
584
- except (TypeError, ctypes.ArgumentError):
585
- raise TypeError(
586
- f"Expected to assign a `{self._wp_scalar_type_.__name__}` value "
587
- f"but got `{type(value).__name__}` instead"
588
- ) from None
762
+
763
+ # Count how many dimensions the output value is expected to have.
764
+ ndim = sum(1 for x in key if isinstance(x, slice))
765
+
766
+ if ndim == 0:
767
+ try:
768
+ _, v_shape = flatten(value)
769
+ except TypeError:
770
+ raise TypeError(
771
+ f"Expected to assign a `{type_repr(self._wp_scalar_type_)}` value but got `{type(value).__name__}` instead"
772
+ ) from None
773
+
774
+ if v_shape:
775
+ raise RuntimeError(
776
+ f"The provided value is expected to be a scalar but got an object of shape {v_shape} instead"
777
+ )
778
+
779
+ row = key[0] + self._shape_[0] if key[0] < 0 else key[0]
780
+ col = key[1] + self._shape_[1] if key[1] < 0 else key[1]
781
+ idx = row * self._shape_[1] + col
782
+ super().__setitem__(idx, mat_t.scalar_import(value))
783
+ return
784
+
785
+ if ndim == 1:
786
+ _, v_shape = flatten(value)
787
+
788
+ if v_shape and len(v_shape) != 1:
789
+ raise RuntimeError(
790
+ f"The provided value is expected to be a 1D vector but got an object of shape {v_shape} instead"
791
+ )
792
+
793
+ if isinstance(key[1], slice):
794
+ # Row vector.
795
+ cols = range(*key[1].indices(self._shape_[0]))
796
+ if v_shape and v_shape[0] != len(cols):
797
+ raise RuntimeError(
798
+ f"The length of the provided vector ({v_shape[0]}) isn't compatible with the given slice (expected {len(cols)})"
799
+ )
800
+
801
+ row = key[0] + self._shape_[0] if key[0] < 0 else key[0]
802
+ for i, col in enumerate(cols):
803
+ idx = row * self._shape_[1] + col
804
+ super().__setitem__(idx, mat_t.scalar_import(value[i] if v_shape else value))
805
+
806
+ return
807
+ else:
808
+ # Column vector.
809
+ rows = range(*key[0].indices(self._shape_[1]))
810
+ if v_shape and v_shape[0] != len(rows):
811
+ raise RuntimeError(
812
+ f"The length of the provided vector ({v_shape[0]}) isn't compatible with the given slice (expected {len(rows)})"
813
+ )
814
+
815
+ col = key[1] + self._shape_[1] if key[1] < 0 else key[1]
816
+ for i, row in enumerate(rows):
817
+ idx = row * self._shape_[1] + col
818
+ super().__setitem__(idx, mat_t.scalar_import(value[i] if v_shape else value))
819
+
820
+ return
821
+
822
+ assert ndim == 2
823
+
824
+ _, v_shape = flatten(value)
825
+
826
+ if v_shape and len(v_shape) != 2:
827
+ raise RuntimeError(
828
+ f"The provided value is expected to be a 2D matrix but got an object of shape {v_shape} instead"
829
+ )
830
+
831
+ rows = range(*key[0].indices(self._shape_[1]))
832
+ cols = range(*key[1].indices(self._shape_[0]))
833
+
834
+ if v_shape and v_shape != (len(rows), len(cols)):
835
+ raise RuntimeError(
836
+ f"The shape of the provided matrix ({v_shape}) isn't compatible with the given slice (expected ({len(rows)}, {len(cols)}))"
837
+ )
838
+
839
+ for i, row in enumerate(rows):
840
+ for j, col in enumerate(cols):
841
+ idx = row * self._shape_[1] + col
842
+ super().__setitem__(idx, mat_t.scalar_import(value[i, j] if v_shape else value))
589
843
  elif isinstance(key, int):
590
844
  # row vector indexing m[r] = v
591
845
  return self.set_row(key, value)
592
846
  elif isinstance(key, slice):
593
- raise KeyError("Slices are not supported when indexing matrices using the `m[start:end]` notation")
847
+ v_arr, v_shape = flatten(value)
848
+ indices = range(*key.indices(self._shape_[0]))
849
+
850
+ if v_shape and (len(v_shape) != 2 or v_shape[0] != len(indices) or v_shape[1] != self._shape_[1]):
851
+ raise RuntimeError(
852
+ f"The shape of the provided matrix ({v_shape}) isn't compatible with the given slice (expected ({len(indices)}, {self._shape_[1]}))"
853
+ )
854
+
855
+ for i, row in enumerate(indices):
856
+ offset = i * self._shape_[1]
857
+ self.set_row(
858
+ row, v_arr[offset : offset + self._shape_[1]] if v_shape else (value,) * self._shape_[1]
859
+ )
594
860
  else:
595
861
  raise KeyError(f"Invalid key {key}, expected int or pair of ints")
596
862
 
@@ -608,6 +874,60 @@ def matrix(shape, dtype):
608
874
  return mat_t
609
875
 
610
876
 
877
+ def matrix_from_cols(*args: Sequence[Vector]):
878
+ if not all(type_is_vector(x) for x in args):
879
+ raise RuntimeError("all arguments are expected to be vectors")
880
+
881
+ length = args[0]._length_
882
+ if any(x._length_ != length for x in args):
883
+ raise RuntimeError("all vectors are expected to have the same length")
884
+
885
+ dtype = args[0]._wp_scalar_type_
886
+ if any(x._wp_scalar_type_ != dtype for x in args):
887
+ raise RuntimeError("all vectors are expected to have the same dtype")
888
+
889
+ row_count = length
890
+ col_count = len(args)
891
+ out = matrix(shape=(row_count, col_count), dtype=dtype)()
892
+ mat_t = type(out)
893
+
894
+ for col in range(col_count):
895
+ v = args[col]
896
+ for row in range(row_count):
897
+ idx = col_count * row + col
898
+ value = mat_t.scalar_import(v[row])
899
+ super(mat_t, out).__setitem__(idx, value)
900
+
901
+ return out
902
+
903
+
904
+ def matrix_from_rows(*args: Sequence[Vector]):
905
+ if not all(type_is_vector(x) for x in args):
906
+ raise RuntimeError("all arguments are expected to be vectors")
907
+
908
+ length = args[0]._length_
909
+ if any(x._length_ != length for x in args):
910
+ raise RuntimeError("all vectors are expected to have the same length")
911
+
912
+ dtype = args[0]._wp_scalar_type_
913
+ if any(x._wp_scalar_type_ != dtype for x in args):
914
+ raise RuntimeError("all vectors are expected to have the same dtype")
915
+
916
+ row_count = len(args)
917
+ col_count = length
918
+ out = matrix(shape=(row_count, col_count), dtype=dtype)()
919
+ mat_t = type(out)
920
+
921
+ for row in range(row_count):
922
+ v = args[row]
923
+ for col in range(col_count):
924
+ idx = col_count * row + col
925
+ value = mat_t.scalar_import(v[col])
926
+ super(mat_t, out).__setitem__(idx, value)
927
+
928
+ return out
929
+
930
+
611
931
  class void:
612
932
  def __init__(self):
613
933
  pass
@@ -664,12 +984,22 @@ class scalar_base:
664
984
 
665
985
 
666
986
  class float_base(scalar_base):
667
- pass
987
+ def __str__(self) -> str:
988
+ return str(self.value)
989
+
990
+ def __repr__(self) -> str:
991
+ return f"{type(self).__name__}({self!s})"
668
992
 
669
993
 
670
994
  class int_base(scalar_base):
671
995
  def __index__(self) -> int:
672
- return int(self.value)
996
+ return int(self._type_(self.value).value)
997
+
998
+ def __str__(self) -> str:
999
+ return str(self._type_(self.value).value)
1000
+
1001
+ def __repr__(self) -> str:
1002
+ return f"{type(self).__name__}({self!s})"
673
1003
 
674
1004
 
675
1005
  class bool:
@@ -688,6 +1018,12 @@ class bool:
688
1018
  def __int__(self) -> int:
689
1019
  return int(self.value != 0)
690
1020
 
1021
+ def __str__(self) -> str:
1022
+ return str(self.value != 0)
1023
+
1024
+ def __repr__(self) -> str:
1025
+ return f"{type(self).__name__}({self!s})"
1026
+
691
1027
 
692
1028
  class float16(float_base):
693
1029
  _length_ = 1
@@ -1308,6 +1644,49 @@ class launch_bounds_t(ctypes.Structure):
1308
1644
  self.shape[i] = 1
1309
1645
 
1310
1646
 
1647
+ INT_WIDTH = ctypes.sizeof(ctypes.c_int) * 8
1648
+ SLICE_BEGIN = (1 << (INT_WIDTH - 1)) - 1
1649
+ SLICE_END = -(1 << (INT_WIDTH - 1))
1650
+
1651
+
1652
+ class slice_t:
1653
+ _wp_native_name_ = "slice_t"
1654
+
1655
+ def __init__(self, start, stop, step):
1656
+ self.start = start
1657
+ self.stop = stop
1658
+ self.step = step
1659
+
1660
+ def get_length(self, parent_length, wrap=False):
1661
+ if any(isinstance(x, warp.codegen.Var) for x in (self.start, self.stop, self.step)):
1662
+ raise RuntimeError("Vector slice indices must be constant values.")
1663
+
1664
+ if self.step == 0:
1665
+ raise RuntimeError(f"Vector slice step {self.step} is invalid.")
1666
+
1667
+ if self.start == SLICE_BEGIN:
1668
+ start = parent_length - 1 if self.step < 0 else 0
1669
+ else:
1670
+ start = min(max(self.start, -parent_length), parent_length)
1671
+ if wrap:
1672
+ start = start + parent_length if start < 0 else start
1673
+
1674
+ if self.stop == SLICE_END:
1675
+ stop = -1 if self.step < 0 else parent_length
1676
+ else:
1677
+ stop = min(max(self.stop, -parent_length), parent_length)
1678
+ if wrap:
1679
+ stop = stop + parent_length if stop < 0 else stop
1680
+
1681
+ if self.step > 0 and start < stop:
1682
+ return 1 + (stop - start - 1) // self.step
1683
+
1684
+ if self.step < 0 and start > stop:
1685
+ return 1 + (start - stop - 1) // (-self.step)
1686
+
1687
+ return 0
1688
+
1689
+
1311
1690
  class shape_t(ctypes.Structure):
1312
1691
  _fields_ = (("dims", ctypes.c_int32 * ARRAY_MAX_DIMS),)
1313
1692
 
@@ -1388,6 +1767,13 @@ class tuple_t:
1388
1767
  self.values = values
1389
1768
 
1390
1769
 
1770
+ class pointer_t:
1771
+ """Used during codegen to represent pointer types."""
1772
+
1773
+ def __init__(self, dtype):
1774
+ self.dtype = dtype
1775
+
1776
+
1391
1777
  def type_ctype(dtype):
1392
1778
  if dtype == float:
1393
1779
  return ctypes.c_float
@@ -1499,7 +1885,9 @@ def type_typestr(dtype: type) -> str:
1499
1885
 
1500
1886
 
1501
1887
  def scalar_short_name(t):
1502
- if t == float32:
1888
+ if t == float16:
1889
+ return "h"
1890
+ elif t == float32:
1503
1891
  return "f"
1504
1892
  elif t == float64:
1505
1893
  return "d"
@@ -1525,10 +1913,10 @@ def scalar_short_name(t):
1525
1913
  # converts any known type to a human readable string, good for error messages, reporting etc
1526
1914
  def type_repr(t) -> str:
1527
1915
  if is_array(t):
1528
- if t.device is None:
1916
+ if hasattr(t, "device") and t.device is None:
1529
1917
  # array is used as a type annotation - display ndim instead of shape
1530
- return f"array(ndim={t.ndim}, dtype={type_repr(t.dtype)})"
1531
- return f"array(shape={t.shape}, dtype={type_repr(t.dtype)})"
1918
+ return f"{type(t).__name__}(ndim={t.ndim}, dtype={type_repr(t.dtype)})"
1919
+ return f"{type(t).__name__}(shape={t.shape}, dtype={type_repr(t.dtype)})"
1532
1920
  if is_tuple(t):
1533
1921
  return f"tuple({', '.join(type_repr(x) for x in t.types)})"
1534
1922
  if is_tile(t):
@@ -1584,6 +1972,11 @@ def type_is_float(t):
1584
1972
  return t in float_types
1585
1973
 
1586
1974
 
1975
+ # returns True if the passed *type* is a scalar
1976
+ def type_is_scalar(t):
1977
+ return type_is_int(t) or type_is_float(t)
1978
+
1979
+
1587
1980
  # returns True if the passed *type* is a vector
1588
1981
  def type_is_vector(t):
1589
1982
  return getattr(t, "_wp_generic_type_hint_", None) is Vector
@@ -1621,6 +2014,10 @@ def is_float(x: Any) -> builtins.bool:
1621
2014
  return type_is_float(type(x))
1622
2015
 
1623
2016
 
2017
+ def is_scalar(x: Any) -> builtins.bool:
2018
+ return type_is_scalar(type(x))
2019
+
2020
+
1624
2021
  def is_value(x: Any) -> builtins.bool:
1625
2022
  return type_is_value(type(x))
1626
2023
 
@@ -1750,7 +2147,11 @@ def types_equal(a, b, match_generic=False):
1750
2147
 
1751
2148
  return True
1752
2149
 
1753
- if is_array(a) and type(a) is type(b) and types_equal(a.dtype, b.dtype, match_generic=match_generic):
2150
+ if (
2151
+ is_array(a)
2152
+ and (issubclass(type(a), type(b)) or issubclass(type(b), type(a)))
2153
+ and types_equal(a.dtype, b.dtype, match_generic=match_generic)
2154
+ ):
1754
2155
  return True
1755
2156
 
1756
2157
  # match NewStructInstance and Struct dtype
@@ -1863,10 +2264,6 @@ class array(Array[DType]):
1863
2264
  taking two arguments: pointer and size. If ``None``, then no function is called.
1864
2265
  """
1865
2266
 
1866
- # member attributes available during code-gen (e.g.: d = array.shape[0])
1867
- # (initialized when needed)
1868
- _vars = None
1869
-
1870
2267
  def __new__(cls, *args, **kwargs):
1871
2268
  instance = super().__new__(cls)
1872
2269
  instance.deleter = None
@@ -2483,7 +2880,7 @@ class array(Array[DType]):
2483
2880
 
2484
2881
  # Performance note: avoid wrapping the external stream in a temporary Stream object
2485
2882
  if stream != array_stream.cuda_stream:
2486
- warp.context.runtime.core.cuda_stream_wait_stream(
2883
+ warp.context.runtime.core.wp_cuda_stream_wait_stream(
2487
2884
  stream, array_stream.cuda_stream, array_stream.cached_event.cuda_event
2488
2885
  )
2489
2886
  elif self.device.is_cpu:
@@ -2570,9 +2967,9 @@ class array(Array[DType]):
2570
2967
 
2571
2968
  if start < 0 or start >= self.shape[idx]:
2572
2969
  raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
2573
- if stop < 1 or stop > self.shape[idx]:
2970
+ if stop < 0 or stop > self.shape[idx]:
2574
2971
  raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
2575
- if stop <= start:
2972
+ if stop < start:
2576
2973
  raise RuntimeError(f"Invalid indexing in slice: {start}:{stop}:{step}")
2577
2974
 
2578
2975
  new_shape.append(-((stop - start) // -step)) # ceil division
@@ -2722,10 +3119,10 @@ class array(Array[DType]):
2722
3119
  @property
2723
3120
  def vars(self):
2724
3121
  # member attributes available during code-gen (e.g.: d = array.shape[0])
2725
- # Note: we use a shared dict for all array instances
2726
- if array._vars is None:
2727
- array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
2728
- return array._vars
3122
+ return {
3123
+ "shape": warp.codegen.Var("shape", shape_t),
3124
+ "ptr": warp.codegen.Var("data", pointer_t(self.dtype)),
3125
+ }
2729
3126
 
2730
3127
  def mark_init(self):
2731
3128
  """Resets this array's read flag"""
@@ -2830,11 +3227,11 @@ class array(Array[DType]):
2830
3227
  carr_ptr = ctypes.pointer(carr)
2831
3228
 
2832
3229
  if self.device.is_cuda:
2833
- warp.context.runtime.core.array_fill_device(
3230
+ warp.context.runtime.core.wp_array_fill_device(
2834
3231
  self.device.context, carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size
2835
3232
  )
2836
3233
  else:
2837
- warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
3234
+ warp.context.runtime.core.wp_array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
2838
3235
 
2839
3236
  self.mark_init()
2840
3237
 
@@ -3138,7 +3535,7 @@ class array(Array[DType]):
3138
3535
  # Allocate a buffer for the data (64-element char array)
3139
3536
  ipc_handle_buffer = (ctypes.c_char * 64)()
3140
3537
 
3141
- warp.context.runtime.core.cuda_ipc_get_mem_handle(self.ptr, ipc_handle_buffer)
3538
+ warp.context.runtime.core.wp_cuda_ipc_get_mem_handle(self.ptr, ipc_handle_buffer)
3142
3539
 
3143
3540
  return ipc_handle_buffer.raw
3144
3541
 
@@ -3191,7 +3588,7 @@ def from_ptr(ptr, length, dtype=None, shape=None, device=None):
3191
3588
 
3192
3589
 
3193
3590
  def _close_cuda_ipc_handle(ptr, size):
3194
- warp.context.runtime.core.cuda_ipc_close_mem_handle(ptr)
3591
+ warp.context.runtime.core.wp_cuda_ipc_close_mem_handle(ptr)
3195
3592
 
3196
3593
 
3197
3594
  def from_ipc_handle(
@@ -3230,11 +3627,59 @@ def from_ipc_handle(
3230
3627
  if device.is_ipc_supported is False:
3231
3628
  raise RuntimeError(f"IPC is not supported on device {device}.")
3232
3629
 
3233
- ptr = warp.context.runtime.core.cuda_ipc_open_mem_handle(device.context, handle)
3630
+ ptr = warp.context.runtime.core.wp_cuda_ipc_open_mem_handle(device.context, handle)
3234
3631
 
3235
3632
  return array(ptr=ptr, dtype=dtype, shape=shape, strides=strides, device=device, deleter=_close_cuda_ipc_handle)
3236
3633
 
3237
3634
 
3635
+ class fixedarray(array):
3636
+ """A fixed-size, stack allocated, array containing values of the same type.
3637
+
3638
+ Only used during codegen, and for type hints, but otherwise not intended to be used
3639
+ at the Python scope.
3640
+
3641
+ Attributes:
3642
+ dtype (DType): The data type of the array.
3643
+ shape (tuple[int]): Dimensions of the array.
3644
+ """
3645
+
3646
+ def __init__(
3647
+ self,
3648
+ dtype: Any = Any,
3649
+ shape: int | tuple[int, ...] | list[int] | None = None,
3650
+ ):
3651
+ # canonicalize dtype
3652
+ if dtype == int:
3653
+ dtype = int32
3654
+ elif dtype == float:
3655
+ dtype = float32
3656
+ elif dtype == builtins.bool:
3657
+ dtype = bool
3658
+
3659
+ if shape is None:
3660
+ self.dtype = dtype
3661
+ self.ndim = 1
3662
+ self.size = 0
3663
+ self.shape = (0,)
3664
+ self.strides = (0,)
3665
+ else:
3666
+ if isinstance(shape, int):
3667
+ shape = (shape,)
3668
+
3669
+ check_array_shape(shape)
3670
+
3671
+ self.dtype = dtype
3672
+ self.ndim = len(shape)
3673
+ self.size = math.prod(shape)
3674
+ self.shape = shape
3675
+ self.strides = strides_from_shape(shape, dtype)
3676
+
3677
+ @property
3678
+ def vars(self):
3679
+ # member attributes available during code-gen (e.g.: d = array.shape[0])
3680
+ return {"shape": warp.codegen.Var("shape", shape_t)}
3681
+
3682
+
3238
3683
  # A base class for non-contiguous arrays, providing the implementation of common methods like
3239
3684
  # contiguous(), to(), numpy(), list(), assign(), zero_(), and fill_().
3240
3685
  class noncontiguous_array_base(Array[T]):
@@ -3315,11 +3760,11 @@ class noncontiguous_array_base(Array[T]):
3315
3760
  ctype_ptr = ctypes.pointer(ctype)
3316
3761
 
3317
3762
  if self.device.is_cuda:
3318
- warp.context.runtime.core.array_fill_device(
3763
+ warp.context.runtime.core.wp_array_fill_device(
3319
3764
  self.device.context, ctype_ptr, self.type_id, cvalue_ptr, cvalue_size
3320
3765
  )
3321
3766
  else:
3322
- warp.context.runtime.core.array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
3767
+ warp.context.runtime.core.wp_array_fill_host(ctype_ptr, self.type_id, cvalue_ptr, cvalue_size)
3323
3768
 
3324
3769
 
3325
3770
  # helper to check index array properties
@@ -3462,7 +3907,7 @@ def indexedarray4d(*args, **kwargs):
3462
3907
 
3463
3908
  from warp.fabric import fabricarray, indexedfabricarray # noqa: E402
3464
3909
 
3465
- array_types = (array, indexedarray, fabricarray, indexedfabricarray)
3910
+ array_types = (array, indexedarray, fabricarray, indexedfabricarray, fixedarray)
3466
3911
 
3467
3912
 
3468
3913
  def array_type_id(a):
@@ -3663,11 +4108,11 @@ class Bvh:
3663
4108
  )
3664
4109
  constructor = "sah"
3665
4110
 
3666
- self.id = self.runtime.core.bvh_create_host(
4111
+ self.id = self.runtime.core.wp_bvh_create_host(
3667
4112
  get_data(lowers), get_data(uppers), len(lowers), bvh_constructor_values[constructor]
3668
4113
  )
3669
4114
  else:
3670
- self.id = self.runtime.core.bvh_create_device(
4115
+ self.id = self.runtime.core.wp_bvh_create_device(
3671
4116
  self.device.context,
3672
4117
  get_data(lowers),
3673
4118
  get_data(uppers),
@@ -3680,11 +4125,11 @@ class Bvh:
3680
4125
  return
3681
4126
 
3682
4127
  if self.device.is_cpu:
3683
- self.runtime.core.bvh_destroy_host(self.id)
4128
+ self.runtime.core.wp_bvh_destroy_host(self.id)
3684
4129
  else:
3685
4130
  # use CUDA context guard to avoid side effects during garbage collection
3686
4131
  with self.device.context_guard:
3687
- self.runtime.core.bvh_destroy_device(self.id)
4132
+ self.runtime.core.wp_bvh_destroy_device(self.id)
3688
4133
 
3689
4134
  def refit(self):
3690
4135
  """Refit the BVH.
@@ -3693,9 +4138,9 @@ class Bvh:
3693
4138
  """
3694
4139
 
3695
4140
  if self.device.is_cpu:
3696
- self.runtime.core.bvh_refit_host(self.id)
4141
+ self.runtime.core.wp_bvh_refit_host(self.id)
3697
4142
  else:
3698
- self.runtime.core.bvh_refit_device(self.id)
4143
+ self.runtime.core.wp_bvh_refit_device(self.id)
3699
4144
  self.runtime.verify_cuda_device(self.device)
3700
4145
 
3701
4146
 
@@ -3777,7 +4222,7 @@ class Mesh:
3777
4222
  )
3778
4223
  bvh_constructor = "sah"
3779
4224
 
3780
- self.id = self.runtime.core.mesh_create_host(
4225
+ self.id = self.runtime.core.wp_mesh_create_host(
3781
4226
  points.__ctype__(),
3782
4227
  velocities.__ctype__() if velocities else array().__ctype__(),
3783
4228
  indices.__ctype__(),
@@ -3787,7 +4232,7 @@ class Mesh:
3787
4232
  bvh_constructor_values[bvh_constructor],
3788
4233
  )
3789
4234
  else:
3790
- self.id = self.runtime.core.mesh_create_device(
4235
+ self.id = self.runtime.core.wp_mesh_create_device(
3791
4236
  self.device.context,
3792
4237
  points.__ctype__(),
3793
4238
  velocities.__ctype__() if velocities else array().__ctype__(),
@@ -3803,11 +4248,11 @@ class Mesh:
3803
4248
  return
3804
4249
 
3805
4250
  if self.device.is_cpu:
3806
- self.runtime.core.mesh_destroy_host(self.id)
4251
+ self.runtime.core.wp_mesh_destroy_host(self.id)
3807
4252
  else:
3808
4253
  # use CUDA context guard to avoid side effects during garbage collection
3809
4254
  with self.device.context_guard:
3810
- self.runtime.core.mesh_destroy_device(self.id)
4255
+ self.runtime.core.wp_mesh_destroy_device(self.id)
3811
4256
 
3812
4257
  def refit(self):
3813
4258
  """Refit the BVH to points.
@@ -3816,9 +4261,9 @@ class Mesh:
3816
4261
  """
3817
4262
 
3818
4263
  if self.device.is_cpu:
3819
- self.runtime.core.mesh_refit_host(self.id)
4264
+ self.runtime.core.wp_mesh_refit_host(self.id)
3820
4265
  else:
3821
- self.runtime.core.mesh_refit_device(self.id)
4266
+ self.runtime.core.wp_mesh_refit_device(self.id)
3822
4267
  self.runtime.verify_cuda_device(self.device)
3823
4268
 
3824
4269
  @property
@@ -3848,9 +4293,9 @@ class Mesh:
3848
4293
 
3849
4294
  self._points = points_new
3850
4295
  if self.device.is_cpu:
3851
- self.runtime.core.mesh_set_points_host(self.id, points_new.__ctype__())
4296
+ self.runtime.core.wp_mesh_set_points_host(self.id, points_new.__ctype__())
3852
4297
  else:
3853
- self.runtime.core.mesh_set_points_device(self.id, points_new.__ctype__())
4298
+ self.runtime.core.wp_mesh_set_points_device(self.id, points_new.__ctype__())
3854
4299
  self.runtime.verify_cuda_device(self.device)
3855
4300
 
3856
4301
  @property
@@ -3878,9 +4323,9 @@ class Mesh:
3878
4323
 
3879
4324
  self._velocities = velocities_new
3880
4325
  if self.device.is_cpu:
3881
- self.runtime.core.mesh_set_velocities_host(self.id, velocities_new.__ctype__())
4326
+ self.runtime.core.wp_mesh_set_velocities_host(self.id, velocities_new.__ctype__())
3882
4327
  else:
3883
- self.runtime.core.mesh_set_velocities_device(self.id, velocities_new.__ctype__())
4328
+ self.runtime.core.wp_mesh_set_velocities_device(self.id, velocities_new.__ctype__())
3884
4329
  self.runtime.verify_cuda_device(self.device)
3885
4330
 
3886
4331
 
@@ -3912,11 +4357,11 @@ class Volume:
3912
4357
 
3913
4358
  owner = False
3914
4359
  if self.device.is_cpu:
3915
- self.id = self.runtime.core.volume_create_host(
4360
+ self.id = self.runtime.core.wp_volume_create_host(
3916
4361
  ctypes.cast(data.ptr, ctypes.c_void_p), data.size, copy, owner
3917
4362
  )
3918
4363
  else:
3919
- self.id = self.runtime.core.volume_create_device(
4364
+ self.id = self.runtime.core.wp_volume_create_device(
3920
4365
  self.device.context, ctypes.cast(data.ptr, ctypes.c_void_p), data.size, copy, owner
3921
4366
  )
3922
4367
 
@@ -3928,18 +4373,18 @@ class Volume:
3928
4373
  return
3929
4374
 
3930
4375
  if self.device.is_cpu:
3931
- self.runtime.core.volume_destroy_host(self.id)
4376
+ self.runtime.core.wp_volume_destroy_host(self.id)
3932
4377
  else:
3933
4378
  # use CUDA context guard to avoid side effects during garbage collection
3934
4379
  with self.device.context_guard:
3935
- self.runtime.core.volume_destroy_device(self.id)
4380
+ self.runtime.core.wp_volume_destroy_device(self.id)
3936
4381
 
3937
4382
  def array(self) -> array:
3938
4383
  """Return the raw memory buffer of the :class:`Volume` as an array."""
3939
4384
 
3940
4385
  buf = ctypes.c_void_p(0)
3941
4386
  size = ctypes.c_uint64(0)
3942
- self.runtime.core.volume_get_buffer_info(self.id, ctypes.byref(buf), ctypes.byref(size))
4387
+ self.runtime.core.wp_volume_get_buffer_info(self.id, ctypes.byref(buf), ctypes.byref(size))
3943
4388
  return array(ptr=buf.value, dtype=uint8, shape=size.value, device=self.device)
3944
4389
 
3945
4390
  def get_tile_count(self) -> int:
@@ -3949,7 +4394,9 @@ class Volume:
3949
4394
  ctypes.c_uint64(0),
3950
4395
  ctypes.c_uint32(0),
3951
4396
  )
3952
- self.runtime.core.volume_get_tile_and_voxel_count(self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count))
4397
+ self.runtime.core.wp_volume_get_tile_and_voxel_count(
4398
+ self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count)
4399
+ )
3953
4400
  return tile_count.value
3954
4401
 
3955
4402
  def get_tiles(self, out: array | None = None) -> array:
@@ -3976,9 +4423,9 @@ class Volume:
3976
4423
  )
3977
4424
 
3978
4425
  if self.device.is_cpu:
3979
- self.runtime.core.volume_get_tiles_host(self.id, out.ptr)
4426
+ self.runtime.core.wp_volume_get_tiles_host(self.id, out.ptr)
3980
4427
  else:
3981
- self.runtime.core.volume_get_tiles_device(self.id, out.ptr)
4428
+ self.runtime.core.wp_volume_get_tiles_device(self.id, out.ptr)
3982
4429
 
3983
4430
  return out
3984
4431
 
@@ -3989,7 +4436,9 @@ class Volume:
3989
4436
  ctypes.c_uint64(0),
3990
4437
  ctypes.c_uint32(0),
3991
4438
  )
3992
- self.runtime.core.volume_get_tile_and_voxel_count(self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count))
4439
+ self.runtime.core.wp_volume_get_tile_and_voxel_count(
4440
+ self.id, ctypes.byref(tile_count), ctypes.byref(voxel_count)
4441
+ )
3993
4442
  return voxel_count.value
3994
4443
 
3995
4444
  def get_voxels(self, out: array | None = None) -> array:
@@ -4015,9 +4464,9 @@ class Volume:
4015
4464
  )
4016
4465
 
4017
4466
  if self.device.is_cpu:
4018
- self.runtime.core.volume_get_voxels_host(self.id, out.ptr)
4467
+ self.runtime.core.wp_volume_get_voxels_host(self.id, out.ptr)
4019
4468
  else:
4020
- self.runtime.core.volume_get_voxels_device(self.id, out.ptr)
4469
+ self.runtime.core.wp_volume_get_voxels_device(self.id, out.ptr)
4021
4470
 
4022
4471
  return out
4023
4472
 
@@ -4028,7 +4477,7 @@ class Volume:
4028
4477
  raise RuntimeError("Invalid Volume")
4029
4478
 
4030
4479
  dx, dy, dz = ctypes.c_float(0), ctypes.c_float(0), ctypes.c_float(0)
4031
- self.runtime.core.volume_get_voxel_size(self.id, ctypes.byref(dx), ctypes.byref(dy), ctypes.byref(dz))
4480
+ self.runtime.core.wp_volume_get_voxel_size(self.id, ctypes.byref(dx), ctypes.byref(dy), ctypes.byref(dz))
4032
4481
  return (dx.value, dy.value, dz.value)
4033
4482
 
4034
4483
  class GridInfo(NamedTuple):
@@ -4061,7 +4510,7 @@ class Volume:
4061
4510
  transform_buffer = (ctypes.c_float * 9)()
4062
4511
  type_str_buffer = (ctypes.c_char * 16)()
4063
4512
 
4064
- name = self.runtime.core.volume_get_grid_info(
4513
+ name = self.runtime.core.wp_volume_get_grid_info(
4065
4514
  self.id,
4066
4515
  ctypes.byref(grid_size),
4067
4516
  ctypes.byref(grid_index),
@@ -4124,7 +4573,7 @@ class Volume:
4124
4573
  def get_feature_array_count(self) -> int:
4125
4574
  """Return the number of supplemental data arrays stored alongside the grid"""
4126
4575
 
4127
- return self.runtime.core.volume_get_blind_data_count(self.id)
4576
+ return self.runtime.core.wp_volume_get_blind_data_count(self.id)
4128
4577
 
4129
4578
  class FeatureArrayInfo(NamedTuple):
4130
4579
  """Metadata for a supplemental data array"""
@@ -4149,7 +4598,7 @@ class Volume:
4149
4598
  value_size = ctypes.c_uint32(0)
4150
4599
  type_str_buffer = (ctypes.c_char * 16)()
4151
4600
 
4152
- name = self.runtime.core.volume_get_blind_data_info(
4601
+ name = self.runtime.core.wp_volume_get_blind_data_info(
4153
4602
  self.id,
4154
4603
  feature_index,
4155
4604
  ctypes.byref(buf),
@@ -4425,7 +4874,7 @@ class Volume:
4425
4874
  # (to allow this we would need to ref-count the volume descriptor)
4426
4875
  existing_buf = ctypes.c_void_p(0)
4427
4876
  existing_size = ctypes.c_uint64(0)
4428
- warp.context.runtime.core.volume_get_buffer_info(
4877
+ warp.context.runtime.core.wp_volume_get_buffer_info(
4429
4878
  grid_ptr, ctypes.byref(existing_buf), ctypes.byref(existing_size)
4430
4879
  )
4431
4880
 
@@ -4678,7 +5127,7 @@ class Volume:
4678
5127
  transform_buf, translation_buf = Volume._fill_transform_buffers(voxel_size, translation, transform)
4679
5128
 
4680
5129
  if bg_value is None:
4681
- volume.id = volume.runtime.core.volume_index_from_tiles_device(
5130
+ volume.id = volume.runtime.core.wp_volume_index_from_tiles_device(
4682
5131
  volume.device.context,
4683
5132
  ctypes.c_void_p(tile_points.ptr),
4684
5133
  tile_points.shape[0],
@@ -4717,7 +5166,7 @@ class Volume:
4717
5166
  cvalue_size = ctypes.sizeof(cvalue)
4718
5167
  cvalue_type = nvdb_type.encode("ascii")
4719
5168
 
4720
- volume.id = volume.runtime.core.volume_from_tiles_device(
5169
+ volume.id = volume.runtime.core.wp_volume_from_tiles_device(
4721
5170
  volume.device.context,
4722
5171
  ctypes.c_void_p(tile_points.ptr),
4723
5172
  tile_points.shape[0],
@@ -4779,7 +5228,7 @@ class Volume:
4779
5228
 
4780
5229
  transform_buf, translation_buf = Volume._fill_transform_buffers(voxel_size, translation, transform)
4781
5230
 
4782
- volume.id = volume.runtime.core.volume_from_active_voxels_device(
5231
+ volume.id = volume.runtime.core.wp_volume_from_active_voxels_device(
4783
5232
  volume.device.context,
4784
5233
  ctypes.c_void_p(voxel_points.ptr),
4785
5234
  voxel_points.shape[0],
@@ -5019,9 +5468,9 @@ class HashGrid:
5019
5468
  self.device = self.runtime.get_device(device)
5020
5469
 
5021
5470
  if self.device.is_cpu:
5022
- self.id = self.runtime.core.hash_grid_create_host(dim_x, dim_y, dim_z)
5471
+ self.id = self.runtime.core.wp_hash_grid_create_host(dim_x, dim_y, dim_z)
5023
5472
  else:
5024
- self.id = self.runtime.core.hash_grid_create_device(self.device.context, dim_x, dim_y, dim_z)
5473
+ self.id = self.runtime.core.wp_hash_grid_create_device(self.device.context, dim_x, dim_y, dim_z)
5025
5474
 
5026
5475
  # indicates whether the grid data has been reserved for use by a kernel
5027
5476
  self.reserved = False
@@ -5046,16 +5495,16 @@ class HashGrid:
5046
5495
  points = points.contiguous().flatten()
5047
5496
 
5048
5497
  if self.device.is_cpu:
5049
- self.runtime.core.hash_grid_update_host(self.id, radius, ctypes.byref(points.__ctype__()))
5498
+ self.runtime.core.wp_hash_grid_update_host(self.id, radius, ctypes.byref(points.__ctype__()))
5050
5499
  else:
5051
- self.runtime.core.hash_grid_update_device(self.id, radius, ctypes.byref(points.__ctype__()))
5500
+ self.runtime.core.wp_hash_grid_update_device(self.id, radius, ctypes.byref(points.__ctype__()))
5052
5501
  self.reserved = True
5053
5502
 
5054
5503
  def reserve(self, num_points):
5055
5504
  if self.device.is_cpu:
5056
- self.runtime.core.hash_grid_reserve_host(self.id, num_points)
5505
+ self.runtime.core.wp_hash_grid_reserve_host(self.id, num_points)
5057
5506
  else:
5058
- self.runtime.core.hash_grid_reserve_device(self.id, num_points)
5507
+ self.runtime.core.wp_hash_grid_reserve_device(self.id, num_points)
5059
5508
  self.reserved = True
5060
5509
 
5061
5510
  def __del__(self):
@@ -5063,166 +5512,11 @@ class HashGrid:
5063
5512
  return
5064
5513
 
5065
5514
  if self.device.is_cpu:
5066
- self.runtime.core.hash_grid_destroy_host(self.id)
5515
+ self.runtime.core.wp_hash_grid_destroy_host(self.id)
5067
5516
  else:
5068
5517
  # use CUDA context guard to avoid side effects during garbage collection
5069
5518
  with self.device.context_guard:
5070
- self.runtime.core.hash_grid_destroy_device(self.id)
5071
-
5072
-
5073
- class MarchingCubes:
5074
- def __new__(cls, *args, **kwargs):
5075
- instance = super().__new__(cls)
5076
- instance.id = None
5077
- return instance
5078
-
5079
- def __init__(self, nx: int, ny: int, nz: int, max_verts: int, max_tris: int, device=None):
5080
- """CUDA-based Marching Cubes algorithm to extract a 2D surface mesh from a 3D volume.
5081
-
5082
- Attributes:
5083
- id: Unique identifier for this object.
5084
- verts (:class:`warp.array`): Array of vertex positions of type :class:`warp.vec3f`
5085
- for the output surface mesh.
5086
- This is populated after running :func:`surface`.
5087
- indices (:class:`warp.array`): Array containing indices of type :class:`warp.int32`
5088
- defining triangles for the output surface mesh.
5089
- This is populated after running :func:`surface`.
5090
-
5091
- Each set of three consecutive integers in the array represents a single triangle,
5092
- in which each integer is an index referring to a vertex in the :attr:`verts` array.
5093
-
5094
- Args:
5095
- nx: Number of cubes in the x-direction.
5096
- ny: Number of cubes in the y-direction.
5097
- nz: Number of cubes in the z-direction.
5098
- max_verts: Maximum expected number of vertices (used for array preallocation).
5099
- max_tris: Maximum expected number of triangles (used for array preallocation).
5100
- device (Devicelike): CUDA device on which to run marching cubes and allocate memory.
5101
-
5102
- Raises:
5103
- RuntimeError: ``device`` not a CUDA device.
5104
-
5105
- .. note::
5106
- The shape of the marching cubes should match the shape of the scalar field being surfaced.
5107
-
5108
- """
5109
-
5110
- self.id = 0
5111
-
5112
- self.runtime = warp.context.runtime
5113
-
5114
- self.device = self.runtime.get_device(device)
5115
-
5116
- if not self.device.is_cuda:
5117
- raise RuntimeError("Only CUDA devices are supported for marching cubes")
5118
-
5119
- self.nx = nx
5120
- self.ny = ny
5121
- self.nz = nz
5122
-
5123
- self.max_verts = max_verts
5124
- self.max_tris = max_tris
5125
-
5126
- # bindings to warp.so
5127
- self.alloc = self.runtime.core.marching_cubes_create_device
5128
- self.alloc.argtypes = [ctypes.c_void_p]
5129
- self.alloc.restype = ctypes.c_uint64
5130
- self.free = self.runtime.core.marching_cubes_destroy_device
5131
-
5132
- from warp.context import zeros
5133
-
5134
- self.verts = zeros(max_verts, dtype=vec3, device=self.device)
5135
- self.indices = zeros(max_tris * 3, dtype=warp.int32, device=self.device)
5136
-
5137
- # alloc surfacer
5138
- self.id = ctypes.c_uint64(self.alloc(self.device.context))
5139
-
5140
- def __del__(self):
5141
- if not self.id:
5142
- return
5143
-
5144
- # use CUDA context guard to avoid side effects during garbage collection
5145
- with self.device.context_guard:
5146
- # destroy surfacer
5147
- self.free(self.id)
5148
-
5149
- def resize(self, nx: int, ny: int, nz: int, max_verts: int, max_tris: int) -> None:
5150
- """Update the expected input and maximum output sizes for the marching cubes calculation.
5151
-
5152
- This function has no immediate effect on the underlying buffers.
5153
- The new values take effect on the next :func:`surface` call.
5154
-
5155
- Args:
5156
- nx: Number of cubes in the x-direction.
5157
- ny: Number of cubes in the y-direction.
5158
- nz: Number of cubes in the z-direction.
5159
- max_verts: Maximum expected number of vertices (used for array preallocation).
5160
- max_tris: Maximum expected number of triangles (used for array preallocation).
5161
- """
5162
- # actual allocations will be resized on next call to surface()
5163
- self.nx = nx
5164
- self.ny = ny
5165
- self.nz = nz
5166
- self.max_verts = max_verts
5167
- self.max_tris = max_tris
5168
-
5169
- def surface(self, field: array(dtype=float, ndim=3), threshold: float) -> None:
5170
- """Compute a 2D surface mesh of a given isosurface from a 3D scalar field.
5171
-
5172
- The triangles and vertices defining the output mesh are written to the
5173
- :attr:`indices` and :attr:`verts` arrays.
5174
-
5175
- Args:
5176
- field: Scalar field from which to generate a mesh.
5177
- threshold: Target isosurface value.
5178
-
5179
- Raises:
5180
- ValueError: ``field`` is not a 3D array.
5181
- ValueError: Marching cubes shape does not match the shape of ``field``.
5182
- RuntimeError: :attr:`max_verts` and/or :attr:`max_tris` might be too small to hold the surface mesh.
5183
- """
5184
-
5185
- # WP_API int marching_cubes_surface_host(const float* field, int nx, int ny, int nz, float threshold, wp::vec3* verts, int* triangles, int max_verts, int max_tris, int* out_num_verts, int* out_num_tris);
5186
- num_verts = ctypes.c_int(0)
5187
- num_tris = ctypes.c_int(0)
5188
-
5189
- self.runtime.core.marching_cubes_surface_device.restype = ctypes.c_int
5190
-
5191
- # For now we require that input field shape matches nx, ny, nz
5192
- if field.ndim != 3:
5193
- raise ValueError(f"Input field must be a three-dimensional array (got {field.ndim}).")
5194
- if field.shape[0] != self.nx or field.shape[1] != self.ny or field.shape[2] != self.nz:
5195
- raise ValueError(
5196
- f"Marching cubes shape ({self.nx}, {self.ny}, {self.nz}) does not match the "
5197
- f"input array shape {field.shape}."
5198
- )
5199
-
5200
- error = self.runtime.core.marching_cubes_surface_device(
5201
- self.id,
5202
- ctypes.cast(field.ptr, ctypes.c_void_p),
5203
- self.nx,
5204
- self.ny,
5205
- self.nz,
5206
- ctypes.c_float(threshold),
5207
- ctypes.cast(self.verts.ptr, ctypes.c_void_p),
5208
- ctypes.cast(self.indices.ptr, ctypes.c_void_p),
5209
- self.max_verts,
5210
- self.max_tris,
5211
- ctypes.c_void_p(ctypes.addressof(num_verts)),
5212
- ctypes.c_void_p(ctypes.addressof(num_tris)),
5213
- )
5214
-
5215
- if error:
5216
- raise RuntimeError(
5217
- f"Buffers may not be large enough, marching cubes required at least {num_verts} vertices, and {num_tris} triangles."
5218
- )
5219
-
5220
- # resize the geometry arrays
5221
- self.verts.shape = (num_verts.value,)
5222
- self.indices.shape = (num_tris.value * 3,)
5223
-
5224
- self.verts.size = num_verts.value
5225
- self.indices.size = num_tris.value * 3
5519
+ self.runtime.core.wp_hash_grid_destroy_device(self.id)
5226
5520
 
5227
5521
 
5228
5522
  generic_types = (Any, Scalar, Float, Int)
@@ -5389,6 +5683,16 @@ def get_type_code(arg_type: type) -> str:
5389
5683
  # special case for generics
5390
5684
  # note: since Python 3.11 Any is a type, so we check for it first
5391
5685
  return "?"
5686
+ elif (
5687
+ sys.version_info < (3, 11)
5688
+ and hasattr(types, "GenericAlias")
5689
+ and isinstance(arg_type, types.GenericAlias)
5690
+ and arg_type.__origin__ is tuple
5691
+ ):
5692
+ # Handle tuple[...] on Python <= 3.10 where it creates types.GenericAlias
5693
+ # This must come before isinstance(arg_type, type) check
5694
+ arg_types = arg_type.__args__
5695
+ return f"tpl{len(arg_types)}{''.join(get_type_code(x) for x in arg_types)}"
5392
5696
  elif isinstance(arg_type, type):
5393
5697
  if hasattr(arg_type, "_wp_scalar_type_"):
5394
5698
  # vector/matrix type