gstaichi 0.1.18.dev1__cp310-cp310-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget-release.cmake +29 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools/cmake/SPIRV-ToolsTarget.cmake +113 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-diff/cmake/SPIRV-Tools-diffTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-link/cmake/SPIRV-Tools-linkTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-lint/cmake/SPIRV-Tools-lintTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-opt/cmake/SPIRV-Tools-optTargets.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceConfig.cmake +5 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/SPIRV-Tools-reduce/cmake/SPIRV-Tools-reduceTarget.cmake +122 -0
- gstaichi-0.1.18.dev1.data/data/bin/SPIRV-Tools-shared.dll +0 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3.h +6389 -0
- gstaichi-0.1.18.dev1.data/data/include/GLFW/glfw3native.h +594 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/instrument.hpp +268 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.h +907 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/libspirv.hpp +375 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/linker.hpp +97 -0
- gstaichi-0.1.18.dev1.data/data/include/spirv-tools/optimizer.hpp +970 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-diff.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-link.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-lint.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-opt.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-reduce.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools-shared.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/SPIRV-Tools.lib +0 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Config.cmake +3 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3ConfigVersion.cmake +65 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets-release.cmake +19 -0
- gstaichi-0.1.18.dev1.data/data/lib/cmake/glfw3/glfw3Targets.cmake +107 -0
- gstaichi-0.1.18.dev1.data/data/lib/glfw3.lib +0 -0
- gstaichi-0.1.18.dev1.dist-info/METADATA +108 -0
- gstaichi-0.1.18.dev1.dist-info/RECORD +198 -0
- gstaichi-0.1.18.dev1.dist-info/WHEEL +5 -0
- gstaichi-0.1.18.dev1.dist-info/entry_points.txt +2 -0
- gstaichi-0.1.18.dev1.dist-info/licenses/LICENSE +201 -0
- gstaichi-0.1.18.dev1.dist-info/top_level.txt +1 -0
- taichi/CHANGELOG.md +15 -0
- taichi/__init__.py +44 -0
- taichi/__main__.py +5 -0
- taichi/_funcs.py +706 -0
- taichi/_kernels.py +420 -0
- taichi/_lib/__init__.py +3 -0
- taichi/_lib/c_api/bin/taichi_c_api.dll +0 -0
- taichi/_lib/c_api/include/taichi/cpp/taichi.hpp +1401 -0
- taichi/_lib/c_api/include/taichi/taichi.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_core.h +1111 -0
- taichi/_lib/c_api/include/taichi/taichi_cpu.h +29 -0
- taichi/_lib/c_api/include/taichi/taichi_cuda.h +36 -0
- taichi/_lib/c_api/include/taichi/taichi_platform.h +55 -0
- taichi/_lib/c_api/include/taichi/taichi_unity.h +64 -0
- taichi/_lib/c_api/include/taichi/taichi_vulkan.h +151 -0
- taichi/_lib/c_api/lib/taichi_c_api.lib +0 -0
- taichi/_lib/c_api/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/c_api/runtime/runtime_x64.bc +0 -0
- taichi/_lib/c_api/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfig.cmake +29 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiConfigVersion.cmake +65 -0
- taichi/_lib/c_api/taichi/lib/cmake/taichi/TaichiTargets.cmake +121 -0
- taichi/_lib/core/__init__.py +0 -0
- taichi/_lib/core/py.typed +0 -0
- taichi/_lib/core/taichi_python.cp310-win_amd64.pyd +0 -0
- taichi/_lib/core/taichi_python.pyi +3077 -0
- taichi/_lib/runtime/runtime_cuda.bc +0 -0
- taichi/_lib/runtime/runtime_x64.bc +0 -0
- taichi/_lib/runtime/slim_libdevice.10.bc +0 -0
- taichi/_lib/utils.py +249 -0
- taichi/_logging.py +131 -0
- taichi/_main.py +552 -0
- taichi/_snode/__init__.py +5 -0
- taichi/_snode/fields_builder.py +189 -0
- taichi/_snode/snode_tree.py +34 -0
- taichi/_ti_module/__init__.py +3 -0
- taichi/_ti_module/cppgen.py +309 -0
- taichi/_ti_module/module.py +145 -0
- taichi/_version.py +1 -0
- taichi/_version_check.py +100 -0
- taichi/ad/__init__.py +3 -0
- taichi/ad/_ad.py +530 -0
- taichi/algorithms/__init__.py +3 -0
- taichi/algorithms/_algorithms.py +117 -0
- taichi/aot/__init__.py +12 -0
- taichi/aot/_export.py +28 -0
- taichi/aot/conventions/__init__.py +3 -0
- taichi/aot/conventions/gfxruntime140/__init__.py +38 -0
- taichi/aot/conventions/gfxruntime140/dr.py +244 -0
- taichi/aot/conventions/gfxruntime140/sr.py +613 -0
- taichi/aot/module.py +253 -0
- taichi/aot/utils.py +151 -0
- taichi/assets/.git +1 -0
- taichi/assets/Go-Regular.ttf +0 -0
- taichi/assets/static/imgs/ti_gallery.png +0 -0
- taichi/examples/minimal.py +28 -0
- taichi/experimental.py +16 -0
- taichi/graph/__init__.py +3 -0
- taichi/graph/_graph.py +292 -0
- taichi/lang/__init__.py +50 -0
- taichi/lang/_ndarray.py +348 -0
- taichi/lang/_ndrange.py +152 -0
- taichi/lang/_texture.py +172 -0
- taichi/lang/_wrap_inspect.py +189 -0
- taichi/lang/any_array.py +99 -0
- taichi/lang/argpack.py +411 -0
- taichi/lang/ast/__init__.py +5 -0
- taichi/lang/ast/ast_transformer.py +1806 -0
- taichi/lang/ast/ast_transformer_utils.py +328 -0
- taichi/lang/ast/checkers.py +106 -0
- taichi/lang/ast/symbol_resolver.py +57 -0
- taichi/lang/ast/transform.py +9 -0
- taichi/lang/common_ops.py +310 -0
- taichi/lang/exception.py +80 -0
- taichi/lang/expr.py +180 -0
- taichi/lang/field.py +464 -0
- taichi/lang/impl.py +1246 -0
- taichi/lang/kernel_arguments.py +157 -0
- taichi/lang/kernel_impl.py +1415 -0
- taichi/lang/matrix.py +1877 -0
- taichi/lang/matrix_ops.py +341 -0
- taichi/lang/matrix_ops_utils.py +190 -0
- taichi/lang/mesh.py +687 -0
- taichi/lang/misc.py +807 -0
- taichi/lang/ops.py +1489 -0
- taichi/lang/runtime_ops.py +13 -0
- taichi/lang/shell.py +35 -0
- taichi/lang/simt/__init__.py +5 -0
- taichi/lang/simt/block.py +94 -0
- taichi/lang/simt/grid.py +7 -0
- taichi/lang/simt/subgroup.py +191 -0
- taichi/lang/simt/warp.py +96 -0
- taichi/lang/snode.py +487 -0
- taichi/lang/source_builder.py +150 -0
- taichi/lang/struct.py +855 -0
- taichi/lang/util.py +381 -0
- taichi/linalg/__init__.py +8 -0
- taichi/linalg/matrixfree_cg.py +310 -0
- taichi/linalg/sparse_cg.py +59 -0
- taichi/linalg/sparse_matrix.py +303 -0
- taichi/linalg/sparse_solver.py +123 -0
- taichi/math/__init__.py +11 -0
- taichi/math/_complex.py +204 -0
- taichi/math/mathimpl.py +886 -0
- taichi/profiler/__init__.py +6 -0
- taichi/profiler/kernel_metrics.py +260 -0
- taichi/profiler/kernel_profiler.py +592 -0
- taichi/profiler/memory_profiler.py +15 -0
- taichi/profiler/scoped_profiler.py +36 -0
- taichi/shaders/Circles_vk.frag +29 -0
- taichi/shaders/Circles_vk.vert +45 -0
- taichi/shaders/Circles_vk_frag.spv +0 -0
- taichi/shaders/Circles_vk_vert.spv +0 -0
- taichi/shaders/Lines_vk.frag +9 -0
- taichi/shaders/Lines_vk.vert +11 -0
- taichi/shaders/Lines_vk_frag.spv +0 -0
- taichi/shaders/Lines_vk_vert.spv +0 -0
- taichi/shaders/Mesh_vk.frag +71 -0
- taichi/shaders/Mesh_vk.vert +68 -0
- taichi/shaders/Mesh_vk_frag.spv +0 -0
- taichi/shaders/Mesh_vk_vert.spv +0 -0
- taichi/shaders/Particles_vk.frag +95 -0
- taichi/shaders/Particles_vk.vert +73 -0
- taichi/shaders/Particles_vk_frag.spv +0 -0
- taichi/shaders/Particles_vk_vert.spv +0 -0
- taichi/shaders/SceneLines2quad_vk_comp.spv +0 -0
- taichi/shaders/SceneLines_vk.frag +9 -0
- taichi/shaders/SceneLines_vk.vert +12 -0
- taichi/shaders/SceneLines_vk_frag.spv +0 -0
- taichi/shaders/SceneLines_vk_vert.spv +0 -0
- taichi/shaders/SetImage_vk.frag +21 -0
- taichi/shaders/SetImage_vk.vert +15 -0
- taichi/shaders/SetImage_vk_frag.spv +0 -0
- taichi/shaders/SetImage_vk_vert.spv +0 -0
- taichi/shaders/Triangles_vk.frag +16 -0
- taichi/shaders/Triangles_vk.vert +29 -0
- taichi/shaders/Triangles_vk_frag.spv +0 -0
- taichi/shaders/Triangles_vk_vert.spv +0 -0
- taichi/shaders/lines2quad_vk_comp.spv +0 -0
- taichi/sparse/__init__.py +3 -0
- taichi/sparse/_sparse_grid.py +77 -0
- taichi/tools/__init__.py +12 -0
- taichi/tools/diagnose.py +124 -0
- taichi/tools/np2ply.py +364 -0
- taichi/tools/vtk.py +38 -0
- taichi/types/__init__.py +19 -0
- taichi/types/annotations.py +47 -0
- taichi/types/compound_types.py +90 -0
- taichi/types/enums.py +49 -0
- taichi/types/ndarray_type.py +147 -0
- taichi/types/primitive_types.py +203 -0
- taichi/types/quant.py +88 -0
- taichi/types/texture_type.py +85 -0
- taichi/types/utils.py +13 -0
@@ -0,0 +1,117 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from taichi._kernels import (
|
4
|
+
blit_from_field_to_field,
|
5
|
+
scan_add_inclusive,
|
6
|
+
sort_stage,
|
7
|
+
uniform_add,
|
8
|
+
warp_shfl_up_i32,
|
9
|
+
)
|
10
|
+
from taichi.lang.impl import current_cfg, field
|
11
|
+
from taichi.lang.kernel_impl import data_oriented
|
12
|
+
from taichi.lang.misc import cuda, vulkan
|
13
|
+
from taichi.lang.runtime_ops import sync
|
14
|
+
from taichi.lang.simt import subgroup
|
15
|
+
from taichi.types.primitive_types import i32
|
16
|
+
|
17
|
+
|
18
|
+
def parallel_sort(keys, values=None):
|
19
|
+
"""Odd-even merge sort
|
20
|
+
|
21
|
+
References:
|
22
|
+
https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
|
23
|
+
https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
|
24
|
+
"""
|
25
|
+
N = keys.shape[0]
|
26
|
+
|
27
|
+
num_stages = 0
|
28
|
+
p = 1
|
29
|
+
while p < N:
|
30
|
+
k = p
|
31
|
+
while k >= 1:
|
32
|
+
invocations = int((N - k - k % p) / (2 * k)) + 1
|
33
|
+
if values is None:
|
34
|
+
sort_stage(keys, 0, keys, N, p, k, invocations)
|
35
|
+
else:
|
36
|
+
sort_stage(keys, 1, values, N, p, k, invocations)
|
37
|
+
num_stages += 1
|
38
|
+
sync()
|
39
|
+
k = int(k / 2)
|
40
|
+
p = int(p * 2)
|
41
|
+
|
42
|
+
|
43
|
+
@data_oriented
|
44
|
+
class PrefixSumExecutor:
|
45
|
+
"""Parallel Prefix Sum (Scan) Helper
|
46
|
+
|
47
|
+
Use this helper to perform an inclusive in-place's parallel prefix sum.
|
48
|
+
|
49
|
+
References:
|
50
|
+
https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
|
51
|
+
https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(self, length):
|
55
|
+
self.sorting_length = length
|
56
|
+
|
57
|
+
BLOCK_SZ = 64
|
58
|
+
GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)
|
59
|
+
|
60
|
+
# Buffer position and length
|
61
|
+
# This is a single buffer implementation for ease of aot usage
|
62
|
+
ele_num = length
|
63
|
+
self.ele_nums = [ele_num]
|
64
|
+
start_pos = 0
|
65
|
+
self.ele_nums_pos = [start_pos]
|
66
|
+
|
67
|
+
while ele_num > 1:
|
68
|
+
ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
|
69
|
+
self.ele_nums.append(ele_num)
|
70
|
+
start_pos += BLOCK_SZ * ele_num
|
71
|
+
self.ele_nums_pos.append(start_pos)
|
72
|
+
|
73
|
+
self.large_arr = field(i32, shape=start_pos)
|
74
|
+
|
75
|
+
def run(self, input_arr):
|
76
|
+
length = self.sorting_length
|
77
|
+
ele_nums = self.ele_nums
|
78
|
+
ele_nums_pos = self.ele_nums_pos
|
79
|
+
|
80
|
+
if input_arr.dtype != i32:
|
81
|
+
raise RuntimeError("Only ti.i32 type is supported for prefix sum.")
|
82
|
+
|
83
|
+
if current_cfg().arch == cuda:
|
84
|
+
inclusive_add = warp_shfl_up_i32
|
85
|
+
elif current_cfg().arch == vulkan:
|
86
|
+
inclusive_add = subgroup.inclusive_add
|
87
|
+
else:
|
88
|
+
raise RuntimeError(f"{str(current_cfg().arch)} is not supported for prefix sum.")
|
89
|
+
|
90
|
+
blit_from_field_to_field(self.large_arr, input_arr, 0, length)
|
91
|
+
|
92
|
+
# Kogge-Stone construction
|
93
|
+
for i in range(len(ele_nums) - 1):
|
94
|
+
if i == len(ele_nums) - 2:
|
95
|
+
scan_add_inclusive(
|
96
|
+
self.large_arr,
|
97
|
+
ele_nums_pos[i],
|
98
|
+
ele_nums_pos[i + 1],
|
99
|
+
True,
|
100
|
+
inclusive_add,
|
101
|
+
)
|
102
|
+
else:
|
103
|
+
scan_add_inclusive(
|
104
|
+
self.large_arr,
|
105
|
+
ele_nums_pos[i],
|
106
|
+
ele_nums_pos[i + 1],
|
107
|
+
False,
|
108
|
+
inclusive_add,
|
109
|
+
)
|
110
|
+
|
111
|
+
for i in range(len(ele_nums) - 3, -1, -1):
|
112
|
+
uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])
|
113
|
+
|
114
|
+
blit_from_field_to_field(input_arr, self.large_arr, 0, length)
|
115
|
+
|
116
|
+
|
117
|
+
__all__ = ["parallel_sort", "PrefixSumExecutor"]
|
taichi/aot/__init__.py
ADDED
@@ -0,0 +1,12 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""Taichi's AOT (ahead of time) module.
|
4
|
+
|
5
|
+
Users can use Taichi as a GPU compute shader/kernel compiler by compiling their
|
6
|
+
Taichi kernels into an AOT module.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import taichi.aot.conventions
|
10
|
+
from taichi.aot._export import export, export_as
|
11
|
+
from taichi.aot.conventions.gfxruntime140 import GfxRuntime140
|
12
|
+
from taichi.aot.module import Module
|
taichi/aot/_export.py
ADDED
@@ -0,0 +1,28 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
from typing import Any, Dict, List, Optional
|
4
|
+
|
5
|
+
|
6
|
+
class AotExportKernel:
|
7
|
+
def __init__(self, f, name: str, template_types: Dict[str, Any]) -> None:
|
8
|
+
self.kernel = f
|
9
|
+
self.name = name
|
10
|
+
self.template_types = template_types
|
11
|
+
|
12
|
+
|
13
|
+
_aot_kernels: List[AotExportKernel] = []
|
14
|
+
|
15
|
+
|
16
|
+
def export_as(name: str, *, template_types: Optional[Dict[str, Any]] = None):
|
17
|
+
def inner(f):
|
18
|
+
assert hasattr(f, "_is_wrapped_kernel"), "Only Taichi kernels can be exported"
|
19
|
+
|
20
|
+
record = AotExportKernel(f, name, template_types or {})
|
21
|
+
_aot_kernels.append(record)
|
22
|
+
return f
|
23
|
+
|
24
|
+
return inner
|
25
|
+
|
26
|
+
|
27
|
+
def export(f):
|
28
|
+
return export_as(f.__name__)(f)
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
import json
|
4
|
+
import zipfile
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import Any, List
|
7
|
+
|
8
|
+
from taichi.aot.conventions.gfxruntime140 import dr, sr
|
9
|
+
|
10
|
+
|
11
|
+
class GfxRuntime140:
|
12
|
+
def __init__(self, metadata_json: Any, graphs_json: Any) -> None:
|
13
|
+
metadata = dr.from_json_metadata(metadata_json)
|
14
|
+
graphs = [dr.from_json_graph(x) for x in graphs_json]
|
15
|
+
self.metadata = sr.from_dr_metadata(metadata)
|
16
|
+
self.graphs = [sr.from_dr_graph(self.metadata, x) for x in graphs]
|
17
|
+
|
18
|
+
@staticmethod
|
19
|
+
def from_module(module_path: str) -> "GfxRuntime140":
|
20
|
+
if Path(module_path).is_file():
|
21
|
+
with zipfile.ZipFile(module_path) as z:
|
22
|
+
with z.open("metadata.json") as f:
|
23
|
+
metadata_json = json.load(f)
|
24
|
+
with z.open("graphs.json") as f:
|
25
|
+
graphs_json = json.load(f)
|
26
|
+
else:
|
27
|
+
with open(f"{module_path}/metadata.json") as f:
|
28
|
+
metadata_json = json.load(f)
|
29
|
+
with open(f"{module_path}/graphs.json") as f:
|
30
|
+
graphs_json = json.load(f)
|
31
|
+
|
32
|
+
return GfxRuntime140(metadata_json, graphs_json)
|
33
|
+
|
34
|
+
def to_metadata_json(self) -> Any:
|
35
|
+
return dr.to_json_metadata(sr.to_dr_metadata(self.metadata))
|
36
|
+
|
37
|
+
def to_graphs_json(self) -> List[Any]:
|
38
|
+
return [dr.to_json_graph(sr.to_dr_graph(x)) for x in self.graphs]
|
@@ -0,0 +1,244 @@
|
|
1
|
+
# type: ignore
|
2
|
+
|
3
|
+
"""
|
4
|
+
Data representation of all JSON data structures following the GfxRuntime140
|
5
|
+
convention.
|
6
|
+
"""
|
7
|
+
|
8
|
+
from typing import Any, Dict, List, Optional
|
9
|
+
|
10
|
+
from taichi.aot.utils import dump_json_data_model, json_data_model
|
11
|
+
|
12
|
+
|
13
|
+
@json_data_model
|
14
|
+
class FieldAttributes:
|
15
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
16
|
+
dtype = j["dtype"]
|
17
|
+
dtype_name = j["dtype_name"]
|
18
|
+
element_shape = j["element_shape"]
|
19
|
+
field_name = j["field_name"]
|
20
|
+
is_scalar = j["is_scalar"]
|
21
|
+
mem_offset_in_parent = j["mem_offset_in_parent"]
|
22
|
+
shape = j["shape"]
|
23
|
+
|
24
|
+
self.dtype: int = int(dtype)
|
25
|
+
self.dtype_name: str = str(dtype_name)
|
26
|
+
self.element_shape: List[int] = [int(x) for x in element_shape]
|
27
|
+
self.field_name: str = str(field_name)
|
28
|
+
self.is_scalar: bool = bool(is_scalar)
|
29
|
+
self.mem_offset_in_parent: int = int(mem_offset_in_parent)
|
30
|
+
self.shape: List[int] = [int(x) for x in shape]
|
31
|
+
|
32
|
+
|
33
|
+
@json_data_model
|
34
|
+
class ArgumentAttributes:
|
35
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
36
|
+
index = j["key"][0]
|
37
|
+
dtype = j["value"]["dtype"]
|
38
|
+
element_shape = j["value"]["element_shape"]
|
39
|
+
field_dim = j["value"]["field_dim"]
|
40
|
+
fmt = j["value"]["format"]
|
41
|
+
is_array = j["value"]["is_array"]
|
42
|
+
offset_in_mem = j["value"]["offset_in_mem"]
|
43
|
+
stride = j["value"]["stride"]
|
44
|
+
# (penguinliong) Note that the name field is optional for kernels.
|
45
|
+
# Kernels are always launched by indexed arguments and this is for
|
46
|
+
# debugging and header generation only.
|
47
|
+
name = j["value"]["name"] if "name" in j["value"] and len(j["value"]["name"]) > 0 else None
|
48
|
+
ptype = j["value"]["ptype"] if "ptype" in j["value"] else None
|
49
|
+
|
50
|
+
self.dtype: int = int(dtype)
|
51
|
+
self.element_shape: List[int] = [int(x) for x in element_shape]
|
52
|
+
self.field_dim: int = int(field_dim)
|
53
|
+
self.format: int = int(fmt)
|
54
|
+
self.index: int = int(index)
|
55
|
+
self.is_array: bool = bool(is_array)
|
56
|
+
self.offset_in_mem: int = int(offset_in_mem)
|
57
|
+
self.stride: int = int(stride)
|
58
|
+
self.name: Optional[str] = str(name) if name is not None else None
|
59
|
+
self.ptype: Optional[int] = int(ptype) if ptype is not None else None
|
60
|
+
|
61
|
+
|
62
|
+
@json_data_model
|
63
|
+
class ContextAttributes:
|
64
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
65
|
+
arg_attribs_vec_ = j["arg_attribs_vec_"]
|
66
|
+
args_bytes_ = j["args_bytes_"]
|
67
|
+
arr_access = j["arr_access"]
|
68
|
+
ret_attribs_vec_ = j["ret_attribs_vec_"]
|
69
|
+
rets_bytes_ = j["rets_bytes_"]
|
70
|
+
|
71
|
+
self.arg_attribs_vec_: List[ArgumentAttributes] = [ArgumentAttributes(x) for x in arg_attribs_vec_]
|
72
|
+
self.arg_attribs_vec_.sort(key=lambda x: x.index)
|
73
|
+
self.args_bytes_: int = int(args_bytes_)
|
74
|
+
self.arr_access: List[int] = [int(x["value"]) for x in arr_access]
|
75
|
+
self.ret_attribs_vec_: List[ArgumentAttributes] = [ArgumentAttributes(x) for x in ret_attribs_vec_]
|
76
|
+
self.rets_bytes_: int = int(rets_bytes_)
|
77
|
+
|
78
|
+
|
79
|
+
@json_data_model
|
80
|
+
class Buffer:
|
81
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
82
|
+
root_id = j["root_id"][0]
|
83
|
+
ty = j["type"]
|
84
|
+
|
85
|
+
self.root_id: int = int(root_id)
|
86
|
+
self.type: int = int(ty)
|
87
|
+
|
88
|
+
|
89
|
+
@json_data_model
|
90
|
+
class BufferBinding:
|
91
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
92
|
+
binding = j["binding"]
|
93
|
+
buffer = j["buffer"]
|
94
|
+
|
95
|
+
self.binding: int = int(binding)
|
96
|
+
self.buffer: Buffer = Buffer(buffer)
|
97
|
+
|
98
|
+
|
99
|
+
@json_data_model
|
100
|
+
class TextureBinding:
|
101
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
102
|
+
arg_id = j["arg_id"]
|
103
|
+
binding = j["binding"]
|
104
|
+
is_storage = j["is_storage"]
|
105
|
+
|
106
|
+
self.arg_id: int = int(arg_id)
|
107
|
+
self.binding: int = int(binding)
|
108
|
+
self.is_storage: bool = bool(is_storage)
|
109
|
+
|
110
|
+
|
111
|
+
@json_data_model
|
112
|
+
class RangeForAttributes:
|
113
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
114
|
+
begin = j["begin"]
|
115
|
+
const_begin = j["const_begin"]
|
116
|
+
const_end = j["const_end"]
|
117
|
+
end = j["end"]
|
118
|
+
|
119
|
+
self.begin: int = int(begin)
|
120
|
+
self.const_begin: bool = bool(const_begin)
|
121
|
+
self.const_end: bool = bool(const_end)
|
122
|
+
self.end: int = int(end)
|
123
|
+
|
124
|
+
|
125
|
+
@json_data_model
|
126
|
+
class TaskAttributes:
|
127
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
128
|
+
advisory_num_threads_per_group = j["advisory_num_threads_per_group"]
|
129
|
+
advisory_total_num_threads = j["advisory_total_num_threads"]
|
130
|
+
buffer_binds = j["buffer_binds"]
|
131
|
+
name = j["name"]
|
132
|
+
range_for_attribs = j["range_for_attribs"] if "range_for_attribs" in j else None
|
133
|
+
task_type = j["task_type"]
|
134
|
+
texture_binds = j["texture_binds"]
|
135
|
+
|
136
|
+
self.advisory_num_threads_per_group: int = int(advisory_num_threads_per_group)
|
137
|
+
self.advisory_total_num_threads: int = int(advisory_total_num_threads)
|
138
|
+
self.buffer_binds: List[BufferBinding] = [BufferBinding(x) for x in buffer_binds]
|
139
|
+
self.name: str = str(name)
|
140
|
+
self.range_for_attribs: Optional[RangeForAttributes] = (
|
141
|
+
RangeForAttributes(range_for_attribs) if range_for_attribs is not None else None
|
142
|
+
)
|
143
|
+
self.task_type: int = int(task_type)
|
144
|
+
self.texture_binds: List[TextureBinding] = [TextureBinding(x) for x in texture_binds]
|
145
|
+
|
146
|
+
|
147
|
+
@json_data_model
|
148
|
+
class DeviceCapabilityLevel:
|
149
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
150
|
+
key = j["key"]
|
151
|
+
value = j["value"]
|
152
|
+
|
153
|
+
self.key: str = str(key)
|
154
|
+
self.value: int = int(value)
|
155
|
+
|
156
|
+
|
157
|
+
@json_data_model
|
158
|
+
class KernelAttributes:
|
159
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
160
|
+
ctx_attribs = j["ctx_attribs"]
|
161
|
+
is_jit_evaluator = j["is_jit_evaluator"]
|
162
|
+
name = j["name"]
|
163
|
+
tasks_attribs = j["tasks_attribs"]
|
164
|
+
|
165
|
+
self.ctx_attribs: ContextAttributes = ContextAttributes(ctx_attribs)
|
166
|
+
self.is_jit_evaluator: bool = bool(is_jit_evaluator)
|
167
|
+
self.name: str = str(name)
|
168
|
+
self.tasks_attribs: List[TaskAttributes] = [TaskAttributes(x) for x in tasks_attribs]
|
169
|
+
|
170
|
+
|
171
|
+
@json_data_model
|
172
|
+
class Metadata:
|
173
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
174
|
+
fields = j["fields"]
|
175
|
+
kernels = j["kernels"]
|
176
|
+
required_caps = j["required_caps"]
|
177
|
+
root_buffer_size = j["root_buffer_size"]
|
178
|
+
|
179
|
+
self.fields: List[FieldAttributes] = [FieldAttributes(x) for x in fields]
|
180
|
+
self.kernels: List[KernelAttributes] = [KernelAttributes(x) for x in kernels]
|
181
|
+
self.required_caps: List[DeviceCapabilityLevel] = [DeviceCapabilityLevel(x) for x in required_caps]
|
182
|
+
self.root_buffer_size: int = int(root_buffer_size)
|
183
|
+
|
184
|
+
|
185
|
+
def from_json_metadata(j: Dict[str, Any]) -> Metadata:
|
186
|
+
return Metadata(j)
|
187
|
+
|
188
|
+
|
189
|
+
def to_json_metadata(meta_data: Metadata) -> Dict[str, Any]:
|
190
|
+
return dump_json_data_model(meta_data)
|
191
|
+
|
192
|
+
|
193
|
+
@json_data_model
|
194
|
+
class SymbolicArgument:
|
195
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
196
|
+
dtype_id = j["dtype_id"]
|
197
|
+
element_shape = j["element_shape"]
|
198
|
+
field_dim = j["field_dim"]
|
199
|
+
name = j["name"]
|
200
|
+
num_channels = j["num_channels"]
|
201
|
+
tag = j["tag"]
|
202
|
+
|
203
|
+
self.dtype_id: int = int(dtype_id)
|
204
|
+
self.element_shape: List[int] = [int(x) for x in element_shape]
|
205
|
+
self.field_dim: int = int(field_dim)
|
206
|
+
self.name: str = str(name)
|
207
|
+
self.num_channels: int = int(num_channels)
|
208
|
+
self.tag: int = int(tag)
|
209
|
+
|
210
|
+
|
211
|
+
@json_data_model
|
212
|
+
class Dispatch:
|
213
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
214
|
+
kernel_name = j["kernel_name"]
|
215
|
+
symbolic_args = j["symbolic_args"]
|
216
|
+
|
217
|
+
self.kernel_name: str = str(kernel_name)
|
218
|
+
self.symbolic_args: List[SymbolicArgument] = [SymbolicArgument(x) for x in symbolic_args]
|
219
|
+
|
220
|
+
|
221
|
+
@json_data_model
|
222
|
+
class GraphData:
|
223
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
224
|
+
dispatches = j["dispatches"]
|
225
|
+
|
226
|
+
self.dispatches = [Dispatch(x) for x in dispatches]
|
227
|
+
|
228
|
+
|
229
|
+
@json_data_model
|
230
|
+
class Graph:
|
231
|
+
def __init__(self, j: Dict[str, Any]) -> None:
|
232
|
+
key = j["key"]
|
233
|
+
value = j["value"]
|
234
|
+
|
235
|
+
self.key = str(key)
|
236
|
+
self.value = GraphData(value)
|
237
|
+
|
238
|
+
|
239
|
+
def from_json_graph(j: Dict[str, Any]) -> Graph:
|
240
|
+
return Graph(j)
|
241
|
+
|
242
|
+
|
243
|
+
def to_json_graph(graph: Graph) -> Dict[str, Any]:
|
244
|
+
return dump_json_data_model(graph)
|