ai-edge-torch-nightly 0.3.0.dev20240910__py3-none-any.whl → 0.3.0.dev20240913__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py +31 -12
  2. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +25 -6
  3. ai_edge_torch/generative/examples/gemma/gemma.py +50 -30
  4. ai_edge_torch/generative/examples/gemma/gemma2.py +85 -58
  5. ai_edge_torch/generative/examples/{experimental/phi → phi}/convert_to_tflite.py +11 -12
  6. ai_edge_torch/generative/examples/{experimental/phi → phi}/phi2.py +46 -43
  7. ai_edge_torch/generative/examples/{experimental/gemma → smallm}/convert_to_tflite.py +12 -14
  8. ai_edge_torch/generative/examples/smallm/smallm.py +122 -0
  9. ai_edge_torch/generative/examples/stable_diffusion/clip.py +11 -5
  10. ai_edge_torch/generative/examples/t5/t5.py +35 -22
  11. ai_edge_torch/generative/examples/t5/t5_attention.py +18 -13
  12. ai_edge_torch/generative/examples/test_models/toy_model.py +15 -13
  13. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +74 -33
  14. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +25 -6
  15. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +55 -34
  16. ai_edge_torch/generative/layers/attention.py +77 -73
  17. ai_edge_torch/generative/layers/builder.py +5 -3
  18. ai_edge_torch/generative/layers/kv_cache.py +163 -51
  19. ai_edge_torch/generative/layers/model_config.py +38 -19
  20. ai_edge_torch/generative/layers/normalization.py +158 -0
  21. ai_edge_torch/generative/layers/unet/blocks_2d.py +0 -2
  22. ai_edge_torch/generative/test/{test_experimental_ekv.py → test_kv_cache.py} +12 -24
  23. ai_edge_torch/generative/test/test_loader.py +1 -1
  24. ai_edge_torch/generative/test/test_model_conversion.py +72 -34
  25. ai_edge_torch/generative/test/test_model_conversion_large.py +51 -23
  26. ai_edge_torch/generative/test/utils.py +54 -0
  27. ai_edge_torch/generative/utilities/loader.py +15 -15
  28. ai_edge_torch/generative/utilities/t5_loader.py +21 -20
  29. ai_edge_torch/odml_torch/lowerings/__init__.py +1 -0
  30. ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +0 -1
  31. ai_edge_torch/odml_torch/lowerings/_layer_norm.py +78 -0
  32. ai_edge_torch/version.py +1 -1
  33. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/METADATA +1 -1
  34. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/RECORD +39 -45
  35. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +0 -219
  36. ai_edge_torch/generative/examples/experimental/phi/__init__.py +0 -14
  37. ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py +0 -14
  38. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +0 -87
  39. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +0 -205
  40. ai_edge_torch/generative/examples/phi2/__init__.py +0 -14
  41. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +0 -67
  42. ai_edge_torch/generative/examples/phi2/phi2.py +0 -189
  43. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +0 -176
  44. /ai_edge_torch/generative/examples/{experimental → phi}/__init__.py +0 -0
  45. /ai_edge_torch/generative/examples/{experimental/gemma → smallm}/__init__.py +0 -0
  46. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/LICENSE +0 -0
  47. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/WHEEL +0 -0
  48. {ai_edge_torch_nightly-0.3.0.dev20240910.dist-info → ai_edge_torch_nightly-0.3.0.dev20240913.dist-info}/top_level.txt +0 -0
@@ -221,7 +221,8 @@ class ModelLoader:
221
221
  converted_state: Dict[str, torch.Tensor],
222
222
  ):
223
223
  prefix = f"transformer_blocks.{idx}"
224
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
224
+ ff_config = config.block_config(idx).ff_config
225
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
225
226
  ff_up_proj_name = self._names.ff_up_proj.format(idx)
226
227
  ff_down_proj_name = self._names.ff_down_proj.format(idx)
227
228
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -230,7 +231,7 @@ class ModelLoader:
230
231
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
231
232
  f"{ff_down_proj_name}.weight"
232
233
  )
233
- if config.ff_config.use_bias:
234
+ if ff_config.use_bias:
234
235
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
235
236
  f"{ff_up_proj_name}.bias"
236
237
  )
@@ -250,7 +251,7 @@ class ModelLoader:
250
251
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
251
252
  f"{ff_gate_proj_name}.weight"
252
253
  )
253
- if config.ff_config.use_bias:
254
+ if ff_config.use_bias:
254
255
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
255
256
  f"{ff_up_proj_name}.bias"
256
257
  )
@@ -289,6 +290,7 @@ class ModelLoader:
289
290
  converted_state: Dict[str, torch.Tensor],
290
291
  ):
291
292
  prefix = f"transformer_blocks.{idx}"
293
+ attn_config = config.block_config(idx).attn_config
292
294
  if self._names.attn_fused_qkv_proj:
293
295
  fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
294
296
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
@@ -300,13 +302,13 @@ class ModelLoader:
300
302
  v_name = self._names.attn_value_proj.format(idx)
301
303
  converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = (
302
304
  self._fuse_qkv(
303
- config,
305
+ attn_config,
304
306
  state.pop(f"{q_name}.weight"),
305
307
  state.pop(f"{k_name}.weight"),
306
308
  state.pop(f"{v_name}.weight"),
307
309
  )
308
310
  )
309
- if config.attn_config.qkv_use_bias:
311
+ if attn_config.qkv_use_bias:
310
312
  if self._names.attn_fused_qkv_proj:
311
313
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
312
314
  f"{fused_qkv_name}.bias"
@@ -314,7 +316,7 @@ class ModelLoader:
314
316
  else:
315
317
  converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = (
316
318
  self._fuse_qkv(
317
- config,
319
+ attn_config,
318
320
  state.pop(f"{q_name}.bias"),
319
321
  state.pop(f"{k_name}.bias"),
320
322
  state.pop(f"{v_name}.bias"),
@@ -325,7 +327,7 @@ class ModelLoader:
325
327
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
326
328
  state.pop(f"{o_name}.weight")
327
329
  )
328
- if config.attn_config.output_proj_use_bias:
330
+ if attn_config.output_proj_use_bias:
329
331
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
330
332
  state.pop(f"{o_name}.bias")
331
333
  )
@@ -360,18 +362,16 @@ class ModelLoader:
360
362
 
361
363
  def _fuse_qkv(
362
364
  self,
363
- config: model_config.ModelConfig,
365
+ attn_config: model_config.AttentionConfig,
364
366
  q: torch.Tensor,
365
367
  k: torch.Tensor,
366
368
  v: torch.Tensor,
367
369
  ) -> torch.Tensor:
368
- if config.attn_config.qkv_fused_interleaved:
369
- q_per_kv = (
370
- config.attn_config.num_heads // config.attn_config.num_query_groups
371
- )
372
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
373
- ks = torch.split(k, config.attn_config.head_dim)
374
- vs = torch.split(v, config.attn_config.head_dim)
370
+ if attn_config.qkv_fused_interleaved:
371
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
372
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
373
+ ks = torch.split(k, attn_config.head_dim)
374
+ vs = torch.split(v, attn_config.head_dim)
375
375
  cycled = [t for group in zip(qs, ks, vs) for t in group]
376
376
  return torch.cat(cycled)
377
377
  else:
@@ -279,7 +279,8 @@ class ModelLoader:
279
279
  prefix = additional_prefix + f"transformer_blocks.{idx}"
280
280
  if names.ff_up_proj is None or names.ff_down_proj is None:
281
281
  return
282
- if config.ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
282
+ ff_config = config.block_config(idx).ff_config
283
+ if ff_config.type == model_config.FeedForwardType.SEQUENTIAL:
283
284
  ff_up_proj_name = names.ff_up_proj.format(idx)
284
285
  ff_down_proj_name = names.ff_down_proj.format(idx)
285
286
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
@@ -288,7 +289,7 @@ class ModelLoader:
288
289
  converted_state[f"{prefix}.ff.w2.weight"] = state.pop(
289
290
  f"{ff_down_proj_name}.weight"
290
291
  )
291
- if config.ff_config.use_bias:
292
+ if ff_config.use_bias:
292
293
  converted_state[f"{prefix}.ff.w1.bias"] = state.pop(
293
294
  f"{ff_up_proj_name}.bias"
294
295
  )
@@ -309,7 +310,7 @@ class ModelLoader:
309
310
  converted_state[f"{prefix}.ff.w1.weight"] = state.pop(
310
311
  f"{ff_gate_proj_name}.weight"
311
312
  )
312
- if config.ff_config.use_bias:
313
+ if ff_config.use_bias:
313
314
  converted_state[f"{prefix}.ff.w3.bias"] = state.pop(
314
315
  f"{ff_up_proj_name}.bias"
315
316
  )
@@ -337,20 +338,21 @@ class ModelLoader:
337
338
  ):
338
339
  return
339
340
  prefix = additional_prefix + f"transformer_blocks.{idx}"
341
+ attn_config = config.block_config(idx).attn_config
340
342
  q_name = names.attn_query_proj.format(idx)
341
343
  k_name = names.attn_key_proj.format(idx)
342
344
  v_name = names.attn_value_proj.format(idx)
343
345
  # model.encoder.transformer_blocks[0].atten_func.q_projection.weight
344
346
  if fuse_attention:
345
347
  converted_state[f"{prefix}.atten_func.attn.weight"] = self._fuse_qkv(
346
- config,
348
+ attn_config,
347
349
  state.pop(f"{q_name}.weight"),
348
350
  state.pop(f"{k_name}.weight"),
349
351
  state.pop(f"{v_name}.weight"),
350
352
  )
351
- if config.attn_config.qkv_use_bias:
353
+ if attn_config.qkv_use_bias:
352
354
  converted_state[f"{prefix}.atten_func.attn.bias"] = self._fuse_qkv(
353
- config,
355
+ attn_config,
354
356
  state.pop(f"{q_name}.bias"),
355
357
  state.pop(f"{k_name}.bias"),
356
358
  state.pop(f"{v_name}.bias"),
@@ -365,7 +367,7 @@ class ModelLoader:
365
367
  converted_state[f"{prefix}.atten_func.v_projection.weight"] = state.pop(
366
368
  f"{v_name}.weight"
367
369
  )
368
- if config.attn_config.qkv_use_bias:
370
+ if attn_config.qkv_use_bias:
369
371
  converted_state[f"{prefix}.atten_func.q_projection.bias"] = state.pop(
370
372
  f"{q_name}.bias"
371
373
  )
@@ -380,7 +382,7 @@ class ModelLoader:
380
382
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = (
381
383
  state.pop(f"{o_name}.weight")
382
384
  )
383
- if config.attn_config.output_proj_use_bias:
385
+ if attn_config.output_proj_use_bias:
384
386
  converted_state[f"{prefix}.atten_func.output_projection.bias"] = (
385
387
  state.pop(f"{o_name}.bias")
386
388
  )
@@ -402,6 +404,7 @@ class ModelLoader:
402
404
  ):
403
405
  return
404
406
  prefix = additional_prefix + f"transformer_blocks.{idx}"
407
+ attn_config = config.block_config(idx).attn_config
405
408
  q_name = names.cross_attn_query_proj.format(idx)
406
409
  k_name = names.cross_attn_key_proj.format(idx)
407
410
  v_name = names.cross_attn_value_proj.format(idx)
@@ -409,16 +412,16 @@ class ModelLoader:
409
412
  if fuse_attention:
410
413
  converted_state[f"{prefix}.cross_atten_func.attn.weight"] = (
411
414
  self._fuse_qkv(
412
- config,
415
+ attn_config,
413
416
  state.pop(f"{q_name}.weight"),
414
417
  state.pop(f"{k_name}.weight"),
415
418
  state.pop(f"{v_name}.weight"),
416
419
  )
417
420
  )
418
- if config.attn_config.qkv_use_bias:
421
+ if attn_config.qkv_use_bias:
419
422
  converted_state[f"{prefix}.cross_atten_func.attn.bias"] = (
420
423
  self._fuse_qkv(
421
- config,
424
+ attn_config,
422
425
  state.pop(f"{q_name}.bias"),
423
426
  state.pop(f"{k_name}.bias"),
424
427
  state.pop(f"{v_name}.bias"),
@@ -434,7 +437,7 @@ class ModelLoader:
434
437
  converted_state[f"{prefix}.cross_atten_func.v_projection.weight"] = (
435
438
  state.pop(f"{v_name}.weight")
436
439
  )
437
- if config.attn_config.qkv_use_bias:
440
+ if attn_config.qkv_use_bias:
438
441
  converted_state[f"{prefix}.cross_atten_func.q_projection.bias"] = (
439
442
  state.pop(f"{q_name}.bias")
440
443
  )
@@ -449,7 +452,7 @@ class ModelLoader:
449
452
  converted_state[f"{prefix}.cross_atten_func.output_projection.weight"] = (
450
453
  state.pop(f"{o_name}.weight")
451
454
  )
452
- if config.attn_config.output_proj_use_bias:
455
+ if attn_config.output_proj_use_bias:
453
456
  converted_state[f"{prefix}.cross_atten_func.output_projection.bias"] = (
454
457
  state.pop(f"{o_name}.bias")
455
458
  )
@@ -496,16 +499,14 @@ class ModelLoader:
496
499
 
497
500
  def _fuse_qkv(
498
501
  self,
499
- config: model_config.ModelConfig,
502
+ attn_config: model_config.AttentionConfig,
500
503
  q: torch.Tensor,
501
504
  k: torch.Tensor,
502
505
  v: torch.Tensor,
503
506
  ) -> torch.Tensor:
504
- q_per_kv = (
505
- config.attn_config.num_heads // config.attn_config.num_query_groups
506
- )
507
- qs = torch.split(q, config.attn_config.head_dim * q_per_kv)
508
- ks = torch.split(k, config.attn_config.head_dim)
509
- vs = torch.split(v, config.attn_config.head_dim)
507
+ q_per_kv = attn_config.num_heads // attn_config.num_query_groups
508
+ qs = torch.split(q, attn_config.head_dim * q_per_kv)
509
+ ks = torch.split(k, attn_config.head_dim)
510
+ vs = torch.split(v, attn_config.head_dim)
510
511
  cycled = [t for group in zip(qs, ks, vs) for t in group]
511
512
  return torch.cat(cycled)
@@ -16,6 +16,7 @@ from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
+ from . import _layer_norm
19
20
  from . import context
20
21
  from . import registry
21
22
  from . import utils
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
167
167
  lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
168
168
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
169
169
  lower_by_torch_xla2(torch.ops.aten.native_group_norm)
170
- lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
171
170
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
172
171
  lower_by_torch_xla2(torch.ops.aten.ne)
173
172
  lower_by_torch_xla2(torch.ops.aten.neg)
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
27
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
28
+ @registry.lower(torch.ops.aten.native_layer_norm)
29
+ def _aten_native_layer_norm(
30
+ lctx,
31
+ data: ir.Value,
32
+ normalized_shape: list[int],
33
+ weight: Optional[ir.Value],
34
+ bias: Optional[ir.Value],
35
+ eps: float,
36
+ ):
37
+ data_type: ir.RankedTensorType = data.type
38
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
39
+ dest_shape = [
40
+ 1,
41
+ unnormalized_count,
42
+ math.prod(normalized_shape),
43
+ ]
44
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
45
+
46
+ reshaped_data = stablehlo.reshape(dest_type, data)
47
+
48
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
49
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
50
+ output, mean, var = stablehlo.batch_norm_training(
51
+ reshaped_data, one, zero, eps, 1
52
+ )
53
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
54
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
55
+
56
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
57
+ normalized_shape
58
+ )
59
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
60
+ mean = stablehlo.reshape(stats_type, mean)
61
+ rstd = stablehlo.reshape(stats_type, rstd)
62
+
63
+ output = stablehlo.reshape(data_type, output)
64
+
65
+ data_rank = len(data_type.shape)
66
+ normalized_rank = len(normalized_shape)
67
+ if weight is not None:
68
+ weight = stablehlo.broadcast_in_dim(
69
+ data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ )
71
+ output = stablehlo.multiply(weight, output)
72
+ if bias is not None:
73
+ bias = stablehlo.broadcast_in_dim(
74
+ data_type, bias, list(range(data_rank - normalized_rank, data_rank))
75
+ )
76
+ output = stablehlo.add(bias, output)
77
+
78
+ return output, mean, rstd
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240910"
16
+ __version__ = "0.3.0.dev20240913"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240910
3
+ Version: 0.3.0.dev20240913
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=e4sh_RFYgNHGoVuOeICnFZtLu1MQCNv7qpq94nKFarU,706
5
+ ai_edge_torch/version.py,sha256=2_ahYhvytovu9mWRifMKeqx6-0JbD7-iV5FXU890d7Y,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -39,27 +39,20 @@ ai_edge_torch/debug/test/test_search_model.py,sha256=-RuU0QsjqkfzZF2IbeA55MoeVOa
39
39
  ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
40
40
  ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
41
41
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
- ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
- ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
- ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
46
- ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
- ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
- ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
49
- ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
- ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
52
42
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
- ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
- ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
57
- ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
- ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
43
+ ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=ZJvw8uFVu7FEJ7eXfpzn-pPKgPELoxkGz4Zg7LKKMSI,3048
44
+ ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=hM-fwjZG53p1UE_lkovLMmHRDHleJsb6_0ib0_k0v54,3040
45
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=uejk9Mi85uRuFYIUi5XI58rf4K7TFeE5cZ1flejF8EE,7473
46
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=H0scyAdqRyV2wwaFx1LAa3A5oYn1C5tTdPWvbDTd_SQ,10256
47
+ ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
+ ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
49
+ ai_edge_torch/generative/examples/phi/phi2.py,sha256=wjTLCfCUDcLqvVsrPH-Wx04pOKeuigZCWHO3gL1WOEA,7072
50
+ ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
+ ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
52
+ ai_edge_torch/generative/examples/smallm/smallm.py,sha256=mzlbXxCCB10FN03QDRoPXw-cbucQM_O_Hs8hqLZAvck,4002
60
53
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
54
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
55
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=evl5Rn_Hlp9-BsNmcf6liXa2syET3-Fz-zVaWjqPKx8,4657
63
56
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=7ra36nM5tQwSw-vi6QCFLx5IssZhT-6yVK4H3XsAc4w,5044
64
57
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
65
58
  ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
@@ -74,29 +67,28 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=ZE6H
74
67
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
75
68
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
69
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz70WOvNpTJ9LFkiDnlwgJiXfUZCVk,4548
77
- ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
70
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=Ekg92OwIXSkSRii9OY-mp3-SExtsxOdoIDTFxm25hso,21304
71
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqXquaFQPvCFBFF5zOnmGVb3Hg,8731
79
72
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
73
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=QyLeCqDnk71WvvFH68g9UeF-HytonSk1ItGF9dc7Zj8,5854
74
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=oX_D_kU9PegBX3Fx9z_J3a1Oh2PF05F0nwZNxyLgQNA,5880
83
75
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
- ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
76
+ ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
77
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Mnn_aMImR1CpC_T0CMKlp3XgoLyR7N56VR3blVSnMHQ,7007
86
78
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
79
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
80
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
89
- ai_edge_torch/generative/layers/attention.py,sha256=2UujQePRJ1LK02PN-hGcuMu0ooCJC6ETfPvzEYVFyho,12284
81
+ ai_edge_torch/generative/layers/attention.py,sha256=d9yLaqxPCtClhNUmauOEFBKxhLnsXdN3NiYy1WspIPI,12826
90
82
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
91
- ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
83
+ ai_edge_torch/generative/layers/builder.py,sha256=6jDNaa_djF32AjxIJtaDGBzlj3zlvl1yZivK3gC4j94,4424
92
84
  ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
93
- ai_edge_torch/generative/layers/kv_cache.py,sha256=Ob8QeXWW5xt-6hcGA0uoC48eRQ8lfvKca8JbWtFx2CE,3082
94
- ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
95
- ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
85
+ ai_edge_torch/generative/layers/kv_cache.py,sha256=FveTTO0z_yi0-ZdGMuamzSvuInn6B4lesKZ4PHT2Vmg,6088
86
+ ai_edge_torch/generative/layers/model_config.py,sha256=mil4RkGuNFBDKo3gPd9QnfGKLKPZWX9Gz2_q9hX8sNU,6407
87
+ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
96
88
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
97
89
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
98
90
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
91
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
100
92
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
101
93
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
102
94
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -107,15 +99,16 @@ ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC
107
99
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
100
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
101
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
111
- ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
112
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=b3InJ8Rx03YtHpE9h-j0pSXAY1cCf-dLlx4Y5LSJnRQ,5174
113
- ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=9JXcd-rX8MpsYeEWUFEXf783GOwYOLY64KzDfFdmRJ8,4484
102
+ ai_edge_torch/generative/test/test_kv_cache.py,sha256=W6Bh0gYDzmwb0j9HdD5_D7Z7FPToP2HSyFrmwIXuFqo,3793
103
+ ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
104
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
105
+ ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
114
106
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
107
+ ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
115
108
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
116
- ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
109
+ ai_edge_torch/generative/utilities/loader.py,sha256=kn4TCgGAG8s4mdvPITimOBCaVyn04Ksz4gZIleFYF1o,12754
117
110
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
118
- ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
111
+ ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
119
112
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
120
113
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
121
114
  ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
@@ -145,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
145
138
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
146
139
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
147
140
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
148
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
141
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
149
142
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
150
143
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
151
144
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
152
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
145
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
146
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
153
147
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
148
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
149
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -161,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
155
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
156
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
157
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/METADATA,sha256=WFNExTO6eF-tAWPmDdQDlr9dvplcoNB0uPdVxSNXYHk,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/RECORD,,
158
+ ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
159
+ ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/METADATA,sha256=ahbsMN1e0Tuq_LmrkB6NE-VgVTC65KEiZX3VVmTbcWQ,1859
160
+ ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
161
+ ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
162
+ ai_edge_torch_nightly-0.3.0.dev20240913.dist-info/RECORD,,