torch-weighttracker 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_weighttracker/__init__.py +5 -0
- torch_weighttracker/calculations/__init__.py +77 -0
- torch_weighttracker/calculations/base.py +54 -0
- torch_weighttracker/calculations/cached_calc.py +23 -0
- torch_weighttracker/calculations/calcs/__init__.py +68 -0
- torch_weighttracker/calculations/calcs/active_macs_pr_module.py +58 -0
- torch_weighttracker/calculations/calcs/active_units.py +31 -0
- torch_weighttracker/calculations/calcs/baseline_group_sizes.py +51 -0
- torch_weighttracker/calculations/calcs/baseline_macs_pr_module.py +100 -0
- torch_weighttracker/calculations/calcs/baseline_module_axes.py +102 -0
- torch_weighttracker/calculations/calcs/baseline_param_pr_unit_pr_group.py +144 -0
- torch_weighttracker/calculations/calcs/bitrate_pr_module.py +91 -0
- torch_weighttracker/calculations/calcs/group_change_effect.py +136 -0
- torch_weighttracker/calculations/calcs/group_sizes.py +49 -0
- torch_weighttracker/calculations/calcs/group_unit_param_change.py +179 -0
- torch_weighttracker/calculations/calcs/groups_to_units.py +79 -0
- torch_weighttracker/calculations/calcs/l2_norm_pr_unit.py +47 -0
- torch_weighttracker/calculations/calcs/param_pr_unit.py +64 -0
- torch_weighttracker/calculations/calcs/structured_unit_sum.py +31 -0
- torch_weighttracker/calculations/calcs/unit_active_mask.py +38 -0
- torch_weighttracker/calculations/calcs/unit_delta_to_module_axis.py +185 -0
- torch_weighttracker/calculations/calcs/units_to_group.py +64 -0
- torch_weighttracker/calculations/calcs/units_to_module_axis.py +118 -0
- torch_weighttracker/calculations/calculations.py +75 -0
- torch_weighttracker/calculations/context.py +56 -0
- torch_weighttracker/calculations/pipeline_calc.py +136 -0
- torch_weighttracker/calculations/reduction_calc.py +149 -0
- torch_weighttracker/calculations/registry.py +84 -0
- torch_weighttracker/calculations/spec.py +18 -0
- torch_weighttracker/calculations/static_calc.py +21 -0
- torch_weighttracker/canonical_units.py +448 -0
- torch_weighttracker/consumer_ignore.py +53 -0
- torch_weighttracker/extractors/codeq_bitrate_extractor.py +222 -0
- torch_weighttracker/extractors/extractor.py +106 -0
- torch_weighttracker/extractors/parameter_extractor.py +103 -0
- torch_weighttracker/operations/__init__.py +53 -0
- torch_weighttracker/operations/base.py +41 -0
- torch_weighttracker/operations/conv.py +29 -0
- torch_weighttracker/operations/generic.py +176 -0
- torch_weighttracker/operations/linear.py +24 -0
- torch_weighttracker/operations/mha.py +285 -0
- torch_weighttracker/operations/norm.py +39 -0
- torch_weighttracker/operations/resolver.py +44 -0
- torch_weighttracker/plans/mapping_plan.py +90 -0
- torch_weighttracker/plans/unit_weight_operation_plan.py +226 -0
- torch_weighttracker/py.typed +1 -0
- torch_weighttracker/reductions/builder.py +498 -0
- torch_weighttracker/reductions/compiler.py +92 -0
- torch_weighttracker/reductions/helpers.py +75 -0
- torch_weighttracker/reductions/ops.py +96 -0
- torch_weighttracker/regularizers/__init__.py +17 -0
- torch_weighttracker/regularizers/base.py +79 -0
- torch_weighttracker/regularizers/group_lasso.py +43 -0
- torch_weighttracker/regularizers/group_lasso_with_bitrate.py +9 -0
- torch_weighttracker/torch_pruning/__init__.py +75 -0
- torch_weighttracker/torch_pruning/_helpers.py +79 -0
- torch_weighttracker/torch_pruning/dependency/__init__.py +9 -0
- torch_weighttracker/torch_pruning/dependency/constants.py +10 -0
- torch_weighttracker/torch_pruning/dependency/dependency.py +109 -0
- torch_weighttracker/torch_pruning/dependency/graph.py +613 -0
- torch_weighttracker/torch_pruning/dependency/group.py +143 -0
- torch_weighttracker/torch_pruning/dependency/index_mapping.py +403 -0
- torch_weighttracker/torch_pruning/dependency/node.py +60 -0
- torch_weighttracker/torch_pruning/dependency/shape_infer.py +124 -0
- torch_weighttracker/torch_pruning/ops.py +346 -0
- torch_weighttracker/torch_pruning/pruner/__init__.py +53 -0
- torch_weighttracker/torch_pruning/pruner/function.py +574 -0
- torch_weighttracker/torch_pruning/utils/__init__.py +3 -0
- torch_weighttracker/torch_pruning/utils/utils.py +19 -0
- torch_weighttracker/trackers/__init__.py +11 -0
- torch_weighttracker/trackers/base.py +96 -0
- torch_weighttracker/trackers/structured_bops.py +122 -0
- torch_weighttracker/weight_tracker.py +593 -0
- torch_weighttracker-0.1.0.dist-info/METADATA +93 -0
- torch_weighttracker-0.1.0.dist-info/RECORD +78 -0
- torch_weighttracker-0.1.0.dist-info/WHEEL +5 -0
- torch_weighttracker-0.1.0.dist-info/licenses/LICENSE +22 -0
- torch_weighttracker-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
from torch_weighttracker.calculations.base import (
|
|
2
|
+
BaseCalculation,
|
|
3
|
+
CalcType,
|
|
4
|
+
Calculation,
|
|
5
|
+
)
|
|
6
|
+
from torch_weighttracker.calculations.cached_calc import CachedCalculation
|
|
7
|
+
from torch_weighttracker.calculations.calculations import (
|
|
8
|
+
CALCULATION_SPECS,
|
|
9
|
+
ActiveMacsPrModuleCalc,
|
|
10
|
+
BaselineMacsPrModuleCalc,
|
|
11
|
+
BaselineModuleAxesCalc,
|
|
12
|
+
BaselineParamPrUnitPrGroup,
|
|
13
|
+
CalculationContext,
|
|
14
|
+
CalculationSpec,
|
|
15
|
+
InitialUnitCountPrGroup,
|
|
16
|
+
ParamPrUnit,
|
|
17
|
+
UnitActiveMaskCalc,
|
|
18
|
+
UnitPrGroup,
|
|
19
|
+
create_active_macs_pr_module_calc,
|
|
20
|
+
create_active_units_calc,
|
|
21
|
+
create_baseline_group_sizes_calc,
|
|
22
|
+
create_baseline_macs_pr_module_calc,
|
|
23
|
+
create_baseline_module_axes_calc,
|
|
24
|
+
create_baseline_param_pr_unit_pr_group_calc,
|
|
25
|
+
create_bitrate_pr_module_calc,
|
|
26
|
+
create_group_change_effect_calc,
|
|
27
|
+
create_group_unit_param_change_calc,
|
|
28
|
+
create_group_sizes_calc,
|
|
29
|
+
create_groups_to_units_calc,
|
|
30
|
+
create_l2_norm_pr_unit_calc,
|
|
31
|
+
create_param_pr_unit_calc,
|
|
32
|
+
create_structured_unit_sum_calc,
|
|
33
|
+
create_unit_active_mask_calc,
|
|
34
|
+
create_unit_delta_to_module_axis_calc,
|
|
35
|
+
create_units_to_group_calc,
|
|
36
|
+
create_units_to_module_axis_calc,
|
|
37
|
+
)
|
|
38
|
+
from torch_weighttracker.calculations.pipeline_calc import PipelineCalc
|
|
39
|
+
from torch_weighttracker.calculations.reduction_calc import ReductionCalc
|
|
40
|
+
|
|
41
|
+
__all__ = [
|
|
42
|
+
"ActiveMacsPrModuleCalc",
|
|
43
|
+
"BaseCalculation",
|
|
44
|
+
"BaselineMacsPrModuleCalc",
|
|
45
|
+
"BaselineModuleAxesCalc",
|
|
46
|
+
"BaselineParamPrUnitPrGroup",
|
|
47
|
+
"CALCULATION_SPECS",
|
|
48
|
+
"CachedCalculation",
|
|
49
|
+
"CalcType",
|
|
50
|
+
"Calculation",
|
|
51
|
+
"CalculationContext",
|
|
52
|
+
"CalculationSpec",
|
|
53
|
+
"InitialUnitCountPrGroup",
|
|
54
|
+
"ParamPrUnit",
|
|
55
|
+
"PipelineCalc",
|
|
56
|
+
"ReductionCalc",
|
|
57
|
+
"UnitActiveMaskCalc",
|
|
58
|
+
"UnitPrGroup",
|
|
59
|
+
"create_active_macs_pr_module_calc",
|
|
60
|
+
"create_active_units_calc",
|
|
61
|
+
"create_baseline_group_sizes_calc",
|
|
62
|
+
"create_baseline_macs_pr_module_calc",
|
|
63
|
+
"create_baseline_module_axes_calc",
|
|
64
|
+
"create_baseline_param_pr_unit_pr_group_calc",
|
|
65
|
+
"create_bitrate_pr_module_calc",
|
|
66
|
+
"create_group_change_effect_calc",
|
|
67
|
+
"create_group_unit_param_change_calc",
|
|
68
|
+
"create_group_sizes_calc",
|
|
69
|
+
"create_groups_to_units_calc",
|
|
70
|
+
"create_l2_norm_pr_unit_calc",
|
|
71
|
+
"create_param_pr_unit_calc",
|
|
72
|
+
"create_structured_unit_sum_calc",
|
|
73
|
+
"create_unit_active_mask_calc",
|
|
74
|
+
"create_unit_delta_to_module_axis_calc",
|
|
75
|
+
"create_units_to_group_calc",
|
|
76
|
+
"create_units_to_module_axis_calc",
|
|
77
|
+
]
|
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
from enum import Enum
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from torch import nn
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CalcType(str, Enum):
|
|
11
|
+
STRUCTURED_UNIT_SUM = "structured_unit_sum"
|
|
12
|
+
ACTIVE_UNITS = "active_units"
|
|
13
|
+
L2_NORM_PR_UNIT = "l2_norm_pr_unit"
|
|
14
|
+
BITRATE_PR_MODULE = "bitrate_pr_module"
|
|
15
|
+
UNITS_TO_MODULE_AXIS = "units_to_module_axis"
|
|
16
|
+
UNIT_DELTA_TO_MODULE_AXIS = "unit_delta_to_module_axis"
|
|
17
|
+
ACTIVE_MACS_PR_MODULE = "active_macs_pr_module"
|
|
18
|
+
BASELINE_MACS_PR_MODULE = "baseline_macs_pr_module"
|
|
19
|
+
BASELINE_MODULE_AXES = "baseline_module_axes"
|
|
20
|
+
UNITS_TO_GROUP = "units_to_group"
|
|
21
|
+
GROUPS_TO_UNITS = "groups_to_units"
|
|
22
|
+
UNIT_ACTIVE_MASK = "unit_active_mask"
|
|
23
|
+
GROUP_CHANGE_EFFECT = "group_change_effect"
|
|
24
|
+
GROUP_UNIT_PARAM_CHANGE = "group_unit_param_change"
|
|
25
|
+
BASELINE_PARAM_PR_UNIT_PR_GROUP = "baseline_param_pr_unit_pr_group"
|
|
26
|
+
PARAM_PR_UNIT = "param_pr_unit"
|
|
27
|
+
GROUP_SIZES = "group_sizes"
|
|
28
|
+
INIT_UNIT_PR_GROUP_COUNT = "baseline_group_sizes"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BaseCalculation(nn.Module):
|
|
32
|
+
calculation_type: CalcType | None = None
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
dependencies: Mapping[CalcType, nn.Module] | None = None,
|
|
37
|
+
) -> None:
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.dependencies = nn.ModuleDict(
|
|
40
|
+
{}
|
|
41
|
+
if dependencies is None
|
|
42
|
+
else {calc_type.name: module for calc_type, module in dependencies.items()}
|
|
43
|
+
)
|
|
44
|
+
self.register_buffer("output_anchor", torch.empty(()), persistent=False)
|
|
45
|
+
|
|
46
|
+
def calc(self, calc_type: CalcType | str) -> nn.Module:
|
|
47
|
+
calc_type = CalcType(calc_type)
|
|
48
|
+
return self.dependencies[calc_type.name]
|
|
49
|
+
|
|
50
|
+
def compute(self, calc_type: CalcType | str, *args, **kwargs) -> torch.Tensor:
|
|
51
|
+
return self.calc(calc_type)(*args, **kwargs)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
Calculation = BaseCalculation
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from typing import Any
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
from torch import nn
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class CachedCalculation(nn.Module):
|
|
8
|
+
def __init__(self, calculation: nn.Module) -> None:
|
|
9
|
+
super().__init__()
|
|
10
|
+
self.calculation = calculation
|
|
11
|
+
self.calculation_type = getattr(calculation, "calculation_type", None)
|
|
12
|
+
|
|
13
|
+
if not hasattr(calculation, "output_anchor"):
|
|
14
|
+
raise ValueError(
|
|
15
|
+
f"{calculation.__class__.__name__} must define output_anchor"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
@torch.no_grad()
|
|
19
|
+
def refresh_cache(self, *args: Any, **kwargs: Any) -> None:
|
|
20
|
+
self.calculation.output_anchor = self.calculation(*args, **kwargs)
|
|
21
|
+
|
|
22
|
+
def forward(self) -> torch.Tensor:
|
|
23
|
+
return self.calculation.output_anchor
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
from torch_weighttracker.calculations.calcs.active_macs_pr_module import (
|
|
2
|
+
ActiveMacsPrModuleCalc,
|
|
3
|
+
create_active_macs_pr_module_calc,
|
|
4
|
+
)
|
|
5
|
+
from torch_weighttracker.calculations.calcs.active_units import create_active_units_calc
|
|
6
|
+
from torch_weighttracker.calculations.calcs.baseline_macs_pr_module import (
|
|
7
|
+
BaselineMacsPrModuleCalc,
|
|
8
|
+
create_baseline_macs_pr_module_calc,
|
|
9
|
+
)
|
|
10
|
+
from torch_weighttracker.calculations.calcs.baseline_group_sizes import (
|
|
11
|
+
InitialUnitCountPrGroup,
|
|
12
|
+
create_baseline_group_sizes_calc,
|
|
13
|
+
)
|
|
14
|
+
from torch_weighttracker.calculations.calcs.baseline_module_axes import (
|
|
15
|
+
BaselineModuleAxesCalc,
|
|
16
|
+
create_baseline_module_axes_calc,
|
|
17
|
+
)
|
|
18
|
+
from torch_weighttracker.calculations.calcs.baseline_param_pr_unit_pr_group import (
|
|
19
|
+
BaselineParamPrUnitPrGroup,
|
|
20
|
+
create_baseline_param_pr_unit_pr_group_calc,
|
|
21
|
+
)
|
|
22
|
+
from torch_weighttracker.calculations.calcs.bitrate_pr_module import create_bitrate_pr_module_calc
|
|
23
|
+
from torch_weighttracker.calculations.calcs.group_change_effect import create_group_change_effect_calc
|
|
24
|
+
from torch_weighttracker.calculations.calcs.group_unit_param_change import create_group_unit_param_change_calc
|
|
25
|
+
from torch_weighttracker.calculations.calcs.group_sizes import UnitPrGroup, create_group_sizes_calc
|
|
26
|
+
from torch_weighttracker.calculations.calcs.groups_to_units import create_groups_to_units_calc
|
|
27
|
+
from torch_weighttracker.calculations.calcs.l2_norm_pr_unit import (
|
|
28
|
+
L2NormPrUnit,
|
|
29
|
+
create_l2_norm_pr_unit_calc,
|
|
30
|
+
)
|
|
31
|
+
from torch_weighttracker.calculations.calcs.param_pr_unit import ParamPrUnit, create_param_pr_unit_calc
|
|
32
|
+
from torch_weighttracker.calculations.calcs.structured_unit_sum import create_structured_unit_sum_calc
|
|
33
|
+
from torch_weighttracker.calculations.calcs.unit_active_mask import UnitActiveMaskCalc, create_unit_active_mask_calc
|
|
34
|
+
from torch_weighttracker.calculations.calcs.unit_delta_to_module_axis import (
|
|
35
|
+
create_unit_delta_to_module_axis_calc,
|
|
36
|
+
)
|
|
37
|
+
from torch_weighttracker.calculations.calcs.units_to_group import create_units_to_group_calc
|
|
38
|
+
from torch_weighttracker.calculations.calcs.units_to_module_axis import create_units_to_module_axis_calc
|
|
39
|
+
|
|
40
|
+
__all__ = [
|
|
41
|
+
"ActiveMacsPrModuleCalc",
|
|
42
|
+
"InitialUnitCountPrGroup",
|
|
43
|
+
"L2NormPrUnit",
|
|
44
|
+
"BaselineMacsPrModuleCalc",
|
|
45
|
+
"BaselineModuleAxesCalc",
|
|
46
|
+
"BaselineParamPrUnitPrGroup",
|
|
47
|
+
"ParamPrUnit",
|
|
48
|
+
"UnitPrGroup",
|
|
49
|
+
"UnitActiveMaskCalc",
|
|
50
|
+
"create_active_macs_pr_module_calc",
|
|
51
|
+
"create_active_units_calc",
|
|
52
|
+
"create_baseline_group_sizes_calc",
|
|
53
|
+
"create_baseline_macs_pr_module_calc",
|
|
54
|
+
"create_baseline_module_axes_calc",
|
|
55
|
+
"create_baseline_param_pr_unit_pr_group_calc",
|
|
56
|
+
"create_bitrate_pr_module_calc",
|
|
57
|
+
"create_group_change_effect_calc",
|
|
58
|
+
"create_group_unit_param_change_calc",
|
|
59
|
+
"create_group_sizes_calc",
|
|
60
|
+
"create_groups_to_units_calc",
|
|
61
|
+
"create_l2_norm_pr_unit_calc",
|
|
62
|
+
"create_param_pr_unit_calc",
|
|
63
|
+
"create_structured_unit_sum_calc",
|
|
64
|
+
"create_unit_active_mask_calc",
|
|
65
|
+
"create_unit_delta_to_module_axis_calc",
|
|
66
|
+
"create_units_to_group_calc",
|
|
67
|
+
"create_units_to_module_axis_calc",
|
|
68
|
+
]
|
|
@@ -0,0 +1,58 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Mapping
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from torch_weighttracker.calculations.base import CalcType, Calculation
|
|
9
|
+
from torch_weighttracker.calculations.context import CalculationContext
|
|
10
|
+
from torch_weighttracker.calculations.spec import CalculationSpec
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ActiveMacsPrModuleCalc(Calculation):
|
|
14
|
+
"""
|
|
15
|
+
Returns the active runtime MAC count for each weighted module.
|
|
16
|
+
|
|
17
|
+
Output: 1D tensor with length `len(weighted_modules)`.
|
|
18
|
+
Input: none.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
calculation_type = CalcType.ACTIVE_MACS_PR_MODULE
|
|
22
|
+
|
|
23
|
+
def __init__(self, dependencies: Mapping[CalcType, nn.Module]) -> None:
|
|
24
|
+
super().__init__(dependencies)
|
|
25
|
+
|
|
26
|
+
def forward(self) -> torch.Tensor:
|
|
27
|
+
active_units = self.compute(CalcType.UNIT_ACTIVE_MASK)
|
|
28
|
+
baseline_axes = self.compute(CalcType.BASELINE_MODULE_AXES)
|
|
29
|
+
baseline_macs = self.compute(CalcType.BASELINE_MACS_PR_MODULE)
|
|
30
|
+
axis_delta = self.compute(
|
|
31
|
+
CalcType.UNIT_DELTA_TO_MODULE_AXIS,
|
|
32
|
+
active_units,
|
|
33
|
+
).view_as(baseline_axes)
|
|
34
|
+
active_axes = baseline_axes + axis_delta
|
|
35
|
+
return baseline_macs * (active_axes / baseline_axes).prod(dim=1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def create_active_macs_pr_module_calc(
|
|
39
|
+
ctx: CalculationContext,
|
|
40
|
+
*,
|
|
41
|
+
dependencies: Mapping[CalcType, nn.Module],
|
|
42
|
+
) -> ActiveMacsPrModuleCalc:
|
|
43
|
+
return ActiveMacsPrModuleCalc(dependencies)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
CALCULATION_SPEC = CalculationSpec(
|
|
47
|
+
calculation_type=CalcType.ACTIVE_MACS_PR_MODULE,
|
|
48
|
+
required_calculations=(
|
|
49
|
+
CalcType.UNIT_ACTIVE_MASK,
|
|
50
|
+
CalcType.UNIT_DELTA_TO_MODULE_AXIS,
|
|
51
|
+
CalcType.BASELINE_MACS_PR_MODULE,
|
|
52
|
+
CalcType.BASELINE_MODULE_AXES,
|
|
53
|
+
),
|
|
54
|
+
create=lambda ctx, deps: create_active_macs_pr_module_calc(
|
|
55
|
+
ctx,
|
|
56
|
+
dependencies=deps,
|
|
57
|
+
),
|
|
58
|
+
)
|
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
|
|
5
|
+
from torch_weighttracker.calculations.base import CalcType
|
|
6
|
+
from torch_weighttracker.calculations.reduction_calc import ReductionCalc
|
|
7
|
+
from torch_weighttracker.calculations.spec import CalculationSpec
|
|
8
|
+
from torch_weighttracker.canonical_units import CanonicalUnitGroup
|
|
9
|
+
from torch_weighttracker.operations import WeightOperationType
|
|
10
|
+
from torch_weighttracker.plans.unit_weight_operation_plan import create_group_member_plan
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def create_active_units_calc(
|
|
14
|
+
groups: Iterable[CanonicalUnitGroup],
|
|
15
|
+
) -> ReductionCalc:
|
|
16
|
+
"""
|
|
17
|
+
Returns the active parameter count for each canonical unit.
|
|
18
|
+
|
|
19
|
+
Output: 1D tensor with length equal to the total canonical unit count.
|
|
20
|
+
Input: none.
|
|
21
|
+
"""
|
|
22
|
+
return ReductionCalc(
|
|
23
|
+
create_group_member_plan(groups, WeightOperationType.ACTIVE),
|
|
24
|
+
calculation_type=CalcType.ACTIVE_UNITS,
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
CALCULATION_SPEC = CalculationSpec(
|
|
29
|
+
calculation_type=CalcType.ACTIVE_UNITS,
|
|
30
|
+
create=lambda ctx, deps: create_active_units_calc(ctx.canonical_groups),
|
|
31
|
+
)
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from torch_weighttracker.calculations.base import CalcType
|
|
8
|
+
from torch_weighttracker.calculations.context import calculation_device, calculation_dtype
|
|
9
|
+
from torch_weighttracker.calculations.spec import CalculationSpec
|
|
10
|
+
from torch_weighttracker.calculations.static_calc import StaticCalc
|
|
11
|
+
from torch_weighttracker.canonical_units import CanonicalUnitGroup
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class InitialUnitCountPrGroup(StaticCalc):
|
|
15
|
+
"""
|
|
16
|
+
Returns the initial count of canonical units for each canonical group.
|
|
17
|
+
|
|
18
|
+
Output: 1D tensor with length `len(canonical_groups)`.
|
|
19
|
+
Input: none.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
calculation_type = CalcType.INIT_UNIT_PR_GROUP_COUNT
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
initial_unit_count_pr_group: torch.Tensor,
|
|
27
|
+
) -> None:
|
|
28
|
+
super().__init__(initial_unit_count_pr_group)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def create_baseline_group_sizes_calc(
|
|
32
|
+
groups: Iterable[CanonicalUnitGroup],
|
|
33
|
+
*,
|
|
34
|
+
device: torch.device | str,
|
|
35
|
+
dtype: torch.dtype,
|
|
36
|
+
) -> InitialUnitCountPrGroup:
|
|
37
|
+
group_lengths = [group.length for group in groups]
|
|
38
|
+
return InitialUnitCountPrGroup(
|
|
39
|
+
torch.tensor(group_lengths, device=device, dtype=dtype)
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
CALCULATION_SPEC = CalculationSpec(
|
|
44
|
+
calculation_type=CalcType.INIT_UNIT_PR_GROUP_COUNT,
|
|
45
|
+
cache_constant=True,
|
|
46
|
+
create=lambda ctx, deps: create_baseline_group_sizes_calc(
|
|
47
|
+
ctx.canonical_groups,
|
|
48
|
+
device=calculation_device(ctx),
|
|
49
|
+
dtype=calculation_dtype(ctx),
|
|
50
|
+
),
|
|
51
|
+
)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from torch_weighttracker.calculations.base import CalcType
|
|
6
|
+
from torch_weighttracker.calculations.context import (
|
|
7
|
+
CalculationContext,
|
|
8
|
+
calculation_device,
|
|
9
|
+
calculation_dtype,
|
|
10
|
+
)
|
|
11
|
+
from torch_weighttracker.calculations.spec import CalculationSpec
|
|
12
|
+
from torch_weighttracker.calculations.static_calc import StaticCalc
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BaselineMacsPrModuleCalc(StaticCalc):
|
|
16
|
+
"""
|
|
17
|
+
Returns the initial runtime MAC count for each weighted module.
|
|
18
|
+
|
|
19
|
+
Output: 1D tensor with length `len(weighted_modules)`.
|
|
20
|
+
Input: none.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
calculation_type = CalcType.BASELINE_MACS_PR_MODULE
|
|
24
|
+
|
|
25
|
+
def __init__(self, baseline_macs: torch.Tensor) -> None:
|
|
26
|
+
super().__init__(baseline_macs)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def create_baseline_macs_pr_module_calc(
|
|
30
|
+
ctx: CalculationContext,
|
|
31
|
+
*,
|
|
32
|
+
device: torch.device | str,
|
|
33
|
+
dtype: torch.dtype,
|
|
34
|
+
) -> BaselineMacsPrModuleCalc:
|
|
35
|
+
if ctx.example_inputs is None:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
"BASELINE_MACS_PR_MODULE requires example_inputs so fvcore can "
|
|
38
|
+
"compute runtime MACs."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
from fvcore.nn import FlopCountAnalysis
|
|
43
|
+
except ImportError as error:
|
|
44
|
+
raise RuntimeError(
|
|
45
|
+
"BASELINE_MACS_PR_MODULE requires fvcore. Install fvcore or disable "
|
|
46
|
+
"structured BOPs MAC accounting."
|
|
47
|
+
) from error
|
|
48
|
+
|
|
49
|
+
values = _fvcore_macs_by_weighted_module(ctx, FlopCountAnalysis)
|
|
50
|
+
return BaselineMacsPrModuleCalc(
|
|
51
|
+
torch.tensor(values, dtype=dtype, device=device)
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _fvcore_macs_by_weighted_module(
|
|
56
|
+
ctx: CalculationContext,
|
|
57
|
+
flop_count_analysis,
|
|
58
|
+
) -> list[float]:
|
|
59
|
+
names_by_module = {module: name for name, module in ctx.model.named_modules()}
|
|
60
|
+
analysis = flop_count_analysis(ctx.model, ctx.example_inputs)
|
|
61
|
+
if hasattr(analysis, "unsupported_ops_warnings"):
|
|
62
|
+
analysis = analysis.unsupported_ops_warnings(False)
|
|
63
|
+
if hasattr(analysis, "uncalled_modules_warnings"):
|
|
64
|
+
analysis = analysis.uncalled_modules_warnings(False)
|
|
65
|
+
|
|
66
|
+
by_module = analysis.by_module()
|
|
67
|
+
uncalled = (
|
|
68
|
+
set(analysis.uncalled_modules())
|
|
69
|
+
if hasattr(analysis, "uncalled_modules")
|
|
70
|
+
else set()
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
values: list[float] = []
|
|
74
|
+
missing: list[str] = []
|
|
75
|
+
for module in ctx.weighted_modules:
|
|
76
|
+
name = names_by_module.get(module)
|
|
77
|
+
if name is None or name not in by_module or name in uncalled:
|
|
78
|
+
missing.append("<unnamed>" if name is None else name)
|
|
79
|
+
continue
|
|
80
|
+
values.append(float(by_module[name]))
|
|
81
|
+
|
|
82
|
+
if missing:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
"fvcore did not report MACs for weighted modules: "
|
|
85
|
+
+ ", ".join(missing)
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
return values
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
CALCULATION_SPEC = CalculationSpec(
|
|
92
|
+
calculation_type=CalcType.BASELINE_MACS_PR_MODULE,
|
|
93
|
+
requires_groups=False,
|
|
94
|
+
cache_constant=True,
|
|
95
|
+
create=lambda ctx, deps: create_baseline_macs_pr_module_calc(
|
|
96
|
+
ctx,
|
|
97
|
+
device=calculation_device(ctx),
|
|
98
|
+
dtype=calculation_dtype(ctx),
|
|
99
|
+
),
|
|
100
|
+
)
|
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
import torch.nn as nn
|
|
7
|
+
|
|
8
|
+
from torch_weighttracker.calculations.base import CalcType
|
|
9
|
+
from torch_weighttracker.calculations.context import calculation_device, calculation_dtype
|
|
10
|
+
from torch_weighttracker.calculations.spec import CalculationSpec
|
|
11
|
+
from torch_weighttracker.calculations.static_calc import StaticCalc
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class BaselineModuleAxesCalc(StaticCalc):
|
|
15
|
+
"""
|
|
16
|
+
Returns the initial input-axis and output-axis size for each weighted module.
|
|
17
|
+
|
|
18
|
+
Output: 2D tensor with shape `(len(weighted_modules), 2)`.
|
|
19
|
+
Input: none.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
calculation_type = CalcType.BASELINE_MODULE_AXES
|
|
23
|
+
|
|
24
|
+
def __init__(self, baseline_axes: torch.Tensor) -> None:
|
|
25
|
+
super().__init__(baseline_axes)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def create_baseline_module_axes_calc(
|
|
29
|
+
modules: Iterable[nn.Module],
|
|
30
|
+
*,
|
|
31
|
+
device: torch.device | str,
|
|
32
|
+
dtype: torch.dtype,
|
|
33
|
+
) -> BaselineModuleAxesCalc:
|
|
34
|
+
baseline_axes = torch.tensor(
|
|
35
|
+
[_module_axis_sizes(module) for module in modules],
|
|
36
|
+
dtype=dtype,
|
|
37
|
+
device=device,
|
|
38
|
+
)
|
|
39
|
+
return BaselineModuleAxesCalc(baseline_axes)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _module_axis_sizes(module: nn.Module) -> tuple[float, float]:
|
|
43
|
+
if isinstance(module, nn.Conv2d):
|
|
44
|
+
return _conv_axis_sizes(module)
|
|
45
|
+
|
|
46
|
+
if isinstance(module, nn.Linear):
|
|
47
|
+
return float(module.in_features), float(module.out_features)
|
|
48
|
+
|
|
49
|
+
if isinstance(module, nn.MultiheadAttention):
|
|
50
|
+
raise ValueError(
|
|
51
|
+
"ACTIVE_MACS_PR_MODULE does not support nn.MultiheadAttention parent "
|
|
52
|
+
"modules in V1. Use projection Linear modules or add explicit MHA "
|
|
53
|
+
"operation terms."
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if isinstance(module, nn.modules.batchnorm._BatchNorm):
|
|
57
|
+
return float(module.num_features), float(module.num_features)
|
|
58
|
+
|
|
59
|
+
if isinstance(module, nn.LayerNorm):
|
|
60
|
+
feature_dim = _layernorm_feature_dim(module)
|
|
61
|
+
return float(feature_dim), float(feature_dim)
|
|
62
|
+
|
|
63
|
+
if isinstance(module, nn.GroupNorm):
|
|
64
|
+
return float(module.num_channels), float(module.num_channels)
|
|
65
|
+
|
|
66
|
+
if isinstance(module, nn.modules.instancenorm._InstanceNorm):
|
|
67
|
+
return float(module.num_features), float(module.num_features)
|
|
68
|
+
|
|
69
|
+
if isinstance(module, nn.Embedding):
|
|
70
|
+
return float(module.embedding_dim), float(module.embedding_dim)
|
|
71
|
+
|
|
72
|
+
raise ValueError(
|
|
73
|
+
"Baseline module-axis sizes are not implemented for "
|
|
74
|
+
f"{module.__class__.__name__}."
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def _conv_axis_sizes(module: nn.Conv2d) -> tuple[float, float]:
|
|
79
|
+
if module.groups != 1:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
"ACTIVE_MACS_PR_MODULE currently supports only Conv2d(groups=1)."
|
|
82
|
+
)
|
|
83
|
+
return float(module.in_channels), float(module.out_channels)
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def _layernorm_feature_dim(module: nn.LayerNorm) -> int:
|
|
87
|
+
normalized_shape = tuple(module.normalized_shape)
|
|
88
|
+
if len(normalized_shape) == 0:
|
|
89
|
+
raise ValueError("LayerNorm normalized_shape must not be empty.")
|
|
90
|
+
return int(normalized_shape[-1])
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
CALCULATION_SPEC = CalculationSpec(
|
|
94
|
+
calculation_type=CalcType.BASELINE_MODULE_AXES,
|
|
95
|
+
requires_groups=False,
|
|
96
|
+
cache_constant=True,
|
|
97
|
+
create=lambda ctx, deps: create_baseline_module_axes_calc(
|
|
98
|
+
ctx.weighted_modules,
|
|
99
|
+
device=calculation_device(ctx),
|
|
100
|
+
dtype=calculation_dtype(ctx),
|
|
101
|
+
),
|
|
102
|
+
)
|