ai-edge-torch-nightly 0.3.0.dev20240911__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmalLM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.smallm import smallm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+ import torch
26
+
27
+
28
+ def convert_smallm_to_tflite(
29
+ checkpoint_path: str,
30
+ prefill_seq_len: int = 512,
31
+ kv_cache_max_len: int = 1024,
32
+ quantize: bool = True,
33
+ ):
34
+ """Converts SmalLM model to multi-signature tflite model.
35
+
36
+ Args:
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
39
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
+ Defaults to 512.
41
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
+ including both prefill and decode. Defaults to 1024.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
45
+ """
46
+ pytorch_model = smallm.build_model(
47
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
+ )
49
+ # Tensors used to trace the model graph during conversion.
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
52
+ decode_token = torch.tensor([[0]], dtype=torch.long)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
+
56
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
+ edge_model = (
58
+ ai_edge_torch.signature(
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
75
+ )
76
+ .convert(quant_config=quant_config)
77
+ )
78
+ quant_suffix = 'q8' if quantize else 'f32'
79
+ edge_model.export(
80
+ f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
+ )
82
+
83
+
84
+ if __name__ == '__main__':
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
+ convert_smallm_to_tflite(path)
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmalLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmalLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmalLM(tiny_llama.TinyLlama):
36
+ """A SmalLM model built from the Edge Generative API layers.
37
+
38
+ SmalLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmalLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmalLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmalLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ config = cfg.ModelConfig(
71
+ vocab_size=49152,
72
+ num_layers=30,
73
+ max_seq_len=2048,
74
+ embedding_dim=576,
75
+ kv_cache_max_len=kv_cache_max_len,
76
+ attn_config=attn_config,
77
+ ff_config=ff_config,
78
+ pre_attention_norm_config=norm_config,
79
+ post_attention_norm_config=norm_config,
80
+ final_norm_config=norm_config,
81
+ enable_hlfb=True,
82
+ )
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ config = get_model_config(**kwargs)
88
+ model = SmalLM(config)
89
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
90
+ # since embedding and lm-head use the same weight, we need to set strict
91
+ # to False.
92
+ loader.load(model, strict=False)
93
+ model.eval()
94
+ return model
95
+
96
+
97
+ def define_and_run(checkpoint_path: str) -> None:
98
+ """Instantiates and runs a SmalLM model."""
99
+
100
+ current_dir = pathlib.Path(__file__).parent.resolve()
101
+ smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
102
+ kv_cache_max_len = 1024
103
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
104
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
105
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
106
+ tokens[0, :4] = idx
107
+ input_pos = torch.arange(0, kv_cache_max_len)
108
+ kv = kv_utils.KVCache.from_model_config(model.config)
109
+ output = model.forward(tokens, input_pos, kv)
110
+ assert torch.allclose(
111
+ smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ input_checkpoint_path = os.path.join(
117
+ pathlib.Path.home(), "Downloads/llm_data/smallm"
118
+ )
119
+ define_and_run(input_checkpoint_path)
@@ -44,7 +44,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
44
44
  )
45
45
 
46
46
 
47
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
48
48
  """A TinyLlama model built from the Edge Generative API layers."""
49
49
 
50
50
  def __init__(self, config: cfg.ModelConfig):
@@ -169,7 +169,7 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
169
169
 
170
170
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
171
171
  config = get_model_config(**kwargs)
172
- model = TinyLLamma(config)
172
+ model = TinyLlama(config)
173
173
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
174
174
  loader.load(model)
175
175
  model.eval()
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
59
  zero_centered_gamma=config.zero_centered,
60
60
  )
61
61
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
62
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
63
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
64
+ return normalization.GroupNorm(
65
+ config.group_num, dim, config.epsilon, config.enable_hlfb
66
+ )
65
67
  else:
66
68
  raise ValueError("Unsupported norm type.")
67
69
 
@@ -104,6 +104,7 @@ class NormalizationConfig:
104
104
  """Normalizater parameters."""
105
105
 
106
106
  type: NormalizationType = NormalizationType.NONE
107
+ enable_hlfb: bool = False
107
108
  epsilon: float = 1e-5
108
109
  zero_centered: bool = False
109
110
  # Number of groups used in group normalization.
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
18
21
 
19
22
 
20
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
58
61
  return output * (1 + self.weight)
59
62
  else:
60
63
  return output * self.weight
64
+
65
+
66
+ class GroupNorm(torch.nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ group_num: int,
71
+ dim: int,
72
+ eps: float = 1e-5,
73
+ enable_hlfb: bool = False,
74
+ ):
75
+ """Initialize the GroupNorm layer.
76
+
77
+ Args:
78
+ group_num (int): Number of groups to separate the channels into.
79
+ dim (int): Dimension of the input tensor.
80
+ eps (float): A small float value to ensure numerical stability (default:
81
+ 1e-6).
82
+ enable_hlfb (bool): Whether to convert this normalization into a single
83
+ op.
84
+ """
85
+ super().__init__()
86
+ self.enable_hlfb = enable_hlfb
87
+ self.group_num = group_num
88
+ self.eps = eps
89
+ self.weight = torch.nn.Parameter(torch.ones(dim))
90
+ self.bias = torch.nn.Parameter(torch.ones(dim))
91
+
92
+ def forward(self, x):
93
+ """Running the forward pass of GroupNorm layer.
94
+
95
+ Args:
96
+ x (torch.Tensor): input tensor.
97
+
98
+ Returns:
99
+ torch.Tensor: output tensor after applying GroupNorm.
100
+ """
101
+ if self.enable_hlfb:
102
+ return group_norm_with_hlfb(
103
+ x,
104
+ self.weight,
105
+ self.bias,
106
+ self.group_num,
107
+ self.eps,
108
+ )
109
+ else:
110
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
111
+
112
+
113
+ class LayerNorm(torch.nn.Module):
114
+
115
+ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
116
+ """Initialize the LayerNorm layer.
117
+
118
+ Args:
119
+ dim (int): dimension of the input tensor.
120
+ eps (float): A small float value to ensure numerical stability (default:
121
+ 1e-6).
122
+ enable_hlfb (bool): Whether to convert this normalization into a single
123
+ op.
124
+ """
125
+ super().__init__()
126
+ self.enable_hlfb = enable_hlfb
127
+ self.eps = eps
128
+ self.weight = torch.nn.Parameter(torch.ones(dim))
129
+ self.bias = torch.nn.Parameter(torch.ones(dim))
130
+
131
+ def forward(self, x):
132
+ """Running the forward pass of LayerNorm layer.
133
+
134
+ Args:
135
+ x (torch.Tensor): input tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: output tensor after applying LayerNorm.
139
+ """
140
+ if self.enable_hlfb:
141
+ return layer_norm_with_hlfb(
142
+ x,
143
+ self.weight,
144
+ self.bias,
145
+ self.eps,
146
+ )
147
+ else:
148
+ return F.layer_norm(
149
+ x,
150
+ x.shape,
151
+ self.weight.broadcast_to(x.shape),
152
+ self.bias.broadcast_to(x.shape),
153
+ self.eps,
154
+ )
155
+
156
+
157
+ def group_norm_with_hlfb(
158
+ x: torch.Tensor,
159
+ w: torch.Tensor,
160
+ b: torch.Tensor,
161
+ num_groups: int,
162
+ eps: float,
163
+ ):
164
+ """Group Normalization with high-level function boundary enabled.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
168
+ w (torch.Tensor): The weight tensor for the normalization.
169
+ b (torch.Tensor): The bias tensor for the normalization.
170
+ num_groups (int): Number of groups to separate the channels into.
171
+ eps (float): A small float value to ensure numerical stability.
172
+
173
+ Returns:
174
+ The output tensor of Group Normalization.
175
+ """
176
+ x = torch.permute(x, (0, 2, 3, 1))
177
+
178
+ builder = StableHLOCompositeBuilder(
179
+ name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
180
+ )
181
+ x, w, b = builder.mark_inputs(x, w, b)
182
+ x = torch.permute(x, (0, 3, 1, 2))
183
+ y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
184
+ y = torch.permute(y, (0, 2, 3, 1))
185
+ y = builder.mark_outputs(y)
186
+
187
+ y = torch.permute(y, (0, 3, 1, 2))
188
+ return y
189
+
190
+
191
+ def layer_norm_with_hlfb(
192
+ x: torch.Tensor,
193
+ w: torch.Tensor,
194
+ b: torch.Tensor,
195
+ eps: float,
196
+ ):
197
+ """Layer Normalization with high-level function boundary enabled.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor for Layer Normalization.
201
+ w (torch.Tensor): The weight tensor for the normalization.
202
+ b (torch.Tensor): The bias tensor for the normalization.
203
+ eps (float): A small float value to ensure numerical stability.
204
+
205
+ Returns:
206
+ The output tensor of Layer Normalization.
207
+ """
208
+ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
+ x, w, b = builder.mark_inputs(x, w, b)
210
+ y = F.layer_norm(
211
+ x,
212
+ x.shape,
213
+ weight=w.broadcast_to(x.shape),
214
+ bias=b.broadcast_to(x.shape),
215
+ eps=eps,
216
+ )
217
+ y = builder.mark_outputs(y)
218
+ return y
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  config.attention_batch_size,
123
123
  config.dim,
124
124
  config.attention_config,
125
- 0,
126
125
  enable_hlfb=config.enable_hlfb,
127
126
  )
128
127
 
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
180
179
  config.query_dim,
181
180
  config.cross_dim,
182
181
  config.attention_config,
183
- 0,
184
182
  enable_hlfb=config.enable_hlfb,
185
183
  )
186
184
 
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
71
71
  safetensors.torch.save_file(test_weights, file_path)
72
72
  cfg = tiny_llama.get_model_config()
73
73
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLLamma(cfg)
74
+ model = tiny_llama.TinyLlama(cfg)
75
75
 
76
76
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
77
  # if returns successfully, it means all the tensors were initiallized.
@@ -123,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
123
123
  )
124
124
  def test_tiny_llama_multisig(self):
125
125
  config = tiny_llama.get_fake_model_config()
126
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
126
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
127
127
 
128
128
  # prefill
129
129
  seq_len = 10
@@ -16,6 +16,7 @@ from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
+ from . import _layer_norm
19
20
  from . import context
20
21
  from . import registry
21
22
  from . import utils
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
167
167
  lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
168
168
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
169
169
  lower_by_torch_xla2(torch.ops.aten.native_group_norm)
170
- lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
171
170
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
172
171
  lower_by_torch_xla2(torch.ops.aten.ne)
173
172
  lower_by_torch_xla2(torch.ops.aten.neg)
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
27
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
28
+ @registry.lower(torch.ops.aten.native_layer_norm)
29
+ def _aten_native_layer_norm(
30
+ lctx,
31
+ data: ir.Value,
32
+ normalized_shape: list[int],
33
+ weight: Optional[ir.Value],
34
+ bias: Optional[ir.Value],
35
+ eps: float,
36
+ ):
37
+ data_type: ir.RankedTensorType = data.type
38
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
39
+ dest_shape = [
40
+ 1,
41
+ unnormalized_count,
42
+ math.prod(normalized_shape),
43
+ ]
44
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
45
+
46
+ reshaped_data = stablehlo.reshape(dest_type, data)
47
+
48
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
49
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
50
+ output, mean, var = stablehlo.batch_norm_training(
51
+ reshaped_data, one, zero, eps, 1
52
+ )
53
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
54
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
55
+
56
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
57
+ normalized_shape
58
+ )
59
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
60
+ mean = stablehlo.reshape(stats_type, mean)
61
+ rstd = stablehlo.reshape(stats_type, rstd)
62
+
63
+ output = stablehlo.reshape(data_type, output)
64
+
65
+ data_rank = len(data_type.shape)
66
+ normalized_rank = len(normalized_shape)
67
+ if weight is not None:
68
+ weight = stablehlo.broadcast_in_dim(
69
+ data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ )
71
+ output = stablehlo.multiply(weight, output)
72
+ if bias is not None:
73
+ bias = stablehlo.broadcast_in_dim(
74
+ data_type, bias, list(range(data_rank - normalized_rank, data_rank))
75
+ )
76
+ output = stablehlo.add(bias, output)
77
+
78
+ return output, mean, rstd
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240911"
16
+ __version__ = "0.3.0.dev20240912"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240911
3
+ Version: 0.3.0.dev20240912
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=vCTKdj1Lc6r2UbJhIZpLdXauJSS0KfBLzgy9e3D16AA,706
5
+ ai_edge_torch/version.py,sha256=Li1VzlXx5ExydpfV93yVAd78cF1L_g3x30-daYdgsLA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -47,6 +47,9 @@ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKt
47
47
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
48
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
49
49
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
50
+ ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
+ ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
52
+ ai_edge_torch/generative/examples/smallm/smallm.py,sha256=j7SDdcX0WvgQWgpaAi7Gi39Jf0-w9D9PftDbugNrN1M,3919
50
53
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
54
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
52
55
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -71,21 +74,21 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_
71
74
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
72
75
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
76
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
74
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=RK7oisSwIPqUWwwE1P-hDJlEnRJJ_V29UjUCxt4xETE,6780
77
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HwoEWls-uJ7oHj0HYxJtgZZhgiBR_OQPXlR6l14vm5E,6778
75
78
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
76
79
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
77
80
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
81
  ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
79
82
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
80
- ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
83
+ ai_edge_torch/generative/layers/builder.py,sha256=KMwMfZ08r5CXHhcPVZ72nZnIAcsMAIKsv7-QPntlqgI,4418
81
84
  ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
82
85
  ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
83
- ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
84
- ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
86
+ ai_edge_torch/generative/layers/model_config.py,sha256=03tjidDM1uo_H0jsHNjYEUR5R1FEckc1GIxSoE7ItQQ,5780
87
+ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
85
88
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
86
89
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
87
90
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
88
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
91
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
89
92
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
90
93
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
91
94
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -97,8 +100,8 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
97
100
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
98
101
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
102
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
100
- ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
101
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=OmAHSGkxTNzDX5_kYjK7pxlPk0YZLqL9YiVIJQfuvPc,5889
103
+ ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
104
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
102
105
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
103
106
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
104
107
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
@@ -135,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
135
138
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
136
139
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
137
140
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
138
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
141
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
139
142
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
140
143
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
141
144
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
142
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
145
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
146
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
143
147
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
144
148
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
145
149
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -151,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
151
155
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
152
156
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
153
157
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
154
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
155
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/METADATA,sha256=caHeAQX6pxEskue_BvgwkTfZEsG55rXHFwPDcV9oCN8,1859
156
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
157
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
158
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/RECORD,,
158
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
159
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/METADATA,sha256=EjeMjRJ5PeW8Azc8hoiJeMP_WaHUDlCend4DFIeQnzc,1859
160
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
161
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
162
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/RECORD,,