warp-lang 1.7.0__py3-none-manylinux_2_34_aarch64.whl → 1.7.2__py3-none-manylinux_2_34_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (60) hide show
  1. warp/autograd.py +12 -2
  2. warp/bin/warp-clang.so +0 -0
  3. warp/bin/warp.so +0 -0
  4. warp/build.py +1 -1
  5. warp/builtins.py +103 -66
  6. warp/codegen.py +48 -27
  7. warp/config.py +1 -1
  8. warp/context.py +112 -49
  9. warp/examples/benchmarks/benchmark_cloth.py +1 -1
  10. warp/examples/distributed/example_jacobi_mpi.py +507 -0
  11. warp/fem/cache.py +1 -1
  12. warp/fem/field/field.py +11 -1
  13. warp/fem/field/nodal_field.py +36 -22
  14. warp/fem/geometry/adaptive_nanogrid.py +7 -3
  15. warp/fem/geometry/trimesh.py +4 -12
  16. warp/jax_experimental/custom_call.py +14 -2
  17. warp/jax_experimental/ffi.py +100 -67
  18. warp/native/builtin.h +91 -65
  19. warp/native/svd.h +59 -49
  20. warp/native/tile.h +55 -26
  21. warp/native/volume.cpp +2 -2
  22. warp/native/volume_builder.cu +33 -22
  23. warp/native/warp.cu +1 -1
  24. warp/render/render_opengl.py +41 -34
  25. warp/render/render_usd.py +96 -6
  26. warp/sim/collide.py +11 -9
  27. warp/sim/inertia.py +189 -156
  28. warp/sim/integrator_euler.py +3 -0
  29. warp/sim/integrator_xpbd.py +3 -0
  30. warp/sim/model.py +56 -31
  31. warp/sim/render.py +4 -0
  32. warp/sparse.py +1 -1
  33. warp/stubs.py +73 -25
  34. warp/tests/assets/torus.usda +1 -1
  35. warp/tests/cuda/test_streams.py +1 -1
  36. warp/tests/sim/test_collision.py +237 -206
  37. warp/tests/sim/test_inertia.py +161 -0
  38. warp/tests/sim/test_model.py +5 -3
  39. warp/tests/sim/{flaky_test_sim_grad.py → test_sim_grad.py} +1 -4
  40. warp/tests/sim/test_xpbd.py +399 -0
  41. warp/tests/test_array.py +8 -7
  42. warp/tests/test_atomic.py +181 -2
  43. warp/tests/test_builtins_resolution.py +38 -38
  44. warp/tests/test_codegen.py +24 -3
  45. warp/tests/test_examples.py +16 -6
  46. warp/tests/test_fem.py +93 -14
  47. warp/tests/test_func.py +1 -1
  48. warp/tests/test_mat.py +416 -119
  49. warp/tests/test_quat.py +321 -137
  50. warp/tests/test_struct.py +116 -0
  51. warp/tests/test_vec.py +320 -174
  52. warp/tests/tile/test_tile.py +27 -0
  53. warp/tests/tile/test_tile_load.py +124 -0
  54. warp/tests/unittest_suites.py +2 -5
  55. warp/types.py +107 -9
  56. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/METADATA +41 -19
  57. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/RECORD +60 -57
  58. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/WHEEL +1 -1
  59. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/licenses/LICENSE.md +0 -26
  60. {warp_lang-1.7.0.dist-info → warp_lang-1.7.2.dist-info}/top_level.txt +0 -0
warp/autograd.py CHANGED
@@ -52,7 +52,12 @@ def gradcheck(
52
52
  ) -> bool:
53
53
  """
54
54
  Checks whether the autodiff gradient of a Warp kernel matches finite differences.
55
- Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
55
+
56
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
57
+
58
+ .. math::
59
+
60
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
56
61
 
57
62
  The kernel function and its adjoint version are launched with the given inputs and outputs, as well as the provided
58
63
  ``dim``, ``max_blocks``, and ``block_dim`` arguments (see :func:`warp.launch` for more details).
@@ -250,7 +255,12 @@ def gradcheck_tape(
250
255
  ) -> bool:
251
256
  """
252
257
  Checks whether the autodiff gradients for kernels recorded on the Warp tape match finite differences.
253
- Fails if the relative or absolute errors between the autodiff and finite difference gradients exceed the specified tolerance, or if the autodiff gradients contain NaN values.
258
+
259
+ Given the autodiff (:math:`\\nabla_\\text{AD}`) and finite difference gradients (:math:`\\nabla_\\text{FD}`), the check succeeds if the autodiff gradients contain no NaN values and the following condition holds:
260
+
261
+ .. math::
262
+
263
+ |\\nabla_\\text{AD} - \\nabla_\\text{FD}| \\leq atol + rtol \\cdot |\\nabla_\\text{FD}|.
254
264
 
255
265
  Note:
256
266
  Only Warp kernels recorded on the tape are checked but not arbitrary functions that have been recorded, e.g. via :meth:`Tape.record_func`.
warp/bin/warp-clang.so CHANGED
Binary file
warp/bin/warp.so CHANGED
Binary file
warp/build.py CHANGED
@@ -360,7 +360,7 @@ def build_lto_solver(M, N, solver, solver_enum, fill_mode, arch, precision_enum,
360
360
  # TODO: MathDx doesn't yet have heuristics for Blackwell
361
361
  arch = min(arch, 90)
362
362
 
363
- lto_symbol = f"{solver}_{M}_{N}_{arch}_{precision_enum}"
363
+ lto_symbol = f"{solver}_{M}_{N}_{arch}_{num_threads}_{precision_enum}_{fill_mode}"
364
364
  ltoir_decl = f"void {lto_symbol}{parameter_list};"
365
365
 
366
366
  # early out if LTO for this symbol is already cached in current module
warp/builtins.py CHANGED
@@ -836,7 +836,7 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
836
836
 
837
837
  if dtype is None:
838
838
  dtype = value_type
839
- elif value_type != dtype:
839
+ elif not warp.types.scalars_equal(value_type, dtype):
840
840
  raise RuntimeError(
841
841
  f"the value used to fill this vector is expected to be of the type `{dtype.__name__}`"
842
842
  )
@@ -857,9 +857,9 @@ def vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
857
857
 
858
858
  if dtype is None:
859
859
  dtype = value_type
860
- elif value_type != dtype:
860
+ elif not warp.types.scalars_equal(value_type, dtype):
861
861
  raise RuntimeError(
862
- f"all values used to initialize this vector matrix are expected to be of the type `{dtype.__name__}`"
862
+ f"all values used to initialize this vector are expected to be of the type `{dtype.__name__}`"
863
863
  )
864
864
 
865
865
  if length is None:
@@ -940,7 +940,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
940
940
 
941
941
  if dtype is None:
942
942
  dtype = value_type
943
- elif value_type != dtype:
943
+ elif not warp.types.scalars_equal(value_type, dtype):
944
944
  raise RuntimeError(
945
945
  f"the value used to fill this matrix is expected to be of the type `{dtype.__name__}`"
946
946
  )
@@ -979,7 +979,7 @@ def matrix_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, An
979
979
 
980
980
  if dtype is None:
981
981
  dtype = value_type
982
- elif value_type != dtype:
982
+ elif not warp.types.scalars_equal(value_type, dtype):
983
983
  raise RuntimeError(
984
984
  f"all values used to initialize this matrix are expected to be of the type `{dtype.__name__}`"
985
985
  )
@@ -1170,7 +1170,7 @@ def matrix_transform_value_func(arg_types: Mapping[str, type], arg_values: Mappi
1170
1170
 
1171
1171
  if dtype is None:
1172
1172
  dtype = value_type
1173
- elif value_type != dtype:
1173
+ elif not warp.types.scalars_equal(value_type, dtype):
1174
1174
  raise RuntimeError(
1175
1175
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1176
1176
  )
@@ -1305,7 +1305,7 @@ def quaternion_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str
1305
1305
 
1306
1306
  if dtype is None:
1307
1307
  dtype = value_type
1308
- elif value_type != dtype:
1308
+ elif not warp.types.scalars_equal(value_type, dtype):
1309
1309
  raise RuntimeError(
1310
1310
  f"all values used to initialize this quaternion are expected to be of the type `{dtype.__name__}`"
1311
1311
  )
@@ -1345,7 +1345,8 @@ add_builtin(
1345
1345
  )
1346
1346
  add_builtin(
1347
1347
  "quaternion",
1348
- input_types={"x": Float, "y": Float, "z": Float, "w": Float},
1348
+ input_types={"x": Float, "y": Float, "z": Float, "w": Float, "dtype": Scalar},
1349
+ defaults={"dtype": None},
1349
1350
  value_func=quaternion_value_func,
1350
1351
  export_func=lambda input_types: {k: v for k, v in input_types.items() if k != "dtype"},
1351
1352
  dispatch_func=quaternion_dispatch_func,
@@ -1515,7 +1516,7 @@ def transformation_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1515
1516
  dtype = arg_values.get("dtype", None)
1516
1517
  if dtype is None:
1517
1518
  dtype = value_type
1518
- elif value_type != dtype:
1519
+ elif not warp.types.scalars_equal(value_type, dtype):
1519
1520
  raise RuntimeError(
1520
1521
  f"all values used to initialize this transformation matrix are expected to be of the type `{dtype.__name__}`"
1521
1522
  )
@@ -1682,7 +1683,7 @@ def spatial_vector_value_func(arg_types: Mapping[str, type], arg_values: Mapping
1682
1683
 
1683
1684
  if dtype is None:
1684
1685
  dtype = value_type
1685
- elif value_type != dtype:
1686
+ elif not warp.types.scalars_equal(value_type, dtype):
1686
1687
  raise RuntimeError(
1687
1688
  f"all values used to initialize this spatial vector are expected to be of the type `{dtype.__name__}`"
1688
1689
  )
@@ -2263,7 +2264,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2263
2264
  f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2264
2265
  )
2265
2266
 
2266
- return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
2267
+ return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape, storage=arg_types["t"].storage)
2267
2268
 
2268
2269
 
2269
2270
  def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -2422,7 +2423,6 @@ add_builtin(
2422
2423
  group="Tile Primitives",
2423
2424
  export=False,
2424
2425
  hidden=True,
2425
- missing_grad=True,
2426
2426
  )
2427
2427
 
2428
2428
  add_builtin(
@@ -2432,7 +2432,6 @@ add_builtin(
2432
2432
  group="Tile Primitives",
2433
2433
  export=False,
2434
2434
  hidden=True,
2435
- missing_grad=True,
2436
2435
  )
2437
2436
 
2438
2437
  add_builtin(
@@ -2442,7 +2441,6 @@ add_builtin(
2442
2441
  group="Tile Primitives",
2443
2442
  export=False,
2444
2443
  hidden=True,
2445
- missing_grad=True,
2446
2444
  )
2447
2445
 
2448
2446
  add_builtin(
@@ -2452,7 +2450,6 @@ add_builtin(
2452
2450
  group="Tile Primitives",
2453
2451
  export=False,
2454
2452
  hidden=True,
2455
- missing_grad=True,
2456
2453
  )
2457
2454
 
2458
2455
 
@@ -4895,46 +4892,78 @@ add_builtin(
4895
4892
  )
4896
4893
 
4897
4894
 
4895
+ SUPPORTED_ATOMIC_TYPES = (
4896
+ warp.int32,
4897
+ warp.int64,
4898
+ warp.uint32,
4899
+ warp.uint64,
4900
+ warp.float32,
4901
+ warp.float64,
4902
+ )
4903
+
4904
+
4898
4905
  def atomic_op_constraint(arg_types: Mapping[str, Any]):
4899
4906
  idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4900
4907
  return all(types_equal(idx_types[0], t) for t in idx_types[1:]) and arg_types["arr"].ndim == len(idx_types)
4901
4908
 
4902
4909
 
4903
- def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4904
- if arg_types is None:
4905
- return Any
4910
+ def create_atomic_op_value_func(op: str):
4911
+ def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4912
+ if arg_types is None:
4913
+ return Any
4906
4914
 
4907
- arr_type = arg_types["arr"]
4908
- value_type = arg_types["value"]
4909
- idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4915
+ arr_type = arg_types["arr"]
4916
+ value_type = arg_types["value"]
4917
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4910
4918
 
4911
- if not is_array(arr_type):
4912
- raise RuntimeError("atomic() first argument must be an array")
4919
+ if not is_array(arr_type):
4920
+ raise RuntimeError(f"atomic_{op}() first argument must be an array")
4913
4921
 
4914
- idx_count = len(idx_types)
4922
+ idx_count = len(idx_types)
4915
4923
 
4916
- if idx_count < arr_type.ndim:
4917
- raise RuntimeError(
4918
- "Num indices < num dimensions for atomic, this is a codegen error, should have generated a view instead"
4919
- )
4924
+ if idx_count < arr_type.ndim:
4925
+ raise RuntimeError(
4926
+ f"Num indices < num dimensions for atomic_{op}(), this is a codegen error, should have generated a view instead"
4927
+ )
4920
4928
 
4921
- if idx_count > arr_type.ndim:
4922
- raise RuntimeError(
4923
- f"Num indices > num dimensions for atomic, received {idx_count}, but array only has {arr_type.ndim}"
4924
- )
4929
+ if idx_count > arr_type.ndim:
4930
+ raise RuntimeError(
4931
+ f"Num indices > num dimensions for atomic_{op}(), received {idx_count}, but array only has {arr_type.ndim}"
4932
+ )
4925
4933
 
4926
- # check index types
4927
- for t in idx_types:
4928
- if not type_is_int(t):
4929
- raise RuntimeError(f"atomic() index arguments must be of integer type, got index of type {type_repr(t)}")
4934
+ # check index types
4935
+ for t in idx_types:
4936
+ if not type_is_int(t):
4937
+ raise RuntimeError(
4938
+ f"atomic_{op}() index arguments must be of integer type, got index of type {type_repr(t)}"
4939
+ )
4930
4940
 
4931
- # check value type
4932
- if not types_equal(arr_type.dtype, value_type):
4933
- raise RuntimeError(
4934
- f"atomic() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
4935
- )
4941
+ # check value type
4942
+ if not types_equal(arr_type.dtype, value_type):
4943
+ raise RuntimeError(
4944
+ f"atomic_{op}() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
4945
+ )
4936
4946
 
4937
- return arr_type.dtype
4947
+ scalar_type = getattr(arr_type.dtype, "_wp_scalar_type_", arr_type.dtype)
4948
+ if op in ("add", "sub"):
4949
+ supported_atomic_types = (*SUPPORTED_ATOMIC_TYPES, warp.float16)
4950
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
4951
+ raise RuntimeError(
4952
+ f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float16, float32, or float64 "
4953
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
4954
+ )
4955
+ elif op in ("min", "max"):
4956
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
4957
+ raise RuntimeError(
4958
+ f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
4959
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
4960
+ )
4961
+ else:
4962
+ raise NotImplementedError
4963
+
4964
+ return arr_type.dtype
4965
+
4966
+ return fn
4938
4967
 
4939
4968
 
4940
4969
  def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -4959,9 +4988,10 @@ for array_type in array_types:
4959
4988
  hidden=hidden,
4960
4989
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4961
4990
  constraint=atomic_op_constraint,
4962
- value_func=atomic_op_value_func,
4991
+ value_func=create_atomic_op_value_func("add"),
4963
4992
  dispatch_func=atomic_op_dispatch_func,
4964
- doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
4993
+ doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
4994
+ This function is automatically invoked when using the syntax ``arr[i] += value``.""",
4965
4995
  group="Utility",
4966
4996
  skip_replay=True,
4967
4997
  )
@@ -4970,9 +5000,10 @@ for array_type in array_types:
4970
5000
  hidden=hidden,
4971
5001
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4972
5002
  constraint=atomic_op_constraint,
4973
- value_func=atomic_op_value_func,
5003
+ value_func=create_atomic_op_value_func("add"),
4974
5004
  dispatch_func=atomic_op_dispatch_func,
4975
- doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
5005
+ doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
5006
+ This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
4976
5007
  group="Utility",
4977
5008
  skip_replay=True,
4978
5009
  )
@@ -4981,9 +5012,10 @@ for array_type in array_types:
4981
5012
  hidden=hidden,
4982
5013
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4983
5014
  constraint=atomic_op_constraint,
4984
- value_func=atomic_op_value_func,
5015
+ value_func=create_atomic_op_value_func("add"),
4985
5016
  dispatch_func=atomic_op_dispatch_func,
4986
- doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
5017
+ doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
5018
+ This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
4987
5019
  group="Utility",
4988
5020
  skip_replay=True,
4989
5021
  )
@@ -4992,9 +5024,10 @@ for array_type in array_types:
4992
5024
  hidden=hidden,
4993
5025
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4994
5026
  constraint=atomic_op_constraint,
4995
- value_func=atomic_op_value_func,
5027
+ value_func=create_atomic_op_value_func("add"),
4996
5028
  dispatch_func=atomic_op_dispatch_func,
4997
- doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
5029
+ doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
5030
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
4998
5031
  group="Utility",
4999
5032
  skip_replay=True,
5000
5033
  )
@@ -5004,9 +5037,10 @@ for array_type in array_types:
5004
5037
  hidden=hidden,
5005
5038
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5006
5039
  constraint=atomic_op_constraint,
5007
- value_func=atomic_op_value_func,
5040
+ value_func=create_atomic_op_value_func("sub"),
5008
5041
  dispatch_func=atomic_op_dispatch_func,
5009
- doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
5042
+ doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
5043
+ This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
5010
5044
  group="Utility",
5011
5045
  skip_replay=True,
5012
5046
  )
@@ -5015,9 +5049,10 @@ for array_type in array_types:
5015
5049
  hidden=hidden,
5016
5050
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5017
5051
  constraint=atomic_op_constraint,
5018
- value_func=atomic_op_value_func,
5052
+ value_func=create_atomic_op_value_func("sub"),
5019
5053
  dispatch_func=atomic_op_dispatch_func,
5020
- doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
5054
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
5055
+ This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
5021
5056
  group="Utility",
5022
5057
  skip_replay=True,
5023
5058
  )
@@ -5026,9 +5061,10 @@ for array_type in array_types:
5026
5061
  hidden=hidden,
5027
5062
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5028
5063
  constraint=atomic_op_constraint,
5029
- value_func=atomic_op_value_func,
5064
+ value_func=create_atomic_op_value_func("sub"),
5030
5065
  dispatch_func=atomic_op_dispatch_func,
5031
- doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
5066
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
5067
+ This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
5032
5068
  group="Utility",
5033
5069
  skip_replay=True,
5034
5070
  )
@@ -5037,9 +5073,10 @@ for array_type in array_types:
5037
5073
  hidden=hidden,
5038
5074
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5039
5075
  constraint=atomic_op_constraint,
5040
- value_func=atomic_op_value_func,
5076
+ value_func=create_atomic_op_value_func("sub"),
5041
5077
  dispatch_func=atomic_op_dispatch_func,
5042
- doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
5078
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
5079
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
5043
5080
  group="Utility",
5044
5081
  skip_replay=True,
5045
5082
  )
@@ -5049,7 +5086,7 @@ for array_type in array_types:
5049
5086
  hidden=hidden,
5050
5087
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5051
5088
  constraint=atomic_op_constraint,
5052
- value_func=atomic_op_value_func,
5089
+ value_func=create_atomic_op_value_func("min"),
5053
5090
  dispatch_func=atomic_op_dispatch_func,
5054
5091
  doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
5055
5092
 
@@ -5062,7 +5099,7 @@ for array_type in array_types:
5062
5099
  hidden=hidden,
5063
5100
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5064
5101
  constraint=atomic_op_constraint,
5065
- value_func=atomic_op_value_func,
5102
+ value_func=create_atomic_op_value_func("min"),
5066
5103
  dispatch_func=atomic_op_dispatch_func,
5067
5104
  doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
5068
5105
 
@@ -5075,7 +5112,7 @@ for array_type in array_types:
5075
5112
  hidden=hidden,
5076
5113
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5077
5114
  constraint=atomic_op_constraint,
5078
- value_func=atomic_op_value_func,
5115
+ value_func=create_atomic_op_value_func("min"),
5079
5116
  dispatch_func=atomic_op_dispatch_func,
5080
5117
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
5081
5118
 
@@ -5088,7 +5125,7 @@ for array_type in array_types:
5088
5125
  hidden=hidden,
5089
5126
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5090
5127
  constraint=atomic_op_constraint,
5091
- value_func=atomic_op_value_func,
5128
+ value_func=create_atomic_op_value_func("min"),
5092
5129
  dispatch_func=atomic_op_dispatch_func,
5093
5130
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
5094
5131
 
@@ -5102,7 +5139,7 @@ for array_type in array_types:
5102
5139
  hidden=hidden,
5103
5140
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5104
5141
  constraint=atomic_op_constraint,
5105
- value_func=atomic_op_value_func,
5142
+ value_func=create_atomic_op_value_func("max"),
5106
5143
  dispatch_func=atomic_op_dispatch_func,
5107
5144
  doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
5108
5145
 
@@ -5115,7 +5152,7 @@ for array_type in array_types:
5115
5152
  hidden=hidden,
5116
5153
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5117
5154
  constraint=atomic_op_constraint,
5118
- value_func=atomic_op_value_func,
5155
+ value_func=create_atomic_op_value_func("max"),
5119
5156
  dispatch_func=atomic_op_dispatch_func,
5120
5157
  doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
5121
5158
 
@@ -5128,7 +5165,7 @@ for array_type in array_types:
5128
5165
  hidden=hidden,
5129
5166
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5130
5167
  constraint=atomic_op_constraint,
5131
- value_func=atomic_op_value_func,
5168
+ value_func=create_atomic_op_value_func("max"),
5132
5169
  dispatch_func=atomic_op_dispatch_func,
5133
5170
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
5134
5171
 
@@ -5141,7 +5178,7 @@ for array_type in array_types:
5141
5178
  hidden=hidden,
5142
5179
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5143
5180
  constraint=atomic_op_constraint,
5144
- value_func=atomic_op_value_func,
5181
+ value_func=create_atomic_op_value_func("max"),
5145
5182
  dispatch_func=atomic_op_dispatch_func,
5146
5183
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
5147
5184
 
warp/codegen.py CHANGED
@@ -202,7 +202,7 @@ def get_full_arg_spec(func: Callable) -> inspect.FullArgSpec:
202
202
  return spec._replace(annotations=eval_annotations(spec.annotations, func))
203
203
 
204
204
 
205
- def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
205
+ def struct_instance_repr_recursive(inst: StructInstance, depth: int, use_repr: bool) -> str:
206
206
  indent = "\t"
207
207
 
208
208
  # handle empty structs
@@ -216,9 +216,12 @@ def struct_instance_repr_recursive(inst: StructInstance, depth: int) -> str:
216
216
  field_value = getattr(inst, field_name, None)
217
217
 
218
218
  if isinstance(field_value, StructInstance):
219
- field_value = struct_instance_repr_recursive(field_value, depth + 1)
219
+ field_value = struct_instance_repr_recursive(field_value, depth + 1, use_repr)
220
220
 
221
- lines.append(f"{indent * (depth + 1)}{field_name}={field_value},")
221
+ if use_repr:
222
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!r},")
223
+ else:
224
+ lines.append(f"{indent * (depth + 1)}{field_name}={field_value!s},")
222
225
 
223
226
  lines.append(f"{indent * depth})")
224
227
  return "\n".join(lines)
@@ -237,7 +240,7 @@ class StructInstance:
237
240
  # create Python attributes for each of the struct's variables
238
241
  for field, var in cls.vars.items():
239
242
  if isinstance(var.type, warp.codegen.Struct):
240
- self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
243
+ self.__dict__[field] = var.type.instance_type(ctype=getattr(self._ctype, field))
241
244
  elif isinstance(var.type, warp.types.array):
242
245
  self.__dict__[field] = None
243
246
  else:
@@ -285,6 +288,11 @@ class StructInstance:
285
288
  )
286
289
  setattr(self._ctype, name, value.__ctype__())
287
290
 
291
+ # workaround to prevent gradient buffers being garbage collected
292
+ # since users can do struct.array.requires_grad = False the gradient array
293
+ # would be collected while the struct ctype still holds a reference to it
294
+ super().__setattr__("_" + name + "_grad", value.grad)
295
+
288
296
  elif isinstance(var.type, Struct):
289
297
  # assign structs by-value, otherwise we would have problematic cases transferring ownership
290
298
  # of the underlying ctypes data between shared Python struct instances
@@ -341,7 +349,10 @@ class StructInstance:
341
349
  return self._ctype
342
350
 
343
351
  def __repr__(self):
344
- return struct_instance_repr_recursive(self, 0)
352
+ return struct_instance_repr_recursive(self, 0, use_repr=True)
353
+
354
+ def __str__(self):
355
+ return struct_instance_repr_recursive(self, 0, use_repr=False)
345
356
 
346
357
  def to(self, device):
347
358
  """Copies this struct with all array members moved onto the given device.
@@ -407,11 +418,14 @@ class StructInstance:
407
418
  class Struct:
408
419
  hash: bytes
409
420
 
410
- def __init__(self, cls: type, key: str, module: warp.context.Module):
421
+ def __init__(self, key: str, cls: type, module: warp.context.Module):
422
+ self.key = key
411
423
  self.cls = cls
412
424
  self.module = module
413
- self.key = key
414
- self.vars: Dict[str, Var] = {}
425
+ self.vars: dict[str, Var] = {}
426
+
427
+ if isinstance(self.cls, Sequence):
428
+ raise RuntimeError("Warp structs must be defined as base classes")
415
429
 
416
430
  annotations = get_annotations(self.cls)
417
431
  for label, type in annotations.items():
@@ -483,34 +497,35 @@ class Struct:
483
497
 
484
498
  self.default_constructor.add_overload(self.value_constructor)
485
499
 
486
- if module:
500
+ if isinstance(module, warp.context.Module):
487
501
  module.register_struct(self)
488
502
 
489
- def __call__(self):
490
- """
491
- This function returns s = StructInstance(self)
492
- s uses self.cls as template.
493
- To enable autocomplete on s, we inherit from self.cls.
494
- For example,
503
+ # Define class for instances of this struct
504
+ # To enable autocomplete on s, we inherit from self.cls.
505
+ # For example,
495
506
 
496
- @wp.struct
497
- class A:
498
- # annotations
499
- ...
500
-
501
- The type annotations are inherited in A(), allowing autocomplete in kernels
502
- """
503
- # return StructInstance(self)
507
+ # @wp.struct
508
+ # class A:
509
+ # # annotations
510
+ # ...
504
511
 
512
+ # The type annotations are inherited in A(), allowing autocomplete in kernels
505
513
  class NewStructInstance(self.cls, StructInstance):
506
- def __init__(inst):
507
- StructInstance.__init__(inst, self, None)
514
+ def __init__(inst, ctype=None):
515
+ StructInstance.__init__(inst, self, ctype)
508
516
 
509
517
  # make sure warp.types.get_type_code works with this StructInstance
510
518
  NewStructInstance.cls = self.cls
511
519
  NewStructInstance.native_name = self.native_name
512
520
 
513
- return NewStructInstance()
521
+ self.instance_type = NewStructInstance
522
+
523
+ def __call__(self):
524
+ """
525
+ This function returns s = StructInstance(self)
526
+ s uses self.cls as template.
527
+ """
528
+ return self.instance_type()
514
529
 
515
530
  def initializer(self):
516
531
  return self.default_constructor
@@ -1492,6 +1507,8 @@ class Adjoint:
1492
1507
 
1493
1508
  def add_return(adj, var):
1494
1509
  if var is None or len(var) == 0:
1510
+ # NOTE: If this kernel gets compiled for a CUDA device, then we need
1511
+ # to convert the return; into a continue; in codegen_func_forward()
1495
1512
  adj.add_forward("return;", f"goto label{adj.label_count};")
1496
1513
  elif len(var) == 1:
1497
1514
  adj.add_forward(f"return {var[0].emit()};", f"goto label{adj.label_count};")
@@ -3549,7 +3566,11 @@ def codegen_func_forward(adj, func_type="kernel", device="cpu"):
3549
3566
  lines += ["// forward\n"]
3550
3567
 
3551
3568
  for f in adj.blocks[0].body_forward:
3552
- lines += [f + "\n"]
3569
+ if func_type == "kernel" and device == "cuda" and f.lstrip().startswith("return;"):
3570
+ # Use of grid-stride loops in CUDA kernels requires that we convert return; to continue;
3571
+ lines += [f.replace("return;", "continue;") + "\n"]
3572
+ else:
3573
+ lines += [f + "\n"]
3553
3574
 
3554
3575
  return "".join(l.lstrip() if l.lstrip().startswith("#line") else indent_block + l for l in lines)
3555
3576
 
warp/config.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import Optional
17
17
 
18
- version: str = "1.7.0"
18
+ version: str = "1.7.2"
19
19
  """Warp version string"""
20
20
 
21
21
  verify_fp: bool = False