tff-core 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tff/core/__init__.py ADDED
File without changes
File without changes
@@ -0,0 +1,174 @@
1
+ """Custom dependency exclusion rules for layer/domain boundaries."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from pathlib import Path
8
+
9
+ from tff.core.model import ModelRepresentation
10
+ from tff.core.config import FitnessFunctionsConfig, resolve_project_path
11
+ from tff.core.report import LintFinding
12
+ from tff.core.utils.paths import get_layer_and_domain, model_path_relative
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class CustomExclusionsChecker:
18
+ """Enforce custom exclusions for model dependencies between layers."""
19
+
20
+ def __init__(self, models: dict[str, ModelRepresentation], exclusions_path: Path):
21
+ self.models = models
22
+ self.exclusions_path = exclusions_path
23
+ self.exclusions = self._load_exclusions()
24
+
25
+ def _load_exclusions(self) -> dict:
26
+ if not self.exclusions_path.exists():
27
+ logger.warning(
28
+ "Config file %s not found. No exclusions will be enforced.",
29
+ self.exclusions_path,
30
+ )
31
+ return {}
32
+
33
+ try:
34
+ with open(self.exclusions_path, encoding="utf-8") as f:
35
+ return json.load(f)
36
+ except (json.JSONDecodeError, OSError) as e:
37
+ logger.warning(
38
+ "Could not load exclusions config from %s: %s",
39
+ self.exclusions_path,
40
+ e,
41
+ )
42
+ return {}
43
+
44
+ def _normalize_model_name(self, name: str) -> str:
45
+ parts = name.replace('"', "").split(".")
46
+ if len(parts) >= 2:
47
+ return f"{parts[-2]}.{parts[-1]}"
48
+ return name
49
+
50
+ def _is_allowed_exception(self, model_name: str, dependency_name: str) -> bool:
51
+ normalized_model = self._normalize_model_name(model_name)
52
+ normalized_dependency = self._normalize_model_name(dependency_name)
53
+
54
+ for exception in self.exclusions.get("allowed_exceptions", []):
55
+ if (
56
+ exception.get("model") == normalized_model
57
+ and exception.get("dependency") == normalized_dependency
58
+ ):
59
+ return True
60
+ return False
61
+
62
+ def _is_excluded_dependency(
63
+ self,
64
+ source_layer: str,
65
+ source_domain: str,
66
+ target_layer: str,
67
+ target_domain: str,
68
+ model_name: str | None = None,
69
+ dependency_name: str | None = None,
70
+ ) -> bool:
71
+ if model_name and dependency_name:
72
+ if self._is_allowed_exception(model_name, dependency_name):
73
+ return False
74
+
75
+ for exclusion in self.exclusions.get("exclusions", []):
76
+ source_match = True
77
+ if (
78
+ "source_layer" in exclusion
79
+ and exclusion["source_layer"] != source_layer
80
+ ):
81
+ source_match = False
82
+ if (
83
+ "source_domain" in exclusion
84
+ and exclusion["source_domain"] != source_domain
85
+ ):
86
+ source_match = False
87
+
88
+ target_match = True
89
+ if (
90
+ "target_layer" in exclusion
91
+ and exclusion["target_layer"] != target_layer
92
+ ):
93
+ target_match = False
94
+ if (
95
+ "target_domain" in exclusion
96
+ and exclusion["target_domain"] != target_domain
97
+ ):
98
+ target_match = False
99
+
100
+ if source_match and target_match:
101
+ return True
102
+
103
+ return False
104
+
105
+ def check_model(self, model: ModelRepresentation) -> list[str]:
106
+ if model.is_symbolic:
107
+ return []
108
+
109
+ violations = []
110
+ model_layer, model_domain = get_layer_and_domain(model.path)
111
+ if not model_layer:
112
+ return []
113
+
114
+ for dependency_name in model.depends_on:
115
+ try:
116
+ dependency_model = self.models.get(dependency_name)
117
+ if not dependency_model:
118
+ continue
119
+
120
+ dep_layer, dep_domain = get_layer_and_domain(dependency_model.path)
121
+ if not dep_layer:
122
+ continue
123
+
124
+ if self._is_excluded_dependency(
125
+ source_layer=dep_layer,
126
+ source_domain=dep_domain or "",
127
+ target_layer=model_layer,
128
+ target_domain=model_domain or "",
129
+ model_name=str(model.name),
130
+ dependency_name=str(dependency_name),
131
+ ):
132
+ violations.append(
133
+ f"Model '{model.name}' in layer '{model_layer}"
134
+ f"{f'/{model_domain}' if model_domain else ''}' "
135
+ f"depends on '{dependency_name}' in layer '{dep_layer}"
136
+ f"{f'/{dep_domain}' if dep_domain else ''}', "
137
+ f"which is not allowed by custom exclusions"
138
+ )
139
+ except Exception as e:
140
+ logger.error(
141
+ "Unexpected error checking dependency %s for model %s: %s",
142
+ dependency_name,
143
+ model.name,
144
+ e,
145
+ exc_info=True,
146
+ )
147
+ continue
148
+
149
+ return violations
150
+
151
+
152
+ def collect_custom_exclusion_findings(
153
+ models: dict[str, ModelRepresentation], config: FitnessFunctionsConfig
154
+ ) -> list[LintFinding]:
155
+ exclusions_path = resolve_project_path(config, config.exclusions_path)
156
+ checker = CustomExclusionsChecker(models, exclusions_path)
157
+ findings: list[LintFinding] = []
158
+
159
+ for model_name, model in models.items():
160
+ if model.is_symbolic:
161
+ continue
162
+
163
+ for message in checker.check_model(model):
164
+ findings.append(
165
+ LintFinding(
166
+ check="custom_exclusions",
167
+ severity="error",
168
+ model=str(model.name),
169
+ path=model_path_relative(model),
170
+ message=message.removeprefix(f"Model '{model.name}' ").strip(),
171
+ )
172
+ )
173
+
174
+ return findings
@@ -0,0 +1,78 @@
1
+ """Dependency graph fan-in/fan-out metrics."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from collections import defaultdict
6
+
7
+ from tff.core.model import ModelRepresentation
8
+ from tff.core.config import FitnessFunctionsConfig
9
+ from tff.core.report import LintFinding
10
+ from tff.core.utils.paths import model_path_relative
11
+
12
+
13
+ def collect_dependency_graph_findings(
14
+ models: dict[str, ModelRepresentation], config: FitnessFunctionsConfig
15
+ ) -> list[LintFinding]:
16
+ graph_config = config.checks.dependency_graph
17
+ reverse: dict[str, set[str]] = defaultdict(set)
18
+ for model_name, model in models.items():
19
+ if model.is_symbolic:
20
+ continue
21
+ for dependency in model.depends_on:
22
+ reverse[str(dependency)].add(str(model_name))
23
+
24
+ findings: list[LintFinding] = []
25
+ for model_name, model in models.items():
26
+ if model.is_symbolic:
27
+ continue
28
+
29
+ from tff.core.utils.paths import get_layer_from_path
30
+ layer = get_layer_from_path(model.path)
31
+ if not graph_config.should_run(layer):
32
+ continue
33
+
34
+ fan_in = len(model.depends_on)
35
+ fan_out = len(reverse.get(str(model_name), set()))
36
+
37
+ if fan_out > graph_config.fan_out_fail:
38
+ findings.append(
39
+ LintFinding(
40
+ check="dependency_graph",
41
+ severity="error",
42
+ model=str(model_name),
43
+ path=model_path_relative(model),
44
+ message=(
45
+ f"fan_out={fan_out} (fail>{graph_config.fan_out_fail}) — "
46
+ "high blast-radius hub model"
47
+ ),
48
+ )
49
+ )
50
+ elif fan_out > graph_config.fan_out_warn:
51
+ findings.append(
52
+ LintFinding(
53
+ check="dependency_graph",
54
+ severity="warning",
55
+ model=str(model_name),
56
+ path=model_path_relative(model),
57
+ message=(
58
+ f"fan_out={fan_out} (warn>{graph_config.fan_out_warn}) — "
59
+ "run impact analysis before changing"
60
+ ),
61
+ )
62
+ )
63
+
64
+ if fan_in > graph_config.fan_in_warn:
65
+ findings.append(
66
+ LintFinding(
67
+ check="dependency_graph",
68
+ severity="warning",
69
+ model=str(model_name),
70
+ path=model_path_relative(model),
71
+ message=(
72
+ f"fan_in={fan_in} (warn>{graph_config.fan_in_warn}) — "
73
+ "consider decomposing upstream dependencies"
74
+ ),
75
+ )
76
+ )
77
+
78
+ return findings
@@ -0,0 +1,98 @@
1
+ """Layer integrity check — unidirectional flow and mart domain isolation."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from tff.core.model import ModelRepresentation
6
+ from tff.core.config import FitnessFunctionsConfig
7
+ from tff.core.report import LintFinding, normalize_model_name
8
+ from tff.core.utils.paths import (
9
+ get_layer_from_path,
10
+ get_marts_domain_from_path,
11
+ model_path_relative,
12
+ )
13
+
14
+
15
+ def _layer_index(
16
+ layer: str | None, dependency_model_kind: str, layer_order: list[str]
17
+ ) -> int | None:
18
+ layer_index = {name: idx for idx, name in enumerate(layer_order)}
19
+ if layer:
20
+ return layer_index.get(layer)
21
+ if dependency_model_kind == "EXTERNAL":
22
+ return -1
23
+ return None
24
+
25
+
26
+ def collect_layer_integrity_findings(
27
+ models: dict[str, ModelRepresentation], config: FitnessFunctionsConfig
28
+ ) -> list[LintFinding]:
29
+ findings: list[LintFinding] = []
30
+ layer_order = config.layers.order
31
+ marts_layer = config.rules.mart_naming.layer_name
32
+
33
+ for model_name, model in models.items():
34
+ if model.is_external or model.is_symbolic:
35
+ continue
36
+
37
+ model_layer = get_layer_from_path(model.path, layer_order)
38
+ model_layer_index = _layer_index(model_layer, "EXTERNAL" if model.is_external else "STANDARD", layer_order)
39
+ model_marts_domain = (
40
+ get_marts_domain_from_path(model.path, marts_layer)
41
+ if model_layer == marts_layer
42
+ else None
43
+ )
44
+
45
+ for dependency in model.depends_on:
46
+ dependency_model = models.get(dependency)
47
+ if not dependency_model:
48
+ continue
49
+
50
+ dependency_layer = get_layer_from_path(dependency_model.path, layer_order)
51
+ dependency_layer_index = _layer_index(
52
+ dependency_layer,
53
+ "EXTERNAL" if dependency_model.is_external else "STANDARD",
54
+ layer_order,
55
+ )
56
+
57
+ if (
58
+ model_layer_index is not None
59
+ and dependency_layer_index is not None
60
+ and dependency_layer_index > model_layer_index
61
+ ):
62
+ findings.append(
63
+ LintFinding(
64
+ check="layer_integrity",
65
+ severity="error",
66
+ model=str(model.name),
67
+ path=model_path_relative(model),
68
+ message=(
69
+ f"depends on {normalize_model_name(str(dependency))} "
70
+ "in a downstream layer"
71
+ ),
72
+ )
73
+ )
74
+
75
+ if model_layer == marts_layer and dependency_layer == marts_layer:
76
+ dependency_marts_domain = get_marts_domain_from_path(
77
+ dependency_model.path, marts_layer
78
+ )
79
+ if (
80
+ model_marts_domain
81
+ and dependency_marts_domain
82
+ and model_marts_domain != dependency_marts_domain
83
+ ):
84
+ findings.append(
85
+ LintFinding(
86
+ check="layer_integrity",
87
+ severity="error",
88
+ model=str(model.name),
89
+ path=model_path_relative(model),
90
+ message=(
91
+ f"{marts_layer}/{model_marts_domain} depends on "
92
+ f"{normalize_model_name(str(dependency))} "
93
+ f"({marts_layer}/{dependency_marts_domain})"
94
+ ),
95
+ )
96
+ )
97
+
98
+ return findings
@@ -0,0 +1,106 @@
1
+ """Cross-model schema contract parity checks."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ from pathlib import Path
7
+
8
+ from tff.core.config import (
9
+ FitnessFunctionsConfig,
10
+ _ensure_under_root,
11
+ resolve_project_path,
12
+ )
13
+ from tff.core.report import LintFinding
14
+ from tff.core.utils.schema_contract_utils import (
15
+ check_column_list_parity,
16
+ check_dimension_set_parity,
17
+ extract_final_select_columns,
18
+ normalize_columns,
19
+ read_model_sql,
20
+ )
21
+
22
+
23
+ def _resolve_path(project_root: Path, models_dir: str, filename: str) -> Path:
24
+ return _ensure_under_root(project_root / models_dir / filename, project_root)
25
+
26
+
27
+ def _schema_contract_errors(project_root: Path, contract_config: dict) -> list[str]:
28
+ errors: list[str] = []
29
+ for group in contract_config.get("column_parity_groups", []):
30
+ models_dir = group["models_dir"]
31
+ reference_path = _resolve_path(project_root, models_dir, group["reference"])
32
+ if not reference_path.exists():
33
+ errors.append(f"{reference_path.name} not found")
34
+ continue
35
+
36
+ exclude = set(group.get("exclude_columns", []))
37
+ ref_substitutions = group.get("reference_substitutions", {})
38
+ reference_cols = normalize_columns(
39
+ extract_final_select_columns(read_model_sql(reference_path)),
40
+ substitutions=ref_substitutions,
41
+ exclude=exclude,
42
+ )
43
+
44
+ for member in group["members"]:
45
+ member_path = _resolve_path(project_root, models_dir, member["file"])
46
+ if not member_path.exists():
47
+ errors.append(f"{member_path.name} not found")
48
+ continue
49
+ member_cols = normalize_columns(
50
+ extract_final_select_columns(read_model_sql(member_path)),
51
+ substitutions=member.get("substitutions", {}),
52
+ exclude=exclude,
53
+ )
54
+ errors.extend(
55
+ check_column_list_parity(
56
+ reference_cols,
57
+ member_cols,
58
+ group["reference"],
59
+ member["file"],
60
+ )
61
+ )
62
+
63
+ for group in contract_config.get("dimension_parity_groups", []):
64
+ models_dir = group["models_dir"]
65
+ left_cfg = group["left"]
66
+ right_cfg = group["right"]
67
+ left_path = _resolve_path(project_root, models_dir, left_cfg["file"])
68
+ right_path = _resolve_path(project_root, models_dir, right_cfg["file"])
69
+ if not left_path.exists() or not right_path.exists():
70
+ continue
71
+
72
+ left_dims = set(extract_final_select_columns(read_model_sql(left_path))) - set(
73
+ left_cfg.get("exclude_columns", [])
74
+ )
75
+ right_dims = set(extract_final_select_columns(read_model_sql(right_path))) - set(
76
+ right_cfg.get("exclude_columns", [])
77
+ )
78
+ errors.extend(
79
+ check_dimension_set_parity(
80
+ left_dims,
81
+ right_dims,
82
+ left_cfg["file"],
83
+ right_cfg["file"],
84
+ )
85
+ )
86
+
87
+ return errors
88
+
89
+
90
+ def collect_schema_contract_findings(
91
+ config: FitnessFunctionsConfig,
92
+ ) -> list[LintFinding]:
93
+ project_root: Path = getattr(config, "_project_root", Path.cwd())
94
+ contract_path = resolve_project_path(config, config.contract_groups_path)
95
+ if not contract_path.exists():
96
+ return []
97
+
98
+ contract_config = json.loads(contract_path.read_text(encoding="utf-8"))
99
+ return [
100
+ LintFinding(
101
+ check="schema_contracts",
102
+ severity="error",
103
+ message=error.replace("\n", " — "),
104
+ )
105
+ for error in _schema_contract_errors(project_root, contract_config)
106
+ ]
tff/core/config.py ADDED
@@ -0,0 +1,200 @@
1
+ """Pydantic models and loaders for fitness_functions.yaml."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import yaml
9
+ from pydantic import BaseModel, Field
10
+
11
+
12
+ class LayersConfig(BaseModel):
13
+ order: list[str] = Field(
14
+ default_factory=lambda: ["sources", "derived", "core", "marts", "export"]
15
+ )
16
+
17
+
18
+ class CheckEnabled(BaseModel):
19
+ enabled: bool = True
20
+
21
+
22
+ class LayerFilterConfig(BaseModel):
23
+ enabled: bool = True
24
+ skip_layers: list[str] = Field(default_factory=list)
25
+ only_layers: list[str] | None = None
26
+
27
+ def should_run(self, layer: str | None) -> bool:
28
+ if not self.enabled:
29
+ return False
30
+ if layer is None:
31
+ return True
32
+ if self.only_layers is not None:
33
+ return layer in self.only_layers
34
+ return layer not in self.skip_layers
35
+
36
+
37
+ class DependencyGraphCheckConfig(LayerFilterConfig):
38
+ fan_out_warn: int = 15
39
+ fan_out_fail: int = 25
40
+ fan_in_warn: int = 10
41
+
42
+
43
+ class ChecksConfig(BaseModel):
44
+ layer_integrity: CheckEnabled = Field(default_factory=CheckEnabled)
45
+ custom_exclusions: CheckEnabled = Field(default_factory=CheckEnabled)
46
+ schema_contracts: CheckEnabled = Field(default_factory=CheckEnabled)
47
+ dependency_graph: DependencyGraphCheckConfig = Field(
48
+ default_factory=DependencyGraphCheckConfig
49
+ )
50
+
51
+
52
+ class ClassificationMacrosRuleConfig(LayerFilterConfig):
53
+ skip_layers: list[str] = Field(default_factory=lambda: ["sources"])
54
+ columns: dict[str, str] = Field(
55
+ default_factory=lambda: {
56
+ "product_type": r"@product_type\b|@PRODUCT_TYPE\b",
57
+ "billing_segment": r"@BILLING_SEGMENT\b|@billing_segment\b",
58
+ "industry": r"@INDUSTRY\b|@industry\b",
59
+ }
60
+ )
61
+
62
+
63
+ class SqlComplexityRuleConfig(LayerFilterConfig):
64
+ warn_only: bool = True
65
+ thresholds: dict[str, list[int]] = Field(
66
+ default_factory=lambda: {
67
+ "decision_points": [15, 25],
68
+ "cte_count": [8, 12],
69
+ "join_count": [8, 12],
70
+ "line_count": [250, 400],
71
+ }
72
+ )
73
+
74
+
75
+ class MartNamingRuleConfig(LayerFilterConfig):
76
+ layer_name: str = "marts"
77
+ rule: str = "prefix_with_subdirectory"
78
+
79
+
80
+ class ColumnNamesRuleConfig(LayerFilterConfig):
81
+ replacements: dict[str, str] = Field(default_factory=dict)
82
+
83
+
84
+ class ColumnTypeRuleEntry(BaseModel):
85
+ name: str
86
+ pattern: str
87
+ data_type: str
88
+
89
+
90
+ class ColumnTypesRuleConfig(LayerFilterConfig):
91
+ rules: list[ColumnTypeRuleEntry] = Field(default_factory=list)
92
+ equivalent_types: dict[str, list[str]] = Field(
93
+ default_factory=lambda: {"text": ["text", "varchar"]}
94
+ )
95
+
96
+
97
+ class MetadataRuleConfig(LayerFilterConfig):
98
+ owner: bool = True
99
+ description: bool = True
100
+ grain: bool = True
101
+ not_null: bool = True
102
+ unique_values: bool = True
103
+
104
+
105
+ class FilenameEqualsModelnameRuleConfig(LayerFilterConfig):
106
+ pass
107
+
108
+
109
+ class NoSelectStarRuleConfig(LayerFilterConfig):
110
+ skip_layers: list[str] = Field(default_factory=lambda: ["sources"])
111
+
112
+
113
+ class NoPositionalGroupByOrOrderByRuleConfig(LayerFilterConfig):
114
+ skip_layers: list[str] = Field(default_factory=lambda: ["sources"])
115
+
116
+
117
+ class RulesConfig(BaseModel):
118
+ classification_macros: ClassificationMacrosRuleConfig = Field(
119
+ default_factory=ClassificationMacrosRuleConfig
120
+ )
121
+ sql_complexity: SqlComplexityRuleConfig = Field(
122
+ default_factory=SqlComplexityRuleConfig
123
+ )
124
+ mart_naming: MartNamingRuleConfig = Field(default_factory=MartNamingRuleConfig)
125
+ column_names: ColumnNamesRuleConfig = Field(default_factory=ColumnNamesRuleConfig)
126
+ column_types: ColumnTypesRuleConfig = Field(default_factory=ColumnTypesRuleConfig)
127
+ metadata: MetadataRuleConfig = Field(default_factory=MetadataRuleConfig)
128
+ filename_equals_modelname: FilenameEqualsModelnameRuleConfig = Field(
129
+ default_factory=FilenameEqualsModelnameRuleConfig
130
+ )
131
+ no_select_star: NoSelectStarRuleConfig = Field(
132
+ default_factory=NoSelectStarRuleConfig
133
+ )
134
+ no_positional_group_by_or_order_by: NoPositionalGroupByOrOrderByRuleConfig = Field(
135
+ default_factory=NoPositionalGroupByOrOrderByRuleConfig
136
+ )
137
+
138
+
139
+ class FitnessFunctionsConfig(BaseModel):
140
+ contract_groups_path: str = "linter_contract_groups.json"
141
+ exclusions_path: str = "linter_exclusions.json"
142
+ layers: LayersConfig = Field(default_factory=LayersConfig)
143
+ checks: ChecksConfig = Field(default_factory=ChecksConfig)
144
+ rules: RulesConfig = Field(default_factory=RulesConfig)
145
+
146
+
147
+ def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
148
+ merged = dict(base)
149
+ for key, value in override.items():
150
+ if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
151
+ merged[key] = _deep_merge(merged[key], value)
152
+ else:
153
+ merged[key] = value
154
+ return merged
155
+
156
+
157
+ def load_fitness_config(
158
+ project_root: Path,
159
+ config_path: str | Path | None = "fitness_functions.yaml",
160
+ overrides: dict[str, Any] | None = None,
161
+ ) -> FitnessFunctionsConfig:
162
+ """Load fitness config with defaults, yaml file, and optional overrides."""
163
+ data: dict[str, Any] = {}
164
+
165
+ if config_path is not None:
166
+ yaml_path = Path(config_path)
167
+ if not yaml_path.is_absolute():
168
+ yaml_path = project_root / yaml_path
169
+ if yaml_path.exists():
170
+ loaded = yaml.safe_load(yaml_path.read_text(encoding="utf-8")) or {}
171
+ if not isinstance(loaded, dict):
172
+ raise ValueError(f"Expected mapping in {yaml_path}")
173
+ data = loaded
174
+
175
+ if overrides:
176
+ data = _deep_merge(data, overrides)
177
+
178
+ config = FitnessFunctionsConfig.model_validate(data)
179
+ config._project_root = project_root # type: ignore[attr-defined]
180
+ return config
181
+
182
+
183
+ def _ensure_under_root(path: Path, root: Path) -> Path:
184
+ resolved = path.resolve()
185
+ root_resolved = root.resolve()
186
+ try:
187
+ resolved.relative_to(root_resolved)
188
+ except ValueError:
189
+ raise ValueError(
190
+ f"Path {path} resolves outside project root {root}"
191
+ ) from None
192
+ return resolved
193
+
194
+
195
+ def resolve_project_path(config: FitnessFunctionsConfig, relative: str) -> Path:
196
+ root: Path = getattr(config, "_project_root", Path.cwd())
197
+ path = Path(relative)
198
+ if path.is_absolute():
199
+ return _ensure_under_root(path, root)
200
+ return _ensure_under_root(root / path, root)
tff/core/context.py ADDED
@@ -0,0 +1,30 @@
1
+ """Thread-local fitness function configuration for SQLMesh rule classes."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import threading
6
+ from typing import TYPE_CHECKING
7
+
8
+ if TYPE_CHECKING:
9
+ from tff.core.config import FitnessFunctionsConfig
10
+
11
+ _config_local = threading.local()
12
+
13
+
14
+ def set_ff_config(config: FitnessFunctionsConfig) -> None:
15
+ _config_local.config = config
16
+
17
+
18
+ def get_ff_config() -> FitnessFunctionsConfig:
19
+ from tff.core.config import FitnessFunctionsConfig
20
+
21
+ config = getattr(_config_local, "config", None)
22
+ if config is None:
23
+ config = FitnessFunctionsConfig()
24
+ _config_local.config = config
25
+ return config
26
+
27
+
28
+ def clear_ff_config() -> None:
29
+ if hasattr(_config_local, "config"):
30
+ del _config_local.config