ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +61 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +53 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +59 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
- ai_edge_torch/generative/layers/unet/model_config.py +3 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
- ai_edge_torch/generative/utilities/verifier.py +200 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
|
|
41
41
|
)
|
42
42
|
self.conv_1 = nn.Conv2d(
|
43
43
|
config.in_channels,
|
44
|
-
config.
|
44
|
+
config.hidden_channels,
|
45
45
|
kernel_size=3,
|
46
46
|
stride=1,
|
47
47
|
padding=1,
|
48
48
|
)
|
49
49
|
if config.time_embedding_channels is not None:
|
50
50
|
self.time_emb_proj = nn.Linear(
|
51
|
-
config.time_embedding_channels, config.
|
51
|
+
config.time_embedding_channels, config.hidden_channels
|
52
52
|
)
|
53
53
|
else:
|
54
54
|
self.time_emb_proj = None
|
55
55
|
self.norm_2 = layers_builder.build_norm(
|
56
|
-
config.
|
56
|
+
config.hidden_channels, config.normalization_config
|
57
57
|
)
|
58
58
|
self.conv_2 = nn.Conv2d(
|
59
|
-
config.
|
59
|
+
config.hidden_channels,
|
60
60
|
config.out_channels,
|
61
61
|
kernel_size=3,
|
62
62
|
stride=1,
|
@@ -178,6 +178,8 @@ class CrossAttentionBlock2D(nn.Module):
|
|
178
178
|
config.attention_batch_size,
|
179
179
|
config.query_dim,
|
180
180
|
config.cross_dim,
|
181
|
+
config.hidden_dim,
|
182
|
+
config.output_dim,
|
181
183
|
config.attention_config,
|
182
184
|
enable_hlfb=config.enable_hlfb,
|
183
185
|
)
|
@@ -389,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
389
391
|
ResidualBlock2D(
|
390
392
|
unet_cfg.ResidualBlock2DConfig(
|
391
393
|
in_channels=input_channels,
|
394
|
+
hidden_channels=config.out_channels,
|
392
395
|
out_channels=config.out_channels,
|
393
396
|
time_embedding_channels=config.time_embedding_channels,
|
394
397
|
normalization_config=config.normalization_config,
|
@@ -490,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
490
493
|
ResidualBlock2D(
|
491
494
|
unet_cfg.ResidualBlock2DConfig(
|
492
495
|
in_channels=input_channels,
|
496
|
+
hidden_channels=config.out_channels,
|
493
497
|
out_channels=config.out_channels,
|
494
498
|
time_embedding_channels=config.time_embedding_channels,
|
495
499
|
normalization_config=config.normalization_config,
|
@@ -600,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
600
604
|
ResidualBlock2D(
|
601
605
|
unet_cfg.ResidualBlock2DConfig(
|
602
606
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
+
hidden_channels=config.out_channels,
|
603
608
|
out_channels=config.out_channels,
|
604
609
|
time_embedding_channels=config.time_embedding_channels,
|
605
610
|
normalization_config=config.normalization_config,
|
@@ -704,6 +709,7 @@ class MidBlock2D(nn.Module):
|
|
704
709
|
ResidualBlock2D(
|
705
710
|
unet_cfg.ResidualBlock2DConfig(
|
706
711
|
in_channels=config.in_channels,
|
712
|
+
hidden_channels=config.in_channels,
|
707
713
|
out_channels=config.in_channels,
|
708
714
|
time_embedding_channels=config.time_embedding_channels,
|
709
715
|
normalization_config=config.normalization_config,
|
@@ -722,6 +728,7 @@ class MidBlock2D(nn.Module):
|
|
722
728
|
ResidualBlock2D(
|
723
729
|
unet_cfg.ResidualBlock2DConfig(
|
724
730
|
in_channels=config.in_channels,
|
731
|
+
hidden_channels=config.in_channels,
|
725
732
|
out_channels=config.in_channels,
|
726
733
|
time_embedding_channels=config.time_embedding_channels,
|
727
734
|
normalization_config=config.normalization_config,
|
@@ -48,6 +48,7 @@ class DownSamplingConfig:
|
|
48
48
|
@dataclasses.dataclass
|
49
49
|
class ResidualBlock2DConfig:
|
50
50
|
in_channels: int
|
51
|
+
hidden_channels: int
|
51
52
|
out_channels: int
|
52
53
|
normalization_config: layers_cfg.NormalizationConfig
|
53
54
|
activation_config: layers_cfg.ActivationConfig
|
@@ -68,6 +69,8 @@ class AttentionBlock2DConfig:
|
|
68
69
|
class CrossAttentionBlock2DConfig:
|
69
70
|
query_dim: int
|
70
71
|
cross_dim: int
|
72
|
+
hidden_dim: int
|
73
|
+
output_dim: int
|
71
74
|
normalization_config: layers_cfg.NormalizationConfig
|
72
75
|
attention_config: layers_cfg.AttentionConfig
|
73
76
|
enable_hlfb: bool = True
|
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
|
|
96
96
|
def test_gemma2(self):
|
97
97
|
config = gemma2.get_fake_model_config()
|
98
98
|
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
100
100
|
|
101
101
|
@googletest.skipIf(
|
102
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions for model conversion."""
|
17
|
+
|
18
|
+
import ai_edge_torch
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
21
|
+
import torch
|
22
|
+
|
23
|
+
|
24
|
+
def convert_to_tflite(
|
25
|
+
pytorch_model: torch.nn.Module,
|
26
|
+
tflite_path: str,
|
27
|
+
prefill_seq_len: int = 512,
|
28
|
+
quantize: bool = True,
|
29
|
+
):
|
30
|
+
"""Converts a nn.Module model to multi-signature tflite model.
|
31
|
+
|
32
|
+
A PyTorch model will be converted to a tflite model with two signatures:
|
33
|
+
"prefill" and "decode".
|
34
|
+
|
35
|
+
"prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
|
36
|
+
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
37
|
+
external KV cache as a sample input.
|
38
|
+
|
39
|
+
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
40
|
+
of shape [1, 1] of the token position, and an external KV cache as a sample
|
41
|
+
input.
|
42
|
+
|
43
|
+
The final tflite model will be exported to tflite_path.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
47
|
+
tflite_path (str): The tflite file path to export.
|
48
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
49
|
+
Defaults to 512.
|
50
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
51
|
+
to True.
|
52
|
+
"""
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
58
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
59
|
+
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
61
|
+
edge_model = (
|
62
|
+
ai_edge_torch.signature(
|
63
|
+
'prefill',
|
64
|
+
pytorch_model,
|
65
|
+
sample_kwargs={
|
66
|
+
'tokens': prefill_tokens,
|
67
|
+
'input_pos': prefill_input_pos,
|
68
|
+
'kv_cache': kv,
|
69
|
+
},
|
70
|
+
)
|
71
|
+
.signature(
|
72
|
+
'decode',
|
73
|
+
pytorch_model,
|
74
|
+
sample_kwargs={
|
75
|
+
'tokens': decode_token,
|
76
|
+
'input_pos': decode_input_pos,
|
77
|
+
'kv_cache': kv,
|
78
|
+
},
|
79
|
+
)
|
80
|
+
.convert(quant_config=quant_config)
|
81
|
+
)
|
82
|
+
edge_model.export(tflite_path)
|
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
|
|
412
412
|
):
|
413
413
|
residual_block_config = unet_config.ResidualBlock2DConfig(
|
414
414
|
in_channels=config.in_channels,
|
415
|
+
hidden_channels=config.in_channels,
|
415
416
|
out_channels=config.in_channels,
|
416
417
|
time_embedding_channels=config.time_embedding_channels,
|
417
418
|
normalization_config=config.normalization_config,
|
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
|
|
466
467
|
f"{converted_state_param_prefix}.resnets.{i}",
|
467
468
|
unet_config.ResidualBlock2DConfig(
|
468
469
|
in_channels=input_channels,
|
470
|
+
hidden_channels=config.out_channels,
|
469
471
|
out_channels=config.out_channels,
|
470
472
|
time_embedding_channels=config.time_embedding_channels,
|
471
473
|
normalization_config=config.normalization_config,
|
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
|
|
508
510
|
f"{converted_state_param_prefix}.resnets.{i}",
|
509
511
|
unet_config.ResidualBlock2DConfig(
|
510
512
|
in_channels=input_channels,
|
513
|
+
hidden_channels=config.out_channels,
|
511
514
|
out_channels=config.out_channels,
|
512
515
|
time_embedding_channels=config.time_embedding_channels,
|
513
516
|
normalization_config=config.normalization_config,
|
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
|
|
554
557
|
f"{converted_state_param_prefix}.resnets.{i}",
|
555
558
|
unet_config.ResidualBlock2DConfig(
|
556
559
|
in_channels=resnet_in_channels + res_skip_channels,
|
560
|
+
hidden_channels=config.out_channels,
|
557
561
|
out_channels=config.out_channels,
|
558
562
|
time_embedding_channels=config.time_embedding_channels,
|
559
563
|
normalization_config=config.normalization_config,
|
@@ -811,6 +815,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
811
815
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
812
816
|
query_dim=output_channel,
|
813
817
|
cross_dim=config.transformer_cross_attention_dim,
|
818
|
+
hidden_dim=output_channel,
|
819
|
+
output_dim=output_channel,
|
814
820
|
normalization_config=config.transformer_norm_config,
|
815
821
|
attention_config=build_attention_config(
|
816
822
|
num_heads=config.transformer_num_attention_heads,
|
@@ -877,6 +883,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
877
883
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
878
884
|
query_dim=mid_block_channels,
|
879
885
|
cross_dim=config.transformer_cross_attention_dim,
|
886
|
+
hidden_dim=mid_block_channels,
|
887
|
+
output_dim=mid_block_channels,
|
880
888
|
normalization_config=config.transformer_norm_config,
|
881
889
|
attention_config=build_attention_config(
|
882
890
|
num_heads=config.transformer_num_attention_heads,
|
@@ -950,6 +958,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
950
958
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
951
959
|
query_dim=output_channel,
|
952
960
|
cross_dim=config.transformer_cross_attention_dim,
|
961
|
+
hidden_dim=output_channel,
|
962
|
+
output_dim=output_channel,
|
953
963
|
normalization_config=config.transformer_norm_config,
|
954
964
|
attention_config=build_attention_config(
|
955
965
|
num_heads=config.transformer_num_attention_heads,
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions to verify the reauthored models."""
|
17
|
+
|
18
|
+
import datetime
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import numpy as np
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
def log_msg(*args):
|
27
|
+
print("[%s]" % datetime.datetime.now(), *args)
|
28
|
+
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
model: torch.nn.Module,
|
32
|
+
tokens: torch.Tensor,
|
33
|
+
kv_cache: kv_utils.KVCache,
|
34
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
35
|
+
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
model (torch.nn.Module): The model to forward. It should be a model built
|
39
|
+
with ai_edge_torch Generative API.
|
40
|
+
tokens (torch.Tensor): The input tokens to forward.
|
41
|
+
kv_cache (KVCache): The KV cache to forward.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
The output logits and the updated KV cache.
|
45
|
+
"""
|
46
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
47
|
+
output = model.forward(tokens, input_pos, kv_cache)
|
48
|
+
return output["logits"], output["kv_cache"]
|
49
|
+
|
50
|
+
|
51
|
+
def generate(
|
52
|
+
model: torch.nn.Module, prompts: torch.Tensor, response_len: int
|
53
|
+
) -> torch.Tensor:
|
54
|
+
"""Generates the response to the prompts.
|
55
|
+
|
56
|
+
It appends tokens output by the model to the prompts and feeds them back to
|
57
|
+
the model up to decode_len.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
model (torch.nn.Module): The model to generate. It should be a model built
|
61
|
+
with ai_edge_torch Generative API.
|
62
|
+
prompts (torch.Tensor): The prompts to generate.
|
63
|
+
response_len (int): The number of tokens to generate.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
The generated tokens.
|
67
|
+
"""
|
68
|
+
input_ids = prompts[0].int().tolist()
|
69
|
+
kv_cache = kv_utils.KVCache.from_model_config(model.config)
|
70
|
+
for _ in range(response_len - len(input_ids)):
|
71
|
+
logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
|
72
|
+
generated_token = logits[0][-1].argmax().item()
|
73
|
+
input_ids.append(generated_token)
|
74
|
+
return torch.tensor([input_ids])
|
75
|
+
|
76
|
+
|
77
|
+
def verify_with_input_ids(
|
78
|
+
original_model: torch.nn.Module,
|
79
|
+
reauthored_model: torch.nn.Module,
|
80
|
+
input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
|
81
|
+
kv_cache_max_len: int = 1024,
|
82
|
+
rtol: float = 1e-05,
|
83
|
+
atol: float = 1e-05,
|
84
|
+
) -> bool:
|
85
|
+
"""Verifies if the model reauthored generates the same output of the oringal.
|
86
|
+
|
87
|
+
It compares only one outputs from the original and the reauthored model.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
original_model (torch.nn.Module): The original model.
|
91
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
92
|
+
Generative API.
|
93
|
+
input_ids (torch.Tensor): The input token IDs to forward.
|
94
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
95
|
+
rtol (float): The relative tolerance for the comparison.
|
96
|
+
atol (float): The absolute tolerance for the comparison.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
True if the model reauthored generates the same output of the original.
|
100
|
+
"""
|
101
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
102
|
+
input_ids_len = input_ids.shape[1]
|
103
|
+
tokens[0, :input_ids_len] = input_ids
|
104
|
+
|
105
|
+
log_msg("Forwarding the original model...")
|
106
|
+
outputs_original = original_model.forward(tokens)
|
107
|
+
logits_original = outputs_original.logits[0, input_ids_len - 1, :]
|
108
|
+
log_msg("logits_original: ", logits_original)
|
109
|
+
|
110
|
+
log_msg("Forwarding the reauthored model...")
|
111
|
+
kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
|
112
|
+
outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
|
113
|
+
logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
|
114
|
+
log_msg("logits_reauthored:", logits_reauthored)
|
115
|
+
|
116
|
+
return torch.allclose(
|
117
|
+
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
def verify_model_with_prompts(
|
122
|
+
original_model: torch.nn.Module,
|
123
|
+
reauthored_model: torch.nn.Module,
|
124
|
+
tokenizer: torch.nn.Module,
|
125
|
+
prompts: str,
|
126
|
+
) -> bool:
|
127
|
+
"""Verifies if the model reauthored generates the same answer of the oringal.
|
128
|
+
|
129
|
+
It compares an answer, i.e. multiple continuous outputs generated by the
|
130
|
+
original and the reauthored model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
original_model (torch.nn.Module): The original model.
|
134
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
135
|
+
Generative API.
|
136
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
137
|
+
prompts (str): The input prompts to generate answers.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
True if the model reauthored generates the same answer of the original.
|
141
|
+
"""
|
142
|
+
prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
|
143
|
+
|
144
|
+
log_msg("Generating answer with the original model...")
|
145
|
+
outputs_original = original_model.generate(prompt_tokens)
|
146
|
+
response_original = tokenizer.decode(outputs_original[0])
|
147
|
+
log_msg("outputs_from_original_model: [[", response_original, "]]")
|
148
|
+
|
149
|
+
log_msg("Generating answer with the reauthored model...")
|
150
|
+
generate_len = len(outputs_original[0])
|
151
|
+
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
152
|
+
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
153
|
+
log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
|
154
|
+
|
155
|
+
return response_original == response_reauthored
|
156
|
+
|
157
|
+
|
158
|
+
def verify_reauthored_model(
|
159
|
+
original_model: torch.nn.Module,
|
160
|
+
reauthored_model: torch.nn.Module,
|
161
|
+
tokenizer: torch.nn.Module,
|
162
|
+
prompts: List[str],
|
163
|
+
rtol: float = 1e-05,
|
164
|
+
atol: float = 1e-05,
|
165
|
+
):
|
166
|
+
"""Verifies the reauthored model against the original model.
|
167
|
+
|
168
|
+
It verifies the reauthored model with two methods:
|
169
|
+
1. It compares the output of the original and the reauthored model with an
|
170
|
+
arbitrary input.
|
171
|
+
2. It compares the answer generated by the original and the reauthored model
|
172
|
+
with a prompt.
|
173
|
+
|
174
|
+
It prints out "PASS" or "FAILED" to the console.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
original_model (torch.nn.Module): The original model.
|
178
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
179
|
+
Generative API.
|
180
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
181
|
+
prompts (List[str]): List of the input prompts to generate answers.
|
182
|
+
rtol (float): The relative tolerance for the comparison.
|
183
|
+
atol (float): The absolute tolerance for the comparison.
|
184
|
+
"""
|
185
|
+
log_msg("Verifying the reauthored model with an arbitrary input...")
|
186
|
+
if verify_with_input_ids(
|
187
|
+
original_model, reauthored_model, rtol=rtol, atol=atol
|
188
|
+
):
|
189
|
+
log_msg("PASS")
|
190
|
+
else:
|
191
|
+
log_msg("FAILED")
|
192
|
+
|
193
|
+
for p in prompts:
|
194
|
+
log_msg("Verifying the reauthored model with prompts:", p)
|
195
|
+
if verify_model_with_prompts(
|
196
|
+
original_model, reauthored_model, tokenizer, p
|
197
|
+
):
|
198
|
+
log_msg("PASS")
|
199
|
+
else:
|
200
|
+
log_msg("FAILED")
|
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
212
212
|
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
213
|
@lower(torch.ops.aten.slice_scatter)
|
214
214
|
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
-
start = start
|
216
|
-
end = end
|
215
|
+
start = start if start is not None else 0
|
216
|
+
end = end if end is not None else self.type.shape[dim]
|
217
|
+
|
218
|
+
start, end = np.clip(
|
219
|
+
[start, end], -self.type.shape[dim], self.type.shape[dim]
|
220
|
+
)
|
221
|
+
|
217
222
|
if start < 0:
|
218
223
|
start = self.type.shape[dim] + start
|
219
224
|
if end < 0:
|
220
225
|
end = self.type.shape[dim] + end
|
221
226
|
|
222
|
-
end
|
227
|
+
if end <= start or np.prod(src.type.shape) == 0:
|
228
|
+
return self
|
223
229
|
|
230
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
224
231
|
padding_low = start
|
225
232
|
padding_high = self.type.shape[dim] - end
|
233
|
+
interior_padding = step - 1
|
226
234
|
|
227
235
|
rank = len(self.type.shape)
|
228
236
|
src = stablehlo.pad(
|
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
230
238
|
utils.splat(0, src.type.element_type, []),
|
231
239
|
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
240
|
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
-
interior_padding=[
|
241
|
+
interior_padding=[
|
242
|
+
interior_padding if i == dim else 0 for i in range(rank)
|
243
|
+
],
|
234
244
|
)
|
235
245
|
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
246
|
pred[*[
|
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
|
|
57
57
|
torch._decomp.get_decompositions([
|
58
58
|
torch.ops.aten.upsample_nearest2d,
|
59
59
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
60
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
61
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
62
63
|
torch.ops.aten.grid_sampler_2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240919
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,25 +39,28 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
43
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
|
44
44
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
|
45
45
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
46
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
|
48
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
|
49
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
|
49
50
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
51
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
|
52
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
53
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
|
52
54
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
|
56
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
57
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
|
55
58
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
59
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
57
60
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
61
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
59
62
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
63
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
|
61
64
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
62
65
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
63
66
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -75,24 +78,25 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
75
78
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
76
79
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
77
80
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
79
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
|
82
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
83
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
|
80
84
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
85
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
86
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
87
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
88
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
89
|
+
ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
|
86
90
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
87
91
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
88
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
89
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
92
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
|
93
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
|
90
94
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
92
96
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
97
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
|
94
98
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
96
100
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
101
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
102
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -104,13 +108,15 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
104
108
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
105
109
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
106
110
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
|
107
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
111
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
|
108
112
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
113
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
114
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
115
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
111
116
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
112
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
117
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
113
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
|
114
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
115
121
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
116
122
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -141,13 +147,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
141
147
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
142
148
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
143
149
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
144
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
150
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
|
145
151
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
146
152
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
147
153
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
|
148
154
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
149
155
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
150
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
156
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
|
151
157
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
152
158
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
153
159
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -157,8 +163,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
157
163
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
164
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
159
165
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
160
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,
|
File without changes
|
File without changes
|