triton-model-analyzer 1.48.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. model_analyzer/__init__.py +15 -0
  2. model_analyzer/analyzer.py +448 -0
  3. model_analyzer/cli/__init__.py +15 -0
  4. model_analyzer/cli/cli.py +193 -0
  5. model_analyzer/config/__init__.py +15 -0
  6. model_analyzer/config/generate/__init__.py +15 -0
  7. model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
  8. model_analyzer/config/generate/base_model_config_generator.py +352 -0
  9. model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
  10. model_analyzer/config/generate/brute_run_config_generator.py +154 -0
  11. model_analyzer/config/generate/concurrency_sweeper.py +75 -0
  12. model_analyzer/config/generate/config_generator_interface.py +52 -0
  13. model_analyzer/config/generate/coordinate.py +143 -0
  14. model_analyzer/config/generate/coordinate_data.py +86 -0
  15. model_analyzer/config/generate/generator_utils.py +116 -0
  16. model_analyzer/config/generate/manual_model_config_generator.py +187 -0
  17. model_analyzer/config/generate/model_config_generator_factory.py +92 -0
  18. model_analyzer/config/generate/model_profile_spec.py +74 -0
  19. model_analyzer/config/generate/model_run_config_generator.py +154 -0
  20. model_analyzer/config/generate/model_variant_name_manager.py +150 -0
  21. model_analyzer/config/generate/neighborhood.py +536 -0
  22. model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
  23. model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
  24. model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
  25. model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
  26. model_analyzer/config/generate/quick_run_config_generator.py +753 -0
  27. model_analyzer/config/generate/run_config_generator_factory.py +329 -0
  28. model_analyzer/config/generate/search_config.py +112 -0
  29. model_analyzer/config/generate/search_dimension.py +73 -0
  30. model_analyzer/config/generate/search_dimensions.py +85 -0
  31. model_analyzer/config/generate/search_parameter.py +49 -0
  32. model_analyzer/config/generate/search_parameters.py +388 -0
  33. model_analyzer/config/input/__init__.py +15 -0
  34. model_analyzer/config/input/config_command.py +483 -0
  35. model_analyzer/config/input/config_command_profile.py +1747 -0
  36. model_analyzer/config/input/config_command_report.py +267 -0
  37. model_analyzer/config/input/config_defaults.py +236 -0
  38. model_analyzer/config/input/config_enum.py +83 -0
  39. model_analyzer/config/input/config_field.py +216 -0
  40. model_analyzer/config/input/config_list_generic.py +112 -0
  41. model_analyzer/config/input/config_list_numeric.py +151 -0
  42. model_analyzer/config/input/config_list_string.py +111 -0
  43. model_analyzer/config/input/config_none.py +71 -0
  44. model_analyzer/config/input/config_object.py +129 -0
  45. model_analyzer/config/input/config_primitive.py +81 -0
  46. model_analyzer/config/input/config_status.py +75 -0
  47. model_analyzer/config/input/config_sweep.py +83 -0
  48. model_analyzer/config/input/config_union.py +113 -0
  49. model_analyzer/config/input/config_utils.py +128 -0
  50. model_analyzer/config/input/config_value.py +243 -0
  51. model_analyzer/config/input/objects/__init__.py +15 -0
  52. model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
  53. model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
  54. model_analyzer/config/input/objects/config_plot.py +198 -0
  55. model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
  56. model_analyzer/config/input/yaml_config_validator.py +82 -0
  57. model_analyzer/config/run/__init__.py +15 -0
  58. model_analyzer/config/run/model_run_config.py +313 -0
  59. model_analyzer/config/run/run_config.py +168 -0
  60. model_analyzer/constants.py +76 -0
  61. model_analyzer/device/__init__.py +15 -0
  62. model_analyzer/device/device.py +24 -0
  63. model_analyzer/device/gpu_device.py +87 -0
  64. model_analyzer/device/gpu_device_factory.py +248 -0
  65. model_analyzer/entrypoint.py +307 -0
  66. model_analyzer/log_formatter.py +65 -0
  67. model_analyzer/model_analyzer_exceptions.py +24 -0
  68. model_analyzer/model_manager.py +255 -0
  69. model_analyzer/monitor/__init__.py +15 -0
  70. model_analyzer/monitor/cpu_monitor.py +69 -0
  71. model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
  72. model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
  73. model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
  74. model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
  75. model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
  76. model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
  77. model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
  78. model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
  79. model_analyzer/monitor/dcgm/__init__.py +15 -0
  80. model_analyzer/monitor/dcgm/common/__init__.py +13 -0
  81. model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
  82. model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
  83. model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
  84. model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
  85. model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
  86. model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
  87. model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
  88. model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
  89. model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
  90. model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
  91. model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
  92. model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
  93. model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
  94. model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
  95. model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
  96. model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
  97. model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
  98. model_analyzer/monitor/dcgm/pydcgm.py +47 -0
  99. model_analyzer/monitor/monitor.py +143 -0
  100. model_analyzer/monitor/remote_monitor.py +137 -0
  101. model_analyzer/output/__init__.py +15 -0
  102. model_analyzer/output/file_writer.py +63 -0
  103. model_analyzer/output/output_writer.py +42 -0
  104. model_analyzer/perf_analyzer/__init__.py +15 -0
  105. model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
  106. model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
  107. model_analyzer/perf_analyzer/perf_config.py +479 -0
  108. model_analyzer/plots/__init__.py +15 -0
  109. model_analyzer/plots/detailed_plot.py +266 -0
  110. model_analyzer/plots/plot_manager.py +224 -0
  111. model_analyzer/plots/simple_plot.py +213 -0
  112. model_analyzer/record/__init__.py +15 -0
  113. model_analyzer/record/gpu_record.py +68 -0
  114. model_analyzer/record/metrics_manager.py +887 -0
  115. model_analyzer/record/record.py +280 -0
  116. model_analyzer/record/record_aggregator.py +256 -0
  117. model_analyzer/record/types/__init__.py +15 -0
  118. model_analyzer/record/types/cpu_available_ram.py +93 -0
  119. model_analyzer/record/types/cpu_used_ram.py +93 -0
  120. model_analyzer/record/types/gpu_free_memory.py +96 -0
  121. model_analyzer/record/types/gpu_power_usage.py +107 -0
  122. model_analyzer/record/types/gpu_total_memory.py +96 -0
  123. model_analyzer/record/types/gpu_used_memory.py +96 -0
  124. model_analyzer/record/types/gpu_utilization.py +108 -0
  125. model_analyzer/record/types/inter_token_latency_avg.py +60 -0
  126. model_analyzer/record/types/inter_token_latency_base.py +74 -0
  127. model_analyzer/record/types/inter_token_latency_max.py +60 -0
  128. model_analyzer/record/types/inter_token_latency_min.py +60 -0
  129. model_analyzer/record/types/inter_token_latency_p25.py +60 -0
  130. model_analyzer/record/types/inter_token_latency_p50.py +60 -0
  131. model_analyzer/record/types/inter_token_latency_p75.py +60 -0
  132. model_analyzer/record/types/inter_token_latency_p90.py +60 -0
  133. model_analyzer/record/types/inter_token_latency_p95.py +60 -0
  134. model_analyzer/record/types/inter_token_latency_p99.py +60 -0
  135. model_analyzer/record/types/output_token_throughput.py +105 -0
  136. model_analyzer/record/types/perf_client_response_wait.py +97 -0
  137. model_analyzer/record/types/perf_client_send_recv.py +97 -0
  138. model_analyzer/record/types/perf_latency.py +111 -0
  139. model_analyzer/record/types/perf_latency_avg.py +60 -0
  140. model_analyzer/record/types/perf_latency_base.py +74 -0
  141. model_analyzer/record/types/perf_latency_p90.py +60 -0
  142. model_analyzer/record/types/perf_latency_p95.py +60 -0
  143. model_analyzer/record/types/perf_latency_p99.py +60 -0
  144. model_analyzer/record/types/perf_server_compute_infer.py +97 -0
  145. model_analyzer/record/types/perf_server_compute_input.py +97 -0
  146. model_analyzer/record/types/perf_server_compute_output.py +97 -0
  147. model_analyzer/record/types/perf_server_queue.py +97 -0
  148. model_analyzer/record/types/perf_throughput.py +105 -0
  149. model_analyzer/record/types/time_to_first_token_avg.py +60 -0
  150. model_analyzer/record/types/time_to_first_token_base.py +74 -0
  151. model_analyzer/record/types/time_to_first_token_max.py +60 -0
  152. model_analyzer/record/types/time_to_first_token_min.py +60 -0
  153. model_analyzer/record/types/time_to_first_token_p25.py +60 -0
  154. model_analyzer/record/types/time_to_first_token_p50.py +60 -0
  155. model_analyzer/record/types/time_to_first_token_p75.py +60 -0
  156. model_analyzer/record/types/time_to_first_token_p90.py +60 -0
  157. model_analyzer/record/types/time_to_first_token_p95.py +60 -0
  158. model_analyzer/record/types/time_to_first_token_p99.py +60 -0
  159. model_analyzer/reports/__init__.py +15 -0
  160. model_analyzer/reports/html_report.py +195 -0
  161. model_analyzer/reports/pdf_report.py +50 -0
  162. model_analyzer/reports/report.py +86 -0
  163. model_analyzer/reports/report_factory.py +62 -0
  164. model_analyzer/reports/report_manager.py +1376 -0
  165. model_analyzer/reports/report_utils.py +42 -0
  166. model_analyzer/result/__init__.py +15 -0
  167. model_analyzer/result/constraint_manager.py +150 -0
  168. model_analyzer/result/model_config_measurement.py +354 -0
  169. model_analyzer/result/model_constraints.py +105 -0
  170. model_analyzer/result/parameter_search.py +246 -0
  171. model_analyzer/result/result_manager.py +430 -0
  172. model_analyzer/result/result_statistics.py +159 -0
  173. model_analyzer/result/result_table.py +217 -0
  174. model_analyzer/result/result_table_manager.py +646 -0
  175. model_analyzer/result/result_utils.py +42 -0
  176. model_analyzer/result/results.py +277 -0
  177. model_analyzer/result/run_config_measurement.py +658 -0
  178. model_analyzer/result/run_config_result.py +210 -0
  179. model_analyzer/result/run_config_result_comparator.py +110 -0
  180. model_analyzer/result/sorted_results.py +151 -0
  181. model_analyzer/state/__init__.py +15 -0
  182. model_analyzer/state/analyzer_state.py +76 -0
  183. model_analyzer/state/analyzer_state_manager.py +215 -0
  184. model_analyzer/triton/__init__.py +15 -0
  185. model_analyzer/triton/client/__init__.py +15 -0
  186. model_analyzer/triton/client/client.py +234 -0
  187. model_analyzer/triton/client/client_factory.py +57 -0
  188. model_analyzer/triton/client/grpc_client.py +104 -0
  189. model_analyzer/triton/client/http_client.py +107 -0
  190. model_analyzer/triton/model/__init__.py +15 -0
  191. model_analyzer/triton/model/model_config.py +556 -0
  192. model_analyzer/triton/model/model_config_variant.py +29 -0
  193. model_analyzer/triton/server/__init__.py +15 -0
  194. model_analyzer/triton/server/server.py +76 -0
  195. model_analyzer/triton/server/server_config.py +269 -0
  196. model_analyzer/triton/server/server_docker.py +229 -0
  197. model_analyzer/triton/server/server_factory.py +306 -0
  198. model_analyzer/triton/server/server_local.py +158 -0
  199. triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
  200. triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
  201. triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
  202. triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
  203. triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
  204. triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,215 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import glob
18
+ import json
19
+ import logging
20
+ import os
21
+ import signal
22
+ import sys
23
+ import traceback
24
+
25
+ from model_analyzer.constants import LOGGER_NAME, MAX_NUMBER_OF_INTERRUPTS
26
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
27
+ from model_analyzer.state.analyzer_state import AnalyzerState
28
+
29
+ logger = logging.getLogger(LOGGER_NAME)
30
+
31
+
32
+ class AnalyzerStateManager:
33
+ """
34
+ Maintains the state of the Model Analyzer
35
+ """
36
+
37
+ def __init__(self, config, server):
38
+ """
39
+ Parameters
40
+ ----------
41
+ config: ConfigCommand
42
+ The analyzer's config
43
+ server : TritonServer
44
+ Handle for tritonserver instance
45
+ """
46
+
47
+ self._config = config
48
+ self._server = server
49
+ self._exiting = 0
50
+ self._checkpoint_dir = config.checkpoint_directory
51
+ self._state_changed = False
52
+
53
+ if os.path.exists(self._checkpoint_dir):
54
+ self._checkpoint_index = self._latest_checkpoint() + 1
55
+ else:
56
+ os.makedirs(self._checkpoint_dir)
57
+ self._checkpoint_index = 0
58
+ signal.signal(signal.SIGINT, self.interrupt_handler)
59
+
60
+ self._current_state = AnalyzerState()
61
+ self._starting_fresh_run = True
62
+
63
+ def starting_fresh_run(self):
64
+ """
65
+ Returns
66
+ -------
67
+ True if starting a fresh run
68
+ False if checkpoint found and loaded
69
+ """
70
+
71
+ return self._starting_fresh_run
72
+
73
+ def exiting(self):
74
+ """
75
+ Returns
76
+ -------
77
+ True if interrupt handler ran
78
+ even once, False otherwise
79
+ """
80
+
81
+ return self._exiting > 0
82
+
83
+ def get_state_variable(self, name):
84
+ """
85
+ Get a named variable from
86
+ the current AnalyzerState
87
+
88
+ Parameters
89
+ ----------
90
+ name : str
91
+ The name of the variable
92
+ """
93
+ return self._current_state.get(name)
94
+
95
+ def set_state_variable(self, name, value):
96
+ """
97
+ Set a named variable from
98
+ the current AnalyzerState
99
+
100
+ Parameters
101
+ ----------
102
+ name: str
103
+ The name of the variable
104
+ value: Any
105
+ the value to set for that variable
106
+ """
107
+
108
+ self._state_changed = True
109
+ self._current_state.set(name, value)
110
+
111
+ def load_checkpoint(self, checkpoint_required):
112
+ """
113
+ Load the state of the Model Analyzer from
114
+ most recent checkpoint file, also
115
+ set whether we are starting a fresh run
116
+
117
+ Parameters
118
+ ----------
119
+ checkpoint_required : bool
120
+ If true, an existing checkpoint is required to run MA
121
+ """
122
+
123
+ latest_checkpoint_file = os.path.join(
124
+ self._checkpoint_dir, f"{self._latest_checkpoint()}.ckpt"
125
+ )
126
+ if os.path.exists(latest_checkpoint_file):
127
+ logger.info(f"Loaded checkpoint from file {latest_checkpoint_file}")
128
+ with open(latest_checkpoint_file, "r") as f:
129
+ try:
130
+ self._current_state = AnalyzerState.from_dict(json.load(f))
131
+ except EOFError:
132
+ raise TritonModelAnalyzerException(
133
+ f"Checkpoint file {latest_checkpoint_file} is"
134
+ " empty or corrupted. Remove it from checkpoint"
135
+ " directory."
136
+ )
137
+ self._starting_fresh_run = False
138
+ else:
139
+ if checkpoint_required:
140
+ raise TritonModelAnalyzerException(f"No checkpoint file found")
141
+ else:
142
+ logger.info("No checkpoint file found, starting a fresh run.")
143
+
144
+ def default_encode(self, obj):
145
+ if isinstance(obj, bytes):
146
+ return obj.decode("utf-8")
147
+ elif hasattr(obj, "to_dict"):
148
+ return obj.to_dict()
149
+ else:
150
+ return obj.__dict__
151
+
152
+ def save_checkpoint(self):
153
+ """
154
+ Saves the state of the model analyzer to disk
155
+ if there has been a change since the last checkpoint
156
+
157
+ Parameters
158
+ ----------
159
+ state: AnalyzerState
160
+ The state object to be saved
161
+ """
162
+
163
+ ckpt_filename = os.path.join(
164
+ self._checkpoint_dir, f"{self._checkpoint_index}.ckpt"
165
+ )
166
+ if self._state_changed:
167
+ with open(ckpt_filename, "w") as f:
168
+ json.dump(self._current_state, f, default=self.default_encode)
169
+ logger.info(f"Saved checkpoint to {ckpt_filename}")
170
+
171
+ self._state_changed = False
172
+ else:
173
+ logger.info(f"No changes made to analyzer data, no checkpoint saved.")
174
+
175
+ def interrupt_handler(self, signal, frame):
176
+ """
177
+ A signal handler to properly
178
+ shutdown the model analyzer on
179
+ interrupt
180
+ """
181
+
182
+ self._exiting += 1
183
+ if logger.getEffectiveLevel() <= logging.DEBUG:
184
+ traceback.print_stack(limit=15)
185
+ logger.info(
186
+ f"Received SIGINT {self._exiting}/{MAX_NUMBER_OF_INTERRUPTS}. "
187
+ "Will attempt to exit after current measurement."
188
+ )
189
+ if self._exiting >= MAX_NUMBER_OF_INTERRUPTS:
190
+ logger.info(
191
+ f"Received SIGINT maximum number of times. Saving state and exiting immediately. "
192
+ "perf_analyzer may still be running"
193
+ )
194
+ self.save_checkpoint()
195
+
196
+ # Exit server
197
+ if self._server:
198
+ self._server.stop()
199
+ sys.exit(1)
200
+
201
+ def _latest_checkpoint(self):
202
+ """
203
+ Get the highest index checkpoint file in the
204
+ checkpoint directory, return its index.
205
+ """
206
+
207
+ checkpoint_files = glob.glob(os.path.join(self._checkpoint_dir, "*.ckpt"))
208
+ if not checkpoint_files:
209
+ return -1
210
+ try:
211
+ return max(
212
+ [int(os.path.split(f)[1].split(".")[0]) for f in checkpoint_files]
213
+ )
214
+ except Exception as e:
215
+ raise TritonModelAnalyzerException(e)
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
@@ -0,0 +1,234 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ import time
19
+ from subprocess import DEVNULL
20
+
21
+ from model_analyzer.constants import LOGGER_NAME
22
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
23
+
24
+ logger = logging.getLogger(LOGGER_NAME)
25
+
26
+
27
+ class TritonClient:
28
+ """
29
+ Defines the interface for the objects created by
30
+ TritonClientFactory
31
+ """
32
+
33
+ def wait_for_server_ready(
34
+ self,
35
+ num_retries,
36
+ sleep_time=1,
37
+ log_file=None,
38
+ ):
39
+ """
40
+ Parameters
41
+ ----------
42
+ num_retries : int
43
+ number of times to send a ready status
44
+ request to the server before raising
45
+ an exception
46
+ sleep_time: int
47
+ amount of time in seconds to sleep between retries
48
+ log_file: TextIOWrapper
49
+ file that contains the server's output log
50
+ Raises
51
+ ------
52
+ TritonModelAnalyzerException
53
+ If server readiness could not be
54
+ determined in given num_retries
55
+ """
56
+
57
+ retries = num_retries
58
+ while retries > 0:
59
+ try:
60
+ if self._client.is_server_ready():
61
+ time.sleep(sleep_time)
62
+ return
63
+ else:
64
+ self._check_for_triton_log_errors(log_file)
65
+ time.sleep(sleep_time)
66
+ retries -= 1
67
+ except Exception as e:
68
+ # Log connection failures with more detail for debugging
69
+ if retries == num_retries or retries % 10 == 0:
70
+ logger.debug(
71
+ f"Failed to connect to Triton server (attempt {num_retries - retries + 1}/{num_retries}): {e}"
72
+ )
73
+ self._check_for_triton_log_errors(log_file)
74
+ time.sleep(sleep_time)
75
+ retries -= 1
76
+ if retries == 0:
77
+ raise TritonModelAnalyzerException(e)
78
+ raise TritonModelAnalyzerException(
79
+ "Could not determine server readiness. " "Number of retries exceeded."
80
+ )
81
+
82
+ def load_model(self, model_name, variant_name="", config_str=None):
83
+ """
84
+ Request the inference server to load
85
+ a particular model in explicit model
86
+ control mode.
87
+
88
+ Parameters
89
+ ----------
90
+ model_name : str
91
+ Name of the model
92
+
93
+ variant_name: str
94
+ Name of the model variant
95
+
96
+ config_str: str
97
+ Optional config string used to load the model
98
+
99
+ Returns
100
+ ------
101
+ int or None
102
+ Returns -1 if the failed.
103
+ """
104
+
105
+ variant_name = variant_name if variant_name else model_name
106
+
107
+ try:
108
+ self._client.load_model(model_name, config=config_str)
109
+ logger.debug(f"Model {variant_name} loaded")
110
+ return None
111
+ except Exception as e:
112
+ logger.info(f"Model {variant_name} load failed: {e}")
113
+ if "polling is enabled" in e.message():
114
+ raise TritonModelAnalyzerException(
115
+ "The remote Tritonserver needs to be launched in EXPLICIT mode"
116
+ )
117
+ return -1
118
+
119
+ def unload_model(self, model_name):
120
+ """
121
+ Request the inference server to unload
122
+ a particular model in explicit model
123
+ control mode.
124
+
125
+ Parameters
126
+ ----------
127
+ model_name : str
128
+ name of the model to load from repository
129
+
130
+ Raises
131
+ ------
132
+ TritonModelAnalyzerException
133
+ If server throws Exception
134
+
135
+ Returns
136
+ ------
137
+ int or None
138
+ Returns -1 if the failed.
139
+ """
140
+
141
+ try:
142
+ self._client.unload_model(model_name)
143
+ logger.debug(f"Model {model_name} unloaded")
144
+ return None
145
+ except Exception as e:
146
+ logger.info(f"Model {model_name} unload failed: {e}")
147
+ return -1
148
+
149
+ def wait_for_model_ready(self, model_name, num_retries, sleep_time=1):
150
+ """
151
+ Returns when model is ready.
152
+
153
+ Parameters
154
+ ----------
155
+ model_name : str
156
+ name of the model to load from repository
157
+ num_retries : int
158
+ number of times to send a ready status
159
+ request to the server before raising
160
+ an exception
161
+
162
+ Raises
163
+ ------
164
+ TritonModelAnalyzerException
165
+ If could not determine model readiness
166
+ in given num_retries
167
+
168
+ Returns
169
+ ------
170
+ int or None
171
+ Returns -1 if the failed.
172
+ """
173
+
174
+ retries = num_retries
175
+ error = None
176
+ while retries > 0:
177
+ try:
178
+ if self._client.is_model_ready(model_name):
179
+ return None
180
+ else:
181
+ time.sleep(sleep_time)
182
+ retries -= 1
183
+ except Exception as e:
184
+ error = e
185
+ time.sleep(sleep_time)
186
+ retries -= 1
187
+
188
+ logger.info(f"Model readiness failed for model {model_name}. Error {error}")
189
+ return -1
190
+
191
+ def get_model_config(self, model_name, num_retries):
192
+ """
193
+ Model name to get the config for.
194
+
195
+ Parameters
196
+ ----------
197
+ model_name : str
198
+ Name of the model to find the config.
199
+
200
+ num_retries : int
201
+ Number of times to wait for the model load
202
+
203
+ Returns
204
+ -------
205
+ dict or None
206
+ A dictionary containing the model config.
207
+ """
208
+
209
+ self.wait_for_model_ready(model_name, num_retries)
210
+ model_config_dict = self._client.get_model_config(model_name)
211
+ return model_config_dict
212
+
213
+ def is_server_ready(self):
214
+ """
215
+ Returns true if the server is ready. Else False
216
+ """
217
+ return self._client.is_server_ready()
218
+
219
+ def _check_for_triton_log_errors(self, log_file):
220
+ if not log_file or log_file == DEVNULL:
221
+ return
222
+
223
+ log_file.seek(0)
224
+ log_output = log_file.read()
225
+
226
+ if not type(log_output) == str:
227
+ log_output = log_output.decode("utf-8")
228
+
229
+ if log_output:
230
+ if "Unexpected argument:" in log_output:
231
+ error_start = log_output.find("Unexpected argument:")
232
+ raise TritonModelAnalyzerException(
233
+ f"Error: TritonServer did not launch successfully\n\n{log_output[error_start:]}"
234
+ )
@@ -0,0 +1,57 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .grpc_client import TritonGRPCClient
18
+ from .http_client import TritonHTTPClient
19
+
20
+
21
+ class TritonClientFactory:
22
+ """
23
+ Base client creator class that declares
24
+ a factory method
25
+ """
26
+
27
+ @staticmethod
28
+ def create_grpc_client(server_url, ssl_options={}):
29
+ """
30
+ Parameters
31
+ ----------
32
+ server_url : str
33
+ The url for Triton server's GRPC endpoint
34
+ ssl_options : dict
35
+ Dictionary of SSL options for gRPC python client
36
+
37
+ Returns
38
+ -------
39
+ TritonGRPCClient
40
+ """
41
+ return TritonGRPCClient(server_url=server_url, ssl_options=ssl_options)
42
+
43
+ @staticmethod
44
+ def create_http_client(server_url, ssl_options={}):
45
+ """
46
+ Parameters
47
+ ----------
48
+ server_url : str
49
+ The url for Triton server's HTTP endpoint
50
+ ssl_options : dict
51
+ Dictionary of SSL options for HTTP python client
52
+
53
+ Returns
54
+ -------
55
+ TritonHTTPClient
56
+ """
57
+ return TritonHTTPClient(server_url=server_url, ssl_options=ssl_options)
@@ -0,0 +1,104 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import tritonclient.grpc as grpcclient
18
+
19
+ from .client import TritonClient
20
+
21
+
22
+ class TritonGRPCClient(TritonClient):
23
+ """
24
+ Concrete implementation of TritonClient
25
+ for GRPC
26
+ """
27
+
28
+ def __init__(self, server_url, ssl_options={}):
29
+ """
30
+ Parameters
31
+ ----------
32
+ server_url : str
33
+ The url for Triton server's GRPC endpoint
34
+ ssl_options : dict
35
+ Dictionary of SSL options for gRPC python client
36
+ """
37
+
38
+ ssl = False
39
+ root_certificates = None
40
+ private_key = None
41
+ certificate_chain = None
42
+
43
+ if "ssl-grpc-use-ssl" in ssl_options:
44
+ ssl = ssl_options["ssl-grpc-use-ssl"].lower() == "true"
45
+ if "ssl-grpc-root-certifications-file" in ssl_options:
46
+ root_certificates = ssl_options["ssl-grpc-root-certifications-file"]
47
+ if "ssl-grpc-private-key-file" in ssl_options:
48
+ private_key = ssl_options["ssl-grpc-private-key-file"]
49
+ if "ssl-grpc-certificate-chain-file" in ssl_options:
50
+ certificate_chain = ssl_options["ssl-grpc-certificate-chain-file"]
51
+
52
+ # Fix for gRPC 1.60.0+: Force IPv4 resolution for localhost connections
53
+ # gRPC 1.60.0+ prefers IPv6, causing "localhost" to resolve to [::1]
54
+ # On systems where IPv6 is not properly configured, this causes connection failures
55
+ # Force IPv4 by using 127.0.0.1, which is more reliable across environments
56
+ channel_args = None
57
+ if "localhost" in server_url:
58
+ server_url = server_url.replace("localhost", "127.0.0.1")
59
+ # For SSL connections, override target name to match certificate
60
+ if ssl:
61
+ channel_args = [("grpc.ssl_target_name_override", "localhost")]
62
+
63
+ self._client = grpcclient.InferenceServerClient(
64
+ url=server_url,
65
+ ssl=ssl,
66
+ root_certificates=root_certificates,
67
+ private_key=private_key,
68
+ certificate_chain=certificate_chain,
69
+ channel_args=channel_args,
70
+ )
71
+
72
+ def get_model_config(self, model_name, num_retries):
73
+ """
74
+ Model name to get the config for.
75
+
76
+ Parameters
77
+ ----------
78
+ model_name : str
79
+ Name of the model to find the config.
80
+
81
+ num_retries : int
82
+ Number of times to wait for the model load
83
+
84
+ Returns
85
+ -------
86
+ dict
87
+ A dictionary containing the model config.
88
+ """
89
+
90
+ self.wait_for_model_ready(model_name, num_retries)
91
+ model_config_dict = self._client.get_model_config(model_name, as_json=True)
92
+ return model_config_dict["config"]
93
+
94
+ def get_model_repository_index(self):
95
+ """
96
+ Returns the JSON dict holding the model repository index.
97
+ """
98
+ return self._client.get_model_repository_index(as_json=True)["models"]
99
+
100
+ def is_model_ready(self, model_name: str) -> bool:
101
+ """
102
+ Returns true if the model is loaded on the server
103
+ """
104
+ return self._client.is_model_ready(model_name)