ai-edge-torch-nightly 0.2.0.dev20240711__py3-none-any.whl → 0.2.0.dev20240712__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

@@ -88,16 +88,14 @@ def convert_signatures(
88
88
  _warn_training_modules(signatures)
89
89
 
90
90
  exported_programs: torch.export.ExportedProgram = [
91
- torch.export.export(
92
- sig.module, sig.sample_args, dynamic_shapes=sig.dynamic_shapes
93
- )
91
+ torch.export.export(sig.module, sig.flat_args, dynamic_shapes=sig.dynamic_shapes)
94
92
  for sig in signatures
95
93
  ]
96
94
 
97
95
  # Apply default fx passes
98
96
  exported_programs = list(map(_run_convert_passes, exported_programs))
99
97
  shlo_bundles: list[stablehlo.StableHLOModelBundle] = [
100
- cutils.exported_program_to_stablehlo_bundle(exported, sig.sample_args)
98
+ cutils.exported_program_to_stablehlo_bundle(exported, sig.flat_args)
101
99
  for exported, sig in zip(exported_programs, signatures)
102
100
  ]
103
101
 
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
+ import collections
16
17
  import copy
17
18
  from dataclasses import dataclass
18
19
  import gc
@@ -22,6 +23,7 @@ import tempfile
22
23
  from typing import Any, Dict, Optional, Tuple, Union
23
24
 
24
25
  import torch
26
+ import torch.utils._pytree as pytree
25
27
  from torch_xla import stablehlo
26
28
 
27
29
  from ai_edge_torch.generative.quantize.ai_edge_quantizer_glue import translate_recipe # NOQA
@@ -47,8 +49,59 @@ class Signature:
47
49
  name: str
48
50
  module: torch.nn.Module
49
51
  sample_args: tuple[torch.Tensor]
52
+ sample_kwargs: dict[str, torch.Tensor]
50
53
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None
51
54
 
55
+ @property
56
+ def _normalized_sample_args_kwargs(self):
57
+ args, kwargs = self.sample_args, self.sample_kwargs
58
+ if args is not None:
59
+ if not isinstance(args, tuple):
60
+ # TODO(b/352584188): Check value types
61
+ raise ValueError("sample_args must be a tuple of torch tensors.")
62
+ if kwargs is not None:
63
+ if not isinstance(kwargs, dict) or not all(
64
+ isinstance(key, str) for key in kwargs.keys()
65
+ ):
66
+ # TODO(b/352584188): Check value types
67
+ raise ValueError("sample_kwargs must be a dict of string to tensor.")
68
+
69
+ args = args if args is not None else tuple()
70
+ kwargs = kwargs if kwargs is not None else {}
71
+ return args, kwargs
72
+
73
+ @property
74
+ def flat_arg_names(self) -> list[str]:
75
+ spec = pytree.tree_flatten(self._normalized_sample_args_kwargs)[1]
76
+ args_spec, kwargs_spec = spec.children_specs
77
+
78
+ names = []
79
+ for i in range(args_spec.num_leaves):
80
+ names.append(f"args_{i}")
81
+
82
+ dict_context = (
83
+ kwargs_spec.context
84
+ if kwargs_spec.type is not collections.defaultdict
85
+ # ignore mismatch of `default_factory` for defaultdict
86
+ else kwargs_spec.context[1]
87
+ )
88
+
89
+ for name, value_spec in zip(dict_context, kwargs_spec.children_specs):
90
+ if value_spec.num_leaves == 1:
91
+ names.append(name)
92
+ else:
93
+ # value_spec.num_leaves may be greater than 1 when the value is a (nested)
94
+ # tuple of tensors. We haven't decided how we should support flattenable
95
+ # tensor containers as inputs.
96
+ # TODO(b/352584188): Decide the behavior of tensor container as input (flatten or reject)
97
+ for i in range(value_spec.num_leaves):
98
+ names.append(f"{name}_{i}")
99
+ return names
100
+
101
+ @property
102
+ def flat_args(self) -> tuple[torch.Tensor]:
103
+ return tuple(pytree.tree_flatten(self._normalized_sample_args_kwargs)[0])
104
+
52
105
 
53
106
  def exported_program_to_stablehlo_bundle(
54
107
  exported_program: torch.export.ExportedProgram, sample_args: tuple[torch.Tensor]
@@ -189,7 +242,9 @@ def _make_tf_function(
189
242
 
190
243
  def _make_tf_signature(
191
244
  meta: stablehlo.StableHLOFunctionMeta,
245
+ signature: Signature,
192
246
  ) -> list[tf.TensorSpec]:
247
+ input_names = signature.flat_arg_names
193
248
  input_pos_to_spec = {
194
249
  loc.position: spec
195
250
  for loc, spec in itertools.chain(
@@ -197,9 +252,11 @@ def _make_tf_signature(
197
252
  )
198
253
  if loc.type_ == stablehlo.VariableType.INPUT_ARG
199
254
  }
255
+ assert len(input_pos_to_spec) == len(input_names)
256
+
200
257
  primitive_type_to_tf_type = {"int": "int32", "float": "float32"}
201
258
  ret: list[tf.TensorSpec] = []
202
- for i in range(len(input_pos_to_spec)):
259
+ for i, name in enumerate(input_names):
203
260
  spec = input_pos_to_spec[i]
204
261
  shape = _get_shape_with_dynamic(spec)
205
262
  ret.append(
@@ -208,7 +265,7 @@ def _make_tf_signature(
208
265
  dtype=primitive_type_to_tf_type[spec.dtype]
209
266
  if spec.dtype in primitive_type_to_tf_type
210
267
  else spec.dtype,
211
- name=f"args_{i}",
268
+ name=name,
212
269
  )
213
270
  )
214
271
  return ret
@@ -276,7 +333,8 @@ def convert_stablehlo_to_tflite(
276
333
  tf.Variable(v, trainable=False) for v in bundle.additional_constants
277
334
  ]
278
335
  tf_signatures: list[list[tf.TensorSpec]] = list(
279
- _make_tf_signature(func.meta) for func in bundle.stablehlo_funcs
336
+ _make_tf_signature(func.meta, sig)
337
+ for func, sig in zip(bundle.stablehlo_funcs, signatures)
280
338
  )
281
339
 
282
340
  tf_functions = _make_tf_function(shlo_graph_module, bundle)
@@ -34,17 +34,23 @@ class Converter:
34
34
  self,
35
35
  name: str,
36
36
  module: torch.nn.Module,
37
- sample_args: tuple[cutils.TracingArg],
37
+ sample_args=None,
38
+ sample_kwargs=None,
39
+ *,
38
40
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
39
41
  ) -> Converter:
40
42
  """Alias to `add_signature`"""
41
- return self.add_signature(name, module, sample_args, dynamic_shapes)
43
+ return self.add_signature(
44
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
45
+ )
42
46
 
43
47
  def add_signature(
44
48
  self,
45
49
  name: str,
46
50
  module: torch.nn.Module,
47
- sample_args: tuple[cutils.TracingArg],
51
+ sample_args=None,
52
+ sample_kwargs=None,
53
+ *,
48
54
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
49
55
  ) -> Converter:
50
56
  """Allows adding a new named torch model along with sample args to the conversion.
@@ -52,7 +58,8 @@ class Converter:
52
58
  Args:
53
59
  name: The name of the signature included in the converted edge model.
54
60
  module: The torch module to be converted.
55
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
61
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
62
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
56
63
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
57
64
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
58
65
 
@@ -63,13 +70,21 @@ class Converter:
63
70
  if name in [sig.name for sig in self._signatures]:
64
71
  raise ValueError(f"A signature with the provided name ({name}) is already added.")
65
72
 
66
- self._signatures.append(cutils.Signature(name, module, sample_args, dynamic_shapes))
73
+ if sample_args is None and sample_kwargs is None:
74
+ raise ValueError("sample_args or sample_kwargs must be provided.")
75
+
76
+ self._signatures.append(
77
+ cutils.Signature(
78
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
79
+ )
80
+ )
67
81
  return self
68
82
 
69
83
  def convert(
70
84
  self,
71
85
  module: torch.nn.Module = None,
72
- sample_args: tuple[cutils.TracingArg] = None,
86
+ sample_args=None,
87
+ sample_kwargs=None,
73
88
  *,
74
89
  quant_config: Optional[qcfg.QuantConfig] = None,
75
90
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
@@ -88,7 +103,8 @@ class Converter:
88
103
  Args:
89
104
  name: The name of the signature included in the converted edge model.
90
105
  module: The torch module to be converted.
91
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
106
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
107
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
92
108
  quant_config: User-defined quantization method and scheme of the model.
93
109
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
94
110
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
@@ -100,12 +116,20 @@ class Converter:
100
116
  ValueError: If the arguments are not provided as expected. See the example in this functions's comment.
101
117
  """
102
118
  if module is not None:
103
- if sample_args is not None: # both module and args provided
119
+ if (
120
+ sample_args is not None or sample_kwargs is not None
121
+ ): # both module and args provided
104
122
  self.add_signature(
105
- cutils.DEFAULT_SIGNATURE_NAME, module, sample_args, dynamic_shapes
123
+ cutils.DEFAULT_SIGNATURE_NAME,
124
+ module,
125
+ sample_args,
126
+ sample_kwargs,
127
+ dynamic_shapes=dynamic_shapes,
128
+ )
129
+ else: # module is provided but not args
130
+ raise ValueError(
131
+ "sample_args or sample_kwargs must be provided if a module is specified."
106
132
  )
107
- else: # module is provided but not sample_args
108
- raise ValueError("sample_args needs to be provided if a module is specified.")
109
133
 
110
134
  return conversion.convert_signatures(
111
135
  self._signatures,
@@ -117,7 +141,8 @@ class Converter:
117
141
  def signature(
118
142
  name: str,
119
143
  module: torch.nn.Module,
120
- sample_args: tuple[cutils.TracingArg],
144
+ sample_args=None,
145
+ sample_kwargs=None,
121
146
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
122
147
  ) -> Converter:
123
148
  """Initiates a Converter object with the provided signature.
@@ -125,7 +150,8 @@ def signature(
125
150
  Args:
126
151
  name: The name of the signature included in the converted edge model.
127
152
  module: The torch module to be converted.
128
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
153
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
154
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
129
155
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
130
156
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
131
157
 
@@ -134,12 +160,15 @@ def signature(
134
160
  edge_model = converter.convert()
135
161
 
136
162
  """
137
- return Converter().signature(name, module, sample_args, dynamic_shapes)
163
+ return Converter().signature(
164
+ name, module, sample_args, sample_kwargs, dynamic_shapes=dynamic_shapes
165
+ )
138
166
 
139
167
 
140
168
  def convert(
141
169
  module: torch.nn.Module = None,
142
- sample_args: tuple[cutils.TracingArg] = None,
170
+ sample_args=None,
171
+ sample_kwargs=None,
143
172
  *,
144
173
  quant_config: Optional[qcfg.QuantConfig] = None,
145
174
  dynamic_shapes: Optional[Union[Dict[str, Any], Tuple[Any]]] = None,
@@ -149,7 +178,8 @@ def convert(
149
178
 
150
179
  Args:
151
180
  module: The torch module to be converted.
152
- sample_args: Tuple of args by which the torch module will be traced prior to conversion.
181
+ sample_args: Tuple of tensors by which the torch module will be traced with prior to conversion.
182
+ sample_kwargs: Dict of str to tensor by which the torch module will be traced with prior to conversion.
153
183
  quant_config: User-defined quantization method and scheme of the model.
154
184
  dynamic_shapes: Optional dict or tuple that specify dynamic shape specifications for each input in original order.
155
185
  See https://pytorch.org/docs/stable/export.html#expressing-dynamism for more details.
@@ -165,6 +195,7 @@ def convert(
165
195
  return Converter().convert(
166
196
  module,
167
197
  sample_args,
198
+ sample_kwargs,
168
199
  quant_config=quant_config,
169
200
  dynamic_shapes=dynamic_shapes,
170
201
  _ai_edge_converter_flags=_ai_edge_converter_flags,
@@ -267,6 +267,45 @@ class TestConvert(unittest.TestCase):
267
267
  model_coverage.compare_tflite_torch(edge_model, model, validate_input)
268
268
  )
269
269
 
270
+ def test_convert_model_with_kwargs(self):
271
+ """
272
+ Test converting a simple model with sample_kwargs.
273
+ """
274
+
275
+ class SampleModel(torch.nn.Module):
276
+
277
+ def forward(self, x, y):
278
+ return x + y
279
+
280
+ kwargs_gen = lambda: dict(x=torch.randn(10, 10), y=torch.randn(10, 10))
281
+
282
+ model = SampleModel().eval()
283
+ edge_model = ai_edge_torch.convert(model, sample_kwargs=kwargs_gen())
284
+
285
+ self.assertTrue(
286
+ model_coverage.compare_tflite_torch(edge_model, model, kwargs=kwargs_gen)
287
+ )
288
+
289
+ def test_convert_model_with_args_kwargs(self):
290
+ """
291
+ Test converting a simple model with both sample_args and sample_kwargs.
292
+ """
293
+
294
+ class SampleModel(torch.nn.Module):
295
+
296
+ def forward(self, x, y):
297
+ return x + y
298
+
299
+ args_gen = lambda: (torch.randn(10, 10),)
300
+ kwargs_gen = lambda: dict(y=torch.randn(10, 10))
301
+
302
+ model = SampleModel().eval()
303
+ edge_model = ai_edge_torch.convert(model, args_gen(), kwargs_gen())
304
+
305
+ self.assertTrue(
306
+ model_coverage.compare_tflite_torch(edge_model, model, args_gen, kwargs_gen)
307
+ )
308
+
270
309
 
271
310
  if __name__ == "__main__":
272
311
  unittest.main()
@@ -40,7 +40,7 @@ class RemoveSDPACompositeZeroMaskPass(ExportedProgramPassBase):
40
40
  if self.is_zero_tensor_node(source):
41
41
  # Remove the mark_tensor call on the mask input by
42
42
  # replacing the target with an identity function.
43
- node.target = lambda *args, **kwargs: args[0]
43
+ node.target = lambda *args, **kwargs: torch.zeros_like(args[0])
44
44
 
45
45
  exported_program.graph_module.graph.lint()
46
46
  exported_program.graph_module.recompile()
ai_edge_torch/model.py CHANGED
@@ -33,7 +33,10 @@ class Model(abc.ABC):
33
33
 
34
34
  @abc.abstractmethod
35
35
  def __call__(
36
- self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
36
+ self,
37
+ *args: npt.ArrayLike,
38
+ signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
39
+ **kwargs,
37
40
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
38
41
  raise NotImplementedError()
39
42
 
@@ -62,12 +65,16 @@ class TfLiteModel(Model):
62
65
  self._tflite_model = tflite_model
63
66
 
64
67
  def __call__(
65
- self, *args: npt.ArrayLike, signature_name: str = cutils.DEFAULT_SIGNATURE_NAME
68
+ self,
69
+ *args: npt.ArrayLike,
70
+ signature_name: str = cutils.DEFAULT_SIGNATURE_NAME,
71
+ **kwargs,
66
72
  ) -> npt.ArrayLike | tuple[npt.ArrayLike]:
67
73
  """Runs inference on the edge model using the provided arguments.
68
74
 
69
75
  Args:
70
76
  *args: The arguments to be passed to the model for inference.
77
+ **kwargs: The arguments with specific names to be passed to the model for inference.
71
78
  signature_name: The name of the signature to be used for inference.
72
79
  The default signature is used if not provided.
73
80
  """
@@ -90,13 +97,14 @@ class TfLiteModel(Model):
90
97
  else:
91
98
  raise exception
92
99
 
93
- if len(signature_list[signature_name]['inputs']) != len(args):
100
+ if len(signature_list[signature_name]['inputs']) != len(args) + len(kwargs):
94
101
  raise ValueError(
95
102
  f"The model requires {len(signature_list[signature_name]['inputs'])} arguments but {len(args)} was provided."
96
103
  )
97
104
 
98
105
  # Gather the input dictionary based on the signature.
99
106
  inputs = {f'args_{idx}': args[idx] for idx in range(len(args))}
107
+ inputs = {**inputs, **kwargs}
100
108
  outputs = runner(**inputs)
101
109
 
102
110
  return (
@@ -60,7 +60,8 @@ def _torch_tensors_to_np(*argv):
60
60
  def compare_tflite_torch(
61
61
  edge_model: Model,
62
62
  torch_eval_func: Callable,
63
- input_data=None,
63
+ args=None,
64
+ kwargs=None,
64
65
  *,
65
66
  num_valid_inputs: int = 1,
66
67
  signature_name: str = None,
@@ -71,8 +72,9 @@ def compare_tflite_torch(
71
72
  Args:
72
73
  edge_model: Serialized ai_edge_torch.model.Model object.
73
74
  torch_eval_func: Callable function to evaluate torch model.
74
- input_data: torch.tensor array or a callable to generate a torch.tensor array
75
+ args: torch.tensor array or a callable to generate a torch.tensor array
75
76
  with random data, to pass into models during inference. (default None).
77
+ kwargs: dict of str to torch.tensor, or a callable to generate such.
76
78
  num_valid_inputs: Defines the number of times the random inputs will be generated (if a callable is provided for input_data).
77
79
  signature_name: If provided, specifies the name for the signature of the edge_model to run.
78
80
  Calls the default signature if not provided.
@@ -86,29 +88,33 @@ def compare_tflite_torch(
86
88
  # The supplied model_def.forward_args() will be executed num_valid_inputs
87
89
  # times to generate num_valid_inputs random inputs.
88
90
  torch_inputs = [
89
- input_data() if callable(input_data) else input_data
91
+ (
92
+ (args() if callable(args) else args) or tuple(),
93
+ (kwargs() if callable(kwargs) else kwargs) or {},
94
+ )
90
95
  for _ in range(num_valid_inputs)
91
96
  ]
92
- torch_outputs = [torch_eval_func(*xs) for xs in torch_inputs]
93
- np_inputs = [_torch_tensors_to_np(xs) for xs in torch_inputs]
97
+ torch_outputs = [torch_eval_func(*args, **kwargs) for args, kwargs in torch_inputs]
98
+ np_inputs = [
99
+ (_torch_tensors_to_np(args), _torch_tensors_to_np(kwargs))
100
+ for args, kwargs in torch_inputs
101
+ ]
94
102
  np_outputs = [_torch_tensors_to_np(_flatten(ys)) for ys in torch_outputs]
95
103
 
96
104
  # Define inline utility function used throughout the function.
97
105
  def equal_fn(actual, expected):
98
106
  return np.allclose(actual, expected, atol=atol, rtol=rtol)
99
107
 
100
- def get_actual_fn(input):
108
+ def get_edge_output(inputs):
109
+ args, kwargs = inputs
101
110
  if signature_name is None:
102
- return _flatten(edge_model(*input))
111
+ return _flatten(edge_model(*args, **kwargs))
103
112
  else:
104
- return _flatten(edge_model(*input, signature_name=signature_name))
105
-
106
- def get_expected_fn(input=None, idx=0):
107
- return np_outputs[idx]
113
+ return _flatten(edge_model(*args, **kwargs, signature_name=signature_name))
108
114
 
109
115
  for idx, np_input in enumerate(np_inputs):
110
- output = get_actual_fn(np_input)
111
- golden_output = get_expected_fn(np_input, idx)
116
+ output = get_edge_output(np_input)
117
+ golden_output = np_outputs[idx]
112
118
 
113
119
  is_output_len_eq = len(golden_output) == len(output)
114
120
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240711
3
+ Version: 0.2.0.dev20240712
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -1,9 +1,9 @@
1
1
  ai_edge_torch/__init__.py,sha256=CNDboRP4zQBpz2hznNCQWcQCARvNXUm3DMa1Dw_XXFg,1067
2
- ai_edge_torch/model.py,sha256=kmcgELjsYl8YzF8nUF6P7q4i8MWS-pLGpfsy-yTUXmE,4243
2
+ ai_edge_torch/model.py,sha256=8Ba9ia7TCM_fciulw6qObmzdcxL3IaLQKDqpR7Lxp-Q,4440
3
3
  ai_edge_torch/convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
4
- ai_edge_torch/convert/conversion.py,sha256=8K8jQuaCjlUWoj7jiimxp_zpN6mYThLOcQ858UDcYnE,4159
5
- ai_edge_torch/convert/conversion_utils.py,sha256=9BqCL38DErv1vEVGtT3BIJVhdwZjw2EQ-_m5UpvVVYE,11294
6
- ai_edge_torch/convert/converter.py,sha256=bjj5TV5_g4sGyuSh8ThEDydlNMqhkGSY4SzXK6vwhqI,6927
4
+ ai_edge_torch/convert/conversion.py,sha256=StJHglvx6cii36oi8sj-tZda009e9UqR6ufZOZkP1SY,4137
5
+ ai_edge_torch/convert/conversion_utils.py,sha256=PKXIlSCU-8DhppNBh9ICDNUlEOpV0HgCbt85jDVe3rA,13394
6
+ ai_edge_torch/convert/converter.py,sha256=hSrW6A-kix9cjdD6CuLL7rseWrLKoV6GRy-iUSW_nZc,7875
7
7
  ai_edge_torch/convert/to_channel_last_io.py,sha256=zo5tY3yDhY_EPCkrL1XSXs2uRFS8B4_qu08dSjNsUGk,2778
8
8
  ai_edge_torch/convert/fx_passes/__init__.py,sha256=EPs4PSIDLuRH5EBETi6deaOvaaf_Q4xD3_9NVcR7x8o,2810
9
9
  ai_edge_torch/convert/fx_passes/_pass_base.py,sha256=ijVyDclPnd6a0DWWUJkwR4igj6f82S-cE1-83QGPvgw,1652
@@ -22,7 +22,7 @@ ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partition
22
22
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=8uHJbIwPMTgeSfYVba163pkXSQkHLxFwar_8A1AhgAM,2279
23
23
  ai_edge_torch/convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=lklGxE1R32vsjFbhLLBDEFL4pfLi_iTgI9Ftb6Grezk,7156
24
24
  ai_edge_torch/convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
25
- ai_edge_torch/convert/test/test_convert.py,sha256=2qPmmGqnfV_o1gfsSdjGq3-JR1b323ligiy5MdAv9NA,8021
25
+ ai_edge_torch/convert/test/test_convert.py,sha256=h0vOffr8saDQRkiXljNWDZ17EBjnS4xAtxd8DxETleY,9081
26
26
  ai_edge_torch/convert/test/test_convert_composites.py,sha256=_Ojc-H6GOS5s8ek3_8eRBL_AiCs-k3srziPJ2R4Ulrg,7255
27
27
  ai_edge_torch/convert/test/test_convert_multisig.py,sha256=kMaGnHe9ylfyU68qCifYcaGwJqyejKz--QQt9jS2oUA,4537
28
28
  ai_edge_torch/convert/test/test_to_channel_last_io.py,sha256=I8c4ZG3v1vo0yxQYzLK_BTId4AOL9vadHGDtfCUZ4UI,2930
@@ -68,7 +68,7 @@ ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TI
68
68
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=nT7Fh-f5ZdwaK3dPoCvZflpJ4fRHjLdFMjk1_uw3-b8,2559
69
69
  ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
70
70
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=aXvYiaHDvETIrh0Q9DDZA_ZBiazGk80DT6nt7lLtC1o,1172
71
- ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=IehLwFNwa0C9fnk1pmNmyfuAwwWbuwdyKy46BSqNVdI,1948
71
+ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=BCAcc_OcEjvbaXQSbc8vlKeMad7E3gCA4BNsUdWRwBI,1966
72
72
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
73
  ai_edge_torch/generative/layers/attention.py,sha256=AW0Qo3uOIe6p1rJNJ6zR_r4fqL2y-6QJHh0yUd-5Yb0,11966
74
74
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
@@ -113,9 +113,9 @@ ai_edge_torch/quantize/pt2e_quantizer_utils.py,sha256=yjzKoptnfEeW_sN7sODUfj3nCt
113
113
  ai_edge_torch/quantize/quant_config.py,sha256=eO9Ra160ITjQSyRBEGy6nNIVH3gYacSWDdN5XtvHwjc,3148
114
114
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
115
115
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
116
- ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
117
- ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
- ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/METADATA,sha256=GftPz7zSGYCaTvO4gntWftMbj0NCSh4OXJEe1epdBCU,1745
119
- ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
- ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
- ai_edge_torch_nightly-0.2.0.dev20240711.dist-info/RECORD,,
116
+ ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=kzIulTldq8R9E-lAZsvfSTvLu3FYEX7b9DyYM3qisXM,4485
117
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
118
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/METADATA,sha256=BGHmRLYo3ko7KRysDP59YexTpPn45jrtpRwqQkPAM5s,1745
119
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
120
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
121
+ ai_edge_torch_nightly-0.2.0.dev20240712.dist-info/RECORD,,