warp-lang 1.8.0__py3-none-macosx_10_13_universal2.whl → 1.9.0__py3-none-macosx_10_13_universal2.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of warp-lang might be problematic. Click here for more details.

Files changed (153) hide show
  1. warp/__init__.py +282 -103
  2. warp/__init__.pyi +482 -110
  3. warp/bin/libwarp-clang.dylib +0 -0
  4. warp/bin/libwarp.dylib +0 -0
  5. warp/build.py +93 -30
  6. warp/build_dll.py +48 -63
  7. warp/builtins.py +955 -137
  8. warp/codegen.py +327 -209
  9. warp/config.py +1 -1
  10. warp/context.py +1363 -800
  11. warp/examples/core/example_marching_cubes.py +1 -0
  12. warp/examples/core/example_render_opengl.py +100 -3
  13. warp/examples/fem/example_apic_fluid.py +98 -52
  14. warp/examples/fem/example_convection_diffusion_dg.py +25 -4
  15. warp/examples/fem/example_diffusion_mgpu.py +8 -3
  16. warp/examples/fem/utils.py +68 -22
  17. warp/examples/interop/example_jax_callable.py +34 -4
  18. warp/examples/interop/example_jax_kernel.py +27 -1
  19. warp/fabric.py +1 -1
  20. warp/fem/cache.py +27 -19
  21. warp/fem/domain.py +2 -2
  22. warp/fem/field/nodal_field.py +2 -2
  23. warp/fem/field/virtual.py +266 -166
  24. warp/fem/geometry/geometry.py +5 -5
  25. warp/fem/integrate.py +200 -91
  26. warp/fem/space/restriction.py +4 -0
  27. warp/fem/space/shape/tet_shape_function.py +3 -10
  28. warp/jax_experimental/custom_call.py +1 -1
  29. warp/jax_experimental/ffi.py +203 -54
  30. warp/marching_cubes.py +708 -0
  31. warp/native/array.h +103 -8
  32. warp/native/builtin.h +90 -9
  33. warp/native/bvh.cpp +64 -28
  34. warp/native/bvh.cu +58 -58
  35. warp/native/bvh.h +2 -2
  36. warp/native/clang/clang.cpp +7 -7
  37. warp/native/coloring.cpp +13 -3
  38. warp/native/crt.cpp +2 -2
  39. warp/native/crt.h +3 -5
  40. warp/native/cuda_util.cpp +42 -11
  41. warp/native/cuda_util.h +10 -4
  42. warp/native/exports.h +1842 -1908
  43. warp/native/fabric.h +2 -1
  44. warp/native/hashgrid.cpp +37 -37
  45. warp/native/hashgrid.cu +2 -2
  46. warp/native/initializer_array.h +1 -1
  47. warp/native/intersect.h +4 -4
  48. warp/native/mat.h +1913 -119
  49. warp/native/mathdx.cpp +43 -43
  50. warp/native/mesh.cpp +24 -24
  51. warp/native/mesh.cu +26 -26
  52. warp/native/mesh.h +5 -3
  53. warp/native/nanovdb/GridHandle.h +179 -12
  54. warp/native/nanovdb/HostBuffer.h +8 -7
  55. warp/native/nanovdb/NanoVDB.h +517 -895
  56. warp/native/nanovdb/NodeManager.h +323 -0
  57. warp/native/nanovdb/PNanoVDB.h +2 -2
  58. warp/native/quat.h +337 -16
  59. warp/native/rand.h +7 -7
  60. warp/native/range.h +7 -1
  61. warp/native/reduce.cpp +10 -10
  62. warp/native/reduce.cu +13 -14
  63. warp/native/runlength_encode.cpp +2 -2
  64. warp/native/runlength_encode.cu +5 -5
  65. warp/native/scan.cpp +3 -3
  66. warp/native/scan.cu +4 -4
  67. warp/native/sort.cpp +10 -10
  68. warp/native/sort.cu +22 -22
  69. warp/native/sparse.cpp +8 -8
  70. warp/native/sparse.cu +14 -14
  71. warp/native/spatial.h +366 -17
  72. warp/native/svd.h +23 -8
  73. warp/native/temp_buffer.h +2 -2
  74. warp/native/tile.h +303 -70
  75. warp/native/tile_radix_sort.h +5 -1
  76. warp/native/tile_reduce.h +16 -25
  77. warp/native/tuple.h +2 -2
  78. warp/native/vec.h +385 -18
  79. warp/native/volume.cpp +54 -54
  80. warp/native/volume.cu +1 -1
  81. warp/native/volume.h +2 -1
  82. warp/native/volume_builder.cu +30 -37
  83. warp/native/warp.cpp +150 -149
  84. warp/native/warp.cu +337 -193
  85. warp/native/warp.h +227 -226
  86. warp/optim/linear.py +736 -271
  87. warp/render/imgui_manager.py +289 -0
  88. warp/render/render_opengl.py +137 -57
  89. warp/render/render_usd.py +0 -1
  90. warp/sim/collide.py +1 -2
  91. warp/sim/graph_coloring.py +2 -2
  92. warp/sim/integrator_vbd.py +10 -2
  93. warp/sparse.py +559 -176
  94. warp/tape.py +2 -0
  95. warp/tests/aux_test_module_aot.py +7 -0
  96. warp/tests/cuda/test_async.py +3 -3
  97. warp/tests/cuda/test_conditional_captures.py +101 -0
  98. warp/tests/geometry/test_marching_cubes.py +233 -12
  99. warp/tests/sim/test_cloth.py +89 -6
  100. warp/tests/sim/test_coloring.py +82 -7
  101. warp/tests/test_array.py +56 -5
  102. warp/tests/test_assert.py +53 -0
  103. warp/tests/test_atomic_cas.py +127 -114
  104. warp/tests/test_codegen.py +3 -2
  105. warp/tests/test_context.py +8 -15
  106. warp/tests/test_enum.py +136 -0
  107. warp/tests/test_examples.py +2 -2
  108. warp/tests/test_fem.py +45 -2
  109. warp/tests/test_fixedarray.py +229 -0
  110. warp/tests/test_func.py +18 -15
  111. warp/tests/test_future_annotations.py +7 -5
  112. warp/tests/test_linear_solvers.py +30 -0
  113. warp/tests/test_map.py +1 -1
  114. warp/tests/test_mat.py +1540 -378
  115. warp/tests/test_mat_assign_copy.py +178 -0
  116. warp/tests/test_mat_constructors.py +574 -0
  117. warp/tests/test_module_aot.py +287 -0
  118. warp/tests/test_print.py +69 -0
  119. warp/tests/test_quat.py +162 -34
  120. warp/tests/test_quat_assign_copy.py +145 -0
  121. warp/tests/test_reload.py +2 -1
  122. warp/tests/test_sparse.py +103 -0
  123. warp/tests/test_spatial.py +140 -34
  124. warp/tests/test_spatial_assign_copy.py +160 -0
  125. warp/tests/test_static.py +48 -0
  126. warp/tests/test_struct.py +43 -3
  127. warp/tests/test_tape.py +38 -0
  128. warp/tests/test_types.py +0 -20
  129. warp/tests/test_vec.py +216 -441
  130. warp/tests/test_vec_assign_copy.py +143 -0
  131. warp/tests/test_vec_constructors.py +325 -0
  132. warp/tests/tile/test_tile.py +206 -152
  133. warp/tests/tile/test_tile_cholesky.py +605 -0
  134. warp/tests/tile/test_tile_load.py +169 -0
  135. warp/tests/tile/test_tile_mathdx.py +2 -558
  136. warp/tests/tile/test_tile_matmul.py +179 -0
  137. warp/tests/tile/test_tile_mlp.py +1 -1
  138. warp/tests/tile/test_tile_reduce.py +100 -11
  139. warp/tests/tile/test_tile_shared_memory.py +16 -16
  140. warp/tests/tile/test_tile_sort.py +59 -55
  141. warp/tests/unittest_suites.py +16 -0
  142. warp/tests/walkthrough_debug.py +1 -1
  143. warp/thirdparty/unittest_parallel.py +108 -9
  144. warp/types.py +554 -264
  145. warp/utils.py +68 -86
  146. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/METADATA +28 -65
  147. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/RECORD +150 -138
  148. warp/native/marching.cpp +0 -19
  149. warp/native/marching.cu +0 -514
  150. warp/native/marching.h +0 -19
  151. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/WHEEL +0 -0
  152. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/licenses/LICENSE.md +0 -0
  153. {warp_lang-1.8.0.dist-info → warp_lang-1.9.0.dist-info}/top_level.txt +0 -0
warp/__init__.pyi CHANGED
@@ -36,120 +36,299 @@ FabricArray = Generic[DType]
36
36
  IndexedFabricArray = Generic[DType]
37
37
  Tile = Generic[DType, Shape]
38
38
 
39
- from warp.types import array, array1d, array2d, array3d, array4d, constant, from_ptr
40
- from warp.types import indexedarray, indexedarray1d, indexedarray2d, indexedarray3d, indexedarray4d
41
- from warp.fabric import fabricarray, fabricarrayarray, indexedfabricarray, indexedfabricarrayarray
42
- from warp.types import tile
43
-
44
- from warp.types import bool, int8, uint8, int16, uint16, int32, uint32, int64, uint64, float16, float32, float64
45
- from warp.types import vec2, vec2b, vec2ub, vec2s, vec2us, vec2i, vec2ui, vec2l, vec2ul, vec2h, vec2f, vec2d
46
- from warp.types import vec3, vec3b, vec3ub, vec3s, vec3us, vec3i, vec3ui, vec3l, vec3ul, vec3h, vec3f, vec3d
47
- from warp.types import vec4, vec4b, vec4ub, vec4s, vec4us, vec4i, vec4ui, vec4l, vec4ul, vec4h, vec4f, vec4d
48
- from warp.types import mat22, mat22h, mat22f, mat22d
49
- from warp.types import mat33, mat33h, mat33f, mat33d
50
- from warp.types import mat44, mat44h, mat44f, mat44d
51
- from warp.types import quat, quath, quatf, quatd
52
- from warp.types import transform, transformh, transformf, transformd
53
- from warp.types import spatial_vector, spatial_vectorh, spatial_vectorf, spatial_vectord
54
- from warp.types import spatial_matrix, spatial_matrixh, spatial_matrixf, spatial_matrixd
55
-
56
- from warp.types import Int, Float, Scalar
57
-
58
- from warp.types import Bvh, Mesh, HashGrid, Volume, MarchingCubes
59
- from warp.types import BvhQuery, HashGridQuery, MeshQueryAABB, MeshQueryPoint, MeshQueryRay
60
-
61
- from warp.types import matmul, adj_matmul, batched_matmul, adj_batched_matmul
39
+ from warp.types import array as array
40
+ from warp.types import array1d as array1d
41
+ from warp.types import array2d as array2d
42
+ from warp.types import array3d as array3d
43
+ from warp.types import array4d as array4d
44
+ from warp.types import constant as constant
45
+ from warp.types import from_ptr as from_ptr
46
+ from warp.types import fixedarray as fixedarray
47
+ from warp.types import indexedarray as indexedarray
48
+ from warp.types import indexedarray1d as indexedarray1d
49
+ from warp.types import indexedarray2d as indexedarray2d
50
+ from warp.types import indexedarray3d as indexedarray3d
51
+ from warp.types import indexedarray4d as indexedarray4d
52
+ from warp.fabric import fabricarray as fabricarray
53
+ from warp.fabric import fabricarrayarray as fabricarrayarray
54
+ from warp.fabric import indexedfabricarray as indexedfabricarray
55
+ from warp.fabric import indexedfabricarrayarray as indexedfabricarrayarray
56
+ from warp.types import tile as tile
57
+
58
+ from warp.types import bool as bool
59
+ from warp.types import int8 as int8
60
+ from warp.types import uint8 as uint8
61
+ from warp.types import int16 as int16
62
+ from warp.types import uint16 as uint16
63
+ from warp.types import int32 as int32
64
+ from warp.types import uint32 as uint32
65
+ from warp.types import int64 as int64
66
+ from warp.types import uint64 as uint64
67
+ from warp.types import float16 as float16
68
+ from warp.types import float32 as float32
69
+ from warp.types import float64 as float64
70
+
71
+ from warp.types import vec2 as vec2
72
+ from warp.types import vec2b as vec2b
73
+ from warp.types import vec2ub as vec2ub
74
+ from warp.types import vec2s as vec2s
75
+ from warp.types import vec2us as vec2us
76
+ from warp.types import vec2i as vec2i
77
+ from warp.types import vec2ui as vec2ui
78
+ from warp.types import vec2l as vec2l
79
+ from warp.types import vec2ul as vec2ul
80
+ from warp.types import vec2h as vec2h
81
+ from warp.types import vec2f as vec2f
82
+ from warp.types import vec2d as vec2d
83
+
84
+ from warp.types import vec3 as vec3
85
+ from warp.types import vec3b as vec3b
86
+ from warp.types import vec3ub as vec3ub
87
+ from warp.types import vec3s as vec3s
88
+ from warp.types import vec3us as vec3us
89
+ from warp.types import vec3i as vec3i
90
+ from warp.types import vec3ui as vec3ui
91
+ from warp.types import vec3l as vec3l
92
+ from warp.types import vec3ul as vec3ul
93
+ from warp.types import vec3h as vec3h
94
+ from warp.types import vec3f as vec3f
95
+ from warp.types import vec3d as vec3d
96
+
97
+ from warp.types import vec4 as vec4
98
+ from warp.types import vec4b as vec4b
99
+ from warp.types import vec4ub as vec4ub
100
+ from warp.types import vec4s as vec4s
101
+ from warp.types import vec4us as vec4us
102
+ from warp.types import vec4i as vec4i
103
+ from warp.types import vec4ui as vec4ui
104
+ from warp.types import vec4l as vec4l
105
+ from warp.types import vec4ul as vec4ul
106
+ from warp.types import vec4h as vec4h
107
+ from warp.types import vec4f as vec4f
108
+ from warp.types import vec4d as vec4d
109
+
110
+ from warp.types import mat22 as mat22
111
+ from warp.types import mat22h as mat22h
112
+ from warp.types import mat22f as mat22f
113
+ from warp.types import mat22d as mat22d
114
+
115
+ from warp.types import mat33 as mat33
116
+ from warp.types import mat33h as mat33h
117
+ from warp.types import mat33f as mat33f
118
+ from warp.types import mat33d as mat33d
119
+
120
+ from warp.types import mat44 as mat44
121
+ from warp.types import mat44h as mat44h
122
+ from warp.types import mat44f as mat44f
123
+ from warp.types import mat44d as mat44d
124
+
125
+ from warp.types import quat as quat
126
+ from warp.types import quath as quath
127
+ from warp.types import quatf as quatf
128
+ from warp.types import quatd as quatd
129
+
130
+ from warp.types import transform as transform
131
+ from warp.types import transformh as transformh
132
+ from warp.types import transformf as transformf
133
+ from warp.types import transformd as transformd
134
+
135
+ from warp.types import spatial_vector as spatial_vector
136
+ from warp.types import spatial_vectorh as spatial_vectorh
137
+ from warp.types import spatial_vectorf as spatial_vectorf
138
+ from warp.types import spatial_vectord as spatial_vectord
139
+
140
+ from warp.types import spatial_matrix as spatial_matrix
141
+ from warp.types import spatial_matrixh as spatial_matrixh
142
+ from warp.types import spatial_matrixf as spatial_matrixf
143
+ from warp.types import spatial_matrixd as spatial_matrixd
144
+
145
+ from warp.types import Int as Int
146
+ from warp.types import Float as Float
147
+ from warp.types import Scalar as Scalar
148
+
149
+ from warp.types import Bvh as Bvh
150
+ from warp.types import Mesh as Mesh
151
+ from warp.types import HashGrid as HashGrid
152
+ from warp.types import Volume as Volume
153
+ from warp.types import BvhQuery as BvhQuery
154
+ from warp.types import HashGridQuery as HashGridQuery
155
+ from warp.types import MeshQueryAABB as MeshQueryAABB
156
+ from warp.types import MeshQueryPoint as MeshQueryPoint
157
+ from warp.types import MeshQueryRay as MeshQueryRay
158
+
159
+ from warp.types import matmul as matmul
160
+ from warp.types import adj_matmul as adj_matmul
161
+ from warp.types import batched_matmul as batched_matmul
162
+ from warp.types import adj_batched_matmul as adj_batched_matmul
62
163
 
63
164
  from warp.types import vector as vec
64
165
  from warp.types import matrix as mat
65
166
 
66
- from warp.types import dtype_from_numpy, dtype_to_numpy
67
-
68
- from warp.types import from_ipc_handle
69
-
70
- from warp.context import init, func, func_grad, func_replay, func_native, kernel, struct, overload
71
- from warp.context import is_cpu_available, is_cuda_available, is_device_available
72
- from warp.context import get_devices, get_preferred_device
73
- from warp.context import get_cuda_devices, get_cuda_device_count, get_cuda_device, map_cuda_device, unmap_cuda_device
74
- from warp.context import get_device, set_device, synchronize_device
75
- from warp.context import (
76
- zeros,
77
- zeros_like,
78
- ones,
79
- ones_like,
80
- full,
81
- full_like,
82
- clone,
83
- empty,
84
- empty_like,
85
- copy,
86
- from_numpy,
87
- launch,
88
- launch_tiled,
89
- synchronize,
90
- force_load,
91
- load_module,
92
- event_from_ipc_handle,
93
- )
94
- from warp.context import set_module_options, get_module_options, get_module
95
- from warp.context import capture_begin, capture_end, capture_launch, capture_if, capture_while, capture_debug_dot_print
96
- from warp.context import Kernel, Function, Launch
97
- from warp.context import Stream, get_stream, set_stream, wait_stream, synchronize_stream
98
- from warp.context import Event, record_event, wait_event, synchronize_event, get_event_elapsed_time
99
- from warp.context import RegisteredGLBuffer
100
- from warp.context import is_mempool_supported, is_mempool_enabled, set_mempool_enabled
101
- from warp.context import (
102
- set_mempool_release_threshold,
103
- get_mempool_release_threshold,
104
- get_mempool_used_mem_current,
105
- get_mempool_used_mem_high,
106
- )
107
- from warp.context import is_mempool_access_supported, is_mempool_access_enabled, set_mempool_access_enabled
108
- from warp.context import is_peer_access_supported, is_peer_access_enabled, set_peer_access_enabled
109
-
110
- from warp.tape import Tape
111
- from warp.utils import ScopedTimer, ScopedDevice, ScopedStream
112
- from warp.utils import ScopedMempool, ScopedMempoolAccess, ScopedPeerAccess
113
- from warp.utils import ScopedCapture
114
- from warp.utils import transform_expand, quat_between_vectors
115
- from warp.utils import TimingResult, timing_begin, timing_end, timing_print
116
- from warp.utils import (
117
- TIMING_KERNEL,
118
- TIMING_KERNEL_BUILTIN,
119
- TIMING_MEMCPY,
120
- TIMING_MEMSET,
121
- TIMING_GRAPH,
122
- TIMING_ALL,
123
- )
124
- from warp.utils import map
125
-
126
- from warp.torch import from_torch, to_torch
127
- from warp.torch import dtype_from_torch, dtype_to_torch
128
- from warp.torch import device_from_torch, device_to_torch
129
- from warp.torch import stream_from_torch, stream_to_torch
130
-
131
- from warp.jax import from_jax, to_jax
132
- from warp.jax import dtype_from_jax, dtype_to_jax
133
- from warp.jax import device_from_jax, device_to_jax
134
-
135
- from warp.dlpack import from_dlpack, to_dlpack
136
-
137
- from warp.paddle import from_paddle, to_paddle
138
- from warp.paddle import dtype_from_paddle, dtype_to_paddle
139
- from warp.paddle import device_from_paddle, device_to_paddle
140
- from warp.paddle import stream_from_paddle
141
-
142
- from warp.build import clear_kernel_cache
143
- from warp.build import clear_lto_cache
167
+ from warp.types import matrix_from_cols as matrix_from_cols
168
+ from warp.types import matrix_from_rows as matrix_from_rows
169
+
170
+ from warp.types import dtype_from_numpy as dtype_from_numpy
171
+ from warp.types import dtype_to_numpy as dtype_to_numpy
172
+
173
+ from warp.types import from_ipc_handle as from_ipc_handle
174
+
175
+ from warp.context import init as init
176
+ from warp.context import func as func
177
+ from warp.context import func_grad as func_grad
178
+ from warp.context import func_replay as func_replay
179
+ from warp.context import func_native as func_native
180
+ from warp.context import kernel as kernel
181
+ from warp.context import struct as struct
182
+ from warp.context import overload as overload
183
+
184
+ from warp.context import is_cpu_available as is_cpu_available
185
+ from warp.context import is_cuda_available as is_cuda_available
186
+ from warp.context import is_device_available as is_device_available
187
+ from warp.context import get_devices as get_devices
188
+ from warp.context import get_preferred_device as get_preferred_device
189
+ from warp.context import get_cuda_devices as get_cuda_devices
190
+ from warp.context import get_cuda_device_count as get_cuda_device_count
191
+ from warp.context import get_cuda_device as get_cuda_device
192
+ from warp.context import map_cuda_device as map_cuda_device
193
+ from warp.context import unmap_cuda_device as unmap_cuda_device
194
+ from warp.context import get_device as get_device
195
+ from warp.context import set_device as set_device
196
+ from warp.context import synchronize_device as synchronize_device
197
+
198
+ from warp.context import zeros as zeros
199
+ from warp.context import zeros_like as zeros_like
200
+ from warp.context import ones as ones
201
+ from warp.context import ones_like as ones_like
202
+ from warp.context import full as full
203
+ from warp.context import full_like as full_like
204
+ from warp.context import clone as clone
205
+ from warp.context import empty as empty
206
+ from warp.context import empty_like as empty_like
207
+ from warp.context import copy as copy
208
+ from warp.context import from_numpy as from_numpy
209
+
210
+ from warp.context import launch as launch
211
+ from warp.context import launch_tiled as launch_tiled
212
+ from warp.context import synchronize as synchronize
213
+ from warp.context import compile_aot_module as compile_aot_module
214
+ from warp.context import force_load as force_load
215
+ from warp.context import load_module as load_module
216
+ from warp.context import load_aot_module as load_aot_module
217
+ from warp.context import event_from_ipc_handle as event_from_ipc_handle
218
+
219
+ from warp.context import set_module_options as set_module_options
220
+ from warp.context import get_module_options as get_module_options
221
+ from warp.context import get_module as get_module
222
+
223
+ from warp.context import capture_begin as capture_begin
224
+ from warp.context import capture_end as capture_end
225
+ from warp.context import capture_launch as capture_launch
226
+ from warp.context import capture_if as capture_if
227
+ from warp.context import capture_while as capture_while
228
+ from warp.context import capture_debug_dot_print as capture_debug_dot_print
229
+
230
+ from warp.context import Kernel as Kernel
231
+ from warp.context import Function as Function
232
+ from warp.context import Launch as Launch
233
+
234
+ from warp.context import Stream as Stream
235
+ from warp.context import get_stream as get_stream
236
+ from warp.context import set_stream as set_stream
237
+ from warp.context import wait_stream as wait_stream
238
+ from warp.context import synchronize_stream as synchronize_stream
239
+
240
+ from warp.context import Event as Event
241
+ from warp.context import record_event as record_event
242
+ from warp.context import wait_event as wait_event
243
+ from warp.context import synchronize_event as synchronize_event
244
+ from warp.context import get_event_elapsed_time as get_event_elapsed_time
245
+
246
+ from warp.context import RegisteredGLBuffer as RegisteredGLBuffer
247
+
248
+ from warp.context import is_mempool_supported as is_mempool_supported
249
+ from warp.context import is_mempool_enabled as is_mempool_enabled
250
+ from warp.context import set_mempool_enabled as set_mempool_enabled
251
+
252
+ from warp.context import set_mempool_release_threshold as set_mempool_release_threshold
253
+ from warp.context import get_mempool_release_threshold as get_mempool_release_threshold
254
+ from warp.context import get_mempool_used_mem_current as get_mempool_used_mem_current
255
+ from warp.context import get_mempool_used_mem_high as get_mempool_used_mem_high
256
+
257
+ from warp.context import is_mempool_access_supported as is_mempool_access_supported
258
+ from warp.context import is_mempool_access_enabled as is_mempool_access_enabled
259
+ from warp.context import set_mempool_access_enabled as set_mempool_access_enabled
260
+
261
+ from warp.context import is_peer_access_supported as is_peer_access_supported
262
+ from warp.context import is_peer_access_enabled as is_peer_access_enabled
263
+ from warp.context import set_peer_access_enabled as set_peer_access_enabled
264
+
265
+ from warp.tape import Tape as Tape
266
+
267
+ from warp.utils import ScopedTimer as ScopedTimer
268
+ from warp.utils import ScopedDevice as ScopedDevice
269
+ from warp.utils import ScopedStream as ScopedStream
270
+ from warp.utils import ScopedMempool as ScopedMempool
271
+ from warp.utils import ScopedMempoolAccess as ScopedMempoolAccess
272
+ from warp.utils import ScopedPeerAccess as ScopedPeerAccess
273
+ from warp.utils import ScopedCapture as ScopedCapture
274
+
275
+ from warp.utils import transform_expand as transform_expand
276
+ from warp.utils import quat_between_vectors as quat_between_vectors
277
+
278
+ from warp.utils import TimingResult as TimingResult
279
+ from warp.utils import timing_begin as timing_begin
280
+ from warp.utils import timing_end as timing_end
281
+ from warp.utils import timing_print as timing_print
282
+
283
+ from warp.utils import TIMING_KERNEL as TIMING_KERNEL
284
+ from warp.utils import TIMING_KERNEL_BUILTIN as TIMING_KERNEL_BUILTIN
285
+ from warp.utils import TIMING_MEMCPY as TIMING_MEMCPY
286
+ from warp.utils import TIMING_MEMSET as TIMING_MEMSET
287
+ from warp.utils import TIMING_GRAPH as TIMING_GRAPH
288
+ from warp.utils import TIMING_ALL as TIMING_ALL
289
+
290
+ from warp.utils import map as map
291
+
292
+ from warp.marching_cubes import MarchingCubes as MarchingCubes
293
+
294
+ from warp.torch import from_torch as from_torch
295
+ from warp.torch import to_torch as to_torch
296
+ from warp.torch import dtype_from_torch as dtype_from_torch
297
+ from warp.torch import dtype_to_torch as dtype_to_torch
298
+ from warp.torch import device_from_torch as device_from_torch
299
+ from warp.torch import device_to_torch as device_to_torch
300
+ from warp.torch import stream_from_torch as stream_from_torch
301
+ from warp.torch import stream_to_torch as stream_to_torch
302
+
303
+ from warp.jax import from_jax as from_jax
304
+ from warp.jax import to_jax as to_jax
305
+ from warp.jax import dtype_from_jax as dtype_from_jax
306
+ from warp.jax import dtype_to_jax as dtype_to_jax
307
+ from warp.jax import device_from_jax as device_from_jax
308
+ from warp.jax import device_to_jax as device_to_jax
309
+
310
+ from warp.dlpack import from_dlpack as from_dlpack
311
+ from warp.dlpack import to_dlpack as to_dlpack
312
+
313
+ from warp.paddle import from_paddle as from_paddle
314
+ from warp.paddle import to_paddle as to_paddle
315
+ from warp.paddle import dtype_from_paddle as dtype_from_paddle
316
+ from warp.paddle import dtype_to_paddle as dtype_to_paddle
317
+ from warp.paddle import device_from_paddle as device_from_paddle
318
+ from warp.paddle import device_to_paddle as device_to_paddle
319
+ from warp.paddle import stream_from_paddle as stream_from_paddle
320
+
321
+ from warp.build import clear_kernel_cache as clear_kernel_cache
322
+ from warp.build import clear_lto_cache as clear_lto_cache
144
323
 
145
324
  from warp.constants import *
146
325
 
147
326
  from . import builtins
148
- from warp.builtins import static
327
+ from warp.builtins import static as static
149
328
 
150
329
  from warp.math import *
151
330
 
152
- import warp.config as config
331
+ from . import config as config
153
332
 
154
333
  __version__ = config.version
155
334
 
@@ -924,7 +1103,7 @@ def tile_arange(*args: Scalar, dtype: Scalar, storage: str) -> Tile[Scalar, Tupl
924
1103
 
925
1104
  @over
926
1105
  def tile_load(
927
- a: Array[Any], shape: Tuple[int, ...], offset: Tuple[int, ...], storage: str
1106
+ a: Array[Any], shape: Tuple[int, ...], offset: Tuple[int, ...], storage: str, bounds_check: bool
928
1107
  ) -> Tile[Any, Tuple[int, ...]]:
929
1108
  """Loads a tile from a global memory array.
930
1109
 
@@ -935,12 +1114,80 @@ def tile_load(
935
1114
  :param offset: Offset in the source array to begin reading from (optional)
936
1115
  :param storage: The storage location for the tile: ``"register"`` for registers
937
1116
  (default) or ``"shared"`` for shared memory.
1117
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster load times
938
1118
  :returns: A tile with shape as specified and data type the same as the source array
939
1119
  """
940
1120
  ...
941
1121
 
942
1122
  @over
943
- def tile_store(a: Array[Any], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...]):
1123
+ def tile_load_indexed(
1124
+ a: Array[Any],
1125
+ indices: Tile[int32, Tuple[int]],
1126
+ shape: Tuple[int, ...],
1127
+ offset: Tuple[int, ...],
1128
+ axis: int32,
1129
+ storage: str,
1130
+ ) -> Tile[Any, Tuple[int, ...]]:
1131
+ """Loads a tile from a global memory array, with loads along a specified axis mapped according to a 1D tile of indices.
1132
+
1133
+ :param a: The source array in global memory
1134
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
1135
+ :param shape: Shape of the tile to load, must have the same number of dimensions as ``a``, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
1136
+ :param offset: Offset in the source array to begin reading from (optional)
1137
+ :param axis: Axis of ``a`` that indices refer to
1138
+ :param storage: The storage location for the tile: ``"register"`` for registers (default) or ``"shared"`` for shared memory.
1139
+ :returns: A tile with shape as specified and data type the same as the source array
1140
+
1141
+ This example shows how to select and store the even indexed rows from a 2D array:
1142
+
1143
+ .. code-block:: python
1144
+
1145
+ TILE_M = wp.constant(2)
1146
+ TILE_N = wp.constant(2)
1147
+ HALF_M = wp.constant(TILE_M // 2)
1148
+ HALF_N = wp.constant(TILE_N // 2)
1149
+
1150
+ @wp.kernel
1151
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
1152
+ i, j = wp.tid()
1153
+
1154
+ evens = wp.tile_arange(HALF_M, dtype=int, storage="shared") * 2
1155
+
1156
+ t0 = wp.tile_load_indexed(
1157
+ x, indices=evens, shape=(HALF_M, TILE_N), offset=(i * TILE_M, j * TILE_N), axis=0, storage="register"
1158
+ )
1159
+ wp.tile_store(y, t0, offset=(i * HALF_M, j * TILE_N))
1160
+
1161
+ M = TILE_M * 2
1162
+ N = TILE_N * 2
1163
+
1164
+ arr = np.arange(M * N).reshape(M, N)
1165
+
1166
+ x = wp.array(arr, dtype=float)
1167
+ y = wp.zeros((M // 2, N), dtype=float)
1168
+
1169
+ wp.launch_tiled(compute, dim=[2, 2], inputs=[x], outputs=[y], block_dim=32, device=device)
1170
+
1171
+ print(x.numpy())
1172
+ print(y.numpy())
1173
+
1174
+ Prints:
1175
+
1176
+ .. code-block:: text
1177
+
1178
+ [[ 0. 1. 2. 3.]
1179
+ [ 4. 5. 6. 7.]
1180
+ [ 8. 9. 10. 11.]
1181
+ [12. 13. 14. 15.]]
1182
+
1183
+ [[ 0. 1. 2. 3.]
1184
+ [ 8. 9. 10. 11.]]
1185
+
1186
+ """
1187
+ ...
1188
+
1189
+ @over
1190
+ def tile_store(a: Array[Any], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...], bounds_check: bool):
944
1191
  """Store a tile to a global memory array.
945
1192
 
946
1193
  This method will cooperatively store a tile to global memory using all threads in the block.
@@ -948,22 +1195,147 @@ def tile_store(a: Array[Any], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int,
948
1195
  :param a: The destination array in global memory
949
1196
  :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array
950
1197
  :param offset: Offset in the destination array (optional)
1198
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
1199
+
1200
+ """
1201
+ ...
1202
+
1203
+ @over
1204
+ def tile_store_indexed(
1205
+ a: Array[Any], indices: Tile[int32, Tuple[int]], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...], axis: int32
1206
+ ):
1207
+ """Store a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
1208
+
1209
+ :param a: The destination array in global memory
1210
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
1211
+ :param t: The source tile to store data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
1212
+ :param offset: Offset in the destination array (optional)
1213
+ :param axis: Axis of ``a`` that indices refer to
1214
+
1215
+ This example shows how to map tile rows to the even rows of a 2D array:
1216
+
1217
+ .. code-block:: python
1218
+
1219
+ TILE_M = wp.constant(2)
1220
+ TILE_N = wp.constant(2)
1221
+ TWO_M = wp.constant(TILE_M * 2)
1222
+ TWO_N = wp.constant(TILE_N * 2)
1223
+
1224
+ @wp.kernel
1225
+ def compute(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
1226
+ i, j = wp.tid()
1227
+
1228
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1229
+
1230
+ evens_M = wp.tile_arange(TILE_M, dtype=int, storage="shared") * 2
1231
+
1232
+ wp.tile_store_indexed(y, indices=evens_M, t=t, offset=(i * TWO_M, j * TILE_N), axis=0)
1233
+
1234
+ M = TILE_M * 2
1235
+ N = TILE_N * 2
1236
+
1237
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
1238
+
1239
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
1240
+ y = wp.zeros((M * 2, N), dtype=float, requires_grad=True, device=device)
1241
+
1242
+ wp.launch_tiled(compute, dim=[2, 2], inputs=[x], outputs=[y], block_dim=32, device=device)
1243
+
1244
+ print(x.numpy())
1245
+ print(y.numpy())
1246
+
1247
+ Prints:
1248
+
1249
+ .. code-block:: text
1250
+
1251
+ [[ 0. 1. 2. 3.]
1252
+ [ 4. 5. 6. 7.]
1253
+ [ 8. 9. 10. 11.]
1254
+ [12. 13. 14. 15.]]
1255
+
1256
+ [[ 0. 1. 2. 3.]
1257
+ [ 0. 0. 0. 0.]
1258
+ [ 4. 5. 6. 7.]
1259
+ [ 0. 0. 0. 0.]
1260
+ [ 8. 9. 10. 11.]
1261
+ [ 0. 0. 0. 0.]
1262
+ [12. 13. 14. 15.]
1263
+ [ 0. 0. 0. 0.]]
1264
+
951
1265
  """
952
1266
  ...
953
1267
 
954
1268
  @over
955
1269
  def tile_atomic_add(
956
- a: Array[Any], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...]
1270
+ a: Array[Any], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...], bounds_check: bool
957
1271
  ) -> Tile[Any, Tuple[int, ...]]:
958
1272
  """Atomically add a tile onto the array `a`, each element will be updated atomically.
959
1273
 
960
1274
  :param a: Array in global memory, should have the same ``dtype`` as the input tile
961
1275
  :param t: Source tile to add to the destination array
962
1276
  :param offset: Offset in the destination array (optional)
1277
+ :param bounds_check: Needed for unaligned tiles, but can disable for memory-aligned tiles for faster write times
963
1278
  :returns: A tile with the same dimensions and data type as the source tile, holding the original value of the destination elements
964
1279
  """
965
1280
  ...
966
1281
 
1282
+ @over
1283
+ def tile_atomic_add_indexed(
1284
+ a: Array[Any], indices: Tile[int32, Tuple[int]], t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...], axis: int32
1285
+ ) -> Tile[Any, Tuple[int, ...]]:
1286
+ """Atomically add a tile to a global memory array, with storage along a specified axis mapped according to a 1D tile of indices.
1287
+
1288
+ :param a: The destination array in global memory
1289
+ :param indices: A 1D tile of integer indices mapping to elements in ``a``.
1290
+ :param t: The source tile to extract data from, must have the same data type and number of dimensions as the destination array, and along ``axis``, it must have the same number of elements as the ``indices`` tile.
1291
+ :param offset: Offset in the destination array (optional)
1292
+ :param axis: Axis of ``a`` that indices refer to
1293
+
1294
+ This example shows how to compute a blocked, row-wise reduction:
1295
+
1296
+ .. code-block:: python
1297
+
1298
+ TILE_M = wp.constant(2)
1299
+ TILE_N = wp.constant(2)
1300
+
1301
+ @wp.kernel
1302
+ def tile_atomic_add_indexed(x: wp.array2d(dtype=float), y: wp.array2d(dtype=float)):
1303
+ i, j = wp.tid()
1304
+
1305
+ t = wp.tile_load(x, shape=(TILE_M, TILE_N), offset=(i * TILE_M, j * TILE_N), storage="register")
1306
+
1307
+ zeros = wp.tile_zeros(TILE_M, dtype=int, storage="shared")
1308
+
1309
+ wp.tile_atomic_add_indexed(y, indices=zeros, t=t, offset=(i, j * TILE_N), axis=0)
1310
+
1311
+ M = TILE_M * 2
1312
+ N = TILE_N * 2
1313
+
1314
+ arr = np.arange(M * N, dtype=float).reshape(M, N)
1315
+
1316
+ x = wp.array(arr, dtype=float, requires_grad=True, device=device)
1317
+ y = wp.zeros((2, N), dtype=float, requires_grad=True, device=device)
1318
+
1319
+ wp.launch_tiled(tile_atomic_add_indexed, dim=[2, 2], inputs=[x], outputs=[y], block_dim=32, device=device)
1320
+
1321
+ print(x.numpy())
1322
+ print(y.numpy())
1323
+
1324
+ Prints:
1325
+
1326
+ .. code-block:: text
1327
+
1328
+ [[ 0. 1. 2. 3.]
1329
+ [ 4. 5. 6. 7.]
1330
+ [ 8. 9. 10. 11.]
1331
+ [12. 13. 14. 15.]]
1332
+
1333
+ [[ 4. 6. 8. 10.]
1334
+ [20. 22. 24. 26.]]
1335
+
1336
+ """
1337
+ ...
1338
+
967
1339
  @over
968
1340
  def tile_view(
969
1341
  t: Tile[Any, Tuple[int, ...]], offset: Tuple[int, ...], shape: Tuple[int, ...]
@@ -1370,7 +1742,7 @@ def tile_map(op: Callable, a: Tile[Scalar, Tuple[int, ...]]) -> Tile[Scalar, Tup
1370
1742
 
1371
1743
  :param op: A callable function that accepts one argument and returns one argument, may be a user function or builtin
1372
1744
  :param a: The input tile, the operator (or one of its overloads) must be able to accept the tile's data type
1373
- :returns: A tile with the same dimensions and data type as the input tile.
1745
+ :returns: A tile with the same dimensions as the input tile. Its datatype is specified by the return type of op
1374
1746
 
1375
1747
  Example:
1376
1748
 
@@ -1401,12 +1773,12 @@ def tile_map(
1401
1773
  """Apply a binary function onto the tile.
1402
1774
 
1403
1775
  This function cooperatively applies a binary function to each element of the tiles using all threads in the block.
1404
- Both input tiles must have the same dimensions and datatype.
1776
+ Both input tiles must have the same dimensions, and if using a builtin op, the same datatypes.
1405
1777
 
1406
1778
  :param op: A callable function that accepts two arguments and returns one argument, all of the same type, may be a user function or builtin
1407
1779
  :param a: The first input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1408
1780
  :param b: The second input tile, the operator (or one of its overloads) must be able to accept the tile's dtype
1409
- :returns: A tile with the same dimensions and datatype as the input tiles.
1781
+ :returns: A tile with the same dimensions as the input tiles. Its datatype is specified by the return type of op
1410
1782
 
1411
1783
  Example:
1412
1784
 
@@ -2971,7 +3343,7 @@ def mod(a: Scalar, b: Scalar) -> Scalar:
2971
3343
  ...
2972
3344
 
2973
3345
  @over
2974
- def mod(a: Vector[Any, Scalar], b: Vector[Any, Scalar]) -> Scalar:
3346
+ def mod(a: Vector[Any, Scalar], b: Vector[Any, Scalar]) -> Vector[Any, Scalar]:
2975
3347
  """Modulo operation using truncated division."""
2976
3348
  ...
2977
3349
 
Binary file
warp/bin/libwarp.dylib CHANGED
Binary file