edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/base/data_transfer_models.py +15 -4
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/dataset/dataset_operations_mixin.py +216 -180
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +7 -3
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +94 -5
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/interviews/request_token_estimator.py +8 -0
- edsl/invigilators/invigilators.py +35 -13
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +154 -113
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +110 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +176 -71
- edsl/language_models/registry.py +5 -0
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_dict.py +201 -16
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +115 -46
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +150 -9
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
edsl/notebooks/notebook.py
CHANGED
@@ -2,6 +2,10 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import json
|
5
|
+
import subprocess
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
import shutil
|
5
9
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
6
10
|
|
7
11
|
if TYPE_CHECKING:
|
@@ -17,12 +21,56 @@ class Notebook(Base):
|
|
17
21
|
"""
|
18
22
|
|
19
23
|
default_name = "notebook"
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
def _lint_code(code: str) -> str:
|
27
|
+
"""
|
28
|
+
Lint Python code using ruff.
|
29
|
+
|
30
|
+
:param code: The Python code to lint
|
31
|
+
:return: The linted code
|
32
|
+
"""
|
33
|
+
try:
|
34
|
+
# Check if ruff is installed
|
35
|
+
if shutil.which("ruff") is None:
|
36
|
+
# If ruff is not installed, return original code
|
37
|
+
return code
|
38
|
+
|
39
|
+
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as temp_file:
|
40
|
+
temp_file.write(code)
|
41
|
+
temp_file_path = temp_file.name
|
42
|
+
|
43
|
+
# Run ruff to format the code
|
44
|
+
try:
|
45
|
+
result = subprocess.run(
|
46
|
+
["ruff", "format", temp_file_path],
|
47
|
+
check=True,
|
48
|
+
stdout=subprocess.PIPE,
|
49
|
+
stderr=subprocess.PIPE
|
50
|
+
)
|
51
|
+
|
52
|
+
# Read the formatted code
|
53
|
+
with open(temp_file_path, 'r') as f:
|
54
|
+
linted_code = f.read()
|
55
|
+
|
56
|
+
return linted_code
|
57
|
+
except subprocess.CalledProcessError:
|
58
|
+
# If ruff fails, return the original code
|
59
|
+
return code
|
60
|
+
except FileNotFoundError:
|
61
|
+
# If ruff is not installed, return the original code
|
62
|
+
return code
|
63
|
+
finally:
|
64
|
+
# Clean up temporary file
|
65
|
+
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
|
66
|
+
os.unlink(temp_file_path)
|
20
67
|
|
21
68
|
def __init__(
|
22
69
|
self,
|
23
70
|
path: Optional[str] = None,
|
24
71
|
data: Optional[Dict] = None,
|
25
72
|
name: Optional[str] = None,
|
73
|
+
lint: bool = True,
|
26
74
|
):
|
27
75
|
"""
|
28
76
|
Initialize a new Notebook.
|
@@ -32,6 +80,7 @@ class Notebook(Base):
|
|
32
80
|
:param path: A filepath from which to load the notebook.
|
33
81
|
If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
|
34
82
|
:param name: A name for the Notebook.
|
83
|
+
:param lint: Whether to lint Python code cells using ruff. Defaults to True.
|
35
84
|
"""
|
36
85
|
import nbformat
|
37
86
|
|
@@ -54,6 +103,16 @@ class Notebook(Base):
|
|
54
103
|
raise NotebookEnvironmentError(
|
55
104
|
"Cannot create a notebook from within itself in this development environment"
|
56
105
|
)
|
106
|
+
|
107
|
+
# Store the lint parameter
|
108
|
+
self.lint = lint
|
109
|
+
|
110
|
+
# Apply linting to code cells if enabled
|
111
|
+
if self.lint and self.data and "cells" in self.data:
|
112
|
+
for cell in self.data["cells"]:
|
113
|
+
if cell.get("cell_type") == "code" and "source" in cell:
|
114
|
+
# Only lint Python code cells
|
115
|
+
cell["source"] = self._lint_code(cell["source"])
|
57
116
|
|
58
117
|
# TODO: perhaps add sanity check function
|
59
118
|
# 1. could check if the notebook is a valid notebook
|
@@ -63,7 +122,7 @@ class Notebook(Base):
|
|
63
122
|
self.name = name or self.default_name
|
64
123
|
|
65
124
|
@classmethod
|
66
|
-
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
125
|
+
def from_script(cls, path: str, name: Optional[str] = None, lint: bool = True) -> "Notebook":
|
67
126
|
import nbformat
|
68
127
|
|
69
128
|
# Read the script file
|
@@ -78,12 +137,12 @@ class Notebook(Base):
|
|
78
137
|
nb.cells.append(first_cell)
|
79
138
|
|
80
139
|
# Create a Notebook instance with the notebook data
|
81
|
-
notebook_instance = cls(nb)
|
140
|
+
notebook_instance = cls(data=nb, name=name, lint=lint)
|
82
141
|
|
83
142
|
return notebook_instance
|
84
143
|
|
85
144
|
@classmethod
|
86
|
-
def from_current_script(cls) -> "Notebook":
|
145
|
+
def from_current_script(cls, lint: bool = True) -> "Notebook":
|
87
146
|
import inspect
|
88
147
|
import os
|
89
148
|
|
@@ -93,7 +152,7 @@ class Notebook(Base):
|
|
93
152
|
current_file_path = os.path.abspath(caller_frame[1].filename)
|
94
153
|
|
95
154
|
# Use from_script to create the notebook
|
96
|
-
return cls.from_script(current_file_path)
|
155
|
+
return cls.from_script(current_file_path, lint=lint)
|
97
156
|
|
98
157
|
def __eq__(self, other):
|
99
158
|
"""
|
@@ -114,7 +173,7 @@ class Notebook(Base):
|
|
114
173
|
"""
|
115
174
|
Serialize to a dictionary.
|
116
175
|
"""
|
117
|
-
d = {"name": self.name, "data": self.data}
|
176
|
+
d = {"name": self.name, "data": self.data, "lint": self.lint}
|
118
177
|
if add_edsl_version:
|
119
178
|
from .. import __version__
|
120
179
|
|
@@ -124,11 +183,17 @@ class Notebook(Base):
|
|
124
183
|
|
125
184
|
@classmethod
|
126
185
|
@remove_edsl_version
|
127
|
-
def from_dict(cls, d: Dict) -> "Notebook":
|
186
|
+
def from_dict(cls, d: Dict, lint: bool = None) -> "Notebook":
|
128
187
|
"""
|
129
188
|
Convert a dictionary representation of a Notebook to a Notebook object.
|
189
|
+
|
190
|
+
:param d: Dictionary containing notebook data and name
|
191
|
+
:param lint: Whether to lint Python code cells. If None, uses the value from the dictionary or defaults to True.
|
192
|
+
:return: A new Notebook instance
|
130
193
|
"""
|
131
|
-
|
194
|
+
# Use the lint parameter from the dictionary if none is provided, otherwise default to True
|
195
|
+
notebook_lint = lint if lint is not None else d.get("lint", True)
|
196
|
+
return cls(data=d["data"], name=d["name"], lint=notebook_lint)
|
132
197
|
|
133
198
|
def to_file(self, path: str):
|
134
199
|
"""
|
@@ -205,11 +270,13 @@ class Notebook(Base):
|
|
205
270
|
return table
|
206
271
|
|
207
272
|
@classmethod
|
208
|
-
def example(cls, randomize: bool = False) -> Notebook:
|
273
|
+
def example(cls, randomize: bool = False, lint: bool = True) -> Notebook:
|
209
274
|
"""
|
210
275
|
Returns an example Notebook instance.
|
211
276
|
|
212
277
|
:param randomize: If True, adds a random string one of the cells' output.
|
278
|
+
:param lint: Whether to lint Python code cells. Defaults to True.
|
279
|
+
:return: An example Notebook instance
|
213
280
|
"""
|
214
281
|
addition = "" if not randomize else str(uuid4())
|
215
282
|
cells = [
|
@@ -238,7 +305,7 @@ class Notebook(Base):
|
|
238
305
|
"nbformat_minor": 4,
|
239
306
|
"cells": cells,
|
240
307
|
}
|
241
|
-
return cls(data=data)
|
308
|
+
return cls(data=data, lint=lint)
|
242
309
|
|
243
310
|
def code(self) -> List[str]:
|
244
311
|
"""
|
@@ -246,7 +313,7 @@ class Notebook(Base):
|
|
246
313
|
"""
|
247
314
|
lines = []
|
248
315
|
lines.append("from edsl import Notebook") # Keep as absolute for code generation
|
249
|
-
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
|
316
|
+
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""", lint={self.lint})')
|
250
317
|
return lines
|
251
318
|
|
252
319
|
def to_latex(self, filename: str):
|
@@ -0,0 +1,134 @@
|
|
1
|
+
# EDSL Validation Logging System
|
2
|
+
|
3
|
+
This system logs validation failures that occur during question answering and provides tools to analyze these failures to improve the "fix" methods for various question types.
|
4
|
+
|
5
|
+
## Background
|
6
|
+
|
7
|
+
When a language model's response to a question fails validation (e.g., the response doesn't match the expected format or constraints), EDSL throws a `QuestionAnswerValidationError`. To make these validations more robust, we've added a system to log these failures and analyze common patterns.
|
8
|
+
|
9
|
+
## Features
|
10
|
+
|
11
|
+
- **Validation Logging**: Automatically logs validation failures to a local file
|
12
|
+
- **Log Analysis**: Tools to analyze validation failures by question type and error message
|
13
|
+
- **Fix Method Suggestions**: Generates suggestions for improving fix methods based on common failure patterns
|
14
|
+
- **CLI Interface**: Command-line tools for managing and analyzing validation logs
|
15
|
+
|
16
|
+
## Usage
|
17
|
+
|
18
|
+
### Command Line Interface
|
19
|
+
|
20
|
+
The validation logging system is integrated with the EDSL CLI:
|
21
|
+
|
22
|
+
```bash
|
23
|
+
# Show recent validation failure logs
|
24
|
+
edsl validation logs
|
25
|
+
|
26
|
+
# Show recent logs filtered by question type
|
27
|
+
edsl validation logs --type QuestionMultipleChoice
|
28
|
+
|
29
|
+
# Save logs to a file
|
30
|
+
edsl validation logs --output validation_logs.json
|
31
|
+
|
32
|
+
# Clear all validation logs
|
33
|
+
edsl validation clear
|
34
|
+
|
35
|
+
# Show validation failure statistics
|
36
|
+
edsl validation stats
|
37
|
+
|
38
|
+
# Get suggestions for improving fix methods
|
39
|
+
edsl validation suggest
|
40
|
+
|
41
|
+
# Filter suggestions for a specific question type
|
42
|
+
edsl validation suggest --type QuestionMultipleChoice
|
43
|
+
|
44
|
+
# Generate a comprehensive JSON report
|
45
|
+
edsl validation report
|
46
|
+
|
47
|
+
# Generate an HTML report and open it in browser
|
48
|
+
edsl validation html-report
|
49
|
+
|
50
|
+
# Generate HTML report without opening browser
|
51
|
+
edsl validation html-report --no-open
|
52
|
+
```
|
53
|
+
|
54
|
+
You can also use the `make` command to generate reports:
|
55
|
+
|
56
|
+
```bash
|
57
|
+
# Generate and open HTML validation report
|
58
|
+
make validation-report
|
59
|
+
|
60
|
+
# Show validation statistics
|
61
|
+
make validation-stats
|
62
|
+
```
|
63
|
+
|
64
|
+
### Programmatic Usage
|
65
|
+
|
66
|
+
You can also use the validation logging system programmatically:
|
67
|
+
|
68
|
+
```python
|
69
|
+
from edsl.questions import (
|
70
|
+
log_validation_failure,
|
71
|
+
get_validation_failure_logs,
|
72
|
+
clear_validation_logs,
|
73
|
+
get_validation_failure_stats,
|
74
|
+
suggest_fix_improvements,
|
75
|
+
export_improvements_report,
|
76
|
+
generate_html_report,
|
77
|
+
generate_and_open_report
|
78
|
+
)
|
79
|
+
|
80
|
+
# Get recent validation failure logs
|
81
|
+
logs = get_validation_failure_logs(n=10)
|
82
|
+
|
83
|
+
# Get validation failure statistics
|
84
|
+
stats = get_validation_failure_stats()
|
85
|
+
|
86
|
+
# Get suggestions for improving fix methods
|
87
|
+
suggestions = suggest_fix_improvements()
|
88
|
+
|
89
|
+
# Generate a JSON report
|
90
|
+
report_path = export_improvements_report()
|
91
|
+
|
92
|
+
# Generate an HTML report
|
93
|
+
html_report_path = generate_html_report()
|
94
|
+
|
95
|
+
# Generate and open HTML report in browser
|
96
|
+
generate_and_open_report()
|
97
|
+
```
|
98
|
+
|
99
|
+
## Implementation Details
|
100
|
+
|
101
|
+
The validation logging system consists of the following components:
|
102
|
+
|
103
|
+
1. **Validation Logger**: Logs validation failures to a local file
|
104
|
+
2. **Validation Analysis**: Analyzes logs to identify patterns and suggest improvements
|
105
|
+
3. **HTML Report Generator**: Creates user-friendly HTML reports with visualizations
|
106
|
+
4. **CLI Integration**: Provides command-line tools for working with validation logs
|
107
|
+
|
108
|
+
### Log Format
|
109
|
+
|
110
|
+
Validation failure logs include the following information:
|
111
|
+
|
112
|
+
- Timestamp
|
113
|
+
- Question type and name
|
114
|
+
- Error message
|
115
|
+
- Invalid data that failed validation
|
116
|
+
- Model schema used for validation
|
117
|
+
- Question details (if available)
|
118
|
+
- Stack trace
|
119
|
+
|
120
|
+
### Storage Location
|
121
|
+
|
122
|
+
Logs are stored in the default EDSL log directory:
|
123
|
+
|
124
|
+
- Linux/macOS: `~/.edsl/logs/validation_failures.log`
|
125
|
+
- Windows: `%USERPROFILE%\.edsl\logs\validation_failures.log`
|
126
|
+
|
127
|
+
## Future Improvements
|
128
|
+
|
129
|
+
Potential future improvements to the validation logging system:
|
130
|
+
|
131
|
+
1. Integration with coop for cloud storage and analysis of validation failures
|
132
|
+
2. Machine learning to automatically suggest fix method improvements
|
133
|
+
3. Automated tests using common validation failure patterns
|
134
|
+
4. A web-based dashboard for visualizing validation failure statistics
|
edsl/questions/__init__.py
CHANGED
@@ -34,6 +34,7 @@ Derived Question Types:
|
|
34
34
|
- QuestionLinearScale: Linear scale with customizable range and labels
|
35
35
|
- QuestionYesNo: Simple binary yes/no response
|
36
36
|
- QuestionTopK: Selection of top K items from a list of options
|
37
|
+
- QuestionMultipleChoiceWithOther: Multiple choice with option to specify "Other" custom response
|
37
38
|
|
38
39
|
Technical Architecture:
|
39
40
|
---------------------
|
@@ -124,9 +125,19 @@ from .question_likert_five import QuestionLikertFive
|
|
124
125
|
from .question_linear_scale import QuestionLinearScale
|
125
126
|
from .question_yes_no import QuestionYesNo
|
126
127
|
from .question_top_k import QuestionTopK
|
128
|
+
from .question_multiple_choice_with_other import QuestionMultipleChoiceWithOther
|
127
129
|
|
128
130
|
from .exceptions import QuestionScenarioRenderError
|
129
131
|
|
132
|
+
# Import validation modules
|
133
|
+
from .validation_logger import log_validation_failure, get_validation_failure_logs, clear_validation_logs
|
134
|
+
from .validation_analysis import (
|
135
|
+
get_validation_failure_stats,
|
136
|
+
suggest_fix_improvements,
|
137
|
+
export_improvements_report
|
138
|
+
)
|
139
|
+
from .validation_html_report import generate_html_report, generate_and_open_report
|
140
|
+
|
130
141
|
__all__ = [
|
131
142
|
# Exceptions
|
132
143
|
"QuestionScenarioRenderError",
|
@@ -156,4 +167,16 @@ __all__ = [
|
|
156
167
|
"QuestionTopK",
|
157
168
|
"QuestionLikertFive",
|
158
169
|
"QuestionYesNo",
|
159
|
-
|
170
|
+
"QuestionMultipleChoiceWithOther",
|
171
|
+
"QuestionMultipleChoiceWithOther",
|
172
|
+
|
173
|
+
# Validation utilities
|
174
|
+
"log_validation_failure",
|
175
|
+
"get_validation_failure_logs",
|
176
|
+
"clear_validation_logs",
|
177
|
+
"get_validation_failure_stats",
|
178
|
+
"suggest_fix_improvements",
|
179
|
+
"export_improvements_report",
|
180
|
+
"generate_html_report",
|
181
|
+
"generate_and_open_report",
|
182
|
+
]
|
edsl/questions/exceptions.py
CHANGED
@@ -72,6 +72,27 @@ class QuestionAnswerValidationError(QuestionErrors):
|
|
72
72
|
self.data = data
|
73
73
|
self.model = model
|
74
74
|
super().__init__(self.message)
|
75
|
+
|
76
|
+
# Log validation failure for analysis
|
77
|
+
try:
|
78
|
+
from .validation_logger import log_validation_failure
|
79
|
+
|
80
|
+
# Get question type and name if available
|
81
|
+
question_type = getattr(model, "question_type", "unknown")
|
82
|
+
question_name = getattr(model, "question_name", "unknown")
|
83
|
+
|
84
|
+
# Log the validation failure
|
85
|
+
log_validation_failure(
|
86
|
+
question_type=question_type,
|
87
|
+
question_name=question_name,
|
88
|
+
error_message=str(message),
|
89
|
+
invalid_data=data,
|
90
|
+
model_schema=model.model_json_schema(),
|
91
|
+
question_dict=getattr(model, "to_dict", lambda: None)(),
|
92
|
+
)
|
93
|
+
except Exception:
|
94
|
+
# Silently ignore logging errors to not disrupt normal operation
|
95
|
+
pass
|
75
96
|
|
76
97
|
def __str__(self):
|
77
98
|
if isinstance(self.message, ValidationError):
|
edsl/questions/question_dict.py
CHANGED
@@ -76,7 +76,6 @@ def create_dict_response(
|
|
76
76
|
field_definitions = {}
|
77
77
|
if len(value_types) == 0:
|
78
78
|
value_types = ["str"] * len(answer_keys) # Default to str if no types provided
|
79
|
-
|
80
79
|
for key, t_str in zip(answer_keys, value_types):
|
81
80
|
python_type = _parse_type_string(t_str)
|
82
81
|
field_definitions[key] = (python_type, Field(...))
|
@@ -193,27 +192,177 @@ class DictResponseValidator(ResponseValidatorABC):
|
|
193
192
|
'not a dictionary'
|
194
193
|
"""
|
195
194
|
# First try to separate dictionary from trailing comment if they're on the same line
|
195
|
+
original_response = response
|
196
196
|
if isinstance(response, str):
|
197
197
|
# Try to find where the dictionary ends and comment begins
|
198
198
|
try:
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
199
|
+
# Find the first opening brace
|
200
|
+
response_str = response.strip()
|
201
|
+
if response_str.startswith("{"):
|
202
|
+
# Count braces to find proper JSON ending
|
203
|
+
brace_count = 0
|
204
|
+
dict_end_pos = None
|
205
|
+
|
206
|
+
for i, char in enumerate(response_str):
|
207
|
+
if char == "{":
|
208
|
+
brace_count += 1
|
209
|
+
elif char == "}":
|
210
|
+
brace_count -= 1
|
211
|
+
if brace_count == 0:
|
212
|
+
dict_end_pos = i + 1
|
213
|
+
break
|
214
|
+
|
215
|
+
if dict_end_pos is not None:
|
216
|
+
dict_str = response_str[:dict_end_pos]
|
217
|
+
comment = response_str[dict_end_pos:].strip()
|
218
|
+
|
219
|
+
try:
|
220
|
+
answer_dict = ast.literal_eval(dict_str)
|
221
|
+
response = {
|
222
|
+
"answer": answer_dict,
|
223
|
+
"comment": comment if comment else None,
|
224
|
+
}
|
225
|
+
if verbose:
|
226
|
+
print(
|
227
|
+
f"Successfully split answer from comment. Comment length: {len(comment) if comment else 0}"
|
228
|
+
)
|
229
|
+
except (ValueError, SyntaxError) as e:
|
230
|
+
if verbose:
|
231
|
+
print(f"Failed to parse dictionary: {e}")
|
232
|
+
except Exception as e:
|
233
|
+
if verbose:
|
234
|
+
print(f"Exception during dictionary parsing: {e}")
|
212
235
|
|
213
236
|
# Continue with existing fix logic
|
214
237
|
if "answer" not in response or not isinstance(response["answer"], dict):
|
215
238
|
if verbose:
|
216
239
|
print("Cannot fix response: 'answer' field missing or not a dictionary")
|
240
|
+
|
241
|
+
# Special case: if we have the original string response, try a more direct parsing approach
|
242
|
+
if isinstance(
|
243
|
+
original_response, str
|
244
|
+
) and original_response.strip().startswith("{"):
|
245
|
+
try:
|
246
|
+
# Try to parse the JSON part directly, skipping nested comments
|
247
|
+
response_str = original_response.strip()
|
248
|
+
import json
|
249
|
+
|
250
|
+
# Find where the dict ends by tracking nested braces
|
251
|
+
brace_count = 0
|
252
|
+
dict_end_pos = None
|
253
|
+
|
254
|
+
for i, char in enumerate(response_str):
|
255
|
+
if char == "{":
|
256
|
+
brace_count += 1
|
257
|
+
elif char == "}":
|
258
|
+
brace_count -= 1
|
259
|
+
if brace_count == 0:
|
260
|
+
dict_end_pos = i + 1
|
261
|
+
break
|
262
|
+
|
263
|
+
if dict_end_pos is not None:
|
264
|
+
dict_str = response_str[:dict_end_pos]
|
265
|
+
comment = response_str[dict_end_pos:].strip()
|
266
|
+
|
267
|
+
# Try parsing with JSON first (faster but stricter)
|
268
|
+
try:
|
269
|
+
dict_str = dict_str.replace(
|
270
|
+
"'", '"'
|
271
|
+
) # Convert Python quotes to JSON quotes
|
272
|
+
dict_str = dict_str.replace("False", "false").replace(
|
273
|
+
"True", "true"
|
274
|
+
) # Fix booleans
|
275
|
+
answer_dict = json.loads(dict_str)
|
276
|
+
except json.JSONDecodeError:
|
277
|
+
# Fall back to ast.literal_eval (safer)
|
278
|
+
try:
|
279
|
+
answer_dict = ast.literal_eval(dict_str)
|
280
|
+
except (ValueError, SyntaxError):
|
281
|
+
if verbose:
|
282
|
+
print("Could not parse the dictionary part")
|
283
|
+
return original_response
|
284
|
+
|
285
|
+
# Now fix types
|
286
|
+
fixed_answer = {}
|
287
|
+
for key, type_str in zip(
|
288
|
+
self.answer_keys, getattr(self, "value_types", [])
|
289
|
+
):
|
290
|
+
if key in answer_dict:
|
291
|
+
value = answer_dict[key]
|
292
|
+
# Convert types
|
293
|
+
if type_str == "int" and not isinstance(value, int):
|
294
|
+
try:
|
295
|
+
fixed_answer[key] = int(value)
|
296
|
+
if verbose:
|
297
|
+
print(
|
298
|
+
f"Converted '{key}' from {type(value).__name__} to int"
|
299
|
+
)
|
300
|
+
except (ValueError, TypeError):
|
301
|
+
fixed_answer[key] = value
|
302
|
+
|
303
|
+
elif type_str == "float" and not isinstance(
|
304
|
+
value, float
|
305
|
+
):
|
306
|
+
try:
|
307
|
+
fixed_answer[key] = float(value)
|
308
|
+
if verbose:
|
309
|
+
print(
|
310
|
+
f"Converted '{key}' from {type(value).__name__} to float"
|
311
|
+
)
|
312
|
+
except (ValueError, TypeError):
|
313
|
+
fixed_answer[key] = value
|
314
|
+
|
315
|
+
elif (
|
316
|
+
type_str.startswith("list[") or type_str == "list"
|
317
|
+
) and not isinstance(value, list):
|
318
|
+
# Convert string to list by splitting
|
319
|
+
if isinstance(value, str):
|
320
|
+
items = [
|
321
|
+
item.strip() for item in value.split(",")
|
322
|
+
]
|
323
|
+
fixed_answer[key] = items
|
324
|
+
if verbose:
|
325
|
+
print(
|
326
|
+
f"Converted '{key}' from string to list: {items}"
|
327
|
+
)
|
328
|
+
else:
|
329
|
+
fixed_answer[key] = value
|
330
|
+
else:
|
331
|
+
fixed_answer[key] = value
|
332
|
+
else:
|
333
|
+
# Key not in answer, set a default
|
334
|
+
if type_str == "int":
|
335
|
+
fixed_answer[key] = 0
|
336
|
+
elif type_str == "float":
|
337
|
+
fixed_answer[key] = 0.0
|
338
|
+
elif type_str.startswith("list") or type_str == "list":
|
339
|
+
fixed_answer[key] = []
|
340
|
+
else:
|
341
|
+
fixed_answer[key] = ""
|
342
|
+
|
343
|
+
# Construct final fixed response
|
344
|
+
fixed_response = {
|
345
|
+
"answer": fixed_answer,
|
346
|
+
"comment": comment if comment else None,
|
347
|
+
"generated_tokens": None,
|
348
|
+
}
|
349
|
+
|
350
|
+
if verbose:
|
351
|
+
print(f"Directly fixed response with type conversion")
|
352
|
+
|
353
|
+
try:
|
354
|
+
# Try to validate
|
355
|
+
self.response_model.model_validate(fixed_response)
|
356
|
+
if verbose:
|
357
|
+
print("Successfully validated fixed response")
|
358
|
+
return fixed_response
|
359
|
+
except Exception as e:
|
360
|
+
if verbose:
|
361
|
+
print(f"Validation of direct fix failed: {e}")
|
362
|
+
except Exception as e:
|
363
|
+
if verbose:
|
364
|
+
print(f"Error during direct parsing: {e}")
|
365
|
+
|
217
366
|
return response
|
218
367
|
|
219
368
|
answer_dict = response["answer"]
|
@@ -246,7 +395,9 @@ class DictResponseValidator(ResponseValidatorABC):
|
|
246
395
|
except (ValueError, TypeError):
|
247
396
|
pass
|
248
397
|
|
249
|
-
elif
|
398
|
+
elif (
|
399
|
+
type_str.startswith("list[") or type_str == "list"
|
400
|
+
) and not isinstance(value, list):
|
250
401
|
# Try to convert string to list by splitting
|
251
402
|
if isinstance(value, str):
|
252
403
|
items = [item.strip() for item in value.split(",")]
|
@@ -267,7 +418,8 @@ class DictResponseValidator(ResponseValidatorABC):
|
|
267
418
|
fixed_response = {
|
268
419
|
"answer": fixed_answer,
|
269
420
|
"comment": response.get("comment"),
|
270
|
-
"generated_tokens": response.get("generated_tokens")
|
421
|
+
"generated_tokens": response.get("generated_tokens")
|
422
|
+
or response, # Ensure generated_tokens is captured
|
271
423
|
}
|
272
424
|
|
273
425
|
try:
|
@@ -279,6 +431,39 @@ class DictResponseValidator(ResponseValidatorABC):
|
|
279
431
|
except Exception as e:
|
280
432
|
if verbose:
|
281
433
|
print(f"Validation failed for fixed answer: {e}")
|
434
|
+
|
435
|
+
# If still failing, try one more time with default values for missing keys
|
436
|
+
if hasattr(self, "answer_keys") and hasattr(self, "value_types"):
|
437
|
+
for key, type_str in zip(
|
438
|
+
self.answer_keys, getattr(self, "value_types", [])
|
439
|
+
):
|
440
|
+
if key not in fixed_answer:
|
441
|
+
if type_str == "int":
|
442
|
+
fixed_answer[key] = 0
|
443
|
+
elif type_str == "float":
|
444
|
+
fixed_answer[key] = 0.0
|
445
|
+
elif type_str.startswith("list") or type_str == "list":
|
446
|
+
fixed_answer[key] = []
|
447
|
+
else:
|
448
|
+
fixed_answer[key] = ""
|
449
|
+
|
450
|
+
# Try again with all keys
|
451
|
+
fixed_response = {
|
452
|
+
"answer": fixed_answer,
|
453
|
+
"comment": response.get("comment"),
|
454
|
+
"generated_tokens": response.get("generated_tokens"),
|
455
|
+
}
|
456
|
+
|
457
|
+
try:
|
458
|
+
# Validate the fixed answer
|
459
|
+
self.response_model.model_validate(fixed_response)
|
460
|
+
if verbose:
|
461
|
+
print("Successfully fixed response with defaults")
|
462
|
+
return fixed_response
|
463
|
+
except Exception as e:
|
464
|
+
if verbose:
|
465
|
+
print(f"Validation still failed after adding defaults: {e}")
|
466
|
+
|
282
467
|
return response
|
283
468
|
|
284
469
|
valid_examples = [
|