gstaichi 2.1.1rc3__cp313-cp313-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (179) hide show
  1. gstaichi/CHANGELOG.md +4 -0
  2. gstaichi/__init__.py +40 -0
  3. gstaichi/_funcs.py +706 -0
  4. gstaichi/_kernels.py +420 -0
  5. gstaichi/_lib/__init__.py +3 -0
  6. gstaichi/_lib/core/__init__.py +0 -0
  7. gstaichi/_lib/core/gstaichi_python.cpython-313-darwin.so +0 -0
  8. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  9. gstaichi/_lib/core/py.typed +0 -0
  10. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  11. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  12. gstaichi/_lib/utils.py +243 -0
  13. gstaichi/_logging.py +131 -0
  14. gstaichi/_snode/__init__.py +5 -0
  15. gstaichi/_snode/fields_builder.py +187 -0
  16. gstaichi/_snode/snode_tree.py +34 -0
  17. gstaichi/_test_tools/__init__.py +18 -0
  18. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  19. gstaichi/_test_tools/load_kernel_string.py +30 -0
  20. gstaichi/_test_tools/textwrap2.py +6 -0
  21. gstaichi/_version.py +1 -0
  22. gstaichi/_version_check.py +100 -0
  23. gstaichi/ad/__init__.py +3 -0
  24. gstaichi/ad/_ad.py +530 -0
  25. gstaichi/algorithms/__init__.py +3 -0
  26. gstaichi/algorithms/_algorithms.py +117 -0
  27. gstaichi/assets/.git +1 -0
  28. gstaichi/assets/Go-Regular.ttf +0 -0
  29. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  30. gstaichi/examples/lcg_python.py +26 -0
  31. gstaichi/examples/lcg_taichi.py +34 -0
  32. gstaichi/examples/minimal.py +28 -0
  33. gstaichi/experimental.py +16 -0
  34. gstaichi/lang/__init__.py +50 -0
  35. gstaichi/lang/_dataclass_util.py +31 -0
  36. gstaichi/lang/_fast_caching/__init__.py +3 -0
  37. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  38. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  39. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  40. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  41. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  42. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  43. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  44. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  45. gstaichi/lang/_ndarray.py +352 -0
  46. gstaichi/lang/_ndrange.py +152 -0
  47. gstaichi/lang/_template_mapper.py +195 -0
  48. gstaichi/lang/_texture.py +172 -0
  49. gstaichi/lang/_wrap_inspect.py +215 -0
  50. gstaichi/lang/any_array.py +99 -0
  51. gstaichi/lang/ast/__init__.py +5 -0
  52. gstaichi/lang/ast/ast_transformer.py +1323 -0
  53. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  54. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  55. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  56. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  57. gstaichi/lang/ast/checkers.py +106 -0
  58. gstaichi/lang/ast/symbol_resolver.py +57 -0
  59. gstaichi/lang/ast/transform.py +9 -0
  60. gstaichi/lang/common_ops.py +310 -0
  61. gstaichi/lang/exception.py +80 -0
  62. gstaichi/lang/expr.py +180 -0
  63. gstaichi/lang/field.py +428 -0
  64. gstaichi/lang/impl.py +1243 -0
  65. gstaichi/lang/kernel_arguments.py +155 -0
  66. gstaichi/lang/kernel_impl.py +1341 -0
  67. gstaichi/lang/matrix.py +1835 -0
  68. gstaichi/lang/matrix_ops.py +341 -0
  69. gstaichi/lang/matrix_ops_utils.py +190 -0
  70. gstaichi/lang/mesh.py +687 -0
  71. gstaichi/lang/misc.py +782 -0
  72. gstaichi/lang/ops.py +1494 -0
  73. gstaichi/lang/runtime_ops.py +13 -0
  74. gstaichi/lang/shell.py +35 -0
  75. gstaichi/lang/simt/__init__.py +5 -0
  76. gstaichi/lang/simt/block.py +94 -0
  77. gstaichi/lang/simt/grid.py +7 -0
  78. gstaichi/lang/simt/subgroup.py +191 -0
  79. gstaichi/lang/simt/warp.py +96 -0
  80. gstaichi/lang/snode.py +489 -0
  81. gstaichi/lang/source_builder.py +150 -0
  82. gstaichi/lang/struct.py +810 -0
  83. gstaichi/lang/util.py +312 -0
  84. gstaichi/linalg/__init__.py +8 -0
  85. gstaichi/linalg/matrixfree_cg.py +310 -0
  86. gstaichi/linalg/sparse_cg.py +59 -0
  87. gstaichi/linalg/sparse_matrix.py +303 -0
  88. gstaichi/linalg/sparse_solver.py +123 -0
  89. gstaichi/math/__init__.py +11 -0
  90. gstaichi/math/_complex.py +205 -0
  91. gstaichi/math/mathimpl.py +886 -0
  92. gstaichi/profiler/__init__.py +6 -0
  93. gstaichi/profiler/kernel_metrics.py +260 -0
  94. gstaichi/profiler/kernel_profiler.py +586 -0
  95. gstaichi/profiler/memory_profiler.py +15 -0
  96. gstaichi/profiler/scoped_profiler.py +36 -0
  97. gstaichi/sparse/__init__.py +3 -0
  98. gstaichi/sparse/_sparse_grid.py +77 -0
  99. gstaichi/tools/__init__.py +12 -0
  100. gstaichi/tools/diagnose.py +117 -0
  101. gstaichi/tools/np2ply.py +364 -0
  102. gstaichi/tools/vtk.py +38 -0
  103. gstaichi/types/__init__.py +19 -0
  104. gstaichi/types/annotations.py +52 -0
  105. gstaichi/types/compound_types.py +71 -0
  106. gstaichi/types/enums.py +49 -0
  107. gstaichi/types/ndarray_type.py +169 -0
  108. gstaichi/types/primitive_types.py +206 -0
  109. gstaichi/types/quant.py +88 -0
  110. gstaichi/types/texture_type.py +85 -0
  111. gstaichi/types/utils.py +11 -0
  112. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3.h +6389 -0
  113. gstaichi-2.1.1rc3.data/data/include/GLFW/glfw3native.h +594 -0
  114. gstaichi-2.1.1rc3.data/data/include/spirv-tools/instrument.hpp +268 -0
  115. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.h +907 -0
  116. gstaichi-2.1.1rc3.data/data/include/spirv-tools/libspirv.hpp +375 -0
  117. gstaichi-2.1.1rc3.data/data/include/spirv-tools/linker.hpp +97 -0
  118. gstaichi-2.1.1rc3.data/data/include/spirv-tools/optimizer.hpp +970 -0
  119. gstaichi-2.1.1rc3.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  120. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.h +2568 -0
  121. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv.hpp +2579 -0
  122. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  123. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  124. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  125. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  126. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  127. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  128. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  129. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  130. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  131. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  132. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  133. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  134. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  135. gstaichi-2.1.1rc3.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  136. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  137. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  138. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  139. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  140. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  141. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  142. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  143. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  144. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  145. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  146. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  147. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  148. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  149. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  150. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  151. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  152. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  153. gstaichi-2.1.1rc3.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  154. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  155. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  156. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  157. gstaichi-2.1.1rc3.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  158. gstaichi-2.1.1rc3.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  159. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  160. gstaichi-2.1.1rc3.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  161. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  162. gstaichi-2.1.1rc3.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  163. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  164. gstaichi-2.1.1rc3.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  165. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  166. gstaichi-2.1.1rc3.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  167. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  168. gstaichi-2.1.1rc3.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  169. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  170. gstaichi-2.1.1rc3.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  171. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  172. gstaichi-2.1.1rc3.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  173. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  174. gstaichi-2.1.1rc3.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  175. gstaichi-2.1.1rc3.dist-info/METADATA +106 -0
  176. gstaichi-2.1.1rc3.dist-info/RECORD +179 -0
  177. gstaichi-2.1.1rc3.dist-info/WHEEL +5 -0
  178. gstaichi-2.1.1rc3.dist-info/licenses/LICENSE +201 -0
  179. gstaichi-2.1.1rc3.dist-info/top_level.txt +1 -0
@@ -0,0 +1,77 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.lang.impl import grouped, root, static
4
+ from gstaichi.lang.kernel_impl import kernel
5
+ from gstaichi.lang.misc import ij, ijk
6
+ from gstaichi.lang.snode import is_active
7
+ from gstaichi.lang.struct import Struct
8
+ from gstaichi.types.annotations import template
9
+ from gstaichi.types.primitive_types import f32
10
+
11
+
12
+ def grid(field_dict, shape):
13
+ """Creates a 2D/3D sparse grid with each element is a struct. The struct is placed on a bitmasked snode.
14
+
15
+ Args:
16
+ field_dict (dict): a dict, each item is like `name: type`.
17
+ shape (Tuple[int]): shape of the field.
18
+ Returns:
19
+ x: the created sparse grid, which is a bitmasked `ti.Struct.field`.
20
+
21
+ Examples::
22
+ # create a 2D sparse grid
23
+ >>> grid = ti.sparse.grid({'pos': ti.math.vec2, 'mass': ti.f32, 'grid2particles': ti.types.vector(20, ti.i32)}, shape=(10, 10))
24
+
25
+ # access
26
+ >>> grid[0, 0].pos = ti.math.vec2(1.0, 2.0)
27
+ >>> grid[0, 0].mass = 1.0
28
+ >>> grid[0, 0].grid2particles[2] = 123
29
+
30
+ # print the usage of the sparse grid, which is in [0,1]
31
+ >>> print(ti.sparse.usage(grid))
32
+ # 0.009999999776482582
33
+ """
34
+ x = Struct.field(field_dict)
35
+ if len(shape) == 2:
36
+ snode = root.bitmasked(ij, shape)
37
+ snode.place(x)
38
+ elif len(shape) == 3:
39
+ snode = root.bitmasked(ijk, shape)
40
+ snode.place(x)
41
+ else:
42
+ raise Exception("Only 2D and 3D sparse grids are supported")
43
+ return x
44
+
45
+
46
+ @kernel
47
+ def usage(x: template()) -> f32:
48
+ """
49
+ Get the usage of the sparse grid, which is in [0,1]
50
+
51
+ Args:
52
+ x(struct field): the sparse grid to be checked.
53
+ Returns:
54
+ usage(f32): the usage of the sparse grid, which is in [0,1]
55
+
56
+ Examples::
57
+ >>> grid = ti.sparse.grid({'pos': ti.math.vec2, 'mass': ti.f32, 'grid2particles': ti.types.vector(20, ti.i32)}, shape=(10, 10))
58
+ >>> grid[0, 0].mass = 1.0
59
+ >>> print(ti.sparse.usage(grid))
60
+ # 0.009999999776482582
61
+ """
62
+ cnt = 0
63
+ for I in grouped(x.parent()):
64
+ if is_active(x.parent(), I):
65
+ cnt += 1
66
+ total = 1.0
67
+ if static(len(x.shape) == 2):
68
+ total = x.shape[0] * x.shape[1]
69
+ elif static(len(x.shape) == 3):
70
+ total = x.shape[0] * x.shape[1] * x.shape[2]
71
+ else:
72
+ raise ValueError("The dimension of the sparse grid should be 2 or 3")
73
+ res = cnt / total
74
+ return res
75
+
76
+
77
+ __all__ = ["grid", "usage"]
@@ -0,0 +1,12 @@
1
+ # type: ignore
2
+
3
+ """GsTaichi utility module.
4
+
5
+ - `image` submodule for image io.
6
+ - `video` submodule for exporting results to video files.
7
+ - `diagnose` submodule for printing system environment information.
8
+ """
9
+
10
+ from gstaichi.tools.diagnose import *
11
+ from gstaichi.tools.np2ply import *
12
+ from gstaichi.tools.vtk import *
@@ -0,0 +1,117 @@
1
+ # type: ignore
2
+
3
+ import locale
4
+ import os
5
+ import platform
6
+ import subprocess
7
+ import sys
8
+
9
+
10
+ def main():
11
+ print("GsTaichi system diagnose:")
12
+ print("")
13
+ executable = sys.executable
14
+
15
+ print(f"python: {sys.version}")
16
+ print(f"system: {sys.platform}")
17
+ print(f"executable: {executable}")
18
+ print(f"platform: {platform.platform()}")
19
+ print(f'architecture: {" ".join(platform.architecture())}')
20
+ print(f"uname: {platform.uname()}")
21
+
22
+ print(f'locale: {".".join(locale.getdefaultlocale())}')
23
+ print(f'PATH: {os.environ.get("PATH")}')
24
+ print(f"PYTHONPATH: {sys.path}")
25
+ print("")
26
+
27
+ try:
28
+ lsb_release = subprocess.check_output(["lsb_release", "-a"])
29
+ except Exception as e:
30
+ print(f"`lsb_release` not available: {e}")
31
+ else:
32
+ print(f"{lsb_release.decode()}")
33
+
34
+ for k, v in os.environ.items():
35
+ if k.startswith("TI_"):
36
+ print(f"{k}={v}")
37
+ print("")
38
+
39
+ def try_print(tag, expr):
40
+ try:
41
+ cmd = f'import gstaichi as ti; print("===="); print({expr}, end="")'
42
+ ret = subprocess.check_output([executable, "-c", cmd]).decode()
43
+ ret = ret.split("====" + os.linesep, maxsplit=1)[1]
44
+ print(f"{tag}: {ret}")
45
+ except Exception as e:
46
+ print(f"{tag}: ERROR {e}")
47
+
48
+ print("")
49
+ try_print("import", "ti")
50
+ print("")
51
+ for arch in ["cpu", "metal", "cuda", "vulkan"]:
52
+ try_print(arch, f"ti.lang.misc.is_arch_supported(ti.{arch})")
53
+ print("")
54
+
55
+ try:
56
+ glewinfo = subprocess.check_output(["glewinfo"])
57
+ except Exception as e:
58
+ print(f"`glewinfo` not available: {e}")
59
+ else:
60
+ for line in glewinfo.decode().splitlines():
61
+ if line.startswith("OpenGL version"):
62
+ print(line)
63
+ continue
64
+
65
+ exts = [
66
+ "GL_ARB_compute_shader",
67
+ "GL_ARB_gpu_shader_int64",
68
+ "GL_NV_shader_atomic_float",
69
+ "GL_NV_shader_atomic_float64",
70
+ "GL_NV_shader_atomic_int64",
71
+ ]
72
+ if line.split(":")[0] in exts:
73
+ print(line)
74
+
75
+ print("")
76
+ try:
77
+ nvidia_smi = subprocess.check_output(["nvidia-smi"])
78
+ except Exception as e:
79
+ print(f"`nvidia-smi` not available: {e}")
80
+ else:
81
+ print(f"{nvidia_smi.decode()}")
82
+
83
+ try:
84
+ ti_header = subprocess.check_output([executable, "-c", "import gstaichi"])
85
+ except Exception as e:
86
+ print(f"`import gstaichi` failed: {e}")
87
+ else:
88
+ print(f"{ti_header.decode()}")
89
+
90
+ try:
91
+ ti_init_test = subprocess.check_output([executable, "-c", "import gstaichi as ti; ti.init()"])
92
+ except Exception as e:
93
+ print(f"`ti.init()` failed: {e}")
94
+ else:
95
+ print(f"{ti_init_test.decode()}")
96
+
97
+ try:
98
+ ti_cuda_test = subprocess.check_output([executable, "-c", "import gstaichi as ti; ti.init(arch=ti.cuda)"])
99
+ except Exception as e:
100
+ print(f"GsTaichi CUDA test failed: {e}")
101
+ else:
102
+ print(f"{ti_cuda_test.decode()}")
103
+
104
+ try:
105
+ ti_laplace = subprocess.check_output([executable, "-m", "gstaichi", "example", "minimal"])
106
+ except Exception as e:
107
+ print(f"`python/gstaichi/examples/algorithm/laplace.py` failed: {e}")
108
+ else:
109
+ print(f"{ti_laplace.decode()}")
110
+
111
+ print("Consider attaching this log when maintainers ask about system information.")
112
+
113
+
114
+ if __name__ == "__main__":
115
+ main()
116
+
117
+ __all__ = []
@@ -0,0 +1,364 @@
1
+ # type: ignore
2
+
3
+ # convert numpy array to ply files
4
+ import sys
5
+
6
+ import numpy as np
7
+
8
+
9
+ class PLYWriter:
10
+ """Writes `numpy.array` data to `ply` files.
11
+
12
+ Args:
13
+ num_vertices (int): number of vertices.
14
+ num_faces (int, optional): number of faces.
15
+ face_type (str): `tri` or `quad`.
16
+ comment (str): comment message.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ num_vertices: int,
22
+ num_faces=0,
23
+ face_type="tri",
24
+ comment="created by PLYWriter",
25
+ ):
26
+ assert num_vertices > 0, "num_vertices should be greater than 0"
27
+ assert num_faces >= 0, "num_faces shouldn't be less than 0"
28
+ assert face_type == "tri" or face_type == "quad", "Only tri and quad faces are supported for now"
29
+
30
+ self.ply_supported_types = [
31
+ "char",
32
+ "uchar",
33
+ "short",
34
+ "ushort",
35
+ "int",
36
+ "uint",
37
+ "float",
38
+ "double",
39
+ ]
40
+ self.corresponding_numpy_types = [
41
+ np.int8,
42
+ np.uint8,
43
+ np.int16,
44
+ np.uint16,
45
+ np.int32,
46
+ np.uint32,
47
+ np.float32,
48
+ np.float64,
49
+ ]
50
+ self.type_map = {}
51
+ for i, ply_type in enumerate(self.ply_supported_types):
52
+ self.type_map[ply_type] = self.corresponding_numpy_types[i]
53
+
54
+ self.num_vertices = num_vertices
55
+ self.num_vertex_channels = 0
56
+ self.vertex_channels = []
57
+ self.vertex_data_type = []
58
+ self.vertex_data = []
59
+ self.num_faces = num_faces
60
+ self.num_face_channels = 0
61
+ self.face_channels = []
62
+ self.face_data_type = []
63
+ self.face_data = []
64
+ self.face_type = face_type
65
+ if face_type == "tri":
66
+ self.face_indices = -np.ones((self.num_faces, 3), dtype=np.int32)
67
+ elif face_type == "quad":
68
+ self.face_indices = -np.ones((self.num_faces, 4), dtype=np.int32)
69
+ self.comment = comment
70
+
71
+ def add_vertex_channel(self, key: str, data_type: str, data: np.array):
72
+ if data_type not in self.ply_supported_types:
73
+ print("Unknown type " + data_type + " detected, skipping this channel")
74
+ return
75
+ if data.ndim == 1:
76
+ assert data.size == self.num_vertices, "The dimension of the vertex channel is not correct"
77
+ self.num_vertex_channels += 1
78
+ if key in self.vertex_channels:
79
+ print("WARNING: duplicate key " + key + " detected")
80
+ self.vertex_channels.append(key)
81
+ self.vertex_data_type.append(data_type)
82
+ self.vertex_data.append(self.type_map[data_type](data))
83
+ else:
84
+ num_col = data.size // self.num_vertices
85
+ assert (
86
+ data.ndim == 2 and data.size == num_col * self.num_vertices
87
+ ), "The dimension of the vertex channel is not correct"
88
+ data.shape = (self.num_vertices, num_col)
89
+ self.num_vertex_channels += num_col
90
+ for i in range(num_col):
91
+ item_key = key + "_" + str(i + 1)
92
+ if item_key in self.vertex_channels:
93
+ print("WARNING: duplicate key " + item_key + " detected")
94
+ self.vertex_channels.append(item_key)
95
+ self.vertex_data_type.append(data_type)
96
+ self.vertex_data.append(self.type_map[data_type](data[:, i]))
97
+
98
+ def add_vertex_pos(self, x: np.array, y: np.array, z: np.array):
99
+ """Set the (x, y, z) coordinates of the vertices.
100
+
101
+ Args:
102
+ x (`numpy.array(float)`): x-coordinates of the vertices.
103
+ y (`numpy.array(float)`): y-coordinates of the vertices.
104
+ z (`numpy.array(float)`): z-coordinates of the vertices.
105
+ """
106
+ self.add_vertex_channel("x", "float", x)
107
+ self.add_vertex_channel("y", "float", y)
108
+ self.add_vertex_channel("z", "float", z)
109
+
110
+ # TODO active and refactor later if user feedback indicates the necessity for a compact the input list
111
+ # pass ti vector/matrix field directly
112
+ # def add_vertex_pos(self, pos):
113
+ # assert isinstance(pos, (np.ndarray, ti.Matrix))
114
+ # if not isinstance(pos, np.ndarray):
115
+ # pos = pos.to_numpy()
116
+ # dim = pos.shape[pos.ndim-1]
117
+ # assert dim == 2 or dim == 3, "Only 2D and 3D positions are supported"
118
+ # n = pos.size // dim
119
+ # assert n == self.num_vertices, "Size of the input is not correct"
120
+ # pos = np.reshape(pos, (n, dim))
121
+ # self.add_vertex_channel("x", "float", pos[:, 0])
122
+ # self.add_vertex_channel("y", "float", pos[:, 1])
123
+ # if(dim == 3):
124
+ # self.add_vertex_channel("z", "float", pos[:, 2])
125
+ # if(dim == 2):
126
+ # self.add_vertex_channel("z", "float", np.zeros(n))
127
+
128
+ def add_vertex_normal(self, nx: np.array, ny: np.array, nz: np.array):
129
+ """Add normal vectors at the vertices.
130
+
131
+ The three arguments are all numpy arrays of float type and have
132
+ the same length.
133
+
134
+ Args:
135
+ nx (`numpy.array(float)`): x-coordinates of the normal vectors.
136
+ ny (`numpy.array(float)`): y-coordinates of the normal vectors.
137
+ nz (`numpy.array(float)`): z-coordinates of the normal vectors.
138
+ """
139
+ self.add_vertex_channel("nx", "float", nx)
140
+ self.add_vertex_channel("ny", "float", ny)
141
+ self.add_vertex_channel("nz", "float", nz)
142
+
143
+ # TODO active and refactor later if user feedback indicates the necessity for a compact the input list
144
+ # pass ti vector/matrix field directly
145
+ # def add_vertex_normal(self, normal):
146
+ # assert isinstance(normal, (np.ndarray, ti.Matrix))
147
+ # if not isinstance(normal, np.ndarray):
148
+ # normal = normal.to_numpy()
149
+ # dim = normal.shape[normal.ndim-1]
150
+ # assert dim == 3, "Only 3D normal is supported"
151
+ # n = normal.size // dim
152
+ # assert n == self.num_vertices, "Size of the input is not correct"
153
+ # normal = np.reshape(normal, (n, dim))
154
+ # self.add_vertex_channel("nx", "float", normal[:, 0])
155
+ # self.add_vertex_channel("ny", "float", normal[:, 1])
156
+ # self.add_vertex_channel("nz", "float", normal[:, 2])
157
+
158
+ def add_vertex_vel(self, vx: np.array, vy: np.array, vz: np.array):
159
+ """Add velocity vectors at the vertices.
160
+
161
+ Args:
162
+ vx (`numpy.array(float)`): x-coordinates of the velocity vectors.
163
+ vy (`numpy.array(float)`): y-coordinates of the velocity vectors.
164
+ vz (`numpy.array(float)`): z-coordinates of the velocity vectors.
165
+ """
166
+ self.add_vertex_channel("vx", "float", vx)
167
+ self.add_vertex_channel("vy", "float", vy)
168
+ self.add_vertex_channel("vz", "float", vz)
169
+
170
+ def add_vertex_color(self, r: np.array, g: np.array, b: np.array):
171
+ """Sets the (r, g, b) channels of the colors at the vertices.
172
+
173
+ The three arguments are all numpy arrays of float type and have
174
+ the same length.
175
+
176
+ Args:
177
+ r (`numpy.array(float)`): the r-channel (red) of the colors.
178
+ g (`numpy.array(float)`): the g-channel (green) of the color.
179
+ b (`numpy.array(float)`): the b-channel (blue) of the colors.
180
+ """
181
+ self.add_vertex_channel("red", "float", r)
182
+ self.add_vertex_channel("green", "float", g)
183
+ self.add_vertex_channel("blue", "float", b)
184
+
185
+ def add_vertex_alpha(self, alpha: np.array):
186
+ """Sets the alpha-channel (transparent) of the vertex colors.
187
+
188
+ Args:
189
+ alpha (`numpy.array(float)`): the alpha-channel (transparent) of the colors.
190
+ """
191
+ self.add_vertex_channel("Alpha", "float", alpha)
192
+
193
+ def add_vertex_rgba(self, r: np.array, g: np.array, b: np.array, a: np.array):
194
+ """Sets the (r, g, b, a) channels of the colors at the vertices.
195
+
196
+ Args:
197
+ r (`numpy.array(float)`): the r-channel (red) of the colors.
198
+ g (`numpy.array(float)`): the g-channel (green) of the color.
199
+ b (`numpy.array(float)`): the b-channel (blue) of the colors.
200
+ a (`numpy.array(float)`): the a-channel (alpha) of the colors.
201
+ """
202
+ self.add_vertex_channel("red", "float", r)
203
+ self.add_vertex_channel("green", "float", g)
204
+ self.add_vertex_channel("blue", "float", b)
205
+ self.add_vertex_channel("Alpha", "float", a)
206
+
207
+ # TODO active and refactor later if user feedback indicates the necessity for a compact the input list
208
+ # pass ti vector/matrix field directly
209
+ # def add_vertex_color(self, color):
210
+ # assert isinstance(color, (np.ndarray, ti.Matrix))
211
+ # if not isinstance(color, np.ndarray):
212
+ # color = color.to_numpy()
213
+ # channels = color.shape[color.ndim-1]
214
+ # assert channels == 3 or channels == 4, "The dimension for color should be either be 3 (rgb) or 4 (rgba)"
215
+ # n = color.size // channels
216
+ # assert n == self.num_vertices, "Size of the input is not correct"
217
+ # color = np.reshape(color, (n, channels))
218
+ # self.add_vertex_channel("red", "float", color[:, 0])
219
+ # self.add_vertex_channel("green", "float", color[:, 1])
220
+ # self.add_vertex_channel("blue", "float", color[:, 2])
221
+ # if channels == 4:
222
+ # self.add_vertex_channel("Alpha", "float", color[:, 3])
223
+
224
+ def add_vertex_id(self):
225
+ """Sets the ids of the vertices.
226
+
227
+ The id of a vertex is equal to its index in the vertex array.
228
+ """
229
+ self.add_vertex_channel("id", "int", np.arange(self.num_vertices))
230
+
231
+ def add_vertex_piece(self, piece: np.array):
232
+ self.add_vertex_channel("piece", "int", piece)
233
+
234
+ def add_faces(self, indices: np.array):
235
+ if self.face_type == "tri":
236
+ vert_per_face = 3
237
+ else:
238
+ vert_per_face = 4
239
+ assert vert_per_face * self.num_faces == indices.size, "The dimension of the face vertices is not correct"
240
+ self.face_indices = np.reshape(indices, (self.num_faces, vert_per_face))
241
+ self.face_indices = self.face_indices.astype(np.int32)
242
+
243
+ def add_face_channel(self, key: str, data_type: str, data: np.array):
244
+ if data_type not in self.ply_supported_types:
245
+ print("Unknown type " + data_type + " detected, skipping this channel")
246
+ return
247
+ if data.ndim == 1:
248
+ assert data.size == self.num_faces, "The dimension of the face channel is not correct"
249
+ self.num_face_channels += 1
250
+ if key in self.face_channels:
251
+ print("WARNING: duplicate key " + key + " detected")
252
+ self.face_channels.append(key)
253
+ self.face_data_type.append(data_type)
254
+ self.face_data.append(self.type_map[data_type](data))
255
+ else:
256
+ num_col = data.size // self.num_faces
257
+ assert (
258
+ data.ndim == 2 and data.size == num_col * self.num_faces
259
+ ), "The dimension of the face channel is not correct"
260
+ data.shape = (self.num_faces, num_col)
261
+ self.num_face_channels += num_col
262
+ for i in range(num_col):
263
+ item_key = key + "_" + str(i + 1)
264
+ if item_key in self.face_channels:
265
+ print("WARNING: duplicate key " + item_key + " detected")
266
+ self.face_channels.append(item_key)
267
+ self.face_data_type.append(data_type)
268
+ self.face_data.append(self.type_map[data_type](data[:, i]))
269
+
270
+ def add_face_id(self):
271
+ self.add_face_channel("id", "int", np.arange(self.num_faces))
272
+
273
+ def add_face_piece(self, piece: np.array):
274
+ self.add_face_channel("piece", "int", piece)
275
+
276
+ def sanity_check(self):
277
+ assert "x" in self.vertex_channels, "The vertex pos channel is missing"
278
+ assert "y" in self.vertex_channels, "The vertex pos channel is missing"
279
+ assert "z" in self.vertex_channels, "The vertex pos channel is missing"
280
+ if self.num_faces > 0:
281
+ for idx in self.face_indices.flatten():
282
+ assert idx >= 0 and idx < self.num_vertices, "The face indices are invalid"
283
+
284
+ def print_header(self, path: str, _format: str):
285
+ with open(path, "w") as f:
286
+ f.writelines(
287
+ [
288
+ "ply\n",
289
+ "format " + _format + " 1.0\n",
290
+ "comment " + self.comment + "\n",
291
+ ]
292
+ )
293
+ f.write("element vertex " + str(self.num_vertices) + "\n")
294
+ for i in range(self.num_vertex_channels):
295
+ f.write("property " + self.vertex_data_type[i] + " " + self.vertex_channels[i] + "\n")
296
+ if self.num_faces != 0:
297
+ f.write("element face " + str(self.num_faces) + "\n")
298
+ f.write("property list uchar int vertex_indices\n")
299
+ for i in range(self.num_face_channels):
300
+ f.write("property " + self.face_data_type[i] + " " + self.face_channels[i] + "\n")
301
+ f.write("end_header\n")
302
+
303
+ def export(self, path):
304
+ self.sanity_check()
305
+ self.print_header(path, "binary_" + sys.byteorder + "_endian")
306
+ with open(path, "ab") as f:
307
+ for i in range(self.num_vertices):
308
+ for j in range(self.num_vertex_channels):
309
+ f.write(self.vertex_data[j][i])
310
+ if self.face_type == "tri":
311
+ vert_per_face = np.uint8(3)
312
+ else:
313
+ vert_per_face = np.uint8(4)
314
+ for i in range(self.num_faces):
315
+ f.write(vert_per_face)
316
+ for j in range(vert_per_face):
317
+ f.write(self.face_indices[i, j])
318
+ for j in range(self.num_face_channels):
319
+ f.write(self.face_data[j][i])
320
+
321
+ def export_ascii(self, path):
322
+ self.sanity_check()
323
+ self.print_header(path, "ascii")
324
+ with open(path, "a") as f:
325
+ for i in range(self.num_vertices):
326
+ for j in range(self.num_vertex_channels):
327
+ f.write(str(self.vertex_data[j][i]) + " ")
328
+ f.write("\n")
329
+ if self.face_type == "tri":
330
+ vert_per_face = 3
331
+ else:
332
+ vert_per_face = 4
333
+ for i in range(self.num_faces):
334
+ f.writelines(
335
+ [
336
+ str(vert_per_face) + " ",
337
+ " ".join(map(str, self.face_indices[i, :])),
338
+ " ",
339
+ ]
340
+ )
341
+ for j in range(self.num_face_channels):
342
+ f.write(str(self.face_data[j][i]) + " ")
343
+ f.write("\n")
344
+
345
+ def export_frame_ascii(self, series_num: int, path: str):
346
+ # if path has ply ending
347
+ last_4_char = path[-4:]
348
+ if last_4_char == ".ply":
349
+ path = path[:-4]
350
+
351
+ real_path = path + "_" + f"{series_num:0=6d}" + ".ply"
352
+ self.export_ascii(real_path)
353
+
354
+ def export_frame(self, series_num: int, path: str):
355
+ # if path has ply ending
356
+ last_4_char = path[-4:]
357
+ if last_4_char == ".ply":
358
+ path = path[:-4]
359
+
360
+ real_path = path + "_" + f"{series_num:0=6d}" + ".ply"
361
+ self.export(real_path)
362
+
363
+
364
+ __all__ = ["PLYWriter"]
gstaichi/tools/vtk.py ADDED
@@ -0,0 +1,38 @@
1
+ # type: ignore
2
+
3
+ import numpy as np
4
+
5
+
6
+ def write_vtk(scalar_field, filename):
7
+ try:
8
+ from pyevtk.hl import gridToVTK # pylint: disable=import-outside-toplevel
9
+ except ImportError:
10
+ raise RuntimeError(
11
+ "Failed to import pyevtk. Please install it via /\
12
+ `pip install pyevtk` first. "
13
+ )
14
+
15
+ scalar_field_np = scalar_field.to_numpy()
16
+ field_shape = scalar_field_np.shape
17
+ dimensions = len(field_shape)
18
+
19
+ if dimensions not in (2, 3):
20
+ raise ValueError("The input field must be a 2D or 3D scalar field.")
21
+
22
+ if dimensions == 2:
23
+ scalar_field_np = scalar_field_np[np.newaxis, :, :]
24
+ zcoords = np.array([0, 1])
25
+ elif dimensions == 3:
26
+ zcoords = np.arange(0, field_shape[2])
27
+ else:
28
+ raise ValueError("dimensions should be 2 or 3")
29
+ gridToVTK(
30
+ filename,
31
+ x=np.arange(0, field_shape[0]),
32
+ y=np.arange(0, field_shape[1]),
33
+ z=zcoords,
34
+ cellData={filename: scalar_field_np},
35
+ )
36
+
37
+
38
+ __all__ = ["write_vtk"]
@@ -0,0 +1,19 @@
1
+ # type: ignore
2
+
3
+ """
4
+ This module defines data types in GsTaichi:
5
+
6
+ - primitive: int, float, etc.
7
+ - compound: matrix, vector, struct.
8
+ - template: for reference types.
9
+ - ndarray: for arbitrary arrays.
10
+ - quant: for quantized types, see "https://yuanming.gstaichi.graphics/publication/2021-quangstaichi/quangstaichi.pdf"
11
+ """
12
+
13
+ from gstaichi.types import quant
14
+ from gstaichi.types.annotations import *
15
+ from gstaichi.types.compound_types import *
16
+ from gstaichi.types.ndarray_type import *
17
+ from gstaichi.types.primitive_types import *
18
+ from gstaichi.types.texture_type import *
19
+ from gstaichi.types.utils import *
@@ -0,0 +1,52 @@
1
+ from typing import Any, Generic, TypeVar
2
+
3
+ T = TypeVar("T")
4
+
5
+
6
+ class Template(Generic[T]):
7
+ """Type annotation for template kernel parameter.
8
+ Useful for passing parameters to kernels by reference.
9
+
10
+ See also https://docs.taichi-lang.org/docs/meta.
11
+
12
+ Args:
13
+ tensor (Any): unused
14
+ dim (Any): unused
15
+
16
+ Example::
17
+
18
+ >>> a = 1
19
+ >>>
20
+ >>> @ti.kernel
21
+ >>> def test():
22
+ >>> print(a)
23
+ >>>
24
+ >>> @ti.kernel
25
+ >>> def test_template(a: ti.template()):
26
+ >>> print(a)
27
+ >>>
28
+ >>> test(a) # will print 1
29
+ >>> test_template(a) # will also print 1
30
+ >>> a = 2
31
+ >>> test(a) # will still print 1
32
+ >>> test_template(a) # will print 2
33
+ """
34
+
35
+ def __init__(self, element_type: type[T] = object, ndim: int | None = None):
36
+ self.element_type = element_type
37
+ self.ndim = ndim
38
+
39
+ def __getitem__(self, i: Any) -> T:
40
+ raise NotImplemented
41
+
42
+
43
+ template = Template
44
+ """Alias for :class:`~gstaichi.types.annotations.Template`.
45
+ """
46
+
47
+
48
+ class sparse_matrix_builder:
49
+ pass
50
+
51
+
52
+ __all__ = ["template", "sparse_matrix_builder", "Template"]