stackfix 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cloudgym/__init__.py +3 -0
- cloudgym/benchmark/__init__.py +0 -0
- cloudgym/benchmark/dataset.py +188 -0
- cloudgym/benchmark/evaluator.py +275 -0
- cloudgym/cli.py +61 -0
- cloudgym/fixer/__init__.py +1 -0
- cloudgym/fixer/cli.py +521 -0
- cloudgym/fixer/detector.py +81 -0
- cloudgym/fixer/formatter.py +55 -0
- cloudgym/fixer/lambda_handler.py +126 -0
- cloudgym/fixer/repairer.py +237 -0
- cloudgym/generator/__init__.py +0 -0
- cloudgym/generator/formatter.py +142 -0
- cloudgym/generator/pipeline.py +271 -0
- cloudgym/inverter/__init__.py +0 -0
- cloudgym/inverter/_cf_injectors.py +705 -0
- cloudgym/inverter/_cf_utils.py +202 -0
- cloudgym/inverter/_hcl_utils.py +182 -0
- cloudgym/inverter/_tf_injectors.py +641 -0
- cloudgym/inverter/_yaml_cf.py +84 -0
- cloudgym/inverter/agentic.py +90 -0
- cloudgym/inverter/engine.py +258 -0
- cloudgym/inverter/programmatic.py +95 -0
- cloudgym/scraper/__init__.py +0 -0
- cloudgym/scraper/aws_samples.py +159 -0
- cloudgym/scraper/github.py +238 -0
- cloudgym/scraper/registry.py +165 -0
- cloudgym/scraper/validator.py +116 -0
- cloudgym/taxonomy/__init__.py +10 -0
- cloudgym/taxonomy/base.py +102 -0
- cloudgym/taxonomy/cloudformation.py +258 -0
- cloudgym/taxonomy/terraform.py +274 -0
- cloudgym/utils/__init__.py +0 -0
- cloudgym/utils/config.py +57 -0
- cloudgym/utils/ollama.py +66 -0
- cloudgym/validator/__init__.py +0 -0
- cloudgym/validator/cloudformation.py +55 -0
- cloudgym/validator/opentofu.py +103 -0
- cloudgym/validator/terraform.py +115 -0
- stackfix-0.1.0.dist-info/METADATA +182 -0
- stackfix-0.1.0.dist-info/RECORD +44 -0
- stackfix-0.1.0.dist-info/WHEEL +4 -0
- stackfix-0.1.0.dist-info/entry_points.txt +3 -0
- stackfix-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
"""AWS Lambda handler for IaC repair using GGUF model.
|
|
2
|
+
|
|
3
|
+
Deploy with the 0.5B model (~300MB) for sub-$0.001/invocation cost.
|
|
4
|
+
|
|
5
|
+
Model loading:
|
|
6
|
+
- Lambda layer: Place model at /opt/model.gguf
|
|
7
|
+
- S3: Set MODEL_S3_BUCKET and MODEL_S3_KEY env vars
|
|
8
|
+
- Bundled: Set MODEL_PATH env var
|
|
9
|
+
|
|
10
|
+
Example event:
|
|
11
|
+
{
|
|
12
|
+
"config": "resource aws_s3_bucket my_bucket {\\n acl = \\"public\\"\\n}",
|
|
13
|
+
"errors": ["Error: 'acl' argument is deprecated"],
|
|
14
|
+
"format": "terraform"
|
|
15
|
+
}
|
|
16
|
+
|
|
17
|
+
Response:
|
|
18
|
+
{
|
|
19
|
+
"repaired": "resource aws_s3_bucket my_bucket {\\n ...\\n}",
|
|
20
|
+
"original_errors": 1,
|
|
21
|
+
"verified": true,
|
|
22
|
+
"remaining_errors": 0
|
|
23
|
+
}
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
from __future__ import annotations
|
|
27
|
+
|
|
28
|
+
import json
|
|
29
|
+
import logging
|
|
30
|
+
import os
|
|
31
|
+
import tempfile
|
|
32
|
+
from pathlib import Path
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
# Lazy-loaded repairer (persists across warm invocations)
|
|
37
|
+
_repairer = None
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_model_path() -> str:
|
|
41
|
+
"""Resolve the GGUF model path from environment or known locations."""
|
|
42
|
+
# Explicit path
|
|
43
|
+
if path := os.environ.get("MODEL_PATH"):
|
|
44
|
+
return path
|
|
45
|
+
|
|
46
|
+
# Lambda layer
|
|
47
|
+
layer_path = "/opt/model.gguf"
|
|
48
|
+
if Path(layer_path).exists():
|
|
49
|
+
return layer_path
|
|
50
|
+
|
|
51
|
+
# S3 download
|
|
52
|
+
bucket = os.environ.get("MODEL_S3_BUCKET")
|
|
53
|
+
key = os.environ.get("MODEL_S3_KEY")
|
|
54
|
+
if bucket and key:
|
|
55
|
+
local_path = "/tmp/model.gguf"
|
|
56
|
+
if not Path(local_path).exists():
|
|
57
|
+
import boto3
|
|
58
|
+
|
|
59
|
+
logger.info("Downloading model from s3://%s/%s", bucket, key)
|
|
60
|
+
s3 = boto3.client("s3")
|
|
61
|
+
s3.download_file(bucket, key, local_path)
|
|
62
|
+
logger.info("Model downloaded to %s", local_path)
|
|
63
|
+
return local_path
|
|
64
|
+
|
|
65
|
+
raise RuntimeError(
|
|
66
|
+
"No model found. Set MODEL_PATH, place model at /opt/model.gguf, "
|
|
67
|
+
"or set MODEL_S3_BUCKET + MODEL_S3_KEY."
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def _get_repairer():
|
|
72
|
+
"""Get or create the GGUFRepairer (cached across warm invocations)."""
|
|
73
|
+
global _repairer
|
|
74
|
+
if _repairer is None:
|
|
75
|
+
from cloudgym.fixer.repairer import GGUFRepairer
|
|
76
|
+
|
|
77
|
+
_repairer = GGUFRepairer(
|
|
78
|
+
model_path=_get_model_path(),
|
|
79
|
+
n_gpu_layers=0,
|
|
80
|
+
max_tokens=4096,
|
|
81
|
+
)
|
|
82
|
+
return _repairer
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
def handler(event, context=None):
|
|
86
|
+
"""AWS Lambda entry point for IaC repair.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
event: Dict with 'config' (str), 'errors' (list[str]), optional 'format' (str).
|
|
90
|
+
context: Lambda context (unused).
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
Dict with 'repaired' (str), 'verified' (bool), etc.
|
|
94
|
+
"""
|
|
95
|
+
config = event.get("config", "")
|
|
96
|
+
errors = event.get("errors", [])
|
|
97
|
+
iac_format = event.get("format")
|
|
98
|
+
|
|
99
|
+
if not config:
|
|
100
|
+
return {"error": "Missing 'config' field", "statusCode": 400}
|
|
101
|
+
|
|
102
|
+
if not errors:
|
|
103
|
+
return {"repaired": config, "original_errors": 0, "verified": True, "remaining_errors": 0}
|
|
104
|
+
|
|
105
|
+
repairer = _get_repairer()
|
|
106
|
+
repaired = repairer.repair(config, errors)
|
|
107
|
+
|
|
108
|
+
# Optional: verify the repair
|
|
109
|
+
verified = False
|
|
110
|
+
remaining_errors = -1
|
|
111
|
+
try:
|
|
112
|
+
from cloudgym.fixer.detector import IaCFormat, validate_content_sync
|
|
113
|
+
|
|
114
|
+
fmt = IaCFormat(iac_format) if iac_format else None
|
|
115
|
+
_, result = validate_content_sync(repaired, fmt)
|
|
116
|
+
verified = result.valid
|
|
117
|
+
remaining_errors = len(result.errors)
|
|
118
|
+
except Exception:
|
|
119
|
+
logger.exception("Verification failed")
|
|
120
|
+
|
|
121
|
+
return {
|
|
122
|
+
"repaired": repaired,
|
|
123
|
+
"original_errors": len(errors),
|
|
124
|
+
"verified": verified,
|
|
125
|
+
"remaining_errors": remaining_errors,
|
|
126
|
+
}
|
|
@@ -0,0 +1,237 @@
|
|
|
1
|
+
"""Model-based IaC repair engine."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
logger = logging.getLogger(__name__)
|
|
9
|
+
|
|
10
|
+
REPAIR_SYSTEM_PROMPT = (
|
|
11
|
+
"You are an expert Infrastructure-as-Code engineer. "
|
|
12
|
+
"Fix the broken configuration below. Return ONLY the fixed configuration "
|
|
13
|
+
"with no explanation, no markdown fences, and no comments about the fix."
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
DISCUSS_SYSTEM_PROMPT = (
|
|
17
|
+
"You are an expert Infrastructure-as-Code engineer and teacher. "
|
|
18
|
+
"Analyze the configuration and validation errors below. Explain:\n"
|
|
19
|
+
"1. What each error means in plain language\n"
|
|
20
|
+
"2. Why it's a problem (what would happen in production)\n"
|
|
21
|
+
"3. How to fix it\n"
|
|
22
|
+
"4. Best practices to avoid this in future\n"
|
|
23
|
+
"Be concise but thorough."
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# Default model/adapter paths (relative to package or absolute)
|
|
27
|
+
DEFAULT_BASE_MODEL = "mlx-community/Qwen2.5-Coder-3B-Instruct-4bit"
|
|
28
|
+
DEFAULT_ADAPTER_PATH = "data/models/iac-repair-adapter"
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def _strip_markdown_fences(text: str) -> str:
|
|
32
|
+
"""Strip markdown code fences and chat tokens from model output."""
|
|
33
|
+
for stop in ("<|im_end|>", "<|endoftext|>", "<|end|>"):
|
|
34
|
+
if stop in text:
|
|
35
|
+
text = text[:text.index(stop)]
|
|
36
|
+
|
|
37
|
+
text = text.strip()
|
|
38
|
+
lines = text.splitlines()
|
|
39
|
+
if lines and lines[0].startswith("```"):
|
|
40
|
+
lines = lines[1:]
|
|
41
|
+
if lines and lines[-1].strip() == "```":
|
|
42
|
+
lines = lines[:-1]
|
|
43
|
+
return "\n".join(lines)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def _build_prompt(broken_config: str, errors: list[str]) -> str:
|
|
47
|
+
"""Build the repair prompt from config content and error messages."""
|
|
48
|
+
error_text = "\n".join(errors) if errors else "Unknown error"
|
|
49
|
+
return (
|
|
50
|
+
f"This IaC configuration has validation errors:\n\n"
|
|
51
|
+
f"Errors:\n{error_text}\n\n"
|
|
52
|
+
f"Broken config:\n```\n{broken_config}\n```\n\n"
|
|
53
|
+
f"Return the fixed configuration:"
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def _build_discuss_prompt(config: str, errors: list[str]) -> str:
|
|
58
|
+
"""Build the discussion prompt for explaining errors."""
|
|
59
|
+
error_text = "\n".join(errors) if errors else "No validation errors"
|
|
60
|
+
return (
|
|
61
|
+
f"Analyze this IaC configuration:\n\n"
|
|
62
|
+
f"Validation errors:\n{error_text}\n\n"
|
|
63
|
+
f"Configuration:\n```\n{config}\n```\n\n"
|
|
64
|
+
f"Explain what's wrong and how to fix it:"
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class MLXRepairer:
|
|
69
|
+
"""Repair IaC configs using a local MLX fine-tuned model."""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
base_model: str = DEFAULT_BASE_MODEL,
|
|
74
|
+
adapter_path: str = DEFAULT_ADAPTER_PATH,
|
|
75
|
+
temp: float = 0.3,
|
|
76
|
+
max_tokens: int = 4096,
|
|
77
|
+
):
|
|
78
|
+
self.base_model = base_model
|
|
79
|
+
self.adapter_path = adapter_path
|
|
80
|
+
self.temp = temp
|
|
81
|
+
self.max_tokens = max_tokens
|
|
82
|
+
self._model = None
|
|
83
|
+
self._tokenizer = None
|
|
84
|
+
|
|
85
|
+
def _ensure_loaded(self):
|
|
86
|
+
"""Lazy-load model and tokenizer."""
|
|
87
|
+
if self._model is not None:
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
from mlx_lm import load
|
|
91
|
+
|
|
92
|
+
adapter = self.adapter_path if Path(self.adapter_path).exists() else None
|
|
93
|
+
logger.info("Loading model %s (adapter: %s)", self.base_model, adapter)
|
|
94
|
+
self._model, self._tokenizer = load(
|
|
95
|
+
self.base_model, adapter_path=adapter
|
|
96
|
+
)
|
|
97
|
+
logger.info("Model loaded.")
|
|
98
|
+
|
|
99
|
+
def _generate(self, system_prompt: str, user_prompt: str) -> str:
|
|
100
|
+
"""Generate a response with the given prompts."""
|
|
101
|
+
self._ensure_loaded()
|
|
102
|
+
|
|
103
|
+
from mlx_lm import generate
|
|
104
|
+
from mlx_lm.sample_utils import make_sampler
|
|
105
|
+
|
|
106
|
+
messages = [
|
|
107
|
+
{"role": "system", "content": system_prompt},
|
|
108
|
+
{"role": "user", "content": user_prompt},
|
|
109
|
+
]
|
|
110
|
+
prompt_text = self._tokenizer.apply_chat_template(
|
|
111
|
+
messages, tokenize=False, add_generation_prompt=True
|
|
112
|
+
)
|
|
113
|
+
sampler = make_sampler(temp=self.temp)
|
|
114
|
+
return generate(
|
|
115
|
+
self._model, self._tokenizer, prompt=prompt_text,
|
|
116
|
+
max_tokens=self.max_tokens, sampler=sampler,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
def repair(self, broken_config: str, errors: list[str]) -> str:
|
|
120
|
+
"""Generate a repaired config from broken config + errors."""
|
|
121
|
+
return _strip_markdown_fences(
|
|
122
|
+
self._generate(REPAIR_SYSTEM_PROMPT, _build_prompt(broken_config, errors))
|
|
123
|
+
)
|
|
124
|
+
|
|
125
|
+
def discuss(self, config: str, errors: list[str]) -> str:
|
|
126
|
+
"""Explain errors and suggest fixes in natural language."""
|
|
127
|
+
return self._generate(DISCUSS_SYSTEM_PROMPT, _build_discuss_prompt(config, errors))
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
class GGUFRepairer:
|
|
131
|
+
"""Repair IaC configs using a GGUF model via llama-cpp-python (cross-platform CPU)."""
|
|
132
|
+
|
|
133
|
+
def __init__(
|
|
134
|
+
self,
|
|
135
|
+
model_path: str,
|
|
136
|
+
n_gpu_layers: int = 0,
|
|
137
|
+
temp: float = 0.3,
|
|
138
|
+
max_tokens: int = 4096,
|
|
139
|
+
n_ctx: int = 8192,
|
|
140
|
+
):
|
|
141
|
+
self.model_path = model_path
|
|
142
|
+
self.n_gpu_layers = n_gpu_layers
|
|
143
|
+
self.temp = temp
|
|
144
|
+
self.max_tokens = max_tokens
|
|
145
|
+
self.n_ctx = n_ctx
|
|
146
|
+
self._llm = None
|
|
147
|
+
|
|
148
|
+
def _ensure_loaded(self):
|
|
149
|
+
"""Lazy-load the GGUF model."""
|
|
150
|
+
if self._llm is not None:
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
from llama_cpp import Llama
|
|
154
|
+
|
|
155
|
+
logger.info("Loading GGUF model from %s", self.model_path)
|
|
156
|
+
self._llm = Llama(
|
|
157
|
+
model_path=self.model_path,
|
|
158
|
+
n_gpu_layers=self.n_gpu_layers,
|
|
159
|
+
n_ctx=self.n_ctx,
|
|
160
|
+
verbose=False,
|
|
161
|
+
chat_format="chatml",
|
|
162
|
+
)
|
|
163
|
+
logger.info("GGUF model loaded.")
|
|
164
|
+
|
|
165
|
+
def _chat(self, system_prompt: str, user_prompt: str) -> str:
|
|
166
|
+
"""Generate a response via llama.cpp chat completion."""
|
|
167
|
+
self._ensure_loaded()
|
|
168
|
+
|
|
169
|
+
response = self._llm.create_chat_completion(
|
|
170
|
+
messages=[
|
|
171
|
+
{"role": "system", "content": system_prompt},
|
|
172
|
+
{"role": "user", "content": user_prompt},
|
|
173
|
+
],
|
|
174
|
+
temperature=self.temp,
|
|
175
|
+
max_tokens=self.max_tokens,
|
|
176
|
+
)
|
|
177
|
+
return response["choices"][0]["message"]["content"]
|
|
178
|
+
|
|
179
|
+
def repair(self, broken_config: str, errors: list[str]) -> str:
|
|
180
|
+
"""Generate a repaired config from broken config + errors."""
|
|
181
|
+
return _strip_markdown_fences(
|
|
182
|
+
self._chat(REPAIR_SYSTEM_PROMPT, _build_prompt(broken_config, errors))
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
def discuss(self, config: str, errors: list[str]) -> str:
|
|
186
|
+
"""Explain errors and suggest fixes in natural language."""
|
|
187
|
+
return self._chat(DISCUSS_SYSTEM_PROMPT, _build_discuss_prompt(config, errors))
|
|
188
|
+
|
|
189
|
+
def unload(self):
|
|
190
|
+
"""Explicitly release model memory."""
|
|
191
|
+
if self._llm is not None:
|
|
192
|
+
del self._llm
|
|
193
|
+
self._llm = None
|
|
194
|
+
logger.info("GGUF model unloaded.")
|
|
195
|
+
|
|
196
|
+
def __enter__(self):
|
|
197
|
+
return self
|
|
198
|
+
|
|
199
|
+
def __exit__(self, *args):
|
|
200
|
+
self.unload()
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class OllamaRepairer:
|
|
204
|
+
"""Repair IaC configs using an Ollama-served model."""
|
|
205
|
+
|
|
206
|
+
def __init__(self, model: str = "qwen2.5-coder:3b", base_url: str = "http://localhost:11434"):
|
|
207
|
+
self.model = model
|
|
208
|
+
self.base_url = base_url
|
|
209
|
+
|
|
210
|
+
def _chat(self, system_prompt: str, user_prompt: str) -> str:
|
|
211
|
+
"""Send a chat request to Ollama."""
|
|
212
|
+
import httpx
|
|
213
|
+
|
|
214
|
+
resp = httpx.post(
|
|
215
|
+
f"{self.base_url}/api/chat",
|
|
216
|
+
json={
|
|
217
|
+
"model": self.model,
|
|
218
|
+
"messages": [
|
|
219
|
+
{"role": "system", "content": system_prompt},
|
|
220
|
+
{"role": "user", "content": user_prompt},
|
|
221
|
+
],
|
|
222
|
+
"stream": False,
|
|
223
|
+
},
|
|
224
|
+
timeout=120.0,
|
|
225
|
+
)
|
|
226
|
+
resp.raise_for_status()
|
|
227
|
+
return resp.json()["message"]["content"]
|
|
228
|
+
|
|
229
|
+
def repair(self, broken_config: str, errors: list[str]) -> str:
|
|
230
|
+
"""Generate a repaired config via Ollama API."""
|
|
231
|
+
return _strip_markdown_fences(
|
|
232
|
+
self._chat(REPAIR_SYSTEM_PROMPT, _build_prompt(broken_config, errors))
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
def discuss(self, config: str, errors: list[str]) -> str:
|
|
236
|
+
"""Explain errors and suggest fixes in natural language."""
|
|
237
|
+
return self._chat(DISCUSS_SYSTEM_PROMPT, _build_discuss_prompt(config, errors))
|
|
File without changes
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
"""Output formatter for training data (JSONL with train/val/test splits)."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import hashlib
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
from dataclasses import asdict, dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class TrainingRecord:
|
|
16
|
+
"""A single training record (gold + broken pair)."""
|
|
17
|
+
|
|
18
|
+
id: str
|
|
19
|
+
format: str # "terraform" or "cloudformation"
|
|
20
|
+
gold_config: str
|
|
21
|
+
broken_config: str
|
|
22
|
+
errors: list[str]
|
|
23
|
+
warnings: list[str]
|
|
24
|
+
fault_types: list[str]
|
|
25
|
+
fault_description: str
|
|
26
|
+
difficulty: str # "low", "medium", "high"
|
|
27
|
+
source: str # "programmatic" or "agentic"
|
|
28
|
+
split: str = "" # "train", "val", or "test" — assigned by format_and_split
|
|
29
|
+
gold_hash: str = "" # Hash of gold config for dedup/split
|
|
30
|
+
|
|
31
|
+
def __post_init__(self):
|
|
32
|
+
if not self.gold_hash:
|
|
33
|
+
self.gold_hash = hashlib.md5(self.gold_config.encode()).hexdigest()
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def format_and_split(
|
|
37
|
+
records: list[TrainingRecord],
|
|
38
|
+
output_dir: str | Path,
|
|
39
|
+
ratios: tuple[float, float, float] = (0.8, 0.1, 0.1),
|
|
40
|
+
) -> dict:
|
|
41
|
+
"""Split records by gold config and write JSONL files.
|
|
42
|
+
|
|
43
|
+
Splitting is done by gold config hash (not by record) to prevent
|
|
44
|
+
data leakage — all variants of the same gold config go to the same split.
|
|
45
|
+
|
|
46
|
+
Args:
|
|
47
|
+
records: List of TrainingRecord objects.
|
|
48
|
+
output_dir: Directory to write train.jsonl, val.jsonl, test.jsonl.
|
|
49
|
+
ratios: (train, val, test) split ratios.
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
Metadata dict with counts and split info.
|
|
53
|
+
"""
|
|
54
|
+
output_path = Path(output_dir)
|
|
55
|
+
output_path.mkdir(parents=True, exist_ok=True)
|
|
56
|
+
|
|
57
|
+
# Group records by gold_hash
|
|
58
|
+
gold_groups: dict[str, list[TrainingRecord]] = {}
|
|
59
|
+
for record in records:
|
|
60
|
+
gold_groups.setdefault(record.gold_hash, []).append(record)
|
|
61
|
+
|
|
62
|
+
# Deterministic ordering by hash
|
|
63
|
+
sorted_hashes = sorted(gold_groups.keys())
|
|
64
|
+
total_golds = len(sorted_hashes)
|
|
65
|
+
|
|
66
|
+
train_end = int(total_golds * ratios[0])
|
|
67
|
+
val_end = train_end + int(total_golds * ratios[1])
|
|
68
|
+
|
|
69
|
+
train_hashes = set(sorted_hashes[:train_end])
|
|
70
|
+
val_hashes = set(sorted_hashes[train_end:val_end])
|
|
71
|
+
test_hashes = set(sorted_hashes[val_end:])
|
|
72
|
+
|
|
73
|
+
# Assign splits
|
|
74
|
+
splits: dict[str, list[TrainingRecord]] = {"train": [], "val": [], "test": []}
|
|
75
|
+
for h in sorted_hashes:
|
|
76
|
+
if h in train_hashes:
|
|
77
|
+
split = "train"
|
|
78
|
+
elif h in val_hashes:
|
|
79
|
+
split = "val"
|
|
80
|
+
else:
|
|
81
|
+
split = "test"
|
|
82
|
+
for record in gold_groups[h]:
|
|
83
|
+
record.split = split
|
|
84
|
+
splits[split].append(record)
|
|
85
|
+
|
|
86
|
+
# Write JSONL files
|
|
87
|
+
counts = {}
|
|
88
|
+
for split_name, split_records in splits.items():
|
|
89
|
+
filepath = output_path / f"{split_name}.jsonl"
|
|
90
|
+
with open(filepath, "w") as f:
|
|
91
|
+
for record in split_records:
|
|
92
|
+
f.write(json.dumps(asdict(record), ensure_ascii=False) + "\n")
|
|
93
|
+
counts[split_name] = len(split_records)
|
|
94
|
+
logger.info("Wrote %d records to %s", len(split_records), filepath)
|
|
95
|
+
|
|
96
|
+
# Write metadata
|
|
97
|
+
metadata = {
|
|
98
|
+
"total_records": len(records),
|
|
99
|
+
"total_gold_configs": total_golds,
|
|
100
|
+
"splits": counts,
|
|
101
|
+
"ratios": list(ratios),
|
|
102
|
+
"fault_type_distribution": _count_fault_types(records),
|
|
103
|
+
"format_distribution": _count_formats(records),
|
|
104
|
+
"source_distribution": _count_sources(records),
|
|
105
|
+
"difficulty_distribution": _count_difficulties(records),
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
meta_path = output_path / "metadata.json"
|
|
109
|
+
with open(meta_path, "w") as f:
|
|
110
|
+
json.dump(metadata, f, indent=2)
|
|
111
|
+
logger.info("Wrote metadata to %s", meta_path)
|
|
112
|
+
|
|
113
|
+
return metadata
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _count_fault_types(records: list[TrainingRecord]) -> dict[str, int]:
|
|
117
|
+
counts: dict[str, int] = {}
|
|
118
|
+
for r in records:
|
|
119
|
+
for ft in r.fault_types:
|
|
120
|
+
counts[ft] = counts.get(ft, 0) + 1
|
|
121
|
+
return dict(sorted(counts.items()))
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def _count_formats(records: list[TrainingRecord]) -> dict[str, int]:
|
|
125
|
+
counts: dict[str, int] = {}
|
|
126
|
+
for r in records:
|
|
127
|
+
counts[r.format] = counts.get(r.format, 0) + 1
|
|
128
|
+
return counts
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
def _count_sources(records: list[TrainingRecord]) -> dict[str, int]:
|
|
132
|
+
counts: dict[str, int] = {}
|
|
133
|
+
for r in records:
|
|
134
|
+
counts[r.source] = counts.get(r.source, 0) + 1
|
|
135
|
+
return counts
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def _count_difficulties(records: list[TrainingRecord]) -> dict[str, int]:
|
|
139
|
+
counts: dict[str, int] = {}
|
|
140
|
+
for r in records:
|
|
141
|
+
counts[r.difficulty] = counts.get(r.difficulty, 0) + 1
|
|
142
|
+
return counts
|