ai-edge-torch-nightly 0.3.0.dev20240915__py3-none-any.whl → 0.3.0.dev20240917__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- ai_edge_torch/generative/examples/stable_diffusion/diffusion.py +6 -0
- ai_edge_torch/generative/layers/attention.py +8 -4
- ai_edge_torch/generative/layers/unet/blocks_2d.py +2 -0
- ai_edge_torch/generative/layers/unet/model_config.py +2 -0
- ai_edge_torch/generative/utilities/stable_diffusion_loader.py +6 -0
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240915.dist-info → ai_edge_torch_nightly-0.3.0.dev20240917.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.3.0.dev20240915.dist-info → ai_edge_torch_nightly-0.3.0.dev20240917.dist-info}/RECORD +11 -11
- {ai_edge_torch_nightly-0.3.0.dev20240915.dist-info → ai_edge_torch_nightly-0.3.0.dev20240917.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240915.dist-info → ai_edge_torch_nightly-0.3.0.dev20240917.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.3.0.dev20240915.dist-info → ai_edge_torch_nightly-0.3.0.dev20240917.dist-info}/top_level.txt +0 -0
@@ -336,6 +336,8 @@ class Diffusion(nn.Module):
|
|
336
336
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
337
337
|
query_dim=output_channel,
|
338
338
|
cross_dim=config.transformer_cross_attention_dim,
|
339
|
+
hidden_dim=output_channel,
|
340
|
+
output_dim=output_channel,
|
339
341
|
attention_batch_size=config.transformer_batch_size,
|
340
342
|
normalization_config=config.transformer_norm_config,
|
341
343
|
attention_config=build_attention_config(
|
@@ -406,6 +408,8 @@ class Diffusion(nn.Module):
|
|
406
408
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
407
409
|
query_dim=mid_block_channels,
|
408
410
|
cross_dim=config.transformer_cross_attention_dim,
|
411
|
+
hidden_dim=mid_block_channels,
|
412
|
+
output_dim=mid_block_channels,
|
409
413
|
attention_batch_size=config.transformer_batch_size,
|
410
414
|
normalization_config=config.transformer_norm_config,
|
411
415
|
attention_config=build_attention_config(
|
@@ -477,6 +481,8 @@ class Diffusion(nn.Module):
|
|
477
481
|
cross_attention_block_config=unet_cfg.CrossAttentionBlock2DConfig(
|
478
482
|
query_dim=output_channel,
|
479
483
|
cross_dim=config.transformer_cross_attention_dim,
|
484
|
+
hidden_dim=output_channel,
|
485
|
+
output_dim=output_channel,
|
480
486
|
attention_batch_size=config.transformer_batch_size,
|
481
487
|
normalization_config=config.transformer_norm_config,
|
482
488
|
attention_config=build_attention_config(
|
@@ -298,6 +298,8 @@ class CrossAttention(nn.Module):
|
|
298
298
|
batch_size: int,
|
299
299
|
query_dim: int,
|
300
300
|
cross_dim: int,
|
301
|
+
hidden_dim: int,
|
302
|
+
output_dim: int,
|
301
303
|
config: cfg.AttentionConfig,
|
302
304
|
enable_hlfb: bool,
|
303
305
|
):
|
@@ -307,6 +309,8 @@ class CrossAttention(nn.Module):
|
|
307
309
|
batch_size (int): batch size of the input tensor.
|
308
310
|
query_dim (int): query tensor's dimension.
|
309
311
|
cross_dim (int): cross attention's dimensions, for key and value tensors.
|
312
|
+
hidden_dim (int): hidden dimension that q, k, v tensors project to.
|
313
|
+
output_dim (int): output tensor's dimension.
|
310
314
|
config (cfg.AttentionConfig): attention specific configurations.
|
311
315
|
enable_hlfb (bool): whether hlfb is enabled or not.
|
312
316
|
"""
|
@@ -314,16 +318,16 @@ class CrossAttention(nn.Module):
|
|
314
318
|
self.config = config
|
315
319
|
self.n_heads = config.num_heads
|
316
320
|
self.q_projection = nn.Linear(
|
317
|
-
query_dim,
|
321
|
+
query_dim, hidden_dim, bias=config.qkv_use_bias
|
318
322
|
)
|
319
323
|
self.k_projection = nn.Linear(
|
320
|
-
cross_dim,
|
324
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
321
325
|
)
|
322
326
|
self.v_projection = nn.Linear(
|
323
|
-
cross_dim,
|
327
|
+
cross_dim, hidden_dim, bias=config.qkv_use_bias
|
324
328
|
)
|
325
329
|
self.output_projection = nn.Linear(
|
326
|
-
|
330
|
+
hidden_dim, output_dim, bias=config.output_proj_use_bias
|
327
331
|
)
|
328
332
|
|
329
333
|
self.sdpa_func = (
|
@@ -68,6 +68,8 @@ class AttentionBlock2DConfig:
|
|
68
68
|
class CrossAttentionBlock2DConfig:
|
69
69
|
query_dim: int
|
70
70
|
cross_dim: int
|
71
|
+
hidden_dim: int
|
72
|
+
output_dim: int
|
71
73
|
normalization_config: layers_cfg.NormalizationConfig
|
72
74
|
attention_config: layers_cfg.AttentionConfig
|
73
75
|
enable_hlfb: bool = True
|
@@ -811,6 +811,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
811
811
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
812
812
|
query_dim=output_channel,
|
813
813
|
cross_dim=config.transformer_cross_attention_dim,
|
814
|
+
hidden_dim=output_channel,
|
815
|
+
output_dim=output_channel,
|
814
816
|
normalization_config=config.transformer_norm_config,
|
815
817
|
attention_config=build_attention_config(
|
816
818
|
num_heads=config.transformer_num_attention_heads,
|
@@ -877,6 +879,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
877
879
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
878
880
|
query_dim=mid_block_channels,
|
879
881
|
cross_dim=config.transformer_cross_attention_dim,
|
882
|
+
hidden_dim=mid_block_channels,
|
883
|
+
output_dim=mid_block_channels,
|
880
884
|
normalization_config=config.transformer_norm_config,
|
881
885
|
attention_config=build_attention_config(
|
882
886
|
num_heads=config.transformer_num_attention_heads,
|
@@ -950,6 +954,8 @@ class DiffusionModelLoader(BaseLoader):
|
|
950
954
|
cross_attention_block_config=unet_config.CrossAttentionBlock2DConfig(
|
951
955
|
query_dim=output_channel,
|
952
956
|
cross_dim=config.transformer_cross_attention_dim,
|
957
|
+
hidden_dim=output_channel,
|
958
|
+
output_dim=output_channel,
|
953
959
|
normalization_config=config.transformer_norm_config,
|
954
960
|
attention_config=build_attention_config(
|
955
961
|
num_heads=config.transformer_num_attention_heads,
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.3.0.
|
3
|
+
Version: 0.3.0.dev20240917
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -3,7 +3,7 @@ ai_edge_torch/config.py,sha256=FMWeCH2b7HYILBvaI1iZNnYCO4WAhDOwBZBmIE-xrF0,909
|
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/fx_pass_base.py,sha256=D86Gw3pIRcpnTebUPKlnPbPGJae1S6Fw4DZZ3ZkD0zw,3730
|
5
5
|
ai_edge_torch/model.py,sha256=NYV6Mkaje_ditIEI_s_7nLP_-8i4kbGM8nRzieVkbUI,5397
|
6
|
-
ai_edge_torch/version.py,sha256=
|
6
|
+
ai_edge_torch/version.py,sha256=Dpo7ejWCykjMm-XsI6Dfr_UDCz2nOg_GGvS1If4XnfA,706
|
7
7
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
8
8
|
ai_edge_torch/_convert/conversion.py,sha256=5uPwHhmc6kwiIz-CqaiHDejf2SOWMHrb-rYEHm69wKc,3801
|
9
9
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -57,7 +57,7 @@ ai_edge_torch/generative/examples/stable_diffusion/attention.py,sha256=kDWG6MlIG
|
|
57
57
|
ai_edge_torch/generative/examples/stable_diffusion/clip.py,sha256=tL6w2dr6VP66IXjSKo9StDNP-wl0RO3fh6dIliiYlFA,4656
|
58
58
|
ai_edge_torch/generative/examples/stable_diffusion/convert_to_tflite.py,sha256=vfMGI03UL_gfB561t2kzIHuScwnsUmqaPWxgvq_1T5A,5043
|
59
59
|
ai_edge_torch/generative/examples/stable_diffusion/decoder.py,sha256=slieF2-QcDCwd4DRZ7snsZIphT97IXpp4plRRsRSwL8,13983
|
60
|
-
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=
|
60
|
+
ai_edge_torch/generative/examples/stable_diffusion/diffusion.py,sha256=7o-5oJARCm4fhRwmNv84ofmajP5MMIS102vj4d8eeRQ,31248
|
61
61
|
ai_edge_torch/generative/examples/stable_diffusion/encoder.py,sha256=CAPsW84A8f00nS6fLFeh_XUjCPsDCA5UxHOUsMrLfSU,3450
|
62
62
|
ai_edge_torch/generative/examples/stable_diffusion/pipeline.py,sha256=x9lEEENGNbpx6VTf_LTVudd9d6bs9tLvFUKTl252zEY,8623
|
63
63
|
ai_edge_torch/generative/examples/stable_diffusion/tokenizer.py,sha256=xychak9hdLd6ieXBYEwrK2BkF8NRZWZSSCijIsESpBA,3420
|
@@ -80,7 +80,7 @@ ai_edge_torch/generative/examples/tiny_llama/tiny_llama.py,sha256=Upo8jjqR0VKvkd
|
|
80
80
|
ai_edge_torch/generative/fx_passes/__init__.py,sha256=jrzCB3ZyY_t5jJM1e2Czdt3DjAIL43R0_a-T-I7wOzw,1155
|
81
81
|
ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=hhxSQvkDMv0isZJhmuLiod66ZODaJ8uSPSVTJVHBabQ,1931
|
82
82
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
83
|
-
ai_edge_torch/generative/layers/attention.py,sha256=
|
83
|
+
ai_edge_torch/generative/layers/attention.py,sha256=Z0Y_G8IG0LmvLX2u9D8__Fkr22szB-az6wMNnZpzhkA,13233
|
84
84
|
ai_edge_torch/generative/layers/attention_utils.py,sha256=68GXGR2HSWBFViTxX7cHifzVG-kcLS2IL2tQJPIpupg,7344
|
85
85
|
ai_edge_torch/generative/layers/builder.py,sha256=iuAv8D7HY-azBDy7-UBILMdjuKjpe38rE2gK4H3erwE,5092
|
86
86
|
ai_edge_torch/generative/layers/feed_forward.py,sha256=dfS1psdmomgs4EbwzkYyV_xx1xl3P1lU-3GoS8m0Avw,4221
|
@@ -90,9 +90,9 @@ ai_edge_torch/generative/layers/normalization.py,sha256=iod9oNkoDS5m-yFY_Y_XMyvC
|
|
90
90
|
ai_edge_torch/generative/layers/rotary_position_embedding.py,sha256=CZqOoibLcHvUgrgaIIWAlmk3XgE2inzx340MN-npLoU,1347
|
91
91
|
ai_edge_torch/generative/layers/scaled_dot_product_attention.py,sha256=VW-VP8e7FTSPCdu-6DVxpwNrIdgX0R_kq6F6MSEiyXE,3848
|
92
92
|
ai_edge_torch/generative/layers/unet/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
93
|
-
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=
|
93
|
+
ai_edge_torch/generative/layers/unet/blocks_2d.py,sha256=ZA--ohqmTfXeYQaBP1WpwFOf-TGHZmUMONocPL_hlFc,27244
|
94
94
|
ai_edge_torch/generative/layers/unet/builder.py,sha256=zAqWXdimmMrQRhmE_t9XkS68mh6PSrzwb-2NZZXrR5I,1901
|
95
|
-
ai_edge_torch/generative/layers/unet/model_config.py,sha256=
|
95
|
+
ai_edge_torch/generative/layers/unet/model_config.py,sha256=EzF2qpuoW_qBTYO2uuThh4PN0BqF2vXQHgmfJJKVOSg,9244
|
96
96
|
ai_edge_torch/generative/quantize/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
97
97
|
ai_edge_torch/generative/quantize/example.py,sha256=n_YFFP3dpKjeNKYZicDGL5LqtjqwhYEIaDrC6-Ci2vE,1539
|
98
98
|
ai_edge_torch/generative/quantize/quant_attrs.py,sha256=n1Fm8BFC8gJa_oiwwAOOghJyHtOXYZ4q-5ZRy4pHrIw,1957
|
@@ -109,7 +109,7 @@ ai_edge_torch/generative/test/test_quantize.py,sha256=8geJhKwYBU20m0mdGPD1BUFwQ0
|
|
109
109
|
ai_edge_torch/generative/test/utils.py,sha256=YvEhO2HIj1LkBs5du1UxY-cGRW9HMyAYsOUhgsTrTpA,1796
|
110
110
|
ai_edge_torch/generative/utilities/__init__.py,sha256=-_jxnnFnCgnTU4oTm4MnRsvL5lqhomBNdFBbqfmfHPo,720
|
111
111
|
ai_edge_torch/generative/utilities/loader.py,sha256=b9iotIhVDX-Zc9XjIDUaLxnV395AyBnkQe3dV5YA7Co,13297
|
112
|
-
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=
|
112
|
+
ai_edge_torch/generative/utilities/stable_diffusion_loader.py,sha256=sMMidBhGxD-0bJw5FYNVMLb7uIre3zszJ1xBAsyeDGQ,35961
|
113
113
|
ai_edge_torch/generative/utilities/t5_loader.py,sha256=tEsfy8-ymzbbjOIc-oesXF3yGyyWtJgFXn2s7VOavt8,16961
|
114
114
|
ai_edge_torch/hlfb/__init__.py,sha256=sH4um75na-O8tzxN6chFyp6Y4xnexsE7kUQpZySv6dE,735
|
115
115
|
ai_edge_torch/hlfb/mark_pattern/__init__.py,sha256=cjTprggj_cuktSCm7-A25e7Shop3k63ylp7sdZmtZ8o,4790
|
@@ -157,8 +157,8 @@ ai_edge_torch/quantize/quant_config.py,sha256=U0KisSW-uZkoMJcy-ZP9W57p3tsa594fr9
|
|
157
157
|
ai_edge_torch/testing/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
158
158
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
159
159
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
160
|
-
ai_edge_torch_nightly-0.3.0.
|
161
|
-
ai_edge_torch_nightly-0.3.0.
|
162
|
-
ai_edge_torch_nightly-0.3.0.
|
163
|
-
ai_edge_torch_nightly-0.3.0.
|
164
|
-
ai_edge_torch_nightly-0.3.0.
|
160
|
+
ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
161
|
+
ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/METADATA,sha256=Aca58kkMgjlXanUgCx2pk-2dS-ZWe9CSAVLpJozD7V4,1859
|
162
|
+
ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/WHEEL,sha256=eOLhNAGa2EW3wWl_TU484h7q1UNgy0JXjjoqKoxAAQc,92
|
163
|
+
ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
164
|
+
ai_edge_torch_nightly-0.3.0.dev20240917.dist-info/RECORD,,
|
File without changes
|
File without changes
|