warp-lang 1.1.0__py3-none-macosx_10_13_universal2.whl → 1.2.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (218) hide show
  1. warp/bin/libwarp-clang.dylib +0 -0
  2. warp/bin/libwarp.dylib +0 -0
  3. warp/build.py +10 -37
  4. warp/build_dll.py +2 -2
  5. warp/builtins.py +274 -6
  6. warp/codegen.py +51 -4
  7. warp/config.py +2 -2
  8. warp/constants.py +4 -0
  9. warp/context.py +418 -203
  10. warp/examples/benchmarks/benchmark_api.py +0 -2
  11. warp/examples/benchmarks/benchmark_cloth_warp.py +0 -1
  12. warp/examples/benchmarks/benchmark_launches.py +0 -2
  13. warp/examples/core/example_dem.py +0 -2
  14. warp/examples/core/example_fluid.py +0 -2
  15. warp/examples/core/example_graph_capture.py +0 -2
  16. warp/examples/core/example_marching_cubes.py +0 -2
  17. warp/examples/core/example_mesh.py +0 -2
  18. warp/examples/core/example_mesh_intersect.py +0 -2
  19. warp/examples/core/example_nvdb.py +0 -2
  20. warp/examples/core/example_raycast.py +0 -2
  21. warp/examples/core/example_raymarch.py +0 -2
  22. warp/examples/core/example_render_opengl.py +0 -2
  23. warp/examples/core/example_sph.py +0 -2
  24. warp/examples/core/example_torch.py +0 -3
  25. warp/examples/core/example_wave.py +0 -2
  26. warp/examples/fem/example_apic_fluid.py +140 -115
  27. warp/examples/fem/example_burgers.py +262 -0
  28. warp/examples/fem/example_convection_diffusion.py +0 -2
  29. warp/examples/fem/example_convection_diffusion_dg.py +0 -2
  30. warp/examples/fem/example_deformed_geometry.py +0 -2
  31. warp/examples/fem/example_diffusion.py +0 -2
  32. warp/examples/fem/example_diffusion_3d.py +5 -4
  33. warp/examples/fem/example_diffusion_mgpu.py +0 -2
  34. warp/examples/fem/example_mixed_elasticity.py +0 -2
  35. warp/examples/fem/example_navier_stokes.py +0 -2
  36. warp/examples/fem/example_stokes.py +0 -2
  37. warp/examples/fem/example_stokes_transfer.py +0 -2
  38. warp/examples/optim/example_bounce.py +0 -2
  39. warp/examples/optim/example_cloth_throw.py +0 -2
  40. warp/examples/optim/example_diffray.py +0 -2
  41. warp/examples/optim/example_drone.py +0 -2
  42. warp/examples/optim/example_inverse_kinematics.py +0 -2
  43. warp/examples/optim/example_inverse_kinematics_torch.py +0 -2
  44. warp/examples/optim/example_spring_cage.py +0 -2
  45. warp/examples/optim/example_trajectory.py +0 -2
  46. warp/examples/optim/example_walker.py +0 -2
  47. warp/examples/sim/example_cartpole.py +0 -2
  48. warp/examples/sim/example_cloth.py +0 -2
  49. warp/examples/sim/example_granular.py +0 -2
  50. warp/examples/sim/example_granular_collision_sdf.py +0 -2
  51. warp/examples/sim/example_jacobian_ik.py +0 -2
  52. warp/examples/sim/example_particle_chain.py +0 -2
  53. warp/examples/sim/example_quadruped.py +0 -2
  54. warp/examples/sim/example_rigid_chain.py +0 -2
  55. warp/examples/sim/example_rigid_contact.py +0 -2
  56. warp/examples/sim/example_rigid_force.py +0 -2
  57. warp/examples/sim/example_rigid_gyroscopic.py +0 -2
  58. warp/examples/sim/example_rigid_soft_contact.py +0 -2
  59. warp/examples/sim/example_soft_body.py +0 -2
  60. warp/fem/__init__.py +1 -0
  61. warp/fem/cache.py +3 -1
  62. warp/fem/geometry/__init__.py +1 -0
  63. warp/fem/geometry/element.py +4 -0
  64. warp/fem/geometry/grid_3d.py +0 -4
  65. warp/fem/geometry/nanogrid.py +455 -0
  66. warp/fem/integrate.py +63 -9
  67. warp/fem/space/__init__.py +43 -158
  68. warp/fem/space/basis_space.py +34 -0
  69. warp/fem/space/collocated_function_space.py +1 -1
  70. warp/fem/space/grid_2d_function_space.py +13 -132
  71. warp/fem/space/grid_3d_function_space.py +16 -154
  72. warp/fem/space/hexmesh_function_space.py +37 -134
  73. warp/fem/space/nanogrid_function_space.py +202 -0
  74. warp/fem/space/quadmesh_2d_function_space.py +12 -119
  75. warp/fem/space/restriction.py +4 -1
  76. warp/fem/space/shape/__init__.py +77 -0
  77. warp/fem/space/shape/cube_shape_function.py +5 -15
  78. warp/fem/space/tetmesh_function_space.py +6 -76
  79. warp/fem/space/trimesh_2d_function_space.py +6 -76
  80. warp/native/array.h +12 -3
  81. warp/native/builtin.h +48 -5
  82. warp/native/bvh.cpp +14 -10
  83. warp/native/bvh.cu +23 -15
  84. warp/native/bvh.h +1 -0
  85. warp/native/clang/clang.cpp +2 -1
  86. warp/native/crt.cpp +11 -1
  87. warp/native/crt.h +18 -1
  88. warp/native/exports.h +187 -0
  89. warp/native/mat.h +47 -0
  90. warp/native/mesh.cpp +1 -1
  91. warp/native/mesh.cu +1 -2
  92. warp/native/nanovdb/GridHandle.h +366 -0
  93. warp/native/nanovdb/HostBuffer.h +590 -0
  94. warp/native/nanovdb/NanoVDB.h +3999 -2157
  95. warp/native/nanovdb/PNanoVDB.h +936 -99
  96. warp/native/quat.h +28 -1
  97. warp/native/rand.h +5 -1
  98. warp/native/vec.h +45 -1
  99. warp/native/volume.cpp +335 -103
  100. warp/native/volume.cu +39 -13
  101. warp/native/volume.h +725 -303
  102. warp/native/volume_builder.cu +381 -360
  103. warp/native/volume_builder.h +16 -1
  104. warp/native/volume_impl.h +61 -0
  105. warp/native/warp.cu +8 -2
  106. warp/native/warp.h +15 -7
  107. warp/render/render_opengl.py +191 -52
  108. warp/sim/integrator_featherstone.py +10 -3
  109. warp/sim/integrator_xpbd.py +16 -22
  110. warp/sparse.py +89 -27
  111. warp/stubs.py +83 -0
  112. warp/tests/assets/test_index_grid.nvdb +0 -0
  113. warp/tests/aux_test_dependent.py +0 -2
  114. warp/tests/aux_test_grad_customs.py +0 -2
  115. warp/tests/aux_test_reference.py +0 -2
  116. warp/tests/aux_test_reference_reference.py +0 -2
  117. warp/tests/aux_test_square.py +0 -2
  118. warp/tests/disabled_kinematics.py +0 -2
  119. warp/tests/test_adam.py +0 -2
  120. warp/tests/test_arithmetic.py +0 -36
  121. warp/tests/test_array.py +9 -11
  122. warp/tests/test_array_reduce.py +0 -2
  123. warp/tests/test_async.py +0 -2
  124. warp/tests/test_atomic.py +0 -2
  125. warp/tests/test_bool.py +58 -50
  126. warp/tests/test_builtins_resolution.py +0 -2
  127. warp/tests/test_bvh.py +0 -2
  128. warp/tests/test_closest_point_edge_edge.py +0 -1
  129. warp/tests/test_codegen.py +0 -4
  130. warp/tests/test_compile_consts.py +130 -10
  131. warp/tests/test_conditional.py +0 -2
  132. warp/tests/test_copy.py +0 -2
  133. warp/tests/test_ctypes.py +6 -8
  134. warp/tests/test_dense.py +0 -2
  135. warp/tests/test_devices.py +0 -2
  136. warp/tests/test_dlpack.py +9 -11
  137. warp/tests/test_examples.py +42 -39
  138. warp/tests/test_fabricarray.py +0 -3
  139. warp/tests/test_fast_math.py +0 -2
  140. warp/tests/test_fem.py +75 -54
  141. warp/tests/test_fp16.py +0 -2
  142. warp/tests/test_func.py +0 -2
  143. warp/tests/test_generics.py +27 -2
  144. warp/tests/test_grad.py +147 -8
  145. warp/tests/test_grad_customs.py +0 -2
  146. warp/tests/test_hash_grid.py +1 -3
  147. warp/tests/test_import.py +0 -2
  148. warp/tests/test_indexedarray.py +0 -2
  149. warp/tests/test_intersect.py +0 -2
  150. warp/tests/test_jax.py +0 -2
  151. warp/tests/test_large.py +11 -9
  152. warp/tests/test_launch.py +0 -2
  153. warp/tests/test_lerp.py +10 -54
  154. warp/tests/test_linear_solvers.py +3 -5
  155. warp/tests/test_lvalue.py +0 -2
  156. warp/tests/test_marching_cubes.py +0 -2
  157. warp/tests/test_mat.py +0 -2
  158. warp/tests/test_mat_lite.py +0 -2
  159. warp/tests/test_mat_scalar_ops.py +0 -2
  160. warp/tests/test_math.py +0 -2
  161. warp/tests/test_matmul.py +35 -37
  162. warp/tests/test_matmul_lite.py +29 -31
  163. warp/tests/test_mempool.py +0 -2
  164. warp/tests/test_mesh.py +0 -3
  165. warp/tests/test_mesh_query_aabb.py +0 -2
  166. warp/tests/test_mesh_query_point.py +0 -2
  167. warp/tests/test_mesh_query_ray.py +0 -2
  168. warp/tests/test_mlp.py +0 -2
  169. warp/tests/test_model.py +0 -2
  170. warp/tests/test_module_hashing.py +111 -0
  171. warp/tests/test_modules_lite.py +0 -3
  172. warp/tests/test_multigpu.py +0 -2
  173. warp/tests/test_noise.py +0 -4
  174. warp/tests/test_operators.py +0 -2
  175. warp/tests/test_options.py +0 -2
  176. warp/tests/test_peer.py +0 -2
  177. warp/tests/test_pinned.py +0 -2
  178. warp/tests/test_print.py +0 -2
  179. warp/tests/test_quat.py +0 -2
  180. warp/tests/test_rand.py +41 -5
  181. warp/tests/test_reload.py +0 -10
  182. warp/tests/test_rounding.py +0 -2
  183. warp/tests/test_runlength_encode.py +0 -2
  184. warp/tests/test_sim_grad.py +0 -2
  185. warp/tests/test_sim_kinematics.py +0 -2
  186. warp/tests/test_smoothstep.py +0 -2
  187. warp/tests/test_snippet.py +0 -2
  188. warp/tests/test_sparse.py +0 -2
  189. warp/tests/test_spatial.py +0 -2
  190. warp/tests/test_special_values.py +362 -0
  191. warp/tests/test_streams.py +0 -2
  192. warp/tests/test_struct.py +0 -2
  193. warp/tests/test_tape.py +0 -2
  194. warp/tests/test_torch.py +0 -2
  195. warp/tests/test_transient_module.py +0 -2
  196. warp/tests/test_types.py +0 -2
  197. warp/tests/test_utils.py +0 -2
  198. warp/tests/test_vec.py +0 -2
  199. warp/tests/test_vec_lite.py +0 -2
  200. warp/tests/test_vec_scalar_ops.py +0 -2
  201. warp/tests/test_verify_fp.py +0 -2
  202. warp/tests/test_volume.py +237 -13
  203. warp/tests/test_volume_write.py +86 -3
  204. warp/tests/unittest_serial.py +10 -9
  205. warp/tests/unittest_suites.py +6 -2
  206. warp/tests/unittest_utils.py +2 -171
  207. warp/tests/unused_test_misc.py +0 -2
  208. warp/tests/walkthrough_debug.py +1 -1
  209. warp/thirdparty/unittest_parallel.py +37 -40
  210. warp/types.py +514 -77
  211. {warp_lang-1.1.0.dist-info → warp_lang-1.2.0.dist-info}/METADATA +57 -30
  212. warp_lang-1.2.0.dist-info/RECORD +359 -0
  213. warp/examples/fem/example_convection_diffusion_dg0.py +0 -204
  214. warp/native/nanovdb/PNanoVDBWrite.h +0 -295
  215. warp_lang-1.1.0.dist-info/RECORD +0 -352
  216. {warp_lang-1.1.0.dist-info → warp_lang-1.2.0.dist-info}/LICENSE.md +0 -0
  217. {warp_lang-1.1.0.dist-info → warp_lang-1.2.0.dist-info}/WHEEL +0 -0
  218. {warp_lang-1.1.0.dist-info → warp_lang-1.2.0.dist-info}/top_level.txt +0 -0
warp/context.py CHANGED
@@ -6,6 +6,7 @@
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
8
  import ast
9
+ import builtins
9
10
  import ctypes
10
11
  import functools
11
12
  import hashlib
@@ -18,7 +19,8 @@ import platform
18
19
  import sys
19
20
  import types
20
21
  from copy import copy as shallowcopy
21
- from types import ModuleType
22
+ from pathlib import Path
23
+ from struct import pack as struct_pack
22
24
  from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union
23
25
 
24
26
  import numpy as np
@@ -345,6 +347,8 @@ class Function:
345
347
  def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
346
348
  uses_non_warp_array_type = False
347
349
 
350
+ warp.context.init()
351
+
348
352
  # Retrieve the built-in function from Warp's dll.
349
353
  c_func = getattr(warp.context.runtime.core, func.mangled_name)
350
354
 
@@ -1168,7 +1172,7 @@ def get_module(name):
1168
1172
  # clear out old kernels, funcs, struct definitions
1169
1173
  old_module.kernels = {}
1170
1174
  old_module.functions = {}
1171
- old_module.constants = []
1175
+ old_module.constants = {}
1172
1176
  old_module.structs = {}
1173
1177
  old_module.loader = parent_loader
1174
1178
 
@@ -1315,7 +1319,7 @@ class Module:
1315
1319
 
1316
1320
  self.kernels = {}
1317
1321
  self.functions = {}
1318
- self.constants = []
1322
+ self.constants = {} # Any constants referenced in this module including those defined in other modules
1319
1323
  self.structs = {}
1320
1324
 
1321
1325
  self.cpu_module = None
@@ -1442,7 +1446,13 @@ class Module:
1442
1446
  if isinstance(arg.type, warp.codegen.Struct) and arg.type.module is not None:
1443
1447
  add_ref(arg.type.module)
1444
1448
 
1445
- def hash_module(self):
1449
+ def hash_module(self, recompute_content_hash=False):
1450
+ """Recursively compute and return a hash for the module.
1451
+
1452
+ If ``recompute_content_hash`` is False, each module's previously
1453
+ computed ``content_hash`` will be used.
1454
+ """
1455
+
1446
1456
  def get_annotations(obj: Any) -> Mapping[str, Any]:
1447
1457
  """Alternative to `inspect.get_annotations()` for Python 3.9 and older."""
1448
1458
  # See https://docs.python.org/3/howto/annotations.html#accessing-the-annotations-dict-of-an-object-in-python-3-9-and-older
@@ -1461,10 +1471,13 @@ class Module:
1461
1471
  # The visited set tracks modules already visited to avoid circular references.
1462
1472
 
1463
1473
  # check if we need to update the content hash
1464
- if not module.content_hash:
1474
+ if not module.content_hash or recompute_content_hash:
1465
1475
  # recompute content hash
1466
1476
  ch = hashlib.sha256()
1467
1477
 
1478
+ # Start with an empty constants dictionary in case any have been removed
1479
+ module.constants = {}
1480
+
1468
1481
  # struct source
1469
1482
  for struct in module.structs.values():
1470
1483
  s = ",".join(
@@ -1474,28 +1487,33 @@ class Module:
1474
1487
  ch.update(bytes(s, "utf-8"))
1475
1488
 
1476
1489
  # functions source
1477
- for func in module.functions.values():
1478
- s = func.adj.source
1479
- ch.update(bytes(s, "utf-8"))
1480
-
1481
- if func.custom_grad_func:
1482
- s = func.custom_grad_func.adj.source
1483
- ch.update(bytes(s, "utf-8"))
1484
- if func.custom_replay_func:
1485
- s = func.custom_replay_func.adj.source
1486
- if func.replay_snippet:
1487
- s = func.replay_snippet
1488
- if func.native_snippet:
1489
- s = func.native_snippet
1490
- ch.update(bytes(s, "utf-8"))
1491
- if func.adj_native_snippet:
1492
- s = func.adj_native_snippet
1490
+ for function in module.functions.values():
1491
+ # include all overloads
1492
+ for sig, func in function.user_overloads.items():
1493
+ # signature
1494
+ ch.update(bytes(sig, "utf-8"))
1495
+
1496
+ # source
1497
+ s = func.adj.source
1493
1498
  ch.update(bytes(s, "utf-8"))
1494
1499
 
1495
- # cache func arg types
1496
- for arg, arg_type in func.adj.arg_types.items():
1497
- s = f"{arg}: {get_type_name(arg_type)}"
1498
- ch.update(bytes(s, "utf-8"))
1500
+ if func.custom_grad_func:
1501
+ s = func.custom_grad_func.adj.source
1502
+ ch.update(bytes(s, "utf-8"))
1503
+ if func.custom_replay_func:
1504
+ s = func.custom_replay_func.adj.source
1505
+ if func.replay_snippet:
1506
+ s = func.replay_snippet
1507
+ if func.native_snippet:
1508
+ s = func.native_snippet
1509
+ ch.update(bytes(s, "utf-8"))
1510
+ if func.adj_native_snippet:
1511
+ s = func.adj_native_snippet
1512
+ ch.update(bytes(s, "utf-8"))
1513
+
1514
+ # Populate constants referenced in this function
1515
+ if func.adj:
1516
+ module.constants.update(func.adj.get_constant_references())
1499
1517
 
1500
1518
  # kernel source
1501
1519
  for kernel in module.kernels.values():
@@ -1511,6 +1529,34 @@ class Module:
1511
1529
  for sig in sorted(kernel.overloads.keys()):
1512
1530
  ch.update(bytes(sig, "utf-8"))
1513
1531
 
1532
+ # Populate constants referenced in this kernel
1533
+ module.constants.update(kernel.adj.get_constant_references())
1534
+
1535
+ # constants referenced in this module
1536
+ for constant_name, constant_value in module.constants.items():
1537
+ ch.update(bytes(constant_name, "utf-8"))
1538
+
1539
+ # hash the constant value
1540
+ if isinstance(constant_value, builtins.bool):
1541
+ # This needs to come before the check for `int` since all boolean
1542
+ # values are also instances of `int`.
1543
+ ch.update(struct_pack("?", constant_value))
1544
+ elif isinstance(constant_value, int):
1545
+ ch.update(struct_pack("<q", constant_value))
1546
+ elif isinstance(constant_value, float):
1547
+ ch.update(struct_pack("<d", constant_value))
1548
+ elif isinstance(constant_value, warp.types.float16):
1549
+ # float16 is a special case
1550
+ p = ctypes.pointer(ctypes.c_float(constant_value.value))
1551
+ ch.update(p.contents)
1552
+ elif isinstance(constant_value, tuple(warp.types.scalar_types)):
1553
+ p = ctypes.pointer(constant_value._type_(constant_value.value))
1554
+ ch.update(p.contents)
1555
+ elif isinstance(constant_value, ctypes.Array):
1556
+ ch.update(bytes(constant_value))
1557
+ else:
1558
+ raise RuntimeError(f"Invalid constant type: {type(constant_value)}")
1559
+
1514
1560
  module.content_hash = ch.digest()
1515
1561
 
1516
1562
  h = hashlib.sha256()
@@ -1529,10 +1575,6 @@ class Module:
1529
1575
 
1530
1576
  h.update(bytes(warp.config.mode, "utf-8"))
1531
1577
 
1532
- # compile-time constants (global)
1533
- if warp.types._constant_hash:
1534
- h.update(warp.types._constant_hash.digest())
1535
-
1536
1578
  # recurse on references
1537
1579
  visited.add(module)
1538
1580
 
@@ -1546,7 +1588,7 @@ class Module:
1546
1588
 
1547
1589
  return hash_recursive(self, visited=set())
1548
1590
 
1549
- def load(self, device):
1591
+ def load(self, device) -> bool:
1550
1592
  from warp.utils import ScopedTimer
1551
1593
 
1552
1594
  device = get_device(device)
@@ -1570,68 +1612,19 @@ class Module:
1570
1612
  if not warp.is_cuda_available():
1571
1613
  raise RuntimeError("Failed to build CUDA module because CUDA is not available")
1572
1614
 
1573
- with ScopedTimer(f"Module {self.name} load on device '{device}'", active=not warp.config.quiet):
1574
- build_path = warp.build.kernel_bin_dir
1575
- gen_path = warp.build.kernel_gen_dir
1576
-
1577
- if not os.path.exists(build_path):
1578
- os.makedirs(build_path)
1579
- if not os.path.exists(gen_path):
1580
- os.makedirs(gen_path)
1615
+ module_name = "wp_" + self.name
1616
+ module_hash = self.hash_module()
1581
1617
 
1582
- module_name = "wp_" + self.name
1583
- module_path = os.path.join(build_path, module_name)
1584
- module_hash = self.hash_module()
1585
-
1586
- builder = ModuleBuilder(self, self.options)
1618
+ # use a unique module path using the module short hash
1619
+ module_dir = os.path.join(warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}")
1587
1620
 
1621
+ with ScopedTimer(
1622
+ f"Module {self.name} {module_hash.hex()[:7]} load on device '{device}'", active=not warp.config.quiet
1623
+ ):
1624
+ # -----------------------------------------------------------
1625
+ # determine output paths
1588
1626
  if device.is_cpu:
1589
- obj_path = os.path.join(build_path, module_name)
1590
- obj_path = obj_path + ".o"
1591
- cpu_hash_path = module_path + ".cpu.hash"
1592
-
1593
- # check cache
1594
- if warp.config.cache_kernels and os.path.isfile(cpu_hash_path) and os.path.isfile(obj_path):
1595
- with open(cpu_hash_path, "rb") as f:
1596
- cache_hash = f.read()
1597
-
1598
- if cache_hash == module_hash:
1599
- runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1600
- self.cpu_module = module_name
1601
- return True
1602
-
1603
- # build
1604
- try:
1605
- cpp_path = os.path.join(gen_path, module_name + ".cpp")
1606
-
1607
- # write cpp sources
1608
- cpp_source = builder.codegen("cpu")
1609
-
1610
- cpp_file = open(cpp_path, "w")
1611
- cpp_file.write(cpp_source)
1612
- cpp_file.close()
1613
-
1614
- # build object code
1615
- with ScopedTimer("Compile x86", active=warp.config.verbose):
1616
- warp.build.build_cpu(
1617
- obj_path,
1618
- cpp_path,
1619
- mode=self.options["mode"],
1620
- fast_math=self.options["fast_math"],
1621
- verify_fp=warp.config.verify_fp,
1622
- )
1623
-
1624
- # update cpu hash
1625
- with open(cpu_hash_path, "wb") as f:
1626
- f.write(module_hash)
1627
-
1628
- # load the object code
1629
- runtime.llvm.load_obj(obj_path.encode("utf-8"), module_name.encode("utf-8"))
1630
- self.cpu_module = module_name
1631
-
1632
- except Exception as e:
1633
- self.cpu_build_failed = True
1634
- raise (e)
1627
+ output_name = "module_codegen.o"
1635
1628
 
1636
1629
  elif device.is_cuda:
1637
1630
  # determine whether to use PTX or CUBIN
@@ -1650,62 +1643,138 @@ class Module:
1650
1643
 
1651
1644
  if use_ptx:
1652
1645
  output_arch = min(device.arch, warp.config.ptx_target_arch)
1653
- output_path = module_path + f".sm{output_arch}.ptx"
1646
+ output_name = f"module_codegen.sm{output_arch}.ptx"
1654
1647
  else:
1655
1648
  output_arch = device.arch
1656
- output_path = module_path + f".sm{output_arch}.cubin"
1649
+ output_name = f"module_codegen.sm{output_arch}.cubin"
1650
+
1651
+ # final object binary path
1652
+ binary_path = os.path.join(module_dir, output_name)
1653
+
1654
+ # -----------------------------------------------------------
1655
+ # check cache and build if necessary
1656
+
1657
+ build_dir = None
1658
+
1659
+ if not os.path.exists(binary_path) or not warp.config.cache_kernels:
1660
+ builder = ModuleBuilder(self, self.options)
1661
+
1662
+ # create a temporary (process unique) dir for build outputs before moving to the binary dir
1663
+ build_dir = os.path.join(
1664
+ warp.config.kernel_cache_dir, f"{module_name}_{module_hash.hex()[:7]}_p{os.getpid()}"
1665
+ )
1666
+
1667
+ # dir may exist from previous attempts / runs / archs
1668
+ Path(build_dir).mkdir(parents=True, exist_ok=True)
1669
+
1670
+ # build CPU
1671
+ if device.is_cpu:
1672
+ # build
1673
+ try:
1674
+ source_code_path = os.path.join(build_dir, "module_codegen.cpp")
1675
+
1676
+ # write cpp sources
1677
+ cpp_source = builder.codegen("cpu")
1657
1678
 
1658
- cuda_hash_path = module_path + f".sm{output_arch}.hash"
1679
+ with open(source_code_path, "w") as cpp_file:
1680
+ cpp_file.write(cpp_source)
1659
1681
 
1660
- # check cache
1661
- if warp.config.cache_kernels and os.path.isfile(cuda_hash_path) and os.path.isfile(output_path):
1662
- with open(cuda_hash_path, "rb") as f:
1663
- cache_hash = f.read()
1682
+ output_path = os.path.join(build_dir, output_name)
1664
1683
 
1665
- if cache_hash == module_hash:
1666
- cuda_module = warp.build.load_cuda(output_path, device)
1667
- if cuda_module is not None:
1668
- self.cuda_modules[device.context] = cuda_module
1669
- return True
1684
+ # build object code
1685
+ with ScopedTimer("Compile x86", active=warp.config.verbose):
1686
+ warp.build.build_cpu(
1687
+ output_path,
1688
+ source_code_path,
1689
+ mode=self.options["mode"],
1690
+ fast_math=self.options["fast_math"],
1691
+ verify_fp=warp.config.verify_fp,
1692
+ )
1693
+
1694
+ except Exception as e:
1695
+ self.cpu_build_failed = True
1696
+ raise (e)
1697
+
1698
+ elif device.is_cuda:
1699
+ # build
1700
+ try:
1701
+ source_code_path = os.path.join(build_dir, "module_codegen.cu")
1702
+
1703
+ # write cuda sources
1704
+ cu_source = builder.codegen("cuda")
1705
+
1706
+ with open(source_code_path, "w") as cu_file:
1707
+ cu_file.write(cu_source)
1708
+
1709
+ output_path = os.path.join(build_dir, output_name)
1710
+
1711
+ # generate PTX or CUBIN
1712
+ with ScopedTimer("Compile CUDA", active=warp.config.verbose):
1713
+ warp.build.build_cuda(
1714
+ source_code_path,
1715
+ output_arch,
1716
+ output_path,
1717
+ config=self.options["mode"],
1718
+ fast_math=self.options["fast_math"],
1719
+ verify_fp=warp.config.verify_fp,
1720
+ )
1721
+
1722
+ except Exception as e:
1723
+ self.cuda_build_failed = True
1724
+ raise (e)
1725
+
1726
+ # -----------------------------------------------------------
1727
+ # update cache
1670
1728
 
1671
- # build
1672
1729
  try:
1673
- cu_path = os.path.join(gen_path, module_name + ".cu")
1674
-
1675
- # write cuda sources
1676
- cu_source = builder.codegen("cuda")
1677
-
1678
- cu_file = open(cu_path, "w")
1679
- cu_file.write(cu_source)
1680
- cu_file.close()
1681
-
1682
- # generate PTX or CUBIN
1683
- with ScopedTimer("Compile CUDA", active=warp.config.verbose):
1684
- warp.build.build_cuda(
1685
- cu_path,
1686
- output_arch,
1687
- output_path,
1688
- config=self.options["mode"],
1689
- fast_math=self.options["fast_math"],
1690
- verify_fp=warp.config.verify_fp,
1691
- )
1692
-
1693
- # update cuda hash
1694
- with open(cuda_hash_path, "wb") as f:
1695
- f.write(module_hash)
1696
-
1697
- # load the module
1698
- cuda_module = warp.build.load_cuda(output_path, device)
1699
- if cuda_module is not None:
1700
- self.cuda_modules[device.context] = cuda_module
1701
- else:
1702
- raise Exception(f"Failed to load CUDA module '{self.name}'")
1730
+ # Copy process-specific build directory to a process-independent location
1731
+ os.rename(build_dir, module_dir)
1732
+ except (OSError, FileExistsError):
1733
+ # another process likely updated the module dir first
1734
+ pass
1703
1735
 
1704
- except Exception as e:
1705
- self.cuda_build_failed = True
1706
- raise (e)
1736
+ if os.path.exists(module_dir):
1737
+ if not os.path.exists(binary_path):
1738
+ # copy our output file to the destination module
1739
+ # this is necessary in case different processes
1740
+ # have different GPU architectures / devices
1741
+ try:
1742
+ os.rename(output_path, binary_path)
1743
+ except (OSError, FileExistsError):
1744
+ # another process likely updated the module dir first
1745
+ pass
1707
1746
 
1708
- return True
1747
+ try:
1748
+ final_source_path = os.path.join(module_dir, os.path.basename(source_code_path))
1749
+ if not os.path.exists(final_source_path):
1750
+ os.rename(source_code_path, final_source_path)
1751
+ except (OSError, FileExistsError):
1752
+ # another process likely updated the module dir first
1753
+ pass
1754
+ except Exception as e:
1755
+ # We don't need source_code_path to be copied successfully to proceed, so warn and keep running
1756
+ warp.utils.warn(f"Exception when renaming {source_code_path}: {e}")
1757
+
1758
+ # -----------------------------------------------------------
1759
+ # Load CPU or CUDA binary
1760
+ if device.is_cpu:
1761
+ runtime.llvm.load_obj(binary_path.encode("utf-8"), module_name.encode("utf-8"))
1762
+ self.cpu_module = module_name
1763
+
1764
+ elif device.is_cuda:
1765
+ cuda_module = warp.build.load_cuda(binary_path, device)
1766
+ if cuda_module is not None:
1767
+ self.cuda_modules[device.context] = cuda_module
1768
+ else:
1769
+ raise Exception(f"Failed to load CUDA module '{self.name}'")
1770
+
1771
+ if build_dir:
1772
+ import shutil
1773
+
1774
+ # clean up build_dir used for this process regardless
1775
+ shutil.rmtree(build_dir, ignore_errors=True)
1776
+
1777
+ return True
1709
1778
 
1710
1779
  def unload(self):
1711
1780
  if self.cpu_module:
@@ -2578,22 +2647,36 @@ class Runtime:
2578
2647
  ]
2579
2648
  self.core.cutlass_gemm.restype = ctypes.c_bool
2580
2649
 
2581
- self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64]
2650
+ self.core.volume_create_host.argtypes = [ctypes.c_void_p, ctypes.c_uint64, ctypes.c_bool, ctypes.c_bool]
2582
2651
  self.core.volume_create_host.restype = ctypes.c_uint64
2583
- self.core.volume_get_buffer_info_host.argtypes = [
2652
+ self.core.volume_get_tiles_host.argtypes = [
2584
2653
  ctypes.c_uint64,
2585
- ctypes.POINTER(ctypes.c_void_p),
2586
- ctypes.POINTER(ctypes.c_uint64),
2654
+ ctypes.c_void_p,
2587
2655
  ]
2588
- self.core.volume_get_tiles_host.argtypes = [
2656
+ self.core.volume_get_voxels_host.argtypes = [
2589
2657
  ctypes.c_uint64,
2590
- ctypes.POINTER(ctypes.c_void_p),
2591
- ctypes.POINTER(ctypes.c_uint64),
2658
+ ctypes.c_void_p,
2592
2659
  ]
2593
2660
  self.core.volume_destroy_host.argtypes = [ctypes.c_uint64]
2594
2661
 
2595
- self.core.volume_create_device.argtypes = [ctypes.c_void_p, ctypes.c_void_p, ctypes.c_uint64]
2662
+ self.core.volume_create_device.argtypes = [
2663
+ ctypes.c_void_p,
2664
+ ctypes.c_void_p,
2665
+ ctypes.c_uint64,
2666
+ ctypes.c_bool,
2667
+ ctypes.c_bool,
2668
+ ]
2596
2669
  self.core.volume_create_device.restype = ctypes.c_uint64
2670
+ self.core.volume_get_tiles_device.argtypes = [
2671
+ ctypes.c_uint64,
2672
+ ctypes.c_void_p,
2673
+ ]
2674
+ self.core.volume_get_voxels_device.argtypes = [
2675
+ ctypes.c_uint64,
2676
+ ctypes.c_void_p,
2677
+ ]
2678
+ self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
2679
+
2597
2680
  self.core.volume_f_from_tiles_device.argtypes = [
2598
2681
  ctypes.c_void_p,
2599
2682
  ctypes.c_void_p,
@@ -2632,24 +2715,68 @@ class Runtime:
2632
2715
  ctypes.c_bool,
2633
2716
  ]
2634
2717
  self.core.volume_i_from_tiles_device.restype = ctypes.c_uint64
2635
- self.core.volume_get_buffer_info_device.argtypes = [
2636
- ctypes.c_uint64,
2637
- ctypes.POINTER(ctypes.c_void_p),
2638
- ctypes.POINTER(ctypes.c_uint64),
2718
+ self.core.volume_index_from_tiles_device.argtypes = [
2719
+ ctypes.c_void_p,
2720
+ ctypes.c_void_p,
2721
+ ctypes.c_int,
2722
+ ctypes.c_float,
2723
+ ctypes.c_float,
2724
+ ctypes.c_float,
2725
+ ctypes.c_float,
2726
+ ctypes.c_bool,
2639
2727
  ]
2640
- self.core.volume_get_tiles_device.argtypes = [
2728
+ self.core.volume_index_from_tiles_device.restype = ctypes.c_uint64
2729
+ self.core.volume_from_active_voxels_device.argtypes = [
2730
+ ctypes.c_void_p,
2731
+ ctypes.c_void_p,
2732
+ ctypes.c_int,
2733
+ ctypes.c_float,
2734
+ ctypes.c_float,
2735
+ ctypes.c_float,
2736
+ ctypes.c_float,
2737
+ ctypes.c_bool,
2738
+ ]
2739
+ self.core.volume_from_active_voxels_device.restype = ctypes.c_uint64
2740
+
2741
+ self.core.volume_get_buffer_info.argtypes = [
2641
2742
  ctypes.c_uint64,
2642
2743
  ctypes.POINTER(ctypes.c_void_p),
2643
2744
  ctypes.POINTER(ctypes.c_uint64),
2644
2745
  ]
2645
- self.core.volume_destroy_device.argtypes = [ctypes.c_uint64]
2646
-
2647
2746
  self.core.volume_get_voxel_size.argtypes = [
2648
2747
  ctypes.c_uint64,
2649
2748
  ctypes.POINTER(ctypes.c_float),
2650
2749
  ctypes.POINTER(ctypes.c_float),
2651
2750
  ctypes.POINTER(ctypes.c_float),
2652
2751
  ]
2752
+ self.core.volume_get_tile_and_voxel_count.argtypes = [
2753
+ ctypes.c_uint64,
2754
+ ctypes.POINTER(ctypes.c_uint32),
2755
+ ctypes.POINTER(ctypes.c_uint64),
2756
+ ]
2757
+ self.core.volume_get_grid_info.argtypes = [
2758
+ ctypes.c_uint64,
2759
+ ctypes.POINTER(ctypes.c_uint64),
2760
+ ctypes.POINTER(ctypes.c_uint32),
2761
+ ctypes.POINTER(ctypes.c_uint32),
2762
+ ctypes.c_float * 3,
2763
+ ctypes.c_float * 9,
2764
+ ctypes.c_char * 16,
2765
+ ]
2766
+ self.core.volume_get_grid_info.restype = ctypes.c_char_p
2767
+ self.core.volume_get_blind_data_count.argtypes = [
2768
+ ctypes.c_uint64,
2769
+ ]
2770
+ self.core.volume_get_blind_data_count.restype = ctypes.c_uint64
2771
+ self.core.volume_get_blind_data_info.argtypes = [
2772
+ ctypes.c_uint64,
2773
+ ctypes.c_uint32,
2774
+ ctypes.POINTER(ctypes.c_void_p),
2775
+ ctypes.POINTER(ctypes.c_uint64),
2776
+ ctypes.POINTER(ctypes.c_uint32),
2777
+ ctypes.c_char * 16,
2778
+ ]
2779
+ self.core.volume_get_blind_data_info.restype = ctypes.c_char_p
2653
2780
 
2654
2781
  bsr_matrix_from_triplets_argtypes = [
2655
2782
  ctypes.c_int,
@@ -3194,12 +3321,10 @@ class Runtime:
3194
3321
  raise RuntimeError(f"CUDA error detected: {err}")
3195
3322
 
3196
3323
 
3197
- def assert_initialized():
3198
- assert runtime is not None, "Warp not initialized, call wp.init() before use"
3199
-
3200
-
3201
3324
  # global entry points
3202
3325
  def is_cpu_available():
3326
+ init()
3327
+
3203
3328
  return runtime.llvm
3204
3329
 
3205
3330
 
@@ -3221,7 +3346,7 @@ def is_cuda_driver_initialized() -> bool:
3221
3346
 
3222
3347
  This can be helpful in cases in which ``cuInit()`` was called before a fork.
3223
3348
  """
3224
- assert_initialized()
3349
+ init()
3225
3350
 
3226
3351
  return runtime.core.cuda_driver_is_initialized()
3227
3352
 
@@ -3229,7 +3354,7 @@ def is_cuda_driver_initialized() -> bool:
3229
3354
  def get_devices() -> List[Device]:
3230
3355
  """Returns a list of devices supported in this environment."""
3231
3356
 
3232
- assert_initialized()
3357
+ init()
3233
3358
 
3234
3359
  devices = []
3235
3360
  if is_cpu_available():
@@ -3242,7 +3367,7 @@ def get_devices() -> List[Device]:
3242
3367
  def get_cuda_device_count() -> int:
3243
3368
  """Returns the number of CUDA devices supported in this environment."""
3244
3369
 
3245
- assert_initialized()
3370
+ init()
3246
3371
 
3247
3372
  return len(runtime.cuda_devices)
3248
3373
 
@@ -3250,7 +3375,7 @@ def get_cuda_device_count() -> int:
3250
3375
  def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
3251
3376
  """Returns the CUDA device with the given ordinal or the current CUDA device if ordinal is None."""
3252
3377
 
3253
- assert_initialized()
3378
+ init()
3254
3379
 
3255
3380
  if ordinal is None:
3256
3381
  return runtime.get_current_cuda_device()
@@ -3261,7 +3386,7 @@ def get_cuda_device(ordinal: Union[int, None] = None) -> Device:
3261
3386
  def get_cuda_devices() -> List[Device]:
3262
3387
  """Returns a list of CUDA devices supported in this environment."""
3263
3388
 
3264
- assert_initialized()
3389
+ init()
3265
3390
 
3266
3391
  return runtime.cuda_devices
3267
3392
 
@@ -3269,7 +3394,7 @@ def get_cuda_devices() -> List[Device]:
3269
3394
  def get_preferred_device() -> Device:
3270
3395
  """Returns the preferred compute device, CUDA if available and CPU otherwise."""
3271
3396
 
3272
- assert_initialized()
3397
+ init()
3273
3398
 
3274
3399
  if is_cuda_available():
3275
3400
  return runtime.cuda_devices[0]
@@ -3282,7 +3407,7 @@ def get_preferred_device() -> Device:
3282
3407
  def get_device(ident: Devicelike = None) -> Device:
3283
3408
  """Returns the device identified by the argument."""
3284
3409
 
3285
- assert_initialized()
3410
+ init()
3286
3411
 
3287
3412
  return runtime.get_device(ident)
3288
3413
 
@@ -3290,7 +3415,7 @@ def get_device(ident: Devicelike = None) -> Device:
3290
3415
  def set_device(ident: Devicelike):
3291
3416
  """Sets the target device identified by the argument."""
3292
3417
 
3293
- assert_initialized()
3418
+ init()
3294
3419
 
3295
3420
  device = runtime.get_device(ident)
3296
3421
  runtime.set_default_device(device)
@@ -3311,7 +3436,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3311
3436
  The associated wp.Device.
3312
3437
  """
3313
3438
 
3314
- assert_initialized()
3439
+ init()
3315
3440
 
3316
3441
  return runtime.map_cuda_device(alias, context)
3317
3442
 
@@ -3319,7 +3444,7 @@ def map_cuda_device(alias: str, context: ctypes.c_void_p = None) -> Device:
3319
3444
  def unmap_cuda_device(alias: str):
3320
3445
  """Remove a CUDA device with the given alias."""
3321
3446
 
3322
- assert_initialized()
3447
+ init()
3323
3448
 
3324
3449
  runtime.unmap_cuda_device(alias)
3325
3450
 
@@ -3327,7 +3452,7 @@ def unmap_cuda_device(alias: str):
3327
3452
  def is_mempool_supported(device: Devicelike):
3328
3453
  """Check if CUDA memory pool allocators are available on the device."""
3329
3454
 
3330
- assert_initialized()
3455
+ init()
3331
3456
 
3332
3457
  device = runtime.get_device(device)
3333
3458
 
@@ -3337,7 +3462,7 @@ def is_mempool_supported(device: Devicelike):
3337
3462
  def is_mempool_enabled(device: Devicelike):
3338
3463
  """Check if CUDA memory pool allocators are enabled on the device."""
3339
3464
 
3340
- assert_initialized()
3465
+ init()
3341
3466
 
3342
3467
  device = runtime.get_device(device)
3343
3468
 
@@ -3357,7 +3482,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool):
3357
3482
  prior to graph capture.
3358
3483
  """
3359
3484
 
3360
- assert_initialized()
3485
+ init()
3361
3486
 
3362
3487
  device = runtime.get_device(device)
3363
3488
 
@@ -3387,7 +3512,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3387
3512
  For example, 1024**3 means one GiB of memory.
3388
3513
  """
3389
3514
 
3390
- assert_initialized()
3515
+ init()
3391
3516
 
3392
3517
  device = runtime.get_device(device)
3393
3518
 
@@ -3409,7 +3534,7 @@ def set_mempool_release_threshold(device: Devicelike, threshold: Union[int, floa
3409
3534
  def get_mempool_release_threshold(device: Devicelike):
3410
3535
  """Get the CUDA memory pool release threshold on the device."""
3411
3536
 
3412
- assert_initialized()
3537
+ init()
3413
3538
 
3414
3539
  device = runtime.get_device(device)
3415
3540
 
@@ -3432,7 +3557,7 @@ def is_peer_access_supported(target_device: Devicelike, peer_device: Devicelike)
3432
3557
  A Boolean value indicating if this peer access is supported by the system.
3433
3558
  """
3434
3559
 
3435
- assert_initialized()
3560
+ init()
3436
3561
 
3437
3562
  target_device = runtime.get_device(target_device)
3438
3563
  peer_device = runtime.get_device(peer_device)
@@ -3453,7 +3578,7 @@ def is_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike):
3453
3578
  A Boolean value indicating if this peer access is currently enabled.
3454
3579
  """
3455
3580
 
3456
- assert_initialized()
3581
+ init()
3457
3582
 
3458
3583
  target_device = runtime.get_device(target_device)
3459
3584
  peer_device = runtime.get_device(peer_device)
@@ -3474,7 +3599,7 @@ def set_peer_access_enabled(target_device: Devicelike, peer_device: Devicelike,
3474
3599
  CUDA pooled allocators, use `set_mempool_access_enabled()`.
3475
3600
  """
3476
3601
 
3477
- assert_initialized()
3602
+ init()
3478
3603
 
3479
3604
  target_device = runtime.get_device(target_device)
3480
3605
  peer_device = runtime.get_device(peer_device)
@@ -3505,7 +3630,7 @@ def is_mempool_access_supported(target_device: Devicelike, peer_device: Deviceli
3505
3630
  A Boolean value indicating if this memory pool access is supported by the system.
3506
3631
  """
3507
3632
 
3508
- assert_initialized()
3633
+ init()
3509
3634
 
3510
3635
  return target_device.is_mempool_supported and is_peer_access_supported(target_device, peer_device)
3511
3636
 
@@ -3520,7 +3645,7 @@ def is_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelike
3520
3645
  A Boolean value indicating if this peer access is currently enabled.
3521
3646
  """
3522
3647
 
3523
- assert_initialized()
3648
+ init()
3524
3649
 
3525
3650
  target_device = runtime.get_device(target_device)
3526
3651
  peer_device = runtime.get_device(peer_device)
@@ -3538,7 +3663,7 @@ def set_mempool_access_enabled(target_device: Devicelike, peer_device: Devicelik
3538
3663
  default CUDA allocators, use `set_peer_access_enabled()`.
3539
3664
  """
3540
3665
 
3541
- assert_initialized()
3666
+ init()
3542
3667
 
3543
3668
  target_device = runtime.get_device(target_device)
3544
3669
  peer_device = runtime.get_device(peer_device)
@@ -3640,34 +3765,87 @@ def wait_stream(stream: Stream, event: Event = None):
3640
3765
 
3641
3766
  class RegisteredGLBuffer:
3642
3767
  """
3643
- Helper object to register a GL buffer with CUDA so that it can be mapped to a Warp array.
3768
+ Helper class to register a GL buffer with CUDA so that it can be mapped to a Warp array.
3769
+
3770
+ Example usage::
3771
+
3772
+ import warp as wp
3773
+ import numpy as np
3774
+ from pyglet.gl import *
3775
+
3776
+ wp.init()
3777
+
3778
+ # create a GL buffer
3779
+ gl_buffer_id = GLuint()
3780
+ glGenBuffers(1, gl_buffer_id)
3781
+
3782
+ # copy some data to the GL buffer
3783
+ glBindBuffer(GL_ARRAY_BUFFER, gl_buffer_id)
3784
+ gl_data = np.arange(1024, dtype=np.float32)
3785
+ glBufferData(GL_ARRAY_BUFFER, gl_data.nbytes, gl_data.ctypes.data, GL_DYNAMIC_DRAW)
3786
+ glBindBuffer(GL_ARRAY_BUFFER, 0)
3787
+
3788
+ # register the GL buffer with CUDA
3789
+ cuda_gl_buffer = wp.RegisteredGLBuffer(gl_buffer_id)
3790
+
3791
+ # map the GL buffer to a Warp array
3792
+ arr = cuda_gl_buffer.map(dtype=wp.float32, shape=(1024,))
3793
+ # launch a Warp kernel to manipulate or read the array
3794
+ wp.launch(my_kernel, dim=1024, inputs=[arr])
3795
+ # unmap the GL buffer
3796
+ cuda_gl_buffer.unmap()
3644
3797
  """
3645
3798
 
3646
- # Specifies no hints about how this resource will be used.
3647
- # It is therefore assumed that this resource will be
3648
- # read from and written to by CUDA. This is the default value.
3649
3799
  NONE = 0x00
3800
+ """
3801
+ Flag that specifies no hints about how this resource will be used.
3802
+ It is therefore assumed that this resource will be
3803
+ read from and written to by CUDA. This is the default value.
3804
+ """
3650
3805
 
3651
- # Specifies that CUDA will not write to this resource.
3652
3806
  READ_ONLY = 0x01
3807
+ """
3808
+ Flag that specifies that CUDA will not write to this resource.
3809
+ """
3653
3810
 
3654
- # Specifies that CUDA will not read from this resource and will write over the
3655
- # entire contents of the resource, so none of the data previously
3656
- # stored in the resource will be preserved.
3657
3811
  WRITE_DISCARD = 0x02
3812
+ """
3813
+ Flag that specifies that CUDA will not read from this resource and will write over the
3814
+ entire contents of the resource, so none of the data previously
3815
+ stored in the resource will be preserved.
3816
+ """
3658
3817
 
3659
- def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE):
3660
- """Create a new RegisteredGLBuffer object.
3818
+ __fallback_warning_shown = False
3661
3819
 
3820
+ def __init__(self, gl_buffer_id: int, device: Devicelike = None, flags: int = NONE, fallback_to_copy: bool = True):
3821
+ """
3662
3822
  Args:
3663
3823
  gl_buffer_id: The OpenGL buffer id (GLuint).
3664
3824
  device: The device to register the buffer with. If None, the current device will be used.
3665
- flags: A combination of the flags constants.
3825
+ flags: A combination of the flags constants :attr:`NONE`, :attr:`READ_ONLY`, and :attr:`WRITE_DISCARD`.
3826
+ fallback_to_copy: If True and CUDA/OpenGL interop is not available, fall back to copy operations between the Warp array and the OpenGL buffer. Otherwise, a ``RuntimeError`` will be raised.
3827
+
3828
+ Note:
3829
+
3830
+ The ``fallback_to_copy`` option (to use copy operations if CUDA graphics interop functionality is not available) requires pyglet version 2.0 or later. Install via ``pip install pyglet==2.*``.
3666
3831
  """
3667
3832
  self.gl_buffer_id = gl_buffer_id
3668
3833
  self.device = get_device(device)
3669
3834
  self.context = self.device.context
3835
+ self.flags = flags
3836
+ self.fallback_to_copy = fallback_to_copy
3670
3837
  self.resource = runtime.core.cuda_graphics_register_gl_buffer(self.context, gl_buffer_id, flags)
3838
+ if self.resource is None:
3839
+ if self.fallback_to_copy:
3840
+ self.warp_buffer = None
3841
+ self.warp_buffer_cpu = None
3842
+ if not RegisteredGLBuffer.__fallback_warning_shown:
3843
+ warp.utils.warn(
3844
+ "Could not register GL buffer since CUDA/OpenGL interoperability is not available. Falling back to copy operations between the Warp array and the OpenGL buffer.",
3845
+ )
3846
+ RegisteredGLBuffer.__fallback_warning_shown = True
3847
+ else:
3848
+ raise RuntimeError(f"Failed to register OpenGL buffer {gl_buffer_id} with CUDA")
3671
3849
 
3672
3850
  def __del__(self):
3673
3851
  if not self.resource:
@@ -3687,18 +3865,48 @@ class RegisteredGLBuffer:
3687
3865
  Returns:
3688
3866
  A Warp array object representing the mapped OpenGL buffer.
3689
3867
  """
3690
- runtime.core.cuda_graphics_map(self.context, self.resource)
3691
- ctypes.POINTER(ctypes.c_uint64), ctypes.POINTER(ctypes.c_size_t)
3692
- ptr = ctypes.c_uint64(0)
3693
- size = ctypes.c_size_t(0)
3694
- runtime.core.cuda_graphics_device_ptr_and_size(
3695
- self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
3696
- )
3697
- return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
3868
+ if self.resource is not None:
3869
+ runtime.core.cuda_graphics_map(self.context, self.resource)
3870
+ ptr = ctypes.c_uint64(0)
3871
+ size = ctypes.c_size_t(0)
3872
+ runtime.core.cuda_graphics_device_ptr_and_size(
3873
+ self.context, self.resource, ctypes.byref(ptr), ctypes.byref(size)
3874
+ )
3875
+ return warp.array(ptr=ptr.value, dtype=dtype, shape=shape, device=self.device)
3876
+ elif self.fallback_to_copy:
3877
+ if self.warp_buffer is None or self.warp_buffer.dtype != dtype or self.warp_buffer.shape != shape:
3878
+ self.warp_buffer = warp.empty(shape, dtype, device=self.device)
3879
+ self.warp_buffer_cpu = warp.empty(shape, dtype, device="cpu", pinned=True)
3880
+
3881
+ if self.flags == self.READ_ONLY or self.flags == self.NONE:
3882
+ # copy from OpenGL buffer to Warp array
3883
+ from pyglet import gl
3884
+
3885
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.gl_buffer_id)
3886
+ nbytes = self.warp_buffer.size * warp.types.type_size_in_bytes(dtype)
3887
+ gl.glGetBufferSubData(gl.GL_ARRAY_BUFFER, 0, nbytes, self.warp_buffer_cpu.ptr)
3888
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
3889
+ warp.copy(self.warp_buffer, self.warp_buffer_cpu)
3890
+ return self.warp_buffer
3891
+
3892
+ return None
3698
3893
 
3699
3894
  def unmap(self):
3700
3895
  """Unmap the OpenGL buffer."""
3701
- runtime.core.cuda_graphics_unmap(self.context, self.resource)
3896
+ if self.resource is not None:
3897
+ runtime.core.cuda_graphics_unmap(self.context, self.resource)
3898
+ elif self.fallback_to_copy:
3899
+ if self.warp_buffer is None:
3900
+ raise RuntimeError("RegisteredGLBuffer first has to be mapped")
3901
+
3902
+ if self.flags == self.WRITE_DISCARD or self.flags == self.NONE:
3903
+ # copy from Warp array to OpenGL buffer
3904
+ from pyglet import gl
3905
+
3906
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, self.gl_buffer_id)
3907
+ buffer = self.warp_buffer.numpy()
3908
+ gl.glBufferData(gl.GL_ARRAY_BUFFER, buffer.nbytes, buffer.ctypes.data, gl.GL_DYNAMIC_DRAW)
3909
+ gl.glBindBuffer(gl.GL_ARRAY_BUFFER, 0)
3702
3910
 
3703
3911
 
3704
3912
  def zeros(
@@ -4253,7 +4461,7 @@ def launch(
4253
4461
  If negative or zero, the maximum hardware value will be used.
4254
4462
  """
4255
4463
 
4256
- assert_initialized()
4464
+ init()
4257
4465
 
4258
4466
  # if stream is specified, use the associated device
4259
4467
  if stream is not None:
@@ -4496,7 +4704,7 @@ def force_load(device: Union[Device, str, List[Device], List[str]] = None, modul
4496
4704
 
4497
4705
 
4498
4706
  def load_module(
4499
- module: Union[Module, ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
4707
+ module: Union[Module, types.ModuleType, str] = None, device: Union[Device, str] = None, recursive: bool = False
4500
4708
  ):
4501
4709
  """Force user-defined module to be compiled and loaded
4502
4710
 
@@ -4514,7 +4722,7 @@ def load_module(
4514
4722
  module_name = module.__name__
4515
4723
  elif isinstance(module, Module):
4516
4724
  module_name = module.name
4517
- elif isinstance(module, ModuleType):
4725
+ elif isinstance(module, types.ModuleType):
4518
4726
  module_name = module.__name__
4519
4727
  elif isinstance(module, str):
4520
4728
  module_name = module
@@ -4863,13 +5071,20 @@ def copy(
4863
5071
 
4864
5072
  # copy gradient, if needed
4865
5073
  if hasattr(src, "grad") and src.grad is not None and hasattr(dest, "grad") and dest.grad is not None:
4866
- copy(dest.grad, src.grad, stream=stream)
5074
+ copy(dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)
4867
5075
 
4868
5076
  if runtime.tape:
4869
- runtime.tape.record_func(backward=lambda: adj_copy(dest.grad, src.grad, stream=stream), arrays=[dest, src])
5077
+ runtime.tape.record_func(
5078
+ backward=lambda: adj_copy(
5079
+ dest.grad, src.grad, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream
5080
+ ),
5081
+ arrays=[dest, src],
5082
+ )
4870
5083
 
4871
5084
 
4872
- def adj_copy(adj_dest: warp.array, adj_src: warp.array, stream: Stream = None):
5085
+ def adj_copy(
5086
+ adj_dest: warp.array, adj_src: warp.array, dest_offset: int, src_offset: int, count: int, stream: Stream = None
5087
+ ):
4873
5088
  """Copy adjoint operation for wp.copy() calls on the tape.
4874
5089
 
4875
5090
  Args:
@@ -4877,7 +5092,7 @@ def adj_copy(adj_dest: warp.array, adj_src: warp.array, stream: Stream = None):
4877
5092
  adj_src: Source array adjoint
4878
5093
  stream: The stream on which the copy was performed in the forward pass
4879
5094
  """
4880
- copy(adj_src, adj_dest, stream=stream)
5095
+ copy(adj_src, adj_dest, dest_offset=dest_offset, src_offset=src_offset, count=count, stream=stream)
4881
5096
 
4882
5097
 
4883
5098
  def type_str(t):