drawthings-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- drawthings_py/__init__.py +19 -0
- drawthings_py/_dt_service.py +39 -0
- drawthings_py/_errors.py +21 -0
- drawthings_py/_metadata.py +338 -0
- drawthings_py/_png_writer.py +262 -0
- drawthings_py/_preview_decoders.py +657 -0
- drawthings_py/_util.py +102 -0
- drawthings_py/cli_service.py +25 -0
- drawthings_py/configs/__init__.py +31 -0
- drawthings_py/configs/config_dict.py +287 -0
- drawthings_py/configs/config_prop.py +576 -0
- drawthings_py/configs/configs.py +110 -0
- drawthings_py/configs/gen_config.py +397 -0
- drawthings_py/configs/presets.py +58 -0
- drawthings_py/configs/types.py +261 -0
- drawthings_py/drawthings.py +34 -0
- drawthings_py/filename_pattern.py +136 -0
- drawthings_py/generated/dt_grpc/__init__.py +0 -0
- drawthings_py/generated/dt_grpc/config_generated.py +1763 -0
- drawthings_py/generated/dt_grpc/config_generated.pyi +527 -0
- drawthings_py/generated/dt_grpc/image_service/__init__.py +502 -0
- drawthings_py/grpc_service.py +303 -0
- drawthings_py/image_buffer.py +471 -0
- drawthings_py/py.typed +0 -0
- drawthings_py/request_builder.py +423 -0
- drawthings_py/resources/config_props.yaml +926 -0
- drawthings_py/resources/configs/anima_preview_3.json +36 -0
- drawthings_py/resources/configs/chroma_hd.json +36 -0
- drawthings_py/resources/configs/default.json +34 -0
- drawthings_py/resources/configs/ernie_image_base.json +35 -0
- drawthings_py/resources/configs/ernie_image_turbo.json +36 -0
- drawthings_py/resources/configs/flux_1_dev.json +38 -0
- drawthings_py/resources/configs/flux_1_fill_dev.json +34 -0
- drawthings_py/resources/configs/flux_1_schnell.json +34 -0
- drawthings_py/resources/configs/flux_2_dev_with_turbo.json +43 -0
- drawthings_py/resources/configs/flux_2_klein_4b.json +37 -0
- drawthings_py/resources/configs/flux_2_klein_4b_base.json +35 -0
- drawthings_py/resources/configs/flux_2_klein_9b.json +37 -0
- drawthings_py/resources/configs/flux_2_klein_9b_base.json +35 -0
- drawthings_py/resources/configs/flux_2_klein_9b_kv.json +36 -0
- drawthings_py/resources/configs/qwen_image_2512.json +35 -0
- drawthings_py/resources/configs/qwen_image_2512_lightning.json +41 -0
- drawthings_py/resources/configs/qwen_image_edit_2511.json +34 -0
- drawthings_py/resources/configs/qwen_image_edit_2511_lightning.json +41 -0
- drawthings_py/resources/configs/sdxl.json +40 -0
- drawthings_py/resources/configs/stable_diffusion.json +29 -0
- drawthings_py/resources/configs/z_image_base.json +34 -0
- drawthings_py/resources/configs/z_image_turbo.json +36 -0
- drawthings_py/resources/root_ca.crt +30 -0
- drawthings_py/seed_provider.py +23 -0
- drawthings_py-0.1.0.dist-info/METADATA +240 -0
- drawthings_py-0.1.0.dist-info/RECORD +54 -0
- drawthings_py-0.1.0.dist-info/WHEEL +4 -0
- drawthings_py-0.1.0.dist-info/licenses/LICENSE +675 -0
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Draw Things Python SDK
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import drawthings_py.drawthings as DrawThings
|
|
6
|
+
from .configs import Configs, ConfigDict, Presets
|
|
7
|
+
from .filename_pattern import FilenamePattern
|
|
8
|
+
from .image_buffer import ImageBuffer
|
|
9
|
+
from .request_builder import RequestBuilder
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"DrawThings",
|
|
13
|
+
"RequestBuilder",
|
|
14
|
+
"ImageBuffer",
|
|
15
|
+
"ConfigDict",
|
|
16
|
+
"Configs",
|
|
17
|
+
"Presets",
|
|
18
|
+
"FilenamePattern",
|
|
19
|
+
]
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Primary entry point for using Draw Things services
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from __future__ import annotations
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from types import TracebackType
|
|
8
|
+
|
|
9
|
+
from .image_buffer import ImageBuffer
|
|
10
|
+
from .request_builder import RequestBuilder
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class DrawThingsService(ABC):
|
|
14
|
+
"""
|
|
15
|
+
Base class for grpc and cii service
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
async def generate_image(self, request: RequestBuilder) -> list[ImageBuffer]:
|
|
20
|
+
"""
|
|
21
|
+
Generate an image from the provided request builder
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def _dispose(self):
|
|
26
|
+
"""
|
|
27
|
+
dispose of the service
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
async def __aenter__(self) -> "DrawThingsService":
|
|
31
|
+
return self
|
|
32
|
+
|
|
33
|
+
async def __aexit__(
|
|
34
|
+
self,
|
|
35
|
+
exc_type: type[BaseException] | None,
|
|
36
|
+
exc_val: BaseException | None,
|
|
37
|
+
exc_tb: TracebackType | None,
|
|
38
|
+
) -> None:
|
|
39
|
+
self._dispose()
|
drawthings_py/_errors.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from grpclib import GRPCError, Status
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class DrawThingsServerError(Exception):
|
|
5
|
+
pass
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class DrawThingsUnavailableError(Exception):
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def raise_grpc_error(e: GRPCError):
|
|
13
|
+
if e.status == Status.INTERNAL:
|
|
14
|
+
raise DrawThingsServerError(
|
|
15
|
+
f"There was an error on the server. This could be a temporary issue with DT+ or a problem with your request. ({e.status.name}: {e.message})"
|
|
16
|
+
) from e
|
|
17
|
+
|
|
18
|
+
if e.status == Status.UNAVAILABLE:
|
|
19
|
+
raise DrawThingsUnavailableError(
|
|
20
|
+
"The gRPC server is currently unavailable."
|
|
21
|
+
) from e
|
|
@@ -0,0 +1,338 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import sys
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Any, Required, TypedDict, cast
|
|
8
|
+
|
|
9
|
+
from .generated.dt_grpc.config_generated import GenerationConfiguration
|
|
10
|
+
|
|
11
|
+
SAMPLER_NAMES: dict[Any, str] = {
|
|
12
|
+
0: "DPM++ 2M Karras",
|
|
13
|
+
1: "Euler A",
|
|
14
|
+
2: "DDIM",
|
|
15
|
+
3: "PLMS",
|
|
16
|
+
4: "DPM++ SDE Karras",
|
|
17
|
+
5: "UniPC",
|
|
18
|
+
6: "LCM",
|
|
19
|
+
7: "Euler A Substep",
|
|
20
|
+
8: "DPM++ SDE Substep",
|
|
21
|
+
9: "TCD",
|
|
22
|
+
10: "Euler A Trailing",
|
|
23
|
+
11: "DPM++ SDE Trailing",
|
|
24
|
+
12: "DPM++ 2M AYS",
|
|
25
|
+
13: "Euler A AYS",
|
|
26
|
+
14: "DPM++ SDE AYS",
|
|
27
|
+
15: "DPM++ 2M Trailing",
|
|
28
|
+
16: "DDIM Trailing",
|
|
29
|
+
17: "UniPC Trailing",
|
|
30
|
+
18: "UniPC AYS",
|
|
31
|
+
19: "TCD Trailing",
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
SEED_MODE_NAMES = {
|
|
35
|
+
0: "Legacy",
|
|
36
|
+
1: "Torch CPU Compatible",
|
|
37
|
+
2: "Scale Alike",
|
|
38
|
+
3: "Nvidia GPU Compatible",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
COMPRESSION_ARTIFACT_NAMES = {
|
|
42
|
+
0: "disabled",
|
|
43
|
+
1: "h264",
|
|
44
|
+
2: "h265",
|
|
45
|
+
3: "jpeg",
|
|
46
|
+
}
|
|
47
|
+
|
|
48
|
+
LORA_MODE_NAMES = {
|
|
49
|
+
0: "all",
|
|
50
|
+
1: "base",
|
|
51
|
+
2: "refiner",
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
CONTROL_MODE_NAMES = {
|
|
55
|
+
0: "balanced",
|
|
56
|
+
1: "prompt",
|
|
57
|
+
2: "control",
|
|
58
|
+
}
|
|
59
|
+
|
|
60
|
+
CONTROL_INPUT_TYPE_NAMES = {
|
|
61
|
+
0: "unspecified",
|
|
62
|
+
1: "custom",
|
|
63
|
+
2: "depth",
|
|
64
|
+
3: "canny",
|
|
65
|
+
4: "scribble",
|
|
66
|
+
5: "pose",
|
|
67
|
+
6: "normalbae",
|
|
68
|
+
7: "color",
|
|
69
|
+
8: "lineart",
|
|
70
|
+
9: "softedge",
|
|
71
|
+
10: "seg",
|
|
72
|
+
11: "inpaint",
|
|
73
|
+
12: "ip2p",
|
|
74
|
+
13: "shuffle",
|
|
75
|
+
14: "mlsd",
|
|
76
|
+
15: "tile",
|
|
77
|
+
16: "blur",
|
|
78
|
+
17: "lowquality",
|
|
79
|
+
18: "gray",
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class V2(TypedDict, total=False):
|
|
84
|
+
aestheticScore: float
|
|
85
|
+
batchCount: int
|
|
86
|
+
batchSize: int
|
|
87
|
+
causalInference: bool
|
|
88
|
+
causalInferencePad: int
|
|
89
|
+
cfgZeroInitSteps: int
|
|
90
|
+
cfgZeroStar: float
|
|
91
|
+
clipLText: float
|
|
92
|
+
clipSkip: int
|
|
93
|
+
clipWeight: float
|
|
94
|
+
compressionArtifacts: str
|
|
95
|
+
compressionArtifactsQuality: int
|
|
96
|
+
controls: list[dict[str, Any]]
|
|
97
|
+
cropLeft: int
|
|
98
|
+
cropTop: int
|
|
99
|
+
decodingTileHeight: int
|
|
100
|
+
decodingTileOverlap: int
|
|
101
|
+
decodingTileWidth: int
|
|
102
|
+
diffusionTileHeight: int
|
|
103
|
+
diffusionTileOverlap: int
|
|
104
|
+
diffusionTileWidth: int
|
|
105
|
+
fps: int
|
|
106
|
+
guidanceEmbed: bool
|
|
107
|
+
guidanceScale: float
|
|
108
|
+
height: int
|
|
109
|
+
hiresFix: bool
|
|
110
|
+
hiresFixHeight: int
|
|
111
|
+
hiresFixStrength: float
|
|
112
|
+
hiresFixWidth: int
|
|
113
|
+
id: int
|
|
114
|
+
imageGuidanceScale: float
|
|
115
|
+
imagePriorSteps: int
|
|
116
|
+
loras: list[dict[str, Any]]
|
|
117
|
+
maskBlur: float
|
|
118
|
+
maskBlurOutset: int
|
|
119
|
+
motionScale: int
|
|
120
|
+
negativeAestheticScore: float
|
|
121
|
+
seed: int
|
|
122
|
+
steps: int
|
|
123
|
+
width: int
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class ImageMetadata(TypedDict, total=False):
|
|
127
|
+
c: str
|
|
128
|
+
uc: str
|
|
129
|
+
steps: Required[int]
|
|
130
|
+
sampler: Required[str]
|
|
131
|
+
scale: Required[float]
|
|
132
|
+
seed: Required[int]
|
|
133
|
+
size: Required[str]
|
|
134
|
+
model: Required[str]
|
|
135
|
+
strength: Required[float]
|
|
136
|
+
seed_mode: Required[str]
|
|
137
|
+
shift: Required[float]
|
|
138
|
+
v2: Required[V2]
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def create_metadata(
|
|
142
|
+
config: GenerationConfiguration, prompt: str, negative_prompt: str
|
|
143
|
+
) -> ImageMetadata:
|
|
144
|
+
"""Create Draw Things PNG metadata from the generation configuration."""
|
|
145
|
+
cfg = config
|
|
146
|
+
width = _pixels(cfg.StartWidth())
|
|
147
|
+
height = _pixels(cfg.StartHeight())
|
|
148
|
+
model = _decode_string(cfg.Model()) or ""
|
|
149
|
+
sampler = cfg.Sampler()
|
|
150
|
+
seed_mode = cfg.SeedMode()
|
|
151
|
+
v2 = _create_v2_metadata(cfg, width, height, model, sampler, seed_mode)
|
|
152
|
+
|
|
153
|
+
return ImageMetadata(
|
|
154
|
+
c=prompt or "",
|
|
155
|
+
uc=negative_prompt or "",
|
|
156
|
+
model=model,
|
|
157
|
+
sampler=SAMPLER_NAMES.get(sampler, str(sampler)),
|
|
158
|
+
scale=cfg.GuidanceScale(),
|
|
159
|
+
seed=cfg.Seed(),
|
|
160
|
+
seed_mode=SEED_MODE_NAMES.get(seed_mode, str(seed_mode)),
|
|
161
|
+
shift=cfg.Shift(),
|
|
162
|
+
size=f"{width}x{height}",
|
|
163
|
+
steps=cfg.Steps(),
|
|
164
|
+
strength=cfg.Strength(),
|
|
165
|
+
v2=v2,
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def _create_v2_metadata(
|
|
170
|
+
cfg: GenerationConfiguration,
|
|
171
|
+
width: int,
|
|
172
|
+
height: int,
|
|
173
|
+
model: str,
|
|
174
|
+
sampler: int,
|
|
175
|
+
seed_mode: int,
|
|
176
|
+
) -> V2:
|
|
177
|
+
v2: dict[str, Any] = {
|
|
178
|
+
"aestheticScore": cfg.AestheticScore(),
|
|
179
|
+
"batchCount": cfg.BatchCount(),
|
|
180
|
+
"batchSize": cfg.BatchSize(),
|
|
181
|
+
"causalInference": cfg.CausalInference(),
|
|
182
|
+
"causalInferencePad": cfg.CausalInferencePad(),
|
|
183
|
+
"cfgZeroInitSteps": cfg.CfgZeroInitSteps(),
|
|
184
|
+
"cfgZeroStar": cfg.CfgZeroStar(),
|
|
185
|
+
"clipLText": cfg.ClipLText(),
|
|
186
|
+
"clipSkip": cfg.ClipSkip(),
|
|
187
|
+
"clipWeight": cfg.ClipWeight(),
|
|
188
|
+
"compressionArtifacts": COMPRESSION_ARTIFACT_NAMES.get(
|
|
189
|
+
cfg.CompressionArtifacts(),
|
|
190
|
+
),
|
|
191
|
+
"compressionArtifactsQuality": cfg.CompressionArtifactsQuality(),
|
|
192
|
+
"controls": _controls(cfg),
|
|
193
|
+
"cropLeft": cfg.CropLeft(),
|
|
194
|
+
"cropTop": cfg.CropTop(),
|
|
195
|
+
"decodingTileHeight": _pixels(cfg.DecodingTileHeight()),
|
|
196
|
+
"decodingTileOverlap": _pixels(cfg.DecodingTileOverlap()),
|
|
197
|
+
"decodingTileWidth": _pixels(cfg.DecodingTileWidth()),
|
|
198
|
+
"diffusionTileHeight": _pixels(cfg.DiffusionTileHeight()),
|
|
199
|
+
"diffusionTileOverlap": _pixels(cfg.DiffusionTileOverlap()),
|
|
200
|
+
"diffusionTileWidth": _pixels(cfg.DiffusionTileWidth()),
|
|
201
|
+
"fps": cfg.FpsId(),
|
|
202
|
+
"guidanceEmbed": cfg.GuidanceEmbed(),
|
|
203
|
+
"guidanceScale": cfg.GuidanceScale(),
|
|
204
|
+
"guidingFrameNoise": cfg.CondAug(),
|
|
205
|
+
"height": height,
|
|
206
|
+
"hiresFix": cfg.HiresFix(),
|
|
207
|
+
"hiresFixHeight": _pixels(cfg.HiresFixStartHeight()),
|
|
208
|
+
"hiresFixStrength": cfg.HiresFixStrength(),
|
|
209
|
+
"hiresFixWidth": _pixels(cfg.HiresFixStartWidth()),
|
|
210
|
+
"id": cfg.Id(),
|
|
211
|
+
"imageGuidanceScale": cfg.ImageGuidanceScale(),
|
|
212
|
+
"imagePriorSteps": cfg.ImagePriorSteps(),
|
|
213
|
+
"loras": _loras(cfg),
|
|
214
|
+
"maskBlur": cfg.MaskBlur(),
|
|
215
|
+
"maskBlurOutset": cfg.MaskBlurOutset(),
|
|
216
|
+
"model": model,
|
|
217
|
+
"motionScale": cfg.MotionBucketId(),
|
|
218
|
+
"negativeAestheticScore": cfg.NegativeAestheticScore(),
|
|
219
|
+
"negativeOriginalImageHeight": _pixels(cfg.NegativeOriginalImageHeight()),
|
|
220
|
+
"negativeOriginalImageWidth": _pixels(cfg.NegativeOriginalImageWidth()),
|
|
221
|
+
"negativePromptForImagePrior": cfg.NegativePromptForImagePrior(),
|
|
222
|
+
"numFrames": cfg.NumFrames(),
|
|
223
|
+
"originalImageHeight": _pixels(cfg.OriginalImageHeight()),
|
|
224
|
+
"originalImageWidth": _pixels(cfg.OriginalImageWidth()),
|
|
225
|
+
"preserveOriginalAfterInpaint": cfg.PreserveOriginalAfterInpaint(),
|
|
226
|
+
"refinerStart": cfg.RefinerStart(),
|
|
227
|
+
"resolutionDependentShift": cfg.ResolutionDependentShift(),
|
|
228
|
+
"sampler": sampler,
|
|
229
|
+
"seed": cfg.Seed(),
|
|
230
|
+
"seedMode": seed_mode,
|
|
231
|
+
"separateClipL": cfg.SeparateClipL(),
|
|
232
|
+
"separateOpenClipG": cfg.SeparateOpenClipG(),
|
|
233
|
+
"separateT5": cfg.SeparateT5(),
|
|
234
|
+
"sharpness": cfg.Sharpness(),
|
|
235
|
+
"shift": _metadata_float(cfg.Shift()),
|
|
236
|
+
"speedUpWithGuidanceEmbed": cfg.SpeedUpWithGuidanceEmbed(),
|
|
237
|
+
"stage2Guidance": cfg.Stage2Cfg(),
|
|
238
|
+
"stage2Shift": cfg.Stage2Shift(),
|
|
239
|
+
"stage2Steps": cfg.Stage2Steps(),
|
|
240
|
+
"startFrameGuidance": cfg.StartFrameCfg(),
|
|
241
|
+
"steps": cfg.Steps(),
|
|
242
|
+
"stochasticSamplingGamma": _metadata_float(cfg.StochasticSamplingGamma()),
|
|
243
|
+
"strength": cfg.Strength(),
|
|
244
|
+
"t5TextEncoder": cfg.T5TextEncoder(),
|
|
245
|
+
"targetImageHeight": _pixels(cfg.TargetImageHeight()),
|
|
246
|
+
"targetImageWidth": _pixels(cfg.TargetImageWidth()),
|
|
247
|
+
"teaCache": cfg.TeaCache(),
|
|
248
|
+
"teaCacheEnd": cfg.TeaCacheEnd(),
|
|
249
|
+
"teaCacheMaxSkipSteps": cfg.TeaCacheMaxSkipSteps(),
|
|
250
|
+
"teaCacheStart": cfg.TeaCacheStart(),
|
|
251
|
+
"teaCacheThreshold": cfg.TeaCacheThreshold(),
|
|
252
|
+
"tiledDecoding": cfg.TiledDecoding(),
|
|
253
|
+
"tiledDiffusion": cfg.TiledDiffusion(),
|
|
254
|
+
"upscalerScaleFactor": cfg.UpscalerScaleFactor(),
|
|
255
|
+
"width": width,
|
|
256
|
+
"zeroNegativePrompt": cfg.ZeroNegativePrompt(),
|
|
257
|
+
}
|
|
258
|
+
v2 = {k: v for k, v in v2.items() if v is not None and v != ""}
|
|
259
|
+
return cast(V2, v2) # pyright: ignore[reportInvalidCast]
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def _decode_string(value: bytes | str | None) -> str | None:
|
|
263
|
+
if value is None:
|
|
264
|
+
return None
|
|
265
|
+
if isinstance(value, bytes):
|
|
266
|
+
return value.decode("utf-8")
|
|
267
|
+
return value
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
def _pixels(value: int | None) -> int:
|
|
271
|
+
return (value or 0) * 64
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _metadata_float(value: float) -> float:
|
|
275
|
+
return round(float(value), 7)
|
|
276
|
+
|
|
277
|
+
|
|
278
|
+
def _loras(cfg: GenerationConfiguration) -> list[dict[str, Any]]:
|
|
279
|
+
return [
|
|
280
|
+
{
|
|
281
|
+
"mode": LORA_MODE_NAMES.get(lora.Mode(), str(lora.Mode())),
|
|
282
|
+
"file": _decode_string(lora.File()) or "",
|
|
283
|
+
"weight": lora.Weight(),
|
|
284
|
+
}
|
|
285
|
+
for lora in _nested_items(cfg, "Loras", "LorasLength")
|
|
286
|
+
]
|
|
287
|
+
|
|
288
|
+
|
|
289
|
+
def _controls(cfg: GenerationConfiguration) -> list[dict[str, Any]]:
|
|
290
|
+
controls = []
|
|
291
|
+
for control in _nested_items(cfg, "Controls", "ControlsLength"):
|
|
292
|
+
controls.append(
|
|
293
|
+
{
|
|
294
|
+
"file": _decode_string(control.File()) or "",
|
|
295
|
+
"weight": control.Weight(),
|
|
296
|
+
"guidanceStart": control.GuidanceStart(),
|
|
297
|
+
"guidanceEnd": control.GuidanceEnd(),
|
|
298
|
+
"noPrompt": control.NoPrompt(),
|
|
299
|
+
"globalAveragePooling": control.GlobalAveragePooling(),
|
|
300
|
+
"downSamplingRate": control.DownSamplingRate(),
|
|
301
|
+
"controlMode": CONTROL_MODE_NAMES.get(
|
|
302
|
+
control.ControlMode(), str(control.ControlMode())
|
|
303
|
+
),
|
|
304
|
+
"targetBlocks": [
|
|
305
|
+
_decode_string(control.TargetBlocks(index)) or ""
|
|
306
|
+
for index in range(control.TargetBlocksLength())
|
|
307
|
+
],
|
|
308
|
+
"inputOverride": CONTROL_INPUT_TYPE_NAMES.get(
|
|
309
|
+
control.InputOverride(), str(control.InputOverride())
|
|
310
|
+
),
|
|
311
|
+
}
|
|
312
|
+
)
|
|
313
|
+
return controls
|
|
314
|
+
|
|
315
|
+
|
|
316
|
+
def _nested_items(
|
|
317
|
+
cfg: GenerationConfiguration,
|
|
318
|
+
item_name: str,
|
|
319
|
+
length_name: str,
|
|
320
|
+
) -> list[Any]:
|
|
321
|
+
generated_dir = str(Path(__file__).parent / "generated" / "dt_grpc")
|
|
322
|
+
added_path = generated_dir not in sys.path
|
|
323
|
+
if added_path:
|
|
324
|
+
sys.path.append(generated_dir)
|
|
325
|
+
try:
|
|
326
|
+
item: Callable[[int], Any] = getattr(cfg, item_name)
|
|
327
|
+
length: Callable[[], int] = getattr(cfg, length_name)
|
|
328
|
+
return [item(index) for index in range(length())]
|
|
329
|
+
finally:
|
|
330
|
+
if added_path:
|
|
331
|
+
sys.path.remove(generated_dir)
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def copy_with_seed(metadata: ImageMetadata, seed: int) -> ImageMetadata:
|
|
335
|
+
new_metadata = copy.deepcopy(metadata)
|
|
336
|
+
new_metadata["seed"] = seed
|
|
337
|
+
new_metadata["v2"]["seed"] = seed
|
|
338
|
+
return new_metadata
|
|
@@ -0,0 +1,262 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utils for writing Draw Things metadata to pngs
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import io
|
|
6
|
+
import json
|
|
7
|
+
import struct
|
|
8
|
+
import zlib
|
|
9
|
+
from typing import cast
|
|
10
|
+
from typing_extensions import override
|
|
11
|
+
|
|
12
|
+
from PIL import Image
|
|
13
|
+
|
|
14
|
+
from ._metadata import ImageMetadata
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BytesEncoder(json.JSONEncoder):
|
|
18
|
+
"""Custom JSON encoder that handles bytes by decoding them to UTF-8 strings."""
|
|
19
|
+
|
|
20
|
+
@override
|
|
21
|
+
def default(self, o: object) -> str: # noqa: ANN401
|
|
22
|
+
if isinstance(o, bytes):
|
|
23
|
+
try:
|
|
24
|
+
return o.decode("utf-8")
|
|
25
|
+
except UnicodeDecodeError:
|
|
26
|
+
return o.hex()
|
|
27
|
+
return cast(str, super().default(o))
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def png_chunk(chunk_type: bytes, data: bytes) -> bytes:
|
|
34
|
+
crc = zlib.crc32(chunk_type)
|
|
35
|
+
crc = zlib.crc32(data, crc)
|
|
36
|
+
|
|
37
|
+
return (
|
|
38
|
+
struct.pack(">I", len(data))
|
|
39
|
+
+ chunk_type
|
|
40
|
+
+ data
|
|
41
|
+
+ struct.pack(">I", crc & 0xFFFFFFFF)
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def build_itxt_chunk(keyword: str, text: str) -> bytes:
|
|
46
|
+
out = bytearray()
|
|
47
|
+
|
|
48
|
+
out.extend(keyword.encode("utf-8"))
|
|
49
|
+
out.append(0) # keyword null terminator
|
|
50
|
+
|
|
51
|
+
out.append(0) # compression flag
|
|
52
|
+
out.append(0) # compression method
|
|
53
|
+
|
|
54
|
+
out.append(0) # language tag empty
|
|
55
|
+
out.append(0) # translated keyword empty
|
|
56
|
+
|
|
57
|
+
out.extend(text.encode("utf-8"))
|
|
58
|
+
|
|
59
|
+
return bytes(out)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def build_exif_user_comment(width: int, height: int) -> bytes:
|
|
63
|
+
exif = bytearray()
|
|
64
|
+
|
|
65
|
+
# TIFF header
|
|
66
|
+
exif.extend(b"MM") # big endian
|
|
67
|
+
exif.extend(struct.pack(">H", 42))
|
|
68
|
+
exif.extend(struct.pack(">I", 8))
|
|
69
|
+
|
|
70
|
+
# IFD0
|
|
71
|
+
exif.extend(struct.pack(">H", 1))
|
|
72
|
+
|
|
73
|
+
# ExifOffset tag
|
|
74
|
+
exif.extend(struct.pack(">H", 0x8769))
|
|
75
|
+
exif.extend(struct.pack(">H", 4)) # LONG
|
|
76
|
+
exif.extend(struct.pack(">I", 1))
|
|
77
|
+
exif.extend(struct.pack(">I", 26))
|
|
78
|
+
|
|
79
|
+
# next IFD
|
|
80
|
+
exif.extend(struct.pack(">I", 0))
|
|
81
|
+
|
|
82
|
+
# Exif SubIFD
|
|
83
|
+
exif.extend(struct.pack(">H", 2))
|
|
84
|
+
|
|
85
|
+
# ExifImageWidth
|
|
86
|
+
exif.extend(struct.pack(">H", 0xA002))
|
|
87
|
+
exif.extend(struct.pack(">H", 4))
|
|
88
|
+
exif.extend(struct.pack(">I", 1))
|
|
89
|
+
exif.extend(struct.pack(">I", width))
|
|
90
|
+
|
|
91
|
+
# ExifImageHeight
|
|
92
|
+
exif.extend(struct.pack(">H", 0xA003))
|
|
93
|
+
exif.extend(struct.pack(">H", 4))
|
|
94
|
+
exif.extend(struct.pack(">I", 1))
|
|
95
|
+
exif.extend(struct.pack(">I", height))
|
|
96
|
+
|
|
97
|
+
# next SubIFD
|
|
98
|
+
exif.extend(struct.pack(">I", 0))
|
|
99
|
+
|
|
100
|
+
return bytes(exif)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
def format_desc_float(f: float) -> str:
|
|
104
|
+
s = str(float(f))
|
|
105
|
+
|
|
106
|
+
if "." not in s:
|
|
107
|
+
return s + ".0"
|
|
108
|
+
|
|
109
|
+
return s
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def build_description(metadata: ImageMetadata | None) -> str:
|
|
113
|
+
if metadata is None:
|
|
114
|
+
return ""
|
|
115
|
+
desc: list[str] = []
|
|
116
|
+
if c := metadata.get("c"):
|
|
117
|
+
desc.append("Prompt: " + c)
|
|
118
|
+
if uc := metadata.get("uc"):
|
|
119
|
+
desc.append("Negative Prompt: " + uc)
|
|
120
|
+
if steps := metadata.get("steps"):
|
|
121
|
+
desc.append("Steps: " + str(steps))
|
|
122
|
+
if sampler := metadata.get("sampler"):
|
|
123
|
+
desc.append("Sampler: " + sampler)
|
|
124
|
+
if guidance_scale := metadata["v2"].get("guidanceScale", 0.0):
|
|
125
|
+
desc.append("Guidance Scale: " + format_desc_float(guidance_scale))
|
|
126
|
+
if seed := metadata.get("seed"):
|
|
127
|
+
desc.append("Seed: " + str(seed))
|
|
128
|
+
if size := metadata.get("size"):
|
|
129
|
+
desc.append("Size: " + size)
|
|
130
|
+
if model := metadata.get("model"):
|
|
131
|
+
desc.append("Model: " + model)
|
|
132
|
+
if strength := metadata.get("strength"):
|
|
133
|
+
desc.append("Strength: " + format_desc_float(strength))
|
|
134
|
+
|
|
135
|
+
return "\n".join(desc)
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def build_drawthings_xmp(json_string: str, description: str) -> str:
|
|
139
|
+
escaped_description = description.replace("\n", "
")
|
|
140
|
+
|
|
141
|
+
return f"""<x:xmpmeta xmlns:x="adobe:ns:meta/" x:xmptk="XMP Core 6.0.0">
|
|
142
|
+
<rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
|
|
143
|
+
<rdf:Description rdf:about=""
|
|
144
|
+
xmlns:dc="http://purl.org/dc/elements/1.1/"
|
|
145
|
+
xmlns:xmp="http://ns.adobe.com/xap/1.0/"
|
|
146
|
+
xmlns:exif="http://ns.adobe.com/exif/1.0/">
|
|
147
|
+
<dc:description>
|
|
148
|
+
<rdf:Alt>
|
|
149
|
+
<rdf:li xml:lang="x-default">{escaped_description}</rdf:li>
|
|
150
|
+
</rdf:Alt>
|
|
151
|
+
</dc:description>
|
|
152
|
+
<xmp:CreatorTool>Draw Things</xmp:CreatorTool>
|
|
153
|
+
<exif:UserComment>
|
|
154
|
+
<rdf:Alt>
|
|
155
|
+
<rdf:li xml:lang="x-default">{json_string}</rdf:li>
|
|
156
|
+
</rdf:Alt>
|
|
157
|
+
</exif:UserComment>
|
|
158
|
+
</rdf:Description>
|
|
159
|
+
</rdf:RDF>
|
|
160
|
+
</x:xmpmeta>
|
|
161
|
+
"""
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def write_png_with_usercomment(
|
|
165
|
+
pixels: bytes,
|
|
166
|
+
width: int,
|
|
167
|
+
height: int,
|
|
168
|
+
channels: int,
|
|
169
|
+
metadata: ImageMetadata | None = None,
|
|
170
|
+
) -> bytes:
|
|
171
|
+
mode_map = {
|
|
172
|
+
1: "L",
|
|
173
|
+
2: "LA",
|
|
174
|
+
3: "RGB",
|
|
175
|
+
4: "RGBA",
|
|
176
|
+
}
|
|
177
|
+
|
|
178
|
+
if channels not in mode_map:
|
|
179
|
+
raise ValueError("Unsupported channel count")
|
|
180
|
+
|
|
181
|
+
mode = mode_map[channels]
|
|
182
|
+
|
|
183
|
+
img = Image.frombytes(mode, (width, height), pixels)
|
|
184
|
+
|
|
185
|
+
temp = io.BytesIO()
|
|
186
|
+
img.save(temp, format="PNG")
|
|
187
|
+
|
|
188
|
+
png_data = temp.getvalue()
|
|
189
|
+
|
|
190
|
+
# Parse existing PNG chunks
|
|
191
|
+
pos = 8
|
|
192
|
+
chunks: list[tuple[bytes, bytes, bytes]] = []
|
|
193
|
+
|
|
194
|
+
while pos < len(png_data):
|
|
195
|
+
length = cast(int, struct.unpack(">I", png_data[pos : pos + 4])[0])
|
|
196
|
+
chunk_type = png_data[pos + 4 : pos + 8]
|
|
197
|
+
chunk_data = png_data[pos + 8 : pos + 8 + length]
|
|
198
|
+
chunk_crc = png_data[pos + 8 + length : pos + 12 + length]
|
|
199
|
+
|
|
200
|
+
chunks.append(
|
|
201
|
+
(
|
|
202
|
+
chunk_type,
|
|
203
|
+
chunk_data,
|
|
204
|
+
chunk_crc,
|
|
205
|
+
)
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
pos += length + 12
|
|
209
|
+
|
|
210
|
+
out = bytearray()
|
|
211
|
+
out.extend(PNG_SIGNATURE)
|
|
212
|
+
|
|
213
|
+
inserted = False
|
|
214
|
+
|
|
215
|
+
for chunk_type, chunk_data, chunk_crc in chunks:
|
|
216
|
+
out.extend(
|
|
217
|
+
struct.pack(">I", len(chunk_data)) + chunk_type + chunk_data + chunk_crc
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if chunk_type == b"IHDR" and not inserted:
|
|
221
|
+
# sRGB chunk
|
|
222
|
+
out.extend(
|
|
223
|
+
png_chunk(
|
|
224
|
+
b"sRGB",
|
|
225
|
+
b"\x00",
|
|
226
|
+
)
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
if metadata is not None:
|
|
230
|
+
json_string = json.dumps(
|
|
231
|
+
metadata, separators=(",", ":"), cls=BytesEncoder
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
exif = build_exif_user_comment(width, height)
|
|
235
|
+
|
|
236
|
+
out.extend(
|
|
237
|
+
png_chunk(
|
|
238
|
+
b"eXIf",
|
|
239
|
+
exif,
|
|
240
|
+
)
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
xmp = build_drawthings_xmp(
|
|
244
|
+
json_string,
|
|
245
|
+
build_description(metadata),
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
itxt = build_itxt_chunk(
|
|
249
|
+
"XML:com.adobe.xmp",
|
|
250
|
+
xmp,
|
|
251
|
+
)
|
|
252
|
+
|
|
253
|
+
out.extend(
|
|
254
|
+
png_chunk(
|
|
255
|
+
b"iTXt",
|
|
256
|
+
itxt,
|
|
257
|
+
)
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
inserted = True
|
|
261
|
+
|
|
262
|
+
return bytes(out)
|