warp-lang 0.11.0__py3-none-manylinux2014_x86_64.whl → 1.0.0__py3-none-manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (170) hide show
  1. warp/__init__.py +8 -0
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +7 -6
  5. warp/build_dll.py +70 -79
  6. warp/builtins.py +10 -6
  7. warp/codegen.py +51 -19
  8. warp/config.py +7 -8
  9. warp/constants.py +3 -0
  10. warp/context.py +948 -245
  11. warp/dlpack.py +198 -113
  12. warp/examples/assets/bunny.usd +0 -0
  13. warp/examples/assets/cartpole.urdf +110 -0
  14. warp/examples/assets/crazyflie.usd +0 -0
  15. warp/examples/assets/cube.usda +42 -0
  16. warp/examples/assets/nv_ant.xml +92 -0
  17. warp/examples/assets/nv_humanoid.xml +183 -0
  18. warp/examples/assets/quadruped.urdf +268 -0
  19. warp/examples/assets/rocks.nvdb +0 -0
  20. warp/examples/assets/rocks.usd +0 -0
  21. warp/examples/assets/sphere.usda +56 -0
  22. warp/examples/assets/torus.usda +105 -0
  23. warp/examples/benchmarks/benchmark_api.py +383 -0
  24. warp/examples/benchmarks/benchmark_cloth.py +279 -0
  25. warp/examples/benchmarks/benchmark_cloth_cupy.py +88 -0
  26. warp/examples/benchmarks/benchmark_cloth_jax.py +100 -0
  27. warp/examples/benchmarks/benchmark_cloth_numba.py +142 -0
  28. warp/examples/benchmarks/benchmark_cloth_numpy.py +77 -0
  29. warp/examples/benchmarks/benchmark_cloth_pytorch.py +86 -0
  30. warp/examples/benchmarks/benchmark_cloth_taichi.py +112 -0
  31. warp/examples/benchmarks/benchmark_cloth_warp.py +146 -0
  32. warp/examples/benchmarks/benchmark_launches.py +295 -0
  33. warp/examples/core/example_dem.py +221 -0
  34. warp/examples/core/example_fluid.py +267 -0
  35. warp/examples/core/example_graph_capture.py +129 -0
  36. warp/examples/core/example_marching_cubes.py +177 -0
  37. warp/examples/core/example_mesh.py +154 -0
  38. warp/examples/core/example_mesh_intersect.py +193 -0
  39. warp/examples/core/example_nvdb.py +169 -0
  40. warp/examples/core/example_raycast.py +89 -0
  41. warp/examples/core/example_raymarch.py +178 -0
  42. warp/examples/core/example_render_opengl.py +141 -0
  43. warp/examples/core/example_sph.py +389 -0
  44. warp/examples/core/example_torch.py +181 -0
  45. warp/examples/core/example_wave.py +249 -0
  46. warp/examples/fem/bsr_utils.py +380 -0
  47. warp/examples/fem/example_apic_fluid.py +391 -0
  48. warp/examples/fem/example_convection_diffusion.py +168 -0
  49. warp/examples/fem/example_convection_diffusion_dg.py +209 -0
  50. warp/examples/fem/example_convection_diffusion_dg0.py +194 -0
  51. warp/examples/fem/example_deformed_geometry.py +159 -0
  52. warp/examples/fem/example_diffusion.py +173 -0
  53. warp/examples/fem/example_diffusion_3d.py +152 -0
  54. warp/examples/fem/example_diffusion_mgpu.py +214 -0
  55. warp/examples/fem/example_mixed_elasticity.py +222 -0
  56. warp/examples/fem/example_navier_stokes.py +243 -0
  57. warp/examples/fem/example_stokes.py +192 -0
  58. warp/examples/fem/example_stokes_transfer.py +249 -0
  59. warp/examples/fem/mesh_utils.py +109 -0
  60. warp/examples/fem/plot_utils.py +287 -0
  61. warp/examples/optim/example_bounce.py +248 -0
  62. warp/examples/optim/example_cloth_throw.py +210 -0
  63. warp/examples/optim/example_diffray.py +535 -0
  64. warp/examples/optim/example_drone.py +850 -0
  65. warp/examples/optim/example_inverse_kinematics.py +169 -0
  66. warp/examples/optim/example_inverse_kinematics_torch.py +170 -0
  67. warp/examples/optim/example_spring_cage.py +234 -0
  68. warp/examples/optim/example_trajectory.py +201 -0
  69. warp/examples/sim/example_cartpole.py +128 -0
  70. warp/examples/sim/example_cloth.py +184 -0
  71. warp/examples/sim/example_granular.py +113 -0
  72. warp/examples/sim/example_granular_collision_sdf.py +185 -0
  73. warp/examples/sim/example_jacobian_ik.py +213 -0
  74. warp/examples/sim/example_particle_chain.py +106 -0
  75. warp/examples/sim/example_quadruped.py +179 -0
  76. warp/examples/sim/example_rigid_chain.py +191 -0
  77. warp/examples/sim/example_rigid_contact.py +176 -0
  78. warp/examples/sim/example_rigid_force.py +126 -0
  79. warp/examples/sim/example_rigid_gyroscopic.py +97 -0
  80. warp/examples/sim/example_rigid_soft_contact.py +124 -0
  81. warp/examples/sim/example_soft_body.py +178 -0
  82. warp/fabric.py +29 -20
  83. warp/fem/cache.py +0 -1
  84. warp/fem/dirichlet.py +0 -2
  85. warp/fem/integrate.py +0 -1
  86. warp/jax.py +45 -0
  87. warp/jax_experimental.py +339 -0
  88. warp/native/builtin.h +12 -0
  89. warp/native/bvh.cu +18 -18
  90. warp/native/clang/clang.cpp +8 -3
  91. warp/native/cuda_util.cpp +94 -5
  92. warp/native/cuda_util.h +35 -6
  93. warp/native/cutlass_gemm.cpp +1 -1
  94. warp/native/cutlass_gemm.cu +4 -1
  95. warp/native/error.cpp +66 -0
  96. warp/native/error.h +27 -0
  97. warp/native/mesh.cu +2 -2
  98. warp/native/reduce.cu +4 -4
  99. warp/native/runlength_encode.cu +2 -2
  100. warp/native/scan.cu +2 -2
  101. warp/native/sparse.cu +0 -1
  102. warp/native/temp_buffer.h +2 -2
  103. warp/native/warp.cpp +95 -60
  104. warp/native/warp.cu +1053 -218
  105. warp/native/warp.h +49 -32
  106. warp/optim/linear.py +33 -16
  107. warp/render/render_opengl.py +202 -101
  108. warp/render/render_usd.py +82 -40
  109. warp/sim/__init__.py +13 -4
  110. warp/sim/articulation.py +4 -5
  111. warp/sim/collide.py +320 -175
  112. warp/sim/import_mjcf.py +25 -30
  113. warp/sim/import_urdf.py +94 -63
  114. warp/sim/import_usd.py +51 -36
  115. warp/sim/inertia.py +3 -2
  116. warp/sim/integrator.py +233 -0
  117. warp/sim/integrator_euler.py +447 -469
  118. warp/sim/integrator_featherstone.py +1991 -0
  119. warp/sim/integrator_xpbd.py +1420 -640
  120. warp/sim/model.py +765 -487
  121. warp/sim/particles.py +2 -1
  122. warp/sim/render.py +35 -13
  123. warp/sim/utils.py +222 -11
  124. warp/stubs.py +8 -0
  125. warp/tape.py +16 -1
  126. warp/tests/aux_test_grad_customs.py +23 -0
  127. warp/tests/test_array.py +190 -1
  128. warp/tests/test_async.py +656 -0
  129. warp/tests/test_bool.py +50 -0
  130. warp/tests/test_dlpack.py +164 -11
  131. warp/tests/test_examples.py +166 -74
  132. warp/tests/test_fem.py +8 -1
  133. warp/tests/test_generics.py +15 -5
  134. warp/tests/test_grad.py +1 -1
  135. warp/tests/test_grad_customs.py +172 -12
  136. warp/tests/test_jax.py +254 -0
  137. warp/tests/test_large.py +29 -6
  138. warp/tests/test_launch.py +25 -0
  139. warp/tests/test_linear_solvers.py +20 -3
  140. warp/tests/test_matmul.py +61 -16
  141. warp/tests/test_matmul_lite.py +13 -13
  142. warp/tests/test_mempool.py +186 -0
  143. warp/tests/test_multigpu.py +3 -0
  144. warp/tests/test_options.py +16 -2
  145. warp/tests/test_peer.py +137 -0
  146. warp/tests/test_print.py +3 -1
  147. warp/tests/test_quat.py +23 -0
  148. warp/tests/test_sim_kinematics.py +97 -0
  149. warp/tests/test_snippet.py +126 -3
  150. warp/tests/test_streams.py +108 -79
  151. warp/tests/test_torch.py +16 -8
  152. warp/tests/test_utils.py +32 -27
  153. warp/tests/test_verify_fp.py +65 -0
  154. warp/tests/test_volume.py +1 -1
  155. warp/tests/unittest_serial.py +2 -0
  156. warp/tests/unittest_suites.py +12 -0
  157. warp/tests/unittest_utils.py +14 -7
  158. warp/thirdparty/unittest_parallel.py +15 -3
  159. warp/torch.py +10 -8
  160. warp/types.py +363 -246
  161. warp/utils.py +143 -19
  162. warp_lang-1.0.0.dist-info/LICENSE.md +126 -0
  163. warp_lang-1.0.0.dist-info/METADATA +394 -0
  164. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/RECORD +167 -86
  165. warp/sim/optimizer.py +0 -138
  166. warp_lang-0.11.0.dist-info/LICENSE.md +0 -36
  167. warp_lang-0.11.0.dist-info/METADATA +0 -238
  168. /warp/tests/{walkthough_debug.py → walkthrough_debug.py} +0 -0
  169. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/WHEEL +0 -0
  170. {warp_lang-0.11.0.dist-info → warp_lang-1.0.0.dist-info}/top_level.txt +0 -0
warp/__init__.py CHANGED
@@ -45,6 +45,8 @@ from warp.context import get_device, set_device, synchronize_device
45
45
  from warp.context import (
46
46
  zeros,
47
47
  zeros_like,
48
+ ones,
49
+ ones_like,
48
50
  full,
49
51
  full_like,
50
52
  clone,
@@ -63,9 +65,15 @@ from warp.context import Kernel, Function, Launch
63
65
  from warp.context import Stream, get_stream, set_stream, synchronize_stream
64
66
  from warp.context import Event, record_event, wait_event, wait_stream
65
67
  from warp.context import RegisteredGLBuffer
68
+ from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
69
+ from warp.context import set_mempool_release_threshold, get_mempool_release_threshold
70
+ from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
71
+ from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
66
72
 
67
73
  from warp.tape import Tape
68
74
  from warp.utils import ScopedTimer, ScopedDevice, ScopedStream
75
+ from warp.utils import ScopedMempool, ScopedMempoolAccess, ScopedPeerAccess
76
+ from warp.utils import ScopedCapture
69
77
  from warp.utils import transform_expand, quat_between_vectors
70
78
 
71
79
  from warp.torch import from_torch, to_torch
warp/bin/warp-clang.so CHANGED
Binary file
warp/bin/warp.so CHANGED
Binary file
warp/build.py CHANGED
@@ -45,7 +45,7 @@ def build_cpu(obj_path, cpp_path, mode="release", verify_fp=False, fast_math=Fal
45
45
  inc_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "native").encode("utf-8")
46
46
  obj_path = obj_path.encode("utf-8")
47
47
 
48
- err = warp.context.runtime.llvm.compile_cpp(src, cpp_path, inc_path, obj_path, mode == "debug")
48
+ err = warp.context.runtime.llvm.compile_cpp(src, cpp_path, inc_path, obj_path, mode == "debug", verify_fp)
49
49
  if err != 0:
50
50
  raise Exception(f"CPU kernel build failed with error code {err}")
51
51
 
@@ -66,9 +66,7 @@ def init_kernel_cache(path=None):
66
66
  if path is not None:
67
67
  cache_root_dir = os.path.realpath(path)
68
68
  else:
69
- cache_root_dir = appdirs.user_cache_dir(
70
- appname="warp", appauthor="NVIDIA Corporation", version=warp.config.version
71
- )
69
+ cache_root_dir = appdirs.user_cache_dir(appname="warp", appauthor="NVIDIA", version=warp.config.version)
72
70
 
73
71
  cache_bin_dir = os.path.join(cache_root_dir, "bin")
74
72
  cache_gen_dir = os.path.join(cache_root_dir, "gen")
@@ -95,15 +93,18 @@ def init_kernel_cache(path=None):
95
93
  def clear_kernel_cache():
96
94
  """Clear the kernel cache."""
97
95
 
96
+ is_intialized = kernel_bin_dir is not None and kernel_gen_dir is not None
97
+ assert is_intialized, "The kernel cache directory is not configured; wp.init() has not been called yet or failed."
98
+
98
99
  import glob
99
100
 
100
101
  paths = []
101
102
 
102
- if kernel_bin_dir is not None and os.path.isdir(kernel_bin_dir):
103
+ if os.path.isdir(kernel_bin_dir):
103
104
  pattern = os.path.join(kernel_bin_dir, "wp_*")
104
105
  paths += glob.glob(pattern)
105
106
 
106
- if kernel_gen_dir is not None and os.path.isdir(kernel_gen_dir):
107
+ if os.path.isdir(kernel_gen_dir):
107
108
  pattern = os.path.join(kernel_gen_dir, "wp_*")
108
109
  paths += glob.glob(pattern)
109
110
 
warp/build_dll.py CHANGED
@@ -10,9 +10,10 @@ import os
10
10
  import subprocess
11
11
  import platform
12
12
 
13
- import warp.config
14
13
  from warp.utils import ScopedTimer
15
14
 
15
+ verbose_cmd = True # print command lines before executing them
16
+
16
17
 
17
18
  # returns a canonical machine architecture string
18
19
  # - "x86_64" for x86-64, aka. AMD64, aka. x64
@@ -26,8 +27,8 @@ def machine_architecture() -> str:
26
27
  raise RuntimeError(f"Unrecognized machine architecture {machine}")
27
28
 
28
29
 
29
- def run_cmd(cmd, capture=False):
30
- if warp.config.verbose:
30
+ def run_cmd(cmd):
31
+ if verbose_cmd:
31
32
  print(cmd)
32
33
 
33
34
  try:
@@ -41,8 +42,8 @@ def run_cmd(cmd, capture=False):
41
42
 
42
43
 
43
44
  # cut-down version of vcvars64.bat that allows using
44
- # custom toolchain locations
45
- def set_msvc_compiler(msvc_path, sdk_path):
45
+ # custom toolchain locations, returns the compiler program path
46
+ def set_msvc_env(msvc_path, sdk_path):
46
47
  if "INCLUDE" not in os.environ:
47
48
  os.environ["INCLUDE"] = ""
48
49
 
@@ -65,58 +66,51 @@ def set_msvc_compiler(msvc_path, sdk_path):
65
66
  os.environ["PATH"] += os.pathsep + os.path.join(msvc_path, "bin/HostX64/x64")
66
67
  os.environ["PATH"] += os.pathsep + os.path.join(sdk_path, "bin/x64")
67
68
 
68
- warp.config.host_compiler = os.path.join(msvc_path, "bin", "HostX64", "x64", "cl.exe")
69
+ return os.path.join(msvc_path, "bin", "HostX64", "x64", "cl.exe")
69
70
 
70
71
 
71
72
  def find_host_compiler():
72
73
  if os.name == "nt":
73
- try:
74
- # try and find an installed host compiler (msvc)
75
- # runs vcvars and copies back the build environment
76
-
77
- vswhere_path = r"%ProgramFiles(x86)%/Microsoft Visual Studio/Installer/vswhere.exe"
78
- vswhere_path = os.path.expandvars(vswhere_path)
79
- if not os.path.exists(vswhere_path):
80
- return ""
81
-
82
- vs_path = run_cmd(f'"{vswhere_path}" -latest -property installationPath').decode().rstrip()
83
- vsvars_path = os.path.join(vs_path, "VC\\Auxiliary\\Build\\vcvars64.bat")
84
-
85
- output = run_cmd(f'"{vsvars_path}" && set').decode()
86
-
87
- for line in output.splitlines():
88
- pair = line.split("=", 1)
89
- if len(pair) >= 2:
90
- os.environ[pair[0]] = pair[1]
91
-
92
- cl_path = run_cmd("where cl.exe").decode("utf-8").rstrip()
93
- cl_version = os.environ["VCToolsVersion"].split(".")
94
-
95
- # ensure at least VS2019 version, see list of MSVC versions here https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B
96
- cl_required_major = 14
97
- cl_required_minor = 29
98
-
99
- if (
100
- (int(cl_version[0]) < cl_required_major)
101
- or (int(cl_version[0]) == cl_required_major)
102
- and int(cl_version[1]) < cl_required_minor
103
- ):
104
- print(
105
- f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
106
- )
107
- return ""
74
+ # try and find an installed host compiler (msvc)
75
+ # runs vcvars and copies back the build environment
76
+
77
+ vswhere_path = r"%ProgramFiles(x86)%/Microsoft Visual Studio/Installer/vswhere.exe"
78
+ vswhere_path = os.path.expandvars(vswhere_path)
79
+ if not os.path.exists(vswhere_path):
80
+ return ""
81
+
82
+ vs_path = run_cmd(f'"{vswhere_path}" -latest -property installationPath').decode().rstrip()
83
+ vsvars_path = os.path.join(vs_path, "VC\\Auxiliary\\Build\\vcvars64.bat")
84
+
85
+ output = run_cmd(f'"{vsvars_path}" && set').decode()
108
86
 
109
- return cl_path
87
+ for line in output.splitlines():
88
+ pair = line.split("=", 1)
89
+ if len(pair) >= 2:
90
+ os.environ[pair[0]] = pair[1]
110
91
 
111
- except Exception as e:
112
- # couldn't find host compiler
92
+ cl_path = run_cmd("where cl.exe").decode("utf-8").rstrip()
93
+ cl_version = os.environ["VCToolsVersion"].split(".")
94
+
95
+ # ensure at least VS2019 version, see list of MSVC versions here https://en.wikipedia.org/wiki/Microsoft_Visual_C%2B%2B
96
+ cl_required_major = 14
97
+ cl_required_minor = 29
98
+
99
+ if (
100
+ (int(cl_version[0]) < cl_required_major)
101
+ or (int(cl_version[0]) == cl_required_major)
102
+ and int(cl_version[1]) < cl_required_minor
103
+ ):
104
+ print(
105
+ f"Warp: MSVC found but compiler version too old, found {cl_version[0]}.{cl_version[1]}, but must be {cl_required_major}.{cl_required_minor} or higher, kernel host compilation will be disabled."
106
+ )
113
107
  return ""
108
+
109
+ return cl_path
110
+
114
111
  else:
115
112
  # try and find g++
116
- try:
117
- return run_cmd("which g++").decode()
118
- except:
119
- return ""
113
+ return run_cmd("which g++").decode()
120
114
 
121
115
 
122
116
  def get_cuda_toolkit_version(cuda_home):
@@ -141,11 +135,12 @@ def quote(path):
141
135
  return '"' + path + '"'
142
136
 
143
137
 
144
- def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp=False, fast_math=False, quick=False):
145
- cuda_home = warp.config.cuda_path
138
+ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, arch, mode=None):
139
+ mode = args.mode if (mode is None) else mode
140
+ cuda_home = args.cuda_path
146
141
  cuda_cmd = None
147
142
 
148
- if quick:
143
+ if args.quick:
149
144
  cutlass_includes = ""
150
145
  cutlass_enabled = "WP_ENABLE_CUTLASS=0"
151
146
  else:
@@ -153,7 +148,7 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
153
148
  cutlass_includes = f'-I"{cutlass_home}/include" -I"{cutlass_home}/tools/util/include"'
154
149
  cutlass_enabled = "WP_ENABLE_CUTLASS=1"
155
150
 
156
- if quick or cu_path is None:
151
+ if args.quick or cu_path is None:
157
152
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=0"
158
153
  else:
159
154
  cuda_compat_enabled = "WP_ENABLE_CUDA_COMPATIBILITY=1"
@@ -165,7 +160,7 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
165
160
  nanovdb_home = warp_home_path.parent / "_build/host-deps/nanovdb/include"
166
161
 
167
162
  # output stale, rebuild
168
- if warp.config.verbose:
163
+ if args.verbose:
169
164
  print(f"Building {dll_path}")
170
165
 
171
166
  native_dir = os.path.join(warp_home, "native")
@@ -181,7 +176,7 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
181
176
 
182
177
  gencode_opts = []
183
178
 
184
- if quick:
179
+ if args.quick:
185
180
  # minimum supported architectures (PTX)
186
181
  gencode_opts += ["-gencode=arch=compute_52,code=compute_52", "-gencode=arch=compute_75,code=compute_75"]
187
182
  else:
@@ -224,15 +219,15 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
224
219
  "--extended-lambda",
225
220
  ]
226
221
 
227
- if fast_math:
222
+ if args.fast_math:
228
223
  nvcc_opts.append("--use_fast_math")
229
224
 
230
225
  # is the library being built with CUDA enabled?
231
226
  cuda_enabled = "WP_ENABLE_CUDA=1" if (cu_path is not None) else "WP_ENABLE_CUDA=0"
232
227
 
233
228
  if os.name == "nt":
234
- if warp.config.host_compiler:
235
- host_linker = os.path.join(os.path.dirname(warp.config.host_compiler), "link.exe")
229
+ if args.host_compiler:
230
+ host_linker = os.path.join(os.path.dirname(args.host_compiler), "link.exe")
236
231
  else:
237
232
  raise RuntimeError("Warp build error: No host compiler was found")
238
233
 
@@ -251,27 +246,27 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
251
246
  iter_dbg = "_ITERATOR_DEBUG_LEVEL=2"
252
247
  debug = "_DEBUG"
253
248
 
254
- if warp.config.mode == "debug":
249
+ if args.mode == "debug":
255
250
  cpp_flags = f'/nologo {runtime} /Zi /Od /D "{debug}" /D WP_ENABLE_DEBUG=1 /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" /I"{nanovdb_home}" {includes}'
256
251
  linkopts = ["/DLL", "/DEBUG"]
257
- elif warp.config.mode == "release":
252
+ elif args.mode == "release":
258
253
  cpp_flags = f'/nologo {runtime} /Ox /D "{debug}" /D WP_ENABLE_DEBUG=0 /D "{cuda_enabled}" /D "{cutlass_enabled}" /D "{cuda_compat_enabled}" /D "{iter_dbg}" /I"{native_dir}" /I"{nanovdb_home}" {includes}'
259
254
  linkopts = ["/DLL"]
260
255
  else:
261
- raise RuntimeError(f"Unrecognized build configuration (debug, release), got: {mode}")
256
+ raise RuntimeError(f"Unrecognized build configuration (debug, release), got: {args.mode}")
262
257
 
263
- if verify_fp:
258
+ if args.verify_fp:
264
259
  cpp_flags += ' /D "WP_VERIFY_FP"'
265
260
 
266
- if fast_math:
261
+ if args.fast_math:
267
262
  cpp_flags += " /fp:fast"
268
263
 
269
- with ScopedTimer("build", active=warp.config.verbose):
264
+ with ScopedTimer("build", active=args.verbose):
270
265
  for cpp_path in cpp_paths:
271
266
  cpp_out = cpp_path + ".obj"
272
267
  linkopts.append(quote(cpp_out))
273
268
 
274
- cpp_cmd = f'"{warp.config.host_compiler}" {cpp_flags} -c "{cpp_path}" /Fo"{cpp_out}"'
269
+ cpp_cmd = f'"{args.host_compiler}" {cpp_flags} -c "{cpp_path}" /Fo"{cpp_out}"'
275
270
  run_cmd(cpp_cmd)
276
271
 
277
272
  if cu_path:
@@ -283,14 +278,14 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
283
278
  elif mode == "release":
284
279
  cuda_cmd = f'"{cuda_home}/bin/nvcc" -O3 {" ".join(nvcc_opts)} -I"{native_dir}" -I"{nanovdb_home}" -DNDEBUG -DWP_ENABLE_CUDA=1 -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
285
280
 
286
- with ScopedTimer("build_cuda", active=warp.config.verbose):
281
+ with ScopedTimer("build_cuda", active=args.verbose):
287
282
  run_cmd(cuda_cmd)
288
283
  linkopts.append(quote(cu_out))
289
284
  linkopts.append(
290
285
  f'cudart_static.lib nvrtc_static.lib nvrtc-builtins_static.lib nvptxcompiler_static.lib ws2_32.lib user32.lib /LIBPATH:"{cuda_home}/lib/x64"'
291
286
  )
292
287
 
293
- with ScopedTimer("link", active=warp.config.verbose):
288
+ with ScopedTimer("link", active=args.verbose):
294
289
  link_cmd = f'"{host_linker}" {" ".join(linkopts + libs)} /out:"{dll_path}"'
295
290
  run_cmd(link_cmd)
296
291
 
@@ -311,15 +306,15 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
311
306
  if mode == "release":
312
307
  cpp_flags = f'{target} -O3 -DNDEBUG -DWP_ENABLE_DEBUG=0 -D{cuda_enabled} -D{cutlass_enabled} -D{cuda_compat_enabled} -fPIC -fvisibility=hidden --std=c++14 -D_GLIBCXX_USE_CXX11_ABI=0 -I"{native_dir}" {includes}'
313
308
 
314
- if verify_fp:
309
+ if args.verify_fp:
315
310
  cpp_flags += " -DWP_VERIFY_FP"
316
311
 
317
- if fast_math:
312
+ if args.fast_math:
318
313
  cpp_flags += " -ffast-math"
319
314
 
320
315
  ld_inputs = []
321
316
 
322
- with ScopedTimer("build", active=warp.config.verbose):
317
+ with ScopedTimer("build", active=args.verbose):
323
318
  for cpp_path in cpp_paths:
324
319
  cpp_out = cpp_path + ".o"
325
320
  ld_inputs.append(quote(cpp_out))
@@ -336,7 +331,7 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
336
331
  elif mode == "release":
337
332
  cuda_cmd = f'"{cuda_home}/bin/nvcc" -O3 --compiler-options -fPIC,-fvisibility=hidden {" ".join(nvcc_opts)} -DNDEBUG -DWP_ENABLE_CUDA=1 -I"{native_dir}" -D{cutlass_enabled} {cutlass_includes} -o "{cu_out}" -c "{cu_path}"'
338
333
 
339
- with ScopedTimer("build_cuda", active=warp.config.verbose):
334
+ with ScopedTimer("build_cuda", active=args.verbose):
340
335
  run_cmd(cuda_cmd)
341
336
 
342
337
  ld_inputs.append(quote(cu_out))
@@ -351,7 +346,7 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
351
346
  opt_no_undefined = "-Wl,--no-undefined"
352
347
  opt_exclude_libs = "-Wl,--exclude-libs,ALL"
353
348
 
354
- with ScopedTimer("link", active=warp.config.verbose):
349
+ with ScopedTimer("link", active=args.verbose):
355
350
  origin = "@loader_path" if (sys.platform == "darwin") else "$ORIGIN"
356
351
  link_cmd = f"g++ {target} -shared -Wl,-rpath,'{origin}' {opt_no_undefined} {opt_exclude_libs} -o '{dll_path}' {' '.join(ld_inputs + libs)}"
357
352
  run_cmd(link_cmd)
@@ -366,19 +361,15 @@ def build_dll_for_arch(dll_path, cpp_paths, cu_path, libs, mode, arch, verify_fp
366
361
  )
367
362
 
368
363
 
369
- def build_dll(dll_path, cpp_paths, cu_path, libs=[], mode="release", verify_fp=False, fast_math=False, quick=False):
364
+ def build_dll(args, dll_path, cpp_paths, cu_path, libs=[]):
370
365
  if sys.platform == "darwin":
371
366
  # create a universal binary by combining x86-64 and AArch64 builds
372
- build_dll_for_arch(dll_path + "-x86_64", cpp_paths, cu_path, libs, mode, "x86_64", verify_fp, fast_math, quick)
373
- build_dll_for_arch(
374
- dll_path + "-aarch64", cpp_paths, cu_path, libs, mode, "aarch64", verify_fp, fast_math, quick
375
- )
367
+ build_dll_for_arch(args, dll_path + "-x86_64", cpp_paths, cu_path, libs, "x86_64")
368
+ build_dll_for_arch(args, dll_path + "-aarch64", cpp_paths, cu_path, libs, "aarch64")
376
369
 
377
370
  run_cmd(f"lipo -create -output {dll_path} {dll_path}-x86_64 {dll_path}-aarch64")
378
371
  os.remove(f"{dll_path}-x86_64")
379
372
  os.remove(f"{dll_path}-aarch64")
380
373
 
381
374
  else:
382
- build_dll_for_arch(
383
- dll_path, cpp_paths, cu_path, libs, mode, machine_architecture(), verify_fp, fast_math, quick
384
- )
375
+ build_dll_for_arch(args, dll_path, cpp_paths, cu_path, libs, machine_architecture())
warp/builtins.py CHANGED
@@ -612,16 +612,20 @@ add_builtin(
612
612
 
613
613
 
614
614
  # scalar type constructors between all storage / compute types
615
- scalar_types_all = [*scalar_types, int, float]
615
+ scalar_types_all = [*scalar_types, bool, int, float]
616
616
  for t in scalar_types_all:
617
617
  for u in scalar_types_all:
618
618
  add_builtin(
619
- t.__name__, input_types={"u": u}, value_type=t, doc="", hidden=True, group="Scalar Math", export=False
619
+ t.__name__,
620
+ input_types={"u": u},
621
+ value_type=t,
622
+ doc="",
623
+ hidden=True,
624
+ group="Scalar Math",
625
+ export=False,
626
+ namespace="wp::" if t is not bool else "",
620
627
  )
621
628
 
622
- for u in [bool, builtins.bool]:
623
- add_builtin(bool.__name__, input_types={"u": u}, value_type=bool, doc="", hidden=True, export=False, namespace="")
624
-
625
629
 
626
630
  def vector_constructor_func(arg_types, kwds, templates):
627
631
  if arg_types is None:
@@ -2852,7 +2856,7 @@ add_builtin(
2852
2856
  skip_replay=True,
2853
2857
  )
2854
2858
 
2855
- for t in scalar_types + vector_types + [builtins.bool]:
2859
+ for t in scalar_types + vector_types + [bool, builtins.bool]:
2856
2860
  if "vec" in t.__name__ or "mat" in t.__name__:
2857
2861
  continue
2858
2862
  add_builtin(
warp/codegen.py CHANGED
@@ -418,7 +418,10 @@ def compute_type_str(base_name, template_params):
418
418
  if isinstance(p, int):
419
419
  return str(p)
420
420
  elif hasattr(p, "_type_"):
421
- return f"wp::{p.__name__}"
421
+ if p.__name__ == "bool":
422
+ return "bool"
423
+ else:
424
+ return f"wp::{p.__name__}"
422
425
  return p.__name__
423
426
 
424
427
  return f"{base_name}<{','.join(map(param2str, template_params))}>"
@@ -595,12 +598,17 @@ class Adjoint:
595
598
  adj.skip_build = False
596
599
 
597
600
  # generate function ssa form and adjoint
598
- def build(adj, builder):
601
+ def build(adj, builder, default_builder_options={}):
599
602
  if adj.skip_build:
600
603
  return
601
604
 
602
605
  adj.builder = builder
603
606
 
607
+ if adj.builder:
608
+ adj.builder_options = adj.builder.options
609
+ else:
610
+ adj.builder_options = default_builder_options
611
+
604
612
  adj.symbols = {} # map from symbols to adjoint variables
605
613
  adj.variables = [] # list of local variables (in order)
606
614
 
@@ -911,8 +919,16 @@ class Adjoint:
911
919
  break
912
920
 
913
921
  # if it is a user-function then build it recursively
914
- if not func.is_builtin():
922
+ if not func.is_builtin() and func not in adj.builder.functions:
915
923
  adj.builder.build_function(func)
924
+ # add custom grad, replay functions to the list of functions
925
+ # to be built later (invalid code could be generated if we built them now)
926
+ # so that they are not missed when only the forward function is imported
927
+ # from another module
928
+ if func.custom_grad_func:
929
+ adj.builder.deferred_functions.append(func.custom_grad_func)
930
+ if func.custom_replay_func:
931
+ adj.builder.deferred_functions.append(func.custom_replay_func)
916
932
 
917
933
  # evaluate the function type based on inputs
918
934
  arg_types = [strip_reference(a.type) for a in args if not isinstance(a, warp.context.Function)]
@@ -924,9 +940,11 @@ class Adjoint:
924
940
  use_initializer_list = func.initializer_list_func(args, templates)
925
941
 
926
942
  args_var = [
927
- adj.load(a)
928
- if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
929
- else a
943
+ (
944
+ adj.load(a)
945
+ if not ((param_types[i] == Reference or param_types[i] == Callable) if i < len(param_types) else False)
946
+ else a
947
+ )
930
948
  for i, a in enumerate(args)
931
949
  ]
932
950
 
@@ -940,7 +958,7 @@ class Adjoint:
940
958
  f"{func.namespace}{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
941
959
  )
942
960
  replay_call = forward_call
943
- if func.custom_replay_func is not None:
961
+ if func.custom_replay_func is not None or func.replay_snippet is not None:
944
962
  replay_call = f"{func.namespace}replay_{func_name}({adj.format_forward_call_args(args_var, use_initializer_list)});"
945
963
 
946
964
  elif not isinstance(return_type, list) or len(return_type) == 1:
@@ -1539,7 +1557,11 @@ class Adjoint:
1539
1557
 
1540
1558
  # test if we're above max unroll count
1541
1559
  max_iters = abs(end - start) // abs(step)
1542
- max_unroll = adj.builder.options["max_unroll"]
1560
+
1561
+ if "max_unroll" in adj.builder_options:
1562
+ max_unroll = adj.builder_options["max_unroll"]
1563
+ else:
1564
+ max_unroll = warp.config.max_unroll
1543
1565
 
1544
1566
  ok_to_unroll = True
1545
1567
 
@@ -1722,9 +1744,7 @@ class Adjoint:
1722
1744
 
1723
1745
  target = adj.eval(node.value)
1724
1746
  if not is_local_value(target):
1725
- raise RuntimeError(
1726
- "Cannot reference a global variable from a kernel unless `wp.constant()` is being used"
1727
- )
1747
+ raise RuntimeError("Cannot reference a global variable from a kernel unless `wp.constant()` is being used")
1728
1748
 
1729
1749
  indices = []
1730
1750
 
@@ -2008,11 +2028,9 @@ class Adjoint:
2008
2028
  # Look up the closure info and append it to adj.func.__globals__
2009
2029
  # in case you want to define a kernel inside a function and refer
2010
2030
  # to variables you've declared inside that function:
2011
- extract_contents = (
2012
- lambda contents: contents
2013
- if isinstance(contents, warp.context.Function) or not callable(contents)
2014
- else contents
2015
- )
2031
+ def extract_contents(contents):
2032
+ return contents if isinstance(contents, warp.context.Function) or not callable(contents) else contents
2033
+
2016
2034
  capturedvars = dict(
2017
2035
  zip(
2018
2036
  adj.func.__code__.co_freevars,
@@ -2343,9 +2361,12 @@ def constant_str(value):
2343
2361
  initlist = []
2344
2362
  for i in range(value._length_):
2345
2363
  x = ctypes.Array.__getitem__(value, i)
2346
- initlist.append(str(scalar_value(x)))
2364
+ initlist.append(str(scalar_value(x)).lower())
2347
2365
 
2348
- dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
2366
+ if value._wp_scalar_type_ is bool:
2367
+ dtypestr = f"wp::initializer_array<{value._length_},{value._wp_scalar_type_.__name__}>"
2368
+ else:
2369
+ dtypestr = f"wp::initializer_array<{value._length_},wp::{value._wp_scalar_type_.__name__}>"
2349
2370
 
2350
2371
  # construct value from initializer array, e.g. wp::initializer_array<4,wp::float32>{1.0, 2.0, 3.0, 4.0}
2351
2372
  return f"{dtypestr}{{{', '.join(initlist)}}}"
@@ -2614,7 +2635,7 @@ def codegen_func(adj, c_func_name: str, device="cpu", options={}):
2614
2635
  return s
2615
2636
 
2616
2637
 
2617
- def codegen_snippet(adj, name, snippet, adj_snippet):
2638
+ def codegen_snippet(adj, name, snippet, adj_snippet, replay_snippet):
2618
2639
  forward_args = []
2619
2640
  reverse_args = []
2620
2641
 
@@ -2633,6 +2654,7 @@ def codegen_snippet(adj, name, snippet, adj_snippet):
2633
2654
  reverse_args.append(arg.ctype() + " & adj_" + arg.label)
2634
2655
 
2635
2656
  forward_template = cuda_forward_function_template
2657
+ replay_template = cuda_forward_function_template
2636
2658
  reverse_template = cuda_reverse_function_template
2637
2659
 
2638
2660
  s = ""
@@ -2645,6 +2667,16 @@ def codegen_snippet(adj, name, snippet, adj_snippet):
2645
2667
  lineno=adj.fun_lineno,
2646
2668
  )
2647
2669
 
2670
+ if replay_snippet is not None:
2671
+ s += replay_template.format(
2672
+ name="replay_" + name,
2673
+ return_type="void",
2674
+ forward_args=indent(forward_args),
2675
+ forward_body=replay_snippet,
2676
+ filename=adj.filename,
2677
+ lineno=adj.fun_lineno,
2678
+ )
2679
+
2648
2680
  if adj_snippet:
2649
2681
  reverse_body = adj_snippet
2650
2682
  else:
warp/config.py CHANGED
@@ -5,11 +5,7 @@
5
5
  # distribution of this software and related documentation without an express
6
6
  # license agreement from NVIDIA CORPORATION is strictly prohibited.
7
7
 
8
- version = "0.11.0"
9
-
10
- cuda_path = (
11
- None # path to local CUDA toolchain, if None at init time warp will attempt to find the SDK using CUDA_PATH env var
12
- )
8
+ version = "1.0.0"
13
9
 
14
10
  verify_fp = False # verify inputs and outputs are finite after each launch
15
11
  verify_cuda = False # if true will check CUDA errors after each kernel launch / memory operation
@@ -17,10 +13,9 @@ print_launches = False # if true will print out launch information
17
13
 
18
14
  mode = "release"
19
15
  verbose = False # print extra informative messages
16
+ verbose_warnings = False # whether file and line info gets included in Warp warnings
20
17
  quiet = False # suppress all output except errors and warnings
21
18
 
22
- host_compiler = None # user can specify host compiler here, otherwise will attempt to find one automatically
23
-
24
19
  cache_kernels = True
25
20
  kernel_cache_dir = None # path to kernel cache directory, if None a default path will be used
26
21
 
@@ -34,4 +29,8 @@ enable_backward = True # whether to compiler the backward passes of the kernels
34
29
 
35
30
  llvm_cuda = False # use Clang/LLVM instead of NVRTC to compile CUDA
36
31
 
37
- graph_capture_module_load_default = True # Default value of force_module_load for capture_begin()
32
+ enable_graph_capture_module_load_by_default = True # Default value of force_module_load for capture_begin()
33
+
34
+ enable_mempools_at_init = True # Whether CUDA devices will be initialized with mempools enabled (if supported)
35
+
36
+ max_unroll = 16
warp/constants.py CHANGED
@@ -26,6 +26,8 @@ __all__ = [
26
26
  "phi",
27
27
  "PI",
28
28
  "pi",
29
+ "HALF_PI",
30
+ "half_pi",
29
31
  "TAU",
30
32
  "tau",
31
33
  ]
@@ -37,6 +39,7 @@ LN2 = ln2 = constant(0.69314718055994530942) # ln(2)
37
39
  LN10 = ln10 = constant(2.30258509299404568402) # ln(10)
38
40
  PHI = phi = constant(1.61803398874989484820) # golden constant
39
41
  PI = pi = constant(3.14159265358979323846) # pi
42
+ HALF_PI = half_pi = constant(1.57079632679489661923) # half pi
40
43
  TAU = tau = constant(6.28318530717958647692) # 2 * pi
41
44
 
42
45
  INF = inf = constant(math.inf)