cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +5 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +14 -20
- cache_dit/cache_factory/block_adapters/block_adapters.py +46 -2
- cache_dit/cache_factory/block_adapters/block_registers.py +3 -2
- cache_dit/cache_factory/cache_adapters.py +8 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +11 -11
- cache_dit/cache_factory/cache_contexts/cache_manager.py +5 -5
- cache_dit/cache_factory/cache_contexts/taylorseer.py +12 -6
- cache_dit/cache_factory/cache_interface.py +9 -9
- cache_dit/cache_factory/patch_functors/__init__.py +1 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +142 -52
- cache_dit/cache_factory/patch_functors/functor_dit.py +130 -0
- cache_dit/metrics/clip_score.py +135 -0
- cache_dit/metrics/fid.py +42 -0
- cache_dit/metrics/image_reward.py +177 -0
- cache_dit/metrics/lpips.py +2 -14
- cache_dit/metrics/metrics.py +420 -76
- cache_dit/utils.py +15 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/METADATA +261 -52
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/RECORD +25 -22
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.33.dist-info → cache_dit-0.2.36.dist-info}/top_level.txt +0 -0
|
@@ -1,10 +1,9 @@
|
|
|
1
|
-
import inspect
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
2
|
import numpy as np
|
|
5
3
|
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
4
|
from diffusers import ChromaTransformer2DModel
|
|
7
5
|
from diffusers.models.transformers.transformer_chroma import (
|
|
6
|
+
ChromaTransformerBlock,
|
|
8
7
|
ChromaSingleTransformerBlock,
|
|
9
8
|
Transformer2DModelOutput,
|
|
10
9
|
)
|
|
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
27
26
|
def apply(
|
|
28
27
|
self,
|
|
29
28
|
transformer: ChromaTransformer2DModel,
|
|
30
|
-
blocks: torch.nn.ModuleList = None,
|
|
31
29
|
**kwargs,
|
|
32
30
|
) -> ChromaTransformer2DModel:
|
|
33
31
|
if hasattr(transformer, "_is_patched"):
|
|
34
32
|
return transformer
|
|
35
33
|
|
|
36
|
-
if blocks is None:
|
|
37
|
-
blocks = transformer.single_transformer_blocks
|
|
38
|
-
|
|
39
34
|
is_patched = False
|
|
40
|
-
for block in
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
35
|
+
for index_block, block in enumerate(transformer.transformer_blocks):
|
|
36
|
+
assert isinstance(block, ChromaTransformerBlock)
|
|
37
|
+
img_offset = 3 * len(transformer.single_transformer_blocks)
|
|
38
|
+
txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
|
|
39
|
+
img_modulation = img_offset + 6 * index_block
|
|
40
|
+
text_modulation = txt_offset + 6 * index_block
|
|
41
|
+
block._img_modulation = img_modulation
|
|
42
|
+
block._text_modulation = text_modulation
|
|
43
|
+
block.forward = __patch_double_forward__.__get__(block)
|
|
44
|
+
|
|
45
|
+
for index_block, block in enumerate(
|
|
46
|
+
transformer.single_transformer_blocks
|
|
47
|
+
):
|
|
48
|
+
assert isinstance(block, ChromaSingleTransformerBlock)
|
|
49
|
+
start_idx = 3 * index_block
|
|
50
|
+
block._start_idx = start_idx
|
|
51
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
52
|
+
|
|
53
|
+
is_patched = True
|
|
48
54
|
|
|
49
55
|
cls_name = transformer.__class__.__name__
|
|
50
56
|
|
|
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
69
75
|
return transformer
|
|
70
76
|
|
|
71
77
|
|
|
78
|
+
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
79
|
+
def __patch_double_forward__(
|
|
80
|
+
self: ChromaTransformerBlock,
|
|
81
|
+
hidden_states: torch.Tensor,
|
|
82
|
+
encoder_hidden_states: torch.Tensor,
|
|
83
|
+
pooled_temb: torch.Tensor,
|
|
84
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
85
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
86
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
87
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
88
|
+
# TODO: Fuse controlnet into block forward
|
|
89
|
+
img_modulation = self._img_modulation
|
|
90
|
+
text_modulation = self._text_modulation
|
|
91
|
+
temb = torch.cat(
|
|
92
|
+
(
|
|
93
|
+
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
94
|
+
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
95
|
+
),
|
|
96
|
+
dim=1,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
temb_img, temb_txt = temb[:, :6], temb[:, 6:]
|
|
100
|
+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
|
101
|
+
hidden_states, emb=temb_img
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
(
|
|
105
|
+
norm_encoder_hidden_states,
|
|
106
|
+
c_gate_msa,
|
|
107
|
+
c_shift_mlp,
|
|
108
|
+
c_scale_mlp,
|
|
109
|
+
c_gate_mlp,
|
|
110
|
+
) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
|
|
111
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
112
|
+
if attention_mask is not None:
|
|
113
|
+
attention_mask = (
|
|
114
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Attention.
|
|
118
|
+
attention_outputs = self.attn(
|
|
119
|
+
hidden_states=norm_hidden_states,
|
|
120
|
+
encoder_hidden_states=norm_encoder_hidden_states,
|
|
121
|
+
image_rotary_emb=image_rotary_emb,
|
|
122
|
+
attention_mask=attention_mask,
|
|
123
|
+
**joint_attention_kwargs,
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
if len(attention_outputs) == 2:
|
|
127
|
+
attn_output, context_attn_output = attention_outputs
|
|
128
|
+
elif len(attention_outputs) == 3:
|
|
129
|
+
attn_output, context_attn_output, ip_attn_output = attention_outputs
|
|
130
|
+
|
|
131
|
+
# Process attention outputs for the `hidden_states`.
|
|
132
|
+
attn_output = gate_msa.unsqueeze(1) * attn_output
|
|
133
|
+
hidden_states = hidden_states + attn_output
|
|
134
|
+
|
|
135
|
+
norm_hidden_states = self.norm2(hidden_states)
|
|
136
|
+
norm_hidden_states = (
|
|
137
|
+
norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
ff_output = self.ff(norm_hidden_states)
|
|
141
|
+
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
|
142
|
+
|
|
143
|
+
hidden_states = hidden_states + ff_output
|
|
144
|
+
if len(attention_outputs) == 3:
|
|
145
|
+
hidden_states = hidden_states + ip_attn_output
|
|
146
|
+
|
|
147
|
+
# Process attention outputs for the `encoder_hidden_states`.
|
|
148
|
+
|
|
149
|
+
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
|
|
150
|
+
encoder_hidden_states = encoder_hidden_states + context_attn_output
|
|
151
|
+
|
|
152
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
|
153
|
+
norm_encoder_hidden_states = (
|
|
154
|
+
norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
|
|
155
|
+
+ c_shift_mlp[:, None]
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
context_ff_output = self.ff_context(norm_encoder_hidden_states)
|
|
159
|
+
encoder_hidden_states = (
|
|
160
|
+
encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
|
|
161
|
+
)
|
|
162
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
163
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
164
|
+
|
|
165
|
+
return encoder_hidden_states, hidden_states
|
|
166
|
+
|
|
167
|
+
|
|
72
168
|
# adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
73
169
|
def __patch_single_forward__(
|
|
74
170
|
self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
|
|
75
171
|
hidden_states: torch.Tensor,
|
|
76
|
-
|
|
77
|
-
temb: torch.Tensor,
|
|
172
|
+
pooled_temb: torch.Tensor,
|
|
78
173
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
174
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
79
175
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
80
|
-
) ->
|
|
81
|
-
|
|
82
|
-
|
|
176
|
+
) -> torch.Tensor:
|
|
177
|
+
# TODO: Fuse controlnet into block forward
|
|
178
|
+
start_idx = self._start_idx
|
|
179
|
+
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
83
180
|
|
|
84
181
|
residual = hidden_states
|
|
85
182
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
86
183
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
87
184
|
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
185
|
+
|
|
186
|
+
if attention_mask is not None:
|
|
187
|
+
attention_mask = (
|
|
188
|
+
attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
|
|
189
|
+
)
|
|
190
|
+
|
|
88
191
|
attn_output = self.attn(
|
|
89
192
|
hidden_states=norm_hidden_states,
|
|
90
193
|
image_rotary_emb=image_rotary_emb,
|
|
194
|
+
attention_mask=attention_mask,
|
|
91
195
|
**joint_attention_kwargs,
|
|
92
196
|
)
|
|
93
197
|
|
|
@@ -98,11 +202,7 @@ def __patch_single_forward__(
|
|
|
98
202
|
if hidden_states.dtype == torch.float16:
|
|
99
203
|
hidden_states = hidden_states.clip(-65504, 65504)
|
|
100
204
|
|
|
101
|
-
|
|
102
|
-
hidden_states[:, :text_seq_len],
|
|
103
|
-
hidden_states[:, text_seq_len:],
|
|
104
|
-
)
|
|
105
|
-
return encoder_hidden_states, hidden_states
|
|
205
|
+
return hidden_states
|
|
106
206
|
|
|
107
207
|
|
|
108
208
|
# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
|
|
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
|
|
|
174
274
|
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
175
275
|
|
|
176
276
|
for index_block, block in enumerate(self.transformer_blocks):
|
|
177
|
-
img_offset = 3 * len(self.single_transformer_blocks)
|
|
178
|
-
txt_offset = img_offset + 6 * len(self.transformer_blocks)
|
|
179
|
-
img_modulation = img_offset + 6 * index_block
|
|
180
|
-
text_modulation = txt_offset + 6 * index_block
|
|
181
|
-
temb = torch.cat(
|
|
182
|
-
(
|
|
183
|
-
pooled_temb[:, img_modulation : img_modulation + 6],
|
|
184
|
-
pooled_temb[:, text_modulation : text_modulation + 6],
|
|
185
|
-
),
|
|
186
|
-
dim=1,
|
|
187
|
-
)
|
|
188
277
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
189
278
|
encoder_hidden_states, hidden_states = (
|
|
190
279
|
self._gradient_checkpointing_func(
|
|
191
280
|
block,
|
|
192
281
|
hidden_states,
|
|
193
282
|
encoder_hidden_states,
|
|
194
|
-
|
|
283
|
+
pooled_temb,
|
|
195
284
|
image_rotary_emb,
|
|
196
285
|
attention_mask,
|
|
197
286
|
)
|
|
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
|
|
|
201
290
|
encoder_hidden_states, hidden_states = block(
|
|
202
291
|
hidden_states=hidden_states,
|
|
203
292
|
encoder_hidden_states=encoder_hidden_states,
|
|
204
|
-
|
|
293
|
+
pooled_temb=pooled_temb,
|
|
205
294
|
image_rotary_emb=image_rotary_emb,
|
|
206
295
|
attention_mask=attention_mask,
|
|
207
296
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
208
297
|
)
|
|
209
298
|
|
|
299
|
+
# TODO: Fuse controlnet into block forward
|
|
210
300
|
# controlnet residual
|
|
211
301
|
if controlnet_block_samples is not None:
|
|
212
302
|
interval_control = len(self.transformer_blocks) / len(
|
|
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
|
|
|
227
317
|
+ controlnet_block_samples[index_block // interval_control]
|
|
228
318
|
)
|
|
229
319
|
|
|
320
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
321
|
+
|
|
230
322
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
231
|
-
start_idx = 3 * index_block
|
|
232
|
-
temb = pooled_temb[:, start_idx : start_idx + 3]
|
|
233
323
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
)
|
|
324
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
325
|
+
block,
|
|
326
|
+
hidden_states,
|
|
327
|
+
pooled_temb,
|
|
328
|
+
image_rotary_emb,
|
|
329
|
+
attention_mask,
|
|
330
|
+
joint_attention_kwargs,
|
|
242
331
|
)
|
|
243
332
|
|
|
244
333
|
else:
|
|
245
|
-
|
|
334
|
+
hidden_states = block(
|
|
246
335
|
hidden_states=hidden_states,
|
|
247
|
-
|
|
248
|
-
temb=temb,
|
|
336
|
+
pooled_temb=pooled_temb,
|
|
249
337
|
image_rotary_emb=image_rotary_emb,
|
|
250
338
|
attention_mask=attention_mask,
|
|
251
339
|
joint_attention_kwargs=joint_attention_kwargs,
|
|
252
340
|
)
|
|
253
341
|
|
|
342
|
+
# TODO: Fuse controlnet into block forward
|
|
254
343
|
# controlnet residual
|
|
255
344
|
if controlnet_single_block_samples is not None:
|
|
256
345
|
interval_control = len(self.single_transformer_blocks) / len(
|
|
257
346
|
controlnet_single_block_samples
|
|
258
347
|
)
|
|
259
348
|
interval_control = int(np.ceil(interval_control))
|
|
260
|
-
hidden_states = (
|
|
261
|
-
hidden_states
|
|
349
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
|
350
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
262
351
|
+ controlnet_single_block_samples[
|
|
263
352
|
index_block // interval_control
|
|
264
353
|
]
|
|
265
354
|
)
|
|
266
355
|
|
|
356
|
+
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
|
267
357
|
temb = pooled_temb[:, -2:]
|
|
268
358
|
hidden_states = self.norm_out(hidden_states, temb)
|
|
269
359
|
output = self.proj_out(hidden_states)
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.nn.functional as F
|
|
3
|
+
|
|
4
|
+
from typing import Optional, Dict, Any
|
|
5
|
+
from diffusers.models.transformers.dit_transformer_2d import (
|
|
6
|
+
DiTTransformer2DModel,
|
|
7
|
+
Transformer2DModelOutput,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.patch_functors.functor_base import (
|
|
10
|
+
PatchFunctor,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DiTPatchFunctor(PatchFunctor):
|
|
18
|
+
|
|
19
|
+
def apply(
|
|
20
|
+
self,
|
|
21
|
+
transformer: DiTTransformer2DModel,
|
|
22
|
+
**kwargs,
|
|
23
|
+
) -> DiTTransformer2DModel:
|
|
24
|
+
if hasattr(transformer, "_is_patched"):
|
|
25
|
+
return transformer
|
|
26
|
+
|
|
27
|
+
is_patched = False
|
|
28
|
+
|
|
29
|
+
transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
|
|
30
|
+
|
|
31
|
+
is_patched = True
|
|
32
|
+
|
|
33
|
+
cls_name = transformer.__class__.__name__
|
|
34
|
+
|
|
35
|
+
if is_patched:
|
|
36
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
37
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
38
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
39
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
40
|
+
"parallized forward and cause a downgrade of performance."
|
|
41
|
+
)
|
|
42
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
43
|
+
transformer
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
transformer._is_patched = is_patched # True or False
|
|
47
|
+
|
|
48
|
+
logger.info(
|
|
49
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
50
|
+
f"Patch: {is_patched}."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return transformer
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def __patch_transformer_forward__(
|
|
57
|
+
self: DiTTransformer2DModel,
|
|
58
|
+
hidden_states: torch.Tensor,
|
|
59
|
+
timestep: Optional[torch.LongTensor] = None,
|
|
60
|
+
class_labels: Optional[torch.LongTensor] = None,
|
|
61
|
+
cross_attention_kwargs: Dict[str, Any] = None,
|
|
62
|
+
return_dict: bool = True,
|
|
63
|
+
):
|
|
64
|
+
height, width = (
|
|
65
|
+
hidden_states.shape[-2] // self.patch_size,
|
|
66
|
+
hidden_states.shape[-1] // self.patch_size,
|
|
67
|
+
)
|
|
68
|
+
hidden_states = self.pos_embed(hidden_states)
|
|
69
|
+
|
|
70
|
+
# 2. Blocks
|
|
71
|
+
for block in self.transformer_blocks:
|
|
72
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
73
|
+
hidden_states = self._gradient_checkpointing_func(
|
|
74
|
+
block,
|
|
75
|
+
hidden_states,
|
|
76
|
+
None,
|
|
77
|
+
None,
|
|
78
|
+
None,
|
|
79
|
+
timestep,
|
|
80
|
+
cross_attention_kwargs,
|
|
81
|
+
class_labels,
|
|
82
|
+
)
|
|
83
|
+
else:
|
|
84
|
+
hidden_states = block(
|
|
85
|
+
hidden_states,
|
|
86
|
+
attention_mask=None,
|
|
87
|
+
encoder_hidden_states=None,
|
|
88
|
+
encoder_attention_mask=None,
|
|
89
|
+
timestep=timestep,
|
|
90
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
|
91
|
+
class_labels=class_labels,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# 3. Output
|
|
95
|
+
# conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
|
|
96
|
+
conditioning = self._norm1_emb(
|
|
97
|
+
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
|
98
|
+
)
|
|
99
|
+
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
|
100
|
+
hidden_states = (
|
|
101
|
+
self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
|
102
|
+
)
|
|
103
|
+
hidden_states = self.proj_out_2(hidden_states)
|
|
104
|
+
|
|
105
|
+
# unpatchify
|
|
106
|
+
height = width = int(hidden_states.shape[1] ** 0.5)
|
|
107
|
+
hidden_states = hidden_states.reshape(
|
|
108
|
+
shape=(
|
|
109
|
+
-1,
|
|
110
|
+
height,
|
|
111
|
+
width,
|
|
112
|
+
self.patch_size,
|
|
113
|
+
self.patch_size,
|
|
114
|
+
self.out_channels,
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
|
118
|
+
output = hidden_states.reshape(
|
|
119
|
+
shape=(
|
|
120
|
+
-1,
|
|
121
|
+
self.out_channels,
|
|
122
|
+
height * self.patch_size,
|
|
123
|
+
width * self.patch_size,
|
|
124
|
+
)
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
if not return_dict:
|
|
128
|
+
return (output,)
|
|
129
|
+
|
|
130
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import re
|
|
3
|
+
import pathlib
|
|
4
|
+
import numpy as np
|
|
5
|
+
from PIL import Image
|
|
6
|
+
from tqdm import tqdm
|
|
7
|
+
import torch
|
|
8
|
+
from transformers import CLIPProcessor, CLIPModel
|
|
9
|
+
|
|
10
|
+
from typing import Tuple, Union
|
|
11
|
+
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
12
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
13
|
+
from cache_dit.logger import init_logger
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CLIPScore:
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
device="cuda" if torch.cuda.is_available() else "cpu",
|
|
25
|
+
clip_model_path: str = None,
|
|
26
|
+
):
|
|
27
|
+
self.device = device
|
|
28
|
+
if clip_model_path is None:
|
|
29
|
+
clip_model_path = os.environ.get(
|
|
30
|
+
"CLIP_MODEL_DIR", "laion/CLIP-ViT-g-14-laion2B-s12B-b42K"
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
# Load models
|
|
34
|
+
self.clip_model = CLIPModel.from_pretrained(clip_model_path)
|
|
35
|
+
self.clip_model = self.clip_model.to(device) # type: ignore
|
|
36
|
+
self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
|
|
37
|
+
|
|
38
|
+
@torch.no_grad()
|
|
39
|
+
def compute_clip_score(
|
|
40
|
+
self,
|
|
41
|
+
img: Image.Image | np.ndarray,
|
|
42
|
+
prompt: str,
|
|
43
|
+
) -> float:
|
|
44
|
+
if isinstance(img, Image.Image):
|
|
45
|
+
img_pil = img.convert("RGB")
|
|
46
|
+
elif isinstance(img, np.ndarray):
|
|
47
|
+
img_pil = Image.fromarray(img).convert("RGB")
|
|
48
|
+
else:
|
|
49
|
+
img_pil = Image.open(img).convert("RGB")
|
|
50
|
+
with torch.no_grad():
|
|
51
|
+
inputs = self.clip_processor(
|
|
52
|
+
text=prompt,
|
|
53
|
+
images=img_pil,
|
|
54
|
+
return_tensors="pt",
|
|
55
|
+
padding=True,
|
|
56
|
+
truncation=True,
|
|
57
|
+
).to(self.device)
|
|
58
|
+
outputs = self.clip_model(**inputs)
|
|
59
|
+
return outputs.logits_per_image.item()
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
clip_score_instance: CLIPScore = None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def compute_clip_score_img(
|
|
66
|
+
img: Image.Image | np.ndarray | str,
|
|
67
|
+
prompt: str,
|
|
68
|
+
clip_model_path: str = None,
|
|
69
|
+
) -> float:
|
|
70
|
+
global clip_score_instance
|
|
71
|
+
if clip_score_instance is None:
|
|
72
|
+
clip_score_instance = CLIPScore(clip_model_path=clip_model_path)
|
|
73
|
+
assert clip_score_instance is not None
|
|
74
|
+
return clip_score_instance.compute_clip_score(img, prompt)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def compute_clip_score(
|
|
78
|
+
img_dir: Image.Image | np.ndarray | str,
|
|
79
|
+
prompts: str | list[str],
|
|
80
|
+
clip_model_path: str = None,
|
|
81
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
82
|
+
if not os.path.isdir(img_dir) or (
|
|
83
|
+
not isinstance(prompts, list) and not os.path.isfile(prompts)
|
|
84
|
+
):
|
|
85
|
+
return (
|
|
86
|
+
compute_clip_score_img(
|
|
87
|
+
img_dir,
|
|
88
|
+
prompts,
|
|
89
|
+
clip_model_path=clip_model_path,
|
|
90
|
+
),
|
|
91
|
+
1,
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# compute dir metric
|
|
95
|
+
def natural_sort_key(filename):
|
|
96
|
+
match = re.search(r"(\d+)\D*$", filename)
|
|
97
|
+
return int(match.group(1)) if match else filename
|
|
98
|
+
|
|
99
|
+
img_dir: pathlib.Path = pathlib.Path(img_dir)
|
|
100
|
+
img_files = [
|
|
101
|
+
file
|
|
102
|
+
for ext in _IMAGE_EXTENSIONS
|
|
103
|
+
for file in img_dir.rglob("*.{}".format(ext))
|
|
104
|
+
]
|
|
105
|
+
img_files = [file.as_posix() for file in img_files]
|
|
106
|
+
img_files = sorted(img_files, key=natural_sort_key)
|
|
107
|
+
|
|
108
|
+
if os.path.isfile(prompts):
|
|
109
|
+
"""Load prompts from file"""
|
|
110
|
+
with open(prompts, "r", encoding="utf-8") as f:
|
|
111
|
+
prompts_load = [line.strip() for line in f.readlines()]
|
|
112
|
+
prompts = prompts_load.copy()
|
|
113
|
+
|
|
114
|
+
vaild_len = min(len(img_files), len(prompts))
|
|
115
|
+
img_files = img_files[:vaild_len]
|
|
116
|
+
prompts = prompts[:vaild_len]
|
|
117
|
+
|
|
118
|
+
clip_scores = []
|
|
119
|
+
|
|
120
|
+
for img_file, prompt in tqdm(
|
|
121
|
+
zip(img_files, prompts),
|
|
122
|
+
total=vaild_len,
|
|
123
|
+
disable=not get_metrics_verbose(),
|
|
124
|
+
):
|
|
125
|
+
clip_scores.append(
|
|
126
|
+
compute_clip_score_img(
|
|
127
|
+
img_file,
|
|
128
|
+
prompt,
|
|
129
|
+
clip_model_path=clip_model_path,
|
|
130
|
+
)
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if vaild_len > 0:
|
|
134
|
+
return np.mean(clip_scores), vaild_len
|
|
135
|
+
return None, None
|
cache_dit/metrics/fid.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import cv2
|
|
3
3
|
import pathlib
|
|
4
|
+
import warnings
|
|
5
|
+
|
|
4
6
|
import numpy as np
|
|
5
7
|
from PIL import Image
|
|
6
8
|
from tqdm import tqdm
|
|
@@ -8,13 +10,21 @@ from scipy import linalg
|
|
|
8
10
|
import torch
|
|
9
11
|
import torchvision.transforms as TF
|
|
10
12
|
from torch.nn.functional import adaptive_avg_pool2d
|
|
13
|
+
|
|
14
|
+
from typing import Tuple, Union
|
|
11
15
|
from cache_dit.metrics.inception import InceptionV3
|
|
12
16
|
from cache_dit.metrics.config import _IMAGE_EXTENSIONS
|
|
13
17
|
from cache_dit.metrics.config import _VIDEO_EXTENSIONS
|
|
18
|
+
from cache_dit.metrics.config import get_metrics_verbose
|
|
19
|
+
from cache_dit.utils import disable_print
|
|
14
20
|
from cache_dit.logger import init_logger
|
|
15
21
|
|
|
22
|
+
warnings.filterwarnings("ignore")
|
|
23
|
+
|
|
16
24
|
logger = init_logger(__name__)
|
|
17
25
|
|
|
26
|
+
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
27
|
+
|
|
18
28
|
|
|
19
29
|
# Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
|
|
20
30
|
class ImagePathDataset(torch.utils.data.Dataset):
|
|
@@ -496,3 +506,35 @@ class FrechetInceptionDistance:
|
|
|
496
506
|
return [], [], 0
|
|
497
507
|
|
|
498
508
|
return video_true_frames, video_test_frames, valid_frames
|
|
509
|
+
|
|
510
|
+
|
|
511
|
+
fid_instance: FrechetInceptionDistance = None
|
|
512
|
+
|
|
513
|
+
|
|
514
|
+
def compute_fid(
|
|
515
|
+
image_true: np.ndarray | str,
|
|
516
|
+
image_test: np.ndarray | str,
|
|
517
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
518
|
+
global fid_instance
|
|
519
|
+
if fid_instance is None:
|
|
520
|
+
with disable_print():
|
|
521
|
+
fid_instance = FrechetInceptionDistance(
|
|
522
|
+
disable_tqdm=not get_metrics_verbose(),
|
|
523
|
+
)
|
|
524
|
+
assert fid_instance is not None
|
|
525
|
+
return fid_instance.compute_fid(image_true, image_test)
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
def compute_video_fid(
|
|
529
|
+
# file or dir
|
|
530
|
+
video_true: str,
|
|
531
|
+
video_test: str,
|
|
532
|
+
) -> Union[Tuple[float, int], Tuple[None, None]]:
|
|
533
|
+
global fid_instance
|
|
534
|
+
if fid_instance is None:
|
|
535
|
+
with disable_print():
|
|
536
|
+
fid_instance = FrechetInceptionDistance(
|
|
537
|
+
disable_tqdm=not get_metrics_verbose(),
|
|
538
|
+
)
|
|
539
|
+
assert fid_instance is not None
|
|
540
|
+
return fid_instance.compute_fid(video_true, video_test)
|