gstaichi 0.0.0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (154) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  11. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  12. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  13. gstaichi/_lib/utils.py +243 -0
  14. gstaichi/_logging.py +131 -0
  15. gstaichi/_snode/__init__.py +5 -0
  16. gstaichi/_snode/fields_builder.py +187 -0
  17. gstaichi/_snode/snode_tree.py +34 -0
  18. gstaichi/_test_tools/__init__.py +18 -0
  19. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  20. gstaichi/_test_tools/load_kernel_string.py +30 -0
  21. gstaichi/_test_tools/textwrap2.py +6 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +366 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +7 -0
  52. gstaichi/lang/ast/ast_transformer.py +1351 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1259 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1386 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +784 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +10 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +21 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  113. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  114. gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  115. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  116. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  117. gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  118. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  119. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  120. gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  121. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  122. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  123. gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  124. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  125. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  126. gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  127. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  128. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  129. gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  130. gstaichi-0.0.0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  131. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  132. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  133. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  134. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  135. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  136. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  137. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  138. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  139. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  140. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  141. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  142. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  143. gstaichi-0.0.0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  144. gstaichi-0.0.0.data/data/lib/SPIRV-Tools.lib +0 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  149. gstaichi-0.0.0.data/data/lib/glfw3.lib +0 -0
  150. gstaichi-0.0.0.dist-info/METADATA +97 -0
  151. gstaichi-0.0.0.dist-info/RECORD +154 -0
  152. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  153. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  154. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
gstaichi/_kernels.py ADDED
@@ -0,0 +1,420 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._funcs import field_fill_gstaichi_scope
4
+ from gstaichi._lib.utils import get_os_name
5
+ from gstaichi.lang import ops
6
+ from gstaichi.lang._ndrange import ndrange
7
+ from gstaichi.lang.expr import Expr
8
+ from gstaichi.lang.field import ScalarField
9
+ from gstaichi.lang.impl import grouped, static, static_assert
10
+ from gstaichi.lang.kernel_impl import func, kernel
11
+ from gstaichi.lang.misc import loop_config
12
+ from gstaichi.lang.simt import block, warp
13
+ from gstaichi.lang.snode import deactivate
14
+ from gstaichi.math import vec3
15
+ from gstaichi.types import ndarray_type, texture_type, vector
16
+ from gstaichi.types.annotations import template
17
+ from gstaichi.types.enums import Format
18
+ from gstaichi.types.primitive_types import f16, f32, f64, i32, u8
19
+
20
+
21
+ # A set of helper (meta)functions
22
+ @kernel
23
+ def fill_field(field: template(), val: template()):
24
+ value = ops.cast(val, field.dtype)
25
+ for I in grouped(field):
26
+ field[I] = value
27
+
28
+
29
+ @kernel
30
+ def fill_ndarray(ndarray: ndarray_type.ndarray(), val: template()):
31
+ for I in grouped(ndarray):
32
+ ndarray[I] = val
33
+
34
+
35
+ @kernel
36
+ def fill_ndarray_matrix(ndarray: ndarray_type.ndarray(), val: template()):
37
+ for I in grouped(ndarray):
38
+ ndarray[I] = val
39
+
40
+
41
+ @kernel
42
+ def tensor_to_ext_arr(tensor: template(), arr: ndarray_type.ndarray()):
43
+ # default value of offset is [], replace it with [0] * len
44
+ offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
45
+
46
+ for I in grouped(tensor):
47
+ arr[I - offset] = tensor[I]
48
+
49
+
50
+ @kernel
51
+ def ndarray_to_ext_arr(ndarray: ndarray_type.ndarray(), arr: ndarray_type.ndarray()):
52
+ for I in grouped(ndarray):
53
+ arr[I] = ndarray[I]
54
+
55
+
56
+ @kernel
57
+ def ndarray_matrix_to_ext_arr(
58
+ ndarray: ndarray_type.ndarray(),
59
+ arr: ndarray_type.ndarray(),
60
+ layout_is_aos: template(),
61
+ as_vector: template(),
62
+ ):
63
+ for I in grouped(ndarray):
64
+ for p in static(range(ndarray[I].n)):
65
+ if static(as_vector):
66
+ if static(layout_is_aos):
67
+ arr[I, p] = ndarray[I][p]
68
+ else:
69
+ arr[p, I] = ndarray[I][p]
70
+ else:
71
+ for q in static(range(ndarray[I].m)):
72
+ if static(layout_is_aos):
73
+ arr[I, p, q] = ndarray[I][p, q]
74
+ else:
75
+ arr[p, q, I] = ndarray[I][p, q]
76
+
77
+
78
+ @kernel
79
+ def vector_to_fast_image(img: template(), out: ndarray_type.ndarray()):
80
+ static_assert(len(img.shape) == 2)
81
+ offset = static(img.snode.ptr.offset if len(img.snode.ptr.offset) != 0 else [0, 0])
82
+ i_offset = static(offset[0])
83
+ j_offset = static(offset[1])
84
+ # FIXME: Why is ``for i, j in img:`` slower than:
85
+ for i, j in ndrange(*img.shape):
86
+ r, g, b = 0, 0, 0
87
+ color = img[i + i_offset, (img.shape[1] + j_offset) - 1 - j]
88
+ if static(img.dtype in [f16, f32, f64]):
89
+ r, g, b = ops.min(255, ops.max(0, int(color * 255)))[:3]
90
+ else:
91
+ static_assert(img.dtype == u8)
92
+ r, g, b = color[:3]
93
+
94
+ idx = j * img.shape[0] + i
95
+ # We use i32 for |out| since Metal doesn't support u8 types
96
+ if static(get_os_name() != "osx"):
97
+ out[idx] = (r << 16) + (g << 8) + b
98
+ else:
99
+ # What's -16777216?
100
+ #
101
+ # On Mac, we need to set the alpha channel to 0xff. Since Mac's GUI
102
+ # is big-endian, the color is stored in ABGR order, and we need to
103
+ # add 0xff000000, which is -16777216 in I32's legit range. (Albeit
104
+ # the clarity, adding 0xff000000 doesn't work.)
105
+ alpha = -16777216
106
+ out[idx] = (b << 16) + (g << 8) + r + alpha
107
+
108
+
109
+ @kernel
110
+ def tensor_to_image(tensor: template(), arr: ndarray_type.ndarray()):
111
+ # default value of offset is [], replace it with [0] * len
112
+ offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
113
+ for I in grouped(tensor):
114
+ t = ops.cast(tensor[I], f32)
115
+ arr[I - offset, 0] = t
116
+ arr[I - offset, 1] = t
117
+ arr[I - offset, 2] = t
118
+
119
+
120
+ @kernel
121
+ def vector_to_image(mat: template(), arr: ndarray_type.ndarray()):
122
+ # default value of offset is [], replace it with [0] * len
123
+ offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
124
+ for I in grouped(mat):
125
+ for p in static(range(mat.n)):
126
+ arr[I - offset, p] = ops.cast(mat[I][p], f32)
127
+ if static(mat.n <= 2):
128
+ arr[I - offset, 2] = 0
129
+
130
+
131
+ @kernel
132
+ def tensor_to_tensor(tensor: template(), other: template()):
133
+ static_assert(tensor.shape == other.shape)
134
+ shape = static(tensor.shape)
135
+ tensor_offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(shape))
136
+ other_offset = static(other.snode.ptr.offset if len(other.snode.ptr.offset) != 0 else [0] * len(shape))
137
+
138
+ for I in grouped(ndrange(*shape)):
139
+ tensor[I + tensor_offset] = other[I + other_offset]
140
+
141
+
142
+ @kernel
143
+ def ext_arr_to_tensor(arr: ndarray_type.ndarray(), tensor: template()):
144
+ # default value of offset is [], replace it with [0] * len
145
+ offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
146
+ for I in grouped(tensor):
147
+ tensor[I] = arr[I - offset]
148
+
149
+
150
+ @kernel
151
+ def ndarray_to_ndarray(ndarray: ndarray_type.ndarray(), other: ndarray_type.ndarray()):
152
+ for I in grouped(ndarray):
153
+ ndarray[I] = other[I]
154
+
155
+
156
+ @kernel
157
+ def ext_arr_to_ndarray(arr: ndarray_type.ndarray(), ndarray: ndarray_type.ndarray()):
158
+ for I in grouped(ndarray):
159
+ ndarray[I] = arr[I]
160
+
161
+
162
+ @kernel
163
+ def ext_arr_to_ndarray_matrix(
164
+ arr: ndarray_type.ndarray(),
165
+ ndarray: ndarray_type.ndarray(),
166
+ layout_is_aos: template(),
167
+ as_vector: template(),
168
+ ):
169
+ for I in grouped(ndarray):
170
+ for p in static(range(ndarray[I].n)):
171
+ if static(as_vector):
172
+ if static(layout_is_aos):
173
+ ndarray[I][p] = arr[I, p]
174
+ else:
175
+ ndarray[I][p] = arr[p, I]
176
+ else:
177
+ for q in static(range(ndarray[I].m)):
178
+ if static(layout_is_aos):
179
+ ndarray[I][p, q] = arr[I, p, q]
180
+ else:
181
+ ndarray[I][p, q] = arr[p, q, I]
182
+
183
+
184
+ @kernel
185
+ def matrix_to_ext_arr(mat: template(), arr: ndarray_type.ndarray(), as_vector: template()):
186
+ # default value of offset is [], replace it with [0] * len
187
+ offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
188
+
189
+ for I in grouped(mat):
190
+ for p in static(range(mat.n)):
191
+ for q in static(range(mat.m)):
192
+ if static(as_vector):
193
+ if static(getattr(mat, "ndim", 2) == 1):
194
+ arr[I - offset, p] = mat[I][p]
195
+ else:
196
+ arr[I - offset, p] = mat[I][p, q]
197
+ else:
198
+ if static(getattr(mat, "ndim", 2) == 1):
199
+ arr[I - offset, p, q] = mat[I][p]
200
+ else:
201
+ arr[I - offset, p, q] = mat[I][p, q]
202
+
203
+
204
+ @kernel
205
+ def ext_arr_to_matrix(arr: ndarray_type.ndarray(), mat: template(), as_vector: template()):
206
+ # default value of offset is [], replace it with [0] * len
207
+ offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
208
+
209
+ for I in grouped(mat):
210
+ for p in static(range(mat.n)):
211
+ for q in static(range(mat.m)):
212
+ if static(getattr(mat, "ndim", 2) == 1):
213
+ if static(as_vector):
214
+ mat[I][p] = arr[I - offset, p]
215
+ else:
216
+ mat[I][p] = arr[I - offset, p, q]
217
+ else:
218
+ if static(as_vector):
219
+ mat[I][p, q] = arr[I - offset, p]
220
+ else:
221
+ mat[I][p, q] = arr[I - offset, p, q]
222
+
223
+
224
+ # extract ndarray of raw vulkan memory layout to normal memory layout.
225
+ # the vulkan layout stored in ndarray : width-by-width stored along n-
226
+ # darray's shape[1] which is the height-axis(So use [size // h, size %
227
+ # h]). And the height-order of vulkan layout is flip up-down.(So take
228
+ # [size = (h - 1 - j) * w + i] to get the index)
229
+ @kernel
230
+ def arr_vulkan_layout_to_arr_normal_layout(vk_arr: ndarray_type.ndarray(), normal_arr: ndarray_type.ndarray()):
231
+ static_assert(len(normal_arr.shape) == 2)
232
+ w = normal_arr.shape[0]
233
+ h = normal_arr.shape[1]
234
+ for i, j in ndrange(w, h):
235
+ normal_arr[i, j] = vk_arr[(h - 1 - j) * w + i]
236
+
237
+
238
+ # extract ndarray of raw vulkan memory layout into a gstaichi-field data
239
+ # structure with normal memory layout.
240
+ @kernel
241
+ def arr_vulkan_layout_to_field_normal_layout(vk_arr: ndarray_type.ndarray(), normal_field: template()):
242
+ static_assert(len(normal_field.shape) == 2)
243
+ w = static(normal_field.shape[0])
244
+ h = static(normal_field.shape[1])
245
+ offset = static(normal_field.snode.ptr.offset if len(normal_field.snode.ptr.offset) != 0 else [0, 0])
246
+ i_offset = static(offset[0])
247
+ j_offset = static(offset[1])
248
+
249
+ for i, j in ndrange(w, h):
250
+ normal_field[i + i_offset, j + j_offset] = vk_arr[(h - 1 - j) * w + i]
251
+
252
+
253
+ @kernel
254
+ def clear_gradients(_vars: template()):
255
+ for I in grouped(ScalarField(Expr(_vars[0]))):
256
+ for s in static(_vars):
257
+ ScalarField(Expr(s))[I] = ops.cast(0, dtype=s.get_dt())
258
+
259
+
260
+ @kernel
261
+ def field_fill_python_scope(F: template(), val: template()):
262
+ field_fill_gstaichi_scope(F, val)
263
+
264
+
265
+ @kernel
266
+ def snode_deactivate(b: template()):
267
+ for I in grouped(b):
268
+ deactivate(b, I)
269
+
270
+
271
+ @kernel
272
+ def snode_deactivate_dynamic(b: template()):
273
+ for I in grouped(b.parent()):
274
+ deactivate(b, I)
275
+
276
+
277
+ @kernel
278
+ def load_texture_from_numpy(
279
+ tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
280
+ img: ndarray_type.ndarray(dtype=vec3, ndim=2),
281
+ ):
282
+ for i, j in img:
283
+ tex.store(
284
+ vector(2, i32)([i, j]),
285
+ vector(4, f32)([img[i, j][0], img[i, j][1], img[i, j][2], 0]) / 255.0,
286
+ )
287
+
288
+
289
+ @kernel
290
+ def save_texture_to_numpy(
291
+ tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
292
+ img: ndarray_type.ndarray(dtype=vec3, ndim=2),
293
+ ):
294
+ for i, j in img:
295
+ img[i, j] = ops.round(tex.load(vector(2, i32)([i, j])).rgb * 255)
296
+
297
+
298
+ # Odd-even merge sort
299
+ @kernel
300
+ def sort_stage(
301
+ keys: template(),
302
+ use_values: int,
303
+ values: template(),
304
+ N: int,
305
+ p: int,
306
+ k: int,
307
+ invocations: int,
308
+ ):
309
+ keys_offset = static(keys.snode.ptr.offset if len(keys.snode.ptr.offset) != 0 else 0)
310
+ values_offset = static(values.snode.ptr.offset if len(values.snode.ptr.offset) != 0 else 0)
311
+ for inv in range(invocations):
312
+ j = k % p + inv * 2 * k
313
+ for i in range(0, ops.min(k, N - j - k)):
314
+ a = i + j
315
+ b = i + j + k
316
+ if int(a / (p * 2)) == int(b / (p * 2)):
317
+ key_a = keys[a + keys_offset]
318
+ key_b = keys[b + keys_offset]
319
+ if key_a > key_b:
320
+ keys[a + keys_offset] = key_b
321
+ keys[b + keys_offset] = key_a
322
+ if use_values != 0:
323
+ temp = values[a + values_offset]
324
+ values[a + values_offset] = values[b + values_offset]
325
+ values[b + values_offset] = temp
326
+
327
+
328
+ # Parallel Prefix Sum (Scan)
329
+ @func
330
+ def warp_shfl_up_i32(val: template()):
331
+ global_tid = block.global_thread_idx()
332
+ WARP_SZ = 32
333
+ lane_id = global_tid % WARP_SZ
334
+ # Intra-warp scan, manually unrolled
335
+ offset_j = 1
336
+ n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
337
+ if lane_id >= offset_j:
338
+ val += n
339
+ offset_j = 2
340
+ n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
341
+ if lane_id >= offset_j:
342
+ val += n
343
+ offset_j = 4
344
+ n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
345
+ if lane_id >= offset_j:
346
+ val += n
347
+ offset_j = 8
348
+ n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
349
+ if lane_id >= offset_j:
350
+ val += n
351
+ offset_j = 16
352
+ n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
353
+ if lane_id >= offset_j:
354
+ val += n
355
+ return val
356
+
357
+
358
+ @kernel
359
+ def scan_add_inclusive(
360
+ arr_in: template(),
361
+ in_beg: i32,
362
+ in_end: i32,
363
+ single_block: template(),
364
+ inclusive_add: template(),
365
+ ):
366
+ WARP_SZ = 32
367
+ BLOCK_SZ = 64
368
+ loop_config(block_dim=64)
369
+ for i in range(in_beg, in_end):
370
+ val = arr_in[i]
371
+
372
+ thread_id = i % BLOCK_SZ
373
+ block_id = int((i - in_beg) // BLOCK_SZ)
374
+ lane_id = thread_id % WARP_SZ
375
+ warp_id = thread_id // WARP_SZ
376
+
377
+ pad_shared = block.SharedArray((65,), i32)
378
+
379
+ val = inclusive_add(val)
380
+ block.sync()
381
+
382
+ # Put warp scan results to smem
383
+ # TODO replace smem with real smem when available
384
+ if thread_id % WARP_SZ == WARP_SZ - 1:
385
+ pad_shared[warp_id] = val
386
+ block.sync()
387
+
388
+ # Inter-warp scan, use the first thread in the first warp
389
+ if warp_id == 0 and lane_id == 0:
390
+ for k in range(1, BLOCK_SZ / WARP_SZ):
391
+ pad_shared[k] += pad_shared[k - 1]
392
+ block.sync()
393
+
394
+ # Update data with warp sums
395
+ warp_sum = 0
396
+ if warp_id > 0:
397
+ warp_sum = pad_shared[warp_id - 1]
398
+ val += warp_sum
399
+ arr_in[i] = val
400
+
401
+ # Update partial sums except the final block
402
+ if not single_block and (thread_id == BLOCK_SZ - 1):
403
+ arr_in[in_end + block_id] = val
404
+
405
+
406
+ @kernel
407
+ def uniform_add(arr_in: template(), in_beg: i32, in_end: i32):
408
+ BLOCK_SZ = 64
409
+ loop_config(block_dim=64)
410
+ for i in range(in_beg + BLOCK_SZ, in_end):
411
+ block_id = int((i - in_beg) // BLOCK_SZ)
412
+ arr_in[i] += arr_in[in_end + block_id - 1]
413
+
414
+
415
+ @kernel
416
+ def blit_from_field_to_field(dst: template(), src: template(), offset: i32, size: i32):
417
+ dst_offset = static(dst.snode.ptr.offset if len(dst.snode.ptr.offset) != 0 else 0)
418
+ src_offset = static(src.snode.ptr.offset if len(src.snode.ptr.offset) != 0 else 0)
419
+ for i in range(size):
420
+ dst[i + dst_offset + offset] = src[i + src_offset]
@@ -0,0 +1,5 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._lib.utils import ti_python_core as core
4
+
5
+ __all__ = ["core"]
File without changes