warp-lang 1.0.0b5__py3-none-manylinux2014_x86_64.whl → 1.0.0b6__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (187) hide show
  1. docs/conf.py +3 -4
  2. examples/env/env_ant.py +1 -1
  3. examples/env/env_cartpole.py +1 -1
  4. examples/env/env_humanoid.py +1 -1
  5. examples/example_dem.py +28 -26
  6. examples/example_diffray.py +37 -30
  7. examples/example_fluid.py +7 -3
  8. examples/example_jacobian_ik.py +1 -1
  9. examples/example_mesh_intersect.py +10 -7
  10. examples/example_nvdb.py +3 -3
  11. examples/example_render_opengl.py +19 -10
  12. examples/example_sim_cartpole.py +9 -5
  13. examples/example_sim_cloth.py +29 -25
  14. examples/example_sim_fk_grad.py +2 -2
  15. examples/example_sim_fk_grad_torch.py +3 -3
  16. examples/example_sim_grad_bounce.py +11 -8
  17. examples/example_sim_grad_cloth.py +12 -9
  18. examples/example_sim_granular.py +2 -2
  19. examples/example_sim_granular_collision_sdf.py +13 -13
  20. examples/example_sim_neo_hookean.py +3 -3
  21. examples/example_sim_particle_chain.py +2 -2
  22. examples/example_sim_quadruped.py +8 -5
  23. examples/example_sim_rigid_chain.py +8 -5
  24. examples/example_sim_rigid_contact.py +13 -10
  25. examples/example_sim_rigid_fem.py +2 -2
  26. examples/example_sim_rigid_gyroscopic.py +2 -2
  27. examples/example_sim_rigid_kinematics.py +1 -1
  28. examples/example_sim_trajopt.py +3 -2
  29. examples/fem/example_apic_fluid.py +5 -7
  30. examples/fem/example_diffusion_mgpu.py +18 -16
  31. warp/__init__.py +3 -2
  32. warp/bin/warp.so +0 -0
  33. warp/build_dll.py +29 -9
  34. warp/builtins.py +206 -7
  35. warp/codegen.py +58 -38
  36. warp/config.py +3 -1
  37. warp/context.py +234 -128
  38. warp/fem/__init__.py +2 -2
  39. warp/fem/cache.py +2 -1
  40. warp/fem/field/nodal_field.py +18 -17
  41. warp/fem/geometry/hexmesh.py +11 -6
  42. warp/fem/geometry/quadmesh_2d.py +16 -12
  43. warp/fem/geometry/tetmesh.py +19 -8
  44. warp/fem/geometry/trimesh_2d.py +18 -7
  45. warp/fem/integrate.py +341 -196
  46. warp/fem/quadrature/__init__.py +1 -1
  47. warp/fem/quadrature/pic_quadrature.py +138 -53
  48. warp/fem/quadrature/quadrature.py +81 -9
  49. warp/fem/space/__init__.py +1 -1
  50. warp/fem/space/basis_space.py +169 -51
  51. warp/fem/space/grid_2d_function_space.py +2 -2
  52. warp/fem/space/grid_3d_function_space.py +2 -2
  53. warp/fem/space/hexmesh_function_space.py +2 -2
  54. warp/fem/space/partition.py +9 -6
  55. warp/fem/space/quadmesh_2d_function_space.py +2 -2
  56. warp/fem/space/shape/cube_shape_function.py +27 -15
  57. warp/fem/space/shape/square_shape_function.py +29 -18
  58. warp/fem/space/tetmesh_function_space.py +2 -2
  59. warp/fem/space/topology.py +10 -0
  60. warp/fem/space/trimesh_2d_function_space.py +2 -2
  61. warp/fem/utils.py +10 -5
  62. warp/native/array.h +49 -8
  63. warp/native/builtin.h +31 -14
  64. warp/native/cuda_util.cpp +8 -3
  65. warp/native/cuda_util.h +1 -0
  66. warp/native/exports.h +1177 -1108
  67. warp/native/intersect.h +4 -4
  68. warp/native/intersect_adj.h +8 -8
  69. warp/native/mat.h +65 -6
  70. warp/native/mesh.h +126 -5
  71. warp/native/quat.h +28 -4
  72. warp/native/vec.h +76 -14
  73. warp/native/warp.cu +1 -6
  74. warp/render/render_opengl.py +261 -109
  75. warp/sim/import_mjcf.py +13 -7
  76. warp/sim/import_urdf.py +14 -14
  77. warp/sim/inertia.py +17 -18
  78. warp/sim/model.py +67 -67
  79. warp/sim/render.py +1 -1
  80. warp/sparse.py +6 -6
  81. warp/stubs.py +19 -81
  82. warp/tape.py +1 -1
  83. warp/tests/__main__.py +3 -6
  84. warp/tests/{test_class_kernel.py → aux_test_class_kernel.py} +9 -1
  85. warp/tests/aux_test_conditional_unequal_types_kernels.py +21 -0
  86. warp/tests/{test_dependent.py → aux_test_dependent.py} +2 -2
  87. warp/tests/{test_reference.py → aux_test_reference.py} +1 -1
  88. warp/tests/aux_test_unresolved_func.py +14 -0
  89. warp/tests/aux_test_unresolved_symbol.py +14 -0
  90. warp/tests/{test_kinematics.py → disabled_kinematics.py} +10 -12
  91. warp/tests/run_coverage_serial.py +31 -0
  92. warp/tests/test_adam.py +102 -106
  93. warp/tests/test_arithmetic.py +39 -40
  94. warp/tests/test_array.py +46 -48
  95. warp/tests/test_array_reduce.py +25 -19
  96. warp/tests/test_atomic.py +62 -26
  97. warp/tests/test_bool.py +16 -11
  98. warp/tests/test_builtins_resolution.py +1292 -0
  99. warp/tests/test_bvh.py +9 -12
  100. warp/tests/test_closest_point_edge_edge.py +53 -57
  101. warp/tests/test_codegen.py +164 -134
  102. warp/tests/test_compile_consts.py +13 -19
  103. warp/tests/test_conditional.py +30 -32
  104. warp/tests/test_copy.py +9 -12
  105. warp/tests/test_ctypes.py +90 -98
  106. warp/tests/test_dense.py +20 -14
  107. warp/tests/test_devices.py +34 -35
  108. warp/tests/test_dlpack.py +74 -75
  109. warp/tests/test_examples.py +215 -97
  110. warp/tests/test_fabricarray.py +15 -21
  111. warp/tests/test_fast_math.py +14 -11
  112. warp/tests/test_fem.py +280 -97
  113. warp/tests/test_fp16.py +19 -15
  114. warp/tests/test_func.py +177 -194
  115. warp/tests/test_generics.py +71 -77
  116. warp/tests/test_grad.py +83 -32
  117. warp/tests/test_grad_customs.py +7 -9
  118. warp/tests/test_hash_grid.py +6 -10
  119. warp/tests/test_import.py +9 -23
  120. warp/tests/test_indexedarray.py +19 -21
  121. warp/tests/test_intersect.py +15 -9
  122. warp/tests/test_large.py +17 -19
  123. warp/tests/test_launch.py +14 -17
  124. warp/tests/test_lerp.py +63 -63
  125. warp/tests/test_lvalue.py +84 -35
  126. warp/tests/test_marching_cubes.py +9 -13
  127. warp/tests/test_mat.py +388 -3004
  128. warp/tests/test_mat_lite.py +9 -12
  129. warp/tests/test_mat_scalar_ops.py +2889 -0
  130. warp/tests/test_math.py +10 -11
  131. warp/tests/test_matmul.py +104 -100
  132. warp/tests/test_matmul_lite.py +72 -98
  133. warp/tests/test_mesh.py +35 -32
  134. warp/tests/test_mesh_query_aabb.py +18 -25
  135. warp/tests/test_mesh_query_point.py +39 -23
  136. warp/tests/test_mesh_query_ray.py +9 -21
  137. warp/tests/test_mlp.py +8 -9
  138. warp/tests/test_model.py +89 -93
  139. warp/tests/test_modules_lite.py +15 -25
  140. warp/tests/test_multigpu.py +87 -114
  141. warp/tests/test_noise.py +10 -12
  142. warp/tests/test_operators.py +14 -21
  143. warp/tests/test_options.py +10 -11
  144. warp/tests/test_pinned.py +16 -18
  145. warp/tests/test_print.py +16 -20
  146. warp/tests/test_quat.py +121 -88
  147. warp/tests/test_rand.py +12 -13
  148. warp/tests/test_reload.py +27 -32
  149. warp/tests/test_rounding.py +7 -10
  150. warp/tests/test_runlength_encode.py +105 -106
  151. warp/tests/test_smoothstep.py +8 -9
  152. warp/tests/test_snippet.py +13 -22
  153. warp/tests/test_sparse.py +30 -29
  154. warp/tests/test_spatial.py +179 -174
  155. warp/tests/test_streams.py +100 -107
  156. warp/tests/test_struct.py +98 -67
  157. warp/tests/test_tape.py +11 -17
  158. warp/tests/test_torch.py +89 -86
  159. warp/tests/test_transient_module.py +9 -12
  160. warp/tests/test_types.py +328 -50
  161. warp/tests/test_utils.py +217 -218
  162. warp/tests/test_vec.py +133 -2133
  163. warp/tests/test_vec_lite.py +8 -11
  164. warp/tests/test_vec_scalar_ops.py +2099 -0
  165. warp/tests/test_volume.py +391 -382
  166. warp/tests/test_volume_write.py +122 -135
  167. warp/tests/unittest_serial.py +35 -0
  168. warp/tests/unittest_suites.py +291 -0
  169. warp/tests/{test_base.py → unittest_utils.py} +138 -25
  170. warp/tests/{test_misc.py → unused_test_misc.py} +13 -5
  171. warp/tests/{test_debug.py → walkthough_debug.py} +2 -15
  172. warp/thirdparty/unittest_parallel.py +257 -54
  173. warp/types.py +119 -98
  174. warp/utils.py +14 -0
  175. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/METADATA +2 -1
  176. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/RECORD +182 -178
  177. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/WHEEL +1 -1
  178. warp/tests/test_all.py +0 -239
  179. warp/tests/test_conditional_unequal_types_kernels.py +0 -14
  180. warp/tests/test_coverage.py +0 -38
  181. warp/tests/test_unresolved_func.py +0 -7
  182. warp/tests/test_unresolved_symbol.py +0 -7
  183. /warp/tests/{test_compile_consts_dummy.py → aux_test_compile_consts_dummy.py} +0 -0
  184. /warp/tests/{test_reference_reference.py → aux_test_reference_reference.py} +0 -0
  185. /warp/tests/{test_square.py → aux_test_square.py} +0 -0
  186. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/LICENSE.md +0 -0
  187. {warp_lang-1.0.0b5.dist-info → warp_lang-1.0.0b6.dist-info}/top_level.txt +0 -0
@@ -25,6 +25,7 @@ class SquareBipolynomialShapeFunctions:
25
25
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
26
26
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
27
27
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
28
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
28
29
 
29
30
  @property
30
31
  def name(self) -> str:
@@ -93,13 +94,21 @@ class SquareBipolynomialShapeFunctions:
93
94
  ):
94
95
  return 0.5
95
96
 
97
+ def trace_node_quadrature_weight_open(
98
+ node_index_in_elt: int,
99
+ ):
100
+ return 0.0
101
+
102
+ if not is_closed(self.family):
103
+ return cache.get_func(trace_node_quadrature_weight_open, self.name)
104
+
96
105
  if ORDER == 1:
97
106
  return cache.get_func(trace_node_quadrature_weight_linear, self.name)
98
107
 
99
108
  return cache.get_func(trace_node_quadrature_weight, self.name)
100
109
 
101
110
  def make_element_inner_weight(self):
102
- ORDER = self.ORDER
111
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
103
112
  LOBATTO_COORDS = self.LOBATTO_COORDS
104
113
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
105
114
 
@@ -107,11 +116,11 @@ class SquareBipolynomialShapeFunctions:
107
116
  coords: Coords,
108
117
  node_index_in_elt: int,
109
118
  ):
110
- node_i = node_index_in_elt // (ORDER + 1)
111
- node_j = node_index_in_elt - (ORDER + 1) * node_i
119
+ node_i = node_index_in_elt // ORDER_PLUS_ONE
120
+ node_j = node_index_in_elt - ORDER_PLUS_ONE * node_i
112
121
 
113
122
  w = float(1.0)
114
- for k in range(ORDER + 1):
123
+ for k in range(ORDER_PLUS_ONE):
115
124
  if k != node_i:
116
125
  w *= coords[0] - LOBATTO_COORDS[k]
117
126
  if k != node_j:
@@ -131,13 +140,13 @@ class SquareBipolynomialShapeFunctions:
131
140
  wy = (1.0 - coords[1]) * (1.0 - v[1]) + v[1] * coords[1]
132
141
  return wx * wy
133
142
 
134
- if ORDER == 1:
143
+ if self.ORDER == 1 and is_closed(self.family):
135
144
  return cache.get_func(element_inner_weight_linear, self.name)
136
145
 
137
146
  return cache.get_func(element_inner_weight, self.name)
138
147
 
139
148
  def make_element_inner_weight_gradient(self):
140
- ORDER = self.ORDER
149
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
141
150
  LOBATTO_COORDS = self.LOBATTO_COORDS
142
151
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
143
152
 
@@ -145,12 +154,12 @@ class SquareBipolynomialShapeFunctions:
145
154
  coords: Coords,
146
155
  node_index_in_elt: int,
147
156
  ):
148
- node_i = node_index_in_elt // (ORDER + 1)
149
- node_j = node_index_in_elt - (ORDER + 1) * node_i
157
+ node_i = node_index_in_elt // ORDER_PLUS_ONE
158
+ node_j = node_index_in_elt - ORDER_PLUS_ONE * node_i
150
159
 
151
160
  prefix_x = float(1.0)
152
161
  prefix_y = float(1.0)
153
- for k in range(ORDER + 1):
162
+ for k in range(ORDER_PLUS_ONE):
154
163
  if k != node_i:
155
164
  prefix_y *= coords[0] - LOBATTO_COORDS[k]
156
165
  if k != node_j:
@@ -159,7 +168,7 @@ class SquareBipolynomialShapeFunctions:
159
168
  grad_x = float(0.0)
160
169
  grad_y = float(0.0)
161
170
 
162
- for k in range(ORDER + 1):
171
+ for k in range(ORDER_PLUS_ONE):
163
172
  if k != node_i:
164
173
  delta_x = coords[0] - LOBATTO_COORDS[k]
165
174
  grad_x = grad_x * delta_x + prefix_x
@@ -187,7 +196,7 @@ class SquareBipolynomialShapeFunctions:
187
196
 
188
197
  return wp.vec2(dx * wy, dy * wx)
189
198
 
190
- if ORDER == 1:
199
+ if self.ORDER == 1 and is_closed(self.family):
191
200
  return cache.get_func(element_inner_weight_gradient_linear, self.name)
192
201
 
193
202
  return cache.get_func(element_inner_weight_gradient, self.name)
@@ -230,6 +239,7 @@ class SquareSerendipityShapeFunctions:
230
239
  self.LOBATTO_COORDS = wp.constant(NodeVec(lobatto_coords))
231
240
  self.LOBATTO_WEIGHT = wp.constant(NodeVec(lobatto_weight))
232
241
  self.LAGRANGE_SCALE = wp.constant(NodeVec(lagrange_scale))
242
+ self.ORDER_PLUS_ONE = wp.constant(self.ORDER + 1)
233
243
 
234
244
  self.node_type_and_type_index = self._get_node_type_and_type_index()
235
245
  self._node_lobatto_indices = self._get_node_lobatto_indices()
@@ -328,6 +338,7 @@ class SquareSerendipityShapeFunctions:
328
338
 
329
339
  def make_element_inner_weight(self):
330
340
  ORDER = self.ORDER
341
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
331
342
 
332
343
  LOBATTO_COORDS = self.LOBATTO_COORDS
333
344
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
@@ -361,7 +372,7 @@ class SquareSerendipityShapeFunctions:
361
372
  if node_type == SquareSerendipityShapeFunctions.EDGE_Y:
362
373
  w *= wp.select(node_i == 0, coords[0], 1.0 - coords[0])
363
374
  else:
364
- for k in range(ORDER + 1):
375
+ for k in range(ORDER_PLUS_ONE):
365
376
  if k != node_i:
366
377
  w *= coords[0] - LOBATTO_COORDS[k]
367
378
 
@@ -370,7 +381,7 @@ class SquareSerendipityShapeFunctions:
370
381
  if node_type == SquareSerendipityShapeFunctions.EDGE_X:
371
382
  w *= wp.select(node_j == 0, coords[1], 1.0 - coords[1])
372
383
  else:
373
- for k in range(ORDER + 1):
384
+ for k in range(ORDER_PLUS_ONE):
374
385
  if k != node_j:
375
386
  w *= coords[1] - LOBATTO_COORDS[k]
376
387
  w *= LAGRANGE_SCALE[node_j]
@@ -381,6 +392,7 @@ class SquareSerendipityShapeFunctions:
381
392
 
382
393
  def make_element_inner_weight_gradient(self):
383
394
  ORDER = self.ORDER
395
+ ORDER_PLUS_ONE = self.ORDER_PLUS_ONE
384
396
  LOBATTO_COORDS = self.LOBATTO_COORDS
385
397
  LAGRANGE_SCALE = self.LAGRANGE_SCALE
386
398
 
@@ -424,7 +436,7 @@ class SquareSerendipityShapeFunctions:
424
436
  prefix_x = wp.select(node_j == 0, coords[1], 1.0 - coords[1])
425
437
  else:
426
438
  prefix_x = LAGRANGE_SCALE[node_j]
427
- for k in range(ORDER + 1):
439
+ for k in range(ORDER_PLUS_ONE):
428
440
  if k != node_j:
429
441
  prefix_x *= coords[1] - LOBATTO_COORDS[k]
430
442
 
@@ -432,7 +444,7 @@ class SquareSerendipityShapeFunctions:
432
444
  prefix_y = wp.select(node_i == 0, coords[0], 1.0 - coords[0])
433
445
  else:
434
446
  prefix_y = LAGRANGE_SCALE[node_i]
435
- for k in range(ORDER + 1):
447
+ for k in range(ORDER_PLUS_ONE):
436
448
  if k != node_i:
437
449
  prefix_y *= coords[0] - LOBATTO_COORDS[k]
438
450
 
@@ -441,7 +453,7 @@ class SquareSerendipityShapeFunctions:
441
453
  else:
442
454
  prefix_y *= LAGRANGE_SCALE[node_j]
443
455
  grad_y = float(0.0)
444
- for k in range(ORDER + 1):
456
+ for k in range(ORDER_PLUS_ONE):
445
457
  if k != node_j:
446
458
  delta_y = coords[1] - LOBATTO_COORDS[k]
447
459
  grad_y = grad_y * delta_y + prefix_y
@@ -452,7 +464,7 @@ class SquareSerendipityShapeFunctions:
452
464
  else:
453
465
  prefix_x *= LAGRANGE_SCALE[node_i]
454
466
  grad_x = float(0.0)
455
- for k in range(ORDER + 1):
467
+ for k in range(ORDER_PLUS_ONE):
456
468
  if k != node_i:
457
469
  delta_x = coords[0] - LOBATTO_COORDS[k]
458
470
  grad_x = grad_x * delta_x + prefix_x
@@ -530,7 +542,6 @@ class SquareNonConformingPolynomialShapeFunctions:
530
542
  NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
531
543
 
532
544
  if self.ORDER == 2:
533
-
534
545
  # Intrinsic quadrature (order 2)
535
546
  @cache.dynamic_func(suffix=self.name)
536
547
  def node_quadrature_weight_quadratic(
@@ -5,7 +5,7 @@ from warp.fem.geometry import Tetmesh
5
5
  from warp.fem import cache
6
6
 
7
7
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
8
- from .basis_space import BasisSpace, TraceBasisSpace
8
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
9
9
 
10
10
  from .shape import ShapeFunction, ConstantShapeFunction
11
11
  from .shape import TetrahedronPolynomialShapeFunctions, TetrahedronNonConformingPolynomialShapeFunctions
@@ -136,7 +136,7 @@ class TetmeshDiscontinuousSpaceTopology(
136
136
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
137
137
 
138
138
 
139
- class TetmeshBasisSpace(BasisSpace):
139
+ class TetmeshBasisSpace(ShapeBasisSpace):
140
140
  def __init__(self, topology: TetmeshSpaceTopology, shape: ShapeFunction):
141
141
  super().__init__(topology, shape)
142
142
 
@@ -227,6 +227,10 @@ class DiscontinuousSpaceTopologyMixin:
227
227
  def node_count(self):
228
228
  return self.geometry.cell_count() * self.NODES_PER_ELEMENT
229
229
 
230
+ @property
231
+ def name(self):
232
+ return f"{self.geometry.name}_D{self.NODES_PER_ELEMENT}"
233
+
230
234
  def _make_element_node_index(self):
231
235
  NODES_PER_ELEMENT = self.NODES_PER_ELEMENT
232
236
 
@@ -242,6 +246,12 @@ class DiscontinuousSpaceTopologyMixin:
242
246
  return element_node_index
243
247
 
244
248
 
249
+ class DiscontinuousSpaceTopology(DiscontinuousSpaceTopologyMixin, SpaceTopology):
250
+ """Topology for generic discontinuous spaces"""
251
+
252
+ pass
253
+
254
+
245
255
  class DeformedGeometrySpaceTopology(SpaceTopology):
246
256
  def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
247
257
  super().__init__(geometry, base_topology.NODES_PER_ELEMENT)
@@ -5,7 +5,7 @@ from warp.fem.geometry import Trimesh2D
5
5
  from warp.fem import cache
6
6
 
7
7
  from .topology import SpaceTopology, DiscontinuousSpaceTopologyMixin, forward_base_topology
8
- from .basis_space import BasisSpace, TraceBasisSpace
8
+ from .basis_space import ShapeBasisSpace, TraceBasisSpace
9
9
 
10
10
  from .shape import ShapeFunction, ConstantShapeFunction
11
11
  from .shape import Triangle2DPolynomialShapeFunctions, Triangle2DNonConformingPolynomialShapeFunctions
@@ -101,7 +101,7 @@ class Trimesh2DDiscontinuousSpaceTopology(
101
101
  super().__init__(mesh, shape.NODES_PER_ELEMENT)
102
102
 
103
103
 
104
- class Trimesh2DBasisSpace(BasisSpace):
104
+ class Trimesh2DBasisSpace(ShapeBasisSpace):
105
105
  def __init__(self, topology: Trimesh2DSpaceTopology, shape: ShapeFunction):
106
106
  super().__init__(topology, shape)
107
107
 
warp/fem/utils.py CHANGED
@@ -1,10 +1,15 @@
1
1
  from typing import Any, Tuple
2
2
 
3
- import warp as wp
4
3
  import numpy as np
5
4
 
6
- from warp.utils import radix_sort_pairs, runlength_encode, array_scan
7
- from warp.fem.cache import borrow_temporary, borrow_temporary_like, TemporaryStore, Temporary
5
+ import warp as wp
6
+ from warp.fem.cache import (
7
+ Temporary,
8
+ TemporaryStore,
9
+ borrow_temporary,
10
+ borrow_temporary_like,
11
+ )
12
+ from warp.utils import array_scan, radix_sort_pairs, runlength_encode
8
13
 
9
14
 
10
15
  @wp.func
@@ -168,7 +173,7 @@ def compress_node_indices(
168
173
  if node_indices.device.is_cuda:
169
174
  unique_node_count_host = borrow_temporary(temporary_store, shape=(1,), dtype=int, pinned=True, device="cpu")
170
175
  wp.copy(src=unique_node_count_dev.array, dest=unique_node_count_host.array, count=1)
171
- wp.synchronize_stream(wp.get_stream())
176
+ wp.synchronize_stream(wp.get_stream(node_indices.device))
172
177
  unique_node_count_dev.release()
173
178
  unique_node_count = int(unique_node_count_host.array.numpy()[0])
174
179
  unique_node_count_host.release()
@@ -217,7 +222,7 @@ def masked_indices(
217
222
  if offsets.device.is_cuda:
218
223
  masked_count_temp = borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
219
224
  wp.copy(dest=masked_count_temp.array, src=offsets, src_offset=offsets.shape[0] - 1, count=1)
220
- wp.synchronize_stream(wp.get_stream())
225
+ wp.synchronize_stream(wp.get_stream(offsets.device))
221
226
  masked_count = int(masked_count_temp.array.numpy()[0])
222
227
  masked_count_temp.release()
223
228
  else:
warp/native/array.h CHANGED
@@ -951,23 +951,64 @@ inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k,
951
951
  template<template<typename> class A1, template<typename> class A2, typename T>
952
952
  inline CUDA_CALLABLE void adj_atomic_sub(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
953
953
 
954
+ // generic handler for scalar values
954
955
  template<template<typename> class A1, template<typename> class A2, typename T>
955
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
956
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
957
+ if (buf.grad)
958
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
959
+
960
+ FP_VERIFY_ADJ_1(value, adj_value)
961
+ }
956
962
  template<template<typename> class A1, template<typename> class A2, typename T>
957
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
963
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
964
+ if (buf.grad)
965
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
966
+
967
+ FP_VERIFY_ADJ_2(value, adj_value)
968
+ }
958
969
  template<template<typename> class A1, template<typename> class A2, typename T>
959
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
970
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
971
+ if (buf.grad)
972
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
973
+
974
+ FP_VERIFY_ADJ_3(value, adj_value)
975
+ }
960
976
  template<template<typename> class A1, template<typename> class A2, typename T>
961
- inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
977
+ inline CUDA_CALLABLE void adj_atomic_min(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
978
+ if (buf.grad)
979
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
980
+
981
+ FP_VERIFY_ADJ_4(value, adj_value)
982
+ }
962
983
 
963
984
  template<template<typename> class A1, template<typename> class A2, typename T>
964
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {}
985
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, T value, const A2<T>& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {
986
+ if (buf.grad)
987
+ adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value);
988
+
989
+ FP_VERIFY_ADJ_1(value, adj_value)
990
+ }
965
991
  template<template<typename> class A1, template<typename> class A2, typename T>
966
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {}
992
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {
993
+ if (buf.grad)
994
+ adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value);
995
+
996
+ FP_VERIFY_ADJ_2(value, adj_value)
997
+ }
967
998
  template<template<typename> class A1, template<typename> class A2, typename T>
968
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {}
999
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {
1000
+ if (buf.grad)
1001
+ adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value);
1002
+
1003
+ FP_VERIFY_ADJ_3(value, adj_value)
1004
+ }
969
1005
  template<template<typename> class A1, template<typename> class A2, typename T>
970
- inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {}
1006
+ inline CUDA_CALLABLE void adj_atomic_max(const A1<T>& buf, int i, int j, int k, int l, T value, const A2<T>& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {
1007
+ if (buf.grad)
1008
+ adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value);
1009
+
1010
+ FP_VERIFY_ADJ_4(value, adj_value)
1011
+ }
971
1012
 
972
1013
  } // namespace wp
973
1014
 
warp/native/builtin.h CHANGED
@@ -295,7 +295,7 @@ inline CUDA_CALLABLE T rshift(T a, T b) { return a>>b; } \
295
295
  inline CUDA_CALLABLE T invert(T x) { return ~x; } \
296
296
  inline CUDA_CALLABLE bool isfinite(T x) { return true; } \
297
297
  inline CUDA_CALLABLE void adj_mul(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
298
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret) { } \
299
299
  inline CUDA_CALLABLE void adj_add(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
300
300
  inline CUDA_CALLABLE void adj_sub(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
301
301
  inline CUDA_CALLABLE void adj_mod(T a, T b, T& adj_a, T& adj_b, T adj_ret) { } \
@@ -443,10 +443,10 @@ inline CUDA_CALLABLE T div(T a, T b)\
443
443
  })\
444
444
  return a/b;\
445
445
  }\
446
- inline CUDA_CALLABLE void adj_div(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
446
+ inline CUDA_CALLABLE void adj_div(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
447
447
  {\
448
448
  adj_a += adj_ret/b;\
449
- adj_b -= adj_ret*(a/b)/b;\
449
+ adj_b -= adj_ret*(ret)/b;\
450
450
  DO_IF_FPCHECK(\
451
451
  if (!isfinite(adj_a) || !isfinite(adj_b))\
452
452
  {\
@@ -859,11 +859,11 @@ inline CUDA_CALLABLE void adj_log10(T a, T& adj_a, T adj_ret)\
859
859
  assert(0);\
860
860
  })\
861
861
  }\
862
- inline CUDA_CALLABLE void adj_exp(T a, T& adj_a, T adj_ret) { adj_a += exp(a)*adj_ret; }\
863
- inline CUDA_CALLABLE void adj_pow(T a, T b, T& adj_a, T& adj_b, T adj_ret)\
862
+ inline CUDA_CALLABLE void adj_exp(T a, T ret, T& adj_a, T adj_ret) { adj_a += ret*adj_ret; }\
863
+ inline CUDA_CALLABLE void adj_pow(T a, T b, T ret, T& adj_a, T& adj_b, T adj_ret)\
864
864
  { \
865
865
  adj_a += b*pow(a, b-T(1))*adj_ret;\
866
- adj_b += log(a)*pow(a, b)*adj_ret;\
866
+ adj_b += log(a)*ret*adj_ret;\
867
867
  DO_IF_FPCHECK(if (!isfinite(adj_a) || !isfinite(adj_b))\
868
868
  {\
869
869
  printf("%s:%d - adj_pow(%f, %f, %f, %f, %f)\n", __FILE__, __LINE__, float(a), float(b), float(adj_a), float(adj_b), float(adj_ret));\
@@ -962,24 +962,22 @@ inline CUDA_CALLABLE void adj_cosh(T x, T& adj_x, T adj_ret)\
962
962
  {\
963
963
  adj_x += sinh(x)*adj_ret;\
964
964
  }\
965
- inline CUDA_CALLABLE void adj_tanh(T x, T& adj_x, T adj_ret)\
965
+ inline CUDA_CALLABLE void adj_tanh(T x, T ret, T& adj_x, T adj_ret)\
966
966
  {\
967
- T tanh_x = tanh(x);\
968
- adj_x += (T(1) - tanh_x*tanh_x)*adj_ret;\
967
+ adj_x += (T(1) - ret*ret)*adj_ret;\
969
968
  }\
970
- inline CUDA_CALLABLE void adj_sqrt(T x, T& adj_x, T adj_ret)\
969
+ inline CUDA_CALLABLE void adj_sqrt(T x, T ret, T& adj_x, T adj_ret)\
971
970
  {\
972
- adj_x += T(0.5)*(T(1)/sqrt(x))*adj_ret;\
971
+ adj_x += T(0.5)*(T(1)/ret)*adj_ret;\
973
972
  DO_IF_FPCHECK(if (!isfinite(adj_x))\
974
973
  {\
975
974
  printf("%s:%d - adj_sqrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
976
975
  assert(0);\
977
976
  })\
978
977
  }\
979
- inline CUDA_CALLABLE void adj_cbrt(T x, T& adj_x, T adj_ret)\
978
+ inline CUDA_CALLABLE void adj_cbrt(T x, T ret, T& adj_x, T adj_ret)\
980
979
  {\
981
- T cbrt_x = cbrt(x);\
982
- adj_x += (T(1)/T(3))*(T(1)/(cbrt_x*cbrt_x))*adj_ret;\
980
+ adj_x += (T(1)/T(3))*(T(1)/(ret*ret))*adj_ret;\
983
981
  DO_IF_FPCHECK(if (!isfinite(adj_x))\
984
982
  {\
985
983
  printf("%s:%d - adj_cbrt(%f, %f, %f)\n", __FILE__, __LINE__, float(x), float(adj_x), float(adj_ret));\
@@ -1273,6 +1271,25 @@ inline CUDA_CALLABLE int atomic_min(int* address, int val)
1273
1271
  #endif
1274
1272
  }
1275
1273
 
1274
+ // default behavior for adjoint of atomic min/max operation that accumulates gradients for all elements matching the min/max value
1275
+ template <typename T>
1276
+ CUDA_CALLABLE inline void adj_atomic_minmax(T *addr, T *adj_addr, const T &value, T &adj_value)
1277
+ {
1278
+ if (value == *addr)
1279
+ adj_value += *adj_addr;
1280
+ }
1281
+
1282
+ // for integral types we do not accumulate gradients
1283
+ CUDA_CALLABLE inline void adj_atomic_minmax(int8* buf, int8* adj_buf, const int8 &value, int8 &adj_value) { }
1284
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint8* buf, uint8* adj_buf, const uint8 &value, uint8 &adj_value) { }
1285
+ CUDA_CALLABLE inline void adj_atomic_minmax(int16* buf, int16* adj_buf, const int16 &value, int16 &adj_value) { }
1286
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint16* buf, uint16* adj_buf, const uint16 &value, uint16 &adj_value) { }
1287
+ CUDA_CALLABLE inline void adj_atomic_minmax(int32* buf, int32* adj_buf, const int32 &value, int32 &adj_value) { }
1288
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint32* buf, uint32* adj_buf, const uint32 &value, uint32 &adj_value) { }
1289
+ CUDA_CALLABLE inline void adj_atomic_minmax(int64* buf, int64* adj_buf, const int64 &value, int64 &adj_value) { }
1290
+ CUDA_CALLABLE inline void adj_atomic_minmax(uint64* buf, uint64* adj_buf, const uint64 &value, uint64 &adj_value) { }
1291
+ CUDA_CALLABLE inline void adj_atomic_minmax(bool* buf, bool* adj_buf, const bool &value, bool &adj_value) { }
1292
+
1276
1293
 
1277
1294
  } // namespace wp
1278
1295
 
warp/native/cuda_util.cpp CHANGED
@@ -89,6 +89,7 @@ static PFN_cuGraphicsResourceGetMappedPointer_v3020 pfn_cuGraphicsResourceGetMap
89
89
  static PFN_cuGraphicsGLRegisterBuffer_v3000 pfn_cuGraphicsGLRegisterBuffer;
90
90
  static PFN_cuGraphicsUnregisterResource_v3000 pfn_cuGraphicsUnregisterResource;
91
91
 
92
+ static bool cuda_driver_initialized = false;
92
93
 
93
94
  bool ContextGuard::always_restore = false;
94
95
 
@@ -196,11 +197,15 @@ bool init_cuda_driver()
196
197
  get_driver_entry_point("cuGraphicsUnregisterResource", &(void*&)pfn_cuGraphicsUnregisterResource);
197
198
 
198
199
  if (pfn_cuInit)
199
- return check_cu(pfn_cuInit(0));
200
- else
201
- return false;
200
+ cuda_driver_initialized = check_cu(pfn_cuInit(0));
201
+
202
+ return cuda_driver_initialized;
202
203
  }
203
204
 
205
+ bool is_cuda_driver_initialized()
206
+ {
207
+ return cuda_driver_initialized;
208
+ }
204
209
 
205
210
  bool check_cuda_result(cudaError_t code, const char* file, int line)
206
211
  {
warp/native/cuda_util.h CHANGED
@@ -83,6 +83,7 @@ CUresult cuGraphicsUnregisterResource_f(CUgraphicsResource resource);
83
83
 
84
84
 
85
85
  bool init_cuda_driver();
86
+ bool is_cuda_driver_initialized();
86
87
 
87
88
  bool check_cuda_result(cudaError_t code, const char* file, int line);
88
89
  inline bool check_cuda_result(uint64_t code, const char* file, int line)