drawthings-py 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. drawthings_py/__init__.py +19 -0
  2. drawthings_py/_dt_service.py +39 -0
  3. drawthings_py/_errors.py +21 -0
  4. drawthings_py/_metadata.py +338 -0
  5. drawthings_py/_png_writer.py +262 -0
  6. drawthings_py/_preview_decoders.py +657 -0
  7. drawthings_py/_util.py +102 -0
  8. drawthings_py/cli_service.py +25 -0
  9. drawthings_py/configs/__init__.py +31 -0
  10. drawthings_py/configs/config_dict.py +287 -0
  11. drawthings_py/configs/config_prop.py +576 -0
  12. drawthings_py/configs/configs.py +110 -0
  13. drawthings_py/configs/gen_config.py +397 -0
  14. drawthings_py/configs/presets.py +58 -0
  15. drawthings_py/configs/types.py +261 -0
  16. drawthings_py/drawthings.py +34 -0
  17. drawthings_py/filename_pattern.py +136 -0
  18. drawthings_py/generated/dt_grpc/__init__.py +0 -0
  19. drawthings_py/generated/dt_grpc/config_generated.py +1763 -0
  20. drawthings_py/generated/dt_grpc/config_generated.pyi +527 -0
  21. drawthings_py/generated/dt_grpc/image_service/__init__.py +502 -0
  22. drawthings_py/grpc_service.py +303 -0
  23. drawthings_py/image_buffer.py +471 -0
  24. drawthings_py/py.typed +0 -0
  25. drawthings_py/request_builder.py +423 -0
  26. drawthings_py/resources/config_props.yaml +926 -0
  27. drawthings_py/resources/configs/anima_preview_3.json +36 -0
  28. drawthings_py/resources/configs/chroma_hd.json +36 -0
  29. drawthings_py/resources/configs/default.json +34 -0
  30. drawthings_py/resources/configs/ernie_image_base.json +35 -0
  31. drawthings_py/resources/configs/ernie_image_turbo.json +36 -0
  32. drawthings_py/resources/configs/flux_1_dev.json +38 -0
  33. drawthings_py/resources/configs/flux_1_fill_dev.json +34 -0
  34. drawthings_py/resources/configs/flux_1_schnell.json +34 -0
  35. drawthings_py/resources/configs/flux_2_dev_with_turbo.json +43 -0
  36. drawthings_py/resources/configs/flux_2_klein_4b.json +37 -0
  37. drawthings_py/resources/configs/flux_2_klein_4b_base.json +35 -0
  38. drawthings_py/resources/configs/flux_2_klein_9b.json +37 -0
  39. drawthings_py/resources/configs/flux_2_klein_9b_base.json +35 -0
  40. drawthings_py/resources/configs/flux_2_klein_9b_kv.json +36 -0
  41. drawthings_py/resources/configs/qwen_image_2512.json +35 -0
  42. drawthings_py/resources/configs/qwen_image_2512_lightning.json +41 -0
  43. drawthings_py/resources/configs/qwen_image_edit_2511.json +34 -0
  44. drawthings_py/resources/configs/qwen_image_edit_2511_lightning.json +41 -0
  45. drawthings_py/resources/configs/sdxl.json +40 -0
  46. drawthings_py/resources/configs/stable_diffusion.json +29 -0
  47. drawthings_py/resources/configs/z_image_base.json +34 -0
  48. drawthings_py/resources/configs/z_image_turbo.json +36 -0
  49. drawthings_py/resources/root_ca.crt +30 -0
  50. drawthings_py/seed_provider.py +23 -0
  51. drawthings_py-0.1.0.dist-info/METADATA +240 -0
  52. drawthings_py-0.1.0.dist-info/RECORD +54 -0
  53. drawthings_py-0.1.0.dist-info/WHEEL +4 -0
  54. drawthings_py-0.1.0.dist-info/licenses/LICENSE +675 -0
@@ -0,0 +1,19 @@
1
+ """
2
+ Draw Things Python SDK
3
+ """
4
+
5
+ import drawthings_py.drawthings as DrawThings
6
+ from .configs import Configs, ConfigDict, Presets
7
+ from .filename_pattern import FilenamePattern
8
+ from .image_buffer import ImageBuffer
9
+ from .request_builder import RequestBuilder
10
+
11
+ __all__ = [
12
+ "DrawThings",
13
+ "RequestBuilder",
14
+ "ImageBuffer",
15
+ "ConfigDict",
16
+ "Configs",
17
+ "Presets",
18
+ "FilenamePattern",
19
+ ]
@@ -0,0 +1,39 @@
1
+ """
2
+ Primary entry point for using Draw Things services
3
+ """
4
+
5
+ from __future__ import annotations
6
+ from abc import ABC, abstractmethod
7
+ from types import TracebackType
8
+
9
+ from .image_buffer import ImageBuffer
10
+ from .request_builder import RequestBuilder
11
+
12
+
13
+ class DrawThingsService(ABC):
14
+ """
15
+ Base class for grpc and cii service
16
+ """
17
+
18
+ @abstractmethod
19
+ async def generate_image(self, request: RequestBuilder) -> list[ImageBuffer]:
20
+ """
21
+ Generate an image from the provided request builder
22
+ """
23
+
24
+ @abstractmethod
25
+ def _dispose(self):
26
+ """
27
+ dispose of the service
28
+ """
29
+
30
+ async def __aenter__(self) -> "DrawThingsService":
31
+ return self
32
+
33
+ async def __aexit__(
34
+ self,
35
+ exc_type: type[BaseException] | None,
36
+ exc_val: BaseException | None,
37
+ exc_tb: TracebackType | None,
38
+ ) -> None:
39
+ self._dispose()
@@ -0,0 +1,21 @@
1
+ from grpclib import GRPCError, Status
2
+
3
+
4
+ class DrawThingsServerError(Exception):
5
+ pass
6
+
7
+
8
+ class DrawThingsUnavailableError(Exception):
9
+ pass
10
+
11
+
12
+ def raise_grpc_error(e: GRPCError):
13
+ if e.status == Status.INTERNAL:
14
+ raise DrawThingsServerError(
15
+ f"There was an error on the server. This could be a temporary issue with DT+ or a problem with your request. ({e.status.name}: {e.message})"
16
+ ) from e
17
+
18
+ if e.status == Status.UNAVAILABLE:
19
+ raise DrawThingsUnavailableError(
20
+ "The gRPC server is currently unavailable."
21
+ ) from e
@@ -0,0 +1,338 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import sys
5
+ from collections.abc import Callable
6
+ from pathlib import Path
7
+ from typing import Any, Required, TypedDict, cast
8
+
9
+ from .generated.dt_grpc.config_generated import GenerationConfiguration
10
+
11
+ SAMPLER_NAMES: dict[Any, str] = {
12
+ 0: "DPM++ 2M Karras",
13
+ 1: "Euler A",
14
+ 2: "DDIM",
15
+ 3: "PLMS",
16
+ 4: "DPM++ SDE Karras",
17
+ 5: "UniPC",
18
+ 6: "LCM",
19
+ 7: "Euler A Substep",
20
+ 8: "DPM++ SDE Substep",
21
+ 9: "TCD",
22
+ 10: "Euler A Trailing",
23
+ 11: "DPM++ SDE Trailing",
24
+ 12: "DPM++ 2M AYS",
25
+ 13: "Euler A AYS",
26
+ 14: "DPM++ SDE AYS",
27
+ 15: "DPM++ 2M Trailing",
28
+ 16: "DDIM Trailing",
29
+ 17: "UniPC Trailing",
30
+ 18: "UniPC AYS",
31
+ 19: "TCD Trailing",
32
+ }
33
+
34
+ SEED_MODE_NAMES = {
35
+ 0: "Legacy",
36
+ 1: "Torch CPU Compatible",
37
+ 2: "Scale Alike",
38
+ 3: "Nvidia GPU Compatible",
39
+ }
40
+
41
+ COMPRESSION_ARTIFACT_NAMES = {
42
+ 0: "disabled",
43
+ 1: "h264",
44
+ 2: "h265",
45
+ 3: "jpeg",
46
+ }
47
+
48
+ LORA_MODE_NAMES = {
49
+ 0: "all",
50
+ 1: "base",
51
+ 2: "refiner",
52
+ }
53
+
54
+ CONTROL_MODE_NAMES = {
55
+ 0: "balanced",
56
+ 1: "prompt",
57
+ 2: "control",
58
+ }
59
+
60
+ CONTROL_INPUT_TYPE_NAMES = {
61
+ 0: "unspecified",
62
+ 1: "custom",
63
+ 2: "depth",
64
+ 3: "canny",
65
+ 4: "scribble",
66
+ 5: "pose",
67
+ 6: "normalbae",
68
+ 7: "color",
69
+ 8: "lineart",
70
+ 9: "softedge",
71
+ 10: "seg",
72
+ 11: "inpaint",
73
+ 12: "ip2p",
74
+ 13: "shuffle",
75
+ 14: "mlsd",
76
+ 15: "tile",
77
+ 16: "blur",
78
+ 17: "lowquality",
79
+ 18: "gray",
80
+ }
81
+
82
+
83
+ class V2(TypedDict, total=False):
84
+ aestheticScore: float
85
+ batchCount: int
86
+ batchSize: int
87
+ causalInference: bool
88
+ causalInferencePad: int
89
+ cfgZeroInitSteps: int
90
+ cfgZeroStar: float
91
+ clipLText: float
92
+ clipSkip: int
93
+ clipWeight: float
94
+ compressionArtifacts: str
95
+ compressionArtifactsQuality: int
96
+ controls: list[dict[str, Any]]
97
+ cropLeft: int
98
+ cropTop: int
99
+ decodingTileHeight: int
100
+ decodingTileOverlap: int
101
+ decodingTileWidth: int
102
+ diffusionTileHeight: int
103
+ diffusionTileOverlap: int
104
+ diffusionTileWidth: int
105
+ fps: int
106
+ guidanceEmbed: bool
107
+ guidanceScale: float
108
+ height: int
109
+ hiresFix: bool
110
+ hiresFixHeight: int
111
+ hiresFixStrength: float
112
+ hiresFixWidth: int
113
+ id: int
114
+ imageGuidanceScale: float
115
+ imagePriorSteps: int
116
+ loras: list[dict[str, Any]]
117
+ maskBlur: float
118
+ maskBlurOutset: int
119
+ motionScale: int
120
+ negativeAestheticScore: float
121
+ seed: int
122
+ steps: int
123
+ width: int
124
+
125
+
126
+ class ImageMetadata(TypedDict, total=False):
127
+ c: str
128
+ uc: str
129
+ steps: Required[int]
130
+ sampler: Required[str]
131
+ scale: Required[float]
132
+ seed: Required[int]
133
+ size: Required[str]
134
+ model: Required[str]
135
+ strength: Required[float]
136
+ seed_mode: Required[str]
137
+ shift: Required[float]
138
+ v2: Required[V2]
139
+
140
+
141
+ def create_metadata(
142
+ config: GenerationConfiguration, prompt: str, negative_prompt: str
143
+ ) -> ImageMetadata:
144
+ """Create Draw Things PNG metadata from the generation configuration."""
145
+ cfg = config
146
+ width = _pixels(cfg.StartWidth())
147
+ height = _pixels(cfg.StartHeight())
148
+ model = _decode_string(cfg.Model()) or ""
149
+ sampler = cfg.Sampler()
150
+ seed_mode = cfg.SeedMode()
151
+ v2 = _create_v2_metadata(cfg, width, height, model, sampler, seed_mode)
152
+
153
+ return ImageMetadata(
154
+ c=prompt or "",
155
+ uc=negative_prompt or "",
156
+ model=model,
157
+ sampler=SAMPLER_NAMES.get(sampler, str(sampler)),
158
+ scale=cfg.GuidanceScale(),
159
+ seed=cfg.Seed(),
160
+ seed_mode=SEED_MODE_NAMES.get(seed_mode, str(seed_mode)),
161
+ shift=cfg.Shift(),
162
+ size=f"{width}x{height}",
163
+ steps=cfg.Steps(),
164
+ strength=cfg.Strength(),
165
+ v2=v2,
166
+ )
167
+
168
+
169
+ def _create_v2_metadata(
170
+ cfg: GenerationConfiguration,
171
+ width: int,
172
+ height: int,
173
+ model: str,
174
+ sampler: int,
175
+ seed_mode: int,
176
+ ) -> V2:
177
+ v2: dict[str, Any] = {
178
+ "aestheticScore": cfg.AestheticScore(),
179
+ "batchCount": cfg.BatchCount(),
180
+ "batchSize": cfg.BatchSize(),
181
+ "causalInference": cfg.CausalInference(),
182
+ "causalInferencePad": cfg.CausalInferencePad(),
183
+ "cfgZeroInitSteps": cfg.CfgZeroInitSteps(),
184
+ "cfgZeroStar": cfg.CfgZeroStar(),
185
+ "clipLText": cfg.ClipLText(),
186
+ "clipSkip": cfg.ClipSkip(),
187
+ "clipWeight": cfg.ClipWeight(),
188
+ "compressionArtifacts": COMPRESSION_ARTIFACT_NAMES.get(
189
+ cfg.CompressionArtifacts(),
190
+ ),
191
+ "compressionArtifactsQuality": cfg.CompressionArtifactsQuality(),
192
+ "controls": _controls(cfg),
193
+ "cropLeft": cfg.CropLeft(),
194
+ "cropTop": cfg.CropTop(),
195
+ "decodingTileHeight": _pixels(cfg.DecodingTileHeight()),
196
+ "decodingTileOverlap": _pixels(cfg.DecodingTileOverlap()),
197
+ "decodingTileWidth": _pixels(cfg.DecodingTileWidth()),
198
+ "diffusionTileHeight": _pixels(cfg.DiffusionTileHeight()),
199
+ "diffusionTileOverlap": _pixels(cfg.DiffusionTileOverlap()),
200
+ "diffusionTileWidth": _pixels(cfg.DiffusionTileWidth()),
201
+ "fps": cfg.FpsId(),
202
+ "guidanceEmbed": cfg.GuidanceEmbed(),
203
+ "guidanceScale": cfg.GuidanceScale(),
204
+ "guidingFrameNoise": cfg.CondAug(),
205
+ "height": height,
206
+ "hiresFix": cfg.HiresFix(),
207
+ "hiresFixHeight": _pixels(cfg.HiresFixStartHeight()),
208
+ "hiresFixStrength": cfg.HiresFixStrength(),
209
+ "hiresFixWidth": _pixels(cfg.HiresFixStartWidth()),
210
+ "id": cfg.Id(),
211
+ "imageGuidanceScale": cfg.ImageGuidanceScale(),
212
+ "imagePriorSteps": cfg.ImagePriorSteps(),
213
+ "loras": _loras(cfg),
214
+ "maskBlur": cfg.MaskBlur(),
215
+ "maskBlurOutset": cfg.MaskBlurOutset(),
216
+ "model": model,
217
+ "motionScale": cfg.MotionBucketId(),
218
+ "negativeAestheticScore": cfg.NegativeAestheticScore(),
219
+ "negativeOriginalImageHeight": _pixels(cfg.NegativeOriginalImageHeight()),
220
+ "negativeOriginalImageWidth": _pixels(cfg.NegativeOriginalImageWidth()),
221
+ "negativePromptForImagePrior": cfg.NegativePromptForImagePrior(),
222
+ "numFrames": cfg.NumFrames(),
223
+ "originalImageHeight": _pixels(cfg.OriginalImageHeight()),
224
+ "originalImageWidth": _pixels(cfg.OriginalImageWidth()),
225
+ "preserveOriginalAfterInpaint": cfg.PreserveOriginalAfterInpaint(),
226
+ "refinerStart": cfg.RefinerStart(),
227
+ "resolutionDependentShift": cfg.ResolutionDependentShift(),
228
+ "sampler": sampler,
229
+ "seed": cfg.Seed(),
230
+ "seedMode": seed_mode,
231
+ "separateClipL": cfg.SeparateClipL(),
232
+ "separateOpenClipG": cfg.SeparateOpenClipG(),
233
+ "separateT5": cfg.SeparateT5(),
234
+ "sharpness": cfg.Sharpness(),
235
+ "shift": _metadata_float(cfg.Shift()),
236
+ "speedUpWithGuidanceEmbed": cfg.SpeedUpWithGuidanceEmbed(),
237
+ "stage2Guidance": cfg.Stage2Cfg(),
238
+ "stage2Shift": cfg.Stage2Shift(),
239
+ "stage2Steps": cfg.Stage2Steps(),
240
+ "startFrameGuidance": cfg.StartFrameCfg(),
241
+ "steps": cfg.Steps(),
242
+ "stochasticSamplingGamma": _metadata_float(cfg.StochasticSamplingGamma()),
243
+ "strength": cfg.Strength(),
244
+ "t5TextEncoder": cfg.T5TextEncoder(),
245
+ "targetImageHeight": _pixels(cfg.TargetImageHeight()),
246
+ "targetImageWidth": _pixels(cfg.TargetImageWidth()),
247
+ "teaCache": cfg.TeaCache(),
248
+ "teaCacheEnd": cfg.TeaCacheEnd(),
249
+ "teaCacheMaxSkipSteps": cfg.TeaCacheMaxSkipSteps(),
250
+ "teaCacheStart": cfg.TeaCacheStart(),
251
+ "teaCacheThreshold": cfg.TeaCacheThreshold(),
252
+ "tiledDecoding": cfg.TiledDecoding(),
253
+ "tiledDiffusion": cfg.TiledDiffusion(),
254
+ "upscalerScaleFactor": cfg.UpscalerScaleFactor(),
255
+ "width": width,
256
+ "zeroNegativePrompt": cfg.ZeroNegativePrompt(),
257
+ }
258
+ v2 = {k: v for k, v in v2.items() if v is not None and v != ""}
259
+ return cast(V2, v2) # pyright: ignore[reportInvalidCast]
260
+
261
+
262
+ def _decode_string(value: bytes | str | None) -> str | None:
263
+ if value is None:
264
+ return None
265
+ if isinstance(value, bytes):
266
+ return value.decode("utf-8")
267
+ return value
268
+
269
+
270
+ def _pixels(value: int | None) -> int:
271
+ return (value or 0) * 64
272
+
273
+
274
+ def _metadata_float(value: float) -> float:
275
+ return round(float(value), 7)
276
+
277
+
278
+ def _loras(cfg: GenerationConfiguration) -> list[dict[str, Any]]:
279
+ return [
280
+ {
281
+ "mode": LORA_MODE_NAMES.get(lora.Mode(), str(lora.Mode())),
282
+ "file": _decode_string(lora.File()) or "",
283
+ "weight": lora.Weight(),
284
+ }
285
+ for lora in _nested_items(cfg, "Loras", "LorasLength")
286
+ ]
287
+
288
+
289
+ def _controls(cfg: GenerationConfiguration) -> list[dict[str, Any]]:
290
+ controls = []
291
+ for control in _nested_items(cfg, "Controls", "ControlsLength"):
292
+ controls.append(
293
+ {
294
+ "file": _decode_string(control.File()) or "",
295
+ "weight": control.Weight(),
296
+ "guidanceStart": control.GuidanceStart(),
297
+ "guidanceEnd": control.GuidanceEnd(),
298
+ "noPrompt": control.NoPrompt(),
299
+ "globalAveragePooling": control.GlobalAveragePooling(),
300
+ "downSamplingRate": control.DownSamplingRate(),
301
+ "controlMode": CONTROL_MODE_NAMES.get(
302
+ control.ControlMode(), str(control.ControlMode())
303
+ ),
304
+ "targetBlocks": [
305
+ _decode_string(control.TargetBlocks(index)) or ""
306
+ for index in range(control.TargetBlocksLength())
307
+ ],
308
+ "inputOverride": CONTROL_INPUT_TYPE_NAMES.get(
309
+ control.InputOverride(), str(control.InputOverride())
310
+ ),
311
+ }
312
+ )
313
+ return controls
314
+
315
+
316
+ def _nested_items(
317
+ cfg: GenerationConfiguration,
318
+ item_name: str,
319
+ length_name: str,
320
+ ) -> list[Any]:
321
+ generated_dir = str(Path(__file__).parent / "generated" / "dt_grpc")
322
+ added_path = generated_dir not in sys.path
323
+ if added_path:
324
+ sys.path.append(generated_dir)
325
+ try:
326
+ item: Callable[[int], Any] = getattr(cfg, item_name)
327
+ length: Callable[[], int] = getattr(cfg, length_name)
328
+ return [item(index) for index in range(length())]
329
+ finally:
330
+ if added_path:
331
+ sys.path.remove(generated_dir)
332
+
333
+
334
+ def copy_with_seed(metadata: ImageMetadata, seed: int) -> ImageMetadata:
335
+ new_metadata = copy.deepcopy(metadata)
336
+ new_metadata["seed"] = seed
337
+ new_metadata["v2"]["seed"] = seed
338
+ return new_metadata
@@ -0,0 +1,262 @@
1
+ """
2
+ Utils for writing Draw Things metadata to pngs
3
+ """
4
+
5
+ import io
6
+ import json
7
+ import struct
8
+ import zlib
9
+ from typing import cast
10
+ from typing_extensions import override
11
+
12
+ from PIL import Image
13
+
14
+ from ._metadata import ImageMetadata
15
+
16
+
17
+ class BytesEncoder(json.JSONEncoder):
18
+ """Custom JSON encoder that handles bytes by decoding them to UTF-8 strings."""
19
+
20
+ @override
21
+ def default(self, o: object) -> str: # noqa: ANN401
22
+ if isinstance(o, bytes):
23
+ try:
24
+ return o.decode("utf-8")
25
+ except UnicodeDecodeError:
26
+ return o.hex()
27
+ return cast(str, super().default(o))
28
+
29
+
30
+ PNG_SIGNATURE = b"\x89PNG\r\n\x1a\n"
31
+
32
+
33
+ def png_chunk(chunk_type: bytes, data: bytes) -> bytes:
34
+ crc = zlib.crc32(chunk_type)
35
+ crc = zlib.crc32(data, crc)
36
+
37
+ return (
38
+ struct.pack(">I", len(data))
39
+ + chunk_type
40
+ + data
41
+ + struct.pack(">I", crc & 0xFFFFFFFF)
42
+ )
43
+
44
+
45
+ def build_itxt_chunk(keyword: str, text: str) -> bytes:
46
+ out = bytearray()
47
+
48
+ out.extend(keyword.encode("utf-8"))
49
+ out.append(0) # keyword null terminator
50
+
51
+ out.append(0) # compression flag
52
+ out.append(0) # compression method
53
+
54
+ out.append(0) # language tag empty
55
+ out.append(0) # translated keyword empty
56
+
57
+ out.extend(text.encode("utf-8"))
58
+
59
+ return bytes(out)
60
+
61
+
62
+ def build_exif_user_comment(width: int, height: int) -> bytes:
63
+ exif = bytearray()
64
+
65
+ # TIFF header
66
+ exif.extend(b"MM") # big endian
67
+ exif.extend(struct.pack(">H", 42))
68
+ exif.extend(struct.pack(">I", 8))
69
+
70
+ # IFD0
71
+ exif.extend(struct.pack(">H", 1))
72
+
73
+ # ExifOffset tag
74
+ exif.extend(struct.pack(">H", 0x8769))
75
+ exif.extend(struct.pack(">H", 4)) # LONG
76
+ exif.extend(struct.pack(">I", 1))
77
+ exif.extend(struct.pack(">I", 26))
78
+
79
+ # next IFD
80
+ exif.extend(struct.pack(">I", 0))
81
+
82
+ # Exif SubIFD
83
+ exif.extend(struct.pack(">H", 2))
84
+
85
+ # ExifImageWidth
86
+ exif.extend(struct.pack(">H", 0xA002))
87
+ exif.extend(struct.pack(">H", 4))
88
+ exif.extend(struct.pack(">I", 1))
89
+ exif.extend(struct.pack(">I", width))
90
+
91
+ # ExifImageHeight
92
+ exif.extend(struct.pack(">H", 0xA003))
93
+ exif.extend(struct.pack(">H", 4))
94
+ exif.extend(struct.pack(">I", 1))
95
+ exif.extend(struct.pack(">I", height))
96
+
97
+ # next SubIFD
98
+ exif.extend(struct.pack(">I", 0))
99
+
100
+ return bytes(exif)
101
+
102
+
103
+ def format_desc_float(f: float) -> str:
104
+ s = str(float(f))
105
+
106
+ if "." not in s:
107
+ return s + ".0"
108
+
109
+ return s
110
+
111
+
112
+ def build_description(metadata: ImageMetadata | None) -> str:
113
+ if metadata is None:
114
+ return ""
115
+ desc: list[str] = []
116
+ if c := metadata.get("c"):
117
+ desc.append("Prompt: " + c)
118
+ if uc := metadata.get("uc"):
119
+ desc.append("Negative Prompt: " + uc)
120
+ if steps := metadata.get("steps"):
121
+ desc.append("Steps: " + str(steps))
122
+ if sampler := metadata.get("sampler"):
123
+ desc.append("Sampler: " + sampler)
124
+ if guidance_scale := metadata["v2"].get("guidanceScale", 0.0):
125
+ desc.append("Guidance Scale: " + format_desc_float(guidance_scale))
126
+ if seed := metadata.get("seed"):
127
+ desc.append("Seed: " + str(seed))
128
+ if size := metadata.get("size"):
129
+ desc.append("Size: " + size)
130
+ if model := metadata.get("model"):
131
+ desc.append("Model: " + model)
132
+ if strength := metadata.get("strength"):
133
+ desc.append("Strength: " + format_desc_float(strength))
134
+
135
+ return "\n".join(desc)
136
+
137
+
138
+ def build_drawthings_xmp(json_string: str, description: str) -> str:
139
+ escaped_description = description.replace("\n", "
")
140
+
141
+ return f"""<x:xmpmeta xmlns:x="adobe:ns:meta/" x:xmptk="XMP Core 6.0.0">
142
+ <rdf:RDF xmlns:rdf="http://www.w3.org/1999/02/22-rdf-syntax-ns#">
143
+ <rdf:Description rdf:about=""
144
+ xmlns:dc="http://purl.org/dc/elements/1.1/"
145
+ xmlns:xmp="http://ns.adobe.com/xap/1.0/"
146
+ xmlns:exif="http://ns.adobe.com/exif/1.0/">
147
+ <dc:description>
148
+ <rdf:Alt>
149
+ <rdf:li xml:lang="x-default">{escaped_description}</rdf:li>
150
+ </rdf:Alt>
151
+ </dc:description>
152
+ <xmp:CreatorTool>Draw Things</xmp:CreatorTool>
153
+ <exif:UserComment>
154
+ <rdf:Alt>
155
+ <rdf:li xml:lang="x-default">{json_string}</rdf:li>
156
+ </rdf:Alt>
157
+ </exif:UserComment>
158
+ </rdf:Description>
159
+ </rdf:RDF>
160
+ </x:xmpmeta>
161
+ """
162
+
163
+
164
+ def write_png_with_usercomment(
165
+ pixels: bytes,
166
+ width: int,
167
+ height: int,
168
+ channels: int,
169
+ metadata: ImageMetadata | None = None,
170
+ ) -> bytes:
171
+ mode_map = {
172
+ 1: "L",
173
+ 2: "LA",
174
+ 3: "RGB",
175
+ 4: "RGBA",
176
+ }
177
+
178
+ if channels not in mode_map:
179
+ raise ValueError("Unsupported channel count")
180
+
181
+ mode = mode_map[channels]
182
+
183
+ img = Image.frombytes(mode, (width, height), pixels)
184
+
185
+ temp = io.BytesIO()
186
+ img.save(temp, format="PNG")
187
+
188
+ png_data = temp.getvalue()
189
+
190
+ # Parse existing PNG chunks
191
+ pos = 8
192
+ chunks: list[tuple[bytes, bytes, bytes]] = []
193
+
194
+ while pos < len(png_data):
195
+ length = cast(int, struct.unpack(">I", png_data[pos : pos + 4])[0])
196
+ chunk_type = png_data[pos + 4 : pos + 8]
197
+ chunk_data = png_data[pos + 8 : pos + 8 + length]
198
+ chunk_crc = png_data[pos + 8 + length : pos + 12 + length]
199
+
200
+ chunks.append(
201
+ (
202
+ chunk_type,
203
+ chunk_data,
204
+ chunk_crc,
205
+ )
206
+ )
207
+
208
+ pos += length + 12
209
+
210
+ out = bytearray()
211
+ out.extend(PNG_SIGNATURE)
212
+
213
+ inserted = False
214
+
215
+ for chunk_type, chunk_data, chunk_crc in chunks:
216
+ out.extend(
217
+ struct.pack(">I", len(chunk_data)) + chunk_type + chunk_data + chunk_crc
218
+ )
219
+
220
+ if chunk_type == b"IHDR" and not inserted:
221
+ # sRGB chunk
222
+ out.extend(
223
+ png_chunk(
224
+ b"sRGB",
225
+ b"\x00",
226
+ )
227
+ )
228
+
229
+ if metadata is not None:
230
+ json_string = json.dumps(
231
+ metadata, separators=(",", ":"), cls=BytesEncoder
232
+ )
233
+
234
+ exif = build_exif_user_comment(width, height)
235
+
236
+ out.extend(
237
+ png_chunk(
238
+ b"eXIf",
239
+ exif,
240
+ )
241
+ )
242
+
243
+ xmp = build_drawthings_xmp(
244
+ json_string,
245
+ build_description(metadata),
246
+ )
247
+
248
+ itxt = build_itxt_chunk(
249
+ "XML:com.adobe.xmp",
250
+ xmp,
251
+ )
252
+
253
+ out.extend(
254
+ png_chunk(
255
+ b"iTXt",
256
+ itxt,
257
+ )
258
+ )
259
+
260
+ inserted = True
261
+
262
+ return bytes(out)