triton-model-analyzer 1.48.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. model_analyzer/__init__.py +15 -0
  2. model_analyzer/analyzer.py +448 -0
  3. model_analyzer/cli/__init__.py +15 -0
  4. model_analyzer/cli/cli.py +193 -0
  5. model_analyzer/config/__init__.py +15 -0
  6. model_analyzer/config/generate/__init__.py +15 -0
  7. model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
  8. model_analyzer/config/generate/base_model_config_generator.py +352 -0
  9. model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
  10. model_analyzer/config/generate/brute_run_config_generator.py +154 -0
  11. model_analyzer/config/generate/concurrency_sweeper.py +75 -0
  12. model_analyzer/config/generate/config_generator_interface.py +52 -0
  13. model_analyzer/config/generate/coordinate.py +143 -0
  14. model_analyzer/config/generate/coordinate_data.py +86 -0
  15. model_analyzer/config/generate/generator_utils.py +116 -0
  16. model_analyzer/config/generate/manual_model_config_generator.py +187 -0
  17. model_analyzer/config/generate/model_config_generator_factory.py +92 -0
  18. model_analyzer/config/generate/model_profile_spec.py +74 -0
  19. model_analyzer/config/generate/model_run_config_generator.py +154 -0
  20. model_analyzer/config/generate/model_variant_name_manager.py +150 -0
  21. model_analyzer/config/generate/neighborhood.py +536 -0
  22. model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
  23. model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
  24. model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
  25. model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
  26. model_analyzer/config/generate/quick_run_config_generator.py +753 -0
  27. model_analyzer/config/generate/run_config_generator_factory.py +329 -0
  28. model_analyzer/config/generate/search_config.py +112 -0
  29. model_analyzer/config/generate/search_dimension.py +73 -0
  30. model_analyzer/config/generate/search_dimensions.py +85 -0
  31. model_analyzer/config/generate/search_parameter.py +49 -0
  32. model_analyzer/config/generate/search_parameters.py +388 -0
  33. model_analyzer/config/input/__init__.py +15 -0
  34. model_analyzer/config/input/config_command.py +483 -0
  35. model_analyzer/config/input/config_command_profile.py +1747 -0
  36. model_analyzer/config/input/config_command_report.py +267 -0
  37. model_analyzer/config/input/config_defaults.py +236 -0
  38. model_analyzer/config/input/config_enum.py +83 -0
  39. model_analyzer/config/input/config_field.py +216 -0
  40. model_analyzer/config/input/config_list_generic.py +112 -0
  41. model_analyzer/config/input/config_list_numeric.py +151 -0
  42. model_analyzer/config/input/config_list_string.py +111 -0
  43. model_analyzer/config/input/config_none.py +71 -0
  44. model_analyzer/config/input/config_object.py +129 -0
  45. model_analyzer/config/input/config_primitive.py +81 -0
  46. model_analyzer/config/input/config_status.py +75 -0
  47. model_analyzer/config/input/config_sweep.py +83 -0
  48. model_analyzer/config/input/config_union.py +113 -0
  49. model_analyzer/config/input/config_utils.py +128 -0
  50. model_analyzer/config/input/config_value.py +243 -0
  51. model_analyzer/config/input/objects/__init__.py +15 -0
  52. model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
  53. model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
  54. model_analyzer/config/input/objects/config_plot.py +198 -0
  55. model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
  56. model_analyzer/config/input/yaml_config_validator.py +82 -0
  57. model_analyzer/config/run/__init__.py +15 -0
  58. model_analyzer/config/run/model_run_config.py +313 -0
  59. model_analyzer/config/run/run_config.py +168 -0
  60. model_analyzer/constants.py +76 -0
  61. model_analyzer/device/__init__.py +15 -0
  62. model_analyzer/device/device.py +24 -0
  63. model_analyzer/device/gpu_device.py +87 -0
  64. model_analyzer/device/gpu_device_factory.py +248 -0
  65. model_analyzer/entrypoint.py +307 -0
  66. model_analyzer/log_formatter.py +65 -0
  67. model_analyzer/model_analyzer_exceptions.py +24 -0
  68. model_analyzer/model_manager.py +255 -0
  69. model_analyzer/monitor/__init__.py +15 -0
  70. model_analyzer/monitor/cpu_monitor.py +69 -0
  71. model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
  72. model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
  73. model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
  74. model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
  75. model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
  76. model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
  77. model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
  78. model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
  79. model_analyzer/monitor/dcgm/__init__.py +15 -0
  80. model_analyzer/monitor/dcgm/common/__init__.py +13 -0
  81. model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
  82. model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
  83. model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
  84. model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
  85. model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
  86. model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
  87. model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
  88. model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
  89. model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
  90. model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
  91. model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
  92. model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
  93. model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
  94. model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
  95. model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
  96. model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
  97. model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
  98. model_analyzer/monitor/dcgm/pydcgm.py +47 -0
  99. model_analyzer/monitor/monitor.py +143 -0
  100. model_analyzer/monitor/remote_monitor.py +137 -0
  101. model_analyzer/output/__init__.py +15 -0
  102. model_analyzer/output/file_writer.py +63 -0
  103. model_analyzer/output/output_writer.py +42 -0
  104. model_analyzer/perf_analyzer/__init__.py +15 -0
  105. model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
  106. model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
  107. model_analyzer/perf_analyzer/perf_config.py +479 -0
  108. model_analyzer/plots/__init__.py +15 -0
  109. model_analyzer/plots/detailed_plot.py +266 -0
  110. model_analyzer/plots/plot_manager.py +224 -0
  111. model_analyzer/plots/simple_plot.py +213 -0
  112. model_analyzer/record/__init__.py +15 -0
  113. model_analyzer/record/gpu_record.py +68 -0
  114. model_analyzer/record/metrics_manager.py +887 -0
  115. model_analyzer/record/record.py +280 -0
  116. model_analyzer/record/record_aggregator.py +256 -0
  117. model_analyzer/record/types/__init__.py +15 -0
  118. model_analyzer/record/types/cpu_available_ram.py +93 -0
  119. model_analyzer/record/types/cpu_used_ram.py +93 -0
  120. model_analyzer/record/types/gpu_free_memory.py +96 -0
  121. model_analyzer/record/types/gpu_power_usage.py +107 -0
  122. model_analyzer/record/types/gpu_total_memory.py +96 -0
  123. model_analyzer/record/types/gpu_used_memory.py +96 -0
  124. model_analyzer/record/types/gpu_utilization.py +108 -0
  125. model_analyzer/record/types/inter_token_latency_avg.py +60 -0
  126. model_analyzer/record/types/inter_token_latency_base.py +74 -0
  127. model_analyzer/record/types/inter_token_latency_max.py +60 -0
  128. model_analyzer/record/types/inter_token_latency_min.py +60 -0
  129. model_analyzer/record/types/inter_token_latency_p25.py +60 -0
  130. model_analyzer/record/types/inter_token_latency_p50.py +60 -0
  131. model_analyzer/record/types/inter_token_latency_p75.py +60 -0
  132. model_analyzer/record/types/inter_token_latency_p90.py +60 -0
  133. model_analyzer/record/types/inter_token_latency_p95.py +60 -0
  134. model_analyzer/record/types/inter_token_latency_p99.py +60 -0
  135. model_analyzer/record/types/output_token_throughput.py +105 -0
  136. model_analyzer/record/types/perf_client_response_wait.py +97 -0
  137. model_analyzer/record/types/perf_client_send_recv.py +97 -0
  138. model_analyzer/record/types/perf_latency.py +111 -0
  139. model_analyzer/record/types/perf_latency_avg.py +60 -0
  140. model_analyzer/record/types/perf_latency_base.py +74 -0
  141. model_analyzer/record/types/perf_latency_p90.py +60 -0
  142. model_analyzer/record/types/perf_latency_p95.py +60 -0
  143. model_analyzer/record/types/perf_latency_p99.py +60 -0
  144. model_analyzer/record/types/perf_server_compute_infer.py +97 -0
  145. model_analyzer/record/types/perf_server_compute_input.py +97 -0
  146. model_analyzer/record/types/perf_server_compute_output.py +97 -0
  147. model_analyzer/record/types/perf_server_queue.py +97 -0
  148. model_analyzer/record/types/perf_throughput.py +105 -0
  149. model_analyzer/record/types/time_to_first_token_avg.py +60 -0
  150. model_analyzer/record/types/time_to_first_token_base.py +74 -0
  151. model_analyzer/record/types/time_to_first_token_max.py +60 -0
  152. model_analyzer/record/types/time_to_first_token_min.py +60 -0
  153. model_analyzer/record/types/time_to_first_token_p25.py +60 -0
  154. model_analyzer/record/types/time_to_first_token_p50.py +60 -0
  155. model_analyzer/record/types/time_to_first_token_p75.py +60 -0
  156. model_analyzer/record/types/time_to_first_token_p90.py +60 -0
  157. model_analyzer/record/types/time_to_first_token_p95.py +60 -0
  158. model_analyzer/record/types/time_to_first_token_p99.py +60 -0
  159. model_analyzer/reports/__init__.py +15 -0
  160. model_analyzer/reports/html_report.py +195 -0
  161. model_analyzer/reports/pdf_report.py +50 -0
  162. model_analyzer/reports/report.py +86 -0
  163. model_analyzer/reports/report_factory.py +62 -0
  164. model_analyzer/reports/report_manager.py +1376 -0
  165. model_analyzer/reports/report_utils.py +42 -0
  166. model_analyzer/result/__init__.py +15 -0
  167. model_analyzer/result/constraint_manager.py +150 -0
  168. model_analyzer/result/model_config_measurement.py +354 -0
  169. model_analyzer/result/model_constraints.py +105 -0
  170. model_analyzer/result/parameter_search.py +246 -0
  171. model_analyzer/result/result_manager.py +430 -0
  172. model_analyzer/result/result_statistics.py +159 -0
  173. model_analyzer/result/result_table.py +217 -0
  174. model_analyzer/result/result_table_manager.py +646 -0
  175. model_analyzer/result/result_utils.py +42 -0
  176. model_analyzer/result/results.py +277 -0
  177. model_analyzer/result/run_config_measurement.py +658 -0
  178. model_analyzer/result/run_config_result.py +210 -0
  179. model_analyzer/result/run_config_result_comparator.py +110 -0
  180. model_analyzer/result/sorted_results.py +151 -0
  181. model_analyzer/state/__init__.py +15 -0
  182. model_analyzer/state/analyzer_state.py +76 -0
  183. model_analyzer/state/analyzer_state_manager.py +215 -0
  184. model_analyzer/triton/__init__.py +15 -0
  185. model_analyzer/triton/client/__init__.py +15 -0
  186. model_analyzer/triton/client/client.py +234 -0
  187. model_analyzer/triton/client/client_factory.py +57 -0
  188. model_analyzer/triton/client/grpc_client.py +104 -0
  189. model_analyzer/triton/client/http_client.py +107 -0
  190. model_analyzer/triton/model/__init__.py +15 -0
  191. model_analyzer/triton/model/model_config.py +556 -0
  192. model_analyzer/triton/model/model_config_variant.py +29 -0
  193. model_analyzer/triton/server/__init__.py +15 -0
  194. model_analyzer/triton/server/server.py +76 -0
  195. model_analyzer/triton/server/server_config.py +269 -0
  196. model_analyzer/triton/server/server_docker.py +229 -0
  197. model_analyzer/triton/server/server_factory.py +306 -0
  198. model_analyzer/triton/server/server_local.py +158 -0
  199. triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
  200. triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
  201. triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
  202. triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
  203. triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
  204. triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,246 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from math import log2
19
+ from typing import Generator, List, Optional
20
+
21
+ from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
22
+ from model_analyzer.constants import (
23
+ LOGGER_NAME,
24
+ THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES,
25
+ THROUGHPUT_MINIMUM_GAIN,
26
+ )
27
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
28
+ from model_analyzer.result.run_config_measurement import RunConfigMeasurement
29
+
30
+ logger = logging.getLogger(LOGGER_NAME)
31
+
32
+
33
+ class ParameterSearch:
34
+ """
35
+ Generates the next parameter value to use when searching through
36
+ RunConfigMeasurements for the best value (according to the users objective)
37
+ - Will sweep from by powers of two from min to max parameter
38
+ - If the user specifies a constraint, the algorithm will perform a binary search
39
+ around the boundary if the constraint is violated
40
+
41
+ Invariant: It is necessary for the user to add new measurements as they are taken
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ config: ConfigCommandProfile,
47
+ model_parameters: dict = {},
48
+ skip_parameter_sweep: bool = False,
49
+ ) -> None:
50
+ """
51
+ Parameters
52
+ ----------
53
+ config: ConfigCommandProfile
54
+ Profile configuration information
55
+ skip_parameter_sweep: bool
56
+ If true, skips the parameter sweep and only does the binary search
57
+ """
58
+ self._skip_parameter_sweep = skip_parameter_sweep
59
+ self._parameter_is_request_rate = config.is_request_rate_specified(
60
+ model_parameters
61
+ )
62
+
63
+ if self._parameter_is_request_rate:
64
+ self._min_parameter_index = int(
65
+ log2(config.run_config_search_min_request_rate)
66
+ )
67
+ self._max_parameter_index = int(
68
+ log2(config.run_config_search_max_request_rate)
69
+ )
70
+
71
+ else:
72
+ self._min_parameter_index = int(
73
+ log2(config.run_config_search_min_concurrency)
74
+ )
75
+ self._max_parameter_index = int(
76
+ log2(config.run_config_search_max_concurrency)
77
+ )
78
+
79
+ self._max_binary_search_steps = config.run_config_search_max_binary_search_steps
80
+
81
+ self._run_config_measurements: List[Optional[RunConfigMeasurement]] = []
82
+ self._parameters: List[int] = []
83
+ self._last_failing_parameter = 0
84
+ self._last_passing_parameter = 0
85
+
86
+ def add_run_config_measurement(
87
+ self, run_config_measurement: Optional[RunConfigMeasurement]
88
+ ) -> None:
89
+ """
90
+ Adds a new RunConfigMeasurement
91
+ Invariant: Assumed that RCMs are added in the same order they are measured
92
+ """
93
+ self._run_config_measurements.append(run_config_measurement)
94
+
95
+ def search_parameters(self) -> Generator[int, None, None]:
96
+ """
97
+ First performs a parameter sweep, and then, if necessary, perform
98
+ a binary parameter search around the point where the constraint
99
+ violated
100
+ """
101
+ yield from self._perform_parameter_sweep()
102
+
103
+ if self._was_constraint_violated():
104
+ yield from self._perform_binary_parameter_search()
105
+
106
+ def _perform_parameter_sweep(self) -> Generator[int, None, None]:
107
+ for parameter in (
108
+ 2**i
109
+ for i in range(self._min_parameter_index, self._max_parameter_index + 1)
110
+ ):
111
+ if self._should_continue_parameter_sweep():
112
+ self._parameters.append(parameter)
113
+ yield parameter
114
+ else:
115
+ # We can't actually skip the sweep because the results need to be added
116
+ # but, we can suppress the logging messages
117
+ if not self._skip_parameter_sweep:
118
+ if self._parameter_is_request_rate:
119
+ logger.info(
120
+ "Terminating request rate sweep - throughput is decreasing"
121
+ )
122
+ else:
123
+ logger.info(
124
+ "Terminating concurrency sweep - throughput is decreasing"
125
+ )
126
+ return
127
+
128
+ def _should_continue_parameter_sweep(self) -> bool:
129
+ self._check_measurement_count()
130
+
131
+ if not self._are_minimum_tries_reached():
132
+ return True
133
+ else:
134
+ return not self._has_objective_gain_saturated()
135
+
136
+ def _check_measurement_count(self) -> None:
137
+ if len(self._run_config_measurements) != len(self._parameters):
138
+ raise TritonModelAnalyzerException(
139
+ f"Internal Measurement count: {self._parameters}, doesn't match number "
140
+ f"of measurements added: {len(self._run_config_measurements)}."
141
+ )
142
+
143
+ def _are_minimum_tries_reached(self) -> bool:
144
+ if (
145
+ len(self._run_config_measurements)
146
+ < THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES
147
+ ):
148
+ return False
149
+ else:
150
+ return True
151
+
152
+ def _has_objective_gain_saturated(self) -> bool:
153
+ gain = self._calculate_gain()
154
+ return gain < THROUGHPUT_MINIMUM_GAIN
155
+
156
+ def _calculate_gain(self) -> float:
157
+ first_rcm = self._run_config_measurements[
158
+ -THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES
159
+ ]
160
+
161
+ best_rcm = self._get_best_rcm()
162
+
163
+ # These cover the cases where we don't get a result from PA
164
+ if not first_rcm and not best_rcm:
165
+ return 0
166
+ if not first_rcm:
167
+ return 1
168
+ elif not best_rcm:
169
+ return -1
170
+ else:
171
+ gain = first_rcm.compare_measurements(best_rcm)
172
+
173
+ return gain
174
+
175
+ def _get_best_rcm(self) -> Optional[RunConfigMeasurement]:
176
+ # Need to remove entries (None) with no result from PA before sorting
177
+ pruned_rcms = [
178
+ rcm
179
+ for rcm in self._run_config_measurements[
180
+ -THROUGHPUT_MINIMUM_CONSECUTIVE_PARAMETER_TRIES:
181
+ ]
182
+ if rcm
183
+ ]
184
+ best_rcm = max(pruned_rcms) if pruned_rcms else None
185
+
186
+ return best_rcm
187
+
188
+ def _was_constraint_violated(self) -> bool:
189
+ for i in range(len(self._run_config_measurements) - 1, 1, -1):
190
+ if self._at_constraint_failure_boundary(i):
191
+ self._last_failing_parameter = self._parameters[i]
192
+ self._last_passing_parameter = self._parameters[i - 1]
193
+ return True
194
+
195
+ if (
196
+ self._run_config_measurements[0]
197
+ and not self._run_config_measurements[0].is_passing_constraints()
198
+ ):
199
+ self._last_failing_parameter = self._parameters[i]
200
+ self._last_passing_parameter = 0
201
+ return True
202
+ else:
203
+ return False
204
+
205
+ def _at_constraint_failure_boundary(self, index: int) -> bool:
206
+ if (
207
+ not self._run_config_measurements[index]
208
+ or not self._run_config_measurements[index - 1]
209
+ ):
210
+ return False
211
+
212
+ at_failure_boundary = (
213
+ not self._run_config_measurements[ # type: ignore
214
+ index
215
+ ].is_passing_constraints()
216
+ and self._run_config_measurements[
217
+ index - 1 # type: ignore
218
+ ].is_passing_constraints()
219
+ )
220
+
221
+ return at_failure_boundary
222
+
223
+ def _perform_binary_parameter_search(self) -> Generator[int, None, None]:
224
+ # This is needed because we are going to restart the search from the
225
+ # parameter that failed - so we expect this to be at the end of the list
226
+ self._parameters.append(self._last_failing_parameter)
227
+
228
+ for i in range(0, self._max_binary_search_steps):
229
+ parameter = self._determine_next_binary_parameter()
230
+
231
+ if parameter != self._parameters[-1]:
232
+ self._parameters.append(parameter)
233
+ yield parameter
234
+
235
+ def _determine_next_binary_parameter(self) -> int:
236
+ if not self._run_config_measurements[-1]:
237
+ return 0
238
+
239
+ if self._run_config_measurements[-1].is_passing_constraints():
240
+ self._last_passing_parameter = self._parameters[-1]
241
+ parameter = int((self._last_failing_parameter + self._parameters[-1]) / 2)
242
+ else:
243
+ self._last_failing_parameter = self._parameters[-1]
244
+ parameter = int((self._last_passing_parameter + self._parameters[-1]) / 2)
245
+
246
+ return parameter
@@ -0,0 +1,430 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from collections import defaultdict
18
+ from typing import DefaultDict, Union
19
+
20
+ from model_analyzer.config.generate.base_model_config_generator import (
21
+ BaseModelConfigGenerator,
22
+ )
23
+ from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
24
+ from model_analyzer.config.input.config_command_report import ConfigCommandReport
25
+ from model_analyzer.config.run.run_config import RunConfig
26
+ from model_analyzer.constants import TOP_MODELS_REPORT_KEY
27
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
28
+ from model_analyzer.result.constraint_manager import ConstraintManager
29
+ from model_analyzer.result.result_statistics import ResultStatistics
30
+ from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
31
+
32
+ from .results import Results
33
+ from .run_config_measurement import RunConfigMeasurement
34
+ from .run_config_result import RunConfigResult
35
+ from .run_config_result_comparator import RunConfigResultComparator
36
+ from .sorted_results import SortedResults
37
+
38
+
39
+ class ResultManager:
40
+ """
41
+ This class provides methods to create to hold
42
+ and sort results
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ config: Union[ConfigCommandProfile, ConfigCommandReport],
48
+ state_manager: AnalyzerStateManager,
49
+ constraint_manager: ConstraintManager,
50
+ ):
51
+ """
52
+ Parameters
53
+ ----------
54
+ config :ConfigCommandProfile/ConfigCommandReport
55
+ the model analyzer config
56
+ state_manager: AnalyzerStateManager
57
+ The object that allows control and update of state
58
+ constraint_manager: ConstraintManager
59
+ The object that handles processing and applying
60
+ constraints on a given measurements
61
+ """
62
+
63
+ self._config = config
64
+ self._state_manager = state_manager
65
+ self._constraint_manager = constraint_manager
66
+
67
+ # Data structures for sorting results
68
+ self._per_model_sorted_results: DefaultDict[str, SortedResults] = defaultdict(
69
+ SortedResults
70
+ )
71
+ self._across_model_sorted_results: SortedResults = SortedResults()
72
+
73
+ if state_manager.starting_fresh_run():
74
+ self._init_state()
75
+
76
+ self._complete_setup()
77
+
78
+ def get_model_names(self):
79
+ """
80
+ Returns a list of model names that have sorted results
81
+ """
82
+ return list(self._per_model_sorted_results.keys())
83
+
84
+ def get_model_sorted_results(self, model_name):
85
+ """
86
+ Returns a list of sorted results for the requested model
87
+ """
88
+ if model_name not in self._per_model_sorted_results:
89
+ raise TritonModelAnalyzerException(
90
+ f"model name {model_name} not found in result manager"
91
+ )
92
+ return self._per_model_sorted_results[model_name]
93
+
94
+ def get_across_model_sorted_results(self):
95
+ """
96
+ Returns a list of sorted results across all models
97
+ """
98
+ return self._across_model_sorted_results
99
+
100
+ def get_results(self):
101
+ """Returns all results (return type is Results)"""
102
+ return self._state_manager.get_state_variable("ResultManager.results")
103
+
104
+ def get_server_only_data(self):
105
+ """
106
+ Returns : dict
107
+ keys are gpu ids and values are lists of metric values
108
+ """
109
+ return self._state_manager.get_state_variable("ResultManager.server_only_data")
110
+
111
+ def add_server_data(self, data):
112
+ """
113
+ Adds data to directly to the server only table
114
+
115
+ Parameters
116
+ ----------
117
+ data : dict
118
+ keys are gpu ids and values are lists of metric values
119
+ """
120
+
121
+ self._state_manager.set_state_variable("ResultManager.server_only_data", data)
122
+
123
+ def add_run_config_measurement(
124
+ self, run_config: RunConfig, run_config_measurement: RunConfigMeasurement
125
+ ) -> None:
126
+ """
127
+ Add measurement to individual result heap,
128
+ global result heap and results class
129
+ """
130
+ model_name = run_config.models_name()
131
+
132
+ run_config_result = RunConfigResult(
133
+ model_name=model_name,
134
+ run_config=run_config,
135
+ comparator=self._run_comparators[model_name],
136
+ constraint_manager=self._constraint_manager,
137
+ )
138
+
139
+ run_config_measurement.set_metric_weightings(
140
+ self._run_comparators[model_name].get_metric_weights()
141
+ )
142
+
143
+ run_config_measurement.set_model_config_weighting(
144
+ self._run_comparators[model_name].get_model_weights()
145
+ )
146
+
147
+ self._add_rcm_to_results(run_config, run_config_measurement)
148
+ run_config_result.add_run_config_measurement(run_config_measurement)
149
+
150
+ self._per_model_sorted_results[model_name].add_result(run_config_result)
151
+ self._across_model_sorted_results.add_result(run_config_result)
152
+
153
+ def get_model_configs_run_config_measurements(self, model_variants_name):
154
+ """
155
+ Unsorted list of RunConfigMeasurements for a config
156
+
157
+ Parameters
158
+ ----------
159
+ model_variants_name: str
160
+
161
+ Returns
162
+ -------
163
+ (RunConfig, list of RunConfigMeasurements)
164
+ The measurements for a particular config, in the order
165
+ they were obtained.
166
+ """
167
+
168
+ results = self._state_manager.get_state_variable("ResultManager.results")
169
+
170
+ # Name format is <base_model_name>_config_<number_or_default>
171
+ #
172
+ model_name = BaseModelConfigGenerator.extract_model_name_from_variant_name(
173
+ model_variants_name
174
+ )
175
+
176
+ # Remote mode has model_name == model_config_name
177
+ #
178
+ if not results.contains_model(model_name):
179
+ model_name = model_variants_name
180
+
181
+ if results.contains_model(model_name) and results.contains_model_variant(
182
+ model_name, model_variants_name
183
+ ):
184
+ return results.get_all_model_variant_measurements(
185
+ model_name, model_variants_name
186
+ )
187
+ else:
188
+ raise TritonModelAnalyzerException(
189
+ f"RunConfig {model_variants_name} requested for report step but no results were found. "
190
+ "Double check the name and ensure that this model config was actually profiled."
191
+ )
192
+
193
+ def top_n_results(
194
+ self, model_name=None, n=SortedResults.GET_ALL_RESULTS, include_default=False
195
+ ):
196
+ """
197
+ Parameters
198
+ ----------
199
+ model_name: str
200
+ The name of the model
201
+ for which we need the top
202
+ n results.
203
+ n : int
204
+ The number of top results
205
+ to retrieve. Returns all by
206
+ default
207
+ include_default : bool
208
+ If true, the model's default config results will
209
+ be included in the returned results. In the case
210
+ that the default isn't one of the top n results,
211
+ then n+1 results will be returned
212
+ Returns
213
+ -------
214
+ list of RunConfigResults
215
+ The n best results for this model,
216
+ must all be passing results
217
+ """
218
+
219
+ if model_name:
220
+ results = self._per_model_sorted_results[model_name]
221
+ else:
222
+ results = self._across_model_sorted_results
223
+
224
+ top_results = results.top_n_results(n)
225
+
226
+ if include_default:
227
+ self._add_default_to_results(model_name, top_results, results)
228
+
229
+ return top_results
230
+
231
+ def get_result_statistics(self):
232
+ """
233
+ This function computes statistics
234
+ with results currently in the result
235
+ manager's heap
236
+ """
237
+
238
+ def _update_stats(statistics, sorted_results, stats_key):
239
+ passing_measurements = 0
240
+ failing_measurements = 0
241
+ total_configs = 0
242
+ for result in sorted_results.results():
243
+ total_configs += 1
244
+ passing_measurements += len(result.passing_measurements())
245
+ failing_measurements += len(result.failing_measurements())
246
+
247
+ statistics.set_total_configurations(stats_key, total_configs)
248
+ statistics.set_passing_measurements(stats_key, passing_measurements)
249
+ statistics.set_failing_measurements(stats_key, failing_measurements)
250
+
251
+ result_stats = ResultStatistics()
252
+ for model_name, sorted_results in self._per_model_sorted_results.items():
253
+ _update_stats(result_stats, sorted_results, model_name)
254
+
255
+ _update_stats(
256
+ result_stats, self._across_model_sorted_results, TOP_MODELS_REPORT_KEY
257
+ )
258
+
259
+ return result_stats
260
+
261
+ def _init_state(self):
262
+ """
263
+ Sets ResultManager object managed
264
+ state variables in AnalyerState
265
+ """
266
+
267
+ self._state_manager.set_state_variable("ResultManager.results", Results())
268
+ self._state_manager.set_state_variable("ResultManager.server_only_data", {})
269
+
270
+ def _complete_setup(self):
271
+ # The Report subcommand can init, but nothing needs to be done
272
+ if isinstance(self._config, ConfigCommandProfile):
273
+ self._complete_profile_setup()
274
+ elif isinstance(self._config, ConfigCommandReport):
275
+ pass
276
+ else:
277
+ raise TritonModelAnalyzerException(
278
+ f"Expected config of type ConfigCommandProfile/ConfigCommandReport,"
279
+ f" got {type(self._config)}."
280
+ )
281
+
282
+ def _complete_profile_setup(self):
283
+ self._create_concurrent_profile_model_name()
284
+
285
+ if self._config.run_config_profile_models_concurrently_enable:
286
+ self._setup_for_concurrent_profile()
287
+ else:
288
+ self._setup_for_sequential_profile()
289
+
290
+ self._add_results_to_heaps(suppress_warnings=True)
291
+
292
+ def _create_concurrent_profile_model_name(self):
293
+ profile_model_names = [
294
+ model.model_name() for model in self._config.profile_models
295
+ ]
296
+
297
+ self._concurrent_profile_model_name = ",".join(profile_model_names)
298
+
299
+ def _profiling_models_concurrently(self):
300
+ """
301
+ Returns
302
+ -------
303
+ bool: True if we are doing concurrent model profile
304
+ """
305
+ results = self._state_manager.get_state_variable("ResultManager.results")
306
+
307
+ return bool(
308
+ results.get_model_measurements_dict(
309
+ models_name=self._concurrent_profile_model_name, suppress_warning=True
310
+ )
311
+ and len(self._config.profile_models) > 1
312
+ )
313
+
314
+ def _setup_for_concurrent_profile(self):
315
+ self._profile_model_names = [self._concurrent_profile_model_name]
316
+
317
+ model_objectives_list = [
318
+ model.objectives() for model in self._config.profile_models
319
+ ]
320
+ model_weighting_list = [
321
+ model.weighting() for model in self._config.profile_models
322
+ ]
323
+
324
+ self._run_comparators = {
325
+ self._concurrent_profile_model_name: RunConfigResultComparator(
326
+ metric_objectives_list=model_objectives_list,
327
+ model_weights=model_weighting_list,
328
+ )
329
+ }
330
+
331
+ def _setup_for_sequential_profile(self):
332
+ self._profile_model_names = [
333
+ model.model_name() for model in self._config.profile_models
334
+ ]
335
+
336
+ self._run_comparators = {
337
+ model.model_name(): RunConfigResultComparator(
338
+ metric_objectives_list=[model.objectives()],
339
+ model_weights=[model.weighting()],
340
+ )
341
+ for model in self._config.profile_models
342
+ }
343
+
344
+ def _add_rcm_to_results(self, run_config, run_config_measurement):
345
+ """
346
+ This function adds model inference
347
+ measurements to the required result
348
+
349
+ Parameters
350
+ ----------
351
+ run_config : RunConfig
352
+ Contains the parameters used to generate the measurement
353
+ run_config_measurement: RunConfigMeasurement
354
+ the measurement to be added
355
+ """
356
+
357
+ # Get reference to results state and modify it
358
+ results = self._state_manager.get_state_variable("ResultManager.results")
359
+
360
+ results.add_run_config_measurement(run_config, run_config_measurement)
361
+
362
+ # Use set_state_variable to record that state may have been changed
363
+ self._state_manager.set_state_variable(
364
+ name="ResultManager.results", value=results
365
+ )
366
+
367
+ def _add_results_to_heaps(self, suppress_warnings=False):
368
+ """
369
+ Construct and add results to individual result heaps
370
+ as well as global result heap
371
+ """
372
+ results = self._state_manager.get_state_variable("ResultManager.results")
373
+
374
+ for model_name in self._profile_model_names:
375
+ model_measurements = results.get_model_measurements_dict(
376
+ model_name, suppress_warnings
377
+ )
378
+
379
+ # Only add in models that exist in the checkpoint
380
+ if not model_measurements:
381
+ continue
382
+
383
+ for run_config, run_config_measurements in model_measurements.values():
384
+ run_config_result = RunConfigResult(
385
+ model_name=model_name,
386
+ run_config=run_config,
387
+ comparator=self._run_comparators[model_name],
388
+ constraint_manager=self._constraint_manager,
389
+ )
390
+
391
+ for run_config_measurement in run_config_measurements.values():
392
+ run_config_measurement.set_metric_weightings(
393
+ self._run_comparators[model_name].get_metric_weights()
394
+ )
395
+
396
+ run_config_measurement.set_model_config_weighting(
397
+ self._run_comparators[model_name].get_model_weights()
398
+ )
399
+
400
+ run_config_result.add_run_config_measurement(run_config_measurement)
401
+
402
+ self._per_model_sorted_results[model_name].add_result(run_config_result)
403
+ self._across_model_sorted_results.add_result(run_config_result)
404
+
405
+ def _add_default_to_results(self, model_name, results, sorted_results):
406
+ """
407
+ If default config is already in results, keep it there. Else, find and
408
+ add it from the result heap
409
+ """
410
+ if not model_name:
411
+ return
412
+
413
+ model_names = model_name.split(",")
414
+ model_names = [model_name + "_config_default" for model_name in model_names]
415
+ default_model_name = ",".join(model_names)
416
+
417
+ for run_config_result in results:
418
+ if (
419
+ run_config_result.run_config().model_variants_name()
420
+ == default_model_name
421
+ ):
422
+ return
423
+
424
+ for run_config_result in sorted_results.results():
425
+ if (
426
+ run_config_result.run_config().model_variants_name()
427
+ == default_model_name
428
+ ):
429
+ results.append(run_config_result)
430
+ return