triton-model-analyzer 1.48.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- model_analyzer/__init__.py +15 -0
- model_analyzer/analyzer.py +448 -0
- model_analyzer/cli/__init__.py +15 -0
- model_analyzer/cli/cli.py +193 -0
- model_analyzer/config/__init__.py +15 -0
- model_analyzer/config/generate/__init__.py +15 -0
- model_analyzer/config/generate/automatic_model_config_generator.py +164 -0
- model_analyzer/config/generate/base_model_config_generator.py +352 -0
- model_analyzer/config/generate/brute_plus_binary_parameter_search_run_config_generator.py +164 -0
- model_analyzer/config/generate/brute_run_config_generator.py +154 -0
- model_analyzer/config/generate/concurrency_sweeper.py +75 -0
- model_analyzer/config/generate/config_generator_interface.py +52 -0
- model_analyzer/config/generate/coordinate.py +143 -0
- model_analyzer/config/generate/coordinate_data.py +86 -0
- model_analyzer/config/generate/generator_utils.py +116 -0
- model_analyzer/config/generate/manual_model_config_generator.py +187 -0
- model_analyzer/config/generate/model_config_generator_factory.py +92 -0
- model_analyzer/config/generate/model_profile_spec.py +74 -0
- model_analyzer/config/generate/model_run_config_generator.py +154 -0
- model_analyzer/config/generate/model_variant_name_manager.py +150 -0
- model_analyzer/config/generate/neighborhood.py +536 -0
- model_analyzer/config/generate/optuna_plus_concurrency_sweep_run_config_generator.py +141 -0
- model_analyzer/config/generate/optuna_run_config_generator.py +838 -0
- model_analyzer/config/generate/perf_analyzer_config_generator.py +312 -0
- model_analyzer/config/generate/quick_plus_concurrency_sweep_run_config_generator.py +130 -0
- model_analyzer/config/generate/quick_run_config_generator.py +753 -0
- model_analyzer/config/generate/run_config_generator_factory.py +329 -0
- model_analyzer/config/generate/search_config.py +112 -0
- model_analyzer/config/generate/search_dimension.py +73 -0
- model_analyzer/config/generate/search_dimensions.py +85 -0
- model_analyzer/config/generate/search_parameter.py +49 -0
- model_analyzer/config/generate/search_parameters.py +388 -0
- model_analyzer/config/input/__init__.py +15 -0
- model_analyzer/config/input/config_command.py +483 -0
- model_analyzer/config/input/config_command_profile.py +1747 -0
- model_analyzer/config/input/config_command_report.py +267 -0
- model_analyzer/config/input/config_defaults.py +236 -0
- model_analyzer/config/input/config_enum.py +83 -0
- model_analyzer/config/input/config_field.py +216 -0
- model_analyzer/config/input/config_list_generic.py +112 -0
- model_analyzer/config/input/config_list_numeric.py +151 -0
- model_analyzer/config/input/config_list_string.py +111 -0
- model_analyzer/config/input/config_none.py +71 -0
- model_analyzer/config/input/config_object.py +129 -0
- model_analyzer/config/input/config_primitive.py +81 -0
- model_analyzer/config/input/config_status.py +75 -0
- model_analyzer/config/input/config_sweep.py +83 -0
- model_analyzer/config/input/config_union.py +113 -0
- model_analyzer/config/input/config_utils.py +128 -0
- model_analyzer/config/input/config_value.py +243 -0
- model_analyzer/config/input/objects/__init__.py +15 -0
- model_analyzer/config/input/objects/config_model_profile_spec.py +325 -0
- model_analyzer/config/input/objects/config_model_report_spec.py +173 -0
- model_analyzer/config/input/objects/config_plot.py +198 -0
- model_analyzer/config/input/objects/config_protobuf_utils.py +101 -0
- model_analyzer/config/input/yaml_config_validator.py +82 -0
- model_analyzer/config/run/__init__.py +15 -0
- model_analyzer/config/run/model_run_config.py +313 -0
- model_analyzer/config/run/run_config.py +168 -0
- model_analyzer/constants.py +76 -0
- model_analyzer/device/__init__.py +15 -0
- model_analyzer/device/device.py +24 -0
- model_analyzer/device/gpu_device.py +87 -0
- model_analyzer/device/gpu_device_factory.py +248 -0
- model_analyzer/entrypoint.py +307 -0
- model_analyzer/log_formatter.py +65 -0
- model_analyzer/model_analyzer_exceptions.py +24 -0
- model_analyzer/model_manager.py +255 -0
- model_analyzer/monitor/__init__.py +15 -0
- model_analyzer/monitor/cpu_monitor.py +69 -0
- model_analyzer/monitor/dcgm/DcgmDiag.py +191 -0
- model_analyzer/monitor/dcgm/DcgmFieldGroup.py +83 -0
- model_analyzer/monitor/dcgm/DcgmGroup.py +815 -0
- model_analyzer/monitor/dcgm/DcgmHandle.py +141 -0
- model_analyzer/monitor/dcgm/DcgmJsonReader.py +69 -0
- model_analyzer/monitor/dcgm/DcgmReader.py +623 -0
- model_analyzer/monitor/dcgm/DcgmStatus.py +57 -0
- model_analyzer/monitor/dcgm/DcgmSystem.py +412 -0
- model_analyzer/monitor/dcgm/__init__.py +15 -0
- model_analyzer/monitor/dcgm/common/__init__.py +13 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_cli_parser.py +194 -0
- model_analyzer/monitor/dcgm/common/dcgm_client_main.py +86 -0
- model_analyzer/monitor/dcgm/dcgm_agent.py +887 -0
- model_analyzer/monitor/dcgm/dcgm_collectd_plugin.py +369 -0
- model_analyzer/monitor/dcgm/dcgm_errors.py +395 -0
- model_analyzer/monitor/dcgm/dcgm_field_helpers.py +546 -0
- model_analyzer/monitor/dcgm/dcgm_fields.py +815 -0
- model_analyzer/monitor/dcgm/dcgm_fields_collectd.py +671 -0
- model_analyzer/monitor/dcgm/dcgm_fields_internal.py +29 -0
- model_analyzer/monitor/dcgm/dcgm_fluentd.py +45 -0
- model_analyzer/monitor/dcgm/dcgm_monitor.py +138 -0
- model_analyzer/monitor/dcgm/dcgm_prometheus.py +326 -0
- model_analyzer/monitor/dcgm/dcgm_structs.py +2357 -0
- model_analyzer/monitor/dcgm/dcgm_telegraf.py +65 -0
- model_analyzer/monitor/dcgm/dcgm_value.py +151 -0
- model_analyzer/monitor/dcgm/dcgmvalue.py +155 -0
- model_analyzer/monitor/dcgm/denylist_recommendations.py +573 -0
- model_analyzer/monitor/dcgm/pydcgm.py +47 -0
- model_analyzer/monitor/monitor.py +143 -0
- model_analyzer/monitor/remote_monitor.py +137 -0
- model_analyzer/output/__init__.py +15 -0
- model_analyzer/output/file_writer.py +63 -0
- model_analyzer/output/output_writer.py +42 -0
- model_analyzer/perf_analyzer/__init__.py +15 -0
- model_analyzer/perf_analyzer/genai_perf_config.py +206 -0
- model_analyzer/perf_analyzer/perf_analyzer.py +882 -0
- model_analyzer/perf_analyzer/perf_config.py +479 -0
- model_analyzer/plots/__init__.py +15 -0
- model_analyzer/plots/detailed_plot.py +266 -0
- model_analyzer/plots/plot_manager.py +224 -0
- model_analyzer/plots/simple_plot.py +213 -0
- model_analyzer/record/__init__.py +15 -0
- model_analyzer/record/gpu_record.py +68 -0
- model_analyzer/record/metrics_manager.py +887 -0
- model_analyzer/record/record.py +280 -0
- model_analyzer/record/record_aggregator.py +256 -0
- model_analyzer/record/types/__init__.py +15 -0
- model_analyzer/record/types/cpu_available_ram.py +93 -0
- model_analyzer/record/types/cpu_used_ram.py +93 -0
- model_analyzer/record/types/gpu_free_memory.py +96 -0
- model_analyzer/record/types/gpu_power_usage.py +107 -0
- model_analyzer/record/types/gpu_total_memory.py +96 -0
- model_analyzer/record/types/gpu_used_memory.py +96 -0
- model_analyzer/record/types/gpu_utilization.py +108 -0
- model_analyzer/record/types/inter_token_latency_avg.py +60 -0
- model_analyzer/record/types/inter_token_latency_base.py +74 -0
- model_analyzer/record/types/inter_token_latency_max.py +60 -0
- model_analyzer/record/types/inter_token_latency_min.py +60 -0
- model_analyzer/record/types/inter_token_latency_p25.py +60 -0
- model_analyzer/record/types/inter_token_latency_p50.py +60 -0
- model_analyzer/record/types/inter_token_latency_p75.py +60 -0
- model_analyzer/record/types/inter_token_latency_p90.py +60 -0
- model_analyzer/record/types/inter_token_latency_p95.py +60 -0
- model_analyzer/record/types/inter_token_latency_p99.py +60 -0
- model_analyzer/record/types/output_token_throughput.py +105 -0
- model_analyzer/record/types/perf_client_response_wait.py +97 -0
- model_analyzer/record/types/perf_client_send_recv.py +97 -0
- model_analyzer/record/types/perf_latency.py +111 -0
- model_analyzer/record/types/perf_latency_avg.py +60 -0
- model_analyzer/record/types/perf_latency_base.py +74 -0
- model_analyzer/record/types/perf_latency_p90.py +60 -0
- model_analyzer/record/types/perf_latency_p95.py +60 -0
- model_analyzer/record/types/perf_latency_p99.py +60 -0
- model_analyzer/record/types/perf_server_compute_infer.py +97 -0
- model_analyzer/record/types/perf_server_compute_input.py +97 -0
- model_analyzer/record/types/perf_server_compute_output.py +97 -0
- model_analyzer/record/types/perf_server_queue.py +97 -0
- model_analyzer/record/types/perf_throughput.py +105 -0
- model_analyzer/record/types/time_to_first_token_avg.py +60 -0
- model_analyzer/record/types/time_to_first_token_base.py +74 -0
- model_analyzer/record/types/time_to_first_token_max.py +60 -0
- model_analyzer/record/types/time_to_first_token_min.py +60 -0
- model_analyzer/record/types/time_to_first_token_p25.py +60 -0
- model_analyzer/record/types/time_to_first_token_p50.py +60 -0
- model_analyzer/record/types/time_to_first_token_p75.py +60 -0
- model_analyzer/record/types/time_to_first_token_p90.py +60 -0
- model_analyzer/record/types/time_to_first_token_p95.py +60 -0
- model_analyzer/record/types/time_to_first_token_p99.py +60 -0
- model_analyzer/reports/__init__.py +15 -0
- model_analyzer/reports/html_report.py +195 -0
- model_analyzer/reports/pdf_report.py +50 -0
- model_analyzer/reports/report.py +86 -0
- model_analyzer/reports/report_factory.py +62 -0
- model_analyzer/reports/report_manager.py +1376 -0
- model_analyzer/reports/report_utils.py +42 -0
- model_analyzer/result/__init__.py +15 -0
- model_analyzer/result/constraint_manager.py +150 -0
- model_analyzer/result/model_config_measurement.py +354 -0
- model_analyzer/result/model_constraints.py +105 -0
- model_analyzer/result/parameter_search.py +246 -0
- model_analyzer/result/result_manager.py +430 -0
- model_analyzer/result/result_statistics.py +159 -0
- model_analyzer/result/result_table.py +217 -0
- model_analyzer/result/result_table_manager.py +646 -0
- model_analyzer/result/result_utils.py +42 -0
- model_analyzer/result/results.py +277 -0
- model_analyzer/result/run_config_measurement.py +658 -0
- model_analyzer/result/run_config_result.py +210 -0
- model_analyzer/result/run_config_result_comparator.py +110 -0
- model_analyzer/result/sorted_results.py +151 -0
- model_analyzer/state/__init__.py +15 -0
- model_analyzer/state/analyzer_state.py +76 -0
- model_analyzer/state/analyzer_state_manager.py +215 -0
- model_analyzer/triton/__init__.py +15 -0
- model_analyzer/triton/client/__init__.py +15 -0
- model_analyzer/triton/client/client.py +234 -0
- model_analyzer/triton/client/client_factory.py +57 -0
- model_analyzer/triton/client/grpc_client.py +104 -0
- model_analyzer/triton/client/http_client.py +107 -0
- model_analyzer/triton/model/__init__.py +15 -0
- model_analyzer/triton/model/model_config.py +556 -0
- model_analyzer/triton/model/model_config_variant.py +29 -0
- model_analyzer/triton/server/__init__.py +15 -0
- model_analyzer/triton/server/server.py +76 -0
- model_analyzer/triton/server/server_config.py +269 -0
- model_analyzer/triton/server/server_docker.py +229 -0
- model_analyzer/triton/server/server_factory.py +306 -0
- model_analyzer/triton/server/server_local.py +158 -0
- triton_model_analyzer-1.48.0.dist-info/METADATA +52 -0
- triton_model_analyzer-1.48.0.dist-info/RECORD +204 -0
- triton_model_analyzer-1.48.0.dist-info/WHEEL +5 -0
- triton_model_analyzer-1.48.0.dist-info/entry_points.txt +2 -0
- triton_model_analyzer-1.48.0.dist-info/licenses/LICENSE +67 -0
- triton_model_analyzer-1.48.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
20
|
+
|
|
21
|
+
from model_analyzer.config.generate.coordinate import Coordinate
|
|
22
|
+
from model_analyzer.config.generate.coordinate_data import CoordinateData
|
|
23
|
+
from model_analyzer.config.generate.search_config import NeighborhoodConfig
|
|
24
|
+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class Neighborhood:
|
|
28
|
+
"""
|
|
29
|
+
Defines and operates on a set of coordinates within a radius around
|
|
30
|
+
a 'home' coordinate
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
# This defines the bounds of how the vector calculated from
|
|
34
|
+
# measurements is converted to a step vector.
|
|
35
|
+
#
|
|
36
|
+
# The translation will return the lowest index that has a value greater
|
|
37
|
+
# than the input value.
|
|
38
|
+
#
|
|
39
|
+
# For example, if the input is greater than the value in index 1 but less than
|
|
40
|
+
# the value in index 2, the resulting step will be 1
|
|
41
|
+
#
|
|
42
|
+
TRANSLATION_LIST = [0.09, 0.3, 1.0]
|
|
43
|
+
|
|
44
|
+
def __init__(
|
|
45
|
+
self,
|
|
46
|
+
neighborhood_config: NeighborhoodConfig,
|
|
47
|
+
home_coordinate: Coordinate,
|
|
48
|
+
coordinate_data: CoordinateData,
|
|
49
|
+
):
|
|
50
|
+
"""
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
neighborhood_config:
|
|
54
|
+
NeighborhoodConfig object
|
|
55
|
+
home_coordinate:
|
|
56
|
+
Coordinate object to center the neighborhood around
|
|
57
|
+
"""
|
|
58
|
+
assert type(neighborhood_config) == NeighborhoodConfig
|
|
59
|
+
|
|
60
|
+
self._config = neighborhood_config
|
|
61
|
+
self._home_coordinate = home_coordinate
|
|
62
|
+
self._coordinate_data = coordinate_data
|
|
63
|
+
|
|
64
|
+
self._radius = self._config.get_radius()
|
|
65
|
+
self._neighborhood = self._create_neighborhood()
|
|
66
|
+
|
|
67
|
+
self._force_slow_mode = False
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def calc_distance(
|
|
71
|
+
cls,
|
|
72
|
+
coordinate1: Union[Coordinate, List[int]],
|
|
73
|
+
coordinate2: Union[Coordinate, List[int]],
|
|
74
|
+
) -> float:
|
|
75
|
+
"""
|
|
76
|
+
Return the euclidean distance between two coordinates
|
|
77
|
+
"""
|
|
78
|
+
|
|
79
|
+
distance = 0.0
|
|
80
|
+
for i, _ in enumerate(coordinate1):
|
|
81
|
+
diff = coordinate1[i] - coordinate2[i]
|
|
82
|
+
distance += math.pow(diff, 2)
|
|
83
|
+
distance = math.sqrt(distance)
|
|
84
|
+
return distance
|
|
85
|
+
|
|
86
|
+
def enough_coordinates_initialized(self) -> bool:
|
|
87
|
+
"""
|
|
88
|
+
Returns true if enough coordinates inside of the neighborhood
|
|
89
|
+
have been initialized with valid measurements. Else false
|
|
90
|
+
|
|
91
|
+
If the neighborhood is in slow mode, this means all adjacent neighbors
|
|
92
|
+
must be visited
|
|
93
|
+
"""
|
|
94
|
+
if self._is_slow_mode():
|
|
95
|
+
return self._are_all_adjacent_neighbors_measured()
|
|
96
|
+
else:
|
|
97
|
+
min_initialized = self._config.get_min_initialized()
|
|
98
|
+
num_initialized = len(self._get_coordinates_with_valid_measurements())
|
|
99
|
+
return num_initialized >= min_initialized
|
|
100
|
+
|
|
101
|
+
def force_slow_mode(self) -> None:
|
|
102
|
+
"""
|
|
103
|
+
When called, forces the neighborhood into slow mode
|
|
104
|
+
"""
|
|
105
|
+
self._force_slow_mode = True
|
|
106
|
+
|
|
107
|
+
def determine_new_home(self) -> Coordinate:
|
|
108
|
+
"""
|
|
109
|
+
Based on the measurements in the neighborhood, determine where
|
|
110
|
+
the next location should be.
|
|
111
|
+
|
|
112
|
+
If the neighborhood is in slow mode, return the best found measurement
|
|
113
|
+
Otherwise calculate a new coordinate from the measurements
|
|
114
|
+
|
|
115
|
+
Returns
|
|
116
|
+
-------
|
|
117
|
+
new_coordinate
|
|
118
|
+
The new coordinate computed based on the neighborhood measurements.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
if self._is_slow_mode():
|
|
122
|
+
return self._get_best_coordinate_found()
|
|
123
|
+
else:
|
|
124
|
+
return self._calculate_new_home()
|
|
125
|
+
|
|
126
|
+
def _get_best_coordinate_found(self) -> Coordinate:
|
|
127
|
+
vectors, measurements = self._get_measurements_passing_constraints()
|
|
128
|
+
|
|
129
|
+
if len(vectors) == 0:
|
|
130
|
+
return self._home_coordinate
|
|
131
|
+
|
|
132
|
+
home_measurement = self._get_home_measurement()
|
|
133
|
+
|
|
134
|
+
if home_measurement and home_measurement.is_passing_constraints():
|
|
135
|
+
vectors.append(Coordinate([0] * self._config.get_num_dimensions()))
|
|
136
|
+
measurements.append(home_measurement)
|
|
137
|
+
|
|
138
|
+
_, best_vector = sorted(zip(measurements, vectors))[-1]
|
|
139
|
+
|
|
140
|
+
best_coordinate = self._home_coordinate + best_vector
|
|
141
|
+
return best_coordinate
|
|
142
|
+
|
|
143
|
+
def _calculate_new_home(self) -> Coordinate:
|
|
144
|
+
step_vector = self._get_step_vector()
|
|
145
|
+
step_vector_coordinate = self._translate_step_vector(
|
|
146
|
+
step_vector, Neighborhood.TRANSLATION_LIST
|
|
147
|
+
)
|
|
148
|
+
tmp_new_coordinate = self._home_coordinate + step_vector_coordinate
|
|
149
|
+
new_coordinate = self._clamp_coordinate_to_bounds(tmp_new_coordinate)
|
|
150
|
+
return new_coordinate
|
|
151
|
+
|
|
152
|
+
def _translate_step_vector(
|
|
153
|
+
self, step_vector: List[float], translate_list: List[float]
|
|
154
|
+
) -> Coordinate:
|
|
155
|
+
translated_step_vector = Coordinate([0] * len(step_vector))
|
|
156
|
+
for i, v in enumerate(step_vector):
|
|
157
|
+
translated_step_vector[i] = self._translate_value(v, translate_list)
|
|
158
|
+
|
|
159
|
+
return translated_step_vector
|
|
160
|
+
|
|
161
|
+
def _translate_value(self, value: float, translation_list: List[float]) -> int:
|
|
162
|
+
ret = 0
|
|
163
|
+
for index, bound in enumerate(translation_list):
|
|
164
|
+
if value > 0 and value > bound:
|
|
165
|
+
ret = index + 1
|
|
166
|
+
if value < 0 and value < -1 * bound:
|
|
167
|
+
ret = -1 * (index + 1)
|
|
168
|
+
return ret
|
|
169
|
+
|
|
170
|
+
def pick_coordinate_to_initialize(self) -> Optional[Coordinate]:
|
|
171
|
+
"""
|
|
172
|
+
Based on the initialized coordinate values, pick an unvisited
|
|
173
|
+
coordinate to initialize next.
|
|
174
|
+
|
|
175
|
+
If the neighborhood is in slow mode, only pick from within the adjacent neighbors
|
|
176
|
+
"""
|
|
177
|
+
|
|
178
|
+
if self._is_slow_mode():
|
|
179
|
+
return self._pick_slow_mode_coordinate_to_initialize()
|
|
180
|
+
else:
|
|
181
|
+
return self._pick_fast_mode_coordinate_to_initialize()
|
|
182
|
+
|
|
183
|
+
def _pick_slow_mode_coordinate_to_initialize(self) -> Coordinate:
|
|
184
|
+
for neighbor in self._get_all_adjacent_neighbors():
|
|
185
|
+
if not self._is_coordinate_measured(neighbor):
|
|
186
|
+
return neighbor
|
|
187
|
+
|
|
188
|
+
raise Exception("Picking slow mode coordinate, but none are unvisited")
|
|
189
|
+
|
|
190
|
+
def _pick_fast_mode_coordinate_to_initialize(self) -> Optional[Coordinate]:
|
|
191
|
+
covered_values_per_dimension = self._get_covered_values_per_dimension()
|
|
192
|
+
|
|
193
|
+
max_num_uncovered = -1
|
|
194
|
+
best_coordinate = None
|
|
195
|
+
for coordinate in self._neighborhood:
|
|
196
|
+
if not self._is_coordinate_measured(coordinate):
|
|
197
|
+
num_uncovered = self._get_num_uncovered_values(
|
|
198
|
+
coordinate, covered_values_per_dimension
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if num_uncovered > max_num_uncovered:
|
|
202
|
+
max_num_uncovered = num_uncovered
|
|
203
|
+
best_coordinate = coordinate
|
|
204
|
+
|
|
205
|
+
return best_coordinate
|
|
206
|
+
|
|
207
|
+
def get_nearest_neighbor(self, coordinate_in: Coordinate) -> Coordinate:
|
|
208
|
+
"""
|
|
209
|
+
Find the nearest coordinate to the `coordinate_in` among the
|
|
210
|
+
coordinates within the current neighborhood.
|
|
211
|
+
"""
|
|
212
|
+
min_distance = float("inf")
|
|
213
|
+
nearest_neighbor = self._home_coordinate
|
|
214
|
+
|
|
215
|
+
for coordinate in self._neighborhood:
|
|
216
|
+
distance = Neighborhood.calc_distance(coordinate, coordinate_in)
|
|
217
|
+
if distance < min_distance:
|
|
218
|
+
nearest_neighbor = coordinate
|
|
219
|
+
min_distance = distance
|
|
220
|
+
|
|
221
|
+
return nearest_neighbor
|
|
222
|
+
|
|
223
|
+
def _create_neighborhood(self) -> List[Coordinate]:
|
|
224
|
+
"""
|
|
225
|
+
Create and return a neighborhood of all Coordinates within
|
|
226
|
+
range <_radius> that are also within all bounds
|
|
227
|
+
"""
|
|
228
|
+
|
|
229
|
+
neighborhood = []
|
|
230
|
+
potential_steps = self._get_potential_steps(
|
|
231
|
+
self._config.get_num_dimensions(), self._radius
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
for potential_step in potential_steps:
|
|
235
|
+
for i, v in enumerate(self._home_coordinate):
|
|
236
|
+
potential_step[i] += v
|
|
237
|
+
if self._is_in_bounds(potential_step):
|
|
238
|
+
neighborhood.append(Coordinate(potential_step))
|
|
239
|
+
return neighborhood
|
|
240
|
+
|
|
241
|
+
def _is_in_bounds(self, potential_coordinate: List[int]) -> bool:
|
|
242
|
+
for i, v in enumerate(potential_coordinate):
|
|
243
|
+
dim = self._config.get_dimension(i)
|
|
244
|
+
if v > dim.get_max_idx() or v < dim.get_min_idx():
|
|
245
|
+
return False
|
|
246
|
+
return True
|
|
247
|
+
|
|
248
|
+
def _get_potential_steps(
|
|
249
|
+
self, num_coordinates: int, radius: int
|
|
250
|
+
) -> List[List[int]]:
|
|
251
|
+
"""
|
|
252
|
+
Create and return a list of all possible step vectors that are
|
|
253
|
+
within <_radius> distance
|
|
254
|
+
"""
|
|
255
|
+
|
|
256
|
+
result_list: List[List[int]] = []
|
|
257
|
+
v = [0] * num_coordinates
|
|
258
|
+
self._permute_steps_in_range(v, radius, 0, result_list)
|
|
259
|
+
return result_list
|
|
260
|
+
|
|
261
|
+
def _append_combinations_to_results(
|
|
262
|
+
self, curr_val: List[int], index: int, result_list: List[List[int]]
|
|
263
|
+
) -> None:
|
|
264
|
+
"""
|
|
265
|
+
Given a List of integers (a potential step vector) with all positive
|
|
266
|
+
values, permutate all combinations of positive/negative values and
|
|
267
|
+
append it to the result_list
|
|
268
|
+
|
|
269
|
+
For example, an input of [1,0,2] will append the following:
|
|
270
|
+
[1,0,2], [1,0,-2], [-1,0,2], [-1,0,-2]
|
|
271
|
+
"""
|
|
272
|
+
if index + 1 == len(curr_val):
|
|
273
|
+
result_list.append(deepcopy(curr_val))
|
|
274
|
+
if curr_val[index]:
|
|
275
|
+
curr_val[index] = -curr_val[index]
|
|
276
|
+
result_list.append(deepcopy(curr_val))
|
|
277
|
+
else:
|
|
278
|
+
self._append_combinations_to_results(curr_val, index + 1, result_list)
|
|
279
|
+
if curr_val[index]:
|
|
280
|
+
curr_val[index] = -curr_val[index]
|
|
281
|
+
self._append_combinations_to_results(curr_val, index + 1, result_list)
|
|
282
|
+
|
|
283
|
+
def _permute_steps_in_range(
|
|
284
|
+
self,
|
|
285
|
+
curr_step: List[int],
|
|
286
|
+
radius: int,
|
|
287
|
+
index: int,
|
|
288
|
+
result_list: List[List[int]],
|
|
289
|
+
) -> None:
|
|
290
|
+
"""
|
|
291
|
+
Recursively walk all combinations of steps within the desired radius
|
|
292
|
+
"""
|
|
293
|
+
base = [0] * len(curr_step)
|
|
294
|
+
|
|
295
|
+
for i in range(radius + 1):
|
|
296
|
+
curr_step[index] = i
|
|
297
|
+
|
|
298
|
+
# Leaf (rightmost) coordinate index: Add to results if in range
|
|
299
|
+
if index == len(curr_step) - 1:
|
|
300
|
+
d = Neighborhood.calc_distance(base, curr_step)
|
|
301
|
+
if d <= radius:
|
|
302
|
+
self._append_combinations_to_results(curr_step, 0, result_list)
|
|
303
|
+
else:
|
|
304
|
+
return
|
|
305
|
+
# Non-leaf coordinate index: Recurse
|
|
306
|
+
else:
|
|
307
|
+
self._permute_steps_in_range(curr_step, radius, index + 1, result_list)
|
|
308
|
+
|
|
309
|
+
def _get_coordinates_with_valid_measurements(self) -> List[Coordinate]:
|
|
310
|
+
initialized_coordinates = []
|
|
311
|
+
for coordinate in self._neighborhood:
|
|
312
|
+
if (
|
|
313
|
+
coordinate != self._home_coordinate
|
|
314
|
+
and self._coordinate_data.has_valid_measurement(coordinate)
|
|
315
|
+
):
|
|
316
|
+
initialized_coordinates.append(deepcopy(coordinate))
|
|
317
|
+
return initialized_coordinates
|
|
318
|
+
|
|
319
|
+
def _get_step_vector(self) -> List[float]:
|
|
320
|
+
"""
|
|
321
|
+
Calculate a vector that indicates a direction to step from the
|
|
322
|
+
home coordinate (current center).
|
|
323
|
+
|
|
324
|
+
Returns
|
|
325
|
+
-------
|
|
326
|
+
step_vector
|
|
327
|
+
a coordinate that tells the direction to move.
|
|
328
|
+
"""
|
|
329
|
+
|
|
330
|
+
compare_constraints = not self._is_home_passing_constraints()
|
|
331
|
+
return self._calculate_step_vector_from_measurements(
|
|
332
|
+
compare_constraints=compare_constraints
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
def _calculate_step_vector_from_measurements(
|
|
336
|
+
self, compare_constraints: bool
|
|
337
|
+
) -> List[float]:
|
|
338
|
+
home_measurement = self._get_home_measurement()
|
|
339
|
+
if not home_measurement:
|
|
340
|
+
raise Exception("Can't step from home if it has no measurement")
|
|
341
|
+
|
|
342
|
+
vectors, measurements = self._get_all_measurements()
|
|
343
|
+
|
|
344
|
+
# This function should only ever be called if all are passing or none are passing
|
|
345
|
+
_, p = self._get_measurements_passing_constraints()
|
|
346
|
+
assert len(p) == 0 or len(p) == len(measurements)
|
|
347
|
+
|
|
348
|
+
if not vectors:
|
|
349
|
+
return [0.0] * self._config.get_num_dimensions()
|
|
350
|
+
|
|
351
|
+
weights = []
|
|
352
|
+
for m in measurements:
|
|
353
|
+
if compare_constraints:
|
|
354
|
+
weight = home_measurement.compare_constraints(m)
|
|
355
|
+
else:
|
|
356
|
+
weight = home_measurement.compare_measurements(m)
|
|
357
|
+
if not weight:
|
|
358
|
+
weight = 0.0
|
|
359
|
+
weights.append(weight)
|
|
360
|
+
|
|
361
|
+
return self._calculate_step_vector_from_vectors_and_weights(vectors, weights)
|
|
362
|
+
|
|
363
|
+
def _calculate_step_vector_from_vectors_and_weights(
|
|
364
|
+
self, vectors: List[Coordinate], weights: List[float]
|
|
365
|
+
) -> List[float]:
|
|
366
|
+
step_vector = [0.0] * self._config.get_num_dimensions()
|
|
367
|
+
dim_sum_vector = [0.0] * self._config.get_num_dimensions()
|
|
368
|
+
|
|
369
|
+
# For each dimension -
|
|
370
|
+
# if non zero, add weight (inverting if dimension is negative)
|
|
371
|
+
# divide by sum of coordinate of that dimension
|
|
372
|
+
for vector, weight in zip(vectors, weights):
|
|
373
|
+
for dim, v in enumerate(vector):
|
|
374
|
+
if v:
|
|
375
|
+
if v > 0:
|
|
376
|
+
step_vector[dim] += weight
|
|
377
|
+
dim_sum_vector[dim] += v
|
|
378
|
+
else:
|
|
379
|
+
step_vector[dim] -= weight
|
|
380
|
+
dim_sum_vector[dim] -= v
|
|
381
|
+
|
|
382
|
+
for dim, v in enumerate(dim_sum_vector):
|
|
383
|
+
if v:
|
|
384
|
+
step_vector[dim] /= v
|
|
385
|
+
|
|
386
|
+
return step_vector
|
|
387
|
+
|
|
388
|
+
def _get_all_measurements(
|
|
389
|
+
self,
|
|
390
|
+
) -> Tuple[List[Coordinate], List[RunConfigMeasurement]]:
|
|
391
|
+
"""
|
|
392
|
+
Gather all the visited vectors (directions from the home coordinate)
|
|
393
|
+
and their corresponding measurements.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
(vectors, measurements)
|
|
398
|
+
collection of vectors and their measurements.
|
|
399
|
+
"""
|
|
400
|
+
coordinates = self._get_coordinates_with_valid_measurements()
|
|
401
|
+
|
|
402
|
+
vectors = []
|
|
403
|
+
measurements = []
|
|
404
|
+
for coordinate in coordinates:
|
|
405
|
+
measurement = self._coordinate_data.get_measurement(coordinate)
|
|
406
|
+
if measurement:
|
|
407
|
+
vectors.append(coordinate - self._home_coordinate)
|
|
408
|
+
measurements.append(measurement)
|
|
409
|
+
return vectors, measurements
|
|
410
|
+
|
|
411
|
+
def _get_measurements_passing_constraints(
|
|
412
|
+
self,
|
|
413
|
+
) -> Tuple[List[Coordinate], List[RunConfigMeasurement]]:
|
|
414
|
+
"""
|
|
415
|
+
Gather all the vectors (directions from the home coordinate)
|
|
416
|
+
and their corresponding measurements that are passing constraints.
|
|
417
|
+
|
|
418
|
+
Returns
|
|
419
|
+
-------
|
|
420
|
+
(vectors, measurements)
|
|
421
|
+
collection of vectors and their measurements.
|
|
422
|
+
"""
|
|
423
|
+
coordinates = self._get_coordinates_with_valid_measurements()
|
|
424
|
+
|
|
425
|
+
vectors = []
|
|
426
|
+
measurements = []
|
|
427
|
+
for coordinate in coordinates:
|
|
428
|
+
measurement = self._coordinate_data.get_measurement(coordinate)
|
|
429
|
+
if measurement and measurement.is_passing_constraints():
|
|
430
|
+
vectors.append(coordinate - self._home_coordinate)
|
|
431
|
+
measurements.append(measurement)
|
|
432
|
+
return vectors, measurements
|
|
433
|
+
|
|
434
|
+
def _is_coordinate_measured(self, coordinate: Coordinate) -> bool:
|
|
435
|
+
return self._coordinate_data.is_measured(coordinate)
|
|
436
|
+
|
|
437
|
+
def _clamp_coordinate_to_bounds(self, coordinate: Coordinate) -> Coordinate:
|
|
438
|
+
clamped_coordinate = deepcopy(coordinate)
|
|
439
|
+
|
|
440
|
+
for i, v in enumerate(coordinate):
|
|
441
|
+
sd = self._config.get_dimension(i)
|
|
442
|
+
|
|
443
|
+
v = min(sd.get_max_idx(), v)
|
|
444
|
+
v = max(sd.get_min_idx(), v)
|
|
445
|
+
clamped_coordinate[i] = v
|
|
446
|
+
return clamped_coordinate
|
|
447
|
+
|
|
448
|
+
def _get_covered_values_per_dimension(self) -> List[Dict[Coordinate, bool]]:
|
|
449
|
+
"""
|
|
450
|
+
Returns a list of dicts that indicates which values have been
|
|
451
|
+
covered in each dimension.
|
|
452
|
+
|
|
453
|
+
(e.g.)
|
|
454
|
+
covered_values_per_dimension[dimension][value] = bool
|
|
455
|
+
"""
|
|
456
|
+
measured_coordinates = self._get_coordinates_with_valid_measurements()
|
|
457
|
+
|
|
458
|
+
covered_values_per_dimension: List[Dict[Coordinate, bool]] = [
|
|
459
|
+
{} for _ in range(self._config.get_num_dimensions())
|
|
460
|
+
]
|
|
461
|
+
|
|
462
|
+
for coordinate in measured_coordinates:
|
|
463
|
+
for i, v in enumerate(coordinate):
|
|
464
|
+
covered_values_per_dimension[i][v] = True
|
|
465
|
+
|
|
466
|
+
return covered_values_per_dimension
|
|
467
|
+
|
|
468
|
+
def _get_num_uncovered_values(
|
|
469
|
+
self,
|
|
470
|
+
coordinate: Coordinate,
|
|
471
|
+
covered_values_per_dimension: List[Dict[Coordinate, bool]],
|
|
472
|
+
) -> int:
|
|
473
|
+
"""
|
|
474
|
+
Determine how many of the coordinate dimensions in the input coordinate have values
|
|
475
|
+
that are not covered in covered_values_per_dimension
|
|
476
|
+
"""
|
|
477
|
+
num_uncovered = 0
|
|
478
|
+
|
|
479
|
+
for i, v in enumerate(coordinate):
|
|
480
|
+
if not covered_values_per_dimension[i].get(v, False):
|
|
481
|
+
num_uncovered += 1
|
|
482
|
+
|
|
483
|
+
return num_uncovered
|
|
484
|
+
|
|
485
|
+
def _is_slow_mode(self) -> bool:
|
|
486
|
+
if self._force_slow_mode:
|
|
487
|
+
return True
|
|
488
|
+
|
|
489
|
+
if not self._is_home_measured():
|
|
490
|
+
return False
|
|
491
|
+
|
|
492
|
+
passing_vectors, _ = self._get_measurements_passing_constraints()
|
|
493
|
+
all_vectors, _ = self._get_all_measurements()
|
|
494
|
+
|
|
495
|
+
any_failing = len(all_vectors) != len(passing_vectors)
|
|
496
|
+
any_passing = len(passing_vectors) != 0
|
|
497
|
+
home_passing = self._is_home_passing_constraints()
|
|
498
|
+
|
|
499
|
+
return (home_passing and any_failing) or (not home_passing and any_passing)
|
|
500
|
+
|
|
501
|
+
def _are_all_adjacent_neighbors_measured(self) -> bool:
|
|
502
|
+
for neighbor in self._get_all_adjacent_neighbors():
|
|
503
|
+
if not self._is_coordinate_measured(neighbor):
|
|
504
|
+
return False
|
|
505
|
+
return True
|
|
506
|
+
|
|
507
|
+
def _get_all_adjacent_neighbors(self) -> List[Coordinate]:
|
|
508
|
+
adjacent_neighbors = []
|
|
509
|
+
|
|
510
|
+
for dim in range(self._config.get_num_dimensions()):
|
|
511
|
+
dimension = self._config.get_dimension(dim)
|
|
512
|
+
|
|
513
|
+
down_neighbor = Coordinate(self._home_coordinate)
|
|
514
|
+
down_neighbor[dim] -= 1
|
|
515
|
+
if down_neighbor[dim] >= dimension.get_min_idx():
|
|
516
|
+
adjacent_neighbors.append(down_neighbor)
|
|
517
|
+
|
|
518
|
+
up_neighbor = Coordinate(self._home_coordinate)
|
|
519
|
+
up_neighbor[dim] += 1
|
|
520
|
+
if up_neighbor[dim] <= dimension.get_max_idx():
|
|
521
|
+
adjacent_neighbors.append(up_neighbor)
|
|
522
|
+
|
|
523
|
+
return adjacent_neighbors
|
|
524
|
+
|
|
525
|
+
def _get_home_measurement(self) -> Optional[RunConfigMeasurement]:
|
|
526
|
+
return self._coordinate_data.get_measurement(coordinate=self._home_coordinate)
|
|
527
|
+
|
|
528
|
+
def _is_home_measured(self) -> bool:
|
|
529
|
+
return self._get_home_measurement() is not None
|
|
530
|
+
|
|
531
|
+
def _is_home_passing_constraints(self) -> bool:
|
|
532
|
+
home_measurement = self._get_home_measurement()
|
|
533
|
+
if not home_measurement:
|
|
534
|
+
raise Exception("Can't check home passing if it isn't measured yet")
|
|
535
|
+
|
|
536
|
+
return home_measurement.is_passing_constraints()
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
|
|
3
|
+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
4
|
+
#
|
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
6
|
+
# you may not use this file except in compliance with the License.
|
|
7
|
+
# You may obtain a copy of the License at
|
|
8
|
+
#
|
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
10
|
+
#
|
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
14
|
+
# See the License for the specific language governing permissions and
|
|
15
|
+
# limitations under the License.
|
|
16
|
+
|
|
17
|
+
import logging
|
|
18
|
+
from copy import deepcopy
|
|
19
|
+
from typing import Dict, Generator, List, Optional
|
|
20
|
+
|
|
21
|
+
from model_analyzer.config.generate.concurrency_sweeper import ConcurrencySweeper
|
|
22
|
+
from model_analyzer.config.generate.model_profile_spec import ModelProfileSpec
|
|
23
|
+
from model_analyzer.config.generate.model_variant_name_manager import (
|
|
24
|
+
ModelVariantNameManager,
|
|
25
|
+
)
|
|
26
|
+
from model_analyzer.config.generate.optuna_run_config_generator import (
|
|
27
|
+
OptunaRunConfigGenerator,
|
|
28
|
+
)
|
|
29
|
+
from model_analyzer.config.generate.search_parameters import SearchParameters
|
|
30
|
+
from model_analyzer.config.input.config_command_profile import ConfigCommandProfile
|
|
31
|
+
from model_analyzer.config.run.run_config import RunConfig
|
|
32
|
+
from model_analyzer.constants import LOGGER_NAME
|
|
33
|
+
from model_analyzer.result.parameter_search import ParameterSearch
|
|
34
|
+
from model_analyzer.result.result_manager import ResultManager
|
|
35
|
+
from model_analyzer.result.run_config_measurement import RunConfigMeasurement
|
|
36
|
+
from model_analyzer.state.analyzer_state_manager import AnalyzerStateManager
|
|
37
|
+
|
|
38
|
+
from .config_generator_interface import ConfigGeneratorInterface
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(LOGGER_NAME)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class OptunaPlusConcurrencySweepRunConfigGenerator(ConfigGeneratorInterface):
|
|
44
|
+
"""
|
|
45
|
+
First run OptunaConfigGenerator for an Optuna search, then use
|
|
46
|
+
ParameterSearch for a concurrency sweep + binary search of the default
|
|
47
|
+
and Top N results
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def __init__(
|
|
51
|
+
self,
|
|
52
|
+
config: ConfigCommandProfile,
|
|
53
|
+
state_manager: AnalyzerStateManager,
|
|
54
|
+
gpu_count: int,
|
|
55
|
+
models: List[ModelProfileSpec],
|
|
56
|
+
composing_models: List[ModelProfileSpec],
|
|
57
|
+
result_manager: ResultManager,
|
|
58
|
+
model_variant_name_manager: ModelVariantNameManager,
|
|
59
|
+
search_parameters: Dict[str, SearchParameters],
|
|
60
|
+
composing_search_parameters: Dict[str, SearchParameters],
|
|
61
|
+
):
|
|
62
|
+
"""
|
|
63
|
+
Parameters
|
|
64
|
+
----------
|
|
65
|
+
config: ConfigCommandProfile
|
|
66
|
+
Profile configuration information
|
|
67
|
+
state_manager: AnalyzerStateManager
|
|
68
|
+
The object that allows control and update of checkpoint state
|
|
69
|
+
gpu_count: Number of gpus in the system
|
|
70
|
+
models: List of ModelProfileSpec
|
|
71
|
+
List of models to profile
|
|
72
|
+
composing_models: List of ModelProfileSpec
|
|
73
|
+
List of composing models that exist inside of the supplied models
|
|
74
|
+
result_manager: ResultManager
|
|
75
|
+
The object that handles storing and sorting the results from the perf analyzer
|
|
76
|
+
model_variant_name_manager: ModelVariantNameManager
|
|
77
|
+
Maps model variants to config names
|
|
78
|
+
search_parameters: SearchParameters
|
|
79
|
+
The object that handles the users configuration search parameters
|
|
80
|
+
composing_search_parameters: SearchParameters
|
|
81
|
+
The object that handles the users configuration search parameters for composing models
|
|
82
|
+
"""
|
|
83
|
+
self._config = config
|
|
84
|
+
self._state_manager = state_manager
|
|
85
|
+
self._gpu_count = gpu_count
|
|
86
|
+
self._models = models
|
|
87
|
+
self._composing_models = composing_models
|
|
88
|
+
self._result_manager = result_manager
|
|
89
|
+
self._model_variant_name_manager = model_variant_name_manager
|
|
90
|
+
self._search_parameters = search_parameters
|
|
91
|
+
self._composing_search_parameters = composing_search_parameters
|
|
92
|
+
|
|
93
|
+
def set_last_results(
|
|
94
|
+
self, measurements: List[Optional[RunConfigMeasurement]]
|
|
95
|
+
) -> None:
|
|
96
|
+
self._last_measurement = measurements[-1]
|
|
97
|
+
self._rcg.set_last_results(measurements)
|
|
98
|
+
|
|
99
|
+
def get_configs(self) -> Generator[RunConfig, None, None]:
|
|
100
|
+
"""
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
RunConfig
|
|
104
|
+
The next RunConfig generated by this class
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
logger.info("")
|
|
108
|
+
logger.info("Starting Optuna mode search to find optimal configs")
|
|
109
|
+
logger.info("")
|
|
110
|
+
yield from self._execute_optuna_search()
|
|
111
|
+
logger.info("")
|
|
112
|
+
if self._config.concurrency_sweep_disable:
|
|
113
|
+
logger.info("Done with Optuna mode search.")
|
|
114
|
+
else:
|
|
115
|
+
logger.info(
|
|
116
|
+
"Done with Optuna mode search. Gathering concurrency sweep measurements for reports"
|
|
117
|
+
)
|
|
118
|
+
logger.info("")
|
|
119
|
+
yield from ConcurrencySweeper(
|
|
120
|
+
config=self._config, result_manager=self._result_manager
|
|
121
|
+
).get_configs()
|
|
122
|
+
logger.info("")
|
|
123
|
+
logger.info("Done gathering concurrency sweep measurements for reports")
|
|
124
|
+
logger.info("")
|
|
125
|
+
|
|
126
|
+
def _execute_optuna_search(self) -> Generator[RunConfig, None, None]:
|
|
127
|
+
self._rcg: ConfigGeneratorInterface = self._create_optuna_run_config_generator()
|
|
128
|
+
|
|
129
|
+
yield from self._rcg.get_configs()
|
|
130
|
+
|
|
131
|
+
def _create_optuna_run_config_generator(self) -> OptunaRunConfigGenerator:
|
|
132
|
+
return OptunaRunConfigGenerator(
|
|
133
|
+
config=self._config,
|
|
134
|
+
state_manager=self._state_manager,
|
|
135
|
+
gpu_count=self._gpu_count,
|
|
136
|
+
models=self._models,
|
|
137
|
+
composing_models=self._composing_models,
|
|
138
|
+
model_variant_name_manager=self._model_variant_name_manager,
|
|
139
|
+
search_parameters=self._search_parameters,
|
|
140
|
+
composing_search_parameters=self._composing_search_parameters,
|
|
141
|
+
)
|