edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -0,0 +1,313 @@
1
+ from edsl.scenarios.file_methods import FileMethods
2
+ import tempfile
3
+ import re
4
+ from typing import List, Optional
5
+ import textwrap
6
+
7
+
8
+ class SqlMethods(FileMethods):
9
+ suffix = "sql"
10
+
11
+ def view_system(self):
12
+ import os
13
+ import subprocess
14
+
15
+ if os.path.exists(self.path):
16
+ try:
17
+ if (os_name := os.name) == "posix":
18
+ subprocess.run(["open", self.path], check=True) # macOS
19
+ elif os_name == "nt":
20
+ os.startfile(self.path) # Windows
21
+ else:
22
+ subprocess.run(["xdg-open", self.path], check=True) # Linux
23
+ except Exception as e:
24
+ print(f"Error opening SQL file: {e}")
25
+ else:
26
+ print("SQL file was not found.")
27
+
28
+ def view_notebook(self):
29
+ from IPython.display import FileLink, display, HTML
30
+ import pygments
31
+ from pygments.lexers import SqlLexer
32
+ from pygments.formatters import HtmlFormatter
33
+
34
+ try:
35
+ with open(self.path, "r", encoding="utf-8") as f:
36
+ content = f.read()
37
+
38
+ formatter = HtmlFormatter(style="monokai")
39
+ highlighted_sql = pygments.highlight(content, SqlLexer(), formatter)
40
+ css = formatter.get_style_defs(".highlight")
41
+ display(HTML(f"<style>{css}</style>{highlighted_sql}"))
42
+ display(FileLink(self.path))
43
+ except Exception as e:
44
+ print(f"Error displaying SQL: {e}")
45
+
46
+ def _format_keywords(self, sql: str) -> str:
47
+ """Capitalize SQL keywords."""
48
+ keywords = {
49
+ "select",
50
+ "from",
51
+ "where",
52
+ "and",
53
+ "or",
54
+ "insert",
55
+ "update",
56
+ "delete",
57
+ "create",
58
+ "drop",
59
+ "alter",
60
+ "table",
61
+ "into",
62
+ "values",
63
+ "group",
64
+ "by",
65
+ "having",
66
+ "order",
67
+ "limit",
68
+ "join",
69
+ "left",
70
+ "right",
71
+ "inner",
72
+ "outer",
73
+ "on",
74
+ "as",
75
+ "distinct",
76
+ "count",
77
+ "sum",
78
+ "avg",
79
+ "max",
80
+ "min",
81
+ "between",
82
+ "like",
83
+ "in",
84
+ "is",
85
+ "null",
86
+ "not",
87
+ "case",
88
+ "when",
89
+ "then",
90
+ "else",
91
+ "end",
92
+ }
93
+
94
+ words = sql.split()
95
+ formatted_words = []
96
+ for word in words:
97
+ lower_word = word.lower()
98
+ if lower_word in keywords:
99
+ formatted_words.append(word.upper())
100
+ else:
101
+ formatted_words.append(word.lower())
102
+ return " ".join(formatted_words)
103
+
104
+ def _indent_sql(self, sql: str) -> str:
105
+ """Add basic indentation to SQL statement."""
106
+ lines = sql.split("\n")
107
+ indented_lines = []
108
+ indent_level = 0
109
+
110
+ for line in lines:
111
+ line = line.strip()
112
+
113
+ # Decrease indent for closing parentheses
114
+ if line.startswith(")"):
115
+ indent_level = max(0, indent_level - 1)
116
+
117
+ # Add indentation
118
+ if line:
119
+ indented_lines.append(" " * indent_level + line)
120
+ else:
121
+ indented_lines.append("")
122
+
123
+ # Increase indent after opening parentheses
124
+ if line.endswith("("):
125
+ indent_level += 1
126
+
127
+ # Special cases for common SQL clauses
128
+ lower_line = line.lower()
129
+ if any(
130
+ clause in lower_line
131
+ for clause in [
132
+ "select",
133
+ "from",
134
+ "where",
135
+ "group by",
136
+ "having",
137
+ "order by",
138
+ ]
139
+ ):
140
+ indent_level = 1
141
+
142
+ return "\n".join(indented_lines)
143
+
144
+ def format_sql(self) -> bool:
145
+ """Format the SQL file with proper indentation and keyword capitalization."""
146
+ try:
147
+ with open(self.path, "r", encoding="utf-8") as f:
148
+ content = f.read()
149
+
150
+ # Remove extra whitespace and format
151
+ content = " ".join(content.split())
152
+ content = self._format_keywords(content)
153
+ content = self._indent_sql(content)
154
+
155
+ # Wrap long lines
156
+ wrapped_content = []
157
+ for line in content.split("\n"):
158
+ if len(line) > 80:
159
+ wrapped_line = textwrap.fill(
160
+ line, width=80, subsequent_indent=" "
161
+ )
162
+ wrapped_content.append(wrapped_line)
163
+ else:
164
+ wrapped_content.append(line)
165
+
166
+ formatted_sql = "\n".join(wrapped_content)
167
+
168
+ with open(self.path, "w", encoding="utf-8") as f:
169
+ f.write(formatted_sql)
170
+
171
+ return True
172
+ except Exception as e:
173
+ print(f"Error formatting SQL: {e}")
174
+ return False
175
+
176
+ def split_statements(self) -> List[str]:
177
+ """Split the SQL file into individual statements."""
178
+ try:
179
+ with open(self.path, "r", encoding="utf-8") as f:
180
+ content = f.read()
181
+
182
+ # Handle both semicolon and GO statement terminators
183
+ statements = []
184
+ current_stmt = []
185
+
186
+ for line in content.split("\n"):
187
+ line = line.strip()
188
+
189
+ # Skip empty lines and comments
190
+ if not line or line.startswith("--"):
191
+ continue
192
+
193
+ if line.endswith(";"):
194
+ current_stmt.append(line[:-1]) # Remove semicolon
195
+ statements.append(" ".join(current_stmt))
196
+ current_stmt = []
197
+ elif line.upper() == "GO":
198
+ if current_stmt:
199
+ statements.append(" ".join(current_stmt))
200
+ current_stmt = []
201
+ else:
202
+ current_stmt.append(line)
203
+
204
+ # Add any remaining statement
205
+ if current_stmt:
206
+ statements.append(" ".join(current_stmt))
207
+
208
+ return [stmt.strip() for stmt in statements if stmt.strip()]
209
+ except Exception as e:
210
+ print(f"Error splitting SQL statements: {e}")
211
+ return []
212
+
213
+ def validate_basic_syntax(self) -> bool:
214
+ """
215
+ Perform basic SQL syntax validation.
216
+ This is a simple check and doesn't replace proper SQL parsing.
217
+ """
218
+ try:
219
+ with open(self.path, "r", encoding="utf-8") as f:
220
+ content = f.read()
221
+
222
+ statements = self.split_statements()
223
+ for stmt in statements:
224
+ # Check for basic SQL keywords
225
+ stmt_upper = stmt.upper()
226
+ if not any(
227
+ keyword in stmt_upper
228
+ for keyword in [
229
+ "SELECT",
230
+ "INSERT",
231
+ "UPDATE",
232
+ "DELETE",
233
+ "CREATE",
234
+ "DROP",
235
+ "ALTER",
236
+ ]
237
+ ):
238
+ print(f"Warning: Statement might be incomplete: {stmt}")
239
+
240
+ # Check for basic parentheses matching
241
+ if stmt.count("(") != stmt.count(")"):
242
+ print(f"Error: Unmatched parentheses in statement: {stmt}")
243
+ return False
244
+
245
+ # Check for basic quote matching
246
+ if stmt.count("'") % 2 != 0:
247
+ print(f"Error: Unmatched quotes in statement: {stmt}")
248
+ return False
249
+
250
+ return True
251
+ except Exception as e:
252
+ print(f"Error validating SQL: {e}")
253
+ return False
254
+
255
+ def extract_table_names(self) -> List[str]:
256
+ """Extract table names from the SQL file."""
257
+ tables = set()
258
+ try:
259
+ with open(self.path, "r", encoding="utf-8") as f:
260
+ content = f.read()
261
+
262
+ patterns = [
263
+ r"FROM\s+([a-zA-Z_][a-zA-Z0-9_]*)",
264
+ r"JOIN\s+([a-zA-Z_][a-zA-Z0-9_]*)",
265
+ r"UPDATE\s+([a-zA-Z_][a-zA-Z0-9_]*)",
266
+ r"INSERT\s+INTO\s+([a-zA-Z_][a-zA-Z0-9_]*)",
267
+ r"CREATE\s+TABLE\s+([a-zA-Z_][a-zA-Z0-9_]*)",
268
+ ]
269
+
270
+ for pattern in patterns:
271
+ tables.update(re.findall(pattern, content, re.IGNORECASE))
272
+
273
+ return sorted(list(tables))
274
+ except Exception as e:
275
+ print(f"Error extracting table names: {e}")
276
+ return []
277
+
278
+ def example(self):
279
+ sample_sql = """-- Sample SQL file with common operations
280
+ CREATE TABLE employees (
281
+ id INTEGER PRIMARY KEY,
282
+ name VARCHAR(100) NOT NULL,
283
+ department VARCHAR(50),
284
+ salary DECIMAL(10,2),
285
+ hire_date DATE
286
+ );
287
+
288
+ INSERT INTO employees (name, department, salary, hire_date)
289
+ VALUES
290
+ ('John Doe', 'Engineering', 75000.00, '2023-01-15'),
291
+ ('Jane Smith', 'Marketing', 65000.00, '2023-02-01');
292
+
293
+ -- Query to analyze employee data
294
+ SELECT
295
+ department,
296
+ COUNT(*) as employee_count,
297
+ AVG(salary) as avg_salary
298
+ FROM employees
299
+ GROUP BY department
300
+ HAVING COUNT(*) > 0
301
+ ORDER BY avg_salary DESC;
302
+
303
+ -- Update salary with conditions
304
+ UPDATE employees
305
+ SET salary = salary * 1.1
306
+ WHERE department = 'Engineering'
307
+ AND hire_date < '2024-01-01';
308
+ """
309
+ with tempfile.NamedTemporaryFile(
310
+ delete=False, suffix=".sql", mode="w", encoding="utf-8"
311
+ ) as f:
312
+ f.write(sample_sql)
313
+ return f.name
@@ -0,0 +1,149 @@
1
+ from edsl.scenarios.file_methods import FileMethods
2
+ import os
3
+ import tempfile
4
+ import sqlite3
5
+
6
+
7
+ class SQLiteMethods(FileMethods):
8
+ suffix = "db" # or "sqlite", depending on your preference
9
+
10
+ def extract_text(self):
11
+ """
12
+ Extracts a text representation of the database schema and table contents.
13
+ """
14
+ with sqlite3.connect(self.path) as conn:
15
+ cursor = conn.cursor()
16
+
17
+ # Get all table names
18
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
19
+ tables = cursor.fetchall()
20
+
21
+ full_text = []
22
+
23
+ # For each table, get schema and contents
24
+ for (table_name,) in tables:
25
+ # Get table schema
26
+ cursor.execute(
27
+ f"SELECT sql FROM sqlite_master WHERE type='table' AND name='{table_name}';"
28
+ )
29
+ schema = cursor.fetchone()[0]
30
+ full_text.append(f"Table: {table_name}")
31
+ full_text.append(f"Schema: {schema}")
32
+
33
+ # Get table contents
34
+ cursor.execute(f"SELECT * FROM {table_name};")
35
+ rows = cursor.fetchall()
36
+
37
+ # Get column names
38
+ column_names = [description[0] for description in cursor.description]
39
+ full_text.append(f"Columns: {', '.join(column_names)}")
40
+
41
+ # Add row data
42
+ for row in rows:
43
+ full_text.append(str(row))
44
+ full_text.append("\n")
45
+
46
+ return "\n".join(full_text)
47
+
48
+ def view_system(self):
49
+ """
50
+ Opens the database with the system's default SQLite viewer if available.
51
+ """
52
+ import os
53
+ import subprocess
54
+
55
+ if os.path.exists(self.path):
56
+ try:
57
+ if (os_name := os.name) == "posix":
58
+ # Try DB Browser for SQLite on macOS
59
+ subprocess.run(
60
+ ["open", "-a", "DB Browser for SQLite", self.path], check=True
61
+ )
62
+ elif os_name == "nt":
63
+ # Try DB Browser for SQLite on Windows
64
+ subprocess.run(["DB Browser for SQLite.exe", self.path], check=True)
65
+ else:
66
+ # Try sqlitebrowser on Linux
67
+ subprocess.run(["sqlitebrowser", self.path], check=True)
68
+ except Exception as e:
69
+ print(f"Error opening SQLite database: {e}")
70
+ else:
71
+ print("SQLite database file was not found.")
72
+
73
+ def view_notebook(self):
74
+ """
75
+ Displays database contents in a Jupyter notebook.
76
+ """
77
+ import pandas as pd
78
+ from IPython.display import HTML, display
79
+
80
+ with sqlite3.connect(self.path) as conn:
81
+ # Get all table names
82
+ cursor = conn.cursor()
83
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table';")
84
+ tables = cursor.fetchall()
85
+
86
+ html_parts = []
87
+ for (table_name,) in tables:
88
+ # Read table into pandas DataFrame
89
+ df = pd.read_sql_query(f"SELECT * FROM {table_name}", conn)
90
+
91
+ # Convert to HTML with styling
92
+ table_html = f"""
93
+ <div style="margin-bottom: 20px;">
94
+ <h3>{table_name}</h3>
95
+ {df.to_html(index=False)}
96
+ </div>
97
+ """
98
+ html_parts.append(table_html)
99
+
100
+ # Combine all tables into one scrollable div
101
+ html = f"""
102
+ <div style="width: 800px; height: 800px; padding: 20px;
103
+ border: 1px solid #ccc; overflow-y: auto;">
104
+ {''.join(html_parts)}
105
+ </div>
106
+ """
107
+ display(HTML(html))
108
+
109
+ def example(self):
110
+ """
111
+ Creates an example SQLite database for testing.
112
+ """
113
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".db") as tmp:
114
+ conn = sqlite3.connect(tmp.name)
115
+ cursor = conn.cursor()
116
+
117
+ # Create a sample table
118
+ cursor.execute(
119
+ """
120
+ CREATE TABLE survey_responses (
121
+ id INTEGER PRIMARY KEY,
122
+ question TEXT,
123
+ response TEXT
124
+ )
125
+ """
126
+ )
127
+
128
+ # Insert some sample data
129
+ sample_data = [
130
+ (1, "First Survey Question", "Response 1"),
131
+ (2, "Second Survey Question", "Response 2"),
132
+ ]
133
+ cursor.executemany(
134
+ "INSERT INTO survey_responses (id, question, response) VALUES (?, ?, ?)",
135
+ sample_data,
136
+ )
137
+
138
+ conn.commit()
139
+ conn.close()
140
+ tmp.close()
141
+
142
+ return tmp.name
143
+
144
+
145
+ if __name__ == "__main__":
146
+ sqlite_temp = SQLiteMethods.example()
147
+ from edsl.scenarios.FileStore import FileStore
148
+
149
+ fs = FileStore(sqlite_temp)
@@ -0,0 +1,33 @@
1
+ from edsl.scenarios.file_methods import FileMethods
2
+ import tempfile
3
+
4
+
5
+ class TxtMethods(FileMethods):
6
+ suffix = "txt"
7
+
8
+ def view_system(self):
9
+ import os
10
+ import subprocess
11
+
12
+ if os.path.exists(self.path):
13
+ try:
14
+ if (os_name := os.name) == "posix":
15
+ subprocess.run(["open", self.path], check=True) # macOS
16
+ elif os_name == "nt":
17
+ os.startfile(self.path) # Windows
18
+ else:
19
+ subprocess.run(["xdg-open", self.path], check=True) # Linux
20
+ except Exception as e:
21
+ print(f"Error opening TXT: {e}")
22
+ else:
23
+ print("TXT file was not found.")
24
+
25
+ def view_notebook(self):
26
+ from IPython.display import FileLink, display
27
+
28
+ display(FileLink(self.path))
29
+
30
+ def example(self):
31
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".txt") as f:
32
+ f.write(b"Hello, World!")
33
+ return f.name
@@ -1,9 +1,9 @@
1
1
  from __future__ import annotations
2
2
  from typing import Union, TYPE_CHECKING
3
3
 
4
- # if TYPE_CHECKING:
5
- from edsl.scenarios.ScenarioList import ScenarioList
6
- from edsl.scenarios.Scenario import Scenario
4
+ if TYPE_CHECKING:
5
+ from edsl.scenarios.ScenarioList import ScenarioList
6
+ from edsl.scenarios.Scenario import Scenario
7
7
 
8
8
 
9
9
  class ScenarioJoin:
@@ -23,7 +23,7 @@ class ScenarioJoin:
23
23
  self.left = left
24
24
  self.right = right
25
25
 
26
- def left_join(self, by: Union[str, list[str]]) -> ScenarioList:
26
+ def left_join(self, by: Union[str, list[str]]) -> "ScenarioList":
27
27
  """Perform a left join between the two ScenarioLists.
28
28
 
29
29
  Args:
@@ -35,6 +35,8 @@ class ScenarioJoin:
35
35
  Raises:
36
36
  ValueError: If by is empty or if any join keys don't exist in both ScenarioLists
37
37
  """
38
+ from edsl.scenarios.ScenarioList import ScenarioList
39
+
38
40
  self._validate_join_keys(by)
39
41
  by_keys = [by] if isinstance(by, str) else by
40
42
 
@@ -86,6 +88,8 @@ class ScenarioJoin:
86
88
  self, by_keys: list[str], other_dict: dict, all_keys: set
87
89
  ) -> list[Scenario]:
88
90
  """Create the joined scenarios."""
91
+ from edsl.scenarios.Scenario import Scenario
92
+
89
93
  new_scenarios = []
90
94
 
91
95
  for scenario in self.left:
@@ -105,8 +109,8 @@ class ScenarioJoin:
105
109
  def _handle_matching_scenario(
106
110
  self,
107
111
  new_scenario: dict,
108
- left_scenario: Scenario,
109
- right_scenario: Scenario,
112
+ left_scenario: "Scenario",
113
+ right_scenario: "Scenario",
110
114
  by_keys: list[str],
111
115
  ) -> None:
112
116
  """Handle merging of matching scenarios and conflict warnings."""
@@ -0,0 +1,156 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+
4
+ class ScenarioSelector:
5
+ """
6
+ A class for performing advanced field selection on ScenarioList objects,
7
+ including support for wildcard patterns.
8
+
9
+ Args:
10
+ scenario_list: The ScenarioList object to perform selections on
11
+
12
+ Examples:
13
+ >>> from edsl import Scenario, ScenarioList
14
+ >>> scenarios = ScenarioList([Scenario({'test_1': 1, 'test_2': 2, 'other': 3}), Scenario({'test_1': 4, 'test_2': 5, 'other': 6})])
15
+ >>> selector = ScenarioSelector(scenarios)
16
+ >>> selector.select('test*')
17
+ ScenarioList([Scenario({'test_1': 1, 'test_2': 2}), Scenario({'test_1': 4, 'test_2': 5})])
18
+ """
19
+
20
+ def __init__(self, scenario_list: "ScenarioList"):
21
+ """Initialize with a ScenarioList object."""
22
+ self.scenario_list = scenario_list
23
+ self.available_fields = (
24
+ list(scenario_list.data[0].keys()) if scenario_list.data else []
25
+ )
26
+
27
+ def _match_field_pattern(self, pattern: str, field: str) -> bool:
28
+ """
29
+ Checks if a field name matches a pattern with wildcards.
30
+ Supports '*' as wildcard at start or end of pattern.
31
+
32
+ Args:
33
+ pattern: The pattern to match against, may contain '*' at start or end
34
+ field: The field name to check
35
+
36
+ Examples:
37
+ >>> from edsl.scenarios import ScenarioList, Scenario
38
+ >>> selector = ScenarioSelector(ScenarioList([]))
39
+ >>> selector._match_field_pattern('test*', 'test_field')
40
+ True
41
+ >>> selector._match_field_pattern('*field', 'test_field')
42
+ True
43
+ >>> selector._match_field_pattern('test', 'test')
44
+ True
45
+ >>> selector._match_field_pattern('*test*', 'my_test_field')
46
+ True
47
+ """
48
+ if "*" not in pattern:
49
+ return pattern == field
50
+
51
+ if pattern.startswith("*") and pattern.endswith("*"):
52
+ return pattern[1:-1] in field
53
+ elif pattern.startswith("*"):
54
+ return field.endswith(pattern[1:])
55
+ elif pattern.endswith("*"):
56
+ return field.startswith(pattern[:-1])
57
+ return pattern == field
58
+
59
+ def _get_matching_fields(self, patterns: list[str]) -> list[str]:
60
+ """
61
+ Gets all fields that match any of the given patterns.
62
+
63
+ Args:
64
+ patterns: List of field patterns, may contain wildcards
65
+
66
+ Returns:
67
+ List of field names that match at least one pattern
68
+
69
+ Examples:
70
+ >>> from edsl import Scenario, ScenarioList
71
+ >>> scenarios = ScenarioList([
72
+ ... Scenario({'test_1': 1, 'test_2': 2, 'other': 3})
73
+ ... ])
74
+ >>> selector = ScenarioSelector(scenarios)
75
+ >>> selector._get_matching_fields(['test*'])
76
+ ['test_1', 'test_2']
77
+ """
78
+ matching_fields = set()
79
+ for pattern in patterns:
80
+ matches = [
81
+ field
82
+ for field in self.available_fields
83
+ if self._match_field_pattern(pattern, field)
84
+ ]
85
+ matching_fields.update(matches)
86
+ return sorted(list(matching_fields))
87
+
88
+ def select(self, *fields) -> "ScenarioList":
89
+ """
90
+ Selects scenarios with only the referenced fields.
91
+ Supports wildcard patterns using '*' at the start or end of field names.
92
+
93
+ Args:
94
+ *fields: Field names or patterns to select. Patterns may include '*' for wildcards.
95
+
96
+ Returns:
97
+ A new ScenarioList containing only the matched fields.
98
+
99
+ Raises:
100
+ ValueError: If no fields match the given patterns.
101
+
102
+ Examples:
103
+ >>> from edsl import Scenario, ScenarioList
104
+ >>> scenarios = ScenarioList([
105
+ ... Scenario({'test_1': 1, 'test_2': 2, 'other': 3}),
106
+ ... Scenario({'test_1': 4, 'test_2': 5, 'other': 6})
107
+ ... ])
108
+ >>> selector = ScenarioSelector(scenarios)
109
+ >>> selector.select('test*') # Selects all fields starting with 'test'
110
+ ScenarioList([Scenario({'test_1': 1, 'test_2': 2}), Scenario({'test_1': 4, 'test_2': 5})])
111
+ >>> selector.select('*_1') # Selects all fields ending with '_1'
112
+ ScenarioList([Scenario({'test_1': 1}), Scenario({'test_1': 4})])
113
+ >>> selector.select('test_1', '*_2') # Multiple patterns
114
+ ScenarioList([Scenario({'test_1': 1, 'test_2': 2}), Scenario({'test_1': 4, 'test_2': 5})])
115
+ """
116
+ if not self.scenario_list.data:
117
+ return self.scenario_list.__class__([])
118
+
119
+ # Convert single string to list for consistent processing
120
+ patterns = list(fields)
121
+
122
+ # Get all fields that match the patterns
123
+ fields_to_select = self._get_matching_fields(patterns)
124
+
125
+ # If no fields match, raise an informative error
126
+ if not fields_to_select:
127
+ raise ValueError(
128
+ f"No fields matched the given patterns: {patterns}. "
129
+ f"Available fields are: {self.available_fields}"
130
+ )
131
+
132
+ return self.scenario_list.__class__(
133
+ [scenario.select(fields_to_select) for scenario in self.scenario_list.data]
134
+ )
135
+
136
+ def get_available_fields(self) -> list[str]:
137
+ """
138
+ Returns a list of all available fields in the ScenarioList.
139
+
140
+ Returns:
141
+ List of field names available for selection.
142
+
143
+ Examples:
144
+ >>> from edsl import Scenario, ScenarioList
145
+ >>> scenarios = ScenarioList([Scenario({'test_1': 1, 'test_2': 2, 'other': 3})])
146
+ >>> selector = ScenarioSelector(scenarios)
147
+ >>> selector.get_available_fields()
148
+ ['other', 'test_1', 'test_2']
149
+ """
150
+ return sorted(self.available_fields)
151
+
152
+
153
+ if __name__ == "__main__":
154
+ import doctest
155
+
156
+ doctest.testmod(optionflags=doctest.ELLIPSIS)