abstractvision 0.1.0__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- abstractvision/__init__.py +18 -3
- abstractvision/__main__.py +8 -0
- abstractvision/artifacts.py +320 -0
- abstractvision/assets/vision_model_capabilities.json +406 -0
- abstractvision/backends/__init__.py +43 -0
- abstractvision/backends/base_backend.py +63 -0
- abstractvision/backends/huggingface_diffusers.py +1503 -0
- abstractvision/backends/openai_compatible.py +325 -0
- abstractvision/backends/stable_diffusion_cpp.py +751 -0
- abstractvision/cli.py +778 -0
- abstractvision/errors.py +19 -0
- abstractvision/integrations/__init__.py +5 -0
- abstractvision/integrations/abstractcore.py +263 -0
- abstractvision/integrations/abstractcore_plugin.py +193 -0
- abstractvision/model_capabilities.py +255 -0
- abstractvision/types.py +95 -0
- abstractvision/vision_manager.py +115 -0
- abstractvision-0.2.1.dist-info/METADATA +243 -0
- abstractvision-0.2.1.dist-info/RECORD +23 -0
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/WHEEL +1 -1
- abstractvision-0.2.1.dist-info/entry_points.txt +5 -0
- abstractvision-0.1.0.dist-info/METADATA +0 -65
- abstractvision-0.1.0.dist-info/RECORD +0 -6
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/licenses/LICENSE +0 -0
- {abstractvision-0.1.0.dist-info → abstractvision-0.2.1.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import pkgutil
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Union
|
|
7
|
+
|
|
8
|
+
from .errors import CapabilityNotSupportedError, UnknownModelError
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass(frozen=True)
|
|
12
|
+
class VisionTaskSpec:
|
|
13
|
+
"""Declarative spec for a single task supported by a model."""
|
|
14
|
+
|
|
15
|
+
task: str
|
|
16
|
+
inputs: List[str]
|
|
17
|
+
outputs: List[str]
|
|
18
|
+
params: Dict[str, Any]
|
|
19
|
+
requires: Optional[Dict[str, Any]] = None
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
@dataclass(frozen=True)
|
|
23
|
+
class VisionModelSpec:
|
|
24
|
+
model_id: str
|
|
25
|
+
provider: str
|
|
26
|
+
license: str
|
|
27
|
+
tasks: Dict[str, VisionTaskSpec]
|
|
28
|
+
notes: str = ""
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class VisionModelCapabilitiesRegistry:
|
|
32
|
+
"""Loads `assets/vision_model_capabilities.json` and answers capability questions."""
|
|
33
|
+
|
|
34
|
+
DEFAULT_ASSET_PATH = "assets/vision_model_capabilities.json"
|
|
35
|
+
|
|
36
|
+
def __init__(self, *, asset_path: Optional[str] = None):
|
|
37
|
+
self._asset_path = asset_path or self.DEFAULT_ASSET_PATH
|
|
38
|
+
self._schema_version: str = ""
|
|
39
|
+
self._tasks: Dict[str, Dict[str, Any]] = {}
|
|
40
|
+
self._models: Dict[str, VisionModelSpec] = {}
|
|
41
|
+
self._load()
|
|
42
|
+
|
|
43
|
+
def _load(self) -> None:
|
|
44
|
+
raw = pkgutil.get_data("abstractvision", self._asset_path)
|
|
45
|
+
if raw is None:
|
|
46
|
+
raise RuntimeError(f"Capability asset not found: abstractvision/{self._asset_path}")
|
|
47
|
+
data = json.loads(raw.decode("utf-8"))
|
|
48
|
+
validate_capabilities_json(data)
|
|
49
|
+
|
|
50
|
+
self._schema_version = str(data.get("schema_version") or "")
|
|
51
|
+
tasks_raw = data.get("tasks", {})
|
|
52
|
+
self._tasks = tasks_raw if isinstance(tasks_raw, dict) else {}
|
|
53
|
+
|
|
54
|
+
models = data.get("models", {})
|
|
55
|
+
if not isinstance(models, dict):
|
|
56
|
+
raise ValueError("Invalid capability asset: `models` must be an object keyed by model_id.")
|
|
57
|
+
|
|
58
|
+
parsed: Dict[str, VisionModelSpec] = {}
|
|
59
|
+
for model_id, spec in models.items():
|
|
60
|
+
provider = str(spec.get("provider", "unknown"))
|
|
61
|
+
license_name = str(spec.get("license", "unknown"))
|
|
62
|
+
notes = str(spec.get("notes", "")) if spec.get("notes") is not None else ""
|
|
63
|
+
|
|
64
|
+
tasks_raw = spec.get("tasks", {})
|
|
65
|
+
if not isinstance(tasks_raw, dict):
|
|
66
|
+
raise ValueError(f"Invalid tasks for model {model_id}: must be an object keyed by task.")
|
|
67
|
+
|
|
68
|
+
tasks: Dict[str, VisionTaskSpec] = {}
|
|
69
|
+
for task_name, t in tasks_raw.items():
|
|
70
|
+
inputs = list(t.get("inputs", []))
|
|
71
|
+
outputs = list(t.get("outputs", []))
|
|
72
|
+
params = dict(t.get("params", {}))
|
|
73
|
+
requires = t.get("requires")
|
|
74
|
+
tasks[task_name] = VisionTaskSpec(
|
|
75
|
+
task=str(task_name),
|
|
76
|
+
inputs=[str(x) for x in inputs],
|
|
77
|
+
outputs=[str(x) for x in outputs],
|
|
78
|
+
params=params,
|
|
79
|
+
requires=requires if isinstance(requires, dict) else None,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
parsed[str(model_id)] = VisionModelSpec(
|
|
83
|
+
model_id=str(model_id),
|
|
84
|
+
provider=provider,
|
|
85
|
+
license=license_name,
|
|
86
|
+
tasks=tasks,
|
|
87
|
+
notes=notes,
|
|
88
|
+
)
|
|
89
|
+
|
|
90
|
+
self._models = parsed
|
|
91
|
+
|
|
92
|
+
def list_models(self) -> List[str]:
|
|
93
|
+
return sorted(self._models.keys())
|
|
94
|
+
|
|
95
|
+
def schema_version(self) -> str:
|
|
96
|
+
return self._schema_version
|
|
97
|
+
|
|
98
|
+
def list_tasks(self) -> List[str]:
|
|
99
|
+
return sorted([str(k) for k in self._tasks.keys() if isinstance(k, str) and k.strip()])
|
|
100
|
+
|
|
101
|
+
def get_task(self, task: str) -> Dict[str, Any]:
|
|
102
|
+
task = str(task or "")
|
|
103
|
+
out = self._tasks.get(task)
|
|
104
|
+
if isinstance(out, dict):
|
|
105
|
+
return out
|
|
106
|
+
return {}
|
|
107
|
+
|
|
108
|
+
def get(self, model_id: str) -> VisionModelSpec:
|
|
109
|
+
try:
|
|
110
|
+
return self._models[model_id]
|
|
111
|
+
except KeyError as e:
|
|
112
|
+
raise UnknownModelError(f"Unknown vision model id: {model_id}") from e
|
|
113
|
+
|
|
114
|
+
def supports(self, model_id: str, task: str) -> bool:
|
|
115
|
+
try:
|
|
116
|
+
return task in self.get(model_id).tasks
|
|
117
|
+
except UnknownModelError:
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
def require_support(self, model_id: str, task: str) -> None:
|
|
121
|
+
if not self.supports(model_id, task):
|
|
122
|
+
raise CapabilityNotSupportedError(f"Model '{model_id}' does not support task '{task}'.")
|
|
123
|
+
|
|
124
|
+
def models_for_task(self, task: str) -> List[str]:
|
|
125
|
+
out: List[str] = []
|
|
126
|
+
for mid, spec in self._models.items():
|
|
127
|
+
if task in spec.tasks:
|
|
128
|
+
out.append(mid)
|
|
129
|
+
return sorted(out)
|
|
130
|
+
|
|
131
|
+
def iter_task_specs(self, model_id: str) -> Iterable[VisionTaskSpec]:
|
|
132
|
+
return self.get(model_id).tasks.values()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
_PathPart = Union[str, int]
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _fmt_path(parts: Sequence[_PathPart]) -> str:
|
|
139
|
+
out: List[str] = []
|
|
140
|
+
for p in parts:
|
|
141
|
+
if isinstance(p, int):
|
|
142
|
+
out.append(f"[{p}]")
|
|
143
|
+
else:
|
|
144
|
+
if not out:
|
|
145
|
+
out.append(str(p))
|
|
146
|
+
else:
|
|
147
|
+
out.append(f"[{p!r}]")
|
|
148
|
+
return "".join(out) if out else "<root>"
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def validate_capabilities_json(data: Any) -> None:
|
|
152
|
+
"""Validate the `vision_model_capabilities.json` schema (dependency-light).
|
|
153
|
+
|
|
154
|
+
This is intentionally a "soft schema": it enforces required structure and
|
|
155
|
+
internal reference integrity, while allowing additive fields.
|
|
156
|
+
"""
|
|
157
|
+
if not isinstance(data, dict):
|
|
158
|
+
raise ValueError("Invalid capability asset: top-level JSON must be an object.")
|
|
159
|
+
|
|
160
|
+
schema_version = data.get("schema_version")
|
|
161
|
+
if schema_version is None:
|
|
162
|
+
raise ValueError("Invalid capability asset: missing required key 'schema_version'.")
|
|
163
|
+
if not isinstance(schema_version, (str, int, float)):
|
|
164
|
+
raise ValueError("Invalid capability asset: 'schema_version' must be a string or number.")
|
|
165
|
+
|
|
166
|
+
tasks = data.get("tasks")
|
|
167
|
+
if not isinstance(tasks, dict):
|
|
168
|
+
raise ValueError("Invalid capability asset: 'tasks' must be an object keyed by task name.")
|
|
169
|
+
for task_name, task_spec in tasks.items():
|
|
170
|
+
if not isinstance(task_name, str) or not task_name.strip():
|
|
171
|
+
raise ValueError(f"Invalid capability asset: task key must be a non-empty string (got {task_name!r}).")
|
|
172
|
+
if not isinstance(task_spec, dict):
|
|
173
|
+
raise ValueError(f"Invalid capability asset: tasks[{task_name!r}] must be an object.")
|
|
174
|
+
desc = task_spec.get("description")
|
|
175
|
+
if desc is not None and not isinstance(desc, str):
|
|
176
|
+
raise ValueError(f"Invalid capability asset: tasks[{task_name!r}]['description'] must be a string.")
|
|
177
|
+
|
|
178
|
+
models = data.get("models")
|
|
179
|
+
if not isinstance(models, dict):
|
|
180
|
+
raise ValueError("Invalid capability asset: 'models' must be an object keyed by model_id.")
|
|
181
|
+
|
|
182
|
+
def _err(path: Sequence[_PathPart], msg: str) -> None:
|
|
183
|
+
raise ValueError(f"Invalid capability asset at {_fmt_path(path)}: {msg}")
|
|
184
|
+
|
|
185
|
+
def _expect_dict(value: Any, path: Sequence[_PathPart]) -> Dict[str, Any]:
|
|
186
|
+
if not isinstance(value, dict):
|
|
187
|
+
_err(path, "expected object")
|
|
188
|
+
return value
|
|
189
|
+
|
|
190
|
+
def _expect_str(value: Any, path: Sequence[_PathPart]) -> str:
|
|
191
|
+
if not isinstance(value, str) or not value.strip():
|
|
192
|
+
_err(path, "expected non-empty string")
|
|
193
|
+
return value
|
|
194
|
+
|
|
195
|
+
def _expect_list_of_str(value: Any, path: Sequence[_PathPart]) -> List[str]:
|
|
196
|
+
if not isinstance(value, list):
|
|
197
|
+
_err(path, "expected list of strings")
|
|
198
|
+
out: List[str] = []
|
|
199
|
+
for i, item in enumerate(value):
|
|
200
|
+
if not isinstance(item, str) or not item.strip():
|
|
201
|
+
_err([*path, i], "expected non-empty string")
|
|
202
|
+
out.append(item)
|
|
203
|
+
return out
|
|
204
|
+
|
|
205
|
+
for model_id, model_spec in models.items():
|
|
206
|
+
if not isinstance(model_id, str) or not model_id.strip():
|
|
207
|
+
_err(["models"], f"model key must be a non-empty string (got {model_id!r})")
|
|
208
|
+
model_path: List[_PathPart] = ["models", model_id]
|
|
209
|
+
m = _expect_dict(model_spec, model_path)
|
|
210
|
+
|
|
211
|
+
provider = m.get("provider")
|
|
212
|
+
if provider is not None:
|
|
213
|
+
_expect_str(provider, [*model_path, "provider"])
|
|
214
|
+
else:
|
|
215
|
+
_err([*model_path, "provider"], "missing required key")
|
|
216
|
+
|
|
217
|
+
license_name = m.get("license")
|
|
218
|
+
if license_name is not None:
|
|
219
|
+
_expect_str(license_name, [*model_path, "license"])
|
|
220
|
+
else:
|
|
221
|
+
_err([*model_path, "license"], "missing required key")
|
|
222
|
+
|
|
223
|
+
tasks_raw = m.get("tasks")
|
|
224
|
+
if tasks_raw is None:
|
|
225
|
+
_err([*model_path, "tasks"], "missing required key")
|
|
226
|
+
tmap = _expect_dict(tasks_raw, [*model_path, "tasks"])
|
|
227
|
+
for task_name, task_spec in tmap.items():
|
|
228
|
+
_expect_str(task_name, [*model_path, "tasks", task_name])
|
|
229
|
+
if task_name not in tasks:
|
|
230
|
+
_err([*model_path, "tasks", task_name], "unknown task (not present in top-level 'tasks')")
|
|
231
|
+
|
|
232
|
+
tpath: List[_PathPart] = [*model_path, "tasks", task_name]
|
|
233
|
+
t = _expect_dict(task_spec, tpath)
|
|
234
|
+
_expect_list_of_str(t.get("inputs", []), [*tpath, "inputs"])
|
|
235
|
+
_expect_list_of_str(t.get("outputs", []), [*tpath, "outputs"])
|
|
236
|
+
|
|
237
|
+
params = t.get("params")
|
|
238
|
+
if params is None:
|
|
239
|
+
_err([*tpath, "params"], "missing required key")
|
|
240
|
+
pmap = _expect_dict(params, [*tpath, "params"])
|
|
241
|
+
for pname, pspec in pmap.items():
|
|
242
|
+
_expect_str(pname, [*tpath, "params", pname])
|
|
243
|
+
pobj = _expect_dict(pspec, [*tpath, "params", pname])
|
|
244
|
+
required = pobj.get("required")
|
|
245
|
+
if not isinstance(required, bool):
|
|
246
|
+
_err([*tpath, "params", pname, "required"], "expected boolean")
|
|
247
|
+
|
|
248
|
+
requires = t.get("requires")
|
|
249
|
+
if requires is not None:
|
|
250
|
+
robj = _expect_dict(requires, [*tpath, "requires"])
|
|
251
|
+
base_model_id = robj.get("base_model_id")
|
|
252
|
+
if base_model_id is not None:
|
|
253
|
+
_expect_str(base_model_id, [*tpath, "requires", "base_model_id"])
|
|
254
|
+
if base_model_id not in models:
|
|
255
|
+
_err([*tpath, "requires", "base_model_id"], f"unknown model id: {base_model_id!r}")
|
abstractvision/types.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import Any, Dict, Optional, Sequence
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass(frozen=True)
|
|
8
|
+
class VisionBackendCapabilities:
|
|
9
|
+
"""Backend-level capability constraints (optional; additive).
|
|
10
|
+
|
|
11
|
+
This complements the model registry (what a model *can* do) with runtime/backend
|
|
12
|
+
constraints (what a configured backend *will* do).
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
supported_tasks: Optional[Sequence[str]] = None
|
|
16
|
+
supports_mask: Optional[bool] = None
|
|
17
|
+
max_width: Optional[int] = None
|
|
18
|
+
max_height: Optional[int] = None
|
|
19
|
+
max_fps: Optional[int] = None
|
|
20
|
+
max_frames: Optional[int] = None
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass(frozen=True)
|
|
24
|
+
class ImageGenerationRequest:
|
|
25
|
+
prompt: str
|
|
26
|
+
negative_prompt: Optional[str] = None
|
|
27
|
+
width: Optional[int] = None
|
|
28
|
+
height: Optional[int] = None
|
|
29
|
+
seed: Optional[int] = None
|
|
30
|
+
steps: Optional[int] = None
|
|
31
|
+
guidance_scale: Optional[float] = None
|
|
32
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@dataclass(frozen=True)
|
|
36
|
+
class ImageEditRequest:
|
|
37
|
+
prompt: str
|
|
38
|
+
image: bytes
|
|
39
|
+
mask: Optional[bytes] = None
|
|
40
|
+
negative_prompt: Optional[str] = None
|
|
41
|
+
seed: Optional[int] = None
|
|
42
|
+
steps: Optional[int] = None
|
|
43
|
+
guidance_scale: Optional[float] = None
|
|
44
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass(frozen=True)
|
|
48
|
+
class MultiAngleRequest:
|
|
49
|
+
prompt: str
|
|
50
|
+
reference_image: Optional[bytes] = None
|
|
51
|
+
angles: Sequence[str] = ("front", "three_quarter", "side", "back")
|
|
52
|
+
negative_prompt: Optional[str] = None
|
|
53
|
+
seed: Optional[int] = None
|
|
54
|
+
steps: Optional[int] = None
|
|
55
|
+
guidance_scale: Optional[float] = None
|
|
56
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
@dataclass(frozen=True)
|
|
60
|
+
class VideoGenerationRequest:
|
|
61
|
+
prompt: str
|
|
62
|
+
negative_prompt: Optional[str] = None
|
|
63
|
+
width: Optional[int] = None
|
|
64
|
+
height: Optional[int] = None
|
|
65
|
+
fps: Optional[int] = None
|
|
66
|
+
num_frames: Optional[int] = None
|
|
67
|
+
seed: Optional[int] = None
|
|
68
|
+
steps: Optional[int] = None
|
|
69
|
+
guidance_scale: Optional[float] = None
|
|
70
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
@dataclass(frozen=True)
|
|
74
|
+
class ImageToVideoRequest:
|
|
75
|
+
image: bytes
|
|
76
|
+
prompt: Optional[str] = None
|
|
77
|
+
negative_prompt: Optional[str] = None
|
|
78
|
+
width: Optional[int] = None
|
|
79
|
+
height: Optional[int] = None
|
|
80
|
+
fps: Optional[int] = None
|
|
81
|
+
num_frames: Optional[int] = None
|
|
82
|
+
seed: Optional[int] = None
|
|
83
|
+
steps: Optional[int] = None
|
|
84
|
+
guidance_scale: Optional[float] = None
|
|
85
|
+
extra: Dict[str, Any] = field(default_factory=dict)
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass(frozen=True)
|
|
89
|
+
class GeneratedAsset:
|
|
90
|
+
"""Generic return type for generated media."""
|
|
91
|
+
|
|
92
|
+
media_type: str # "image" | "video"
|
|
93
|
+
data: bytes
|
|
94
|
+
mime_type: str
|
|
95
|
+
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass
|
|
4
|
+
from typing import Any, Dict, List, Optional, Union
|
|
5
|
+
|
|
6
|
+
from .backends.base_backend import VisionBackend
|
|
7
|
+
from .artifacts import MediaStore
|
|
8
|
+
from .errors import BackendNotConfiguredError, CapabilityNotSupportedError
|
|
9
|
+
from .model_capabilities import VisionModelCapabilitiesRegistry
|
|
10
|
+
from .types import (
|
|
11
|
+
GeneratedAsset,
|
|
12
|
+
ImageEditRequest,
|
|
13
|
+
ImageGenerationRequest,
|
|
14
|
+
ImageToVideoRequest,
|
|
15
|
+
MultiAngleRequest,
|
|
16
|
+
VideoGenerationRequest,
|
|
17
|
+
VisionBackendCapabilities,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@dataclass
|
|
22
|
+
class VisionManager:
|
|
23
|
+
"""High-level orchestrator for generative vision tasks.
|
|
24
|
+
|
|
25
|
+
Intentionally thin: delegates execution to the configured backend.
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
backend: Optional[VisionBackend] = None
|
|
29
|
+
store: Optional[MediaStore] = None
|
|
30
|
+
model_id: Optional[str] = None
|
|
31
|
+
registry: Optional[VisionModelCapabilitiesRegistry] = None
|
|
32
|
+
|
|
33
|
+
def __post_init__(self) -> None:
|
|
34
|
+
if self.model_id and self.registry is None:
|
|
35
|
+
self.registry = VisionModelCapabilitiesRegistry()
|
|
36
|
+
|
|
37
|
+
def _require_backend(self) -> VisionBackend:
|
|
38
|
+
if self.backend is None:
|
|
39
|
+
raise BackendNotConfiguredError(
|
|
40
|
+
"No vision backend configured. "
|
|
41
|
+
"Provide a backend to VisionManager(backend=...) before calling generation methods."
|
|
42
|
+
)
|
|
43
|
+
return self.backend
|
|
44
|
+
|
|
45
|
+
def _require_model_support(self, task: str) -> None:
|
|
46
|
+
if not self.model_id:
|
|
47
|
+
return
|
|
48
|
+
reg = self.registry or VisionModelCapabilitiesRegistry()
|
|
49
|
+
# Keep a reference so repeated calls don't reload the asset.
|
|
50
|
+
self.registry = reg
|
|
51
|
+
reg.require_support(str(self.model_id), str(task))
|
|
52
|
+
|
|
53
|
+
def _backend_caps(self, backend: VisionBackend) -> Optional[VisionBackendCapabilities]:
|
|
54
|
+
try:
|
|
55
|
+
return backend.get_capabilities()
|
|
56
|
+
except Exception:
|
|
57
|
+
return None
|
|
58
|
+
|
|
59
|
+
def _require_backend_support(self, backend: VisionBackend, task: str) -> Optional[VisionBackendCapabilities]:
|
|
60
|
+
caps = self._backend_caps(backend)
|
|
61
|
+
if caps is None:
|
|
62
|
+
return None
|
|
63
|
+
if caps.supported_tasks is not None and str(task) not in set([str(t) for t in caps.supported_tasks]):
|
|
64
|
+
raise CapabilityNotSupportedError(f"Backend does not support task '{task}'.")
|
|
65
|
+
return caps
|
|
66
|
+
|
|
67
|
+
def _maybe_store(self, asset: GeneratedAsset, *, tags: Optional[Dict[str, str]] = None) -> Union[GeneratedAsset, Dict[str, Any]]:
|
|
68
|
+
if self.store is None:
|
|
69
|
+
return asset
|
|
70
|
+
return self.store.store_bytes(
|
|
71
|
+
asset.data,
|
|
72
|
+
content_type=asset.mime_type,
|
|
73
|
+
metadata=asset.metadata,
|
|
74
|
+
tags=tags,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
def generate_image(self, prompt: str, **kwargs) -> Union[GeneratedAsset, Dict[str, Any]]:
|
|
78
|
+
backend = self._require_backend()
|
|
79
|
+
self._require_model_support("text_to_image")
|
|
80
|
+
self._require_backend_support(backend, "text_to_image")
|
|
81
|
+
asset = backend.generate_image(ImageGenerationRequest(prompt=prompt, **kwargs))
|
|
82
|
+
return self._maybe_store(asset, tags={"kind": "generated_media", "modality": "image", "task": "text_to_image"})
|
|
83
|
+
|
|
84
|
+
def edit_image(self, prompt: str, image: bytes, **kwargs) -> Union[GeneratedAsset, Dict[str, Any]]:
|
|
85
|
+
backend = self._require_backend()
|
|
86
|
+
self._require_model_support("image_to_image")
|
|
87
|
+
caps = self._require_backend_support(backend, "image_to_image")
|
|
88
|
+
mask = kwargs.get("mask")
|
|
89
|
+
if mask is not None and caps is not None and caps.supports_mask is False:
|
|
90
|
+
raise CapabilityNotSupportedError("Backend does not support masked image edits (mask parameter).")
|
|
91
|
+
asset = backend.edit_image(ImageEditRequest(prompt=prompt, image=image, **kwargs))
|
|
92
|
+
return self._maybe_store(asset, tags={"kind": "generated_media", "modality": "image", "task": "image_to_image"})
|
|
93
|
+
|
|
94
|
+
def generate_angles(self, prompt: str, **kwargs) -> Union[List[GeneratedAsset], List[Dict[str, Any]]]:
|
|
95
|
+
backend = self._require_backend()
|
|
96
|
+
self._require_model_support("multi_view_image")
|
|
97
|
+
self._require_backend_support(backend, "multi_view_image")
|
|
98
|
+
assets = backend.generate_angles(MultiAngleRequest(prompt=prompt, **kwargs))
|
|
99
|
+
if self.store is None:
|
|
100
|
+
return assets
|
|
101
|
+
return [self._maybe_store(a, tags={"kind": "generated_media", "modality": "image", "task": "multi_view_image"}) for a in assets] # type: ignore[return-value]
|
|
102
|
+
|
|
103
|
+
def generate_video(self, prompt: str, **kwargs) -> Union[GeneratedAsset, Dict[str, Any]]:
|
|
104
|
+
backend = self._require_backend()
|
|
105
|
+
self._require_model_support("text_to_video")
|
|
106
|
+
self._require_backend_support(backend, "text_to_video")
|
|
107
|
+
asset = backend.generate_video(VideoGenerationRequest(prompt=prompt, **kwargs))
|
|
108
|
+
return self._maybe_store(asset, tags={"kind": "generated_media", "modality": "video", "task": "text_to_video"})
|
|
109
|
+
|
|
110
|
+
def image_to_video(self, image: bytes, **kwargs) -> Union[GeneratedAsset, Dict[str, Any]]:
|
|
111
|
+
backend = self._require_backend()
|
|
112
|
+
self._require_model_support("image_to_video")
|
|
113
|
+
self._require_backend_support(backend, "image_to_video")
|
|
114
|
+
asset = backend.image_to_video(ImageToVideoRequest(image=image, **kwargs))
|
|
115
|
+
return self._maybe_store(asset, tags={"kind": "generated_media", "modality": "video", "task": "image_to_video"})
|