ai-edge-torch-nightly 0.2.0.dev20240710__py3-none-any.whl → 0.2.0.dev20240712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (23) hide show
  1. ai_edge_torch/convert/conversion.py +2 -4
  2. ai_edge_torch/convert/conversion_utils.py +61 -3
  3. ai_edge_torch/convert/converter.py +47 -16
  4. ai_edge_torch/convert/test/test_convert.py +39 -0
  5. ai_edge_torch/generative/examples/stable_diffusion/clip.py +12 -10
  6. ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +56 -30
  7. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +72 -69
  8. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +80 -72
  9. ai_edge_torch/generative/examples/stable_diffusion/pipeline.py +1 -1
  10. ai_edge_torch/generative/examples/t5/t5_attention.py +6 -1
  11. ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py +1 -1
  12. ai_edge_torch/generative/layers/model_config.py +4 -0
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +1 -1
  14. ai_edge_torch/generative/layers/unet/model_config.py +5 -5
  15. ai_edge_torch/generative/utilities/loader.py +9 -6
  16. ai_edge_torch/generative/utilities/stable_diffusion_loader.py +74 -10
  17. ai_edge_torch/model.py +11 -3
  18. ai_edge_torch/testing/model_coverage/model_coverage.py +19 -13
  19. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/METADATA +1 -1
  20. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/RECORD +23 -23
  21. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/LICENSE +0 -0
  22. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/WHEEL +0 -0
  23. {ai_edge_torch_nightly-0.2.0.dev20240710.dist-info → ai_edge_torch_nightly-0.2.0.dev20240712.dist-info}/top_level.txt +0 -0
@@ -37,6 +37,9 @@ class ResidualBlockTensorNames:
37
37
  class AttentionBlockTensorNames:
38
38
  norm: str = None
39
39
  fused_qkv_proj: str = None
40
+ q_proj: str = None
41
+ k_proj: str = None
42
+ v_proj: str = None
40
43
  output_proj: str = None
41
44
 
42
45
 
@@ -106,12 +109,21 @@ def _map_to_converted_state(
106
109
  state_param: str,
107
110
  converted_state: Dict[str, torch.Tensor],
108
111
  converted_state_param: str,
112
+ squeeze_dims: bool = False,
109
113
  ):
110
114
  converted_state[f"{converted_state_param}.weight"] = state.pop(
111
115
  f"{state_param}.weight"
112
116
  )
117
+ if squeeze_dims:
118
+ converted_state[f"{converted_state_param}.weight"] = torch.squeeze(
119
+ converted_state[f"{converted_state_param}.weight"]
120
+ )
113
121
  if f"{state_param}.bias" in state:
114
122
  converted_state[f"{converted_state_param}.bias"] = state.pop(f"{state_param}.bias")
123
+ if squeeze_dims:
124
+ converted_state[f"{converted_state_param}.bias"] = torch.squeeze(
125
+ converted_state[f"{converted_state_param}.bias"]
126
+ )
115
127
 
116
128
 
117
129
  class BaseLoader(loader.ModelLoader):
@@ -179,17 +191,65 @@ class BaseLoader(loader.ModelLoader):
179
191
  f"{converted_state_param_prefix}.norm",
180
192
  )
181
193
  attention_layer_prefix = f"{converted_state_param_prefix}.attention"
182
- _map_to_converted_state(
183
- state,
184
- tensor_names.fused_qkv_proj,
185
- converted_state,
186
- f"{attention_layer_prefix}.qkv_projection",
187
- )
194
+ if tensor_names.fused_qkv_proj is not None:
195
+ _map_to_converted_state(
196
+ state,
197
+ tensor_names.fused_qkv_proj,
198
+ converted_state,
199
+ f"{attention_layer_prefix}.qkv_projection",
200
+ )
201
+ else:
202
+ _map_to_converted_state(
203
+ state,
204
+ tensor_names.q_proj,
205
+ converted_state,
206
+ f"{attention_layer_prefix}.q_projection",
207
+ squeeze_dims=True,
208
+ )
209
+ _map_to_converted_state(
210
+ state,
211
+ tensor_names.k_proj,
212
+ converted_state,
213
+ f"{attention_layer_prefix}.k_projection",
214
+ squeeze_dims=True,
215
+ )
216
+ _map_to_converted_state(
217
+ state,
218
+ tensor_names.v_proj,
219
+ converted_state,
220
+ f"{attention_layer_prefix}.v_projection",
221
+ squeeze_dims=True,
222
+ )
223
+ converted_state[f"{attention_layer_prefix}.qkv_projection.weight"] = torch.concat(
224
+ [
225
+ converted_state[f"{attention_layer_prefix}.q_projection.weight"],
226
+ converted_state[f"{attention_layer_prefix}.k_projection.weight"],
227
+ converted_state[f"{attention_layer_prefix}.v_projection.weight"],
228
+ ],
229
+ axis=0,
230
+ )
231
+ del converted_state[f"{attention_layer_prefix}.q_projection.weight"]
232
+ del converted_state[f"{attention_layer_prefix}.k_projection.weight"]
233
+ del converted_state[f"{attention_layer_prefix}.v_projection.weight"]
234
+ if config.attention_config.qkv_use_bias:
235
+ converted_state[f"{attention_layer_prefix}.qkv_projection.bias"] = torch.concat(
236
+ [
237
+ converted_state[f"{attention_layer_prefix}.q_projection.bias"],
238
+ converted_state[f"{attention_layer_prefix}.k_projection.bias"],
239
+ converted_state[f"{attention_layer_prefix}.v_projection.bias"],
240
+ ],
241
+ axis=0,
242
+ )
243
+ del converted_state[f"{attention_layer_prefix}.q_projection.bias"]
244
+ del converted_state[f"{attention_layer_prefix}.k_projection.bias"]
245
+ del converted_state[f"{attention_layer_prefix}.v_projection.bias"]
246
+
188
247
  _map_to_converted_state(
189
248
  state,
190
249
  tensor_names.output_proj,
191
250
  converted_state,
192
251
  f"{attention_layer_prefix}.output_projection",
252
+ squeeze_dims=True,
193
253
  )
194
254
 
195
255
  def _map_cross_attention_block(
@@ -269,7 +329,7 @@ class BaseLoader(loader.ModelLoader):
269
329
  converted_state: Dict[str, torch.Tensor],
270
330
  tensor_names: TransformerBlockTensorNames,
271
331
  converted_state_param_prefix: str,
272
- config: unet_config.TransformerBlock2Dconfig,
332
+ config: unet_config.TransformerBlock2DConfig,
273
333
  ):
274
334
  _map_to_converted_state(
275
335
  state,
@@ -482,6 +542,10 @@ class BaseLoader(loader.ModelLoader):
482
542
  )
483
543
 
484
544
 
545
+ # Alias class name for better code reading.
546
+ ClipModelLoader = BaseLoader
547
+
548
+
485
549
  class AutoEncoderModelLoader(BaseLoader):
486
550
 
487
551
  @dataclass
@@ -668,7 +732,7 @@ class DiffusionModelLoader(BaseLoader):
668
732
  stride=2,
669
733
  padding=config.downsample_padding,
670
734
  ),
671
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
735
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
672
736
  attention_block_config=unet_config.AttentionBlock2DConfig(
673
737
  dim=output_channel,
674
738
  normalization_config=config.transformer_norm_config,
@@ -726,7 +790,7 @@ class DiffusionModelLoader(BaseLoader):
726
790
  ),
727
791
  num_layers=config.mid_block_layers,
728
792
  time_embedding_channels=config.time_embedding_blocks_dim,
729
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
793
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
730
794
  attention_block_config=unet_config.AttentionBlock2DConfig(
731
795
  dim=mid_block_channels,
732
796
  normalization_config=config.transformer_norm_config,
@@ -789,7 +853,7 @@ class DiffusionModelLoader(BaseLoader):
789
853
  mode=unet_config.SamplingType.NEAREST,
790
854
  scale_factor=2,
791
855
  ),
792
- transformer_block_config=unet_config.TransformerBlock2Dconfig(
856
+ transformer_block_config=unet_config.TransformerBlock2DConfig(
793
857
  attention_block_config=unet_config.AttentionBlock2DConfig(
794
858
  dim=output_channel,
795
859
  normalization_config=config.transformer_norm_config,
ai_edge_torch/model.py CHANGED
@@ -33,7 +33,10 @@ class Model(abc.ABC):
33
33
 
34
34
  @abc.abstractmethod
35
35
  def __call__(
36
- self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
36
+ self,
37
+ *args: npt.ArrayLike,
38
+ signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
39
+ **kwargs,
37
40
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
38
41
  raise NotImplementedError()
39
42
 
@@ -62,12 +65,16 @@ class TfLiteModel(Model):
62
65
  self._tflite_model = tflite_model
63
66
 
64
67
  def __call__(
65
- self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
68
+ self,
69
+ *args: npt.ArrayLike,
70
+ signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
71
+ **kwargs,
66
72
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
67
73
  """Runs inference on the edge model using the provided arguments.
68
74
 
69
75
  Args:
70
76
  *args: The arguments to be passed to the model for inference.
77
+ **kwargs: The arguments with specific names to be passed to the model for inference.
71
78
  signature_name: The name of the signature to be used for inference.
72
79
  The default signature is used if not provided.
73
80
  """
@@ -90,13 +97,14 @@ class TfLiteModel(Model):
90
97
  else:
91
98
  raise exception
92
99
 
93
- if len(signature_list[signature_name]['inputs']) != len(args):
100
+ if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
94
101
  raise ValueError(
95
102
  f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
96
103
  )
97
104
 
98
105
  # Gather the input dictionary based on the signature.
99
106
  inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
107
+ inputs = {**inputs, **kwargs}
100
108
  outputs = runner(**inputs)
101
109
 
102
110
  return (
@@ -60,7 +60,8 @@ def _torch_tensors_to_np(*argv):
60
60
  def compare_tflite_torch(
61
61
  edge_model: Model,
62
62
  torch_eval_func: Callable,
63
- input_data=None,
63
+ args=None,
64
+ kwargs=None,
64
65
  *,
65
66
  num_valid_inputs: int = 1,
66
67
  signature_name: str = None,
@@ -71,8 +72,9 @@ def compare_tflite_torch(
71
72
  Args:
72
73
  edge_model: Serialized ai_edge_torch.model.Model object.
73
74
  torch_eval_func: Callable function to evaluate torch model.
74
- input_data: torch.tensor array or a callable to generate a torch.tensor array
75
+ args: torch.tensor array or a callable to generate a torch.tensor array
75
76
  with random data, to pass into models during inference. (default None).
77
+ kwargs: dict of str to torch.tensor, or a callable to generate such.
76
78
  num_valid_inputs: Defines the number of times the random inputs will be generated (if a callable is provided for input_data).
77
79
  signature_name: If provided, specifies the name for the signature of the edge_model to run.
78
80
  Calls the default signature if not provided.
@@ -86,29 +88,33 @@ def compare_tflite_torch(
86
88
  # The supplied model_def.forward_args() will be executed num_valid_inputs
87
89
  # times to generate num_valid_inputs random inputs.
88
90
  torch_inputs = [
89
- input_data() if callable(input_data) else input_data
91
+ (
92
+ (args() if callable(args) else args) or tuple(),
93
+ (kwargs() if callable(kwargs) else kwargs) or {},
94
+ )
90
95
  for _ in range(num_valid_inputs)
91
96
  ]
92
- torch_outputs = [torch_eval_func(*xs) for xs in torch_inputs]
93
- np_inputs = [_torch_tensors_to_np(xs) for xs in torch_inputs]
97
+ torch_outputs = [torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs]
98
+ np_inputs = [
99
+ (_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs))
100
+ for args, kwargs in torch_inputs
101
+ ]
94
102
  np_outputs = [_torch_tensors_to_np(_flatten(ys)) for ys in torch_outputs]
95
103
 
96
104
  # Define inline utility function used throughout the function.
97
105
  def equal_fn(actual, expected):
98
106
  return np.allclose(actual, expected, atol=atol, rtol=rtol)
99
107
 
100
- def get_actual_fn(input):
108
+ def get_edge_output(inputs):
109
+ args, kwargs = inputs
101
110
  if signature_name is None:
102
- return _flatten(edge_model(*input))
111
+ return _flatten(edge_model(*args, **kwargs))
103
112
  else:
104
- return _flatten(edge_model(*input, signature_name=signature_name))
105
-
106
- def get_expected_fn(input=None, idx=0):
107
- return np_outputs[idx]
113
+ return _flatten(edge_model(*args, **kwargs, signature_name=signature_name))
108
114
 
109
115
  for idx, np_input in enumerate(np_inputs):
110
- output = get_actual_fn(np_input)
111
- golden_output = get_expected_fn(np_input, idx)
116
+ output = get_edge_output(np_input)
117
+ golden_output = np_outputs[idx]
112
118
 
113
119
  is_output_len_eq = len(golden_output) == len(output)
114
120
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240710
3
+ Version: 0.2.0.dev20240712
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,9 +1,9 @@
1
1
  ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
2
- ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
2
+ ai_edge_torch/model.py,sha256=8Ba9ia7TCM_fciulw6qObmzdcxL3IaLQKDqpR7Lxp-Q,4440
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
- ai_edge_torch/convert/conversion.py,sha256=8K8jQuaCjlUWoj7jiimxp_zpN6mYThLOcQ858UDcYnE,4159
5
- ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
- ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
4
+ ai_edge_torch/convert/conversion.py,sha256=StJHglvx6cii36oi8sj-tZda009e9UqR6ufZOZkP1SY,4137
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=PKXIlSCU-8DhppNBh9ICDNUlEOpV0HgCbt85jDVe3rA,13394
6
+ ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUSW_nZc,7875
7
7
  ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
8
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
9
9
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
23
23
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
24
24
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
25
- ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
+ ai_edge_torch/convert/test/test_convert.py,sha256=h0vOffr8saDQRkiXljNWDZ17EBjnS4xAtxd8DxETleY,9081
26
26
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
27
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
28
  ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
@@ -43,12 +43,12 @@ ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=uF1A2EX8xYie3
43
43
  ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
44
44
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
45
45
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
46
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
47
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=wVEjsKd5JCIiYf5GF19rOXs2NHscZh0D69mxaS4f0Sk,4182
48
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=RgxedILk7iNMb0mhE4VkCs6d7BnFzYhR3vspUkC0-1o,11425
49
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=sRevfsmCun7zbceJbOstLKNUsLwzQDsGm7Mi2JmlREg,26021
46
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=P-cUUQaQKGKV2p-7hvLJ--RpCIA7gk8WCDRgg0pNtd0,4331
47
+ ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=XwV1z7cVkQ947k_ERftEeL8n0NUFCJAltLtqDVfzYGI,4704
48
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=xHcmOZaW7hoWlEEEqtB4FWoHMw5AsGHPHXMNiXEfviY,13814
49
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=G-MgiEM_PpegNMePBPuNQDeUfjk42EYrVZAyJHC54AY,28468
50
50
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
51
- ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
51
+ ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=AopJ-KE74lzq4QJUP_hYeiXvGth7uWv7nNKqkhtcoF8,8277
52
52
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
53
53
  ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=NFpOfA4KN0JpShm5QvuYbQYZ844NzexWD8nV3WjMOZM,2397
54
54
  ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
@@ -59,7 +59,7 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5i
59
59
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
60
60
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=7RwaZQaKhFt3zKAUbFjq95CSYhL1nd9BVSbSRNJp4-4,4529
61
61
  ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
62
- ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rkMwi-NJGBXHm5S57Rsj1LbcoVdyRkS7GmIBuU6F_2E,8274
62
+ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=KaGzCAViNOpJIQbRF-ItouuVPqI9nroWRRGN-KFYKZs,8357
63
63
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
64
64
  ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=Sf3ZMYv-iuMRKAKLow47qth8vTF1zl6i8TxJ9uT_StU,3885
65
65
  ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
@@ -68,21 +68,21 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
68
68
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=nT7Fh-f5ZdwaK3dPoCvZflpJ4fRHjLdFMjk1_uw3-b8,2559
69
69
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
70
70
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
71
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=IehLwFNwa0C9fnk1pmNmyfuAwwWbuwdyKy46BSqNVdI,1948
71
+ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=BCAcc_OcEjvbaXQSbc8vlKeMad7E3gCA4BNsUdWRwBI,1966
72
72
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
73
  ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
74
74
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
75
75
  ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
76
76
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
77
77
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
78
- ai_edge_torch/generative/layers/model_config.py,sha256=aQLtOPdGpehfnb4aGO-iILLAsRU5t7j6opyezPEUY_w,4673
78
+ ai_edge_torch/generative/layers/model_config.py,sha256=s6aIBib_LhjZC3p1pRxjcg3mf1BUrGqPQdsb6G83U-c,5028
79
79
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
80
80
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
81
81
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
82
82
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=H45wsXA6iJi_Mjd66NiQrh7i1fx05r9o_FI-fSnhVts,26538
83
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=T70veX57CC9uNidwzoVGzOu-CwzcYMBr1Zk_0bq5UlM,26538
84
84
  ai_edge_torch/generative/layers/unet/builder.py,sha256=NmJiZ2-e1wbv9jnvI3VCyUJlONV5ZAOz-RTc7ipAZ5U,1872
85
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=FrIO-CR8aRIV2i8aFqom_4S7WCEDLMyYwo6U0oFyn7A,9097
85
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=GU12QEJwO6ukveMR9JRsrhE0YIPKuhk1U81CylmOQTA,9097
86
86
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
87
87
  ai_edge_torch/generative/quantize/example.py,sha256=Oy-Ss1oKXMu5RVOGt8QiUwKtrHEfhbVjTXXjxPcOqDA,1536
88
88
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -97,8 +97,8 @@ ai_edge_torch/generative/test/loader_test.py,sha256=N88CbrLW7Q2x1EyurwdXQ6YjsA-y
97
97
  ai_edge_torch/generative/test/test_model_conversion.py,sha256=LsPTrLC1I4JW2GowTS3V9Eu257vLHr2Yj5f_qaFUX84,7589
98
98
  ai_edge_torch/generative/test/test_quantize.py,sha256=TxZwe2cCTfwq9t2thBuYiLdp5Xu2cspCbQgziZ3Oo7k,5269
99
99
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
100
- ai_edge_torch/generative/utilities/loader.py,sha256=Hs92478j1g4jQGvbdP1aWvOy907HjwqQZE-NFy6HELo,11326
101
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=7ChqrnthD7I-Be6vkRvYTRhbGQ3tqMbikLpjY5HpSzE,30890
100
+ ai_edge_torch/generative/utilities/loader.py,sha256=NTaCrU2qmeJpqdAau13ZgyeOpwATqhZB68GY0LZjU6A,11438
101
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=zixjZryUaCSDKmfPkQvYwbPJhUyTmZ4AK_lWN8iFo68,33324
102
102
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=h1FQzt4x8wiQMX4NzYNVIaJGLr_YKH0sojBvy0amexM,16503
103
103
  ai_edge_torch/hlfb/__init__.py,sha256=rrje8a2iuKboBoV96bVq7nlS9HsnuEMbHE5JiWmCxFA,752
104
104
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=2VXnHcGf23VOuP-1GriGIpuL98leBB8twp_qaScMnmc,4799
@@ -113,9 +113,9 @@ ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCt
113
113
  ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
114
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
- ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
117
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/METADATA,sha256=6ask_HCsla1Tzx5_ORpPGrdvtwYAwS6BB3jNV31Jo9g,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240710.dist-info/RECORD,,
116
+ ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
117
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/METADATA,sha256=BGHmRLYo3ko7KRysDP59YexTpPn45jrtpRwqQkPAM5s,1745
119
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/RECORD,,