triton-model-analyzer 1.48.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. model_analyzer/__init__.py +15 -0
  2. model_analyzer/analyzer.py +448 -0
  3. model_analyzer/cli/__init__.py +15 -0
  4. model_analyzer/cli/cli.py +193 -0
  5. model_analyzer/config/__init__.py +15 -0
  6. model_analyzer/config/generate/__init__.py +15 -0
  7. model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
  8. model_analyzer/config/generate/base_model_config_generator.py +352 -0
  9. model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
  10. model_analyzer/config/generate/brute_run_config_generator.py +154 -0
  11. model_analyzer/config/generate/concurrency_sweeper.py +75 -0
  12. model_analyzer/config/generate/config_generator_interface.py +52 -0
  13. model_analyzer/config/generate/coordinate.py +143 -0
  14. model_analyzer/config/generate/coordinate_data.py +86 -0
  15. model_analyzer/config/generate/generator_utils.py +116 -0
  16. model_analyzer/config/generate/manual_model_config_generator.py +187 -0
  17. model_analyzer/config/generate/model_config_generator_factory.py +92 -0
  18. model_analyzer/config/generate/model_profile_spec.py +74 -0
  19. model_analyzer/config/generate/model_run_config_generator.py +154 -0
  20. model_analyzer/config/generate/model_variant_name_manager.py +150 -0
  21. model_analyzer/config/generate/neighborhood.py +536 -0
  22. model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
  23. model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
  24. model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
  25. model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
  26. model_analyzer/config/generate/quick_run_config_generator.py +753 -0
  27. model_analyzer/config/generate/run_config_generator_factory.py +329 -0
  28. model_analyzer/config/generate/search_config.py +112 -0
  29. model_analyzer/config/generate/search_dimension.py +73 -0
  30. model_analyzer/config/generate/search_dimensions.py +85 -0
  31. model_analyzer/config/generate/search_parameter.py +49 -0
  32. model_analyzer/config/generate/search_parameters.py +388 -0
  33. model_analyzer/config/input/__init__.py +15 -0
  34. model_analyzer/config/input/config_command.py +483 -0
  35. model_analyzer/config/input/config_command_profile.py +1747 -0
  36. model_analyzer/config/input/config_command_report.py +267 -0
  37. model_analyzer/config/input/config_defaults.py +236 -0
  38. model_analyzer/config/input/config_enum.py +83 -0
  39. model_analyzer/config/input/config_field.py +216 -0
  40. model_analyzer/config/input/config_list_generic.py +112 -0
  41. model_analyzer/config/input/config_list_numeric.py +151 -0
  42. model_analyzer/config/input/config_list_string.py +111 -0
  43. model_analyzer/config/input/config_none.py +71 -0
  44. model_analyzer/config/input/config_object.py +129 -0
  45. model_analyzer/config/input/config_primitive.py +81 -0
  46. model_analyzer/config/input/config_status.py +75 -0
  47. model_analyzer/config/input/config_sweep.py +83 -0
  48. model_analyzer/config/input/config_union.py +113 -0
  49. model_analyzer/config/input/config_utils.py +128 -0
  50. model_analyzer/config/input/config_value.py +243 -0
  51. model_analyzer/config/input/objects/__init__.py +15 -0
  52. model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
  53. model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
  54. model_analyzer/config/input/objects/config_plot.py +198 -0
  55. model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
  56. model_analyzer/config/input/yaml_config_validator.py +82 -0
  57. model_analyzer/config/run/__init__.py +15 -0
  58. model_analyzer/config/run/model_run_config.py +313 -0
  59. model_analyzer/config/run/run_config.py +168 -0
  60. model_analyzer/constants.py +76 -0
  61. model_analyzer/device/__init__.py +15 -0
  62. model_analyzer/device/device.py +24 -0
  63. model_analyzer/device/gpu_device.py +87 -0
  64. model_analyzer/device/gpu_device_factory.py +248 -0
  65. model_analyzer/entrypoint.py +307 -0
  66. model_analyzer/log_formatter.py +65 -0
  67. model_analyzer/model_analyzer_exceptions.py +24 -0
  68. model_analyzer/model_manager.py +255 -0
  69. model_analyzer/monitor/__init__.py +15 -0
  70. model_analyzer/monitor/cpu_monitor.py +69 -0
  71. model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
  72. model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
  73. model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
  74. model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
  75. model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
  76. model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
  77. model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
  78. model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
  79. model_analyzer/monitor/dcgm/__init__.py +15 -0
  80. model_analyzer/monitor/dcgm/common/__init__.py +13 -0
  81. model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
  82. model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
  83. model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
  84. model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
  85. model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
  86. model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
  87. model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
  88. model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
  89. model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
  90. model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
  91. model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
  92. model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
  93. model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
  94. model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
  95. model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
  96. model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
  97. model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
  98. model_analyzer/monitor/dcgm/pydcgm.py +47 -0
  99. model_analyzer/monitor/monitor.py +143 -0
  100. model_analyzer/monitor/remote_monitor.py +137 -0
  101. model_analyzer/output/__init__.py +15 -0
  102. model_analyzer/output/file_writer.py +63 -0
  103. model_analyzer/output/output_writer.py +42 -0
  104. model_analyzer/perf_analyzer/__init__.py +15 -0
  105. model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
  106. model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
  107. model_analyzer/perf_analyzer/perf_config.py +479 -0
  108. model_analyzer/plots/__init__.py +15 -0
  109. model_analyzer/plots/detailed_plot.py +266 -0
  110. model_analyzer/plots/plot_manager.py +224 -0
  111. model_analyzer/plots/simple_plot.py +213 -0
  112. model_analyzer/record/__init__.py +15 -0
  113. model_analyzer/record/gpu_record.py +68 -0
  114. model_analyzer/record/metrics_manager.py +887 -0
  115. model_analyzer/record/record.py +280 -0
  116. model_analyzer/record/record_aggregator.py +256 -0
  117. model_analyzer/record/types/__init__.py +15 -0
  118. model_analyzer/record/types/cpu_available_ram.py +93 -0
  119. model_analyzer/record/types/cpu_used_ram.py +93 -0
  120. model_analyzer/record/types/gpu_free_memory.py +96 -0
  121. model_analyzer/record/types/gpu_power_usage.py +107 -0
  122. model_analyzer/record/types/gpu_total_memory.py +96 -0
  123. model_analyzer/record/types/gpu_used_memory.py +96 -0
  124. model_analyzer/record/types/gpu_utilization.py +108 -0
  125. model_analyzer/record/types/inter_token_latency_avg.py +60 -0
  126. model_analyzer/record/types/inter_token_latency_base.py +74 -0
  127. model_analyzer/record/types/inter_token_latency_max.py +60 -0
  128. model_analyzer/record/types/inter_token_latency_min.py +60 -0
  129. model_analyzer/record/types/inter_token_latency_p25.py +60 -0
  130. model_analyzer/record/types/inter_token_latency_p50.py +60 -0
  131. model_analyzer/record/types/inter_token_latency_p75.py +60 -0
  132. model_analyzer/record/types/inter_token_latency_p90.py +60 -0
  133. model_analyzer/record/types/inter_token_latency_p95.py +60 -0
  134. model_analyzer/record/types/inter_token_latency_p99.py +60 -0
  135. model_analyzer/record/types/output_token_throughput.py +105 -0
  136. model_analyzer/record/types/perf_client_response_wait.py +97 -0
  137. model_analyzer/record/types/perf_client_send_recv.py +97 -0
  138. model_analyzer/record/types/perf_latency.py +111 -0
  139. model_analyzer/record/types/perf_latency_avg.py +60 -0
  140. model_analyzer/record/types/perf_latency_base.py +74 -0
  141. model_analyzer/record/types/perf_latency_p90.py +60 -0
  142. model_analyzer/record/types/perf_latency_p95.py +60 -0
  143. model_analyzer/record/types/perf_latency_p99.py +60 -0
  144. model_analyzer/record/types/perf_server_compute_infer.py +97 -0
  145. model_analyzer/record/types/perf_server_compute_input.py +97 -0
  146. model_analyzer/record/types/perf_server_compute_output.py +97 -0
  147. model_analyzer/record/types/perf_server_queue.py +97 -0
  148. model_analyzer/record/types/perf_throughput.py +105 -0
  149. model_analyzer/record/types/time_to_first_token_avg.py +60 -0
  150. model_analyzer/record/types/time_to_first_token_base.py +74 -0
  151. model_analyzer/record/types/time_to_first_token_max.py +60 -0
  152. model_analyzer/record/types/time_to_first_token_min.py +60 -0
  153. model_analyzer/record/types/time_to_first_token_p25.py +60 -0
  154. model_analyzer/record/types/time_to_first_token_p50.py +60 -0
  155. model_analyzer/record/types/time_to_first_token_p75.py +60 -0
  156. model_analyzer/record/types/time_to_first_token_p90.py +60 -0
  157. model_analyzer/record/types/time_to_first_token_p95.py +60 -0
  158. model_analyzer/record/types/time_to_first_token_p99.py +60 -0
  159. model_analyzer/reports/__init__.py +15 -0
  160. model_analyzer/reports/html_report.py +195 -0
  161. model_analyzer/reports/pdf_report.py +50 -0
  162. model_analyzer/reports/report.py +86 -0
  163. model_analyzer/reports/report_factory.py +62 -0
  164. model_analyzer/reports/report_manager.py +1376 -0
  165. model_analyzer/reports/report_utils.py +42 -0
  166. model_analyzer/result/__init__.py +15 -0
  167. model_analyzer/result/constraint_manager.py +150 -0
  168. model_analyzer/result/model_config_measurement.py +354 -0
  169. model_analyzer/result/model_constraints.py +105 -0
  170. model_analyzer/result/parameter_search.py +246 -0
  171. model_analyzer/result/result_manager.py +430 -0
  172. model_analyzer/result/result_statistics.py +159 -0
  173. model_analyzer/result/result_table.py +217 -0
  174. model_analyzer/result/result_table_manager.py +646 -0
  175. model_analyzer/result/result_utils.py +42 -0
  176. model_analyzer/result/results.py +277 -0
  177. model_analyzer/result/run_config_measurement.py +658 -0
  178. model_analyzer/result/run_config_result.py +210 -0
  179. model_analyzer/result/run_config_result_comparator.py +110 -0
  180. model_analyzer/result/sorted_results.py +151 -0
  181. model_analyzer/state/__init__.py +15 -0
  182. model_analyzer/state/analyzer_state.py +76 -0
  183. model_analyzer/state/analyzer_state_manager.py +215 -0
  184. model_analyzer/triton/__init__.py +15 -0
  185. model_analyzer/triton/client/__init__.py +15 -0
  186. model_analyzer/triton/client/client.py +234 -0
  187. model_analyzer/triton/client/client_factory.py +57 -0
  188. model_analyzer/triton/client/grpc_client.py +104 -0
  189. model_analyzer/triton/client/http_client.py +107 -0
  190. model_analyzer/triton/model/__init__.py +15 -0
  191. model_analyzer/triton/model/model_config.py +556 -0
  192. model_analyzer/triton/model/model_config_variant.py +29 -0
  193. model_analyzer/triton/server/__init__.py +15 -0
  194. model_analyzer/triton/server/server.py +76 -0
  195. model_analyzer/triton/server/server_config.py +269 -0
  196. model_analyzer/triton/server/server_docker.py +229 -0
  197. model_analyzer/triton/server/server_factory.py +306 -0
  198. model_analyzer/triton/server/server_local.py +158 -0
  199. triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
  200. triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
  201. triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
  202. triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
  203. triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
  204. triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,107 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+
19
+ import gevent.ssl
20
+ import tritonclient.http as httpclient
21
+
22
+ from .client import TritonClient
23
+
24
+
25
+ class TritonHTTPClient(TritonClient):
26
+ """
27
+ Concrete implementation of TritonClient
28
+ for HTTP
29
+ """
30
+
31
+ def __init__(self, server_url, ssl_options={}):
32
+ """
33
+ Parameters
34
+ ----------
35
+ server_url : str
36
+ The url for Triton server's HTTP endpoint
37
+ ssl_options : dict
38
+ Dictionary of SSL options for HTTP python client
39
+ """
40
+
41
+ ssl = False
42
+ client_ssl_options = {}
43
+ ssl_context_factory = gevent.ssl._create_unverified_context
44
+ insecure = True
45
+ verify_peer = 0
46
+ verify_host = 0
47
+
48
+ if server_url.startswith("http://"):
49
+ server_url = server_url.replace("http://", "", 1)
50
+ elif server_url.startswith("https://"):
51
+ ssl = True
52
+ server_url = server_url.replace("https://", "", 1)
53
+ if "ssl-https-ca-certificates-file" in ssl_options:
54
+ client_ssl_options["ca_certs"] = ssl_options[
55
+ "ssl-https-ca-certificates-file"
56
+ ]
57
+ if "ssl-https-client-certificate-file" in ssl_options:
58
+ if (
59
+ "ssl-https-client-certificate-type" in ssl_options
60
+ and ssl_options["ssl-https-client-certificate-type"] == "PEM"
61
+ ):
62
+ client_ssl_options["certfile"] = ssl_options[
63
+ "ssl-https-client-certificate-file"
64
+ ]
65
+ else:
66
+ logging.warning(
67
+ "model-analyzer with SSL must be passed a client certificate file in PEM format."
68
+ )
69
+ if "ssl-https-private-key-file" in ssl_options:
70
+ if (
71
+ "ssl-https-private-key-type" in ssl_options
72
+ and ssl_options["ssl-https-private-key-type"] == "PEM"
73
+ ):
74
+ client_ssl_options["keyfile"] = ssl_options[
75
+ "ssl-https-private-key-file"
76
+ ]
77
+ else:
78
+ logging.warning(
79
+ "model-analyzer with SSL must be passed a private key file in PEM format."
80
+ )
81
+ if "ssl-https-verify-peer" in ssl_options:
82
+ verify_peer = ssl_options["ssl-https-verify-peer"]
83
+ if "ssl-https-verify-host" in ssl_options:
84
+ verify_host = ssl_options["ssl-https-verify-host"]
85
+ if verify_peer != 0 and verify_host != 0:
86
+ ssl_context_factory = None
87
+ insecure = False
88
+
89
+ self._client = httpclient.InferenceServerClient(
90
+ url=server_url,
91
+ ssl=ssl,
92
+ ssl_options=client_ssl_options,
93
+ ssl_context_factory=ssl_context_factory,
94
+ insecure=insecure,
95
+ )
96
+
97
+ def get_model_repository_index(self):
98
+ """
99
+ Returns the JSON dict holding the model repository index.
100
+ """
101
+ return self._client.get_model_repository_index()
102
+
103
+ def is_model_ready(self, model_name: str) -> bool:
104
+ """
105
+ Returns true if the model is loaded on the server
106
+ """
107
+ return self._client.is_model_ready(model_name)
@@ -0,0 +1,15 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
@@ -0,0 +1,556 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import json
18
+ import os
19
+ from copy import deepcopy
20
+ from shutil import copytree
21
+ from typing import Any, Dict, List, Optional
22
+
23
+ from google.protobuf import json_format, text_format
24
+ from google.protobuf.descriptor import FieldDescriptor
25
+ from numba import cuda
26
+ from tritonclient.grpc import model_config_pb2
27
+
28
+ from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
29
+ from model_analyzer.config.input.objects.config_model_profile_spec import (
30
+ ConfigModelProfileSpec,
31
+ )
32
+ from model_analyzer.device.gpu_device import GPUDevice
33
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
34
+ from model_analyzer.triton.client.client import TritonClient
35
+ from model_analyzer.triton.server.server_factory import TritonServerFactory
36
+
37
+
38
+ class ModelConfig:
39
+ """
40
+ A class that encapsulates all the metadata about a Triton model.
41
+ """
42
+
43
+ _default_config_dict: Dict[str, Any] = {}
44
+
45
+ def __init__(self, model_config):
46
+ """
47
+ Parameters
48
+ -------
49
+ model_config : protobuf message
50
+ """
51
+
52
+ self._model_config = model_config
53
+
54
+ def to_dict(self):
55
+ model_config_dict = json_format.MessageToDict(self._model_config)
56
+ return model_config_dict
57
+
58
+ @classmethod
59
+ def from_dict(cls, model_config_dict):
60
+ return ModelConfig.create_from_dictionary(model_config_dict)
61
+
62
+ @staticmethod
63
+ def create_model_config_dict(config, client, gpus, model_repository, model_name):
64
+ """
65
+ Attempts to create a base model config dict from config.pbtxt, if one exists
66
+ If the config.pbtxt is not present, we will load a Triton Server with the
67
+ base model and have it create a default config for MA, if possible
68
+
69
+ Parameters:
70
+ -----------
71
+ config: ModelAnalyzerConfig
72
+ client: TritonClient
73
+ gpus: List of GPUDevices
74
+ model_repository: str
75
+ path to the model repository on the file system
76
+ model_name: str
77
+ name of the base model
78
+ """
79
+
80
+ if (
81
+ ModelConfig._default_config_dict
82
+ and model_name in ModelConfig._default_config_dict
83
+ ):
84
+ return deepcopy(ModelConfig._default_config_dict[model_name])
85
+
86
+ model_path = f"{model_repository}/{model_name}"
87
+
88
+ try:
89
+ config = ModelConfig._create_from_file(model_path).get_config()
90
+ except Exception:
91
+ if ModelConfig._can_launch_mode_get_default_config_from_server(config):
92
+ config = ModelConfig._get_default_config_from_server(
93
+ config, client, gpus, model_name
94
+ )
95
+
96
+ # An auto-completed triton model config will set preferred_batch_size
97
+ # to a default value. We do not want to keep and honor that
98
+ # value when we are searching, so we discard it here
99
+ if (
100
+ "dynamic_batching" in config
101
+ and "preferred_batch_size" in config["dynamic_batching"]
102
+ ):
103
+ del config["dynamic_batching"]["preferred_batch_size"]
104
+
105
+ else:
106
+ ModelConfig._check_default_config_exceptions(config, model_path)
107
+
108
+ ModelConfig._default_config_dict[model_name] = config
109
+ return deepcopy(config)
110
+
111
+ @staticmethod
112
+ def _can_launch_mode_get_default_config_from_server(config):
113
+ launch_mode_can_get_default_config = (
114
+ config.triton_launch_mode == "docker"
115
+ or config.triton_launch_mode == "local"
116
+ or config.triton_launch_mode == "remote"
117
+ )
118
+ return launch_mode_can_get_default_config
119
+
120
+ @staticmethod
121
+ def _get_default_config_from_server(config, client, gpus, model_name):
122
+ """
123
+ Load a Triton Server with the base model and have it create
124
+ a default config for MA, if possible
125
+
126
+ Parameters:
127
+ -----------
128
+ config: ModelAnalyzerConfig
129
+ client: TritonClient
130
+ gpus: List of GPUDevices
131
+ model_name: str
132
+ name of the base model
133
+ """
134
+
135
+ server = TritonServerFactory.get_server_handle(
136
+ config, gpus, use_model_repository=True
137
+ )
138
+
139
+ server.start()
140
+ client.wait_for_server_ready(
141
+ num_retries=config.client_max_retries, log_file=server.log_file()
142
+ )
143
+
144
+ if client.load_model(model_name=model_name) == -1:
145
+ server.stop()
146
+
147
+ client.wait_for_model_ready(model_name, config.client_max_retries)
148
+
149
+ config = client.get_model_config(model_name, config.client_max_retries)
150
+
151
+ server.stop()
152
+
153
+ if "input" not in config or "output" not in config:
154
+ raise TritonModelAnalyzerException(
155
+ "Attempted have Triton create a default config, but this is not possible for this model type."
156
+ )
157
+
158
+ return config
159
+
160
+ @staticmethod
161
+ def _check_default_config_exceptions(config, model_path):
162
+ if config.triton_launch_mode != "remote":
163
+ if not os.path.exists(model_path):
164
+ raise TritonModelAnalyzerException(
165
+ f'Model path "{model_path}" specified does not exist.'
166
+ )
167
+
168
+ if os.path.isfile(model_path):
169
+ raise TritonModelAnalyzerException(
170
+ f'Model output path "{model_path}" must be a directory.'
171
+ )
172
+
173
+ model_config_path = os.path.join(model_path, "config.pbtxt")
174
+ raise TritonModelAnalyzerException(
175
+ f'Path "{model_config_path}" does not exist.'
176
+ f" Triton does not support default config creation for {config.triton_launch_mode} mode."
177
+ )
178
+
179
+ @staticmethod
180
+ def _create_from_file(model_path):
181
+ """
182
+ Constructs a ModelConfig from the pbtxt at file
183
+
184
+ Parameters
185
+ -------
186
+ model_path : str
187
+ The full path to this model directory
188
+
189
+ Returns
190
+ -------
191
+ ModelConfig
192
+ """
193
+
194
+ if not os.path.exists(model_path):
195
+ raise TritonModelAnalyzerException(
196
+ f'Model path "{model_path}" specified does not exist.'
197
+ )
198
+
199
+ if os.path.isfile(model_path):
200
+ raise TritonModelAnalyzerException(
201
+ f'Model output path "{model_path}" must be a directory.'
202
+ )
203
+
204
+ model_config_path = os.path.join(model_path, "config.pbtxt")
205
+ if not os.path.isfile(model_config_path):
206
+ raise TritonModelAnalyzerException(
207
+ f'Path "{model_config_path}" does not exist.'
208
+ " Make sure that you have specified the correct model"
209
+ " repository and model name(s)."
210
+ )
211
+
212
+ with open(model_config_path, "r+") as f:
213
+ config_str = f.read()
214
+
215
+ protobuf_message = text_format.Parse(config_str, model_config_pb2.ModelConfig())
216
+
217
+ return ModelConfig(protobuf_message)
218
+
219
+ @staticmethod
220
+ def create_from_dictionary(model_dict):
221
+ """
222
+ Constructs a ModelConfig from a Python dictionary
223
+
224
+ Parameters
225
+ -------
226
+ model_dict : dict
227
+ A dictionary containing the model configuration.
228
+
229
+ Returns
230
+ -------
231
+ ModelConfig
232
+ """
233
+
234
+ protobuf_message = json_format.ParseDict(
235
+ model_dict, model_config_pb2.ModelConfig()
236
+ )
237
+
238
+ return ModelConfig(protobuf_message)
239
+
240
+ @staticmethod
241
+ def create_from_triton_api(client, model_name, num_retries):
242
+ """
243
+ Creates the model config from the Triton API.
244
+
245
+ Parameters
246
+ ----------
247
+ client : TritonClient
248
+ Triton client to use to call the API
249
+ model_name : str
250
+ Name of the model to request config for.
251
+ num_retries : int
252
+ Number of times to try loading the model.
253
+ """
254
+
255
+ model_config_dict = client.get_model_config(model_name, num_retries)
256
+
257
+ return ModelConfig.create_from_dictionary(model_config_dict)
258
+
259
+ @staticmethod
260
+ def create_from_profile_spec(
261
+ spec: ConfigModelProfileSpec,
262
+ config: ConfigCommandProfile,
263
+ client: TritonClient,
264
+ gpus: List[GPUDevice],
265
+ ) -> "ModelConfig":
266
+ """
267
+ Creates the model config from a ModelProfileSpec, plus assoc. collateral
268
+ """
269
+
270
+ model_config_dict = ModelConfig.create_model_config_dict(
271
+ config=config,
272
+ client=client,
273
+ gpus=gpus,
274
+ model_repository=config.model_repository,
275
+ model_name=spec.model_name(),
276
+ )
277
+
278
+ model_config = ModelConfig.create_from_dictionary(model_config_dict)
279
+
280
+ return model_config
281
+
282
+ def is_ensemble(self) -> bool:
283
+ """
284
+ Returns
285
+ -------
286
+ bool
287
+ True if this is an ensemble model
288
+ """
289
+
290
+ return getattr(self._model_config, "platform") == "ensemble"
291
+
292
+ def get_ensemble_composing_models(self) -> Optional[List[str]]:
293
+ """
294
+ Returns
295
+ -------
296
+ List[str]: Sub-model names
297
+ """
298
+
299
+ if not self.is_ensemble():
300
+ raise TritonModelAnalyzerException(
301
+ "Cannot find composing_models. Model platform is not ensemble."
302
+ )
303
+
304
+ try:
305
+ composing_models = [
306
+ model["modelName"]
307
+ for model in self.to_dict()["ensembleScheduling"]["step"]
308
+ ]
309
+ except Exception:
310
+ raise TritonModelAnalyzerException(
311
+ "Cannot find composing_models. Ensemble Scheduling and/or step is not present in config protobuf."
312
+ )
313
+
314
+ return composing_models
315
+
316
+ def set_composing_model_variant_name(
317
+ self, composing_model_name: str, variant_name: str
318
+ ) -> None:
319
+ """
320
+ Replaces the Ensembles composing_model's name with the variant name
321
+ """
322
+
323
+ if not self.is_ensemble():
324
+ raise TritonModelAnalyzerException(
325
+ "Cannot find composing_models. Model platform is not ensemble."
326
+ )
327
+
328
+ model_config_dict = self.to_dict()
329
+
330
+ try:
331
+ for composing_model in model_config_dict["ensembleScheduling"]["step"]:
332
+ if composing_model["modelName"] == composing_model_name:
333
+ composing_model["modelName"] = variant_name
334
+ except Exception:
335
+ raise TritonModelAnalyzerException(
336
+ "Cannot find composing_models. Ensemble Scheduling and/or step is not present in config protobuf."
337
+ )
338
+
339
+ self._model_config = self.from_dict(model_config_dict)._model_config
340
+
341
+ def set_model_name(self, model_name: str) -> None:
342
+ model_config_dict = self.to_dict()
343
+ model_config_dict["name"] = model_name
344
+ self._model_config = self.from_dict(model_config_dict)._model_config
345
+
346
+ def write_config_to_file(
347
+ self, model_path, src_model_path, first_variant_model_path
348
+ ):
349
+ """
350
+ Writes a protobuf config file.
351
+
352
+ Parameters
353
+ ----------
354
+ model_path : str
355
+ Path to write the model config.
356
+
357
+ src_model_path : str
358
+ Path to the source model in the Triton Model Repository
359
+
360
+ first_variant_model_path : str
361
+ Indicates the path to the first model variant.
362
+
363
+ Raises
364
+ ------
365
+ TritonModelAnalyzerException
366
+ If the path doesn't exist or the path is a file
367
+ """
368
+
369
+ if not os.path.exists(model_path):
370
+ raise TritonModelAnalyzerException("Output path specified does not exist.")
371
+
372
+ if os.path.isfile(model_path):
373
+ raise TritonModelAnalyzerException("Model output path must be a directory.")
374
+
375
+ model_config_bytes = text_format.MessageToBytes(self._model_config)
376
+ # Create current variant model as symlinks to first variant model
377
+ if first_variant_model_path is not None:
378
+ for file in os.listdir(first_variant_model_path):
379
+ # Do not copy the config.pbtxt file
380
+ if file == "config.pbtxt":
381
+ continue
382
+ else:
383
+ os.symlink(
384
+ os.path.join(
385
+ os.path.relpath(first_variant_model_path, model_path), file
386
+ ),
387
+ os.path.join(model_path, file),
388
+ )
389
+ else:
390
+ # Create first variant model as copy of source model
391
+ copytree(src_model_path, model_path, dirs_exist_ok=True)
392
+
393
+ with open(os.path.join(model_path, "config.pbtxt"), "wb") as f:
394
+ f.write(model_config_bytes)
395
+
396
+ def get_config(self):
397
+ """
398
+ Get the model config.
399
+
400
+ Returns
401
+ -------
402
+ dict
403
+ A dictionary containing the model configuration.
404
+ """
405
+
406
+ return json_format.MessageToDict(
407
+ self._model_config, preserving_proto_field_name=True
408
+ )
409
+
410
+ def get_config_str(self):
411
+ """
412
+ Get the model config json str
413
+
414
+ Returns
415
+ -------
416
+ str
417
+ A JSON string containing the model configuration.
418
+ """
419
+ return json.dumps(self.get_config())
420
+
421
+ def set_config(self, config):
422
+ """
423
+ Set the model config from a dictionary.
424
+
425
+ Parameters
426
+ ----------
427
+ config : dict
428
+ The new dictionary containing the model config.
429
+ """
430
+
431
+ self._model_config = json_format.ParseDict(
432
+ config, model_config_pb2.ModelConfig()
433
+ )
434
+
435
+ def set_field(self, name, value):
436
+ """
437
+ Set a value for a Model Config field.
438
+
439
+ Parameters
440
+ ----------
441
+ name : str
442
+ Name of the field
443
+ value : object
444
+ The value to be used for the field.
445
+ """
446
+ model_config = self._model_config
447
+
448
+ if (
449
+ model_config.DESCRIPTOR.fields_by_name[name].label
450
+ == FieldDescriptor.LABEL_REPEATED
451
+ ):
452
+ repeated_field = getattr(model_config, name)
453
+ del repeated_field[:]
454
+ repeated_field.extend(value)
455
+ else:
456
+ setattr(model_config, name, value)
457
+
458
+ def get_field(self, name):
459
+ """
460
+ Get the value for the current field.
461
+ """
462
+
463
+ model_config = self._model_config
464
+ return getattr(model_config, name)
465
+
466
+ def max_batch_size(self) -> int:
467
+ """
468
+ Returns the max batch size (int)
469
+ """
470
+
471
+ model_config = self.get_config()
472
+ return model_config.get("max_batch_size", 0)
473
+
474
+ def dynamic_batching_string(self) -> str:
475
+ """
476
+ Returns
477
+ -------
478
+ str
479
+ representation of the dynamic batcher
480
+ configuration used to generate this result
481
+ """
482
+
483
+ model_config = self.get_config()
484
+ if "dynamic_batching" in model_config:
485
+ return "Enabled"
486
+ else:
487
+ return "Disabled"
488
+
489
+ def instance_group_count(self, system_gpu_count: int) -> int:
490
+ """
491
+ Returns:
492
+ int: The total number of instance groups (cpu + gpu)
493
+ """
494
+
495
+ kind_to_count = self._get_instance_groups(system_gpu_count)
496
+ instance_group_count = sum([count for count in kind_to_count.values()])
497
+
498
+ return instance_group_count
499
+
500
+ def instance_group_string(self, system_gpu_count: int) -> str:
501
+ """
502
+ Returns
503
+ -------
504
+ str
505
+ representation of the instance group used
506
+ to generate this result
507
+
508
+ Format is "GPU:<count> + CPU:<count>"
509
+ """
510
+
511
+ kind_to_count = self._get_instance_groups(system_gpu_count)
512
+
513
+ ret_str = ""
514
+ for k, v in kind_to_count.items():
515
+ if ret_str != "":
516
+ ret_str += " + "
517
+ ret_str += f"{v}:{k}"
518
+ return ret_str
519
+
520
+ def _get_instance_groups(self, system_gpu_count: int) -> Dict[str, int]:
521
+ """
522
+ Returns a dictionary with type of instance (GPU/CPU) and its count
523
+ """
524
+ model_config = self.get_config()
525
+
526
+ # TODO change when remote mode is fixed
527
+ default_kind = "GPU" if cuda.is_available() else "CPU"
528
+ default_count = 1
529
+
530
+ instance_group_list: List[Dict[str, Any]] = [{}]
531
+ if "instance_group" in model_config:
532
+ instance_group_list = model_config["instance_group"]
533
+
534
+ kind_to_count: Dict[str, Any] = {}
535
+
536
+ for group in instance_group_list:
537
+ group_kind = default_kind
538
+ group_count = default_count
539
+ group_gpus_count = system_gpu_count
540
+ # Update with instance group values
541
+ if "kind" in group:
542
+ group_kind = group["kind"].split("_")[1]
543
+ if "count" in group:
544
+ group_count = group["count"]
545
+ if "gpus" in group:
546
+ group_gpus_count = len(group["gpus"])
547
+
548
+ group_total_count = group_count
549
+ if group_kind == "GPU":
550
+ group_total_count *= group_gpus_count
551
+
552
+ if group_kind not in kind_to_count:
553
+ kind_to_count[group_kind] = 0
554
+ kind_to_count[group_kind] += group_total_count
555
+
556
+ return kind_to_count
@@ -0,0 +1,29 @@
1
+ # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+
17
+ from model_analyzer.triton.model.model_config import ModelConfig
18
+
19
+
20
+ @dataclass
21
+ class ModelConfigVariant:
22
+ """
23
+ A dataclass that holds the ModelConfig as well as the variant name
24
+ and cpu_only flag for the model
25
+ """
26
+
27
+ model_config: ModelConfig
28
+ variant_name: str
29
+ cpu_only: bool = False