ai-edge-torch-nightly 0.3.0.dev20241119__py3-none-any.whl → 0.3.0.dev20241121__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py +2 -2
- ai_edge_torch/_convert/test/test_convert_composites.py +1 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py +68 -0
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +3 -3
- ai_edge_torch/generative/layers/model_config.py +23 -20
- ai_edge_torch/odml_torch/lowerings/_basic.py +30 -1
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241119.dist-info → ai_edge_torch_nightly-0.3.0.dev20241121.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241119.dist-info → ai_edge_torch_nightly-0.3.0.dev20241121.dist-info}/RECORD +13 -12
- {ai_edge_torch_nightly-0.3.0.dev20241119.dist-info → ai_edge_torch_nightly-0.3.0.dev20241121.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241119.dist-info → ai_edge_torch_nightly-0.3.0.dev20241121.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241119.dist-info → ai_edge_torch_nightly-0.3.0.dev20241121.dist-info}/top_level.txt +0 -0
@@ -49,7 +49,7 @@ def _get_upsample_bilinear2d_pattern():
|
|
49
49
|
output = internal_match.returning_nodes[0]
|
50
50
|
output_h, output_w = output.meta["val"].shape[-2:]
|
51
51
|
return {
|
52
|
-
"
|
52
|
+
"size": (int(output_h), int(output_w)),
|
53
53
|
"align_corners": False,
|
54
54
|
"is_nchw_op": True,
|
55
55
|
}
|
@@ -73,7 +73,7 @@ def _get_upsample_bilinear2d_align_corners_pattern():
|
|
73
73
|
output = internal_match.returning_nodes[0]
|
74
74
|
output_h, output_w = output.meta["val"].shape[-2:]
|
75
75
|
return {
|
76
|
-
"
|
76
|
+
"size": (int(output_h), int(output_w)),
|
77
77
|
"align_corners": True,
|
78
78
|
"is_nchw_op": True,
|
79
79
|
}
|
@@ -39,6 +39,7 @@ def _func_to_torch_module(func: Callable[..., torch.Tensor]):
|
|
39
39
|
return TestModule(func).eval()
|
40
40
|
|
41
41
|
|
42
|
+
@googletest.skip('Temporary outage due to changes for b/377531086')
|
42
43
|
class TestConvertComposites(googletest.TestCase):
|
43
44
|
"""Tests conversion modules that are meant to be wrapped as composites."""
|
44
45
|
|
@@ -0,0 +1,68 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
"""Example to convert a Gemma2 model to multiple prefill length tflite model."""
|
17
|
+
|
18
|
+
import os
|
19
|
+
import pathlib
|
20
|
+
|
21
|
+
from absl import app
|
22
|
+
from absl import flags
|
23
|
+
from ai_edge_torch.generative.examples.gemma import gemma2
|
24
|
+
from ai_edge_torch.generative.utilities import converter
|
25
|
+
|
26
|
+
_CHECKPOINT_PATH = flags.DEFINE_string(
|
27
|
+
'checkpoint_path',
|
28
|
+
os.path.join(pathlib.Path.home(), 'Downloads/llm_data/gemma2-2b'),
|
29
|
+
'The path to the model checkpoint, or directory holding the checkpoint.',
|
30
|
+
)
|
31
|
+
_TFLITE_PATH = flags.DEFINE_string(
|
32
|
+
'tflite_path',
|
33
|
+
'/tmp/',
|
34
|
+
'The tflite file path to export.',
|
35
|
+
)
|
36
|
+
_PREFILL_SEQ_LENS = flags.DEFINE_multi_integer(
|
37
|
+
'prefill_seq_lens',
|
38
|
+
(8, 64, 128, 256, 512, 1024),
|
39
|
+
'List of the maximum sizes of prefill input tensors.',
|
40
|
+
)
|
41
|
+
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
|
+
'kv_cache_max_len',
|
43
|
+
1280,
|
44
|
+
'The maximum size of KV cache buffer, including both prefill and decode.',
|
45
|
+
)
|
46
|
+
_QUANTIZE = flags.DEFINE_bool(
|
47
|
+
'quantize',
|
48
|
+
True,
|
49
|
+
'Whether the model should be quantized.',
|
50
|
+
)
|
51
|
+
|
52
|
+
|
53
|
+
def main(_):
|
54
|
+
pytorch_model = gemma2.build_2b_model(
|
55
|
+
_CHECKPOINT_PATH.value, kv_cache_max_len=_KV_CACHE_MAX_LEN.value
|
56
|
+
)
|
57
|
+
quant_suffix = 'q8' if _QUANTIZE.value else 'f32'
|
58
|
+
output_filename = f'gemma2_{quant_suffix}_multi-prefill-seq_ekv{_KV_CACHE_MAX_LEN.value}.tflite'
|
59
|
+
converter.convert_to_tflite(
|
60
|
+
pytorch_model,
|
61
|
+
tflite_path=os.path.join(_TFLITE_PATH.value, output_filename),
|
62
|
+
prefill_seq_len=_PREFILL_SEQ_LENS.value,
|
63
|
+
quantize=_QUANTIZE.value,
|
64
|
+
)
|
65
|
+
|
66
|
+
|
67
|
+
if __name__ == '__main__':
|
68
|
+
app.run(main)
|
@@ -33,10 +33,10 @@ _TFLITE_PATH = flags.DEFINE_string(
|
|
33
33
|
'/tmp/',
|
34
34
|
'The tflite file path to export.',
|
35
35
|
)
|
36
|
-
_PREFILL_SEQ_LEN = flags.
|
36
|
+
_PREFILL_SEQ_LEN = flags.DEFINE_integer(
|
37
37
|
'prefill_seq_len',
|
38
|
-
|
39
|
-
'
|
38
|
+
1024,
|
39
|
+
'The maximum size of prefill input tensor.',
|
40
40
|
)
|
41
41
|
_KV_CACHE_MAX_LEN = flags.DEFINE_integer(
|
42
42
|
'kv_cache_max_len',
|
@@ -12,9 +12,10 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
|
16
|
-
|
17
|
-
|
15
|
+
|
16
|
+
"""Model configuration class."""
|
17
|
+
|
18
|
+
import dataclasses
|
18
19
|
import enum
|
19
20
|
from typing import Optional, Sequence, Union
|
20
21
|
|
@@ -35,7 +36,7 @@ class ActivationType(enum.Enum):
|
|
35
36
|
|
36
37
|
@enum.unique
|
37
38
|
class NormalizationType(enum.Enum):
|
38
|
-
"""Different normalization functions"""
|
39
|
+
"""Different normalization functions."""
|
39
40
|
|
40
41
|
# No normalization is applied.
|
41
42
|
NONE = enum.auto()
|
@@ -59,7 +60,7 @@ class AttentionType(enum.Enum):
|
|
59
60
|
LOCAL_SLIDING = enum.auto()
|
60
61
|
|
61
62
|
|
62
|
-
@dataclass
|
63
|
+
@dataclasses.dataclass
|
63
64
|
class NormalizationConfig:
|
64
65
|
"""Normalizater parameters."""
|
65
66
|
|
@@ -71,7 +72,7 @@ class NormalizationConfig:
|
|
71
72
|
group_num: Optional[float] = None
|
72
73
|
|
73
74
|
|
74
|
-
@dataclass
|
75
|
+
@dataclasses.dataclass
|
75
76
|
class AttentionConfig:
|
76
77
|
"""Attention model's parameters."""
|
77
78
|
|
@@ -90,18 +91,20 @@ class AttentionConfig:
|
|
90
91
|
# Whether to use bias with Query, Key, and Value projection.
|
91
92
|
qkv_use_bias: bool = False
|
92
93
|
# Whether the fused q, k, v projection weights interleaves q, k, v heads.
|
93
|
-
# If True, the projection weights are in format
|
94
|
-
#
|
94
|
+
# If True, the projection weights are in format:
|
95
|
+
# `[q_head_0, k_head_0, v_head_0, q_head_1, k_head_1, v_head_1, ...]`
|
96
|
+
# If False, the projection weights are in format:
|
97
|
+
# `[q_head_0, q_head_1, ..., k_head_0, k_head_1, ... v_head_0, v_head_1, ...]`
|
95
98
|
qkv_fused_interleaved: bool = True
|
96
99
|
# Whether to use bias with attention output projection.
|
97
100
|
output_proj_use_bias: bool = False
|
98
101
|
enable_kv_cache: bool = True
|
99
102
|
# The normalization applied to query projection's output.
|
100
|
-
query_norm_config: NormalizationConfig = field(
|
103
|
+
query_norm_config: NormalizationConfig = dataclasses.field(
|
101
104
|
default_factory=NormalizationConfig
|
102
105
|
)
|
103
106
|
# The normalization applied to key projection's output.
|
104
|
-
key_norm_config: NormalizationConfig = field(
|
107
|
+
key_norm_config: NormalizationConfig = dataclasses.field(
|
105
108
|
default_factory=NormalizationConfig
|
106
109
|
)
|
107
110
|
relative_attention_num_buckets: int = 0
|
@@ -114,7 +117,7 @@ class AttentionConfig:
|
|
114
117
|
sliding_window_size: Optional[int] = None
|
115
118
|
|
116
119
|
|
117
|
-
@dataclass
|
120
|
+
@dataclasses.dataclass
|
118
121
|
class ActivationConfig:
|
119
122
|
type: ActivationType = ActivationType.LINEAR
|
120
123
|
# Dimension of input and output, used in GeGLU.
|
@@ -122,7 +125,7 @@ class ActivationConfig:
|
|
122
125
|
dim_out: Optional[int] = None
|
123
126
|
|
124
127
|
|
125
|
-
@dataclass
|
128
|
+
@dataclasses.dataclass
|
126
129
|
class FeedForwardConfig:
|
127
130
|
"""FeedForward module's parameters."""
|
128
131
|
|
@@ -131,27 +134,27 @@ class FeedForwardConfig:
|
|
131
134
|
intermediate_size: int
|
132
135
|
use_bias: bool = False
|
133
136
|
# The normalization applied to feed forward's input.
|
134
|
-
pre_ff_norm_config: NormalizationConfig = field(
|
137
|
+
pre_ff_norm_config: NormalizationConfig = dataclasses.field(
|
135
138
|
default_factory=NormalizationConfig
|
136
139
|
)
|
137
140
|
# The normalization applied to feed forward's output.
|
138
|
-
post_ff_norm_config: NormalizationConfig = field(
|
141
|
+
post_ff_norm_config: NormalizationConfig = dataclasses.field(
|
139
142
|
default_factory=NormalizationConfig
|
140
143
|
)
|
141
144
|
|
142
145
|
|
143
|
-
@dataclass
|
146
|
+
@dataclasses.dataclass
|
144
147
|
class TransformerBlockConfig:
|
145
148
|
"""TransformerBlock module's parameters."""
|
146
149
|
|
147
150
|
attn_config: AttentionConfig
|
148
151
|
ff_config: FeedForwardConfig
|
149
152
|
# The normalization applied to attention's input.
|
150
|
-
pre_attention_norm_config: NormalizationConfig = field(
|
153
|
+
pre_attention_norm_config: NormalizationConfig = dataclasses.field(
|
151
154
|
default_factory=NormalizationConfig
|
152
155
|
)
|
153
156
|
# The normalization applied to attentions's output.
|
154
|
-
post_attention_norm_config: NormalizationConfig = field(
|
157
|
+
post_attention_norm_config: NormalizationConfig = dataclasses.field(
|
155
158
|
default_factory=NormalizationConfig
|
156
159
|
)
|
157
160
|
# If set to True, only attn_config.pre_attention_norm is applied to the input
|
@@ -163,7 +166,7 @@ class TransformerBlockConfig:
|
|
163
166
|
relative_attention: bool = False
|
164
167
|
|
165
168
|
|
166
|
-
@dataclass
|
169
|
+
@dataclasses.dataclass
|
167
170
|
class ImageEmbeddingConfig:
|
168
171
|
"""Image embedding parameters."""
|
169
172
|
|
@@ -173,7 +176,7 @@ class ImageEmbeddingConfig:
|
|
173
176
|
patch_size: int
|
174
177
|
|
175
178
|
|
176
|
-
@dataclass
|
179
|
+
@dataclasses.dataclass
|
177
180
|
class ModelConfig:
|
178
181
|
"""Base configurations for building a transformer architecture."""
|
179
182
|
|
@@ -187,7 +190,7 @@ class ModelConfig:
|
|
187
190
|
block_configs: Union[TransformerBlockConfig, Sequence[TransformerBlockConfig]]
|
188
191
|
|
189
192
|
# The normalization applied before LM head.
|
190
|
-
final_norm_config: NormalizationConfig = field(
|
193
|
+
final_norm_config: NormalizationConfig = dataclasses.field(
|
191
194
|
default_factory=NormalizationConfig
|
192
195
|
)
|
193
196
|
|
@@ -15,13 +15,17 @@
|
|
15
15
|
import math
|
16
16
|
from typing import Optional, Union
|
17
17
|
|
18
|
+
from ai_edge_torch.odml_torch import export_utils
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import context
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
18
21
|
from ai_edge_torch.odml_torch.lowerings import utils
|
19
22
|
from jax._src.lib.mlir import ir
|
20
23
|
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
21
24
|
import numpy as np
|
22
25
|
import torch
|
23
26
|
|
24
|
-
|
27
|
+
LoweringContext = context.LoweringContext
|
28
|
+
lower = registry.lower
|
25
29
|
|
26
30
|
|
27
31
|
# add(Tensor self, Tensor other) -> Tensor
|
@@ -211,6 +215,31 @@ def _aten_floor(lctx, x: ir.Value, *, out=None) -> ir.Value:
|
|
211
215
|
return stablehlo.floor(x)
|
212
216
|
|
213
217
|
|
218
|
+
# Schema:
|
219
|
+
# - aten::cat(Tensor[] tensors, int dim=0) -> Tensor
|
220
|
+
# Torch Reference:
|
221
|
+
# - https://pytorch.org/docs/main/generated/torch.cat.html
|
222
|
+
@lower(torch.ops.aten.cat.default)
|
223
|
+
def _aten_cat(lctx: LoweringContext, tensors, dim=0):
|
224
|
+
assert tensors
|
225
|
+
non_empty_tensors = [t for t in tensors if np.prod(t.type.shape) != 0]
|
226
|
+
out_meta = lctx.node.meta["tensor_meta"]
|
227
|
+
if not non_empty_tensors:
|
228
|
+
return utils.splat(
|
229
|
+
0,
|
230
|
+
export_utils.torch_dtype_to_ir_element_type(
|
231
|
+
lctx.ir_context, out_meta.dtype
|
232
|
+
),
|
233
|
+
out_meta.shape,
|
234
|
+
)
|
235
|
+
|
236
|
+
if dim < 0:
|
237
|
+
dim = dim + len(out_meta.shape)
|
238
|
+
dim = ir.IntegerAttr.get(ir.IntegerType.get_signless(64), dim)
|
239
|
+
|
240
|
+
return stablehlo.concatenate(non_empty_tensors, dim)
|
241
|
+
|
242
|
+
|
214
243
|
# Schema:
|
215
244
|
# - aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt?
|
216
245
|
# start=None, SymInt? end=None, SymInt step=1) -> Tensor
|
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.bitwise_not)
|
|
105
105
|
lower_by_torch_xla2(torch.ops.aten.bitwise_or)
|
106
106
|
lower_by_torch_xla2(torch.ops.aten.bitwise_xor)
|
107
107
|
lower_by_torch_xla2(torch.ops.aten.bmm)
|
108
|
-
lower_by_torch_xla2(torch.ops.aten.cat)
|
109
108
|
lower_by_torch_xla2(torch.ops.aten.ceil)
|
110
109
|
lower_by_torch_xla2(torch.ops.aten.clamp.Tensor)
|
111
110
|
lower_by_torch_xla2(torch.ops.aten.clamp.default)
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20241121
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=6eLmEn5xqmozokHVWP7j-jjFiQlv2a1aZxDucMzXDh8,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=HwzfRx_DX5TLtPqwEH1_NOm38_INvHzHl4_mX67KOdQ,5448
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -12,7 +12,7 @@ ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6W
|
|
12
12
|
ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDium1lz_BXEQ6IsMFc,2893
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
|
-
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=
|
15
|
+
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
16
|
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=IlZuK42kfVcRqAWZp4j2k_81T2uWo9T2558U_GPJAlU,2327
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
@@ -27,7 +27,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
|
|
27
27
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
|
28
28
|
ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
29
29
|
ai_edge_torch/_convert/test/test_convert.py,sha256=yXfeWDw9u_rTS3B6kvvFPo5E4XNT3zKTSLFSBSAI9Fc,15502
|
30
|
-
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=
|
30
|
+
ai_edge_torch/_convert/test/test_convert_composites.py,sha256=ELwHxTdTTCJm30aWg_PZXxg9HvDM4Hnf9lT0wwOWT6s,8060
|
31
31
|
ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
|
32
32
|
ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
|
33
33
|
ai_edge_torch/debug/__init__.py,sha256=N05Mmvi41KgSuK0JhuMejERESgP8QekiGdp9_PEyuKU,742
|
@@ -45,7 +45,8 @@ ai_edge_torch/generative/examples/amd_llama_135m/convert_to_tflite.py,sha256=-n7
|
|
45
45
|
ai_edge_torch/generative/examples/amd_llama_135m/verify.py,sha256=-9Nb9D818YSJR3olVtBwoLNeMMD5qE58YBnsA67hlHg,2421
|
46
46
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=evmUj_4yygQthSRU-ke-Xn1qFNDCZKbegqINWfruKwU,2184
|
48
|
-
ai_edge_torch/generative/examples/gemma/
|
48
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_multi_prefills.py,sha256=6d9wG5MnStEys34_gFXwKTMRXUBFLTW1jEzCoWkAtwM,2224
|
49
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=RZDs6oY-NLYrPNtfuJDweIHzGUL2kzpIc3AW_1p8gGg,2186
|
49
50
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=oSbysiPvwp5efMbNYZop3HrxDMGiD15Tmz-HiQuTr2E,3315
|
50
51
|
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=RQFQDMEnIVp8PefcCTr7P0CvllKI7FVoIJLXbPLLIsc,9056
|
51
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
@@ -117,7 +118,7 @@ ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5
|
|
117
118
|
ai_edge_torch/generative/layers/builder.py,sha256=Z5LyzCEThgnYZeyViakaE3yJVzTGHtw13acHsAQR15U,5050
|
118
119
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
119
120
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=lbm-yJ1jGPtcgWS4C3FmSnB1IlxqDE7g0BLRh3PN4N4,6324
|
120
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
121
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
121
122
|
ai_edge_torch/generative/layers/normalization.py,sha256=eKAGst9rPuyRFExMcQFJO7R3iHdCtlmjeF_lITjLhwE,6498
|
122
123
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
123
124
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
@@ -177,10 +178,10 @@ ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkg
|
|
177
178
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
178
179
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
179
180
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
180
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
181
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=mxNh20Z4ZeQMu0AAdXnNMXdm2PdAh3RmQPzq2SBpxQs,9954
|
181
182
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
182
183
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
183
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
184
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=MY6FFSJKYtD1M1l2q3hDKf3P4NpODqQ4NyWudYe1tTE,10772
|
184
185
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
185
186
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
186
187
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=gqx3n1Mx8pnGQz3nkIF1T_8bkRabXLJBvUoJJn5kOUY,2911
|
@@ -193,8 +194,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
193
194
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
194
195
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
195
196
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
196
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
-
ai_edge_torch_nightly-0.3.0.
|
198
|
-
ai_edge_torch_nightly-0.3.0.
|
199
|
-
ai_edge_torch_nightly-0.3.0.
|
200
|
-
ai_edge_torch_nightly-0.3.0.
|
197
|
+
ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
198
|
+
ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/METADATA,sha256=AJUg6jkWACMXVy7gopyMvlD0aJfw1BVnkZKbGS9cXX0,1897
|
199
|
+
ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/WHEEL,sha256=bFJAMchF8aTQGUgMZzHJyDDMPTO3ToJ7x23SLJa1SVo,92
|
200
|
+
ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
201
|
+
ai_edge_torch_nightly-0.3.0.dev20241121.dist-info/RECORD,,
|
File without changes
|
File without changes
|