gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import json
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from .base_types.config import RuntimeConfig, PipelineConfig
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# ====== Parse cli arguments ======
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def parse_pipeline_defs(value: Optional[str]) -> dict[str, PipelineConfig]:
|
|
14
|
+
"""Parse pipeline definitions from command line argument"""
|
|
15
|
+
if not value:
|
|
16
|
+
return {}
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
loaded = json.loads(value)
|
|
20
|
+
if isinstance(loaded, dict):
|
|
21
|
+
return loaded
|
|
22
|
+
else:
|
|
23
|
+
logger.error("Pipeline definitions are not a dictionary")
|
|
24
|
+
return {}
|
|
25
|
+
except json.JSONDecodeError as e:
|
|
26
|
+
logger.error(f"Failed to parse pipeline definitions: {e}")
|
|
27
|
+
return {}
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def parse_enabled_models(value: Optional[str]) -> list[str]:
|
|
31
|
+
"""Parse enabled models from command line argument"""
|
|
32
|
+
if not value:
|
|
33
|
+
return []
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
return json.loads(value)
|
|
37
|
+
except json.JSONDecodeError:
|
|
38
|
+
return value.split(",")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def parse_arguments() -> RuntimeConfig:
|
|
42
|
+
"""Parse command line arguments and return configuration"""
|
|
43
|
+
parser = argparse.ArgumentParser(description="Cozy Creator")
|
|
44
|
+
|
|
45
|
+
default_config = RuntimeConfig()
|
|
46
|
+
|
|
47
|
+
parser.add_argument(
|
|
48
|
+
"--home-dir",
|
|
49
|
+
default=default_config.home_dir,
|
|
50
|
+
help="Cozy creator's home directory",
|
|
51
|
+
)
|
|
52
|
+
parser.add_argument(
|
|
53
|
+
"--environment",
|
|
54
|
+
default=default_config.environment,
|
|
55
|
+
help="Server environment (dev/prod)",
|
|
56
|
+
)
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--host", default=default_config.host, help="Hostname or IP-address"
|
|
59
|
+
)
|
|
60
|
+
parser.add_argument(
|
|
61
|
+
"--port",
|
|
62
|
+
type=int,
|
|
63
|
+
default=default_config.port,
|
|
64
|
+
help="Port to bind Python runtime to",
|
|
65
|
+
)
|
|
66
|
+
parser.add_argument(
|
|
67
|
+
"--pipeline-defs",
|
|
68
|
+
type=str,
|
|
69
|
+
default=default_config.pipeline_defs,
|
|
70
|
+
help="JSON string of pipeline definitions",
|
|
71
|
+
)
|
|
72
|
+
parser.add_argument(
|
|
73
|
+
"--enabled-models",
|
|
74
|
+
type=str,
|
|
75
|
+
default=default_config.enabled_models,
|
|
76
|
+
help="Comma-separated list or JSON array of models to warm up",
|
|
77
|
+
)
|
|
78
|
+
parser.add_argument(
|
|
79
|
+
"--models-path",
|
|
80
|
+
type=str,
|
|
81
|
+
default=default_config.models_path,
|
|
82
|
+
help="Path to models directory",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
args = parser.parse_args()
|
|
86
|
+
|
|
87
|
+
# Update config with parsed arguments
|
|
88
|
+
config = RuntimeConfig(
|
|
89
|
+
home_dir=args.home_dir,
|
|
90
|
+
environment=args.environment,
|
|
91
|
+
host=args.host,
|
|
92
|
+
port=args.port,
|
|
93
|
+
pipeline_defs=parse_pipeline_defs(args.pipeline_defs),
|
|
94
|
+
enabled_models=parse_enabled_models(args.enabled_models),
|
|
95
|
+
models_path=args.models_path or os.path.join(args.home_dir, "models"),
|
|
96
|
+
)
|
|
97
|
+
|
|
98
|
+
return config
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from .config import get_config
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def get_assets_dir():
|
|
7
|
+
config = get_config()
|
|
8
|
+
if config.assets_path:
|
|
9
|
+
return os.path.expanduser(config.assets_path)
|
|
10
|
+
return os.path.join(get_home_dir(), "assets")
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def get_models_dir():
|
|
14
|
+
config = get_config()
|
|
15
|
+
|
|
16
|
+
if config.models_path:
|
|
17
|
+
return os.path.expanduser(config.models_path)
|
|
18
|
+
return os.path.join(get_home_dir(), "models")
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def get_home_dir():
|
|
22
|
+
return os.path.expanduser(get_config().home_dir)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
import json
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Dict, Any, Optional
|
|
4
|
+
from decimal import Decimal
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
class PipelineDef:
|
|
9
|
+
"""
|
|
10
|
+
Model class for pipeline definitions from the database.
|
|
11
|
+
"""
|
|
12
|
+
def __init__(self, name: str, source: str = "", class_name: str = "",
|
|
13
|
+
custom_pipeline: str = "", default_args: Dict = None,
|
|
14
|
+
metadata: Dict = None, components: Dict = None,
|
|
15
|
+
prompt_def: Optional['PromptDef'] = None, estimated_size_gb: Optional[float] = None):
|
|
16
|
+
self.name = name
|
|
17
|
+
self.source = source
|
|
18
|
+
self.class_name = class_name
|
|
19
|
+
self.custom_pipeline = custom_pipeline
|
|
20
|
+
self.default_args = default_args or {}
|
|
21
|
+
self.metadata = metadata or {}
|
|
22
|
+
self.components = components or {}
|
|
23
|
+
self.prompt_def = prompt_def
|
|
24
|
+
self.estimated_size_gb = estimated_size_gb
|
|
25
|
+
|
|
26
|
+
class PromptDef:
|
|
27
|
+
"""
|
|
28
|
+
Model class for prompt definitions from the database.
|
|
29
|
+
"""
|
|
30
|
+
def __init__(self, pipeline_id: int, positive_prompt: str = "", negative_prompt: str = ""):
|
|
31
|
+
self.pipeline_id = pipeline_id
|
|
32
|
+
self.positive_prompt = positive_prompt
|
|
33
|
+
self.negative_prompt = negative_prompt
|
|
34
|
+
|
|
35
|
+
def get_pipeline_defs(db_conn, pipeline_names: List[str]) -> List[PipelineDef]:
|
|
36
|
+
"""
|
|
37
|
+
Retrieves pipeline definitions from the database based on their names.
|
|
38
|
+
Similar to the Go GetPipelineDefs function.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
db_conn: Database connection
|
|
42
|
+
pipeline_names: List of pipeline names to retrieve
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
List of PipelineDef objects
|
|
46
|
+
"""
|
|
47
|
+
if not pipeline_names:
|
|
48
|
+
return []
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
pipeline_defs = []
|
|
52
|
+
with db_conn.cursor() as cur:
|
|
53
|
+
# Query to get pipeline definitions with their prompt definitions
|
|
54
|
+
query = """
|
|
55
|
+
SELECT
|
|
56
|
+
p.id,
|
|
57
|
+
p.name,
|
|
58
|
+
p.source,
|
|
59
|
+
p.class_name,
|
|
60
|
+
p.custom_pipeline,
|
|
61
|
+
p.default_args,
|
|
62
|
+
p.metadata,
|
|
63
|
+
p.components,
|
|
64
|
+
p.estimated_size_bytes,
|
|
65
|
+
pr.positive_prompt,
|
|
66
|
+
pr.negative_prompt
|
|
67
|
+
FROM
|
|
68
|
+
pipeline_defs p
|
|
69
|
+
LEFT JOIN
|
|
70
|
+
prompt_defs pr ON p.prompt_def_id = pr.id
|
|
71
|
+
WHERE
|
|
72
|
+
p.name = ANY(%s)
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
cur.execute(query, (pipeline_names,))
|
|
76
|
+
rows = cur.fetchall()
|
|
77
|
+
|
|
78
|
+
for row in rows:
|
|
79
|
+
# Create prompt_def if available
|
|
80
|
+
prompt_def = None
|
|
81
|
+
if row['positive_prompt'] or row['negative_prompt']:
|
|
82
|
+
prompt_def = PromptDef(
|
|
83
|
+
pipeline_id=row['id'],
|
|
84
|
+
positive_prompt=row['positive_prompt'] or "",
|
|
85
|
+
negative_prompt=row['negative_prompt'] or ""
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
# Parse JSON fields
|
|
89
|
+
default_args = {}
|
|
90
|
+
if row['default_args']:
|
|
91
|
+
try:
|
|
92
|
+
default_args = json.loads(row['default_args']) if isinstance(row['default_args'], str) else row['default_args']
|
|
93
|
+
except json.JSONDecodeError:
|
|
94
|
+
logger.warning(f"Failed to parse default_args for pipeline {row['name']}")
|
|
95
|
+
|
|
96
|
+
metadata = {}
|
|
97
|
+
if row['metadata']:
|
|
98
|
+
try:
|
|
99
|
+
metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
|
|
100
|
+
except json.JSONDecodeError:
|
|
101
|
+
logger.warning(f"Failed to parse metadata for pipeline {row['name']}")
|
|
102
|
+
|
|
103
|
+
components = {}
|
|
104
|
+
if row['components']:
|
|
105
|
+
try:
|
|
106
|
+
components = json.loads(row['components']) if isinstance(row['components'], str) else row['components']
|
|
107
|
+
except json.JSONDecodeError:
|
|
108
|
+
logger.warning(f"Failed to parse components for pipeline {row['name']}")
|
|
109
|
+
|
|
110
|
+
estimated_size_val_gb = None # Changed name
|
|
111
|
+
if row['estimated_size_bytes'] is not None: # DB column name is still estimated_size_bytes
|
|
112
|
+
try:
|
|
113
|
+
if isinstance(row['estimated_size_bytes'], Decimal):
|
|
114
|
+
estimated_size_val_gb = float(row['estimated_size_bytes'])
|
|
115
|
+
else:
|
|
116
|
+
estimated_size_val_gb = float(str(row['estimated_size_bytes']))
|
|
117
|
+
except (ValueError, TypeError) as e:
|
|
118
|
+
logger.warning(f"Could not convert estimated_size_bytes ('{row['estimated_size_bytes']}') to float for {row['name']}: {e}")
|
|
119
|
+
|
|
120
|
+
# Create PipelineDef
|
|
121
|
+
pipeline_def = PipelineDef(
|
|
122
|
+
name=row['name'],
|
|
123
|
+
source=row['source'] or "",
|
|
124
|
+
class_name=row['class_name'] or "",
|
|
125
|
+
custom_pipeline=row['custom_pipeline'] or "",
|
|
126
|
+
default_args=default_args,
|
|
127
|
+
metadata=metadata,
|
|
128
|
+
components=components,
|
|
129
|
+
prompt_def=prompt_def,
|
|
130
|
+
estimated_size_gb=estimated_size_val_gb
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
pipeline_defs.append(pipeline_def)
|
|
134
|
+
|
|
135
|
+
return pipeline_defs
|
|
136
|
+
|
|
137
|
+
except Exception as e:
|
|
138
|
+
logger.error(f"Error retrieving pipeline definitions: {e}")
|
|
139
|
+
# Ensure transaction is rolled back
|
|
140
|
+
db_conn.rollback()
|
|
141
|
+
raise
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
import re
|
|
2
|
+
from typing import Dict, Any
|
|
3
|
+
from dataclasses import asdict, is_dataclass
|
|
4
|
+
from .base_types.config import RuntimeConfig
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def flatten_architectures(architectures):
|
|
8
|
+
flat_architectures = {}
|
|
9
|
+
for arch_id, architecture in architectures.items():
|
|
10
|
+
if isinstance(architecture, list):
|
|
11
|
+
for arch in architecture:
|
|
12
|
+
flat_architectures[f"{arch_id}:{arch.__name__}"] = architecture
|
|
13
|
+
else:
|
|
14
|
+
flat_architectures[arch_id] = architecture
|
|
15
|
+
|
|
16
|
+
return flat_architectures
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def to_snake_case(value):
|
|
20
|
+
"""
|
|
21
|
+
Convert CamelCase to snake_case
|
|
22
|
+
"""
|
|
23
|
+
pattern = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
|
|
24
|
+
return pattern.sub("_", value).lower()
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def serialize_config(config: RuntimeConfig) -> dict[str, Any]:
|
|
28
|
+
"""
|
|
29
|
+
Serialize a dataclass (like RuntimeConfig) into a dictionary.
|
|
30
|
+
This function handles nested dataclasses and converts them to dictionaries as well.
|
|
31
|
+
"""
|
|
32
|
+
|
|
33
|
+
def serialize(obj):
|
|
34
|
+
if is_dataclass(obj):
|
|
35
|
+
return {k: serialize(v) for k, v in asdict(obj).items()}
|
|
36
|
+
elif isinstance(obj, list):
|
|
37
|
+
return [serialize(item) for item in obj]
|
|
38
|
+
elif isinstance(obj, dict):
|
|
39
|
+
return {k: serialize(v) for k, v in obj.items()}
|
|
40
|
+
else:
|
|
41
|
+
return obj
|
|
42
|
+
|
|
43
|
+
return serialize(config)
|
gen_worker/types.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
from typing import IO, Optional
|
|
5
|
+
|
|
6
|
+
import msgspec
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Asset(msgspec.Struct):
|
|
10
|
+
"""Reference to a file in the invoking tenant's file store.
|
|
11
|
+
|
|
12
|
+
The worker runtime should populate `local_path` before invoking tenant code
|
|
13
|
+
so the function can open/read the file efficiently.
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
ref: str
|
|
17
|
+
tenant_id: Optional[str] = None
|
|
18
|
+
local_path: Optional[str] = None
|
|
19
|
+
mime_type: Optional[str] = None
|
|
20
|
+
size_bytes: Optional[int] = None
|
|
21
|
+
sha256: Optional[str] = None
|
|
22
|
+
|
|
23
|
+
def __fspath__(self) -> str:
|
|
24
|
+
if self.local_path is None:
|
|
25
|
+
raise ValueError("Asset.local_path is not set (file not materialized)")
|
|
26
|
+
return self.local_path
|
|
27
|
+
|
|
28
|
+
def open(self, mode: str = "rb") -> IO[bytes]:
|
|
29
|
+
if "b" not in mode:
|
|
30
|
+
raise ValueError("Asset.open only supports binary modes")
|
|
31
|
+
if self.local_path is None:
|
|
32
|
+
raise ValueError("Asset.local_path is not set (file not materialized)")
|
|
33
|
+
return open(self.local_path, mode)
|
|
34
|
+
|
|
35
|
+
def exists(self) -> bool:
|
|
36
|
+
if self.local_path is None:
|
|
37
|
+
return False
|
|
38
|
+
return os.path.exists(self.local_path)
|
|
39
|
+
|
|
40
|
+
def read_bytes(self, max_bytes: Optional[int] = None) -> bytes:
|
|
41
|
+
if self.local_path is None:
|
|
42
|
+
raise ValueError("Asset.local_path is not set (file not materialized)")
|
|
43
|
+
with open(self.local_path, "rb") as f:
|
|
44
|
+
data = f.read() if max_bytes is None else f.read(max_bytes + 1)
|
|
45
|
+
if max_bytes is not None and len(data) > max_bytes:
|
|
46
|
+
raise ValueError("asset too large to read into memory")
|
|
47
|
+
return data
|