kalavai-client 0.5.14__py3-none-any.whl → 0.5.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- kalavai_client/__init__.py +1 -1
- kalavai_client/assets/apps.yaml +2 -2
- kalavai_client/assets/docker-compose-gui.yaml +10 -0
- kalavai_client/assets/docker-compose-template.yaml +5 -3
- kalavai_client/cli.py +143 -597
- kalavai_client/cluster.py +25 -2
- kalavai_client/core.py +655 -4
- kalavai_client/env.py +41 -2
- kalavai_client/utils.py +55 -19
- {kalavai_client-0.5.14.dist-info → kalavai_client-0.5.16.dist-info}/METADATA +6 -4
- kalavai_client-0.5.16.dist-info/RECORD +23 -0
- {kalavai_client-0.5.14.dist-info → kalavai_client-0.5.16.dist-info}/WHEEL +1 -1
- kalavai_client-0.5.14.dist-info/RECORD +0 -22
- {kalavai_client-0.5.14.dist-info → kalavai_client-0.5.16.dist-info}/LICENSE +0 -0
- {kalavai_client-0.5.14.dist-info → kalavai_client-0.5.16.dist-info}/entry_points.txt +0 -0
    
        kalavai_client/core.py
    CHANGED
    
    | @@ -1,17 +1,70 @@ | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import time
         | 
| 1 3 | 
             
            from collections import defaultdict
         | 
| 2 4 | 
             
            import math
         | 
| 5 | 
            +
            import uuid
         | 
| 6 | 
            +
            import socket
         | 
| 7 | 
            +
            import ipaddress
         | 
| 8 | 
            +
            import netifaces as ni
         | 
| 3 9 |  | 
| 4 10 | 
             
            from pydantic import BaseModel
         | 
| 5 11 |  | 
| 12 | 
            +
            from kalavai_client.cluster import CLUSTER
         | 
| 6 13 | 
             
            from kalavai_client.utils import (
         | 
| 7 14 | 
             
                request_to_server,
         | 
| 8 | 
            -
                load_server_info
         | 
| 15 | 
            +
                load_server_info,
         | 
| 16 | 
            +
                decode_dict,
         | 
| 17 | 
            +
                get_vpn_details,
         | 
| 18 | 
            +
                validate_join_public_seed,
         | 
| 19 | 
            +
                generate_compose_config,
         | 
| 20 | 
            +
                store_server_info,
         | 
| 21 | 
            +
                is_watcher_alive,
         | 
| 22 | 
            +
                run_cmd,
         | 
| 23 | 
            +
                leave_vpn,
         | 
| 24 | 
            +
                safe_remove,
         | 
| 25 | 
            +
                get_public_seeds,
         | 
| 26 | 
            +
                load_template,
         | 
| 27 | 
            +
                is_storage_compatible,
         | 
| 28 | 
            +
                NODE_NAME_KEY,
         | 
| 29 | 
            +
                MANDATORY_TOKEN_FIELDS,
         | 
| 30 | 
            +
                PUBLIC_LOCATION_KEY,
         | 
| 31 | 
            +
                CLUSTER_IP_KEY,
         | 
| 32 | 
            +
                CLUSTER_NAME_KEY,
         | 
| 33 | 
            +
                AUTH_KEY,
         | 
| 34 | 
            +
                WATCHER_SERVICE_KEY,
         | 
| 35 | 
            +
                CLUSTER_TOKEN_KEY,
         | 
| 36 | 
            +
                READONLY_AUTH_KEY,
         | 
| 37 | 
            +
                WRITE_AUTH_KEY,
         | 
| 38 | 
            +
                WATCHER_PORT_KEY,
         | 
| 39 | 
            +
                WATCHER_SERVICE_KEY,
         | 
| 40 | 
            +
                USER_NODE_LABEL_KEY,
         | 
| 41 | 
            +
                ALLOW_UNREGISTERED_USER_KEY
         | 
| 42 | 
            +
            )
         | 
| 43 | 
            +
            from kalavai_client.auth import (
         | 
| 44 | 
            +
                KalavaiAuthClient
         | 
| 9 45 | 
             
            )
         | 
| 10 46 | 
             
            from kalavai_client.env import (
         | 
| 11 47 | 
             
                USER_COOKIE,
         | 
| 12 48 | 
             
                USER_LOCAL_SERVER_FILE,
         | 
| 13 49 | 
             
                TEMPLATE_LABEL,
         | 
| 14 | 
            -
                SERVER_IP_KEY
         | 
| 50 | 
            +
                SERVER_IP_KEY,
         | 
| 51 | 
            +
                USER_COMPOSE_FILE,
         | 
| 52 | 
            +
                DEFAULT_VPN_CONTAINER_NAME,
         | 
| 53 | 
            +
                CONTAINER_HOST_PATH,
         | 
| 54 | 
            +
                USER_VPN_COMPOSE_FILE,
         | 
| 55 | 
            +
                USER_HELM_APPS_FILE,
         | 
| 56 | 
            +
                USER_KUBECONFIG_FILE,
         | 
| 57 | 
            +
                USER_TEMPLATES_FOLDER,
         | 
| 58 | 
            +
                USER_WORKSPACE_TEMPLATE,
         | 
| 59 | 
            +
                DEFAULT_USER_WORKSPACE_VALUES,
         | 
| 60 | 
            +
                STORAGE_CLASS_LABEL,
         | 
| 61 | 
            +
                USER_NODE_LABEL,
         | 
| 62 | 
            +
                DEFAULT_WATCHER_PORT,
         | 
| 63 | 
            +
                HELM_APPS_FILE,
         | 
| 64 | 
            +
                HELM_APPS_VALUES,
         | 
| 65 | 
            +
                POOL_CONFIG_DEFAULT_VALUES,
         | 
| 66 | 
            +
                POOL_CONFIG_TEMPLATE,
         | 
| 67 | 
            +
                FORBIDEDEN_IPS
         | 
| 15 68 | 
             
            )
         | 
| 16 69 |  | 
| 17 70 | 
             
            class Job(BaseModel):
         | 
| @@ -19,6 +72,7 @@ class Job(BaseModel): | |
| 19 72 | 
             
                name: str = None
         | 
| 20 73 | 
             
                workers: str = None
         | 
| 21 74 | 
             
                endpoint: str = None
         | 
| 75 | 
            +
                status: str = None
         | 
| 22 76 |  | 
| 23 77 | 
             
            class DeviceStatus(BaseModel):
         | 
| 24 78 | 
             
                name: str
         | 
| @@ -36,6 +90,76 @@ class GPU(BaseModel): | |
| 36 90 | 
             
                model: str
         | 
| 37 91 |  | 
| 38 92 |  | 
| 93 | 
            +
            def init_user_workspace(force_namespace=None):
         | 
| 94 | 
            +
                
         | 
| 95 | 
            +
                # load template config and populate with values
         | 
| 96 | 
            +
                sidecar_template_yaml = load_template(
         | 
| 97 | 
            +
                    template_path=USER_WORKSPACE_TEMPLATE,
         | 
| 98 | 
            +
                    values={},
         | 
| 99 | 
            +
                    default_values_path=DEFAULT_USER_WORKSPACE_VALUES)
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                try:
         | 
| 102 | 
            +
                    data = {"config": sidecar_template_yaml}
         | 
| 103 | 
            +
                    if force_namespace is not None:
         | 
| 104 | 
            +
                        data["force_namespace"] = force_namespace
         | 
| 105 | 
            +
                    result = request_to_server(
         | 
| 106 | 
            +
                        method="post",
         | 
| 107 | 
            +
                        endpoint="/v1/create_user_space",
         | 
| 108 | 
            +
                        data=data,
         | 
| 109 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 110 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 111 | 
            +
                    )
         | 
| 112 | 
            +
                    return {"success"}
         | 
| 113 | 
            +
                except Exception as e:
         | 
| 114 | 
            +
                    return {"error": f"Error when connecting to kalavai service: {str(e)}"}
         | 
| 115 | 
            +
             | 
| 116 | 
            +
            def check_seed_compatibility():
         | 
| 117 | 
            +
                """Check required packages to start pools"""
         | 
| 118 | 
            +
                logs = []
         | 
| 119 | 
            +
                # docker
         | 
| 120 | 
            +
                try:
         | 
| 121 | 
            +
                    run_cmd("docker version >/dev/null 2>&1")
         | 
| 122 | 
            +
                except:
         | 
| 123 | 
            +
                    logs.append("[red]Docker not installed. Install instructions:\n")
         | 
| 124 | 
            +
                    logs.append("   Linux: https://docs.docker.com/engine/install/\n")
         | 
| 125 | 
            +
                    logs.append("   Windows/MacOS: https://docs.docker.com/desktop/\n")
         | 
| 126 | 
            +
                
         | 
| 127 | 
            +
                return {"issues": logs}
         | 
| 128 | 
            +
             | 
| 129 | 
            +
            def check_worker_compatibility():
         | 
| 130 | 
            +
                """Check required packages to join pools"""
         | 
| 131 | 
            +
                logs = []
         | 
| 132 | 
            +
                # docker
         | 
| 133 | 
            +
                try:
         | 
| 134 | 
            +
                    run_cmd("docker version >/dev/null 2>&1")
         | 
| 135 | 
            +
                except:
         | 
| 136 | 
            +
                    logs.append("[red]Docker not installed. Install instructions:\n")
         | 
| 137 | 
            +
                    logs.append("   Linux: https://docs.docker.com/engine/install/\n")
         | 
| 138 | 
            +
                    logs.append("   Windows/MacOS: https://docs.docker.com/desktop/\n")
         | 
| 139 | 
            +
                
         | 
| 140 | 
            +
                return {"issues": logs}
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            def get_ip_addresses(subnet=None):
         | 
| 143 | 
            +
                ips = []
         | 
| 144 | 
            +
                retry = 3
         | 
| 145 | 
            +
                while len(ips) == 0:
         | 
| 146 | 
            +
                    for iface in ni.interfaces():
         | 
| 147 | 
            +
                        try:
         | 
| 148 | 
            +
                            ip = ni.ifaddresses(iface)[ni.AF_INET][0]['addr']
         | 
| 149 | 
            +
                            if ip in FORBIDEDEN_IPS:
         | 
| 150 | 
            +
                                continue
         | 
| 151 | 
            +
                            if subnet is None or ipaddress.ip_address(ip) in ipaddress.ip_network(subnet):
         | 
| 152 | 
            +
                                ips.append(ip)
         | 
| 153 | 
            +
                        except:
         | 
| 154 | 
            +
                            pass
         | 
| 155 | 
            +
                    if len(ips) == 1:
         | 
| 156 | 
            +
                        return ips[0]
         | 
| 157 | 
            +
                    time.sleep(2)
         | 
| 158 | 
            +
                    retry -= 1
         | 
| 159 | 
            +
                    if retry < 0:
         | 
| 160 | 
            +
                        raise ValueError(f"No IPs available on subnet {subnet}")
         | 
| 161 | 
            +
                return ips
         | 
| 162 | 
            +
             | 
| 39 163 | 
             
            def fetch_resources():
         | 
| 40 164 | 
             
                try:
         | 
| 41 165 | 
             
                    total = request_to_server(
         | 
| @@ -57,6 +181,35 @@ def fetch_resources(): | |
| 57 181 |  | 
| 58 182 | 
             
                return {"total": total, "available": available}
         | 
| 59 183 |  | 
| 184 | 
            +
            def fetch_job_defaults(name):
         | 
| 185 | 
            +
                data = {
         | 
| 186 | 
            +
                    "template": name
         | 
| 187 | 
            +
                }
         | 
| 188 | 
            +
                try:
         | 
| 189 | 
            +
                    defaults = request_to_server(
         | 
| 190 | 
            +
                        method="get",
         | 
| 191 | 
            +
                        endpoint="/v1/job_defaults",
         | 
| 192 | 
            +
                        data=data,
         | 
| 193 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 194 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 195 | 
            +
                    )
         | 
| 196 | 
            +
                    return defaults
         | 
| 197 | 
            +
                except Exception as e:
         | 
| 198 | 
            +
                    return {"error": str(e)}
         | 
| 199 | 
            +
                
         | 
| 200 | 
            +
            def fetch_job_templates():
         | 
| 201 | 
            +
                try:
         | 
| 202 | 
            +
                    templates = request_to_server(
         | 
| 203 | 
            +
                        method="get",
         | 
| 204 | 
            +
                        endpoint="/v1/get_job_templates",
         | 
| 205 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 206 | 
            +
                        data=None,
         | 
| 207 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 208 | 
            +
                    )
         | 
| 209 | 
            +
                    return templates
         | 
| 210 | 
            +
                except Exception as e:
         | 
| 211 | 
            +
                    return {"error": str(e)}
         | 
| 212 | 
            +
             | 
| 60 213 | 
             
            def fetch_job_names():
         | 
| 61 214 | 
             
                data = {
         | 
| 62 215 | 
             
                    "group": "batch.volcano.sh",
         | 
| @@ -126,11 +279,18 @@ def fetch_job_details(jobs: list[Job]): | |
| 126 279 | 
             
                        node_ports = [f"{p['node_port']} (mapped to {p['port']})" for s in result.values() for p in s["ports"]]
         | 
| 127 280 |  | 
| 128 281 | 
             
                        urls = [f"http://{load_server_info(data_key=SERVER_IP_KEY, file=USER_LOCAL_SERVER_FILE)}:{node_port}" for node_port in node_ports]
         | 
| 282 | 
            +
                        if "Ready" in workers_status and len(workers_status) == 1:
         | 
| 283 | 
            +
                            status = "running"
         | 
| 284 | 
            +
                        elif any([st in workers_status for st in ["Failed", "Completed"]]):
         | 
| 285 | 
            +
                            status = "error"
         | 
| 286 | 
            +
                        else:
         | 
| 287 | 
            +
                            status = "pending"
         | 
| 129 288 | 
             
                        job_details.append(
         | 
| 130 289 | 
             
                            Job(owner=namespace,
         | 
| 131 290 | 
             
                                name=deployment,
         | 
| 132 291 | 
             
                                workers=workers,
         | 
| 133 | 
            -
                                endpoint="\n".join(urls) | 
| 292 | 
            +
                                endpoint="\n".join(urls),
         | 
| 293 | 
            +
                                status=str(status))
         | 
| 134 294 | 
             
                        )
         | 
| 135 295 |  | 
| 136 296 | 
             
                    except Exception as e:
         | 
| @@ -138,6 +298,47 @@ def fetch_job_details(jobs: list[Job]): | |
| 138 298 |  | 
| 139 299 | 
             
                return job_details
         | 
| 140 300 |  | 
| 301 | 
            +
            def deploy_job(template_name, values_dict, force_namespace=None):
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                # deploy template with kube-watcher
         | 
| 304 | 
            +
                data = {
         | 
| 305 | 
            +
                    "template": template_name,
         | 
| 306 | 
            +
                    "template_values": values_dict
         | 
| 307 | 
            +
                }
         | 
| 308 | 
            +
                if force_namespace is not None:
         | 
| 309 | 
            +
                    data["force_namespace"] = force_namespace
         | 
| 310 | 
            +
             | 
| 311 | 
            +
                try:
         | 
| 312 | 
            +
                    result = request_to_server(
         | 
| 313 | 
            +
                        method="post",
         | 
| 314 | 
            +
                        endpoint="/v1/deploy_job",
         | 
| 315 | 
            +
                        data=data,
         | 
| 316 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 317 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 318 | 
            +
                    )
         | 
| 319 | 
            +
                    return result
         | 
| 320 | 
            +
                except Exception as e:
         | 
| 321 | 
            +
                    return {"error": str(e)}  
         | 
| 322 | 
            +
                
         | 
| 323 | 
            +
            def delete_job(name, force_namespace=None):
         | 
| 324 | 
            +
                data = {
         | 
| 325 | 
            +
                    "label": TEMPLATE_LABEL, # this ensures that both lws template and services are deleted
         | 
| 326 | 
            +
                    "value": name
         | 
| 327 | 
            +
                }
         | 
| 328 | 
            +
                if force_namespace is not None:
         | 
| 329 | 
            +
                    data["force_namespace"] = force_namespace
         | 
| 330 | 
            +
                try:
         | 
| 331 | 
            +
                    result = request_to_server(
         | 
| 332 | 
            +
                        method="post",
         | 
| 333 | 
            +
                        endpoint="/v1/delete_labeled_resources",
         | 
| 334 | 
            +
                        data=data,
         | 
| 335 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 336 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 337 | 
            +
                    )
         | 
| 338 | 
            +
                    return result
         | 
| 339 | 
            +
                except Exception as e:
         | 
| 340 | 
            +
                    return {"error": str(e)}
         | 
| 341 | 
            +
             | 
| 141 342 | 
             
            def fetch_devices():
         | 
| 142 343 | 
             
                """Load devices status info for all hosts"""
         | 
| 143 344 | 
             
                try:
         | 
| @@ -224,4 +425,454 @@ def fetch_gpus(available=False): | |
| 224 425 | 
             
                    return all_gpus
         | 
| 225 426 |  | 
| 226 427 | 
             
                except Exception as e:
         | 
| 227 | 
            -
                    return {"error": str(e)}
         | 
| 428 | 
            +
                    return {"error": str(e)}
         | 
| 429 | 
            +
             | 
| 430 | 
            +
            def load_user_session():
         | 
| 431 | 
            +
                auth = KalavaiAuthClient(
         | 
| 432 | 
            +
                    user_cookie_file=USER_COOKIE
         | 
| 433 | 
            +
                )
         | 
| 434 | 
            +
                return auth.load_user_session()
         | 
| 435 | 
            +
                
         | 
| 436 | 
            +
            def authenticate_user(username=None, password=None):
         | 
| 437 | 
            +
                auth = KalavaiAuthClient(
         | 
| 438 | 
            +
                    user_cookie_file=USER_COOKIE
         | 
| 439 | 
            +
                )
         | 
| 440 | 
            +
                user = auth.load_user_session()
         | 
| 441 | 
            +
                if user is None:
         | 
| 442 | 
            +
                    user = auth.login(username=username, password=password)
         | 
| 443 | 
            +
                
         | 
| 444 | 
            +
                if user is None:
         | 
| 445 | 
            +
                    return {"error": "Username or password incorrect"}
         | 
| 446 | 
            +
                return user
         | 
| 447 | 
            +
             | 
| 448 | 
            +
            def user_logout():
         | 
| 449 | 
            +
                auth = KalavaiAuthClient(
         | 
| 450 | 
            +
                    user_cookie_file=USER_COOKIE
         | 
| 451 | 
            +
                )
         | 
| 452 | 
            +
                auth.logout()
         | 
| 453 | 
            +
                return True
         | 
| 454 | 
            +
             | 
| 455 | 
            +
            def check_token(token, public=False):
         | 
| 456 | 
            +
                try:
         | 
| 457 | 
            +
                    data = decode_dict(token)
         | 
| 458 | 
            +
                    for field in MANDATORY_TOKEN_FIELDS:
         | 
| 459 | 
            +
                        assert field in data
         | 
| 460 | 
            +
                    if public:
         | 
| 461 | 
            +
                        if data[PUBLIC_LOCATION_KEY] is None:
         | 
| 462 | 
            +
                            raise ValueError("Token is not valid for public pools. Did you start the cluster with a public_location?")
         | 
| 463 | 
            +
                    return {"status": True}
         | 
| 464 | 
            +
                except Exception as e:
         | 
| 465 | 
            +
                    return {"error": str(e)}
         | 
| 466 | 
            +
             | 
| 467 | 
            +
            def attach_to_pool(token, node_name=None):
         | 
| 468 | 
            +
                if node_name is None:
         | 
| 469 | 
            +
                    node_name = f"{socket.gethostname()}-{uuid.uuid4().hex[:6]}"
         | 
| 470 | 
            +
                
         | 
| 471 | 
            +
                # check token
         | 
| 472 | 
            +
                valid = check_token(token=token)
         | 
| 473 | 
            +
                if "error" in valid:
         | 
| 474 | 
            +
                    return {"error": f"Invalid token: {valid}"}
         | 
| 475 | 
            +
             | 
| 476 | 
            +
                try:
         | 
| 477 | 
            +
                    data = decode_dict(token)
         | 
| 478 | 
            +
                    kalavai_seed_ip = data[CLUSTER_IP_KEY]
         | 
| 479 | 
            +
                    cluster_name = data[CLUSTER_NAME_KEY]
         | 
| 480 | 
            +
                    auth_key = data[AUTH_KEY]
         | 
| 481 | 
            +
                    watcher_service = data[WATCHER_SERVICE_KEY]
         | 
| 482 | 
            +
                    public_location = data[PUBLIC_LOCATION_KEY]
         | 
| 483 | 
            +
                    vpn = defaultdict(lambda: None)
         | 
| 484 | 
            +
                except Exception as e:
         | 
| 485 | 
            +
                    return {"error": f"Invalid token. {str(e)}"} 
         | 
| 486 | 
            +
                
         | 
| 487 | 
            +
                user = defaultdict(lambda: None)
         | 
| 488 | 
            +
                if public_location is not None:
         | 
| 489 | 
            +
                    user = load_user_session()
         | 
| 490 | 
            +
                    if user is None:
         | 
| 491 | 
            +
                        return {"error ": "Must be logged in to join public pools"}
         | 
| 492 | 
            +
                    try:
         | 
| 493 | 
            +
                        vpn = get_vpn_details(
         | 
| 494 | 
            +
                            location=public_location,
         | 
| 495 | 
            +
                            user_cookie=USER_COOKIE)
         | 
| 496 | 
            +
                    except Exception as e:
         | 
| 497 | 
            +
                        return {"error": f"Are you authenticated? {str(e)}"}
         | 
| 498 | 
            +
                    try:
         | 
| 499 | 
            +
                        validate_join_public_seed(
         | 
| 500 | 
            +
                            cluster_name=cluster_name,
         | 
| 501 | 
            +
                            join_key=token,
         | 
| 502 | 
            +
                            user_cookie=USER_COOKIE
         | 
| 503 | 
            +
                        )
         | 
| 504 | 
            +
                    except Exception as e:
         | 
| 505 | 
            +
                        return {"error": f"Error when joining network: {str(e)}"}
         | 
| 506 | 
            +
                    
         | 
| 507 | 
            +
                # local agent join
         | 
| 508 | 
            +
                # 1. Generate local cache files
         | 
| 509 | 
            +
                # Generate docker compose recipe
         | 
| 510 | 
            +
                generate_compose_config(
         | 
| 511 | 
            +
                    role="",
         | 
| 512 | 
            +
                    vpn_token=vpn["key"],
         | 
| 513 | 
            +
                    node_name=node_name,
         | 
| 514 | 
            +
                    is_public=public_location is not None)
         | 
| 515 | 
            +
                
         | 
| 516 | 
            +
                store_server_info(
         | 
| 517 | 
            +
                    server_ip=kalavai_seed_ip,
         | 
| 518 | 
            +
                    auth_key=auth_key,
         | 
| 519 | 
            +
                    file=USER_LOCAL_SERVER_FILE,
         | 
| 520 | 
            +
                    watcher_service=watcher_service,
         | 
| 521 | 
            +
                    node_name=node_name,
         | 
| 522 | 
            +
                    cluster_name=cluster_name,
         | 
| 523 | 
            +
                    public_location=public_location,
         | 
| 524 | 
            +
                    user_api_key=user["api_key"])
         | 
| 525 | 
            +
                
         | 
| 526 | 
            +
                run_cmd(f"docker compose -f {USER_COMPOSE_FILE} up -d")
         | 
| 527 | 
            +
                # ensure we are connected
         | 
| 528 | 
            +
                while True:
         | 
| 529 | 
            +
                    time.sleep(30)
         | 
| 530 | 
            +
                    if is_watcher_alive(server_creds=USER_LOCAL_SERVER_FILE, user_cookie=USER_COOKIE):
         | 
| 531 | 
            +
                        break
         | 
| 532 | 
            +
             | 
| 533 | 
            +
                return cluster_name
         | 
| 534 | 
            +
             | 
| 535 | 
            +
            def join_pool(token, num_gpus=0, node_name=None):
         | 
| 536 | 
            +
                compatibility = check_worker_compatibility()
         | 
| 537 | 
            +
                if len(compatibility["issues"]) > 0:
         | 
| 538 | 
            +
                    return {"error": compatibility["issues"]}
         | 
| 539 | 
            +
                
         | 
| 540 | 
            +
                if node_name is None:
         | 
| 541 | 
            +
                    node_name = f"{socket.gethostname()}-{uuid.uuid4().hex[:6]}"
         | 
| 542 | 
            +
                
         | 
| 543 | 
            +
                # check token
         | 
| 544 | 
            +
                valid = check_token(token=token)
         | 
| 545 | 
            +
                if "error" in valid:
         | 
| 546 | 
            +
                    return {"error": f"Invalid token: {valid}"}
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                try:
         | 
| 549 | 
            +
                    data = decode_dict(token)
         | 
| 550 | 
            +
                    kalavai_seed_ip = data[CLUSTER_IP_KEY]
         | 
| 551 | 
            +
                    kalavai_token = data[CLUSTER_TOKEN_KEY]
         | 
| 552 | 
            +
                    cluster_name = data[CLUSTER_NAME_KEY]
         | 
| 553 | 
            +
                    auth_key = data[AUTH_KEY]
         | 
| 554 | 
            +
                    watcher_service = data[WATCHER_SERVICE_KEY]
         | 
| 555 | 
            +
                    public_location = data[PUBLIC_LOCATION_KEY]
         | 
| 556 | 
            +
                    vpn = defaultdict(lambda: None)
         | 
| 557 | 
            +
                except Exception as e:
         | 
| 558 | 
            +
                    return {"error": f"Invalid token. {str(e)}"} 
         | 
| 559 | 
            +
                
         | 
| 560 | 
            +
                # join private network if provided
         | 
| 561 | 
            +
                node_labels = {
         | 
| 562 | 
            +
                    STORAGE_CLASS_LABEL: is_storage_compatible()
         | 
| 563 | 
            +
                }
         | 
| 564 | 
            +
                user = defaultdict(lambda: None)
         | 
| 565 | 
            +
                if public_location is not None:
         | 
| 566 | 
            +
                    user = authenticate_user()
         | 
| 567 | 
            +
                    if user is None:
         | 
| 568 | 
            +
                        return {"error": "Must be logged in to join public pools"}
         | 
| 569 | 
            +
                    try:
         | 
| 570 | 
            +
                        vpn = get_vpn_details(
         | 
| 571 | 
            +
                            location=public_location,
         | 
| 572 | 
            +
                            user_cookie=USER_COOKIE)
         | 
| 573 | 
            +
                        node_labels[USER_NODE_LABEL] = user["username"]
         | 
| 574 | 
            +
                    except Exception as e:
         | 
| 575 | 
            +
                        return {"error": f"Are you authenticated? Error: {str(e)}"}
         | 
| 576 | 
            +
                    try:
         | 
| 577 | 
            +
                        validate_join_public_seed(
         | 
| 578 | 
            +
                            cluster_name=cluster_name,
         | 
| 579 | 
            +
                            join_key=token,
         | 
| 580 | 
            +
                            user_cookie=USER_COOKIE
         | 
| 581 | 
            +
                        )
         | 
| 582 | 
            +
                    except Exception as e:
         | 
| 583 | 
            +
                        return {"error": f"Error when joining network: {str(e)}"}
         | 
| 584 | 
            +
                    
         | 
| 585 | 
            +
                # local agent join
         | 
| 586 | 
            +
                # Generate docker compose recipe
         | 
| 587 | 
            +
                generate_compose_config(
         | 
| 588 | 
            +
                    role="agent",
         | 
| 589 | 
            +
                    pool_ip=f"https://{kalavai_seed_ip}:6443",
         | 
| 590 | 
            +
                    pool_token=kalavai_token,
         | 
| 591 | 
            +
                    num_gpus=num_gpus,
         | 
| 592 | 
            +
                    vpn_token=vpn["key"],
         | 
| 593 | 
            +
                    node_name=node_name,
         | 
| 594 | 
            +
                    node_labels=node_labels,
         | 
| 595 | 
            +
                    is_public=public_location is not None)
         | 
| 596 | 
            +
                
         | 
| 597 | 
            +
                store_server_info(
         | 
| 598 | 
            +
                    server_ip=kalavai_seed_ip,
         | 
| 599 | 
            +
                    auth_key=auth_key,
         | 
| 600 | 
            +
                    file=USER_LOCAL_SERVER_FILE,
         | 
| 601 | 
            +
                    watcher_service=watcher_service,
         | 
| 602 | 
            +
                    node_name=node_name,
         | 
| 603 | 
            +
                    cluster_name=cluster_name,
         | 
| 604 | 
            +
                    public_location=public_location,
         | 
| 605 | 
            +
                    user_api_key=user["api_key"])
         | 
| 606 | 
            +
                
         | 
| 607 | 
            +
                try:
         | 
| 608 | 
            +
                    CLUSTER.start_worker_node()
         | 
| 609 | 
            +
                except Exception as e:
         | 
| 610 | 
            +
                    return {"error": f"Error connecting to {cluster_name} @ {kalavai_seed_ip}. Check with the admin if the token is still valid."}
         | 
| 611 | 
            +
             | 
| 612 | 
            +
                # ensure we are connected
         | 
| 613 | 
            +
                while True:
         | 
| 614 | 
            +
                    time.sleep(30)
         | 
| 615 | 
            +
                    if is_watcher_alive(server_creds=USER_LOCAL_SERVER_FILE, user_cookie=USER_COOKIE):
         | 
| 616 | 
            +
                        break
         | 
| 617 | 
            +
                
         | 
| 618 | 
            +
                # check the node has connected successfully
         | 
| 619 | 
            +
                try:
         | 
| 620 | 
            +
                    while not CLUSTER.is_agent_running():
         | 
| 621 | 
            +
                        time.sleep(30)
         | 
| 622 | 
            +
                except KeyboardInterrupt:
         | 
| 623 | 
            +
                    return {"error": "Installation aborted. Leaving pool."}
         | 
| 624 | 
            +
                
         | 
| 625 | 
            +
                result = init_user_workspace()
         | 
| 626 | 
            +
                if "error" in result:
         | 
| 627 | 
            +
                    return {"error": f"Error when creating user workspace: {result}"}
         | 
| 628 | 
            +
                
         | 
| 629 | 
            +
                return cluster_name
         | 
| 630 | 
            +
             | 
| 631 | 
            +
            def create_pool(cluster_name: str, ip_address: str, app_values: str=None, pool_config_values: str=None, num_gpus: int=0, node_name: str=None, only_registered_users: bool=False, location: str=None):
         | 
| 632 | 
            +
             | 
| 633 | 
            +
                if not check_seed_compatibility():
         | 
| 634 | 
            +
                    return {"error": "Requirements failed"}
         | 
| 635 | 
            +
                
         | 
| 636 | 
            +
                if app_values is None:
         | 
| 637 | 
            +
                    app_values = HELM_APPS_VALUES
         | 
| 638 | 
            +
                
         | 
| 639 | 
            +
                if pool_config_values is None:
         | 
| 640 | 
            +
                    pool_config_values = POOL_CONFIG_DEFAULT_VALUES
         | 
| 641 | 
            +
             | 
| 642 | 
            +
                node_name = f"{socket.gethostname()}-{uuid.uuid4().hex[:6]}" 
         | 
| 643 | 
            +
                
         | 
| 644 | 
            +
                # if only registered users are allowed, check user has logged in
         | 
| 645 | 
            +
                user = defaultdict(lambda: None)
         | 
| 646 | 
            +
                if only_registered_users or location is not None:
         | 
| 647 | 
            +
                    user = authenticate_user()
         | 
| 648 | 
            +
                    if user is None:
         | 
| 649 | 
            +
                        return {"error": "[white]--only-registered-users [red]or [white]--location[red] can only be used if the host is authenticated. Run [yellow]kalavai login[red] to authenticate"}
         | 
| 650 | 
            +
             | 
| 651 | 
            +
                # join private network if provided
         | 
| 652 | 
            +
                vpn = defaultdict(lambda: None)
         | 
| 653 | 
            +
                node_labels = {
         | 
| 654 | 
            +
                    STORAGE_CLASS_LABEL: is_storage_compatible()
         | 
| 655 | 
            +
                }
         | 
| 656 | 
            +
                if location is not None:
         | 
| 657 | 
            +
                    try:
         | 
| 658 | 
            +
                        vpn = get_vpn_details(
         | 
| 659 | 
            +
                            location=location,
         | 
| 660 | 
            +
                            user_cookie=USER_COOKIE)
         | 
| 661 | 
            +
                        node_labels[USER_NODE_LABEL] = user["username"]
         | 
| 662 | 
            +
                    except Exception as e:
         | 
| 663 | 
            +
                        return {"error": f"[red]Error when joining network: {str(e)}"}
         | 
| 664 | 
            +
                
         | 
| 665 | 
            +
                # Generate docker compose recipe
         | 
| 666 | 
            +
                generate_compose_config(
         | 
| 667 | 
            +
                    role="server",
         | 
| 668 | 
            +
                    vpn_token=vpn["key"],
         | 
| 669 | 
            +
                    node_ip_address=ip_address,
         | 
| 670 | 
            +
                    num_gpus=num_gpus,
         | 
| 671 | 
            +
                    node_name=node_name,
         | 
| 672 | 
            +
                    node_labels=node_labels,
         | 
| 673 | 
            +
                    is_public=location is not None
         | 
| 674 | 
            +
                )
         | 
| 675 | 
            +
                
         | 
| 676 | 
            +
                # start server
         | 
| 677 | 
            +
                CLUSTER.start_seed_node()
         | 
| 678 | 
            +
                
         | 
| 679 | 
            +
                while not CLUSTER.is_agent_running():
         | 
| 680 | 
            +
                    time.sleep(10)
         | 
| 681 | 
            +
                
         | 
| 682 | 
            +
                # select IP address (for external discovery)
         | 
| 683 | 
            +
                if ip_address is None or location is not None:
         | 
| 684 | 
            +
                    # load VPN ip
         | 
| 685 | 
            +
                    ip_address = CLUSTER.get_vpn_ip()
         | 
| 686 | 
            +
             | 
| 687 | 
            +
                # populate local cred files
         | 
| 688 | 
            +
                auth_key = str(uuid.uuid4())
         | 
| 689 | 
            +
                write_auth_key = str(uuid.uuid4())
         | 
| 690 | 
            +
                readonly_auth_key = str(uuid.uuid4())
         | 
| 691 | 
            +
                
         | 
| 692 | 
            +
                watcher_service = f"{ip_address}:{DEFAULT_WATCHER_PORT}"
         | 
| 693 | 
            +
                values = {
         | 
| 694 | 
            +
                    CLUSTER_NAME_KEY: cluster_name,
         | 
| 695 | 
            +
                    CLUSTER_IP_KEY: ip_address,
         | 
| 696 | 
            +
                    AUTH_KEY: auth_key,
         | 
| 697 | 
            +
                    READONLY_AUTH_KEY: readonly_auth_key,
         | 
| 698 | 
            +
                    WRITE_AUTH_KEY: write_auth_key,
         | 
| 699 | 
            +
                    WATCHER_PORT_KEY: DEFAULT_WATCHER_PORT,
         | 
| 700 | 
            +
                    WATCHER_SERVICE_KEY: watcher_service,
         | 
| 701 | 
            +
                    USER_NODE_LABEL_KEY: USER_NODE_LABEL,
         | 
| 702 | 
            +
                    ALLOW_UNREGISTERED_USER_KEY: not only_registered_users
         | 
| 703 | 
            +
                }
         | 
| 704 | 
            +
             | 
| 705 | 
            +
                store_server_info(
         | 
| 706 | 
            +
                    server_ip=ip_address,
         | 
| 707 | 
            +
                    auth_key=auth_key,
         | 
| 708 | 
            +
                    readonly_auth_key=readonly_auth_key,
         | 
| 709 | 
            +
                    write_auth_key=write_auth_key,
         | 
| 710 | 
            +
                    file=USER_LOCAL_SERVER_FILE,
         | 
| 711 | 
            +
                    watcher_service=watcher_service,
         | 
| 712 | 
            +
                    node_name=node_name,
         | 
| 713 | 
            +
                    cluster_name=cluster_name,
         | 
| 714 | 
            +
                    public_location=location,
         | 
| 715 | 
            +
                    user_api_key=user["api_key"])
         | 
| 716 | 
            +
                
         | 
| 717 | 
            +
                # Generate helmfile recipe
         | 
| 718 | 
            +
                helm_yaml = load_template(
         | 
| 719 | 
            +
                    template_path=HELM_APPS_FILE,
         | 
| 720 | 
            +
                    values=values,
         | 
| 721 | 
            +
                    default_values_path=app_values,
         | 
| 722 | 
            +
                    force_defaults=True)
         | 
| 723 | 
            +
                with open(USER_HELM_APPS_FILE, "w") as f:
         | 
| 724 | 
            +
                    f.write(helm_yaml)
         | 
| 725 | 
            +
                
         | 
| 726 | 
            +
                # set template values in helmfile
         | 
| 727 | 
            +
                try:
         | 
| 728 | 
            +
                    CLUSTER.update_dependencies(
         | 
| 729 | 
            +
                        dependencies_file=USER_HELM_APPS_FILE
         | 
| 730 | 
            +
                    )
         | 
| 731 | 
            +
                except Exception as e:
         | 
| 732 | 
            +
                    return {"error": f"Error when updating dependencies: {str(e)}"}
         | 
| 733 | 
            +
             | 
| 734 | 
            +
                if location is not None:
         | 
| 735 | 
            +
                    # TODO: register with kalavai if it's a public cluster
         | 
| 736 | 
            +
                    pass
         | 
| 737 | 
            +
                    #pool__publish()
         | 
| 738 | 
            +
                
         | 
| 739 | 
            +
                # wait until the server is ready to create objects
         | 
| 740 | 
            +
                while True:
         | 
| 741 | 
            +
                    time.sleep(30)
         | 
| 742 | 
            +
                    if is_watcher_alive(server_creds=USER_LOCAL_SERVER_FILE, user_cookie=USER_COOKIE):
         | 
| 743 | 
            +
                        break
         | 
| 744 | 
            +
             | 
| 745 | 
            +
                result = pool_init(pool_config_values_path=pool_config_values)
         | 
| 746 | 
            +
                if "error" in result or ("failed" in result and len(result['failed']) > 0):
         | 
| 747 | 
            +
                    return {"error": f"Error when initialising pool: {result}"}
         | 
| 748 | 
            +
                # init default namespace
         | 
| 749 | 
            +
                init_user_workspace(force_namespace="default")
         | 
| 750 | 
            +
                if only_registered_users:
         | 
| 751 | 
            +
                    # init user namespace
         | 
| 752 | 
            +
                    init_user_workspace()
         | 
| 753 | 
            +
                
         | 
| 754 | 
            +
                return {"success"}
         | 
| 755 | 
            +
             | 
| 756 | 
            +
            def pool_init(pool_config_values_path=None):
         | 
| 757 | 
            +
                """Deploy configured objects to initialise pool"""
         | 
| 758 | 
            +
                if pool_config_values_path is None:
         | 
| 759 | 
            +
                    return
         | 
| 760 | 
            +
                
         | 
| 761 | 
            +
                # load template config and populate with values
         | 
| 762 | 
            +
                sidecar_template_yaml = load_template(
         | 
| 763 | 
            +
                    template_path=POOL_CONFIG_TEMPLATE,
         | 
| 764 | 
            +
                    values={},
         | 
| 765 | 
            +
                    default_values_path=pool_config_values_path)
         | 
| 766 | 
            +
             | 
| 767 | 
            +
                try:
         | 
| 768 | 
            +
                    result = request_to_server(
         | 
| 769 | 
            +
                        method="post",
         | 
| 770 | 
            +
                        endpoint="/v1/deploy_generic_model",
         | 
| 771 | 
            +
                        data={"config": sidecar_template_yaml},
         | 
| 772 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 773 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 774 | 
            +
                    )
         | 
| 775 | 
            +
                    return result
         | 
| 776 | 
            +
                except Exception as e:
         | 
| 777 | 
            +
                    return {"error": f"[red]Error when connecting to kalavai service: {str(e)}"}
         | 
| 778 | 
            +
             | 
| 779 | 
            +
            def is_connected():
         | 
| 780 | 
            +
                if not os.path.isfile(USER_LOCAL_SERVER_FILE):
         | 
| 781 | 
            +
                    return False
         | 
| 782 | 
            +
                return is_watcher_alive(server_creds=USER_LOCAL_SERVER_FILE, user_cookie=USER_COOKIE, timeout=10)
         | 
| 783 | 
            +
             | 
| 784 | 
            +
            def is_agent_running():
         | 
| 785 | 
            +
                return CLUSTER.is_agent_running()
         | 
| 786 | 
            +
             | 
| 787 | 
            +
            def is_server():
         | 
| 788 | 
            +
                return CLUSTER.is_seed_node()
         | 
| 789 | 
            +
             | 
| 790 | 
            +
            def pause_agent(retries=3):
         | 
| 791 | 
            +
                try:
         | 
| 792 | 
            +
                    while retries > 0:
         | 
| 793 | 
            +
                        state = CLUSTER.pause_agent()
         | 
| 794 | 
            +
                        if state:
         | 
| 795 | 
            +
                            return {"success"}
         | 
| 796 | 
            +
                        time.sleep(5)
         | 
| 797 | 
            +
                        retries -= 1
         | 
| 798 | 
            +
                except:
         | 
| 799 | 
            +
                    return {"error": "Could not pause agent"}
         | 
| 800 | 
            +
             | 
| 801 | 
            +
            def resume_agent(retries=3):
         | 
| 802 | 
            +
                try:
         | 
| 803 | 
            +
                    while retries > 0:
         | 
| 804 | 
            +
                        state = CLUSTER.restart_agent()
         | 
| 805 | 
            +
                        if state:
         | 
| 806 | 
            +
                            return {"success"}
         | 
| 807 | 
            +
                        time.sleep(5)
         | 
| 808 | 
            +
                        retries -= 1
         | 
| 809 | 
            +
                except:
         | 
| 810 | 
            +
                    return {"error": "Could not resume agent"}
         | 
| 811 | 
            +
             | 
| 812 | 
            +
            def cleanup_local():
         | 
| 813 | 
            +
                safe_remove(CONTAINER_HOST_PATH)
         | 
| 814 | 
            +
                safe_remove(USER_COMPOSE_FILE)
         | 
| 815 | 
            +
                safe_remove(USER_VPN_COMPOSE_FILE)
         | 
| 816 | 
            +
                safe_remove(USER_HELM_APPS_FILE)
         | 
| 817 | 
            +
                safe_remove(USER_KUBECONFIG_FILE)
         | 
| 818 | 
            +
                safe_remove(USER_LOCAL_SERVER_FILE)
         | 
| 819 | 
            +
                safe_remove(USER_TEMPLATES_FOLDER)
         | 
| 820 | 
            +
             | 
| 821 | 
            +
            def delete_node(name):
         | 
| 822 | 
            +
                data = {
         | 
| 823 | 
            +
                    "node_names": [name]
         | 
| 824 | 
            +
                }
         | 
| 825 | 
            +
                try:
         | 
| 826 | 
            +
                    result = request_to_server(
         | 
| 827 | 
            +
                        method="post",
         | 
| 828 | 
            +
                        endpoint="/v1/delete_nodes",
         | 
| 829 | 
            +
                        data=data,
         | 
| 830 | 
            +
                        server_creds=USER_LOCAL_SERVER_FILE,
         | 
| 831 | 
            +
                        user_cookie=USER_COOKIE
         | 
| 832 | 
            +
                    )
         | 
| 833 | 
            +
                    if result is None or result is True:
         | 
| 834 | 
            +
                        return {f"Node {name} deleted successfully"}
         | 
| 835 | 
            +
                    else:
         | 
| 836 | 
            +
                        return {"error": result}
         | 
| 837 | 
            +
                except Exception as e:
         | 
| 838 | 
            +
                    return {"error": str(e)}
         | 
| 839 | 
            +
             | 
| 840 | 
            +
            def stop_pool(skip_node_deletion=False):
         | 
| 841 | 
            +
                # delete local node from server
         | 
| 842 | 
            +
                logs = []
         | 
| 843 | 
            +
                if not skip_node_deletion:
         | 
| 844 | 
            +
                    logs.append(
         | 
| 845 | 
            +
                        delete_node(load_server_info(data_key=NODE_NAME_KEY, file=USER_LOCAL_SERVER_FILE))
         | 
| 846 | 
            +
                    )
         | 
| 847 | 
            +
                # unpublish event (only if seed node)
         | 
| 848 | 
            +
                # TODO: no, this should be done via the platform!!!
         | 
| 849 | 
            +
                # try:
         | 
| 850 | 
            +
                #     if CLUSTER.is_seed_node():
         | 
| 851 | 
            +
                #         console.log("Unregistering pool...")
         | 
| 852 | 
            +
                #         unregister_cluster(
         | 
| 853 | 
            +
                #             name=load_server_info(data_key=CLUSTER_NAME_KEY, file=USER_LOCAL_SERVER_FILE),
         | 
| 854 | 
            +
                #             user_cookie=USER_COOKIE)
         | 
| 855 | 
            +
                # except Exception as e:
         | 
| 856 | 
            +
                #     console.log(f"[red][WARNING]: (ignore if not a public pool) Error when unpublishing cluster. {str(e)}")
         | 
| 857 | 
            +
                # remove local node agent
         | 
| 858 | 
            +
                
         | 
| 859 | 
            +
                # disconnect from VPN first, then remove agent, then remove local files
         | 
| 860 | 
            +
                try:
         | 
| 861 | 
            +
                    vpns = leave_vpn(container_name=DEFAULT_VPN_CONTAINER_NAME)
         | 
| 862 | 
            +
                    if vpns is not None:
         | 
| 863 | 
            +
                        for vpn in vpns:
         | 
| 864 | 
            +
                            logs.append(f"You have left {vpn} VPN")
         | 
| 865 | 
            +
                except:
         | 
| 866 | 
            +
                    # no vpn
         | 
| 867 | 
            +
                    pass
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                CLUSTER.remove_agent()
         | 
| 870 | 
            +
             | 
| 871 | 
            +
                # clean local files
         | 
| 872 | 
            +
                cleanup_local()
         | 
| 873 | 
            +
             | 
| 874 | 
            +
                return logs
         | 
| 875 | 
            +
             | 
| 876 | 
            +
            def list_available_pools(user_only=False):
         | 
| 877 | 
            +
                pools = get_public_seeds(user_only=user_only, user_cookie=USER_COOKIE)
         | 
| 878 | 
            +
                return pools
         |