gstaichi 2.1.1__cp311-cp311-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (178) hide show
  1. gstaichi/__init__.py +40 -0
  2. gstaichi/_funcs.py +706 -0
  3. gstaichi/_kernels.py +420 -0
  4. gstaichi/_lib/__init__.py +3 -0
  5. gstaichi/_lib/core/__init__.py +0 -0
  6. gstaichi/_lib/core/gstaichi_python.cpython-311-darwin.so +0 -0
  7. gstaichi/_lib/core/gstaichi_python.pyi +2909 -0
  8. gstaichi/_lib/core/py.typed +0 -0
  9. gstaichi/_lib/runtime/libMoltenVK.dylib +0 -0
  10. gstaichi/_lib/runtime/runtime_arm64.bc +0 -0
  11. gstaichi/_lib/utils.py +243 -0
  12. gstaichi/_logging.py +131 -0
  13. gstaichi/_snode/__init__.py +5 -0
  14. gstaichi/_snode/fields_builder.py +187 -0
  15. gstaichi/_snode/snode_tree.py +34 -0
  16. gstaichi/_test_tools/__init__.py +18 -0
  17. gstaichi/_test_tools/dataclass_test_tools.py +36 -0
  18. gstaichi/_test_tools/load_kernel_string.py +30 -0
  19. gstaichi/_test_tools/textwrap2.py +6 -0
  20. gstaichi/_version.py +1 -0
  21. gstaichi/_version_check.py +100 -0
  22. gstaichi/ad/__init__.py +3 -0
  23. gstaichi/ad/_ad.py +530 -0
  24. gstaichi/algorithms/__init__.py +3 -0
  25. gstaichi/algorithms/_algorithms.py +117 -0
  26. gstaichi/assets/.git +1 -0
  27. gstaichi/assets/Go-Regular.ttf +0 -0
  28. gstaichi/assets/static/imgs/ti_gallery.png +0 -0
  29. gstaichi/examples/lcg_python.py +26 -0
  30. gstaichi/examples/lcg_taichi.py +34 -0
  31. gstaichi/examples/minimal.py +28 -0
  32. gstaichi/experimental.py +16 -0
  33. gstaichi/lang/__init__.py +50 -0
  34. gstaichi/lang/_dataclass_util.py +31 -0
  35. gstaichi/lang/_fast_caching/__init__.py +3 -0
  36. gstaichi/lang/_fast_caching/args_hasher.py +110 -0
  37. gstaichi/lang/_fast_caching/config_hasher.py +30 -0
  38. gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
  39. gstaichi/lang/_fast_caching/function_hasher.py +57 -0
  40. gstaichi/lang/_fast_caching/hash_utils.py +11 -0
  41. gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
  42. gstaichi/lang/_fast_caching/src_hasher.py +75 -0
  43. gstaichi/lang/_kernel_impl_dataclass.py +212 -0
  44. gstaichi/lang/_ndarray.py +352 -0
  45. gstaichi/lang/_ndrange.py +152 -0
  46. gstaichi/lang/_template_mapper.py +195 -0
  47. gstaichi/lang/_texture.py +172 -0
  48. gstaichi/lang/_wrap_inspect.py +215 -0
  49. gstaichi/lang/any_array.py +99 -0
  50. gstaichi/lang/ast/__init__.py +5 -0
  51. gstaichi/lang/ast/ast_transformer.py +1323 -0
  52. gstaichi/lang/ast/ast_transformer_utils.py +346 -0
  53. gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
  54. gstaichi/lang/ast/ast_transformers/call_transformer.py +324 -0
  55. gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
  56. gstaichi/lang/ast/checkers.py +106 -0
  57. gstaichi/lang/ast/symbol_resolver.py +57 -0
  58. gstaichi/lang/ast/transform.py +9 -0
  59. gstaichi/lang/common_ops.py +310 -0
  60. gstaichi/lang/exception.py +80 -0
  61. gstaichi/lang/expr.py +180 -0
  62. gstaichi/lang/field.py +428 -0
  63. gstaichi/lang/impl.py +1245 -0
  64. gstaichi/lang/kernel_arguments.py +155 -0
  65. gstaichi/lang/kernel_impl.py +1341 -0
  66. gstaichi/lang/matrix.py +1835 -0
  67. gstaichi/lang/matrix_ops.py +341 -0
  68. gstaichi/lang/matrix_ops_utils.py +190 -0
  69. gstaichi/lang/mesh.py +687 -0
  70. gstaichi/lang/misc.py +780 -0
  71. gstaichi/lang/ops.py +1494 -0
  72. gstaichi/lang/runtime_ops.py +13 -0
  73. gstaichi/lang/shell.py +35 -0
  74. gstaichi/lang/simt/__init__.py +5 -0
  75. gstaichi/lang/simt/block.py +94 -0
  76. gstaichi/lang/simt/grid.py +7 -0
  77. gstaichi/lang/simt/subgroup.py +191 -0
  78. gstaichi/lang/simt/warp.py +96 -0
  79. gstaichi/lang/snode.py +489 -0
  80. gstaichi/lang/source_builder.py +150 -0
  81. gstaichi/lang/struct.py +810 -0
  82. gstaichi/lang/util.py +312 -0
  83. gstaichi/linalg/__init__.py +8 -0
  84. gstaichi/linalg/matrixfree_cg.py +310 -0
  85. gstaichi/linalg/sparse_cg.py +59 -0
  86. gstaichi/linalg/sparse_matrix.py +303 -0
  87. gstaichi/linalg/sparse_solver.py +123 -0
  88. gstaichi/math/__init__.py +11 -0
  89. gstaichi/math/_complex.py +205 -0
  90. gstaichi/math/mathimpl.py +886 -0
  91. gstaichi/profiler/__init__.py +6 -0
  92. gstaichi/profiler/kernel_metrics.py +260 -0
  93. gstaichi/profiler/kernel_profiler.py +586 -0
  94. gstaichi/profiler/memory_profiler.py +15 -0
  95. gstaichi/profiler/scoped_profiler.py +36 -0
  96. gstaichi/sparse/__init__.py +3 -0
  97. gstaichi/sparse/_sparse_grid.py +77 -0
  98. gstaichi/tools/__init__.py +12 -0
  99. gstaichi/tools/diagnose.py +117 -0
  100. gstaichi/tools/np2ply.py +364 -0
  101. gstaichi/tools/vtk.py +38 -0
  102. gstaichi/types/__init__.py +19 -0
  103. gstaichi/types/annotations.py +52 -0
  104. gstaichi/types/compound_types.py +71 -0
  105. gstaichi/types/enums.py +49 -0
  106. gstaichi/types/ndarray_type.py +169 -0
  107. gstaichi/types/primitive_types.py +206 -0
  108. gstaichi/types/quant.py +88 -0
  109. gstaichi/types/texture_type.py +85 -0
  110. gstaichi/types/utils.py +11 -0
  111. gstaichi-2.1.1.data/data/include/GLFW/glfw3.h +6389 -0
  112. gstaichi-2.1.1.data/data/include/GLFW/glfw3native.h +594 -0
  113. gstaichi-2.1.1.data/data/include/spirv-tools/instrument.hpp +268 -0
  114. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.h +907 -0
  115. gstaichi-2.1.1.data/data/include/spirv-tools/libspirv.hpp +375 -0
  116. gstaichi-2.1.1.data/data/include/spirv-tools/linker.hpp +97 -0
  117. gstaichi-2.1.1.data/data/include/spirv-tools/optimizer.hpp +970 -0
  118. gstaichi-2.1.1.data/data/include/spirv_cross/GLSL.std.450.h +114 -0
  119. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.h +2568 -0
  120. gstaichi-2.1.1.data/data/include/spirv_cross/spirv.hpp +2579 -0
  121. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cfg.hpp +168 -0
  122. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_common.hpp +1920 -0
  123. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cpp.hpp +93 -0
  124. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross.hpp +1171 -0
  125. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_c.h +1074 -0
  126. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_containers.hpp +754 -0
  127. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_error_handling.hpp +94 -0
  128. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_parsed_ir.hpp +256 -0
  129. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_cross_util.hpp +37 -0
  130. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_glsl.hpp +1001 -0
  131. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_hlsl.hpp +406 -0
  132. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_msl.hpp +1273 -0
  133. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_parser.hpp +103 -0
  134. gstaichi-2.1.1.data/data/include/spirv_cross/spirv_reflect.hpp +91 -0
  135. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsConfig.cmake +5 -0
  136. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget-release.cmake +29 -0
  137. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools/SPIRV-ToolsTarget.cmake +114 -0
  138. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffConfig.cmake +5 -0
  139. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets-release.cmake +19 -0
  140. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-diff/SPIRV-Tools-diffTargets.cmake +123 -0
  141. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkConfig.cmake +5 -0
  142. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets-release.cmake +19 -0
  143. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-link/SPIRV-Tools-linkTargets.cmake +123 -0
  144. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintConfig.cmake +5 -0
  145. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets-release.cmake +19 -0
  146. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-lint/SPIRV-Tools-lintTargets.cmake +123 -0
  147. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optConfig.cmake +5 -0
  148. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets-release.cmake +19 -0
  149. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-opt/SPIRV-Tools-optTargets.cmake +123 -0
  150. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceConfig.cmake +5 -0
  151. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget-release.cmake +19 -0
  152. gstaichi-2.1.1.data/data/lib/cmake/SPIRV-Tools-reduce/SPIRV-Tools-reduceTarget.cmake +123 -0
  153. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
  154. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
  155. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
  156. gstaichi-2.1.1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
  157. gstaichi-2.1.1.data/data/lib/libSPIRV-Tools-shared.dylib +0 -0
  158. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig-release.cmake +19 -0
  159. gstaichi-2.1.1.data/data/share/spirv_cross_c/cmake/spirv_cross_cConfig.cmake +123 -0
  160. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig-release.cmake +19 -0
  161. gstaichi-2.1.1.data/data/share/spirv_cross_core/cmake/spirv_cross_coreConfig.cmake +106 -0
  162. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig-release.cmake +19 -0
  163. gstaichi-2.1.1.data/data/share/spirv_cross_cpp/cmake/spirv_cross_cppConfig.cmake +123 -0
  164. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig-release.cmake +19 -0
  165. gstaichi-2.1.1.data/data/share/spirv_cross_glsl/cmake/spirv_cross_glslConfig.cmake +123 -0
  166. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig-release.cmake +19 -0
  167. gstaichi-2.1.1.data/data/share/spirv_cross_hlsl/cmake/spirv_cross_hlslConfig.cmake +123 -0
  168. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig-release.cmake +19 -0
  169. gstaichi-2.1.1.data/data/share/spirv_cross_msl/cmake/spirv_cross_mslConfig.cmake +123 -0
  170. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig-release.cmake +19 -0
  171. gstaichi-2.1.1.data/data/share/spirv_cross_reflect/cmake/spirv_cross_reflectConfig.cmake +106 -0
  172. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig-release.cmake +19 -0
  173. gstaichi-2.1.1.data/data/share/spirv_cross_util/cmake/spirv_cross_utilConfig.cmake +123 -0
  174. gstaichi-2.1.1.dist-info/METADATA +106 -0
  175. gstaichi-2.1.1.dist-info/RECORD +178 -0
  176. gstaichi-2.1.1.dist-info/WHEEL +5 -0
  177. gstaichi-2.1.1.dist-info/licenses/LICENSE +201 -0
  178. gstaichi-2.1.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,6 @@
1
+ def dedent(indent_size: int, v: str) -> str:
2
+ lines = []
3
+ for line in v.split("\n"):
4
+ line = line[indent_size:]
5
+ lines.append(line)
6
+ return "\n".join(lines)
gstaichi/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = '2.1.1'
@@ -0,0 +1,100 @@
1
+ # type: ignore
2
+
3
+ import datetime
4
+ import json
5
+ import os
6
+ import platform
7
+ import threading
8
+ import uuid
9
+ from urllib import request
10
+
11
+ from gstaichi._lib import core as _ti_core
12
+
13
+
14
+ def check_version(cur_uuid):
15
+ # Check GsTaichi version for the user.
16
+ major = _ti_core.get_version_major()
17
+ minor = _ti_core.get_version_minor()
18
+ patch = _ti_core.get_version_patch()
19
+ version = f"{major}.{minor}.{patch}"
20
+ payload = {"version": version, "platform": "", "python": ""}
21
+
22
+ system = platform.system()
23
+ u = platform.uname()
24
+ if (u.system, u.machine) == ("Linux", "x86_64"):
25
+ payload["platform"] = "manylinux_2_27_x86_64"
26
+ elif (u.system, u.machine) in (("Linux", "arm64"), ("Linux", "aarch64")):
27
+ payload["platform"] = "manylinux_2_27_aarch64"
28
+ elif system == "Windows":
29
+ payload["platform"] = "win_amd64"
30
+ elif system == "Darwin":
31
+ # we only support arm64
32
+ assert payload["platform"] == "arm64"
33
+ payload["platform"] = "macosx_11_0_arm64"
34
+
35
+ python_version = platform.python_version().split(".")
36
+ payload["python"] = "cp" + python_version[0] + python_version[1]
37
+
38
+ payload["uuid"] = cur_uuid
39
+ if os.getenv("TI_CI") == "1":
40
+ payload["type"] = "CI"
41
+ # We do not want request exceptions break users' usage of GsTaichi.
42
+ try:
43
+ payload = json.dumps(payload)
44
+ payload = payload.encode()
45
+ req = request.Request("https://metadata.gstaichi.graphics/check_version", method="POST")
46
+ req.add_header("Content-Type", "application/json")
47
+ with request.urlopen(req, data=payload, timeout=5) as response:
48
+ response = json.loads(response.read().decode("utf-8"))
49
+ return response
50
+ except:
51
+ return None
52
+
53
+
54
+ def write_version_info(response, cur_uuid, version_info_path, cur_date):
55
+ if response is None:
56
+ return
57
+ with open(version_info_path, "w") as f:
58
+ f.write((cur_date).strftime("%Y-%m-%d"))
59
+ f.write("\n")
60
+ if response["status"] == 1:
61
+ f.write(response["latest_version"])
62
+ else:
63
+ f.write("0.0.0")
64
+ f.write("\n")
65
+ f.write(cur_uuid)
66
+ f.write("\n")
67
+
68
+
69
+ def try_check_version():
70
+ try:
71
+ os.makedirs(_ti_core.get_repo_dir(), exist_ok=True)
72
+ version_info_path = os.path.join(_ti_core.get_repo_dir(), "version_info")
73
+ cur_date = datetime.date.today()
74
+ if os.path.exists(version_info_path):
75
+ with open(version_info_path, "r") as f:
76
+ version_info_file = f.readlines()
77
+ last_time = version_info_file[0].rstrip()
78
+ cur_uuid = version_info_file[2].rstrip()
79
+ if cur_date.strftime("%Y-%m-%d") > last_time:
80
+ response = check_version(cur_uuid)
81
+ write_version_info(response, cur_uuid, version_info_path, cur_date)
82
+ else:
83
+ cur_uuid = str(uuid.uuid4())
84
+ write_version_info({"status": 0}, cur_uuid, version_info_path, cur_date)
85
+ response = check_version(cur_uuid)
86
+ write_version_info(response, cur_uuid, version_info_path, cur_date)
87
+ # Wildcard exception to catch potential file writing errors.
88
+ except:
89
+ pass
90
+
91
+
92
+ def start_version_check_thread():
93
+ skip = os.environ.get("TI_SKIP_VERSION_CHECK")
94
+ if skip != "ON":
95
+ # We don't join this thread because we do not wish to block users.
96
+ check_version_thread = threading.Thread(target=try_check_version, daemon=True)
97
+ check_version_thread.start()
98
+
99
+
100
+ __all__ = []
@@ -0,0 +1,3 @@
1
+ # type: ignore
2
+
3
+ from gstaichi.ad._ad import *
gstaichi/ad/_ad.py ADDED
@@ -0,0 +1,530 @@
1
+ # type: ignore
2
+
3
+ """GsTaichi automatic differentiation module.
4
+
5
+ This module supplies two decorators for users to customize their
6
+ gradient computation task.
7
+ """
8
+
9
+ import warnings
10
+ from functools import reduce
11
+
12
+ import numpy as np
13
+
14
+ import gstaichi.types.primitive_types as types
15
+ from gstaichi import _snode
16
+ from gstaichi.lang import impl
17
+ from gstaichi.lang._ndarray import Ndarray
18
+ from gstaichi.lang.expr import Expr
19
+ from gstaichi.lang.field import Field, ScalarField
20
+ from gstaichi.lang.kernel_impl import kernel
21
+ from gstaichi.lang.snode import SNode
22
+ from gstaichi.types import ndarray, template
23
+ from gstaichi.types.enums import AutodiffMode, SNodeGradType
24
+
25
+
26
+ class GradChecker:
27
+ def __init__(self, loss, to_check):
28
+ self.to_check = to_check
29
+ self.loss = loss
30
+ self.eps_range = 2.0 ** np.arange(-3, -30, -2).astype(np.float64)
31
+ self.result = [None] * len(to_check)
32
+ self.all_fields = get_all_fields()
33
+ self.backups = save_all_fields(self.all_fields)
34
+
35
+ def add_calls(self, calls):
36
+ self.calls = calls
37
+
38
+ def check_grad(self):
39
+ assert self.loss.dtype == types.f64, "Only f64 is supported when checking grad."
40
+
41
+ @kernel
42
+ def x_pos(x: template(), tangent_np: ndarray(), eps: types.f64):
43
+ for I in impl.grouped(x):
44
+ x[I] += eps * tangent_np[I]
45
+
46
+ @kernel
47
+ def x_neg(x: template(), tangent_np: ndarray(), eps: types.f64):
48
+ for I in impl.grouped(x):
49
+ x[I] -= eps * tangent_np[I]
50
+
51
+ for i, x in enumerate(self.to_check):
52
+ if x is self.loss:
53
+ self.result[i] = True
54
+ continue
55
+
56
+ check_pass = False
57
+
58
+ re_range = []
59
+ for eps in self.eps_range:
60
+ tangent_np = np.array(np.random.rand(*x.shape)).astype(np.float64)
61
+
62
+ restore_all_fields(self.all_fields, self.backups)
63
+ x_pos(x, tangent_np, eps)
64
+ for func, args in self.calls:
65
+ func(*args)
66
+ loss_pos = self.loss.to_numpy()
67
+
68
+ restore_all_fields(self.all_fields, self.backups)
69
+ x_neg(x, tangent_np, eps)
70
+ for func, args in self.calls:
71
+ func(*args)
72
+ loss_neg = self.loss.to_numpy()
73
+
74
+ ip_numerical = (loss_pos - loss_neg) * 0.5 / eps
75
+ x_grad_np = x.grad.to_numpy()
76
+ extra_dim = x_grad_np.ndim - tangent_np.ndim
77
+ if extra_dim == 1:
78
+ tangent_np = np.expand_dims(tangent_np, axis=-1)
79
+ if extra_dim == 2:
80
+ tangent_np = np.expand_dims(tangent_np, axis=-1)
81
+
82
+ ip_autodiff = np.sum(x_grad_np * tangent_np)
83
+ err = abs(ip_autodiff - ip_numerical)
84
+ if ip_numerical != 0:
85
+ re = err / abs(ip_autodiff)
86
+ else:
87
+ re = err / (abs(ip_autodiff) + 1e-20)
88
+ re_range.append(re)
89
+
90
+ if err * 100 <= abs(ip_autodiff):
91
+ check_pass = True
92
+ break
93
+
94
+ self.result[i] = check_pass
95
+
96
+ if not check_pass:
97
+ print(
98
+ "variable",
99
+ i,
100
+ "has relative error",
101
+ min(re_range),
102
+ ", expected relative error 0.01",
103
+ )
104
+ else:
105
+ print("variable", i, "passes grad check")
106
+
107
+ assert all(self.result), "Grad check failed: Not all variables pass grad check"
108
+
109
+ restore_all_fields(self.all_fields, self.backups)
110
+ for func, args in self.calls:
111
+ func(*args)
112
+
113
+
114
+ def get_all_fields():
115
+ def visit(node, fields):
116
+ for _i in range(node.ptr.get_num_ch()):
117
+ ch = node.ptr.get_ch(_i)
118
+ if not ch.is_place():
119
+ visit(SNode(ch), fields)
120
+ else:
121
+ if not ch.is_primal():
122
+ continue
123
+ fields.append(ScalarField(Expr(ch.get_expr())))
124
+
125
+ fields = []
126
+ for root_fb in _snode.FieldsBuilder._finalized_roots():
127
+ visit(root_fb, fields)
128
+ return fields
129
+
130
+
131
+ def save_all_fields(all_fields):
132
+ return [x.to_numpy() for x in all_fields]
133
+
134
+
135
+ def restore_all_fields(all_fields, backups):
136
+ assert len(all_fields) == len(backups)
137
+ for f, x in zip(all_fields, backups):
138
+ f.from_numpy(x)
139
+
140
+
141
+ class Tape:
142
+ def __init__(self, loss=None, clear_gradients=True, validation=False, grad_check=None):
143
+ """A context manager for reverse mode autodiff :class:`~gstaichi.ad.Tape`. The
144
+ context manager would catching all of the callings of functions that
145
+ decorated by :func:`~gstaichi.lang.kernel_impl.kernel` or
146
+ :func:`~gstaichi.ad.grad_replaced` under `with` statement, and calculate
147
+ all the partial gradients of a given loss variable by calling all of the
148
+ gradient function of the callings caught in reverse order while `with`
149
+ statement ended.
150
+
151
+ See also :func:`~gstaichi.lang.kernel_impl.kernel` and
152
+ :func:`~gstaichi.ad.grad_replaced` for gradient functions.
153
+
154
+ Args:
155
+ loss(:class:`~gstaichi.lang.expr.Expr`): The loss field, which shape should be ().
156
+ clear_gradients(Bool): Before `with` body start, clear all gradients or not.
157
+ validation(Bool): Check whether the code inside the context manager is autodiff valid, e.g., agree with the global data access rule.
158
+ grad_check(List[Field]): List of fields that need to check gradients.
159
+
160
+ Example::
161
+
162
+ >>> @ti.kernel
163
+ >>> def sum(a: ti.float32):
164
+ >>> for I in ti.grouped(x):
165
+ >>> y[None] += x[I] ** a
166
+ >>>
167
+ >>> with ti.ad.Tape(loss = y):
168
+ >>> sum(2)
169
+ """
170
+ self.calls = []
171
+ self.modes = []
172
+ self.entered = False
173
+ self.gradient_evaluated = False
174
+ self.clear_gradients = clear_gradients
175
+ self.validation = validation
176
+ self.runtime = impl.get_runtime()
177
+ if not self.runtime.prog.config().debug and self.validation:
178
+ warnings.warn(
179
+ "Debug mode is disabled, autodiff valid check will not work. Please specify `ti.init(debug=True)` to enable the check.",
180
+ Warning,
181
+ )
182
+ self.eval_on_exit = loss is not None
183
+ self.loss = loss
184
+ self.grad_checker = None
185
+ if grad_check:
186
+ assert isinstance(grad_check, list), "grad_check should be a list of fields that need to check gradients."
187
+ self.grad_checker = GradChecker(loss, grad_check)
188
+
189
+ def __enter__(self):
190
+ assert not self.entered, "Tape can be entered only once."
191
+ self.entered = True
192
+
193
+ if isinstance(self.loss, Field):
194
+ impl.get_runtime().materialize()
195
+ if len(self.loss.shape) != 0:
196
+ raise RuntimeError("The loss of `Tape` must be a 0-D field, i.e. scalar")
197
+ if not self.loss.snode.ptr.has_adjoint():
198
+ raise RuntimeError(
199
+ "Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)"
200
+ " for all fields that are required by autodiff."
201
+ )
202
+ if self.clear_gradients:
203
+ clear_all_gradients()
204
+ if self.validation:
205
+ clear_all_gradients(gradient_type=SNodeGradType.ADJOINT_CHECKBIT)
206
+
207
+ self.loss.fill(0.0)
208
+ elif isinstance(self.loss, Ndarray):
209
+ if self.loss._get_nelement() != 1:
210
+ raise RuntimeError("The loss of `Tape` must be an ndarray with only one element")
211
+ if self.loss.grad is None:
212
+ raise RuntimeError(
213
+ "Gradients of loss are not allocated, please set needs_grad=True for all ndarrays that are required by autodiff."
214
+ )
215
+ self.loss.fill(0.0)
216
+ else:
217
+ import torch # pylint: disable=C0415
218
+
219
+ if self.loss.numel() != 1:
220
+ raise RuntimeError("The loss of `Tape` must be a tensor only contains one element")
221
+ if not self.loss.requires_grad:
222
+ raise RuntimeError(
223
+ "Gradients of loss are not allocated, please set requires_grad=True for all tensors that are required by autodiff."
224
+ )
225
+ with torch.no_grad():
226
+ self.loss.fill_(0.0)
227
+
228
+ # Attach the context manager to runtime
229
+ self.runtime.target_tape = self
230
+ return self
231
+
232
+ def __exit__(self, _type, value, tb):
233
+ self.runtime.target_tape = None
234
+ if self.eval_on_exit:
235
+ self.grad()
236
+ for calls, mode in zip(self.calls, self.modes):
237
+ calls[0].autodiff_mode = mode
238
+
239
+ def insert(self, func, args):
240
+ # Kernels with mode `AutodiffMode.NONE` and `AutodiffMode.VALIDATION` are all forward kernels.
241
+ # The difference is there are `assert` for global data access rule check in VALIDATION kernels.
242
+ assert func.autodiff_mode in (
243
+ AutodiffMode.NONE,
244
+ AutodiffMode.VALIDATION,
245
+ ), "Inserted funcs should be forward kernels."
246
+ self.modes.append(func.autodiff_mode)
247
+ if self.validation:
248
+ func.autodiff_mode = AutodiffMode.VALIDATION
249
+ self.calls.append((func, args))
250
+
251
+ def grad(self):
252
+ assert self.entered, "Before evaluating gradients tape must be entered."
253
+ assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once."
254
+
255
+ # Set grad for loss
256
+ if isinstance(self.loss, (Field, Ndarray)):
257
+ self.loss.grad.fill(1.0)
258
+ else:
259
+ import torch # pylint: disable=C0415
260
+
261
+ if self.loss.grad is None:
262
+ self.loss.grad = torch.ones_like(self.loss)
263
+ else:
264
+ with torch.no_grad():
265
+ self.loss.grad.fill_(1.0)
266
+
267
+ for func, args in reversed(self.calls):
268
+ # we need to check whether "func" has "grad" attribute
269
+ # since we insert write_int and write_float kernels to self.calls
270
+ # e.g. x[None] = 0.0, this func has no grad attribute
271
+ if hasattr(func, "grad"):
272
+ func.grad(*args)
273
+
274
+ self.gradient_evaluated = True
275
+ if self.grad_checker:
276
+ self.grad_checker.add_calls(self.calls)
277
+ self.grad_checker.check_grad()
278
+
279
+
280
+ def clear_all_gradients(gradient_type=SNodeGradType.ADJOINT):
281
+ """Sets the gradients of all fields to zero."""
282
+ impl.get_runtime().materialize()
283
+
284
+ def visit(node):
285
+ places = []
286
+ for _i in range(node.ptr.get_num_ch()):
287
+ ch = node.ptr.get_ch(_i)
288
+ if not ch.is_place():
289
+ visit(SNode(ch))
290
+ else:
291
+ if ch.get_snode_grad_type() == gradient_type:
292
+ places.append(ch.get_expr())
293
+
294
+ places = tuple(places)
295
+ if places:
296
+ from gstaichi._kernels import clear_gradients # pylint: disable=C0415
297
+
298
+ clear_gradients(places)
299
+
300
+ for root_fb in _snode.FieldsBuilder._finalized_roots():
301
+ visit(root_fb)
302
+
303
+
304
+ def grad_replaced(func):
305
+ """A decorator for python function to customize gradient with GsTaichi's autodiff
306
+ system, e.g. `ti.ad.Tape()` and `kernel.grad()`.
307
+
308
+ This decorator forces GsTaichi's autodiff system to use a user-defined gradient
309
+ function for the decorated function. Its customized gradient must be decorated
310
+ by :func:`~gstaichi.ad.grad_for`.
311
+
312
+ Args:
313
+ fn (Callable): The python function to be decorated.
314
+
315
+ Returns:
316
+ Callable: The decorated function.
317
+
318
+ Example::
319
+
320
+ >>> @ti.kernel
321
+ >>> def multiply(a: ti.float32):
322
+ >>> for I in ti.grouped(x):
323
+ >>> y[I] = x[I] * a
324
+ >>>
325
+ >>> @ti.kernel
326
+ >>> def multiply_grad(a: ti.float32):
327
+ >>> for I in ti.grouped(x):
328
+ >>> x.grad[I] = y.grad[I] / a
329
+ >>>
330
+ >>> @ti.ad.grad_replaced
331
+ >>> def foo(a):
332
+ >>> multiply(a)
333
+ >>>
334
+ >>> @ti.ad.grad_for(foo)
335
+ >>> def foo_grad(a):
336
+ >>> multiply_grad(a)"""
337
+
338
+ def decorated(*args, **kwargs):
339
+ # TODO [#3025]: get rid of circular imports and move this to the top.
340
+ impl.get_runtime().grad_replaced = True
341
+ if impl.get_runtime().target_tape:
342
+ impl.get_runtime().target_tape.insert(decorated, args)
343
+ try:
344
+ func(*args, **kwargs)
345
+ finally:
346
+ impl.get_runtime().grad_replaced = False
347
+
348
+ decorated.grad = None
349
+ decorated.autodiff_mode = AutodiffMode.NONE
350
+ return decorated
351
+
352
+
353
+ def grad_for(primal):
354
+ """Generates a decorator to decorate `primal`'s customized gradient function.
355
+
356
+ See :func:`~gstaichi.lang.grad_replaced` for examples.
357
+
358
+ Args:
359
+ primal (Callable): The primal function, must be decorated by :func:`~gstaichi.ad.grad_replaced`.
360
+
361
+ Returns:
362
+ Callable: The decorator used to decorate customized gradient function."""
363
+
364
+ def decorator(func):
365
+ def decorated(*args, **kwargs):
366
+ func(*args, **kwargs)
367
+
368
+ if not hasattr(primal, "grad"):
369
+ raise RuntimeError(f"Primal function `{primal.__name__}` must be decorated by ti.ad.grad_replaced")
370
+ if primal.grad is not None:
371
+ raise RuntimeError(
372
+ "Primal function must be a **python** function instead of a gstaichi kernel. Please wrap the gstaichi kernel in a @ti.ad.grad_replaced decorated python function instead."
373
+ )
374
+ primal.grad = decorated
375
+ return decorated
376
+
377
+ return decorator
378
+
379
+
380
+ def no_grad(func):
381
+ """A decorator for python function to skip gradient calculation within GsTaichi's
382
+ autodiff system, e.g. `ti.ad.Tape()` and `kernel.grad()`.
383
+ This decorator forces GsTaichi's autodiff system to use an empty gradient function
384
+ for the decorated function.
385
+
386
+ Args:
387
+ fn (Callable): The python function to be decorated.
388
+
389
+ Returns:
390
+ Callable: The decorated function.
391
+
392
+ Example::
393
+
394
+ >>> @ti.kernel
395
+ >>> def multiply(a: ti.float32):
396
+ >>> for I in ti.grouped(x):
397
+ >>> y[I] = x[I] * a
398
+ >>>
399
+ >>> @ti.no_grad
400
+ >>> def foo(a):
401
+ >>> multiply(a)"""
402
+
403
+ def decorated(*args, **kwargs):
404
+ impl.get_runtime().grad_replaced = True
405
+ if impl.get_runtime().target_tape:
406
+ impl.get_runtime().target_tape.insert(decorated, args)
407
+ try:
408
+ func(*args, **kwargs)
409
+ finally:
410
+ impl.get_runtime().grad_replaced = False
411
+
412
+ def placeholder(*args, **kwargs):
413
+ return
414
+
415
+ decorated.grad = placeholder
416
+ decorated.autodiff_mode = AutodiffMode.NONE
417
+ return decorated
418
+
419
+
420
+ class FwdMode:
421
+ def __init__(self, loss, param, seed=None, clear_gradients=True):
422
+ self.calls = []
423
+ self.modes = []
424
+ self.entered = False
425
+ self.kernels_recovered = False
426
+ self.runtime = impl.get_runtime()
427
+ self.loss = loss
428
+ self.param = param
429
+ self.seed = seed
430
+ self.clear_gradients = clear_gradients
431
+
432
+ def __enter__(self):
433
+ assert not self.entered, "Forward mode manager can be entered only once."
434
+ self.entered = True
435
+ impl.get_runtime().materialize()
436
+ if not isinstance(self.loss, list):
437
+ self.loss = [self.loss]
438
+ for ls in self.loss:
439
+ assert isinstance(ls, ScalarField)
440
+
441
+ # Currently we only support only one N-D field as a group of parameters,
442
+ # which is sufficient for computing Jacobian-vector product(Jvp).
443
+ # For cases with multiple groups of parameters, it requires to run the forward ad multiple times,
444
+ # which is out of scope of the current design for this interface.
445
+
446
+ # TODO: support vector field and matrix field
447
+ assert isinstance(self.param, ScalarField)
448
+
449
+ def shape_flatten(shape):
450
+ return reduce((lambda x, y: x * y), list(shape))
451
+
452
+ # Handle 0-D field
453
+ if len(self.param.shape) != 0:
454
+ parameters_shape_flatten = shape_flatten(self.param.shape)
455
+ else:
456
+ parameters_shape_flatten = 1
457
+
458
+ if not self.seed:
459
+ if parameters_shape_flatten == 1:
460
+ # Compute the derivative respect to the first variable by default
461
+ self.seed = [1.0]
462
+ else:
463
+ raise RuntimeError(
464
+ "`seed` is not set for non 0-D field, please specify."
465
+ " `seed` is a list to specify which parameters the computed derivatives respect to. The length of the `seed` should be same to that of the `parameters`"
466
+ " E.g. Given a loss `loss = ti.field(float, shape=3)`, parameter `x = ti.field(float, shape=3)`"
467
+ " seed = [0, 0, 1] indicates compute derivative respect to the third element of `x`."
468
+ " seed = [1, 1, 1] indicates compute the sum of derivatives respect to all three element of `x`, i.e., Jacobian-vector product(Jvp) for each element in `loss`"
469
+ )
470
+ else:
471
+ assert parameters_shape_flatten == len(self.seed)
472
+
473
+ # Clear gradients
474
+ if self.clear_gradients:
475
+ clear_all_gradients(gradient_type=SNodeGradType.DUAL)
476
+
477
+ # Set seed for each variable
478
+ if len(self.seed) == 1:
479
+ if len(self.param.shape) == 0:
480
+ # e.g., x= ti.field(float, shape = ())
481
+ self.param.dual[None] = 1.0 * self.seed[0]
482
+ else:
483
+ # e.g., ti.root.dense(ti.i, 1).place(x.dual)
484
+ self.param.dual[0] = 1.0 * self.seed[0]
485
+ else:
486
+ self.param.dual.from_numpy(np.array(self.seed, dtype=np.float32))
487
+
488
+ # Attach the context manager to the runtime
489
+ self.runtime.fwd_mode_manager = self
490
+
491
+ def __exit__(self, _type, value, tb):
492
+ self.runtime.fwd_mode_manager = None
493
+ self.clear_seed()
494
+ self.recover_kernels()
495
+
496
+ def insert(self, func):
497
+ assert (
498
+ func.autodiff_mode == AutodiffMode.NONE or func.autodiff_mode == AutodiffMode.FORWARD
499
+ ), "Inserted funcs should be forward or grad kernels (forward mode)."
500
+ self.modes.append(func.autodiff_mode)
501
+ func.autodiff_mode = AutodiffMode.FORWARD
502
+ self.calls.append((func))
503
+
504
+ def recover_kernels(self):
505
+ assert self.entered, "Before recover the kernels, fwd mode manager must be entered."
506
+ for calls, mode in zip(self.calls, self.modes):
507
+ calls.autodiff_mode = mode
508
+ self.kernels_recovered = True
509
+
510
+ def clear_seed(self):
511
+ # clear seed values
512
+ if len(self.seed) == 1:
513
+ if len(self.param.shape) == 0:
514
+ # e.g., x= ti.field(float, shape = ())
515
+ self.param.dual[None] = 0.0
516
+ else:
517
+ # e.g., ti.root.dense(ti.i, 1).place(x.dual)
518
+ self.param.dual[0] = 0.0
519
+ else:
520
+ self.param.dual.fill(0)
521
+
522
+
523
+ __all__ = [
524
+ "FwdMode",
525
+ "Tape",
526
+ "clear_all_gradients",
527
+ "grad_for",
528
+ "grad_replaced",
529
+ "no_grad",
530
+ ]
@@ -0,0 +1,3 @@
1
+ # type: ignore
2
+
3
+ from ._algorithms import *