gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
gen_worker/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
# Make src/gen_worker a Python package
|
|
2
|
+
from .decorators import worker_function, ResourceRequirements
|
|
3
|
+
from .worker import ActionContext
|
|
4
|
+
from .errors import RetryableError, FatalError
|
|
5
|
+
from .types import Asset
|
|
6
|
+
from .model_interface import ModelManager
|
|
7
|
+
from .downloader import ModelDownloader, CozyHubDownloader
|
|
8
|
+
|
|
9
|
+
__all__ = [
|
|
10
|
+
"worker_function",
|
|
11
|
+
"ResourceRequirements",
|
|
12
|
+
"ActionContext",
|
|
13
|
+
"RetryableError",
|
|
14
|
+
"FatalError",
|
|
15
|
+
"Asset",
|
|
16
|
+
"ModelManager",
|
|
17
|
+
"ModelDownloader",
|
|
18
|
+
"CozyHubDownloader",
|
|
19
|
+
]
|
gen_worker/decorators.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Callable, Any, Dict, Optional, TypeVar, cast
|
|
3
|
+
|
|
4
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
|
5
|
+
|
|
6
|
+
class ResourceRequirements:
|
|
7
|
+
"""
|
|
8
|
+
Specifies the resource requirements for a worker function.
|
|
9
|
+
"""
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
model_name: Optional[str] = None,
|
|
13
|
+
model_family: Optional[str] = None,
|
|
14
|
+
min_vram_gb: Optional[float] = None,
|
|
15
|
+
recommended_vram_gb: Optional[float] = None,
|
|
16
|
+
requires_gpu: bool = False,
|
|
17
|
+
expects_pipeline_arg: bool = False,
|
|
18
|
+
max_concurrency: Optional[int] = None
|
|
19
|
+
# Add other potential requirements here:
|
|
20
|
+
# e.g., cpu_cores: Optional[int] = None,
|
|
21
|
+
# specific_accelerators: Optional[list[str]] = None,
|
|
22
|
+
# etc.
|
|
23
|
+
) -> None:
|
|
24
|
+
self.model_name = model_name
|
|
25
|
+
self.model_family = model_family
|
|
26
|
+
self.min_vram_gb = min_vram_gb
|
|
27
|
+
self.recommended_vram_gb = recommended_vram_gb
|
|
28
|
+
self.requires_gpu = requires_gpu
|
|
29
|
+
self.expects_pipeline_arg = expects_pipeline_arg
|
|
30
|
+
self.max_concurrency = max_concurrency
|
|
31
|
+
# Store all defined attributes for easy access
|
|
32
|
+
self._requirements = {k: v for k, v in locals().items() if k != 'self' and v is not None}
|
|
33
|
+
|
|
34
|
+
def to_dict(self) -> Dict[str, Any]:
|
|
35
|
+
"""Returns a dictionary representation of the defined requirements."""
|
|
36
|
+
return self._requirements
|
|
37
|
+
|
|
38
|
+
def __repr__(self) -> str:
|
|
39
|
+
return f"ResourceRequirements({self._requirements})"
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def worker_function(
|
|
43
|
+
resources: Optional[ResourceRequirements] = None,
|
|
44
|
+
) -> Callable[[F], F]:
|
|
45
|
+
"""
|
|
46
|
+
Decorator to mark a function as a worker task and associate resource requirements.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
resources: An optional ResourceRequirements object describing the function's needs.
|
|
50
|
+
"""
|
|
51
|
+
if resources is None:
|
|
52
|
+
resources = ResourceRequirements() # Default empty requirements
|
|
53
|
+
|
|
54
|
+
def decorator(func: F) -> F:
|
|
55
|
+
# Attach metadata directly to the function object.
|
|
56
|
+
# The SDK's runner component will look for these attributes.
|
|
57
|
+
setattr(func, '_is_worker_function', True)
|
|
58
|
+
setattr(func, '_worker_resources', resources)
|
|
59
|
+
|
|
60
|
+
# Return the original function, now marked with attributes.
|
|
61
|
+
# Use functools.wraps to preserve original function metadata (like __name__, __doc__)
|
|
62
|
+
# even though we are returning the function itself.
|
|
63
|
+
wrapped = functools.wraps(func)(func)
|
|
64
|
+
return cast(F, wrapped)
|
|
65
|
+
|
|
66
|
+
return decorator
|
gen_worker/downloader.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import hashlib
|
|
3
|
+
import os
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Optional
|
|
7
|
+
from urllib.parse import urlparse
|
|
8
|
+
|
|
9
|
+
import aiohttp
|
|
10
|
+
import backoff
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModelDownloader(ABC):
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def download(self, model_ref: str, dest_dir: str, filename: Optional[str] = None) -> str:
|
|
17
|
+
"""Download a model artifact and return the local file path."""
|
|
18
|
+
raise NotImplementedError
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class DownloadResult:
|
|
23
|
+
path: str
|
|
24
|
+
sha256: Optional[str] = None
|
|
25
|
+
bytes_written: int = 0
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class CozyHubDownloader(ModelDownloader):
|
|
29
|
+
"""
|
|
30
|
+
Simple async downloader for Cozy hub model artifacts.
|
|
31
|
+
|
|
32
|
+
If model_ref is a full URL (http/https), it is fetched directly.
|
|
33
|
+
Otherwise, it is appended to base_url as a path segment.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, base_url: str, token: Optional[str] = None, timeout_seconds: int = 120) -> None:
|
|
37
|
+
self.base_url = base_url.rstrip("/")
|
|
38
|
+
self.token = token
|
|
39
|
+
self.timeout_seconds = timeout_seconds
|
|
40
|
+
|
|
41
|
+
def _resolve_url(self, model_ref: str) -> str:
|
|
42
|
+
ref = model_ref.strip()
|
|
43
|
+
if ref.startswith("http://") or ref.startswith("https://"):
|
|
44
|
+
return ref
|
|
45
|
+
if not self.base_url:
|
|
46
|
+
raise ValueError("COZY_HUB_URL is required for non-URL model_ref")
|
|
47
|
+
return f"{self.base_url}/{ref.lstrip('/')}"
|
|
48
|
+
|
|
49
|
+
def _default_filename(self, url: str) -> str:
|
|
50
|
+
path = urlparse(url).path
|
|
51
|
+
name = os.path.basename(path)
|
|
52
|
+
return name or "model.bin"
|
|
53
|
+
|
|
54
|
+
@backoff.on_exception(backoff.expo, (aiohttp.ClientError, asyncio.TimeoutError), max_tries=5)
|
|
55
|
+
async def download(self, model_ref: str, dest_dir: str, filename: Optional[str] = None) -> str:
|
|
56
|
+
url = self._resolve_url(model_ref)
|
|
57
|
+
os.makedirs(dest_dir, exist_ok=True)
|
|
58
|
+
target_name = filename or self._default_filename(url)
|
|
59
|
+
target_path = os.path.join(dest_dir, target_name)
|
|
60
|
+
|
|
61
|
+
headers = {}
|
|
62
|
+
if self.token:
|
|
63
|
+
headers["Authorization"] = f"Bearer {self.token}"
|
|
64
|
+
|
|
65
|
+
timeout = aiohttp.ClientTimeout(total=self.timeout_seconds)
|
|
66
|
+
sha256 = hashlib.sha256()
|
|
67
|
+
bytes_written = 0
|
|
68
|
+
|
|
69
|
+
async with aiohttp.ClientSession(timeout=timeout, headers=headers) as session:
|
|
70
|
+
async with session.get(url) as resp:
|
|
71
|
+
resp.raise_for_status()
|
|
72
|
+
total = resp.content_length or 0
|
|
73
|
+
progress = tqdm(total=total, unit="B", unit_scale=True, desc=f"download {target_name}")
|
|
74
|
+
with open(target_path, "wb") as f:
|
|
75
|
+
async for chunk in resp.content.iter_chunked(1 << 20):
|
|
76
|
+
if not chunk:
|
|
77
|
+
continue
|
|
78
|
+
f.write(chunk)
|
|
79
|
+
sha256.update(chunk)
|
|
80
|
+
bytes_written += len(chunk)
|
|
81
|
+
progress.update(len(chunk))
|
|
82
|
+
progress.close()
|
|
83
|
+
|
|
84
|
+
return target_path
|
gen_worker/entrypoint.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
import asyncio
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
# Ensure the package source is potentially discoverable if running locally
|
|
8
|
+
# In a proper install, this might not be strictly necessary
|
|
9
|
+
# but helps during development if the current dir is the repo root.
|
|
10
|
+
# script_dir = os.path.dirname(os.path.abspath(__file__))
|
|
11
|
+
# sys.path.insert(0, os.path.dirname(script_dir))
|
|
12
|
+
|
|
13
|
+
try:
|
|
14
|
+
from .worker import Worker
|
|
15
|
+
from .model_interface import ModelManagementInterface
|
|
16
|
+
except ImportError as e:
|
|
17
|
+
print(f"Error importing Worker: {e}", file=sys.stderr)
|
|
18
|
+
print("Please ensure the gen_worker package is installed or accessible in PYTHONPATH.", file=sys.stderr)
|
|
19
|
+
sys.exit(1)
|
|
20
|
+
|
|
21
|
+
# Optional Default Model Management Components
|
|
22
|
+
DMM_AVAILABLE = False
|
|
23
|
+
DefaultModelManager_cls = None
|
|
24
|
+
dmm_load_config_func = None
|
|
25
|
+
dmm_set_config_func = None
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
from .torch_manager import (
|
|
29
|
+
DefaultModelManager, # The class itself
|
|
30
|
+
load_config, # The config loading utility
|
|
31
|
+
set_config # The config setting utility
|
|
32
|
+
# ModelManager as DefaultModelDownloader # If you export your downloader too
|
|
33
|
+
)
|
|
34
|
+
DefaultModelManager_cls = DefaultModelManager
|
|
35
|
+
dmm_load_config_func = load_config
|
|
36
|
+
dmm_set_config_func = set_config
|
|
37
|
+
# DefaultModelDownloader_cls = DefaultModelDownloader
|
|
38
|
+
DMM_AVAILABLE = True
|
|
39
|
+
except ImportError:
|
|
40
|
+
# This is not necessarily an error if user doesn't intend to use default MMM
|
|
41
|
+
pass # logging will happen in main
|
|
42
|
+
|
|
43
|
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
|
44
|
+
logger = logging.getLogger('WorkerEntrypoint')
|
|
45
|
+
|
|
46
|
+
if DMM_AVAILABLE:
|
|
47
|
+
logger.info("DefaultModelManager components are available in this gen-worker installation.")
|
|
48
|
+
else:
|
|
49
|
+
logger.info("DefaultModelManager components not found. Built-in dynamic model management will be unavailable "
|
|
50
|
+
"unless ENABLE_DEFAULT_MODEL_MANAGER is explicitly set to true.")
|
|
51
|
+
|
|
52
|
+
# --- Configuration ---
|
|
53
|
+
# Read from environment variables or set defaults
|
|
54
|
+
SCHEDULER_ADDR = os.getenv('SCHEDULER_ADDR', 'localhost:8080')
|
|
55
|
+
SCHEDULER_ADDRS = os.getenv('SCHEDULER_ADDRS', '')
|
|
56
|
+
SEED_ADDRS = [addr.strip() for addr in SCHEDULER_ADDRS.split(',') if addr.strip()]
|
|
57
|
+
|
|
58
|
+
# Default user module name, can be overridden by environment variable
|
|
59
|
+
default_user_modules = 'functions' # A sensible default
|
|
60
|
+
user_modules_str = os.getenv('USER_MODULES', default_user_modules)
|
|
61
|
+
USER_MODULES = [mod.strip() for mod in user_modules_str.split(',') if mod.strip()]
|
|
62
|
+
|
|
63
|
+
WORKER_ID = os.getenv('WORKER_ID', "worker-1") # Optional, will be generated if None
|
|
64
|
+
AUTH_TOKEN = os.getenv('AUTH_TOKEN') or os.getenv('WORKER_JWT') # Optional
|
|
65
|
+
USE_TLS = os.getenv('USE_TLS', 'false').lower() in ('true', '1', 't')
|
|
66
|
+
RECONNECT_DELAY = int(os.getenv('RECONNECT_DELAY', '5'))
|
|
67
|
+
MAX_RECONNECT_ATTEMPTS = int(os.getenv('MAX_RECONNECT_ATTEMPTS', '0'))
|
|
68
|
+
ENABLE_DEFAULT_MODEL_MANAGER = os.getenv('ENABLE_DEFAULT_MODEL_MANAGER', 'false').lower() in ('true', '1', 't')
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
if __name__ == '__main__':
|
|
72
|
+
|
|
73
|
+
# if sys.platform == "win32":
|
|
74
|
+
# asyncio.set_event_loop_policy(asyncio.WindowsProactorEventLoopPolicy()) # Or SelectorEventLoopPolicy
|
|
75
|
+
|
|
76
|
+
logger.info(f'Starting worker...')
|
|
77
|
+
logger.info(f' Scheduler Address: {SCHEDULER_ADDR}')
|
|
78
|
+
if SEED_ADDRS:
|
|
79
|
+
logger.info(f' Scheduler Seeds: {SEED_ADDRS}')
|
|
80
|
+
logger.info(f' User Function Modules: {USER_MODULES}')
|
|
81
|
+
logger.info(f' Worker ID: {WORKER_ID or "(generated)"}')
|
|
82
|
+
logger.info(f' Use TLS: {USE_TLS}')
|
|
83
|
+
logger.info(f' Reconnect Delay: {RECONNECT_DELAY}s')
|
|
84
|
+
logger.info(f' Max Reconnect Attempts: {MAX_RECONNECT_ATTEMPTS or "Infinite"}')
|
|
85
|
+
logger.info(f' Enable Default Model Manager: {ENABLE_DEFAULT_MODEL_MANAGER}')
|
|
86
|
+
|
|
87
|
+
if not USER_MODULES:
|
|
88
|
+
logger.error("No user function modules specified. Set the USER_MODULES environment variable.")
|
|
89
|
+
sys.exit(1)
|
|
90
|
+
|
|
91
|
+
model_manager_instance_to_pass = None
|
|
92
|
+
|
|
93
|
+
if ENABLE_DEFAULT_MODEL_MANAGER:
|
|
94
|
+
if DMM_AVAILABLE and DefaultModelManager_cls and dmm_load_config_func and dmm_set_config_func:
|
|
95
|
+
logger.info("DefaultModelManager is ENABLED and AVAILABLE. Initializing...")
|
|
96
|
+
try:
|
|
97
|
+
# Load config (e.g., from DB/YAML) needed by DefaultModelManager and its utils
|
|
98
|
+
app_cfg = dmm_load_config_func()
|
|
99
|
+
dmm_set_config_func(app_cfg) # Set it globally for utils used by DMM
|
|
100
|
+
logger.info("Application configuration loaded for DefaultModelManager.")
|
|
101
|
+
|
|
102
|
+
model_manager_instance_to_pass = cast(ModelManagementInterface, DefaultModelManager_cls())
|
|
103
|
+
logger.info("DefaultModelManager instance created.")
|
|
104
|
+
except Exception as e_dmm_init:
|
|
105
|
+
logger.exception(f"Failed to initialize DefaultModelManager: {e_dmm_init}. "
|
|
106
|
+
"Proceeding without dynamic model management.")
|
|
107
|
+
model_manager_instance_to_pass = None
|
|
108
|
+
else:
|
|
109
|
+
logger.warning("ENABLE_DEFAULT_MODEL_MANAGER is true, but DefaultModelManager components "
|
|
110
|
+
"are not fully available/imported. Proceeding without dynamic model management.")
|
|
111
|
+
else:
|
|
112
|
+
logger.info("ENABLE_DEFAULT_MODEL_MANAGER is false. Worker will run without built-in dynamic model management.")
|
|
113
|
+
|
|
114
|
+
try:
|
|
115
|
+
worker = Worker(
|
|
116
|
+
scheduler_addr=SCHEDULER_ADDR,
|
|
117
|
+
scheduler_addrs=SEED_ADDRS,
|
|
118
|
+
user_module_names=USER_MODULES,
|
|
119
|
+
worker_id=WORKER_ID,
|
|
120
|
+
auth_token=AUTH_TOKEN,
|
|
121
|
+
use_tls=USE_TLS,
|
|
122
|
+
reconnect_delay=RECONNECT_DELAY,
|
|
123
|
+
max_reconnect_attempts=MAX_RECONNECT_ATTEMPTS,
|
|
124
|
+
model_manager=model_manager_instance_to_pass
|
|
125
|
+
)
|
|
126
|
+
# This blocks until the worker stops
|
|
127
|
+
worker.run()
|
|
128
|
+
logger.info('Worker process finished gracefully.')
|
|
129
|
+
sys.exit(0)
|
|
130
|
+
except ImportError as e:
|
|
131
|
+
logger.exception(f"Failed to import user module(s) or dependencies: {e}. Make sure modules '{USER_MODULES}' and their requirements are installed.")
|
|
132
|
+
sys.exit(1)
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.exception(f"Worker failed unexpectedly: {e}")
|
|
135
|
+
sys.exit(1)
|
gen_worker/errors.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
from typing import List, Any, Optional, Dict
|
|
3
|
+
|
|
4
|
+
DownloaderType = Any
|
|
5
|
+
|
|
6
|
+
class ModelManagementInterface(ABC):
|
|
7
|
+
@abstractmethod
|
|
8
|
+
async def process_supported_models_config(
|
|
9
|
+
self,
|
|
10
|
+
supported_model_ids: List[str],
|
|
11
|
+
downloader_instance: Optional[DownloaderType]
|
|
12
|
+
) -> None:
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def load_model_into_vram(self, model_id: str) -> bool:
|
|
17
|
+
pass
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def get_active_pipeline(self, model_id: str) -> Optional[Any]:
|
|
21
|
+
pass
|
|
22
|
+
|
|
23
|
+
@abstractmethod
|
|
24
|
+
def get_vram_loaded_models(self) -> List[str]:
|
|
25
|
+
pass
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class ModelManager(ABC):
|
|
29
|
+
"""
|
|
30
|
+
Core model manager interface (no torch imports).
|
|
31
|
+
Implementations are responsible for loading/unloading models into memory.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def load(self, model_ref: str, local_path: Optional[str] = None, **opts: Any) -> Any:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
def get(self, model_ref: str) -> Optional[Any]:
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
@abstractmethod
|
|
43
|
+
def unload(self, model_ref: str) -> None:
|
|
44
|
+
pass
|
|
45
|
+
|
|
46
|
+
@abstractmethod
|
|
47
|
+
def stats(self) -> Dict[str, Any]:
|
|
48
|
+
pass
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
"""Protocol buffer module for worker communication."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any
|
|
7
|
+
|
|
8
|
+
from . import frontend_pb2, worker_scheduler_pb2
|
|
9
|
+
|
|
10
|
+
__all__ = [
|
|
11
|
+
"worker_scheduler_pb2",
|
|
12
|
+
"worker_scheduler_pb2_grpc",
|
|
13
|
+
"frontend_pb2",
|
|
14
|
+
"frontend_pb2_grpc",
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
# Compatibility: generated grpc stubs use absolute imports (e.g., worker_scheduler_pb2).
|
|
18
|
+
# Register module aliases before importing *_grpc stubs.
|
|
19
|
+
sys.modules.setdefault("worker_scheduler_pb2", worker_scheduler_pb2) # type: ignore[arg-type]
|
|
20
|
+
sys.modules.setdefault("frontend_pb2", frontend_pb2) # type: ignore[arg-type]
|
|
21
|
+
|
|
22
|
+
# Import grpc stubs after aliases are registered.
|
|
23
|
+
from . import frontend_pb2_grpc, worker_scheduler_pb2_grpc
|
|
24
|
+
|
|
25
|
+
# Also expose grpc stubs for convenience.
|
|
26
|
+
sys.modules.setdefault("worker_scheduler_pb2_grpc", worker_scheduler_pb2_grpc) # type: ignore[arg-type]
|
|
27
|
+
sys.modules.setdefault("frontend_pb2_grpc", frontend_pb2_grpc) # type: ignore[arg-type]
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# -*- coding: utf-8 -*-
|
|
2
|
+
# Generated by the protocol buffer compiler. DO NOT EDIT!
|
|
3
|
+
# NO CHECKED-IN PROTOBUF GENCODE
|
|
4
|
+
# source: frontend.proto
|
|
5
|
+
# Protobuf Python Version: 5.27.4
|
|
6
|
+
"""Generated protocol buffer code."""
|
|
7
|
+
from google.protobuf import descriptor as _descriptor
|
|
8
|
+
from google.protobuf import descriptor_pool as _descriptor_pool
|
|
9
|
+
from google.protobuf import runtime_version as _runtime_version
|
|
10
|
+
from google.protobuf import symbol_database as _symbol_database
|
|
11
|
+
from google.protobuf.internal import builder as _builder
|
|
12
|
+
_runtime_version.ValidateProtobufRuntimeVersion(
|
|
13
|
+
_runtime_version.Domain.PUBLIC,
|
|
14
|
+
5,
|
|
15
|
+
27,
|
|
16
|
+
4,
|
|
17
|
+
'',
|
|
18
|
+
'frontend.proto'
|
|
19
|
+
)
|
|
20
|
+
# @@protoc_insertion_point(imports)
|
|
21
|
+
|
|
22
|
+
_sym_db = _symbol_database.Default()
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0e\x66rontend.proto\x12\x08\x66rontend\"\\\n\rActionOptions\x12\n\n\x02id\x18\x01 \x01(\t\x12\x12\n\ntimeout_ms\x18\x02 \x01(\x03\x12+\n\x0cretry_policy\x18\x03 \x01(\x0b\x32\x15.frontend.RetryPolicy\"\xa1\x01\n\x0bRetryPolicy\x12\x1b\n\x13initial_interval_ms\x18\x01 \x01(\x03\x12\x1b\n\x13\x62\x61\x63koff_coefficient\x18\x02 \x01(\x02\x12\x1b\n\x13maximum_interval_ms\x18\x03 \x01(\x03\x12\x18\n\x10maximum_attempts\x18\x04 \x01(\x05\x12!\n\x19non_retryable_error_types\x18\x05 \x03(\t\"\xb1\x01\n\x14\x45xecuteActionRequest\x12\x15\n\rfunction_name\x18\x01 \x01(\t\x12\x15\n\rinput_payload\x18\x02 \x01(\x0c\x12(\n\x07options\x18\x03 \x01(\x0b\x32\x17.frontend.ActionOptions\x12\x15\n\rdeployment_id\x18\x04 \x01(\t\x12\x19\n\x11required_model_id\x18\x05 \x01(\t\x12\x0f\n\x07user_id\x18\x06 \x01(\t\"\'\n\x15\x45xecuteActionResponse\x12\x0e\n\x06run_id\x18\x01 \x01(\t\"\x1f\n\rGetRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\t\"P\n\x0eGetRunResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x16\n\x0eoutput_payload\x18\x02 \x01(\x0c\x12\x15\n\rerror_message\x18\x03 \x01(\t\"\"\n\x10\x43\x61ncelRunRequest\x12\x0e\n\x06run_id\x18\x01 \x01(\t\"\x1f\n\x11\x43\x61ncelRunResponse\x12\n\n\x02ok\x18\x01 \x01(\x08\x32\xe6\x01\n\x0f\x46rontendService\x12P\n\rExecuteAction\x12\x1e.frontend.ExecuteActionRequest\x1a\x1f.frontend.ExecuteActionResponse\x12;\n\x06GetRun\x12\x17.frontend.GetRunRequest\x1a\x18.frontend.GetRunResponse\x12\x44\n\tCancelRun\x12\x1a.frontend.CancelRunRequest\x1a\x1b.frontend.CancelRunResponseB:Z8github.com/cozy-creator/gen-orchestrator/pkg/pb/frontendb\x06proto3')
|
|
28
|
+
|
|
29
|
+
_globals = globals()
|
|
30
|
+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
|
|
31
|
+
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'frontend_pb2', _globals)
|
|
32
|
+
if not _descriptor._USE_C_DESCRIPTORS:
|
|
33
|
+
_globals['DESCRIPTOR']._loaded_options = None
|
|
34
|
+
_globals['DESCRIPTOR']._serialized_options = b'Z8github.com/cozy-creator/gen-orchestrator/pkg/pb/frontend'
|
|
35
|
+
_globals['_ACTIONOPTIONS']._serialized_start=28
|
|
36
|
+
_globals['_ACTIONOPTIONS']._serialized_end=120
|
|
37
|
+
_globals['_RETRYPOLICY']._serialized_start=123
|
|
38
|
+
_globals['_RETRYPOLICY']._serialized_end=284
|
|
39
|
+
_globals['_EXECUTEACTIONREQUEST']._serialized_start=287
|
|
40
|
+
_globals['_EXECUTEACTIONREQUEST']._serialized_end=464
|
|
41
|
+
_globals['_EXECUTEACTIONRESPONSE']._serialized_start=466
|
|
42
|
+
_globals['_EXECUTEACTIONRESPONSE']._serialized_end=505
|
|
43
|
+
_globals['_GETRUNREQUEST']._serialized_start=507
|
|
44
|
+
_globals['_GETRUNREQUEST']._serialized_end=538
|
|
45
|
+
_globals['_GETRUNRESPONSE']._serialized_start=540
|
|
46
|
+
_globals['_GETRUNRESPONSE']._serialized_end=620
|
|
47
|
+
_globals['_CANCELRUNREQUEST']._serialized_start=622
|
|
48
|
+
_globals['_CANCELRUNREQUEST']._serialized_end=656
|
|
49
|
+
_globals['_CANCELRUNRESPONSE']._serialized_start=658
|
|
50
|
+
_globals['_CANCELRUNRESPONSE']._serialized_end=689
|
|
51
|
+
_globals['_FRONTENDSERVICE']._serialized_start=692
|
|
52
|
+
_globals['_FRONTENDSERVICE']._serialized_end=922
|
|
53
|
+
# @@protoc_insertion_point(module_scope)
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
2
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
|
3
|
+
import grpc
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
6
|
+
import frontend_pb2 as frontend__pb2
|
|
7
|
+
|
|
8
|
+
GRPC_GENERATED_VERSION = '1.76.0'
|
|
9
|
+
GRPC_VERSION = grpc.__version__
|
|
10
|
+
_version_not_supported = False
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from grpc._utilities import first_version_is_lower
|
|
14
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
15
|
+
except ImportError:
|
|
16
|
+
_version_not_supported = True
|
|
17
|
+
|
|
18
|
+
if _version_not_supported:
|
|
19
|
+
raise RuntimeError(
|
|
20
|
+
f'The grpc package installed is at version {GRPC_VERSION},'
|
|
21
|
+
+ ' but the generated code in frontend_pb2_grpc.py depends on'
|
|
22
|
+
+ f' grpcio>={GRPC_GENERATED_VERSION}.'
|
|
23
|
+
+ f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}'
|
|
24
|
+
+ f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.'
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class FrontendServiceStub(object):
|
|
29
|
+
"""The gRPC service for client <-> scheduler.
|
|
30
|
+
"""
|
|
31
|
+
|
|
32
|
+
def __init__(self, channel):
|
|
33
|
+
"""Constructor.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
channel: A grpc.Channel.
|
|
37
|
+
"""
|
|
38
|
+
self.ExecuteAction = channel.unary_unary(
|
|
39
|
+
'/frontend.FrontendService/ExecuteAction',
|
|
40
|
+
request_serializer=frontend__pb2.ExecuteActionRequest.SerializeToString,
|
|
41
|
+
response_deserializer=frontend__pb2.ExecuteActionResponse.FromString,
|
|
42
|
+
_registered_method=True)
|
|
43
|
+
self.GetRun = channel.unary_unary(
|
|
44
|
+
'/frontend.FrontendService/GetRun',
|
|
45
|
+
request_serializer=frontend__pb2.GetRunRequest.SerializeToString,
|
|
46
|
+
response_deserializer=frontend__pb2.GetRunResponse.FromString,
|
|
47
|
+
_registered_method=True)
|
|
48
|
+
self.CancelRun = channel.unary_unary(
|
|
49
|
+
'/frontend.FrontendService/CancelRun',
|
|
50
|
+
request_serializer=frontend__pb2.CancelRunRequest.SerializeToString,
|
|
51
|
+
response_deserializer=frontend__pb2.CancelRunResponse.FromString,
|
|
52
|
+
_registered_method=True)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class FrontendServiceServicer(object):
|
|
56
|
+
"""The gRPC service for client <-> scheduler.
|
|
57
|
+
"""
|
|
58
|
+
|
|
59
|
+
def ExecuteAction(self, request, context):
|
|
60
|
+
"""1) Submit a new action/job to the scheduler.
|
|
61
|
+
"""
|
|
62
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
63
|
+
context.set_details('Method not implemented!')
|
|
64
|
+
raise NotImplementedError('Method not implemented!')
|
|
65
|
+
|
|
66
|
+
def GetRun(self, request, context):
|
|
67
|
+
"""2) Wait/await the final result of an existing action. Blocks until completed or error.
|
|
68
|
+
"""
|
|
69
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
70
|
+
context.set_details('Method not implemented!')
|
|
71
|
+
raise NotImplementedError('Method not implemented!')
|
|
72
|
+
|
|
73
|
+
def CancelRun(self, request, context):
|
|
74
|
+
"""3) Cancel an in-flight action/job.
|
|
75
|
+
"""
|
|
76
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
77
|
+
context.set_details('Method not implemented!')
|
|
78
|
+
raise NotImplementedError('Method not implemented!')
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def add_FrontendServiceServicer_to_server(servicer, server):
|
|
82
|
+
rpc_method_handlers = {
|
|
83
|
+
'ExecuteAction': grpc.unary_unary_rpc_method_handler(
|
|
84
|
+
servicer.ExecuteAction,
|
|
85
|
+
request_deserializer=frontend__pb2.ExecuteActionRequest.FromString,
|
|
86
|
+
response_serializer=frontend__pb2.ExecuteActionResponse.SerializeToString,
|
|
87
|
+
),
|
|
88
|
+
'GetRun': grpc.unary_unary_rpc_method_handler(
|
|
89
|
+
servicer.GetRun,
|
|
90
|
+
request_deserializer=frontend__pb2.GetRunRequest.FromString,
|
|
91
|
+
response_serializer=frontend__pb2.GetRunResponse.SerializeToString,
|
|
92
|
+
),
|
|
93
|
+
'CancelRun': grpc.unary_unary_rpc_method_handler(
|
|
94
|
+
servicer.CancelRun,
|
|
95
|
+
request_deserializer=frontend__pb2.CancelRunRequest.FromString,
|
|
96
|
+
response_serializer=frontend__pb2.CancelRunResponse.SerializeToString,
|
|
97
|
+
),
|
|
98
|
+
}
|
|
99
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
|
100
|
+
'frontend.FrontendService', rpc_method_handlers)
|
|
101
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
|
102
|
+
server.add_registered_method_handlers('frontend.FrontendService', rpc_method_handlers)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
# This class is part of an EXPERIMENTAL API.
|
|
106
|
+
class FrontendService(object):
|
|
107
|
+
"""The gRPC service for client <-> scheduler.
|
|
108
|
+
"""
|
|
109
|
+
|
|
110
|
+
@staticmethod
|
|
111
|
+
def ExecuteAction(request,
|
|
112
|
+
target,
|
|
113
|
+
options=(),
|
|
114
|
+
channel_credentials=None,
|
|
115
|
+
call_credentials=None,
|
|
116
|
+
insecure=False,
|
|
117
|
+
compression=None,
|
|
118
|
+
wait_for_ready=None,
|
|
119
|
+
timeout=None,
|
|
120
|
+
metadata=None):
|
|
121
|
+
return grpc.experimental.unary_unary(
|
|
122
|
+
request,
|
|
123
|
+
target,
|
|
124
|
+
'/frontend.FrontendService/ExecuteAction',
|
|
125
|
+
frontend__pb2.ExecuteActionRequest.SerializeToString,
|
|
126
|
+
frontend__pb2.ExecuteActionResponse.FromString,
|
|
127
|
+
options,
|
|
128
|
+
channel_credentials,
|
|
129
|
+
insecure,
|
|
130
|
+
call_credentials,
|
|
131
|
+
compression,
|
|
132
|
+
wait_for_ready,
|
|
133
|
+
timeout,
|
|
134
|
+
metadata,
|
|
135
|
+
_registered_method=True)
|
|
136
|
+
|
|
137
|
+
@staticmethod
|
|
138
|
+
def GetRun(request,
|
|
139
|
+
target,
|
|
140
|
+
options=(),
|
|
141
|
+
channel_credentials=None,
|
|
142
|
+
call_credentials=None,
|
|
143
|
+
insecure=False,
|
|
144
|
+
compression=None,
|
|
145
|
+
wait_for_ready=None,
|
|
146
|
+
timeout=None,
|
|
147
|
+
metadata=None):
|
|
148
|
+
return grpc.experimental.unary_unary(
|
|
149
|
+
request,
|
|
150
|
+
target,
|
|
151
|
+
'/frontend.FrontendService/GetRun',
|
|
152
|
+
frontend__pb2.GetRunRequest.SerializeToString,
|
|
153
|
+
frontend__pb2.GetRunResponse.FromString,
|
|
154
|
+
options,
|
|
155
|
+
channel_credentials,
|
|
156
|
+
insecure,
|
|
157
|
+
call_credentials,
|
|
158
|
+
compression,
|
|
159
|
+
wait_for_ready,
|
|
160
|
+
timeout,
|
|
161
|
+
metadata,
|
|
162
|
+
_registered_method=True)
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def CancelRun(request,
|
|
166
|
+
target,
|
|
167
|
+
options=(),
|
|
168
|
+
channel_credentials=None,
|
|
169
|
+
call_credentials=None,
|
|
170
|
+
insecure=False,
|
|
171
|
+
compression=None,
|
|
172
|
+
wait_for_ready=None,
|
|
173
|
+
timeout=None,
|
|
174
|
+
metadata=None):
|
|
175
|
+
return grpc.experimental.unary_unary(
|
|
176
|
+
request,
|
|
177
|
+
target,
|
|
178
|
+
'/frontend.FrontendService/CancelRun',
|
|
179
|
+
frontend__pb2.CancelRunRequest.SerializeToString,
|
|
180
|
+
frontend__pb2.CancelRunResponse.FromString,
|
|
181
|
+
options,
|
|
182
|
+
channel_credentials,
|
|
183
|
+
insecure,
|
|
184
|
+
call_credentials,
|
|
185
|
+
compression,
|
|
186
|
+
wait_for_ready,
|
|
187
|
+
timeout,
|
|
188
|
+
metadata,
|
|
189
|
+
_registered_method=True)
|