ai-edge-torch-nightly 0.3.0.dev20240827__py3-none-any.whl → 0.3.0.dev20240829__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +6 -1
- ai_edge_torch/_convert/test/test_convert.py +1 -1
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -1
- ai_edge_torch/_convert/test/test_convert_multisig.py +71 -31
- ai_edge_torch/_convert/test/test_to_channel_last_io.py +1 -1
- ai_edge_torch/debug/test/test_culprit.py +1 -1
- ai_edge_torch/debug/test/test_search_model.py +1 -1
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +43 -59
- ai_edge_torch/generative/test/test_experimental_ekv.py +1 -1
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_quantize.py +1 -1
- ai_edge_torch/hlfb/test/test_mark_pattern.py +1 -1
- ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +1 -1
- ai_edge_torch/lowertools/odml_torch_utils.py +5 -1
- ai_edge_torch/lowertools/test_utils.py +1 -1
- ai_edge_torch/odml_torch/__init__.py +20 -0
- ai_edge_torch/odml_torch/_torch_future.py +61 -0
- ai_edge_torch/odml_torch/_torch_library.py +19 -0
- ai_edge_torch/odml_torch/composite/__init__.py +16 -0
- ai_edge_torch/odml_torch/composite/mark_tensor.py +120 -0
- ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py +106 -0
- ai_edge_torch/odml_torch/debuginfo/__init__.py +16 -0
- ai_edge_torch/odml_torch/debuginfo/_build.py +43 -0
- ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py +55 -0
- ai_edge_torch/odml_torch/export.py +320 -0
- ai_edge_torch/odml_torch/export_utils.py +168 -0
- ai_edge_torch/odml_torch/jax_bridge/__init__.py +15 -0
- ai_edge_torch/odml_torch/jax_bridge/_wrap.py +152 -0
- ai_edge_torch/odml_torch/jax_bridge/utils.py +75 -0
- ai_edge_torch/odml_torch/lowerings/__init__.py +24 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +204 -0
- ai_edge_torch/odml_torch/lowerings/_batch_norm.py +65 -0
- ai_edge_torch/odml_torch/lowerings/_convolution.py +119 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -0
- ai_edge_torch/odml_torch/lowerings/context.py +42 -0
- ai_edge_torch/odml_torch/lowerings/registry.py +87 -0
- ai_edge_torch/odml_torch/lowerings/utils.py +185 -0
- ai_edge_torch/odml_torch/passes/__init__.py +38 -0
- ai_edge_torch/odml_torch/tf_integration.py +194 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/RECORD +46 -22
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240827.dist-info → ai_edge_torch_nightly-0.3.0.dev20240829.dist-info}/top_level.txt +0 -0
|
@@ -331,7 +331,12 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
|
331
331
|
def batch_norm(input, weight, bias, running_mean, running_var, momentum, eps):
|
|
332
332
|
a = input - running_mean
|
|
333
333
|
b = torch.sqrt(running_var + eps)
|
|
334
|
-
|
|
334
|
+
out = a / b
|
|
335
|
+
if weight is not None:
|
|
336
|
+
out = out * weight
|
|
337
|
+
if bias is not None:
|
|
338
|
+
out = out + bias
|
|
339
|
+
return out, None, None
|
|
335
340
|
|
|
336
341
|
node.target = batch_norm
|
|
337
342
|
|
|
@@ -21,7 +21,7 @@ from ai_edge_torch.testing import model_coverage
|
|
|
21
21
|
import parameterized
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from
|
|
24
|
+
from absl.testing import absltest as googletest
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
@@ -17,9 +17,45 @@
|
|
|
17
17
|
import ai_edge_torch
|
|
18
18
|
from ai_edge_torch.testing import model_coverage
|
|
19
19
|
import torch
|
|
20
|
-
import
|
|
20
|
+
from torch import nn
|
|
21
21
|
|
|
22
|
-
from
|
|
22
|
+
from absl.testing import absltest as googletest
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class FullyConnectedModel(nn.Module):
|
|
26
|
+
"""A simple fully connected model with two fully connected layers."""
|
|
27
|
+
|
|
28
|
+
def __init__(self, input_size, hidden_size, output_size):
|
|
29
|
+
super(FullyConnectedModel, self).__init__()
|
|
30
|
+
self.fc = nn.Linear(input_size, hidden_size) # Fully connected layer
|
|
31
|
+
self.relu = nn.ReLU() # Activation function
|
|
32
|
+
self.output = nn.Linear(hidden_size, output_size)
|
|
33
|
+
|
|
34
|
+
def forward(self, x):
|
|
35
|
+
x = self.fc(x)
|
|
36
|
+
x = self.relu(x)
|
|
37
|
+
x = self.output(x)
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class FullyConvModel(nn.Module):
|
|
42
|
+
"""A simple fully convolutional model with two convolutions."""
|
|
43
|
+
|
|
44
|
+
def __init__(self):
|
|
45
|
+
super(FullyConvModel, self).__init__()
|
|
46
|
+
self.conv1 = nn.Conv2d(
|
|
47
|
+
3, 16, kernel_size=3, padding=1
|
|
48
|
+
) # Input channels: 3 (RGB), Output channels: 16
|
|
49
|
+
self.relu = nn.ReLU(inplace=True)
|
|
50
|
+
self.conv2 = nn.Conv2d(
|
|
51
|
+
16, 1, kernel_size=1
|
|
52
|
+
) # Output channels: 1 (single channel output)
|
|
53
|
+
|
|
54
|
+
def forward(self, x):
|
|
55
|
+
x = self.conv1(x)
|
|
56
|
+
x = self.relu(x)
|
|
57
|
+
x = self.conv2(x)
|
|
58
|
+
return x
|
|
23
59
|
|
|
24
60
|
|
|
25
61
|
class TestConvertMultiSignature(googletest.TestCase):
|
|
@@ -29,12 +65,12 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
29
65
|
super().setUp()
|
|
30
66
|
torch.manual_seed(0)
|
|
31
67
|
|
|
32
|
-
def
|
|
68
|
+
def test_convert_with_default(self):
|
|
33
69
|
"""Tests conversion of a model with two signatures one of which is the default."""
|
|
34
|
-
torch_module =
|
|
70
|
+
torch_module = FullyConvModel().eval()
|
|
35
71
|
|
|
36
|
-
args = (torch.randn(4, 3,
|
|
37
|
-
large_args = (torch.randn(4, 3,
|
|
72
|
+
args = (torch.randn(4, 3, 12, 12),)
|
|
73
|
+
large_args = (torch.randn(4, 3, 24, 24),)
|
|
38
74
|
|
|
39
75
|
signature_name = "large_input"
|
|
40
76
|
|
|
@@ -51,12 +87,12 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
51
87
|
)
|
|
52
88
|
)
|
|
53
89
|
|
|
54
|
-
def
|
|
90
|
+
def test_convert_no_default(self):
|
|
55
91
|
"""Tests conversion of a model with two signatures none of which is the default."""
|
|
56
|
-
torch_module =
|
|
92
|
+
torch_module = FullyConvModel().eval()
|
|
57
93
|
|
|
58
|
-
args = (torch.randn(4, 3,
|
|
59
|
-
large_args = (torch.randn(4, 3,
|
|
94
|
+
args = (torch.randn(4, 3, 12, 12),)
|
|
95
|
+
large_args = (torch.randn(4, 3, 24, 24),)
|
|
60
96
|
|
|
61
97
|
signature_name_1 = "input"
|
|
62
98
|
signature_name_2 = "large_input"
|
|
@@ -84,12 +120,12 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
84
120
|
)
|
|
85
121
|
)
|
|
86
122
|
|
|
87
|
-
def
|
|
123
|
+
def test_convert_signature_helper(self):
|
|
88
124
|
"""Tests the ai_edge_torch.signature helper function works."""
|
|
89
|
-
torch_module =
|
|
125
|
+
torch_module = FullyConvModel().eval()
|
|
90
126
|
|
|
91
|
-
args = (torch.randn(4, 3,
|
|
92
|
-
large_args = (torch.randn(4, 3,
|
|
127
|
+
args = (torch.randn(4, 3, 12, 12),)
|
|
128
|
+
large_args = (torch.randn(4, 3, 24, 24),)
|
|
93
129
|
|
|
94
130
|
signature_name = "large_input"
|
|
95
131
|
|
|
@@ -108,39 +144,43 @@ class TestConvertMultiSignature(googletest.TestCase):
|
|
|
108
144
|
|
|
109
145
|
def test_convert_separate_modules(self):
|
|
110
146
|
"""Tests conversion of two completely different modules as separate signatures."""
|
|
111
|
-
|
|
112
|
-
|
|
147
|
+
fully_conv = FullyConvModel().eval()
|
|
148
|
+
fully_connected = FullyConnectedModel(10, 5, 10).eval()
|
|
113
149
|
|
|
114
|
-
|
|
115
|
-
|
|
150
|
+
fully_conv_args = (torch.randn(4, 3, 12, 12),)
|
|
151
|
+
fully_connected_args = (torch.randn(10),)
|
|
116
152
|
|
|
117
|
-
|
|
118
|
-
|
|
153
|
+
fully_conv_signature_name = "fully_conv"
|
|
154
|
+
fully_connected_signature_name = "fully_connected"
|
|
119
155
|
|
|
120
156
|
edge_model = (
|
|
121
157
|
ai_edge_torch.signature(
|
|
122
|
-
|
|
158
|
+
fully_conv_signature_name, fully_conv, fully_conv_args
|
|
159
|
+
)
|
|
160
|
+
.signature(
|
|
161
|
+
fully_connected_signature_name,
|
|
162
|
+
fully_connected,
|
|
163
|
+
fully_connected_args,
|
|
123
164
|
)
|
|
124
|
-
.
|
|
125
|
-
.convert(resnet18, resnet_args)
|
|
165
|
+
.convert(fully_connected, fully_connected_args)
|
|
126
166
|
)
|
|
127
167
|
|
|
128
|
-
|
|
129
|
-
|
|
168
|
+
fully_conv_inference_args = (torch.randn(4, 3, 12, 12),)
|
|
169
|
+
fully_connected_inference_args = (torch.randn(10),)
|
|
130
170
|
self.assertTrue(
|
|
131
171
|
model_coverage.compare_tflite_torch(
|
|
132
172
|
edge_model,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
signature_name=
|
|
173
|
+
fully_conv,
|
|
174
|
+
fully_conv_inference_args,
|
|
175
|
+
signature_name=fully_conv_signature_name,
|
|
136
176
|
)
|
|
137
177
|
)
|
|
138
178
|
self.assertTrue(
|
|
139
179
|
model_coverage.compare_tflite_torch(
|
|
140
180
|
edge_model,
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
signature_name=
|
|
181
|
+
fully_connected,
|
|
182
|
+
fully_connected_inference_args,
|
|
183
|
+
signature_name=fully_connected_signature_name,
|
|
144
184
|
)
|
|
145
185
|
)
|
|
146
186
|
|
|
@@ -15,16 +15,16 @@
|
|
|
15
15
|
|
|
16
16
|
import argparse
|
|
17
17
|
import os
|
|
18
|
-
|
|
19
|
-
from typing import
|
|
18
|
+
import pathlib
|
|
19
|
+
from typing import Optional
|
|
20
20
|
|
|
21
|
-
import ai_edge_torch
|
|
22
|
-
from ai_edge_torch.generative.examples.stable_diffusion
|
|
23
|
-
|
|
24
|
-
from ai_edge_torch.
|
|
21
|
+
import ai_edge_torch
|
|
22
|
+
from ai_edge_torch.generative.examples.stable_diffusion import samplers
|
|
23
|
+
from ai_edge_torch.generative.examples.stable_diffusion import tokenizer
|
|
24
|
+
from ai_edge_torch.generative.examples.stable_diffusion import util
|
|
25
25
|
import numpy as np
|
|
26
26
|
from PIL import Image
|
|
27
|
-
|
|
27
|
+
import tqdm
|
|
28
28
|
|
|
29
29
|
arg_parser = argparse.ArgumentParser()
|
|
30
30
|
arg_parser.add_argument(
|
|
@@ -104,12 +104,12 @@ class StableDiffusion:
|
|
|
104
104
|
diffusion_ckpt: str,
|
|
105
105
|
decoder_ckpt: str
|
|
106
106
|
):
|
|
107
|
-
self.tokenizer = Tokenizer(tokenizer_vocab_dir)
|
|
108
|
-
self.clip = TfLiteModel.load(clip_ckpt)
|
|
109
|
-
self.decoder = TfLiteModel.load(decoder_ckpt)
|
|
110
|
-
self.diffusion = TfLiteModel.load(diffusion_ckpt)
|
|
107
|
+
self.tokenizer = tokenizer.Tokenizer(tokenizer_vocab_dir)
|
|
108
|
+
self.clip = ai_edge_torch.model.TfLiteModel.load(clip_ckpt)
|
|
109
|
+
self.decoder = ai_edge_torch.model.TfLiteModel.load(decoder_ckpt)
|
|
110
|
+
self.diffusion = ai_edge_torch.model.TfLiteModel.load(diffusion_ckpt)
|
|
111
111
|
if encoder_ckpt is not None:
|
|
112
|
-
self.encoder = TfLiteModel.load(encoder_ckpt)
|
|
112
|
+
self.encoder = ai_edge_torch.model.TfLiteModel.load(encoder_ckpt)
|
|
113
113
|
|
|
114
114
|
|
|
115
115
|
def run_tflite_pipeline(
|
|
@@ -128,48 +128,32 @@ def run_tflite_pipeline(
|
|
|
128
128
|
):
|
|
129
129
|
"""Run stable diffusion pipeline with tflite model.
|
|
130
130
|
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
The prompt to guide the image generation.
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
The
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
expense of slower inference. This parameter will be modulated by `strength`.
|
|
158
|
-
seed:
|
|
159
|
-
A seed to make generation deterministic.
|
|
160
|
-
strength:
|
|
161
|
-
Conceptually, indicates how much to transform the reference `input_image`.
|
|
162
|
-
Must be between 0 and 1.
|
|
163
|
-
`input_image` will be used as a starting point, adding more noise to it the
|
|
164
|
-
larger the `strength`.
|
|
165
|
-
The number of denoising steps depends on the amount of noise initially
|
|
166
|
-
added. When `strength` is 1,
|
|
167
|
-
added noise will be maximum and the denoising process will run for the full
|
|
168
|
-
number of iterations
|
|
169
|
-
specified in `n_inference_steps`. A value of 1, therefore, essentially
|
|
170
|
-
ignores `input_image`.
|
|
171
|
-
input_image:
|
|
172
|
-
Image which is served as the starting point for the image generation.
|
|
131
|
+
Args:
|
|
132
|
+
model: StableDiffsuion model.
|
|
133
|
+
prompt: The prompt to guide the image generation.
|
|
134
|
+
output_path: The path to the generated output image.
|
|
135
|
+
uncond_prompt: The prompt not to guide the image generation.
|
|
136
|
+
cfg_scale: Guidance scale of classifier-free guidance. Higher guidance scale
|
|
137
|
+
encourages to generate images that are closely linked to the text
|
|
138
|
+
`prompt`, usually at the expense of lower image quality.
|
|
139
|
+
height: The height in pixels of the generated image.
|
|
140
|
+
width: The width in pixels of the generated image.
|
|
141
|
+
sampler: A sampler to be used to denoise the encoded image latents. Can be
|
|
142
|
+
one of `k_lms, `k_euler`, or `k_euler_ancestral`.
|
|
143
|
+
n_inference_steps: The number of denoising steps. More denoising steps
|
|
144
|
+
usually lead to a higher quality image at the expense of slower inference.
|
|
145
|
+
This parameter will be modulated by `strength`.
|
|
146
|
+
seed: A seed to make generation deterministic.
|
|
147
|
+
strength: Conceptually, indicates how much to transform the reference
|
|
148
|
+
`input_image`. Must be between 0 and 1. `input_image` will be used as a
|
|
149
|
+
starting point, adding more noise to it the larger the `strength`. The
|
|
150
|
+
number of denoising steps depends on the amount of noise initially added.
|
|
151
|
+
When `strength` is 1, added noise will be maximum and the denoising
|
|
152
|
+
process will run for the full number of iterations specified in
|
|
153
|
+
`n_inference_steps`. A value of 1, therefore, essentially ignores
|
|
154
|
+
`input_image`.
|
|
155
|
+
input_image: Image which is served as the starting point for the image
|
|
156
|
+
generation.
|
|
173
157
|
"""
|
|
174
158
|
if not 0 < strength < 1:
|
|
175
159
|
raise ValueError('strength must be between 0 and 1')
|
|
@@ -202,7 +186,8 @@ def run_tflite_pipeline(
|
|
|
202
186
|
context = np.concatenate([cond_context, uncond_context], axis=0)
|
|
203
187
|
noise_shape = (1, 4, height // 8, width // 8)
|
|
204
188
|
|
|
205
|
-
# Initialization starts from input_image if any, otherwise, starts from a
|
|
189
|
+
# Initialization starts from input_image if any, otherwise, starts from a
|
|
190
|
+
# random sampling.
|
|
206
191
|
if input_image:
|
|
207
192
|
if not hasattr(model, 'encoder'):
|
|
208
193
|
raise AttributeError(
|
|
@@ -210,7 +195,6 @@ def run_tflite_pipeline(
|
|
|
210
195
|
' input_image.'
|
|
211
196
|
)
|
|
212
197
|
input_image = input_image.resize((width, height))
|
|
213
|
-
input_image_np = np.array(input_image).astype(np.float32)
|
|
214
198
|
input_image_np = util.rescale(input_image, (0, 255), (-1, 1))
|
|
215
199
|
input_image_np = util.move_channel(input_image_np, to='first')
|
|
216
200
|
encoder_noise = np.random.normal(size=noise_shape).astype(np.float32)
|
|
@@ -223,8 +207,8 @@ def run_tflite_pipeline(
|
|
|
223
207
|
latents *= sampler.initial_scale
|
|
224
208
|
|
|
225
209
|
# Diffusion process.
|
|
226
|
-
timesteps = tqdm(sampler.timesteps)
|
|
227
|
-
for
|
|
210
|
+
timesteps = tqdm.tqdm(sampler.timesteps)
|
|
211
|
+
for _, timestep in enumerate(timesteps):
|
|
228
212
|
time_embedding = util.get_time_embedding(timestep)
|
|
229
213
|
|
|
230
214
|
input_latents = latents * sampler.get_input_scale()
|
|
@@ -242,7 +226,7 @@ def run_tflite_pipeline(
|
|
|
242
226
|
images = util.rescale(images, (-1, 1), (0, 255), clamp=True)
|
|
243
227
|
images = util.move_channel(images, to='last')
|
|
244
228
|
if not os.path.exists(output_path):
|
|
245
|
-
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
229
|
+
pathlib.Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
246
230
|
Image.fromarray(images[0].astype(np.uint8)).save(output_path)
|
|
247
231
|
|
|
248
232
|
|
|
@@ -21,7 +21,7 @@ from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
|
|
|
21
21
|
import ai_edge_torch.generative.layers.model_config as cfg
|
|
22
22
|
import torch
|
|
23
23
|
|
|
24
|
-
from
|
|
24
|
+
from absl.testing import absltest as googletest
|
|
25
25
|
|
|
26
26
|
|
|
27
27
|
class TestExternalKVLayers(googletest.TestCase):
|
|
@@ -22,7 +22,7 @@ from ai_edge_torch.generative.utilities import loader as loading_utils
|
|
|
22
22
|
import safetensors.torch
|
|
23
23
|
import torch
|
|
24
24
|
|
|
25
|
-
from
|
|
25
|
+
from absl.testing import absltest as googletest
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
class TestLoader(googletest.TestCase):
|
|
@@ -28,7 +28,7 @@ from ai_edge_torch.testing import model_coverage
|
|
|
28
28
|
from parameterized import parameterized
|
|
29
29
|
import torch
|
|
30
30
|
|
|
31
|
-
from
|
|
31
|
+
from absl.testing import absltest as googletest
|
|
32
32
|
|
|
33
33
|
|
|
34
34
|
class TestVerifyRecipes(googletest.TestCase):
|
|
@@ -19,7 +19,7 @@ from ai_edge_torch.hlfb import mark_pattern
|
|
|
19
19
|
from ai_edge_torch.hlfb.mark_pattern import pattern as pattern_module
|
|
20
20
|
import torch
|
|
21
21
|
|
|
22
|
-
from
|
|
22
|
+
from absl.testing import absltest as googletest
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
def _export_stablehlo_mlir(model, args=None):
|
|
@@ -22,7 +22,7 @@ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
|
|
|
22
22
|
import torch
|
|
23
23
|
import torch.nn.functional as F
|
|
24
24
|
|
|
25
|
-
from
|
|
25
|
+
from absl.testing import absltest as googletest
|
|
26
26
|
|
|
27
27
|
|
|
28
28
|
def _export_stablehlo_mlir(model, args):
|
|
@@ -84,13 +84,17 @@ def _wrap_as_tf_func(
|
|
|
84
84
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in bundle.output_signature]
|
|
85
85
|
s_outs = [_get_shape_with_dynamic(sig) for sig in bundle.output_signature]
|
|
86
86
|
call_args = _extract_call_args(bundle, args, tf_state_dict)
|
|
87
|
+
# HACK: In OSS, we use MLIR pybinding and StableHLO dialect from JAX's
|
|
88
|
+
# build, which may not have the same StableHLO version as what used in
|
|
89
|
+
# TFLite converter. Therefore we always serialize MLIR module in VHLO.
|
|
90
|
+
# TODO(b/362798610) Build MLIR pybinding in ai-edge-torch release.
|
|
87
91
|
call_module_return = tfxla.call_module(
|
|
88
92
|
tuple(call_args),
|
|
89
93
|
version=5,
|
|
90
94
|
Tout=t_outs, # dtype information
|
|
91
95
|
Sout=s_outs, # Shape information
|
|
92
96
|
function_list=[],
|
|
93
|
-
module=bundle.
|
|
97
|
+
module=bundle.module_bytecode_vhlo,
|
|
94
98
|
)
|
|
95
99
|
spec = exported_program.call_spec.out_spec
|
|
96
100
|
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from . import composite
|
|
16
|
+
from . import debuginfo
|
|
17
|
+
from . import export
|
|
18
|
+
from . import export_utils
|
|
19
|
+
from . import lowerings
|
|
20
|
+
from . import passes
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Wrappers for latest torch APIs/utilities to maintain backward compatibility with older torch releases."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
from torch.fx import _pytree as fx_pytree
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def graph_module_flat_inputs(ep: torch.export.ExportedProgram, args, kwargs):
|
|
22
|
+
"""Transform args, kwargs of __call__ to args for graph_module.
|
|
23
|
+
|
|
24
|
+
self.graph_module takes stuff from state dict as inputs.
|
|
25
|
+
The invariant is for ep: ExportedProgram is
|
|
26
|
+
ep(args, kwargs) ==
|
|
27
|
+
ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs)))
|
|
28
|
+
"""
|
|
29
|
+
if hasattr(ep, "_graph_module_flat_inputs"):
|
|
30
|
+
return ep._graph_module_flat_inputs(args, kwargs)
|
|
31
|
+
|
|
32
|
+
if args is None:
|
|
33
|
+
args = tuple()
|
|
34
|
+
if kwargs is None:
|
|
35
|
+
kwargs = {}
|
|
36
|
+
|
|
37
|
+
flat_args = args
|
|
38
|
+
if (in_spec := ep.call_spec.in_spec) is not None:
|
|
39
|
+
if (
|
|
40
|
+
in_spec.type == tuple
|
|
41
|
+
and len(in_spec.children_specs) == 2
|
|
42
|
+
and in_spec.children_specs[0].type == tuple
|
|
43
|
+
and in_spec.children_specs[1].type == dict
|
|
44
|
+
):
|
|
45
|
+
# NOTE: this is the case where in_spec is for both args and kwargs
|
|
46
|
+
flat_args = fx_pytree.tree_flatten_spec((args, kwargs), in_spec)
|
|
47
|
+
else:
|
|
48
|
+
flat_args = fx_pytree.tree_flatten_spec(args, in_spec)
|
|
49
|
+
|
|
50
|
+
param_buffer_keys = ep.graph_signature.parameters + ep.graph_signature.buffers
|
|
51
|
+
param_buffer_values = tuple(ep.state_dict[key] for key in param_buffer_keys)
|
|
52
|
+
|
|
53
|
+
if hasattr(ep.graph_signature, "lifted_tensor_constants"):
|
|
54
|
+
ordered_tensor_constants = tuple(
|
|
55
|
+
ep.tensor_constants[name]
|
|
56
|
+
for name in ep.graph_signature.lifted_tensor_constants
|
|
57
|
+
)
|
|
58
|
+
else:
|
|
59
|
+
ordered_tensor_constants = tuple()
|
|
60
|
+
|
|
61
|
+
return (*param_buffer_values, *flat_args, *ordered_tensor_constants)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
"""Torch library for registering ODML Torch custom ops."""
|
|
16
|
+
|
|
17
|
+
import torch
|
|
18
|
+
|
|
19
|
+
ODML_TORCH_LIB = torch.library.Library("odml_torch", "DEF")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
# ==============================================================================
|
|
15
|
+
from .mark_tensor import mark_tensor_op
|
|
16
|
+
from .stablehlo_composite_builder import StableHLOCompositeBuilder
|