ai-edge-torch-nightly 0.2.0.dev20240604__py3-none-any.whl → 0.2.0.dev20240606__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -33,7 +33,6 @@ class TestModelConversion(unittest.TestCase):
33
33
  """Unit tests that check for model conversion and correctness."""
34
34
 
35
35
  def test_toy_model_with_kv_cache(self):
36
- self.skipTest("b/338288901")
37
36
  config = toy_model_with_kv_cache.get_model_config()
38
37
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
39
38
  idx, input_pos = torch.tensor([[1]], dtype=torch.long), torch.tensor(
@@ -42,19 +41,21 @@ class TestModelConversion(unittest.TestCase):
42
41
 
43
42
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
44
43
 
45
- self.assertTrue(
46
- model_coverage.compare_tflite_torch(
47
- edge_model,
48
- pytorch_model,
49
- (idx, input_pos),
50
- num_valid_inputs=1,
51
- atol=1e-5,
52
- rtol=1e-5,
53
- )
54
- )
44
+ # TODO(b/338288901): re-enable test to check output tensors.
45
+ skip_output_check = True
46
+ if skip_output_check is False:
47
+ self.assertTrue(
48
+ model_coverage.compare_tflite_torch(
49
+ edge_model,
50
+ pytorch_model,
51
+ (idx, input_pos),
52
+ num_valid_inputs=1,
53
+ atol=1e-5,
54
+ rtol=1e-5,
55
+ )
56
+ )
55
57
 
56
58
  def test_toy_model_with_kv_cache_with_hlfb(self):
57
- self.skipTest("b/338288901")
58
59
  config = toy_model_with_kv_cache.get_model_config()
59
60
  config.enable_hlfb = True
60
61
  pytorch_model = toy_model_with_kv_cache.ToyModelWithKV(config)
@@ -64,16 +65,19 @@ class TestModelConversion(unittest.TestCase):
64
65
 
65
66
  edge_model = ai_edge_torch.convert(pytorch_model, (idx, input_pos))
66
67
 
67
- self.assertTrue(
68
- model_coverage.compare_tflite_torch(
69
- edge_model,
70
- pytorch_model,
71
- (idx, input_pos),
72
- num_valid_inputs=1,
73
- atol=1e-5,
74
- rtol=1e-5,
75
- )
76
- )
68
+ # TODO(b/338288901): re-enable test to check output tensors.
69
+ skip_output_check = True
70
+ if skip_output_check is False:
71
+ self.assertTrue(
72
+ model_coverage.compare_tflite_torch(
73
+ edge_model,
74
+ pytorch_model,
75
+ (idx, input_pos),
76
+ num_valid_inputs=1,
77
+ atol=1e-5,
78
+ rtol=1e-5,
79
+ )
80
+ )
77
81
 
78
82
  def test_tiny_llama(self):
79
83
  self.skipTest("b/338288901")
@@ -87,19 +91,21 @@ class TestModelConversion(unittest.TestCase):
87
91
 
88
92
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
89
93
 
90
- self.assertTrue(
91
- model_coverage.compare_tflite_torch(
92
- edge_model,
93
- pytorch_model,
94
- (tokens, input_pos),
95
- num_valid_inputs=1,
96
- atol=1e-5,
97
- rtol=1e-5,
98
- )
99
- )
94
+ # TODO(b/338288901): re-enable test to check output tensors.
95
+ skip_output_check = True
96
+ if skip_output_check is False:
97
+ self.assertTrue(
98
+ model_coverage.compare_tflite_torch(
99
+ edge_model,
100
+ pytorch_model,
101
+ (tokens, input_pos),
102
+ num_valid_inputs=1,
103
+ atol=1e-5,
104
+ rtol=1e-5,
105
+ )
106
+ )
100
107
 
101
108
  def test_tiny_llama_multisig(self):
102
- self.skipTest("b/338288901")
103
109
  config = tiny_llama.get_fake_model_config_for_test()
104
110
  pytorch_model = tiny_llama.TinyLLamma(config)
105
111
 
@@ -122,32 +128,30 @@ class TestModelConversion(unittest.TestCase):
122
128
  .convert()
123
129
  )
124
130
 
125
- # For the pytorch model, the KV cache is a persistent state internal to the model, and it
126
- # will be shared for prefill and decode. However, for tflite, currently we can't share
127
- # kv-cache between the two signatures. prefill will change the content in kv-cache,
128
- # but it won't be readable by the decode tflite model. This means the output of running `decode` after
129
- # running `prefill` in pytorch will be different from the output of running `decode` after `prefill` via ai_edge_torch.
130
- copied_model = copy.deepcopy(pytorch_model)
131
-
132
- self.assertTrue(
133
- model_coverage.compare_tflite_torch(
134
- edge_model,
135
- pytorch_model,
136
- (prefill_tokens, prefill_input_pos),
137
- signature_name="prefill",
138
- num_valid_inputs=1,
139
- )
140
- )
141
-
142
- self.assertTrue(
143
- model_coverage.compare_tflite_torch(
144
- edge_model,
145
- copied_model,
146
- (decode_token, decode_input_pos),
147
- signature_name="decode",
148
- num_valid_inputs=1,
149
- )
150
- )
131
+ # TODO(b/338288901): re-enable test to check output tensors.
132
+ skip_output_check = True
133
+ if skip_output_check is False:
134
+ copied_model = copy.deepcopy(pytorch_model)
135
+
136
+ self.assertTrue(
137
+ model_coverage.compare_tflite_torch(
138
+ edge_model,
139
+ pytorch_model,
140
+ (prefill_tokens, prefill_input_pos),
141
+ signature_name="prefill",
142
+ num_valid_inputs=1,
143
+ )
144
+ )
145
+
146
+ self.assertTrue(
147
+ model_coverage.compare_tflite_torch(
148
+ edge_model,
149
+ copied_model,
150
+ (decode_token, decode_input_pos),
151
+ signature_name="decode",
152
+ num_valid_inputs=1,
153
+ )
154
+ )
151
155
 
152
156
  def test_gemma(self):
153
157
  self.skipTest("b/338288901")
@@ -161,17 +165,20 @@ class TestModelConversion(unittest.TestCase):
161
165
 
162
166
  edge_model = ai_edge_torch.convert(model, (tokens, input_pos))
163
167
 
164
- # TODO(talumbau, haoliang): debug numerical diff.
165
- self.assertTrue(
166
- model_coverage.compare_tflite_torch(
167
- edge_model,
168
- model,
169
- (tokens, input_pos),
170
- num_valid_inputs=1,
171
- atol=1e-2,
172
- rtol=1e-5,
173
- )
174
- )
168
+ # TODO(b/338288901): re-enable test to check output tensors.
169
+ skip_output_check = True
170
+ if skip_output_check is False:
171
+ # TODO(talumbau, haoliang): debug numerical diff.
172
+ self.assertTrue(
173
+ model_coverage.compare_tflite_torch(
174
+ edge_model,
175
+ model,
176
+ (tokens, input_pos),
177
+ num_valid_inputs=1,
178
+ atol=1e-2,
179
+ rtol=1e-5,
180
+ )
181
+ )
175
182
 
176
183
  def test_phi2(self):
177
184
  self.skipTest("b/338288901")
@@ -185,16 +192,19 @@ class TestModelConversion(unittest.TestCase):
185
192
 
186
193
  edge_model = ai_edge_torch.convert(pytorch_model, (tokens, input_pos))
187
194
 
188
- self.assertTrue(
189
- model_coverage.compare_tflite_torch(
190
- edge_model,
191
- pytorch_model,
192
- (tokens, input_pos),
193
- num_valid_inputs=1,
194
- atol=1e-5,
195
- rtol=1e-5,
196
- )
197
- )
195
+ # TODO(b/338288901): re-enable test to check output tensors.
196
+ skip_output_check = True
197
+ if skip_output_check is False:
198
+ self.assertTrue(
199
+ model_coverage.compare_tflite_torch(
200
+ edge_model,
201
+ pytorch_model,
202
+ (tokens, input_pos),
203
+ num_valid_inputs=1,
204
+ atol=1e-5,
205
+ rtol=1e-5,
206
+ )
207
+ )
198
208
 
199
209
 
200
210
  if __name__ == "__main__":
@@ -69,10 +69,16 @@ def load_pytorch_statedict(full_path: str):
69
69
  Raises:
70
70
  ValueError: If no tensors are loaded from the provided directory or file.
71
71
  """
72
- pattern = os.path.join(full_path, "*.bin") if os.path.isdir(full_path) else full_path
73
72
  files = []
74
- for file in glob.glob(pattern):
75
- files.append(file)
73
+ patterns = []
74
+ if os.path.isdir(full_path):
75
+ patterns.append(os.path.join(full_path, "*.bin"))
76
+ patterns.append(os.path.join(full_path, "*.pt"))
77
+ else:
78
+ patterns.append(full_path)
79
+ for pattern in patterns:
80
+ for file in glob.glob(pattern):
81
+ files.append(file)
76
82
 
77
83
  tensors = {}
78
84
  for file in files:
@@ -93,18 +99,20 @@ class ModelLoader:
93
99
 
94
100
  @dataclass
95
101
  class TensorNames:
96
- attn_query_proj: str
97
- attn_key_proj: str
98
- attn_value_proj: str
99
- attn_output_proj: str
100
-
101
- ff_up_proj: str
102
- ff_down_proj: str
102
+ attn_query_proj: str = None
103
+ attn_key_proj: str = None
104
+ attn_value_proj: str = None
105
+ attn_fused_qkv_proj: str = None
106
+ attn_output_proj: str = None
107
+
108
+ ff_up_proj: str = None
109
+ ff_down_proj: str = None
103
110
  ff_gate_proj: str = None
104
111
 
105
112
  pre_attn_norm: str = None
106
113
  pre_ff_norm: str = None
107
114
  embedding: str = None
115
+ embedding_position: str = None
108
116
  final_norm: str = None
109
117
  lm_head: str = None
110
118
 
@@ -129,6 +137,10 @@ class ModelLoader:
129
137
  strict (bool, optional): Whether the converted keys are strictly
130
138
  matched. Defaults to True.
131
139
 
140
+ Returns:
141
+ missing_keys (List[str]): a list of str containing the missing keys
142
+ unexpected_keys (List[str]): a list of str containing the unexpected keys
143
+
132
144
  Raises:
133
145
  ValueError: If conversion results in unmapped tensors and strict mode is
134
146
  enabled.
@@ -139,6 +151,10 @@ class ModelLoader:
139
151
  converted_state["tok_embedding.weight"] = state.pop(
140
152
  f"{self._names.embedding}.weight"
141
153
  )
154
+ if self._names.embedding_position is not None:
155
+ converted_state["tok_embedding_position"] = state.pop(
156
+ f"{self._names.embedding_position}"
157
+ )
142
158
  if self._names.lm_head is not None:
143
159
  converted_state["lm_head.weight"] = state.pop(f"{self._names.lm_head}.weight")
144
160
  if model.config.lm_head_use_bias:
@@ -158,7 +174,7 @@ class ModelLoader:
158
174
  raise ValueError(
159
175
  f"Failed to map all tensor. Remaing tensor are: {list(state.keys())}"
160
176
  )
161
- model.load_state_dict(converted_state, strict=strict)
177
+ return model.load_state_dict(converted_state, strict=strict)
162
178
 
163
179
  def _get_loader(self) -> Callable[[str], Dict[str, torch.Tensor]]:
164
180
  """A best effort method for finding appropriate state loader.
@@ -172,13 +188,15 @@ class ModelLoader:
172
188
  if os.path.isdir(self._file_name):
173
189
  if glob.glob(os.path.join(self._file_name, "*.safetensors")):
174
190
  return load_safetensors
175
- if glob.glob(os.path.join(self._file_name, "*.bin")):
191
+ if glob.glob(os.path.join(self._file_name, "*.bin")) or glob.glob(
192
+ os.path.join(self._file_name, "*.pt")
193
+ ):
176
194
  return load_pytorch_statedict
177
195
 
178
196
  if self._file_name.endswith(".safetensors"):
179
197
  return load_safetensors
180
198
 
181
- if self._file_name.endswith(".bin"):
199
+ if self._file_name.endswith(".bin") or self._file_name.endswith(".pt"):
182
200
  return load_pytorch_statedict
183
201
 
184
202
  raise ValueError(f"File format not supported.")
@@ -225,22 +243,33 @@ class ModelLoader:
225
243
  converted_state: Dict[str, torch.Tensor],
226
244
  ):
227
245
  prefix = f"transformer_blocks.{idx}"
228
- q_name = self._names.attn_query_proj.format(idx)
229
- k_name = self._names.attn_key_proj.format(idx)
230
- v_name = self._names.attn_value_proj.format(idx)
231
- converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
232
- config,
233
- state.pop(f"{q_name}.weight"),
234
- state.pop(f"{k_name}.weight"),
235
- state.pop(f"{v_name}.weight"),
236
- )
237
- if config.attn_config.qkv_use_bias:
238
- converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
246
+ if self._names.attn_fused_qkv_proj:
247
+ fused_qkv_name = self._names.attn_fused_qkv_proj.format(idx)
248
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = state.pop(
249
+ f"{fused_qkv_name}.weight"
250
+ )
251
+ else:
252
+ q_name = self._names.attn_query_proj.format(idx)
253
+ k_name = self._names.attn_key_proj.format(idx)
254
+ v_name = self._names.attn_value_proj.format(idx)
255
+ converted_state[f"{prefix}.atten_func.qkv_projection.weight"] = self._fuse_qkv(
239
256
  config,
240
- state.pop(f"{q_name}.bias"),
241
- state.pop(f"{k_name}.bias"),
242
- state.pop(f"{v_name}.bias"),
257
+ state.pop(f"{q_name}.weight"),
258
+ state.pop(f"{k_name}.weight"),
259
+ state.pop(f"{v_name}.weight"),
243
260
  )
261
+ if config.attn_config.qkv_use_bias:
262
+ if self._names.attn_fused_qkv_proj:
263
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = state.pop(
264
+ f"{fused_qkv_name}.bias"
265
+ )
266
+ else:
267
+ converted_state[f"{prefix}.atten_func.qkv_projection.bias"] = self._fuse_qkv(
268
+ config,
269
+ state.pop(f"{q_name}.bias"),
270
+ state.pop(f"{k_name}.bias"),
271
+ state.pop(f"{v_name}.bias"),
272
+ )
244
273
 
245
274
  o_name = self._names.attn_output_proj.format(idx)
246
275
  converted_state[f"{prefix}.atten_func.output_projection.weight"] = state.pop(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240604
3
+ Version: 0.2.0.dev20240606
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -6,7 +6,7 @@ ai_edge_torch/convert/conversion_utils.py,sha256=NpVm3Ms81_cIW5IYgGsr0BVganJJgBK
6
6
  ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
7
7
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=Ll2nNwufjcV5nSruQPXiloq7F1E7pWJ2T5clXmy1lk8,2825
8
8
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
9
- ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=wHVWNNMu5h_ya6GnnJn0cNif9xmdSqr8Vm-R7lllxZM,6213
9
+ ai_edge_torch/convert/fx_passes/build_aten_composite_pass.py,sha256=2yqUwJJ2R233_X9FNMOP9oYRTTzH34TR_BIUj-wfnKw,7080
10
10
  ai_edge_torch/convert/fx_passes/build_upsample_bilinear2d_composite_pass.py,sha256=76XYoIlFDgrzp5QemoaEalPFcEbfszkEH_PLvO1ASCk,2607
11
11
  ai_edge_torch/convert/fx_passes/canonicalize_pass.py,sha256=UX6dJsxCqSkftXXvNBV-i7Bjk6H7qTyqzUnE640Itfg,1673
12
12
  ai_edge_torch/convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=aRT8hTS3n9ie28lgu6mygtFO6Ypwu0qjNb0c81v9HLs,2448
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=FlNKt2EhIKnlVEeUWTiv5sz446YKU6Yy1H0Gd6VRgkU,6432
23
23
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
24
24
  ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
- ai_edge_torch/convert/test/test_convert_composites.py,sha256=SrVn_cEMtQhYYCMOUKK0K7M57MQNQX-lOUwieln0HGA,6616
25
+ ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
26
26
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
27
27
  ai_edge_torch/debug/__init__.py,sha256=TKvmnjVk3asvYcVh6C-LPr6srgAF_nppSAupWEXqwPY,707
28
28
  ai_edge_torch/debug/culprit.py,sha256=vklaxBUfINdo44OsH7csILK70N41gEThCGchGEfbTZw,12789
@@ -40,10 +40,10 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlY
40
40
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=KR1Ci4rlJeeGfsFRliCxUve9K7RTJLZfTRMgFtfQ4MU,2434
44
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=6REAYy1Bv-Iv5zcmA_m_W6fH6jt5a3IS6Vge18jS_Wo,3633
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
44
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=MI73RjOeD4Kh7AL0j5_QXiZq-rl_qCdibSE6eCQCyeY,3804
45
45
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=AgVAdUbSkHXONVUjAyBQEXhIUUlinf9kNljcBpWnj3A,3276
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=nq94VpQ103eOimnmdyg7u3Xk1LH1IxGlmIbr2AttRIk,16224
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=L6hLaMQGb8-_BwSvTLIuDnZwfTqn0K4swBUjfPnYWZo,2341
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -56,22 +56,24 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
58
  ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
59
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=anR99IrzR21x6yswFHYG5QQtPDZ7rVicf6STfMp54fU,8998
59
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
61
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
62
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
62
63
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
63
64
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
64
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
65
66
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
66
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
67
- ai_edge_torch/generative/layers/attention.py,sha256=PxixRZb00v5BQkWbDwaJgke4Rd5LwzdWe0zH9SG4Tj0,9127
68
+ ai_edge_torch/generative/layers/attention.py,sha256=zNIBXxCOA5Mz_F_dfBbKpIovhtcB6q5a-i8oAxls1d0,7071
68
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
69
- ai_edge_torch/generative/layers/builder.py,sha256=ZSBVLv5EOtCkSW_Z8C2Hd7jN52nIAA2as1-qpmHGbCg,3201
70
+ ai_edge_torch/generative/layers/builder.py,sha256=WLTeDId9t3Xwt0h1zxzqoYyFvfrNzPKLskcl39q8Aqw,3403
70
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
71
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
72
- ai_edge_torch/generative/layers/model_config.py,sha256=KpJRIHV5BJH8QOa7h6LXLZyC7UDWgbCEsw0CvArz49Q,4064
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=2zT9nyoyuuyk5ziiww0VSJ6_JO7pDf7uOYbO9O3OQc4,4249
73
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
74
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
+ ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
75
77
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
76
78
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
77
79
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
@@ -81,10 +83,10 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=CRA2ENevS-3usHqidWDe2w
81
83
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=OQ4ghQXknA1PPjuY-xBgAmOpaIBgYFM8F2YAIot06hE,1345
82
84
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
85
  ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-ySQcPxpZH3QOGp-M,3317
84
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=1NfZxKo9Gx6CmVfd86K1FkmsNQnjzIV1ojBS85UGvT0,6500
86
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=i_SAW-hD8SaHuopMZI9IuXXDFn5uSTJa1nKZhaC3dAQ,6811
85
87
  ai_edge_torch/generative/test/test_quantize.py,sha256=f70sH1ZFzdCwYj0MG-eg54WOC4LasR0D8CTUYpjxZYM,3728
86
88
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
87
- ai_edge_torch/generative/utilities/loader.py,sha256=c-ZOIDBVnat_5l2W5sWU7HQm7CL-wducS8poSu5PlUg,10107
89
+ ai_edge_torch/generative/utilities/loader.py,sha256=r-_hSanSjLZ_YXFpZUb0Up94u5F8JHp70Vf2nlONPSg,11269
88
90
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=guDTv-12UUvJGl4eDvvZX3t4rRKewfXO8SpcYXM6gbc,16156
89
91
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
90
92
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -100,8 +102,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
100
102
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
101
103
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
102
104
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
103
- ai_edge_torch_nightly-0.2.0.dev20240604.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
104
- ai_edge_torch_nightly-0.2.0.dev20240604.dist-info/METADATA,sha256=0Hjm-wBxxZRbVdxcFhmOKVwsU52_l5F6IblrflOWMWk,1748
105
- ai_edge_torch_nightly-0.2.0.dev20240604.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
106
- ai_edge_torch_nightly-0.2.0.dev20240604.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
107
- ai_edge_torch_nightly-0.2.0.dev20240604.dist-info/RECORD,,
105
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
106
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/METADATA,sha256=2yFjQFvs93hoppwWWSJi-B9HC2n-h1s2rzjsyGXU1zI,1748
107
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
108
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
109
+ ai_edge_torch_nightly-0.2.0.dev20240606.dist-info/RECORD,,