gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,98 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import logging
5
+ from .base_types.config import RuntimeConfig, PipelineConfig
6
+ from typing import Optional
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # ====== Parse cli arguments ======
11
+
12
+
13
+ def parse_pipeline_defs(value: Optional[str]) -> dict[str, PipelineConfig]:
14
+ """Parse pipeline definitions from command line argument"""
15
+ if not value:
16
+ return {}
17
+
18
+ try:
19
+ loaded = json.loads(value)
20
+ if isinstance(loaded, dict):
21
+ return loaded
22
+ else:
23
+ logger.error("Pipeline definitions are not a dictionary")
24
+ return {}
25
+ except json.JSONDecodeError as e:
26
+ logger.error(f"Failed to parse pipeline definitions: {e}")
27
+ return {}
28
+
29
+
30
+ def parse_enabled_models(value: Optional[str]) -> list[str]:
31
+ """Parse enabled models from command line argument"""
32
+ if not value:
33
+ return []
34
+
35
+ try:
36
+ return json.loads(value)
37
+ except json.JSONDecodeError:
38
+ return value.split(",")
39
+
40
+
41
+ def parse_arguments() -> RuntimeConfig:
42
+ """Parse command line arguments and return configuration"""
43
+ parser = argparse.ArgumentParser(description="Cozy Creator")
44
+
45
+ default_config = RuntimeConfig()
46
+
47
+ parser.add_argument(
48
+ "--home-dir",
49
+ default=default_config.home_dir,
50
+ help="Cozy creator's home directory",
51
+ )
52
+ parser.add_argument(
53
+ "--environment",
54
+ default=default_config.environment,
55
+ help="Server environment (dev/prod)",
56
+ )
57
+ parser.add_argument(
58
+ "--host", default=default_config.host, help="Hostname or IP-address"
59
+ )
60
+ parser.add_argument(
61
+ "--port",
62
+ type=int,
63
+ default=default_config.port,
64
+ help="Port to bind Python runtime to",
65
+ )
66
+ parser.add_argument(
67
+ "--pipeline-defs",
68
+ type=str,
69
+ default=default_config.pipeline_defs,
70
+ help="JSON string of pipeline definitions",
71
+ )
72
+ parser.add_argument(
73
+ "--enabled-models",
74
+ type=str,
75
+ default=default_config.enabled_models,
76
+ help="Comma-separated list or JSON array of models to warm up",
77
+ )
78
+ parser.add_argument(
79
+ "--models-path",
80
+ type=str,
81
+ default=default_config.models_path,
82
+ help="Path to models directory",
83
+ )
84
+
85
+ args = parser.parse_args()
86
+
87
+ # Update config with parsed arguments
88
+ config = RuntimeConfig(
89
+ home_dir=args.home_dir,
90
+ environment=args.environment,
91
+ host=args.host,
92
+ port=args.port,
93
+ pipeline_defs=parse_pipeline_defs(args.pipeline_defs),
94
+ enabled_models=parse_enabled_models(args.enabled_models),
95
+ models_path=args.models_path or os.path.join(args.home_dir, "models"),
96
+ )
97
+
98
+ return config
@@ -0,0 +1,22 @@
1
+ import os
2
+ from .config import get_config
3
+
4
+
5
+
6
+ def get_assets_dir():
7
+ config = get_config()
8
+ if config.assets_path:
9
+ return os.path.expanduser(config.assets_path)
10
+ return os.path.join(get_home_dir(), "assets")
11
+
12
+
13
+ def get_models_dir():
14
+ config = get_config()
15
+
16
+ if config.models_path:
17
+ return os.path.expanduser(config.models_path)
18
+ return os.path.join(get_home_dir(), "models")
19
+
20
+
21
+ def get_home_dir():
22
+ return os.path.expanduser(get_config().home_dir)
@@ -0,0 +1,141 @@
1
+ import json
2
+ import logging
3
+ from typing import List, Dict, Any, Optional
4
+ from decimal import Decimal
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ class PipelineDef:
9
+ """
10
+ Model class for pipeline definitions from the database.
11
+ """
12
+ def __init__(self, name: str, source: str = "", class_name: str = "",
13
+ custom_pipeline: str = "", default_args: Dict = None,
14
+ metadata: Dict = None, components: Dict = None,
15
+ prompt_def: Optional['PromptDef'] = None, estimated_size_gb: Optional[float] = None):
16
+ self.name = name
17
+ self.source = source
18
+ self.class_name = class_name
19
+ self.custom_pipeline = custom_pipeline
20
+ self.default_args = default_args or {}
21
+ self.metadata = metadata or {}
22
+ self.components = components or {}
23
+ self.prompt_def = prompt_def
24
+ self.estimated_size_gb = estimated_size_gb
25
+
26
+ class PromptDef:
27
+ """
28
+ Model class for prompt definitions from the database.
29
+ """
30
+ def __init__(self, pipeline_id: int, positive_prompt: str = "", negative_prompt: str = ""):
31
+ self.pipeline_id = pipeline_id
32
+ self.positive_prompt = positive_prompt
33
+ self.negative_prompt = negative_prompt
34
+
35
+ def get_pipeline_defs(db_conn, pipeline_names: List[str]) -> List[PipelineDef]:
36
+ """
37
+ Retrieves pipeline definitions from the database based on their names.
38
+ Similar to the Go GetPipelineDefs function.
39
+
40
+ Args:
41
+ db_conn: Database connection
42
+ pipeline_names: List of pipeline names to retrieve
43
+
44
+ Returns:
45
+ List of PipelineDef objects
46
+ """
47
+ if not pipeline_names:
48
+ return []
49
+
50
+ try:
51
+ pipeline_defs = []
52
+ with db_conn.cursor() as cur:
53
+ # Query to get pipeline definitions with their prompt definitions
54
+ query = """
55
+ SELECT
56
+ p.id,
57
+ p.name,
58
+ p.source,
59
+ p.class_name,
60
+ p.custom_pipeline,
61
+ p.default_args,
62
+ p.metadata,
63
+ p.components,
64
+ p.estimated_size_bytes,
65
+ pr.positive_prompt,
66
+ pr.negative_prompt
67
+ FROM
68
+ pipeline_defs p
69
+ LEFT JOIN
70
+ prompt_defs pr ON p.prompt_def_id = pr.id
71
+ WHERE
72
+ p.name = ANY(%s)
73
+ """
74
+
75
+ cur.execute(query, (pipeline_names,))
76
+ rows = cur.fetchall()
77
+
78
+ for row in rows:
79
+ # Create prompt_def if available
80
+ prompt_def = None
81
+ if row['positive_prompt'] or row['negative_prompt']:
82
+ prompt_def = PromptDef(
83
+ pipeline_id=row['id'],
84
+ positive_prompt=row['positive_prompt'] or "",
85
+ negative_prompt=row['negative_prompt'] or ""
86
+ )
87
+
88
+ # Parse JSON fields
89
+ default_args = {}
90
+ if row['default_args']:
91
+ try:
92
+ default_args = json.loads(row['default_args']) if isinstance(row['default_args'], str) else row['default_args']
93
+ except json.JSONDecodeError:
94
+ logger.warning(f"Failed to parse default_args for pipeline {row['name']}")
95
+
96
+ metadata = {}
97
+ if row['metadata']:
98
+ try:
99
+ metadata = json.loads(row['metadata']) if isinstance(row['metadata'], str) else row['metadata']
100
+ except json.JSONDecodeError:
101
+ logger.warning(f"Failed to parse metadata for pipeline {row['name']}")
102
+
103
+ components = {}
104
+ if row['components']:
105
+ try:
106
+ components = json.loads(row['components']) if isinstance(row['components'], str) else row['components']
107
+ except json.JSONDecodeError:
108
+ logger.warning(f"Failed to parse components for pipeline {row['name']}")
109
+
110
+ estimated_size_val_gb = None # Changed name
111
+ if row['estimated_size_bytes'] is not None: # DB column name is still estimated_size_bytes
112
+ try:
113
+ if isinstance(row['estimated_size_bytes'], Decimal):
114
+ estimated_size_val_gb = float(row['estimated_size_bytes'])
115
+ else:
116
+ estimated_size_val_gb = float(str(row['estimated_size_bytes']))
117
+ except (ValueError, TypeError) as e:
118
+ logger.warning(f"Could not convert estimated_size_bytes ('{row['estimated_size_bytes']}') to float for {row['name']}: {e}")
119
+
120
+ # Create PipelineDef
121
+ pipeline_def = PipelineDef(
122
+ name=row['name'],
123
+ source=row['source'] or "",
124
+ class_name=row['class_name'] or "",
125
+ custom_pipeline=row['custom_pipeline'] or "",
126
+ default_args=default_args,
127
+ metadata=metadata,
128
+ components=components,
129
+ prompt_def=prompt_def,
130
+ estimated_size_gb=estimated_size_val_gb
131
+ )
132
+
133
+ pipeline_defs.append(pipeline_def)
134
+
135
+ return pipeline_defs
136
+
137
+ except Exception as e:
138
+ logger.error(f"Error retrieving pipeline definitions: {e}")
139
+ # Ensure transaction is rolled back
140
+ db_conn.rollback()
141
+ raise
@@ -0,0 +1,43 @@
1
+ import re
2
+ from typing import Dict, Any
3
+ from dataclasses import asdict, is_dataclass
4
+ from .base_types.config import RuntimeConfig
5
+
6
+
7
+ def flatten_architectures(architectures):
8
+ flat_architectures = {}
9
+ for arch_id, architecture in architectures.items():
10
+ if isinstance(architecture, list):
11
+ for arch in architecture:
12
+ flat_architectures[f"{arch_id}:{arch.__name__}"] = architecture
13
+ else:
14
+ flat_architectures[arch_id] = architecture
15
+
16
+ return flat_architectures
17
+
18
+
19
+ def to_snake_case(value):
20
+ """
21
+ Convert CamelCase to snake_case
22
+ """
23
+ pattern = re.compile(r"(?<=[a-z])(?=[A-Z])|(?<=[A-Z])(?=[A-Z][a-z])")
24
+ return pattern.sub("_", value).lower()
25
+
26
+
27
+ def serialize_config(config: RuntimeConfig) -> dict[str, Any]:
28
+ """
29
+ Serialize a dataclass (like RuntimeConfig) into a dictionary.
30
+ This function handles nested dataclasses and converts them to dictionaries as well.
31
+ """
32
+
33
+ def serialize(obj):
34
+ if is_dataclass(obj):
35
+ return {k: serialize(v) for k, v in asdict(obj).items()}
36
+ elif isinstance(obj, list):
37
+ return [serialize(item) for item in obj]
38
+ elif isinstance(obj, dict):
39
+ return {k: serialize(v) for k, v in obj.items()}
40
+ else:
41
+ return obj
42
+
43
+ return serialize(config)
gen_worker/types.py ADDED
@@ -0,0 +1,47 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import IO, Optional
5
+
6
+ import msgspec
7
+
8
+
9
+ class Asset(msgspec.Struct):
10
+ """Reference to a file in the invoking tenant's file store.
11
+
12
+ The worker runtime should populate `local_path` before invoking tenant code
13
+ so the function can open/read the file efficiently.
14
+ """
15
+
16
+ ref: str
17
+ tenant_id: Optional[str] = None
18
+ local_path: Optional[str] = None
19
+ mime_type: Optional[str] = None
20
+ size_bytes: Optional[int] = None
21
+ sha256: Optional[str] = None
22
+
23
+ def __fspath__(self) -> str:
24
+ if self.local_path is None:
25
+ raise ValueError("Asset.local_path is not set (file not materialized)")
26
+ return self.local_path
27
+
28
+ def open(self, mode: str = "rb") -> IO[bytes]:
29
+ if "b" not in mode:
30
+ raise ValueError("Asset.open only supports binary modes")
31
+ if self.local_path is None:
32
+ raise ValueError("Asset.local_path is not set (file not materialized)")
33
+ return open(self.local_path, mode)
34
+
35
+ def exists(self) -> bool:
36
+ if self.local_path is None:
37
+ return False
38
+ return os.path.exists(self.local_path)
39
+
40
+ def read_bytes(self, max_bytes: Optional[int] = None) -> bytes:
41
+ if self.local_path is None:
42
+ raise ValueError("Asset.local_path is not set (file not materialized)")
43
+ with open(self.local_path, "rb") as f:
44
+ data = f.read() if max_bytes is None else f.read(max_bytes + 1)
45
+ if max_bytes is not None and len(data) > max_bytes:
46
+ raise ValueError("asset too large to read into memory")
47
+ return data