warp-lang 1.7.1__py3-none-manylinux_2_34_aarch64.whl → 1.7.2rc1__py3-none-manylinux_2_34_aarch64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of warp-lang might be problematic. Click here for more details.
- warp/bin/warp-clang.so +0 -0
- warp/bin/warp.so +0 -0
- warp/builtins.py +92 -56
- warp/codegen.py +31 -22
- warp/config.py +1 -1
- warp/context.py +106 -49
- warp/fem/cache.py +1 -1
- warp/jax_experimental/ffi.py +95 -66
- warp/native/builtin.h +91 -65
- warp/native/svd.h +59 -49
- warp/native/tile.h +46 -17
- warp/native/volume.cpp +2 -2
- warp/native/volume_builder.cu +33 -22
- warp/render/render_opengl.py +22 -17
- warp/render/render_usd.py +3 -3
- warp/sim/model.py +29 -21
- warp/sparse.py +1 -1
- warp/stubs.py +72 -24
- warp/tests/cuda/test_streams.py +1 -1
- warp/tests/sim/test_model.py +5 -3
- warp/tests/sim/test_sim_grad.py +1 -8
- warp/tests/test_array.py +8 -7
- warp/tests/test_atomic.py +181 -2
- warp/tests/test_builtins_resolution.py +38 -38
- warp/tests/test_fem.py +20 -6
- warp/tests/test_func.py +1 -1
- warp/tests/test_mat.py +46 -16
- warp/tests/test_struct.py +116 -0
- warp/tests/tile/test_tile.py +27 -0
- warp/tests/tile/test_tile_load.py +27 -0
- warp/types.py +42 -1
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2rc1.dist-info}/METADATA +26 -16
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2rc1.dist-info}/RECORD +36 -36
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2rc1.dist-info}/WHEEL +1 -1
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2rc1.dist-info}/licenses/LICENSE.md +0 -0
- {warp_lang-1.7.1.dist-info → warp_lang-1.7.2rc1.dist-info}/top_level.txt +0 -0
warp/bin/warp-clang.so
CHANGED
|
Binary file
|
warp/bin/warp.so
CHANGED
|
Binary file
|
warp/builtins.py
CHANGED
|
@@ -2264,7 +2264,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
|
|
|
2264
2264
|
f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
|
|
2265
2265
|
)
|
|
2266
2266
|
|
|
2267
|
-
return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
|
|
2267
|
+
return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape, storage=arg_types["t"].storage)
|
|
2268
2268
|
|
|
2269
2269
|
|
|
2270
2270
|
def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
@@ -2423,7 +2423,6 @@ add_builtin(
|
|
|
2423
2423
|
group="Tile Primitives",
|
|
2424
2424
|
export=False,
|
|
2425
2425
|
hidden=True,
|
|
2426
|
-
missing_grad=True,
|
|
2427
2426
|
)
|
|
2428
2427
|
|
|
2429
2428
|
add_builtin(
|
|
@@ -2433,7 +2432,6 @@ add_builtin(
|
|
|
2433
2432
|
group="Tile Primitives",
|
|
2434
2433
|
export=False,
|
|
2435
2434
|
hidden=True,
|
|
2436
|
-
missing_grad=True,
|
|
2437
2435
|
)
|
|
2438
2436
|
|
|
2439
2437
|
add_builtin(
|
|
@@ -2443,7 +2441,6 @@ add_builtin(
|
|
|
2443
2441
|
group="Tile Primitives",
|
|
2444
2442
|
export=False,
|
|
2445
2443
|
hidden=True,
|
|
2446
|
-
missing_grad=True,
|
|
2447
2444
|
)
|
|
2448
2445
|
|
|
2449
2446
|
add_builtin(
|
|
@@ -2453,7 +2450,6 @@ add_builtin(
|
|
|
2453
2450
|
group="Tile Primitives",
|
|
2454
2451
|
export=False,
|
|
2455
2452
|
hidden=True,
|
|
2456
|
-
missing_grad=True,
|
|
2457
2453
|
)
|
|
2458
2454
|
|
|
2459
2455
|
|
|
@@ -4896,46 +4892,78 @@ add_builtin(
|
|
|
4896
4892
|
)
|
|
4897
4893
|
|
|
4898
4894
|
|
|
4895
|
+
SUPPORTED_ATOMIC_TYPES = (
|
|
4896
|
+
warp.int32,
|
|
4897
|
+
warp.int64,
|
|
4898
|
+
warp.uint32,
|
|
4899
|
+
warp.uint64,
|
|
4900
|
+
warp.float32,
|
|
4901
|
+
warp.float64,
|
|
4902
|
+
)
|
|
4903
|
+
|
|
4904
|
+
|
|
4899
4905
|
def atomic_op_constraint(arg_types: Mapping[str, Any]):
|
|
4900
4906
|
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
4901
4907
|
return all(types_equal(idx_types[0], t) for t in idx_types[1:]) and arg_types["arr"].ndim == len(idx_types)
|
|
4902
4908
|
|
|
4903
4909
|
|
|
4904
|
-
def
|
|
4905
|
-
|
|
4906
|
-
|
|
4910
|
+
def create_atomic_op_value_func(op: str):
|
|
4911
|
+
def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
|
|
4912
|
+
if arg_types is None:
|
|
4913
|
+
return Any
|
|
4907
4914
|
|
|
4908
|
-
|
|
4909
|
-
|
|
4910
|
-
|
|
4915
|
+
arr_type = arg_types["arr"]
|
|
4916
|
+
value_type = arg_types["value"]
|
|
4917
|
+
idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
|
|
4911
4918
|
|
|
4912
|
-
|
|
4913
|
-
|
|
4919
|
+
if not is_array(arr_type):
|
|
4920
|
+
raise RuntimeError(f"atomic_{op}() first argument must be an array")
|
|
4914
4921
|
|
|
4915
|
-
|
|
4922
|
+
idx_count = len(idx_types)
|
|
4916
4923
|
|
|
4917
|
-
|
|
4918
|
-
|
|
4919
|
-
|
|
4920
|
-
|
|
4924
|
+
if idx_count < arr_type.ndim:
|
|
4925
|
+
raise RuntimeError(
|
|
4926
|
+
f"Num indices < num dimensions for atomic_{op}(), this is a codegen error, should have generated a view instead"
|
|
4927
|
+
)
|
|
4921
4928
|
|
|
4922
|
-
|
|
4923
|
-
|
|
4924
|
-
|
|
4925
|
-
|
|
4929
|
+
if idx_count > arr_type.ndim:
|
|
4930
|
+
raise RuntimeError(
|
|
4931
|
+
f"Num indices > num dimensions for atomic_{op}(), received {idx_count}, but array only has {arr_type.ndim}"
|
|
4932
|
+
)
|
|
4926
4933
|
|
|
4927
|
-
|
|
4928
|
-
|
|
4929
|
-
|
|
4930
|
-
|
|
4934
|
+
# check index types
|
|
4935
|
+
for t in idx_types:
|
|
4936
|
+
if not type_is_int(t):
|
|
4937
|
+
raise RuntimeError(
|
|
4938
|
+
f"atomic_{op}() index arguments must be of integer type, got index of type {type_repr(t)}"
|
|
4939
|
+
)
|
|
4931
4940
|
|
|
4932
|
-
|
|
4933
|
-
|
|
4934
|
-
|
|
4935
|
-
|
|
4936
|
-
|
|
4941
|
+
# check value type
|
|
4942
|
+
if not types_equal(arr_type.dtype, value_type):
|
|
4943
|
+
raise RuntimeError(
|
|
4944
|
+
f"atomic_{op}() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
|
|
4945
|
+
)
|
|
4946
|
+
|
|
4947
|
+
scalar_type = getattr(arr_type.dtype, "_wp_scalar_type_", arr_type.dtype)
|
|
4948
|
+
if op in ("add", "sub"):
|
|
4949
|
+
supported_atomic_types = (*SUPPORTED_ATOMIC_TYPES, warp.float16)
|
|
4950
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
|
|
4951
|
+
raise RuntimeError(
|
|
4952
|
+
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float16, float32, or float64 "
|
|
4953
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
4954
|
+
)
|
|
4955
|
+
elif op in ("min", "max"):
|
|
4956
|
+
if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
|
|
4957
|
+
raise RuntimeError(
|
|
4958
|
+
f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
|
|
4959
|
+
f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
|
|
4960
|
+
)
|
|
4961
|
+
else:
|
|
4962
|
+
raise NotImplementedError
|
|
4937
4963
|
|
|
4938
|
-
|
|
4964
|
+
return arr_type.dtype
|
|
4965
|
+
|
|
4966
|
+
return fn
|
|
4939
4967
|
|
|
4940
4968
|
|
|
4941
4969
|
def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
|
|
@@ -4960,9 +4988,10 @@ for array_type in array_types:
|
|
|
4960
4988
|
hidden=hidden,
|
|
4961
4989
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
4962
4990
|
constraint=atomic_op_constraint,
|
|
4963
|
-
value_func=
|
|
4991
|
+
value_func=create_atomic_op_value_func("add"),
|
|
4964
4992
|
dispatch_func=atomic_op_dispatch_func,
|
|
4965
|
-
doc="Atomically
|
|
4993
|
+
doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
4994
|
+
This function is automatically invoked when using the syntax ``arr[i] += value``.""",
|
|
4966
4995
|
group="Utility",
|
|
4967
4996
|
skip_replay=True,
|
|
4968
4997
|
)
|
|
@@ -4971,9 +5000,10 @@ for array_type in array_types:
|
|
|
4971
5000
|
hidden=hidden,
|
|
4972
5001
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
4973
5002
|
constraint=atomic_op_constraint,
|
|
4974
|
-
value_func=
|
|
5003
|
+
value_func=create_atomic_op_value_func("add"),
|
|
4975
5004
|
dispatch_func=atomic_op_dispatch_func,
|
|
4976
|
-
doc="Atomically
|
|
5005
|
+
doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
5006
|
+
This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
|
|
4977
5007
|
group="Utility",
|
|
4978
5008
|
skip_replay=True,
|
|
4979
5009
|
)
|
|
@@ -4982,9 +5012,10 @@ for array_type in array_types:
|
|
|
4982
5012
|
hidden=hidden,
|
|
4983
5013
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
4984
5014
|
constraint=atomic_op_constraint,
|
|
4985
|
-
value_func=
|
|
5015
|
+
value_func=create_atomic_op_value_func("add"),
|
|
4986
5016
|
dispatch_func=atomic_op_dispatch_func,
|
|
4987
|
-
doc="Atomically
|
|
5017
|
+
doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
5018
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
|
|
4988
5019
|
group="Utility",
|
|
4989
5020
|
skip_replay=True,
|
|
4990
5021
|
)
|
|
@@ -4993,9 +5024,10 @@ for array_type in array_types:
|
|
|
4993
5024
|
hidden=hidden,
|
|
4994
5025
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
4995
5026
|
constraint=atomic_op_constraint,
|
|
4996
|
-
value_func=
|
|
5027
|
+
value_func=create_atomic_op_value_func("add"),
|
|
4997
5028
|
dispatch_func=atomic_op_dispatch_func,
|
|
4998
|
-
doc="Atomically
|
|
5029
|
+
doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
5030
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
|
|
4999
5031
|
group="Utility",
|
|
5000
5032
|
skip_replay=True,
|
|
5001
5033
|
)
|
|
@@ -5005,9 +5037,10 @@ for array_type in array_types:
|
|
|
5005
5037
|
hidden=hidden,
|
|
5006
5038
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
5007
5039
|
constraint=atomic_op_constraint,
|
|
5008
|
-
value_func=
|
|
5040
|
+
value_func=create_atomic_op_value_func("sub"),
|
|
5009
5041
|
dispatch_func=atomic_op_dispatch_func,
|
|
5010
|
-
doc="Atomically
|
|
5042
|
+
doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
|
|
5043
|
+
This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
|
|
5011
5044
|
group="Utility",
|
|
5012
5045
|
skip_replay=True,
|
|
5013
5046
|
)
|
|
@@ -5016,9 +5049,10 @@ for array_type in array_types:
|
|
|
5016
5049
|
hidden=hidden,
|
|
5017
5050
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
5018
5051
|
constraint=atomic_op_constraint,
|
|
5019
|
-
value_func=
|
|
5052
|
+
value_func=create_atomic_op_value_func("sub"),
|
|
5020
5053
|
dispatch_func=atomic_op_dispatch_func,
|
|
5021
|
-
doc="Atomically
|
|
5054
|
+
doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
|
|
5055
|
+
This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
|
|
5022
5056
|
group="Utility",
|
|
5023
5057
|
skip_replay=True,
|
|
5024
5058
|
)
|
|
@@ -5027,9 +5061,10 @@ for array_type in array_types:
|
|
|
5027
5061
|
hidden=hidden,
|
|
5028
5062
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
5029
5063
|
constraint=atomic_op_constraint,
|
|
5030
|
-
value_func=
|
|
5064
|
+
value_func=create_atomic_op_value_func("sub"),
|
|
5031
5065
|
dispatch_func=atomic_op_dispatch_func,
|
|
5032
|
-
doc="Atomically
|
|
5066
|
+
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
|
|
5067
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
|
|
5033
5068
|
group="Utility",
|
|
5034
5069
|
skip_replay=True,
|
|
5035
5070
|
)
|
|
@@ -5038,9 +5073,10 @@ for array_type in array_types:
|
|
|
5038
5073
|
hidden=hidden,
|
|
5039
5074
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
5040
5075
|
constraint=atomic_op_constraint,
|
|
5041
|
-
value_func=
|
|
5076
|
+
value_func=create_atomic_op_value_func("sub"),
|
|
5042
5077
|
dispatch_func=atomic_op_dispatch_func,
|
|
5043
|
-
doc="Atomically
|
|
5078
|
+
doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
|
|
5079
|
+
This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
|
|
5044
5080
|
group="Utility",
|
|
5045
5081
|
skip_replay=True,
|
|
5046
5082
|
)
|
|
@@ -5050,7 +5086,7 @@ for array_type in array_types:
|
|
|
5050
5086
|
hidden=hidden,
|
|
5051
5087
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
5052
5088
|
constraint=atomic_op_constraint,
|
|
5053
|
-
value_func=
|
|
5089
|
+
value_func=create_atomic_op_value_func("min"),
|
|
5054
5090
|
dispatch_func=atomic_op_dispatch_func,
|
|
5055
5091
|
doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
5056
5092
|
|
|
@@ -5063,7 +5099,7 @@ for array_type in array_types:
|
|
|
5063
5099
|
hidden=hidden,
|
|
5064
5100
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
5065
5101
|
constraint=atomic_op_constraint,
|
|
5066
|
-
value_func=
|
|
5102
|
+
value_func=create_atomic_op_value_func("min"),
|
|
5067
5103
|
dispatch_func=atomic_op_dispatch_func,
|
|
5068
5104
|
doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
5069
5105
|
|
|
@@ -5076,7 +5112,7 @@ for array_type in array_types:
|
|
|
5076
5112
|
hidden=hidden,
|
|
5077
5113
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
5078
5114
|
constraint=atomic_op_constraint,
|
|
5079
|
-
value_func=
|
|
5115
|
+
value_func=create_atomic_op_value_func("min"),
|
|
5080
5116
|
dispatch_func=atomic_op_dispatch_func,
|
|
5081
5117
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
5082
5118
|
|
|
@@ -5089,7 +5125,7 @@ for array_type in array_types:
|
|
|
5089
5125
|
hidden=hidden,
|
|
5090
5126
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
5091
5127
|
constraint=atomic_op_constraint,
|
|
5092
|
-
value_func=
|
|
5128
|
+
value_func=create_atomic_op_value_func("min"),
|
|
5093
5129
|
dispatch_func=atomic_op_dispatch_func,
|
|
5094
5130
|
doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
5095
5131
|
|
|
@@ -5103,7 +5139,7 @@ for array_type in array_types:
|
|
|
5103
5139
|
hidden=hidden,
|
|
5104
5140
|
input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
|
|
5105
5141
|
constraint=atomic_op_constraint,
|
|
5106
|
-
value_func=
|
|
5142
|
+
value_func=create_atomic_op_value_func("max"),
|
|
5107
5143
|
dispatch_func=atomic_op_dispatch_func,
|
|
5108
5144
|
doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
|
|
5109
5145
|
|
|
@@ -5116,7 +5152,7 @@ for array_type in array_types:
|
|
|
5116
5152
|
hidden=hidden,
|
|
5117
5153
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
|
|
5118
5154
|
constraint=atomic_op_constraint,
|
|
5119
|
-
value_func=
|
|
5155
|
+
value_func=create_atomic_op_value_func("max"),
|
|
5120
5156
|
dispatch_func=atomic_op_dispatch_func,
|
|
5121
5157
|
doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
|
|
5122
5158
|
|
|
@@ -5129,7 +5165,7 @@ for array_type in array_types:
|
|
|
5129
5165
|
hidden=hidden,
|
|
5130
5166
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
|
|
5131
5167
|
constraint=atomic_op_constraint,
|
|
5132
|
-
value_func=
|
|
5168
|
+
value_func=create_atomic_op_value_func("max"),
|
|
5133
5169
|
dispatch_func=atomic_op_dispatch_func,
|
|
5134
5170
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
|
|
5135
5171
|
|
|
@@ -5142,7 +5178,7 @@ for array_type in array_types:
|
|
|
5142
5178
|
hidden=hidden,
|
|
5143
5179
|
input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
|
|
5144
5180
|
constraint=atomic_op_constraint,
|
|
5145
|
-
value_func=
|
|
5181
|
+
value_func=create_atomic_op_value_func("max"),
|
|
5146
5182
|
dispatch_func=atomic_op_dispatch_func,
|
|
5147
5183
|
doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
|
|
5148
5184
|
|
warp/codegen.py
CHANGED
|
@@ -240,7 +240,7 @@ class StructInstance:
|
|
|
240
240
|
# create Python attributes for each of the struct's variables
|
|
241
241
|
for field, var in cls.vars.items():
|
|
242
242
|
if isinstance(var.type, warp.codegen.Struct):
|
|
243
|
-
self.__dict__[field] =
|
|
243
|
+
self.__dict__[field] = var.type.instance_type(ctype=getattr(self._ctype, field))
|
|
244
244
|
elif isinstance(var.type, warp.types.array):
|
|
245
245
|
self.__dict__[field] = None
|
|
246
246
|
else:
|
|
@@ -288,6 +288,11 @@ class StructInstance:
|
|
|
288
288
|
)
|
|
289
289
|
setattr(self._ctype, name, value.__ctype__())
|
|
290
290
|
|
|
291
|
+
# workaround to prevent gradient buffers being garbage collected
|
|
292
|
+
# since users can do struct.array.requires_grad = False the gradient array
|
|
293
|
+
# would be collected while the struct ctype still holds a reference to it
|
|
294
|
+
super().__setattr__("_" + name + "_grad", value.grad)
|
|
295
|
+
|
|
291
296
|
elif isinstance(var.type, Struct):
|
|
292
297
|
# assign structs by-value, otherwise we would have problematic cases transferring ownership
|
|
293
298
|
# of the underlying ctypes data between shared Python struct instances
|
|
@@ -413,11 +418,14 @@ class StructInstance:
|
|
|
413
418
|
class Struct:
|
|
414
419
|
hash: bytes
|
|
415
420
|
|
|
416
|
-
def __init__(self,
|
|
421
|
+
def __init__(self, key: str, cls: type, module: warp.context.Module):
|
|
422
|
+
self.key = key
|
|
417
423
|
self.cls = cls
|
|
418
424
|
self.module = module
|
|
419
|
-
self.
|
|
420
|
-
|
|
425
|
+
self.vars: dict[str, Var] = {}
|
|
426
|
+
|
|
427
|
+
if isinstance(self.cls, Sequence):
|
|
428
|
+
raise RuntimeError("Warp structs must be defined as base classes")
|
|
421
429
|
|
|
422
430
|
annotations = get_annotations(self.cls)
|
|
423
431
|
for label, type in annotations.items():
|
|
@@ -489,34 +497,35 @@ class Struct:
|
|
|
489
497
|
|
|
490
498
|
self.default_constructor.add_overload(self.value_constructor)
|
|
491
499
|
|
|
492
|
-
if module:
|
|
500
|
+
if isinstance(module, warp.context.Module):
|
|
493
501
|
module.register_struct(self)
|
|
494
502
|
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
s uses self.cls as template.
|
|
499
|
-
To enable autocomplete on s, we inherit from self.cls.
|
|
500
|
-
For example,
|
|
501
|
-
|
|
502
|
-
@wp.struct
|
|
503
|
-
class A:
|
|
504
|
-
# annotations
|
|
505
|
-
...
|
|
503
|
+
# Define class for instances of this struct
|
|
504
|
+
# To enable autocomplete on s, we inherit from self.cls.
|
|
505
|
+
# For example,
|
|
506
506
|
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
#
|
|
507
|
+
# @wp.struct
|
|
508
|
+
# class A:
|
|
509
|
+
# # annotations
|
|
510
|
+
# ...
|
|
510
511
|
|
|
512
|
+
# The type annotations are inherited in A(), allowing autocomplete in kernels
|
|
511
513
|
class NewStructInstance(self.cls, StructInstance):
|
|
512
|
-
def __init__(inst):
|
|
513
|
-
StructInstance.__init__(inst, self,
|
|
514
|
+
def __init__(inst, ctype=None):
|
|
515
|
+
StructInstance.__init__(inst, self, ctype)
|
|
514
516
|
|
|
515
517
|
# make sure warp.types.get_type_code works with this StructInstance
|
|
516
518
|
NewStructInstance.cls = self.cls
|
|
517
519
|
NewStructInstance.native_name = self.native_name
|
|
518
520
|
|
|
519
|
-
|
|
521
|
+
self.instance_type = NewStructInstance
|
|
522
|
+
|
|
523
|
+
def __call__(self):
|
|
524
|
+
"""
|
|
525
|
+
This function returns s = StructInstance(self)
|
|
526
|
+
s uses self.cls as template.
|
|
527
|
+
"""
|
|
528
|
+
return self.instance_type()
|
|
520
529
|
|
|
521
530
|
def initializer(self):
|
|
522
531
|
return self.default_constructor
|
warp/config.py
CHANGED
warp/context.py
CHANGED
|
@@ -457,6 +457,24 @@ class Function:
|
|
|
457
457
|
return f"<Function {self.key}({inputs_str})>"
|
|
458
458
|
|
|
459
459
|
|
|
460
|
+
def get_builtin_type(return_type: type) -> type:
|
|
461
|
+
# The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
|
|
462
|
+
# in the list of hard coded types so it knows it's returning one of them:
|
|
463
|
+
if hasattr(return_type, "_wp_generic_type_hint_"):
|
|
464
|
+
return_type_match = tuple(
|
|
465
|
+
x
|
|
466
|
+
for x in generic_vtypes
|
|
467
|
+
if x._wp_generic_type_hint_ == return_type._wp_generic_type_hint_
|
|
468
|
+
and x._wp_type_params_ == return_type._wp_type_params_
|
|
469
|
+
)
|
|
470
|
+
if not return_type_match:
|
|
471
|
+
raise RuntimeError("No match")
|
|
472
|
+
|
|
473
|
+
return return_type_match[0]
|
|
474
|
+
|
|
475
|
+
return return_type
|
|
476
|
+
|
|
477
|
+
|
|
460
478
|
def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
|
|
461
479
|
uses_non_warp_array_type = False
|
|
462
480
|
|
|
@@ -1074,7 +1092,7 @@ def kernel(
|
|
|
1074
1092
|
# decorator to register struct, @struct
|
|
1075
1093
|
def struct(c: type):
|
|
1076
1094
|
m = get_module(c.__module__)
|
|
1077
|
-
s = warp.codegen.Struct(
|
|
1095
|
+
s = warp.codegen.Struct(key=warp.codegen.make_full_qualified_name(c), cls=c, module=m)
|
|
1078
1096
|
s = functools.update_wrapper(s, c)
|
|
1079
1097
|
return s
|
|
1080
1098
|
|
|
@@ -1445,6 +1463,24 @@ def register_api_function(
|
|
|
1445
1463
|
"""
|
|
1446
1464
|
function.group = group
|
|
1447
1465
|
function.hidden = hidden
|
|
1466
|
+
|
|
1467
|
+
# Update the docstring to mark these functions as being available from kernels and Python's runtime.
|
|
1468
|
+
assert function.__doc__.startswith("\n")
|
|
1469
|
+
leading_space_count = sum(1 for _ in itertools.takewhile(str.isspace, function.__doc__[1:]))
|
|
1470
|
+
assert leading_space_count % 4 == 0
|
|
1471
|
+
indent_level = leading_space_count // 4
|
|
1472
|
+
indent = " "
|
|
1473
|
+
function.__doc__ = (
|
|
1474
|
+
f"\n"
|
|
1475
|
+
f"{indent * indent_level}.. hlist::\n"
|
|
1476
|
+
f"{indent * (indent_level + 1)}:columns: 8\n"
|
|
1477
|
+
f"\n"
|
|
1478
|
+
f"{indent * (indent_level + 1)}* Kernel\n"
|
|
1479
|
+
f"{indent * (indent_level + 1)}* Python\n"
|
|
1480
|
+
f"{indent * (indent_level + 1)}* Differentiable\n"
|
|
1481
|
+
f"{function.__doc__}"
|
|
1482
|
+
)
|
|
1483
|
+
|
|
1448
1484
|
builtin_functions[function.key] = function
|
|
1449
1485
|
|
|
1450
1486
|
|
|
@@ -6510,12 +6546,46 @@ def type_str(t):
|
|
|
6510
6546
|
return t.__name__
|
|
6511
6547
|
|
|
6512
6548
|
|
|
6513
|
-
def
|
|
6549
|
+
def ctype_ret_str(t):
|
|
6550
|
+
return get_builtin_type(t).__name__
|
|
6551
|
+
|
|
6552
|
+
|
|
6553
|
+
def resolve_exported_function_sig(f):
|
|
6554
|
+
if not f.export or f.generic:
|
|
6555
|
+
return None
|
|
6556
|
+
|
|
6557
|
+
# only export simple types that don't use arrays or templated types
|
|
6558
|
+
if not f.is_simple():
|
|
6559
|
+
return None
|
|
6560
|
+
|
|
6561
|
+
# Runtime arguments that are to be passed to the function, not its template signature.
|
|
6562
|
+
if f.export_func is not None:
|
|
6563
|
+
func_args = f.export_func(f.input_types)
|
|
6564
|
+
else:
|
|
6565
|
+
func_args = f.input_types
|
|
6566
|
+
|
|
6567
|
+
# todo: construct a default value for each of the functions args
|
|
6568
|
+
# so we can generate the return type for overloaded functions
|
|
6569
|
+
return_type = f.value_func(func_args, None)
|
|
6570
|
+
|
|
6571
|
+
try:
|
|
6572
|
+
return_type_str = ctype_ret_str(return_type)
|
|
6573
|
+
except Exception:
|
|
6574
|
+
return None
|
|
6575
|
+
|
|
6576
|
+
if return_type_str.startswith("Tuple"):
|
|
6577
|
+
return None
|
|
6578
|
+
|
|
6579
|
+
return (func_args, return_type)
|
|
6580
|
+
|
|
6581
|
+
|
|
6582
|
+
def print_function(f, file, is_exported, noentry=False): # pragma: no cover
|
|
6514
6583
|
"""Writes a function definition to a file for use in reST documentation
|
|
6515
6584
|
|
|
6516
6585
|
Args:
|
|
6517
6586
|
f: The function being written
|
|
6518
6587
|
file: The file object for output
|
|
6588
|
+
is_exported: Whether the function is available in Python's runtime
|
|
6519
6589
|
noentry: If True, then the :noindex: and :nocontentsentry: directive
|
|
6520
6590
|
options will be added
|
|
6521
6591
|
|
|
@@ -6543,11 +6613,21 @@ def print_function(f, file, noentry=False): # pragma: no cover
|
|
|
6543
6613
|
print(" :nocontentsentry:", file=file)
|
|
6544
6614
|
print("", file=file)
|
|
6545
6615
|
|
|
6616
|
+
print(" .. hlist::", file=file)
|
|
6617
|
+
print(" :columns: 8", file=file)
|
|
6618
|
+
print("", file=file)
|
|
6619
|
+
print(" * Kernel", file=file)
|
|
6620
|
+
|
|
6621
|
+
if is_exported:
|
|
6622
|
+
print(" * Python", file=file)
|
|
6623
|
+
|
|
6624
|
+
if not f.missing_grad:
|
|
6625
|
+
print(" * Differentiable", file=file)
|
|
6626
|
+
|
|
6627
|
+
print("", file=file)
|
|
6628
|
+
|
|
6546
6629
|
if f.doc != "":
|
|
6547
|
-
|
|
6548
|
-
print(f" {f.doc}", file=file)
|
|
6549
|
-
else:
|
|
6550
|
-
print(f" {f.doc} [1]_", file=file)
|
|
6630
|
+
print(f" {f.doc}", file=file)
|
|
6551
6631
|
print("", file=file)
|
|
6552
6632
|
|
|
6553
6633
|
print(file=file)
|
|
@@ -6563,8 +6643,10 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6563
6643
|
".. functions:\n"
|
|
6564
6644
|
".. currentmodule:: warp\n"
|
|
6565
6645
|
"\n"
|
|
6566
|
-
"
|
|
6567
|
-
"
|
|
6646
|
+
"Built-Ins Reference\n"
|
|
6647
|
+
"===================\n"
|
|
6648
|
+
"This section lists the Warp types and functions available to use from Warp kernels and optionally also from the Warp Python runtime API.\n"
|
|
6649
|
+
"For a listing of the API that is exclusively intended to be used at the *Python Scope* and run inside the CPython interpreter, see the :doc:`runtime` section.\n"
|
|
6568
6650
|
)
|
|
6569
6651
|
|
|
6570
6652
|
print(header, file=file)
|
|
@@ -6609,9 +6691,12 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6609
6691
|
if hasattr(f, "overloads"):
|
|
6610
6692
|
# append all overloads to the group
|
|
6611
6693
|
for o in f.overloads:
|
|
6612
|
-
|
|
6694
|
+
sig = resolve_exported_function_sig(f)
|
|
6695
|
+
is_exported = sig is not None
|
|
6696
|
+
groups[f.group].append((o, is_exported))
|
|
6613
6697
|
else:
|
|
6614
|
-
|
|
6698
|
+
is_exported = False
|
|
6699
|
+
groups[f.group].append((f, is_exported))
|
|
6615
6700
|
|
|
6616
6701
|
# Keep track of what function and query types have been written
|
|
6617
6702
|
written_functions = set()
|
|
@@ -6630,7 +6715,7 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6630
6715
|
print(k, file=file)
|
|
6631
6716
|
print("---------------", file=file)
|
|
6632
6717
|
|
|
6633
|
-
for f in g:
|
|
6718
|
+
for f, is_exported in g:
|
|
6634
6719
|
if f.func:
|
|
6635
6720
|
# f is a Warp function written in Python, we can use autofunction
|
|
6636
6721
|
print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
|
|
@@ -6643,15 +6728,11 @@ def export_functions_rst(file): # pragma: no cover
|
|
|
6643
6728
|
|
|
6644
6729
|
if f.key in written_functions:
|
|
6645
6730
|
# Add :noindex: + :nocontentsentry: since Sphinx gets confused
|
|
6646
|
-
print_function(f, file
|
|
6731
|
+
print_function(f, file, is_exported, noentry=True)
|
|
6647
6732
|
else:
|
|
6648
|
-
if print_function(f, file
|
|
6733
|
+
if print_function(f, file, is_exported):
|
|
6649
6734
|
written_functions.add(f.key)
|
|
6650
6735
|
|
|
6651
|
-
# footnotes
|
|
6652
|
-
print(".. rubric:: Footnotes", file=file)
|
|
6653
|
-
print(".. [1] Function gradients have not been implemented for backpropagation.", file=file)
|
|
6654
|
-
|
|
6655
6736
|
|
|
6656
6737
|
def export_stubs(file): # pragma: no cover
|
|
6657
6738
|
"""Generates stub file for auto-complete of builtin functions"""
|
|
@@ -6751,14 +6832,6 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
6751
6832
|
else:
|
|
6752
6833
|
return t.__name__
|
|
6753
6834
|
|
|
6754
|
-
def ctype_ret_str(t):
|
|
6755
|
-
if isinstance(t, int):
|
|
6756
|
-
return "int"
|
|
6757
|
-
elif isinstance(t, float):
|
|
6758
|
-
return "float"
|
|
6759
|
-
else:
|
|
6760
|
-
return t.__name__
|
|
6761
|
-
|
|
6762
6835
|
file.write("namespace wp {\n\n")
|
|
6763
6836
|
file.write('extern "C" {\n\n')
|
|
6764
6837
|
|
|
@@ -6766,40 +6839,24 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
|
|
|
6766
6839
|
if not hasattr(g, "overloads"):
|
|
6767
6840
|
continue
|
|
6768
6841
|
for f in g.overloads:
|
|
6769
|
-
|
|
6770
|
-
|
|
6771
|
-
|
|
6772
|
-
# only export simple types that don't use arrays
|
|
6773
|
-
# or templated types
|
|
6774
|
-
if not f.is_simple():
|
|
6775
|
-
continue
|
|
6776
|
-
|
|
6777
|
-
try:
|
|
6778
|
-
# todo: construct a default value for each of the functions args
|
|
6779
|
-
# so we can generate the return type for overloaded functions
|
|
6780
|
-
return_type = ctype_ret_str(f.value_func(None, None))
|
|
6781
|
-
except Exception:
|
|
6782
|
-
continue
|
|
6783
|
-
|
|
6784
|
-
if return_type.startswith("Tuple"):
|
|
6842
|
+
sig = resolve_exported_function_sig(f)
|
|
6843
|
+
if sig is None:
|
|
6785
6844
|
continue
|
|
6786
6845
|
|
|
6787
|
-
|
|
6788
|
-
if f.export_func is not None:
|
|
6789
|
-
func_args = f.export_func(f.input_types)
|
|
6790
|
-
else:
|
|
6791
|
-
func_args = f.input_types
|
|
6846
|
+
func_args, return_type = sig
|
|
6792
6847
|
|
|
6793
6848
|
args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
|
|
6794
6849
|
params = ", ".join(func_args.keys())
|
|
6795
6850
|
|
|
6851
|
+
return_str = ctype_ret_str(return_type)
|
|
6852
|
+
|
|
6796
6853
|
if args == "":
|
|
6797
|
-
file.write(f"WP_API void {f.mangled_name}({
|
|
6798
|
-
elif return_type
|
|
6854
|
+
file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
|
|
6855
|
+
elif return_type is None:
|
|
6799
6856
|
file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
|
|
6800
6857
|
else:
|
|
6801
6858
|
file.write(
|
|
6802
|
-
f"WP_API void {f.mangled_name}({args}, {
|
|
6859
|
+
f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
|
|
6803
6860
|
)
|
|
6804
6861
|
|
|
6805
6862
|
file.write('\n} // extern "C"\n\n')
|