diffsynth-engine 0.4.3.dev9__py3-none-any.whl → 0.4.3.dev11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,11 @@
1
1
  import json
2
2
  import torch
3
+ import torch.distributed as dist
3
4
  import math
4
5
  from typing import Callable, List, Tuple, Optional, Union, Dict
5
6
  from tqdm import tqdm
6
7
  from einops import rearrange
7
- import torch.distributed as dist
8
+ from PIL import Image
8
9
 
9
10
  from diffsynth_engine.configs import QwenImagePipelineConfig, QwenImageStateDicts
10
11
  from diffsynth_engine.models.basic.lora import LoRAContext
@@ -16,13 +17,14 @@ from diffsynth_engine.models.qwen_image import (
16
17
  Qwen2_5_VLConfig,
17
18
  )
18
19
  from diffsynth_engine.models.qwen_image import QwenImageVAE
19
- from diffsynth_engine.tokenizers import Qwen2TokenizerFast
20
+ from diffsynth_engine.tokenizers import Qwen2TokenizerFast, Qwen2VLProcessor
20
21
  from diffsynth_engine.pipelines import BasePipeline, LoRAStateDictConverter
21
22
  from diffsynth_engine.pipelines.utils import calculate_shift
22
23
  from diffsynth_engine.algorithm.noise_scheduler import RecifitedFlowScheduler
23
24
  from diffsynth_engine.algorithm.sampler import FlowMatchEulerSampler
24
25
  from diffsynth_engine.utils.constants import (
25
26
  QWEN_IMAGE_TOKENIZER_CONF_PATH,
27
+ QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
26
28
  QWEN_IMAGE_CONFIG_FILE,
27
29
  QWEN_IMAGE_VISION_CONFIG_FILE,
28
30
  QWEN_IMAGE_VAE_CONFIG_FILE,
@@ -44,20 +46,23 @@ class QwenImageLoRAConverter(LoRAStateDictConverter):
44
46
  lora_a_suffix = None
45
47
  if "lora_A.default.weight" in key:
46
48
  lora_a_suffix = "lora_A.default.weight"
49
+ lora_b_suffix = "lora_B.default.weight"
47
50
  elif "lora_A.weight" in key:
48
51
  lora_a_suffix = "lora_A.weight"
52
+ lora_b_suffix = "lora_B.weight"
53
+ elif "lora_down.weight" in key:
54
+ lora_a_suffix = "lora_down.weight"
55
+ lora_b_suffix = "lora_up.weight"
49
56
 
50
57
  if lora_a_suffix is None:
51
58
  continue
52
59
 
53
60
  lora_args = {}
54
61
  lora_args["down"] = param
55
-
56
- lora_b_suffix = lora_a_suffix.replace("lora_A", "lora_B")
57
62
  lora_args["up"] = lora_state_dict[origin_key.replace(lora_a_suffix, lora_b_suffix)]
58
63
 
59
64
  lora_args["rank"] = lora_args["up"].shape[1]
60
- alpha_key = origin_key.replace("lora_up", "lora_A").replace(lora_a_suffix, "alpha")
65
+ alpha_key = origin_key.replace(lora_a_suffix, "alpha")
61
66
 
62
67
  if alpha_key in lora_state_dict:
63
68
  alpha = lora_state_dict[alpha_key]
@@ -83,6 +88,7 @@ class QwenImagePipeline(BasePipeline):
83
88
  self,
84
89
  config: QwenImagePipelineConfig,
85
90
  tokenizer: Qwen2TokenizerFast,
91
+ processor: Qwen2VLProcessor,
86
92
  encoder: Qwen2_5_VLForConditionalGeneration,
87
93
  dit: QwenImageDiT,
88
94
  vae: QwenImageVAE,
@@ -97,11 +103,15 @@ class QwenImagePipeline(BasePipeline):
97
103
  self.config = config
98
104
  self.prompt_template_encode = "<|im_start|>system\nDescribe the image by detailing the color, shape, size, texture, quantity, text, spatial relationships of the objects and background:<|im_end|>\n<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
99
105
  self.prompt_template_encode_start_idx = 34
106
+
107
+ self.edit_prompt_template_encode = "<|im_start|>system\nDescribe the key features of the input image (color, shape, size, texture, objects, background), then explain how the user's text instruction should alter or modify the image. Generate a new image that meets the user's requirements while maintaining consistency with the original input where appropriate.<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>{}<|im_end|>\n<|im_start|>assistant\n"
108
+ self.edit_prompt_template_encode_start_idx = 64
100
109
  # sampler
101
110
  self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
102
111
  self.sampler = FlowMatchEulerSampler()
103
112
  # models
104
113
  self.tokenizer = tokenizer
114
+ self.processor = processor
105
115
  self.encoder = encoder
106
116
  self.dit = dit
107
117
  self.vae = vae
@@ -155,6 +165,10 @@ class QwenImagePipeline(BasePipeline):
155
165
 
156
166
  init_device = "cpu" if config.parallelism > 1 or config.offload_mode is not None else config.device
157
167
  tokenizer = Qwen2TokenizerFast.from_pretrained(QWEN_IMAGE_TOKENIZER_CONF_PATH)
168
+ processor = Qwen2VLProcessor.from_pretrained(
169
+ tokenizer_config_path=QWEN_IMAGE_TOKENIZER_CONF_PATH,
170
+ image_processor_config_path=QWEN_IMAGE_PROCESSOR_CONFIG_FILE,
171
+ )
158
172
  with open(QWEN_IMAGE_VISION_CONFIG_FILE, "r") as f:
159
173
  vision_config = Qwen2_5_VLVisionConfig(**json.load(f))
160
174
  with open(QWEN_IMAGE_CONFIG_FILE, "r") as f:
@@ -201,6 +215,7 @@ class QwenImagePipeline(BasePipeline):
201
215
  pipe = cls(
202
216
  config=config,
203
217
  tokenizer=tokenizer,
218
+ processor=processor,
204
219
  encoder=encoder,
205
220
  dit=dit,
206
221
  vae=vae,
@@ -209,7 +224,7 @@ class QwenImagePipeline(BasePipeline):
209
224
 
210
225
  if config.offload_mode is not None:
211
226
  pipe.enable_cpu_offload(config.offload_mode, config.offload_to_disk)
212
-
227
+
213
228
  if config.model_dtype == torch.float8_e4m3fn:
214
229
  pipe.dtype = torch.bfloat16 # compute dtype
215
230
  pipe.enable_fp8_autocast(
@@ -302,9 +317,51 @@ class QwenImagePipeline(BasePipeline):
302
317
 
303
318
  return prompt_embeds, prompt_embeds_mask
304
319
 
320
+ def encode_prompt_with_image(
321
+ self,
322
+ prompt: Union[str, List[str]],
323
+ image: torch.Tensor,
324
+ num_images_per_prompt: int = 1,
325
+ max_sequence_length: int = 1024,
326
+ ):
327
+ prompt = [prompt] if isinstance(prompt, str) else prompt
328
+
329
+ batch_size = len(prompt)
330
+ template = self.edit_prompt_template_encode
331
+ drop_idx = self.edit_prompt_template_encode_start_idx
332
+ texts = [template.format(txt) for txt in prompt]
333
+
334
+ model_inputs = self.processor(text=texts, images=image, max_length=max_sequence_length + drop_idx)
335
+ input_ids, attention_mask, pixel_values, image_grid_thw = (
336
+ model_inputs["input_ids"].to(self.device),
337
+ model_inputs["attention_mask"].to(self.device),
338
+ model_inputs["pixel_values"].to(self.device),
339
+ model_inputs["image_grid_thw"].to(self.device),
340
+ )
341
+ outputs = self.encoder(
342
+ input_ids=input_ids,
343
+ attention_mask=attention_mask,
344
+ pixel_values=pixel_values,
345
+ image_grid_thw=image_grid_thw,
346
+ )
347
+ hidden_states = outputs["hidden_states"]
348
+ prompt_embeds = hidden_states[:, drop_idx:]
349
+ prompt_embeds_mask = attention_mask[:, drop_idx:]
350
+ seq_len = prompt_embeds.shape[1]
351
+
352
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
353
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
354
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
355
+
356
+ prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
357
+ prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
358
+
359
+ return prompt_embeds, prompt_embeds_mask
360
+
305
361
  def predict_noise_with_cfg(
306
362
  self,
307
363
  latents: torch.Tensor,
364
+ image_latents: torch.Tensor,
308
365
  timestep: torch.Tensor,
309
366
  prompt_emb: torch.Tensor,
310
367
  negative_prompt_emb: torch.Tensor,
@@ -316,6 +373,7 @@ class QwenImagePipeline(BasePipeline):
316
373
  if cfg_scale <= 1.0 or negative_prompt_emb is None:
317
374
  return self.predict_noise(
318
375
  latents,
376
+ image_latents,
319
377
  timestep,
320
378
  prompt_emb,
321
379
  prompt_embeds_mask,
@@ -325,12 +383,14 @@ class QwenImagePipeline(BasePipeline):
325
383
  h, w = latents.shape[-2:]
326
384
  positive_noise_pred = self.predict_noise(
327
385
  latents,
386
+ image_latents,
328
387
  timestep,
329
388
  prompt_emb,
330
389
  prompt_embeds_mask,
331
390
  )
332
391
  negative_noise_pred = self.predict_noise(
333
392
  latents,
393
+ image_latents,
334
394
  timestep,
335
395
  negative_prompt_emb,
336
396
  negative_prompt_embeds_mask,
@@ -346,9 +406,11 @@ class QwenImagePipeline(BasePipeline):
346
406
  prompt_emb = torch.cat([prompt_emb, negative_prompt_emb], dim=0)
347
407
  prompt_embeds_mask = torch.cat([prompt_embeds_mask, negative_prompt_embeds_mask], dim=0)
348
408
  latents = torch.cat([latents, latents], dim=0)
409
+ image_latents = torch.cat([image_latents, image_latents], dim=0)
349
410
  timestep = torch.cat([timestep, timestep], dim=0)
350
411
  noise_pred = self.predict_noise(
351
412
  latents,
413
+ image_latents,
352
414
  timestep,
353
415
  prompt_emb,
354
416
  prompt_embeds_mask,
@@ -363,25 +425,49 @@ class QwenImagePipeline(BasePipeline):
363
425
  def predict_noise(
364
426
  self,
365
427
  latents: torch.Tensor,
428
+ image_latents: torch.Tensor,
366
429
  timestep: torch.Tensor,
367
430
  prompt_emb: torch.Tensor,
368
431
  prompt_embeds_mask: torch.Tensor,
369
432
  ):
370
433
  self.load_models_to_device(["dit"])
371
-
372
434
  noise_pred = self.dit(
373
435
  image=latents,
436
+ edit=image_latents,
374
437
  text=prompt_emb,
375
438
  timestep=timestep,
376
439
  txt_seq_lens=prompt_embeds_mask.sum(dim=1),
377
440
  )
378
441
  return noise_pred
379
442
 
443
+ def prepare_image_latents(self, input_image: Image.Image):
444
+ image = self.preprocess_image(input_image).to(
445
+ device=self.device, dtype=self.vae.model.encoder.conv1.weight.dtype
446
+ )
447
+ image = image.unsqueeze(2)
448
+ image_latents = self.vae.encode(
449
+ image,
450
+ device=self.device,
451
+ tiled=self.vae_tiled,
452
+ tile_size=self.vae_tile_size,
453
+ tile_stride=self.vae_tile_stride,
454
+ )
455
+ image_latents = image_latents.squeeze(2)
456
+ return image_latents
457
+
458
+ def calculate_dimensions(self, target_area, ratio):
459
+ width = math.sqrt(target_area * ratio)
460
+ height = width / ratio
461
+ width = round(width / 32) * 32
462
+ height = round(height / 32) * 32
463
+ return width, height
464
+
380
465
  @torch.no_grad()
381
466
  def __call__(
382
467
  self,
383
468
  prompt: str,
384
469
  negative_prompt: str = "",
470
+ input_image: Image.Image | None = None, # use for img2img
385
471
  cfg_scale: float = 4.0, # true cfg
386
472
  height: int = 1328,
387
473
  width: int = 1328,
@@ -389,29 +475,51 @@ class QwenImagePipeline(BasePipeline):
389
475
  seed: int | None = None,
390
476
  progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
391
477
  ):
478
+ if input_image is not None:
479
+ width, height = input_image.size
480
+ width, height = self.calculate_dimensions(1024 * 1024, width / height)
481
+ input_image = input_image.resize((width, height), Image.LANCZOS)
482
+
483
+ self.validate_image_size(height, width, minimum=64, multiple_of=16)
484
+
392
485
  noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to(
393
486
  device=self.device
394
487
  )
395
488
  # dynamic shift
396
489
  image_seq_len = math.ceil(height // 16) * math.ceil(width // 16)
397
490
  mu = calculate_shift(image_seq_len, max_shift=0.9, max_seq_len=8192)
491
+ if input_image:
492
+ image_latents = self.prepare_image_latents(input_image)
493
+ else:
494
+ image_latents = None
398
495
  init_latents, latents, sigmas, timesteps = self.prepare_latents(noise, num_inference_steps, mu)
399
496
  # Initialize sampler
400
497
  self.sampler.initialize(sigmas=sigmas)
401
498
 
402
499
  self.load_models_to_device(["encoder"])
403
- prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096)
404
- if cfg_scale > 1.0 and negative_prompt != "":
405
- negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
500
+ if image_latents is not None:
501
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt_with_image(prompt, input_image, 1, 4096)
502
+ if cfg_scale > 1.0 and negative_prompt != "":
503
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt_with_image(
504
+ negative_prompt, input_image, 1, 4096
505
+ )
506
+ else:
507
+ negative_prompt_embeds, negative_prompt_embeds_mask = None, None
406
508
  else:
407
- negative_prompt_embeds, negative_prompt_embeds_mask = None, None
509
+ prompt_embeds, prompt_embeds_mask = self.encode_prompt(prompt, 1, 4096)
510
+ if cfg_scale > 1.0 and negative_prompt != "":
511
+ negative_prompt_embeds, negative_prompt_embeds_mask = self.encode_prompt(negative_prompt, 1, 4096)
512
+ else:
513
+ negative_prompt_embeds, negative_prompt_embeds_mask = None, None
408
514
  self.model_lifecycle_finish(["encoder"])
409
515
 
410
516
  hide_progress = dist.is_initialized() and dist.get_rank() != 0
517
+
411
518
  for i, timestep in enumerate(tqdm(timesteps, disable=hide_progress)):
412
519
  timestep = timestep.unsqueeze(0).to(dtype=self.dtype)
413
520
  noise_pred = self.predict_noise_with_cfg(
414
521
  latents=latents,
522
+ image_latents=image_latents,
415
523
  timestep=timestep,
416
524
  prompt_emb=prompt_embeds,
417
525
  negative_prompt_emb=negative_prompt_embeds,
@@ -431,12 +539,16 @@ class QwenImagePipeline(BasePipeline):
431
539
  latents = rearrange(latents, "B C H W -> B C 1 H W")
432
540
  vae_output = rearrange(
433
541
  self.vae.decode(
434
- latents.to(self.vae.model.encoder.conv1.weight.dtype), device=self.vae.model.encoder.conv1.weight.device
542
+ latents.to(self.vae.model.encoder.conv1.weight.dtype),
543
+ device=self.vae.model.encoder.conv1.weight.device,
544
+ tiled=self.vae_tiled,
545
+ tile_size=self.vae_tile_size,
546
+ tile_stride=self.vae_tile_stride,
435
547
  )[0],
436
548
  "C B H W -> B C H W",
437
549
  )
438
550
  image = self.vae_output_to_image(vae_output)
439
551
  # Offload all models
440
- self.model_lifecycle_finish(["vae"])
552
+ self.model_lifecycle_finish(["vae"])
441
553
  self.load_models_to_device([])
442
554
  return image
@@ -181,21 +181,21 @@ class SDImagePipeline(BasePipeline):
181
181
  raise ValueError("`model_path` cannot be empty")
182
182
  logger.info(f"loading state dict from {config.model_path} ...")
183
183
  state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
184
-
184
+
185
185
  if state_dicts.vae is None:
186
186
  if config.vae_path is None:
187
187
  state_dicts.vae = state_dicts.model
188
188
  else:
189
189
  logger.info(f"loading state dict from {config.vae_path} ...")
190
190
  state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
191
-
191
+
192
192
  if state_dicts.clip is None:
193
193
  if config.clip_path is None:
194
194
  state_dicts.clip = state_dicts.model
195
195
  else:
196
196
  logger.info(f"loading state dict from {config.clip_path} ...")
197
197
  state_dicts.clip = cls.load_model_checkpoint(config.clip_path, device="cpu", dtype=config.clip_dtype)
198
-
198
+
199
199
  init_device = "cpu" if config.offload_mode is not None else config.device
200
200
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
201
201
  with LoRAContext():
@@ -159,28 +159,32 @@ class SDXLImagePipeline(BasePipeline):
159
159
  raise ValueError("`model_path` cannot be empty")
160
160
  logger.info(f"loading state dict from {config.model_path} ...")
161
161
  state_dicts.model = cls.load_model_checkpoint(config.model_path, device="cpu", dtype=config.model_dtype)
162
-
162
+
163
163
  if state_dicts.vae is None:
164
164
  if config.vae_path is None:
165
165
  state_dicts.vae = state_dicts.model
166
166
  else:
167
167
  logger.info(f"loading state dict from {config.vae_path} ...")
168
168
  state_dicts.vae = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)
169
-
169
+
170
170
  if state_dicts.clip_l is None:
171
171
  if config.clip_l_path is None:
172
172
  state_dicts.clip_l = state_dicts.model
173
173
  else:
174
174
  logger.info(f"loading state dict from {config.clip_l_path} ...")
175
- state_dicts.clip_l = cls.load_model_checkpoint(config.clip_l_path, device="cpu", dtype=config.clip_l_dtype)
176
-
175
+ state_dicts.clip_l = cls.load_model_checkpoint(
176
+ config.clip_l_path, device="cpu", dtype=config.clip_l_dtype
177
+ )
178
+
177
179
  if state_dicts.clip_g is None:
178
180
  if config.clip_g_path is None:
179
181
  state_dicts.clip_g = state_dicts.model
180
182
  else:
181
183
  logger.info(f"loading state dict from {config.clip_g_path} ...")
182
- state_dicts.clip_g = cls.load_model_checkpoint(config.clip_g_path, device="cpu", dtype=config.clip_g_dtype)
183
-
184
+ state_dicts.clip_g = cls.load_model_checkpoint(
185
+ config.clip_g_path, device="cpu", dtype=config.clip_g_dtype
186
+ )
187
+
184
188
  init_device = "cpu" if config.offload_mode else config.device
185
189
  tokenizer = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_CONF_PATH)
186
190
  tokenizer_2 = CLIPTokenizer.from_pretrained(SDXL_TOKENIZER_2_CONF_PATH)
@@ -3,6 +3,8 @@ from .clip import CLIPTokenizer
3
3
  from .t5 import T5TokenizerFast
4
4
  from .wan import WanT5Tokenizer
5
5
  from .qwen2 import Qwen2TokenizerFast
6
+ from .qwen2_vl_image_processor import Qwen2VLImageProcessor
7
+ from .qwen2_vl_processor import Qwen2VLProcessor
6
8
 
7
9
  __all__ = [
8
10
  "BaseTokenizer",
@@ -10,4 +12,6 @@ __all__ = [
10
12
  "T5TokenizerFast",
11
13
  "WanT5Tokenizer",
12
14
  "Qwen2TokenizerFast",
15
+ "Qwen2VLImageProcessor",
16
+ "Qwen2VLProcessor",
13
17
  ]
@@ -0,0 +1,157 @@
1
+ # modified from transformers.models.qwen2_vl.image_processing_qwen2_vl
2
+ import os
3
+ import json
4
+ import logging
5
+ import numpy as np
6
+ from typing import List, Optional
7
+ from PIL import Image
8
+
9
+ from diffsynth_engine.utils.constants import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
10
+ from diffsynth_engine.utils.image import (
11
+ ChannelDimension,
12
+ convert_to_rgb,
13
+ get_image_size,
14
+ infer_channel_dimension_format,
15
+ rescale_image,
16
+ resize_image,
17
+ smart_resize,
18
+ normalize_image,
19
+ to_channel_dimension_format,
20
+ )
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class Qwen2VLImageProcessor:
26
+ def __init__(
27
+ self,
28
+ do_resize: bool = True,
29
+ resample: Image.Resampling = Image.Resampling.BICUBIC,
30
+ do_rescale: bool = True,
31
+ rescale_factor: float = 1.0 / 255,
32
+ do_normalize: bool = True,
33
+ image_mean: List[float] = OPENAI_CLIP_MEAN,
34
+ image_std: List[float] = OPENAI_CLIP_STD,
35
+ do_convert_rgb: bool = True,
36
+ min_pixels: int = 56 * 56,
37
+ max_pixels: int = 28 * 28 * 1280,
38
+ patch_size: int = 14,
39
+ temporal_patch_size: int = 2,
40
+ merge_size: int = 2,
41
+ **kwargs,
42
+ ):
43
+ self.do_resize = do_resize
44
+ self.resample = resample
45
+ self.do_rescale = do_rescale
46
+ self.rescale_factor = rescale_factor
47
+ self.do_normalize = do_normalize
48
+ self.do_convert_rgb = do_convert_rgb
49
+ self.size = {"shortest_edge": min_pixels, "longest_edge": max_pixels}
50
+ self.image_mean = image_mean
51
+ self.image_std = image_std
52
+ self.patch_size = patch_size
53
+ self.merge_size = merge_size
54
+ self.min_pixels = min_pixels
55
+ self.max_pixels = max_pixels
56
+ self.temporal_patch_size = temporal_patch_size
57
+
58
+ @classmethod
59
+ def from_pretrained(cls, config_file_path: str | os.PathLike, **kwargs):
60
+ init_kwargs = {}
61
+ if not os.path.exists(config_file_path):
62
+ logger.warning(f"Cannot find {config_file_path}, init processor with default parameters")
63
+ else:
64
+ with open(config_file_path, "r", encoding="utf-8") as kwargs_handler:
65
+ init_kwargs = json.load(kwargs_handler)
66
+
67
+ init_kwargs.update(**kwargs)
68
+ return cls(**init_kwargs)
69
+
70
+ def __call__(
71
+ self,
72
+ images: Image.Image | List[Image.Image],
73
+ videos: Optional[List[List[Image.Image]]] = None,
74
+ data_format: Optional[ChannelDimension] = ChannelDimension.FIRST,
75
+ ):
76
+ pixel_values, image_grid_thws = None, None
77
+ if images is not None:
78
+ if isinstance(images, Image.Image):
79
+ images = [images]
80
+ pixel_values, image_grid_thws = [], []
81
+ for image in images:
82
+ flatten_patches, image_grid_thw = self._preprocess([image], data_format)
83
+ pixel_values.extend(flatten_patches)
84
+ image_grid_thws.append(image_grid_thw)
85
+ pixel_values = np.array(pixel_values)
86
+ image_grid_thws = np.array(image_grid_thws)
87
+
88
+ vision_pixel_values, vision_grid_thws = None, None
89
+ if videos is not None:
90
+ vision_pixel_values, vision_grid_thws = [], []
91
+ for images in videos:
92
+ flatten_patches, video_grid_thw = self._preprocess(images, data_format)
93
+ vision_pixel_values.append(flatten_patches)
94
+ vision_grid_thws.append(video_grid_thw)
95
+ vision_pixel_values = np.array(vision_pixel_values)
96
+ vision_grid_thws = np.array(vision_grid_thws)
97
+
98
+ return pixel_values, image_grid_thws, vision_pixel_values, vision_grid_thws
99
+
100
+ def _preprocess(self, images: List[Image.Image], data_format: Optional[ChannelDimension] = ChannelDimension.FIRST):
101
+ images = [convert_to_rgb(image) for image in images]
102
+ image_nps = [np.array(image) for image in images]
103
+ input_data_format = infer_channel_dimension_format(image_nps[0])
104
+ height, width = get_image_size(image_nps[0], input_data_format)
105
+ resized_height, resized_width = height, width
106
+
107
+ processed_image_nps = []
108
+ for image_np in image_nps:
109
+ if self.do_resize:
110
+ resized_height, resized_width = smart_resize(
111
+ height,
112
+ width,
113
+ factor=self.patch_size * self.merge_size,
114
+ min_pixels=self.min_pixels,
115
+ max_pixels=self.max_pixels,
116
+ )
117
+ image_np = resize_image(
118
+ image_np, resized_height, resized_width, self.resample, input_data_format=input_data_format
119
+ )
120
+
121
+ if self.do_rescale:
122
+ image_np = rescale_image(image_np, self.rescale_factor)
123
+
124
+ if self.do_normalize:
125
+ image_np = normalize_image(
126
+ image_np, self.image_mean, self.image_std, input_data_format=input_data_format
127
+ )
128
+ image_np = to_channel_dimension_format(image_np, data_format, input_data_format)
129
+ processed_image_nps.append(image_np)
130
+
131
+ patches = np.array(processed_image_nps)
132
+ if data_format == ChannelDimension.LAST:
133
+ patches = patches.transpose(0, 3, 1, 2)
134
+ if patches.shape[0] % self.temporal_patch_size != 0:
135
+ repeats = np.repeat(patches[-1][np.newaxis], self.temporal_patch_size - 1, axis=0)
136
+ patches = np.concatenate([patches, repeats], axis=0)
137
+ num_channel = patches.shape[1]
138
+ grid_t = patches.shape[0] // self.temporal_patch_size
139
+ grid_h = resized_height // self.patch_size
140
+ grid_w = resized_width // self.patch_size
141
+ patches = patches.reshape(
142
+ grid_t,
143
+ self.temporal_patch_size,
144
+ num_channel,
145
+ grid_h // self.merge_size,
146
+ self.merge_size,
147
+ self.patch_size,
148
+ grid_w // self.merge_size,
149
+ self.merge_size,
150
+ self.patch_size,
151
+ )
152
+ patches = patches.transpose(0, 3, 6, 4, 7, 2, 1, 5, 8)
153
+ flatten_patches = patches.reshape(
154
+ grid_t * grid_h * grid_w, num_channel * self.temporal_patch_size * self.patch_size * self.patch_size
155
+ )
156
+
157
+ return flatten_patches, (grid_t, grid_h, grid_w)
@@ -0,0 +1,100 @@
1
+ import os
2
+ import re
3
+ import torch
4
+ import logging
5
+ from PIL import Image
6
+ from typing import List, Dict, Optional
7
+
8
+ from diffsynth_engine.tokenizers.qwen2_vl_image_processor import Qwen2VLImageProcessor
9
+ from diffsynth_engine.tokenizers.qwen2 import Qwen2TokenizerFast
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+
14
+ class Qwen2VLProcessor:
15
+ def __init__(
16
+ self,
17
+ tokenizer: Qwen2TokenizerFast,
18
+ image_processor: Qwen2VLImageProcessor,
19
+ image_token: str = "<|image_pad|>",
20
+ **kwargs,
21
+ ):
22
+ self.tokenizer = tokenizer
23
+ self.image_processor = image_processor
24
+ self.image_token = image_token
25
+
26
+ @classmethod
27
+ def from_pretrained(
28
+ cls,
29
+ tokenizer_config_path: str | os.PathLike,
30
+ image_processor_config_path: str | os.PathLike,
31
+ **kwargs,
32
+ ):
33
+ tokenizer = Qwen2TokenizerFast.from_pretrained(tokenizer_config_path)
34
+ image_processor = Qwen2VLImageProcessor.from_pretrained(image_processor_config_path)
35
+ return cls(tokenizer=tokenizer, image_processor=image_processor, **kwargs)
36
+
37
+ def batch_decode(
38
+ self,
39
+ ids: List[List[int]] | List[torch.Tensor],
40
+ skip_special_tokens: bool = False,
41
+ clean_up_tokenization_spaces: Optional[bool] = None,
42
+ ):
43
+ if isinstance(ids[0], torch.Tensor):
44
+ ids = [id_.tolist() for id_ in ids]
45
+ decoded = self.tokenizer.batch_decode(ids, skip_special_tokens, clean_up_tokenization_spaces)
46
+ pattern = r"<\|vision_start\|>.*?<\|vision_end\|>"
47
+ decoded_with_image_tag = [re.sub(pattern, "<image>", d, flags=re.DOTALL) for d in decoded]
48
+ decoded_with_image_tag = [re.sub(r"<\|im_end\|>", "", d) for d in decoded_with_image_tag]
49
+ return decoded_with_image_tag
50
+
51
+ def __call__(
52
+ self,
53
+ text: str | List[str],
54
+ images: Optional[List[Image.Image]] = None,
55
+ videos: Optional[List[List[Image.Image]]] = None,
56
+ max_length: Optional[int] = None,
57
+ ) -> Dict[str, torch.Tensor]:
58
+ """
59
+ Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
60
+ and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
61
+ the text. To prepare the vision inputs, this method forwards the `vision_infos` and `kwrags` arguments to
62
+ Qwen2VLImageProcessor's [`~Qwen2VLImageProcessor.__call__`] if `vision_infos` is not `None`.
63
+
64
+ Args:
65
+ text (`List[str]`):
66
+ The sequence or batch of sequences to be encoded.
67
+ images (`List[PIL.Image.Image]`):
68
+ The batch of images to be prepared.
69
+ videos (`List[List[PIL.Image.Image]]`):
70
+ The batch of videos to be prepared.
71
+ """
72
+ images_pixel_values, images_grid_thws, video_pixels_values, video_grid_thws = self.image_processor(
73
+ images, videos
74
+ )
75
+
76
+ if not isinstance(text, list):
77
+ text = [text]
78
+ if images_grid_thws is not None:
79
+ merge_length = self.image_processor.merge_size**2
80
+ index = 0
81
+ for i in range(len(text)):
82
+ while self.image_token in text[i]:
83
+ text[i] = text[i].replace(
84
+ self.image_token, "<|placeholder|>" * (images_grid_thws[index].prod() // merge_length), 1
85
+ )
86
+ index += 1
87
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
88
+ text_inputs = self.tokenizer(text, max_length=max_length)
89
+
90
+ processed_inputs = text_inputs
91
+ if images_pixel_values is not None:
92
+ processed_inputs["pixel_values"] = torch.from_numpy(images_pixel_values)
93
+ if images_grid_thws is not None:
94
+ processed_inputs["image_grid_thw"] = torch.from_numpy(images_grid_thws)
95
+ if video_pixels_values is not None:
96
+ processed_inputs["pixel_values_videos"] = video_pixels_values
97
+ if video_grid_thws is not None:
98
+ processed_inputs["video_grid_thw"] = video_grid_thws
99
+
100
+ return processed_inputs
@@ -5,6 +5,7 @@ REPO_ROOT = os.path.dirname(PACKAGE_ROOT)
5
5
 
6
6
  # conf
7
7
  CONF_PATH = os.path.join(PACKAGE_ROOT, "conf")
8
+
8
9
  # tokenizers
9
10
  FLUX_TOKENIZER_1_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_1")
10
11
  FLUX_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "flux", "tokenizer_2")
@@ -12,6 +13,8 @@ SDXL_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokeni
12
13
  SDXL_TOKENIZER_2_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "sdxl", "tokenizer_2")
13
14
  WAN_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "wan", "umt5-xxl")
14
15
  QWEN_IMAGE_TOKENIZER_CONF_PATH = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "tokenizer")
16
+ QWEN_IMAGE_PROCESSOR_CONFIG_FILE = os.path.join(CONF_PATH, "tokenizers", "qwen_image", "qwen2_vl_image_processor.json")
17
+
15
18
  # models
16
19
  VAE_CONFIG_FILE = os.path.join(CONF_PATH, "models", "components", "vae.json")
17
20
  FLUX_DIT_CONFIG_FILE = os.path.join(CONF_PATH, "models", "flux", "flux_dit.json")
@@ -46,3 +49,6 @@ KB = 1024
46
49
  MB = 1024 * KB
47
50
  GB = 1024 * MB
48
51
  TB = 1024 * GB
52
+
53
+ OPENAI_CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073]
54
+ OPENAI_CLIP_STD = [0.26862954, 0.26130258, 0.27577711]