oracle-ads 2.13.17rc0__py3-none-any.whl → 2.13.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ads/aqua/cli.py +7 -5
- ads/aqua/common/entities.py +88 -29
- ads/aqua/common/enums.py +7 -0
- ads/aqua/common/errors.py +5 -0
- ads/aqua/common/utils.py +87 -7
- ads/aqua/constants.py +3 -0
- ads/aqua/extension/deployment_handler.py +36 -0
- ads/aqua/modeldeployment/config_loader.py +10 -0
- ads/aqua/modeldeployment/constants.py +1 -0
- ads/aqua/modeldeployment/deployment.py +99 -22
- ads/aqua/modeldeployment/entities.py +4 -0
- ads/aqua/resources/gpu_shapes_index.json +315 -26
- ads/aqua/shaperecommend/__init__.py +6 -0
- ads/aqua/shaperecommend/constants.py +116 -0
- ads/aqua/shaperecommend/estimator.py +384 -0
- ads/aqua/shaperecommend/llm_config.py +283 -0
- ads/aqua/shaperecommend/recommend.py +493 -0
- ads/aqua/shaperecommend/shape_report.py +233 -0
- ads/aqua/version.json +1 -1
- ads/cli.py +9 -1
- ads/jobs/builders/infrastructure/dsc_job.py +1 -0
- ads/jobs/builders/infrastructure/dsc_job_runtime.py +9 -1
- ads/model/service/oci_datascience_model_deployment.py +46 -19
- ads/opctl/operator/lowcode/common/data.py +7 -2
- ads/opctl/operator/lowcode/common/transformations.py +207 -0
- ads/opctl/operator/lowcode/common/utils.py +8 -0
- ads/opctl/operator/lowcode/forecast/__init__.py +3 -0
- ads/opctl/operator/lowcode/forecast/__main__.py +53 -3
- ads/opctl/operator/lowcode/forecast/const.py +2 -0
- ads/opctl/operator/lowcode/forecast/errors.py +5 -0
- ads/opctl/operator/lowcode/forecast/meta_selector.py +310 -0
- ads/opctl/operator/lowcode/forecast/model/automlx.py +1 -1
- ads/opctl/operator/lowcode/forecast/model/base_model.py +119 -30
- ads/opctl/operator/lowcode/forecast/model/factory.py +33 -2
- ads/opctl/operator/lowcode/forecast/model/forecast_datasets.py +54 -17
- ads/opctl/operator/lowcode/forecast/model_evaluator.py +6 -1
- ads/opctl/operator/lowcode/forecast/schema.yaml +1 -0
- ads/pipeline/ads_pipeline.py +13 -9
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/METADATA +1 -1
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/RECORD +43 -36
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/WHEEL +0 -0
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/entry_points.txt +0 -0
- {oracle_ads-2.13.17rc0.dist-info → oracle_ads-2.13.18.dist-info}/licenses/LICENSE.txt +0 -0
@@ -0,0 +1,493 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
# Copyright (c) 2025 Oracle and/or its affiliates.
|
3
|
+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
|
4
|
+
|
5
|
+
import shutil
|
6
|
+
from typing import List, Union
|
7
|
+
|
8
|
+
from pydantic import ValidationError
|
9
|
+
from rich.table import Table
|
10
|
+
|
11
|
+
from ads.aqua.app import logger
|
12
|
+
from ads.aqua.common.entities import ComputeShapeSummary
|
13
|
+
from ads.aqua.common.errors import (
|
14
|
+
AquaFileNotFoundError,
|
15
|
+
AquaRecommendationError,
|
16
|
+
AquaValueError,
|
17
|
+
)
|
18
|
+
from ads.aqua.common.utils import (
|
19
|
+
build_pydantic_error_message,
|
20
|
+
get_resource_type,
|
21
|
+
load_config,
|
22
|
+
load_gpu_shapes_index,
|
23
|
+
)
|
24
|
+
from ads.aqua.shaperecommend.constants import (
|
25
|
+
BITS_AND_BYTES_4BIT,
|
26
|
+
BITSANDBYTES,
|
27
|
+
SAFETENSORS,
|
28
|
+
SHAPE_MAP,
|
29
|
+
TEXT_GENERATION,
|
30
|
+
TROUBLESHOOT_MSG,
|
31
|
+
)
|
32
|
+
from ads.aqua.shaperecommend.estimator import get_estimator
|
33
|
+
from ads.aqua.shaperecommend.llm_config import LLMConfig
|
34
|
+
from ads.aqua.shaperecommend.shape_report import (
|
35
|
+
ModelConfig,
|
36
|
+
RequestRecommend,
|
37
|
+
ShapeRecommendationReport,
|
38
|
+
ShapeReport,
|
39
|
+
)
|
40
|
+
from ads.model.datascience_model import DataScienceModel
|
41
|
+
from ads.model.service.oci_datascience_model_deployment import (
|
42
|
+
OCIDataScienceModelDeployment,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class AquaShapeRecommend:
|
47
|
+
"""
|
48
|
+
Interface for recommending GPU shapes for machine learning model deployments
|
49
|
+
on Oracle Cloud Infrastructure Data Science service.
|
50
|
+
|
51
|
+
This class provides methods to recommend deployment shapes based on a model's requirements,
|
52
|
+
handle recommendation details and troubleshooting, and retrieve specific OCI Machine Learning shapes.
|
53
|
+
Must be used within a properly configured and authenticated OCI environment.
|
54
|
+
"""
|
55
|
+
|
56
|
+
def which_shapes(
|
57
|
+
self, request: RequestRecommend
|
58
|
+
) -> Union[ShapeRecommendationReport, Table]:
|
59
|
+
"""
|
60
|
+
Lists valid GPU deployment shapes for the provided model and configuration.
|
61
|
+
|
62
|
+
Validates input, retrieves the model configuration, checks the requested sequence length,
|
63
|
+
identifies available and valid compute shapes, and summarizes which shapes are compatible
|
64
|
+
with the current model settings.
|
65
|
+
|
66
|
+
Parameters
|
67
|
+
----------
|
68
|
+
ocid : str
|
69
|
+
OCID of the model to recommend feasible compute shapes.
|
70
|
+
|
71
|
+
available_shapes : List[ComputeShapeSummary]
|
72
|
+
List of available shapes to recommend
|
73
|
+
|
74
|
+
generate_table : bool
|
75
|
+
whether to generate a rich diff Table or ShapeRecommendationReport (see Returns section)
|
76
|
+
|
77
|
+
Returns
|
78
|
+
-------
|
79
|
+
Table (generate_table = True)
|
80
|
+
A table format for the recommendation report with compatible deployment shapes
|
81
|
+
or troubleshooting info citing the largest shapes if no shape is suitable.
|
82
|
+
|
83
|
+
ShapeRecommendationReport (generate_table = False)
|
84
|
+
A recommendation report with compatible deployment shapes, or troubleshooting info
|
85
|
+
citing the largest shapes if no shape is suitable.
|
86
|
+
|
87
|
+
Raises
|
88
|
+
------
|
89
|
+
AquaValueError
|
90
|
+
If parameters are missing or invalid, or if no valid sequence length is requested.
|
91
|
+
"""
|
92
|
+
try:
|
93
|
+
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)
|
94
|
+
|
95
|
+
ds_model = self._validate_model_ocid(request.model_id)
|
96
|
+
data = self._get_model_config(ds_model)
|
97
|
+
|
98
|
+
llm_config = LLMConfig.from_raw_config(data)
|
99
|
+
|
100
|
+
model_name = ds_model.display_name if ds_model.display_name else ""
|
101
|
+
|
102
|
+
shape_recommendation_report = self._summarize_shapes_for_seq_lens(
|
103
|
+
llm_config, shapes, model_name
|
104
|
+
)
|
105
|
+
|
106
|
+
if request.generate_table and shape_recommendation_report.recommendations:
|
107
|
+
shape_recommendation_report = self._rich_diff_table(
|
108
|
+
shape_recommendation_report
|
109
|
+
)
|
110
|
+
|
111
|
+
# custom error to catch model incompatibility issues
|
112
|
+
except AquaRecommendationError as error:
|
113
|
+
return ShapeRecommendationReport(
|
114
|
+
recommendations=[], troubleshoot=str(error)
|
115
|
+
)
|
116
|
+
|
117
|
+
except ValidationError as ex:
|
118
|
+
custom_errors = build_pydantic_error_message(ex)
|
119
|
+
raise AquaValueError(
|
120
|
+
f"Invalid parameters to read config.json of LLM Artifact. Error details: {custom_errors}."
|
121
|
+
) from ex
|
122
|
+
except AquaValueError as ex:
|
123
|
+
logger.error(f"Error with LLM config: {ex}")
|
124
|
+
raise AquaValueError( # noqa: B904
|
125
|
+
f"An error occured while producing recommendations: {ex}"
|
126
|
+
)
|
127
|
+
|
128
|
+
return shape_recommendation_report
|
129
|
+
|
130
|
+
def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]:
|
131
|
+
"""
|
132
|
+
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.
|
133
|
+
|
134
|
+
Parameters
|
135
|
+
----------
|
136
|
+
file : str
|
137
|
+
Path to the JSON file containing shape data.
|
138
|
+
|
139
|
+
Returns
|
140
|
+
-------
|
141
|
+
List[ComputeShapeSummary]
|
142
|
+
List of ComputeShapeSummary objects passing the checks.
|
143
|
+
|
144
|
+
Raises
|
145
|
+
------
|
146
|
+
ValueError
|
147
|
+
If the file cannot be opened, parsed, or the 'shapes' key is missing.
|
148
|
+
"""
|
149
|
+
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)
|
150
|
+
set_user_shapes = {shape.name: shape for shape in oci_shapes}
|
151
|
+
|
152
|
+
gpu_shapes_metadata = load_gpu_shapes_index().shapes
|
153
|
+
|
154
|
+
valid_shapes = []
|
155
|
+
# only loops through GPU shapes, update later to include CPU shapes
|
156
|
+
for name, spec in gpu_shapes_metadata.items():
|
157
|
+
if name in set_user_shapes:
|
158
|
+
oci_shape = set_user_shapes.get(name)
|
159
|
+
|
160
|
+
compute_shape = ComputeShapeSummary(
|
161
|
+
available=True,
|
162
|
+
core_count=oci_shape.core_count,
|
163
|
+
memory_in_gbs=oci_shape.memory_in_gbs,
|
164
|
+
shape_series=SHAPE_MAP.get(oci_shape.shape_series, "GPU"),
|
165
|
+
name=oci_shape.name,
|
166
|
+
gpu_specs=spec,
|
167
|
+
)
|
168
|
+
else:
|
169
|
+
compute_shape = ComputeShapeSummary(
|
170
|
+
available=False, name=name, shape_series="GPU", gpu_specs=spec
|
171
|
+
)
|
172
|
+
valid_shapes.append(compute_shape)
|
173
|
+
|
174
|
+
valid_shapes.sort(
|
175
|
+
key=lambda shape: shape.gpu_specs.gpu_memory_in_gbs, reverse=True
|
176
|
+
)
|
177
|
+
return valid_shapes
|
178
|
+
|
179
|
+
@staticmethod
|
180
|
+
def _rich_diff_table(shape_report: ShapeRecommendationReport) -> Table:
|
181
|
+
"""
|
182
|
+
Generates a rich-formatted table comparing deployment recommendations
|
183
|
+
from a ShapeRecommendationReport object.
|
184
|
+
|
185
|
+
Args:
|
186
|
+
shape_report (ShapeRecommendationReport): The report containing shape recommendations.
|
187
|
+
|
188
|
+
Returns:
|
189
|
+
Table: A rich Table displaying model deployment recommendations.
|
190
|
+
"""
|
191
|
+
logger.debug(
|
192
|
+
"Starting to generate rich diff table from ShapeRecommendationReport."
|
193
|
+
)
|
194
|
+
|
195
|
+
name = shape_report.display_name
|
196
|
+
header = (
|
197
|
+
f"Model Deployment Recommendations: {name}"
|
198
|
+
if name
|
199
|
+
else "Model Deployment Recommendations"
|
200
|
+
)
|
201
|
+
logger.debug(f"Table header set to: {header!r}")
|
202
|
+
|
203
|
+
if shape_report.troubleshoot:
|
204
|
+
header = f"{header}\n{shape_report.troubleshoot}"
|
205
|
+
logger.debug("Appended troubleshoot message to the header.")
|
206
|
+
|
207
|
+
term_columns = shutil.get_terminal_size((120, 20)).columns
|
208
|
+
|
209
|
+
recs_width = min(term_columns - 50, 60)
|
210
|
+
logger.debug(f"Calculated recommendation column width: {recs_width}")
|
211
|
+
|
212
|
+
table = Table(
|
213
|
+
title=header,
|
214
|
+
show_lines=True,
|
215
|
+
)
|
216
|
+
logger.debug("Initialized Table object.")
|
217
|
+
|
218
|
+
table.add_column("Shape Name", max_width=16)
|
219
|
+
table.add_column("Avaliable", max_width=7)
|
220
|
+
table.add_column("Shape Type", max_width=7)
|
221
|
+
table.add_column("GPU Count", max_width=7)
|
222
|
+
table.add_column("Total Memory (GB)", max_width=10)
|
223
|
+
table.add_column("Model Deployment Size (GB)", max_width=7)
|
224
|
+
table.add_column("Deployment Quantization", max_width=10)
|
225
|
+
table.add_column("Recommendation", max_width=recs_width)
|
226
|
+
logger.debug("Added table columns with specified max widths.")
|
227
|
+
|
228
|
+
recs = getattr(shape_report, "recommendations", [])
|
229
|
+
logger.debug(f"Number of recommendations: {len(recs)}")
|
230
|
+
|
231
|
+
for entry in recs:
|
232
|
+
shape = entry.shape_details
|
233
|
+
gpu = shape.gpu_specs
|
234
|
+
conf = entry.configurations[0]
|
235
|
+
model = conf.model_details
|
236
|
+
deploy = conf.deployment_params
|
237
|
+
recommendation = conf.recommendation
|
238
|
+
|
239
|
+
if deploy.params:
|
240
|
+
recommendation = (
|
241
|
+
f"Suggested PARAMS: {deploy.params}\n\n" + recommendation
|
242
|
+
)
|
243
|
+
|
244
|
+
if gpu.gpu_memory_in_gbs and shape.memory_in_gbs:
|
245
|
+
total_memory = f"GPU: {str(gpu.gpu_memory_in_gbs)}\nCPU: {str(shape.memory_in_gbs)}"
|
246
|
+
elif gpu.gpu_memory_in_gbs:
|
247
|
+
total_memory = f"GPU: {str(gpu.gpu_memory_in_gbs)}"
|
248
|
+
else:
|
249
|
+
total_memory = f"CPU: {str(shape.memory_in_gbs)}"
|
250
|
+
|
251
|
+
table.add_row(
|
252
|
+
shape.name,
|
253
|
+
str(shape.available),
|
254
|
+
str(shape.shape_series),
|
255
|
+
str(gpu.gpu_count),
|
256
|
+
total_memory,
|
257
|
+
str(model.total_model_gb),
|
258
|
+
deploy.quantization,
|
259
|
+
recommendation,
|
260
|
+
)
|
261
|
+
|
262
|
+
logger.debug("Completed populating table with recommendation rows.")
|
263
|
+
return table
|
264
|
+
|
265
|
+
@staticmethod
|
266
|
+
def _validate_model_ocid(ocid: str) -> DataScienceModel:
|
267
|
+
"""
|
268
|
+
Ensures the OCID passed is valid for referencing a DataScienceModel resource.
|
269
|
+
"""
|
270
|
+
resource_type = get_resource_type(ocid)
|
271
|
+
|
272
|
+
if resource_type != "datasciencemodel":
|
273
|
+
raise AquaValueError(
|
274
|
+
f"The provided OCID '{ocid}' is not a valid Oracle Cloud Data Science Model OCID. "
|
275
|
+
"Please provide an OCID corresponding to a Data Science model resource. "
|
276
|
+
"Tip: Data Science model OCIDs typically start with 'ocid1.datasciencemodel...'."
|
277
|
+
)
|
278
|
+
|
279
|
+
model = DataScienceModel.from_id(ocid)
|
280
|
+
return model
|
281
|
+
|
282
|
+
@staticmethod
|
283
|
+
def _get_model_config(model: DataScienceModel):
|
284
|
+
"""
|
285
|
+
Loads the configuration for a given Oracle Cloud Data Science model.
|
286
|
+
|
287
|
+
Validates the resource type associated with the provided OCID, ensures the model
|
288
|
+
is for text-generation with a supported decoder-only architecture, and loads the model's
|
289
|
+
configuration JSON from the artifact path.
|
290
|
+
|
291
|
+
Parameters
|
292
|
+
----------
|
293
|
+
model : DataScienceModel
|
294
|
+
The DataScienceModel representation of the model used in recommendations
|
295
|
+
|
296
|
+
Returns
|
297
|
+
-------
|
298
|
+
dict
|
299
|
+
The parsed configuration dictionary from config.json.
|
300
|
+
|
301
|
+
Raises
|
302
|
+
------
|
303
|
+
AquaValueError
|
304
|
+
If the OCID is not for a Data Science model, or if the model type is not supported,
|
305
|
+
or if required files/tags are not present.
|
306
|
+
|
307
|
+
AquaRecommendationError
|
308
|
+
If the model OCID provided is not supported (only text-generation decoder models in safetensor format supported).
|
309
|
+
"""
|
310
|
+
|
311
|
+
model_task = model.freeform_tags.get("task", "").lower()
|
312
|
+
model_format = model.freeform_tags.get("model_format", "").lower()
|
313
|
+
|
314
|
+
logger.info(f"Current model task type: {model_task}")
|
315
|
+
logger.info(f"Current model format: {model_format}")
|
316
|
+
|
317
|
+
if TEXT_GENERATION not in model_task:
|
318
|
+
raise AquaRecommendationError(
|
319
|
+
"Please provide a decoder-only text-generation model (ex. Llama, Falcon, etc.). "
|
320
|
+
f"Only text-generation models are supported in this tool at this time. Current model task type: {model_task}"
|
321
|
+
)
|
322
|
+
if SAFETENSORS not in model_format:
|
323
|
+
msg = "Please provide a model in Safetensor format. "
|
324
|
+
if model_format:
|
325
|
+
msg += f"The current model format ({model_format}) is not supported by this tool at this time."
|
326
|
+
|
327
|
+
raise AquaRecommendationError(msg)
|
328
|
+
|
329
|
+
if not model.artifact:
|
330
|
+
raise AquaValueError(
|
331
|
+
"Unable to retrieve model artifact. Ensure model is registered and active."
|
332
|
+
)
|
333
|
+
|
334
|
+
try:
|
335
|
+
data = load_config(model.artifact, "config.json")
|
336
|
+
|
337
|
+
except AquaFileNotFoundError as e:
|
338
|
+
logger.error(
|
339
|
+
f"config.json not found in model artifact at {model.artifact}: {e}"
|
340
|
+
)
|
341
|
+
raise AquaRecommendationError(
|
342
|
+
"The configuration file 'config.json' was not found in the specified model directory. "
|
343
|
+
"Please ensure your model follows the Hugging Face format and includes a 'config.json' with the necessary architecture parameters."
|
344
|
+
) from e
|
345
|
+
|
346
|
+
return data
|
347
|
+
|
348
|
+
@staticmethod
|
349
|
+
def _summarize_shapes_for_seq_lens(
|
350
|
+
config: LLMConfig,
|
351
|
+
shapes: List[ComputeShapeSummary],
|
352
|
+
name: str,
|
353
|
+
batch_size: int = 1,
|
354
|
+
) -> ShapeRecommendationReport:
|
355
|
+
"""
|
356
|
+
Generate a recommendation report for eligible deployment shapes by evaluating
|
357
|
+
model memory consumption and maximum model length for given configurations.
|
358
|
+
|
359
|
+
Parameters
|
360
|
+
----------
|
361
|
+
config : LLMConfig
|
362
|
+
The loaded model configuration.
|
363
|
+
shapes : List[ComputeShapeSummary]
|
364
|
+
All candidate deployment shapes.
|
365
|
+
name : str
|
366
|
+
name of the model
|
367
|
+
batch_size : int, optional
|
368
|
+
Batch size to evaluate (default is 1).
|
369
|
+
|
370
|
+
Returns
|
371
|
+
-------
|
372
|
+
ShapeRecommendationReport
|
373
|
+
Report containing shape recommendations and troubleshooting advice, if any.
|
374
|
+
|
375
|
+
Raises
|
376
|
+
------
|
377
|
+
ValueError
|
378
|
+
If no GPU shapes are available.
|
379
|
+
|
380
|
+
Notes
|
381
|
+
-----
|
382
|
+
- Considers quantization if defined in config, otherwise cycles through optimal configs.
|
383
|
+
- Applies pareto optimality if too many recommendations.
|
384
|
+
- Provides troubleshooting options if nothing fits.
|
385
|
+
"""
|
386
|
+
recommendations = []
|
387
|
+
|
388
|
+
if not shapes:
|
389
|
+
raise AquaValueError(
|
390
|
+
"No GPU shapes were passed for recommendation. Ensure shape parsing succeeded."
|
391
|
+
)
|
392
|
+
|
393
|
+
# Pre-quantized: only consider different max-seq-len
|
394
|
+
if config.quantization_type:
|
395
|
+
deployment_config = config.calculate_possible_seq_len()
|
396
|
+
for shape in shapes:
|
397
|
+
shape_quantization = set(shape.gpu_specs.quantization)
|
398
|
+
if config.quantization_type in shape_quantization:
|
399
|
+
allowed_gpu_memory = shape.gpu_specs.gpu_memory_in_gbs
|
400
|
+
for max_seq_len in deployment_config:
|
401
|
+
estimator = get_estimator(
|
402
|
+
llm_config=config,
|
403
|
+
seq_len=max_seq_len,
|
404
|
+
batch_size=batch_size,
|
405
|
+
)
|
406
|
+
if estimator.validate_shape(allowed_gpu_memory):
|
407
|
+
best_config = [
|
408
|
+
ModelConfig.constuct_model_config(
|
409
|
+
estimator, allowed_gpu_memory
|
410
|
+
)
|
411
|
+
]
|
412
|
+
recommendations.append(
|
413
|
+
ShapeReport(
|
414
|
+
shape_details=shape, configurations=best_config
|
415
|
+
)
|
416
|
+
)
|
417
|
+
break
|
418
|
+
|
419
|
+
# unquantized: consider inflight quantization (4bit)
|
420
|
+
else:
|
421
|
+
deployment_config = config.optimal_config()
|
422
|
+
prev_quant = None
|
423
|
+
for shape in shapes:
|
424
|
+
shape_quantization = set(shape.gpu_specs.quantization)
|
425
|
+
allowed_gpu_memory = shape.gpu_specs.gpu_memory_in_gbs
|
426
|
+
for quantization, max_seq_len in deployment_config:
|
427
|
+
if (
|
428
|
+
quantization == BITS_AND_BYTES_4BIT
|
429
|
+
and BITSANDBYTES not in shape_quantization
|
430
|
+
):
|
431
|
+
continue
|
432
|
+
if quantization != prev_quant:
|
433
|
+
updated_config = config.model_copy(
|
434
|
+
update={"in_flight_quantization": quantization}
|
435
|
+
)
|
436
|
+
prev_quant = quantization
|
437
|
+
estimator = get_estimator(
|
438
|
+
llm_config=updated_config,
|
439
|
+
seq_len=max_seq_len,
|
440
|
+
batch_size=batch_size,
|
441
|
+
)
|
442
|
+
if estimator.validate_shape(allowed_gpu_memory):
|
443
|
+
best_config = [
|
444
|
+
ModelConfig.constuct_model_config(
|
445
|
+
estimator, allowed_gpu_memory
|
446
|
+
)
|
447
|
+
]
|
448
|
+
recommendations.append(
|
449
|
+
ShapeReport(shape_details=shape, configurations=best_config)
|
450
|
+
)
|
451
|
+
break
|
452
|
+
|
453
|
+
troubleshoot_msg = ""
|
454
|
+
|
455
|
+
if len(recommendations) > 2:
|
456
|
+
recommendations = ShapeReport.pareto_front(recommendations)
|
457
|
+
|
458
|
+
if not recommendations:
|
459
|
+
# Troubleshooting advice if nothing fits
|
460
|
+
# Assumes shapes is sorted largest to smallest and quantizations 'fp8'/'4bit' exist
|
461
|
+
troubleshoot_msg += TROUBLESHOOT_MSG
|
462
|
+
|
463
|
+
largest_shapes = (
|
464
|
+
[(shapes[0], "fp8", False), (shapes[1], "4bit", True)]
|
465
|
+
if len(shapes) > 1
|
466
|
+
else []
|
467
|
+
) # shape, quantization, in_flight_quantization
|
468
|
+
|
469
|
+
for shape, quantization, in_flight in largest_shapes:
|
470
|
+
if in_flight:
|
471
|
+
updated_config = config.model_copy(
|
472
|
+
update={"in_flight_quantization": quantization}
|
473
|
+
)
|
474
|
+
else:
|
475
|
+
updated_config = config.model_copy(
|
476
|
+
update={"quantization": quantization}
|
477
|
+
)
|
478
|
+
estimator = get_estimator(
|
479
|
+
llm_config=updated_config, seq_len=2048, batch_size=batch_size
|
480
|
+
)
|
481
|
+
allowed_gpu_memory = shape.gpu_specs.gpu_memory_in_gbs * 0.9
|
482
|
+
best_config = [
|
483
|
+
ModelConfig.constuct_model_config(estimator, allowed_gpu_memory)
|
484
|
+
]
|
485
|
+
recommendations.append(
|
486
|
+
ShapeReport(shape_details=shape, configurations=best_config)
|
487
|
+
)
|
488
|
+
|
489
|
+
return ShapeRecommendationReport(
|
490
|
+
display_name=name,
|
491
|
+
recommendations=recommendations,
|
492
|
+
troubleshoot=troubleshoot_msg,
|
493
|
+
)
|