diffsynth-engine 0.4.3.dev9__py3-none-any.whl → 0.4.3.dev11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffsynth_engine/conf/models/qwen_image/qwen2_5_vl_config.json +2 -1
- diffsynth_engine/conf/tokenizers/qwen_image/qwen2_vl_image_processor.json +29 -0
- diffsynth_engine/configs/pipeline.py +5 -0
- diffsynth_engine/models/basic/attention.py +3 -3
- diffsynth_engine/models/qwen_image/qwen2_5_vl.py +41 -57
- diffsynth_engine/models/qwen_image/qwen_image_dit.py +45 -28
- diffsynth_engine/pipelines/base.py +1 -1
- diffsynth_engine/pipelines/qwen_image.py +125 -13
- diffsynth_engine/pipelines/sd_image.py +3 -3
- diffsynth_engine/pipelines/sdxl_image.py +10 -6
- diffsynth_engine/tokenizers/__init__.py +4 -0
- diffsynth_engine/tokenizers/qwen2_vl_image_processor.py +157 -0
- diffsynth_engine/tokenizers/qwen2_vl_processor.py +100 -0
- diffsynth_engine/utils/constants.py +6 -0
- diffsynth_engine/utils/image.py +213 -0
- diffsynth_engine/utils/offload.py +6 -5
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev11.dist-info}/METADATA +2 -2
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev11.dist-info}/RECORD +21 -18
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev11.dist-info}/WHEEL +0 -0
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev11.dist-info}/licenses/LICENSE +0 -0
- {diffsynth_engine-0.4.3.dev9.dist-info → diffsynth_engine-0.4.3.dev11.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
import json
|
|
2
2
|
import torch
|
|
3
|
+
import torch.distributed as dist
|
|
3
4
|
import math
|
|
4
5
|
from typing import Callable, List, Tuple, Optional, Union, Dict
|
|
5
6
|
from tqdm import tqdm
|
|
6
7
|
from einops import rearrange
|
|
7
|
-
|
|
8
|
+
from PIL import Image
|
|
8
9
|
|
|
9
10
|
from diffsynth_engine.configs import QwenImagePipelineConfig, QwenImageStateDicts
|
|
10
11
|
from diffsynth_engine.models.basic.lora import LoRAContext
|
|
@@ -16,13 +17,14 @@ from diffsynth_engine.models.qwen_image import (
|
|
|
16
17
|
Qwen2_5_VLConfig,
|
|
17
18
|
)
|
|
18
19
|
from diffsynth_engine.models.qwen_image import QwenImageVAE
|
|
19
|
-
from diffsynth_engine.tokenizers import Qwen2TokenizerFast
|
|
20
|
+
from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
|
|
20
21
|
from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
|
|
21
22
|
from diffsynth_engine.pipelines.utils import calculate_shift
|
|
22
23
|
from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
|
|
23
24
|
from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
|
|
24
25
|
from diffsynth_engine.utils.constants import (
|
|
25
26
|
QWEN_IMAGE_TOKENIZER_CONF_PATH,
|
|
27
|
+
QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
|
|
26
28
|
QWEN_IMAGE_CONFIG_FILE,
|
|
27
29
|
QWEN_IMAGE_VISION_CONFIG_FILE,
|
|
28
30
|
QWEN_IMAGE_VAE_CONFIG_FILE,
|
|
@@ -44,20 +46,23 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
|
|
|
44
46
|
lora_a_suffix = None
|
|
45
47
|
if "lora_A.default.weight" in key:
|
|
46
48
|
lora_a_suffix = "lora_A.default.weight"
|
|
49
|
+
lora_b_suffix = "lora_B.default.weight"
|
|
47
50
|
elif "lora_A.weight" in key:
|
|
48
51
|
lora_a_suffix = "lora_A.weight"
|
|
52
|
+
lora_b_suffix = "lora_B.weight"
|
|
53
|
+
elif "lora_down.weight" in key:
|
|
54
|
+
lora_a_suffix = "lora_down.weight"
|
|
55
|
+
lora_b_suffix = "lora_up.weight"
|
|
49
56
|
|
|
50
57
|
if lora_a_suffix is None:
|
|
51
58
|
continue
|
|
52
59
|
|
|
53
60
|
lora_args = {}
|
|
54
61
|
lora_args["down"] = param
|
|
55
|
-
|
|
56
|
-
lora_b_suffix = lora_a_suffix.replace("lora_A", "lora_B")
|
|
57
62
|
lora_args["up"] = lora_state_dict[origin_key.replace(lora_a_suffix, lora_b_suffix)]
|
|
58
63
|
|
|
59
64
|
lora_args["rank"] = lora_args["up"].shape[1]
|
|
60
|
-
alpha_key = origin_key.replace(
|
|
65
|
+
alpha_key = origin_key.replace(lora_a_suffix, "alpha")
|
|
61
66
|
|
|
62
67
|
if alpha_key in lora_state_dict:
|
|
63
68
|
alpha = lora_state_dict[alpha_key]
|
|
@@ -83,6 +88,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
83
88
|
self,
|
|
84
89
|
config: QwenImagePipelineConfig,
|
|
85
90
|
tokenizer: Qwen2TokenizerFast,
|
|
91
|
+
processor: Qwen2VLProcessor,
|
|
86
92
|
encoder: Qwen2_5_VLForConditionalGeneration,
|
|
87
93
|
dit: QwenImageDiT,
|
|
88
94
|
vae: QwenImageVAE,
|
|
@@ -97,11 +103,15 @@ class QwenImagePipeline(BasePipeline):
|
|
|
97
103
|
self.config = config
|
|
98
104
|
self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
|
99
105
|
self.prompt_template_encode_start_idx = 34
|
|
106
|
+
|
|
107
|
+
self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
|
|
108
|
+
self.edit_prompt_template_encode_start_idx = 64
|
|
100
109
|
# sampler
|
|
101
110
|
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
|
|
102
111
|
self.sampler = FlowMatchEulerSampler()
|
|
103
112
|
# models
|
|
104
113
|
self.tokenizer = tokenizer
|
|
114
|
+
self.processor = processor
|
|
105
115
|
self.encoder = encoder
|
|
106
116
|
self.dit = dit
|
|
107
117
|
self.vae = vae
|
|
@@ -155,6 +165,10 @@ class QwenImagePipeline(BasePipeline):
|
|
|
155
165
|
|
|
156
166
|
init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
|
|
157
167
|
tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
|
|
168
|
+
processor = Qwen2VLProcessor.from_pretrained(
|
|
169
|
+
tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
|
|
170
|
+
image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
|
|
171
|
+
)
|
|
158
172
|
with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r") as f:
|
|
159
173
|
vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
|
|
160
174
|
with open(QWEN_IMAGE_CONFIG_FILE, "r") as f:
|
|
@@ -201,6 +215,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
201
215
|
pipe = cls(
|
|
202
216
|
config=config,
|
|
203
217
|
tokenizer=tokenizer,
|
|
218
|
+
processor=processor,
|
|
204
219
|
encoder=encoder,
|
|
205
220
|
dit=dit,
|
|
206
221
|
vae=vae,
|
|
@@ -209,7 +224,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
209
224
|
|
|
210
225
|
if config.offload_mode is not None:
|
|
211
226
|
pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
|
|
212
|
-
|
|
227
|
+
|
|
213
228
|
if config.model_dtype == torch.float8_e4m3fn:
|
|
214
229
|
pipe.dtype = torch.bfloat16 # compute dtype
|
|
215
230
|
pipe.enable_fp8_autocast(
|
|
@@ -302,9 +317,51 @@ class QwenImagePipeline(BasePipeline):
|
|
|
302
317
|
|
|
303
318
|
return prompt_embeds, prompt_embeds_mask
|
|
304
319
|
|
|
320
|
+
def encode_prompt_with_image(
|
|
321
|
+
self,
|
|
322
|
+
prompt: Union[str, List[str]],
|
|
323
|
+
image: torch.Tensor,
|
|
324
|
+
num_images_per_prompt: int = 1,
|
|
325
|
+
max_sequence_length: int = 1024,
|
|
326
|
+
):
|
|
327
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
|
328
|
+
|
|
329
|
+
batch_size = len(prompt)
|
|
330
|
+
template = self.edit_prompt_template_encode
|
|
331
|
+
drop_idx = self.edit_prompt_template_encode_start_idx
|
|
332
|
+
texts = [template.format(txt) for txt in prompt]
|
|
333
|
+
|
|
334
|
+
model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx)
|
|
335
|
+
input_ids, attention_mask, pixel_values, image_grid_thw = (
|
|
336
|
+
model_inputs["input_ids"].to(self.device),
|
|
337
|
+
model_inputs["attention_mask"].to(self.device),
|
|
338
|
+
model_inputs["pixel_values"].to(self.device),
|
|
339
|
+
model_inputs["image_grid_thw"].to(self.device),
|
|
340
|
+
)
|
|
341
|
+
outputs = self.encoder(
|
|
342
|
+
input_ids=input_ids,
|
|
343
|
+
attention_mask=attention_mask,
|
|
344
|
+
pixel_values=pixel_values,
|
|
345
|
+
image_grid_thw=image_grid_thw,
|
|
346
|
+
)
|
|
347
|
+
hidden_states = outputs["hidden_states"]
|
|
348
|
+
prompt_embeds = hidden_states[:, drop_idx:]
|
|
349
|
+
prompt_embeds_mask = attention_mask[:, drop_idx:]
|
|
350
|
+
seq_len = prompt_embeds.shape[1]
|
|
351
|
+
|
|
352
|
+
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
|
353
|
+
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
|
354
|
+
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
|
355
|
+
|
|
356
|
+
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
|
|
357
|
+
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
|
|
358
|
+
|
|
359
|
+
return prompt_embeds, prompt_embeds_mask
|
|
360
|
+
|
|
305
361
|
def predict_noise_with_cfg(
|
|
306
362
|
self,
|
|
307
363
|
latents: torch.Tensor,
|
|
364
|
+
image_latents: torch.Tensor,
|
|
308
365
|
timestep: torch.Tensor,
|
|
309
366
|
prompt_emb: torch.Tensor,
|
|
310
367
|
negative_prompt_emb: torch.Tensor,
|
|
@@ -316,6 +373,7 @@ class QwenImagePipeline(BasePipeline):
|
|
|
316
373
|
if cfg_scale <= 1.0 or negative_prompt_emb is None:
|
|
317
374
|
return self.predict_noise(
|
|
318
375
|
latents,
|
|
376
|
+
image_latents,
|
|
319
377
|
timestep,
|
|
320
378
|
prompt_emb,
|
|
321
379
|
prompt_embeds_mask,
|
|
@@ -325,12 +383,14 @@ class QwenImagePipeline(BasePipeline):
|
|
|
325
383
|
h, w = latents.shape[-2:]
|
|
326
384
|
positive_noise_pred = self.predict_noise(
|
|
327
385
|
latents,
|
|
386
|
+
image_latents,
|
|
328
387
|
timestep,
|
|
329
388
|
prompt_emb,
|
|
330
389
|
prompt_embeds_mask,
|
|
331
390
|
)
|
|
332
391
|
negative_noise_pred = self.predict_noise(
|
|
333
392
|
latents,
|
|
393
|
+
image_latents,
|
|
334
394
|
timestep,
|
|
335
395
|
negative_prompt_emb,
|
|
336
396
|
negative_prompt_embeds_mask,
|
|
@@ -346,9 +406,11 @@ class QwenImagePipeline(BasePipeline):
|
|
|
346
406
|
prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
|
|
347
407
|
prompt_embeds_mask = torch.cat([prompt_embeds_mask, negative_prompt_embeds_mask], dim=0)
|
|
348
408
|
latents = torch.cat([latents, latents], dim=0)
|
|
409
|
+
image_latents = torch.cat([image_latents, image_latents], dim=0)
|
|
349
410
|
timestep = torch.cat([timestep, timestep], dim=0)
|
|
350
411
|
noise_pred = self.predict_noise(
|
|
351
412
|
latents,
|
|
413
|
+
image_latents,
|
|
352
414
|
timestep,
|
|
353
415
|
prompt_emb,
|
|
354
416
|
prompt_embeds_mask,
|
|
@@ -363,25 +425,49 @@ class QwenImagePipeline(BasePipeline):
|
|
|
363
425
|
def predict_noise(
|
|
364
426
|
self,
|
|
365
427
|
latents: torch.Tensor,
|
|
428
|
+
image_latents: torch.Tensor,
|
|
366
429
|
timestep: torch.Tensor,
|
|
367
430
|
prompt_emb: torch.Tensor,
|
|
368
431
|
prompt_embeds_mask: torch.Tensor,
|
|
369
432
|
):
|
|
370
433
|
self.load_models_to_device(["dit"])
|
|
371
|
-
|
|
372
434
|
noise_pred = self.dit(
|
|
373
435
|
image=latents,
|
|
436
|
+
edit=image_latents,
|
|
374
437
|
text=prompt_emb,
|
|
375
438
|
timestep=timestep,
|
|
376
439
|
txt_seq_lens=prompt_embeds_mask.sum(dim=1),
|
|
377
440
|
)
|
|
378
441
|
return noise_pred
|
|
379
442
|
|
|
443
|
+
def prepare_image_latents(self, input_image: Image.Image):
|
|
444
|
+
image = self.preprocess_image(input_image).to(
|
|
445
|
+
device=self.device, dtype=self.vae.model.encoder.conv1.weight.dtype
|
|
446
|
+
)
|
|
447
|
+
image = image.unsqueeze(2)
|
|
448
|
+
image_latents = self.vae.encode(
|
|
449
|
+
image,
|
|
450
|
+
device=self.device,
|
|
451
|
+
tiled=self.vae_tiled,
|
|
452
|
+
tile_size=self.vae_tile_size,
|
|
453
|
+
tile_stride=self.vae_tile_stride,
|
|
454
|
+
)
|
|
455
|
+
image_latents = image_latents.squeeze(2)
|
|
456
|
+
return image_latents
|
|
457
|
+
|
|
458
|
+
def calculate_dimensions(self, target_area, ratio):
|
|
459
|
+
width = math.sqrt(target_area * ratio)
|
|
460
|
+
height = width / ratio
|
|
461
|
+
width = round(width / 32) * 32
|
|
462
|
+
height = round(height / 32) * 32
|
|
463
|
+
return width, height
|
|
464
|
+
|
|
380
465
|
@torch.no_grad()
|
|
381
466
|
def __call__(
|
|
382
467
|
self,
|
|
383
468
|
prompt: str,
|
|
384
469
|
negative_prompt: str = "",
|
|
470
|
+
input_image: Image.Image | None = None, # use for img2img
|
|
385
471
|
cfg_scale: float = 4.0, # true cfg
|
|
386
472
|
height: int = 1328,
|
|
387
473
|
width: int = 1328,
|
|
@@ -389,29 +475,51 @@ class QwenImagePipeline(BasePipeline):
|
|
|
389
475
|
seed: int | None = None,
|
|
390
476
|
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
|
|
391
477
|
):
|
|
478
|
+
if input_image is not None:
|
|
479
|
+
width, height = input_image.size
|
|
480
|
+
width, height = self.calculate_dimensions(1024 * 1024, width / height)
|
|
481
|
+
input_image = input_image.resize((width, height), Image.LANCZOS)
|
|
482
|
+
|
|
483
|
+
self.validate_image_size(height, width, minimum=64, multiple_of=16)
|
|
484
|
+
|
|
392
485
|
noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to(
|
|
393
486
|
device=self.device
|
|
394
487
|
)
|
|
395
488
|
# dynamic shift
|
|
396
489
|
image_seq_len = math.ceil(height // 16) * math.ceil(width // 16)
|
|
397
490
|
mu = calculate_shift(image_seq_len, max_shift=0.9, max_seq_len=8192)
|
|
491
|
+
if input_image:
|
|
492
|
+
image_latents = self.prepare_image_latents(input_image)
|
|
493
|
+
else:
|
|
494
|
+
image_latents = None
|
|
398
495
|
init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps, mu)
|
|
399
496
|
# Initialize sampler
|
|
400
497
|
self.sampler.initialize(sigmas=sigmas)
|
|
401
498
|
|
|
402
499
|
self.load_models_to_device(["encoder"])
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
500
|
+
if image_latents is not None:
|
|
501
|
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt_with_image(prompt, input_image, 1, 4096)
|
|
502
|
+
if cfg_scale > 1.0 and negative_prompt != "":
|
|
503
|
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_with_image(
|
|
504
|
+
negative_prompt, input_image, 1, 4096
|
|
505
|
+
)
|
|
506
|
+
else:
|
|
507
|
+
negative_prompt_embeds, negative_prompt_embeds_mask = None, None
|
|
406
508
|
else:
|
|
407
|
-
|
|
509
|
+
prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096)
|
|
510
|
+
if cfg_scale > 1.0 and negative_prompt != "":
|
|
511
|
+
negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
|
|
512
|
+
else:
|
|
513
|
+
negative_prompt_embeds, negative_prompt_embeds_mask = None, None
|
|
408
514
|
self.model_lifecycle_finish(["encoder"])
|
|
409
515
|
|
|
410
516
|
hide_progress = dist.is_initialized() and dist.get_rank() != 0
|
|
517
|
+
|
|
411
518
|
for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
|
|
412
519
|
timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
|
|
413
520
|
noise_pred = self.predict_noise_with_cfg(
|
|
414
521
|
latents=latents,
|
|
522
|
+
image_latents=image_latents,
|
|
415
523
|
timestep=timestep,
|
|
416
524
|
prompt_emb=prompt_embeds,
|
|
417
525
|
negative_prompt_emb=negative_prompt_embeds,
|
|
@@ -431,12 +539,16 @@ class QwenImagePipeline(BasePipeline):
|
|
|
431
539
|
latents = rearrange(latents, "B C H W -> B C 1 H W")
|
|
432
540
|
vae_output = rearrange(
|
|
433
541
|
self.vae.decode(
|
|
434
|
-
latents.to(self.vae.model.encoder.conv1.weight.dtype),
|
|
542
|
+
latents.to(self.vae.model.encoder.conv1.weight.dtype),
|
|
543
|
+
device=self.vae.model.encoder.conv1.weight.device,
|
|
544
|
+
tiled=self.vae_tiled,
|
|
545
|
+
tile_size=self.vae_tile_size,
|
|
546
|
+
tile_stride=self.vae_tile_stride,
|
|
435
547
|
)[0],
|
|
436
548
|
"C B H W -> B C H W",
|
|
437
549
|
)
|
|
438
550
|
image = self.vae_output_to_image(vae_output)
|
|
439
551
|
# Offload all models
|
|
440
|
-
self.model_lifecycle_finish(["vae"])
|
|
552
|
+
self.model_lifecycle_finish(["vae"])
|
|
441
553
|
self.load_models_to_device([])
|
|
442
554
|
return image
|
|
@@ -181,21 +181,21 @@ class SDImagePipeline(BasePipeline):
|
|
|
181
181
|
raise ValueError("`model_path` cannot be empty")
|
|
182
182
|
logger.info(f"loading state dict from {config.model_path} ...")
|
|
183
183
|
state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
184
|
-
|
|
184
|
+
|
|
185
185
|
if state_dicts.vae is None:
|
|
186
186
|
if config.vae_path is None:
|
|
187
187
|
state_dicts.vae = state_dicts.model
|
|
188
188
|
else:
|
|
189
189
|
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
190
190
|
state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
191
|
-
|
|
191
|
+
|
|
192
192
|
if state_dicts.clip is None:
|
|
193
193
|
if config.clip_path is None:
|
|
194
194
|
state_dicts.clip = state_dicts.model
|
|
195
195
|
else:
|
|
196
196
|
logger.info(f"loading state dict from {config.clip_path} ...")
|
|
197
197
|
state_dicts.clip = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
|
|
198
|
-
|
|
198
|
+
|
|
199
199
|
init_device = "cpu" if config.offload_mode is not None else config.device
|
|
200
200
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
201
201
|
with LoRAContext():
|
|
@@ -159,28 +159,32 @@ class SDXLImagePipeline(BasePipeline):
|
|
|
159
159
|
raise ValueError("`model_path` cannot be empty")
|
|
160
160
|
logger.info(f"loading state dict from {config.model_path} ...")
|
|
161
161
|
state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
|
|
162
|
-
|
|
162
|
+
|
|
163
163
|
if state_dicts.vae is None:
|
|
164
164
|
if config.vae_path is None:
|
|
165
165
|
state_dicts.vae = state_dicts.model
|
|
166
166
|
else:
|
|
167
167
|
logger.info(f"loading state dict from {config.vae_path} ...")
|
|
168
168
|
state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
|
|
169
|
-
|
|
169
|
+
|
|
170
170
|
if state_dicts.clip_l is None:
|
|
171
171
|
if config.clip_l_path is None:
|
|
172
172
|
state_dicts.clip_l = state_dicts.model
|
|
173
173
|
else:
|
|
174
174
|
logger.info(f"loading state dict from {config.clip_l_path} ...")
|
|
175
|
-
state_dicts.clip_l = cls.load_model_checkpoint(
|
|
176
|
-
|
|
175
|
+
state_dicts.clip_l = cls.load_model_checkpoint(
|
|
176
|
+
config.clip_l_path, device="cpu", dtype=config.clip_l_dtype
|
|
177
|
+
)
|
|
178
|
+
|
|
177
179
|
if state_dicts.clip_g is None:
|
|
178
180
|
if config.clip_g_path is None:
|
|
179
181
|
state_dicts.clip_g = state_dicts.model
|
|
180
182
|
else:
|
|
181
183
|
logger.info(f"loading state dict from {config.clip_g_path} ...")
|
|
182
|
-
state_dicts.clip_g = cls.load_model_checkpoint(
|
|
183
|
-
|
|
184
|
+
state_dicts.clip_g = cls.load_model_checkpoint(
|
|
185
|
+
config.clip_g_path, device="cpu", dtype=config.clip_g_dtype
|
|
186
|
+
)
|
|
187
|
+
|
|
184
188
|
init_device = "cpu" if config.offload_mode else config.device
|
|
185
189
|
tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
|
|
186
190
|
tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
|
|
@@ -3,6 +3,8 @@ from .clip import CLIPTokenizer
|
|
|
3
3
|
from .t5 import T5TokenizerFast
|
|
4
4
|
from .wan import WanT5Tokenizer
|
|
5
5
|
from .qwen2 import Qwen2TokenizerFast
|
|
6
|
+
from .qwen2_vl_image_processor import Qwen2VLImageProcessor
|
|
7
|
+
from .qwen2_vl_processor import Qwen2VLProcessor
|
|
6
8
|
|
|
7
9
|
__all__ = [
|
|
8
10
|
"BaseTokenizer",
|
|
@@ -10,4 +12,6 @@ __all__ = [
|
|
|
10
12
|
"T5TokenizerFast",
|
|
11
13
|
"WanT5Tokenizer",
|
|
12
14
|
"Qwen2TokenizerFast",
|
|
15
|
+
"Qwen2VLImageProcessor",
|
|
16
|
+
"Qwen2VLProcessor",
|
|
13
17
|
]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
# modified from transformers.models.qwen2_vl.image_processing_qwen2_vl
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
import logging
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import List, Optional
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from diffsynth_engine.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
|
10
|
+
from diffsynth_engine.utils.image import (
|
|
11
|
+
ChannelDimension,
|
|
12
|
+
convert_to_rgb,
|
|
13
|
+
get_image_size,
|
|
14
|
+
infer_channel_dimension_format,
|
|
15
|
+
rescale_image,
|
|
16
|
+
resize_image,
|
|
17
|
+
smart_resize,
|
|
18
|
+
normalize_image,
|
|
19
|
+
to_channel_dimension_format,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class Qwen2VLImageProcessor:
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
do_resize: bool = True,
|
|
29
|
+
resample: Image.Resampling = Image.Resampling.BICUBIC,
|
|
30
|
+
do_rescale: bool = True,
|
|
31
|
+
rescale_factor: float = 1.0 / 255,
|
|
32
|
+
do_normalize: bool = True,
|
|
33
|
+
image_mean: List[float] = OPENAI_CLIP_MEAN,
|
|
34
|
+
image_std: List[float] = OPENAI_CLIP_STD,
|
|
35
|
+
do_convert_rgb: bool = True,
|
|
36
|
+
min_pixels: int = 56 * 56,
|
|
37
|
+
max_pixels: int = 28 * 28 * 1280,
|
|
38
|
+
patch_size: int = 14,
|
|
39
|
+
temporal_patch_size: int = 2,
|
|
40
|
+
merge_size: int = 2,
|
|
41
|
+
**kwargs,
|
|
42
|
+
):
|
|
43
|
+
self.do_resize = do_resize
|
|
44
|
+
self.resample = resample
|
|
45
|
+
self.do_rescale = do_rescale
|
|
46
|
+
self.rescale_factor = rescale_factor
|
|
47
|
+
self.do_normalize = do_normalize
|
|
48
|
+
self.do_convert_rgb = do_convert_rgb
|
|
49
|
+
self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
|
|
50
|
+
self.image_mean = image_mean
|
|
51
|
+
self.image_std = image_std
|
|
52
|
+
self.patch_size = patch_size
|
|
53
|
+
self.merge_size = merge_size
|
|
54
|
+
self.min_pixels = min_pixels
|
|
55
|
+
self.max_pixels = max_pixels
|
|
56
|
+
self.temporal_patch_size = temporal_patch_size
|
|
57
|
+
|
|
58
|
+
@classmethod
|
|
59
|
+
def from_pretrained(cls, config_file_path: str | os.PathLike, **kwargs):
|
|
60
|
+
init_kwargs = {}
|
|
61
|
+
if not os.path.exists(config_file_path):
|
|
62
|
+
logger.warning(f"Cannot find {config_file_path}, init processor with default parameters")
|
|
63
|
+
else:
|
|
64
|
+
with open(config_file_path, "r", encoding="utf-8") as kwargs_handler:
|
|
65
|
+
init_kwargs = json.load(kwargs_handler)
|
|
66
|
+
|
|
67
|
+
init_kwargs.update(**kwargs)
|
|
68
|
+
return cls(**init_kwargs)
|
|
69
|
+
|
|
70
|
+
def __call__(
|
|
71
|
+
self,
|
|
72
|
+
images: Image.Image | List[Image.Image],
|
|
73
|
+
videos: Optional[List[List[Image.Image]]] = None,
|
|
74
|
+
data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
|
|
75
|
+
):
|
|
76
|
+
pixel_values, image_grid_thws = None, None
|
|
77
|
+
if images is not None:
|
|
78
|
+
if isinstance(images, Image.Image):
|
|
79
|
+
images = [images]
|
|
80
|
+
pixel_values, image_grid_thws = [], []
|
|
81
|
+
for image in images:
|
|
82
|
+
flatten_patches, image_grid_thw = self._preprocess([image], data_format)
|
|
83
|
+
pixel_values.extend(flatten_patches)
|
|
84
|
+
image_grid_thws.append(image_grid_thw)
|
|
85
|
+
pixel_values = np.array(pixel_values)
|
|
86
|
+
image_grid_thws = np.array(image_grid_thws)
|
|
87
|
+
|
|
88
|
+
vision_pixel_values, vision_grid_thws = None, None
|
|
89
|
+
if videos is not None:
|
|
90
|
+
vision_pixel_values, vision_grid_thws = [], []
|
|
91
|
+
for images in videos:
|
|
92
|
+
flatten_patches, video_grid_thw = self._preprocess(images, data_format)
|
|
93
|
+
vision_pixel_values.append(flatten_patches)
|
|
94
|
+
vision_grid_thws.append(video_grid_thw)
|
|
95
|
+
vision_pixel_values = np.array(vision_pixel_values)
|
|
96
|
+
vision_grid_thws = np.array(vision_grid_thws)
|
|
97
|
+
|
|
98
|
+
return pixel_values, image_grid_thws, vision_pixel_values, vision_grid_thws
|
|
99
|
+
|
|
100
|
+
def _preprocess(self, images: List[Image.Image], data_format: Optional[ChannelDimension] = ChannelDimension.FIRST):
|
|
101
|
+
images = [convert_to_rgb(image) for image in images]
|
|
102
|
+
image_nps = [np.array(image) for image in images]
|
|
103
|
+
input_data_format = infer_channel_dimension_format(image_nps[0])
|
|
104
|
+
height, width = get_image_size(image_nps[0], input_data_format)
|
|
105
|
+
resized_height, resized_width = height, width
|
|
106
|
+
|
|
107
|
+
processed_image_nps = []
|
|
108
|
+
for image_np in image_nps:
|
|
109
|
+
if self.do_resize:
|
|
110
|
+
resized_height, resized_width = smart_resize(
|
|
111
|
+
height,
|
|
112
|
+
width,
|
|
113
|
+
factor=self.patch_size * self.merge_size,
|
|
114
|
+
min_pixels=self.min_pixels,
|
|
115
|
+
max_pixels=self.max_pixels,
|
|
116
|
+
)
|
|
117
|
+
image_np = resize_image(
|
|
118
|
+
image_np, resized_height, resized_width, self.resample, input_data_format=input_data_format
|
|
119
|
+
)
|
|
120
|
+
|
|
121
|
+
if self.do_rescale:
|
|
122
|
+
image_np = rescale_image(image_np, self.rescale_factor)
|
|
123
|
+
|
|
124
|
+
if self.do_normalize:
|
|
125
|
+
image_np = normalize_image(
|
|
126
|
+
image_np, self.image_mean, self.image_std, input_data_format=input_data_format
|
|
127
|
+
)
|
|
128
|
+
image_np = to_channel_dimension_format(image_np, data_format, input_data_format)
|
|
129
|
+
processed_image_nps.append(image_np)
|
|
130
|
+
|
|
131
|
+
patches = np.array(processed_image_nps)
|
|
132
|
+
if data_format == ChannelDimension.LAST:
|
|
133
|
+
patches = patches.transpose(0, 3, 1, 2)
|
|
134
|
+
if patches.shape[0] % self.temporal_patch_size != 0:
|
|
135
|
+
repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0)
|
|
136
|
+
patches = np.concatenate([patches, repeats], axis=0)
|
|
137
|
+
num_channel = patches.shape[1]
|
|
138
|
+
grid_t = patches.shape[0] // self.temporal_patch_size
|
|
139
|
+
grid_h = resized_height // self.patch_size
|
|
140
|
+
grid_w = resized_width // self.patch_size
|
|
141
|
+
patches = patches.reshape(
|
|
142
|
+
grid_t,
|
|
143
|
+
self.temporal_patch_size,
|
|
144
|
+
num_channel,
|
|
145
|
+
grid_h // self.merge_size,
|
|
146
|
+
self.merge_size,
|
|
147
|
+
self.patch_size,
|
|
148
|
+
grid_w // self.merge_size,
|
|
149
|
+
self.merge_size,
|
|
150
|
+
self.patch_size,
|
|
151
|
+
)
|
|
152
|
+
patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
|
|
153
|
+
flatten_patches = patches.reshape(
|
|
154
|
+
grid_t * grid_h * grid_w, num_channel * self.temporal_patch_size * self.patch_size * self.patch_size
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
return flatten_patches, (grid_t, grid_h, grid_w)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import torch
|
|
4
|
+
import logging
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from typing import List, Dict, Optional
|
|
7
|
+
|
|
8
|
+
from diffsynth_engine.tokenizers.qwen2_vl_image_processor import Qwen2VLImageProcessor
|
|
9
|
+
from diffsynth_engine.tokenizers.qwen2 import Qwen2TokenizerFast
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class Qwen2VLProcessor:
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
tokenizer: Qwen2TokenizerFast,
|
|
18
|
+
image_processor: Qwen2VLImageProcessor,
|
|
19
|
+
image_token: str = "<|image_pad|>",
|
|
20
|
+
**kwargs,
|
|
21
|
+
):
|
|
22
|
+
self.tokenizer = tokenizer
|
|
23
|
+
self.image_processor = image_processor
|
|
24
|
+
self.image_token = image_token
|
|
25
|
+
|
|
26
|
+
@classmethod
|
|
27
|
+
def from_pretrained(
|
|
28
|
+
cls,
|
|
29
|
+
tokenizer_config_path: str | os.PathLike,
|
|
30
|
+
image_processor_config_path: str | os.PathLike,
|
|
31
|
+
**kwargs,
|
|
32
|
+
):
|
|
33
|
+
tokenizer = Qwen2TokenizerFast.from_pretrained(tokenizer_config_path)
|
|
34
|
+
image_processor = Qwen2VLImageProcessor.from_pretrained(image_processor_config_path)
|
|
35
|
+
return cls(tokenizer=tokenizer, image_processor=image_processor, **kwargs)
|
|
36
|
+
|
|
37
|
+
def batch_decode(
|
|
38
|
+
self,
|
|
39
|
+
ids: List[List[int]] | List[torch.Tensor],
|
|
40
|
+
skip_special_tokens: bool = False,
|
|
41
|
+
clean_up_tokenization_spaces: Optional[bool] = None,
|
|
42
|
+
):
|
|
43
|
+
if isinstance(ids[0], torch.Tensor):
|
|
44
|
+
ids = [id_.tolist() for id_ in ids]
|
|
45
|
+
decoded = self.tokenizer.batch_decode(ids, skip_special_tokens, clean_up_tokenization_spaces)
|
|
46
|
+
pattern = r"<\|vision_start\|>.*?<\|vision_end\|>"
|
|
47
|
+
decoded_with_image_tag = [re.sub(pattern, "<image>", d, flags=re.DOTALL) for d in decoded]
|
|
48
|
+
decoded_with_image_tag = [re.sub(r"<\|im_end\|>", "", d) for d in decoded_with_image_tag]
|
|
49
|
+
return decoded_with_image_tag
|
|
50
|
+
|
|
51
|
+
def __call__(
|
|
52
|
+
self,
|
|
53
|
+
text: str | List[str],
|
|
54
|
+
images: Optional[List[Image.Image]] = None,
|
|
55
|
+
videos: Optional[List[List[Image.Image]]] = None,
|
|
56
|
+
max_length: Optional[int] = None,
|
|
57
|
+
) -> Dict[str, torch.Tensor]:
|
|
58
|
+
"""
|
|
59
|
+
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
|
|
60
|
+
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
|
61
|
+
the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
|
|
62
|
+
Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
text (`List[str]`):
|
|
66
|
+
The sequence or batch of sequences to be encoded.
|
|
67
|
+
images (`List[PIL.Image.Image]`):
|
|
68
|
+
The batch of images to be prepared.
|
|
69
|
+
videos (`List[List[PIL.Image.Image]]`):
|
|
70
|
+
The batch of videos to be prepared.
|
|
71
|
+
"""
|
|
72
|
+
images_pixel_values, images_grid_thws, video_pixels_values, video_grid_thws = self.image_processor(
|
|
73
|
+
images, videos
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
if not isinstance(text, list):
|
|
77
|
+
text = [text]
|
|
78
|
+
if images_grid_thws is not None:
|
|
79
|
+
merge_length = self.image_processor.merge_size**2
|
|
80
|
+
index = 0
|
|
81
|
+
for i in range(len(text)):
|
|
82
|
+
while self.image_token in text[i]:
|
|
83
|
+
text[i] = text[i].replace(
|
|
84
|
+
self.image_token, "<|placeholder|>" * (images_grid_thws[index].prod() // merge_length), 1
|
|
85
|
+
)
|
|
86
|
+
index += 1
|
|
87
|
+
text[i] = text[i].replace("<|placeholder|>", self.image_token)
|
|
88
|
+
text_inputs = self.tokenizer(text, max_length=max_length)
|
|
89
|
+
|
|
90
|
+
processed_inputs = text_inputs
|
|
91
|
+
if images_pixel_values is not None:
|
|
92
|
+
processed_inputs["pixel_values"] = torch.from_numpy(images_pixel_values)
|
|
93
|
+
if images_grid_thws is not None:
|
|
94
|
+
processed_inputs["image_grid_thw"] = torch.from_numpy(images_grid_thws)
|
|
95
|
+
if video_pixels_values is not None:
|
|
96
|
+
processed_inputs["pixel_values_videos"] = video_pixels_values
|
|
97
|
+
if video_grid_thws is not None:
|
|
98
|
+
processed_inputs["video_grid_thw"] = video_grid_thws
|
|
99
|
+
|
|
100
|
+
return processed_inputs
|
|
@@ -5,6 +5,7 @@ REPO_ROOT = os.path.dirname(PACKAGE_ROOT)
|
|
|
5
5
|
|
|
6
6
|
# conf
|
|
7
7
|
CONF_PATH = os.path.join(PACKAGE_ROOT, "conf")
|
|
8
|
+
|
|
8
9
|
# tokenizers
|
|
9
10
|
FLUX_TOKENIZER_1_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_1")
|
|
10
11
|
FLUX_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_2")
|
|
@@ -12,6 +13,8 @@ SDXL_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokeni
|
|
|
12
13
|
SDXL_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer_2")
|
|
13
14
|
WAN_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "wan", "umt5-xxl")
|
|
14
15
|
QWEN_IMAGE_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "tokenizer")
|
|
16
|
+
QWEN_IMAGE_PROCESSOR_CONFIG_FILE = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "qwen2_vl_image_processor.json")
|
|
17
|
+
|
|
15
18
|
# models
|
|
16
19
|
VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
|
|
17
20
|
FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
|
|
@@ -46,3 +49,6 @@ KB = 1024
|
|
|
46
49
|
MB = 1024 * KB
|
|
47
50
|
GB = 1024 * MB
|
|
48
51
|
TB = 1024 * GB
|
|
52
|
+
|
|
53
|
+
OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
|
|
54
|
+
OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]
|