cache-dit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.0.dist-info/METADATA +350 -0
- cache_dit-0.1.0.dist-info/RECORD +31 -0
- cache_dit-0.1.0.dist-info/WHEEL +5 -0
- cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
- cache_dit-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/__init__.py
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
from diffusers import DiffusionPipeline
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def apply_fb_cache_on_transformer(transformer, *args, **kwargs):
|
|
10
|
+
transformer_cls_name: str = transformer.__class__.__name__
|
|
11
|
+
if transformer_cls_name.startswith("Flux"):
|
|
12
|
+
adapter_name = "flux"
|
|
13
|
+
elif transformer_cls_name.startswith("Mochi"):
|
|
14
|
+
adapter_name = "mochi"
|
|
15
|
+
elif transformer_cls_name.startswith("CogVideoX"):
|
|
16
|
+
adapter_name = "cogvideox"
|
|
17
|
+
elif transformer_cls_name.startswith("Wan"):
|
|
18
|
+
adapter_name = "wan"
|
|
19
|
+
else:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Unknown transformer class name: {transformer_cls_name}"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
25
|
+
apply_cache_on_transformer_fn = getattr(
|
|
26
|
+
adapter_module, "apply_cache_on_transformer"
|
|
27
|
+
)
|
|
28
|
+
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def apply_fb_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
32
|
+
assert isinstance(pipe, DiffusionPipeline)
|
|
33
|
+
|
|
34
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
35
|
+
if pipe_cls_name.startswith("Flux"):
|
|
36
|
+
adapter_name = "flux"
|
|
37
|
+
elif pipe_cls_name.startswith("Mochi"):
|
|
38
|
+
adapter_name = "mochi"
|
|
39
|
+
elif pipe_cls_name.startswith("CogVideoX"):
|
|
40
|
+
adapter_name = "cogvideox"
|
|
41
|
+
elif pipe_cls_name.startswith("Wan"):
|
|
42
|
+
adapter_name = "wan"
|
|
43
|
+
else:
|
|
44
|
+
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
45
|
+
|
|
46
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
47
|
+
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
|
|
48
|
+
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
# re-export functions for compatibility
|
|
52
|
+
apply_cache_on_transformer: Callable = apply_fb_cache_on_transformer
|
|
53
|
+
apply_cache_on_pipe: Callable = apply_fb_cache_on_pipe
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/cogvideox.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_cache_on_transformer(
|
|
13
|
+
transformer: CogVideoXTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.CachedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
original_forward = transformer.forward
|
|
28
|
+
|
|
29
|
+
@functools.wraps(transformer.__class__.forward)
|
|
30
|
+
def new_forward(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
with unittest.mock.patch.object(
|
|
36
|
+
self,
|
|
37
|
+
"transformer_blocks",
|
|
38
|
+
cached_transformer_blocks,
|
|
39
|
+
):
|
|
40
|
+
return original_forward(
|
|
41
|
+
*args,
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
46
|
+
|
|
47
|
+
transformer._is_cached = True
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_cache_on_pipe(
|
|
53
|
+
pipe: DiffusionPipeline,
|
|
54
|
+
*,
|
|
55
|
+
shallow_patch: bool = False,
|
|
56
|
+
residual_diff_threshold=0.04,
|
|
57
|
+
downsample_factor=1,
|
|
58
|
+
warmup_steps=0,
|
|
59
|
+
max_cached_steps=-1,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
63
|
+
default_attrs={
|
|
64
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
65
|
+
"downsample_factor": downsample_factor,
|
|
66
|
+
"warmup_steps": warmup_steps,
|
|
67
|
+
"max_cached_steps": max_cached_steps,
|
|
68
|
+
},
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if not getattr(pipe, "_is_cached", False):
|
|
72
|
+
original_call = pipe.__class__.__call__
|
|
73
|
+
|
|
74
|
+
@functools.wraps(original_call)
|
|
75
|
+
def new_call(self, *args, **kwargs):
|
|
76
|
+
with cache_context.cache_context(
|
|
77
|
+
cache_context.create_cache_context(
|
|
78
|
+
**cache_kwargs,
|
|
79
|
+
)
|
|
80
|
+
):
|
|
81
|
+
return original_call(self, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
pipe.__class__.__call__ = new_call
|
|
84
|
+
pipe.__class__._is_cached = True
|
|
85
|
+
|
|
86
|
+
if not shallow_patch:
|
|
87
|
+
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
88
|
+
|
|
89
|
+
return pipe
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/flux.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_cache_on_transformer(
|
|
13
|
+
transformer: FluxTransformer2DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.CachedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer.single_transformer_blocks,
|
|
23
|
+
transformer=transformer,
|
|
24
|
+
return_hidden_states_first=False,
|
|
25
|
+
)
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
29
|
+
|
|
30
|
+
original_forward = transformer.forward
|
|
31
|
+
|
|
32
|
+
@functools.wraps(original_forward)
|
|
33
|
+
def new_forward(
|
|
34
|
+
self,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
with (
|
|
39
|
+
unittest.mock.patch.object(
|
|
40
|
+
self,
|
|
41
|
+
"transformer_blocks",
|
|
42
|
+
cached_transformer_blocks,
|
|
43
|
+
),
|
|
44
|
+
unittest.mock.patch.object(
|
|
45
|
+
self,
|
|
46
|
+
"single_transformer_blocks",
|
|
47
|
+
dummy_single_transformer_blocks,
|
|
48
|
+
),
|
|
49
|
+
):
|
|
50
|
+
return original_forward(
|
|
51
|
+
*args,
|
|
52
|
+
**kwargs,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
56
|
+
|
|
57
|
+
transformer._is_cached = True
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def apply_cache_on_pipe(
|
|
63
|
+
pipe: DiffusionPipeline,
|
|
64
|
+
*,
|
|
65
|
+
shallow_patch: bool = False,
|
|
66
|
+
residual_diff_threshold=0.05,
|
|
67
|
+
downsample_factor=1,
|
|
68
|
+
warmup_steps=0,
|
|
69
|
+
max_cached_steps=-1,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
73
|
+
default_attrs={
|
|
74
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
75
|
+
"downsample_factor": downsample_factor,
|
|
76
|
+
"warmup_steps": warmup_steps,
|
|
77
|
+
"max_cached_steps": max_cached_steps,
|
|
78
|
+
},
|
|
79
|
+
**kwargs,
|
|
80
|
+
) # noqa
|
|
81
|
+
|
|
82
|
+
if not getattr(pipe, "_is_cached", False):
|
|
83
|
+
original_call = pipe.__class__.__call__
|
|
84
|
+
|
|
85
|
+
@functools.wraps(original_call)
|
|
86
|
+
def new_call(self, *args, **kwargs):
|
|
87
|
+
with cache_context.cache_context(
|
|
88
|
+
cache_context.create_cache_context(
|
|
89
|
+
**cache_kwargs,
|
|
90
|
+
)
|
|
91
|
+
):
|
|
92
|
+
return original_call(self, *args, **kwargs)
|
|
93
|
+
|
|
94
|
+
pipe.__class__.__call__ = new_call
|
|
95
|
+
pipe.__class__._is_cached = True
|
|
96
|
+
|
|
97
|
+
if not shallow_patch:
|
|
98
|
+
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
99
|
+
|
|
100
|
+
return pipe
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/mochi.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, MochiTransformer3DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_cache_on_transformer(
|
|
13
|
+
transformer: MochiTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.CachedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
original_forward = transformer.forward
|
|
28
|
+
|
|
29
|
+
@functools.wraps(transformer.__class__.forward)
|
|
30
|
+
def new_forward(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
with unittest.mock.patch.object(
|
|
36
|
+
self,
|
|
37
|
+
"transformer_blocks",
|
|
38
|
+
cached_transformer_blocks,
|
|
39
|
+
):
|
|
40
|
+
return original_forward(
|
|
41
|
+
*args,
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
46
|
+
|
|
47
|
+
transformer._is_cached = True
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_cache_on_pipe(
|
|
53
|
+
pipe: DiffusionPipeline,
|
|
54
|
+
*,
|
|
55
|
+
shallow_patch: bool = False,
|
|
56
|
+
residual_diff_threshold=0.06,
|
|
57
|
+
downsample_factor=1,
|
|
58
|
+
warmup_steps=0,
|
|
59
|
+
max_cached_steps=-1,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
63
|
+
default_attrs={
|
|
64
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
65
|
+
"downsample_factor": downsample_factor,
|
|
66
|
+
"warmup_steps": warmup_steps,
|
|
67
|
+
"max_cached_steps": max_cached_steps,
|
|
68
|
+
},
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if not getattr(pipe, "_is_cached", False):
|
|
72
|
+
original_call = pipe.__class__.__call__
|
|
73
|
+
|
|
74
|
+
@functools.wraps(original_call)
|
|
75
|
+
def new_call(self, *args, **kwargs):
|
|
76
|
+
with cache_context.cache_context(
|
|
77
|
+
cache_context.create_cache_context(
|
|
78
|
+
**cache_kwargs,
|
|
79
|
+
)
|
|
80
|
+
):
|
|
81
|
+
return original_call(self, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
pipe.__class__.__call__ = new_call
|
|
84
|
+
pipe.__class__._is_cached = True
|
|
85
|
+
|
|
86
|
+
if not shallow_patch:
|
|
87
|
+
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
88
|
+
|
|
89
|
+
return pipe
|
|
@@ -0,0 +1,98 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_cache_on_transformer(
|
|
13
|
+
transformer: HunyuanVideoTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.CachedTransformerBlocks(
|
|
21
|
+
transformer.blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
return_hidden_states_only=True,
|
|
24
|
+
)
|
|
25
|
+
]
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
original_forward = transformer.forward
|
|
29
|
+
|
|
30
|
+
@functools.wraps(transformer.__class__.forward)
|
|
31
|
+
def new_forward(
|
|
32
|
+
self,
|
|
33
|
+
*args,
|
|
34
|
+
**kwargs,
|
|
35
|
+
):
|
|
36
|
+
with unittest.mock.patch.object(
|
|
37
|
+
self,
|
|
38
|
+
"blocks",
|
|
39
|
+
blocks,
|
|
40
|
+
):
|
|
41
|
+
return original_forward(
|
|
42
|
+
*args,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
47
|
+
|
|
48
|
+
transformer._is_cached = True
|
|
49
|
+
|
|
50
|
+
return transformer
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def apply_cache_on_pipe(
|
|
54
|
+
pipe: DiffusionPipeline,
|
|
55
|
+
*,
|
|
56
|
+
shallow_patch: bool = False,
|
|
57
|
+
residual_diff_threshold=0.03,
|
|
58
|
+
downsample_factor=1,
|
|
59
|
+
slg_layers=None,
|
|
60
|
+
slg_start: float = 0.0,
|
|
61
|
+
slg_end: float = 0.1,
|
|
62
|
+
warmup_steps=0,
|
|
63
|
+
max_cached_steps=-1,
|
|
64
|
+
**kwargs,
|
|
65
|
+
):
|
|
66
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
67
|
+
default_attrs={
|
|
68
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
69
|
+
"downsample_factor": downsample_factor,
|
|
70
|
+
"enable_alter_cache": True,
|
|
71
|
+
"slg_layers": slg_layers,
|
|
72
|
+
"slg_start": slg_start,
|
|
73
|
+
"slg_end": slg_end,
|
|
74
|
+
"num_inference_steps": kwargs.get("num_inference_steps", 50),
|
|
75
|
+
"warmup_steps": warmup_steps,
|
|
76
|
+
"max_cached_steps": max_cached_steps,
|
|
77
|
+
},
|
|
78
|
+
**kwargs,
|
|
79
|
+
)
|
|
80
|
+
if not getattr(pipe, "_is_cached", False):
|
|
81
|
+
original_call = pipe.__class__.__call__
|
|
82
|
+
|
|
83
|
+
@functools.wraps(original_call)
|
|
84
|
+
def new_call(self, *args, **kwargs):
|
|
85
|
+
with cache_context.cache_context(
|
|
86
|
+
cache_context.create_cache_context(
|
|
87
|
+
**cache_kwargs,
|
|
88
|
+
)
|
|
89
|
+
):
|
|
90
|
+
return original_call(self, *args, **kwargs)
|
|
91
|
+
|
|
92
|
+
pipe.__class__.__call__ = new_call
|
|
93
|
+
pipe.__class__._is_cached = True
|
|
94
|
+
|
|
95
|
+
if not shallow_patch:
|
|
96
|
+
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
97
|
+
|
|
98
|
+
return pipe
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class TaylorSeer:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
n_derivatives=2,
|
|
9
|
+
warmup_steps=1,
|
|
10
|
+
skip_interval_steps=1,
|
|
11
|
+
compute_step_map=None,
|
|
12
|
+
):
|
|
13
|
+
self.n_derivatives = n_derivatives
|
|
14
|
+
self.ORDER = n_derivatives + 1
|
|
15
|
+
self.warmup_steps = warmup_steps
|
|
16
|
+
self.skip_interval_steps = skip_interval_steps
|
|
17
|
+
self.compute_step_map = compute_step_map
|
|
18
|
+
self.reset_cache()
|
|
19
|
+
|
|
20
|
+
def reset_cache(self):
|
|
21
|
+
self.state = {
|
|
22
|
+
"dY_prev": [None] * self.ORDER,
|
|
23
|
+
"dY_current": [None] * self.ORDER,
|
|
24
|
+
}
|
|
25
|
+
self.current_step = -1
|
|
26
|
+
self.last_non_approximated_step = -1
|
|
27
|
+
|
|
28
|
+
def should_compute_full(self, step=None):
|
|
29
|
+
step = self.current_step if step is None else step
|
|
30
|
+
if self.compute_step_map is not None:
|
|
31
|
+
return self.compute_step_map[step]
|
|
32
|
+
if (
|
|
33
|
+
step < self.warmup_steps
|
|
34
|
+
or (step - self.warmup_steps + 1) % self.skip_interval_steps == 0
|
|
35
|
+
):
|
|
36
|
+
return True
|
|
37
|
+
return False
|
|
38
|
+
|
|
39
|
+
def approximate_derivative(self, Y):
|
|
40
|
+
dY_current = [None] * self.ORDER
|
|
41
|
+
dY_current[0] = Y
|
|
42
|
+
window = self.current_step - self.last_non_approximated_step
|
|
43
|
+
for i in range(self.n_derivatives):
|
|
44
|
+
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
45
|
+
dY_current[i + 1] = (
|
|
46
|
+
dY_current[i] - self.state["dY_prev"][i]
|
|
47
|
+
) / window
|
|
48
|
+
else:
|
|
49
|
+
break
|
|
50
|
+
return dY_current
|
|
51
|
+
|
|
52
|
+
def approximate_value(self):
|
|
53
|
+
elapsed = self.current_step - self.last_non_approximated_step
|
|
54
|
+
output = 0
|
|
55
|
+
for i, derivative in enumerate(self.state["dY_current"]):
|
|
56
|
+
if derivative is not None:
|
|
57
|
+
output += (1 / math.factorial(i)) * derivative * (elapsed**i)
|
|
58
|
+
else:
|
|
59
|
+
break
|
|
60
|
+
return output
|
|
61
|
+
|
|
62
|
+
def mark_step_begin(self):
|
|
63
|
+
self.current_step += 1
|
|
64
|
+
|
|
65
|
+
def update(self, Y):
|
|
66
|
+
self.state["dY_prev"] = self.state["dY_current"]
|
|
67
|
+
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
68
|
+
self.last_non_approximated_step = self.current_step
|
|
69
|
+
|
|
70
|
+
def step(self, Y):
|
|
71
|
+
self.mark_step_begin()
|
|
72
|
+
if self.should_compute_full():
|
|
73
|
+
self.update(Y)
|
|
74
|
+
return Y
|
|
75
|
+
else:
|
|
76
|
+
return self.approximate_value()
|
|
File without changes
|
cache_dit/logger.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import os
|
|
3
|
+
import sys
|
|
4
|
+
|
|
5
|
+
_FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s"
|
|
6
|
+
_DATE_FORMAT = "%m-%d %H:%M:%S"
|
|
7
|
+
|
|
8
|
+
_LOG_LEVEL = os.environ.get("CACHE_DIT_LOG_LEVEL", "info")
|
|
9
|
+
_LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0)
|
|
10
|
+
_LOG_DIR = os.environ.get("CACHE_DIT_LOG_DIR", None)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class NewLineFormatter(logging.Formatter):
|
|
14
|
+
"""Adds logging prefix to newlines to align multi-line messages."""
|
|
15
|
+
|
|
16
|
+
def __init__(self, fmt, datefmt=None):
|
|
17
|
+
logging.Formatter.__init__(self, fmt, datefmt)
|
|
18
|
+
|
|
19
|
+
def format(self, record):
|
|
20
|
+
msg = logging.Formatter.format(self, record)
|
|
21
|
+
if record.message != "":
|
|
22
|
+
parts = msg.split(record.message)
|
|
23
|
+
msg = msg.replace("\n", "\r\n" + parts[0])
|
|
24
|
+
return msg
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
_root_logger = logging.getLogger("CACHE_DIT")
|
|
28
|
+
_default_handler = None
|
|
29
|
+
_default_file_handler = None
|
|
30
|
+
_inference_log_file_handler = {}
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _setup_logger():
|
|
34
|
+
_root_logger.setLevel(_LOG_LEVEL)
|
|
35
|
+
global _default_handler
|
|
36
|
+
global _default_file_handler
|
|
37
|
+
fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
|
|
38
|
+
|
|
39
|
+
if _default_handler is None:
|
|
40
|
+
_default_handler = logging.StreamHandler(sys.stdout)
|
|
41
|
+
_default_handler.flush = sys.stdout.flush # type: ignore
|
|
42
|
+
_default_handler.setLevel(_LOG_LEVEL)
|
|
43
|
+
_root_logger.addHandler(_default_handler)
|
|
44
|
+
|
|
45
|
+
if _default_file_handler is None and _LOG_DIR is not None:
|
|
46
|
+
if not os.path.exists(_LOG_DIR):
|
|
47
|
+
try:
|
|
48
|
+
os.makedirs(_LOG_DIR)
|
|
49
|
+
except OSError as e:
|
|
50
|
+
_root_logger.warn(f"Error creating directory {_LOG_DIR} : {e}")
|
|
51
|
+
_default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log")
|
|
52
|
+
_default_file_handler.setLevel(_LOG_LEVEL)
|
|
53
|
+
_default_file_handler.setFormatter(fmt)
|
|
54
|
+
_root_logger.addHandler(_default_file_handler)
|
|
55
|
+
|
|
56
|
+
_default_handler.setFormatter(fmt)
|
|
57
|
+
# Setting this will avoid the message
|
|
58
|
+
# being propagated to the parent logger.
|
|
59
|
+
_root_logger.propagate = False
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
# The logger is initialized when the module is imported.
|
|
63
|
+
# This is thread-safe as the module is only imported once,
|
|
64
|
+
# guaranteed by the Python GIL.
|
|
65
|
+
_setup_logger()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def init_logger(name: str):
|
|
69
|
+
pid = os.getpid()
|
|
70
|
+
# Use the same settings as above for root logger
|
|
71
|
+
logger = logging.getLogger(name)
|
|
72
|
+
logger.setLevel(_LOG_LEVEL)
|
|
73
|
+
logger.addHandler(_default_handler)
|
|
74
|
+
if _LOG_DIR is not None and pid is None:
|
|
75
|
+
logger.addHandler(_default_file_handler)
|
|
76
|
+
elif _LOG_DIR is not None:
|
|
77
|
+
if _inference_log_file_handler.get(pid, None) is not None:
|
|
78
|
+
logger.addHandler(_inference_log_file_handler[pid])
|
|
79
|
+
else:
|
|
80
|
+
if not os.path.exists(_LOG_DIR):
|
|
81
|
+
try:
|
|
82
|
+
os.makedirs(_LOG_DIR)
|
|
83
|
+
except OSError as e:
|
|
84
|
+
_root_logger.warn(
|
|
85
|
+
f"Error creating directory {_LOG_DIR} : {e}"
|
|
86
|
+
)
|
|
87
|
+
_inference_log_file_handler[pid] = logging.FileHandler(
|
|
88
|
+
_LOG_DIR + f"/process.{pid}.log"
|
|
89
|
+
)
|
|
90
|
+
_inference_log_file_handler[pid].setLevel(_LOG_LEVEL)
|
|
91
|
+
_inference_log_file_handler[pid].setFormatter(
|
|
92
|
+
NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
|
|
93
|
+
)
|
|
94
|
+
_root_logger.addHandler(_inference_log_file_handler[pid])
|
|
95
|
+
logger.addHandler(_inference_log_file_handler[pid])
|
|
96
|
+
logger.propagate = False
|
|
97
|
+
return logger
|