ai-edge-torch-nightly 0.3.0.dev20241213__py3-none-any.whl → 0.3.0.dev20241214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ai_edge_torch/__init__.py CHANGED
@@ -13,13 +13,13 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ from ai_edge_torch._config import config
16
17
  from ai_edge_torch._convert.converter import convert
17
18
  from ai_edge_torch._convert.converter import signature
18
19
  from ai_edge_torch._convert.to_channel_last_io import to_channel_last_io
19
20
  from ai_edge_torch.model import Model
20
21
  from ai_edge_torch.version import __version__
21
22
 
22
-
23
23
  def load(path: str) -> Model:
24
24
  """Imports an ai_edge_torch model from disk.
25
25
 
@@ -0,0 +1,52 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Provides a configuration for the ai-edge-torch."""
17
+
18
+ import functools
19
+ import logging
20
+ import os
21
+
22
+ __all__ = ["config"]
23
+
24
+
25
+ class _Config:
26
+ """ai-edge-torch global configs."""
27
+
28
+ @property
29
+ @functools.cache # pylint: disable=method-cache-max-size-none
30
+ def use_torch_xla(self) -> bool:
31
+ """True if using torch_xla to lower torch ops to StableHLO.
32
+
33
+ To use torch_xla as the lowering backend, set environment variable
34
+ `USE_TORCH_XLA` to "true".
35
+ """
36
+ var = os.environ.get("USE_TORCH_XLA", "false")
37
+ var = var.lower().strip()
38
+ if var in ("y", "yes", "t", "true", "on", "1"):
39
+ return True
40
+ elif var in ("n", "no", "f", "false", "off", "0"):
41
+ return False
42
+ else:
43
+ logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
44
+ return False
45
+
46
+ @property
47
+ def in_oss(self) -> bool:
48
+ """True if the code is not running in google internal environment."""
49
+ return True
50
+
51
+
52
+ config = _Config()
@@ -19,7 +19,6 @@ import os
19
19
  from typing import Tuple
20
20
 
21
21
  import ai_edge_torch
22
- from ai_edge_torch import config
23
22
  from ai_edge_torch._convert import conversion_utils
24
23
  from ai_edge_torch.quantize import pt2e_quantizer
25
24
  from ai_edge_torch.testing import model_coverage
@@ -292,7 +291,7 @@ class TestConvert(googletest.TestCase):
292
291
  self.assertTrue(result)
293
292
 
294
293
  @googletest.skipIf(
295
- not config.Config.use_torch_xla,
294
+ not ai_edge_torch.config.use_torch_xla,
296
295
  reason="Shape polymorphism is not yet support with odml_torch.",
297
296
  )
298
297
  def test_convert_model_with_dynamic_batch(self):
@@ -20,6 +20,7 @@ from typing import List, Tuple
20
20
 
21
21
  from ai_edge_torch import hlfb
22
22
  from ai_edge_torch.generative.layers import model_config
23
+ from ai_edge_torch.generative.utilities.dynamic_update_slice import dynamic_update_slice
23
24
  import torch
24
25
  import torch.utils._pytree as pytree
25
26
 
@@ -159,8 +160,6 @@ def update(
159
160
  Returns:
160
161
  KVCacheEntry: The updated KVCache entry based on the passed inputs.
161
162
  """
162
- # Turn dynamic_update_slice updates off for now.
163
- use_dus=False
164
163
  update_kv_cache = _update_kv_impl if use_dus else _update_kv_base_impl
165
164
  return update_kv_cache(cache, input_pos, k_slice, v_slice)
166
165
 
@@ -16,7 +16,6 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.test_models import toy_model_with_kv_cache
21
20
  from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
22
21
  from ai_edge_torch.generative.layers import kv_cache
@@ -83,22 +82,22 @@ class TestModelConversion(googletest.TestCase):
83
82
  )
84
83
 
85
84
  @googletest.skipIf(
86
- ai_edge_config.Config.use_torch_xla,
87
- reason="tests with custom ops are not supported on oss",
85
+ ai_edge_torch.config.in_oss,
86
+ reason="tests with custom ops are not supported in oss",
88
87
  )
89
88
  def test_toy_model_with_kv_cache(self):
90
89
  self._test_model_with_kv_cache(enable_hlfb=False)
91
90
 
92
91
  @googletest.skipIf(
93
- ai_edge_config.Config.use_torch_xla,
94
- reason="tests with custom ops are not supported on oss",
92
+ ai_edge_torch.config.in_oss,
93
+ reason="tests with custom ops are not supported in oss",
95
94
  )
96
95
  def test_toy_model_with_kv_cache_with_hlfb(self):
97
96
  self._test_model_with_kv_cache(enable_hlfb=True)
98
97
 
99
98
  @googletest.skipIf(
100
- ai_edge_config.Config.use_torch_xla,
101
- reason="tests with custom ops are not supported on oss",
99
+ ai_edge_torch.config.in_oss,
100
+ reason="tests with custom ops are not supported in oss",
102
101
  )
103
102
  def test_toy_model_has_dus_op(self):
104
103
  """Tests that the model has the dynamic update slice op."""
@@ -179,8 +178,8 @@ class TestModelConversion(googletest.TestCase):
179
178
  )
180
179
 
181
180
  @googletest.skipIf(
182
- ai_edge_config.Config.use_torch_xla,
183
- reason="tests with custom ops are not supported on oss",
181
+ ai_edge_torch.config.in_oss,
182
+ reason="tests with custom ops are not supported in oss",
184
183
  )
185
184
  def test_tiny_llama_multisig(self):
186
185
  config = tiny_llama.get_fake_model_config()
@@ -16,7 +16,6 @@
16
16
  """Testing model conversion for a few gen-ai models."""
17
17
 
18
18
  import ai_edge_torch
19
- from ai_edge_torch import config as ai_edge_config
20
19
  from ai_edge_torch.generative.examples.amd_llama_135m import amd_llama_135m
21
20
  from ai_edge_torch.generative.examples.gemma import gemma1
22
21
  from ai_edge_torch.generative.examples.gemma import gemma2
@@ -91,8 +90,8 @@ class TestModelConversion(googletest.TestCase):
91
90
  )
92
91
 
93
92
  @googletest.skipIf(
94
- ai_edge_config.Config.use_torch_xla,
95
- reason="tests with custom ops are not supported on oss",
93
+ ai_edge_torch.config.in_oss,
94
+ reason="tests with custom ops are not supported in oss",
96
95
  )
97
96
  def test_gemma1(self):
98
97
  config = gemma1.get_fake_model_config()
@@ -100,8 +99,8 @@ class TestModelConversion(googletest.TestCase):
100
99
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
101
100
 
102
101
  @googletest.skipIf(
103
- ai_edge_config.Config.use_torch_xla,
104
- reason="tests with custom ops are not supported on oss",
102
+ ai_edge_torch.config.in_oss,
103
+ reason="tests with custom ops are not supported in oss",
105
104
  )
106
105
  def test_gemma2(self):
107
106
  config = gemma2.get_fake_model_config()
@@ -109,8 +108,8 @@ class TestModelConversion(googletest.TestCase):
109
108
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
110
109
 
111
110
  @googletest.skipIf(
112
- ai_edge_config.Config.use_torch_xla,
113
- reason="tests with custom ops are not supported on oss",
111
+ ai_edge_torch.config.in_oss,
112
+ reason="tests with custom ops are not supported in oss",
114
113
  )
115
114
  def test_llama(self):
116
115
  config = llama.get_fake_model_config()
@@ -118,8 +117,8 @@ class TestModelConversion(googletest.TestCase):
118
117
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
119
118
 
120
119
  @googletest.skipIf(
121
- ai_edge_config.Config.use_torch_xla,
122
- reason="tests with custom ops are not supported on oss",
120
+ ai_edge_torch.config.in_oss,
121
+ reason="tests with custom ops are not supported in oss",
123
122
  )
124
123
  def test_phi2(self):
125
124
  config = phi2.get_fake_model_config()
@@ -128,8 +127,8 @@ class TestModelConversion(googletest.TestCase):
128
127
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
129
128
 
130
129
  @googletest.skipIf(
131
- ai_edge_config.Config.use_torch_xla,
132
- reason="tests with custom ops are not supported on oss",
130
+ ai_edge_torch.config.in_oss,
131
+ reason="tests with custom ops are not supported in oss",
133
132
  )
134
133
  def test_phi3(self):
135
134
  config = phi3.get_fake_model_config()
@@ -137,8 +136,8 @@ class TestModelConversion(googletest.TestCase):
137
136
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
138
137
 
139
138
  @googletest.skipIf(
140
- ai_edge_config.Config.use_torch_xla,
141
- reason="tests with custom ops are not supported on oss",
139
+ ai_edge_torch.config.in_oss,
140
+ reason="tests with custom ops are not supported in oss",
142
141
  )
143
142
  def test_smollm(self):
144
143
  config = smollm.get_fake_model_config()
@@ -146,8 +145,8 @@ class TestModelConversion(googletest.TestCase):
146
145
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
147
146
 
148
147
  @googletest.skipIf(
149
- ai_edge_config.Config.use_torch_xla,
150
- reason="tests with custom ops are not supported on oss",
148
+ ai_edge_torch.config.in_oss,
149
+ reason="tests with custom ops are not supported in oss",
151
150
  )
152
151
  def test_openelm(self):
153
152
  config = openelm.get_fake_model_config()
@@ -155,8 +154,8 @@ class TestModelConversion(googletest.TestCase):
155
154
  self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
156
155
 
157
156
  @googletest.skipIf(
158
- ai_edge_config.Config.use_torch_xla,
159
- reason="tests with custom ops are not supported on oss",
157
+ ai_edge_torch.config.in_oss,
158
+ reason="tests with custom ops are not supported in oss",
160
159
  )
161
160
  def test_qwen(self):
162
161
  config = qwen.get_fake_model_config()
@@ -164,8 +163,8 @@ class TestModelConversion(googletest.TestCase):
164
163
  self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
165
164
 
166
165
  @googletest.skipIf(
167
- ai_edge_config.Config.use_torch_xla,
168
- reason="tests with custom ops are not supported on oss",
166
+ ai_edge_torch.config.in_oss,
167
+ reason="tests with custom ops are not supported in oss",
169
168
  )
170
169
  def test_amd_llama_135m(self):
171
170
  config = amd_llama_135m.get_fake_model_config()
@@ -173,8 +172,8 @@ class TestModelConversion(googletest.TestCase):
173
172
  self._test_model(config, pytorch_model, "prefill", atol=1e-5, rtol=1e-5)
174
173
 
175
174
  @googletest.skipIf(
176
- ai_edge_config.Config.use_torch_xla,
177
- reason="tests with custom ops are not supported on oss",
175
+ ai_edge_torch.config.in_oss,
176
+ reason="tests with custom ops are not supported in oss",
178
177
  )
179
178
  def disabled_test_paligemma(self):
180
179
  config = paligemma.get_fake_model_config()
@@ -222,8 +221,8 @@ class TestModelConversion(googletest.TestCase):
222
221
  )
223
222
 
224
223
  @googletest.skipIf(
225
- ai_edge_config.Config.use_torch_xla,
226
- reason="tests with custom ops are not supported on oss",
224
+ ai_edge_torch.config.in_oss,
225
+ reason="tests with custom ops are not supported in oss",
227
226
  )
228
227
  def test_stable_diffusion_clip(self):
229
228
  config = sd_clip.get_fake_model_config()
@@ -254,8 +253,8 @@ class TestModelConversion(googletest.TestCase):
254
253
  )
255
254
 
256
255
  @googletest.skipIf(
257
- ai_edge_config.Config.use_torch_xla,
258
- reason="tests with custom ops are not supported on oss",
256
+ ai_edge_torch.config.in_oss,
257
+ reason="tests with custom ops are not supported in oss",
259
258
  )
260
259
  def test_stable_diffusion_diffusion(self):
261
260
  config = sd_diffusion.get_fake_model_config(2)
@@ -296,8 +295,8 @@ class TestModelConversion(googletest.TestCase):
296
295
  )
297
296
 
298
297
  @googletest.skipIf(
299
- ai_edge_config.Config.use_torch_xla,
300
- reason="tests with custom ops are not supported on oss",
298
+ ai_edge_torch.config.in_oss,
299
+ reason="tests with custom ops are not supported in oss",
301
300
  )
302
301
  def test_stable_diffusion_decoder(self):
303
302
  config = sd_decoder.get_fake_model_config()
@@ -15,13 +15,15 @@
15
15
 
16
16
  from typing import Any, Optional
17
17
 
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from ai_edge_torch._convert import signature
20
20
  from ai_edge_torch.quantize import quant_config as qcfg
21
21
  import torch
22
22
 
23
+ config = _config.config
24
+
23
25
  # isort: off
24
- if config.Config.use_torch_xla:
26
+ if config.use_torch_xla:
25
27
  from ai_edge_torch.lowertools import torch_xla_utils as utils
26
28
  from ai_edge_torch.lowertools.torch_xla_utils import exported_program_to_mlir_text
27
29
  from torch_xla.experimental.mark_pattern_utils import StableHLOCompositeBuilder
@@ -15,9 +15,11 @@
15
15
 
16
16
  import re
17
17
  from typing import Optional
18
- from ai_edge_torch import config
18
+ from ai_edge_torch import _config
19
19
  from absl.testing import absltest as googletest
20
20
 
21
+ config = _config.config
22
+
21
23
 
22
24
  def _extract_backend_configs(mlir):
23
25
  mlir = mlir.replace("\\22", '"')
@@ -38,7 +40,7 @@ def assert_string_count(
38
40
  if odml_torch_attr_counter is None:
39
41
  odml_torch_attr_counter = {}
40
42
 
41
- if config.Config.use_torch_xla:
43
+ if config.use_torch_xla:
42
44
  for key in torch_xla_pattern_counter:
43
45
  test_case.assertEqual(
44
46
  mlir.count(key),
@@ -276,11 +276,13 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
276
276
  interior_padding if i == dim else 0 for i in range(rank)
277
277
  ],
278
278
  )
279
- pred = np.ones(self.type.shape, dtype=np.bool_)
280
- pred[*[
279
+
280
+ slices = [
281
281
  slice(start, end, step) if i == dim else slice(None, None, None)
282
282
  for i in range(rank)
283
- ]] = False
283
+ ]
284
+ pred = np.ones(self.type.shape, dtype=np.bool_)
285
+ pred[np.index_exp[tuple(slices)]] = False
284
286
  pred = stablehlo.constant(
285
287
  ir.DenseElementsAttr.get(
286
288
  np.packbits(pred, bitorder="little"),
@@ -232,7 +232,9 @@ def _aten_convolution(
232
232
 
233
233
  if bias is not None:
234
234
  # broadcast [C] to [NCHW]
235
- broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(
236
+ output_type, bias, ir.DenseI64ArrayAttr.get([1])
237
+ )
236
238
  res = stablehlo.add(
237
239
  lhs=res,
238
240
  rhs=broadcasted_bias,
@@ -20,6 +20,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
20
20
  from ai_edge_torch.odml_torch.lowerings import utils
21
21
  from jax._src.lib.mlir import ir
22
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import numpy as np
23
24
  import torch
24
25
 
25
26
 
@@ -66,12 +67,20 @@ def _aten_native_layer_norm(
66
67
  normalized_rank = len(normalized_shape)
67
68
  if weight is not None:
68
69
  weight = stablehlo.broadcast_in_dim(
69
- data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ data_type,
71
+ weight,
72
+ ir.DenseI64ArrayAttr.get(
73
+ list(range(data_rank - normalized_rank, data_rank))
74
+ ),
70
75
  )
71
76
  output = stablehlo.multiply(weight, output)
72
77
  if bias is not None:
73
78
  bias = stablehlo.broadcast_in_dim(
74
- data_type, bias, list(range(data_rank - normalized_rank, data_rank))
79
+ data_type,
80
+ bias,
81
+ ir.DenseI64ArrayAttr.get(
82
+ list(range(data_rank - normalized_rank, data_rank))
83
+ ),
75
84
  )
76
85
  output = stablehlo.add(bias, output)
77
86
 
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  """Lowerings for PT2E torch.ops.quantized_decomposed ops."""
16
- from typing import Union, cast
16
+ from typing import Optional, Union, cast
17
17
 
18
18
  from ai_edge_torch.odml_torch.lowerings import context
19
19
  from ai_edge_torch.odml_torch.lowerings import utils
@@ -30,15 +30,15 @@ LoweringContext = context.LoweringContext
30
30
 
31
31
 
32
32
  def _uniform_quantized_type(
33
- stored_type: str | ir.Type,
34
- expressed_type: str | ir.Type,
33
+ stored_type: Union[str, ir.Type],
34
+ expressed_type: Union[str, ir.Type],
35
35
  *,
36
- scale=float | list[float] | tuple[float],
37
- zero_point=float | list[float] | tuple[float],
38
- storage_type_min: int | None = None,
39
- storage_type_max: int | None = None,
40
- channel_axis: int | None = None,
41
- channel_axis_size: int | None = None,
36
+ scale=Union[float, list[float], tuple[float]],
37
+ zero_point=Union[float, list[float], tuple[float]],
38
+ storage_type_min: Optional[int] = None,
39
+ storage_type_max: Optional[int] = None,
40
+ channel_axis: Optional[int] = None,
41
+ channel_axis_size: Optional[int] = None,
42
42
  ):
43
43
  """Polyfill for quant.UniformQuantizedType."""
44
44
  if storage_type_min and storage_type_max:
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20241213"
16
+ __version__ = "0.3.0.dev20241214"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20241213
3
+ Version: 0.3.0.dev20241214
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -11,7 +11,6 @@ Classifier: Intended Audience :: Science/Research
11
11
  Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3 :: Only
14
- Classifier: Programming Language :: Python :: 3.9
15
14
  Classifier: Programming Language :: Python :: 3.10
16
15
  Classifier: Programming Language :: Python :: 3.11
17
16
  Classifier: Topic :: Scientific/Engineering
@@ -20,7 +19,7 @@ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
20
19
  Classifier: Topic :: Software Development
21
20
  Classifier: Topic :: Software Development :: Libraries
22
21
  Classifier: Topic :: Software Development :: Libraries :: Python Modules
23
- Requires-Python: >=3.9
22
+ Requires-Python: >=3.10
24
23
  Description-Content-Type: text/markdown
25
24
  License-File: LICENSE
26
25
  Requires-Dist: numpy
@@ -28,10 +27,13 @@ Requires-Dist: scipy
28
27
  Requires-Dist: safetensors
29
28
  Requires-Dist: tabulate
30
29
  Requires-Dist: torch>=2.4.0
31
- Requires-Dist: torch-xla>=2.4.0
32
30
  Requires-Dist: tf-nightly>=2.19.0.dev20241201
33
31
  Requires-Dist: ai-edge-litert-nightly
34
32
  Requires-Dist: ai-edge-quantizer-nightly
33
+ Requires-Dist: jax
34
+ Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
35
+ Provides-Extra: torch-xla
36
+ Requires-Dist: torch-xla>=2.4.0; extra == "torch-xla"
35
37
 
36
38
  Library that supports converting PyTorch models into a .tflite format, which can
37
39
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -1,9 +1,9 @@
1
- ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,1168
2
- ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
1
+ ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
2
+ ai_edge_torch/_config.py,sha256=QIrerb6uHMahRvMilmhodJ_6jfiRps3qgLOBeidPnS4,1614
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
5
5
  ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
6
- ai_edge_torch/version.py,sha256=ZKmpBJjKnl93nlPyg2RYz15citIW0ntqZ-0diRjwTt8,706
6
+ ai_edge_torch/version.py,sha256=iCH8lnlOrtbGwvxnT3knpY_keeu2UnrJ_ZXNK2LSvf4,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=v6AhfWqRBuHT7uBDueTbntaQtDSMMrvQOqlIDXNUaMA,17250
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=gK9QJuLbpjXt0l6tVnzl9Miq6GLkJR-hB67i3VE13Og,17224
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -119,7 +119,7 @@ ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoT
119
119
  ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
120
120
  ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
121
121
  ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
122
- ai_edge_torch/generative/layers/kv_cache.py,sha256=dOhk3ec21189uPyCDYyxuznYQL6s4od-ln-FoDQ2cE0,6269
122
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
123
123
  ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
124
124
  ai_edge_torch/generative/layers/normalization.py,sha256=h2btgRHMMjOcyLm8adEmcT0pG6imq4QcWblKJK5MYXA,7479
125
125
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
@@ -139,8 +139,8 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
139
139
  ai_edge_torch/generative/test/test_custom_dus.py,sha256=gxG78CcTpXF3iLzDR15Rlz1ey1tNTlSdkp6TeYEijp0,3301
140
140
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=2AulHBS3hC4b_68PNNBkRVOrypy4IM5YjC4p-6dgCMM,3793
141
141
  ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
142
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=4d_UF19KYf5xFa3yhQGe1nu3TKzmXrbr9PFiEZpPlyk,6274
143
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=4lfZXjXqfjiyxc2s8vMuYbdOZzD8VuPelr2AQo9PFNI,11656
142
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=jfqkECCX7XKHeBAuDXrkwQJf0vM72eG3LMc5rluha84,6191
143
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=mVuax3MPRmuNjnDRKXqtc9YmswCy7MnhD1CHADK-3nk,11501
144
144
  ai_edge_torch/generative/test/test_quantize.py,sha256=bEJMhpQ9bIDUZVBXTW888728FcH-i3SyE4JSZZUgU0A,6071
145
145
  ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
146
146
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
@@ -159,12 +159,11 @@ ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge
159
159
  ai_edge_torch/hlfb/mark_pattern/pattern.py,sha256=NP2mYhe5D2GjtqQfqqldp-ko3xtNghuFKKJOQskUJFI,10041
160
160
  ai_edge_torch/hlfb/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
161
161
  ai_edge_torch/hlfb/test/test_mark_pattern.py,sha256=ivq0eVjuf31idfNY0E12F4FxdkSI9hwYXapLJBkIf8Q,4831
162
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4I7p7NwanQzkQNeH0asZ7lz5y7twgQ4,8447
163
162
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
164
- ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
163
+ ai_edge_torch/lowertools/_shim.py,sha256=xJIHDSWNoF4PkkT0JkjeJxgguQ9JGEwooJf9xZNkVRU,3058
165
164
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
166
165
  ai_edge_torch/lowertools/odml_torch_utils.py,sha256=Smt7p62-lZ_3bBBfnbssAK5GAGxm3U_X7M-1qwsmc68,8161
167
- ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
166
+ ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
168
167
  ai_edge_torch/lowertools/torch_xla_utils.py,sha256=XGZE0vZG9WSQT-6dFmPlU8W89z8rfXPRGjuZeuhXCIw,9205
169
168
  ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
170
169
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
@@ -183,12 +182,12 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
183
182
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
184
183
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
185
184
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=0GytV1dGnqe1mKityqQDNFNS8T4QBg3UZuRJcGHwGyA,993
186
- ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=ufvnaAh6rM_yfoc8ybI3VErHEVBv5W_p4iOe9slfwKM,9948
185
+ ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=8mZTp_ybcMO3tDRQdlDP68BVeTw560XsTR4XH-ldTdc,9987
187
186
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
188
- ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
187
+ ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
189
188
  ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=OVmlPGwyhDXKhmG4SAeEsa6iLpJHEHV_jKqwfjYvetA,11643
190
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
191
- ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=rFmzqcdjYrwhcxH8j9zCFStPy21HFF7hkUV_GQ8FPAk,6056
189
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
190
+ ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=GEs83mtEjh8GOW_OATI_ur11VKujrOL2xdZeZ0l1HtM,6100
192
191
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
193
192
  ai_edge_torch/odml_torch/lowerings/decomp.py,sha256=UoJeZVcr4zAN_11i-HzfOhxGCxUm-7b1JXPVBxR2hSs,2414
194
193
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=Tp2h11l5uTifO0aIkuUOWAF_ibEjmd65Xx99w3EXuGE,1924
@@ -201,8 +200,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
201
200
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
202
201
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
203
202
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
204
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
205
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/METADATA,sha256=Yzw2YkrbFAe1EbxfoKBDME5NQe0GIzzseUOVylqFpnM,1897
206
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
207
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
208
- ai_edge_torch_nightly-0.3.0.dev20241213.dist-info/RECORD,,
203
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
204
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/METADATA,sha256=fUbq26zB0WUU1l6eUud8vq3Nm3KSIhox74pzFSFTmoM,1966
205
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
206
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
207
+ ai_edge_torch_nightly-0.3.0.dev20241214.dist-info/RECORD,,
ai_edge_torch/config.py DELETED
@@ -1,27 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
-
16
- """Provides a configuration for the AI Edge Torch library."""
17
-
18
- import dataclasses
19
- import os
20
-
21
-
22
- @dataclasses.dataclass
23
- class Config:
24
- use_torch_xla: bool = os.environ.get("USE_TORCH_XLA", "true").lower() in (
25
- "1",
26
- "true",
27
- )
@@ -1,283 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================
15
- """Tests for StableHLOCompositeBuilder."""
16
-
17
- import math
18
-
19
- from ai_edge_torch import config
20
- from ai_edge_torch import lowertools
21
- from ai_edge_torch.hlfb import StableHLOCompositeBuilder
22
- import torch
23
- import torch.nn.functional as F
24
-
25
- from absl.testing import absltest as googletest
26
-
27
-
28
- def _export_stablehlo_mlir(model, args):
29
- ep = torch.export.export(model, args)
30
- return lowertools.exported_program_to_mlir_text(ep)
31
-
32
-
33
- @googletest.skipIf(
34
- not config.Config.use_torch_xla,
35
- reason="The odml_torch counter part is in odml_torch.",
36
- )
37
- class TestStableHLOCompositeBuilder(googletest.TestCase):
38
-
39
- def test_build_composite(self):
40
- class SampleModel(torch.nn.Module):
41
-
42
- def forward(self, x):
43
- builder = StableHLOCompositeBuilder(name="test.plus_two")
44
- y = x + 1
45
- y = builder.mark_inputs(y)
46
- z = y + 2
47
- z = builder.mark_outputs(z)
48
- return z
49
-
50
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
51
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 1)
52
-
53
- def test_build_multiple_composites(self):
54
- class SampleModel(torch.nn.Module):
55
-
56
- def plus_one(self, x: torch.Tensor):
57
- builder = StableHLOCompositeBuilder("test.plus_one")
58
- x = builder.mark_inputs(x)
59
- y = x + 1
60
- y = builder.mark_outputs(y)
61
- return y
62
-
63
- def plus_two(self, x: torch.Tensor):
64
- builder = StableHLOCompositeBuilder("test.plus_two")
65
- x = builder.mark_inputs(x)
66
- y = x + 2
67
- y = builder.mark_outputs(y)
68
- return y
69
-
70
- def forward(self, x):
71
- x = self.plus_two(x)
72
- x = x + 3
73
- x = self.plus_one(x)
74
- x = x + 4
75
- x = self.plus_two(x)
76
- return x
77
-
78
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
79
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_one"'), 1)
80
- self.assertEqual(mlir.count('stablehlo.composite "test.plus_two"'), 2)
81
-
82
- def test_build_composite_with_attr(self):
83
- class SampleModel(torch.nn.Module):
84
-
85
- def __init__(self):
86
- super().__init__()
87
-
88
- def log_softmax(self, x: torch.Tensor, dim: int):
89
- builder = StableHLOCompositeBuilder(
90
- name="test.log_softmax", attr={"dim": dim}
91
- )
92
- x = builder.mark_inputs(x)
93
- y = torch.nn.functional.log_softmax(x, dim=dim)
94
- y = builder.mark_outputs(y)
95
- return y
96
-
97
- def forward(self, x):
98
- x = x + 1
99
- x = self.log_softmax(x, 0)
100
- x = self.log_softmax(x, 1)
101
- return x
102
-
103
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
104
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 2)
105
- self.assertEqual(mlir.count("composite_attributes = {dim = 0 : i64}"), 1)
106
- self.assertEqual(mlir.count("composite_attributes = {dim = 1 : i64}"), 1)
107
-
108
- def test_build_composite_with_mix_type_attrs(self):
109
- class SampleModel(torch.nn.Module):
110
-
111
- def __init__(self):
112
- super().__init__()
113
-
114
- def log_softmax(self, x: torch.Tensor, dim: int):
115
- builder = StableHLOCompositeBuilder(
116
- name="test.log_softmax",
117
- attr={
118
- "dim": dim,
119
- "source": "torch.nn",
120
- "version": 1.0,
121
- },
122
- )
123
- x = builder.mark_inputs(x)
124
- y = torch.nn.functional.log_softmax(x, dim=dim)
125
- y = builder.mark_outputs(y)
126
- return y
127
-
128
- def forward(self, x):
129
- x = x + 1
130
- x = self.log_softmax(x, 0)
131
- return x
132
-
133
- mlir = _export_stablehlo_mlir(SampleModel().eval(), (torch.rand((2, 2)),))
134
- self.assertEqual(mlir.count('stablehlo.composite "test.log_softmax"'), 1)
135
- self.assertEqual(
136
- mlir.count(
137
- 'composite_attributes = {dim = 0 : i64, source = "torch.nn",'
138
- " version = 1.000000e+00 : f32}"
139
- ),
140
- 1,
141
- )
142
-
143
- def test_sdpa_composite(self):
144
- class SDPAModel(torch.nn.Module):
145
-
146
- def scaled_dot_product_attention(
147
- self,
148
- q: torch.Tensor,
149
- k: torch.Tensor,
150
- v: torch.Tensor,
151
- head_size: int,
152
- mask: torch.Tensor,
153
- ):
154
- builder = StableHLOCompositeBuilder("test.scaled_dot_product_attention")
155
- q, k, v, mask = builder.mark_inputs(q, k, v, mask)
156
-
157
- scale = 1.0 / math.sqrt(head_size)
158
-
159
- q = q.transpose(1, 2)
160
- k = k.transpose(1, 2)
161
- v = v.transpose(1, 2)
162
- y = F.scaled_dot_product_attention(
163
- q,
164
- k,
165
- v,
166
- attn_mask=mask,
167
- dropout_p=0.0,
168
- is_causal=mask is None,
169
- scale=scale,
170
- )
171
- result = y.transpose(1, 2)
172
- result = builder.mark_outputs(result)
173
- return result
174
-
175
- def forward(self, q, k, v, mask):
176
- x = self.scaled_dot_product_attention(
177
- q,
178
- k,
179
- v,
180
- 8,
181
- mask,
182
- )
183
- return x
184
-
185
- query = torch.rand(1, 1, 32, 4)
186
- key = torch.rand(1, 500, 1, 4)
187
- value = torch.rand(1, 500, 1, 4)
188
- mask = torch.rand(1, 1, 1, 500)
189
-
190
- mlir = _export_stablehlo_mlir(
191
- SDPAModel().eval(),
192
- (query, key, value, mask),
193
- )
194
- self.assertEqual(
195
- mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 1
196
- )
197
-
198
- def test_sdpa_composite_with_attr(self):
199
- class SDPAModel(torch.nn.Module):
200
-
201
- def scaled_dot_product_attention(
202
- self,
203
- q: torch.Tensor,
204
- k: torch.Tensor,
205
- v: torch.Tensor,
206
- head_size: int,
207
- include_captanh: bool,
208
- ):
209
- builder = StableHLOCompositeBuilder(
210
- name="test.scaled_dot_product_attention",
211
- attr={"include_captanh": include_captanh},
212
- )
213
- q, k, v = builder.mark_inputs(q, k, v)
214
-
215
- scale = 1.0 / math.sqrt(head_size)
216
-
217
- q = q.transpose(1, 2)
218
- k = k.transpose(1, 2)
219
- v = v.transpose(1, 2)
220
- y = F.scaled_dot_product_attention(
221
- q,
222
- k,
223
- v,
224
- attn_mask=None,
225
- dropout_p=0.0,
226
- is_causal=True,
227
- scale=scale,
228
- )
229
- result = y.transpose(1, 2)
230
- result = builder.mark_outputs(result)
231
- return result
232
-
233
- def forward(self, q, k, v):
234
- x = self.scaled_dot_product_attention(q, k, v, 8, True)
235
- y = self.scaled_dot_product_attention(q, k, v, 8, False)
236
- return x + y
237
-
238
- query = torch.rand(1, 1, 32, 4)
239
- key = torch.rand(1, 500, 1, 4)
240
- value = torch.rand(1, 500, 1, 4)
241
- mlir = _export_stablehlo_mlir(
242
- SDPAModel().eval(),
243
- (query, key, value),
244
- )
245
- self.assertEqual(
246
- mlir.count('stablehlo.composite "test.scaled_dot_product_attention"'), 2
247
- )
248
- self.assertEqual(
249
- mlir.count("composite_attributes = {include_captanh = true}"), 1
250
- )
251
- self.assertEqual(
252
- mlir.count("composite_attributes = {include_captanh = false}"), 1
253
- )
254
-
255
- def test_build_composite_with_multiple_inputs_outputs(self):
256
- class SampleModel(torch.nn.Module):
257
-
258
- def mimo_sample(self, a, b, c):
259
- builder = StableHLOCompositeBuilder(name="test.mimo_sample")
260
-
261
- a, b, c = builder.mark_inputs(a, b, c)
262
- x = a + b + c
263
- y = (a - b) * x
264
- z = (c + 1.0) * a
265
- x, y, z = builder.mark_outputs(x, y, z)
266
-
267
- result = x + y * z
268
- return result
269
-
270
- def forward(self, a, b, c):
271
- x = self.mimo_sample(a, b, c)
272
- x = self.mimo_sample(a, b, x)
273
- x = self.mimo_sample(x, x, c)
274
- return x
275
-
276
- mlir = _export_stablehlo_mlir(
277
- SampleModel().eval(), (torch.rand(2), torch.rand(2), torch.rand(2))
278
- )
279
- self.assertEqual(mlir.count('stablehlo.composite "test.mimo_sample"'), 3)
280
-
281
-
282
- if __name__ == "__main__":
283
- googletest.main()