dgenerate-ultralytics-headless 8.3.253__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dgenerate_ultralytics_headless-8.3.253.dist-info/METADATA +405 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/RECORD +299 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/WHEEL +5 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/entry_points.txt +3 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/licenses/LICENSE +661 -0
- dgenerate_ultralytics_headless-8.3.253.dist-info/top_level.txt +1 -0
- tests/__init__.py +23 -0
- tests/conftest.py +59 -0
- tests/test_cli.py +131 -0
- tests/test_cuda.py +216 -0
- tests/test_engine.py +157 -0
- tests/test_exports.py +309 -0
- tests/test_integrations.py +151 -0
- tests/test_python.py +777 -0
- tests/test_solutions.py +371 -0
- ultralytics/__init__.py +48 -0
- ultralytics/assets/bus.jpg +0 -0
- ultralytics/assets/zidane.jpg +0 -0
- ultralytics/cfg/__init__.py +1028 -0
- ultralytics/cfg/datasets/Argoverse.yaml +78 -0
- ultralytics/cfg/datasets/DOTAv1.5.yaml +37 -0
- ultralytics/cfg/datasets/DOTAv1.yaml +36 -0
- ultralytics/cfg/datasets/GlobalWheat2020.yaml +68 -0
- ultralytics/cfg/datasets/HomeObjects-3K.yaml +32 -0
- ultralytics/cfg/datasets/ImageNet.yaml +2025 -0
- ultralytics/cfg/datasets/Objects365.yaml +447 -0
- ultralytics/cfg/datasets/SKU-110K.yaml +58 -0
- ultralytics/cfg/datasets/TT100K.yaml +346 -0
- ultralytics/cfg/datasets/VOC.yaml +102 -0
- ultralytics/cfg/datasets/VisDrone.yaml +87 -0
- ultralytics/cfg/datasets/african-wildlife.yaml +25 -0
- ultralytics/cfg/datasets/brain-tumor.yaml +22 -0
- ultralytics/cfg/datasets/carparts-seg.yaml +44 -0
- ultralytics/cfg/datasets/coco-pose.yaml +64 -0
- ultralytics/cfg/datasets/coco.yaml +118 -0
- ultralytics/cfg/datasets/coco128-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco128.yaml +101 -0
- ultralytics/cfg/datasets/coco8-grayscale.yaml +103 -0
- ultralytics/cfg/datasets/coco8-multispectral.yaml +104 -0
- ultralytics/cfg/datasets/coco8-pose.yaml +47 -0
- ultralytics/cfg/datasets/coco8-seg.yaml +101 -0
- ultralytics/cfg/datasets/coco8.yaml +101 -0
- ultralytics/cfg/datasets/construction-ppe.yaml +32 -0
- ultralytics/cfg/datasets/crack-seg.yaml +22 -0
- ultralytics/cfg/datasets/dog-pose.yaml +52 -0
- ultralytics/cfg/datasets/dota8-multispectral.yaml +38 -0
- ultralytics/cfg/datasets/dota8.yaml +35 -0
- ultralytics/cfg/datasets/hand-keypoints.yaml +50 -0
- ultralytics/cfg/datasets/kitti.yaml +27 -0
- ultralytics/cfg/datasets/lvis.yaml +1240 -0
- ultralytics/cfg/datasets/medical-pills.yaml +21 -0
- ultralytics/cfg/datasets/open-images-v7.yaml +663 -0
- ultralytics/cfg/datasets/package-seg.yaml +22 -0
- ultralytics/cfg/datasets/signature.yaml +21 -0
- ultralytics/cfg/datasets/tiger-pose.yaml +41 -0
- ultralytics/cfg/datasets/xView.yaml +155 -0
- ultralytics/cfg/default.yaml +130 -0
- ultralytics/cfg/models/11/yolo11-cls-resnet18.yaml +17 -0
- ultralytics/cfg/models/11/yolo11-cls.yaml +33 -0
- ultralytics/cfg/models/11/yolo11-obb.yaml +50 -0
- ultralytics/cfg/models/11/yolo11-pose.yaml +51 -0
- ultralytics/cfg/models/11/yolo11-seg.yaml +50 -0
- ultralytics/cfg/models/11/yolo11.yaml +50 -0
- ultralytics/cfg/models/11/yoloe-11-seg.yaml +48 -0
- ultralytics/cfg/models/11/yoloe-11.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-cls.yaml +32 -0
- ultralytics/cfg/models/12/yolo12-obb.yaml +48 -0
- ultralytics/cfg/models/12/yolo12-pose.yaml +49 -0
- ultralytics/cfg/models/12/yolo12-seg.yaml +48 -0
- ultralytics/cfg/models/12/yolo12.yaml +48 -0
- ultralytics/cfg/models/rt-detr/rtdetr-l.yaml +53 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet101.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-resnet50.yaml +45 -0
- ultralytics/cfg/models/rt-detr/rtdetr-x.yaml +57 -0
- ultralytics/cfg/models/v10/yolov10b.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10l.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10m.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10n.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10s.yaml +45 -0
- ultralytics/cfg/models/v10/yolov10x.yaml +45 -0
- ultralytics/cfg/models/v3/yolov3-spp.yaml +49 -0
- ultralytics/cfg/models/v3/yolov3-tiny.yaml +40 -0
- ultralytics/cfg/models/v3/yolov3.yaml +49 -0
- ultralytics/cfg/models/v5/yolov5-p6.yaml +62 -0
- ultralytics/cfg/models/v5/yolov5.yaml +51 -0
- ultralytics/cfg/models/v6/yolov6.yaml +56 -0
- ultralytics/cfg/models/v8/yoloe-v8-seg.yaml +48 -0
- ultralytics/cfg/models/v8/yoloe-v8.yaml +48 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet101.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls-resnet50.yaml +28 -0
- ultralytics/cfg/models/v8/yolov8-cls.yaml +32 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p2.yaml +58 -0
- ultralytics/cfg/models/v8/yolov8-ghost-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-ghost.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-obb.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-p2.yaml +57 -0
- ultralytics/cfg/models/v8/yolov8-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-pose-p6.yaml +60 -0
- ultralytics/cfg/models/v8/yolov8-pose.yaml +50 -0
- ultralytics/cfg/models/v8/yolov8-rtdetr.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-seg-p6.yaml +59 -0
- ultralytics/cfg/models/v8/yolov8-seg.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8-world.yaml +51 -0
- ultralytics/cfg/models/v8/yolov8-worldv2.yaml +49 -0
- ultralytics/cfg/models/v8/yolov8.yaml +49 -0
- ultralytics/cfg/models/v9/yolov9c-seg.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9c.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9e-seg.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9e.yaml +64 -0
- ultralytics/cfg/models/v9/yolov9m.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9s.yaml +41 -0
- ultralytics/cfg/models/v9/yolov9t.yaml +41 -0
- ultralytics/cfg/trackers/botsort.yaml +21 -0
- ultralytics/cfg/trackers/bytetrack.yaml +12 -0
- ultralytics/data/__init__.py +26 -0
- ultralytics/data/annotator.py +66 -0
- ultralytics/data/augment.py +2801 -0
- ultralytics/data/base.py +435 -0
- ultralytics/data/build.py +437 -0
- ultralytics/data/converter.py +855 -0
- ultralytics/data/dataset.py +834 -0
- ultralytics/data/loaders.py +704 -0
- ultralytics/data/scripts/download_weights.sh +18 -0
- ultralytics/data/scripts/get_coco.sh +61 -0
- ultralytics/data/scripts/get_coco128.sh +18 -0
- ultralytics/data/scripts/get_imagenet.sh +52 -0
- ultralytics/data/split.py +138 -0
- ultralytics/data/split_dota.py +344 -0
- ultralytics/data/utils.py +798 -0
- ultralytics/engine/__init__.py +1 -0
- ultralytics/engine/exporter.py +1580 -0
- ultralytics/engine/model.py +1125 -0
- ultralytics/engine/predictor.py +508 -0
- ultralytics/engine/results.py +1522 -0
- ultralytics/engine/trainer.py +977 -0
- ultralytics/engine/tuner.py +449 -0
- ultralytics/engine/validator.py +387 -0
- ultralytics/hub/__init__.py +166 -0
- ultralytics/hub/auth.py +151 -0
- ultralytics/hub/google/__init__.py +174 -0
- ultralytics/hub/session.py +422 -0
- ultralytics/hub/utils.py +162 -0
- ultralytics/models/__init__.py +9 -0
- ultralytics/models/fastsam/__init__.py +7 -0
- ultralytics/models/fastsam/model.py +79 -0
- ultralytics/models/fastsam/predict.py +169 -0
- ultralytics/models/fastsam/utils.py +23 -0
- ultralytics/models/fastsam/val.py +38 -0
- ultralytics/models/nas/__init__.py +7 -0
- ultralytics/models/nas/model.py +98 -0
- ultralytics/models/nas/predict.py +56 -0
- ultralytics/models/nas/val.py +38 -0
- ultralytics/models/rtdetr/__init__.py +7 -0
- ultralytics/models/rtdetr/model.py +63 -0
- ultralytics/models/rtdetr/predict.py +88 -0
- ultralytics/models/rtdetr/train.py +89 -0
- ultralytics/models/rtdetr/val.py +216 -0
- ultralytics/models/sam/__init__.py +25 -0
- ultralytics/models/sam/amg.py +275 -0
- ultralytics/models/sam/build.py +365 -0
- ultralytics/models/sam/build_sam3.py +377 -0
- ultralytics/models/sam/model.py +169 -0
- ultralytics/models/sam/modules/__init__.py +1 -0
- ultralytics/models/sam/modules/blocks.py +1067 -0
- ultralytics/models/sam/modules/decoders.py +495 -0
- ultralytics/models/sam/modules/encoders.py +794 -0
- ultralytics/models/sam/modules/memory_attention.py +298 -0
- ultralytics/models/sam/modules/sam.py +1160 -0
- ultralytics/models/sam/modules/tiny_encoder.py +979 -0
- ultralytics/models/sam/modules/transformer.py +344 -0
- ultralytics/models/sam/modules/utils.py +512 -0
- ultralytics/models/sam/predict.py +3940 -0
- ultralytics/models/sam/sam3/__init__.py +3 -0
- ultralytics/models/sam/sam3/decoder.py +546 -0
- ultralytics/models/sam/sam3/encoder.py +529 -0
- ultralytics/models/sam/sam3/geometry_encoders.py +415 -0
- ultralytics/models/sam/sam3/maskformer_segmentation.py +286 -0
- ultralytics/models/sam/sam3/model_misc.py +199 -0
- ultralytics/models/sam/sam3/necks.py +129 -0
- ultralytics/models/sam/sam3/sam3_image.py +339 -0
- ultralytics/models/sam/sam3/text_encoder_ve.py +307 -0
- ultralytics/models/sam/sam3/vitdet.py +547 -0
- ultralytics/models/sam/sam3/vl_combiner.py +160 -0
- ultralytics/models/utils/__init__.py +1 -0
- ultralytics/models/utils/loss.py +466 -0
- ultralytics/models/utils/ops.py +315 -0
- ultralytics/models/yolo/__init__.py +7 -0
- ultralytics/models/yolo/classify/__init__.py +7 -0
- ultralytics/models/yolo/classify/predict.py +90 -0
- ultralytics/models/yolo/classify/train.py +202 -0
- ultralytics/models/yolo/classify/val.py +216 -0
- ultralytics/models/yolo/detect/__init__.py +7 -0
- ultralytics/models/yolo/detect/predict.py +122 -0
- ultralytics/models/yolo/detect/train.py +227 -0
- ultralytics/models/yolo/detect/val.py +507 -0
- ultralytics/models/yolo/model.py +430 -0
- ultralytics/models/yolo/obb/__init__.py +7 -0
- ultralytics/models/yolo/obb/predict.py +56 -0
- ultralytics/models/yolo/obb/train.py +79 -0
- ultralytics/models/yolo/obb/val.py +302 -0
- ultralytics/models/yolo/pose/__init__.py +7 -0
- ultralytics/models/yolo/pose/predict.py +65 -0
- ultralytics/models/yolo/pose/train.py +110 -0
- ultralytics/models/yolo/pose/val.py +248 -0
- ultralytics/models/yolo/segment/__init__.py +7 -0
- ultralytics/models/yolo/segment/predict.py +109 -0
- ultralytics/models/yolo/segment/train.py +69 -0
- ultralytics/models/yolo/segment/val.py +307 -0
- ultralytics/models/yolo/world/__init__.py +5 -0
- ultralytics/models/yolo/world/train.py +173 -0
- ultralytics/models/yolo/world/train_world.py +178 -0
- ultralytics/models/yolo/yoloe/__init__.py +22 -0
- ultralytics/models/yolo/yoloe/predict.py +162 -0
- ultralytics/models/yolo/yoloe/train.py +287 -0
- ultralytics/models/yolo/yoloe/train_seg.py +122 -0
- ultralytics/models/yolo/yoloe/val.py +206 -0
- ultralytics/nn/__init__.py +27 -0
- ultralytics/nn/autobackend.py +964 -0
- ultralytics/nn/modules/__init__.py +182 -0
- ultralytics/nn/modules/activation.py +54 -0
- ultralytics/nn/modules/block.py +1947 -0
- ultralytics/nn/modules/conv.py +669 -0
- ultralytics/nn/modules/head.py +1183 -0
- ultralytics/nn/modules/transformer.py +793 -0
- ultralytics/nn/modules/utils.py +159 -0
- ultralytics/nn/tasks.py +1768 -0
- ultralytics/nn/text_model.py +356 -0
- ultralytics/py.typed +1 -0
- ultralytics/solutions/__init__.py +41 -0
- ultralytics/solutions/ai_gym.py +108 -0
- ultralytics/solutions/analytics.py +264 -0
- ultralytics/solutions/config.py +107 -0
- ultralytics/solutions/distance_calculation.py +123 -0
- ultralytics/solutions/heatmap.py +125 -0
- ultralytics/solutions/instance_segmentation.py +86 -0
- ultralytics/solutions/object_blurrer.py +89 -0
- ultralytics/solutions/object_counter.py +190 -0
- ultralytics/solutions/object_cropper.py +87 -0
- ultralytics/solutions/parking_management.py +280 -0
- ultralytics/solutions/queue_management.py +93 -0
- ultralytics/solutions/region_counter.py +133 -0
- ultralytics/solutions/security_alarm.py +151 -0
- ultralytics/solutions/similarity_search.py +219 -0
- ultralytics/solutions/solutions.py +828 -0
- ultralytics/solutions/speed_estimation.py +114 -0
- ultralytics/solutions/streamlit_inference.py +260 -0
- ultralytics/solutions/templates/similarity-search.html +156 -0
- ultralytics/solutions/trackzone.py +88 -0
- ultralytics/solutions/vision_eye.py +67 -0
- ultralytics/trackers/__init__.py +7 -0
- ultralytics/trackers/basetrack.py +115 -0
- ultralytics/trackers/bot_sort.py +257 -0
- ultralytics/trackers/byte_tracker.py +469 -0
- ultralytics/trackers/track.py +116 -0
- ultralytics/trackers/utils/__init__.py +1 -0
- ultralytics/trackers/utils/gmc.py +339 -0
- ultralytics/trackers/utils/kalman_filter.py +482 -0
- ultralytics/trackers/utils/matching.py +154 -0
- ultralytics/utils/__init__.py +1450 -0
- ultralytics/utils/autobatch.py +118 -0
- ultralytics/utils/autodevice.py +205 -0
- ultralytics/utils/benchmarks.py +728 -0
- ultralytics/utils/callbacks/__init__.py +5 -0
- ultralytics/utils/callbacks/base.py +233 -0
- ultralytics/utils/callbacks/clearml.py +146 -0
- ultralytics/utils/callbacks/comet.py +625 -0
- ultralytics/utils/callbacks/dvc.py +197 -0
- ultralytics/utils/callbacks/hub.py +110 -0
- ultralytics/utils/callbacks/mlflow.py +134 -0
- ultralytics/utils/callbacks/neptune.py +126 -0
- ultralytics/utils/callbacks/platform.py +453 -0
- ultralytics/utils/callbacks/raytune.py +42 -0
- ultralytics/utils/callbacks/tensorboard.py +123 -0
- ultralytics/utils/callbacks/wb.py +188 -0
- ultralytics/utils/checks.py +1020 -0
- ultralytics/utils/cpu.py +85 -0
- ultralytics/utils/dist.py +123 -0
- ultralytics/utils/downloads.py +529 -0
- ultralytics/utils/errors.py +35 -0
- ultralytics/utils/events.py +113 -0
- ultralytics/utils/export/__init__.py +7 -0
- ultralytics/utils/export/engine.py +237 -0
- ultralytics/utils/export/imx.py +325 -0
- ultralytics/utils/export/tensorflow.py +231 -0
- ultralytics/utils/files.py +219 -0
- ultralytics/utils/git.py +137 -0
- ultralytics/utils/instance.py +484 -0
- ultralytics/utils/logger.py +506 -0
- ultralytics/utils/loss.py +849 -0
- ultralytics/utils/metrics.py +1563 -0
- ultralytics/utils/nms.py +337 -0
- ultralytics/utils/ops.py +664 -0
- ultralytics/utils/patches.py +201 -0
- ultralytics/utils/plotting.py +1047 -0
- ultralytics/utils/tal.py +404 -0
- ultralytics/utils/torch_utils.py +984 -0
- ultralytics/utils/tqdm.py +443 -0
- ultralytics/utils/triton.py +112 -0
- ultralytics/utils/tuner.py +168 -0
|
@@ -0,0 +1,443 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
import time
|
|
8
|
+
from functools import lru_cache
|
|
9
|
+
from typing import IO, Any
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@lru_cache(maxsize=1)
|
|
13
|
+
def is_noninteractive_console() -> bool:
|
|
14
|
+
"""Check for known non-interactive console environments."""
|
|
15
|
+
return "GITHUB_ACTIONS" in os.environ or "RUNPOD_POD_ID" in os.environ
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class TQDM:
|
|
19
|
+
"""Lightweight zero-dependency progress bar for Ultralytics.
|
|
20
|
+
|
|
21
|
+
Provides clean, rich-style progress bars suitable for various environments including Weights & Biases, console
|
|
22
|
+
outputs, and other logging systems. Features zero external dependencies, clean single-line output, rich-style
|
|
23
|
+
progress bars with Unicode block characters, context manager support, iterator protocol support, and dynamic
|
|
24
|
+
description updates.
|
|
25
|
+
|
|
26
|
+
Attributes:
|
|
27
|
+
iterable (object): Iterable to wrap with progress bar.
|
|
28
|
+
desc (str): Prefix description for the progress bar.
|
|
29
|
+
total (int): Expected number of iterations.
|
|
30
|
+
disable (bool): Whether to disable the progress bar.
|
|
31
|
+
unit (str): String for units of iteration.
|
|
32
|
+
unit_scale (bool): Auto-scale units flag.
|
|
33
|
+
unit_divisor (int): Divisor for unit scaling.
|
|
34
|
+
leave (bool): Whether to leave the progress bar after completion.
|
|
35
|
+
mininterval (float): Minimum time interval between updates.
|
|
36
|
+
initial (int): Initial counter value.
|
|
37
|
+
n (int): Current iteration count.
|
|
38
|
+
closed (bool): Whether the progress bar is closed.
|
|
39
|
+
bar_format (str): Custom bar format string.
|
|
40
|
+
file (object): Output file stream.
|
|
41
|
+
|
|
42
|
+
Methods:
|
|
43
|
+
update: Update progress by n steps.
|
|
44
|
+
set_description: Set or update the description.
|
|
45
|
+
set_postfix: Set postfix for the progress bar.
|
|
46
|
+
close: Close the progress bar and clean up.
|
|
47
|
+
refresh: Refresh the progress bar display.
|
|
48
|
+
clear: Clear the progress bar from display.
|
|
49
|
+
write: Write a message without breaking the progress bar.
|
|
50
|
+
|
|
51
|
+
Examples:
|
|
52
|
+
Basic usage with iterator:
|
|
53
|
+
>>> for i in TQDM(range(100)):
|
|
54
|
+
... time.sleep(0.01)
|
|
55
|
+
|
|
56
|
+
With custom description:
|
|
57
|
+
>>> pbar = TQDM(range(100), desc="Processing")
|
|
58
|
+
>>> for i in pbar:
|
|
59
|
+
... pbar.set_description(f"Processing item {i}")
|
|
60
|
+
|
|
61
|
+
Context manager usage:
|
|
62
|
+
>>> with TQDM(total=100, unit="B", unit_scale=True) as pbar:
|
|
63
|
+
... for i in range(100):
|
|
64
|
+
... pbar.update(1)
|
|
65
|
+
|
|
66
|
+
Manual updates:
|
|
67
|
+
>>> pbar = TQDM(total=100, desc="Training")
|
|
68
|
+
>>> for epoch in range(100):
|
|
69
|
+
... # Do work
|
|
70
|
+
... pbar.update(1)
|
|
71
|
+
>>> pbar.close()
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
# Constants
|
|
75
|
+
MIN_RATE_CALC_INTERVAL = 0.01 # Minimum time interval for rate calculation
|
|
76
|
+
RATE_SMOOTHING_FACTOR = 0.3 # Factor for exponential smoothing of rates
|
|
77
|
+
MAX_SMOOTHED_RATE = 1000000 # Maximum rate to apply smoothing to
|
|
78
|
+
NONINTERACTIVE_MIN_INTERVAL = 60.0 # Minimum interval for non-interactive environments
|
|
79
|
+
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
iterable: Any = None,
|
|
83
|
+
desc: str | None = None,
|
|
84
|
+
total: int | None = None,
|
|
85
|
+
leave: bool = True,
|
|
86
|
+
file: IO[str] | None = None,
|
|
87
|
+
mininterval: float = 0.1,
|
|
88
|
+
disable: bool | None = None,
|
|
89
|
+
unit: str = "it",
|
|
90
|
+
unit_scale: bool = True,
|
|
91
|
+
unit_divisor: int = 1000,
|
|
92
|
+
bar_format: str | None = None, # kept for API compatibility; not used for formatting
|
|
93
|
+
initial: int = 0,
|
|
94
|
+
**kwargs,
|
|
95
|
+
) -> None:
|
|
96
|
+
"""Initialize the TQDM progress bar with specified configuration options.
|
|
97
|
+
|
|
98
|
+
Args:
|
|
99
|
+
iterable (object, optional): Iterable to wrap with progress bar.
|
|
100
|
+
desc (str, optional): Prefix description for the progress bar.
|
|
101
|
+
total (int, optional): Expected number of iterations.
|
|
102
|
+
leave (bool, optional): Whether to leave the progress bar after completion.
|
|
103
|
+
file (object, optional): Output file stream for progress display.
|
|
104
|
+
mininterval (float, optional): Minimum time interval between updates (default 0.1s, 60s in GitHub Actions).
|
|
105
|
+
disable (bool, optional): Whether to disable the progress bar. Auto-detected if None.
|
|
106
|
+
unit (str, optional): String for units of iteration (default "it" for items).
|
|
107
|
+
unit_scale (bool, optional): Auto-scale units for bytes/data units.
|
|
108
|
+
unit_divisor (int, optional): Divisor for unit scaling (default 1000).
|
|
109
|
+
bar_format (str, optional): Custom bar format string.
|
|
110
|
+
initial (int, optional): Initial counter value.
|
|
111
|
+
**kwargs (Any): Additional keyword arguments for compatibility (ignored).
|
|
112
|
+
"""
|
|
113
|
+
# Disable if not verbose
|
|
114
|
+
if disable is None:
|
|
115
|
+
try:
|
|
116
|
+
from ultralytics.utils import LOGGER, VERBOSE
|
|
117
|
+
|
|
118
|
+
disable = not VERBOSE or LOGGER.getEffectiveLevel() > 20
|
|
119
|
+
except ImportError:
|
|
120
|
+
disable = False
|
|
121
|
+
|
|
122
|
+
self.iterable = iterable
|
|
123
|
+
self.desc = desc or ""
|
|
124
|
+
self.total = total or (len(iterable) if hasattr(iterable, "__len__") else None) or None # prevent total=0
|
|
125
|
+
self.disable = disable
|
|
126
|
+
self.unit = unit
|
|
127
|
+
self.unit_scale = unit_scale
|
|
128
|
+
self.unit_divisor = unit_divisor
|
|
129
|
+
self.leave = leave
|
|
130
|
+
self.noninteractive = is_noninteractive_console()
|
|
131
|
+
self.mininterval = max(mininterval, self.NONINTERACTIVE_MIN_INTERVAL) if self.noninteractive else mininterval
|
|
132
|
+
self.initial = initial
|
|
133
|
+
|
|
134
|
+
# Kept for API compatibility (unused for f-string formatting)
|
|
135
|
+
self.bar_format = bar_format
|
|
136
|
+
|
|
137
|
+
self.file = file or sys.stdout
|
|
138
|
+
|
|
139
|
+
# Internal state
|
|
140
|
+
self.n = self.initial
|
|
141
|
+
self.last_print_n = self.initial
|
|
142
|
+
self.last_print_t = time.time()
|
|
143
|
+
self.start_t = time.time()
|
|
144
|
+
self.last_rate = 0.0
|
|
145
|
+
self.closed = False
|
|
146
|
+
self.is_bytes = unit_scale and unit in {"B", "bytes"}
|
|
147
|
+
self.scales = (
|
|
148
|
+
[(1073741824, "GB/s"), (1048576, "MB/s"), (1024, "KB/s")]
|
|
149
|
+
if self.is_bytes
|
|
150
|
+
else [(1e9, f"G{self.unit}/s"), (1e6, f"M{self.unit}/s"), (1e3, f"K{self.unit}/s")]
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if not self.disable and self.total and not self.noninteractive:
|
|
154
|
+
self._display()
|
|
155
|
+
|
|
156
|
+
def _format_rate(self, rate: float) -> str:
|
|
157
|
+
"""Format rate with units, switching between it/s and s/it for readability."""
|
|
158
|
+
if rate <= 0:
|
|
159
|
+
return ""
|
|
160
|
+
|
|
161
|
+
inv_rate = 1 / rate if rate else None
|
|
162
|
+
|
|
163
|
+
# Use s/it format when inv_rate > 1 (i.e., rate < 1 it/s) for better readability
|
|
164
|
+
if inv_rate and inv_rate > 1:
|
|
165
|
+
return f"{inv_rate:.1f}s/B" if self.is_bytes else f"{inv_rate:.1f}s/{self.unit}"
|
|
166
|
+
|
|
167
|
+
# Use it/s format for fast iterations
|
|
168
|
+
fallback = f"{rate:.1f}B/s" if self.is_bytes else f"{rate:.1f}{self.unit}/s"
|
|
169
|
+
return next((f"{rate / t:.1f}{u}" for t, u in self.scales if rate >= t), fallback)
|
|
170
|
+
|
|
171
|
+
def _format_num(self, num: int | float) -> str:
|
|
172
|
+
"""Format number with optional unit scaling."""
|
|
173
|
+
if not self.unit_scale or not self.is_bytes:
|
|
174
|
+
return str(num)
|
|
175
|
+
|
|
176
|
+
for unit in ("", "K", "M", "G", "T"):
|
|
177
|
+
if abs(num) < self.unit_divisor:
|
|
178
|
+
return f"{num:3.1f}{unit}B" if unit else f"{num:.0f}B"
|
|
179
|
+
num /= self.unit_divisor
|
|
180
|
+
return f"{num:.1f}PB"
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def _format_time(seconds: float) -> str:
|
|
184
|
+
"""Format time duration."""
|
|
185
|
+
if seconds < 60:
|
|
186
|
+
return f"{seconds:.1f}s"
|
|
187
|
+
elif seconds < 3600:
|
|
188
|
+
return f"{int(seconds // 60)}:{seconds % 60:02.0f}"
|
|
189
|
+
else:
|
|
190
|
+
h, m = int(seconds // 3600), int((seconds % 3600) // 60)
|
|
191
|
+
return f"{h}:{m:02d}:{seconds % 60:02.0f}"
|
|
192
|
+
|
|
193
|
+
def _generate_bar(self, width: int = 12) -> str:
|
|
194
|
+
"""Generate progress bar."""
|
|
195
|
+
if self.total is None:
|
|
196
|
+
return "━" * width if self.closed else "─" * width
|
|
197
|
+
|
|
198
|
+
frac = min(1.0, self.n / self.total)
|
|
199
|
+
filled = int(frac * width)
|
|
200
|
+
bar = "━" * filled + "─" * (width - filled)
|
|
201
|
+
if filled < width and frac * width - filled > 0.5:
|
|
202
|
+
bar = f"{bar[:filled]}╸{bar[filled + 1 :]}"
|
|
203
|
+
return bar
|
|
204
|
+
|
|
205
|
+
def _should_update(self, dt: float, dn: int) -> bool:
|
|
206
|
+
"""Check if display should update."""
|
|
207
|
+
if self.noninteractive:
|
|
208
|
+
return False
|
|
209
|
+
return (self.total is not None and self.n >= self.total) or (dt >= self.mininterval)
|
|
210
|
+
|
|
211
|
+
def _display(self, final: bool = False) -> None:
|
|
212
|
+
"""Display progress bar."""
|
|
213
|
+
if self.disable or (self.closed and not final):
|
|
214
|
+
return
|
|
215
|
+
|
|
216
|
+
current_time = time.time()
|
|
217
|
+
dt = current_time - self.last_print_t
|
|
218
|
+
dn = self.n - self.last_print_n
|
|
219
|
+
|
|
220
|
+
if not final and not self._should_update(dt, dn):
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
# Calculate rate (avoid crazy numbers)
|
|
224
|
+
if dt > self.MIN_RATE_CALC_INTERVAL:
|
|
225
|
+
rate = dn / dt if dt else 0.0
|
|
226
|
+
# Smooth rate for reasonable values, use raw rate for very high values
|
|
227
|
+
if rate < self.MAX_SMOOTHED_RATE:
|
|
228
|
+
self.last_rate = self.RATE_SMOOTHING_FACTOR * rate + (1 - self.RATE_SMOOTHING_FACTOR) * self.last_rate
|
|
229
|
+
rate = self.last_rate
|
|
230
|
+
else:
|
|
231
|
+
rate = self.last_rate
|
|
232
|
+
|
|
233
|
+
# At completion, use overall rate
|
|
234
|
+
if self.total and self.n >= self.total:
|
|
235
|
+
overall_elapsed = current_time - self.start_t
|
|
236
|
+
if overall_elapsed > 0:
|
|
237
|
+
rate = self.n / overall_elapsed
|
|
238
|
+
|
|
239
|
+
# Update counters
|
|
240
|
+
self.last_print_n = self.n
|
|
241
|
+
self.last_print_t = current_time
|
|
242
|
+
elapsed = current_time - self.start_t
|
|
243
|
+
|
|
244
|
+
# Remaining time
|
|
245
|
+
remaining_str = ""
|
|
246
|
+
if self.total and 0 < self.n < self.total and elapsed > 0:
|
|
247
|
+
est_rate = rate or (self.n / elapsed)
|
|
248
|
+
remaining_str = f"<{self._format_time((self.total - self.n) / est_rate)}"
|
|
249
|
+
|
|
250
|
+
# Numbers and percent
|
|
251
|
+
if self.total:
|
|
252
|
+
percent = (self.n / self.total) * 100
|
|
253
|
+
n_str = self._format_num(self.n)
|
|
254
|
+
t_str = self._format_num(self.total)
|
|
255
|
+
if self.is_bytes and n_str[-2] == t_str[-2]: # Collapse suffix only when identical (e.g. "5.4/5.4MB")
|
|
256
|
+
n_str = n_str.rstrip("KMGTPB")
|
|
257
|
+
else:
|
|
258
|
+
percent = 0.0
|
|
259
|
+
n_str, t_str = self._format_num(self.n), "?"
|
|
260
|
+
|
|
261
|
+
elapsed_str = self._format_time(elapsed)
|
|
262
|
+
rate_str = self._format_rate(rate) or (self._format_rate(self.n / elapsed) if elapsed > 0 else "")
|
|
263
|
+
|
|
264
|
+
bar = self._generate_bar()
|
|
265
|
+
|
|
266
|
+
# Compose progress line via f-strings (two shapes: with/without total)
|
|
267
|
+
if self.total:
|
|
268
|
+
if self.is_bytes and self.n >= self.total:
|
|
269
|
+
# Completed bytes: show only final size
|
|
270
|
+
progress_str = f"{self.desc}: {percent:.0f}% {bar} {t_str} {rate_str} {elapsed_str}"
|
|
271
|
+
else:
|
|
272
|
+
progress_str = (
|
|
273
|
+
f"{self.desc}: {percent:.0f}% {bar} {n_str}/{t_str} {rate_str} {elapsed_str}{remaining_str}"
|
|
274
|
+
)
|
|
275
|
+
else:
|
|
276
|
+
progress_str = f"{self.desc}: {bar} {n_str} {rate_str} {elapsed_str}"
|
|
277
|
+
|
|
278
|
+
# Write to output
|
|
279
|
+
try:
|
|
280
|
+
if self.noninteractive:
|
|
281
|
+
# In non-interactive environments, avoid carriage return which creates empty lines
|
|
282
|
+
self.file.write(progress_str)
|
|
283
|
+
else:
|
|
284
|
+
# In interactive terminals, use carriage return and clear line for updating display
|
|
285
|
+
self.file.write(f"\r\033[K{progress_str}")
|
|
286
|
+
self.file.flush()
|
|
287
|
+
except Exception:
|
|
288
|
+
pass
|
|
289
|
+
|
|
290
|
+
def update(self, n: int = 1) -> None:
|
|
291
|
+
"""Update progress by n steps."""
|
|
292
|
+
if not self.disable and not self.closed:
|
|
293
|
+
self.n += n
|
|
294
|
+
self._display()
|
|
295
|
+
|
|
296
|
+
def set_description(self, desc: str | None) -> None:
|
|
297
|
+
"""Set description."""
|
|
298
|
+
self.desc = desc or ""
|
|
299
|
+
if not self.disable:
|
|
300
|
+
self._display()
|
|
301
|
+
|
|
302
|
+
def set_postfix(self, **kwargs: Any) -> None:
|
|
303
|
+
"""Set postfix (appends to description)."""
|
|
304
|
+
if kwargs:
|
|
305
|
+
postfix = ", ".join(f"{k}={v}" for k, v in kwargs.items())
|
|
306
|
+
base_desc = self.desc.split(" | ")[0] if " | " in self.desc else self.desc
|
|
307
|
+
self.set_description(f"{base_desc} | {postfix}")
|
|
308
|
+
|
|
309
|
+
def close(self) -> None:
|
|
310
|
+
"""Close progress bar."""
|
|
311
|
+
if self.closed:
|
|
312
|
+
return
|
|
313
|
+
|
|
314
|
+
self.closed = True
|
|
315
|
+
|
|
316
|
+
if not self.disable:
|
|
317
|
+
# Final display
|
|
318
|
+
if self.total and self.n >= self.total:
|
|
319
|
+
self.n = self.total
|
|
320
|
+
if self.n != self.last_print_n: # Skip if 100% already shown
|
|
321
|
+
self._display(final=True)
|
|
322
|
+
else:
|
|
323
|
+
self._display(final=True)
|
|
324
|
+
|
|
325
|
+
# Cleanup
|
|
326
|
+
if self.leave:
|
|
327
|
+
self.file.write("\n")
|
|
328
|
+
else:
|
|
329
|
+
self.file.write("\r\033[K")
|
|
330
|
+
|
|
331
|
+
try:
|
|
332
|
+
self.file.flush()
|
|
333
|
+
except Exception:
|
|
334
|
+
pass
|
|
335
|
+
|
|
336
|
+
def __enter__(self) -> TQDM:
|
|
337
|
+
"""Enter context manager."""
|
|
338
|
+
return self
|
|
339
|
+
|
|
340
|
+
def __exit__(self, *args: Any) -> None:
|
|
341
|
+
"""Exit context manager and close progress bar."""
|
|
342
|
+
self.close()
|
|
343
|
+
|
|
344
|
+
def __iter__(self) -> Any:
|
|
345
|
+
"""Iterate over the wrapped iterable with progress updates."""
|
|
346
|
+
if self.iterable is None:
|
|
347
|
+
raise TypeError("'NoneType' object is not iterable")
|
|
348
|
+
|
|
349
|
+
try:
|
|
350
|
+
for item in self.iterable:
|
|
351
|
+
yield item
|
|
352
|
+
self.update(1)
|
|
353
|
+
finally:
|
|
354
|
+
self.close()
|
|
355
|
+
|
|
356
|
+
def __del__(self) -> None:
|
|
357
|
+
"""Destructor to ensure cleanup."""
|
|
358
|
+
try:
|
|
359
|
+
self.close()
|
|
360
|
+
except Exception:
|
|
361
|
+
pass
|
|
362
|
+
|
|
363
|
+
def refresh(self) -> None:
|
|
364
|
+
"""Refresh display."""
|
|
365
|
+
if not self.disable:
|
|
366
|
+
self._display()
|
|
367
|
+
|
|
368
|
+
def clear(self) -> None:
|
|
369
|
+
"""Clear progress bar."""
|
|
370
|
+
if not self.disable:
|
|
371
|
+
try:
|
|
372
|
+
self.file.write("\r\033[K")
|
|
373
|
+
self.file.flush()
|
|
374
|
+
except Exception:
|
|
375
|
+
pass
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def write(s: str, file: IO[str] | None = None, end: str = "\n") -> None:
|
|
379
|
+
"""Static method to write without breaking progress bar."""
|
|
380
|
+
file = file or sys.stdout
|
|
381
|
+
try:
|
|
382
|
+
file.write(s + end)
|
|
383
|
+
file.flush()
|
|
384
|
+
except Exception:
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
if __name__ == "__main__":
|
|
389
|
+
import time
|
|
390
|
+
|
|
391
|
+
print("1. Basic progress bar with known total:")
|
|
392
|
+
for i in TQDM(range(3), desc="Known total"):
|
|
393
|
+
time.sleep(0.05)
|
|
394
|
+
|
|
395
|
+
print("\n2. Manual updates with known total:")
|
|
396
|
+
pbar = TQDM(total=300, desc="Manual updates", unit="files")
|
|
397
|
+
for i in range(300):
|
|
398
|
+
time.sleep(0.03)
|
|
399
|
+
pbar.update(1)
|
|
400
|
+
if i % 10 == 9:
|
|
401
|
+
pbar.set_description(f"Processing batch {i // 10 + 1}")
|
|
402
|
+
pbar.close()
|
|
403
|
+
|
|
404
|
+
print("\n3. Progress bar with unknown total:")
|
|
405
|
+
pbar = TQDM(desc="Unknown total", unit="items")
|
|
406
|
+
for i in range(25):
|
|
407
|
+
time.sleep(0.08)
|
|
408
|
+
pbar.update(1)
|
|
409
|
+
if i % 5 == 4:
|
|
410
|
+
pbar.set_postfix(processed=i + 1, status="OK")
|
|
411
|
+
pbar.close()
|
|
412
|
+
|
|
413
|
+
print("\n4. Context manager with unknown total:")
|
|
414
|
+
with TQDM(desc="Processing stream", unit="B", unit_scale=True, unit_divisor=1024) as pbar:
|
|
415
|
+
for i in range(30):
|
|
416
|
+
time.sleep(0.1)
|
|
417
|
+
pbar.update(1024 * 1024 * i) # Simulate processing MB of data
|
|
418
|
+
|
|
419
|
+
print("\n5. Iterator with unknown length:")
|
|
420
|
+
|
|
421
|
+
def data_stream():
|
|
422
|
+
"""Simulate a data stream of unknown length."""
|
|
423
|
+
import random
|
|
424
|
+
|
|
425
|
+
for i in range(random.randint(10, 20)):
|
|
426
|
+
yield f"data_chunk_{i}"
|
|
427
|
+
|
|
428
|
+
for chunk in TQDM(data_stream(), desc="Stream processing", unit="chunks"):
|
|
429
|
+
time.sleep(0.1)
|
|
430
|
+
|
|
431
|
+
print("\n6. File processing simulation (unknown size):")
|
|
432
|
+
|
|
433
|
+
def process_files():
|
|
434
|
+
"""Simulate processing files of unknown count."""
|
|
435
|
+
return [f"file_{i}.txt" for i in range(18)]
|
|
436
|
+
|
|
437
|
+
pbar = TQDM(desc="Scanning files", unit="files")
|
|
438
|
+
files = process_files()
|
|
439
|
+
for i, filename in enumerate(files):
|
|
440
|
+
time.sleep(0.06)
|
|
441
|
+
pbar.update(1)
|
|
442
|
+
pbar.set_description(f"Processing {filename}")
|
|
443
|
+
pbar.close()
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
from urllib.parse import urlsplit
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class TritonRemoteModel:
|
|
12
|
+
"""Client for interacting with a remote Triton Inference Server model.
|
|
13
|
+
|
|
14
|
+
This class provides a convenient interface for sending inference requests to a Triton Inference Server and
|
|
15
|
+
processing the responses. Supports both HTTP and gRPC communication protocols.
|
|
16
|
+
|
|
17
|
+
Attributes:
|
|
18
|
+
endpoint (str): The name of the model on the Triton server.
|
|
19
|
+
url (str): The URL of the Triton server.
|
|
20
|
+
triton_client: The Triton client (either HTTP or gRPC).
|
|
21
|
+
InferInput: The input class for the Triton client.
|
|
22
|
+
InferRequestedOutput: The output request class for the Triton client.
|
|
23
|
+
input_formats (list[str]): The data types of the model inputs.
|
|
24
|
+
np_input_formats (list[type]): The numpy data types of the model inputs.
|
|
25
|
+
input_names (list[str]): The names of the model inputs.
|
|
26
|
+
output_names (list[str]): The names of the model outputs.
|
|
27
|
+
metadata: The metadata associated with the model.
|
|
28
|
+
|
|
29
|
+
Methods:
|
|
30
|
+
__call__: Call the model with the given inputs and return the outputs.
|
|
31
|
+
|
|
32
|
+
Examples:
|
|
33
|
+
Initialize a Triton client with HTTP
|
|
34
|
+
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
|
35
|
+
|
|
36
|
+
Make inference with numpy arrays
|
|
37
|
+
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
def __init__(self, url: str, endpoint: str = "", scheme: str = ""):
|
|
41
|
+
"""Initialize the TritonRemoteModel for interacting with a remote Triton Inference Server.
|
|
42
|
+
|
|
43
|
+
Arguments may be provided individually or parsed from a collective 'url' argument of the form
|
|
44
|
+
<scheme>://<netloc>/<endpoint>/<task_name>
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
url (str): The URL of the Triton server.
|
|
48
|
+
endpoint (str, optional): The name of the model on the Triton server.
|
|
49
|
+
scheme (str, optional): The communication scheme ('http' or 'grpc').
|
|
50
|
+
"""
|
|
51
|
+
if not endpoint and not scheme: # Parse all args from URL string
|
|
52
|
+
splits = urlsplit(url)
|
|
53
|
+
endpoint = splits.path.strip("/").split("/", 1)[0]
|
|
54
|
+
scheme = splits.scheme
|
|
55
|
+
url = splits.netloc
|
|
56
|
+
|
|
57
|
+
self.endpoint = endpoint
|
|
58
|
+
self.url = url
|
|
59
|
+
|
|
60
|
+
# Choose the Triton client based on the communication scheme
|
|
61
|
+
if scheme == "http":
|
|
62
|
+
import tritonclient.http as client
|
|
63
|
+
|
|
64
|
+
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
|
65
|
+
config = self.triton_client.get_model_config(endpoint)
|
|
66
|
+
else:
|
|
67
|
+
import tritonclient.grpc as client
|
|
68
|
+
|
|
69
|
+
self.triton_client = client.InferenceServerClient(url=self.url, verbose=False, ssl=False)
|
|
70
|
+
config = self.triton_client.get_model_config(endpoint, as_json=True)["config"]
|
|
71
|
+
|
|
72
|
+
# Sort output names alphabetically, i.e. 'output0', 'output1', etc.
|
|
73
|
+
config["output"] = sorted(config["output"], key=lambda x: x.get("name"))
|
|
74
|
+
|
|
75
|
+
# Define model attributes
|
|
76
|
+
type_map = {"TYPE_FP32": np.float32, "TYPE_FP16": np.float16, "TYPE_UINT8": np.uint8}
|
|
77
|
+
self.InferRequestedOutput = client.InferRequestedOutput
|
|
78
|
+
self.InferInput = client.InferInput
|
|
79
|
+
self.input_formats = [x["data_type"] for x in config["input"]]
|
|
80
|
+
self.np_input_formats = [type_map[x] for x in self.input_formats]
|
|
81
|
+
self.input_names = [x["name"] for x in config["input"]]
|
|
82
|
+
self.output_names = [x["name"] for x in config["output"]]
|
|
83
|
+
self.metadata = ast.literal_eval(config.get("parameters", {}).get("metadata", {}).get("string_value", "None"))
|
|
84
|
+
|
|
85
|
+
def __call__(self, *inputs: np.ndarray) -> list[np.ndarray]:
|
|
86
|
+
"""Call the model with the given inputs and return inference results.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
*inputs (np.ndarray): Input data to the model. Each array should match the expected shape and type for the
|
|
90
|
+
corresponding model input.
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
(list[np.ndarray]): Model outputs with the same dtype as the input. Each element in the list corresponds to
|
|
94
|
+
one of the model's output tensors.
|
|
95
|
+
|
|
96
|
+
Examples:
|
|
97
|
+
>>> model = TritonRemoteModel(url="localhost:8000", endpoint="yolov8", scheme="http")
|
|
98
|
+
>>> outputs = model(np.random.rand(1, 3, 640, 640).astype(np.float32))
|
|
99
|
+
"""
|
|
100
|
+
infer_inputs = []
|
|
101
|
+
input_format = inputs[0].dtype
|
|
102
|
+
for i, x in enumerate(inputs):
|
|
103
|
+
if x.dtype != self.np_input_formats[i]:
|
|
104
|
+
x = x.astype(self.np_input_formats[i])
|
|
105
|
+
infer_input = self.InferInput(self.input_names[i], [*x.shape], self.input_formats[i].replace("TYPE_", ""))
|
|
106
|
+
infer_input.set_data_from_numpy(x)
|
|
107
|
+
infer_inputs.append(infer_input)
|
|
108
|
+
|
|
109
|
+
infer_outputs = [self.InferRequestedOutput(output_name) for output_name in self.output_names]
|
|
110
|
+
outputs = self.triton_client.infer(model_name=self.endpoint, inputs=infer_inputs, outputs=infer_outputs)
|
|
111
|
+
|
|
112
|
+
return [outputs.as_numpy(output_name).astype(input_format) for output_name in self.output_names]
|