triton-model-analyzer 1.48.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_analyzer/__init__.py +15 -0
- model_analyzer/analyzer.py +448 -0
- model_analyzer/cli/__init__.py +15 -0
- model_analyzer/cli/cli.py +193 -0
- model_analyzer/config/__init__.py +15 -0
- model_analyzer/config/generate/__init__.py +15 -0
- model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
- model_analyzer/config/generate/base_model_config_generator.py +352 -0
- model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
- model_analyzer/config/generate/brute_run_config_generator.py +154 -0
- model_analyzer/config/generate/concurrency_sweeper.py +75 -0
- model_analyzer/config/generate/config_generator_interface.py +52 -0
- model_analyzer/config/generate/coordinate.py +143 -0
- model_analyzer/config/generate/coordinate_data.py +86 -0
- model_analyzer/config/generate/generator_utils.py +116 -0
- model_analyzer/config/generate/manual_model_config_generator.py +187 -0
- model_analyzer/config/generate/model_config_generator_factory.py +92 -0
- model_analyzer/config/generate/model_profile_spec.py +74 -0
- model_analyzer/config/generate/model_run_config_generator.py +154 -0
- model_analyzer/config/generate/model_variant_name_manager.py +150 -0
- model_analyzer/config/generate/neighborhood.py +536 -0
- model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
- model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
- model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
- model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
- model_analyzer/config/generate/quick_run_config_generator.py +753 -0
- model_analyzer/config/generate/run_config_generator_factory.py +329 -0
- model_analyzer/config/generate/search_config.py +112 -0
- model_analyzer/config/generate/search_dimension.py +73 -0
- model_analyzer/config/generate/search_dimensions.py +85 -0
- model_analyzer/config/generate/search_parameter.py +49 -0
- model_analyzer/config/generate/search_parameters.py +388 -0
- model_analyzer/config/input/__init__.py +15 -0
- model_analyzer/config/input/config_command.py +483 -0
- model_analyzer/config/input/config_command_profile.py +1747 -0
- model_analyzer/config/input/config_command_report.py +267 -0
- model_analyzer/config/input/config_defaults.py +236 -0
- model_analyzer/config/input/config_enum.py +83 -0
- model_analyzer/config/input/config_field.py +216 -0
- model_analyzer/config/input/config_list_generic.py +112 -0
- model_analyzer/config/input/config_list_numeric.py +151 -0
- model_analyzer/config/input/config_list_string.py +111 -0
- model_analyzer/config/input/config_none.py +71 -0
- model_analyzer/config/input/config_object.py +129 -0
- model_analyzer/config/input/config_primitive.py +81 -0
- model_analyzer/config/input/config_status.py +75 -0
- model_analyzer/config/input/config_sweep.py +83 -0
- model_analyzer/config/input/config_union.py +113 -0
- model_analyzer/config/input/config_utils.py +128 -0
- model_analyzer/config/input/config_value.py +243 -0
- model_analyzer/config/input/objects/__init__.py +15 -0
- model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
- model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
- model_analyzer/config/input/objects/config_plot.py +198 -0
- model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
- model_analyzer/config/input/yaml_config_validator.py +82 -0
- model_analyzer/config/run/__init__.py +15 -0
- model_analyzer/config/run/model_run_config.py +313 -0
- model_analyzer/config/run/run_config.py +168 -0
- model_analyzer/constants.py +76 -0
- model_analyzer/device/__init__.py +15 -0
- model_analyzer/device/device.py +24 -0
- model_analyzer/device/gpu_device.py +87 -0
- model_analyzer/device/gpu_device_factory.py +248 -0
- model_analyzer/entrypoint.py +307 -0
- model_analyzer/log_formatter.py +65 -0
- model_analyzer/model_analyzer_exceptions.py +24 -0
- model_analyzer/model_manager.py +255 -0
- model_analyzer/monitor/__init__.py +15 -0
- model_analyzer/monitor/cpu_monitor.py +69 -0
- model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
- model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
- model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
- model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
- model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
- model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
- model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
- model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
- model_analyzer/monitor/dcgm/__init__.py +15 -0
- model_analyzer/monitor/dcgm/common/__init__.py +13 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
- model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
- model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
- model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
- model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
- model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
- model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
- model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
- model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
- model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
- model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
- model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
- model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
- model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
- model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
- model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
- model_analyzer/monitor/dcgm/pydcgm.py +47 -0
- model_analyzer/monitor/monitor.py +143 -0
- model_analyzer/monitor/remote_monitor.py +137 -0
- model_analyzer/output/__init__.py +15 -0
- model_analyzer/output/file_writer.py +63 -0
- model_analyzer/output/output_writer.py +42 -0
- model_analyzer/perf_analyzer/__init__.py +15 -0
- model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
- model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
- model_analyzer/perf_analyzer/perf_config.py +479 -0
- model_analyzer/plots/__init__.py +15 -0
- model_analyzer/plots/detailed_plot.py +266 -0
- model_analyzer/plots/plot_manager.py +224 -0
- model_analyzer/plots/simple_plot.py +213 -0
- model_analyzer/record/__init__.py +15 -0
- model_analyzer/record/gpu_record.py +68 -0
- model_analyzer/record/metrics_manager.py +887 -0
- model_analyzer/record/record.py +280 -0
- model_analyzer/record/record_aggregator.py +256 -0
- model_analyzer/record/types/__init__.py +15 -0
- model_analyzer/record/types/cpu_available_ram.py +93 -0
- model_analyzer/record/types/cpu_used_ram.py +93 -0
- model_analyzer/record/types/gpu_free_memory.py +96 -0
- model_analyzer/record/types/gpu_power_usage.py +107 -0
- model_analyzer/record/types/gpu_total_memory.py +96 -0
- model_analyzer/record/types/gpu_used_memory.py +96 -0
- model_analyzer/record/types/gpu_utilization.py +108 -0
- model_analyzer/record/types/inter_token_latency_avg.py +60 -0
- model_analyzer/record/types/inter_token_latency_base.py +74 -0
- model_analyzer/record/types/inter_token_latency_max.py +60 -0
- model_analyzer/record/types/inter_token_latency_min.py +60 -0
- model_analyzer/record/types/inter_token_latency_p25.py +60 -0
- model_analyzer/record/types/inter_token_latency_p50.py +60 -0
- model_analyzer/record/types/inter_token_latency_p75.py +60 -0
- model_analyzer/record/types/inter_token_latency_p90.py +60 -0
- model_analyzer/record/types/inter_token_latency_p95.py +60 -0
- model_analyzer/record/types/inter_token_latency_p99.py +60 -0
- model_analyzer/record/types/output_token_throughput.py +105 -0
- model_analyzer/record/types/perf_client_response_wait.py +97 -0
- model_analyzer/record/types/perf_client_send_recv.py +97 -0
- model_analyzer/record/types/perf_latency.py +111 -0
- model_analyzer/record/types/perf_latency_avg.py +60 -0
- model_analyzer/record/types/perf_latency_base.py +74 -0
- model_analyzer/record/types/perf_latency_p90.py +60 -0
- model_analyzer/record/types/perf_latency_p95.py +60 -0
- model_analyzer/record/types/perf_latency_p99.py +60 -0
- model_analyzer/record/types/perf_server_compute_infer.py +97 -0
- model_analyzer/record/types/perf_server_compute_input.py +97 -0
- model_analyzer/record/types/perf_server_compute_output.py +97 -0
- model_analyzer/record/types/perf_server_queue.py +97 -0
- model_analyzer/record/types/perf_throughput.py +105 -0
- model_analyzer/record/types/time_to_first_token_avg.py +60 -0
- model_analyzer/record/types/time_to_first_token_base.py +74 -0
- model_analyzer/record/types/time_to_first_token_max.py +60 -0
- model_analyzer/record/types/time_to_first_token_min.py +60 -0
- model_analyzer/record/types/time_to_first_token_p25.py +60 -0
- model_analyzer/record/types/time_to_first_token_p50.py +60 -0
- model_analyzer/record/types/time_to_first_token_p75.py +60 -0
- model_analyzer/record/types/time_to_first_token_p90.py +60 -0
- model_analyzer/record/types/time_to_first_token_p95.py +60 -0
- model_analyzer/record/types/time_to_first_token_p99.py +60 -0
- model_analyzer/reports/__init__.py +15 -0
- model_analyzer/reports/html_report.py +195 -0
- model_analyzer/reports/pdf_report.py +50 -0
- model_analyzer/reports/report.py +86 -0
- model_analyzer/reports/report_factory.py +62 -0
- model_analyzer/reports/report_manager.py +1376 -0
- model_analyzer/reports/report_utils.py +42 -0
- model_analyzer/result/__init__.py +15 -0
- model_analyzer/result/constraint_manager.py +150 -0
- model_analyzer/result/model_config_measurement.py +354 -0
- model_analyzer/result/model_constraints.py +105 -0
- model_analyzer/result/parameter_search.py +246 -0
- model_analyzer/result/result_manager.py +430 -0
- model_analyzer/result/result_statistics.py +159 -0
- model_analyzer/result/result_table.py +217 -0
- model_analyzer/result/result_table_manager.py +646 -0
- model_analyzer/result/result_utils.py +42 -0
- model_analyzer/result/results.py +277 -0
- model_analyzer/result/run_config_measurement.py +658 -0
- model_analyzer/result/run_config_result.py +210 -0
- model_analyzer/result/run_config_result_comparator.py +110 -0
- model_analyzer/result/sorted_results.py +151 -0
- model_analyzer/state/__init__.py +15 -0
- model_analyzer/state/analyzer_state.py +76 -0
- model_analyzer/state/analyzer_state_manager.py +215 -0
- model_analyzer/triton/__init__.py +15 -0
- model_analyzer/triton/client/__init__.py +15 -0
- model_analyzer/triton/client/client.py +234 -0
- model_analyzer/triton/client/client_factory.py +57 -0
- model_analyzer/triton/client/grpc_client.py +104 -0
- model_analyzer/triton/client/http_client.py +107 -0
- model_analyzer/triton/model/__init__.py +15 -0
- model_analyzer/triton/model/model_config.py +556 -0
- model_analyzer/triton/model/model_config_variant.py +29 -0
- model_analyzer/triton/server/__init__.py +15 -0
- model_analyzer/triton/server/server.py +76 -0
- model_analyzer/triton/server/server_config.py +269 -0
- model_analyzer/triton/server/server_docker.py +229 -0
- model_analyzer/triton/server/server_factory.py +306 -0
- model_analyzer/triton/server/server_local.py +158 -0
- triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
- triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
- triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
- triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
- triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
- triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import glob
|
|
18
|
+
import json
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import signal
|
|
22
|
+
import sys
|
|
23
|
+
import traceback
|
|
24
|
+
|
|
25
|
+
from model_analyzer.constants import LOGGER_NAME, MAX_NUMBER_OF_INTERRUPTS
|
|
26
|
+
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
|
|
27
|
+
from model_analyzer.state.analyzer_state import AnalyzerState
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class AnalyzerStateManager:
|
|
33
|
+
"""
|
|
34
|
+
Maintains the state of the Model Analyzer
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self, config, server):
|
|
38
|
+
"""
|
|
39
|
+
Parameters
|
|
40
|
+
----------
|
|
41
|
+
config: ConfigCommand
|
|
42
|
+
The analyzer's config
|
|
43
|
+
server : TritonServer
|
|
44
|
+
Handle for tritonserver instance
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
self._config = config
|
|
48
|
+
self._server = server
|
|
49
|
+
self._exiting = 0
|
|
50
|
+
self._checkpoint_dir = config.checkpoint_directory
|
|
51
|
+
self._state_changed = False
|
|
52
|
+
|
|
53
|
+
if os.path.exists(self._checkpoint_dir):
|
|
54
|
+
self._checkpoint_index = self._latest_checkpoint() + 1
|
|
55
|
+
else:
|
|
56
|
+
os.makedirs(self._checkpoint_dir)
|
|
57
|
+
self._checkpoint_index = 0
|
|
58
|
+
signal.signal(signal.SIGINT, self.interrupt_handler)
|
|
59
|
+
|
|
60
|
+
self._current_state = AnalyzerState()
|
|
61
|
+
self._starting_fresh_run = True
|
|
62
|
+
|
|
63
|
+
def starting_fresh_run(self):
|
|
64
|
+
"""
|
|
65
|
+
Returns
|
|
66
|
+
-------
|
|
67
|
+
True if starting a fresh run
|
|
68
|
+
False if checkpoint found and loaded
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
return self._starting_fresh_run
|
|
72
|
+
|
|
73
|
+
def exiting(self):
|
|
74
|
+
"""
|
|
75
|
+
Returns
|
|
76
|
+
-------
|
|
77
|
+
True if interrupt handler ran
|
|
78
|
+
even once, False otherwise
|
|
79
|
+
"""
|
|
80
|
+
|
|
81
|
+
return self._exiting > 0
|
|
82
|
+
|
|
83
|
+
def get_state_variable(self, name):
|
|
84
|
+
"""
|
|
85
|
+
Get a named variable from
|
|
86
|
+
the current AnalyzerState
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
name : str
|
|
91
|
+
The name of the variable
|
|
92
|
+
"""
|
|
93
|
+
return self._current_state.get(name)
|
|
94
|
+
|
|
95
|
+
def set_state_variable(self, name, value):
|
|
96
|
+
"""
|
|
97
|
+
Set a named variable from
|
|
98
|
+
the current AnalyzerState
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
name: str
|
|
103
|
+
The name of the variable
|
|
104
|
+
value: Any
|
|
105
|
+
the value to set for that variable
|
|
106
|
+
"""
|
|
107
|
+
|
|
108
|
+
self._state_changed = True
|
|
109
|
+
self._current_state.set(name, value)
|
|
110
|
+
|
|
111
|
+
def load_checkpoint(self, checkpoint_required):
|
|
112
|
+
"""
|
|
113
|
+
Load the state of the Model Analyzer from
|
|
114
|
+
most recent checkpoint file, also
|
|
115
|
+
set whether we are starting a fresh run
|
|
116
|
+
|
|
117
|
+
Parameters
|
|
118
|
+
----------
|
|
119
|
+
checkpoint_required : bool
|
|
120
|
+
If true, an existing checkpoint is required to run MA
|
|
121
|
+
"""
|
|
122
|
+
|
|
123
|
+
latest_checkpoint_file = os.path.join(
|
|
124
|
+
self._checkpoint_dir, f"{self._latest_checkpoint()}.ckpt"
|
|
125
|
+
)
|
|
126
|
+
if os.path.exists(latest_checkpoint_file):
|
|
127
|
+
logger.info(f"Loaded checkpoint from file {latest_checkpoint_file}")
|
|
128
|
+
with open(latest_checkpoint_file, "r") as f:
|
|
129
|
+
try:
|
|
130
|
+
self._current_state = AnalyzerState.from_dict(json.load(f))
|
|
131
|
+
except EOFError:
|
|
132
|
+
raise TritonModelAnalyzerException(
|
|
133
|
+
f"Checkpoint file {latest_checkpoint_file} is"
|
|
134
|
+
" empty or corrupted. Remove it from checkpoint"
|
|
135
|
+
" directory."
|
|
136
|
+
)
|
|
137
|
+
self._starting_fresh_run = False
|
|
138
|
+
else:
|
|
139
|
+
if checkpoint_required:
|
|
140
|
+
raise TritonModelAnalyzerException(f"No checkpoint file found")
|
|
141
|
+
else:
|
|
142
|
+
logger.info("No checkpoint file found, starting a fresh run.")
|
|
143
|
+
|
|
144
|
+
def default_encode(self, obj):
|
|
145
|
+
if isinstance(obj, bytes):
|
|
146
|
+
return obj.decode("utf-8")
|
|
147
|
+
elif hasattr(obj, "to_dict"):
|
|
148
|
+
return obj.to_dict()
|
|
149
|
+
else:
|
|
150
|
+
return obj.__dict__
|
|
151
|
+
|
|
152
|
+
def save_checkpoint(self):
|
|
153
|
+
"""
|
|
154
|
+
Saves the state of the model analyzer to disk
|
|
155
|
+
if there has been a change since the last checkpoint
|
|
156
|
+
|
|
157
|
+
Parameters
|
|
158
|
+
----------
|
|
159
|
+
state: AnalyzerState
|
|
160
|
+
The state object to be saved
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
ckpt_filename = os.path.join(
|
|
164
|
+
self._checkpoint_dir, f"{self._checkpoint_index}.ckpt"
|
|
165
|
+
)
|
|
166
|
+
if self._state_changed:
|
|
167
|
+
with open(ckpt_filename, "w") as f:
|
|
168
|
+
json.dump(self._current_state, f, default=self.default_encode)
|
|
169
|
+
logger.info(f"Saved checkpoint to {ckpt_filename}")
|
|
170
|
+
|
|
171
|
+
self._state_changed = False
|
|
172
|
+
else:
|
|
173
|
+
logger.info(f"No changes made to analyzer data, no checkpoint saved.")
|
|
174
|
+
|
|
175
|
+
def interrupt_handler(self, signal, frame):
|
|
176
|
+
"""
|
|
177
|
+
A signal handler to properly
|
|
178
|
+
shutdown the model analyzer on
|
|
179
|
+
interrupt
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
self._exiting += 1
|
|
183
|
+
if logger.getEffectiveLevel() <= logging.DEBUG:
|
|
184
|
+
traceback.print_stack(limit=15)
|
|
185
|
+
logger.info(
|
|
186
|
+
f"Received SIGINT {self._exiting}/{MAX_NUMBER_OF_INTERRUPTS}. "
|
|
187
|
+
"Will attempt to exit after current measurement."
|
|
188
|
+
)
|
|
189
|
+
if self._exiting >= MAX_NUMBER_OF_INTERRUPTS:
|
|
190
|
+
logger.info(
|
|
191
|
+
f"Received SIGINT maximum number of times. Saving state and exiting immediately. "
|
|
192
|
+
"perf_analyzer may still be running"
|
|
193
|
+
)
|
|
194
|
+
self.save_checkpoint()
|
|
195
|
+
|
|
196
|
+
# Exit server
|
|
197
|
+
if self._server:
|
|
198
|
+
self._server.stop()
|
|
199
|
+
sys.exit(1)
|
|
200
|
+
|
|
201
|
+
def _latest_checkpoint(self):
|
|
202
|
+
"""
|
|
203
|
+
Get the highest index checkpoint file in the
|
|
204
|
+
checkpoint directory, return its index.
|
|
205
|
+
"""
|
|
206
|
+
|
|
207
|
+
checkpoint_files = glob.glob(os.path.join(self._checkpoint_dir, "*.ckpt"))
|
|
208
|
+
if not checkpoint_files:
|
|
209
|
+
return -1
|
|
210
|
+
try:
|
|
211
|
+
return max(
|
|
212
|
+
[int(os.path.split(f)[1].split(".")[0]) for f in checkpoint_files]
|
|
213
|
+
)
|
|
214
|
+
except Exception as e:
|
|
215
|
+
raise TritonModelAnalyzerException(e)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
import time
|
|
19
|
+
from subprocess import DEVNULL
|
|
20
|
+
|
|
21
|
+
from model_analyzer.constants import LOGGER_NAME
|
|
22
|
+
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class TritonClient:
|
|
28
|
+
"""
|
|
29
|
+
Defines the interface for the objects created by
|
|
30
|
+
TritonClientFactory
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def wait_for_server_ready(
|
|
34
|
+
self,
|
|
35
|
+
num_retries,
|
|
36
|
+
sleep_time=1,
|
|
37
|
+
log_file=None,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Parameters
|
|
41
|
+
----------
|
|
42
|
+
num_retries : int
|
|
43
|
+
number of times to send a ready status
|
|
44
|
+
request to the server before raising
|
|
45
|
+
an exception
|
|
46
|
+
sleep_time: int
|
|
47
|
+
amount of time in seconds to sleep between retries
|
|
48
|
+
log_file: TextIOWrapper
|
|
49
|
+
file that contains the server's output log
|
|
50
|
+
Raises
|
|
51
|
+
------
|
|
52
|
+
TritonModelAnalyzerException
|
|
53
|
+
If server readiness could not be
|
|
54
|
+
determined in given num_retries
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
retries = num_retries
|
|
58
|
+
while retries > 0:
|
|
59
|
+
try:
|
|
60
|
+
if self._client.is_server_ready():
|
|
61
|
+
time.sleep(sleep_time)
|
|
62
|
+
return
|
|
63
|
+
else:
|
|
64
|
+
self._check_for_triton_log_errors(log_file)
|
|
65
|
+
time.sleep(sleep_time)
|
|
66
|
+
retries -= 1
|
|
67
|
+
except Exception as e:
|
|
68
|
+
# Log connection failures with more detail for debugging
|
|
69
|
+
if retries == num_retries or retries % 10 == 0:
|
|
70
|
+
logger.debug(
|
|
71
|
+
f"Failed to connect to Triton server (attempt {num_retries - retries + 1}/{num_retries}): {e}"
|
|
72
|
+
)
|
|
73
|
+
self._check_for_triton_log_errors(log_file)
|
|
74
|
+
time.sleep(sleep_time)
|
|
75
|
+
retries -= 1
|
|
76
|
+
if retries == 0:
|
|
77
|
+
raise TritonModelAnalyzerException(e)
|
|
78
|
+
raise TritonModelAnalyzerException(
|
|
79
|
+
"Could not determine server readiness. " "Number of retries exceeded."
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
def load_model(self, model_name, variant_name="", config_str=None):
|
|
83
|
+
"""
|
|
84
|
+
Request the inference server to load
|
|
85
|
+
a particular model in explicit model
|
|
86
|
+
control mode.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
model_name : str
|
|
91
|
+
Name of the model
|
|
92
|
+
|
|
93
|
+
variant_name: str
|
|
94
|
+
Name of the model variant
|
|
95
|
+
|
|
96
|
+
config_str: str
|
|
97
|
+
Optional config string used to load the model
|
|
98
|
+
|
|
99
|
+
Returns
|
|
100
|
+
------
|
|
101
|
+
int or None
|
|
102
|
+
Returns -1 if the failed.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
variant_name = variant_name if variant_name else model_name
|
|
106
|
+
|
|
107
|
+
try:
|
|
108
|
+
self._client.load_model(model_name, config=config_str)
|
|
109
|
+
logger.debug(f"Model {variant_name} loaded")
|
|
110
|
+
return None
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.info(f"Model {variant_name} load failed: {e}")
|
|
113
|
+
if "polling is enabled" in e.message():
|
|
114
|
+
raise TritonModelAnalyzerException(
|
|
115
|
+
"The remote Tritonserver needs to be launched in EXPLICIT mode"
|
|
116
|
+
)
|
|
117
|
+
return -1
|
|
118
|
+
|
|
119
|
+
def unload_model(self, model_name):
|
|
120
|
+
"""
|
|
121
|
+
Request the inference server to unload
|
|
122
|
+
a particular model in explicit model
|
|
123
|
+
control mode.
|
|
124
|
+
|
|
125
|
+
Parameters
|
|
126
|
+
----------
|
|
127
|
+
model_name : str
|
|
128
|
+
name of the model to load from repository
|
|
129
|
+
|
|
130
|
+
Raises
|
|
131
|
+
------
|
|
132
|
+
TritonModelAnalyzerException
|
|
133
|
+
If server throws Exception
|
|
134
|
+
|
|
135
|
+
Returns
|
|
136
|
+
------
|
|
137
|
+
int or None
|
|
138
|
+
Returns -1 if the failed.
|
|
139
|
+
"""
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
self._client.unload_model(model_name)
|
|
143
|
+
logger.debug(f"Model {model_name} unloaded")
|
|
144
|
+
return None
|
|
145
|
+
except Exception as e:
|
|
146
|
+
logger.info(f"Model {model_name} unload failed: {e}")
|
|
147
|
+
return -1
|
|
148
|
+
|
|
149
|
+
def wait_for_model_ready(self, model_name, num_retries, sleep_time=1):
|
|
150
|
+
"""
|
|
151
|
+
Returns when model is ready.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
model_name : str
|
|
156
|
+
name of the model to load from repository
|
|
157
|
+
num_retries : int
|
|
158
|
+
number of times to send a ready status
|
|
159
|
+
request to the server before raising
|
|
160
|
+
an exception
|
|
161
|
+
|
|
162
|
+
Raises
|
|
163
|
+
------
|
|
164
|
+
TritonModelAnalyzerException
|
|
165
|
+
If could not determine model readiness
|
|
166
|
+
in given num_retries
|
|
167
|
+
|
|
168
|
+
Returns
|
|
169
|
+
------
|
|
170
|
+
int or None
|
|
171
|
+
Returns -1 if the failed.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
retries = num_retries
|
|
175
|
+
error = None
|
|
176
|
+
while retries > 0:
|
|
177
|
+
try:
|
|
178
|
+
if self._client.is_model_ready(model_name):
|
|
179
|
+
return None
|
|
180
|
+
else:
|
|
181
|
+
time.sleep(sleep_time)
|
|
182
|
+
retries -= 1
|
|
183
|
+
except Exception as e:
|
|
184
|
+
error = e
|
|
185
|
+
time.sleep(sleep_time)
|
|
186
|
+
retries -= 1
|
|
187
|
+
|
|
188
|
+
logger.info(f"Model readiness failed for model {model_name}. Error {error}")
|
|
189
|
+
return -1
|
|
190
|
+
|
|
191
|
+
def get_model_config(self, model_name, num_retries):
|
|
192
|
+
"""
|
|
193
|
+
Model name to get the config for.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
model_name : str
|
|
198
|
+
Name of the model to find the config.
|
|
199
|
+
|
|
200
|
+
num_retries : int
|
|
201
|
+
Number of times to wait for the model load
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
dict or None
|
|
206
|
+
A dictionary containing the model config.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
self.wait_for_model_ready(model_name, num_retries)
|
|
210
|
+
model_config_dict = self._client.get_model_config(model_name)
|
|
211
|
+
return model_config_dict
|
|
212
|
+
|
|
213
|
+
def is_server_ready(self):
|
|
214
|
+
"""
|
|
215
|
+
Returns true if the server is ready. Else False
|
|
216
|
+
"""
|
|
217
|
+
return self._client.is_server_ready()
|
|
218
|
+
|
|
219
|
+
def _check_for_triton_log_errors(self, log_file):
|
|
220
|
+
if not log_file or log_file == DEVNULL:
|
|
221
|
+
return
|
|
222
|
+
|
|
223
|
+
log_file.seek(0)
|
|
224
|
+
log_output = log_file.read()
|
|
225
|
+
|
|
226
|
+
if not type(log_output) == str:
|
|
227
|
+
log_output = log_output.decode("utf-8")
|
|
228
|
+
|
|
229
|
+
if log_output:
|
|
230
|
+
if "Unexpected argument:" in log_output:
|
|
231
|
+
error_start = log_output.find("Unexpected argument:")
|
|
232
|
+
raise TritonModelAnalyzerException(
|
|
233
|
+
f"Error: TritonServer did not launch successfully\n\n{log_output[error_start:]}"
|
|
234
|
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
from .grpc_client import TritonGRPCClient
|
|
18
|
+
from .http_client import TritonHTTPClient
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class TritonClientFactory:
|
|
22
|
+
"""
|
|
23
|
+
Base client creator class that declares
|
|
24
|
+
a factory method
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def create_grpc_client(server_url, ssl_options={}):
|
|
29
|
+
"""
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
server_url : str
|
|
33
|
+
The url for Triton server's GRPC endpoint
|
|
34
|
+
ssl_options : dict
|
|
35
|
+
Dictionary of SSL options for gRPC python client
|
|
36
|
+
|
|
37
|
+
Returns
|
|
38
|
+
-------
|
|
39
|
+
TritonGRPCClient
|
|
40
|
+
"""
|
|
41
|
+
return TritonGRPCClient(server_url=server_url, ssl_options=ssl_options)
|
|
42
|
+
|
|
43
|
+
@staticmethod
|
|
44
|
+
def create_http_client(server_url, ssl_options={}):
|
|
45
|
+
"""
|
|
46
|
+
Parameters
|
|
47
|
+
----------
|
|
48
|
+
server_url : str
|
|
49
|
+
The url for Triton server's HTTP endpoint
|
|
50
|
+
ssl_options : dict
|
|
51
|
+
Dictionary of SSL options for HTTP python client
|
|
52
|
+
|
|
53
|
+
Returns
|
|
54
|
+
-------
|
|
55
|
+
TritonHTTPClient
|
|
56
|
+
"""
|
|
57
|
+
return TritonHTTPClient(server_url=server_url, ssl_options=ssl_options)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import tritonclient.grpc as grpcclient
|
|
18
|
+
|
|
19
|
+
from .client import TritonClient
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TritonGRPCClient(TritonClient):
|
|
23
|
+
"""
|
|
24
|
+
Concrete implementation of TritonClient
|
|
25
|
+
for GRPC
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, server_url, ssl_options={}):
|
|
29
|
+
"""
|
|
30
|
+
Parameters
|
|
31
|
+
----------
|
|
32
|
+
server_url : str
|
|
33
|
+
The url for Triton server's GRPC endpoint
|
|
34
|
+
ssl_options : dict
|
|
35
|
+
Dictionary of SSL options for gRPC python client
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
ssl = False
|
|
39
|
+
root_certificates = None
|
|
40
|
+
private_key = None
|
|
41
|
+
certificate_chain = None
|
|
42
|
+
|
|
43
|
+
if "ssl-grpc-use-ssl" in ssl_options:
|
|
44
|
+
ssl = ssl_options["ssl-grpc-use-ssl"].lower() == "true"
|
|
45
|
+
if "ssl-grpc-root-certifications-file" in ssl_options:
|
|
46
|
+
root_certificates = ssl_options["ssl-grpc-root-certifications-file"]
|
|
47
|
+
if "ssl-grpc-private-key-file" in ssl_options:
|
|
48
|
+
private_key = ssl_options["ssl-grpc-private-key-file"]
|
|
49
|
+
if "ssl-grpc-certificate-chain-file" in ssl_options:
|
|
50
|
+
certificate_chain = ssl_options["ssl-grpc-certificate-chain-file"]
|
|
51
|
+
|
|
52
|
+
# Fix for gRPC 1.60.0+: Force IPv4 resolution for localhost connections
|
|
53
|
+
# gRPC 1.60.0+ prefers IPv6, causing "localhost" to resolve to [::1]
|
|
54
|
+
# On systems where IPv6 is not properly configured, this causes connection failures
|
|
55
|
+
# Force IPv4 by using 127.0.0.1, which is more reliable across environments
|
|
56
|
+
channel_args = None
|
|
57
|
+
if "localhost" in server_url:
|
|
58
|
+
server_url = server_url.replace("localhost", "127.0.0.1")
|
|
59
|
+
# For SSL connections, override target name to match certificate
|
|
60
|
+
if ssl:
|
|
61
|
+
channel_args = [("grpc.ssl_target_name_override", "localhost")]
|
|
62
|
+
|
|
63
|
+
self._client = grpcclient.InferenceServerClient(
|
|
64
|
+
url=server_url,
|
|
65
|
+
ssl=ssl,
|
|
66
|
+
root_certificates=root_certificates,
|
|
67
|
+
private_key=private_key,
|
|
68
|
+
certificate_chain=certificate_chain,
|
|
69
|
+
channel_args=channel_args,
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def get_model_config(self, model_name, num_retries):
|
|
73
|
+
"""
|
|
74
|
+
Model name to get the config for.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
model_name : str
|
|
79
|
+
Name of the model to find the config.
|
|
80
|
+
|
|
81
|
+
num_retries : int
|
|
82
|
+
Number of times to wait for the model load
|
|
83
|
+
|
|
84
|
+
Returns
|
|
85
|
+
-------
|
|
86
|
+
dict
|
|
87
|
+
A dictionary containing the model config.
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
self.wait_for_model_ready(model_name, num_retries)
|
|
91
|
+
model_config_dict = self._client.get_model_config(model_name, as_json=True)
|
|
92
|
+
return model_config_dict["config"]
|
|
93
|
+
|
|
94
|
+
def get_model_repository_index(self):
|
|
95
|
+
"""
|
|
96
|
+
Returns the JSON dict holding the model repository index.
|
|
97
|
+
"""
|
|
98
|
+
return self._client.get_model_repository_index(as_json=True)["models"]
|
|
99
|
+
|
|
100
|
+
def is_model_ready(self, model_name: str) -> bool:
|
|
101
|
+
"""
|
|
102
|
+
Returns true if the model is loaded on the server
|
|
103
|
+
"""
|
|
104
|
+
return self._client.is_model_ready(model_name)
|