triton-model-analyzer 1.48.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_analyzer/__init__.py +15 -0
- model_analyzer/analyzer.py +448 -0
- model_analyzer/cli/__init__.py +15 -0
- model_analyzer/cli/cli.py +193 -0
- model_analyzer/config/__init__.py +15 -0
- model_analyzer/config/generate/__init__.py +15 -0
- model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
- model_analyzer/config/generate/base_model_config_generator.py +352 -0
- model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
- model_analyzer/config/generate/brute_run_config_generator.py +154 -0
- model_analyzer/config/generate/concurrency_sweeper.py +75 -0
- model_analyzer/config/generate/config_generator_interface.py +52 -0
- model_analyzer/config/generate/coordinate.py +143 -0
- model_analyzer/config/generate/coordinate_data.py +86 -0
- model_analyzer/config/generate/generator_utils.py +116 -0
- model_analyzer/config/generate/manual_model_config_generator.py +187 -0
- model_analyzer/config/generate/model_config_generator_factory.py +92 -0
- model_analyzer/config/generate/model_profile_spec.py +74 -0
- model_analyzer/config/generate/model_run_config_generator.py +154 -0
- model_analyzer/config/generate/model_variant_name_manager.py +150 -0
- model_analyzer/config/generate/neighborhood.py +536 -0
- model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
- model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
- model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
- model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
- model_analyzer/config/generate/quick_run_config_generator.py +753 -0
- model_analyzer/config/generate/run_config_generator_factory.py +329 -0
- model_analyzer/config/generate/search_config.py +112 -0
- model_analyzer/config/generate/search_dimension.py +73 -0
- model_analyzer/config/generate/search_dimensions.py +85 -0
- model_analyzer/config/generate/search_parameter.py +49 -0
- model_analyzer/config/generate/search_parameters.py +388 -0
- model_analyzer/config/input/__init__.py +15 -0
- model_analyzer/config/input/config_command.py +483 -0
- model_analyzer/config/input/config_command_profile.py +1747 -0
- model_analyzer/config/input/config_command_report.py +267 -0
- model_analyzer/config/input/config_defaults.py +236 -0
- model_analyzer/config/input/config_enum.py +83 -0
- model_analyzer/config/input/config_field.py +216 -0
- model_analyzer/config/input/config_list_generic.py +112 -0
- model_analyzer/config/input/config_list_numeric.py +151 -0
- model_analyzer/config/input/config_list_string.py +111 -0
- model_analyzer/config/input/config_none.py +71 -0
- model_analyzer/config/input/config_object.py +129 -0
- model_analyzer/config/input/config_primitive.py +81 -0
- model_analyzer/config/input/config_status.py +75 -0
- model_analyzer/config/input/config_sweep.py +83 -0
- model_analyzer/config/input/config_union.py +113 -0
- model_analyzer/config/input/config_utils.py +128 -0
- model_analyzer/config/input/config_value.py +243 -0
- model_analyzer/config/input/objects/__init__.py +15 -0
- model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
- model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
- model_analyzer/config/input/objects/config_plot.py +198 -0
- model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
- model_analyzer/config/input/yaml_config_validator.py +82 -0
- model_analyzer/config/run/__init__.py +15 -0
- model_analyzer/config/run/model_run_config.py +313 -0
- model_analyzer/config/run/run_config.py +168 -0
- model_analyzer/constants.py +76 -0
- model_analyzer/device/__init__.py +15 -0
- model_analyzer/device/device.py +24 -0
- model_analyzer/device/gpu_device.py +87 -0
- model_analyzer/device/gpu_device_factory.py +248 -0
- model_analyzer/entrypoint.py +307 -0
- model_analyzer/log_formatter.py +65 -0
- model_analyzer/model_analyzer_exceptions.py +24 -0
- model_analyzer/model_manager.py +255 -0
- model_analyzer/monitor/__init__.py +15 -0
- model_analyzer/monitor/cpu_monitor.py +69 -0
- model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
- model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
- model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
- model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
- model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
- model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
- model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
- model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
- model_analyzer/monitor/dcgm/__init__.py +15 -0
- model_analyzer/monitor/dcgm/common/__init__.py +13 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
- model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
- model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
- model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
- model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
- model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
- model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
- model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
- model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
- model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
- model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
- model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
- model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
- model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
- model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
- model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
- model_analyzer/monitor/dcgm/pydcgm.py +47 -0
- model_analyzer/monitor/monitor.py +143 -0
- model_analyzer/monitor/remote_monitor.py +137 -0
- model_analyzer/output/__init__.py +15 -0
- model_analyzer/output/file_writer.py +63 -0
- model_analyzer/output/output_writer.py +42 -0
- model_analyzer/perf_analyzer/__init__.py +15 -0
- model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
- model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
- model_analyzer/perf_analyzer/perf_config.py +479 -0
- model_analyzer/plots/__init__.py +15 -0
- model_analyzer/plots/detailed_plot.py +266 -0
- model_analyzer/plots/plot_manager.py +224 -0
- model_analyzer/plots/simple_plot.py +213 -0
- model_analyzer/record/__init__.py +15 -0
- model_analyzer/record/gpu_record.py +68 -0
- model_analyzer/record/metrics_manager.py +887 -0
- model_analyzer/record/record.py +280 -0
- model_analyzer/record/record_aggregator.py +256 -0
- model_analyzer/record/types/__init__.py +15 -0
- model_analyzer/record/types/cpu_available_ram.py +93 -0
- model_analyzer/record/types/cpu_used_ram.py +93 -0
- model_analyzer/record/types/gpu_free_memory.py +96 -0
- model_analyzer/record/types/gpu_power_usage.py +107 -0
- model_analyzer/record/types/gpu_total_memory.py +96 -0
- model_analyzer/record/types/gpu_used_memory.py +96 -0
- model_analyzer/record/types/gpu_utilization.py +108 -0
- model_analyzer/record/types/inter_token_latency_avg.py +60 -0
- model_analyzer/record/types/inter_token_latency_base.py +74 -0
- model_analyzer/record/types/inter_token_latency_max.py +60 -0
- model_analyzer/record/types/inter_token_latency_min.py +60 -0
- model_analyzer/record/types/inter_token_latency_p25.py +60 -0
- model_analyzer/record/types/inter_token_latency_p50.py +60 -0
- model_analyzer/record/types/inter_token_latency_p75.py +60 -0
- model_analyzer/record/types/inter_token_latency_p90.py +60 -0
- model_analyzer/record/types/inter_token_latency_p95.py +60 -0
- model_analyzer/record/types/inter_token_latency_p99.py +60 -0
- model_analyzer/record/types/output_token_throughput.py +105 -0
- model_analyzer/record/types/perf_client_response_wait.py +97 -0
- model_analyzer/record/types/perf_client_send_recv.py +97 -0
- model_analyzer/record/types/perf_latency.py +111 -0
- model_analyzer/record/types/perf_latency_avg.py +60 -0
- model_analyzer/record/types/perf_latency_base.py +74 -0
- model_analyzer/record/types/perf_latency_p90.py +60 -0
- model_analyzer/record/types/perf_latency_p95.py +60 -0
- model_analyzer/record/types/perf_latency_p99.py +60 -0
- model_analyzer/record/types/perf_server_compute_infer.py +97 -0
- model_analyzer/record/types/perf_server_compute_input.py +97 -0
- model_analyzer/record/types/perf_server_compute_output.py +97 -0
- model_analyzer/record/types/perf_server_queue.py +97 -0
- model_analyzer/record/types/perf_throughput.py +105 -0
- model_analyzer/record/types/time_to_first_token_avg.py +60 -0
- model_analyzer/record/types/time_to_first_token_base.py +74 -0
- model_analyzer/record/types/time_to_first_token_max.py +60 -0
- model_analyzer/record/types/time_to_first_token_min.py +60 -0
- model_analyzer/record/types/time_to_first_token_p25.py +60 -0
- model_analyzer/record/types/time_to_first_token_p50.py +60 -0
- model_analyzer/record/types/time_to_first_token_p75.py +60 -0
- model_analyzer/record/types/time_to_first_token_p90.py +60 -0
- model_analyzer/record/types/time_to_first_token_p95.py +60 -0
- model_analyzer/record/types/time_to_first_token_p99.py +60 -0
- model_analyzer/reports/__init__.py +15 -0
- model_analyzer/reports/html_report.py +195 -0
- model_analyzer/reports/pdf_report.py +50 -0
- model_analyzer/reports/report.py +86 -0
- model_analyzer/reports/report_factory.py +62 -0
- model_analyzer/reports/report_manager.py +1376 -0
- model_analyzer/reports/report_utils.py +42 -0
- model_analyzer/result/__init__.py +15 -0
- model_analyzer/result/constraint_manager.py +150 -0
- model_analyzer/result/model_config_measurement.py +354 -0
- model_analyzer/result/model_constraints.py +105 -0
- model_analyzer/result/parameter_search.py +246 -0
- model_analyzer/result/result_manager.py +430 -0
- model_analyzer/result/result_statistics.py +159 -0
- model_analyzer/result/result_table.py +217 -0
- model_analyzer/result/result_table_manager.py +646 -0
- model_analyzer/result/result_utils.py +42 -0
- model_analyzer/result/results.py +277 -0
- model_analyzer/result/run_config_measurement.py +658 -0
- model_analyzer/result/run_config_result.py +210 -0
- model_analyzer/result/run_config_result_comparator.py +110 -0
- model_analyzer/result/sorted_results.py +151 -0
- model_analyzer/state/__init__.py +15 -0
- model_analyzer/state/analyzer_state.py +76 -0
- model_analyzer/state/analyzer_state_manager.py +215 -0
- model_analyzer/triton/__init__.py +15 -0
- model_analyzer/triton/client/__init__.py +15 -0
- model_analyzer/triton/client/client.py +234 -0
- model_analyzer/triton/client/client_factory.py +57 -0
- model_analyzer/triton/client/grpc_client.py +104 -0
- model_analyzer/triton/client/http_client.py +107 -0
- model_analyzer/triton/model/__init__.py +15 -0
- model_analyzer/triton/model/model_config.py +556 -0
- model_analyzer/triton/model/model_config_variant.py +29 -0
- model_analyzer/triton/server/__init__.py +15 -0
- model_analyzer/triton/server/server.py +76 -0
- model_analyzer/triton/server/server_config.py +269 -0
- model_analyzer/triton/server/server_docker.py +229 -0
- model_analyzer/triton/server/server_factory.py +306 -0
- model_analyzer/triton/server/server_local.py +158 -0
- triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
- triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
- triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
- triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
- triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
- triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from typing import Any, Dict, List
|
|
19
|
+
|
|
20
|
+
from model_analyzer.config.generate.model_variant_name_manager import (
|
|
21
|
+
ModelVariantNameManager,
|
|
22
|
+
)
|
|
23
|
+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
|
|
24
|
+
from model_analyzer.constants import DEFAULT_CONFIG_PARAMS, LOGGER_NAME
|
|
25
|
+
from model_analyzer.device.gpu_device import GPUDevice
|
|
26
|
+
from model_analyzer.model_analyzer_exceptions import TritonModelAnalyzerException
|
|
27
|
+
from model_analyzer.triton.client.client import TritonClient
|
|
28
|
+
from model_analyzer.triton.model.model_config_variant import ModelConfigVariant
|
|
29
|
+
|
|
30
|
+
from .base_model_config_generator import BaseModelConfigGenerator
|
|
31
|
+
from .model_profile_spec import ModelProfileSpec
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class AutomaticModelConfigGenerator(BaseModelConfigGenerator):
|
|
37
|
+
"""Given a model, generates model configs in automatic search mode"""
|
|
38
|
+
|
|
39
|
+
_log_first_run = False
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
config: ConfigCommandProfile,
|
|
44
|
+
gpus: List[GPUDevice],
|
|
45
|
+
model: ModelProfileSpec,
|
|
46
|
+
client: TritonClient,
|
|
47
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
48
|
+
default_only: bool,
|
|
49
|
+
early_exit_enable: bool,
|
|
50
|
+
) -> None:
|
|
51
|
+
"""
|
|
52
|
+
Parameters
|
|
53
|
+
----------
|
|
54
|
+
config: ModelAnalyzerConfig
|
|
55
|
+
gpus: List of GPUDevices
|
|
56
|
+
model: ModelProfileSpec
|
|
57
|
+
The model to generate ModelConfigs for
|
|
58
|
+
client: TritonClient
|
|
59
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
60
|
+
default_only: Bool
|
|
61
|
+
If true, only the default config will be generated
|
|
62
|
+
If false, the default config will NOT be generated
|
|
63
|
+
early_exit_enable: Bool
|
|
64
|
+
If true, the generator can early exit if throughput plateaus
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(
|
|
67
|
+
config,
|
|
68
|
+
gpus,
|
|
69
|
+
model,
|
|
70
|
+
client,
|
|
71
|
+
model_variant_name_manager,
|
|
72
|
+
default_only,
|
|
73
|
+
early_exit_enable,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not AutomaticModelConfigGenerator._log_first_run:
|
|
77
|
+
logger.info("")
|
|
78
|
+
logger.info("Starting automatic brute search")
|
|
79
|
+
logger.info("")
|
|
80
|
+
AutomaticModelConfigGenerator._log_first_run = True
|
|
81
|
+
|
|
82
|
+
self._max_instance_count = config.run_config_search_max_instance_count
|
|
83
|
+
self._min_instance_count = config.run_config_search_min_instance_count
|
|
84
|
+
self._max_model_batch_size = config.run_config_search_max_model_batch_size
|
|
85
|
+
self._min_model_batch_size = config.run_config_search_min_model_batch_size
|
|
86
|
+
|
|
87
|
+
self._instance_kind = "KIND_CPU" if self._cpu_only else "KIND_GPU"
|
|
88
|
+
|
|
89
|
+
self._curr_instance_count = self._min_instance_count
|
|
90
|
+
self._curr_max_batch_size = 0
|
|
91
|
+
|
|
92
|
+
self._reset_max_batch_size()
|
|
93
|
+
|
|
94
|
+
if not self._early_exit_enable:
|
|
95
|
+
raise TritonModelAnalyzerException(
|
|
96
|
+
"Early exit disable is not supported in automatic model config generator"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def _done_walking(self) -> bool:
|
|
100
|
+
return self._curr_instance_count > self._max_instance_count
|
|
101
|
+
|
|
102
|
+
def _step(self) -> None:
|
|
103
|
+
self._step_max_batch_size()
|
|
104
|
+
|
|
105
|
+
if self._done_walking_max_batch_size():
|
|
106
|
+
self._reset_max_batch_size()
|
|
107
|
+
self._step_instance_count()
|
|
108
|
+
|
|
109
|
+
def _step_max_batch_size(self) -> None:
|
|
110
|
+
self._curr_max_batch_size *= 2
|
|
111
|
+
|
|
112
|
+
last_max_throughput = self._get_last_results_max_throughput()
|
|
113
|
+
if last_max_throughput:
|
|
114
|
+
self._curr_max_batch_size_throughputs.append(last_max_throughput)
|
|
115
|
+
|
|
116
|
+
def _step_instance_count(self) -> None:
|
|
117
|
+
self._curr_instance_count += 1
|
|
118
|
+
|
|
119
|
+
def _done_walking_max_batch_size(self) -> bool:
|
|
120
|
+
if self._last_results_erroneous():
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
if self._max_batch_size_limit_reached():
|
|
124
|
+
return True
|
|
125
|
+
|
|
126
|
+
if not self._last_results_increased_throughput():
|
|
127
|
+
self._print_max_batch_size_plateau_warning()
|
|
128
|
+
return True
|
|
129
|
+
|
|
130
|
+
return False
|
|
131
|
+
|
|
132
|
+
def _max_batch_size_limit_reached(self) -> bool:
|
|
133
|
+
return self._curr_max_batch_size > self._max_model_batch_size
|
|
134
|
+
|
|
135
|
+
def _reset_max_batch_size(self) -> None:
|
|
136
|
+
super()._reset_max_batch_size()
|
|
137
|
+
|
|
138
|
+
if self._base_model.supports_batching():
|
|
139
|
+
self._curr_max_batch_size = self._min_model_batch_size
|
|
140
|
+
else:
|
|
141
|
+
self._curr_max_batch_size = self._max_model_batch_size
|
|
142
|
+
|
|
143
|
+
def _get_next_model_config_variant(self) -> ModelConfigVariant:
|
|
144
|
+
param_combo = self._get_curr_param_combo()
|
|
145
|
+
model_config_variant = self._make_direct_mode_model_config_variant(param_combo)
|
|
146
|
+
return model_config_variant
|
|
147
|
+
|
|
148
|
+
def _get_curr_param_combo(self) -> Dict:
|
|
149
|
+
if self._default_only:
|
|
150
|
+
return DEFAULT_CONFIG_PARAMS
|
|
151
|
+
|
|
152
|
+
config: Dict[str, Any] = {
|
|
153
|
+
"instance_group": [
|
|
154
|
+
{"count": self._curr_instance_count, "kind": self._instance_kind}
|
|
155
|
+
]
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
if self._base_model.supports_batching():
|
|
159
|
+
config["max_batch_size"] = self._curr_max_batch_size
|
|
160
|
+
|
|
161
|
+
if self._base_model.supports_dynamic_batching():
|
|
162
|
+
config["dynamic_batching"] = {}
|
|
163
|
+
|
|
164
|
+
return config
|
|
@@ -0,0 +1,352 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import abc
|
|
18
|
+
import logging
|
|
19
|
+
from copy import deepcopy
|
|
20
|
+
from typing import Any, Dict, Generator, List, Optional
|
|
21
|
+
|
|
22
|
+
from model_analyzer.config.generate.model_variant_name_manager import (
|
|
23
|
+
ModelVariantNameManager,
|
|
24
|
+
)
|
|
25
|
+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
|
|
26
|
+
from model_analyzer.constants import LOGGER_NAME
|
|
27
|
+
from model_analyzer.device.gpu_device import GPUDevice
|
|
28
|
+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
|
|
29
|
+
from model_analyzer.triton.client.client import TritonClient
|
|
30
|
+
from model_analyzer.triton.model.model_config import ModelConfig
|
|
31
|
+
from model_analyzer.triton.model.model_config_variant import ModelConfigVariant
|
|
32
|
+
|
|
33
|
+
from .config_generator_interface import ConfigGeneratorInterface
|
|
34
|
+
from .model_profile_spec import ModelProfileSpec
|
|
35
|
+
|
|
36
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class BaseModelConfigGenerator(ConfigGeneratorInterface):
|
|
40
|
+
"""Base class for generating model configs"""
|
|
41
|
+
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
config: ConfigCommandProfile,
|
|
45
|
+
gpus: List[GPUDevice],
|
|
46
|
+
model: ModelProfileSpec,
|
|
47
|
+
client: TritonClient,
|
|
48
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
49
|
+
default_only: bool,
|
|
50
|
+
early_exit_enable: bool,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Parameters
|
|
54
|
+
----------
|
|
55
|
+
config: ConfigCommandProfile
|
|
56
|
+
gpus: List of GPUDevices
|
|
57
|
+
model: ModelProfileSpec
|
|
58
|
+
The model to generate ModelConfigs for
|
|
59
|
+
client: TritonClient
|
|
60
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
61
|
+
default_only: Bool
|
|
62
|
+
If true, only the default config will be generated
|
|
63
|
+
If false, the default config will NOT be generated
|
|
64
|
+
early_exit_enable: Bool
|
|
65
|
+
If true, the generator can early exit if throughput plateaus
|
|
66
|
+
"""
|
|
67
|
+
self._config = config
|
|
68
|
+
self._client = client
|
|
69
|
+
self._model_variant_name_manager = model_variant_name_manager
|
|
70
|
+
self._base_model = model
|
|
71
|
+
self._base_model_name = model.model_name()
|
|
72
|
+
self._remote_mode = config.triton_launch_mode == "remote"
|
|
73
|
+
self._c_api_mode = config.triton_launch_mode == "c_api"
|
|
74
|
+
self._cpu_only = model.cpu_only()
|
|
75
|
+
self._default_only = default_only
|
|
76
|
+
self._early_exit_enable = early_exit_enable
|
|
77
|
+
self._model_name_index = 0
|
|
78
|
+
self._generator_started = False
|
|
79
|
+
self._max_batch_size_warning_printed = False
|
|
80
|
+
self._last_results: List[Optional[RunConfigMeasurement]] = []
|
|
81
|
+
# Contains the max throughput from each provided list of measurements
|
|
82
|
+
# since the last time we stepped max_batch_size
|
|
83
|
+
#
|
|
84
|
+
self._curr_max_batch_size_throughputs: List[float] = []
|
|
85
|
+
|
|
86
|
+
def _is_done(self) -> bool:
|
|
87
|
+
"""Returns true if this generator is done generating configs"""
|
|
88
|
+
return self._generator_started and (self._default_only or self._done_walking())
|
|
89
|
+
|
|
90
|
+
def get_configs(self) -> Generator[ModelConfigVariant, None, None]:
|
|
91
|
+
"""
|
|
92
|
+
Returns
|
|
93
|
+
-------
|
|
94
|
+
ModelConfig
|
|
95
|
+
The next ModelConfig generated by this class
|
|
96
|
+
"""
|
|
97
|
+
while True:
|
|
98
|
+
if self._is_done():
|
|
99
|
+
break
|
|
100
|
+
|
|
101
|
+
self._generator_started = True
|
|
102
|
+
config = self._get_next_model_config_variant()
|
|
103
|
+
yield (config)
|
|
104
|
+
self._step()
|
|
105
|
+
|
|
106
|
+
def set_last_results(
|
|
107
|
+
self, measurements: List[Optional[RunConfigMeasurement]]
|
|
108
|
+
) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Given the results from the last ModelConfig, make decisions
|
|
111
|
+
about future configurations to generate
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
measurements: List of Measurements from the last run(s)
|
|
116
|
+
"""
|
|
117
|
+
self._last_results = measurements
|
|
118
|
+
|
|
119
|
+
@abc.abstractmethod
|
|
120
|
+
def _done_walking(self) -> bool:
|
|
121
|
+
raise NotImplementedError
|
|
122
|
+
|
|
123
|
+
@abc.abstractmethod
|
|
124
|
+
def _step(self) -> None:
|
|
125
|
+
raise NotImplementedError
|
|
126
|
+
|
|
127
|
+
@abc.abstractmethod
|
|
128
|
+
def _get_next_model_config_variant(self) -> ModelConfigVariant:
|
|
129
|
+
raise NotImplementedError
|
|
130
|
+
|
|
131
|
+
def _last_results_erroneous(self) -> bool:
|
|
132
|
+
last_max_throughput = self._get_last_results_max_throughput()
|
|
133
|
+
return last_max_throughput is None
|
|
134
|
+
|
|
135
|
+
def _last_results_increased_throughput(self) -> bool:
|
|
136
|
+
if len(self._curr_max_batch_size_throughputs) < 2:
|
|
137
|
+
return True
|
|
138
|
+
|
|
139
|
+
lastest_throughput = self._curr_max_batch_size_throughputs[-1]
|
|
140
|
+
return all(
|
|
141
|
+
lastest_throughput > prev_throughput
|
|
142
|
+
for prev_throughput in self._curr_max_batch_size_throughputs[:-1]
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
def _get_last_results_max_throughput(self) -> Optional[float]:
|
|
146
|
+
throughputs = [
|
|
147
|
+
m.get_non_gpu_metric_value("perf_throughput")
|
|
148
|
+
for m in self._last_results
|
|
149
|
+
if m is not None
|
|
150
|
+
]
|
|
151
|
+
if not throughputs:
|
|
152
|
+
return None
|
|
153
|
+
else:
|
|
154
|
+
return max(throughputs)
|
|
155
|
+
|
|
156
|
+
def _make_remote_model_config_variant(self) -> ModelConfigVariant:
|
|
157
|
+
if not self._config.reload_model_disable:
|
|
158
|
+
self._client.load_model(model_name=self._base_model_name)
|
|
159
|
+
model_config = ModelConfig.create_from_triton_api(
|
|
160
|
+
self._client, self._base_model_name, self._config.client_max_retries
|
|
161
|
+
)
|
|
162
|
+
if not self._config.reload_model_disable:
|
|
163
|
+
self._client.unload_model(self._base_model_name)
|
|
164
|
+
|
|
165
|
+
return ModelConfigVariant(model_config, self._base_model_name, self._cpu_only)
|
|
166
|
+
|
|
167
|
+
def _make_direct_mode_model_config_variant(
|
|
168
|
+
self, param_combo: Dict
|
|
169
|
+
) -> ModelConfigVariant:
|
|
170
|
+
return BaseModelConfigGenerator.make_model_config_variant(
|
|
171
|
+
param_combo=param_combo,
|
|
172
|
+
model=self._base_model,
|
|
173
|
+
model_variant_name_manager=self._model_variant_name_manager,
|
|
174
|
+
c_api_mode=self._c_api_mode,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
@staticmethod
|
|
178
|
+
def make_model_config_variant(
|
|
179
|
+
param_combo: dict,
|
|
180
|
+
model: ModelProfileSpec,
|
|
181
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
182
|
+
c_api_mode: bool,
|
|
183
|
+
) -> ModelConfigVariant:
|
|
184
|
+
"""
|
|
185
|
+
Loads the base model config from the model repository, and then applies the
|
|
186
|
+
parameters in the param_combo on top to create and return a new model config
|
|
187
|
+
|
|
188
|
+
Parameters:
|
|
189
|
+
-----------
|
|
190
|
+
param_combo: dict
|
|
191
|
+
dict of key:value pairs to apply to the model config
|
|
192
|
+
model: ModelProfileSpec
|
|
193
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
194
|
+
c_api_mode: Set to true if mode is c_api
|
|
195
|
+
"""
|
|
196
|
+
logger_str: List[str] = []
|
|
197
|
+
model_name = model.model_name()
|
|
198
|
+
model_config_dict = BaseModelConfigGenerator._apply_param_combo_to_model(
|
|
199
|
+
model, param_combo, logger_str
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
(
|
|
203
|
+
variant_found,
|
|
204
|
+
variant_name,
|
|
205
|
+
) = model_variant_name_manager.get_model_variant_name(
|
|
206
|
+
model_name, model_config_dict, param_combo
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if variant_found:
|
|
210
|
+
logger.info(f"Found existing model config: {variant_name}")
|
|
211
|
+
else:
|
|
212
|
+
logger.info(f"Creating model config: {variant_name}")
|
|
213
|
+
for str in logger_str:
|
|
214
|
+
logger.info(str)
|
|
215
|
+
logger.info("")
|
|
216
|
+
|
|
217
|
+
model_config_dict["name"] = variant_name if c_api_mode else model_name
|
|
218
|
+
model_config = ModelConfig.create_from_dictionary(model_config_dict)
|
|
219
|
+
|
|
220
|
+
return ModelConfigVariant(model_config, variant_name, model.cpu_only())
|
|
221
|
+
|
|
222
|
+
@staticmethod
|
|
223
|
+
def make_ensemble_model_config_variant(
|
|
224
|
+
model: ModelProfileSpec,
|
|
225
|
+
ensemble_composing_model_config_variants: List[ModelConfigVariant],
|
|
226
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
227
|
+
c_api_mode: bool,
|
|
228
|
+
param_combo: Dict = {},
|
|
229
|
+
) -> ModelConfigVariant:
|
|
230
|
+
"""
|
|
231
|
+
Loads the ensemble model spec from the model repository, and then mutates
|
|
232
|
+
the names to match the ensemble composing models
|
|
233
|
+
|
|
234
|
+
Parameters
|
|
235
|
+
----------
|
|
236
|
+
model: ModelProfileSpec
|
|
237
|
+
The top-level ensemble model spec
|
|
238
|
+
ensemble_composing_model_config_variants: List of ModelConfigVariants
|
|
239
|
+
The list of composing model ModelConfigs
|
|
240
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
241
|
+
c_api_mode: Set to true if mode is c_api
|
|
242
|
+
|
|
243
|
+
"""
|
|
244
|
+
logger_str: List[str] = []
|
|
245
|
+
model_name = model.model_name()
|
|
246
|
+
model_config_dict = BaseModelConfigGenerator._apply_param_combo_to_model(
|
|
247
|
+
model, param_combo, logger_str
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
ensemble_key = ModelVariantNameManager.make_ensemble_composing_model_key(
|
|
251
|
+
ensemble_composing_model_config_variants
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
(
|
|
255
|
+
variant_found,
|
|
256
|
+
variant_name,
|
|
257
|
+
) = model_variant_name_manager.get_ensemble_model_variant_name(
|
|
258
|
+
model_name, ensemble_key
|
|
259
|
+
)
|
|
260
|
+
|
|
261
|
+
if variant_found:
|
|
262
|
+
logger.info(f"Found existing ensemble model config: {variant_name}")
|
|
263
|
+
else:
|
|
264
|
+
logger.info(f"Creating ensemble model config: {variant_name}")
|
|
265
|
+
for str in logger_str:
|
|
266
|
+
logger.info(str)
|
|
267
|
+
|
|
268
|
+
model_config_dict["name"] = variant_name if c_api_mode else model_name
|
|
269
|
+
model_config = ModelConfig.create_from_dictionary(model_config_dict)
|
|
270
|
+
|
|
271
|
+
return ModelConfigVariant(model_config, variant_name)
|
|
272
|
+
|
|
273
|
+
@staticmethod
|
|
274
|
+
def _apply_param_combo_to_model(
|
|
275
|
+
model: ModelProfileSpec, param_combo: dict, logger_str: List[str]
|
|
276
|
+
) -> dict:
|
|
277
|
+
"""
|
|
278
|
+
Given a model, apply any parameters and return a model config dictionary
|
|
279
|
+
"""
|
|
280
|
+
model_config_dict = model.get_default_config()
|
|
281
|
+
if param_combo is not None:
|
|
282
|
+
for key, value in param_combo.items():
|
|
283
|
+
if value is not None:
|
|
284
|
+
BaseModelConfigGenerator._apply_value_to_dict(
|
|
285
|
+
key, value, model_config_dict
|
|
286
|
+
)
|
|
287
|
+
|
|
288
|
+
if value == {}:
|
|
289
|
+
logger_str.append(f" Enabling {key}")
|
|
290
|
+
else:
|
|
291
|
+
logger_str.append(f" Setting {key} to {value}")
|
|
292
|
+
|
|
293
|
+
return model_config_dict
|
|
294
|
+
|
|
295
|
+
def _reset_max_batch_size(self) -> None:
|
|
296
|
+
self._max_batch_size_warning_printed = False
|
|
297
|
+
self._curr_max_batch_size_throughputs = []
|
|
298
|
+
|
|
299
|
+
def _print_max_batch_size_plateau_warning(self) -> None:
|
|
300
|
+
if not self._max_batch_size_warning_printed:
|
|
301
|
+
logger.info(
|
|
302
|
+
"No longer increasing max_batch_size because throughput has plateaued"
|
|
303
|
+
)
|
|
304
|
+
self._max_batch_size_warning_printed = True
|
|
305
|
+
|
|
306
|
+
@staticmethod
|
|
307
|
+
def extract_model_name_from_variant_name(variant_name: str) -> str:
|
|
308
|
+
"""
|
|
309
|
+
Removes '_config_#/default' from the variant name and returns
|
|
310
|
+
the model name, eg. model_name_config_10 -> model_name
|
|
311
|
+
"""
|
|
312
|
+
model_name = variant_name
|
|
313
|
+
config_index = variant_name.find("_config_")
|
|
314
|
+
|
|
315
|
+
if config_index != -1:
|
|
316
|
+
model_name = variant_name[:config_index]
|
|
317
|
+
|
|
318
|
+
return model_name
|
|
319
|
+
|
|
320
|
+
@staticmethod
|
|
321
|
+
def create_original_config_from_variant(variant_config: ModelConfig) -> ModelConfig:
|
|
322
|
+
"""
|
|
323
|
+
Removes 'config_#/default' from the variant config and returns
|
|
324
|
+
a new model config
|
|
325
|
+
"""
|
|
326
|
+
original_config = deepcopy(variant_config)
|
|
327
|
+
|
|
328
|
+
original_config.set_model_name(
|
|
329
|
+
BaseModelConfigGenerator.extract_model_name_from_variant_name(
|
|
330
|
+
variant_config.get_field("name")
|
|
331
|
+
)
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
return original_config
|
|
335
|
+
|
|
336
|
+
@staticmethod
|
|
337
|
+
def _apply_value_to_dict(key: Any, value: Any, dict_in: Dict) -> None:
|
|
338
|
+
"""
|
|
339
|
+
Apply the supplied value at the given key into the provided dict.
|
|
340
|
+
|
|
341
|
+
If the key already exists in the dict and both the existing value as well
|
|
342
|
+
as the new input value are dicts, only overwrite the subkeys (recursively)
|
|
343
|
+
provided in the value
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
if type(dict_in.get(key, None)) is dict and type(value) is dict:
|
|
347
|
+
for subkey, subvalue in value.items():
|
|
348
|
+
BaseModelConfigGenerator._apply_value_to_dict(
|
|
349
|
+
subkey, subvalue, dict_in.get(key, None)
|
|
350
|
+
)
|
|
351
|
+
else:
|
|
352
|
+
dict_in[key] = value
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from typing import Dict, Generator, List, Optional
|
|
20
|
+
|
|
21
|
+
from model_analyzer.config.generate.brute_run_config_generator import (
|
|
22
|
+
BruteRunConfigGenerator,
|
|
23
|
+
)
|
|
24
|
+
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
|
|
25
|
+
from model_analyzer.config.generate.model_variant_name_manager import (
|
|
26
|
+
ModelVariantNameManager,
|
|
27
|
+
)
|
|
28
|
+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
|
|
29
|
+
from model_analyzer.config.run.run_config import RunConfig
|
|
30
|
+
from model_analyzer.constants import LOGGER_NAME
|
|
31
|
+
from model_analyzer.device.gpu_device import GPUDevice
|
|
32
|
+
from model_analyzer.result.parameter_search import ParameterSearch
|
|
33
|
+
from model_analyzer.result.result_manager import ResultManager
|
|
34
|
+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
|
|
35
|
+
from model_analyzer.triton.client.client import TritonClient
|
|
36
|
+
|
|
37
|
+
from .config_generator_interface import ConfigGeneratorInterface
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class BrutePlusBinaryParameterSearchRunConfigGenerator(ConfigGeneratorInterface):
|
|
43
|
+
"""
|
|
44
|
+
First run BruteRunConfigGenerator for a brute search, then for
|
|
45
|
+
automatic searches use ParameterSearch to perform a binary search
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
config: ConfigCommandProfile,
|
|
51
|
+
gpus: List[GPUDevice],
|
|
52
|
+
models: List[ModelProfileSpec],
|
|
53
|
+
client: TritonClient,
|
|
54
|
+
result_manager: ResultManager,
|
|
55
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
56
|
+
):
|
|
57
|
+
"""
|
|
58
|
+
Parameters
|
|
59
|
+
----------
|
|
60
|
+
config: ConfigCommandProfile
|
|
61
|
+
Profile configuration information
|
|
62
|
+
gpus: List of GPUDevices
|
|
63
|
+
models: List of ModelProfileSpec
|
|
64
|
+
List of models to profile
|
|
65
|
+
client: TritonClient
|
|
66
|
+
result_manager: ResultManager
|
|
67
|
+
The object that handles storing and sorting the results from the perf analyzer
|
|
68
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
69
|
+
Maps model variants to config names
|
|
70
|
+
"""
|
|
71
|
+
self._config = config
|
|
72
|
+
self._gpus = gpus
|
|
73
|
+
self._models = models
|
|
74
|
+
self._client = client
|
|
75
|
+
self._result_manager = result_manager
|
|
76
|
+
self._model_variant_name_manager = model_variant_name_manager
|
|
77
|
+
|
|
78
|
+
def set_last_results(
|
|
79
|
+
self, measurements: List[Optional[RunConfigMeasurement]]
|
|
80
|
+
) -> None:
|
|
81
|
+
self._last_measurement = measurements[-1]
|
|
82
|
+
self._rcg.set_last_results(measurements)
|
|
83
|
+
|
|
84
|
+
def get_configs(self) -> Generator[RunConfig, None, None]:
|
|
85
|
+
"""
|
|
86
|
+
Returns
|
|
87
|
+
-------
|
|
88
|
+
RunConfig
|
|
89
|
+
The next RunConfig generated by this class
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
yield from self._execute_brute_search()
|
|
93
|
+
logger.info("")
|
|
94
|
+
logger.info("Done with brute mode search.")
|
|
95
|
+
logger.info("")
|
|
96
|
+
|
|
97
|
+
if self._can_binary_search_top_results():
|
|
98
|
+
yield from self._binary_search_over_top_results()
|
|
99
|
+
logger.info("")
|
|
100
|
+
logger.info("Done gathering concurrency sweep measurements for reports")
|
|
101
|
+
logger.info("")
|
|
102
|
+
|
|
103
|
+
def _execute_brute_search(self) -> Generator[RunConfig, None, None]:
|
|
104
|
+
self._rcg: ConfigGeneratorInterface = self._create_brute_run_config_generator()
|
|
105
|
+
|
|
106
|
+
yield from self._rcg.get_configs()
|
|
107
|
+
|
|
108
|
+
def _create_brute_run_config_generator(self) -> BruteRunConfigGenerator:
|
|
109
|
+
return BruteRunConfigGenerator(
|
|
110
|
+
config=self._config,
|
|
111
|
+
gpus=self._gpus,
|
|
112
|
+
models=self._models,
|
|
113
|
+
client=self._client,
|
|
114
|
+
model_variant_name_manager=self._model_variant_name_manager,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def _can_binary_search_top_results(self) -> bool:
|
|
118
|
+
for model in self._models:
|
|
119
|
+
if model.parameters()["concurrency"] or model.parameters()["request_rate"]:
|
|
120
|
+
return False
|
|
121
|
+
|
|
122
|
+
return True
|
|
123
|
+
|
|
124
|
+
def _binary_search_over_top_results(self) -> Generator[RunConfig, None, None]:
|
|
125
|
+
for model_name in self._result_manager.get_model_names():
|
|
126
|
+
top_results = self._result_manager.top_n_results(
|
|
127
|
+
model_name=model_name,
|
|
128
|
+
n=self._config.num_configs_per_model,
|
|
129
|
+
include_default=True,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
for result in top_results:
|
|
133
|
+
run_config = deepcopy(result.run_config())
|
|
134
|
+
model_parameters = self._get_model_parameters(model_name)
|
|
135
|
+
parameter_search = ParameterSearch(
|
|
136
|
+
config=self._config,
|
|
137
|
+
model_parameters=model_parameters,
|
|
138
|
+
skip_parameter_sweep=True,
|
|
139
|
+
)
|
|
140
|
+
for parameter in parameter_search.search_parameters():
|
|
141
|
+
run_config = self._set_parameter(
|
|
142
|
+
run_config, model_parameters, parameter
|
|
143
|
+
)
|
|
144
|
+
yield run_config
|
|
145
|
+
parameter_search.add_run_config_measurement(self._last_measurement)
|
|
146
|
+
|
|
147
|
+
def _get_model_parameters(self, model_name: str) -> Dict:
|
|
148
|
+
for model in self._models:
|
|
149
|
+
if model_name == model.model_name():
|
|
150
|
+
return model.parameters()
|
|
151
|
+
|
|
152
|
+
return {}
|
|
153
|
+
|
|
154
|
+
def _set_parameter(
|
|
155
|
+
self, run_config: RunConfig, model_parameters: Dict, parameter: int
|
|
156
|
+
) -> RunConfig:
|
|
157
|
+
for model_run_config in run_config.model_run_configs():
|
|
158
|
+
perf_config = model_run_config.perf_config()
|
|
159
|
+
if self._config.is_request_rate_specified(model_parameters):
|
|
160
|
+
perf_config.update_config({"request-rate-range": parameter})
|
|
161
|
+
else:
|
|
162
|
+
perf_config.update_config({"concurrency-range": parameter})
|
|
163
|
+
|
|
164
|
+
return run_config
|