ai-edge-torch-nightly 0.3.0.dev20240909__py3-none-any.whl → 0.3.0.dev20240910__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -25,6 +25,7 @@ from ai_edge_torch.testing import model_coverage
25
25
  import numpy as np
26
26
  import tensorflow as tf
27
27
  import torch
28
+ from torch import nn
28
29
  import torchvision
29
30
 
30
31
  from absl.testing import absltest as googletest
@@ -51,7 +52,7 @@ class TestConvert(googletest.TestCase):
51
52
  def test_convert_add(self):
52
53
  """Tests conversion of a simple Add module."""
53
54
 
54
- class Add(torch.nn.Module):
55
+ class Add(nn.Module):
55
56
 
56
57
  def forward(self, a, b):
57
58
  return a + b
@@ -70,7 +71,7 @@ class TestConvert(googletest.TestCase):
70
71
  def test_convert_dot_add(self):
71
72
  """Tests conversion of a matrix multiplication followed by an add."""
72
73
 
73
- class DotAdd(torch.nn.Module):
74
+ class DotAdd(nn.Module):
74
75
 
75
76
  def forward(self, a, b, c):
76
77
  return a @ b + c
@@ -99,7 +100,7 @@ class TestConvert(googletest.TestCase):
99
100
  def test_signature_args_ordering(self):
100
101
  """Tests conversion of a model with more than 10 arguments."""
101
102
 
102
- class AddChainWith11Args(torch.nn.Module):
103
+ class AddChainWith11Args(nn.Module):
103
104
  """A model with 11 arguments."""
104
105
 
105
106
  def forward(
@@ -152,7 +153,7 @@ class TestConvert(googletest.TestCase):
152
153
  def test_multi_output_model(self):
153
154
  """Tests conversion of a model that returns multiple outputs."""
154
155
 
155
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
156
+ class BasicAddModelWithMultipleOutputs(nn.Module):
156
157
  """A model that returns multiple outputs."""
157
158
 
158
159
  def forward(self, arg0, arg1):
@@ -176,7 +177,7 @@ class TestConvert(googletest.TestCase):
176
177
  def test_12_outputs_model(self):
177
178
  """Tests conversion of a model that returns more than 10 outputs."""
178
179
 
179
- class BasicAddModelWithMultipleOutputs(torch.nn.Module):
180
+ class BasicAddModelWithMultipleOutputs(nn.Module):
180
181
  """A model that returns multiple outputs."""
181
182
 
182
183
  def forward(self, arg0, arg1):
@@ -245,7 +246,7 @@ class TestConvert(googletest.TestCase):
245
246
  def test_convert_add_converter_flags(self):
246
247
  """Tests conversion of an add module setting a tflite converter flag."""
247
248
 
248
- class Add(torch.nn.Module):
249
+ class Add(nn.Module):
249
250
 
250
251
  def forward(self, a, b):
251
252
  return a + b
@@ -267,6 +268,27 @@ class TestConvert(googletest.TestCase):
267
268
  )
268
269
  self.assertTrue(os.path.isdir(ir_dump_path))
269
270
 
271
+ def test_convert_conv_transpose_batch_norm(self):
272
+ """Tests conversion of a model with ConvTranspose2d and BatchNorm2d."""
273
+
274
+ channels = 2
275
+ size = 2
276
+ torch_model = nn.Sequential(
277
+ nn.ConvTranspose2d(
278
+ channels, channels, 1, stride=2, dilation=1, bias=False
279
+ ),
280
+ nn.BatchNorm2d(channels),
281
+ )
282
+
283
+ torch_model.eval()
284
+ sample_input = (torch.rand(1, channels, size, size),)
285
+ edge_model = ai_edge_torch.convert(torch_model, sample_input)
286
+
287
+ result = model_coverage.compare_tflite_torch(
288
+ edge_model, torch_model, sample_input
289
+ )
290
+ self.assertTrue(result)
291
+
270
292
  @googletest.skipIf(
271
293
  not config.Config.use_torch_xla,
272
294
  reason="Shape polymorphism is not yet support with odml_torch.",
@@ -274,7 +296,7 @@ class TestConvert(googletest.TestCase):
274
296
  def test_convert_model_with_dynamic_batch(self):
275
297
  """Test converting a simple model with dynamic batch size."""
276
298
 
277
- class SampleModel(torch.nn.Module):
299
+ class SampleModel(nn.Module):
278
300
 
279
301
  def __init__(self):
280
302
  super().__init__()
@@ -304,7 +326,7 @@ class TestConvert(googletest.TestCase):
304
326
  def test_convert_model_with_kwargs(self):
305
327
  """Test converting a simple model with sample_kwargs."""
306
328
 
307
- class SampleModel(torch.nn.Module):
329
+ class SampleModel(nn.Module):
308
330
 
309
331
  def forward(self, x, y):
310
332
  return x + y
@@ -323,7 +345,7 @@ class TestConvert(googletest.TestCase):
323
345
  def test_convert_model_with_args_kwargs(self):
324
346
  """Test converting a simple model with both sample_args and sample_kwargs."""
325
347
 
326
- class SampleModel(torch.nn.Module):
348
+ class SampleModel(nn.Module):
327
349
 
328
350
  def forward(self, x, y):
329
351
  return x + y
@@ -343,7 +365,7 @@ class TestConvert(googletest.TestCase):
343
365
  def test_convert_model_with_args_nested_kwargs_1(self):
344
366
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
345
367
 
346
- class SampleModel(torch.nn.Module):
368
+ class SampleModel(nn.Module):
347
369
 
348
370
  def forward(self, x: torch.Tensor, y: torch.Tensor, z: TestContainer1):
349
371
  return x + y + z.data_1 + z.data_2[0] + z.data_2[1]
@@ -370,7 +392,7 @@ class TestConvert(googletest.TestCase):
370
392
  def test_convert_model_with_args_nested_kwargs_2(self):
371
393
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
372
394
 
373
- class SampleModel(torch.nn.Module):
395
+ class SampleModel(nn.Module):
374
396
 
375
397
  def forward(self, x, y, z):
376
398
  return x + y + z.data_1 + z.data_2[0][0] + z.data_2[1]
@@ -397,7 +419,7 @@ class TestConvert(googletest.TestCase):
397
419
  def test_convert_model_with_args_nested_kwargs_3(self):
398
420
  """Test converting a simple model with both sample_args and nested sample_kwargs."""
399
421
 
400
- class SampleModel(torch.nn.Module):
422
+ class SampleModel(nn.Module):
401
423
 
402
424
  def forward(self, x, y, z):
403
425
  return x + y + z.data_1 + z.data_2[0]["foo"] + z.data_2[1]
@@ -424,7 +446,7 @@ class TestConvert(googletest.TestCase):
424
446
  def test_convert_model_non_flat_output_dict(self):
425
447
  """Test converting a model with non-flat output structure."""
426
448
 
427
- class SampleModel(torch.nn.Module):
449
+ class SampleModel(nn.Module):
428
450
 
429
451
  def forward(self, x, y, z):
430
452
  return {"x": x, "y": TestContainer1(data_1=y, data_2=[y, z])}
@@ -12,22 +12,171 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- """Provides lowering for coreaten to mlir stablehlo op: Convolution"""
15
+ """Provides lowering for coreaten to stablehlo for Convolution."""
16
16
 
17
17
  import math
18
18
  from typing import Optional
19
19
 
20
+ from ai_edge_torch.odml_torch.lowerings import registry
20
21
  from jax._src.lib.mlir import ir
21
22
  from jax._src.lib.mlir.dialects import hlo as stablehlo
22
23
  import torch
23
24
 
24
- from .registry import lower
25
+
26
+ def make_padding(padding):
27
+ """Change the padding from pytorch to stablehlo style.
28
+
29
+ Stablehlo allows start and end padding for each dimension while aten only
30
+ allows symmetric padding and so only has one number per dimension.
31
+
32
+ Args:
33
+ padding: The padding of the convolution
34
+
35
+ Returns:
36
+ The padding in stablehlo style
37
+ """
38
+ return tuple((p, p) for p in padding)
39
+
40
+
41
+ def create_conv_dimension_numbers(lhs, transposed: bool = False):
42
+ """Create the dimension numbers for the convolution.
43
+
44
+ Args:
45
+ lhs: The input tensor
46
+ transposed: Whether the convolution is transposed
47
+
48
+ Returns:
49
+ The dimension numbers for the convolution
50
+ """
51
+ num_spatial_dims = len(lhs.type.shape) - 2
52
+ spatial_dimensions = []
53
+ for i in range(0, num_spatial_dims):
54
+ spatial_dimensions.append(i + 2)
55
+
56
+ # Regular kernels are OIHW
57
+ # TransposedConv kernels are IOHW
58
+ dimension_numbers = stablehlo.ConvDimensionNumbers.get(
59
+ input_batch_dimension=0,
60
+ input_feature_dimension=1,
61
+ input_spatial_dimensions=spatial_dimensions,
62
+ kernel_input_feature_dimension=0 if transposed else 1,
63
+ kernel_output_feature_dimension=1 if transposed else 0,
64
+ kernel_spatial_dimensions=spatial_dimensions,
65
+ output_batch_dimension=0,
66
+ output_feature_dimension=1,
67
+ output_spatial_dimensions=spatial_dimensions,
68
+ )
69
+ return dimension_numbers
70
+
71
+
72
+ def infer_output_shape(
73
+ lhs,
74
+ rhs,
75
+ stride,
76
+ dilation,
77
+ padding,
78
+ transposed: bool = False,
79
+ output_padding: list[int] = 0,
80
+ ):
81
+ """Infer the output shape of the convolution.
82
+
83
+ Args:
84
+ lhs: The input tensor
85
+ rhs: The kernel tensor
86
+ stride: The stride of the convolution (dilation of input in transposed conv)
87
+ dilation: The kernel dilation of the convolution
88
+ padding: The padding of the convolution
89
+ transposed: Whether the convolution is transposed
90
+ output_padding: The output padding of the convolution
91
+
92
+ Returns:
93
+ The output shape of the convolution
94
+ """
95
+ lhs_type: ir.RankedTensorType = lhs.type
96
+ lhs_shape: list[int] = lhs_type.shape
97
+ rhs_shape: list[int] = rhs.type.shape
98
+
99
+ # Input layout is: (N)CHW and Kernel layout is: (O)IHW for regular conv
100
+ # Input layout is: (N)CHW and Kernel layout is: I(O)HW for transposed conv
101
+ output_shape = (
102
+ [lhs_shape[0], rhs_shape[1]]
103
+ if transposed
104
+ else [lhs_shape[0], rhs_shape[0]]
105
+ )
106
+ num_spatial_dims = len(lhs.type.shape) - 2
107
+
108
+ # looping over the spatial dims (skipping the first 2 dims which are
109
+ # batch and features)
110
+ for spatial_dim in range(0, num_spatial_dims):
111
+ dim = spatial_dim + 2
112
+ dim_size = lhs_shape[dim]
113
+ kernel_dim_size = rhs_shape[dim]
114
+
115
+ if transposed:
116
+ output_dim_size = (
117
+ (dim_size - 1) * stride[spatial_dim]
118
+ - 2 * padding[spatial_dim]
119
+ + dilation[spatial_dim] * (kernel_dim_size - 1)
120
+ + output_padding[spatial_dim]
121
+ + 1
122
+ )
123
+ else:
124
+ output_dim_size = math.floor(
125
+ (
126
+ (
127
+ dim_size
128
+ + 2 * padding[spatial_dim]
129
+ - dilation[spatial_dim] * (kernel_dim_size - 1)
130
+ - 1
131
+ )
132
+ / stride[spatial_dim]
133
+ )
134
+ + 1
135
+ )
136
+
137
+ output_shape.append(output_dim_size)
138
+
139
+ return output_shape
140
+
141
+
142
+ def build_transpose_conv(
143
+ lctx,
144
+ output_type: ir.RankedTensorType,
145
+ lhs: ir.Value,
146
+ rhs: ir.Value,
147
+ stride: list[int],
148
+ padding: list[int],
149
+ dilation: list[int],
150
+ output_padding: list[int],
151
+ groups: int,
152
+ ):
153
+ lhs_type: ir.RankedTensorType = lhs.type
154
+ num_spatial_dims = len(lhs_type.shape) - 2
155
+ rhs = stablehlo.reverse(rhs, list(range(2, 2 + num_spatial_dims)))
156
+
157
+ kernel_size = rhs.type.shape
158
+ # We need to additional padding on the input to get the right output size.
159
+ adjusted_padding = [
160
+ dilation[dim] * (kernel_size[dim + 2] - 1) - padding[dim]
161
+ for dim in range(num_spatial_dims)
162
+ ]
163
+ return stablehlo.convolution(
164
+ result=output_type,
165
+ lhs=lhs,
166
+ rhs=rhs,
167
+ dimension_numbers=create_conv_dimension_numbers(lhs, True),
168
+ feature_group_count=groups,
169
+ batch_group_count=1,
170
+ padding=make_padding(adjusted_padding),
171
+ lhs_dilation=stride,
172
+ rhs_dilation=dilation,
173
+ )
25
174
 
26
175
 
27
176
  # convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride,
28
177
  # SymInt[] padding, SymInt[] dilation, bool transposed,
29
178
  # SymInt[] output_padding, SymInt groups) -> Tensor
30
- # @lower(torch.ops.aten.convolution)
179
+ @registry.lower(torch.ops.aten.convolution)
31
180
  def _aten_convolution(
32
181
  lctx,
33
182
  lhs: ir.Value,
@@ -40,80 +189,53 @@ def _aten_convolution(
40
189
  output_padding: list[int],
41
190
  groups: int,
42
191
  ):
43
- if transposed:
44
- raise NotImplementedError("Transposed convolution is not implemented.")
45
192
 
46
- if bias is not None:
47
- raise NotImplementedError("Bias on convolution is not implemented.")
48
-
49
- # Stablehlo allows start and end padding for each dimension while aten only
50
- # allows symmetric padding and so only has one number per dimension.
51
- def make_padding(padding):
52
- return tuple((p, p) for p in padding)
53
-
54
- def create_conv_dimension_numbers():
55
- num_spatial_dims = len(lhs.type.shape) - 2
56
- spatial_dimensions = []
57
- for i in range(0, num_spatial_dims):
58
- spatial_dimensions.append(i + 2)
59
-
60
- dimension_numbers = stablehlo.ConvDimensionNumbers.get(
61
- input_batch_dimension=0,
62
- input_feature_dimension=1,
63
- input_spatial_dimensions=spatial_dimensions,
64
- kernel_input_feature_dimension=1,
65
- kernel_output_feature_dimension=0,
66
- kernel_spatial_dimensions=spatial_dimensions,
67
- output_batch_dimension=0,
68
- output_feature_dimension=1,
69
- output_spatial_dimensions=spatial_dimensions,
193
+ # TODO(b/365559296) Add support for output_padding
194
+ if any(output_padding):
195
+ raise NotImplementedError(
196
+ "Output padding on convolution is not implemented."
70
197
  )
71
- return dimension_numbers
72
-
73
- def infer_output_shape():
74
- lhs_type: ir.RankedTensorType = lhs.type
75
- lhs_shape: list[int] = lhs_type.shape
76
- rhs_shape: list[int] = rhs.type.shape
77
-
78
- # Input layout is: (N)CHW and Kernel layout is: (O)IHW
79
- output_shape = [lhs_shape[0], rhs_shape[0]]
80
- num_spatial_dims = len(lhs.type.shape) - 2
81
-
82
- # looping over the spatial dims (skipping the first 2 dims which are
83
- # batch and features)
84
- for spatial_dim in range(0, num_spatial_dims):
85
- dim_size = lhs_shape[spatial_dim + 2]
86
- kernel_dim_size = rhs_shape[spatial_dim + 2]
87
-
88
- # for example, a dilation of 2 increases the dimension size by 2
89
- dim_size *= dilation[spatial_dim]
90
-
91
- # padding added to both sides
92
- dim_size += 2 * padding[spatial_dim]
93
-
94
- output_dim_size = math.ceil(
95
- (dim_size - kernel_dim_size + 1) / stride[spatial_dim]
96
- )
97
-
98
- output_shape.append(output_dim_size)
99
-
100
- return output_shape
101
198
 
102
199
  lhs_type: ir.RankedTensorType = lhs.type
103
-
104
- op = stablehlo.ConvolutionOp(
105
- result=ir.RankedTensorType.get(
106
- infer_output_shape(), lhs_type.element_type
107
- ),
108
- lhs=lhs,
109
- rhs=rhs,
110
- dimension_numbers=create_conv_dimension_numbers(),
111
- feature_group_count=groups,
112
- batch_group_count=1,
113
- window_strides=stride,
114
- padding=make_padding(padding),
115
- lhs_dilation=(1,) * len(stride),
116
- rhs_dilation=dilation,
200
+ output_shape = infer_output_shape(
201
+ lhs, rhs, stride, dilation, padding, transposed, output_padding
202
+ )
203
+ output_type = ir.RankedTensorType.get(
204
+ output_shape,
205
+ lhs_type.element_type,
117
206
  )
118
207
 
119
- return op.result
208
+ if transposed:
209
+ res = build_transpose_conv(
210
+ lctx,
211
+ output_type,
212
+ lhs,
213
+ rhs,
214
+ stride,
215
+ padding,
216
+ dilation,
217
+ output_padding,
218
+ groups,
219
+ )
220
+ else:
221
+ res = stablehlo.convolution(
222
+ result=output_type,
223
+ lhs=lhs,
224
+ rhs=rhs,
225
+ dimension_numbers=create_conv_dimension_numbers(lhs),
226
+ feature_group_count=groups,
227
+ batch_group_count=1,
228
+ window_strides=stride,
229
+ padding=make_padding(padding),
230
+ rhs_dilation=dilation,
231
+ )
232
+
233
+ if bias is not None:
234
+ # broadcast [C] to [NCHW]
235
+ broadcasted_bias = stablehlo.broadcast_in_dim(output_type, bias, [1])
236
+ res = stablehlo.add(
237
+ lhs=res,
238
+ rhs=broadcasted_bias,
239
+ )
240
+
241
+ return res
@@ -105,7 +105,6 @@ lower_by_torch_xla2(torch.ops.aten.clamp.default)
105
105
  lower_by_torch_xla2(torch.ops.aten.clone)
106
106
  lower_by_torch_xla2(torch.ops.aten.clone.default)
107
107
  lower_by_torch_xla2(torch.ops.aten.constant_pad_nd)
108
- lower_by_torch_xla2(torch.ops.aten.convolution)
109
108
  lower_by_torch_xla2(torch.ops.aten.cos)
110
109
  lower_by_torch_xla2(torch.ops.aten.cosh)
111
110
  lower_by_torch_xla2(torch.ops.aten.cumsum)
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240909"
16
+ __version__ = "0.3.0.dev20240910"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240909
3
+ Version: 0.3.0.dev20240910
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=r0y6crIySNGhJqtljkzyHxb1XMvLji2VLajLfUjW8b4,706
5
+ ai_edge_torch/version.py,sha256=e4sh_RFYgNHGoVuOeICnFZtLu1MQCNv7qpq94nKFarU,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -26,7 +26,7 @@ ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitio
26
26
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/greedy.py,sha256=L_x8BrF7UDah-SYl-pG11I6CIckdU9kBTUHcmwW4cts,2420
27
27
  ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_partitioners/min_cut.py,sha256=mzfL9cf0qBnpmxM_OlMQFvQsEZV2B_Mia9yEJV4J7rI,7135
28
28
  ai_edge_torch/_convert/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
29
- ai_edge_torch/_convert/test/test_convert.py,sha256=pUYSXuqFg8CAeJ8JkoYf7S0RDLRPVuZUwVOd0xObM6w,14411
29
+ ai_edge_torch/_convert/test/test_convert.py,sha256=FSufFZEeTLBpUnzE1Iy-LvNN0mhDynWMNg7Mei8RpLQ,14973
30
30
  ai_edge_torch/_convert/test/test_convert_composites.py,sha256=BCIODgxMI_3MxMLfNWYMGjcz-al-J3z5eDHCiZJXNwY,7992
31
31
  ai_edge_torch/_convert/test/test_convert_multisig.py,sha256=6_C2R9--KyNR7_oezZIAfyTSR97tOeEWy4XGcbSxBDE,5778
32
32
  ai_edge_torch/_convert/test/test_to_channel_last_io.py,sha256=1o-gUiwzIuO67FNAJ8DeyKv8fVUeZVNNNwofNVDjYeU,3024
@@ -148,8 +148,8 @@ ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7i
148
148
  ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
149
149
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
150
150
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
151
- ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=B6BILeu-UlwGB1O6g7111X1TaIFznsfxXrB72ygBsBA,3885
152
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=I0Y4IK7Zap8m6xfxMw7DfQ9Mg4htKOoypdHVAMHqx9c,10669
151
+ ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
152
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
153
153
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
154
154
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
155
155
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -161,8 +161,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
161
161
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
162
162
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
163
163
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
164
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/METADATA,sha256=s7SAIUvFciy8peNKMHvyhoNQWYx67Jerz4foeV7KiE0,1859
166
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
- ai_edge_torch_nightly-0.3.0.dev20240909.dist-info/RECORD,,
164
+ ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
165
+ ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/METADATA,sha256=WFNExTO6eF-tAWPmDdQDlr9dvplcoNB0uPdVxSNXYHk,1859
166
+ ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
167
+ ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
168
+ ai_edge_torch_nightly-0.3.0.dev20240910.dist-info/RECORD,,