ninetoothed 0.5.0__tar.gz → 0.6.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/PKG-INFO +1 -1
  2. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/pyproject.toml +1 -1
  3. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/jit.py +34 -14
  4. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/tensor.py +4 -1
  5. ninetoothed-0.6.0/tests/test_addmm.py +104 -0
  6. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/.github/workflows/pytest.yml +0 -0
  7. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/.github/workflows/ruff.yml +0 -0
  8. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/.gitignore +0 -0
  9. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/LICENSE +0 -0
  10. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/README.md +0 -0
  11. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/docs/README.zh.md +0 -0
  12. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/docs/source/_static/ninetoothed-logo.png +0 -0
  13. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/requirements.txt +0 -0
  14. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/__init__.py +0 -0
  15. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/language.py +0 -0
  16. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/symbol.py +0 -0
  17. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/src/ninetoothed/torchifier.py +0 -0
  18. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/tests/__init__.py +0 -0
  19. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/tests/skippers.py +0 -0
  20. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/tests/test_add.py +0 -0
  21. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/tests/test_matmul.py +0 -0
  22. {ninetoothed-0.5.0 → ninetoothed-0.6.0}/tests/test_softmax.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.5.0
3
+ Version: 0.6.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "ninetoothed"
7
- version = "0.5.0"
7
+ version = "0.6.0"
8
8
  authors = [{ name = "Jiacheng Huang", email = "huangjiacheng0709@outlook.com" }]
9
9
  description = "A domain-specific language based on Triton but providing higher-level abstraction."
10
10
  readme = "README.md"
@@ -1,9 +1,11 @@
1
1
  import ast
2
2
  import collections
3
3
  import functools
4
+ import importlib.util
4
5
  import inspect
5
6
  import itertools
6
7
  import math
8
+ import sys
7
9
  import tempfile
8
10
 
9
11
  import triton
@@ -41,24 +43,20 @@ class JIT:
41
43
  ast.fix_missing_locations(tree)
42
44
 
43
45
  unparsed = ast.unparse(tree).replace("None:", ":").replace(":None", ":")
46
+ dependencies = self._find_dependencies()
47
+ source = "\n\n".join((unparsed, dependencies)).strip()
44
48
 
45
49
  with tempfile.NamedTemporaryFile(delete=False, suffix=".py") as temp_file:
46
- temp_file.write(unparsed.encode("utf-8"))
50
+ temp_file.write(source.encode("utf-8"))
47
51
  temp_file_name = temp_file.name
48
52
 
49
- with open(temp_file_name, "r") as temp_file:
50
- code = compile(
51
- source=temp_file.read(),
52
- filename=temp_file_name,
53
- mode="exec",
54
- )
55
-
56
- namespace = {}
57
- exec(code, namespace)
53
+ module = type(self)._import_from_path(temp_file_name, temp_file_name)
54
+ module_vars = vars(module)
58
55
 
59
56
  handle = _Handle(
60
- namespace[self.func.__name__],
61
- namespace[f"launch_{self.func.__name__}"],
57
+ module_vars[self.func.__name__],
58
+ module_vars[f"launch_{self.func.__name__}"],
59
+ source,
62
60
  )
63
61
 
64
62
  type(self).handles[source_file][source_line] = handle
@@ -74,6 +72,24 @@ class JIT:
74
72
 
75
73
  return ast.Module(body=[finder.result], type_ignores=[])
76
74
 
75
+ def _find_dependencies(self):
76
+ dependencies = set()
77
+
78
+ for obj in self.func.__globals__.values():
79
+ if isinstance(obj, triton.runtime.JITFunction):
80
+ dependencies.add(obj.src)
81
+
82
+ return "\n".join(f"@triton.jit\n{dependency}" for dependency in dependencies)
83
+
84
+ @staticmethod
85
+ def _import_from_path(module_name, file_path):
86
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
87
+ module = importlib.util.module_from_spec(spec)
88
+ sys.modules[module_name] = module
89
+ spec.loader.exec_module(module)
90
+
91
+ return module
92
+
77
93
 
78
94
  class CodeGenerator(ast.NodeTransformer):
79
95
  def __init__(self, context):
@@ -102,7 +118,7 @@ class CodeGenerator(ast.NodeTransformer):
102
118
  self.generic_visit(node)
103
119
 
104
120
  for arg in self._args:
105
- if not isinstance(arg, Tensor):
121
+ if not isinstance(arg, Tensor) or arg.ndim == 0:
106
122
  continue
107
123
 
108
124
  offsets = arg.offsets()
@@ -355,6 +371,9 @@ class CodeGenerator(ast.NodeTransformer):
355
371
 
356
372
  @staticmethod
357
373
  def _generate_load(tensor, intermediate_indices=()):
374
+ if tensor.ndim == 0:
375
+ return Symbol(tensor.original.name).node
376
+
358
377
  pointers, mask = CodeGenerator._generate_pointers_and_mask(
359
378
  tensor, intermediate_indices
360
379
  )
@@ -459,9 +478,10 @@ class Tritonizer(ast.NodeTransformer):
459
478
 
460
479
 
461
480
  class _Handle:
462
- def __init__(self, kernel, launch):
481
+ def __init__(self, kernel, launch, source):
463
482
  self._kernel = kernel
464
483
  self._launch = launch
484
+ self._source = source
465
485
 
466
486
  def __call__(self, *args, **kwargs):
467
487
  return self._launch(*args, **kwargs)
@@ -21,7 +21,7 @@ class Tensor:
21
21
 
22
22
  self.dtype = dtype
23
23
 
24
- self.name = f"tensor_{type(self).num_instances}"
24
+ self.name = f"_ninetoothed_tensor_{type(self).num_instances}"
25
25
 
26
26
  if ndim is not None:
27
27
  self.shape = (Symbol(self.size_string(i)) for i in range(ndim))
@@ -103,6 +103,9 @@ class Tensor:
103
103
  )
104
104
 
105
105
  def names(self):
106
+ if self.ndim == 0:
107
+ return {self.original.name}
108
+
106
109
  return (
107
110
  {self.original.pointer_string()}
108
111
  | {
@@ -0,0 +1,104 @@
1
+ import random
2
+
3
+ import torch
4
+
5
+ import ninetoothed
6
+ import ninetoothed.language as ntl
7
+ from ninetoothed import Symbol, Tensor
8
+ from tests.skippers import skip_if_cuda_not_available, skip_if_float8_e5m2_not_supported
9
+
10
+
11
+ def addmm(input, mat1, mat2, beta=1, alpha=1):
12
+ BLOCK_SIZE_M = Symbol("BLOCK_SIZE_M", meta=True)
13
+ BLOCK_SIZE_N = Symbol("BLOCK_SIZE_N", meta=True)
14
+ BLOCK_SIZE_K = Symbol("BLOCK_SIZE_K", meta=True)
15
+
16
+ input_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
17
+
18
+ output_tiled = Tensor(2).tile((BLOCK_SIZE_M, BLOCK_SIZE_N))
19
+
20
+ mat1_tiled = (
21
+ Tensor(2)
22
+ .tile((BLOCK_SIZE_M, BLOCK_SIZE_K))
23
+ .tile((1, -1))
24
+ .expand((-1, output_tiled.shape[1]))
25
+ )
26
+ mat1_tiled.dtype = mat1_tiled.dtype.squeeze(0)
27
+
28
+ mat2_tiled = (
29
+ Tensor(2)
30
+ .tile((BLOCK_SIZE_K, BLOCK_SIZE_N))
31
+ .tile((-1, 1))
32
+ .expand((output_tiled.shape[0], -1))
33
+ )
34
+ mat2_tiled.dtype = mat2_tiled.dtype.squeeze(1)
35
+
36
+ @ninetoothed.jit
37
+ def addmm_kernel(
38
+ input: input_tiled,
39
+ mat1: mat1_tiled,
40
+ mat2: mat2_tiled,
41
+ beta: Tensor(0),
42
+ alpha: Tensor(0),
43
+ output: output_tiled,
44
+ ):
45
+ accumulator = ntl.zeros(output.shape, dtype=ntl.float32)
46
+ for k in range(mat1.shape[0]):
47
+ accumulator += ntl.dot(mat1[k], mat2[k])
48
+ output = beta * input + alpha * accumulator.to(ntl.float16)
49
+
50
+ output = torch.empty(
51
+ (mat1.shape[0], mat2.shape[1]), device=mat1.device, dtype=torch.float16
52
+ )
53
+
54
+ addmm_kernel(input, mat1, mat2, beta, alpha, output)
55
+
56
+ return output
57
+
58
+
59
+ @skip_if_cuda_not_available
60
+ class TestCUDA:
61
+ @classmethod
62
+ def setup_class(cls):
63
+ torch.manual_seed(0)
64
+
65
+ shape = (512, 512)
66
+
67
+ cls.input = torch.randn(shape, device="cuda")
68
+ cls.mat1 = torch.randn(shape, device="cuda")
69
+ cls.mat2 = torch.randn(shape, device="cuda")
70
+ cls.beta = random.uniform(0, 1)
71
+ cls.alpha = random.uniform(0, 1)
72
+
73
+ def test_fp16(self):
74
+ input = type(self).input.to(torch.float16)
75
+ mat1 = type(self).mat1.to(torch.float16)
76
+ mat2 = type(self).mat2.to(torch.float16)
77
+ beta = type(self).beta
78
+ alpha = type(self).alpha
79
+
80
+ assert torch.allclose(
81
+ addmm(input, mat1, mat2, beta=beta, alpha=alpha),
82
+ torch.addmm(input, mat1, mat2, beta=beta, alpha=alpha),
83
+ atol=0.075,
84
+ )
85
+
86
+ @skip_if_float8_e5m2_not_supported
87
+ def test_fp8(self):
88
+ input = type(self).input.to(torch.float8_e5m2)
89
+ mat1 = type(self).mat1.to(torch.float8_e5m2)
90
+ mat2 = type(self).mat2.T.to(torch.float8_e5m2)
91
+ beta = type(self).beta
92
+ alpha = type(self).alpha
93
+
94
+ assert torch.allclose(
95
+ addmm(input, mat1, mat2, beta=beta, alpha=alpha),
96
+ torch.addmm(
97
+ input.to(torch.float16),
98
+ mat1.to(torch.float16),
99
+ mat2.to(torch.float16),
100
+ beta=beta,
101
+ alpha=alpha,
102
+ ),
103
+ atol=0.125,
104
+ )
File without changes
File without changes
File without changes