cache-dit 0.2.33__py3-none-any.whl → 0.2.36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,10 +1,9 @@
1
- import inspect
2
-
3
1
  import torch
4
2
  import numpy as np
5
3
  from typing import Tuple, Optional, Dict, Any, Union
6
4
  from diffusers import ChromaTransformer2DModel
7
5
  from diffusers.models.transformers.transformer_chroma import (
6
+ ChromaTransformerBlock,
8
7
  ChromaSingleTransformerBlock,
9
8
  Transformer2DModelOutput,
10
9
  )
@@ -27,24 +26,31 @@ class ChromaPatchFunctor(PatchFunctor):
27
26
  def apply(
28
27
  self,
29
28
  transformer: ChromaTransformer2DModel,
30
- blocks: torch.nn.ModuleList = None,
31
29
  **kwargs,
32
30
  ) -> ChromaTransformer2DModel:
33
31
  if hasattr(transformer, "_is_patched"):
34
32
  return transformer
35
33
 
36
- if blocks is None:
37
- blocks = transformer.single_transformer_blocks
38
-
39
34
  is_patched = False
40
- for block in blocks:
41
- if isinstance(block, ChromaSingleTransformerBlock):
42
- forward_parameters = inspect.signature(
43
- block.forward
44
- ).parameters.keys()
45
- if "encoder_hidden_states" not in forward_parameters:
46
- block.forward = __patch_single_forward__.__get__(block)
47
- is_patched = True
35
+ for index_block, block in enumerate(transformer.transformer_blocks):
36
+ assert isinstance(block, ChromaTransformerBlock)
37
+ img_offset = 3 * len(transformer.single_transformer_blocks)
38
+ txt_offset = img_offset + 6 * len(transformer.transformer_blocks)
39
+ img_modulation = img_offset + 6 * index_block
40
+ text_modulation = txt_offset + 6 * index_block
41
+ block._img_modulation = img_modulation
42
+ block._text_modulation = text_modulation
43
+ block.forward = __patch_double_forward__.__get__(block)
44
+
45
+ for index_block, block in enumerate(
46
+ transformer.single_transformer_blocks
47
+ ):
48
+ assert isinstance(block, ChromaSingleTransformerBlock)
49
+ start_idx = 3 * index_block
50
+ block._start_idx = start_idx
51
+ block.forward = __patch_single_forward__.__get__(block)
52
+
53
+ is_patched = True
48
54
 
49
55
  cls_name = transformer.__class__.__name__
50
56
 
@@ -69,25 +75,123 @@ class ChromaPatchFunctor(PatchFunctor):
69
75
  return transformer
70
76
 
71
77
 
78
+ # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
79
+ def __patch_double_forward__(
80
+ self: ChromaTransformerBlock,
81
+ hidden_states: torch.Tensor,
82
+ encoder_hidden_states: torch.Tensor,
83
+ pooled_temb: torch.Tensor,
84
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
87
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
88
+ # TODO: Fuse controlnet into block forward
89
+ img_modulation = self._img_modulation
90
+ text_modulation = self._text_modulation
91
+ temb = torch.cat(
92
+ (
93
+ pooled_temb[:, img_modulation : img_modulation + 6],
94
+ pooled_temb[:, text_modulation : text_modulation + 6],
95
+ ),
96
+ dim=1,
97
+ )
98
+
99
+ temb_img, temb_txt = temb[:, :6], temb[:, 6:]
100
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
101
+ hidden_states, emb=temb_img
102
+ )
103
+
104
+ (
105
+ norm_encoder_hidden_states,
106
+ c_gate_msa,
107
+ c_shift_mlp,
108
+ c_scale_mlp,
109
+ c_gate_mlp,
110
+ ) = self.norm1_context(encoder_hidden_states, emb=temb_txt)
111
+ joint_attention_kwargs = joint_attention_kwargs or {}
112
+ if attention_mask is not None:
113
+ attention_mask = (
114
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
115
+ )
116
+
117
+ # Attention.
118
+ attention_outputs = self.attn(
119
+ hidden_states=norm_hidden_states,
120
+ encoder_hidden_states=norm_encoder_hidden_states,
121
+ image_rotary_emb=image_rotary_emb,
122
+ attention_mask=attention_mask,
123
+ **joint_attention_kwargs,
124
+ )
125
+
126
+ if len(attention_outputs) == 2:
127
+ attn_output, context_attn_output = attention_outputs
128
+ elif len(attention_outputs) == 3:
129
+ attn_output, context_attn_output, ip_attn_output = attention_outputs
130
+
131
+ # Process attention outputs for the `hidden_states`.
132
+ attn_output = gate_msa.unsqueeze(1) * attn_output
133
+ hidden_states = hidden_states + attn_output
134
+
135
+ norm_hidden_states = self.norm2(hidden_states)
136
+ norm_hidden_states = (
137
+ norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
138
+ )
139
+
140
+ ff_output = self.ff(norm_hidden_states)
141
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
142
+
143
+ hidden_states = hidden_states + ff_output
144
+ if len(attention_outputs) == 3:
145
+ hidden_states = hidden_states + ip_attn_output
146
+
147
+ # Process attention outputs for the `encoder_hidden_states`.
148
+
149
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
150
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
151
+
152
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
153
+ norm_encoder_hidden_states = (
154
+ norm_encoder_hidden_states * (1 + c_scale_mlp[:, None])
155
+ + c_shift_mlp[:, None]
156
+ )
157
+
158
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
159
+ encoder_hidden_states = (
160
+ encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
161
+ )
162
+ if encoder_hidden_states.dtype == torch.float16:
163
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
164
+
165
+ return encoder_hidden_states, hidden_states
166
+
167
+
72
168
  # adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
73
169
  def __patch_single_forward__(
74
170
  self: ChromaSingleTransformerBlock, # Almost same as FluxSingleTransformerBlock
75
171
  hidden_states: torch.Tensor,
76
- encoder_hidden_states: torch.Tensor,
77
- temb: torch.Tensor,
172
+ pooled_temb: torch.Tensor,
78
173
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
174
+ attention_mask: Optional[torch.Tensor] = None,
79
175
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
80
- ) -> Tuple[torch.Tensor, torch.Tensor]:
81
- text_seq_len = encoder_hidden_states.shape[1]
82
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
176
+ ) -> torch.Tensor:
177
+ # TODO: Fuse controlnet into block forward
178
+ start_idx = self._start_idx
179
+ temb = pooled_temb[:, start_idx : start_idx + 3]
83
180
 
84
181
  residual = hidden_states
85
182
  norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
86
183
  mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
87
184
  joint_attention_kwargs = joint_attention_kwargs or {}
185
+
186
+ if attention_mask is not None:
187
+ attention_mask = (
188
+ attention_mask[:, None, None, :] * attention_mask[:, None, :, None]
189
+ )
190
+
88
191
  attn_output = self.attn(
89
192
  hidden_states=norm_hidden_states,
90
193
  image_rotary_emb=image_rotary_emb,
194
+ attention_mask=attention_mask,
91
195
  **joint_attention_kwargs,
92
196
  )
93
197
 
@@ -98,11 +202,7 @@ def __patch_single_forward__(
98
202
  if hidden_states.dtype == torch.float16:
99
203
  hidden_states = hidden_states.clip(-65504, 65504)
100
204
 
101
- encoder_hidden_states, hidden_states = (
102
- hidden_states[:, :text_seq_len],
103
- hidden_states[:, text_seq_len:],
104
- )
105
- return encoder_hidden_states, hidden_states
205
+ return hidden_states
106
206
 
107
207
 
108
208
  # Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_chroma.py
@@ -174,24 +274,13 @@ def __patch_transformer_forward__(
174
274
  joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
175
275
 
176
276
  for index_block, block in enumerate(self.transformer_blocks):
177
- img_offset = 3 * len(self.single_transformer_blocks)
178
- txt_offset = img_offset + 6 * len(self.transformer_blocks)
179
- img_modulation = img_offset + 6 * index_block
180
- text_modulation = txt_offset + 6 * index_block
181
- temb = torch.cat(
182
- (
183
- pooled_temb[:, img_modulation : img_modulation + 6],
184
- pooled_temb[:, text_modulation : text_modulation + 6],
185
- ),
186
- dim=1,
187
- )
188
277
  if torch.is_grad_enabled() and self.gradient_checkpointing:
189
278
  encoder_hidden_states, hidden_states = (
190
279
  self._gradient_checkpointing_func(
191
280
  block,
192
281
  hidden_states,
193
282
  encoder_hidden_states,
194
- temb,
283
+ pooled_temb,
195
284
  image_rotary_emb,
196
285
  attention_mask,
197
286
  )
@@ -201,12 +290,13 @@ def __patch_transformer_forward__(
201
290
  encoder_hidden_states, hidden_states = block(
202
291
  hidden_states=hidden_states,
203
292
  encoder_hidden_states=encoder_hidden_states,
204
- temb=temb,
293
+ pooled_temb=pooled_temb,
205
294
  image_rotary_emb=image_rotary_emb,
206
295
  attention_mask=attention_mask,
207
296
  joint_attention_kwargs=joint_attention_kwargs,
208
297
  )
209
298
 
299
+ # TODO: Fuse controlnet into block forward
210
300
  # controlnet residual
211
301
  if controlnet_block_samples is not None:
212
302
  interval_control = len(self.transformer_blocks) / len(
@@ -227,43 +317,43 @@ def __patch_transformer_forward__(
227
317
  + controlnet_block_samples[index_block // interval_control]
228
318
  )
229
319
 
320
+ hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
321
+
230
322
  for index_block, block in enumerate(self.single_transformer_blocks):
231
- start_idx = 3 * index_block
232
- temb = pooled_temb[:, start_idx : start_idx + 3]
233
323
  if torch.is_grad_enabled() and self.gradient_checkpointing:
234
- encoder_hidden_states, hidden_states = (
235
- self._gradient_checkpointing_func(
236
- block,
237
- hidden_states,
238
- encoder_hidden_states,
239
- temb,
240
- image_rotary_emb,
241
- )
324
+ hidden_states = self._gradient_checkpointing_func(
325
+ block,
326
+ hidden_states,
327
+ pooled_temb,
328
+ image_rotary_emb,
329
+ attention_mask,
330
+ joint_attention_kwargs,
242
331
  )
243
332
 
244
333
  else:
245
- encoder_hidden_states, hidden_states = block(
334
+ hidden_states = block(
246
335
  hidden_states=hidden_states,
247
- encoder_hidden_states=encoder_hidden_states,
248
- temb=temb,
336
+ pooled_temb=pooled_temb,
249
337
  image_rotary_emb=image_rotary_emb,
250
338
  attention_mask=attention_mask,
251
339
  joint_attention_kwargs=joint_attention_kwargs,
252
340
  )
253
341
 
342
+ # TODO: Fuse controlnet into block forward
254
343
  # controlnet residual
255
344
  if controlnet_single_block_samples is not None:
256
345
  interval_control = len(self.single_transformer_blocks) / len(
257
346
  controlnet_single_block_samples
258
347
  )
259
348
  interval_control = int(np.ceil(interval_control))
260
- hidden_states = (
261
- hidden_states
349
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
350
+ hidden_states[:, encoder_hidden_states.shape[1] :, ...]
262
351
  + controlnet_single_block_samples[
263
352
  index_block // interval_control
264
353
  ]
265
354
  )
266
355
 
356
+ hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
267
357
  temb = pooled_temb[:, -2:]
268
358
  hidden_states = self.norm_out(hidden_states, temb)
269
359
  output = self.proj_out(hidden_states)
@@ -0,0 +1,130 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from typing import Optional, Dict, Any
5
+ from diffusers.models.transformers.dit_transformer_2d import (
6
+ DiTTransformer2DModel,
7
+ Transformer2DModelOutput,
8
+ )
9
+ from cache_dit.cache_factory.patch_functors.functor_base import (
10
+ PatchFunctor,
11
+ )
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class DiTPatchFunctor(PatchFunctor):
18
+
19
+ def apply(
20
+ self,
21
+ transformer: DiTTransformer2DModel,
22
+ **kwargs,
23
+ ) -> DiTTransformer2DModel:
24
+ if hasattr(transformer, "_is_patched"):
25
+ return transformer
26
+
27
+ is_patched = False
28
+
29
+ transformer._norm1_emb = transformer.transformer_blocks[0].norm1.emb
30
+
31
+ is_patched = True
32
+
33
+ cls_name = transformer.__class__.__name__
34
+
35
+ if is_patched:
36
+ logger.warning(f"Patched {cls_name} for cache-dit.")
37
+ assert not getattr(transformer, "_is_parallelized", False), (
38
+ "Please call `cache_dit.enable_cache` before Parallelize, "
39
+ "the __patch_transformer_forward__ will overwrite the "
40
+ "parallized forward and cause a downgrade of performance."
41
+ )
42
+ transformer.forward = __patch_transformer_forward__.__get__(
43
+ transformer
44
+ )
45
+
46
+ transformer._is_patched = is_patched # True or False
47
+
48
+ logger.info(
49
+ f"Applied {self.__class__.__name__} for {cls_name}, "
50
+ f"Patch: {is_patched}."
51
+ )
52
+
53
+ return transformer
54
+
55
+
56
+ def __patch_transformer_forward__(
57
+ self: DiTTransformer2DModel,
58
+ hidden_states: torch.Tensor,
59
+ timestep: Optional[torch.LongTensor] = None,
60
+ class_labels: Optional[torch.LongTensor] = None,
61
+ cross_attention_kwargs: Dict[str, Any] = None,
62
+ return_dict: bool = True,
63
+ ):
64
+ height, width = (
65
+ hidden_states.shape[-2] // self.patch_size,
66
+ hidden_states.shape[-1] // self.patch_size,
67
+ )
68
+ hidden_states = self.pos_embed(hidden_states)
69
+
70
+ # 2. Blocks
71
+ for block in self.transformer_blocks:
72
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
73
+ hidden_states = self._gradient_checkpointing_func(
74
+ block,
75
+ hidden_states,
76
+ None,
77
+ None,
78
+ None,
79
+ timestep,
80
+ cross_attention_kwargs,
81
+ class_labels,
82
+ )
83
+ else:
84
+ hidden_states = block(
85
+ hidden_states,
86
+ attention_mask=None,
87
+ encoder_hidden_states=None,
88
+ encoder_attention_mask=None,
89
+ timestep=timestep,
90
+ cross_attention_kwargs=cross_attention_kwargs,
91
+ class_labels=class_labels,
92
+ )
93
+
94
+ # 3. Output
95
+ # conditioning = self.transformer_blocks[0].norm1.emb(timestep, class_labels, hidden_dtype=hidden_states.dtype)
96
+ conditioning = self._norm1_emb(
97
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
98
+ )
99
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
100
+ hidden_states = (
101
+ self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
102
+ )
103
+ hidden_states = self.proj_out_2(hidden_states)
104
+
105
+ # unpatchify
106
+ height = width = int(hidden_states.shape[1] ** 0.5)
107
+ hidden_states = hidden_states.reshape(
108
+ shape=(
109
+ -1,
110
+ height,
111
+ width,
112
+ self.patch_size,
113
+ self.patch_size,
114
+ self.out_channels,
115
+ )
116
+ )
117
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
118
+ output = hidden_states.reshape(
119
+ shape=(
120
+ -1,
121
+ self.out_channels,
122
+ height * self.patch_size,
123
+ width * self.patch_size,
124
+ )
125
+ )
126
+
127
+ if not return_dict:
128
+ return (output,)
129
+
130
+ return Transformer2DModelOutput(sample=output)
@@ -0,0 +1,135 @@
1
+ import os
2
+ import re
3
+ import pathlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ from tqdm import tqdm
7
+ import torch
8
+ from transformers import CLIPProcessor, CLIPModel
9
+
10
+ from typing import Tuple, Union
11
+ from cache_dit.metrics.config import _IMAGE_EXTENSIONS
12
+ from cache_dit.metrics.config import get_metrics_verbose
13
+ from cache_dit.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ DISABLE_VERBOSE = not get_metrics_verbose()
19
+
20
+
21
+ class CLIPScore:
22
+ def __init__(
23
+ self,
24
+ device="cuda" if torch.cuda.is_available() else "cpu",
25
+ clip_model_path: str = None,
26
+ ):
27
+ self.device = device
28
+ if clip_model_path is None:
29
+ clip_model_path = os.environ.get(
30
+ "CLIP_MODEL_DIR", "laion/CLIP-ViT-g-14-laion2B-s12B-b42K"
31
+ )
32
+
33
+ # Load models
34
+ self.clip_model = CLIPModel.from_pretrained(clip_model_path)
35
+ self.clip_model = self.clip_model.to(device) # type: ignore
36
+ self.clip_processor = CLIPProcessor.from_pretrained(clip_model_path)
37
+
38
+ @torch.no_grad()
39
+ def compute_clip_score(
40
+ self,
41
+ img: Image.Image | np.ndarray,
42
+ prompt: str,
43
+ ) -> float:
44
+ if isinstance(img, Image.Image):
45
+ img_pil = img.convert("RGB")
46
+ elif isinstance(img, np.ndarray):
47
+ img_pil = Image.fromarray(img).convert("RGB")
48
+ else:
49
+ img_pil = Image.open(img).convert("RGB")
50
+ with torch.no_grad():
51
+ inputs = self.clip_processor(
52
+ text=prompt,
53
+ images=img_pil,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ truncation=True,
57
+ ).to(self.device)
58
+ outputs = self.clip_model(**inputs)
59
+ return outputs.logits_per_image.item()
60
+
61
+
62
+ clip_score_instance: CLIPScore = None
63
+
64
+
65
+ def compute_clip_score_img(
66
+ img: Image.Image | np.ndarray | str,
67
+ prompt: str,
68
+ clip_model_path: str = None,
69
+ ) -> float:
70
+ global clip_score_instance
71
+ if clip_score_instance is None:
72
+ clip_score_instance = CLIPScore(clip_model_path=clip_model_path)
73
+ assert clip_score_instance is not None
74
+ return clip_score_instance.compute_clip_score(img, prompt)
75
+
76
+
77
+ def compute_clip_score(
78
+ img_dir: Image.Image | np.ndarray | str,
79
+ prompts: str | list[str],
80
+ clip_model_path: str = None,
81
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
82
+ if not os.path.isdir(img_dir) or (
83
+ not isinstance(prompts, list) and not os.path.isfile(prompts)
84
+ ):
85
+ return (
86
+ compute_clip_score_img(
87
+ img_dir,
88
+ prompts,
89
+ clip_model_path=clip_model_path,
90
+ ),
91
+ 1,
92
+ )
93
+
94
+ # compute dir metric
95
+ def natural_sort_key(filename):
96
+ match = re.search(r"(\d+)\D*$", filename)
97
+ return int(match.group(1)) if match else filename
98
+
99
+ img_dir: pathlib.Path = pathlib.Path(img_dir)
100
+ img_files = [
101
+ file
102
+ for ext in _IMAGE_EXTENSIONS
103
+ for file in img_dir.rglob("*.{}".format(ext))
104
+ ]
105
+ img_files = [file.as_posix() for file in img_files]
106
+ img_files = sorted(img_files, key=natural_sort_key)
107
+
108
+ if os.path.isfile(prompts):
109
+ """Load prompts from file"""
110
+ with open(prompts, "r", encoding="utf-8") as f:
111
+ prompts_load = [line.strip() for line in f.readlines()]
112
+ prompts = prompts_load.copy()
113
+
114
+ vaild_len = min(len(img_files), len(prompts))
115
+ img_files = img_files[:vaild_len]
116
+ prompts = prompts[:vaild_len]
117
+
118
+ clip_scores = []
119
+
120
+ for img_file, prompt in tqdm(
121
+ zip(img_files, prompts),
122
+ total=vaild_len,
123
+ disable=not get_metrics_verbose(),
124
+ ):
125
+ clip_scores.append(
126
+ compute_clip_score_img(
127
+ img_file,
128
+ prompt,
129
+ clip_model_path=clip_model_path,
130
+ )
131
+ )
132
+
133
+ if vaild_len > 0:
134
+ return np.mean(clip_scores), vaild_len
135
+ return None, None
cache_dit/metrics/fid.py CHANGED
@@ -1,6 +1,8 @@
1
1
  import os
2
2
  import cv2
3
3
  import pathlib
4
+ import warnings
5
+
4
6
  import numpy as np
5
7
  from PIL import Image
6
8
  from tqdm import tqdm
@@ -8,13 +10,21 @@ from scipy import linalg
8
10
  import torch
9
11
  import torchvision.transforms as TF
10
12
  from torch.nn.functional import adaptive_avg_pool2d
13
+
14
+ from typing import Tuple, Union
11
15
  from cache_dit.metrics.inception import InceptionV3
12
16
  from cache_dit.metrics.config import _IMAGE_EXTENSIONS
13
17
  from cache_dit.metrics.config import _VIDEO_EXTENSIONS
18
+ from cache_dit.metrics.config import get_metrics_verbose
19
+ from cache_dit.utils import disable_print
14
20
  from cache_dit.logger import init_logger
15
21
 
22
+ warnings.filterwarnings("ignore")
23
+
16
24
  logger = init_logger(__name__)
17
25
 
26
+ DISABLE_VERBOSE = not get_metrics_verbose()
27
+
18
28
 
19
29
  # Adapted from: https://github.com/mseitzer/pytorch-fid/blob/master/src/pytorch_fid/fid_score.py
20
30
  class ImagePathDataset(torch.utils.data.Dataset):
@@ -496,3 +506,35 @@ class FrechetInceptionDistance:
496
506
  return [], [], 0
497
507
 
498
508
  return video_true_frames, video_test_frames, valid_frames
509
+
510
+
511
+ fid_instance: FrechetInceptionDistance = None
512
+
513
+
514
+ def compute_fid(
515
+ image_true: np.ndarray | str,
516
+ image_test: np.ndarray | str,
517
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
518
+ global fid_instance
519
+ if fid_instance is None:
520
+ with disable_print():
521
+ fid_instance = FrechetInceptionDistance(
522
+ disable_tqdm=not get_metrics_verbose(),
523
+ )
524
+ assert fid_instance is not None
525
+ return fid_instance.compute_fid(image_true, image_test)
526
+
527
+
528
+ def compute_video_fid(
529
+ # file or dir
530
+ video_true: str,
531
+ video_test: str,
532
+ ) -> Union[Tuple[float, int], Tuple[None, None]]:
533
+ global fid_instance
534
+ if fid_instance is None:
535
+ with disable_print():
536
+ fid_instance = FrechetInceptionDistance(
537
+ disable_tqdm=not get_metrics_verbose(),
538
+ )
539
+ assert fid_instance is not None
540
+ return fid_instance.compute_fid(video_true, video_test)