lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,451 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
"""
|
|
6
|
+
Implementation of the following modules is borrowed from ml-cvnets repo:
|
|
7
|
+
https://github.com/apple/ml-cvnets/blob/main/cvnets/layers/multi_head_attention.py
|
|
8
|
+
https://github.com/apple/ml-cvnets/blob/main/cvnets/text_encoders/transformer.py
|
|
9
|
+
|
|
10
|
+
Please see ACKNOWLEDGEMENTS for license details.
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import List, Optional, Union
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch import Size, Tensor, nn
|
|
17
|
+
from torch.nn import functional as F
|
|
18
|
+
from torchvision.ops import StochasticDepth
|
|
19
|
+
|
|
20
|
+
from ... import logger
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class LayerNormFP32(nn.LayerNorm):
|
|
24
|
+
"""
|
|
25
|
+
Applies `Layer Normalization <https://arxiv.org/abs/1607.06450>`_ over a input tensor with FP32 precision
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(
|
|
29
|
+
self,
|
|
30
|
+
normalized_shape: Union[int, List[int], Size],
|
|
31
|
+
eps: Optional[float] = 1e-5,
|
|
32
|
+
elementwise_affine: Optional[bool] = True,
|
|
33
|
+
*args,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
super().__init__(
|
|
37
|
+
normalized_shape=normalized_shape,
|
|
38
|
+
eps=eps,
|
|
39
|
+
elementwise_affine=elementwise_affine,
|
|
40
|
+
*args,
|
|
41
|
+
**kwargs,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
45
|
+
# Convert input from dtype X to FP32 and perform normalization operation.
|
|
46
|
+
# This may help with underflow/overflow issues that we typically see with normalization layers
|
|
47
|
+
inp_dtype = x.dtype
|
|
48
|
+
return super().forward(x.to(torch.float32)).to(inp_dtype)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_normalization_layer(norm_type, num_features):
|
|
52
|
+
if norm_type == "layer_norm":
|
|
53
|
+
return nn.LayerNorm(num_features)
|
|
54
|
+
elif norm_type == "layer_norm_fp32":
|
|
55
|
+
return LayerNormFP32(num_features)
|
|
56
|
+
else:
|
|
57
|
+
raise NotImplementedError(f"Option: {norm_type} not supported.")
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class PositionalEmbedding(nn.Module):
|
|
61
|
+
def __init__(
|
|
62
|
+
self,
|
|
63
|
+
num_embeddings: int,
|
|
64
|
+
embedding_dim: int,
|
|
65
|
+
padding_idx: Optional[int] = None,
|
|
66
|
+
is_learnable: Optional[bool] = False,
|
|
67
|
+
interpolation_mode: Optional[str] = "bilinear",
|
|
68
|
+
*args,
|
|
69
|
+
**kwargs,
|
|
70
|
+
):
|
|
71
|
+
super().__init__()
|
|
72
|
+
# Add other pos embedding here and logic to choose between them
|
|
73
|
+
module = LearnablePositionalEmbedding
|
|
74
|
+
|
|
75
|
+
self.pos_embed = module(
|
|
76
|
+
num_embeddings=num_embeddings,
|
|
77
|
+
embedding_dim=embedding_dim,
|
|
78
|
+
padding_idx=padding_idx,
|
|
79
|
+
interpolation_mode=interpolation_mode,
|
|
80
|
+
*args,
|
|
81
|
+
**kwargs,
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
|
85
|
+
return self.pos_embed(seq_len, *args, **kwargs)
|
|
86
|
+
|
|
87
|
+
def __repr__(self):
|
|
88
|
+
return self.pos_embed.__repr__()
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class LearnablePositionalEmbedding(nn.Module):
|
|
92
|
+
"""Learnable Positional embedding"""
|
|
93
|
+
|
|
94
|
+
def __init__(
|
|
95
|
+
self,
|
|
96
|
+
num_embeddings: int,
|
|
97
|
+
embedding_dim: int,
|
|
98
|
+
padding_idx: Optional[int] = None,
|
|
99
|
+
interpolation_mode: Optional[str] = "bilinear",
|
|
100
|
+
*args,
|
|
101
|
+
**kwargs,
|
|
102
|
+
):
|
|
103
|
+
super().__init__()
|
|
104
|
+
self.pos_embed = nn.Parameter(torch.empty(1, 1, num_embeddings, embedding_dim))
|
|
105
|
+
self.embedding_dim = embedding_dim
|
|
106
|
+
self.num_embeddings = num_embeddings
|
|
107
|
+
self.padding_idx = padding_idx
|
|
108
|
+
self.interpolation_mode = interpolation_mode
|
|
109
|
+
|
|
110
|
+
self.reset_parameters()
|
|
111
|
+
|
|
112
|
+
def reset_parameters(self) -> None:
|
|
113
|
+
nn.init.trunc_normal_(self.pos_embed, mean=0, std=self.embedding_dim**-0.5)
|
|
114
|
+
if self.padding_idx is not None:
|
|
115
|
+
with torch.no_grad():
|
|
116
|
+
self.pos_embed[:, :, self.padding_idx, ...] = 0.0
|
|
117
|
+
|
|
118
|
+
def forward(self, seq_len: int, *args, **kwargs) -> Tensor:
|
|
119
|
+
# scale pos embedding
|
|
120
|
+
pos_embed = self.pos_embed
|
|
121
|
+
if self.padding_idx is not None:
|
|
122
|
+
with torch.no_grad():
|
|
123
|
+
pos_embed[:, :, self.padding_idx, ...] = 0.0
|
|
124
|
+
|
|
125
|
+
if seq_len != self.num_embeddings:
|
|
126
|
+
pos_embed = F.interpolate(
|
|
127
|
+
pos_embed,
|
|
128
|
+
size=(seq_len, self.embedding_dim),
|
|
129
|
+
mode=self.interpolation_mode,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
# Input is of the form [Batch, Seq_len, Embedding_dim]
|
|
133
|
+
return pos_embed.reshape(1, seq_len, self.embedding_dim)
|
|
134
|
+
|
|
135
|
+
def __repr__(self):
|
|
136
|
+
return "{}(num_embeddings={}, embedding_dim={}, padding_idx={})".format(
|
|
137
|
+
self.__class__.__name__,
|
|
138
|
+
self.num_embeddings,
|
|
139
|
+
self.embedding_dim,
|
|
140
|
+
self.padding_idx,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class MultiHeadAttention(nn.Module):
|
|
145
|
+
"""
|
|
146
|
+
This layer applies a multi-head self- or cross-attention as described in
|
|
147
|
+
`Attention is all you need <https://arxiv.org/abs/1706.03762>`_ paper
|
|
148
|
+
|
|
149
|
+
Args:
|
|
150
|
+
embed_dim (int): :math:`C_{in}` from an expected input of size :math:`(N, S, C_{in})`
|
|
151
|
+
num_heads (int): Number of heads in multi-head attention
|
|
152
|
+
attn_dropout (Optional[float]): Attention dropout. Default: 0.0
|
|
153
|
+
bias (Optional[bool]): Use bias or not. Default: ``True``
|
|
154
|
+
|
|
155
|
+
Shape:
|
|
156
|
+
- Input:
|
|
157
|
+
- Query tensor (x_q) :math:`(N, S, C_{in})` where :math:`N` is batch size, :math:`S` is number of source tokens,
|
|
158
|
+
and :math:`C_{in}` is input embedding dim
|
|
159
|
+
- Optional Key-Value tensor (x_kv) :math:`(N, T, C_{in})` where :math:`T` is number of target tokens
|
|
160
|
+
- Output: same shape as the input
|
|
161
|
+
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
embed_dim: int,
|
|
167
|
+
num_heads: int,
|
|
168
|
+
attn_dropout: Optional[float] = 0.0,
|
|
169
|
+
bias: Optional[bool] = True,
|
|
170
|
+
output_dim: Optional[int] = None,
|
|
171
|
+
*args,
|
|
172
|
+
**kwargs,
|
|
173
|
+
) -> None:
|
|
174
|
+
if output_dim is None:
|
|
175
|
+
output_dim = embed_dim
|
|
176
|
+
super().__init__()
|
|
177
|
+
if embed_dim % num_heads != 0:
|
|
178
|
+
logger.error(
|
|
179
|
+
"Embedding dim must be divisible by number of heads in {}. Got: embed_dim={} and num_heads={}".format(
|
|
180
|
+
self.__class__.__name__, embed_dim, num_heads
|
|
181
|
+
)
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.qkv_proj = nn.Linear(
|
|
185
|
+
in_features=embed_dim, out_features=3 * embed_dim, bias=bias
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self.attn_dropout = nn.Dropout(p=attn_dropout)
|
|
189
|
+
self.out_proj = nn.Linear(
|
|
190
|
+
in_features=embed_dim, out_features=output_dim, bias=bias
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
self.head_dim = embed_dim // num_heads
|
|
194
|
+
self.scaling = self.head_dim**-0.5
|
|
195
|
+
self.softmax = nn.Softmax(dim=-1)
|
|
196
|
+
self.num_heads = num_heads
|
|
197
|
+
self.embed_dim = embed_dim
|
|
198
|
+
self.use_separate_proj_weight = embed_dim != output_dim
|
|
199
|
+
|
|
200
|
+
def __repr__(self):
|
|
201
|
+
return "{}(head_dim={}, num_heads={}, attn_dropout={})".format(
|
|
202
|
+
self.__class__.__name__, self.head_dim, self.num_heads, self.attn_dropout.p
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
def _forward_impl(
|
|
206
|
+
self,
|
|
207
|
+
x_q: Tensor,
|
|
208
|
+
x_kv: Optional[Tensor] = None,
|
|
209
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
210
|
+
attn_mask: Optional[Tensor] = None,
|
|
211
|
+
) -> Tensor:
|
|
212
|
+
# [N, S, C]
|
|
213
|
+
b_sz, S_len, in_channels = x_q.shape
|
|
214
|
+
|
|
215
|
+
if x_kv is None:
|
|
216
|
+
# self-attention
|
|
217
|
+
# [N, S, C] --> [N, S, 3C] --> [N, S, 3, h, c] where C = hc
|
|
218
|
+
qkv = self.qkv_proj(x_q).reshape(b_sz, S_len, 3, self.num_heads, -1)
|
|
219
|
+
# [N, S, 3, h, c] --> [N, h, 3, S, C]
|
|
220
|
+
qkv = qkv.transpose(1, 3).contiguous()
|
|
221
|
+
|
|
222
|
+
# [N, h, 3, S, C] --> [N, h, S, C] x 3
|
|
223
|
+
query, key, value = qkv[:, :, 0], qkv[:, :, 1], qkv[:, :, 2]
|
|
224
|
+
else:
|
|
225
|
+
T_len = x_kv.shape[1]
|
|
226
|
+
|
|
227
|
+
# cross-attention
|
|
228
|
+
# [N, S, C]
|
|
229
|
+
query = F.linear(
|
|
230
|
+
x_q,
|
|
231
|
+
weight=self.qkv_proj.weight[: self.embed_dim, ...],
|
|
232
|
+
bias=self.qkv_proj.bias[: self.embed_dim]
|
|
233
|
+
if self.qkv_proj.bias is not None
|
|
234
|
+
else None,
|
|
235
|
+
)
|
|
236
|
+
# [N, S, C] --> [N, S, h, c] --> [N, h, S, c]
|
|
237
|
+
query = (
|
|
238
|
+
query.reshape(b_sz, S_len, self.num_heads, self.head_dim)
|
|
239
|
+
.transpose(1, 2)
|
|
240
|
+
.contiguous()
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
# [N, T, C] --> [N, T, 2C]
|
|
244
|
+
kv = F.linear(
|
|
245
|
+
x_kv,
|
|
246
|
+
weight=self.qkv_proj.weight[self.embed_dim :, ...],
|
|
247
|
+
bias=self.qkv_proj.bias[self.embed_dim :]
|
|
248
|
+
if self.qkv_proj.bias is not None
|
|
249
|
+
else None,
|
|
250
|
+
)
|
|
251
|
+
# [N, T, 2C] --> [N, T, 2, h, c]
|
|
252
|
+
kv = kv.reshape(b_sz, T_len, 2, self.num_heads, self.head_dim)
|
|
253
|
+
# [N, T, 2, h, c] --> [N, h, 2, T, c]
|
|
254
|
+
kv = kv.transpose(1, 3).contiguous()
|
|
255
|
+
key, value = kv[:, :, 0], kv[:, :, 1]
|
|
256
|
+
|
|
257
|
+
query = query * self.scaling
|
|
258
|
+
|
|
259
|
+
# [N h, T, c] --> [N, h, c, T]
|
|
260
|
+
key = key.transpose(-1, -2)
|
|
261
|
+
|
|
262
|
+
# QK^T
|
|
263
|
+
# [N, h, S, c] x [N, h, c, T] --> [N, h, S, T]
|
|
264
|
+
attn = torch.matmul(query, key)
|
|
265
|
+
|
|
266
|
+
batch_size, num_heads, num_src_tokens, num_tgt_tokens = attn.shape
|
|
267
|
+
if attn_mask is not None:
|
|
268
|
+
# attn_mask shape should be the same as attn
|
|
269
|
+
assert list(attn_mask.shape) == [
|
|
270
|
+
batch_size,
|
|
271
|
+
num_src_tokens,
|
|
272
|
+
num_tgt_tokens,
|
|
273
|
+
], "Shape of attention mask should be [{}, {}, {}]. Got: {}".format(
|
|
274
|
+
batch_size, num_src_tokens, num_tgt_tokens, attn_mask.shape
|
|
275
|
+
)
|
|
276
|
+
# [N, S, T] --> [N, 1, S, T]
|
|
277
|
+
attn_mask = attn_mask.unsqueeze(1)
|
|
278
|
+
attn = attn + attn_mask
|
|
279
|
+
|
|
280
|
+
if key_padding_mask is not None:
|
|
281
|
+
# Do not attend to padding positions
|
|
282
|
+
# key padding mask size is [N, T]
|
|
283
|
+
assert key_padding_mask.dim() == 2 and list(key_padding_mask.shape) == [
|
|
284
|
+
batch_size,
|
|
285
|
+
num_tgt_tokens,
|
|
286
|
+
], "Key_padding_mask should be 2-dimension with shape [{}, {}]. Got: {}".format(
|
|
287
|
+
batch_size, num_tgt_tokens, key_padding_mask.shape
|
|
288
|
+
)
|
|
289
|
+
attn = attn.masked_fill(
|
|
290
|
+
key_padding_mask.unsqueeze(1)
|
|
291
|
+
.unsqueeze(2)
|
|
292
|
+
.to(torch.bool), # [N, T] --> [N, 1, 1, T]
|
|
293
|
+
float("-inf"),
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
attn_dtype = attn.dtype
|
|
297
|
+
attn_as_float = self.softmax(attn.float())
|
|
298
|
+
attn = attn_as_float.to(attn_dtype)
|
|
299
|
+
attn = self.attn_dropout(attn)
|
|
300
|
+
|
|
301
|
+
# weighted sum
|
|
302
|
+
# [N, h, S, T] x [N, h, T, c] --> [N, h, S, c]
|
|
303
|
+
out = torch.matmul(attn, value)
|
|
304
|
+
|
|
305
|
+
# [N, h, S, c] --> [N, S, h, c] --> [N, S, C]
|
|
306
|
+
out = out.transpose(1, 2).reshape(b_sz, S_len, -1)
|
|
307
|
+
out = self.out_proj(out)
|
|
308
|
+
|
|
309
|
+
return out
|
|
310
|
+
|
|
311
|
+
def forward(
|
|
312
|
+
self,
|
|
313
|
+
x_q: Tensor,
|
|
314
|
+
x_kv: Optional[Tensor] = None,
|
|
315
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
316
|
+
attn_mask: Optional[Tensor] = None,
|
|
317
|
+
*args,
|
|
318
|
+
**kwargs,
|
|
319
|
+
) -> Tensor:
|
|
320
|
+
# [Batch , Sequence, Hidden_dim]
|
|
321
|
+
return self._forward_impl(
|
|
322
|
+
x_q=x_q,
|
|
323
|
+
x_kv=x_kv,
|
|
324
|
+
key_padding_mask=key_padding_mask,
|
|
325
|
+
attn_mask=attn_mask,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class TransformerEncoder(nn.Module):
|
|
330
|
+
"""
|
|
331
|
+
This class defines the pre-norm `Transformer encoder <https://arxiv.org/abs/1706.03762>`_
|
|
332
|
+
Args:
|
|
333
|
+
embed_dim: :math:`C_{in}` from an expected input of size :math:`(N, P, C_{in})`.
|
|
334
|
+
ffn_latent_dim: Inner dimension of the FFN.
|
|
335
|
+
num_heads: Number of heads in multi-head attention. Default: 8.
|
|
336
|
+
attn_dropout: Dropout rate for attention in multi-head attention. Default: 0.0
|
|
337
|
+
dropout: Dropout rate. Default: 0.0.
|
|
338
|
+
ffn_dropout: Dropout between FFN layers. Default: 0.0.
|
|
339
|
+
transformer_norm_layer: Normalization layer. Default: layer_norm.
|
|
340
|
+
stochastic_dropout: Stochastic dropout setting. Default: 0.0.
|
|
341
|
+
|
|
342
|
+
Shape:
|
|
343
|
+
- Input: :math:`(N, P, C_{in})` where :math:`N` is batch size, :math:`P` is number of patches,
|
|
344
|
+
and :math:`C_{in}` is input embedding dim
|
|
345
|
+
- Output: same shape as the input
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
def __init__(
|
|
349
|
+
self,
|
|
350
|
+
embed_dim: int,
|
|
351
|
+
ffn_latent_dim: int,
|
|
352
|
+
num_heads: Optional[int] = 8,
|
|
353
|
+
attn_dropout: Optional[float] = 0.0,
|
|
354
|
+
dropout: Optional[float] = 0.0,
|
|
355
|
+
ffn_dropout: Optional[float] = 0.0,
|
|
356
|
+
transformer_norm_layer: Optional[str] = "layer_norm",
|
|
357
|
+
stochastic_dropout: Optional[float] = 0.0,
|
|
358
|
+
*args,
|
|
359
|
+
**kwargs,
|
|
360
|
+
) -> None:
|
|
361
|
+
|
|
362
|
+
super().__init__()
|
|
363
|
+
|
|
364
|
+
# Build attention layer
|
|
365
|
+
attn_unit = MultiHeadAttention(
|
|
366
|
+
embed_dim,
|
|
367
|
+
num_heads,
|
|
368
|
+
attn_dropout=attn_dropout,
|
|
369
|
+
bias=True,
|
|
370
|
+
)
|
|
371
|
+
|
|
372
|
+
self.pre_norm_mha = nn.Sequential(
|
|
373
|
+
get_normalization_layer(
|
|
374
|
+
norm_type=transformer_norm_layer, num_features=embed_dim
|
|
375
|
+
),
|
|
376
|
+
attn_unit,
|
|
377
|
+
nn.Dropout(p=dropout),
|
|
378
|
+
)
|
|
379
|
+
|
|
380
|
+
act_name = nn.GELU()
|
|
381
|
+
self.pre_norm_ffn = nn.Sequential(
|
|
382
|
+
get_normalization_layer(
|
|
383
|
+
norm_type=transformer_norm_layer, num_features=embed_dim
|
|
384
|
+
),
|
|
385
|
+
nn.Linear(in_features=embed_dim, out_features=ffn_latent_dim, bias=True),
|
|
386
|
+
act_name,
|
|
387
|
+
nn.Dropout(p=ffn_dropout),
|
|
388
|
+
nn.Linear(in_features=ffn_latent_dim, out_features=embed_dim, bias=True),
|
|
389
|
+
nn.Dropout(p=dropout),
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
self.drop_path = nn.Identity()
|
|
393
|
+
if stochastic_dropout > 0.0:
|
|
394
|
+
if dropout > 0.0:
|
|
395
|
+
logger.error(
|
|
396
|
+
"Stochastic dropout and dropout are mutually exclusive. "
|
|
397
|
+
"Use either of them, but not both."
|
|
398
|
+
"Got: {} and {}".format(stochastic_dropout, dropout)
|
|
399
|
+
)
|
|
400
|
+
self.drop_path = StochasticDepth(p=stochastic_dropout, mode="row")
|
|
401
|
+
|
|
402
|
+
self.embed_dim = embed_dim
|
|
403
|
+
self.ffn_dim = ffn_latent_dim
|
|
404
|
+
self.ffn_dropout = ffn_dropout
|
|
405
|
+
self.stochastic_dropout = stochastic_dropout
|
|
406
|
+
self.std_dropout = dropout
|
|
407
|
+
self.attn_fn_name = attn_unit.__class__.__name__
|
|
408
|
+
self.act_fn_name = act_name.__class__.__name__
|
|
409
|
+
self.norm_type = transformer_norm_layer
|
|
410
|
+
|
|
411
|
+
def __repr__(self) -> str:
|
|
412
|
+
return "{}(embed_dim={}, ffn_dim={}, dropout={}, ffn_dropout={}, stochastic_dropout={}, attn_fn={}, act_fn={}, norm_fn={})".format(
|
|
413
|
+
self.__class__.__name__,
|
|
414
|
+
self.embed_dim,
|
|
415
|
+
self.ffn_dim,
|
|
416
|
+
self.std_dropout,
|
|
417
|
+
self.ffn_dropout,
|
|
418
|
+
self.stochastic_dropout,
|
|
419
|
+
self.attn_fn_name,
|
|
420
|
+
self.act_fn_name,
|
|
421
|
+
self.norm_type,
|
|
422
|
+
)
|
|
423
|
+
|
|
424
|
+
def forward(
|
|
425
|
+
self,
|
|
426
|
+
x: Tensor,
|
|
427
|
+
x_prev: Optional[Tensor] = None,
|
|
428
|
+
key_padding_mask: Optional[Tensor] = None,
|
|
429
|
+
attn_mask: Optional[Tensor] = None,
|
|
430
|
+
*args,
|
|
431
|
+
**kwargs,
|
|
432
|
+
) -> Tensor:
|
|
433
|
+
|
|
434
|
+
# Multi-head attention
|
|
435
|
+
res = x
|
|
436
|
+
x = self.pre_norm_mha[0](x) # norm
|
|
437
|
+
x = self.pre_norm_mha[1](
|
|
438
|
+
x_q=x,
|
|
439
|
+
x_kv=x_prev,
|
|
440
|
+
key_padding_mask=key_padding_mask,
|
|
441
|
+
attn_mask=attn_mask,
|
|
442
|
+
*args,
|
|
443
|
+
**kwargs,
|
|
444
|
+
) # mha
|
|
445
|
+
|
|
446
|
+
x = self.drop_path(self.pre_norm_mha[2](x)) # applying stochastic depth
|
|
447
|
+
x = x + res
|
|
448
|
+
|
|
449
|
+
# Feed forward network
|
|
450
|
+
x = x + self.drop_path(self.pre_norm_ffn(x))
|
|
451
|
+
return x
|
|
@@ -0,0 +1,113 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
|
|
11
|
+
from ... import logger
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class GlobalPool(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
This layers applies global pooling over a 4D or 5D input tensor
|
|
17
|
+
|
|
18
|
+
Args:
|
|
19
|
+
pool_type (Optional[str]): Pooling type. It can be mean, rms, or abs. Default: `mean`
|
|
20
|
+
keep_dim (Optional[bool]): Do not squeeze the dimensions of a tensor. Default: `False`
|
|
21
|
+
|
|
22
|
+
Shape:
|
|
23
|
+
- Input: :math:`(N, C, H, W)` or :math:`(N, C, D, H, W)`
|
|
24
|
+
- Output: :math:`(N, C, 1, 1)` or :math:`(N, C, 1, 1, 1)` if keep_dim else :math:`(N, C)`
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
pool_types = ["mean", "rms", "abs"]
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
pool_type: Optional[str] = "mean",
|
|
32
|
+
keep_dim: Optional[bool] = False,
|
|
33
|
+
*args,
|
|
34
|
+
**kwargs
|
|
35
|
+
) -> None:
|
|
36
|
+
super().__init__()
|
|
37
|
+
if pool_type not in self.pool_types:
|
|
38
|
+
logger.error(
|
|
39
|
+
"Supported pool types are: {}. Got {}".format(
|
|
40
|
+
self.pool_types, pool_type
|
|
41
|
+
)
|
|
42
|
+
)
|
|
43
|
+
self.pool_type = pool_type
|
|
44
|
+
self.keep_dim = keep_dim
|
|
45
|
+
|
|
46
|
+
def _global_pool(self, x: Tensor, dims: List):
|
|
47
|
+
if self.pool_type == "rms": # root mean square
|
|
48
|
+
x = x**2
|
|
49
|
+
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
|
50
|
+
x = x**-0.5
|
|
51
|
+
elif self.pool_type == "abs": # absolute
|
|
52
|
+
x = torch.mean(torch.abs(x), dim=dims, keepdim=self.keep_dim)
|
|
53
|
+
else:
|
|
54
|
+
# default is mean
|
|
55
|
+
# same as AdaptiveAvgPool
|
|
56
|
+
x = torch.mean(x, dim=dims, keepdim=self.keep_dim)
|
|
57
|
+
return x
|
|
58
|
+
|
|
59
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
60
|
+
if x.dim() == 4:
|
|
61
|
+
dims = [-2, -1]
|
|
62
|
+
elif x.dim() == 5:
|
|
63
|
+
dims = [-3, -2, -1]
|
|
64
|
+
else:
|
|
65
|
+
raise NotImplementedError("Currently 2D and 3D global pooling supported")
|
|
66
|
+
return self._global_pool(x, dims=dims)
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
class GlobalPool2D(nn.Module):
|
|
70
|
+
"""This class implements global pooling with linear projection."""
|
|
71
|
+
|
|
72
|
+
def __init__(self, in_dim: int, out_dim: int, *args, **kwargs) -> None:
|
|
73
|
+
super().__init__()
|
|
74
|
+
scale = in_dim**-0.5
|
|
75
|
+
self.pool = GlobalPool(pool_type="mean", keep_dim=False)
|
|
76
|
+
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
|
77
|
+
self.in_dim = in_dim
|
|
78
|
+
self.out_dim = out_dim
|
|
79
|
+
|
|
80
|
+
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
|
81
|
+
# x is of shape [batch, in_dim]
|
|
82
|
+
assert (
|
|
83
|
+
x.dim() == 4
|
|
84
|
+
), "Input should be 4-dimensional (Batch x in_dim x in_height x in_width). Got: {}".format(
|
|
85
|
+
x.shape
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# [batch, in_dim, in_height, in_width] --> [batch, in_dim]
|
|
89
|
+
x = self.pool(x)
|
|
90
|
+
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
|
91
|
+
x = x @ self.proj
|
|
92
|
+
return x
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
class SimpleImageProjectionHead(nn.Module):
|
|
96
|
+
"""This class implements linear projection head."""
|
|
97
|
+
|
|
98
|
+
def __init__(self, in_dim: int, out_dim: int) -> None:
|
|
99
|
+
super().__init__()
|
|
100
|
+
scale = in_dim**-0.5
|
|
101
|
+
self.proj = nn.Parameter(scale * torch.randn(size=(in_dim, out_dim)))
|
|
102
|
+
self.in_dim = in_dim
|
|
103
|
+
self.out_dim = out_dim
|
|
104
|
+
|
|
105
|
+
def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
|
|
106
|
+
# x is of shape [batch, in_dim]
|
|
107
|
+
assert (
|
|
108
|
+
x.dim() == 2
|
|
109
|
+
), "Input should be 2-dimensional (Batch x in_dim). Got: {}".format(x.shape)
|
|
110
|
+
|
|
111
|
+
# [batch, in_dim] x [in_dim, out_dim] --> [batch, out_dim]
|
|
112
|
+
x = x @ self.proj
|
|
113
|
+
return x
|