gstaichi 0.0.0__cp310-cp310-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +51 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +5 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-310-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +122 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +83 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +366 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +7 -0
  51. gstaichi/lang/ast/ast_transformer.py +1351 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1259 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1386 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +784 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +10 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +21 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-0.0.0.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-0.0.0.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-0.0.0.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-0.0.0.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-0.0.0.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-0.0.0.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-0.0.0.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-0.0.0.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-0.0.0.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-0.0.0.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-0.0.0.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-0.0.0.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-0.0.0.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-0.0.0.dist-info/METADATA +97 -0
  175. gstaichi-0.0.0.dist-info/RECORD +178 -0
  176. gstaichi-0.0.0.dist-info/WHEEL +5 -0
  177. gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,117 @@
1
+ # type: ignore
2
+
3
+ from gstaichi._kernels import (
4
+ blit_from_field_to_field,
5
+ scan_add_inclusive,
6
+ sort_stage,
7
+ uniform_add,
8
+ warp_shfl_up_i32,
9
+ )
10
+ from gstaichi.lang.impl import current_cfg, field
11
+ from gstaichi.lang.kernel_impl import data_oriented
12
+ from gstaichi.lang.misc import cuda, vulkan
13
+ from gstaichi.lang.runtime_ops import sync
14
+ from gstaichi.lang.simt import subgroup
15
+ from gstaichi.types.primitive_types import i32
16
+
17
+
18
+ def parallel_sort(keys, values=None):
19
+ """Odd-even merge sort
20
+
21
+ References:
22
+ https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
23
+ https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
24
+ """
25
+ N = keys.shape[0]
26
+
27
+ num_stages = 0
28
+ p = 1
29
+ while p < N:
30
+ k = p
31
+ while k >= 1:
32
+ invocations = int((N - k - k % p) / (2 * k)) + 1
33
+ if values is None:
34
+ sort_stage(keys, 0, keys, N, p, k, invocations)
35
+ else:
36
+ sort_stage(keys, 1, values, N, p, k, invocations)
37
+ num_stages += 1
38
+ sync()
39
+ k = int(k / 2)
40
+ p = int(p * 2)
41
+
42
+
43
+ @data_oriented
44
+ class PrefixSumExecutor:
45
+ """Parallel Prefix Sum (Scan) Helper
46
+
47
+ Use this helper to perform an inclusive in-place's parallel prefix sum.
48
+
49
+ References:
50
+ https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
51
+ https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
52
+ """
53
+
54
+ def __init__(self, length):
55
+ self.sorting_length = length
56
+
57
+ BLOCK_SZ = 64
58
+ GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)
59
+
60
+ # Buffer position and length
61
+ # This is a single buffer implementation for ease of aot usage
62
+ ele_num = length
63
+ self.ele_nums = [ele_num]
64
+ start_pos = 0
65
+ self.ele_nums_pos = [start_pos]
66
+
67
+ while ele_num > 1:
68
+ ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
69
+ self.ele_nums.append(ele_num)
70
+ start_pos += BLOCK_SZ * ele_num
71
+ self.ele_nums_pos.append(start_pos)
72
+
73
+ self.large_arr = field(i32, shape=start_pos)
74
+
75
+ def run(self, input_arr):
76
+ length = self.sorting_length
77
+ ele_nums = self.ele_nums
78
+ ele_nums_pos = self.ele_nums_pos
79
+
80
+ if input_arr.dtype != i32:
81
+ raise RuntimeError("Only ti.i32 type is supported for prefix sum.")
82
+
83
+ if current_cfg().arch == cuda:
84
+ inclusive_add = warp_shfl_up_i32
85
+ elif current_cfg().arch == vulkan:
86
+ inclusive_add = subgroup.inclusive_add
87
+ else:
88
+ raise RuntimeError(f"{str(current_cfg().arch)} is not supported for prefix sum.")
89
+
90
+ blit_from_field_to_field(self.large_arr, input_arr, 0, length)
91
+
92
+ # Kogge-Stone construction
93
+ for i in range(len(ele_nums) - 1):
94
+ if i == len(ele_nums) - 2:
95
+ scan_add_inclusive(
96
+ self.large_arr,
97
+ ele_nums_pos[i],
98
+ ele_nums_pos[i + 1],
99
+ True,
100
+ inclusive_add,
101
+ )
102
+ else:
103
+ scan_add_inclusive(
104
+ self.large_arr,
105
+ ele_nums_pos[i],
106
+ ele_nums_pos[i + 1],
107
+ False,
108
+ inclusive_add,
109
+ )
110
+
111
+ for i in range(len(ele_nums) - 3, -1, -1):
112
+ uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])
113
+
114
+ blit_from_field_to_field(input_arr, self.large_arr, 0, length)
115
+
116
+
117
+ __all__ = ["parallel_sort", "PrefixSumExecutor"]
gstaichi/assets/.git ADDED
@@ -0,0 +1 @@
1
+ gitdir: ../../.git/modules/external/assets
Binary file
@@ -0,0 +1,26 @@
1
+ import time
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+
6
+
7
+ def lcg_np(B: int, lcg_its: int, a: npt.NDArray) -> None:
8
+ for i in range(B):
9
+ x = a[i]
10
+ for j in range(lcg_its):
11
+ x = (1664525 * x + 1013904223) % 2147483647
12
+ a[i] = x
13
+
14
+
15
+ def main() -> None:
16
+ B = 16000
17
+ a = np.ndarray((B,), np.int32)
18
+
19
+ start = time.time()
20
+ lcg_np(B, 1000, a)
21
+ end = time.time()
22
+ print("elapsed", end - start)
23
+ # elapsed 5.552601099014282 on macbook air m4
24
+
25
+
26
+ main()
@@ -0,0 +1,34 @@
1
+ import time
2
+
3
+ import gstaichi as ti
4
+
5
+
6
+ @ti.kernel
7
+ def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
8
+ for i in range(B):
9
+ x = a[i]
10
+ for j in range(lcg_its):
11
+ x = (1664525 * x + 1013904223) % 2147483647
12
+ a[i] = x
13
+
14
+
15
+ def main() -> None:
16
+ ti.init(arch=ti.gpu)
17
+
18
+ B = 16000
19
+ a = ti.ndarray(ti.int32, (B,))
20
+
21
+ ti.sync()
22
+ start = time.time()
23
+ lcg_ti(B, 1000, a)
24
+ ti.sync()
25
+ end = time.time()
26
+ print("elapsed", end - start)
27
+
28
+ # [GsTaichi] version 1.8.0, llvm 15.0.7, commit 5afed1c9, osx, python 3.10.16
29
+ # [GsTaichi] Starting on arch=metal
30
+ # elapsed 0.04660296440124512
31
+ # (on mac air m4)
32
+
33
+
34
+ main()
@@ -0,0 +1,28 @@
1
+ import gstaichi as ti
2
+
3
+
4
+ @ti.kernel
5
+ def lcg_ti(B: int, lcg_its: int, a: ti.types.NDArray[ti.i32, 1]) -> None:
6
+ """
7
+ Linear congruential generator https://en.wikipedia.org/wiki/Linear_congruential_generator
8
+ """
9
+ for i in range(B):
10
+ x = a[i]
11
+ for j in range(lcg_its):
12
+ x = (1664525 * x + 1013904223) % 2147483647
13
+ a[i] = x
14
+
15
+
16
+ def main() -> None:
17
+ ti.init(arch=ti.cpu)
18
+
19
+ B = 10
20
+ lcg_its = 10
21
+
22
+ a = ti.ndarray(ti.int32, (B,))
23
+
24
+ lcg_ti(B, lcg_its, a)
25
+ print(f"LCG for B={B}, lcg_its={lcg_its}: ", a.to_numpy()) # pylint: disable=no-member
26
+
27
+
28
+ main()
@@ -0,0 +1,16 @@
1
+ # type: ignore
2
+
3
+ import warnings
4
+
5
+ from gstaichi.lang.kernel_impl import real_func as _real_func
6
+
7
+
8
+ def real_func(func):
9
+ warnings.warn(
10
+ "ti.experimental.real_func is deprecated because it is no longer experimental. " "Use ti.real_func instead.",
11
+ DeprecationWarning,
12
+ )
13
+ return _real_func(func)
14
+
15
+
16
+ __all__ = ["real_func"]
@@ -0,0 +1,50 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang import impl, simt # noqa: F401
4
+ from gstaichi.lang._fast_caching.function_hasher import pure # noqa: F401
5
+ from gstaichi.lang._ndarray import *
6
+ from gstaichi.lang._ndrange import ndrange # noqa: F401
7
+ from gstaichi.lang._texture import Texture # noqa: F401
8
+ from gstaichi.lang.exception import *
9
+ from gstaichi.lang.field import *
10
+ from gstaichi.lang.impl import *
11
+ from gstaichi.lang.kernel_impl import *
12
+ from gstaichi.lang.matrix import *
13
+ from gstaichi.lang.mesh import *
14
+ from gstaichi.lang.misc import * # pylint: disable=W0622
15
+ from gstaichi.lang.ops import * # pylint: disable=W0622
16
+ from gstaichi.lang.runtime_ops import *
17
+ from gstaichi.lang.snode import *
18
+ from gstaichi.lang.source_builder import *
19
+ from gstaichi.lang.struct import *
20
+ from gstaichi.types.enums import DeviceCapability, Format, Layout # noqa: F401
21
+
22
+ __all__ = [
23
+ s
24
+ for s in dir()
25
+ if not s.startswith("_")
26
+ and s
27
+ not in [
28
+ "any_array",
29
+ "ast",
30
+ "common_ops",
31
+ "enums",
32
+ "exception",
33
+ "expr",
34
+ "impl",
35
+ "inspect",
36
+ "kernel_arguments",
37
+ "kernel_impl",
38
+ "matrix",
39
+ "mesh",
40
+ "misc",
41
+ "ops",
42
+ "platform",
43
+ "runtime_ops",
44
+ "shell",
45
+ "snode",
46
+ "source_builder",
47
+ "struct",
48
+ "util",
49
+ ]
50
+ ]
@@ -0,0 +1,31 @@
1
+ def create_flat_name(basename: str, child_name: str) -> str:
2
+ """
3
+ Appends child_name to basename, separated by __ti_.
4
+ If basename does not start with __ti_ then prefix the resulting string
5
+ with __ti_.
6
+
7
+ Note that we want to avoid adding prefix __ti_ if already included in `basename`,
8
+ to avoid duplicating said delimiter.
9
+
10
+ We'll use this when expanding py dataclass members, e.g.
11
+
12
+ @dataclasses.dataclass
13
+ def Foo:
14
+ a: int
15
+ b: int
16
+
17
+ foo = Foo(a=5, b=3)
18
+
19
+ When we expand out foo, we'll replace foo with the following names instead:
20
+ - __ti_foo__ti_a
21
+ - __ti_foo__ti_b
22
+
23
+ We use the __ti_ to ensure that it's easy to ensure no collision with existing user-defined
24
+ names. We require the user to not create any fields or variables which themselves are prefixed
25
+ with __ti_, and given this constraint, the names we create will not conflict with user-generated
26
+ names.
27
+ """
28
+ full_name = f"{basename}__ti_{child_name}"
29
+ if not full_name.startswith("__ti_"):
30
+ full_name = f"__ti_{full_name}"
31
+ return full_name
@@ -0,0 +1,3 @@
1
+ from .args_hasher import FIELD_METADATA_CACHE_VALUE
2
+
3
+ __all__ = ["FIELD_METADATA_CACHE_VALUE"]
@@ -0,0 +1,122 @@
1
+ import dataclasses
2
+ import enum
3
+ import numbers
4
+ import time
5
+ from typing import Any, Sequence
6
+
7
+ import numpy as np
8
+
9
+ from gstaichi import _logging
10
+
11
+ from .._ndarray import ScalarNdarray
12
+ from ..field import ScalarField
13
+ from ..matrix import MatrixField, MatrixNdarray, VectorNdarray
14
+ from ..util import is_data_oriented
15
+ from .hash_utils import hash_iterable_strings
16
+
17
+ g_num_calls = 0
18
+ g_num_args = 0
19
+ g_hashing_time = 0
20
+ g_repr_time = 0
21
+ g_num_ignored_calls = 0
22
+
23
+
24
+ FIELD_METADATA_CACHE_VALUE = "add_value_to_cache_key"
25
+
26
+
27
+ def dataclass_to_repr(path: tuple[str, ...], arg: Any) -> str:
28
+ repr_l = []
29
+ for field in dataclasses.fields(arg):
30
+ child_value = getattr(arg, field.name)
31
+ _repr = stringify_obj_type(path + (field.name,), child_value)
32
+ full_repr = f"{field.name}: ({_repr})"
33
+ if field.metadata.get(FIELD_METADATA_CACHE_VALUE, False):
34
+ full_repr += f" = {child_value}"
35
+ repr_l.append(full_repr)
36
+ return "[" + ",".join(repr_l) + "]"
37
+
38
+
39
+ def stringify_obj_type(path: tuple[str, ...], obj: object) -> str | None:
40
+ """
41
+ Convert an object into a string representation that only depends on its type.
42
+
43
+ String should somehow represent the type of obj. Doesnt have to be hashed, nor does it have
44
+ to be the actual python type string, just a string that is representative of the type, and won't collide
45
+ with different (allowed) types.
46
+
47
+ Note that fields are not included in fast cache.
48
+ """
49
+ arg_type = type(obj)
50
+ if isinstance(obj, ScalarNdarray):
51
+ return f"[nd-{obj.dtype}-{len(obj.shape)}]"
52
+ if isinstance(obj, VectorNdarray):
53
+ return f"[ndv-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
54
+ if isinstance(obj, ScalarField):
55
+ # disabled for now, because we need to think about how to handle field offset
56
+ # etc
57
+ # TODO: think about whether there is a way to include fields
58
+ return None
59
+ if isinstance(obj, MatrixNdarray):
60
+ return f"[ndm-{obj.m}-{obj.n}-{obj.dtype}-{len(obj.shape)}]"
61
+ if "torch.Tensor" in str(arg_type):
62
+ return f"[pt-{obj.dtype}-{obj.ndim}]" # type: ignore
63
+ if isinstance(obj, np.ndarray):
64
+ return f"[np-{obj.dtype}-{obj.ndim}]"
65
+ if isinstance(obj, MatrixField):
66
+ # disabled for now, because we need to think about how to handle field offset
67
+ # etc
68
+ # TODO: think about whether there is a way to include fields
69
+ return None
70
+ if dataclasses.is_dataclass(obj):
71
+ return dataclass_to_repr(path, obj)
72
+ if is_data_oriented(obj):
73
+ child_repr_l = []
74
+ for k, v in obj.__dict__.items():
75
+ _child_repr = stringify_obj_type((*path, k), v)
76
+ if _child_repr is None:
77
+ print("not representable child", k, type(v), "path", path)
78
+ return None
79
+ child_repr_l.append(f"{k}: {_child_repr}")
80
+ return ", ".join(child_repr_l)
81
+ if issubclass(arg_type, (numbers.Number, np.number)):
82
+ return str(arg_type)
83
+ if arg_type is np.bool_:
84
+ # np is deprecating bool. Treat specially/carefully
85
+ return "np.bool_"
86
+ if isinstance(obj, enum.Enum):
87
+ return f"enum-{obj.name}-{obj.value}"
88
+ # The bit in caps should not be modified without updating corresponding test
89
+ # The rest of free text can be freely modified
90
+ # (will probably formalize this in more general doc / contributor guidelines at some point)
91
+ _logging.warn(
92
+ f"[FASTCACHE][PARAM_INVALID] Parameter with path {path} and type {arg_type} not allowed by fast cache."
93
+ )
94
+ return None
95
+
96
+
97
+ def hash_args(args: Sequence[Any]) -> str | None:
98
+ global g_num_calls, g_num_args, g_hashing_time, g_repr_time, g_num_ignored_calls
99
+ g_num_calls += 1
100
+ g_num_args += len(args)
101
+ hash_l = []
102
+ for i_arg, arg in enumerate(args):
103
+ start = time.time()
104
+ _hash = stringify_obj_type((str(i_arg),), arg)
105
+ g_repr_time += time.time() - start
106
+ if not _hash:
107
+ g_num_ignored_calls += 1
108
+ return None
109
+ hash_l.append(_hash)
110
+ start = time.time()
111
+ res = hash_iterable_strings(hash_l)
112
+ g_hashing_time += time.time() - start
113
+ return res
114
+
115
+
116
+ def dump_stats() -> None:
117
+ print("args hasher dump stats")
118
+ print("total calls", g_num_calls)
119
+ print("ignored calls", g_num_ignored_calls)
120
+ print("total args", g_num_args)
121
+ print("hashing time", g_hashing_time)
122
+ print("arg representation time", g_repr_time)
@@ -0,0 +1,30 @@
1
+ from gstaichi.lang import impl
2
+
3
+ from .hash_utils import hash_iterable_strings
4
+
5
+ EXCLUDE_PREFIXES = ["_", "offline_cache", "print_", "verbose_"]
6
+
7
+
8
+ def hash_compile_config() -> str:
9
+ """
10
+ Calculates a hash string for the current compiler config.
11
+
12
+ If any value in the compiler config changes, the hash string changes too.
13
+
14
+ Though arguably we might want to blacklist certain keys, such as print_ir_debug,
15
+ which do not affect the compiled kernels, just stuff that gets printed during
16
+ the compilation process.
17
+ """
18
+ config = impl.get_runtime().prog.config()
19
+ config_l = []
20
+ for k in dir(config):
21
+ skip = False
22
+ for prefix in EXCLUDE_PREFIXES:
23
+ if k.startswith(prefix) or k in [""]:
24
+ skip = True
25
+ if skip:
26
+ continue
27
+ v = getattr(config, k)
28
+ config_l.append(f"{k}={v}")
29
+ config_hash = hash_iterable_strings(config_l, separator="\n")
30
+ return config_hash
@@ -0,0 +1,21 @@
1
+ from pydantic import BaseModel
2
+
3
+ from .._wrap_inspect import FunctionSourceInfo
4
+
5
+
6
+ class HashedFunctionSourceInfo(BaseModel):
7
+ """
8
+ Wraps a function source info, and the hash string of that function.
9
+
10
+ By not adding the hash directly into function source info, we avoid
11
+ having to make hash an optional type, and checking if it's empty or not.
12
+
13
+ If you have a HashedFunctionSourceInfo object, then you are guaranteed
14
+ to have the hash string.
15
+
16
+ If you only have the FunctionSourceInfo object, you are guaranteed that it
17
+ does not have a hash string.
18
+ """
19
+
20
+ function_source_info: FunctionSourceInfo
21
+ hash: str
@@ -0,0 +1,57 @@
1
+ import os
2
+ from itertools import islice
3
+ from typing import TYPE_CHECKING, Iterable
4
+
5
+ from .._wrap_inspect import FunctionSourceInfo
6
+ from .fast_caching_types import HashedFunctionSourceInfo
7
+ from .hash_utils import hash_iterable_strings
8
+
9
+ if TYPE_CHECKING:
10
+ from gstaichi.lang.kernel_impl import GsTaichiCallable
11
+
12
+
13
+ def pure(fn: "GsTaichiCallable") -> "GsTaichiCallable":
14
+ fn.is_pure = True
15
+ return fn
16
+
17
+
18
+ def _read_file(function_info: FunctionSourceInfo) -> list[str]:
19
+ with open(function_info.filepath) as f:
20
+ return list(islice(f, function_info.start_lineno, function_info.end_lineno + 1))
21
+
22
+
23
+ def _hash_function(function_info: FunctionSourceInfo) -> str:
24
+ return hash_iterable_strings(_read_file(function_info))
25
+
26
+
27
+ def hash_functions(function_infos: Iterable[FunctionSourceInfo]) -> list[HashedFunctionSourceInfo]:
28
+ results = []
29
+ for f_info in function_infos:
30
+ hash_ = _hash_function(f_info)
31
+ results.append(HashedFunctionSourceInfo(function_source_info=f_info, hash=hash_))
32
+ return results
33
+
34
+
35
+ def hash_kernel(kernel_info: FunctionSourceInfo) -> str:
36
+ return _hash_function(kernel_info)
37
+
38
+
39
+ def dump_stats() -> None:
40
+ print("function hasher dump stats")
41
+
42
+
43
+ def _validate_hashed_function_info(hashed_function_info: HashedFunctionSourceInfo) -> bool:
44
+ """
45
+ Checks the hash
46
+ """
47
+ if not os.path.isfile(hashed_function_info.function_source_info.filepath):
48
+ return False
49
+ _hash = _hash_function(hashed_function_info.function_source_info)
50
+ return _hash == hashed_function_info.hash
51
+
52
+
53
+ def validate_hashed_function_infos(function_infos: Iterable[HashedFunctionSourceInfo]) -> bool:
54
+ for function_info in function_infos:
55
+ if not _validate_hashed_function_info(function_info):
56
+ return False
57
+ return True
@@ -0,0 +1,11 @@
1
+ import hashlib
2
+ from typing import Iterable
3
+
4
+
5
+ def hash_iterable_strings(strings: Iterable[str], separator: str = "_") -> str:
6
+ h = hashlib.sha256()
7
+ separator_enc = separator.encode("utf-8")
8
+ for v in strings:
9
+ h.update(v.encode("utf-8"))
10
+ h.update(separator_enc)
11
+ return h.hexdigest()
@@ -0,0 +1,52 @@
1
+ import os
2
+
3
+ from .. import impl
4
+
5
+
6
+ class PythonSideCache:
7
+ """
8
+ Manages a cache that is managed from the python side (we also have c++-side caches)
9
+
10
+ The cache is disk-based. When we create the PythonSideCache object, the cache
11
+ path is created as a sub-folder of CompileConfig.offline_cache_file_path.
12
+
13
+ Note that constructing this object is cheap, so there is no need to maintain some
14
+ kind of conceptual singleton instance or similar.
15
+
16
+ Each cache key value is stored to a single file, with the cache key as the filename.
17
+
18
+ No metadata is associated with the file, making management very lightweight.
19
+
20
+ We update the file date/time when we read from a particular file, so we can easily
21
+ implement an LRU cleaning strategy at some point in the future, based on the file
22
+ date/times.
23
+ """
24
+
25
+ def __init__(self) -> None:
26
+ _cache_parent_folder = impl.get_runtime().prog.config().offline_cache_file_path
27
+ self.cache_folder = os.path.join(_cache_parent_folder, "python_side_cache")
28
+ os.makedirs(self.cache_folder, exist_ok=True)
29
+
30
+ def _get_filepath(self, key: str) -> str:
31
+ filepath = os.path.join(self.cache_folder, f"{key}.cache.txt")
32
+ return filepath
33
+
34
+ def _touch(self, filepath):
35
+ """
36
+ Updates file date/time.
37
+ """
38
+ with open(filepath, "a"):
39
+ os.utime(filepath, None)
40
+
41
+ def store(self, key: str, value: str) -> None:
42
+ filepath = self._get_filepath(key)
43
+ with open(filepath, "w") as f:
44
+ f.write(value)
45
+
46
+ def try_load(self, key: str) -> str | None:
47
+ filepath = self._get_filepath(key)
48
+ if not os.path.isfile(filepath):
49
+ return None
50
+ self._touch(filepath)
51
+ with open(filepath) as f:
52
+ return f.read()