warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.8.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (58) hide show
  1. warp/bin/libwarp.dylib +0 -0
  2. warp/build_dll.py +5 -0
  3. warp/codegen.py +15 -3
  4. warp/config.py +1 -1
  5. warp/context.py +122 -24
  6. warp/examples/interop/example_jax_callable.py +34 -4
  7. warp/examples/interop/example_jax_kernel.py +27 -1
  8. warp/fem/field/virtual.py +2 -0
  9. warp/fem/integrate.py +78 -47
  10. warp/jax_experimental/ffi.py +201 -53
  11. warp/native/array.h +4 -4
  12. warp/native/builtin.h +8 -4
  13. warp/native/coloring.cpp +5 -1
  14. warp/native/cuda_util.cpp +1 -1
  15. warp/native/intersect.h +2 -2
  16. warp/native/mat.h +3 -3
  17. warp/native/mesh.h +1 -1
  18. warp/native/quat.h +6 -2
  19. warp/native/rand.h +7 -7
  20. warp/native/sparse.cu +1 -1
  21. warp/native/svd.h +23 -8
  22. warp/native/tile.h +20 -1
  23. warp/native/tile_radix_sort.h +5 -1
  24. warp/native/tile_reduce.h +16 -25
  25. warp/native/tuple.h +2 -2
  26. warp/native/vec.h +4 -4
  27. warp/native/warp.cpp +1 -1
  28. warp/native/warp.cu +15 -2
  29. warp/native/warp.h +1 -1
  30. warp/render/render_opengl.py +52 -51
  31. warp/render/render_usd.py +0 -1
  32. warp/sim/collide.py +1 -2
  33. warp/sim/integrator_vbd.py +10 -2
  34. warp/sparse.py +1 -1
  35. warp/tape.py +2 -0
  36. warp/tests/sim/test_cloth.py +89 -6
  37. warp/tests/sim/test_coloring.py +76 -1
  38. warp/tests/test_assert.py +53 -0
  39. warp/tests/test_atomic_cas.py +127 -114
  40. warp/tests/test_mat.py +22 -0
  41. warp/tests/test_quat.py +22 -0
  42. warp/tests/test_sparse.py +32 -0
  43. warp/tests/test_static.py +48 -0
  44. warp/tests/test_tape.py +38 -0
  45. warp/tests/test_vec.py +38 -408
  46. warp/tests/test_vec_constructors.py +325 -0
  47. warp/tests/tile/test_tile.py +31 -143
  48. warp/tests/tile/test_tile_mathdx.py +2 -2
  49. warp/tests/tile/test_tile_matmul.py +179 -0
  50. warp/tests/tile/test_tile_reduce.py +100 -11
  51. warp/tests/tile/test_tile_shared_memory.py +12 -12
  52. warp/tests/tile/test_tile_sort.py +59 -55
  53. warp/tests/unittest_suites.py +10 -0
  54. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/METADATA +4 -4
  55. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/RECORD +58 -56
  56. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/WHEEL +0 -0
  57. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/licenses/LICENSE.md +0 -0
  58. {warp_lang-1.8.0.dist-info → warp_lang-1.8.1.dist-info}/top_level.txt +0 -0
warp/bin/libwarp.dylib CHANGED
Binary file
warp/build_dll.py CHANGED
@@ -227,6 +227,7 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
227
227
  "-gencode=arch=compute_61,code=sm_61",
228
228
  "-gencode=arch=compute_70,code=sm_70", # Volta
229
229
  "-gencode=arch=compute_75,code=sm_75", # Turing
230
+ "-gencode=arch=compute_75,code=compute_75", # Turing (PTX)
230
231
  "-gencode=arch=compute_80,code=sm_80", # Ampere
231
232
  "-gencode=arch=compute_86,code=sm_86",
232
233
  ]
@@ -260,6 +261,10 @@ def build_dll_for_arch(args, dll_path, cpp_paths, cu_path, arch, libs: Optional[
260
261
  "--cuda-gpu-arch=sm_87", # Orin
261
262
  ]
262
263
 
264
+ if ctk_version >= (12, 8):
265
+ gencode_opts += ["-gencode=arch=compute_101,code=sm_101"] # Thor (CUDA 12 numbering)
266
+ clang_arch_flags += ["--cuda-gpu-arch=sm_101"]
267
+
263
268
  if ctk_version >= (12, 8):
264
269
  # Support for Blackwell is available with CUDA Toolkit 12.8+
265
270
  gencode_opts += [
warp/codegen.py CHANGED
@@ -616,6 +616,8 @@ def compute_type_str(base_name, template_params):
616
616
  def param2str(p):
617
617
  if isinstance(p, int):
618
618
  return str(p)
619
+ elif hasattr(p, "_wp_generic_type_str_"):
620
+ return compute_type_str(f"wp::{p._wp_generic_type_str_}", p._wp_type_params_)
619
621
  elif hasattr(p, "_type_"):
620
622
  if p.__name__ == "bool":
621
623
  return "bool"
@@ -967,6 +969,11 @@ class Adjoint:
967
969
  # this is to avoid registering false references to overshadowed modules
968
970
  adj.symbols[name] = arg
969
971
 
972
+ # Indicates whether there are unresolved static expressions in the function.
973
+ # These stem from wp.static() expressions that could not be evaluated at declaration time.
974
+ # This will signal to the module builder that this module needs to be rebuilt even if the module hash is unchanged.
975
+ adj.has_unresolved_static_expressions = False
976
+
970
977
  # try to replace static expressions by their constant result if the
971
978
  # expression can be evaluated at declaration time
972
979
  adj.static_expressions: dict[str, Any] = {}
@@ -2322,8 +2329,9 @@ class Adjoint:
2322
2329
 
2323
2330
  if adj.is_static_expression(func):
2324
2331
  # try to evaluate wp.static() expressions
2325
- obj, _ = adj.evaluate_static_expression(node)
2332
+ obj, code = adj.evaluate_static_expression(node)
2326
2333
  if obj is not None:
2334
+ adj.static_expressions[code] = obj
2327
2335
  if isinstance(obj, warp.context.Function):
2328
2336
  # special handling for wp.static() evaluating to a function
2329
2337
  return obj
@@ -3109,6 +3117,7 @@ class Adjoint:
3109
3117
 
3110
3118
  # Since this is an expression, we can enforce it to be defined on a single line.
3111
3119
  static_code = static_code.replace("\n", "")
3120
+ code_to_eval = static_code # code to be evaluated
3112
3121
 
3113
3122
  vars_dict = adj.get_static_evaluation_context()
3114
3123
  # add constant variables to the static call context
@@ -3150,10 +3159,10 @@ class Adjoint:
3150
3159
  loc = end
3151
3160
 
3152
3161
  new_static_code += static_code[len_value_locs[-1][2] :]
3153
- static_code = new_static_code
3162
+ code_to_eval = new_static_code
3154
3163
 
3155
3164
  try:
3156
- value = eval(static_code, vars_dict)
3165
+ value = eval(code_to_eval, vars_dict)
3157
3166
  if warp.config.verbose:
3158
3167
  print(f"Evaluated static command: {static_code} = {value}")
3159
3168
  except NameError as e:
@@ -3206,6 +3215,9 @@ class Adjoint:
3206
3215
  # (and is therefore not executable and raises this exception), in which
3207
3216
  # case changing the constant, or the code affecting this constant, would lead to
3208
3217
  # a different module hash anyway.
3218
+ # In any case, we mark this Adjoint to have unresolvable static expressions.
3219
+ # This will trigger a code generation step even if the module hash is unchanged.
3220
+ adj.has_unresolved_static_expressions = True
3209
3221
  pass
3210
3222
 
3211
3223
  return self.generic_visit(node)
warp/config.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import Optional
17
17
 
18
- version: str = "1.8.0"
18
+ version: str = "1.8.1"
19
19
  """Warp version string"""
20
20
 
21
21
  verify_fp: bool = False
warp/context.py CHANGED
@@ -1692,7 +1692,7 @@ class ModuleHasher:
1692
1692
  ch.update(bytes(name, "utf-8"))
1693
1693
  ch.update(self.get_constant_bytes(value))
1694
1694
 
1695
- # hash wp.static() expressions that were evaluated at declaration time
1695
+ # hash wp.static() expressions
1696
1696
  for k, v in adj.static_expressions.items():
1697
1697
  ch.update(bytes(k, "utf-8"))
1698
1698
  if isinstance(v, Function):
@@ -2011,6 +2011,9 @@ class Module:
2011
2011
  # is retained and later reloaded with the same hash.
2012
2012
  self.cpu_exec_id = 0
2013
2013
 
2014
+ # Indicates whether the module has functions or kernels with unresolved static expressions.
2015
+ self.has_unresolved_static_expressions = False
2016
+
2014
2017
  self.options = {
2015
2018
  "max_unroll": warp.config.max_unroll,
2016
2019
  "enable_backward": warp.config.enable_backward,
@@ -2018,7 +2021,7 @@ class Module:
2018
2021
  "fuse_fp": True,
2019
2022
  "lineinfo": warp.config.lineinfo,
2020
2023
  "cuda_output": None, # supported values: "ptx", "cubin", or None (automatic)
2021
- "mode": warp.config.mode,
2024
+ "mode": None,
2022
2025
  "block_dim": 256,
2023
2026
  "compile_time_trace": warp.config.compile_time_trace,
2024
2027
  }
@@ -2047,6 +2050,10 @@ class Module:
2047
2050
  # track all kernel objects, even if they are duplicates
2048
2051
  self._live_kernels.add(kernel)
2049
2052
 
2053
+ # Check for unresolved static expressions in the kernel.
2054
+ if kernel.adj.has_unresolved_static_expressions:
2055
+ self.has_unresolved_static_expressions = True
2056
+
2050
2057
  self.find_references(kernel.adj)
2051
2058
 
2052
2059
  # for a reload of module on next launch
@@ -2106,6 +2113,10 @@ class Module:
2106
2113
  del func_existing.user_overloads[k]
2107
2114
  func_existing.add_overload(func)
2108
2115
 
2116
+ # Check for unresolved static expressions in the function.
2117
+ if func.adj.has_unresolved_static_expressions:
2118
+ self.has_unresolved_static_expressions = True
2119
+
2109
2120
  self.find_references(func.adj)
2110
2121
 
2111
2122
  # for a reload of module on next launch
@@ -2165,7 +2176,7 @@ class Module:
2165
2176
  self.hashers[block_dim] = ModuleHasher(self)
2166
2177
  return self.hashers[block_dim].get_module_hash()
2167
2178
 
2168
- def load(self, device, block_dim=None) -> ModuleExec:
2179
+ def load(self, device, block_dim=None) -> ModuleExec | None:
2169
2180
  device = runtime.get_device(device)
2170
2181
 
2171
2182
  # update module options if launching with a new block dim
@@ -2174,6 +2185,20 @@ class Module:
2174
2185
 
2175
2186
  active_block_dim = self.options["block_dim"]
2176
2187
 
2188
+ if self.has_unresolved_static_expressions:
2189
+ # The module hash currently does not account for unresolved static expressions
2190
+ # (only static expressions evaluated at declaration time so far).
2191
+ # We need to generate the code for the functions and kernels that have
2192
+ # unresolved static expressions and then compute the module hash again.
2193
+ builder_options = {
2194
+ **self.options,
2195
+ "output_arch": None,
2196
+ }
2197
+ # build functions, kernels to resolve static expressions
2198
+ _ = ModuleBuilder(self, builder_options)
2199
+
2200
+ self.has_unresolved_static_expressions = False
2201
+
2177
2202
  # compute the hash if needed
2178
2203
  if active_block_dim not in self.hashers:
2179
2204
  self.hashers[active_block_dim] = ModuleHasher(self)
@@ -2262,6 +2287,8 @@ class Module:
2262
2287
 
2263
2288
  module_load_timer.extra_msg = " (compiled)" # For wp.ScopedTimer informational purposes
2264
2289
 
2290
+ mode = self.options["mode"] if self.options["mode"] is not None else warp.config.mode
2291
+
2265
2292
  # build CPU
2266
2293
  if device.is_cpu:
2267
2294
  # build
@@ -2281,7 +2308,7 @@ class Module:
2281
2308
  warp.build.build_cpu(
2282
2309
  output_path,
2283
2310
  source_code_path,
2284
- mode=self.options["mode"],
2311
+ mode=mode,
2285
2312
  fast_math=self.options["fast_math"],
2286
2313
  verify_fp=warp.config.verify_fp,
2287
2314
  fuse_fp=self.options["fuse_fp"],
@@ -2311,7 +2338,7 @@ class Module:
2311
2338
  source_code_path,
2312
2339
  output_arch,
2313
2340
  output_path,
2314
- config=self.options["mode"],
2341
+ config=mode,
2315
2342
  verify_fp=warp.config.verify_fp,
2316
2343
  fast_math=self.options["fast_math"],
2317
2344
  fuse_fp=self.options["fuse_fp"],
@@ -3759,6 +3786,7 @@ class Runtime:
3759
3786
  self.core.cuda_graph_end_capture.restype = ctypes.c_bool
3760
3787
 
3761
3788
  self.core.cuda_graph_create_exec.argtypes = [
3789
+ ctypes.c_void_p,
3762
3790
  ctypes.c_void_p,
3763
3791
  ctypes.c_void_p,
3764
3792
  ctypes.POINTER(ctypes.c_void_p),
@@ -4066,9 +4094,14 @@ class Runtime:
4066
4094
  # Update the default PTX architecture based on devices present in the system.
4067
4095
  # Use the lowest architecture among devices that meet the minimum architecture requirement.
4068
4096
  # Devices below the required minimum will use the highest architecture they support.
4069
- eligible_archs = [d.arch for d in self.cuda_devices if d.arch >= self.default_ptx_arch]
4070
- if eligible_archs:
4071
- self.default_ptx_arch = min(eligible_archs)
4097
+ try:
4098
+ self.default_ptx_arch = min(
4099
+ d.arch
4100
+ for d in self.cuda_devices
4101
+ if d.arch >= self.default_ptx_arch and d.arch in self.nvrtc_supported_archs
4102
+ )
4103
+ except ValueError:
4104
+ pass # no eligible NVRTC-supported arch ≥ default, retain existing
4072
4105
  else:
4073
4106
  # CUDA not available
4074
4107
  self.set_default_device("cpu")
@@ -6255,6 +6288,40 @@ def get_module_options(module: Any = None) -> dict[str, Any]:
6255
6288
  return get_module(m.__name__).options
6256
6289
 
6257
6290
 
6291
+ def _unregister_capture(device: Device, stream: Stream, graph: Graph):
6292
+ """Unregister a graph capture from the device and runtime.
6293
+
6294
+ This should be called when a graph capture is no longer active, either because it completed or was paused.
6295
+ The graph should only be registered while it is actively capturing.
6296
+
6297
+ Args:
6298
+ device: The CUDA device the graph was being captured on
6299
+ stream: The CUDA stream the graph was being captured on
6300
+ graph: The Graph object that was being captured
6301
+ """
6302
+ del device.captures[stream]
6303
+ del runtime.captures[graph.capture_id]
6304
+
6305
+
6306
+ def _register_capture(device: Device, stream: Stream, graph: Graph, capture_id: int):
6307
+ """Register a graph capture with the device and runtime.
6308
+
6309
+ Makes the graph discoverable through its capture_id so that retain_module_exec() can be called
6310
+ when launching kernels during graph capture. This ensures modules are retained until graph execution completes.
6311
+
6312
+ Args:
6313
+ device: The CUDA device the graph is being captured on
6314
+ stream: The CUDA stream the graph is being captured on
6315
+ graph: The Graph object being captured
6316
+ capture_id: Unique identifier for this graph capture
6317
+ """
6318
+ # add to ongoing captures on the device
6319
+ device.captures[stream] = graph
6320
+
6321
+ # add to lookup table by globally unique capture id
6322
+ runtime.captures[capture_id] = graph
6323
+
6324
+
6258
6325
  def capture_begin(
6259
6326
  device: Devicelike = None,
6260
6327
  stream: Stream | None = None,
@@ -6320,11 +6387,7 @@ def capture_begin(
6320
6387
  capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6321
6388
  graph = Graph(device, capture_id)
6322
6389
 
6323
- # add to ongoing captures on the device
6324
- device.captures[stream] = graph
6325
-
6326
- # add to lookup table by globally unique capture id
6327
- runtime.captures[capture_id] = graph
6390
+ _register_capture(device, stream, graph, capture_id)
6328
6391
 
6329
6392
 
6330
6393
  def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Graph:
@@ -6352,8 +6415,7 @@ def capture_end(device: Devicelike = None, stream: Stream | None = None) -> Grap
6352
6415
  if graph is None:
6353
6416
  raise RuntimeError("Graph capture is not active on this stream")
6354
6417
 
6355
- del device.captures[stream]
6356
- del runtime.captures[graph.capture_id]
6418
+ _unregister_capture(device, stream, graph)
6357
6419
 
6358
6420
  # get the graph executable
6359
6421
  g = ctypes.c_void_p()
@@ -6393,7 +6455,7 @@ def assert_conditional_graph_support():
6393
6455
  raise RuntimeError("Conditional graph nodes require CUDA driver 12.4+")
6394
6456
 
6395
6457
 
6396
- def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ctypes.c_void_p:
6458
+ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> Graph:
6397
6459
  if stream is not None:
6398
6460
  device = stream.device
6399
6461
  else:
@@ -6402,14 +6464,24 @@ def capture_pause(device: Devicelike = None, stream: Stream | None = None) -> ct
6402
6464
  raise RuntimeError("Must be a CUDA device")
6403
6465
  stream = device.stream
6404
6466
 
6405
- graph = ctypes.c_void_p()
6406
- if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(graph)):
6467
+ # get the graph being captured
6468
+ graph = device.captures.get(stream)
6469
+
6470
+ if graph is None:
6471
+ raise RuntimeError("Graph capture is not active on this stream")
6472
+
6473
+ _unregister_capture(device, stream, graph)
6474
+
6475
+ g = ctypes.c_void_p()
6476
+ if not runtime.core.cuda_graph_pause_capture(device.context, stream.cuda_stream, ctypes.byref(g)):
6407
6477
  raise RuntimeError(runtime.get_error_string())
6408
6478
 
6479
+ graph.graph = g
6480
+
6409
6481
  return graph
6410
6482
 
6411
6483
 
6412
- def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: Stream | None = None):
6484
+ def capture_resume(graph: Graph, device: Devicelike = None, stream: Stream | None = None):
6413
6485
  if stream is not None:
6414
6486
  device = stream.device
6415
6487
  else:
@@ -6418,9 +6490,14 @@ def capture_resume(graph: ctypes.c_void_p, device: Devicelike = None, stream: St
6418
6490
  raise RuntimeError("Must be a CUDA device")
6419
6491
  stream = device.stream
6420
6492
 
6421
- if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph):
6493
+ if not runtime.core.cuda_graph_resume_capture(device.context, stream.cuda_stream, graph.graph):
6422
6494
  raise RuntimeError(runtime.get_error_string())
6423
6495
 
6496
+ capture_id = runtime.core.cuda_stream_get_capture_id(stream.cuda_stream)
6497
+ graph.capture_id = capture_id
6498
+
6499
+ _register_capture(device, stream, graph, capture_id)
6500
+
6424
6501
 
6425
6502
  # reusable pinned readback buffer for conditions
6426
6503
  condition_host = None
@@ -6518,10 +6595,15 @@ def capture_if(
6518
6595
 
6519
6596
  # pause capturing parent graph
6520
6597
  main_graph = capture_pause(stream=stream)
6598
+ # store the pointer to the cuda graph to restore it later
6599
+ main_graph_ptr = main_graph.graph
6521
6600
 
6522
6601
  # capture if-graph
6523
6602
  if on_true is not None:
6524
- capture_resume(graph_on_true, stream=stream)
6603
+ # temporarily repurpose the main_graph python object such that all dependencies
6604
+ # added through retain_module_exec() end up in the correct python graph object
6605
+ main_graph.graph = graph_on_true
6606
+ capture_resume(main_graph, stream=stream)
6525
6607
  if isinstance(on_true, Callable):
6526
6608
  on_true(**kwargs)
6527
6609
  elif isinstance(on_true, Graph):
@@ -6541,7 +6623,10 @@ def capture_if(
6541
6623
 
6542
6624
  # capture else-graph
6543
6625
  if on_false is not None:
6544
- capture_resume(graph_on_false, stream=stream)
6626
+ # temporarily repurpose the main_graph python object such that all dependencies
6627
+ # added through retain_module_exec() end up in the correct python graph object
6628
+ main_graph.graph = graph_on_false
6629
+ capture_resume(main_graph, stream=stream)
6545
6630
  if isinstance(on_false, Callable):
6546
6631
  on_false(**kwargs)
6547
6632
  elif isinstance(on_false, Graph):
@@ -6559,6 +6644,9 @@ def capture_if(
6559
6644
  raise TypeError("on_false must be a Callable or a Graph")
6560
6645
  capture_pause(stream=stream)
6561
6646
 
6647
+ # restore the main graph to its original state
6648
+ main_graph.graph = main_graph_ptr
6649
+
6562
6650
  # resume capturing parent graph
6563
6651
  capture_resume(main_graph, stream=stream)
6564
6652
 
@@ -6641,7 +6729,13 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6641
6729
 
6642
6730
  # pause capturing parent graph and start capturing child graph
6643
6731
  main_graph = capture_pause(stream=stream)
6644
- capture_resume(body_graph, stream=stream)
6732
+ # store the pointer to the cuda graph to restore it later
6733
+ main_graph_ptr = main_graph.graph
6734
+
6735
+ # temporarily repurpose the main_graph python object such that all dependencies
6736
+ # added through retain_module_exec() end up in the correct python graph object
6737
+ main_graph.graph = body_graph
6738
+ capture_resume(main_graph, stream=stream)
6645
6739
 
6646
6740
  # capture while-body
6647
6741
  if isinstance(while_body, Callable):
@@ -6670,6 +6764,8 @@ def capture_while(condition: warp.array(dtype=int), while_body: Callable | Graph
6670
6764
 
6671
6765
  # stop capturing child graph and resume capturing parent graph
6672
6766
  capture_pause(stream=stream)
6767
+ # restore the main graph to its original state
6768
+ main_graph.graph = main_graph_ptr
6673
6769
  capture_resume(main_graph, stream=stream)
6674
6770
 
6675
6771
 
@@ -6691,7 +6787,9 @@ def capture_launch(graph: Graph, stream: Stream | None = None):
6691
6787
 
6692
6788
  if graph.graph_exec is None:
6693
6789
  g = ctypes.c_void_p()
6694
- result = runtime.core.cuda_graph_create_exec(graph.device.context, graph.graph, ctypes.byref(g))
6790
+ result = runtime.core.cuda_graph_create_exec(
6791
+ graph.device.context, stream.cuda_stream, graph.graph, ctypes.byref(g)
6792
+ )
6695
6793
  if not result:
6696
6794
  raise RuntimeError(f"Graph creation error: {runtime.get_error_string()}")
6697
6795
  graph.graph_exec = g
@@ -42,7 +42,7 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
42
42
 
43
43
  # The Python function to call.
44
44
  # Note the argument annotations, just like Warp kernels.
45
- def example_func(
45
+ def scale_func(
46
46
  # inputs
47
47
  a: wp.array(dtype=float),
48
48
  b: wp.array(dtype=wp.vec2),
@@ -55,8 +55,23 @@ def example_func(
55
55
  wp.launch(scale_vec_kernel, dim=b.shape, inputs=[b, s], outputs=[d])
56
56
 
57
57
 
58
+ @wp.kernel
59
+ def accum_kernel(a: wp.array(dtype=float), b: wp.array(dtype=float)):
60
+ tid = wp.tid()
61
+ b[tid] += a[tid]
62
+
63
+
64
+ def in_out_func(
65
+ a: wp.array(dtype=float), # input only
66
+ b: wp.array(dtype=float), # input and output
67
+ c: wp.array(dtype=float), # output only
68
+ ):
69
+ wp.launch(scale_kernel, dim=a.size, inputs=[a, 2.0], outputs=[c])
70
+ wp.launch(accum_kernel, dim=a.size, inputs=[a, b]) # modifies `b`
71
+
72
+
58
73
  def example1():
59
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
74
+ jax_func = jax_callable(scale_func, num_outputs=2)
60
75
 
61
76
  @jax.jit
62
77
  def f():
@@ -78,7 +93,7 @@ def example1():
78
93
 
79
94
 
80
95
  def example2():
81
- jax_func = jax_callable(example_func, num_outputs=2, vmap_method="broadcast_all")
96
+ jax_func = jax_callable(scale_func, num_outputs=2)
82
97
 
83
98
  # NOTE: scalar arguments must be static compile-time constants
84
99
  @partial(jax.jit, static_argnames=["s"])
@@ -100,11 +115,26 @@ def example2():
100
115
  print(r2)
101
116
 
102
117
 
118
+ def example3():
119
+ # Using input-output arguments
120
+
121
+ jax_func = jax_callable(in_out_func, num_outputs=2, in_out_argnames=["b"])
122
+
123
+ f = jax.jit(jax_func)
124
+
125
+ a = jnp.ones(10, dtype=jnp.float32)
126
+ b = jnp.arange(10, dtype=jnp.float32)
127
+
128
+ b, c = f(a, b)
129
+ print(b)
130
+ print(c)
131
+
132
+
103
133
  def main():
104
134
  wp.init()
105
135
  wp.load_module(device=wp.get_device())
106
136
 
107
- examples = [example1, example2]
137
+ examples = [example1, example2, example3]
108
138
 
109
139
  for example in examples:
110
140
  print("\n===========================================================================")
@@ -72,6 +72,17 @@ def scale_vec_kernel(a: wp.array(dtype=wp.vec2), s: float, output: wp.array(dtyp
72
72
  output[tid] = a[tid] * s
73
73
 
74
74
 
75
+ @wp.kernel
76
+ def in_out_kernel(
77
+ a: wp.array(dtype=float), # input only
78
+ b: wp.array(dtype=float), # input and output
79
+ c: wp.array(dtype=float), # output only
80
+ ):
81
+ tid = wp.tid()
82
+ b[tid] += a[tid]
83
+ c[tid] = 2.0 * a[tid]
84
+
85
+
75
86
  def example1():
76
87
  # two inputs and one output
77
88
  jax_add = jax_kernel(add_kernel)
@@ -189,11 +200,26 @@ def example7():
189
200
  print(f())
190
201
 
191
202
 
203
+ def example8():
204
+ # Using input-output arguments
205
+
206
+ jax_func = jax_kernel(in_out_kernel, num_outputs=2, in_out_argnames=["b"])
207
+
208
+ f = jax.jit(jax_func)
209
+
210
+ a = jnp.ones(10, dtype=jnp.float32)
211
+ b = jnp.arange(10, dtype=jnp.float32)
212
+
213
+ b, c = f(a, b)
214
+ print(b)
215
+ print(c)
216
+
217
+
192
218
  def main():
193
219
  wp.init()
194
220
  wp.load_module(device=wp.get_device())
195
221
 
196
- examples = [example1, example2, example3, example4, example5, example6, example7]
222
+ examples = [example1, example2, example3, example4, example5, example6, example7, example8]
197
223
 
198
224
  for example in examples:
199
225
  print("\n===========================================================================")
warp/fem/field/virtual.py CHANGED
@@ -365,6 +365,8 @@ class LocalAdjointField(SpaceField):
365
365
  self._TAYLOR_DOF_COUNTS = LocalAdjointField.DofOffsets(0)
366
366
  self.TAYLOR_DOF_COUNT = 0
367
367
 
368
+ cache.setup_dynamic_attributes(self)
369
+
368
370
  def notify_operator_usage(self, ops: Set[operator.Operator]):
369
371
  # Rebuild degrees-of-freedom offsets based on used operators
370
372