accelforge 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- accelforge/__init__.py +21 -0
- accelforge/_accelerated_imports.py +16 -0
- accelforge/_deprecate/_simanneal/evalmapping.py +271 -0
- accelforge/_deprecate/_simanneal/mapspaceglobals.py +298 -0
- accelforge/_deprecate/_simanneal/simanneal.py +666 -0
- accelforge/_deprecate/_simanneal/tracking.py +105 -0
- accelforge/_deprecate/_simanneal/wrappers.py +218 -0
- accelforge/_deprecate/_simanneal2/__init__.py +7 -0
- accelforge/_deprecate/_simanneal2/simanneal.py +493 -0
- accelforge/_deprecate/_simanneal2/tracking.py +116 -0
- accelforge/_deprecate/compatibility_util.py +181 -0
- accelforge/_deprecate/layerdeduplication/__init__.py +2 -0
- accelforge/_deprecate/layerdeduplication/group_similar_einsums.py +160 -0
- accelforge/_deprecate/layerdeduplication/grouped_einsums.py +84 -0
- accelforge/_deprecate/mapping_filter_tags/__init__.py +2 -0
- accelforge/_deprecate/mapping_filter_tags/ffmt.py +212 -0
- accelforge/_deprecate/mapping_filter_tags/onesplit.py +24 -0
- accelforge/_deprecate/mapping_filter_tags/util.py +24 -0
- accelforge/_deprecate/tags.py +69 -0
- accelforge/_deprecate/viz/__init__.py +0 -0
- accelforge/_deprecate/viz/interactive.py +159 -0
- accelforge/_deprecate/viz/reservationtree.py +307 -0
- accelforge/_deprecate/viz/ski_slope.py +88 -0
- accelforge/_version.py +15 -0
- accelforge/examples.py +39 -0
- accelforge/frontend/__init__.py +10 -0
- accelforge/frontend/_binding.py +129 -0
- accelforge/frontend/_workload_isl/__init__.py +2 -0
- accelforge/frontend/_workload_isl/_isl.py +149 -0
- accelforge/frontend/_workload_isl/_symbolic.py +141 -0
- accelforge/frontend/arch copy.py +1544 -0
- accelforge/frontend/arch.py +1642 -0
- accelforge/frontend/config.py +63 -0
- accelforge/frontend/mapper/__init__.py +5 -0
- accelforge/frontend/mapper/ffm.py +126 -0
- accelforge/frontend/mapper/mapper.py +7 -0
- accelforge/frontend/mapper/metrics.py +30 -0
- accelforge/frontend/mapping/__init__.py +1 -0
- accelforge/frontend/mapping/mapping.py +1736 -0
- accelforge/frontend/model.py +14 -0
- accelforge/frontend/renames.py +150 -0
- accelforge/frontend/spec copy.py +230 -0
- accelforge/frontend/spec.py +301 -0
- accelforge/frontend/variables.py +12 -0
- accelforge/frontend/workload.py +952 -0
- accelforge/mapper/FFM/__init__.py +9 -0
- accelforge/mapper/FFM/_join_pmappings/__init__.py +0 -0
- accelforge/mapper/FFM/_join_pmappings/compatibility.py +653 -0
- accelforge/mapper/FFM/_join_pmappings/compress_pmappings.py +140 -0
- accelforge/mapper/FFM/_join_pmappings/join_pmappings.py +703 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_dataframe.py +901 -0
- accelforge/mapper/FFM/_join_pmappings/pmapping_group.py +337 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/__init__.py +0 -0
- accelforge/mapper/FFM/_make_pmappings/contraints/constraints.py +360 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_loops.py +373 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_pmapping_templates.py +463 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_reservations.py +95 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storage_order.py +382 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmapping_templates/make_storages.py +155 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings.py +411 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/__init__.py +1 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_pmappings_from_templates.py +407 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/make_tile_shapes.py +1681 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/run_model.py +170 -0
- accelforge/mapper/FFM/_make_pmappings/make_pmappings_from_templates/symbol_relations.py +174 -0
- accelforge/mapper/FFM/_make_pmappings/pmapper_job.py +282 -0
- accelforge/mapper/FFM/_pareto_df/df_convention.py +273 -0
- accelforge/mapper/FFM/_pareto_df/pareto copy.py +836 -0
- accelforge/mapper/FFM/_pareto_df/pareto.py +508 -0
- accelforge/mapper/FFM/data.py +61 -0
- accelforge/mapper/FFM/main copy.py +236 -0
- accelforge/mapper/FFM/main.py +208 -0
- accelforge/mapper/FFM/mappings.py +510 -0
- accelforge/mapper/FFM/pmappings.py +310 -0
- accelforge/mapper/__init__.py +4 -0
- accelforge/mapper.py +0 -0
- accelforge/model/__init__.py +1 -0
- accelforge/model/_looptree/__init__.py +0 -0
- accelforge/model/_looptree/accesses.py +335 -0
- accelforge/model/_looptree/capacity/__init__.py +1 -0
- accelforge/model/_looptree/capacity/aggregators.py +36 -0
- accelforge/model/_looptree/capacity/capacity.py +47 -0
- accelforge/model/_looptree/energy.py +150 -0
- accelforge/model/_looptree/equivalent_ranks.py +29 -0
- accelforge/model/_looptree/latency/__init__.py +1 -0
- accelforge/model/_looptree/latency/latency.py +98 -0
- accelforge/model/_looptree/latency/memory.py +120 -0
- accelforge/model/_looptree/latency/processors.py +92 -0
- accelforge/model/_looptree/mapping_utilities.py +71 -0
- accelforge/model/_looptree/reuse/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/__init__.py +1 -0
- accelforge/model/_looptree/reuse/isl/des.py +59 -0
- accelforge/model/_looptree/reuse/isl/isl_functions.py +374 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/__init__.py +4 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/analyze_mapping.py +297 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/skews_from_mapping.py +236 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/tiling.py +685 -0
- accelforge/model/_looptree/reuse/isl/mapping_to_isl/types.py +188 -0
- accelforge/model/_looptree/reuse/isl/spatial.py +260 -0
- accelforge/model/_looptree/reuse/isl/temporal.py +182 -0
- accelforge/model/_looptree/reuse/symbolic/__init__.py +1 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy 2.py +1346 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic copy.py +1408 -0
- accelforge/model/_looptree/reuse/symbolic/symbolic.py +1396 -0
- accelforge/model/_looptree/run.py +122 -0
- accelforge/model/_looptree/types.py +26 -0
- accelforge/model/_looptree/visualization/__init__.py +0 -0
- accelforge/model/_looptree/visualization/occupancy.py +11 -0
- accelforge/model/main.py +222 -0
- accelforge/plotting/__init__.py +2 -0
- accelforge/plotting/mappings.py +219 -0
- accelforge/plotting/specs.py +57 -0
- accelforge/util/__init__.py +4 -0
- accelforge/util/_base_analysis_types.py +24 -0
- accelforge/util/_basetypes.py +1089 -0
- accelforge/util/_frozenset.py +36 -0
- accelforge/util/_isl.py +29 -0
- accelforge/util/_itertools.py +14 -0
- accelforge/util/_mathfuncs.py +57 -0
- accelforge/util/_parse_expressions.py +339 -0
- accelforge/util/_picklecache.py +32 -0
- accelforge/util/_setexpressions.py +268 -0
- accelforge/util/_sympy/__init__.py +0 -0
- accelforge/util/_sympy/broadcast_max.py +18 -0
- accelforge/util/_visualization.py +112 -0
- accelforge/util/_yaml.py +579 -0
- accelforge/util/parallel.py +193 -0
- accelforge-0.0.1.dist-info/METADATA +64 -0
- accelforge-0.0.1.dist-info/RECORD +258 -0
- accelforge-0.0.1.dist-info/WHEEL +5 -0
- accelforge-0.0.1.dist-info/licenses/LICENSE +19 -0
- accelforge-0.0.1.dist-info/top_level.txt +5 -0
- docs/_build/html/_sources/fastfusion.frontend.mapper.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.frontend.rst.txt +70 -0
- docs/_build/html/_sources/fastfusion.frontend.workload.rst.txt +21 -0
- docs/_build/html/_sources/fastfusion.mapper.FFM.rst.txt +37 -0
- docs/_build/html/_sources/fastfusion.mapper.rst.txt +18 -0
- docs/_build/html/_sources/fastfusion.rst.txt +20 -0
- docs/_build/html/_sources/fastfusion.util.rst.txt +21 -0
- docs/_build/html/_sources/index.rst.txt +87 -0
- docs/_build/html/_sources/modules.rst.txt +7 -0
- docs/_build/html/_sources/notes/citation.rst.txt +45 -0
- docs/_build/html/_sources/notes/definitions.rst.txt +43 -0
- docs/_build/html/_sources/notes/faqs.rst.txt +39 -0
- docs/_build/html/_sources/notes/modeling/accelerator_energy_latency.rst.txt +72 -0
- docs/_build/html/_sources/notes/modeling/component_energy_area.rst.txt +96 -0
- docs/_build/html/_sources/notes/modeling/mapping.rst.txt +100 -0
- docs/_build/html/_sources/notes/modeling.rst.txt +33 -0
- docs/_build/html/_sources/notes/parsing/arithmetic_parsing.rst.txt +136 -0
- docs/_build/html/_sources/notes/parsing/setexpressions.rst.txt +63 -0
- docs/_build/html/_sources/notes/parsing/yaml_parsing.rst.txt +176 -0
- docs/_build/html/_sources/notes/quickstart_and_installation.rst.txt +9 -0
- docs/_build/html/_sources/notes/spec/architecture.rst.txt +133 -0
- docs/_build/html/_sources/notes/spec/mapping.rst.txt +12 -0
- docs/_build/html/_sources/notes/spec/workload.rst.txt +83 -0
- docs/_build/html/_sources/notes/spec.rst.txt +36 -0
- docs/source/_ext/include_attrs.py +213 -0
- docs/source/_ext/include_docstring.py +364 -0
- docs/source/_ext/include_functions.py +154 -0
- docs/source/_ext/include_notebook.py +131 -0
- docs/source/_ext/include_yaml.py +119 -0
- docs/source/_ext/inherited_attributes.py +222 -0
- docs/source/_ext/paths.py +4 -0
- docs/source/conf.py +79 -0
- examples/arches/compute_in_memory/_include.yaml +74 -0
- examples/arches/compute_in_memory/_include_functions.py +229 -0
- examples/arches/compute_in_memory/_load_spec.py +57 -0
- examples/arches/compute_in_memory/components/c2c_multiplier.py +181 -0
- examples/arches/compute_in_memory/components/dac_c2c_r2r.py +605 -0
- examples/arches/compute_in_memory/components/misc.py +195 -0
- examples/arches/compute_in_memory/components/util/bit_functions.py +51 -0
- examples/arches/compute_in_memory/components/zero_comparator.py +92 -0
- examples/arches/compute_in_memory/isaac.yaml +233 -0
- examples/arches/compute_in_memory/memory_cells/ecram_demo.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_isaac_isca_2016.yaml +64 -0
- examples/arches/compute_in_memory/memory_cells/rram_neurosim_default.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/rram_raella_isca_2023.yaml +70 -0
- examples/arches/compute_in_memory/memory_cells/rram_wan_nature_2022.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_colonnade_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_example.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_jia_jssc_2020.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_sinangil_jssc_2021.yaml +63 -0
- examples/arches/compute_in_memory/memory_cells/sram_wang_vlsi_2022.yaml +63 -0
- examples/arches/compute_in_memory/wang_vlsi_2022.yaml +289 -0
- examples/arches/eyeriss.yaml +68 -0
- examples/arches/fanout_variations/at_glb.yaml +31 -0
- examples/arches/fanout_variations/at_glb_with_fanout_node.yaml +34 -0
- examples/arches/fanout_variations/at_mac.yaml +31 -0
- examples/arches/fanout_variations/at_mac_with_constraints.yaml +38 -0
- examples/arches/fanout_variations/at_mac_with_fanout_node.yaml +34 -0
- examples/arches/nvdla.yaml +47 -0
- examples/arches/simple.yaml +28 -0
- examples/arches/tpu_v4i.yaml +67 -0
- examples/mappings/unfused_matmuls_to_simple.yaml +33 -0
- examples/misc/component_annotated.yaml +33 -0
- examples/workloads/gpt3_6.7B.yaml +124 -0
- examples/workloads/matmuls.yaml +20 -0
- examples/workloads/mobilenet_28.yaml +81 -0
- examples/workloads/mobilenet_various_separate.yaml +106 -0
- examples/workloads/three_matmuls_annotated.yaml +59 -0
- notebooks/.ipynb_checkpoints/fastfusion_arch_study_michael-checkpoint.ipynb +359 -0
- notebooks/compute_in_memory/_scripts.py +339 -0
- notebooks/compute_in_memory/isaac.guide.ipynb +270 -0
- notebooks/compute_in_memory/wang_vlsi_2022.ipynb +602 -0
- notebooks/paths.py +4 -0
- notebooks/tutorials/.ipynb_checkpoints/1_FFM-checkpoint.ipynb +3110 -0
- notebooks/tutorials/FFM.ipynb +3498 -0
- notebooks/tutorials/_include.py +48 -0
- notebooks/tutorials/component_energy_area.ipynb +363 -0
- tests/Q_mapping.yaml +38 -0
- tests/__init__.py +0 -0
- tests/conv.mapping.yaml +27 -0
- tests/conv.workload.yaml +13 -0
- tests/conv_sym.mapping.yaml +43 -0
- tests/copy.mapping.yaml +35 -0
- tests/copy.workload.yaml +15 -0
- tests/distribuffers/__init__.py +0 -0
- tests/distribuffers/multicast/test_cases.yaml +482 -0
- tests/distribuffers/spec/binding/valid_bindings.yaml +97 -0
- tests/distribuffers/spec/distributed.yaml +100 -0
- tests/distribuffers/spec/logical_arch.yaml +32 -0
- tests/distribuffers/spec/physical_arch.yaml +69 -0
- tests/distribuffers/test_binding.py +48 -0
- tests/frontend/__init__.py +0 -0
- tests/frontend/test_mapping_viz.py +52 -0
- tests/mapper/__init__.py +0 -0
- tests/mapper/configs/conv1d/conv1d.mapping.yaml +31 -0
- tests/mapper/configs/conv1d/conv1d.workload.yaml +11 -0
- tests/mapper/configs/two_conv1d/two_conv1d.expected.yaml +38 -0
- tests/mapper/configs/two_conv1d/two_conv1d.mapping.yaml +54 -0
- tests/mapper/configs/two_conv1d/two_conv1d.workload.yaml +19 -0
- tests/mapper/test_mapping_to_isl.py +90 -0
- tests/mapper/test_spatial_reuse_analysis.py +67 -0
- tests/mapper/test_temporal_reuse_analysis.py +56 -0
- tests/mapper/util.py +58 -0
- tests/matmul.mapping.yaml +29 -0
- tests/matmul.workload.yaml +12 -0
- tests/matmul_spatial.mapping.yaml +44 -0
- tests/mha.renames.yaml +65 -0
- tests/mha.workload.yaml +67 -0
- tests/mha.yaml +59 -0
- tests/mha_full.workload.yaml +67 -0
- tests/mobilenet.workload.yaml +35 -0
- tests/mobilenet_long.workload.yaml +64 -0
- tests/pmappingcache.py +24 -0
- tests/processing_stage.arch.yaml +40 -0
- tests/snowcat.arch.yaml +36 -0
- tests/test_ffm_join_pmappings.py +106 -0
- tests/test_ffm_make_pmappings.py +82 -0
- tests/test_ffm_make_tile_shapes.py +49 -0
- tests/test_mapper.py +100 -0
- tests/test_model.py +37 -0
- tests/test_plotting.py +72 -0
- tests/test_processing_stage.py +46 -0
- tests/test_symbolic_model.py +248 -0
- tests/test_workload.py +141 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from sympy import Symbol
|
|
2
|
+
import accelforge.frontend.arch as arch
|
|
3
|
+
from accelforge.frontend.mapping import TensorHolder
|
|
4
|
+
from accelforge.mapper.FFM._make_pmappings.pmapper_job import Job
|
|
5
|
+
from accelforge.model._looptree.reuse import symbolic
|
|
6
|
+
from accelforge.model._looptree.reuse.symbolic import (
|
|
7
|
+
analyze_reuse_and_add_reservations_to_mapping,
|
|
8
|
+
)
|
|
9
|
+
from accelforge.model._looptree.energy import (
|
|
10
|
+
compute_energy_from_actions,
|
|
11
|
+
gather_actions,
|
|
12
|
+
)
|
|
13
|
+
from accelforge.model._looptree.latency.memory import component_latency
|
|
14
|
+
from accelforge.mapper.FFM._join_pmappings.pmapping_dataframe import (
|
|
15
|
+
nameloop2col,
|
|
16
|
+
tensor2col,
|
|
17
|
+
firstlatency2col,
|
|
18
|
+
action2col,
|
|
19
|
+
energy2col,
|
|
20
|
+
)
|
|
21
|
+
from accelforge.frontend.mapper.metrics import Metrics
|
|
22
|
+
import sympy
|
|
23
|
+
from numbers import Number
|
|
24
|
+
from accelforge.util._sympy.broadcast_max import Max
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def run_model(
|
|
28
|
+
job: Job,
|
|
29
|
+
) -> tuple[list[Symbol], dict[str, float], dict[str, float], dict[str, float]]:
|
|
30
|
+
pmapping = job.mapping
|
|
31
|
+
spec = job.spec
|
|
32
|
+
metrics = job.metrics
|
|
33
|
+
is_copy_op = job.is_copy_operation
|
|
34
|
+
workload = spec.workload
|
|
35
|
+
|
|
36
|
+
df = {}
|
|
37
|
+
|
|
38
|
+
reuse = analyze_reuse_and_add_reservations_to_mapping(job)
|
|
39
|
+
|
|
40
|
+
latency = component_latency(reuse, job.flattened_arch, pmapping, spec)
|
|
41
|
+
try:
|
|
42
|
+
overall_latency = Max(*latency.values()) if latency else 0
|
|
43
|
+
except Exception as e:
|
|
44
|
+
for k, v in latency.items():
|
|
45
|
+
if not isinstance(v, (Number, sympy.Symbol, sympy.Expr)):
|
|
46
|
+
raise ValueError(
|
|
47
|
+
f"Invalid type for latency: {k}: {type(v)} {str(v).strip()}"
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Error calculating latency for {job.einsum_name}. Could not calculate "
|
|
52
|
+
f"a symbolic max of the following latencies:\n\t"
|
|
53
|
+
+ "\n\t".join(
|
|
54
|
+
[f"{k}: {type(v)} {str(v).strip()}" for k, v in latency.items()]
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
memory_to_size = {}
|
|
58
|
+
component_to_non_power_gated_porp = {}
|
|
59
|
+
usage_df = {}
|
|
60
|
+
|
|
61
|
+
non_power_gated_instances = 1
|
|
62
|
+
for node in job.flattened_arch:
|
|
63
|
+
if isinstance(node, arch.TensorHolder):
|
|
64
|
+
if isinstance(node, arch.Memory):
|
|
65
|
+
memory_to_size[node.name] = node.size
|
|
66
|
+
|
|
67
|
+
# If there's no loops that use this spatial fanout, then the model won't output
|
|
68
|
+
# any usage. We still want to reserve at least one spatial instance in this
|
|
69
|
+
# case.
|
|
70
|
+
used_fanout = reuse.fanout.get((node.name, job.einsum_name), {})
|
|
71
|
+
for s in node.spatial:
|
|
72
|
+
usage = used_fanout.get(s.name, 1) / s.fanout
|
|
73
|
+
scaled_usage = usage * s.usage_scale
|
|
74
|
+
usage_df[f"usage<SEP>spatial<SEP>{node.name}<SEP>{s.name}"] = scaled_usage
|
|
75
|
+
non_power_gated_instances *= usage
|
|
76
|
+
component_to_non_power_gated_porp[node.name] = non_power_gated_instances
|
|
77
|
+
|
|
78
|
+
actions = gather_actions(reuse, None, use_name=True)
|
|
79
|
+
energy = compute_energy_from_actions(
|
|
80
|
+
spec, actions, overall_latency, component_to_non_power_gated_porp
|
|
81
|
+
)
|
|
82
|
+
if symbolic.PRINT_FORMULAS:
|
|
83
|
+
for k, v in energy.items():
|
|
84
|
+
print(f"{k}: {v}")
|
|
85
|
+
|
|
86
|
+
fusable_tensors = workload.tensor_names_used_in_multiple_einsums
|
|
87
|
+
tensor_to_backing = {}
|
|
88
|
+
for node in pmapping.nodes:
|
|
89
|
+
if isinstance(node, TensorHolder):
|
|
90
|
+
for tensor in node.tensors:
|
|
91
|
+
if tensor not in tensor_to_backing and tensor in fusable_tensors:
|
|
92
|
+
tensor_to_backing[tensor] = node.component
|
|
93
|
+
|
|
94
|
+
total_occupancy = {}
|
|
95
|
+
compute_unit = pmapping.nodes[-1].component
|
|
96
|
+
|
|
97
|
+
n_instances = workload.n_instances * workload.einsums[job.einsum_name].n_instances
|
|
98
|
+
|
|
99
|
+
n_loop_options = set()
|
|
100
|
+
for buffet, stats in reuse.buffet_stats.items():
|
|
101
|
+
if buffet.level == compute_unit:
|
|
102
|
+
continue
|
|
103
|
+
|
|
104
|
+
occupancy = stats.max_occupancy
|
|
105
|
+
|
|
106
|
+
if occupancy == 0:
|
|
107
|
+
continue
|
|
108
|
+
if stats.persistent:
|
|
109
|
+
occupancy *= n_instances
|
|
110
|
+
|
|
111
|
+
for tensor, backing in tensor_to_backing.items():
|
|
112
|
+
if (is_copy_op or buffet.tensor == tensor) and buffet.level == backing:
|
|
113
|
+
df[tensor2col(tensor)] = occupancy / memory_to_size[buffet.level]
|
|
114
|
+
|
|
115
|
+
total_occupancy.setdefault(buffet.level, {}).setdefault(stats.n_loops_above, 0)
|
|
116
|
+
total_occupancy[buffet.level][stats.n_loops_above] += occupancy
|
|
117
|
+
n_loop_options.add(stats.n_loops_above)
|
|
118
|
+
|
|
119
|
+
for memory, occupancies in total_occupancy.items():
|
|
120
|
+
if memory not in job.memories_track_all:
|
|
121
|
+
continue
|
|
122
|
+
running_total = 0
|
|
123
|
+
for n_loop in n_loop_options:
|
|
124
|
+
if n_loop in occupancies:
|
|
125
|
+
running_total += occupancies[n_loop]
|
|
126
|
+
df[nameloop2col(memory, n_loop)] = (
|
|
127
|
+
running_total / memory_to_size[memory]
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
if metrics & Metrics.ACTIONS:
|
|
131
|
+
detailed_actions = gather_actions(reuse, None, verbose=True, use_name=True)
|
|
132
|
+
for key, count in detailed_actions.items():
|
|
133
|
+
df[action2col(key)] = count.total * n_instances
|
|
134
|
+
detailed_energy = compute_energy_from_actions(
|
|
135
|
+
spec, detailed_actions, overall_latency, component_to_non_power_gated_porp
|
|
136
|
+
)
|
|
137
|
+
for key, energy_val in detailed_energy.items():
|
|
138
|
+
df[energy2col(key)] = energy_val * n_instances
|
|
139
|
+
for component, cur_latency in latency.items():
|
|
140
|
+
df[f"latency<SEP>{component}"] = cur_latency * n_instances
|
|
141
|
+
|
|
142
|
+
if metrics & Metrics.LATENCY:
|
|
143
|
+
df["Total<SEP>latency"] = overall_latency * n_instances
|
|
144
|
+
# df[f"latency<SEP>compute"] = comp_latency * n_instances
|
|
145
|
+
# For first latency, we'll follow the convention of treating compute
|
|
146
|
+
# as a component, similarly to memory (see below).
|
|
147
|
+
for compute_level, stats in reuse.compute_stats.items(): # FIRST LATENCY
|
|
148
|
+
for idx, max_first_latency in stats.max_first_latency.items():
|
|
149
|
+
df[firstlatency2col(compute_level.level, idx)] = (
|
|
150
|
+
max_first_latency * n_instances
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
if metrics & Metrics.ENERGY:
|
|
154
|
+
df["Total<SEP>energy"] = sum(energy.values()) * n_instances
|
|
155
|
+
|
|
156
|
+
per_memory_usage_df = {}
|
|
157
|
+
for memory, occupancies in total_occupancy.items():
|
|
158
|
+
if job.ignored_resources is not None and memory not in job.ignored_resources:
|
|
159
|
+
key = f"usage<SEP>memory<SEP>{memory}"
|
|
160
|
+
per_memory_usage_df[key] = (
|
|
161
|
+
sum(occupancies.values()) / memory_to_size[memory]
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
return (
|
|
165
|
+
reuse.symbols,
|
|
166
|
+
df,
|
|
167
|
+
per_memory_usage_df,
|
|
168
|
+
usage_df,
|
|
169
|
+
reuse.tensor2mapping,
|
|
170
|
+
)
|
|
@@ -0,0 +1,174 @@
|
|
|
1
|
+
from collections import defaultdict
|
|
2
|
+
from numbers import Number
|
|
3
|
+
|
|
4
|
+
from sympy import Symbol
|
|
5
|
+
|
|
6
|
+
from accelforge.frontend.workload import Workload
|
|
7
|
+
from accelforge.frontend._workload_isl._symbolic import get_stride_and_halo
|
|
8
|
+
from accelforge.frontend.mapping import (
|
|
9
|
+
Loop,
|
|
10
|
+
Mapping,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SymbolRelations:
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.what_tiles_symbol: list[tuple[Symbol | int, Symbol | int]] = []
|
|
17
|
+
self.tile_shape_and_initial: list[tuple[Symbol | int, Symbol | int]] = []
|
|
18
|
+
self.delta_choices: list[tuple[Symbol, frozenset[int]]] = []
|
|
19
|
+
self.bounds: tuple[tuple[Symbol, int, int], ...] = ()
|
|
20
|
+
|
|
21
|
+
def make_bounds(self):
|
|
22
|
+
all_symbols = set(
|
|
23
|
+
s for w in self.what_tiles_symbol for s in w if isinstance(s, Symbol)
|
|
24
|
+
)
|
|
25
|
+
self.bounds = tuple((s, 1, self.get_max_size(s)) for s in all_symbols)
|
|
26
|
+
|
|
27
|
+
def is_stride(self, symbol: Symbol) -> bool:
|
|
28
|
+
"""Check if `symbol` is a stride."""
|
|
29
|
+
for tile_shape, initial in self.tile_shape_and_initial:
|
|
30
|
+
if tile_shape == symbol:
|
|
31
|
+
return True
|
|
32
|
+
if initial == symbol:
|
|
33
|
+
return False
|
|
34
|
+
return True
|
|
35
|
+
|
|
36
|
+
def is_initial_tile_shape(self, symbol: Symbol) -> bool:
|
|
37
|
+
"""Check if `symbol` is a initial tile shape."""
|
|
38
|
+
for tile_shape, initial in self.tile_shape_and_initial:
|
|
39
|
+
if tile_shape == symbol:
|
|
40
|
+
return False
|
|
41
|
+
if initial == symbol:
|
|
42
|
+
return True
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
def get_tile_shape(self, symbol: Symbol) -> Symbol | int:
|
|
46
|
+
"""Get the stride corresponding to the initial tile shape `symbol`."""
|
|
47
|
+
for tile_shape, initial in self.tile_shape_and_initial:
|
|
48
|
+
if initial == symbol:
|
|
49
|
+
return tile_shape
|
|
50
|
+
raise ValueError(f"Symbol {symbol} not found as initial in {self}")
|
|
51
|
+
|
|
52
|
+
def get_initial(self, symbol: Symbol, none_if_fail: bool = False) -> Symbol | int:
|
|
53
|
+
for tile_shape, initial in self.tile_shape_and_initial:
|
|
54
|
+
if tile_shape == symbol:
|
|
55
|
+
return initial
|
|
56
|
+
if not none_if_fail:
|
|
57
|
+
raise ValueError(f"Symbol {symbol} not found as tile_shape in {self}")
|
|
58
|
+
else:
|
|
59
|
+
return None
|
|
60
|
+
|
|
61
|
+
def get_delta_choices(self, symbol: Symbol) -> frozenset[int]:
|
|
62
|
+
"""Get the possible initial deltas for the rank variable represented by `symbol`."""
|
|
63
|
+
for initial, choices in self.delta_choices:
|
|
64
|
+
if initial == symbol:
|
|
65
|
+
return choices
|
|
66
|
+
raise ValueError(f"Symbol {symbol} not found in {self}")
|
|
67
|
+
|
|
68
|
+
def get_inner_tiles(
|
|
69
|
+
self, symbol: Symbol, none_if_fail: bool = False
|
|
70
|
+
) -> Symbol | int | None:
|
|
71
|
+
"""Get tiles within the tile represented by `symbol`."""
|
|
72
|
+
for tiled_by, what_tiles in self.what_tiles_symbol:
|
|
73
|
+
if tiled_by == symbol:
|
|
74
|
+
return what_tiles
|
|
75
|
+
if none_if_fail:
|
|
76
|
+
return None
|
|
77
|
+
raise ValueError(f"Symbol {symbol} not found in {self}")
|
|
78
|
+
|
|
79
|
+
def get_outer_tiles(
|
|
80
|
+
self, symbol: Symbol, none_if_fail: bool = False
|
|
81
|
+
) -> Symbol | int | None:
|
|
82
|
+
"""Get the tile that contain the tile represented by `symbol`."""
|
|
83
|
+
for tiled_by, what_tiles in self.what_tiles_symbol:
|
|
84
|
+
if what_tiles == symbol:
|
|
85
|
+
return tiled_by
|
|
86
|
+
if none_if_fail:
|
|
87
|
+
return None
|
|
88
|
+
raise ValueError(f"Symbol {symbol} not found in {self}")
|
|
89
|
+
|
|
90
|
+
def get_max_size(self, symbol: Symbol) -> Number:
|
|
91
|
+
while not isinstance(symbol, Number):
|
|
92
|
+
symbol = self.get_outer_tiles(symbol)
|
|
93
|
+
return symbol
|
|
94
|
+
|
|
95
|
+
@staticmethod
|
|
96
|
+
def from_pmapping_and_shape(
|
|
97
|
+
pmapping: Mapping, shape: dict[str, int], workload: Workload
|
|
98
|
+
) -> "SymbolRelations":
|
|
99
|
+
initial_delta_choices = get_initial_delta_choices(
|
|
100
|
+
pmapping.nodes[-1].einsum, workload
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
relation = SymbolRelations()
|
|
104
|
+
last_seen_loop_per_rank_var: dict[str, Symbol | int] = dict(shape)
|
|
105
|
+
for node in pmapping.nodes:
|
|
106
|
+
if not isinstance(node, Loop):
|
|
107
|
+
continue
|
|
108
|
+
prev = last_seen_loop_per_rank_var.get(node.rank_variable, None)
|
|
109
|
+
# If we're a symbol and we've seen an outer loop with the same rank variable,
|
|
110
|
+
# then we tile that one.
|
|
111
|
+
if prev is not None:
|
|
112
|
+
relation.what_tiles_symbol.append((prev, node.tile_shape))
|
|
113
|
+
last_seen_loop_per_rank_var[node.rank_variable] = node.tile_shape
|
|
114
|
+
|
|
115
|
+
if (
|
|
116
|
+
isinstance(node.initial_tile_shape, Symbol)
|
|
117
|
+
and node.initial_tile_shape != node.tile_shape
|
|
118
|
+
):
|
|
119
|
+
relation.tile_shape_and_initial.append(
|
|
120
|
+
(node.tile_shape, node.initial_tile_shape)
|
|
121
|
+
)
|
|
122
|
+
relation.delta_choices.append(
|
|
123
|
+
(
|
|
124
|
+
node.initial_tile_shape,
|
|
125
|
+
frozenset(initial_delta_choices[node.rank_variable]),
|
|
126
|
+
)
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
for r, s in last_seen_loop_per_rank_var.items():
|
|
130
|
+
if isinstance(s, Symbol):
|
|
131
|
+
relation.what_tiles_symbol.append((s, 1))
|
|
132
|
+
|
|
133
|
+
relation.make_bounds()
|
|
134
|
+
return relation
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_initial_delta_choices(einsum_name: str, workload: Workload):
|
|
138
|
+
stride_and_halo = get_stride_and_halo(workload)
|
|
139
|
+
einsum = workload.einsums[einsum_name]
|
|
140
|
+
|
|
141
|
+
choices = defaultdict(lambda: set([0]))
|
|
142
|
+
consumer_chains = []
|
|
143
|
+
stack = [[(None, einsum)]]
|
|
144
|
+
while stack:
|
|
145
|
+
cur_chain = stack.pop()
|
|
146
|
+
last_tensor, last_einsum = cur_chain[-1]
|
|
147
|
+
for tensor in last_einsum.output_tensor_names:
|
|
148
|
+
einsums_with_tensor_as_input = workload.einsums_with_tensor_as_input(tensor)
|
|
149
|
+
|
|
150
|
+
if len(einsums_with_tensor_as_input) == 0:
|
|
151
|
+
consumer_chains.append(cur_chain)
|
|
152
|
+
|
|
153
|
+
for next_einsum in einsums_with_tensor_as_input:
|
|
154
|
+
stack.append(cur_chain + [(tensor, next_einsum)])
|
|
155
|
+
|
|
156
|
+
for chain in consumer_chains:
|
|
157
|
+
for (_, producer), (tensor, consumer) in zip(
|
|
158
|
+
list(reversed(chain))[1:], reversed(chain)
|
|
159
|
+
):
|
|
160
|
+
rank_stride_and_halo = stride_and_halo[(consumer.name, tensor)]
|
|
161
|
+
if tensor is None:
|
|
162
|
+
break # done
|
|
163
|
+
|
|
164
|
+
for cons_rank_var in consumer.rank_variables:
|
|
165
|
+
for prod_rank_var in producer.rank_variables:
|
|
166
|
+
prod_rank = prod_rank_var.upper()
|
|
167
|
+
for cons_choice in choices[cons_rank_var]:
|
|
168
|
+
key = (prod_rank, cons_rank_var)
|
|
169
|
+
if key not in rank_stride_and_halo:
|
|
170
|
+
continue
|
|
171
|
+
stride, halo = rank_stride_and_halo[key]
|
|
172
|
+
choices[prod_rank_var].add(int(cons_choice * stride + halo))
|
|
173
|
+
|
|
174
|
+
return choices
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
from collections.abc import Sequence
|
|
2
|
+
from dataclasses import dataclass, field
|
|
3
|
+
import logging
|
|
4
|
+
from numbers import Number
|
|
5
|
+
from typing import Any, Callable
|
|
6
|
+
from uuid import UUID, uuid4
|
|
7
|
+
|
|
8
|
+
import accelforge.frontend.arch as arch
|
|
9
|
+
from accelforge.frontend.mapping import (
|
|
10
|
+
Mapping,
|
|
11
|
+
)
|
|
12
|
+
from accelforge.frontend.spec import Spec
|
|
13
|
+
from accelforge.frontend._workload_isl._symbolic import Relevant, PartiallyRelevant
|
|
14
|
+
from accelforge.frontend.workload import (
|
|
15
|
+
Einsum,
|
|
16
|
+
EinsumName,
|
|
17
|
+
RankVariable,
|
|
18
|
+
SymbolTable,
|
|
19
|
+
TensorName,
|
|
20
|
+
Workload,
|
|
21
|
+
Rank,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from accelforge.frontend.mapper import Metrics
|
|
25
|
+
from accelforge.mapper.FFM._join_pmappings.compatibility import (
|
|
26
|
+
Compatibility,
|
|
27
|
+
)
|
|
28
|
+
from accelforge.mapper.FFM._make_pmappings.contraints.constraints import (
|
|
29
|
+
MappingConstraints,
|
|
30
|
+
_ConstraintLambda,
|
|
31
|
+
)
|
|
32
|
+
from accelforge.util.parallel import _expfmt
|
|
33
|
+
from accelforge.util._itertools import first
|
|
34
|
+
from accelforge.frontend.mapping import Reservation as ReservationNode
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def make_compatibility(
|
|
38
|
+
mapping: Mapping,
|
|
39
|
+
fusable_tensors: set[TensorName],
|
|
40
|
+
workload: Workload,
|
|
41
|
+
rank_variable_bounds: dict[RankVariable, int],
|
|
42
|
+
stride_and_halo,
|
|
43
|
+
) -> Compatibility:
|
|
44
|
+
|
|
45
|
+
einsum = workload.einsums[mapping.nodes[-1].einsum]
|
|
46
|
+
rank_variable_to_ranks = {
|
|
47
|
+
t.name: t.rank_variable2ranks for t in einsum.tensor_accesses
|
|
48
|
+
}
|
|
49
|
+
return Compatibility.from_mapping(mapping, fusable_tensors, rank_variable_to_ranks)
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
@dataclass
|
|
53
|
+
class Job:
|
|
54
|
+
spec: Spec | None
|
|
55
|
+
metrics: Metrics
|
|
56
|
+
rank_variable_bounds: dict[RankVariable, int]
|
|
57
|
+
|
|
58
|
+
job_id: UUID = field(default_factory=uuid4)
|
|
59
|
+
|
|
60
|
+
stride_and_halo: (
|
|
61
|
+
dict[TensorName, dict[tuple[Rank, RankVariable], tuple[int, int]]] | None
|
|
62
|
+
) = None
|
|
63
|
+
mapping: Mapping | None = None
|
|
64
|
+
constraints: MappingConstraints | None = None
|
|
65
|
+
fusable_tensors: set[TensorName] | None = None
|
|
66
|
+
flattened_arch: list[arch.Leaf] | None = None
|
|
67
|
+
|
|
68
|
+
einsum_name: EinsumName | None = None
|
|
69
|
+
"""If the Job is for a single einsum, this is the einsum name."""
|
|
70
|
+
|
|
71
|
+
_compatibility: Compatibility | None = None
|
|
72
|
+
memories_track_all: list[str] | None = None
|
|
73
|
+
memories_track_pmappings_only: list[str] | None = None
|
|
74
|
+
ignored_resources: set[str] | None = None
|
|
75
|
+
time_limit: float | int = float("inf")
|
|
76
|
+
memory_limit: float | int = float("inf")
|
|
77
|
+
messages: list[str] = field(default_factory=list)
|
|
78
|
+
pmapping_keep_rates: dict[str, float] = field(default_factory=dict)
|
|
79
|
+
tensor_to_relevancy: (
|
|
80
|
+
dict[TensorName, dict[RankVariable, Relevant | PartiallyRelevant]] | None
|
|
81
|
+
) = None
|
|
82
|
+
|
|
83
|
+
n_total_pmappings: int = 1
|
|
84
|
+
n_valid_pmappings: int = 1
|
|
85
|
+
n_evaluated_pmappings: int = 0
|
|
86
|
+
|
|
87
|
+
_update_compatibility_with_tile_shapes_args: dict[str, Any] | None = None
|
|
88
|
+
|
|
89
|
+
symbol_table: SymbolTable | None = None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def einsum(self) -> Einsum:
|
|
93
|
+
return self.spec.workload.einsums[self.einsum_name]
|
|
94
|
+
|
|
95
|
+
@property
|
|
96
|
+
def compatibility(self) -> Compatibility:
|
|
97
|
+
if self._compatibility is None:
|
|
98
|
+
self._make_compatibility_and_updater()
|
|
99
|
+
return self._compatibility
|
|
100
|
+
|
|
101
|
+
@compatibility.setter
|
|
102
|
+
def compatibility(self, compatibility: Compatibility):
|
|
103
|
+
self._compatibility = compatibility
|
|
104
|
+
|
|
105
|
+
def update_compatibility_with_tile_shapes(
|
|
106
|
+
self, tile_shapes: Sequence[Number], tensor2size: dict
|
|
107
|
+
) -> Callable[[Sequence[Number], dict], Compatibility]:
|
|
108
|
+
if self._update_compatibility_with_tile_shapes_args is None:
|
|
109
|
+
self._make_compatibility_and_updater()
|
|
110
|
+
return update_compatibility_with_tile_shapes(
|
|
111
|
+
self._compatibility,
|
|
112
|
+
tile_shapes=tile_shapes,
|
|
113
|
+
tensor2size=tensor2size,
|
|
114
|
+
**self._update_compatibility_with_tile_shapes_args,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
def _make_compatibility_and_updater(self):
|
|
118
|
+
from accelforge.model._looptree.reuse.symbolic import (
|
|
119
|
+
quick_insert_reservation_nodes,
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
with_reservations = quick_insert_reservation_nodes(self)
|
|
123
|
+
self._compatibility = make_compatibility(
|
|
124
|
+
with_reservations,
|
|
125
|
+
self.fusable_tensors,
|
|
126
|
+
self.spec.workload,
|
|
127
|
+
self.rank_variable_bounds,
|
|
128
|
+
self.stride_and_halo,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@property
|
|
132
|
+
def is_copy_operation(self) -> bool:
|
|
133
|
+
return self.spec.workload.einsums[self.einsum_name].is_copy_operation
|
|
134
|
+
|
|
135
|
+
@classmethod
|
|
136
|
+
def make_job(
|
|
137
|
+
cls,
|
|
138
|
+
**kwargs,
|
|
139
|
+
) -> "Job":
|
|
140
|
+
defaults = {
|
|
141
|
+
"spec": None,
|
|
142
|
+
"mapping": None,
|
|
143
|
+
"workload": None,
|
|
144
|
+
"architecture": None,
|
|
145
|
+
}
|
|
146
|
+
kwargs = {**defaults, **kwargs}
|
|
147
|
+
return cls(**kwargs)
|
|
148
|
+
|
|
149
|
+
def pretty_str(self) -> str:
|
|
150
|
+
constraints = self.constraints.get_all_constraints()
|
|
151
|
+
node2constraints: dict[int, list[_ConstraintLambda]] = {}
|
|
152
|
+
for constraint in constraints:
|
|
153
|
+
for target_index in constraint._target_node_indices:
|
|
154
|
+
l = node2constraints.setdefault(target_index, [])
|
|
155
|
+
l.append(constraint)
|
|
156
|
+
|
|
157
|
+
# Reservations are added after mapping generation so it messes up the indexing
|
|
158
|
+
mapping = [n for n in self.mapping.nodes if not isinstance(n, ReservationNode)]
|
|
159
|
+
|
|
160
|
+
s = ""
|
|
161
|
+
s += "=" * 80 + "\n"
|
|
162
|
+
s += f"Mapper job with ID {self.job_id}\n"
|
|
163
|
+
s += f"Einsum name: {self.einsum_name}\n"
|
|
164
|
+
s += f"Rank variable bounds: {self.rank_variable_bounds}\n"
|
|
165
|
+
s += f"Compute node name: {self.flattened_arch[-1].name}\n"
|
|
166
|
+
s += f"Mapping:\n"
|
|
167
|
+
for i, node in enumerate(mapping):
|
|
168
|
+
cur_constraints = sorted(
|
|
169
|
+
constraints.index(c) for c in node2constraints.get(i, [])
|
|
170
|
+
)
|
|
171
|
+
s += f"\t{i} {node.compact_str()} constrained by {cur_constraints}\n"
|
|
172
|
+
s += self.constraints.pretty_str()
|
|
173
|
+
s += f"Messages:\n"
|
|
174
|
+
for m in self.messages:
|
|
175
|
+
s += f"\t{m}\n"
|
|
176
|
+
|
|
177
|
+
s += f"Total pmappings: {self.n_total_pmappings}\n"
|
|
178
|
+
s += f"Valid pmappings: {self.n_valid_pmappings}\n"
|
|
179
|
+
s += f"One in {_expfmt(self.n_total_pmappings / self.n_valid_pmappings)} pmappings is valid\n"
|
|
180
|
+
s += f"Number of pmappings evaluated: {self.n_evaluated_pmappings}\n"
|
|
181
|
+
s += f"One in {_expfmt(self.n_evaluated_pmappings / self.n_total_pmappings)} pmappings was evaluated\n"
|
|
182
|
+
s += f"Pmapping elimination reasons:\n"
|
|
183
|
+
for cause, keep_rate in self.pmapping_keep_rates.items():
|
|
184
|
+
s += f"\t{cause} kept one in {_expfmt(1/keep_rate)} pmappings\n"
|
|
185
|
+
s += "=" * 80 + "\n"
|
|
186
|
+
return s
|
|
187
|
+
|
|
188
|
+
def set_total_pmappings(self, n_pmappings: int):
|
|
189
|
+
self.n_total_pmappings = n_pmappings
|
|
190
|
+
|
|
191
|
+
def log_porp_pmappings_kept(
|
|
192
|
+
self,
|
|
193
|
+
cause: str,
|
|
194
|
+
porp_kept: float,
|
|
195
|
+
out_of: int = None,
|
|
196
|
+
):
|
|
197
|
+
if out_of is not None:
|
|
198
|
+
n_kept = porp_kept * out_of + (self.n_total_pmappings - out_of)
|
|
199
|
+
porp_kept = n_kept / self.n_total_pmappings
|
|
200
|
+
|
|
201
|
+
if any(x == 0 for x in self.pmapping_keep_rates.values()):
|
|
202
|
+
return
|
|
203
|
+
|
|
204
|
+
self.pmapping_keep_rates.setdefault(cause, 1)
|
|
205
|
+
self.pmapping_keep_rates[cause] *= porp_kept
|
|
206
|
+
|
|
207
|
+
def log_message(self, message: str):
|
|
208
|
+
self.messages.append(message)
|
|
209
|
+
logging.info(message)
|
|
210
|
+
|
|
211
|
+
def __copy__(self) -> "Job":
|
|
212
|
+
new = self.__class__(**self.__dict__)
|
|
213
|
+
new.messages = self.messages.copy()
|
|
214
|
+
new.pmapping_keep_rates = self.pmapping_keep_rates.copy()
|
|
215
|
+
return new
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class SameSpecJobs(list[Job]):
|
|
219
|
+
@property
|
|
220
|
+
def spec(self) -> Spec:
|
|
221
|
+
return first(self).spec
|
|
222
|
+
|
|
223
|
+
@property
|
|
224
|
+
def rank_variable_bounds(self) -> dict[RankVariable, int]:
|
|
225
|
+
return first(self).rank_variable_bounds
|
|
226
|
+
|
|
227
|
+
@property
|
|
228
|
+
def metrics(self) -> Metrics:
|
|
229
|
+
return first(self).metrics
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
class SameEinsumJobs(SameSpecJobs):
|
|
233
|
+
def check_invariance(self):
|
|
234
|
+
all_einsums = set(job.einsum_name for job in self)
|
|
235
|
+
if len(all_einsums) > 1:
|
|
236
|
+
raise RuntimeError("broken invariance: not all Einsums are equal.")
|
|
237
|
+
|
|
238
|
+
@property
|
|
239
|
+
def fusable_tensors(self) -> set[TensorName]:
|
|
240
|
+
return first(self).fusable_tensors
|
|
241
|
+
|
|
242
|
+
@property
|
|
243
|
+
def einsum_name(self) -> set[EinsumName]:
|
|
244
|
+
return first(self).einsum_name
|
|
245
|
+
|
|
246
|
+
@property
|
|
247
|
+
def rank_variable_bounds(self) -> dict[RankVariable, int]:
|
|
248
|
+
return first(self).rank_variable_bounds
|
|
249
|
+
|
|
250
|
+
@property
|
|
251
|
+
def stride_and_halo(
|
|
252
|
+
self,
|
|
253
|
+
) -> dict[tuple[str, str], dict[tuple[str, str], tuple[int, int]]]:
|
|
254
|
+
return first(self).stride_and_halo
|
|
255
|
+
|
|
256
|
+
@property
|
|
257
|
+
def is_copy_op(self) -> bool:
|
|
258
|
+
return first(self).is_copy_operation
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
class SameCompatibilityJobs(SameEinsumJobs):
|
|
262
|
+
"""Jobs with the same compatibility before tile shape exploration."""
|
|
263
|
+
|
|
264
|
+
def check_invariance(self):
|
|
265
|
+
all_compatibilities = set(job.compatibility for job in self)
|
|
266
|
+
if len(all_compatibilities) > 1:
|
|
267
|
+
raise RuntimeError(
|
|
268
|
+
"broken invariance: " "not all compatibilities are equal."
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
@property
|
|
272
|
+
def compatibility(self) -> Compatibility:
|
|
273
|
+
return first(self).compatibility
|
|
274
|
+
|
|
275
|
+
@property
|
|
276
|
+
def update_compatibility_with_tile_shapes(
|
|
277
|
+
self,
|
|
278
|
+
) -> Callable[[Sequence[Number], dict], Compatibility]:
|
|
279
|
+
return first(self).update_compatibility_with_tile_shapes
|
|
280
|
+
|
|
281
|
+
def split(self) -> list["SameCompatibilityJobs"]:
|
|
282
|
+
return [SameCompatibilityJobs([j]) for j in self]
|