atelier-diffusion 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atelier_diffusion-0.1.0/LICENSE +21 -0
- atelier_diffusion-0.1.0/PKG-INFO +378 -0
- atelier_diffusion-0.1.0/README.md +321 -0
- atelier_diffusion-0.1.0/atelier/__init__.py +5 -0
- atelier_diffusion-0.1.0/atelier/adapters/__init__.py +4 -0
- atelier_diffusion-0.1.0/atelier/adapters/base.py +99 -0
- atelier_diffusion-0.1.0/atelier/adapters/qwen_edit.py +302 -0
- atelier_diffusion-0.1.0/atelier/adapters/qwen_image.py +364 -0
- atelier_diffusion-0.1.0/atelier/adapters/sdxl.py +286 -0
- atelier_diffusion-0.1.0/atelier/callbacks.py +26 -0
- atelier_diffusion-0.1.0/atelier/config.py +50 -0
- atelier_diffusion-0.1.0/atelier/data/__init__.py +5 -0
- atelier_diffusion-0.1.0/atelier/data/cache.py +135 -0
- atelier_diffusion-0.1.0/atelier/data/editing.py +142 -0
- atelier_diffusion-0.1.0/atelier/data/generation.py +93 -0
- atelier_diffusion-0.1.0/atelier/losses/__init__.py +8 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_cpo.py +59 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_dpo.py +95 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_ipo.py +81 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_kto.py +109 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_orpo.py +72 -0
- atelier_diffusion-0.1.0/atelier/losses/diffusion_simpo.py +54 -0
- atelier_diffusion-0.1.0/atelier/losses/epsilon.py +31 -0
- atelier_diffusion-0.1.0/atelier/losses/flow_matching.py +70 -0
- atelier_diffusion-0.1.0/atelier/losses/utils.py +174 -0
- atelier_diffusion-0.1.0/atelier/registry.py +51 -0
- atelier_diffusion-0.1.0/atelier/train.py +270 -0
- atelier_diffusion-0.1.0/atelier/trainer.py +450 -0
- atelier_diffusion-0.1.0/atelier_diffusion.egg-info/PKG-INFO +378 -0
- atelier_diffusion-0.1.0/atelier_diffusion.egg-info/SOURCES.txt +39 -0
- atelier_diffusion-0.1.0/atelier_diffusion.egg-info/dependency_links.txt +1 -0
- atelier_diffusion-0.1.0/atelier_diffusion.egg-info/requires.txt +20 -0
- atelier_diffusion-0.1.0/atelier_diffusion.egg-info/top_level.txt +1 -0
- atelier_diffusion-0.1.0/pyproject.toml +60 -0
- atelier_diffusion-0.1.0/setup.cfg +4 -0
- atelier_diffusion-0.1.0/tests/test_adapters.py +172 -0
- atelier_diffusion-0.1.0/tests/test_config.py +40 -0
- atelier_diffusion-0.1.0/tests/test_data.py +352 -0
- atelier_diffusion-0.1.0/tests/test_losses.py +494 -0
- atelier_diffusion-0.1.0/tests/test_train_cli.py +179 -0
- atelier_diffusion-0.1.0/tests/test_trainer.py +398 -0
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Schneewolf Labs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,378 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: atelier-diffusion
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Simple, multi-GPU diffusion model fine-tuning library (Schneewolf Labs)
|
|
5
|
+
Author: Schneewolf Labs
|
|
6
|
+
License: MIT License
|
|
7
|
+
|
|
8
|
+
Copyright (c) 2025 Schneewolf Labs
|
|
9
|
+
|
|
10
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
11
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
12
|
+
in the Software without restriction, including without limitation the rights
|
|
13
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
14
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
15
|
+
furnished to do so, subject to the following conditions:
|
|
16
|
+
|
|
17
|
+
The above copyright notice and this permission notice shall be included in all
|
|
18
|
+
copies or substantial portions of the Software.
|
|
19
|
+
|
|
20
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
21
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
22
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
23
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
24
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
25
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
26
|
+
SOFTWARE.
|
|
27
|
+
|
|
28
|
+
Project-URL: Homepage, https://github.com/Schneewolf-Labs/atelier
|
|
29
|
+
Project-URL: Repository, https://github.com/Schneewolf-Labs/atelier
|
|
30
|
+
Project-URL: Issues, https://github.com/Schneewolf-Labs/atelier/issues
|
|
31
|
+
Keywords: diffusion,lora,qwen-image,sdxl,fine-tuning,schneewolf-labs,atelier
|
|
32
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
33
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
34
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
35
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
36
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
37
|
+
Requires-Python: >=3.10
|
|
38
|
+
Description-Content-Type: text/markdown
|
|
39
|
+
License-File: LICENSE
|
|
40
|
+
Requires-Dist: accelerate>=0.24.0
|
|
41
|
+
Requires-Dist: peft>=0.6.0
|
|
42
|
+
Requires-Dist: datasets>=2.14.0
|
|
43
|
+
Requires-Dist: tqdm>=4.60.0
|
|
44
|
+
Requires-Dist: Pillow>=9.0.0
|
|
45
|
+
Requires-Dist: numpy>=1.24.0
|
|
46
|
+
Provides-Extra: quantization
|
|
47
|
+
Requires-Dist: bitsandbytes>=0.41.0; extra == "quantization"
|
|
48
|
+
Provides-Extra: logging
|
|
49
|
+
Requires-Dist: wandb>=0.15.0; extra == "logging"
|
|
50
|
+
Provides-Extra: yaml
|
|
51
|
+
Requires-Dist: PyYAML>=6.0; extra == "yaml"
|
|
52
|
+
Provides-Extra: dev
|
|
53
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
54
|
+
Requires-Dist: ruff>=0.1.0; extra == "dev"
|
|
55
|
+
Requires-Dist: PyYAML>=6.0; extra == "dev"
|
|
56
|
+
Dynamic: license-file
|
|
57
|
+
|
|
58
|
+
# 🎨 Atelier 🔨
|
|
59
|
+
|
|
60
|
+
A simple, multi-GPU diffusion model fine-tuning library. One training loop, pluggable adapters and loss functions.
|
|
61
|
+
|
|
62
|
+
Sister project to [Grimoire](https://github.com/Schneewolf-Labs/Grimoire) (LLM fine-tuning). Both serve as training engines for [Merlina](https://github.com/Schneewolf-Labs/Merlina).
|
|
63
|
+
|
|
64
|
+
## Why
|
|
65
|
+
|
|
66
|
+
Diffusion model training scripts tend to be monolithic — model loading, data processing, the training loop, and architecture-specific forward passes all tangled together. Switching from SDXL to Qwen-Image-Edit means rewriting the whole script.
|
|
67
|
+
|
|
68
|
+
Atelier separates what varies (model architecture, training objective) from what doesn't (the training loop, multi-GPU, checkpointing, logging). Adding a new model means writing an adapter. Adding a new training objective means writing a loss function. The trainer never changes.
|
|
69
|
+
|
|
70
|
+
## Install
|
|
71
|
+
|
|
72
|
+
```bash
|
|
73
|
+
pip install -e .
|
|
74
|
+
|
|
75
|
+
# With optional dependencies
|
|
76
|
+
pip install -e ".[quantization]" # bitsandbytes for 8-bit optimizers
|
|
77
|
+
pip install -e ".[logging]" # wandb
|
|
78
|
+
pip install -e ".[all]" # everything
|
|
79
|
+
```
|
|
80
|
+
|
|
81
|
+
## Quick start
|
|
82
|
+
|
|
83
|
+
### Qwen-Image-Edit LoRA (flow matching)
|
|
84
|
+
|
|
85
|
+
```python
|
|
86
|
+
from peft import LoraConfig
|
|
87
|
+
from atelier import AtelierTrainer, TrainingConfig
|
|
88
|
+
from atelier.adapters import QwenEditAdapter
|
|
89
|
+
from atelier.losses import FlowMatchingLoss
|
|
90
|
+
from atelier.data import EditingDataset, cache_embeddings
|
|
91
|
+
|
|
92
|
+
# Load adapter (handles model, VAE, text encoder, scheduler)
|
|
93
|
+
adapter = QwenEditAdapter("Qwen/Qwen-Image-Edit")
|
|
94
|
+
|
|
95
|
+
# Pre-compute embeddings to save VRAM during training
|
|
96
|
+
text_emb, target_emb, control_emb = cache_embeddings(
|
|
97
|
+
raw_dataset, adapter, cache_dir="./output/cache",
|
|
98
|
+
)
|
|
99
|
+
adapter.free_encoders() # reclaim VRAM
|
|
100
|
+
|
|
101
|
+
dataset = EditingDataset(
|
|
102
|
+
raw_dataset,
|
|
103
|
+
cached_text_embeddings=text_emb,
|
|
104
|
+
cached_target_embeddings=target_emb,
|
|
105
|
+
cached_control_embeddings=control_emb,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
trainer = AtelierTrainer(
|
|
109
|
+
adapter=adapter,
|
|
110
|
+
config=TrainingConfig(
|
|
111
|
+
output_dir="./output",
|
|
112
|
+
num_epochs=50,
|
|
113
|
+
batch_size=1,
|
|
114
|
+
learning_rate=1e-4,
|
|
115
|
+
gradient_accumulation_steps=2,
|
|
116
|
+
),
|
|
117
|
+
loss_fn=FlowMatchingLoss(),
|
|
118
|
+
train_dataset=dataset,
|
|
119
|
+
peft_config=LoraConfig(
|
|
120
|
+
r=64,
|
|
121
|
+
lora_alpha=128,
|
|
122
|
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
|
123
|
+
),
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
trainer.train()
|
|
127
|
+
trainer.save_model("./my-lora")
|
|
128
|
+
```
|
|
129
|
+
|
|
130
|
+
### Qwen-Image LoRA (text-to-image, flow matching)
|
|
131
|
+
|
|
132
|
+
Same loss as Qwen-Image-Edit; the adapter differs because the text encoder
|
|
133
|
+
is not vision-conditioned and the transformer sees only the noised target
|
|
134
|
+
(no control image concat).
|
|
135
|
+
|
|
136
|
+
```python
|
|
137
|
+
from peft import LoraConfig
|
|
138
|
+
from atelier import AtelierTrainer, TrainingConfig
|
|
139
|
+
from atelier.adapters import QwenImageAdapter
|
|
140
|
+
from atelier.losses import FlowMatchingLoss
|
|
141
|
+
from atelier.data import EditingDataset, cache_embeddings
|
|
142
|
+
|
|
143
|
+
adapter = QwenImageAdapter("Qwen/Qwen-Image")
|
|
144
|
+
|
|
145
|
+
# Dataset only needs (prompt, chosen) — no "rejected" column.
|
|
146
|
+
text_emb, target_emb, _ = cache_embeddings(
|
|
147
|
+
raw_dataset, adapter, cache_dir="./output/cache",
|
|
148
|
+
)
|
|
149
|
+
adapter.free_encoders()
|
|
150
|
+
|
|
151
|
+
dataset = EditingDataset(
|
|
152
|
+
raw_dataset,
|
|
153
|
+
cached_text_embeddings=text_emb,
|
|
154
|
+
cached_target_embeddings=target_emb,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
trainer = AtelierTrainer(
|
|
158
|
+
adapter=adapter,
|
|
159
|
+
config=TrainingConfig(output_dir="./output", num_epochs=8, batch_size=1),
|
|
160
|
+
loss_fn=FlowMatchingLoss(),
|
|
161
|
+
train_dataset=dataset,
|
|
162
|
+
peft_config=LoraConfig(
|
|
163
|
+
r=32, lora_alpha=64,
|
|
164
|
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
|
165
|
+
),
|
|
166
|
+
)
|
|
167
|
+
trainer.train()
|
|
168
|
+
trainer.save_model("./my-qwen-image-lora")
|
|
169
|
+
```
|
|
170
|
+
|
|
171
|
+
### SDXL DPO (preference optimization)
|
|
172
|
+
|
|
173
|
+
Same trainer, different adapter and loss function.
|
|
174
|
+
|
|
175
|
+
```python
|
|
176
|
+
from atelier.adapters import SDXLAdapter
|
|
177
|
+
from atelier.losses import DiffusionDPOLoss
|
|
178
|
+
from atelier.data import GenerationDataset
|
|
179
|
+
|
|
180
|
+
adapter = SDXLAdapter(
|
|
181
|
+
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
182
|
+
weights="/path/to/model.safetensors",
|
|
183
|
+
)
|
|
184
|
+
adapter.freeze_layers(strategy="color_blocks", layers="0,1")
|
|
185
|
+
|
|
186
|
+
dataset = GenerationDataset(
|
|
187
|
+
raw_dataset,
|
|
188
|
+
tokenizer=adapter.tokenizer,
|
|
189
|
+
tokenizer_2=adapter.tokenizer_2,
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
trainer = AtelierTrainer(
|
|
193
|
+
adapter=adapter,
|
|
194
|
+
config=TrainingConfig(
|
|
195
|
+
output_dir="./output",
|
|
196
|
+
num_epochs=10,
|
|
197
|
+
batch_size=1,
|
|
198
|
+
learning_rate=2e-6,
|
|
199
|
+
optimizer="adamw_8bit",
|
|
200
|
+
mixed_precision="fp16",
|
|
201
|
+
),
|
|
202
|
+
loss_fn=DiffusionDPOLoss(beta=0.4, sft_weight=0.3),
|
|
203
|
+
train_dataset=dataset,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
trainer.train()
|
|
207
|
+
trainer.save_model("./my-sdxl")
|
|
208
|
+
```
|
|
209
|
+
|
|
210
|
+
### With LoRA
|
|
211
|
+
|
|
212
|
+
Pass a `peft_config` and Atelier handles the rest.
|
|
213
|
+
|
|
214
|
+
```python
|
|
215
|
+
from peft import LoraConfig
|
|
216
|
+
|
|
217
|
+
trainer = AtelierTrainer(
|
|
218
|
+
adapter=adapter,
|
|
219
|
+
config=TrainingConfig(...),
|
|
220
|
+
loss_fn=FlowMatchingLoss(),
|
|
221
|
+
train_dataset=dataset,
|
|
222
|
+
peft_config=LoraConfig(
|
|
223
|
+
r=64,
|
|
224
|
+
lora_alpha=128,
|
|
225
|
+
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
|
226
|
+
),
|
|
227
|
+
)
|
|
228
|
+
```
|
|
229
|
+
|
|
230
|
+
## Guides
|
|
231
|
+
|
|
232
|
+
- **[Loss Formulas](docs/loss-formulas.md)** — Math for flow matching and diffusion DPO
|
|
233
|
+
- **[Adapters](docs/adapters.md)** — Writing a custom adapter for a new model architecture
|
|
234
|
+
- **[Callbacks](docs/callbacks.md)** — Hooking into the training loop
|
|
235
|
+
- **[Multi-GPU and DeepSpeed](docs/deepspeed.md)** — Distributed training setup
|
|
236
|
+
|
|
237
|
+
## YAML config + CLI
|
|
238
|
+
|
|
239
|
+
For orchestrators (e.g. [Merlina](https://github.com/Schneewolf-Labs/Merlina)) or
|
|
240
|
+
when you just don't want to write a Python wrapper per run, train from a YAML config:
|
|
241
|
+
|
|
242
|
+
```bash
|
|
243
|
+
pip install -e ".[yaml]"
|
|
244
|
+
python -m atelier.train --config configs/qwen_image_lora_example.yaml
|
|
245
|
+
|
|
246
|
+
# Override anything on the CLI (JSON-decoded values):
|
|
247
|
+
python -m atelier.train --config configs/my.yaml \
|
|
248
|
+
--set training.num_epochs=2 \
|
|
249
|
+
--set training.output_dir=./output-quick \
|
|
250
|
+
--set 'peft.target_modules=["to_q","to_v"]'
|
|
251
|
+
```
|
|
252
|
+
|
|
253
|
+
The YAML schema mirrors the Python API one-for-one — `model.adapter` picks an
|
|
254
|
+
adapter from `atelier.registry.ADAPTERS`, `loss.type` picks from `LOSSES`,
|
|
255
|
+
`peft` becomes a `LoraConfig`, `training` becomes a `TrainingConfig`, and
|
|
256
|
+
`dataset` accepts an HF hub name, a local JSONL, or a `load_from_disk` path.
|
|
257
|
+
See `atelier/train.py` for the full schema and `configs/qwen_image_lora_example.yaml`
|
|
258
|
+
for a worked example.
|
|
259
|
+
|
|
260
|
+
## Multi-GPU
|
|
261
|
+
|
|
262
|
+
No code changes. Configure with `accelerate` and launch:
|
|
263
|
+
|
|
264
|
+
```bash
|
|
265
|
+
accelerate config
|
|
266
|
+
accelerate launch --multi_gpu --num_processes 4 train.py
|
|
267
|
+
accelerate launch --use_deepspeed --deepspeed_config ds_config.json train.py
|
|
268
|
+
```
|
|
269
|
+
|
|
270
|
+
## Callbacks
|
|
271
|
+
|
|
272
|
+
Subclass `TrainerCallback` and override the hooks you need:
|
|
273
|
+
|
|
274
|
+
```python
|
|
275
|
+
from atelier import TrainerCallback
|
|
276
|
+
|
|
277
|
+
class MyCallback(TrainerCallback):
|
|
278
|
+
def on_step_end(self, trainer, step, loss, metrics):
|
|
279
|
+
if should_stop():
|
|
280
|
+
trainer.request_stop()
|
|
281
|
+
|
|
282
|
+
def on_log(self, trainer, metrics):
|
|
283
|
+
print(f"Step {trainer.global_step}: {metrics}")
|
|
284
|
+
|
|
285
|
+
trainer = AtelierTrainer(..., callbacks=[MyCallback()])
|
|
286
|
+
```
|
|
287
|
+
|
|
288
|
+
Available hooks: `on_train_begin`, `on_train_end`, `on_epoch_begin`, `on_epoch_end`, `on_step_end`, `on_log`, `on_evaluate`, `on_save`.
|
|
289
|
+
|
|
290
|
+
## Configuration
|
|
291
|
+
|
|
292
|
+
`TrainingConfig` fields with defaults:
|
|
293
|
+
|
|
294
|
+
| Field | Default | Description |
|
|
295
|
+
|---|---|---|
|
|
296
|
+
| `output_dir` | `"./output"` | Checkpoints and saved models |
|
|
297
|
+
| `num_epochs` | `3` | Number of training epochs |
|
|
298
|
+
| `batch_size` | `1` | Per-device batch size |
|
|
299
|
+
| `gradient_accumulation_steps` | `1` | Steps before optimizer update |
|
|
300
|
+
| `learning_rate` | `1e-4` | Peak learning rate |
|
|
301
|
+
| `weight_decay` | `0.01` | L2 regularization |
|
|
302
|
+
| `warmup_ratio` | `0.1` | Fraction of steps for LR warmup |
|
|
303
|
+
| `warmup_steps` | `0` | Overrides `warmup_ratio` if > 0 |
|
|
304
|
+
| `max_grad_norm` | `1.0` | Gradient clipping |
|
|
305
|
+
| `mixed_precision` | `"bf16"` | `"no"`, `"fp16"`, or `"bf16"` |
|
|
306
|
+
| `gradient_checkpointing` | `True` | Trade compute for memory |
|
|
307
|
+
| `optimizer` | `"adamw"` | See supported optimizers below |
|
|
308
|
+
| `lr_scheduler` | `"cosine"` | `"linear"`, `"cosine"`, `"constant"`, `"constant_with_warmup"` |
|
|
309
|
+
| `logging_steps` | `10` | Log metrics every N steps |
|
|
310
|
+
| `eval_steps` | `None` | Evaluate every N steps |
|
|
311
|
+
| `save_steps` | `None` | Checkpoint every N steps |
|
|
312
|
+
| `save_total_limit` | `2` | Max checkpoints to keep |
|
|
313
|
+
| `save_on_epoch_end` | `True` | Checkpoint after each epoch |
|
|
314
|
+
| `resume_from_checkpoint` | `None` | Path to resume from |
|
|
315
|
+
| `seed` | `42` | Random seed |
|
|
316
|
+
| `log_with` | `None` | `"wandb"` for W&B tracking |
|
|
317
|
+
|
|
318
|
+
**Supported optimizers:** `adamw`, `adamw_8bit`, `paged_adamw_8bit`, `adafactor`, `sgd`
|
|
319
|
+
|
|
320
|
+
## Architecture
|
|
321
|
+
|
|
322
|
+
```
|
|
323
|
+
atelier/
|
|
324
|
+
├── trainer.py # AtelierTrainer — the training loop
|
|
325
|
+
├── config.py # TrainingConfig dataclass
|
|
326
|
+
├── callbacks.py # TrainerCallback base class
|
|
327
|
+
├── adapters/
|
|
328
|
+
│ ├── base.py # ModelAdapter protocol
|
|
329
|
+
│ ├── qwen_edit.py # Qwen-Image-Edit (DiT + flow matching, image-conditioned)
|
|
330
|
+
│ ├── qwen_image.py # Qwen-Image (DiT + flow matching, text-to-image)
|
|
331
|
+
│ └── sdxl.py # SDXL (UNet + DDPM)
|
|
332
|
+
├── losses/
|
|
333
|
+
│ ├── flow_matching.py # Flow matching MSE
|
|
334
|
+
│ └── diffusion_dpo.py # DPO + SFT regularization
|
|
335
|
+
└── data/
|
|
336
|
+
├── editing.py # Paired image editing dataset
|
|
337
|
+
├── generation.py # Text-to-image dataset
|
|
338
|
+
└── cache.py # Embedding pre-computation
|
|
339
|
+
```
|
|
340
|
+
|
|
341
|
+
### How it fits together
|
|
342
|
+
|
|
343
|
+
The **adapter** encapsulates everything that varies per model architecture — loading, encoding, the forward pass, and saving. In Grimoire (LLM training), every model has the same forward signature (`model(input_ids)` → logits). In diffusion training, forward passes vary wildly: Qwen-Image-Edit needs latent packing, control image concatenation, and RoPE shapes; SDXL needs dual CLIP conditioning and time embeddings. The adapter hides this.
|
|
344
|
+
|
|
345
|
+
The **loss function** orchestrates the training objective — sampling noise and timesteps, calling the adapter's forward pass, and computing the loss. Flow matching predicts the velocity field; DPO compares noise predictions for chosen vs rejected images.
|
|
346
|
+
|
|
347
|
+
The **trainer** owns the loop — optimizer, gradient accumulation, checkpointing, logging. It calls `loss_fn(adapter, model, batch)` and never needs to know what model architecture or training objective is being used.
|
|
348
|
+
|
|
349
|
+
### Loss function interface
|
|
350
|
+
|
|
351
|
+
```python
|
|
352
|
+
class MyLoss:
|
|
353
|
+
def __call__(self, adapter, model, batch, training=True):
|
|
354
|
+
# Use adapter for noise sampling, forward pass, target computation
|
|
355
|
+
return loss, metrics_dict
|
|
356
|
+
|
|
357
|
+
def create_collator(self):
|
|
358
|
+
return MyCollator()
|
|
359
|
+
```
|
|
360
|
+
|
|
361
|
+
### Adapter interface
|
|
362
|
+
|
|
363
|
+
```python
|
|
364
|
+
class MyAdapter(ModelAdapter):
|
|
365
|
+
def model(self): ... # The trainable model
|
|
366
|
+
def encode_images(self): ... # VAE encode
|
|
367
|
+
def encode_text(self): ... # Text encode
|
|
368
|
+
def sample_timesteps(self): ... # Timestep sampling
|
|
369
|
+
def add_noise(self): ... # Create noisy input
|
|
370
|
+
def compute_target(self): ... # What model should predict
|
|
371
|
+
def forward(self): ... # Architecture-specific forward
|
|
372
|
+
def save_lora(self): ... # Save LoRA weights
|
|
373
|
+
def save_model(self): ... # Save full model
|
|
374
|
+
```
|
|
375
|
+
|
|
376
|
+
## License
|
|
377
|
+
|
|
378
|
+
MIT
|