ai-edge-torch-nightly 0.7.0.dev20251007__py3-none-any.whl → 0.8.0.dev20251225__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of ai-edge-torch-nightly might be problematic. Click here for more details.
- ai_edge_torch/_convert/conversion.py +2 -1
- ai_edge_torch/fx_infra/__init__.py +1 -0
- ai_edge_torch/fx_infra/_safe_run_decompositions.py +54 -1
- ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py +1 -3
- ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py +1 -1
- ai_edge_torch/generative/layers/attention.py +25 -2
- ai_edge_torch/generative/layers/attention_test.py +13 -1
- ai_edge_torch/generative/layers/attention_utils.py +62 -1
- ai_edge_torch/generative/layers/attention_utils_test.py +20 -0
- ai_edge_torch/generative/layers/builder.py +4 -2
- ai_edge_torch/generative/layers/model_config.py +5 -0
- ai_edge_torch/generative/layers/normalization.py +8 -2
- ai_edge_torch/generative/layers/scaled_dot_product_attention.py +35 -5
- ai_edge_torch/generative/layers/sdpa_with_kv_update.py +8 -3
- ai_edge_torch/generative/quantize/example.py +1 -1
- ai_edge_torch/generative/quantize/quant_attrs.py +8 -1
- ai_edge_torch/generative/quantize/quant_recipe.py +0 -13
- ai_edge_torch/generative/quantize/quant_recipe_utils.py +12 -19
- ai_edge_torch/generative/quantize/quant_recipes.py +16 -21
- ai_edge_torch/generative/quantize/supported_schemes.py +4 -1
- ai_edge_torch/generative/test/test_kv_cache.py +18 -6
- ai_edge_torch/generative/test/test_quantize.py +17 -26
- ai_edge_torch/generative/utilities/converter.py +97 -22
- ai_edge_torch/generative/utilities/litertlm_builder.py +61 -8
- ai_edge_torch/generative/utilities/loader.py +2 -1
- ai_edge_torch/lowertools/translate_recipe.py +8 -3
- ai_edge_torch/odml_torch/experimental/__init__.py +14 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py +20 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py +438 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py +728 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py +371 -0
- ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py +37 -0
- ai_edge_torch/odml_torch/export.py +24 -7
- ai_edge_torch/odml_torch/lowerings/_basic.py +155 -0
- ai_edge_torch/odml_torch/lowerings/_decomp_registry.py +94 -2
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py +255 -5
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/METADATA +15 -3
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/RECORD +42 -36
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/WHEEL +1 -1
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.7.0.dev20251007.dist-info → ai_edge_torch_nightly-0.8.0.dev20251225.dist-info}/top_level.txt +0 -0
|
@@ -14,13 +14,72 @@
|
|
|
14
14
|
# ==============================================================================
|
|
15
15
|
"""Torch export decompositions to run before lowering."""
|
|
16
16
|
|
|
17
|
+
import functools
|
|
17
18
|
from ai_edge_torch import fx_infra
|
|
18
19
|
import torch
|
|
19
20
|
|
|
20
21
|
|
|
22
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
23
|
+
def upsample_compute_output_size(input_size, output_size, scale_factors):
|
|
24
|
+
spatial_dimensions = len(input_size) - 2
|
|
25
|
+
if output_size is not None:
|
|
26
|
+
torch._check(
|
|
27
|
+
scale_factors is None,
|
|
28
|
+
lambda: "Must specify exactly one of output_size and scale_factors",
|
|
29
|
+
)
|
|
30
|
+
torch._check(len(output_size) == spatial_dimensions, lambda: "")
|
|
31
|
+
return output_size
|
|
32
|
+
if scale_factors is not None:
|
|
33
|
+
# NB: this isn't necessary lol
|
|
34
|
+
torch._check(
|
|
35
|
+
output_size is None,
|
|
36
|
+
lambda: "Must specify exactly one of output_size and scale_factors",
|
|
37
|
+
)
|
|
38
|
+
torch._check(len(scale_factors) == spatial_dimensions, lambda: "")
|
|
39
|
+
output_size = []
|
|
40
|
+
for i, s in enumerate(scale_factors):
|
|
41
|
+
if int(s) == s:
|
|
42
|
+
output_size.append(input_size[i + 2] * int(s))
|
|
43
|
+
else:
|
|
44
|
+
output_size.append(torch.sym_int(input_size[i + 2] * s))
|
|
45
|
+
return output_size
|
|
46
|
+
torch._check(
|
|
47
|
+
False, lambda: "Must specify exactly one of output_size and scale_factors"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
52
|
+
def _compute_upsample_nearest_indices(input, output_size, scales, exact=False):
|
|
53
|
+
indices = []
|
|
54
|
+
num_spatial_dims = len(output_size)
|
|
55
|
+
offset = 0.5 if exact else 0.0
|
|
56
|
+
|
|
57
|
+
for d in range(num_spatial_dims):
|
|
58
|
+
osize = output_size[d]
|
|
59
|
+
isize = input.shape[-num_spatial_dims + d]
|
|
60
|
+
scale = (
|
|
61
|
+
isize / (isize * scales[d]) if scales[d] is not None else isize / osize
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
output_indices = torch.arange(
|
|
65
|
+
osize, dtype=torch.float32, device=input.device
|
|
66
|
+
)
|
|
67
|
+
input_indices = ((output_indices + offset) * scale).to(torch.int64)
|
|
68
|
+
for _ in range(num_spatial_dims - 1 - d):
|
|
69
|
+
input_indices = input_indices.unsqueeze(-1)
|
|
70
|
+
indices.append(input_indices)
|
|
71
|
+
return tuple(indices)
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Fork from pytorch/torch/_decomp/decompositions.py
|
|
75
|
+
def _upsample_nearest2d_common(input, h_indices, w_indices):
|
|
76
|
+
result = torch.ops.aten.index(input, (None, None, h_indices, w_indices))
|
|
77
|
+
result = result.contiguous()
|
|
78
|
+
return result
|
|
79
|
+
|
|
80
|
+
|
|
21
81
|
fx_infra.decomp.update_pre_lower_decomp(
|
|
22
82
|
torch._decomp.get_decompositions([
|
|
23
|
-
torch.ops.aten.upsample_nearest2d,
|
|
24
83
|
torch.ops.aten._native_batch_norm_legit.no_stats,
|
|
25
84
|
torch.ops.aten._native_batch_norm_legit_functional,
|
|
26
85
|
torch.ops.aten._adaptive_avg_pool2d,
|
|
@@ -35,11 +94,44 @@ fx_infra.decomp.update_pre_lower_decomp(
|
|
|
35
94
|
torch.ops.aten.replication_pad2d,
|
|
36
95
|
torch.ops.aten.replication_pad3d,
|
|
37
96
|
torch.ops.aten.upsample_bilinear2d.vec,
|
|
38
|
-
torch.ops.aten.upsample_nearest2d.vec,
|
|
39
97
|
torch.ops.aten.addmm,
|
|
40
98
|
])
|
|
41
99
|
)
|
|
42
100
|
|
|
101
|
+
|
|
102
|
+
@functools.partial(
|
|
103
|
+
fx_infra.decomp.add_pre_lower_decomp,
|
|
104
|
+
torch.ops.aten.upsample_nearest2d.default,
|
|
105
|
+
)
|
|
106
|
+
@fx_infra.annotate_force_decomp
|
|
107
|
+
def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
|
|
108
|
+
h_indices, w_indices = _compute_upsample_nearest_indices(
|
|
109
|
+
input, output_size, (scales_h, scales_w)
|
|
110
|
+
)
|
|
111
|
+
return _upsample_nearest2d_common(input, h_indices, w_indices)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def get_scale_value(scales, idx):
|
|
115
|
+
if scales is None:
|
|
116
|
+
return None
|
|
117
|
+
return scales[idx]
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
@functools.partial(
|
|
121
|
+
fx_infra.decomp.add_pre_lower_decomp,
|
|
122
|
+
torch.ops.aten.upsample_nearest2d.vec,
|
|
123
|
+
)
|
|
124
|
+
@fx_infra.annotate_force_decomp
|
|
125
|
+
def upsample_nearest2d_vec(input, output_size, scale_factors):
|
|
126
|
+
osize = upsample_compute_output_size(input.size(), output_size, scale_factors)
|
|
127
|
+
scale_h = get_scale_value(scale_factors, 0)
|
|
128
|
+
scale_w = get_scale_value(scale_factors, 1)
|
|
129
|
+
|
|
130
|
+
return torch.ops.aten.upsample_nearest2d.default(
|
|
131
|
+
input, osize, scale_h, scale_w
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
|
|
43
135
|
fx_infra.decomp.remove_pre_lower_decomp(torch.ops.aten.roll)
|
|
44
136
|
|
|
45
137
|
# Torch's default einsum impl/decompositions is less efficient and
|
|
@@ -21,6 +21,7 @@ from ai_edge_torch.odml_torch.lowerings import registry
|
|
|
21
21
|
import jax
|
|
22
22
|
import jax.numpy as jnp
|
|
23
23
|
from jax._src.lib.mlir import ir
|
|
24
|
+
import numpy as np
|
|
24
25
|
import torch
|
|
25
26
|
import torch_xla2.ops.jaten # Import to load torch_xla2 ops
|
|
26
27
|
import torch_xla2.ops.ops_registry # Import to load torch_xla2 ops
|
|
@@ -71,8 +72,6 @@ lower_by_torch_xla2(torch.ops.aten._cdist_forward)
|
|
|
71
72
|
lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
|
|
72
73
|
lower_by_torch_xla2(torch.ops.aten._local_scalar_dense)
|
|
73
74
|
lower_by_torch_xla2(torch.ops.aten._log_softmax)
|
|
74
|
-
lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit)
|
|
75
|
-
lower_by_torch_xla2(torch.ops.aten._native_batch_norm_legit_no_training)
|
|
76
75
|
lower_by_torch_xla2(torch.ops.aten._pdist_forward)
|
|
77
76
|
lower_by_torch_xla2(torch.ops.aten._softmax)
|
|
78
77
|
lower_by_torch_xla2(torch.ops.aten._unsafe_index)
|
|
@@ -158,10 +157,8 @@ lower_by_torch_xla2(torch.ops.aten.logical_not)
|
|
|
158
157
|
lower_by_torch_xla2(torch.ops.aten.logical_or)
|
|
159
158
|
lower_by_torch_xla2(torch.ops.aten.logical_xor)
|
|
160
159
|
lower_by_torch_xla2(torch.ops.aten.max)
|
|
161
|
-
lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices)
|
|
162
160
|
lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
|
|
163
161
|
lower_by_torch_xla2(torch.ops.aten.max_pool2d_with_indices_backward)
|
|
164
|
-
lower_by_torch_xla2(torch.ops.aten.max_pool3d_with_indices)
|
|
165
162
|
lower_by_torch_xla2(torch.ops.aten.maximum)
|
|
166
163
|
lower_by_torch_xla2(torch.ops.aten.mean)
|
|
167
164
|
lower_by_torch_xla2(torch.ops.aten.min)
|
|
@@ -175,7 +172,6 @@ lower_by_torch_xla2(torch.ops.aten.nonzero)
|
|
|
175
172
|
lower_by_torch_xla2(torch.ops.aten.outer)
|
|
176
173
|
lower_by_torch_xla2(torch.ops.aten.permute)
|
|
177
174
|
lower_by_torch_xla2(torch.ops.aten.permute_copy)
|
|
178
|
-
lower_by_torch_xla2(torch.ops.aten.pixel_shuffle)
|
|
179
175
|
lower_by_torch_xla2(torch.ops.aten.pow)
|
|
180
176
|
lower_by_torch_xla2(torch.ops.aten.prod)
|
|
181
177
|
lower_by_torch_xla2(torch.ops.aten.reciprocal)
|
|
@@ -240,6 +236,249 @@ lower_by_torch_xla2(torch.ops.prims.broadcast_in_dim)
|
|
|
240
236
|
lower_by_torch_xla2(torch.ops.prims.var)
|
|
241
237
|
|
|
242
238
|
|
|
239
|
+
def _ceil_mode_padding(
|
|
240
|
+
padding: list[int],
|
|
241
|
+
input_shape: list[int],
|
|
242
|
+
kernel_size: list[int],
|
|
243
|
+
stride: list[int],
|
|
244
|
+
dilation: list[int],
|
|
245
|
+
ceil_mode: bool,
|
|
246
|
+
):
|
|
247
|
+
"""Creates low and high padding specification for ceil mode.
|
|
248
|
+
|
|
249
|
+
This is for the given padding (which is symmetric). Additional high padding
|
|
250
|
+
could be required when ceil mode is set.
|
|
251
|
+
"""
|
|
252
|
+
ceil_mode_padding = []
|
|
253
|
+
for i in range(len(padding)):
|
|
254
|
+
left_padding = padding[i]
|
|
255
|
+
right_padding = left_padding
|
|
256
|
+
|
|
257
|
+
input_size = input_shape[2 + i]
|
|
258
|
+
output_size_rem = (
|
|
259
|
+
input_size + 2 * left_padding - (kernel_size[i] - 1) * dilation[i] - 1
|
|
260
|
+
) % stride[i]
|
|
261
|
+
if ceil_mode and output_size_rem != 0:
|
|
262
|
+
extra_padding = stride[i] - output_size_rem
|
|
263
|
+
new_output_size = (
|
|
264
|
+
input_size
|
|
265
|
+
+ left_padding
|
|
266
|
+
+ right_padding
|
|
267
|
+
+ extra_padding
|
|
268
|
+
- (kernel_size[i] - 1) * dilation[i]
|
|
269
|
+
- 1
|
|
270
|
+
+ stride[i]
|
|
271
|
+
- 1
|
|
272
|
+
) // stride[i] + 1
|
|
273
|
+
# Ensure that the last pooling starts inside the image.
|
|
274
|
+
size_to_compare = input_size + left_padding
|
|
275
|
+
|
|
276
|
+
if (new_output_size - 1) * stride[i] < size_to_compare:
|
|
277
|
+
right_padding += extra_padding
|
|
278
|
+
|
|
279
|
+
ceil_mode_padding.append((left_padding, right_padding))
|
|
280
|
+
return ceil_mode_padding
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
def max_pool(
|
|
284
|
+
inputs,
|
|
285
|
+
kernel_size,
|
|
286
|
+
strides=None,
|
|
287
|
+
padding=0,
|
|
288
|
+
dilation=1,
|
|
289
|
+
ceil_mode=False,
|
|
290
|
+
with_index=False,
|
|
291
|
+
):
|
|
292
|
+
num_spatial_dims = len(kernel_size)
|
|
293
|
+
num_batch_dims = inputs.ndim - num_spatial_dims - 1
|
|
294
|
+
kernel_size_tup = tuple(kernel_size)
|
|
295
|
+
# Default stride is kernel_size
|
|
296
|
+
strides_tup = tuple(strides) if strides else kernel_size_tup
|
|
297
|
+
if isinstance(padding, int):
|
|
298
|
+
padding_list = [padding for _ in range(num_spatial_dims)]
|
|
299
|
+
elif not padding: # padding can be [], meaning all zeros.
|
|
300
|
+
padding_list = [0 for _ in range(num_spatial_dims)]
|
|
301
|
+
else:
|
|
302
|
+
padding_list = padding
|
|
303
|
+
|
|
304
|
+
if isinstance(dilation, int):
|
|
305
|
+
dilation_tup = tuple(dilation for _ in range(num_spatial_dims))
|
|
306
|
+
elif not dilation:
|
|
307
|
+
dilation_tup = tuple(1 for _ in range(num_spatial_dims))
|
|
308
|
+
elif isinstance(dilation, list):
|
|
309
|
+
dilation_tup = tuple(dilation)
|
|
310
|
+
else:
|
|
311
|
+
dilation_tup = dilation
|
|
312
|
+
|
|
313
|
+
input_shape_for_ceil = inputs.shape
|
|
314
|
+
if num_batch_dims == 0:
|
|
315
|
+
input_shape_for_ceil = [1, *input_shape_for_ceil]
|
|
316
|
+
padding_pairs = _ceil_mode_padding(
|
|
317
|
+
padding_list,
|
|
318
|
+
input_shape_for_ceil,
|
|
319
|
+
kernel_size_tup,
|
|
320
|
+
strides_tup,
|
|
321
|
+
dilation_tup,
|
|
322
|
+
ceil_mode,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
assert len(kernel_size_tup) == len(
|
|
326
|
+
strides_tup
|
|
327
|
+
), f"len({kernel_size_tup=}) must equal len({strides_tup=})"
|
|
328
|
+
assert len(kernel_size_tup) == len(
|
|
329
|
+
dilation_tup
|
|
330
|
+
), f"len({kernel_size_tup=}) must equal len({dilation_tup=})"
|
|
331
|
+
|
|
332
|
+
is_single_input = False
|
|
333
|
+
if num_batch_dims == 0:
|
|
334
|
+
inputs = inputs[None]
|
|
335
|
+
is_single_input = True
|
|
336
|
+
|
|
337
|
+
reduce_window_strides = (1,) * (inputs.ndim - num_spatial_dims) + strides_tup
|
|
338
|
+
reduce_window_dims = (1,) * (inputs.ndim - num_spatial_dims) + kernel_size_tup
|
|
339
|
+
reduce_window_dilation = (1,) * (
|
|
340
|
+
inputs.ndim - num_spatial_dims
|
|
341
|
+
) + dilation_tup
|
|
342
|
+
|
|
343
|
+
assert inputs.ndim == len(
|
|
344
|
+
reduce_window_dims
|
|
345
|
+
), f"len({inputs.shape}) != len({reduce_window_dims})"
|
|
346
|
+
if not isinstance(padding_pairs, str):
|
|
347
|
+
padding_pairs_tup = tuple(padding_pairs)
|
|
348
|
+
assert all(
|
|
349
|
+
[len(x) == 2 for x in padding_pairs_tup]
|
|
350
|
+
), f"each entry in padding {padding_pairs_tup} must be length 2"
|
|
351
|
+
padding_lax = ((0, 0),) * (
|
|
352
|
+
inputs.ndim - len(padding_pairs_tup)
|
|
353
|
+
) + padding_pairs_tup
|
|
354
|
+
else:
|
|
355
|
+
padding_lax = padding_pairs
|
|
356
|
+
|
|
357
|
+
indices = jnp.arange(
|
|
358
|
+
np.prod(inputs.shape[-num_spatial_dims:]), dtype=jnp.int64
|
|
359
|
+
)
|
|
360
|
+
indices = indices.reshape(inputs.shape[-num_spatial_dims:])
|
|
361
|
+
indices_shape = (1,) * (inputs.ndim - indices.ndim) + indices.shape
|
|
362
|
+
indices = jnp.broadcast_to(indices.reshape(indices_shape), inputs.shape)
|
|
363
|
+
|
|
364
|
+
return_dtype = inputs.dtype
|
|
365
|
+
if jnp.issubdtype(inputs.dtype, jnp.integer):
|
|
366
|
+
init_val = jnp.int32(jnp.iinfo(jnp.int32).min)
|
|
367
|
+
inputs = inputs.astype(jnp.int32)
|
|
368
|
+
else:
|
|
369
|
+
init_val = jnp.float32(-jnp.inf)
|
|
370
|
+
inputs = inputs.astype(jnp.float32)
|
|
371
|
+
|
|
372
|
+
if not with_index:
|
|
373
|
+
y = jax.lax.reduce_window(
|
|
374
|
+
inputs,
|
|
375
|
+
init_val,
|
|
376
|
+
jax.lax.max,
|
|
377
|
+
reduce_window_dims,
|
|
378
|
+
reduce_window_strides,
|
|
379
|
+
padding_lax,
|
|
380
|
+
window_dilation=reduce_window_dilation,
|
|
381
|
+
)
|
|
382
|
+
if is_single_input:
|
|
383
|
+
y = jnp.squeeze(y, axis=0)
|
|
384
|
+
return y.astype(return_dtype)
|
|
385
|
+
else:
|
|
386
|
+
|
|
387
|
+
def reduce_fn(a, b):
|
|
388
|
+
ai, av = a
|
|
389
|
+
bi, bv = b
|
|
390
|
+
which = av >= bv
|
|
391
|
+
return jnp.where(which, ai, bi), jnp.where(which, av, bv)
|
|
392
|
+
|
|
393
|
+
indices, y = jax.lax.reduce_window(
|
|
394
|
+
(indices, inputs),
|
|
395
|
+
(jnp.int64(0), init_val),
|
|
396
|
+
reduce_fn,
|
|
397
|
+
reduce_window_dims,
|
|
398
|
+
reduce_window_strides,
|
|
399
|
+
padding_lax,
|
|
400
|
+
window_dilation=reduce_window_dilation,
|
|
401
|
+
)
|
|
402
|
+
if is_single_input:
|
|
403
|
+
indices = jnp.squeeze(indices, axis=0)
|
|
404
|
+
y = jnp.squeeze(y, axis=0)
|
|
405
|
+
y = y.astype(return_dtype)
|
|
406
|
+
return y, indices
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
@lower_by_jax(torch.ops.aten.max_pool2d_with_indices)
|
|
410
|
+
def _aten_max_pool2d_with_indices(
|
|
411
|
+
self, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
|
|
412
|
+
):
|
|
413
|
+
stride = stride if stride is not None else []
|
|
414
|
+
y = max_pool(
|
|
415
|
+
self,
|
|
416
|
+
kernel_size,
|
|
417
|
+
strides=stride,
|
|
418
|
+
padding=padding,
|
|
419
|
+
dilation=dilation,
|
|
420
|
+
ceil_mode=ceil_mode,
|
|
421
|
+
with_index=False,
|
|
422
|
+
)
|
|
423
|
+
# TFLite's reduce_window kernel doesn't support multiple inputs/outputs,
|
|
424
|
+
# so we emit reduce_window with a single output and return dummy indices.
|
|
425
|
+
return y, jnp.zeros_like(y, dtype=jnp.int64)
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
@lower_by_jax(torch.ops.aten.max_pool3d_with_indices.default)
|
|
429
|
+
def _aten_max_pool3d_with_indices(
|
|
430
|
+
self, kernel_size, stride=None, padding=0, dilation=1, ceil_mode=False
|
|
431
|
+
):
|
|
432
|
+
stride = stride if stride is not None else []
|
|
433
|
+
y = max_pool(
|
|
434
|
+
self,
|
|
435
|
+
kernel_size,
|
|
436
|
+
strides=stride,
|
|
437
|
+
padding=padding,
|
|
438
|
+
dilation=dilation,
|
|
439
|
+
ceil_mode=ceil_mode,
|
|
440
|
+
with_index=False,
|
|
441
|
+
)
|
|
442
|
+
# TFLite's reduce_window kernel doesn't support multiple inputs/outputs,
|
|
443
|
+
# so we emit reduce_window with a single output and return dummy indices.
|
|
444
|
+
return y, jnp.zeros_like(y, dtype=jnp.int64)
|
|
445
|
+
|
|
446
|
+
|
|
447
|
+
@lower_by_jax(torch.ops.aten.pixel_shuffle)
|
|
448
|
+
def _aten_pixel_shuffle(x, upscale_factor):
|
|
449
|
+
"""PixelShuffle implementation in JAX lowering.
|
|
450
|
+
|
|
451
|
+
Args:
|
|
452
|
+
x: Input tensor. Typically a feature map.
|
|
453
|
+
upscale_factor: Integer by which to upscale the spatial dimensions.
|
|
454
|
+
|
|
455
|
+
Returns:
|
|
456
|
+
Tensor after PixelShuffle operation.
|
|
457
|
+
"""
|
|
458
|
+
|
|
459
|
+
batch_size, channels, height, width = x.shape
|
|
460
|
+
|
|
461
|
+
if channels % (upscale_factor**2) != 0:
|
|
462
|
+
raise ValueError(
|
|
463
|
+
"Number of channels must be divisible by the square of the upscale"
|
|
464
|
+
" factor."
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
new_channels = channels // (upscale_factor**2)
|
|
468
|
+
new_height = height * upscale_factor
|
|
469
|
+
new_width = width * upscale_factor
|
|
470
|
+
|
|
471
|
+
x = x.reshape(
|
|
472
|
+
batch_size, new_channels, upscale_factor, upscale_factor, height, width
|
|
473
|
+
)
|
|
474
|
+
x = jnp.transpose(
|
|
475
|
+
x, (0, 1, 4, 2, 5, 3)
|
|
476
|
+
) # Move channels to spatial dimensions
|
|
477
|
+
x = x.reshape(batch_size, new_channels, new_height, new_width)
|
|
478
|
+
|
|
479
|
+
return x
|
|
480
|
+
|
|
481
|
+
|
|
243
482
|
@lower_by_jax(torch.ops.aten.unbind)
|
|
244
483
|
def _aten_copy(self, *args, **kwargs):
|
|
245
484
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.unbind_copy](self, *args, **kwargs)
|
|
@@ -250,6 +489,17 @@ def _aten_copy(self, src, **kwargs):
|
|
|
250
489
|
return _TORCH_XLA2_IMPLS[torch.ops.aten.copy](self, src)
|
|
251
490
|
|
|
252
491
|
|
|
492
|
+
@lower_by_jax(torch.ops.aten.elu.default)
|
|
493
|
+
def _aten_elu(self, alpha=1.0, scale=1.0, input_scale=1.0):
|
|
494
|
+
pos_coef = scale
|
|
495
|
+
neg_coef = alpha * scale
|
|
496
|
+
neg_input_coef = input_scale
|
|
497
|
+
|
|
498
|
+
pos_branch = self * pos_coef
|
|
499
|
+
neg_branch = jnp.expm1(self * neg_input_coef) * neg_coef
|
|
500
|
+
return jnp.where(self >= 0, pos_branch, neg_branch)
|
|
501
|
+
|
|
502
|
+
|
|
253
503
|
@registry.lower(torch.ops.aten.add.Scalar)
|
|
254
504
|
def _aten_add_scalar(lctx: LoweringContext, self, other):
|
|
255
505
|
_log_usage(torch.ops.aten.add.Scalar)
|
ai_edge_torch/version.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: ai-edge-torch-nightly
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.8.0.dev20251225
|
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
|
@@ -13,6 +13,8 @@ Classifier: Programming Language :: Python :: 3
|
|
|
13
13
|
Classifier: Programming Language :: Python :: 3 :: Only
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.10
|
|
15
15
|
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
16
18
|
Classifier: Topic :: Scientific/Engineering
|
|
17
19
|
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
18
20
|
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
@@ -37,7 +39,17 @@ Requires-Dist: ai-edge-quantizer-nightly
|
|
|
37
39
|
Requires-Dist: jax
|
|
38
40
|
Requires-Dist: torch-xla2[odml]>=0.0.1.dev20241201
|
|
39
41
|
Provides-Extra: torch-xla
|
|
40
|
-
Requires-Dist:
|
|
42
|
+
Requires-Dist: torch_xla>=2.4.0; extra == "torch-xla"
|
|
43
|
+
Dynamic: classifier
|
|
44
|
+
Dynamic: description
|
|
45
|
+
Dynamic: description-content-type
|
|
46
|
+
Dynamic: home-page
|
|
47
|
+
Dynamic: keywords
|
|
48
|
+
Dynamic: license-file
|
|
49
|
+
Dynamic: provides-extra
|
|
50
|
+
Dynamic: requires-dist
|
|
51
|
+
Dynamic: requires-python
|
|
52
|
+
Dynamic: summary
|
|
41
53
|
|
|
42
54
|
Library that supports converting PyTorch models into a .tflite format, which can
|
|
43
55
|
then be run with TensorFlow Lite and MediaPipe. This enables applications for
|
|
@@ -2,9 +2,9 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
|
5
|
-
ai_edge_torch/version.py,sha256=
|
|
5
|
+
ai_edge_torch/version.py,sha256=EqYE0SbfgYjfk194m19-ExhokdnUqLLGwlHCgT7w_rM,806
|
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
7
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
|
7
|
+
ai_edge_torch/_convert/conversion.py,sha256=JqGZZGbpTmYiT-ta07IQbJ9-gFm-3Vip2aSzW9ulIng,6117
|
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
|
9
9
|
ai_edge_torch/_convert/converter.py,sha256=6MLKELzAwFoiXv-b7KRYi7gc7Z57XOeowcz9ArIl9TM,12100
|
|
10
10
|
ai_edge_torch/_convert/signature.py,sha256=-YKJdLk-eNEHfhdPCtcQVtZf915SoVePEFxKXPPf16c,2572
|
|
@@ -41,9 +41,9 @@ ai_edge_torch/examples/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzN
|
|
|
41
41
|
ai_edge_torch/examples/selfie_segmentation/__init__.py,sha256=JaAnrFoXTl3RJX97XspklkTyqOHVyAgRJsZtzNDd10c,671
|
|
42
42
|
ai_edge_torch/examples/selfie_segmentation/model.py,sha256=5otCH1MzNgSP0fikYq53hgiO1F0ZN1SCVzOIo7cVAcA,17136
|
|
43
43
|
ai_edge_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
44
|
-
ai_edge_torch/fx_infra/__init__.py,sha256=
|
|
44
|
+
ai_edge_torch/fx_infra/__init__.py,sha256=bseeaX7oDyyvl_oAIT2MfDYZWITmNpn660AlBOUcyUc,1418
|
|
45
45
|
ai_edge_torch/fx_infra/_canonicalize_pass.py,sha256=GDRoDdPVQw--QQFTT5J_C3TVuphL31m6K6F1-67SE4s,1097
|
|
46
|
-
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=
|
|
46
|
+
ai_edge_torch/fx_infra/_safe_run_decompositions.py,sha256=3rGVQj7OgrIBd--olhnGpSyF2kvpjXMfADss6VPsvxQ,4167
|
|
47
47
|
ai_edge_torch/fx_infra/decomp.py,sha256=S58SCgwMHYVFl_hJwlJxvu2wcI-AGNn82gel3qmTPrU,2500
|
|
48
48
|
ai_edge_torch/fx_infra/graph_utils.py,sha256=nqGe-xIJ77RamSUh0UYyI2XHOsZqFDWax-vpRAtVR_E,2796
|
|
49
49
|
ai_edge_torch/fx_infra/pass_base.py,sha256=Ic2AlhSoRFscz6l7gJKvWVNMDLQFfAw5kRf84-ZR9qM,2904
|
|
@@ -141,7 +141,7 @@ ai_edge_torch/generative/examples/smollm/verify_util.py,sha256=tHwmo8E474gLKAu4I
|
|
|
141
141
|
ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
142
142
|
ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
|
|
143
143
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=lSCRZsoLjH_kqasRMwCy5IogkhyJdwcHKsPEfyxsXCQ,6112
|
|
144
|
-
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=
|
|
144
|
+
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=Wba7-rEibvc7ii0uQhO87meVqdQC0_3mtYj5PRXXXbE,5584
|
|
145
145
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=afyHXc86h-ij5zTULmZnM1h313N9VWCyIVriH6pqeSo,16368
|
|
146
146
|
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=ylqXOZhYc6XFCaNBKQw0jAnYrCtRFFQKzQzEsFIntvo,34890
|
|
147
147
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
|
@@ -151,7 +151,7 @@ ai_edge_torch/generative/examples/stable_diffusion/util.py,sha256=XIXIB0vCvQKOGy
|
|
|
151
151
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/__init__.py,sha256=uQWKzCD_49ackNFrt50H04dkDXxfAwUCtMWWQre5SVE,830
|
|
152
152
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler.py,sha256=wBBNM24waZ57M1rXonwesfUkKe9DqpqO3eW6BfZkrD0,2323
|
|
153
153
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_euler_ancestral.py,sha256=c89ldwtuQ2_yspGrGa7oh7fsvTt6A86Whxa6fBK9YOQ,2526
|
|
154
|
-
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=
|
|
154
|
+
ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=NDiJPreRFHW1l50Lmb1FHNQcs3wslFAhS4FEFTQKjMU,2805
|
|
155
155
|
ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=RxR5rw0wFFm_5CfAY-3-EIz83vhM9EKye8Bb5zBb0Ok,1341
|
|
156
156
|
ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
157
157
|
ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=HHtZTtUh3QgE4F74-ru_8n1pt6cqfbObw12xoaMJ7NQ,4596
|
|
@@ -169,49 +169,49 @@ ai_edge_torch/generative/examples/tiny_llama/verify_util.py,sha256=z6vPBXDWAL6gN
|
|
|
169
169
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=PFSMsA1vfBfrV9ssBCkYJNl8Hx_bLdWjN01iyjPM5jE,1094
|
|
170
170
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A8yIBoqgArd2k40rZmCgD1Ya369KR7182bhI,2129
|
|
171
171
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
172
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
|
173
|
-
ai_edge_torch/generative/layers/attention_test.py,sha256=
|
|
174
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
|
175
|
-
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=
|
|
176
|
-
ai_edge_torch/generative/layers/builder.py,sha256=
|
|
172
|
+
ai_edge_torch/generative/layers/attention.py,sha256=ZjU3vX-7gOq1KQb3xSZ1NT3xryOTXbYb_vkx_DlcizA,14524
|
|
173
|
+
ai_edge_torch/generative/layers/attention_test.py,sha256=ON9jQRY1r2kFpVq-Qkg6b13Ob95fd4PqHo1hic3RbOQ,5057
|
|
174
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=3Ox1XjW_vaqz1-RuVG9RbzRKUqCberFW8P2BQcoNm7A,9659
|
|
175
|
+
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=IHIk39wqaPvxmkZtW27VD3_4xUpyFow_7mScf8OWdqU,3292
|
|
176
|
+
ai_edge_torch/generative/layers/builder.py,sha256=5QL59CbOOW_mk3mlPdcdirGcAxdLee5atbZlnu5Z3ts,5079
|
|
177
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=LH4CNHr-pFfLUuCpwbYL3GpoAMgHJ4nLju3XCqA4VwM,1416
|
|
178
178
|
ai_edge_torch/generative/layers/einsum_test.py,sha256=ltIE773bvvNLv_9aLQxFwe1MgQ762sez0c5E2tejxuA,1079
|
|
179
179
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=_GmtHxwL068l9gh_F_WFcFk7La-Tl5SfoQ9v2hMabZM,5541
|
|
180
180
|
ai_edge_torch/generative/layers/feed_forward_test.py,sha256=Y5l1eC9NgfYixHcfIfE1W4FGh7oC-9UGGyHdKS9tQKc,1880
|
|
181
181
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=A0IFXZ1HD2ZHOWRLfsDO4almgE0KQfjyBOdBFZIGnAs,10893
|
|
182
182
|
ai_edge_torch/generative/layers/lora.py,sha256=hsvWLLOnW7HQ0AysOZu30x_cetMquDd1tjfyLz8HCSU,17892
|
|
183
|
-
ai_edge_torch/generative/layers/model_config.py,sha256=
|
|
184
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
|
183
|
+
ai_edge_torch/generative/layers/model_config.py,sha256=y6nyydMzm7JUV7AKttcFy3tvti-nE6tRXoVbBB9dyiM,10438
|
|
184
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=syasVh3dRDVp2Nwhl0x7zucL-chTnCqWgeV1mb87DFY,7435
|
|
185
185
|
ai_edge_torch/generative/layers/normalization_test.py,sha256=zwurZly-TgFxdgVVdpzu9vCpcLbd5RYt_gKg9Lfg1jI,2248
|
|
186
186
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=975zR202MdIrILJ7blceAcxrNqX1ZCN0ECKG1gz-bV8,2655
|
|
187
|
-
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=
|
|
187
|
+
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=1zhOsJpI4CTn78weOs0uRwkRxYu6wGfBvYVFpGFl0qQ,6681
|
|
188
188
|
ai_edge_torch/generative/layers/scaled_dot_product_attention_test.py,sha256=c6JBMQsq9XeMmR1XvGEIidNsoh-YIvichXo2LwVHgr4,3301
|
|
189
|
-
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=
|
|
189
|
+
ai_edge_torch/generative/layers/sdpa_with_kv_update.py,sha256=fK_h9M-03ai5dV8ZyQzvB0y84IKlNg9h-4bt9F6bU0g,3833
|
|
190
190
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
191
191
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZteHZXK6HKyxYji49DQ46sA9aIy7U3Jnz0HZp6hfevY,28996
|
|
192
192
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
|
193
193
|
ai_edge_torch/generative/layers/unet/model_config.py,sha256=pPDwLawc23pfMaPVyMJlYmxVVusjMvx-l8wBwOYOH-c,9692
|
|
194
194
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
195
|
-
ai_edge_torch/generative/quantize/example.py,sha256=
|
|
196
|
-
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=
|
|
197
|
-
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=
|
|
198
|
-
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=
|
|
199
|
-
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=
|
|
200
|
-
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=
|
|
195
|
+
ai_edge_torch/generative/quantize/example.py,sha256=rRm5noeqnCJyaDYG6e1micwaYdA3ooXK0MBYl98neQQ,1742
|
|
196
|
+
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=THKS82G9sR25biEk8AVRFWwB1cjVXpbjZD5IBfNOOMw,2344
|
|
197
|
+
ai_edge_torch/generative/quantize/quant_recipe.py,sha256=kL3UPNscWKzugiyunMn5XK6Lzpw-yXTdtM_NmCNbFUM,5330
|
|
198
|
+
ai_edge_torch/generative/quantize/quant_recipe_utils.py,sha256=bFvY-OnbH6fe5xhS6dfxkAl7NTGzOVtGJ_8417GwGgI,2478
|
|
199
|
+
ai_edge_torch/generative/quantize/quant_recipes.py,sha256=nvADL7wObxJLp651HI1xv5_HeQZ2RaY2swoRC1EOK7k,2783
|
|
200
|
+
ai_edge_torch/generative/quantize/supported_schemes.py,sha256=lEcqwxOIfxhTYCQUelODLky57933r372Uh7sHYNV0Ok,1712
|
|
201
201
|
ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
202
202
|
ai_edge_torch/generative/test/test_custom_dus.py,sha256=MjIhTvkTko872M35XMciobvICcDWTcIDJ3rociko-wM,3267
|
|
203
|
-
ai_edge_torch/generative/test/test_kv_cache.py,sha256=
|
|
203
|
+
ai_edge_torch/generative/test/test_kv_cache.py,sha256=kdoBpyzVOjQsZkEPnbDOh33gRE0RR8tq1XCgJkSYwU0,6157
|
|
204
204
|
ai_edge_torch/generative/test/test_loader.py,sha256=9mQUeeZKOVApOWSWl2cN9c10axZjMKM1-0Zd823CCS4,3449
|
|
205
205
|
ai_edge_torch/generative/test/test_lora.py,sha256=sKnBixmGIXHUtOwh3SA4MIFenPbjK2n-Xknwic_KMDQ,5046
|
|
206
206
|
ai_edge_torch/generative/test/test_model_conversion.py,sha256=T35zdzag2-nmy4qc6AifAjbDXAHU2vyLTE1QCabYBzk,6298
|
|
207
207
|
ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=NkEwrjO8vIcde3XwanpFBhNIw1GSOyJFKNjlvSJmVMY,13271
|
|
208
|
-
ai_edge_torch/generative/test/test_quantize.py,sha256=
|
|
208
|
+
ai_edge_torch/generative/test/test_quantize.py,sha256=Sh82oxDIc_vw-rCd991qMNQxzM3OkI3oiPj2g4hD5MQ,7088
|
|
209
209
|
ai_edge_torch/generative/test/utils.py,sha256=tF6aCfAGJnc9dmzCnZCEOuKNVimfWOqscv9og0DDLHU,2656
|
|
210
210
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
|
211
|
-
ai_edge_torch/generative/utilities/converter.py,sha256=
|
|
211
|
+
ai_edge_torch/generative/utilities/converter.py,sha256=B2DNg5LcDoniHnh22KFDDSHqnmi8w87-dstKCae3CeM,23322
|
|
212
212
|
ai_edge_torch/generative/utilities/export_config.py,sha256=5B15nYyqf96kjjYlHfPctUfsIdsBsh1f8rxKitJpwKQ,2384
|
|
213
|
-
ai_edge_torch/generative/utilities/litertlm_builder.py,sha256=
|
|
214
|
-
ai_edge_torch/generative/utilities/loader.py,sha256=
|
|
213
|
+
ai_edge_torch/generative/utilities/litertlm_builder.py,sha256=Jy6R4S18JCIk77ZwrfzeOCSaGozvRQfXSobGkwOEOKs,6447
|
|
214
|
+
ai_edge_torch/generative/utilities/loader.py,sha256=QQeEu0cTC7gWnB7RkHonjWLdVGjMbDHd1lfYO_TcJyU,16047
|
|
215
215
|
ai_edge_torch/generative/utilities/model_builder.py,sha256=xBvcTxihB9TN88UtQiXA9sAITQgf-pA77R-VZlLgUeU,6950
|
|
216
216
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
|
217
217
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
|
@@ -232,11 +232,11 @@ ai_edge_torch/lowertools/common_utils.py,sha256=4HQtquPZ6oiId8vR_1ykW_uK4ELnyo5z
|
|
|
232
232
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=QRuS7S5lULRWEh3J1sWIsnKh-rbX7rd9tt6JJHbMPfo,8317
|
|
233
233
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
|
234
234
|
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
|
235
|
-
ai_edge_torch/lowertools/translate_recipe.py,sha256=
|
|
235
|
+
ai_edge_torch/lowertools/translate_recipe.py,sha256=WEv8Kr0RZw0pBc5PoNUEPnQGBW_JUmuVWQ7Hf4bw0uY,6545
|
|
236
236
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
|
237
237
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=jSYHf1CMTJzMizPMbu2b39hAt0ZTR6gQLq67GMe9KTo,2336
|
|
238
238
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
|
239
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
|
239
|
+
ai_edge_torch/odml_torch/export.py,sha256=PdNdFtEZcFpy4G8ikhElnil8yoYZUWHQ_180E1BGFV4,15347
|
|
240
240
|
ai_edge_torch/odml_torch/export_utils.py,sha256=Eax4QUefIzpmVuQxo1y9FqJ6g0qXjg4C0IVZ5uYPscs,4899
|
|
241
241
|
ai_edge_torch/odml_torch/optimization_barrier.py,sha256=2lmSiu5iXWLFWpupZHvsVeNYNzG5AVGSK3K_CNhS5Sk,2290
|
|
242
242
|
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
|
@@ -246,15 +246,21 @@ ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=h6DQkYV
|
|
|
246
246
|
ai_edge_torch/odml_torch/debuginfo/__init__.py,sha256=3A_lMyj-B-DOhLJG6WmjKvZK5te2rXje8FrfqOhZsN0,959
|
|
247
247
|
ai_edge_torch/odml_torch/debuginfo/_build.py,sha256=6Ns2rlfOilLJEk5cUxlkRwm2uxOgEF2-0S2DMcOqr6A,3319
|
|
248
248
|
ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW-1YElU9BPWzKtJA2eKWOI,1739
|
|
249
|
+
ai_edge_torch/odml_torch/experimental/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
|
250
|
+
ai_edge_torch/odml_torch/experimental/torch_tfl/__init__.py,sha256=MN-8wS8vZmSWr9TOoZivqzZGDPoNpy15dByM3arNnM8,968
|
|
251
|
+
ai_edge_torch/odml_torch/experimental/torch_tfl/_decomps.py,sha256=KEPc8wZeAukFhlfKUA2TSmRdMozqG16msnEw7JsGrTg,13256
|
|
252
|
+
ai_edge_torch/odml_torch/experimental/torch_tfl/_lowerings.py,sha256=9coRYoVQFtGhcWuMgZf7lRGIMpRqupbyueAqShx6ieQ,21065
|
|
253
|
+
ai_edge_torch/odml_torch/experimental/torch_tfl/_ops.py,sha256=gRoAf6QjER-fb4oVB05vVGDV__HiXeJgEo1WXX7TnkM,10778
|
|
254
|
+
ai_edge_torch/odml_torch/experimental/torch_tfl/torch_library_utils.py,sha256=9Oq3Dsc5YAkpBgPNFWTquT4TVQnOG8LOlwIabDIc10k,1236
|
|
249
255
|
ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=e9Oa4J3An9FYr3zM0OzjzyNNitEeg-IoBUGNfUxsDSA,798
|
|
250
256
|
ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=LqwZ1vCJTSOzgzvH8LUAN-sAkF-l_pGj1AMEIzAqHCA,6638
|
|
251
257
|
ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
|
|
252
258
|
ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=uJ-niilt1c-D6QJzLwgvCUf62le_JsxQTlqj_iP_Ps0,1009
|
|
253
|
-
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=
|
|
259
|
+
ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=HOTYfQWin8tqi1yakIyardxhRViZ6rhLV6ZomMSS7zA,17554
|
|
254
260
|
ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
|
|
255
261
|
ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=Q0aDzyUcZMoSzSbOU-r3LJMgPe6fble0QwdYVIOHHHk,6887
|
|
256
|
-
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=
|
|
257
|
-
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=
|
|
262
|
+
ai_edge_torch/odml_torch/lowerings/_decomp_registry.py,sha256=LmSj5RsZBi00EE7KfF3dI2U0e60LMHA6mDKc-TC2U0U,5486
|
|
263
|
+
ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=YmyM-5HJeeYaIhmKTOnCjfX3_A1PPh1gPGUi1d8EBs8,26454
|
|
258
264
|
ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=khJIvDVk2s332Nd2Be-5dM6-wp5DGff61HCV5lskHmQ,3011
|
|
259
265
|
ai_edge_torch/odml_torch/lowerings/_quantized_decomposed.py,sha256=XDZ0zLej_XaQDJnaAAxhNFAd7NfQm5SOVEp_nno_krA,6178
|
|
260
266
|
ai_edge_torch/odml_torch/lowerings/_rand.py,sha256=g6SuqDkuC6hD35lyP1-5H7ASDIzPSmKukeNT5naZSv8,4133
|
|
@@ -270,8 +276,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
|
270
276
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
|
271
277
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
|
272
278
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
|
273
|
-
ai_edge_torch_nightly-0.
|
|
274
|
-
ai_edge_torch_nightly-0.
|
|
275
|
-
ai_edge_torch_nightly-0.
|
|
276
|
-
ai_edge_torch_nightly-0.
|
|
277
|
-
ai_edge_torch_nightly-0.
|
|
279
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/licenses/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
|
280
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/METADATA,sha256=aFJe9FPOOmsDU7-axoCav2W0-n4fRxf0lApj08AZb0s,2399
|
|
281
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
|
282
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
|
283
|
+
ai_edge_torch_nightly-0.8.0.dev20251225.dist-info/RECORD,,
|
|
File without changes
|