gstaichi 2.1.1__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/_funcs.py +706 -0
  3. gstaichi/_kernels.py +420 -0
  4. gstaichi/_lib/__init__.py +3 -0
  5. gstaichi/_lib/core/__init__.py +0 -0
  6. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  7. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  8. gstaichi/_lib/core/py.typed +0 -0
  9. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  10. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  11. gstaichi/_lib/utils.py +243 -0
  12. gstaichi/_logging.py +131 -0
  13. gstaichi/_snode/__init__.py +5 -0
  14. gstaichi/_snode/fields_builder.py +187 -0
  15. gstaichi/_snode/snode_tree.py +34 -0
  16. gstaichi/_test_tools/__init__.py +18 -0
  17. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  18. gstaichi/_test_tools/load_kernel_string.py +30 -0
  19. gstaichi/_test_tools/textwrap2.py +6 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +352 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +5 -0
  51. gstaichi/lang/ast/ast_transformer.py +1323 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1245 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1341 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +780 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +8 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +19 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-2.1.1.dist-info/METADATA +106 -0
  175. gstaichi-2.1.1.dist-info/RECORD +178 -0
  176. gstaichi-2.1.1.dist-info/WHEEL +5 -0
  177. gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-2.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,117 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._kernels import (
4
+ blit_from_field_to_field,
5
+ scan_add_inclusive,
6
+ sort_stage,
7
+ uniform_add,
8
+ warp_shfl_up_i32,
9
+ )
10
+ from gstaichi.lang.impl import current_cfg, field
11
+ from gstaichi.lang.kernel_impl import data_oriented
12
+ from gstaichi.lang.misc import cuda, vulkan
13
+ from gstaichi.lang.runtime_ops import sync
14
+ from gstaichi.lang.simt import subgroup
15
+ from gstaichi.types.primitive_types import i32
16
+
17
+
18
+ def parallel_sort(keys, values=None):
19
+ """Odd-even merge sort
20
+
21
+ References:
22
+ https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
23
+ https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
24
+ """
25
+ N = keys.shape[0]
26
+
27
+ num_stages = 0
28
+ p = 1
29
+ while p < N:
30
+ k = p
31
+ while k >= 1:
32
+ invocations = int((N - k - k % p) / (2 * k)) + 1
33
+ if values is None:
34
+ sort_stage(keys, 0, keys, N, p, k, invocations)
35
+ else:
36
+ sort_stage(keys, 1, values, N, p, k, invocations)
37
+ num_stages += 1
38
+ sync()
39
+ k = int(k / 2)
40
+ p = int(p * 2)
41
+
42
+
43
+ @data_oriented
44
+ class PrefixSumExecutor:
45
+ """Parallel Prefix Sum (Scan) Helper
46
+
47
+ Use this helper to perform an inclusive in-place's parallel prefix sum.
48
+
49
+ References:
50
+ https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
51
+ https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
52
+ """
53
+
54
+ def __init__(self, length):
55
+ self.sorting_length = length
56
+
57
+ BLOCK_SZ = 64
58
+ GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)
59
+
60
+ # Buffer position and length
61
+ # This is a single buffer implementation for ease of aot usage
62
+ ele_num = length
63
+ self.ele_nums = [ele_num]
64
+ start_pos = 0
65
+ self.ele_nums_pos = [start_pos]
66
+
67
+ while ele_num > 1:
68
+ ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
69
+ self.ele_nums.append(ele_num)
70
+ start_pos += BLOCK_SZ * ele_num
71
+ self.ele_nums_pos.append(start_pos)
72
+
73
+ self.large_arr = field(i32, shape=start_pos)
74
+
75
+ def run(self, input_arr):
76
+ length = self.sorting_length
77
+ ele_nums = self.ele_nums
78
+ ele_nums_pos = self.ele_nums_pos
79
+
80
+ if input_arr.dtype != i32:
81
+ raise RuntimeError("Only ti.i32 type is supported for prefix sum.")
82
+
83
+ if current_cfg().arch == cuda:
84
+ inclusive_add = warp_shfl_up_i32
85
+ elif current_cfg().arch == vulkan:
86
+ inclusive_add = subgroup.inclusive_add
87
+ else:
88
+ raise RuntimeError(f"{str(current_cfg().arch)} is not supported for prefix sum.")
89
+
90
+ blit_from_field_to_field(self.large_arr, input_arr, 0, length)
91
+
92
+ # Kogge-Stone construction
93
+ for i in range(len(ele_nums) - 1):
94
+ if i == len(ele_nums) - 2:
95
+ scan_add_inclusive(
96
+ self.large_arr,
97
+ ele_nums_pos[i],
98
+ ele_nums_pos[i + 1],
99
+ True,
100
+ inclusive_add,
101
+ )
102
+ else:
103
+ scan_add_inclusive(
104
+ self.large_arr,
105
+ ele_nums_pos[i],
106
+ ele_nums_pos[i + 1],
107
+ False,
108
+ inclusive_add,
109
+ )
110
+
111
+ for i in range(len(ele_nums) - 3, -1, -1):
112
+ uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])
113
+
114
+ blit_from_field_to_field(input_arr, self.large_arr, 0, length)
115
+
116
+
117
+ __all__ = ["parallel_sort", "PrefixSumExecutor"]
gstaichi/assets/.git ADDED
@@ -0,0 +1 @@
1
+ gitdir: ../../.git/modules/external/assets
Binary file
@@ -0,0 +1,26 @@
1
+ import time
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+
6
+
7
+ def lcg_np(B: int, lcg_its: int, a: npt.NDArray) -> None:
8
+ for i in range(B):
9
+ x = a[i]
10
+ for j in range(lcg_its):
11
+ x = (1664525 * x + 1013904223) % 2147483647
12
+ a[i] = x
13
+
14
+
15
+ def main() -> None:
16
+ B = 16000
17
+ a = np.ndarray((B,), np.int32)
18
+
19
+ start = time.time()
20
+ lcg_np(B, 1000, a)
21
+ end = time.time()
22
+ print("elapsed", end - start)
23
+ # elapsed 5.552601099014282 on macbook air m4
24
+
25
+
26
+ main()
@@ -0,0 +1,34 @@
1
+ import time
2
+
3
+ import gstaichi as ti
4
+
5
+
6
+ @ti.kernel
7
+ def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
8
+ for i in range(B):
9
+ x = a[i]
10
+ for j in range(lcg_its):
11
+ x = (1664525 * x + 1013904223) % 2147483647
12
+ a[i] = x
13
+
14
+
15
+ def main() -> None:
16
+ ti.init(arch=ti.gpu)
17
+
18
+ B = 16000
19
+ a = ti.ndarray(ti.int32, (B,))
20
+
21
+ ti.sync()
22
+ start = time.time()
23
+ lcg_ti(B, 1000, a)
24
+ ti.sync()
25
+ end = time.time()
26
+ print("elapsed", end - start)
27
+
28
+ # [GsTaichi] version 1.8.0, llvm 15.0.7, commit 5afed1c9, osx, python 3.10.16
29
+ # [GsTaichi] Starting on arch=metal
30
+ # elapsed 0.04660296440124512
31
+ # (on mac air m4)
32
+
33
+
34
+ main()
@@ -0,0 +1,28 @@
1
+ import gstaichi as ti
2
+
3
+
4
+ @ti.kernel
5
+ def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
6
+ """
7
+ Linear congruential generator https://en.wikipedia.org/wiki/Linear_congruential_generator
8
+ """
9
+ for i in range(B):
10
+ x = a[i]
11
+ for j in range(lcg_its):
12
+ x = (1664525 * x + 1013904223) % 2147483647
13
+ a[i] = x
14
+
15
+
16
+ def main() -> None:
17
+ ti.init(arch=ti.cpu)
18
+
19
+ B = 10
20
+ lcg_its = 10
21
+
22
+ a = ti.ndarray(ti.int32, (B,))
23
+
24
+ lcg_ti(B, lcg_its, a)
25
+ print(f"LCG for B={B}, lcg_its={lcg_its}: ", a.to_numpy()) # pylint: disable=no-member
26
+
27
+
28
+ main()
@@ -0,0 +1,16 @@
1
+ # type: ignore
2
+
3
+ import warnings
4
+
5
+ from gstaichi.lang.kernel_impl import real_func as _real_func
6
+
7
+
8
+ def real_func(func):
9
+ warnings.warn(
10
+ "ti.experimental.real_func is deprecated because it is no longer experimental. " "Use ti.real_func instead.",
11
+ DeprecationWarning,
12
+ )
13
+ return _real_func(func)
14
+
15
+
16
+ __all__ = ["real_func"]
@@ -0,0 +1,50 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl, simt
4
+ from gstaichi.lang._fast_caching.function_hasher import pure
5
+ from gstaichi.lang._ndarray import *
6
+ from gstaichi.lang._ndrange import ndrange
7
+ from gstaichi.lang._texture import Texture
8
+ from gstaichi.lang.exception import *
9
+ from gstaichi.lang.field import *
10
+ from gstaichi.lang.impl import *
11
+ from gstaichi.lang.kernel_impl import *
12
+ from gstaichi.lang.matrix import *
13
+ from gstaichi.lang.mesh import *
14
+ from gstaichi.lang.misc import * # pylint: disable=W0622
15
+ from gstaichi.lang.ops import * # pylint: disable=W0622
16
+ from gstaichi.lang.runtime_ops import *
17
+ from gstaichi.lang.snode import *
18
+ from gstaichi.lang.source_builder import *
19
+ from gstaichi.lang.struct import *
20
+ from gstaichi.types.enums import DeviceCapability, Format, Layout
21
+
22
+ __all__ = [
23
+ s
24
+ for s in dir()
25
+ if not s.startswith("_")
26
+ and s
27
+ not in [
28
+ "any_array",
29
+ "ast",
30
+ "common_ops",
31
+ "enums",
32
+ "exception",
33
+ "expr",
34
+ "impl",
35
+ "inspect",
36
+ "kernel_arguments",
37
+ "kernel_impl",
38
+ "matrix",
39
+ "mesh",
40
+ "misc",
41
+ "ops",
42
+ "platform",
43
+ "runtime_ops",
44
+ "shell",
45
+ "snode",
46
+ "source_builder",
47
+ "struct",
48
+ "util",
49
+ ]
50
+ ]
@@ -0,0 +1,31 @@
1
+ def create_flat_name(basename: str, child_name: str) -> str:
2
+ """
3
+ Appends child_name to basename, separated by __ti_.
4
+ If basename does not start with __ti_ then prefix the resulting string
5
+ with __ti_.
6
+
7
+ Note that we want to avoid adding prefix __ti_ if already included in `basename`,
8
+ to avoid duplicating said delimiter.
9
+
10
+ We'll use this when expanding py dataclass members, e.g.
11
+
12
+ @dataclasses.dataclass
13
+ def Foo:
14
+ a: int
15
+ b: int
16
+
17
+ foo = Foo(a=5, b=3)
18
+
19
+ When we expand out foo, we'll replace foo with the following names instead:
20
+ - __ti_foo__ti_a
21
+ - __ti_foo__ti_b
22
+
23
+ We use the __ti_ to ensure that it's easy to ensure no collision with existing user-defined
24
+ names. We require the user to not create any fields or variables which themselves are prefixed
25
+ with __ti_, and given this constraint, the names we create will not conflict with user-generated
26
+ names.
27
+ """
28
+ full_name = f"{basename}__ti_{child_name}"
29
+ if not full_name.startswith("__ti_"):
30
+ full_name = f"__ti_{full_name}"
31
+ return full_name
@@ -0,0 +1,3 @@
1
+ from .args_hasher import FIELD_METADATA_CACHE_VALUE
2
+
3
+ __all__ = ["FIELD_METADATA_CACHE_VALUE"]
@@ -0,0 +1,110 @@
1
+ import dataclasses
2
+ import enum
3
+ import numbers
4
+ import time
5
+ from typing import Any, Sequence
6
+
7
+ import numpy as np
8
+
9
+ from .._ndarray import ScalarNdarray
10
+ from ..field import ScalarField
11
+ from ..matrix import MatrixField, MatrixNdarray, VectorNdarray
12
+ from ..util import is_data_oriented
13
+ from .hash_utils import hash_iterable_strings
14
+
15
+ g_num_calls = 0
16
+ g_num_args = 0
17
+ g_hashing_time = 0
18
+ g_repr_time = 0
19
+ g_num_ignored_calls = 0
20
+
21
+
22
+ FIELD_METADATA_CACHE_VALUE = "add_value_to_cache_key"
23
+
24
+
25
+ def dataclass_to_repr(path: tuple[str, ...], arg: Any) -> str:
26
+ repr_l = []
27
+ for field in dataclasses.fields(arg):
28
+ child_value = getattr(arg, field.name)
29
+ _repr = stringify_obj_type(path + (field.name,), child_value)
30
+ full_repr = f"{field.name}: ({_repr})"
31
+ if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False):
32
+ full_repr += f" = {child_value}"
33
+ repr_l.append(full_repr)
34
+ return "[" + ",".join(repr_l) + "]"
35
+
36
+
37
+ def stringify_obj_type(path: tuple[str, ...], obj: Any) -> str | None:
38
+ """
39
+ Convert an object into a string representation that only depends on its type.
40
+
41
+ String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have
42
+ to be the actual python type string, just a string that is representative of the type, and won't collide
43
+ with different (allowed) types.
44
+
45
+ `path` is used during debugging.
46
+ """
47
+ # TODO: We should have a way of printing this without having to hack the code really. Using logger perhaps?
48
+ # (I have another PR that addreses this https://github.com/Genesis-Embodied-AI/gstaichi/pull/144/files)
49
+ arg_type = type(obj)
50
+ if isinstance(obj, ScalarNdarray):
51
+ return f"[nd-{obj.dtype}-{len(obj.shape)}]"
52
+ if isinstance(obj, VectorNdarray):
53
+ return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
54
+ if isinstance(obj, ScalarField):
55
+ return f"[f-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
56
+ if isinstance(obj, MatrixNdarray):
57
+ return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
58
+ if "torch.Tensor" in str(arg_type):
59
+ return f"[pt-{obj.dtype}-{obj.ndim}]"
60
+ if isinstance(obj, np.ndarray):
61
+ return f"[np-{obj.dtype}-{obj.ndim}]"
62
+ if isinstance(obj, MatrixField):
63
+ return f"[fm-{obj.m}-{obj.n}-{obj.snode._id}-{obj.dtype}-{obj.shape}]"
64
+ if dataclasses.is_dataclass(obj):
65
+ return dataclass_to_repr(path, obj)
66
+ if is_data_oriented(obj):
67
+ child_repr_l = []
68
+ for k, v in obj.__dict__.items():
69
+ _child_repr = stringify_obj_type((*path, k), v)
70
+ if _child_repr is None:
71
+ print("not representable child", k, type(v), "path", path)
72
+ return None
73
+ child_repr_l.append(f"{k}: {_child_repr}")
74
+ return ", ".join(child_repr_l)
75
+ if issubclass(arg_type, (numbers.Number, np.number)):
76
+ return str(arg_type)
77
+ if arg_type is np.bool_:
78
+ # np is deprecating bool. Treat specially/carefully
79
+ return "np.bool_"
80
+ if isinstance(obj, enum.Enum):
81
+ return f"enum-{obj.name}-{obj.value}"
82
+ return None
83
+
84
+
85
+ def hash_args(args: Sequence[Any]) -> str | None:
86
+ global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls
87
+ g_num_calls += 1
88
+ g_num_args += len(args)
89
+ hash_l = []
90
+ for i_arg, arg in enumerate(args):
91
+ start = time.time()
92
+ _hash = stringify_obj_type((str(i_arg),), arg)
93
+ g_repr_time += time.time() - start
94
+ if not _hash:
95
+ g_num_ignored_calls += 1
96
+ return None
97
+ hash_l.append(_hash)
98
+ start = time.time()
99
+ res = hash_iterable_strings(hash_l)
100
+ g_hashing_time += time.time() - start
101
+ return res
102
+
103
+
104
+ def dump_stats() -> None:
105
+ print("args hasher dump stats")
106
+ print("total calls", g_num_calls)
107
+ print("ignored calls", g_num_ignored_calls)
108
+ print("total args", g_num_args)
109
+ print("hashing time", g_hashing_time)
110
+ print("arg representation time", g_repr_time)
@@ -0,0 +1,30 @@
1
+ from gstaichi.lang import impl
2
+
3
+ from .hash_utils import hash_iterable_strings
4
+
5
+ EXCLUDE_PREFIXES = ["_", "offline_cache", "print_", "verbose_"]
6
+
7
+
8
+ def hash_compile_config() -> str:
9
+ """
10
+ Calculates a hash string for the current compiler config.
11
+
12
+ If any value in the compiler config changes, the hash string changes too.
13
+
14
+ Though arguably we might want to blacklist certain keys, such as print_ir_debug,
15
+ which do not affect the compiled kernels, just stuff that gets printed during
16
+ the compilation process.
17
+ """
18
+ config = impl.get_runtime().prog.config()
19
+ config_l = []
20
+ for k in dir(config):
21
+ skip = False
22
+ for prefix in EXCLUDE_PREFIXES:
23
+ if k.startswith(prefix) or k in [""]:
24
+ skip = True
25
+ if skip:
26
+ continue
27
+ v = getattr(config, k)
28
+ config_l.append(f"{k}={v}")
29
+ config_hash = hash_iterable_strings(config_l, separator="\n")
30
+ return config_hash
@@ -0,0 +1,21 @@
1
+ from pydantic import BaseModel
2
+
3
+ from .._wrap_inspect import FunctionSourceInfo
4
+
5
+
6
+ class HashedFunctionSourceInfo(BaseModel):
7
+ """
8
+ Wraps a function source info, and the hash string of that function.
9
+
10
+ By not adding the hash directly into function source info, we avoid
11
+ having to make hash an optional type, and checking if it's empty or not.
12
+
13
+ If you have a HashedFunctionSourceInfo object, then you are guaranteed
14
+ to have the hash string.
15
+
16
+ If you only have the FunctionSourceInfo object, you are guaranteed that it
17
+ does not have a hash string.
18
+ """
19
+
20
+ function_source_info: FunctionSourceInfo
21
+ hash: str
@@ -0,0 +1,57 @@
1
+ import os
2
+ from itertools import islice
3
+ from typing import TYPE_CHECKING, Iterable
4
+
5
+ from .._wrap_inspect import FunctionSourceInfo
6
+ from .fast_caching_types import HashedFunctionSourceInfo
7
+ from .hash_utils import hash_iterable_strings
8
+
9
+ if TYPE_CHECKING:
10
+ from gstaichi.lang.kernel_impl import GsTaichiCallable
11
+
12
+
13
+ def pure(fn: "GsTaichiCallable") -> "GsTaichiCallable":
14
+ fn.is_pure = True
15
+ return fn
16
+
17
+
18
+ def _read_file(function_info: FunctionSourceInfo) -> list[str]:
19
+ with open(function_info.filepath) as f:
20
+ return list(islice(f, function_info.start_lineno, function_info.end_lineno + 1))
21
+
22
+
23
+ def _hash_function(function_info: FunctionSourceInfo) -> str:
24
+ return hash_iterable_strings(_read_file(function_info))
25
+
26
+
27
+ def hash_functions(function_infos: Iterable[FunctionSourceInfo]) -> list[HashedFunctionSourceInfo]:
28
+ results = []
29
+ for f_info in function_infos:
30
+ hash_ = _hash_function(f_info)
31
+ results.append(HashedFunctionSourceInfo(function_source_info=f_info, hash=hash_))
32
+ return results
33
+
34
+
35
+ def hash_kernel(kernel_info: FunctionSourceInfo) -> str:
36
+ return _hash_function(kernel_info)
37
+
38
+
39
+ def dump_stats() -> None:
40
+ print("function hasher dump stats")
41
+
42
+
43
+ def _validate_hashed_function_info(hashed_function_info: HashedFunctionSourceInfo) -> bool:
44
+ """
45
+ Checks the hash
46
+ """
47
+ if not os.path.isfile(hashed_function_info.function_source_info.filepath):
48
+ return False
49
+ _hash = _hash_function(hashed_function_info.function_source_info)
50
+ return _hash == hashed_function_info.hash
51
+
52
+
53
+ def validate_hashed_function_infos(function_infos: Iterable[HashedFunctionSourceInfo]) -> bool:
54
+ for function_info in function_infos:
55
+ if not _validate_hashed_function_info(function_info):
56
+ return False
57
+ return True
@@ -0,0 +1,11 @@
1
+ import hashlib
2
+ from typing import Iterable
3
+
4
+
5
+ def hash_iterable_strings(strings: Iterable[str], separator: str = "_") -> str:
6
+ h = hashlib.sha256()
7
+ separator_enc = separator.encode("utf-8")
8
+ for v in strings:
9
+ h.update(v.encode("utf-8"))
10
+ h.update(separator_enc)
11
+ return h.hexdigest()
@@ -0,0 +1,52 @@
1
+ import os
2
+
3
+ from .. import impl
4
+
5
+
6
+ class PythonSideCache:
7
+ """
8
+ Manages a cache that is managed from the python side (we also have c++-side caches)
9
+
10
+ The cache is disk-based. When we create the PythonSideCache object, the cache
11
+ path is created as a sub-folder of CompileConfig.offline_cache_file_path.
12
+
13
+ Note that constructing this object is cheap, so there is no need to maintain some
14
+ kind of conceptual singleton instance or similar.
15
+
16
+ Each cache key value is stored to a single file, with the cache key as the filename.
17
+
18
+ No metadata is associated with the file, making management very lightweight.
19
+
20
+ We update the file date/time when we read from a particular file, so we can easily
21
+ implement an LRU cleaning strategy at some point in the future, based on the file
22
+ date/times.
23
+ """
24
+
25
+ def __init__(self) -> None:
26
+ _cache_parent_folder = impl.get_runtime().prog.config().offline_cache_file_path
27
+ self.cache_folder = os.path.join(_cache_parent_folder, "python_side_cache")
28
+ os.makedirs(self.cache_folder, exist_ok=True)
29
+
30
+ def _get_filepath(self, key: str) -> str:
31
+ filepath = os.path.join(self.cache_folder, f"{key}.cache.txt")
32
+ return filepath
33
+
34
+ def _touch(self, filepath):
35
+ """
36
+ Updates file date/time.
37
+ """
38
+ with open(filepath, "a"):
39
+ os.utime(filepath, None)
40
+
41
+ def store(self, key: str, value: str) -> None:
42
+ filepath = self._get_filepath(key)
43
+ with open(filepath, "w") as f:
44
+ f.write(value)
45
+
46
+ def try_load(self, key: str) -> str | None:
47
+ filepath = self._get_filepath(key)
48
+ if not os.path.isfile(filepath):
49
+ return None
50
+ self._touch(filepath)
51
+ with open(filepath) as f:
52
+ return f.read()
@@ -0,0 +1,75 @@
1
+ from typing import Any, Iterable, Sequence
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from .._wrap_inspect import FunctionSourceInfo
6
+ from . import args_hasher, config_hasher, function_hasher
7
+ from .fast_caching_types import HashedFunctionSourceInfo
8
+ from .hash_utils import hash_iterable_strings
9
+ from .python_side_cache import PythonSideCache
10
+
11
+
12
+ def create_cache_key(kernel_source_info: FunctionSourceInfo, args: Sequence[Any]) -> str | None:
13
+ """
14
+ cache key takes into account:
15
+ - arg types
16
+ - cache value arg values
17
+ - kernel function (but not sub functions)
18
+ - compilation config (which includes arch, and debug)
19
+ """
20
+ args_hash = args_hasher.hash_args(args)
21
+ if args_hash is None:
22
+ return None
23
+ kernel_hash = function_hasher.hash_kernel(kernel_source_info)
24
+ config_hash = config_hasher.hash_compile_config()
25
+ cache_key = hash_iterable_strings((kernel_hash, args_hash, config_hash))
26
+ return cache_key
27
+
28
+
29
+ class CacheValue(BaseModel):
30
+ hashed_function_source_infos: list[HashedFunctionSourceInfo]
31
+
32
+
33
+ def store(cache_key: str, function_source_infos: Iterable[FunctionSourceInfo]) -> None:
34
+ """
35
+ Note that unlike other caches, this cache is not going to store the actual value we want.
36
+ This cache is only used for verification that our cache key is valid. Big picture:
37
+ - we have a cache key, based on args and top level kernel function
38
+ - we want to use this to look up LLVM IR, in C++ side cache
39
+ - however, before doing that, we first want to validate that the source code didn't change
40
+ - i.e. is our cache key still valid?
41
+ - the python side cache contains information we will use to verify that our cache key is valid
42
+ - ie the list of function source infos
43
+ """
44
+ if not cache_key:
45
+ return
46
+ cache = PythonSideCache()
47
+ hashed_function_source_infos = function_hasher.hash_functions(function_source_infos)
48
+ cache_value_obj = CacheValue(hashed_function_source_infos=list(hashed_function_source_infos))
49
+ cache.store(cache_key, cache_value_obj.json())
50
+
51
+
52
+ def _try_load(cache_key: str) -> Sequence[HashedFunctionSourceInfo] | None:
53
+ cache = PythonSideCache()
54
+ maybe_cache_value_json = cache.try_load(cache_key)
55
+ if maybe_cache_value_json is None:
56
+ return None
57
+ cache_value_obj = CacheValue.parse_raw(maybe_cache_value_json)
58
+ return cache_value_obj.hashed_function_source_infos
59
+
60
+
61
+ def validate_cache_key(cache_key: str) -> bool:
62
+ """
63
+ loads function source infos from cache, if available
64
+ checks the hashes against the current source code
65
+ """
66
+ maybe_hashed_function_source_infos = _try_load(cache_key)
67
+ if not maybe_hashed_function_source_infos:
68
+ return False
69
+ return function_hasher.validate_hashed_function_infos(maybe_hashed_function_source_infos)
70
+
71
+
72
+ def dump_stats() -> None:
73
+ print("dump stats")
74
+ args_hasher.dump_stats()
75
+ function_hasher.dump_stats()