edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -1,190 +0,0 @@
1
- import textwrap
2
- from random import random
3
- from edsl.config import CONFIG
4
-
5
- # if "EDSL_DEFAULT_MODEL" not in CONFIG:
6
- # default_model = "test"
7
- # else:
8
- # default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
9
-
10
- from collections import UserList
11
-
12
-
13
- class PrettyList(UserList):
14
- def __init__(self, data=None, columns=None):
15
- super().__init__(data)
16
- self.columns = columns
17
-
18
- def _repr_html_(self):
19
- if isinstance(self[0], list) or isinstance(self[0], tuple):
20
- num_cols = len(self[0])
21
- else:
22
- num_cols = 1
23
-
24
- if self.columns:
25
- columns = self.columns
26
- else:
27
- columns = list(range(num_cols))
28
-
29
- if num_cols > 1:
30
- return (
31
- "<pre><table>"
32
- + "".join(["<th>" + str(column) + "</th>" for column in columns])
33
- + "".join(
34
- [
35
- "<tr>"
36
- + "".join(["<td>" + str(x) + "</td>" for x in row])
37
- + "</tr>"
38
- for row in self
39
- ]
40
- )
41
- + "</table></pre>"
42
- )
43
- else:
44
- return (
45
- "<pre><table>"
46
- + "".join(["<th>" + str(index) + "</th>" for index in columns])
47
- + "".join(
48
- ["<tr>" + "<td>" + str(row) + "</td>" + "</tr>" for row in self]
49
- )
50
- + "</table></pre>"
51
- )
52
-
53
-
54
- def get_model_class(model_name, registry=None):
55
- from edsl.inference_services.registry import default
56
-
57
- registry = registry or default
58
- factory = registry.create_model_factory(model_name)
59
- return factory
60
-
61
-
62
- class Meta(type):
63
- def __repr__(cls):
64
- return textwrap.dedent(
65
- f"""\
66
- Available models: {cls.available()}
67
-
68
- To create an instance, you can do:
69
- >>> m = Model('gpt-4-1106-preview', temperature=0.5, ...)
70
-
71
- To get the default model, you can leave out the model name.
72
- To see the available models, you can do:
73
- >>> Model.available()
74
- """
75
- )
76
-
77
-
78
- class Model(metaclass=Meta):
79
- default_model = CONFIG.get("EDSL_DEFAULT_MODEL")
80
-
81
- def __new__(
82
- cls, model_name=None, registry=None, service_name=None, *args, **kwargs
83
- ):
84
- # Map index to the respective subclass
85
- if model_name is None:
86
- model_name = (
87
- cls.default_model
88
- ) # when model_name is None, use the default model, set in the config file
89
- from edsl.inference_services.registry import default
90
-
91
- registry = registry or default
92
-
93
- if isinstance(model_name, int): # can refer to a model by index
94
- model_name = cls.available(name_only=True)[model_name]
95
-
96
- factory = registry.create_model_factory(model_name, service_name=service_name)
97
- return factory(*args, **kwargs)
98
-
99
- @classmethod
100
- def add_model(cls, service_name, model_name):
101
- from edsl.inference_services.registry import default
102
-
103
- registry = default
104
- registry.add_model(service_name, model_name)
105
-
106
- @classmethod
107
- def services(cls, registry=None):
108
- from edsl.inference_services.registry import default
109
-
110
- registry = registry or default
111
- return [r._inference_service_ for r in registry.services]
112
-
113
- @classmethod
114
- def available(cls, search_term=None, name_only=False, registry=None, service=None):
115
- from edsl.inference_services.registry import default
116
-
117
- registry = registry or default
118
- full_list = registry.available()
119
-
120
- if service is not None:
121
- if service not in cls.services(registry=registry):
122
- raise ValueError(f"Service {service} not found in available services.")
123
-
124
- full_list = [m for m in full_list if m[1] == service]
125
-
126
- if search_term is None:
127
- if name_only:
128
- return PrettyList(
129
- [m[0] for m in full_list],
130
- columns=["Model Name", "Service Name", "Code"],
131
- )
132
- else:
133
- return PrettyList(
134
- full_list, columns=["Model Name", "Service Name", "Code"]
135
- )
136
- else:
137
- filtered_results = [
138
- m for m in full_list if search_term in m[0] or search_term in m[1]
139
- ]
140
- if name_only:
141
- return PrettyList(
142
- [m[0] for m in filtered_results],
143
- columns=["Model Name", "Service Name", "Code"],
144
- )
145
- else:
146
- return PrettyList(
147
- filtered_results, columns=["Model Name", "Service Name", "Code"]
148
- )
149
-
150
- @classmethod
151
- def check_models(cls, verbose=False):
152
- print("Checking all available models...\n")
153
- for model in cls.available(name_only=True):
154
- print(f"Now checking: {model}")
155
- try:
156
- m = cls(model)
157
- except Exception as e:
158
- print(f"Error creating instance of {model}: {e}")
159
- continue
160
- try:
161
- results = m.hello(verbose)
162
- if verbose:
163
- print(f"Results from model call: {results}")
164
- except Exception as e:
165
- print(f"Error calling 'hello' on {model}: {e}")
166
- continue
167
- print("OK!")
168
- print("\n")
169
-
170
- @classmethod
171
- def example(cls, randomize: bool = False) -> "Model":
172
- """
173
- Returns an example Model instance.
174
-
175
- :param randomize: If True, the temperature is set to a random decimal between 0 and 1.
176
- """
177
- temperature = 0.5 if not randomize else round(random(), 2)
178
- model_name = cls.default_model
179
- return cls(model_name, temperature=temperature)
180
-
181
-
182
- if __name__ == "__main__":
183
- import doctest
184
-
185
- doctest.testmod(optionflags=doctest.ELLIPSIS)
186
-
187
- available = Model.available()
188
- m = Model("gpt-4-1106-preview")
189
- results = m.execute_model_call("Hello world")
190
- print(results)
@@ -1,83 +0,0 @@
1
- import asyncio
2
- import aiohttp
3
- import json
4
- from typing import Any
5
-
6
- from edsl import CONFIG
7
-
8
- from edsl.language_models.LanguageModel import LanguageModel
9
-
10
-
11
- def replicate_model_factory(model_name, base_url, api_token):
12
- class ReplicateLanguageModelBase(LanguageModel):
13
- _model_ = (
14
- model_name # Example model name, replace with actual model name if needed
15
- )
16
- _parameters_ = {
17
- "temperature": 0.1,
18
- "topK": 50,
19
- "topP": 0.9,
20
- "max_new_tokens": 500,
21
- "min_new_tokens": -1,
22
- "repetition_penalty": 1.15,
23
- # "version": "5fe0a3d7ac2852264a25279d1dfb798acbc4d49711d126646594e212cb821749",
24
- "use_cache": True,
25
- }
26
- _api_token = api_token
27
- _base_url = base_url
28
-
29
- async def async_execute_model_call(
30
- self, user_prompt: str, system_prompt: str = ""
31
- ) -> dict[str, Any]:
32
- self.api_token = self._api_token
33
- self.headers = {
34
- "Authorization": f"Token {self.api_token}",
35
- "Content-Type": "application/json",
36
- }
37
- # combined_prompt = f"{system_prompt} {user_prompt}".strip()
38
- # print(f"Prompt: {combined_prompt}")
39
- data = {
40
- # "version": self._parameters_["version"],
41
- "input": {
42
- "debug": False,
43
- "top_k": self._parameters_["topK"],
44
- "top_p": self._parameters_["topP"],
45
- "prompt": user_prompt,
46
- "system_prompt": system_prompt,
47
- "temperature": self._parameters_["temperature"],
48
- "max_new_tokens": self._parameters_["max_new_tokens"],
49
- "min_new_tokens": self._parameters_["min_new_tokens"],
50
- "prompt_template": "{prompt}",
51
- "repetition_penalty": self._parameters_["repetition_penalty"],
52
- },
53
- }
54
-
55
- async with aiohttp.ClientSession() as session:
56
- async with session.post(
57
- self._base_url, headers=self.headers, data=json.dumps(data)
58
- ) as response:
59
- raw_response_text = await response.text()
60
- data = json.loads(raw_response_text)
61
- print(f"This was the data returned by the model:{data}")
62
- prediction_url = data["urls"]["get"]
63
-
64
- while True:
65
- async with session.get(
66
- prediction_url, headers=self.headers
67
- ) as get_response:
68
- if get_response.status != 200:
69
- # Handle non-success status codes appropriately
70
- return None
71
-
72
- get_data = await get_response.text()
73
- get_data = json.loads(get_data)
74
- if get_data["status"] == "succeeded":
75
- return get_data
76
- await asyncio.sleep(1)
77
-
78
- def parse_response(self, raw_response: dict[str, Any]) -> str:
79
- data = "".join(raw_response["output"])
80
- print(f"This is what the model returned: {data}")
81
- return data
82
-
83
- return ReplicateLanguageModelBase
@@ -1,238 +0,0 @@
1
- """Mixin for working with SQLite respresentation of a 'Results' object."""
2
-
3
- import sqlite3
4
- from enum import Enum
5
- from typing import Literal, Union, Optional
6
-
7
-
8
- class SQLDataShape(Enum):
9
- """Enum for the shape of the data in the SQL database."""
10
-
11
- WIDE = "wide"
12
- LONG = "long"
13
-
14
-
15
- class ResultsDBMixin:
16
- """Mixin for interacting with a Results object as if it were a SQL database."""
17
-
18
- def _rows(self):
19
- """Return the rows of the `Results` object as a list of tuples."""
20
- for index, result in enumerate(self):
21
- yield from result.rows(index)
22
-
23
- def export_sql_dump(self, shape: Literal["wide", "long"], filename: str):
24
- """Export the SQL database to a file.
25
-
26
- :param shape: The shape of the data in the database (wide or long)
27
- :param filename: The filename to save the database to
28
- """
29
- shape_enum = self._get_shape_enum(shape)
30
- conn = self._db(shape=shape_enum)
31
-
32
- with open(filename, "w") as f:
33
- for line in conn.iterdump():
34
- f.write(f"{line}\n")
35
-
36
- conn.close()
37
-
38
- def backup_db_to_file(self, shape: Literal["wide", "long"], filename: str):
39
- """Backup the in-memory database to a file.
40
-
41
-
42
- :param shape: The shape of the data in the database (wide or long)
43
- :param filename: The filename to save the database to
44
-
45
- >>> from edsl.results import Results
46
- >>> r = Results.example()
47
- >>> r.backup_db_to_file(filename="backup.db", shape="long")
48
-
49
- """
50
- shape_enum = self._get_shape_enum(shape)
51
- # Source database connection (in-memory)
52
- source_conn = self._db(shape=shape_enum)
53
-
54
- # Destination database connection (file)
55
- dest_conn = sqlite3.connect(filename)
56
-
57
- # Backup in-memory database to file
58
- with source_conn:
59
- source_conn.backup(dest_conn)
60
-
61
- # Close both connections
62
- source_conn.close()
63
- dest_conn.close()
64
-
65
- def _db(self, shape: SQLDataShape, remove_prefix=False):
66
- """Create a SQLite database in memory and return the connection.
67
-
68
- :param shape: The shape of the data in the database (wide or long)
69
- :param remove_prefix: Whether to remove the prefix from the column names
70
-
71
- """
72
- if shape == SQLDataShape.LONG:
73
- conn = sqlite3.connect(":memory:")
74
-
75
- create_table_query = """
76
- CREATE TABLE self (
77
- id INTEGER,
78
- data_type TEXT,
79
- key TEXT,
80
- value TEXT
81
- )
82
- """
83
- conn.execute(create_table_query)
84
-
85
- list_of_tuples = list(self._rows())
86
- insert_query = (
87
- "INSERT INTO self (id, data_type, key, value) VALUES (?, ?, ?, ?)"
88
- )
89
- conn.executemany(insert_query, list_of_tuples)
90
- conn.commit()
91
- return conn
92
- elif shape == SQLDataShape.WIDE:
93
- from sqlalchemy import create_engine
94
-
95
- engine = create_engine("sqlite:///:memory:")
96
- df = self.to_pandas(remove_prefix=remove_prefix, lists_as_strings=True)
97
- df.to_sql("self", engine, index=False, if_exists="replace")
98
- return engine.connect()
99
- else:
100
- raise Exception("Invalid SQLDataShape")
101
-
102
- def _get_shape_enum(self, shape: Literal["wide", "long"]):
103
- """Convert the shape string to a SQLDataShape enum."""
104
- if shape is None:
105
- raise Exception("Must select either 'wide' or 'long' format")
106
- elif shape == "wide":
107
- return SQLDataShape.WIDE
108
- elif shape == "long":
109
- return SQLDataShape.LONG
110
- else:
111
- raise Exception("Invalid shape: must be either 'long' or 'wide'")
112
-
113
- def sql(
114
- self,
115
- query: str,
116
- shape: Literal["wide", "long"] = "wide",
117
- remove_prefix: bool = True,
118
- transpose: bool = None,
119
- transpose_by: str = None,
120
- csv: bool = False,
121
- to_list=False,
122
- to_latex=False,
123
- filename: Optional[str] = None,
124
- ) -> Union["pd.DataFrame", str]:
125
- """Execute a SQL query and return the results as a DataFrame.
126
-
127
- :param query: The SQL query to execute
128
- :param shape: The shape of the data in the database (wide or long)
129
- :param remove_prefix: Whether to remove the prefix from the column names
130
- :param transpose: Whether to transpose the DataFrame
131
- :param transpose_by: The column to use as the index when transposing
132
- :param csv: Whether to return the DataFrame as a CSV string
133
-
134
-
135
- Example usage:
136
-
137
- >>> from edsl.results import Results
138
- >>> r = Results.example()
139
- >>> d = r.sql("select data_type, key, value from self where data_type = 'answer' order by value limit 3", shape="long")
140
- >>> sorted(list(d['value']))
141
- ['Good', 'Great', 'Great']
142
-
143
- We can also return the data in wide format.
144
- Note the use of single quotes to escape the column names, as required by sql.
145
-
146
- >>> from edsl.results import Results
147
- >>> Results.example().sql("select how_feeling from self", shape = 'wide', remove_prefix=True)
148
- how_feeling
149
- 0 OK
150
- 1 Great
151
- 2 Terrible
152
- 3 OK
153
- """
154
- import pandas as pd
155
-
156
- shape_enum = self._get_shape_enum(shape)
157
-
158
- conn = self._db(shape=shape_enum, remove_prefix=remove_prefix)
159
- df = pd.read_sql_query(query, conn)
160
-
161
- # Transpose the DataFrame if transpose is True
162
- if transpose or transpose_by:
163
- df = pd.DataFrame(df)
164
- if transpose_by:
165
- df = df.set_index(transpose_by)
166
- else:
167
- df = df.set_index(df.columns[0])
168
- df = df.transpose()
169
-
170
- if csv and to_list:
171
- raise Exception("Cannot return both CSV and list")
172
-
173
- if to_list:
174
- return df.values.tolist()
175
-
176
- if to_latex:
177
- df.columns = [col.replace("_", " ") for col in df.columns]
178
-
179
- latex_output = df.to_latex(index=False)
180
- if filename:
181
- with open(filename, "w") as f:
182
- f.write(latex_output)
183
- return None
184
- return latex_output
185
-
186
- if csv:
187
- if filename:
188
- df.to_csv(filename, index=False)
189
- return None
190
-
191
- return df.to_csv(index=False)
192
-
193
- return df
194
-
195
- def show_schema(
196
- self, shape: Literal["wide", "long"], remove_prefix: bool = False
197
- ) -> None:
198
- """Show the schema of the Results database.
199
-
200
- :param shape: The shape of the data in the database (wide or long)
201
- :param remove_prefix: Whether to remove the prefix from the column names
202
-
203
- >>> from edsl.results import Results
204
- >>> r = Results.example()
205
- >>> r.show_schema(shape="long")
206
- Type: table, Name: self, SQL: CREATE TABLE self (
207
- ...
208
- <BLANKLINE>
209
- """
210
- import pandas as pd
211
-
212
- shape_enum = self._get_shape_enum(shape)
213
- conn = self._db(shape=shape_enum, remove_prefix=remove_prefix)
214
-
215
- if shape_enum == SQLDataShape.LONG:
216
- # Query to get the schema of all tables
217
- query = "SELECT type, name, sql FROM sqlite_master WHERE type='table'"
218
- cursor = conn.execute(query)
219
- schema = cursor.fetchall()
220
- conn.close()
221
-
222
- # Format and return the schema information
223
- schema_info = ""
224
- for row in schema:
225
- schema_info += f"Type: {row[0]}, Name: {row[1]}, SQL: {row[2]}\n"
226
-
227
- print(schema_info)
228
- elif shape_enum == SQLDataShape.WIDE:
229
- query = f"PRAGMA table_info(self)"
230
- schema = pd.read_sql(query, conn)
231
- # print(schema)
232
- return schema
233
-
234
-
235
- if __name__ == "__main__":
236
- import doctest
237
-
238
- doctest.testmod(optionflags=doctest.ELLIPSIS)