gstaichi 0.1.18.dev1__cp310-cp310-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +219 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_metal.h +72 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/libtaichi_c_api.dylib +0 -0
- taichi/_lib/c_api/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/c_api/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cpython-310-darwin.so +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/libMoltenVK.dylib +0 -0
- taichi/_lib/runtime/runtime_arm64.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
taichi/_kernels.py
ADDED
@@ -0,0 +1,420 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from taichi._funcs import field_fill_taichi_scope
|
4
|
+
from taichi._lib.utils import get_os_name
|
5
|
+
from taichi.lang import ops
|
6
|
+
from taichi.lang._ndrange import ndrange
|
7
|
+
from taichi.lang.expr import Expr
|
8
|
+
from taichi.lang.field import ScalarField
|
9
|
+
from taichi.lang.impl import grouped, static, static_assert
|
10
|
+
from taichi.lang.kernel_impl import func, kernel
|
11
|
+
from taichi.lang.misc import loop_config
|
12
|
+
from taichi.lang.simt import block, warp
|
13
|
+
from taichi.lang.snode import deactivate
|
14
|
+
from taichi.math import vec3
|
15
|
+
from taichi.types import ndarray_type, texture_type, vector
|
16
|
+
from taichi.types.annotations import template
|
17
|
+
from taichi.types.enums import Format
|
18
|
+
from taichi.types.primitive_types import f16, f32, f64, i32, u8
|
19
|
+
|
20
|
+
|
21
|
+
# A set of helper (meta)functions
|
22
|
+
@kernel
|
23
|
+
def fill_field(field: template(), val: template()):
|
24
|
+
value = ops.cast(val, field.dtype)
|
25
|
+
for I in grouped(field):
|
26
|
+
field[I] = value
|
27
|
+
|
28
|
+
|
29
|
+
@kernel
|
30
|
+
def fill_ndarray(ndarray: ndarray_type.ndarray(), val: template()):
|
31
|
+
for I in grouped(ndarray):
|
32
|
+
ndarray[I] = val
|
33
|
+
|
34
|
+
|
35
|
+
@kernel
|
36
|
+
def fill_ndarray_matrix(ndarray: ndarray_type.ndarray(), val: template()):
|
37
|
+
for I in grouped(ndarray):
|
38
|
+
ndarray[I] = val
|
39
|
+
|
40
|
+
|
41
|
+
@kernel
|
42
|
+
def tensor_to_ext_arr(tensor: template(), arr: ndarray_type.ndarray()):
|
43
|
+
# default value of offset is [], replace it with [0] * len
|
44
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
45
|
+
|
46
|
+
for I in grouped(tensor):
|
47
|
+
arr[I - offset] = tensor[I]
|
48
|
+
|
49
|
+
|
50
|
+
@kernel
|
51
|
+
def ndarray_to_ext_arr(ndarray: ndarray_type.ndarray(), arr: ndarray_type.ndarray()):
|
52
|
+
for I in grouped(ndarray):
|
53
|
+
arr[I] = ndarray[I]
|
54
|
+
|
55
|
+
|
56
|
+
@kernel
|
57
|
+
def ndarray_matrix_to_ext_arr(
|
58
|
+
ndarray: ndarray_type.ndarray(),
|
59
|
+
arr: ndarray_type.ndarray(),
|
60
|
+
layout_is_aos: template(),
|
61
|
+
as_vector: template(),
|
62
|
+
):
|
63
|
+
for I in grouped(ndarray):
|
64
|
+
for p in static(range(ndarray[I].n)):
|
65
|
+
if static(as_vector):
|
66
|
+
if static(layout_is_aos):
|
67
|
+
arr[I, p] = ndarray[I][p]
|
68
|
+
else:
|
69
|
+
arr[p, I] = ndarray[I][p]
|
70
|
+
else:
|
71
|
+
for q in static(range(ndarray[I].m)):
|
72
|
+
if static(layout_is_aos):
|
73
|
+
arr[I, p, q] = ndarray[I][p, q]
|
74
|
+
else:
|
75
|
+
arr[p, q, I] = ndarray[I][p, q]
|
76
|
+
|
77
|
+
|
78
|
+
@kernel
|
79
|
+
def vector_to_fast_image(img: template(), out: ndarray_type.ndarray()):
|
80
|
+
static_assert(len(img.shape) == 2)
|
81
|
+
offset = static(img.snode.ptr.offset if len(img.snode.ptr.offset) != 0 else [0, 0])
|
82
|
+
i_offset = static(offset[0])
|
83
|
+
j_offset = static(offset[1])
|
84
|
+
# FIXME: Why is ``for i, j in img:`` slower than:
|
85
|
+
for i, j in ndrange(*img.shape):
|
86
|
+
r, g, b = 0, 0, 0
|
87
|
+
color = img[i + i_offset, (img.shape[1] + j_offset) - 1 - j]
|
88
|
+
if static(img.dtype in [f16, f32, f64]):
|
89
|
+
r, g, b = ops.min(255, ops.max(0, int(color * 255)))[:3]
|
90
|
+
else:
|
91
|
+
static_assert(img.dtype == u8)
|
92
|
+
r, g, b = color[:3]
|
93
|
+
|
94
|
+
idx = j * img.shape[0] + i
|
95
|
+
# We use i32 for |out| since OpenGL and Metal doesn't support u8 types
|
96
|
+
if static(get_os_name() != "osx"):
|
97
|
+
out[idx] = (r << 16) + (g << 8) + b
|
98
|
+
else:
|
99
|
+
# What's -16777216?
|
100
|
+
#
|
101
|
+
# On Mac, we need to set the alpha channel to 0xff. Since Mac's GUI
|
102
|
+
# is big-endian, the color is stored in ABGR order, and we need to
|
103
|
+
# add 0xff000000, which is -16777216 in I32's legit range. (Albeit
|
104
|
+
# the clarity, adding 0xff000000 doesn't work.)
|
105
|
+
alpha = -16777216
|
106
|
+
out[idx] = (b << 16) + (g << 8) + r + alpha
|
107
|
+
|
108
|
+
|
109
|
+
@kernel
|
110
|
+
def tensor_to_image(tensor: template(), arr: ndarray_type.ndarray()):
|
111
|
+
# default value of offset is [], replace it with [0] * len
|
112
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
113
|
+
for I in grouped(tensor):
|
114
|
+
t = ops.cast(tensor[I], f32)
|
115
|
+
arr[I - offset, 0] = t
|
116
|
+
arr[I - offset, 1] = t
|
117
|
+
arr[I - offset, 2] = t
|
118
|
+
|
119
|
+
|
120
|
+
@kernel
|
121
|
+
def vector_to_image(mat: template(), arr: ndarray_type.ndarray()):
|
122
|
+
# default value of offset is [], replace it with [0] * len
|
123
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
124
|
+
for I in grouped(mat):
|
125
|
+
for p in static(range(mat.n)):
|
126
|
+
arr[I - offset, p] = ops.cast(mat[I][p], f32)
|
127
|
+
if static(mat.n <= 2):
|
128
|
+
arr[I - offset, 2] = 0
|
129
|
+
|
130
|
+
|
131
|
+
@kernel
|
132
|
+
def tensor_to_tensor(tensor: template(), other: template()):
|
133
|
+
static_assert(tensor.shape == other.shape)
|
134
|
+
shape = static(tensor.shape)
|
135
|
+
tensor_offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(shape))
|
136
|
+
other_offset = static(other.snode.ptr.offset if len(other.snode.ptr.offset) != 0 else [0] * len(shape))
|
137
|
+
|
138
|
+
for I in grouped(ndrange(*shape)):
|
139
|
+
tensor[I + tensor_offset] = other[I + other_offset]
|
140
|
+
|
141
|
+
|
142
|
+
@kernel
|
143
|
+
def ext_arr_to_tensor(arr: ndarray_type.ndarray(), tensor: template()):
|
144
|
+
# default value of offset is [], replace it with [0] * len
|
145
|
+
offset = static(tensor.snode.ptr.offset if len(tensor.snode.ptr.offset) != 0 else [0] * len(tensor.shape))
|
146
|
+
for I in grouped(tensor):
|
147
|
+
tensor[I] = arr[I - offset]
|
148
|
+
|
149
|
+
|
150
|
+
@kernel
|
151
|
+
def ndarray_to_ndarray(ndarray: ndarray_type.ndarray(), other: ndarray_type.ndarray()):
|
152
|
+
for I in grouped(ndarray):
|
153
|
+
ndarray[I] = other[I]
|
154
|
+
|
155
|
+
|
156
|
+
@kernel
|
157
|
+
def ext_arr_to_ndarray(arr: ndarray_type.ndarray(), ndarray: ndarray_type.ndarray()):
|
158
|
+
for I in grouped(ndarray):
|
159
|
+
ndarray[I] = arr[I]
|
160
|
+
|
161
|
+
|
162
|
+
@kernel
|
163
|
+
def ext_arr_to_ndarray_matrix(
|
164
|
+
arr: ndarray_type.ndarray(),
|
165
|
+
ndarray: ndarray_type.ndarray(),
|
166
|
+
layout_is_aos: template(),
|
167
|
+
as_vector: template(),
|
168
|
+
):
|
169
|
+
for I in grouped(ndarray):
|
170
|
+
for p in static(range(ndarray[I].n)):
|
171
|
+
if static(as_vector):
|
172
|
+
if static(layout_is_aos):
|
173
|
+
ndarray[I][p] = arr[I, p]
|
174
|
+
else:
|
175
|
+
ndarray[I][p] = arr[p, I]
|
176
|
+
else:
|
177
|
+
for q in static(range(ndarray[I].m)):
|
178
|
+
if static(layout_is_aos):
|
179
|
+
ndarray[I][p, q] = arr[I, p, q]
|
180
|
+
else:
|
181
|
+
ndarray[I][p, q] = arr[p, q, I]
|
182
|
+
|
183
|
+
|
184
|
+
@kernel
|
185
|
+
def matrix_to_ext_arr(mat: template(), arr: ndarray_type.ndarray(), as_vector: template()):
|
186
|
+
# default value of offset is [], replace it with [0] * len
|
187
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
188
|
+
|
189
|
+
for I in grouped(mat):
|
190
|
+
for p in static(range(mat.n)):
|
191
|
+
for q in static(range(mat.m)):
|
192
|
+
if static(as_vector):
|
193
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
194
|
+
arr[I - offset, p] = mat[I][p]
|
195
|
+
else:
|
196
|
+
arr[I - offset, p] = mat[I][p, q]
|
197
|
+
else:
|
198
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
199
|
+
arr[I - offset, p, q] = mat[I][p]
|
200
|
+
else:
|
201
|
+
arr[I - offset, p, q] = mat[I][p, q]
|
202
|
+
|
203
|
+
|
204
|
+
@kernel
|
205
|
+
def ext_arr_to_matrix(arr: ndarray_type.ndarray(), mat: template(), as_vector: template()):
|
206
|
+
# default value of offset is [], replace it with [0] * len
|
207
|
+
offset = static(mat.snode.ptr.offset if len(mat.snode.ptr.offset) != 0 else [0] * len(mat.shape))
|
208
|
+
|
209
|
+
for I in grouped(mat):
|
210
|
+
for p in static(range(mat.n)):
|
211
|
+
for q in static(range(mat.m)):
|
212
|
+
if static(getattr(mat, "ndim", 2) == 1):
|
213
|
+
if static(as_vector):
|
214
|
+
mat[I][p] = arr[I - offset, p]
|
215
|
+
else:
|
216
|
+
mat[I][p] = arr[I - offset, p, q]
|
217
|
+
else:
|
218
|
+
if static(as_vector):
|
219
|
+
mat[I][p, q] = arr[I - offset, p]
|
220
|
+
else:
|
221
|
+
mat[I][p, q] = arr[I - offset, p, q]
|
222
|
+
|
223
|
+
|
224
|
+
# extract ndarray of raw vulkan memory layout to normal memory layout.
|
225
|
+
# the vulkan layout stored in ndarray : width-by-width stored along n-
|
226
|
+
# darray's shape[1] which is the height-axis(So use [size // h, size %
|
227
|
+
# h]). And the height-order of vulkan layout is flip up-down.(So take
|
228
|
+
# [size = (h - 1 - j) * w + i] to get the index)
|
229
|
+
@kernel
|
230
|
+
def arr_vulkan_layout_to_arr_normal_layout(vk_arr: ndarray_type.ndarray(), normal_arr: ndarray_type.ndarray()):
|
231
|
+
static_assert(len(normal_arr.shape) == 2)
|
232
|
+
w = normal_arr.shape[0]
|
233
|
+
h = normal_arr.shape[1]
|
234
|
+
for i, j in ndrange(w, h):
|
235
|
+
normal_arr[i, j] = vk_arr[(h - 1 - j) * w + i]
|
236
|
+
|
237
|
+
|
238
|
+
# extract ndarray of raw vulkan memory layout into a taichi-field data
|
239
|
+
# structure with normal memory layout.
|
240
|
+
@kernel
|
241
|
+
def arr_vulkan_layout_to_field_normal_layout(vk_arr: ndarray_type.ndarray(), normal_field: template()):
|
242
|
+
static_assert(len(normal_field.shape) == 2)
|
243
|
+
w = static(normal_field.shape[0])
|
244
|
+
h = static(normal_field.shape[1])
|
245
|
+
offset = static(normal_field.snode.ptr.offset if len(normal_field.snode.ptr.offset) != 0 else [0, 0])
|
246
|
+
i_offset = static(offset[0])
|
247
|
+
j_offset = static(offset[1])
|
248
|
+
|
249
|
+
for i, j in ndrange(w, h):
|
250
|
+
normal_field[i + i_offset, j + j_offset] = vk_arr[(h - 1 - j) * w + i]
|
251
|
+
|
252
|
+
|
253
|
+
@kernel
|
254
|
+
def clear_gradients(_vars: template()):
|
255
|
+
for I in grouped(ScalarField(Expr(_vars[0]))):
|
256
|
+
for s in static(_vars):
|
257
|
+
ScalarField(Expr(s))[I] = ops.cast(0, dtype=s.get_dt())
|
258
|
+
|
259
|
+
|
260
|
+
@kernel
|
261
|
+
def field_fill_python_scope(F: template(), val: template()):
|
262
|
+
field_fill_taichi_scope(F, val)
|
263
|
+
|
264
|
+
|
265
|
+
@kernel
|
266
|
+
def snode_deactivate(b: template()):
|
267
|
+
for I in grouped(b):
|
268
|
+
deactivate(b, I)
|
269
|
+
|
270
|
+
|
271
|
+
@kernel
|
272
|
+
def snode_deactivate_dynamic(b: template()):
|
273
|
+
for I in grouped(b.parent()):
|
274
|
+
deactivate(b, I)
|
275
|
+
|
276
|
+
|
277
|
+
@kernel
|
278
|
+
def load_texture_from_numpy(
|
279
|
+
tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
|
280
|
+
img: ndarray_type.ndarray(dtype=vec3, ndim=2),
|
281
|
+
):
|
282
|
+
for i, j in img:
|
283
|
+
tex.store(
|
284
|
+
vector(2, i32)([i, j]),
|
285
|
+
vector(4, f32)([img[i, j][0], img[i, j][1], img[i, j][2], 0]) / 255.0,
|
286
|
+
)
|
287
|
+
|
288
|
+
|
289
|
+
@kernel
|
290
|
+
def save_texture_to_numpy(
|
291
|
+
tex: texture_type.rw_texture(num_dimensions=2, fmt=Format.rgba8, lod=0),
|
292
|
+
img: ndarray_type.ndarray(dtype=vec3, ndim=2),
|
293
|
+
):
|
294
|
+
for i, j in img:
|
295
|
+
img[i, j] = ops.round(tex.load(vector(2, i32)([i, j])).rgb * 255)
|
296
|
+
|
297
|
+
|
298
|
+
# Odd-even merge sort
|
299
|
+
@kernel
|
300
|
+
def sort_stage(
|
301
|
+
keys: template(),
|
302
|
+
use_values: int,
|
303
|
+
values: template(),
|
304
|
+
N: int,
|
305
|
+
p: int,
|
306
|
+
k: int,
|
307
|
+
invocations: int,
|
308
|
+
):
|
309
|
+
keys_offset = static(keys.snode.ptr.offset if len(keys.snode.ptr.offset) != 0 else 0)
|
310
|
+
values_offset = static(values.snode.ptr.offset if len(values.snode.ptr.offset) != 0 else 0)
|
311
|
+
for inv in range(invocations):
|
312
|
+
j = k % p + inv * 2 * k
|
313
|
+
for i in range(0, ops.min(k, N - j - k)):
|
314
|
+
a = i + j
|
315
|
+
b = i + j + k
|
316
|
+
if int(a / (p * 2)) == int(b / (p * 2)):
|
317
|
+
key_a = keys[a + keys_offset]
|
318
|
+
key_b = keys[b + keys_offset]
|
319
|
+
if key_a > key_b:
|
320
|
+
keys[a + keys_offset] = key_b
|
321
|
+
keys[b + keys_offset] = key_a
|
322
|
+
if use_values != 0:
|
323
|
+
temp = values[a + values_offset]
|
324
|
+
values[a + values_offset] = values[b + values_offset]
|
325
|
+
values[b + values_offset] = temp
|
326
|
+
|
327
|
+
|
328
|
+
# Parallel Prefix Sum (Scan)
|
329
|
+
@func
|
330
|
+
def warp_shfl_up_i32(val: template()):
|
331
|
+
global_tid = block.global_thread_idx()
|
332
|
+
WARP_SZ = 32
|
333
|
+
lane_id = global_tid % WARP_SZ
|
334
|
+
# Intra-warp scan, manually unrolled
|
335
|
+
offset_j = 1
|
336
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
337
|
+
if lane_id >= offset_j:
|
338
|
+
val += n
|
339
|
+
offset_j = 2
|
340
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
341
|
+
if lane_id >= offset_j:
|
342
|
+
val += n
|
343
|
+
offset_j = 4
|
344
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
345
|
+
if lane_id >= offset_j:
|
346
|
+
val += n
|
347
|
+
offset_j = 8
|
348
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
349
|
+
if lane_id >= offset_j:
|
350
|
+
val += n
|
351
|
+
offset_j = 16
|
352
|
+
n = warp.shfl_up_i32(warp.active_mask(), val, offset_j)
|
353
|
+
if lane_id >= offset_j:
|
354
|
+
val += n
|
355
|
+
return val
|
356
|
+
|
357
|
+
|
358
|
+
@kernel
|
359
|
+
def scan_add_inclusive(
|
360
|
+
arr_in: template(),
|
361
|
+
in_beg: i32,
|
362
|
+
in_end: i32,
|
363
|
+
single_block: template(),
|
364
|
+
inclusive_add: template(),
|
365
|
+
):
|
366
|
+
WARP_SZ = 32
|
367
|
+
BLOCK_SZ = 64
|
368
|
+
loop_config(block_dim=64)
|
369
|
+
for i in range(in_beg, in_end):
|
370
|
+
val = arr_in[i]
|
371
|
+
|
372
|
+
thread_id = i % BLOCK_SZ
|
373
|
+
block_id = int((i - in_beg) // BLOCK_SZ)
|
374
|
+
lane_id = thread_id % WARP_SZ
|
375
|
+
warp_id = thread_id // WARP_SZ
|
376
|
+
|
377
|
+
pad_shared = block.SharedArray((65,), i32)
|
378
|
+
|
379
|
+
val = inclusive_add(val)
|
380
|
+
block.sync()
|
381
|
+
|
382
|
+
# Put warp scan results to smem
|
383
|
+
# TODO replace smem with real smem when available
|
384
|
+
if thread_id % WARP_SZ == WARP_SZ - 1:
|
385
|
+
pad_shared[warp_id] = val
|
386
|
+
block.sync()
|
387
|
+
|
388
|
+
# Inter-warp scan, use the first thread in the first warp
|
389
|
+
if warp_id == 0 and lane_id == 0:
|
390
|
+
for k in range(1, BLOCK_SZ / WARP_SZ):
|
391
|
+
pad_shared[k] += pad_shared[k - 1]
|
392
|
+
block.sync()
|
393
|
+
|
394
|
+
# Update data with warp sums
|
395
|
+
warp_sum = 0
|
396
|
+
if warp_id > 0:
|
397
|
+
warp_sum = pad_shared[warp_id - 1]
|
398
|
+
val += warp_sum
|
399
|
+
arr_in[i] = val
|
400
|
+
|
401
|
+
# Update partial sums except the final block
|
402
|
+
if not single_block and (thread_id == BLOCK_SZ - 1):
|
403
|
+
arr_in[in_end + block_id] = val
|
404
|
+
|
405
|
+
|
406
|
+
@kernel
|
407
|
+
def uniform_add(arr_in: template(), in_beg: i32, in_end: i32):
|
408
|
+
BLOCK_SZ = 64
|
409
|
+
loop_config(block_dim=64)
|
410
|
+
for i in range(in_beg + BLOCK_SZ, in_end):
|
411
|
+
block_id = int((i - in_beg) // BLOCK_SZ)
|
412
|
+
arr_in[i] += arr_in[in_end + block_id - 1]
|
413
|
+
|
414
|
+
|
415
|
+
@kernel
|
416
|
+
def blit_from_field_to_field(dst: template(), src: template(), offset: i32, size: i32):
|
417
|
+
dst_offset = static(dst.snode.ptr.offset if len(dst.snode.ptr.offset) != 0 else 0)
|
418
|
+
src_offset = static(src.snode.ptr.offset if len(src.snode.ptr.offset) != 0 else 0)
|
419
|
+
for i in range(size):
|
420
|
+
dst[i + dst_offset + offset] = src[i + src_offset]
|
taichi/_lib/__init__.py
ADDED