ai-edge-torch-nightly 0.2.0.dev20240730__py3-none-any.whl → 0.2.0.dev20240805__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (92) hide show
  1. ai_edge_torch/__init__.py +1 -0
  2. ai_edge_torch/convert/conversion.py +12 -8
  3. ai_edge_torch/convert/conversion_utils.py +38 -20
  4. ai_edge_torch/convert/converter.py +11 -5
  5. ai_edge_torch/convert/fx_passes/__init__.py +3 -4
  6. ai_edge_torch/convert/fx_passes/_pass_base.py +6 -2
  7. ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py +34 -24
  8. ai_edge_torch/convert/fx_passes/build_interpolate_composite_pass.py +11 -10
  9. ai_edge_torch/convert/fx_passes/canonicalize_pass.py +2 -3
  10. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +18 -7
  11. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py +4 -3
  12. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py +6 -4
  13. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +9 -5
  14. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py +1 -2
  15. ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/pass_body.py +14 -10
  16. ai_edge_torch/convert/test/test_convert.py +39 -16
  17. ai_edge_torch/convert/test/test_convert_composites.py +115 -86
  18. ai_edge_torch/convert/test/test_convert_multisig.py +18 -10
  19. ai_edge_torch/convert/test/test_to_channel_last_io.py +1 -2
  20. ai_edge_torch/convert/to_channel_last_io.py +6 -2
  21. ai_edge_torch/debug/culprit.py +41 -16
  22. ai_edge_torch/debug/test/test_culprit.py +4 -3
  23. ai_edge_torch/debug/test/test_search_model.py +4 -3
  24. ai_edge_torch/debug/utils.py +3 -1
  25. ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py +4 -3
  26. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +10 -8
  27. ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py +7 -4
  28. ai_edge_torch/generative/examples/experimental/phi/phi2.py +10 -8
  29. ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py +1 -2
  30. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +10 -8
  31. ai_edge_torch/generative/examples/gemma/convert_to_tflite.py +4 -3
  32. ai_edge_torch/generative/examples/gemma/gemma.py +13 -9
  33. ai_edge_torch/generative/examples/phi2/convert_to_tflite.py +7 -4
  34. ai_edge_torch/generative/examples/phi2/phi2.py +13 -9
  35. ai_edge_torch/generative/examples/stable_diffusion/attention.py +3 -1
  36. ai_edge_torch/generative/examples/stable_diffusion/clip.py +20 -9
  37. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +26 -13
  38. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +15 -7
  39. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +47 -16
  40. ai_edge_torch/generative/examples/stable_diffusion/encoder.py +4 -3
  41. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +42 -12
  42. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py +4 -3
  43. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py +4 -3
  44. ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +4 -3
  45. ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py +4 -1
  46. ai_edge_torch/generative/examples/stable_diffusion/util.py +9 -3
  47. ai_edge_torch/generative/examples/t5/convert_to_tflite.py +8 -5
  48. ai_edge_torch/generative/examples/t5/t5.py +158 -125
  49. ai_edge_torch/generative/examples/t5/t5_attention.py +15 -7
  50. ai_edge_torch/generative/examples/test_models/toy_model.py +7 -5
  51. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +3 -4
  52. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +4 -5
  53. ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py +4 -3
  54. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +10 -8
  55. ai_edge_torch/generative/fx_passes/__init__.py +1 -2
  56. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +6 -3
  57. ai_edge_torch/generative/layers/attention.py +19 -11
  58. ai_edge_torch/generative/layers/builder.py +3 -4
  59. ai_edge_torch/generative/layers/kv_cache.py +4 -3
  60. ai_edge_torch/generative/layers/model_config.py +6 -2
  61. ai_edge_torch/generative/layers/rotary_position_embedding.py +3 -1
  62. ai_edge_torch/generative/layers/scaled_dot_product_attention.py +1 -2
  63. ai_edge_torch/generative/layers/unet/blocks_2d.py +69 -21
  64. ai_edge_torch/generative/layers/unet/builder.py +7 -4
  65. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py +9 -4
  66. ai_edge_torch/generative/quantize/example.py +2 -3
  67. ai_edge_torch/generative/quantize/quant_recipe.py +2 -1
  68. ai_edge_torch/generative/quantize/quant_recipe_utils.py +10 -0
  69. ai_edge_torch/generative/quantize/quant_recipes.py +8 -0
  70. ai_edge_torch/generative/test/loader_test.py +5 -4
  71. ai_edge_torch/generative/test/test_experimental_ekv.py +22 -11
  72. ai_edge_torch/generative/test/test_model_conversion.py +2 -3
  73. ai_edge_torch/generative/test/test_quantize.py +45 -47
  74. ai_edge_torch/generative/utilities/loader.py +55 -28
  75. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +86 -33
  76. ai_edge_torch/generative/utilities/t5_loader.py +77 -48
  77. ai_edge_torch/hlfb/mark_pattern/__init__.py +2 -3
  78. ai_edge_torch/hlfb/mark_pattern/pattern.py +16 -7
  79. ai_edge_torch/hlfb/test/test_mark_pattern.py +4 -3
  80. ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py +12 -6
  81. ai_edge_torch/model.py +8 -5
  82. ai_edge_torch/quantize/pt2e_quantizer.py +30 -15
  83. ai_edge_torch/quantize/pt2e_quantizer_utils.py +30 -11
  84. ai_edge_torch/quantize/quant_config.py +6 -2
  85. ai_edge_torch/testing/model_coverage/model_coverage.py +11 -7
  86. ai_edge_torch/version.py +16 -0
  87. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/METADATA +5 -5
  88. ai_edge_torch_nightly-0.2.0.dev20240805.dist-info/RECORD +133 -0
  89. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/WHEEL +1 -1
  90. ai_edge_torch_nightly-0.2.0.dev20240730.dist-info/RECORD +0 -132
  91. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/LICENSE +0 -0
  92. {ai_edge_torch_nightly-0.2.0.dev20240730.dist-info → ai_edge_torch_nightly-0.2.0.dev20240805.dist-info}/top_level.txt +0 -0
@@ -23,14 +23,13 @@ import os
23
23
  import sys
24
24
  from typing import Any, Callable, Generator, List, Optional, Tuple, Union
25
25
 
26
+ import ai_edge_torch
27
+ from ai_edge_torch.debug import utils
26
28
  from functorch.compile import minifier as fx_minifier
27
29
  import torch
28
30
  from torch._functorch import aot_autograd
29
31
  import torch.utils._pytree as pytree
30
32
 
31
- import ai_edge_torch
32
- from ai_edge_torch.debug import utils
33
-
34
33
  _torch_float_dtypes = {
35
34
  torch.float32,
36
35
  torch.float,
@@ -120,21 +119,29 @@ class Culprit(SearchResult):
120
119
  # TODO (b/321263453): Support Python code gen with sample arg tensor values.
121
120
  random_inputs = True
122
121
 
123
- graph_module_code = self.graph_module.print_readable(print_output=False).rstrip()
122
+ graph_module_code = self.graph_module.print_readable(
123
+ print_output=False
124
+ ).rstrip()
124
125
 
125
126
  input_strs = []
126
127
  for value in self.inputs:
127
128
  if torch.is_tensor(value):
128
129
  if not random_inputs:
129
- input_strs.append(f"# size={_get_shape_str(value)}, dtype={value.dtype}")
130
- input_strs.append(f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),")
130
+ input_strs.append(
131
+ f"# size={_get_shape_str(value)}, dtype={value.dtype}"
132
+ )
133
+ input_strs.append(
134
+ f"torch.load(io.BytesIO({_tensor_to_buffer(value)})),"
135
+ )
131
136
  else:
132
137
  input_strs.append(_tensor_to_random_tensor_call(value) + ",")
133
138
  else:
134
139
  input_strs.append(str(value) + ",")
135
140
 
136
141
  inputs_code = (
137
- "_args = (\n" + "\n".join([" " * 4 + code for code in input_strs]) + "\n)"
142
+ "_args = (\n"
143
+ + "\n".join([" " * 4 + code for code in input_strs])
144
+ + "\n)"
138
145
  )
139
146
 
140
147
  code = graph_module_code + "\n\n" + inputs_code
@@ -157,7 +164,9 @@ class Culprit(SearchResult):
157
164
  + "from torch import device\n"
158
165
  + "import ai_edge_torch\n\n"
159
166
  + definitions
160
- + f"\n\n_edge_model = ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(), _args)\n"
167
+ + "\n\n_edge_model ="
168
+ f" ai_edge_torch.convert({_CULPRIT_GRAPH_MODULE_NAME}().eval(),"
169
+ " _args)\n"
161
170
  )
162
171
  if self._runtime_errors:
163
172
  code += "_edge_model(*_args)\n"
@@ -212,7 +221,9 @@ def _normalize_getitem_nodes(fx_gm: torch.fx.GraphModule):
212
221
  return fx_gm
213
222
 
214
223
 
215
- def _erase_unused_inputs(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
224
+ def _erase_unused_inputs(
225
+ fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
226
+ ):
216
227
  fx_gm = copy.deepcopy(fx_gm)
217
228
  inputs = tuple(inputs)
218
229
  args = fx_gm.graph.process_inputs(*inputs)
@@ -316,7 +327,9 @@ def _erase_sub_gm_from_gm(
316
327
  return fx_gm, fx_inputs
317
328
 
318
329
 
319
- def _normalize_minified_fx_gm(fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]):
330
+ def _normalize_minified_fx_gm(
331
+ fx_gm: torch.fx.GraphModule, inputs: Tuple[torch.Tensor]
332
+ ):
320
333
  fx_gm, inputs = _erase_unused_inputs(fx_gm, inputs)
321
334
  fx_gm = _lift_dead_ops_to_outputs(fx_gm)
322
335
  fx_gm, _ = aot_autograd.aot_export_module(fx_gm, inputs, trace_joint=False)
@@ -374,7 +387,8 @@ def _search_model(
374
387
  ep = torch.export.export(model, export_args)
375
388
  except Exception as err:
376
389
  raise ValueError(
377
- "Your model is not exportable by torch.export.export. Please modify your model to be torch-exportable first."
390
+ "Your model is not exportable by torch.export.export. Please modify"
391
+ " your model to be torch-exportable first."
378
392
  ) from err
379
393
  else:
380
394
  ep = model
@@ -392,7 +406,9 @@ def _search_model(
392
406
  xla_hlo_debug_value = os.environ["XLA_HLO_DEBUG"]
393
407
  del os.environ["XLA_HLO_DEBUG"]
394
408
 
395
- create_minified_hlo_graph = torch._functorch.fx_minifier.create_minified_hlo_graph
409
+ create_minified_hlo_graph = (
410
+ torch._functorch.fx_minifier.create_minified_hlo_graph
411
+ )
396
412
  torch._functorch.fx_minifier.create_minified_hlo_graph = (
397
413
  lambda *args, **kwargs: None
398
414
  )
@@ -403,7 +419,9 @@ def _search_model(
403
419
  if xla_hlo_debug_value is not None:
404
420
  os.environ["XLA_HLO_DEBUG"] = xla_hlo_debug_value
405
421
 
406
- torch._functorch.fx_minifier.create_minified_hlo_graph = create_minified_hlo_graph
422
+ torch._functorch.fx_minifier.create_minified_hlo_graph = (
423
+ create_minified_hlo_graph
424
+ )
407
425
 
408
426
  found_culprits_num = 0
409
427
  while True:
@@ -420,7 +438,9 @@ def _search_model(
420
438
  max_granularity=max_granularity,
421
439
  )
422
440
 
423
- min_fx_gm, min_inputs = _normalize_minified_fx_gm(raw_min_fx_gm, raw_min_inputs)
441
+ min_fx_gm, min_inputs = _normalize_minified_fx_gm(
442
+ raw_min_fx_gm, raw_min_inputs
443
+ )
424
444
  found_culprits_num += 1
425
445
  yield SearchResult(min_fx_gm, min_inputs)
426
446
 
@@ -429,7 +449,10 @@ def _search_model(
429
449
  )
430
450
 
431
451
  except RuntimeError as e:
432
- if str(e) == "Input graph did not fail the tester" and found_culprits_num > 0:
452
+ if (
453
+ str(e) == "Input graph did not fail the tester"
454
+ and found_culprits_num > 0
455
+ ):
433
456
  break
434
457
  raise e
435
458
 
@@ -467,5 +490,7 @@ def find_culprits(
467
490
  enable_fx_minifier_logging=enable_fx_minifier_logging,
468
491
  ):
469
492
  yield Culprit(
470
- search_result.graph_module, search_result.inputs, _runtime_errors=runtime_errors
493
+ search_result.graph_module,
494
+ search_result.inputs,
495
+ _runtime_errors=runtime_errors,
471
496
  )
@@ -19,16 +19,17 @@ import io
19
19
  import sys
20
20
  import unittest
21
21
 
22
- import torch
23
-
24
22
  from ai_edge_torch.debug import find_culprits
23
+ import torch
25
24
 
26
25
  _test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
26
 
28
27
  _test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
29
28
 
30
29
 
31
- @torch.library.impl(_test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd")
30
+ @torch.library.impl(
31
+ _test_culprit_lib, "non_lowerable_op", "CompositeExplicitAutograd"
32
+ )
32
33
  def non_lowerable_op(x):
33
34
  if x.max() > 10.0:
34
35
  return x + 1.0
@@ -16,9 +16,8 @@
16
16
 
17
17
  import unittest
18
18
 
19
- import torch
20
-
21
19
  from ai_edge_torch.debug import _search_model
20
+ import torch
22
21
 
23
22
 
24
23
  class TestSearchModel(unittest.TestCase):
@@ -43,7 +42,9 @@ class TestSearchModel(unittest.TestCase):
43
42
 
44
43
  results = list(_search_model(find_subgraph_with_sub, model, args))
45
44
  self.assertEqual(len(results), 2)
46
- self.assertIn(torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes])
45
+ self.assertIn(
46
+ torch.ops.aten.sub.Tensor, [n.target for n in results[0].graph.nodes]
47
+ )
47
48
 
48
49
 
49
50
  if __name__ == "__main__":
@@ -21,7 +21,9 @@ import torch.fx._pytree as fx_pytree
21
21
  from torch.utils import _pytree as pytree
22
22
 
23
23
 
24
- def exported_program_to_fx_graph_module_and_inputs(ep: torch.export.ExportedProgram):
24
+ def exported_program_to_fx_graph_module_and_inputs(
25
+ ep: torch.export.ExportedProgram,
26
+ ):
25
27
  fx_gm = ep.graph_module
26
28
  fx_inputs = pytree.tree_map(
27
29
  torch.tensor, ep._graph_module_flat_inputs(*ep.example_inputs)
@@ -20,12 +20,11 @@
20
20
  import os
21
21
  from pathlib import Path
22
22
 
23
- import torch
24
-
25
23
  import ai_edge_torch
26
24
  from ai_edge_torch.generative.examples.experimental.gemma import gemma
27
25
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
26
  from ai_edge_torch.generative.quantize import quant_recipes
27
+ import torch
29
28
 
30
29
 
31
30
  def convert_gemma_to_tflite(
@@ -79,7 +78,9 @@ def convert_gemma_to_tflite(
79
78
  )
80
79
  .convert(quant_config=quant_config)
81
80
  )
82
- edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
81
+ edge_model.export(
82
+ f'/tmp/gemma_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
83
+ )
83
84
 
84
85
 
85
86
  if __name__ == '__main__':
@@ -21,16 +21,15 @@ import os
21
21
  from pathlib import Path
22
22
  from typing import Tuple
23
23
 
24
- import numpy as np
25
- import torch
26
- import torch.nn as nn
27
-
28
24
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
29
25
  import ai_edge_torch.generative.layers.builder as builder
30
26
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
31
27
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
32
28
  import ai_edge_torch.generative.layers.model_config as cfg
33
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
+ import numpy as np
31
+ import torch
32
+ import torch.nn as nn
34
33
 
35
34
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
35
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -81,7 +80,9 @@ class Gemma(nn.Module):
81
80
  device=torch.device("cpu"),
82
81
  )
83
82
  self.mask_cache = attn_utils.build_causal_mask_cache(
84
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
83
+ size=config.kv_cache_max,
84
+ dtype=torch.float32,
85
+ device=torch.device("cpu"),
85
86
  )
86
87
  self.config = config
87
88
 
@@ -93,9 +94,10 @@ class Gemma(nn.Module):
93
94
  kv_cache: kv_utils.EKVCache,
94
95
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
95
96
  B, T = tokens.size()
96
- assert (
97
- self.config.max_seq_len >= T
98
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
97
+ assert self.config.max_seq_len >= T, (
98
+ f"Cannot forward sequence of length {T}, max seq length is only"
99
+ f" {self.config.max_seq_len}"
100
+ )
99
101
 
100
102
  cos, sin = self.rope_cache
101
103
  cos = cos.index_select(0, input_pos)
@@ -19,12 +19,11 @@
19
19
  import os
20
20
  from pathlib import Path
21
21
 
22
- import torch
23
-
24
22
  import ai_edge_torch
25
23
  from ai_edge_torch.generative.examples.experimental.phi import phi2
26
24
  from ai_edge_torch.generative.layers.experimental import ekv_cache
27
25
  from ai_edge_torch.generative.quantize import quant_recipes
26
+ import torch
28
27
 
29
28
 
30
29
  def convert_phi2_to_tflite(
@@ -46,7 +45,9 @@ def convert_phi2_to_tflite(
46
45
  quantize (bool, optional): Whether the model should be quanized.
47
46
  Defaults to True.
48
47
  """
49
- pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
48
+ pytorch_model = phi2.build_model(
49
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
50
+ )
50
51
  # Tensors used to trace the model graph during conversion.
51
52
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
52
53
  prefill_input_pos = torch.arange(0, prefill_seq_len)
@@ -76,7 +77,9 @@ def convert_phi2_to_tflite(
76
77
  )
77
78
  .convert(quant_config=quant_config)
78
79
  )
79
- edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite')
80
+ edge_model.export(
81
+ f'/tmp/phi2_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
82
+ )
80
83
 
81
84
 
82
85
  if __name__ == '__main__':
@@ -22,16 +22,15 @@ import os
22
22
  from pathlib import Path
23
23
  from typing import Tuple
24
24
 
25
- import numpy as np
26
- import torch
27
- import torch.nn as nn
28
-
29
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
26
  import ai_edge_torch.generative.layers.builder as builder
31
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
28
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
29
  import ai_edge_torch.generative.layers.model_config as cfg
34
30
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
35
34
 
36
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
36
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -76,7 +75,9 @@ class Phi2(nn.Module):
76
75
  device=torch.device("cpu"),
77
76
  )
78
77
  self.mask_cache = attn_utils.build_causal_mask_cache(
79
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
78
+ size=config.kv_cache_max,
79
+ dtype=torch.float32,
80
+ device=torch.device("cpu"),
80
81
  )
81
82
  self.config = config
82
83
 
@@ -88,9 +89,10 @@ class Phi2(nn.Module):
88
89
  kv_cache: kv_utils.EKVCache,
89
90
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
90
91
  B, T = tokens.size()
91
- assert (
92
- self.config.max_seq_len >= T
93
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
92
+ assert self.config.max_seq_len >= T, (
93
+ f"Cannot forward sequence of length {T}, max seq length is only"
94
+ f" {self.config.max_seq_len}"
95
+ )
94
96
 
95
97
  cos, sin = self.rope_cache
96
98
  cos = cos.index_select(0, input_pos)
@@ -20,12 +20,11 @@
20
20
  import os
21
21
  from pathlib import Path
22
22
 
23
- import torch
24
-
25
23
  import ai_edge_torch
26
24
  from ai_edge_torch.generative.examples.experimental.tiny_llama import tiny_llama # NOQA
27
25
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
26
  from ai_edge_torch.generative.quantize import quant_recipes
27
+ import torch
29
28
 
30
29
 
31
30
  def convert_tiny_llama_to_tflite(
@@ -22,16 +22,15 @@ import os
22
22
  from pathlib import Path
23
23
  from typing import Tuple
24
24
 
25
- import numpy as np
26
- import torch
27
- import torch.nn as nn
28
-
29
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
30
26
  import ai_edge_torch.generative.layers.builder as builder
31
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
32
28
  from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
33
29
  import ai_edge_torch.generative.layers.model_config as cfg
34
30
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
35
34
 
36
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
37
36
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -78,7 +77,9 @@ class TinyLLamma(nn.Module):
78
77
  device=torch.device("cpu"),
79
78
  )
80
79
  self.mask_cache = attn_utils.build_causal_mask_cache(
81
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
80
+ size=config.kv_cache_max,
81
+ dtype=torch.float32,
82
+ device=torch.device("cpu"),
82
83
  )
83
84
  self.config = config
84
85
 
@@ -90,9 +91,10 @@ class TinyLLamma(nn.Module):
90
91
  kv_cache: kv_utils.EKVCache,
91
92
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
92
93
  B, T = tokens.size()
93
- assert (
94
- self.config.max_seq_len >= T
95
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
94
+ assert self.config.max_seq_len >= T, (
95
+ f"Cannot forward sequence of length {T}, max seq length is only"
96
+ f" {self.config.max_seq_len}"
97
+ )
96
98
 
97
99
  cos, sin = self.rope_cache
98
100
  cos = cos.index_select(0, input_pos)
@@ -16,11 +16,10 @@
16
16
  import os
17
17
  from pathlib import Path
18
18
 
19
- import torch
20
-
21
19
  import ai_edge_torch
22
20
  from ai_edge_torch.generative.examples.gemma import gemma
23
21
  from ai_edge_torch.generative.quantize import quant_recipes
22
+ import torch
24
23
 
25
24
 
26
25
  def convert_gemma_to_tflite(
@@ -58,7 +57,9 @@ def convert_gemma_to_tflite(
58
57
  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
59
58
  .convert(quant_config=quant_config)
60
59
  )
61
- edge_model.export(f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
60
+ edge_model.export(
61
+ f'/tmp/gemma_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
+ )
62
63
 
63
64
 
64
65
  if __name__ == '__main__':
@@ -17,15 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- import numpy as np
21
- import torch
22
- import torch.nn as nn
23
-
24
20
  from ai_edge_torch.generative.layers.attention import TransformerBlock
25
21
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
22
  import ai_edge_torch.generative.layers.builder as builder
27
23
  import ai_edge_torch.generative.layers.model_config as cfg
28
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
29
28
 
30
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -76,7 +75,9 @@ class Gemma(nn.Module):
76
75
  device=torch.device("cpu"),
77
76
  )
78
77
  self.mask_cache = attn_utils.build_causal_mask_cache(
79
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
78
+ size=config.kv_cache_max,
79
+ dtype=torch.float32,
80
+ device=torch.device("cpu"),
80
81
  )
81
82
  self.config = config
82
83
 
@@ -86,9 +87,10 @@ class Gemma(nn.Module):
86
87
  @torch.inference_mode
87
88
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
88
89
  B, T = idx.size()
89
- assert (
90
- self.config.max_seq_len >= T
91
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
90
+ assert self.config.max_seq_len >= T, (
91
+ f"Cannot forward sequence of length {T}, max seq length is only"
92
+ f" {self.config.max_seq_len}"
93
+ )
92
94
 
93
95
  cos, sin = self.rope_cache
94
96
  cos = cos.index_select(0, input_pos)
@@ -171,7 +173,9 @@ def define_and_run_2b() -> None:
171
173
  input_pos = torch.arange(0, kv_cache_max_len)
172
174
  lm_logits = model.forward(tokens, input_pos)
173
175
  print("comparing with goldens..")
174
- assert torch.allclose(gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
176
+ assert torch.allclose(
177
+ gemma_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
178
+ )
175
179
 
176
180
 
177
181
  if __name__ == "__main__":
@@ -16,11 +16,10 @@
16
16
  import os
17
17
  from pathlib import Path
18
18
 
19
- import torch
20
-
21
19
  import ai_edge_torch
22
20
  from ai_edge_torch.generative.examples.phi2 import phi2
23
21
  from ai_edge_torch.generative.quantize import quant_recipes
22
+ import torch
24
23
 
25
24
 
26
25
  def convert_phi2_to_tflite(
@@ -41,7 +40,9 @@ def convert_phi2_to_tflite(
41
40
  quantize (bool, optional): Whether the model should be quanized.
42
41
  Defaults to True.
43
42
  """
44
- pytorch_model = phi2.build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
43
+ pytorch_model = phi2.build_model(
44
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
45
+ )
45
46
  # Tensors used to trace the model graph during conversion.
46
47
  prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
47
48
  prefill_input_pos = torch.arange(0, prefill_seq_len)
@@ -56,7 +57,9 @@ def convert_phi2_to_tflite(
56
57
  .signature('decode', pytorch_model, (decode_token, decode_input_pos))
57
58
  .convert(quant_config=quant_config)
58
59
  )
59
- edge_model.export(f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite')
60
+ edge_model.export(
61
+ f'/tmp/phi2_seq{prefill_seq_len}_kv{kv_cache_max_len}.tflite'
62
+ )
60
63
 
61
64
 
62
65
  if __name__ == '__main__':
@@ -18,15 +18,14 @@
18
18
  import os
19
19
  from pathlib import Path
20
20
 
21
- import numpy as np
22
- import torch
23
- import torch.nn as nn
24
-
25
21
  from ai_edge_torch.generative.layers.attention import TransformerBlock
26
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
27
23
  import ai_edge_torch.generative.layers.builder as builder
28
24
  import ai_edge_torch.generative.layers.model_config as cfg
29
25
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn as nn
30
29
 
31
30
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
32
31
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -71,7 +70,9 @@ class Phi2(nn.Module):
71
70
  device=torch.device("cpu"),
72
71
  )
73
72
  self.mask_cache = attn_utils.build_causal_mask_cache(
74
- size=config.kv_cache_max, dtype=torch.float32, device=torch.device("cpu")
73
+ size=config.kv_cache_max,
74
+ dtype=torch.float32,
75
+ device=torch.device("cpu"),
75
76
  )
76
77
  self.config = config
77
78
 
@@ -81,9 +82,10 @@ class Phi2(nn.Module):
81
82
  @torch.inference_mode
82
83
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
83
84
  B, T = idx.size()
84
- assert (
85
- self.config.max_seq_len >= T
86
- ), f"Cannot forward sequence of length {T}, max seq length is only {self.config.max_seq_len}"
85
+ assert self.config.max_seq_len >= T, (
86
+ f"Cannot forward sequence of length {T}, max seq length is only"
87
+ f" {self.config.max_seq_len}"
88
+ )
87
89
 
88
90
  cos, sin = self.rope_cache
89
91
  cos = cos.index_select(0, input_pos)
@@ -160,7 +162,9 @@ def define_and_run() -> None:
160
162
  input_pos = torch.arange(0, kv_cache_max_len)
161
163
  lm_logits = model.forward(tokens, input_pos)
162
164
  print("comparing with goldens..")
163
- assert torch.allclose(phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05)
165
+ assert torch.allclose(
166
+ phi2_goldens, lm_logits[0, idx.shape[1] - 1, :], atol=1e-05
167
+ )
164
168
 
165
169
 
166
170
  if __name__ == "__main__":
@@ -73,7 +73,9 @@ class SelfAttention(nn.Module):
73
73
 
74
74
  class CrossAttention(nn.Module):
75
75
 
76
- def __init__(self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True):
76
+ def __init__(
77
+ self, n_heads, d_embed, d_cross, in_proj_bias=True, out_proj_bias=True
78
+ ):
77
79
  super().__init__()
78
80
  self.q_proj = nn.Linear(d_embed, d_embed, bias=in_proj_bias)
79
81
  self.k_proj = nn.Linear(d_cross, d_embed, bias=in_proj_bias)
@@ -13,25 +13,34 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- import torch
17
- from torch import nn
18
-
19
16
  from ai_edge_torch.generative.layers.attention import TransformerBlock
20
17
  import ai_edge_torch.generative.layers.attention_utils as attention_utils
21
18
  import ai_edge_torch.generative.layers.builder as builder
22
19
  import ai_edge_torch.generative.layers.model_config as cfg
23
20
  import ai_edge_torch.generative.utilities.loader as loading_utils
21
+ import torch
22
+ from torch import nn
24
23
 
25
24
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
26
- ff_up_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1",
27
- ff_down_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2",
25
+ ff_up_proj=(
26
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc1"
27
+ ),
28
+ ff_down_proj=(
29
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.mlp.fc2"
30
+ ),
28
31
  attn_query_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.q_proj",
29
32
  attn_key_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.k_proj",
30
33
  attn_value_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.v_proj",
31
34
  attn_output_proj="cond_stage_model.transformer.text_model.encoder.layers.{}.self_attn.out_proj",
32
- pre_attn_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1",
33
- pre_ff_norm="cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2",
34
- embedding="cond_stage_model.transformer.text_model.embeddings.token_embedding",
35
+ pre_attn_norm=(
36
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm1"
37
+ ),
38
+ pre_ff_norm=(
39
+ "cond_stage_model.transformer.text_model.encoder.layers.{}.layer_norm2"
40
+ ),
41
+ embedding=(
42
+ "cond_stage_model.transformer.text_model.embeddings.token_embedding"
43
+ ),
35
44
  embedding_position="cond_stage_model.transformer.text_model.embeddings.position_embedding.weight",
36
45
  final_norm="cond_stage_model.transformer.text_model.final_layer_norm",
37
46
  lm_head=None,
@@ -54,7 +63,9 @@ class CLIP(nn.Module):
54
63
  self.transformer_blocks = nn.ModuleList(
55
64
  TransformerBlock(config) for _ in range(config.num_layers)
56
65
  )
57
- self.final_norm = builder.build_norm(config.embedding_dim, config.final_norm_config)
66
+ self.final_norm = builder.build_norm(
67
+ config.embedding_dim, config.final_norm_config
68
+ )
58
69
 
59
70
  self.mask_cache = attention_utils.build_causal_mask_cache(
60
71
  size=config.max_seq_len, dtype=torch.float32