warp-lang 1.7.2__py3-none-win_amd64.whl → 1.8.0__py3-none-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (181) hide show
  1. warp/__init__.py +3 -1
  2. warp/__init__.pyi +3489 -1
  3. warp/autograd.py +45 -122
  4. warp/bin/warp-clang.dll +0 -0
  5. warp/bin/warp.dll +0 -0
  6. warp/build.py +241 -252
  7. warp/build_dll.py +125 -26
  8. warp/builtins.py +1907 -384
  9. warp/codegen.py +257 -101
  10. warp/config.py +12 -1
  11. warp/constants.py +1 -1
  12. warp/context.py +657 -223
  13. warp/dlpack.py +1 -1
  14. warp/examples/benchmarks/benchmark_cloth.py +2 -2
  15. warp/examples/benchmarks/benchmark_tile_sort.py +155 -0
  16. warp/examples/core/example_sample_mesh.py +1 -1
  17. warp/examples/core/example_spin_lock.py +93 -0
  18. warp/examples/core/example_work_queue.py +118 -0
  19. warp/examples/fem/example_adaptive_grid.py +5 -5
  20. warp/examples/fem/example_apic_fluid.py +1 -1
  21. warp/examples/fem/example_burgers.py +1 -1
  22. warp/examples/fem/example_convection_diffusion.py +9 -6
  23. warp/examples/fem/example_darcy_ls_optimization.py +489 -0
  24. warp/examples/fem/example_deformed_geometry.py +1 -1
  25. warp/examples/fem/example_diffusion.py +2 -2
  26. warp/examples/fem/example_diffusion_3d.py +1 -1
  27. warp/examples/fem/example_distortion_energy.py +1 -1
  28. warp/examples/fem/example_elastic_shape_optimization.py +387 -0
  29. warp/examples/fem/example_magnetostatics.py +5 -3
  30. warp/examples/fem/example_mixed_elasticity.py +5 -3
  31. warp/examples/fem/example_navier_stokes.py +11 -9
  32. warp/examples/fem/example_nonconforming_contact.py +5 -3
  33. warp/examples/fem/example_streamlines.py +8 -3
  34. warp/examples/fem/utils.py +9 -8
  35. warp/examples/interop/example_jax_ffi_callback.py +2 -2
  36. warp/examples/optim/example_drone.py +1 -1
  37. warp/examples/sim/example_cloth.py +1 -1
  38. warp/examples/sim/example_cloth_self_contact.py +48 -54
  39. warp/examples/tile/example_tile_block_cholesky.py +502 -0
  40. warp/examples/tile/example_tile_cholesky.py +2 -1
  41. warp/examples/tile/example_tile_convolution.py +1 -1
  42. warp/examples/tile/example_tile_filtering.py +1 -1
  43. warp/examples/tile/example_tile_matmul.py +1 -1
  44. warp/examples/tile/example_tile_mlp.py +2 -0
  45. warp/fabric.py +7 -7
  46. warp/fem/__init__.py +5 -0
  47. warp/fem/adaptivity.py +1 -1
  48. warp/fem/cache.py +152 -63
  49. warp/fem/dirichlet.py +2 -2
  50. warp/fem/domain.py +136 -6
  51. warp/fem/field/field.py +141 -99
  52. warp/fem/field/nodal_field.py +85 -39
  53. warp/fem/field/virtual.py +97 -52
  54. warp/fem/geometry/adaptive_nanogrid.py +91 -86
  55. warp/fem/geometry/closest_point.py +13 -0
  56. warp/fem/geometry/deformed_geometry.py +102 -40
  57. warp/fem/geometry/element.py +56 -2
  58. warp/fem/geometry/geometry.py +323 -22
  59. warp/fem/geometry/grid_2d.py +157 -62
  60. warp/fem/geometry/grid_3d.py +116 -20
  61. warp/fem/geometry/hexmesh.py +86 -20
  62. warp/fem/geometry/nanogrid.py +166 -86
  63. warp/fem/geometry/partition.py +59 -25
  64. warp/fem/geometry/quadmesh.py +86 -135
  65. warp/fem/geometry/tetmesh.py +47 -119
  66. warp/fem/geometry/trimesh.py +77 -270
  67. warp/fem/integrate.py +107 -52
  68. warp/fem/linalg.py +25 -58
  69. warp/fem/operator.py +124 -27
  70. warp/fem/quadrature/pic_quadrature.py +36 -14
  71. warp/fem/quadrature/quadrature.py +40 -16
  72. warp/fem/space/__init__.py +1 -1
  73. warp/fem/space/basis_function_space.py +66 -46
  74. warp/fem/space/basis_space.py +17 -4
  75. warp/fem/space/dof_mapper.py +1 -1
  76. warp/fem/space/function_space.py +2 -2
  77. warp/fem/space/grid_2d_function_space.py +4 -1
  78. warp/fem/space/hexmesh_function_space.py +4 -2
  79. warp/fem/space/nanogrid_function_space.py +3 -1
  80. warp/fem/space/partition.py +11 -2
  81. warp/fem/space/quadmesh_function_space.py +4 -1
  82. warp/fem/space/restriction.py +5 -2
  83. warp/fem/space/shape/__init__.py +10 -8
  84. warp/fem/space/tetmesh_function_space.py +4 -1
  85. warp/fem/space/topology.py +52 -21
  86. warp/fem/space/trimesh_function_space.py +4 -1
  87. warp/fem/utils.py +53 -8
  88. warp/jax.py +1 -2
  89. warp/jax_experimental/ffi.py +12 -17
  90. warp/jax_experimental/xla_ffi.py +37 -24
  91. warp/math.py +171 -1
  92. warp/native/array.h +99 -0
  93. warp/native/builtin.h +174 -31
  94. warp/native/coloring.cpp +1 -1
  95. warp/native/exports.h +118 -63
  96. warp/native/intersect.h +3 -3
  97. warp/native/mat.h +5 -10
  98. warp/native/mathdx.cpp +11 -5
  99. warp/native/matnn.h +1 -123
  100. warp/native/quat.h +28 -4
  101. warp/native/sparse.cpp +121 -258
  102. warp/native/sparse.cu +181 -274
  103. warp/native/spatial.h +305 -17
  104. warp/native/tile.h +583 -72
  105. warp/native/tile_radix_sort.h +1108 -0
  106. warp/native/tile_reduce.h +237 -2
  107. warp/native/tile_scan.h +240 -0
  108. warp/native/tuple.h +189 -0
  109. warp/native/vec.h +6 -16
  110. warp/native/warp.cpp +36 -4
  111. warp/native/warp.cu +574 -51
  112. warp/native/warp.h +47 -74
  113. warp/optim/linear.py +5 -1
  114. warp/paddle.py +7 -8
  115. warp/py.typed +0 -0
  116. warp/render/render_opengl.py +58 -29
  117. warp/render/render_usd.py +124 -61
  118. warp/sim/__init__.py +9 -0
  119. warp/sim/collide.py +252 -78
  120. warp/sim/graph_coloring.py +8 -1
  121. warp/sim/import_mjcf.py +4 -3
  122. warp/sim/import_usd.py +11 -7
  123. warp/sim/integrator.py +5 -2
  124. warp/sim/integrator_euler.py +1 -1
  125. warp/sim/integrator_featherstone.py +1 -1
  126. warp/sim/integrator_vbd.py +751 -320
  127. warp/sim/integrator_xpbd.py +1 -1
  128. warp/sim/model.py +265 -260
  129. warp/sim/utils.py +10 -7
  130. warp/sparse.py +303 -166
  131. warp/tape.py +52 -51
  132. warp/tests/cuda/test_conditional_captures.py +1046 -0
  133. warp/tests/cuda/test_streams.py +1 -1
  134. warp/tests/geometry/test_volume.py +2 -2
  135. warp/tests/interop/test_dlpack.py +9 -9
  136. warp/tests/interop/test_jax.py +0 -1
  137. warp/tests/run_coverage_serial.py +1 -1
  138. warp/tests/sim/disabled_kinematics.py +2 -2
  139. warp/tests/sim/{test_vbd.py → test_cloth.py} +296 -113
  140. warp/tests/sim/test_collision.py +159 -51
  141. warp/tests/sim/test_coloring.py +15 -1
  142. warp/tests/test_array.py +254 -2
  143. warp/tests/test_array_reduce.py +2 -2
  144. warp/tests/test_atomic_cas.py +299 -0
  145. warp/tests/test_codegen.py +142 -19
  146. warp/tests/test_conditional.py +47 -1
  147. warp/tests/test_ctypes.py +0 -20
  148. warp/tests/test_devices.py +8 -0
  149. warp/tests/test_fabricarray.py +4 -2
  150. warp/tests/test_fem.py +58 -25
  151. warp/tests/test_func.py +42 -1
  152. warp/tests/test_grad.py +1 -1
  153. warp/tests/test_lerp.py +1 -3
  154. warp/tests/test_map.py +481 -0
  155. warp/tests/test_mat.py +1 -24
  156. warp/tests/test_quat.py +6 -15
  157. warp/tests/test_rounding.py +10 -38
  158. warp/tests/test_runlength_encode.py +7 -7
  159. warp/tests/test_smoothstep.py +1 -1
  160. warp/tests/test_sparse.py +51 -2
  161. warp/tests/test_spatial.py +507 -1
  162. warp/tests/test_struct.py +2 -2
  163. warp/tests/test_tuple.py +265 -0
  164. warp/tests/test_types.py +2 -2
  165. warp/tests/test_utils.py +24 -18
  166. warp/tests/tile/test_tile.py +420 -1
  167. warp/tests/tile/test_tile_mathdx.py +518 -14
  168. warp/tests/tile/test_tile_reduce.py +213 -0
  169. warp/tests/tile/test_tile_shared_memory.py +130 -1
  170. warp/tests/tile/test_tile_sort.py +117 -0
  171. warp/tests/unittest_suites.py +4 -6
  172. warp/types.py +462 -308
  173. warp/utils.py +647 -86
  174. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/METADATA +20 -6
  175. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/RECORD +178 -166
  176. warp/stubs.py +0 -3381
  177. warp/tests/sim/test_xpbd.py +0 -399
  178. warp/tests/test_mlp.py +0 -282
  179. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/WHEEL +0 -0
  180. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/licenses/LICENSE.md +0 -0
  181. {warp_lang-1.7.2.dist-info → warp_lang-1.8.0.dist-info}/top_level.txt +0 -0
@@ -13,7 +13,7 @@
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
 
16
- from typing import Optional, Tuple, Type
16
+ from typing import ClassVar, Optional, Tuple, Type
17
17
 
18
18
  import warp as wp
19
19
  from warp.fem import cache
@@ -39,6 +39,12 @@ class SpaceTopology:
39
39
  .. note:: This will change to be defined per-element in future versions
40
40
  """
41
41
 
42
+ _dynamic_attribute_constructors: ClassVar = {
43
+ "element_node_count": lambda obj: obj._make_constant_element_node_count(),
44
+ "element_node_sign": lambda obj: obj._make_constant_element_node_sign(),
45
+ "side_neighbor_node_counts": lambda obj: obj._make_constant_side_neighbor_node_counts(),
46
+ }
47
+
42
48
  @wp.struct
43
49
  class TopologyArg:
44
50
  """Structure containing arguments to be passed to device functions"""
@@ -51,8 +57,7 @@ class SpaceTopology:
51
57
  self.MAX_NODES_PER_ELEMENT = wp.constant(max_nodes_per_element)
52
58
  self.ElementArg = geometry.CellArg
53
59
 
54
- self._make_constant_element_node_count()
55
- self._make_constant_element_node_sign()
60
+ cache.setup_dynamic_attributes(self, cls=__class__)
56
61
 
57
62
  @property
58
63
  def geometry(self) -> Geometry:
@@ -67,6 +72,9 @@ class SpaceTopology:
67
72
  """Value of the topology argument structure to be passed to device functions"""
68
73
  return SpaceTopology.TopologyArg()
69
74
 
75
+ def fill_topo_arg(self, arg, device):
76
+ pass
77
+
70
78
  @property
71
79
  def name(self):
72
80
  return f"{self.__class__.__name__}_{self.MAX_NODES_PER_ELEMENT}"
@@ -182,6 +190,11 @@ class SpaceTopology:
182
190
  ):
183
191
  return NODES_PER_ELEMENT
184
192
 
193
+ return constant_element_node_count
194
+
195
+ def _make_constant_side_neighbor_node_counts(self):
196
+ NODES_PER_ELEMENT = wp.constant(self.MAX_NODES_PER_ELEMENT)
197
+
185
198
  @cache.dynamic_func(suffix=self.name)
186
199
  def constant_side_neighbor_node_counts(
187
200
  side_arg: self.geometry.SideArg,
@@ -189,8 +202,7 @@ class SpaceTopology:
189
202
  ):
190
203
  return NODES_PER_ELEMENT, NODES_PER_ELEMENT
191
204
 
192
- self.element_node_count = constant_element_node_count
193
- self.side_neighbor_node_counts = constant_side_neighbor_node_counts
205
+ return constant_side_neighbor_node_counts
194
206
 
195
207
  def _make_constant_element_node_sign(self):
196
208
  @cache.dynamic_func(suffix=self.name)
@@ -202,12 +214,21 @@ class SpaceTopology:
202
214
  ):
203
215
  return 1.0
204
216
 
205
- self.element_node_sign = constant_element_node_sign
217
+ return constant_element_node_sign
206
218
 
207
219
 
208
220
  class TraceSpaceTopology(SpaceTopology):
209
221
  """Auto-generated trace topology defining the node indices associated to the geometry sides"""
210
222
 
223
+ _dynamic_attribute_constructors: ClassVar = {
224
+ "inner_cell_index": lambda obj: obj._make_inner_cell_index(),
225
+ "outer_cell_index": lambda obj: obj._make_outer_cell_index(),
226
+ "neighbor_cell_index": lambda obj: obj._make_neighbor_cell_index(),
227
+ "element_node_index": lambda obj: obj._make_element_node_index(),
228
+ "element_node_count": lambda obj: obj._make_element_node_count(),
229
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
230
+ }
231
+
211
232
  def __init__(self, topo: SpaceTopology):
212
233
  self._topo = topo
213
234
 
@@ -218,14 +239,10 @@ class TraceSpaceTopology(SpaceTopology):
218
239
 
219
240
  self.TopologyArg = topo.TopologyArg
220
241
  self.topo_arg_value = topo.topo_arg_value
242
+ self.fill_topo_arg = topo.fill_topo_arg
221
243
 
222
- self.inner_cell_index = self._make_inner_cell_index()
223
- self.outer_cell_index = self._make_outer_cell_index()
224
- self.neighbor_cell_index = self._make_neighbor_cell_index()
225
-
226
- self.element_node_index = self._make_element_node_index()
227
- self.element_node_count = self._make_element_node_count()
228
244
  self.side_neighbor_node_counts = None
245
+ cache.setup_dynamic_attributes(self, cls=__class__)
229
246
 
230
247
  def node_count(self) -> int:
231
248
  return self._topo.node_count()
@@ -354,21 +371,29 @@ class RegularDiscontinuousSpaceTopology(RegularDiscontinuousSpaceTopologyMixin,
354
371
 
355
372
 
356
373
  class DeformedGeometrySpaceTopology(SpaceTopology):
374
+ _dynamic_attribute_constructors: ClassVar = {
375
+ "element_node_index": lambda obj: obj._make_element_node_index(),
376
+ "element_node_count": lambda obj: obj._make_element_node_count(),
377
+ "element_node_sign": lambda obj: obj._make_element_node_sign(),
378
+ "side_neighbor_node_counts": lambda obj: obj._make_side_neighbor_node_counts(),
379
+ }
380
+
357
381
  def __init__(self, geometry: DeformedGeometry, base_topology: SpaceTopology):
358
382
  self.base = base_topology
359
383
  super().__init__(geometry, base_topology.MAX_NODES_PER_ELEMENT)
360
384
 
361
385
  self.node_count = self.base.node_count
362
386
  self.topo_arg_value = self.base.topo_arg_value
387
+ self.fill_topo_arg = self.base.fill_topo_arg
363
388
  self.TopologyArg = self.base.TopologyArg
364
389
 
365
- self._make_passthrough_functions()
390
+ cache.setup_dynamic_attributes(self, cls=__class__)
366
391
 
367
392
  @property
368
393
  def name(self):
369
394
  return f"{self.base.name}_{self.geometry.field.name}"
370
395
 
371
- def _make_passthrough_functions(self):
396
+ def _make_element_node_index(self):
372
397
  @cache.dynamic_func(suffix=self.name)
373
398
  def element_node_index(
374
399
  elt_arg: self.geometry.CellArg,
@@ -376,16 +401,22 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
376
401
  element_index: ElementIndex,
377
402
  node_index_in_elt: int,
378
403
  ):
379
- return self.base.element_node_index(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
404
+ return self.base.element_node_index(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
405
+
406
+ return element_node_index
380
407
 
408
+ def _make_element_node_count(self):
381
409
  @cache.dynamic_func(suffix=self.name)
382
410
  def element_node_count(
383
411
  elt_arg: self.geometry.CellArg,
384
412
  topo_arg: self.TopologyArg,
385
413
  element_count: ElementIndex,
386
414
  ):
387
- return self.base.element_node_count(elt_arg.elt_arg, topo_arg, element_count)
415
+ return self.base.element_node_count(elt_arg.base_arg, topo_arg, element_count)
416
+
417
+ return element_node_count
388
418
 
419
+ def _make_side_neighbor_node_counts(self):
389
420
  @cache.dynamic_func(suffix=self.name)
390
421
  def side_neighbor_node_counts(
391
422
  side_arg: self.geometry.SideArg,
@@ -394,6 +425,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
394
425
  inner_count, outer_count = self.base.side_neighbor_node_counts(side_arg.base_arg, element_index)
395
426
  return inner_count, outer_count
396
427
 
428
+ return side_neighbor_node_counts
429
+
430
+ def _make_element_node_sign(self):
397
431
  @cache.dynamic_func(suffix=self.name)
398
432
  def element_node_sign(
399
433
  elt_arg: self.geometry.CellArg,
@@ -401,12 +435,9 @@ class DeformedGeometrySpaceTopology(SpaceTopology):
401
435
  element_index: ElementIndex,
402
436
  node_index_in_elt: int,
403
437
  ):
404
- return self.base.element_node_sign(elt_arg.elt_arg, topo_arg, element_index, node_index_in_elt)
438
+ return self.base.element_node_sign(elt_arg.base_arg, topo_arg, element_index, node_index_in_elt)
405
439
 
406
- self.element_node_index = element_node_index
407
- self.element_node_count = element_node_count
408
- self.element_node_sign = element_node_sign
409
- self.side_neighbor_node_counts = side_neighbor_node_counts
440
+ return element_node_sign
410
441
 
411
442
 
412
443
  def forward_base_topology(topology_class: Type[SpaceTopology], geometry: Geometry, *args, **kwargs) -> SpaceTopology:
@@ -50,12 +50,15 @@ class TrimeshSpaceTopology(SpaceTopology):
50
50
  @cache.cached_arg_value
51
51
  def topo_arg_value(self, device):
52
52
  arg = TrimeshTopologyArg()
53
+ self.fill_topo_arg(arg, device)
54
+ return arg
55
+
56
+ def fill_topo_arg(self, arg: TrimeshTopologyArg, device):
53
57
  arg.tri_edge_indices = self._tri_edge_indices.to(device)
54
58
  arg.edge_vertex_indices = self._mesh.edge_vertex_indices.to(device)
55
59
 
56
60
  arg.vertex_count = self._mesh.vertex_count()
57
61
  arg.edge_count = self._mesh.side_count()
58
- return arg
59
62
 
60
63
  def _compute_tri_edge_indices(self):
61
64
  self._tri_edge_indices = wp.empty(
warp/fem/utils.py CHANGED
@@ -19,6 +19,7 @@ import numpy as np
19
19
 
20
20
  import warp as wp
21
21
  import warp.fem.cache as cache
22
+ import warp.types
22
23
  from warp.fem.linalg import ( # noqa: F401 (for backward compatibility, not part of public API but used in examples)
23
24
  array_axpy,
24
25
  inverse_qr,
@@ -28,6 +29,57 @@ from warp.fem.types import NULL_NODE_INDEX
28
29
  from warp.utils import array_scan, radix_sort_pairs, runlength_encode
29
30
 
30
31
 
32
+ def type_zero_element(dtype):
33
+ suffix = warp.types.get_type_code(dtype)
34
+
35
+ if dtype in warp.types.scalar_types:
36
+
37
+ @cache.dynamic_func(suffix=suffix)
38
+ def zero_element():
39
+ return dtype(0.0)
40
+
41
+ return zero_element
42
+
43
+ @cache.dynamic_func(suffix=suffix)
44
+ def zero_element():
45
+ return dtype()
46
+
47
+ return zero_element
48
+
49
+
50
+ def type_basis_element(dtype):
51
+ suffix = warp.types.get_type_code(dtype)
52
+
53
+ if dtype in warp.types.scalar_types:
54
+
55
+ @cache.dynamic_func(suffix=suffix)
56
+ def basis_element(coord: int):
57
+ return dtype(1.0)
58
+
59
+ return basis_element
60
+
61
+ if warp.types.type_is_matrix(dtype):
62
+ cols = dtype._shape_[1]
63
+
64
+ @cache.dynamic_func(suffix=suffix)
65
+ def basis_element(coord: int):
66
+ v = dtype()
67
+ i = coord // cols
68
+ j = coord - i * cols
69
+ v[i, j] = v.dtype(1.0)
70
+ return v
71
+
72
+ return basis_element
73
+
74
+ @cache.dynamic_func(suffix=suffix)
75
+ def basis_element(coord: int):
76
+ v = dtype()
77
+ v[coord] = v.dtype(1.0)
78
+ return v
79
+
80
+ return basis_element
81
+
82
+
31
83
  def compress_node_indices(
32
84
  node_count: int,
33
85
  node_indices: wp.array(dtype=int),
@@ -126,14 +178,7 @@ def host_read_at_index(array: wp.array, index: int = -1, temporary_store: cache.
126
178
 
127
179
  if index < 0:
128
180
  index += array.shape[0]
129
-
130
- if array.device.is_cuda:
131
- temp = cache.borrow_temporary(temporary_store, shape=1, dtype=int, pinned=True, device="cpu")
132
- wp.copy(dest=temp.array, src=array, src_offset=index, count=1)
133
- wp.synchronize_stream(wp.get_stream(array.device))
134
- return temp.array.numpy()[0]
135
-
136
- return array.numpy()[index]
181
+ return array[index : index + 1].numpy()[0]
137
182
 
138
183
 
139
184
  def masked_indices(
warp/jax.py CHANGED
@@ -182,6 +182,5 @@ def from_jax(jax_array, dtype=None) -> warp.array:
182
182
  Returns:
183
183
  warp.array: The converted Warp array.
184
184
  """
185
- import jax.dlpack
186
185
 
187
- return warp.from_dlpack(jax.dlpack.to_dlpack(jax_array), dtype=dtype)
186
+ return warp.from_dlpack(jax_array, dtype=dtype)
@@ -306,7 +306,6 @@ class FfiCallable:
306
306
  self.graph_compatible = graph_compatible
307
307
  self.output_dims = output_dims
308
308
  self.first_array_arg = None
309
- self.has_static_args = False
310
309
  self.call_id = 0
311
310
  self.call_descriptors = {}
312
311
 
@@ -335,8 +334,6 @@ class FfiCallable:
335
334
  if arg.is_array:
336
335
  if arg_idx < self.num_inputs and self.first_array_arg is None:
337
336
  self.first_array_arg = arg_idx
338
- else:
339
- self.has_static_args = True
340
337
  self.args.append(arg)
341
338
  arg_idx += 1
342
339
 
@@ -425,14 +422,11 @@ class FfiCallable:
425
422
  module = wp.get_module(self.func.__module__)
426
423
  module.load(device)
427
424
 
428
- if self.has_static_args:
429
- # save call data to be retrieved by callback
430
- call_id = self.call_id
431
- self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
432
- self.call_id += 1
433
- return call(*args, call_id=call_id)
434
- else:
435
- return call(*args)
425
+ # save call data to be retrieved by callback
426
+ call_id = self.call_id
427
+ self.call_descriptors[call_id] = FfiCallDesc(static_inputs)
428
+ self.call_id += 1
429
+ return call(*args, call_id=call_id)
436
430
 
437
431
  def ffi_callback(self, call_frame):
438
432
  try:
@@ -454,11 +448,10 @@ class FfiCallable:
454
448
  )
455
449
  return None
456
450
 
457
- if self.has_static_args:
458
- # retrieve call info
459
- attrs = decode_attrs(call_frame.contents.attrs)
460
- call_id = int(attrs["call_id"])
461
- call_desc = self.call_descriptors[call_id]
451
+ # retrieve call info
452
+ attrs = decode_attrs(call_frame.contents.attrs)
453
+ call_id = int(attrs["call_id"])
454
+ call_desc = self.call_descriptors[call_id]
462
455
 
463
456
  num_inputs = call_frame.contents.args.size
464
457
  inputs = ctypes.cast(call_frame.contents.args.args, ctypes.POINTER(ctypes.POINTER(XLA_FFI_Buffer)))
@@ -500,8 +493,10 @@ class FfiCallable:
500
493
  # call the Python function with reconstructed arguments
501
494
  with wp.ScopedStream(stream, sync_enter=False):
502
495
  if stream.is_capturing:
503
- with wp.ScopedCapture(stream=stream, external=True):
496
+ with wp.ScopedCapture(stream=stream, external=True) as capture:
504
497
  self.func(*arg_list)
498
+ # keep a reference to the capture object to prevent required modules getting unloaded
499
+ call_desc.capture = capture
505
500
  else:
506
501
  self.func(*arg_list)
507
502
 
@@ -130,14 +130,14 @@ class XLA_FFI_DataType(enum.IntEnum):
130
130
  # int64_t* dims; // length == rank
131
131
  # };
132
132
  class XLA_FFI_Buffer(ctypes.Structure):
133
- _fields_ = [
133
+ _fields_ = (
134
134
  ("struct_size", ctypes.c_size_t),
135
135
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
136
136
  ("dtype", ctypes.c_int), # XLA_FFI_DataType
137
137
  ("data", ctypes.c_void_p),
138
138
  ("rank", ctypes.c_int64),
139
139
  ("dims", ctypes.POINTER(ctypes.c_int64)),
140
- ]
140
+ )
141
141
 
142
142
 
143
143
  # typedef enum {
@@ -162,13 +162,13 @@ class XLA_FFI_RetType(enum.IntEnum):
162
162
  # void** args; // length == size
163
163
  # };
164
164
  class XLA_FFI_Args(ctypes.Structure):
165
- _fields_ = [
165
+ _fields_ = (
166
166
  ("struct_size", ctypes.c_size_t),
167
167
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
168
168
  ("size", ctypes.c_int64),
169
169
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_ArgType*
170
170
  ("args", ctypes.POINTER(ctypes.c_void_p)),
171
- ]
171
+ )
172
172
 
173
173
 
174
174
  # struct XLA_FFI_Rets {
@@ -179,13 +179,13 @@ class XLA_FFI_Args(ctypes.Structure):
179
179
  # void** rets; // length == size
180
180
  # };
181
181
  class XLA_FFI_Rets(ctypes.Structure):
182
- _fields_ = [
182
+ _fields_ = (
183
183
  ("struct_size", ctypes.c_size_t),
184
184
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
185
185
  ("size", ctypes.c_int64),
186
186
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_RetType*
187
187
  ("rets", ctypes.POINTER(ctypes.c_void_p)),
188
- ]
188
+ )
189
189
 
190
190
 
191
191
  # typedef struct XLA_FFI_ByteSpan {
@@ -193,7 +193,10 @@ class XLA_FFI_Rets(ctypes.Structure):
193
193
  # size_t len;
194
194
  # } XLA_FFI_ByteSpan;
195
195
  class XLA_FFI_ByteSpan(ctypes.Structure):
196
- _fields_ = [("ptr", ctypes.POINTER(ctypes.c_char)), ("len", ctypes.c_size_t)]
196
+ _fields_ = (
197
+ ("ptr", ctypes.POINTER(ctypes.c_char)),
198
+ ("len", ctypes.c_size_t),
199
+ )
197
200
 
198
201
 
199
202
  # typedef struct XLA_FFI_Scalar {
@@ -201,7 +204,10 @@ class XLA_FFI_ByteSpan(ctypes.Structure):
201
204
  # void* value;
202
205
  # } XLA_FFI_Scalar;
203
206
  class XLA_FFI_Scalar(ctypes.Structure):
204
- _fields_ = [("dtype", ctypes.c_int), ("value", ctypes.c_void_p)]
207
+ _fields_ = (
208
+ ("dtype", ctypes.c_int),
209
+ ("value", ctypes.c_void_p),
210
+ )
205
211
 
206
212
 
207
213
  # typedef struct XLA_FFI_Array {
@@ -210,7 +216,11 @@ class XLA_FFI_Scalar(ctypes.Structure):
210
216
  # void* data;
211
217
  # } XLA_FFI_Array;
212
218
  class XLA_FFI_Array(ctypes.Structure):
213
- _fields_ = [("dtype", ctypes.c_int), ("size", ctypes.c_size_t), ("data", ctypes.c_void_p)]
219
+ _fields_ = (
220
+ ("dtype", ctypes.c_int),
221
+ ("size", ctypes.c_size_t),
222
+ ("data", ctypes.c_void_p),
223
+ )
214
224
 
215
225
 
216
226
  # typedef enum {
@@ -235,14 +245,14 @@ class XLA_FFI_AttrType(enum.IntEnum):
235
245
  # void** attrs; // length == size
236
246
  # };
237
247
  class XLA_FFI_Attrs(ctypes.Structure):
238
- _fields_ = [
248
+ _fields_ = (
239
249
  ("struct_size", ctypes.c_size_t),
240
250
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
241
251
  ("size", ctypes.c_int64),
242
252
  ("types", ctypes.POINTER(ctypes.c_int)), # XLA_FFI_AttrType*
243
253
  ("names", ctypes.POINTER(ctypes.POINTER(XLA_FFI_ByteSpan))),
244
254
  ("attrs", ctypes.POINTER(ctypes.c_void_p)),
245
- ]
255
+ )
246
256
 
247
257
 
248
258
  # struct XLA_FFI_Api_Version {
@@ -252,12 +262,12 @@ class XLA_FFI_Attrs(ctypes.Structure):
252
262
  # int minor_version; // out
253
263
  # };
254
264
  class XLA_FFI_Api_Version(ctypes.Structure):
255
- _fields_ = [
265
+ _fields_ = (
256
266
  ("struct_size", ctypes.c_size_t),
257
267
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
258
268
  ("major_version", ctypes.c_int),
259
269
  ("minor_version", ctypes.c_int),
260
- ]
270
+ )
261
271
 
262
272
 
263
273
  # enum XLA_FFI_Handler_TraitsBits {
@@ -276,11 +286,11 @@ class XLA_FFI_Handler_TraitsBits(enum.IntEnum):
276
286
  # XLA_FFI_Handler_Traits traits;
277
287
  # };
278
288
  class XLA_FFI_Metadata(ctypes.Structure):
279
- _fields_ = [
289
+ _fields_ = (
280
290
  ("struct_size", ctypes.c_size_t),
281
291
  ("api_version", XLA_FFI_Api_Version), # XLA_FFI_Extension_Type
282
292
  ("traits", ctypes.c_uint32), # XLA_FFI_Handler_Traits
283
- ]
293
+ )
284
294
 
285
295
 
286
296
  # struct XLA_FFI_Metadata_Extension {
@@ -288,7 +298,10 @@ class XLA_FFI_Metadata(ctypes.Structure):
288
298
  # XLA_FFI_Metadata* metadata;
289
299
  # };
290
300
  class XLA_FFI_Metadata_Extension(ctypes.Structure):
291
- _fields_ = [("extension_base", XLA_FFI_Extension_Base), ("metadata", ctypes.POINTER(XLA_FFI_Metadata))]
301
+ _fields_ = (
302
+ ("extension_base", XLA_FFI_Extension_Base),
303
+ ("metadata", ctypes.POINTER(XLA_FFI_Metadata)),
304
+ )
292
305
 
293
306
 
294
307
  # typedef enum {
@@ -337,12 +350,12 @@ class XLA_FFI_Error_Code(enum.IntEnum):
337
350
  # XLA_FFI_Error_Code errc;
338
351
  # };
339
352
  class XLA_FFI_Error_Create_Args(ctypes.Structure):
340
- _fields_ = [
353
+ _fields_ = (
341
354
  ("struct_size", ctypes.c_size_t),
342
355
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
343
356
  ("message", ctypes.c_char_p),
344
357
  ("errc", ctypes.c_int),
345
- ] # XLA_FFI_Error_Code
358
+ ) # XLA_FFI_Error_Code
346
359
 
347
360
 
348
361
  XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Error_Create_Args))
@@ -355,12 +368,12 @@ XLA_FFI_Error_Create = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_
355
368
  # void* stream; // out
356
369
  # };
357
370
  class XLA_FFI_Stream_Get_Args(ctypes.Structure):
358
- _fields_ = [
371
+ _fields_ = (
359
372
  ("struct_size", ctypes.c_size_t),
360
373
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
361
374
  ("ctx", ctypes.c_void_p), # XLA_FFI_ExecutionContext*
362
375
  ("stream", ctypes.c_void_p),
363
- ] # // out
376
+ ) # // out
364
377
 
365
378
 
366
379
  XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_Stream_Get_Args))
@@ -391,7 +404,7 @@ XLA_FFI_Stream_Get = ctypes.CFUNCTYPE(ctypes.c_void_p, ctypes.POINTER(XLA_FFI_St
391
404
  # _XLA_FFI_API_STRUCT_FIELD(XLA_FFI_Future_SetError);
392
405
  # };
393
406
  class XLA_FFI_Api(ctypes.Structure):
394
- _fields_ = [
407
+ _fields_ = (
395
408
  ("struct_size", ctypes.c_size_t),
396
409
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
397
410
  ("api_version", XLA_FFI_Api_Version),
@@ -412,7 +425,7 @@ class XLA_FFI_Api(ctypes.Structure):
412
425
  ("XLA_FFI_Future_Create", ctypes.c_void_p), # XLA_FFI_Future_Create
413
426
  ("XLA_FFI_Future_SetAvailable", ctypes.c_void_p), # XLA_FFI_Future_SetAvailable
414
427
  ("XLA_FFI_Future_SetError", ctypes.c_void_p), # XLA_FFI_Future_SetError
415
- ]
428
+ )
416
429
 
417
430
 
418
431
  # struct XLA_FFI_CallFrame {
@@ -431,7 +444,7 @@ class XLA_FFI_Api(ctypes.Structure):
431
444
  # XLA_FFI_Future* future; // out
432
445
  # };
433
446
  class XLA_FFI_CallFrame(ctypes.Structure):
434
- _fields_ = [
447
+ _fields_ = (
435
448
  ("struct_size", ctypes.c_size_t),
436
449
  ("extension_start", ctypes.POINTER(XLA_FFI_Extension_Base)),
437
450
  ("api", ctypes.POINTER(XLA_FFI_Api)),
@@ -441,7 +454,7 @@ class XLA_FFI_CallFrame(ctypes.Structure):
441
454
  ("rets", XLA_FFI_Rets),
442
455
  ("attrs", XLA_FFI_Attrs),
443
456
  ("future", ctypes.c_void_p), # XLA_FFI_Future* // out
444
- ]
457
+ )
445
458
 
446
459
 
447
460
  _xla_data_type_to_constructor = {