atelier-diffusion 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. atelier_diffusion-0.1.0/LICENSE +21 -0
  2. atelier_diffusion-0.1.0/PKG-INFO +378 -0
  3. atelier_diffusion-0.1.0/README.md +321 -0
  4. atelier_diffusion-0.1.0/atelier/__init__.py +5 -0
  5. atelier_diffusion-0.1.0/atelier/adapters/__init__.py +4 -0
  6. atelier_diffusion-0.1.0/atelier/adapters/base.py +99 -0
  7. atelier_diffusion-0.1.0/atelier/adapters/qwen_edit.py +302 -0
  8. atelier_diffusion-0.1.0/atelier/adapters/qwen_image.py +364 -0
  9. atelier_diffusion-0.1.0/atelier/adapters/sdxl.py +286 -0
  10. atelier_diffusion-0.1.0/atelier/callbacks.py +26 -0
  11. atelier_diffusion-0.1.0/atelier/config.py +50 -0
  12. atelier_diffusion-0.1.0/atelier/data/__init__.py +5 -0
  13. atelier_diffusion-0.1.0/atelier/data/cache.py +135 -0
  14. atelier_diffusion-0.1.0/atelier/data/editing.py +142 -0
  15. atelier_diffusion-0.1.0/atelier/data/generation.py +93 -0
  16. atelier_diffusion-0.1.0/atelier/losses/__init__.py +8 -0
  17. atelier_diffusion-0.1.0/atelier/losses/diffusion_cpo.py +59 -0
  18. atelier_diffusion-0.1.0/atelier/losses/diffusion_dpo.py +95 -0
  19. atelier_diffusion-0.1.0/atelier/losses/diffusion_ipo.py +81 -0
  20. atelier_diffusion-0.1.0/atelier/losses/diffusion_kto.py +109 -0
  21. atelier_diffusion-0.1.0/atelier/losses/diffusion_orpo.py +72 -0
  22. atelier_diffusion-0.1.0/atelier/losses/diffusion_simpo.py +54 -0
  23. atelier_diffusion-0.1.0/atelier/losses/epsilon.py +31 -0
  24. atelier_diffusion-0.1.0/atelier/losses/flow_matching.py +70 -0
  25. atelier_diffusion-0.1.0/atelier/losses/utils.py +174 -0
  26. atelier_diffusion-0.1.0/atelier/registry.py +51 -0
  27. atelier_diffusion-0.1.0/atelier/train.py +270 -0
  28. atelier_diffusion-0.1.0/atelier/trainer.py +450 -0
  29. atelier_diffusion-0.1.0/atelier_diffusion.egg-info/PKG-INFO +378 -0
  30. atelier_diffusion-0.1.0/atelier_diffusion.egg-info/SOURCES.txt +39 -0
  31. atelier_diffusion-0.1.0/atelier_diffusion.egg-info/dependency_links.txt +1 -0
  32. atelier_diffusion-0.1.0/atelier_diffusion.egg-info/requires.txt +20 -0
  33. atelier_diffusion-0.1.0/atelier_diffusion.egg-info/top_level.txt +1 -0
  34. atelier_diffusion-0.1.0/pyproject.toml +60 -0
  35. atelier_diffusion-0.1.0/setup.cfg +4 -0
  36. atelier_diffusion-0.1.0/tests/test_adapters.py +172 -0
  37. atelier_diffusion-0.1.0/tests/test_config.py +40 -0
  38. atelier_diffusion-0.1.0/tests/test_data.py +352 -0
  39. atelier_diffusion-0.1.0/tests/test_losses.py +494 -0
  40. atelier_diffusion-0.1.0/tests/test_train_cli.py +179 -0
  41. atelier_diffusion-0.1.0/tests/test_trainer.py +398 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Schneewolf Labs
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,378 @@
1
+ Metadata-Version: 2.4
2
+ Name: atelier-diffusion
3
+ Version: 0.1.0
4
+ Summary: Simple, multi-GPU diffusion model fine-tuning library (Schneewolf Labs)
5
+ Author: Schneewolf Labs
6
+ License: MIT License
7
+
8
+ Copyright (c) 2025 Schneewolf Labs
9
+
10
+ Permission is hereby granted, free of charge, to any person obtaining a copy
11
+ of this software and associated documentation files (the "Software"), to deal
12
+ in the Software without restriction, including without limitation the rights
13
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
14
+ copies of the Software, and to permit persons to whom the Software is
15
+ furnished to do so, subject to the following conditions:
16
+
17
+ The above copyright notice and this permission notice shall be included in all
18
+ copies or substantial portions of the Software.
19
+
20
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
21
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
22
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
23
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
24
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
26
+ SOFTWARE.
27
+
28
+ Project-URL: Homepage, https://github.com/Schneewolf-Labs/atelier
29
+ Project-URL: Repository, https://github.com/Schneewolf-Labs/atelier
30
+ Project-URL: Issues, https://github.com/Schneewolf-Labs/atelier/issues
31
+ Keywords: diffusion,lora,qwen-image,sdxl,fine-tuning,schneewolf-labs,atelier
32
+ Classifier: License :: OSI Approved :: MIT License
33
+ Classifier: Programming Language :: Python :: 3.10
34
+ Classifier: Programming Language :: Python :: 3.11
35
+ Classifier: Programming Language :: Python :: 3.12
36
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
37
+ Requires-Python: >=3.10
38
+ Description-Content-Type: text/markdown
39
+ License-File: LICENSE
40
+ Requires-Dist: accelerate>=0.24.0
41
+ Requires-Dist: peft>=0.6.0
42
+ Requires-Dist: datasets>=2.14.0
43
+ Requires-Dist: tqdm>=4.60.0
44
+ Requires-Dist: Pillow>=9.0.0
45
+ Requires-Dist: numpy>=1.24.0
46
+ Provides-Extra: quantization
47
+ Requires-Dist: bitsandbytes>=0.41.0; extra == "quantization"
48
+ Provides-Extra: logging
49
+ Requires-Dist: wandb>=0.15.0; extra == "logging"
50
+ Provides-Extra: yaml
51
+ Requires-Dist: PyYAML>=6.0; extra == "yaml"
52
+ Provides-Extra: dev
53
+ Requires-Dist: pytest>=7.0; extra == "dev"
54
+ Requires-Dist: ruff>=0.1.0; extra == "dev"
55
+ Requires-Dist: PyYAML>=6.0; extra == "dev"
56
+ Dynamic: license-file
57
+
58
+ # 🎨 Atelier 🔨
59
+
60
+ A simple, multi-GPU diffusion model fine-tuning library. One training loop, pluggable adapters and loss functions.
61
+
62
+ Sister project to [Grimoire](https://github.com/Schneewolf-Labs/Grimoire) (LLM fine-tuning). Both serve as training engines for [Merlina](https://github.com/Schneewolf-Labs/Merlina).
63
+
64
+ ## Why
65
+
66
+ Diffusion model training scripts tend to be monolithic — model loading, data processing, the training loop, and architecture-specific forward passes all tangled together. Switching from SDXL to Qwen-Image-Edit means rewriting the whole script.
67
+
68
+ Atelier separates what varies (model architecture, training objective) from what doesn't (the training loop, multi-GPU, checkpointing, logging). Adding a new model means writing an adapter. Adding a new training objective means writing a loss function. The trainer never changes.
69
+
70
+ ## Install
71
+
72
+ ```bash
73
+ pip install -e .
74
+
75
+ # With optional dependencies
76
+ pip install -e ".[quantization]" # bitsandbytes for 8-bit optimizers
77
+ pip install -e ".[logging]" # wandb
78
+ pip install -e ".[all]" # everything
79
+ ```
80
+
81
+ ## Quick start
82
+
83
+ ### Qwen-Image-Edit LoRA (flow matching)
84
+
85
+ ```python
86
+ from peft import LoraConfig
87
+ from atelier import AtelierTrainer, TrainingConfig
88
+ from atelier.adapters import QwenEditAdapter
89
+ from atelier.losses import FlowMatchingLoss
90
+ from atelier.data import EditingDataset, cache_embeddings
91
+
92
+ # Load adapter (handles model, VAE, text encoder, scheduler)
93
+ adapter = QwenEditAdapter("Qwen/Qwen-Image-Edit")
94
+
95
+ # Pre-compute embeddings to save VRAM during training
96
+ text_emb, target_emb, control_emb = cache_embeddings(
97
+ raw_dataset, adapter, cache_dir="./output/cache",
98
+ )
99
+ adapter.free_encoders() # reclaim VRAM
100
+
101
+ dataset = EditingDataset(
102
+ raw_dataset,
103
+ cached_text_embeddings=text_emb,
104
+ cached_target_embeddings=target_emb,
105
+ cached_control_embeddings=control_emb,
106
+ )
107
+
108
+ trainer = AtelierTrainer(
109
+ adapter=adapter,
110
+ config=TrainingConfig(
111
+ output_dir="./output",
112
+ num_epochs=50,
113
+ batch_size=1,
114
+ learning_rate=1e-4,
115
+ gradient_accumulation_steps=2,
116
+ ),
117
+ loss_fn=FlowMatchingLoss(),
118
+ train_dataset=dataset,
119
+ peft_config=LoraConfig(
120
+ r=64,
121
+ lora_alpha=128,
122
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
123
+ ),
124
+ )
125
+
126
+ trainer.train()
127
+ trainer.save_model("./my-lora")
128
+ ```
129
+
130
+ ### Qwen-Image LoRA (text-to-image, flow matching)
131
+
132
+ Same loss as Qwen-Image-Edit; the adapter differs because the text encoder
133
+ is not vision-conditioned and the transformer sees only the noised target
134
+ (no control image concat).
135
+
136
+ ```python
137
+ from peft import LoraConfig
138
+ from atelier import AtelierTrainer, TrainingConfig
139
+ from atelier.adapters import QwenImageAdapter
140
+ from atelier.losses import FlowMatchingLoss
141
+ from atelier.data import EditingDataset, cache_embeddings
142
+
143
+ adapter = QwenImageAdapter("Qwen/Qwen-Image")
144
+
145
+ # Dataset only needs (prompt, chosen) — no "rejected" column.
146
+ text_emb, target_emb, _ = cache_embeddings(
147
+ raw_dataset, adapter, cache_dir="./output/cache",
148
+ )
149
+ adapter.free_encoders()
150
+
151
+ dataset = EditingDataset(
152
+ raw_dataset,
153
+ cached_text_embeddings=text_emb,
154
+ cached_target_embeddings=target_emb,
155
+ )
156
+
157
+ trainer = AtelierTrainer(
158
+ adapter=adapter,
159
+ config=TrainingConfig(output_dir="./output", num_epochs=8, batch_size=1),
160
+ loss_fn=FlowMatchingLoss(),
161
+ train_dataset=dataset,
162
+ peft_config=LoraConfig(
163
+ r=32, lora_alpha=64,
164
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
165
+ ),
166
+ )
167
+ trainer.train()
168
+ trainer.save_model("./my-qwen-image-lora")
169
+ ```
170
+
171
+ ### SDXL DPO (preference optimization)
172
+
173
+ Same trainer, different adapter and loss function.
174
+
175
+ ```python
176
+ from atelier.adapters import SDXLAdapter
177
+ from atelier.losses import DiffusionDPOLoss
178
+ from atelier.data import GenerationDataset
179
+
180
+ adapter = SDXLAdapter(
181
+ "stabilityai/stable-diffusion-xl-base-1.0",
182
+ weights="/path/to/model.safetensors",
183
+ )
184
+ adapter.freeze_layers(strategy="color_blocks", layers="0,1")
185
+
186
+ dataset = GenerationDataset(
187
+ raw_dataset,
188
+ tokenizer=adapter.tokenizer,
189
+ tokenizer_2=adapter.tokenizer_2,
190
+ )
191
+
192
+ trainer = AtelierTrainer(
193
+ adapter=adapter,
194
+ config=TrainingConfig(
195
+ output_dir="./output",
196
+ num_epochs=10,
197
+ batch_size=1,
198
+ learning_rate=2e-6,
199
+ optimizer="adamw_8bit",
200
+ mixed_precision="fp16",
201
+ ),
202
+ loss_fn=DiffusionDPOLoss(beta=0.4, sft_weight=0.3),
203
+ train_dataset=dataset,
204
+ )
205
+
206
+ trainer.train()
207
+ trainer.save_model("./my-sdxl")
208
+ ```
209
+
210
+ ### With LoRA
211
+
212
+ Pass a `peft_config` and Atelier handles the rest.
213
+
214
+ ```python
215
+ from peft import LoraConfig
216
+
217
+ trainer = AtelierTrainer(
218
+ adapter=adapter,
219
+ config=TrainingConfig(...),
220
+ loss_fn=FlowMatchingLoss(),
221
+ train_dataset=dataset,
222
+ peft_config=LoraConfig(
223
+ r=64,
224
+ lora_alpha=128,
225
+ target_modules=["to_k", "to_q", "to_v", "to_out.0"],
226
+ ),
227
+ )
228
+ ```
229
+
230
+ ## Guides
231
+
232
+ - **[Loss Formulas](docs/loss-formulas.md)** — Math for flow matching and diffusion DPO
233
+ - **[Adapters](docs/adapters.md)** — Writing a custom adapter for a new model architecture
234
+ - **[Callbacks](docs/callbacks.md)** — Hooking into the training loop
235
+ - **[Multi-GPU and DeepSpeed](docs/deepspeed.md)** — Distributed training setup
236
+
237
+ ## YAML config + CLI
238
+
239
+ For orchestrators (e.g. [Merlina](https://github.com/Schneewolf-Labs/Merlina)) or
240
+ when you just don't want to write a Python wrapper per run, train from a YAML config:
241
+
242
+ ```bash
243
+ pip install -e ".[yaml]"
244
+ python -m atelier.train --config configs/qwen_image_lora_example.yaml
245
+
246
+ # Override anything on the CLI (JSON-decoded values):
247
+ python -m atelier.train --config configs/my.yaml \
248
+ --set training.num_epochs=2 \
249
+ --set training.output_dir=./output-quick \
250
+ --set 'peft.target_modules=["to_q","to_v"]'
251
+ ```
252
+
253
+ The YAML schema mirrors the Python API one-for-one — `model.adapter` picks an
254
+ adapter from `atelier.registry.ADAPTERS`, `loss.type` picks from `LOSSES`,
255
+ `peft` becomes a `LoraConfig`, `training` becomes a `TrainingConfig`, and
256
+ `dataset` accepts an HF hub name, a local JSONL, or a `load_from_disk` path.
257
+ See `atelier/train.py` for the full schema and `configs/qwen_image_lora_example.yaml`
258
+ for a worked example.
259
+
260
+ ## Multi-GPU
261
+
262
+ No code changes. Configure with `accelerate` and launch:
263
+
264
+ ```bash
265
+ accelerate config
266
+ accelerate launch --multi_gpu --num_processes 4 train.py
267
+ accelerate launch --use_deepspeed --deepspeed_config ds_config.json train.py
268
+ ```
269
+
270
+ ## Callbacks
271
+
272
+ Subclass `TrainerCallback` and override the hooks you need:
273
+
274
+ ```python
275
+ from atelier import TrainerCallback
276
+
277
+ class MyCallback(TrainerCallback):
278
+ def on_step_end(self, trainer, step, loss, metrics):
279
+ if should_stop():
280
+ trainer.request_stop()
281
+
282
+ def on_log(self, trainer, metrics):
283
+ print(f"Step {trainer.global_step}: {metrics}")
284
+
285
+ trainer = AtelierTrainer(..., callbacks=[MyCallback()])
286
+ ```
287
+
288
+ Available hooks: `on_train_begin`, `on_train_end`, `on_epoch_begin`, `on_epoch_end`, `on_step_end`, `on_log`, `on_evaluate`, `on_save`.
289
+
290
+ ## Configuration
291
+
292
+ `TrainingConfig` fields with defaults:
293
+
294
+ | Field | Default | Description |
295
+ |---|---|---|
296
+ | `output_dir` | `"./output"` | Checkpoints and saved models |
297
+ | `num_epochs` | `3` | Number of training epochs |
298
+ | `batch_size` | `1` | Per-device batch size |
299
+ | `gradient_accumulation_steps` | `1` | Steps before optimizer update |
300
+ | `learning_rate` | `1e-4` | Peak learning rate |
301
+ | `weight_decay` | `0.01` | L2 regularization |
302
+ | `warmup_ratio` | `0.1` | Fraction of steps for LR warmup |
303
+ | `warmup_steps` | `0` | Overrides `warmup_ratio` if > 0 |
304
+ | `max_grad_norm` | `1.0` | Gradient clipping |
305
+ | `mixed_precision` | `"bf16"` | `"no"`, `"fp16"`, or `"bf16"` |
306
+ | `gradient_checkpointing` | `True` | Trade compute for memory |
307
+ | `optimizer` | `"adamw"` | See supported optimizers below |
308
+ | `lr_scheduler` | `"cosine"` | `"linear"`, `"cosine"`, `"constant"`, `"constant_with_warmup"` |
309
+ | `logging_steps` | `10` | Log metrics every N steps |
310
+ | `eval_steps` | `None` | Evaluate every N steps |
311
+ | `save_steps` | `None` | Checkpoint every N steps |
312
+ | `save_total_limit` | `2` | Max checkpoints to keep |
313
+ | `save_on_epoch_end` | `True` | Checkpoint after each epoch |
314
+ | `resume_from_checkpoint` | `None` | Path to resume from |
315
+ | `seed` | `42` | Random seed |
316
+ | `log_with` | `None` | `"wandb"` for W&B tracking |
317
+
318
+ **Supported optimizers:** `adamw`, `adamw_8bit`, `paged_adamw_8bit`, `adafactor`, `sgd`
319
+
320
+ ## Architecture
321
+
322
+ ```
323
+ atelier/
324
+ ├── trainer.py # AtelierTrainer — the training loop
325
+ ├── config.py # TrainingConfig dataclass
326
+ ├── callbacks.py # TrainerCallback base class
327
+ ├── adapters/
328
+ │ ├── base.py # ModelAdapter protocol
329
+ │ ├── qwen_edit.py # Qwen-Image-Edit (DiT + flow matching, image-conditioned)
330
+ │ ├── qwen_image.py # Qwen-Image (DiT + flow matching, text-to-image)
331
+ │ └── sdxl.py # SDXL (UNet + DDPM)
332
+ ├── losses/
333
+ │ ├── flow_matching.py # Flow matching MSE
334
+ │ └── diffusion_dpo.py # DPO + SFT regularization
335
+ └── data/
336
+ ├── editing.py # Paired image editing dataset
337
+ ├── generation.py # Text-to-image dataset
338
+ └── cache.py # Embedding pre-computation
339
+ ```
340
+
341
+ ### How it fits together
342
+
343
+ The **adapter** encapsulates everything that varies per model architecture — loading, encoding, the forward pass, and saving. In Grimoire (LLM training), every model has the same forward signature (`model(input_ids)` → logits). In diffusion training, forward passes vary wildly: Qwen-Image-Edit needs latent packing, control image concatenation, and RoPE shapes; SDXL needs dual CLIP conditioning and time embeddings. The adapter hides this.
344
+
345
+ The **loss function** orchestrates the training objective — sampling noise and timesteps, calling the adapter's forward pass, and computing the loss. Flow matching predicts the velocity field; DPO compares noise predictions for chosen vs rejected images.
346
+
347
+ The **trainer** owns the loop — optimizer, gradient accumulation, checkpointing, logging. It calls `loss_fn(adapter, model, batch)` and never needs to know what model architecture or training objective is being used.
348
+
349
+ ### Loss function interface
350
+
351
+ ```python
352
+ class MyLoss:
353
+ def __call__(self, adapter, model, batch, training=True):
354
+ # Use adapter for noise sampling, forward pass, target computation
355
+ return loss, metrics_dict
356
+
357
+ def create_collator(self):
358
+ return MyCollator()
359
+ ```
360
+
361
+ ### Adapter interface
362
+
363
+ ```python
364
+ class MyAdapter(ModelAdapter):
365
+ def model(self): ... # The trainable model
366
+ def encode_images(self): ... # VAE encode
367
+ def encode_text(self): ... # Text encode
368
+ def sample_timesteps(self): ... # Timestep sampling
369
+ def add_noise(self): ... # Create noisy input
370
+ def compute_target(self): ... # What model should predict
371
+ def forward(self): ... # Architecture-specific forward
372
+ def save_lora(self): ... # Save LoRA weights
373
+ def save_model(self): ... # Save full model
374
+ ```
375
+
376
+ ## License
377
+
378
+ MIT