ai-edge-torch-nightly 0.2.0.dev20240608__py3-none-any.whl → 0.2.0.dev20240610__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ai-edge-torch-nightly might be problematic. Click here for more details.

Files changed (19) hide show
  1. ai_edge_torch/generative/examples/gemma/gemma.py +1 -1
  2. ai_edge_torch/generative/examples/phi2/phi2.py +1 -1
  3. ai_edge_torch/generative/examples/stable_diffusion/clip.py +1 -1
  4. ai_edge_torch/generative/examples/stable_diffusion/decoder.py +4 -4
  5. ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -1
  6. ai_edge_torch/generative/examples/t5/t5.py +1 -1
  7. ai_edge_torch/generative/examples/test_models/toy_model.py +1 -1
  8. ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py +1 -1
  9. ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py +1 -1
  10. ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py +1 -1
  11. ai_edge_torch/generative/layers/builder.py +31 -9
  12. ai_edge_torch/generative/layers/model_config.py +10 -2
  13. ai_edge_torch/generative/layers/unet/blocks_2d.py +4 -4
  14. ai_edge_torch/generative/layers/unet/model_config.py +4 -4
  15. {ai_edge_torch_nightly-0.2.0.dev20240608.dist-info → ai_edge_torch_nightly-0.2.0.dev20240610.dist-info}/METADATA +1 -1
  16. {ai_edge_torch_nightly-0.2.0.dev20240608.dist-info → ai_edge_torch_nightly-0.2.0.dev20240610.dist-info}/RECORD +19 -19
  17. {ai_edge_torch_nightly-0.2.0.dev20240608.dist-info → ai_edge_torch_nightly-0.2.0.dev20240610.dist-info}/LICENSE +0 -0
  18. {ai_edge_torch_nightly-0.2.0.dev20240608.dist-info → ai_edge_torch_nightly-0.2.0.dev20240610.dist-info}/WHEEL +0 -0
  19. {ai_edge_torch_nightly-0.2.0.dev20240608.dist-info → ai_edge_torch_nightly-0.2.0.dev20240610.dist-info}/top_level.txt +0 -0
@@ -116,7 +116,7 @@ def get_model_config_2b(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
116
116
  )
117
117
  ff_config = cfg.FeedForwardConfig(
118
118
  type=cfg.FeedForwardType.GATED,
119
- activation=cfg.ActivationType.GELU_TANH,
119
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
120
120
  intermediate_size=16384,
121
121
  )
122
122
  norm_config = cfg.NormalizationConfig(
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  )
113
113
  ff_config = cfg.FeedForwardConfig(
114
114
  type=cfg.FeedForwardType.SEQUENTIAL,
115
- activation=cfg.ActivationType.GELU_TANH,
115
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_TANH),
116
116
  intermediate_size=10240,
117
117
  use_bias=True,
118
118
  )
@@ -90,7 +90,7 @@ def get_model_config() -> cfg.ModelConfig:
90
90
 
91
91
  ff_config = cfg.FeedForwardConfig(
92
92
  type=cfg.FeedForwardType.SEQUENTIAL,
93
- activation=cfg.ActivationType.GELU_QUICK,
93
+ activation=cfg.ActivationConfig(cfg.ActivationType.GELU_QUICK),
94
94
  intermediate_size=embedding_dim * 4,
95
95
  use_bias=True,
96
96
  )
@@ -221,7 +221,7 @@ class Decoder(nn.Module):
221
221
  in_channels=prev_output_channel,
222
222
  out_channels=block_out_channels,
223
223
  normalization_config=config.normalization_config,
224
- activation_type=config.activation_type,
224
+ activation_config=config.activation_config,
225
225
  num_layers=config.layers_per_block,
226
226
  add_upsample=not_final_block,
227
227
  upsample_conv=True,
@@ -235,7 +235,7 @@ class Decoder(nn.Module):
235
235
  self.final_norm = layers_builder.build_norm(
236
236
  block_out_channels, config.normalization_config
237
237
  )
238
- self.act_fn = layers_builder.get_activation(config.activation_type)
238
+ self.act_fn = layers_builder.get_activation(config.activation_config)
239
239
  self.conv_out = nn.Conv2d(
240
240
  block_out_channels,
241
241
  config.out_channels,
@@ -287,7 +287,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
287
287
  mid_block_config = unet_cfg.MidBlock2DConfig(
288
288
  in_channels=block_out_channels[-1],
289
289
  normalization_config=norm_config,
290
- activation_type=layers_cfg.ActivationType.SILU,
290
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
291
291
  num_layers=1,
292
292
  attention_block_config=att_config,
293
293
  )
@@ -296,7 +296,7 @@ def get_model_config() -> unet_cfg.AutoEncoderConfig:
296
296
  in_channels=in_channels,
297
297
  latent_channels=latent_channels,
298
298
  out_channels=out_channels,
299
- activation_type=layers_cfg.ActivationType.SILU,
299
+ activation_config=layers_cfg.ActivationConfig(layers_cfg.ActivationType.SILU),
300
300
  block_out_channels=block_out_channels,
301
301
  scaling_factor=scaling_factor,
302
302
  layers_per_block=layers_per_block,
@@ -130,7 +130,7 @@ class Upsample(nn.Module):
130
130
  self.conv = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
131
131
 
132
132
  def forward(self, x):
133
- x = F.interpolate(x, scale_factor=2, mode='nearest')
133
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
134
134
  return self.conv(x)
135
135
 
136
136
 
@@ -237,3 +237,8 @@ class Diffusion(nn.Module):
237
237
  output = self.unet(latent, context, time)
238
238
  output = self.final(output)
239
239
  return output
240
+
241
+
242
+ if __name__ == "__main__":
243
+ diffusion = Diffusion()
244
+ print(diffusion.state_dict().keys())
@@ -349,7 +349,7 @@ def get_model_config_t5() -> cfg.ModelConfig:
349
349
  )
350
350
  ff_config = cfg.FeedForwardConfig(
351
351
  type=cfg.FeedForwardType.SEQUENTIAL,
352
- activation=cfg.ActivationType.RELU,
352
+ activation=cfg.ActivationConfig(cfg.ActivationType.RELU),
353
353
  intermediate_size=3072,
354
354
  )
355
355
  # T5 Confirmed as RMS Norm and eps = 1e-6 TJA.
@@ -76,7 +76,7 @@ def define_and_run() -> None:
76
76
  )
77
77
  ff_config = cfg.FeedForwardConfig(
78
78
  type=cfg.FeedForwardType.GATED,
79
- activation=cfg.ActivationType.SILU,
79
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
80
80
  intermediate_size=256,
81
81
  )
82
82
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -95,7 +95,7 @@ def get_model_config() -> cfg.ModelConfig:
95
95
  )
96
96
  ff_config = cfg.FeedForwardConfig(
97
97
  type=cfg.FeedForwardType.GATED,
98
- activation=cfg.ActivationType.SILU,
98
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
99
99
  intermediate_size=256,
100
100
  )
101
101
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -83,7 +83,7 @@ def get_model_config() -> cfg.ModelConfig:
83
83
  )
84
84
  ff_config = cfg.FeedForwardConfig(
85
85
  type=cfg.FeedForwardType.GATED,
86
- activation=cfg.ActivationType.SILU,
86
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
87
87
  intermediate_size=256,
88
88
  )
89
89
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -112,7 +112,7 @@ def get_model_config(kv_cache_max_len: int = 1024) -> cfg.ModelConfig:
112
112
  )
113
113
  ff_config = cfg.FeedForwardConfig(
114
114
  type=cfg.FeedForwardType.GATED,
115
- activation=cfg.ActivationType.SILU,
115
+ activation=cfg.ActivationConfig(cfg.ActivationType.SILU),
116
116
  intermediate_size=5632,
117
117
  )
118
118
  norm_config = cfg.NormalizationConfig(type=cfg.NormalizationType.RMS_NORM)
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
  # Builder class for individual components.
16
+ import torch
16
17
  from torch import nn
17
18
  import torch.nn.functional as F
18
19
 
@@ -21,6 +22,23 @@ import ai_edge_torch.generative.layers.model_config as cfg
21
22
  import ai_edge_torch.generative.layers.normalization as normalization
22
23
 
23
24
 
25
+ class GeGLU(nn.Module):
26
+ """GeGLU is an activation function which is a variant of GELU.
27
+
28
+ GeGLU(x) = (xW+b) * GELU(xV+c)
29
+ See: https://arxiv.org/abs/2002.05202v1
30
+
31
+ """
32
+
33
+ def __init__(self, d_in: int, d_out: int):
34
+ super().__init__()
35
+ self.proj = nn.Linear(d_in, d_out * 2)
36
+
37
+ def forward(self, x: torch.Tensor):
38
+ x, gate = self.proj(x).chunk(2, dim=-1)
39
+ return x * F.gelu(gate)
40
+
41
+
24
42
  def build_norm(dim: int, config: cfg.NormalizationConfig):
25
43
  """Builder function for normalizers.
26
44
 
@@ -81,29 +99,33 @@ def build_ff(dim: int, config: cfg.FeedForwardConfig):
81
99
  )
82
100
 
83
101
 
84
- def get_activation(type_: cfg.ActivationType):
85
- """Get pytorch callable activation from the name.
102
+ def get_activation(config: cfg.ActivationConfig):
103
+ """Get pytorch callable activation from the activation config.
86
104
 
87
105
  Args:
88
- name (string): activation's name.
106
+ config (cfg.ActivationConfig): activation config.
89
107
 
90
108
  Returns:
91
109
  Activation function.
92
110
 
93
111
  Raises:
94
- ValueError: If activation name is not supported.
112
+ ValueError: If activation config is not supported.
95
113
  """
96
- if type_ == cfg.ActivationType.SILU:
114
+ if config.type == cfg.ActivationType.LINEAR:
115
+ return lambda x: x
116
+ elif config.type == cfg.ActivationType.SILU:
97
117
  return F.silu
98
- elif type_ == cfg.ActivationType.GELU:
118
+ elif config.type == cfg.ActivationType.GELU:
99
119
  return F.gelu
100
- elif type_ == cfg.ActivationType.GELU_TANH:
120
+ elif config.type == cfg.ActivationType.GELU_TANH:
101
121
  return lambda x: F.gelu(x, approximate="tanh")
102
- elif type_ == cfg.ActivationType.GELU_QUICK:
122
+ elif config.type == cfg.ActivationType.GELU_QUICK:
103
123
  # GELU approximation that is fast but somewhat inaccurate.
104
124
  # See: https://github.com/hendrycks/GELUs
105
125
  return lambda x: x * F.sigmoid(1.702 * x)
106
- elif type_ == cfg.ActivationType.RELU:
126
+ elif config.type == cfg.ActivationType.GE_GLU:
127
+ return GeGLU(config.dim_in, config.dim_out)
128
+ elif config.type == cfg.ActivationType.RELU:
107
129
  return F.relu
108
130
  else:
109
131
  raise ValueError("Unsupported activation type.")
@@ -27,7 +27,7 @@ class ActivationType(enum.Enum):
27
27
  SILU = enum.auto()
28
28
  GELU = enum.auto()
29
29
  GELU_TANH = enum.auto()
30
- GELU_QUICK = enum.auto()
30
+ GE_GLU = enum.auto()
31
31
  RELU = enum.auto()
32
32
 
33
33
 
@@ -74,12 +74,20 @@ class AttentionConfig:
74
74
  relative_attention_max_distance: int = 0
75
75
 
76
76
 
77
+ @dataclass
78
+ class ActivationConfig:
79
+ type: ActivationType = ActivationType.LINEAR
80
+ # Dimension of input and output, used in GeGLU.
81
+ dim_in: Optional[int] = None
82
+ dim_out: Optional[int] = None
83
+
84
+
77
85
  @dataclass
78
86
  class FeedForwardConfig:
79
87
  """FeedForward module's parameters."""
80
88
 
81
89
  type: FeedForwardType
82
- activation: ActivationType
90
+ activation: ActivationConfig
83
91
  intermediate_size: int
84
92
  use_bias: bool = False
85
93
 
@@ -53,7 +53,7 @@ class ResidualBlock2D(nn.Module):
53
53
  self.conv_2 = nn.Conv2d(
54
54
  config.out_channels, config.out_channels, kernel_size=3, stride=1, padding=1
55
55
  )
56
- self.act_fn = layers_builder.get_activation(config.activation_type)
56
+ self.act_fn = layers_builder.get_activation(config.activation_config)
57
57
  if config.in_channels == config.out_channels:
58
58
  self.residual_layer = nn.Identity()
59
59
  else:
@@ -167,7 +167,7 @@ class UpDecoderBlock2D(nn.Module):
167
167
  out_channels=config.out_channels,
168
168
  time_embedding_channels=config.time_embedding_channels,
169
169
  normalization_config=config.normalization_config,
170
- activation_type=config.activation_type,
170
+ activation_config=config.activation_config,
171
171
  )
172
172
  )
173
173
  )
@@ -244,7 +244,7 @@ class MidBlock2D(nn.Module):
244
244
  out_channels=config.in_channels,
245
245
  time_embedding_channels=config.time_embedding_channels,
246
246
  normalization_config=config.normalization_config,
247
- activation_type=config.activation_type,
247
+ activation_config=config.activation_config,
248
248
  )
249
249
  )
250
250
  ]
@@ -259,7 +259,7 @@ class MidBlock2D(nn.Module):
259
259
  out_channels=config.in_channels,
260
260
  time_embedding_channels=config.time_embedding_channels,
261
261
  normalization_config=config.normalization_config,
262
- activation_type=config.activation_type,
262
+ activation_config=config.activation_config,
263
263
  )
264
264
  )
265
265
  )
@@ -39,7 +39,7 @@ class ResidualBlock2DConfig:
39
39
  in_channels: int
40
40
  out_channels: int
41
41
  normalization_config: layers_cfg.NormalizationConfig
42
- activation_type: layers_cfg.ActivationType
42
+ activation_config: layers_cfg.ActivationConfig
43
43
  # Optional time embedding channels if the residual block takes a time embedding context as input
44
44
  time_embedding_channels: Optional[int] = None
45
45
 
@@ -56,7 +56,7 @@ class UpDecoderBlock2DConfig:
56
56
  in_channels: int
57
57
  out_channels: int
58
58
  normalization_config: layers_cfg.NormalizationConfig
59
- activation_type: layers_cfg.ActivationType
59
+ activation_config: layers_cfg.ActivationConfig
60
60
  num_layers: int
61
61
  # Optional time embedding channels if the residual blocks take a time embedding context as input
62
62
  time_embedding_channels: Optional[int] = None
@@ -72,7 +72,7 @@ class UpDecoderBlock2DConfig:
72
72
  class MidBlock2DConfig:
73
73
  in_channels: int
74
74
  normalization_config: layers_cfg.NormalizationConfig
75
- activation_type: layers_cfg.ActivationType
75
+ activation_config: layers_cfg.ActivationConfig
76
76
  num_layers: int
77
77
  # Optional time embedding channels if the residual blocks take a time embedding context as input
78
78
  time_embedding_channels: Optional[int] = None
@@ -85,7 +85,7 @@ class AutoEncoderConfig:
85
85
  """Configurations of encoder/decoder in the autoencoder model."""
86
86
 
87
87
  # The activation type of encoder/decoder blocks.
88
- activation_type: layers_cfg.ActivationType
88
+ activation_config: layers_cfg.ActivationConfig
89
89
 
90
90
  # The output channels of each block.
91
91
  block_out_channels: List[int]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.2.0.dev20240608
3
+ Version: 0.2.0.dev20240610
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -34,16 +34,16 @@ ai_edge_torch/generative/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf
34
34
  ai_edge_torch/generative/examples/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
35
35
  ai_edge_torch/generative/examples/gemma/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
36
36
  ai_edge_torch/generative/examples/gemma/convert_to_tflite.py,sha256=dZv3r24uHsTMokEdnl3nf7LpmV0q7FLnVtCuHn5AuUs,2538
37
- ai_edge_torch/generative/examples/gemma/gemma.py,sha256=YF4Ua-1lnL3qhQnh1sY5-HlYw2Dq6ZRm227XyDe7WAw,5913
37
+ ai_edge_torch/generative/examples/gemma/gemma.py,sha256=1lZfXGHmbII4rFu0U2B9NzlJCRhphxtmQtkCHQ39_uw,5935
38
38
  ai_edge_torch/generative/examples/phi2/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
39
39
  ai_edge_torch/generative/examples/phi2/convert_to_tflite.py,sha256=6nOuwx9q3AUlYcjXRRXSr_3M2JKqdJ-vUf-uE3VFYHE,2512
40
- ai_edge_torch/generative/examples/phi2/phi2.py,sha256=VvigzPQ_LJHeADTsMliwFwPe2BcnOhFgKDqr_WZ2JQ8,5540
40
+ ai_edge_torch/generative/examples/phi2/phi2.py,sha256=PMhKC6JCAMYSj2F3UmWHWK4rTcXD-B6PuehaoDccRqk,5562
41
41
  ai_edge_torch/generative/examples/stable_diffusion/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
42
42
  ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=Lo4Dq7a3Kg-lyH56iqGtqCo5UaClQHRCTDdNagXGTo8,3535
43
- ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=yUCJemEh4n8ez-yLgVU0HZAki-PZ9nY04DFjgpx9PUc,3698
43
+ ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=qU1wVEcn_biwCuDguZljhlLGzpLIqgqC31Dh_lXquQc,3720
44
44
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=NmgDo5uAefrhMUbYku0TKHlqzO0NVWI_M1ue8tddQR4,4024
45
- ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=Z1bnZvYtPdwNy706kixVDfL32X-R87B_WF3CcHwiz0o,11038
46
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=TfbfsmuKoGsBENF9fYIAN_SMEQNhj-kjNdqQXFJGxpg,7784
45
+ ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=meW8t-3BDdjFs5vCAf76cn6lGx49a_GcEvnVa9R5if4,11106
46
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=_gEeUxa9Xyd3iLb_fyeUefHKuELVDorDlQs8e7wdXKg,7878
47
47
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=mgbxkeFDMkNIGmnbcFTIFPu8EWKokghiviYIOB2lE3Q,3437
48
48
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=FCbnwlkpYYb-tF7KscbSYjNEdg7XnuLju1cDuIRoQv8,8277
49
49
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=r9RqbyNvuvXOGu3ojtl7ZmbC7o4Pt8aUKAhN1yCdtEc,3397
@@ -55,29 +55,29 @@ ai_edge_torch/generative/examples/stable_diffusion/samplers/k_lms.py,sha256=iPYX
55
55
  ai_edge_torch/generative/examples/stable_diffusion/samplers/sampler.py,sha256=5iRfU5MO6GR6K3WrdddIU_9U7ZZGEEb7zGKVY1WFl-8,1340
56
56
  ai_edge_torch/generative/examples/t5/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
57
57
  ai_edge_torch/generative/examples/t5/convert_to_tflite.py,sha256=bWtwtUacvJOEDUpuYvLTgkP7oTkXKJA-Tf4FPxlD1Cw,4536
58
- ai_edge_torch/generative/examples/t5/t5.py,sha256=q2gG5RRo7RgNzvHXYC0Juh6Tgt5d_RTMSWFaYvOKiZU,21065
58
+ ai_edge_torch/generative/examples/t5/t5.py,sha256=L6YrVzUEzP-Imb8W28LdukFGrx1aWSzz1kyYK_9RFZM,21087
59
59
  ai_edge_torch/generative/examples/t5/t5_attention.py,sha256=rRgwCEdVtzcJEaGbbBjw8HxCxrCX3pXA5nelawdYiME,9036
60
60
  ai_edge_torch/generative/examples/test_models/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
61
- ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=EV07_MEG3fv9g0ZGu9gbBd5BjjrGkxCT1pv7dvhz4TI,3791
62
- ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=rzL5h7Z5DIEgfpc1pWgYHdKt2aR8ha_CUqTKQBSPBaU,5521
63
- ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=MUr6fSj2hBuYSlNbZtrBBpzqB_0WY-l_xYcd_TFFUjY,4831
61
+ ai_edge_torch/generative/examples/test_models/toy_model.py,sha256=CUXsQ_IU96NaCg9jyfeKI0Zz2iWDkJUsPJyPR1Pgz7I,3813
62
+ ai_edge_torch/generative/examples/test_models/toy_model_with_external_kv_cache.py,sha256=zwCmCnhr-vhBwHqv9i7xMasdBGVNqAGxZvWsncsJn58,5543
63
+ ai_edge_torch/generative/examples/test_models/toy_model_with_kv_cache.py,sha256=lfYUiem_Pbn3vGgPx84BeI8n7rN3-1fImwCLm8Eo2U8,4853
64
64
  ai_edge_torch/generative/examples/tiny_llama/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
65
65
  ai_edge_torch/generative/examples/tiny_llama/convert_to_tflite.py,sha256=E4I5OlC4zyl5cxiiu7uTED-zcwYRu210lP1zuT3xLBE,2566
66
- ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=hVGpuI8gpj4Rn9k4otsRE22MSLFHBDlUOgioY6Ru6VI,5629
66
+ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=IFRLPG9wz_aLl_zV_6CETCjSM03ukA6bZqqyDLVACuw,5651
67
67
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
68
68
  ai_edge_torch/generative/layers/attention.py,sha256=Z8gXHYs6h8gaRiYAdvYUbHzg_2EmqfxiChsf_SYraAc,7902
69
69
  ai_edge_torch/generative/layers/attention_utils.py,sha256=hXhuyKblPPxKIRzlAf1YNlwHgpbj-6nReRLhRHELx5k,6350
70
- ai_edge_torch/generative/layers/builder.py,sha256=8cPL1NAutjT6Dwtyy2X7NSaTl9WCUJM5SIrBIDcEvVY,3520
70
+ ai_edge_torch/generative/layers/builder.py,sha256=jAyrR5hsSI0aimKZumyvxdJ1GovERIfsK0g-dezX2gs,4163
71
71
  ai_edge_torch/generative/layers/feed_forward.py,sha256=4j2QaSCw59Jkk_ixKDpKEj7FLRauzuExTiSNRzAjAhE,2820
72
72
  ai_edge_torch/generative/layers/kv_cache.py,sha256=4uiZLO3om5G3--kT04Jt0esEYznbkJ7QLzSHfb8mjc4,3090
73
- ai_edge_torch/generative/layers/model_config.py,sha256=72DXOsFC0buvzZp6YyVjuTVrpphAubBJ5NJWfs3kEwk,4362
73
+ ai_edge_torch/generative/layers/model_config.py,sha256=g_XJXcQOCkE-mt58fSH4-T4GY_uLeMilg6mxwDMCfz4,4557
74
74
  ai_edge_torch/generative/layers/normalization.py,sha256=M27eW3TcNK20oaXClXtfnu0lLWrAGrSKSsbegRWnj3c,1867
75
75
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=12SsCuoRuLNCwnFGe_pHDOZEBwBcqXs87Aj0PaWWw4E,1383
76
76
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=dYafGC205QE5CLIbBTCI-7eVvEGZEHzs1toPEhemeDs,3391
77
77
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
78
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=KuZd2oZhkCQSknSgXMBla-sfYBPUv5bZNf9RYKXHfGg,10052
78
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=7mHyJYq9lq5zVYp4mEz-R8Az3FFngi711YC20KP6ED8,10066
79
79
  ai_edge_torch/generative/layers/unet/builder.py,sha256=iH0_nuY9TF2ap5h1JbGNCOonPTfrXQHcF8U0slrIREM,1210
80
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=LeRGB34fQ73UlknlFpjM9U-SZIRcQDnSmDltJivX-UA,4044
80
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=sbtbDEHmMV9GLKngwjsNvqm8wovLxnlidkQbXdXkXKs,4060
81
81
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
82
82
  ai_edge_torch/generative/quantize/example.py,sha256=t-YwyKSPAG-OZC1DfH-0vfie2RHHpTSQjxUY-tmhu5g,1543
83
83
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=ffBALrrbrfiG_mrOr-f3B1Gc6PlAma9gtvVnfP7SDzI,1862
@@ -107,8 +107,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=ExThdTXqnWmGC3-F6sdXbXr8nYzkEe_qCz
107
107
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
108
108
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
109
109
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=EIyKz-HY70DguWuSrJal8LpYXQ5ZSEUf3ZrVl7jikFM,4286
110
- ai_edge_torch_nightly-0.2.0.dev20240608.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
111
- ai_edge_torch_nightly-0.2.0.dev20240608.dist-info/METADATA,sha256=eM8KLGmQ4Kc6bVDcSDSJ12O-jFwr0ARV7uGeh-T8nvk,1748
112
- ai_edge_torch_nightly-0.2.0.dev20240608.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
- ai_edge_torch_nightly-0.2.0.dev20240608.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
114
- ai_edge_torch_nightly-0.2.0.dev20240608.dist-info/RECORD,,
110
+ ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
111
+ ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/METADATA,sha256=6hL5PV3S56VU2l6xqS-YrmzMZeajtXsikIdR7kDYcWE,1748
112
+ ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
113
+ ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
114
+ ai_edge_torch_nightly-0.2.0.dev20240610.dist-info/RECORD,,