ninetoothed 0.5.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ninetoothed/jit.py +34 -14
- ninetoothed/tensor.py +4 -1
- {ninetoothed-0.5.0.dist-info → ninetoothed-0.6.0.dist-info}/METADATA +1 -1
- ninetoothed-0.6.0.dist-info/RECORD +10 -0
- ninetoothed-0.5.0.dist-info/RECORD +0 -10
- {ninetoothed-0.5.0.dist-info → ninetoothed-0.6.0.dist-info}/WHEEL +0 -0
- {ninetoothed-0.5.0.dist-info → ninetoothed-0.6.0.dist-info}/licenses/LICENSE +0 -0
ninetoothed/jit.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1
1
|
import ast
|
2
2
|
import collections
|
3
3
|
import functools
|
4
|
+
import importlib.util
|
4
5
|
import inspect
|
5
6
|
import itertools
|
6
7
|
import math
|
8
|
+
import sys
|
7
9
|
import tempfile
|
8
10
|
|
9
11
|
import triton
|
@@ -41,24 +43,20 @@ class JIT:
|
|
41
43
|
ast.fix_missing_locations(tree)
|
42
44
|
|
43
45
|
unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
|
46
|
+
dependencies = self._find_dependencies()
|
47
|
+
source = "\n\n".join((unparsed, dependencies)).strip()
|
44
48
|
|
45
49
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
|
46
|
-
temp_file.write(
|
50
|
+
temp_file.write(source.encode("utf-8"))
|
47
51
|
temp_file_name = temp_file.name
|
48
52
|
|
49
|
-
|
50
|
-
|
51
|
-
source=temp_file.read(),
|
52
|
-
filename=temp_file_name,
|
53
|
-
mode="exec",
|
54
|
-
)
|
55
|
-
|
56
|
-
namespace = {}
|
57
|
-
exec(code, namespace)
|
53
|
+
module = type(self)._import_from_path(temp_file_name, temp_file_name)
|
54
|
+
module_vars = vars(module)
|
58
55
|
|
59
56
|
handle = _Handle(
|
60
|
-
|
61
|
-
|
57
|
+
module_vars[self.func.__name__],
|
58
|
+
module_vars[f"launch_{self.func.__name__}"],
|
59
|
+
source,
|
62
60
|
)
|
63
61
|
|
64
62
|
type(self).handles[source_file][source_line] = handle
|
@@ -74,6 +72,24 @@ class JIT:
|
|
74
72
|
|
75
73
|
return ast.Module(body=[finder.result], type_ignores=[])
|
76
74
|
|
75
|
+
def _find_dependencies(self):
|
76
|
+
dependencies = set()
|
77
|
+
|
78
|
+
for obj in self.func.__globals__.values():
|
79
|
+
if isinstance(obj, triton.runtime.JITFunction):
|
80
|
+
dependencies.add(obj.src)
|
81
|
+
|
82
|
+
return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
|
83
|
+
|
84
|
+
@staticmethod
|
85
|
+
def _import_from_path(module_name, file_path):
|
86
|
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
87
|
+
module = importlib.util.module_from_spec(spec)
|
88
|
+
sys.modules[module_name] = module
|
89
|
+
spec.loader.exec_module(module)
|
90
|
+
|
91
|
+
return module
|
92
|
+
|
77
93
|
|
78
94
|
class CodeGenerator(ast.NodeTransformer):
|
79
95
|
def __init__(self, context):
|
@@ -102,7 +118,7 @@ class CodeGenerator(ast.NodeTransformer):
|
|
102
118
|
self.generic_visit(node)
|
103
119
|
|
104
120
|
for arg in self._args:
|
105
|
-
if not isinstance(arg, Tensor):
|
121
|
+
if not isinstance(arg, Tensor) or arg.ndim == 0:
|
106
122
|
continue
|
107
123
|
|
108
124
|
offsets = arg.offsets()
|
@@ -355,6 +371,9 @@ class CodeGenerator(ast.NodeTransformer):
|
|
355
371
|
|
356
372
|
@staticmethod
|
357
373
|
def _generate_load(tensor, intermediate_indices=()):
|
374
|
+
if tensor.ndim == 0:
|
375
|
+
return Symbol(tensor.original.name).node
|
376
|
+
|
358
377
|
pointers, mask = CodeGenerator._generate_pointers_and_mask(
|
359
378
|
tensor, intermediate_indices
|
360
379
|
)
|
@@ -459,9 +478,10 @@ class Tritonizer(ast.NodeTransformer):
|
|
459
478
|
|
460
479
|
|
461
480
|
class _Handle:
|
462
|
-
def __init__(self, kernel, launch):
|
481
|
+
def __init__(self, kernel, launch, source):
|
463
482
|
self._kernel = kernel
|
464
483
|
self._launch = launch
|
484
|
+
self._source = source
|
465
485
|
|
466
486
|
def __call__(self, *args, **kwargs):
|
467
487
|
return self._launch(*args, **kwargs)
|
ninetoothed/tensor.py
CHANGED
@@ -21,7 +21,7 @@ class Tensor:
|
|
21
21
|
|
22
22
|
self.dtype = dtype
|
23
23
|
|
24
|
-
self.name = f"
|
24
|
+
self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
|
25
25
|
|
26
26
|
if ndim is not None:
|
27
27
|
self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
|
@@ -103,6 +103,9 @@ class Tensor:
|
|
103
103
|
)
|
104
104
|
|
105
105
|
def names(self):
|
106
|
+
if self.ndim == 0:
|
107
|
+
return {self.original.name}
|
108
|
+
|
106
109
|
return (
|
107
110
|
{self.original.pointer_string()}
|
108
111
|
| {
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.3
|
2
2
|
Name: ninetoothed
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.6.0
|
4
4
|
Summary: A domain-specific language based on Triton but providing higher-level abstraction.
|
5
5
|
Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
|
6
6
|
Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
|
@@ -0,0 +1,10 @@
|
|
1
|
+
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
+
ninetoothed/jit.py,sha256=5gNp4HixCkural_Ns3DxwT4LL3OUcG0ECj4NLjb-EYk,16959
|
3
|
+
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
+
ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
|
5
|
+
ninetoothed/tensor.py,sha256=L-9LhwnM4uRtRvj3tqrzerUijEfKeTQvFBcmS1hQilI,6656
|
6
|
+
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
+
ninetoothed-0.6.0.dist-info/METADATA,sha256=zvY4nvKt7R8kWDYrGnApem_C07trLgOj1-7zXPfqD9U,6785
|
8
|
+
ninetoothed-0.6.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
+
ninetoothed-0.6.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
+
ninetoothed-0.6.0.dist-info/RECORD,,
|
@@ -1,10 +0,0 @@
|
|
1
|
-
ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
|
2
|
-
ninetoothed/jit.py,sha256=q-TRGF81rUEwV1TGDrew3ijwvzCWenR8EejZbYteZSI,16188
|
3
|
-
ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
|
4
|
-
ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
|
5
|
-
ninetoothed/tensor.py,sha256=UO79yYwHMfdqv32Ww2mtcl-ki1C9zInC0vBNwDtzlHU,6575
|
6
|
-
ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
|
7
|
-
ninetoothed-0.5.0.dist-info/METADATA,sha256=ObwfQtwBk3x90adbQfiSo5wK11qUG9f4NdmunyjC--0,6785
|
8
|
-
ninetoothed-0.5.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
9
|
-
ninetoothed-0.5.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
10
|
-
ninetoothed-0.5.0.dist-info/RECORD,,
|
File without changes
|
File without changes
|