ai-edge-torch-nightly 0.3.0.dev20240916__py3-none-any.whl → 0.3.0.dev20240917__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
336
336
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
337
337
  query_dim=output_channel,
338
338
  cross_dim=config.transformer_cross_attention_dim,
339
+ hidden_dim=output_channel,
340
+ output_dim=output_channel,
339
341
  attention_batch_size=config.transformer_batch_size,
340
342
  normalization_config=config.transformer_norm_config,
341
343
  attention_config=build_attention_config(
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
406
408
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
407
409
  query_dim=mid_block_channels,
408
410
  cross_dim=config.transformer_cross_attention_dim,
411
+ hidden_dim=mid_block_channels,
412
+ output_dim=mid_block_channels,
409
413
  attention_batch_size=config.transformer_batch_size,
410
414
  normalization_config=config.transformer_norm_config,
411
415
  attention_config=build_attention_config(
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
477
481
  cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
478
482
  query_dim=output_channel,
479
483
  cross_dim=config.transformer_cross_attention_dim,
484
+ hidden_dim=output_channel,
485
+ output_dim=output_channel,
480
486
  attention_batch_size=config.transformer_batch_size,
481
487
  normalization_config=config.transformer_norm_config,
482
488
  attention_config=build_attention_config(
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
298
298
  batch_size: int,
299
299
  query_dim: int,
300
300
  cross_dim: int,
301
+ hidden_dim: int,
302
+ output_dim: int,
301
303
  config: cfg.AttentionConfig,
302
304
  enable_hlfb: bool,
303
305
  ):
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
307
309
  batch_size (int): batch size of the input tensor.
308
310
  query_dim (int): query tensor's dimension.
309
311
  cross_dim (int): cross attention's dimensions, for key and value tensors.
312
+ hidden_dim (int): hidden dimension that q, k, v tensors project to.
313
+ output_dim (int): output tensor's dimension.
310
314
  config (cfg.AttentionConfig): attention specific configurations.
311
315
  enable_hlfb (bool): whether hlfb is enabled or not.
312
316
  """
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
314
318
  self.config = config
315
319
  self.n_heads = config.num_heads
316
320
  self.q_projection = nn.Linear(
317
- query_dim, query_dim, bias=config.qkv_use_bias
321
+ query_dim, hidden_dim, bias=config.qkv_use_bias
318
322
  )
319
323
  self.k_projection = nn.Linear(
320
- cross_dim, query_dim, bias=config.qkv_use_bias
324
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
321
325
  )
322
326
  self.v_projection = nn.Linear(
323
- cross_dim, query_dim, bias=config.qkv_use_bias
327
+ cross_dim, hidden_dim, bias=config.qkv_use_bias
324
328
  )
325
329
  self.output_projection = nn.Linear(
326
- query_dim, query_dim, bias=config.output_proj_use_bias
330
+ hidden_dim, output_dim, bias=config.output_proj_use_bias
327
331
  )
328
332
 
329
333
  self.sdpa_func = (
@@ -178,6 +178,8 @@ class CrossAttentionBlock2D(nn.Module):
178
178
  config.attention_batch_size,
179
179
  config.query_dim,
180
180
  config.cross_dim,
181
+ config.hidden_dim,
182
+ config.output_dim,
181
183
  config.attention_config,
182
184
  enable_hlfb=config.enable_hlfb,
183
185
  )
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
68
68
  class CrossAttentionBlock2DConfig:
69
69
  query_dim: int
70
70
  cross_dim: int
71
+ hidden_dim: int
72
+ output_dim: int
71
73
  normalization_config: layers_cfg.NormalizationConfig
72
74
  attention_config: layers_cfg.AttentionConfig
73
75
  enable_hlfb: bool = True
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
811
811
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
812
812
  query_dim=output_channel,
813
813
  cross_dim=config.transformer_cross_attention_dim,
814
+ hidden_dim=output_channel,
815
+ output_dim=output_channel,
814
816
  normalization_config=config.transformer_norm_config,
815
817
  attention_config=build_attention_config(
816
818
  num_heads=config.transformer_num_attention_heads,
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
877
879
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
878
880
  query_dim=mid_block_channels,
879
881
  cross_dim=config.transformer_cross_attention_dim,
882
+ hidden_dim=mid_block_channels,
883
+ output_dim=mid_block_channels,
880
884
  normalization_config=config.transformer_norm_config,
881
885
  attention_config=build_attention_config(
882
886
  num_heads=config.transformer_num_attention_heads,
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
950
954
  cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
951
955
  query_dim=output_channel,
952
956
  cross_dim=config.transformer_cross_attention_dim,
957
+ hidden_dim=output_channel,
958
+ output_dim=output_channel,
953
959
  normalization_config=config.transformer_norm_config,
954
960
  attention_config=build_attention_config(
955
961
  num_heads=config.transformer_num_attention_heads,
ai_edge_torch/version.py CHANGED
@@ -13,4 +13,4 @@
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
15
 
16
- __version__ = "0.3.0.dev20240916"
16
+ __version__ = "0.3.0.dev20240917"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ai-edge-torch-nightly
3
- Version: 0.3.0.dev20240916
3
+ Version: 0.3.0.dev20240917
4
4
  Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
5
5
  Home-page: https://github.com/google-ai-edge/ai-edge-torch
6
6
  Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
3
3
  ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
4
4
  ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
5
5
  ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
6
- ai_edge_torch/version.py,sha256=nRUErTd6i3Pxfpnp3BacFfEH5cQbDvxrA6YeTzKNOxU,706
6
+ ai_edge_torch/version.py,sha256=Dpo7ejWCykjMm-XsI6Dfr_UDCz2nOg_GGvS1If4XnfA,706
7
7
  ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
8
8
  ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
9
9
  ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIG
57
57
  ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
58
58
  ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
59
59
  ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
60
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7oUIJ6HO0vmlhFdkXpqGm9KTB-eM4Ob9VrHSDlIGFOg,30926
60
+ ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
61
61
  ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
62
62
  ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
63
63
  ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
@@ -80,7 +80,7 @@ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkd
80
80
  ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
81
81
  ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
82
82
  ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
83
- ai_edge_torch/generative/layers/attention.py,sha256=37Fua94dQSiBA9Y5XvHxGb5IfN8p8UgNgu5YwM1Rmrw,13057
83
+ ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
84
84
  ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
85
85
  ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
86
86
  ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
90
90
  ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
91
91
  ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
92
92
  ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
93
- ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=cpygyJccLq6KHKxV7oz4YKh529YLjC9isupnsVmPi0A,27190
93
+ ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
94
94
  ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
95
- ai_edge_torch/generative/layers/unet/model_config.py,sha256=NvBJj09a7ZC-ChGE_ex-_kLnE_fjzrY6txbLSh1pMKA,9208
95
+ ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
96
96
  ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
97
97
  ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
98
98
  ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
@@ -109,7 +109,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
109
109
  ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
110
110
  ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
111
111
  ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
112
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=pKp3AMSbS3otCvgwJRF5M1l4JRNKk-aCKimXzIMSrds,35679
112
+ ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
113
113
  ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
114
114
  ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
115
115
  ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
@@ -157,8 +157,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
157
157
  ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
158
158
  ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
159
159
  ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
160
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
161
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/METADATA,sha256=yK-gW8Z98p5-9PvIsfCu3f5FAACNAPH5_BecOImrfKo,1859
162
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
163
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
164
- ai_edge_torch_nightly-0.3.0.dev20240916.dist-info/RECORD,,
160
+ ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
161
+ ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/METADATA,sha256=Aca58kkMgjlXanUgCx2pk-2dS-ZWe9CSAVLpJozD7V4,1859
162
+ ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
163
+ ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
164
+ ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/RECORD,,