ai-edge-torch-nightly 0.3.0.dev20240911__py3-none-any.whl → 0.3.0.dev20240912__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,14 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
@@ -0,0 +1,86 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of converting SmalLM model to multi-signature tflite model."""
17
+
18
+ import os
19
+ import pathlib
20
+
21
+ import ai_edge_torch
22
+ from ai_edge_torch.generative.examples.smallm import smallm
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ from ai_edge_torch.generative.quantize import quant_recipes
25
+ import torch
26
+
27
+
28
+ def convert_smallm_to_tflite(
29
+ checkpoint_path: str,
30
+ prefill_seq_len: int = 512,
31
+ kv_cache_max_len: int = 1024,
32
+ quantize: bool = True,
33
+ ):
34
+ """Converts SmalLM model to multi-signature tflite model.
35
+
36
+ Args:
37
+ checkpoint_path (str): The filepath to the model checkpoint, or directory
38
+ holding the checkpoint.
39
+ prefill_seq_len (int, optional): The maximum size of prefill input tensor.
40
+ Defaults to 512.
41
+ kv_cache_max_len (int, optional): The maximum size of KV cache buffer,
42
+ including both prefill and decode. Defaults to 1024.
43
+ quantize (bool, optional): Whether the model should be quanized. Defaults
44
+ to True.
45
+ """
46
+ pytorch_model = smallm.build_model(
47
+ checkpoint_path, kv_cache_max_len=kv_cache_max_len
48
+ )
49
+ # Tensors used to trace the model graph during conversion.
50
+ prefill_tokens = torch.full((1, prefill_seq_len), 0, dtype=torch.long)
51
+ prefill_input_pos = torch.arange(0, prefill_seq_len)
52
+ decode_token = torch.tensor([[0]], dtype=torch.long)
53
+ decode_input_pos = torch.tensor([0], dtype=torch.int64)
54
+ kv = kv_utils.KVCache.from_model_config(pytorch_model.config)
55
+
56
+ quant_config = quant_recipes.full_int8_dynamic_recipe() if quantize else None
57
+ edge_model = (
58
+ ai_edge_torch.signature(
59
+ 'prefill',
60
+ pytorch_model,
61
+ sample_kwargs={
62
+ 'tokens': prefill_tokens,
63
+ 'input_pos': prefill_input_pos,
64
+ 'kv_cache': kv,
65
+ },
66
+ )
67
+ .signature(
68
+ 'decode',
69
+ pytorch_model,
70
+ sample_kwargs={
71
+ 'tokens': decode_token,
72
+ 'input_pos': decode_input_pos,
73
+ 'kv_cache': kv,
74
+ },
75
+ )
76
+ .convert(quant_config=quant_config)
77
+ )
78
+ quant_suffix = 'q8' if quantize else 'f32'
79
+ edge_model.export(
80
+ f'/tmp/smallm_{quant_suffix}_seq{prefill_seq_len}_ekv{kv_cache_max_len}.tflite'
81
+ )
82
+
83
+
84
+ if __name__ == '__main__':
85
+ path = os.path.join(pathlib.Path.home(), 'Downloads/llm_data/smallm')
86
+ convert_smallm_to_tflite(path)
@@ -0,0 +1,119 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Example of building a SmalLM model."""
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+
22
+ from ai_edge_torch.generative.examples.tiny_llama import tiny_llama
23
+ from ai_edge_torch.generative.layers import kv_cache as kv_utils
24
+ import ai_edge_torch.generative.layers.model_config as cfg
25
+ import ai_edge_torch.generative.utilities.loader as loading_utils
26
+ import numpy as np
27
+ import torch
28
+ from torch import nn
29
+
30
+ TENSOR_NAMES = copy.copy(tiny_llama.TENSOR_NAMES)
31
+ # SmalLM re-uses the embedding as the head projection layer.
32
+ TENSOR_NAMES.lm_head = None
33
+
34
+
35
+ class SmalLM(tiny_llama.TinyLlama):
36
+ """A SmalLM model built from the Edge Generative API layers.
37
+
38
+ SmalLM shares the same architecture as TinyLlama, but with different model
39
+ sizes.
40
+ """
41
+
42
+ def __init__(self, config: cfg.ModelConfig):
43
+ super().__init__(config)
44
+ # SmalLM re-uses the embedding as the head projection layer.
45
+ self.lm_head.weight.data = self.tok_embedding.weight.data
46
+
47
+
48
+ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
49
+ """Returns the model config for a SmalLM 135M model.
50
+
51
+ Args:
52
+ kv_cache_max_len (int): The maximum sequence length of the KV cache. Default
53
+ is 1024.
54
+
55
+ Returns:
56
+ The model config for a SmalLM model.
57
+ """
58
+ attn_config = cfg.AttentionConfig(
59
+ num_heads=9,
60
+ head_dim=64,
61
+ num_query_groups=3,
62
+ rotary_percentage=1.0,
63
+ )
64
+ ff_config = cfg.FeedForwardConfig(
65
+ type=cfg.FeedForwardType.GATED,
66
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
67
+ intermediate_size=1536,
68
+ )
69
+ norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
70
+ config = cfg.ModelConfig(
71
+ vocab_size=49152,
72
+ num_layers=30,
73
+ max_seq_len=2048,
74
+ embedding_dim=576,
75
+ kv_cache_max_len=kv_cache_max_len,
76
+ attn_config=attn_config,
77
+ ff_config=ff_config,
78
+ pre_attention_norm_config=norm_config,
79
+ post_attention_norm_config=norm_config,
80
+ final_norm_config=norm_config,
81
+ enable_hlfb=True,
82
+ )
83
+ return config
84
+
85
+
86
+ def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
87
+ config = get_model_config(**kwargs)
88
+ model = SmalLM(config)
89
+ loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
90
+ # since embedding and lm-head use the same weight, we need to set strict
91
+ # to False.
92
+ loader.load(model, strict=False)
93
+ model.eval()
94
+ return model
95
+
96
+
97
+ def define_and_run(checkpoint_path: str) -> None:
98
+ """Instantiates and runs a SmalLM model."""
99
+
100
+ current_dir = pathlib.Path(__file__).parent.resolve()
101
+ smallm_goldens = torch.load(current_dir / "smallm_lm_logits.pt")
102
+ kv_cache_max_len = 1024
103
+ model = build_model(checkpoint_path, kv_cache_max_len=kv_cache_max_len)
104
+ idx = torch.from_numpy(np.array([[1, 2, 3, 4]]))
105
+ tokens = torch.full((1, kv_cache_max_len), 0, dtype=torch.long, device="cpu")
106
+ tokens[0, :4] = idx
107
+ input_pos = torch.arange(0, kv_cache_max_len)
108
+ kv = kv_utils.KVCache.from_model_config(model.config)
109
+ output = model.forward(tokens, input_pos, kv)
110
+ assert torch.allclose(
111
+ smallm_goldens, output["logits"][0, idx.shape[1] - 1, :], atol=1e-05
112
+ )
113
+
114
+
115
+ if __name__ == "__main__":
116
+ input_checkpoint_path = os.path.join(
117
+ pathlib.Path.home(), "Downloads/llm_data/smallm"
118
+ )
119
+ define_and_run(input_checkpoint_path)
@@ -44,7 +44,7 @@ TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
44
44
  )
45
45
 
46
46
 
47
- class TinyLLamma(nn.Module):
47
+ class TinyLlama(nn.Module):
48
48
  """A TinyLlama model built from the Edge Generative API layers."""
49
49
 
50
50
  def __init__(self, config: cfg.ModelConfig):
@@ -169,7 +169,7 @@ def get_fake_model_config(**kwargs) -> cfg.ModelConfig:
169
169
 
170
170
  def build_model(checkpoint_path: str, **kwargs) -> nn.Module:
171
171
  config = get_model_config(**kwargs)
172
- model = TinyLLamma(config)
172
+ model = TinyLlama(config)
173
173
  loader = loading_utils.ModelLoader(checkpoint_path, TENSOR_NAMES)
174
174
  loader.load(model)
175
175
  model.eval()
@@ -59,9 +59,11 @@ def build_norm(dim: int, config: cfg.NormalizationConfig):
59
59
  zero_centered_gamma=config.zero_centered,
60
60
  )
61
61
  elif config.type == cfg.NormalizationType.LAYER_NORM:
62
- return nn.LayerNorm(dim, eps=config.epsilon)
62
+ return normalization.LayerNorm(dim, config.epsilon, config.enable_hlfb)
63
63
  elif config.type == cfg.NormalizationType.GROUP_NORM:
64
- return nn.GroupNorm(config.group_num, dim, config.epsilon)
64
+ return normalization.GroupNorm(
65
+ config.group_num, dim, config.epsilon, config.enable_hlfb
66
+ )
65
67
  else:
66
68
  raise ValueError("Unsupported norm type.")
67
69
 
@@ -104,6 +104,7 @@ class NormalizationConfig:
104
104
  """Normalizater parameters."""
105
105
 
106
106
  type: NormalizationType = NormalizationType.NONE
107
+ enable_hlfb: bool = False
107
108
  epsilon: float = 1e-5
108
109
  zero_centered: bool = False
109
110
  # Number of groups used in group normalization.
@@ -14,7 +14,10 @@
14
14
  # ==============================================================================
15
15
  # Common normalization layers.
16
16
 
17
+ from ai_edge_torch.hlfb import StableHLOCompositeBuilder
17
18
  import torch
19
+ from torch import nn
20
+ import torch.nn.functional as F
18
21
 
19
22
 
20
23
  # Implementation for RMSNorm from: https://arxiv.org/abs/1910.07467
@@ -58,3 +61,158 @@ class RMSNorm(torch.nn.Module):
58
61
  return output * (1 + self.weight)
59
62
  else:
60
63
  return output * self.weight
64
+
65
+
66
+ class GroupNorm(torch.nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ group_num: int,
71
+ dim: int,
72
+ eps: float = 1e-5,
73
+ enable_hlfb: bool = False,
74
+ ):
75
+ """Initialize the GroupNorm layer.
76
+
77
+ Args:
78
+ group_num (int): Number of groups to separate the channels into.
79
+ dim (int): Dimension of the input tensor.
80
+ eps (float): A small float value to ensure numerical stability (default:
81
+ 1e-6).
82
+ enable_hlfb (bool): Whether to convert this normalization into a single
83
+ op.
84
+ """
85
+ super().__init__()
86
+ self.enable_hlfb = enable_hlfb
87
+ self.group_num = group_num
88
+ self.eps = eps
89
+ self.weight = torch.nn.Parameter(torch.ones(dim))
90
+ self.bias = torch.nn.Parameter(torch.ones(dim))
91
+
92
+ def forward(self, x):
93
+ """Running the forward pass of GroupNorm layer.
94
+
95
+ Args:
96
+ x (torch.Tensor): input tensor.
97
+
98
+ Returns:
99
+ torch.Tensor: output tensor after applying GroupNorm.
100
+ """
101
+ if self.enable_hlfb:
102
+ return group_norm_with_hlfb(
103
+ x,
104
+ self.weight,
105
+ self.bias,
106
+ self.group_num,
107
+ self.eps,
108
+ )
109
+ else:
110
+ return F.group_norm(x, self.group_num, self.weight, self.bias, self.eps)
111
+
112
+
113
+ class LayerNorm(torch.nn.Module):
114
+
115
+ def __init__(self, dim: int, eps: float = 1e-5, enable_hlfb: bool = False):
116
+ """Initialize the LayerNorm layer.
117
+
118
+ Args:
119
+ dim (int): dimension of the input tensor.
120
+ eps (float): A small float value to ensure numerical stability (default:
121
+ 1e-6).
122
+ enable_hlfb (bool): Whether to convert this normalization into a single
123
+ op.
124
+ """
125
+ super().__init__()
126
+ self.enable_hlfb = enable_hlfb
127
+ self.eps = eps
128
+ self.weight = torch.nn.Parameter(torch.ones(dim))
129
+ self.bias = torch.nn.Parameter(torch.ones(dim))
130
+
131
+ def forward(self, x):
132
+ """Running the forward pass of LayerNorm layer.
133
+
134
+ Args:
135
+ x (torch.Tensor): input tensor.
136
+
137
+ Returns:
138
+ torch.Tensor: output tensor after applying LayerNorm.
139
+ """
140
+ if self.enable_hlfb:
141
+ return layer_norm_with_hlfb(
142
+ x,
143
+ self.weight,
144
+ self.bias,
145
+ self.eps,
146
+ )
147
+ else:
148
+ return F.layer_norm(
149
+ x,
150
+ x.shape,
151
+ self.weight.broadcast_to(x.shape),
152
+ self.bias.broadcast_to(x.shape),
153
+ self.eps,
154
+ )
155
+
156
+
157
+ def group_norm_with_hlfb(
158
+ x: torch.Tensor,
159
+ w: torch.Tensor,
160
+ b: torch.Tensor,
161
+ num_groups: int,
162
+ eps: float,
163
+ ):
164
+ """Group Normalization with high-level function boundary enabled.
165
+
166
+ Args:
167
+ x (torch.Tensor): Input tensor for Group Normalization, with BCHW shape.
168
+ w (torch.Tensor): The weight tensor for the normalization.
169
+ b (torch.Tensor): The bias tensor for the normalization.
170
+ num_groups (int): Number of groups to separate the channels into.
171
+ eps (float): A small float value to ensure numerical stability.
172
+
173
+ Returns:
174
+ The output tensor of Group Normalization.
175
+ """
176
+ x = torch.permute(x, (0, 2, 3, 1))
177
+
178
+ builder = StableHLOCompositeBuilder(
179
+ name="odml.group_norm", attr={"num_groups": num_groups, "eps": eps}
180
+ )
181
+ x, w, b = builder.mark_inputs(x, w, b)
182
+ x = torch.permute(x, (0, 3, 1, 2))
183
+ y = F.group_norm(x, num_groups, weight=w, bias=b, eps=eps)
184
+ y = torch.permute(y, (0, 2, 3, 1))
185
+ y = builder.mark_outputs(y)
186
+
187
+ y = torch.permute(y, (0, 3, 1, 2))
188
+ return y
189
+
190
+
191
+ def layer_norm_with_hlfb(
192
+ x: torch.Tensor,
193
+ w: torch.Tensor,
194
+ b: torch.Tensor,
195
+ eps: float,
196
+ ):
197
+ """Layer Normalization with high-level function boundary enabled.
198
+
199
+ Args:
200
+ x (torch.Tensor): Input tensor for Layer Normalization.
201
+ w (torch.Tensor): The weight tensor for the normalization.
202
+ b (torch.Tensor): The bias tensor for the normalization.
203
+ eps (float): A small float value to ensure numerical stability.
204
+
205
+ Returns:
206
+ The output tensor of Layer Normalization.
207
+ """
208
+ builder = StableHLOCompositeBuilder(name="odml.layer_norm", attr={"eps": eps})
209
+ x, w, b = builder.mark_inputs(x, w, b)
210
+ y = F.layer_norm(
211
+ x,
212
+ x.shape,
213
+ weight=w.broadcast_to(x.shape),
214
+ bias=b.broadcast_to(x.shape),
215
+ eps=eps,
216
+ )
217
+ y = builder.mark_outputs(y)
218
+ return y
@@ -122,7 +122,6 @@ class AttentionBlock2D(nn.Module):
122
122
  config.attention_batch_size,
123
123
  config.dim,
124
124
  config.attention_config,
125
- 0,
126
125
  enable_hlfb=config.enable_hlfb,
127
126
  )
128
127
 
@@ -180,7 +179,6 @@ class CrossAttentionBlock2D(nn.Module):
180
179
  config.query_dim,
181
180
  config.cross_dim,
182
181
  config.attention_config,
183
- 0,
184
182
  enable_hlfb=config.enable_hlfb,
185
183
  )
186
184
 
@@ -71,7 +71,7 @@ class TestLoader(googletest.TestCase):
71
71
  safetensors.torch.save_file(test_weights, file_path)
72
72
  cfg = tiny_llama.get_model_config()
73
73
  cfg.num_layers = 1
74
- model = tiny_llama.TinyLLamma(cfg)
74
+ model = tiny_llama.TinyLlama(cfg)
75
75
 
76
76
  loader = loading_utils.ModelLoader(file_path, tiny_llama.TENSOR_NAMES)
77
77
  # if returns successfully, it means all the tensors were initiallized.
@@ -123,7 +123,7 @@ class TestModelConversion(googletest.TestCase):
123
123
  )
124
124
  def test_tiny_llama_multisig(self):
125
125
  config = tiny_llama.get_fake_model_config()
126
- pytorch_model = tiny_llama.TinyLLamma(config).eval()
126
+ pytorch_model = tiny_llama.TinyLlama(config).eval()
127
127
 
128
128
  # prefill
129
129
  seq_len = 10
@@ -16,6 +16,7 @@ from . import _basic
16
16
  from . import _batch_norm
17
17
  from . import _convolution
18
18
  from . import _jax_lowerings
19
+ from . import _layer_norm
19
20
  from . import context
20
21
  from . import registry
21
22
  from . import utils
@@ -167,7 +167,6 @@ lower_by_torch_xla2(torch.ops.aten.mul.Scalar)
167
167
  lower_by_torch_xla2(torch.ops.aten.mul.Tensor)
168
168
  lower_by_torch_xla2(torch.ops.aten.native_batch_norm)
169
169
  lower_by_torch_xla2(torch.ops.aten.native_group_norm)
170
- lower_by_torch_xla2(torch.ops.aten.native_layer_norm)
171
170
  lower_by_torch_xla2(torch.ops.aten.native_layer_norm_backward)
172
171
  lower_by_torch_xla2(torch.ops.aten.ne)
173
172
  lower_by_torch_xla2(torch.ops.aten.neg)
@@ -0,0 +1,78 @@
1
+ # Copyright 2024 The AI Edge Torch Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Provides lowering for coreaten to stablehlo for LayerNorm."""
16
+
17
+ import math
18
+ from typing import Optional
19
+ from ai_edge_torch.odml_torch.lowerings import registry
20
+ from ai_edge_torch.odml_torch.lowerings import utils
21
+ from jax._src.lib.mlir import ir
22
+ from jax._src.lib.mlir.dialects import hlo as stablehlo
23
+ import torch
24
+
25
+
26
+ # native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight,
27
+ # Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)
28
+ @registry.lower(torch.ops.aten.native_layer_norm)
29
+ def _aten_native_layer_norm(
30
+ lctx,
31
+ data: ir.Value,
32
+ normalized_shape: list[int],
33
+ weight: Optional[ir.Value],
34
+ bias: Optional[ir.Value],
35
+ eps: float,
36
+ ):
37
+ data_type: ir.RankedTensorType = data.type
38
+ unnormalized_count = math.prod(data_type.shape) // math.prod(normalized_shape)
39
+ dest_shape = [
40
+ 1,
41
+ unnormalized_count,
42
+ math.prod(normalized_shape),
43
+ ]
44
+ dest_type = ir.RankedTensorType.get(dest_shape, data_type.element_type)
45
+
46
+ reshaped_data = stablehlo.reshape(dest_type, data)
47
+
48
+ one = utils.splat(1, data_type.element_type, [unnormalized_count])
49
+ zero = utils.splat(0, data_type.element_type, [unnormalized_count])
50
+ output, mean, var = stablehlo.batch_norm_training(
51
+ reshaped_data, one, zero, eps, 1
52
+ )
53
+ eps_splat = utils.splat(eps, var.type.element_type, var.type.shape)
54
+ rstd = stablehlo.rsqrt(stablehlo.add(var, eps_splat))
55
+
56
+ stats_shape = data_type.shape[: -1 * len(normalized_shape)] + [1] * len(
57
+ normalized_shape
58
+ )
59
+ stats_type = ir.RankedTensorType.get(stats_shape, data_type.element_type)
60
+ mean = stablehlo.reshape(stats_type, mean)
61
+ rstd = stablehlo.reshape(stats_type, rstd)
62
+
63
+ output = stablehlo.reshape(data_type, output)
64
+
65
+ data_rank = len(data_type.shape)
66
+ normalized_rank = len(normalized_shape)
67
+ if weight is not None:
68
+ weight = stablehlo.broadcast_in_dim(
69
+ data_type, weight, list(range(data_rank - normalized_rank, data_rank))
70
+ )
71
+ output = stablehlo.multiply(weight, output)
72
+ if bias is not None:
73
+ bias = stablehlo.broadcast_in_dim(
74
+ data_type, bias, list(range(data_rank - normalized_rank, data_rank))
75
+ )
76
+ output = stablehlo.add(bias, output)
77
+
78
+ return output, mean, rstd
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240911"
16
+ __version__ = "0.3.0.dev20240912"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240911
3
+ Version: 0.3.0.dev20240912
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=48qP37uHT90YPs4eIUQxCiWVwqGEX3idCUs6mQKvX1U,116
2
2
  ai_edge_torch/config.py,sha256=PCd9PVrbUNeVIUDFUCnW4goDWU4bjouK28yMYU6VOi0,877
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
5
- ai_edge_torch/version.py,sha256=vCTKdj1Lc6r2UbJhIZpLdXauJSS0KfBLzgy9e3D16AA,706
5
+ ai_edge_torch/version.py,sha256=Li1VzlXx5ExydpfV93yVAd78cF1L_g3x30-daYdgsLA,706
6
6
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
7
7
  ai_edge_torch/_convert/conversion.py,sha256=kcv_QgNgeyDmrqwdzHicGNP68w6zF7GJg7YkMEIXp4Q,3759
8
8
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -47,6 +47,9 @@ ai_edge_torch/generative/examples/gemma/gemma2.py,sha256=X6WfUCDJDEqyyEAYGq1lmKt
47
47
  ai_edge_torch/generative/examples/phi/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
48
48
  ai_edge_torch/generative/examples/phi/convert_to_tflite.py,sha256=vqEpZVmB0_wMKcAl6RXm7W57DqPTzEdVVN6W2Z-QYzI,3011
49
49
  ai_edge_torch/generative/examples/phi/phi2.py,sha256=BzvUrClFx5HKf6PYzJc7ba2O3AwYUJE485u5GSOiPy4,6851
50
+ ai_edge_torch/generative/examples/smallm/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
+ ai_edge_torch/generative/examples/smallm/convert_to_tflite.py,sha256=aqqxQMBBO_dtGB1iZ1tpF8hbGpdZkx0VIz62ZqfVMCc,3036
52
+ ai_edge_torch/generative/examples/smallm/smallm.py,sha256=j7SDdcX0WvgQWgpaAi7Gi39Jf0-w9D9PftDbugNrN1M,3919
50
53
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
51
54
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIGa89zC5KSRcJlw2c4ITuw8KcchtfmF55f4g,3545
52
55
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=0WniBWQ6_NcQc5WycX3YRRX7Os9AGQSxfc1m2HKBqg8,4479
@@ -71,21 +74,21 @@ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=5wj2RmQRIwD6O_
71
74
  ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=PbWpfg3AOEZjI1FlnZCxRD-kIKtdkR9AOZ6l-9-TpRA,5664
72
75
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
73
76
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=y4LiWhwgflqrg4WWh3wq5ei3VOT_cV0A62x62qptQiM,3070
74
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=RK7oisSwIPqUWwwE1P-hDJlEnRJJ_V29UjUCxt4xETE,6780
77
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=HwoEWls-uJ7oHj0HYxJtgZZhgiBR_OQPXlR6l14vm5E,6778
75
78
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=fmNNXawJ722M4cTUuTx289rT0NHxBEsOy_k8baqCOms,1173
76
79
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=sXis0U4u-RoIp_NyrmWJNnqFqpqRuZOrhfsJIO6rMps,2028
77
80
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
81
  ai_edge_torch/generative/layers/attention.py,sha256=ee0KHRakhjLjawP32FY2EntxOkyPvjiEZChLnBn_HPc,12601
79
82
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
80
- ai_edge_torch/generative/layers/builder.py,sha256=xb7rjADv3Jm4qfmlYtg6oLLe7ReDE9UjsEqiejPpDD8,4346
83
+ ai_edge_torch/generative/layers/builder.py,sha256=KMwMfZ08r5CXHhcPVZ72nZnIAcsMAIKsv7-QPntlqgI,4418
81
84
  ai_edge_torch/generative/layers/feed_forward.py,sha256=uto7xtwx6jPkk1GZ2x7pSTentQzRrPSKw4_PSE12ahA,3525
82
85
  ai_edge_torch/generative/layers/kv_cache.py,sha256=WDu03NQwkDCrrrT9Du_3ZOxlURZz3XDbS1PLzFozhMI,6013
83
- ai_edge_torch/generative/layers/model_config.py,sha256=WpZ9djUBAZddyeSODHDaVMG37EQqfzGGrlMPi8AA-Hc,5752
84
- ai_edge_torch/generative/layers/normalization.py,sha256=u8lv0p-ktKcRqCDlOqZQa9WQcfDK9JM2IaUQFQdn7xs,1860
86
+ ai_edge_torch/generative/layers/model_config.py,sha256=03tjidDM1uo_H0jsHNjYEUR5R1FEckc1GIxSoE7ItQQ,5780
87
+ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvCU5a88ESd_s5WY34ErKA,6129
85
88
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
86
89
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
87
90
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
88
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=V4zUAqjWeBseMPG9B-93LDv1LM3Dds6Q-H0NxY0koSA,27212
91
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
89
92
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
90
93
  ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
91
94
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
@@ -97,8 +100,8 @@ ai_edge_torch/generative/quantize/quant_recipes.py,sha256=0Kvr_o7pbMnE8VMe6Ml0FB
97
100
  ai_edge_torch/generative/quantize/supported_schemes.py,sha256=FjdycEOvxRgBmQdZVufetPvkDoD7rUowIOSKV9oV5Kk,1418
98
101
  ai_edge_torch/generative/test/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
99
102
  ai_edge_torch/generative/test/test_kv_cache.py,sha256=FU2rmU03Lp-vZ5wWXXCao1WEw7xbpqebFMANL_O2chA,3713
100
- ai_edge_torch/generative/test/test_loader.py,sha256=_y5EHGgoNOmCuYonsB81UJScHVsTAQXUVd44czMAw6k,3379
101
- ai_edge_torch/generative/test/test_model_conversion.py,sha256=OmAHSGkxTNzDX5_kYjK7pxlPk0YZLqL9YiVIJQfuvPc,5889
103
+ ai_edge_torch/generative/test/test_loader.py,sha256=8y74ChO3CZCfEi1eCf3-w47kRgAI4qPYCXpi8rTQXMA,3378
104
+ ai_edge_torch/generative/test/test_model_conversion.py,sha256=SIv7_sc5qHvbHFN8SbAfY00iXGvH7J6cJLkERU_cd5k,5888
102
105
  ai_edge_torch/generative/test/test_model_conversion_large.py,sha256=F3q3K9ZgWBzlLy4WpE8-w6UWSuJ-UoJwMm3N6Zb3Y14,5016
103
106
  ai_edge_torch/generative/test/test_quantize.py,sha256=kY_NRpF-v1i4clqI1CFFWEagJv-5PzBDkeJ2fInl9_w,5913
104
107
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
@@ -135,11 +138,12 @@ ai_edge_torch/odml_torch/debuginfo/_op_polyfill.py,sha256=IvOBQyROI9WHS3umHRxsDW
135
138
  ai_edge_torch/odml_torch/jax_bridge/__init__.py,sha256=Jco5zvejxuyl9xHQxZICAKbkgH7x38qPlwUUpD7S15Q,730
136
139
  ai_edge_torch/odml_torch/jax_bridge/_wrap.py,sha256=drN3L0uTsSjkluKgt6Ngq7b5HLReE_7iAitHpZ9PKqE,5428
137
140
  ai_edge_torch/odml_torch/jax_bridge/utils.py,sha256=T8isGc896VrHZ6c_L5pYmLpolQ7ibcOlgWfPuVFPzIg,2264
138
- ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=GqYk6oBJw7KWeG4_6gxSu_OvYhjJcC2FpGzWPPEdH6w,933
141
+ ai_edge_torch/odml_torch/lowerings/__init__.py,sha256=dE_qzh-OnCNjWzqs1-PHs5PNlRF726qMQKM3tkwAzEs,959
139
142
  ai_edge_torch/odml_torch/lowerings/_basic.py,sha256=wV8AUK8dvjLUy3qjqw_IxpiYVDWUMPNZRfi3XYE_hDs,6972
140
143
  ai_edge_torch/odml_torch/lowerings/_batch_norm.py,sha256=PaLI0BB6pdBW1VyfW8VTOT_Be-ZcqYdNOsyfzKfq8Cg,2064
141
144
  ai_edge_torch/odml_torch/lowerings/_convolution.py,sha256=v1VdKmL8YLJv3PR9VgyNghO83A25PpTzY2ZUAJqlq3Q,6847
142
- ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=s-cT_tIQHu7w5hXl8MCixRxLlHplpXW-UCzHT9TY--o,10621
145
+ ai_edge_torch/odml_torch/lowerings/_jax_lowerings.py,sha256=Ii1akrKLhRTkZ715JxXBBGKv3jGfXReXMQCYNzSnxmM,10567
146
+ ai_edge_torch/odml_torch/lowerings/_layer_norm.py,sha256=1ePJs7oIdUkVdMddFsXMc53qTkEKqGz0ZhQQoNzBa10,2862
143
147
  ai_edge_torch/odml_torch/lowerings/context.py,sha256=jslcCv7r_HtImSRTxJwHAUV_QCu9Jub51lovmoBkmFA,1295
144
148
  ai_edge_torch/odml_torch/lowerings/registry.py,sha256=ES3x_RJ22T5rlmMrlomex2DdcZbhlyVJ7_HS3rjz3Uk,2851
145
149
  ai_edge_torch/odml_torch/lowerings/utils.py,sha256=NczqpsSd3Fn7yVcPC3qllemiZxxDAZgcW1T5l8-W9fE,5593
@@ -151,8 +155,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
151
155
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
152
156
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
153
157
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
154
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
155
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/METADATA,sha256=caHeAQX6pxEskue_BvgwkTfZEsG55rXHFwPDcV9oCN8,1859
156
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
157
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
158
- ai_edge_torch_nightly-0.3.0.dev20240911.dist-info/RECORD,,
158
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
159
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/METADATA,sha256=EjeMjRJ5PeW8Azc8hoiJeMP_WaHUDlCend4DFIeQnzc,1859
160
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
161
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
162
+ ai_edge_torch_nightly-0.3.0.dev20240912.dist-info/RECORD,,