ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240919__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/openelm/openelm.py +0 -29
- ai_edge_torch/generative/examples/openelm/verify.py +61 -0
- ai_edge_torch/generative/examples/phi/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/phi/phi2.py +4 -31
- ai_edge_torch/generative/examples/phi/verify.py +53 -0
- ai_edge_torch/generative/examples/smollm/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/smollm/smollm.py +0 -30
- ai_edge_torch/generative/examples/smollm/verify.py +59 -0
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +36 -56
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +0 -29
- ai_edge_torch/generative/examples/tiny_llama/verify.py +61 -0
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/builder.py +3 -1
- ai_edge_torch/generative/layers/model_config.py +3 -0
- ai_edge_torch/generative/layers/normalization.py +31 -20
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +19 -9
- ai_edge_torch/generative/layers/unet/blocks_2d.py +11 -4
- ai_edge_torch/generative/layers/unet/model_config.py +3 -0
- ai_edge_torch/generative/test/test_model_conversion_large.py +1 -1
- ai_edge_torch/generative/utilities/converter.py +82 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +10 -0
- ai_edge_torch/generative/utilities/verifier.py +200 -0
- ai_edge_torch/odml_torch/lowerings/_basic.py +14 -4
- ai_edge_torch/odml_torch/lowerings/registry.py +1 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/RECORD +34 -28
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240916.dist-info → ai_edge_torch_nightly-0.3.0.dev20240919.dist-info}/top_level.txt +0 -0
@@ -41,22 +41,22 @@ class ResidualBlock2D(nn.Module):
|
|
41
41
|
)
|
42
42
|
self.conv_1 = nn.Conv2d(
|
43
43
|
config.in_channels,
|
44
|
-
config.
|
44
|
+
config.hidden_channels,
|
45
45
|
kernel_size=3,
|
46
46
|
stride=1,
|
47
47
|
padding=1,
|
48
48
|
)
|
49
49
|
if config.time_embedding_channels is not None:
|
50
50
|
self.time_emb_proj = nn.Linear(
|
51
|
-
config.time_embedding_channels, config.
|
51
|
+
config.time_embedding_channels, config.hidden_channels
|
52
52
|
)
|
53
53
|
else:
|
54
54
|
self.time_emb_proj = None
|
55
55
|
self.norm_2 = layers_builder.build_norm(
|
56
|
-
config.
|
56
|
+
config.hidden_channels, config.normalization_config
|
57
57
|
)
|
58
58
|
self.conv_2 = nn.Conv2d(
|
59
|
-
config.
|
59
|
+
config.hidden_channels,
|
60
60
|
config.out_channels,
|
61
61
|
kernel_size=3,
|
62
62
|
stride=1,
|
@@ -178,6 +178,8 @@ class CrossAttentionBlock2D(nn.Module):
|
|
178
178
|
config.attention_batch_size,
|
179
179
|
config.query_dim,
|
180
180
|
config.cross_dim,
|
181
|
+
config.hidden_dim,
|
182
|
+
config.output_dim,
|
181
183
|
config.attention_config,
|
182
184
|
enable_hlfb=config.enable_hlfb,
|
183
185
|
)
|
@@ -389,6 +391,7 @@ class DownEncoderBlock2D(nn.Module):
|
|
389
391
|
ResidualBlock2D(
|
390
392
|
unet_cfg.ResidualBlock2DConfig(
|
391
393
|
in_channels=input_channels,
|
394
|
+
hidden_channels=config.out_channels,
|
392
395
|
out_channels=config.out_channels,
|
393
396
|
time_embedding_channels=config.time_embedding_channels,
|
394
397
|
normalization_config=config.normalization_config,
|
@@ -490,6 +493,7 @@ class UpDecoderBlock2D(nn.Module):
|
|
490
493
|
ResidualBlock2D(
|
491
494
|
unet_cfg.ResidualBlock2DConfig(
|
492
495
|
in_channels=input_channels,
|
496
|
+
hidden_channels=config.out_channels,
|
493
497
|
out_channels=config.out_channels,
|
494
498
|
time_embedding_channels=config.time_embedding_channels,
|
495
499
|
normalization_config=config.normalization_config,
|
@@ -600,6 +604,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
600
604
|
ResidualBlock2D(
|
601
605
|
unet_cfg.ResidualBlock2DConfig(
|
602
606
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
+
hidden_channels=config.out_channels,
|
603
608
|
out_channels=config.out_channels,
|
604
609
|
time_embedding_channels=config.time_embedding_channels,
|
605
610
|
normalization_config=config.normalization_config,
|
@@ -704,6 +709,7 @@ class MidBlock2D(nn.Module):
|
|
704
709
|
ResidualBlock2D(
|
705
710
|
unet_cfg.ResidualBlock2DConfig(
|
706
711
|
in_channels=config.in_channels,
|
712
|
+
hidden_channels=config.in_channels,
|
707
713
|
out_channels=config.in_channels,
|
708
714
|
time_embedding_channels=config.time_embedding_channels,
|
709
715
|
normalization_config=config.normalization_config,
|
@@ -722,6 +728,7 @@ class MidBlock2D(nn.Module):
|
|
722
728
|
ResidualBlock2D(
|
723
729
|
unet_cfg.ResidualBlock2DConfig(
|
724
730
|
in_channels=config.in_channels,
|
731
|
+
hidden_channels=config.in_channels,
|
725
732
|
out_channels=config.in_channels,
|
726
733
|
time_embedding_channels=config.time_embedding_channels,
|
727
734
|
normalization_config=config.normalization_config,
|
@@ -48,6 +48,7 @@ class DownSamplingConfig:
|
|
48
48
|
@dataclasses.dataclass
|
49
49
|
class ResidualBlock2DConfig:
|
50
50
|
in_channels: int
|
51
|
+
hidden_channels: int
|
51
52
|
out_channels: int
|
52
53
|
normalization_config: layers_cfg.NormalizationConfig
|
53
54
|
activation_config: layers_cfg.ActivationConfig
|
@@ -68,6 +69,8 @@ class AttentionBlock2DConfig:
|
|
68
69
|
class CrossAttentionBlock2DConfig:
|
69
70
|
query_dim: int
|
70
71
|
cross_dim: int
|
72
|
+
hidden_dim: int
|
73
|
+
output_dim: int
|
71
74
|
normalization_config: layers_cfg.NormalizationConfig
|
72
75
|
attention_config: layers_cfg.AttentionConfig
|
73
76
|
enable_hlfb: bool = True
|
@@ -96,7 +96,7 @@ class TestModelConversion(googletest.TestCase):
|
|
96
96
|
def test_gemma2(self):
|
97
97
|
config = gemma2.get_fake_model_config()
|
98
98
|
pytorch_model = gemma2.Gemma2(config).eval()
|
99
|
-
self._test_model(config, pytorch_model, "prefill", atol=1e-
|
99
|
+
self._test_model(config, pytorch_model, "prefill", atol=1e-4, rtol=1e-5)
|
100
100
|
|
101
101
|
@googletest.skipIf(
|
102
102
|
ai_edge_config.Config.use_torch_xla,
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions for model conversion."""
|
17
|
+
|
18
|
+
import ai_edge_torch
|
19
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
20
|
+
from ai_edge_torch.generative.quantize import quant_recipes
|
21
|
+
import torch
|
22
|
+
|
23
|
+
|
24
|
+
def convert_to_tflite(
|
25
|
+
pytorch_model: torch.nn.Module,
|
26
|
+
tflite_path: str,
|
27
|
+
prefill_seq_len: int = 512,
|
28
|
+
quantize: bool = True,
|
29
|
+
):
|
30
|
+
"""Converts a nn.Module model to multi-signature tflite model.
|
31
|
+
|
32
|
+
A PyTorch model will be converted to a tflite model with two signatures:
|
33
|
+
"prefill" and "decode".
|
34
|
+
|
35
|
+
"prefill" signature takes a tensor of shape [1, prefill_seq_len] of token
|
36
|
+
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
37
|
+
external KV cache as a sample input.
|
38
|
+
|
39
|
+
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
40
|
+
of shape [1, 1] of the token position, and an external KV cache as a sample
|
41
|
+
input.
|
42
|
+
|
43
|
+
The final tflite model will be exported to tflite_path.
|
44
|
+
|
45
|
+
Args:
|
46
|
+
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
47
|
+
tflite_path (str): The tflite file path to export.
|
48
|
+
prefill_seq_len (int, optional): The maximum size of prefill input tensor.
|
49
|
+
Defaults to 512.
|
50
|
+
quantize (bool, optional): Whether the model should be quanized. Defaults
|
51
|
+
to True.
|
52
|
+
"""
|
53
|
+
# Tensors used to trace the model graph during conversion.
|
54
|
+
prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.int)
|
55
|
+
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
56
|
+
decode_token = torch.tensor([[0]], dtype=torch.int)
|
57
|
+
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
58
|
+
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
59
|
+
|
60
|
+
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
61
|
+
edge_model = (
|
62
|
+
ai_edge_torch.signature(
|
63
|
+
'prefill',
|
64
|
+
pytorch_model,
|
65
|
+
sample_kwargs={
|
66
|
+
'tokens': prefill_tokens,
|
67
|
+
'input_pos': prefill_input_pos,
|
68
|
+
'kv_cache': kv,
|
69
|
+
},
|
70
|
+
)
|
71
|
+
.signature(
|
72
|
+
'decode',
|
73
|
+
pytorch_model,
|
74
|
+
sample_kwargs={
|
75
|
+
'tokens': decode_token,
|
76
|
+
'input_pos': decode_input_pos,
|
77
|
+
'kv_cache': kv,
|
78
|
+
},
|
79
|
+
)
|
80
|
+
.convert(quant_config=quant_config)
|
81
|
+
)
|
82
|
+
edge_model.export(tflite_path)
|
@@ -412,6 +412,7 @@ class BaseLoader(loader.ModelLoader):
|
|
412
412
|
):
|
413
413
|
residual_block_config = unet_config.ResidualBlock2DConfig(
|
414
414
|
in_channels=config.in_channels,
|
415
|
+
hidden_channels=config.in_channels,
|
415
416
|
out_channels=config.in_channels,
|
416
417
|
time_embedding_channels=config.time_embedding_channels,
|
417
418
|
normalization_config=config.normalization_config,
|
@@ -466,6 +467,7 @@ class BaseLoader(loader.ModelLoader):
|
|
466
467
|
f"{converted_state_param_prefix}.resnets.{i}",
|
467
468
|
unet_config.ResidualBlock2DConfig(
|
468
469
|
in_channels=input_channels,
|
470
|
+
hidden_channels=config.out_channels,
|
469
471
|
out_channels=config.out_channels,
|
470
472
|
time_embedding_channels=config.time_embedding_channels,
|
471
473
|
normalization_config=config.normalization_config,
|
@@ -508,6 +510,7 @@ class BaseLoader(loader.ModelLoader):
|
|
508
510
|
f"{converted_state_param_prefix}.resnets.{i}",
|
509
511
|
unet_config.ResidualBlock2DConfig(
|
510
512
|
in_channels=input_channels,
|
513
|
+
hidden_channels=config.out_channels,
|
511
514
|
out_channels=config.out_channels,
|
512
515
|
time_embedding_channels=config.time_embedding_channels,
|
513
516
|
normalization_config=config.normalization_config,
|
@@ -554,6 +557,7 @@ class BaseLoader(loader.ModelLoader):
|
|
554
557
|
f"{converted_state_param_prefix}.resnets.{i}",
|
555
558
|
unet_config.ResidualBlock2DConfig(
|
556
559
|
in_channels=resnet_in_channels + res_skip_channels,
|
560
|
+
hidden_channels=config.out_channels,
|
557
561
|
out_channels=config.out_channels,
|
558
562
|
time_embedding_channels=config.time_embedding_channels,
|
559
563
|
normalization_config=config.normalization_config,
|
@@ -811,6 +815,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
811
815
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
812
816
|
query_dim=output_channel,
|
813
817
|
cross_dim=config.transformer_cross_attention_dim,
|
818
|
+
hidden_dim=output_channel,
|
819
|
+
output_dim=output_channel,
|
814
820
|
normalization_config=config.transformer_norm_config,
|
815
821
|
attention_config=build_attention_config(
|
816
822
|
num_heads=config.transformer_num_attention_heads,
|
@@ -877,6 +883,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
877
883
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
878
884
|
query_dim=mid_block_channels,
|
879
885
|
cross_dim=config.transformer_cross_attention_dim,
|
886
|
+
hidden_dim=mid_block_channels,
|
887
|
+
output_dim=mid_block_channels,
|
880
888
|
normalization_config=config.transformer_norm_config,
|
881
889
|
attention_config=build_attention_config(
|
882
890
|
num_heads=config.transformer_num_attention_heads,
|
@@ -950,6 +958,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
950
958
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
951
959
|
query_dim=output_channel,
|
952
960
|
cross_dim=config.transformer_cross_attention_dim,
|
961
|
+
hidden_dim=output_channel,
|
962
|
+
output_dim=output_channel,
|
953
963
|
normalization_config=config.transformer_norm_config,
|
954
964
|
attention_config=build_attention_config(
|
955
965
|
num_heads=config.transformer_num_attention_heads,
|
@@ -0,0 +1,200 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Common utility functions to verify the reauthored models."""
|
17
|
+
|
18
|
+
import datetime
|
19
|
+
from typing import List
|
20
|
+
|
21
|
+
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
|
+
import numpy as np
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
def log_msg(*args):
|
27
|
+
print("[%s]" % datetime.datetime.now(), *args)
|
28
|
+
|
29
|
+
|
30
|
+
def forward(
|
31
|
+
model: torch.nn.Module,
|
32
|
+
tokens: torch.Tensor,
|
33
|
+
kv_cache: kv_utils.KVCache,
|
34
|
+
) -> tuple[torch.Tensor, kv_utils.KVCache]:
|
35
|
+
"""Forwards the model reauthored with ai_edge_torch Generative API.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
model (torch.nn.Module): The model to forward. It should be a model built
|
39
|
+
with ai_edge_torch Generative API.
|
40
|
+
tokens (torch.Tensor): The input tokens to forward.
|
41
|
+
kv_cache (KVCache): The KV cache to forward.
|
42
|
+
|
43
|
+
Returns:
|
44
|
+
The output logits and the updated KV cache.
|
45
|
+
"""
|
46
|
+
input_pos = torch.arange(0, tokens.shape[1], dtype=torch.int)
|
47
|
+
output = model.forward(tokens, input_pos, kv_cache)
|
48
|
+
return output["logits"], output["kv_cache"]
|
49
|
+
|
50
|
+
|
51
|
+
def generate(
|
52
|
+
model: torch.nn.Module, prompts: torch.Tensor, response_len: int
|
53
|
+
) -> torch.Tensor:
|
54
|
+
"""Generates the response to the prompts.
|
55
|
+
|
56
|
+
It appends tokens output by the model to the prompts and feeds them back to
|
57
|
+
the model up to decode_len.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
model (torch.nn.Module): The model to generate. It should be a model built
|
61
|
+
with ai_edge_torch Generative API.
|
62
|
+
prompts (torch.Tensor): The prompts to generate.
|
63
|
+
response_len (int): The number of tokens to generate.
|
64
|
+
|
65
|
+
Returns:
|
66
|
+
The generated tokens.
|
67
|
+
"""
|
68
|
+
input_ids = prompts[0].int().tolist()
|
69
|
+
kv_cache = kv_utils.KVCache.from_model_config(model.config)
|
70
|
+
for _ in range(response_len - len(input_ids)):
|
71
|
+
logits, kv_cache = forward(model, torch.tensor([input_ids]), kv_cache)
|
72
|
+
generated_token = logits[0][-1].argmax().item()
|
73
|
+
input_ids.append(generated_token)
|
74
|
+
return torch.tensor([input_ids])
|
75
|
+
|
76
|
+
|
77
|
+
def verify_with_input_ids(
|
78
|
+
original_model: torch.nn.Module,
|
79
|
+
reauthored_model: torch.nn.Module,
|
80
|
+
input_ids: torch.Tensor = torch.from_numpy(np.array([[1, 2, 3, 4]])).int(),
|
81
|
+
kv_cache_max_len: int = 1024,
|
82
|
+
rtol: float = 1e-05,
|
83
|
+
atol: float = 1e-05,
|
84
|
+
) -> bool:
|
85
|
+
"""Verifies if the model reauthored generates the same output of the oringal.
|
86
|
+
|
87
|
+
It compares only one outputs from the original and the reauthored model.
|
88
|
+
|
89
|
+
Args:
|
90
|
+
original_model (torch.nn.Module): The original model.
|
91
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
92
|
+
Generative API.
|
93
|
+
input_ids (torch.Tensor): The input token IDs to forward.
|
94
|
+
kv_cache_max_len (int): The maximum sequence length of the KV cache.
|
95
|
+
rtol (float): The relative tolerance for the comparison.
|
96
|
+
atol (float): The absolute tolerance for the comparison.
|
97
|
+
|
98
|
+
Returns:
|
99
|
+
True if the model reauthored generates the same output of the original.
|
100
|
+
"""
|
101
|
+
tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.int, device="cpu")
|
102
|
+
input_ids_len = input_ids.shape[1]
|
103
|
+
tokens[0, :input_ids_len] = input_ids
|
104
|
+
|
105
|
+
log_msg("Forwarding the original model...")
|
106
|
+
outputs_original = original_model.forward(tokens)
|
107
|
+
logits_original = outputs_original.logits[0, input_ids_len - 1, :]
|
108
|
+
log_msg("logits_original: ", logits_original)
|
109
|
+
|
110
|
+
log_msg("Forwarding the reauthored model...")
|
111
|
+
kv_cache = kv_utils.KVCache.from_model_config(reauthored_model.config)
|
112
|
+
outputs_reauthored = forward(reauthored_model, tokens, kv_cache)
|
113
|
+
logits_reauthored = outputs_reauthored[0][0, input_ids_len - 1, :]
|
114
|
+
log_msg("logits_reauthored:", logits_reauthored)
|
115
|
+
|
116
|
+
return torch.allclose(
|
117
|
+
logits_original, logits_reauthored, rtol=rtol, atol=atol
|
118
|
+
)
|
119
|
+
|
120
|
+
|
121
|
+
def verify_model_with_prompts(
|
122
|
+
original_model: torch.nn.Module,
|
123
|
+
reauthored_model: torch.nn.Module,
|
124
|
+
tokenizer: torch.nn.Module,
|
125
|
+
prompts: str,
|
126
|
+
) -> bool:
|
127
|
+
"""Verifies if the model reauthored generates the same answer of the oringal.
|
128
|
+
|
129
|
+
It compares an answer, i.e. multiple continuous outputs generated by the
|
130
|
+
original and the reauthored model.
|
131
|
+
|
132
|
+
Args:
|
133
|
+
original_model (torch.nn.Module): The original model.
|
134
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
135
|
+
Generative API.
|
136
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
137
|
+
prompts (str): The input prompts to generate answers.
|
138
|
+
|
139
|
+
Returns:
|
140
|
+
True if the model reauthored generates the same answer of the original.
|
141
|
+
"""
|
142
|
+
prompt_tokens = tokenizer.encode(prompts, return_tensors="pt")
|
143
|
+
|
144
|
+
log_msg("Generating answer with the original model...")
|
145
|
+
outputs_original = original_model.generate(prompt_tokens)
|
146
|
+
response_original = tokenizer.decode(outputs_original[0])
|
147
|
+
log_msg("outputs_from_original_model: [[", response_original, "]]")
|
148
|
+
|
149
|
+
log_msg("Generating answer with the reauthored model...")
|
150
|
+
generate_len = len(outputs_original[0])
|
151
|
+
outputs_reauthored = generate(reauthored_model, prompt_tokens, generate_len)
|
152
|
+
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
153
|
+
log_msg("outputs from reauthored model: [[", response_reauthored, "]]")
|
154
|
+
|
155
|
+
return response_original == response_reauthored
|
156
|
+
|
157
|
+
|
158
|
+
def verify_reauthored_model(
|
159
|
+
original_model: torch.nn.Module,
|
160
|
+
reauthored_model: torch.nn.Module,
|
161
|
+
tokenizer: torch.nn.Module,
|
162
|
+
prompts: List[str],
|
163
|
+
rtol: float = 1e-05,
|
164
|
+
atol: float = 1e-05,
|
165
|
+
):
|
166
|
+
"""Verifies the reauthored model against the original model.
|
167
|
+
|
168
|
+
It verifies the reauthored model with two methods:
|
169
|
+
1. It compares the output of the original and the reauthored model with an
|
170
|
+
arbitrary input.
|
171
|
+
2. It compares the answer generated by the original and the reauthored model
|
172
|
+
with a prompt.
|
173
|
+
|
174
|
+
It prints out "PASS" or "FAILED" to the console.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
original_model (torch.nn.Module): The original model.
|
178
|
+
reauthored_model (torch.nn.Module): The model reauthored with ai_edge_torch
|
179
|
+
Generative API.
|
180
|
+
tokenizer (torch.nn.Module): The tokenizer.
|
181
|
+
prompts (List[str]): List of the input prompts to generate answers.
|
182
|
+
rtol (float): The relative tolerance for the comparison.
|
183
|
+
atol (float): The absolute tolerance for the comparison.
|
184
|
+
"""
|
185
|
+
log_msg("Verifying the reauthored model with an arbitrary input...")
|
186
|
+
if verify_with_input_ids(
|
187
|
+
original_model, reauthored_model, rtol=rtol, atol=atol
|
188
|
+
):
|
189
|
+
log_msg("PASS")
|
190
|
+
else:
|
191
|
+
log_msg("FAILED")
|
192
|
+
|
193
|
+
for p in prompts:
|
194
|
+
log_msg("Verifying the reauthored model with prompts:", p)
|
195
|
+
if verify_model_with_prompts(
|
196
|
+
original_model, reauthored_model, tokenizer, p
|
197
|
+
):
|
198
|
+
log_msg("PASS")
|
199
|
+
else:
|
200
|
+
log_msg("FAILED")
|
@@ -212,17 +212,25 @@ def _aten_div(mod, x, y, *, rounding_mode=None, out=None) -> ir.Value:
|
|
212
212
|
# - https://github.com/pytorch/pytorch/blob/18f9331e5deb4c02ae5c206e133a9b4add49bd97/aten/src/ATen/native/TensorShape.cpp#L4002
|
213
213
|
@lower(torch.ops.aten.slice_scatter)
|
214
214
|
def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
215
|
-
start = start
|
216
|
-
end = end
|
215
|
+
start = start if start is not None else 0
|
216
|
+
end = end if end is not None else self.type.shape[dim]
|
217
|
+
|
218
|
+
start, end = np.clip(
|
219
|
+
[start, end], -self.type.shape[dim], self.type.shape[dim]
|
220
|
+
)
|
221
|
+
|
217
222
|
if start < 0:
|
218
223
|
start = self.type.shape[dim] + start
|
219
224
|
if end < 0:
|
220
225
|
end = self.type.shape[dim] + end
|
221
226
|
|
222
|
-
end
|
227
|
+
if end <= start or np.prod(src.type.shape) == 0:
|
228
|
+
return self
|
223
229
|
|
230
|
+
end = start + step * math.ceil((end - start) / step) - (step - 1)
|
224
231
|
padding_low = start
|
225
232
|
padding_high = self.type.shape[dim] - end
|
233
|
+
interior_padding = step - 1
|
226
234
|
|
227
235
|
rank = len(self.type.shape)
|
228
236
|
src = stablehlo.pad(
|
@@ -230,7 +238,9 @@ def _aten_slice_scatter(lctx, self, src, dim=0, start=None, end=None, step=1):
|
|
230
238
|
utils.splat(0, src.type.element_type, []),
|
231
239
|
edge_padding_low=[padding_low if i == dim else 0 for i in range(rank)],
|
232
240
|
edge_padding_high=[padding_high if i == dim else 0 for i in range(rank)],
|
233
|
-
interior_padding=[
|
241
|
+
interior_padding=[
|
242
|
+
interior_padding if i == dim else 0 for i in range(rank)
|
243
|
+
],
|
234
244
|
)
|
235
245
|
pred = np.ones(self.type.shape, dtype=np.bool_)
|
236
246
|
pred[*[
|
@@ -57,6 +57,7 @@ global_registry.decompositions.update(
|
|
57
57
|
torch._decomp.get_decompositions([
|
58
58
|
torch.ops.aten.upsample_nearest2d,
|
59
59
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
60
|
+
torch.ops.aten._native_batch_norm_legit_functional,
|
60
61
|
torch.ops.aten._adaptive_avg_pool2d,
|
61
62
|
torch.ops.aten._adaptive_avg_pool3d,
|
62
63
|
torch.ops.aten.grid_sampler_2d,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240919
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=N5hYc9s2RU44J1_oe0UfJhTFo0d4JvMlKvxNlYtK0GI,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,25 +39,28 @@ ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrK
|
|
39
39
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
43
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
42
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=09VbyWErOMP9BXGwZpwvqzN5RaOqRigsELfxNRVeWns,2024
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=qJKQu6lKuSVhn8JR7KUeInq0u6yqgxEi7hfKCrZrIqY,2019
|
44
44
|
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=hjpSPzEjPHuxwRJ-vHHtCCf2PSTnm30Mp0ajYYtDivo,7489
|
45
45
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=gCLOti-4xHunjphNBbx9St6faRteSakm8Oex6R1Xek0,10272
|
46
46
|
ai_edge_torch/generative/examples/openelm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=
|
48
|
-
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=
|
47
|
+
ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=HnqP3te1Qvy4SKaaqPrsG05eojiKDJShp4H3jPC9tYg,2023
|
48
|
+
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=gGkHELNrt4xqnu11fCh3sJbZ7OsPyvoiF1J1aKCs5r8,7532
|
49
|
+
ai_edge_torch/generative/examples/openelm/verify.py,sha256=2qFdyLfcefdA3s1KQ-ZGWo4XReMXkEQAvpUEyJE5iqM,2057
|
49
50
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=
|
51
|
-
ai_edge_torch/generative/examples/phi/phi2.py,sha256=
|
51
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=viIkbAgknE3zxavTZtib87cMIG2_-jJXtxJPcmB2pGQ,2007
|
52
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=YwAszA53aOjvaMJ5wua2-5rP79N21Un_Y5yBCfFSYNU,6189
|
53
|
+
ai_edge_torch/generative/examples/phi/verify.py,sha256=R9BjOArnn-3svoIApmP1NwO47n8KIFikOF0_MEgTOa4,1770
|
52
54
|
ai_edge_torch/generative/examples/smollm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/smollm/convert_to_tflite.py,sha256=86hvBleyFXWmwy3Ke5J7x7WcCtG20D2kiBNrodE0R4w,2017
|
56
|
+
ai_edge_torch/generative/examples/smollm/smollm.py,sha256=hyhMk-b5762Q2xmjdD47g85dcbBSNJXNPIsifm1DRto,3239
|
57
|
+
ai_edge_torch/generative/examples/smollm/verify.py,sha256=JzidfVMMFDXzDdwn7ToDPuMo6eaoENNZGpEzX3f61Jk,1976
|
55
58
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
59
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
57
60
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
61
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
59
62
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
63
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
|
61
64
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
62
65
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
63
66
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -75,24 +78,25 @@ ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W
|
|
75
78
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
76
79
|
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=e_Kqm5dStSrNE9_aIYC-vYJRsqLn-hJVkmR4QjYqZI0,5913
|
77
80
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
78
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
79
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
81
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=Yg5G1LePoryeTib35lqICqaDW6foLUzSRgwJ2FlklIw,2040
|
82
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=tlWpa7Aun3u3w5b-9EBtW7olhmSf8W-tn5bKUIwC-ys,6044
|
83
|
+
ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=jld5PlGOQXMIWc1WoDYL_1nnsoVzRfrg-WgnsxRgaEU,2041
|
80
84
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
85
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
86
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
87
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
88
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
89
|
+
ai_edge_torch/generative/layers/builder.py,sha256=toT9Tl1x9o5KbG-eGOEViUr4fd_4f-XLZdMQT0Ae5_8,5130
|
86
90
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
87
91
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=2El7kZYnQRCRcVc63xgiAdBh9oVOksDu35p9XggvaGE,6148
|
88
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
89
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
92
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=d0Y-EFb4Rr7iLZ4Bsdf1i92KuhY1BXRqyeUN2kuu510,6923
|
93
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=l_36uFdruJwqqyubnBTM0M-iGiJfeFafyXKPPK8KHVo,6713
|
90
94
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
92
96
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
97
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=c8rtlfDaeKmUfiiTKPmQhNW-U5vW9jFB2pPPcvT6qsc,27527
|
94
98
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
99
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=8ze9kVWMuyZVQcgK7hWYw9TM1W9lXD-2j0iMHlxoGX4,9267
|
96
100
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
101
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
102
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -104,13 +108,15 @@ ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudj
|
|
104
108
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
105
109
|
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
106
110
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=DBlqxW2IT-dZYzEfOMAp86Wtqiu6kgSWZ9BKZR1Clrw,5467
|
107
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
111
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=dUYFarOldejqbMpa0j0vIDvXlWPAancuI8di3XkGxm8,4498
|
108
112
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
109
113
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
114
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
115
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=MQUg2ZLmfk_2csWmQWKD_II0bXq4X3McI5i-qWraieE,2987
|
111
116
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
112
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
117
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
113
118
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=QAv1uJdI5o1yfphr_DpzxhZswKa4VG3JZUpqbCCWKMk,7114
|
114
120
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
115
121
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
116
122
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -141,13 +147,13 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
141
147
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
142
148
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
143
149
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
144
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
150
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=VvB050UCjB17h6-UNtsaqzVF13MGI01fPFkdmmghTj4,8790
|
145
151
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
146
152
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
147
153
|
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=RN6BwMHuFj_rFgLCZ6Tu32XHbS2HGjPJeir2nROQ2rA,10517
|
148
154
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
149
155
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
150
|
-
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=
|
156
|
+
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
|
151
157
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
152
158
|
ai_edge_torch/odml_torch/passes/__init__.py,sha256=AVwIwUTMx7rXacKjGy4kwrtMd3XB2v_ncdc40KOjUqQ,1245
|
153
159
|
ai_edge_torch/quantize/__init__.py,sha256=aB5dXot04bqyUhpsDFvxt9CIi15QAC4euvqOndJ0XLU,714
|
@@ -157,8 +163,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
157
163
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
164
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
159
165
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
160
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
167
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/METADATA,sha256=NkHYIOMz-5DNKJuSQ8wE-3Nz1R6a9YZ59M-Nq8sAnJg,1859
|
168
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
169
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
170
|
+
ai_edge_torch_nightly-0.3.0.dev20240919.dist-info/RECORD,,
|
File without changes
|
File without changes
|