gstaichi 2.1.1__cp310-cp310-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/__init__.py +40 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +3 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
- gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version.py +1 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +110 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +75 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +352 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +5 -0
- gstaichi/lang/ast/ast_transformer.py +1323 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1245 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1341 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +780 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +8 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +19 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
- gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
- gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
- gstaichi-2.1.1.dist-info/METADATA +106 -0
- gstaichi-2.1.1.dist-info/RECORD +178 -0
- gstaichi-2.1.1.dist-info/WHEEL +5 -0
- gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
- gstaichi-2.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,117 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi._kernels import (
|
4
|
+
blit_from_field_to_field,
|
5
|
+
scan_add_inclusive,
|
6
|
+
sort_stage,
|
7
|
+
uniform_add,
|
8
|
+
warp_shfl_up_i32,
|
9
|
+
)
|
10
|
+
from gstaichi.lang.impl import current_cfg, field
|
11
|
+
from gstaichi.lang.kernel_impl import data_oriented
|
12
|
+
from gstaichi.lang.misc import cuda, vulkan
|
13
|
+
from gstaichi.lang.runtime_ops import sync
|
14
|
+
from gstaichi.lang.simt import subgroup
|
15
|
+
from gstaichi.types.primitive_types import i32
|
16
|
+
|
17
|
+
|
18
|
+
def parallel_sort(keys, values=None):
|
19
|
+
"""Odd-even merge sort
|
20
|
+
|
21
|
+
References:
|
22
|
+
https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
|
23
|
+
https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
|
24
|
+
"""
|
25
|
+
N = keys.shape[0]
|
26
|
+
|
27
|
+
num_stages = 0
|
28
|
+
p = 1
|
29
|
+
while p < N:
|
30
|
+
k = p
|
31
|
+
while k >= 1:
|
32
|
+
invocations = int((N - k - k % p) / (2 * k)) + 1
|
33
|
+
if values is None:
|
34
|
+
sort_stage(keys, 0, keys, N, p, k, invocations)
|
35
|
+
else:
|
36
|
+
sort_stage(keys, 1, values, N, p, k, invocations)
|
37
|
+
num_stages += 1
|
38
|
+
sync()
|
39
|
+
k = int(k / 2)
|
40
|
+
p = int(p * 2)
|
41
|
+
|
42
|
+
|
43
|
+
@data_oriented
|
44
|
+
class PrefixSumExecutor:
|
45
|
+
"""Parallel Prefix Sum (Scan) Helper
|
46
|
+
|
47
|
+
Use this helper to perform an inclusive in-place's parallel prefix sum.
|
48
|
+
|
49
|
+
References:
|
50
|
+
https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
|
51
|
+
https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self, length):
|
55
|
+
self.sorting_length = length
|
56
|
+
|
57
|
+
BLOCK_SZ = 64
|
58
|
+
GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)
|
59
|
+
|
60
|
+
# Buffer position and length
|
61
|
+
# This is a single buffer implementation for ease of aot usage
|
62
|
+
ele_num = length
|
63
|
+
self.ele_nums = [ele_num]
|
64
|
+
start_pos = 0
|
65
|
+
self.ele_nums_pos = [start_pos]
|
66
|
+
|
67
|
+
while ele_num > 1:
|
68
|
+
ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
|
69
|
+
self.ele_nums.append(ele_num)
|
70
|
+
start_pos += BLOCK_SZ * ele_num
|
71
|
+
self.ele_nums_pos.append(start_pos)
|
72
|
+
|
73
|
+
self.large_arr = field(i32, shape=start_pos)
|
74
|
+
|
75
|
+
def run(self, input_arr):
|
76
|
+
length = self.sorting_length
|
77
|
+
ele_nums = self.ele_nums
|
78
|
+
ele_nums_pos = self.ele_nums_pos
|
79
|
+
|
80
|
+
if input_arr.dtype != i32:
|
81
|
+
raise RuntimeError("Only ti.i32 type is supported for prefix sum.")
|
82
|
+
|
83
|
+
if current_cfg().arch == cuda:
|
84
|
+
inclusive_add = warp_shfl_up_i32
|
85
|
+
elif current_cfg().arch == vulkan:
|
86
|
+
inclusive_add = subgroup.inclusive_add
|
87
|
+
else:
|
88
|
+
raise RuntimeError(f"{str(current_cfg().arch)} is not supported for prefix sum.")
|
89
|
+
|
90
|
+
blit_from_field_to_field(self.large_arr, input_arr, 0, length)
|
91
|
+
|
92
|
+
# Kogge-Stone construction
|
93
|
+
for i in range(len(ele_nums) - 1):
|
94
|
+
if i == len(ele_nums) - 2:
|
95
|
+
scan_add_inclusive(
|
96
|
+
self.large_arr,
|
97
|
+
ele_nums_pos[i],
|
98
|
+
ele_nums_pos[i + 1],
|
99
|
+
True,
|
100
|
+
inclusive_add,
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
scan_add_inclusive(
|
104
|
+
self.large_arr,
|
105
|
+
ele_nums_pos[i],
|
106
|
+
ele_nums_pos[i + 1],
|
107
|
+
False,
|
108
|
+
inclusive_add,
|
109
|
+
)
|
110
|
+
|
111
|
+
for i in range(len(ele_nums) - 3, -1, -1):
|
112
|
+
uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])
|
113
|
+
|
114
|
+
blit_from_field_to_field(input_arr, self.large_arr, 0, length)
|
115
|
+
|
116
|
+
|
117
|
+
__all__ = ["parallel_sort", "PrefixSumExecutor"]
|
gstaichi/assets/.git
ADDED
@@ -0,0 +1 @@
|
|
1
|
+
gitdir: ../../.git/modules/external/assets
|
Binary file
|
Binary file
|
@@ -0,0 +1,26 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import numpy.typing as npt
|
5
|
+
|
6
|
+
|
7
|
+
def lcg_np(B: int, lcg_its: int, a: npt.NDArray) -> None:
|
8
|
+
for i in range(B):
|
9
|
+
x = a[i]
|
10
|
+
for j in range(lcg_its):
|
11
|
+
x = (1664525 * x + 1013904223) % 2147483647
|
12
|
+
a[i] = x
|
13
|
+
|
14
|
+
|
15
|
+
def main() -> None:
|
16
|
+
B = 16000
|
17
|
+
a = np.ndarray((B,), np.int32)
|
18
|
+
|
19
|
+
start = time.time()
|
20
|
+
lcg_np(B, 1000, a)
|
21
|
+
end = time.time()
|
22
|
+
print("elapsed", end - start)
|
23
|
+
# elapsed 5.552601099014282 on macbook air m4
|
24
|
+
|
25
|
+
|
26
|
+
main()
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import time
|
2
|
+
|
3
|
+
import gstaichi as ti
|
4
|
+
|
5
|
+
|
6
|
+
@ti.kernel
|
7
|
+
def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
|
8
|
+
for i in range(B):
|
9
|
+
x = a[i]
|
10
|
+
for j in range(lcg_its):
|
11
|
+
x = (1664525 * x + 1013904223) % 2147483647
|
12
|
+
a[i] = x
|
13
|
+
|
14
|
+
|
15
|
+
def main() -> None:
|
16
|
+
ti.init(arch=ti.gpu)
|
17
|
+
|
18
|
+
B = 16000
|
19
|
+
a = ti.ndarray(ti.int32, (B,))
|
20
|
+
|
21
|
+
ti.sync()
|
22
|
+
start = time.time()
|
23
|
+
lcg_ti(B, 1000, a)
|
24
|
+
ti.sync()
|
25
|
+
end = time.time()
|
26
|
+
print("elapsed", end - start)
|
27
|
+
|
28
|
+
# [GsTaichi] version 1.8.0, llvm 15.0.7, commit 5afed1c9, osx, python 3.10.16
|
29
|
+
# [GsTaichi] Starting on arch=metal
|
30
|
+
# elapsed 0.04660296440124512
|
31
|
+
# (on mac air m4)
|
32
|
+
|
33
|
+
|
34
|
+
main()
|
@@ -0,0 +1,28 @@
|
|
1
|
+
import gstaichi as ti
|
2
|
+
|
3
|
+
|
4
|
+
@ti.kernel
|
5
|
+
def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
|
6
|
+
"""
|
7
|
+
Linear congruential generator https://en.wikipedia.org/wiki/Linear_congruential_generator
|
8
|
+
"""
|
9
|
+
for i in range(B):
|
10
|
+
x = a[i]
|
11
|
+
for j in range(lcg_its):
|
12
|
+
x = (1664525 * x + 1013904223) % 2147483647
|
13
|
+
a[i] = x
|
14
|
+
|
15
|
+
|
16
|
+
def main() -> None:
|
17
|
+
ti.init(arch=ti.cpu)
|
18
|
+
|
19
|
+
B = 10
|
20
|
+
lcg_its = 10
|
21
|
+
|
22
|
+
a = ti.ndarray(ti.int32, (B,))
|
23
|
+
|
24
|
+
lcg_ti(B, lcg_its, a)
|
25
|
+
print(f"LCG for B={B}, lcg_its={lcg_its}: ", a.to_numpy()) # pylint: disable=no-member
|
26
|
+
|
27
|
+
|
28
|
+
main()
|
gstaichi/experimental.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import warnings
|
4
|
+
|
5
|
+
from gstaichi.lang.kernel_impl import real_func as _real_func
|
6
|
+
|
7
|
+
|
8
|
+
def real_func(func):
|
9
|
+
warnings.warn(
|
10
|
+
"ti.experimental.real_func is deprecated because it is no longer experimental. " "Use ti.real_func instead.",
|
11
|
+
DeprecationWarning,
|
12
|
+
)
|
13
|
+
return _real_func(func)
|
14
|
+
|
15
|
+
|
16
|
+
__all__ = ["real_func"]
|
@@ -0,0 +1,50 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from gstaichi.lang import impl, simt
|
4
|
+
from gstaichi.lang._fast_caching.function_hasher import pure
|
5
|
+
from gstaichi.lang._ndarray import *
|
6
|
+
from gstaichi.lang._ndrange import ndrange
|
7
|
+
from gstaichi.lang._texture import Texture
|
8
|
+
from gstaichi.lang.exception import *
|
9
|
+
from gstaichi.lang.field import *
|
10
|
+
from gstaichi.lang.impl import *
|
11
|
+
from gstaichi.lang.kernel_impl import *
|
12
|
+
from gstaichi.lang.matrix import *
|
13
|
+
from gstaichi.lang.mesh import *
|
14
|
+
from gstaichi.lang.misc import * # pylint: disable=W0622
|
15
|
+
from gstaichi.lang.ops import * # pylint: disable=W0622
|
16
|
+
from gstaichi.lang.runtime_ops import *
|
17
|
+
from gstaichi.lang.snode import *
|
18
|
+
from gstaichi.lang.source_builder import *
|
19
|
+
from gstaichi.lang.struct import *
|
20
|
+
from gstaichi.types.enums import DeviceCapability, Format, Layout
|
21
|
+
|
22
|
+
__all__ = [
|
23
|
+
s
|
24
|
+
for s in dir()
|
25
|
+
if not s.startswith("_")
|
26
|
+
and s
|
27
|
+
not in [
|
28
|
+
"any_array",
|
29
|
+
"ast",
|
30
|
+
"common_ops",
|
31
|
+
"enums",
|
32
|
+
"exception",
|
33
|
+
"expr",
|
34
|
+
"impl",
|
35
|
+
"inspect",
|
36
|
+
"kernel_arguments",
|
37
|
+
"kernel_impl",
|
38
|
+
"matrix",
|
39
|
+
"mesh",
|
40
|
+
"misc",
|
41
|
+
"ops",
|
42
|
+
"platform",
|
43
|
+
"runtime_ops",
|
44
|
+
"shell",
|
45
|
+
"snode",
|
46
|
+
"source_builder",
|
47
|
+
"struct",
|
48
|
+
"util",
|
49
|
+
]
|
50
|
+
]
|
@@ -0,0 +1,31 @@
|
|
1
|
+
def create_flat_name(basename: str, child_name: str) -> str:
|
2
|
+
"""
|
3
|
+
Appends child_name to basename, separated by __ti_.
|
4
|
+
If basename does not start with __ti_ then prefix the resulting string
|
5
|
+
with __ti_.
|
6
|
+
|
7
|
+
Note that we want to avoid adding prefix __ti_ if already included in `basename`,
|
8
|
+
to avoid duplicating said delimiter.
|
9
|
+
|
10
|
+
We'll use this when expanding py dataclass members, e.g.
|
11
|
+
|
12
|
+
@dataclasses.dataclass
|
13
|
+
def Foo:
|
14
|
+
a: int
|
15
|
+
b: int
|
16
|
+
|
17
|
+
foo = Foo(a=5, b=3)
|
18
|
+
|
19
|
+
When we expand out foo, we'll replace foo with the following names instead:
|
20
|
+
- __ti_foo__ti_a
|
21
|
+
- __ti_foo__ti_b
|
22
|
+
|
23
|
+
We use the __ti_ to ensure that it's easy to ensure no collision with existing user-defined
|
24
|
+
names. We require the user to not create any fields or variables which themselves are prefixed
|
25
|
+
with __ti_, and given this constraint, the names we create will not conflict with user-generated
|
26
|
+
names.
|
27
|
+
"""
|
28
|
+
full_name = f"{basename}__ti_{child_name}"
|
29
|
+
if not full_name.startswith("__ti_"):
|
30
|
+
full_name = f"__ti_{full_name}"
|
31
|
+
return full_name
|
@@ -0,0 +1,110 @@
|
|
1
|
+
import dataclasses
|
2
|
+
import enum
|
3
|
+
import numbers
|
4
|
+
import time
|
5
|
+
from typing import Any, Sequence
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
|
9
|
+
from .._ndarray import ScalarNdarray
|
10
|
+
from ..field import ScalarField
|
11
|
+
from ..matrix import MatrixField, MatrixNdarray, VectorNdarray
|
12
|
+
from ..util import is_data_oriented
|
13
|
+
from .hash_utils import hash_iterable_strings
|
14
|
+
|
15
|
+
g_num_calls = 0
|
16
|
+
g_num_args = 0
|
17
|
+
g_hashing_time = 0
|
18
|
+
g_repr_time = 0
|
19
|
+
g_num_ignored_calls = 0
|
20
|
+
|
21
|
+
|
22
|
+
FIELD_METADATA_CACHE_VALUE = "add_value_to_cache_key"
|
23
|
+
|
24
|
+
|
25
|
+
def dataclass_to_repr(path: tuple[str, ...], arg: Any) -> str:
|
26
|
+
repr_l = []
|
27
|
+
for field in dataclasses.fields(arg):
|
28
|
+
child_value = getattr(arg, field.name)
|
29
|
+
_repr = stringify_obj_type(path + (field.name,), child_value)
|
30
|
+
full_repr = f"{field.name}: ({_repr})"
|
31
|
+
if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False):
|
32
|
+
full_repr += f" = {child_value}"
|
33
|
+
repr_l.append(full_repr)
|
34
|
+
return "[" + ",".join(repr_l) + "]"
|
35
|
+
|
36
|
+
|
37
|
+
def stringify_obj_type(path: tuple[str, ...], obj: Any) -> str | None:
|
38
|
+
"""
|
39
|
+
Convert an object into a string representation that only depends on its type.
|
40
|
+
|
41
|
+
String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have
|
42
|
+
to be the actual python type string, just a string that is representative of the type, and won't collide
|
43
|
+
with different (allowed) types.
|
44
|
+
|
45
|
+
`path` is used during debugging.
|
46
|
+
"""
|
47
|
+
# TODO: We should have a way of printing this without having to hack the code really. Using logger perhaps?
|
48
|
+
# (I have another PR that addreses this https://github.com/Genesis-Embodied-AI/gstaichi/pull/144/files)
|
49
|
+
arg_type = type(obj)
|
50
|
+
if isinstance(obj, ScalarNdarray):
|
51
|
+
return f"[nd-{obj.dtype}-{len(obj.shape)}]"
|
52
|
+
if isinstance(obj, VectorNdarray):
|
53
|
+
return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
|
54
|
+
if isinstance(obj, ScalarField):
|
55
|
+
return f"[f-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
|
56
|
+
if isinstance(obj, MatrixNdarray):
|
57
|
+
return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
|
58
|
+
if "torch.Tensor" in str(arg_type):
|
59
|
+
return f"[pt-{obj.dtype}-{obj.ndim}]"
|
60
|
+
if isinstance(obj, np.ndarray):
|
61
|
+
return f"[np-{obj.dtype}-{obj.ndim}]"
|
62
|
+
if isinstance(obj, MatrixField):
|
63
|
+
return f"[fm-{obj.m}-{obj.n}-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
|
64
|
+
if dataclasses.is_dataclass(obj):
|
65
|
+
return dataclass_to_repr(path, obj)
|
66
|
+
if is_data_oriented(obj):
|
67
|
+
child_repr_l = []
|
68
|
+
for k, v in obj.__dict__.items():
|
69
|
+
_child_repr = stringify_obj_type((*path, k), v)
|
70
|
+
if _child_repr is None:
|
71
|
+
print("not representable child", k, type(v), "path", path)
|
72
|
+
return None
|
73
|
+
child_repr_l.append(f"{k}: {_child_repr}")
|
74
|
+
return ", ".join(child_repr_l)
|
75
|
+
if issubclass(arg_type, (numbers.Number, np.number)):
|
76
|
+
return str(arg_type)
|
77
|
+
if arg_type is np.bool_:
|
78
|
+
# np is deprecating bool. Treat specially/carefully
|
79
|
+
return "np.bool_"
|
80
|
+
if isinstance(obj, enum.Enum):
|
81
|
+
return f"enum-{obj.name}-{obj.value}"
|
82
|
+
return None
|
83
|
+
|
84
|
+
|
85
|
+
def hash_args(args: Sequence[Any]) -> str | None:
|
86
|
+
global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls
|
87
|
+
g_num_calls += 1
|
88
|
+
g_num_args += len(args)
|
89
|
+
hash_l = []
|
90
|
+
for i_arg, arg in enumerate(args):
|
91
|
+
start = time.time()
|
92
|
+
_hash = stringify_obj_type((str(i_arg),), arg)
|
93
|
+
g_repr_time += time.time() - start
|
94
|
+
if not _hash:
|
95
|
+
g_num_ignored_calls += 1
|
96
|
+
return None
|
97
|
+
hash_l.append(_hash)
|
98
|
+
start = time.time()
|
99
|
+
res = hash_iterable_strings(hash_l)
|
100
|
+
g_hashing_time += time.time() - start
|
101
|
+
return res
|
102
|
+
|
103
|
+
|
104
|
+
def dump_stats() -> None:
|
105
|
+
print("args hasher dump stats")
|
106
|
+
print("total calls", g_num_calls)
|
107
|
+
print("ignored calls", g_num_ignored_calls)
|
108
|
+
print("total args", g_num_args)
|
109
|
+
print("hashing time", g_hashing_time)
|
110
|
+
print("arg representation time", g_repr_time)
|
@@ -0,0 +1,30 @@
|
|
1
|
+
from gstaichi.lang import impl
|
2
|
+
|
3
|
+
from .hash_utils import hash_iterable_strings
|
4
|
+
|
5
|
+
EXCLUDE_PREFIXES = ["_", "offline_cache", "print_", "verbose_"]
|
6
|
+
|
7
|
+
|
8
|
+
def hash_compile_config() -> str:
|
9
|
+
"""
|
10
|
+
Calculates a hash string for the current compiler config.
|
11
|
+
|
12
|
+
If any value in the compiler config changes, the hash string changes too.
|
13
|
+
|
14
|
+
Though arguably we might want to blacklist certain keys, such as print_ir_debug,
|
15
|
+
which do not affect the compiled kernels, just stuff that gets printed during
|
16
|
+
the compilation process.
|
17
|
+
"""
|
18
|
+
config = impl.get_runtime().prog.config()
|
19
|
+
config_l = []
|
20
|
+
for k in dir(config):
|
21
|
+
skip = False
|
22
|
+
for prefix in EXCLUDE_PREFIXES:
|
23
|
+
if k.startswith(prefix) or k in [""]:
|
24
|
+
skip = True
|
25
|
+
if skip:
|
26
|
+
continue
|
27
|
+
v = getattr(config, k)
|
28
|
+
config_l.append(f"{k}={v}")
|
29
|
+
config_hash = hash_iterable_strings(config_l, separator="\n")
|
30
|
+
return config_hash
|
@@ -0,0 +1,21 @@
|
|
1
|
+
from pydantic import BaseModel
|
2
|
+
|
3
|
+
from .._wrap_inspect import FunctionSourceInfo
|
4
|
+
|
5
|
+
|
6
|
+
class HashedFunctionSourceInfo(BaseModel):
|
7
|
+
"""
|
8
|
+
Wraps a function source info, and the hash string of that function.
|
9
|
+
|
10
|
+
By not adding the hash directly into function source info, we avoid
|
11
|
+
having to make hash an optional type, and checking if it's empty or not.
|
12
|
+
|
13
|
+
If you have a HashedFunctionSourceInfo object, then you are guaranteed
|
14
|
+
to have the hash string.
|
15
|
+
|
16
|
+
If you only have the FunctionSourceInfo object, you are guaranteed that it
|
17
|
+
does not have a hash string.
|
18
|
+
"""
|
19
|
+
|
20
|
+
function_source_info: FunctionSourceInfo
|
21
|
+
hash: str
|
@@ -0,0 +1,57 @@
|
|
1
|
+
import os
|
2
|
+
from itertools import islice
|
3
|
+
from typing import TYPE_CHECKING, Iterable
|
4
|
+
|
5
|
+
from .._wrap_inspect import FunctionSourceInfo
|
6
|
+
from .fast_caching_types import HashedFunctionSourceInfo
|
7
|
+
from .hash_utils import hash_iterable_strings
|
8
|
+
|
9
|
+
if TYPE_CHECKING:
|
10
|
+
from gstaichi.lang.kernel_impl import GsTaichiCallable
|
11
|
+
|
12
|
+
|
13
|
+
def pure(fn: "GsTaichiCallable") -> "GsTaichiCallable":
|
14
|
+
fn.is_pure = True
|
15
|
+
return fn
|
16
|
+
|
17
|
+
|
18
|
+
def _read_file(function_info: FunctionSourceInfo) -> list[str]:
|
19
|
+
with open(function_info.filepath) as f:
|
20
|
+
return list(islice(f, function_info.start_lineno, function_info.end_lineno + 1))
|
21
|
+
|
22
|
+
|
23
|
+
def _hash_function(function_info: FunctionSourceInfo) -> str:
|
24
|
+
return hash_iterable_strings(_read_file(function_info))
|
25
|
+
|
26
|
+
|
27
|
+
def hash_functions(function_infos: Iterable[FunctionSourceInfo]) -> list[HashedFunctionSourceInfo]:
|
28
|
+
results = []
|
29
|
+
for f_info in function_infos:
|
30
|
+
hash_ = _hash_function(f_info)
|
31
|
+
results.append(HashedFunctionSourceInfo(function_source_info=f_info, hash=hash_))
|
32
|
+
return results
|
33
|
+
|
34
|
+
|
35
|
+
def hash_kernel(kernel_info: FunctionSourceInfo) -> str:
|
36
|
+
return _hash_function(kernel_info)
|
37
|
+
|
38
|
+
|
39
|
+
def dump_stats() -> None:
|
40
|
+
print("function hasher dump stats")
|
41
|
+
|
42
|
+
|
43
|
+
def _validate_hashed_function_info(hashed_function_info: HashedFunctionSourceInfo) -> bool:
|
44
|
+
"""
|
45
|
+
Checks the hash
|
46
|
+
"""
|
47
|
+
if not os.path.isfile(hashed_function_info.function_source_info.filepath):
|
48
|
+
return False
|
49
|
+
_hash = _hash_function(hashed_function_info.function_source_info)
|
50
|
+
return _hash == hashed_function_info.hash
|
51
|
+
|
52
|
+
|
53
|
+
def validate_hashed_function_infos(function_infos: Iterable[HashedFunctionSourceInfo]) -> bool:
|
54
|
+
for function_info in function_infos:
|
55
|
+
if not _validate_hashed_function_info(function_info):
|
56
|
+
return False
|
57
|
+
return True
|
@@ -0,0 +1,11 @@
|
|
1
|
+
import hashlib
|
2
|
+
from typing import Iterable
|
3
|
+
|
4
|
+
|
5
|
+
def hash_iterable_strings(strings: Iterable[str], separator: str = "_") -> str:
|
6
|
+
h = hashlib.sha256()
|
7
|
+
separator_enc = separator.encode("utf-8")
|
8
|
+
for v in strings:
|
9
|
+
h.update(v.encode("utf-8"))
|
10
|
+
h.update(separator_enc)
|
11
|
+
return h.hexdigest()
|
@@ -0,0 +1,52 @@
|
|
1
|
+
import os
|
2
|
+
|
3
|
+
from .. import impl
|
4
|
+
|
5
|
+
|
6
|
+
class PythonSideCache:
|
7
|
+
"""
|
8
|
+
Manages a cache that is managed from the python side (we also have c++-side caches)
|
9
|
+
|
10
|
+
The cache is disk-based. When we create the PythonSideCache object, the cache
|
11
|
+
path is created as a sub-folder of CompileConfig.offline_cache_file_path.
|
12
|
+
|
13
|
+
Note that constructing this object is cheap, so there is no need to maintain some
|
14
|
+
kind of conceptual singleton instance or similar.
|
15
|
+
|
16
|
+
Each cache key value is stored to a single file, with the cache key as the filename.
|
17
|
+
|
18
|
+
No metadata is associated with the file, making management very lightweight.
|
19
|
+
|
20
|
+
We update the file date/time when we read from a particular file, so we can easily
|
21
|
+
implement an LRU cleaning strategy at some point in the future, based on the file
|
22
|
+
date/times.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(self) -> None:
|
26
|
+
_cache_parent_folder = impl.get_runtime().prog.config().offline_cache_file_path
|
27
|
+
self.cache_folder = os.path.join(_cache_parent_folder, "python_side_cache")
|
28
|
+
os.makedirs(self.cache_folder, exist_ok=True)
|
29
|
+
|
30
|
+
def _get_filepath(self, key: str) -> str:
|
31
|
+
filepath = os.path.join(self.cache_folder, f"{key}.cache.txt")
|
32
|
+
return filepath
|
33
|
+
|
34
|
+
def _touch(self, filepath):
|
35
|
+
"""
|
36
|
+
Updates file date/time.
|
37
|
+
"""
|
38
|
+
with open(filepath, "a"):
|
39
|
+
os.utime(filepath, None)
|
40
|
+
|
41
|
+
def store(self, key: str, value: str) -> None:
|
42
|
+
filepath = self._get_filepath(key)
|
43
|
+
with open(filepath, "w") as f:
|
44
|
+
f.write(value)
|
45
|
+
|
46
|
+
def try_load(self, key: str) -> str | None:
|
47
|
+
filepath = self._get_filepath(key)
|
48
|
+
if not os.path.isfile(filepath):
|
49
|
+
return None
|
50
|
+
self._touch(filepath)
|
51
|
+
with open(filepath) as f:
|
52
|
+
return f.read()
|
@@ -0,0 +1,75 @@
|
|
1
|
+
from typing import Any, Iterable, Sequence
|
2
|
+
|
3
|
+
from pydantic import BaseModel
|
4
|
+
|
5
|
+
from .._wrap_inspect import FunctionSourceInfo
|
6
|
+
from . import args_hasher, config_hasher, function_hasher
|
7
|
+
from .fast_caching_types import HashedFunctionSourceInfo
|
8
|
+
from .hash_utils import hash_iterable_strings
|
9
|
+
from .python_side_cache import PythonSideCache
|
10
|
+
|
11
|
+
|
12
|
+
def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
|
13
|
+
"""
|
14
|
+
cache key takes into account:
|
15
|
+
- arg types
|
16
|
+
- cache value arg values
|
17
|
+
- kernel function (but not sub functions)
|
18
|
+
- compilation config (which includes arch, and debug)
|
19
|
+
"""
|
20
|
+
args_hash = args_hasher.hash_args(args)
|
21
|
+
if args_hash is None:
|
22
|
+
return None
|
23
|
+
kernel_hash = function_hasher.hash_kernel(kernel_source_info)
|
24
|
+
config_hash = config_hasher.hash_compile_config()
|
25
|
+
cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
|
26
|
+
return cache_key
|
27
|
+
|
28
|
+
|
29
|
+
class CacheValue(BaseModel):
|
30
|
+
hashed_function_source_infos: list[HashedFunctionSourceInfo]
|
31
|
+
|
32
|
+
|
33
|
+
def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
|
34
|
+
"""
|
35
|
+
Note that unlike other caches, this cache is not going to store the actual value we want.
|
36
|
+
This cache is only used for verification that our cache key is valid. Big picture:
|
37
|
+
- we have a cache key, based on args and top level kernel function
|
38
|
+
- we want to use this to look up LLVM IR, in C++ side cache
|
39
|
+
- however, before doing that, we first want to validate that the source code didn't change
|
40
|
+
- i.e. is our cache key still valid?
|
41
|
+
- the python side cache contains information we will use to verify that our cache key is valid
|
42
|
+
- ie the list of function source infos
|
43
|
+
"""
|
44
|
+
if not cache_key:
|
45
|
+
return
|
46
|
+
cache = PythonSideCache()
|
47
|
+
hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
|
48
|
+
cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
|
49
|
+
cache.store(cache_key, cache_value_obj.json())
|
50
|
+
|
51
|
+
|
52
|
+
def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
|
53
|
+
cache = PythonSideCache()
|
54
|
+
maybe_cache_value_json = cache.try_load(cache_key)
|
55
|
+
if maybe_cache_value_json is None:
|
56
|
+
return None
|
57
|
+
cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
|
58
|
+
return cache_value_obj.hashed_function_source_infos
|
59
|
+
|
60
|
+
|
61
|
+
def validate_cache_key(cache_key: str) -> bool:
|
62
|
+
"""
|
63
|
+
loads function source infos from cache, if available
|
64
|
+
checks the hashes against the current source code
|
65
|
+
"""
|
66
|
+
maybe_hashed_function_source_infos = _try_load(cache_key)
|
67
|
+
if not maybe_hashed_function_source_infos:
|
68
|
+
return False
|
69
|
+
return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
|
70
|
+
|
71
|
+
|
72
|
+
def dump_stats() -> None:
|
73
|
+
print("dump stats")
|
74
|
+
args_hasher.dump_stats()
|
75
|
+
function_hasher.dump_stats()
|