gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +9 -0
- gstaichi/__init__.py +40 -0
- gstaichi/__main__.py +5 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- gstaichi/_lib/utils.py +249 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_main.py +545 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +103 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +199 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +189 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/argpack.py +411 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1318 -0
- gstaichi/lang/ast/ast_transformer_utils.py +341 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +466 -0
- gstaichi/lang/impl.py +1241 -0
- gstaichi/lang/kernel_arguments.py +157 -0
- gstaichi/lang/kernel_impl.py +1382 -0
- gstaichi/lang/matrix.py +1881 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +778 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +855 -0
- gstaichi/lang/util.py +381 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +47 -0
- gstaichi/types/compound_types.py +90 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +147 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +13 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
- gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
- gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
gstaichi/_kernels.py
ADDED
@@ -0,0 +1,420 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._funcs import field_fill_gstaichi_scope
|
4
|
+
from gstaichi._lib.utils import get_os_name
|
5
|
+
from gstaichi.lang import ops
|
6
|
+
from gstaichi.lang._ndrange import ndrange
|
7
|
+
from gstaichi.lang.expr import Expr
|
8
|
+
from gstaichi.lang.field import ScalarField
|
9
|
+
from gstaichi.lang.impl import grouped, static, static_assert
|
10
|
+
from gstaichi.lang.kernel_impl import func, kernel
|
11
|
+
from gstaichi.lang.misc import loop_config
|
12
|
+
from gstaichi.lang.simt import block, warp
|
13
|
+
from gstaichi.lang.snode import deactivate
|
14
|
+
from gstaichi.math import vec3
|
15
|
+
from gstaichi.types import ndarray_type, texture_type, vector
|
16
|
+
from gstaichi.types.annotations import template
|
17
|
+
from gstaichi.types.enums import Format
|
18
|
+
from gstaichi.types.primitive_types import f16, f32, f64, i32, u8
|
19
|
+
|
20
|
+
|
21
|
+
# A set of helper (meta)functions
|
22
|
+
@kernel
|
23
|
+
def fill_field(field: template(), val: template()):
|
24
|
+
value = ops.cast(val, field.dtype)
|
25
|
+
for I in grouped(field):
|
26
|
+
field[I] = value
|
27
|
+
|
28
|
+
|
29
|
+
@kernel
|
30
|
+
def fill_ndarray(ndarray: ndarray_type.ndarray(), val: template()):
|
31
|
+
for I in grouped(ndarray):
|
32
|
+
ndarray[I] = val
|
33
|
+
|
34
|
+
|
35
|
+
@kernel
|
36
|
+
def fill_ndarray_matrix(ndarray: ndarray_type.ndarray(), val: template()):
|
37
|
+
for I in grouped(ndarray):
|
38
|
+
ndarray[I] = val
|
39
|
+
|
40
|
+
|
41
|
+
@kernel
|
42
|
+
def tensor_to_ext_arr(tensor: template(), arr: ndarray_type.ndarray()):
|
43
|
+
# default value of offset is [], replace it with [0] * len
|
44
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
45
|
+
|
46
|
+
for I in grouped(tensor):
|
47
|
+
arr[I - offset] = tensor[I]
|
48
|
+
|
49
|
+
|
50
|
+
@kernel
|
51
|
+
def ndarray_to_ext_arr(ndarray: ndarray_type.ndarray(), arr: ndarray_type.ndarray()):
|
52
|
+
for I in grouped(ndarray):
|
53
|
+
arr[I] = ndarray[I]
|
54
|
+
|
55
|
+
|
56
|
+
@kernel
|
57
|
+
def ndarray_matrix_to_ext_arr(
|
58
|
+
ndarray: ndarray_type.ndarray(),
|
59
|
+
arr: ndarray_type.ndarray(),
|
60
|
+
layout_is_aos: template(),
|
61
|
+
as_vector: template(),
|
62
|
+
):
|
63
|
+
for I in grouped(ndarray):
|
64
|
+
for p in static(range(ndarray[I].n)):
|
65
|
+
if static(as_vector):
|
66
|
+
if static(layout_is_aos):
|
67
|
+
arr[I, p] = ndarray[I][p]
|
68
|
+
else:
|
69
|
+
arr[p, I] = ndarray[I][p]
|
70
|
+
else:
|
71
|
+
for q in static(range(ndarray[I].m)):
|
72
|
+
if static(layout_is_aos):
|
73
|
+
arr[I, p, q] = ndarray[I][p, q]
|
74
|
+
else:
|
75
|
+
arr[p, q, I] = ndarray[I][p, q]
|
76
|
+
|
77
|
+
|
78
|
+
@kernel
|
79
|
+
def vector_to_fast_image(img: template(), out: ndarray_type.ndarray()):
|
80
|
+
static_assert(len(img.shape) == 2)
|
81
|
+
offset = static(img.snode.ptr.offset if len(img.snode.ptr.offset) != 0 else [0, 0])
|
82
|
+
i_offset = static(offset[0])
|
83
|
+
j_offset = static(offset[1])
|
84
|
+
# FIXME: Why is ``for i, j in img:`` slower than:
|
85
|
+
for i, j in ndrange(*img.shape):
|
86
|
+
r, g, b = 0, 0, 0
|
87
|
+
color = img[i + i_offset, (img.shape[1] + j_offset) - 1 - j]
|
88
|
+
if static(img.dtype in [f16, f32, f64]):
|
89
|
+
r, g, b = ops.min(255, ops.max(0, int(color * 255)))[:3]
|
90
|
+
else:
|
91
|
+
static_assert(img.dtype == u8)
|
92
|
+
r, g, b = color[:3]
|
93
|
+
|
94
|
+
idx = j * img.shape[0] + i
|
95
|
+
# We use i32 for |out| since Metal doesn't support u8 types
|
96
|
+
if static(get_os_name() != "osx"):
|
97
|
+
out[idx] = (r << 16) + (g << 8) + b
|
98
|
+
else:
|
99
|
+
# What's -16777216?
|
100
|
+
#
|
101
|
+
# On Mac, we need to set the alpha channel to 0xff. Since Mac's GUI
|
102
|
+
# is big-endian, the color is stored in ABGR order, and we need to
|
103
|
+
# add 0xff000000, which is -16777216 in I32's legit range. (Albeit
|
104
|
+
# the clarity, adding 0xff000000 doesn't work.)
|
105
|
+
alpha = -16777216
|
106
|
+
out[idx] = (b << 16) + (g << 8) + r + alpha
|
107
|
+
|
108
|
+
|
109
|
+
@kernel
|
110
|
+
def tensor_to_image(tensor: template(), arr: ndarray_type.ndarray()):
|
111
|
+
# default value of offset is [], replace it with [0] * len
|
112
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
113
|
+
for I in grouped(tensor):
|
114
|
+
t = ops.cast(tensor[I], f32)
|
115
|
+
arr[I - offset, 0] = t
|
116
|
+
arr[I - offset, 1] = t
|
117
|
+
arr[I - offset, 2] = t
|
118
|
+
|
119
|
+
|
120
|
+
@kernel
|
121
|
+
def vector_to_image(mat: template(), arr: ndarray_type.ndarray()):
|
122
|
+
# default value of offset is [], replace it with [0] * len
|
123
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
124
|
+
for I in grouped(mat):
|
125
|
+
for p in static(range(mat.n)):
|
126
|
+
arr[I - offset, p] = ops.cast(mat[I][p], f32)
|
127
|
+
if static(mat.n <= 2):
|
128
|
+
arr[I - offset, 2] = 0
|
129
|
+
|
130
|
+
|
131
|
+
@kernel
|
132
|
+
def tensor_to_tensor(tensor: template(), other: template()):
|
133
|
+
static_assert(tensor.shape == other.shape)
|
134
|
+
shape = static(tensor.shape)
|
135
|
+
tensor_offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(shape))
|
136
|
+
other_offset = static(other.snode.ptr.offset if len(other.snode.ptr.offset) != 0 else [0] * len(shape))
|
137
|
+
|
138
|
+
for I in grouped(ndrange(*shape)):
|
139
|
+
tensor[I + tensor_offset] = other[I + other_offset]
|
140
|
+
|
141
|
+
|
142
|
+
@kernel
|
143
|
+
def ext_arr_to_tensor(arr: ndarray_type.ndarray(), tensor: template()):
|
144
|
+
# default value of offset is [], replace it with [0] * len
|
145
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
146
|
+
for I in grouped(tensor):
|
147
|
+
tensor[I] = arr[I - offset]
|
148
|
+
|
149
|
+
|
150
|
+
@kernel
|
151
|
+
def ndarray_to_ndarray(ndarray: ndarray_type.ndarray(), other: ndarray_type.ndarray()):
|
152
|
+
for I in grouped(ndarray):
|
153
|
+
ndarray[I] = other[I]
|
154
|
+
|
155
|
+
|
156
|
+
@kernel
|
157
|
+
def ext_arr_to_ndarray(arr: ndarray_type.ndarray(), ndarray: ndarray_type.ndarray()):
|
158
|
+
for I in grouped(ndarray):
|
159
|
+
ndarray[I] = arr[I]
|
160
|
+
|
161
|
+
|
162
|
+
@kernel
|
163
|
+
def ext_arr_to_ndarray_matrix(
|
164
|
+
arr: ndarray_type.ndarray(),
|
165
|
+
ndarray: ndarray_type.ndarray(),
|
166
|
+
layout_is_aos: template(),
|
167
|
+
as_vector: template(),
|
168
|
+
):
|
169
|
+
for I in grouped(ndarray):
|
170
|
+
for p in static(range(ndarray[I].n)):
|
171
|
+
if static(as_vector):
|
172
|
+
if static(layout_is_aos):
|
173
|
+
ndarray[I][p] = arr[I, p]
|
174
|
+
else:
|
175
|
+
ndarray[I][p] = arr[p, I]
|
176
|
+
else:
|
177
|
+
for q in static(range(ndarray[I].m)):
|
178
|
+
if static(layout_is_aos):
|
179
|
+
ndarray[I][p, q] = arr[I, p, q]
|
180
|
+
else:
|
181
|
+
ndarray[I][p, q] = arr[p, q, I]
|
182
|
+
|
183
|
+
|
184
|
+
@kernel
|
185
|
+
def matrix_to_ext_arr(mat: template(), arr: ndarray_type.ndarray(), as_vector: template()):
|
186
|
+
# default value of offset is [], replace it with [0] * len
|
187
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
188
|
+
|
189
|
+
for I in grouped(mat):
|
190
|
+
for p in static(range(mat.n)):
|
191
|
+
for q in static(range(mat.m)):
|
192
|
+
if static(as_vector):
|
193
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
194
|
+
arr[I - offset, p] = mat[I][p]
|
195
|
+
else:
|
196
|
+
arr[I - offset, p] = mat[I][p, q]
|
197
|
+
else:
|
198
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
199
|
+
arr[I - offset, p, q] = mat[I][p]
|
200
|
+
else:
|
201
|
+
arr[I - offset, p, q] = mat[I][p, q]
|
202
|
+
|
203
|
+
|
204
|
+
@kernel
|
205
|
+
def ext_arr_to_matrix(arr: ndarray_type.ndarray(), mat: template(), as_vector: template()):
|
206
|
+
# default value of offset is [], replace it with [0] * len
|
207
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
208
|
+
|
209
|
+
for I in grouped(mat):
|
210
|
+
for p in static(range(mat.n)):
|
211
|
+
for q in static(range(mat.m)):
|
212
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
213
|
+
if static(as_vector):
|
214
|
+
mat[I][p] = arr[I - offset, p]
|
215
|
+
else:
|
216
|
+
mat[I][p] = arr[I - offset, p, q]
|
217
|
+
else:
|
218
|
+
if static(as_vector):
|
219
|
+
mat[I][p, q] = arr[I - offset, p]
|
220
|
+
else:
|
221
|
+
mat[I][p, q] = arr[I - offset, p, q]
|
222
|
+
|
223
|
+
|
224
|
+
# extract ndarray of raw vulkan memory layout to normal memory layout.
|
225
|
+
# the vulkan layout stored in ndarray : width-by-width stored along n-
|
226
|
+
# darray's shape[1] which is the height-axis(So use [size // h, size %
|
227
|
+
# h]). And the height-order of vulkan layout is flip up-down.(So take
|
228
|
+
# [size = (h - 1 - j) * w + i] to get the index)
|
229
|
+
@kernel
|
230
|
+
def arr_vulkan_layout_to_arr_normal_layout(vk_arr: ndarray_type.ndarray(), normal_arr: ndarray_type.ndarray()):
|
231
|
+
static_assert(len(normal_arr.shape) == 2)
|
232
|
+
w = normal_arr.shape[0]
|
233
|
+
h = normal_arr.shape[1]
|
234
|
+
for i, j in ndrange(w, h):
|
235
|
+
normal_arr[i, j] = vk_arr[(h - 1 - j) * w + i]
|
236
|
+
|
237
|
+
|
238
|
+
# extract ndarray of raw vulkan memory layout into a gstaichi-field data
|
239
|
+
# structure with normal memory layout.
|
240
|
+
@kernel
|
241
|
+
def arr_vulkan_layout_to_field_normal_layout(vk_arr: ndarray_type.ndarray(), normal_field: template()):
|
242
|
+
static_assert(len(normal_field.shape) == 2)
|
243
|
+
w = static(normal_field.shape[0])
|
244
|
+
h = static(normal_field.shape[1])
|
245
|
+
offset = static(normal_field.snode.ptr.offset if len(normal_field.snode.ptr.offset) != 0 else [0, 0])
|
246
|
+
i_offset = static(offset[0])
|
247
|
+
j_offset = static(offset[1])
|
248
|
+
|
249
|
+
for i, j in ndrange(w, h):
|
250
|
+
normal_field[i + i_offset, j + j_offset] = vk_arr[(h - 1 - j) * w + i]
|
251
|
+
|
252
|
+
|
253
|
+
@kernel
|
254
|
+
def clear_gradients(_vars: template()):
|
255
|
+
for I in grouped(ScalarField(Expr(_vars[0]))):
|
256
|
+
for s in static(_vars):
|
257
|
+
ScalarField(Expr(s))[I] = ops.cast(0, dtype=s.get_dt())
|
258
|
+
|
259
|
+
|
260
|
+
@kernel
|
261
|
+
def field_fill_python_scope(F: template(), val: template()):
|
262
|
+
field_fill_gstaichi_scope(F, val)
|
263
|
+
|
264
|
+
|
265
|
+
@kernel
|
266
|
+
def snode_deactivate(b: template()):
|
267
|
+
for I in grouped(b):
|
268
|
+
deactivate(b, I)
|
269
|
+
|
270
|
+
|
271
|
+
@kernel
|
272
|
+
def snode_deactivate_dynamic(b: template()):
|
273
|
+
for I in grouped(b.parent()):
|
274
|
+
deactivate(b, I)
|
275
|
+
|
276
|
+
|
277
|
+
@kernel
|
278
|
+
def load_texture_from_numpy(
|
279
|
+
tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
|
280
|
+
img: ndarray_type.ndarray(dtype=vec3, ndim=2),
|
281
|
+
):
|
282
|
+
for i, j in img:
|
283
|
+
tex.store(
|
284
|
+
vector(2, i32)([i, j]),
|
285
|
+
vector(4, f32)([img[i, j][0], img[i, j][1], img[i, j][2], 0]) / 255.0,
|
286
|
+
)
|
287
|
+
|
288
|
+
|
289
|
+
@kernel
|
290
|
+
def save_texture_to_numpy(
|
291
|
+
tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
|
292
|
+
img: ndarray_type.ndarray(dtype=vec3, ndim=2),
|
293
|
+
):
|
294
|
+
for i, j in img:
|
295
|
+
img[i, j] = ops.round(tex.load(vector(2, i32)([i, j])).rgb * 255)
|
296
|
+
|
297
|
+
|
298
|
+
# Odd-even merge sort
|
299
|
+
@kernel
|
300
|
+
def sort_stage(
|
301
|
+
keys: template(),
|
302
|
+
use_values: int,
|
303
|
+
values: template(),
|
304
|
+
N: int,
|
305
|
+
p: int,
|
306
|
+
k: int,
|
307
|
+
invocations: int,
|
308
|
+
):
|
309
|
+
keys_offset = static(keys.snode.ptr.offset if len(keys.snode.ptr.offset) != 0 else 0)
|
310
|
+
values_offset = static(values.snode.ptr.offset if len(values.snode.ptr.offset) != 0 else 0)
|
311
|
+
for inv in range(invocations):
|
312
|
+
j = k % p + inv * 2 * k
|
313
|
+
for i in range(0, ops.min(k, N - j - k)):
|
314
|
+
a = i + j
|
315
|
+
b = i + j + k
|
316
|
+
if int(a / (p * 2)) == int(b / (p * 2)):
|
317
|
+
key_a = keys[a + keys_offset]
|
318
|
+
key_b = keys[b + keys_offset]
|
319
|
+
if key_a > key_b:
|
320
|
+
keys[a + keys_offset] = key_b
|
321
|
+
keys[b + keys_offset] = key_a
|
322
|
+
if use_values != 0:
|
323
|
+
temp = values[a + values_offset]
|
324
|
+
values[a + values_offset] = values[b + values_offset]
|
325
|
+
values[b + values_offset] = temp
|
326
|
+
|
327
|
+
|
328
|
+
# Parallel Prefix Sum (Scan)
|
329
|
+
@func
|
330
|
+
def warp_shfl_up_i32(val: template()):
|
331
|
+
global_tid = block.global_thread_idx()
|
332
|
+
WARP_SZ = 32
|
333
|
+
lane_id = global_tid % WARP_SZ
|
334
|
+
# Intra-warp scan, manually unrolled
|
335
|
+
offset_j = 1
|
336
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
337
|
+
if lane_id >= offset_j:
|
338
|
+
val += n
|
339
|
+
offset_j = 2
|
340
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
341
|
+
if lane_id >= offset_j:
|
342
|
+
val += n
|
343
|
+
offset_j = 4
|
344
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
345
|
+
if lane_id >= offset_j:
|
346
|
+
val += n
|
347
|
+
offset_j = 8
|
348
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
349
|
+
if lane_id >= offset_j:
|
350
|
+
val += n
|
351
|
+
offset_j = 16
|
352
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
353
|
+
if lane_id >= offset_j:
|
354
|
+
val += n
|
355
|
+
return val
|
356
|
+
|
357
|
+
|
358
|
+
@kernel
|
359
|
+
def scan_add_inclusive(
|
360
|
+
arr_in: template(),
|
361
|
+
in_beg: i32,
|
362
|
+
in_end: i32,
|
363
|
+
single_block: template(),
|
364
|
+
inclusive_add: template(),
|
365
|
+
):
|
366
|
+
WARP_SZ = 32
|
367
|
+
BLOCK_SZ = 64
|
368
|
+
loop_config(block_dim=64)
|
369
|
+
for i in range(in_beg, in_end):
|
370
|
+
val = arr_in[i]
|
371
|
+
|
372
|
+
thread_id = i % BLOCK_SZ
|
373
|
+
block_id = int((i - in_beg) // BLOCK_SZ)
|
374
|
+
lane_id = thread_id % WARP_SZ
|
375
|
+
warp_id = thread_id // WARP_SZ
|
376
|
+
|
377
|
+
pad_shared = block.SharedArray((65,), i32)
|
378
|
+
|
379
|
+
val = inclusive_add(val)
|
380
|
+
block.sync()
|
381
|
+
|
382
|
+
# Put warp scan results to smem
|
383
|
+
# TODO replace smem with real smem when available
|
384
|
+
if thread_id % WARP_SZ == WARP_SZ - 1:
|
385
|
+
pad_shared[warp_id] = val
|
386
|
+
block.sync()
|
387
|
+
|
388
|
+
# Inter-warp scan, use the first thread in the first warp
|
389
|
+
if warp_id == 0 and lane_id == 0:
|
390
|
+
for k in range(1, BLOCK_SZ / WARP_SZ):
|
391
|
+
pad_shared[k] += pad_shared[k - 1]
|
392
|
+
block.sync()
|
393
|
+
|
394
|
+
# Update data with warp sums
|
395
|
+
warp_sum = 0
|
396
|
+
if warp_id > 0:
|
397
|
+
warp_sum = pad_shared[warp_id - 1]
|
398
|
+
val += warp_sum
|
399
|
+
arr_in[i] = val
|
400
|
+
|
401
|
+
# Update partial sums except the final block
|
402
|
+
if not single_block and (thread_id == BLOCK_SZ - 1):
|
403
|
+
arr_in[in_end + block_id] = val
|
404
|
+
|
405
|
+
|
406
|
+
@kernel
|
407
|
+
def uniform_add(arr_in: template(), in_beg: i32, in_end: i32):
|
408
|
+
BLOCK_SZ = 64
|
409
|
+
loop_config(block_dim=64)
|
410
|
+
for i in range(in_beg + BLOCK_SZ, in_end):
|
411
|
+
block_id = int((i - in_beg) // BLOCK_SZ)
|
412
|
+
arr_in[i] += arr_in[in_end + block_id - 1]
|
413
|
+
|
414
|
+
|
415
|
+
@kernel
|
416
|
+
def blit_from_field_to_field(dst: template(), src: template(), offset: i32, size: i32):
|
417
|
+
dst_offset = static(dst.snode.ptr.offset if len(dst.snode.ptr.offset) != 0 else 0)
|
418
|
+
src_offset = static(src.snode.ptr.offset if len(src.snode.ptr.offset) != 0 else 0)
|
419
|
+
for i in range(size):
|
420
|
+
dst[i + dst_offset + offset] = src[i + src_offset]
|
File without changes
|
Binary file
|