lightly-studio 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of lightly-studio might be problematic. Click here for more details.
- lightly_studio/__init__.py +11 -0
- lightly_studio/api/__init__.py +0 -0
- lightly_studio/api/app.py +110 -0
- lightly_studio/api/cache.py +77 -0
- lightly_studio/api/db.py +133 -0
- lightly_studio/api/db_tables.py +32 -0
- lightly_studio/api/features.py +7 -0
- lightly_studio/api/routes/api/annotation.py +233 -0
- lightly_studio/api/routes/api/annotation_label.py +90 -0
- lightly_studio/api/routes/api/annotation_task.py +38 -0
- lightly_studio/api/routes/api/classifier.py +387 -0
- lightly_studio/api/routes/api/dataset.py +182 -0
- lightly_studio/api/routes/api/dataset_tag.py +257 -0
- lightly_studio/api/routes/api/exceptions.py +96 -0
- lightly_studio/api/routes/api/features.py +17 -0
- lightly_studio/api/routes/api/metadata.py +37 -0
- lightly_studio/api/routes/api/metrics.py +80 -0
- lightly_studio/api/routes/api/sample.py +196 -0
- lightly_studio/api/routes/api/settings.py +45 -0
- lightly_studio/api/routes/api/status.py +19 -0
- lightly_studio/api/routes/api/text_embedding.py +48 -0
- lightly_studio/api/routes/api/validators.py +17 -0
- lightly_studio/api/routes/healthz.py +13 -0
- lightly_studio/api/routes/images.py +104 -0
- lightly_studio/api/routes/webapp.py +51 -0
- lightly_studio/api/server.py +82 -0
- lightly_studio/core/__init__.py +0 -0
- lightly_studio/core/dataset.py +523 -0
- lightly_studio/core/sample.py +77 -0
- lightly_studio/core/start_gui.py +15 -0
- lightly_studio/dataset/__init__.py +0 -0
- lightly_studio/dataset/edge_embedding_generator.py +144 -0
- lightly_studio/dataset/embedding_generator.py +91 -0
- lightly_studio/dataset/embedding_manager.py +163 -0
- lightly_studio/dataset/env.py +16 -0
- lightly_studio/dataset/file_utils.py +35 -0
- lightly_studio/dataset/loader.py +622 -0
- lightly_studio/dataset/mobileclip_embedding_generator.py +144 -0
- lightly_studio/dist_lightly_studio_view_app/_app/env.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/0.DenzbfeK.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/LightlyLogo.BNjCIww-.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans- +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Bold.DGvYQtcs.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Italic-VariableFont_wdth_wght.B4AZ-wl6.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-Regular.DxJTClRG.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-SemiBold.D3TTYgdB.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/OpenSans-VariableFont_wdth_wght.BZBpG5Iz.ttf +0 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.OwPEPQZu.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/SelectableSvgGroup.b653GmVf.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/_layout.T-zjSUd3.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/assets/useFeatureFlags.CV-KWLNP.css +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/69_IOA4Y.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B2FVR0s0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B90CZVMX.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/B9zumHo5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BJXwVxaE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bsi3UGy5.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bu7uvVrG.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/Bx1xMsFy.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/BylOuP6i.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/C8I8rFJQ.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CDnpyLsT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CWj6FrbW.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CYgJF_JY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CcaPhhk3.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CvOmgdoc.js +93 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/CxtLVaYz.js +3 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D5-A_Ffd.js +4 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6RI2Zrd.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D6su9Aln.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/D98V7j6A.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIRAtgl0.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DIeogL5L.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DOlTMNyt.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjUWrjOv.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/DjfY96ND.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/H7C68rOM.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/O-EABkf9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/XO7A28GO.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/hQVEETDE.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/l7KrR96u.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/nAHhluT7.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/r64xT6ao.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/vC4nQVEB.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/chunks/x9G_hzyY.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/app.CjnvpsmS.js +2 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/entry/start.0o1H7wM9.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/0.XRq_TUwu.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/1.B4rNYwVp.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/10.DfBwOEhN.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/11.CWG1ehzT.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/12.CwF2_8mP.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/2.CS4muRY-.js +6 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/3.CWHpKonm.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/4.OUWOLQeV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/5.Dm6t9F5W.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/6.Bw5ck4gK.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/7.CF0EDTR6.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/8.Cw30LEcV.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/immutable/nodes/9.CPu3CiBc.js +1 -0
- lightly_studio/dist_lightly_studio_view_app/_app/version.json +1 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon-precomposed.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/apple-touch-icon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/favicon.png +0 -0
- lightly_studio/dist_lightly_studio_view_app/index.html +44 -0
- lightly_studio/examples/example.py +23 -0
- lightly_studio/examples/example_metadata.py +338 -0
- lightly_studio/examples/example_selection.py +39 -0
- lightly_studio/examples/example_split_work.py +67 -0
- lightly_studio/examples/example_v2.py +21 -0
- lightly_studio/export_schema.py +18 -0
- lightly_studio/few_shot_classifier/__init__.py +0 -0
- lightly_studio/few_shot_classifier/classifier.py +80 -0
- lightly_studio/few_shot_classifier/classifier_manager.py +663 -0
- lightly_studio/few_shot_classifier/random_forest_classifier.py +489 -0
- lightly_studio/metadata/complex_metadata.py +47 -0
- lightly_studio/metadata/gps_coordinate.py +41 -0
- lightly_studio/metadata/metadata_protocol.py +17 -0
- lightly_studio/metrics/__init__.py +0 -0
- lightly_studio/metrics/detection/__init__.py +0 -0
- lightly_studio/metrics/detection/map.py +268 -0
- lightly_studio/models/__init__.py +1 -0
- lightly_studio/models/annotation/__init__.py +0 -0
- lightly_studio/models/annotation/annotation_base.py +171 -0
- lightly_studio/models/annotation/instance_segmentation.py +56 -0
- lightly_studio/models/annotation/links.py +17 -0
- lightly_studio/models/annotation/object_detection.py +47 -0
- lightly_studio/models/annotation/semantic_segmentation.py +44 -0
- lightly_studio/models/annotation_label.py +47 -0
- lightly_studio/models/annotation_task.py +28 -0
- lightly_studio/models/classifier.py +20 -0
- lightly_studio/models/dataset.py +84 -0
- lightly_studio/models/embedding_model.py +30 -0
- lightly_studio/models/metadata.py +208 -0
- lightly_studio/models/sample.py +180 -0
- lightly_studio/models/sample_embedding.py +37 -0
- lightly_studio/models/settings.py +60 -0
- lightly_studio/models/tag.py +96 -0
- lightly_studio/py.typed +0 -0
- lightly_studio/resolvers/__init__.py +7 -0
- lightly_studio/resolvers/annotation_label_resolver/__init__.py +21 -0
- lightly_studio/resolvers/annotation_label_resolver/create.py +27 -0
- lightly_studio/resolvers/annotation_label_resolver/delete.py +28 -0
- lightly_studio/resolvers/annotation_label_resolver/get_all.py +22 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_id.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/get_by_label_name.py +24 -0
- lightly_studio/resolvers/annotation_label_resolver/names_by_ids.py +25 -0
- lightly_studio/resolvers/annotation_label_resolver/update.py +38 -0
- lightly_studio/resolvers/annotation_resolver/__init__.py +33 -0
- lightly_studio/resolvers/annotation_resolver/count_annotations_by_dataset.py +120 -0
- lightly_studio/resolvers/annotation_resolver/create.py +19 -0
- lightly_studio/resolvers/annotation_resolver/create_many.py +96 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotation.py +45 -0
- lightly_studio/resolvers/annotation_resolver/delete_annotations.py +56 -0
- lightly_studio/resolvers/annotation_resolver/get_all.py +74 -0
- lightly_studio/resolvers/annotation_resolver/get_by_id.py +18 -0
- lightly_studio/resolvers/annotation_resolver/update_annotation_label.py +144 -0
- lightly_studio/resolvers/annotation_resolver/update_bounding_box.py +68 -0
- lightly_studio/resolvers/annotation_task_resolver.py +31 -0
- lightly_studio/resolvers/annotations/__init__.py +1 -0
- lightly_studio/resolvers/annotations/annotations_filter.py +89 -0
- lightly_studio/resolvers/dataset_resolver.py +278 -0
- lightly_studio/resolvers/embedding_model_resolver.py +100 -0
- lightly_studio/resolvers/metadata_resolver/__init__.py +15 -0
- lightly_studio/resolvers/metadata_resolver/metadata_filter.py +163 -0
- lightly_studio/resolvers/metadata_resolver/sample/__init__.py +21 -0
- lightly_studio/resolvers/metadata_resolver/sample/bulk_set_metadata.py +48 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_by_sample_id.py +24 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_metadata_info.py +104 -0
- lightly_studio/resolvers/metadata_resolver/sample/get_value_for_sample.py +27 -0
- lightly_studio/resolvers/metadata_resolver/sample/set_value_for_sample.py +53 -0
- lightly_studio/resolvers/sample_embedding_resolver.py +86 -0
- lightly_studio/resolvers/sample_resolver.py +249 -0
- lightly_studio/resolvers/samples_filter.py +81 -0
- lightly_studio/resolvers/settings_resolver.py +58 -0
- lightly_studio/resolvers/tag_resolver.py +276 -0
- lightly_studio/selection/README.md +6 -0
- lightly_studio/selection/mundig.py +105 -0
- lightly_studio/selection/select.py +96 -0
- lightly_studio/selection/select_via_db.py +93 -0
- lightly_studio/selection/selection_config.py +31 -0
- lightly_studio/services/annotations_service/__init__.py +21 -0
- lightly_studio/services/annotations_service/get_annotation_by_id.py +31 -0
- lightly_studio/services/annotations_service/update_annotation.py +65 -0
- lightly_studio/services/annotations_service/update_annotation_label.py +48 -0
- lightly_studio/services/annotations_service/update_annotations.py +29 -0
- lightly_studio/setup_logging.py +19 -0
- lightly_studio/type_definitions.py +19 -0
- lightly_studio/vendor/ACKNOWLEDGEMENTS +422 -0
- lightly_studio/vendor/LICENSE +31 -0
- lightly_studio/vendor/LICENSE_weights_data +50 -0
- lightly_studio/vendor/README.md +5 -0
- lightly_studio/vendor/__init__.py +1 -0
- lightly_studio/vendor/mobileclip/__init__.py +96 -0
- lightly_studio/vendor/mobileclip/clip.py +77 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_b.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s0.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s1.json +18 -0
- lightly_studio/vendor/mobileclip/configs/mobileclip_s2.json +18 -0
- lightly_studio/vendor/mobileclip/image_encoder.py +67 -0
- lightly_studio/vendor/mobileclip/logger.py +154 -0
- lightly_studio/vendor/mobileclip/models/__init__.py +10 -0
- lightly_studio/vendor/mobileclip/models/mci.py +933 -0
- lightly_studio/vendor/mobileclip/models/vit.py +433 -0
- lightly_studio/vendor/mobileclip/modules/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/common/mobileone.py +341 -0
- lightly_studio/vendor/mobileclip/modules/common/transformer.py +451 -0
- lightly_studio/vendor/mobileclip/modules/image/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/image/image_projection.py +113 -0
- lightly_studio/vendor/mobileclip/modules/image/replknet.py +188 -0
- lightly_studio/vendor/mobileclip/modules/text/__init__.py +4 -0
- lightly_studio/vendor/mobileclip/modules/text/repmixer.py +281 -0
- lightly_studio/vendor/mobileclip/modules/text/tokenizer.py +38 -0
- lightly_studio/vendor/mobileclip/text_encoder.py +245 -0
- lightly_studio-0.3.1.dist-info/METADATA +520 -0
- lightly_studio-0.3.1.dist-info/RECORD +219 -0
- lightly_studio-0.3.1.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,50 @@
|
|
|
1
|
+
ML-MobileCLIP Model Weights and Data
|
|
2
|
+
|
|
3
|
+
Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
|
|
5
|
+
IMPORTANT: This Apple software is supplied to you by Apple
|
|
6
|
+
Inc. ("Apple") in consideration of your agreement to the following
|
|
7
|
+
terms, and your use, installation, modification or redistribution of
|
|
8
|
+
this Apple software constitutes acceptance of these terms. If you do
|
|
9
|
+
not agree with these terms, please do not use, install, modify or
|
|
10
|
+
redistribute this Apple software.
|
|
11
|
+
|
|
12
|
+
In consideration of your agreement to abide by the following terms, and
|
|
13
|
+
subject to these terms, Apple grants you a personal, non-exclusive
|
|
14
|
+
license, under Apple's copyrights in this original Apple software (the
|
|
15
|
+
"Apple Software"), to use, reproduce, modify and redistribute the Apple
|
|
16
|
+
Software, with or without modifications, in source and/or binary forms;
|
|
17
|
+
provided that if you redistribute the Apple Software in its entirety and
|
|
18
|
+
without modifications, you must retain this notice and the following
|
|
19
|
+
text and disclaimers in all such redistributions of the Apple Software.
|
|
20
|
+
Neither the name, trademarks, service marks or logos of Apple Inc. may
|
|
21
|
+
be used to endorse or promote products derived from the Apple Software
|
|
22
|
+
without specific prior written permission from Apple. Except as
|
|
23
|
+
expressly stated in this notice, no other rights or licenses, express or
|
|
24
|
+
implied, are granted by Apple herein, including but not limited to any
|
|
25
|
+
patent rights that may be infringed by your derivative works or by other
|
|
26
|
+
works in which the Apple Software may be incorporated.
|
|
27
|
+
|
|
28
|
+
The Apple Software is provided by Apple on an "AS IS" basis. APPLE
|
|
29
|
+
MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION
|
|
30
|
+
THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS
|
|
31
|
+
FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND
|
|
32
|
+
OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS.
|
|
33
|
+
|
|
34
|
+
IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL
|
|
35
|
+
OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
|
36
|
+
SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
|
37
|
+
INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION,
|
|
38
|
+
MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED
|
|
39
|
+
AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE),
|
|
40
|
+
STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE
|
|
41
|
+
POSSIBILITY OF SUCH DAMAGE.
|
|
42
|
+
|
|
43
|
+
-------------------------------------------------------------------------------
|
|
44
|
+
SOFTWARE DISTRIBUTED WITH ML-MobileCLIP:
|
|
45
|
+
|
|
46
|
+
The ML-MobileCLIP software copyright and license terms can be found in LICENSE.
|
|
47
|
+
|
|
48
|
+
The ML-MobileCLIP software includes a number of subcomponents with separate
|
|
49
|
+
copyright notices and license terms - please see the file ACKNOWLEDGEMENTS.
|
|
50
|
+
-------------------------------------------------------------------------------
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Vendor directory for third-party code."""
|
|
@@ -0,0 +1,96 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
import os
|
|
6
|
+
import json
|
|
7
|
+
from typing import Optional, Union, Tuple, Any
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn as nn
|
|
11
|
+
from torchvision.transforms import (
|
|
12
|
+
CenterCrop,
|
|
13
|
+
Compose,
|
|
14
|
+
InterpolationMode,
|
|
15
|
+
Resize,
|
|
16
|
+
ToTensor,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
from .clip import CLIP
|
|
20
|
+
from .modules.text.tokenizer import (
|
|
21
|
+
ClipTokenizer,
|
|
22
|
+
)
|
|
23
|
+
from .modules.common.mobileone import reparameterize_model
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def create_model_and_transforms(
|
|
27
|
+
model_name: str,
|
|
28
|
+
pretrained: Optional[str] = None,
|
|
29
|
+
reparameterize: Optional[bool] = True,
|
|
30
|
+
device: Union[str, torch.device] = "cpu",
|
|
31
|
+
) -> Tuple[nn.Module, Any, Any]:
|
|
32
|
+
"""
|
|
33
|
+
Method to instantiate model and pre-processing transforms necessary for inference.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
model_name: Model name. Choose from ['mobileclip_s0', 'mobileclip_s1', 'mobileclip_s2', 'mobileclip_b']
|
|
37
|
+
pretrained: Location of pretrained checkpoint.
|
|
38
|
+
reparameterize: When set to True, re-parameterizable branches get folded for faster inference.
|
|
39
|
+
device: Device identifier for model placement.
|
|
40
|
+
|
|
41
|
+
Returns:
|
|
42
|
+
Tuple of instantiated model, and preprocessing transforms for inference.
|
|
43
|
+
"""
|
|
44
|
+
# Config files
|
|
45
|
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
|
46
|
+
configs_dir = os.path.join(root_dir, "configs")
|
|
47
|
+
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
|
48
|
+
|
|
49
|
+
# Get config from yaml file
|
|
50
|
+
if not os.path.exists(model_cfg_file):
|
|
51
|
+
raise ValueError(f"Unsupported model name: {model_name}")
|
|
52
|
+
model_cfg = json.load(open(model_cfg_file, "r"))
|
|
53
|
+
|
|
54
|
+
# Build preprocessing transforms for inference
|
|
55
|
+
resolution = model_cfg["image_cfg"]["image_size"]
|
|
56
|
+
resize_size = resolution
|
|
57
|
+
centercrop_size = resolution
|
|
58
|
+
aug_list = [
|
|
59
|
+
Resize(
|
|
60
|
+
resize_size,
|
|
61
|
+
interpolation=InterpolationMode.BILINEAR,
|
|
62
|
+
),
|
|
63
|
+
CenterCrop(centercrop_size),
|
|
64
|
+
ToTensor(),
|
|
65
|
+
]
|
|
66
|
+
preprocess = Compose(aug_list)
|
|
67
|
+
|
|
68
|
+
# Build model
|
|
69
|
+
model = CLIP(cfg=model_cfg)
|
|
70
|
+
model.to(device)
|
|
71
|
+
model.eval()
|
|
72
|
+
|
|
73
|
+
# Load checkpoint
|
|
74
|
+
if pretrained is not None:
|
|
75
|
+
chkpt = torch.load(pretrained, weights_only=True)
|
|
76
|
+
model.load_state_dict(chkpt)
|
|
77
|
+
|
|
78
|
+
# Reparameterize model for inference (if specified)
|
|
79
|
+
if reparameterize:
|
|
80
|
+
model = reparameterize_model(model)
|
|
81
|
+
|
|
82
|
+
return model, None, preprocess
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def get_tokenizer(model_name: str) -> nn.Module:
|
|
86
|
+
# Config files
|
|
87
|
+
root_dir = os.path.dirname(os.path.abspath(__file__))
|
|
88
|
+
configs_dir = os.path.join(root_dir, "configs")
|
|
89
|
+
model_cfg_file = os.path.join(configs_dir, model_name + ".json")
|
|
90
|
+
|
|
91
|
+
# Get config from yaml file
|
|
92
|
+
model_cfg = json.load(open(model_cfg_file, "r"))
|
|
93
|
+
|
|
94
|
+
# Build tokenizer
|
|
95
|
+
text_tokenizer = ClipTokenizer(model_cfg)
|
|
96
|
+
return text_tokenizer
|
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
""" Model schema in open_clip format for inference only. """
|
|
6
|
+
import math
|
|
7
|
+
from typing import Any, Optional, Dict
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
import torch.nn.functional as F
|
|
11
|
+
from torch import nn
|
|
12
|
+
|
|
13
|
+
from .text_encoder import (
|
|
14
|
+
TextTransformer,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
from .image_encoder import MCi
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CLIP(nn.Module):
|
|
21
|
+
"""Base class for multi-modal image-text data"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, cfg: Dict, output_dict: bool = False, *args, **kwargs) -> None:
|
|
24
|
+
super().__init__()
|
|
25
|
+
self.output_dict = output_dict
|
|
26
|
+
self.projection_dim = cfg["embed_dim"]
|
|
27
|
+
if self.projection_dim is None:
|
|
28
|
+
raise ValueError("Please specify `embed_dim` in model config.")
|
|
29
|
+
|
|
30
|
+
self.image_encoder = MCi(
|
|
31
|
+
model_name=cfg["image_cfg"]["model_name"],
|
|
32
|
+
projection_dim=self.projection_dim,
|
|
33
|
+
)
|
|
34
|
+
self.text_encoder = TextTransformer(
|
|
35
|
+
cfg=cfg["text_cfg"], projection_dim=self.projection_dim
|
|
36
|
+
)
|
|
37
|
+
self.logit_scale = nn.Parameter(torch.ones([]) * math.log(1.0 / 0.07))
|
|
38
|
+
|
|
39
|
+
def _exponentiate_and_clip_logits(self, max_scale: float = 100.0):
|
|
40
|
+
scale = self.logit_scale.exp()
|
|
41
|
+
scale = torch.clamp(scale, 0, max_scale)
|
|
42
|
+
return scale
|
|
43
|
+
|
|
44
|
+
def encode_image(self, image: torch.Tensor, normalize: bool = False):
|
|
45
|
+
image_encoder_out = self.image_encoder(image)
|
|
46
|
+
if isinstance(image_encoder_out, dict):
|
|
47
|
+
features = image_encoder_out["logits"]
|
|
48
|
+
else:
|
|
49
|
+
features = image_encoder_out
|
|
50
|
+
return F.normalize(features, dim=-1) if normalize else features
|
|
51
|
+
|
|
52
|
+
def encode_text(self, text: torch.Tensor, normalize: bool = False):
|
|
53
|
+
text_features = self.text_encoder(text_tokens=text, key_padding_mask=None)
|
|
54
|
+
return F.normalize(text_features, dim=-1) if normalize else text_features
|
|
55
|
+
|
|
56
|
+
def forward(
|
|
57
|
+
self,
|
|
58
|
+
image: Optional[torch.Tensor] = None,
|
|
59
|
+
text: Optional[torch.Tensor] = None,
|
|
60
|
+
*args,
|
|
61
|
+
**kwargs
|
|
62
|
+
) -> Any:
|
|
63
|
+
|
|
64
|
+
image_embeddings = (
|
|
65
|
+
self.encode_image(image, normalize=True) if image is not None else None
|
|
66
|
+
)
|
|
67
|
+
text_embeddings = (
|
|
68
|
+
self.encode_text(text, normalize=True) if text is not None else None
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
if self.output_dict:
|
|
72
|
+
return {
|
|
73
|
+
"image_features": image_embeddings,
|
|
74
|
+
"text_features": text_embeddings,
|
|
75
|
+
"logit_scale": self._exponentiate_and_clip_logits(),
|
|
76
|
+
}
|
|
77
|
+
return image_embeddings, text_embeddings, self._exponentiate_and_clip_logits()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
{
|
|
2
|
+
"embed_dim": 512,
|
|
3
|
+
"image_cfg": {
|
|
4
|
+
"image_size": 224,
|
|
5
|
+
"model_name": "vit_b16"
|
|
6
|
+
},
|
|
7
|
+
"text_cfg": {
|
|
8
|
+
"context_length": 77,
|
|
9
|
+
"vocab_size": 49408,
|
|
10
|
+
"dim": 512,
|
|
11
|
+
"ffn_multiplier_per_layer": 4.0,
|
|
12
|
+
"n_heads_per_layer": 8,
|
|
13
|
+
"n_transformer_layers": 12,
|
|
14
|
+
"norm_layer": "layer_norm_fp32",
|
|
15
|
+
"causal_masking": true,
|
|
16
|
+
"model_name": "base"
|
|
17
|
+
}
|
|
18
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
{
|
|
2
|
+
"embed_dim": 512,
|
|
3
|
+
"image_cfg": {
|
|
4
|
+
"image_size": 256,
|
|
5
|
+
"model_name": "mci0"
|
|
6
|
+
},
|
|
7
|
+
"text_cfg": {
|
|
8
|
+
"context_length": 77,
|
|
9
|
+
"vocab_size": 49408,
|
|
10
|
+
"dim": 512,
|
|
11
|
+
"ffn_multiplier_per_layer": 4.0,
|
|
12
|
+
"n_heads_per_layer": 8,
|
|
13
|
+
"n_transformer_layers": 4,
|
|
14
|
+
"norm_layer": "layer_norm_fp32",
|
|
15
|
+
"causal_masking": false,
|
|
16
|
+
"model_name": "mct"
|
|
17
|
+
}
|
|
18
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
{
|
|
2
|
+
"embed_dim": 512,
|
|
3
|
+
"image_cfg": {
|
|
4
|
+
"image_size": 256,
|
|
5
|
+
"model_name": "mci1"
|
|
6
|
+
},
|
|
7
|
+
"text_cfg": {
|
|
8
|
+
"context_length": 77,
|
|
9
|
+
"vocab_size": 49408,
|
|
10
|
+
"dim": 512,
|
|
11
|
+
"ffn_multiplier_per_layer": 4.0,
|
|
12
|
+
"n_heads_per_layer": 8,
|
|
13
|
+
"n_transformer_layers": 12,
|
|
14
|
+
"norm_layer": "layer_norm_fp32",
|
|
15
|
+
"causal_masking": false,
|
|
16
|
+
"model_name": "base"
|
|
17
|
+
}
|
|
18
|
+
}
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
{
|
|
2
|
+
"embed_dim": 512,
|
|
3
|
+
"image_cfg": {
|
|
4
|
+
"image_size": 256,
|
|
5
|
+
"model_name": "mci2"
|
|
6
|
+
},
|
|
7
|
+
"text_cfg": {
|
|
8
|
+
"context_length": 77,
|
|
9
|
+
"vocab_size": 49408,
|
|
10
|
+
"dim": 512,
|
|
11
|
+
"ffn_multiplier_per_layer": 4.0,
|
|
12
|
+
"n_heads_per_layer": 8,
|
|
13
|
+
"n_transformer_layers": 12,
|
|
14
|
+
"norm_layer": "layer_norm_fp32",
|
|
15
|
+
"causal_masking": false,
|
|
16
|
+
"model_name": "base"
|
|
17
|
+
}
|
|
18
|
+
}
|
|
@@ -0,0 +1,67 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
from timm.models import create_model
|
|
9
|
+
|
|
10
|
+
from . import models # Added to register models
|
|
11
|
+
from .modules.image.image_projection import GlobalPool2D
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MCi(nn.Module):
|
|
15
|
+
"""
|
|
16
|
+
This class implements `MCi Models <https://arxiv.org/pdf/2311.17049.pdf>`_
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, model_name: str, *args, **kwargs) -> None:
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.projection_dim = None
|
|
22
|
+
if "projection_dim" in kwargs:
|
|
23
|
+
self.projection_dim = kwargs.get("projection_dim")
|
|
24
|
+
|
|
25
|
+
# Create model
|
|
26
|
+
self.model = create_model(model_name, projection_dim=self.projection_dim)
|
|
27
|
+
|
|
28
|
+
# Build out projection head.
|
|
29
|
+
if self.projection_dim is not None:
|
|
30
|
+
if hasattr(self.model, "head"):
|
|
31
|
+
self.model.head = MCi._update_image_classifier(
|
|
32
|
+
image_classifier=self.model.head, projection_dim=self.projection_dim
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
def forward(self, x: Any, *args, **kwargs) -> Any:
|
|
36
|
+
"""A forward function of the model."""
|
|
37
|
+
x = self.model(x)
|
|
38
|
+
return x
|
|
39
|
+
|
|
40
|
+
@staticmethod
|
|
41
|
+
def _get_in_feature_dimension(image_classifier: nn.Module) -> int:
|
|
42
|
+
"""Return the input feature dimension to the image classification head."""
|
|
43
|
+
in_features = None
|
|
44
|
+
if isinstance(image_classifier, nn.Sequential):
|
|
45
|
+
# Classifier that uses nn.Sequential usually has global pooling and
|
|
46
|
+
# multiple linear layers. Find the first linear layer and get its
|
|
47
|
+
# in_features
|
|
48
|
+
for layer in image_classifier:
|
|
49
|
+
if isinstance(layer, nn.Linear):
|
|
50
|
+
in_features = layer.in_features
|
|
51
|
+
break
|
|
52
|
+
elif isinstance(image_classifier, nn.Linear):
|
|
53
|
+
in_features = image_classifier.in_features
|
|
54
|
+
|
|
55
|
+
if in_features is None:
|
|
56
|
+
raise NotImplementedError(
|
|
57
|
+
f"Cannot get input feature dimension of {image_classifier}."
|
|
58
|
+
)
|
|
59
|
+
return in_features
|
|
60
|
+
|
|
61
|
+
@staticmethod
|
|
62
|
+
def _update_image_classifier(
|
|
63
|
+
image_classifier: nn.Module, projection_dim: int, *args, **kwargs
|
|
64
|
+
) -> nn.Module:
|
|
65
|
+
in_features = MCi._get_in_feature_dimension(image_classifier)
|
|
66
|
+
new_img_classifier = GlobalPool2D(in_dim=in_features, out_dim=projection_dim)
|
|
67
|
+
return new_img_classifier
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
#
|
|
2
|
+
# For licensing see accompanying LICENSE file.
|
|
3
|
+
# Copyright (C) 2024 Apple Inc. All Rights Reserved.
|
|
4
|
+
#
|
|
5
|
+
|
|
6
|
+
import os
|
|
7
|
+
import sys
|
|
8
|
+
import time
|
|
9
|
+
import traceback
|
|
10
|
+
from typing import Optional, Union
|
|
11
|
+
|
|
12
|
+
text_colors = {
|
|
13
|
+
"logs": "\033[34m", # 033 is the escape code and 34 is the color code
|
|
14
|
+
"info": "\033[32m",
|
|
15
|
+
"warning": "\033[33m",
|
|
16
|
+
"debug": "\033[93m",
|
|
17
|
+
"error": "\033[31m",
|
|
18
|
+
"bold": "\033[1m",
|
|
19
|
+
"end_color": "\033[0m",
|
|
20
|
+
"light_red": "\033[36m",
|
|
21
|
+
}
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_curr_time_stamp() -> str:
|
|
25
|
+
return time.strftime("%Y-%m-%d %H:%M:%S")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def error(message: str) -> None:
|
|
29
|
+
time_stamp = get_curr_time_stamp()
|
|
30
|
+
error_str = (
|
|
31
|
+
text_colors["error"]
|
|
32
|
+
+ text_colors["bold"]
|
|
33
|
+
+ "ERROR "
|
|
34
|
+
+ text_colors["end_color"]
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
# exiting with code -1 does not tell any information about the error (e.g., NaN encountered in the loss).
|
|
38
|
+
# For more descriptive error messages, we replace exit(-1) with sys.exit(ERROR_MESSAGE).
|
|
39
|
+
# This allows us to handle specific exceptions in the tests.
|
|
40
|
+
|
|
41
|
+
# print("{} - {} - {}".format(time_stamp, error_str, message), flush=True)
|
|
42
|
+
# print("{} - {} - {}".format(time_stamp, error_str, "Exiting!!!"), flush=True)
|
|
43
|
+
# exit(-1)
|
|
44
|
+
|
|
45
|
+
if sys.exc_info()[0] is None:
|
|
46
|
+
traceback.print_stack()
|
|
47
|
+
else:
|
|
48
|
+
traceback.print_exc()
|
|
49
|
+
sys.exit("{} - {} - {}. Exiting!!!".format(time_stamp, error_str, message))
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def color_text(in_text: str) -> str:
|
|
53
|
+
return text_colors["light_red"] + in_text + text_colors["end_color"]
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def log(message: str, end="\n") -> None:
|
|
57
|
+
time_stamp = get_curr_time_stamp()
|
|
58
|
+
log_str = (
|
|
59
|
+
text_colors["logs"] + text_colors["bold"] + "LOGS " + text_colors["end_color"]
|
|
60
|
+
)
|
|
61
|
+
print("{} - {} - {}".format(time_stamp, log_str, message), end=end)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def warning(message: Union[str, Warning]) -> None:
|
|
65
|
+
if isinstance(message, Warning):
|
|
66
|
+
message = f"{type(message).__name__}({','.join(map(repr, message.args))}"
|
|
67
|
+
|
|
68
|
+
time_stamp = get_curr_time_stamp()
|
|
69
|
+
warn_str = (
|
|
70
|
+
text_colors["warning"]
|
|
71
|
+
+ text_colors["bold"]
|
|
72
|
+
+ "WARNING"
|
|
73
|
+
+ text_colors["end_color"]
|
|
74
|
+
)
|
|
75
|
+
print("{} - {} - {}".format(time_stamp, warn_str, message))
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def ignore_exception_with_warning(message: str) -> None:
|
|
79
|
+
"""
|
|
80
|
+
After catching a tolerable exception E1 (e.g. when Model.forward() fails during
|
|
81
|
+
profiling with try-catch, it'll be helpful to log the exception for future
|
|
82
|
+
investigation. But printing the error stack trace, as is, could be confusing
|
|
83
|
+
when an uncaught (non-tolerable) exception "E2" raises down the road. Then, the log
|
|
84
|
+
will contain two stack traces for E1, E2. When looking for errors in logs, users
|
|
85
|
+
should look for E2, but they may find E1.
|
|
86
|
+
|
|
87
|
+
This function appends "(WARNING)" at the end of all lines of the E1 traceback, so
|
|
88
|
+
that the user can distinguish E1 from uncaught exception E2.
|
|
89
|
+
|
|
90
|
+
Args:
|
|
91
|
+
message: Extra explanation and context for debugging. (Note: the exception obj
|
|
92
|
+
will be automatically fetched from python. No need to pass it as an argument or as
|
|
93
|
+
message)
|
|
94
|
+
"""
|
|
95
|
+
warning(f"{message}:\n{traceback.format_exc()}".replace("\n", "\n(WARNING)"))
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def info(message: str, print_line: Optional[bool] = False) -> None:
|
|
99
|
+
time_stamp = get_curr_time_stamp()
|
|
100
|
+
info_str = (
|
|
101
|
+
text_colors["info"] + text_colors["bold"] + "INFO " + text_colors["end_color"]
|
|
102
|
+
)
|
|
103
|
+
print("{} - {} - {}".format(time_stamp, info_str, message))
|
|
104
|
+
if print_line:
|
|
105
|
+
double_dash_line(dashes=150)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def debug(message: str) -> None:
|
|
109
|
+
time_stamp = get_curr_time_stamp()
|
|
110
|
+
log_str = (
|
|
111
|
+
text_colors["debug"]
|
|
112
|
+
+ text_colors["bold"]
|
|
113
|
+
+ "DEBUG "
|
|
114
|
+
+ text_colors["end_color"]
|
|
115
|
+
)
|
|
116
|
+
print("{} - {} - {}".format(time_stamp, log_str, message))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def double_dash_line(dashes: Optional[int] = 75) -> None:
|
|
120
|
+
print(text_colors["error"] + "=" * dashes + text_colors["end_color"])
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def singe_dash_line(dashes: Optional[int] = 67) -> None:
|
|
124
|
+
print("-" * dashes)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def print_header(header: str) -> None:
|
|
128
|
+
double_dash_line()
|
|
129
|
+
print(
|
|
130
|
+
text_colors["info"]
|
|
131
|
+
+ text_colors["bold"]
|
|
132
|
+
+ "=" * 50
|
|
133
|
+
+ str(header)
|
|
134
|
+
+ text_colors["end_color"]
|
|
135
|
+
)
|
|
136
|
+
double_dash_line()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def print_header_minor(header: str) -> None:
|
|
140
|
+
print(
|
|
141
|
+
text_colors["warning"]
|
|
142
|
+
+ text_colors["bold"]
|
|
143
|
+
+ "=" * 25
|
|
144
|
+
+ str(header)
|
|
145
|
+
+ text_colors["end_color"]
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def disable_printing():
|
|
150
|
+
sys.stdout = open(os.devnull, "w")
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def enable_printing():
|
|
154
|
+
sys.stdout = sys.__stdout__
|