ai-edge-torch-nightly 0.3.0.dev20240902__py3-none-any.whl → 0.3.0.dev20240905__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +32 -42
  2. ai_edge_torch/generative/examples/experimental/phi/phi2.py +33 -15
  3. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +29 -14
  4. ai_edge_torch/generative/examples/gemma/gemma.py +26 -41
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +29 -48
  6. ai_edge_torch/generative/examples/phi2/phi2.py +25 -37
  7. ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
  8. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +21 -9
  9. ai_edge_torch/generative/test/test_model_conversion.py +0 -29
  10. ai_edge_torch/generative/test/test_quantize.py +23 -10
  11. ai_edge_torch/generative/utilities/loader.py +1 -1
  12. ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
  13. ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
  14. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
  15. ai_edge_torch/odml_torch/export.py +1 -1
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/METADATA +2 -2
  18. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/RECORD +21 -22
  19. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
  20. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,16 @@ import os
21
21
  from pathlib import Path
22
22
  from typing import Tuple
23
23
 
24
+ from ai_edge_torch.generative.layers import builder
24
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
26
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
27
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
28
  import ai_edge_torch.generative.layers.model_config as cfg
29
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
30
  import numpy as np
31
31
  import torch
32
- import torch.nn as nn
32
+ from torch import nn
33
+
33
34
 
34
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
36
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -48,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
49
 
49
50
 
50
51
  class Gemma(nn.Module):
52
+ """A Gemma model built from the Edge Generative API layers."""
51
53
 
52
54
  def __init__(self, config: cfg.ModelConfig):
53
55
  super().__init__()
@@ -65,7 +67,7 @@ class Gemma(nn.Module):
65
67
  # Gemma re-uses the embedding as the head projection layer.
66
68
  self.lm_head.weight.data = self.tok_embedding.weight.data
67
69
  self.transformer_blocks = nn.ModuleList(
68
- TransformerBlock(config) for _ in range(config.num_layers)
70
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
69
71
  )
70
72
  self.final_norm = builder.build_norm(
71
73
  config.embedding_dim,
@@ -95,9 +97,9 @@ class Gemma(nn.Module):
95
97
  input_pos: torch.Tensor,
96
98
  kv_cache: kv_utils.EKVCache,
97
99
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
98
- B, T = tokens.size()
99
- assert self.config.max_seq_len >= T, (
100
- f"Cannot forward sequence of length {T}, max seq length is only"
100
+ _, seq_len = tokens.size()
101
+ assert self.config.max_seq_len >= seq_len, (
102
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
101
103
  f" {self.config.max_seq_len}"
102
104
  )
103
105
 
@@ -125,6 +127,15 @@ class Gemma(nn.Module):
125
127
 
126
128
 
127
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
+ """Returns the model config for a Gemma 2B model.
131
+
132
+ Args:
133
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
134
+ is 1024.
135
+
136
+ Returns:
137
+ The model config for a Gemma 2B model.
138
+ """
128
139
  attn_config = cfg.AttentionConfig(
129
140
  num_heads=8,
130
141
  head_dim=256,
@@ -160,41 +171,18 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
160
171
 
161
172
 
162
173
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
163
- attn_config = cfg.AttentionConfig(
164
- num_heads=8,
165
- head_dim=256,
166
- num_query_groups=1,
167
- rotary_percentage=1.0,
168
- )
169
- ff_config = cfg.FeedForwardConfig(
170
- type=cfg.FeedForwardType.GATED,
171
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
172
- intermediate_size=128,
173
- )
174
- norm_config = cfg.NormalizationConfig(
175
- type=cfg.NormalizationType.RMS_NORM,
176
- epsilon=1e-6,
177
- zero_centered=True,
178
- )
179
- config = cfg.ModelConfig(
180
- vocab_size=128,
181
- num_layers=2,
182
- max_seq_len=2 * kv_cache_max_len,
183
- embedding_dim=2048,
184
- kv_cache_max_len=kv_cache_max_len,
185
- attn_config=attn_config,
186
- ff_config=ff_config,
187
- pre_attention_norm_config=norm_config,
188
- post_attention_norm_config=norm_config,
189
- final_norm_config=norm_config,
190
- parallel_residual=False,
191
- lm_head_use_bias=False,
192
- enable_hlfb=True,
193
- )
174
+ config = get_model_config_2b(kv_cache_max_len)
175
+ config.ff_config.intermediate_size = 128
176
+ config.vocab_size = 128
177
+ config.num_layers = 2
178
+ config.max_seq_len = 2 * kv_cache_max_len
194
179
  return config
195
180
 
196
181
 
197
- def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
182
+ def build_2b_model(
183
+ checkpoint_path: str, test_model: bool = False, **kwargs
184
+ ) -> nn.Module:
185
+ """Instantiates the model instance and load checkpoint if provided."""
198
186
  config = (
199
187
  get_fake_model_config(**kwargs)
200
188
  if test_model
@@ -210,7 +198,9 @@ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
210
198
  return model
211
199
 
212
200
 
213
- def define_and_run_2b(checkpoint_path, test_model=False) -> None:
201
+ def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
202
+ """Instantiates and runs a Gemma 2B model."""
203
+
214
204
  kv_cache_max_len = 1024
215
205
  model = build_2b_model(
216
206
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -225,5 +215,5 @@ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
225
215
 
226
216
 
227
217
  if __name__ == "__main__":
228
- checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
229
- define_and_run_2b(checkpoint_path)
218
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
219
+ define_and_run_2b(input_checkpoint_path)
@@ -17,20 +17,20 @@
17
17
  # Note: This is an experimental version of phi2 with external KV cache.
18
18
  # Please use with caution.
19
19
 
20
-
21
20
  import os
22
21
  from pathlib import Path
23
22
  from typing import Tuple
24
23
 
24
+ from ai_edge_torch.generative.layers import builder
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
27
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
29
28
  import ai_edge_torch.generative.layers.model_config as cfg
30
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
30
  import numpy as np
32
31
  import torch
33
- import torch.nn as nn
32
+ from torch import nn
33
+
34
34
 
35
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
36
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -47,6 +47,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
47
47
 
48
48
 
49
49
  class Phi2(nn.Module):
50
+ """A Phi-2 model built from the Edge Generative API layers."""
50
51
 
51
52
  def __init__(self, config: cfg.ModelConfig):
52
53
  super().__init__()
@@ -60,7 +61,7 @@ class Phi2(nn.Module):
60
61
  config.vocab_size, config.embedding_dim, padding_idx=0
61
62
  )
62
63
  self.transformer_blocks = nn.ModuleList(
63
- TransformerBlock(config) for _ in range(config.num_layers)
64
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
64
65
  )
65
66
  self.final_norm = builder.build_norm(
66
67
  config.embedding_dim,
@@ -90,9 +91,9 @@ class Phi2(nn.Module):
90
91
  input_pos: torch.Tensor,
91
92
  kv_cache: kv_utils.EKVCache,
92
93
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
93
- B, T = tokens.size()
94
- assert self.config.max_seq_len >= T, (
95
- f"Cannot forward sequence of length {T}, max seq length is only"
94
+ _, seq_len = tokens.size()
95
+ assert self.config.max_seq_len >= seq_len, (
96
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
96
97
  f" {self.config.max_seq_len}"
97
98
  )
98
99
 
@@ -118,6 +119,15 @@ class Phi2(nn.Module):
118
119
 
119
120
 
120
121
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
122
+ """Returns the model config for a Phi-2 model.
123
+
124
+ Args:
125
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
126
+ is 1024.
127
+
128
+ Returns:
129
+ The model config for a Phi-2 model.
130
+ """
121
131
  attn_config = cfg.AttentionConfig(
122
132
  num_heads=32,
123
133
  head_dim=80,
@@ -150,15 +160,21 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
160
  return config
151
161
 
152
162
 
153
- def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
154
- config = get_model_config(**kwargs)
163
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
164
+ config = get_model_config(kv_cache_max_len)
165
+ config.vocab_size = 128
155
166
  config.num_layers = 2
167
+ config.max_seq_len = 2 * kv_cache_max_len
168
+ config.ff_config.intermediate_size = 128
156
169
  return config
157
170
 
158
171
 
159
- def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
172
+ def build_model(
173
+ checkpoint_path: str, test_model: bool = False, **kwargs
174
+ ) -> nn.Module:
175
+ """Instantiates the model instance and load checkpoint if provided."""
160
176
  config = (
161
- get_fake_model_config_for_test(**kwargs)
177
+ get_fake_model_config(**kwargs)
162
178
  if test_model
163
179
  else get_model_config(**kwargs)
164
180
  )
@@ -170,7 +186,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
170
186
  return model
171
187
 
172
188
 
173
- def define_and_run(checkpoint_path, test_model=False) -> None:
189
+ def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
190
+ """Instantiates and runs a Phi-2 model."""
191
+
174
192
  kv_cache_max_len = 1024
175
193
  model = build_model(
176
194
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -185,5 +203,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
185
203
 
186
204
 
187
205
  if __name__ == "__main__":
188
- checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
189
- define_and_run(checkpoint_path)
206
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
207
+ define_and_run(input_checkpoint_path)
@@ -17,20 +17,20 @@
17
17
  # Note: This is an experimental version of TinyLlama with external KV cache.
18
18
  # Please use with caution.
19
19
 
20
-
21
20
  import os
22
21
  from pathlib import Path
23
22
  from typing import Tuple
24
23
 
24
+ from ai_edge_torch.generative.layers import builder
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
27
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
29
28
  import ai_edge_torch.generative.layers.model_config as cfg
30
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
30
  import numpy as np
32
31
  import torch
33
- import torch.nn as nn
32
+ from torch import nn
33
+
34
34
 
35
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
36
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
49
49
 
50
50
 
51
51
  class TinyLLamma(nn.Module):
52
+ """A TinyLlama model built from the Edge Generative API layers."""
52
53
 
53
54
  def __init__(self, config: cfg.ModelConfig):
54
55
  super().__init__()
@@ -62,7 +63,7 @@ class TinyLLamma(nn.Module):
62
63
  config.vocab_size, config.embedding_dim, padding_idx=0
63
64
  )
64
65
  self.transformer_blocks = nn.ModuleList(
65
- TransformerBlock(config) for _ in range(config.num_layers)
66
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
66
67
  )
67
68
  self.final_norm = builder.build_norm(
68
69
  config.embedding_dim,
@@ -92,9 +93,9 @@ class TinyLLamma(nn.Module):
92
93
  input_pos: torch.Tensor,
93
94
  kv_cache: kv_utils.EKVCache,
94
95
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
95
- B, T = tokens.size()
96
- assert self.config.max_seq_len >= T, (
97
- f"Cannot forward sequence of length {T}, max seq length is only"
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
98
99
  f" {self.config.max_seq_len}"
99
100
  )
100
101
 
@@ -121,6 +122,15 @@ class TinyLLamma(nn.Module):
121
122
 
122
123
 
123
124
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
125
+ """Returns the model config for a TinyLlama model.
126
+
127
+ Args:
128
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
129
+ is 1024.
130
+
131
+ Returns:
132
+ The model config for a TinyLlama model.
133
+ """
124
134
  attn_config = cfg.AttentionConfig(
125
135
  num_heads=32,
126
136
  head_dim=64,
@@ -149,7 +159,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
149
159
  return config
150
160
 
151
161
 
152
- def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
162
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
153
163
  config = get_model_config(**kwargs)
154
164
  config.vocab_size = 128
155
165
  config.num_layers = 2
@@ -157,9 +167,12 @@ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
157
167
  return config
158
168
 
159
169
 
160
- def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
170
+ def build_model(
171
+ checkpoint_path: str, test_model: bool = False, **kwargs
172
+ ) -> nn.Module:
173
+ """Instantiates the model instance and load checkpoint if provided."""
161
174
  config = (
162
- get_fake_model_config_for_test(**kwargs)
175
+ get_fake_model_config(**kwargs)
163
176
  if test_model
164
177
  else get_model_config(**kwargs)
165
178
  )
@@ -171,7 +184,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
171
184
  return model
172
185
 
173
186
 
174
- def define_and_run(checkpoint_path, test_model=False) -> None:
187
+ def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
188
+ """Instantiates and runs a TinyLlama model."""
189
+
175
190
  kv_cache_max_len = 1024
176
191
  model = build_model(
177
192
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -186,5 +201,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
186
201
 
187
202
 
188
203
  if __name__ == "__main__":
189
- checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
190
- define_and_run(checkpoint_path)
204
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
205
+ define_and_run(input_checkpoint_path)
@@ -17,14 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
25
  import numpy as np
26
26
  import torch
27
- import torch.nn as nn
27
+ from torch import nn
28
28
 
29
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
 
44
44
 
45
45
  class Gemma(nn.Module):
46
+ """A Gemma model built from the Edge Generative API layers."""
46
47
 
47
48
  def __init__(self, config: cfg.ModelConfig):
48
49
  super().__init__()
@@ -60,7 +61,7 @@ class Gemma(nn.Module):
60
61
  # Gemma re-uses the embedding as the head projection layer.
61
62
  self.lm_head.weight.data = self.tok_embedding.weight.data
62
63
  self.transformer_blocks = nn.ModuleList(
63
- TransformerBlock(config) for _ in range(config.num_layers)
64
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
64
65
  )
65
66
  self.final_norm = builder.build_norm(
66
67
  config.embedding_dim,
@@ -88,9 +89,9 @@ class Gemma(nn.Module):
88
89
  # This can be eliminated if we handle k/v cache updates inside the model itself.
89
90
  @torch.inference_mode
90
91
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
91
- B, T = idx.size()
92
- assert self.config.max_seq_len >= T, (
93
- f"Cannot forward sequence of length {T}, max seq length is only"
92
+ _, seq_len = idx.size()
93
+ assert self.config.max_seq_len >= seq_len, (
94
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
94
95
  f" {self.config.max_seq_len}"
95
96
  )
96
97
 
@@ -104,7 +105,7 @@ class Gemma(nn.Module):
104
105
  x = self.tok_embedding(idx)
105
106
  x = x * (self.config.embedding_dim**0.5)
106
107
 
107
- for i, block in enumerate(self.transformer_blocks):
108
+ for _, block in enumerate(self.transformer_blocks):
108
109
  x = block(x, (cos, sin), mask, input_pos)
109
110
 
110
111
  x = self.final_norm(x)
@@ -113,6 +114,15 @@ class Gemma(nn.Module):
113
114
 
114
115
 
115
116
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
117
+ """Returns the model config for a Gemma 2B model.
118
+
119
+ Args:
120
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
121
+ is 1024.
122
+
123
+ Returns:
124
+ The model config for a Gemma 2B model.
125
+ """
116
126
  attn_config = cfg.AttentionConfig(
117
127
  num_heads=8,
118
128
  head_dim=256,
@@ -147,43 +157,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
157
  return config
148
158
 
149
159
 
150
- # TODO(b/363021962): Clean up this part to streamline fake model config generation.
151
160
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
152
- attn_config = cfg.AttentionConfig(
153
- num_heads=8,
154
- head_dim=256,
155
- num_query_groups=1,
156
- rotary_percentage=1.0,
157
- )
158
- ff_config = cfg.FeedForwardConfig(
159
- type=cfg.FeedForwardType.GATED,
160
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
161
- intermediate_size=128,
162
- )
163
- norm_config = cfg.NormalizationConfig(
164
- type=cfg.NormalizationType.RMS_NORM,
165
- epsilon=1e-6,
166
- zero_centered=True,
167
- )
168
- config = cfg.ModelConfig(
169
- vocab_size=128,
170
- num_layers=2,
171
- max_seq_len=2 * kv_cache_max_len,
172
- embedding_dim=2048,
173
- kv_cache_max_len=kv_cache_max_len,
174
- attn_config=attn_config,
175
- ff_config=ff_config,
176
- pre_attention_norm_config=norm_config,
177
- post_attention_norm_config=norm_config,
178
- final_norm_config=norm_config,
179
- parallel_residual=False,
180
- lm_head_use_bias=False,
181
- enable_hlfb=True,
182
- )
161
+ config = get_model_config_2b(kv_cache_max_len)
162
+ config.ff_config.intermediate_size = 128
163
+ config.vocab_size = 128
164
+ config.num_layers = 2
165
+ config.max_seq_len = 2 * kv_cache_max_len
183
166
  return config
184
167
 
185
168
 
186
- def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
169
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
187
170
  config = get_model_config_2b(**kwargs)
188
171
  model = Gemma(config)
189
172
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -195,6 +178,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
195
178
 
196
179
 
197
180
  def define_and_run_2b() -> None:
181
+ """Instantiates and runs a Gemma 2B model."""
182
+
198
183
  current_dir = Path(__file__).parent.resolve()
199
184
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
200
185
 
@@ -18,14 +18,14 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Optional, Tuple
20
20
 
21
- from ai_edge_torch.generative.layers.attention import TransformerBlock
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- import ai_edge_torch.generative.layers.builder as builder
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
26
  import numpy as np
27
27
  import torch
28
- import torch.nn as nn
28
+ from torch import nn
29
29
 
30
30
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
31
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
  )
44
44
 
45
45
 
46
- class Gemma2Block(TransformerBlock):
46
+ class Gemma2Block(attention.TransformerBlock):
47
47
 
48
48
  def forward(
49
49
  self,
@@ -76,6 +76,7 @@ class Gemma2Block(TransformerBlock):
76
76
 
77
77
 
78
78
  class Gemma2(nn.Module):
79
+ """A Gemma2 model built from the Edge Generative API layers."""
79
80
 
80
81
  def __init__(self, config: cfg.ModelConfig):
81
82
  super().__init__()
@@ -138,9 +139,9 @@ class Gemma2(nn.Module):
138
139
 
139
140
  @torch.inference_mode
140
141
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
141
- B, T = idx.size()
142
- assert self.config.max_seq_len >= T, (
143
- f"Cannot forward sequence of length {T}, max seq length is only"
142
+ _, seq_len = idx.size()
143
+ assert self.config.max_seq_len >= seq_len, (
144
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
144
145
  f" {self.config.max_seq_len}"
145
146
  )
146
147
 
@@ -166,6 +167,15 @@ class Gemma2(nn.Module):
166
167
 
167
168
 
168
169
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
170
+ """Returns the model config for a Gemma2 2B model.
171
+
172
+ Args:
173
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
174
+ is 1024.
175
+
176
+ Returns:
177
+ The model config for a Gemma 2B model.
178
+ """
169
179
  attn_config = cfg.AttentionConfig(
170
180
  num_heads=8,
171
181
  head_dim=256,
@@ -210,50 +220,19 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
210
220
 
211
221
 
212
222
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
213
- attn_config = cfg.AttentionConfig(
214
- num_heads=4,
215
- head_dim=64,
216
- num_query_groups=4,
217
- rotary_percentage=1.0,
218
- qkv_transpose_before_split=True,
219
- logit_softcap=50.0,
220
- sliding_window_size=64,
221
- attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
222
- * 13,
223
- )
224
-
225
- norm_config = cfg.NormalizationConfig(
226
- type=cfg.NormalizationType.RMS_NORM,
227
- epsilon=1e-6,
228
- zero_centered=True,
229
- )
230
- ff_config = cfg.FeedForwardConfig(
231
- type=cfg.FeedForwardType.GATED,
232
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
233
- intermediate_size=128,
234
- pre_ff_norm_config=norm_config,
235
- post_ff_norm_config=norm_config,
236
- )
237
- config = cfg.ModelConfig(
238
- vocab_size=128,
239
- num_layers=2,
240
- max_seq_len=2 * kv_cache_max_len,
241
- embedding_dim=128,
242
- kv_cache_max_len=kv_cache_max_len,
243
- attn_config=attn_config,
244
- ff_config=ff_config,
245
- pre_attention_norm_config=norm_config,
246
- post_attention_norm_config=norm_config,
247
- final_norm_config=norm_config,
248
- parallel_residual=False,
249
- lm_head_use_bias=False,
250
- enable_hlfb=True,
251
- final_logit_softcap=30.0,
252
- )
223
+ config = get_model_config_2b(kv_cache_max_len)
224
+ config.attn_config.num_heads = 4
225
+ config.attn_config.head_dim = 64
226
+ config.attn_config.sliding_window_size = 64
227
+ config.ff_config.intermediate_size = 128
228
+ config.vocab_size = 128
229
+ config.num_layers = 2
230
+ config.max_seq_len = 2 * kv_cache_max_len
231
+ config.embedding_dim = 128
253
232
  return config
254
233
 
255
234
 
256
- def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
235
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
257
236
  config = get_model_config_2b(**kwargs)
258
237
  model = Gemma2(config)
259
238
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -265,6 +244,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
265
244
 
266
245
 
267
246
  def define_and_run_2b() -> None:
247
+ """Instantiates and runs a Gemma2 2B model."""
248
+
268
249
  current_dir = Path(__file__).parent.resolve()
269
250
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
270
251
  print("Running GEMMA 2")
@@ -18,14 +18,14 @@
18
18
  import os
19
19
  from pathlib import Path
20
20
 
21
- from ai_edge_torch.generative.layers.attention import TransformerBlock
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- import ai_edge_torch.generative.layers.builder as builder
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
26
  import numpy as np
27
27
  import torch
28
- import torch.nn as nn
28
+ from torch import nn
29
29
 
30
30
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
31
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -42,6 +42,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
42
 
43
43
 
44
44
  class Phi2(nn.Module):
45
+ """A Phi-2 model built from the Edge Generative API layers."""
45
46
 
46
47
  def __init__(self, config: cfg.ModelConfig):
47
48
  super().__init__()
@@ -55,7 +56,7 @@ class Phi2(nn.Module):
55
56
  config.vocab_size, config.embedding_dim, padding_idx=0
56
57
  )
57
58
  self.transformer_blocks = nn.ModuleList(
58
- TransformerBlock(config) for _ in range(config.num_layers)
59
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
59
60
  )
60
61
  self.final_norm = builder.build_norm(
61
62
  config.embedding_dim,
@@ -83,9 +84,9 @@ class Phi2(nn.Module):
83
84
  # This can be eliminated if we handle k/v cache updates inside the model itself.
84
85
  @torch.inference_mode
85
86
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
86
- B, T = idx.size()
87
- assert self.config.max_seq_len >= T, (
88
- f"Cannot forward sequence of length {T}, max seq length is only"
87
+ _, seq_len = idx.size()
88
+ assert self.config.max_seq_len >= seq_len, (
89
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
89
90
  f" {self.config.max_seq_len}"
90
91
  )
91
92
 
@@ -98,7 +99,7 @@ class Phi2(nn.Module):
98
99
  # forward the model itself
99
100
  x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
100
101
 
101
- for i, block in enumerate(self.transformer_blocks):
102
+ for _, block in enumerate(self.transformer_blocks):
102
103
  x = block(x, (cos, sin), mask, input_pos)
103
104
 
104
105
  x = self.final_norm(x)
@@ -107,6 +108,15 @@ class Phi2(nn.Module):
107
108
 
108
109
 
109
110
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
111
+ """Returns the model config for a Phi-2 model.
112
+
113
+ Args:
114
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
115
+ is 1024.
116
+
117
+ Returns:
118
+ The model config for a Phi-2 model.
119
+ """
110
120
  attn_config = cfg.AttentionConfig(
111
121
  num_heads=32,
112
122
  head_dim=80,
@@ -140,35 +150,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
150
 
141
151
 
142
152
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
143
- attn_config = cfg.AttentionConfig(
144
- num_heads=16,
145
- head_dim=80,
146
- num_query_groups=4,
147
- rotary_percentage=0.4,
148
- qkv_use_bias=True,
149
- output_proj_use_bias=True,
150
- )
151
- ff_config = cfg.FeedForwardConfig(
152
- type=cfg.FeedForwardType.SEQUENTIAL,
153
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
154
- intermediate_size=128,
155
- use_bias=True,
156
- )
157
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
158
- config = cfg.ModelConfig(
159
- vocab_size=128,
160
- num_layers=2,
161
- max_seq_len=2 * kv_cache_max_len,
162
- kv_cache_max_len=kv_cache_max_len,
163
- embedding_dim=128,
164
- attn_config=attn_config,
165
- ff_config=ff_config,
166
- pre_attention_norm_config=norm_config,
167
- final_norm_config=norm_config,
168
- parallel_residual=True,
169
- lm_head_use_bias=True,
170
- enable_hlfb=True,
171
- )
153
+ config = get_model_config(kv_cache_max_len)
154
+ config.vocab_size = 128
155
+ config.num_layers = 2
156
+ config.max_seq_len = 2 * kv_cache_max_len
157
+ config.ff_config.intermediate_size = 128
172
158
  return config
173
159
 
174
160
 
@@ -181,6 +167,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
181
167
 
182
168
 
183
169
  def define_and_run() -> None:
170
+ """Instantiates and runs a Phi-2 model."""
171
+
184
172
  current_dir = Path(__file__).parent.resolve()
185
173
  phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
186
174
  kv_cache_max_len = 1024
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
71
71
  return self.lm_head(x)
72
72
 
73
73
 
74
+ class ToySingleLayerModelWeightSharing(torch.nn.Module):
75
+
76
+ def __init__(self, config: cfg.ModelConfig) -> None:
77
+ super().__init__()
78
+ self.lm_head = nn.Linear(
79
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
80
+ )
81
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
82
+ self.lm_head = nn.Linear(
83
+ config.embedding_dim,
84
+ config.vocab_size,
85
+ bias=config.lm_head_use_bias,
86
+ )
87
+ self.lm_head.weight.data = self.tok_embedding.weight.data
88
+ self.transformer_block = TransformerBlock(config)
89
+ self.final_norm = builder.build_norm(
90
+ config.embedding_dim,
91
+ config.final_norm_config,
92
+ )
93
+ self.rope_cache = attn_utils.build_rope_cache(
94
+ size=config.max_seq_len,
95
+ dim=int(
96
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
97
+ ),
98
+ base=10_000,
99
+ condense_ratio=1,
100
+ dtype=torch.float32,
101
+ device=torch.device('cpu'),
102
+ )
103
+ self.mask_cache = attn_utils.build_causal_mask_cache(
104
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
105
+ )
106
+ self.config = config
107
+
108
+ @torch.inference_mode
109
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
110
+ x = self.tok_embedding(idx)
111
+ cos, sin = self.rope_cache
112
+
113
+ cos = cos.index_select(0, input_pos)
114
+ sin = sin.index_select(0, input_pos)
115
+ mask = self.mask_cache.index_select(2, input_pos)
116
+ mask = mask[:, :, :, : self.config.max_seq_len]
117
+
118
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
119
+ x = self.final_norm(x)
120
+ res = self.lm_head(x)
121
+ return res
122
+
123
+
74
124
  def get_model_config() -> cfg.ModelConfig:
75
125
  attn_config = cfg.AttentionConfig(
76
126
  num_heads=32,
@@ -17,14 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
25
  import numpy as np
26
26
  import torch
27
- import torch.nn as nn
27
+ from torch import nn
28
28
 
29
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
 
44
44
 
45
45
  class TinyLLamma(nn.Module):
46
+ """A TinyLlama model built from the Edge Generative API layers."""
46
47
 
47
48
  def __init__(self, config: cfg.ModelConfig):
48
49
  super().__init__()
@@ -56,7 +57,7 @@ class TinyLLamma(nn.Module):
56
57
  config.vocab_size, config.embedding_dim, padding_idx=0
57
58
  )
58
59
  self.transformer_blocks = nn.ModuleList(
59
- TransformerBlock(config) for _ in range(config.num_layers)
60
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
60
61
  )
61
62
  self.final_norm = builder.build_norm(
62
63
  config.embedding_dim,
@@ -84,9 +85,9 @@ class TinyLLamma(nn.Module):
84
85
  # This can be eliminated if we handle k/v cache updates inside the model itself.
85
86
  @torch.inference_mode
86
87
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
87
- B, T = idx.size()
88
- assert self.config.max_seq_len >= T, (
89
- f"Cannot forward sequence of length {T}, max seq length is only"
88
+ _, seq_len = idx.size()
89
+ assert self.config.max_seq_len >= seq_len, (
90
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
90
91
  f" {self.config.max_seq_len}"
91
92
  )
92
93
 
@@ -99,7 +100,7 @@ class TinyLLamma(nn.Module):
99
100
  # forward the model itself
100
101
  x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
101
102
 
102
- for i, block in enumerate(self.transformer_blocks):
103
+ for _, block in enumerate(self.transformer_blocks):
103
104
  x = block(x, (cos, sin), mask, input_pos)
104
105
 
105
106
  x = self.final_norm(x)
@@ -109,6 +110,15 @@ class TinyLLamma(nn.Module):
109
110
 
110
111
 
111
112
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
+ """Returns the model config for a TinyLlama model.
114
+
115
+ Args:
116
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
117
+ is 1024.
118
+
119
+ Returns:
120
+ The model config for a TinyLlama model.
121
+ """
112
122
  attn_config = cfg.AttentionConfig(
113
123
  num_heads=32,
114
124
  head_dim=64,
@@ -145,7 +155,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
145
155
  return config
146
156
 
147
157
 
148
- def build_model(checkpoint_path, **kwargs) -> nn.Module:
158
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
149
159
  config = get_model_config(**kwargs)
150
160
  model = TinyLLamma(config)
151
161
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -154,6 +164,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
154
164
 
155
165
 
156
166
  def define_and_run() -> None:
167
+ """Instantiates and runs a TinyLlama model."""
168
+
157
169
  current_dir = Path(__file__).parent.resolve()
158
170
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
159
171
  kv_cache_max_len = 1024
@@ -70,35 +70,6 @@ class TestModelConversion(googletest.TestCase):
70
70
  )
71
71
  )
72
72
 
73
- @googletest.skipIf(
74
- ai_edge_config.Config.use_torch_xla,
75
- reason="tests with custom ops are not supported on oss",
76
- )
77
- def test_toy_model_with_multi_batches(self):
78
- self.skipTest("b/362842043")
79
- config = toy_model_with_kv_cache.get_model_config()
80
- config.batch_size = 2
81
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
82
- idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
83
- [10], dtype=torch.int64
84
- )
85
-
86
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
87
- edge_model.set_interpreter_builder(
88
- self._interpreter_builder(edge_model.tflite_model())
89
- )
90
-
91
- self.assertTrue(
92
- model_coverage.compare_tflite_torch(
93
- edge_model,
94
- pytorch_model,
95
- (idx, input_pos),
96
- num_valid_inputs=1,
97
- atol=1e-5,
98
- rtol=1e-5,
99
- )
100
- )
101
-
102
73
  @googletest.skipIf(
103
74
  ai_edge_config.Config.use_torch_xla,
104
75
  reason="tests with custom ops are not supported on oss",
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
25
25
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
26
26
  from ai_edge_torch.quantize import quant_config
27
27
  from ai_edge_torch.testing import model_coverage
28
- from parameterized import parameterized
29
28
  import torch
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from absl.testing import parameterized
32
32
 
33
33
 
34
- class TestVerifyRecipes(googletest.TestCase):
34
+ class TestVerifyRecipes(parameterized.TestCase):
35
35
  """Unit tests that check for model quantization recipes."""
36
36
 
37
- @parameterized.expand([
37
+ @parameterized.parameters([
38
38
  (Dtype.FP32, Dtype.FP32),
39
39
  (Dtype.INT8, Dtype.INT8),
40
40
  (Dtype.INT8, Dtype.FP16),
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
52
52
  with self.assertRaises(ValueError):
53
53
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
54
54
 
55
- @parameterized.expand([
55
+ @parameterized.parameters([
56
56
  (
57
57
  Dtype.FP32,
58
58
  Dtype.INT8,
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
88
88
  ).verify()
89
89
 
90
90
 
91
- class TestQuantizeConvert(googletest.TestCase):
91
+ class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
105
105
  )
106
106
  )
107
107
 
108
- @parameterized.expand([
108
+ @parameterized.parameters([
109
109
  (quant_recipes.full_fp16_recipe()),
110
110
  (quant_recipes.full_int8_dynamic_recipe()),
111
111
  (quant_recipes.full_int8_weight_only_recipe()),
112
112
  (_attention_int8_dynamic_recipe()),
113
113
  (_feedforward_int8_dynamic_recipe()),
114
114
  ])
115
- @googletest.skipIf(
116
- not config.Config.use_torch_xla,
117
- reason="Not working with odml_torch at the moment.",
118
- )
119
115
  def test_quantize_convert_toy_sizes(self, quant_config):
120
116
  config = toy_model.get_model_config()
121
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
132
128
  "Quantized model isn't smaller than F32 model.",
133
129
  )
134
130
 
131
+ def test_quantize_convert_toy_weight_sharing(self):
132
+ config = toy_model.get_model_config()
133
+ pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
+ input_pos = torch.arange(0, 100)
136
+
137
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
138
+ quantized_model = ai_edge_torch.convert(
139
+ pytorch_model, (idx, input_pos), quant_config=quant_config
140
+ )
141
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
142
+ self.assertLess(
143
+ len(quantized_model._tflite_model),
144
+ len(float_model._tflite_model),
145
+ "Quantized model isn't smaller than F32 model.",
146
+ )
147
+
135
148
  def test_quantize_convert_compare_toy(self):
136
149
  self.skipTest("b/338288901")
137
150
  config = toy_model_with_kv_cache.get_model_config()
@@ -208,7 +208,7 @@ class ModelLoader:
208
208
  if self._file_name.endswith(".safetensors"):
209
209
  return load_safetensors
210
210
 
211
- if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
211
+ if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
212
212
  return load_pytorch_statedict
213
213
 
214
214
  raise ValueError("File format not supported.")
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
21
21
  from ai_edge_torch._convert import conversion_utils
22
22
  from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.lowertools import common_utils
24
+ from ai_edge_torch.lowertools import translate_recipe
24
25
  from ai_edge_torch.odml_torch import export
25
26
  from ai_edge_torch.odml_torch import export_utils
26
27
  from ai_edge_torch.quantize import quant_config as qcfg
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
186
187
  converter._experimental_enable_composite_direct_lowering = True
187
188
  converter.model_origin_framework = "PYTORCH"
188
189
 
190
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
191
+ if (
192
+ quant_config is not None
193
+ and quant_config._quantizer_mode
194
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
195
+ ):
196
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
197
+ quant_config.generative_recipe
198
+ )
199
+
189
200
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
190
201
 
191
202
  tflite_model = converter.convert()
192
203
 
204
+ if (
205
+ quant_config is not None
206
+ and quant_config._quantizer_mode
207
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
208
+ ):
209
+ tflite_model = translate_recipe.quantize_model(
210
+ tflite_model, translated_recipe
211
+ )
212
+
193
213
  return tflite_model
194
214
 
195
215
 
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
25
25
  from ai_edge_torch import model
26
26
  from ai_edge_torch._convert import conversion_utils
27
27
  from ai_edge_torch._convert import signature as signature_module
28
- from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
29
28
  from ai_edge_torch.lowertools import common_utils
29
+ from ai_edge_torch.lowertools import translate_recipe
30
30
  from ai_edge_torch.quantize import quant_config as qcfg
31
31
  import torch
32
32
  from torch_xla import stablehlo
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
17
17
  from ai_edge_torch.generative.quantize import quant_attrs
18
18
  from ai_edge_torch.generative.quantize import quant_recipe
19
19
 
20
- _OpExecutionMode = quantizer.qtyping.OpExecutionMode
20
+ _ComputePrecision = quantizer.qtyping.ComputePrecision
21
+ _QuantGranularity = quantizer.qtyping.QuantGranularity
21
22
  _OpName = quantizer.qtyping.TFLOperationName
22
23
  _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
23
24
  _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
50
51
  return quantizer.qtyping.TensorDataType.INT
51
52
 
52
53
 
53
- def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
54
+ def _get_compute_precision_from_mode(
55
+ mode: quant_attrs.Mode,
56
+ ) -> _ComputePrecision:
54
57
  if mode == quant_attrs.Mode.DYNAMIC_RANGE:
55
- return _OpExecutionMode.DRQ
58
+ return _ComputePrecision.INTEGER
56
59
  elif mode == quant_attrs.Mode.WEIGHT_ONLY:
57
- return _OpExecutionMode.WEIGHT_ONLY
60
+ return _ComputePrecision.FLOAT
58
61
  raise ValueError('Unimplemented execution mode')
59
62
 
60
63
 
61
- def _get_channelwise_from_granularity(
64
+ def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
65
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
66
+ return False
67
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
68
+ return True
69
+ raise ValueError('Unimplemented execution mode')
70
+
71
+
72
+ def _get_granularity(
62
73
  granularity: quant_attrs.Granularity,
63
74
  ) -> bool:
64
75
  if granularity == quant_attrs.Granularity.CHANNELWISE:
65
- return True
66
- elif granularity == quant_attrs.Granularity.NONE:
67
- return False
76
+ return _QuantGranularity.CHANNELWISE
77
+ if granularity == quant_attrs.Granularity.NONE:
78
+ return _QuantGranularity.TENSORWISE
68
79
  raise ValueError('Unimplemented granularity')
69
80
 
70
81
 
@@ -88,12 +99,13 @@ def _set_quant_config(
88
99
  weight_tensor_config=_TensorQuantConfig(
89
100
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
90
101
  symmetric=True,
91
- channel_wise=_get_channelwise_from_granularity(
92
- layer_recipe.granularity
93
- ),
102
+ granularity=_get_granularity(layer_recipe.granularity),
94
103
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
95
104
  ),
96
- execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
105
+ compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
106
+ explicit_dequantize=_get_explicit_dequant_from_mode(
107
+ layer_recipe.mode
108
+ ),
97
109
  ),
98
110
  algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
99
111
  )
@@ -277,7 +277,7 @@ def exported_program_to_mlir(
277
277
  main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
278
278
  temp_func.erase()
279
279
 
280
- module.operation.verify()
280
+ module.operation.verify()
281
281
 
282
282
  input_signature = []
283
283
  state_dict = {}
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240902"
16
+ __version__ = "0.3.0.dev20240905"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240902
3
+ Version: 0.3.0.dev20240905
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
- Requires-Dist: ai-edge-quantizer-nightly==0.0.1.dev20240718
33
+ Requires-Dist: ai-edge-quantizer-nightly
34
34
 
35
35
  Library that supports converting PyTorch models into a .tflite format, which can
36
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=pl_weDdkMIjqukMxBF4uho_z-MvFlGy_ButOq6tJwVc,706
5
+ ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,21 +42,21 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
42
42
  ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
43
  ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
44
  ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=EdElPCDLYxnNvkPMJkE3WKvESze1ehgShEk2NnbrXLg,7527
45
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
46
46
  ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
- ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
48
+ ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
49
49
  ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=zQYtyk3xYdiRAnzMKN58Q_wgTQFnDujxp6L4RFQjiD4,6383
51
+ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
52
52
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
53
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=pzD9dYUYg8E6fFACh-8B8G9NHFXOVEWBjf5aDeipU2s,7202
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=ypd6uBb4FgDpuWm_w8JNYBAf4eFxWbYccs8vCgBhi-I,9374
55
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
57
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=91mWxEtKgDtUhCAewWNwH_UOOCzy6tPdf6LNRlxZhrc,6700
59
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
60
60
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
62
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -77,12 +77,12 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
77
77
  ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
78
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
81
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
83
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=JmwU1sniO37vnCFc8dklbd-0ofTZK0PaBv_Ksn1Vq6M,5930
85
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
86
86
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
87
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
88
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
106
106
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
107
107
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
108
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
111
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
113
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
114
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=wQLVjMnKHBCVCU_I-xAUZvlOFoDiwYwKQDvCZ2mjtOM,6193
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
115
113
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
116
- ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
114
+ ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
117
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
118
- ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
116
+ ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
119
117
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
120
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
121
119
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
128
126
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
129
127
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
130
128
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
131
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=GKfW1X-QSFffQdVlBuD-bNpP265xcdUlfBY3-9I4f_o,7447
129
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
132
130
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
133
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-SRm9YNsIGsaVd5Cyp2PP-tdLBJH8EDoMFAa2y89a1w,9043
131
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
132
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
134
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
135
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
136
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
137
- ai_edge_torch/odml_torch/export.py,sha256=hIGT-JKYbIa6e_G0AD-k4MSTIAMGdHC1hNHMn9CxsYw,10467
136
+ ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
138
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
139
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
140
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
162
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
163
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
164
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
165
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
166
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/METADATA,sha256=Bvc6_uRgjiaqUsVareqJETErsc5rU7NNTPTDqn0JwoA,1878
167
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
168
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
169
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================