ninetoothed 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ninetoothed/jit.py CHANGED
@@ -137,7 +137,7 @@ class CodeGenerator(ast.NodeTransformer):
137
137
  node.args = [
138
138
  ast.arg(arg=name)
139
139
  if not Symbol.is_constexpr(name)
140
- else ast.arg(arg=name, annotation=attribute("constexpr"))
140
+ else ast.arg(arg=name, annotation=attribute("constexpr").node)
141
141
  for name in non_meta_names
142
142
  ] + [
143
143
  ast.arg(arg=name, annotation=attribute("constexpr").node)
@@ -287,15 +287,30 @@ class CodeGenerator(ast.NodeTransformer):
287
287
  )
288
288
 
289
289
  def _generate_launch(self, params, meta):
290
+ constexpr_params = [param for param in params if Symbol.is_constexpr(param)]
291
+ constexpr_params_without_prefixes = [
292
+ Symbol.remove_prefix(param) for param in constexpr_params
293
+ ]
294
+
290
295
  launch = ast.FunctionDef(
291
296
  name=f"launch_{self._func_def.name}",
292
297
  args=ast.arguments(
293
298
  posonlyargs=[],
294
- args=[ast.arg(arg.original.name) for arg in self._args],
299
+ args=[ast.arg(arg=arg.original.name) for arg in self._args]
300
+ + [ast.arg(arg=param) for param in constexpr_params_without_prefixes],
295
301
  kwonlyargs=[],
296
302
  defaults=[],
297
303
  ),
298
304
  body=[
305
+ ast.Assign(
306
+ targets=[ast.Name(id=param, ctx=ast.Store())],
307
+ value=ast.Name(id=param_without_prefix, ctx=ast.Load()),
308
+ )
309
+ for param, param_without_prefix in zip(
310
+ constexpr_params, constexpr_params_without_prefixes
311
+ )
312
+ ]
313
+ + [
299
314
  ast.Expr(
300
315
  ast.Call(
301
316
  func=ast.Subscript(
ninetoothed/symbol.py CHANGED
@@ -137,19 +137,39 @@ class Symbol:
137
137
 
138
138
  @staticmethod
139
139
  def is_constexpr(name):
140
- return name.startswith("_ninetoothed_constexpr_") or Symbol.is_meta(name)
140
+ return name.startswith(Symbol._constexpr_prefix()) or Symbol.is_meta(name)
141
141
 
142
142
  @staticmethod
143
143
  def is_meta(name):
144
- return name.startswith("_ninetoothed_meta_")
144
+ return name.startswith(Symbol._meta_prefix())
145
+
146
+ @staticmethod
147
+ def remove_prefix(name):
148
+ if name.startswith(Symbol._constexpr_prefix()):
149
+ return name.removeprefix(Symbol._constexpr_prefix())
150
+
151
+ if name.startswith(Symbol._meta_prefix()):
152
+ return name.removeprefix(Symbol._meta_prefix())
145
153
 
146
154
  @staticmethod
147
155
  def _create_constexpr(name):
148
- return f"_ninetoothed_constexpr_{name}"
156
+ return f"{Symbol._constexpr_prefix()}{name}"
149
157
 
150
158
  @staticmethod
151
159
  def _create_meta(name):
152
- return f"_ninetoothed_meta_{name}"
160
+ return f"{Symbol._meta_prefix()}{name}"
161
+
162
+ @staticmethod
163
+ def _constexpr_prefix():
164
+ return f"{Symbol._ninetoothed_prefix()}constexpr_"
165
+
166
+ @staticmethod
167
+ def _meta_prefix():
168
+ return f"{Symbol._ninetoothed_prefix()}meta_"
169
+
170
+ @staticmethod
171
+ def _ninetoothed_prefix():
172
+ return "_ninetoothed_"
153
173
 
154
174
 
155
175
  class _FindAndReplacer(ast.NodeTransformer):
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: ninetoothed
3
- Version: 0.3.0
3
+ Version: 0.4.0
4
4
  Summary: A domain-specific language based on Triton but providing higher-level abstraction.
5
5
  Project-URL: Homepage, https://github.com/InfiniTensor/ninetoothed
6
6
  Project-URL: Issues, https://github.com/InfiniTensor/ninetoothed/issues
@@ -0,0 +1,10 @@
1
+ ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
+ ninetoothed/jit.py,sha256=ECjaHcrVNj1SBxoMdzjGi5iDp3rtv2jUiHjvK0eU6Cs,16188
3
+ ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
+ ninetoothed/symbol.py,sha256=8Wg-JQPkVv9mMIxB1Rj4SHzOytHXPgHLkuK0BEFPDkc,5243
5
+ ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
+ ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
+ ninetoothed-0.4.0.dist-info/METADATA,sha256=Wgg0CP-j8VkiJWMpyOLOL7C1kVLkeF4OoZD6eyZsgLQ,6720
8
+ ninetoothed-0.4.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
+ ninetoothed-0.4.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
+ ninetoothed-0.4.0.dist-info/RECORD,,
@@ -1,10 +0,0 @@
1
- ninetoothed/__init__.py,sha256=T5UJXlC-wbo8JKPbLUNT65Kccp12xP52WFV5FsugETI,147
2
- ninetoothed/jit.py,sha256=nhjZRi8_kcjWZX0eOrnxLlzJfVg5vn12f9oi0Er2ABE,15515
3
- ninetoothed/language.py,sha256=YwjlBENmmKPTnhaQ2uYbj5MwzrCAT7MLJ6VkQ6NeXJE,504
4
- ninetoothed/symbol.py,sha256=Bd54qcI8KQAX0JRE_wPXycswtdSofhZ6Rr5MtZcv9fo,4665
5
- ninetoothed/tensor.py,sha256=_DrjOJ-pBvEbSNUvUoYJduLQXmuKgNcqhe4xUDMVoZw,6275
6
- ninetoothed/torchifier.py,sha256=8M2PDwyFIfVypX6Z-Vt_bGbsCPqxqKnftL0rXeh9bOM,911
7
- ninetoothed-0.3.0.dist-info/METADATA,sha256=CqdtfdV0eHzSwxJmFpD2IG5d4WTc6RDlpqMZue4Ml2Q,6720
8
- ninetoothed-0.3.0.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
9
- ninetoothed-0.3.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
10
- ninetoothed-0.3.0.dist-info/RECORD,,