warp-lang 1.7.1__py3-none-macosx_10_13_universal2.whl → 1.7.2rc1__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Binary file
warp/bin/libwarp.dylib CHANGED
Binary file
warp/builtins.py CHANGED
@@ -2264,7 +2264,7 @@ def tile_atomic_add_value_func(arg_types, arg_values):
2264
2264
  f"tile_atomic_add() 'a' and 't' arguments must have the same dtype, got {arg_types['a'].dtype} and {arg_types['t'].dtype}"
2265
2265
  )
2266
2266
 
2267
- return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape)
2267
+ return Tile(dtype=arg_types["t"].dtype, shape=arg_types["t"].shape, storage=arg_types["t"].storage)
2268
2268
 
2269
2269
 
2270
2270
  def tile_atomic_add_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -2423,7 +2423,6 @@ add_builtin(
2423
2423
  group="Tile Primitives",
2424
2424
  export=False,
2425
2425
  hidden=True,
2426
- missing_grad=True,
2427
2426
  )
2428
2427
 
2429
2428
  add_builtin(
@@ -2433,7 +2432,6 @@ add_builtin(
2433
2432
  group="Tile Primitives",
2434
2433
  export=False,
2435
2434
  hidden=True,
2436
- missing_grad=True,
2437
2435
  )
2438
2436
 
2439
2437
  add_builtin(
@@ -2443,7 +2441,6 @@ add_builtin(
2443
2441
  group="Tile Primitives",
2444
2442
  export=False,
2445
2443
  hidden=True,
2446
- missing_grad=True,
2447
2444
  )
2448
2445
 
2449
2446
  add_builtin(
@@ -2453,7 +2450,6 @@ add_builtin(
2453
2450
  group="Tile Primitives",
2454
2451
  export=False,
2455
2452
  hidden=True,
2456
- missing_grad=True,
2457
2453
  )
2458
2454
 
2459
2455
 
@@ -4896,46 +4892,78 @@ add_builtin(
4896
4892
  )
4897
4893
 
4898
4894
 
4895
+ SUPPORTED_ATOMIC_TYPES = (
4896
+ warp.int32,
4897
+ warp.int64,
4898
+ warp.uint32,
4899
+ warp.uint64,
4900
+ warp.float32,
4901
+ warp.float64,
4902
+ )
4903
+
4904
+
4899
4905
  def atomic_op_constraint(arg_types: Mapping[str, Any]):
4900
4906
  idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4901
4907
  return all(types_equal(idx_types[0], t) for t in idx_types[1:]) and arg_types["arr"].ndim == len(idx_types)
4902
4908
 
4903
4909
 
4904
- def atomic_op_value_func(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4905
- if arg_types is None:
4906
- return Any
4910
+ def create_atomic_op_value_func(op: str):
4911
+ def fn(arg_types: Mapping[str, type], arg_values: Mapping[str, Any]):
4912
+ if arg_types is None:
4913
+ return Any
4907
4914
 
4908
- arr_type = arg_types["arr"]
4909
- value_type = arg_types["value"]
4910
- idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4915
+ arr_type = arg_types["arr"]
4916
+ value_type = arg_types["value"]
4917
+ idx_types = tuple(arg_types[x] for x in "ijkl" if arg_types.get(x, None) is not None)
4911
4918
 
4912
- if not is_array(arr_type):
4913
- raise RuntimeError("atomic() first argument must be an array")
4919
+ if not is_array(arr_type):
4920
+ raise RuntimeError(f"atomic_{op}() first argument must be an array")
4914
4921
 
4915
- idx_count = len(idx_types)
4922
+ idx_count = len(idx_types)
4916
4923
 
4917
- if idx_count < arr_type.ndim:
4918
- raise RuntimeError(
4919
- "Num indices < num dimensions for atomic, this is a codegen error, should have generated a view instead"
4920
- )
4924
+ if idx_count < arr_type.ndim:
4925
+ raise RuntimeError(
4926
+ f"Num indices < num dimensions for atomic_{op}(), this is a codegen error, should have generated a view instead"
4927
+ )
4921
4928
 
4922
- if idx_count > arr_type.ndim:
4923
- raise RuntimeError(
4924
- f"Num indices > num dimensions for atomic, received {idx_count}, but array only has {arr_type.ndim}"
4925
- )
4929
+ if idx_count > arr_type.ndim:
4930
+ raise RuntimeError(
4931
+ f"Num indices > num dimensions for atomic_{op}(), received {idx_count}, but array only has {arr_type.ndim}"
4932
+ )
4926
4933
 
4927
- # check index types
4928
- for t in idx_types:
4929
- if not type_is_int(t):
4930
- raise RuntimeError(f"atomic() index arguments must be of integer type, got index of type {type_repr(t)}")
4934
+ # check index types
4935
+ for t in idx_types:
4936
+ if not type_is_int(t):
4937
+ raise RuntimeError(
4938
+ f"atomic_{op}() index arguments must be of integer type, got index of type {type_repr(t)}"
4939
+ )
4931
4940
 
4932
- # check value type
4933
- if not types_equal(arr_type.dtype, value_type):
4934
- raise RuntimeError(
4935
- f"atomic() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
4936
- )
4941
+ # check value type
4942
+ if not types_equal(arr_type.dtype, value_type):
4943
+ raise RuntimeError(
4944
+ f"atomic_{op}() value argument type ({type_repr(value_type)}) must be of the same type as the array ({type_repr(arr_type.dtype)})"
4945
+ )
4946
+
4947
+ scalar_type = getattr(arr_type.dtype, "_wp_scalar_type_", arr_type.dtype)
4948
+ if op in ("add", "sub"):
4949
+ supported_atomic_types = (*SUPPORTED_ATOMIC_TYPES, warp.float16)
4950
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in supported_atomic_types):
4951
+ raise RuntimeError(
4952
+ f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float16, float32, or float64 "
4953
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
4954
+ )
4955
+ elif op in ("min", "max"):
4956
+ if not any(types_equal(scalar_type, x, match_generic=True) for x in SUPPORTED_ATOMIC_TYPES):
4957
+ raise RuntimeError(
4958
+ f"atomic_{op}() operations only work on arrays with [u]int32, [u]int64, float32, or float64 "
4959
+ f"as the underlying scalar types, but got {type_repr(arr_type.dtype)} (with scalar type {type_repr(scalar_type)})"
4960
+ )
4961
+ else:
4962
+ raise NotImplementedError
4937
4963
 
4938
- return arr_type.dtype
4964
+ return arr_type.dtype
4965
+
4966
+ return fn
4939
4967
 
4940
4968
 
4941
4969
  def atomic_op_dispatch_func(input_types: Mapping[str, type], return_type: Any, args: Mapping[str, Var]):
@@ -4960,9 +4988,10 @@ for array_type in array_types:
4960
4988
  hidden=hidden,
4961
4989
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
4962
4990
  constraint=atomic_op_constraint,
4963
- value_func=atomic_op_value_func,
4991
+ value_func=create_atomic_op_value_func("add"),
4964
4992
  dispatch_func=atomic_op_dispatch_func,
4965
- doc="Atomically add ``value`` onto ``arr[i]`` and return the old value.",
4993
+ doc="""Atomically adds ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
4994
+ This function is automatically invoked when using the syntax ``arr[i] += value``.""",
4966
4995
  group="Utility",
4967
4996
  skip_replay=True,
4968
4997
  )
@@ -4971,9 +5000,10 @@ for array_type in array_types:
4971
5000
  hidden=hidden,
4972
5001
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
4973
5002
  constraint=atomic_op_constraint,
4974
- value_func=atomic_op_value_func,
5003
+ value_func=create_atomic_op_value_func("add"),
4975
5004
  dispatch_func=atomic_op_dispatch_func,
4976
- doc="Atomically add ``value`` onto ``arr[i,j]`` and return the old value.",
5005
+ doc="""Atomically adds ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
5006
+ This function is automatically invoked when using the syntax ``arr[i,j] += value``.""",
4977
5007
  group="Utility",
4978
5008
  skip_replay=True,
4979
5009
  )
@@ -4982,9 +5012,10 @@ for array_type in array_types:
4982
5012
  hidden=hidden,
4983
5013
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
4984
5014
  constraint=atomic_op_constraint,
4985
- value_func=atomic_op_value_func,
5015
+ value_func=create_atomic_op_value_func("add"),
4986
5016
  dispatch_func=atomic_op_dispatch_func,
4987
- doc="Atomically add ``value`` onto ``arr[i,j,k]`` and return the old value.",
5017
+ doc="""Atomically adds ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
5018
+ This function is automatically invoked when using the syntax ``arr[i,j,k] += value``.""",
4988
5019
  group="Utility",
4989
5020
  skip_replay=True,
4990
5021
  )
@@ -4993,9 +5024,10 @@ for array_type in array_types:
4993
5024
  hidden=hidden,
4994
5025
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
4995
5026
  constraint=atomic_op_constraint,
4996
- value_func=atomic_op_value_func,
5027
+ value_func=create_atomic_op_value_func("add"),
4997
5028
  dispatch_func=atomic_op_dispatch_func,
4998
- doc="Atomically add ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
5029
+ doc="""Atomically adds ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
5030
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] += value``.""",
4999
5031
  group="Utility",
5000
5032
  skip_replay=True,
5001
5033
  )
@@ -5005,9 +5037,10 @@ for array_type in array_types:
5005
5037
  hidden=hidden,
5006
5038
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5007
5039
  constraint=atomic_op_constraint,
5008
- value_func=atomic_op_value_func,
5040
+ value_func=create_atomic_op_value_func("sub"),
5009
5041
  dispatch_func=atomic_op_dispatch_func,
5010
- doc="Atomically subtract ``value`` onto ``arr[i]`` and return the old value.",
5042
+ doc="""Atomically subtracts ``value`` onto ``arr[i]`` and returns the original value of ``arr[i]``.
5043
+ This function is automatically invoked when using the syntax ``arr[i] -= value``.""",
5011
5044
  group="Utility",
5012
5045
  skip_replay=True,
5013
5046
  )
@@ -5016,9 +5049,10 @@ for array_type in array_types:
5016
5049
  hidden=hidden,
5017
5050
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5018
5051
  constraint=atomic_op_constraint,
5019
- value_func=atomic_op_value_func,
5052
+ value_func=create_atomic_op_value_func("sub"),
5020
5053
  dispatch_func=atomic_op_dispatch_func,
5021
- doc="Atomically subtract ``value`` onto ``arr[i,j]`` and return the old value.",
5054
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j]`` and returns the original value of ``arr[i,j]``.
5055
+ This function is automatically invoked when using the syntax ``arr[i,j] -= value``.""",
5022
5056
  group="Utility",
5023
5057
  skip_replay=True,
5024
5058
  )
@@ -5027,9 +5061,10 @@ for array_type in array_types:
5027
5061
  hidden=hidden,
5028
5062
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5029
5063
  constraint=atomic_op_constraint,
5030
- value_func=atomic_op_value_func,
5064
+ value_func=create_atomic_op_value_func("sub"),
5031
5065
  dispatch_func=atomic_op_dispatch_func,
5032
- doc="Atomically subtract ``value`` onto ``arr[i,j,k]`` and return the old value.",
5066
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j,k]`` and returns the original value of ``arr[i,j,k]``.
5067
+ This function is automatically invoked when using the syntax ``arr[i,j,k] -= value``.""",
5033
5068
  group="Utility",
5034
5069
  skip_replay=True,
5035
5070
  )
@@ -5038,9 +5073,10 @@ for array_type in array_types:
5038
5073
  hidden=hidden,
5039
5074
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5040
5075
  constraint=atomic_op_constraint,
5041
- value_func=atomic_op_value_func,
5076
+ value_func=create_atomic_op_value_func("sub"),
5042
5077
  dispatch_func=atomic_op_dispatch_func,
5043
- doc="Atomically subtract ``value`` onto ``arr[i,j,k,l]`` and return the old value.",
5078
+ doc="""Atomically subtracts ``value`` onto ``arr[i,j,k,l]`` and returns the original value of ``arr[i,j,k,l]``.
5079
+ This function is automatically invoked when using the syntax ``arr[i,j,k,l] -= value``.""",
5044
5080
  group="Utility",
5045
5081
  skip_replay=True,
5046
5082
  )
@@ -5050,7 +5086,7 @@ for array_type in array_types:
5050
5086
  hidden=hidden,
5051
5087
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5052
5088
  constraint=atomic_op_constraint,
5053
- value_func=atomic_op_value_func,
5089
+ value_func=create_atomic_op_value_func("min"),
5054
5090
  dispatch_func=atomic_op_dispatch_func,
5055
5091
  doc="""Compute the minimum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
5056
5092
 
@@ -5063,7 +5099,7 @@ for array_type in array_types:
5063
5099
  hidden=hidden,
5064
5100
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5065
5101
  constraint=atomic_op_constraint,
5066
- value_func=atomic_op_value_func,
5102
+ value_func=create_atomic_op_value_func("min"),
5067
5103
  dispatch_func=atomic_op_dispatch_func,
5068
5104
  doc="""Compute the minimum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
5069
5105
 
@@ -5076,7 +5112,7 @@ for array_type in array_types:
5076
5112
  hidden=hidden,
5077
5113
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5078
5114
  constraint=atomic_op_constraint,
5079
- value_func=atomic_op_value_func,
5115
+ value_func=create_atomic_op_value_func("min"),
5080
5116
  dispatch_func=atomic_op_dispatch_func,
5081
5117
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
5082
5118
 
@@ -5089,7 +5125,7 @@ for array_type in array_types:
5089
5125
  hidden=hidden,
5090
5126
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5091
5127
  constraint=atomic_op_constraint,
5092
- value_func=atomic_op_value_func,
5128
+ value_func=create_atomic_op_value_func("min"),
5093
5129
  dispatch_func=atomic_op_dispatch_func,
5094
5130
  doc="""Compute the minimum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
5095
5131
 
@@ -5103,7 +5139,7 @@ for array_type in array_types:
5103
5139
  hidden=hidden,
5104
5140
  input_types={"arr": array_type(dtype=Any), "i": Int, "value": Any},
5105
5141
  constraint=atomic_op_constraint,
5106
- value_func=atomic_op_value_func,
5142
+ value_func=create_atomic_op_value_func("max"),
5107
5143
  dispatch_func=atomic_op_dispatch_func,
5108
5144
  doc="""Compute the maximum of ``value`` and ``arr[i]``, atomically update the array, and return the old value.
5109
5145
 
@@ -5116,7 +5152,7 @@ for array_type in array_types:
5116
5152
  hidden=hidden,
5117
5153
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "value": Any},
5118
5154
  constraint=atomic_op_constraint,
5119
- value_func=atomic_op_value_func,
5155
+ value_func=create_atomic_op_value_func("max"),
5120
5156
  dispatch_func=atomic_op_dispatch_func,
5121
5157
  doc="""Compute the maximum of ``value`` and ``arr[i,j]``, atomically update the array, and return the old value.
5122
5158
 
@@ -5129,7 +5165,7 @@ for array_type in array_types:
5129
5165
  hidden=hidden,
5130
5166
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "value": Any},
5131
5167
  constraint=atomic_op_constraint,
5132
- value_func=atomic_op_value_func,
5168
+ value_func=create_atomic_op_value_func("max"),
5133
5169
  dispatch_func=atomic_op_dispatch_func,
5134
5170
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k]``, atomically update the array, and return the old value.
5135
5171
 
@@ -5142,7 +5178,7 @@ for array_type in array_types:
5142
5178
  hidden=hidden,
5143
5179
  input_types={"arr": array_type(dtype=Any), "i": Int, "j": Int, "k": Int, "l": Int, "value": Any},
5144
5180
  constraint=atomic_op_constraint,
5145
- value_func=atomic_op_value_func,
5181
+ value_func=create_atomic_op_value_func("max"),
5146
5182
  dispatch_func=atomic_op_dispatch_func,
5147
5183
  doc="""Compute the maximum of ``value`` and ``arr[i,j,k,l]``, atomically update the array, and return the old value.
5148
5184
 
warp/codegen.py CHANGED
@@ -240,7 +240,7 @@ class StructInstance:
240
240
  # create Python attributes for each of the struct's variables
241
241
  for field, var in cls.vars.items():
242
242
  if isinstance(var.type, warp.codegen.Struct):
243
- self.__dict__[field] = StructInstance(var.type, getattr(self._ctype, field))
243
+ self.__dict__[field] = var.type.instance_type(ctype=getattr(self._ctype, field))
244
244
  elif isinstance(var.type, warp.types.array):
245
245
  self.__dict__[field] = None
246
246
  else:
@@ -288,6 +288,11 @@ class StructInstance:
288
288
  )
289
289
  setattr(self._ctype, name, value.__ctype__())
290
290
 
291
+ # workaround to prevent gradient buffers being garbage collected
292
+ # since users can do struct.array.requires_grad = False the gradient array
293
+ # would be collected while the struct ctype still holds a reference to it
294
+ super().__setattr__("_" + name + "_grad", value.grad)
295
+
291
296
  elif isinstance(var.type, Struct):
292
297
  # assign structs by-value, otherwise we would have problematic cases transferring ownership
293
298
  # of the underlying ctypes data between shared Python struct instances
@@ -413,11 +418,14 @@ class StructInstance:
413
418
  class Struct:
414
419
  hash: bytes
415
420
 
416
- def __init__(self, cls: type, key: str, module: warp.context.Module):
421
+ def __init__(self, key: str, cls: type, module: warp.context.Module):
422
+ self.key = key
417
423
  self.cls = cls
418
424
  self.module = module
419
- self.key = key
420
- self.vars: Dict[str, Var] = {}
425
+ self.vars: dict[str, Var] = {}
426
+
427
+ if isinstance(self.cls, Sequence):
428
+ raise RuntimeError("Warp structs must be defined as base classes")
421
429
 
422
430
  annotations = get_annotations(self.cls)
423
431
  for label, type in annotations.items():
@@ -489,34 +497,35 @@ class Struct:
489
497
 
490
498
  self.default_constructor.add_overload(self.value_constructor)
491
499
 
492
- if module:
500
+ if isinstance(module, warp.context.Module):
493
501
  module.register_struct(self)
494
502
 
495
- def __call__(self):
496
- """
497
- This function returns s = StructInstance(self)
498
- s uses self.cls as template.
499
- To enable autocomplete on s, we inherit from self.cls.
500
- For example,
501
-
502
- @wp.struct
503
- class A:
504
- # annotations
505
- ...
503
+ # Define class for instances of this struct
504
+ # To enable autocomplete on s, we inherit from self.cls.
505
+ # For example,
506
506
 
507
- The type annotations are inherited in A(), allowing autocomplete in kernels
508
- """
509
- # return StructInstance(self)
507
+ # @wp.struct
508
+ # class A:
509
+ # # annotations
510
+ # ...
510
511
 
512
+ # The type annotations are inherited in A(), allowing autocomplete in kernels
511
513
  class NewStructInstance(self.cls, StructInstance):
512
- def __init__(inst):
513
- StructInstance.__init__(inst, self, None)
514
+ def __init__(inst, ctype=None):
515
+ StructInstance.__init__(inst, self, ctype)
514
516
 
515
517
  # make sure warp.types.get_type_code works with this StructInstance
516
518
  NewStructInstance.cls = self.cls
517
519
  NewStructInstance.native_name = self.native_name
518
520
 
519
- return NewStructInstance()
521
+ self.instance_type = NewStructInstance
522
+
523
+ def __call__(self):
524
+ """
525
+ This function returns s = StructInstance(self)
526
+ s uses self.cls as template.
527
+ """
528
+ return self.instance_type()
520
529
 
521
530
  def initializer(self):
522
531
  return self.default_constructor
warp/config.py CHANGED
@@ -15,7 +15,7 @@
15
15
 
16
16
  from typing import Optional
17
17
 
18
- version: str = "1.7.1"
18
+ version: str = "1.7.2"
19
19
  """Warp version string"""
20
20
 
21
21
  verify_fp: bool = False
warp/context.py CHANGED
@@ -457,6 +457,24 @@ class Function:
457
457
  return f"<Function {self.key}({inputs_str})>"
458
458
 
459
459
 
460
+ def get_builtin_type(return_type: type) -> type:
461
+ # The return_type might just be vector_t(length=3,dtype=wp.float32), so we've got to match that
462
+ # in the list of hard coded types so it knows it's returning one of them:
463
+ if hasattr(return_type, "_wp_generic_type_hint_"):
464
+ return_type_match = tuple(
465
+ x
466
+ for x in generic_vtypes
467
+ if x._wp_generic_type_hint_ == return_type._wp_generic_type_hint_
468
+ and x._wp_type_params_ == return_type._wp_type_params_
469
+ )
470
+ if not return_type_match:
471
+ raise RuntimeError("No match")
472
+
473
+ return return_type_match[0]
474
+
475
+ return return_type
476
+
477
+
460
478
  def call_builtin(func: Function, *params: Any) -> Tuple[bool, Any]:
461
479
  uses_non_warp_array_type = False
462
480
 
@@ -1074,7 +1092,7 @@ def kernel(
1074
1092
  # decorator to register struct, @struct
1075
1093
  def struct(c: type):
1076
1094
  m = get_module(c.__module__)
1077
- s = warp.codegen.Struct(cls=c, key=warp.codegen.make_full_qualified_name(c), module=m)
1095
+ s = warp.codegen.Struct(key=warp.codegen.make_full_qualified_name(c), cls=c, module=m)
1078
1096
  s = functools.update_wrapper(s, c)
1079
1097
  return s
1080
1098
 
@@ -1445,6 +1463,24 @@ def register_api_function(
1445
1463
  """
1446
1464
  function.group = group
1447
1465
  function.hidden = hidden
1466
+
1467
+ # Update the docstring to mark these functions as being available from kernels and Python's runtime.
1468
+ assert function.__doc__.startswith("\n")
1469
+ leading_space_count = sum(1 for _ in itertools.takewhile(str.isspace, function.__doc__[1:]))
1470
+ assert leading_space_count % 4 == 0
1471
+ indent_level = leading_space_count // 4
1472
+ indent = " "
1473
+ function.__doc__ = (
1474
+ f"\n"
1475
+ f"{indent * indent_level}.. hlist::\n"
1476
+ f"{indent * (indent_level + 1)}:columns: 8\n"
1477
+ f"\n"
1478
+ f"{indent * (indent_level + 1)}* Kernel\n"
1479
+ f"{indent * (indent_level + 1)}* Python\n"
1480
+ f"{indent * (indent_level + 1)}* Differentiable\n"
1481
+ f"{function.__doc__}"
1482
+ )
1483
+
1448
1484
  builtin_functions[function.key] = function
1449
1485
 
1450
1486
 
@@ -6510,12 +6546,46 @@ def type_str(t):
6510
6546
  return t.__name__
6511
6547
 
6512
6548
 
6513
- def print_function(f, file, noentry=False): # pragma: no cover
6549
+ def ctype_ret_str(t):
6550
+ return get_builtin_type(t).__name__
6551
+
6552
+
6553
+ def resolve_exported_function_sig(f):
6554
+ if not f.export or f.generic:
6555
+ return None
6556
+
6557
+ # only export simple types that don't use arrays or templated types
6558
+ if not f.is_simple():
6559
+ return None
6560
+
6561
+ # Runtime arguments that are to be passed to the function, not its template signature.
6562
+ if f.export_func is not None:
6563
+ func_args = f.export_func(f.input_types)
6564
+ else:
6565
+ func_args = f.input_types
6566
+
6567
+ # todo: construct a default value for each of the functions args
6568
+ # so we can generate the return type for overloaded functions
6569
+ return_type = f.value_func(func_args, None)
6570
+
6571
+ try:
6572
+ return_type_str = ctype_ret_str(return_type)
6573
+ except Exception:
6574
+ return None
6575
+
6576
+ if return_type_str.startswith("Tuple"):
6577
+ return None
6578
+
6579
+ return (func_args, return_type)
6580
+
6581
+
6582
+ def print_function(f, file, is_exported, noentry=False): # pragma: no cover
6514
6583
  """Writes a function definition to a file for use in reST documentation
6515
6584
 
6516
6585
  Args:
6517
6586
  f: The function being written
6518
6587
  file: The file object for output
6588
+ is_exported: Whether the function is available in Python's runtime
6519
6589
  noentry: If True, then the :noindex: and :nocontentsentry: directive
6520
6590
  options will be added
6521
6591
 
@@ -6543,11 +6613,21 @@ def print_function(f, file, noentry=False): # pragma: no cover
6543
6613
  print(" :nocontentsentry:", file=file)
6544
6614
  print("", file=file)
6545
6615
 
6616
+ print(" .. hlist::", file=file)
6617
+ print(" :columns: 8", file=file)
6618
+ print("", file=file)
6619
+ print(" * Kernel", file=file)
6620
+
6621
+ if is_exported:
6622
+ print(" * Python", file=file)
6623
+
6624
+ if not f.missing_grad:
6625
+ print(" * Differentiable", file=file)
6626
+
6627
+ print("", file=file)
6628
+
6546
6629
  if f.doc != "":
6547
- if not f.missing_grad:
6548
- print(f" {f.doc}", file=file)
6549
- else:
6550
- print(f" {f.doc} [1]_", file=file)
6630
+ print(f" {f.doc}", file=file)
6551
6631
  print("", file=file)
6552
6632
 
6553
6633
  print(file=file)
@@ -6563,8 +6643,10 @@ def export_functions_rst(file): # pragma: no cover
6563
6643
  ".. functions:\n"
6564
6644
  ".. currentmodule:: warp\n"
6565
6645
  "\n"
6566
- "Kernel Reference\n"
6567
- "================"
6646
+ "Built-Ins Reference\n"
6647
+ "===================\n"
6648
+ "This section lists the Warp types and functions available to use from Warp kernels and optionally also from the Warp Python runtime API.\n"
6649
+ "For a listing of the API that is exclusively intended to be used at the *Python Scope* and run inside the CPython interpreter, see the :doc:`runtime` section.\n"
6568
6650
  )
6569
6651
 
6570
6652
  print(header, file=file)
@@ -6609,9 +6691,12 @@ def export_functions_rst(file): # pragma: no cover
6609
6691
  if hasattr(f, "overloads"):
6610
6692
  # append all overloads to the group
6611
6693
  for o in f.overloads:
6612
- groups[f.group].append(o)
6694
+ sig = resolve_exported_function_sig(f)
6695
+ is_exported = sig is not None
6696
+ groups[f.group].append((o, is_exported))
6613
6697
  else:
6614
- groups[f.group].append(f)
6698
+ is_exported = False
6699
+ groups[f.group].append((f, is_exported))
6615
6700
 
6616
6701
  # Keep track of what function and query types have been written
6617
6702
  written_functions = set()
@@ -6630,7 +6715,7 @@ def export_functions_rst(file): # pragma: no cover
6630
6715
  print(k, file=file)
6631
6716
  print("---------------", file=file)
6632
6717
 
6633
- for f in g:
6718
+ for f, is_exported in g:
6634
6719
  if f.func:
6635
6720
  # f is a Warp function written in Python, we can use autofunction
6636
6721
  print(f".. autofunction:: {f.func.__module__}.{f.key}", file=file)
@@ -6643,15 +6728,11 @@ def export_functions_rst(file): # pragma: no cover
6643
6728
 
6644
6729
  if f.key in written_functions:
6645
6730
  # Add :noindex: + :nocontentsentry: since Sphinx gets confused
6646
- print_function(f, file=file, noentry=True)
6731
+ print_function(f, file, is_exported, noentry=True)
6647
6732
  else:
6648
- if print_function(f, file=file):
6733
+ if print_function(f, file, is_exported):
6649
6734
  written_functions.add(f.key)
6650
6735
 
6651
- # footnotes
6652
- print(".. rubric:: Footnotes", file=file)
6653
- print(".. [1] Function gradients have not been implemented for backpropagation.", file=file)
6654
-
6655
6736
 
6656
6737
  def export_stubs(file): # pragma: no cover
6657
6738
  """Generates stub file for auto-complete of builtin functions"""
@@ -6751,14 +6832,6 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6751
6832
  else:
6752
6833
  return t.__name__
6753
6834
 
6754
- def ctype_ret_str(t):
6755
- if isinstance(t, int):
6756
- return "int"
6757
- elif isinstance(t, float):
6758
- return "float"
6759
- else:
6760
- return t.__name__
6761
-
6762
6835
  file.write("namespace wp {\n\n")
6763
6836
  file.write('extern "C" {\n\n')
6764
6837
 
@@ -6766,40 +6839,24 @@ def export_builtins(file: io.TextIOBase): # pragma: no cover
6766
6839
  if not hasattr(g, "overloads"):
6767
6840
  continue
6768
6841
  for f in g.overloads:
6769
- if not f.export or f.generic:
6770
- continue
6771
-
6772
- # only export simple types that don't use arrays
6773
- # or templated types
6774
- if not f.is_simple():
6775
- continue
6776
-
6777
- try:
6778
- # todo: construct a default value for each of the functions args
6779
- # so we can generate the return type for overloaded functions
6780
- return_type = ctype_ret_str(f.value_func(None, None))
6781
- except Exception:
6782
- continue
6783
-
6784
- if return_type.startswith("Tuple"):
6842
+ sig = resolve_exported_function_sig(f)
6843
+ if sig is None:
6785
6844
  continue
6786
6845
 
6787
- # Runtime arguments that are to be passed to the function, not its template signature.
6788
- if f.export_func is not None:
6789
- func_args = f.export_func(f.input_types)
6790
- else:
6791
- func_args = f.input_types
6846
+ func_args, return_type = sig
6792
6847
 
6793
6848
  args = ", ".join(f"{ctype_arg_str(v)} {k}" for k, v in func_args.items())
6794
6849
  params = ", ".join(func_args.keys())
6795
6850
 
6851
+ return_str = ctype_ret_str(return_type)
6852
+
6796
6853
  if args == "":
6797
- file.write(f"WP_API void {f.mangled_name}({return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6798
- elif return_type == "None":
6854
+ file.write(f"WP_API void {f.mangled_name}({return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n")
6855
+ elif return_type is None:
6799
6856
  file.write(f"WP_API void {f.mangled_name}({args}) {{ wp::{f.key}({params}); }}\n")
6800
6857
  else:
6801
6858
  file.write(
6802
- f"WP_API void {f.mangled_name}({args}, {return_type}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
6859
+ f"WP_API void {f.mangled_name}({args}, {return_str}* ret) {{ *ret = wp::{f.key}({params}); }}\n"
6803
6860
  )
6804
6861
 
6805
6862
  file.write('\n} // extern "C"\n\n')