gen-worker 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gen_worker/__init__.py +19 -0
- gen_worker/decorators.py +66 -0
- gen_worker/default_model_manager/__init__.py +5 -0
- gen_worker/downloader.py +84 -0
- gen_worker/entrypoint.py +135 -0
- gen_worker/errors.py +10 -0
- gen_worker/model_interface.py +48 -0
- gen_worker/pb/__init__.py +27 -0
- gen_worker/pb/frontend_pb2.py +53 -0
- gen_worker/pb/frontend_pb2_grpc.py +189 -0
- gen_worker/pb/worker_scheduler_pb2.py +69 -0
- gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
- gen_worker/py.typed +0 -0
- gen_worker/testing/__init__.py +1 -0
- gen_worker/testing/stub_manager.py +69 -0
- gen_worker/torch_manager/__init__.py +4 -0
- gen_worker/torch_manager/manager.py +2059 -0
- gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
- gen_worker/torch_manager/utils/base_types/common.py +52 -0
- gen_worker/torch_manager/utils/base_types/config.py +46 -0
- gen_worker/torch_manager/utils/config.py +321 -0
- gen_worker/torch_manager/utils/db/database.py +46 -0
- gen_worker/torch_manager/utils/device.py +26 -0
- gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
- gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
- gen_worker/torch_manager/utils/globals.py +59 -0
- gen_worker/torch_manager/utils/load_models.py +238 -0
- gen_worker/torch_manager/utils/local_cache.py +340 -0
- gen_worker/torch_manager/utils/model_downloader.py +763 -0
- gen_worker/torch_manager/utils/parse_cli.py +98 -0
- gen_worker/torch_manager/utils/paths.py +22 -0
- gen_worker/torch_manager/utils/repository.py +141 -0
- gen_worker/torch_manager/utils/utils.py +43 -0
- gen_worker/types.py +47 -0
- gen_worker/worker.py +1720 -0
- gen_worker-0.1.4.dist-info/METADATA +113 -0
- gen_worker-0.1.4.dist-info/RECORD +38 -0
- gen_worker-0.1.4.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
3
|
+
from typing import Any, Iterable, TypeVar, Optional, TypedDict, Generic
|
|
4
|
+
from spandrel import Architecture as SpandrelArchitecture, ImageModelDescriptor
|
|
5
|
+
from .common import StateDict, TorchDevice
|
|
6
|
+
|
|
7
|
+
T = TypeVar("T", bound=torch.nn.Module, covariant=True)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class ComponentMetadata(TypedDict):
|
|
11
|
+
display_name: str
|
|
12
|
+
input_space: str
|
|
13
|
+
output_space: str
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# TO DO: in the future, maybe we can compare sets of keys, rather than use
|
|
17
|
+
# a detect method? That might be more optimized.
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Architecture(ABC, Generic[T]):
|
|
21
|
+
"""
|
|
22
|
+
The abstract-base-class that all cozy-creator Architectures should implement.
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
@property
|
|
26
|
+
def display_name(self) -> str:
|
|
27
|
+
return self._display_name
|
|
28
|
+
|
|
29
|
+
@property
|
|
30
|
+
def input_space(self) -> str:
|
|
31
|
+
return self._input_space
|
|
32
|
+
|
|
33
|
+
@property
|
|
34
|
+
def output_space(self) -> str:
|
|
35
|
+
return self._output_space
|
|
36
|
+
|
|
37
|
+
@property
|
|
38
|
+
def model(self) -> T:
|
|
39
|
+
"""Access the underlying PyTorch model."""
|
|
40
|
+
return self._model # type: ignore
|
|
41
|
+
|
|
42
|
+
@property
|
|
43
|
+
def config(self) -> Any:
|
|
44
|
+
return self._config
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
*,
|
|
49
|
+
state_dict: Optional[StateDict] = None,
|
|
50
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
51
|
+
) -> None:
|
|
52
|
+
"""
|
|
53
|
+
Constructor signature should look like this, although this abstract-base
|
|
54
|
+
class does not (and cannot) enforce your constructor signature.
|
|
55
|
+
"""
|
|
56
|
+
self._display_name = "default"
|
|
57
|
+
self._input_space = "default"
|
|
58
|
+
self._output_space = "default"
|
|
59
|
+
self._config = {}
|
|
60
|
+
pass
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
@abstractmethod
|
|
64
|
+
def detect(
|
|
65
|
+
cls,
|
|
66
|
+
*,
|
|
67
|
+
state_dict: Optional[StateDict] = None,
|
|
68
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
69
|
+
) -> Optional[ComponentMetadata]:
|
|
70
|
+
"""
|
|
71
|
+
Detects whether the given state dictionary matches the architecture.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
state_dict (StateDict): The state dictionary from a PyTorch model.
|
|
75
|
+
metadata (dict[str, Any]): optional additional metadata to help identify the model
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
bool: True if the state dictionary matches the architecture, False otherwise.
|
|
79
|
+
"""
|
|
80
|
+
pass
|
|
81
|
+
|
|
82
|
+
@abstractmethod
|
|
83
|
+
def load(
|
|
84
|
+
self,
|
|
85
|
+
state_dict: StateDict,
|
|
86
|
+
device: Optional[TorchDevice] = None,
|
|
87
|
+
) -> None:
|
|
88
|
+
"""
|
|
89
|
+
Loads a model from the given state dictionary according to the architecture.
|
|
90
|
+
|
|
91
|
+
Args:
|
|
92
|
+
state_dict (StateDict): The state dictionary from a PyTorch model.
|
|
93
|
+
device: The device the loaded model is sent to.
|
|
94
|
+
"""
|
|
95
|
+
pass
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class SpandrelArchitectureAdapter(Architecture):
|
|
100
|
+
"""
|
|
101
|
+
This class converts architectures from the spandrel library to our own
|
|
102
|
+
Architecture interface.
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, arch: SpandrelArchitecture):
|
|
106
|
+
super().__init__()
|
|
107
|
+
if not isinstance(arch, SpandrelArchitecture):
|
|
108
|
+
raise TypeError("'arch' must be an instance of spandrel Architecture")
|
|
109
|
+
|
|
110
|
+
self.inner = arch
|
|
111
|
+
self._model = None
|
|
112
|
+
self._display_name = self.inner.name
|
|
113
|
+
|
|
114
|
+
def load(self, state_dict: StateDict, device: Optional[TorchDevice] = None) -> None:
|
|
115
|
+
descriptor = self.inner.load(state_dict)
|
|
116
|
+
if not isinstance(descriptor, ImageModelDescriptor):
|
|
117
|
+
raise TypeError("descriptor must be an instance of ImageModelDescriptor")
|
|
118
|
+
|
|
119
|
+
self._model = descriptor.model
|
|
120
|
+
if device is not None:
|
|
121
|
+
self._model.to(device)
|
|
122
|
+
elif descriptor.supports_half:
|
|
123
|
+
self._model.to(torch.float16)
|
|
124
|
+
elif descriptor.supports_bfloat16:
|
|
125
|
+
self._model.to(torch.bfloat16)
|
|
126
|
+
else:
|
|
127
|
+
raise Exception("Device not provided and could not be inferred")
|
|
128
|
+
|
|
129
|
+
@classmethod
|
|
130
|
+
def detect(
|
|
131
|
+
cls,
|
|
132
|
+
state_dict: StateDict = None,
|
|
133
|
+
metadata: dict[str, Any] = None,
|
|
134
|
+
) -> Optional[ComponentMetadata]:
|
|
135
|
+
pass
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def architecture_validator(plugin: Any) -> bool:
|
|
139
|
+
try:
|
|
140
|
+
if isinstance(plugin, Iterable):
|
|
141
|
+
return all(architecture_validator(p) for p in plugin)
|
|
142
|
+
return issubclass(plugin, Architecture)
|
|
143
|
+
except TypeError:
|
|
144
|
+
print(f"Invalid plugin type: {plugin}")
|
|
145
|
+
return False
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Union, List, Callable, Any
|
|
3
|
+
import PIL.Image
|
|
4
|
+
import numpy as np
|
|
5
|
+
from enum import Enum
|
|
6
|
+
from multiprocessing.connection import Connection
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Language(Enum):
|
|
10
|
+
"""
|
|
11
|
+
ISO 639-1 language codes; used for localizing text.
|
|
12
|
+
English will be displayed for all text lacking a localization.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
ENGLISH = "en"
|
|
16
|
+
CHINESE = "zh"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Category(Enum):
|
|
20
|
+
"""
|
|
21
|
+
Used to group nodes by category in the client.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
LOADER = {Language.ENGLISH: "Loader", Language.CHINESE: "加载器"}
|
|
25
|
+
PIPE = {Language.ENGLISH: "Pipe", Language.CHINESE: "管道"}
|
|
26
|
+
UPSCALER = {Language.ENGLISH: "Upscaler", Language.CHINESE: "升频器"}
|
|
27
|
+
MASK = {Language.ENGLISH: "Mask"}
|
|
28
|
+
INPAINTING = {Language.ENGLISH: "Inpainting"}
|
|
29
|
+
IMAGES = {Language.ENGLISH: "Images"}
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
StateDict = dict[str, torch.Tensor]
|
|
33
|
+
"""
|
|
34
|
+
The parameters of a PyTorch model, serialized as a flat dict.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
TorchDevice = Union[str, torch.device]
|
|
38
|
+
"""
|
|
39
|
+
A string like 'cpu', 'cuda', or 'mps' or a torch device object.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
ImageOutputType = Union[List[PIL.Image.Image], np.ndarray]
|
|
43
|
+
"""
|
|
44
|
+
Static typing for image outputs
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
Validator = Callable[[Any], bool]
|
|
48
|
+
|
|
49
|
+
JobQueueItem = tuple[dict[str, Any], Connection]
|
|
50
|
+
"""
|
|
51
|
+
Type of items on the job-queue
|
|
52
|
+
"""
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
from typing import Optional, Any
|
|
4
|
+
import logging
|
|
5
|
+
|
|
6
|
+
logger = logging.getLogger(__name__)
|
|
7
|
+
|
|
8
|
+
DEFAULT_HOME_DIR = os.path.expanduser("~/.cozy-creator/")
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class PipelineConfig:
|
|
13
|
+
source: str
|
|
14
|
+
class_name: Optional[str | tuple[str, str]]
|
|
15
|
+
components: Optional[dict[str, "ComponentConfig"]]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class ComponentConfig:
|
|
20
|
+
source: str
|
|
21
|
+
class_name: Optional[str | tuple[str, str]]
|
|
22
|
+
kwargs: Optional[dict[str, Any]]
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def default_home_dir() -> str:
|
|
26
|
+
return DEFAULT_HOME_DIR
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def default_models_path() -> str:
|
|
30
|
+
return os.path.join(DEFAULT_HOME_DIR, "models")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def default_host() -> str:
|
|
35
|
+
return "0.0.0.0" if os.path.exists("/.dockerenv") else "localhost"
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@dataclass
|
|
39
|
+
class RuntimeConfig:
|
|
40
|
+
home_dir: str = field(default_factory=default_home_dir)
|
|
41
|
+
environment: str = "dev"
|
|
42
|
+
host: str = field(default_factory=default_host)
|
|
43
|
+
port: int = 8882
|
|
44
|
+
pipeline_defs: dict[str, PipelineConfig] = field(default_factory=dict)
|
|
45
|
+
enabled_models: list[str] = field(default_factory=list)
|
|
46
|
+
models_path: str = field(default_factory=default_models_path)
|
|
@@ -0,0 +1,321 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import argparse
|
|
3
|
+
from typing import Optional, List, Callable, Dict, Any
|
|
4
|
+
from .base_types.config import RuntimeConfig, PipelineConfig
|
|
5
|
+
import yaml
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
logger = logging.getLogger(__name__)
|
|
10
|
+
|
|
11
|
+
cozy_config: Optional[RuntimeConfig] = None
|
|
12
|
+
"""
|
|
13
|
+
Global configuration for the Cozy Gen-Server
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def config_loaded() -> bool:
|
|
18
|
+
"""
|
|
19
|
+
Returns a boolean indicating whether the config has been loaded.
|
|
20
|
+
This will return True if called within the cozy runtime, since the config is loaded at the start.
|
|
21
|
+
"""
|
|
22
|
+
return cozy_config is not None
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def load_pipeline_defs_from_db(enabled_models: List[str]) -> Dict[str, Any]:
|
|
26
|
+
"""
|
|
27
|
+
Load pipeline definitions from the database for the enabled models.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
enabled_models: List of model names to fetch from the database
|
|
31
|
+
|
|
32
|
+
Returns:
|
|
33
|
+
Dictionary of pipeline definitions keyed by model name
|
|
34
|
+
"""
|
|
35
|
+
try:
|
|
36
|
+
from .db.database import get_db_connection
|
|
37
|
+
from .repository import get_pipeline_defs
|
|
38
|
+
|
|
39
|
+
# if not enabled_models:
|
|
40
|
+
# return {}
|
|
41
|
+
|
|
42
|
+
db_conn = get_db_connection()
|
|
43
|
+
|
|
44
|
+
if not db_conn:
|
|
45
|
+
logger.error("load_pipeline_defs_from_db: Could not get database connection.")
|
|
46
|
+
return {} # Cannot proceed without DB connection
|
|
47
|
+
|
|
48
|
+
names_to_query: List[str]
|
|
49
|
+
if enabled_models is None or not enabled_models: # Check if None or empty list
|
|
50
|
+
logger.info("load_pipeline_defs_from_db: No specific model names provided, fetching all model names from DB.")
|
|
51
|
+
all_db_model_names = []
|
|
52
|
+
with db_conn.cursor() as cursor: # Use a different variable name for cursor
|
|
53
|
+
cursor.execute("SELECT name FROM pipeline_defs WHERE source IS NOT NULL AND source != ''") # Fetch only usable models
|
|
54
|
+
rows = cursor.fetchall()
|
|
55
|
+
for row in rows:
|
|
56
|
+
all_db_model_names.append(row['name'])
|
|
57
|
+
|
|
58
|
+
if not all_db_model_names:
|
|
59
|
+
logger.warning("load_pipeline_defs_from_db: No model names found in DB to load definitions for.")
|
|
60
|
+
return {}
|
|
61
|
+
names_to_query = all_db_model_names
|
|
62
|
+
else:
|
|
63
|
+
names_to_query = enabled_models
|
|
64
|
+
|
|
65
|
+
logger.debug(f"load_pipeline_defs_from_db: Fetching definitions for models: {names_to_query}")
|
|
66
|
+
|
|
67
|
+
# get_pipeline_defs returns List[PipelineDef objects]
|
|
68
|
+
pipeline_def_objects = get_pipeline_defs(db_conn, names_to_query)
|
|
69
|
+
|
|
70
|
+
if not pipeline_def_objects:
|
|
71
|
+
logger.warning(f"load_pipeline_defs_from_db: get_pipeline_defs returned no objects for names: {names_to_query}")
|
|
72
|
+
return {}
|
|
73
|
+
|
|
74
|
+
logger.info(f"load_pipeline_defs_from_db: Loaded {len(pipeline_def_objects)} PipelineDef objects from repository.")
|
|
75
|
+
|
|
76
|
+
# Convert DB models to dictionary format
|
|
77
|
+
db_pipeline_defs = {}
|
|
78
|
+
for model in pipeline_def_objects:
|
|
79
|
+
pipeline_def = {
|
|
80
|
+
"source": model.source,
|
|
81
|
+
"class_name": model.class_name,
|
|
82
|
+
"custom_pipeline": model.custom_pipeline,
|
|
83
|
+
"default_args": model.default_args,
|
|
84
|
+
"metadata": model.metadata,
|
|
85
|
+
"components": {}
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
# Convert components
|
|
89
|
+
if model.components:
|
|
90
|
+
for name, comp in model.components.items():
|
|
91
|
+
if isinstance(comp, dict):
|
|
92
|
+
pipeline_def["components"][name] = {
|
|
93
|
+
"class_name": comp.get("class_name", ""),
|
|
94
|
+
"source": comp.get("source", ""),
|
|
95
|
+
"kwargs": comp.get("kwargs", {})
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
# Add prompt definitions if available
|
|
99
|
+
if hasattr(model, "prompt_def") and model.prompt_def:
|
|
100
|
+
if not pipeline_def.get("default_args"):
|
|
101
|
+
pipeline_def["default_args"] = {}
|
|
102
|
+
pipeline_def["default_args"]["positive_prompt"] = model.prompt_def.positive_prompt
|
|
103
|
+
pipeline_def["default_args"]["negative_prompt"] = model.prompt_def.negative_prompt
|
|
104
|
+
|
|
105
|
+
db_pipeline_defs[model.name] = pipeline_def
|
|
106
|
+
|
|
107
|
+
return db_pipeline_defs
|
|
108
|
+
except Exception as e:
|
|
109
|
+
logger.error(f"Error loading pipeline definitions from database: {e}")
|
|
110
|
+
return {}
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def merge_pipeline_defs(existing_defs: Dict[str, Any], incoming_defs: Dict[str, Any]) -> Dict[str, Any]:
|
|
114
|
+
"""
|
|
115
|
+
Merge pipeline definitions from different sources, similar to Go implementation.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
existing_defs: Pipeline definitions from config.yaml
|
|
119
|
+
incoming_defs: Pipeline definitions from database
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
Merged pipeline definitions
|
|
123
|
+
"""
|
|
124
|
+
merged_defs = existing_defs.copy()
|
|
125
|
+
|
|
126
|
+
# Merge incoming defs into existing defs
|
|
127
|
+
for model_id, model_def in incoming_defs.items():
|
|
128
|
+
if model_id in merged_defs:
|
|
129
|
+
# Update only empty fields in existing definition
|
|
130
|
+
existing_def = merged_defs[model_id]
|
|
131
|
+
|
|
132
|
+
if not existing_def.get("source"):
|
|
133
|
+
existing_def["source"] = model_def.get("source", "")
|
|
134
|
+
|
|
135
|
+
if not existing_def.get("class_name"):
|
|
136
|
+
existing_def["class_name"] = model_def.get("class_name", "")
|
|
137
|
+
|
|
138
|
+
if not existing_def.get("custom_pipeline"):
|
|
139
|
+
existing_def["custom_pipeline"] = model_def.get("custom_pipeline", "")
|
|
140
|
+
|
|
141
|
+
if not existing_def.get("default_args"):
|
|
142
|
+
existing_def["default_args"] = model_def.get("default_args", {})
|
|
143
|
+
|
|
144
|
+
if not existing_def.get("metadata"):
|
|
145
|
+
existing_def["metadata"] = model_def.get("metadata", {})
|
|
146
|
+
|
|
147
|
+
# Handle components
|
|
148
|
+
if not existing_def.get("components"):
|
|
149
|
+
existing_def["components"] = {}
|
|
150
|
+
|
|
151
|
+
# Merge components
|
|
152
|
+
for comp_name, comp_def in model_def.get("components", {}).items():
|
|
153
|
+
if comp_name in existing_def["components"]:
|
|
154
|
+
# Update component fields if empty
|
|
155
|
+
existing_comp = existing_def["components"][comp_name]
|
|
156
|
+
|
|
157
|
+
if not existing_comp.get("class_name"):
|
|
158
|
+
existing_comp["class_name"] = comp_def.get("class_name", "")
|
|
159
|
+
|
|
160
|
+
if not existing_comp.get("source"):
|
|
161
|
+
existing_comp["source"] = comp_def.get("source", "")
|
|
162
|
+
|
|
163
|
+
if not existing_comp.get("kwargs"):
|
|
164
|
+
existing_comp["kwargs"] = comp_def.get("kwargs", {})
|
|
165
|
+
else:
|
|
166
|
+
# Add new component
|
|
167
|
+
existing_def["components"][comp_name] = comp_def
|
|
168
|
+
else:
|
|
169
|
+
# Add new model definition
|
|
170
|
+
merged_defs[model_id] = model_def
|
|
171
|
+
|
|
172
|
+
# Remove models without a source
|
|
173
|
+
models_to_remove = [model_id for model_id, def_obj in merged_defs.items()
|
|
174
|
+
if not def_obj.get("source")]
|
|
175
|
+
for model_id in models_to_remove:
|
|
176
|
+
del merged_defs[model_id]
|
|
177
|
+
|
|
178
|
+
return merged_defs
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def load_config() -> RuntimeConfig:
|
|
182
|
+
"""
|
|
183
|
+
Load the configuration from a YAML file located at COZY_HOME/config.yaml.
|
|
184
|
+
Merges it with default values and database pipeline definitions.
|
|
185
|
+
"""
|
|
186
|
+
default_home = os.path.expanduser("~/.cozy-creator")
|
|
187
|
+
|
|
188
|
+
cozy_mount_path = os.getenv("COZY_MOUNT_PATH")
|
|
189
|
+
if cozy_mount_path:
|
|
190
|
+
default_home = cozy_mount_path
|
|
191
|
+
|
|
192
|
+
default_models_path = os.path.join(default_home, "models")
|
|
193
|
+
|
|
194
|
+
default_config = {
|
|
195
|
+
"home_dir": default_home,
|
|
196
|
+
"environment": "dev",
|
|
197
|
+
"host": "localhost",
|
|
198
|
+
"port": 8882,
|
|
199
|
+
"pipeline_defs": {},
|
|
200
|
+
# "enabled_models": [],
|
|
201
|
+
"models_path": default_models_path,
|
|
202
|
+
}
|
|
203
|
+
|
|
204
|
+
print(f"default_config: {default_config}")
|
|
205
|
+
|
|
206
|
+
# Use COZY_HOME if set, else default.
|
|
207
|
+
home_dir = os.environ.get("COZY_HOME", default_home)
|
|
208
|
+
config_path = os.path.join(home_dir, "config.yaml")
|
|
209
|
+
|
|
210
|
+
merged = default_config.copy()
|
|
211
|
+
|
|
212
|
+
if os.path.exists(config_path):
|
|
213
|
+
try:
|
|
214
|
+
with open(config_path, "r") as f:
|
|
215
|
+
yaml_config = yaml.safe_load(f) or {}
|
|
216
|
+
|
|
217
|
+
# Merge basic config values
|
|
218
|
+
# for key, value in yaml_config.items():
|
|
219
|
+
# if key != "pipeline_defs":
|
|
220
|
+
# merged[key] = value
|
|
221
|
+
|
|
222
|
+
# Get pipeline defs from config
|
|
223
|
+
config_pipeline_defs = yaml_config.get("pipeline_defs", {})
|
|
224
|
+
except Exception as e:
|
|
225
|
+
logger.error(f"Error loading config from {config_path}: {e}")
|
|
226
|
+
config_pipeline_defs = {}
|
|
227
|
+
else:
|
|
228
|
+
logger.warning(f"Config file {config_path} not found. Using default configuration.")
|
|
229
|
+
config_pipeline_defs = {}
|
|
230
|
+
|
|
231
|
+
# Get enabled models from environment if provided
|
|
232
|
+
# enabled_models_env = os.environ.get("ENABLED_MODELS")
|
|
233
|
+
# if enabled_models_env:
|
|
234
|
+
# try:
|
|
235
|
+
# # If provided as JSON.
|
|
236
|
+
# merged["enabled_models"] = json.loads(enabled_models_env)
|
|
237
|
+
# except Exception:
|
|
238
|
+
# # Otherwise, assume comma-separated.
|
|
239
|
+
# merged["enabled_models"] = [m.strip() for m in enabled_models_env.split(",") if m.strip()]
|
|
240
|
+
|
|
241
|
+
# Load ALL pipeline definitions from the database
|
|
242
|
+
# The modified load_pipeline_defs_from_db with no args should fetch all.
|
|
243
|
+
logger.info("Loading all pipeline definitions from database...")
|
|
244
|
+
db_pipeline_defs = load_pipeline_defs_from_db(enabled_models=None)
|
|
245
|
+
print(f"db_pipeline_defs: {db_pipeline_defs}")
|
|
246
|
+
|
|
247
|
+
# Merge pipeline definitions from config and database
|
|
248
|
+
merged["pipeline_defs"] = merge_pipeline_defs(config_pipeline_defs, db_pipeline_defs)
|
|
249
|
+
print(f"merged: {merged}")
|
|
250
|
+
|
|
251
|
+
return RuntimeConfig(**merged)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def set_config(config: RuntimeConfig):
|
|
255
|
+
"""
|
|
256
|
+
Sets the global configuration object .
|
|
257
|
+
"""
|
|
258
|
+
global cozy_config
|
|
259
|
+
cozy_config = config
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def set_environment(environment: str):
|
|
263
|
+
"""
|
|
264
|
+
Sets the global environment variable.
|
|
265
|
+
"""
|
|
266
|
+
global ENVIRONMENT
|
|
267
|
+
ENVIRONMENT = environment
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def get_environment() -> str:
|
|
271
|
+
"""
|
|
272
|
+
Returns the global environment variable.
|
|
273
|
+
"""
|
|
274
|
+
if ENVIRONMENT is None:
|
|
275
|
+
raise ValueError("Environment has not been set yet")
|
|
276
|
+
|
|
277
|
+
return ENVIRONMENT
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
def get_config() -> RuntimeConfig:
|
|
281
|
+
"""
|
|
282
|
+
Returns the global configuration object. This is only available if the config has been loaded, which happens at
|
|
283
|
+
the start of the server, else it will raise an error.
|
|
284
|
+
"""
|
|
285
|
+
if cozy_config is None:
|
|
286
|
+
raise ValueError("Config has not been loaded yet")
|
|
287
|
+
|
|
288
|
+
return cozy_config
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
ParseArgsMethod = Callable[
|
|
292
|
+
[argparse.ArgumentParser, Optional[List[str]], Optional[argparse.Namespace]],
|
|
293
|
+
Optional[argparse.Namespace],
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def is_model_enabled(model_name: str) -> bool:
|
|
298
|
+
"""
|
|
299
|
+
Returns a boolean indicating whether a model is enabled in the global configuration.
|
|
300
|
+
"""
|
|
301
|
+
config = get_config()
|
|
302
|
+
if config.pipeline_defs is None:
|
|
303
|
+
return False
|
|
304
|
+
|
|
305
|
+
return model_name in config.pipeline_defs.keys()
|
|
306
|
+
|
|
307
|
+
|
|
308
|
+
def get_mock_config() -> RuntimeConfig:
|
|
309
|
+
"""
|
|
310
|
+
Returns a mock (or test) version of the global configuration object.
|
|
311
|
+
This can be used outside of the cozy server environment.
|
|
312
|
+
"""
|
|
313
|
+
|
|
314
|
+
environment = "test"
|
|
315
|
+
# home_dir = DEFAULT_HOME_DIR
|
|
316
|
+
|
|
317
|
+
return RuntimeConfig(
|
|
318
|
+
port=8881,
|
|
319
|
+
host="127.0.0.1",
|
|
320
|
+
environment=environment,
|
|
321
|
+
)
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import logging
|
|
3
|
+
from typing import List, Optional
|
|
4
|
+
import psycopg2
|
|
5
|
+
from psycopg2.extras import RealDictCursor
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
# Load environment variables from .env file
|
|
11
|
+
load_dotenv()
|
|
12
|
+
|
|
13
|
+
_db_connection = None
|
|
14
|
+
|
|
15
|
+
def get_db_connection():
|
|
16
|
+
"""
|
|
17
|
+
Returns a connection to the database. Creates a new connection if one doesn't exist.
|
|
18
|
+
Uses the DB_DSN environment variable for connection details.
|
|
19
|
+
"""
|
|
20
|
+
global _db_connection
|
|
21
|
+
|
|
22
|
+
if _db_connection is None:
|
|
23
|
+
db_dsn = os.getenv("DB_DSN")
|
|
24
|
+
if not db_dsn:
|
|
25
|
+
raise ValueError("DB_DSN environment variable not set")
|
|
26
|
+
|
|
27
|
+
try:
|
|
28
|
+
logger.info("Connecting to database...")
|
|
29
|
+
_db_connection = psycopg2.connect(db_dsn, cursor_factory=RealDictCursor)
|
|
30
|
+
logger.info("Database connection established")
|
|
31
|
+
except Exception as e:
|
|
32
|
+
logger.error(f"Error connecting to database: {e}")
|
|
33
|
+
raise
|
|
34
|
+
|
|
35
|
+
return _db_connection
|
|
36
|
+
|
|
37
|
+
def close_db_connection():
|
|
38
|
+
"""
|
|
39
|
+
Closes the database connection if it exists.
|
|
40
|
+
"""
|
|
41
|
+
global _db_connection
|
|
42
|
+
|
|
43
|
+
if _db_connection is not None:
|
|
44
|
+
_db_connection.close()
|
|
45
|
+
_db_connection = None
|
|
46
|
+
logger.info("Database connection closed")
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
logger = logging.getLogger(__name__)
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def get_torch_device(index: int = 0) -> torch.device:
|
|
8
|
+
if torch.cuda.is_available():
|
|
9
|
+
return torch.device("cuda", index)
|
|
10
|
+
if torch.backends.mps.is_available():
|
|
11
|
+
return torch.device("mps", index)
|
|
12
|
+
if torch.xpu.is_available():
|
|
13
|
+
return torch.device("xpu", index)
|
|
14
|
+
|
|
15
|
+
logger.warning("No device found, using CPU. This will slow down performance.")
|
|
16
|
+
return torch.device("cpu")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def get_torch_device_count() -> int:
|
|
20
|
+
if torch.cuda.is_available():
|
|
21
|
+
return torch.cuda.device_count()
|
|
22
|
+
if torch.backends.mps.is_available():
|
|
23
|
+
return torch.mps.device_count()
|
|
24
|
+
if torch.xpu.is_available():
|
|
25
|
+
return torch.xpu.device_count()
|
|
26
|
+
return 1
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
def fix_sdxl_compat():
|
|
2
|
+
"""Fix SDXL compatibility across diffusers versions"""
|
|
3
|
+
try:
|
|
4
|
+
from diffusers import StableDiffusionXLPipeline
|
|
5
|
+
if not hasattr(StableDiffusionXLPipeline, 'do_classifier_free_guidance'):
|
|
6
|
+
StableDiffusionXLPipeline.do_classifier_free_guidance = True
|
|
7
|
+
except:
|
|
8
|
+
pass
|
|
9
|
+
|
|
10
|
+
fix_sdxl_compat() # Auto-run on import
|