gmicloud 0.1.4__py3-none-any.whl → 0.1.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gmicloud/__init__.py +12 -1
- gmicloud/_internal/_client/_artifact_client.py +126 -56
- gmicloud/_internal/_client/_http_client.py +5 -1
- gmicloud/_internal/_client/_iam_client.py +107 -42
- gmicloud/_internal/_client/_task_client.py +75 -30
- gmicloud/_internal/_enums.py +13 -0
- gmicloud/_internal/_manager/_artifact_manager.py +100 -5
- gmicloud/_internal/_manager/_iam_manager.py +36 -0
- gmicloud/_internal/_manager/_task_manager.py +88 -12
- gmicloud/_internal/_models.py +121 -12
- gmicloud/client.py +194 -69
- gmicloud/tests/test_artifacts.py +14 -15
- gmicloud/tests/test_tasks.py +1 -1
- gmicloud-0.1.6.dist-info/METADATA +147 -0
- gmicloud-0.1.6.dist-info/RECORD +27 -0
- {gmicloud-0.1.4.dist-info → gmicloud-0.1.6.dist-info}/WHEEL +1 -1
- gmicloud-0.1.4.dist-info/METADATA +0 -250
- gmicloud-0.1.4.dist-info/RECORD +0 -26
- {gmicloud-0.1.4.dist-info → gmicloud-0.1.6.dist-info}/top_level.txt +0 -0
| @@ -1,9 +1,14 @@ | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            from requests.exceptions import RequestException
         | 
| 3 | 
            +
             | 
| 1 4 | 
             
            from ._http_client import HTTPClient
         | 
| 2 5 | 
             
            from ._decorator import handle_refresh_token
         | 
| 3 6 | 
             
            from ._iam_client import IAMClient
         | 
| 4 7 | 
             
            from .._config import TASK_SERVICE_BASE_URL
         | 
| 5 8 | 
             
            from .._models import *
         | 
| 6 9 |  | 
| 10 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 11 | 
            +
             | 
| 7 12 |  | 
| 8 13 | 
             
            class TaskClient:
         | 
| 9 14 | 
             
                """
         | 
| @@ -21,17 +26,19 @@ class TaskClient: | |
| 21 26 | 
             
                    self.iam_client = iam_client
         | 
| 22 27 |  | 
| 23 28 | 
             
                @handle_refresh_token
         | 
| 24 | 
            -
                def get_task(self, task_id: str) -> Task:
         | 
| 29 | 
            +
                def get_task(self, task_id: str) -> Optional[Task]:
         | 
| 25 30 | 
             
                    """
         | 
| 26 31 | 
             
                    Retrieves a task from the task service using the given task ID.
         | 
| 27 32 |  | 
| 28 33 | 
             
                    :param task_id: The ID of the task to be retrieved.
         | 
| 29 | 
            -
                    :return: An instance of Task containing the details of the retrieved task.
         | 
| 30 | 
            -
                    :rtype: Task
         | 
| 34 | 
            +
                    :return: An instance of Task containing the details of the retrieved task, or None if an error occurs.
         | 
| 31 35 | 
             
                    """
         | 
| 32 | 
            -
                     | 
| 33 | 
            -
             | 
| 34 | 
            -
             | 
| 36 | 
            +
                    try:
         | 
| 37 | 
            +
                        response = self.client.get("/get_task", self.iam_client.get_custom_headers(), {"task_id": task_id})
         | 
| 38 | 
            +
                        return Task.model_validate(response) if response else None
         | 
| 39 | 
            +
                    except (RequestException, ValueError) as e:
         | 
| 40 | 
            +
                        logger.error(f"Failed to retrieve task {task_id}: {e}")
         | 
| 41 | 
            +
                        return None
         | 
| 35 42 |  | 
| 36 43 | 
             
                @handle_refresh_token
         | 
| 37 44 | 
             
                def get_all_tasks(self) -> GetAllTasksResponse:
         | 
| @@ -39,70 +46,108 @@ class TaskClient: | |
| 39 46 | 
             
                    Retrieves all tasks from the task service.
         | 
| 40 47 |  | 
| 41 48 | 
             
                    :return: An instance of GetAllTasksResponse containing the retrieved tasks.
         | 
| 42 | 
            -
                    :rtype: GetAllTasksResponse
         | 
| 43 49 | 
             
                    """
         | 
| 44 | 
            -
                     | 
| 45 | 
            -
             | 
| 50 | 
            +
                    try:
         | 
| 51 | 
            +
                        response = self.client.get("/get_tasks", self.iam_client.get_custom_headers())
         | 
| 52 | 
            +
                        if not response:
         | 
| 53 | 
            +
                            logger.error("Empty response from /get_tasks")
         | 
| 54 | 
            +
                            return GetAllTasksResponse(tasks=[])
         | 
| 55 | 
            +
                        return GetAllTasksResponse.model_validate(response)
         | 
| 56 | 
            +
                    except (RequestException, ValueError) as e:
         | 
| 57 | 
            +
                        logger.error(f"Failed to retrieve all tasks: {e}")
         | 
| 46 58 | 
             
                        return GetAllTasksResponse(tasks=[])
         | 
| 47 59 |  | 
| 48 | 
            -
                    return GetAllTasksResponse.model_validate(result)
         | 
| 49 | 
            -
             | 
| 50 60 | 
             
                @handle_refresh_token
         | 
| 51 | 
            -
                def create_task(self, task: Task) -> CreateTaskResponse:
         | 
| 61 | 
            +
                def create_task(self, task: Task) -> Optional[CreateTaskResponse]:
         | 
| 52 62 | 
             
                    """
         | 
| 53 63 | 
             
                    Creates a new task using the provided task object.
         | 
| 54 64 |  | 
| 55 65 | 
             
                    :param task: The Task object containing the details of the task to be created.
         | 
| 66 | 
            +
                    :return: The response object containing created task details, or None if an error occurs.
         | 
| 56 67 | 
             
                    """
         | 
| 57 | 
            -
                     | 
| 58 | 
            -
             | 
| 59 | 
            -
             | 
| 68 | 
            +
                    try:
         | 
| 69 | 
            +
                        response = self.client.post("/create_task", self.iam_client.get_custom_headers(), task.model_dump())
         | 
| 70 | 
            +
                        return CreateTaskResponse.model_validate(response) if response else None
         | 
| 71 | 
            +
                    except (RequestException, ValueError) as e:
         | 
| 72 | 
            +
                        logger.error(f"Failed to create task: {e}")
         | 
| 73 | 
            +
                        return None
         | 
| 60 74 |  | 
| 61 75 | 
             
                @handle_refresh_token
         | 
| 62 | 
            -
                def update_task_schedule(self, task: Task):
         | 
| 76 | 
            +
                def update_task_schedule(self, task: Task) -> bool:
         | 
| 63 77 | 
             
                    """
         | 
| 64 78 | 
             
                    Updates the schedule of an existing task.
         | 
| 65 79 |  | 
| 66 80 | 
             
                    :param task: The Task object containing the updated task details.
         | 
| 81 | 
            +
                    :return: True if update is successful, False otherwise.
         | 
| 67 82 | 
             
                    """
         | 
| 68 | 
            -
                     | 
| 83 | 
            +
                    try:
         | 
| 84 | 
            +
                        response = self.client.put("/update_schedule", self.iam_client.get_custom_headers(), task.model_dump())
         | 
| 85 | 
            +
                        return response is not None
         | 
| 86 | 
            +
                    except RequestException as e:
         | 
| 87 | 
            +
                        logger.error(f"Failed to update schedule for task {task.task_id}: {e}")
         | 
| 88 | 
            +
                        return False
         | 
| 69 89 |  | 
| 70 90 | 
             
                @handle_refresh_token
         | 
| 71 | 
            -
                def start_task(self, task_id: str):
         | 
| 91 | 
            +
                def start_task(self, task_id: str) -> bool:
         | 
| 72 92 | 
             
                    """
         | 
| 73 93 | 
             
                    Starts a task using the given task ID.
         | 
| 74 94 |  | 
| 75 95 | 
             
                    :param task_id: The ID of the task to be started.
         | 
| 96 | 
            +
                    :return: True if start is successful, False otherwise.
         | 
| 76 97 | 
             
                    """
         | 
| 77 | 
            -
                     | 
| 98 | 
            +
                    try:
         | 
| 99 | 
            +
                        response = self.client.post("/start_task", self.iam_client.get_custom_headers(), {"task_id": task_id})
         | 
| 100 | 
            +
                        return response is not None
         | 
| 101 | 
            +
                    except RequestException as e:
         | 
| 102 | 
            +
                        logger.error(f"Failed to start task {task_id}: {e}")
         | 
| 103 | 
            +
                        return False
         | 
| 78 104 |  | 
| 79 105 | 
             
                @handle_refresh_token
         | 
| 80 | 
            -
                def stop_task(self, task_id: str):
         | 
| 106 | 
            +
                def stop_task(self, task_id: str) -> bool:
         | 
| 81 107 | 
             
                    """
         | 
| 82 108 | 
             
                    Stops a running task using the given task ID.
         | 
| 83 109 |  | 
| 84 110 | 
             
                    :param task_id: The ID of the task to be stopped.
         | 
| 111 | 
            +
                    :return: True if stop is successful, False otherwise.
         | 
| 85 112 | 
             
                    """
         | 
| 86 | 
            -
                     | 
| 113 | 
            +
                    try:
         | 
| 114 | 
            +
                        response = self.client.post("/stop_task", self.iam_client.get_custom_headers(), {"task_id": task_id})
         | 
| 115 | 
            +
                        return response is not None
         | 
| 116 | 
            +
                    except RequestException as e:
         | 
| 117 | 
            +
                        logger.error(f"Failed to stop task {task_id}: {e}")
         | 
| 118 | 
            +
                        return False
         | 
| 87 119 |  | 
| 88 120 | 
             
                @handle_refresh_token
         | 
| 89 | 
            -
                def get_usage_data(self, start_timestamp: str, end_timestamp: str) -> GetUsageDataResponse:
         | 
| 121 | 
            +
                def get_usage_data(self, start_timestamp: str, end_timestamp: str) -> Optional[GetUsageDataResponse]:
         | 
| 90 122 | 
             
                    """
         | 
| 91 123 | 
             
                    Retrieves the usage data of a task using the given task ID.
         | 
| 92 124 |  | 
| 93 125 | 
             
                    :param start_timestamp: The start timestamp of the usage data.
         | 
| 94 126 | 
             
                    :param end_timestamp: The end timestamp of the usage data.
         | 
| 95 | 
            -
                     | 
| 96 | 
            -
                     | 
| 97 | 
            -
             | 
| 98 | 
            -
             | 
| 99 | 
            -
             | 
| 127 | 
            +
                    :return: An instance of GetUsageDataResponse, or None if an error occurs.
         | 
| 128 | 
            +
                    """
         | 
| 129 | 
            +
                    try:
         | 
| 130 | 
            +
                        response = self.client.get(
         | 
| 131 | 
            +
                            "/get_usage_data",
         | 
| 132 | 
            +
                            self.iam_client.get_custom_headers(),
         | 
| 133 | 
            +
                            {"start_timestamp": start_timestamp, "end_timestamp": end_timestamp}
         | 
| 134 | 
            +
                        )
         | 
| 135 | 
            +
                        return GetUsageDataResponse.model_validate(response) if response else None
         | 
| 136 | 
            +
                    except (RequestException, ValueError) as e:
         | 
| 137 | 
            +
                        logger.error(f"Failed to retrieve usage data from {start_timestamp} to {end_timestamp}: {e}")
         | 
| 138 | 
            +
                        return None
         | 
| 100 139 |  | 
| 101 140 | 
             
                @handle_refresh_token
         | 
| 102 | 
            -
                def archive_task(self, task_id: str):
         | 
| 141 | 
            +
                def archive_task(self, task_id: str) -> bool:
         | 
| 103 142 | 
             
                    """
         | 
| 104 143 | 
             
                    Archives a task using the given task ID.
         | 
| 105 144 |  | 
| 106 145 | 
             
                    :param task_id: The ID of the task to be archived.
         | 
| 107 | 
            -
                     | 
| 108 | 
            -
                     | 
| 146 | 
            +
                    :return: True if archiving is successful, False otherwise.
         | 
| 147 | 
            +
                    """
         | 
| 148 | 
            +
                    try:
         | 
| 149 | 
            +
                        response = self.client.post("/archive_task", self.iam_client.get_custom_headers(), {"task_id": task_id})
         | 
| 150 | 
            +
                        return response is not None
         | 
| 151 | 
            +
                    except RequestException as e:
         | 
| 152 | 
            +
                        logger.error(f"Failed to archive task {task_id}: {e}")
         | 
| 153 | 
            +
                        return False
         | 
    
        gmicloud/_internal/_enums.py
    CHANGED
    
    | @@ -23,3 +23,16 @@ class TaskEndpointStatus(str, Enum): | |
| 23 23 | 
             
                READY = "ready"
         | 
| 24 24 | 
             
                UNREADY = "unready"
         | 
| 25 25 | 
             
                NEW = "new"
         | 
| 26 | 
            +
             | 
| 27 | 
            +
            class TaskStatus(str, Enum):
         | 
| 28 | 
            +
                IDLE = "idle"
         | 
| 29 | 
            +
                STARTING = "starting"
         | 
| 30 | 
            +
                IN_QUEUE = "in-queue"
         | 
| 31 | 
            +
                RUNNING = "running"
         | 
| 32 | 
            +
                NEEDSTOP = "needstop"
         | 
| 33 | 
            +
                ARCHIVED = "archived"
         | 
| 34 | 
            +
             | 
| 35 | 
            +
            class ModelParameterType(str, Enum):
         | 
| 36 | 
            +
                NUMERIC = "numeric"
         | 
| 37 | 
            +
                TEXT = "text"
         | 
| 38 | 
            +
                BOOLEAN = "boolean"
         | 
| @@ -1,4 +1,5 @@ | |
| 1 1 | 
             
            import os
         | 
| 2 | 
            +
            import time
         | 
| 2 3 | 
             
            from typing import List
         | 
| 3 4 | 
             
            import mimetypes
         | 
| 4 5 |  | 
| @@ -7,6 +8,9 @@ from .._client._artifact_client import ArtifactClient | |
| 7 8 | 
             
            from .._client._file_upload_client import FileUploadClient
         | 
| 8 9 | 
             
            from .._models import *
         | 
| 9 10 |  | 
| 11 | 
            +
            import logging
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 10 14 |  | 
| 11 15 | 
             
            class ArtifactManager:
         | 
| 12 16 | 
             
                """
         | 
| @@ -81,7 +85,46 @@ class ArtifactManager: | |
| 81 85 | 
             
                    if not artifact_template_id or not artifact_template_id.strip():
         | 
| 82 86 | 
             
                        raise ValueError("Artifact template ID is required and cannot be empty.")
         | 
| 83 87 |  | 
| 84 | 
            -
                     | 
| 88 | 
            +
                    resp = self.artifact_client.create_artifact_from_template(artifact_template_id)
         | 
| 89 | 
            +
                    if not resp or not resp.artifact_id:
         | 
| 90 | 
            +
                        raise ValueError("Failed to create artifact from template.")
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    return resp.artifact_id
         | 
| 93 | 
            +
                
         | 
| 94 | 
            +
                def create_artifact_from_template_name(self, artifact_template_name: str) -> tuple[str, ReplicaResource]:
         | 
| 95 | 
            +
                    """
         | 
| 96 | 
            +
                    Create an artifact from a template.
         | 
| 97 | 
            +
                    :param artifact_template_name: The name of the template to use.
         | 
| 98 | 
            +
                    :return: A tuple containing the artifact ID and the recommended replica resources.
         | 
| 99 | 
            +
                    :rtype: tuple[str, ReplicaResource]
         | 
| 100 | 
            +
                    """
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    recommended_replica_resources = None
         | 
| 103 | 
            +
                    template_id = None
         | 
| 104 | 
            +
                    try:
         | 
| 105 | 
            +
                        templates = self.get_public_templates()
         | 
| 106 | 
            +
                    except Exception as e:
         | 
| 107 | 
            +
                        logger.error(f"Failed to get artifact templates, Error: {e}")
         | 
| 108 | 
            +
                    for template in templates:
         | 
| 109 | 
            +
                        if template.template_data and template.template_data.name == artifact_template_name:
         | 
| 110 | 
            +
                            resources_template = template.template_data.resources
         | 
| 111 | 
            +
                            recommended_replica_resources = ReplicaResource(
         | 
| 112 | 
            +
                                cpu=resources_template.cpu,
         | 
| 113 | 
            +
                                ram_gb=resources_template.memory,
         | 
| 114 | 
            +
                                gpu=resources_template.gpu,
         | 
| 115 | 
            +
                                gpu_name=resources_template.gpu_name,
         | 
| 116 | 
            +
                            )
         | 
| 117 | 
            +
                            template_id = template.template_id
         | 
| 118 | 
            +
                            break
         | 
| 119 | 
            +
                    if not template_id:
         | 
| 120 | 
            +
                        raise ValueError(f"Template with name {artifact_template_name} not found.")
         | 
| 121 | 
            +
                    try: 
         | 
| 122 | 
            +
                        artifact_id = self.create_artifact_from_template(template_id)
         | 
| 123 | 
            +
                        self.wait_for_artifact_ready(artifact_id)
         | 
| 124 | 
            +
                        return artifact_id, recommended_replica_resources
         | 
| 125 | 
            +
                    except Exception as e:
         | 
| 126 | 
            +
                        logger.error(f"Failed to create artifact from template, Error: {e}")
         | 
| 127 | 
            +
                        raise e
         | 
| 85 128 |  | 
| 86 129 | 
             
                def rebuild_artifact(self, artifact_id: str) -> RebuildArtifactResponse:
         | 
| 87 130 | 
             
                    """
         | 
| @@ -170,7 +213,11 @@ class ArtifactManager: | |
| 170 213 |  | 
| 171 214 | 
             
                    req = GetBigFileUploadUrlRequest(artifact_id=artifact_id, file_name=model_file_name, file_type=model_file_type)
         | 
| 172 215 |  | 
| 173 | 
            -
                     | 
| 216 | 
            +
                    resp = self.artifact_client.get_bigfile_upload_url(req)
         | 
| 217 | 
            +
                    if not resp or not resp.upload_link:
         | 
| 218 | 
            +
                        raise ValueError("Failed to get bigfile upload URL.")
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    return resp.upload_link
         | 
| 174 221 |  | 
| 175 222 | 
             
                def delete_bigfile(self, artifact_id: str, file_name: str) -> str:
         | 
| 176 223 | 
             
                    """
         | 
| @@ -182,7 +229,11 @@ class ArtifactManager: | |
| 182 229 | 
             
                    self._validate_artifact_id(artifact_id)
         | 
| 183 230 | 
             
                    self._validate_file_name(file_name)
         | 
| 184 231 |  | 
| 185 | 
            -
                     | 
| 232 | 
            +
                    resp = self.artifact_client.delete_bigfile(artifact_id, file_name)
         | 
| 233 | 
            +
                    if not resp or not resp.status:
         | 
| 234 | 
            +
                        raise ValueError("Failed to delete bigfile.")
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                    return resp.status
         | 
| 186 237 |  | 
| 187 238 | 
             
                def upload_large_file(self, upload_link: str, file_path: str) -> None:
         | 
| 188 239 | 
             
                    """
         | 
| @@ -228,15 +279,59 @@ class ArtifactManager: | |
| 228 279 | 
             
                        FileUploadClient.upload_large_file(bigfile_upload_url_resp.upload_link, model_file_path)
         | 
| 229 280 |  | 
| 230 281 | 
             
                    return artifact_id
         | 
| 282 | 
            +
                
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                def wait_for_artifact_ready(self, artifact_id: str, timeout_s: int = 900) -> None:
         | 
| 285 | 
            +
                    """
         | 
| 286 | 
            +
                    Wait for an artifact to be ready.
         | 
| 287 | 
            +
             | 
| 288 | 
            +
                    :param artifact_id: The ID of the artifact to wait for.
         | 
| 289 | 
            +
                    :param timeout_s: The timeout in seconds.
         | 
| 290 | 
            +
                    :return: None
         | 
| 291 | 
            +
                    """
         | 
| 292 | 
            +
                    start_time = time.time()
         | 
| 293 | 
            +
                    while True:
         | 
| 294 | 
            +
                        try:
         | 
| 295 | 
            +
                            artifact = self.get_artifact(artifact_id)
         | 
| 296 | 
            +
                            if artifact.build_status == BuildStatus.SUCCESS:
         | 
| 297 | 
            +
                                return
         | 
| 298 | 
            +
                            elif artifact.build_status in [BuildStatus.FAILED, BuildStatus.TIMEOUT, BuildStatus.CANCELLED]:
         | 
| 299 | 
            +
                                raise Exception(f"Artifact build failed, status: {artifact.build_status}")
         | 
| 300 | 
            +
                        except Exception as e:
         | 
| 301 | 
            +
                            logger.error(f"Failed to get artifact, Error: {e}")
         | 
| 302 | 
            +
                        if time.time() - start_time > timeout_s:
         | 
| 303 | 
            +
                            raise Exception(f"Artifact build takes more than {timeout_s // 60} minutes. Testing aborted.")
         | 
| 304 | 
            +
                        time.sleep(10)
         | 
| 231 305 |  | 
| 232 | 
            -
                 | 
| 306 | 
            +
                
         | 
| 307 | 
            +
                def get_public_templates(self) -> List[ArtifactTemplate]:
         | 
| 233 308 | 
             
                    """
         | 
| 234 309 | 
             
                    Fetch all artifact templates.
         | 
| 235 310 |  | 
| 236 311 | 
             
                    :return: A list of ArtifactTemplate objects.
         | 
| 237 312 | 
             
                    :rtype: List[ArtifactTemplate]
         | 
| 238 313 | 
             
                    """
         | 
| 239 | 
            -
                    return self.artifact_client. | 
| 314 | 
            +
                    return self.artifact_client.get_public_templates()
         | 
| 315 | 
            +
                    
         | 
| 316 | 
            +
             | 
| 317 | 
            +
                def list_public_template_names(self) -> list[str]:
         | 
| 318 | 
            +
                    """
         | 
| 319 | 
            +
                    List all public templates.
         | 
| 320 | 
            +
             | 
| 321 | 
            +
                    :return: A list of template names.
         | 
| 322 | 
            +
                    :rtype: list[str]
         | 
| 323 | 
            +
                    """
         | 
| 324 | 
            +
                    template_names = []
         | 
| 325 | 
            +
                    try: 
         | 
| 326 | 
            +
                        templates = self.get_public_templates()
         | 
| 327 | 
            +
                        for template in templates:
         | 
| 328 | 
            +
                            if template.template_data and template.template_data.name:
         | 
| 329 | 
            +
                                template_names.append(template.template_data.name)
         | 
| 330 | 
            +
                        return template_names
         | 
| 331 | 
            +
                    except Exception as e:
         | 
| 332 | 
            +
                        logger.error(f"Failed to get artifact templates, Error: {e}")
         | 
| 333 | 
            +
                        return []
         | 
| 334 | 
            +
             | 
| 240 335 |  | 
| 241 336 | 
             
                @staticmethod
         | 
| 242 337 | 
             
                def _validate_file_name(file_name: str) -> None:
         | 
| @@ -0,0 +1,36 @@ | |
| 1 | 
            +
            from datetime import datetime
         | 
| 2 | 
            +
            from .._client._iam_client import IAMClient
         | 
| 3 | 
            +
            from .._models import *
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class IAMManager:
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                IamManager handles operations related to IAM, including user authentication and authorization.
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                def __init__(self, iam_client: IAMClient):
         | 
| 12 | 
            +
                    """
         | 
| 13 | 
            +
                    Initialize the IAMManager instance and the associated IAMClient.
         | 
| 14 | 
            +
                    """
         | 
| 15 | 
            +
                    self.iam_client = iam_client
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                def create_org_api_key(self, name: str, expires_at: Optional[int] = None) -> str:
         | 
| 18 | 
            +
                    """
         | 
| 19 | 
            +
                    Creates a new API key for the current user.
         | 
| 20 | 
            +
                    """
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                    if not name:
         | 
| 23 | 
            +
                        raise ValueError("API key name cannot be empty")
         | 
| 24 | 
            +
                    if not expires_at:
         | 
| 25 | 
            +
                        # Set the expiration date to 30 days from now
         | 
| 26 | 
            +
                        expires_at = int(datetime.now().timestamp()) + 30 * 24 * 60 * 60
         | 
| 27 | 
            +
                        print(expires_at)
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                    return self.iam_client.create_org_api_key(
         | 
| 30 | 
            +
                        CreateAPIKeyRequest(name=name, type="ie_model", expiresAt=expires_at))
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                def get_org_api_keys(self) -> List[APIKey]:
         | 
| 33 | 
            +
                    """
         | 
| 34 | 
            +
                    Fetches all API keys for the current user.
         | 
| 35 | 
            +
                    """
         | 
| 36 | 
            +
                    return self.iam_client.get_org_api_keys().keys
         | 
| @@ -4,6 +4,10 @@ from .._client._iam_client import IAMClient | |
| 4 4 | 
             
            from .._client._task_client import TaskClient
         | 
| 5 5 | 
             
            from .._models import *
         | 
| 6 6 |  | 
| 7 | 
            +
            import time
         | 
| 8 | 
            +
            import logging
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            logger = logging.getLogger(__name__)
         | 
| 7 11 |  | 
| 8 12 | 
             
            class TaskManager:
         | 
| 9 13 | 
             
                """
         | 
| @@ -37,7 +41,11 @@ class TaskManager: | |
| 37 41 |  | 
| 38 42 | 
             
                    :return: A list of `Task` objects.
         | 
| 39 43 | 
             
                    """
         | 
| 40 | 
            -
                     | 
| 44 | 
            +
                    resp = self.task_client.get_all_tasks(self.iam_client.get_user_id())
         | 
| 45 | 
            +
                    if not resp or not resp.tasks:
         | 
| 46 | 
            +
                        return []
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                    return resp.tasks
         | 
| 41 49 |  | 
| 42 50 | 
             
                def create_task(self, task: Task) -> Task:
         | 
| 43 51 | 
             
                    """
         | 
| @@ -51,8 +59,11 @@ class TaskManager: | |
| 51 59 | 
             
                    self._validate_task(task)
         | 
| 52 60 | 
             
                    if not task.owner:
         | 
| 53 61 | 
             
                        task.owner = TaskOwner(user_id=self.iam_client.get_user_id())
         | 
| 62 | 
            +
                    resp = self.task_client.create_task(task)
         | 
| 63 | 
            +
                    if not resp or not resp.task:
         | 
| 64 | 
            +
                        raise ValueError("Failed to create task.")
         | 
| 54 65 |  | 
| 55 | 
            -
                    return  | 
| 66 | 
            +
                    return resp.task
         | 
| 56 67 |  | 
| 57 68 | 
             
                def create_task_from_file(self, artifact_id: str, config_file_path: str, trigger_timestamp: int = None) -> Task:
         | 
| 58 69 | 
             
                    """
         | 
| @@ -76,7 +87,7 @@ class TaskManager: | |
| 76 87 |  | 
| 77 88 | 
             
                    return self.create_task(task)
         | 
| 78 89 |  | 
| 79 | 
            -
                def update_task_schedule(self, task: Task):
         | 
| 90 | 
            +
                def update_task_schedule(self, task: Task) -> bool:
         | 
| 80 91 | 
             
                    """
         | 
| 81 92 | 
             
                    Update the schedule of an existing task.
         | 
| 82 93 |  | 
| @@ -87,10 +98,10 @@ class TaskManager: | |
| 87 98 | 
             
                    self._validate_task(task)
         | 
| 88 99 | 
             
                    self._validate_not_empty(task.task_id, "Task ID")
         | 
| 89 100 |  | 
| 90 | 
            -
                    self.task_client.update_task_schedule(task)
         | 
| 101 | 
            +
                    return self.task_client.update_task_schedule(task)
         | 
| 91 102 |  | 
| 92 103 | 
             
                def update_task_schedule_from_file(self, artifact_id: str, task_id: str, config_file_path: str,
         | 
| 93 | 
            -
                                                   trigger_timestamp: int = None):
         | 
| 104 | 
            +
                                                   trigger_timestamp: int = None) -> bool:
         | 
| 94 105 | 
             
                    """
         | 
| 95 106 | 
             
                    Update the schedule of an existing task using data from a file. The file should contain a valid task definition.
         | 
| 96 107 |  | 
| @@ -112,9 +123,9 @@ class TaskManager: | |
| 112 123 | 
             
                    if trigger_timestamp:
         | 
| 113 124 | 
             
                        task.config.task_scheduling.scheduling_oneoff.trigger_timestamp = trigger_timestamp
         | 
| 114 125 |  | 
| 115 | 
            -
                    self.update_task_schedule(task)
         | 
| 126 | 
            +
                    return self.update_task_schedule(task)
         | 
| 116 127 |  | 
| 117 | 
            -
                def start_task(self, task_id: str):
         | 
| 128 | 
            +
                def start_task(self, task_id: str) -> bool:
         | 
| 118 129 | 
             
                    """
         | 
| 119 130 | 
             
                    Start a task by its ID.
         | 
| 120 131 |  | 
| @@ -124,9 +135,53 @@ class TaskManager: | |
| 124 135 | 
             
                    """
         | 
| 125 136 | 
             
                    self._validate_not_empty(task_id, "Task ID")
         | 
| 126 137 |  | 
| 127 | 
            -
                    self.task_client.start_task(task_id)
         | 
| 138 | 
            +
                    return self.task_client.start_task(task_id)
         | 
| 139 | 
            +
                
         | 
| 128 140 |  | 
| 129 | 
            -
                def  | 
| 141 | 
            +
                def start_task_and_wait(self, task_id: str, timeout_s: int = 900) -> Task:
         | 
| 142 | 
            +
                    """
         | 
| 143 | 
            +
                    Start a task and wait for it to be ready.
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    :param task_id: The ID of the task to start.
         | 
| 146 | 
            +
                    :param timeout_s: The timeout in seconds.
         | 
| 147 | 
            +
                    :return: The task object.
         | 
| 148 | 
            +
                    :rtype: Task
         | 
| 149 | 
            +
                    """
         | 
| 150 | 
            +
                    # trigger start task
         | 
| 151 | 
            +
                    try:
         | 
| 152 | 
            +
                        self.start_task(task_id)
         | 
| 153 | 
            +
                        logger.info(f"Started task ID: {task_id}")
         | 
| 154 | 
            +
                    except Exception as e:
         | 
| 155 | 
            +
                        logger.error(f"Failed to start task, Error: {e}")
         | 
| 156 | 
            +
                        raise e
         | 
| 157 | 
            +
                    
         | 
| 158 | 
            +
                    start_time = time.time()
         | 
| 159 | 
            +
                    while True:
         | 
| 160 | 
            +
                        try:
         | 
| 161 | 
            +
                            task = self.get_task(task_id)
         | 
| 162 | 
            +
                            if task.task_status == TaskStatus.RUNNING:
         | 
| 163 | 
            +
                                return task
         | 
| 164 | 
            +
                            elif task.task_status in [TaskStatus.NEEDSTOP, TaskStatus.ARCHIVED]:
         | 
| 165 | 
            +
                                raise Exception(f"Unexpected task status after starting: {task.task_status}")
         | 
| 166 | 
            +
                            # Also check endpoint status. 
         | 
| 167 | 
            +
                            elif task.task_status == TaskStatus.RUNNING:
         | 
| 168 | 
            +
                                if task.endpoint_info and task.endpoint_info.endpoint_status == TaskEndpointStatus.RUNNING:
         | 
| 169 | 
            +
                                    return task
         | 
| 170 | 
            +
                                elif task.endpoint_info and task.endpoint_info.endpoint_status in [TaskEndpointStatus.UNKNOWN, TaskEndpointStatus.ARCHIVED]:
         | 
| 171 | 
            +
                                    raise Exception(f"Unexpected endpoint status after starting: {task.endpoint_info.endpoint_status}")
         | 
| 172 | 
            +
                                else:
         | 
| 173 | 
            +
                                    logger.info(f"Pending endpoint starting. endpoint status: {task.endpoint_info.endpoint_status}")
         | 
| 174 | 
            +
                            else:
         | 
| 175 | 
            +
                                logger.info(f"Pending task starting. Task status: {task.task_status}")
         | 
| 176 | 
            +
             | 
| 177 | 
            +
                        except Exception as e:
         | 
| 178 | 
            +
                            logger.error(f"Failed to get task, Error: {e}")
         | 
| 179 | 
            +
                        if time.time() - start_time > timeout_s:
         | 
| 180 | 
            +
                            raise Exception(f"Task creation takes more than {timeout_s // 60} minutes. Testing aborted.")
         | 
| 181 | 
            +
                        time.sleep(10)
         | 
| 182 | 
            +
             | 
| 183 | 
            +
             | 
| 184 | 
            +
                def stop_task(self, task_id: str) -> bool:
         | 
| 130 185 | 
             
                    """
         | 
| 131 186 | 
             
                    Stop a task by its ID.
         | 
| 132 187 |  | 
| @@ -136,7 +191,28 @@ class TaskManager: | |
| 136 191 | 
             
                    """
         | 
| 137 192 | 
             
                    self._validate_not_empty(task_id, "Task ID")
         | 
| 138 193 |  | 
| 139 | 
            -
                     | 
| 194 | 
            +
                    
         | 
| 195 | 
            +
                def stop_task_and_wait(self, task_id: str, timeout_s: int = 900):
         | 
| 196 | 
            +
                    task_manager = self.task_manager
         | 
| 197 | 
            +
                    try:
         | 
| 198 | 
            +
                        self.task_manager.stop_task(task_id)
         | 
| 199 | 
            +
                        logger.info(f"Stopping task ID: {task_id}")
         | 
| 200 | 
            +
                    except Exception as e:
         | 
| 201 | 
            +
                        logger.error(f"Failed to stop task, Error: {e}")
         | 
| 202 | 
            +
                    task_manager = self.task_manager
         | 
| 203 | 
            +
                    start_time = time.time()
         | 
| 204 | 
            +
                    while True:
         | 
| 205 | 
            +
                        try:
         | 
| 206 | 
            +
                            task = self.get_task(task_id)
         | 
| 207 | 
            +
                            if task.task_status == TaskStatus.IDLE:
         | 
| 208 | 
            +
                                break
         | 
| 209 | 
            +
                        except Exception as e:
         | 
| 210 | 
            +
                            logger.error(f"Failed to get task, Error: {e}")
         | 
| 211 | 
            +
                        if time.time() - start_time > timeout_s:
         | 
| 212 | 
            +
                            raise Exception(f"Task stopping takes more than {timeout_s // 60} minutes. Testing aborted.")
         | 
| 213 | 
            +
                        time.sleep(10)
         | 
| 214 | 
            +
             | 
| 215 | 
            +
                    return self.task_client.stop_task(task_id)
         | 
| 140 216 |  | 
| 141 217 | 
             
                def get_usage_data(self, start_timestamp: str, end_timestamp: str) -> GetUsageDataResponse:
         | 
| 142 218 | 
             
                    """
         | 
| @@ -151,7 +227,7 @@ class TaskManager: | |
| 151 227 |  | 
| 152 228 | 
             
                    return self.task_client.get_usage_data(start_timestamp, end_timestamp)
         | 
| 153 229 |  | 
| 154 | 
            -
                def archive_task(self, task_id: str):
         | 
| 230 | 
            +
                def archive_task(self, task_id: str) -> bool:
         | 
| 155 231 | 
             
                    """
         | 
| 156 232 | 
             
                    Archive a task by its ID.
         | 
| 157 233 |  | 
| @@ -161,7 +237,7 @@ class TaskManager: | |
| 161 237 | 
             
                    """
         | 
| 162 238 | 
             
                    self._validate_not_empty(task_id, "Task ID")
         | 
| 163 239 |  | 
| 164 | 
            -
                    self.task_client.archive_task(task_id)
         | 
| 240 | 
            +
                    return self.task_client.archive_task(task_id)
         | 
| 165 241 |  | 
| 166 242 | 
             
                @staticmethod
         | 
| 167 243 | 
             
                def _validate_not_empty(value: str, name: str):
         |