warp-lang 1.2.2__py3-none-manylinux2014_aarch64.whl → 1.3.0__py3-none-manylinux2014_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (193) hide show
  1. warp/__init__.py +8 -6
  2. warp/autograd.py +823 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +6 -2
  5. warp/builtins.py +1410 -886
  6. warp/codegen.py +503 -166
  7. warp/config.py +48 -18
  8. warp/context.py +400 -198
  9. warp/dlpack.py +8 -0
  10. warp/examples/assets/bunny.usd +0 -0
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +1 -1
  12. warp/examples/benchmarks/benchmark_interop_torch.py +158 -0
  13. warp/examples/benchmarks/benchmark_launches.py +1 -1
  14. warp/examples/core/example_cupy.py +78 -0
  15. warp/examples/fem/example_apic_fluid.py +17 -36
  16. warp/examples/fem/example_burgers.py +9 -18
  17. warp/examples/fem/example_convection_diffusion.py +7 -17
  18. warp/examples/fem/example_convection_diffusion_dg.py +27 -47
  19. warp/examples/fem/example_deformed_geometry.py +11 -22
  20. warp/examples/fem/example_diffusion.py +7 -18
  21. warp/examples/fem/example_diffusion_3d.py +24 -28
  22. warp/examples/fem/example_diffusion_mgpu.py +7 -14
  23. warp/examples/fem/example_magnetostatics.py +190 -0
  24. warp/examples/fem/example_mixed_elasticity.py +111 -80
  25. warp/examples/fem/example_navier_stokes.py +30 -34
  26. warp/examples/fem/example_nonconforming_contact.py +290 -0
  27. warp/examples/fem/example_stokes.py +17 -32
  28. warp/examples/fem/example_stokes_transfer.py +12 -21
  29. warp/examples/fem/example_streamlines.py +350 -0
  30. warp/examples/fem/utils.py +936 -0
  31. warp/fabric.py +5 -2
  32. warp/fem/__init__.py +13 -3
  33. warp/fem/cache.py +161 -11
  34. warp/fem/dirichlet.py +37 -28
  35. warp/fem/domain.py +105 -14
  36. warp/fem/field/__init__.py +14 -3
  37. warp/fem/field/field.py +454 -11
  38. warp/fem/field/nodal_field.py +33 -18
  39. warp/fem/geometry/deformed_geometry.py +50 -15
  40. warp/fem/geometry/hexmesh.py +12 -24
  41. warp/fem/geometry/nanogrid.py +106 -31
  42. warp/fem/geometry/quadmesh_2d.py +6 -11
  43. warp/fem/geometry/tetmesh.py +103 -61
  44. warp/fem/geometry/trimesh_2d.py +98 -47
  45. warp/fem/integrate.py +231 -186
  46. warp/fem/operator.py +14 -9
  47. warp/fem/quadrature/pic_quadrature.py +35 -9
  48. warp/fem/quadrature/quadrature.py +119 -32
  49. warp/fem/space/basis_space.py +98 -22
  50. warp/fem/space/collocated_function_space.py +3 -1
  51. warp/fem/space/function_space.py +7 -2
  52. warp/fem/space/grid_2d_function_space.py +3 -3
  53. warp/fem/space/grid_3d_function_space.py +4 -4
  54. warp/fem/space/hexmesh_function_space.py +3 -2
  55. warp/fem/space/nanogrid_function_space.py +12 -14
  56. warp/fem/space/partition.py +45 -47
  57. warp/fem/space/restriction.py +19 -16
  58. warp/fem/space/shape/cube_shape_function.py +91 -3
  59. warp/fem/space/shape/shape_function.py +7 -0
  60. warp/fem/space/shape/square_shape_function.py +32 -0
  61. warp/fem/space/shape/tet_shape_function.py +11 -7
  62. warp/fem/space/shape/triangle_shape_function.py +10 -1
  63. warp/fem/space/topology.py +116 -42
  64. warp/fem/types.py +8 -1
  65. warp/fem/utils.py +301 -83
  66. warp/native/array.h +16 -0
  67. warp/native/builtin.h +0 -15
  68. warp/native/cuda_util.cpp +14 -6
  69. warp/native/exports.h +1348 -1308
  70. warp/native/quat.h +79 -0
  71. warp/native/rand.h +27 -4
  72. warp/native/sparse.cpp +83 -81
  73. warp/native/sparse.cu +381 -453
  74. warp/native/vec.h +64 -0
  75. warp/native/volume.cpp +40 -49
  76. warp/native/volume_builder.cu +2 -3
  77. warp/native/volume_builder.h +12 -17
  78. warp/native/warp.cu +3 -3
  79. warp/native/warp.h +69 -59
  80. warp/render/render_opengl.py +17 -9
  81. warp/sim/articulation.py +117 -17
  82. warp/sim/collide.py +35 -29
  83. warp/sim/model.py +123 -18
  84. warp/sim/render.py +3 -1
  85. warp/sparse.py +867 -203
  86. warp/stubs.py +312 -541
  87. warp/tape.py +29 -1
  88. warp/tests/disabled_kinematics.py +1 -1
  89. warp/tests/test_adam.py +1 -1
  90. warp/tests/test_arithmetic.py +1 -1
  91. warp/tests/test_array.py +58 -1
  92. warp/tests/test_array_reduce.py +1 -1
  93. warp/tests/test_async.py +1 -1
  94. warp/tests/test_atomic.py +1 -1
  95. warp/tests/test_bool.py +1 -1
  96. warp/tests/test_builtins_resolution.py +1 -1
  97. warp/tests/test_bvh.py +6 -1
  98. warp/tests/test_closest_point_edge_edge.py +1 -1
  99. warp/tests/test_codegen.py +66 -1
  100. warp/tests/test_compile_consts.py +1 -1
  101. warp/tests/test_conditional.py +1 -1
  102. warp/tests/test_copy.py +1 -1
  103. warp/tests/test_ctypes.py +1 -1
  104. warp/tests/test_dense.py +1 -1
  105. warp/tests/test_devices.py +1 -1
  106. warp/tests/test_dlpack.py +1 -1
  107. warp/tests/test_examples.py +33 -4
  108. warp/tests/test_fabricarray.py +5 -2
  109. warp/tests/test_fast_math.py +1 -1
  110. warp/tests/test_fem.py +213 -6
  111. warp/tests/test_fp16.py +1 -1
  112. warp/tests/test_func.py +1 -1
  113. warp/tests/test_future_annotations.py +90 -0
  114. warp/tests/test_generics.py +1 -1
  115. warp/tests/test_grad.py +1 -1
  116. warp/tests/test_grad_customs.py +1 -1
  117. warp/tests/test_grad_debug.py +247 -0
  118. warp/tests/test_hash_grid.py +6 -1
  119. warp/tests/test_implicit_init.py +354 -0
  120. warp/tests/test_import.py +1 -1
  121. warp/tests/test_indexedarray.py +1 -1
  122. warp/tests/test_intersect.py +1 -1
  123. warp/tests/test_jax.py +1 -1
  124. warp/tests/test_large.py +1 -1
  125. warp/tests/test_launch.py +1 -1
  126. warp/tests/test_lerp.py +1 -1
  127. warp/tests/test_linear_solvers.py +1 -1
  128. warp/tests/test_lvalue.py +1 -1
  129. warp/tests/test_marching_cubes.py +5 -2
  130. warp/tests/test_mat.py +34 -35
  131. warp/tests/test_mat_lite.py +2 -1
  132. warp/tests/test_mat_scalar_ops.py +1 -1
  133. warp/tests/test_math.py +1 -1
  134. warp/tests/test_matmul.py +20 -16
  135. warp/tests/test_matmul_lite.py +1 -1
  136. warp/tests/test_mempool.py +1 -1
  137. warp/tests/test_mesh.py +5 -2
  138. warp/tests/test_mesh_query_aabb.py +1 -1
  139. warp/tests/test_mesh_query_point.py +1 -1
  140. warp/tests/test_mesh_query_ray.py +1 -1
  141. warp/tests/test_mlp.py +1 -1
  142. warp/tests/test_model.py +1 -1
  143. warp/tests/test_module_hashing.py +77 -1
  144. warp/tests/test_modules_lite.py +1 -1
  145. warp/tests/test_multigpu.py +1 -1
  146. warp/tests/test_noise.py +1 -1
  147. warp/tests/test_operators.py +1 -1
  148. warp/tests/test_options.py +1 -1
  149. warp/tests/test_overwrite.py +542 -0
  150. warp/tests/test_peer.py +1 -1
  151. warp/tests/test_pinned.py +1 -1
  152. warp/tests/test_print.py +1 -1
  153. warp/tests/test_quat.py +15 -1
  154. warp/tests/test_rand.py +1 -1
  155. warp/tests/test_reload.py +1 -1
  156. warp/tests/test_rounding.py +1 -1
  157. warp/tests/test_runlength_encode.py +1 -1
  158. warp/tests/test_scalar_ops.py +95 -0
  159. warp/tests/test_sim_grad.py +1 -1
  160. warp/tests/test_sim_kinematics.py +1 -1
  161. warp/tests/test_smoothstep.py +1 -1
  162. warp/tests/test_sparse.py +82 -15
  163. warp/tests/test_spatial.py +1 -1
  164. warp/tests/test_special_values.py +2 -11
  165. warp/tests/test_streams.py +11 -1
  166. warp/tests/test_struct.py +1 -1
  167. warp/tests/test_tape.py +1 -1
  168. warp/tests/test_torch.py +194 -1
  169. warp/tests/test_transient_module.py +1 -1
  170. warp/tests/test_types.py +1 -1
  171. warp/tests/test_utils.py +1 -1
  172. warp/tests/test_vec.py +15 -63
  173. warp/tests/test_vec_lite.py +2 -1
  174. warp/tests/test_vec_scalar_ops.py +65 -1
  175. warp/tests/test_verify_fp.py +1 -1
  176. warp/tests/test_volume.py +28 -2
  177. warp/tests/test_volume_write.py +1 -1
  178. warp/tests/unittest_serial.py +1 -1
  179. warp/tests/unittest_suites.py +9 -1
  180. warp/tests/walkthrough_debug.py +1 -1
  181. warp/thirdparty/unittest_parallel.py +2 -5
  182. warp/torch.py +103 -41
  183. warp/types.py +341 -224
  184. warp/utils.py +11 -2
  185. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/METADATA +99 -46
  186. warp_lang-1.3.0.dist-info/RECORD +368 -0
  187. warp/examples/fem/bsr_utils.py +0 -378
  188. warp/examples/fem/mesh_utils.py +0 -133
  189. warp/examples/fem/plot_utils.py +0 -292
  190. warp_lang-1.2.2.dist-info/RECORD +0 -359
  191. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/LICENSE.md +0 -0
  192. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/WHEEL +0 -0
  193. {warp_lang-1.2.2.dist-info → warp_lang-1.3.0.dist-info}/top_level.txt +0 -0
warp/types.py CHANGED
@@ -12,7 +12,7 @@ import ctypes
12
12
  import inspect
13
13
  import struct
14
14
  import zlib
15
- from typing import Any, Callable, Generic, List, NamedTuple, Optional, Tuple, TypeVar, Union
15
+ from typing import Any, Callable, Generic, List, NamedTuple, Optional, Sequence, Tuple, TypeVar, Union
16
16
 
17
17
  import numpy as np
18
18
 
@@ -50,6 +50,15 @@ class Array(Generic[DType]):
50
50
  pass
51
51
 
52
52
 
53
+ int_tuple_type_hints = {
54
+ Tuple[int]: 1,
55
+ Tuple[int, int]: 2,
56
+ Tuple[int, int, int]: 3,
57
+ Tuple[int, int, int, int]: 4,
58
+ Tuple[int, ...]: -1,
59
+ }
60
+
61
+
53
62
  def constant(x):
54
63
  """Function to declare compile-time constants accessible from Warp kernels
55
64
 
@@ -99,6 +108,7 @@ def vector(length, dtype):
99
108
  # warp scalar type:
100
109
  _wp_scalar_type_ = dtype
101
110
  _wp_type_params_ = [length, dtype]
111
+ _wp_type_args_ = {"length": length, "dtype": dtype}
102
112
  _wp_generic_type_str_ = "vec_t"
103
113
  _wp_generic_type_hint_ = Vector
104
114
  _wp_constructor_ = "vector"
@@ -282,6 +292,7 @@ def matrix(shape, dtype):
282
292
  # used in type checking and when writing out c++ code for constructors:
283
293
  _wp_scalar_type_ = dtype
284
294
  _wp_type_params_ = [shape[0], shape[1], dtype]
295
+ _wp_type_args_ = {"shape": (shape[0], shape[1]), "dtype": dtype}
285
296
  _wp_generic_type_str_ = "mat_t"
286
297
  _wp_generic_type_hint_ = Matrix
287
298
  _wp_constructor_ = "matrix"
@@ -471,233 +482,130 @@ class void:
471
482
  pass
472
483
 
473
484
 
474
- class bool:
475
- _length_ = 1
476
- _type_ = ctypes.c_bool
477
-
478
- def __init__(self, x=False):
485
+ class scalar_base:
486
+ def __init__(self, x=0):
479
487
  self.value = x
480
488
 
481
489
  def __bool__(self) -> builtins.bool:
482
490
  return self.value != 0
483
491
 
484
- def __float__(self) -> float:
485
- return float(self.value != 0)
486
-
487
- def __int__(self) -> int:
488
- return int(self.value != 0)
489
-
490
-
491
- class float16:
492
- _length_ = 1
493
- _type_ = ctypes.c_uint16
494
-
495
- def __init__(self, x=0.0):
496
- self.value = x
497
-
498
- def __bool__(self) -> bool:
499
- return self.value != 0.0
500
-
501
492
  def __float__(self) -> float:
502
493
  return float(self.value)
503
494
 
504
495
  def __int__(self) -> int:
505
496
  return int(self.value)
506
497
 
498
+ def __add__(self, y):
499
+ return warp.add(self, y)
507
500
 
508
- class float32:
509
- _length_ = 1
510
- _type_ = ctypes.c_float
501
+ def __radd__(self, y):
502
+ return warp.add(y, self)
511
503
 
512
- def __init__(self, x=0.0):
513
- self.value = x
504
+ def __sub__(self, y):
505
+ return warp.sub(self, y)
514
506
 
515
- def __bool__(self) -> bool:
516
- return self.value != 0.0
507
+ def __rsub__(self, y):
508
+ return warp.sub(y, self)
517
509
 
518
- def __float__(self) -> float:
519
- return float(self.value)
510
+ def __mul__(self, y):
511
+ return warp.mul(self, y)
520
512
 
521
- def __int__(self) -> int:
522
- return int(self.value)
513
+ def __rmul__(self, x):
514
+ return warp.mul(x, self)
523
515
 
516
+ def __truediv__(self, y):
517
+ return warp.div(self, y)
524
518
 
525
- class float64:
526
- _length_ = 1
527
- _type_ = ctypes.c_double
519
+ def __rtruediv__(self, x):
520
+ return warp.div(x, self)
528
521
 
529
- def __init__(self, x=0.0):
530
- self.value = x
522
+ def __pos__(self):
523
+ return warp.pos(self)
531
524
 
532
- def __bool__(self) -> bool:
533
- return self.value != 0.0
525
+ def __neg__(self):
526
+ return warp.neg(self)
534
527
 
535
- def __float__(self) -> float:
536
- return float(self.value)
537
528
 
538
- def __int__(self) -> int:
539
- return int(self.value)
540
-
541
-
542
- class int8:
543
- _length_ = 1
544
- _type_ = ctypes.c_int8
545
-
546
- def __init__(self, x=0):
547
- self.value = x
548
-
549
- def __bool__(self) -> bool:
550
- return self.value != 0
529
+ class float_base(scalar_base):
530
+ pass
551
531
 
552
- def __float__(self) -> float:
553
- return float(self.value)
554
-
555
- def __int__(self) -> int:
556
- return int(self.value)
557
532
 
533
+ class int_base(scalar_base):
558
534
  def __index__(self) -> int:
559
535
  return int(self.value)
560
536
 
561
537
 
562
- class uint8:
538
+ class bool:
563
539
  _length_ = 1
564
- _type_ = ctypes.c_uint8
540
+ _type_ = ctypes.c_bool
565
541
 
566
- def __init__(self, x=0):
542
+ def __init__(self, x=False):
567
543
  self.value = x
568
544
 
569
- def __bool__(self) -> bool:
545
+ def __bool__(self) -> builtins.bool:
570
546
  return self.value != 0
571
547
 
572
548
  def __float__(self) -> float:
573
- return float(self.value)
549
+ return float(self.value != 0)
574
550
 
575
551
  def __int__(self) -> int:
576
- return int(self.value)
577
-
578
- def __index__(self) -> int:
579
- return int(self.value)
552
+ return int(self.value != 0)
580
553
 
581
554
 
582
- class int16:
555
+ class float16(float_base):
583
556
  _length_ = 1
584
- _type_ = ctypes.c_int16
585
-
586
- def __init__(self, x=0):
587
- self.value = x
588
-
589
- def __bool__(self) -> bool:
590
- return self.value != 0
591
-
592
- def __float__(self) -> float:
593
- return float(self.value)
557
+ _type_ = ctypes.c_uint16
594
558
 
595
- def __int__(self) -> int:
596
- return int(self.value)
597
559
 
598
- def __index__(self) -> int:
599
- return int(self.value)
560
+ class float32(float_base):
561
+ _length_ = 1
562
+ _type_ = ctypes.c_float
600
563
 
601
564
 
602
- class uint16:
565
+ class float64(float_base):
603
566
  _length_ = 1
604
- _type_ = ctypes.c_uint16
567
+ _type_ = ctypes.c_double
605
568
 
606
- def __init__(self, x=0):
607
- self.value = x
608
569
 
609
- def __bool__(self) -> bool:
610
- return self.value != 0
611
-
612
- def __float__(self) -> float:
613
- return float(self.value)
570
+ class int8(int_base):
571
+ _length_ = 1
572
+ _type_ = ctypes.c_int8
614
573
 
615
- def __int__(self) -> int:
616
- return int(self.value)
617
574
 
618
- def __index__(self) -> int:
619
- return int(self.value)
575
+ class uint8(int_base):
576
+ _length_ = 1
577
+ _type_ = ctypes.c_uint8
620
578
 
621
579
 
622
- class int32:
580
+ class int16(int_base):
623
581
  _length_ = 1
624
- _type_ = ctypes.c_int32
582
+ _type_ = ctypes.c_int16
625
583
 
626
- def __init__(self, x=0):
627
- self.value = x
628
584
 
629
- def __bool__(self) -> bool:
630
- return self.value != 0
585
+ class uint16(int_base):
586
+ _length_ = 1
587
+ _type_ = ctypes.c_uint16
631
588
 
632
- def __float__(self) -> float:
633
- return float(self.value)
634
589
 
635
- def __int__(self) -> int:
636
- return int(self.value)
637
-
638
- def __index__(self) -> int:
639
- return int(self.value)
590
+ class int32(int_base):
591
+ _length_ = 1
592
+ _type_ = ctypes.c_int32
640
593
 
641
594
 
642
- class uint32:
595
+ class uint32(int_base):
643
596
  _length_ = 1
644
597
  _type_ = ctypes.c_uint32
645
598
 
646
- def __init__(self, x=0):
647
- self.value = x
648
-
649
- def __bool__(self) -> bool:
650
- return self.value != 0
651
-
652
- def __float__(self) -> float:
653
- return float(self.value)
654
-
655
- def __int__(self) -> int:
656
- return int(self.value)
657
-
658
- def __index__(self) -> int:
659
- return int(self.value)
660
599
 
661
-
662
- class int64:
600
+ class int64(int_base):
663
601
  _length_ = 1
664
602
  _type_ = ctypes.c_int64
665
603
 
666
- def __init__(self, x=0):
667
- self.value = x
668
-
669
- def __bool__(self) -> bool:
670
- return self.value != 0
671
604
 
672
- def __float__(self) -> float:
673
- return float(self.value)
674
-
675
- def __int__(self) -> int:
676
- return int(self.value)
677
-
678
- def __index__(self) -> int:
679
- return int(self.value)
680
-
681
-
682
- class uint64:
605
+ class uint64(int_base):
683
606
  _length_ = 1
684
607
  _type_ = ctypes.c_uint64
685
608
 
686
- def __init__(self, x=0):
687
- self.value = x
688
-
689
- def __bool__(self) -> bool:
690
- return self.value != 0
691
-
692
- def __float__(self) -> float:
693
- return float(self.value)
694
-
695
- def __int__(self) -> int:
696
- return int(self.value)
697
-
698
- def __index__(self) -> int:
699
- return int(self.value)
700
-
701
609
 
702
610
  def quaternion(dtype=Any):
703
611
  class quat_t(vector(length=4, dtype=dtype)):
@@ -707,6 +615,7 @@ def quaternion(dtype=Any):
707
615
 
708
616
  ret = quat_t
709
617
  ret._wp_type_params_ = [dtype]
618
+ ret._wp_type_args_ = {"dtype": dtype}
710
619
  ret._wp_generic_type_str_ = "quat_t"
711
620
  ret._wp_generic_type_hint_ = Quaternion
712
621
  ret._wp_constructor_ = "quaternion"
@@ -743,6 +652,7 @@ def transformation(dtype=Any):
743
652
  ),
744
653
  )
745
654
  _wp_type_params_ = [dtype]
655
+ _wp_type_args_ = {"dtype": dtype}
746
656
  _wp_generic_type_str_ = "transform_t"
747
657
  _wp_generic_type_hint_ = Transformation
748
658
  _wp_constructor_ = "transformation"
@@ -1150,6 +1060,9 @@ class bvh_query_t:
1150
1060
  pass
1151
1061
 
1152
1062
 
1063
+ BvhQuery = bvh_query_t
1064
+
1065
+
1153
1066
  # definition just for kernel type (cannot be a parameter), see mesh.h
1154
1067
  class mesh_query_aabb_t:
1155
1068
  """Object used to track state during mesh traversal."""
@@ -1158,6 +1071,9 @@ class mesh_query_aabb_t:
1158
1071
  pass
1159
1072
 
1160
1073
 
1074
+ MeshQueryAABB = mesh_query_aabb_t
1075
+
1076
+
1161
1077
  # definition just for kernel type (cannot be a parameter), see hash_grid.h
1162
1078
  class hash_grid_query_t:
1163
1079
  """Object used to track state during neighbor traversal."""
@@ -1166,6 +1082,9 @@ class hash_grid_query_t:
1166
1082
  pass
1167
1083
 
1168
1084
 
1085
+ HashGridQuery = hash_grid_query_t
1086
+
1087
+
1169
1088
  # maximum number of dimensions, must match array.h
1170
1089
  ARRAY_MAX_DIMS = 4
1171
1090
  LAUNCH_MAX_DIMS = 4
@@ -1378,7 +1297,8 @@ def type_repr(t):
1378
1297
  if t in scalar_types:
1379
1298
  return t.__name__
1380
1299
 
1381
- return t.__module__ + "." + t.__qualname__
1300
+ name = getattr(t, "__qualname__", t.__name__)
1301
+ return t.__module__ + "." + name
1382
1302
 
1383
1303
 
1384
1304
  def type_is_int(t):
@@ -1400,6 +1320,11 @@ def type_is_vector(t):
1400
1320
  return getattr(t, "_wp_generic_type_hint_", None) is Vector
1401
1321
 
1402
1322
 
1323
+ # returns True if the passed *type* is a quaternion
1324
+ def type_is_quaternion(t):
1325
+ return getattr(t, "_wp_generic_type_hint_", None) is Quaternion
1326
+
1327
+
1403
1328
  # returns True if the passed *type* is a matrix
1404
1329
  def type_is_matrix(t):
1405
1330
  return getattr(t, "_wp_generic_type_hint_", None) is Matrix
@@ -1432,9 +1357,30 @@ def is_array(a):
1432
1357
 
1433
1358
 
1434
1359
  def scalars_equal(a, b, match_generic):
1360
+ # convert to canonical types
1361
+ if a == float:
1362
+ a = float32
1363
+ elif a == int:
1364
+ a = int32
1365
+ elif a == builtins.bool:
1366
+ a = bool
1367
+
1368
+ if b == float:
1369
+ b = float32
1370
+ elif b == int:
1371
+ b = int32
1372
+ elif b == builtins.bool:
1373
+ b = bool
1374
+
1435
1375
  if match_generic:
1436
1376
  if a == Any or b == Any:
1437
1377
  return True
1378
+ if a == Int and b in int_types:
1379
+ return True
1380
+ if b == Int and a in int_types:
1381
+ return True
1382
+ if a == Int and b == Int:
1383
+ return True
1438
1384
  if a == Scalar and b in scalar_and_bool_types:
1439
1385
  return True
1440
1386
  if b == Scalar and a in scalar_and_bool_types:
@@ -1448,25 +1394,29 @@ def scalars_equal(a, b, match_generic):
1448
1394
  if a == Float and b == Float:
1449
1395
  return True
1450
1396
 
1451
- # convert to canonical types
1452
- if a == float:
1453
- a = float32
1454
- elif a == int:
1455
- a = int32
1456
- elif a == builtins.bool:
1457
- a = bool
1458
-
1459
- if b == float:
1460
- b = float32
1461
- elif b == int:
1462
- b = int32
1463
- elif b == builtins.bool:
1464
- b = bool
1465
-
1466
1397
  return a == b
1467
1398
 
1468
1399
 
1469
1400
  def types_equal(a, b, match_generic=False):
1401
+ if match_generic:
1402
+ if a in int_tuple_type_hints and isinstance(b, Sequence):
1403
+ a_length = int_tuple_type_hints[a]
1404
+ if (a_length == -1 or a_length == len(b)) and all(
1405
+ scalars_equal(x, Int, match_generic=match_generic) for x in b
1406
+ ):
1407
+ return True
1408
+ if b in int_tuple_type_hints and isinstance(a, Sequence):
1409
+ b_length = int_tuple_type_hints[b]
1410
+ if (b_length == -1 or b_length == len(a)) and all(
1411
+ scalars_equal(x, Int, match_generic=match_generic) for x in a
1412
+ ):
1413
+ return True
1414
+ if a in int_tuple_type_hints and b in int_tuple_type_hints:
1415
+ a_length = int_tuple_type_hints[a]
1416
+ b_length = int_tuple_type_hints[b]
1417
+ if a_length is None or b_length is None or a_length == b_length:
1418
+ return True
1419
+
1470
1420
  # convert to canonical types
1471
1421
  if a == float:
1472
1422
  a = float32
@@ -1522,6 +1472,61 @@ def check_array_shape(shape: Tuple):
1522
1472
  )
1523
1473
 
1524
1474
 
1475
+ def array_ctype_from_interface(interface: dict, dtype=None, owner=None):
1476
+ """Get native array descriptor (array_t) from __array_interface__ or __cuda_array_interface__ dictionary"""
1477
+
1478
+ ptr = interface.get("data")[0]
1479
+ shape = interface.get("shape")
1480
+ strides = interface.get("strides")
1481
+ typestr = interface.get("typestr")
1482
+
1483
+ element_dtype = dtype_from_numpy(np.dtype(typestr))
1484
+
1485
+ if strides is None:
1486
+ strides = strides_from_shape(shape, element_dtype)
1487
+
1488
+ if dtype is None:
1489
+ # accept verbatum
1490
+ pass
1491
+ elif hasattr(dtype, "_shape_"):
1492
+ # vector/matrix types, ensure element dtype matches
1493
+ if element_dtype != dtype._wp_scalar_type_:
1494
+ raise RuntimeError(
1495
+ f"Could not convert array interface with typestr='{typestr}' to Warp array with dtype={dtype}"
1496
+ )
1497
+ dtype_shape = dtype._shape_
1498
+ dtype_dims = len(dtype._shape_)
1499
+ ctype_size = ctypes.sizeof(dtype._type_)
1500
+ # ensure inner shape matches
1501
+ if dtype_dims > len(shape) or dtype_shape != shape[-dtype_dims:]:
1502
+ raise RuntimeError(
1503
+ f"Could not convert array interface with shape {shape} to Warp array with dtype={dtype}, ensure that source inner shape is {dtype_shape}"
1504
+ )
1505
+ # ensure inner strides are contiguous
1506
+ if strides[-1] != ctype_size or (dtype_dims > 1 and strides[-2] != ctype_size * dtype_shape[-1]):
1507
+ raise RuntimeError(
1508
+ f"Could not convert array interface with shape {shape} to Warp array with dtype={dtype}, because the source inner strides are not contiguous"
1509
+ )
1510
+ # trim shape and strides
1511
+ shape = tuple(shape[:-dtype_dims]) or (1,)
1512
+ strides = tuple(strides[:-dtype_dims]) or (ctype_size,)
1513
+ else:
1514
+ # scalar types, ensure dtype matches
1515
+ if element_dtype != dtype:
1516
+ raise RuntimeError(
1517
+ f"Could not convert array interface with typestr='{typestr}' to Warp array with dtype={dtype}"
1518
+ )
1519
+
1520
+ # create array descriptor
1521
+ array_ctype = array_t(ptr, 0, len(shape), shape, strides)
1522
+
1523
+ # keep owner alive
1524
+ if owner is not None:
1525
+ array_ctype._ref = owner
1526
+
1527
+ return array_ctype
1528
+
1529
+
1525
1530
  class array(Array):
1526
1531
  # member attributes available during code-gen (e.g.: d = array.shape[0])
1527
1532
  # (initialized when needed)
@@ -1631,6 +1636,9 @@ class array(Array):
1631
1636
  else:
1632
1637
  self._init_annotation(dtype, ndim or 1)
1633
1638
 
1639
+ # initialize read flag
1640
+ self.mark_init()
1641
+
1634
1642
  # initialize gradient, if needed
1635
1643
  if self.device is not None:
1636
1644
  if grad is not None:
@@ -1642,6 +1650,9 @@ class array(Array):
1642
1650
  if requires_grad:
1643
1651
  self._alloc_grad()
1644
1652
 
1653
+ # reference to other array
1654
+ self._ref = None
1655
+
1645
1656
  def _init_from_data(self, data, dtype, shape, device, copy, pinned):
1646
1657
  if not hasattr(data, "__len__"):
1647
1658
  raise RuntimeError(f"Data must be a sequence or array, got scalar {data}")
@@ -2164,6 +2175,9 @@ class array(Array):
2164
2175
  """
2165
2176
  Enables A @ B syntax for matrix multiplication
2166
2177
  """
2178
+ if not is_array(other):
2179
+ return NotImplemented
2180
+
2167
2181
  if self.ndim != 2 or other.ndim != 2:
2168
2182
  raise RuntimeError(
2169
2183
  "A has dim = {}, B has dim = {}. If multiplying with @, A and B must have dim = 2.".format(
@@ -2234,6 +2248,33 @@ class array(Array):
2234
2248
  array._vars = {"shape": warp.codegen.Var("shape", shape_t)}
2235
2249
  return array._vars
2236
2250
 
2251
+ def mark_init(self):
2252
+ """Resets this array's read flag"""
2253
+ self._is_read = False
2254
+
2255
+ def mark_read(self):
2256
+ """Marks this array as having been read from in a kernel or recorded function on the tape."""
2257
+ # no additional checks required: it is always safe to set an array to READ
2258
+ self._is_read = True
2259
+
2260
+ # recursively update all parent arrays
2261
+ parent = self._ref
2262
+ while parent is not None:
2263
+ parent._is_read = True
2264
+ parent = parent._ref
2265
+
2266
+ def mark_write(self, **kwargs):
2267
+ """Detect if we are writing to an array that has already been read from"""
2268
+ if self._is_read:
2269
+ if "arg_name" and "kernel_name" and "filename" and "lineno" in kwargs:
2270
+ print(
2271
+ f"Warning: Array {self} passed to argument {kwargs['arg_name']} in kernel {kwargs['kernel_name']} at {kwargs['filename']}:{kwargs['lineno']} is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
2272
+ )
2273
+ else:
2274
+ print(
2275
+ f"Warning: Array {self} is being written to but has already been read from in a previous launch. This may corrupt gradient computation in the backward pass."
2276
+ )
2277
+
2237
2278
  def zero_(self):
2238
2279
  """Zeroes-out the array entries."""
2239
2280
  if self.is_contiguous:
@@ -2241,6 +2282,7 @@ class array(Array):
2241
2282
  self.device.memset(self.ptr, 0, self.size * type_size_in_bytes(self.dtype))
2242
2283
  else:
2243
2284
  self.fill_(0)
2285
+ self.mark_init()
2244
2286
 
2245
2287
  def fill_(self, value):
2246
2288
  """Set all array entries to `value`
@@ -2315,6 +2357,8 @@ class array(Array):
2315
2357
  else:
2316
2358
  warp.context.runtime.core.array_fill_host(carr_ptr, ARRAY_TYPE_REGULAR, cvalue_ptr, cvalue_size)
2317
2359
 
2360
+ self.mark_init()
2361
+
2318
2362
  def assign(self, src):
2319
2363
  """Wraps ``src`` in an :class:`warp.array` if it is not already one and copies the contents to ``self``."""
2320
2364
  if is_array(src):
@@ -2421,6 +2465,9 @@ class array(Array):
2421
2465
  grad=None if self.grad is None else self.grad.flatten(),
2422
2466
  )
2423
2467
 
2468
+ # transfer read flag
2469
+ a._is_read = self._is_read
2470
+
2424
2471
  # store back-ref to stop data being destroyed
2425
2472
  a._ref = self
2426
2473
  return a
@@ -2482,6 +2529,9 @@ class array(Array):
2482
2529
  grad=None if self.grad is None else self.grad.reshape(shape),
2483
2530
  )
2484
2531
 
2532
+ # transfer read flag
2533
+ a._is_read = self._is_read
2534
+
2485
2535
  # store back-ref to stop data being destroyed
2486
2536
  a._ref = self
2487
2537
  return a
@@ -2505,6 +2555,9 @@ class array(Array):
2505
2555
  grad=None if self.grad is None else self.grad.view(dtype),
2506
2556
  )
2507
2557
 
2558
+ # transfer read flag
2559
+ a._is_read = self._is_read
2560
+
2508
2561
  a._ref = self
2509
2562
  return a
2510
2563
 
@@ -2558,6 +2611,9 @@ class array(Array):
2558
2611
 
2559
2612
  a.is_transposed = not self.is_transposed
2560
2613
 
2614
+ # transfer read flag
2615
+ a._is_read = self._is_read
2616
+
2561
2617
  a._ref = self
2562
2618
  return a
2563
2619
 
@@ -2841,6 +2897,11 @@ def array_type_id(a):
2841
2897
 
2842
2898
 
2843
2899
  class Bvh:
2900
+ def __new__(cls, *args, **kwargs):
2901
+ instance = super(Bvh, cls).__new__(cls)
2902
+ instance.id = None
2903
+ return instance
2904
+
2844
2905
  def __init__(self, lowers, uppers):
2845
2906
  """Class representing a bounding volume hierarchy.
2846
2907
 
@@ -2853,8 +2914,6 @@ class Bvh:
2853
2914
  uppers (:class:`warp.array`): Array of upper bounds :class:`warp.vec3`
2854
2915
  """
2855
2916
 
2856
- self.id = 0
2857
-
2858
2917
  if len(lowers) != len(uppers):
2859
2918
  raise RuntimeError("Bvh the same number of lower and upper bounds must be provided")
2860
2919
 
@@ -2916,6 +2975,11 @@ class Mesh:
2916
2975
  "indices": Var("indices", array(dtype=int32)),
2917
2976
  }
2918
2977
 
2978
+ def __new__(cls, *args, **kwargs):
2979
+ instance = super(Mesh, cls).__new__(cls)
2980
+ instance.id = None
2981
+ return instance
2982
+
2919
2983
  def __init__(self, points=None, indices=None, velocities=None, support_winding_number=False):
2920
2984
  """Class representing a triangle mesh.
2921
2985
 
@@ -2930,8 +2994,6 @@ class Mesh:
2930
2994
  support_winding_number (bool): If true the mesh will build additional datastructures to support `wp.mesh_query_point_sign_winding_number()` queries
2931
2995
  """
2932
2996
 
2933
- self.id = 0
2934
-
2935
2997
  if points.device != indices.device:
2936
2998
  raise RuntimeError("Mesh points and indices must live on the same device")
2937
2999
 
@@ -3001,6 +3063,11 @@ class Volume:
3001
3063
  #: Enum value to specify trilinear interpolation during sampling
3002
3064
  LINEAR = constant(1)
3003
3065
 
3066
+ def __new__(cls, *args, **kwargs):
3067
+ instance = super(Volume, cls).__new__(cls)
3068
+ instance.id = None
3069
+ return instance
3070
+
3004
3071
  def __init__(self, data: array, copy: bool = True):
3005
3072
  """Class representing a sparse grid.
3006
3073
 
@@ -3009,8 +3076,6 @@ class Volume:
3009
3076
  copy (bool): Whether the incoming data will be copied or aliased
3010
3077
  """
3011
3078
 
3012
- self.id = 0
3013
-
3014
3079
  # keep a runtime reference for orderly destruction
3015
3080
  self.runtime = warp.context.runtime
3016
3081
 
@@ -3568,9 +3633,39 @@ class Volume:
3568
3633
 
3569
3634
  return cls.allocate_by_tiles(tile_points, voxel_size, bg_value, translation, device)
3570
3635
 
3636
+ @staticmethod
3637
+ def _fill_transform_buffers(
3638
+ voxel_size: Union[float, List[float]],
3639
+ translation,
3640
+ transform,
3641
+ ):
3642
+ if transform is None:
3643
+ if voxel_size is None:
3644
+ raise ValueError("Either 'voxel_size' or 'transform' must be provided")
3645
+
3646
+ if isinstance(voxel_size, float):
3647
+ voxel_size = (voxel_size, voxel_size, voxel_size)
3648
+ transform = mat33f(voxel_size[0], 0.0, 0.0, 0.0, voxel_size[1], 0.0, 0.0, 0.0, voxel_size[2])
3649
+ else:
3650
+ if voxel_size is not None:
3651
+ raise ValueError("Only one of 'voxel_size' or 'transform' must be provided")
3652
+
3653
+ if not isinstance(transform, mat33f):
3654
+ transform = mat33f(transform)
3655
+
3656
+ transform_buf = (ctypes.c_float * 9).from_buffer_copy(transform)
3657
+ translation_buf = (ctypes.c_float * 3)(translation[0], translation[1], translation[2])
3658
+ return transform_buf, translation_buf
3659
+
3571
3660
  @classmethod
3572
3661
  def allocate_by_tiles(
3573
- cls, tile_points: array, voxel_size: float, bg_value=0.0, translation=(0.0, 0.0, 0.0), device=None
3662
+ cls,
3663
+ tile_points: array,
3664
+ voxel_size: Union[float, List[float]] = None,
3665
+ bg_value=0.0,
3666
+ translation=(0.0, 0.0, 0.0),
3667
+ device=None,
3668
+ transform=None,
3574
3669
  ) -> Volume:
3575
3670
  """Allocate a new Volume with active tiles for each point tile_points.
3576
3671
 
@@ -3588,16 +3683,15 @@ class Volume:
3588
3683
  The array may use an integer scalar type (2D N-by-3 array of :class:`warp.int32` or 1D array of `warp.vec3i` values), indicating index space positions,
3589
3684
  or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
3590
3685
  Repeated points per tile are allowed and will be efficiently deduplicated.
3591
- voxel_size (float): Voxel size of the new volume.
3686
+ voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
3592
3687
  bg_value (array-like, float, int or None): Value of unallocated voxels of the volume, also defines the volume's type. A :class:`warp.vec3` volume is created if this is `array-like`, an index volume will be created if `bg_value` is ``None``.
3593
3688
  translation (array-like): Translation between the index and world spaces.
3689
+ transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
3594
3690
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
3595
3691
 
3596
3692
  """
3597
3693
  device = warp.get_device(device)
3598
3694
 
3599
- if voxel_size <= 0.0:
3600
- raise RuntimeError(f"Voxel size must be positive! Got {voxel_size}")
3601
3695
  if not device.is_cuda:
3602
3696
  raise RuntimeError("Only CUDA devices are supported for allocate_by_tiles")
3603
3697
  if not _is_contiguous_vec_like_array(tile_points, vec_length=3, scalar_types=(float32, int32)):
@@ -3610,15 +3704,16 @@ class Volume:
3610
3704
  volume = cls(data=None)
3611
3705
  volume.device = device
3612
3706
  in_world_space = type_scalar_type(tile_points.dtype) == float32
3707
+
3708
+ transform_buf, translation_buf = Volume._fill_transform_buffers(voxel_size, translation, transform)
3709
+
3613
3710
  if bg_value is None:
3614
3711
  volume.id = volume.runtime.core.volume_index_from_tiles_device(
3615
3712
  volume.device.context,
3616
3713
  ctypes.c_void_p(tile_points.ptr),
3617
3714
  tile_points.shape[0],
3618
- voxel_size,
3619
- translation[0],
3620
- translation[1],
3621
- translation[2],
3715
+ transform_buf,
3716
+ translation_buf,
3622
3717
  in_world_space,
3623
3718
  )
3624
3719
  elif hasattr(bg_value, "__len__"):
@@ -3626,38 +3721,30 @@ class Volume:
3626
3721
  volume.device.context,
3627
3722
  ctypes.c_void_p(tile_points.ptr),
3628
3723
  tile_points.shape[0],
3629
- voxel_size,
3630
- bg_value[0],
3631
- bg_value[1],
3632
- bg_value[2],
3633
- translation[0],
3634
- translation[1],
3635
- translation[2],
3724
+ transform_buf,
3725
+ translation_buf,
3636
3726
  in_world_space,
3727
+ (ctypes.c_float * 3)(bg_value[0], bg_value[1], bg_value[2]),
3637
3728
  )
3638
3729
  elif isinstance(bg_value, int):
3639
3730
  volume.id = volume.runtime.core.volume_i_from_tiles_device(
3640
3731
  volume.device.context,
3641
3732
  ctypes.c_void_p(tile_points.ptr),
3642
3733
  tile_points.shape[0],
3643
- voxel_size,
3644
- bg_value,
3645
- translation[0],
3646
- translation[1],
3647
- translation[2],
3734
+ transform_buf,
3735
+ translation_buf,
3648
3736
  in_world_space,
3737
+ bg_value,
3649
3738
  )
3650
3739
  else:
3651
3740
  volume.id = volume.runtime.core.volume_f_from_tiles_device(
3652
3741
  volume.device.context,
3653
3742
  ctypes.c_void_p(tile_points.ptr),
3654
3743
  tile_points.shape[0],
3655
- voxel_size,
3656
- float(bg_value),
3657
- translation[0],
3658
- translation[1],
3659
- translation[2],
3744
+ transform_buf,
3745
+ translation_buf,
3660
3746
  in_world_space,
3747
+ float(bg_value),
3661
3748
  )
3662
3749
 
3663
3750
  if volume.id == 0:
@@ -3667,7 +3754,12 @@ class Volume:
3667
3754
 
3668
3755
  @classmethod
3669
3756
  def allocate_by_voxels(
3670
- cls, voxel_points: array, voxel_size: float, translation=(0.0, 0.0, 0.0), device=None
3757
+ cls,
3758
+ voxel_points: array,
3759
+ voxel_size: Union[float, List[float]] = None,
3760
+ translation=(0.0, 0.0, 0.0),
3761
+ device=None,
3762
+ transform=None,
3671
3763
  ) -> Volume:
3672
3764
  """Allocate a new Volume with active voxel for each point voxel_points.
3673
3765
 
@@ -3682,19 +3774,16 @@ class Volume:
3682
3774
  The array may use an integer scalar type (2D N-by-3 array of :class:`warp.int32` or 1D array of `warp.vec3i` values), indicating index space positions,
3683
3775
  or a floating point scalar type (2D N-by-3 array of :class:`warp.float32` or 1D array of `warp.vec3f` values), indicating world space positions.
3684
3776
  Repeated points per tile are allowed and will be efficiently deduplicated.
3685
- voxel_size (float): Voxel size of the new volume.
3777
+ voxel_size (float or array-like): Voxel size(s) of the new volume. Ignored if `transform` is given.
3686
3778
  translation (array-like): Translation between the index and world spaces.
3779
+ transform (array-like): Linear transform between the index and world spaces. If ``None``, deduced from `voxel_size`.
3687
3780
  device (Devicelike): The CUDA device to create the volume on, e.g.: "cuda" or "cuda:0".
3688
3781
 
3689
3782
  """
3690
3783
  device = warp.get_device(device)
3691
3784
 
3692
- if voxel_size <= 0.0:
3693
- raise RuntimeError(f"Voxel size must be positive! Got {voxel_size}")
3694
3785
  if not device.is_cuda:
3695
3786
  raise RuntimeError("Only CUDA devices are supported for allocate_by_tiles")
3696
- if not (is_array(voxel_points) and voxel_points.is_contiguous):
3697
- raise RuntimeError("tile_points must be a contiguous array")
3698
3787
  if not _is_contiguous_vec_like_array(voxel_points, vec_length=3, scalar_types=(float32, int32)):
3699
3788
  raise RuntimeError(
3700
3789
  "voxel_points must be contiguous and either a 1D warp array of vec3f or vec3i or a 2D n-by-3 array of int32 or float32."
@@ -3706,14 +3795,14 @@ class Volume:
3706
3795
  volume.device = device
3707
3796
  in_world_space = type_scalar_type(voxel_points.dtype) == float32
3708
3797
 
3798
+ transform_buf, translation_buf = Volume._fill_transform_buffers(voxel_size, translation, transform)
3799
+
3709
3800
  volume.id = volume.runtime.core.volume_from_active_voxels_device(
3710
3801
  volume.device.context,
3711
3802
  ctypes.c_void_p(voxel_points.ptr),
3712
3803
  voxel_points.shape[0],
3713
- voxel_size,
3714
- translation[0],
3715
- translation[1],
3716
- translation[2],
3804
+ transform_buf,
3805
+ translation_buf,
3717
3806
  in_world_space,
3718
3807
  )
3719
3808
 
@@ -3765,6 +3854,9 @@ class mesh_query_point_t:
3765
3854
  }
3766
3855
 
3767
3856
 
3857
+ MeshQueryPoint = mesh_query_point_t
3858
+
3859
+
3768
3860
  # definition just for kernel type (cannot be a parameter), see mesh.h
3769
3861
  # NOTE: its layout must match the corresponding struct defined in C.
3770
3862
  class mesh_query_ray_t:
@@ -3796,6 +3888,9 @@ class mesh_query_ray_t:
3796
3888
  }
3797
3889
 
3798
3890
 
3891
+ MeshQueryRay = mesh_query_ray_t
3892
+
3893
+
3799
3894
  def matmul(
3800
3895
  a: array2d,
3801
3896
  b: array2d,
@@ -3852,10 +3947,16 @@ def matmul(
3852
3947
  backward=lambda: adj_matmul(a, b, c, a.grad, b.grad, c.grad, d.grad, alpha, beta, allow_tf32x3_arith),
3853
3948
  arrays=[a, b, c, d],
3854
3949
  )
3950
+ if warp.config.verify_autograd_array_access:
3951
+ d.mark_write()
3952
+ a.mark_read()
3953
+ b.mark_read()
3954
+ c.mark_read()
3855
3955
 
3856
3956
  # cpu fallback if no cuda devices found
3857
3957
  if device == "cpu":
3858
- d.assign(alpha * (a.numpy() @ b.numpy()) + beta * c.numpy())
3958
+ np_dtype = warp_type_to_np_dtype[a.dtype]
3959
+ d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
3859
3960
  return
3860
3961
 
3861
3962
  cc = device.arch
@@ -3971,8 +4072,9 @@ def adj_matmul(
3971
4072
 
3972
4073
  # cpu fallback if no cuda devices found
3973
4074
  if device == "cpu":
3974
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose()) + adj_a.numpy())
3975
- adj_b.assign(alpha * (a.numpy().transpose() @ adj_d.numpy()) + adj_b.numpy())
4075
+ np_dtype = warp_type_to_np_dtype[a.dtype]
4076
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose(), dtype=np_dtype) + adj_a.numpy())
4077
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose(), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
3976
4078
  adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
3977
4079
  return
3978
4080
 
@@ -4135,10 +4237,16 @@ def batched_matmul(
4135
4237
  ),
4136
4238
  arrays=[a, b, c, d],
4137
4239
  )
4240
+ if warp.config.verify_autograd_array_access:
4241
+ d.mark_write()
4242
+ a.mark_read()
4243
+ b.mark_read()
4244
+ c.mark_read()
4138
4245
 
4139
4246
  # cpu fallback if no cuda devices found
4140
4247
  if device == "cpu":
4141
- d.assign(alpha * np.matmul(a.numpy(), b.numpy()) + beta * c.numpy())
4248
+ np_dtype = warp_type_to_np_dtype[a.dtype]
4249
+ d.assign(alpha * np.matmul(a.numpy(), b.numpy(), dtype=np_dtype) + beta * c.numpy())
4142
4250
  return
4143
4251
 
4144
4252
  # handle case in which batch_count exceeds max_batch_count, which is a CUDA array size maximum
@@ -4282,8 +4390,9 @@ def adj_batched_matmul(
4282
4390
 
4283
4391
  # cpu fallback if no cuda devices found
4284
4392
  if device == "cpu":
4285
- adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1))) + adj_a.numpy())
4286
- adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy()) + adj_b.numpy())
4393
+ np_dtype = warp_type_to_np_dtype[a.dtype]
4394
+ adj_a.assign(alpha * np.matmul(adj_d.numpy(), b.numpy().transpose((0, 2, 1)), dtype=np_dtype) + adj_a.numpy())
4395
+ adj_b.assign(alpha * np.matmul(a.numpy().transpose((0, 2, 1)), adj_d.numpy(), dtype=np_dtype) + adj_b.numpy())
4287
4396
  adj_c.assign(beta * adj_d.numpy() + adj_c.numpy())
4288
4397
  return
4289
4398
 
@@ -4487,6 +4596,11 @@ def adj_batched_matmul(
4487
4596
 
4488
4597
 
4489
4598
  class HashGrid:
4599
+ def __new__(cls, *args, **kwargs):
4600
+ instance = super(HashGrid, cls).__new__(cls)
4601
+ instance.id = None
4602
+ return instance
4603
+
4490
4604
  def __init__(self, dim_x, dim_y, dim_z, device=None):
4491
4605
  """Class representing a hash grid object for accelerated point queries.
4492
4606
 
@@ -4500,8 +4614,6 @@ class HashGrid:
4500
4614
  dim_z (int): Number of cells in z-axis
4501
4615
  """
4502
4616
 
4503
- self.id = 0
4504
-
4505
4617
  self.runtime = warp.context.runtime
4506
4618
 
4507
4619
  self.device = self.runtime.get_device(device)
@@ -4559,6 +4671,11 @@ class HashGrid:
4559
4671
 
4560
4672
 
4561
4673
  class MarchingCubes:
4674
+ def __new__(cls, *args, **kwargs):
4675
+ instance = super(MarchingCubes, cls).__new__(cls)
4676
+ instance.id = None
4677
+ return instance
4678
+
4562
4679
  def __init__(self, nx: int, ny: int, nz: int, max_verts: int, max_tris: int, device=None):
4563
4680
  """CUDA-based Marching Cubes algorithm to extract a 2D surface mesh from a 3D volume.
4564
4681