ultralytics 8.0.237__py3-none-any.whl → 8.0.239__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of ultralytics might be problematic. Click here for more details.

Files changed (137) hide show
  1. ultralytics/__init__.py +2 -2
  2. ultralytics/cfg/__init__.py +241 -138
  3. ultralytics/cfg/datasets/DOTAv1.5.yaml +1 -1
  4. ultralytics/cfg/datasets/DOTAv1.yaml +1 -1
  5. ultralytics/cfg/datasets/dota8.yaml +34 -0
  6. ultralytics/data/__init__.py +9 -2
  7. ultralytics/data/annotator.py +4 -4
  8. ultralytics/data/augment.py +186 -169
  9. ultralytics/data/base.py +54 -48
  10. ultralytics/data/build.py +34 -23
  11. ultralytics/data/converter.py +242 -70
  12. ultralytics/data/dataset.py +117 -95
  13. ultralytics/data/explorer/__init__.py +5 -0
  14. ultralytics/data/explorer/explorer.py +170 -97
  15. ultralytics/data/explorer/gui/__init__.py +1 -0
  16. ultralytics/data/explorer/gui/dash.py +146 -76
  17. ultralytics/data/explorer/utils.py +87 -25
  18. ultralytics/data/loaders.py +75 -62
  19. ultralytics/data/split_dota.py +44 -36
  20. ultralytics/data/utils.py +160 -142
  21. ultralytics/engine/exporter.py +348 -292
  22. ultralytics/engine/model.py +102 -66
  23. ultralytics/engine/predictor.py +74 -55
  24. ultralytics/engine/results.py +63 -40
  25. ultralytics/engine/trainer.py +192 -144
  26. ultralytics/engine/tuner.py +66 -59
  27. ultralytics/engine/validator.py +31 -26
  28. ultralytics/hub/__init__.py +54 -31
  29. ultralytics/hub/auth.py +28 -25
  30. ultralytics/hub/session.py +282 -133
  31. ultralytics/hub/utils.py +64 -42
  32. ultralytics/models/__init__.py +1 -1
  33. ultralytics/models/fastsam/__init__.py +1 -1
  34. ultralytics/models/fastsam/model.py +6 -6
  35. ultralytics/models/fastsam/predict.py +3 -2
  36. ultralytics/models/fastsam/prompt.py +55 -48
  37. ultralytics/models/fastsam/val.py +1 -1
  38. ultralytics/models/nas/__init__.py +1 -1
  39. ultralytics/models/nas/model.py +9 -8
  40. ultralytics/models/nas/predict.py +8 -6
  41. ultralytics/models/nas/val.py +11 -9
  42. ultralytics/models/rtdetr/__init__.py +1 -1
  43. ultralytics/models/rtdetr/model.py +11 -9
  44. ultralytics/models/rtdetr/train.py +18 -16
  45. ultralytics/models/rtdetr/val.py +25 -19
  46. ultralytics/models/sam/__init__.py +1 -1
  47. ultralytics/models/sam/amg.py +13 -14
  48. ultralytics/models/sam/build.py +44 -42
  49. ultralytics/models/sam/model.py +6 -6
  50. ultralytics/models/sam/modules/decoders.py +6 -4
  51. ultralytics/models/sam/modules/encoders.py +37 -35
  52. ultralytics/models/sam/modules/sam.py +5 -4
  53. ultralytics/models/sam/modules/tiny_encoder.py +95 -73
  54. ultralytics/models/sam/modules/transformer.py +3 -2
  55. ultralytics/models/sam/predict.py +39 -27
  56. ultralytics/models/utils/loss.py +99 -95
  57. ultralytics/models/utils/ops.py +34 -31
  58. ultralytics/models/yolo/__init__.py +1 -1
  59. ultralytics/models/yolo/classify/__init__.py +1 -1
  60. ultralytics/models/yolo/classify/predict.py +8 -6
  61. ultralytics/models/yolo/classify/train.py +37 -31
  62. ultralytics/models/yolo/classify/val.py +26 -24
  63. ultralytics/models/yolo/detect/__init__.py +1 -1
  64. ultralytics/models/yolo/detect/predict.py +8 -6
  65. ultralytics/models/yolo/detect/train.py +47 -37
  66. ultralytics/models/yolo/detect/val.py +100 -82
  67. ultralytics/models/yolo/model.py +31 -25
  68. ultralytics/models/yolo/obb/__init__.py +1 -1
  69. ultralytics/models/yolo/obb/predict.py +13 -12
  70. ultralytics/models/yolo/obb/train.py +3 -3
  71. ultralytics/models/yolo/obb/val.py +80 -58
  72. ultralytics/models/yolo/pose/__init__.py +1 -1
  73. ultralytics/models/yolo/pose/predict.py +17 -12
  74. ultralytics/models/yolo/pose/train.py +28 -25
  75. ultralytics/models/yolo/pose/val.py +91 -64
  76. ultralytics/models/yolo/segment/__init__.py +1 -1
  77. ultralytics/models/yolo/segment/predict.py +10 -8
  78. ultralytics/models/yolo/segment/train.py +16 -15
  79. ultralytics/models/yolo/segment/val.py +90 -68
  80. ultralytics/nn/__init__.py +26 -6
  81. ultralytics/nn/autobackend.py +144 -112
  82. ultralytics/nn/modules/__init__.py +96 -13
  83. ultralytics/nn/modules/block.py +28 -7
  84. ultralytics/nn/modules/conv.py +41 -23
  85. ultralytics/nn/modules/head.py +67 -59
  86. ultralytics/nn/modules/transformer.py +49 -32
  87. ultralytics/nn/modules/utils.py +20 -15
  88. ultralytics/nn/tasks.py +215 -141
  89. ultralytics/solutions/ai_gym.py +59 -47
  90. ultralytics/solutions/distance_calculation.py +22 -15
  91. ultralytics/solutions/heatmap.py +76 -54
  92. ultralytics/solutions/object_counter.py +46 -39
  93. ultralytics/solutions/speed_estimation.py +13 -16
  94. ultralytics/trackers/__init__.py +1 -1
  95. ultralytics/trackers/basetrack.py +1 -0
  96. ultralytics/trackers/bot_sort.py +2 -1
  97. ultralytics/trackers/byte_tracker.py +10 -7
  98. ultralytics/trackers/track.py +7 -7
  99. ultralytics/trackers/utils/gmc.py +25 -25
  100. ultralytics/trackers/utils/kalman_filter.py +85 -42
  101. ultralytics/trackers/utils/matching.py +8 -7
  102. ultralytics/utils/__init__.py +173 -151
  103. ultralytics/utils/autobatch.py +10 -10
  104. ultralytics/utils/benchmarks.py +76 -86
  105. ultralytics/utils/callbacks/__init__.py +1 -1
  106. ultralytics/utils/callbacks/base.py +29 -29
  107. ultralytics/utils/callbacks/clearml.py +51 -43
  108. ultralytics/utils/callbacks/comet.py +81 -66
  109. ultralytics/utils/callbacks/dvc.py +33 -26
  110. ultralytics/utils/callbacks/hub.py +44 -26
  111. ultralytics/utils/callbacks/mlflow.py +31 -24
  112. ultralytics/utils/callbacks/neptune.py +35 -25
  113. ultralytics/utils/callbacks/raytune.py +9 -4
  114. ultralytics/utils/callbacks/tensorboard.py +16 -11
  115. ultralytics/utils/callbacks/wb.py +39 -33
  116. ultralytics/utils/checks.py +189 -141
  117. ultralytics/utils/dist.py +15 -12
  118. ultralytics/utils/downloads.py +112 -96
  119. ultralytics/utils/errors.py +1 -1
  120. ultralytics/utils/files.py +11 -11
  121. ultralytics/utils/instance.py +22 -22
  122. ultralytics/utils/loss.py +117 -67
  123. ultralytics/utils/metrics.py +224 -158
  124. ultralytics/utils/ops.py +39 -29
  125. ultralytics/utils/patches.py +3 -3
  126. ultralytics/utils/plotting.py +217 -120
  127. ultralytics/utils/tal.py +19 -13
  128. ultralytics/utils/torch_utils.py +138 -109
  129. ultralytics/utils/triton.py +12 -10
  130. ultralytics/utils/tuner.py +49 -47
  131. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/METADATA +5 -4
  132. ultralytics-8.0.239.dist-info/RECORD +188 -0
  133. ultralytics-8.0.237.dist-info/RECORD +0 -187
  134. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/LICENSE +0 -0
  135. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/WHEEL +0 -0
  136. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/entry_points.txt +0 -0
  137. {ultralytics-8.0.237.dist-info → ultralytics-8.0.239.dist-info}/top_level.txt +0 -0
@@ -73,40 +73,43 @@ class Tuner:
73
73
  Args:
74
74
  args (dict, optional): Configuration for hyperparameter evolution.
75
75
  """
76
- self.space = args.pop('space', None) or { # key: (min, max, gain(optional))
76
+ self.space = args.pop("space", None) or { # key: (min, max, gain(optional))
77
77
  # 'optimizer': tune.choice(['SGD', 'Adam', 'AdamW', 'NAdam', 'RAdam', 'RMSProp']),
78
- 'lr0': (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
79
- 'lrf': (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
80
- 'momentum': (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
81
- 'weight_decay': (0.0, 0.001), # optimizer weight decay 5e-4
82
- 'warmup_epochs': (0.0, 5.0), # warmup epochs (fractions ok)
83
- 'warmup_momentum': (0.0, 0.95), # warmup initial momentum
84
- 'box': (1.0, 20.0), # box loss gain
85
- 'cls': (0.2, 4.0), # cls loss gain (scale with pixels)
86
- 'dfl': (0.4, 6.0), # dfl loss gain
87
- 'hsv_h': (0.0, 0.1), # image HSV-Hue augmentation (fraction)
88
- 'hsv_s': (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
89
- 'hsv_v': (0.0, 0.9), # image HSV-Value augmentation (fraction)
90
- 'degrees': (0.0, 45.0), # image rotation (+/- deg)
91
- 'translate': (0.0, 0.9), # image translation (+/- fraction)
92
- 'scale': (0.0, 0.95), # image scale (+/- gain)
93
- 'shear': (0.0, 10.0), # image shear (+/- deg)
94
- 'perspective': (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
95
- 'flipud': (0.0, 1.0), # image flip up-down (probability)
96
- 'fliplr': (0.0, 1.0), # image flip left-right (probability)
97
- 'mosaic': (0.0, 1.0), # image mixup (probability)
98
- 'mixup': (0.0, 1.0), # image mixup (probability)
99
- 'copy_paste': (0.0, 1.0)} # segment copy-paste (probability)
78
+ "lr0": (1e-5, 1e-1), # initial learning rate (i.e. SGD=1E-2, Adam=1E-3)
79
+ "lrf": (0.0001, 0.1), # final OneCycleLR learning rate (lr0 * lrf)
80
+ "momentum": (0.7, 0.98, 0.3), # SGD momentum/Adam beta1
81
+ "weight_decay": (0.0, 0.001), # optimizer weight decay 5e-4
82
+ "warmup_epochs": (0.0, 5.0), # warmup epochs (fractions ok)
83
+ "warmup_momentum": (0.0, 0.95), # warmup initial momentum
84
+ "box": (1.0, 20.0), # box loss gain
85
+ "cls": (0.2, 4.0), # cls loss gain (scale with pixels)
86
+ "dfl": (0.4, 6.0), # dfl loss gain
87
+ "hsv_h": (0.0, 0.1), # image HSV-Hue augmentation (fraction)
88
+ "hsv_s": (0.0, 0.9), # image HSV-Saturation augmentation (fraction)
89
+ "hsv_v": (0.0, 0.9), # image HSV-Value augmentation (fraction)
90
+ "degrees": (0.0, 45.0), # image rotation (+/- deg)
91
+ "translate": (0.0, 0.9), # image translation (+/- fraction)
92
+ "scale": (0.0, 0.95), # image scale (+/- gain)
93
+ "shear": (0.0, 10.0), # image shear (+/- deg)
94
+ "perspective": (0.0, 0.001), # image perspective (+/- fraction), range 0-0.001
95
+ "flipud": (0.0, 1.0), # image flip up-down (probability)
96
+ "fliplr": (0.0, 1.0), # image flip left-right (probability)
97
+ "mosaic": (0.0, 1.0), # image mixup (probability)
98
+ "mixup": (0.0, 1.0), # image mixup (probability)
99
+ "copy_paste": (0.0, 1.0), # segment copy-paste (probability)
100
+ }
100
101
  self.args = get_cfg(overrides=args)
101
- self.tune_dir = get_save_dir(self.args, name='tune')
102
- self.tune_csv = self.tune_dir / 'tune_results.csv'
102
+ self.tune_dir = get_save_dir(self.args, name="tune")
103
+ self.tune_csv = self.tune_dir / "tune_results.csv"
103
104
  self.callbacks = _callbacks or callbacks.get_default_callbacks()
104
- self.prefix = colorstr('Tuner: ')
105
+ self.prefix = colorstr("Tuner: ")
105
106
  callbacks.add_integration_callbacks(self)
106
- LOGGER.info(f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
107
- f'{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning')
107
+ LOGGER.info(
108
+ f"{self.prefix}Initialized Tuner instance with 'tune_dir={self.tune_dir}'\n"
109
+ f"{self.prefix}💡 Learn about tuning at https://docs.ultralytics.com/guides/hyperparameter-tuning"
110
+ )
108
111
 
109
- def _mutate(self, parent='single', n=5, mutation=0.8, sigma=0.2):
112
+ def _mutate(self, parent="single", n=5, mutation=0.8, sigma=0.2):
110
113
  """
111
114
  Mutates the hyperparameters based on bounds and scaling factors specified in `self.space`.
112
115
 
@@ -121,15 +124,15 @@ class Tuner:
121
124
  """
122
125
  if self.tune_csv.exists(): # if CSV file exists: select best hyps and mutate
123
126
  # Select parent(s)
124
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
127
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
125
128
  fitness = x[:, 0] # first column
126
129
  n = min(n, len(x)) # number of previous results to consider
127
130
  x = x[np.argsort(-fitness)][:n] # top n mutations
128
- w = x[:, 0] - x[:, 0].min() + 1E-6 # weights (sum > 0)
129
- if parent == 'single' or len(x) == 1:
131
+ w = x[:, 0] - x[:, 0].min() + 1e-6 # weights (sum > 0)
132
+ if parent == "single" or len(x) == 1:
130
133
  # x = x[random.randint(0, n - 1)] # random selection
131
134
  x = x[random.choices(range(n), weights=w)[0]] # weighted selection
132
- elif parent == 'weighted':
135
+ elif parent == "weighted":
133
136
  x = (x * w.reshape(n, 1)).sum(0) / w.sum() # weighted combination
134
137
 
135
138
  # Mutate
@@ -174,44 +177,44 @@ class Tuner:
174
177
 
175
178
  t0 = time.time()
176
179
  best_save_dir, best_metrics = None, None
177
- (self.tune_dir / 'weights').mkdir(parents=True, exist_ok=True)
180
+ (self.tune_dir / "weights").mkdir(parents=True, exist_ok=True)
178
181
  for i in range(iterations):
179
182
  # Mutate hyperparameters
180
183
  mutated_hyp = self._mutate()
181
- LOGGER.info(f'{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}')
184
+ LOGGER.info(f"{self.prefix}Starting iteration {i + 1}/{iterations} with hyperparameters: {mutated_hyp}")
182
185
 
183
186
  metrics = {}
184
187
  train_args = {**vars(self.args), **mutated_hyp}
185
188
  save_dir = get_save_dir(get_cfg(train_args))
186
- weights_dir = save_dir / 'weights'
187
- ckpt_file = weights_dir / ('best.pt' if (weights_dir / 'best.pt').exists() else 'last.pt')
189
+ weights_dir = save_dir / "weights"
190
+ ckpt_file = weights_dir / ("best.pt" if (weights_dir / "best.pt").exists() else "last.pt")
188
191
  try:
189
192
  # Train YOLO model with mutated hyperparameters (run in subprocess to avoid dataloader hang)
190
- cmd = ['yolo', 'train', *(f'{k}={v}' for k, v in train_args.items())]
193
+ cmd = ["yolo", "train", *(f"{k}={v}" for k, v in train_args.items())]
191
194
  return_code = subprocess.run(cmd, check=True).returncode
192
- metrics = torch.load(ckpt_file)['train_metrics']
193
- assert return_code == 0, 'training failed'
195
+ metrics = torch.load(ckpt_file)["train_metrics"]
196
+ assert return_code == 0, "training failed"
194
197
 
195
198
  except Exception as e:
196
- LOGGER.warning(f'WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}')
199
+ LOGGER.warning(f"WARNING ❌️ training failure for hyperparameter tuning iteration {i + 1}\n{e}")
197
200
 
198
201
  # Save results and mutated_hyp to CSV
199
- fitness = metrics.get('fitness', 0.0)
202
+ fitness = metrics.get("fitness", 0.0)
200
203
  log_row = [round(fitness, 5)] + [mutated_hyp[k] for k in self.space.keys()]
201
- headers = '' if self.tune_csv.exists() else (','.join(['fitness'] + list(self.space.keys())) + '\n')
202
- with open(self.tune_csv, 'a') as f:
203
- f.write(headers + ','.join(map(str, log_row)) + '\n')
204
+ headers = "" if self.tune_csv.exists() else (",".join(["fitness"] + list(self.space.keys())) + "\n")
205
+ with open(self.tune_csv, "a") as f:
206
+ f.write(headers + ",".join(map(str, log_row)) + "\n")
204
207
 
205
208
  # Get best results
206
- x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=',', skiprows=1)
209
+ x = np.loadtxt(self.tune_csv, ndmin=2, delimiter=",", skiprows=1)
207
210
  fitness = x[:, 0] # first column
208
211
  best_idx = fitness.argmax()
209
212
  best_is_current = best_idx == i
210
213
  if best_is_current:
211
214
  best_save_dir = save_dir
212
215
  best_metrics = {k: round(v, 5) for k, v in metrics.items()}
213
- for ckpt in weights_dir.glob('*.pt'):
214
- shutil.copy2(ckpt, self.tune_dir / 'weights')
216
+ for ckpt in weights_dir.glob("*.pt"):
217
+ shutil.copy2(ckpt, self.tune_dir / "weights")
215
218
  elif cleanup:
216
219
  shutil.rmtree(ckpt_file.parent) # remove iteration weights/ dir to reduce storage space
217
220
 
@@ -219,15 +222,19 @@ class Tuner:
219
222
  plot_tune_results(self.tune_csv)
220
223
 
221
224
  # Save and print tune results
222
- header = (f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
223
- f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
224
- f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
225
- f'{self.prefix}Best fitness metrics are {best_metrics}\n'
226
- f'{self.prefix}Best fitness model is {best_save_dir}\n'
227
- f'{self.prefix}Best fitness hyperparameters are printed below.\n')
228
- LOGGER.info('\n' + header)
225
+ header = (
226
+ f'{self.prefix}{i + 1}/{iterations} iterations complete ✅ ({time.time() - t0:.2f}s)\n'
227
+ f'{self.prefix}Results saved to {colorstr("bold", self.tune_dir)}\n'
228
+ f'{self.prefix}Best fitness={fitness[best_idx]} observed at iteration {best_idx + 1}\n'
229
+ f'{self.prefix}Best fitness metrics are {best_metrics}\n'
230
+ f'{self.prefix}Best fitness model is {best_save_dir}\n'
231
+ f'{self.prefix}Best fitness hyperparameters are printed below.\n'
232
+ )
233
+ LOGGER.info("\n" + header)
229
234
  data = {k: float(x[best_idx, i + 1]) for i, k in enumerate(self.space.keys())}
230
- yaml_save(self.tune_dir / 'best_hyperparameters.yaml',
231
- data=data,
232
- header=remove_colorstr(header.replace(self.prefix, '# ')) + '\n')
233
- yaml_print(self.tune_dir / 'best_hyperparameters.yaml')
235
+ yaml_save(
236
+ self.tune_dir / "best_hyperparameters.yaml",
237
+ data=data,
238
+ header=remove_colorstr(header.replace(self.prefix, "# ")) + "\n",
239
+ )
240
+ yaml_print(self.tune_dir / "best_hyperparameters.yaml")
@@ -89,10 +89,10 @@ class BaseValidator:
89
89
  self.nc = None
90
90
  self.iouv = None
91
91
  self.jdict = None
92
- self.speed = {'preprocess': 0.0, 'inference': 0.0, 'loss': 0.0, 'postprocess': 0.0}
92
+ self.speed = {"preprocess": 0.0, "inference": 0.0, "loss": 0.0, "postprocess": 0.0}
93
93
 
94
94
  self.save_dir = save_dir or get_save_dir(self.args)
95
- (self.save_dir / 'labels' if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
95
+ (self.save_dir / "labels" if self.args.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
96
96
  if self.args.conf is None:
97
97
  self.args.conf = 0.001 # default conf=0.001
98
98
  self.args.imgsz = check_imgsz(self.args.imgsz, max_dim=1)
@@ -110,7 +110,7 @@ class BaseValidator:
110
110
  if self.training:
111
111
  self.device = trainer.device
112
112
  self.data = trainer.data
113
- self.args.half = self.device.type != 'cpu' # force FP16 val during training
113
+ self.args.half = self.device.type != "cpu" # force FP16 val during training
114
114
  model = trainer.ema.ema or trainer.model
115
115
  model = model.half() if self.args.half else model.float()
116
116
  # self.model = model
@@ -119,11 +119,13 @@ class BaseValidator:
119
119
  model.eval()
120
120
  else:
121
121
  callbacks.add_integration_callbacks(self)
122
- model = AutoBackend(model or self.args.model,
123
- device=select_device(self.args.device, self.args.batch),
124
- dnn=self.args.dnn,
125
- data=self.args.data,
126
- fp16=self.args.half)
122
+ model = AutoBackend(
123
+ model or self.args.model,
124
+ device=select_device(self.args.device, self.args.batch),
125
+ dnn=self.args.dnn,
126
+ data=self.args.data,
127
+ fp16=self.args.half,
128
+ )
127
129
  # self.model = model
128
130
  self.device = model.device # update device
129
131
  self.args.half = model.fp16 # update half
@@ -133,16 +135,16 @@ class BaseValidator:
133
135
  self.args.batch = model.batch_size
134
136
  elif not pt and not jit:
135
137
  self.args.batch = 1 # export.py models default to batch-size 1
136
- LOGGER.info(f'Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models')
138
+ LOGGER.info(f"Forcing batch=1 square inference (1,3,{imgsz},{imgsz}) for non-PyTorch models")
137
139
 
138
- if str(self.args.data).split('.')[-1] in ('yaml', 'yml'):
140
+ if str(self.args.data).split(".")[-1] in ("yaml", "yml"):
139
141
  self.data = check_det_dataset(self.args.data)
140
- elif self.args.task == 'classify':
142
+ elif self.args.task == "classify":
141
143
  self.data = check_cls_dataset(self.args.data, split=self.args.split)
142
144
  else:
143
145
  raise FileNotFoundError(emojis(f"Dataset '{self.args.data}' for task={self.args.task} not found ❌"))
144
146
 
145
- if self.device.type in ('cpu', 'mps'):
147
+ if self.device.type in ("cpu", "mps"):
146
148
  self.args.workers = 0 # faster CPU val as time dominated by inference, not dataloading
147
149
  if not pt:
148
150
  self.args.rect = False
@@ -152,13 +154,13 @@ class BaseValidator:
152
154
  model.eval()
153
155
  model.warmup(imgsz=(1 if pt else self.args.batch, 3, imgsz, imgsz)) # warmup
154
156
 
155
- self.run_callbacks('on_val_start')
157
+ self.run_callbacks("on_val_start")
156
158
  dt = Profile(), Profile(), Profile(), Profile()
157
159
  bar = TQDM(self.dataloader, desc=self.get_desc(), total=len(self.dataloader))
158
160
  self.init_metrics(de_parallel(model))
159
161
  self.jdict = [] # empty before each val
160
162
  for batch_i, batch in enumerate(bar):
161
- self.run_callbacks('on_val_batch_start')
163
+ self.run_callbacks("on_val_batch_start")
162
164
  self.batch_i = batch_i
163
165
  # Preprocess
164
166
  with dt[0]:
@@ -166,7 +168,7 @@ class BaseValidator:
166
168
 
167
169
  # Inference
168
170
  with dt[1]:
169
- preds = model(batch['img'], augment=augment)
171
+ preds = model(batch["img"], augment=augment)
170
172
 
171
173
  # Loss
172
174
  with dt[2]:
@@ -182,23 +184,25 @@ class BaseValidator:
182
184
  self.plot_val_samples(batch, batch_i)
183
185
  self.plot_predictions(batch, preds, batch_i)
184
186
 
185
- self.run_callbacks('on_val_batch_end')
187
+ self.run_callbacks("on_val_batch_end")
186
188
  stats = self.get_stats()
187
189
  self.check_stats(stats)
188
- self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1E3 for x in dt)))
190
+ self.speed = dict(zip(self.speed.keys(), (x.t / len(self.dataloader.dataset) * 1e3 for x in dt)))
189
191
  self.finalize_metrics()
190
192
  self.print_results()
191
- self.run_callbacks('on_val_end')
193
+ self.run_callbacks("on_val_end")
192
194
  if self.training:
193
195
  model.float()
194
- results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix='val')}
196
+ results = {**stats, **trainer.label_loss_items(self.loss.cpu() / len(self.dataloader), prefix="val")}
195
197
  return {k: round(float(v), 5) for k, v in results.items()} # return results as 5 decimal place floats
196
198
  else:
197
- LOGGER.info('Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image' %
198
- tuple(self.speed.values()))
199
+ LOGGER.info(
200
+ "Speed: %.1fms preprocess, %.1fms inference, %.1fms loss, %.1fms postprocess per image"
201
+ % tuple(self.speed.values())
202
+ )
199
203
  if self.args.save_json and self.jdict:
200
- with open(str(self.save_dir / 'predictions.json'), 'w') as f:
201
- LOGGER.info(f'Saving {f.name}...')
204
+ with open(str(self.save_dir / "predictions.json"), "w") as f:
205
+ LOGGER.info(f"Saving {f.name}...")
202
206
  json.dump(self.jdict, f) # flatten and save
203
207
  stats = self.eval_json(stats) # update stats
204
208
  if self.args.plots or self.args.save_json:
@@ -228,6 +232,7 @@ class BaseValidator:
228
232
  if use_scipy:
229
233
  # WARNING: known issue that reduces mAP in https://github.com/ultralytics/ultralytics/pull/4708
230
234
  import scipy # scope import to avoid importing for all commands
235
+
231
236
  cost_matrix = iou * (iou >= threshold)
232
237
  if cost_matrix.any():
233
238
  labels_idx, detections_idx = scipy.optimize.linear_sum_assignment(cost_matrix, maximize=True)
@@ -257,11 +262,11 @@ class BaseValidator:
257
262
 
258
263
  def get_dataloader(self, dataset_path, batch_size):
259
264
  """Get data loader from dataset path and batch size."""
260
- raise NotImplementedError('get_dataloader function not implemented for this validator')
265
+ raise NotImplementedError("get_dataloader function not implemented for this validator")
261
266
 
262
267
  def build_dataset(self, img_path):
263
268
  """Build dataset."""
264
- raise NotImplementedError('build_dataset function not implemented in validator')
269
+ raise NotImplementedError("build_dataset function not implemented in validator")
265
270
 
266
271
  def preprocess(self, batch):
267
272
  """Preprocesses an input batch."""
@@ -306,7 +311,7 @@ class BaseValidator:
306
311
 
307
312
  def on_plot(self, name, data=None):
308
313
  """Registers plots (e.g. to be consumed in callbacks)"""
309
- self.plots[Path(name)] = {'data': data, 'timestamp': time.time()}
314
+ self.plots[Path(name)] = {"data": data, "timestamp": time.time()}
310
315
 
311
316
  # TODO: may need to put these following functions into callback
312
317
  def plot_val_samples(self, batch, ni):
@@ -1,28 +1,50 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  import requests
4
+ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT, HUBClient
4
5
 
5
6
  from ultralytics.data.utils import HUBDatasetStats
6
7
  from ultralytics.hub.auth import Auth
7
- from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX
8
+ from ultralytics.hub.utils import PREFIX
8
9
  from ultralytics.utils import LOGGER, SETTINGS
9
10
 
10
11
 
11
- def login(api_key=''):
12
+ def login(api_key: str = None, save=True) -> bool:
12
13
  """
13
14
  Log in to the Ultralytics HUB API using the provided API key.
14
15
 
16
+ The session is not stored; a new session is created when needed using the saved SETTINGS or the HUB_API_KEY environment variable if successfully authenticated.
17
+
15
18
  Args:
16
- api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
19
+ api_key (str, optional): The API key to use for authentication. If not provided, it will be retrieved from SETTINGS or HUB_API_KEY environment variable.
20
+ save (bool, optional): Whether to save the API key to SETTINGS if authentication is successful.
21
+ Returns:
22
+ bool: True if authentication is successful, False otherwise.
23
+ """
24
+ api_key_url = f"{HUB_WEB_ROOT}/settings?tab=api+keys" # set the redirect URL
25
+ saved_key = SETTINGS.get("api_key")
26
+ active_key = api_key or saved_key
27
+ credentials = {"api_key": active_key} if active_key and active_key != "" else None # set credentials
17
28
 
18
- Example:
19
- ```python
20
- from ultralytics import hub
29
+ client = HUBClient(credentials) # initialize HUBClient
21
30
 
22
- hub.login('API_KEY')
23
- ```
24
- """
25
- Auth(api_key, verbose=True)
31
+ if client.authenticated:
32
+ # Successfully authenticated with HUB
33
+
34
+ if save and client.api_key != saved_key:
35
+ SETTINGS.update({"api_key": client.api_key}) # update settings with valid API key
36
+
37
+ # Set message based on whether key was provided or retrieved from settings
38
+ log_message = (
39
+ "New authentication successful ✅" if client.api_key == api_key or not credentials else "Authenticated ✅"
40
+ )
41
+ LOGGER.info(f"{PREFIX}{log_message}")
42
+
43
+ return True
44
+ else:
45
+ # Failed to authenticate with HUB
46
+ LOGGER.info(f"{PREFIX}Retrieve API key from {api_key_url}")
47
+ return False
26
48
 
27
49
 
28
50
  def logout():
@@ -36,49 +58,50 @@ def logout():
36
58
  hub.logout()
37
59
  ```
38
60
  """
39
- SETTINGS['api_key'] = ''
61
+ SETTINGS["api_key"] = ""
40
62
  SETTINGS.save()
41
63
  LOGGER.info(f"{PREFIX}logged out ✅. To log in again, use 'yolo hub login'.")
42
64
 
43
65
 
44
- def reset_model(model_id=''):
66
+ def reset_model(model_id=""):
45
67
  """Reset a trained model to an untrained state."""
46
- r = requests.post(f'{HUB_API_ROOT}/model-reset', json={'apiKey': Auth().api_key, 'modelId': model_id})
68
+ r = requests.post(f"{HUB_API_ROOT}/model-reset", json={"modelId": model_id}, headers={"x-api-key": Auth().api_key})
47
69
  if r.status_code == 200:
48
- LOGGER.info(f'{PREFIX}Model reset successfully')
70
+ LOGGER.info(f"{PREFIX}Model reset successfully")
49
71
  return
50
- LOGGER.warning(f'{PREFIX}Model reset failure {r.status_code} {r.reason}')
72
+ LOGGER.warning(f"{PREFIX}Model reset failure {r.status_code} {r.reason}")
51
73
 
52
74
 
53
75
  def export_fmts_hub():
54
76
  """Returns a list of HUB-supported export formats."""
55
77
  from ultralytics.engine.exporter import export_formats
56
- return list(export_formats()['Argument'][1:]) + ['ultralytics_tflite', 'ultralytics_coreml']
78
+
79
+ return list(export_formats()["Argument"][1:]) + ["ultralytics_tflite", "ultralytics_coreml"]
57
80
 
58
81
 
59
- def export_model(model_id='', format='torchscript'):
82
+ def export_model(model_id="", format="torchscript"):
60
83
  """Export a model to all formats."""
61
84
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
62
- r = requests.post(f'{HUB_API_ROOT}/v1/models/{model_id}/export',
63
- json={'format': format},
64
- headers={'x-api-key': Auth().api_key})
65
- assert r.status_code == 200, f'{PREFIX}{format} export failure {r.status_code} {r.reason}'
66
- LOGGER.info(f'{PREFIX}{format} export started ✅')
85
+ r = requests.post(
86
+ f"{HUB_API_ROOT}/v1/models/{model_id}/export", json={"format": format}, headers={"x-api-key": Auth().api_key}
87
+ )
88
+ assert r.status_code == 200, f"{PREFIX}{format} export failure {r.status_code} {r.reason}"
89
+ LOGGER.info(f"{PREFIX}{format} export started ✅")
67
90
 
68
91
 
69
- def get_export(model_id='', format='torchscript'):
92
+ def get_export(model_id="", format="torchscript"):
70
93
  """Get an exported model dictionary with download URL."""
71
94
  assert format in export_fmts_hub(), f"Unsupported export format '{format}', valid formats are {export_fmts_hub()}"
72
- r = requests.post(f'{HUB_API_ROOT}/get-export',
73
- json={
74
- 'apiKey': Auth().api_key,
75
- 'modelId': model_id,
76
- 'format': format})
77
- assert r.status_code == 200, f'{PREFIX}{format} get_export failure {r.status_code} {r.reason}'
95
+ r = requests.post(
96
+ f"{HUB_API_ROOT}/get-export",
97
+ json={"apiKey": Auth().api_key, "modelId": model_id, "format": format},
98
+ headers={"x-api-key": Auth().api_key},
99
+ )
100
+ assert r.status_code == 200, f"{PREFIX}{format} get_export failure {r.status_code} {r.reason}"
78
101
  return r.json()
79
102
 
80
103
 
81
- def check_dataset(path='', task='detect'):
104
+ def check_dataset(path="", task="detect"):
82
105
  """
83
106
  Function for error-checking HUB dataset Zip file before upload. It checks a dataset for errors before it is uploaded
84
107
  to the HUB. Usage examples are given below.
@@ -97,4 +120,4 @@ def check_dataset(path='', task='detect'):
97
120
  ```
98
121
  """
99
122
  HUBDatasetStats(path=path, task=task).get_json()
100
- LOGGER.info(f'Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.')
123
+ LOGGER.info(f"Checks completed correctly ✅. Upload this dataset to {HUB_WEB_ROOT}/datasets/.")
ultralytics/hub/auth.py CHANGED
@@ -1,11 +1,12 @@
1
1
  # Ultralytics YOLO 🚀, AGPL-3.0 license
2
2
 
3
3
  import requests
4
+ from hub_sdk import HUB_API_ROOT, HUB_WEB_ROOT
4
5
 
5
- from ultralytics.hub.utils import HUB_API_ROOT, HUB_WEB_ROOT, PREFIX, request_with_credentials
6
+ from ultralytics.hub.utils import PREFIX, request_with_credentials
6
7
  from ultralytics.utils import LOGGER, SETTINGS, emojis, is_colab
7
8
 
8
- API_KEY_URL = f'{HUB_WEB_ROOT}/settings?tab=api+keys'
9
+ API_KEY_URL = f"{HUB_WEB_ROOT}/settings?tab=api+keys"
9
10
 
10
11
 
11
12
  class Auth:
@@ -22,9 +23,10 @@ class Auth:
22
23
  api_key (str or bool): API key for authentication, initialized as False.
23
24
  model_key (bool): Placeholder for model key, initialized as False.
24
25
  """
26
+
25
27
  id_token = api_key = model_key = False
26
28
 
27
- def __init__(self, api_key='', verbose=False):
29
+ def __init__(self, api_key="", verbose=False):
28
30
  """
29
31
  Initialize the Auth class with an optional API key.
30
32
 
@@ -32,18 +34,18 @@ class Auth:
32
34
  api_key (str, optional): May be an API key or a combination API key and model ID, i.e. key_id
33
35
  """
34
36
  # Split the input API key in case it contains a combined key_model and keep only the API key part
35
- api_key = api_key.split('_')[0]
37
+ api_key = api_key.split("_")[0]
36
38
 
37
39
  # Set API key attribute as value passed or SETTINGS API key if none passed
38
- self.api_key = api_key or SETTINGS.get('api_key', '')
40
+ self.api_key = api_key or SETTINGS.get("api_key", "")
39
41
 
40
42
  # If an API key is provided
41
43
  if self.api_key:
42
44
  # If the provided API key matches the API key in the SETTINGS
43
- if self.api_key == SETTINGS.get('api_key'):
45
+ if self.api_key == SETTINGS.get("api_key"):
44
46
  # Log that the user is already logged in
45
47
  if verbose:
46
- LOGGER.info(f'{PREFIX}Authenticated ✅')
48
+ LOGGER.info(f"{PREFIX}Authenticated ✅")
47
49
  return
48
50
  else:
49
51
  # Attempt to authenticate with the provided API key
@@ -58,12 +60,12 @@ class Auth:
58
60
 
59
61
  # Update SETTINGS with the new API key after successful authentication
60
62
  if success:
61
- SETTINGS.update({'api_key': self.api_key})
63
+ SETTINGS.update({"api_key": self.api_key})
62
64
  # Log that the new login was successful
63
65
  if verbose:
64
- LOGGER.info(f'{PREFIX}New authentication successful ✅')
66
+ LOGGER.info(f"{PREFIX}New authentication successful ✅")
65
67
  elif verbose:
66
- LOGGER.info(f'{PREFIX}Retrieve API key from {API_KEY_URL}')
68
+ LOGGER.info(f"{PREFIX}Retrieve API key from {API_KEY_URL}")
67
69
 
68
70
  def request_api_key(self, max_attempts=3):
69
71
  """
@@ -72,13 +74,14 @@ class Auth:
72
74
  Returns the model ID.
73
75
  """
74
76
  import getpass
77
+
75
78
  for attempts in range(max_attempts):
76
- LOGGER.info(f'{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}')
77
- input_key = getpass.getpass(f'Enter API key from {API_KEY_URL} ')
78
- self.api_key = input_key.split('_')[0] # remove model id if present
79
+ LOGGER.info(f"{PREFIX}Login. Attempt {attempts + 1} of {max_attempts}")
80
+ input_key = getpass.getpass(f"Enter API key from {API_KEY_URL} ")
81
+ self.api_key = input_key.split("_")[0] # remove model id if present
79
82
  if self.authenticate():
80
83
  return True
81
- raise ConnectionError(emojis(f'{PREFIX}Failed to authenticate ❌'))
84
+ raise ConnectionError(emojis(f"{PREFIX}Failed to authenticate ❌"))
82
85
 
83
86
  def authenticate(self) -> bool:
84
87
  """
@@ -89,14 +92,14 @@ class Auth:
89
92
  """
90
93
  try:
91
94
  if header := self.get_auth_header():
92
- r = requests.post(f'{HUB_API_ROOT}/v1/auth', headers=header)
93
- if not r.json().get('success', False):
94
- raise ConnectionError('Unable to authenticate.')
95
+ r = requests.post(f"{HUB_API_ROOT}/v1/auth", headers=header)
96
+ if not r.json().get("success", False):
97
+ raise ConnectionError("Unable to authenticate.")
95
98
  return True
96
- raise ConnectionError('User has not authenticated locally.')
99
+ raise ConnectionError("User has not authenticated locally.")
97
100
  except ConnectionError:
98
101
  self.id_token = self.api_key = False # reset invalid
99
- LOGGER.warning(f'{PREFIX}Invalid API key ⚠️')
102
+ LOGGER.warning(f"{PREFIX}Invalid API key ⚠️")
100
103
  return False
101
104
 
102
105
  def auth_with_cookies(self) -> bool:
@@ -110,12 +113,12 @@ class Auth:
110
113
  if not is_colab():
111
114
  return False # Currently only works with Colab
112
115
  try:
113
- authn = request_with_credentials(f'{HUB_API_ROOT}/v1/auth/auto')
114
- if authn.get('success', False):
115
- self.id_token = authn.get('data', {}).get('idToken', None)
116
+ authn = request_with_credentials(f"{HUB_API_ROOT}/v1/auth/auto")
117
+ if authn.get("success", False):
118
+ self.id_token = authn.get("data", {}).get("idToken", None)
116
119
  self.authenticate()
117
120
  return True
118
- raise ConnectionError('Unable to fetch browser authentication details.')
121
+ raise ConnectionError("Unable to fetch browser authentication details.")
119
122
  except ConnectionError:
120
123
  self.id_token = False # reset invalid
121
124
  return False
@@ -128,7 +131,7 @@ class Auth:
128
131
  (dict): The authentication header if id_token or API key is set, None otherwise.
129
132
  """
130
133
  if self.id_token:
131
- return {'authorization': f'Bearer {self.id_token}'}
134
+ return {"authorization": f"Bearer {self.id_token}"}
132
135
  elif self.api_key:
133
- return {'x-api-key': self.api_key}
136
+ return {"x-api-key": self.api_key}
134
137
  # else returns None