gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +9 -0
- gstaichi/__init__.py +40 -0
- gstaichi/__main__.py +5 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- gstaichi/_lib/utils.py +249 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_main.py +545 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +0 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +103 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +199 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +189 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/argpack.py +411 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1318 -0
- gstaichi/lang/ast/ast_transformer_utils.py +341 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +466 -0
- gstaichi/lang/impl.py +1241 -0
- gstaichi/lang/kernel_arguments.py +157 -0
- gstaichi/lang/kernel_impl.py +1382 -0
- gstaichi/lang/matrix.py +1881 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +778 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +855 -0
- gstaichi/lang/util.py +381 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +47 -0
- gstaichi/types/compound_types.py +90 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +147 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +13 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
- gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
- gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
- gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,50 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang import impl, simt
|
4
|
+
from gstaichi.lang._ndarray import *
|
5
|
+
from gstaichi.lang._ndrange import ndrange
|
6
|
+
from gstaichi.lang._texture import Texture
|
7
|
+
from gstaichi.lang.argpack import *
|
8
|
+
from gstaichi.lang.exception import *
|
9
|
+
from gstaichi.lang.field import *
|
10
|
+
from gstaichi.lang.impl import *
|
11
|
+
from gstaichi.lang.kernel_impl import *
|
12
|
+
from gstaichi.lang.matrix import *
|
13
|
+
from gstaichi.lang.mesh import *
|
14
|
+
from gstaichi.lang.misc import * # pylint: disable=W0622
|
15
|
+
from gstaichi.lang.ops import * # pylint: disable=W0622
|
16
|
+
from gstaichi.lang.runtime_ops import *
|
17
|
+
from gstaichi.lang.snode import *
|
18
|
+
from gstaichi.lang.source_builder import *
|
19
|
+
from gstaichi.lang.struct import *
|
20
|
+
from gstaichi.types.enums import DeviceCapability, Format, Layout
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
s
|
24
|
+
for s in dir()
|
25
|
+
if not s.startswith("_")
|
26
|
+
and s
|
27
|
+
not in [
|
28
|
+
"any_array",
|
29
|
+
"ast",
|
30
|
+
"common_ops",
|
31
|
+
"enums",
|
32
|
+
"exception",
|
33
|
+
"expr",
|
34
|
+
"impl",
|
35
|
+
"inspect",
|
36
|
+
"kernel_arguments",
|
37
|
+
"kernel_impl",
|
38
|
+
"matrix",
|
39
|
+
"mesh",
|
40
|
+
"misc",
|
41
|
+
"ops",
|
42
|
+
"platform",
|
43
|
+
"runtime_ops",
|
44
|
+
"shell",
|
45
|
+
"snode",
|
46
|
+
"source_builder",
|
47
|
+
"struct",
|
48
|
+
"util",
|
49
|
+
]
|
50
|
+
]
|
@@ -0,0 +1,352 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Union
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
|
7
|
+
from gstaichi._lib import core as _ti_core
|
8
|
+
from gstaichi.lang import impl
|
9
|
+
from gstaichi.lang.exception import GsTaichiIndexError
|
10
|
+
from gstaichi.lang.util import cook_dtype, get_traceback, python_scope, to_numpy_type
|
11
|
+
from gstaichi.types import primitive_types
|
12
|
+
from gstaichi.types.enums import Layout
|
13
|
+
from gstaichi.types.ndarray_type import NdarrayTypeMetadata
|
14
|
+
from gstaichi.types.utils import is_real, is_signed
|
15
|
+
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from gstaichi.lang.matrix import MatrixNdarray, VectorNdarray
|
18
|
+
|
19
|
+
TensorNdarray = Union["ScalarNdarray", VectorNdarray, MatrixNdarray]
|
20
|
+
|
21
|
+
|
22
|
+
class Ndarray:
|
23
|
+
"""GsTaichi ndarray class.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
dtype (DataType): Data type of each value.
|
27
|
+
shape (Tuple[int]): Shape of the Ndarray.
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self):
|
31
|
+
self.host_accessor = None
|
32
|
+
self.shape = None
|
33
|
+
self.element_type = None
|
34
|
+
self.dtype = None
|
35
|
+
self.arr = None
|
36
|
+
self.layout = Layout.AOS
|
37
|
+
self.grad: "TensorNdarray | None" = None
|
38
|
+
|
39
|
+
def get_type(self):
|
40
|
+
return NdarrayTypeMetadata(self.element_type, self.shape, self.grad is not None)
|
41
|
+
|
42
|
+
@property
|
43
|
+
def element_shape(self):
|
44
|
+
"""Gets ndarray element shape.
|
45
|
+
|
46
|
+
Returns:
|
47
|
+
Tuple[Int]: Ndarray element shape.
|
48
|
+
"""
|
49
|
+
raise NotImplementedError()
|
50
|
+
|
51
|
+
@python_scope
|
52
|
+
def __setitem__(self, key, value):
|
53
|
+
"""Sets ndarray element in Python scope.
|
54
|
+
|
55
|
+
Args:
|
56
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
57
|
+
value (element type): Value to set.
|
58
|
+
"""
|
59
|
+
raise NotImplementedError()
|
60
|
+
|
61
|
+
@python_scope
|
62
|
+
def __getitem__(self, key):
|
63
|
+
"""Gets ndarray element in Python scope.
|
64
|
+
|
65
|
+
Args:
|
66
|
+
key (Union[List[int], int, None]): Coordinates of the ndarray element.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
element type: Value retrieved.
|
70
|
+
"""
|
71
|
+
raise NotImplementedError()
|
72
|
+
|
73
|
+
@python_scope
|
74
|
+
def fill(self, val):
|
75
|
+
"""Fills ndarray with a specific scalar value.
|
76
|
+
|
77
|
+
Args:
|
78
|
+
val (Union[int, float]): Value to fill.
|
79
|
+
"""
|
80
|
+
if impl.current_cfg().arch != _ti_core.Arch.cuda and impl.current_cfg().arch != _ti_core.Arch.x64:
|
81
|
+
self._fill_by_kernel(val)
|
82
|
+
elif _ti_core.is_tensor(self.element_type):
|
83
|
+
self._fill_by_kernel(val)
|
84
|
+
elif self.dtype == primitive_types.f32:
|
85
|
+
impl.get_runtime().prog.fill_float(self.arr, val)
|
86
|
+
elif self.dtype == primitive_types.i32:
|
87
|
+
impl.get_runtime().prog.fill_int(self.arr, val)
|
88
|
+
elif self.dtype == primitive_types.u32:
|
89
|
+
impl.get_runtime().prog.fill_uint(self.arr, val)
|
90
|
+
else:
|
91
|
+
self._fill_by_kernel(val)
|
92
|
+
|
93
|
+
@python_scope
|
94
|
+
def _ndarray_to_numpy(self):
|
95
|
+
"""Converts ndarray to a numpy array.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
numpy.ndarray: The result numpy array.
|
99
|
+
"""
|
100
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
101
|
+
from gstaichi._kernels import ndarray_to_ext_arr # pylint: disable=C0415
|
102
|
+
|
103
|
+
ndarray_to_ext_arr(self, arr)
|
104
|
+
impl.get_runtime().sync()
|
105
|
+
return arr
|
106
|
+
|
107
|
+
@python_scope
|
108
|
+
def _ndarray_matrix_to_numpy(self, as_vector):
|
109
|
+
"""Converts matrix ndarray to a numpy array.
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
numpy.ndarray: The result numpy array.
|
113
|
+
"""
|
114
|
+
arr = np.zeros(shape=self.arr.total_shape(), dtype=to_numpy_type(self.dtype))
|
115
|
+
from gstaichi._kernels import ndarray_matrix_to_ext_arr # pylint: disable=C0415
|
116
|
+
|
117
|
+
layout_is_aos = 1
|
118
|
+
ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector)
|
119
|
+
impl.get_runtime().sync()
|
120
|
+
return arr
|
121
|
+
|
122
|
+
@python_scope
|
123
|
+
def _ndarray_from_numpy(self, arr):
|
124
|
+
"""Loads all values from a numpy array.
|
125
|
+
|
126
|
+
Args:
|
127
|
+
arr (numpy.ndarray): The source numpy array.
|
128
|
+
"""
|
129
|
+
if not isinstance(arr, np.ndarray):
|
130
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
131
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
132
|
+
raise ValueError(f"Mismatch shape: {tuple(self.arr.shape)} expected, but {tuple(arr.shape)} provided")
|
133
|
+
if not arr.flags.c_contiguous:
|
134
|
+
arr = np.ascontiguousarray(arr)
|
135
|
+
|
136
|
+
from gstaichi._kernels import ext_arr_to_ndarray # pylint: disable=C0415
|
137
|
+
|
138
|
+
ext_arr_to_ndarray(arr, self)
|
139
|
+
impl.get_runtime().sync()
|
140
|
+
|
141
|
+
@python_scope
|
142
|
+
def _ndarray_matrix_from_numpy(self, arr, as_vector):
|
143
|
+
"""Loads all values from a numpy array.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
arr (numpy.ndarray): The source numpy array.
|
147
|
+
"""
|
148
|
+
if not isinstance(arr, np.ndarray):
|
149
|
+
raise TypeError(f"{np.ndarray} expected, but {type(arr)} provided")
|
150
|
+
if tuple(self.arr.total_shape()) != tuple(arr.shape):
|
151
|
+
raise ValueError(
|
152
|
+
f"Mismatch shape: {tuple(self.arr.total_shape())} expected, but {tuple(arr.shape)} provided"
|
153
|
+
)
|
154
|
+
if not arr.flags.c_contiguous:
|
155
|
+
arr = np.ascontiguousarray(arr)
|
156
|
+
|
157
|
+
from gstaichi._kernels import ext_arr_to_ndarray_matrix # pylint: disable=C0415
|
158
|
+
|
159
|
+
layout_is_aos = 1
|
160
|
+
ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector)
|
161
|
+
impl.get_runtime().sync()
|
162
|
+
|
163
|
+
@python_scope
|
164
|
+
def _get_element_size(self):
|
165
|
+
"""Returns the size of one element in bytes.
|
166
|
+
|
167
|
+
Returns:
|
168
|
+
Size in bytes.
|
169
|
+
"""
|
170
|
+
return self.arr.element_size()
|
171
|
+
|
172
|
+
@python_scope
|
173
|
+
def _get_nelement(self):
|
174
|
+
"""Returns the total number of elements.
|
175
|
+
|
176
|
+
Returns:
|
177
|
+
Total number of elements.
|
178
|
+
"""
|
179
|
+
return self.arr.nelement()
|
180
|
+
|
181
|
+
@python_scope
|
182
|
+
def copy_from(self, other):
|
183
|
+
"""Copies all elements from another ndarray.
|
184
|
+
|
185
|
+
The shape of the other ndarray needs to be the same as `self`.
|
186
|
+
|
187
|
+
Args:
|
188
|
+
other (Ndarray): The source ndarray.
|
189
|
+
"""
|
190
|
+
assert isinstance(other, Ndarray)
|
191
|
+
assert tuple(self.arr.shape) == tuple(other.arr.shape)
|
192
|
+
from gstaichi._kernels import ndarray_to_ndarray # pylint: disable=C0415
|
193
|
+
|
194
|
+
ndarray_to_ndarray(self, other)
|
195
|
+
impl.get_runtime().sync()
|
196
|
+
|
197
|
+
def _set_grad(self, grad: "TensorNdarray"):
|
198
|
+
"""Sets the gradient ndarray.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
grad (Ndarray): The gradient ndarray.
|
202
|
+
"""
|
203
|
+
self.grad = grad
|
204
|
+
|
205
|
+
def __deepcopy__(self, memo=None):
|
206
|
+
"""Copies all elements to a new ndarray.
|
207
|
+
|
208
|
+
Returns:
|
209
|
+
Ndarray: The result ndarray.
|
210
|
+
"""
|
211
|
+
raise NotImplementedError()
|
212
|
+
|
213
|
+
def _fill_by_kernel(self, val):
|
214
|
+
"""Fills ndarray with a specific scalar value using a ti.kernel.
|
215
|
+
|
216
|
+
Args:
|
217
|
+
val (Union[int, float]): Value to fill.
|
218
|
+
"""
|
219
|
+
raise NotImplementedError()
|
220
|
+
|
221
|
+
@python_scope
|
222
|
+
def _pad_key(self, key):
|
223
|
+
if key is None:
|
224
|
+
key = ()
|
225
|
+
if not isinstance(key, (tuple, list)):
|
226
|
+
key = (key,)
|
227
|
+
if len(key) != len(self.arr.total_shape()):
|
228
|
+
raise GsTaichiIndexError(f"{len(self.arr.total_shape())}d ndarray indexed with {len(key)}d indices: {key}")
|
229
|
+
return key
|
230
|
+
|
231
|
+
@python_scope
|
232
|
+
def _initialize_host_accessor(self):
|
233
|
+
if self.host_accessor:
|
234
|
+
return
|
235
|
+
impl.get_runtime().materialize()
|
236
|
+
self.host_accessor = NdarrayHostAccessor(self.arr)
|
237
|
+
|
238
|
+
|
239
|
+
class ScalarNdarray(Ndarray):
|
240
|
+
"""GsTaichi ndarray with scalar elements.
|
241
|
+
|
242
|
+
Args:
|
243
|
+
dtype (DataType): Data type of each value.
|
244
|
+
shape (Tuple[int]): Shape of the ndarray.
|
245
|
+
"""
|
246
|
+
|
247
|
+
def __init__(self, dtype, arr_shape):
|
248
|
+
super().__init__()
|
249
|
+
self.dtype = cook_dtype(dtype)
|
250
|
+
self.arr = impl.get_runtime().prog.create_ndarray(
|
251
|
+
self.dtype, arr_shape, layout=Layout.NULL, zero_fill=True, dbg_info=_ti_core.DebugInfo(get_traceback())
|
252
|
+
)
|
253
|
+
self.shape = tuple(self.arr.shape)
|
254
|
+
self.element_type = dtype
|
255
|
+
|
256
|
+
def __del__(self):
|
257
|
+
if impl is not None and impl.get_runtime is not None and impl.get_runtime() is not None:
|
258
|
+
prog = impl.get_runtime()._prog
|
259
|
+
if prog is not None:
|
260
|
+
prog.delete_ndarray(self.arr)
|
261
|
+
|
262
|
+
@property
|
263
|
+
def element_shape(self):
|
264
|
+
return ()
|
265
|
+
|
266
|
+
@python_scope
|
267
|
+
def __setitem__(self, key, value):
|
268
|
+
self._initialize_host_accessor()
|
269
|
+
self.host_accessor.setter(value, *self._pad_key(key))
|
270
|
+
|
271
|
+
@python_scope
|
272
|
+
def __getitem__(self, key):
|
273
|
+
self._initialize_host_accessor()
|
274
|
+
return self.host_accessor.getter(*self._pad_key(key))
|
275
|
+
|
276
|
+
@python_scope
|
277
|
+
def to_numpy(self):
|
278
|
+
return self._ndarray_to_numpy()
|
279
|
+
|
280
|
+
@python_scope
|
281
|
+
def from_numpy(self, arr):
|
282
|
+
self._ndarray_from_numpy(arr)
|
283
|
+
|
284
|
+
def __deepcopy__(self, memo=None):
|
285
|
+
ret_arr = ScalarNdarray(self.dtype, self.shape)
|
286
|
+
ret_arr.copy_from(self)
|
287
|
+
return ret_arr
|
288
|
+
|
289
|
+
def _fill_by_kernel(self, val):
|
290
|
+
from gstaichi._kernels import fill_ndarray # pylint: disable=C0415
|
291
|
+
|
292
|
+
fill_ndarray(self, val)
|
293
|
+
|
294
|
+
def __repr__(self):
|
295
|
+
return "<ti.ndarray>"
|
296
|
+
|
297
|
+
|
298
|
+
class NdarrayHostAccessor:
|
299
|
+
def __init__(self, ndarray):
|
300
|
+
dtype = ndarray.element_data_type()
|
301
|
+
if is_real(dtype):
|
302
|
+
|
303
|
+
def getter(*key):
|
304
|
+
return ndarray.read_float(key)
|
305
|
+
|
306
|
+
def setter(value, *key):
|
307
|
+
ndarray.write_float(key, value)
|
308
|
+
|
309
|
+
else:
|
310
|
+
if is_signed(dtype):
|
311
|
+
|
312
|
+
def getter(*key):
|
313
|
+
return ndarray.read_int(key)
|
314
|
+
|
315
|
+
else:
|
316
|
+
|
317
|
+
def getter(*key):
|
318
|
+
return ndarray.read_uint(key)
|
319
|
+
|
320
|
+
def setter(value, *key):
|
321
|
+
ndarray.write_int(key, value)
|
322
|
+
|
323
|
+
self.getter = getter
|
324
|
+
self.setter = setter
|
325
|
+
|
326
|
+
|
327
|
+
class NdarrayHostAccess:
|
328
|
+
"""Class for accessing VectorNdarray/MatrixNdarray in Python scope.
|
329
|
+
Args:
|
330
|
+
arr (Union[VectorNdarray, MatrixNdarray]): See above.
|
331
|
+
indices_first (Tuple[Int]): Indices of first-level access (coordinates in the field).
|
332
|
+
indices_second (Tuple[Int]): Indices of second-level access (indices in the vector/matrix).
|
333
|
+
"""
|
334
|
+
|
335
|
+
def __init__(self, arr, indices_first, indices_second):
|
336
|
+
self.ndarr = arr
|
337
|
+
self.arr = arr.arr
|
338
|
+
self.indices = indices_first + indices_second
|
339
|
+
|
340
|
+
def getter():
|
341
|
+
self.ndarr._initialize_host_accessor()
|
342
|
+
return self.ndarr.host_accessor.getter(*self.ndarr._pad_key(self.indices))
|
343
|
+
|
344
|
+
def setter(value):
|
345
|
+
self.ndarr._initialize_host_accessor()
|
346
|
+
self.ndarr.host_accessor.setter(value, *self.ndarr._pad_key(self.indices))
|
347
|
+
|
348
|
+
self.getter = getter
|
349
|
+
self.setter = setter
|
350
|
+
|
351
|
+
|
352
|
+
__all__ = ["Ndarray", "ScalarNdarray"]
|
@@ -0,0 +1,152 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import collections.abc
|
4
|
+
from typing import Iterable
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
|
8
|
+
from gstaichi.lang import ops
|
9
|
+
from gstaichi.lang.exception import GsTaichiSyntaxError, GsTaichiTypeError
|
10
|
+
from gstaichi.lang.expr import Expr
|
11
|
+
from gstaichi.lang.matrix import Matrix
|
12
|
+
from gstaichi.types.utils import is_integral
|
13
|
+
|
14
|
+
|
15
|
+
class _Ndrange:
|
16
|
+
def __init__(self, *args):
|
17
|
+
args = list(args)
|
18
|
+
for i, arg in enumerate(args):
|
19
|
+
if not isinstance(arg, collections.abc.Sequence):
|
20
|
+
args[i] = (0, arg)
|
21
|
+
if len(args[i]) != 2:
|
22
|
+
raise GsTaichiSyntaxError(
|
23
|
+
"Every argument of ndrange should be a scalar or a tuple/list like (begin, end)"
|
24
|
+
)
|
25
|
+
args[i] = (args[i][0], ops.max(args[i][0], args[i][1]))
|
26
|
+
for arg in args:
|
27
|
+
for bound in arg:
|
28
|
+
if not isinstance(bound, (int, np.integer)) and not (
|
29
|
+
isinstance(bound, Expr) and is_integral(bound.ptr.get_rvalue_type())
|
30
|
+
):
|
31
|
+
raise GsTaichiTypeError(
|
32
|
+
"Every argument of ndrange should be an integer scalar or a tuple/list of (int, int)"
|
33
|
+
)
|
34
|
+
self.bounds = args
|
35
|
+
|
36
|
+
self.dimensions = [None] * len(args)
|
37
|
+
for i, bound in enumerate(self.bounds):
|
38
|
+
self.dimensions[i] = bound[1] - bound[0]
|
39
|
+
|
40
|
+
self.acc_dimensions = self.dimensions.copy()
|
41
|
+
for i in reversed(range(len(self.bounds) - 1)):
|
42
|
+
self.acc_dimensions[i] = self.acc_dimensions[i] * self.acc_dimensions[i + 1]
|
43
|
+
if len(self.acc_dimensions) == 0: # for the empty case, e.g. ti.ndrange()
|
44
|
+
self.acc_dimensions = [1]
|
45
|
+
|
46
|
+
def __iter__(self):
|
47
|
+
def gen(d, prefix):
|
48
|
+
if d == len(self.bounds):
|
49
|
+
yield prefix
|
50
|
+
else:
|
51
|
+
for t in range(self.bounds[d][0], self.bounds[d][1]):
|
52
|
+
yield from gen(d + 1, prefix + (t,))
|
53
|
+
|
54
|
+
yield from gen(0, ())
|
55
|
+
|
56
|
+
def grouped(self):
|
57
|
+
return GroupedNDRange(self)
|
58
|
+
|
59
|
+
|
60
|
+
def ndrange(*args) -> Iterable:
|
61
|
+
"""Return an immutable iterator object for looping over multi-dimensional indices.
|
62
|
+
|
63
|
+
This returned set of multi-dimensional indices is the direct product (in the set-theory sense)
|
64
|
+
of n groups of integers, where n equals the number of arguments in the input list, and looks like
|
65
|
+
|
66
|
+
range(x1, y1) x range(x2, y2) x ... x range(xn, yn)
|
67
|
+
|
68
|
+
The k-th argument corresponds to the k-th `range()` factor in the above product, and each
|
69
|
+
argument must be an integer or a pair of two integers. An integer argument n will be interpreted
|
70
|
+
as `range(0, n)`, and a pair of two integers (start, end) will be interpreted as `range(start, end)`.
|
71
|
+
|
72
|
+
You can loop over these multi-dimensonal indices in different ways, see the examples below.
|
73
|
+
|
74
|
+
Args:
|
75
|
+
entries: (int, tuple): Must be either an integer, or a tuple/list of two integers.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
An immutable iterator object.
|
79
|
+
|
80
|
+
Example::
|
81
|
+
|
82
|
+
You can loop over 1-D integers in range [start, end), as in native Python
|
83
|
+
|
84
|
+
>>> @ti.kernel
|
85
|
+
>>> def loop_1d():
|
86
|
+
>>> start = 2
|
87
|
+
>>> end = 5
|
88
|
+
>>> for i in ti.ndrange((start, end)):
|
89
|
+
>>> print(i) # will print 2 3 4
|
90
|
+
|
91
|
+
Note the braces around `(start, end)` in the above code. If without them,
|
92
|
+
the parameter `2` will be interpreted as `range(0, 2)`, `5` will be
|
93
|
+
interpreted as `range(0, 5)`, and you will get a set of 2-D indices which
|
94
|
+
contains 2x5=10 elements, and need two indices i, j to loop over them:
|
95
|
+
|
96
|
+
>>> @ti.kernel
|
97
|
+
>>> def loop_2d():
|
98
|
+
>>> for i, j in ti.ndrange(2, 5):
|
99
|
+
>>> print(i, j)
|
100
|
+
0 0
|
101
|
+
...
|
102
|
+
0 4
|
103
|
+
...
|
104
|
+
1 4
|
105
|
+
|
106
|
+
But you do can use a single index i to loop over these 2-D indices, in this case
|
107
|
+
the indices are returned as a 1-D array `(0, 1, ..., 9)`:
|
108
|
+
|
109
|
+
>>> @ti.kernel
|
110
|
+
>>> def loop_2d_as_1d():
|
111
|
+
>>> for i in ti.ndrange(2, 5):
|
112
|
+
>>> print(i)
|
113
|
+
will print 0 1 2 3 4 5 6 7 8 9
|
114
|
+
|
115
|
+
In general, you can use any `1 <= k <= n` iterators to loop over a set of n-D
|
116
|
+
indices. For `k=n` all the indices are n-dimensional, and they are returned in
|
117
|
+
lexical order, but for `k<n` iterators the last n-k+1 dimensions will be collapsed into
|
118
|
+
a 1-D array of consecutive integers `(0, 1, 2, ...)` whose length equals the
|
119
|
+
total number of indices in the last n-k+1 dimensions:
|
120
|
+
|
121
|
+
>>> @ti.kernel
|
122
|
+
>>> def loop_3d_as_2d():
|
123
|
+
>>> # use two iterators to loop over a set of 3-D indices
|
124
|
+
>>> # the last two dimensions for 4, 5 will collapse into
|
125
|
+
>>> # the array [0, 1, 2, ..., 19]
|
126
|
+
>>> for i, j in ti.ndrange(3, 4, 5):
|
127
|
+
>>> print(i, j)
|
128
|
+
will print 0 0, 0 1, ..., 0 19, ..., 2 19.
|
129
|
+
|
130
|
+
A typical usage of `ndrange` is when you want to loop over a tensor and process
|
131
|
+
its entries in parallel. You should avoid writing nested `for` loops here since
|
132
|
+
only top level `for` loops are paralleled in gstaichi, instead you can use `ndrange`
|
133
|
+
to hold all entries in one top level loop:
|
134
|
+
|
135
|
+
>>> @ti.kernel
|
136
|
+
>>> def loop_tensor():
|
137
|
+
>>> for row, col, channel in ti.ndrange(image_height, image_width, channels):
|
138
|
+
>>> image[row, col, channel] = ...
|
139
|
+
"""
|
140
|
+
return _Ndrange(*args)
|
141
|
+
|
142
|
+
|
143
|
+
class GroupedNDRange:
|
144
|
+
def __init__(self, r):
|
145
|
+
self.r = r
|
146
|
+
|
147
|
+
def __iter__(self):
|
148
|
+
for ind in self.r:
|
149
|
+
yield Matrix(list(ind))
|
150
|
+
|
151
|
+
|
152
|
+
__all__ = ["ndrange"]
|