gen-worker 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. gen_worker/__init__.py +19 -0
  2. gen_worker/decorators.py +66 -0
  3. gen_worker/default_model_manager/__init__.py +5 -0
  4. gen_worker/downloader.py +84 -0
  5. gen_worker/entrypoint.py +135 -0
  6. gen_worker/errors.py +10 -0
  7. gen_worker/model_interface.py +48 -0
  8. gen_worker/pb/__init__.py +27 -0
  9. gen_worker/pb/frontend_pb2.py +53 -0
  10. gen_worker/pb/frontend_pb2_grpc.py +189 -0
  11. gen_worker/pb/worker_scheduler_pb2.py +69 -0
  12. gen_worker/pb/worker_scheduler_pb2_grpc.py +100 -0
  13. gen_worker/py.typed +0 -0
  14. gen_worker/testing/__init__.py +1 -0
  15. gen_worker/testing/stub_manager.py +69 -0
  16. gen_worker/torch_manager/__init__.py +4 -0
  17. gen_worker/torch_manager/manager.py +2059 -0
  18. gen_worker/torch_manager/utils/base_types/architecture.py +145 -0
  19. gen_worker/torch_manager/utils/base_types/common.py +52 -0
  20. gen_worker/torch_manager/utils/base_types/config.py +46 -0
  21. gen_worker/torch_manager/utils/config.py +321 -0
  22. gen_worker/torch_manager/utils/db/database.py +46 -0
  23. gen_worker/torch_manager/utils/device.py +26 -0
  24. gen_worker/torch_manager/utils/diffusers_fix.py +10 -0
  25. gen_worker/torch_manager/utils/flashpack_loader.py +262 -0
  26. gen_worker/torch_manager/utils/globals.py +59 -0
  27. gen_worker/torch_manager/utils/load_models.py +238 -0
  28. gen_worker/torch_manager/utils/local_cache.py +340 -0
  29. gen_worker/torch_manager/utils/model_downloader.py +763 -0
  30. gen_worker/torch_manager/utils/parse_cli.py +98 -0
  31. gen_worker/torch_manager/utils/paths.py +22 -0
  32. gen_worker/torch_manager/utils/repository.py +141 -0
  33. gen_worker/torch_manager/utils/utils.py +43 -0
  34. gen_worker/types.py +47 -0
  35. gen_worker/worker.py +1720 -0
  36. gen_worker-0.1.4.dist-info/METADATA +113 -0
  37. gen_worker-0.1.4.dist-info/RECORD +38 -0
  38. gen_worker-0.1.4.dist-info/WHEEL +4 -0
@@ -0,0 +1,145 @@
1
+ import torch
2
+ from abc import ABC, abstractmethod
3
+ from typing import Any, Iterable, TypeVar, Optional, TypedDict, Generic
4
+ from spandrel import Architecture as SpandrelArchitecture, ImageModelDescriptor
5
+ from .common import StateDict, TorchDevice
6
+
7
+ T = TypeVar("T", bound=torch.nn.Module, covariant=True)
8
+
9
+
10
+ class ComponentMetadata(TypedDict):
11
+ display_name: str
12
+ input_space: str
13
+ output_space: str
14
+
15
+
16
+ # TO DO: in the future, maybe we can compare sets of keys, rather than use
17
+ # a detect method? That might be more optimized.
18
+
19
+
20
+ class Architecture(ABC, Generic[T]):
21
+ """
22
+ The abstract-base-class that all cozy-creator Architectures should implement.
23
+ """
24
+
25
+ @property
26
+ def display_name(self) -> str:
27
+ return self._display_name
28
+
29
+ @property
30
+ def input_space(self) -> str:
31
+ return self._input_space
32
+
33
+ @property
34
+ def output_space(self) -> str:
35
+ return self._output_space
36
+
37
+ @property
38
+ def model(self) -> T:
39
+ """Access the underlying PyTorch model."""
40
+ return self._model # type: ignore
41
+
42
+ @property
43
+ def config(self) -> Any:
44
+ return self._config
45
+
46
+ def __init__(
47
+ self,
48
+ *,
49
+ state_dict: Optional[StateDict] = None,
50
+ metadata: Optional[dict[str, Any]] = None,
51
+ ) -> None:
52
+ """
53
+ Constructor signature should look like this, although this abstract-base
54
+ class does not (and cannot) enforce your constructor signature.
55
+ """
56
+ self._display_name = "default"
57
+ self._input_space = "default"
58
+ self._output_space = "default"
59
+ self._config = {}
60
+ pass
61
+
62
+ @classmethod
63
+ @abstractmethod
64
+ def detect(
65
+ cls,
66
+ *,
67
+ state_dict: Optional[StateDict] = None,
68
+ metadata: Optional[dict[str, Any]] = None,
69
+ ) -> Optional[ComponentMetadata]:
70
+ """
71
+ Detects whether the given state dictionary matches the architecture.
72
+
73
+ Args:
74
+ state_dict (StateDict): The state dictionary from a PyTorch model.
75
+ metadata (dict[str, Any]): optional additional metadata to help identify the model
76
+
77
+ Returns:
78
+ bool: True if the state dictionary matches the architecture, False otherwise.
79
+ """
80
+ pass
81
+
82
+ @abstractmethod
83
+ def load(
84
+ self,
85
+ state_dict: StateDict,
86
+ device: Optional[TorchDevice] = None,
87
+ ) -> None:
88
+ """
89
+ Loads a model from the given state dictionary according to the architecture.
90
+
91
+ Args:
92
+ state_dict (StateDict): The state dictionary from a PyTorch model.
93
+ device: The device the loaded model is sent to.
94
+ """
95
+ pass
96
+
97
+
98
+
99
+ class SpandrelArchitectureAdapter(Architecture):
100
+ """
101
+ This class converts architectures from the spandrel library to our own
102
+ Architecture interface.
103
+ """
104
+
105
+ def __init__(self, arch: SpandrelArchitecture):
106
+ super().__init__()
107
+ if not isinstance(arch, SpandrelArchitecture):
108
+ raise TypeError("'arch' must be an instance of spandrel Architecture")
109
+
110
+ self.inner = arch
111
+ self._model = None
112
+ self._display_name = self.inner.name
113
+
114
+ def load(self, state_dict: StateDict, device: Optional[TorchDevice] = None) -> None:
115
+ descriptor = self.inner.load(state_dict)
116
+ if not isinstance(descriptor, ImageModelDescriptor):
117
+ raise TypeError("descriptor must be an instance of ImageModelDescriptor")
118
+
119
+ self._model = descriptor.model
120
+ if device is not None:
121
+ self._model.to(device)
122
+ elif descriptor.supports_half:
123
+ self._model.to(torch.float16)
124
+ elif descriptor.supports_bfloat16:
125
+ self._model.to(torch.bfloat16)
126
+ else:
127
+ raise Exception("Device not provided and could not be inferred")
128
+
129
+ @classmethod
130
+ def detect(
131
+ cls,
132
+ state_dict: StateDict = None,
133
+ metadata: dict[str, Any] = None,
134
+ ) -> Optional[ComponentMetadata]:
135
+ pass
136
+
137
+
138
+ def architecture_validator(plugin: Any) -> bool:
139
+ try:
140
+ if isinstance(plugin, Iterable):
141
+ return all(architecture_validator(p) for p in plugin)
142
+ return issubclass(plugin, Architecture)
143
+ except TypeError:
144
+ print(f"Invalid plugin type: {plugin}")
145
+ return False
@@ -0,0 +1,52 @@
1
+ import torch
2
+ from typing import Union, List, Callable, Any
3
+ import PIL.Image
4
+ import numpy as np
5
+ from enum import Enum
6
+ from multiprocessing.connection import Connection
7
+
8
+
9
+ class Language(Enum):
10
+ """
11
+ ISO 639-1 language codes; used for localizing text.
12
+ English will be displayed for all text lacking a localization.
13
+ """
14
+
15
+ ENGLISH = "en"
16
+ CHINESE = "zh"
17
+
18
+
19
+ class Category(Enum):
20
+ """
21
+ Used to group nodes by category in the client.
22
+ """
23
+
24
+ LOADER = {Language.ENGLISH: "Loader", Language.CHINESE: "加载器"}
25
+ PIPE = {Language.ENGLISH: "Pipe", Language.CHINESE: "管道"}
26
+ UPSCALER = {Language.ENGLISH: "Upscaler", Language.CHINESE: "升频器"}
27
+ MASK = {Language.ENGLISH: "Mask"}
28
+ INPAINTING = {Language.ENGLISH: "Inpainting"}
29
+ IMAGES = {Language.ENGLISH: "Images"}
30
+
31
+
32
+ StateDict = dict[str, torch.Tensor]
33
+ """
34
+ The parameters of a PyTorch model, serialized as a flat dict.
35
+ """
36
+
37
+ TorchDevice = Union[str, torch.device]
38
+ """
39
+ A string like 'cpu', 'cuda', or 'mps' or a torch device object.
40
+ """
41
+
42
+ ImageOutputType = Union[List[PIL.Image.Image], np.ndarray]
43
+ """
44
+ Static typing for image outputs
45
+ """
46
+
47
+ Validator = Callable[[Any], bool]
48
+
49
+ JobQueueItem = tuple[dict[str, Any], Connection]
50
+ """
51
+ Type of items on the job-queue
52
+ """
@@ -0,0 +1,46 @@
1
+ import os
2
+ from dataclasses import dataclass, field
3
+ from typing import Optional, Any
4
+ import logging
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+ DEFAULT_HOME_DIR = os.path.expanduser("~/.cozy-creator/")
9
+
10
+
11
+ @dataclass
12
+ class PipelineConfig:
13
+ source: str
14
+ class_name: Optional[str | tuple[str, str]]
15
+ components: Optional[dict[str, "ComponentConfig"]]
16
+
17
+
18
+ @dataclass
19
+ class ComponentConfig:
20
+ source: str
21
+ class_name: Optional[str | tuple[str, str]]
22
+ kwargs: Optional[dict[str, Any]]
23
+
24
+
25
+ def default_home_dir() -> str:
26
+ return DEFAULT_HOME_DIR
27
+
28
+
29
+ def default_models_path() -> str:
30
+ return os.path.join(DEFAULT_HOME_DIR, "models")
31
+
32
+
33
+
34
+ def default_host() -> str:
35
+ return "0.0.0.0" if os.path.exists("/.dockerenv") else "localhost"
36
+
37
+
38
+ @dataclass
39
+ class RuntimeConfig:
40
+ home_dir: str = field(default_factory=default_home_dir)
41
+ environment: str = "dev"
42
+ host: str = field(default_factory=default_host)
43
+ port: int = 8882
44
+ pipeline_defs: dict[str, PipelineConfig] = field(default_factory=dict)
45
+ enabled_models: list[str] = field(default_factory=list)
46
+ models_path: str = field(default_factory=default_models_path)
@@ -0,0 +1,321 @@
1
+ import os
2
+ import argparse
3
+ from typing import Optional, List, Callable, Dict, Any
4
+ from .base_types.config import RuntimeConfig, PipelineConfig
5
+ import yaml
6
+ import json
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ cozy_config: Optional[RuntimeConfig] = None
12
+ """
13
+ Global configuration for the Cozy Gen-Server
14
+ """
15
+
16
+
17
+ def config_loaded() -> bool:
18
+ """
19
+ Returns a boolean indicating whether the config has been loaded.
20
+ This will return True if called within the cozy runtime, since the config is loaded at the start.
21
+ """
22
+ return cozy_config is not None
23
+
24
+
25
+ def load_pipeline_defs_from_db(enabled_models: List[str]) -> Dict[str, Any]:
26
+ """
27
+ Load pipeline definitions from the database for the enabled models.
28
+
29
+ Args:
30
+ enabled_models: List of model names to fetch from the database
31
+
32
+ Returns:
33
+ Dictionary of pipeline definitions keyed by model name
34
+ """
35
+ try:
36
+ from .db.database import get_db_connection
37
+ from .repository import get_pipeline_defs
38
+
39
+ # if not enabled_models:
40
+ # return {}
41
+
42
+ db_conn = get_db_connection()
43
+
44
+ if not db_conn:
45
+ logger.error("load_pipeline_defs_from_db: Could not get database connection.")
46
+ return {} # Cannot proceed without DB connection
47
+
48
+ names_to_query: List[str]
49
+ if enabled_models is None or not enabled_models: # Check if None or empty list
50
+ logger.info("load_pipeline_defs_from_db: No specific model names provided, fetching all model names from DB.")
51
+ all_db_model_names = []
52
+ with db_conn.cursor() as cursor: # Use a different variable name for cursor
53
+ cursor.execute("SELECT name FROM pipeline_defs WHERE source IS NOT NULL AND source != ''") # Fetch only usable models
54
+ rows = cursor.fetchall()
55
+ for row in rows:
56
+ all_db_model_names.append(row['name'])
57
+
58
+ if not all_db_model_names:
59
+ logger.warning("load_pipeline_defs_from_db: No model names found in DB to load definitions for.")
60
+ return {}
61
+ names_to_query = all_db_model_names
62
+ else:
63
+ names_to_query = enabled_models
64
+
65
+ logger.debug(f"load_pipeline_defs_from_db: Fetching definitions for models: {names_to_query}")
66
+
67
+ # get_pipeline_defs returns List[PipelineDef objects]
68
+ pipeline_def_objects = get_pipeline_defs(db_conn, names_to_query)
69
+
70
+ if not pipeline_def_objects:
71
+ logger.warning(f"load_pipeline_defs_from_db: get_pipeline_defs returned no objects for names: {names_to_query}")
72
+ return {}
73
+
74
+ logger.info(f"load_pipeline_defs_from_db: Loaded {len(pipeline_def_objects)} PipelineDef objects from repository.")
75
+
76
+ # Convert DB models to dictionary format
77
+ db_pipeline_defs = {}
78
+ for model in pipeline_def_objects:
79
+ pipeline_def = {
80
+ "source": model.source,
81
+ "class_name": model.class_name,
82
+ "custom_pipeline": model.custom_pipeline,
83
+ "default_args": model.default_args,
84
+ "metadata": model.metadata,
85
+ "components": {}
86
+ }
87
+
88
+ # Convert components
89
+ if model.components:
90
+ for name, comp in model.components.items():
91
+ if isinstance(comp, dict):
92
+ pipeline_def["components"][name] = {
93
+ "class_name": comp.get("class_name", ""),
94
+ "source": comp.get("source", ""),
95
+ "kwargs": comp.get("kwargs", {})
96
+ }
97
+
98
+ # Add prompt definitions if available
99
+ if hasattr(model, "prompt_def") and model.prompt_def:
100
+ if not pipeline_def.get("default_args"):
101
+ pipeline_def["default_args"] = {}
102
+ pipeline_def["default_args"]["positive_prompt"] = model.prompt_def.positive_prompt
103
+ pipeline_def["default_args"]["negative_prompt"] = model.prompt_def.negative_prompt
104
+
105
+ db_pipeline_defs[model.name] = pipeline_def
106
+
107
+ return db_pipeline_defs
108
+ except Exception as e:
109
+ logger.error(f"Error loading pipeline definitions from database: {e}")
110
+ return {}
111
+
112
+
113
+ def merge_pipeline_defs(existing_defs: Dict[str, Any], incoming_defs: Dict[str, Any]) -> Dict[str, Any]:
114
+ """
115
+ Merge pipeline definitions from different sources, similar to Go implementation.
116
+
117
+ Args:
118
+ existing_defs: Pipeline definitions from config.yaml
119
+ incoming_defs: Pipeline definitions from database
120
+
121
+ Returns:
122
+ Merged pipeline definitions
123
+ """
124
+ merged_defs = existing_defs.copy()
125
+
126
+ # Merge incoming defs into existing defs
127
+ for model_id, model_def in incoming_defs.items():
128
+ if model_id in merged_defs:
129
+ # Update only empty fields in existing definition
130
+ existing_def = merged_defs[model_id]
131
+
132
+ if not existing_def.get("source"):
133
+ existing_def["source"] = model_def.get("source", "")
134
+
135
+ if not existing_def.get("class_name"):
136
+ existing_def["class_name"] = model_def.get("class_name", "")
137
+
138
+ if not existing_def.get("custom_pipeline"):
139
+ existing_def["custom_pipeline"] = model_def.get("custom_pipeline", "")
140
+
141
+ if not existing_def.get("default_args"):
142
+ existing_def["default_args"] = model_def.get("default_args", {})
143
+
144
+ if not existing_def.get("metadata"):
145
+ existing_def["metadata"] = model_def.get("metadata", {})
146
+
147
+ # Handle components
148
+ if not existing_def.get("components"):
149
+ existing_def["components"] = {}
150
+
151
+ # Merge components
152
+ for comp_name, comp_def in model_def.get("components", {}).items():
153
+ if comp_name in existing_def["components"]:
154
+ # Update component fields if empty
155
+ existing_comp = existing_def["components"][comp_name]
156
+
157
+ if not existing_comp.get("class_name"):
158
+ existing_comp["class_name"] = comp_def.get("class_name", "")
159
+
160
+ if not existing_comp.get("source"):
161
+ existing_comp["source"] = comp_def.get("source", "")
162
+
163
+ if not existing_comp.get("kwargs"):
164
+ existing_comp["kwargs"] = comp_def.get("kwargs", {})
165
+ else:
166
+ # Add new component
167
+ existing_def["components"][comp_name] = comp_def
168
+ else:
169
+ # Add new model definition
170
+ merged_defs[model_id] = model_def
171
+
172
+ # Remove models without a source
173
+ models_to_remove = [model_id for model_id, def_obj in merged_defs.items()
174
+ if not def_obj.get("source")]
175
+ for model_id in models_to_remove:
176
+ del merged_defs[model_id]
177
+
178
+ return merged_defs
179
+
180
+
181
+ def load_config() -> RuntimeConfig:
182
+ """
183
+ Load the configuration from a YAML file located at COZY_HOME/config.yaml.
184
+ Merges it with default values and database pipeline definitions.
185
+ """
186
+ default_home = os.path.expanduser("~/.cozy-creator")
187
+
188
+ cozy_mount_path = os.getenv("COZY_MOUNT_PATH")
189
+ if cozy_mount_path:
190
+ default_home = cozy_mount_path
191
+
192
+ default_models_path = os.path.join(default_home, "models")
193
+
194
+ default_config = {
195
+ "home_dir": default_home,
196
+ "environment": "dev",
197
+ "host": "localhost",
198
+ "port": 8882,
199
+ "pipeline_defs": {},
200
+ # "enabled_models": [],
201
+ "models_path": default_models_path,
202
+ }
203
+
204
+ print(f"default_config: {default_config}")
205
+
206
+ # Use COZY_HOME if set, else default.
207
+ home_dir = os.environ.get("COZY_HOME", default_home)
208
+ config_path = os.path.join(home_dir, "config.yaml")
209
+
210
+ merged = default_config.copy()
211
+
212
+ if os.path.exists(config_path):
213
+ try:
214
+ with open(config_path, "r") as f:
215
+ yaml_config = yaml.safe_load(f) or {}
216
+
217
+ # Merge basic config values
218
+ # for key, value in yaml_config.items():
219
+ # if key != "pipeline_defs":
220
+ # merged[key] = value
221
+
222
+ # Get pipeline defs from config
223
+ config_pipeline_defs = yaml_config.get("pipeline_defs", {})
224
+ except Exception as e:
225
+ logger.error(f"Error loading config from {config_path}: {e}")
226
+ config_pipeline_defs = {}
227
+ else:
228
+ logger.warning(f"Config file {config_path} not found. Using default configuration.")
229
+ config_pipeline_defs = {}
230
+
231
+ # Get enabled models from environment if provided
232
+ # enabled_models_env = os.environ.get("ENABLED_MODELS")
233
+ # if enabled_models_env:
234
+ # try:
235
+ # # If provided as JSON.
236
+ # merged["enabled_models"] = json.loads(enabled_models_env)
237
+ # except Exception:
238
+ # # Otherwise, assume comma-separated.
239
+ # merged["enabled_models"] = [m.strip() for m in enabled_models_env.split(",") if m.strip()]
240
+
241
+ # Load ALL pipeline definitions from the database
242
+ # The modified load_pipeline_defs_from_db with no args should fetch all.
243
+ logger.info("Loading all pipeline definitions from database...")
244
+ db_pipeline_defs = load_pipeline_defs_from_db(enabled_models=None)
245
+ print(f"db_pipeline_defs: {db_pipeline_defs}")
246
+
247
+ # Merge pipeline definitions from config and database
248
+ merged["pipeline_defs"] = merge_pipeline_defs(config_pipeline_defs, db_pipeline_defs)
249
+ print(f"merged: {merged}")
250
+
251
+ return RuntimeConfig(**merged)
252
+
253
+
254
+ def set_config(config: RuntimeConfig):
255
+ """
256
+ Sets the global configuration object .
257
+ """
258
+ global cozy_config
259
+ cozy_config = config
260
+
261
+
262
+ def set_environment(environment: str):
263
+ """
264
+ Sets the global environment variable.
265
+ """
266
+ global ENVIRONMENT
267
+ ENVIRONMENT = environment
268
+
269
+
270
+ def get_environment() -> str:
271
+ """
272
+ Returns the global environment variable.
273
+ """
274
+ if ENVIRONMENT is None:
275
+ raise ValueError("Environment has not been set yet")
276
+
277
+ return ENVIRONMENT
278
+
279
+
280
+ def get_config() -> RuntimeConfig:
281
+ """
282
+ Returns the global configuration object. This is only available if the config has been loaded, which happens at
283
+ the start of the server, else it will raise an error.
284
+ """
285
+ if cozy_config is None:
286
+ raise ValueError("Config has not been loaded yet")
287
+
288
+ return cozy_config
289
+
290
+
291
+ ParseArgsMethod = Callable[
292
+ [argparse.ArgumentParser, Optional[List[str]], Optional[argparse.Namespace]],
293
+ Optional[argparse.Namespace],
294
+ ]
295
+
296
+
297
+ def is_model_enabled(model_name: str) -> bool:
298
+ """
299
+ Returns a boolean indicating whether a model is enabled in the global configuration.
300
+ """
301
+ config = get_config()
302
+ if config.pipeline_defs is None:
303
+ return False
304
+
305
+ return model_name in config.pipeline_defs.keys()
306
+
307
+
308
+ def get_mock_config() -> RuntimeConfig:
309
+ """
310
+ Returns a mock (or test) version of the global configuration object.
311
+ This can be used outside of the cozy server environment.
312
+ """
313
+
314
+ environment = "test"
315
+ # home_dir = DEFAULT_HOME_DIR
316
+
317
+ return RuntimeConfig(
318
+ port=8881,
319
+ host="127.0.0.1",
320
+ environment=environment,
321
+ )
@@ -0,0 +1,46 @@
1
+ import os
2
+ import logging
3
+ from typing import List, Optional
4
+ import psycopg2
5
+ from psycopg2.extras import RealDictCursor
6
+ from dotenv import load_dotenv
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ # Load environment variables from .env file
11
+ load_dotenv()
12
+
13
+ _db_connection = None
14
+
15
+ def get_db_connection():
16
+ """
17
+ Returns a connection to the database. Creates a new connection if one doesn't exist.
18
+ Uses the DB_DSN environment variable for connection details.
19
+ """
20
+ global _db_connection
21
+
22
+ if _db_connection is None:
23
+ db_dsn = os.getenv("DB_DSN")
24
+ if not db_dsn:
25
+ raise ValueError("DB_DSN environment variable not set")
26
+
27
+ try:
28
+ logger.info("Connecting to database...")
29
+ _db_connection = psycopg2.connect(db_dsn, cursor_factory=RealDictCursor)
30
+ logger.info("Database connection established")
31
+ except Exception as e:
32
+ logger.error(f"Error connecting to database: {e}")
33
+ raise
34
+
35
+ return _db_connection
36
+
37
+ def close_db_connection():
38
+ """
39
+ Closes the database connection if it exists.
40
+ """
41
+ global _db_connection
42
+
43
+ if _db_connection is not None:
44
+ _db_connection.close()
45
+ _db_connection = None
46
+ logger.info("Database connection closed")
@@ -0,0 +1,26 @@
1
+ import logging
2
+ import torch
3
+
4
+ logger = logging.getLogger(__name__)
5
+
6
+
7
+ def get_torch_device(index: int = 0) -> torch.device:
8
+ if torch.cuda.is_available():
9
+ return torch.device("cuda", index)
10
+ if torch.backends.mps.is_available():
11
+ return torch.device("mps", index)
12
+ if torch.xpu.is_available():
13
+ return torch.device("xpu", index)
14
+
15
+ logger.warning("No device found, using CPU. This will slow down performance.")
16
+ return torch.device("cpu")
17
+
18
+
19
+ def get_torch_device_count() -> int:
20
+ if torch.cuda.is_available():
21
+ return torch.cuda.device_count()
22
+ if torch.backends.mps.is_available():
23
+ return torch.mps.device_count()
24
+ if torch.xpu.is_available():
25
+ return torch.xpu.device_count()
26
+ return 1
@@ -0,0 +1,10 @@
1
+ def fix_sdxl_compat():
2
+ """Fix SDXL compatibility across diffusers versions"""
3
+ try:
4
+ from diffusers import StableDiffusionXLPipeline
5
+ if not hasattr(StableDiffusionXLPipeline, 'do_classifier_free_guidance'):
6
+ StableDiffusionXLPipeline.do_classifier_free_guidance = True
7
+ except:
8
+ pass
9
+
10
+ fix_sdxl_compat() # Auto-run on import