warp-lang 1.1.0__py3-none-manylinux2014_x86_64.whl → 1.2.1__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (218) hide show
  1. warp/bin/warp-clang.so +0 -0
  2. warp/bin/warp.so +0 -0
  3. warp/build.py +10 -37
  4. warp/build_dll.py +2 -2
  5. warp/builtins.py +274 -6
  6. warp/codegen.py +51 -4
  7. warp/config.py +2 -2
  8. warp/constants.py +4 -0
  9. warp/context.py +422 -203
  10. warp/examples/benchmarks/benchmark_api.py +0 -2
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +0 -1
  12. warp/examples/benchmarks/benchmark_launches.py +0 -2
  13. warp/examples/core/example_dem.py +0 -2
  14. warp/examples/core/example_fluid.py +0 -2
  15. warp/examples/core/example_graph_capture.py +0 -2
  16. warp/examples/core/example_marching_cubes.py +0 -2
  17. warp/examples/core/example_mesh.py +0 -2
  18. warp/examples/core/example_mesh_intersect.py +0 -2
  19. warp/examples/core/example_nvdb.py +0 -2
  20. warp/examples/core/example_raycast.py +0 -2
  21. warp/examples/core/example_raymarch.py +0 -2
  22. warp/examples/core/example_render_opengl.py +0 -2
  23. warp/examples/core/example_sph.py +0 -2
  24. warp/examples/core/example_torch.py +0 -3
  25. warp/examples/core/example_wave.py +0 -2
  26. warp/examples/fem/example_apic_fluid.py +140 -115
  27. warp/examples/fem/example_burgers.py +262 -0
  28. warp/examples/fem/example_convection_diffusion.py +0 -2
  29. warp/examples/fem/example_convection_diffusion_dg.py +0 -2
  30. warp/examples/fem/example_deformed_geometry.py +0 -2
  31. warp/examples/fem/example_diffusion.py +0 -2
  32. warp/examples/fem/example_diffusion_3d.py +5 -4
  33. warp/examples/fem/example_diffusion_mgpu.py +0 -2
  34. warp/examples/fem/example_mixed_elasticity.py +0 -2
  35. warp/examples/fem/example_navier_stokes.py +0 -2
  36. warp/examples/fem/example_stokes.py +0 -2
  37. warp/examples/fem/example_stokes_transfer.py +0 -2
  38. warp/examples/optim/example_bounce.py +0 -2
  39. warp/examples/optim/example_cloth_throw.py +0 -2
  40. warp/examples/optim/example_diffray.py +0 -2
  41. warp/examples/optim/example_drone.py +0 -2
  42. warp/examples/optim/example_inverse_kinematics.py +0 -2
  43. warp/examples/optim/example_inverse_kinematics_torch.py +0 -2
  44. warp/examples/optim/example_spring_cage.py +0 -2
  45. warp/examples/optim/example_trajectory.py +0 -2
  46. warp/examples/optim/example_walker.py +0 -2
  47. warp/examples/sim/example_cartpole.py +0 -2
  48. warp/examples/sim/example_cloth.py +0 -2
  49. warp/examples/sim/example_granular.py +0 -2
  50. warp/examples/sim/example_granular_collision_sdf.py +0 -2
  51. warp/examples/sim/example_jacobian_ik.py +0 -2
  52. warp/examples/sim/example_particle_chain.py +0 -2
  53. warp/examples/sim/example_quadruped.py +0 -2
  54. warp/examples/sim/example_rigid_chain.py +0 -2
  55. warp/examples/sim/example_rigid_contact.py +0 -2
  56. warp/examples/sim/example_rigid_force.py +0 -2
  57. warp/examples/sim/example_rigid_gyroscopic.py +0 -2
  58. warp/examples/sim/example_rigid_soft_contact.py +0 -2
  59. warp/examples/sim/example_soft_body.py +0 -2
  60. warp/fem/__init__.py +1 -0
  61. warp/fem/cache.py +3 -1
  62. warp/fem/geometry/__init__.py +1 -0
  63. warp/fem/geometry/element.py +4 -0
  64. warp/fem/geometry/grid_3d.py +0 -4
  65. warp/fem/geometry/nanogrid.py +455 -0
  66. warp/fem/integrate.py +63 -9
  67. warp/fem/space/__init__.py +43 -158
  68. warp/fem/space/basis_space.py +34 -0
  69. warp/fem/space/collocated_function_space.py +1 -1
  70. warp/fem/space/grid_2d_function_space.py +13 -132
  71. warp/fem/space/grid_3d_function_space.py +16 -154
  72. warp/fem/space/hexmesh_function_space.py +37 -134
  73. warp/fem/space/nanogrid_function_space.py +202 -0
  74. warp/fem/space/quadmesh_2d_function_space.py +12 -119
  75. warp/fem/space/restriction.py +4 -1
  76. warp/fem/space/shape/__init__.py +77 -0
  77. warp/fem/space/shape/cube_shape_function.py +5 -15
  78. warp/fem/space/tetmesh_function_space.py +6 -76
  79. warp/fem/space/trimesh_2d_function_space.py +6 -76
  80. warp/native/array.h +12 -3
  81. warp/native/builtin.h +48 -5
  82. warp/native/bvh.cpp +14 -10
  83. warp/native/bvh.cu +23 -15
  84. warp/native/bvh.h +1 -0
  85. warp/native/clang/clang.cpp +2 -1
  86. warp/native/crt.cpp +11 -1
  87. warp/native/crt.h +18 -1
  88. warp/native/exports.h +187 -0
  89. warp/native/mat.h +47 -0
  90. warp/native/mesh.cpp +1 -1
  91. warp/native/mesh.cu +1 -2
  92. warp/native/nanovdb/GridHandle.h +366 -0
  93. warp/native/nanovdb/HostBuffer.h +590 -0
  94. warp/native/nanovdb/NanoVDB.h +3999 -2157
  95. warp/native/nanovdb/PNanoVDB.h +936 -99
  96. warp/native/quat.h +28 -1
  97. warp/native/rand.h +5 -1
  98. warp/native/vec.h +45 -1
  99. warp/native/volume.cpp +335 -103
  100. warp/native/volume.cu +39 -13
  101. warp/native/volume.h +725 -303
  102. warp/native/volume_builder.cu +381 -360
  103. warp/native/volume_builder.h +16 -1
  104. warp/native/volume_impl.h +61 -0
  105. warp/native/warp.cu +8 -2
  106. warp/native/warp.h +15 -7
  107. warp/render/render_opengl.py +191 -52
  108. warp/sim/integrator_featherstone.py +10 -3
  109. warp/sim/integrator_xpbd.py +16 -22
  110. warp/sparse.py +89 -27
  111. warp/stubs.py +83 -0
  112. warp/tests/assets/test_index_grid.nvdb +0 -0
  113. warp/tests/aux_test_dependent.py +0 -2
  114. warp/tests/aux_test_grad_customs.py +0 -2
  115. warp/tests/aux_test_reference.py +0 -2
  116. warp/tests/aux_test_reference_reference.py +0 -2
  117. warp/tests/aux_test_square.py +0 -2
  118. warp/tests/disabled_kinematics.py +0 -2
  119. warp/tests/test_adam.py +0 -2
  120. warp/tests/test_arithmetic.py +0 -36
  121. warp/tests/test_array.py +9 -11
  122. warp/tests/test_array_reduce.py +0 -2
  123. warp/tests/test_async.py +0 -2
  124. warp/tests/test_atomic.py +0 -2
  125. warp/tests/test_bool.py +58 -50
  126. warp/tests/test_builtins_resolution.py +0 -2
  127. warp/tests/test_bvh.py +0 -2
  128. warp/tests/test_closest_point_edge_edge.py +0 -1
  129. warp/tests/test_codegen.py +0 -4
  130. warp/tests/test_compile_consts.py +130 -10
  131. warp/tests/test_conditional.py +0 -2
  132. warp/tests/test_copy.py +0 -2
  133. warp/tests/test_ctypes.py +6 -8
  134. warp/tests/test_dense.py +0 -2
  135. warp/tests/test_devices.py +0 -2
  136. warp/tests/test_dlpack.py +9 -11
  137. warp/tests/test_examples.py +42 -39
  138. warp/tests/test_fabricarray.py +0 -3
  139. warp/tests/test_fast_math.py +0 -2
  140. warp/tests/test_fem.py +75 -54
  141. warp/tests/test_fp16.py +0 -2
  142. warp/tests/test_func.py +0 -2
  143. warp/tests/test_generics.py +27 -2
  144. warp/tests/test_grad.py +147 -8
  145. warp/tests/test_grad_customs.py +0 -2
  146. warp/tests/test_hash_grid.py +1 -3
  147. warp/tests/test_import.py +0 -2
  148. warp/tests/test_indexedarray.py +0 -2
  149. warp/tests/test_intersect.py +0 -2
  150. warp/tests/test_jax.py +0 -2
  151. warp/tests/test_large.py +11 -9
  152. warp/tests/test_launch.py +0 -2
  153. warp/tests/test_lerp.py +10 -54
  154. warp/tests/test_linear_solvers.py +3 -5
  155. warp/tests/test_lvalue.py +0 -2
  156. warp/tests/test_marching_cubes.py +0 -2
  157. warp/tests/test_mat.py +0 -2
  158. warp/tests/test_mat_lite.py +0 -2
  159. warp/tests/test_mat_scalar_ops.py +0 -2
  160. warp/tests/test_math.py +0 -2
  161. warp/tests/test_matmul.py +35 -37
  162. warp/tests/test_matmul_lite.py +29 -31
  163. warp/tests/test_mempool.py +0 -2
  164. warp/tests/test_mesh.py +0 -3
  165. warp/tests/test_mesh_query_aabb.py +0 -2
  166. warp/tests/test_mesh_query_point.py +0 -2
  167. warp/tests/test_mesh_query_ray.py +0 -2
  168. warp/tests/test_mlp.py +0 -2
  169. warp/tests/test_model.py +0 -2
  170. warp/tests/test_module_hashing.py +111 -0
  171. warp/tests/test_modules_lite.py +0 -3
  172. warp/tests/test_multigpu.py +0 -2
  173. warp/tests/test_noise.py +0 -4
  174. warp/tests/test_operators.py +0 -2
  175. warp/tests/test_options.py +0 -2
  176. warp/tests/test_peer.py +0 -2
  177. warp/tests/test_pinned.py +0 -2
  178. warp/tests/test_print.py +0 -2
  179. warp/tests/test_quat.py +0 -2
  180. warp/tests/test_rand.py +41 -5
  181. warp/tests/test_reload.py +0 -10
  182. warp/tests/test_rounding.py +0 -2
  183. warp/tests/test_runlength_encode.py +0 -2
  184. warp/tests/test_sim_grad.py +0 -2
  185. warp/tests/test_sim_kinematics.py +0 -2
  186. warp/tests/test_smoothstep.py +0 -2
  187. warp/tests/test_snippet.py +0 -2
  188. warp/tests/test_sparse.py +0 -2
  189. warp/tests/test_spatial.py +0 -2
  190. warp/tests/test_special_values.py +362 -0
  191. warp/tests/test_streams.py +0 -2
  192. warp/tests/test_struct.py +0 -2
  193. warp/tests/test_tape.py +0 -2
  194. warp/tests/test_torch.py +0 -2
  195. warp/tests/test_transient_module.py +0 -2
  196. warp/tests/test_types.py +0 -2
  197. warp/tests/test_utils.py +0 -2
  198. warp/tests/test_vec.py +0 -2
  199. warp/tests/test_vec_lite.py +0 -2
  200. warp/tests/test_vec_scalar_ops.py +0 -2
  201. warp/tests/test_verify_fp.py +0 -2
  202. warp/tests/test_volume.py +237 -13
  203. warp/tests/test_volume_write.py +86 -3
  204. warp/tests/unittest_serial.py +10 -9
  205. warp/tests/unittest_suites.py +6 -2
  206. warp/tests/unittest_utils.py +2 -171
  207. warp/tests/unused_test_misc.py +0 -2
  208. warp/tests/walkthrough_debug.py +1 -1
  209. warp/thirdparty/unittest_parallel.py +37 -40
  210. warp/types.py +526 -85
  211. {warp_lang-1.1.0.dist-info → warp_lang-1.2.1.dist-info}/METADATA +61 -31
  212. warp_lang-1.2.1.dist-info/RECORD +359 -0
  213. warp/examples/fem/example_convection_diffusion_dg0.py +0 -204
  214. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  215. warp_lang-1.1.0.dist-info/RECORD +0 -352
  216. {warp_lang-1.1.0.dist-info → warp_lang-1.2.1.dist-info}/LICENSE.md +0 -0
  217. {warp_lang-1.1.0.dist-info → warp_lang-1.2.1.dist-info}/WHEEL +0 -0
  218. {warp_lang-1.1.0.dist-info → warp_lang-1.2.1.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -6,6 +6,7 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import ast
9
+ import builtins
9
10
  import ctypes
10
11
  import functools
11
12
  import hashlib
@@ -18,7 +19,8 @@ import platform
18
19
  import sys
19
20
  import types
20
21
  from copy import copy as shallowcopy
21
- from types import ModuleType
22
+ from pathlib import Path
23
+ from struct import pack as struct_pack
22
24
  from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
23
25
 
24
26
  import numpy as np
@@ -345,6 +347,8 @@ class Function:
345
347
  def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
346
348
  uses_non_warp_array_type = False
347
349
 
350
+ warp.context.init()
351
+
348
352
  # Retrieve the built-in function from Warp's dll.
349
353
  c_func = getattr(warp.context.runtime.core, func.mangled_name)
350
354
 
@@ -1168,7 +1172,7 @@ def get_module(name):
1168
1172
  # clear out old kernels, funcs, struct definitions
1169
1173
  old_module.kernels = {}
1170
1174
  old_module.functions = {}
1171
- old_module.constants = []
1175
+ old_module.constants = {}
1172
1176
  old_module.structs = {}
1173
1177
  old_module.loader = parent_loader
1174
1178
 
@@ -1315,7 +1319,7 @@ class Module:
1315
1319
 
1316
1320
  self.kernels = {}
1317
1321
  self.functions = {}
1318
- self.constants = []
1322
+ self.constants = {} # Any constants referenced in this module including those defined in other modules
1319
1323
  self.structs = {}
1320
1324
 
1321
1325
  self.cpu_module = None
@@ -1442,7 +1446,13 @@ class Module:
1442
1446
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
1443
1447
  add_ref(arg.type.module)
1444
1448
 
1445
- def hash_module(self):
1449
+ def hash_module(self, recompute_content_hash=False):
1450
+ """Recursively compute and return a hash for the module.
1451
+
1452
+ If ``recompute_content_hash`` is False, each module's previously
1453
+ computed ``content_hash`` will be used.
1454
+ """
1455
+
1446
1456
  def get_annotations(obj: Any) -> Mapping[str, Any]:
1447
1457
  """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
1448
1458
  # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
@@ -1461,10 +1471,13 @@ class Module:
1461
1471
  # The visited set tracks modules already visited to avoid circular references.
1462
1472
 
1463
1473
  # check if we need to update the content hash
1464
- if not module.content_hash:
1474
+ if not module.content_hash or recompute_content_hash:
1465
1475
  # recompute content hash
1466
1476
  ch = hashlib.sha256()
1467
1477
 
1478
+ # Start with an empty constants dictionary in case any have been removed
1479
+ module.constants = {}
1480
+
1468
1481
  # struct source
1469
1482
  for struct in module.structs.values():
1470
1483
  s = ",".join(
@@ -1474,28 +1487,34 @@ class Module:
1474
1487
  ch.update(bytes(s, "utf-8"))
1475
1488
 
1476
1489
  # functions source
1477
- for func in module.functions.values():
1478
- s = func.adj.source
1479
- ch.update(bytes(s, "utf-8"))
1480
-
1481
- if func.custom_grad_func:
1482
- s = func.custom_grad_func.adj.source
1483
- ch.update(bytes(s, "utf-8"))
1484
- if func.custom_replay_func:
1485
- s = func.custom_replay_func.adj.source
1486
- if func.replay_snippet:
1487
- s = func.replay_snippet
1488
- if func.native_snippet:
1489
- s = func.native_snippet
1490
- ch.update(bytes(s, "utf-8"))
1491
- if func.adj_native_snippet:
1492
- s = func.adj_native_snippet
1490
+ for function in module.functions.values():
1491
+ # include all concrete and generic overloads
1492
+ overloads = itertools.chain(function.user_overloads.items(), function.user_templates.items())
1493
+ for sig, func in overloads:
1494
+ # signature
1495
+ ch.update(bytes(sig, "utf-8"))
1496
+
1497
+ # source
1498
+ s = func.adj.source
1493
1499
  ch.update(bytes(s, "utf-8"))
1494
1500
 
1495
- # cache func arg types
1496
- for arg, arg_type in func.adj.arg_types.items():
1497
- s = f"{arg}: {get_type_name(arg_type)}"
1498
- ch.update(bytes(s, "utf-8"))
1501
+ if func.custom_grad_func:
1502
+ s = func.custom_grad_func.adj.source
1503
+ ch.update(bytes(s, "utf-8"))
1504
+ if func.custom_replay_func:
1505
+ s = func.custom_replay_func.adj.source
1506
+ if func.replay_snippet:
1507
+ s = func.replay_snippet
1508
+ if func.native_snippet:
1509
+ s = func.native_snippet
1510
+ ch.update(bytes(s, "utf-8"))
1511
+ if func.adj_native_snippet:
1512
+ s = func.adj_native_snippet
1513
+ ch.update(bytes(s, "utf-8"))
1514
+
1515
+ # Populate constants referenced in this function
1516
+ if func.adj:
1517
+ module.constants.update(func.adj.get_constant_references())
1499
1518
 
1500
1519
  # kernel source
1501
1520
  for kernel in module.kernels.values():
@@ -1511,6 +1530,34 @@ class Module:
1511
1530
  for sig in sorted(kernel.overloads.keys()):
1512
1531
  ch.update(bytes(sig, "utf-8"))
1513
1532
 
1533
+ # Populate constants referenced in this kernel
1534
+ module.constants.update(kernel.adj.get_constant_references())
1535
+
1536
+ # constants referenced in this module
1537
+ for constant_name, constant_value in module.constants.items():
1538
+ ch.update(bytes(constant_name, "utf-8"))
1539
+
1540
+ # hash the constant value
1541
+ if isinstance(constant_value, builtins.bool):
1542
+ # This needs to come before the check for `int` since all boolean
1543
+ # values are also instances of `int`.
1544
+ ch.update(struct_pack("?", constant_value))
1545
+ elif isinstance(constant_value, int):
1546
+ ch.update(struct_pack("<q", constant_value))
1547
+ elif isinstance(constant_value, float):
1548
+ ch.update(struct_pack("<d", constant_value))
1549
+ elif isinstance(constant_value, warp.types.float16):
1550
+ # float16 is a special case
1551
+ p = ctypes.pointer(ctypes.c_float(constant_value.value))
1552
+ ch.update(p.contents)
1553
+ elif isinstance(constant_value, tuple(warp.types.scalar_types)):
1554
+ p = ctypes.pointer(constant_value._type_(constant_value.value))
1555
+ ch.update(p.contents)
1556
+ elif isinstance(constant_value, ctypes.Array):
1557
+ ch.update(bytes(constant_value))
1558
+ else:
1559
+ raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
1560
+
1514
1561
  module.content_hash = ch.digest()
1515
1562
 
1516
1563
  h = hashlib.sha256()
@@ -1529,10 +1576,6 @@ class Module:
1529
1576
 
1530
1577
  h.update(bytes(warp.config.mode, "utf-8"))
1531
1578
 
1532
- # compile-time constants (global)
1533
- if warp.types._constant_hash:
1534
- h.update(warp.types._constant_hash.digest())
1535
-
1536
1579
  # recurse on references
1537
1580
  visited.add(module)
1538
1581
 
@@ -1546,7 +1589,7 @@ class Module:
1546
1589
 
1547
1590
  return hash_recursive(self, visited=set())
1548
1591
 
1549
- def load(self, device):
1592
+ def load(self, device) -> bool:
1550
1593
  from warp.utils import ScopedTimer
1551
1594
 
1552
1595
  device = get_device(device)
@@ -1570,68 +1613,19 @@ class Module:
1570
1613
  if not warp.is_cuda_available():
1571
1614
  raise RuntimeError("Failed to build CUDA module because CUDA is not available")
1572
1615
 
1573
- with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
1574
- build_path = warp.build.kernel_bin_dir
1575
- gen_path = warp.build.kernel_gen_dir
1576
-
1577
- if not os.path.exists(build_path):
1578
- os.makedirs(build_path)
1579
- if not os.path.exists(gen_path):
1580
- os.makedirs(gen_path)
1616
+ module_name = "wp_" + self.name
1617
+ module_hash = self.hash_module()
1581
1618
 
1582
- module_name = "wp_" + self.name
1583
- module_path = os.path.join(build_path, module_name)
1584
- module_hash = self.hash_module()
1585
-
1586
- builder = ModuleBuilder(self, self.options)
1619
+ # use a unique module path using the module short hash
1620
+ module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
1587
1621
 
1622
+ with ScopedTimer(
1623
+ f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
1624
+ ):
1625
+ # -----------------------------------------------------------
1626
+ # determine output paths
1588
1627
  if device.is_cpu:
1589
- obj_path = os.path.join(build_path, module_name)
1590
- obj_path = obj_path + ".o"
1591
- cpu_hash_path = module_path + ".cpu.hash"
1592
-
1593
- # check cache
1594
- if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path):
1595
- with open(cpu_hash_path, "rb") as f:
1596
- cache_hash = f.read()
1597
-
1598
- if cache_hash == module_hash:
1599
- runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1600
- self.cpu_module = module_name
1601
- return True
1602
-
1603
- # build
1604
- try:
1605
- cpp_path = os.path.join(gen_path, module_name + ".cpp")
1606
-
1607
- # write cpp sources
1608
- cpp_source = builder.codegen("cpu")
1609
-
1610
- cpp_file = open(cpp_path, "w")
1611
- cpp_file.write(cpp_source)
1612
- cpp_file.close()
1613
-
1614
- # build object code
1615
- with ScopedTimer("Compile x86", active=warp.config.verbose):
1616
- warp.build.build_cpu(
1617
- obj_path,
1618
- cpp_path,
1619
- mode=self.options["mode"],
1620
- fast_math=self.options["fast_math"],
1621
- verify_fp=warp.config.verify_fp,
1622
- )
1623
-
1624
- # update cpu hash
1625
- with open(cpu_hash_path, "wb") as f:
1626
- f.write(module_hash)
1627
-
1628
- # load the object code
1629
- runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1630
- self.cpu_module = module_name
1631
-
1632
- except Exception as e:
1633
- self.cpu_build_failed = True
1634
- raise (e)
1628
+ output_name = "module_codegen.o"
1635
1629
 
1636
1630
  elif device.is_cuda:
1637
1631
  # determine whether to use PTX or CUBIN
@@ -1650,62 +1644,138 @@ class Module:
1650
1644
 
1651
1645
  if use_ptx:
1652
1646
  output_arch = min(device.arch, warp.config.ptx_target_arch)
1653
- output_path = module_path + f".sm{output_arch}.ptx"
1647
+ output_name = f"module_codegen.sm{output_arch}.ptx"
1654
1648
  else:
1655
1649
  output_arch = device.arch
1656
- output_path = module_path + f".sm{output_arch}.cubin"
1650
+ output_name = f"module_codegen.sm{output_arch}.cubin"
1651
+
1652
+ # final object binary path
1653
+ binary_path = os.path.join(module_dir, output_name)
1654
+
1655
+ # -----------------------------------------------------------
1656
+ # check cache and build if necessary
1657
+
1658
+ build_dir = None
1659
+
1660
+ if not os.path.exists(binary_path) or not warp.config.cache_kernels:
1661
+ builder = ModuleBuilder(self, self.options)
1662
+
1663
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
1664
+ build_dir = os.path.join(
1665
+ warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}_p{os.getpid()}"
1666
+ )
1667
+
1668
+ # dir may exist from previous attempts / runs / archs
1669
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
1670
+
1671
+ # build CPU
1672
+ if device.is_cpu:
1673
+ # build
1674
+ try:
1675
+ source_code_path = os.path.join(build_dir, "module_codegen.cpp")
1676
+
1677
+ # write cpp sources
1678
+ cpp_source = builder.codegen("cpu")
1679
+
1680
+ with open(source_code_path, "w") as cpp_file:
1681
+ cpp_file.write(cpp_source)
1657
1682
 
1658
- cuda_hash_path = module_path + f".sm{output_arch}.hash"
1683
+ output_path = os.path.join(build_dir, output_name)
1659
1684
 
1660
- # check cache
1661
- if warp.config.cache_kernels and os.path.isfile(cuda_hash_path) and os.path.isfile(output_path):
1662
- with open(cuda_hash_path, "rb") as f:
1663
- cache_hash = f.read()
1685
+ # build object code
1686
+ with ScopedTimer("Compile x86", active=warp.config.verbose):
1687
+ warp.build.build_cpu(
1688
+ output_path,
1689
+ source_code_path,
1690
+ mode=self.options["mode"],
1691
+ fast_math=self.options["fast_math"],
1692
+ verify_fp=warp.config.verify_fp,
1693
+ )
1694
+
1695
+ except Exception as e:
1696
+ self.cpu_build_failed = True
1697
+ raise (e)
1698
+
1699
+ elif device.is_cuda:
1700
+ # build
1701
+ try:
1702
+ source_code_path = os.path.join(build_dir, "module_codegen.cu")
1703
+
1704
+ # write cuda sources
1705
+ cu_source = builder.codegen("cuda")
1706
+
1707
+ with open(source_code_path, "w") as cu_file:
1708
+ cu_file.write(cu_source)
1709
+
1710
+ output_path = os.path.join(build_dir, output_name)
1711
+
1712
+ # generate PTX or CUBIN
1713
+ with ScopedTimer("Compile CUDA", active=warp.config.verbose):
1714
+ warp.build.build_cuda(
1715
+ source_code_path,
1716
+ output_arch,
1717
+ output_path,
1718
+ config=self.options["mode"],
1719
+ fast_math=self.options["fast_math"],
1720
+ verify_fp=warp.config.verify_fp,
1721
+ )
1722
+
1723
+ except Exception as e:
1724
+ self.cuda_build_failed = True
1725
+ raise (e)
1664
1726
 
1665
- if cache_hash == module_hash:
1666
- cuda_module = warp.build.load_cuda(output_path, device)
1667
- if cuda_module is not None:
1668
- self.cuda_modules[device.context] = cuda_module
1669
- return True
1727
+ # -----------------------------------------------------------
1728
+ # update cache
1670
1729
 
1671
- # build
1672
1730
  try:
1673
- cu_path = os.path.join(gen_path, module_name + ".cu")
1674
-
1675
- # write cuda sources
1676
- cu_source = builder.codegen("cuda")
1677
-
1678
- cu_file = open(cu_path, "w")
1679
- cu_file.write(cu_source)
1680
- cu_file.close()
1681
-
1682
- # generate PTX or CUBIN
1683
- with ScopedTimer("Compile CUDA", active=warp.config.verbose):
1684
- warp.build.build_cuda(
1685
- cu_path,
1686
- output_arch,
1687
- output_path,
1688
- config=self.options["mode"],
1689
- fast_math=self.options["fast_math"],
1690
- verify_fp=warp.config.verify_fp,
1691
- )
1692
-
1693
- # update cuda hash
1694
- with open(cuda_hash_path, "wb") as f:
1695
- f.write(module_hash)
1696
-
1697
- # load the module
1698
- cuda_module = warp.build.load_cuda(output_path, device)
1699
- if cuda_module is not None:
1700
- self.cuda_modules[device.context] = cuda_module
1701
- else:
1702
- raise Exception(f"Failed to load CUDA module '{self.name}'")
1731
+ # Copy process-specific build directory to a process-independent location
1732
+ os.rename(build_dir, module_dir)
1733
+ except (OSError, FileExistsError):
1734
+ # another process likely updated the module dir first
1735
+ pass
1703
1736
 
1704
- except Exception as e:
1705
- self.cuda_build_failed = True
1706
- raise (e)
1737
+ if os.path.exists(module_dir):
1738
+ if not os.path.exists(binary_path):
1739
+ # copy our output file to the destination module
1740
+ # this is necessary in case different processes
1741
+ # have different GPU architectures / devices
1742
+ try:
1743
+ os.rename(output_path, binary_path)
1744
+ except (OSError, FileExistsError):
1745
+ # another process likely updated the module dir first
1746
+ pass
1707
1747
 
1708
- return True
1748
+ try:
1749
+ final_source_path = os.path.join(module_dir, os.path.basename(source_code_path))
1750
+ if not os.path.exists(final_source_path):
1751
+ os.rename(source_code_path, final_source_path)
1752
+ except (OSError, FileExistsError):
1753
+ # another process likely updated the module dir first
1754
+ pass
1755
+ except Exception as e:
1756
+ # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
1757
+ warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
1758
+
1759
+ # -----------------------------------------------------------
1760
+ # Load CPU or CUDA binary
1761
+ if device.is_cpu:
1762
+ runtime.llvm.load_obj(binary_path.encode("utf-8"), module_name.encode("utf-8"))
1763
+ self.cpu_module = module_name
1764
+
1765
+ elif device.is_cuda:
1766
+ cuda_module = warp.build.load_cuda(binary_path, device)
1767
+ if cuda_module is not None:
1768
+ self.cuda_modules[device.context] = cuda_module
1769
+ else:
1770
+ raise Exception(f"Failed to load CUDA module '{self.name}'")
1771
+
1772
+ if build_dir:
1773
+ import shutil
1774
+
1775
+ # clean up build_dir used for this process regardless
1776
+ shutil.rmtree(build_dir, ignore_errors=True)
1777
+
1778
+ return True
1709
1779
 
1710
1780
  def unload(self):
1711
1781
  if self.cpu_module:
@@ -2578,22 +2648,36 @@ class Runtime:
2578
2648
  ]
2579
2649
  self.core.cutlass_gemm.restype = ctypes.c_bool
2580
2650
 
2581
- self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64]
2651
+ self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
2582
2652
  self.core.volume_create_host.restype = ctypes.c_uint64
2583
- self.core.volume_get_buffer_info_host.argtypes = [
2653
+ self.core.volume_get_tiles_host.argtypes = [
2584
2654
  ctypes.c_uint64,
2585
- ctypes.POINTER(ctypes.c_void_p),
2586
- ctypes.POINTER(ctypes.c_uint64),
2655
+ ctypes.c_void_p,
2587
2656
  ]
2588
- self.core.volume_get_tiles_host.argtypes = [
2657
+ self.core.volume_get_voxels_host.argtypes = [
2589
2658
  ctypes.c_uint64,
2590
- ctypes.POINTER(ctypes.c_void_p),
2591
- ctypes.POINTER(ctypes.c_uint64),
2659
+ ctypes.c_void_p,
2592
2660
  ]
2593
2661
  self.core.volume_destroy_host.argtypes = [ctypes.c_uint64]
2594
2662
 
2595
- self.core.volume_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64]
2663
+ self.core.volume_create_device.argtypes = [
2664
+ ctypes.c_void_p,
2665
+ ctypes.c_void_p,
2666
+ ctypes.c_uint64,
2667
+ ctypes.c_bool,
2668
+ ctypes.c_bool,
2669
+ ]
2596
2670
  self.core.volume_create_device.restype = ctypes.c_uint64
2671
+ self.core.volume_get_tiles_device.argtypes = [
2672
+ ctypes.c_uint64,
2673
+ ctypes.c_void_p,
2674
+ ]
2675
+ self.core.volume_get_voxels_device.argtypes = [
2676
+ ctypes.c_uint64,
2677
+ ctypes.c_void_p,
2678
+ ]
2679
+ self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
2680
+
2597
2681
  self.core.volume_f_from_tiles_device.argtypes = [
2598
2682
  ctypes.c_void_p,
2599
2683
  ctypes.c_void_p,
@@ -2632,24 +2716,68 @@ class Runtime:
2632
2716
  ctypes.c_bool,
2633
2717
  ]
2634
2718
  self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
2635
- self.core.volume_get_buffer_info_device.argtypes = [
2636
- ctypes.c_uint64,
2637
- ctypes.POINTER(ctypes.c_void_p),
2638
- ctypes.POINTER(ctypes.c_uint64),
2719
+ self.core.volume_index_from_tiles_device.argtypes = [
2720
+ ctypes.c_void_p,
2721
+ ctypes.c_void_p,
2722
+ ctypes.c_int,
2723
+ ctypes.c_float,
2724
+ ctypes.c_float,
2725
+ ctypes.c_float,
2726
+ ctypes.c_float,
2727
+ ctypes.c_bool,
2639
2728
  ]
2640
- self.core.volume_get_tiles_device.argtypes = [
2729
+ self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
2730
+ self.core.volume_from_active_voxels_device.argtypes = [
2731
+ ctypes.c_void_p,
2732
+ ctypes.c_void_p,
2733
+ ctypes.c_int,
2734
+ ctypes.c_float,
2735
+ ctypes.c_float,
2736
+ ctypes.c_float,
2737
+ ctypes.c_float,
2738
+ ctypes.c_bool,
2739
+ ]
2740
+ self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
2741
+
2742
+ self.core.volume_get_buffer_info.argtypes = [
2641
2743
  ctypes.c_uint64,
2642
2744
  ctypes.POINTER(ctypes.c_void_p),
2643
2745
  ctypes.POINTER(ctypes.c_uint64),
2644
2746
  ]
2645
- self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
2646
-
2647
2747
  self.core.volume_get_voxel_size.argtypes = [
2648
2748
  ctypes.c_uint64,
2649
2749
  ctypes.POINTER(ctypes.c_float),
2650
2750
  ctypes.POINTER(ctypes.c_float),
2651
2751
  ctypes.POINTER(ctypes.c_float),
2652
2752
  ]
2753
+ self.core.volume_get_tile_and_voxel_count.argtypes = [
2754
+ ctypes.c_uint64,
2755
+ ctypes.POINTER(ctypes.c_uint32),
2756
+ ctypes.POINTER(ctypes.c_uint64),
2757
+ ]
2758
+ self.core.volume_get_grid_info.argtypes = [
2759
+ ctypes.c_uint64,
2760
+ ctypes.POINTER(ctypes.c_uint64),
2761
+ ctypes.POINTER(ctypes.c_uint32),
2762
+ ctypes.POINTER(ctypes.c_uint32),
2763
+ ctypes.c_float * 3,
2764
+ ctypes.c_float * 9,
2765
+ ctypes.c_char * 16,
2766
+ ]
2767
+ self.core.volume_get_grid_info.restype = ctypes.c_char_p
2768
+ self.core.volume_get_blind_data_count.argtypes = [
2769
+ ctypes.c_uint64,
2770
+ ]
2771
+ self.core.volume_get_blind_data_count.restype = ctypes.c_uint64
2772
+ self.core.volume_get_blind_data_info.argtypes = [
2773
+ ctypes.c_uint64,
2774
+ ctypes.c_uint32,
2775
+ ctypes.POINTER(ctypes.c_void_p),
2776
+ ctypes.POINTER(ctypes.c_uint64),
2777
+ ctypes.POINTER(ctypes.c_uint32),
2778
+ ctypes.c_char * 16,
2779
+ ]
2780
+ self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
2653
2781
 
2654
2782
  bsr_matrix_from_triplets_argtypes = [
2655
2783
  ctypes.c_int,
@@ -3194,12 +3322,10 @@ class Runtime:
3194
3322
  raise RuntimeError(f"CUDA error detected: {err}")
3195
3323
 
3196
3324
 
3197
- def assert_initialized():
3198
- assert runtime is not None, "Warp not initialized, call wp.init() before use"
3199
-
3200
-
3201
3325
  # global entry points
3202
3326
  def is_cpu_available():
3327
+ init()
3328
+
3203
3329
  return runtime.llvm
3204
3330
 
3205
3331
 
@@ -3221,7 +3347,7 @@ def is_cuda_driver_initialized() -> bool:
3221
3347
 
3222
3348
  This can be helpful in cases in which ``cuInit()`` was called before a fork.
3223
3349
  """
3224
- assert_initialized()
3350
+ init()
3225
3351
 
3226
3352
  return runtime.core.cuda_driver_is_initialized()
3227
3353
 
@@ -3229,7 +3355,7 @@ def is_cuda_driver_initialized() -> bool:
3229
3355
  def get_devices() -> List[Device]:
3230
3356
  """Returns a list of devices supported in this environment."""
3231
3357
 
3232
- assert_initialized()
3358
+ init()
3233
3359
 
3234
3360
  devices = []
3235
3361
  if is_cpu_available():
@@ -3242,7 +3368,7 @@ def get_devices() -> List[Device]:
3242
3368
  def get_cuda_device_count() -> int:
3243
3369
  """Returns the number of CUDA devices supported in this environment."""
3244
3370
 
3245
- assert_initialized()
3371
+ init()
3246
3372
 
3247
3373
  return len(runtime.cuda_devices)
3248
3374
 
@@ -3250,7 +3376,7 @@ def get_cuda_device_count() -> int:
3250
3376
  def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
3251
3377
  """Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
3252
3378
 
3253
- assert_initialized()
3379
+ init()
3254
3380
 
3255
3381
  if ordinal is None:
3256
3382
  return runtime.get_current_cuda_device()
@@ -3261,7 +3387,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
3261
3387
  def get_cuda_devices() -> List[Device]:
3262
3388
  """Returns a list of CUDA devices supported in this environment."""
3263
3389
 
3264
- assert_initialized()
3390
+ init()
3265
3391
 
3266
3392
  return runtime.cuda_devices
3267
3393
 
@@ -3269,7 +3395,7 @@ def get_cuda_devices() -> List[Device]:
3269
3395
  def get_preferred_device() -> Device:
3270
3396
  """Returns the preferred compute device, CUDA if available and CPU otherwise."""
3271
3397
 
3272
- assert_initialized()
3398
+ init()
3273
3399
 
3274
3400
  if is_cuda_available():
3275
3401
  return runtime.cuda_devices[0]
@@ -3282,7 +3408,7 @@ def get_preferred_device() -> Device:
3282
3408
  def get_device(ident: Devicelike = None) -> Device:
3283
3409
  """Returns the device identified by the argument."""
3284
3410
 
3285
- assert_initialized()
3411
+ init()
3286
3412
 
3287
3413
  return runtime.get_device(ident)
3288
3414
 
@@ -3290,7 +3416,7 @@ def get_device(ident: Devicelike = None) -> Device:
3290
3416
  def set_device(ident: Devicelike):
3291
3417
  """Sets the target device identified by the argument."""
3292
3418
 
3293
- assert_initialized()
3419
+ init()
3294
3420
 
3295
3421
  device = runtime.get_device(ident)
3296
3422
  runtime.set_default_device(device)
@@ -3311,7 +3437,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3311
3437
  The associated wp.Device.
3312
3438
  """
3313
3439
 
3314
- assert_initialized()
3440
+ init()
3315
3441
 
3316
3442
  return runtime.map_cuda_device(alias, context)
3317
3443
 
@@ -3319,7 +3445,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3319
3445
  def unmap_cuda_device(alias: str):
3320
3446
  """Remove a CUDA device with the given alias."""
3321
3447
 
3322
- assert_initialized()
3448
+ init()
3323
3449
 
3324
3450
  runtime.unmap_cuda_device(alias)
3325
3451
 
@@ -3327,7 +3453,7 @@ def unmap_cuda_device(alias: str):
3327
3453
  def is_mempool_supported(device: Devicelike):
3328
3454
  """Check if CUDA memory pool allocators are available on the device."""
3329
3455
 
3330
- assert_initialized()
3456
+ init()
3331
3457
 
3332
3458
  device = runtime.get_device(device)
3333
3459
 
@@ -3337,7 +3463,7 @@ def is_mempool_supported(device: Devicelike):
3337
3463
  def is_mempool_enabled(device: Devicelike):
3338
3464
  """Check if CUDA memory pool allocators are enabled on the device."""
3339
3465
 
3340
- assert_initialized()
3466
+ init()
3341
3467
 
3342
3468
  device = runtime.get_device(device)
3343
3469
 
@@ -3357,7 +3483,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
3357
3483
  prior to graph capture.
3358
3484
  """
3359
3485
 
3360
- assert_initialized()
3486
+ init()
3361
3487
 
3362
3488
  device = runtime.get_device(device)
3363
3489
 
@@ -3387,7 +3513,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3387
3513
  For example, 1024**3 means one GiB of memory.
3388
3514
  """
3389
3515
 
3390
- assert_initialized()
3516
+ init()
3391
3517
 
3392
3518
  device = runtime.get_device(device)
3393
3519
 
@@ -3409,7 +3535,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3409
3535
  def get_mempool_release_threshold(device: Devicelike):
3410
3536
  """Get the CUDA memory pool release threshold on the device."""
3411
3537
 
3412
- assert_initialized()
3538
+ init()
3413
3539
 
3414
3540
  device = runtime.get_device(device)
3415
3541
 
@@ -3432,7 +3558,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3432
3558
  A Boolean value indicating if this peer access is supported by the system.
3433
3559
  """
3434
3560
 
3435
- assert_initialized()
3561
+ init()
3436
3562
 
3437
3563
  target_device = runtime.get_device(target_device)
3438
3564
  peer_device = runtime.get_device(peer_device)
@@ -3453,7 +3579,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3453
3579
  A Boolean value indicating if this peer access is currently enabled.
3454
3580
  """
3455
3581
 
3456
- assert_initialized()
3582
+ init()
3457
3583
 
3458
3584
  target_device = runtime.get_device(target_device)
3459
3585
  peer_device = runtime.get_device(peer_device)
@@ -3474,7 +3600,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
3474
3600
  CUDA pooled allocators, use `set_mempool_access_enabled()`.
3475
3601
  """
3476
3602
 
3477
- assert_initialized()
3603
+ init()
3478
3604
 
3479
3605
  target_device = runtime.get_device(target_device)
3480
3606
  peer_device = runtime.get_device(peer_device)
@@ -3505,7 +3631,10 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
3505
3631
  A Boolean value indicating if this memory pool access is supported by the system.
3506
3632
  """
3507
3633
 
3508
- assert_initialized()
3634
+ init()
3635
+
3636
+ target_device = runtime.get_device(target_device)
3637
+ peer_device = runtime.get_device(peer_device)
3509
3638
 
3510
3639
  return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
3511
3640
 
@@ -3520,7 +3649,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
3520
3649
  A Boolean value indicating if this peer access is currently enabled.
3521
3650
  """
3522
3651
 
3523
- assert_initialized()
3652
+ init()
3524
3653
 
3525
3654
  target_device = runtime.get_device(target_device)
3526
3655
  peer_device = runtime.get_device(peer_device)
@@ -3538,7 +3667,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
3538
3667
  default CUDA allocators, use `set_peer_access_enabled()`.
3539
3668
  """
3540
3669
 
3541
- assert_initialized()
3670
+ init()
3542
3671
 
3543
3672
  target_device = runtime.get_device(target_device)
3544
3673
  peer_device = runtime.get_device(peer_device)
@@ -3640,34 +3769,87 @@ def wait_stream(stream: Stream, event: Event = None):
3640
3769
 
3641
3770
  class RegisteredGLBuffer:
3642
3771
  """
3643
- Helper object to register a GL buffer with CUDA so that it can be mapped to a Warp array.
3772
+ Helper class to register a GL buffer with CUDA so that it can be mapped to a Warp array.
3773
+
3774
+ Example usage::
3775
+
3776
+ import warp as wp
3777
+ import numpy as np
3778
+ from pyglet.gl import *
3779
+
3780
+ wp.init()
3781
+
3782
+ # create a GL buffer
3783
+ gl_buffer_id = GLuint()
3784
+ glGenBuffers(1, gl_buffer_id)
3785
+
3786
+ # copy some data to the GL buffer
3787
+ glBindBuffer(GL_ARRAY_BUFFER, gl_buffer_id)
3788
+ gl_data = np.arange(1024, dtype=np.float32)
3789
+ glBufferData(GL_ARRAY_BUFFER, gl_data.nbytes, gl_data.ctypes.data, GL_DYNAMIC_DRAW)
3790
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
3791
+
3792
+ # register the GL buffer with CUDA
3793
+ cuda_gl_buffer = wp.RegisteredGLBuffer(gl_buffer_id)
3794
+
3795
+ # map the GL buffer to a Warp array
3796
+ arr = cuda_gl_buffer.map(dtype=wp.float32, shape=(1024,))
3797
+ # launch a Warp kernel to manipulate or read the array
3798
+ wp.launch(my_kernel, dim=1024, inputs=[arr])
3799
+ # unmap the GL buffer
3800
+ cuda_gl_buffer.unmap()
3644
3801
  """
3645
3802
 
3646
- # Specifies no hints about how this resource will be used.
3647
- # It is therefore assumed that this resource will be
3648
- # read from and written to by CUDA. This is the default value.
3649
3803
  NONE = 0x00
3804
+ """
3805
+ Flag that specifies no hints about how this resource will be used.
3806
+ It is therefore assumed that this resource will be
3807
+ read from and written to by CUDA. This is the default value.
3808
+ """
3650
3809
 
3651
- # Specifies that CUDA will not write to this resource.
3652
3810
  READ_ONLY = 0x01
3811
+ """
3812
+ Flag that specifies that CUDA will not write to this resource.
3813
+ """
3653
3814
 
3654
- # Specifies that CUDA will not read from this resource and will write over the
3655
- # entire contents of the resource, so none of the data previously
3656
- # stored in the resource will be preserved.
3657
3815
  WRITE_DISCARD = 0x02
3816
+ """
3817
+ Flag that specifies that CUDA will not read from this resource and will write over the
3818
+ entire contents of the resource, so none of the data previously
3819
+ stored in the resource will be preserved.
3820
+ """
3658
3821
 
3659
- def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE):
3660
- """Create a new RegisteredGLBuffer object.
3822
+ __fallback_warning_shown = False
3661
3823
 
3824
+ def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE, fallback_to_copy: bool = True):
3825
+ """
3662
3826
  Args:
3663
3827
  gl_buffer_id: The OpenGL buffer id (GLuint).
3664
3828
  device: The device to register the buffer with. If None, the current device will be used.
3665
- flags: A combination of the flags constants.
3829
+ flags: A combination of the flags constants :attr:`NONE`, :attr:`READ_ONLY`, and :attr:`WRITE_DISCARD`.
3830
+ fallback_to_copy: If True and CUDA/OpenGL interop is not available, fall back to copy operations between the Warp array and the OpenGL buffer. Otherwise, a ``RuntimeError`` will be raised.
3831
+
3832
+ Note:
3833
+
3834
+ The ``fallback_to_copy`` option (to use copy operations if CUDA graphics interop functionality is not available) requires pyglet version 2.0 or later. Install via ``pip install pyglet==2.*``.
3666
3835
  """
3667
3836
  self.gl_buffer_id = gl_buffer_id
3668
3837
  self.device = get_device(device)
3669
3838
  self.context = self.device.context
3839
+ self.flags = flags
3840
+ self.fallback_to_copy = fallback_to_copy
3670
3841
  self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
3842
+ if self.resource is None:
3843
+ if self.fallback_to_copy:
3844
+ self.warp_buffer = None
3845
+ self.warp_buffer_cpu = None
3846
+ if not RegisteredGLBuffer.__fallback_warning_shown:
3847
+ warp.utils.warn(
3848
+ "Could not register GL buffer since CUDA/OpenGL interoperability is not available. Falling back to copy operations between the Warp array and the OpenGL buffer.",
3849
+ )
3850
+ RegisteredGLBuffer.__fallback_warning_shown = True
3851
+ else:
3852
+ raise RuntimeError(f"Failed to register OpenGL buffer {gl_buffer_id} with CUDA")
3671
3853
 
3672
3854
  def __del__(self):
3673
3855
  if not self.resource:
@@ -3687,18 +3869,48 @@ class RegisteredGLBuffer:
3687
3869
  Returns:
3688
3870
  A Warp array object representing the mapped OpenGL buffer.
3689
3871
  """
3690
- runtime.core.cuda_graphics_map(self.context, self.resource)
3691
- ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_size_t)
3692
- ptr = ctypes.c_uint64(0)
3693
- size = ctypes.c_size_t(0)
3694
- runtime.core.cuda_graphics_device_ptr_and_size(
3695
- self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
3696
- )
3697
- return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
3872
+ if self.resource is not None:
3873
+ runtime.core.cuda_graphics_map(self.context, self.resource)
3874
+ ptr = ctypes.c_uint64(0)
3875
+ size = ctypes.c_size_t(0)
3876
+ runtime.core.cuda_graphics_device_ptr_and_size(
3877
+ self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
3878
+ )
3879
+ return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
3880
+ elif self.fallback_to_copy:
3881
+ if self.warp_buffer is None or self.warp_buffer.dtype != dtype or self.warp_buffer.shape != shape:
3882
+ self.warp_buffer = warp.empty(shape, dtype, device=self.device)
3883
+ self.warp_buffer_cpu = warp.empty(shape, dtype, device="cpu", pinned=True)
3884
+
3885
+ if self.flags == self.READ_ONLY or self.flags == self.NONE:
3886
+ # copy from OpenGL buffer to Warp array
3887
+ from pyglet import gl
3888
+
3889
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.gl_buffer_id)
3890
+ nbytes = self.warp_buffer.size * warp.types.type_size_in_bytes(dtype)
3891
+ gl.glGetBufferSubData(gl.GL_ARRAY_BUFFER, 0, nbytes, self.warp_buffer_cpu.ptr)
3892
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
3893
+ warp.copy(self.warp_buffer, self.warp_buffer_cpu)
3894
+ return self.warp_buffer
3895
+
3896
+ return None
3698
3897
 
3699
3898
  def unmap(self):
3700
3899
  """Unmap the OpenGL buffer."""
3701
- runtime.core.cuda_graphics_unmap(self.context, self.resource)
3900
+ if self.resource is not None:
3901
+ runtime.core.cuda_graphics_unmap(self.context, self.resource)
3902
+ elif self.fallback_to_copy:
3903
+ if self.warp_buffer is None:
3904
+ raise RuntimeError("RegisteredGLBuffer first has to be mapped")
3905
+
3906
+ if self.flags == self.WRITE_DISCARD or self.flags == self.NONE:
3907
+ # copy from Warp array to OpenGL buffer
3908
+ from pyglet import gl
3909
+
3910
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.gl_buffer_id)
3911
+ buffer = self.warp_buffer.numpy()
3912
+ gl.glBufferData(gl.GL_ARRAY_BUFFER, buffer.nbytes, buffer.ctypes.data, gl.GL_DYNAMIC_DRAW)
3913
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
3702
3914
 
3703
3915
 
3704
3916
  def zeros(
@@ -4253,7 +4465,7 @@ def launch(
4253
4465
  If negative or zero, the maximum hardware value will be used.
4254
4466
  """
4255
4467
 
4256
- assert_initialized()
4468
+ init()
4257
4469
 
4258
4470
  # if stream is specified, use the associated device
4259
4471
  if stream is not None:
@@ -4496,7 +4708,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
4496
4708
 
4497
4709
 
4498
4710
  def load_module(
4499
- module: Union[Module, ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
4711
+ module: Union[Module, types.ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
4500
4712
  ):
4501
4713
  """Force user-defined module to be compiled and loaded
4502
4714
 
@@ -4514,7 +4726,7 @@ def load_module(
4514
4726
  module_name = module.__name__
4515
4727
  elif isinstance(module, Module):
4516
4728
  module_name = module.name
4517
- elif isinstance(module, ModuleType):
4729
+ elif isinstance(module, types.ModuleType):
4518
4730
  module_name = module.__name__
4519
4731
  elif isinstance(module, str):
4520
4732
  module_name = module
@@ -4863,13 +5075,20 @@ def copy(
4863
5075
 
4864
5076
  # copy gradient, if needed
4865
5077
  if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
4866
- copy(dest.grad, src.grad, stream=stream)
5078
+ copy(dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)
4867
5079
 
4868
5080
  if runtime.tape:
4869
- runtime.tape.record_func(backward=lambda: adj_copy(dest.grad, src.grad, stream=stream), arrays=[dest, src])
5081
+ runtime.tape.record_func(
5082
+ backward=lambda: adj_copy(
5083
+ dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream
5084
+ ),
5085
+ arrays=[dest, src],
5086
+ )
4870
5087
 
4871
5088
 
4872
- def adj_copy(adj_dest: warp.array, adj_src: warp.array, stream: Stream = None):
5089
+ def adj_copy(
5090
+ adj_dest: warp.array, adj_src: warp.array, dest_offset: int, src_offset: int, count: int, stream: Stream = None
5091
+ ):
4873
5092
  """Copy adjoint operation for wp.copy() calls on the tape.
4874
5093
 
4875
5094
  Args:
@@ -4877,7 +5096,7 @@ def adj_copy(adj_dest: warp.array, adj_src: warp.array, stream: Stream = None):
4877
5096
  adj_src: Source array adjoint
4878
5097
  stream: The stream on which the copy was performed in the forward pass
4879
5098
  """
4880
- copy(adj_src, adj_dest, stream=stream)
5099
+ copy(adj_src, adj_dest, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)
4881
5100
 
4882
5101
 
4883
5102
  def type_str(t):