cache-dit 0.2.28__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -2,6 +2,7 @@ import torch
2
2
 
3
3
  import inspect
4
4
  import dataclasses
5
+ from collections.abc import Iterable
5
6
 
6
7
  from typing import Any, Tuple, List, Optional, Union
7
8
 
@@ -82,9 +83,6 @@ class BlockAdapter:
82
83
  # Flags for separate cfg
83
84
  has_separate_cfg: bool = False
84
85
 
85
- # Other Flags
86
- disable_patch: bool = False
87
-
88
86
  # Flags to control auto block adapter
89
87
  # NOTE: NOT support for multi-transformers.
90
88
  auto: bool = False
@@ -107,15 +105,92 @@ class BlockAdapter:
107
105
  default="max", metadata={"allowed_values": ["max", "min"]}
108
106
  )
109
107
 
108
+ # Other Flags
109
+ skip_post_init: bool = False
110
+
110
111
  def __post_init__(self):
112
+ if self.skip_post_init:
113
+ return
111
114
  assert any((self.pipe is not None, self.transformer is not None))
112
- self.patchify()
115
+ self.maybe_fill_attrs()
116
+ self.maybe_patchify()
117
+
118
+ def maybe_fill_attrs(self):
119
+ # NOTE: This func should be call before normalize.
120
+ # Allow empty `blocks_names`, we will auto fill it.
121
+ # TODO: preprocess more empty attrs.
122
+ if (
123
+ self.transformer is not None
124
+ and self.blocks is not None
125
+ and self.blocks_name is None
126
+ ):
113
127
 
114
- def patchify(self, *args, **kwargs):
128
+ def _find(transformer, blocks):
129
+ attr_names = dir(transformer)
130
+ assert isinstance(blocks, torch.nn.ModuleList)
131
+ blocks_name = None
132
+ for attr_name in attr_names:
133
+ if attr := getattr(transformer, attr_name, None):
134
+ if isinstance(attr, torch.nn.ModuleList) and id(
135
+ attr
136
+ ) == id(blocks):
137
+ blocks_name = attr_name
138
+ break
139
+ assert (
140
+ blocks_name is not None
141
+ ), "No blocks_name match, please set it manually!"
142
+ return blocks_name
143
+
144
+ if self.nested_depth(self.transformer) == 0:
145
+ if self.nested_depth(self.blocks) == 0: # str
146
+ self.blocks_name = _find(self.transformer, self.blocks)
147
+ elif self.nested_depth(self.blocks) == 1:
148
+ self.blocks_name = [
149
+ _find(self.transformer, blocks)
150
+ for blocks in self.blocks
151
+ ]
152
+ else:
153
+ raise ValueError(
154
+ "Blocks nested depth can't more than 1 if transformer "
155
+ f"is not a list, current is: {self.nested_depth(self.blocks)}"
156
+ )
157
+ elif self.nested_depth(self.transformer) == 1: # List[str]
158
+ if self.nested_depth(self.blocks) == 1: # List[str]
159
+ assert len(self.transformer) == len(self.blocks)
160
+ self.blocks_name = [
161
+ _find(transformer, blocks)
162
+ for transformer, blocks in zip(
163
+ self.transformer, self.blocks
164
+ )
165
+ ]
166
+ elif self.nested_depth(self.blocks) == 2: # List[List[str]]
167
+ assert len(self.transformer) == len(self.blocks)
168
+ self.blocks_name = []
169
+ for i in range(len(self.blocks)):
170
+ self.blocks_name.append(
171
+ [
172
+ _find(self.transformer[i], blocks)
173
+ for blocks in self.blocks[i]
174
+ ]
175
+ )
176
+ else:
177
+ raise ValueError(
178
+ "Blocks nested depth can only be 1 or 2 "
179
+ "if transformer is a list, current is: "
180
+ f"{self.nested_depth(self.blocks)}"
181
+ )
182
+ else:
183
+ raise ValueError(
184
+ "transformer nested depth can't more than 1, "
185
+ f"current is: {self.nested_depth(self.transformer)}"
186
+ )
187
+ logger.info(f"Auto fill blocks_name: {self.blocks_name}.")
188
+
189
+ def maybe_patchify(self, *args, **kwargs):
115
190
  # Process some specificial cases, specific for transformers
116
191
  # that has different forward patterns between single_transformer_blocks
117
192
  # and transformer_blocks , such as Flux (diffusers < 0.35.0).
118
- if self.patch_functor is not None and not self.disable_patch:
193
+ if self.patch_functor is not None:
119
194
  if self.transformer is not None:
120
195
  self.patch_functor.apply(self.transformer, *args, **kwargs)
121
196
  else:
@@ -141,7 +216,7 @@ class BlockAdapter:
141
216
  transformer = pipe.transformer
142
217
 
143
218
  # "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
144
- blocks, blocks_name = BlockAdapter.find_blocks(
219
+ blocks, blocks_name = BlockAdapter.find_match_blocks(
145
220
  transformer=transformer,
146
221
  allow_prefixes=adapter.allow_prefixes,
147
222
  allow_suffixes=adapter.allow_suffixes,
@@ -164,6 +239,10 @@ class BlockAdapter:
164
239
  def check_block_adapter(
165
240
  adapter: "BlockAdapter",
166
241
  ) -> bool:
242
+
243
+ if getattr(adapter, "_is_normlized", False):
244
+ return True
245
+
167
246
  def _check_warning(attr: str):
168
247
  if getattr(adapter, attr, None) is None:
169
248
  logger.warning(f"{attr} is None!")
@@ -185,24 +264,23 @@ class BlockAdapter:
185
264
  if not _check_warning("forward_pattern"):
186
265
  return False
187
266
 
188
- if isinstance(adapter.blocks, list):
189
- for i, blocks in enumerate(adapter.blocks):
190
- if not isinstance(blocks, torch.nn.ModuleList):
191
- logger.warning(f"blocks[{i}] is not ModuleList.")
192
- return False
267
+ if BlockAdapter.nested_depth(adapter.blocks) == 0:
268
+ blocks = adapter.blocks
193
269
  else:
194
- if not isinstance(adapter.blocks, torch.nn.ModuleList):
195
- logger.warning("blocks is not ModuleList.")
196
- return False
270
+ blocks = BlockAdapter.flatten(adapter.blocks)[0]
271
+
272
+ if not isinstance(blocks, torch.nn.ModuleList):
273
+ logger.warning("blocks is not ModuleList.")
274
+ return False
197
275
 
198
276
  return True
199
277
 
200
278
  @staticmethod
201
- def find_blocks(
279
+ def find_match_blocks(
202
280
  transformer: torch.nn.Module,
203
281
  allow_prefixes: List[str] = [
204
- "transformer",
205
- "single_transformer",
282
+ "transformer_blocks",
283
+ "single_transformer_blocks",
206
284
  "blocks",
207
285
  "layers",
208
286
  "single_stream_blocks",
@@ -230,10 +308,10 @@ class BlockAdapter:
230
308
  valid_count = []
231
309
  forward_pattern = kwargs.pop("forward_pattern", None)
232
310
  for blocks_name in blocks_names:
233
- if blocks := getattr(transformer, blocks_name, None):
311
+ if (blocks := getattr(transformer, blocks_name, None)) is not None:
234
312
  if isinstance(blocks, torch.nn.ModuleList):
235
313
  block = blocks[0]
236
- block_cls_name = block.__class__.__name__
314
+ block_cls_name: str = block.__class__.__name__
237
315
  # Check suffixes
238
316
  if isinstance(block, torch.nn.Module) and (
239
317
  any(
@@ -293,6 +371,18 @@ class BlockAdapter:
293
371
 
294
372
  return final_blocks, final_name
295
373
 
374
+ @staticmethod
375
+ def find_blocks(
376
+ transformer: torch.nn.Module,
377
+ ) -> List[torch.nn.ModuleList]:
378
+ total_blocks = []
379
+ for attr in dir(transformer):
380
+ if (blocks := getattr(transformer, attr, None)) is not None:
381
+ if isinstance(blocks, torch.nn.ModuleList):
382
+ if isinstance(blocks[0], torch.nn.Module):
383
+ total_blocks.append(blocks)
384
+ return total_blocks
385
+
296
386
  @staticmethod
297
387
  def match_block_pattern(
298
388
  block: torch.nn.Module,
@@ -373,103 +463,51 @@ class BlockAdapter:
373
463
  if getattr(adapter, "_is_normalized", False):
374
464
  return adapter
375
465
 
376
- if not isinstance(adapter.transformer, list):
466
+ if BlockAdapter.nested_depth(adapter.transformer) == 0:
377
467
  adapter.transformer = [adapter.transformer]
378
468
 
379
- if isinstance(adapter.blocks, torch.nn.ModuleList):
380
- # blocks_0 = [[blocks_0,],] -> match [TRN_0,]
381
- adapter.blocks = [[adapter.blocks]]
382
- elif isinstance(adapter.blocks, list):
383
- if isinstance(adapter.blocks[0], torch.nn.ModuleList):
384
- # [blocks_0, blocks_1] -> [[blocks_0, blocks_1],] -> match [TRN_0,]
385
- if len(adapter.blocks) == len(adapter.transformer):
386
- adapter.blocks = [[blocks] for blocks in adapter.blocks]
387
- else:
388
- adapter.blocks = [adapter.blocks]
389
- elif isinstance(adapter.blocks[0], list):
390
- # [[blocks_0, blocks_1],[blocks_2, blocks_3],] -> match [TRN_0, TRN_1,]
391
- pass
392
-
393
- if isinstance(adapter.blocks_name, str):
394
- adapter.blocks_name = [[adapter.blocks_name]]
395
- elif isinstance(adapter.blocks_name, list):
396
- if isinstance(adapter.blocks_name[0], str):
397
- if len(adapter.blocks_name) == len(adapter.transformer):
398
- adapter.blocks_name = [
399
- [blocks_name] for blocks_name in adapter.blocks_name
400
- ]
401
- else:
402
- adapter.blocks_name = [adapter.blocks_name]
403
- elif isinstance(adapter.blocks_name[0], list):
404
- pass
405
-
406
- if isinstance(adapter.forward_pattern, ForwardPattern):
407
- adapter.forward_pattern = [[adapter.forward_pattern]]
408
- elif isinstance(adapter.forward_pattern, list):
409
- if isinstance(adapter.forward_pattern[0], ForwardPattern):
410
- if len(adapter.forward_pattern) == len(adapter.transformer):
411
- adapter.forward_pattern = [
412
- [forward_pattern]
413
- for forward_pattern in adapter.forward_pattern
414
- ]
415
- else:
416
- adapter.forward_pattern = [adapter.forward_pattern]
417
- elif isinstance(adapter.forward_pattern[0], list):
418
- pass
419
-
420
- if isinstance(adapter.dummy_blocks_names, list):
421
- if len(adapter.dummy_blocks_names) > 0:
422
- if isinstance(adapter.dummy_blocks_names[0], str):
423
- if len(adapter.dummy_blocks_names) == len(
424
- adapter.transformer
425
- ):
426
- adapter.dummy_blocks_names = [
427
- [dummy_blocks_names]
428
- for dummy_blocks_names in adapter.dummy_blocks_names
429
- ]
430
- else:
431
- adapter.dummy_blocks_names = [
432
- adapter.dummy_blocks_names
433
- ]
434
- elif isinstance(adapter.dummy_blocks_names[0], list):
435
- pass
436
- else:
437
- # Empty dummy_blocks_names
438
- adapter.dummy_blocks_names = [
439
- [] for _ in range(len(adapter.transformer))
440
- ]
441
-
442
- if adapter.params_modifiers is not None:
443
- if isinstance(adapter.params_modifiers, ParamsModifier):
444
- adapter.params_modifiers = [[adapter.params_modifiers]]
445
- elif isinstance(adapter.params_modifiers, list):
446
- if isinstance(adapter.params_modifiers[0], ParamsModifier):
447
- if len(adapter.params_modifiers) == len(
448
- adapter.transformer
449
- ):
450
- adapter.params_modifiers = [
451
- [params_modifiers]
452
- for params_modifiers in adapter.params_modifiers
453
- ]
469
+ def _normalize_attr(attr: Any):
470
+ normalized_attr = attr
471
+ if attr is None:
472
+ return normalized_attr
473
+
474
+ if BlockAdapter.nested_depth(attr) == 0:
475
+ normalized_attr = [[attr]]
476
+ elif BlockAdapter.nested_depth(attr) == 1: # List
477
+ if attr: # not-empty
478
+ if len(attr) == len(adapter.transformer):
479
+ normalized_attr = [[a] for a in attr]
454
480
  else:
455
- adapter.params_modifiers = [adapter.params_modifiers]
456
- elif isinstance(adapter.params_modifiers[0], list):
457
- pass
481
+ normalized_attr = [attr]
482
+ else: # [] empty
483
+ normalized_attr = [
484
+ [] for _ in range(len(adapter.transformer))
485
+ ]
486
+
487
+ assert len(adapter.transformer) == len(normalized_attr)
488
+ return normalized_attr
489
+
490
+ adapter.blocks = _normalize_attr(adapter.blocks)
491
+ adapter.blocks_name = _normalize_attr(adapter.blocks_name)
492
+ adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
493
+ adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
494
+ adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
495
+ BlockAdapter.unique(adapter)
458
496
 
459
- assert len(adapter.transformer) == len(adapter.blocks)
460
- assert len(adapter.transformer) == len(adapter.blocks_name)
461
- assert len(adapter.transformer) == len(adapter.forward_pattern)
462
- assert len(adapter.transformer) == len(adapter.dummy_blocks_names)
463
- if adapter.params_modifiers is not None:
464
- assert len(adapter.transformer) == len(adapter.params_modifiers)
497
+ adapter._is_normalized = True
498
+
499
+ return adapter
465
500
 
501
+ @classmethod
502
+ def unique(cls, adapter: "BlockAdapter"):
503
+ # NOTE: Users should never call this function
466
504
  for i in range(len(adapter.blocks)):
467
505
  assert len(adapter.blocks[i]) == len(adapter.blocks_name[i])
468
506
  assert len(adapter.blocks[i]) == len(adapter.forward_pattern[i])
469
507
 
508
+ # Generate unique blocks names
470
509
  if len(adapter.unique_blocks_name) == 0:
471
510
  for i in range(len(adapter.transformer)):
472
- # Generate unique blocks names
473
511
  adapter.unique_blocks_name.append(
474
512
  [
475
513
  f"{name}_{hash(id(blocks))}"
@@ -479,10 +517,10 @@ class BlockAdapter:
479
517
  )
480
518
  ]
481
519
  )
520
+ else:
521
+ assert len(adapter.transformer) == len(adapter.unique_blocks_name)
482
522
 
483
- assert len(adapter.transformer) == len(adapter.unique_blocks_name)
484
-
485
- # Match Forward Pattern
523
+ # Also check Match Forward Pattern
486
524
  for i in range(len(adapter.transformer)):
487
525
  for forward_pattern, blocks in zip(
488
526
  adapter.forward_pattern[i], adapter.blocks[i]
@@ -496,10 +534,6 @@ class BlockAdapter:
496
534
  f"supported lists: {ForwardPattern.supported_patterns()}"
497
535
  )
498
536
 
499
- adapter._is_normalized = True
500
-
501
- return adapter
502
-
503
537
  @classmethod
504
538
  def assert_normalized(cls, adapter: "BlockAdapter"):
505
539
  if not getattr(adapter, "_is_normalized", False):
@@ -527,12 +561,46 @@ class BlockAdapter:
527
561
  raise TypeError(f"Can't check this type: {adapter}!")
528
562
 
529
563
  @classmethod
530
- def flatten(cls, attr: List[List[Any]]):
531
- if isinstance(attr, list):
532
- if not isinstance(attr[0], list):
533
- return attr
534
- flatten_attr = []
535
- for i in range(len(attr)):
536
- flatten_attr.extend(attr[i])
537
- return flatten_attr
538
- return attr
564
+ def nested_depth(cls, obj: Any):
565
+ # str: 0; List[str]: 1; List[List[str]]: 2
566
+ atom_types = (
567
+ str,
568
+ bytes,
569
+ torch.nn.ModuleList,
570
+ torch.nn.Module,
571
+ torch.Tensor,
572
+ )
573
+ if isinstance(obj, atom_types):
574
+ return 0
575
+ if not isinstance(obj, Iterable):
576
+ return 0
577
+ if isinstance(obj, dict):
578
+ items = obj.values()
579
+ else:
580
+ items = obj
581
+
582
+ max_depth = 0
583
+ for item in items:
584
+ current_depth = cls.nested_depth(item)
585
+ if current_depth > max_depth:
586
+ max_depth = current_depth
587
+ return 1 + max_depth
588
+
589
+ @classmethod
590
+ def flatten(cls, attr: List[Any]) -> List[Any]:
591
+ atom_types = (
592
+ str,
593
+ bytes,
594
+ torch.nn.ModuleList,
595
+ torch.nn.Module,
596
+ torch.Tensor,
597
+ )
598
+ if not isinstance(attr, list):
599
+ return attr
600
+ flattened = []
601
+ for item in attr:
602
+ if isinstance(item, list) and not isinstance(item, atom_types):
603
+ flattened.extend(cls.flatten(item))
604
+ else:
605
+ flattened.append(item)
606
+ return flattened
@@ -47,15 +47,24 @@ class BlockAdapterRegistry:
47
47
  @classmethod
48
48
  def has_separate_cfg(
49
49
  cls,
50
- pipe: DiffusionPipeline | str | Any,
50
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
51
51
  ) -> bool:
52
- if cls.get_adapter(
53
- pipe,
54
- disable_patch=True,
55
- ).has_separate_cfg:
52
+
53
+ # Prefer custom setting from block adapter.
54
+ if isinstance(pipe_or_adapter, BlockAdapter):
55
+ return pipe_or_adapter.has_separate_cfg
56
+
57
+ has_separate_cfg = False
58
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
59
+ has_separate_cfg = cls.get_adapter(
60
+ pipe_or_adapter,
61
+ skip_post_init=True, # check cfg setting only
62
+ ).has_separate_cfg
63
+
64
+ if has_separate_cfg:
56
65
  return True
57
66
 
58
- pipe_cls_name = pipe.__class__.__name__
67
+ pipe_cls_name = pipe_or_adapter.__class__.__name__
59
68
  for name in cls._predefined_adapters_has_spearate_cfg:
60
69
  if pipe_cls_name.startswith(name):
61
70
  return True
@@ -29,36 +29,39 @@ class CachedAdapter:
29
29
  @classmethod
30
30
  def apply(
31
31
  cls,
32
- pipe: DiffusionPipeline = None,
33
- block_adapter: BlockAdapter = None,
32
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter,
34
33
  **cache_context_kwargs,
35
- ) -> DiffusionPipeline:
34
+ ) -> BlockAdapter:
36
35
  assert (
37
- pipe is not None or block_adapter is not None
36
+ pipe_or_adapter is not None
38
37
  ), "pipe or block_adapter can not both None!"
39
38
 
40
- if pipe is not None:
41
- if BlockAdapterRegistry.is_supported(pipe):
39
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
40
+ if BlockAdapterRegistry.is_supported(pipe_or_adapter):
42
41
  logger.info(
43
- f"{pipe.__class__.__name__} is officially supported by cache-dit. "
44
- "Use it's pre-defined BlockAdapter directly!"
42
+ f"{pipe_or_adapter.__class__.__name__} is officially "
43
+ "supported by cache-dit. Use it's pre-defined BlockAdapter "
44
+ "directly!"
45
+ )
46
+ block_adapter = BlockAdapterRegistry.get_adapter(
47
+ pipe_or_adapter
45
48
  )
46
- block_adapter = BlockAdapterRegistry.get_adapter(pipe)
47
49
  return cls.cachify(
48
50
  block_adapter,
49
51
  **cache_context_kwargs,
50
52
  )
51
53
  else:
52
54
  raise ValueError(
53
- f"{pipe.__class__.__name__} is not officially supported "
55
+ f"{pipe_or_adapter.__class__.__name__} is not officially supported "
54
56
  "by cache-dit, please set BlockAdapter instead!"
55
57
  )
56
58
  else:
59
+ assert isinstance(pipe_or_adapter, BlockAdapter)
57
60
  logger.info(
58
- "Adapting cache acceleration using custom BlockAdapter!"
61
+ "Adapting Cache Acceleration using custom BlockAdapter!"
59
62
  )
60
63
  return cls.cachify(
61
- block_adapter,
64
+ pipe_or_adapter,
62
65
  **cache_context_kwargs,
63
66
  )
64
67
 
@@ -67,7 +70,7 @@ class CachedAdapter:
67
70
  cls,
68
71
  block_adapter: BlockAdapter,
69
72
  **cache_context_kwargs,
70
- ) -> DiffusionPipeline:
73
+ ) -> BlockAdapter:
71
74
 
72
75
  if block_adapter.auto:
73
76
  block_adapter = BlockAdapter.auto_block_adapter(
@@ -93,7 +96,7 @@ class CachedAdapter:
93
96
  block_adapter,
94
97
  )
95
98
 
96
- return block_adapter.pipe
99
+ return block_adapter
97
100
 
98
101
  @classmethod
99
102
  def patch_params(
@@ -126,18 +129,29 @@ class CachedAdapter:
126
129
  params_shift += len(blocks)
127
130
 
128
131
  @classmethod
129
- def check_context_kwargs(cls, pipe, **cache_context_kwargs):
132
+ def check_context_kwargs(
133
+ cls,
134
+ block_adapter: BlockAdapter,
135
+ **cache_context_kwargs,
136
+ ):
130
137
  # Check cache_context_kwargs
131
138
  if not cache_context_kwargs["enable_spearate_cfg"]:
132
139
  # Check cfg for some specific case if users don't set it as True
133
- cache_context_kwargs["enable_spearate_cfg"] = (
134
- BlockAdapterRegistry.has_separate_cfg(pipe)
135
- )
136
- logger.info(
137
- f"Use default 'enable_spearate_cfg': "
138
- f"{cache_context_kwargs['enable_spearate_cfg']}, "
139
- f"Pipeline: {pipe.__class__.__name__}."
140
- )
140
+ if BlockAdapterRegistry.has_separate_cfg(block_adapter):
141
+ cache_context_kwargs["enable_spearate_cfg"] = True
142
+ logger.info(
143
+ f"Use custom 'enable_spearate_cfg' from BlockAdapter: True. "
144
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
145
+ )
146
+ else:
147
+ cache_context_kwargs["enable_spearate_cfg"] = (
148
+ BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
149
+ )
150
+ logger.info(
151
+ f"Use default 'enable_spearate_cfg' from block adapter "
152
+ f"register: {cache_context_kwargs['enable_spearate_cfg']}, "
153
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
154
+ )
141
155
 
142
156
  if cache_type := cache_context_kwargs.pop("cache_type", None):
143
157
  assert (
@@ -160,8 +174,7 @@ class CachedAdapter:
160
174
 
161
175
  # Check cache_context_kwargs
162
176
  cache_context_kwargs = cls.check_context_kwargs(
163
- block_adapter.pipe,
164
- **cache_context_kwargs,
177
+ block_adapter, **cache_context_kwargs
165
178
  )
166
179
  # Apply cache on pipeline: wrap cache context
167
180
  pipe_cls_name = block_adapter.pipe.__class__.__name__
@@ -23,3 +23,19 @@ def patch_cached_stats(
23
23
  module._residual_diffs = cache_manager.get_residual_diffs()
24
24
  module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
25
  module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
26
+
27
+
28
+ def remove_cached_stats(
29
+ module: torch.nn.Module | Any,
30
+ ):
31
+ if module is None:
32
+ return
33
+
34
+ if hasattr(module, "_cached_steps"):
35
+ del module._cached_steps
36
+ if hasattr(module, "_residual_diffs"):
37
+ del module._residual_diffs
38
+ if hasattr(module, "_cfg_cached_steps"):
39
+ del module._cfg_cached_steps
40
+ if hasattr(module, "_cfg_residual_diffs"):
41
+ del module._cfg_residual_diffs
@@ -63,6 +63,20 @@ class CachedContextManager:
63
63
  _context = self.new_context(*args, **kwargs)
64
64
  return _context
65
65
 
66
+ def remove_context(self, cached_context: CachedContext | str):
67
+ if isinstance(cached_context, CachedContext):
68
+ cached_context.clear_buffers()
69
+ if cached_context.name in self._cached_context_manager:
70
+ del self._cached_context_manager[cached_context.name]
71
+ else:
72
+ if cached_context in self._cached_context_manager:
73
+ self._cached_context_manager[cached_context].clear_buffers()
74
+ del self._cached_context_manager[cached_context]
75
+
76
+ def clear_contexts(self):
77
+ for cached_context in self._cached_context_manager:
78
+ self.remove_context(cached_context)
79
+
66
80
  @contextlib.contextmanager
67
81
  def enter_context(self, cached_context: CachedContext | str):
68
82
  old_cached_context = self._current_context