ai-edge-torch-nightly 0.3.0.dev20241110__py3-none-any.whl → 0.3.0.dev20241115__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -0
- ai_edge_torch/generative/examples/gemma/verify_gemma2.py +1 -14
- ai_edge_torch/generative/examples/gemma/verify_util.py +28 -3
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +158 -0
- ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py +82 -0
- ai_edge_torch/generative/layers/attention.py +8 -6
- ai_edge_torch/generative/layers/model_config.py +14 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +53 -17
- ai_edge_torch/generative/layers/unet/model_config.py +8 -0
- ai_edge_torch/generative/utilities/loader.py +4 -0
- ai_edge_torch/generative/utilities/verifier.py +16 -4
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241110.dist-info → ai_edge_torch_nightly-0.3.0.dev20241115.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241110.dist-info → ai_edge_torch_nightly-0.3.0.dev20241115.dist-info}/RECORD +17 -15
- {ai_edge_torch_nightly-0.3.0.dev20241110.dist-info → ai_edge_torch_nightly-0.3.0.dev20241115.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241110.dist-info → ai_edge_torch_nightly-0.3.0.dev20241115.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241110.dist-info → ai_edge_torch_nightly-0.3.0.dev20241115.dist-info}/top_level.txt +0 -0
@@ -51,6 +51,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
51
51
|
return {
|
52
52
|
"output": (int(output_h), int(output_w)),
|
53
53
|
"align_corners": False,
|
54
|
+
"is_nchw_op": True,
|
54
55
|
}
|
55
56
|
|
56
57
|
return pattern
|
@@ -74,6 +75,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
74
75
|
return {
|
75
76
|
"output": (int(output_h), int(output_w)),
|
76
77
|
"align_corners": True,
|
78
|
+
"is_nchw_op": True,
|
77
79
|
}
|
78
80
|
|
79
81
|
return pattern
|
@@ -15,10 +15,8 @@
|
|
15
15
|
|
16
16
|
"""Verifies the reauthored Gemma2 model."""
|
17
17
|
|
18
|
-
import logging
|
19
18
|
from absl import app
|
20
19
|
from absl import flags
|
21
|
-
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
20
|
from ai_edge_torch.generative.examples.gemma import verify_util
|
23
21
|
import kagglehub
|
24
22
|
|
@@ -38,18 +36,7 @@ _MAX_NEW_TOKENS = flags.DEFINE_integer(
|
|
38
36
|
def main(_):
|
39
37
|
checkpoint = kagglehub.model_download("google/gemma-2/pyTorch/gemma-2-2b-it")
|
40
38
|
|
41
|
-
|
42
|
-
reauthored_model = gemma2.build_2b_model(checkpoint)
|
43
|
-
|
44
|
-
verify_util.verify_reauthored_gemma_model(
|
45
|
-
checkpoint=checkpoint,
|
46
|
-
variant="2b-v2",
|
47
|
-
reauthored_model=reauthored_model,
|
48
|
-
generate_prompts=_PROMPTS.value,
|
49
|
-
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
50
|
-
max_new_tokens=_MAX_NEW_TOKENS.value,
|
51
|
-
atol=1e-04,
|
52
|
-
)
|
39
|
+
verify_util.verify_gemma2(checkpoint, _PROMPTS.value, _MAX_NEW_TOKENS.value)
|
53
40
|
|
54
41
|
|
55
42
|
if __name__ == "__main__":
|
@@ -19,6 +19,7 @@ import logging
|
|
19
19
|
import os
|
20
20
|
from typing import List, Tuple
|
21
21
|
|
22
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
22
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
23
24
|
from ai_edge_torch.generative.utilities import verifier
|
24
25
|
from gemma import config as gemma_config
|
@@ -109,8 +110,11 @@ def verify_reauthored_gemma_model(
|
|
109
110
|
max_new_tokens: int = 20,
|
110
111
|
rtol: float = 1e-05,
|
111
112
|
atol: float = 1e-05,
|
112
|
-
):
|
113
|
-
"""Verifies the reauthored Gemma model against the original model.
|
113
|
+
) -> bool:
|
114
|
+
"""Verifies the reauthored Gemma model against the original model.
|
115
|
+
|
116
|
+
Returns True if the verification passes, False otherwise.
|
117
|
+
"""
|
114
118
|
config = gemma_config.get_model_config(variant)
|
115
119
|
config.tokenizer = os.path.join(checkpoint, tokenizer_filename)
|
116
120
|
# Use float32 to be compatible with the reauthored model.
|
@@ -120,7 +124,7 @@ def verify_reauthored_gemma_model(
|
|
120
124
|
original_model = gemma_model.GemmaForCausalLM(config).eval()
|
121
125
|
original_model.load_weights(os.path.join(checkpoint, weight_filename))
|
122
126
|
|
123
|
-
verifier.verify_reauthored_model(
|
127
|
+
return verifier.verify_reauthored_model(
|
124
128
|
original_model=GemmaWrapper(original_model),
|
125
129
|
reauthored_model=verifier.ReauthoredModelWrapper(reauthored_model),
|
126
130
|
tokenizer=GemmaTokenizerWrapper(original_model.tokenizer),
|
@@ -130,3 +134,24 @@ def verify_reauthored_gemma_model(
|
|
130
134
|
rtol=rtol,
|
131
135
|
atol=atol,
|
132
136
|
)
|
137
|
+
|
138
|
+
|
139
|
+
def verify_gemma2(
|
140
|
+
gemma2_model_path: str, prompts: List[str], max_new_tokens: int
|
141
|
+
) -> bool:
|
142
|
+
"""Verifies the reauthored Gemma2 model.
|
143
|
+
|
144
|
+
Return True if the verification passes, False otherwise.
|
145
|
+
"""
|
146
|
+
logging.info("Building the reauthored model from: %s", gemma2_model_path)
|
147
|
+
reauthored_model = gemma2.build_2b_model(gemma2_model_path)
|
148
|
+
|
149
|
+
return verify_reauthored_gemma_model(
|
150
|
+
checkpoint=gemma2_model_path,
|
151
|
+
variant="2b-v2",
|
152
|
+
reauthored_model=reauthored_model,
|
153
|
+
generate_prompts=prompts,
|
154
|
+
forward_input_ids=[[2, 651, 9456, 576, 573, 3520, 3858, 603, 235248]],
|
155
|
+
max_new_tokens=max_new_tokens,
|
156
|
+
atol=1e-04,
|
157
|
+
)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of building an image encoder of PaliGemma model which is Siglip."""
|
17
|
+
|
18
|
+
from ai_edge_torch.generative.layers import attention
|
19
|
+
from ai_edge_torch.generative.layers import builder
|
20
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
|
+
import ai_edge_torch.generative.utilities.loader as loading_utils
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
26
|
+
ff_up_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc1",
|
27
|
+
ff_down_proj="vision_tower.vision_model.encoder.layers.{}.mlp.fc2",
|
28
|
+
attn_query_proj=(
|
29
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.q_proj"
|
30
|
+
),
|
31
|
+
attn_key_proj=(
|
32
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.k_proj"
|
33
|
+
),
|
34
|
+
attn_value_proj=(
|
35
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.v_proj"
|
36
|
+
),
|
37
|
+
attn_output_proj=(
|
38
|
+
"vision_tower.vision_model.encoder.layers.{}.self_attn.out_proj"
|
39
|
+
),
|
40
|
+
pre_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm1",
|
41
|
+
post_attn_norm="vision_tower.vision_model.encoder.layers.{}.layer_norm2",
|
42
|
+
embedding="vision_tower.vision_model.embeddings.patch_embedding",
|
43
|
+
embedding_position=(
|
44
|
+
"vision_tower.vision_model.embeddings.position_embedding.weight"
|
45
|
+
),
|
46
|
+
final_norm="vision_tower.vision_model.post_layernorm",
|
47
|
+
)
|
48
|
+
|
49
|
+
|
50
|
+
class SiglipVisionEncoder(nn.Module):
|
51
|
+
"""Signlip vision encoder from the Edge Generative API."""
|
52
|
+
|
53
|
+
def __init__(self, config: cfg.ModelConfig):
|
54
|
+
super().__init__()
|
55
|
+
|
56
|
+
# Construct model layers.
|
57
|
+
self.tok_embedding = nn.Conv2d(
|
58
|
+
in_channels=config.image_embedding.channels,
|
59
|
+
out_channels=config.embedding_dim,
|
60
|
+
kernel_size=config.image_embedding.patch_size,
|
61
|
+
stride=config.image_embedding.patch_size,
|
62
|
+
padding="valid",
|
63
|
+
)
|
64
|
+
num_patches = (
|
65
|
+
config.image_embedding.image_size // config.image_embedding.patch_size
|
66
|
+
) ** 2
|
67
|
+
self.tok_embedding_position = nn.Parameter(
|
68
|
+
torch.zeros((num_patches, config.embedding_dim))
|
69
|
+
)
|
70
|
+
|
71
|
+
self.transformer_blocks = nn.ModuleList(
|
72
|
+
attention.TransformerBlock(config.block_config(idx), config)
|
73
|
+
for idx in range(config.num_layers)
|
74
|
+
)
|
75
|
+
self.final_norm = builder.build_norm(
|
76
|
+
config.embedding_dim,
|
77
|
+
config.final_norm_config,
|
78
|
+
)
|
79
|
+
self.config = config
|
80
|
+
|
81
|
+
@torch.inference_mode
|
82
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
83
|
+
# Embed the image according to SiplipVisionEmbeddings.
|
84
|
+
x = self.tok_embedding(pixel_values)
|
85
|
+
x = x.flatten(2).transpose(1, 2) + self.tok_embedding_position
|
86
|
+
|
87
|
+
# Pass a dummy mask because SDPA attention impl expects non-None mask.
|
88
|
+
mask = torch.zeros(x.shape[:2])
|
89
|
+
for _, block in enumerate(self.transformer_blocks):
|
90
|
+
x = block(x, mask=mask)
|
91
|
+
return self.final_norm(x)
|
92
|
+
|
93
|
+
|
94
|
+
def get_image_encoder_config() -> cfg.ModelConfig:
|
95
|
+
"""Returns the model config for the image encoder of a PaliGemma 3B-224 model.
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
The model config for the image encoder of a PaliGemma 3B model.
|
99
|
+
"""
|
100
|
+
image_embedding_config = cfg.ImageEmbeddingConfig(
|
101
|
+
channels=3,
|
102
|
+
image_size=224,
|
103
|
+
patch_size=14,
|
104
|
+
)
|
105
|
+
attn_config = cfg.AttentionConfig(
|
106
|
+
num_heads=16,
|
107
|
+
head_dim=72,
|
108
|
+
num_query_groups=16,
|
109
|
+
qkv_use_bias=True,
|
110
|
+
output_proj_use_bias=True,
|
111
|
+
)
|
112
|
+
ff_config = cfg.FeedForwardConfig(
|
113
|
+
type=cfg.FeedForwardType.SEQUENTIAL,
|
114
|
+
activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
|
115
|
+
intermediate_size=4304,
|
116
|
+
use_bias=True,
|
117
|
+
)
|
118
|
+
norm_config = cfg.NormalizationConfig(
|
119
|
+
type=cfg.NormalizationType.LAYER_NORM,
|
120
|
+
epsilon=1e-6,
|
121
|
+
enable_hlfb=True,
|
122
|
+
)
|
123
|
+
block_config = cfg.TransformerBlockConfig(
|
124
|
+
attn_config=attn_config,
|
125
|
+
ff_config=ff_config,
|
126
|
+
pre_attention_norm_config=norm_config,
|
127
|
+
post_attention_norm_config=norm_config,
|
128
|
+
)
|
129
|
+
config = cfg.ModelConfig(
|
130
|
+
vocab_size=0, # Not used in image encoder.
|
131
|
+
num_layers=27,
|
132
|
+
max_seq_len=0, # Not used in image encoder.
|
133
|
+
embedding_dim=1152,
|
134
|
+
embedding_use_bias=True,
|
135
|
+
image_embedding=image_embedding_config,
|
136
|
+
block_configs=block_config,
|
137
|
+
final_norm_config=norm_config,
|
138
|
+
enable_hlfb=True,
|
139
|
+
)
|
140
|
+
return config
|
141
|
+
|
142
|
+
|
143
|
+
def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
144
|
+
config = get_image_encoder_config()
|
145
|
+
# PaliGemma image encoder has only one block config.
|
146
|
+
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.num_layers = 2
|
148
|
+
return config
|
149
|
+
|
150
|
+
|
151
|
+
def build_image_encoder(checkpoint_path: str) -> SiglipVisionEncoder:
|
152
|
+
config = get_image_encoder_config()
|
153
|
+
encoder = SiglipVisionEncoder(config)
|
154
|
+
loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
|
155
|
+
# Loose the strictness because only image encoder is being loaded.
|
156
|
+
loader.load(encoder, strict=False)
|
157
|
+
encoder.eval()
|
158
|
+
return encoder
|
@@ -0,0 +1,82 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Verifies the reauthored image encoder of PaliGemma 3B model."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import pathlib
|
20
|
+
from absl import app
|
21
|
+
from absl import flags
|
22
|
+
from ai_edge_torch.generative.examples.paligemma import image_encoder
|
23
|
+
from PIL import Image
|
24
|
+
import requests
|
25
|
+
import torch
|
26
|
+
import transformers
|
27
|
+
|
28
|
+
_IMAGE_URL = flags.DEFINE_string(
|
29
|
+
"image_url",
|
30
|
+
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/car.jpg?download=true",
|
31
|
+
"The image URI to encode.",
|
32
|
+
)
|
33
|
+
|
34
|
+
|
35
|
+
def main(_):
|
36
|
+
checkpoint = "google/paligemma-3b-mix-224"
|
37
|
+
logging.info("Loading the original model from: %s", checkpoint)
|
38
|
+
original_full_model = (
|
39
|
+
transformers.PaliGemmaForConditionalGeneration.from_pretrained(checkpoint)
|
40
|
+
)
|
41
|
+
original_vision_model = original_full_model.eval().vision_tower
|
42
|
+
|
43
|
+
# Locate the cached dir.
|
44
|
+
cached_config_file = transformers.utils.cached_file(
|
45
|
+
checkpoint, transformers.utils.CONFIG_NAME
|
46
|
+
)
|
47
|
+
reauthored_checkpoint = pathlib.Path(cached_config_file).parent
|
48
|
+
logging.info("Building the reauthored model from: %s", reauthored_checkpoint)
|
49
|
+
reauthored_model = image_encoder.build_image_encoder(reauthored_checkpoint)
|
50
|
+
|
51
|
+
logging.info("Loading the processor from: %s", checkpoint)
|
52
|
+
# It works only when GemmaTokenizerFast is available. In some environments,
|
53
|
+
# use_fast=False doeesn't work either if the tokenizer cannot load the
|
54
|
+
# sentencepiece model file properly.
|
55
|
+
processor = transformers.AutoProcessor.from_pretrained(checkpoint)
|
56
|
+
|
57
|
+
logging.info("Loading the image from: %s", _IMAGE_URL.value)
|
58
|
+
image = Image.open(requests.get(_IMAGE_URL.value, stream=True).raw)
|
59
|
+
pixel_values = processor(images=image, return_tensors="pt")["pixel_values"]
|
60
|
+
|
61
|
+
logging.info("Forwarding the original model...")
|
62
|
+
outputs_original = original_vision_model.forward(pixel_values=pixel_values)
|
63
|
+
outputs_original = outputs_original.last_hidden_state
|
64
|
+
logging.info("outputs_original: %s", outputs_original)
|
65
|
+
|
66
|
+
logging.info("Forwarding the reauthored model...")
|
67
|
+
outputs_reauthored = reauthored_model.forward(pixel_values=pixel_values)
|
68
|
+
logging.info("outputs_reauthored: %s", outputs_reauthored)
|
69
|
+
|
70
|
+
try:
|
71
|
+
assert torch.allclose(
|
72
|
+
outputs_original, outputs_reauthored, atol=1e-04, rtol=1e-04
|
73
|
+
)
|
74
|
+
except AssertionError as e:
|
75
|
+
logging.error("*** FAILED *** verify with an image")
|
76
|
+
raise e
|
77
|
+
else:
|
78
|
+
logging.info("*** PASSED *** verify with an image")
|
79
|
+
|
80
|
+
|
81
|
+
if __name__ == "__main__":
|
82
|
+
app.run(main)
|
@@ -235,9 +235,10 @@ class CausalSelfAttention(nn.Module):
|
|
235
235
|
k = k.reshape(B, T, -1, self.config.head_dim)
|
236
236
|
v = v.reshape(B, T, -1, self.config.head_dim)
|
237
237
|
|
238
|
-
|
239
|
-
|
240
|
-
|
238
|
+
if rope is not None:
|
239
|
+
# Compute rotary positional embedding for query and key.
|
240
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
241
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
241
242
|
|
242
243
|
if kv_cache is not None:
|
243
244
|
kv_cache = kv_utils.update(
|
@@ -372,9 +373,10 @@ class CrossAttention(nn.Module):
|
|
372
373
|
k = k.view(interim_shape)
|
373
374
|
v = v.view(interim_shape)
|
374
375
|
|
375
|
-
|
376
|
-
|
377
|
-
|
376
|
+
if rope is not None:
|
377
|
+
# Compute rotary positional embedding for query and key.
|
378
|
+
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
379
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
378
380
|
|
379
381
|
if kv_cache is not None:
|
380
382
|
kv_cache = kv_utils.update(
|
@@ -163,6 +163,16 @@ class TransformerBlockConfig:
|
|
163
163
|
relative_attention: bool = False
|
164
164
|
|
165
165
|
|
166
|
+
@dataclass
|
167
|
+
class ImageEmbeddingConfig:
|
168
|
+
"""Image embedding parameters."""
|
169
|
+
|
170
|
+
channels: int
|
171
|
+
# All images should be normalized to the size of [image_size * image_size].
|
172
|
+
image_size: int
|
173
|
+
patch_size: int
|
174
|
+
|
175
|
+
|
166
176
|
@dataclass
|
167
177
|
class ModelConfig:
|
168
178
|
"""Base configurations for building a transformer architecture."""
|
@@ -183,6 +193,10 @@ class ModelConfig:
|
|
183
193
|
|
184
194
|
# Scale factor of the embedding.
|
185
195
|
embedding_scale: Optional[float] = None
|
196
|
+
# Use bias term within embedding.
|
197
|
+
embedding_use_bias: bool = False
|
198
|
+
# Image embedding parameters.
|
199
|
+
image_embedding: Optional[ImageEmbeddingConfig] = None
|
186
200
|
|
187
201
|
# Use bias term within LLM's HEAD.
|
188
202
|
lm_head_use_bias: bool = False
|
@@ -115,12 +115,15 @@ class AttentionBlock2D(nn.Module):
|
|
115
115
|
"""
|
116
116
|
super().__init__()
|
117
117
|
self.config = config
|
118
|
+
hidden_dim = config.hidden_dim
|
119
|
+
if not hidden_dim:
|
120
|
+
hidden_dim = config.dim
|
118
121
|
self.norm = layers_builder.build_norm(
|
119
|
-
|
122
|
+
hidden_dim, config.normalization_config
|
120
123
|
)
|
121
124
|
self.attention = SelfAttention(
|
122
125
|
config.attention_batch_size,
|
123
|
-
|
126
|
+
hidden_dim,
|
124
127
|
config.attention_config,
|
125
128
|
enable_hlfb=config.enable_hlfb,
|
126
129
|
)
|
@@ -172,7 +175,7 @@ class CrossAttentionBlock2D(nn.Module):
|
|
172
175
|
super().__init__()
|
173
176
|
self.config = config
|
174
177
|
self.norm = layers_builder.build_norm(
|
175
|
-
config.
|
178
|
+
config.output_dim, config.normalization_config
|
176
179
|
)
|
177
180
|
self.attention = CrossAttention(
|
178
181
|
config.attention_batch_size,
|
@@ -294,21 +297,30 @@ class TransformerBlock2D(nn.Module):
|
|
294
297
|
hidden_states
|
295
298
|
"""
|
296
299
|
|
297
|
-
def __init__(
|
300
|
+
def __init__(
|
301
|
+
self, config: unet_cfg.TransformerBlock2DConfig, dim_override=None
|
302
|
+
):
|
298
303
|
"""Initialize an instance of the TransformerBlock2D.
|
299
304
|
|
300
305
|
Args:
|
301
306
|
config (unet_cfg.TransformerBlock2Dconfig): the configuration of this
|
302
307
|
block.
|
308
|
+
dim_override: in case specified, overrides config.attention_block_config.hidden_dim. Set to None by default.
|
303
309
|
"""
|
304
310
|
super().__init__()
|
305
311
|
self.config = config
|
312
|
+
attention_block_config_dim = config.attention_block_config.dim
|
313
|
+
attention_block_config_hidden_dim = config.attention_block_config.hidden_dim
|
314
|
+
if dim_override:
|
315
|
+
attention_block_config_dim = dim_override
|
316
|
+
if not attention_block_config_hidden_dim:
|
317
|
+
attention_block_config_hidden_dim = attention_block_config_dim
|
306
318
|
self.pre_conv_norm = layers_builder.build_norm(
|
307
|
-
|
319
|
+
attention_block_config_dim, config.pre_conv_normalization_config
|
308
320
|
)
|
309
321
|
self.conv_in = nn.Conv2d(
|
310
|
-
|
311
|
-
|
322
|
+
attention_block_config_dim,
|
323
|
+
attention_block_config_hidden_dim,
|
312
324
|
kernel_size=1,
|
313
325
|
padding=0,
|
314
326
|
)
|
@@ -318,8 +330,8 @@ class TransformerBlock2D(nn.Module):
|
|
318
330
|
)
|
319
331
|
self.feed_forward = FeedForwardBlock2D(config.feed_forward_block_config)
|
320
332
|
self.conv_out = nn.Conv2d(
|
321
|
-
|
322
|
-
|
333
|
+
attention_block_config_hidden_dim,
|
334
|
+
attention_block_config_dim,
|
323
335
|
kernel_size=1,
|
324
336
|
padding=0,
|
325
337
|
)
|
@@ -385,14 +397,18 @@ class DownEncoderBlock2D(nn.Module):
|
|
385
397
|
self.config = config
|
386
398
|
resnets = []
|
387
399
|
transformers = []
|
400
|
+
hidden_channels = config.hidden_channels
|
401
|
+
if not hidden_channels:
|
402
|
+
hidden_channels = config.out_channels
|
388
403
|
for i in range(config.num_layers):
|
389
404
|
input_channels = config.in_channels if i == 0 else config.out_channels
|
390
405
|
resnets.append(
|
391
406
|
ResidualBlock2D(
|
392
407
|
unet_cfg.ResidualBlock2DConfig(
|
393
408
|
in_channels=input_channels,
|
394
|
-
hidden_channels=
|
409
|
+
hidden_channels=hidden_channels,
|
395
410
|
out_channels=config.out_channels,
|
411
|
+
residual_out_channels=config.out_channels,
|
396
412
|
time_embedding_channels=config.time_embedding_channels,
|
397
413
|
normalization_config=config.normalization_config,
|
398
414
|
activation_config=config.activation_config,
|
@@ -589,23 +605,37 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
589
605
|
"""
|
590
606
|
super().__init__()
|
591
607
|
self.config = config
|
608
|
+
hidden_channels = config.hidden_channels
|
609
|
+
if not hidden_channels:
|
610
|
+
hidden_channels = config.out_channels
|
611
|
+
sub_block_channels = config.sub_block_channels
|
612
|
+
if sub_block_channels:
|
613
|
+
assert len(sub_block_channels) == config.num_layers, (
|
614
|
+
"Assertion failed: The length of 'sub_block_channels'"
|
615
|
+
f" ({len(sub_block_channels)}) does not match 'config.num_layers'"
|
616
|
+
f" ({config.num_layers})."
|
617
|
+
)
|
618
|
+
else:
|
619
|
+
sub_block_channels = [config.out_channels] * config.num_layers
|
592
620
|
resnets = []
|
593
621
|
transformers = []
|
594
622
|
for i in range(config.num_layers):
|
623
|
+
resnet_in_channels = (
|
624
|
+
config.prev_out_channels if i == 0 else sub_block_channels[i - 1]
|
625
|
+
)
|
595
626
|
res_skip_channels = (
|
596
627
|
config.in_channels
|
597
628
|
if (i == config.num_layers - 1)
|
598
629
|
else config.out_channels
|
599
630
|
)
|
600
|
-
|
601
|
-
config.prev_out_channels if i == 0 else config.out_channels
|
602
|
-
)
|
631
|
+
residual_out_channel = sub_block_channels[i]
|
603
632
|
resnets.append(
|
604
633
|
ResidualBlock2D(
|
605
634
|
unet_cfg.ResidualBlock2DConfig(
|
606
635
|
in_channels=resnet_in_channels + res_skip_channels,
|
607
|
-
hidden_channels=
|
608
|
-
out_channels=
|
636
|
+
hidden_channels=hidden_channels,
|
637
|
+
out_channels=sub_block_channels[i],
|
638
|
+
residual_out_channels=residual_out_channel,
|
609
639
|
time_embedding_channels=config.time_embedding_channels,
|
610
640
|
normalization_config=config.normalization_config,
|
611
641
|
activation_config=config.activation_config,
|
@@ -613,7 +643,12 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
613
643
|
)
|
614
644
|
)
|
615
645
|
if config.transformer_block_config:
|
616
|
-
transformers.append(
|
646
|
+
transformers.append(
|
647
|
+
TransformerBlock2D(
|
648
|
+
config.transformer_block_config,
|
649
|
+
dim_override=sub_block_channels[i],
|
650
|
+
)
|
651
|
+
)
|
617
652
|
self.resnets = nn.ModuleList(resnets)
|
618
653
|
self.transformers = (
|
619
654
|
nn.ModuleList(transformers) if len(transformers) > 0 else None
|
@@ -623,7 +658,7 @@ class SkipUpDecoderBlock2D(nn.Module):
|
|
623
658
|
if config.upsample_conv:
|
624
659
|
self.upsample_conv = nn.Conv2d(
|
625
660
|
config.out_channels,
|
626
|
-
|
661
|
+
sub_block_channels[0],
|
627
662
|
kernel_size=3,
|
628
663
|
stride=1,
|
629
664
|
padding=1,
|
@@ -711,6 +746,7 @@ class MidBlock2D(nn.Module):
|
|
711
746
|
in_channels=config.in_channels,
|
712
747
|
hidden_channels=config.in_channels,
|
713
748
|
out_channels=config.in_channels,
|
749
|
+
residual_out_channels=config.in_channels,
|
714
750
|
time_embedding_channels=config.time_embedding_channels,
|
715
751
|
normalization_config=config.normalization_config,
|
716
752
|
activation_config=config.activation_config,
|
@@ -50,10 +50,12 @@ class ResidualBlock2DConfig:
|
|
50
50
|
in_channels: int
|
51
51
|
hidden_channels: int
|
52
52
|
out_channels: int
|
53
|
+
hidden_channels: int
|
53
54
|
normalization_config: layers_cfg.NormalizationConfig
|
54
55
|
activation_config: layers_cfg.ActivationConfig
|
55
56
|
# Optional time embedding channels if the residual block takes a time embedding context as input
|
56
57
|
time_embedding_channels: Optional[int] = None
|
58
|
+
residual_out_channels: Optional[int] = None
|
57
59
|
|
58
60
|
|
59
61
|
@dataclasses.dataclass
|
@@ -63,6 +65,7 @@ class AttentionBlock2DConfig:
|
|
63
65
|
attention_config: layers_cfg.AttentionConfig
|
64
66
|
enable_hlfb: bool = True
|
65
67
|
attention_batch_size: int = 1
|
68
|
+
hidden_dim: Optional[int] = None
|
66
69
|
|
67
70
|
|
68
71
|
@dataclasses.dataclass
|
@@ -101,6 +104,8 @@ class UpDecoderBlock2DConfig:
|
|
101
104
|
normalization_config: layers_cfg.NormalizationConfig
|
102
105
|
activation_config: layers_cfg.ActivationConfig
|
103
106
|
num_layers: int
|
107
|
+
# The dimension of output channels of previous connected block
|
108
|
+
prev_out_channels: Optional[int] = None
|
104
109
|
# Optional time embedding channels if the residual blocks take a time embedding as input
|
105
110
|
time_embedding_channels: Optional[int] = None
|
106
111
|
# Whether to add upsample operation after residual blocks
|
@@ -136,6 +141,8 @@ class SkipUpDecoderBlock2DConfig:
|
|
136
141
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
137
142
|
# Optional dimension of context tensor if context tensor is given as input.
|
138
143
|
context_dim: Optional[int] = None
|
144
|
+
sub_block_channels: Optional[tuple] = None
|
145
|
+
hidden_channels: Optional[int] = None
|
139
146
|
|
140
147
|
|
141
148
|
@dataclasses.dataclass
|
@@ -157,6 +164,7 @@ class DownEncoderBlock2DConfig:
|
|
157
164
|
transformer_block_config: Optional[TransformerBlock2DConfig] = None
|
158
165
|
# Optional dimension of context tensor if context tensor is given as input.
|
159
166
|
context_dim: Optional[int] = None
|
167
|
+
hidden_channels: Optional[int] = None
|
160
168
|
|
161
169
|
|
162
170
|
@dataclasses.dataclass
|
@@ -157,6 +157,10 @@ class ModelLoader:
|
|
157
157
|
converted_state["tok_embedding.weight"] = state.pop(
|
158
158
|
f"{self._names.embedding}.weight"
|
159
159
|
)
|
160
|
+
if model.config.embedding_use_bias:
|
161
|
+
converted_state["tok_embedding.bias"] = state.pop(
|
162
|
+
f"{self._names.embedding}.bias"
|
163
|
+
)
|
160
164
|
if self._names.embedding_position is not None:
|
161
165
|
converted_state["tok_embedding_position"] = state.pop(
|
162
166
|
f"{self._names.embedding_position}"
|
@@ -228,7 +228,7 @@ def verify_reauthored_model(
|
|
228
228
|
rtol: float = 1e-05,
|
229
229
|
atol: float = 1e-05,
|
230
230
|
continue_on_failure: bool = False,
|
231
|
-
):
|
231
|
+
) -> bool:
|
232
232
|
"""Verifies the reauthored model against the original model.
|
233
233
|
|
234
234
|
It verifies the reauthored model with two methods:
|
@@ -237,7 +237,8 @@ def verify_reauthored_model(
|
|
237
237
|
2. It compares the answer generated by the original and the reauthored model
|
238
238
|
with a prompt.
|
239
239
|
|
240
|
-
It prints out "PASS" or "FAILED" to the console.
|
240
|
+
It prints out "PASS" or "FAILED" to the console. It returns True if all
|
241
|
+
verification passes, False otherwise.
|
241
242
|
|
242
243
|
Args:
|
243
244
|
original_model (ModelWrapper): The original model.
|
@@ -253,6 +254,8 @@ def verify_reauthored_model(
|
|
253
254
|
continue_on_failure (bool): If True, it continues to verify the next prompt
|
254
255
|
or input IDs even if a previous one fails.
|
255
256
|
"""
|
257
|
+
failure_count = 0
|
258
|
+
|
256
259
|
for input_ids in forward_input_ids:
|
257
260
|
logging.info("Verifying the reauthored model with input IDs: %s", input_ids)
|
258
261
|
try:
|
@@ -261,8 +264,9 @@ def verify_reauthored_model(
|
|
261
264
|
)
|
262
265
|
except AssertionError as e:
|
263
266
|
logging.error("*** FAILED *** verify with input IDs: %s", input_ids)
|
267
|
+
failure_count += 1
|
264
268
|
if not continue_on_failure:
|
265
|
-
|
269
|
+
return False
|
266
270
|
else:
|
267
271
|
logging.info("*** PASSED *** verify with input IDs: %s", input_ids)
|
268
272
|
|
@@ -274,7 +278,15 @@ def verify_reauthored_model(
|
|
274
278
|
)
|
275
279
|
except AssertionError as e:
|
276
280
|
logging.error("*** FAILED *** verify with prompts: %s", prompts)
|
281
|
+
failure_count += 1
|
277
282
|
if not continue_on_failure:
|
278
|
-
|
283
|
+
return False
|
279
284
|
else:
|
280
285
|
logging.info("*** PASSED *** verify with prompts: %s", prompts)
|
286
|
+
|
287
|
+
if failure_count == 0:
|
288
|
+
logging.info("*** PASSED *** verify_reauthored_model")
|
289
|
+
return True
|
290
|
+
else:
|
291
|
+
logging.error("*** FAILED *** verify_reauthored_model")
|
292
|
+
return False
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241115
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=pp4KVtq0a8ju4UB5nOeiv7QDkmgpHmz5XUokSR86qfI,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
|
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=m_yj66V11LmWCYgA7yLtr__cy14IbC5WEJe0BE0_IPE,4339
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
@@ -50,8 +50,8 @@ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6
|
|
50
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
53
|
-
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=
|
53
|
+
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
54
|
+
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
55
55
|
ai_edge_torch/generative/examples/llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
56
56
|
ai_edge_torch/generative/examples/llama/convert_to_tflite.py,sha256=P0-pByTM5tslE23ILgo7nd0nOGE25ciBRG5wKJj0bBk,2411
|
57
57
|
ai_edge_torch/generative/examples/llama/llama.py,sha256=AMcCbuDBxEfbO-l3KiEXbUaXEJ3RLLwkHii7to7UhVo,6854
|
@@ -62,7 +62,9 @@ ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFE
|
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
63
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
64
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=JSb9h3gcIh5oYrbLU6rI8OU8FzfWeTCFJT5XRWu4btE,3675
|
65
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=v19_EKALhAP9FjkINKqpv8JsVaQ6iH_7X5FpnhE6abw,5500
|
65
66
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
67
|
+
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
66
68
|
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
67
69
|
ai_edge_torch/generative/examples/phi/convert_phi3_to_tflite.py,sha256=rkbTtMaqSVG48cm-NTxR_LDgZmXAEBqayTm9O49oMXc,2171
|
68
70
|
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=3go690yX6PFeXMdpY7y4JZorAwxX0HT_b_pKZieauvk,2169
|
@@ -108,19 +110,19 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
108
110
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
109
111
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
110
112
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
111
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
113
|
+
ai_edge_torch/generative/layers/attention.py,sha256=zN3BQjA25Ej_aRU0rFnyx--K74xf5ykc02zGvUpYHeE,13295
|
112
114
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
113
115
|
ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
|
114
116
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
115
117
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
116
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
118
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=xqa7ZBEjgK4UWJAThRXb_VBFZ5KCGtDu-QaY5GXar9s,7366
|
117
119
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
118
120
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
119
121
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
120
122
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
121
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
123
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
122
124
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
123
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
125
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
|
124
126
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
125
127
|
ai_edge_torch/generative/quantize/example.py,sha256=1lfVNUd2cEyRUnoZ7BLbRJ9IN-FTKiWBtZNPFUzAiWE,1747
|
126
128
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -137,12 +139,12 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
137
139
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
138
140
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
139
141
|
ai_edge_torch/generative/utilities/converter.py,sha256=17O83wVifH1vQJCI4WC3DaNiCIOyK2gys1GzohbLrRs,5554
|
140
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
142
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=k5fjCokNomte4ymy9IJrEWAuCSMhsPCJfmv1y5s0ZEc,13452
|
141
143
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=89jt80UUfDzYBi-x077HBavWeuNJuYPXym9fiKCY1Tk,5278
|
142
144
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
143
145
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
144
146
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
145
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
147
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=h5hGyIpYGyPZwvelbzpdkjy99Kpd4JkvhqWtQN9cm-M,10413
|
146
148
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
147
149
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
148
150
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -189,8 +191,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
189
191
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
190
192
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
191
193
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
192
|
-
ai_edge_torch_nightly-0.3.0.
|
193
|
-
ai_edge_torch_nightly-0.3.0.
|
194
|
-
ai_edge_torch_nightly-0.3.0.
|
195
|
-
ai_edge_torch_nightly-0.3.0.
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
194
|
+
ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
195
|
+
ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/METADATA,sha256=epuuYZFnqVvLzIS0X27XMCFQpnc-dO8JJQ8DXVNv5IE,1897
|
196
|
+
ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241115.dist-info/RECORD,,
|
File without changes
|
File without changes
|