ultralytics 8.1.38__py3-none-any.whl → 8.1.40__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (58) hide show
  1. ultralytics/__init__.py +1 -1
  2. ultralytics/cfg/__init__.py +3 -3
  3. ultralytics/cfg/datasets/lvis.yaml +1239 -0
  4. ultralytics/data/__init__.py +18 -2
  5. ultralytics/data/augment.py +124 -3
  6. ultralytics/data/base.py +2 -2
  7. ultralytics/data/build.py +25 -3
  8. ultralytics/data/converter.py +24 -6
  9. ultralytics/data/dataset.py +142 -27
  10. ultralytics/data/loaders.py +11 -8
  11. ultralytics/data/split_dota.py +1 -1
  12. ultralytics/data/utils.py +33 -8
  13. ultralytics/engine/exporter.py +3 -3
  14. ultralytics/engine/model.py +6 -3
  15. ultralytics/engine/results.py +2 -2
  16. ultralytics/engine/trainer.py +59 -55
  17. ultralytics/engine/validator.py +2 -2
  18. ultralytics/hub/utils.py +1 -1
  19. ultralytics/models/fastsam/model.py +1 -1
  20. ultralytics/models/fastsam/prompt.py +4 -5
  21. ultralytics/models/nas/model.py +1 -1
  22. ultralytics/models/sam/model.py +1 -1
  23. ultralytics/models/sam/modules/tiny_encoder.py +1 -1
  24. ultralytics/models/yolo/__init__.py +2 -2
  25. ultralytics/models/yolo/classify/train.py +1 -1
  26. ultralytics/models/yolo/detect/train.py +1 -1
  27. ultralytics/models/yolo/detect/val.py +36 -17
  28. ultralytics/models/yolo/model.py +1 -0
  29. ultralytics/models/yolo/world/__init__.py +5 -0
  30. ultralytics/models/yolo/world/train.py +92 -0
  31. ultralytics/models/yolo/world/train_world.py +108 -0
  32. ultralytics/nn/autobackend.py +5 -5
  33. ultralytics/nn/modules/block.py +4 -2
  34. ultralytics/nn/modules/conv.py +1 -1
  35. ultralytics/nn/modules/head.py +13 -4
  36. ultralytics/nn/tasks.py +30 -14
  37. ultralytics/solutions/ai_gym.py +1 -1
  38. ultralytics/solutions/heatmap.py +85 -47
  39. ultralytics/solutions/object_counter.py +79 -64
  40. ultralytics/trackers/byte_tracker.py +1 -1
  41. ultralytics/trackers/track.py +1 -1
  42. ultralytics/trackers/utils/gmc.py +1 -1
  43. ultralytics/utils/__init__.py +4 -4
  44. ultralytics/utils/benchmarks.py +2 -2
  45. ultralytics/utils/callbacks/comet.py +1 -1
  46. ultralytics/utils/callbacks/mlflow.py +1 -1
  47. ultralytics/utils/checks.py +3 -3
  48. ultralytics/utils/downloads.py +2 -2
  49. ultralytics/utils/loss.py +1 -1
  50. ultralytics/utils/metrics.py +1 -1
  51. ultralytics/utils/plotting.py +36 -22
  52. ultralytics/utils/torch_utils.py +17 -3
  53. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/METADATA +1 -1
  54. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/RECORD +58 -54
  55. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/LICENSE +0 -0
  56. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/WHEEL +0 -0
  57. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/entry_points.txt +0 -0
  58. {ultralytics-8.1.38.dist-info → ultralytics-8.1.40.dist-info}/top_level.txt +0 -0
@@ -43,16 +43,19 @@ class ObjectCounter:
43
43
  # Object counting Information
44
44
  self.in_counts = 0
45
45
  self.out_counts = 0
46
- self.counting_dict = {}
46
+ self.count_ids = []
47
+ self.class_wise_count = {}
47
48
  self.count_txt_thickness = 0
48
- self.count_txt_color = (0, 0, 0)
49
- self.count_color = (255, 255, 255)
49
+ self.count_txt_color = (255, 255, 255)
50
+ self.line_color = (255, 255, 255)
51
+ self.cls_txtdisplay_gap = 50
52
+ self.fontsize = 0.6
50
53
 
51
54
  # Tracks info
52
55
  self.track_history = defaultdict(list)
53
56
  self.track_thickness = 2
54
57
  self.draw_tracks = False
55
- self.track_color = (0, 255, 0)
58
+ self.track_color = None
56
59
 
57
60
  # Check if environment support imshow
58
61
  self.env_check = check_imshow(warn=True)
@@ -68,12 +71,14 @@ class ObjectCounter:
68
71
  view_in_counts=True,
69
72
  view_out_counts=True,
70
73
  draw_tracks=False,
71
- count_txt_thickness=2,
72
- count_txt_color=(0, 0, 0),
73
- count_color=(255, 255, 255),
74
- track_color=(0, 255, 0),
74
+ count_txt_thickness=3,
75
+ count_txt_color=(255, 255, 255),
76
+ fontsize=0.8,
77
+ line_color=(255, 255, 255),
78
+ track_color=None,
75
79
  region_thickness=5,
76
80
  line_dist_thresh=15,
81
+ cls_txtdisplay_gap=50,
77
82
  ):
78
83
  """
79
84
  Configures the Counter's image, bounding box line thickness, and counting region points.
@@ -89,11 +94,13 @@ class ObjectCounter:
89
94
  draw_tracks (Bool): draw tracks
90
95
  count_txt_thickness (int): Text thickness for object counting display
91
96
  count_txt_color (RGB color): count text color value
92
- count_color (RGB color): count text background color value
97
+ fontsize (float): Text display font size
98
+ line_color (RGB color): count highlighter line color
93
99
  count_reg_color (RGB color): Color of object counting region
94
100
  track_color (RGB color): color for tracks
95
101
  region_thickness (int): Object counting Region thickness
96
102
  line_dist_thresh (int): Euclidean Distance threshold for line counter
103
+ cls_txtdisplay_gap (int): Display gap between each class count
97
104
  """
98
105
  self.tf = line_thickness
99
106
  self.view_img = view_img
@@ -108,7 +115,7 @@ class ObjectCounter:
108
115
  self.reg_pts = reg_pts
109
116
  self.counting_region = LineString(self.reg_pts)
110
117
  elif len(reg_pts) >= 3:
111
- print("Region Counter Initiated.")
118
+ print("Polygon Counter Initiated.")
112
119
  self.reg_pts = reg_pts
113
120
  self.counting_region = Polygon(self.reg_pts)
114
121
  else:
@@ -120,10 +127,12 @@ class ObjectCounter:
120
127
  self.track_color = track_color
121
128
  self.count_txt_thickness = count_txt_thickness
122
129
  self.count_txt_color = count_txt_color
123
- self.count_color = count_color
130
+ self.fontsize = fontsize
131
+ self.line_color = line_color
124
132
  self.region_color = count_reg_color
125
133
  self.region_thickness = region_thickness
126
134
  self.line_dist_thresh = line_dist_thresh
135
+ self.cls_txtdisplay_gap = cls_txtdisplay_gap
127
136
 
128
137
  def mouse_event_for_region(self, event, x, y, flags, params):
129
138
  """
@@ -171,7 +180,13 @@ class ObjectCounter:
171
180
  # Extract tracks
172
181
  for box, track_id, cls in zip(boxes, track_ids, clss):
173
182
  # Draw bounding box
174
- self.annotator.box_label(box, label=f"{track_id}:{self.names[cls]}", color=colors(int(track_id), True))
183
+ self.annotator.box_label(box, label=f"{self.names[cls]}#{track_id}", color=colors(int(track_id), True))
184
+
185
+ # Store class info
186
+ if self.names[cls] not in self.class_wise_count:
187
+ if len(self.names[cls]) > 5:
188
+ self.names[cls] = self.names[cls][:5]
189
+ self.class_wise_count[self.names[cls]] = {"in": 0, "out": 0}
175
190
 
176
191
  # Draw Tracks
177
192
  track_line = self.track_history[track_id]
@@ -182,68 +197,68 @@ class ObjectCounter:
182
197
  # Draw track trails
183
198
  if self.draw_tracks:
184
199
  self.annotator.draw_centroid_and_tracks(
185
- track_line, color=self.track_color, track_thickness=self.track_thickness
200
+ track_line,
201
+ color=self.track_color if self.track_color else colors(int(track_id), True),
202
+ track_thickness=self.track_thickness,
186
203
  )
187
204
 
188
205
  prev_position = self.track_history[track_id][-2] if len(self.track_history[track_id]) > 1 else None
189
- centroid = Point((box[:2] + box[2:]) / 2)
190
-
191
- # Count objects
192
- if len(self.reg_pts) >= 3: # any polygon
193
- is_inside = self.counting_region.contains(centroid)
194
- current_position = "in" if is_inside else "out"
195
206
 
196
- if prev_position is not None:
197
- if self.counting_dict[track_id] != current_position and is_inside:
198
- self.in_counts += 1
199
- self.counting_dict[track_id] = "in"
200
- elif self.counting_dict[track_id] != current_position and not is_inside:
201
- self.out_counts += 1
202
- self.counting_dict[track_id] = "out"
203
- else:
204
- self.counting_dict[track_id] = current_position
207
+ # Count objects in any polygon
208
+ if len(self.reg_pts) >= 3:
209
+ is_inside = self.counting_region.contains(Point(track_line[-1]))
205
210
 
206
- else:
207
- self.counting_dict[track_id] = current_position
211
+ if prev_position is not None and is_inside and track_id not in self.count_ids:
212
+ self.count_ids.append(track_id)
208
213
 
209
- elif len(self.reg_pts) == 2:
210
- if prev_position is not None:
211
- is_inside = (box[0] - prev_position[0]) * (
212
- self.counting_region.centroid.x - prev_position[0]
213
- ) > 0
214
- current_position = "in" if is_inside else "out"
215
-
216
- if self.counting_dict[track_id] != current_position and is_inside:
214
+ if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
217
215
  self.in_counts += 1
218
- self.counting_dict[track_id] = "in"
219
- elif self.counting_dict[track_id] != current_position and not is_inside:
220
- self.out_counts += 1
221
- self.counting_dict[track_id] = "out"
216
+ self.class_wise_count[self.names[cls]]["in"] += 1
222
217
  else:
223
- self.counting_dict[track_id] = current_position
224
- else:
225
- self.counting_dict[track_id] = None
226
-
227
- incount_label = f"In Count : {self.in_counts}"
228
- outcount_label = f"OutCount : {self.out_counts}"
229
-
230
- # Display counts based on user choice
231
- counts_label = None
232
- if not self.view_in_counts and not self.view_out_counts:
233
- counts_label = None
234
- elif not self.view_in_counts:
235
- counts_label = outcount_label
236
- elif not self.view_out_counts:
237
- counts_label = incount_label
238
- else:
239
- counts_label = f"{incount_label} {outcount_label}"
218
+ self.out_counts += 1
219
+ self.class_wise_count[self.names[cls]]["out"] += 1
240
220
 
241
- if counts_label is not None:
242
- self.annotator.count_labels(
243
- counts=counts_label,
244
- count_txt_size=self.count_txt_thickness,
221
+ # Count objects using line
222
+ elif len(self.reg_pts) == 2:
223
+ is_inside = (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0
224
+
225
+ if prev_position is not None and is_inside and track_id not in self.count_ids:
226
+ distance = Point(track_line[-1]).distance(self.counting_region)
227
+
228
+ if distance < self.line_dist_thresh and track_id not in self.count_ids:
229
+ self.count_ids.append(track_id)
230
+
231
+ if (box[0] - prev_position[0]) * (self.counting_region.centroid.x - prev_position[0]) > 0:
232
+ self.in_counts += 1
233
+ self.class_wise_count[self.names[cls]]["in"] += 1
234
+ else:
235
+ self.out_counts += 1
236
+ self.class_wise_count[self.names[cls]]["out"] += 1
237
+
238
+ label = "Ultralytics Analytics \t"
239
+
240
+ for key, value in self.class_wise_count.items():
241
+ if value["in"] != 0 or value["out"] != 0:
242
+ if not self.view_in_counts and not self.view_out_counts:
243
+ label = None
244
+ elif not self.view_in_counts:
245
+ label += f"{str.capitalize(key)}: IN {value['in']} \t"
246
+ elif not self.view_out_counts:
247
+ label += f"{str.capitalize(key)}: OUT {value['out']} \t"
248
+ else:
249
+ label += f"{str.capitalize(key)}: IN {value['in']} OUT {value['out']} \t"
250
+
251
+ label = label.rstrip()
252
+ label = label.split("\t")
253
+
254
+ if label is not None:
255
+ self.annotator.display_counts(
256
+ counts=label,
257
+ tf=self.count_txt_thickness,
258
+ fontScale=self.fontsize,
245
259
  txt_color=self.count_txt_color,
246
- color=self.count_color,
260
+ line_color=self.line_color,
261
+ classwise_txtgap=self.cls_txtdisplay_gap,
247
262
  )
248
263
 
249
264
  def display_frames(self):
@@ -47,7 +47,7 @@ class STrack(BaseTrack):
47
47
  """Initialize new STrack instance."""
48
48
  super().__init__()
49
49
  # xywh+idx or xywha+idx
50
- assert len(xywh) in [5, 6], f"expected 5 or 6 values but got {len(xywh)}"
50
+ assert len(xywh) in {5, 6}, f"expected 5 or 6 values but got {len(xywh)}"
51
51
  self._tlwh = np.asarray(xywh2ltwh(xywh[:4]), dtype=np.float32)
52
52
  self.kalman_filter = None
53
53
  self.mean, self.covariance = None, None
@@ -31,7 +31,7 @@ def on_predict_start(predictor: object, persist: bool = False) -> None:
31
31
  tracker = check_yaml(predictor.args.tracker)
32
32
  cfg = IterableSimpleNamespace(**yaml_load(tracker))
33
33
 
34
- if cfg.tracker_type not in ["bytetrack", "botsort"]:
34
+ if cfg.tracker_type not in {"bytetrack", "botsort"}:
35
35
  raise AssertionError(f"Only 'bytetrack' and 'botsort' are supported for now, but got '{cfg.tracker_type}'")
36
36
 
37
37
  trackers = []
@@ -94,7 +94,7 @@ class GMC:
94
94
  array([[1, 2, 3],
95
95
  [4, 5, 6]])
96
96
  """
97
- if self.method in ["orb", "sift"]:
97
+ if self.method in {"orb", "sift"}:
98
98
  return self.applyFeatures(raw_frame, detections)
99
99
  elif self.method == "ecc":
100
100
  return self.applyEcc(raw_frame)
@@ -41,7 +41,7 @@ VERBOSE = str(os.getenv("YOLO_VERBOSE", True)).lower() == "true" # global verbo
41
41
  TQDM_BAR_FORMAT = "{l_bar}{bar:10}{r_bar}" if VERBOSE else None # tqdm bar format
42
42
  LOGGING_NAME = "ultralytics"
43
43
  MACOS, LINUX, WINDOWS = (platform.system() == x for x in ["Darwin", "Linux", "Windows"]) # environment booleans
44
- ARM64 = platform.machine() in ("arm64", "aarch64") # ARM64 booleans
44
+ ARM64 = platform.machine() in {"arm64", "aarch64"} # ARM64 booleans
45
45
  HELP_MSG = """
46
46
  Usage examples for running YOLOv8:
47
47
 
@@ -359,7 +359,7 @@ def yaml_load(file="data.yaml", append_filename=False):
359
359
  Returns:
360
360
  (dict): YAML data and file name.
361
361
  """
362
- assert Path(file).suffix in (".yaml", ".yml"), f"Attempting to load non-YAML file {file} with yaml_load()"
362
+ assert Path(file).suffix in {".yaml", ".yml"}, f"Attempting to load non-YAML file {file} with yaml_load()"
363
363
  with open(file, errors="ignore", encoding="utf-8") as f:
364
364
  s = f.read() # string
365
365
 
@@ -866,7 +866,7 @@ def set_sentry():
866
866
  """
867
867
  if "exc_info" in hint:
868
868
  exc_type, exc_value, tb = hint["exc_info"]
869
- if exc_type in (KeyboardInterrupt, FileNotFoundError) or "out of memory" in str(exc_value):
869
+ if exc_type in {KeyboardInterrupt, FileNotFoundError} or "out of memory" in str(exc_value):
870
870
  return None # do not send event
871
871
 
872
872
  event["tags"] = {
@@ -879,7 +879,7 @@ def set_sentry():
879
879
 
880
880
  if (
881
881
  SETTINGS["sync"]
882
- and RANK in (-1, 0)
882
+ and RANK in {-1, 0}
883
883
  and Path(ARGV[0]).name == "yolo"
884
884
  and not TESTS_RUNNING
885
885
  and ONLINE
@@ -115,7 +115,7 @@ def benchmark(
115
115
 
116
116
  # Predict
117
117
  assert model.task != "pose" or i != 7, "GraphDef Pose inference is not supported"
118
- assert i not in (9, 10), "inference not supported" # Edge TPU and TF.js are unsupported
118
+ assert i not in {9, 10}, "inference not supported" # Edge TPU and TF.js are unsupported
119
119
  assert i != 5 or platform.system() == "Darwin", "inference only supported on macOS>=10.13" # CoreML
120
120
  exported_model.predict(ASSETS / "bus.jpg", imgsz=imgsz, device=device, half=half)
121
121
 
@@ -220,7 +220,7 @@ class ProfileModels:
220
220
  output = []
221
221
  for file in files:
222
222
  engine_file = file.with_suffix(".engine")
223
- if file.suffix in (".pt", ".yaml", ".yml"):
223
+ if file.suffix in {".pt", ".yaml", ".yml"}:
224
224
  model = YOLO(str(file))
225
225
  model.fuse() # to report correct params and GFLOPs in model.info()
226
226
  model_info = model.info()
@@ -71,7 +71,7 @@ def _get_experiment_type(mode, project_name):
71
71
 
72
72
  def _create_experiment(args):
73
73
  """Ensures that the experiment object is only created in a single process during distributed training."""
74
- if RANK not in (-1, 0):
74
+ if RANK not in {-1, 0}:
75
75
  return
76
76
  try:
77
77
  comet_mode = _get_comet_mode()
@@ -108,7 +108,7 @@ def on_train_end(trainer):
108
108
  for f in trainer.save_dir.glob("*"): # log all other files in save_dir
109
109
  if f.suffix in {".png", ".jpg", ".csv", ".pt", ".yaml"}:
110
110
  mlflow.log_artifact(str(f))
111
- keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() in ("true")
111
+ keep_run_active = os.environ.get("MLFLOW_KEEP_RUN_ACTIVE", "False").lower() == "true"
112
112
  if keep_run_active:
113
113
  LOGGER.info(f"{PREFIX}mlflow run still alive, remember to close it using mlflow.end_run()")
114
114
  else:
@@ -237,7 +237,7 @@ def check_version(
237
237
  result = False
238
238
  elif op == "!=" and c == v:
239
239
  result = False
240
- elif op in (">=", "") and not (c >= v): # if no constraint passed assume '>=required'
240
+ elif op in {">=", ""} and not (c >= v): # if no constraint passed assume '>=required'
241
241
  result = False
242
242
  elif op == "<=" and not (c <= v):
243
243
  result = False
@@ -500,7 +500,7 @@ def check_file(file, suffix="", download=True, hard=True):
500
500
  raise FileNotFoundError(f"'{file}' does not exist")
501
501
  elif len(files) > 1 and hard:
502
502
  raise FileNotFoundError(f"Multiple files match '{file}', specify exact path: {files}")
503
- return files[0] if len(files) else [] if hard else file # return file
503
+ return files[0] if len(files) else [] # return file
504
504
 
505
505
 
506
506
  def check_yaml(file, suffix=(".yaml", ".yml"), hard=True):
@@ -632,7 +632,7 @@ def check_amp(model):
632
632
  (bool): Returns True if the AMP functionality works correctly with YOLOv8 model, else False.
633
633
  """
634
634
  device = next(model.parameters()).device # get model device
635
- if device.type in ("cpu", "mps"):
635
+ if device.type in {"cpu", "mps"}:
636
636
  return False # AMP only used on CUDA devices
637
637
 
638
638
  def amp_allclose(m, im):
@@ -356,13 +356,13 @@ def safe_download(
356
356
  raise ConnectionError(emojis(f"❌ Download failure for {url}. Retry limit reached.")) from e
357
357
  LOGGER.warning(f"⚠️ Download failure, retrying {i + 1}/{retry} {url}...")
358
358
 
359
- if unzip and f.exists() and f.suffix in ("", ".zip", ".tar", ".gz"):
359
+ if unzip and f.exists() and f.suffix in {"", ".zip", ".tar", ".gz"}:
360
360
  from zipfile import is_zipfile
361
361
 
362
362
  unzip_dir = (dir or f.parent).resolve() # unzip to dir if provided else unzip in place
363
363
  if is_zipfile(f):
364
364
  unzip_dir = unzip_file(file=f, path=unzip_dir, exist_ok=exist_ok, progress=progress) # unzip
365
- elif f.suffix in (".tar", ".gz"):
365
+ elif f.suffix in {".tar", ".gz"}:
366
366
  LOGGER.info(f"Unzipping {f} to {unzip_dir}...")
367
367
  subprocess.run(["tar", "xf" if f.suffix == ".tar" else "xfz", f, "--directory", unzip_dir], check=True)
368
368
  if delete:
ultralytics/utils/loss.py CHANGED
@@ -157,7 +157,7 @@ class v8DetectionLoss:
157
157
  self.hyp = h
158
158
  self.stride = m.stride # model strides
159
159
  self.nc = m.nc # number of classes
160
- self.no = m.no
160
+ self.no = m.nc + m.reg_max * 4
161
161
  self.reg_max = m.reg_max
162
162
  self.device = device
163
163
 
@@ -298,7 +298,7 @@ class ConfusionMatrix:
298
298
  self.task = task
299
299
  self.matrix = np.zeros((nc + 1, nc + 1)) if self.task == "detect" else np.zeros((nc, nc))
300
300
  self.nc = nc # number of classes
301
- self.conf = 0.25 if conf in (None, 0.001) else conf # apply 0.25 if default val conf is passed
301
+ self.conf = 0.25 if conf in {None, 0.001} else conf # apply 0.25 if default val conf is passed
302
302
  self.iou_thres = iou_thres
303
303
 
304
304
  def process_cls_preds(self, preds, targets):
@@ -363,35 +363,49 @@ class Annotator:
363
363
  cv2.polylines(self.im, [points], isClosed=False, color=color, thickness=track_thickness)
364
364
  cv2.circle(self.im, (int(track[-1][0]), int(track[-1][1])), track_thickness * 2, color, -1)
365
365
 
366
- def count_labels(self, counts=0, count_txt_size=2, color=(255, 255, 255), txt_color=(0, 0, 0)):
366
+ def display_counts(
367
+ self, counts=None, tf=2, fontScale=0.6, line_color=(0, 0, 0), txt_color=(255, 255, 255), classwise_txtgap=55
368
+ ):
367
369
  """
368
- Plot counts for object counter.
370
+ Display counts on im0.
369
371
 
370
372
  Args:
371
- counts (int): objects counts value
372
- count_txt_size (int): text size for counts display
373
- color (tuple): background color of counts display
374
- txt_color (tuple): text color of counts display
373
+ counts (str): objects count data
374
+ tf (int): text thickness for display
375
+ fontScale (float): text fontsize for display
376
+ line_color (RGB Color): counts highlighter color
377
+ txt_color (RGB Color): counts display color
378
+ classwise_txtgap (int): Gap between each class count data
375
379
  """
376
- self.tf = count_txt_size
377
- tl = self.tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
380
+
381
+ tl = tf or round(0.002 * (self.im.shape[0] + self.im.shape[1]) / 2) + 1
378
382
  tf = max(tl - 1, 1)
379
383
 
380
- # Get text size for in_count and out_count
381
- t_size_in = cv2.getTextSize(str(counts), 0, fontScale=tl / 2, thickness=tf)[0]
384
+ t_sizes = [cv2.getTextSize(str(count), 0, fontScale=0.8, thickness=tf)[0] for count in counts]
382
385
 
383
- # Calculate positions for counts label
384
- text_width = t_size_in[0]
385
- text_x = (self.im.shape[1] - text_width) // 2 # Center x-coordinate
386
- text_y = t_size_in[1]
386
+ max_text_width = max([size[0] for size in t_sizes])
387
+ max_text_height = max([size[1] for size in t_sizes])
387
388
 
388
- # Create a rounded rectangle for in_count
389
- cv2.rectangle(
390
- self.im, (text_x - 5, text_y - 5), (text_x + text_width + 7, text_y + t_size_in[1] + 7), color, -1
391
- )
392
- cv2.putText(
393
- self.im, str(counts), (text_x, text_y + t_size_in[1]), 0, tl / 2, txt_color, self.tf, lineType=cv2.LINE_AA
394
- )
389
+ text_x = self.im.shape[1] - max_text_width - 20
390
+ text_y = classwise_txtgap
391
+
392
+ for i, count in enumerate(counts):
393
+ text_x_pos = text_x
394
+ text_y_pos = text_y + i * classwise_txtgap
395
+
396
+ cv2.putText(
397
+ self.im,
398
+ str(count),
399
+ (text_x_pos, text_y_pos),
400
+ cv2.FONT_HERSHEY_SIMPLEX,
401
+ fontScale=fontScale,
402
+ color=txt_color,
403
+ thickness=tf,
404
+ lineType=cv2.LINE_AA,
405
+ )
406
+
407
+ line_y_pos = text_y_pos + max_text_height + 5
408
+ cv2.line(self.im, (text_x_pos, line_y_pos), (text_x_pos + max_text_width, line_y_pos), line_color, tf)
395
409
 
396
410
  @staticmethod
397
411
  def estimate_pose_angle(a, b, c):
@@ -890,7 +904,7 @@ def plot_results(file="path/to/results.csv", dir="", segment=False, pose=False,
890
904
  ax[i].plot(x, y, marker=".", label=f.stem, linewidth=2, markersize=8) # actual results
891
905
  ax[i].plot(x, gaussian_filter1d(y, sigma=3), ":", label="smooth", linewidth=2) # smoothing line
892
906
  ax[i].set_title(s[j], fontsize=12)
893
- # if j in [8, 9, 10]: # share train and val loss y axes
907
+ # if j in {8, 9, 10}: # share train and val loss y axes
894
908
  # ax[i].get_shared_y_axes().join(ax[i], ax[i - 5])
895
909
  except Exception as e:
896
910
  LOGGER.warning(f"WARNING: Plotting error for {f}: {e}")
@@ -37,7 +37,7 @@ TORCHVISION_0_13 = check_version(torchvision.__version__, "0.13.0")
37
37
  def torch_distributed_zero_first(local_rank: int):
38
38
  """Decorator to make all processes in distributed training wait for each local_master to do something."""
39
39
  initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
40
- if initialized and local_rank not in (-1, 0):
40
+ if initialized and local_rank not in {-1, 0}:
41
41
  dist.barrier(device_ids=[local_rank])
42
42
  yield
43
43
  if initialized and local_rank == 0:
@@ -109,7 +109,7 @@ def select_device(device="", batch=0, newline=False, verbose=True):
109
109
  for remove in "cuda:", "none", "(", ")", "[", "]", "'", " ":
110
110
  device = device.replace(remove, "") # to string, 'cuda:0' -> '0' and '(0, 1)' -> '0,1'
111
111
  cpu = device == "cpu"
112
- mps = device in ("mps", "mps:0") # Apple Metal Performance Shaders (MPS)
112
+ mps = device in {"mps", "mps:0"} # Apple Metal Performance Shaders (MPS)
113
113
  if cpu or mps:
114
114
  os.environ["CUDA_VISIBLE_DEVICES"] = "-1" # force torch.cuda.is_available() = False
115
115
  elif device: # non-cpu device requested
@@ -347,7 +347,7 @@ def initialize_weights(model):
347
347
  elif t is nn.BatchNorm2d:
348
348
  m.eps = 1e-3
349
349
  m.momentum = 0.03
350
- elif t in [nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU]:
350
+ elif t in {nn.Hardswish, nn.LeakyReLU, nn.ReLU, nn.ReLU6, nn.SiLU}:
351
351
  m.inplace = True
352
352
 
353
353
 
@@ -505,6 +505,20 @@ def strip_optimizer(f: Union[str, Path] = "best.pt", s: str = "") -> None:
505
505
  LOGGER.info(f"Optimizer stripped from {f},{f' saved as {s},' if s else ''} {mb:.1f}MB")
506
506
 
507
507
 
508
+ def convert_optimizer_state_dict_to_fp16(state_dict):
509
+ """
510
+ Converts the state_dict of a given optimizer to FP16, focusing on the 'state' key for tensor conversions.
511
+
512
+ This method aims to reduce storage size without altering 'param_groups' as they contain non-tensor data.
513
+ """
514
+ for state in state_dict["state"].values():
515
+ for k, v in state.items():
516
+ if isinstance(v, torch.Tensor) and v.dtype is torch.float32:
517
+ state[k] = v.half()
518
+
519
+ return state_dict
520
+
521
+
508
522
  def profile(input, ops, n=10, device=None):
509
523
  """
510
524
  Ultralytics speed, memory and FLOPs profiler.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: ultralytics
3
- Version: 8.1.38
3
+ Version: 8.1.40
4
4
  Summary: Ultralytics YOLOv8 for SOTA object detection, multi-object tracking, instance segmentation, pose estimation and image classification.
5
5
  Author: Glenn Jocher, Ayush Chaurasia, Jing Qiu
6
6
  Maintainer: Glenn Jocher, Ayush Chaurasia, Jing Qiu