triton-model-analyzer 1.48.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (204) hide show
  1. model_analyzer/__init__.py +15 -0
  2. model_analyzer/analyzer.py +448 -0
  3. model_analyzer/cli/__init__.py +15 -0
  4. model_analyzer/cli/cli.py +193 -0
  5. model_analyzer/config/__init__.py +15 -0
  6. model_analyzer/config/generate/__init__.py +15 -0
  7. model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
  8. model_analyzer/config/generate/base_model_config_generator.py +352 -0
  9. model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
  10. model_analyzer/config/generate/brute_run_config_generator.py +154 -0
  11. model_analyzer/config/generate/concurrency_sweeper.py +75 -0
  12. model_analyzer/config/generate/config_generator_interface.py +52 -0
  13. model_analyzer/config/generate/coordinate.py +143 -0
  14. model_analyzer/config/generate/coordinate_data.py +86 -0
  15. model_analyzer/config/generate/generator_utils.py +116 -0
  16. model_analyzer/config/generate/manual_model_config_generator.py +187 -0
  17. model_analyzer/config/generate/model_config_generator_factory.py +92 -0
  18. model_analyzer/config/generate/model_profile_spec.py +74 -0
  19. model_analyzer/config/generate/model_run_config_generator.py +154 -0
  20. model_analyzer/config/generate/model_variant_name_manager.py +150 -0
  21. model_analyzer/config/generate/neighborhood.py +536 -0
  22. model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
  23. model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
  24. model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
  25. model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
  26. model_analyzer/config/generate/quick_run_config_generator.py +753 -0
  27. model_analyzer/config/generate/run_config_generator_factory.py +329 -0
  28. model_analyzer/config/generate/search_config.py +112 -0
  29. model_analyzer/config/generate/search_dimension.py +73 -0
  30. model_analyzer/config/generate/search_dimensions.py +85 -0
  31. model_analyzer/config/generate/search_parameter.py +49 -0
  32. model_analyzer/config/generate/search_parameters.py +388 -0
  33. model_analyzer/config/input/__init__.py +15 -0
  34. model_analyzer/config/input/config_command.py +483 -0
  35. model_analyzer/config/input/config_command_profile.py +1747 -0
  36. model_analyzer/config/input/config_command_report.py +267 -0
  37. model_analyzer/config/input/config_defaults.py +236 -0
  38. model_analyzer/config/input/config_enum.py +83 -0
  39. model_analyzer/config/input/config_field.py +216 -0
  40. model_analyzer/config/input/config_list_generic.py +112 -0
  41. model_analyzer/config/input/config_list_numeric.py +151 -0
  42. model_analyzer/config/input/config_list_string.py +111 -0
  43. model_analyzer/config/input/config_none.py +71 -0
  44. model_analyzer/config/input/config_object.py +129 -0
  45. model_analyzer/config/input/config_primitive.py +81 -0
  46. model_analyzer/config/input/config_status.py +75 -0
  47. model_analyzer/config/input/config_sweep.py +83 -0
  48. model_analyzer/config/input/config_union.py +113 -0
  49. model_analyzer/config/input/config_utils.py +128 -0
  50. model_analyzer/config/input/config_value.py +243 -0
  51. model_analyzer/config/input/objects/__init__.py +15 -0
  52. model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
  53. model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
  54. model_analyzer/config/input/objects/config_plot.py +198 -0
  55. model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
  56. model_analyzer/config/input/yaml_config_validator.py +82 -0
  57. model_analyzer/config/run/__init__.py +15 -0
  58. model_analyzer/config/run/model_run_config.py +313 -0
  59. model_analyzer/config/run/run_config.py +168 -0
  60. model_analyzer/constants.py +76 -0
  61. model_analyzer/device/__init__.py +15 -0
  62. model_analyzer/device/device.py +24 -0
  63. model_analyzer/device/gpu_device.py +87 -0
  64. model_analyzer/device/gpu_device_factory.py +248 -0
  65. model_analyzer/entrypoint.py +307 -0
  66. model_analyzer/log_formatter.py +65 -0
  67. model_analyzer/model_analyzer_exceptions.py +24 -0
  68. model_analyzer/model_manager.py +255 -0
  69. model_analyzer/monitor/__init__.py +15 -0
  70. model_analyzer/monitor/cpu_monitor.py +69 -0
  71. model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
  72. model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
  73. model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
  74. model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
  75. model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
  76. model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
  77. model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
  78. model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
  79. model_analyzer/monitor/dcgm/__init__.py +15 -0
  80. model_analyzer/monitor/dcgm/common/__init__.py +13 -0
  81. model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
  82. model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
  83. model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
  84. model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
  85. model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
  86. model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
  87. model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
  88. model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
  89. model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
  90. model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
  91. model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
  92. model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
  93. model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
  94. model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
  95. model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
  96. model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
  97. model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
  98. model_analyzer/monitor/dcgm/pydcgm.py +47 -0
  99. model_analyzer/monitor/monitor.py +143 -0
  100. model_analyzer/monitor/remote_monitor.py +137 -0
  101. model_analyzer/output/__init__.py +15 -0
  102. model_analyzer/output/file_writer.py +63 -0
  103. model_analyzer/output/output_writer.py +42 -0
  104. model_analyzer/perf_analyzer/__init__.py +15 -0
  105. model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
  106. model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
  107. model_analyzer/perf_analyzer/perf_config.py +479 -0
  108. model_analyzer/plots/__init__.py +15 -0
  109. model_analyzer/plots/detailed_plot.py +266 -0
  110. model_analyzer/plots/plot_manager.py +224 -0
  111. model_analyzer/plots/simple_plot.py +213 -0
  112. model_analyzer/record/__init__.py +15 -0
  113. model_analyzer/record/gpu_record.py +68 -0
  114. model_analyzer/record/metrics_manager.py +887 -0
  115. model_analyzer/record/record.py +280 -0
  116. model_analyzer/record/record_aggregator.py +256 -0
  117. model_analyzer/record/types/__init__.py +15 -0
  118. model_analyzer/record/types/cpu_available_ram.py +93 -0
  119. model_analyzer/record/types/cpu_used_ram.py +93 -0
  120. model_analyzer/record/types/gpu_free_memory.py +96 -0
  121. model_analyzer/record/types/gpu_power_usage.py +107 -0
  122. model_analyzer/record/types/gpu_total_memory.py +96 -0
  123. model_analyzer/record/types/gpu_used_memory.py +96 -0
  124. model_analyzer/record/types/gpu_utilization.py +108 -0
  125. model_analyzer/record/types/inter_token_latency_avg.py +60 -0
  126. model_analyzer/record/types/inter_token_latency_base.py +74 -0
  127. model_analyzer/record/types/inter_token_latency_max.py +60 -0
  128. model_analyzer/record/types/inter_token_latency_min.py +60 -0
  129. model_analyzer/record/types/inter_token_latency_p25.py +60 -0
  130. model_analyzer/record/types/inter_token_latency_p50.py +60 -0
  131. model_analyzer/record/types/inter_token_latency_p75.py +60 -0
  132. model_analyzer/record/types/inter_token_latency_p90.py +60 -0
  133. model_analyzer/record/types/inter_token_latency_p95.py +60 -0
  134. model_analyzer/record/types/inter_token_latency_p99.py +60 -0
  135. model_analyzer/record/types/output_token_throughput.py +105 -0
  136. model_analyzer/record/types/perf_client_response_wait.py +97 -0
  137. model_analyzer/record/types/perf_client_send_recv.py +97 -0
  138. model_analyzer/record/types/perf_latency.py +111 -0
  139. model_analyzer/record/types/perf_latency_avg.py +60 -0
  140. model_analyzer/record/types/perf_latency_base.py +74 -0
  141. model_analyzer/record/types/perf_latency_p90.py +60 -0
  142. model_analyzer/record/types/perf_latency_p95.py +60 -0
  143. model_analyzer/record/types/perf_latency_p99.py +60 -0
  144. model_analyzer/record/types/perf_server_compute_infer.py +97 -0
  145. model_analyzer/record/types/perf_server_compute_input.py +97 -0
  146. model_analyzer/record/types/perf_server_compute_output.py +97 -0
  147. model_analyzer/record/types/perf_server_queue.py +97 -0
  148. model_analyzer/record/types/perf_throughput.py +105 -0
  149. model_analyzer/record/types/time_to_first_token_avg.py +60 -0
  150. model_analyzer/record/types/time_to_first_token_base.py +74 -0
  151. model_analyzer/record/types/time_to_first_token_max.py +60 -0
  152. model_analyzer/record/types/time_to_first_token_min.py +60 -0
  153. model_analyzer/record/types/time_to_first_token_p25.py +60 -0
  154. model_analyzer/record/types/time_to_first_token_p50.py +60 -0
  155. model_analyzer/record/types/time_to_first_token_p75.py +60 -0
  156. model_analyzer/record/types/time_to_first_token_p90.py +60 -0
  157. model_analyzer/record/types/time_to_first_token_p95.py +60 -0
  158. model_analyzer/record/types/time_to_first_token_p99.py +60 -0
  159. model_analyzer/reports/__init__.py +15 -0
  160. model_analyzer/reports/html_report.py +195 -0
  161. model_analyzer/reports/pdf_report.py +50 -0
  162. model_analyzer/reports/report.py +86 -0
  163. model_analyzer/reports/report_factory.py +62 -0
  164. model_analyzer/reports/report_manager.py +1376 -0
  165. model_analyzer/reports/report_utils.py +42 -0
  166. model_analyzer/result/__init__.py +15 -0
  167. model_analyzer/result/constraint_manager.py +150 -0
  168. model_analyzer/result/model_config_measurement.py +354 -0
  169. model_analyzer/result/model_constraints.py +105 -0
  170. model_analyzer/result/parameter_search.py +246 -0
  171. model_analyzer/result/result_manager.py +430 -0
  172. model_analyzer/result/result_statistics.py +159 -0
  173. model_analyzer/result/result_table.py +217 -0
  174. model_analyzer/result/result_table_manager.py +646 -0
  175. model_analyzer/result/result_utils.py +42 -0
  176. model_analyzer/result/results.py +277 -0
  177. model_analyzer/result/run_config_measurement.py +658 -0
  178. model_analyzer/result/run_config_result.py +210 -0
  179. model_analyzer/result/run_config_result_comparator.py +110 -0
  180. model_analyzer/result/sorted_results.py +151 -0
  181. model_analyzer/state/__init__.py +15 -0
  182. model_analyzer/state/analyzer_state.py +76 -0
  183. model_analyzer/state/analyzer_state_manager.py +215 -0
  184. model_analyzer/triton/__init__.py +15 -0
  185. model_analyzer/triton/client/__init__.py +15 -0
  186. model_analyzer/triton/client/client.py +234 -0
  187. model_analyzer/triton/client/client_factory.py +57 -0
  188. model_analyzer/triton/client/grpc_client.py +104 -0
  189. model_analyzer/triton/client/http_client.py +107 -0
  190. model_analyzer/triton/model/__init__.py +15 -0
  191. model_analyzer/triton/model/model_config.py +556 -0
  192. model_analyzer/triton/model/model_config_variant.py +29 -0
  193. model_analyzer/triton/server/__init__.py +15 -0
  194. model_analyzer/triton/server/server.py +76 -0
  195. model_analyzer/triton/server/server_config.py +269 -0
  196. model_analyzer/triton/server/server_docker.py +229 -0
  197. model_analyzer/triton/server/server_factory.py +306 -0
  198. model_analyzer/triton/server/server_local.py +158 -0
  199. triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
  200. triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
  201. triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
  202. triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
  203. triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
  204. triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,154 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import Dict, Generator, List, Optional
18
+
19
+ from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
20
+ from model_analyzer.config.generate.model_run_config_generator import (
21
+ ModelRunConfigGenerator,
22
+ )
23
+ from model_analyzer.config.generate.model_variant_name_manager import (
24
+ ModelVariantNameManager,
25
+ )
26
+ from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
27
+ from model_analyzer.config.run.model_run_config import ModelRunConfig
28
+ from model_analyzer.config.run.run_config import RunConfig
29
+ from model_analyzer.device.gpu_device import GPUDevice
30
+ from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
31
+ from model_analyzer.result.run_config_measurement import RunConfigMeasurement
32
+ from model_analyzer.triton.client.client import TritonClient
33
+
34
+ from .config_generator_interface import ConfigGeneratorInterface
35
+
36
+
37
+ class BruteRunConfigGenerator(ConfigGeneratorInterface):
38
+ """
39
+ Generates all RunConfigs to execute via brute force given a list of models
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ config: ConfigCommandProfile,
45
+ gpus: List[GPUDevice],
46
+ models: List[ModelProfileSpec],
47
+ client: TritonClient,
48
+ model_variant_name_manager: ModelVariantNameManager,
49
+ skip_default_config: bool = False,
50
+ ):
51
+ """
52
+ Parameters
53
+ ----------
54
+ config: ModelAnalyzerConfig
55
+
56
+ gpus: List of GPUDevices
57
+
58
+ models: List of ConfigModelProfileSpec
59
+ The models to generate ModelRunConfigs for
60
+
61
+ client: TritonClient
62
+
63
+ model_variant_name_manager: ModelVariantNameManager
64
+
65
+ skip_default_config: bool
66
+ """
67
+ self._config = config
68
+ self._gpus = gpus
69
+ self._models = models
70
+ self._client = client
71
+ self._model_variant_name_manager = model_variant_name_manager
72
+
73
+ self._triton_env = BruteRunConfigGenerator.determine_triton_server_env(models)
74
+
75
+ self._num_models = len(models)
76
+
77
+ self._curr_model_run_configs: List[Optional[ModelRunConfig]] = [
78
+ None for n in range(self._num_models)
79
+ ]
80
+ self._curr_results: List = [[] for n in range(self._num_models)]
81
+ self._curr_generators: Dict[int, ConfigGeneratorInterface] = {}
82
+
83
+ self._skip_default_config = skip_default_config
84
+
85
+ def set_last_results(
86
+ self, measurements: List[Optional[RunConfigMeasurement]]
87
+ ) -> None:
88
+ for index in range(self._num_models):
89
+ self._curr_results[index].extend(measurements)
90
+
91
+ def get_configs(self) -> Generator[RunConfig, None, None]:
92
+ """
93
+ Returns
94
+ -------
95
+ RunConfig
96
+ The next RunConfig generated by this class
97
+ """
98
+
99
+ yield from self._get_next_config()
100
+
101
+ def _get_next_config(self) -> Generator[RunConfig, None, None]:
102
+ if not self._skip_default_config:
103
+ yield from self._generate_subset(0, default_only=True)
104
+
105
+ yield from self._generate_subset(0, default_only=False)
106
+
107
+ def _generate_subset(
108
+ self, index: int, default_only: bool
109
+ ) -> Generator[RunConfig, None, None]:
110
+ mrcg = ModelRunConfigGenerator(
111
+ self._config,
112
+ self._gpus,
113
+ self._models[index],
114
+ self._client,
115
+ self._model_variant_name_manager,
116
+ default_only,
117
+ )
118
+
119
+ self._curr_generators[index] = mrcg
120
+
121
+ for model_run_config in mrcg.get_configs():
122
+ self._curr_model_run_configs[index] = model_run_config
123
+
124
+ if index == (len(self._models) - 1):
125
+ yield (self._make_run_config())
126
+ else:
127
+ yield from self._generate_subset(index + 1, default_only)
128
+
129
+ self._send_results_to_generator(index)
130
+
131
+ def _make_run_config(self) -> RunConfig:
132
+ run_config = RunConfig(self._triton_env, self._models[0].genai_perf_flags())
133
+ for index in range(len(self._models)):
134
+ run_config.add_model_run_config(self._curr_model_run_configs[index])
135
+ return run_config
136
+
137
+ def _send_results_to_generator(self, index: int) -> None:
138
+ self._curr_generators[index].set_last_results(self._curr_results[index])
139
+ self._curr_results[index] = []
140
+
141
+ @classmethod
142
+ def determine_triton_server_env(cls, models: List[ModelProfileSpec]) -> Dict:
143
+ """
144
+ Given a list of models, return the triton environment
145
+ """
146
+ triton_env = models[0].triton_server_environment()
147
+
148
+ for model in models:
149
+ if model.triton_server_environment() != triton_env:
150
+ raise TritonModelAnalyzerException(
151
+ f"Mismatching triton server environments. The triton server environment must be the same for all models when run concurrently"
152
+ )
153
+
154
+ return triton_env
@@ -0,0 +1,75 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+ from copy import deepcopy
19
+ from typing import Generator, List, Optional
20
+
21
+ from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
22
+ from model_analyzer.config.run.run_config import RunConfig
23
+ from model_analyzer.constants import LOGGER_NAME
24
+ from model_analyzer.result.parameter_search import ParameterSearch
25
+ from model_analyzer.result.result_manager import ResultManager
26
+ from model_analyzer.result.run_config_measurement import RunConfigMeasurement
27
+
28
+ logger = logging.getLogger(LOGGER_NAME)
29
+
30
+
31
+ class ConcurrencySweeper:
32
+ """
33
+ Sweeps concurrency for the top-N model configs
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ config: ConfigCommandProfile,
39
+ result_manager: ResultManager,
40
+ ):
41
+ self._config = config
42
+ self._result_manager = result_manager
43
+ self._last_measurement: Optional[RunConfigMeasurement] = None
44
+
45
+ def set_last_results(
46
+ self, measurements: List[Optional[RunConfigMeasurement]]
47
+ ) -> None:
48
+ self._last_measurement = measurements[-1]
49
+
50
+ def get_configs(self) -> Generator[RunConfig, None, None]:
51
+ """
52
+ A generator which creates RunConfigs based on sweeping
53
+ concurrency over the top-N models
54
+ """
55
+ for model_name in self._result_manager.get_model_names():
56
+ top_results = self._result_manager.top_n_results(
57
+ model_name=model_name,
58
+ n=self._config.num_configs_per_model,
59
+ include_default=True,
60
+ )
61
+
62
+ for result in top_results:
63
+ run_config = deepcopy(result.run_config())
64
+ parameter_search = ParameterSearch(self._config)
65
+ for concurrency in parameter_search.search_parameters():
66
+ run_config = self._create_run_config(run_config, concurrency)
67
+ yield run_config
68
+ parameter_search.add_run_config_measurement(self._last_measurement)
69
+
70
+ def _create_run_config(self, run_config: RunConfig, concurrency: int) -> RunConfig:
71
+ for model_run_config in run_config.model_run_configs():
72
+ perf_config = model_run_config.perf_config()
73
+ perf_config.update_config({"concurrency-range": concurrency})
74
+
75
+ return run_config
@@ -0,0 +1,52 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import abc
18
+ from typing import Any, Generator, List, Optional
19
+
20
+ from model_analyzer.result.run_config_measurement import RunConfigMeasurement
21
+
22
+
23
+ class ConfigGeneratorInterface(abc.ABC):
24
+ """
25
+ An interface class for config generators
26
+ """
27
+
28
+ @classmethod
29
+ def __subclasshook__(cls, subclass: Any) -> bool:
30
+ return (
31
+ hasattr(subclass, "__init__")
32
+ and callable(subclass.__init__)
33
+ and hasattr(subclass, "get_configs")
34
+ and callable(subclass.get_configs)
35
+ and hasattr(subclass, "set_last_results")
36
+ and callable(subclass.set_last_results)
37
+ or NotImplemented
38
+ )
39
+
40
+ @abc.abstractmethod
41
+ def __init__(self) -> None:
42
+ raise NotImplementedError
43
+
44
+ @abc.abstractmethod
45
+ def get_configs(self) -> Generator[Any, None, None]:
46
+ raise NotImplementedError
47
+
48
+ @abc.abstractmethod
49
+ def set_last_results(
50
+ self, measurements: List[Optional[RunConfigMeasurement]]
51
+ ) -> None:
52
+ raise NotImplementedError
@@ -0,0 +1,143 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from copy import deepcopy
18
+ from functools import total_ordering
19
+ from typing import Any, Iterator, List, Union
20
+
21
+
22
+ @total_ordering
23
+ class Coordinate:
24
+ """
25
+ Class to define a coordinate in n-dimension space
26
+ """
27
+
28
+ def __init__(self, val: Union["Coordinate", List[int]]):
29
+ """
30
+ val: list
31
+ List of floats or integers corresponding to the location in space
32
+ """
33
+ if isinstance(val, Coordinate):
34
+ val = val._values
35
+
36
+ self._values: List[int] = deepcopy(val)
37
+
38
+ def __getitem__(self, idx: int) -> int:
39
+ return self._values[idx]
40
+
41
+ def __setitem__(self, idx: int, item: int) -> None:
42
+ self._values[idx] = item
43
+
44
+ def __len__(self) -> int:
45
+ return len(self._values)
46
+
47
+ def __add__(self, other: Any) -> "Coordinate":
48
+ if type(other) == Coordinate:
49
+ return self._add_coordinate(other)
50
+ elif type(other) == int or type(other) == float:
51
+ return self._add_number(other)
52
+ else:
53
+ raise Exception("Unhandled addition type")
54
+
55
+ def __sub__(self, other: Any) -> "Coordinate":
56
+ if type(other) == Coordinate:
57
+ return self._sub_coordinate(other)
58
+ elif type(other) == int or type(other) == float:
59
+ return self._sub_number(other)
60
+ else:
61
+ raise Exception("Unhandled subtraction type")
62
+
63
+ def __truediv__(self, other: Any) -> "Coordinate":
64
+ if type(other) == int or type(other) == float:
65
+ return self._div_number(other)
66
+ else:
67
+ raise Exception("Unhandled division type")
68
+
69
+ def __mul__(self, other: Any) -> "Coordinate":
70
+ if type(other) == int or type(other) == float:
71
+ return self._mul_number(other)
72
+ else:
73
+ raise Exception("Unhandled mul type")
74
+
75
+ def __eq__(self, other: Any) -> bool:
76
+ for i, v in enumerate(self._values):
77
+ if v != other[i]:
78
+ return False
79
+ return True
80
+
81
+ def __lt__(self, other: Any) -> bool:
82
+ for i, v in enumerate(self._values):
83
+ if v != other[i]:
84
+ return v < other[i]
85
+ return False
86
+
87
+ def round(self) -> None:
88
+ """Rounds the coordinate in-place"""
89
+ for i, _ in enumerate(self._values):
90
+ self._values[i] = round(self._values[i])
91
+
92
+ def _add_coordinate(self, other: Any) -> "Coordinate":
93
+ ret = Coordinate(self._values)
94
+ for i, v in enumerate(self._values):
95
+ ret[i] = v + other[i]
96
+ return ret
97
+
98
+ def _add_number(self, other: Any) -> "Coordinate":
99
+ ret = Coordinate(self._values)
100
+ for i, v in enumerate(self._values):
101
+ ret[i] = v + other
102
+ return ret
103
+
104
+ def _sub_coordinate(self, other: Any) -> "Coordinate":
105
+ ret = Coordinate(self._values)
106
+ for i, v in enumerate(self._values):
107
+ ret[i] = v - other[i]
108
+ return ret
109
+
110
+ def _sub_number(self, other: Any) -> "Coordinate":
111
+ ret = Coordinate(self._values)
112
+ for i, v in enumerate(self._values):
113
+ ret[i] = v - other
114
+ return ret
115
+
116
+ def _mul_number(self, other: Any) -> "Coordinate":
117
+ ret = Coordinate(self._values)
118
+ for i, v in enumerate(self._values):
119
+ ret[i] = v * other
120
+ return ret
121
+
122
+ def _div_number(self, other: Any) -> "Coordinate":
123
+ ret = Coordinate(self._values)
124
+ for i, v in enumerate(self._values):
125
+ ret[i] = v / other
126
+ return ret
127
+
128
+ def __iter__(self) -> Iterator:
129
+ self._idx = 0
130
+ return self
131
+
132
+ def __next__(self) -> int:
133
+ if self._idx < len(self._values):
134
+ val = self._values[self._idx]
135
+ self._idx += 1
136
+ return val
137
+ raise StopIteration
138
+
139
+ def __str__(self) -> str:
140
+ return str(self._values)
141
+
142
+ def __repr__(self) -> str:
143
+ return repr(f"Coordinate({self._values})")
@@ -0,0 +1,86 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from typing import Dict, Optional, Tuple
18
+
19
+ from model_analyzer.config.generate.coordinate import Coordinate
20
+ from model_analyzer.result.run_config_measurement import RunConfigMeasurement
21
+
22
+ CoordinateKey = Tuple[Coordinate, ...]
23
+
24
+
25
+ class CoordinateData:
26
+ """
27
+ A class that tracks the measurement data in the current neighborhood
28
+ and the visit counts of all the coordinates in the coordinate space.
29
+ """
30
+
31
+ def __init__(self) -> None:
32
+ self._measurements: Dict[CoordinateKey, Optional[RunConfigMeasurement]] = {}
33
+ self._visit_counts: Dict[CoordinateKey, int] = {}
34
+ self._is_measured: Dict[CoordinateKey, bool] = {}
35
+
36
+ def get_measurement(self, coordinate: Coordinate) -> Optional[RunConfigMeasurement]:
37
+ """
38
+ Return the measurement data of the given coordinate.
39
+ """
40
+ key: CoordinateKey = tuple(coordinate)
41
+ return self._measurements.get(key, None)
42
+
43
+ def set_measurement(
44
+ self, coordinate: Coordinate, measurement: Optional[RunConfigMeasurement]
45
+ ) -> None:
46
+ """
47
+ Set the measurement for the given coordinate.
48
+ """
49
+ key: CoordinateKey = tuple(coordinate)
50
+ self._measurements[key] = measurement
51
+ self._is_measured[key] = True
52
+
53
+ def is_measured(self, coordinate: Coordinate) -> bool:
54
+ """
55
+ Returns true if a measurement has been set for the given Coordinate
56
+ """
57
+ key: CoordinateKey = tuple(coordinate)
58
+ return self._is_measured.get(key, False)
59
+
60
+ def has_valid_measurement(self, coordinate: Coordinate) -> bool:
61
+ """
62
+ Returns true if there is a valid measurement for the given Coordinate
63
+ """
64
+ return self.get_measurement(coordinate) is not None
65
+
66
+ def reset_measurements(self) -> None:
67
+ """
68
+ Resets the collection of measurements.
69
+ """
70
+ self._measurements = {}
71
+
72
+ def get_visit_count(self, coordinate: Coordinate) -> int:
73
+ """
74
+ Get the visit count for the given coordinate.
75
+ Returns 0 if the coordinate hasn't been visited yet
76
+ """
77
+ key: CoordinateKey = tuple(coordinate)
78
+ return self._visit_counts.get(key, 0)
79
+
80
+ def increment_visit_count(self, coordinate: Coordinate) -> None:
81
+ """
82
+ Increase the visit count for the given coordinate by 1
83
+ """
84
+ key: CoordinateKey = tuple(coordinate)
85
+ new_count = self.get_visit_count(coordinate) + 1
86
+ self._visit_counts[key] = new_count
@@ -0,0 +1,116 @@
1
+ #!/usr/bin/env python3
2
+
3
+ # Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from itertools import product
18
+ from typing import Dict, List
19
+
20
+
21
+ class GeneratorUtils:
22
+ """Class for utility functions for Generators"""
23
+
24
+ @staticmethod
25
+ def generate_combinations(value: object) -> List:
26
+ """
27
+ Generates all the alternative fields for
28
+ a given value.
29
+
30
+ Parameters
31
+ ----------
32
+ value : object
33
+ The value to be used for sweeping.
34
+
35
+ Returns
36
+ -------
37
+ list
38
+ A list of all the alternatives for the parameters.
39
+ """
40
+
41
+ if type(value) is dict:
42
+ sweeped_dict = {}
43
+ for key, sweep_choices in value.items():
44
+ sweep_parameter_list = []
45
+
46
+ # This is the list of sweep parameters. When parsing a
47
+ # config every sweepable parameter will be converted
48
+ # to a list of values to make the parameter sweeping easier in
49
+ # here.
50
+ for sweep_choice in sweep_choices:
51
+ sweep_parameter_list += GeneratorUtils.generate_combinations(
52
+ sweep_choice
53
+ )
54
+
55
+ sweeped_dict[key] = sweep_parameter_list
56
+
57
+ # Generate parameter combinations for this field.
58
+ return GeneratorUtils.generate_parameter_combinations(sweeped_dict)
59
+
60
+ # When this line of code is executed the value for this field is
61
+ # a list. This list does NOT represent possible sweep values.
62
+ # Because of this we need to ensure that in every sweep configuration,
63
+ # one item from every list item exists.
64
+ elif type(value) is list:
65
+ # This list contains a set of lists. The return value from this
66
+ # branch of the code is a list of lists where in each inner list
67
+ # there is one item from every list item.
68
+ sweep_parameter_list = []
69
+ for item in value:
70
+ sweep_parameter_list_item = GeneratorUtils.generate_combinations(item)
71
+ sweep_parameter_list.append(sweep_parameter_list_item)
72
+
73
+ # Cartesian product of all the elements in the sweep_parameter_list
74
+ return [list(x) for x in list(product(*sweep_parameter_list))]
75
+
76
+ # In the default case return a list of the value. This function should
77
+ # always return a list.
78
+ return [value]
79
+
80
+ @staticmethod
81
+ def generate_parameter_combinations(params: Dict) -> List[Dict]:
82
+ """
83
+ Generate a list of all possible subdictionaries
84
+ from given dictionary. The subdictionaries will
85
+ have all the same keys, but only one value from
86
+ each key.
87
+
88
+ Parameters
89
+ ----------
90
+ params : dict
91
+ keys are strings and the values must be lists
92
+ """
93
+
94
+ param_combinations = list(product(*tuple(params.values())))
95
+ return [dict(zip(params.keys(), vals)) for vals in param_combinations]
96
+
97
+ @staticmethod
98
+ def generate_doubled_list(min_value: int, max_value: int) -> List[int]:
99
+ """
100
+ Generates a list of values from min_value -> max_value doubling
101
+ min_value for each entry
102
+
103
+ Parameters
104
+ ----------
105
+ min_value: int
106
+ The minimum value for the generated list
107
+ max_value : int
108
+ The value that the generated list will not exceed
109
+ """
110
+
111
+ list = []
112
+ val = 1 if min_value == 0 else min_value
113
+ while val <= max_value:
114
+ list.append(val)
115
+ val *= 2
116
+ return list