atelier-diffusion 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
atelier/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ from .callbacks import TrainerCallback as TrainerCallback
2
+ from .config import TrainingConfig as TrainingConfig
3
+ from .trainer import AtelierTrainer as AtelierTrainer
4
+
5
+ __version__ = "0.1.0"
@@ -0,0 +1,4 @@
1
+ from .base import ModelAdapter as ModelAdapter
2
+ from .qwen_edit import QwenEditAdapter as QwenEditAdapter
3
+ from .qwen_image import QwenImageAdapter as QwenImageAdapter
4
+ from .sdxl import SDXLAdapter as SDXLAdapter
@@ -0,0 +1,99 @@
1
+ import torch
2
+
3
+
4
+ class ModelAdapter:
5
+ """Base class for model adapters.
6
+
7
+ Adapters encapsulate everything that varies per model architecture:
8
+ loading, encoding, forward pass, noise scheduling, and saving.
9
+ """
10
+
11
+ @property
12
+ def model(self):
13
+ """The trainable model (transformer or UNet)."""
14
+ raise NotImplementedError
15
+
16
+ @property
17
+ def noise_scheduler(self):
18
+ """The noise scheduler for this model."""
19
+ raise NotImplementedError
20
+
21
+ @property
22
+ def device(self):
23
+ """The active device for encoders + model.
24
+
25
+ Override if the adapter doesn't store ``self._device``. Used by
26
+ cache_embeddings + other utilities that need a device to run on
27
+ but can't ask ``adapter.model.device`` (the model may not be
28
+ loaded yet when only encoders are present).
29
+ """
30
+ return getattr(self, "_device", "cpu")
31
+
32
+ def encode_images(self, images, device=None):
33
+ """Encode PIL images to latent space via VAE.
34
+
35
+ Returns a tensor of latents.
36
+ """
37
+ raise NotImplementedError
38
+
39
+ def encode_image_tensor(self, image_tensor, device=None):
40
+ """Encode a batch of image tensors [B, C, H, W] in [-1, 1] to latents.
41
+
42
+ Used by loss functions for on-the-fly encoding from pre-processed tensors.
43
+ """
44
+ raise NotImplementedError
45
+
46
+ def encode_text(self, prompts, device=None, **kwargs):
47
+ """Encode text prompts to embeddings.
48
+
49
+ Returns a dict of tensors (prompt_embeds, masks, etc.).
50
+ """
51
+ raise NotImplementedError
52
+
53
+ def sample_timesteps(self, batch_size, device):
54
+ """Sample timesteps and compute sigmas.
55
+
56
+ Returns (timesteps, sigmas) tensors.
57
+ """
58
+ raise NotImplementedError
59
+
60
+ def add_noise(self, latents, noise, timesteps, sigmas):
61
+ """Create noisy input from clean latents.
62
+
63
+ Flow matching: (1 - sigma) * latents + sigma * noise
64
+ DDPM: scheduler.add_noise(latents, noise, timesteps)
65
+ """
66
+ raise NotImplementedError
67
+
68
+ def compute_target(self, noise, latents, sigmas):
69
+ """Compute what the model should predict.
70
+
71
+ Flow matching: noise - latents
72
+ Epsilon: noise
73
+ V-prediction: sigma * noise - (1 - sigma) * latents (approx)
74
+ """
75
+ raise NotImplementedError
76
+
77
+ def forward(self, model, noisy_latents, timesteps, batch):
78
+ """Run the model forward pass.
79
+
80
+ Handles architecture-specific kwargs (packing, conditioning, etc.).
81
+ Returns the model prediction tensor.
82
+ """
83
+ raise NotImplementedError
84
+
85
+ def save_lora(self, model, path):
86
+ """Save LoRA weights in architecture-specific format."""
87
+ raise NotImplementedError
88
+
89
+ def save_model(self, model, path):
90
+ """Save full model weights."""
91
+ raise NotImplementedError
92
+
93
+ @torch.no_grad()
94
+ def free_encoders(self):
95
+ """Free VAE and text encoder(s) from memory.
96
+
97
+ Call after pre-computing embeddings to reclaim VRAM before training.
98
+ """
99
+ pass
@@ -0,0 +1,302 @@
1
+ import copy
2
+ import gc
3
+ import logging
4
+
5
+ import numpy as np
6
+ import torch
7
+ from PIL import Image
8
+
9
+ from .base import ModelAdapter
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class QwenEditAdapter(ModelAdapter):
15
+ """Adapter for Qwen-Image-Edit (DiT + video VAE + flow matching).
16
+
17
+ Handles:
18
+ - Loading QwenImageTransformer2DModel, AutoencoderKLQwenImage, QwenImageEditPipeline
19
+ - Text encoding via pipeline.encode_prompt (with control image conditioning)
20
+ - Image encoding via video VAE ([B, C, 1, H, W] shape)
21
+ - Latent normalization, packing/unpacking
22
+ - Flow matching timestep sampling and noise schedule
23
+ - LoRA saving via QwenImagePipeline.save_lora_weights
24
+ """
25
+
26
+ def __init__(self, pretrained_path, device="cuda", dtype=None):
27
+ from diffusers import (
28
+ AutoencoderKLQwenImage,
29
+ FlowMatchEulerDiscreteScheduler,
30
+ QwenImageEditPipeline,
31
+ QwenImageTransformer2DModel,
32
+ )
33
+
34
+ self._dtype = dtype or torch.bfloat16
35
+
36
+ # Load text encoding pipeline (no transformer/VAE — just text encoder)
37
+ self._pipeline = QwenImageEditPipeline.from_pretrained(
38
+ pretrained_path, transformer=None, vae=None, torch_dtype=self._dtype,
39
+ )
40
+ self._pipeline.to(device)
41
+
42
+ # Load VAE
43
+ self._vae = AutoencoderKLQwenImage.from_pretrained(pretrained_path, subfolder="vae")
44
+ self._vae.to(device, dtype=self._dtype)
45
+ self._vae.eval()
46
+ self._vae.requires_grad_(False)
47
+
48
+ # Load VAE config for latent normalization
49
+ self._vae_config = AutoencoderKLQwenImage.load_config(pretrained_path, subfolder="vae")
50
+ self._init_vae_normalization()
51
+
52
+ # Compute VAE scale factor
53
+ if "temporal_downsample" in self._vae_config:
54
+ self._vae_scale_factor = 2 ** len(self._vae_config["temporal_downsample"])
55
+ elif "temperal_downsample" in self._vae_config:
56
+ self._vae_scale_factor = 2 ** len(self._vae_config["temperal_downsample"])
57
+ else:
58
+ logger.warning("Could not find temporal_downsample in VAE config, using default scale factor of 8")
59
+ self._vae_scale_factor = 8
60
+
61
+ # Load transformer (the trainable model)
62
+ self._model = QwenImageTransformer2DModel.from_pretrained(pretrained_path, subfolder="transformer")
63
+ self._model.to(device, dtype=self._dtype)
64
+
65
+ # Load noise scheduler
66
+ self._scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
67
+ self._scheduler_copy = copy.deepcopy(self._scheduler)
68
+
69
+ self._device = device
70
+
71
+ # Warm up VAE
72
+ dummy = torch.zeros(1, 3, 1, 64, 64).to(device=device, dtype=self._dtype)
73
+ self._vae.encode(dummy)
74
+
75
+ def _init_vae_normalization(self):
76
+ """Pre-compute latent normalization tensors from VAE config."""
77
+ cfg = self._vae_config
78
+ if "latents_mean" in cfg and "latents_std" in cfg and "z_dim" in cfg:
79
+ self._latents_mean = torch.tensor(cfg["latents_mean"]).view(1, 1, cfg["z_dim"], 1, 1)
80
+ self._latents_std = 1.0 / torch.tensor(cfg["latents_std"]).view(1, 1, cfg["z_dim"], 1, 1)
81
+ self._has_normalization = True
82
+ else:
83
+ self._has_normalization = False
84
+ logger.warning("VAE config missing normalization parameters, skipping latent normalization")
85
+
86
+ @property
87
+ def model(self):
88
+ return self._model
89
+
90
+ @property
91
+ def noise_scheduler(self):
92
+ return self._scheduler
93
+
94
+ def encode_image_tensor(self, image_tensor, device=None):
95
+ """Encode image tensors [B, C, H, W] in [-1, 1] to latents via video VAE."""
96
+ device = device or self._device
97
+ # Video VAE expects [B, C, 1, H, W]
98
+ pixel_values = image_tensor.unsqueeze(2).to(device=device, dtype=self._dtype)
99
+ with torch.no_grad():
100
+ latents = self._vae.encode(pixel_values).latent_dist.sample()
101
+ return latents
102
+
103
+ def encode_images(self, images, height=None, width=None, device=None):
104
+ """Encode PIL images to latents via video VAE.
105
+
106
+ Returns tensor of shape [B, C, 1, H', W'] (video VAE format).
107
+ """
108
+ device = device or self._device
109
+ latents_list = []
110
+
111
+ for image in images:
112
+ if not isinstance(image, Image.Image):
113
+ image = Image.open(image) if isinstance(image, str) else Image.fromarray(np.uint8(image))
114
+ image = image.convert("RGB")
115
+
116
+ if height and width:
117
+ image = self._pipeline.image_processor.resize(image, height, width)
118
+
119
+ # Convert to tensor [C, H, W] in [-1, 1]
120
+ img_np = np.array(image).astype(np.float32)
121
+ img_tensor = torch.from_numpy(img_np / 127.5 - 1.0).permute(2, 0, 1)
122
+
123
+ # Video VAE expects [B, C, 1, H, W]
124
+ pixel_values = img_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=self._dtype)
125
+ latents = self._vae.encode(pixel_values).latent_dist.sample()
126
+ latents_list.append(latents[0])
127
+
128
+ return torch.stack(latents_list)
129
+
130
+ def encode_text(self, prompts, images=None, device=None, max_sequence_length=1024, **kwargs):
131
+ """Encode text prompts via QwenImageEditPipeline.encode_prompt.
132
+
133
+ Args:
134
+ prompts: List of prompt strings.
135
+ images: Optional list of control images for conditioning.
136
+ device: Target device.
137
+ max_sequence_length: Max token length.
138
+
139
+ Returns dict with 'prompt_embeds' and 'prompt_embeds_mask'.
140
+ """
141
+ device = device or self._device
142
+ image = images[0] if images else None
143
+
144
+ prompt_embeds, prompt_embeds_mask = self._pipeline.encode_prompt(
145
+ image=image,
146
+ prompt=prompts,
147
+ device=device,
148
+ num_images_per_prompt=1,
149
+ max_sequence_length=max_sequence_length,
150
+ )
151
+
152
+ result = {"prompt_embeds": prompt_embeds}
153
+ if prompt_embeds_mask is not None:
154
+ result["prompt_embeds_mask"] = prompt_embeds_mask
155
+ return result
156
+
157
+ def normalize_latents(self, latents):
158
+ """Apply latent normalization from VAE config."""
159
+ if not self._has_normalization:
160
+ return latents
161
+ mean = self._latents_mean.to(latents.device, latents.dtype)
162
+ std = self._latents_std.to(latents.device, latents.dtype)
163
+ return (latents - mean) * std
164
+
165
+ def sample_timesteps(self, batch_size, device):
166
+ """Sample timesteps using density-based sampling for flow matching."""
167
+ from diffusers.training_utils import compute_density_for_timestep_sampling
168
+
169
+ u = compute_density_for_timestep_sampling(
170
+ weighting_scheme="none",
171
+ batch_size=batch_size,
172
+ logit_mean=0.0,
173
+ logit_std=1.0,
174
+ mode_scale=1.29,
175
+ )
176
+ indices = (u * self._scheduler_copy.config.num_train_timesteps).long()
177
+ timesteps = self._scheduler_copy.timesteps[indices].to(device=device)
178
+ sigmas = self._get_sigmas(timesteps, device=device, dtype=torch.float32)
179
+ return timesteps, sigmas
180
+
181
+ def _get_sigmas(self, timesteps, device, dtype=torch.float32, n_dim=5):
182
+ """Look up sigmas for given timesteps."""
183
+ sigmas = self._scheduler_copy.sigmas.to(device=device, dtype=dtype)
184
+ schedule_timesteps = self._scheduler_copy.timesteps.to(device)
185
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
186
+
187
+ sigma = sigmas[step_indices].flatten()
188
+ while len(sigma.shape) < n_dim:
189
+ sigma = sigma.unsqueeze(-1)
190
+ return sigma
191
+
192
+ def add_noise(self, latents, noise, timesteps, sigmas):
193
+ """Flow matching noise addition: (1 - sigma) * latents + sigma * noise."""
194
+ sigmas = sigmas.to(latents.device, latents.dtype)
195
+ return (1.0 - sigmas) * latents + sigmas * noise
196
+
197
+ def compute_target(self, noise, latents, sigmas):
198
+ """Flow matching target: noise - latents."""
199
+ return noise - latents
200
+
201
+ def forward(self, model, noisy_latents, timesteps, batch):
202
+ """Run QwenImageTransformer2DModel with latent packing and control conditioning."""
203
+ from diffusers import QwenImageEditPipeline
204
+
205
+ control_latents = batch["control_latents"].to(noisy_latents.device, noisy_latents.dtype)
206
+ prompt_embeds = batch["prompt_embeds"].to(noisy_latents.device, noisy_latents.dtype)
207
+
208
+ if "prompt_embeds_mask" in batch and isinstance(batch["prompt_embeds_mask"], torch.Tensor):
209
+ prompt_mask = batch["prompt_embeds_mask"].to(dtype=torch.int32, device=noisy_latents.device)
210
+ else:
211
+ prompt_mask = torch.ones(
212
+ prompt_embeds.shape[:2], dtype=torch.int32, device=noisy_latents.device,
213
+ )
214
+
215
+ bsz = noisy_latents.shape[0]
216
+
217
+ # Pack latents
218
+ packed_noisy = QwenImageEditPipeline._pack_latents(
219
+ noisy_latents, bsz,
220
+ noisy_latents.shape[2], noisy_latents.shape[3], noisy_latents.shape[4],
221
+ )
222
+ packed_control = QwenImageEditPipeline._pack_latents(
223
+ control_latents, bsz,
224
+ control_latents.shape[2], control_latents.shape[3], control_latents.shape[4],
225
+ )
226
+
227
+ # Concatenate target + control
228
+ packed_input = torch.cat([packed_noisy, packed_control], dim=1)
229
+
230
+ # Image shapes for RoPE
231
+ img_shapes = [
232
+ [
233
+ (1, noisy_latents.shape[3] // 2, noisy_latents.shape[4] // 2),
234
+ (1, control_latents.shape[3] // 2, control_latents.shape[4] // 2),
235
+ ]
236
+ ] * bsz
237
+
238
+ txt_seq_lens = prompt_mask.sum(dim=1).tolist()
239
+
240
+ # Forward pass
241
+ output = model(
242
+ hidden_states=packed_input,
243
+ timestep=timesteps / 1000,
244
+ guidance=None,
245
+ encoder_hidden_states_mask=prompt_mask,
246
+ encoder_hidden_states=prompt_embeds,
247
+ img_shapes=img_shapes,
248
+ txt_seq_lens=txt_seq_lens,
249
+ return_dict=False,
250
+ )[0]
251
+
252
+ # Extract prediction for target (first half)
253
+ output = output[:, :packed_noisy.size(1)]
254
+
255
+ # Unpack
256
+ output = QwenImageEditPipeline._unpack_latents(
257
+ output,
258
+ height=noisy_latents.shape[3] * self._vae_scale_factor,
259
+ width=noisy_latents.shape[4] * self._vae_scale_factor,
260
+ vae_scale_factor=self._vae_scale_factor,
261
+ )
262
+
263
+ return output
264
+
265
+ def save_lora(self, model, path):
266
+ """Save LoRA weights in the format diffusers + PEFT can load back.
267
+
268
+ Drops the legacy ``convert_state_dict_to_diffusers`` call (which
269
+ rewrites PEFT-format keys into the pre-PEFT diffusers layout that
270
+ modern ``pipe.load_lora_weights`` rejects) and explicitly strips
271
+ the ``base_model.model.`` PEFT wrapper prefix that
272
+ ``get_peft_model_state_dict`` retains in some versions. See the
273
+ QwenImageAdapter version for the full bug-tour comment.
274
+ """
275
+ import os
276
+
277
+ from diffusers import QwenImagePipeline
278
+ from peft.utils import get_peft_model_state_dict
279
+
280
+ os.makedirs(path, exist_ok=True)
281
+ from .qwen_image import strip_peft_prefix
282
+ state_dict = strip_peft_prefix(get_peft_model_state_dict(model))
283
+ QwenImagePipeline.save_lora_weights(path, state_dict, safe_serialization=True)
284
+ logger.info("LoRA weights saved to %s", path)
285
+
286
+ def save_model(self, model, path):
287
+ """Save full model weights."""
288
+ import os
289
+
290
+ os.makedirs(path, exist_ok=True)
291
+ model.save_pretrained(path, safe_serialization=True)
292
+ logger.info("Model saved to %s", path)
293
+
294
+ def free_encoders(self):
295
+ """Free VAE and text encoding pipeline from memory."""
296
+ del self._vae
297
+ del self._pipeline
298
+ self._vae = None
299
+ self._pipeline = None
300
+ gc.collect()
301
+ torch.cuda.empty_cache()
302
+ logger.info("Freed VAE and text encoder from memory")