ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (46) hide show
  1. ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
  2. ai_edge_torch/_convert/test/test_convert.py +1 -1
  3. ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
  4. ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
  5. ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
  6. ai_edge_torch/debug/test/test_culprit.py +1 -1
  7. ai_edge_torch/debug/test/test_search_model.py +1 -1
  8. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
  9. ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
  10. ai_edge_torch/generative/test/test_loader.py +1 -1
  11. ai_edge_torch/generative/test/test_model_conversion.py +1 -1
  12. ai_edge_torch/generative/test/test_quantize.py +1 -1
  13. ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
  14. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
  15. ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
  16. ai_edge_torch/lowertools/test_utils.py +1 -1
  17. ai_edge_torch/odml_torch/__init__.py +20 -0
  18. ai_edge_torch/odml_torch/_torch_future.py +61 -0
  19. ai_edge_torch/odml_torch/_torch_library.py +19 -0
  20. ai_edge_torch/odml_torch/composite/__init__.py +16 -0
  21. ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
  22. ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
  23. ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
  24. ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
  25. ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
  26. ai_edge_torch/odml_torch/export.py +320 -0
  27. ai_edge_torch/odml_torch/export_utils.py +168 -0
  28. ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
  29. ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
  30. ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
  31. ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
  32. ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
  33. ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
  34. ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
  35. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
  36. ai_edge_torch/odml_torch/lowerings/context.py +42 -0
  37. ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
  38. ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
  39. ai_edge_torch/odml_torch/passes/__init__.py +38 -0
  40. ai_edge_torch/odml_torch/tf_integration.py +194 -0
  41. ai_edge_torch/version.py +1 -1
  42. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
  43. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
  44. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
  45. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
@@ -331,7 +331,12 @@ def _aten__native_batch_norm_legit_no_training(node):
331
331
  def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
332
332
  a = input - running_mean
333
333
  b = torch.sqrt(running_var + eps)
334
- return a / b * weight + bias, None, None
334
+ out = a / b
335
+ if weight is not None:
336
+ out = out * weight
337
+ if bias is not None:
338
+ out = out + bias
339
+ return out, None, None
335
340
 
336
341
  node.target = batch_norm
337
342
 
@@ -27,7 +27,7 @@ import tensorflow as tf
27
27
  import torch
28
28
  import torchvision
29
29
 
30
- from tensorflow.python.platform import googletest
30
+ from absl.testing import absltest as googletest
31
31
 
32
32
 
33
33
  @dataclasses.dataclass
@@ -21,7 +21,7 @@ from ai_edge_torch.testing import model_coverage
21
21
  import parameterized
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
 
27
27
  def _func_to_torch_module(func: Callable[..., torch.Tensor]):
@@ -17,9 +17,45 @@
17
17
  import ai_edge_torch
18
18
  from ai_edge_torch.testing import model_coverage
19
19
  import torch
20
- import torchvision
20
+ from torch import nn
21
21
 
22
- from tensorflow.python.platform import googletest
22
+ from absl.testing import absltest as googletest
23
+
24
+
25
+ class FullyConnectedModel(nn.Module):
26
+ """A simple fully connected model with two fully connected layers."""
27
+
28
+ def __init__(self, input_size, hidden_size, output_size):
29
+ super(FullyConnectedModel, self).__init__()
30
+ self.fc = nn.Linear(input_size, hidden_size) # Fully connected layer
31
+ self.relu = nn.ReLU() # Activation function
32
+ self.output = nn.Linear(hidden_size, output_size)
33
+
34
+ def forward(self, x):
35
+ x = self.fc(x)
36
+ x = self.relu(x)
37
+ x = self.output(x)
38
+ return x
39
+
40
+
41
+ class FullyConvModel(nn.Module):
42
+ """A simple fully convolutional model with two convolutions."""
43
+
44
+ def __init__(self):
45
+ super(FullyConvModel, self).__init__()
46
+ self.conv1 = nn.Conv2d(
47
+ 3, 16, kernel_size=3, padding=1
48
+ ) # Input channels: 3 (RGB), Output channels: 16
49
+ self.relu = nn.ReLU(inplace=True)
50
+ self.conv2 = nn.Conv2d(
51
+ 16, 1, kernel_size=1
52
+ ) # Output channels: 1 (single channel output)
53
+
54
+ def forward(self, x):
55
+ x = self.conv1(x)
56
+ x = self.relu(x)
57
+ x = self.conv2(x)
58
+ return x
23
59
 
24
60
 
25
61
  class TestConvertMultiSignature(googletest.TestCase):
@@ -29,12 +65,12 @@ class TestConvertMultiSignature(googletest.TestCase):
29
65
  super().setUp()
30
66
  torch.manual_seed(0)
31
67
 
32
- def test_convert_mobilenet_v2_with_default(self):
68
+ def test_convert_with_default(self):
33
69
  """Tests conversion of a model with two signatures one of which is the default."""
34
- torch_module = torchvision.models.mobilenet_v2().eval()
70
+ torch_module = FullyConvModel().eval()
35
71
 
36
- args = (torch.randn(4, 3, 224, 224),)
37
- large_args = (torch.randn(4, 3, 336, 336),)
72
+ args = (torch.randn(4, 3, 12, 12),)
73
+ large_args = (torch.randn(4, 3, 24, 24),)
38
74
 
39
75
  signature_name = "large_input"
40
76
 
@@ -51,12 +87,12 @@ class TestConvertMultiSignature(googletest.TestCase):
51
87
  )
52
88
  )
53
89
 
54
- def test_convert_mobilenet_v2_no_default(self):
90
+ def test_convert_no_default(self):
55
91
  """Tests conversion of a model with two signatures none of which is the default."""
56
- torch_module = torchvision.models.mobilenet_v2().eval()
92
+ torch_module = FullyConvModel().eval()
57
93
 
58
- args = (torch.randn(4, 3, 224, 224),)
59
- large_args = (torch.randn(4, 3, 336, 336),)
94
+ args = (torch.randn(4, 3, 12, 12),)
95
+ large_args = (torch.randn(4, 3, 24, 24),)
60
96
 
61
97
  signature_name_1 = "input"
62
98
  signature_name_2 = "large_input"
@@ -84,12 +120,12 @@ class TestConvertMultiSignature(googletest.TestCase):
84
120
  )
85
121
  )
86
122
 
87
- def test_convert_mobilenet_v2_signature_helper(self):
123
+ def test_convert_signature_helper(self):
88
124
  """Tests the ai_edge_torch.signature helper function works."""
89
- torch_module = torchvision.models.mobilenet_v2().eval()
125
+ torch_module = FullyConvModel().eval()
90
126
 
91
- args = (torch.randn(4, 3, 224, 224),)
92
- large_args = (torch.randn(4, 3, 336, 336),)
127
+ args = (torch.randn(4, 3, 12, 12),)
128
+ large_args = (torch.randn(4, 3, 24, 24),)
93
129
 
94
130
  signature_name = "large_input"
95
131
 
@@ -108,39 +144,43 @@ class TestConvertMultiSignature(googletest.TestCase):
108
144
 
109
145
  def test_convert_separate_modules(self):
110
146
  """Tests conversion of two completely different modules as separate signatures."""
111
- mobilentv2 = torchvision.models.mobilenet_v2().eval()
112
- resnet18 = torchvision.models.resnet18().eval()
147
+ fully_conv = FullyConvModel().eval()
148
+ fully_connected = FullyConnectedModel(10, 5, 10).eval()
113
149
 
114
- mobilenet_args = (torch.randn(4, 3, 224, 224),)
115
- resnet_args = (torch.randn(4, 3, 224, 224),)
150
+ fully_conv_args = (torch.randn(4, 3, 12, 12),)
151
+ fully_connected_args = (torch.randn(10),)
116
152
 
117
- mobilenet_signature_name = "mobilentv2"
118
- resnet_signature_name = "resnet18"
153
+ fully_conv_signature_name = "fully_conv"
154
+ fully_connected_signature_name = "fully_connected"
119
155
 
120
156
  edge_model = (
121
157
  ai_edge_torch.signature(
122
- mobilenet_signature_name, mobilentv2, mobilenet_args
158
+ fully_conv_signature_name, fully_conv, fully_conv_args
159
+ )
160
+ .signature(
161
+ fully_connected_signature_name,
162
+ fully_connected,
163
+ fully_connected_args,
123
164
  )
124
- .signature(resnet_signature_name, resnet18, resnet_args)
125
- .convert(resnet18, resnet_args)
165
+ .convert(fully_connected, fully_connected_args)
126
166
  )
127
167
 
128
- mobilenet_inference_args = (torch.randn(4, 3, 224, 224),)
129
- resnet_inference_args = (torch.randn(4, 3, 224, 224),)
168
+ fully_conv_inference_args = (torch.randn(4, 3, 12, 12),)
169
+ fully_connected_inference_args = (torch.randn(10),)
130
170
  self.assertTrue(
131
171
  model_coverage.compare_tflite_torch(
132
172
  edge_model,
133
- mobilentv2,
134
- mobilenet_inference_args,
135
- signature_name=mobilenet_signature_name,
173
+ fully_conv,
174
+ fully_conv_inference_args,
175
+ signature_name=fully_conv_signature_name,
136
176
  )
137
177
  )
138
178
  self.assertTrue(
139
179
  model_coverage.compare_tflite_torch(
140
180
  edge_model,
141
- resnet18,
142
- resnet_inference_args,
143
- signature_name=resnet_signature_name,
181
+ fully_connected,
182
+ fully_connected_inference_args,
183
+ signature_name=fully_connected_signature_name,
144
184
  )
145
185
  )
146
186
 
@@ -17,7 +17,7 @@
17
17
  import ai_edge_torch
18
18
  import torch
19
19
 
20
- from tensorflow.python.platform import googletest
20
+ from absl.testing import absltest as googletest
21
21
 
22
22
 
23
23
  class Identity(torch.nn.Module):
@@ -21,7 +21,7 @@ import sys
21
21
  from ai_edge_torch.debug import find_culprits
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
27
 
@@ -17,7 +17,7 @@
17
17
  from ai_edge_torch.debug import _search_model
18
18
  import torch
19
19
 
20
- from tensorflow.python.platform import googletest
20
+ from absl.testing import absltest as googletest
21
21
 
22
22
 
23
23
  class TestSearchModel(googletest.TestCase):
@@ -15,16 +15,16 @@
15
15
 
16
16
  import argparse
17
17
  import os
18
- from pathlib import Path
19
- from typing import Dict, Optional
18
+ import pathlib
19
+ from typing import Optional
20
20
 
21
- import ai_edge_torch.generative.examples.stable_diffusion.samplers as samplers
22
- from ai_edge_torch.generative.examples.stable_diffusion.tokenizer import Tokenizer # NOQA
23
- import ai_edge_torch.generative.examples.stable_diffusion.util as util
24
- from ai_edge_torch.model import TfLiteModel
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.stable_diffusion import samplers
23
+ from ai_edge_torch.generative.examples.stable_diffusion import tokenizer
24
+ from ai_edge_torch.generative.examples.stable_diffusion import util
25
25
  import numpy as np
26
26
  from PIL import Image
27
- from tqdm import tqdm
27
+ import tqdm
28
28
 
29
29
  arg_parser = argparse.ArgumentParser()
30
30
  arg_parser.add_argument(
@@ -104,12 +104,12 @@ class StableDiffusion:
104
104
  diffusion_ckpt: str,
105
105
  decoder_ckpt: str
106
106
  ):
107
- self.tokenizer = Tokenizer(tokenizer_vocab_dir)
108
- self.clip = TfLiteModel.load(clip_ckpt)
109
- self.decoder = TfLiteModel.load(decoder_ckpt)
110
- self.diffusion = TfLiteModel.load(diffusion_ckpt)
107
+ self.tokenizer = tokenizer.Tokenizer(tokenizer_vocab_dir)
108
+ self.clip = ai_edge_torch.model.TfLiteModel.load(clip_ckpt)
109
+ self.decoder = ai_edge_torch.model.TfLiteModel.load(decoder_ckpt)
110
+ self.diffusion = ai_edge_torch.model.TfLiteModel.load(diffusion_ckpt)
111
111
  if encoder_ckpt is not None:
112
- self.encoder = TfLiteModel.load(encoder_ckpt)
112
+ self.encoder = ai_edge_torch.model.TfLiteModel.load(encoder_ckpt)
113
113
 
114
114
 
115
115
  def run_tflite_pipeline(
@@ -128,48 +128,32 @@ def run_tflite_pipeline(
128
128
  ):
129
129
  """Run stable diffusion pipeline with tflite model.
130
130
 
131
- model:
132
-
133
- StableDiffsuion model.
134
- prompt:
135
- The prompt to guide the image generation.
136
- output_path:
137
- The path to the generated output image.
138
- uncond_prompt:
139
- The prompt not to guide the image generation.
140
- cfg_scale:
141
- Guidance scale of classifier-free guidance. Higher guidance scale encourages
142
- to generate
143
- images that are closely linked to the text `prompt`, usually at the expense
144
- of lower
145
- image quality.
146
- height:
147
- The height in pixels of the generated image.
148
- width:
149
- The width in pixels of the generated image.
150
- sampler:
151
- A sampler to be used to denoise the encoded image latents. Can be one of
152
- `k_lms, `k_euler`,
153
- or `k_euler_ancestral`.
154
- n_inference_steps:
155
- The number of denoising steps. More denoising steps usually lead to a higher
156
- quality image at the
157
- expense of slower inference. This parameter will be modulated by `strength`.
158
- seed:
159
- A seed to make generation deterministic.
160
- strength:
161
- Conceptually, indicates how much to transform the reference `input_image`.
162
- Must be between 0 and 1.
163
- `input_image` will be used as a starting point, adding more noise to it the
164
- larger the `strength`.
165
- The number of denoising steps depends on the amount of noise initially
166
- added. When `strength` is 1,
167
- added noise will be maximum and the denoising process will run for the full
168
- number of iterations
169
- specified in `n_inference_steps`. A value of 1, therefore, essentially
170
- ignores `input_image`.
171
- input_image:
172
- Image which is served as the starting point for the image generation.
131
+ Args:
132
+ model: StableDiffsuion model.
133
+ prompt: The prompt to guide the image generation.
134
+ output_path: The path to the generated output image.
135
+ uncond_prompt: The prompt not to guide the image generation.
136
+ cfg_scale: Guidance scale of classifier-free guidance. Higher guidance scale
137
+ encourages to generate images that are closely linked to the text
138
+ `prompt`, usually at the expense of lower image quality.
139
+ height: The height in pixels of the generated image.
140
+ width: The width in pixels of the generated image.
141
+ sampler: A sampler to be used to denoise the encoded image latents. Can be
142
+ one of `k_lms, `k_euler`, or `k_euler_ancestral`.
143
+ n_inference_steps: The number of denoising steps. More denoising steps
144
+ usually lead to a higher quality image at the expense of slower inference.
145
+ This parameter will be modulated by `strength`.
146
+ seed: A seed to make generation deterministic.
147
+ strength: Conceptually, indicates how much to transform the reference
148
+ `input_image`. Must be between 0 and 1. `input_image` will be used as a
149
+ starting point, adding more noise to it the larger the `strength`. The
150
+ number of denoising steps depends on the amount of noise initially added.
151
+ When `strength` is 1, added noise will be maximum and the denoising
152
+ process will run for the full number of iterations specified in
153
+ `n_inference_steps`. A value of 1, therefore, essentially ignores
154
+ `input_image`.
155
+ input_image: Image which is served as the starting point for the image
156
+ generation.
173
157
  """
174
158
  if not 0 < strength < 1:
175
159
  raise ValueError('strength must be between 0 and 1')
@@ -202,7 +186,8 @@ def run_tflite_pipeline(
202
186
  context = np.concatenate([cond_context, uncond_context], axis=0)
203
187
  noise_shape = (1, 4, height // 8, width // 8)
204
188
 
205
- # Initialization starts from input_image if any, otherwise, starts from a random sampling.
189
+ # Initialization starts from input_image if any, otherwise, starts from a
190
+ # random sampling.
206
191
  if input_image:
207
192
  if not hasattr(model, 'encoder'):
208
193
  raise AttributeError(
@@ -210,7 +195,6 @@ def run_tflite_pipeline(
210
195
  ' input_image.'
211
196
  )
212
197
  input_image = input_image.resize((width, height))
213
- input_image_np = np.array(input_image).astype(np.float32)
214
198
  input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
215
199
  input_image_np = util.move_channel(input_image_np, to='first')
216
200
  encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
@@ -223,8 +207,8 @@ def run_tflite_pipeline(
223
207
  latents *= sampler.initial_scale
224
208
 
225
209
  # Diffusion process.
226
- timesteps = tqdm(sampler.timesteps)
227
- for i, timestep in enumerate(timesteps):
210
+ timesteps = tqdm.tqdm(sampler.timesteps)
211
+ for _, timestep in enumerate(timesteps):
228
212
  time_embedding = util.get_time_embedding(timestep)
229
213
 
230
214
  input_latents = latents * sampler.get_input_scale()
@@ -242,7 +226,7 @@ def run_tflite_pipeline(
242
226
  images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
243
227
  images = util.move_channel(images, to='last')
244
228
  if not os.path.exists(output_path):
245
- Path(output_path).parent.mkdir(parents=True, exist_ok=True)
229
+ pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
246
230
  Image.fromarray(images[0].astype(np.uint8)).save(output_path)
247
231
 
248
232
 
@@ -21,7 +21,7 @@ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
21
21
  import ai_edge_torch.generative.layers.model_config as cfg
22
22
  import torch
23
23
 
24
- from tensorflow.python.platform import googletest
24
+ from absl.testing import absltest as googletest
25
25
 
26
26
 
27
27
  class TestExternalKVLayers(googletest.TestCase):
@@ -22,7 +22,7 @@ from ai_edge_torch.generative.utilities import loader as loading_utils
22
22
  import safetensors.torch
23
23
  import torch
24
24
 
25
- from tensorflow.python.platform import googletest
25
+ from absl.testing import absltest as googletest
26
26
 
27
27
 
28
28
  class TestLoader(googletest.TestCase):
@@ -24,7 +24,7 @@ from ai_edge_torch.testing import model_coverage
24
24
  import numpy as np
25
25
  import torch
26
26
 
27
- from tensorflow.python.platform import googletest
27
+ from absl.testing import absltest as googletest
28
28
 
29
29
 
30
30
  class TestModelConversion(googletest.TestCase):
@@ -28,7 +28,7 @@ from ai_edge_torch.testing import model_coverage
28
28
  from parameterized import parameterized
29
29
  import torch
30
30
 
31
- from tensorflow.python.platform import googletest
31
+ from absl.testing import absltest as googletest
32
32
 
33
33
 
34
34
  class TestVerifyRecipes(googletest.TestCase):
@@ -19,7 +19,7 @@ from ai_edge_torch.hlfb import mark_pattern
19
19
  from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
20
20
  import torch
21
21
 
22
- from tensorflow.python.platform import googletest
22
+ from absl.testing import absltest as googletest
23
23
 
24
24
 
25
25
  def _export_stablehlo_mlir(model, args=None):
@@ -22,7 +22,7 @@ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
22
22
  import torch
23
23
  import torch.nn.functional as F
24
24
 
25
- from tensorflow.python.platform import googletest
25
+ from absl.testing import absltest as googletest
26
26
 
27
27
 
28
28
  def _export_stablehlo_mlir(model, args):
@@ -84,13 +84,17 @@ def _wrap_as_tf_func(
84
84
  t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
85
85
  s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
86
86
  call_args = _extract_call_args(bundle, args, tf_state_dict)
87
+ # HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
88
+ # build, which may not have the same StableHLO version as what used in
89
+ # TFLite converter. Therefore we always serialize MLIR module in VHLO.
90
+ # TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
87
91
  call_module_return = tfxla.call_module(
88
92
  tuple(call_args),
89
93
  version=5,
90
94
  Tout=t_outs, # dtype information
91
95
  Sout=s_outs, # Shape information
92
96
  function_list=[],
93
- module=bundle.module_bytecode,
97
+ module=bundle.module_bytecode_vhlo,
94
98
  )
95
99
  spec = exported_program.call_spec.out_spec
96
100
 
@@ -16,7 +16,7 @@
16
16
  import re
17
17
  from typing import Optional
18
18
  from ai_edge_torch import config
19
- from tensorflow.python.platform import googletest
19
+ from absl.testing import absltest as googletest
20
20
 
21
21
 
22
22
  def _extract_backend_configs(mlir):
@@ -0,0 +1,20 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from . import composite
16
+ from . import debuginfo
17
+ from . import export
18
+ from . import export_utils
19
+ from . import lowerings
20
+ from . import passes
@@ -0,0 +1,61 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Wrappers for latest torch APIs/utilities to maintain backward compatibility with older torch releases."""
16
+
17
+ import torch
18
+ from torch.fx import _pytree as fx_pytree
19
+
20
+
21
+ def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
22
+ """Transform args, kwargs of __call__ to args for graph_module.
23
+
24
+ self.graph_module takes stuff from state dict as inputs.
25
+ The invariant is for ep: ExportedProgram is
26
+ ep(args, kwargs) ==
27
+ ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
28
+ """
29
+ if hasattr(ep, "_graph_module_flat_inputs"):
30
+ return ep._graph_module_flat_inputs(args, kwargs)
31
+
32
+ if args is None:
33
+ args = tuple()
34
+ if kwargs is None:
35
+ kwargs = {}
36
+
37
+ flat_args = args
38
+ if (in_spec := ep.call_spec.in_spec) is not None:
39
+ if (
40
+ in_spec.type == tuple
41
+ and len(in_spec.children_specs) == 2
42
+ and in_spec.children_specs[0].type == tuple
43
+ and in_spec.children_specs[1].type == dict
44
+ ):
45
+ # NOTE: this is the case where in_spec is for both args and kwargs
46
+ flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
47
+ else:
48
+ flat_args = fx_pytree.tree_flatten_spec(args, in_spec)
49
+
50
+ param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers
51
+ param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)
52
+
53
+ if hasattr(ep.graph_signature, "lifted_tensor_constants"):
54
+ ordered_tensor_constants = tuple(
55
+ ep.tensor_constants[name]
56
+ for name in ep.graph_signature.lifted_tensor_constants
57
+ )
58
+ else:
59
+ ordered_tensor_constants = tuple()
60
+
61
+ return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
@@ -0,0 +1,19 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Torch library for registering ODML Torch custom ops."""
16
+
17
+ import torch
18
+
19
+ ODML_TORCH_LIB = torch.library.Library("odml_torch", "DEF")
@@ -0,0 +1,16 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ from .mark_tensor import mark_tensor_op
16
+ from .stablehlo_composite_builder import StableHLOCompositeBuilder