cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,538 @@
1
+ import torch
2
+
3
+ import inspect
4
+ import dataclasses
5
+
6
+ from typing import Any, Tuple, List, Optional, Union
7
+
8
+ from diffusers import DiffusionPipeline
9
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
10
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
11
+
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ class ParamsModifier:
18
+ def __init__(self, **kwargs):
19
+ self._context_kwargs = kwargs.copy()
20
+
21
+
22
+ @dataclasses.dataclass
23
+ class BlockAdapter:
24
+
25
+ # Transformer configurations.
26
+ pipe: Union[
27
+ DiffusionPipeline,
28
+ Any,
29
+ ] = None
30
+
31
+ # single transformer (most cases) or list of transformers (Wan2.2, etc)
32
+ transformer: Union[
33
+ torch.nn.Module,
34
+ List[torch.nn.Module],
35
+ ] = None
36
+
37
+ # Block Level Flags
38
+ # Each transformer contains a list of blocks-list,
39
+ # blocks_name-list, dummy_blocks_names-list, etc.
40
+ blocks: Union[
41
+ torch.nn.ModuleList,
42
+ List[torch.nn.ModuleList],
43
+ List[List[torch.nn.ModuleList]],
44
+ ] = None
45
+
46
+ # transformer_blocks, blocks, etc.
47
+ blocks_name: Union[
48
+ str,
49
+ List[str],
50
+ List[List[str]],
51
+ ] = None
52
+
53
+ unique_blocks_name: Union[
54
+ str,
55
+ List[str],
56
+ List[List[str]],
57
+ ] = dataclasses.field(default_factory=list)
58
+
59
+ dummy_blocks_names: Union[
60
+ List[str],
61
+ List[List[str]],
62
+ ] = dataclasses.field(default_factory=list)
63
+
64
+ forward_pattern: Union[
65
+ ForwardPattern,
66
+ List[ForwardPattern],
67
+ List[List[ForwardPattern]],
68
+ ] = None
69
+
70
+ # modify cache context params for specific blocks.
71
+ params_modifiers: Union[
72
+ ParamsModifier,
73
+ List[ParamsModifier],
74
+ List[List[ParamsModifier]],
75
+ ] = None
76
+
77
+ check_num_outputs: bool = True
78
+
79
+ # Pipeline Level Flags
80
+ # Patch Functor: Flux, etc.
81
+ patch_functor: Optional[PatchFunctor] = None
82
+ # Flags for separate cfg
83
+ has_separate_cfg: bool = False
84
+
85
+ # Other Flags
86
+ disable_patch: bool = False
87
+
88
+ # Flags to control auto block adapter
89
+ # NOTE: NOT support for multi-transformers.
90
+ auto: bool = False
91
+ allow_prefixes: List[str] = dataclasses.field(
92
+ default_factory=lambda: [
93
+ "transformer",
94
+ "single_transformer",
95
+ "blocks",
96
+ "layers",
97
+ "single_stream_blocks",
98
+ "double_stream_blocks",
99
+ ]
100
+ )
101
+ check_prefixes: bool = True
102
+ allow_suffixes: List[str] = dataclasses.field(
103
+ default_factory=lambda: ["TransformerBlock"]
104
+ )
105
+ check_suffixes: bool = False
106
+ blocks_policy: str = dataclasses.field(
107
+ default="max", metadata={"allowed_values": ["max", "min"]}
108
+ )
109
+
110
+ def __post_init__(self):
111
+ assert any((self.pipe is not None, self.transformer is not None))
112
+ self.patchify()
113
+
114
+ def patchify(self, *args, **kwargs):
115
+ # Process some specificial cases, specific for transformers
116
+ # that has different forward patterns between single_transformer_blocks
117
+ # and transformer_blocks , such as Flux (diffusers < 0.35.0).
118
+ if self.patch_functor is not None and not self.disable_patch:
119
+ if self.transformer is not None:
120
+ self.patch_functor.apply(self.transformer, *args, **kwargs)
121
+ else:
122
+ assert hasattr(self.pipe, "transformer")
123
+ self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
124
+
125
+ @staticmethod
126
+ def auto_block_adapter(
127
+ adapter: "BlockAdapter",
128
+ ) -> "BlockAdapter":
129
+ assert adapter.auto, (
130
+ "Please manually set `auto` to True, or, manually "
131
+ "set all the transformer blocks configuration."
132
+ )
133
+ assert adapter.pipe is not None, "adapter.pipe can not be None."
134
+ assert (
135
+ adapter.forward_pattern is not None
136
+ ), "adapter.forward_pattern can not be None."
137
+ pipe = adapter.pipe
138
+
139
+ assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
140
+
141
+ transformer = pipe.transformer
142
+
143
+ # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
144
+ blocks, blocks_name = BlockAdapter.find_blocks(
145
+ transformer=transformer,
146
+ allow_prefixes=adapter.allow_prefixes,
147
+ allow_suffixes=adapter.allow_suffixes,
148
+ check_prefixes=adapter.check_prefixes,
149
+ check_suffixes=adapter.check_suffixes,
150
+ blocks_policy=adapter.blocks_policy,
151
+ forward_pattern=adapter.forward_pattern,
152
+ check_num_outputs=adapter.check_num_outputs,
153
+ )
154
+
155
+ return BlockAdapter(
156
+ pipe=pipe,
157
+ transformer=transformer,
158
+ blocks=blocks,
159
+ blocks_name=blocks_name,
160
+ forward_pattern=adapter.forward_pattern,
161
+ )
162
+
163
+ @staticmethod
164
+ def check_block_adapter(
165
+ adapter: "BlockAdapter",
166
+ ) -> bool:
167
+ def _check_warning(attr: str):
168
+ if getattr(adapter, attr, None) is None:
169
+ logger.warning(f"{attr} is None!")
170
+ return False
171
+ return True
172
+
173
+ if not _check_warning("pipe"):
174
+ return False
175
+
176
+ if not _check_warning("transformer"):
177
+ return False
178
+
179
+ if not _check_warning("blocks"):
180
+ return False
181
+
182
+ if not _check_warning("blocks_name"):
183
+ return False
184
+
185
+ if not _check_warning("forward_pattern"):
186
+ return False
187
+
188
+ if isinstance(adapter.blocks, list):
189
+ for i, blocks in enumerate(adapter.blocks):
190
+ if not isinstance(blocks, torch.nn.ModuleList):
191
+ logger.warning(f"blocks[{i}] is not ModuleList.")
192
+ return False
193
+ else:
194
+ if not isinstance(adapter.blocks, torch.nn.ModuleList):
195
+ logger.warning("blocks is not ModuleList.")
196
+ return False
197
+
198
+ return True
199
+
200
+ @staticmethod
201
+ def find_blocks(
202
+ transformer: torch.nn.Module,
203
+ allow_prefixes: List[str] = [
204
+ "transformer",
205
+ "single_transformer",
206
+ "blocks",
207
+ "layers",
208
+ "single_stream_blocks",
209
+ "double_stream_blocks",
210
+ ],
211
+ allow_suffixes: List[str] = [
212
+ "TransformerBlock",
213
+ ],
214
+ check_prefixes: bool = True,
215
+ check_suffixes: bool = False,
216
+ **kwargs,
217
+ ) -> Tuple[torch.nn.ModuleList, str]:
218
+ # Check prefixes
219
+ if check_prefixes:
220
+ blocks_names = []
221
+ for attr_name in dir(transformer):
222
+ for prefix in allow_prefixes:
223
+ if attr_name.startswith(prefix):
224
+ blocks_names.append(attr_name)
225
+ else:
226
+ blocks_names = dir(transformer)
227
+
228
+ # Check ModuleList
229
+ valid_names = []
230
+ valid_count = []
231
+ forward_pattern = kwargs.pop("forward_pattern", None)
232
+ for blocks_name in blocks_names:
233
+ if blocks := getattr(transformer, blocks_name, None):
234
+ if isinstance(blocks, torch.nn.ModuleList):
235
+ block = blocks[0]
236
+ block_cls_name = block.__class__.__name__
237
+ # Check suffixes
238
+ if isinstance(block, torch.nn.Module) and (
239
+ any(
240
+ (
241
+ block_cls_name.endswith(allow_suffix)
242
+ for allow_suffix in allow_suffixes
243
+ )
244
+ )
245
+ or (not check_suffixes)
246
+ ):
247
+ # May check forward pattern
248
+ if forward_pattern is not None:
249
+ if BlockAdapter.match_blocks_pattern(
250
+ blocks,
251
+ forward_pattern,
252
+ logging=False,
253
+ **kwargs,
254
+ ):
255
+ valid_names.append(blocks_name)
256
+ valid_count.append(len(blocks))
257
+ else:
258
+ valid_names.append(blocks_name)
259
+ valid_count.append(len(blocks))
260
+
261
+ if not valid_names:
262
+ raise ValueError(
263
+ "Auto selected transformer blocks failed, please set it manually."
264
+ )
265
+
266
+ final_name = valid_names[0]
267
+ final_count = valid_count[0]
268
+ block_policy = kwargs.get("blocks_policy", "max")
269
+
270
+ for blocks_name, count in zip(valid_names, valid_count):
271
+ blocks = getattr(transformer, blocks_name)
272
+ logger.info(
273
+ f"Auto selected transformer blocks: {blocks_name}, "
274
+ f"class: {blocks[0].__class__.__name__}, "
275
+ f"num blocks: {count}"
276
+ )
277
+ if block_policy == "max":
278
+ if final_count < count:
279
+ final_count = count
280
+ final_name = blocks_name
281
+ else:
282
+ if final_count > count:
283
+ final_count = count
284
+ final_name = blocks_name
285
+
286
+ final_blocks = getattr(transformer, final_name)
287
+
288
+ logger.info(
289
+ f"Final selected transformer blocks: {final_name}, "
290
+ f"class: {final_blocks[0].__class__.__name__}, "
291
+ f"num blocks: {final_count}, block_policy: {block_policy}."
292
+ )
293
+
294
+ return final_blocks, final_name
295
+
296
+ @staticmethod
297
+ def match_block_pattern(
298
+ block: torch.nn.Module,
299
+ forward_pattern: ForwardPattern,
300
+ **kwargs,
301
+ ) -> bool:
302
+ assert (
303
+ forward_pattern.Supported
304
+ and forward_pattern in ForwardPattern.supported_patterns()
305
+ ), f"Pattern {forward_pattern} is not support now!"
306
+
307
+ forward_parameters = set(
308
+ inspect.signature(block.forward).parameters.keys()
309
+ )
310
+
311
+ in_matched = True
312
+ out_matched = True
313
+
314
+ if kwargs.get("check_num_outputs", True):
315
+ num_outputs = str(
316
+ inspect.signature(block.forward).return_annotation
317
+ ).count("torch.Tensor")
318
+
319
+ if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
320
+ # output pattern not match
321
+ out_matched = False
322
+
323
+ for required_param in forward_pattern.In:
324
+ if required_param not in forward_parameters:
325
+ in_matched = False
326
+
327
+ return in_matched and out_matched
328
+
329
+ @staticmethod
330
+ def match_blocks_pattern(
331
+ transformer_blocks: torch.nn.ModuleList,
332
+ forward_pattern: ForwardPattern,
333
+ logging: bool = True,
334
+ **kwargs,
335
+ ) -> bool:
336
+ assert (
337
+ forward_pattern.Supported
338
+ and forward_pattern in ForwardPattern.supported_patterns()
339
+ ), f"Pattern {forward_pattern} is not support now!"
340
+
341
+ assert isinstance(transformer_blocks, torch.nn.ModuleList)
342
+
343
+ pattern_matched_states = []
344
+ for block in transformer_blocks:
345
+ pattern_matched_states.append(
346
+ BlockAdapter.match_block_pattern(
347
+ block,
348
+ forward_pattern,
349
+ **kwargs,
350
+ )
351
+ )
352
+
353
+ pattern_matched = all(pattern_matched_states) # all block match
354
+ if pattern_matched and logging:
355
+ block_cls_names = [
356
+ block.__class__.__name__ for block in transformer_blocks
357
+ ]
358
+ block_cls_names = list(set(block_cls_names))
359
+ if len(block_cls_names) == 1:
360
+ block_cls_names = block_cls_names[0]
361
+ logger.info(
362
+ f"Match Block Forward Pattern: {block_cls_names}, {forward_pattern}"
363
+ f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
364
+ )
365
+
366
+ return pattern_matched
367
+
368
+ @staticmethod
369
+ def normalize(
370
+ adapter: "BlockAdapter",
371
+ ) -> "BlockAdapter":
372
+
373
+ if getattr(adapter, "_is_normalized", False):
374
+ return adapter
375
+
376
+ if not isinstance(adapter.transformer, list):
377
+ adapter.transformer = [adapter.transformer]
378
+
379
+ if isinstance(adapter.blocks, torch.nn.ModuleList):
380
+ # blocks_0 = [[blocks_0,],] -> match [TRN_0,]
381
+ adapter.blocks = [[adapter.blocks]]
382
+ elif isinstance(adapter.blocks, list):
383
+ if isinstance(adapter.blocks[0], torch.nn.ModuleList):
384
+ # [blocks_0, blocks_1] -> [[blocks_0, blocks_1],] -> match [TRN_0,]
385
+ if len(adapter.blocks) == len(adapter.transformer):
386
+ adapter.blocks = [[blocks] for blocks in adapter.blocks]
387
+ else:
388
+ adapter.blocks = [adapter.blocks]
389
+ elif isinstance(adapter.blocks[0], list):
390
+ # [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
391
+ pass
392
+
393
+ if isinstance(adapter.blocks_name, str):
394
+ adapter.blocks_name = [[adapter.blocks_name]]
395
+ elif isinstance(adapter.blocks_name, list):
396
+ if isinstance(adapter.blocks_name[0], str):
397
+ if len(adapter.blocks_name) == len(adapter.transformer):
398
+ adapter.blocks_name = [
399
+ [blocks_name] for blocks_name in adapter.blocks_name
400
+ ]
401
+ else:
402
+ adapter.blocks_name = [adapter.blocks_name]
403
+ elif isinstance(adapter.blocks_name[0], list):
404
+ pass
405
+
406
+ if isinstance(adapter.forward_pattern, ForwardPattern):
407
+ adapter.forward_pattern = [[adapter.forward_pattern]]
408
+ elif isinstance(adapter.forward_pattern, list):
409
+ if isinstance(adapter.forward_pattern[0], ForwardPattern):
410
+ if len(adapter.forward_pattern) == len(adapter.transformer):
411
+ adapter.forward_pattern = [
412
+ [forward_pattern]
413
+ for forward_pattern in adapter.forward_pattern
414
+ ]
415
+ else:
416
+ adapter.forward_pattern = [adapter.forward_pattern]
417
+ elif isinstance(adapter.forward_pattern[0], list):
418
+ pass
419
+
420
+ if isinstance(adapter.dummy_blocks_names, list):
421
+ if len(adapter.dummy_blocks_names) > 0:
422
+ if isinstance(adapter.dummy_blocks_names[0], str):
423
+ if len(adapter.dummy_blocks_names) == len(
424
+ adapter.transformer
425
+ ):
426
+ adapter.dummy_blocks_names = [
427
+ [dummy_blocks_names]
428
+ for dummy_blocks_names in adapter.dummy_blocks_names
429
+ ]
430
+ else:
431
+ adapter.dummy_blocks_names = [
432
+ adapter.dummy_blocks_names
433
+ ]
434
+ elif isinstance(adapter.dummy_blocks_names[0], list):
435
+ pass
436
+ else:
437
+ # Empty dummy_blocks_names
438
+ adapter.dummy_blocks_names = [
439
+ [] for _ in range(len(adapter.transformer))
440
+ ]
441
+
442
+ if adapter.params_modifiers is not None:
443
+ if isinstance(adapter.params_modifiers, ParamsModifier):
444
+ adapter.params_modifiers = [[adapter.params_modifiers]]
445
+ elif isinstance(adapter.params_modifiers, list):
446
+ if isinstance(adapter.params_modifiers[0], ParamsModifier):
447
+ if len(adapter.params_modifiers) == len(
448
+ adapter.transformer
449
+ ):
450
+ adapter.params_modifiers = [
451
+ [params_modifiers]
452
+ for params_modifiers in adapter.params_modifiers
453
+ ]
454
+ else:
455
+ adapter.params_modifiers = [adapter.params_modifiers]
456
+ elif isinstance(adapter.params_modifiers[0], list):
457
+ pass
458
+
459
+ assert len(adapter.transformer) == len(adapter.blocks)
460
+ assert len(adapter.transformer) == len(adapter.blocks_name)
461
+ assert len(adapter.transformer) == len(adapter.forward_pattern)
462
+ assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
463
+ if adapter.params_modifiers is not None:
464
+ assert len(adapter.transformer) == len(adapter.params_modifiers)
465
+
466
+ for i in range(len(adapter.blocks)):
467
+ assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
468
+ assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
469
+
470
+ if len(adapter.unique_blocks_name) == 0:
471
+ for i in range(len(adapter.transformer)):
472
+ # Generate unique blocks names
473
+ adapter.unique_blocks_name.append(
474
+ [
475
+ f"{name}_{hash(id(blocks))}"
476
+ for blocks, name in zip(
477
+ adapter.blocks[i],
478
+ adapter.blocks_name[i],
479
+ )
480
+ ]
481
+ )
482
+
483
+ assert len(adapter.transformer) == len(adapter.unique_blocks_name)
484
+
485
+ # Match Forward Pattern
486
+ for i in range(len(adapter.transformer)):
487
+ for forward_pattern, blocks in zip(
488
+ adapter.forward_pattern[i], adapter.blocks[i]
489
+ ):
490
+ assert BlockAdapter.match_blocks_pattern(
491
+ blocks,
492
+ forward_pattern=forward_pattern,
493
+ check_num_outputs=adapter.check_num_outputs,
494
+ ), (
495
+ "No block forward pattern matched, "
496
+ f"supported lists: {ForwardPattern.supported_patterns()}"
497
+ )
498
+
499
+ adapter._is_normalized = True
500
+
501
+ return adapter
502
+
503
+ @classmethod
504
+ def assert_normalized(cls, adapter: "BlockAdapter"):
505
+ if not getattr(adapter, "_is_normalized", False):
506
+ raise RuntimeError("block_adapter must be normailzed.")
507
+
508
+ @classmethod
509
+ def is_cached(cls, adapter: Any) -> bool:
510
+ if isinstance(adapter, cls):
511
+ cls.assert_normalized(adapter)
512
+ return all(
513
+ (
514
+ getattr(adapter.pipe, "_is_cached", False),
515
+ getattr(adapter.transformer[0], "_is_cached", False),
516
+ )
517
+ )
518
+ elif isinstance(
519
+ adapter,
520
+ (DiffusionPipeline, torch.nn.Module),
521
+ ):
522
+ return getattr(adapter, "_is_cached", False)
523
+ elif isinstance(adapter, list): # [TRN_0,...]
524
+ assert isinstance(adapter[0], torch.nn.Module)
525
+ return getattr(adapter[0], "_is_cached", False)
526
+ else:
527
+ raise TypeError(f"Can't check this type: {adapter}!")
528
+
529
+ @classmethod
530
+ def flatten(cls, attr: List[List[Any]]):
531
+ if isinstance(attr, list):
532
+ if not isinstance(attr[0], list):
533
+ return attr
534
+ flatten_attr = []
535
+ for i in range(len(attr)):
536
+ flatten_attr.extend(attr[i])
537
+ return flatten_attr
538
+ return attr
@@ -0,0 +1,77 @@
1
+ from typing import Any, Tuple, List, Dict
2
+
3
+ from diffusers import DiffusionPipeline
4
+ from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
5
+
6
+ from cache_dit.logger import init_logger
7
+
8
+ logger = init_logger(__name__)
9
+
10
+
11
+ class BlockAdapterRegistry:
12
+ _adapters: Dict[str, BlockAdapter] = {}
13
+ _predefined_adapters_has_spearate_cfg: List[str] = {
14
+ "QwenImage",
15
+ "Wan",
16
+ "CogView4",
17
+ "Cosmos",
18
+ "SkyReelsV2",
19
+ "Chroma",
20
+ }
21
+
22
+ @classmethod
23
+ def register(cls, name):
24
+ def decorator(func):
25
+ cls._adapters[name] = func
26
+ return func
27
+
28
+ return decorator
29
+
30
+ @classmethod
31
+ def get_adapter(
32
+ cls,
33
+ pipe: DiffusionPipeline | str | Any,
34
+ **kwargs,
35
+ ) -> BlockAdapter:
36
+ if not isinstance(pipe, str):
37
+ pipe_cls_name: str = pipe.__class__.__name__
38
+ else:
39
+ pipe_cls_name = pipe
40
+
41
+ for name in cls._adapters:
42
+ if pipe_cls_name.startswith(name):
43
+ return cls._adapters[name](pipe, **kwargs)
44
+
45
+ return BlockAdapter()
46
+
47
+ @classmethod
48
+ def has_separate_cfg(
49
+ cls,
50
+ pipe: DiffusionPipeline | str | Any,
51
+ ) -> bool:
52
+ if cls.get_adapter(
53
+ pipe,
54
+ disable_patch=True,
55
+ ).has_separate_cfg:
56
+ return True
57
+
58
+ pipe_cls_name = pipe.__class__.__name__
59
+ for name in cls._predefined_adapters_has_spearate_cfg:
60
+ if pipe_cls_name.startswith(name):
61
+ return True
62
+
63
+ return False
64
+
65
+ @classmethod
66
+ def is_supported(cls, pipe) -> bool:
67
+ pipe_cls_name: str = pipe.__class__.__name__
68
+
69
+ for name in cls._adapters:
70
+ if pipe_cls_name.startswith(name):
71
+ return True
72
+ return False
73
+
74
+ @classmethod
75
+ def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
76
+ val_pipelines = cls._adapters.keys()
77
+ return len(val_pipelines), [p + "*" for p in val_pipelines]