markback 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- markback/__init__.py +86 -0
- markback/cli.py +435 -0
- markback/config.py +181 -0
- markback/linter.py +312 -0
- markback/llm.py +175 -0
- markback/parser.py +587 -0
- markback/types.py +270 -0
- markback/workflow.py +351 -0
- markback/writer.py +249 -0
- markback-0.1.0.dist-info/METADATA +251 -0
- markback-0.1.0.dist-info/RECORD +14 -0
- markback-0.1.0.dist-info/WHEEL +4 -0
- markback-0.1.0.dist-info/entry_points.txt +2 -0
- markback-0.1.0.dist-info/licenses/LICENSE +21 -0
markback/types.py
ADDED
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
"""Core types for MarkBack format."""
|
|
2
|
+
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from enum import Enum
|
|
5
|
+
from pathlib import Path
|
|
6
|
+
from typing import Optional, Union
|
|
7
|
+
from urllib.parse import urlparse
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class Severity(Enum):
|
|
11
|
+
"""Diagnostic severity levels."""
|
|
12
|
+
ERROR = "error"
|
|
13
|
+
WARNING = "warning"
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class ErrorCode(Enum):
|
|
17
|
+
"""Lint error codes (MUST fix)."""
|
|
18
|
+
E001 = "E001" # Missing feedback (no <<< delimiter found)
|
|
19
|
+
E002 = "E002" # Multiple <<< delimiters in one record
|
|
20
|
+
E003 = "E003" # Malformed URI in @uri header
|
|
21
|
+
E004 = "E004" # Content after <<< delimiter
|
|
22
|
+
E005 = "E005" # Content present when @source specified
|
|
23
|
+
E006 = "E006" # Malformed header syntax
|
|
24
|
+
E007 = "E007" # Invalid JSON after json: prefix
|
|
25
|
+
E008 = "E008" # Unclosed quote in structured attribute value
|
|
26
|
+
E009 = "E009" # Empty feedback (nothing after <<< )
|
|
27
|
+
E010 = "E010" # Missing blank line before inline content
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class WarningCode(Enum):
|
|
31
|
+
"""Lint warning codes (SHOULD fix)."""
|
|
32
|
+
W001 = "W001" # Duplicate URI within same file
|
|
33
|
+
W002 = "W002" # Unknown header keyword
|
|
34
|
+
W003 = "W003" # @source file not found
|
|
35
|
+
W004 = "W004" # Trailing whitespace on line
|
|
36
|
+
W005 = "W005" # Multiple blank lines
|
|
37
|
+
W006 = "W006" # Missing @uri (record has no identifier)
|
|
38
|
+
W007 = "W007" # Paired feedback file not found
|
|
39
|
+
W008 = "W008" # Non-canonical formatting detected
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
@dataclass
|
|
43
|
+
class Diagnostic:
|
|
44
|
+
"""A lint diagnostic message."""
|
|
45
|
+
file: Optional[Path]
|
|
46
|
+
line: Optional[int]
|
|
47
|
+
column: Optional[int]
|
|
48
|
+
severity: Severity
|
|
49
|
+
code: Union[ErrorCode, WarningCode]
|
|
50
|
+
message: str
|
|
51
|
+
record_index: Optional[int] = None
|
|
52
|
+
|
|
53
|
+
def __str__(self) -> str:
|
|
54
|
+
parts = []
|
|
55
|
+
if self.file:
|
|
56
|
+
parts.append(str(self.file))
|
|
57
|
+
if self.line is not None:
|
|
58
|
+
parts.append(str(self.line))
|
|
59
|
+
if self.column is not None:
|
|
60
|
+
parts.append(str(self.column))
|
|
61
|
+
|
|
62
|
+
location = ":".join(parts) if parts else "<unknown>"
|
|
63
|
+
return f"{location}: {self.code.value} {self.message}"
|
|
64
|
+
|
|
65
|
+
def to_dict(self) -> dict:
|
|
66
|
+
"""Convert to JSON-serializable dict."""
|
|
67
|
+
return {
|
|
68
|
+
"file": str(self.file) if self.file else None,
|
|
69
|
+
"line": self.line,
|
|
70
|
+
"column": self.column,
|
|
71
|
+
"severity": self.severity.value,
|
|
72
|
+
"code": self.code.value,
|
|
73
|
+
"message": self.message,
|
|
74
|
+
"record_index": self.record_index,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class SourceRef:
|
|
80
|
+
"""Reference to external content (file path or URI)."""
|
|
81
|
+
value: str
|
|
82
|
+
is_uri: bool = False
|
|
83
|
+
|
|
84
|
+
def __post_init__(self):
|
|
85
|
+
# Determine if this is a URI or file path
|
|
86
|
+
if not self.is_uri:
|
|
87
|
+
parsed = urlparse(self.value)
|
|
88
|
+
# Consider it a URI if it has a scheme that's not a Windows drive letter
|
|
89
|
+
self.is_uri = bool(parsed.scheme) and len(parsed.scheme) > 1
|
|
90
|
+
|
|
91
|
+
def resolve(self, base_path: Optional[Path] = None) -> Path:
|
|
92
|
+
"""Resolve to a file path (relative paths resolved against base_path)."""
|
|
93
|
+
if self.is_uri:
|
|
94
|
+
parsed = urlparse(self.value)
|
|
95
|
+
if parsed.scheme == "file":
|
|
96
|
+
# file:// URI
|
|
97
|
+
return Path(parsed.path)
|
|
98
|
+
raise ValueError(f"Cannot resolve non-file URI to path: {self.value}")
|
|
99
|
+
|
|
100
|
+
path = Path(self.value)
|
|
101
|
+
if path.is_absolute():
|
|
102
|
+
return path
|
|
103
|
+
if base_path:
|
|
104
|
+
return base_path / path
|
|
105
|
+
return path
|
|
106
|
+
|
|
107
|
+
def __str__(self) -> str:
|
|
108
|
+
return self.value
|
|
109
|
+
|
|
110
|
+
def __eq__(self, other: object) -> bool:
|
|
111
|
+
if isinstance(other, SourceRef):
|
|
112
|
+
return self.value == other.value
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
def __hash__(self) -> int:
|
|
116
|
+
return hash(self.value)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
@dataclass
|
|
120
|
+
class Record:
|
|
121
|
+
"""A MarkBack record containing content and feedback."""
|
|
122
|
+
feedback: str
|
|
123
|
+
uri: Optional[str] = None
|
|
124
|
+
source: Optional[SourceRef] = None
|
|
125
|
+
content: Optional[str] = None
|
|
126
|
+
metadata: dict = field(default_factory=dict)
|
|
127
|
+
|
|
128
|
+
# Parsing metadata (not part of logical record)
|
|
129
|
+
_source_file: Optional[Path] = field(default=None, repr=False, compare=False)
|
|
130
|
+
_start_line: Optional[int] = field(default=None, repr=False, compare=False)
|
|
131
|
+
_end_line: Optional[int] = field(default=None, repr=False, compare=False)
|
|
132
|
+
_is_compact: bool = field(default=False, repr=False, compare=False)
|
|
133
|
+
|
|
134
|
+
def __post_init__(self):
|
|
135
|
+
# Validate: must have either content or source
|
|
136
|
+
if self.content is None and self.source is None:
|
|
137
|
+
# This is allowed - feedback-only record
|
|
138
|
+
pass
|
|
139
|
+
|
|
140
|
+
def get_identifier(self) -> Optional[str]:
|
|
141
|
+
"""Get the record identifier (URI or source path)."""
|
|
142
|
+
if self.uri:
|
|
143
|
+
return self.uri
|
|
144
|
+
if self.source:
|
|
145
|
+
return str(self.source)
|
|
146
|
+
return None
|
|
147
|
+
|
|
148
|
+
def has_inline_content(self) -> bool:
|
|
149
|
+
"""Check if record has inline content (vs external source)."""
|
|
150
|
+
return self.content is not None and len(self.content.strip()) > 0
|
|
151
|
+
|
|
152
|
+
def to_dict(self) -> dict:
|
|
153
|
+
"""Convert to JSON-serializable dict."""
|
|
154
|
+
return {
|
|
155
|
+
"uri": self.uri,
|
|
156
|
+
"source": str(self.source) if self.source else None,
|
|
157
|
+
"content": self.content,
|
|
158
|
+
"feedback": self.feedback,
|
|
159
|
+
"metadata": self.metadata,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
|
|
163
|
+
@dataclass
|
|
164
|
+
class ParseResult:
|
|
165
|
+
"""Result of parsing a MarkBack file or set of files."""
|
|
166
|
+
records: list[Record]
|
|
167
|
+
diagnostics: list[Diagnostic]
|
|
168
|
+
source_file: Optional[Path] = None
|
|
169
|
+
|
|
170
|
+
@property
|
|
171
|
+
def has_errors(self) -> bool:
|
|
172
|
+
return any(d.severity == Severity.ERROR for d in self.diagnostics)
|
|
173
|
+
|
|
174
|
+
@property
|
|
175
|
+
def has_warnings(self) -> bool:
|
|
176
|
+
return any(d.severity == Severity.WARNING for d in self.diagnostics)
|
|
177
|
+
|
|
178
|
+
@property
|
|
179
|
+
def error_count(self) -> int:
|
|
180
|
+
return sum(1 for d in self.diagnostics if d.severity == Severity.ERROR)
|
|
181
|
+
|
|
182
|
+
@property
|
|
183
|
+
def warning_count(self) -> int:
|
|
184
|
+
return sum(1 for d in self.diagnostics if d.severity == Severity.WARNING)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
@dataclass
|
|
188
|
+
class FeedbackParsed:
|
|
189
|
+
"""Parsed structured feedback."""
|
|
190
|
+
raw: str
|
|
191
|
+
label: Optional[str] = None
|
|
192
|
+
attributes: dict = field(default_factory=dict)
|
|
193
|
+
comment: Optional[str] = None
|
|
194
|
+
is_json: bool = False
|
|
195
|
+
json_data: Optional[dict] = None
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def parse_feedback(feedback: str) -> FeedbackParsed:
|
|
199
|
+
"""Parse feedback string into structured components.
|
|
200
|
+
|
|
201
|
+
Supports:
|
|
202
|
+
- Simple label: "positive"
|
|
203
|
+
- Label + comment: "negative; use more formal language"
|
|
204
|
+
- Attributes: "sentiment=positive; confidence=0.9"
|
|
205
|
+
- Mixed: "good; quality=high; needs more detail"
|
|
206
|
+
- JSON: "json:{...}"
|
|
207
|
+
"""
|
|
208
|
+
import json as json_module
|
|
209
|
+
|
|
210
|
+
result = FeedbackParsed(raw=feedback)
|
|
211
|
+
|
|
212
|
+
# Check for JSON mode
|
|
213
|
+
if feedback.startswith("json:"):
|
|
214
|
+
result.is_json = True
|
|
215
|
+
try:
|
|
216
|
+
result.json_data = json_module.loads(feedback[5:])
|
|
217
|
+
except json_module.JSONDecodeError:
|
|
218
|
+
pass # Invalid JSON, leave as raw
|
|
219
|
+
return result
|
|
220
|
+
|
|
221
|
+
# Split on "; " (semicolon + space)
|
|
222
|
+
segments = []
|
|
223
|
+
current = []
|
|
224
|
+
in_quotes = False
|
|
225
|
+
i = 0
|
|
226
|
+
|
|
227
|
+
while i < len(feedback):
|
|
228
|
+
char = feedback[i]
|
|
229
|
+
|
|
230
|
+
if char == '"' and (i == 0 or feedback[i-1] != '\\'):
|
|
231
|
+
in_quotes = not in_quotes
|
|
232
|
+
current.append(char)
|
|
233
|
+
elif char == ';' and not in_quotes and i + 1 < len(feedback) and feedback[i + 1] == ' ':
|
|
234
|
+
segments.append(''.join(current))
|
|
235
|
+
current = []
|
|
236
|
+
i += 1 # Skip the space after semicolon
|
|
237
|
+
else:
|
|
238
|
+
current.append(char)
|
|
239
|
+
i += 1
|
|
240
|
+
|
|
241
|
+
if current:
|
|
242
|
+
segments.append(''.join(current))
|
|
243
|
+
|
|
244
|
+
# Classify segments
|
|
245
|
+
for segment in segments:
|
|
246
|
+
segment = segment.strip()
|
|
247
|
+
if not segment:
|
|
248
|
+
continue
|
|
249
|
+
|
|
250
|
+
if '=' in segment:
|
|
251
|
+
# Key-value attribute
|
|
252
|
+
eq_pos = segment.index('=')
|
|
253
|
+
key = segment[:eq_pos]
|
|
254
|
+
value = segment[eq_pos + 1:]
|
|
255
|
+
# Remove quotes if present
|
|
256
|
+
if value.startswith('"') and value.endswith('"'):
|
|
257
|
+
value = value[1:-1].replace('\\"', '"').replace('\\\\', '\\')
|
|
258
|
+
result.attributes[key] = value
|
|
259
|
+
else:
|
|
260
|
+
# Label or comment
|
|
261
|
+
if result.label is None:
|
|
262
|
+
result.label = segment
|
|
263
|
+
else:
|
|
264
|
+
# Additional non-attribute segment is a comment
|
|
265
|
+
if result.comment:
|
|
266
|
+
result.comment += "; " + segment
|
|
267
|
+
else:
|
|
268
|
+
result.comment = segment
|
|
269
|
+
|
|
270
|
+
return result
|
markback/workflow.py
ADDED
|
@@ -0,0 +1,351 @@
|
|
|
1
|
+
"""Editor/Operator workflow for prompt refinement and evaluation."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from dataclasses import dataclass, field
|
|
5
|
+
from datetime import datetime
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Optional
|
|
8
|
+
|
|
9
|
+
from .config import Config
|
|
10
|
+
from .llm import LLMClient, LLMClientFactory, LLMResponse
|
|
11
|
+
from .types import Record, parse_feedback
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class WorkflowResult:
|
|
16
|
+
"""Result of a workflow run."""
|
|
17
|
+
refined_prompt: str
|
|
18
|
+
outputs: list[dict] # {record_idx, output, ...}
|
|
19
|
+
evaluation: dict
|
|
20
|
+
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@dataclass
|
|
24
|
+
class EvaluationResult:
|
|
25
|
+
"""Result of evaluating operator outputs against expected feedback."""
|
|
26
|
+
total: int
|
|
27
|
+
correct: int
|
|
28
|
+
incorrect: int
|
|
29
|
+
accuracy: float
|
|
30
|
+
details: list[dict] # Per-record evaluation details
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# Default prompts
|
|
34
|
+
EDITOR_SYSTEM_PROMPT = """You are a prompt engineer. Your task is to refine and improve prompts based on examples and feedback.
|
|
35
|
+
|
|
36
|
+
Given:
|
|
37
|
+
1. An initial prompt (may be empty)
|
|
38
|
+
2. A set of examples with their expected outputs/labels
|
|
39
|
+
3. Feedback on what works and what doesn't
|
|
40
|
+
|
|
41
|
+
Produce an improved prompt that will help an LLM generate better outputs for similar examples.
|
|
42
|
+
|
|
43
|
+
Output ONLY the refined prompt, no explanations."""
|
|
44
|
+
|
|
45
|
+
EDITOR_USER_TEMPLATE = """Initial prompt:
|
|
46
|
+
{initial_prompt}
|
|
47
|
+
|
|
48
|
+
Examples and feedback:
|
|
49
|
+
{examples}
|
|
50
|
+
|
|
51
|
+
Based on this feedback, produce an improved prompt that addresses the issues noted."""
|
|
52
|
+
|
|
53
|
+
OPERATOR_SYSTEM_PROMPT = """Follow the instructions in the prompt exactly. Respond with the output only."""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def format_examples_for_editor(records: list[Record]) -> str:
|
|
57
|
+
"""Format records as examples for the editor prompt."""
|
|
58
|
+
parts = []
|
|
59
|
+
|
|
60
|
+
for i, record in enumerate(records):
|
|
61
|
+
parts.append(f"--- Example {i + 1} ---")
|
|
62
|
+
|
|
63
|
+
if record.content:
|
|
64
|
+
parts.append(f"Input: {record.content[:500]}{'...' if len(record.content) > 500 else ''}")
|
|
65
|
+
elif record.source:
|
|
66
|
+
parts.append(f"Input: [from {record.source}]")
|
|
67
|
+
|
|
68
|
+
# Parse feedback for structured info
|
|
69
|
+
parsed = parse_feedback(record.feedback)
|
|
70
|
+
|
|
71
|
+
if parsed.label:
|
|
72
|
+
parts.append(f"Label: {parsed.label}")
|
|
73
|
+
if parsed.comment:
|
|
74
|
+
parts.append(f"Feedback: {parsed.comment}")
|
|
75
|
+
if parsed.attributes:
|
|
76
|
+
parts.append(f"Attributes: {parsed.attributes}")
|
|
77
|
+
|
|
78
|
+
parts.append(f"Raw feedback: {record.feedback}")
|
|
79
|
+
parts.append("")
|
|
80
|
+
|
|
81
|
+
return "\n".join(parts)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def run_editor(
|
|
85
|
+
client: LLMClient,
|
|
86
|
+
initial_prompt: str,
|
|
87
|
+
records: list[Record],
|
|
88
|
+
system_prompt: Optional[str] = None,
|
|
89
|
+
) -> str:
|
|
90
|
+
"""Run the editor LLM to refine a prompt.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
client: LLM client to use
|
|
94
|
+
initial_prompt: The starting prompt (may be empty)
|
|
95
|
+
records: Training records with content and feedback
|
|
96
|
+
system_prompt: Optional custom system prompt
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
Refined prompt string
|
|
100
|
+
"""
|
|
101
|
+
examples = format_examples_for_editor(records)
|
|
102
|
+
|
|
103
|
+
user_prompt = EDITOR_USER_TEMPLATE.format(
|
|
104
|
+
initial_prompt=initial_prompt or "(No initial prompt provided)",
|
|
105
|
+
examples=examples,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
response = client.complete(
|
|
109
|
+
prompt=user_prompt,
|
|
110
|
+
system=system_prompt or EDITOR_SYSTEM_PROMPT,
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
return response.content.strip()
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def run_operator(
|
|
117
|
+
client: LLMClient,
|
|
118
|
+
prompt: str,
|
|
119
|
+
input_content: str,
|
|
120
|
+
system_prompt: Optional[str] = None,
|
|
121
|
+
) -> str:
|
|
122
|
+
"""Run the operator LLM with a prompt on input content.
|
|
123
|
+
|
|
124
|
+
Args:
|
|
125
|
+
client: LLM client to use
|
|
126
|
+
prompt: The prompt to apply
|
|
127
|
+
input_content: The input content to process
|
|
128
|
+
system_prompt: Optional custom system prompt
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Operator output
|
|
132
|
+
"""
|
|
133
|
+
user_prompt = f"{prompt}\n\nInput:\n{input_content}"
|
|
134
|
+
|
|
135
|
+
response = client.complete(
|
|
136
|
+
prompt=user_prompt,
|
|
137
|
+
system=system_prompt or OPERATOR_SYSTEM_PROMPT,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
return response.content.strip()
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def run_operator_batch(
|
|
144
|
+
client: LLMClient,
|
|
145
|
+
prompt: str,
|
|
146
|
+
records: list[Record],
|
|
147
|
+
system_prompt: Optional[str] = None,
|
|
148
|
+
) -> list[dict]:
|
|
149
|
+
"""Run the operator LLM on multiple records.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
client: LLM client to use
|
|
153
|
+
prompt: The prompt to apply
|
|
154
|
+
records: Records to process
|
|
155
|
+
system_prompt: Optional custom system prompt
|
|
156
|
+
|
|
157
|
+
Returns:
|
|
158
|
+
List of outputs with record info
|
|
159
|
+
"""
|
|
160
|
+
outputs = []
|
|
161
|
+
|
|
162
|
+
for i, record in enumerate(records):
|
|
163
|
+
content = record.content or ""
|
|
164
|
+
|
|
165
|
+
if not content and record.source:
|
|
166
|
+
# Try to load content from source
|
|
167
|
+
try:
|
|
168
|
+
source_path = record.source.resolve()
|
|
169
|
+
if source_path.exists():
|
|
170
|
+
content = source_path.read_text(encoding="utf-8")
|
|
171
|
+
except Exception:
|
|
172
|
+
content = f"[Content from {record.source}]"
|
|
173
|
+
|
|
174
|
+
output = run_operator(client, prompt, content, system_prompt)
|
|
175
|
+
|
|
176
|
+
outputs.append({
|
|
177
|
+
"record_idx": i,
|
|
178
|
+
"uri": record.uri,
|
|
179
|
+
"output": output,
|
|
180
|
+
"input_preview": content[:200] if content else None,
|
|
181
|
+
})
|
|
182
|
+
|
|
183
|
+
return outputs
|
|
184
|
+
|
|
185
|
+
|
|
186
|
+
def evaluate_outputs(
|
|
187
|
+
outputs: list[dict],
|
|
188
|
+
records: list[Record],
|
|
189
|
+
config: Config,
|
|
190
|
+
) -> EvaluationResult:
|
|
191
|
+
"""Evaluate operator outputs against expected feedback.
|
|
192
|
+
|
|
193
|
+
Simple evaluation:
|
|
194
|
+
- Parse feedback for label
|
|
195
|
+
- Check if output "matches" the expected label semantically
|
|
196
|
+
|
|
197
|
+
For v1, we use a simple heuristic:
|
|
198
|
+
- If feedback label is positive, output should not contain negative indicators
|
|
199
|
+
- If feedback label is negative, output should acknowledge issues
|
|
200
|
+
"""
|
|
201
|
+
positive_labels = set(config.positive_labels)
|
|
202
|
+
negative_labels = set(config.negative_labels)
|
|
203
|
+
|
|
204
|
+
details = []
|
|
205
|
+
correct = 0
|
|
206
|
+
incorrect = 0
|
|
207
|
+
|
|
208
|
+
for output_info in outputs:
|
|
209
|
+
idx = output_info["record_idx"]
|
|
210
|
+
record = records[idx]
|
|
211
|
+
operator_output = output_info["output"]
|
|
212
|
+
|
|
213
|
+
# Parse expected feedback
|
|
214
|
+
parsed = parse_feedback(record.feedback)
|
|
215
|
+
expected_label = parsed.label.lower() if parsed.label else None
|
|
216
|
+
|
|
217
|
+
# Determine expected sentiment
|
|
218
|
+
expected_positive = expected_label in positive_labels if expected_label else None
|
|
219
|
+
expected_negative = expected_label in negative_labels if expected_label else None
|
|
220
|
+
|
|
221
|
+
# Simple output analysis
|
|
222
|
+
output_lower = operator_output.lower()
|
|
223
|
+
|
|
224
|
+
# Check for obvious positive/negative indicators in output
|
|
225
|
+
output_has_positive = any(word in output_lower for word in ["good", "correct", "yes", "approved", "success"])
|
|
226
|
+
output_has_negative = any(word in output_lower for word in ["bad", "wrong", "no", "error", "fail", "issue"])
|
|
227
|
+
|
|
228
|
+
# Determine if match
|
|
229
|
+
match = None
|
|
230
|
+
if expected_positive is True:
|
|
231
|
+
match = output_has_positive or not output_has_negative
|
|
232
|
+
elif expected_negative is True:
|
|
233
|
+
match = output_has_negative or not output_has_positive
|
|
234
|
+
else:
|
|
235
|
+
# Unknown expected sentiment - can't evaluate
|
|
236
|
+
match = None
|
|
237
|
+
|
|
238
|
+
if match is True:
|
|
239
|
+
correct += 1
|
|
240
|
+
elif match is False:
|
|
241
|
+
incorrect += 1
|
|
242
|
+
|
|
243
|
+
details.append({
|
|
244
|
+
"record_idx": idx,
|
|
245
|
+
"uri": record.uri,
|
|
246
|
+
"expected_label": expected_label,
|
|
247
|
+
"expected_positive": expected_positive,
|
|
248
|
+
"operator_output_preview": operator_output[:200],
|
|
249
|
+
"match": match,
|
|
250
|
+
})
|
|
251
|
+
|
|
252
|
+
total = len(outputs)
|
|
253
|
+
evaluated = correct + incorrect
|
|
254
|
+
accuracy = correct / evaluated if evaluated > 0 else 0.0
|
|
255
|
+
|
|
256
|
+
return EvaluationResult(
|
|
257
|
+
total=total,
|
|
258
|
+
correct=correct,
|
|
259
|
+
incorrect=incorrect,
|
|
260
|
+
accuracy=accuracy,
|
|
261
|
+
details=details,
|
|
262
|
+
)
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
def run_workflow(
|
|
266
|
+
config: Config,
|
|
267
|
+
initial_prompt: str,
|
|
268
|
+
records: list[Record],
|
|
269
|
+
editor_client: Optional[LLMClient] = None,
|
|
270
|
+
operator_client: Optional[LLMClient] = None,
|
|
271
|
+
) -> WorkflowResult:
|
|
272
|
+
"""Run the full editor/operator workflow.
|
|
273
|
+
|
|
274
|
+
1. Editor refines the prompt using examples and feedback
|
|
275
|
+
2. Operator runs the refined prompt on examples
|
|
276
|
+
3. Evaluate operator performance
|
|
277
|
+
|
|
278
|
+
Args:
|
|
279
|
+
config: Configuration
|
|
280
|
+
initial_prompt: Starting prompt (may be empty)
|
|
281
|
+
records: Training records with content and feedback
|
|
282
|
+
editor_client: LLM client for editor (created from config if None)
|
|
283
|
+
operator_client: LLM client for operator (created from config if None)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
Workflow result with refined prompt, outputs, and evaluation
|
|
287
|
+
"""
|
|
288
|
+
# Create clients if not provided
|
|
289
|
+
if editor_client is None:
|
|
290
|
+
if config.editor is None:
|
|
291
|
+
raise ValueError("Editor LLM not configured")
|
|
292
|
+
editor_client = LLMClientFactory.create(config.editor)
|
|
293
|
+
|
|
294
|
+
if operator_client is None:
|
|
295
|
+
if config.operator is None:
|
|
296
|
+
raise ValueError("Operator LLM not configured")
|
|
297
|
+
operator_client = LLMClientFactory.create(config.operator)
|
|
298
|
+
|
|
299
|
+
# Step 1: Editor refines prompt
|
|
300
|
+
refined_prompt = run_editor(editor_client, initial_prompt, records)
|
|
301
|
+
|
|
302
|
+
# Step 2: Operator processes examples with refined prompt
|
|
303
|
+
outputs = run_operator_batch(operator_client, refined_prompt, records)
|
|
304
|
+
|
|
305
|
+
# Step 3: Evaluate
|
|
306
|
+
evaluation = evaluate_outputs(outputs, records, config)
|
|
307
|
+
|
|
308
|
+
return WorkflowResult(
|
|
309
|
+
refined_prompt=refined_prompt,
|
|
310
|
+
outputs=outputs,
|
|
311
|
+
evaluation={
|
|
312
|
+
"total": evaluation.total,
|
|
313
|
+
"correct": evaluation.correct,
|
|
314
|
+
"incorrect": evaluation.incorrect,
|
|
315
|
+
"accuracy": evaluation.accuracy,
|
|
316
|
+
"details": evaluation.details,
|
|
317
|
+
},
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
|
|
321
|
+
def save_workflow_result(
|
|
322
|
+
result: WorkflowResult,
|
|
323
|
+
output_path: Path,
|
|
324
|
+
config: Config,
|
|
325
|
+
) -> Path:
|
|
326
|
+
"""Save workflow result to file.
|
|
327
|
+
|
|
328
|
+
Args:
|
|
329
|
+
result: The workflow result
|
|
330
|
+
output_path: Base path for output
|
|
331
|
+
config: Configuration (for file mode)
|
|
332
|
+
|
|
333
|
+
Returns:
|
|
334
|
+
Path to saved file
|
|
335
|
+
"""
|
|
336
|
+
data = {
|
|
337
|
+
"timestamp": result.timestamp,
|
|
338
|
+
"refined_prompt": result.refined_prompt,
|
|
339
|
+
"outputs": result.outputs,
|
|
340
|
+
"evaluation": result.evaluation,
|
|
341
|
+
}
|
|
342
|
+
|
|
343
|
+
if config.file_mode == "versioned":
|
|
344
|
+
# Add timestamp to filename
|
|
345
|
+
stem = output_path.stem
|
|
346
|
+
suffix = output_path.suffix
|
|
347
|
+
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
348
|
+
output_path = output_path.with_name(f"{stem}_{ts}{suffix}")
|
|
349
|
+
|
|
350
|
+
output_path.write_text(json.dumps(data, indent=2), encoding="utf-8")
|
|
351
|
+
return output_path
|