edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
@@ -0,0 +1,313 @@
|
|
1
|
+
from edsl.scenarios.file_methods import FileMethods
|
2
|
+
import tempfile
|
3
|
+
import re
|
4
|
+
from typing import List, Optional
|
5
|
+
import textwrap
|
6
|
+
|
7
|
+
|
8
|
+
class SqlMethods(FileMethods):
|
9
|
+
suffix = "sql"
|
10
|
+
|
11
|
+
def view_system(self):
|
12
|
+
import os
|
13
|
+
import subprocess
|
14
|
+
|
15
|
+
if os.path.exists(self.path):
|
16
|
+
try:
|
17
|
+
if (os_name := os.name) == "posix":
|
18
|
+
subprocess.run(["open", self.path], check=True) # macOS
|
19
|
+
elif os_name == "nt":
|
20
|
+
os.startfile(self.path) # Windows
|
21
|
+
else:
|
22
|
+
subprocess.run(["xdg-open", self.path], check=True) # Linux
|
23
|
+
except Exception as e:
|
24
|
+
print(f"Error opening SQL file: {e}")
|
25
|
+
else:
|
26
|
+
print("SQL file was not found.")
|
27
|
+
|
28
|
+
def view_notebook(self):
|
29
|
+
from IPython.display import FileLink, display, HTML
|
30
|
+
import pygments
|
31
|
+
from pygments.lexers import SqlLexer
|
32
|
+
from pygments.formatters import HtmlFormatter
|
33
|
+
|
34
|
+
try:
|
35
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
36
|
+
content = f.read()
|
37
|
+
|
38
|
+
formatter = HtmlFormatter(style="monokai")
|
39
|
+
highlighted_sql = pygments.highlight(content, SqlLexer(), formatter)
|
40
|
+
css = formatter.get_style_defs(".highlight")
|
41
|
+
display(HTML(f"<style>{css}</style>{highlighted_sql}"))
|
42
|
+
display(FileLink(self.path))
|
43
|
+
except Exception as e:
|
44
|
+
print(f"Error displaying SQL: {e}")
|
45
|
+
|
46
|
+
def _format_keywords(self, sql: str) -> str:
|
47
|
+
"""Capitalize SQL keywords."""
|
48
|
+
keywords = {
|
49
|
+
"select",
|
50
|
+
"from",
|
51
|
+
"where",
|
52
|
+
"and",
|
53
|
+
"or",
|
54
|
+
"insert",
|
55
|
+
"update",
|
56
|
+
"delete",
|
57
|
+
"create",
|
58
|
+
"drop",
|
59
|
+
"alter",
|
60
|
+
"table",
|
61
|
+
"into",
|
62
|
+
"values",
|
63
|
+
"group",
|
64
|
+
"by",
|
65
|
+
"having",
|
66
|
+
"order",
|
67
|
+
"limit",
|
68
|
+
"join",
|
69
|
+
"left",
|
70
|
+
"right",
|
71
|
+
"inner",
|
72
|
+
"outer",
|
73
|
+
"on",
|
74
|
+
"as",
|
75
|
+
"distinct",
|
76
|
+
"count",
|
77
|
+
"sum",
|
78
|
+
"avg",
|
79
|
+
"max",
|
80
|
+
"min",
|
81
|
+
"between",
|
82
|
+
"like",
|
83
|
+
"in",
|
84
|
+
"is",
|
85
|
+
"null",
|
86
|
+
"not",
|
87
|
+
"case",
|
88
|
+
"when",
|
89
|
+
"then",
|
90
|
+
"else",
|
91
|
+
"end",
|
92
|
+
}
|
93
|
+
|
94
|
+
words = sql.split()
|
95
|
+
formatted_words = []
|
96
|
+
for word in words:
|
97
|
+
lower_word = word.lower()
|
98
|
+
if lower_word in keywords:
|
99
|
+
formatted_words.append(word.upper())
|
100
|
+
else:
|
101
|
+
formatted_words.append(word.lower())
|
102
|
+
return " ".join(formatted_words)
|
103
|
+
|
104
|
+
def _indent_sql(self, sql: str) -> str:
|
105
|
+
"""Add basic indentation to SQL statement."""
|
106
|
+
lines = sql.split("\n")
|
107
|
+
indented_lines = []
|
108
|
+
indent_level = 0
|
109
|
+
|
110
|
+
for line in lines:
|
111
|
+
line = line.strip()
|
112
|
+
|
113
|
+
# Decrease indent for closing parentheses
|
114
|
+
if line.startswith(")"):
|
115
|
+
indent_level = max(0, indent_level - 1)
|
116
|
+
|
117
|
+
# Add indentation
|
118
|
+
if line:
|
119
|
+
indented_lines.append(" " * indent_level + line)
|
120
|
+
else:
|
121
|
+
indented_lines.append("")
|
122
|
+
|
123
|
+
# Increase indent after opening parentheses
|
124
|
+
if line.endswith("("):
|
125
|
+
indent_level += 1
|
126
|
+
|
127
|
+
# Special cases for common SQL clauses
|
128
|
+
lower_line = line.lower()
|
129
|
+
if any(
|
130
|
+
clause in lower_line
|
131
|
+
for clause in [
|
132
|
+
"select",
|
133
|
+
"from",
|
134
|
+
"where",
|
135
|
+
"group by",
|
136
|
+
"having",
|
137
|
+
"order by",
|
138
|
+
]
|
139
|
+
):
|
140
|
+
indent_level = 1
|
141
|
+
|
142
|
+
return "\n".join(indented_lines)
|
143
|
+
|
144
|
+
def format_sql(self) -> bool:
|
145
|
+
"""Format the SQL file with proper indentation and keyword capitalization."""
|
146
|
+
try:
|
147
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
148
|
+
content = f.read()
|
149
|
+
|
150
|
+
# Remove extra whitespace and format
|
151
|
+
content = " ".join(content.split())
|
152
|
+
content = self._format_keywords(content)
|
153
|
+
content = self._indent_sql(content)
|
154
|
+
|
155
|
+
# Wrap long lines
|
156
|
+
wrapped_content = []
|
157
|
+
for line in content.split("\n"):
|
158
|
+
if len(line) > 80:
|
159
|
+
wrapped_line = textwrap.fill(
|
160
|
+
line, width=80, subsequent_indent=" "
|
161
|
+
)
|
162
|
+
wrapped_content.append(wrapped_line)
|
163
|
+
else:
|
164
|
+
wrapped_content.append(line)
|
165
|
+
|
166
|
+
formatted_sql = "\n".join(wrapped_content)
|
167
|
+
|
168
|
+
with open(self.path, "w", encoding="utf-8") as f:
|
169
|
+
f.write(formatted_sql)
|
170
|
+
|
171
|
+
return True
|
172
|
+
except Exception as e:
|
173
|
+
print(f"Error formatting SQL: {e}")
|
174
|
+
return False
|
175
|
+
|
176
|
+
def split_statements(self) -> List[str]:
|
177
|
+
"""Split the SQL file into individual statements."""
|
178
|
+
try:
|
179
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
180
|
+
content = f.read()
|
181
|
+
|
182
|
+
# Handle both semicolon and GO statement terminators
|
183
|
+
statements = []
|
184
|
+
current_stmt = []
|
185
|
+
|
186
|
+
for line in content.split("\n"):
|
187
|
+
line = line.strip()
|
188
|
+
|
189
|
+
# Skip empty lines and comments
|
190
|
+
if not line or line.startswith("--"):
|
191
|
+
continue
|
192
|
+
|
193
|
+
if line.endswith(";"):
|
194
|
+
current_stmt.append(line[:-1]) # Remove semicolon
|
195
|
+
statements.append(" ".join(current_stmt))
|
196
|
+
current_stmt = []
|
197
|
+
elif line.upper() == "GO":
|
198
|
+
if current_stmt:
|
199
|
+
statements.append(" ".join(current_stmt))
|
200
|
+
current_stmt = []
|
201
|
+
else:
|
202
|
+
current_stmt.append(line)
|
203
|
+
|
204
|
+
# Add any remaining statement
|
205
|
+
if current_stmt:
|
206
|
+
statements.append(" ".join(current_stmt))
|
207
|
+
|
208
|
+
return [stmt.strip() for stmt in statements if stmt.strip()]
|
209
|
+
except Exception as e:
|
210
|
+
print(f"Error splitting SQL statements: {e}")
|
211
|
+
return []
|
212
|
+
|
213
|
+
def validate_basic_syntax(self) -> bool:
|
214
|
+
"""
|
215
|
+
Perform basic SQL syntax validation.
|
216
|
+
This is a simple check and doesn't replace proper SQL parsing.
|
217
|
+
"""
|
218
|
+
try:
|
219
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
220
|
+
content = f.read()
|
221
|
+
|
222
|
+
statements = self.split_statements()
|
223
|
+
for stmt in statements:
|
224
|
+
# Check for basic SQL keywords
|
225
|
+
stmt_upper = stmt.upper()
|
226
|
+
if not any(
|
227
|
+
keyword in stmt_upper
|
228
|
+
for keyword in [
|
229
|
+
"SELECT",
|
230
|
+
"INSERT",
|
231
|
+
"UPDATE",
|
232
|
+
"DELETE",
|
233
|
+
"CREATE",
|
234
|
+
"DROP",
|
235
|
+
"ALTER",
|
236
|
+
]
|
237
|
+
):
|
238
|
+
print(f"Warning: Statement might be incomplete: {stmt}")
|
239
|
+
|
240
|
+
# Check for basic parentheses matching
|
241
|
+
if stmt.count("(") != stmt.count(")"):
|
242
|
+
print(f"Error: Unmatched parentheses in statement: {stmt}")
|
243
|
+
return False
|
244
|
+
|
245
|
+
# Check for basic quote matching
|
246
|
+
if stmt.count("'") % 2 != 0:
|
247
|
+
print(f"Error: Unmatched quotes in statement: {stmt}")
|
248
|
+
return False
|
249
|
+
|
250
|
+
return True
|
251
|
+
except Exception as e:
|
252
|
+
print(f"Error validating SQL: {e}")
|
253
|
+
return False
|
254
|
+
|
255
|
+
def extract_table_names(self) -> List[str]:
|
256
|
+
"""Extract table names from the SQL file."""
|
257
|
+
tables = set()
|
258
|
+
try:
|
259
|
+
with open(self.path, "r", encoding="utf-8") as f:
|
260
|
+
content = f.read()
|
261
|
+
|
262
|
+
patterns = [
|
263
|
+
r"FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
264
|
+
r"JOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
265
|
+
r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
266
|
+
r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
267
|
+
r"CREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_]*)",
|
268
|
+
]
|
269
|
+
|
270
|
+
for pattern in patterns:
|
271
|
+
tables.update(re.findall(pattern, content, re.IGNORECASE))
|
272
|
+
|
273
|
+
return sorted(list(tables))
|
274
|
+
except Exception as e:
|
275
|
+
print(f"Error extracting table names: {e}")
|
276
|
+
return []
|
277
|
+
|
278
|
+
def example(self):
|
279
|
+
sample_sql = """-- Sample SQL file with common operations
|
280
|
+
CREATE TABLE employees (
|
281
|
+
id INTEGER PRIMARY KEY,
|
282
|
+
name VARCHAR(100) NOT NULL,
|
283
|
+
department VARCHAR(50),
|
284
|
+
salary DECIMAL(10,2),
|
285
|
+
hire_date DATE
|
286
|
+
);
|
287
|
+
|
288
|
+
INSERT INTO employees (name, department, salary, hire_date)
|
289
|
+
VALUES
|
290
|
+
('John Doe', 'Engineering', 75000.00, '2023-01-15'),
|
291
|
+
('Jane Smith', 'Marketing', 65000.00, '2023-02-01');
|
292
|
+
|
293
|
+
-- Query to analyze employee data
|
294
|
+
SELECT
|
295
|
+
department,
|
296
|
+
COUNT(*) as employee_count,
|
297
|
+
AVG(salary) as avg_salary
|
298
|
+
FROM employees
|
299
|
+
GROUP BY department
|
300
|
+
HAVING COUNT(*) > 0
|
301
|
+
ORDER BY avg_salary DESC;
|
302
|
+
|
303
|
+
-- Update salary with conditions
|
304
|
+
UPDATE employees
|
305
|
+
SET salary = salary * 1.1
|
306
|
+
WHERE department = 'Engineering'
|
307
|
+
AND hire_date < '2024-01-01';
|
308
|
+
"""
|
309
|
+
with tempfile.NamedTemporaryFile(
|
310
|
+
delete=False, suffix=".sql", mode="w", encoding="utf-8"
|
311
|
+
) as f:
|
312
|
+
f.write(sample_sql)
|
313
|
+
return f.name
|
@@ -0,0 +1,149 @@
|
|
1
|
+
from edsl.scenarios.file_methods import FileMethods
|
2
|
+
import os
|
3
|
+
import tempfile
|
4
|
+
import sqlite3
|
5
|
+
|
6
|
+
|
7
|
+
class SQLiteMethods(FileMethods):
|
8
|
+
suffix = "db" # or "sqlite", depending on your preference
|
9
|
+
|
10
|
+
def extract_text(self):
|
11
|
+
"""
|
12
|
+
Extracts a text representation of the database schema and table contents.
|
13
|
+
"""
|
14
|
+
with sqlite3.connect(self.path) as conn:
|
15
|
+
cursor = conn.cursor()
|
16
|
+
|
17
|
+
# Get all table names
|
18
|
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
19
|
+
tables = cursor.fetchall()
|
20
|
+
|
21
|
+
full_text = []
|
22
|
+
|
23
|
+
# For each table, get schema and contents
|
24
|
+
for (table_name,) in tables:
|
25
|
+
# Get table schema
|
26
|
+
cursor.execute(
|
27
|
+
f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';"
|
28
|
+
)
|
29
|
+
schema = cursor.fetchone()[0]
|
30
|
+
full_text.append(f"Table: {table_name}")
|
31
|
+
full_text.append(f"Schema: {schema}")
|
32
|
+
|
33
|
+
# Get table contents
|
34
|
+
cursor.execute(f"SELECT * FROM {table_name};")
|
35
|
+
rows = cursor.fetchall()
|
36
|
+
|
37
|
+
# Get column names
|
38
|
+
column_names = [description[0] for description in cursor.description]
|
39
|
+
full_text.append(f"Columns: {', '.join(column_names)}")
|
40
|
+
|
41
|
+
# Add row data
|
42
|
+
for row in rows:
|
43
|
+
full_text.append(str(row))
|
44
|
+
full_text.append("\n")
|
45
|
+
|
46
|
+
return "\n".join(full_text)
|
47
|
+
|
48
|
+
def view_system(self):
|
49
|
+
"""
|
50
|
+
Opens the database with the system's default SQLite viewer if available.
|
51
|
+
"""
|
52
|
+
import os
|
53
|
+
import subprocess
|
54
|
+
|
55
|
+
if os.path.exists(self.path):
|
56
|
+
try:
|
57
|
+
if (os_name := os.name) == "posix":
|
58
|
+
# Try DB Browser for SQLite on macOS
|
59
|
+
subprocess.run(
|
60
|
+
["open", "-a", "DB Browser for SQLite", self.path], check=True
|
61
|
+
)
|
62
|
+
elif os_name == "nt":
|
63
|
+
# Try DB Browser for SQLite on Windows
|
64
|
+
subprocess.run(["DB Browser for SQLite.exe", self.path], check=True)
|
65
|
+
else:
|
66
|
+
# Try sqlitebrowser on Linux
|
67
|
+
subprocess.run(["sqlitebrowser", self.path], check=True)
|
68
|
+
except Exception as e:
|
69
|
+
print(f"Error opening SQLite database: {e}")
|
70
|
+
else:
|
71
|
+
print("SQLite database file was not found.")
|
72
|
+
|
73
|
+
def view_notebook(self):
|
74
|
+
"""
|
75
|
+
Displays database contents in a Jupyter notebook.
|
76
|
+
"""
|
77
|
+
import pandas as pd
|
78
|
+
from IPython.display import HTML, display
|
79
|
+
|
80
|
+
with sqlite3.connect(self.path) as conn:
|
81
|
+
# Get all table names
|
82
|
+
cursor = conn.cursor()
|
83
|
+
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
|
84
|
+
tables = cursor.fetchall()
|
85
|
+
|
86
|
+
html_parts = []
|
87
|
+
for (table_name,) in tables:
|
88
|
+
# Read table into pandas DataFrame
|
89
|
+
df = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
|
90
|
+
|
91
|
+
# Convert to HTML with styling
|
92
|
+
table_html = f"""
|
93
|
+
<div style="margin-bottom: 20px;">
|
94
|
+
<h3>{table_name}</h3>
|
95
|
+
{df.to_html(index=False)}
|
96
|
+
</div>
|
97
|
+
"""
|
98
|
+
html_parts.append(table_html)
|
99
|
+
|
100
|
+
# Combine all tables into one scrollable div
|
101
|
+
html = f"""
|
102
|
+
<div style="width: 800px; height: 800px; padding: 20px;
|
103
|
+
border: 1px solid #ccc; overflow-y: auto;">
|
104
|
+
{''.join(html_parts)}
|
105
|
+
</div>
|
106
|
+
"""
|
107
|
+
display(HTML(html))
|
108
|
+
|
109
|
+
def example(self):
|
110
|
+
"""
|
111
|
+
Creates an example SQLite database for testing.
|
112
|
+
"""
|
113
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
|
114
|
+
conn = sqlite3.connect(tmp.name)
|
115
|
+
cursor = conn.cursor()
|
116
|
+
|
117
|
+
# Create a sample table
|
118
|
+
cursor.execute(
|
119
|
+
"""
|
120
|
+
CREATE TABLE survey_responses (
|
121
|
+
id INTEGER PRIMARY KEY,
|
122
|
+
question TEXT,
|
123
|
+
response TEXT
|
124
|
+
)
|
125
|
+
"""
|
126
|
+
)
|
127
|
+
|
128
|
+
# Insert some sample data
|
129
|
+
sample_data = [
|
130
|
+
(1, "First Survey Question", "Response 1"),
|
131
|
+
(2, "Second Survey Question", "Response 2"),
|
132
|
+
]
|
133
|
+
cursor.executemany(
|
134
|
+
"INSERT INTO survey_responses (id, question, response) VALUES (?, ?, ?)",
|
135
|
+
sample_data,
|
136
|
+
)
|
137
|
+
|
138
|
+
conn.commit()
|
139
|
+
conn.close()
|
140
|
+
tmp.close()
|
141
|
+
|
142
|
+
return tmp.name
|
143
|
+
|
144
|
+
|
145
|
+
if __name__ == "__main__":
|
146
|
+
sqlite_temp = SQLiteMethods.example()
|
147
|
+
from edsl.scenarios.FileStore import FileStore
|
148
|
+
|
149
|
+
fs = FileStore(sqlite_temp)
|
@@ -0,0 +1,33 @@
|
|
1
|
+
from edsl.scenarios.file_methods import FileMethods
|
2
|
+
import tempfile
|
3
|
+
|
4
|
+
|
5
|
+
class TxtMethods(FileMethods):
|
6
|
+
suffix = "txt"
|
7
|
+
|
8
|
+
def view_system(self):
|
9
|
+
import os
|
10
|
+
import subprocess
|
11
|
+
|
12
|
+
if os.path.exists(self.path):
|
13
|
+
try:
|
14
|
+
if (os_name := os.name) == "posix":
|
15
|
+
subprocess.run(["open", self.path], check=True) # macOS
|
16
|
+
elif os_name == "nt":
|
17
|
+
os.startfile(self.path) # Windows
|
18
|
+
else:
|
19
|
+
subprocess.run(["xdg-open", self.path], check=True) # Linux
|
20
|
+
except Exception as e:
|
21
|
+
print(f"Error opening TXT: {e}")
|
22
|
+
else:
|
23
|
+
print("TXT file was not found.")
|
24
|
+
|
25
|
+
def view_notebook(self):
|
26
|
+
from IPython.display import FileLink, display
|
27
|
+
|
28
|
+
display(FileLink(self.path))
|
29
|
+
|
30
|
+
def example(self):
|
31
|
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
|
32
|
+
f.write(b"Hello, World!")
|
33
|
+
return f.name
|
edsl/study/ObjectEntry.py
CHANGED
edsl/study/SnapShot.py
CHANGED
@@ -32,7 +32,7 @@ class SnapShot:
|
|
32
32
|
{'Cache': <class 'edsl.data.Cache.Cache'>}
|
33
33
|
"""
|
34
34
|
from edsl.Base import RegisterSubclassesMeta
|
35
|
-
from edsl import QuestionBase
|
35
|
+
from edsl.questions.QuestionBase import QuestionBase
|
36
36
|
|
37
37
|
all_edsl_objects = RegisterSubclassesMeta.get_registry()
|
38
38
|
|
edsl/study/Study.py
CHANGED
@@ -7,7 +7,9 @@ import socket
|
|
7
7
|
from datetime import datetime
|
8
8
|
from typing import Dict, Optional, Union
|
9
9
|
from uuid import UUID, uuid4
|
10
|
-
|
10
|
+
|
11
|
+
from edsl.data.Cache import Cache
|
12
|
+
from edsl import set_session_cache, unset_session_cache
|
11
13
|
from edsl.utilities.utilities import dict_hash
|
12
14
|
from edsl.study.ObjectEntry import ObjectEntry
|
13
15
|
from edsl.study.ProofOfWork import ProofOfWork
|
@@ -405,7 +407,7 @@ class Study:
|
|
405
407
|
|
406
408
|
study_file = tempfile.NamedTemporaryFile()
|
407
409
|
with cls(filename=study_file.name, verbose=verbose) as study:
|
408
|
-
from edsl import QuestionFreeText
|
410
|
+
from edsl.questions.QuestionFreeText import QuestionFreeText
|
409
411
|
|
410
412
|
q = QuestionFreeText.example(randomize=randomize)
|
411
413
|
return study
|
@@ -464,7 +466,7 @@ class Study:
|
|
464
466
|
def push(self) -> dict:
|
465
467
|
"""Push the objects to coop."""
|
466
468
|
|
467
|
-
from edsl import Coop
|
469
|
+
from edsl.coop.coop import Coop
|
468
470
|
|
469
471
|
coop = Coop()
|
470
472
|
return coop.create(self, description=self.description)
|
@@ -517,12 +519,3 @@ if __name__ == "__main__":
|
|
517
519
|
import doctest
|
518
520
|
|
519
521
|
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
520
|
-
|
521
|
-
# with Study(name = "cool_study") as study:
|
522
|
-
# from edsl import QuestionFreeText
|
523
|
-
# q = QuestionFreeText.example()
|
524
|
-
|
525
|
-
# assert len(study.objects) == 1
|
526
|
-
|
527
|
-
# print(study.versions())
|
528
|
-
# {'q': [ObjectEntry(variable_name='q', object=Question('free_text', question_name = """how_are_you""", question_text = """How are you?"""), description='Question name: how_are_you', coop_info=None, created_at=1720276402.561273, edsl_class_name='QuestionFreeText')]}
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from edsl.surveys.base import EndOfSurvey
|
2
|
+
from edsl.surveys.DAG import DAG
|
3
|
+
from edsl.exceptions.surveys import SurveyError
|
4
|
+
|
5
|
+
|
6
|
+
class ConstructDAG:
|
7
|
+
def __init__(self, survey):
|
8
|
+
self.survey = survey
|
9
|
+
self.questions = survey.questions
|
10
|
+
|
11
|
+
self.parameters_by_question = self.survey.parameters_by_question
|
12
|
+
self.question_name_to_index = self.survey.question_name_to_index
|
13
|
+
|
14
|
+
def dag(self, textify: bool = False) -> DAG:
|
15
|
+
memory_dag = self.survey.memory_plan.dag
|
16
|
+
rule_dag = self.survey.rule_collection.dag
|
17
|
+
piping_dag = self.piping_dag
|
18
|
+
if textify:
|
19
|
+
memory_dag = DAG(self.textify(memory_dag))
|
20
|
+
rule_dag = DAG(self.textify(rule_dag))
|
21
|
+
piping_dag = DAG(self.textify(piping_dag))
|
22
|
+
return memory_dag + rule_dag + piping_dag
|
23
|
+
|
24
|
+
@property
|
25
|
+
def piping_dag(self) -> DAG:
|
26
|
+
"""Figures out the DAG of piping dependencies.
|
27
|
+
|
28
|
+
>>> from edsl import Survey
|
29
|
+
>>> from edsl import QuestionFreeText
|
30
|
+
>>> q0 = QuestionFreeText(question_text="Here is a question", question_name="q0")
|
31
|
+
>>> q1 = QuestionFreeText(question_text="You previously answered {{ q0 }}---how do you feel now?", question_name="q1")
|
32
|
+
>>> s = Survey([q0, q1])
|
33
|
+
>>> ConstructDAG(s).piping_dag
|
34
|
+
{1: {0}}
|
35
|
+
"""
|
36
|
+
d = {}
|
37
|
+
for question_name, depenencies in self.parameters_by_question.items():
|
38
|
+
if depenencies:
|
39
|
+
question_index = self.question_name_to_index[question_name]
|
40
|
+
for dependency in depenencies:
|
41
|
+
if dependency not in self.question_name_to_index:
|
42
|
+
pass
|
43
|
+
else:
|
44
|
+
dependency_index = self.question_name_to_index[dependency]
|
45
|
+
if question_index not in d:
|
46
|
+
d[question_index] = set()
|
47
|
+
d[question_index].add(dependency_index)
|
48
|
+
return d
|
49
|
+
|
50
|
+
def textify(self, index_dag: DAG) -> DAG:
|
51
|
+
"""Convert the DAG of question indices to a DAG of question names.
|
52
|
+
|
53
|
+
:param index_dag: The DAG of question indices.
|
54
|
+
|
55
|
+
Example:
|
56
|
+
|
57
|
+
>>> from edsl import Survey
|
58
|
+
>>> s = Survey.example()
|
59
|
+
>>> d = s.dag()
|
60
|
+
>>> d
|
61
|
+
{1: {0}, 2: {0}}
|
62
|
+
>>> ConstructDAG(s).textify(d)
|
63
|
+
{'q1': {'q0'}, 'q2': {'q0'}}
|
64
|
+
"""
|
65
|
+
|
66
|
+
def get_name(index: int):
|
67
|
+
"""Return the name of the question given the index."""
|
68
|
+
if index >= len(self.questions):
|
69
|
+
return EndOfSurvey
|
70
|
+
try:
|
71
|
+
return self.questions[index].question_name
|
72
|
+
except IndexError:
|
73
|
+
print(
|
74
|
+
f"The index is {index} but the length of the questions is {len(self.questions)}"
|
75
|
+
)
|
76
|
+
raise SurveyError
|
77
|
+
|
78
|
+
try:
|
79
|
+
text_dag = {}
|
80
|
+
for child_index, parent_indices in index_dag.items():
|
81
|
+
parent_names = {get_name(index) for index in parent_indices}
|
82
|
+
child_name = get_name(child_index)
|
83
|
+
text_dag[child_name] = parent_names
|
84
|
+
return text_dag
|
85
|
+
except IndexError:
|
86
|
+
raise
|
87
|
+
|
88
|
+
|
89
|
+
if __name__ == "__main__":
|
90
|
+
import doctest
|
91
|
+
|
92
|
+
doctest.testmod()
|