dgenerate-ultralytics-headless 8.3.218__py3-none-any.whl → 8.3.221__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/METADATA +1 -1
  2. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/RECORD +77 -77
  3. tests/__init__.py +5 -7
  4. tests/conftest.py +3 -7
  5. tests/test_cli.py +9 -2
  6. tests/test_engine.py +1 -1
  7. tests/test_exports.py +37 -9
  8. tests/test_integrations.py +4 -4
  9. tests/test_python.py +37 -44
  10. tests/test_solutions.py +154 -145
  11. ultralytics/__init__.py +1 -1
  12. ultralytics/cfg/__init__.py +7 -5
  13. ultralytics/cfg/default.yaml +1 -1
  14. ultralytics/data/__init__.py +4 -4
  15. ultralytics/data/augment.py +10 -10
  16. ultralytics/data/base.py +1 -1
  17. ultralytics/data/build.py +1 -1
  18. ultralytics/data/converter.py +3 -3
  19. ultralytics/data/dataset.py +3 -3
  20. ultralytics/data/loaders.py +2 -2
  21. ultralytics/data/utils.py +2 -2
  22. ultralytics/engine/exporter.py +73 -20
  23. ultralytics/engine/model.py +1 -1
  24. ultralytics/engine/predictor.py +1 -0
  25. ultralytics/engine/trainer.py +5 -3
  26. ultralytics/engine/tuner.py +4 -4
  27. ultralytics/hub/__init__.py +9 -7
  28. ultralytics/hub/utils.py +2 -2
  29. ultralytics/models/__init__.py +1 -1
  30. ultralytics/models/fastsam/__init__.py +1 -1
  31. ultralytics/models/fastsam/predict.py +10 -16
  32. ultralytics/models/nas/__init__.py +1 -1
  33. ultralytics/models/rtdetr/__init__.py +1 -1
  34. ultralytics/models/sam/__init__.py +1 -1
  35. ultralytics/models/sam/amg.py +2 -2
  36. ultralytics/models/sam/modules/blocks.py +1 -1
  37. ultralytics/models/sam/modules/transformer.py +1 -1
  38. ultralytics/models/sam/predict.py +1 -1
  39. ultralytics/models/yolo/__init__.py +1 -1
  40. ultralytics/models/yolo/pose/__init__.py +1 -1
  41. ultralytics/models/yolo/segment/val.py +1 -1
  42. ultralytics/models/yolo/yoloe/__init__.py +7 -7
  43. ultralytics/nn/__init__.py +7 -7
  44. ultralytics/nn/autobackend.py +32 -5
  45. ultralytics/nn/modules/__init__.py +60 -60
  46. ultralytics/nn/modules/block.py +26 -26
  47. ultralytics/nn/modules/conv.py +7 -7
  48. ultralytics/nn/modules/head.py +1 -1
  49. ultralytics/nn/modules/transformer.py +7 -7
  50. ultralytics/nn/modules/utils.py +1 -1
  51. ultralytics/nn/tasks.py +3 -3
  52. ultralytics/solutions/__init__.py +12 -12
  53. ultralytics/solutions/object_counter.py +3 -6
  54. ultralytics/solutions/queue_management.py +1 -1
  55. ultralytics/solutions/similarity_search.py +3 -3
  56. ultralytics/trackers/__init__.py +1 -1
  57. ultralytics/trackers/byte_tracker.py +2 -2
  58. ultralytics/trackers/utils/matching.py +1 -1
  59. ultralytics/utils/__init__.py +2 -2
  60. ultralytics/utils/benchmarks.py +4 -4
  61. ultralytics/utils/callbacks/comet.py +2 -2
  62. ultralytics/utils/checks.py +2 -2
  63. ultralytics/utils/downloads.py +2 -2
  64. ultralytics/utils/export/__init__.py +1 -1
  65. ultralytics/utils/files.py +1 -1
  66. ultralytics/utils/git.py +1 -1
  67. ultralytics/utils/logger.py +1 -1
  68. ultralytics/utils/metrics.py +13 -9
  69. ultralytics/utils/ops.py +8 -8
  70. ultralytics/utils/plotting.py +2 -1
  71. ultralytics/utils/torch_utils.py +5 -4
  72. ultralytics/utils/triton.py +2 -2
  73. ultralytics/utils/tuner.py +4 -2
  74. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/WHEEL +0 -0
  75. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/entry_points.txt +0 -0
  76. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/licenses/LICENSE +0 -0
  77. {dgenerate_ultralytics_headless-8.3.218.dist-info → dgenerate_ultralytics_headless-8.3.221.dist-info}/top_level.txt +0 -0
@@ -78,7 +78,7 @@ def iou_distance(atracks: list, btracks: list) -> np.ndarray:
78
78
  >>> btracks = [np.array([5, 5, 15, 15]), np.array([25, 25, 35, 35])]
79
79
  >>> cost_matrix = iou_distance(atracks, btracks)
80
80
  """
81
- if atracks and isinstance(atracks[0], np.ndarray) or btracks and isinstance(btracks[0], np.ndarray):
81
+ if (atracks and isinstance(atracks[0], np.ndarray)) or (btracks and isinstance(btracks[0], np.ndarray)):
82
82
  atlbrs = atracks
83
83
  btlbrs = btracks
84
84
  else:
@@ -260,7 +260,7 @@ class SimpleClass:
260
260
  # Display only the module and class name for subclasses
261
261
  s = f"{a}: {v.__module__}.{v.__class__.__name__} object"
262
262
  else:
263
- s = f"{a}: {repr(v)}"
263
+ s = f"{a}: {v!r}"
264
264
  attr.append(s)
265
265
  return f"{self.__module__}.{self.__class__.__name__} object with attributes:\n\n" + "\n".join(attr)
266
266
 
@@ -1137,7 +1137,7 @@ def set_sentry():
1137
1137
  return
1138
1138
  # If sentry_sdk package is not installed then return and do not use Sentry
1139
1139
  try:
1140
- import sentry_sdk # noqa
1140
+ import sentry_sdk
1141
1141
  except ImportError:
1142
1142
  return
1143
1143
 
@@ -286,7 +286,7 @@ class RF100Benchmark:
286
286
  with open(ds_link_txt, encoding="utf-8") as file:
287
287
  for line in file:
288
288
  try:
289
- _, url, workspace, project, version = re.split("/+", line.strip())
289
+ _, _url, workspace, project, version = re.split("/+", line.strip())
290
290
  self.ds_names.append(project)
291
291
  proj_version = f"{project}-{version}"
292
292
  if not Path(proj_version).exists():
@@ -357,7 +357,7 @@ class RF100Benchmark:
357
357
  map_val = lst["map50"]
358
358
  else:
359
359
  LOGGER.info("Single dict found")
360
- map_val = [res["map50"] for res in eval_lines][0]
360
+ map_val = next(res["map50"] for res in eval_lines)
361
361
 
362
362
  with open(eval_log_file, "a", encoding="utf-8") as f:
363
363
  f.write(f"{self.ds_names[list_ind]}: {map_val}\n")
@@ -681,7 +681,7 @@ class ProfileModels:
681
681
  Returns:
682
682
  (str): Formatted table row string with model metrics.
683
683
  """
684
- layers, params, gradients, flops = model_info
684
+ _layers, params, _gradients, flops = model_info
685
685
  return (
686
686
  f"| {model_name:18s} | {self.imgsz} | - | {t_onnx[0]:.1f}±{t_onnx[1]:.1f} ms | {t_engine[0]:.1f}±"
687
687
  f"{t_engine[1]:.1f} ms | {params / 1e6:.1f} | {flops:.1f} |"
@@ -706,7 +706,7 @@ class ProfileModels:
706
706
  Returns:
707
707
  (dict): Dictionary containing profiling results.
708
708
  """
709
- layers, params, gradients, flops = model_info
709
+ _layers, params, _gradients, flops = model_info
710
710
  return {
711
711
  "model/name": model_name,
712
712
  "model/parameters": params,
@@ -261,7 +261,7 @@ def _format_prediction_annotations(image_path, metadata, class_label_map=None, c
261
261
  class_label_map = {class_map[k]: v for k, v in class_label_map.items()}
262
262
  try:
263
263
  # import pycotools utilities to decompress annotations for various tasks, e.g. segmentation
264
- from faster_coco_eval.core.mask import decode # noqa
264
+ from faster_coco_eval.core.mask import decode
265
265
  except ImportError:
266
266
  decode = None
267
267
 
@@ -350,7 +350,7 @@ def _create_prediction_metadata_map(model_predictions) -> dict:
350
350
  def _log_confusion_matrix(experiment, trainer, curr_step, curr_epoch) -> None:
351
351
  """Log the confusion matrix to Comet experiment."""
352
352
  conf_mat = trainer.validator.confusion_matrix.matrix
353
- names = list(trainer.data["names"].values()) + ["background"]
353
+ names = [*list(trainer.data["names"].values()), "background"]
354
354
  experiment.log_confusion_matrix(
355
355
  matrix=conf_mat, labels=names, max_categories=len(names), epoch=curr_epoch, step=curr_step
356
356
  )
@@ -672,7 +672,7 @@ def check_yolo(verbose=True, device=""):
672
672
  # System info
673
673
  gib = 1 << 30 # bytes per GiB
674
674
  ram = psutil.virtual_memory().total
675
- total, used, free = shutil.disk_usage("/")
675
+ total, _used, free = shutil.disk_usage("/")
676
676
  s = f"({os.cpu_count()} CPUs, {ram / gib:.1f} GB RAM, {(total - free) / gib:.1f}/{total / gib:.1f} GB disk)"
677
677
  try:
678
678
  from IPython import display
@@ -705,7 +705,7 @@ def collect_system_info():
705
705
  gib = 1 << 30 # bytes per GiB
706
706
  cuda = torch.cuda.is_available()
707
707
  check_yolo()
708
- total, used, free = shutil.disk_usage("/")
708
+ total, _used, free = shutil.disk_usage("/")
709
709
 
710
710
  info_dict = {
711
711
  "OS": platform.platform(),
@@ -183,7 +183,7 @@ def unzip_file(
183
183
  if unzip_as_dir:
184
184
  # Zip has 1 top-level directory
185
185
  extract_path = path # i.e. ../datasets
186
- path = Path(path) / list(top_level_dirs)[0] # i.e. extract coco8/ dir to ../datasets/
186
+ path = Path(path) / next(iter(top_level_dirs)) # i.e. extract coco8/ dir to ../datasets/
187
187
  else:
188
188
  # Zip has multiple files at top level
189
189
  path = extract_path = Path(path) / Path(file).stem # i.e. extract multiple files to ../datasets/coco8/
@@ -222,7 +222,7 @@ def check_disk_space(
222
222
  Returns:
223
223
  (bool): True if there is sufficient disk space, False otherwise.
224
224
  """
225
- total, used, free = shutil.disk_usage(path) # bytes
225
+ _total, _used, free = shutil.disk_usage(path) # bytes
226
226
  if file_bytes * sf < free:
227
227
  return True # sufficient space
228
228
 
@@ -92,7 +92,7 @@ def onnx2engine(
92
92
  INT8 calibration requires a dataset and generates a calibration cache.
93
93
  Metadata is serialized and written to the engine file if provided.
94
94
  """
95
- import tensorrt as trt # noqa
95
+ import tensorrt as trt
96
96
 
97
97
  engine_file = engine_file or Path(onnx_file).with_suffix(".engine")
98
98
 
@@ -49,7 +49,7 @@ class WorkingDirectory(contextlib.ContextDecorator):
49
49
  """Change the current working directory to the specified directory upon entering the context."""
50
50
  os.chdir(self.dir)
51
51
 
52
- def __exit__(self, exc_type, exc_val, exc_tb): # noqa
52
+ def __exit__(self, exc_type, exc_val, exc_tb):
53
53
  """Restore the original working directory when exiting the context."""
54
54
  os.chdir(self.cwd)
55
55
 
ultralytics/utils/git.py CHANGED
@@ -51,7 +51,7 @@ class GitRepo:
51
51
  @staticmethod
52
52
  def _find_root(p: Path) -> Path | None:
53
53
  """Return repo root or None."""
54
- return next((d for d in [p] + list(p.parents) if (d / ".git").exists()), None)
54
+ return next((d for d in [p, *list(p.parents)] if (d / ".git").exists()), None)
55
55
 
56
56
  @staticmethod
57
57
  def _gitdir(root: Path) -> Path | None:
@@ -200,7 +200,7 @@ class ConsoleLogger:
200
200
  class _ConsoleCapture:
201
201
  """Lightweight stdout/stderr capture."""
202
202
 
203
- __slots__ = ("original", "callback")
203
+ __slots__ = ("callback", "original")
204
204
 
205
205
  def __init__(self, original, callback):
206
206
  self.original = original
@@ -535,7 +535,7 @@ class ConfusionMatrix(DataExportMixin):
535
535
  array = array[keep_idx, :][:, keep_idx] # slice matrix rows and cols
536
536
  n = (self.nc + k - 1) // k # number of retained classes
537
537
  nc = nn = n if self.task == "classify" else n + 1 # adjust for background if needed
538
- ticklabels = (names + ["background"]) if (0 < nn < 99) and (nn == nc) else "auto"
538
+ ticklabels = ([*names, "background"]) if (0 < nn < 99) and (nn == nc) else "auto"
539
539
  xy_ticks = np.arange(len(ticklabels))
540
540
  tick_fontsize = max(6, 15 - 0.1 * nc) # Minimum size is 6
541
541
  label_fontsize = max(6, 12 - 0.1 * nc)
@@ -608,7 +608,7 @@ class ConfusionMatrix(DataExportMixin):
608
608
  """
609
609
  import re
610
610
 
611
- names = list(self.names.values()) if self.task == "classify" else list(self.names.values()) + ["background"]
611
+ names = list(self.names.values()) if self.task == "classify" else [*list(self.names.values()), "background"]
612
612
  clean_names, seen = [], set()
613
613
  for name in names:
614
614
  clean_name = re.sub(r"[^a-zA-Z0-9_]", "_", name)
@@ -1152,8 +1152,8 @@ class DetMetrics(SimpleClass, DataExportMixin):
1152
1152
  @property
1153
1153
  def results_dict(self) -> dict[str, float]:
1154
1154
  """Return dictionary of computed performance metrics and statistics."""
1155
- keys = self.keys + ["fitness"]
1156
- values = ((float(x) if hasattr(x, "item") else x) for x in (self.mean_results() + [self.fitness]))
1155
+ keys = [*self.keys, "fitness"]
1156
+ values = ((float(x) if hasattr(x, "item") else x) for x in ([*self.mean_results(), self.fitness]))
1157
1157
  return dict(zip(keys, values))
1158
1158
 
1159
1159
  @property
@@ -1270,7 +1270,8 @@ class SegmentMetrics(DetMetrics):
1270
1270
  @property
1271
1271
  def keys(self) -> list[str]:
1272
1272
  """Return a list of keys for accessing metrics."""
1273
- return DetMetrics.keys.fget(self) + [
1273
+ return [
1274
+ *DetMetrics.keys.fget(self),
1274
1275
  "metrics/precision(M)",
1275
1276
  "metrics/recall(M)",
1276
1277
  "metrics/mAP50(M)",
@@ -1298,7 +1299,8 @@ class SegmentMetrics(DetMetrics):
1298
1299
  @property
1299
1300
  def curves(self) -> list[str]:
1300
1301
  """Return a list of curves for accessing specific metrics curves."""
1301
- return DetMetrics.curves.fget(self) + [
1302
+ return [
1303
+ *DetMetrics.curves.fget(self),
1302
1304
  "Precision-Recall(M)",
1303
1305
  "F1-Confidence(M)",
1304
1306
  "Precision-Confidence(M)",
@@ -1407,7 +1409,8 @@ class PoseMetrics(DetMetrics):
1407
1409
  @property
1408
1410
  def keys(self) -> list[str]:
1409
1411
  """Return a list of evaluation metric keys."""
1410
- return DetMetrics.keys.fget(self) + [
1412
+ return [
1413
+ *DetMetrics.keys.fget(self),
1411
1414
  "metrics/precision(P)",
1412
1415
  "metrics/recall(P)",
1413
1416
  "metrics/mAP50(P)",
@@ -1435,7 +1438,8 @@ class PoseMetrics(DetMetrics):
1435
1438
  @property
1436
1439
  def curves(self) -> list[str]:
1437
1440
  """Return a list of curves for accessing specific metrics curves."""
1438
- return DetMetrics.curves.fget(self) + [
1441
+ return [
1442
+ *DetMetrics.curves.fget(self),
1439
1443
  "Precision-Recall(B)",
1440
1444
  "F1-Confidence(B)",
1441
1445
  "Precision-Confidence(B)",
@@ -1527,7 +1531,7 @@ class ClassifyMetrics(SimpleClass, DataExportMixin):
1527
1531
  @property
1528
1532
  def results_dict(self) -> dict[str, float]:
1529
1533
  """Return a dictionary with model's performance metrics and fitness score."""
1530
- return dict(zip(self.keys + ["fitness"], [self.top1, self.top5, self.fitness]))
1534
+ return dict(zip([*self.keys, "fitness"], [self.top1, self.top5, self.fitness]))
1531
1535
 
1532
1536
  @property
1533
1537
  def keys(self) -> list[str]:
ultralytics/utils/ops.py CHANGED
@@ -56,7 +56,7 @@ class Profile(contextlib.ContextDecorator):
56
56
  self.start = self.time()
57
57
  return self
58
58
 
59
- def __exit__(self, type, value, traceback): # noqa
59
+ def __exit__(self, type, value, traceback):
60
60
  """Stop timing."""
61
61
  self.dt = self.time() - self.start # delta-time
62
62
  self.t += self.dt # accumulate dt
@@ -236,10 +236,10 @@ def scale_image(masks, im0_shape, ratio_pad=None):
236
236
  pad = ratio_pad[1]
237
237
 
238
238
  pad_w, pad_h = pad
239
- top = int(round(pad_h - 0.1))
240
- left = int(round(pad_w - 0.1))
241
- bottom = im1_h - int(round(pad_h + 0.1))
242
- right = im1_w - int(round(pad_w + 0.1))
239
+ top = round(pad_h - 0.1)
240
+ left = round(pad_w - 0.1)
241
+ bottom = im1_h - round(pad_h + 0.1)
242
+ right = im1_w - round(pad_w + 0.1)
243
243
 
244
244
  if len(masks.shape) < 2:
245
245
  raise ValueError(f'"len of masks shape" should be 2 or 3, but got {len(masks.shape)}')
@@ -599,9 +599,9 @@ def scale_masks(masks, shape, padding: bool = True):
599
599
  if padding:
600
600
  pad_w /= 2
601
601
  pad_h /= 2
602
- top, left = (int(round(pad_h - 0.1)), int(round(pad_w - 0.1))) if padding else (0, 0)
603
- bottom = mh - int(round(pad_h + 0.1))
604
- right = mw - int(round(pad_w + 0.1))
602
+ top, left = (round(pad_h - 0.1), round(pad_w - 0.1)) if padding else (0, 0)
603
+ bottom = mh - round(pad_h + 0.1)
604
+ right = mw - round(pad_w + 0.1)
605
605
  return F.interpolate(masks[..., top:bottom, left:right], shape, mode="bilinear") # NCHW masks
606
606
 
607
607
 
@@ -4,8 +4,9 @@ from __future__ import annotations
4
4
 
5
5
  import math
6
6
  import warnings
7
+ from collections.abc import Callable
7
8
  from pathlib import Path
8
- from typing import Any, Callable
9
+ from typing import Any
9
10
 
10
11
  import cv2
11
12
  import numpy as np
@@ -44,6 +44,7 @@ TORCH_1_13 = check_version(TORCH_VERSION, "1.13.0")
44
44
  TORCH_2_0 = check_version(TORCH_VERSION, "2.0.0")
45
45
  TORCH_2_1 = check_version(TORCH_VERSION, "2.1.0")
46
46
  TORCH_2_4 = check_version(TORCH_VERSION, "2.4.0")
47
+ TORCH_2_9 = check_version(TORCH_VERSION, "2.9.0")
47
48
  TORCHVISION_0_10 = check_version(TORCHVISION_VERSION, "0.10.0")
48
49
  TORCHVISION_0_11 = check_version(TORCHVISION_VERSION, "0.11.0")
49
50
  TORCHVISION_0_13 = check_version(TORCHVISION_VERSION, "0.13.0")
@@ -333,10 +334,10 @@ def model_info(model, detailed=False, verbose=True, imgsz=640):
333
334
  if len(m._parameters):
334
335
  for pn, p in m.named_parameters():
335
336
  LOGGER.info(
336
- f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{str(list(p.shape)):>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
337
+ f"{i:>5g}{f'{mn}.{pn}':>40}{mt:>20}{p.requires_grad!r:>10}{p.numel():>12g}{list(p.shape)!s:>20}{p.mean():>10.3g}{p.std():>10.3g}{str(p.dtype).replace('torch.', ''):>15}"
337
338
  )
338
339
  else: # layers with no learnable params
339
- LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{str([]):>20}{'-':>10}{'-':>10}{'-':>15}")
340
+ LOGGER.info(f"{i:>5g}{mn:>40}{mt:>20}{False!r:>10}{0:>12g}{[]!s:>20}{'-':>10}{'-':>10}{'-':>15}")
340
341
 
341
342
  flops = get_flops(model, imgsz) # imgsz may be int or list, i.e. imgsz=640 or imgsz=[640, 320]
342
343
  fused = " (fused)" if getattr(model, "is_fused", lambda: False)() else ""
@@ -685,7 +686,7 @@ class ModelEMA:
685
686
  copy_attr(self.ema, model, include, exclude)
686
687
 
687
688
 
688
- def strip_optimizer(f: str | Path = "best.pt", s: str = "", updates: dict[str, Any] = None) -> dict[str, Any]:
689
+ def strip_optimizer(f: str | Path = "best.pt", s: str = "", updates: dict[str, Any] | None = None) -> dict[str, Any]:
689
690
  """
690
691
  Strip optimizer from 'f' to finalize training, optionally save as 's'.
691
692
 
@@ -865,7 +866,7 @@ def profile_ops(input, ops, n=10, device=None, max_num_obj=0):
865
866
  mem += cuda_info["memory"] / 1e9 # (GB)
866
867
  s_in, s_out = (tuple(x.shape) if isinstance(x, torch.Tensor) else "list" for x in (x, y)) # shapes
867
868
  p = sum(x.numel() for x in m.parameters()) if isinstance(m, nn.Module) else 0 # parameters
868
- LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{str(s_in):>24s}{str(s_out):>24s}")
869
+ LOGGER.info(f"{p:12}{flops:12.4g}{mem:>14.3f}{tf:14.4g}{tb:14.4g}{s_in!s:>24s}{s_out!s:>24s}")
869
870
  results.append([p, flops, mem, tf, tb, s_in, s_out])
870
871
  except Exception as e:
871
872
  LOGGER.info(e)
@@ -64,12 +64,12 @@ class TritonRemoteModel:
64
64
 
65
65
  # Choose the Triton client based on the communication scheme
66
66
  if scheme == "http":
67
- import tritonclient.http as client # noqa
67
+ import tritonclient.http as client
68
68
 
69
69
  self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
70
70
  config = self.triton_client.get_model_config(endpoint)
71
71
  else:
72
- import tritonclient.grpc as client # noqa
72
+ import tritonclient.grpc as client
73
73
 
74
74
  self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
75
75
  config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
@@ -1,14 +1,16 @@
1
1
  # Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  from ultralytics.cfg import TASK2DATA, TASK2METRIC, get_cfg, get_save_dir
4
6
  from ultralytics.utils import DEFAULT_CFG, DEFAULT_CFG_DICT, LOGGER, NUM_THREADS, checks, colorstr
5
7
 
6
8
 
7
9
  def run_ray_tune(
8
10
  model,
9
- space: dict = None,
11
+ space: dict | None = None,
10
12
  grace_period: int = 10,
11
- gpu_per_trial: int = None,
13
+ gpu_per_trial: int | None = None,
12
14
  max_samples: int = 10,
13
15
  **train_args,
14
16
  ):