ai-edge-torch-nightly 0.3.0.dev20241213__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
ai_edge_torch/__init__.py CHANGED
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._config import config
16
17
  from ai_edge_torch._convert.converter import convert
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
20
21
  from ai_edge_torch.version import __version__
21
22
 
22
-
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the ai-edge-torch."""
17
+
18
+ import functools
19
+ import logging
20
+ import os
21
+
22
+ __all__ = ["config"]
23
+
24
+
25
+ class _Config:
26
+ """ai-edge-torch global configs."""
27
+
28
+ @property
29
+ @functools.cache # pylint: disable=method-cache-max-size-none
30
+ def use_torch_xla(self) -> bool:
31
+ """True if using torch_xla to lower torch ops to StableHLO.
32
+
33
+ To use torch_xla as the lowering backend, set environment variable
34
+ `USE_TORCH_XLA` to "true".
35
+ """
36
+ var = os.environ.get("USE_TORCH_XLA", "false")
37
+ var = var.lower().strip()
38
+ if var in ("y", "yes", "t", "true", "on", "1"):
39
+ return True
40
+ elif var in ("n", "no", "f", "false", "off", "0"):
41
+ return False
42
+ else:
43
+ logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
+ return False
45
+
46
+ @property
47
+ def in_oss(self) -> bool:
48
+ """True if the code is not running in google internal environment."""
49
+ return True
50
+
51
+
52
+ config = _Config()
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Tuple
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch import config
23
22
  from ai_edge_torch._convert import conversion_utils
24
23
  from ai_edge_torch.quantize import pt2e_quantizer
25
24
  from ai_edge_torch.testing import model_coverage
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
292
291
  self.assertTrue(result)
293
292
 
294
293
  @googletest.skipIf(
295
- not config.Config.use_torch_xla,
294
+ not ai_edge_torch.config.use_torch_xla,
296
295
  reason="Shape polymorphism is not yet support with odml_torch.",
297
296
  )
298
297
  def test_convert_model_with_dynamic_batch(self):
@@ -20,6 +20,7 @@ from typing import List, Tuple
20
20
 
21
21
  from ai_edge_torch import hlfb
22
22
  from ai_edge_torch.generative.layers import model_config
23
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
23
24
  import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
@@ -159,8 +160,6 @@ def update(
159
160
  Returns:
160
161
  KVCacheEntry: The updated KVCache entry based on the passed inputs.
161
162
  """
162
- # Turn dynamic_update_slice updates off for now.
163
- use_dus=False
164
163
  update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165
164
  return update_kv_cache(cache, input_pos, k_slice, v_slice)
166
165
 
@@ -16,7 +16,6 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
21
  from ai_edge_torch.generative.layers import kv_cache
@@ -83,22 +82,22 @@ class TestModelConversion(googletest.TestCase):
83
82
  )
84
83
 
85
84
  @googletest.skipIf(
86
- ai_edge_config.Config.use_torch_xla,
87
- reason="tests with custom ops are not supported on oss",
85
+ ai_edge_torch.config.in_oss,
86
+ reason="tests with custom ops are not supported in oss",
88
87
  )
89
88
  def test_toy_model_with_kv_cache(self):
90
89
  self._test_model_with_kv_cache(enable_hlfb=False)
91
90
 
92
91
  @googletest.skipIf(
93
- ai_edge_config.Config.use_torch_xla,
94
- reason="tests with custom ops are not supported on oss",
92
+ ai_edge_torch.config.in_oss,
93
+ reason="tests with custom ops are not supported in oss",
95
94
  )
96
95
  def test_toy_model_with_kv_cache_with_hlfb(self):
97
96
  self._test_model_with_kv_cache(enable_hlfb=True)
98
97
 
99
98
  @googletest.skipIf(
100
- ai_edge_config.Config.use_torch_xla,
101
- reason="tests with custom ops are not supported on oss",
99
+ ai_edge_torch.config.in_oss,
100
+ reason="tests with custom ops are not supported in oss",
102
101
  )
103
102
  def test_toy_model_has_dus_op(self):
104
103
  """Tests that the model has the dynamic update slice op."""
@@ -179,8 +178,8 @@ class TestModelConversion(googletest.TestCase):
179
178
  )
180
179
 
181
180
  @googletest.skipIf(
182
- ai_edge_config.Config.use_torch_xla,
183
- reason="tests with custom ops are not supported on oss",
181
+ ai_edge_torch.config.in_oss,
182
+ reason="tests with custom ops are not supported in oss",
184
183
  )
185
184
  def test_tiny_llama_multisig(self):
186
185
  config = tiny_llama.get_fake_model_config()
@@ -16,7 +16,6 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21
20
  from ai_edge_torch.generative.examples.gemma import gemma1
22
21
  from ai_edge_torch.generative.examples.gemma import gemma2
@@ -91,8 +90,8 @@ class TestModelConversion(googletest.TestCase):
91
90
  )
92
91
 
93
92
  @googletest.skipIf(
94
- ai_edge_config.Config.use_torch_xla,
95
- reason="tests with custom ops are not supported on oss",
93
+ ai_edge_torch.config.in_oss,
94
+ reason="tests with custom ops are not supported in oss",
96
95
  )
97
96
  def test_gemma1(self):
98
97
  config = gemma1.get_fake_model_config()
@@ -100,8 +99,8 @@ class TestModelConversion(googletest.TestCase):
100
99
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
101
100
 
102
101
  @googletest.skipIf(
103
- ai_edge_config.Config.use_torch_xla,
104
- reason="tests with custom ops are not supported on oss",
102
+ ai_edge_torch.config.in_oss,
103
+ reason="tests with custom ops are not supported in oss",
105
104
  )
106
105
  def test_gemma2(self):
107
106
  config = gemma2.get_fake_model_config()
@@ -109,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
109
108
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
110
109
 
111
110
  @googletest.skipIf(
112
- ai_edge_config.Config.use_torch_xla,
113
- reason="tests with custom ops are not supported on oss",
111
+ ai_edge_torch.config.in_oss,
112
+ reason="tests with custom ops are not supported in oss",
114
113
  )
115
114
  def test_llama(self):
116
115
  config = llama.get_fake_model_config()
@@ -118,8 +117,8 @@ class TestModelConversion(googletest.TestCase):
118
117
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
119
118
 
120
119
  @googletest.skipIf(
121
- ai_edge_config.Config.use_torch_xla,
122
- reason="tests with custom ops are not supported on oss",
120
+ ai_edge_torch.config.in_oss,
121
+ reason="tests with custom ops are not supported in oss",
123
122
  )
124
123
  def test_phi2(self):
125
124
  config = phi2.get_fake_model_config()
@@ -128,8 +127,8 @@ class TestModelConversion(googletest.TestCase):
128
127
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129
128
 
130
129
  @googletest.skipIf(
131
- ai_edge_config.Config.use_torch_xla,
132
- reason="tests with custom ops are not supported on oss",
130
+ ai_edge_torch.config.in_oss,
131
+ reason="tests with custom ops are not supported in oss",
133
132
  )
134
133
  def test_phi3(self):
135
134
  config = phi3.get_fake_model_config()
@@ -137,8 +136,8 @@ class TestModelConversion(googletest.TestCase):
137
136
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
138
137
 
139
138
  @googletest.skipIf(
140
- ai_edge_config.Config.use_torch_xla,
141
- reason="tests with custom ops are not supported on oss",
139
+ ai_edge_torch.config.in_oss,
140
+ reason="tests with custom ops are not supported in oss",
142
141
  )
143
142
  def test_smollm(self):
144
143
  config = smollm.get_fake_model_config()
@@ -146,8 +145,8 @@ class TestModelConversion(googletest.TestCase):
146
145
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
146
 
148
147
  @googletest.skipIf(
149
- ai_edge_config.Config.use_torch_xla,
150
- reason="tests with custom ops are not supported on oss",
148
+ ai_edge_torch.config.in_oss,
149
+ reason="tests with custom ops are not supported in oss",
151
150
  )
152
151
  def test_openelm(self):
153
152
  config = openelm.get_fake_model_config()
@@ -155,8 +154,8 @@ class TestModelConversion(googletest.TestCase):
155
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
155
 
157
156
  @googletest.skipIf(
158
- ai_edge_config.Config.use_torch_xla,
159
- reason="tests with custom ops are not supported on oss",
157
+ ai_edge_torch.config.in_oss,
158
+ reason="tests with custom ops are not supported in oss",
160
159
  )
161
160
  def test_qwen(self):
162
161
  config = qwen.get_fake_model_config()
@@ -164,8 +163,8 @@ class TestModelConversion(googletest.TestCase):
164
163
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
164
 
166
165
  @googletest.skipIf(
167
- ai_edge_config.Config.use_torch_xla,
168
- reason="tests with custom ops are not supported on oss",
166
+ ai_edge_torch.config.in_oss,
167
+ reason="tests with custom ops are not supported in oss",
169
168
  )
170
169
  def test_amd_llama_135m(self):
171
170
  config = amd_llama_135m.get_fake_model_config()
@@ -173,8 +172,8 @@ class TestModelConversion(googletest.TestCase):
173
172
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174
173
 
175
174
  @googletest.skipIf(
176
- ai_edge_config.Config.use_torch_xla,
177
- reason="tests with custom ops are not supported on oss",
175
+ ai_edge_torch.config.in_oss,
176
+ reason="tests with custom ops are not supported in oss",
178
177
  )
179
178
  def disabled_test_paligemma(self):
180
179
  config = paligemma.get_fake_model_config()
@@ -222,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
222
221
  )
223
222
 
224
223
  @googletest.skipIf(
225
- ai_edge_config.Config.use_torch_xla,
226
- reason="tests with custom ops are not supported on oss",
224
+ ai_edge_torch.config.in_oss,
225
+ reason="tests with custom ops are not supported in oss",
227
226
  )
228
227
  def test_stable_diffusion_clip(self):
229
228
  config = sd_clip.get_fake_model_config()
@@ -254,8 +253,8 @@ class TestModelConversion(googletest.TestCase):
254
253
  )
255
254
 
256
255
  @googletest.skipIf(
257
- ai_edge_config.Config.use_torch_xla,
258
- reason="tests with custom ops are not supported on oss",
256
+ ai_edge_torch.config.in_oss,
257
+ reason="tests with custom ops are not supported in oss",
259
258
  )
260
259
  def test_stable_diffusion_diffusion(self):
261
260
  config = sd_diffusion.get_fake_model_config(2)
@@ -296,8 +295,8 @@ class TestModelConversion(googletest.TestCase):
296
295
  )
297
296
 
298
297
  @googletest.skipIf(
299
- ai_edge_config.Config.use_torch_xla,
300
- reason="tests with custom ops are not supported on oss",
298
+ ai_edge_torch.config.in_oss,
299
+ reason="tests with custom ops are not supported in oss",
301
300
  )
302
301
  def test_stable_diffusion_decoder(self):
303
302
  config = sd_decoder.get_fake_model_config()
@@ -15,13 +15,15 @@
15
15
 
16
16
  from typing import Any, Optional
17
17
 
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from ai_edge_torch._convert import signature
20
20
  from ai_edge_torch.quantize import quant_config as qcfg
21
21
  import torch
22
22
 
23
+ config = _config.config
24
+
23
25
  # isort: off
24
- if config.Config.use_torch_xla:
26
+ if config.use_torch_xla:
25
27
  from ai_edge_torch.lowertools import torch_xla_utils as utils
26
28
  from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
27
29
  from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
@@ -15,9 +15,11 @@
15
15
 
16
16
  import re
17
17
  from typing import Optional
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from absl.testing import absltest as googletest
20
20
 
21
+ config = _config.config
22
+
21
23
 
22
24
  def _extract_backend_configs(mlir):
23
25
  mlir = mlir.replace("\\22", '"')
@@ -38,7 +40,7 @@ def assert_string_count(
38
40
  if odml_torch_attr_counter is None:
39
41
  odml_torch_attr_counter = {}
40
42
 
41
- if config.Config.use_torch_xla:
43
+ if config.use_torch_xla:
42
44
  for key in torch_xla_pattern_counter:
43
45
  test_case.assertEqual(
44
46
  mlir.count(key),
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
276
276
  interior_padding if i == dim else 0 for i in range(rank)
277
277
  ],
278
278
  )
279
- pred = np.ones(self.type.shape, dtype=np.bool_)
280
- pred[*[
279
+
280
+ slices = [
281
281
  slice(start, end, step) if i == dim else slice(None, None, None)
282
282
  for i in range(rank)
283
- ]] = False
283
+ ]
284
+ pred = np.ones(self.type.shape, dtype=np.bool_)
285
+ pred[np.index_exp[tuple(slices)]] = False
284
286
  pred = stablehlo.constant(
285
287
  ir.DenseElementsAttr.get(
286
288
  np.packbits(pred, bitorder="little"),
@@ -232,7 +232,9 @@ def _aten_convolution(
232
232
 
233
233
  if bias is not None:
234
234
  # broadcast [C] to [NCHW]
235
- broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(
236
+ output_type, bias, ir.DenseI64ArrayAttr.get([1])
237
+ )
236
238
  res = stablehlo.add(
237
239
  lhs=res,
238
240
  rhs=broadcasted_bias,
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
20
20
  from ai_edge_torch.odml_torch.lowerings import utils
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
23
24
  import torch
24
25
 
25
26
 
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
66
67
  normalized_rank = len(normalized_shape)
67
68
  if weight is not None:
68
69
  weight = stablehlo.broadcast_in_dim(
69
- data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ data_type,
71
+ weight,
72
+ ir.DenseI64ArrayAttr.get(
73
+ list(range(data_rank - normalized_rank, data_rank))
74
+ ),
70
75
  )
71
76
  output = stablehlo.multiply(weight, output)
72
77
  if bias is not None:
73
78
  bias = stablehlo.broadcast_in_dim(
74
- data_type, bias, list(range(data_rank - normalized_rank, data_rank))
79
+ data_type,
80
+ bias,
81
+ ir.DenseI64ArrayAttr.get(
82
+ list(range(data_rank - normalized_rank, data_rank))
83
+ ),
75
84
  )
76
85
  output = stablehlo.add(bias, output)
77
86
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
- from typing import Union, cast
16
+ from typing import Optional, Union, cast
17
17
 
18
18
  from ai_edge_torch.odml_torch.lowerings import context
19
19
  from ai_edge_torch.odml_torch.lowerings import utils
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
30
30
 
31
31
 
32
32
  def _uniform_quantized_type(
33
- stored_type: str | ir.Type,
34
- expressed_type: str | ir.Type,
33
+ stored_type: Union[str, ir.Type],
34
+ expressed_type: Union[str, ir.Type],
35
35
  *,
36
- scale=float | list[float] | tuple[float],
37
- zero_point=float | list[float] | tuple[float],
38
- storage_type_min: int | None = None,
39
- storage_type_max: int | None = None,
40
- channel_axis: int | None = None,
41
- channel_axis_size: int | None = None,
36
+ scale=Union[float, list[float], tuple[float]],
37
+ zero_point=Union[float, list[float], tuple[float]],
38
+ storage_type_min: Optional[int] = None,
39
+ storage_type_max: Optional[int] = None,
40
+ channel_axis: Optional[int] = None,
41
+ channel_axis_size: Optional[int] = None,
42
42
  ):
43
43
  """Polyfill for quant.UniformQuantizedType."""
44
44
  if storage_type_min and storage_type_max:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241213"
16
+ __version__ = "0.3.0.dev20241214"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241213
3
+ Version: 0.3.0.dev20241214
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
- Classifier: Programming Language :: Python :: 3.9
15
14
  Classifier: Programming Language :: Python :: 3.10
16
15
  Classifier: Programming Language :: Python :: 3.11
17
16
  Classifier: Topic :: Scientific/Engineering
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
19
  Classifier: Topic :: Software Development
21
20
  Classifier: Topic :: Software Development :: Libraries
22
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Requires-Python: >=3.9
22
+ Requires-Python: >=3.10
24
23
  Description-Content-Type: text/markdown
25
24
  License-File: LICENSE
26
25
  Requires-Dist: numpy
@@ -28,10 +27,13 @@ Requires-Dist: scipy
28
27
  Requires-Dist: safetensors
29
28
  Requires-Dist: tabulate
30
29
  Requires-Dist: torch>=2.4.0
31
- Requires-Dist: torch-xla>=2.4.0
32
30
  Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
31
  Requires-Dist: ai-edge-litert-nightly
34
32
  Requires-Dist: ai-edge-quantizer-nightly
33
+ Requires-Dist: jax
34
+ Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
35
+ Provides-Extra: torch-xla
36
+ Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
35
37
 
36
38
  Library that supports converting PyTorch models into a .tflite format, which can
37
39
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -1,9 +1,9 @@
1
- ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
- ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
1
+ ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
2
+ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ZKmpBJjKnl93nlPyg2RYz15citIW0ntqZ-0diRjwTt8,706
6
+ ai_edge_torch/version.py,sha256=iCH8lnlOrtbGwvxnT3knpY_keeu2UnrJ_ZXNK2LSvf4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -119,7 +119,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoT
119
119
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
120
120
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
121
121
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
122
- ai_edge_torch/generative/layers/kv_cache.py,sha256=dOhk3ec21189uPyCDYyxuznYQL6s4od-ln-FoDQ2cE0,6269
122
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
123
123
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
124
124
  ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
125
125
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
@@ -139,8 +139,8 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
139
139
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
140
140
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
141
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=4d_UF19KYf5xFa3yhQGe1nu3TKzmXrbr9PFiEZpPlyk,6274
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=4lfZXjXqfjiyxc2s8vMuYbdOZzD8VuPelr2AQo9PFNI,11656
142
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
143
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
144
144
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
145
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
146
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -159,12 +159,11 @@ ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge
159
159
  ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
160
160
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
161
161
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
162
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
163
162
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
164
- ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
163
+ ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
165
164
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
166
165
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
167
- ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
166
+ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
168
167
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
169
168
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
170
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
@@ -183,12 +182,12 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
183
182
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
184
183
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
185
184
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
186
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
185
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
187
186
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
- ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
187
+ ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
189
188
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
190
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
189
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
192
191
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
192
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
194
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -201,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
201
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
203
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
204
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/METADATA,sha256=Yzw2YkrbFAe1EbxfoKBDME5NQe0GIzzseUOVylqFpnM,1897
206
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/METADATA,sha256=fUbq26zB0WUU1l6eUud8vq3Nm3KSIhox74pzFSFTmoM,1966
205
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/RECORD,,
ai_edge_torch/config.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Provides a configuration for the AI Edge Torch library."""
17
-
18
- import dataclasses
19
- import os
20
-
21
-
22
- @dataclasses.dataclass
23
- class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
- "1",
26
- "true",
27
- )
@@ -1,283 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Tests for StableHLOCompositeBuilder."""
16
-
17
- import math
18
-
19
- from ai_edge_torch import config
20
- from ai_edge_torch import lowertools
21
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
22
- import torch
23
- import torch.nn.functional as F
24
-
25
- from absl.testing import absltest as googletest
26
-
27
-
28
- def _export_stablehlo_mlir(model, args):
29
- ep = torch.export.export(model, args)
30
- return lowertools.exported_program_to_mlir_text(ep)
31
-
32
-
33
- @googletest.skipIf(
34
- not config.Config.use_torch_xla,
35
- reason="The odml_torch counter part is in odml_torch.",
36
- )
37
- class TestStableHLOCompositeBuilder(googletest.TestCase):
38
-
39
- def test_build_composite(self):
40
- class SampleModel(torch.nn.Module):
41
-
42
- def forward(self, x):
43
- builder = StableHLOCompositeBuilder(name="test.plus_two")
44
- y = x + 1
45
- y = builder.mark_inputs(y)
46
- z = y + 2
47
- z = builder.mark_outputs(z)
48
- return z
49
-
50
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
51
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
52
-
53
- def test_build_multiple_composites(self):
54
- class SampleModel(torch.nn.Module):
55
-
56
- def plus_one(self, x: torch.Tensor):
57
- builder = StableHLOCompositeBuilder("test.plus_one")
58
- x = builder.mark_inputs(x)
59
- y = x + 1
60
- y = builder.mark_outputs(y)
61
- return y
62
-
63
- def plus_two(self, x: torch.Tensor):
64
- builder = StableHLOCompositeBuilder("test.plus_two")
65
- x = builder.mark_inputs(x)
66
- y = x + 2
67
- y = builder.mark_outputs(y)
68
- return y
69
-
70
- def forward(self, x):
71
- x = self.plus_two(x)
72
- x = x + 3
73
- x = self.plus_one(x)
74
- x = x + 4
75
- x = self.plus_two(x)
76
- return x
77
-
78
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
79
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
80
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
81
-
82
- def test_build_composite_with_attr(self):
83
- class SampleModel(torch.nn.Module):
84
-
85
- def __init__(self):
86
- super().__init__()
87
-
88
- def log_softmax(self, x: torch.Tensor, dim: int):
89
- builder = StableHLOCompositeBuilder(
90
- name="test.log_softmax", attr={"dim": dim}
91
- )
92
- x = builder.mark_inputs(x)
93
- y = torch.nn.functional.log_softmax(x, dim=dim)
94
- y = builder.mark_outputs(y)
95
- return y
96
-
97
- def forward(self, x):
98
- x = x + 1
99
- x = self.log_softmax(x, 0)
100
- x = self.log_softmax(x, 1)
101
- return x
102
-
103
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
104
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
105
- self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
106
- self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
107
-
108
- def test_build_composite_with_mix_type_attrs(self):
109
- class SampleModel(torch.nn.Module):
110
-
111
- def __init__(self):
112
- super().__init__()
113
-
114
- def log_softmax(self, x: torch.Tensor, dim: int):
115
- builder = StableHLOCompositeBuilder(
116
- name="test.log_softmax",
117
- attr={
118
- "dim": dim,
119
- "source": "torch.nn",
120
- "version": 1.0,
121
- },
122
- )
123
- x = builder.mark_inputs(x)
124
- y = torch.nn.functional.log_softmax(x, dim=dim)
125
- y = builder.mark_outputs(y)
126
- return y
127
-
128
- def forward(self, x):
129
- x = x + 1
130
- x = self.log_softmax(x, 0)
131
- return x
132
-
133
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
134
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
135
- self.assertEqual(
136
- mlir.count(
137
- 'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
138
- " version = 1.000000e+00 : f32}"
139
- ),
140
- 1,
141
- )
142
-
143
- def test_sdpa_composite(self):
144
- class SDPAModel(torch.nn.Module):
145
-
146
- def scaled_dot_product_attention(
147
- self,
148
- q: torch.Tensor,
149
- k: torch.Tensor,
150
- v: torch.Tensor,
151
- head_size: int,
152
- mask: torch.Tensor,
153
- ):
154
- builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
155
- q, k, v, mask = builder.mark_inputs(q, k, v, mask)
156
-
157
- scale = 1.0 / math.sqrt(head_size)
158
-
159
- q = q.transpose(1, 2)
160
- k = k.transpose(1, 2)
161
- v = v.transpose(1, 2)
162
- y = F.scaled_dot_product_attention(
163
- q,
164
- k,
165
- v,
166
- attn_mask=mask,
167
- dropout_p=0.0,
168
- is_causal=mask is None,
169
- scale=scale,
170
- )
171
- result = y.transpose(1, 2)
172
- result = builder.mark_outputs(result)
173
- return result
174
-
175
- def forward(self, q, k, v, mask):
176
- x = self.scaled_dot_product_attention(
177
- q,
178
- k,
179
- v,
180
- 8,
181
- mask,
182
- )
183
- return x
184
-
185
- query = torch.rand(1, 1, 32, 4)
186
- key = torch.rand(1, 500, 1, 4)
187
- value = torch.rand(1, 500, 1, 4)
188
- mask = torch.rand(1, 1, 1, 500)
189
-
190
- mlir = _export_stablehlo_mlir(
191
- SDPAModel().eval(),
192
- (query, key, value, mask),
193
- )
194
- self.assertEqual(
195
- mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
196
- )
197
-
198
- def test_sdpa_composite_with_attr(self):
199
- class SDPAModel(torch.nn.Module):
200
-
201
- def scaled_dot_product_attention(
202
- self,
203
- q: torch.Tensor,
204
- k: torch.Tensor,
205
- v: torch.Tensor,
206
- head_size: int,
207
- include_captanh: bool,
208
- ):
209
- builder = StableHLOCompositeBuilder(
210
- name="test.scaled_dot_product_attention",
211
- attr={"include_captanh": include_captanh},
212
- )
213
- q, k, v = builder.mark_inputs(q, k, v)
214
-
215
- scale = 1.0 / math.sqrt(head_size)
216
-
217
- q = q.transpose(1, 2)
218
- k = k.transpose(1, 2)
219
- v = v.transpose(1, 2)
220
- y = F.scaled_dot_product_attention(
221
- q,
222
- k,
223
- v,
224
- attn_mask=None,
225
- dropout_p=0.0,
226
- is_causal=True,
227
- scale=scale,
228
- )
229
- result = y.transpose(1, 2)
230
- result = builder.mark_outputs(result)
231
- return result
232
-
233
- def forward(self, q, k, v):
234
- x = self.scaled_dot_product_attention(q, k, v, 8, True)
235
- y = self.scaled_dot_product_attention(q, k, v, 8, False)
236
- return x + y
237
-
238
- query = torch.rand(1, 1, 32, 4)
239
- key = torch.rand(1, 500, 1, 4)
240
- value = torch.rand(1, 500, 1, 4)
241
- mlir = _export_stablehlo_mlir(
242
- SDPAModel().eval(),
243
- (query, key, value),
244
- )
245
- self.assertEqual(
246
- mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
247
- )
248
- self.assertEqual(
249
- mlir.count("composite_attributes = {include_captanh = true}"), 1
250
- )
251
- self.assertEqual(
252
- mlir.count("composite_attributes = {include_captanh = false}"), 1
253
- )
254
-
255
- def test_build_composite_with_multiple_inputs_outputs(self):
256
- class SampleModel(torch.nn.Module):
257
-
258
- def mimo_sample(self, a, b, c):
259
- builder = StableHLOCompositeBuilder(name="test.mimo_sample")
260
-
261
- a, b, c = builder.mark_inputs(a, b, c)
262
- x = a + b + c
263
- y = (a - b) * x
264
- z = (c + 1.0) * a
265
- x, y, z = builder.mark_outputs(x, y, z)
266
-
267
- result = x + y * z
268
- return result
269
-
270
- def forward(self, a, b, c):
271
- x = self.mimo_sample(a, b, c)
272
- x = self.mimo_sample(a, b, x)
273
- x = self.mimo_sample(x, x, c)
274
- return x
275
-
276
- mlir = _export_stablehlo_mlir(
277
- SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
278
- )
279
- self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
280
-
281
-
282
- if __name__ == "__main__":
283
- googletest.main()