ai-edge-torch-nightly 0.3.0.dev20241117__py3-none-any.whl → 0.3.0.dev20241120__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -2
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py +6 -20
- ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py +80 -0
- ai_edge_torch/generative/examples/paligemma/image_encoder.py +3 -1
- ai_edge_torch/generative/examples/paligemma/paligemma.py +25 -9
- ai_edge_torch/generative/layers/model_config.py +23 -20
- ai_edge_torch/generative/test/test_model_conversion.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion_large.py +50 -1
- ai_edge_torch/generative/test/utils.py +6 -3
- ai_edge_torch/generative/utilities/converter.py +71 -87
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241117.dist-info → ai_edge_torch_nightly-0.3.0.dev20241120.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241117.dist-info → ai_edge_torch_nightly-0.3.0.dev20241120.dist-info}/RECORD +17 -16
- {ai_edge_torch_nightly-0.3.0.dev20241117.dist-info → ai_edge_torch_nightly-0.3.0.dev20241120.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241117.dist-info → ai_edge_torch_nightly-0.3.0.dev20241120.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241117.dist-info → ai_edge_torch_nightly-0.3.0.dev20241120.dist-info}/top_level.txt +0 -0
@@ -49,7 +49,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
49
49
|
output = internal_match.returning_nodes[0]
|
50
50
|
output_h, output_w = output.meta["val"].shape[-2:]
|
51
51
|
return {
|
52
|
-
"
|
52
|
+
"size": (int(output_h), int(output_w)),
|
53
53
|
"align_corners": False,
|
54
54
|
"is_nchw_op": True,
|
55
55
|
}
|
@@ -73,7 +73,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
73
73
|
output = internal_match.returning_nodes[0]
|
74
74
|
output_h, output_w = output.meta["val"].shape[-2:]
|
75
75
|
return {
|
76
|
-
"
|
76
|
+
"size": (int(output_h), int(output_w)),
|
77
77
|
"align_corners": True,
|
78
78
|
"is_nchw_op": True,
|
79
79
|
}
|
@@ -39,6 +39,7 @@ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
39
39
|
return TestModule(func).eval()
|
40
40
|
|
41
41
|
|
42
|
+
@googletest.skip('Temporary outage due to changes for b/377531086')
|
42
43
|
class TestConvertComposites(googletest.TestCase):
|
43
44
|
"""Tests conversion modules that are meant to be wrapped as composites."""
|
44
45
|
|
@@ -13,9 +13,8 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
|
-
"""Example
|
16
|
+
"""Example to convert a Gemma2 model to multiple prefill length tflite model."""
|
17
17
|
|
18
|
-
import logging
|
19
18
|
import os
|
20
19
|
import pathlib
|
21
20
|
|
@@ -35,9 +34,9 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
35
34
|
'The tflite file path to export.',
|
36
35
|
)
|
37
36
|
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
38
|
-
'
|
37
|
+
'prefill_seq_lens',
|
39
38
|
(8, 64, 128, 256, 512, 1024),
|
40
|
-
'
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
41
40
|
)
|
42
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
43
42
|
'kv_cache_max_len',
|
@@ -51,32 +50,19 @@ _QUANTIZE = flags.DEFINE_bool(
|
|
51
50
|
)
|
52
51
|
|
53
52
|
|
54
|
-
|
55
|
-
# now. The main purpose for this function is to allow you export a tflite model
|
56
|
-
# with multiple prefill signatures for different prefill lengths for faster
|
57
|
-
# inference.
|
58
|
-
def convert_to_tflite_multi_prefill_lens():
|
53
|
+
def main(_):
|
59
54
|
pytorch_model = gemma2.build_2b_model(
|
60
55
|
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
61
56
|
)
|
62
57
|
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
63
58
|
output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
64
|
-
converter.
|
59
|
+
converter.convert_to_tflite(
|
65
60
|
pytorch_model,
|
66
61
|
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
67
|
-
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
68
63
|
quantize=_QUANTIZE.value,
|
69
64
|
)
|
70
65
|
|
71
66
|
|
72
|
-
def main(_):
|
73
|
-
if len(_PREFILL_SEQ_LENS.value) > 1:
|
74
|
-
# If multiple prefill lengths are provided, export a model with multiple
|
75
|
-
# prefill signatures each for a different prefill length.
|
76
|
-
convert_to_tflite_multi_prefill_lens()
|
77
|
-
else:
|
78
|
-
logging.warning('Need more than one prefill lengths to be specified.')
|
79
|
-
|
80
|
-
|
81
67
|
if __name__ == '__main__':
|
82
68
|
app.run(main)
|
@@ -0,0 +1,80 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example of converting a PaliGemma model to multi-signature tflite model.
|
17
|
+
|
18
|
+
DISCLAIMER: It works only with ODML Torch conversion backend. Refer to
|
19
|
+
https://github.com/google-ai-edge/ai-edge-torch/blob/main/docs/pytorch_converter/README.md#use-odml-torch-conversion-backend-experimental.
|
20
|
+
"""
|
21
|
+
|
22
|
+
import os
|
23
|
+
import pathlib
|
24
|
+
|
25
|
+
from absl import app
|
26
|
+
from absl import flags
|
27
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
28
|
+
from ai_edge_torch.generative.utilities import converter
|
29
|
+
import torch
|
30
|
+
|
31
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
32
|
+
'checkpoint_path',
|
33
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/paligemma-3b-224'),
|
34
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
35
|
+
)
|
36
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
37
|
+
'tflite_path',
|
38
|
+
'/tmp/',
|
39
|
+
'The tflite file path to export.',
|
40
|
+
)
|
41
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
42
|
+
'prefill_seq_len',
|
43
|
+
1024,
|
44
|
+
'The maximum size of prefill input tensor.',
|
45
|
+
)
|
46
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
47
|
+
'kv_cache_max_len',
|
48
|
+
1280,
|
49
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
50
|
+
)
|
51
|
+
_PIXEL_VALUES_SIZE = flags.DEFINE_multi_integer(
|
52
|
+
'pixel_values_size',
|
53
|
+
[3, 224, 224],
|
54
|
+
'The size of prefill pixel values except the batch dimension.',
|
55
|
+
)
|
56
|
+
_QUANTIZE = flags.DEFINE_bool(
|
57
|
+
'quantize',
|
58
|
+
True,
|
59
|
+
'Whether the model should be quantized.',
|
60
|
+
)
|
61
|
+
|
62
|
+
|
63
|
+
def main(_):
|
64
|
+
pytorch_model = paligemma.build_model(
|
65
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
66
|
+
)
|
67
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
68
|
+
output_filename = f'paligemma_{quant_suffix}_seq{_PREFILL_SEQ_LEN.value}_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
69
|
+
converter.convert_to_tflite(
|
70
|
+
pytorch_model,
|
71
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
72
|
+
prefill_seq_len=_PREFILL_SEQ_LEN.value,
|
73
|
+
pixel_values_size=torch.Size(_PIXEL_VALUES_SIZE.value),
|
74
|
+
quantize=_QUANTIZE.value,
|
75
|
+
config=pytorch_model.config.decoder_config,
|
76
|
+
)
|
77
|
+
|
78
|
+
|
79
|
+
if __name__ == '__main__':
|
80
|
+
app.run(main)
|
@@ -59,7 +59,7 @@ class SiglipVisionEncoder(nn.Module):
|
|
59
59
|
out_channels=config.embedding_dim,
|
60
60
|
kernel_size=config.image_embedding.patch_size,
|
61
61
|
stride=config.image_embedding.patch_size,
|
62
|
-
padding=
|
62
|
+
padding=0,
|
63
63
|
)
|
64
64
|
num_patches = (
|
65
65
|
config.image_embedding.image_size // config.image_embedding.patch_size
|
@@ -144,6 +144,8 @@ def get_fake_image_encoder_config() -> cfg.ModelConfig:
|
|
144
144
|
config = get_image_encoder_config()
|
145
145
|
# PaliGemma image encoder has only one block config.
|
146
146
|
config.block_config(0).ff_config.intermediate_size = 128
|
147
|
+
config.image_embedding.image_size = 8
|
148
|
+
config.image_embedding.patch_size = 2
|
147
149
|
config.num_layers = 2
|
148
150
|
return config
|
149
151
|
|
@@ -54,6 +54,10 @@ class PaliGemma(nn.Module):
|
|
54
54
|
bias=config.image_projection_use_bias,
|
55
55
|
)
|
56
56
|
self.decoder = decoder.Decoder(config.decoder_config)
|
57
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
58
|
+
self.num_patches = (
|
59
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
60
|
+
) ** 2
|
57
61
|
self.config = config
|
58
62
|
|
59
63
|
@torch.inference_mode
|
@@ -74,10 +78,22 @@ class PaliGemma(nn.Module):
|
|
74
78
|
if self.config.decoder_config.embedding_scale is not None:
|
75
79
|
image_embeds = image_embeds / self.config.decoder_config.embedding_scale
|
76
80
|
|
77
|
-
#
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
+
# Merging image_embeds into text_embeds as PaliGemmaForConditionalGeneration
|
82
|
+
# can be done like:
|
83
|
+
#
|
84
|
+
# image_mask = tokens == self.config.image_token_id
|
85
|
+
# image_mask = image_mask.unsqueeze(-1).expand_as(input_embeds)
|
86
|
+
# input_embeds = input_embeds.masked_scatter(image_mask, image_embeds)
|
87
|
+
#
|
88
|
+
# Unfortunately, torch.Tensor.masked_scatter can't be lowered on CPU.
|
89
|
+
# Since PaliGemma token embedder reserves the first [num_patches] tokens
|
90
|
+
# for image tokens, we can use this property to merge image_embeds into
|
91
|
+
# input_embeds by concatenating them.
|
92
|
+
assert image_embeds.shape[1] == self.num_patches
|
93
|
+
assert input_embeds.shape[1] >= self.num_patches
|
94
|
+
input_embeds = torch.cat(
|
95
|
+
(image_embeds, input_embeds[:, self.num_patches:, :]), dim=1
|
96
|
+
)
|
81
97
|
|
82
98
|
return self.decoder(
|
83
99
|
tokens=None,
|
@@ -87,7 +103,7 @@ class PaliGemma(nn.Module):
|
|
87
103
|
)
|
88
104
|
|
89
105
|
|
90
|
-
def get_model_config() -> PaliGemmaConfig:
|
106
|
+
def get_model_config(**kwargs) -> PaliGemmaConfig:
|
91
107
|
"""Returns the model config for a PaliGemma 3B-224 model.
|
92
108
|
|
93
109
|
Returns:
|
@@ -95,13 +111,13 @@ def get_model_config() -> PaliGemmaConfig:
|
|
95
111
|
"""
|
96
112
|
return PaliGemmaConfig(
|
97
113
|
image_encoder_config=image_encoder.get_image_encoder_config(),
|
98
|
-
decoder_config=decoder.get_decoder_config(),
|
114
|
+
decoder_config=decoder.get_decoder_config(**kwargs),
|
99
115
|
image_projection_use_bias=True,
|
100
116
|
image_token_id=257152,
|
101
117
|
)
|
102
118
|
|
103
119
|
|
104
|
-
def
|
120
|
+
def get_fake_model_config() -> PaliGemmaConfig:
|
105
121
|
return PaliGemmaConfig(
|
106
122
|
image_encoder_config=image_encoder.get_fake_image_encoder_config(),
|
107
123
|
decoder_config=decoder.get_fake_decoder_config(),
|
@@ -110,8 +126,8 @@ def get_fake_image_encoder_config() -> PaliGemmaConfig:
|
|
110
126
|
)
|
111
127
|
|
112
128
|
|
113
|
-
def build_model(checkpoint_path: str) -> PaliGemma:
|
114
|
-
config = get_model_config()
|
129
|
+
def build_model(checkpoint_path: str, **kwargs) -> PaliGemma:
|
130
|
+
config = get_model_config(**kwargs)
|
115
131
|
model = PaliGemma(config)
|
116
132
|
# Load the parameters of image encoder.
|
117
133
|
loader = loading_utils.ModelLoader(
|
@@ -12,9 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
|
16
|
+
"""Model configuration class."""
|
17
|
+
|
18
|
+
import dataclasses
|
18
19
|
import enum
|
19
20
|
from typing import Optional, Sequence, Union
|
20
21
|
|
@@ -35,7 +36,7 @@ class ActivationType(enum.Enum):
|
|
35
36
|
|
36
37
|
@enum.unique
|
37
38
|
class NormalizationType(enum.Enum):
|
38
|
-
"""Different normalization functions"""
|
39
|
+
"""Different normalization functions."""
|
39
40
|
|
40
41
|
# No normalization is applied.
|
41
42
|
NONE = enum.auto()
|
@@ -59,7 +60,7 @@ class AttentionType(enum.Enum):
|
|
59
60
|
LOCAL_SLIDING = enum.auto()
|
60
61
|
|
61
62
|
|
62
|
-
@dataclass
|
63
|
+
@dataclasses.dataclass
|
63
64
|
class NormalizationConfig:
|
64
65
|
"""Normalizater parameters."""
|
65
66
|
|
@@ -71,7 +72,7 @@ class NormalizationConfig:
|
|
71
72
|
group_num: Optional[float] = None
|
72
73
|
|
73
74
|
|
74
|
-
@dataclass
|
75
|
+
@dataclasses.dataclass
|
75
76
|
class AttentionConfig:
|
76
77
|
"""Attention model's parameters."""
|
77
78
|
|
@@ -90,18 +91,20 @@ class AttentionConfig:
|
|
90
91
|
# Whether to use bias with Query, Key, and Value projection.
|
91
92
|
qkv_use_bias: bool = False
|
92
93
|
# Whether the fused q, k, v projection weights interleaves q, k, v heads.
|
93
|
-
# If True, the projection weights are in format
|
94
|
-
#
|
94
|
+
# If True, the projection weights are in format:
|
95
|
+
# `[q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]`
|
96
|
+
# If False, the projection weights are in format:
|
97
|
+
# `[q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]`
|
95
98
|
qkv_fused_interleaved: bool = True
|
96
99
|
# Whether to use bias with attention output projection.
|
97
100
|
output_proj_use_bias: bool = False
|
98
101
|
enable_kv_cache: bool = True
|
99
102
|
# The normalization applied to query projection's output.
|
100
|
-
query_norm_config: NormalizationConfig = field(
|
103
|
+
query_norm_config: NormalizationConfig = dataclasses.field(
|
101
104
|
default_factory=NormalizationConfig
|
102
105
|
)
|
103
106
|
# The normalization applied to key projection's output.
|
104
|
-
key_norm_config: NormalizationConfig = field(
|
107
|
+
key_norm_config: NormalizationConfig = dataclasses.field(
|
105
108
|
default_factory=NormalizationConfig
|
106
109
|
)
|
107
110
|
relative_attention_num_buckets: int = 0
|
@@ -114,7 +117,7 @@ class AttentionConfig:
|
|
114
117
|
sliding_window_size: Optional[int] = None
|
115
118
|
|
116
119
|
|
117
|
-
@dataclass
|
120
|
+
@dataclasses.dataclass
|
118
121
|
class ActivationConfig:
|
119
122
|
type: ActivationType = ActivationType.LINEAR
|
120
123
|
# Dimension of input and output, used in GeGLU.
|
@@ -122,7 +125,7 @@ class ActivationConfig:
|
|
122
125
|
dim_out: Optional[int] = None
|
123
126
|
|
124
127
|
|
125
|
-
@dataclass
|
128
|
+
@dataclasses.dataclass
|
126
129
|
class FeedForwardConfig:
|
127
130
|
"""FeedForward module's parameters."""
|
128
131
|
|
@@ -131,27 +134,27 @@ class FeedForwardConfig:
|
|
131
134
|
intermediate_size: int
|
132
135
|
use_bias: bool = False
|
133
136
|
# The normalization applied to feed forward's input.
|
134
|
-
pre_ff_norm_config: NormalizationConfig = field(
|
137
|
+
pre_ff_norm_config: NormalizationConfig = dataclasses.field(
|
135
138
|
default_factory=NormalizationConfig
|
136
139
|
)
|
137
140
|
# The normalization applied to feed forward's output.
|
138
|
-
post_ff_norm_config: NormalizationConfig = field(
|
141
|
+
post_ff_norm_config: NormalizationConfig = dataclasses.field(
|
139
142
|
default_factory=NormalizationConfig
|
140
143
|
)
|
141
144
|
|
142
145
|
|
143
|
-
@dataclass
|
146
|
+
@dataclasses.dataclass
|
144
147
|
class TransformerBlockConfig:
|
145
148
|
"""TransformerBlock module's parameters."""
|
146
149
|
|
147
150
|
attn_config: AttentionConfig
|
148
151
|
ff_config: FeedForwardConfig
|
149
152
|
# The normalization applied to attention's input.
|
150
|
-
pre_attention_norm_config: NormalizationConfig = field(
|
153
|
+
pre_attention_norm_config: NormalizationConfig = dataclasses.field(
|
151
154
|
default_factory=NormalizationConfig
|
152
155
|
)
|
153
156
|
# The normalization applied to attentions's output.
|
154
|
-
post_attention_norm_config: NormalizationConfig = field(
|
157
|
+
post_attention_norm_config: NormalizationConfig = dataclasses.field(
|
155
158
|
default_factory=NormalizationConfig
|
156
159
|
)
|
157
160
|
# If set to True, only attn_config.pre_attention_norm is applied to the input
|
@@ -163,7 +166,7 @@ class TransformerBlockConfig:
|
|
163
166
|
relative_attention: bool = False
|
164
167
|
|
165
168
|
|
166
|
-
@dataclass
|
169
|
+
@dataclasses.dataclass
|
167
170
|
class ImageEmbeddingConfig:
|
168
171
|
"""Image embedding parameters."""
|
169
172
|
|
@@ -173,7 +176,7 @@ class ImageEmbeddingConfig:
|
|
173
176
|
patch_size: int
|
174
177
|
|
175
178
|
|
176
|
-
@dataclass
|
179
|
+
@dataclasses.dataclass
|
177
180
|
class ModelConfig:
|
178
181
|
"""Base configurations for building a transformer architecture."""
|
179
182
|
|
@@ -187,7 +190,7 @@ class ModelConfig:
|
|
187
190
|
block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
|
188
191
|
|
189
192
|
# The normalization applied before LM head.
|
190
|
-
final_norm_config: NormalizationConfig = field(
|
193
|
+
final_norm_config: NormalizationConfig = dataclasses.field(
|
191
194
|
default_factory=NormalizationConfig
|
192
195
|
)
|
193
196
|
|
@@ -117,7 +117,7 @@ class TestModelConversion(googletest.TestCase):
|
|
117
117
|
def _test_multisig_model(self, config, pytorch_model, atol, rtol):
|
118
118
|
# prefill
|
119
119
|
seq_len = 10
|
120
|
-
prefill_tokens = torch.
|
120
|
+
prefill_tokens = torch.zeros((1, seq_len), dtype=torch.int, device="cpu")
|
121
121
|
prompt_token = torch.from_numpy(np.array([1, 2, 3, 4]))
|
122
122
|
prefill_tokens[0, : len(prompt_token)] = prompt_token
|
123
123
|
prefill_input_pos = torch.arange(0, seq_len, dtype=torch.int)
|
@@ -22,6 +22,7 @@ from ai_edge_torch.generative.examples.gemma import gemma1
|
|
22
22
|
from ai_edge_torch.generative.examples.gemma import gemma2
|
23
23
|
from ai_edge_torch.generative.examples.llama import llama
|
24
24
|
from ai_edge_torch.generative.examples.openelm import openelm
|
25
|
+
from ai_edge_torch.generative.examples.paligemma import paligemma
|
25
26
|
from ai_edge_torch.generative.examples.phi import phi2
|
26
27
|
from ai_edge_torch.generative.examples.phi import phi3
|
27
28
|
from ai_edge_torch.generative.examples.qwen import qwen
|
@@ -55,7 +56,7 @@ class TestModelConversion(googletest.TestCase):
|
|
55
56
|
|
56
57
|
def _test_model(self, config, model, signature_name, atol, rtol):
|
57
58
|
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
58
|
-
tokens = torch.
|
59
|
+
tokens = torch.zeros((1, 10), dtype=torch.int, device="cpu")
|
59
60
|
tokens[0, :4] = idx
|
60
61
|
input_pos = torch.arange(0, 10, dtype=torch.int)
|
61
62
|
kv = kv_cache.KVCache.from_model_config(config)
|
@@ -171,6 +172,54 @@ class TestModelConversion(googletest.TestCase):
|
|
171
172
|
pytorch_model = model_builder.DecoderOnlyModel(config).eval()
|
172
173
|
self._test_model(config, pytorch_model, "prefill", atol=1e-3, rtol=1e-5)
|
173
174
|
|
175
|
+
@googletest.skipIf(
|
176
|
+
ai_edge_config.Config.use_torch_xla,
|
177
|
+
reason="tests with custom ops are not supported on oss",
|
178
|
+
)
|
179
|
+
def test_paligemma(self):
|
180
|
+
config = paligemma.get_fake_model_config()
|
181
|
+
pytorch_model = paligemma.PaliGemma(config).eval()
|
182
|
+
idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
|
183
|
+
image_embedding_config = config.image_encoder_config.image_embedding
|
184
|
+
num_patches = (
|
185
|
+
image_embedding_config.image_size // image_embedding_config.patch_size
|
186
|
+
) ** 2
|
187
|
+
# Make sure the token size is longer than the number of image patches.
|
188
|
+
tokens_len = num_patches + 10
|
189
|
+
tokens = torch.zeros((1, tokens_len), dtype=torch.int, device="cpu")
|
190
|
+
tokens[0, :4] = idx
|
191
|
+
input_pos = torch.arange(0, tokens_len, dtype=torch.int)
|
192
|
+
kv = kv_cache.KVCache.from_model_config(config.decoder_config)
|
193
|
+
pixel_values = torch.zeros((1, 3, 8, 8), dtype=torch.float32, device="cpu")
|
194
|
+
|
195
|
+
edge_model = ai_edge_torch.signature(
|
196
|
+
"prefill_pixel",
|
197
|
+
pytorch_model,
|
198
|
+
sample_kwargs={
|
199
|
+
"tokens": tokens,
|
200
|
+
"input_pos": input_pos,
|
201
|
+
"kv_cache": kv,
|
202
|
+
"pixel_values": pixel_values,
|
203
|
+
},
|
204
|
+
).convert()
|
205
|
+
edge_model.set_interpreter_builder(
|
206
|
+
self._interpreter_builder(edge_model.tflite_model())
|
207
|
+
)
|
208
|
+
|
209
|
+
self.assertTrue(
|
210
|
+
test_utils.compare_tflite_torch(
|
211
|
+
edge_model,
|
212
|
+
pytorch_model,
|
213
|
+
tokens,
|
214
|
+
input_pos,
|
215
|
+
kv,
|
216
|
+
pixel_values=pixel_values,
|
217
|
+
signature_name="prefill_pixel",
|
218
|
+
atol=1e-3,
|
219
|
+
rtol=1e-5,
|
220
|
+
)
|
221
|
+
)
|
222
|
+
|
174
223
|
@googletest.skipIf(
|
175
224
|
ai_edge_config.Config.use_torch_xla,
|
176
225
|
reason="tests with custom ops are not supported on oss",
|
@@ -32,18 +32,21 @@ def compare_tflite_torch(
|
|
32
32
|
signature_name: str,
|
33
33
|
atol: float = 1e-5,
|
34
34
|
rtol: float = 1e-5,
|
35
|
+
**kwargs,
|
35
36
|
):
|
36
37
|
"""Compares torch models and TFLite models."""
|
37
38
|
values, spec = pytree.tree_flatten({"kv_cache": kv_cache})
|
38
39
|
flat_names = common_utils.flat_dict_names(spec.children_specs, spec.context)
|
39
|
-
torch_output = torch_model(tokens, input_pos, kv_cache)
|
40
|
+
torch_output = torch_model(tokens, input_pos, kv_cache, **kwargs)
|
40
41
|
|
41
|
-
|
42
|
+
if "pixel_values" in kwargs:
|
43
|
+
kwargs["pixel_values"] = kwargs["pixel_values"].numpy()
|
44
|
+
kwargs.update({k: v.numpy() for k, v in zip(flat_names, values)})
|
42
45
|
edge_output = edge_model(
|
43
46
|
signature_name=signature_name,
|
44
47
|
tokens=tokens.numpy(),
|
45
48
|
input_pos=input_pos.numpy(),
|
46
|
-
**
|
49
|
+
**kwargs,
|
47
50
|
)
|
48
51
|
|
49
52
|
return np.allclose(
|
@@ -15,9 +15,11 @@
|
|
15
15
|
|
16
16
|
"""Common utility functions for model conversion."""
|
17
17
|
|
18
|
-
import
|
18
|
+
from typing import Union
|
19
|
+
|
19
20
|
from ai_edge_torch._convert import converter as converter_utils
|
20
|
-
|
21
|
+
import ai_edge_torch.generative.layers.kv_cache as kv_utils
|
22
|
+
import ai_edge_torch.generative.layers.model_config as cfg
|
21
23
|
from ai_edge_torch.generative.quantize import quant_recipes
|
22
24
|
import torch
|
23
25
|
|
@@ -25,109 +27,74 @@ import torch
|
|
25
27
|
def convert_to_tflite(
|
26
28
|
pytorch_model: torch.nn.Module,
|
27
29
|
tflite_path: str,
|
28
|
-
prefill_seq_len: int
|
30
|
+
prefill_seq_len: Union[int, list[int]],
|
31
|
+
pixel_values_size: torch.Size = None,
|
29
32
|
quantize: bool = True,
|
33
|
+
config: cfg.ModelConfig = None,
|
30
34
|
):
|
31
35
|
"""Converts a nn.Module model to multi-signature tflite model.
|
32
36
|
|
33
|
-
A PyTorch model will be converted to a tflite model with
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
"
|
41
|
-
|
42
|
-
|
37
|
+
A PyTorch model will be converted to a tflite model with several signatures:
|
38
|
+
* "prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
39
|
+
passed),
|
40
|
+
* "prefill_[preill_seq_len]_pixel" (or "prefill_pixel" if only one
|
41
|
+
prefill_seq_len is passed) if num_pixel_values > 0, and
|
42
|
+
* "decode".
|
43
|
+
|
44
|
+
"prefill_[prefill_seq_len]" (or "prefill" if only one prefill_seq_len is
|
45
|
+
passed) signature takes as a sample input:
|
46
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
47
|
+
* a tensor of shape [1, prefill_seq_len] of token positions, and
|
48
|
+
* an external KV cache.
|
49
|
+
|
50
|
+
If num_pixel_values > 0, "prefill_[prefill_seq_len]_pixel" (or "prefill_pixel"
|
51
|
+
if only one prefill_seq_len is passed) signature takes as a sample input:
|
52
|
+
* a tensor of shape [1, prefill_seq_len] of token sequence,
|
53
|
+
* a tensor of shape [1, prefill_seq_len] of token positions,
|
54
|
+
* an external KV cache, and
|
55
|
+
* a tensor of shape [1, num_pixel_values] of pixel values.
|
56
|
+
|
57
|
+
"decode" signature takes as a sample input:
|
58
|
+
* a tensor of shape [1, 1] of token sequence,
|
59
|
+
* a tensor of shape [1, 1] of the token position, and
|
60
|
+
* an external KV cache.
|
43
61
|
|
44
62
|
The final tflite model will be exported to tflite_path.
|
45
63
|
|
46
64
|
Args:
|
47
65
|
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
48
66
|
tflite_path (str): The tflite file path to export.
|
49
|
-
prefill_seq_len (int,
|
50
|
-
|
67
|
+
prefill_seq_len (Union[int, list[int]]): A list of prefill lengths to
|
68
|
+
export.
|
69
|
+
pixel_values_size (torch.Size, optional): The size of pixel values to pass
|
70
|
+
to the model. If None, the model is not expected to take pixel values.
|
51
71
|
quantize (bool, optional): Whether the model should be quanized. Defaults
|
52
72
|
to True.
|
73
|
+
config (cfg.ModelConfig, optional): The model config used to configure KV
|
74
|
+
cache. If None, it uses the config of the pytorch_model.
|
53
75
|
"""
|
54
|
-
|
55
|
-
|
56
|
-
prefill_input_pos = torch.arange(0, prefill_seq_len, dtype=torch.int)
|
57
|
-
decode_token = torch.tensor([[0]], dtype=torch.int)
|
58
|
-
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
59
|
-
kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
|
60
|
-
|
61
|
-
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
62
|
-
edge_model = (
|
63
|
-
ai_edge_torch.signature(
|
64
|
-
'prefill',
|
65
|
-
pytorch_model,
|
66
|
-
sample_kwargs={
|
67
|
-
'tokens': prefill_tokens,
|
68
|
-
'input_pos': prefill_input_pos,
|
69
|
-
'kv_cache': kv,
|
70
|
-
},
|
71
|
-
)
|
72
|
-
.signature(
|
73
|
-
'decode',
|
74
|
-
pytorch_model,
|
75
|
-
sample_kwargs={
|
76
|
-
'tokens': decode_token,
|
77
|
-
'input_pos': decode_input_pos,
|
78
|
-
'kv_cache': kv,
|
79
|
-
},
|
80
|
-
)
|
81
|
-
.convert(quant_config=quant_config)
|
76
|
+
prefill_seq_lens = (
|
77
|
+
[prefill_seq_len] if isinstance(prefill_seq_len, int) else prefill_seq_len
|
82
78
|
)
|
83
|
-
edge_model.export(tflite_path)
|
84
|
-
|
85
|
-
|
86
|
-
def convert_to_tflite_multi_prefill_lens(
|
87
|
-
pytorch_model: torch.nn.Module,
|
88
|
-
tflite_path: str,
|
89
|
-
prefill_seq_lens: list[int],
|
90
|
-
quantize: bool = True,
|
91
|
-
):
|
92
|
-
"""Converts a nn.Module model to multi-signature tflite model with different
|
93
|
-
|
94
|
-
prefill lengths.
|
95
|
-
|
96
|
-
A PyTorch model will be converted to a tflite model with several signatures:
|
97
|
-
"prefill_[prefill_seq_len]" and "decode".
|
98
|
-
|
99
|
-
"prefill_[prefill_seq_len]" signature takes a tensor of shape [1,
|
100
|
-
prefill_seq_len] of token
|
101
|
-
sequence, a tensor of shape [1, prefill_seq_len] of token positions, and an
|
102
|
-
external KV cache as a sample input.
|
103
|
-
|
104
|
-
"decode" signature takes a tensor of shape [1, 1] of token sequence, a tensor
|
105
|
-
of shape [1, 1] of the token position, and an external KV cache as a sample
|
106
|
-
input.
|
107
|
-
|
108
|
-
The final tflite model will be exported to tflite_path.
|
109
79
|
|
110
|
-
Args:
|
111
|
-
pytorch_model (torch.nn.Module): PyTorch model to convert to tflite.
|
112
|
-
tflite_path (str): The tflite file path to export.
|
113
|
-
prefill_seq_lens (list[int]): A list of prefill lengths to export.
|
114
|
-
quantize (bool, optional): Whether the model should be quanized. Defaults
|
115
|
-
to True.
|
116
|
-
"""
|
117
80
|
# Tensors used to trace the model graph during conversion.
|
118
81
|
prefill_tokens_list = []
|
119
82
|
prefill_input_pos_list = []
|
120
|
-
for
|
121
|
-
prefill_tokens_list.append(
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
83
|
+
for seq_len in prefill_seq_lens:
|
84
|
+
prefill_tokens_list.append(torch.full((1, seq_len), 0, dtype=torch.int))
|
85
|
+
prefill_input_pos_list.append(torch.arange(0, seq_len, dtype=torch.int))
|
86
|
+
|
87
|
+
prefill_pixel_values = (
|
88
|
+
torch.full((1,) + pixel_values_size, 0, dtype=torch.float32)
|
89
|
+
if pixel_values_size
|
90
|
+
else None
|
91
|
+
)
|
127
92
|
|
128
93
|
decode_token = torch.tensor([[0]], dtype=torch.int)
|
129
94
|
decode_input_pos = torch.tensor([0], dtype=torch.int)
|
130
|
-
kv = kv_utils.KVCache.from_model_config(
|
95
|
+
kv = kv_utils.KVCache.from_model_config(
|
96
|
+
config if config else pytorch_model.config
|
97
|
+
)
|
131
98
|
|
132
99
|
quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
|
133
100
|
converter = converter_utils.Converter()
|
@@ -135,8 +102,12 @@ def convert_to_tflite_multi_prefill_lens(
|
|
135
102
|
prefill_seq_len = prefill_seq_lens[i]
|
136
103
|
prefill_tokens = prefill_tokens_list[i]
|
137
104
|
prefill_input_pos = prefill_input_pos_list[i]
|
105
|
+
if i == 0 and len(prefill_seq_lens) == 1:
|
106
|
+
prefill_signature_name = 'prefill'
|
107
|
+
else:
|
108
|
+
prefill_signature_name = f'prefill_{prefill_seq_len}'
|
138
109
|
converter.add_signature(
|
139
|
-
|
110
|
+
prefill_signature_name,
|
140
111
|
pytorch_model,
|
141
112
|
sample_kwargs={
|
142
113
|
'tokens': prefill_tokens,
|
@@ -144,8 +115,19 @@ def convert_to_tflite_multi_prefill_lens(
|
|
144
115
|
'kv_cache': kv,
|
145
116
|
},
|
146
117
|
)
|
118
|
+
if prefill_pixel_values is not None:
|
119
|
+
converter.add_signature(
|
120
|
+
prefill_signature_name + '_pixel',
|
121
|
+
pytorch_model,
|
122
|
+
sample_kwargs={
|
123
|
+
'tokens': prefill_tokens,
|
124
|
+
'input_pos': prefill_input_pos,
|
125
|
+
'kv_cache': kv,
|
126
|
+
'pixel_values': prefill_pixel_values,
|
127
|
+
},
|
128
|
+
)
|
147
129
|
|
148
|
-
|
130
|
+
converter.add_signature(
|
149
131
|
'decode',
|
150
132
|
pytorch_model,
|
151
133
|
sample_kwargs={
|
@@ -153,5 +135,7 @@ def convert_to_tflite_multi_prefill_lens(
|
|
153
135
|
'input_pos': decode_input_pos,
|
154
136
|
'kv_cache': kv,
|
155
137
|
},
|
156
|
-
)
|
138
|
+
)
|
139
|
+
|
140
|
+
edge_model = converter.convert(quant_config=quant_config)
|
157
141
|
edge_model.export(tflite_path)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241120
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=52sF7t2CBQE8RcB2Hcmo-f6_BLyCW9NzWZ-wTKM9ho4,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
|
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
29
|
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
30
|
-
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=ELwHxTdTTCJm30aWg_PZXxg9HvDM4Hnf9lT0wwOWT6s,8060
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
@@ -45,7 +45,7 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
|
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
48
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=6d9wG5MnStEys34_gFXwKTMRXUBFLTW1jEzCoWkAtwM,2224
|
49
49
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
50
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
51
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
@@ -61,9 +61,10 @@ ai_edge_torch/generative/examples/openelm/convert_to_tflite.py,sha256=85FVEt6cKF
|
|
61
61
|
ai_edge_torch/generative/examples/openelm/openelm.py,sha256=sFakstoPDcOHSak0IGFEEq_HQMBBSMcx-WVCDZqcVDo,4411
|
62
62
|
ai_edge_torch/generative/examples/openelm/verify.py,sha256=VkigoqhAr8ew95neb3TifYv-SLOSheaWKv2AH0iKDrc,2441
|
63
63
|
ai_edge_torch/generative/examples/paligemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
64
|
+
ai_edge_torch/generative/examples/paligemma/convert_to_tflite.py,sha256=dT7dnx1dzGzFiH5gQJ4M6zcTLSRFvSDpi3IuZ9_vd78,2706
|
64
65
|
ai_edge_torch/generative/examples/paligemma/decoder.py,sha256=XMeznGBbjRJidv725L6_7XzkYskS2cDjf8NGB18FNhg,4944
|
65
|
-
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=
|
66
|
-
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=
|
66
|
+
ai_edge_torch/generative/examples/paligemma/image_encoder.py,sha256=yKPWG8aBp-GuzeyQntlzwTTcGBBjvUywVGRjnlNprmo,5574
|
67
|
+
ai_edge_torch/generative/examples/paligemma/paligemma.py,sha256=pIjsS-IUFevRjFA9153YT1vtWXATGWHsgVQQX_nWaZQ,5280
|
67
68
|
ai_edge_torch/generative/examples/paligemma/verify.py,sha256=Bkbgy-GFjnMNYjduWUM7YLWarPTwmj1v38eHY-PdBlM,4874
|
68
69
|
ai_edge_torch/generative/examples/paligemma/verify_decoder.py,sha256=al5wMPWri4IRVWrLmCplPi6uoCzwh0vBHMGnCt-XUqo,2690
|
69
70
|
ai_edge_torch/generative/examples/paligemma/verify_image_encoder.py,sha256=pSekf1BybhieQz3cQx_llbRQHxczXbTqool8fOyGj_0,3114
|
@@ -117,7 +118,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5
|
|
117
118
|
ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
|
118
119
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
119
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
120
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
121
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
121
122
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
122
123
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
123
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
@@ -135,12 +136,12 @@ ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVu
|
|
135
136
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
136
137
|
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
137
138
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
138
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
139
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
139
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=aZFaheg2sq7rEccch1TZM6W4BSfpJZjrM9Gyp4hVGYs,6351
|
140
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=xWV9O2wuRHc4VNBWuWipiuqXa3AJhiV1nmjewAZHHWM,11177
|
140
141
|
ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0lZKNtCB04SOLO18y4,5980
|
141
|
-
ai_edge_torch/generative/test/utils.py,sha256=
|
142
|
+
ai_edge_torch/generative/test/utils.py,sha256=eQ-hjd1eXuHJF3SJK6_CrjgOZVzmG_4VEdH7Z1gH_lA,1897
|
142
143
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
143
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
144
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=S14STbyxV6A9HKy1BdUo49f2jS6Ij0RL9mVAFUMWYV8,5291
|
144
145
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
145
146
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=OcHJhEqc3LjI3STli6cyn71m1mdzr7QbzF9fqSNCXrg,5730
|
146
147
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
@@ -193,8 +194,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
193
194
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
194
195
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
195
196
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/METADATA,sha256=1Nv_QeerPRw888sOTf4jHx5Ihu-PJD9rL8GOpRHSTa4,1897
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241120.dist-info/RECORD,,
|
File without changes
|
File without changes
|