ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
- ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
- ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
- ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
- ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
- ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
- ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
- ai_edge_torch/generative/examples/t5/t5.py +35 -22
- ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
- ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
- ai_edge_torch/generative/layers/attention.py +77 -73
- ai_edge_torch/generative/layers/builder.py +5 -3
- ai_edge_torch/generative/layers/kv_cache.py +163 -51
- ai_edge_torch/generative/layers/model_config.py +38 -19
- ai_edge_torch/generative/layers/normalization.py +158 -0
- ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
- ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
- ai_edge_torch/generative/test/test_loader.py +1 -1
- ai_edge_torch/generative/test/test_model_conversion.py +72 -34
- ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
- ai_edge_torch/generative/test/utils.py +54 -0
- ai_edge_torch/generative/utilities/loader.py +15 -15
- ai_edge_torch/generative/utilities/t5_loader.py +21 -20
- ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
- ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
- ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
- ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
- ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
- /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
- /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -221,7 +221,8 @@ class ModelLoader:
|
|
221
221
|
converted_state: Dict[str, torch.Tensor],
|
222
222
|
):
|
223
223
|
prefix = f"transformer_blocks.{idx}"
|
224
|
-
|
224
|
+
ff_config = config.block_config(idx).ff_config
|
225
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
225
226
|
ff_up_proj_name = self._names.ff_up_proj.format(idx)
|
226
227
|
ff_down_proj_name = self._names.ff_down_proj.format(idx)
|
227
228
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -230,7 +231,7 @@ class ModelLoader:
|
|
230
231
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
231
232
|
f"{ff_down_proj_name}.weight"
|
232
233
|
)
|
233
|
-
if
|
234
|
+
if ff_config.use_bias:
|
234
235
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
235
236
|
f"{ff_up_proj_name}.bias"
|
236
237
|
)
|
@@ -250,7 +251,7 @@ class ModelLoader:
|
|
250
251
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
251
252
|
f"{ff_gate_proj_name}.weight"
|
252
253
|
)
|
253
|
-
if
|
254
|
+
if ff_config.use_bias:
|
254
255
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
255
256
|
f"{ff_up_proj_name}.bias"
|
256
257
|
)
|
@@ -289,6 +290,7 @@ class ModelLoader:
|
|
289
290
|
converted_state: Dict[str, torch.Tensor],
|
290
291
|
):
|
291
292
|
prefix = f"transformer_blocks.{idx}"
|
293
|
+
attn_config = config.block_config(idx).attn_config
|
292
294
|
if self._names.attn_fused_qkv_proj:
|
293
295
|
fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
|
294
296
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
|
@@ -300,13 +302,13 @@ class ModelLoader:
|
|
300
302
|
v_name = self._names.attn_value_proj.format(idx)
|
301
303
|
converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
|
302
304
|
self._fuse_qkv(
|
303
|
-
|
305
|
+
attn_config,
|
304
306
|
state.pop(f"{q_name}.weight"),
|
305
307
|
state.pop(f"{k_name}.weight"),
|
306
308
|
state.pop(f"{v_name}.weight"),
|
307
309
|
)
|
308
310
|
)
|
309
|
-
if
|
311
|
+
if attn_config.qkv_use_bias:
|
310
312
|
if self._names.attn_fused_qkv_proj:
|
311
313
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
|
312
314
|
f"{fused_qkv_name}.bias"
|
@@ -314,7 +316,7 @@ class ModelLoader:
|
|
314
316
|
else:
|
315
317
|
converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
|
316
318
|
self._fuse_qkv(
|
317
|
-
|
319
|
+
attn_config,
|
318
320
|
state.pop(f"{q_name}.bias"),
|
319
321
|
state.pop(f"{k_name}.bias"),
|
320
322
|
state.pop(f"{v_name}.bias"),
|
@@ -325,7 +327,7 @@ class ModelLoader:
|
|
325
327
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
326
328
|
state.pop(f"{o_name}.weight")
|
327
329
|
)
|
328
|
-
if
|
330
|
+
if attn_config.output_proj_use_bias:
|
329
331
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
330
332
|
state.pop(f"{o_name}.bias")
|
331
333
|
)
|
@@ -360,18 +362,16 @@ class ModelLoader:
|
|
360
362
|
|
361
363
|
def _fuse_qkv(
|
362
364
|
self,
|
363
|
-
|
365
|
+
attn_config: model_config.AttentionConfig,
|
364
366
|
q: torch.Tensor,
|
365
367
|
k: torch.Tensor,
|
366
368
|
v: torch.Tensor,
|
367
369
|
) -> torch.Tensor:
|
368
|
-
if
|
369
|
-
q_per_kv =
|
370
|
-
|
371
|
-
)
|
372
|
-
|
373
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
374
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
370
|
+
if attn_config.qkv_fused_interleaved:
|
371
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
372
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
373
|
+
ks = torch.split(k, attn_config.head_dim)
|
374
|
+
vs = torch.split(v, attn_config.head_dim)
|
375
375
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
376
376
|
return torch.cat(cycled)
|
377
377
|
else:
|
@@ -279,7 +279,8 @@ class ModelLoader:
|
|
279
279
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
280
280
|
if names.ff_up_proj is None or names.ff_down_proj is None:
|
281
281
|
return
|
282
|
-
|
282
|
+
ff_config = config.block_config(idx).ff_config
|
283
|
+
if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
|
283
284
|
ff_up_proj_name = names.ff_up_proj.format(idx)
|
284
285
|
ff_down_proj_name = names.ff_down_proj.format(idx)
|
285
286
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
@@ -288,7 +289,7 @@ class ModelLoader:
|
|
288
289
|
converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
|
289
290
|
f"{ff_down_proj_name}.weight"
|
290
291
|
)
|
291
|
-
if
|
292
|
+
if ff_config.use_bias:
|
292
293
|
converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
|
293
294
|
f"{ff_up_proj_name}.bias"
|
294
295
|
)
|
@@ -309,7 +310,7 @@ class ModelLoader:
|
|
309
310
|
converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
|
310
311
|
f"{ff_gate_proj_name}.weight"
|
311
312
|
)
|
312
|
-
if
|
313
|
+
if ff_config.use_bias:
|
313
314
|
converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
|
314
315
|
f"{ff_up_proj_name}.bias"
|
315
316
|
)
|
@@ -337,20 +338,21 @@ class ModelLoader:
|
|
337
338
|
):
|
338
339
|
return
|
339
340
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
341
|
+
attn_config = config.block_config(idx).attn_config
|
340
342
|
q_name = names.attn_query_proj.format(idx)
|
341
343
|
k_name = names.attn_key_proj.format(idx)
|
342
344
|
v_name = names.attn_value_proj.format(idx)
|
343
345
|
# model.encoder.transformer_blocks[0].atten_func.q_projection.weight
|
344
346
|
if fuse_attention:
|
345
347
|
converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
|
346
|
-
|
348
|
+
attn_config,
|
347
349
|
state.pop(f"{q_name}.weight"),
|
348
350
|
state.pop(f"{k_name}.weight"),
|
349
351
|
state.pop(f"{v_name}.weight"),
|
350
352
|
)
|
351
|
-
if
|
353
|
+
if attn_config.qkv_use_bias:
|
352
354
|
converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
|
353
|
-
|
355
|
+
attn_config,
|
354
356
|
state.pop(f"{q_name}.bias"),
|
355
357
|
state.pop(f"{k_name}.bias"),
|
356
358
|
state.pop(f"{v_name}.bias"),
|
@@ -365,7 +367,7 @@ class ModelLoader:
|
|
365
367
|
converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
|
366
368
|
f"{v_name}.weight"
|
367
369
|
)
|
368
|
-
if
|
370
|
+
if attn_config.qkv_use_bias:
|
369
371
|
converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
|
370
372
|
f"{q_name}.bias"
|
371
373
|
)
|
@@ -380,7 +382,7 @@ class ModelLoader:
|
|
380
382
|
converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
|
381
383
|
state.pop(f"{o_name}.weight")
|
382
384
|
)
|
383
|
-
if
|
385
|
+
if attn_config.output_proj_use_bias:
|
384
386
|
converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
|
385
387
|
state.pop(f"{o_name}.bias")
|
386
388
|
)
|
@@ -402,6 +404,7 @@ class ModelLoader:
|
|
402
404
|
):
|
403
405
|
return
|
404
406
|
prefix = additional_prefix + f"transformer_blocks.{idx}"
|
407
|
+
attn_config = config.block_config(idx).attn_config
|
405
408
|
q_name = names.cross_attn_query_proj.format(idx)
|
406
409
|
k_name = names.cross_attn_key_proj.format(idx)
|
407
410
|
v_name = names.cross_attn_value_proj.format(idx)
|
@@ -409,16 +412,16 @@ class ModelLoader:
|
|
409
412
|
if fuse_attention:
|
410
413
|
converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
|
411
414
|
self._fuse_qkv(
|
412
|
-
|
415
|
+
attn_config,
|
413
416
|
state.pop(f"{q_name}.weight"),
|
414
417
|
state.pop(f"{k_name}.weight"),
|
415
418
|
state.pop(f"{v_name}.weight"),
|
416
419
|
)
|
417
420
|
)
|
418
|
-
if
|
421
|
+
if attn_config.qkv_use_bias:
|
419
422
|
converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
|
420
423
|
self._fuse_qkv(
|
421
|
-
|
424
|
+
attn_config,
|
422
425
|
state.pop(f"{q_name}.bias"),
|
423
426
|
state.pop(f"{k_name}.bias"),
|
424
427
|
state.pop(f"{v_name}.bias"),
|
@@ -434,7 +437,7 @@ class ModelLoader:
|
|
434
437
|
converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
|
435
438
|
state.pop(f"{v_name}.weight")
|
436
439
|
)
|
437
|
-
if
|
440
|
+
if attn_config.qkv_use_bias:
|
438
441
|
converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
|
439
442
|
state.pop(f"{q_name}.bias")
|
440
443
|
)
|
@@ -449,7 +452,7 @@ class ModelLoader:
|
|
449
452
|
converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
|
450
453
|
state.pop(f"{o_name}.weight")
|
451
454
|
)
|
452
|
-
if
|
455
|
+
if attn_config.output_proj_use_bias:
|
453
456
|
converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
|
454
457
|
state.pop(f"{o_name}.bias")
|
455
458
|
)
|
@@ -496,16 +499,14 @@ class ModelLoader:
|
|
496
499
|
|
497
500
|
def _fuse_qkv(
|
498
501
|
self,
|
499
|
-
|
502
|
+
attn_config: model_config.AttentionConfig,
|
500
503
|
q: torch.Tensor,
|
501
504
|
k: torch.Tensor,
|
502
505
|
v: torch.Tensor,
|
503
506
|
) -> torch.Tensor:
|
504
|
-
q_per_kv =
|
505
|
-
|
506
|
-
)
|
507
|
-
|
508
|
-
ks = torch.split(k, config.attn_config.head_dim)
|
509
|
-
vs = torch.split(v, config.attn_config.head_dim)
|
507
|
+
q_per_kv = attn_config.num_heads // attn_config.num_query_groups
|
508
|
+
qs = torch.split(q, attn_config.head_dim * q_per_kv)
|
509
|
+
ks = torch.split(k, attn_config.head_dim)
|
510
|
+
vs = torch.split(v, attn_config.head_dim)
|
510
511
|
cycled = [t for group in zip(qs, ks, vs) for t in group]
|
511
512
|
return torch.cat(cycled)
|
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
|
|
167
167
|
lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
|
168
168
|
lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
|
169
169
|
lower_by_torch_xla2(torch.ops.aten.native_group_norm)
|
170
|
-
lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
|
171
170
|
lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
|
172
171
|
lower_by_torch_xla2(torch.ops.aten.ne)
|
173
172
|
lower_by_torch_xla2(torch.ops.aten.neg)
|
@@ -0,0 +1,78 @@
|
|
1
|
+
# Copyright 2024 The AI Edge Torch Authors.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
"""Provides lowering for coreaten to stablehlo for LayerNorm."""
|
16
|
+
|
17
|
+
import math
|
18
|
+
from typing import Optional
|
19
|
+
from ai_edge_torch.odml_torch.lowerings import registry
|
20
|
+
from ai_edge_torch.odml_torch.lowerings import utils
|
21
|
+
from jax._src.lib.mlir import ir
|
22
|
+
from jax._src.lib.mlir.dialects import hlo as stablehlo
|
23
|
+
import torch
|
24
|
+
|
25
|
+
|
26
|
+
# native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
|
27
|
+
# Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
|
28
|
+
@registry.lower(torch.ops.aten.native_layer_norm)
|
29
|
+
def _aten_native_layer_norm(
|
30
|
+
lctx,
|
31
|
+
data: ir.Value,
|
32
|
+
normalized_shape: list[int],
|
33
|
+
weight: Optional[ir.Value],
|
34
|
+
bias: Optional[ir.Value],
|
35
|
+
eps: float,
|
36
|
+
):
|
37
|
+
data_type: ir.RankedTensorType = data.type
|
38
|
+
unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
|
39
|
+
dest_shape = [
|
40
|
+
1,
|
41
|
+
unnormalized_count,
|
42
|
+
math.prod(normalized_shape),
|
43
|
+
]
|
44
|
+
dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
|
45
|
+
|
46
|
+
reshaped_data = stablehlo.reshape(dest_type, data)
|
47
|
+
|
48
|
+
one = utils.splat(1, data_type.element_type, [unnormalized_count])
|
49
|
+
zero = utils.splat(0, data_type.element_type, [unnormalized_count])
|
50
|
+
output, mean, var = stablehlo.batch_norm_training(
|
51
|
+
reshaped_data, one, zero, eps, 1
|
52
|
+
)
|
53
|
+
eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
|
54
|
+
rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
|
55
|
+
|
56
|
+
stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
|
57
|
+
normalized_shape
|
58
|
+
)
|
59
|
+
stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
|
60
|
+
mean = stablehlo.reshape(stats_type, mean)
|
61
|
+
rstd = stablehlo.reshape(stats_type, rstd)
|
62
|
+
|
63
|
+
output = stablehlo.reshape(data_type, output)
|
64
|
+
|
65
|
+
data_rank = len(data_type.shape)
|
66
|
+
normalized_rank = len(normalized_shape)
|
67
|
+
if weight is not None:
|
68
|
+
weight = stablehlo.broadcast_in_dim(
|
69
|
+
data_type, weight, list(range(data_rank - normalized_rank, data_rank))
|
70
|
+
)
|
71
|
+
output = stablehlo.multiply(weight, output)
|
72
|
+
if bias is not None:
|
73
|
+
bias = stablehlo.broadcast_in_dim(
|
74
|
+
data_type, bias, list(range(data_rank - normalized_rank, data_rank))
|
75
|
+
)
|
76
|
+
output = stablehlo.add(bias, output)
|
77
|
+
|
78
|
+
return output, mean, rstd
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240913
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
|
|
2
2
|
ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=2_ahYhvytovu9mWRifMKeqx6-0JbD7-iV5FXU890d7Y,706
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -39,27 +39,20 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
|
|
39
39
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
40
40
|
ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
41
41
|
ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
42
|
-
ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
43
|
-
ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
44
|
-
ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
|
45
|
-
ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
|
46
|
-
ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
47
|
-
ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
|
48
|
-
ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
|
49
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
50
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
|
51
|
-
ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
|
52
42
|
ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
53
|
-
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=
|
54
|
-
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=
|
55
|
-
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=
|
56
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
57
|
-
ai_edge_torch/generative/examples/
|
58
|
-
ai_edge_torch/generative/examples/
|
59
|
-
ai_edge_torch/generative/examples/
|
43
|
+
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
|
44
|
+
ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
|
45
|
+
ai_edge_torch/generative/examples/gemma/gemma.py,sha256=uejk9Mi85uRuFYIUi5XI58rf4K7TFeE5cZ1flejF8EE,7473
|
46
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=H0scyAdqRyV2wwaFx1LAa3A5oYn1C5tTdPWvbDTd_SQ,10256
|
47
|
+
ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
48
|
+
ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
|
49
|
+
ai_edge_torch/generative/examples/phi/phi2.py,sha256=wjTLCfCUDcLqvVsrPH-Wx04pOKeuigZCWHO3gL1WOEA,7072
|
50
|
+
ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
51
|
+
ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
|
52
|
+
ai_edge_torch/generative/examples/smallm/smallm.py,sha256=mzlbXxCCB10FN03QDRoPXw-cbucQM_O_Hs8hqLZAvck,4002
|
60
53
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
61
54
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
62
|
-
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=
|
55
|
+
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=evl5Rn_Hlp9-BsNmcf6liXa2syET3-Fz-zVaWjqPKx8,4657
|
63
56
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
|
64
57
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
65
58
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
|
@@ -74,29 +67,28 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
|
|
74
67
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
75
68
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
76
69
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
|
77
|
-
ai_edge_torch/generative/examples/t5/t5.py,sha256=
|
78
|
-
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=
|
70
|
+
ai_edge_torch/generative/examples/t5/t5.py,sha256=Ekg92OwIXSkSRii9OY-mp3-SExtsxOdoIDTFxm25hso,21304
|
71
|
+
ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
|
79
72
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
80
|
-
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=
|
81
|
-
ai_edge_torch/generative/examples/test_models/
|
82
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
|
73
|
+
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
|
74
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=oX_D_kU9PegBX3Fx9z_J3a1Oh2PF05F0nwZNxyLgQNA,5880
|
83
75
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
84
|
-
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=
|
85
|
-
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=
|
76
|
+
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
|
77
|
+
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Mnn_aMImR1CpC_T0CMKlp3XgoLyR7N56VR3blVSnMHQ,7007
|
86
78
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
|
87
79
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
|
88
80
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
89
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
81
|
+
ai_edge_torch/generative/layers/attention.py,sha256=d9yLaqxPCtClhNUmauOEFBKxhLnsXdN3NiYy1WspIPI,12826
|
90
82
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
91
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/builder.py,sha256=6jDNaa_djF32AjxIJtaDGBzlj3zlvl1yZivK3gC4j94,4424
|
92
84
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
|
93
|
-
ai_edge_torch/generative/layers/kv_cache.py,sha256=
|
94
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
95
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
85
|
+
ai_edge_torch/generative/layers/kv_cache.py,sha256=FveTTO0z_yi0-ZdGMuamzSvuInn6B4lesKZ4PHT2Vmg,6088
|
86
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=mil4RkGuNFBDKo3gPd9QnfGKLKPZWX9Gz2_q9hX8sNU,6407
|
87
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
|
96
88
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
97
89
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
98
90
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
99
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
91
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
|
100
92
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
101
93
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
|
102
94
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -107,15 +99,16 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
|
|
107
99
|
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
|
108
100
|
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
|
109
101
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
|
-
ai_edge_torch/generative/test/
|
111
|
-
ai_edge_torch/generative/test/test_loader.py,sha256=
|
112
|
-
ai_edge_torch/generative/test/test_model_conversion.py,sha256=
|
113
|
-
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=
|
102
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
|
103
|
+
ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
|
104
|
+
ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
|
105
|
+
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
|
114
106
|
ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
|
107
|
+
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
115
108
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
116
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
109
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=kn4TCgGAG8s4mdvPITimOBCaVyn04Ksz4gZIleFYF1o,12754
|
117
110
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
|
118
|
-
ai_edge_torch/generative/utilities/t5_loader.py,sha256=
|
111
|
+
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
119
112
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
120
113
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
121
114
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -145,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
|
|
145
138
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
|
146
139
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
|
147
140
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
148
|
-
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=
|
141
|
+
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
|
149
142
|
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
|
150
143
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
151
144
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
|
152
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
145
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
|
146
|
+
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
|
153
147
|
ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
|
154
148
|
ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
|
155
149
|
ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
|
@@ -161,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
161
155
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
162
156
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
163
157
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
165
|
-
ai_edge_torch_nightly-0.3.0.
|
166
|
-
ai_edge_torch_nightly-0.3.0.
|
167
|
-
ai_edge_torch_nightly-0.3.0.
|
168
|
-
ai_edge_torch_nightly-0.3.0.
|
158
|
+
ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
159
|
+
ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/METADATA,sha256=ahbsMN1e0Tuq_LmrkB6NE-VgVTC65KEiZX3VVmTbcWQ,1859
|
160
|
+
ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/RECORD,,
|