atelier-diffusion 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- atelier/__init__.py +5 -0
- atelier/adapters/__init__.py +4 -0
- atelier/adapters/base.py +99 -0
- atelier/adapters/qwen_edit.py +302 -0
- atelier/adapters/qwen_image.py +364 -0
- atelier/adapters/sdxl.py +286 -0
- atelier/callbacks.py +26 -0
- atelier/config.py +50 -0
- atelier/data/__init__.py +5 -0
- atelier/data/cache.py +135 -0
- atelier/data/editing.py +142 -0
- atelier/data/generation.py +93 -0
- atelier/losses/__init__.py +8 -0
- atelier/losses/diffusion_cpo.py +59 -0
- atelier/losses/diffusion_dpo.py +95 -0
- atelier/losses/diffusion_ipo.py +81 -0
- atelier/losses/diffusion_kto.py +109 -0
- atelier/losses/diffusion_orpo.py +72 -0
- atelier/losses/diffusion_simpo.py +54 -0
- atelier/losses/epsilon.py +31 -0
- atelier/losses/flow_matching.py +70 -0
- atelier/losses/utils.py +174 -0
- atelier/registry.py +51 -0
- atelier/train.py +270 -0
- atelier/trainer.py +450 -0
- atelier_diffusion-0.1.0.dist-info/METADATA +378 -0
- atelier_diffusion-0.1.0.dist-info/RECORD +30 -0
- atelier_diffusion-0.1.0.dist-info/WHEEL +5 -0
- atelier_diffusion-0.1.0.dist-info/licenses/LICENSE +21 -0
- atelier_diffusion-0.1.0.dist-info/top_level.txt +1 -0
atelier/__init__.py
ADDED
atelier/adapters/base.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class ModelAdapter:
|
|
5
|
+
"""Base class for model adapters.
|
|
6
|
+
|
|
7
|
+
Adapters encapsulate everything that varies per model architecture:
|
|
8
|
+
loading, encoding, forward pass, noise scheduling, and saving.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
@property
|
|
12
|
+
def model(self):
|
|
13
|
+
"""The trainable model (transformer or UNet)."""
|
|
14
|
+
raise NotImplementedError
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def noise_scheduler(self):
|
|
18
|
+
"""The noise scheduler for this model."""
|
|
19
|
+
raise NotImplementedError
|
|
20
|
+
|
|
21
|
+
@property
|
|
22
|
+
def device(self):
|
|
23
|
+
"""The active device for encoders + model.
|
|
24
|
+
|
|
25
|
+
Override if the adapter doesn't store ``self._device``. Used by
|
|
26
|
+
cache_embeddings + other utilities that need a device to run on
|
|
27
|
+
but can't ask ``adapter.model.device`` (the model may not be
|
|
28
|
+
loaded yet when only encoders are present).
|
|
29
|
+
"""
|
|
30
|
+
return getattr(self, "_device", "cpu")
|
|
31
|
+
|
|
32
|
+
def encode_images(self, images, device=None):
|
|
33
|
+
"""Encode PIL images to latent space via VAE.
|
|
34
|
+
|
|
35
|
+
Returns a tensor of latents.
|
|
36
|
+
"""
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
def encode_image_tensor(self, image_tensor, device=None):
|
|
40
|
+
"""Encode a batch of image tensors [B, C, H, W] in [-1, 1] to latents.
|
|
41
|
+
|
|
42
|
+
Used by loss functions for on-the-fly encoding from pre-processed tensors.
|
|
43
|
+
"""
|
|
44
|
+
raise NotImplementedError
|
|
45
|
+
|
|
46
|
+
def encode_text(self, prompts, device=None, **kwargs):
|
|
47
|
+
"""Encode text prompts to embeddings.
|
|
48
|
+
|
|
49
|
+
Returns a dict of tensors (prompt_embeds, masks, etc.).
|
|
50
|
+
"""
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
def sample_timesteps(self, batch_size, device):
|
|
54
|
+
"""Sample timesteps and compute sigmas.
|
|
55
|
+
|
|
56
|
+
Returns (timesteps, sigmas) tensors.
|
|
57
|
+
"""
|
|
58
|
+
raise NotImplementedError
|
|
59
|
+
|
|
60
|
+
def add_noise(self, latents, noise, timesteps, sigmas):
|
|
61
|
+
"""Create noisy input from clean latents.
|
|
62
|
+
|
|
63
|
+
Flow matching: (1 - sigma) * latents + sigma * noise
|
|
64
|
+
DDPM: scheduler.add_noise(latents, noise, timesteps)
|
|
65
|
+
"""
|
|
66
|
+
raise NotImplementedError
|
|
67
|
+
|
|
68
|
+
def compute_target(self, noise, latents, sigmas):
|
|
69
|
+
"""Compute what the model should predict.
|
|
70
|
+
|
|
71
|
+
Flow matching: noise - latents
|
|
72
|
+
Epsilon: noise
|
|
73
|
+
V-prediction: sigma * noise - (1 - sigma) * latents (approx)
|
|
74
|
+
"""
|
|
75
|
+
raise NotImplementedError
|
|
76
|
+
|
|
77
|
+
def forward(self, model, noisy_latents, timesteps, batch):
|
|
78
|
+
"""Run the model forward pass.
|
|
79
|
+
|
|
80
|
+
Handles architecture-specific kwargs (packing, conditioning, etc.).
|
|
81
|
+
Returns the model prediction tensor.
|
|
82
|
+
"""
|
|
83
|
+
raise NotImplementedError
|
|
84
|
+
|
|
85
|
+
def save_lora(self, model, path):
|
|
86
|
+
"""Save LoRA weights in architecture-specific format."""
|
|
87
|
+
raise NotImplementedError
|
|
88
|
+
|
|
89
|
+
def save_model(self, model, path):
|
|
90
|
+
"""Save full model weights."""
|
|
91
|
+
raise NotImplementedError
|
|
92
|
+
|
|
93
|
+
@torch.no_grad()
|
|
94
|
+
def free_encoders(self):
|
|
95
|
+
"""Free VAE and text encoder(s) from memory.
|
|
96
|
+
|
|
97
|
+
Call after pre-computing embeddings to reclaim VRAM before training.
|
|
98
|
+
"""
|
|
99
|
+
pass
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import gc
|
|
3
|
+
import logging
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import torch
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from .base import ModelAdapter
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class QwenEditAdapter(ModelAdapter):
|
|
15
|
+
"""Adapter for Qwen-Image-Edit (DiT + video VAE + flow matching).
|
|
16
|
+
|
|
17
|
+
Handles:
|
|
18
|
+
- Loading QwenImageTransformer2DModel, AutoencoderKLQwenImage, QwenImageEditPipeline
|
|
19
|
+
- Text encoding via pipeline.encode_prompt (with control image conditioning)
|
|
20
|
+
- Image encoding via video VAE ([B, C, 1, H, W] shape)
|
|
21
|
+
- Latent normalization, packing/unpacking
|
|
22
|
+
- Flow matching timestep sampling and noise schedule
|
|
23
|
+
- LoRA saving via QwenImagePipeline.save_lora_weights
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(self, pretrained_path, device="cuda", dtype=None):
|
|
27
|
+
from diffusers import (
|
|
28
|
+
AutoencoderKLQwenImage,
|
|
29
|
+
FlowMatchEulerDiscreteScheduler,
|
|
30
|
+
QwenImageEditPipeline,
|
|
31
|
+
QwenImageTransformer2DModel,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
self._dtype = dtype or torch.bfloat16
|
|
35
|
+
|
|
36
|
+
# Load text encoding pipeline (no transformer/VAE — just text encoder)
|
|
37
|
+
self._pipeline = QwenImageEditPipeline.from_pretrained(
|
|
38
|
+
pretrained_path, transformer=None, vae=None, torch_dtype=self._dtype,
|
|
39
|
+
)
|
|
40
|
+
self._pipeline.to(device)
|
|
41
|
+
|
|
42
|
+
# Load VAE
|
|
43
|
+
self._vae = AutoencoderKLQwenImage.from_pretrained(pretrained_path, subfolder="vae")
|
|
44
|
+
self._vae.to(device, dtype=self._dtype)
|
|
45
|
+
self._vae.eval()
|
|
46
|
+
self._vae.requires_grad_(False)
|
|
47
|
+
|
|
48
|
+
# Load VAE config for latent normalization
|
|
49
|
+
self._vae_config = AutoencoderKLQwenImage.load_config(pretrained_path, subfolder="vae")
|
|
50
|
+
self._init_vae_normalization()
|
|
51
|
+
|
|
52
|
+
# Compute VAE scale factor
|
|
53
|
+
if "temporal_downsample" in self._vae_config:
|
|
54
|
+
self._vae_scale_factor = 2 ** len(self._vae_config["temporal_downsample"])
|
|
55
|
+
elif "temperal_downsample" in self._vae_config:
|
|
56
|
+
self._vae_scale_factor = 2 ** len(self._vae_config["temperal_downsample"])
|
|
57
|
+
else:
|
|
58
|
+
logger.warning("Could not find temporal_downsample in VAE config, using default scale factor of 8")
|
|
59
|
+
self._vae_scale_factor = 8
|
|
60
|
+
|
|
61
|
+
# Load transformer (the trainable model)
|
|
62
|
+
self._model = QwenImageTransformer2DModel.from_pretrained(pretrained_path, subfolder="transformer")
|
|
63
|
+
self._model.to(device, dtype=self._dtype)
|
|
64
|
+
|
|
65
|
+
# Load noise scheduler
|
|
66
|
+
self._scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(pretrained_path, subfolder="scheduler")
|
|
67
|
+
self._scheduler_copy = copy.deepcopy(self._scheduler)
|
|
68
|
+
|
|
69
|
+
self._device = device
|
|
70
|
+
|
|
71
|
+
# Warm up VAE
|
|
72
|
+
dummy = torch.zeros(1, 3, 1, 64, 64).to(device=device, dtype=self._dtype)
|
|
73
|
+
self._vae.encode(dummy)
|
|
74
|
+
|
|
75
|
+
def _init_vae_normalization(self):
|
|
76
|
+
"""Pre-compute latent normalization tensors from VAE config."""
|
|
77
|
+
cfg = self._vae_config
|
|
78
|
+
if "latents_mean" in cfg and "latents_std" in cfg and "z_dim" in cfg:
|
|
79
|
+
self._latents_mean = torch.tensor(cfg["latents_mean"]).view(1, 1, cfg["z_dim"], 1, 1)
|
|
80
|
+
self._latents_std = 1.0 / torch.tensor(cfg["latents_std"]).view(1, 1, cfg["z_dim"], 1, 1)
|
|
81
|
+
self._has_normalization = True
|
|
82
|
+
else:
|
|
83
|
+
self._has_normalization = False
|
|
84
|
+
logger.warning("VAE config missing normalization parameters, skipping latent normalization")
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def model(self):
|
|
88
|
+
return self._model
|
|
89
|
+
|
|
90
|
+
@property
|
|
91
|
+
def noise_scheduler(self):
|
|
92
|
+
return self._scheduler
|
|
93
|
+
|
|
94
|
+
def encode_image_tensor(self, image_tensor, device=None):
|
|
95
|
+
"""Encode image tensors [B, C, H, W] in [-1, 1] to latents via video VAE."""
|
|
96
|
+
device = device or self._device
|
|
97
|
+
# Video VAE expects [B, C, 1, H, W]
|
|
98
|
+
pixel_values = image_tensor.unsqueeze(2).to(device=device, dtype=self._dtype)
|
|
99
|
+
with torch.no_grad():
|
|
100
|
+
latents = self._vae.encode(pixel_values).latent_dist.sample()
|
|
101
|
+
return latents
|
|
102
|
+
|
|
103
|
+
def encode_images(self, images, height=None, width=None, device=None):
|
|
104
|
+
"""Encode PIL images to latents via video VAE.
|
|
105
|
+
|
|
106
|
+
Returns tensor of shape [B, C, 1, H', W'] (video VAE format).
|
|
107
|
+
"""
|
|
108
|
+
device = device or self._device
|
|
109
|
+
latents_list = []
|
|
110
|
+
|
|
111
|
+
for image in images:
|
|
112
|
+
if not isinstance(image, Image.Image):
|
|
113
|
+
image = Image.open(image) if isinstance(image, str) else Image.fromarray(np.uint8(image))
|
|
114
|
+
image = image.convert("RGB")
|
|
115
|
+
|
|
116
|
+
if height and width:
|
|
117
|
+
image = self._pipeline.image_processor.resize(image, height, width)
|
|
118
|
+
|
|
119
|
+
# Convert to tensor [C, H, W] in [-1, 1]
|
|
120
|
+
img_np = np.array(image).astype(np.float32)
|
|
121
|
+
img_tensor = torch.from_numpy(img_np / 127.5 - 1.0).permute(2, 0, 1)
|
|
122
|
+
|
|
123
|
+
# Video VAE expects [B, C, 1, H, W]
|
|
124
|
+
pixel_values = img_tensor.unsqueeze(0).unsqueeze(2).to(device=device, dtype=self._dtype)
|
|
125
|
+
latents = self._vae.encode(pixel_values).latent_dist.sample()
|
|
126
|
+
latents_list.append(latents[0])
|
|
127
|
+
|
|
128
|
+
return torch.stack(latents_list)
|
|
129
|
+
|
|
130
|
+
def encode_text(self, prompts, images=None, device=None, max_sequence_length=1024, **kwargs):
|
|
131
|
+
"""Encode text prompts via QwenImageEditPipeline.encode_prompt.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
prompts: List of prompt strings.
|
|
135
|
+
images: Optional list of control images for conditioning.
|
|
136
|
+
device: Target device.
|
|
137
|
+
max_sequence_length: Max token length.
|
|
138
|
+
|
|
139
|
+
Returns dict with 'prompt_embeds' and 'prompt_embeds_mask'.
|
|
140
|
+
"""
|
|
141
|
+
device = device or self._device
|
|
142
|
+
image = images[0] if images else None
|
|
143
|
+
|
|
144
|
+
prompt_embeds, prompt_embeds_mask = self._pipeline.encode_prompt(
|
|
145
|
+
image=image,
|
|
146
|
+
prompt=prompts,
|
|
147
|
+
device=device,
|
|
148
|
+
num_images_per_prompt=1,
|
|
149
|
+
max_sequence_length=max_sequence_length,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
result = {"prompt_embeds": prompt_embeds}
|
|
153
|
+
if prompt_embeds_mask is not None:
|
|
154
|
+
result["prompt_embeds_mask"] = prompt_embeds_mask
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
def normalize_latents(self, latents):
|
|
158
|
+
"""Apply latent normalization from VAE config."""
|
|
159
|
+
if not self._has_normalization:
|
|
160
|
+
return latents
|
|
161
|
+
mean = self._latents_mean.to(latents.device, latents.dtype)
|
|
162
|
+
std = self._latents_std.to(latents.device, latents.dtype)
|
|
163
|
+
return (latents - mean) * std
|
|
164
|
+
|
|
165
|
+
def sample_timesteps(self, batch_size, device):
|
|
166
|
+
"""Sample timesteps using density-based sampling for flow matching."""
|
|
167
|
+
from diffusers.training_utils import compute_density_for_timestep_sampling
|
|
168
|
+
|
|
169
|
+
u = compute_density_for_timestep_sampling(
|
|
170
|
+
weighting_scheme="none",
|
|
171
|
+
batch_size=batch_size,
|
|
172
|
+
logit_mean=0.0,
|
|
173
|
+
logit_std=1.0,
|
|
174
|
+
mode_scale=1.29,
|
|
175
|
+
)
|
|
176
|
+
indices = (u * self._scheduler_copy.config.num_train_timesteps).long()
|
|
177
|
+
timesteps = self._scheduler_copy.timesteps[indices].to(device=device)
|
|
178
|
+
sigmas = self._get_sigmas(timesteps, device=device, dtype=torch.float32)
|
|
179
|
+
return timesteps, sigmas
|
|
180
|
+
|
|
181
|
+
def _get_sigmas(self, timesteps, device, dtype=torch.float32, n_dim=5):
|
|
182
|
+
"""Look up sigmas for given timesteps."""
|
|
183
|
+
sigmas = self._scheduler_copy.sigmas.to(device=device, dtype=dtype)
|
|
184
|
+
schedule_timesteps = self._scheduler_copy.timesteps.to(device)
|
|
185
|
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
186
|
+
|
|
187
|
+
sigma = sigmas[step_indices].flatten()
|
|
188
|
+
while len(sigma.shape) < n_dim:
|
|
189
|
+
sigma = sigma.unsqueeze(-1)
|
|
190
|
+
return sigma
|
|
191
|
+
|
|
192
|
+
def add_noise(self, latents, noise, timesteps, sigmas):
|
|
193
|
+
"""Flow matching noise addition: (1 - sigma) * latents + sigma * noise."""
|
|
194
|
+
sigmas = sigmas.to(latents.device, latents.dtype)
|
|
195
|
+
return (1.0 - sigmas) * latents + sigmas * noise
|
|
196
|
+
|
|
197
|
+
def compute_target(self, noise, latents, sigmas):
|
|
198
|
+
"""Flow matching target: noise - latents."""
|
|
199
|
+
return noise - latents
|
|
200
|
+
|
|
201
|
+
def forward(self, model, noisy_latents, timesteps, batch):
|
|
202
|
+
"""Run QwenImageTransformer2DModel with latent packing and control conditioning."""
|
|
203
|
+
from diffusers import QwenImageEditPipeline
|
|
204
|
+
|
|
205
|
+
control_latents = batch["control_latents"].to(noisy_latents.device, noisy_latents.dtype)
|
|
206
|
+
prompt_embeds = batch["prompt_embeds"].to(noisy_latents.device, noisy_latents.dtype)
|
|
207
|
+
|
|
208
|
+
if "prompt_embeds_mask" in batch and isinstance(batch["prompt_embeds_mask"], torch.Tensor):
|
|
209
|
+
prompt_mask = batch["prompt_embeds_mask"].to(dtype=torch.int32, device=noisy_latents.device)
|
|
210
|
+
else:
|
|
211
|
+
prompt_mask = torch.ones(
|
|
212
|
+
prompt_embeds.shape[:2], dtype=torch.int32, device=noisy_latents.device,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
bsz = noisy_latents.shape[0]
|
|
216
|
+
|
|
217
|
+
# Pack latents
|
|
218
|
+
packed_noisy = QwenImageEditPipeline._pack_latents(
|
|
219
|
+
noisy_latents, bsz,
|
|
220
|
+
noisy_latents.shape[2], noisy_latents.shape[3], noisy_latents.shape[4],
|
|
221
|
+
)
|
|
222
|
+
packed_control = QwenImageEditPipeline._pack_latents(
|
|
223
|
+
control_latents, bsz,
|
|
224
|
+
control_latents.shape[2], control_latents.shape[3], control_latents.shape[4],
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Concatenate target + control
|
|
228
|
+
packed_input = torch.cat([packed_noisy, packed_control], dim=1)
|
|
229
|
+
|
|
230
|
+
# Image shapes for RoPE
|
|
231
|
+
img_shapes = [
|
|
232
|
+
[
|
|
233
|
+
(1, noisy_latents.shape[3] // 2, noisy_latents.shape[4] // 2),
|
|
234
|
+
(1, control_latents.shape[3] // 2, control_latents.shape[4] // 2),
|
|
235
|
+
]
|
|
236
|
+
] * bsz
|
|
237
|
+
|
|
238
|
+
txt_seq_lens = prompt_mask.sum(dim=1).tolist()
|
|
239
|
+
|
|
240
|
+
# Forward pass
|
|
241
|
+
output = model(
|
|
242
|
+
hidden_states=packed_input,
|
|
243
|
+
timestep=timesteps / 1000,
|
|
244
|
+
guidance=None,
|
|
245
|
+
encoder_hidden_states_mask=prompt_mask,
|
|
246
|
+
encoder_hidden_states=prompt_embeds,
|
|
247
|
+
img_shapes=img_shapes,
|
|
248
|
+
txt_seq_lens=txt_seq_lens,
|
|
249
|
+
return_dict=False,
|
|
250
|
+
)[0]
|
|
251
|
+
|
|
252
|
+
# Extract prediction for target (first half)
|
|
253
|
+
output = output[:, :packed_noisy.size(1)]
|
|
254
|
+
|
|
255
|
+
# Unpack
|
|
256
|
+
output = QwenImageEditPipeline._unpack_latents(
|
|
257
|
+
output,
|
|
258
|
+
height=noisy_latents.shape[3] * self._vae_scale_factor,
|
|
259
|
+
width=noisy_latents.shape[4] * self._vae_scale_factor,
|
|
260
|
+
vae_scale_factor=self._vae_scale_factor,
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
return output
|
|
264
|
+
|
|
265
|
+
def save_lora(self, model, path):
|
|
266
|
+
"""Save LoRA weights in the format diffusers + PEFT can load back.
|
|
267
|
+
|
|
268
|
+
Drops the legacy ``convert_state_dict_to_diffusers`` call (which
|
|
269
|
+
rewrites PEFT-format keys into the pre-PEFT diffusers layout that
|
|
270
|
+
modern ``pipe.load_lora_weights`` rejects) and explicitly strips
|
|
271
|
+
the ``base_model.model.`` PEFT wrapper prefix that
|
|
272
|
+
``get_peft_model_state_dict`` retains in some versions. See the
|
|
273
|
+
QwenImageAdapter version for the full bug-tour comment.
|
|
274
|
+
"""
|
|
275
|
+
import os
|
|
276
|
+
|
|
277
|
+
from diffusers import QwenImagePipeline
|
|
278
|
+
from peft.utils import get_peft_model_state_dict
|
|
279
|
+
|
|
280
|
+
os.makedirs(path, exist_ok=True)
|
|
281
|
+
from .qwen_image import strip_peft_prefix
|
|
282
|
+
state_dict = strip_peft_prefix(get_peft_model_state_dict(model))
|
|
283
|
+
QwenImagePipeline.save_lora_weights(path, state_dict, safe_serialization=True)
|
|
284
|
+
logger.info("LoRA weights saved to %s", path)
|
|
285
|
+
|
|
286
|
+
def save_model(self, model, path):
|
|
287
|
+
"""Save full model weights."""
|
|
288
|
+
import os
|
|
289
|
+
|
|
290
|
+
os.makedirs(path, exist_ok=True)
|
|
291
|
+
model.save_pretrained(path, safe_serialization=True)
|
|
292
|
+
logger.info("Model saved to %s", path)
|
|
293
|
+
|
|
294
|
+
def free_encoders(self):
|
|
295
|
+
"""Free VAE and text encoding pipeline from memory."""
|
|
296
|
+
del self._vae
|
|
297
|
+
del self._pipeline
|
|
298
|
+
self._vae = None
|
|
299
|
+
self._pipeline = None
|
|
300
|
+
gc.collect()
|
|
301
|
+
torch.cuda.empty_cache()
|
|
302
|
+
logger.info("Freed VAE and text encoder from memory")
|