ai-edge-torch-nightly 0.3.0.dev20241226__py3-none-any.whl → 0.3.0.dev20250107__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/_config.py +26 -9
- ai_edge_torch/_convert/conversion.py +22 -18
- ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py +1 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py +13 -0
- ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py +36 -0
- ai_edge_torch/generative/examples/gemma/gemma2.py +25 -43
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +3 -3
- ai_edge_torch/generative/layers/attention.py +29 -4
- ai_edge_torch/generative/layers/normalization.py +2 -50
- ai_edge_torch/generative/layers/rotary_position_embedding.py +27 -34
- ai_edge_torch/generative/utilities/model_builder.py +14 -14
- ai_edge_torch/generative/utilities/verifier.py +4 -4
- ai_edge_torch/lowertools/torch_xla_utils.py +3 -0
- ai_edge_torch/odml_torch/export.py +1 -6
- ai_edge_torch/odml_torch/tf_integration.py +12 -50
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/RECORD +21 -21
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20241226.dist-info → ai_edge_torch_nightly-0.3.0.dev20250107.dist-info}/top_level.txt +0 -0
ai_edge_torch/_config.py
CHANGED
@@ -22,6 +22,18 @@ import os
|
|
22
22
|
__all__ = ["config"]
|
23
23
|
|
24
24
|
|
25
|
+
def _get_bool_env_var(name: str, default: bool) -> bool:
|
26
|
+
var = os.environ.get(name, "false")
|
27
|
+
var = var.lower().strip()
|
28
|
+
if var in ("y", "yes", "t", "true", "on", "1"):
|
29
|
+
return True
|
30
|
+
elif var in ("n", "no", "f", "false", "off", "0"):
|
31
|
+
return False
|
32
|
+
else:
|
33
|
+
logging.warning("Invalid %s value is ignored: %s.", name, var)
|
34
|
+
return default
|
35
|
+
|
36
|
+
|
25
37
|
class _Config:
|
26
38
|
"""ai-edge-torch global configs."""
|
27
39
|
|
@@ -33,20 +45,25 @@ class _Config:
|
|
33
45
|
To use torch_xla as the lowering backend, set environment variable
|
34
46
|
`USE_TORCH_XLA` to "true".
|
35
47
|
"""
|
36
|
-
|
37
|
-
var = var.lower().strip()
|
38
|
-
if var in ("y", "yes", "t", "true", "on", "1"):
|
39
|
-
return True
|
40
|
-
elif var in ("n", "no", "f", "false", "off", "0"):
|
41
|
-
return False
|
42
|
-
else:
|
43
|
-
logging.warning("Invalid USE_TORCH_XLA value is ignored: %s.", var)
|
44
|
-
return False
|
48
|
+
return _get_bool_env_var("USE_TORCH_XLA", default=False)
|
45
49
|
|
46
50
|
@property
|
47
51
|
def in_oss(self) -> bool:
|
48
52
|
"""True if the code is not running in google internal environment."""
|
49
53
|
return True
|
50
54
|
|
55
|
+
@property
|
56
|
+
def enable_group_norm_composite(self) -> bool:
|
57
|
+
"""True if lowering group norm in StableHLO composite.
|
58
|
+
|
59
|
+
Currently only supports NHWC group norm generated by
|
60
|
+
OptimizeLayoutTransposesPass.
|
61
|
+
"""
|
62
|
+
return _get_bool_env_var("ENABLE_GROUP_NORM_COMPOSITE", default=False)
|
63
|
+
|
64
|
+
@enable_group_norm_composite.setter
|
65
|
+
def enable_group_norm_composite(self, value: bool):
|
66
|
+
os.environ["ENABLE_GROUP_NORM_COMPOSITE"] = "y" if value else "n"
|
67
|
+
|
51
68
|
|
52
69
|
config = _Config()
|
@@ -14,9 +14,9 @@
|
|
14
14
|
# ==============================================================================
|
15
15
|
|
16
16
|
import logging
|
17
|
-
import os
|
18
17
|
from typing import Any, Literal, Optional, Union
|
19
18
|
|
19
|
+
import ai_edge_torch
|
20
20
|
from ai_edge_torch import fx_pass_base
|
21
21
|
from ai_edge_torch import lowertools
|
22
22
|
from ai_edge_torch import model
|
@@ -26,8 +26,6 @@ from ai_edge_torch.generative import fx_passes as generative_fx_passes
|
|
26
26
|
from ai_edge_torch.quantize import quant_config as qcfg
|
27
27
|
import torch
|
28
28
|
|
29
|
-
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
30
|
-
|
31
29
|
|
32
30
|
def _run_convert_passes(
|
33
31
|
exported_program: torch.export.ExportedProgram,
|
@@ -35,21 +33,27 @@ def _run_convert_passes(
|
|
35
33
|
exported_program = generative_fx_passes.run_generative_passes(
|
36
34
|
exported_program
|
37
35
|
)
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
36
|
+
|
37
|
+
passes = [
|
38
|
+
fx_passes.BuildInterpolateCompositePass(),
|
39
|
+
fx_passes.CanonicalizePass(),
|
40
|
+
fx_passes.OptimizeLayoutTransposesPass(),
|
41
|
+
fx_passes.CanonicalizePass(),
|
42
|
+
fx_passes.BuildAtenCompositePass(),
|
43
|
+
fx_passes.CanonicalizePass(),
|
44
|
+
fx_passes.RemoveNonUserOutputsPass(),
|
45
|
+
fx_passes.CanonicalizePass(),
|
46
|
+
]
|
47
|
+
|
48
|
+
# Debuginfo is not injected automatically by odml_torch. Only inject
|
49
|
+
# debuginfo via fx pass when using torch_xla.
|
50
|
+
if ai_edge_torch.config.use_torch_xla:
|
51
|
+
passes += [
|
52
|
+
fx_passes.InjectMlirDebuginfoPass(),
|
53
|
+
fx_passes.CanonicalizePass(),
|
54
|
+
]
|
55
|
+
|
56
|
+
exported_program = fx_pass_base.run_passes(exported_program, passes)
|
53
57
|
return exported_program
|
54
58
|
|
55
59
|
|
@@ -62,6 +62,7 @@ def _wrap_call_function_node_with_debuginfo_writer(node: torch.fx.GraphModule):
|
|
62
62
|
|
63
63
|
|
64
64
|
class InjectMlirDebuginfoPass(fx_pass_base.PassBase):
|
65
|
+
"""DEPRECATED: Debuginfo is injected automatically by odml_torch."""
|
65
66
|
|
66
67
|
def call(self, graph_module: torch.fx.GraphModule):
|
67
68
|
for node in graph_module.graph.nodes:
|
@@ -17,6 +17,7 @@
|
|
17
17
|
import dataclasses
|
18
18
|
import operator
|
19
19
|
|
20
|
+
import ai_edge_torch
|
20
21
|
from ai_edge_torch import lowertools
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_rewrite
|
22
23
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -155,6 +156,7 @@ def _qdq_layout_sensitive_inputs_getter(node: Node):
|
|
155
156
|
@layout_sensitive_inputs_getters.register(
|
156
157
|
aten._native_batch_norm_legit_no_training
|
157
158
|
)
|
159
|
+
@layout_sensitive_inputs_getters.register(aten.group_norm)
|
158
160
|
@layout_sensitive_inputs_getters.register(aten.native_group_norm)
|
159
161
|
def _first_arg_getter(node):
|
160
162
|
return [node.args[0]]
|
@@ -188,6 +190,17 @@ def _aten_norm_checker(node):
|
|
188
190
|
return NHWCable(can_be=len(val[0].shape) == 4, must_be=False)
|
189
191
|
|
190
192
|
|
193
|
+
@nhwcable_node_checkers.register(aten.group_norm)
|
194
|
+
def _aten_group_norm_checker(node):
|
195
|
+
val = node.meta.get("val")
|
196
|
+
if not hasattr(val, "shape"):
|
197
|
+
return NHWCable(can_be=False, must_be=False)
|
198
|
+
|
199
|
+
can_be = len(val.shape) == 4
|
200
|
+
must_be = can_be and ai_edge_torch.config.enable_group_norm_composite
|
201
|
+
return NHWCable(can_be=can_be, must_be=must_be)
|
202
|
+
|
203
|
+
|
191
204
|
@nhwcable_node_checkers.register(aten.native_group_norm)
|
192
205
|
def _aten_native_group_norm_checker(node):
|
193
206
|
val = node.meta.get("val")
|
@@ -16,6 +16,7 @@
|
|
16
16
|
|
17
17
|
import operator
|
18
18
|
|
19
|
+
import ai_edge_torch
|
19
20
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import layout_mark
|
20
21
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import op_func_registry
|
21
22
|
from ai_edge_torch._convert.fx_passes.optimize_layout_transposes_pass import utils
|
@@ -23,6 +24,7 @@ import torch
|
|
23
24
|
import torch.utils._pytree as pytree
|
24
25
|
|
25
26
|
aten = torch.ops.aten
|
27
|
+
StableHLOCompositeBuilder = ai_edge_torch.hlfb.StableHLOCompositeBuilder
|
26
28
|
|
27
29
|
__all__ = ["rewrite_nhwc_node", "has_nhwc_rewriter"]
|
28
30
|
|
@@ -342,6 +344,39 @@ def _aten__native_batch_norm_legit_no_training(node):
|
|
342
344
|
node.target = batch_norm
|
343
345
|
|
344
346
|
|
347
|
+
@rewriters.register(aten.group_norm.default)
|
348
|
+
def _aten_group_norm(node):
|
349
|
+
def group_norm(input, num_groups: int, weight=None, bias=None, eps=1e-5):
|
350
|
+
is_composite_supported = (
|
351
|
+
ai_edge_torch.config.enable_group_norm_composite
|
352
|
+
and weight is not None
|
353
|
+
and bias is not None
|
354
|
+
)
|
355
|
+
|
356
|
+
builder = None
|
357
|
+
if is_composite_supported:
|
358
|
+
builder = StableHLOCompositeBuilder(
|
359
|
+
name="odml.group_norm",
|
360
|
+
attr={
|
361
|
+
"num_groups": num_groups,
|
362
|
+
"epsilon": eps,
|
363
|
+
"reduction_axes": [3],
|
364
|
+
"channel_axis": 3,
|
365
|
+
},
|
366
|
+
)
|
367
|
+
input, weight, bias = builder.mark_inputs(input, weight, bias)
|
368
|
+
|
369
|
+
input = utils.tensor_to_nchw(input)
|
370
|
+
output = aten.group_norm.default(input, num_groups, weight, bias, eps=eps)
|
371
|
+
output = utils.tensor_to_nhwc(output)
|
372
|
+
|
373
|
+
if builder is not None:
|
374
|
+
output = builder.mark_outputs(output)
|
375
|
+
return output
|
376
|
+
|
377
|
+
node.target = group_norm
|
378
|
+
|
379
|
+
|
345
380
|
@rewriters.register(aten.native_group_norm.default)
|
346
381
|
def _aten_native_group_norm(node):
|
347
382
|
|
@@ -354,6 +389,7 @@ def _aten_native_group_norm(node):
|
|
354
389
|
flattened_inner_size: int,
|
355
390
|
num_groups: int,
|
356
391
|
eps: float,
|
392
|
+
**kwargs,
|
357
393
|
):
|
358
394
|
input_reshaped = torch.reshape(
|
359
395
|
input,
|
@@ -15,14 +15,13 @@
|
|
15
15
|
|
16
16
|
"""Example of building a Gemma2 model."""
|
17
17
|
|
18
|
-
from typing import
|
18
|
+
from typing import Optional, Tuple
|
19
19
|
|
20
20
|
from ai_edge_torch.generative.layers import attention
|
21
21
|
from ai_edge_torch.generative.layers import builder
|
22
22
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
23
23
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
24
24
|
import ai_edge_torch.generative.layers.model_config as cfg
|
25
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
26
25
|
from ai_edge_torch.generative.utilities import model_builder
|
27
26
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
28
27
|
import torch
|
@@ -104,12 +103,17 @@ class Gemma2(nn.Module):
|
|
104
103
|
config.embedding_dim,
|
105
104
|
config.final_norm_config,
|
106
105
|
)
|
107
|
-
self.mask_cache = attn_utils.build_causal_mask_cache(
|
108
|
-
size=config.kv_cache_max,
|
109
|
-
)
|
110
106
|
# Gemma2 has same hyper parameters for each layer except for attention
|
111
107
|
# types. Use the first layer.
|
112
108
|
attn_config = config.block_config(0).attn_config
|
109
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
110
|
+
size=config.kv_cache_max,
|
111
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
112
|
+
base=attn_config.rotary_base,
|
113
|
+
)
|
114
|
+
self.mask_cache = attn_utils.build_causal_mask_cache(
|
115
|
+
size=config.kv_cache_max,
|
116
|
+
)
|
113
117
|
self.sliding_window_mask_cache = attn_utils.build_sliding_window_mask_cache(
|
114
118
|
size=config.kv_cache_max,
|
115
119
|
window_size=attn_config.sliding_window_size,
|
@@ -136,48 +140,29 @@ class Gemma2(nn.Module):
|
|
136
140
|
f"Cannot forward sequence of length {seq_len}, max seq length is only"
|
137
141
|
f" {self.config.max_seq_len}"
|
138
142
|
)
|
139
|
-
|
140
|
-
# token embeddings of shape (b, t, n_embd)
|
141
|
-
input_embeds = self.tok_embedding(tokens)
|
142
|
-
# RoPE parameters are the same for all blocks. Use the first layer.
|
143
|
-
attn_config = self.config.block_config(0).attn_config
|
144
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
145
|
-
rope = rotary_pos_emb.build_rope(
|
146
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
147
|
-
)
|
148
|
-
mask = [self.get_attention_mask(
|
149
|
-
self.config.block_config(i).attn_config.attn_type, input_pos
|
150
|
-
) for i in range(self.config.num_layers)]
|
151
|
-
|
152
|
-
return self._forward_with_embeds(
|
153
|
-
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
154
|
-
)
|
155
|
-
|
156
|
-
def _forward_with_embeds(
|
157
|
-
self,
|
158
|
-
input_embeds: torch.Tensor,
|
159
|
-
rope: Tuple[torch.Tensor, torch.Tensor],
|
160
|
-
mask: List[torch.Tensor],
|
161
|
-
input_pos: torch.Tensor,
|
162
|
-
kv_cache: kv_utils.KVCache,
|
163
|
-
export_config: Optional[model_builder.ExportConfig] = None,
|
164
|
-
) -> dict[torch.Tensor, kv_utils.KVCache]:
|
165
|
-
"""Forwards the model with input embeddings."""
|
166
143
|
assert len(self.transformer_blocks) == len(kv_cache.caches), (
|
167
144
|
"The number of transformer blocks and the number of KV cache entries"
|
168
145
|
" must be the same."
|
169
146
|
)
|
170
147
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
148
|
+
cos, sin = self.rope_cache
|
149
|
+
cos = cos.index_select(0, input_pos)
|
150
|
+
sin = sin.index_select(0, input_pos)
|
151
|
+
|
152
|
+
# token embeddings of shape (b, t, n_embd)
|
153
|
+
x = self.tok_embedding(tokens)
|
154
|
+
x = x * (self.config.embedding_dim**0.5)
|
155
|
+
|
156
|
+
updated_kv_entires = []
|
175
157
|
for i, block in enumerate(self.transformer_blocks):
|
158
|
+
mask = self.get_attention_mask(
|
159
|
+
block.config.attn_config.attn_type, input_pos
|
160
|
+
)
|
176
161
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
177
|
-
x, kv_entry = block(x,
|
162
|
+
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
178
163
|
if kv_entry:
|
179
|
-
|
180
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
164
|
+
updated_kv_entires.append(kv_entry)
|
165
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
181
166
|
|
182
167
|
if export_config is not None:
|
183
168
|
if (
|
@@ -243,13 +228,11 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
|
|
243
228
|
)
|
244
229
|
|
245
230
|
num_layers = 26
|
246
|
-
embedding_dim = 2304
|
247
231
|
config = cfg.ModelConfig(
|
248
232
|
vocab_size=256000,
|
249
233
|
num_layers=num_layers,
|
250
234
|
max_seq_len=8192,
|
251
|
-
embedding_dim=
|
252
|
-
embedding_scale=embedding_dim**0.5,
|
235
|
+
embedding_dim=2304,
|
253
236
|
kv_cache_max_len=kv_cache_max_len,
|
254
237
|
block_configs=[get_block_config(i) for i in range(num_layers)],
|
255
238
|
final_norm_config=norm_config,
|
@@ -266,7 +249,6 @@ def get_fake_model_config(kv_cache_max_len: int = 128) -> cfg.ModelConfig:
|
|
266
249
|
config.num_layers = 2
|
267
250
|
config.max_seq_len = 2 * kv_cache_max_len
|
268
251
|
config.embedding_dim = 128
|
269
|
-
config.embedding_scale = config.embedding_dim**0.5
|
270
252
|
config.block_configs = config.block_configs[: config.num_layers]
|
271
253
|
for block_config in config.block_configs:
|
272
254
|
block_config.attn_config.num_heads = 4
|
@@ -72,14 +72,14 @@ class ToyModelWithKVCache(torch.nn.Module):
|
|
72
72
|
mask = self.mask_cache.index_select(2, input_pos)
|
73
73
|
mask = mask[:, :, :, : self.config.max_seq_len]
|
74
74
|
|
75
|
-
|
75
|
+
updated_kv_entires = []
|
76
76
|
for i, block in enumerate(self.transformer_blocks):
|
77
77
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
78
78
|
x, kv_entry = block(x, (cos, sin), mask, input_pos, kv_entry)
|
79
79
|
if kv_entry:
|
80
|
-
|
80
|
+
updated_kv_entires.append(kv_entry)
|
81
81
|
|
82
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
82
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
83
83
|
|
84
84
|
if export_config is not None:
|
85
85
|
if (
|
@@ -26,6 +26,33 @@ import torch
|
|
26
26
|
from torch import nn
|
27
27
|
|
28
28
|
|
29
|
+
def _embed_rope(
|
30
|
+
q: torch.Tensor,
|
31
|
+
k: torch.Tensor,
|
32
|
+
n_elem: int,
|
33
|
+
rope: Tuple[torch.Tensor, torch.Tensor],
|
34
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
35
|
+
"""Embed rotary positional embedding for query and key.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
q (torch.Tensor): query tensor.
|
39
|
+
k (torch.Tensor): key tensor.
|
40
|
+
n_elem (int): number of elements to embed rotarty positional embedding.
|
41
|
+
rope (Tuple[torch.Tensor, torch.Tensor]): the input rope tensor.
|
42
|
+
"""
|
43
|
+
if n_elem > 0:
|
44
|
+
cos, sin = rope
|
45
|
+
q_roped = rotary_pos_emb.apply_rope(
|
46
|
+
q[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
47
|
+
)
|
48
|
+
k_roped = rotary_pos_emb.apply_rope(
|
49
|
+
k[..., :n_elem], cos.repeat(1, 2), sin.repeat(1, 2)
|
50
|
+
)
|
51
|
+
q = torch.cat((q_roped, q[..., n_elem:]), dim=-1)
|
52
|
+
k = torch.cat((k_roped, k[..., n_elem:]), dim=-1)
|
53
|
+
return q, k
|
54
|
+
|
55
|
+
|
29
56
|
class TransformerBlock(nn.Module):
|
30
57
|
|
31
58
|
def __init__(
|
@@ -211,8 +238,7 @@ class CausalSelfAttention(nn.Module):
|
|
211
238
|
if rope is not None:
|
212
239
|
# Compute rotary positional embedding for query and key.
|
213
240
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
214
|
-
|
215
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
241
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
216
242
|
|
217
243
|
if kv_cache is not None:
|
218
244
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -348,8 +374,7 @@ class CrossAttention(nn.Module):
|
|
348
374
|
if rope is not None:
|
349
375
|
# Compute rotary positional embedding for query and key.
|
350
376
|
n_elem = int(self.config.rotary_percentage * self.config.head_dim)
|
351
|
-
|
352
|
-
q, k = rotary_pos_emb.apply_rope_inline(q, k, cos, sin)
|
377
|
+
q, k = _embed_rope(q, k, n_elem, rope)
|
353
378
|
|
354
379
|
if kv_cache is not None:
|
355
380
|
kv_cache = kv_utils.update(kv_cache, input_pos, k, v)
|
@@ -80,6 +80,7 @@ class RMSNorm(torch.nn.Module):
|
|
80
80
|
output = self._norm(x.float()).type_as(x)
|
81
81
|
return output * w
|
82
82
|
|
83
|
+
|
83
84
|
class GroupNorm(torch.nn.Module):
|
84
85
|
|
85
86
|
def __init__(
|
@@ -115,16 +116,7 @@ class GroupNorm(torch.nn.Module):
|
|
115
116
|
Returns:
|
116
117
|
torch.Tensor: output tensor after applying GroupNorm.
|
117
118
|
"""
|
118
|
-
|
119
|
-
return group_norm_with_hlfb(
|
120
|
-
x,
|
121
|
-
self.weight,
|
122
|
-
self.bias,
|
123
|
-
self.group_num,
|
124
|
-
self.eps,
|
125
|
-
)
|
126
|
-
else:
|
127
|
-
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
119
|
+
return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
|
128
120
|
|
129
121
|
|
130
122
|
class LayerNorm(torch.nn.Module):
|
@@ -169,46 +161,6 @@ class LayerNorm(torch.nn.Module):
|
|
169
161
|
)
|
170
162
|
|
171
163
|
|
172
|
-
def group_norm_with_hlfb(
|
173
|
-
x: torch.Tensor,
|
174
|
-
w: torch.Tensor,
|
175
|
-
b: torch.Tensor,
|
176
|
-
num_groups: int,
|
177
|
-
eps: float,
|
178
|
-
):
|
179
|
-
"""Group Normalization with high-level function boundary enabled.
|
180
|
-
|
181
|
-
Args:
|
182
|
-
x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
|
183
|
-
w (torch.Tensor): The weight tensor for the normalization.
|
184
|
-
b (torch.Tensor): The bias tensor for the normalization.
|
185
|
-
num_groups (int): Number of groups to separate the channels into.
|
186
|
-
eps (float): A small float value to ensure numerical stability.
|
187
|
-
|
188
|
-
Returns:
|
189
|
-
The output tensor of Group Normalization.
|
190
|
-
"""
|
191
|
-
x = torch.permute(x, (0, 2, 3, 1))
|
192
|
-
|
193
|
-
builder = StableHLOCompositeBuilder(
|
194
|
-
name="odml.group_norm",
|
195
|
-
attr={
|
196
|
-
"num_groups": num_groups,
|
197
|
-
"epsilon": eps,
|
198
|
-
"reduction_axes": [3],
|
199
|
-
"channel_axis": 3,
|
200
|
-
},
|
201
|
-
)
|
202
|
-
x, w, b = builder.mark_inputs(x, w, b)
|
203
|
-
x = torch.permute(x, (0, 3, 1, 2))
|
204
|
-
y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
|
205
|
-
y = torch.permute(y, (0, 2, 3, 1))
|
206
|
-
y = builder.mark_outputs(y)
|
207
|
-
|
208
|
-
y = torch.permute(y, (0, 3, 1, 2))
|
209
|
-
return y
|
210
|
-
|
211
|
-
|
212
164
|
def rms_norm_with_hlfb(
|
213
165
|
x: torch.Tensor,
|
214
166
|
w: torch.Tensor,
|
@@ -32,64 +32,57 @@ def apply_rope(
|
|
32
32
|
"""
|
33
33
|
x = x.transpose(1, 2)
|
34
34
|
head_size = x.size(-1)
|
35
|
-
x1
|
36
|
-
|
37
|
-
|
38
|
-
roped =
|
35
|
+
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
|
36
|
+
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
|
37
|
+
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
|
38
|
+
roped = (x * cos) + (rotated * sin)
|
39
39
|
return roped.transpose(1, 2).type_as(x)
|
40
40
|
|
41
41
|
|
42
|
-
def
|
42
|
+
def apply_rope_inline(
|
43
|
+
q: torch.Tensor,
|
44
|
+
k: torch.Tensor,
|
43
45
|
input_pos: torch.Tensor,
|
44
46
|
n_elem: int,
|
45
|
-
head_dim: int,
|
46
47
|
base: int = 10_000,
|
47
48
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
48
|
-
"""Computes rotary positional embedding
|
49
|
+
"""Computes rotary positional embedding inline for a query and key.
|
49
50
|
|
50
51
|
Args:
|
52
|
+
q: the query tensor.
|
53
|
+
k: the key tensor.
|
51
54
|
input_pos: the sequence indices for the query and key
|
52
55
|
n_elem: number of elements of the head dimension for RoPE computation
|
53
|
-
base: the base of the exponentiated value for RoPE.
|
54
56
|
|
55
57
|
Returns:
|
56
|
-
|
58
|
+
output the RoPE'd query and key.
|
57
59
|
"""
|
58
60
|
|
59
61
|
if n_elem <= 0:
|
60
|
-
return
|
62
|
+
return q, k
|
61
63
|
|
62
64
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2).float() / n_elem))
|
63
65
|
freq_exponents = (2.0 / n_elem) * torch.arange(
|
64
|
-
|
66
|
+
q.shape[-1] // 2, dtype=torch.float32
|
65
67
|
)
|
66
68
|
timescale = float(base) ** freq_exponents
|
67
69
|
radians = input_pos.clone().unsqueeze(0).unsqueeze(-1) / timescale.unsqueeze(
|
68
70
|
0
|
69
71
|
).unsqueeze(0)
|
70
|
-
cos = torch.cos(radians)
|
71
|
-
sin = torch.sin(radians)
|
72
|
-
return cos, sin
|
73
|
-
|
72
|
+
cos = torch.cos(radians).type_as(q)
|
73
|
+
sin = torch.sin(radians).type_as(q)
|
74
74
|
|
75
|
-
def
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
k: the key tensor.
|
86
|
-
cos: the cosine tensor.
|
87
|
-
sin: the sine tensor.
|
88
|
-
|
89
|
-
Returns:
|
90
|
-
output the RoPE'd query and key.
|
91
|
-
"""
|
75
|
+
def apply(x, sin, cos):
|
76
|
+
x = x.transpose(1, 2)
|
77
|
+
b, h, s, d = x.shape
|
78
|
+
ans = torch.split(x, d // 2, dim=-1)
|
79
|
+
x1, x2 = ans
|
80
|
+
left = x1 * cos - x2 * sin
|
81
|
+
right = x2 * cos + x1 * sin
|
82
|
+
res = torch.cat([left, right], dim=-1)
|
83
|
+
res = res.transpose(1, 2)
|
84
|
+
return res
|
92
85
|
|
93
|
-
q_roped =
|
94
|
-
k_roped =
|
86
|
+
q_roped = apply(q, sin, cos)
|
87
|
+
k_roped = apply(k, sin, cos)
|
95
88
|
return q_roped, k_roped
|
@@ -24,7 +24,6 @@ from ai_edge_torch.generative.layers import builder
|
|
24
24
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
25
25
|
import ai_edge_torch.generative.layers.attention_utils as attn_utils
|
26
26
|
import ai_edge_torch.generative.layers.model_config as cfg
|
27
|
-
import ai_edge_torch.generative.layers.rotary_position_embedding as rotary_pos_emb
|
28
27
|
import ai_edge_torch.generative.utilities.loader as loading_utils
|
29
28
|
import torch
|
30
29
|
from torch import nn
|
@@ -86,6 +85,13 @@ class DecoderOnlyModel(nn.Module):
|
|
86
85
|
config.embedding_dim,
|
87
86
|
config.final_norm_config,
|
88
87
|
)
|
88
|
+
# ROPE parameters for all attn_configs are the same. Take the first one.
|
89
|
+
attn_config = config.block_config(0).attn_config
|
90
|
+
self.rope_cache = attn_utils.build_rope_cache(
|
91
|
+
size=config.kv_cache_max,
|
92
|
+
dim=int(attn_config.rotary_percentage * attn_config.head_dim),
|
93
|
+
base=attn_config.rotary_base,
|
94
|
+
)
|
89
95
|
self.mask_cache = attn_utils.build_causal_mask_cache(
|
90
96
|
size=config.kv_cache_max,
|
91
97
|
)
|
@@ -107,22 +113,16 @@ class DecoderOnlyModel(nn.Module):
|
|
107
113
|
|
108
114
|
# token embeddings of shape (b, t, n_embd)
|
109
115
|
input_embeds = self.tok_embedding(tokens)
|
110
|
-
|
111
|
-
|
112
|
-
attn_config = self.config.block_config(0).attn_config
|
113
|
-
n_elem = int(attn_config.rotary_percentage * attn_config.head_dim)
|
114
|
-
rope = rotary_pos_emb.build_rope(
|
115
|
-
input_pos, n_elem, attn_config.head_dim, attn_config.rotary_base
|
116
|
-
)
|
117
|
-
|
116
|
+
cos, sin = self.rope_cache
|
117
|
+
rope = (cos.index_select(0, input_pos), sin.index_select(0, input_pos))
|
118
118
|
mask = self.mask_cache.index_select(2, input_pos)
|
119
119
|
mask = mask[:, :, :, : self.config.kv_cache_max]
|
120
120
|
|
121
|
-
return self.
|
121
|
+
return self.forward_with_embeds(
|
122
122
|
input_embeds, rope, mask, input_pos, kv_cache, export_config
|
123
123
|
)
|
124
124
|
|
125
|
-
def
|
125
|
+
def forward_with_embeds(
|
126
126
|
self,
|
127
127
|
input_embeds: torch.Tensor,
|
128
128
|
rope: Tuple[torch.Tensor, torch.Tensor],
|
@@ -141,13 +141,13 @@ class DecoderOnlyModel(nn.Module):
|
|
141
141
|
if self.config.embedding_scale is not None:
|
142
142
|
x = x * self.config.embedding_scale
|
143
143
|
|
144
|
-
|
144
|
+
updated_kv_entires = []
|
145
145
|
for i, block in enumerate(self.transformer_blocks):
|
146
146
|
kv_entry = kv_cache.caches[i] if kv_cache else None
|
147
147
|
x, kv_entry = block(x, rope, mask, input_pos, kv_entry)
|
148
148
|
if kv_entry:
|
149
|
-
|
150
|
-
updated_kv_cache = kv_utils.KVCache(tuple(
|
149
|
+
updated_kv_entires.append(kv_entry)
|
150
|
+
updated_kv_cache = kv_utils.KVCache(tuple(updated_kv_entires))
|
151
151
|
|
152
152
|
if export_config is not None:
|
153
153
|
if (
|
@@ -16,7 +16,7 @@
|
|
16
16
|
"""Common utility functions to verify the reauthored models."""
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Any,List
|
19
|
+
from typing import Any, List, Optional
|
20
20
|
|
21
21
|
from ai_edge_torch.generative.layers import kv_cache as kv_utils
|
22
22
|
from ai_edge_torch.generative.utilities.model_builder import ExportConfig
|
@@ -134,7 +134,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
134
134
|
prompts: torch.Tensor,
|
135
135
|
max_new_tokens: int,
|
136
136
|
pixel_values: torch.Tensor = None,
|
137
|
-
eos_token_id: int =
|
137
|
+
eos_token_id: Optional[int] = None,
|
138
138
|
) -> torch.IntTensor:
|
139
139
|
input_ids = prompts[0].int().tolist()
|
140
140
|
tokens = torch.tensor([input_ids])
|
@@ -146,7 +146,7 @@ class ReauthoredModelWrapper(ModelWrapper):
|
|
146
146
|
)
|
147
147
|
generated_token = logits[0][-1].argmax().item()
|
148
148
|
input_ids.append(generated_token)
|
149
|
-
if generated_token == eos_token_id:
|
149
|
+
if eos_token_id is not None and generated_token == eos_token_id:
|
150
150
|
break
|
151
151
|
tokens = torch.tensor([[generated_token]])
|
152
152
|
input_pos = torch.tensor([len(input_ids) - 1])
|
@@ -253,7 +253,7 @@ def verify_model_with_prompts(
|
|
253
253
|
outputs_reauthored = reauthored_model.generate(
|
254
254
|
prompt_tokens,
|
255
255
|
max_new_tokens,
|
256
|
-
eos_token_id=tokenizer.tokenizer
|
256
|
+
eos_token_id=getattr(tokenizer.tokenizer, "eos_token_id", None),
|
257
257
|
)
|
258
258
|
response_reauthored = tokenizer.decode(outputs_reauthored[0])
|
259
259
|
logging.info("outputs from reauthored model: [[%s]]", response_reauthored)
|
@@ -27,6 +27,9 @@ if "PJRT_DEVICE" not in os.environ:
|
|
27
27
|
# https://github.com/google-ai-edge/ai-edge-torch/issues/326
|
28
28
|
os.environ["PJRT_DEVICE"] = "CPU"
|
29
29
|
|
30
|
+
os.environ["EXPERIMENTAL_XLA_UNBOUNDED_DYNAMISM"] = "1"
|
31
|
+
|
32
|
+
|
30
33
|
from ai_edge_torch import model
|
31
34
|
from ai_edge_torch._convert import conversion_utils
|
32
35
|
from ai_edge_torch._convert import signature as signature_module
|
@@ -202,7 +202,7 @@ class MlirLowered:
|
|
202
202
|
target_version = stablehlo.get_minimum_version()
|
203
203
|
else:
|
204
204
|
target_version = stablehlo.get_version_from_compatibility_requirement(
|
205
|
-
stablehlo.StablehloCompatibilityRequirement.
|
205
|
+
stablehlo.StablehloCompatibilityRequirement.WEEK_12
|
206
206
|
)
|
207
207
|
module_bytecode = xla_extension.mlir.serialize_portable_artifact(
|
208
208
|
self.module_bytecode, target_version
|
@@ -222,11 +222,6 @@ class MlirLowered:
|
|
222
222
|
# Lazy importing TF when execution is needed.
|
223
223
|
return self.tf_function(*args)
|
224
224
|
|
225
|
-
def to_flatbuffer(self):
|
226
|
-
from . import tf_integration
|
227
|
-
|
228
|
-
return tf_integration.mlir_to_flatbuffer(self)
|
229
|
-
|
230
225
|
|
231
226
|
# TODO(b/331481564) Make this a ai_edge_torch FX pass.
|
232
227
|
def _convert_i64_to_i32(exported_program: torch.export.ExportedProgram):
|
@@ -12,10 +12,9 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
# ==============================================================================
|
15
|
-
"""APIs to convert lowered MLIR from PyTorch to TensorFlow
|
15
|
+
"""APIs to convert lowered MLIR from PyTorch to TensorFlow artifacts."""
|
16
16
|
|
17
17
|
import re
|
18
|
-
import tempfile
|
19
18
|
|
20
19
|
import tensorflow as tf
|
21
20
|
import torch
|
@@ -104,20 +103,26 @@ def _extract_call_args(
|
|
104
103
|
def _wrap_as_tf_func(lowered, tf_state_dict):
|
105
104
|
"""Build tf.function from lowered and tf_state_dict."""
|
106
105
|
|
107
|
-
|
106
|
+
version = 6
|
107
|
+
if hasattr(tfxla, "call_module_maximum_supported_version"):
|
108
|
+
version = tfxla.call_module_maximum_supported_version()
|
109
|
+
|
110
|
+
def tf_func(*args):
|
108
111
|
t_outs = [torch_dtype_to_tf(sig.dtype) for sig in lowered.output_signature]
|
109
112
|
s_outs = [_get_shape_with_dynamic(sig) for sig in lowered.output_signature]
|
110
113
|
call_args = _extract_call_args(lowered, args, tf_state_dict)
|
111
114
|
return tfxla.call_module(
|
112
115
|
tuple(call_args),
|
113
|
-
version=
|
116
|
+
version=version,
|
114
117
|
Tout=t_outs, # dtype information
|
115
|
-
Sout=s_outs, #
|
118
|
+
Sout=s_outs, # shape information
|
116
119
|
function_list=[],
|
117
|
-
module=lowered.
|
120
|
+
module=lowered.module_bytecode_vhlo,
|
121
|
+
has_token_input_output=False,
|
122
|
+
platforms=["CPU"],
|
118
123
|
)
|
119
124
|
|
120
|
-
return
|
125
|
+
return tf_func
|
121
126
|
|
122
127
|
|
123
128
|
def _make_input_signatures(
|
@@ -149,46 +154,3 @@ def mlir_to_tf_function(lowered: export.MlirLowered):
|
|
149
154
|
_wrap_as_tf_func(lowered, tf_state_dict),
|
150
155
|
input_signature=_make_input_signatures(lowered),
|
151
156
|
)
|
152
|
-
|
153
|
-
|
154
|
-
def mlir_to_flatbuffer(lowered: export.MlirLowered):
|
155
|
-
"""Convert the MLIR lowered to a TFLite flatbuffer binary."""
|
156
|
-
tf_state_dict = _build_tf_state_dict(lowered)
|
157
|
-
signature_names = [tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]
|
158
|
-
tf_signatures = [_make_input_signatures(lowered)]
|
159
|
-
tf_functions = [_wrap_as_tf_func(lowered, tf_state_dict)]
|
160
|
-
|
161
|
-
tf_module = tf.Module()
|
162
|
-
tf_module.f = []
|
163
|
-
|
164
|
-
for tf_sig, func in zip(tf_signatures, tf_functions):
|
165
|
-
tf_module.f.append(
|
166
|
-
tf.function(
|
167
|
-
func,
|
168
|
-
input_signature=tf_sig,
|
169
|
-
)
|
170
|
-
)
|
171
|
-
|
172
|
-
tf_module._variables = list(tf_state_dict.values())
|
173
|
-
|
174
|
-
tf_concrete_funcs = [
|
175
|
-
func.get_concrete_function(*tf_sig)
|
176
|
-
for func, tf_sig in zip(tf_module.f, tf_signatures)
|
177
|
-
]
|
178
|
-
|
179
|
-
# We need to temporarily save since TFLite's from_concrete_functions does not
|
180
|
-
# allow providing names for each of the concrete functions.
|
181
|
-
with tempfile.TemporaryDirectory() as temp_dir_path:
|
182
|
-
tf.saved_model.save(
|
183
|
-
tf_module,
|
184
|
-
temp_dir_path,
|
185
|
-
signatures={
|
186
|
-
sig_name: tf_concrete_funcs[idx]
|
187
|
-
for idx, sig_name in enumerate(signature_names)
|
188
|
-
},
|
189
|
-
)
|
190
|
-
|
191
|
-
converter = tf.lite.TFLiteConverter.from_saved_model(temp_dir_path)
|
192
|
-
tflite_model = converter.convert()
|
193
|
-
|
194
|
-
return tflite_model
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20250107
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -1,11 +1,11 @@
|
|
1
1
|
ai_edge_torch/__init__.py,sha256=rq9ZtMJLG8yYNC4tNE4rpl94UAUClZW7f4GAr6HBVDQ,1208
|
2
|
-
ai_edge_torch/_config.py,sha256=
|
2
|
+
ai_edge_torch/_config.py,sha256=PKtOtBOup-cM0wBdQxby6HzuhLhIC3oq-TBG8FF4znE,2161
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=518ziQ0TUxqum2qZXqlD8qr65pHPh8ZNLnwFC6zvK3k,4253
|
5
5
|
ai_edge_torch/model.py,sha256=N-pNpTxzhaFGhWhnSGd70lBzb9VlEhTOq5mddU7bvvI,5542
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=X0ZEB5T3xcR8MsIE8VOHDAdHnCZTzJLBQQ9j2xZ4_qA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
|
-
ai_edge_torch/_convert/conversion.py,sha256=
|
8
|
+
ai_edge_torch/_convert/conversion.py,sha256=_PoH0E1gbbsWhLGwDRwUtW2G_IgNzNF7pKQbn9ct6-4,5778
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
10
10
|
ai_edge_torch/_convert/converter.py,sha256=075F8LRewk_033Ebsnft7FJr3KgtIbtZ_-8udIPy6ho,9980
|
11
11
|
ai_edge_torch/_convert/signature.py,sha256=rGpBNss3Y9FCCCcdBwDo16KqavJi8N5P0M_6Wnmoz6Y,2320
|
@@ -13,12 +13,12 @@ ai_edge_torch/_convert/to_channel_last_io.py,sha256=_31phf7TYgZY2ftpNbrdlB1RhDiu
|
|
13
13
|
ai_edge_torch/_convert/fx_passes/__init__.py,sha256=NVe-eGcm7j8jZpP2pcMhC8j5dVjgR1pPzyXhHdvKH4E,1267
|
14
14
|
ai_edge_torch/_convert/fx_passes/build_aten_composite_pass.py,sha256=doaww8KqrgRTD5LotBVAIRFsEqzPn9R5lcGehBJOczA,9098
|
15
15
|
ai_edge_torch/_convert/fx_passes/build_interpolate_composite_pass.py,sha256=qb4JBDi4Xca14JJUIcaaZQIJiyqKyHJF49jsRCIFCVA,4335
|
16
|
-
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=
|
16
|
+
ai_edge_torch/_convert/fx_passes/inject_mlir_debuginfo_pass.py,sha256=NxT-iCOHq3r3jeZ8qhNoPXV5w8l2eRMu4yEcBri3NxY,2398
|
17
17
|
ai_edge_torch/_convert/fx_passes/remove_non_user_outputs_pass.py,sha256=f1IUVWyhioOClsMiZzLyynoW2R17U83vA-7Q-3pGPM4,2126
|
18
18
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/__init__.py,sha256=lxnoH-WGLeiQIF8XjMGodjiZEFTxucl7g05N7MR9OPk,796
|
19
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=
|
19
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_check.py,sha256=S_Bniv6jY16oOoFUzlyECQ0I2HDjG2D1MOI-QYPk3jQ,8061
|
20
20
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_mark.py,sha256=4RyGUwR22bZqkn_TnptenFJodc_Q43f4_SBG7gmTbos,1621
|
21
|
-
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=
|
21
|
+
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/layout_rewrite.py,sha256=zoAZ2TXKvxUnWnT11U4tx2uF0J5kkNXydgaW7JzfkXI,13811
|
22
22
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/op_func_registry.py,sha256=bsYnudRlXp1PJlu4GF25KSogSkBGQPSaecBrUTONKaw,1031
|
23
23
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/pass_body.py,sha256=t94Am3iPbYQekg-rrtc-jS_aDWtEgAAj7pAKHrG0-9U,10563
|
24
24
|
ai_edge_torch/_convert/fx_passes/optimize_layout_transposes_pass/utils.py,sha256=YLMttMg5PdvXTtQ8lxpKb434UGVvYVALV1-xeuH4UGc,2131
|
@@ -47,7 +47,7 @@ ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIX
|
|
47
47
|
ai_edge_torch/generative/examples/gemma/convert_gemma1_to_tflite.py,sha256=GhwtQZ1xuMyKJl8qdxU6uKavQnlm5US9xhKJvdmgACc,2309
|
48
48
|
ai_edge_torch/generative/examples/gemma/convert_gemma2_to_tflite.py,sha256=hsy4Gd7Inchi0p_Cc5yecH6vr9A7X4MvmQNfTt8N2sQ,2311
|
49
49
|
ai_edge_torch/generative/examples/gemma/gemma1.py,sha256=w8oWYibZzvEvCDyp39EYyAWmjgJljhzdYPyFCfAWxZA,3497
|
50
|
-
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=
|
50
|
+
ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=whQ6DEnmhmj9hd5OyaoEI-FUNJ4m302vY3Swo_IqQcA,9285
|
51
51
|
ai_edge_torch/generative/examples/gemma/verify_gemma1.py,sha256=ip-Gmk4CI5f0GWSdAIdrectxQWJ0t328KCsA4nfHuGg,1736
|
52
52
|
ai_edge_torch/generative/examples/gemma/verify_gemma2.py,sha256=IoBhEMwH07-tFm5-U6F2hpCsI8xynglhq1x9tIOdaPQ,1322
|
53
53
|
ai_edge_torch/generative/examples/gemma/verify_util.py,sha256=tR8RflXocDZqvuStyw9aFlzuiTllEC8rNnjrxms6_Is,5727
|
@@ -109,7 +109,7 @@ ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=l01oYyJo77INzRwN4xqX
|
|
109
109
|
ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
110
110
|
ai_edge_torch/generative/examples/test_models/convert_toy_model.py,sha256=6-WaNHckq_LlXMVTh8x90MGWeWq2bu_T_XQd3w9FnGg,3261
|
111
111
|
ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=4113jZK-Hu3kYop__WTc8Bq-bG6YzQtADbxHtYPEB4w,5036
|
112
|
-
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=
|
112
|
+
ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=WMl1iuCE8So9FDnxPV0OTMzuPngQUTO61g8rfnBLyB4,4664
|
113
113
|
ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
114
114
|
ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=5rgbTIxHoFg8sTnzrGA_ekT-HJEt9p7Dla7cIY874jU,2338
|
115
115
|
ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=mhJ18rb9sxrYRzv1YSzhbNs97oUZck99avZDcUO2oV8,2800
|
@@ -117,14 +117,14 @@ ai_edge_torch/generative/examples/tiny_llama/verify.py,sha256=7Bk8z033M-BCXJ299f
|
|
117
117
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
118
118
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
119
119
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
120
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
120
|
+
ai_edge_torch/generative/layers/attention.py,sha256=aOoVM1hY7qjvzVQI1-m26p_f9qoTLzXXIy8dNtU8xC4,13199
|
121
121
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=zBVwlBUTs-nStIKCZG0ks5ra7tsqc9ShfakFJKH5rds,7344
|
122
122
|
ai_edge_torch/generative/layers/builder.py,sha256=LXGuSHIx6QZAzLFm7aJvlzoMPgQwbXLFchGEKYwOOUA,5090
|
123
123
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=hdICat-8gW7-vxDAevJQ8NQ-mynllPiqLdXQMF6JMnc,4189
|
124
124
|
ai_edge_torch/generative/layers/kv_cache.py,sha256=DhHIggaOQ2IAY4aRuMAuCLWZv1dBz5PYtmOEjkx9EQY,6291
|
125
125
|
ai_edge_torch/generative/layers/model_config.py,sha256=viX51T_naJ9sPpPxPoMnSueBPYE2zxWNOD0xn0f-_bM,7510
|
126
|
-
ai_edge_torch/generative/layers/normalization.py,sha256=
|
127
|
-
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=
|
126
|
+
ai_edge_torch/generative/layers/normalization.py,sha256=MbwH-n80Fob5YvjBzdqDjBizMHLzSJGYRDdbD-rL5C0,6174
|
127
|
+
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=xxWtlVsGGJkEyXC6PwznubyhJnLPEfSpHOORE_hgxss,2670
|
128
128
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=gXxh3papKy4FBpGEX7VyZ7rZ1Js6aHK70Q6DKrVSckY,4154
|
129
129
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
130
130
|
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=9jKzOfiBQ66bp1ZnVIAoREIifVNFx4aTlQeYMAx2_pA,29062
|
@@ -149,12 +149,12 @@ ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5l
|
|
149
149
|
ai_edge_torch/generative/utilities/converter.py,sha256=hIwWUWjgPvWLATtsYYG6RWbFQWhOr2RpPlMrd-4Am9U,5959
|
150
150
|
ai_edge_torch/generative/utilities/dynamic_update_slice.py,sha256=e2mhx-Vp8sUK4EXoPtpZLSx3TViqLAKs67EhKcXBjAQ,2121
|
151
151
|
ai_edge_torch/generative/utilities/loader.py,sha256=A3SOjPXp--AsvoP1hqj5QKWE4sgxoFc3H5EBUz_Eogc,13531
|
152
|
-
ai_edge_torch/generative/utilities/model_builder.py,sha256=
|
152
|
+
ai_edge_torch/generative/utilities/model_builder.py,sha256=rfD6INxunvDVdiUfTUxD7yy0dRxL74W7kVmZsxUjpOQ,6379
|
153
153
|
ai_edge_torch/generative/utilities/moonshine_loader.py,sha256=_RpFabSqtGH5PHiP3_1f6QfO14qMADUxr_HGRlVDFB0,4891
|
154
154
|
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=dqPD9qRXEWtU3ombslOC-BE2l_dMwHoCNu7NsIJhsso,36158
|
155
155
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
156
156
|
ai_edge_torch/generative/utilities/transformers_verifier.py,sha256=8sp9m_FMcXn7nqOrochtu2jIANkJKhnhIBUmH0ZTDR4,1549
|
157
|
-
ai_edge_torch/generative/utilities/verifier.py,sha256=
|
157
|
+
ai_edge_torch/generative/utilities/verifier.py,sha256=6lnBU9Cy5GanB8JWK3-2_VU3PxqunDWGe-SgSLba5Yw,12065
|
158
158
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
159
159
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
160
160
|
ai_edge_torch/hlfb/mark_pattern/passes.py,sha256=pjkKcI1nHECPluAt87cFBrt1DP0f3ge7rHq1NhCkBIE,1936
|
@@ -166,14 +166,14 @@ ai_edge_torch/lowertools/_shim.py,sha256=Mbg16tnCVK0YsHowfbpqpNX1qySuMLvpGI_-I5S
|
|
166
166
|
ai_edge_torch/lowertools/common_utils.py,sha256=Z7p-ivOHtddktpnHrlDm_dSoTxJOdEjFXIGQbzjgwQo,4504
|
167
167
|
ai_edge_torch/lowertools/odml_torch_utils.py,sha256=dxg2pBuVhSZeY2Ouc0F6nHiZilWZmpSPA7I8kGqSkVI,8282
|
168
168
|
ai_edge_torch/lowertools/test_utils.py,sha256=mdxTlhqHABZEQ_GEmPFCL8LIAWtqRtYZUGdSY1ieZjw,1949
|
169
|
-
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=
|
169
|
+
ai_edge_torch/lowertools/torch_xla_utils.py,sha256=1EytIw2R6dthhLhf69wN1L9BaQTeybCD0wga-PhHcMI,9518
|
170
170
|
ai_edge_torch/lowertools/translate_recipe.py,sha256=ymkBpFqAUiupRWqrPOWiVphKcXR1K5vHK0RjgBFtxlE,5652
|
171
171
|
ai_edge_torch/odml_torch/__init__.py,sha256=S8jOzE9nLof-6es3XDiGJRN-9H_XTxsVm9dE7lD3RWo,812
|
172
172
|
ai_edge_torch/odml_torch/_torch_future.py,sha256=AJ0klpsbu2ZBTfiZlqSOoaYzBVITt40a1fYN8xKkEPw,3044
|
173
173
|
ai_edge_torch/odml_torch/_torch_library.py,sha256=Lw1gqL2HWNRspdTwNhIkYAHDyafHedHtkXyKKxn-Wss,805
|
174
|
-
ai_edge_torch/odml_torch/export.py,sha256=
|
174
|
+
ai_edge_torch/odml_torch/export.py,sha256=sqIMXmxK_qIuVC-_DNJ6wKlIWiXq4_WOCKbSqMRFudg,13293
|
175
175
|
ai_edge_torch/odml_torch/export_utils.py,sha256=QeA37Irlty6AiIBuqmHmJgn3lqahBQ5xsh6IKRoKm1g,4774
|
176
|
-
ai_edge_torch/odml_torch/tf_integration.py,sha256=
|
176
|
+
ai_edge_torch/odml_torch/tf_integration.py,sha256=NN29WeXmHZ0S1RPDFHUnBi2DEjMvAtwczStPYIsQ1w8,4849
|
177
177
|
ai_edge_torch/odml_torch/composite/__init__.py,sha256=71GM_gDZxJyo38ZSoYSwhZX3xKA9rknO93JS9kw9w_c,778
|
178
178
|
ai_edge_torch/odml_torch/composite/mark_tensor.py,sha256=U--rwl-XkWKgkdXCXDn6yySug8FR66o1YFUAIoSaWW4,3523
|
179
179
|
ai_edge_torch/odml_torch/composite/stablehlo_composite_builder.py,sha256=2Y52E_gLeoXpMcPpV-svXsgN3JbEIjnPVjm0xkpTUdQ,3319
|
@@ -203,8 +203,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
203
203
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
204
204
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
205
205
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
206
|
-
ai_edge_torch_nightly-0.3.0.
|
207
|
-
ai_edge_torch_nightly-0.3.0.
|
208
|
-
ai_edge_torch_nightly-0.3.0.
|
209
|
-
ai_edge_torch_nightly-0.3.0.
|
210
|
-
ai_edge_torch_nightly-0.3.0.
|
206
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
207
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/METADATA,sha256=p2F-coQaq7CbpMOkQLVnpFB01cCKqftVRGZ4dCVu8Ck,1966
|
208
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
209
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
210
|
+
ai_edge_torch_nightly-0.3.0.dev20250107.dist-info/RECORD,,
|
File without changes
|
File without changes
|