dgenerate-ultralytics-headless 8.3.135__py3-none-any.whl → 8.3.137__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: dgenerate-ultralytics-headless
3
- Version: 8.3.135
3
+ Version: 8.3.137
4
4
  Summary: Automatically built Ultralytics package with python-opencv-headless dependency instead of python-opencv
5
5
  Author-email: Glenn Jocher <glenn.jocher@ultralytics.com>, Jing Qiu <jing.qiu@ultralytics.com>
6
6
  Maintainer-email: Ultralytics <hello@ultralytics.com>
@@ -46,7 +46,6 @@ Requires-Dist: tqdm>=4.64.0
46
46
  Requires-Dist: psutil
47
47
  Requires-Dist: py-cpuinfo
48
48
  Requires-Dist: pandas>=1.1.4
49
- Requires-Dist: seaborn>=0.11.0
50
49
  Requires-Dist: ultralytics-thop>=2.0.0
51
50
  Provides-Extra: dev
52
51
  Requires-Dist: ipython; extra == "dev"
@@ -1,17 +1,17 @@
1
- dgenerate_ultralytics_headless-8.3.135.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
1
+ dgenerate_ultralytics_headless-8.3.137.dist-info/licenses/LICENSE,sha256=DZak_2itbUtvHzD3E7GNUYSRK6jdOJ-GqncQ2weavLA,34523
2
2
  tests/__init__.py,sha256=xnMhv3O_DF1YrW4zk__ZywQzAaoTDjPKPoiI1Ktss1w,670
3
3
  tests/conftest.py,sha256=rsIAipRKfrVNoTaJ1LdpYue8AbcJ_fr3d3WIlM_6uXY,2982
4
4
  tests/test_cli.py,sha256=vXUC_EK0fa87JRhHsCOZf7AJQ5_Jm1sL8u-yhmsaQh0,5851
5
- tests/test_cuda.py,sha256=eKwaqLxWTRRYNROnkH24Ch-HmxTRKQLSIxbMYFYq_p0,8123
5
+ tests/test_cuda.py,sha256=L_2xp2TH-pInsdI8UrbZ5onRtHQGdUVoPXnyX6Ot4_U,7950
6
6
  tests/test_engine.py,sha256=aGqZ8P7QO5C_nOa1b4FOyk92Ysdk5WiP-ST310Vyxys,4962
7
- tests/test_exports.py,sha256=UeeBloqYYGZNh520R3CR80XBxA9XFrNmbK9An6V6C4w,9838
7
+ tests/test_exports.py,sha256=dhZn86LdbapW15RthQF870LGxDjC1MUZhlGdBgPmgIQ,9716
8
8
  tests/test_integrations.py,sha256=dQteeRsRVuT_p5-T88-7jqT65Zm9iAXkyKg-KQ1_TQ8,6341
9
9
  tests/test_python.py,sha256=KWsncKpeDdRmjRftmJpsMl7bBLI3TG_I7Lb4kuemZzQ,25618
10
10
  tests/test_solutions.py,sha256=IFlqyOUCvGbLe_YZqWmNCe_afg4as0p-SfAv3j7VURI,6205
11
- ultralytics/__init__.py,sha256=7IMXy8Z7sekeQRLOVZyuYbA-1kse0gieArFyUxQ9dyE,730
11
+ ultralytics/__init__.py,sha256=8hzZtbr1IMQwOTdqbcNED-RHZiqww--zXivCgQOzujQ,730
12
12
  ultralytics/assets/bus.jpg,sha256=wCAZxJecGR63Od3ZRERe9Aja1Weayrb9Ug751DS_vGM,137419
13
13
  ultralytics/assets/zidane.jpg,sha256=Ftc4aeMmen1O0A3o6GCDO9FlfBslLpTAw0gnetx7bts,50427
14
- ultralytics/cfg/__init__.py,sha256=p1dKUDoVsnjJG8Qj5Q-eukb0WH2IoWV3BcJpEmKu2tE,39487
14
+ ultralytics/cfg/__init__.py,sha256=h0UVCvX6DIpoR4_pthpZD_Ihq7eCaS8HbXsPOm82G0E,39540
15
15
  ultralytics/cfg/default.yaml,sha256=oFG6llJO-Py5H-cR9qs-7FieJamroDLwpbrkhmfROOM,8307
16
16
  ultralytics/cfg/datasets/Argoverse.yaml,sha256=_xlEDIJ9XkUo0v_iNL7FW079BoSeZtKSuLteKTtGbA8,3275
17
17
  ultralytics/cfg/datasets/DOTAv1.5.yaml,sha256=SHND_CFkojxw5iQD5Mcgju2kCZIl0gW2ajuzv1cqoL0,1224
@@ -119,7 +119,7 @@ ultralytics/data/scripts/get_coco.sh,sha256=UuJpJeo3qQpTHVINeOpmP0NYmg8PhEFE3A8J
119
119
  ultralytics/data/scripts/get_coco128.sh,sha256=qmRQl_hOKrsdHrTrnyQuFIH01oDz3lfaz138OgGfLt8,650
120
120
  ultralytics/data/scripts/get_imagenet.sh,sha256=hr42H16bM47iT27rgS7MpEo-GeOZAYUQXgr0B2cwn48,1705
121
121
  ultralytics/engine/__init__.py,sha256=lm6MckFYCPTbqIoX7w0s_daxdjNeBeKW6DXppv1-QUM,70
122
- ultralytics/engine/exporter.py,sha256=tXqZlcOZnDqtK7A0nwago7FfDAb3ftnYui-VeOFzVs0,70823
122
+ ultralytics/engine/exporter.py,sha256=JucFVR_RAfzrRWM9kJK6MHALEbdzrf93ReTnAhiRTBo,70823
123
123
  ultralytics/engine/model.py,sha256=fWhPNWUQzjjWfTEXzTaqSSearV4THRkEa_fl4dDvzWw,52930
124
124
  ultralytics/engine/predictor.py,sha256=AwKpOGY2G-thNNiRw4Kf_MBLamq5tbRhXLNSMRArqFo,21803
125
125
  ultralytics/engine/results.py,sha256=MhbyMCwgslmtV53fqii4UJUaLQ4gKTKdkXi7vvmJDAE,79628
@@ -186,17 +186,17 @@ ultralytics/models/yolo/segment/predict.py,sha256=mIC3aHI7Jg4dU1k2UZnjVj4unE-5TW
186
186
  ultralytics/models/yolo/segment/train.py,sha256=EIyIAjYp127Mb-DomyjPORaONu57OY_gOTK9p2MwW6E,5359
187
187
  ultralytics/models/yolo/segment/val.py,sha256=cXJM1JNuzDraU0SJQRIdzNxabd0bfcxiRE8wozHZChY,18415
188
188
  ultralytics/models/yolo/world/__init__.py,sha256=nlh8I6t8hMGz_vZg8QSlsUW1R-2eKvn9CGUoPPQEGhA,131
189
- ultralytics/models/yolo/world/train.py,sha256=HUJ0XiJIGx_FA9kqNYnSFsaKWMiZUDxgkpfGoBH6UNc,4896
190
- ultralytics/models/yolo/world/train_world.py,sha256=DSa-t9jDbtwF43SJlvtESh1Ux7M77zo9f945eR2D-5w,8363
189
+ ultralytics/models/yolo/world/train.py,sha256=4e54RghcrpdtpxG3n2Nicwo-tcj-wI4nLcUo8_4cf30,6898
190
+ ultralytics/models/yolo/world/train_world.py,sha256=fFhhI-toaEy1_-XcPM1_mF395WRQ26gZ4UxqyUAZmWw,8461
191
191
  ultralytics/models/yolo/yoloe/__init__.py,sha256=6SLytdJtwu37qewf7CobG7C7Wl1m-xtNdvCXEasfPDE,760
192
192
  ultralytics/models/yolo/yoloe/predict.py,sha256=N0oYcr_mdw8wyUAWprAwJhrA0r23BaTeYXEjw2e8_mI,6993
193
- ultralytics/models/yolo/yoloe/train.py,sha256=St3zw_XWRol9pODWU4lvKlJnWYr1lmWQNuhLFwWMge4,12989
193
+ ultralytics/models/yolo/yoloe/train.py,sha256=xRPDJ3nUWxtqjESfmUtsZslVhpgzrZRw8z_QU5hV6nc,11710
194
194
  ultralytics/models/yolo/yoloe/train_seg.py,sha256=BYFBd04k5WQaJPcFbCvVIbEf2IOQyW8_sGeoVT_74j0,4632
195
195
  ultralytics/models/yolo/yoloe/val.py,sha256=oA8cVT3pBXF6aPZy7ITq0mDcktRuIgks8tTtqMRISyY,8431
196
196
  ultralytics/nn/__init__.py,sha256=rjociYD9lo_K-d-1s6TbdWklPLjTcEHk7OIlRDJstIE,615
197
197
  ultralytics/nn/autobackend.py,sha256=X2cxCytBu9fmniy8uJ5aZb28IukQ-uxV1INXeS1lclA,39368
198
- ultralytics/nn/tasks.py,sha256=o7QZvlZyvmECxkITJjtDCPf-hAxXcZOLXP7PKtegOPQ,63594
199
- ultralytics/nn/text_model.py,sha256=8_7SRejKZA4Pi-ha0gjcWrQDDCDMBhtwlg8pPMWgjDE,13145
198
+ ultralytics/nn/tasks.py,sha256=iJWpwRr4yZg1dTT-9jXuzIqkdFmbZm1b7hejnO-CiZk,64337
199
+ ultralytics/nn/text_model.py,sha256=wr5yPRbMqtSr2N5Rzdd0vuv9PcQe8qw4uO596ZHZVGU,13236
200
200
  ultralytics/nn/modules/__init__.py,sha256=dXLtIk9rt944WfsTdpgEdWOg3HQEHdwQztuZ6WNJygs,3144
201
201
  ultralytics/nn/modules/activation.py,sha256=PvXZkA9AzEntR575JkFORdmtcRwATyy0lje-uHA5_8w,2210
202
202
  ultralytics/nn/modules/block.py,sha256=yd6Ao9T2UJNAWc8oB1-CSxyF6-exqbFcN3hTWUZNU3M,66701
@@ -238,7 +238,7 @@ ultralytics/utils/__init__.py,sha256=vac0M-Hx55QXl6Vod3QPjnLBlt87Hwxu1784RXPmeQA
238
238
  ultralytics/utils/autobatch.py,sha256=kg05q2qKg74y_Uq2vvr01i3KhLfpVR7sT0IXBt3_kyI,4921
239
239
  ultralytics/utils/autodevice.py,sha256=OKZfTbswg6SlsYGCGMqROkA-451CXGG47oeyC5Q1kFM,7232
240
240
  ultralytics/utils/benchmarks.py,sha256=lDNNnLeLUzmqKrqrqlCOiau-q7A-gcLooZP2dbxCu-U,30214
241
- ultralytics/utils/checks.py,sha256=L5G8CiQo8v2842KLGOaLG5y_AYRoa5gxCdtTt48LnS0,33129
241
+ ultralytics/utils/checks.py,sha256=TGhnnNVT3NEBhSeckWIe1rGlXUyYI3xhFqK6CR0oBiE,33192
242
242
  ultralytics/utils/dist.py,sha256=aytW0JEkcA5ZTZucV92ot7Bn-apiej8aLk3QNWicjAc,4103
243
243
  ultralytics/utils/downloads.py,sha256=Rn8xDwn2bzgBqiYz3Xn0rm3MWjk4T-QUd2Ajlu1EpQ4,22312
244
244
  ultralytics/utils/errors.py,sha256=vY9h2evFSrHnZdHJVVrmm8Zzw4qVDLyo9DeYW5g0dFk,1573
@@ -246,10 +246,10 @@ ultralytics/utils/export.py,sha256=XInnl9AQeik7EuR1492nzDvgDqaV43FlnM5CLamrgd4,8
246
246
  ultralytics/utils/files.py,sha256=0K4O1cgqRiXaDw7EQK13TqA5SME_RrvfDVQSPetNr5w,8042
247
247
  ultralytics/utils/instance.py,sha256=UOEsXR9V-bXNRk6BTonASBEgeMqvzzAk4S7VdXZJUAM,18090
248
248
  ultralytics/utils/loss.py,sha256=Woc_rj7ptCyezHdylEygXMeSEgivYu_B9jJHD4UwxWE,37607
249
- ultralytics/utils/metrics.py,sha256=pWNq-66VqkMjj05Gqkm8ddoElDK72q_U9cl8y-aEN6k,53963
249
+ ultralytics/utils/metrics.py,sha256=n8guPEADBMRNpeXNShEX-fxVv9xck8S4QaOIiaW_kl0,56037
250
250
  ultralytics/utils/ops.py,sha256=YFwPrKlPcgEmgAWqnJVR0Ccx5NQgp5e3P-YYHwVSP0k,34779
251
251
  ultralytics/utils/patches.py,sha256=_dhIU_eDklQE-aWIjpyjPHl_wOwZoGuIUQnXgdSwk_A,5020
252
- ultralytics/utils/plotting.py,sha256=m9Hsbt6U073jAiztX6clpd9KzznW62oHxCWlBcm0T-s,46920
252
+ ultralytics/utils/plotting.py,sha256=GKic2OMavjJPT3pOPdU0UcvQTrG1LVt0vHJM-Zuy9Bs,47217
253
253
  ultralytics/utils/tal.py,sha256=P5nPoR9qNnFuDIda0fsn8WP6m1V8r7EbvXUuhNRFFTA,20805
254
254
  ultralytics/utils/torch_utils.py,sha256=2SJxxg8Qr0YqOoQ-8qAYn6VrzZdQMObqiw3CJZ-rAY0,39611
255
255
  ultralytics/utils/triton.py,sha256=xK9Db_ZUVDnIK1u76S2G-6ulIBsLfj9HN_YOaSrnMuU,5304
@@ -265,8 +265,8 @@ ultralytics/utils/callbacks/neptune.py,sha256=yYUgEgSv6L39sSev6vjwhAWU3DlPDsbSDV
265
265
  ultralytics/utils/callbacks/raytune.py,sha256=A8amUGpux7dYES-L1iSeMoMXBySGWCD1aUqT7vcG-pU,1284
266
266
  ultralytics/utils/callbacks/tensorboard.py,sha256=jgYnym3cUQFAgN1GzTyO7l3jINtfAh8zhrllDvnLuVQ,5339
267
267
  ultralytics/utils/callbacks/wb.py,sha256=iDRFXI4IIDm8R5OI89DMTmjs8aHLo1HRCLkOFKdaMG4,7507
268
- dgenerate_ultralytics_headless-8.3.135.dist-info/METADATA,sha256=8_HSDModHJ24S-bmagyx903_49yFcEJGC63r5nON6g4,38327
269
- dgenerate_ultralytics_headless-8.3.135.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
270
- dgenerate_ultralytics_headless-8.3.135.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
271
- dgenerate_ultralytics_headless-8.3.135.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
272
- dgenerate_ultralytics_headless-8.3.135.dist-info/RECORD,,
268
+ dgenerate_ultralytics_headless-8.3.137.dist-info/METADATA,sha256=8ui4ivOJaSEgzcD9bZTlWkJ3-Q_44TdABJCrfpEeLRM,38296
269
+ dgenerate_ultralytics_headless-8.3.137.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
270
+ dgenerate_ultralytics_headless-8.3.137.dist-info/entry_points.txt,sha256=YM_wiKyTe9yRrsEfqvYolNO5ngwfoL4-NwgKzc8_7sI,93
271
+ dgenerate_ultralytics_headless-8.3.137.dist-info/top_level.txt,sha256=XP49TwiMw4QGsvTLSYiJhz1xF_k7ev5mQ8jJXaXi45Q,12
272
+ dgenerate_ultralytics_headless-8.3.137.dist-info/RECORD,,
tests/test_cuda.py CHANGED
@@ -41,7 +41,7 @@ def test_amp():
41
41
 
42
42
 
43
43
  @pytest.mark.slow
44
- # @pytest.mark.skipif(IS_JETSON, reason="Temporary disable ONNX for Jetson")
44
+ @pytest.mark.skipif(IS_JETSON, reason="Temporary disable ONNX for Jetson")
45
45
  @pytest.mark.skipif(not DEVICES, reason="No CUDA devices available")
46
46
  @pytest.mark.parametrize(
47
47
  "task, dynamic, int8, half, batch, simplify, nms",
@@ -50,12 +50,7 @@ def test_amp():
50
50
  for task, dynamic, int8, half, batch, simplify, nms in product(
51
51
  TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]
52
52
  )
53
- if not (
54
- (int8 and half)
55
- or (task == "classify" and nms)
56
- or (task == "obb" and nms and (not TORCH_1_13 or IS_JETSON)) # obb nms fails on NVIDIA Jetson
57
- or (simplify and dynamic) # onnxslim is slow when dynamic=True
58
- )
53
+ if not ((int8 and half) or (task == "classify" and nms) or (task == "obb" and nms and not TORCH_1_13))
59
54
  ],
60
55
  )
61
56
  def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
tests/test_exports.py CHANGED
@@ -83,12 +83,7 @@ def test_export_openvino_matrix(task, dynamic, int8, half, batch, nms):
83
83
  for task, dynamic, int8, half, batch, simplify, nms in product(
84
84
  TASKS, [True, False], [False], [False], [1, 2], [True, False], [True, False]
85
85
  )
86
- if not (
87
- (int8 and half)
88
- or (task == "classify" and nms)
89
- or (task == "obb" and nms and not TORCH_1_13)
90
- or (simplify and dynamic) # onnxslim is slow when dynamic=True
91
- )
86
+ if not ((int8 and half) or (task == "classify" and nms) or (task == "obb" and nms and not TORCH_1_13))
92
87
  ],
93
88
  )
94
89
  def test_export_onnx_matrix(task, dynamic, int8, half, batch, simplify, nms):
ultralytics/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
- __version__ = "8.3.135"
3
+ __version__ = "8.3.137"
4
4
 
5
5
  import os
6
6
 
@@ -7,8 +7,6 @@ from pathlib import Path
7
7
  from types import SimpleNamespace
8
8
  from typing import Any, Dict, List, Union
9
9
 
10
- import cv2
11
-
12
10
  from ultralytics import __version__
13
11
  from ultralytics.utils import (
14
12
  ASSETS,
@@ -707,6 +705,8 @@ def handle_yolo_solutions(args: List[str]) -> None:
707
705
  ]
708
706
  )
709
707
  else:
708
+ import cv2 # Only needed for cap and vw functionality
709
+
710
710
  from ultralytics import solutions
711
711
 
712
712
  solution = getattr(solutions, SOLUTION_MAP[solution_name])(is_cli=True, **overrides) # class i.e ObjectCounter
@@ -919,7 +919,7 @@ def entrypoint(debug: str = "") -> None:
919
919
  if task not in TASKS:
920
920
  if task == "track":
921
921
  LOGGER.warning(
922
- "invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}."
922
+ f"invalid 'task=track', setting 'task=detect' and 'mode=track'. Valid tasks are {TASKS}.\n{CLI_HELP_MSG}."
923
923
  )
924
924
  task, mode = "detect", "track"
925
925
  else:
@@ -557,7 +557,7 @@ class Exporter:
557
557
  """YOLO ONNX export."""
558
558
  requirements = ["onnx>=1.12.0,<1.18.0"]
559
559
  if self.args.simplify:
560
- requirements += ["onnxslim>=0.1.46", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
560
+ requirements += ["onnxslim>=0.1.53", "onnxruntime" + ("-gpu" if torch.cuda.is_available() else "")]
561
561
  check_requirements(requirements)
562
562
  import onnx # noqa
563
563
 
@@ -928,7 +928,7 @@ class Exporter:
928
928
  "ai-edge-litert>=1.2.0", # required by 'onnx2tf' package
929
929
  "onnx>=1.12.0,<1.18.0",
930
930
  "onnx2tf>=1.26.3",
931
- "onnxslim>=0.1.46",
931
+ "onnxslim>=0.1.53",
932
932
  "onnxruntime-gpu" if cuda else "onnxruntime",
933
933
  "protobuf>=5",
934
934
  ),
@@ -1,11 +1,14 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
3
  import itertools
4
+ from pathlib import Path
5
+
6
+ import torch
4
7
 
5
8
  from ultralytics.data import build_yolo_dataset
6
- from ultralytics.models import yolo
9
+ from ultralytics.models.yolo.detect import DetectionTrainer
7
10
  from ultralytics.nn.tasks import WorldModel
8
- from ultralytics.utils import DEFAULT_CFG, RANK, checks
11
+ from ultralytics.utils import DEFAULT_CFG, LOGGER, RANK
9
12
  from ultralytics.utils.torch_utils import de_parallel
10
13
 
11
14
 
@@ -15,13 +18,9 @@ def on_pretrain_routine_end(trainer):
15
18
  # Set class names for evaluation
16
19
  names = [name.split("/")[0] for name in list(trainer.test_loader.dataset.data["names"].values())]
17
20
  de_parallel(trainer.ema.ema).set_classes(names, cache_clip_model=False)
18
- device = next(trainer.model.parameters()).device
19
- trainer.text_model, _ = trainer.clip.load("ViT-B/32", device=device)
20
- for p in trainer.text_model.parameters():
21
- p.requires_grad_(False)
22
21
 
23
22
 
24
- class WorldTrainer(yolo.detect.DetectionTrainer):
23
+ class WorldTrainer(DetectionTrainer):
25
24
  """
26
25
  A class to fine-tune a world model on a close-set dataset.
27
26
 
@@ -54,14 +53,7 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
54
53
  if overrides is None:
55
54
  overrides = {}
56
55
  super().__init__(cfg, overrides, _callbacks)
57
-
58
- # Import and assign clip
59
- try:
60
- import clip
61
- except ImportError:
62
- checks.check_requirements("git+https://github.com/ultralytics/CLIP.git")
63
- import clip
64
- self.clip = clip
56
+ self.text_embeddings = None
65
57
 
66
58
  def get_model(self, cfg=None, weights=None, verbose=True):
67
59
  """
@@ -102,18 +94,72 @@ class WorldTrainer(yolo.detect.DetectionTrainer):
102
94
  (Dataset): YOLO dataset configured for training or validation.
103
95
  """
104
96
  gs = max(int(de_parallel(self.model).stride.max() if self.model else 0), 32)
105
- return build_yolo_dataset(
97
+ dataset = build_yolo_dataset(
106
98
  self.args, img_path, batch, self.data, mode=mode, rect=mode == "val", stride=gs, multi_modal=mode == "train"
107
99
  )
100
+ if mode == "train":
101
+ self.set_text_embeddings([dataset], batch) # cache text embeddings to accelerate training
102
+ return dataset
103
+
104
+ def set_text_embeddings(self, datasets, batch):
105
+ """
106
+ Set text embeddings for datasets to accelerate training by caching category names.
107
+
108
+ This method collects unique category names from all datasets, then generates and caches text embeddings
109
+ for these categories to improve training efficiency.
110
+
111
+ Args:
112
+ datasets (List[Dataset]): List of datasets from which to extract category names.
113
+ batch (int | None): Batch size used for processing.
114
+
115
+ Notes:
116
+ This method collects category names from datasets that have the 'category_names' attribute,
117
+ then uses the first dataset's image path to determine where to cache the generated text embeddings.
118
+ """
119
+ text_embeddings = {}
120
+ for dataset in datasets:
121
+ if not hasattr(dataset, "category_names"):
122
+ continue
123
+ text_embeddings.update(
124
+ self.generate_text_embeddings(
125
+ list(dataset.category_names), batch, cache_dir=Path(dataset.img_path).parent
126
+ )
127
+ )
128
+ self.text_embeddings = text_embeddings
129
+
130
+ def generate_text_embeddings(self, texts, batch, cache_dir):
131
+ """
132
+ Generate text embeddings for a list of text samples.
133
+
134
+ Args:
135
+ texts (List[str]): List of text samples to encode.
136
+ batch (int): Batch size for processing.
137
+ cache_dir (Path): Directory to save/load cached embeddings.
138
+
139
+ Returns:
140
+ (dict): Dictionary mapping text samples to their embeddings.
141
+ """
142
+ model = "clip:ViT-B/32"
143
+ cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
144
+ if cache_path.exists():
145
+ LOGGER.info(f"Reading existed cache from '{cache_path}'")
146
+ txt_map = torch.load(cache_path)
147
+ if sorted(txt_map.keys()) == sorted(texts):
148
+ return txt_map
149
+ LOGGER.info(f"Caching text embeddings to '{cache_path}'")
150
+ assert self.model is not None
151
+ txt_feats = self.model.get_text_pe(texts, batch, cache_clip_model=False)
152
+ txt_map = dict(zip(texts, txt_feats.squeeze(0)))
153
+ torch.save(txt_map, cache_path)
154
+ return txt_map
108
155
 
109
156
  def preprocess_batch(self, batch):
110
157
  """Preprocess a batch of images and text for YOLOWorld training."""
111
- batch = super().preprocess_batch(batch)
158
+ batch = DetectionTrainer.preprocess_batch(self, batch)
112
159
 
113
160
  # Add text features
114
161
  texts = list(itertools.chain(*batch["texts"]))
115
- text_token = self.clip.tokenize(texts).to(batch["img"].device)
116
- txt_feats = self.text_model.encode_text(text_token).to(dtype=batch["img"].dtype) # torch.float32
162
+ txt_feats = torch.stack([self.text_embeddings[text] for text in texts]).to(self.device)
117
163
  txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
118
164
  batch["txt_feats"] = txt_feats.reshape(len(batch["texts"]), -1, txt_feats.shape[-1])
119
165
  return batch
@@ -100,6 +100,7 @@ class WorldTrainerFromScratch(WorldTrainer):
100
100
  else build_grounding(self.args, im_path["img_path"], im_path["json_file"], batch, stride=gs)
101
101
  for im_path in img_path
102
102
  ]
103
+ self.set_text_embeddings(datasets, batch) # cache text embeddings to accelerate training
103
104
  return YOLOConcatDataset(datasets) if len(datasets) > 1 else datasets[0]
104
105
 
105
106
  def get_dataset(self):
@@ -2,7 +2,6 @@
2
2
 
3
3
  import itertools
4
4
  from copy import copy, deepcopy
5
- from pathlib import Path
6
5
 
7
6
  import torch
8
7
 
@@ -157,40 +156,7 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
157
156
  Returns:
158
157
  (YOLOConcatDataset | Dataset): The constructed dataset for training or validation.
159
158
  """
160
- datasets = WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
161
- if mode == "train":
162
- self.set_text_embeddings(
163
- datasets.datasets if hasattr(datasets, "datasets") else [datasets], batch
164
- ) # cache text embeddings to accelerate training
165
- return datasets
166
-
167
- def set_text_embeddings(self, datasets, batch):
168
- """
169
- Set text embeddings for datasets to accelerate training by caching category names.
170
-
171
- This method collects unique category names from all datasets, then generates and caches text embeddings
172
- for these categories to improve training efficiency.
173
-
174
- Args:
175
- datasets (List[Dataset]): List of datasets from which to extract category names.
176
- batch (int | None): Batch size used for processing.
177
-
178
- Notes:
179
- This method collects category names from datasets that have the 'category_names' attribute,
180
- then uses the first dataset's image path to determine where to cache the generated text embeddings.
181
- """
182
- # TODO: open up an interface to determine whether to do cache
183
- category_names = set()
184
- for dataset in datasets:
185
- if not hasattr(dataset, "category_names"):
186
- continue
187
- category_names |= dataset.category_names
188
-
189
- # TODO: enable to update the path or use a more general way to get the path
190
- img_path = datasets[0].img_path
191
- self.text_embeddings = self.generate_text_embeddings(
192
- category_names, batch, cache_path=Path(img_path).parent / "text_embeddings.pt"
193
- )
159
+ return WorldTrainerFromScratch.build_dataset(self, img_path, mode, batch)
194
160
 
195
161
  def preprocess_batch(self, batch):
196
162
  """Process batch for training, moving text features to the appropriate device."""
@@ -202,23 +168,28 @@ class YOLOETrainerFromScratch(YOLOETrainer, WorldTrainerFromScratch):
202
168
  batch["txt_feats"] = txt_feats
203
169
  return batch
204
170
 
205
- def generate_text_embeddings(self, texts, batch, cache_path="embeddings.pt"):
171
+ def generate_text_embeddings(self, texts, batch, cache_dir):
206
172
  """
207
173
  Generate text embeddings for a list of text samples.
208
174
 
209
175
  Args:
210
176
  texts (List[str]): List of text samples to encode.
211
177
  batch (int): Batch size for processing.
212
- cache_path (str | Path): Path to save/load cached embeddings.
178
+ cache_dir (Path): Directory to save/load cached embeddings.
213
179
 
214
180
  Returns:
215
181
  (dict): Dictionary mapping text samples to their embeddings.
216
182
  """
183
+ model = "mobileclip:blt"
184
+ cache_path = cache_dir / f"text_embeddings_{model.replace(':', '_').replace('/', '_')}.pt"
217
185
  if cache_path.exists():
218
186
  LOGGER.info(f"Reading existed cache from '{cache_path}'")
219
- return torch.load(cache_path)
187
+ txt_map = torch.load(cache_path)
188
+ if sorted(txt_map.keys()) == sorted(texts):
189
+ return txt_map
190
+ LOGGER.info(f"Caching text embeddings to '{cache_path}'")
220
191
  assert self.model is not None
221
- txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True)
192
+ txt_feats = self.model.get_text_pe(texts, batch, without_reprta=True, cache_clip_model=False)
222
193
  txt_map = dict(zip(texts, txt_feats.squeeze(0)))
223
194
  torch.save(txt_map, cache_path)
224
195
  return txt_map
ultralytics/nn/tasks.py CHANGED
@@ -146,6 +146,8 @@ class BaseModel(torch.nn.Module):
146
146
  (torch.Tensor): The last output of the model.
147
147
  """
148
148
  y, dt, embeddings = [], [], [] # outputs
149
+ embed = frozenset(embed) if embed is not None else {-1}
150
+ max_idx = max(embed)
149
151
  for m in self.model:
150
152
  if m.f != -1: # if not from previous layer
151
153
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -155,9 +157,9 @@ class BaseModel(torch.nn.Module):
155
157
  y.append(x if m.i in self.save else None) # save output
156
158
  if visualize:
157
159
  feature_visualization(x, m.type, m.i, save_dir=visualize)
158
- if embed and m.i in embed:
160
+ if m.i in embed:
159
161
  embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
160
- if m.i == max(embed):
162
+ if m.i == max_idx:
161
163
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
162
164
  return x
163
165
 
@@ -677,6 +679,8 @@ class RTDETRDetectionModel(DetectionModel):
677
679
  (torch.Tensor): Model's output tensor.
678
680
  """
679
681
  y, dt, embeddings = [], [], [] # outputs
682
+ embed = frozenset(embed) if embed is not None else {-1}
683
+ max_idx = max(embed)
680
684
  for m in self.model[:-1]: # except the head part
681
685
  if m.f != -1: # if not from previous layer
682
686
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -686,9 +690,9 @@ class RTDETRDetectionModel(DetectionModel):
686
690
  y.append(x if m.i in self.save else None) # save output
687
691
  if visualize:
688
692
  feature_visualization(x, m.type, m.i, save_dir=visualize)
689
- if embed and m.i in embed:
693
+ if m.i in embed:
690
694
  embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
691
- if m.i == max(embed):
695
+ if m.i == max_idx:
692
696
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
693
697
  head = self.model[-1]
694
698
  x = head([y[j] for j in head.f], batch) # head inference
@@ -721,24 +725,33 @@ class WorldModel(DetectionModel):
721
725
  batch (int): Batch size for processing text tokens.
722
726
  cache_clip_model (bool): Whether to cache the CLIP model.
723
727
  """
724
- try:
725
- import clip
726
- except ImportError:
727
- check_requirements("git+https://github.com/ultralytics/CLIP.git")
728
- import clip
729
-
730
- if (
731
- not getattr(self, "clip_model", None) and cache_clip_model
732
- ): # for backwards compatibility of models lacking clip_model attribute
733
- self.clip_model = clip.load("ViT-B/32")[0]
734
- model = self.clip_model if cache_clip_model else clip.load("ViT-B/32")[0]
735
- device = next(model.parameters()).device
736
- text_token = clip.tokenize(text).to(device)
728
+ self.txt_feats = self.get_text_pe(text, batch=batch, cache_clip_model=cache_clip_model)
729
+ self.model[-1].nc = len(text)
730
+
731
+ @smart_inference_mode()
732
+ def get_text_pe(self, text, batch=80, cache_clip_model=True):
733
+ """
734
+ Set classes in advance so that model could do offline-inference without clip model.
735
+
736
+ Args:
737
+ text (List[str]): List of class names.
738
+ batch (int): Batch size for processing text tokens.
739
+ cache_clip_model (bool): Whether to cache the CLIP model.
740
+
741
+ Returns:
742
+ (torch.Tensor): Text positional embeddings.
743
+ """
744
+ from ultralytics.nn.text_model import build_text_model
745
+
746
+ device = next(self.model.parameters()).device
747
+ if not getattr(self, "clip_model", None) and cache_clip_model:
748
+ # For backwards compatibility of models lacking clip_model attribute
749
+ self.clip_model = build_text_model("clip:ViT-B/32", device=device)
750
+ model = self.clip_model if cache_clip_model else build_text_model("clip:ViT-B/32", device=device)
751
+ text_token = model.tokenize(text)
737
752
  txt_feats = [model.encode_text(token).detach() for token in text_token.split(batch)]
738
753
  txt_feats = txt_feats[0] if len(txt_feats) == 1 else torch.cat(txt_feats, dim=0)
739
- txt_feats = txt_feats / txt_feats.norm(p=2, dim=-1, keepdim=True)
740
- self.txt_feats = txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
741
- self.model[-1].nc = len(text)
754
+ return txt_feats.reshape(-1, len(text), txt_feats.shape[-1])
742
755
 
743
756
  def predict(self, x, profile=False, visualize=False, txt_feats=None, augment=False, embed=None):
744
757
  """
@@ -760,6 +773,8 @@ class WorldModel(DetectionModel):
760
773
  txt_feats = txt_feats.expand(x.shape[0], -1, -1)
761
774
  ori_txt_feats = txt_feats.clone()
762
775
  y, dt, embeddings = [], [], [] # outputs
776
+ embed = frozenset(embed) if embed is not None else {-1}
777
+ max_idx = max(embed)
763
778
  for m in self.model: # except the head part
764
779
  if m.f != -1: # if not from previous layer
765
780
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -777,9 +792,9 @@ class WorldModel(DetectionModel):
777
792
  y.append(x if m.i in self.save else None) # save output
778
793
  if visualize:
779
794
  feature_visualization(x, m.type, m.i, save_dir=visualize)
780
- if embed and m.i in embed:
795
+ if m.i in embed:
781
796
  embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
782
- if m.i == max(embed):
797
+ if m.i == max_idx:
783
798
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
784
799
  return x
785
800
 
@@ -976,6 +991,8 @@ class YOLOEModel(DetectionModel):
976
991
  """
977
992
  y, dt, embeddings = [], [], [] # outputs
978
993
  b = x.shape[0]
994
+ embed = frozenset(embed) if embed is not None else {-1}
995
+ max_idx = max(embed)
979
996
  for m in self.model: # except the head part
980
997
  if m.f != -1: # if not from previous layer
981
998
  x = y[m.f] if isinstance(m.f, int) else [x if j == -1 else y[j] for j in m.f] # from earlier layers
@@ -997,9 +1014,9 @@ class YOLOEModel(DetectionModel):
997
1014
  y.append(x if m.i in self.save else None) # save output
998
1015
  if visualize:
999
1016
  feature_visualization(x, m.type, m.i, save_dir=visualize)
1000
- if embed and m.i in embed:
1017
+ if m.i in embed:
1001
1018
  embeddings.append(torch.nn.functional.adaptive_avg_pool2d(x, (1, 1)).squeeze(-1).squeeze(-1)) # flatten
1002
- if m.i == max(embed):
1019
+ if m.i == max_idx:
1003
1020
  return torch.unbind(torch.cat(embeddings, 1), dim=0)
1004
1021
  return x
1005
1022
 
@@ -324,6 +324,7 @@ class MobileCLIPTS(TextModel):
324
324
  >>> features.shape
325
325
  torch.Size([2, 512]) # Actual dimension depends on model size
326
326
  """
327
+ # NOTE: no need to do normalization here as it's embedded in the torchscript model
327
328
  return self.encoder(texts)
328
329
 
329
330
 
@@ -80,6 +80,7 @@ def parse_requirements(file_path=ROOT.parent / "requirements.txt", package=""):
80
80
  return requirements
81
81
 
82
82
 
83
+ @functools.lru_cache
83
84
  def parse_version(version="0.0.0") -> tuple:
84
85
  """
85
86
  Convert a version string to a tuple of integers, ignoring any extra non-numeric string attached to the version.
@@ -164,6 +165,7 @@ def check_imgsz(imgsz, stride=32, min_dim=1, max_dim=2, floor=0):
164
165
  return sz
165
166
 
166
167
 
168
+ @functools.lru_cache
167
169
  def check_version(
168
170
  current: str = "0.0.0",
169
171
  required: str = "0.0.0",
@@ -580,6 +582,7 @@ def check_is_path_safe(basedir, path):
580
582
  return path_resolved.exists() and path_resolved.parts[: len(base_dir_resolved.parts)] == base_dir_resolved.parts
581
583
 
582
584
 
585
+ @functools.lru_cache
583
586
  def check_imshow(warn=False):
584
587
  """
585
588
  Check if environment supports image displays.
@@ -409,7 +409,7 @@ class ConfusionMatrix:
409
409
  @plt_settings()
410
410
  def plot(self, normalize=True, save_dir="", names=(), on_plot=None):
411
411
  """
412
- Plot the confusion matrix using seaborn and save it to a file.
412
+ Plot the confusion matrix using matplotlib and save it to a file.
413
413
 
414
414
  Args:
415
415
  normalize (bool): Whether to normalize the confusion matrix.
@@ -418,34 +418,63 @@ class ConfusionMatrix:
418
418
  on_plot (func): An optional callback to pass plots path and data when they are rendered.
419
419
  """
420
420
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
421
- import seaborn
422
421
 
423
422
  array = self.matrix / ((self.matrix.sum(0).reshape(1, -1) + 1e-9) if normalize else 1) # normalize columns
424
423
  array[array < 0.005] = np.nan # don't annotate (would appear as 0.00)
425
424
 
426
- fig, ax = plt.subplots(1, 1, figsize=(12, 9), tight_layout=True)
427
- nc, nn = self.nc, len(names) # number of classes, names
428
- seaborn.set_theme(font_scale=1.0 if nc < 50 else 0.8) # for label size
429
- labels = (0 < nn < 99) and (nn == nc) # apply names to ticklabels
430
- ticklabels = (list(names) + ["background"]) if labels else "auto"
425
+ names = list(names)
426
+ fig, ax = plt.subplots(1, 1, figsize=(12, 9))
427
+ if self.nc >= 100: # downsample for large class count
428
+ k = max(2, self.nc // 60) # step size for downsampling, always > 1
429
+ keep_idx = slice(None, None, k) # create slice instead of array
430
+ names = names[keep_idx] # slice class names
431
+ array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
432
+ n = (self.nc + k - 1) // k # number of retained classes
433
+ nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
434
+ else:
435
+ nc = nn = self.nc if self.task == "classify" else self.nc + 1
436
+ ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
437
+ xy_ticks = np.arange(len(ticklabels))
438
+ tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
439
+ label_fontsize = max(6, 12 - 0.1 * nc)
440
+ title_fontsize = max(6, 12 - 0.1 * nc)
441
+ btm = max(0.1, 0.25 - 0.001 * nc) # Minimum value is 0.1
431
442
  with warnings.catch_warnings():
432
443
  warnings.simplefilter("ignore") # suppress empty matrix RuntimeWarning: All-NaN slice encountered
433
- seaborn.heatmap(
434
- array,
435
- ax=ax,
436
- annot=nc < 30,
437
- annot_kws={"size": 8},
438
- cmap="Blues",
439
- fmt=".2f" if normalize else ".0f",
440
- square=True,
441
- vmin=0.0,
442
- xticklabels=ticklabels,
443
- yticklabels=ticklabels,
444
- ).set_facecolor((1, 1, 1))
444
+ im = ax.imshow(array, cmap="Blues", vmin=0.0, interpolation="none")
445
+ ax.xaxis.set_label_position("bottom")
446
+ if nc < 30: # Add score for each cell of confusion matrix
447
+ for i, row in enumerate(array[:nc]):
448
+ for j, val in enumerate(row[:nc]):
449
+ val = array[i, j]
450
+ if np.isnan(val):
451
+ continue
452
+ ax.text(
453
+ j,
454
+ i,
455
+ f"{val:.2f}" if normalize else f"{int(val)}",
456
+ ha="center",
457
+ va="center",
458
+ fontsize=10,
459
+ color="white" if val > (0.7 if normalize else 2) else "black",
460
+ )
461
+ cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.05)
445
462
  title = "Confusion Matrix" + " Normalized" * normalize
446
- ax.set_xlabel("True")
447
- ax.set_ylabel("Predicted")
448
- ax.set_title(title)
463
+ ax.set_xlabel("True", fontsize=label_fontsize, labelpad=10)
464
+ ax.set_ylabel("Predicted", fontsize=label_fontsize, labelpad=10)
465
+ ax.set_title(title, fontsize=title_fontsize, pad=20)
466
+ ax.set_xticks(xy_ticks)
467
+ ax.set_yticks(xy_ticks)
468
+ ax.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False)
469
+ ax.tick_params(axis="y", left=True, right=False, labelleft=True, labelright=False)
470
+ if ticklabels != "auto":
471
+ ax.set_xticklabels(ticklabels, fontsize=tick_fontsize, rotation=90, ha="center")
472
+ ax.set_yticklabels(ticklabels, fontsize=tick_fontsize)
473
+ for s in ["left", "right", "bottom", "top", "outline"]:
474
+ if s != "outline":
475
+ ax.spines[s].set_visible(False) # Confusion matrix plot don't have outline
476
+ cbar.ax.spines[s].set_visible(False)
477
+ fig.subplots_adjust(left=0, right=0.84, top=0.94, bottom=btm) # Adjust layout to ensure equal margins
449
478
  plot_fname = Path(save_dir) / f"{title.lower().replace(' ', '_')}.png"
450
479
  fig.savefig(plot_fname, dpi=250)
451
480
  plt.close(fig)
@@ -537,9 +537,9 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
537
537
  """
538
538
  import matplotlib.pyplot as plt # scope for faster 'import ultralytics'
539
539
  import pandas
540
- import seaborn
540
+ from matplotlib.colors import LinearSegmentedColormap
541
541
 
542
- # Filter matplotlib>=3.7.2 warning and Seaborn use_inf and is_categorical FutureWarnings
542
+ # Filter matplotlib>=3.7.2 warning
543
543
  warnings.filterwarnings("ignore", category=UserWarning, message="The figure layout has changed to tight")
544
544
  warnings.filterwarnings("ignore", category=FutureWarning)
545
545
 
@@ -549,12 +549,17 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
549
549
  boxes = boxes[:1000000] # limit to 1M boxes
550
550
  x = pandas.DataFrame(boxes, columns=["x", "y", "width", "height"])
551
551
 
552
- # Seaborn correlogram
553
- seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
554
- plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
555
- plt.close()
552
+ try: # Seaborn correlogram
553
+ import seaborn
554
+
555
+ seaborn.pairplot(x, corner=True, diag_kind="auto", kind="hist", diag_kws=dict(bins=50), plot_kws=dict(pmax=0.9))
556
+ plt.savefig(save_dir / "labels_correlogram.jpg", dpi=200)
557
+ plt.close()
558
+ except ImportError:
559
+ pass # Skip if seaborn is not installed
556
560
 
557
561
  # Matplotlib labels
562
+ subplot_3_4_color = LinearSegmentedColormap.from_list("white_blue", ["white", "blue"])
558
563
  ax = plt.subplots(2, 2, figsize=(8, 8), tight_layout=True)[1].ravel()
559
564
  y = ax[0].hist(cls, bins=np.linspace(0, nc, nc + 1) - 0.5, rwidth=0.8)
560
565
  for i in range(nc):
@@ -565,18 +570,19 @@ def plot_labels(boxes, cls, names=(), save_dir=Path(""), on_plot=None):
565
570
  ax[0].set_xticklabels(list(names.values()), rotation=90, fontsize=10)
566
571
  else:
567
572
  ax[0].set_xlabel("classes")
568
- seaborn.histplot(x, x="x", y="y", ax=ax[2], bins=50, pmax=0.9)
569
- seaborn.histplot(x, x="width", y="height", ax=ax[3], bins=50, pmax=0.9)
570
-
571
- # Rectangles
572
- boxes[:, 0:2] = 0.5 # center
573
- boxes = ops.xywh2xyxy(boxes) * 1000
573
+ boxes = np.column_stack([0.5 - boxes[:, 2:4] / 2, 0.5 + boxes[:, 2:4] / 2]) * 1000
574
574
  img = Image.fromarray(np.ones((1000, 1000, 3), dtype=np.uint8) * 255)
575
575
  for cls, box in zip(cls[:500], boxes[:500]):
576
576
  ImageDraw.Draw(img).rectangle(box, width=1, outline=colors(cls)) # plot
577
577
  ax[1].imshow(img)
578
578
  ax[1].axis("off")
579
579
 
580
+ ax[2].hist2d(x["x"], x["y"], bins=50, cmap=subplot_3_4_color)
581
+ ax[2].set_xlabel("x")
582
+ ax[2].set_ylabel("y")
583
+ ax[3].hist2d(x["width"], x["height"], bins=50, cmap=subplot_3_4_color)
584
+ ax[3].set_xlabel("width")
585
+ ax[3].set_ylabel("height")
580
586
  for a in [0, 1, 2, 3]:
581
587
  for s in ["top", "right", "left", "bottom"]:
582
588
  ax[a].spines[s].set_visible(False)