ai-edge-torch-nightly 0.6.0.dev20250619__py3-none-any.whl → 0.6.0.dev20250620__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ai_edge_torch/generative/examples/gemma3/image_encoder.py +11 -10
- ai_edge_torch/generative/layers/attention_utils.py +9 -3
- ai_edge_torch/version.py +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250620.dist-info}/METADATA +1 -1
- {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250620.dist-info}/RECORD +8 -8
- {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250620.dist-info}/LICENSE +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250620.dist-info}/WHEEL +0 -0
- {ai_edge_torch_nightly-0.6.0.dev20250619.dist-info → ai_edge_torch_nightly-0.6.0.dev20250620.dist-info}/top_level.txt +0 -0
@@ -24,26 +24,27 @@ import torch.nn.functional as F
|
|
24
24
|
|
25
25
|
|
26
26
|
TENSOR_NAMES = loading_utils.ModelLoader.TensorNames(
|
27
|
-
ff_up_proj="
|
28
|
-
ff_down_proj="
|
27
|
+
ff_up_proj="siglip_vision_model.encoder_blocks.{}.mlp.fc1",
|
28
|
+
ff_down_proj="siglip_vision_model.encoder_blocks.{}.mlp.fc2",
|
29
29
|
attn_query_proj=(
|
30
|
-
"
|
30
|
+
"siglip_vision_model.encoder_blocks.{}.self_attn.q_proj"
|
31
31
|
),
|
32
32
|
attn_key_proj=(
|
33
|
-
"
|
33
|
+
"siglip_vision_model.encoder_blocks.{}.self_attn.k_proj"
|
34
34
|
),
|
35
35
|
attn_value_proj=(
|
36
|
-
"
|
36
|
+
"siglip_vision_model.encoder_blocks.{}.self_attn.v_proj"
|
37
37
|
),
|
38
38
|
attn_output_proj=(
|
39
|
-
"
|
39
|
+
"siglip_vision_model.encoder_blocks.{}.self_attn.o_proj"
|
40
40
|
),
|
41
|
-
pre_attn_norm="
|
42
|
-
|
41
|
+
pre_attn_norm="siglip_vision_model.encoder_blocks.{}.layer_norm1",
|
42
|
+
pre_ff_norm="siglip_vision_model.encoder_blocks.{}.layer_norm2",
|
43
|
+
embedding="siglip_vision_model.patch_embedding",
|
43
44
|
embedding_position=(
|
44
|
-
"
|
45
|
+
"siglip_vision_model.position_embedding.weight"
|
45
46
|
),
|
46
|
-
final_norm="
|
47
|
+
final_norm="siglip_vision_model.final_norm",
|
47
48
|
)
|
48
49
|
|
49
50
|
|
@@ -61,6 +61,7 @@ def build_causal_mask_cache(
|
|
61
61
|
size: int,
|
62
62
|
dtype: torch.dtype = torch.float32,
|
63
63
|
device: torch.device = None,
|
64
|
+
mask_value: float = float('-inf'),
|
64
65
|
) -> torch.Tensor:
|
65
66
|
"""Build a cache for causal attention mask.
|
66
67
|
|
@@ -70,6 +71,8 @@ def build_causal_mask_cache(
|
|
70
71
|
torch.float32.
|
71
72
|
device (torch.device, optional): Output tensor's data type. Defaults to
|
72
73
|
None in which case "cpu" is used.
|
74
|
+
mask_value (float, optional): The value to set the mask to. Defaults to
|
75
|
+
float('-inf').
|
73
76
|
|
74
77
|
Returns:
|
75
78
|
torch.Tensor: Causal attention mask.
|
@@ -77,7 +80,7 @@ def build_causal_mask_cache(
|
|
77
80
|
|
78
81
|
if device is None:
|
79
82
|
device = torch.device('cpu')
|
80
|
-
mask = torch.full((size, size),
|
83
|
+
mask = torch.full((size, size), mask_value, dtype=dtype, device=device)
|
81
84
|
return torch.triu(mask, diagonal=1).unsqueeze(0).unsqueeze(0)
|
82
85
|
|
83
86
|
|
@@ -86,6 +89,7 @@ def build_sliding_window_mask_cache(
|
|
86
89
|
window_size: int,
|
87
90
|
dtype: torch.dtype = torch.float32,
|
88
91
|
device: torch.device = None,
|
92
|
+
mask_value: float = float('-inf'),
|
89
93
|
) -> torch.Tensor:
|
90
94
|
"""Build a cache for a sliding window mask.
|
91
95
|
|
@@ -96,18 +100,20 @@ def build_sliding_window_mask_cache(
|
|
96
100
|
torch.float32.
|
97
101
|
device (torch.device, optional): Output tensor's data type. Defaults to
|
98
102
|
None in which case "cpu" is used.
|
103
|
+
mask_value (float, optional): The value to set the mask to. Defaults to
|
104
|
+
float('-inf').
|
99
105
|
|
100
106
|
Returns:
|
101
107
|
torch.Tensor: Causal attention mask.
|
102
108
|
"""
|
103
109
|
|
104
|
-
mask = build_causal_mask_cache(size, dtype, device)
|
110
|
+
mask = build_causal_mask_cache(size, dtype, device, mask_value)
|
105
111
|
all_ones = torch.ones_like(mask)
|
106
112
|
window_size = min(size, window_size)
|
107
113
|
sliding_mask = torch.triu(all_ones, -1 * window_size + 1) * torch.tril(
|
108
114
|
all_ones, window_size - 1
|
109
115
|
)
|
110
|
-
return torch.where(sliding_mask == 1, mask,
|
116
|
+
return torch.where(sliding_mask == 1, mask, mask_value)
|
111
117
|
|
112
118
|
|
113
119
|
def relative_position_bucket(
|
ai_edge_torch/version.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: ai-edge-torch-nightly
|
3
|
-
Version: 0.6.0.
|
3
|
+
Version: 0.6.0.dev20250620
|
4
4
|
Summary: Supporting PyTorch models with the Google AI Edge TFLite runtime.
|
5
5
|
Home-page: https://github.com/google-ai-edge/ai-edge-torch
|
6
6
|
Keywords: On-Device ML,AI,Google,TFLite,PyTorch,LLMs,GenAI
|
@@ -2,7 +2,7 @@ ai_edge_torch/__init__.py,sha256=lemyLCNoGYRnJsmDuGZu7qOqLbLqG6CGDFtu3ue1syU,129
|
|
2
2
|
ai_edge_torch/_config.py,sha256=AiqhbcheF7j_ozIGDLC89k1we95aVgFDa-tR6h7UI0s,2529
|
3
3
|
ai_edge_torch/conftest.py,sha256=r0GTrhMRhlmOGrrkvumHN8hkmyug6WvF60vWq8wRIBI,758
|
4
4
|
ai_edge_torch/model.py,sha256=A7loFu8jE9CsXsfMmHYZ-KDFJiaD8Kkqwm_9d3IVzk0,5638
|
5
|
-
ai_edge_torch/version.py,sha256=
|
5
|
+
ai_edge_torch/version.py,sha256=L-kCkN9vMaBYCuarY-Y8kGgttAEHZdyPWaupqKVMLiA,806
|
6
6
|
ai_edge_torch/_convert/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
7
7
|
ai_edge_torch/_convert/conversion.py,sha256=iQk3R-pLq4c1nfLqPB4xTRj78gghxPGzJCJtILLdg5o,6123
|
8
8
|
ai_edge_torch/_convert/conversion_utils.py,sha256=Sr8qXVcTwc-ZnZmK7yxVrIOOp1S_vNrwzC0zUvLTI2o,2160
|
@@ -74,7 +74,7 @@ ai_edge_torch/generative/examples/gemma3/__init__.py,sha256=JaAnrFoXTl3RJX97Xspk
|
|
74
74
|
ai_edge_torch/generative/examples/gemma3/convert_gemma3_to_tflite.py,sha256=UEDNN3JmI31WfE2pvacxeJpqumKK86L2dEus3yTURaY,2114
|
75
75
|
ai_edge_torch/generative/examples/gemma3/decoder.py,sha256=1UVv9SFFg5degX3wf-Fefx7nor1AzJj2NWBVuo8bRnM,15540
|
76
76
|
ai_edge_torch/generative/examples/gemma3/gemma3.py,sha256=fFMyIS8si3GdwW8EsdhYk1OKyg_27xDv1HTQ2Gv4N8E,6616
|
77
|
-
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=
|
77
|
+
ai_edge_torch/generative/examples/gemma3/image_encoder.py,sha256=OCMIAQfNmPR4uQUAtlYL6j4xkG0dw2Ays4-lnThcWqQ,5110
|
78
78
|
ai_edge_torch/generative/examples/gemma3/verify_gemma3.py,sha256=v8oNXFICmVOtQxfO7IhZ8GnbvotEkDi9lzYHjoQyOso,2464
|
79
79
|
ai_edge_torch/generative/examples/gemma3/verify_util.py,sha256=5OmUwz38kVHYLA-v8U8evvDN9da2WioZtGo-XK6yq1o,10067
|
80
80
|
ai_edge_torch/generative/examples/hammer/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
@@ -171,7 +171,7 @@ ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py,sha256=myGjal5A
|
|
171
171
|
ai_edge_torch/generative/layers/__init__.py,sha256=hHLluseD2R0Hh4W6XZRIXY_dRQeYudjsrKGf6LZz65g,671
|
172
172
|
ai_edge_torch/generative/layers/attention.py,sha256=RaXENRRQo1MsLdt3U8h3kYTCmd6imHQ-aCXtmPXCh_o,13911
|
173
173
|
ai_edge_torch/generative/layers/attention_test.py,sha256=9v8v96TLyFPdqxEylU1JOAeRFAp2s0YoDHZN83SFJJA,4764
|
174
|
-
ai_edge_torch/generative/layers/attention_utils.py,sha256=
|
174
|
+
ai_edge_torch/generative/layers/attention_utils.py,sha256=2qfg7Tzk9ikKph5w3geOHC1I6EyOCdDsWXMr7F7IOZM,7630
|
175
175
|
ai_edge_torch/generative/layers/attention_utils_test.py,sha256=22gQ1gcRPkwqFG3_p82GZfRKVE3udEssSy58wNOqv0w,2431
|
176
176
|
ai_edge_torch/generative/layers/builder.py,sha256=2bUgkyowDkDznkF8XaHyZs4nowHr1QEHYLM7pMaFmIk,4921
|
177
177
|
ai_edge_torch/generative/layers/einsum.py,sha256=EsZSWNVWUs0-1plp4TBnhP4ZhaRDBa2VlDO6hWpUAqU,1288
|
@@ -268,8 +268,8 @@ ai_edge_torch/testing/__init__.py,sha256=_yGgvnBZWb7T3IN3mc4x1sS4vM96HZwM8pwIcPG
|
|
268
268
|
ai_edge_torch/testing/export.py,sha256=k5mGDGzwc23Z4zaIVDs8CNh-oOt64gsf9MS9NjhbPy4,3293
|
269
269
|
ai_edge_torch/testing/model_coverage/__init__.py,sha256=5P8J6Zk5YYtDvTBucFvB9NGSRI7Gw_24WnrbhXgycEE,765
|
270
270
|
ai_edge_torch/testing/model_coverage/model_coverage.py,sha256=UPB448aMDUyC0HNYVqio2rcJPnDN0tBQMP08J6vPYew,4718
|
271
|
-
ai_edge_torch_nightly-0.6.0.
|
272
|
-
ai_edge_torch_nightly-0.6.0.
|
273
|
-
ai_edge_torch_nightly-0.6.0.
|
274
|
-
ai_edge_torch_nightly-0.6.0.
|
275
|
-
ai_edge_torch_nightly-0.6.0.
|
271
|
+
ai_edge_torch_nightly-0.6.0.dev20250620.dist-info/LICENSE,sha256=z8d0m5b2O9McPEK1xHG_dWgUBT6EfBDz6wA0F7xSPTA,11358
|
272
|
+
ai_edge_torch_nightly-0.6.0.dev20250620.dist-info/METADATA,sha256=W4cZLDBaywmznVjn7haIlLen5cKvXF7VYVIapnx4h0E,2074
|
273
|
+
ai_edge_torch_nightly-0.6.0.dev20250620.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
274
|
+
ai_edge_torch_nightly-0.6.0.dev20250620.dist-info/top_level.txt,sha256=5KXRaF2hwkApYxf7Y8y_tVb9aulGTlbOoNdbx1aKRkE,14
|
275
|
+
ai_edge_torch_nightly-0.6.0.dev20250620.dist-info/RECORD,,
|
File without changes
|
File without changes
|