warp-lang 1.5.0__py3-none-macosx_10_13_universal2.whl → 1.5.1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Binary file
warp/bin/libwarp.dylib CHANGED
Binary file
warp/builtins.py CHANGED
@@ -399,11 +399,11 @@ def scalar_infer_type(arg_types: Mapping[str, type]):
399
399
 
400
400
  scalar_types = set()
401
401
  for t in arg_types:
402
- t = strip_reference(t)
403
- if hasattr(t, "_wp_scalar_type_"):
404
- scalar_types.add(t._wp_scalar_type_)
405
- elif t in scalar_and_bool_types:
406
- scalar_types.add(t)
402
+ t_val = strip_reference(t)
403
+ if hasattr(t_val, "_wp_scalar_type_"):
404
+ scalar_types.add(t_val._wp_scalar_type_)
405
+ elif t_val in scalar_and_bool_types:
406
+ scalar_types.add(t_val)
407
407
 
408
408
  if len(scalar_types) > 1:
409
409
  raise RuntimeError(
@@ -1852,6 +1852,7 @@ def tile_arange_value_func(arg_types: Mapping[str, type], arg_values: Mapping[st
1852
1852
  step = args[2]
1853
1853
 
1854
1854
  if start is None or stop is None or step is None:
1855
+ print(args)
1855
1856
  raise RuntimeError("wp.tile_arange() arguments must be compile time constants")
1856
1857
 
1857
1858
  if "dtype" in arg_values:
@@ -2083,7 +2084,7 @@ def tile_store_1d_value_func(arg_types, arg_values):
2083
2084
 
2084
2085
  add_builtin(
2085
2086
  "tile_store",
2086
- input_types={"a": array(dtype=Any), "i": int, "t": Any},
2087
+ input_types={"a": array(dtype=Any), "i": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2087
2088
  value_func=tile_store_1d_value_func,
2088
2089
  variadic=False,
2089
2090
  skip_replay=True,
@@ -2132,7 +2133,7 @@ def tile_store_2d_value_func(arg_types, arg_values):
2132
2133
 
2133
2134
  add_builtin(
2134
2135
  "tile_store",
2135
- input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Any},
2136
+ input_types={"a": array(dtype=Any), "i": int, "j": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2136
2137
  value_func=tile_store_2d_value_func,
2137
2138
  variadic=False,
2138
2139
  skip_replay=True,
@@ -2177,7 +2178,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2177
2178
 
2178
2179
  add_builtin(
2179
2180
  "tile_atomic_add",
2180
- input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Any},
2181
+ input_types={"a": array(dtype=Any), "x": int, "y": int, "t": Tile(dtype=Any, M=Any, N=Any)},
2181
2182
  value_func=tile_atomic_add_value_func,
2182
2183
  variadic=True,
2183
2184
  skip_replay=True,
@@ -2365,7 +2366,7 @@ def untile_value_func(arg_types, arg_values):
2365
2366
 
2366
2367
  add_builtin(
2367
2368
  "untile",
2368
- input_types={"a": Any},
2369
+ input_types={"a": Tile(dtype=Any, M=Any, N=Any)},
2369
2370
  value_func=untile_value_func,
2370
2371
  variadic=True,
2371
2372
  doc="""Convert a Tile back to per-thread values.
@@ -2390,7 +2391,7 @@ add_builtin(
2390
2391
  t = wp.tile(i)*2
2391
2392
 
2392
2393
  # convert back to per-thread values
2393
- s = wp.untile()
2394
+ s = wp.untile(t)
2394
2395
 
2395
2396
  print(s)
2396
2397
 
@@ -2562,7 +2563,7 @@ add_builtin(
2562
2563
  variadic=True,
2563
2564
  doc="""Broadcast a tile.
2564
2565
 
2565
- This method will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
2566
+ This function will attempt to broadcast the input tile ``a`` to the destination shape (m, n), broadcasting follows NumPy broadcast rules.
2566
2567
 
2567
2568
  :param a: Tile to broadcast
2568
2569
  :returns: Tile with broadcast ``shape=(m, n)``""",
@@ -2654,9 +2655,9 @@ add_builtin(
2654
2655
  t = wp.tile_ones(dtype=float, m=16, n=16)
2655
2656
  s = wp.tile_sum(t)
2656
2657
 
2657
- print(t)
2658
+ print(s)
2658
2659
 
2659
- wp.launch(compute, dim=[64], inputs=[])
2660
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2660
2661
 
2661
2662
  Prints:
2662
2663
 
@@ -2703,18 +2704,19 @@ add_builtin(
2703
2704
  @wp.kernel
2704
2705
  def compute():
2705
2706
 
2706
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
2707
+ t = wp.tile_arange(64, 128)
2707
2708
  s = wp.tile_min(t)
2708
2709
 
2709
- print(t)
2710
+ print(s)
2710
2711
 
2711
- wp.launch(compute, dim=[64], inputs=[])
2712
+
2713
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2712
2714
 
2713
2715
  Prints:
2714
2716
 
2715
2717
  .. code-block:: text
2716
2718
 
2717
- tile(m=1, n=1, storage=register) = [[-10]]
2719
+ tile(m=1, n=1, storage=register) = [[64 ]]
2718
2720
 
2719
2721
  """,
2720
2722
  group="Tile Primitives",
@@ -2755,18 +2757,18 @@ add_builtin(
2755
2757
  @wp.kernel
2756
2758
  def compute():
2757
2759
 
2758
- t = wp.tile_arange(start=--10, stop=10, dtype=float)
2759
- s = wp.tile_min(t)
2760
+ t = wp.tile_arange(64, 128)
2761
+ s = wp.tile_max(t)
2760
2762
 
2761
- print(t)
2763
+ print(s)
2762
2764
 
2763
- wp.launch(compute, dim=[64], inputs=[])
2765
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=64)
2764
2766
 
2765
2767
  Prints:
2766
2768
 
2767
2769
  .. code-block:: text
2768
2770
 
2769
- tile(m=1, n=1, storage=register) = [[10]]
2771
+ tile(m=1, n=1, storage=register) = [[127 ]]
2770
2772
 
2771
2773
  """,
2772
2774
  group="Tile Primitives",
@@ -2796,7 +2798,7 @@ def tile_reduce_dispatch_func(input_types: Mapping[str, type], return_type: Any,
2796
2798
 
2797
2799
  add_builtin(
2798
2800
  "tile_reduce",
2799
- input_types={"op": Callable, "a": Any},
2801
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2800
2802
  value_func=tile_reduce_value_func,
2801
2803
  native_func="tile_reduce",
2802
2804
  doc="""Apply a custom reduction operator across the tile.
@@ -2819,7 +2821,7 @@ add_builtin(
2819
2821
 
2820
2822
  print(s)
2821
2823
 
2822
- wp.launch(factorial, dim=[16], inputs=[], block_dim=16)
2824
+ wp.launch_tiled(factorial, dim=[1], inputs=[], block_dim=16)
2823
2825
 
2824
2826
  Prints:
2825
2827
 
@@ -2856,7 +2858,7 @@ def tile_unary_map_value_func(arg_types, arg_values):
2856
2858
 
2857
2859
  add_builtin(
2858
2860
  "tile_map",
2859
- input_types={"op": Callable, "a": Any},
2861
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any)},
2860
2862
  value_func=tile_unary_map_value_func,
2861
2863
  # dispatch_func=tile_map_dispatch_func,
2862
2864
  # variadic=True,
@@ -2881,7 +2883,7 @@ add_builtin(
2881
2883
 
2882
2884
  print(s)
2883
2885
 
2884
- wp.launch(compute, dim=[16], inputs=[])
2886
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2885
2887
 
2886
2888
  Prints:
2887
2889
 
@@ -2923,7 +2925,7 @@ def tile_binary_map_value_func(arg_types, arg_values):
2923
2925
 
2924
2926
  add_builtin(
2925
2927
  "tile_map",
2926
- input_types={"op": Callable, "a": Any, "b": Any},
2928
+ input_types={"op": Callable, "a": Tile(dtype=Any, M=Any, N=Any), "b": Tile(dtype=Any, M=Any, N=Any)},
2927
2929
  value_func=tile_binary_map_value_func,
2928
2930
  # dispatch_func=tile_map_dispatch_func,
2929
2931
  # variadic=True,
@@ -2952,7 +2954,7 @@ add_builtin(
2952
2954
 
2953
2955
  print(s)
2954
2956
 
2955
- wp.launch(compute, dim=[16], inputs=[])
2957
+ wp.launch_tiled(compute, dim=[1], inputs=[], block_dim=16)
2956
2958
 
2957
2959
  Prints:
2958
2960
 
@@ -4665,6 +4667,19 @@ def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str,
4665
4667
  return arr_type.dtype
4666
4668
 
4667
4669
 
4670
+ def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
4671
+ # as this is a codegen callback, we can mark the fact that this func writes to an array here
4672
+ if warp.config.verify_autograd_array_access:
4673
+ arr = args["arr"]
4674
+ arr.mark_write()
4675
+
4676
+ func_args = tuple(args.values())
4677
+ # we don't need to specify template arguments for atomic ops
4678
+ template_args = ()
4679
+
4680
+ return (func_args, template_args)
4681
+
4682
+
4668
4683
  for array_type in array_types:
4669
4684
  # don't list indexed array operations explicitly in docs
4670
4685
  hidden = array_type == indexedarray
@@ -4675,6 +4690,7 @@ for array_type in array_types:
4675
4690
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4676
4691
  constraint=atomic_op_constraint,
4677
4692
  value_func=atomic_op_value_func,
4693
+ dispatch_func=atomic_op_dispatch_func,
4678
4694
  doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
4679
4695
  group="Utility",
4680
4696
  skip_replay=True,
@@ -4685,6 +4701,7 @@ for array_type in array_types:
4685
4701
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4686
4702
  constraint=atomic_op_constraint,
4687
4703
  value_func=atomic_op_value_func,
4704
+ dispatch_func=atomic_op_dispatch_func,
4688
4705
  doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
4689
4706
  group="Utility",
4690
4707
  skip_replay=True,
@@ -4695,6 +4712,7 @@ for array_type in array_types:
4695
4712
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4696
4713
  constraint=atomic_op_constraint,
4697
4714
  value_func=atomic_op_value_func,
4715
+ dispatch_func=atomic_op_dispatch_func,
4698
4716
  doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
4699
4717
  group="Utility",
4700
4718
  skip_replay=True,
@@ -4705,6 +4723,7 @@ for array_type in array_types:
4705
4723
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4706
4724
  constraint=atomic_op_constraint,
4707
4725
  value_func=atomic_op_value_func,
4726
+ dispatch_func=atomic_op_dispatch_func,
4708
4727
  doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
4709
4728
  group="Utility",
4710
4729
  skip_replay=True,
@@ -4716,6 +4735,7 @@ for array_type in array_types:
4716
4735
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4717
4736
  constraint=atomic_op_constraint,
4718
4737
  value_func=atomic_op_value_func,
4738
+ dispatch_func=atomic_op_dispatch_func,
4719
4739
  doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
4720
4740
  group="Utility",
4721
4741
  skip_replay=True,
@@ -4726,6 +4746,7 @@ for array_type in array_types:
4726
4746
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4727
4747
  constraint=atomic_op_constraint,
4728
4748
  value_func=atomic_op_value_func,
4749
+ dispatch_func=atomic_op_dispatch_func,
4729
4750
  doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
4730
4751
  group="Utility",
4731
4752
  skip_replay=True,
@@ -4736,6 +4757,7 @@ for array_type in array_types:
4736
4757
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4737
4758
  constraint=atomic_op_constraint,
4738
4759
  value_func=atomic_op_value_func,
4760
+ dispatch_func=atomic_op_dispatch_func,
4739
4761
  doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
4740
4762
  group="Utility",
4741
4763
  skip_replay=True,
@@ -4746,6 +4768,7 @@ for array_type in array_types:
4746
4768
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4747
4769
  constraint=atomic_op_constraint,
4748
4770
  value_func=atomic_op_value_func,
4771
+ dispatch_func=atomic_op_dispatch_func,
4749
4772
  doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
4750
4773
  group="Utility",
4751
4774
  skip_replay=True,
@@ -4757,6 +4780,7 @@ for array_type in array_types:
4757
4780
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4758
4781
  constraint=atomic_op_constraint,
4759
4782
  value_func=atomic_op_value_func,
4783
+ dispatch_func=atomic_op_dispatch_func,
4760
4784
  doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
4761
4785
 
4762
4786
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4769,6 +4793,7 @@ for array_type in array_types:
4769
4793
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4770
4794
  constraint=atomic_op_constraint,
4771
4795
  value_func=atomic_op_value_func,
4796
+ dispatch_func=atomic_op_dispatch_func,
4772
4797
  doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
4773
4798
 
4774
4799
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4781,6 +4806,7 @@ for array_type in array_types:
4781
4806
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4782
4807
  constraint=atomic_op_constraint,
4783
4808
  value_func=atomic_op_value_func,
4809
+ dispatch_func=atomic_op_dispatch_func,
4784
4810
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
4785
4811
 
4786
4812
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4793,6 +4819,7 @@ for array_type in array_types:
4793
4819
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4794
4820
  constraint=atomic_op_constraint,
4795
4821
  value_func=atomic_op_value_func,
4822
+ dispatch_func=atomic_op_dispatch_func,
4796
4823
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
4797
4824
 
4798
4825
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4806,6 +4833,7 @@ for array_type in array_types:
4806
4833
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4807
4834
  constraint=atomic_op_constraint,
4808
4835
  value_func=atomic_op_value_func,
4836
+ dispatch_func=atomic_op_dispatch_func,
4809
4837
  doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
4810
4838
 
4811
4839
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4818,6 +4846,7 @@ for array_type in array_types:
4818
4846
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4819
4847
  constraint=atomic_op_constraint,
4820
4848
  value_func=atomic_op_value_func,
4849
+ dispatch_func=atomic_op_dispatch_func,
4821
4850
  doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
4822
4851
 
4823
4852
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4830,6 +4859,7 @@ for array_type in array_types:
4830
4859
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4831
4860
  constraint=atomic_op_constraint,
4832
4861
  value_func=atomic_op_value_func,
4862
+ dispatch_func=atomic_op_dispatch_func,
4833
4863
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
4834
4864
 
4835
4865
  The operation is only atomic on a per-component basis for vectors and matrices.""",
@@ -4842,6 +4872,7 @@ for array_type in array_types:
4842
4872
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4843
4873
  constraint=atomic_op_constraint,
4844
4874
  value_func=atomic_op_value_func,
4875
+ dispatch_func=atomic_op_dispatch_func,
4845
4876
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
4846
4877
 
4847
4878
  The operation is only atomic on a per-component basis for vectors and matrices.""",
warp/codegen.py CHANGED
@@ -1175,25 +1175,25 @@ class Adjoint:
1175
1175
  left = adj.load(left)
1176
1176
  s = output.emit() + " = " + ("(" * len(comps)) + left.emit() + " "
1177
1177
 
1178
- prev_comp = None
1178
+ prev_comp_var = None
1179
1179
 
1180
1180
  for op, comp in zip(op_strings, comps):
1181
1181
  comp_chainable = op_str_is_chainable(op)
1182
- if comp_chainable and prev_comp:
1183
- # We restrict chaining to operands of the same type
1184
- if prev_comp.type is comp.type:
1185
- prev_comp = adj.load(prev_comp)
1186
- comp = adj.load(comp)
1187
- s += "&& (" + prev_comp.emit() + " " + op + " " + comp.emit() + ")) "
1182
+ if comp_chainable and prev_comp_var:
1183
+ # We restrict chaining to operands of the same type
1184
+ if prev_comp_var.type is comp.type:
1185
+ prev_comp_var = adj.load(prev_comp_var)
1186
+ comp_var = adj.load(comp)
1187
+ s += "&& (" + prev_comp_var.emit() + " " + op + " " + comp_var.emit() + ")) "
1188
1188
  else:
1189
1189
  raise WarpCodegenTypeError(
1190
- f"Cannot chain comparisons of unequal types: {prev_comp.type} {op} {comp.type}."
1190
+ f"Cannot chain comparisons of unequal types: {prev_comp_var.type} {op} {comp.type}."
1191
1191
  )
1192
1192
  else:
1193
- comp = adj.load(comp)
1194
- s += op + " " + comp.emit() + ") "
1193
+ comp_var = adj.load(comp)
1194
+ s += op + " " + comp_var.emit() + ") "
1195
1195
 
1196
- prev_comp = comp
1196
+ prev_comp_var = comp_var
1197
1197
 
1198
1198
  s = s.rstrip() + ";"
1199
1199
 
@@ -1366,13 +1366,15 @@ class Adjoint:
1366
1366
  fwd_args = []
1367
1367
  for func_arg in func_args:
1368
1368
  if not isinstance(func_arg, (Reference, warp.context.Function)):
1369
- func_arg = adj.load(func_arg)
1369
+ func_arg_var = adj.load(func_arg)
1370
+ else:
1371
+ func_arg_var = func_arg
1370
1372
 
1371
1373
  # if the argument is a function (and not a builtin), then build it recursively
1372
- if isinstance(func_arg, warp.context.Function) and not func_arg.is_builtin():
1373
- adj.builder.build_function(func_arg)
1374
+ if isinstance(func_arg_var, warp.context.Function) and not func_arg_var.is_builtin():
1375
+ adj.builder.build_function(func_arg_var)
1374
1376
 
1375
- fwd_args.append(strip_reference(func_arg))
1377
+ fwd_args.append(strip_reference(func_arg_var))
1376
1378
 
1377
1379
  if return_type is None:
1378
1380
  # handles expression (zero output) functions, e.g.: void do_something();
@@ -2569,8 +2571,10 @@ class Adjoint:
2569
2571
  adj.return_var = ()
2570
2572
  for ret in var:
2571
2573
  if is_reference(ret.type):
2572
- ret = adj.add_builtin_call("copy", [ret])
2573
- adj.return_var += (ret,)
2574
+ ret_var = adj.add_builtin_call("copy", [ret])
2575
+ else:
2576
+ ret_var = ret
2577
+ adj.return_var += (ret_var,)
2574
2578
 
2575
2579
  adj.add_return(adj.return_var)
2576
2580
 
warp/config.py CHANGED
@@ -7,7 +7,7 @@
7
7
 
8
8
  from typing import Optional
9
9
 
10
- version: str = "1.5.0"
10
+ version: str = "1.5.1"
11
11
  """Warp version string"""
12
12
 
13
13
  verify_fp: bool = False
warp/context.py CHANGED
@@ -7,6 +7,7 @@
7
7
 
8
8
  import ast
9
9
  import ctypes
10
+ import errno
10
11
  import functools
11
12
  import hashlib
12
13
  import inspect
@@ -17,6 +18,7 @@ import operator
17
18
  import os
18
19
  import platform
19
20
  import sys
21
+ import time
20
22
  import types
21
23
  import typing
22
24
  import weakref
@@ -238,24 +240,23 @@ class Function:
238
240
  # in a way that is compatible with Python's semantics.
239
241
  signature_params = []
240
242
  signature_default_param_kind = inspect.Parameter.POSITIONAL_OR_KEYWORD
241
- for param_name in self.input_types.keys():
242
- if param_name.startswith("**"):
243
- param_name = param_name[2:]
243
+ for raw_param_name in self.input_types.keys():
244
+ if raw_param_name.startswith("**"):
245
+ param_name = raw_param_name[2:]
244
246
  param_kind = inspect.Parameter.VAR_KEYWORD
245
- elif param_name.startswith("*"):
246
- param_name = param_name[1:]
247
+ elif raw_param_name.startswith("*"):
248
+ param_name = raw_param_name[1:]
247
249
  param_kind = inspect.Parameter.VAR_POSITIONAL
248
250
 
249
251
  # Once a variadic argument like `*args` is found, any following
250
252
  # arguments need to be passed using keywords.
251
253
  signature_default_param_kind = inspect.Parameter.KEYWORD_ONLY
252
254
  else:
255
+ param_name = raw_param_name
253
256
  param_kind = signature_default_param_kind
254
257
 
255
- param = param = inspect.Parameter(
256
- param_name,
257
- param_kind,
258
- default=self.defaults.get(param_name, inspect.Parameter.empty),
258
+ param = inspect.Parameter(
259
+ param_name, param_kind, default=self.defaults.get(param_name, inspect.Parameter.empty)
259
260
  )
260
261
  signature_params.append(param)
261
262
  self.signature = inspect.Signature(signature_params)
@@ -294,22 +295,22 @@ class Function:
294
295
 
295
296
  if hasattr(self, "user_overloads") and len(self.user_overloads):
296
297
  # user-defined function with overloads
298
+ bound_args = self.signature.bind(*args, **kwargs)
299
+ if self.defaults:
300
+ warp.codegen.apply_defaults(bound_args, self.defaults)
297
301
 
298
- if len(kwargs):
299
- raise RuntimeError(
300
- f"Error calling function '{self.key}', keyword arguments are not supported for user-defined overloads."
301
- )
302
+ arguments = tuple(bound_args.arguments.values())
302
303
 
303
304
  # try and find a matching overload
304
305
  for overload in self.user_overloads.values():
305
- if len(overload.input_types) != len(args):
306
+ if len(overload.input_types) != len(arguments):
306
307
  continue
307
308
  template_types = list(overload.input_types.values())
308
309
  arg_names = list(overload.input_types.keys())
309
310
  try:
310
311
  # attempt to unify argument types with function template types
311
- warp.types.infer_argument_types(args, template_types, arg_names)
312
- return overload.func(*args)
312
+ warp.types.infer_argument_types(arguments, template_types, arg_names)
313
+ return overload.func(*arguments)
313
314
  except Exception:
314
315
  continue
315
316
 
@@ -509,11 +510,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
509
510
  if elem_count != arg_type._length_:
510
511
  return (False, None)
511
512
 
512
- # Retrieve the element type of the sequence while ensuring
513
- # that it's homogeneous.
513
+ # Retrieve the element type of the sequence while ensuring that it's homogeneous.
514
514
  elem_type = type(arr[0])
515
- for i in range(1, elem_count):
516
- if type(arr[i]) is not elem_type:
515
+ for array_index in range(1, elem_count):
516
+ if type(arr[array_index]) is not elem_type:
517
517
  raise ValueError("All array elements must share the same type.")
518
518
 
519
519
  expected_elem_type = arg_type._wp_scalar_type_
@@ -543,10 +543,10 @@ def call_builtin(func: Function, *params) -> Tuple[bool, Any]:
543
543
  c_param = arg_type()
544
544
  if warp.types.type_is_matrix(arg_type):
545
545
  rows, cols = arg_type._shape_
546
- for i in range(rows):
547
- idx_start = i * cols
546
+ for row_index in range(rows):
547
+ idx_start = row_index * cols
548
548
  idx_end = idx_start + cols
549
- c_param[i] = arr[idx_start:idx_end]
549
+ c_param[row_index] = arr[idx_start:idx_end]
550
550
  else:
551
551
  c_param[:] = arr
552
552
 
@@ -1239,16 +1239,16 @@ def add_builtin(
1239
1239
  typelists.append(l)
1240
1240
 
1241
1241
  for arg_types in itertools.product(*typelists):
1242
- arg_types = dict(zip(input_types.keys(), arg_types))
1242
+ concrete_arg_types = dict(zip(input_types.keys(), arg_types))
1243
1243
 
1244
1244
  # Some of these argument lists won't work, eg if the function is mul(), we won't be
1245
1245
  # able to do a matrix vector multiplication for a mat22 and a vec3. The `constraint`
1246
1246
  # function determines which combinations are valid:
1247
1247
  if constraint:
1248
- if constraint(arg_types) is False:
1248
+ if constraint(concrete_arg_types) is False:
1249
1249
  continue
1250
1250
 
1251
- return_type = value_func(arg_types, None)
1251
+ return_type = value_func(concrete_arg_types, None)
1252
1252
 
1253
1253
  # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
1254
1254
  # in the list of hard coded types so it knows it's returning one of them:
@@ -1266,7 +1266,7 @@ def add_builtin(
1266
1266
  # finally we can generate a function call for these concrete types:
1267
1267
  add_builtin(
1268
1268
  key,
1269
- input_types=arg_types,
1269
+ input_types=concrete_arg_types,
1270
1270
  value_type=return_type,
1271
1271
  value_func=value_func if return_type is Any else None,
1272
1272
  export_func=export_func,
@@ -2133,12 +2133,34 @@ class Module:
2133
2133
  # -----------------------------------------------------------
2134
2134
  # update cache
2135
2135
 
2136
- try:
2137
- # Copy process-specific build directory to a process-independent location
2138
- os.rename(build_dir, module_dir)
2139
- except (OSError, FileExistsError):
2140
- # another process likely updated the module dir first
2141
- pass
2136
+ def safe_rename(src, dst, attempts=5, delay=0.1):
2137
+ for i in range(attempts):
2138
+ try:
2139
+ os.rename(src, dst)
2140
+ return
2141
+ except FileExistsError:
2142
+ return
2143
+ except OSError as e:
2144
+ if e.errno == errno.ENOTEMPTY:
2145
+ # if directory exists we assume another process
2146
+ # got there first, in which case we will copy
2147
+ # our output to the directory manually in second step
2148
+ return
2149
+ else:
2150
+ # otherwise assume directory creation failed e.g.: access denied
2151
+ # on Windows we see occasional failures to rename directories due to
2152
+ # some process holding a lock on a file to be moved to workaround
2153
+ # this we make multiple attempts to rename with some delay
2154
+ if i < attempts - 1:
2155
+ time.sleep(delay)
2156
+ else:
2157
+ print(
2158
+ f"Could not update Warp cache with module binaries, trying to rename {build_dir} to {module_dir}, error {e}"
2159
+ )
2160
+ raise e
2161
+
2162
+ # try to move process outputs to cache
2163
+ safe_rename(build_dir, module_dir)
2142
2164
 
2143
2165
  if os.path.exists(module_dir):
2144
2166
  if not os.path.exists(binary_path):
@@ -4074,7 +4096,7 @@ def set_mempool_enabled(device: Devicelike, enable: bool) -> None:
4074
4096
  They should generally be enabled, but there is a rare caveat. Copying data between different GPUs
4075
4097
  may fail during graph capture if the memory was allocated using pooled allocators and memory pool
4076
4098
  access is not enabled between the two GPUs. This is an internal CUDA limitation that is not related
4077
- to Warp. The preferred solution is to enable memory pool access using `warp.set_mempool_access_enabled()`.
4099
+ to Warp. The preferred solution is to enable memory pool access using :func:`set_mempool_access_enabled`.
4078
4100
  If peer access is not supported, then the default CUDA allocators must be used to pre-allocate the memory
4079
4101
  prior to graph capture.
4080
4102
  """
@@ -5272,6 +5294,8 @@ def launch(
5272
5294
  params_addr=kernel_params,
5273
5295
  bounds=bounds,
5274
5296
  device=device,
5297
+ max_blocks=max_blocks,
5298
+ block_dim=block_dim,
5275
5299
  )
5276
5300
  return launch
5277
5301
 
@@ -5355,7 +5379,7 @@ def launch_tiled(*args, **kwargs):
5355
5379
  kwargs["dim"] = dim + [kwargs["block_dim"]]
5356
5380
 
5357
5381
  # forward to original launch method
5358
- launch(*args, **kwargs)
5382
+ return launch(*args, **kwargs)
5359
5383
 
5360
5384
 
5361
5385
  def synchronize():
@@ -100,7 +100,6 @@ class Example:
100
100
  tri_ka=1e4,
101
101
  tri_kd=1e-5,
102
102
  edge_ke=100,
103
- color_particles=True,
104
103
  )
105
104
 
106
105
  usd_stage = Usd.Stage.Open(os.path.join(warp.examples.get_asset_directory(), "bunny.usd"))
@@ -122,6 +121,9 @@ class Example:
122
121
  kf=1.0e1,
123
122
  )
124
123
 
124
+ if self.integrator_type == IntegratorType.VBD:
125
+ builder.color()
126
+
125
127
  self.model = builder.finalize()
126
128
  self.model.ground = True
127
129
  self.model.soft_contact_ke = 1.0e4
@@ -59,7 +59,6 @@ class Geometry:
59
59
  SideIndexArg: wp.codegen.Struct
60
60
  """Structure containing arguments to be passed to device functions for indexing sides"""
61
61
 
62
- @staticmethod
63
62
  def cell_arg_value(self, device) -> "Geometry.CellArg":
64
63
  """Value of the arguments to be passed to cell-related device functions"""
65
64
  raise NotImplementedError
@@ -107,7 +106,6 @@ class Geometry:
107
106
  For elements with the same dimension as the embedding space, this will be zero."""
108
107
  raise NotImplementedError
109
108
 
110
- @staticmethod
111
109
  def side_arg_value(self, device) -> "Geometry.SideArg":
112
110
  """Value of the arguments to be passed to side-related device functions"""
113
111
  raise NotImplementedError
warp/native/coloring.cpp CHANGED
@@ -590,7 +590,11 @@ extern "C"
590
590
  if (num_colors > 1) {
591
591
  std::vector<std::vector<int>> color_groups;
592
592
  convert_to_color_groups(num_colors, graph.node_colors, color_groups);
593
- return balance_color_groups(target_max_min_ratio, graph, color_groups);
593
+
594
+ float max_min_ratio = balance_color_groups(target_max_min_ratio, graph, color_groups);
595
+ memcpy(node_colors.data, graph.node_colors.data(), num_nodes * sizeof(int));
596
+
597
+ return max_min_ratio;
594
598
  }
595
599
  else
596
600
  {