ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250107__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/conversion.py +22 -18
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +29 -4
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/utilities/model_builder.py +14 -14
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
- ai_edge_torch/odml_torch/export.py +1 -6
- ai_edge_torch/odml_torch/tf_integration.py +12 -50
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/RECORD +21 -21
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py
CHANGED
@@ -22,6 +22,18 @@ import os
|
|
22
22
|
__all__ = ["config"]
|
23
23
|
|
24
24
|
|
25
|
+
def _get_bool_env_var(name: str, default: bool) -> bool:
|
26
|
+
var = os.environ.get(name, "false")
|
27
|
+
var = var.lower().strip()
|
28
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
29
|
+
return True
|
30
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
31
|
+
return False
|
32
|
+
else:
|
33
|
+
logging.warning("Invalid %s value is ignored: %s.", name, var)
|
34
|
+
return default
|
35
|
+
|
36
|
+
|
25
37
|
class _Config:
|
26
38
|
"""ai-edge-torch global configs."""
|
27
39
|
|
@@ -33,20 +45,25 @@ class _Config:
|
|
33
45
|
To use torch_xla as the lowering backend, set environment variable
|
34
46
|
`USE_TORCH_XLA` to "true".
|
35
47
|
"""
|
36
|
-
|
37
|
-
var = var.lower().strip()
|
38
|
-
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
-
return True
|
40
|
-
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
-
return False
|
42
|
-
else:
|
43
|
-
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
-
return False
|
48
|
+
return _get_bool_env_var("USE_TORCH_XLA", default=False)
|
45
49
|
|
46
50
|
@property
|
47
51
|
def in_oss(self) -> bool:
|
48
52
|
"""True if the code is not running in google internal environment."""
|
49
53
|
return True
|
50
54
|
|
55
|
+
@property
|
56
|
+
def enable_group_norm_composite(self) -> bool:
|
57
|
+
"""True if lowering group norm in StableHLO composite.
|
58
|
+
|
59
|
+
Currently only supports NHWC group norm generated by
|
60
|
+
OptimizeLayoutTransposesPass.
|
61
|
+
"""
|
62
|
+
return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
|
63
|
+
|
64
|
+
@enable_group_norm_composite.setter
|
65
|
+
def enable_group_norm_composite(self, value: bool):
|
66
|
+
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
|
67
|
+
|
51
68
|
|
52
69
|
config = _Config()
|
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import logging
|
17
|
-
import os
|
18
17
|
from typing import Any, Literal, Optional, Union
|
19
18
|
|
19
|
+
import ai_edge_torch
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
-
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
30
|
-
|
31
29
|
|
32
30
|
def _run_convert_passes(
|
33
31
|
exported_program: torch.export.ExportedProgram,
|
@@ -35,21 +33,27 @@ def _run_convert_passes(
|
|
35
33
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
34
|
exported_program
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
36
|
+
|
37
|
+
passes = [
|
38
|
+
fx_passes.BuildInterpolateCompositePass(),
|
39
|
+
fx_passes.CanonicalizePass(),
|
40
|
+
fx_passes.OptimizeLayoutTransposesPass(),
|
41
|
+
fx_passes.CanonicalizePass(),
|
42
|
+
fx_passes.BuildAtenCompositePass(),
|
43
|
+
fx_passes.CanonicalizePass(),
|
44
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
45
|
+
fx_passes.CanonicalizePass(),
|
46
|
+
]
|
47
|
+
|
48
|
+
# Debuginfo is not injected automatically by odml_torch. Only inject
|
49
|
+
# debuginfo via fx pass when using torch_xla.
|
50
|
+
if ai_edge_torch.config.use_torch_xla:
|
51
|
+
passes += [
|
52
|
+
fx_passes.InjectMlirDebuginfoPass(),
|
53
|
+
fx_passes.CanonicalizePass(),
|
54
|
+
]
|
55
|
+
|
56
|
+
exported_program = fx_pass_base.run_passes(exported_program, passes)
|
53
57
|
return exported_program
|
54
58
|
|
55
59
|
|
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
62
|
|
63
63
|
|
64
64
|
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
65
|
+
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
65
66
|
|
66
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
67
68
|
for node in graph_module.graph.nodes:
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
import operator
|
19
19
|
|
20
|
+
import ai_edge_torch
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
155
156
|
@layout_sensitive_inputs_getters.register(
|
156
157
|
aten._native_batch_norm_legit_no_training
|
157
158
|
)
|
159
|
+
@layout_sensitive_inputs_getters.register(aten.group_norm)
|
158
160
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
159
161
|
def _first_arg_getter(node):
|
160
162
|
return [node.args[0]]
|
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
|
|
188
190
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
189
191
|
|
190
192
|
|
193
|
+
@nhwcable_node_checkers.register(aten.group_norm)
|
194
|
+
def _aten_group_norm_checker(node):
|
195
|
+
val = node.meta.get("val")
|
196
|
+
if not hasattr(val, "shape"):
|
197
|
+
return NHWCable(can_be=False, must_be=False)
|
198
|
+
|
199
|
+
can_be = len(val.shape) == 4
|
200
|
+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
|
201
|
+
return NHWCable(can_be=can_be, must_be=must_be)
|
202
|
+
|
203
|
+
|
191
204
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
192
205
|
def _aten_native_group_norm_checker(node):
|
193
206
|
val = node.meta.get("val")
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import operator
|
18
18
|
|
19
|
+
import ai_edge_torch
|
19
20
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -23,6 +24,7 @@ import torch
|
|
23
24
|
import torch.utils._pytree as pytree
|
24
25
|
|
25
26
|
aten = torch.ops.aten
|
27
|
+
StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
|
26
28
|
|
27
29
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
28
30
|
|
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
342
344
|
node.target = batch_norm
|
343
345
|
|
344
346
|
|
347
|
+
@rewriters.register(aten.group_norm.default)
|
348
|
+
def _aten_group_norm(node):
|
349
|
+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
|
350
|
+
is_composite_supported = (
|
351
|
+
ai_edge_torch.config.enable_group_norm_composite
|
352
|
+
and weight is not None
|
353
|
+
and bias is not None
|
354
|
+
)
|
355
|
+
|
356
|
+
builder = None
|
357
|
+
if is_composite_supported:
|
358
|
+
builder = StableHLOCompositeBuilder(
|
359
|
+
name="odml.group_norm",
|
360
|
+
attr={
|
361
|
+
"num_groups": num_groups,
|
362
|
+
"epsilon": eps,
|
363
|
+
"reduction_axes": [3],
|
364
|
+
"channel_axis": 3,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
input, weight, bias = builder.mark_inputs(input, weight, bias)
|
368
|
+
|
369
|
+
input = utils.tensor_to_nchw(input)
|
370
|
+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
371
|
+
output = utils.tensor_to_nhwc(output)
|
372
|
+
|
373
|
+
if builder is not None:
|
374
|
+
output = builder.mark_outputs(output)
|
375
|
+
return output
|
376
|
+
|
377
|
+
node.target = group_norm
|
378
|
+
|
379
|
+
|
345
380
|
@rewriters.register(aten.native_group_norm.default)
|
346
381
|
def _aten_native_group_norm(node):
|
347
382
|
|
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
|
|
354
389
|
flattened_inner_size: int,
|
355
390
|
num_groups: int,
|
356
391
|
eps: float,
|
392
|
+
**kwargs,
|
357
393
|
):
|
358
394
|
input_reshaped = torch.reshape(
|
359
395
|
input,
|
@@ -15,14 +15,13 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
25
|
from ai_edge_torch.generative.utilities import model_builder
|
27
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
27
|
import torch
|
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
|
|
104
103
|
config.embedding_dim,
|
105
104
|
config.final_norm_config,
|
106
105
|
)
|
107
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
108
|
-
size=config.kv_cache_max,
|
109
|
-
)
|
110
106
|
# Gemma2 has same hyper parameters for each layer except for attention
|
111
107
|
# types. Use the first layer.
|
112
108
|
attn_config = config.block_config(0).attn_config
|
109
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
+
size=config.kv_cache_max,
|
111
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
+
base=attn_config.rotary_base,
|
113
|
+
)
|
114
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
|
+
size=config.kv_cache_max,
|
116
|
+
)
|
113
117
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
114
118
|
size=config.kv_cache_max,
|
115
119
|
window_size=attn_config.sliding_window_size,
|
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
|
|
136
140
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
137
141
|
f" {self.config.max_seq_len}"
|
138
142
|
)
|
139
|
-
|
140
|
-
# token embeddings of shape (b, t, n_embd)
|
141
|
-
input_embeds = self.tok_embedding(tokens)
|
142
|
-
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
-
attn_config = self.config.block_config(0).attn_config
|
144
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
-
rope = rotary_pos_emb.build_rope(
|
146
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
-
)
|
148
|
-
mask = [self.get_attention_mask(
|
149
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
-
) for i in range(self.config.num_layers)]
|
151
|
-
|
152
|
-
return self._forward_with_embeds(
|
153
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
-
)
|
155
|
-
|
156
|
-
def _forward_with_embeds(
|
157
|
-
self,
|
158
|
-
input_embeds: torch.Tensor,
|
159
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
-
mask: List[torch.Tensor],
|
161
|
-
input_pos: torch.Tensor,
|
162
|
-
kv_cache: kv_utils.KVCache,
|
163
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
-
"""Forwards the model with input embeddings."""
|
166
143
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
167
144
|
"The number of transformer blocks and the number of KV cache entries"
|
168
145
|
" must be the same."
|
169
146
|
)
|
170
147
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
148
|
+
cos, sin = self.rope_cache
|
149
|
+
cos = cos.index_select(0, input_pos)
|
150
|
+
sin = sin.index_select(0, input_pos)
|
151
|
+
|
152
|
+
# token embeddings of shape (b, t, n_embd)
|
153
|
+
x = self.tok_embedding(tokens)
|
154
|
+
x = x * (self.config.embedding_dim**0.5)
|
155
|
+
|
156
|
+
updated_kv_entires = []
|
175
157
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
+
mask = self.get_attention_mask(
|
159
|
+
block.config.attn_config.attn_type, input_pos
|
160
|
+
)
|
176
161
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
177
|
-
x, kv_entry = block(x,
|
162
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
178
163
|
if kv_entry:
|
179
|
-
|
180
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
164
|
+
updated_kv_entires.append(kv_entry)
|
165
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
181
166
|
|
182
167
|
if export_config is not None:
|
183
168
|
if (
|
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
243
228
|
)
|
244
229
|
|
245
230
|
num_layers = 26
|
246
|
-
embedding_dim = 2304
|
247
231
|
config = cfg.ModelConfig(
|
248
232
|
vocab_size=256000,
|
249
233
|
num_layers=num_layers,
|
250
234
|
max_seq_len=8192,
|
251
|
-
embedding_dim=
|
252
|
-
embedding_scale=embedding_dim**0.5,
|
235
|
+
embedding_dim=2304,
|
253
236
|
kv_cache_max_len=kv_cache_max_len,
|
254
237
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
255
238
|
final_norm_config=norm_config,
|
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
266
249
|
config.num_layers = 2
|
267
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
268
251
|
config.embedding_dim = 128
|
269
|
-
config.embedding_scale = config.embedding_dim**0.5
|
270
252
|
config.block_configs = config.block_configs[: config.num_layers]
|
271
253
|
for block_config in config.block_configs:
|
272
254
|
block_config.attn_config.num_heads = 4
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entires = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entires.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -26,6 +26,33 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
|
29
|
+
def _embed_rope(
|
30
|
+
q: torch.Tensor,
|
31
|
+
k: torch.Tensor,
|
32
|
+
n_elem: int,
|
33
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
+
"""Embed rotary positional embedding for query and key.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
q (torch.Tensor): query tensor.
|
39
|
+
k (torch.Tensor): key tensor.
|
40
|
+
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
+
"""
|
43
|
+
if n_elem > 0:
|
44
|
+
cos, sin = rope
|
45
|
+
q_roped = rotary_pos_emb.apply_rope(
|
46
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
+
)
|
48
|
+
k_roped = rotary_pos_emb.apply_rope(
|
49
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
+
)
|
51
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
+
return q, k
|
54
|
+
|
55
|
+
|
29
56
|
class TransformerBlock(nn.Module):
|
30
57
|
|
31
58
|
def __init__(
|
@@ -211,8 +238,7 @@ class CausalSelfAttention(nn.Module):
|
|
211
238
|
if rope is not None:
|
212
239
|
# Compute rotary positional embedding for query and key.
|
213
240
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
214
|
-
|
215
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
241
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
216
242
|
|
217
243
|
if kv_cache is not None:
|
218
244
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -348,8 +374,7 @@ class CrossAttention(nn.Module):
|
|
348
374
|
if rope is not None:
|
349
375
|
# Compute rotary positional embedding for query and key.
|
350
376
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
351
|
-
|
352
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
377
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
353
378
|
|
354
379
|
if kv_cache is not None:
|
355
380
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -80,6 +80,7 @@ class RMSNorm(torch.nn.Module):
|
|
80
80
|
output = self._norm(x.float()).type_as(x)
|
81
81
|
return output * w
|
82
82
|
|
83
|
+
|
83
84
|
class GroupNorm(torch.nn.Module):
|
84
85
|
|
85
86
|
def __init__(
|
@@ -115,16 +116,7 @@ class GroupNorm(torch.nn.Module):
|
|
115
116
|
Returns:
|
116
117
|
torch.Tensor: output tensor after applying GroupNorm.
|
117
118
|
"""
|
118
|
-
|
119
|
-
return group_norm_with_hlfb(
|
120
|
-
x,
|
121
|
-
self.weight,
|
122
|
-
self.bias,
|
123
|
-
self.group_num,
|
124
|
-
self.eps,
|
125
|
-
)
|
126
|
-
else:
|
127
|
-
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
119
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
128
120
|
|
129
121
|
|
130
122
|
class LayerNorm(torch.nn.Module):
|
@@ -169,46 +161,6 @@ class LayerNorm(torch.nn.Module):
|
|
169
161
|
)
|
170
162
|
|
171
163
|
|
172
|
-
def group_norm_with_hlfb(
|
173
|
-
x: torch.Tensor,
|
174
|
-
w: torch.Tensor,
|
175
|
-
b: torch.Tensor,
|
176
|
-
num_groups: int,
|
177
|
-
eps: float,
|
178
|
-
):
|
179
|
-
"""Group Normalization with high-level function boundary enabled.
|
180
|
-
|
181
|
-
Args:
|
182
|
-
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
183
|
-
w (torch.Tensor): The weight tensor for the normalization.
|
184
|
-
b (torch.Tensor): The bias tensor for the normalization.
|
185
|
-
num_groups (int): Number of groups to separate the channels into.
|
186
|
-
eps (float): A small float value to ensure numerical stability.
|
187
|
-
|
188
|
-
Returns:
|
189
|
-
The output tensor of Group Normalization.
|
190
|
-
"""
|
191
|
-
x = torch.permute(x, (0, 2, 3, 1))
|
192
|
-
|
193
|
-
builder = StableHLOCompositeBuilder(
|
194
|
-
name="odml.group_norm",
|
195
|
-
attr={
|
196
|
-
"num_groups": num_groups,
|
197
|
-
"epsilon": eps,
|
198
|
-
"reduction_axes": [3],
|
199
|
-
"channel_axis": 3,
|
200
|
-
},
|
201
|
-
)
|
202
|
-
x, w, b = builder.mark_inputs(x, w, b)
|
203
|
-
x = torch.permute(x, (0, 3, 1, 2))
|
204
|
-
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
205
|
-
y = torch.permute(y, (0, 2, 3, 1))
|
206
|
-
y = builder.mark_outputs(y)
|
207
|
-
|
208
|
-
y = torch.permute(y, (0, 3, 1, 2))
|
209
|
-
return y
|
210
|
-
|
211
|
-
|
212
164
|
def rms_norm_with_hlfb(
|
213
165
|
x: torch.Tensor,
|
214
166
|
w: torch.Tensor,
|
@@ -32,64 +32,57 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1
|
36
|
-
|
37
|
-
|
38
|
-
roped =
|
35
|
+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
36
|
+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
37
|
+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
38
|
+
roped = (x * cos) + (rotated * sin)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
42
|
+
def apply_rope_inline(
|
43
|
+
q: torch.Tensor,
|
44
|
+
k: torch.Tensor,
|
43
45
|
input_pos: torch.Tensor,
|
44
46
|
n_elem: int,
|
45
|
-
head_dim: int,
|
46
47
|
base: int = 10_000,
|
47
48
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
48
|
-
"""Computes rotary positional embedding
|
49
|
+
"""Computes rotary positional embedding inline for a query and key.
|
49
50
|
|
50
51
|
Args:
|
52
|
+
q: the query tensor.
|
53
|
+
k: the key tensor.
|
51
54
|
input_pos: the sequence indices for the query and key
|
52
55
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
-
base: the base of the exponentiated value for RoPE.
|
54
56
|
|
55
57
|
Returns:
|
56
|
-
|
58
|
+
output the RoPE'd query and key.
|
57
59
|
"""
|
58
60
|
|
59
61
|
if n_elem <= 0:
|
60
|
-
return
|
62
|
+
return q, k
|
61
63
|
|
62
64
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
63
65
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
64
|
-
|
66
|
+
q.shape[-1] // 2, dtype=torch.float32
|
65
67
|
)
|
66
68
|
timescale = float(base) ** freq_exponents
|
67
69
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
68
70
|
0
|
69
71
|
).unsqueeze(0)
|
70
|
-
cos = torch.cos(radians)
|
71
|
-
sin = torch.sin(radians)
|
72
|
-
return cos, sin
|
73
|
-
|
72
|
+
cos = torch.cos(radians).type_as(q)
|
73
|
+
sin = torch.sin(radians).type_as(q)
|
74
74
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
k: the key tensor.
|
86
|
-
cos: the cosine tensor.
|
87
|
-
sin: the sine tensor.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
output the RoPE'd query and key.
|
91
|
-
"""
|
75
|
+
def apply(x, sin, cos):
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
b, h, s, d = x.shape
|
78
|
+
ans = torch.split(x, d // 2, dim=-1)
|
79
|
+
x1, x2 = ans
|
80
|
+
left = x1 * cos - x2 * sin
|
81
|
+
right = x2 * cos + x1 * sin
|
82
|
+
res = torch.cat([left, right], dim=-1)
|
83
|
+
res = res.transpose(1, 2)
|
84
|
+
return res
|
92
85
|
|
93
|
-
q_roped =
|
94
|
-
k_roped =
|
86
|
+
q_roped = apply(q, sin, cos)
|
87
|
+
k_roped = apply(k, sin, cos)
|
95
88
|
return q_roped, k_roped
|
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers import builder
|
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
28
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
28
|
import torch
|
30
29
|
from torch import nn
|
@@ -86,6 +85,13 @@ class DecoderOnlyModel(nn.Module):
|
|
86
85
|
config.embedding_dim,
|
87
86
|
config.final_norm_config,
|
88
87
|
)
|
88
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
+
attn_config = config.block_config(0).attn_config
|
90
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
+
size=config.kv_cache_max,
|
92
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
+
base=attn_config.rotary_base,
|
94
|
+
)
|
89
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
90
96
|
size=config.kv_cache_max,
|
91
97
|
)
|
@@ -107,22 +113,16 @@ class DecoderOnlyModel(nn.Module):
|
|
107
113
|
|
108
114
|
# token embeddings of shape (b, t, n_embd)
|
109
115
|
input_embeds = self.tok_embedding(tokens)
|
110
|
-
|
111
|
-
|
112
|
-
attn_config = self.config.block_config(0).attn_config
|
113
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
114
|
-
rope = rotary_pos_emb.build_rope(
|
115
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
116
|
-
)
|
117
|
-
|
116
|
+
cos, sin = self.rope_cache
|
117
|
+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
118
118
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
120
|
|
121
|
-
return self.
|
121
|
+
return self.forward_with_embeds(
|
122
122
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
123
123
|
)
|
124
124
|
|
125
|
-
def
|
125
|
+
def forward_with_embeds(
|
126
126
|
self,
|
127
127
|
input_embeds: torch.Tensor,
|
128
128
|
rope: Tuple[torch.Tensor, torch.Tensor],
|
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
|
|
141
141
|
if self.config.embedding_scale is not None:
|
142
142
|
x = x * self.config.embedding_scale
|
143
143
|
|
144
|
-
|
144
|
+
updated_kv_entires = []
|
145
145
|
for i, block in enumerate(self.transformer_blocks):
|
146
146
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
147
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
148
148
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
149
|
+
updated_kv_entires.append(kv_entry)
|
150
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
151
151
|
|
152
152
|
if export_config is not None:
|
153
153
|
if (
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Any,List
|
19
|
+
from typing import Any, List, Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
134
134
|
prompts: torch.Tensor,
|
135
135
|
max_new_tokens: int,
|
136
136
|
pixel_values: torch.Tensor = None,
|
137
|
-
eos_token_id: int =
|
137
|
+
eos_token_id: Optional[int] = None,
|
138
138
|
) -> torch.IntTensor:
|
139
139
|
input_ids = prompts[0].int().tolist()
|
140
140
|
tokens = torch.tensor([input_ids])
|
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
146
146
|
)
|
147
147
|
generated_token = logits[0][-1].argmax().item()
|
148
148
|
input_ids.append(generated_token)
|
149
|
-
if generated_token == eos_token_id:
|
149
|
+
if eos_token_id is not None and generated_token == eos_token_id:
|
150
150
|
break
|
151
151
|
tokens = torch.tensor([[generated_token]])
|
152
152
|
input_pos = torch.tensor([len(input_ids) - 1])
|
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
|
|
253
253
|
outputs_reauthored = reauthored_model.generate(
|
254
254
|
prompt_tokens,
|
255
255
|
max_new_tokens,
|
256
|
-
eos_token_id=tokenizer.tokenizer
|
256
|
+
eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
|
257
257
|
)
|
258
258
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
259
259
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
|
|
27
27
|
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
|
28
28
|
os.environ["PJRT_DEVICE"] = "CPU"
|
29
29
|
|
30
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
31
|
+
|
32
|
+
|
30
33
|
from ai_edge_torch import model
|
31
34
|
from ai_edge_torch._convert import conversion_utils
|
32
35
|
from ai_edge_torch._convert import signature as signature_module
|
@@ -202,7 +202,7 @@ class MlirLowered:
|
|
202
202
|
target_version = stablehlo.get_minimum_version()
|
203
203
|
else:
|
204
204
|
target_version = stablehlo.get_version_from_compatibility_requirement(
|
205
|
-
stablehlo.StablehloCompatibilityRequirement.
|
205
|
+
stablehlo.StablehloCompatibilityRequirement.WEEK_12
|
206
206
|
)
|
207
207
|
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
208
208
|
self.module_bytecode, target_version
|
@@ -222,11 +222,6 @@ class MlirLowered:
|
|
222
222
|
# Lazy importing TF when execution is needed.
|
223
223
|
return self.tf_function(*args)
|
224
224
|
|
225
|
-
def to_flatbuffer(self):
|
226
|
-
from . import tf_integration
|
227
|
-
|
228
|
-
return tf_integration.mlir_to_flatbuffer(self)
|
229
|
-
|
230
225
|
|
231
226
|
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
232
227
|
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
@@ -12,10 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""APIs to convert lowered MLIR from PyTorch to TensorFlow
|
15
|
+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
|
16
16
|
|
17
17
|
import re
|
18
|
-
import tempfile
|
19
18
|
|
20
19
|
import tensorflow as tf
|
21
20
|
import torch
|
@@ -104,20 +103,26 @@ def _extract_call_args(
|
|
104
103
|
def _wrap_as_tf_func(lowered, tf_state_dict):
|
105
104
|
"""Build tf.function from lowered and tf_state_dict."""
|
106
105
|
|
107
|
-
|
106
|
+
version = 6
|
107
|
+
if hasattr(tfxla, "call_module_maximum_supported_version"):
|
108
|
+
version = tfxla.call_module_maximum_supported_version()
|
109
|
+
|
110
|
+
def tf_func(*args):
|
108
111
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
|
109
112
|
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
|
110
113
|
call_args = _extract_call_args(lowered, args, tf_state_dict)
|
111
114
|
return tfxla.call_module(
|
112
115
|
tuple(call_args),
|
113
|
-
version=
|
116
|
+
version=version,
|
114
117
|
Tout=t_outs, # dtype information
|
115
|
-
Sout=s_outs, #
|
118
|
+
Sout=s_outs, # shape information
|
116
119
|
function_list=[],
|
117
|
-
module=lowered.
|
120
|
+
module=lowered.module_bytecode_vhlo,
|
121
|
+
has_token_input_output=False,
|
122
|
+
platforms=["CPU"],
|
118
123
|
)
|
119
124
|
|
120
|
-
return
|
125
|
+
return tf_func
|
121
126
|
|
122
127
|
|
123
128
|
def _make_input_signatures(
|
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
|
|
149
154
|
_wrap_as_tf_func(lowered, tf_state_dict),
|
150
155
|
input_signature=_make_input_signatures(lowered),
|
151
156
|
)
|
152
|
-
|
153
|
-
|
154
|
-
def mlir_to_flatbuffer(lowered: export.MlirLowered):
|
155
|
-
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
|
156
|
-
tf_state_dict = _build_tf_state_dict(lowered)
|
157
|
-
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
158
|
-
tf_signatures = [_make_input_signatures(lowered)]
|
159
|
-
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
|
160
|
-
|
161
|
-
tf_module = tf.Module()
|
162
|
-
tf_module.f = []
|
163
|
-
|
164
|
-
for tf_sig, func in zip(tf_signatures, tf_functions):
|
165
|
-
tf_module.f.append(
|
166
|
-
tf.function(
|
167
|
-
func,
|
168
|
-
input_signature=tf_sig,
|
169
|
-
)
|
170
|
-
)
|
171
|
-
|
172
|
-
tf_module._variables = list(tf_state_dict.values())
|
173
|
-
|
174
|
-
tf_concrete_funcs = [
|
175
|
-
func.get_concrete_function(*tf_sig)
|
176
|
-
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
177
|
-
]
|
178
|
-
|
179
|
-
# We need to temporarily save since TFLite's from_concrete_functions does not
|
180
|
-
# allow providing names for each of the concrete functions.
|
181
|
-
with tempfile.TemporaryDirectory() as temp_dir_path:
|
182
|
-
tf.saved_model.save(
|
183
|
-
tf_module,
|
184
|
-
temp_dir_path,
|
185
|
-
signatures={
|
186
|
-
sig_name: tf_concrete_funcs[idx]
|
187
|
-
for idx, sig_name in enumerate(signature_names)
|
188
|
-
},
|
189
|
-
)
|
190
|
-
|
191
|
-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
192
|
-
tflite_model = converter.convert()
|
193
|
-
|
194
|
-
return tflite_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250107
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,11 +1,11 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
|
2
|
-
ai_edge_torch/_config.py,sha256=
|
2
|
+
ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=X0ZEB5T3xcR8MsIE8VOHDAdHnCZTzJLBQQ9j2xZ4_qA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
@@ -13,12 +13,12 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
|
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
15
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
16
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -109,7 +109,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
|
|
109
109
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
110
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
111
111
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
112
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
|
113
113
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
114
114
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
|
115
115
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
@@ -117,14 +117,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
117
117
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
118
118
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
119
119
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
120
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
120
|
+
ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
|
121
121
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
122
122
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
123
123
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
124
124
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
125
125
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
126
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
127
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
126
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
127
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
128
128
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
129
129
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
130
130
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -149,12 +149,12 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
|
|
149
149
|
ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
|
150
150
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
151
151
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
152
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
152
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
|
153
153
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
154
154
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
155
155
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
156
156
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
157
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
157
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
158
158
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
159
159
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
160
160
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
|
|
166
166
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
167
167
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
168
168
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
169
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
169
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
170
170
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
171
171
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
172
172
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
173
173
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
174
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
174
|
+
ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
|
175
175
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
176
|
-
ai_edge_torch/odml_torch/tf_integration.py,sha256=
|
176
|
+
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
177
177
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
178
178
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
179
179
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/METADATA,sha256=p2F-coQaq7CbpMOkQLVnpFB01cCKqftVRGZ4dCVu8Ck,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/RECORD,,
|
File without changes
|
File without changes
|