gmicloud 0.1.6__py3-none-any.whl → 0.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,12 +1,17 @@
1
1
  import os
2
2
  import time
3
- from typing import List
3
+ from typing import List, Dict, Any
4
4
  import mimetypes
5
+ import concurrent.futures
6
+ import re
7
+ from tqdm import tqdm
8
+ from tqdm.contrib.logging import logging_redirect_tqdm
5
9
 
6
10
  from .._client._iam_client import IAMClient
7
11
  from .._client._artifact_client import ArtifactClient
8
12
  from .._client._file_upload_client import FileUploadClient
9
13
  from .._models import *
14
+ from .._manager.serve_command_utils import parse_server_command, extract_gpu_num_from_serve_command
10
15
 
11
16
  import logging
12
17
 
@@ -53,7 +58,13 @@ class ArtifactManager:
53
58
  self,
54
59
  artifact_name: str,
55
60
  description: Optional[str] = "",
56
- tags: Optional[List[str]] = None
61
+ tags: Optional[List[str]] = None,
62
+ deployment_type: Optional[str] = "",
63
+ template_id: Optional[str] = "",
64
+ env_parameters: Optional[List["EnvParameter"]] = None,
65
+ model_description: Optional[str] = "",
66
+ model_parameters: Optional[List["ModelParameter"]] = None,
67
+ artifact_volume_path: Optional[str] = "",
57
68
  ) -> CreateArtifactResponse:
58
69
  """
59
70
  Create a new artifact for a user.
@@ -69,11 +80,17 @@ class ArtifactManager:
69
80
 
70
81
  req = CreateArtifactRequest(artifact_name=artifact_name,
71
82
  artifact_description=description,
72
- artifact_tags=tags, )
83
+ artifact_tags=tags,
84
+ deployment_type=deployment_type,
85
+ template_id=template_id,
86
+ env_parameters=env_parameters,
87
+ model_description=model_description,
88
+ model_parameters=model_parameters,
89
+ artifact_volume_path=artifact_volume_path)
73
90
 
74
91
  return self.artifact_client.create_artifact(req)
75
92
 
76
- def create_artifact_from_template(self, artifact_template_id: str) -> str:
93
+ def create_artifact_from_template(self, artifact_template_id: str, env_parameters: Optional[dict[str, str]] = None) -> str:
77
94
  """
78
95
  Create a new artifact for a user using a template.
79
96
 
@@ -85,11 +102,16 @@ class ArtifactManager:
85
102
  if not artifact_template_id or not artifact_template_id.strip():
86
103
  raise ValueError("Artifact template ID is required and cannot be empty.")
87
104
 
105
+
88
106
  resp = self.artifact_client.create_artifact_from_template(artifact_template_id)
89
107
  if not resp or not resp.artifact_id:
90
108
  raise ValueError("Failed to create artifact from template.")
91
109
 
110
+ if env_parameters:
111
+ self.artifact_client.add_env_parameters_to_artifact(resp.artifact_id, env_parameters)
112
+
92
113
  return resp.artifact_id
114
+
93
115
 
94
116
  def create_artifact_from_template_name(self, artifact_template_name: str) -> tuple[str, ReplicaResource]:
95
117
  """
@@ -125,6 +147,70 @@ class ArtifactManager:
125
147
  except Exception as e:
126
148
  logger.error(f"Failed to create artifact from template, Error: {e}")
127
149
  raise e
150
+
151
+ def create_artifact_for_serve_command_and_custom_model(self, template_name: str, artifact_name: str, serve_command: str, gpu_type: str, artifact_description: str = "", pre_download_model: str = "", env_parameters: Optional[Dict[str, Any]] = None) -> tuple[str, ReplicaResource]:
152
+ """
153
+ Create an artifact from a template and support custom model.
154
+ :param artifact_template_name: The name of the template to use.
155
+ :return: A tuple containing the artifact ID and the recommended replica resources.
156
+ :rtype: tuple[str, ReplicaResource]
157
+ """
158
+
159
+ recommended_replica_resources = None
160
+ picked_template = None
161
+ try:
162
+ templates = self.get_public_templates()
163
+ except Exception as e:
164
+ logger.error(f"Failed to get artifact templates, Error: {e}")
165
+ for template in templates:
166
+ if template.template_data and template.template_data.name == template_name:
167
+ picked_template = template
168
+ break
169
+ if not picked_template:
170
+ raise ValueError(f"Template with name {template_name} not found.")
171
+
172
+ try:
173
+ if gpu_type not in ["H100", "H200"]:
174
+ raise ValueError("Only support H100 and H200 for now")
175
+
176
+ type, env_vars, serve_args_dict = parse_server_command(serve_command)
177
+ if type.lower() not in template_name.lower():
178
+ raise ValueError(f"Template {template_name} does not support inference with {type}.")
179
+ num_gpus = extract_gpu_num_from_serve_command(serve_args_dict)
180
+ recommended_replica_resources = ReplicaResource(
181
+ cpu=num_gpus * 16,
182
+ ram_gb=num_gpus * 100,
183
+ gpu=num_gpus,
184
+ gpu_name=gpu_type,
185
+ )
186
+ except Exception as e:
187
+ raise ValueError(f"Failed to parse serve command, Error: {e}")
188
+
189
+ try:
190
+ env_vars = []
191
+ if picked_template.template_data and picked_template.template_data.env_parameters:
192
+ env_vars = picked_template.template_data.env_parameters
193
+ env_vars_map = {param.key: param for param in env_vars}
194
+ if env_parameters:
195
+ for key, value in env_parameters.items():
196
+ if key in ['GPU_TYPE', 'SERVE_COMMAND']:
197
+ continue
198
+ if key not in env_vars_map:
199
+ new_param = EnvParameter(key=key, value=value)
200
+ env_vars.append(new_param)
201
+ env_vars_map[key] = new_param
202
+ else:
203
+ env_vars_map[key].value = value
204
+ env_vars.extend([
205
+ EnvParameter(key="SERVE_COMMAND", value=serve_command),
206
+ EnvParameter(key="GPU_TYPE", value=gpu_type),
207
+ ])
208
+ resp = self.create_artifact(artifact_name, artifact_description, deployment_type="template", template_id=picked_template.template_id, env_parameters=env_vars, artifact_volume_path=f"models/{pre_download_model}")
209
+ # Assume Artifact is already with BuildStatus.SUCCESS status
210
+ return resp.artifact_id, recommended_replica_resources
211
+ except Exception as e:
212
+ logger.error(f"Failed to create artifact from template, Error: {e}")
213
+ raise e
128
214
 
129
215
  def rebuild_artifact(self, artifact_id: str) -> RebuildArtifactResponse:
130
216
  """
@@ -211,7 +297,7 @@ class ArtifactManager:
211
297
  model_file_name = os.path.basename(model_file_path)
212
298
  model_file_type = mimetypes.guess_type(model_file_path)[0]
213
299
 
214
- req = GetBigFileUploadUrlRequest(artifact_id=artifact_id, file_name=model_file_name, file_type=model_file_type)
300
+ req = ResumableUploadLinkRequest(artifact_id=artifact_id, file_name=model_file_name, file_type=model_file_type)
215
301
 
216
302
  resp = self.artifact_client.get_bigfile_upload_url(req)
217
303
  if not resp or not resp.upload_link:
@@ -250,36 +336,67 @@ class ArtifactManager:
250
336
 
251
337
  FileUploadClient.upload_large_file(upload_link, file_path)
252
338
 
339
+
340
+ def upload_model_files_to_artifact(self, artifact_id: str, model_directory: str) -> None:
341
+ """
342
+ Upload model files to an existing artifact.
343
+
344
+ :param artifact_id: The ID of the artifact to upload the model files to.
345
+ :param model_directory: The path to the model directory.
346
+ """
347
+
348
+ # List all files in the model directory recursively
349
+ model_file_paths = []
350
+ for root, _, files in os.walk(model_directory):
351
+ # Skip .cache folder
352
+ if '.cache' in root.split(os.path.sep):
353
+ continue
354
+ for file in files:
355
+ model_file_paths.append(os.path.join(root, file))
356
+
357
+ def upload_file(model_file_path):
358
+ self._validate_file_path(model_file_path)
359
+ bigfile_upload_url_resp = self.artifact_client.get_bigfile_upload_url(
360
+ ResumableUploadLinkRequest(artifact_id=artifact_id, file_name=os.path.basename(model_file_path))
361
+ )
362
+ FileUploadClient.upload_large_file(bigfile_upload_url_resp.upload_link, model_file_path)
363
+
364
+ # Upload files in parallel with progress bar
365
+ with tqdm(total=len(model_file_paths), desc="Uploading model files") as progress_bar:
366
+ with logging_redirect_tqdm():
367
+ with concurrent.futures.ThreadPoolExecutor() as executor:
368
+ futures = {executor.submit(upload_file, path): path for path in model_file_paths}
369
+ for future in concurrent.futures.as_completed(futures):
370
+ try:
371
+ future.result()
372
+ except Exception as e:
373
+ logger.error(f"Failed to upload file {futures[future]}, Error: {e}")
374
+ progress_bar.update(1)
375
+
253
376
  def create_artifact_with_model_files(
254
377
  self,
255
378
  artifact_name: str,
256
379
  artifact_file_path: str,
257
- model_file_paths: List[str],
380
+ model_directory: str,
258
381
  description: Optional[str] = "",
259
382
  tags: Optional[str] = None
260
383
  ) -> str:
261
384
  """
262
385
  Create a new artifact for a user and upload model files associated with the artifact.
263
-
264
386
  :param artifact_name: The name of the artifact.
265
387
  :param artifact_file_path: The path to the artifact file(Dockerfile+serve.py).
266
- :param model_file_paths: The paths to the model files.
388
+ :param model_directory: The path to the model directory.
267
389
  :param description: An optional description for the artifact.
268
390
  :param tags: Optional tags associated with the artifact, as a comma-separated string.
269
391
  :return: The `artifact_id` of the created artifact.
270
- :raises FileNotFoundError: If the provided `file_path` does not exist.
271
392
  """
272
393
  artifact_id = self.create_artifact_with_file(artifact_name, artifact_file_path, description, tags)
394
+ logger.info(f"Artifact created: {artifact_id}")
273
395
 
274
- for model_file_path in model_file_paths:
275
- self._validate_file_path(model_file_path)
276
- bigfile_upload_url_resp = self.artifact_client.get_bigfile_upload_url(
277
- GetBigFileUploadUrlRequest(artifact_id=artifact_id, model_file_path=model_file_path)
278
- )
279
- FileUploadClient.upload_large_file(bigfile_upload_url_resp.upload_link, model_file_path)
396
+ self.upload_model_files_to_artifact(artifact_id, model_directory)
280
397
 
281
398
  return artifact_id
282
-
399
+
283
400
 
284
401
  def wait_for_artifact_ready(self, artifact_id: str, timeout_s: int = 900) -> None:
285
402
  """
@@ -295,7 +412,7 @@ class ArtifactManager:
295
412
  artifact = self.get_artifact(artifact_id)
296
413
  if artifact.build_status == BuildStatus.SUCCESS:
297
414
  return
298
- elif artifact.build_status in [BuildStatus.FAILED, BuildStatus.TIMEOUT, BuildStatus.CANCELLED]:
415
+ elif artifact.build_status in [BuildStatus.FAILURE, BuildStatus.TIMEOUT, BuildStatus.CANCELLED]:
299
416
  raise Exception(f"Artifact build failed, status: {artifact.build_status}")
300
417
  except Exception as e:
301
418
  logger.error(f"Failed to get artifact, Error: {e}")
@@ -304,12 +421,12 @@ class ArtifactManager:
304
421
  time.sleep(10)
305
422
 
306
423
 
307
- def get_public_templates(self) -> List[ArtifactTemplate]:
424
+ def get_public_templates(self) -> List[Template]:
308
425
  """
309
426
  Fetch all artifact templates.
310
427
 
311
- :return: A list of ArtifactTemplate objects.
312
- :rtype: List[ArtifactTemplate]
428
+ :return: A list of Template objects.
429
+ :rtype: List[Template]
313
430
  """
314
431
  return self.artifact_client.get_public_templates()
315
432
 
@@ -41,7 +41,7 @@ class TaskManager:
41
41
 
42
42
  :return: A list of `Task` objects.
43
43
  """
44
- resp = self.task_client.get_all_tasks(self.iam_client.get_user_id())
44
+ resp = self.task_client.get_all_tasks()
45
45
  if not resp or not resp.tasks:
46
46
  return []
47
47
 
@@ -63,7 +63,26 @@ class TaskManager:
63
63
  if not resp or not resp.task:
64
64
  raise ValueError("Failed to create task.")
65
65
 
66
+ logger.info(f"Task created: {resp.task.task_id}")
66
67
  return resp.task
68
+
69
+ def create_task_from_artifact_id(self, artifact_id: str, replica_resource: ReplicaResource, task_scheduling: TaskScheduling) -> Task:
70
+ """
71
+ Create a new task using the configuration data from a file.
72
+ """
73
+ # Create Task based on Artifact
74
+ new_task = Task(
75
+ config=TaskConfig(
76
+ ray_task_config=RayTaskConfig(
77
+ artifact_id=artifact_id,
78
+ file_path="serve",
79
+ deployment_name="app",
80
+ replica_resource=replica_resource,
81
+ ),
82
+ task_scheduling = task_scheduling,
83
+ ),
84
+ )
85
+ return self.create_task(new_task).task_id
67
86
 
68
87
  def create_task_from_file(self, artifact_id: str, config_file_path: str, trigger_timestamp: int = None) -> Task:
69
88
  """
@@ -138,48 +157,54 @@ class TaskManager:
138
157
  return self.task_client.start_task(task_id)
139
158
 
140
159
 
141
- def start_task_and_wait(self, task_id: str, timeout_s: int = 900) -> Task:
160
+ def wait_for_task(self, task_id: str, timeout_s: int = 900) -> Task:
142
161
  """
143
- Start a task and wait for it to be ready.
162
+ Wait for a task to reach the RUNNING state or raise an exception if it fails.
144
163
 
145
- :param task_id: The ID of the task to start.
164
+ :param task_id: The ID of the task to wait for.
146
165
  :param timeout_s: The timeout in seconds.
147
166
  :return: The task object.
148
167
  :rtype: Task
149
168
  """
150
- # trigger start task
151
- try:
152
- self.start_task(task_id)
153
- logger.info(f"Started task ID: {task_id}")
154
- except Exception as e:
155
- logger.error(f"Failed to start task, Error: {e}")
156
- raise e
157
-
158
169
  start_time = time.time()
159
170
  while True:
160
171
  try:
161
172
  task = self.get_task(task_id)
162
173
  if task.task_status == TaskStatus.RUNNING:
163
- return task
164
- elif task.task_status in [TaskStatus.NEEDSTOP, TaskStatus.ARCHIVED]:
165
- raise Exception(f"Unexpected task status after starting: {task.task_status}")
166
- # Also check endpoint status.
167
- elif task.task_status == TaskStatus.RUNNING:
168
- if task.endpoint_info and task.endpoint_info.endpoint_status == TaskEndpointStatus.RUNNING:
174
+ if task.endpoint_info is not None and task.endpoint_info.endpoint_status == TaskEndpointStatus.RUNNING:
169
175
  return task
170
- elif task.endpoint_info and task.endpoint_info.endpoint_status in [TaskEndpointStatus.UNKNOWN, TaskEndpointStatus.ARCHIVED]:
171
- raise Exception(f"Unexpected endpoint status after starting: {task.endpoint_info.endpoint_status}")
172
176
  else:
173
- logger.info(f"Pending endpoint starting. endpoint status: {task.endpoint_info.endpoint_status}")
177
+ if task.cluster_endpoints:
178
+ for ce in task.cluster_endpoints:
179
+ if ce.endpoint_status == TaskEndpointStatus.RUNNING:
180
+ return task
181
+ if task.task_status in [TaskStatus.NEEDSTOP, TaskStatus.ARCHIVED]:
182
+ raise Exception(f"Unexpected task status after starting: {task.task_status}")
174
183
  else:
175
184
  logger.info(f"Pending task starting. Task status: {task.task_status}")
176
-
177
185
  except Exception as e:
178
186
  logger.error(f"Failed to get task, Error: {e}")
179
187
  if time.time() - start_time > timeout_s:
180
188
  raise Exception(f"Task creation takes more than {timeout_s // 60} minutes. Testing aborted.")
181
189
  time.sleep(10)
182
190
 
191
+ def start_task_and_wait(self, task_id: str, timeout_s: int = 3600) -> Task:
192
+ """
193
+ Start a task and wait for it to be ready.
194
+
195
+ :param task_id: The ID of the task to start.
196
+ :param timeout_s: The timeout in seconds.
197
+ :return: The task object.
198
+ :rtype: Task
199
+ """
200
+ try:
201
+ self.start_task(task_id)
202
+ logger.info(f"Started task ID: {task_id}")
203
+ except Exception as e:
204
+ logger.error(f"Failed to start task, Error: {e}")
205
+ raise e
206
+
207
+ return self.wait_for_task(task_id, timeout_s)
183
208
 
184
209
  def stop_task(self, task_id: str) -> bool:
185
210
  """
@@ -190,16 +215,15 @@ class TaskManager:
190
215
  :raises ValueError: If `task_id` is invalid (None or empty string).
191
216
  """
192
217
  self._validate_not_empty(task_id, "Task ID")
218
+ return self.task_client.stop_task(task_id)
193
219
 
194
220
 
195
- def stop_task_and_wait(self, task_id: str, timeout_s: int = 900):
196
- task_manager = self.task_manager
221
+ def stop_task_and_wait(self, task_id: str, timeout_s: int = 3600):
197
222
  try:
198
- self.task_manager.stop_task(task_id)
223
+ self.stop_task(task_id)
199
224
  logger.info(f"Stopping task ID: {task_id}")
200
225
  except Exception as e:
201
226
  logger.error(f"Failed to stop task, Error: {e}")
202
- task_manager = self.task_manager
203
227
  start_time = time.time()
204
228
  while True:
205
229
  try:
@@ -212,7 +236,17 @@ class TaskManager:
212
236
  raise Exception(f"Task stopping takes more than {timeout_s // 60} minutes. Testing aborted.")
213
237
  time.sleep(10)
214
238
 
215
- return self.task_client.stop_task(task_id)
239
+ def get_task_endpoint_url(self, task_id: str) -> str:
240
+ task = self.get_task(task_id)
241
+ if task.endpoint_info is not None and task.endpoint_info.endpoint_status == TaskEndpointStatus.RUNNING:
242
+ return task.endpoint_info.endpoint_url
243
+ else:
244
+ if task.cluster_endpoints:
245
+ for ce in task.cluster_endpoints:
246
+ if ce.endpoint_status == TaskEndpointStatus.RUNNING:
247
+ return ce.endpoint_url
248
+ return ""
249
+
216
250
 
217
251
  def get_usage_data(self, start_timestamp: str, end_timestamp: str) -> GetUsageDataResponse:
218
252
  """
@@ -0,0 +1,91 @@
1
+ import os
2
+ import logging
3
+
4
+ from .._client._iam_client import IAMClient
5
+ from .._client._video_client import VideoClient
6
+ from .._models import *
7
+
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class VideoManager:
12
+ """
13
+ A manager for handling video tasks, providing methods to create, update, and stop tasks.
14
+ """
15
+
16
+ def __init__(self, iam_client: IAMClient):
17
+ """
18
+ Initializes the VideoManager with the given IAM client.
19
+ """
20
+ self.video_client = VideoClient(iam_client)
21
+ self.iam_client = iam_client
22
+
23
+
24
+ def get_request_detail(self, request_id: str) -> GetRequestResponse:
25
+ """
26
+ Retrieves detailed information about a specific request by its ID. This endpoint requires authentication with a bearer token and only returns requests belonging to the authenticated organization.
27
+
28
+ :param request_id: The ID of the request to be retrieved.
29
+ :return: Details of the request successfully retrieved
30
+ """
31
+ self._validate_not_empty(request_id, "request_id")
32
+ return self.video_client.get_request_detail(request_id)
33
+
34
+
35
+ def get_requests(self, model_id: str) -> List[GetRequestResponse]:
36
+ """
37
+ Retrieves a list of requests submitted by the authenticated user for a specific model. This endpoint requires authentication with a bearer token and filters results by the authenticated organization.
38
+
39
+ :param model_id: The ID of the model to be retrieved.
40
+ :return: List of user's requests successfully retrieved
41
+ """
42
+ self._validate_not_empty(model_id, "model_id")
43
+ return self.video_client.get_requests(model_id)
44
+
45
+
46
+ def create_request(self, request: SubmitRequestRequest) -> SubmitRequestResponse:
47
+ """
48
+ Submits a new asynchronous request to process a specified model with provided parameters. This endpoint requires authentication with a bearer token.
49
+
50
+ :param request: The request data to be created.
51
+ :return: The created request data.
52
+ """
53
+ if not request:
54
+ raise ValueError("Request data cannot be None.")
55
+ if not request.model:
56
+ raise ValueError("Model ID is required in the request data.")
57
+ if not request.payload:
58
+ raise ValueError("Payload is required in the request data.")
59
+ return self.video_client.create_request(request)
60
+
61
+
62
+ def get_model_detail(self, model_id: str) -> GetModelResponse:
63
+ """
64
+ Retrieves detailed information about a specific model by its ID.
65
+
66
+ :param model_id: The ID of the model to be retrieved.
67
+ :return: Details of the specified model.
68
+ """
69
+ self._validate_not_empty(model_id, "model_id")
70
+ return self.video_client.get_model_detail(model_id)
71
+
72
+
73
+ def get_models(self) -> List[GetModelResponse]:
74
+ """
75
+ Retrieves a list of available models for video processing.
76
+
77
+ :return: A list of available models.
78
+ """
79
+ return self.video_client.get_models()
80
+
81
+
82
+ @staticmethod
83
+ def _validate_not_empty(value: str, name: str):
84
+ """
85
+ Validate a string is neither None nor empty.
86
+
87
+ :param value: The string to validate.
88
+ :param name: The name of the value for error reporting.
89
+ """
90
+ if not value or not value.strip():
91
+ raise ValueError(f"{name} is required and cannot be empty.")
@@ -0,0 +1,125 @@
1
+ import shlex
2
+ import os
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def parse_server_command(cmd_str: str) -> tuple[str, dict, dict]:
8
+ """
9
+ parse server command
10
+ Maybe their are more than two types of server command
11
+ if not found, we can add more parse function
12
+ """
13
+ if "vllm serve" in cmd_str:
14
+ return ("vllm", *parse_server_vllm_command(cmd_str))
15
+ elif "sglang.launch_server" in cmd_str:
16
+ return ("sglang", *parse_server_sglang_command(cmd_str))
17
+ else:
18
+ raise ValueError(f"Unknown serve command: {cmd_str}")
19
+
20
+ def extract_env_and_args(tokens: list) -> tuple[dict, list]:
21
+ """
22
+ Extract environment variables from the tokens list.
23
+ and add the params or flags to environment variables
24
+ """
25
+ env_vars = {}
26
+ while tokens and '=' in tokens[0] and not tokens[0].startswith('--'):
27
+ key, value = tokens.pop(0).split('=', 1)
28
+ env_vars[key] = value
29
+ for k, v in env_vars.items():
30
+ os.environ[k] = v
31
+ return env_vars, tokens
32
+
33
+ def parse_flags_and_args(tokens: list) -> dict:
34
+ """
35
+ parse flags and args
36
+ include three types --flag=value and --flag value annd --flag
37
+ """
38
+ result = {}
39
+ i = 0
40
+ while i < len(tokens):
41
+ token = tokens[i]
42
+ if token.startswith('--') or token.startswith('-'):
43
+ if '=' in token:
44
+ key, value = token[2:].split('=', 1)
45
+ result[key] = value.strip("'\"")
46
+ elif i + 1 < len(tokens) and not tokens[i + 1].startswith('--'):
47
+ if token.startswith('--'):
48
+ result[token[2:]] = tokens[i + 1].strip("'\"")
49
+ else:
50
+ result[token[1:]] = tokens[i + 1].strip("'\"")
51
+ i += 1
52
+ else:
53
+ if token.startswith('--'):
54
+ result[token[2:]] = True
55
+ else:
56
+ result[token[1:]] = True
57
+ else:
58
+ logger.warning(f"Ignoring unknown token: {token}")
59
+ i += 1
60
+ return result
61
+
62
+ def parse_server_vllm_command(cmd_str: str) -> tuple[dict, dict]:
63
+ """ parse vllm command"""
64
+ tokens = shlex.split(cmd_str)
65
+ result = {}
66
+
67
+ # 提取环境变量
68
+ env_vars, tokens = extract_env_and_args(tokens)
69
+ if env_vars:
70
+ result["env_vars"] = env_vars
71
+
72
+ # vllm serve + model
73
+ if tokens[:2] != ['vllm', 'serve']:
74
+ raise ValueError("Invalid vllm serve command format. Example: vllm serve <model path>")
75
+
76
+ if len(tokens) < 3:
77
+ raise ValueError("Missing model path in vllm serve command. Example: vllm serve <model path>")
78
+
79
+ model_path = tokens[2]
80
+ result["model-path"] = model_path
81
+
82
+ flags = parse_flags_and_args(tokens[3:])
83
+ result.update(flags)
84
+ return (env_vars, result)
85
+
86
+ def parse_server_sglang_command(cmd_str: str) -> tuple[dict, dict]:
87
+ """ parse sglang command"""
88
+ tokens = shlex.split(cmd_str)
89
+ result = {}
90
+
91
+ # 提取环境变量
92
+ env_vars, tokens = extract_env_and_args(tokens)
93
+ if env_vars:
94
+ result["env_vars"] = env_vars
95
+ # python3 -m sglang.launch_server
96
+ if tokens[:3] != ['python3', '-m', 'sglang.launch_server'] and tokens[:3] != ['python', '-m', 'sglang.launch_server']:
97
+ raise ValueError("Invalid sglang command format. Example: python3 -m sglang.launch_server")
98
+
99
+ flags = parse_flags_and_args(tokens[3:])
100
+ result.update(flags)
101
+ return (env_vars, result)
102
+
103
+ def extract_gpu_num_from_serve_command(serve_args_dict: dict) -> int:
104
+ """ extract gpu num from serve command """
105
+ cmd_tp_size = 1
106
+ cmd_dp_size = 1
107
+ if "tensor-parallel-size" in serve_args_dict:
108
+ cmd_tp_size = int(serve_args_dict["tensor-parallel-size"])
109
+ elif "tp" in serve_args_dict:
110
+ cmd_tp_size = int(serve_args_dict["tp"])
111
+ elif "tp-size" in serve_args_dict:
112
+ cmd_tp_size = int(serve_args_dict["tp-size"])
113
+ if "data-parallel-size" in serve_args_dict:
114
+ cmd_dp_size = int(serve_args_dict["data-parallel-size"])
115
+ elif "dp" in serve_args_dict:
116
+ cmd_dp_size = int(serve_args_dict["dp"])
117
+ elif "dp-size" in serve_args_dict:
118
+ cmd_dp_size = int(serve_args_dict["dp-size"])
119
+ if "pipeline_parallel_size" in serve_args_dict or "pp" in serve_args_dict:
120
+ raise ValueError("Pipeline parallel size is not supported.")
121
+ cmd_gpu_num = cmd_tp_size * cmd_dp_size
122
+ if cmd_gpu_num > 8:
123
+ raise ValueError("Only support up to 8 GPUs for single task replica.")
124
+ print(f'cmd_tp_size: {cmd_tp_size}, cmd_dp_size: {cmd_dp_size}, cmd_gpu_num: {cmd_gpu_num}')
125
+ return cmd_gpu_num