gstaichi 0.0.0__cp311-cp311-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi/CHANGELOG.md +4 -0
- gstaichi/__init__.py +51 -0
- gstaichi/_funcs.py +706 -0
- gstaichi/_kernels.py +420 -0
- gstaichi/_lib/__init__.py +5 -0
- gstaichi/_lib/core/__init__.py +0 -0
- gstaichi/_lib/core/gstaichi_python.cp311-win_amd64.pyd +0 -0
- gstaichi/_lib/core/gstaichi_python.pyi +2917 -0
- gstaichi/_lib/core/py.typed +0 -0
- gstaichi/_lib/runtime/runtime_cuda.bc +0 -0
- gstaichi/_lib/runtime/runtime_x64.bc +0 -0
- gstaichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- gstaichi/_lib/utils.py +243 -0
- gstaichi/_logging.py +131 -0
- gstaichi/_snode/__init__.py +5 -0
- gstaichi/_snode/fields_builder.py +187 -0
- gstaichi/_snode/snode_tree.py +34 -0
- gstaichi/_test_tools/__init__.py +18 -0
- gstaichi/_test_tools/dataclass_test_tools.py +36 -0
- gstaichi/_test_tools/load_kernel_string.py +30 -0
- gstaichi/_test_tools/textwrap2.py +6 -0
- gstaichi/_version_check.py +100 -0
- gstaichi/ad/__init__.py +3 -0
- gstaichi/ad/_ad.py +530 -0
- gstaichi/algorithms/__init__.py +3 -0
- gstaichi/algorithms/_algorithms.py +117 -0
- gstaichi/assets/.git +1 -0
- gstaichi/assets/Go-Regular.ttf +0 -0
- gstaichi/assets/static/imgs/ti_gallery.png +0 -0
- gstaichi/examples/lcg_python.py +26 -0
- gstaichi/examples/lcg_taichi.py +34 -0
- gstaichi/examples/minimal.py +28 -0
- gstaichi/experimental.py +16 -0
- gstaichi/lang/__init__.py +50 -0
- gstaichi/lang/_dataclass_util.py +31 -0
- gstaichi/lang/_fast_caching/__init__.py +3 -0
- gstaichi/lang/_fast_caching/args_hasher.py +122 -0
- gstaichi/lang/_fast_caching/config_hasher.py +30 -0
- gstaichi/lang/_fast_caching/fast_caching_types.py +21 -0
- gstaichi/lang/_fast_caching/function_hasher.py +57 -0
- gstaichi/lang/_fast_caching/hash_utils.py +11 -0
- gstaichi/lang/_fast_caching/python_side_cache.py +52 -0
- gstaichi/lang/_fast_caching/src_hasher.py +83 -0
- gstaichi/lang/_kernel_impl_dataclass.py +212 -0
- gstaichi/lang/_ndarray.py +366 -0
- gstaichi/lang/_ndrange.py +152 -0
- gstaichi/lang/_template_mapper.py +195 -0
- gstaichi/lang/_texture.py +172 -0
- gstaichi/lang/_wrap_inspect.py +215 -0
- gstaichi/lang/any_array.py +99 -0
- gstaichi/lang/ast/__init__.py +7 -0
- gstaichi/lang/ast/ast_transformer.py +1351 -0
- gstaichi/lang/ast/ast_transformer_utils.py +346 -0
- gstaichi/lang/ast/ast_transformers/__init__.py +0 -0
- gstaichi/lang/ast/ast_transformers/call_transformer.py +327 -0
- gstaichi/lang/ast/ast_transformers/function_def_transformer.py +304 -0
- gstaichi/lang/ast/checkers.py +106 -0
- gstaichi/lang/ast/symbol_resolver.py +57 -0
- gstaichi/lang/ast/transform.py +9 -0
- gstaichi/lang/common_ops.py +310 -0
- gstaichi/lang/exception.py +80 -0
- gstaichi/lang/expr.py +180 -0
- gstaichi/lang/field.py +428 -0
- gstaichi/lang/impl.py +1259 -0
- gstaichi/lang/kernel_arguments.py +155 -0
- gstaichi/lang/kernel_impl.py +1386 -0
- gstaichi/lang/matrix.py +1835 -0
- gstaichi/lang/matrix_ops.py +341 -0
- gstaichi/lang/matrix_ops_utils.py +190 -0
- gstaichi/lang/mesh.py +687 -0
- gstaichi/lang/misc.py +784 -0
- gstaichi/lang/ops.py +1494 -0
- gstaichi/lang/runtime_ops.py +13 -0
- gstaichi/lang/shell.py +35 -0
- gstaichi/lang/simt/__init__.py +5 -0
- gstaichi/lang/simt/block.py +94 -0
- gstaichi/lang/simt/grid.py +7 -0
- gstaichi/lang/simt/subgroup.py +191 -0
- gstaichi/lang/simt/warp.py +96 -0
- gstaichi/lang/snode.py +489 -0
- gstaichi/lang/source_builder.py +150 -0
- gstaichi/lang/struct.py +810 -0
- gstaichi/lang/util.py +312 -0
- gstaichi/linalg/__init__.py +10 -0
- gstaichi/linalg/matrixfree_cg.py +310 -0
- gstaichi/linalg/sparse_cg.py +59 -0
- gstaichi/linalg/sparse_matrix.py +303 -0
- gstaichi/linalg/sparse_solver.py +123 -0
- gstaichi/math/__init__.py +11 -0
- gstaichi/math/_complex.py +205 -0
- gstaichi/math/mathimpl.py +886 -0
- gstaichi/profiler/__init__.py +6 -0
- gstaichi/profiler/kernel_metrics.py +260 -0
- gstaichi/profiler/kernel_profiler.py +586 -0
- gstaichi/profiler/memory_profiler.py +15 -0
- gstaichi/profiler/scoped_profiler.py +36 -0
- gstaichi/sparse/__init__.py +3 -0
- gstaichi/sparse/_sparse_grid.py +77 -0
- gstaichi/tools/__init__.py +12 -0
- gstaichi/tools/diagnose.py +117 -0
- gstaichi/tools/np2ply.py +364 -0
- gstaichi/tools/vtk.py +38 -0
- gstaichi/types/__init__.py +21 -0
- gstaichi/types/annotations.py +52 -0
- gstaichi/types/compound_types.py +71 -0
- gstaichi/types/enums.py +49 -0
- gstaichi/types/ndarray_type.py +169 -0
- gstaichi/types/primitive_types.py +206 -0
- gstaichi/types/quant.py +88 -0
- gstaichi/types/texture_type.py +85 -0
- gstaichi/types/utils.py +11 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.0.0.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.0.0.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.0.0.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.0.0.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.0.0.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.0.0.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.0.0.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.0.0.dist-info/METADATA +97 -0
- gstaichi-0.0.0.dist-info/RECORD +154 -0
- gstaichi-0.0.0.dist-info/WHEEL +5 -0
- gstaichi-0.0.0.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.0.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,100 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import datetime
|
4
|
+
import json
|
5
|
+
import os
|
6
|
+
import platform
|
7
|
+
import threading
|
8
|
+
import uuid
|
9
|
+
from urllib import request
|
10
|
+
|
11
|
+
from gstaichi._lib import core as _ti_core
|
12
|
+
|
13
|
+
|
14
|
+
def check_version(cur_uuid):
|
15
|
+
# Check GsTaichi version for the user.
|
16
|
+
major = _ti_core.get_version_major()
|
17
|
+
minor = _ti_core.get_version_minor()
|
18
|
+
patch = _ti_core.get_version_patch()
|
19
|
+
version = f"{major}.{minor}.{patch}"
|
20
|
+
payload = {"version": version, "platform": "", "python": ""}
|
21
|
+
|
22
|
+
system = platform.system()
|
23
|
+
u = platform.uname()
|
24
|
+
if (u.system, u.machine) == ("Linux", "x86_64"):
|
25
|
+
payload["platform"] = "manylinux_2_27_x86_64"
|
26
|
+
elif (u.system, u.machine) in (("Linux", "arm64"), ("Linux", "aarch64")):
|
27
|
+
payload["platform"] = "manylinux_2_27_aarch64"
|
28
|
+
elif system == "Windows":
|
29
|
+
payload["platform"] = "win_amd64"
|
30
|
+
elif system == "Darwin":
|
31
|
+
# we only support arm64
|
32
|
+
assert payload["platform"] == "arm64"
|
33
|
+
payload["platform"] = "macosx_11_0_arm64"
|
34
|
+
|
35
|
+
python_version = platform.python_version().split(".")
|
36
|
+
payload["python"] = "cp" + python_version[0] + python_version[1]
|
37
|
+
|
38
|
+
payload["uuid"] = cur_uuid
|
39
|
+
if os.getenv("TI_CI") == "1":
|
40
|
+
payload["type"] = "CI"
|
41
|
+
# We do not want request exceptions break users' usage of GsTaichi.
|
42
|
+
try:
|
43
|
+
payload = json.dumps(payload)
|
44
|
+
payload = payload.encode()
|
45
|
+
req = request.Request("https://metadata.gstaichi.graphics/check_version", method="POST")
|
46
|
+
req.add_header("Content-Type", "application/json")
|
47
|
+
with request.urlopen(req, data=payload, timeout=5) as response:
|
48
|
+
response = json.loads(response.read().decode("utf-8"))
|
49
|
+
return response
|
50
|
+
except:
|
51
|
+
return None
|
52
|
+
|
53
|
+
|
54
|
+
def write_version_info(response, cur_uuid, version_info_path, cur_date):
|
55
|
+
if response is None:
|
56
|
+
return
|
57
|
+
with open(version_info_path, "w") as f:
|
58
|
+
f.write((cur_date).strftime("%Y-%m-%d"))
|
59
|
+
f.write("\n")
|
60
|
+
if response["status"] == 1:
|
61
|
+
f.write(response["latest_version"])
|
62
|
+
else:
|
63
|
+
f.write("0.0.0")
|
64
|
+
f.write("\n")
|
65
|
+
f.write(cur_uuid)
|
66
|
+
f.write("\n")
|
67
|
+
|
68
|
+
|
69
|
+
def try_check_version():
|
70
|
+
try:
|
71
|
+
os.makedirs(_ti_core.get_repo_dir(), exist_ok=True)
|
72
|
+
version_info_path = os.path.join(_ti_core.get_repo_dir(), "version_info")
|
73
|
+
cur_date = datetime.date.today()
|
74
|
+
if os.path.exists(version_info_path):
|
75
|
+
with open(version_info_path, "r") as f:
|
76
|
+
version_info_file = f.readlines()
|
77
|
+
last_time = version_info_file[0].rstrip()
|
78
|
+
cur_uuid = version_info_file[2].rstrip()
|
79
|
+
if cur_date.strftime("%Y-%m-%d") > last_time:
|
80
|
+
response = check_version(cur_uuid)
|
81
|
+
write_version_info(response, cur_uuid, version_info_path, cur_date)
|
82
|
+
else:
|
83
|
+
cur_uuid = str(uuid.uuid4())
|
84
|
+
write_version_info({"status": 0}, cur_uuid, version_info_path, cur_date)
|
85
|
+
response = check_version(cur_uuid)
|
86
|
+
write_version_info(response, cur_uuid, version_info_path, cur_date)
|
87
|
+
# Wildcard exception to catch potential file writing errors.
|
88
|
+
except:
|
89
|
+
pass
|
90
|
+
|
91
|
+
|
92
|
+
def start_version_check_thread():
|
93
|
+
skip = os.environ.get("TI_SKIP_VERSION_CHECK")
|
94
|
+
if skip != "ON":
|
95
|
+
# We don't join this thread because we do not wish to block users.
|
96
|
+
check_version_thread = threading.Thread(target=try_check_version, daemon=True)
|
97
|
+
check_version_thread.start()
|
98
|
+
|
99
|
+
|
100
|
+
__all__ = []
|
gstaichi/ad/__init__.py
ADDED
gstaichi/ad/_ad.py
ADDED
@@ -0,0 +1,530 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""GsTaichi automatic differentiation module.
|
4
|
+
|
5
|
+
This module supplies two decorators for users to customize their
|
6
|
+
gradient computation task.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import warnings
|
10
|
+
from functools import reduce
|
11
|
+
|
12
|
+
import numpy as np
|
13
|
+
|
14
|
+
import gstaichi.types.primitive_types as types
|
15
|
+
from gstaichi import _snode
|
16
|
+
from gstaichi.lang import impl
|
17
|
+
from gstaichi.lang._ndarray import Ndarray
|
18
|
+
from gstaichi.lang.expr import Expr
|
19
|
+
from gstaichi.lang.field import Field, ScalarField
|
20
|
+
from gstaichi.lang.kernel_impl import kernel
|
21
|
+
from gstaichi.lang.snode import SNode
|
22
|
+
from gstaichi.types import ndarray, template
|
23
|
+
from gstaichi.types.enums import AutodiffMode, SNodeGradType
|
24
|
+
|
25
|
+
|
26
|
+
class GradChecker:
|
27
|
+
def __init__(self, loss, to_check):
|
28
|
+
self.to_check = to_check
|
29
|
+
self.loss = loss
|
30
|
+
self.eps_range = 2.0 ** np.arange(-3, -30, -2).astype(np.float64)
|
31
|
+
self.result = [None] * len(to_check)
|
32
|
+
self.all_fields = get_all_fields()
|
33
|
+
self.backups = save_all_fields(self.all_fields)
|
34
|
+
|
35
|
+
def add_calls(self, calls):
|
36
|
+
self.calls = calls
|
37
|
+
|
38
|
+
def check_grad(self):
|
39
|
+
assert self.loss.dtype == types.f64, "Only f64 is supported when checking grad."
|
40
|
+
|
41
|
+
@kernel
|
42
|
+
def x_pos(x: template(), tangent_np: ndarray(), eps: types.f64):
|
43
|
+
for I in impl.grouped(x):
|
44
|
+
x[I] += eps * tangent_np[I]
|
45
|
+
|
46
|
+
@kernel
|
47
|
+
def x_neg(x: template(), tangent_np: ndarray(), eps: types.f64):
|
48
|
+
for I in impl.grouped(x):
|
49
|
+
x[I] -= eps * tangent_np[I]
|
50
|
+
|
51
|
+
for i, x in enumerate(self.to_check):
|
52
|
+
if x is self.loss:
|
53
|
+
self.result[i] = True
|
54
|
+
continue
|
55
|
+
|
56
|
+
check_pass = False
|
57
|
+
|
58
|
+
re_range = []
|
59
|
+
for eps in self.eps_range:
|
60
|
+
tangent_np = np.array(np.random.rand(*x.shape)).astype(np.float64)
|
61
|
+
|
62
|
+
restore_all_fields(self.all_fields, self.backups)
|
63
|
+
x_pos(x, tangent_np, eps)
|
64
|
+
for func, args in self.calls:
|
65
|
+
func(*args)
|
66
|
+
loss_pos = self.loss.to_numpy()
|
67
|
+
|
68
|
+
restore_all_fields(self.all_fields, self.backups)
|
69
|
+
x_neg(x, tangent_np, eps)
|
70
|
+
for func, args in self.calls:
|
71
|
+
func(*args)
|
72
|
+
loss_neg = self.loss.to_numpy()
|
73
|
+
|
74
|
+
ip_numerical = (loss_pos - loss_neg) * 0.5 / eps
|
75
|
+
x_grad_np = x.grad.to_numpy()
|
76
|
+
extra_dim = x_grad_np.ndim - tangent_np.ndim
|
77
|
+
if extra_dim == 1:
|
78
|
+
tangent_np = np.expand_dims(tangent_np, axis=-1)
|
79
|
+
if extra_dim == 2:
|
80
|
+
tangent_np = np.expand_dims(tangent_np, axis=-1)
|
81
|
+
|
82
|
+
ip_autodiff = np.sum(x_grad_np * tangent_np)
|
83
|
+
err = abs(ip_autodiff - ip_numerical)
|
84
|
+
if ip_numerical != 0:
|
85
|
+
re = err / abs(ip_autodiff)
|
86
|
+
else:
|
87
|
+
re = err / (abs(ip_autodiff) + 1e-20)
|
88
|
+
re_range.append(re)
|
89
|
+
|
90
|
+
if err * 100 <= abs(ip_autodiff):
|
91
|
+
check_pass = True
|
92
|
+
break
|
93
|
+
|
94
|
+
self.result[i] = check_pass
|
95
|
+
|
96
|
+
if not check_pass:
|
97
|
+
print(
|
98
|
+
"variable",
|
99
|
+
i,
|
100
|
+
"has relative error",
|
101
|
+
min(re_range),
|
102
|
+
", expected relative error 0.01",
|
103
|
+
)
|
104
|
+
else:
|
105
|
+
print("variable", i, "passes grad check")
|
106
|
+
|
107
|
+
assert all(self.result), "Grad check failed: Not all variables pass grad check"
|
108
|
+
|
109
|
+
restore_all_fields(self.all_fields, self.backups)
|
110
|
+
for func, args in self.calls:
|
111
|
+
func(*args)
|
112
|
+
|
113
|
+
|
114
|
+
def get_all_fields():
|
115
|
+
def visit(node, fields):
|
116
|
+
for _i in range(node.ptr.get_num_ch()):
|
117
|
+
ch = node.ptr.get_ch(_i)
|
118
|
+
if not ch.is_place():
|
119
|
+
visit(SNode(ch), fields)
|
120
|
+
else:
|
121
|
+
if not ch.is_primal():
|
122
|
+
continue
|
123
|
+
fields.append(ScalarField(Expr(ch.get_expr())))
|
124
|
+
|
125
|
+
fields = []
|
126
|
+
for root_fb in _snode.FieldsBuilder._finalized_roots():
|
127
|
+
visit(root_fb, fields)
|
128
|
+
return fields
|
129
|
+
|
130
|
+
|
131
|
+
def save_all_fields(all_fields):
|
132
|
+
return [x.to_numpy() for x in all_fields]
|
133
|
+
|
134
|
+
|
135
|
+
def restore_all_fields(all_fields, backups):
|
136
|
+
assert len(all_fields) == len(backups)
|
137
|
+
for f, x in zip(all_fields, backups):
|
138
|
+
f.from_numpy(x)
|
139
|
+
|
140
|
+
|
141
|
+
class Tape:
|
142
|
+
def __init__(self, loss=None, clear_gradients=True, validation=False, grad_check=None):
|
143
|
+
"""A context manager for reverse mode autodiff :class:`~gstaichi.ad.Tape`. The
|
144
|
+
context manager would catching all of the callings of functions that
|
145
|
+
decorated by :func:`~gstaichi.lang.kernel_impl.kernel` or
|
146
|
+
:func:`~gstaichi.ad.grad_replaced` under `with` statement, and calculate
|
147
|
+
all the partial gradients of a given loss variable by calling all of the
|
148
|
+
gradient function of the callings caught in reverse order while `with`
|
149
|
+
statement ended.
|
150
|
+
|
151
|
+
See also :func:`~gstaichi.lang.kernel_impl.kernel` and
|
152
|
+
:func:`~gstaichi.ad.grad_replaced` for gradient functions.
|
153
|
+
|
154
|
+
Args:
|
155
|
+
loss(:class:`~gstaichi.lang.expr.Expr`): The loss field, which shape should be ().
|
156
|
+
clear_gradients(Bool): Before `with` body start, clear all gradients or not.
|
157
|
+
validation(Bool): Check whether the code inside the context manager is autodiff valid, e.g., agree with the global data access rule.
|
158
|
+
grad_check(List[Field]): List of fields that need to check gradients.
|
159
|
+
|
160
|
+
Example::
|
161
|
+
|
162
|
+
>>> @ti.kernel
|
163
|
+
>>> def sum(a: ti.float32):
|
164
|
+
>>> for I in ti.grouped(x):
|
165
|
+
>>> y[None] += x[I] ** a
|
166
|
+
>>>
|
167
|
+
>>> with ti.ad.Tape(loss = y):
|
168
|
+
>>> sum(2)
|
169
|
+
"""
|
170
|
+
self.calls = []
|
171
|
+
self.modes = []
|
172
|
+
self.entered = False
|
173
|
+
self.gradient_evaluated = False
|
174
|
+
self.clear_gradients = clear_gradients
|
175
|
+
self.validation = validation
|
176
|
+
self.runtime = impl.get_runtime()
|
177
|
+
if not self.runtime.prog.config().debug and self.validation:
|
178
|
+
warnings.warn(
|
179
|
+
"Debug mode is disabled, autodiff valid check will not work. Please specify `ti.init(debug=True)` to enable the check.",
|
180
|
+
Warning,
|
181
|
+
)
|
182
|
+
self.eval_on_exit = loss is not None
|
183
|
+
self.loss = loss
|
184
|
+
self.grad_checker = None
|
185
|
+
if grad_check:
|
186
|
+
assert isinstance(grad_check, list), "grad_check should be a list of fields that need to check gradients."
|
187
|
+
self.grad_checker = GradChecker(loss, grad_check)
|
188
|
+
|
189
|
+
def __enter__(self):
|
190
|
+
assert not self.entered, "Tape can be entered only once."
|
191
|
+
self.entered = True
|
192
|
+
|
193
|
+
if isinstance(self.loss, Field):
|
194
|
+
impl.get_runtime().materialize()
|
195
|
+
if len(self.loss.shape) != 0:
|
196
|
+
raise RuntimeError("The loss of `Tape` must be a 0-D field, i.e. scalar")
|
197
|
+
if not self.loss.snode.ptr.has_adjoint():
|
198
|
+
raise RuntimeError(
|
199
|
+
"Gradients of loss are not allocated, please use ti.field(..., needs_grad=True)"
|
200
|
+
" for all fields that are required by autodiff."
|
201
|
+
)
|
202
|
+
if self.clear_gradients:
|
203
|
+
clear_all_gradients()
|
204
|
+
if self.validation:
|
205
|
+
clear_all_gradients(gradient_type=SNodeGradType.ADJOINT_CHECKBIT)
|
206
|
+
|
207
|
+
self.loss.fill(0.0)
|
208
|
+
elif isinstance(self.loss, Ndarray):
|
209
|
+
if self.loss._get_nelement() != 1:
|
210
|
+
raise RuntimeError("The loss of `Tape` must be an ndarray with only one element")
|
211
|
+
if self.loss.grad is None:
|
212
|
+
raise RuntimeError(
|
213
|
+
"Gradients of loss are not allocated, please set needs_grad=True for all ndarrays that are required by autodiff."
|
214
|
+
)
|
215
|
+
self.loss.fill(0.0)
|
216
|
+
else:
|
217
|
+
import torch # pylint: disable=C0415
|
218
|
+
|
219
|
+
if self.loss.numel() != 1:
|
220
|
+
raise RuntimeError("The loss of `Tape` must be a tensor only contains one element")
|
221
|
+
if not self.loss.requires_grad:
|
222
|
+
raise RuntimeError(
|
223
|
+
"Gradients of loss are not allocated, please set requires_grad=True for all tensors that are required by autodiff."
|
224
|
+
)
|
225
|
+
with torch.no_grad():
|
226
|
+
self.loss.fill_(0.0)
|
227
|
+
|
228
|
+
# Attach the context manager to runtime
|
229
|
+
self.runtime.target_tape = self
|
230
|
+
return self
|
231
|
+
|
232
|
+
def __exit__(self, _type, value, tb):
|
233
|
+
self.runtime.target_tape = None
|
234
|
+
if self.eval_on_exit:
|
235
|
+
self.grad()
|
236
|
+
for calls, mode in zip(self.calls, self.modes):
|
237
|
+
calls[0].autodiff_mode = mode
|
238
|
+
|
239
|
+
def insert(self, func, args):
|
240
|
+
# Kernels with mode `AutodiffMode.NONE` and `AutodiffMode.VALIDATION` are all forward kernels.
|
241
|
+
# The difference is there are `assert` for global data access rule check in VALIDATION kernels.
|
242
|
+
assert func.autodiff_mode in (
|
243
|
+
AutodiffMode.NONE,
|
244
|
+
AutodiffMode.VALIDATION,
|
245
|
+
), "Inserted funcs should be forward kernels."
|
246
|
+
self.modes.append(func.autodiff_mode)
|
247
|
+
if self.validation:
|
248
|
+
func.autodiff_mode = AutodiffMode.VALIDATION
|
249
|
+
self.calls.append((func, args))
|
250
|
+
|
251
|
+
def grad(self):
|
252
|
+
assert self.entered, "Before evaluating gradients tape must be entered."
|
253
|
+
assert not self.gradient_evaluated, "Gradients of grad can be evaluated only once."
|
254
|
+
|
255
|
+
# Set grad for loss
|
256
|
+
if isinstance(self.loss, (Field, Ndarray)):
|
257
|
+
self.loss.grad.fill(1.0)
|
258
|
+
else:
|
259
|
+
import torch # pylint: disable=C0415
|
260
|
+
|
261
|
+
if self.loss.grad is None:
|
262
|
+
self.loss.grad = torch.ones_like(self.loss)
|
263
|
+
else:
|
264
|
+
with torch.no_grad():
|
265
|
+
self.loss.grad.fill_(1.0)
|
266
|
+
|
267
|
+
for func, args in reversed(self.calls):
|
268
|
+
# we need to check whether "func" has "grad" attribute
|
269
|
+
# since we insert write_int and write_float kernels to self.calls
|
270
|
+
# e.g. x[None] = 0.0, this func has no grad attribute
|
271
|
+
if hasattr(func, "grad"):
|
272
|
+
func.grad(*args)
|
273
|
+
|
274
|
+
self.gradient_evaluated = True
|
275
|
+
if self.grad_checker:
|
276
|
+
self.grad_checker.add_calls(self.calls)
|
277
|
+
self.grad_checker.check_grad()
|
278
|
+
|
279
|
+
|
280
|
+
def clear_all_gradients(gradient_type=SNodeGradType.ADJOINT):
|
281
|
+
"""Sets the gradients of all fields to zero."""
|
282
|
+
impl.get_runtime().materialize()
|
283
|
+
|
284
|
+
def visit(node):
|
285
|
+
places = []
|
286
|
+
for _i in range(node.ptr.get_num_ch()):
|
287
|
+
ch = node.ptr.get_ch(_i)
|
288
|
+
if not ch.is_place():
|
289
|
+
visit(SNode(ch))
|
290
|
+
else:
|
291
|
+
if ch.get_snode_grad_type() == gradient_type:
|
292
|
+
places.append(ch.get_expr())
|
293
|
+
|
294
|
+
places = tuple(places)
|
295
|
+
if places:
|
296
|
+
from gstaichi._kernels import clear_gradients # pylint: disable=C0415
|
297
|
+
|
298
|
+
clear_gradients(places)
|
299
|
+
|
300
|
+
for root_fb in _snode.FieldsBuilder._finalized_roots():
|
301
|
+
visit(root_fb)
|
302
|
+
|
303
|
+
|
304
|
+
def grad_replaced(func):
|
305
|
+
"""A decorator for python function to customize gradient with GsTaichi's autodiff
|
306
|
+
system, e.g. `ti.ad.Tape()` and `kernel.grad()`.
|
307
|
+
|
308
|
+
This decorator forces GsTaichi's autodiff system to use a user-defined gradient
|
309
|
+
function for the decorated function. Its customized gradient must be decorated
|
310
|
+
by :func:`~gstaichi.ad.grad_for`.
|
311
|
+
|
312
|
+
Args:
|
313
|
+
fn (Callable): The python function to be decorated.
|
314
|
+
|
315
|
+
Returns:
|
316
|
+
Callable: The decorated function.
|
317
|
+
|
318
|
+
Example::
|
319
|
+
|
320
|
+
>>> @ti.kernel
|
321
|
+
>>> def multiply(a: ti.float32):
|
322
|
+
>>> for I in ti.grouped(x):
|
323
|
+
>>> y[I] = x[I] * a
|
324
|
+
>>>
|
325
|
+
>>> @ti.kernel
|
326
|
+
>>> def multiply_grad(a: ti.float32):
|
327
|
+
>>> for I in ti.grouped(x):
|
328
|
+
>>> x.grad[I] = y.grad[I] / a
|
329
|
+
>>>
|
330
|
+
>>> @ti.ad.grad_replaced
|
331
|
+
>>> def foo(a):
|
332
|
+
>>> multiply(a)
|
333
|
+
>>>
|
334
|
+
>>> @ti.ad.grad_for(foo)
|
335
|
+
>>> def foo_grad(a):
|
336
|
+
>>> multiply_grad(a)"""
|
337
|
+
|
338
|
+
def decorated(*args, **kwargs):
|
339
|
+
# TODO [#3025]: get rid of circular imports and move this to the top.
|
340
|
+
impl.get_runtime().grad_replaced = True
|
341
|
+
if impl.get_runtime().target_tape:
|
342
|
+
impl.get_runtime().target_tape.insert(decorated, args)
|
343
|
+
try:
|
344
|
+
func(*args, **kwargs)
|
345
|
+
finally:
|
346
|
+
impl.get_runtime().grad_replaced = False
|
347
|
+
|
348
|
+
decorated.grad = None
|
349
|
+
decorated.autodiff_mode = AutodiffMode.NONE
|
350
|
+
return decorated
|
351
|
+
|
352
|
+
|
353
|
+
def grad_for(primal):
|
354
|
+
"""Generates a decorator to decorate `primal`'s customized gradient function.
|
355
|
+
|
356
|
+
See :func:`~gstaichi.lang.grad_replaced` for examples.
|
357
|
+
|
358
|
+
Args:
|
359
|
+
primal (Callable): The primal function, must be decorated by :func:`~gstaichi.ad.grad_replaced`.
|
360
|
+
|
361
|
+
Returns:
|
362
|
+
Callable: The decorator used to decorate customized gradient function."""
|
363
|
+
|
364
|
+
def decorator(func):
|
365
|
+
def decorated(*args, **kwargs):
|
366
|
+
func(*args, **kwargs)
|
367
|
+
|
368
|
+
if not hasattr(primal, "grad"):
|
369
|
+
raise RuntimeError(f"Primal function `{primal.__name__}` must be decorated by ti.ad.grad_replaced")
|
370
|
+
if primal.grad is not None:
|
371
|
+
raise RuntimeError(
|
372
|
+
"Primal function must be a **python** function instead of a gstaichi kernel. Please wrap the gstaichi kernel in a @ti.ad.grad_replaced decorated python function instead."
|
373
|
+
)
|
374
|
+
primal.grad = decorated
|
375
|
+
return decorated
|
376
|
+
|
377
|
+
return decorator
|
378
|
+
|
379
|
+
|
380
|
+
def no_grad(func):
|
381
|
+
"""A decorator for python function to skip gradient calculation within GsTaichi's
|
382
|
+
autodiff system, e.g. `ti.ad.Tape()` and `kernel.grad()`.
|
383
|
+
This decorator forces GsTaichi's autodiff system to use an empty gradient function
|
384
|
+
for the decorated function.
|
385
|
+
|
386
|
+
Args:
|
387
|
+
fn (Callable): The python function to be decorated.
|
388
|
+
|
389
|
+
Returns:
|
390
|
+
Callable: The decorated function.
|
391
|
+
|
392
|
+
Example::
|
393
|
+
|
394
|
+
>>> @ti.kernel
|
395
|
+
>>> def multiply(a: ti.float32):
|
396
|
+
>>> for I in ti.grouped(x):
|
397
|
+
>>> y[I] = x[I] * a
|
398
|
+
>>>
|
399
|
+
>>> @ti.no_grad
|
400
|
+
>>> def foo(a):
|
401
|
+
>>> multiply(a)"""
|
402
|
+
|
403
|
+
def decorated(*args, **kwargs):
|
404
|
+
impl.get_runtime().grad_replaced = True
|
405
|
+
if impl.get_runtime().target_tape:
|
406
|
+
impl.get_runtime().target_tape.insert(decorated, args)
|
407
|
+
try:
|
408
|
+
func(*args, **kwargs)
|
409
|
+
finally:
|
410
|
+
impl.get_runtime().grad_replaced = False
|
411
|
+
|
412
|
+
def placeholder(*args, **kwargs):
|
413
|
+
return
|
414
|
+
|
415
|
+
decorated.grad = placeholder
|
416
|
+
decorated.autodiff_mode = AutodiffMode.NONE
|
417
|
+
return decorated
|
418
|
+
|
419
|
+
|
420
|
+
class FwdMode:
|
421
|
+
def __init__(self, loss, param, seed=None, clear_gradients=True):
|
422
|
+
self.calls = []
|
423
|
+
self.modes = []
|
424
|
+
self.entered = False
|
425
|
+
self.kernels_recovered = False
|
426
|
+
self.runtime = impl.get_runtime()
|
427
|
+
self.loss = loss
|
428
|
+
self.param = param
|
429
|
+
self.seed = seed
|
430
|
+
self.clear_gradients = clear_gradients
|
431
|
+
|
432
|
+
def __enter__(self):
|
433
|
+
assert not self.entered, "Forward mode manager can be entered only once."
|
434
|
+
self.entered = True
|
435
|
+
impl.get_runtime().materialize()
|
436
|
+
if not isinstance(self.loss, list):
|
437
|
+
self.loss = [self.loss]
|
438
|
+
for ls in self.loss:
|
439
|
+
assert isinstance(ls, ScalarField)
|
440
|
+
|
441
|
+
# Currently we only support only one N-D field as a group of parameters,
|
442
|
+
# which is sufficient for computing Jacobian-vector product(Jvp).
|
443
|
+
# For cases with multiple groups of parameters, it requires to run the forward ad multiple times,
|
444
|
+
# which is out of scope of the current design for this interface.
|
445
|
+
|
446
|
+
# TODO: support vector field and matrix field
|
447
|
+
assert isinstance(self.param, ScalarField)
|
448
|
+
|
449
|
+
def shape_flatten(shape):
|
450
|
+
return reduce((lambda x, y: x * y), list(shape))
|
451
|
+
|
452
|
+
# Handle 0-D field
|
453
|
+
if len(self.param.shape) != 0:
|
454
|
+
parameters_shape_flatten = shape_flatten(self.param.shape)
|
455
|
+
else:
|
456
|
+
parameters_shape_flatten = 1
|
457
|
+
|
458
|
+
if not self.seed:
|
459
|
+
if parameters_shape_flatten == 1:
|
460
|
+
# Compute the derivative respect to the first variable by default
|
461
|
+
self.seed = [1.0]
|
462
|
+
else:
|
463
|
+
raise RuntimeError(
|
464
|
+
"`seed` is not set for non 0-D field, please specify."
|
465
|
+
" `seed` is a list to specify which parameters the computed derivatives respect to. The length of the `seed` should be same to that of the `parameters`"
|
466
|
+
" E.g. Given a loss `loss = ti.field(float, shape=3)`, parameter `x = ti.field(float, shape=3)`"
|
467
|
+
" seed = [0, 0, 1] indicates compute derivative respect to the third element of `x`."
|
468
|
+
" seed = [1, 1, 1] indicates compute the sum of derivatives respect to all three element of `x`, i.e., Jacobian-vector product(Jvp) for each element in `loss`"
|
469
|
+
)
|
470
|
+
else:
|
471
|
+
assert parameters_shape_flatten == len(self.seed)
|
472
|
+
|
473
|
+
# Clear gradients
|
474
|
+
if self.clear_gradients:
|
475
|
+
clear_all_gradients(gradient_type=SNodeGradType.DUAL)
|
476
|
+
|
477
|
+
# Set seed for each variable
|
478
|
+
if len(self.seed) == 1:
|
479
|
+
if len(self.param.shape) == 0:
|
480
|
+
# e.g., x= ti.field(float, shape = ())
|
481
|
+
self.param.dual[None] = 1.0 * self.seed[0]
|
482
|
+
else:
|
483
|
+
# e.g., ti.root.dense(ti.i, 1).place(x.dual)
|
484
|
+
self.param.dual[0] = 1.0 * self.seed[0]
|
485
|
+
else:
|
486
|
+
self.param.dual.from_numpy(np.array(self.seed, dtype=np.float32))
|
487
|
+
|
488
|
+
# Attach the context manager to the runtime
|
489
|
+
self.runtime.fwd_mode_manager = self
|
490
|
+
|
491
|
+
def __exit__(self, _type, value, tb):
|
492
|
+
self.runtime.fwd_mode_manager = None
|
493
|
+
self.clear_seed()
|
494
|
+
self.recover_kernels()
|
495
|
+
|
496
|
+
def insert(self, func):
|
497
|
+
assert (
|
498
|
+
func.autodiff_mode == AutodiffMode.NONE or func.autodiff_mode == AutodiffMode.FORWARD
|
499
|
+
), "Inserted funcs should be forward or grad kernels (forward mode)."
|
500
|
+
self.modes.append(func.autodiff_mode)
|
501
|
+
func.autodiff_mode = AutodiffMode.FORWARD
|
502
|
+
self.calls.append((func))
|
503
|
+
|
504
|
+
def recover_kernels(self):
|
505
|
+
assert self.entered, "Before recover the kernels, fwd mode manager must be entered."
|
506
|
+
for calls, mode in zip(self.calls, self.modes):
|
507
|
+
calls.autodiff_mode = mode
|
508
|
+
self.kernels_recovered = True
|
509
|
+
|
510
|
+
def clear_seed(self):
|
511
|
+
# clear seed values
|
512
|
+
if len(self.seed) == 1:
|
513
|
+
if len(self.param.shape) == 0:
|
514
|
+
# e.g., x= ti.field(float, shape = ())
|
515
|
+
self.param.dual[None] = 0.0
|
516
|
+
else:
|
517
|
+
# e.g., ti.root.dense(ti.i, 1).place(x.dual)
|
518
|
+
self.param.dual[0] = 0.0
|
519
|
+
else:
|
520
|
+
self.param.dual.fill(0)
|
521
|
+
|
522
|
+
|
523
|
+
__all__ = [
|
524
|
+
"FwdMode",
|
525
|
+
"Tape",
|
526
|
+
"clear_all_gradients",
|
527
|
+
"grad_for",
|
528
|
+
"grad_replaced",
|
529
|
+
"no_grad",
|
530
|
+
]
|