fba-bench-core 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- fba_bench_core/__init__.py +11 -0
- fba_bench_core/agents/__init__.py +15 -0
- fba_bench_core/agents/base.py +83 -0
- fba_bench_core/agents/registry.py +16 -0
- fba_bench_core/benchmarking/__init__.py +6 -0
- fba_bench_core/benchmarking/core/__init__.py +1 -0
- fba_bench_core/benchmarking/engine/__init__.py +12 -0
- fba_bench_core/benchmarking/engine/core.py +135 -0
- fba_bench_core/benchmarking/engine/models.py +62 -0
- fba_bench_core/benchmarking/metrics/__init__.py +30 -0
- fba_bench_core/benchmarking/metrics/accuracy_score.py +27 -0
- fba_bench_core/benchmarking/metrics/aggregate.py +39 -0
- fba_bench_core/benchmarking/metrics/completeness.py +38 -0
- fba_bench_core/benchmarking/metrics/cost_efficiency.py +32 -0
- fba_bench_core/benchmarking/metrics/custom_scriptable.py +17 -0
- fba_bench_core/benchmarking/metrics/keyword_coverage.py +41 -0
- fba_bench_core/benchmarking/metrics/policy_compliance.py +18 -0
- fba_bench_core/benchmarking/metrics/registry.py +57 -0
- fba_bench_core/benchmarking/metrics/robustness.py +27 -0
- fba_bench_core/benchmarking/metrics/technical_performance.py +16 -0
- fba_bench_core/benchmarking/registry.py +48 -0
- fba_bench_core/benchmarking/scenarios/__init__.py +1 -0
- fba_bench_core/benchmarking/scenarios/base.py +36 -0
- fba_bench_core/benchmarking/scenarios/complex_marketplace.py +181 -0
- fba_bench_core/benchmarking/scenarios/multiturn_tool_use.py +176 -0
- fba_bench_core/benchmarking/scenarios/registry.py +18 -0
- fba_bench_core/benchmarking/scenarios/research_summarization.py +141 -0
- fba_bench_core/benchmarking/validators/__init__.py +24 -0
- fba_bench_core/benchmarking/validators/determinism_check.py +95 -0
- fba_bench_core/benchmarking/validators/fairness_balance.py +75 -0
- fba_bench_core/benchmarking/validators/outlier_detection.py +53 -0
- fba_bench_core/benchmarking/validators/registry.py +57 -0
- fba_bench_core/benchmarking/validators/reproducibility_metadata.py +74 -0
- fba_bench_core/benchmarking/validators/schema_adherence.py +59 -0
- fba_bench_core/benchmarking/validators/structural_consistency.py +74 -0
- fba_bench_core/config.py +154 -0
- fba_bench_core/domain/__init__.py +75 -0
- fba_bench_core/domain/events/__init__.py +230 -0
- fba_bench_core/domain/events/analytics.py +69 -0
- fba_bench_core/domain/events/base.py +59 -0
- fba_bench_core/domain/events/inventory.py +119 -0
- fba_bench_core/domain/events/marketing.py +102 -0
- fba_bench_core/domain/events/pricing.py +179 -0
- fba_bench_core/domain/models.py +296 -0
- fba_bench_core/exceptions/__init__.py +9 -0
- fba_bench_core/exceptions/base.py +46 -0
- fba_bench_core/services/__init__.py +12 -0
- fba_bench_core/services/base.py +52 -0
- fba_bench_core-1.0.0.dist-info/METADATA +152 -0
- fba_bench_core-1.0.0.dist-info/RECORD +52 -0
- fba_bench_core-1.0.0.dist-info/WHEEL +4 -0
- fba_bench_core-1.0.0.dist-info/licenses/LICENSE +21 -0
@@ -0,0 +1,95 @@
|
|
1
|
+
"""Determinism check validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("determinism_check")
|
9
|
+
def determinism_check(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Check determinism across multiple runs."""
|
13
|
+
issues = []
|
14
|
+
tolerance = config.get("tolerance", 0.0)
|
15
|
+
fields = config.get("fields", ["value"])
|
16
|
+
|
17
|
+
# Group runs by runner_key and seed
|
18
|
+
from collections import defaultdict
|
19
|
+
|
20
|
+
groups = defaultdict(list)
|
21
|
+
for run in runs:
|
22
|
+
key = (run.get("runner_key"), run.get("seed"))
|
23
|
+
groups[key].append(run)
|
24
|
+
|
25
|
+
for (runner, seed), group_runs in groups.items():
|
26
|
+
if len(group_runs) < 2:
|
27
|
+
continue
|
28
|
+
|
29
|
+
for field in fields:
|
30
|
+
values = [
|
31
|
+
run.get("output", {}).get(field)
|
32
|
+
for run in group_runs
|
33
|
+
if run.get("status") == "success"
|
34
|
+
]
|
35
|
+
if len(values) < 2:
|
36
|
+
continue
|
37
|
+
|
38
|
+
# Check if all values are within tolerance
|
39
|
+
if isinstance(values[0], (int, float)):
|
40
|
+
min_val = min(values)
|
41
|
+
max_val = max(values)
|
42
|
+
if max_val - min_val > tolerance:
|
43
|
+
issues.append(
|
44
|
+
{
|
45
|
+
"id": "determinism_mismatch",
|
46
|
+
"severity": "error",
|
47
|
+
"message": f"Determinism mismatch for runner '{runner}', seed {seed}, field '{field}': values {values} exceed tolerance {tolerance}",
|
48
|
+
"context": {
|
49
|
+
"runner": runner,
|
50
|
+
"seed": seed,
|
51
|
+
"field": field,
|
52
|
+
"values": values,
|
53
|
+
},
|
54
|
+
}
|
55
|
+
)
|
56
|
+
else:
|
57
|
+
# For non-numeric, check exact match
|
58
|
+
if not all(v == values[0] for v in values):
|
59
|
+
issues.append(
|
60
|
+
{
|
61
|
+
"id": "determinism_mismatch",
|
62
|
+
"severity": "error",
|
63
|
+
"message": f"Determinism mismatch for runner '{runner}', seed {seed}, field '{field}': values {values} not identical",
|
64
|
+
"context": {
|
65
|
+
"runner": runner,
|
66
|
+
"seed": seed,
|
67
|
+
"field": field,
|
68
|
+
"values": values,
|
69
|
+
},
|
70
|
+
}
|
71
|
+
)
|
72
|
+
|
73
|
+
if not issues:
|
74
|
+
issues.append(
|
75
|
+
{
|
76
|
+
"id": "determinism_ok",
|
77
|
+
"severity": "info",
|
78
|
+
"message": "All deterministic checks passed",
|
79
|
+
}
|
80
|
+
)
|
81
|
+
|
82
|
+
return {
|
83
|
+
"issues": issues,
|
84
|
+
"summary": {
|
85
|
+
"total_groups_checked": len(groups),
|
86
|
+
"groups_with_issues": len(
|
87
|
+
[
|
88
|
+
g
|
89
|
+
for g in groups.values()
|
90
|
+
if len(g) >= 2
|
91
|
+
and any(i["id"] == "determinism_mismatch" for i in issues)
|
92
|
+
]
|
93
|
+
),
|
94
|
+
},
|
95
|
+
}
|
@@ -0,0 +1,75 @@
|
|
1
|
+
"""Fairness balance validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("fairness_balance")
|
9
|
+
def fairness_balance(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Validate fairness balance across runs."""
|
13
|
+
issues = []
|
14
|
+
group = config.get("group", "runner_key")
|
15
|
+
metric_path = config.get("metric_path", "metrics.accuracy")
|
16
|
+
threshold = config.get("threshold", 0.1)
|
17
|
+
min_group_size = config.get("min_group_size", 2)
|
18
|
+
|
19
|
+
from collections import defaultdict
|
20
|
+
|
21
|
+
groups = defaultdict(list)
|
22
|
+
for run in runs:
|
23
|
+
group_key = run.get(group)
|
24
|
+
if group_key:
|
25
|
+
groups[group_key].append(run)
|
26
|
+
|
27
|
+
for group_key, group_runs in groups.items():
|
28
|
+
if len(group_runs) < min_group_size:
|
29
|
+
continue
|
30
|
+
|
31
|
+
# Extract metric values for successful runs
|
32
|
+
values = []
|
33
|
+
for run in group_runs:
|
34
|
+
if run.get("status") == "success":
|
35
|
+
# Handle nested metric path
|
36
|
+
current = run
|
37
|
+
for key in metric_path.split("."):
|
38
|
+
current = current.get(key, {})
|
39
|
+
value = current if isinstance(current, (int, float)) else 0.0
|
40
|
+
values.append(value)
|
41
|
+
|
42
|
+
if values:
|
43
|
+
min_val = min(values)
|
44
|
+
max_val = max(values)
|
45
|
+
if (max_val - min_val) / ((min_val + max_val) / 2) > threshold:
|
46
|
+
issues.append(
|
47
|
+
{
|
48
|
+
"id": "fairness_imbalance",
|
49
|
+
"severity": "warning",
|
50
|
+
"message": f"Fairness imbalance in group '{group_key}': range {min_val}-{max_val} exceeds threshold {threshold}",
|
51
|
+
"context": {
|
52
|
+
"group": group_key,
|
53
|
+
"metric": metric_path,
|
54
|
+
"min": min_val,
|
55
|
+
"max": max_val,
|
56
|
+
"threshold": threshold,
|
57
|
+
},
|
58
|
+
}
|
59
|
+
)
|
60
|
+
else:
|
61
|
+
issues.append(
|
62
|
+
{
|
63
|
+
"id": "fairness_within_threshold",
|
64
|
+
"severity": "info",
|
65
|
+
"message": f"Fairness within threshold for group '{group_key}'",
|
66
|
+
}
|
67
|
+
)
|
68
|
+
|
69
|
+
return {
|
70
|
+
"issues": issues,
|
71
|
+
"summary": {
|
72
|
+
"groups_checked": len(groups),
|
73
|
+
"groups_with_imbalance": len(issues),
|
74
|
+
},
|
75
|
+
}
|
@@ -0,0 +1,53 @@
|
|
1
|
+
"""Outlier detection validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("outlier_detection")
|
9
|
+
def outlier_detection(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Detect outliers in runs."""
|
13
|
+
issues = []
|
14
|
+
k = config.get("k", 1.5)
|
15
|
+
field = config.get("field", "duration_ms")
|
16
|
+
|
17
|
+
durations = [run.get(field, 0) for run in runs if run.get("status") == "success"]
|
18
|
+
if len(durations) < 3:
|
19
|
+
return {"issues": [], "summary": {"checked": len(durations), "outliers": []}}
|
20
|
+
|
21
|
+
median = sorted(durations)[len(durations) // 2]
|
22
|
+
deviations = [abs(d - median) for d in durations]
|
23
|
+
mad = sorted(deviations)[len(deviations) // 2]
|
24
|
+
|
25
|
+
outlier_indices = []
|
26
|
+
for i, dev in enumerate(deviations):
|
27
|
+
if dev > k * mad:
|
28
|
+
outlier_indices.append(i)
|
29
|
+
|
30
|
+
for idx in outlier_indices:
|
31
|
+
issues.append(
|
32
|
+
{
|
33
|
+
"id": "duration_outlier",
|
34
|
+
"severity": "warning",
|
35
|
+
"message": f"Outlier duration at index {idx}: {durations[idx]} (median: {median}, MAD: {mad})",
|
36
|
+
"context": {
|
37
|
+
"index": idx,
|
38
|
+
"value": durations[idx],
|
39
|
+
"median": median,
|
40
|
+
"mad": mad,
|
41
|
+
},
|
42
|
+
}
|
43
|
+
)
|
44
|
+
|
45
|
+
return {
|
46
|
+
"issues": issues,
|
47
|
+
"summary": {
|
48
|
+
"total_runs": len(durations),
|
49
|
+
"outliers": outlier_indices,
|
50
|
+
"median_duration": median,
|
51
|
+
"mad_duration": mad,
|
52
|
+
},
|
53
|
+
}
|
@@ -0,0 +1,57 @@
|
|
1
|
+
"""Registry for validators."""
|
2
|
+
|
3
|
+
from collections.abc import Callable
|
4
|
+
|
5
|
+
|
6
|
+
class ValidatorRegistry:
|
7
|
+
_validators: dict[str, Callable] = {}
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def register(cls, name: str, validator_class: Callable) -> None:
|
11
|
+
"""Register a validator class."""
|
12
|
+
cls._validators[name] = validator_class
|
13
|
+
|
14
|
+
@classmethod
|
15
|
+
def create_validator(cls, name: str, config=None) -> Callable | None:
|
16
|
+
"""Create a validator instance."""
|
17
|
+
fn = cls._validators.get(name)
|
18
|
+
if fn:
|
19
|
+
return fn(config) if config else fn()
|
20
|
+
return None
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def get_validator(cls, name: str) -> Callable | None:
|
24
|
+
"""Get a validator class by name."""
|
25
|
+
return cls._validators.get(name)
|
26
|
+
|
27
|
+
@classmethod
|
28
|
+
def list_validators(cls) -> list[str]:
|
29
|
+
"""List all registered validator names."""
|
30
|
+
return list(cls._validators.keys())
|
31
|
+
|
32
|
+
|
33
|
+
# Global instance for function-based API
|
34
|
+
registry = ValidatorRegistry()
|
35
|
+
|
36
|
+
|
37
|
+
def get_validator(name: str) -> Callable:
|
38
|
+
"""Get a validator by name, raising KeyError if not found."""
|
39
|
+
validator = registry.get_validator(name)
|
40
|
+
if validator is None:
|
41
|
+
raise KeyError(f"Validator '{name}' not found")
|
42
|
+
return validator
|
43
|
+
|
44
|
+
|
45
|
+
def list_validators() -> list[str]:
|
46
|
+
"""List all registered validator names."""
|
47
|
+
return registry.list_validators()
|
48
|
+
|
49
|
+
|
50
|
+
def register_validator(name: str):
|
51
|
+
"""Decorator to register a validator function with the given name."""
|
52
|
+
|
53
|
+
def decorator(func: Callable) -> Callable:
|
54
|
+
registry.register(name, func)
|
55
|
+
return func
|
56
|
+
|
57
|
+
return decorator
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""Reproducibility metadata validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("reproducibility_metadata")
|
9
|
+
def reproducibility_metadata(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Validate presence of reproducibility metadata."""
|
13
|
+
expected_seeds = config.get("expected_seeds", [])
|
14
|
+
config_digest = config.get("config_digest", "")
|
15
|
+
issues = []
|
16
|
+
|
17
|
+
for i, run in enumerate(runs):
|
18
|
+
seed = run.get("seed")
|
19
|
+
if seed is None:
|
20
|
+
issues.append(
|
21
|
+
{
|
22
|
+
"id": "missing_seed",
|
23
|
+
"severity": "warning",
|
24
|
+
"message": f"Run {i} missing seed for reproducibility",
|
25
|
+
"context": {"index": i},
|
26
|
+
}
|
27
|
+
)
|
28
|
+
elif seed not in expected_seeds:
|
29
|
+
issues.append(
|
30
|
+
{
|
31
|
+
"id": "unexpected_seed",
|
32
|
+
"severity": "warning",
|
33
|
+
"message": f"Run {i} has unexpected seed {seed}, expected {expected_seeds}",
|
34
|
+
"context": {"index": i, "seed": seed, "expected": expected_seeds},
|
35
|
+
}
|
36
|
+
)
|
37
|
+
|
38
|
+
# Check for per-run config digest
|
39
|
+
run_digest = run.get("config_digest")
|
40
|
+
if run_digest is None:
|
41
|
+
issues.append(
|
42
|
+
{
|
43
|
+
"id": "per_run_digest_missing",
|
44
|
+
"severity": "info",
|
45
|
+
"message": f"Run {i} missing per-run config digest",
|
46
|
+
}
|
47
|
+
)
|
48
|
+
elif run_digest != config_digest:
|
49
|
+
issues.append(
|
50
|
+
{
|
51
|
+
"id": "config_digest_mismatch",
|
52
|
+
"severity": "error",
|
53
|
+
"message": f"Run {i} config digest {run_digest} does not match expected {config_digest}",
|
54
|
+
"context": {
|
55
|
+
"index": i,
|
56
|
+
"expected": config_digest,
|
57
|
+
"actual": run_digest,
|
58
|
+
},
|
59
|
+
}
|
60
|
+
)
|
61
|
+
|
62
|
+
return {
|
63
|
+
"issues": issues,
|
64
|
+
"summary": {
|
65
|
+
"total_runs": len(runs),
|
66
|
+
"missing_seeds": len([r for r in runs if r.get("seed") is None]),
|
67
|
+
"unexpected_seeds": len(
|
68
|
+
[r for r in runs if r.get("seed") not in expected_seeds]
|
69
|
+
),
|
70
|
+
"digest_mismatches": len(
|
71
|
+
[r for r in runs if r.get("config_digest") != config_digest]
|
72
|
+
),
|
73
|
+
},
|
74
|
+
}
|
@@ -0,0 +1,59 @@
|
|
1
|
+
"""Schema adherence validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("schema_adherence")
|
9
|
+
def schema_adherence(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Validate schema adherence of runs."""
|
13
|
+
contract = config.get("contract", {})
|
14
|
+
required_fields = contract.get("required", {})
|
15
|
+
issues = []
|
16
|
+
|
17
|
+
for i, run in enumerate(runs):
|
18
|
+
if not isinstance(run, dict):
|
19
|
+
issues.append(
|
20
|
+
{
|
21
|
+
"id": "invalid_run_type",
|
22
|
+
"severity": "error",
|
23
|
+
"message": f"Run {i} is not a dict: {type(run)}",
|
24
|
+
"context": {"index": i, "type": type(run)},
|
25
|
+
}
|
26
|
+
)
|
27
|
+
continue
|
28
|
+
|
29
|
+
for field_name, field_type in required_fields.items():
|
30
|
+
if field_name not in run:
|
31
|
+
issues.append(
|
32
|
+
{
|
33
|
+
"id": "missing_field",
|
34
|
+
"severity": "error",
|
35
|
+
"message": f"Missing required field '{field_name}' in run {i}",
|
36
|
+
"context": {"index": i, "field": field_name},
|
37
|
+
}
|
38
|
+
)
|
39
|
+
else:
|
40
|
+
value = run[field_name]
|
41
|
+
if field_type == "int" and not isinstance(value, int):
|
42
|
+
issues.append(
|
43
|
+
{
|
44
|
+
"id": "schema_type_mismatch",
|
45
|
+
"severity": "warning",
|
46
|
+
"message": f"Field '{field_name}' in run {i} has type {type(value)} but expected {field_type}",
|
47
|
+
"context": {
|
48
|
+
"index": i,
|
49
|
+
"field": field_name,
|
50
|
+
"expected": field_type,
|
51
|
+
"actual": type(value),
|
52
|
+
},
|
53
|
+
}
|
54
|
+
)
|
55
|
+
|
56
|
+
return {
|
57
|
+
"issues": issues,
|
58
|
+
"summary": {"total_runs": len(runs), "validation_errors": len(issues)},
|
59
|
+
}
|
@@ -0,0 +1,74 @@
|
|
1
|
+
"""Structural consistency validator."""
|
2
|
+
|
3
|
+
from typing import Any
|
4
|
+
|
5
|
+
from .registry import register_validator
|
6
|
+
|
7
|
+
|
8
|
+
@register_validator("structural_consistency")
|
9
|
+
def structural_consistency(
|
10
|
+
runs: list[dict[str, Any]], config: dict[str, Any]
|
11
|
+
) -> dict[str, Any]:
|
12
|
+
"""Validate structural consistency across runs."""
|
13
|
+
issues = []
|
14
|
+
|
15
|
+
for i, run in enumerate(runs):
|
16
|
+
if not isinstance(run, dict):
|
17
|
+
issues.append(
|
18
|
+
{
|
19
|
+
"id": "invalid_run_type",
|
20
|
+
"severity": "error",
|
21
|
+
"message": f"Run {i} is not a dict: {type(run)}",
|
22
|
+
"context": {"index": i, "type": type(run)},
|
23
|
+
}
|
24
|
+
)
|
25
|
+
continue
|
26
|
+
|
27
|
+
# Check required fields
|
28
|
+
required_fields = [
|
29
|
+
"scenario_key",
|
30
|
+
"runner_key",
|
31
|
+
"status",
|
32
|
+
"duration_ms",
|
33
|
+
"metrics",
|
34
|
+
"output",
|
35
|
+
]
|
36
|
+
for field in required_fields:
|
37
|
+
if field not in run:
|
38
|
+
issues.append(
|
39
|
+
{
|
40
|
+
"id": "missing_field",
|
41
|
+
"severity": "error",
|
42
|
+
"message": f"Missing required field '{field}' in run {i}",
|
43
|
+
"context": {"index": i, "field": field},
|
44
|
+
}
|
45
|
+
)
|
46
|
+
|
47
|
+
# Check duration_ms non-negative
|
48
|
+
duration = run.get("duration_ms")
|
49
|
+
if isinstance(duration, (int, float)) and duration < 0:
|
50
|
+
issues.append(
|
51
|
+
{
|
52
|
+
"id": "negative_duration",
|
53
|
+
"severity": "warning",
|
54
|
+
"message": f"Negative duration_ms {duration} in run {i}",
|
55
|
+
"context": {"index": i, "duration": duration},
|
56
|
+
}
|
57
|
+
)
|
58
|
+
|
59
|
+
# Check output only on success
|
60
|
+
status = run.get("status")
|
61
|
+
if status != "success" and run.get("output") is not None:
|
62
|
+
issues.append(
|
63
|
+
{
|
64
|
+
"id": "unexpected_output_on_failure",
|
65
|
+
"severity": "info",
|
66
|
+
"message": f"Output present on non-success status '{status}' in run {i}",
|
67
|
+
"context": {"index": i, "status": status},
|
68
|
+
}
|
69
|
+
)
|
70
|
+
|
71
|
+
return {
|
72
|
+
"issues": issues,
|
73
|
+
"summary": {"total_runs": len(runs), "structural_issues": len(issues)},
|
74
|
+
}
|
fba_bench_core/config.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
"""Typed configuration contracts for fba_bench_core.
|
2
|
+
|
3
|
+
Phase D:
|
4
|
+
- Introduces BaseAgentConfig and BaseServiceConfig as Pydantic models.
|
5
|
+
- Enforces typed metadata, forbids extra fields, and validates identifiers.
|
6
|
+
- Models are frozen (immutable) to prevent accidental mutation by consumers.
|
7
|
+
|
8
|
+
Downstream guidance:
|
9
|
+
- Subclass BaseAgentConfig / BaseServiceConfig to add domain-specific fields.
|
10
|
+
- Use model_copy(update={...}) to create modified copies rather than mutating.
|
11
|
+
"""
|
12
|
+
|
13
|
+
from __future__ import annotations
|
14
|
+
|
15
|
+
import re
|
16
|
+
|
17
|
+
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
18
|
+
|
19
|
+
# Allowed primitive metadata value types. Reject nested dicts/lists to avoid
|
20
|
+
# arbitrary deep structures hiding Any.
|
21
|
+
Primitive = str | int | float | bool
|
22
|
+
|
23
|
+
|
24
|
+
def _validate_slug(value: str, field_name: str) -> str:
|
25
|
+
"""Ensure identifier uses a simple slug format (alphanum, hyphen, underscore)."""
|
26
|
+
if not isinstance(value, str) or not re.match(r"^[a-zA-Z0-9_-]+$", value):
|
27
|
+
raise ValueError(
|
28
|
+
f"{field_name!r} must be a slug (letters, digits, hyphen, underscore)."
|
29
|
+
)
|
30
|
+
return value
|
31
|
+
|
32
|
+
|
33
|
+
class BaseConfigModel(BaseModel):
|
34
|
+
"""Shared base for configuration models.
|
35
|
+
|
36
|
+
Provides strict Pydantic model settings:
|
37
|
+
- extra="forbid": disallow unknown fields (prevents accidental additions).
|
38
|
+
- validate_assignment=True: validate when creating copies or assigning (still
|
39
|
+
compatible with frozen models).
|
40
|
+
- frozen=True: make instances immutable to avoid accidental runtime mutation.
|
41
|
+
- allow_population_by_field_name=True: helpful if downstream code prefers
|
42
|
+
field-name population.
|
43
|
+
"""
|
44
|
+
|
45
|
+
model_config = ConfigDict(
|
46
|
+
extra="forbid",
|
47
|
+
validate_assignment=True,
|
48
|
+
frozen=True,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
52
|
+
class BaseAgentConfig(BaseConfigModel):
|
53
|
+
"""Base configuration contract for agents.
|
54
|
+
|
55
|
+
Fields:
|
56
|
+
agent_id: Unique identifier (slug) for the agent instance.
|
57
|
+
poll_interval_seconds: Optional polling interval (seconds) for
|
58
|
+
agents that poll external systems. Keep None if not used.
|
59
|
+
max_concurrent_tasks: Optional concurrency hint for schedulers.
|
60
|
+
default_region: Optional region/locale hint (e.g., "us-west-2").
|
61
|
+
metadata: Shallow mapping of simple metadata values (no nested dicts/lists).
|
62
|
+
Keys are strings and values are limited to primitive types.
|
63
|
+
|
64
|
+
Example:
|
65
|
+
class PricingAgentConfig(BaseAgentConfig):
|
66
|
+
pricing_tier: Literal["basic", "pro"] = "basic"
|
67
|
+
"""
|
68
|
+
|
69
|
+
agent_id: str
|
70
|
+
poll_interval_seconds: int | None = None
|
71
|
+
max_concurrent_tasks: int | None = None
|
72
|
+
default_region: str | None = None
|
73
|
+
metadata: dict[str, Primitive] = Field(default_factory=dict)
|
74
|
+
|
75
|
+
@field_validator("agent_id")
|
76
|
+
@classmethod
|
77
|
+
def _check_agent_id(cls, v: str) -> str:
|
78
|
+
return _validate_slug(v, "agent_id")
|
79
|
+
|
80
|
+
@field_validator("poll_interval_seconds", "max_concurrent_tasks")
|
81
|
+
@classmethod
|
82
|
+
def _non_negative_ints(cls, v: int | None) -> int | None:
|
83
|
+
if v is None:
|
84
|
+
return v
|
85
|
+
if v < 0:
|
86
|
+
raise ValueError("must be non-negative")
|
87
|
+
return v
|
88
|
+
|
89
|
+
@field_validator("metadata")
|
90
|
+
@classmethod
|
91
|
+
def _validate_metadata(cls, v: dict[str, Primitive]) -> dict[str, Primitive]:
|
92
|
+
if not isinstance(v, dict):
|
93
|
+
raise ValueError("metadata must be a mapping of str -> primitive values")
|
94
|
+
for k, val in v.items():
|
95
|
+
if not isinstance(k, str):
|
96
|
+
raise ValueError("metadata keys must be strings")
|
97
|
+
if not isinstance(val, (str, int, float, bool)):
|
98
|
+
raise ValueError(
|
99
|
+
"metadata values must be primitive types (str, int, float, bool)"
|
100
|
+
)
|
101
|
+
return v
|
102
|
+
|
103
|
+
|
104
|
+
class BaseServiceConfig(BaseConfigModel):
|
105
|
+
"""Base configuration contract for services.
|
106
|
+
|
107
|
+
Fields:
|
108
|
+
service_id: Unique identifier (slug) for the service instance.
|
109
|
+
poll_interval_seconds, max_concurrent_tasks, default_region, metadata:
|
110
|
+
same semantics as in BaseAgentConfig.
|
111
|
+
|
112
|
+
Example:
|
113
|
+
class CacheServiceConfig(BaseServiceConfig):
|
114
|
+
ttl_seconds: int = 300
|
115
|
+
"""
|
116
|
+
|
117
|
+
service_id: str
|
118
|
+
poll_interval_seconds: int | None = None
|
119
|
+
max_concurrent_tasks: int | None = None
|
120
|
+
default_region: str | None = None
|
121
|
+
metadata: dict[str, Primitive] = Field(default_factory=dict)
|
122
|
+
|
123
|
+
@field_validator("service_id")
|
124
|
+
@classmethod
|
125
|
+
def _check_service_id(cls, v: str) -> str:
|
126
|
+
return _validate_slug(v, "service_id")
|
127
|
+
|
128
|
+
@field_validator("poll_interval_seconds", "max_concurrent_tasks")
|
129
|
+
@classmethod
|
130
|
+
def _non_negative_ints(cls, v: int | None) -> int | None:
|
131
|
+
if v is None:
|
132
|
+
return v
|
133
|
+
if v < 0:
|
134
|
+
raise ValueError("must be non-negative")
|
135
|
+
return v
|
136
|
+
|
137
|
+
@field_validator("metadata")
|
138
|
+
@classmethod
|
139
|
+
def _validate_metadata(cls, v: dict[str, Primitive]) -> dict[str, Primitive]:
|
140
|
+
# Reuse same validation semantics as agent metadata.
|
141
|
+
if not isinstance(v, dict):
|
142
|
+
raise ValueError("metadata must be a mapping of str -> primitive values")
|
143
|
+
for k, val in v.items():
|
144
|
+
if not isinstance(k, str):
|
145
|
+
raise ValueError("metadata keys must be strings")
|
146
|
+
if not isinstance(val, (str, int, float, bool)):
|
147
|
+
raise ValueError(
|
148
|
+
"metadata values must be primitive types (str, int, float, bool)"
|
149
|
+
)
|
150
|
+
return v
|
151
|
+
|
152
|
+
|
153
|
+
# End of file. Subclass these models in downstream packages to add domain-specific
|
154
|
+
# configuration while preserving validation and immutability guarantees.
|