ai-edge-torch-nightly 0.3.0.dev20240902__py3-none-any.whl → 0.3.0.dev20240905__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (22) hide show
  1. ai_edge_torch/generative/examples/experimental/gemma/gemma.py +32 -42
  2. ai_edge_torch/generative/examples/experimental/phi/phi2.py +33 -15
  3. ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py +29 -14
  4. ai_edge_torch/generative/examples/gemma/gemma.py +26 -41
  5. ai_edge_torch/generative/examples/gemma/gemma2.py +29 -48
  6. ai_edge_torch/generative/examples/phi2/phi2.py +25 -37
  7. ai_edge_torch/generative/examples/test_models/toy_model.py +50 -0
  8. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +21 -9
  9. ai_edge_torch/generative/test/test_model_conversion.py +0 -29
  10. ai_edge_torch/generative/test/test_quantize.py +23 -10
  11. ai_edge_torch/generative/utilities/loader.py +1 -1
  12. ai_edge_torch/lowertools/odml_torch_utils.py +20 -0
  13. ai_edge_torch/lowertools/torch_xla_utils.py +1 -1
  14. ai_edge_torch/{generative/quantize/ai_edge_quantizer_glue → lowertools}/translate_recipe.py +24 -12
  15. ai_edge_torch/odml_torch/export.py +1 -1
  16. ai_edge_torch/version.py +1 -1
  17. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/METADATA +2 -2
  18. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/RECORD +21 -22
  19. ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py +0 -14
  20. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/LICENSE +0 -0
  21. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/WHEEL +0 -0
  22. {ai_edge_torch_nightly-0.3.0.dev20240902.dist-info → ai_edge_torch_nightly-0.3.0.dev20240905.dist-info}/top_level.txt +0 -0
@@ -21,15 +21,16 @@ import os
21
21
  from pathlib import Path
22
22
  from typing import Tuple
23
23
 
24
+ from ai_edge_torch.generative.layers import builder
24
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
25
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
26
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
27
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
28
28
  import ai_edge_torch.generative.layers.model_config as cfg
29
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
30
30
  import numpy as np
31
31
  import torch
32
- import torch.nn as nn
32
+ from torch import nn
33
+
33
34
 
34
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
35
36
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -48,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
48
49
 
49
50
 
50
51
  class Gemma(nn.Module):
52
+ """A Gemma model built from the Edge Generative API layers."""
51
53
 
52
54
  def __init__(self, config: cfg.ModelConfig):
53
55
  super().__init__()
@@ -65,7 +67,7 @@ class Gemma(nn.Module):
65
67
  # Gemma re-uses the embedding as the head projection layer.
66
68
  self.lm_head.weight.data = self.tok_embedding.weight.data
67
69
  self.transformer_blocks = nn.ModuleList(
68
- TransformerBlock(config) for _ in range(config.num_layers)
70
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
69
71
  )
70
72
  self.final_norm = builder.build_norm(
71
73
  config.embedding_dim,
@@ -95,9 +97,9 @@ class Gemma(nn.Module):
95
97
  input_pos: torch.Tensor,
96
98
  kv_cache: kv_utils.EKVCache,
97
99
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
98
- B, T = tokens.size()
99
- assert self.config.max_seq_len >= T, (
100
- f"Cannot forward sequence of length {T}, max seq length is only"
100
+ _, seq_len = tokens.size()
101
+ assert self.config.max_seq_len >= seq_len, (
102
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
101
103
  f" {self.config.max_seq_len}"
102
104
  )
103
105
 
@@ -125,6 +127,15 @@ class Gemma(nn.Module):
125
127
 
126
128
 
127
129
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
130
+ """Returns the model config for a Gemma 2B model.
131
+
132
+ Args:
133
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
134
+ is 1024.
135
+
136
+ Returns:
137
+ The model config for a Gemma 2B model.
138
+ """
128
139
  attn_config = cfg.AttentionConfig(
129
140
  num_heads=8,
130
141
  head_dim=256,
@@ -160,41 +171,18 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
160
171
 
161
172
 
162
173
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
163
- attn_config = cfg.AttentionConfig(
164
- num_heads=8,
165
- head_dim=256,
166
- num_query_groups=1,
167
- rotary_percentage=1.0,
168
- )
169
- ff_config = cfg.FeedForwardConfig(
170
- type=cfg.FeedForwardType.GATED,
171
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
172
- intermediate_size=128,
173
- )
174
- norm_config = cfg.NormalizationConfig(
175
- type=cfg.NormalizationType.RMS_NORM,
176
- epsilon=1e-6,
177
- zero_centered=True,
178
- )
179
- config = cfg.ModelConfig(
180
- vocab_size=128,
181
- num_layers=2,
182
- max_seq_len=2 * kv_cache_max_len,
183
- embedding_dim=2048,
184
- kv_cache_max_len=kv_cache_max_len,
185
- attn_config=attn_config,
186
- ff_config=ff_config,
187
- pre_attention_norm_config=norm_config,
188
- post_attention_norm_config=norm_config,
189
- final_norm_config=norm_config,
190
- parallel_residual=False,
191
- lm_head_use_bias=False,
192
- enable_hlfb=True,
193
- )
174
+ config = get_model_config_2b(kv_cache_max_len)
175
+ config.ff_config.intermediate_size = 128
176
+ config.vocab_size = 128
177
+ config.num_layers = 2
178
+ config.max_seq_len = 2 * kv_cache_max_len
194
179
  return config
195
180
 
196
181
 
197
- def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
182
+ def build_2b_model(
183
+ checkpoint_path: str, test_model: bool = False, **kwargs
184
+ ) -> nn.Module:
185
+ """Instantiates the model instance and load checkpoint if provided."""
198
186
  config = (
199
187
  get_fake_model_config(**kwargs)
200
188
  if test_model
@@ -210,7 +198,9 @@ def build_2b_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
210
198
  return model
211
199
 
212
200
 
213
- def define_and_run_2b(checkpoint_path, test_model=False) -> None:
201
+ def define_and_run_2b(checkpoint_path: str, test_model: bool = False) -> None:
202
+ """Instantiates and runs a Gemma 2B model."""
203
+
214
204
  kv_cache_max_len = 1024
215
205
  model = build_2b_model(
216
206
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -225,5 +215,5 @@ def define_and_run_2b(checkpoint_path, test_model=False) -> None:
225
215
 
226
216
 
227
217
  if __name__ == "__main__":
228
- checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
229
- define_and_run_2b(checkpoint_path)
218
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/gemma-2b")
219
+ define_and_run_2b(input_checkpoint_path)
@@ -17,20 +17,20 @@
17
17
  # Note: This is an experimental version of phi2 with external KV cache.
18
18
  # Please use with caution.
19
19
 
20
-
21
20
  import os
22
21
  from pathlib import Path
23
22
  from typing import Tuple
24
23
 
24
+ from ai_edge_torch.generative.layers import builder
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
27
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
29
28
  import ai_edge_torch.generative.layers.model_config as cfg
30
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
30
  import numpy as np
32
31
  import torch
33
- import torch.nn as nn
32
+ from torch import nn
33
+
34
34
 
35
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
36
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -47,6 +47,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
47
47
 
48
48
 
49
49
  class Phi2(nn.Module):
50
+ """A Phi-2 model built from the Edge Generative API layers."""
50
51
 
51
52
  def __init__(self, config: cfg.ModelConfig):
52
53
  super().__init__()
@@ -60,7 +61,7 @@ class Phi2(nn.Module):
60
61
  config.vocab_size, config.embedding_dim, padding_idx=0
61
62
  )
62
63
  self.transformer_blocks = nn.ModuleList(
63
- TransformerBlock(config) for _ in range(config.num_layers)
64
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
64
65
  )
65
66
  self.final_norm = builder.build_norm(
66
67
  config.embedding_dim,
@@ -90,9 +91,9 @@ class Phi2(nn.Module):
90
91
  input_pos: torch.Tensor,
91
92
  kv_cache: kv_utils.EKVCache,
92
93
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
93
- B, T = tokens.size()
94
- assert self.config.max_seq_len >= T, (
95
- f"Cannot forward sequence of length {T}, max seq length is only"
94
+ _, seq_len = tokens.size()
95
+ assert self.config.max_seq_len >= seq_len, (
96
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
96
97
  f" {self.config.max_seq_len}"
97
98
  )
98
99
 
@@ -118,6 +119,15 @@ class Phi2(nn.Module):
118
119
 
119
120
 
120
121
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
122
+ """Returns the model config for a Phi-2 model.
123
+
124
+ Args:
125
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
126
+ is 1024.
127
+
128
+ Returns:
129
+ The model config for a Phi-2 model.
130
+ """
121
131
  attn_config = cfg.AttentionConfig(
122
132
  num_heads=32,
123
133
  head_dim=80,
@@ -150,15 +160,21 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
150
160
  return config
151
161
 
152
162
 
153
- def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
154
- config = get_model_config(**kwargs)
163
+ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
164
+ config = get_model_config(kv_cache_max_len)
165
+ config.vocab_size = 128
155
166
  config.num_layers = 2
167
+ config.max_seq_len = 2 * kv_cache_max_len
168
+ config.ff_config.intermediate_size = 128
156
169
  return config
157
170
 
158
171
 
159
- def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
172
+ def build_model(
173
+ checkpoint_path: str, test_model: bool = False, **kwargs
174
+ ) -> nn.Module:
175
+ """Instantiates the model instance and load checkpoint if provided."""
160
176
  config = (
161
- get_fake_model_config_for_test(**kwargs)
177
+ get_fake_model_config(**kwargs)
162
178
  if test_model
163
179
  else get_model_config(**kwargs)
164
180
  )
@@ -170,7 +186,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
170
186
  return model
171
187
 
172
188
 
173
- def define_and_run(checkpoint_path, test_model=False) -> None:
189
+ def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
190
+ """Instantiates and runs a Phi-2 model."""
191
+
174
192
  kv_cache_max_len = 1024
175
193
  model = build_model(
176
194
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -185,5 +203,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
185
203
 
186
204
 
187
205
  if __name__ == "__main__":
188
- checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
189
- define_and_run(checkpoint_path)
206
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/phi2")
207
+ define_and_run(input_checkpoint_path)
@@ -17,20 +17,20 @@
17
17
  # Note: This is an experimental version of TinyLlama with external KV cache.
18
18
  # Please use with caution.
19
19
 
20
-
21
20
  import os
22
21
  from pathlib import Path
23
22
  from typing import Tuple
24
23
 
24
+ from ai_edge_torch.generative.layers import builder
25
25
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
26
- import ai_edge_torch.generative.layers.builder as builder
26
+ from ai_edge_torch.generative.layers.experimental import attention
27
27
  from ai_edge_torch.generative.layers.experimental import ekv_cache as kv_utils
28
- from ai_edge_torch.generative.layers.experimental.attention import TransformerBlock # NOQA
29
28
  import ai_edge_torch.generative.layers.model_config as cfg
30
29
  import ai_edge_torch.generative.utilities.loader as loading_utils
31
30
  import numpy as np
32
31
  import torch
33
- import torch.nn as nn
32
+ from torch import nn
33
+
34
34
 
35
35
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
36
36
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -49,6 +49,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
49
49
 
50
50
 
51
51
  class TinyLLamma(nn.Module):
52
+ """A TinyLlama model built from the Edge Generative API layers."""
52
53
 
53
54
  def __init__(self, config: cfg.ModelConfig):
54
55
  super().__init__()
@@ -62,7 +63,7 @@ class TinyLLamma(nn.Module):
62
63
  config.vocab_size, config.embedding_dim, padding_idx=0
63
64
  )
64
65
  self.transformer_blocks = nn.ModuleList(
65
- TransformerBlock(config) for _ in range(config.num_layers)
66
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
66
67
  )
67
68
  self.final_norm = builder.build_norm(
68
69
  config.embedding_dim,
@@ -92,9 +93,9 @@ class TinyLLamma(nn.Module):
92
93
  input_pos: torch.Tensor,
93
94
  kv_cache: kv_utils.EKVCache,
94
95
  ) -> Tuple[torch.Tensor, kv_utils.EKVCache]:
95
- B, T = tokens.size()
96
- assert self.config.max_seq_len >= T, (
97
- f"Cannot forward sequence of length {T}, max seq length is only"
96
+ _, seq_len = tokens.size()
97
+ assert self.config.max_seq_len >= seq_len, (
98
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
98
99
  f" {self.config.max_seq_len}"
99
100
  )
100
101
 
@@ -121,6 +122,15 @@ class TinyLLamma(nn.Module):
121
122
 
122
123
 
123
124
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
125
+ """Returns the model config for a TinyLlama model.
126
+
127
+ Args:
128
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
129
+ is 1024.
130
+
131
+ Returns:
132
+ The model config for a TinyLlama model.
133
+ """
124
134
  attn_config = cfg.AttentionConfig(
125
135
  num_heads=32,
126
136
  head_dim=64,
@@ -149,7 +159,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
149
159
  return config
150
160
 
151
161
 
152
- def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
162
+ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
153
163
  config = get_model_config(**kwargs)
154
164
  config.vocab_size = 128
155
165
  config.num_layers = 2
@@ -157,9 +167,12 @@ def get_fake_model_config_for_test(**kwargs) -> cfg.ModelConfig:
157
167
  return config
158
168
 
159
169
 
160
- def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
170
+ def build_model(
171
+ checkpoint_path: str, test_model: bool = False, **kwargs
172
+ ) -> nn.Module:
173
+ """Instantiates the model instance and load checkpoint if provided."""
161
174
  config = (
162
- get_fake_model_config_for_test(**kwargs)
175
+ get_fake_model_config(**kwargs)
163
176
  if test_model
164
177
  else get_model_config(**kwargs)
165
178
  )
@@ -171,7 +184,9 @@ def build_model(checkpoint_path, test_model=False, **kwargs) -> nn.Module:
171
184
  return model
172
185
 
173
186
 
174
- def define_and_run(checkpoint_path, test_model=False) -> None:
187
+ def define_and_run(checkpoint_path: str, test_model: bool = False) -> None:
188
+ """Instantiates and runs a TinyLlama model."""
189
+
175
190
  kv_cache_max_len = 1024
176
191
  model = build_model(
177
192
  checkpoint_path, test_model=test_model, kv_cache_max_len=kv_cache_max_len
@@ -186,5 +201,5 @@ def define_and_run(checkpoint_path, test_model=False) -> None:
186
201
 
187
202
 
188
203
  if __name__ == "__main__":
189
- checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
190
- define_and_run(checkpoint_path)
204
+ input_checkpoint_path = os.path.join(Path.home(), "Downloads/tiny_llama")
205
+ define_and_run(input_checkpoint_path)
@@ -17,14 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
25
  import numpy as np
26
26
  import torch
27
- import torch.nn as nn
27
+ from torch import nn
28
28
 
29
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
 
44
44
 
45
45
  class Gemma(nn.Module):
46
+ """A Gemma model built from the Edge Generative API layers."""
46
47
 
47
48
  def __init__(self, config: cfg.ModelConfig):
48
49
  super().__init__()
@@ -60,7 +61,7 @@ class Gemma(nn.Module):
60
61
  # Gemma re-uses the embedding as the head projection layer.
61
62
  self.lm_head.weight.data = self.tok_embedding.weight.data
62
63
  self.transformer_blocks = nn.ModuleList(
63
- TransformerBlock(config) for _ in range(config.num_layers)
64
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
64
65
  )
65
66
  self.final_norm = builder.build_norm(
66
67
  config.embedding_dim,
@@ -88,9 +89,9 @@ class Gemma(nn.Module):
88
89
  # This can be eliminated if we handle k/v cache updates inside the model itself.
89
90
  @torch.inference_mode
90
91
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
91
- B, T = idx.size()
92
- assert self.config.max_seq_len >= T, (
93
- f"Cannot forward sequence of length {T}, max seq length is only"
92
+ _, seq_len = idx.size()
93
+ assert self.config.max_seq_len >= seq_len, (
94
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
94
95
  f" {self.config.max_seq_len}"
95
96
  )
96
97
 
@@ -104,7 +105,7 @@ class Gemma(nn.Module):
104
105
  x = self.tok_embedding(idx)
105
106
  x = x * (self.config.embedding_dim**0.5)
106
107
 
107
- for i, block in enumerate(self.transformer_blocks):
108
+ for _, block in enumerate(self.transformer_blocks):
108
109
  x = block(x, (cos, sin), mask, input_pos)
109
110
 
110
111
  x = self.final_norm(x)
@@ -113,6 +114,15 @@ class Gemma(nn.Module):
113
114
 
114
115
 
115
116
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
117
+ """Returns the model config for a Gemma 2B model.
118
+
119
+ Args:
120
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
121
+ is 1024.
122
+
123
+ Returns:
124
+ The model config for a Gemma 2B model.
125
+ """
116
126
  attn_config = cfg.AttentionConfig(
117
127
  num_heads=8,
118
128
  head_dim=256,
@@ -147,43 +157,16 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
147
157
  return config
148
158
 
149
159
 
150
- # TODO(b/363021962): Clean up this part to streamline fake model config generation.
151
160
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
152
- attn_config = cfg.AttentionConfig(
153
- num_heads=8,
154
- head_dim=256,
155
- num_query_groups=1,
156
- rotary_percentage=1.0,
157
- )
158
- ff_config = cfg.FeedForwardConfig(
159
- type=cfg.FeedForwardType.GATED,
160
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
161
- intermediate_size=128,
162
- )
163
- norm_config = cfg.NormalizationConfig(
164
- type=cfg.NormalizationType.RMS_NORM,
165
- epsilon=1e-6,
166
- zero_centered=True,
167
- )
168
- config = cfg.ModelConfig(
169
- vocab_size=128,
170
- num_layers=2,
171
- max_seq_len=2 * kv_cache_max_len,
172
- embedding_dim=2048,
173
- kv_cache_max_len=kv_cache_max_len,
174
- attn_config=attn_config,
175
- ff_config=ff_config,
176
- pre_attention_norm_config=norm_config,
177
- post_attention_norm_config=norm_config,
178
- final_norm_config=norm_config,
179
- parallel_residual=False,
180
- lm_head_use_bias=False,
181
- enable_hlfb=True,
182
- )
161
+ config = get_model_config_2b(kv_cache_max_len)
162
+ config.ff_config.intermediate_size = 128
163
+ config.vocab_size = 128
164
+ config.num_layers = 2
165
+ config.max_seq_len = 2 * kv_cache_max_len
183
166
  return config
184
167
 
185
168
 
186
- def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
169
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
187
170
  config = get_model_config_2b(**kwargs)
188
171
  model = Gemma(config)
189
172
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -195,6 +178,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
195
178
 
196
179
 
197
180
  def define_and_run_2b() -> None:
181
+ """Instantiates and runs a Gemma 2B model."""
182
+
198
183
  current_dir = Path(__file__).parent.resolve()
199
184
  gemma_goldens = torch.load(current_dir / "gemma_lm_logits.pt")
200
185
 
@@ -18,14 +18,14 @@ import os
18
18
  from pathlib import Path
19
19
  from typing import Optional, Tuple
20
20
 
21
- from ai_edge_torch.generative.layers.attention import TransformerBlock
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- import ai_edge_torch.generative.layers.builder as builder
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
26
  import numpy as np
27
27
  import torch
28
- import torch.nn as nn
28
+ from torch import nn
29
29
 
30
30
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
31
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,7 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
  )
44
44
 
45
45
 
46
- class Gemma2Block(TransformerBlock):
46
+ class Gemma2Block(attention.TransformerBlock):
47
47
 
48
48
  def forward(
49
49
  self,
@@ -76,6 +76,7 @@ class Gemma2Block(TransformerBlock):
76
76
 
77
77
 
78
78
  class Gemma2(nn.Module):
79
+ """A Gemma2 model built from the Edge Generative API layers."""
79
80
 
80
81
  def __init__(self, config: cfg.ModelConfig):
81
82
  super().__init__()
@@ -138,9 +139,9 @@ class Gemma2(nn.Module):
138
139
 
139
140
  @torch.inference_mode
140
141
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
141
- B, T = idx.size()
142
- assert self.config.max_seq_len >= T, (
143
- f"Cannot forward sequence of length {T}, max seq length is only"
142
+ _, seq_len = idx.size()
143
+ assert self.config.max_seq_len >= seq_len, (
144
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
144
145
  f" {self.config.max_seq_len}"
145
146
  )
146
147
 
@@ -166,6 +167,15 @@ class Gemma2(nn.Module):
166
167
 
167
168
 
168
169
  def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
170
+ """Returns the model config for a Gemma2 2B model.
171
+
172
+ Args:
173
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
174
+ is 1024.
175
+
176
+ Returns:
177
+ The model config for a Gemma 2B model.
178
+ """
169
179
  attn_config = cfg.AttentionConfig(
170
180
  num_heads=8,
171
181
  head_dim=256,
@@ -210,50 +220,19 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
210
220
 
211
221
 
212
222
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
213
- attn_config = cfg.AttentionConfig(
214
- num_heads=4,
215
- head_dim=64,
216
- num_query_groups=4,
217
- rotary_percentage=1.0,
218
- qkv_transpose_before_split=True,
219
- logit_softcap=50.0,
220
- sliding_window_size=64,
221
- attn_types=[cfg.AttentionType.GLOBAL, cfg.AttentionType.LOCAL_SLIDING]
222
- * 13,
223
- )
224
-
225
- norm_config = cfg.NormalizationConfig(
226
- type=cfg.NormalizationType.RMS_NORM,
227
- epsilon=1e-6,
228
- zero_centered=True,
229
- )
230
- ff_config = cfg.FeedForwardConfig(
231
- type=cfg.FeedForwardType.GATED,
232
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
233
- intermediate_size=128,
234
- pre_ff_norm_config=norm_config,
235
- post_ff_norm_config=norm_config,
236
- )
237
- config = cfg.ModelConfig(
238
- vocab_size=128,
239
- num_layers=2,
240
- max_seq_len=2 * kv_cache_max_len,
241
- embedding_dim=128,
242
- kv_cache_max_len=kv_cache_max_len,
243
- attn_config=attn_config,
244
- ff_config=ff_config,
245
- pre_attention_norm_config=norm_config,
246
- post_attention_norm_config=norm_config,
247
- final_norm_config=norm_config,
248
- parallel_residual=False,
249
- lm_head_use_bias=False,
250
- enable_hlfb=True,
251
- final_logit_softcap=30.0,
252
- )
223
+ config = get_model_config_2b(kv_cache_max_len)
224
+ config.attn_config.num_heads = 4
225
+ config.attn_config.head_dim = 64
226
+ config.attn_config.sliding_window_size = 64
227
+ config.ff_config.intermediate_size = 128
228
+ config.vocab_size = 128
229
+ config.num_layers = 2
230
+ config.max_seq_len = 2 * kv_cache_max_len
231
+ config.embedding_dim = 128
253
232
  return config
254
233
 
255
234
 
256
- def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
235
+ def build_2b_model(checkpoint_path: str, **kwargs) -> nn.Module:
257
236
  config = get_model_config_2b(**kwargs)
258
237
  model = Gemma2(config)
259
238
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -265,6 +244,8 @@ def build_2b_model(checkpoint_path, **kwargs) -> nn.Module:
265
244
 
266
245
 
267
246
  def define_and_run_2b() -> None:
247
+ """Instantiates and runs a Gemma2 2B model."""
248
+
268
249
  current_dir = Path(__file__).parent.resolve()
269
250
  gemma2_goldens = torch.load(current_dir / "gemma2it_2b_golden.pt")
270
251
  print("Running GEMMA 2")
@@ -18,14 +18,14 @@
18
18
  import os
19
19
  from pathlib import Path
20
20
 
21
- from ai_edge_torch.generative.layers.attention import TransformerBlock
21
+ from ai_edge_torch.generative.layers import attention
22
+ from ai_edge_torch.generative.layers import builder
22
23
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
23
- import ai_edge_torch.generative.layers.builder as builder
24
24
  import ai_edge_torch.generative.layers.model_config as cfg
25
25
  import ai_edge_torch.generative.utilities.loader as loading_utils
26
26
  import numpy as np
27
27
  import torch
28
- import torch.nn as nn
28
+ from torch import nn
29
29
 
30
30
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
31
31
  ff_up_proj="model.layers.{}.mlp.fc1",
@@ -42,6 +42,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
42
42
 
43
43
 
44
44
  class Phi2(nn.Module):
45
+ """A Phi-2 model built from the Edge Generative API layers."""
45
46
 
46
47
  def __init__(self, config: cfg.ModelConfig):
47
48
  super().__init__()
@@ -55,7 +56,7 @@ class Phi2(nn.Module):
55
56
  config.vocab_size, config.embedding_dim, padding_idx=0
56
57
  )
57
58
  self.transformer_blocks = nn.ModuleList(
58
- TransformerBlock(config) for _ in range(config.num_layers)
59
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
59
60
  )
60
61
  self.final_norm = builder.build_norm(
61
62
  config.embedding_dim,
@@ -83,9 +84,9 @@ class Phi2(nn.Module):
83
84
  # This can be eliminated if we handle k/v cache updates inside the model itself.
84
85
  @torch.inference_mode
85
86
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
86
- B, T = idx.size()
87
- assert self.config.max_seq_len >= T, (
88
- f"Cannot forward sequence of length {T}, max seq length is only"
87
+ _, seq_len = idx.size()
88
+ assert self.config.max_seq_len >= seq_len, (
89
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
89
90
  f" {self.config.max_seq_len}"
90
91
  )
91
92
 
@@ -98,7 +99,7 @@ class Phi2(nn.Module):
98
99
  # forward the model itself
99
100
  x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
100
101
 
101
- for i, block in enumerate(self.transformer_blocks):
102
+ for _, block in enumerate(self.transformer_blocks):
102
103
  x = block(x, (cos, sin), mask, input_pos)
103
104
 
104
105
  x = self.final_norm(x)
@@ -107,6 +108,15 @@ class Phi2(nn.Module):
107
108
 
108
109
 
109
110
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
111
+ """Returns the model config for a Phi-2 model.
112
+
113
+ Args:
114
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
115
+ is 1024.
116
+
117
+ Returns:
118
+ The model config for a Phi-2 model.
119
+ """
110
120
  attn_config = cfg.AttentionConfig(
111
121
  num_heads=32,
112
122
  head_dim=80,
@@ -140,35 +150,11 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
140
150
 
141
151
 
142
152
  def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
143
- attn_config = cfg.AttentionConfig(
144
- num_heads=16,
145
- head_dim=80,
146
- num_query_groups=4,
147
- rotary_percentage=0.4,
148
- qkv_use_bias=True,
149
- output_proj_use_bias=True,
150
- )
151
- ff_config = cfg.FeedForwardConfig(
152
- type=cfg.FeedForwardType.SEQUENTIAL,
153
- activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
154
- intermediate_size=128,
155
- use_bias=True,
156
- )
157
- norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.LAYER_NORM)
158
- config = cfg.ModelConfig(
159
- vocab_size=128,
160
- num_layers=2,
161
- max_seq_len=2 * kv_cache_max_len,
162
- kv_cache_max_len=kv_cache_max_len,
163
- embedding_dim=128,
164
- attn_config=attn_config,
165
- ff_config=ff_config,
166
- pre_attention_norm_config=norm_config,
167
- final_norm_config=norm_config,
168
- parallel_residual=True,
169
- lm_head_use_bias=True,
170
- enable_hlfb=True,
171
- )
153
+ config = get_model_config(kv_cache_max_len)
154
+ config.vocab_size = 128
155
+ config.num_layers = 2
156
+ config.max_seq_len = 2 * kv_cache_max_len
157
+ config.ff_config.intermediate_size = 128
172
158
  return config
173
159
 
174
160
 
@@ -181,6 +167,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
181
167
 
182
168
 
183
169
  def define_and_run() -> None:
170
+ """Instantiates and runs a Phi-2 model."""
171
+
184
172
  current_dir = Path(__file__).parent.resolve()
185
173
  phi2_goldens = torch.load(current_dir / "phi2_lm_logits.pt")
186
174
  kv_cache_max_len = 1024
@@ -71,6 +71,56 @@ class ToySingleLayerModel(torch.nn.Module):
71
71
  return self.lm_head(x)
72
72
 
73
73
 
74
+ class ToySingleLayerModelWeightSharing(torch.nn.Module):
75
+
76
+ def __init__(self, config: cfg.ModelConfig) -> None:
77
+ super().__init__()
78
+ self.lm_head = nn.Linear(
79
+ config.embedding_dim, config.vocab_size, bias=config.lm_head_use_bias
80
+ )
81
+ self.tok_embedding = nn.Embedding(config.vocab_size, config.embedding_dim)
82
+ self.lm_head = nn.Linear(
83
+ config.embedding_dim,
84
+ config.vocab_size,
85
+ bias=config.lm_head_use_bias,
86
+ )
87
+ self.lm_head.weight.data = self.tok_embedding.weight.data
88
+ self.transformer_block = TransformerBlock(config)
89
+ self.final_norm = builder.build_norm(
90
+ config.embedding_dim,
91
+ config.final_norm_config,
92
+ )
93
+ self.rope_cache = attn_utils.build_rope_cache(
94
+ size=config.max_seq_len,
95
+ dim=int(
96
+ config.attn_config.rotary_percentage * config.attn_config.head_dim
97
+ ),
98
+ base=10_000,
99
+ condense_ratio=1,
100
+ dtype=torch.float32,
101
+ device=torch.device('cpu'),
102
+ )
103
+ self.mask_cache = attn_utils.build_causal_mask_cache(
104
+ size=config.max_seq_len, dtype=torch.float32, device=torch.device('cpu')
105
+ )
106
+ self.config = config
107
+
108
+ @torch.inference_mode
109
+ def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
110
+ x = self.tok_embedding(idx)
111
+ cos, sin = self.rope_cache
112
+
113
+ cos = cos.index_select(0, input_pos)
114
+ sin = sin.index_select(0, input_pos)
115
+ mask = self.mask_cache.index_select(2, input_pos)
116
+ mask = mask[:, :, :, : self.config.max_seq_len]
117
+
118
+ x = self.transformer_block(x, (cos, sin), mask, input_pos)
119
+ x = self.final_norm(x)
120
+ res = self.lm_head(x)
121
+ return res
122
+
123
+
74
124
  def get_model_config() -> cfg.ModelConfig:
75
125
  attn_config = cfg.AttentionConfig(
76
126
  num_heads=32,
@@ -17,14 +17,14 @@
17
17
  import os
18
18
  from pathlib import Path
19
19
 
20
- from ai_edge_torch.generative.layers.attention import TransformerBlock
20
+ from ai_edge_torch.generative.layers import attention
21
+ from ai_edge_torch.generative.layers import builder
21
22
  import ai_edge_torch.generative.layers.attention_utils as attn_utils
22
- import ai_edge_torch.generative.layers.builder as builder
23
23
  import ai_edge_torch.generative.layers.model_config as cfg
24
24
  import ai_edge_torch.generative.utilities.loader as loading_utils
25
25
  import numpy as np
26
26
  import torch
27
- import torch.nn as nn
27
+ from torch import nn
28
28
 
29
29
  TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
30
30
  ff_up_proj="model.layers.{}.mlp.up_proj",
@@ -43,6 +43,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
43
43
 
44
44
 
45
45
  class TinyLLamma(nn.Module):
46
+ """A TinyLlama model built from the Edge Generative API layers."""
46
47
 
47
48
  def __init__(self, config: cfg.ModelConfig):
48
49
  super().__init__()
@@ -56,7 +57,7 @@ class TinyLLamma(nn.Module):
56
57
  config.vocab_size, config.embedding_dim, padding_idx=0
57
58
  )
58
59
  self.transformer_blocks = nn.ModuleList(
59
- TransformerBlock(config) for _ in range(config.num_layers)
60
+ attention.TransformerBlock(config) for _ in range(config.num_layers)
60
61
  )
61
62
  self.final_norm = builder.build_norm(
62
63
  config.embedding_dim,
@@ -84,9 +85,9 @@ class TinyLLamma(nn.Module):
84
85
  # This can be eliminated if we handle k/v cache updates inside the model itself.
85
86
  @torch.inference_mode
86
87
  def forward(self, idx: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
87
- B, T = idx.size()
88
- assert self.config.max_seq_len >= T, (
89
- f"Cannot forward sequence of length {T}, max seq length is only"
88
+ _, seq_len = idx.size()
89
+ assert self.config.max_seq_len >= seq_len, (
90
+ f"Cannot forward sequence of length {seq_len}, max seq length is only"
90
91
  f" {self.config.max_seq_len}"
91
92
  )
92
93
 
@@ -99,7 +100,7 @@ class TinyLLamma(nn.Module):
99
100
  # forward the model itself
100
101
  x = self.tok_embedding(idx) # token embeddings of shape (b, t, n_embd)
101
102
 
102
- for i, block in enumerate(self.transformer_blocks):
103
+ for _, block in enumerate(self.transformer_blocks):
103
104
  x = block(x, (cos, sin), mask, input_pos)
104
105
 
105
106
  x = self.final_norm(x)
@@ -109,6 +110,15 @@ class TinyLLamma(nn.Module):
109
110
 
110
111
 
111
112
  def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
113
+ """Returns the model config for a TinyLlama model.
114
+
115
+ Args:
116
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
117
+ is 1024.
118
+
119
+ Returns:
120
+ The model config for a TinyLlama model.
121
+ """
112
122
  attn_config = cfg.AttentionConfig(
113
123
  num_heads=32,
114
124
  head_dim=64,
@@ -145,7 +155,7 @@ def get_fake_model_config() -> cfg.ModelConfig:
145
155
  return config
146
156
 
147
157
 
148
- def build_model(checkpoint_path, **kwargs) -> nn.Module:
158
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
149
159
  config = get_model_config(**kwargs)
150
160
  model = TinyLLamma(config)
151
161
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
@@ -154,6 +164,8 @@ def build_model(checkpoint_path, **kwargs) -> nn.Module:
154
164
 
155
165
 
156
166
  def define_and_run() -> None:
167
+ """Instantiates and runs a TinyLlama model."""
168
+
157
169
  current_dir = Path(__file__).parent.resolve()
158
170
  tiny_llama_goldens = torch.load(current_dir / "tiny_llama_lm_logits.pt")
159
171
  kv_cache_max_len = 1024
@@ -70,35 +70,6 @@ class TestModelConversion(googletest.TestCase):
70
70
  )
71
71
  )
72
72
 
73
- @googletest.skipIf(
74
- ai_edge_config.Config.use_torch_xla,
75
- reason="tests with custom ops are not supported on oss",
76
- )
77
- def test_toy_model_with_multi_batches(self):
78
- self.skipTest("b/362842043")
79
- config = toy_model_with_kv_cache.get_model_config()
80
- config.batch_size = 2
81
- pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config).eval()
82
- idx, input_pos = torch.tensor([[1], [2]], dtype=torch.long), torch.tensor(
83
- [10], dtype=torch.int64
84
- )
85
-
86
- edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
87
- edge_model.set_interpreter_builder(
88
- self._interpreter_builder(edge_model.tflite_model())
89
- )
90
-
91
- self.assertTrue(
92
- model_coverage.compare_tflite_torch(
93
- edge_model,
94
- pytorch_model,
95
- (idx, input_pos),
96
- num_valid_inputs=1,
97
- atol=1e-5,
98
- rtol=1e-5,
99
- )
100
- )
101
-
102
73
  @googletest.skipIf(
103
74
  ai_edge_config.Config.use_torch_xla,
104
75
  reason="tests with custom ops are not supported on oss",
@@ -25,16 +25,16 @@ from ai_edge_torch.generative.quantize.quant_attrs import Granularity
25
25
  from ai_edge_torch.generative.quantize.quant_attrs import Mode
26
26
  from ai_edge_torch.quantize import quant_config
27
27
  from ai_edge_torch.testing import model_coverage
28
- from parameterized import parameterized
29
28
  import torch
30
29
 
31
30
  from absl.testing import absltest as googletest
31
+ from absl.testing import parameterized
32
32
 
33
33
 
34
- class TestVerifyRecipes(googletest.TestCase):
34
+ class TestVerifyRecipes(parameterized.TestCase):
35
35
  """Unit tests that check for model quantization recipes."""
36
36
 
37
- @parameterized.expand([
37
+ @parameterized.parameters([
38
38
  (Dtype.FP32, Dtype.FP32),
39
39
  (Dtype.INT8, Dtype.INT8),
40
40
  (Dtype.INT8, Dtype.FP16),
@@ -52,7 +52,7 @@ class TestVerifyRecipes(googletest.TestCase):
52
52
  with self.assertRaises(ValueError):
53
53
  quant_recipe.LayerQuantRecipe(activation, weight, m, a, g).verify()
54
54
 
55
- @parameterized.expand([
55
+ @parameterized.parameters([
56
56
  (
57
57
  Dtype.FP32,
58
58
  Dtype.INT8,
@@ -88,7 +88,7 @@ class TestVerifyRecipes(googletest.TestCase):
88
88
  ).verify()
89
89
 
90
90
 
91
- class TestQuantizeConvert(googletest.TestCase):
91
+ class TestQuantizeConvert(parameterized.TestCase):
92
92
  """Test conversion with quantization."""
93
93
 
94
94
  def _attention_int8_dynamic_recipe() -> quant_config.QuantConfig:
@@ -105,17 +105,13 @@ class TestQuantizeConvert(googletest.TestCase):
105
105
  )
106
106
  )
107
107
 
108
- @parameterized.expand([
108
+ @parameterized.parameters([
109
109
  (quant_recipes.full_fp16_recipe()),
110
110
  (quant_recipes.full_int8_dynamic_recipe()),
111
111
  (quant_recipes.full_int8_weight_only_recipe()),
112
112
  (_attention_int8_dynamic_recipe()),
113
113
  (_feedforward_int8_dynamic_recipe()),
114
114
  ])
115
- @googletest.skipIf(
116
- not config.Config.use_torch_xla,
117
- reason="Not working with odml_torch at the moment.",
118
- )
119
115
  def test_quantize_convert_toy_sizes(self, quant_config):
120
116
  config = toy_model.get_model_config()
121
117
  pytorch_model = toy_model.ToySingleLayerModel(config)
@@ -132,6 +128,23 @@ class TestQuantizeConvert(googletest.TestCase):
132
128
  "Quantized model isn't smaller than F32 model.",
133
129
  )
134
130
 
131
+ def test_quantize_convert_toy_weight_sharing(self):
132
+ config = toy_model.get_model_config()
133
+ pytorch_model = toy_model.ToySingleLayerModelWeightSharing(config)
134
+ idx = torch.unsqueeze(torch.arange(0, 100), 0)
135
+ input_pos = torch.arange(0, 100)
136
+
137
+ quant_config = quant_recipes.full_int8_dynamic_recipe()
138
+ quantized_model = ai_edge_torch.convert(
139
+ pytorch_model, (idx, input_pos), quant_config=quant_config
140
+ )
141
+ float_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
142
+ self.assertLess(
143
+ len(quantized_model._tflite_model),
144
+ len(float_model._tflite_model),
145
+ "Quantized model isn't smaller than F32 model.",
146
+ )
147
+
135
148
  def test_quantize_convert_compare_toy(self):
136
149
  self.skipTest("b/338288901")
137
150
  config = toy_model_with_kv_cache.get_model_config()
@@ -208,7 +208,7 @@ class ModelLoader:
208
208
  if self._file_name.endswith(".safetensors"):
209
209
  return load_safetensors
210
210
 
211
- if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
211
+ if self._file_name.endswith(".bin") or self._file_name.endswith("pt"):
212
212
  return load_pytorch_statedict
213
213
 
214
214
  raise ValueError("File format not supported.")
@@ -21,6 +21,7 @@ from ai_edge_torch import odml_torch
21
21
  from ai_edge_torch._convert import conversion_utils
22
22
  from ai_edge_torch._convert import signature as signature_module
23
23
  from ai_edge_torch.lowertools import common_utils
24
+ from ai_edge_torch.lowertools import translate_recipe
24
25
  from ai_edge_torch.odml_torch import export
25
26
  from ai_edge_torch.odml_torch import export_utils
26
27
  from ai_edge_torch.quantize import quant_config as qcfg
@@ -186,10 +187,29 @@ def merged_bundle_to_tfl_model(
186
187
  converter._experimental_enable_composite_direct_lowering = True
187
188
  converter.model_origin_framework = "PYTORCH"
188
189
 
190
+ conversion_utils.set_tfl_converter_quant_flags(converter, quant_config)
191
+ if (
192
+ quant_config is not None
193
+ and quant_config._quantizer_mode
194
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
195
+ ):
196
+ translated_recipe = translate_recipe.translate_to_ai_edge_recipe(
197
+ quant_config.generative_recipe
198
+ )
199
+
189
200
  conversion_utils.apply_tfl_converter_flags(converter, _tfl_converter_flags)
190
201
 
191
202
  tflite_model = converter.convert()
192
203
 
204
+ if (
205
+ quant_config is not None
206
+ and quant_config._quantizer_mode
207
+ == quant_config._QuantizerMode.AI_EDGE_QUANTIZER
208
+ ):
209
+ tflite_model = translate_recipe.quantize_model(
210
+ tflite_model, translated_recipe
211
+ )
212
+
193
213
  return tflite_model
194
214
 
195
215
 
@@ -25,8 +25,8 @@ from typing import Any, Dict, Optional, Tuple, Union
25
25
  from ai_edge_torch import model
26
26
  from ai_edge_torch._convert import conversion_utils
27
27
  from ai_edge_torch._convert import signature as signature_module
28
- from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
29
28
  from ai_edge_torch.lowertools import common_utils
29
+ from ai_edge_torch.lowertools import translate_recipe
30
30
  from ai_edge_torch.quantize import quant_config as qcfg
31
31
  import torch
32
32
  from torch_xla import stablehlo
@@ -17,7 +17,8 @@ from ai_edge_quantizer import quantizer
17
17
  from ai_edge_torch.generative.quantize import quant_attrs
18
18
  from ai_edge_torch.generative.quantize import quant_recipe
19
19
 
20
- _OpExecutionMode = quantizer.qtyping.OpExecutionMode
20
+ _ComputePrecision = quantizer.qtyping.ComputePrecision
21
+ _QuantGranularity = quantizer.qtyping.QuantGranularity
21
22
  _OpName = quantizer.qtyping.TFLOperationName
22
23
  _TensorQuantConfig = quantizer.qtyping.TensorQuantizationConfig
23
24
  _OpQuantConfig = quantizer.qtyping.OpQuantizationConfig
@@ -50,21 +51,31 @@ def _get_dtype_from_dtype(
50
51
  return quantizer.qtyping.TensorDataType.INT
51
52
 
52
53
 
53
- def _get_execution_mode_from_mode(mode: quant_attrs.Mode) -> _OpExecutionMode:
54
+ def _get_compute_precision_from_mode(
55
+ mode: quant_attrs.Mode,
56
+ ) -> _ComputePrecision:
54
57
  if mode == quant_attrs.Mode.DYNAMIC_RANGE:
55
- return _OpExecutionMode.DRQ
58
+ return _ComputePrecision.INTEGER
56
59
  elif mode == quant_attrs.Mode.WEIGHT_ONLY:
57
- return _OpExecutionMode.WEIGHT_ONLY
60
+ return _ComputePrecision.FLOAT
58
61
  raise ValueError('Unimplemented execution mode')
59
62
 
60
63
 
61
- def _get_channelwise_from_granularity(
64
+ def _get_explicit_dequant_from_mode(mode: quant_attrs.Mode) -> bool:
65
+ if mode == quant_attrs.Mode.DYNAMIC_RANGE:
66
+ return False
67
+ elif mode == quant_attrs.Mode.WEIGHT_ONLY:
68
+ return True
69
+ raise ValueError('Unimplemented execution mode')
70
+
71
+
72
+ def _get_granularity(
62
73
  granularity: quant_attrs.Granularity,
63
74
  ) -> bool:
64
75
  if granularity == quant_attrs.Granularity.CHANNELWISE:
65
- return True
66
- elif granularity == quant_attrs.Granularity.NONE:
67
- return False
76
+ return _QuantGranularity.CHANNELWISE
77
+ if granularity == quant_attrs.Granularity.NONE:
78
+ return _QuantGranularity.TENSORWISE
68
79
  raise ValueError('Unimplemented granularity')
69
80
 
70
81
 
@@ -88,12 +99,13 @@ def _set_quant_config(
88
99
  weight_tensor_config=_TensorQuantConfig(
89
100
  num_bits=_get_nbits_from_dtype(layer_recipe.weight_dtype),
90
101
  symmetric=True,
91
- channel_wise=_get_channelwise_from_granularity(
92
- layer_recipe.granularity
93
- ),
102
+ granularity=_get_granularity(layer_recipe.granularity),
94
103
  dtype=_get_dtype_from_dtype(layer_recipe.weight_dtype),
95
104
  ),
96
- execution_mode=_get_execution_mode_from_mode(layer_recipe.mode),
105
+ compute_precision=_get_compute_precision_from_mode(layer_recipe.mode),
106
+ explicit_dequantize=_get_explicit_dequant_from_mode(
107
+ layer_recipe.mode
108
+ ),
97
109
  ),
98
110
  algorithm_key=_get_algorithm_key_from_algorithm(layer_recipe.algorithm),
99
111
  )
@@ -277,7 +277,7 @@ def exported_program_to_mlir(
277
277
  main_func.attributes["sym_visibility"] = ir.StringAttr.get("public")
278
278
  temp_func.erase()
279
279
 
280
- module.operation.verify()
280
+ module.operation.verify()
281
281
 
282
282
  input_signature = []
283
283
  state_dict = {}
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240902"
16
+ __version__ = "0.3.0.dev20240905"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240902
3
+ Version: 0.3.0.dev20240905
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -30,7 +30,7 @@ Requires-Dist: tabulate
30
30
  Requires-Dist: torch>=2.4.0
31
31
  Requires-Dist: torch-xla>=2.4.0
32
32
  Requires-Dist: tf-nightly>=2.18.0.dev20240722
33
- Requires-Dist: ai-edge-quantizer-nightly==0.0.1.dev20240718
33
+ Requires-Dist: ai-edge-quantizer-nightly
34
34
 
35
35
  Library that supports converting PyTorch models into a .tflite format, which can
36
36
  then be run with TensorFlow Lite and MediaPipe. This enables applications for
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=pl_weDdkMIjqukMxBF4uho_z-MvFlGy_ButOq6tJwVc,706
5
+ ai_edge_torch/version.py,sha256=-vQGdl2EaV-VpHRty3RwZzH0UVntVt1tmjhtKOIDscw,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -42,21 +42,21 @@ ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQe
42
42
  ai_edge_torch/generative/examples/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
43
43
  ai_edge_torch/generative/examples/experimental/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
44
44
  ai_edge_torch/generative/examples/experimental/gemma/convert_to_tflite.py,sha256=lpiPFSh3SJd6WwuZ0QegSva3__iSz2tUD7L7QfkAe4I,3085
45
- ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=EdElPCDLYxnNvkPMJkE3WKvESze1ehgShEk2NnbrXLg,7527
45
+ ai_edge_torch/generative/examples/experimental/gemma/gemma.py,sha256=aCoD86pf4nuquUMk7MOR-jsN5FqvySSEuMx9Psxjblk,7261
46
46
  ai_edge_torch/generative/examples/experimental/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
47
47
  ai_edge_torch/generative/examples/experimental/phi/convert_to_tflite.py,sha256=DavrdGmqUgoThsGNRv3LXMW5tvJdYEvj66Hf1XRqkXU,3055
48
- ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=u-VJX5mjzQKspXtAhNi53LCITtag-3nCaRTKdk5Z1sc,6231
48
+ ai_edge_torch/generative/examples/experimental/phi/phi2.py,sha256=Jxf3ZyYDpS78l6uh4_LGGIcHawrOhZ1vHoHFVxRaK40,6789
49
49
  ai_edge_torch/generative/examples/experimental/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
50
50
  ai_edge_torch/generative/examples/experimental/tiny_llama/convert_to_tflite.py,sha256=xPVvHQjLJHFiRv_-Fy2sDm0Aft7SG8SXiV6o3rF03cQ,3108
51
- ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=zQYtyk3xYdiRAnzMKN58Q_wgTQFnDujxp6L4RFQjiD4,6383
51
+ ai_edge_torch/generative/examples/experimental/tiny_llama/tiny_llama.py,sha256=nUm0SQbCTmNAc5u-C9gbQRFPt7GDvUt6UjH6doTvH-I,6817
52
52
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
53
53
  ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=pseJExH35lSAK0ZtzSHB1sFtRtF_EuT2xcSpGU0gKVI,2524
54
54
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=w589IJETATd6Z9_1XCIWbrlCV3E92X_5ac3VVCVFXG0,2522
55
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=pzD9dYUYg8E6fFACh-8B8G9NHFXOVEWBjf5aDeipU2s,7202
56
- ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=ypd6uBb4FgDpuWm_w8JNYBAf4eFxWbYccs8vCgBhi-I,9374
55
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=lc1-CfIObHj9D5VJy78BOtGTrQM4TYMI6NfVi8KM5qA,6747
56
+ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=OcUQLFR136e3QRVXRnmtYnRHXyHJS9EYEFlJ1ymXyRY,8859
57
57
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
58
58
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=ON6zLO-nFS8eJ2yhyWzT5x2Somr-Ca-VjpjT7OGFU10,2506
59
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=91mWxEtKgDtUhCAewWNwH_UOOCzy6tPdf6LNRlxZhrc,6700
59
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=FFnhv1kx4fHRhSeOreLGj8kAqPnmkz9pD1RRSDVlM_w,6332
60
60
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
62
62
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -77,12 +77,12 @@ ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=CZVuNEL8OHPkdsz
77
77
  ai_edge_torch/generative/examples/t5/t5.py,sha256=Zobw5BV-PC0nlU9Z6fzb2O07rMeU8vGIk-KtKp9D_H0,20871
78
78
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=1lvbSlzyBwmd5Bs7-Up_v4iJQkCPIJx2RmMkLgy7l2Q,8508
79
79
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
80
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=LfWO_gSr1f66V1pxAc6yh21mtaJs7TVeuO9748zXBnE,3963
80
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_R_pp-A_7gKGSdHWDSXyis97r1ELVI,5622
81
81
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=l9swUKTcDtnTibNSNExaMgLvDeJ4Er2tVh5ZW1EtRgk,5809
82
82
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=mQkcpSe6HlRLMkIRCEHc9ZXL7jxEp9RWSGUQjjd-r2w,4841
83
83
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
84
84
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=CLRqO7ycMbpy7J3_Czp1sLx6hcdwGD9zVq04yRba0e8,2550
85
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=JmwU1sniO37vnCFc8dklbd-0ofTZK0PaBv_Ksn1Vq6M,5930
85
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=4ku0ni3MOWamhPrzLap0BmtdNFk7CH0hwjPNoRAKpvQ,6278
86
86
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
87
87
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
88
88
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -106,16 +106,14 @@ ai_edge_torch/generative/quantize/quant_recipe.py,sha256=tKnuJq6hPD23JPCB9nPAlE1
106
106
  ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=4fgmP_GgeiFUOkIaC9ZZXC12eO3DQZdrWDXRz5YXiwU,2270
107
107
  ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FBxkHM6RJ3C14B2I1mjItjc,2030
108
108
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
109
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
110
- ai_edge_torch/generative/quantize/ai_edge_quantizer_glue/translate_recipe.py,sha256=sSHc_4hUEvi-3KmqbpqWbrRKBjCI1AOctM3dr2EH3vk,5263
111
109
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
112
110
  ai_edge_torch/generative/test/test_experimental_ekv.py,sha256=8qv_eVtJW9GPvBEf2hPQe3tpdJ33XShya6MCX1FqrZM,4355
113
111
  ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
114
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=wQLVjMnKHBCVCU_I-xAUZvlOFoDiwYwKQDvCZ2mjtOM,6193
112
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=KZ0uCeOdKMKyW8jBE8aOjweZmws4mvz37u8zH4XayVU,5285
115
113
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=o3l7HFHP-sg8aHeLNTSpMF91YovPODjp4QzYUnSJiIE,4479
116
- ai_edge_torch/generative/test/test_quantize.py,sha256=JEsk9SAkHK0SFm44K_quISc5yBBS6yvtBP1MDyFHdFw,5344
114
+ ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
117
115
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
118
- ai_edge_torch/generative/utilities/loader.py,sha256=QFZ2lkeoYQ9MZ1CAFVxBHG4OT192SH74UtJCvbDsdeI,12727
116
+ ai_edge_torch/generative/utilities/loader.py,sha256=6J0aAP6-6LySeqeYIHKcchr5T9cVtSO34aoDr3V9gxY,12726
119
117
  ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
120
118
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=_UXcc1QKT-S92hikfo-fTBFhnYLzROqcyRqKonVsqj4,16885
121
119
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
@@ -128,13 +126,14 @@ ai_edge_torch/hlfb/test/test_stablehlo_composite_builder.py,sha256=j8WpeS-mz3Zr4
128
126
  ai_edge_torch/lowertools/__init__.py,sha256=A8WBXvWtuFYYWtNTqPD7waVntLaSVAnSMwx5ugjZBIw,761
129
127
  ai_edge_torch/lowertools/_shim.py,sha256=ilL7x1ebUBj1clg7bagrX4y_nVSHiGrvDrOVfuTeenE,3039
130
128
  ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
131
- ai_edge_torch/lowertools/odml_torch_utils.py,sha256=GKfW1X-QSFffQdVlBuD-bNpP265xcdUlfBY3-9I4f_o,7447
129
+ ai_edge_torch/lowertools/odml_torch_utils.py,sha256=K5dZ_fFDL3GWKo0IoY4OC_GX5MY-guY-MqteolyV9hg,8098
132
130
  ai_edge_torch/lowertools/test_utils.py,sha256=bPgc2iXX16KYtMNvmsRdKfrCY6UJmcfitfCOvHoD7Oc,1930
133
- ai_edge_torch/lowertools/torch_xla_utils.py,sha256=-SRm9YNsIGsaVd5Cyp2PP-tdLBJH8EDoMFAa2y89a1w,9043
131
+ ai_edge_torch/lowertools/torch_xla_utils.py,sha256=n6G3pFGmHar7kgKDsdTB74kv1PUuTTu1XjV7R-QizzE,9003
132
+ ai_edge_torch/lowertools/translate_recipe.py,sha256=DNzD0VD35YZDqiZjAF1IyIPSzUGPDpE0jvFCCYIzpnc,5667
134
133
  ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
135
134
  ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
136
135
  ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
137
- ai_edge_torch/odml_torch/export.py,sha256=hIGT-JKYbIa6e_G0AD-k4MSTIAMGdHC1hNHMn9CxsYw,10467
136
+ ai_edge_torch/odml_torch/export.py,sha256=OXN6jipwFtBvQ9XdyeDGQTQ_-UnCxPYnLc_WW7xF0aI,10469
138
137
  ai_edge_torch/odml_torch/export_utils.py,sha256=q84U69ZQ82hLXw-xncJ8IW-K71Xux-NWlzZTs7hdZWA,5127
139
138
  ai_edge_torch/odml_torch/tf_integration.py,sha256=lTFJPPEijLPFmn6qq2jbpVTQOo0YaOTK36kK6rCiyIE,5956
140
139
  ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
@@ -162,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
162
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
163
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
164
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
165
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
166
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/METADATA,sha256=Bvc6_uRgjiaqUsVareqJETErsc5rU7NNTPTDqn0JwoA,1878
167
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
168
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
169
- ai_edge_torch_nightly-0.3.0.dev20240902.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/METADATA,sha256=8yrrm7TEYgaRhKdUwgStjCqrTWs8YcnnlzoTJt2NrJg,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240905.dist-info/RECORD,,
@@ -1,14 +0,0 @@
1
- # Copyright 2024 The AI Edge Torch Authors.
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- # ==============================================================================