dgenerate-ultralytics-headless 8.3.137__py3-none-any.whl → 8.3.224__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/METADATA +41 -34
- dgenerate_ultralytics_headless-8.3.224.dist-info/RECORD +285 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/WHEEL +1 -1
- tests/__init__.py +7 -6
- tests/conftest.py +15 -39
- tests/test_cli.py +17 -17
- tests/test_cuda.py +17 -8
- tests/test_engine.py +36 -10
- tests/test_exports.py +98 -37
- tests/test_integrations.py +12 -15
- tests/test_python.py +126 -82
- tests/test_solutions.py +319 -135
- ultralytics/__init__.py +27 -9
- ultralytics/cfg/__init__.py +83 -87
- ultralytics/cfg/datasets/Argoverse.yaml +4 -4
- ultralytics/cfg/datasets/DOTAv1.5.yaml +2 -2
- ultralytics/cfg/datasets/DOTAv1.yaml +2 -2
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +2 -2
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +4 -5
- ultralytics/cfg/datasets/ImageNet.yaml +3 -3
- ultralytics/cfg/datasets/Objects365.yaml +24 -20
- ultralytics/cfg/datasets/SKU-110K.yaml +9 -9
- ultralytics/cfg/datasets/VOC.yaml +10 -13
- ultralytics/cfg/datasets/VisDrone.yaml +43 -33
- ultralytics/cfg/datasets/african-wildlife.yaml +5 -5
- ultralytics/cfg/datasets/brain-tumor.yaml +4 -5
- ultralytics/cfg/datasets/carparts-seg.yaml +5 -5
- ultralytics/cfg/datasets/coco-pose.yaml +26 -4
- ultralytics/cfg/datasets/coco.yaml +4 -4
- ultralytics/cfg/datasets/coco128-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco128.yaml +2 -2
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/coco8-pose.yaml +23 -2
- ultralytics/cfg/datasets/coco8-seg.yaml +2 -2
- ultralytics/cfg/datasets/coco8.yaml +2 -2
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +5 -5
- ultralytics/cfg/datasets/dog-pose.yaml +32 -4
- ultralytics/cfg/datasets/dota8-multispectral.yaml +2 -2
- ultralytics/cfg/datasets/dota8.yaml +2 -2
- ultralytics/cfg/datasets/hand-keypoints.yaml +29 -4
- ultralytics/cfg/datasets/lvis.yaml +9 -9
- ultralytics/cfg/datasets/medical-pills.yaml +4 -5
- ultralytics/cfg/datasets/open-images-v7.yaml +7 -10
- ultralytics/cfg/datasets/package-seg.yaml +5 -5
- ultralytics/cfg/datasets/signature.yaml +4 -4
- ultralytics/cfg/datasets/tiger-pose.yaml +20 -4
- ultralytics/cfg/datasets/xView.yaml +5 -5
- ultralytics/cfg/default.yaml +96 -93
- ultralytics/cfg/trackers/botsort.yaml +16 -17
- ultralytics/cfg/trackers/bytetrack.yaml +9 -11
- ultralytics/data/__init__.py +4 -4
- ultralytics/data/annotator.py +12 -12
- ultralytics/data/augment.py +531 -564
- ultralytics/data/base.py +76 -81
- ultralytics/data/build.py +206 -42
- ultralytics/data/converter.py +179 -78
- ultralytics/data/dataset.py +121 -121
- ultralytics/data/loaders.py +114 -91
- ultralytics/data/split.py +28 -15
- ultralytics/data/split_dota.py +67 -48
- ultralytics/data/utils.py +110 -89
- ultralytics/engine/exporter.py +422 -460
- ultralytics/engine/model.py +224 -252
- ultralytics/engine/predictor.py +94 -89
- ultralytics/engine/results.py +345 -595
- ultralytics/engine/trainer.py +231 -134
- ultralytics/engine/tuner.py +279 -73
- ultralytics/engine/validator.py +53 -46
- ultralytics/hub/__init__.py +26 -28
- ultralytics/hub/auth.py +30 -16
- ultralytics/hub/google/__init__.py +34 -36
- ultralytics/hub/session.py +53 -77
- ultralytics/hub/utils.py +23 -109
- ultralytics/models/__init__.py +1 -1
- ultralytics/models/fastsam/__init__.py +1 -1
- ultralytics/models/fastsam/model.py +36 -18
- ultralytics/models/fastsam/predict.py +33 -44
- ultralytics/models/fastsam/utils.py +4 -5
- ultralytics/models/fastsam/val.py +12 -14
- ultralytics/models/nas/__init__.py +1 -1
- ultralytics/models/nas/model.py +16 -20
- ultralytics/models/nas/predict.py +12 -14
- ultralytics/models/nas/val.py +4 -5
- ultralytics/models/rtdetr/__init__.py +1 -1
- ultralytics/models/rtdetr/model.py +9 -9
- ultralytics/models/rtdetr/predict.py +22 -17
- ultralytics/models/rtdetr/train.py +20 -16
- ultralytics/models/rtdetr/val.py +79 -59
- ultralytics/models/sam/__init__.py +8 -2
- ultralytics/models/sam/amg.py +53 -38
- ultralytics/models/sam/build.py +29 -31
- ultralytics/models/sam/model.py +33 -38
- ultralytics/models/sam/modules/blocks.py +159 -182
- ultralytics/models/sam/modules/decoders.py +38 -47
- ultralytics/models/sam/modules/encoders.py +114 -133
- ultralytics/models/sam/modules/memory_attention.py +38 -31
- ultralytics/models/sam/modules/sam.py +114 -93
- ultralytics/models/sam/modules/tiny_encoder.py +268 -291
- ultralytics/models/sam/modules/transformer.py +59 -66
- ultralytics/models/sam/modules/utils.py +55 -72
- ultralytics/models/sam/predict.py +745 -341
- ultralytics/models/utils/loss.py +118 -107
- ultralytics/models/utils/ops.py +118 -71
- ultralytics/models/yolo/__init__.py +1 -1
- ultralytics/models/yolo/classify/predict.py +28 -26
- ultralytics/models/yolo/classify/train.py +50 -81
- ultralytics/models/yolo/classify/val.py +68 -61
- ultralytics/models/yolo/detect/predict.py +12 -15
- ultralytics/models/yolo/detect/train.py +56 -46
- ultralytics/models/yolo/detect/val.py +279 -223
- ultralytics/models/yolo/model.py +167 -86
- ultralytics/models/yolo/obb/predict.py +7 -11
- ultralytics/models/yolo/obb/train.py +23 -25
- ultralytics/models/yolo/obb/val.py +107 -99
- ultralytics/models/yolo/pose/__init__.py +1 -1
- ultralytics/models/yolo/pose/predict.py +12 -14
- ultralytics/models/yolo/pose/train.py +31 -69
- ultralytics/models/yolo/pose/val.py +119 -254
- ultralytics/models/yolo/segment/predict.py +21 -25
- ultralytics/models/yolo/segment/train.py +12 -66
- ultralytics/models/yolo/segment/val.py +126 -305
- ultralytics/models/yolo/world/train.py +53 -45
- ultralytics/models/yolo/world/train_world.py +51 -32
- ultralytics/models/yolo/yoloe/__init__.py +7 -7
- ultralytics/models/yolo/yoloe/predict.py +30 -37
- ultralytics/models/yolo/yoloe/train.py +89 -71
- ultralytics/models/yolo/yoloe/train_seg.py +15 -17
- ultralytics/models/yolo/yoloe/val.py +56 -41
- ultralytics/nn/__init__.py +9 -11
- ultralytics/nn/autobackend.py +179 -107
- ultralytics/nn/modules/__init__.py +67 -67
- ultralytics/nn/modules/activation.py +8 -7
- ultralytics/nn/modules/block.py +302 -323
- ultralytics/nn/modules/conv.py +61 -104
- ultralytics/nn/modules/head.py +488 -186
- ultralytics/nn/modules/transformer.py +183 -123
- ultralytics/nn/modules/utils.py +15 -20
- ultralytics/nn/tasks.py +327 -203
- ultralytics/nn/text_model.py +81 -65
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +12 -12
- ultralytics/solutions/ai_gym.py +19 -27
- ultralytics/solutions/analytics.py +36 -26
- ultralytics/solutions/config.py +29 -28
- ultralytics/solutions/distance_calculation.py +23 -24
- ultralytics/solutions/heatmap.py +17 -19
- ultralytics/solutions/instance_segmentation.py +21 -19
- ultralytics/solutions/object_blurrer.py +16 -17
- ultralytics/solutions/object_counter.py +48 -53
- ultralytics/solutions/object_cropper.py +22 -16
- ultralytics/solutions/parking_management.py +61 -58
- ultralytics/solutions/queue_management.py +19 -19
- ultralytics/solutions/region_counter.py +63 -50
- ultralytics/solutions/security_alarm.py +22 -25
- ultralytics/solutions/similarity_search.py +107 -60
- ultralytics/solutions/solutions.py +343 -262
- ultralytics/solutions/speed_estimation.py +35 -31
- ultralytics/solutions/streamlit_inference.py +104 -40
- ultralytics/solutions/templates/similarity-search.html +31 -24
- ultralytics/solutions/trackzone.py +24 -24
- ultralytics/solutions/vision_eye.py +11 -12
- ultralytics/trackers/__init__.py +1 -1
- ultralytics/trackers/basetrack.py +18 -27
- ultralytics/trackers/bot_sort.py +48 -39
- ultralytics/trackers/byte_tracker.py +94 -94
- ultralytics/trackers/track.py +7 -16
- ultralytics/trackers/utils/gmc.py +37 -69
- ultralytics/trackers/utils/kalman_filter.py +68 -76
- ultralytics/trackers/utils/matching.py +13 -17
- ultralytics/utils/__init__.py +251 -275
- ultralytics/utils/autobatch.py +19 -7
- ultralytics/utils/autodevice.py +68 -38
- ultralytics/utils/benchmarks.py +169 -130
- ultralytics/utils/callbacks/base.py +12 -13
- ultralytics/utils/callbacks/clearml.py +14 -15
- ultralytics/utils/callbacks/comet.py +139 -66
- ultralytics/utils/callbacks/dvc.py +19 -27
- ultralytics/utils/callbacks/hub.py +8 -6
- ultralytics/utils/callbacks/mlflow.py +6 -10
- ultralytics/utils/callbacks/neptune.py +11 -19
- ultralytics/utils/callbacks/platform.py +73 -0
- ultralytics/utils/callbacks/raytune.py +3 -4
- ultralytics/utils/callbacks/tensorboard.py +9 -12
- ultralytics/utils/callbacks/wb.py +33 -30
- ultralytics/utils/checks.py +163 -114
- ultralytics/utils/cpu.py +89 -0
- ultralytics/utils/dist.py +24 -20
- ultralytics/utils/downloads.py +176 -146
- ultralytics/utils/errors.py +11 -13
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/{export.py → export/engine.py} +81 -63
- ultralytics/utils/export/imx.py +294 -0
- ultralytics/utils/export/tensorflow.py +217 -0
- ultralytics/utils/files.py +33 -36
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +105 -120
- ultralytics/utils/logger.py +404 -0
- ultralytics/utils/loss.py +99 -61
- ultralytics/utils/metrics.py +649 -478
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +263 -451
- ultralytics/utils/patches.py +70 -31
- ultralytics/utils/plotting.py +253 -223
- ultralytics/utils/tal.py +48 -61
- ultralytics/utils/torch_utils.py +244 -251
- ultralytics/utils/tqdm.py +438 -0
- ultralytics/utils/triton.py +22 -23
- ultralytics/utils/tuner.py +11 -10
- dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD +0 -272
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/entry_points.txt +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/licenses/LICENSE +0 -0
- {dgenerate_ultralytics_headless-8.3.137.dist-info → dgenerate_ultralytics_headless-8.3.224.dist-info}/top_level.txt +0 -0
ultralytics/nn/text_model.py
CHANGED
|
@@ -1,10 +1,13 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from abc import abstractmethod
|
|
4
6
|
from pathlib import Path
|
|
5
7
|
|
|
6
8
|
import torch
|
|
7
9
|
import torch.nn as nn
|
|
10
|
+
from PIL import Image
|
|
8
11
|
|
|
9
12
|
from ultralytics.utils import checks
|
|
10
13
|
from ultralytics.utils.torch_utils import smart_inference_mode
|
|
@@ -17,15 +20,14 @@ except ImportError:
|
|
|
17
20
|
|
|
18
21
|
|
|
19
22
|
class TextModel(nn.Module):
|
|
20
|
-
"""
|
|
21
|
-
Abstract base class for text encoding models.
|
|
23
|
+
"""Abstract base class for text encoding models.
|
|
22
24
|
|
|
23
25
|
This class defines the interface for text encoding models used in vision-language tasks. Subclasses must implement
|
|
24
|
-
the tokenize and encode_text methods.
|
|
26
|
+
the tokenize and encode_text methods to provide text tokenization and encoding functionality.
|
|
25
27
|
|
|
26
28
|
Methods:
|
|
27
|
-
tokenize: Convert input texts to tokens.
|
|
28
|
-
encode_text: Encode tokenized texts into feature vectors.
|
|
29
|
+
tokenize: Convert input texts to tokens for model processing.
|
|
30
|
+
encode_text: Encode tokenized texts into normalized feature vectors.
|
|
29
31
|
"""
|
|
30
32
|
|
|
31
33
|
def __init__(self):
|
|
@@ -33,22 +35,21 @@ class TextModel(nn.Module):
|
|
|
33
35
|
super().__init__()
|
|
34
36
|
|
|
35
37
|
@abstractmethod
|
|
36
|
-
def tokenize(texts):
|
|
38
|
+
def tokenize(self, texts):
|
|
37
39
|
"""Convert input texts to tokens for model processing."""
|
|
38
40
|
pass
|
|
39
41
|
|
|
40
42
|
@abstractmethod
|
|
41
|
-
def encode_text(texts, dtype):
|
|
43
|
+
def encode_text(self, texts, dtype):
|
|
42
44
|
"""Encode tokenized texts into normalized feature vectors."""
|
|
43
45
|
pass
|
|
44
46
|
|
|
45
47
|
|
|
46
48
|
class CLIP(TextModel):
|
|
47
|
-
"""
|
|
48
|
-
Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
|
49
|
+
"""Implements OpenAI's CLIP (Contrastive Language-Image Pre-training) text encoder.
|
|
49
50
|
|
|
50
|
-
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors
|
|
51
|
-
|
|
51
|
+
This class provides a text encoder based on OpenAI's CLIP model, which can convert text into feature vectors that
|
|
52
|
+
are aligned with corresponding image features in a shared embedding space.
|
|
52
53
|
|
|
53
54
|
Attributes:
|
|
54
55
|
model (clip.model.CLIP): The loaded CLIP model.
|
|
@@ -59,7 +60,6 @@ class CLIP(TextModel):
|
|
|
59
60
|
encode_text: Encode tokenized texts into normalized feature vectors.
|
|
60
61
|
|
|
61
62
|
Examples:
|
|
62
|
-
>>> from ultralytics.models.sam import CLIP
|
|
63
63
|
>>> import torch
|
|
64
64
|
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
65
65
|
>>> clip_model = CLIP(size="ViT-B/32", device=device)
|
|
@@ -68,12 +68,11 @@ class CLIP(TextModel):
|
|
|
68
68
|
>>> print(text_features.shape)
|
|
69
69
|
"""
|
|
70
70
|
|
|
71
|
-
def __init__(self, size, device):
|
|
72
|
-
"""
|
|
73
|
-
Initialize the CLIP text encoder.
|
|
71
|
+
def __init__(self, size: str, device: torch.device) -> None:
|
|
72
|
+
"""Initialize the CLIP text encoder.
|
|
74
73
|
|
|
75
|
-
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads
|
|
76
|
-
|
|
74
|
+
This class implements the TextModel interface using OpenAI's CLIP model for text encoding. It loads a
|
|
75
|
+
pre-trained CLIP model of the specified size and prepares it for text encoding tasks.
|
|
77
76
|
|
|
78
77
|
Args:
|
|
79
78
|
size (str): Model size identifier (e.g., 'ViT-B/32').
|
|
@@ -81,22 +80,20 @@ class CLIP(TextModel):
|
|
|
81
80
|
|
|
82
81
|
Examples:
|
|
83
82
|
>>> import torch
|
|
84
|
-
>>> from ultralytics.models.sam.modules.clip import CLIP
|
|
85
83
|
>>> clip_model = CLIP("ViT-B/32", device=torch.device("cuda:0"))
|
|
86
84
|
>>> text_features = clip_model.encode_text(["a photo of a cat", "a photo of a dog"])
|
|
87
85
|
"""
|
|
88
86
|
super().__init__()
|
|
89
|
-
self.model = clip.load(size, device=device)
|
|
87
|
+
self.model, self.image_preprocess = clip.load(size, device=device)
|
|
90
88
|
self.to(device)
|
|
91
89
|
self.device = device
|
|
92
90
|
self.eval()
|
|
93
91
|
|
|
94
|
-
def tokenize(self, texts):
|
|
95
|
-
"""
|
|
96
|
-
Convert input texts to CLIP tokens.
|
|
92
|
+
def tokenize(self, texts: str | list[str]) -> torch.Tensor:
|
|
93
|
+
"""Convert input texts to CLIP tokens.
|
|
97
94
|
|
|
98
95
|
Args:
|
|
99
|
-
texts (str |
|
|
96
|
+
texts (str | list[str]): Input text or list of texts to tokenize.
|
|
100
97
|
|
|
101
98
|
Returns:
|
|
102
99
|
(torch.Tensor): Tokenized text tensor with shape (batch_size, context_length) ready for model processing.
|
|
@@ -109,16 +106,15 @@ class CLIP(TextModel):
|
|
|
109
106
|
return clip.tokenize(texts).to(self.device)
|
|
110
107
|
|
|
111
108
|
@smart_inference_mode()
|
|
112
|
-
def encode_text(self, texts, dtype=torch.float32):
|
|
113
|
-
"""
|
|
114
|
-
Encode tokenized texts into normalized feature vectors.
|
|
109
|
+
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
110
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
115
111
|
|
|
116
112
|
This method processes tokenized text inputs through the CLIP model to generate feature vectors, which are then
|
|
117
113
|
normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
|
118
114
|
|
|
119
115
|
Args:
|
|
120
116
|
texts (torch.Tensor): Tokenized text inputs, typically created using the tokenize() method.
|
|
121
|
-
dtype (torch.dtype, optional): Data type for output features.
|
|
117
|
+
dtype (torch.dtype, optional): Data type for output features.
|
|
122
118
|
|
|
123
119
|
Returns:
|
|
124
120
|
(torch.Tensor): Normalized text feature vectors with unit length (L2 norm = 1).
|
|
@@ -134,13 +130,43 @@ class CLIP(TextModel):
|
|
|
134
130
|
txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
|
|
135
131
|
return txt_feats
|
|
136
132
|
|
|
133
|
+
@smart_inference_mode()
|
|
134
|
+
def encode_image(self, image: Image.Image | torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
135
|
+
"""Encode preprocessed images into normalized feature vectors.
|
|
136
|
+
|
|
137
|
+
This method processes preprocessed image inputs through the CLIP model to generate feature vectors, which are
|
|
138
|
+
then normalized to unit length. These normalized vectors can be used for text-image similarity comparisons.
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
image (PIL.Image | torch.Tensor): Preprocessed image input. If a PIL Image is provided, it will be converted
|
|
142
|
+
to a tensor using the model's image preprocessing function.
|
|
143
|
+
dtype (torch.dtype, optional): Data type for output features.
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
(torch.Tensor): Normalized image feature vectors with unit length (L2 norm = 1).
|
|
147
|
+
|
|
148
|
+
Examples:
|
|
149
|
+
>>> from ultralytics.nn.text_model import CLIP
|
|
150
|
+
>>> from PIL import Image
|
|
151
|
+
>>> clip_model = CLIP("ViT-B/32", device="cuda")
|
|
152
|
+
>>> image = Image.open("path/to/image.jpg")
|
|
153
|
+
>>> image_tensor = clip_model.image_preprocess(image).unsqueeze(0).to("cuda")
|
|
154
|
+
>>> features = clip_model.encode_image(image_tensor)
|
|
155
|
+
>>> features.shape
|
|
156
|
+
torch.Size([1, 512])
|
|
157
|
+
"""
|
|
158
|
+
if isinstance(image, Image.Image):
|
|
159
|
+
image = self.image_preprocess(image).unsqueeze(0).to(self.device)
|
|
160
|
+
img_feats = self.model.encode_image(image).to(dtype)
|
|
161
|
+
img_feats = img_feats / img_feats.norm(p=2, dim=-1, keepdim=True)
|
|
162
|
+
return img_feats
|
|
163
|
+
|
|
137
164
|
|
|
138
165
|
class MobileCLIP(TextModel):
|
|
139
|
-
"""
|
|
140
|
-
Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
|
166
|
+
"""Implement Apple's MobileCLIP text encoder for efficient text encoding.
|
|
141
167
|
|
|
142
168
|
This class implements the TextModel interface using Apple's MobileCLIP model, providing efficient text encoding
|
|
143
|
-
capabilities for vision-language tasks.
|
|
169
|
+
capabilities for vision-language tasks with reduced computational requirements compared to standard CLIP models.
|
|
144
170
|
|
|
145
171
|
Attributes:
|
|
146
172
|
model (mobileclip.model.MobileCLIP): The loaded MobileCLIP model.
|
|
@@ -161,9 +187,8 @@ class MobileCLIP(TextModel):
|
|
|
161
187
|
|
|
162
188
|
config_size_map = {"s0": "s0", "s1": "s1", "s2": "s2", "b": "b", "blt": "b"}
|
|
163
189
|
|
|
164
|
-
def __init__(self, size, device):
|
|
165
|
-
"""
|
|
166
|
-
Initialize the MobileCLIP text encoder.
|
|
190
|
+
def __init__(self, size: str, device: torch.device) -> None:
|
|
191
|
+
"""Initialize the MobileCLIP text encoder.
|
|
167
192
|
|
|
168
193
|
This class implements the TextModel interface using Apple's MobileCLIP model for efficient text encoding.
|
|
169
194
|
|
|
@@ -172,7 +197,6 @@ class MobileCLIP(TextModel):
|
|
|
172
197
|
device (torch.device): Device to load the model on.
|
|
173
198
|
|
|
174
199
|
Examples:
|
|
175
|
-
>>> from ultralytics.nn.modules import MobileCLIP
|
|
176
200
|
>>> import torch
|
|
177
201
|
>>> model = MobileCLIP("s0", device=torch.device("cpu"))
|
|
178
202
|
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
@@ -203,9 +227,8 @@ class MobileCLIP(TextModel):
|
|
|
203
227
|
self.device = device
|
|
204
228
|
self.eval()
|
|
205
229
|
|
|
206
|
-
def tokenize(self, texts):
|
|
207
|
-
"""
|
|
208
|
-
Convert input texts to MobileCLIP tokens.
|
|
230
|
+
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
|
231
|
+
"""Convert input texts to MobileCLIP tokens.
|
|
209
232
|
|
|
210
233
|
Args:
|
|
211
234
|
texts (list[str]): List of text strings to tokenize.
|
|
@@ -220,9 +243,8 @@ class MobileCLIP(TextModel):
|
|
|
220
243
|
return self.tokenizer(texts).to(self.device)
|
|
221
244
|
|
|
222
245
|
@smart_inference_mode()
|
|
223
|
-
def encode_text(self, texts, dtype=torch.float32):
|
|
224
|
-
"""
|
|
225
|
-
Encode tokenized texts into normalized feature vectors.
|
|
246
|
+
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
247
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
226
248
|
|
|
227
249
|
Args:
|
|
228
250
|
texts (torch.Tensor): Tokenized text inputs.
|
|
@@ -244,14 +266,13 @@ class MobileCLIP(TextModel):
|
|
|
244
266
|
|
|
245
267
|
|
|
246
268
|
class MobileCLIPTS(TextModel):
|
|
247
|
-
"""
|
|
248
|
-
Load a TorchScript traced version of MobileCLIP.
|
|
269
|
+
"""Load a TorchScript traced version of MobileCLIP.
|
|
249
270
|
|
|
250
|
-
This class implements the TextModel interface using Apple's MobileCLIP model
|
|
251
|
-
capabilities for vision-language tasks.
|
|
271
|
+
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format, providing
|
|
272
|
+
efficient text encoding capabilities for vision-language tasks with optimized inference performance.
|
|
252
273
|
|
|
253
274
|
Attributes:
|
|
254
|
-
encoder (
|
|
275
|
+
encoder (torch.jit.ScriptModule): The loaded TorchScript MobileCLIP text encoder.
|
|
255
276
|
tokenizer (callable): Tokenizer function for processing text inputs.
|
|
256
277
|
device (torch.device): Device where the model is loaded.
|
|
257
278
|
|
|
@@ -261,24 +282,22 @@ class MobileCLIPTS(TextModel):
|
|
|
261
282
|
|
|
262
283
|
Examples:
|
|
263
284
|
>>> device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
264
|
-
>>> text_encoder =
|
|
285
|
+
>>> text_encoder = MobileCLIPTS(device=device)
|
|
265
286
|
>>> tokens = text_encoder.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
266
287
|
>>> features = text_encoder.encode_text(tokens)
|
|
267
288
|
"""
|
|
268
289
|
|
|
269
|
-
def __init__(self, device):
|
|
270
|
-
"""
|
|
271
|
-
Initialize the MobileCLIP text encoder.
|
|
290
|
+
def __init__(self, device: torch.device):
|
|
291
|
+
"""Initialize the MobileCLIP TorchScript text encoder.
|
|
272
292
|
|
|
273
|
-
This class implements the TextModel interface using Apple's MobileCLIP model for efficient
|
|
293
|
+
This class implements the TextModel interface using Apple's MobileCLIP model in TorchScript format for efficient
|
|
294
|
+
text encoding with optimized inference performance.
|
|
274
295
|
|
|
275
296
|
Args:
|
|
276
297
|
device (torch.device): Device to load the model on.
|
|
277
298
|
|
|
278
299
|
Examples:
|
|
279
|
-
>>>
|
|
280
|
-
>>> import torch
|
|
281
|
-
>>> model = MobileCLIP(device=torch.device("cpu"))
|
|
300
|
+
>>> model = MobileCLIPTS(device=torch.device("cpu"))
|
|
282
301
|
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
283
302
|
>>> features = model.encode_text(tokens)
|
|
284
303
|
"""
|
|
@@ -289,9 +308,8 @@ class MobileCLIPTS(TextModel):
|
|
|
289
308
|
self.tokenizer = clip.clip.tokenize
|
|
290
309
|
self.device = device
|
|
291
310
|
|
|
292
|
-
def tokenize(self, texts):
|
|
293
|
-
"""
|
|
294
|
-
Convert input texts to MobileCLIP tokens.
|
|
311
|
+
def tokenize(self, texts: list[str]) -> torch.Tensor:
|
|
312
|
+
"""Convert input texts to MobileCLIP tokens.
|
|
295
313
|
|
|
296
314
|
Args:
|
|
297
315
|
texts (list[str]): List of text strings to tokenize.
|
|
@@ -300,15 +318,14 @@ class MobileCLIPTS(TextModel):
|
|
|
300
318
|
(torch.Tensor): Tokenized text inputs with shape (batch_size, sequence_length).
|
|
301
319
|
|
|
302
320
|
Examples:
|
|
303
|
-
>>> model =
|
|
321
|
+
>>> model = MobileCLIPTS("cpu")
|
|
304
322
|
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
305
323
|
"""
|
|
306
324
|
return self.tokenizer(texts).to(self.device)
|
|
307
325
|
|
|
308
326
|
@smart_inference_mode()
|
|
309
|
-
def encode_text(self, texts, dtype=torch.float32):
|
|
310
|
-
"""
|
|
311
|
-
Encode tokenized texts into normalized feature vectors.
|
|
327
|
+
def encode_text(self, texts: torch.Tensor, dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
|
328
|
+
"""Encode tokenized texts into normalized feature vectors.
|
|
312
329
|
|
|
313
330
|
Args:
|
|
314
331
|
texts (torch.Tensor): Tokenized text inputs.
|
|
@@ -318,19 +335,18 @@ class MobileCLIPTS(TextModel):
|
|
|
318
335
|
(torch.Tensor): Normalized text feature vectors with L2 normalization applied.
|
|
319
336
|
|
|
320
337
|
Examples:
|
|
321
|
-
>>> model =
|
|
338
|
+
>>> model = MobileCLIPTS(device="cpu")
|
|
322
339
|
>>> tokens = model.tokenize(["a photo of a cat", "a photo of a dog"])
|
|
323
340
|
>>> features = model.encode_text(tokens)
|
|
324
341
|
>>> features.shape
|
|
325
342
|
torch.Size([2, 512]) # Actual dimension depends on model size
|
|
326
343
|
"""
|
|
327
344
|
# NOTE: no need to do normalization here as it's embedded in the torchscript model
|
|
328
|
-
return self.encoder(texts)
|
|
345
|
+
return self.encoder(texts).to(dtype)
|
|
329
346
|
|
|
330
347
|
|
|
331
|
-
def build_text_model(variant, device=None):
|
|
332
|
-
"""
|
|
333
|
-
Build a text encoding model based on the specified variant.
|
|
348
|
+
def build_text_model(variant: str, device: torch.device = None) -> TextModel:
|
|
349
|
+
"""Build a text encoding model based on the specified variant.
|
|
334
350
|
|
|
335
351
|
Args:
|
|
336
352
|
variant (str): Model variant in format "base:size" (e.g., "clip:ViT-B/32" or "mobileclip:s0").
|
ultralytics/py.typed
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
partial
|
|
@@ -19,23 +19,23 @@ from .trackzone import TrackZone
|
|
|
19
19
|
from .vision_eye import VisionEye
|
|
20
20
|
|
|
21
21
|
__all__ = (
|
|
22
|
-
"ObjectCounter",
|
|
23
|
-
"ObjectCropper",
|
|
24
|
-
"ObjectBlurrer",
|
|
25
22
|
"AIGym",
|
|
26
|
-
"
|
|
27
|
-
"
|
|
23
|
+
"Analytics",
|
|
24
|
+
"DistanceCalculation",
|
|
28
25
|
"Heatmap",
|
|
26
|
+
"Inference",
|
|
29
27
|
"InstanceSegmentation",
|
|
30
|
-
"
|
|
31
|
-
"
|
|
32
|
-
"
|
|
33
|
-
"QueueManager",
|
|
28
|
+
"ObjectBlurrer",
|
|
29
|
+
"ObjectCounter",
|
|
30
|
+
"ObjectCropper",
|
|
34
31
|
"ParkingManagement",
|
|
35
32
|
"ParkingPtsSelection",
|
|
36
|
-
"
|
|
37
|
-
"
|
|
38
|
-
"TrackZone",
|
|
33
|
+
"QueueManager",
|
|
34
|
+
"RegionCounter",
|
|
39
35
|
"SearchApp",
|
|
36
|
+
"SecurityAlarm",
|
|
37
|
+
"SpeedEstimator",
|
|
38
|
+
"TrackZone",
|
|
39
|
+
"VisionEye",
|
|
40
40
|
"VisualAISearch",
|
|
41
41
|
)
|
ultralytics/solutions/ai_gym.py
CHANGED
|
@@ -1,25 +1,25 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
3
|
from collections import defaultdict
|
|
4
|
+
from typing import Any
|
|
4
5
|
|
|
5
6
|
from ultralytics.solutions.solutions import BaseSolution, SolutionAnnotator, SolutionResults
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class AIGym(BaseSolution):
|
|
9
|
-
"""
|
|
10
|
-
A class to manage gym steps of people in a real-time video stream based on their poses.
|
|
10
|
+
"""A class to manage gym steps of people in a real-time video stream based on their poses.
|
|
11
11
|
|
|
12
12
|
This class extends BaseSolution to monitor workouts using YOLO pose estimation models. It tracks and counts
|
|
13
13
|
repetitions of exercises based on predefined angle thresholds for up and down positions.
|
|
14
14
|
|
|
15
15
|
Attributes:
|
|
16
|
-
states (
|
|
16
|
+
states (dict[float, int, str]): Stores per-track angle, count, and stage for workout monitoring.
|
|
17
17
|
up_angle (float): Angle threshold for considering the 'up' position of an exercise.
|
|
18
18
|
down_angle (float): Angle threshold for considering the 'down' position of an exercise.
|
|
19
|
-
kpts (
|
|
19
|
+
kpts (list[int]): Indices of keypoints used for angle calculation.
|
|
20
20
|
|
|
21
21
|
Methods:
|
|
22
|
-
process:
|
|
22
|
+
process: Process a frame to detect poses, calculate angles, and count repetitions.
|
|
23
23
|
|
|
24
24
|
Examples:
|
|
25
25
|
>>> gym = AIGym(model="yolo11n-pose.pt")
|
|
@@ -30,13 +30,12 @@ class AIGym(BaseSolution):
|
|
|
30
30
|
>>> cv2.waitKey(0)
|
|
31
31
|
"""
|
|
32
32
|
|
|
33
|
-
def __init__(self, **kwargs):
|
|
34
|
-
"""
|
|
35
|
-
Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
|
33
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
34
|
+
"""Initialize AIGym for workout monitoring using pose estimation and predefined angles.
|
|
36
35
|
|
|
37
36
|
Args:
|
|
38
|
-
**kwargs (Any): Keyword arguments passed to the parent class constructor
|
|
39
|
-
model (str): Model name or path, defaults to "yolo11n-pose.pt".
|
|
37
|
+
**kwargs (Any): Keyword arguments passed to the parent class constructor including:
|
|
38
|
+
- model (str): Model name or path, defaults to "yolo11n-pose.pt".
|
|
40
39
|
"""
|
|
41
40
|
kwargs["model"] = kwargs.get("model", "yolo11n-pose.pt")
|
|
42
41
|
super().__init__(**kwargs)
|
|
@@ -47,23 +46,19 @@ class AIGym(BaseSolution):
|
|
|
47
46
|
self.down_angle = float(self.CFG["down_angle"]) # Pose down predefined angle to consider down pose
|
|
48
47
|
self.kpts = self.CFG["kpts"] # User selected kpts of workouts storage for further usage
|
|
49
48
|
|
|
50
|
-
def process(self, im0):
|
|
51
|
-
"""
|
|
52
|
-
Monitor workouts using Ultralytics YOLO Pose Model.
|
|
49
|
+
def process(self, im0) -> SolutionResults:
|
|
50
|
+
"""Monitor workouts using Ultralytics YOLO Pose Model.
|
|
53
51
|
|
|
54
|
-
This function processes an input image to track and analyze human poses for workout monitoring. It uses
|
|
55
|
-
|
|
56
|
-
angle thresholds.
|
|
52
|
+
This function processes an input image to track and analyze human poses for workout monitoring. It uses the YOLO
|
|
53
|
+
Pose model to detect keypoints, estimate angles, and count repetitions based on predefined angle thresholds.
|
|
57
54
|
|
|
58
55
|
Args:
|
|
59
56
|
im0 (np.ndarray): Input image for processing.
|
|
60
57
|
|
|
61
58
|
Returns:
|
|
62
|
-
(SolutionResults): Contains processed image `plot_im`,
|
|
63
|
-
'
|
|
64
|
-
|
|
65
|
-
'workout_angle' (list of angles), and
|
|
66
|
-
'total_tracks' (total number of tracked individuals).
|
|
59
|
+
(SolutionResults): Contains processed image `plot_im`, 'workout_count' (list of completed reps),
|
|
60
|
+
'workout_stage' (list of current stages), 'workout_angle' (list of angles), and 'total_tracks' (total
|
|
61
|
+
number of tracked individuals).
|
|
67
62
|
|
|
68
63
|
Examples:
|
|
69
64
|
>>> gym = AIGym()
|
|
@@ -74,15 +69,12 @@ class AIGym(BaseSolution):
|
|
|
74
69
|
annotator = SolutionAnnotator(im0, line_width=self.line_width) # Initialize annotator
|
|
75
70
|
|
|
76
71
|
self.extract_tracks(im0) # Extract tracks (bounding boxes, classes, and masks)
|
|
77
|
-
tracks = self.tracks[0]
|
|
78
72
|
|
|
79
|
-
if
|
|
80
|
-
|
|
81
|
-
kpt_data = tracks.keypoints.data.cpu() # Avoid repeated .cpu() calls
|
|
73
|
+
if len(self.boxes):
|
|
74
|
+
kpt_data = self.tracks.keypoints.data
|
|
82
75
|
|
|
83
76
|
for i, k in enumerate(kpt_data):
|
|
84
|
-
|
|
85
|
-
state = self.states[track_id] # get state details
|
|
77
|
+
state = self.states[self.track_ids[i]] # get state details
|
|
86
78
|
# Get keypoints and estimate the angle
|
|
87
79
|
state["angle"] = annotator.estimate_pose_angle(*[k[int(idx)] for idx in self.kpts])
|
|
88
80
|
annotator.draw_specific_kpts(k, self.kpts, radius=self.line_width * 3)
|
|
@@ -1,6 +1,9 @@
|
|
|
1
1
|
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
2
|
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
3
5
|
from itertools import cycle
|
|
6
|
+
from typing import Any
|
|
4
7
|
|
|
5
8
|
import cv2
|
|
6
9
|
import numpy as np
|
|
@@ -9,11 +12,10 @@ from ultralytics.solutions.solutions import BaseSolution, SolutionResults # Imp
|
|
|
9
12
|
|
|
10
13
|
|
|
11
14
|
class Analytics(BaseSolution):
|
|
12
|
-
"""
|
|
13
|
-
A class for creating and updating various types of charts for visual analytics.
|
|
15
|
+
"""A class for creating and updating various types of charts for visual analytics.
|
|
14
16
|
|
|
15
|
-
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts
|
|
16
|
-
|
|
17
|
+
This class extends BaseSolution to provide functionality for generating line, bar, pie, and area charts based on
|
|
18
|
+
object detection and tracking data.
|
|
17
19
|
|
|
18
20
|
Attributes:
|
|
19
21
|
type (str): The type of analytics chart to generate ('line', 'bar', 'pie', or 'area').
|
|
@@ -26,12 +28,12 @@ class Analytics(BaseSolution):
|
|
|
26
28
|
fontsize (int): Font size for text display.
|
|
27
29
|
color_cycle (cycle): Cyclic iterator for chart colors.
|
|
28
30
|
total_counts (int): Total count of detected objects (used for line charts).
|
|
29
|
-
clswise_count (
|
|
31
|
+
clswise_count (dict[str, int]): Dictionary for class-wise object counts.
|
|
30
32
|
fig (Figure): Matplotlib figure object for the chart.
|
|
31
33
|
ax (Axes): Matplotlib axes object for the chart.
|
|
32
34
|
canvas (FigureCanvasAgg): Canvas for rendering the chart.
|
|
33
35
|
lines (dict): Dictionary to store line objects for area charts.
|
|
34
|
-
color_mapping (
|
|
36
|
+
color_mapping (dict[str, str]): Dictionary mapping class labels to colors for consistent visualization.
|
|
35
37
|
|
|
36
38
|
Methods:
|
|
37
39
|
process: Process image data and update the chart.
|
|
@@ -44,7 +46,7 @@ class Analytics(BaseSolution):
|
|
|
44
46
|
>>> cv2.imshow("Analytics", results.plot_im)
|
|
45
47
|
"""
|
|
46
48
|
|
|
47
|
-
def __init__(self, **kwargs):
|
|
49
|
+
def __init__(self, **kwargs: Any) -> None:
|
|
48
50
|
"""Initialize Analytics class with various chart types for visual data representation."""
|
|
49
51
|
super().__init__(**kwargs)
|
|
50
52
|
|
|
@@ -67,6 +69,8 @@ class Analytics(BaseSolution):
|
|
|
67
69
|
|
|
68
70
|
self.total_counts = 0 # count variable for storing total counts i.e. for line
|
|
69
71
|
self.clswise_count = {} # dictionary for class-wise counts
|
|
72
|
+
self.update_every = kwargs.get("update_every", 30) # Only update graph every 30 frames by default
|
|
73
|
+
self.last_plot_im = None # Cache of the last rendered chart
|
|
70
74
|
|
|
71
75
|
# Ensure line and area chart
|
|
72
76
|
if self.type in {"line", "area"}:
|
|
@@ -86,9 +90,8 @@ class Analytics(BaseSolution):
|
|
|
86
90
|
if self.type == "pie": # Ensure pie chart is circular
|
|
87
91
|
self.ax.axis("equal")
|
|
88
92
|
|
|
89
|
-
def process(self, im0, frame_number):
|
|
90
|
-
"""
|
|
91
|
-
Process image data and run object tracking to update analytics charts.
|
|
93
|
+
def process(self, im0: np.ndarray, frame_number: int) -> SolutionResults:
|
|
94
|
+
"""Process image data and run object tracking to update analytics charts.
|
|
92
95
|
|
|
93
96
|
Args:
|
|
94
97
|
im0 (np.ndarray): Input image for processing.
|
|
@@ -110,29 +113,35 @@ class Analytics(BaseSolution):
|
|
|
110
113
|
if self.type == "line":
|
|
111
114
|
for _ in self.boxes:
|
|
112
115
|
self.total_counts += 1
|
|
113
|
-
|
|
116
|
+
update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
|
|
117
|
+
if update_required:
|
|
118
|
+
self.last_plot_im = self.update_graph(frame_number=frame_number)
|
|
119
|
+
plot_im = self.last_plot_im
|
|
114
120
|
self.total_counts = 0
|
|
115
121
|
elif self.type in {"pie", "bar", "area"}:
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
122
|
+
from collections import Counter
|
|
123
|
+
|
|
124
|
+
self.clswise_count = Counter(self.names[int(cls)] for cls in self.clss)
|
|
125
|
+
update_required = frame_number % self.update_every == 0 or self.last_plot_im is None
|
|
126
|
+
if update_required:
|
|
127
|
+
self.last_plot_im = self.update_graph(
|
|
128
|
+
frame_number=frame_number, count_dict=self.clswise_count, plot=self.type
|
|
129
|
+
)
|
|
130
|
+
plot_im = self.last_plot_im
|
|
123
131
|
else:
|
|
124
132
|
raise ModuleNotFoundError(f"{self.type} chart is not supported ❌")
|
|
125
133
|
|
|
126
134
|
# return output dictionary with summary for more usage
|
|
127
135
|
return SolutionResults(plot_im=plot_im, total_tracks=len(self.track_ids), classwise_count=self.clswise_count)
|
|
128
136
|
|
|
129
|
-
def update_graph(
|
|
130
|
-
""
|
|
131
|
-
|
|
137
|
+
def update_graph(
|
|
138
|
+
self, frame_number: int, count_dict: dict[str, int] | None = None, plot: str = "line"
|
|
139
|
+
) -> np.ndarray:
|
|
140
|
+
"""Update the graph with new data for single or multiple classes.
|
|
132
141
|
|
|
133
142
|
Args:
|
|
134
143
|
frame_number (int): The current frame number.
|
|
135
|
-
count_dict (
|
|
144
|
+
count_dict (dict[str, int], optional): Dictionary with class names as keys and counts as values for multiple
|
|
136
145
|
classes. If None, updates a single line graph.
|
|
137
146
|
plot (str): Type of the plot. Options are 'line', 'bar', 'pie', or 'area'.
|
|
138
147
|
|
|
@@ -184,7 +193,7 @@ class Analytics(BaseSolution):
|
|
|
184
193
|
self.ax.clear()
|
|
185
194
|
for key, y_data in y_data_dict.items():
|
|
186
195
|
color = next(color_cycle)
|
|
187
|
-
self.ax.fill_between(x_data, y_data, color=color, alpha=0.
|
|
196
|
+
self.ax.fill_between(x_data, y_data, color=color, alpha=0.55)
|
|
188
197
|
self.ax.plot(
|
|
189
198
|
x_data,
|
|
190
199
|
y_data,
|
|
@@ -194,7 +203,7 @@ class Analytics(BaseSolution):
|
|
|
194
203
|
markersize=self.line_width * 5,
|
|
195
204
|
label=f"{key} Data Points",
|
|
196
205
|
)
|
|
197
|
-
|
|
206
|
+
elif plot == "bar":
|
|
198
207
|
self.ax.clear() # clear bar data
|
|
199
208
|
for label in labels: # Map labels to colors
|
|
200
209
|
if label not in self.color_mapping:
|
|
@@ -214,12 +223,12 @@ class Analytics(BaseSolution):
|
|
|
214
223
|
for bar, label in zip(bars, labels):
|
|
215
224
|
bar.set_label(label) # Assign label to each bar
|
|
216
225
|
self.ax.legend(loc="upper left", fontsize=13, facecolor=self.fg_color, edgecolor=self.fg_color)
|
|
217
|
-
|
|
226
|
+
elif plot == "pie":
|
|
218
227
|
total = sum(counts)
|
|
219
228
|
percentages = [size / total * 100 for size in counts]
|
|
220
|
-
start_angle = 90
|
|
221
229
|
self.ax.clear()
|
|
222
230
|
|
|
231
|
+
start_angle = 90
|
|
223
232
|
# Create pie chart and create legend labels with percentages
|
|
224
233
|
wedges, _ = self.ax.pie(
|
|
225
234
|
counts, labels=labels, startangle=start_angle, textprops={"color": self.fg_color}, autopct=None
|
|
@@ -232,6 +241,7 @@ class Analytics(BaseSolution):
|
|
|
232
241
|
|
|
233
242
|
# Common plot settings
|
|
234
243
|
self.ax.set_facecolor("#f0f0f0") # Set to light gray or any other color you like
|
|
244
|
+
self.ax.grid(True, linestyle="--", linewidth=0.5, alpha=0.5) # Display grid for more data insights
|
|
235
245
|
self.ax.set_title(self.title, color=self.fg_color, fontsize=self.fontsize)
|
|
236
246
|
self.ax.set_xlabel(self.x_label, color=self.fg_color, fontsize=self.fontsize - 3)
|
|
237
247
|
self.ax.set_ylabel(self.y_label, color=self.fg_color, fontsize=self.fontsize - 3)
|