gstaichi 0.1.25.dev0__cp311-cp311-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (138) hide show
  1. gstaichi/CHANGELOG.md +9 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/__main__.py +5 -0
  4. gstaichi/_funcs.py +706 -0
  5. gstaichi/_kernels.py +420 -0
  6. gstaichi/_lib/__init__.py +3 -0
  7. gstaichi/_lib/core/__init__.py +0 -0
  8. gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
  9. gstaichi/_lib/core/gstaichi_python.pyi +2937 -0
  10. gstaichi/_lib/core/py.typed +0 -0
  11. gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
  12. gstaichi/_lib/runtime/runtime_x64.bc +0 -0
  13. gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
  14. gstaichi/_lib/utils.py +249 -0
  15. gstaichi/_logging.py +131 -0
  16. gstaichi/_main.py +545 -0
  17. gstaichi/_snode/__init__.py +5 -0
  18. gstaichi/_snode/fields_builder.py +187 -0
  19. gstaichi/_snode/snode_tree.py +34 -0
  20. gstaichi/_test_tools/__init__.py +0 -0
  21. gstaichi/_test_tools/load_kernel_string.py +30 -0
  22. gstaichi/_version.py +1 -0
  23. gstaichi/_version_check.py +103 -0
  24. gstaichi/ad/__init__.py +3 -0
  25. gstaichi/ad/_ad.py +530 -0
  26. gstaichi/algorithms/__init__.py +3 -0
  27. gstaichi/algorithms/_algorithms.py +117 -0
  28. gstaichi/assets/.git +1 -0
  29. gstaichi/assets/Go-Regular.ttf +0 -0
  30. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_ndarray.py +352 -0
  35. gstaichi/lang/_ndrange.py +152 -0
  36. gstaichi/lang/_template_mapper.py +199 -0
  37. gstaichi/lang/_texture.py +172 -0
  38. gstaichi/lang/_wrap_inspect.py +189 -0
  39. gstaichi/lang/any_array.py +99 -0
  40. gstaichi/lang/argpack.py +411 -0
  41. gstaichi/lang/ast/__init__.py +5 -0
  42. gstaichi/lang/ast/ast_transformer.py +1318 -0
  43. gstaichi/lang/ast/ast_transformer_utils.py +341 -0
  44. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  45. gstaichi/lang/ast/ast_transformers/call_transformer.py +267 -0
  46. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +320 -0
  47. gstaichi/lang/ast/checkers.py +106 -0
  48. gstaichi/lang/ast/symbol_resolver.py +57 -0
  49. gstaichi/lang/ast/transform.py +9 -0
  50. gstaichi/lang/common_ops.py +310 -0
  51. gstaichi/lang/exception.py +80 -0
  52. gstaichi/lang/expr.py +180 -0
  53. gstaichi/lang/field.py +466 -0
  54. gstaichi/lang/impl.py +1241 -0
  55. gstaichi/lang/kernel_arguments.py +157 -0
  56. gstaichi/lang/kernel_impl.py +1382 -0
  57. gstaichi/lang/matrix.py +1881 -0
  58. gstaichi/lang/matrix_ops.py +341 -0
  59. gstaichi/lang/matrix_ops_utils.py +190 -0
  60. gstaichi/lang/mesh.py +687 -0
  61. gstaichi/lang/misc.py +778 -0
  62. gstaichi/lang/ops.py +1494 -0
  63. gstaichi/lang/runtime_ops.py +13 -0
  64. gstaichi/lang/shell.py +35 -0
  65. gstaichi/lang/simt/__init__.py +5 -0
  66. gstaichi/lang/simt/block.py +94 -0
  67. gstaichi/lang/simt/grid.py +7 -0
  68. gstaichi/lang/simt/subgroup.py +191 -0
  69. gstaichi/lang/simt/warp.py +96 -0
  70. gstaichi/lang/snode.py +489 -0
  71. gstaichi/lang/source_builder.py +150 -0
  72. gstaichi/lang/struct.py +855 -0
  73. gstaichi/lang/util.py +381 -0
  74. gstaichi/linalg/__init__.py +8 -0
  75. gstaichi/linalg/matrixfree_cg.py +310 -0
  76. gstaichi/linalg/sparse_cg.py +59 -0
  77. gstaichi/linalg/sparse_matrix.py +303 -0
  78. gstaichi/linalg/sparse_solver.py +123 -0
  79. gstaichi/math/__init__.py +11 -0
  80. gstaichi/math/_complex.py +205 -0
  81. gstaichi/math/mathimpl.py +886 -0
  82. gstaichi/profiler/__init__.py +6 -0
  83. gstaichi/profiler/kernel_metrics.py +260 -0
  84. gstaichi/profiler/kernel_profiler.py +586 -0
  85. gstaichi/profiler/memory_profiler.py +15 -0
  86. gstaichi/profiler/scoped_profiler.py +36 -0
  87. gstaichi/sparse/__init__.py +3 -0
  88. gstaichi/sparse/_sparse_grid.py +77 -0
  89. gstaichi/tools/__init__.py +12 -0
  90. gstaichi/tools/diagnose.py +117 -0
  91. gstaichi/tools/np2ply.py +364 -0
  92. gstaichi/tools/vtk.py +38 -0
  93. gstaichi/types/__init__.py +19 -0
  94. gstaichi/types/annotations.py +47 -0
  95. gstaichi/types/compound_types.py +90 -0
  96. gstaichi/types/enums.py +49 -0
  97. gstaichi/types/ndarray_type.py +147 -0
  98. gstaichi/types/primitive_types.py +206 -0
  99. gstaichi/types/quant.py +88 -0
  100. gstaichi/types/texture_type.py +85 -0
  101. gstaichi/types/utils.py +13 -0
  102. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
  103. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
  104. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
  105. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
  106. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
  107. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
  108. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
  109. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
  110. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
  111. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
  112. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
  113. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
  114. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
  115. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
  116. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
  117. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
  118. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  119. gstaichi-0.1.25.dev0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
  120. gstaichi-0.1.25.dev0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
  121. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/instrument.hpp +268 -0
  122. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.h +907 -0
  123. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/libspirv.hpp +375 -0
  124. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/linker.hpp +97 -0
  125. gstaichi-0.1.25.dev0.data/data/include/spirv-tools/optimizer.hpp +970 -0
  126. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
  127. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-link.lib +0 -0
  128. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
  129. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
  130. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
  131. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
  132. gstaichi-0.1.25.dev0.data/data/lib/SPIRV-Tools.lib +0 -0
  133. gstaichi-0.1.25.dev0.dist-info/METADATA +105 -0
  134. gstaichi-0.1.25.dev0.dist-info/RECORD +138 -0
  135. gstaichi-0.1.25.dev0.dist-info/WHEEL +5 -0
  136. gstaichi-0.1.25.dev0.dist-info/entry_points.txt +2 -0
  137. gstaichi-0.1.25.dev0.dist-info/licenses/LICENSE +201 -0
  138. gstaichi-0.1.25.dev0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,157 @@
1
+ # type: ignore
2
+
3
+ import inspect
4
+
5
+ import gstaichi.lang
6
+ from gstaichi._lib import core as _ti_core
7
+ from gstaichi.lang import impl, ops
8
+ from gstaichi.lang._texture import RWTextureAccessor, TextureSampler
9
+ from gstaichi.lang.any_array import AnyArray
10
+ from gstaichi.lang.expr import Expr
11
+ from gstaichi.lang.matrix import MatrixType
12
+ from gstaichi.lang.struct import StructType
13
+ from gstaichi.lang.util import cook_dtype
14
+ from gstaichi.types.compound_types import CompoundType
15
+ from gstaichi.types.primitive_types import RefType, u64
16
+
17
+
18
+ class KernelArgument:
19
+ def __init__(self, _annotation, _name, _default=inspect.Parameter.empty):
20
+ self.annotation = _annotation
21
+ self.name = _name
22
+ self.default = _default
23
+
24
+
25
+ class SparseMatrixEntry:
26
+ def __init__(self, ptr, i, j, dtype):
27
+ self.ptr = ptr
28
+ self.i = i
29
+ self.j = j
30
+ self.dtype = dtype
31
+
32
+ def _augassign(self, value, op):
33
+ call_func = f"insert_triplet_{self.dtype}"
34
+ if op == "Add":
35
+ gstaichi.lang.impl.call_internal(call_func, self.ptr, self.i, self.j, ops.cast(value, self.dtype))
36
+ elif op == "Sub":
37
+ gstaichi.lang.impl.call_internal(call_func, self.ptr, self.i, self.j, -ops.cast(value, self.dtype))
38
+ else:
39
+ assert False, "Only operations '+=' and '-=' are supported on sparse matrices."
40
+
41
+
42
+ class SparseMatrixProxy:
43
+ def __init__(self, ptr, dtype):
44
+ self.ptr = ptr
45
+ self.dtype = dtype
46
+
47
+ def subscript(self, i, j):
48
+ return SparseMatrixEntry(self.ptr, i, j, self.dtype)
49
+
50
+
51
+ def decl_scalar_arg(dtype, name, arg_depth):
52
+ is_ref = False
53
+ if isinstance(dtype, RefType):
54
+ is_ref = True
55
+ dtype = dtype.tp
56
+ dtype = cook_dtype(dtype)
57
+ if is_ref:
58
+ arg_id = impl.get_runtime().compiling_callable.insert_pointer_param(dtype, name)
59
+ else:
60
+ arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(dtype, name)
61
+
62
+ argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
63
+ return Expr(
64
+ _ti_core.make_arg_load_expr(arg_id, dtype, is_ref, create_load=True, arg_depth=arg_depth, dbg_info=argload_di)
65
+ )
66
+
67
+
68
+ def get_type_for_kernel_args(dtype, name):
69
+ if isinstance(dtype, MatrixType):
70
+ # Compiling the matrix type to a struct type because the support for the matrix type is not ready yet on SPIR-V based backends.
71
+ if dtype.ndim == 1:
72
+ elements = [(dtype.dtype, f"{name}_{i}") for i in range(dtype.n)]
73
+ else:
74
+ elements = [(dtype.dtype, f"{name}_{i}_{j}") for i in range(dtype.n) for j in range(dtype.m)]
75
+ return _ti_core.get_type_factory_instance().get_struct_type(elements)
76
+ if isinstance(dtype, StructType):
77
+ elements = []
78
+ for k, element_type in dtype.members.items():
79
+ if isinstance(element_type, CompoundType):
80
+ new_dtype = get_type_for_kernel_args(element_type, k)
81
+ elements.append([new_dtype, k])
82
+ else:
83
+ elements.append([element_type, k])
84
+ return _ti_core.get_type_factory_instance().get_struct_type(elements)
85
+ # Assuming dtype is a primitive type
86
+ return dtype
87
+
88
+
89
+ def decl_matrix_arg(matrixtype, name, arg_depth):
90
+ arg_type = get_type_for_kernel_args(matrixtype, name)
91
+ arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
92
+ argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
93
+ arg_load = Expr(
94
+ _ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
95
+ )
96
+ return matrixtype.from_gstaichi_object(arg_load)
97
+
98
+
99
+ def decl_struct_arg(structtype, name, arg_depth):
100
+ arg_type = get_type_for_kernel_args(structtype, name)
101
+ arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(arg_type, name)
102
+ argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
103
+ arg_load = Expr(
104
+ _ti_core.make_arg_load_expr(arg_id, arg_type, create_load=False, arg_depth=arg_depth, dbg_info=argload_di)
105
+ )
106
+ return structtype.from_gstaichi_object(arg_load)
107
+
108
+
109
+ def push_argpack_arg(name):
110
+ impl.get_runtime().compiling_callable.insert_argpack_param_and_push(name)
111
+
112
+
113
+ def decl_argpack_arg(argpacktype, member_dict):
114
+ impl.get_runtime().compiling_callable.pop_argpack_stack()
115
+ return argpacktype.from_gstaichi_object(member_dict)
116
+
117
+
118
+ def decl_sparse_matrix(dtype, name):
119
+ value_type = cook_dtype(dtype)
120
+ ptr_type = cook_dtype(u64)
121
+ # Treat the sparse matrix argument as a scalar since we only need to pass in the base pointer
122
+ arg_id = impl.get_runtime().compiling_callable.insert_scalar_param(ptr_type, name)
123
+ argload_di = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
124
+ return SparseMatrixProxy(
125
+ _ti_core.make_arg_load_expr(arg_id, ptr_type, is_ptr=False, dbg_info=argload_di), value_type
126
+ )
127
+
128
+
129
+ def decl_ndarray_arg(element_type, ndim, name, needs_grad, boundary):
130
+ arg_id = impl.get_runtime().compiling_callable.insert_ndarray_param(element_type, ndim, name, needs_grad)
131
+ return AnyArray(_ti_core.make_external_tensor_expr(element_type, ndim, arg_id, needs_grad, 0, boundary))
132
+
133
+
134
+ def decl_texture_arg(num_dimensions, name):
135
+ # FIXME: texture_arg doesn't have element_shape so better separate them
136
+ arg_id = impl.get_runtime().compiling_callable.insert_texture_param(num_dimensions, name)
137
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
138
+ return TextureSampler(_ti_core.make_texture_ptr_expr(arg_id, num_dimensions, 0, dbg_info), num_dimensions)
139
+
140
+
141
+ def decl_rw_texture_arg(num_dimensions, buffer_format, lod, name):
142
+ # FIXME: texture_arg doesn't have element_shape so better separate them
143
+ arg_id = impl.get_runtime().compiling_callable.insert_rw_texture_param(num_dimensions, buffer_format, name)
144
+ dbg_info = _ti_core.DebugInfo(impl.get_runtime().get_current_src_info())
145
+ return RWTextureAccessor(
146
+ _ti_core.make_rw_texture_ptr_expr(arg_id, num_dimensions, 0, buffer_format, lod, dbg_info), num_dimensions
147
+ )
148
+
149
+
150
+ def decl_ret(dtype):
151
+ if isinstance(dtype, StructType):
152
+ dtype = dtype.dtype
153
+ if isinstance(dtype, MatrixType):
154
+ dtype = _ti_core.get_type_factory_instance().get_tensor_type([dtype.n, dtype.m], dtype.dtype)
155
+ else:
156
+ dtype = cook_dtype(dtype)
157
+ impl.get_runtime().compiling_callable.insert_ret(dtype)