edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -6
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +1 -1
- edsl/agents/PromptConstructor.py +92 -21
- edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
- edsl/agents/prompt_helpers.py +2 -2
- edsl/coop/coop.py +100 -22
- edsl/enums.py +3 -1
- edsl/exceptions/coop.py +4 -0
- edsl/inference_services/AnthropicService.py +2 -0
- edsl/inference_services/AvailableModelFetcher.py +4 -1
- edsl/inference_services/GoogleService.py +2 -0
- edsl/inference_services/GrokService.py +11 -0
- edsl/inference_services/InferenceServiceABC.py +1 -0
- edsl/inference_services/OpenAIService.py +1 -0
- edsl/inference_services/TestService.py +1 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +54 -35
- edsl/jobs/JobsChecks.py +7 -7
- edsl/jobs/JobsPrompts.py +57 -6
- edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
- edsl/jobs/buckets/BucketCollection.py +30 -0
- edsl/jobs/data_structures.py +1 -0
- edsl/language_models/LanguageModel.py +5 -2
- edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
- edsl/language_models/key_management/models.py +10 -4
- edsl/language_models/model.py +43 -11
- edsl/prompts/Prompt.py +124 -61
- edsl/questions/descriptors.py +32 -18
- edsl/questions/question_base_gen_mixin.py +1 -0
- edsl/results/DatasetExportMixin.py +35 -6
- edsl/results/Results.py +180 -1
- edsl/results/ResultsGGMixin.py +117 -60
- edsl/scenarios/FileStore.py +19 -8
- edsl/scenarios/Scenario.py +33 -0
- edsl/scenarios/ScenarioList.py +22 -3
- edsl/scenarios/ScenarioListPdfMixin.py +9 -3
- edsl/surveys/Survey.py +27 -6
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
- {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/language_models/model.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
import textwrap
|
2
2
|
from random import random
|
3
|
-
from typing import Optional, TYPE_CHECKING, List
|
3
|
+
from typing import Optional, TYPE_CHECKING, List, Callable
|
4
4
|
|
5
5
|
from edsl.utilities.PrettyList import PrettyList
|
6
6
|
from edsl.config import CONFIG
|
@@ -11,17 +11,21 @@ from edsl.inference_services.InferenceServicesCollection import (
|
|
11
11
|
from edsl.inference_services.data_structures import AvailableModels
|
12
12
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
13
13
|
from edsl.enums import InferenceServiceLiteral
|
14
|
+
from edsl.exceptions.inference_services import InferenceServiceError
|
14
15
|
|
15
16
|
if TYPE_CHECKING:
|
16
17
|
from edsl.results.Dataset import Dataset
|
17
18
|
|
18
19
|
|
19
|
-
def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None):
|
20
|
+
def get_model_class(model_name, registry: Optional[InferenceServicesCollection] = None, service_name: Optional[InferenceServiceLiteral] = None):
|
20
21
|
from edsl.inference_services.registry import default
|
21
22
|
|
22
23
|
registry = registry or default
|
23
|
-
|
24
|
-
|
24
|
+
try:
|
25
|
+
factory = registry.create_model_factory(model_name, service_name=service_name)
|
26
|
+
return factory
|
27
|
+
except (InferenceServiceError, Exception) as e:
|
28
|
+
return Model._handle_model_error(model_name, e)
|
25
29
|
|
26
30
|
|
27
31
|
class Meta(type):
|
@@ -58,6 +62,33 @@ class Model(metaclass=Meta):
|
|
58
62
|
"""Set a new registry"""
|
59
63
|
cls._registry = registry
|
60
64
|
|
65
|
+
@classmethod
|
66
|
+
def _handle_model_error(cls, model_name: str, error: Exception):
|
67
|
+
"""Handle errors from model creation and execution with notebook-aware behavior."""
|
68
|
+
if isinstance(error, InferenceServiceError):
|
69
|
+
services = [s._inference_service_ for s in cls.get_registry().services]
|
70
|
+
message = (
|
71
|
+
f"Model '{model_name}' not found in any services.\n"
|
72
|
+
"It is likely that our registry is just out of date.\n"
|
73
|
+
"Simply adding the service name to your model call should fix this.\n"
|
74
|
+
f"Available services are: {services}\n"
|
75
|
+
f"To specify a model with a service, use:\n"
|
76
|
+
f'Model("{model_name}", service_name="<service_name>")'
|
77
|
+
)
|
78
|
+
else:
|
79
|
+
message = f"An error occurred: {str(error)}"
|
80
|
+
|
81
|
+
# Check if we're in a notebook environment
|
82
|
+
try:
|
83
|
+
get_ipython()
|
84
|
+
print(message)
|
85
|
+
return None
|
86
|
+
except NameError:
|
87
|
+
# Not in a notebook, raise the exception
|
88
|
+
if isinstance(error, InferenceServiceError):
|
89
|
+
raise InferenceServiceError(message)
|
90
|
+
raise error
|
91
|
+
|
61
92
|
def __new__(
|
62
93
|
cls,
|
63
94
|
model_name: Optional[str] = None,
|
@@ -69,9 +100,7 @@ class Model(metaclass=Meta):
|
|
69
100
|
"Instantiate a new language model."
|
70
101
|
# Map index to the respective subclass
|
71
102
|
if model_name is None:
|
72
|
-
model_name =
|
73
|
-
cls.default_model
|
74
|
-
) # when model_name is None, use the default model, set in the config file
|
103
|
+
model_name = cls.default_model
|
75
104
|
|
76
105
|
if registry is not None:
|
77
106
|
cls.set_registry(registry)
|
@@ -79,10 +108,13 @@ class Model(metaclass=Meta):
|
|
79
108
|
if isinstance(model_name, int): # can refer to a model by index
|
80
109
|
model_name = cls.available(name_only=True)[model_name]
|
81
110
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
111
|
+
try:
|
112
|
+
factory = cls.get_registry().create_model_factory(
|
113
|
+
model_name, service_name=service_name
|
114
|
+
)
|
115
|
+
return factory(*args, **kwargs)
|
116
|
+
except (InferenceServiceError, Exception) as e:
|
117
|
+
return cls._handle_model_error(model_name, e)
|
86
118
|
|
87
119
|
@classmethod
|
88
120
|
def add_model(cls, service_name, model_name) -> None:
|
edsl/prompts/Prompt.py
CHANGED
@@ -10,6 +10,48 @@ from edsl.Base import PersistenceMixin, RepresentationMixin
|
|
10
10
|
|
11
11
|
MAX_NESTING = 100
|
12
12
|
|
13
|
+
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
14
|
+
from functools import lru_cache
|
15
|
+
|
16
|
+
class PreserveUndefined(Undefined):
|
17
|
+
def __str__(self):
|
18
|
+
return "{{ " + str(self._undefined_name) + " }}"
|
19
|
+
|
20
|
+
# Create environment once at module level
|
21
|
+
_env = Environment(undefined=PreserveUndefined)
|
22
|
+
|
23
|
+
@lru_cache(maxsize=1024)
|
24
|
+
def _compile_template(text: str):
|
25
|
+
return _env.from_string(text)
|
26
|
+
|
27
|
+
@lru_cache(maxsize=1024)
|
28
|
+
def _find_template_variables(template: str) -> list[str]:
|
29
|
+
"""Find and return the template variables."""
|
30
|
+
ast = _env.parse(template)
|
31
|
+
return list(meta.find_undeclared_variables(ast))
|
32
|
+
|
33
|
+
def _make_hashable(value):
|
34
|
+
"""Convert unhashable types to hashable ones."""
|
35
|
+
if isinstance(value, list):
|
36
|
+
return tuple(_make_hashable(item) for item in value)
|
37
|
+
if isinstance(value, dict):
|
38
|
+
return frozenset((k, _make_hashable(v)) for k, v in value.items())
|
39
|
+
return value
|
40
|
+
|
41
|
+
@lru_cache(maxsize=1024)
|
42
|
+
def _cached_render(text: str, frozen_replacements: frozenset) -> str:
|
43
|
+
"""Cached version of template rendering with frozen replacements."""
|
44
|
+
# Print cache info on every call
|
45
|
+
cache_info = _cached_render.cache_info()
|
46
|
+
print(f"\t\t\t\t\t Cache status - hits: {cache_info.hits}, misses: {cache_info.misses}, current size: {cache_info.currsize}")
|
47
|
+
|
48
|
+
# Convert back to dict with original types for rendering
|
49
|
+
replacements = {k: v for k, v in frozen_replacements}
|
50
|
+
|
51
|
+
template = _compile_template(text)
|
52
|
+
result = template.render(replacements)
|
53
|
+
|
54
|
+
return result
|
13
55
|
|
14
56
|
class Prompt(PersistenceMixin, RepresentationMixin):
|
15
57
|
"""Class for creating a prompt to be used in a survey."""
|
@@ -145,33 +187,8 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
145
187
|
return f'Prompt(text="""{self.text}""")'
|
146
188
|
|
147
189
|
def template_variables(self) -> list[str]:
|
148
|
-
"""Return the
|
149
|
-
|
150
|
-
Example:
|
151
|
-
|
152
|
-
>>> p = Prompt("Hello, {{person}}")
|
153
|
-
>>> p.template_variables()
|
154
|
-
['person']
|
155
|
-
|
156
|
-
"""
|
157
|
-
return self._template_variables(self.text)
|
158
|
-
|
159
|
-
@staticmethod
|
160
|
-
def _template_variables(template: str) -> list[str]:
|
161
|
-
"""Find and return the template variables.
|
162
|
-
|
163
|
-
:param template: The template to find the variables in.
|
164
|
-
|
165
|
-
"""
|
166
|
-
from jinja2 import Environment, meta, Undefined
|
167
|
-
|
168
|
-
class PreserveUndefined(Undefined):
|
169
|
-
def __str__(self):
|
170
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
171
|
-
|
172
|
-
env = Environment(undefined=PreserveUndefined)
|
173
|
-
ast = env.parse(template)
|
174
|
-
return list(meta.find_undeclared_variables(ast))
|
190
|
+
"""Return the variables in the template."""
|
191
|
+
return _find_template_variables(self.text)
|
175
192
|
|
176
193
|
def undefined_template_variables(self, replacement_dict: dict):
|
177
194
|
"""Return the variables in the template that are not in the replacement_dict.
|
@@ -239,45 +256,39 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
239
256
|
return self
|
240
257
|
|
241
258
|
@staticmethod
|
242
|
-
def _render(
|
243
|
-
text
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
Allows for nested variable resolution up to a specified maximum nesting depth.
|
252
|
-
|
253
|
-
Example:
|
254
|
-
|
255
|
-
>>> codebook = {"age": "Age"}
|
256
|
-
>>> p = Prompt("You are an agent named {{ name }}. {{ codebook['age']}}: {{ age }}")
|
257
|
-
>>> p.render({"name": "John", "age": 44}, codebook=codebook)
|
258
|
-
Prompt(text=\"""You are an agent named John. Age: 44\""")
|
259
|
-
"""
|
260
|
-
from jinja2 import Environment, meta, TemplateSyntaxError, Undefined
|
261
|
-
|
262
|
-
class PreserveUndefined(Undefined):
|
263
|
-
def __str__(self):
|
264
|
-
return "{{ " + str(self._undefined_name) + " }}"
|
265
|
-
|
266
|
-
env = Environment(undefined=PreserveUndefined)
|
259
|
+
def _render(text: str, primary_replacement, **additional_replacements) -> "PromptBase":
|
260
|
+
"""Render the template text with variables replaced."""
|
261
|
+
import time
|
262
|
+
|
263
|
+
# if there are no replacements, return the text
|
264
|
+
if not primary_replacement and not additional_replacements:
|
265
|
+
return text
|
266
|
+
|
267
267
|
try:
|
268
|
+
variables = _find_template_variables(text)
|
269
|
+
|
270
|
+
if not variables: # if there are no variables, return the text
|
271
|
+
return text
|
272
|
+
|
273
|
+
# Combine all replacements
|
274
|
+
all_replacements = {**primary_replacement, **additional_replacements}
|
275
|
+
|
268
276
|
previous_text = None
|
277
|
+
current_text = text
|
278
|
+
iteration = 0
|
279
|
+
|
269
280
|
for _ in range(MAX_NESTING):
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
)
|
274
|
-
|
275
|
-
|
281
|
+
iteration += 1
|
282
|
+
|
283
|
+
template = _compile_template(current_text)
|
284
|
+
rendered_text = template.render(all_replacements)
|
285
|
+
|
286
|
+
if rendered_text == current_text:
|
276
287
|
return rendered_text
|
277
|
-
|
278
|
-
|
288
|
+
|
289
|
+
previous_text = current_text
|
290
|
+
current_text = rendered_text
|
279
291
|
|
280
|
-
# If the loop exits without returning, it indicates too much nesting
|
281
292
|
raise TemplateRenderError(
|
282
293
|
"Too much nesting - you created an infinite loop here, pal"
|
283
294
|
)
|
@@ -331,6 +342,58 @@ class Prompt(PersistenceMixin, RepresentationMixin):
|
|
331
342
|
"""Return an example of the prompt."""
|
332
343
|
return cls(cls.default_instructions)
|
333
344
|
|
345
|
+
def get_prompts(self) -> Dict[str, Any]:
|
346
|
+
"""Get the prompts for the question."""
|
347
|
+
start = time.time()
|
348
|
+
|
349
|
+
# Build all the components
|
350
|
+
instr_start = time.time()
|
351
|
+
agent_instructions = self.agent_instructions_prompt
|
352
|
+
instr_end = time.time()
|
353
|
+
logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
|
354
|
+
|
355
|
+
persona_start = time.time()
|
356
|
+
agent_persona = self.agent_persona_prompt
|
357
|
+
persona_end = time.time()
|
358
|
+
logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
|
359
|
+
|
360
|
+
q_instr_start = time.time()
|
361
|
+
question_instructions = self.question_instructions_prompt
|
362
|
+
q_instr_end = time.time()
|
363
|
+
logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
|
364
|
+
|
365
|
+
memory_start = time.time()
|
366
|
+
prior_question_memory = self.prior_question_memory_prompt
|
367
|
+
memory_end = time.time()
|
368
|
+
logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
|
369
|
+
|
370
|
+
# Get components dict
|
371
|
+
components = {
|
372
|
+
"agent_instructions": agent_instructions.text,
|
373
|
+
"agent_persona": agent_persona.text,
|
374
|
+
"question_instructions": question_instructions.text,
|
375
|
+
"prior_question_memory": prior_question_memory.text,
|
376
|
+
}
|
377
|
+
|
378
|
+
# Use PromptPlan's get_prompts method
|
379
|
+
plan_start = time.time()
|
380
|
+
prompts = self.prompt_plan.get_prompts(**components)
|
381
|
+
plan_end = time.time()
|
382
|
+
logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
|
383
|
+
|
384
|
+
# Handle file keys if present
|
385
|
+
if hasattr(self, 'question_file_keys') and self.question_file_keys:
|
386
|
+
files_start = time.time()
|
387
|
+
files_list = []
|
388
|
+
for key in self.question_file_keys:
|
389
|
+
files_list.append(self.scenario[key])
|
390
|
+
prompts["files_list"] = files_list
|
391
|
+
files_end = time.time()
|
392
|
+
logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
|
393
|
+
|
394
|
+
end = time.time()
|
395
|
+
logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
|
396
|
+
return prompts
|
334
397
|
|
335
398
|
if __name__ == "__main__":
|
336
399
|
print("Running doctests...")
|
edsl/questions/descriptors.py
CHANGED
@@ -249,7 +249,28 @@ class QuestionNameDescriptor(BaseDescriptor):
|
|
249
249
|
|
250
250
|
|
251
251
|
class QuestionOptionsDescriptor(BaseDescriptor):
|
252
|
-
"""Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items.
|
252
|
+
"""Validate that `question_options` is a list, does not exceed the min/max lengths, and has unique items.
|
253
|
+
|
254
|
+
>>> import warnings
|
255
|
+
>>> q_class = QuestionOptionsDescriptor.example()
|
256
|
+
>>> with warnings.catch_warnings(record=True) as w:
|
257
|
+
... _ = q_class(["a ", "b", "c"]) # Has trailing space
|
258
|
+
... assert len(w) == 1
|
259
|
+
... assert "trailing whitespace" in str(w[0].message)
|
260
|
+
|
261
|
+
>>> _ = q_class(["a", "b", "c", "d", "d"])
|
262
|
+
Traceback (most recent call last):
|
263
|
+
...
|
264
|
+
edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
|
265
|
+
|
266
|
+
We allow dynamic question options, which are strings of the form '{{ question_options }}'.
|
267
|
+
|
268
|
+
>>> _ = q_class("{{dynamic_options}}")
|
269
|
+
>>> _ = q_class("dynamic_options")
|
270
|
+
Traceback (most recent call last):
|
271
|
+
...
|
272
|
+
edsl.exceptions.questions.QuestionCreationValidationError: ...
|
273
|
+
"""
|
253
274
|
|
254
275
|
@classmethod
|
255
276
|
def example(cls):
|
@@ -273,23 +294,7 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
273
294
|
self.q_budget = q_budget
|
274
295
|
|
275
296
|
def validate(self, value: Any, instance) -> None:
|
276
|
-
"""Validate the question options.
|
277
|
-
|
278
|
-
>>> q_class = QuestionOptionsDescriptor.example()
|
279
|
-
>>> _ = q_class(["a", "b", "c"])
|
280
|
-
>>> _ = q_class(["a", "b", "c", "d", "d"])
|
281
|
-
Traceback (most recent call last):
|
282
|
-
...
|
283
|
-
edsl.exceptions.questions.QuestionCreationValidationError: Question options must be unique (got ['a', 'b', 'c', 'd', 'd']).
|
284
|
-
|
285
|
-
We allow dynamic question options, which are strings of the form '{{ question_options }}'.
|
286
|
-
|
287
|
-
>>> _ = q_class("{{dynamic_options}}")
|
288
|
-
>>> _ = q_class("dynamic_options")
|
289
|
-
Traceback (most recent call last):
|
290
|
-
...
|
291
|
-
edsl.exceptions.questions.QuestionCreationValidationError: ...
|
292
|
-
"""
|
297
|
+
"""Validate the question options."""
|
293
298
|
if isinstance(value, str):
|
294
299
|
# Check if the string is a dynamic question option
|
295
300
|
if "{{" in value and "}}" in value:
|
@@ -343,6 +348,15 @@ class QuestionOptionsDescriptor(BaseDescriptor):
|
|
343
348
|
f"All question options must be at least 1 character long but less than {Settings.MAX_OPTION_LENGTH} characters long (got {value})."
|
344
349
|
)
|
345
350
|
|
351
|
+
# Check for trailing whitespace in string options
|
352
|
+
if any(isinstance(x, str) and (x != x.strip()) for x in value):
|
353
|
+
import warnings
|
354
|
+
|
355
|
+
warnings.warn(
|
356
|
+
"Some question options contain trailing whitespace. This may cause unexpected behavior.",
|
357
|
+
UserWarning,
|
358
|
+
)
|
359
|
+
|
346
360
|
if hasattr(instance, "min_selections") and instance.min_selections != None:
|
347
361
|
if instance.min_selections > len(value):
|
348
362
|
raise QuestionCreationValidationError(
|
@@ -7,7 +7,6 @@ from typing import Optional, Tuple, Union, List
|
|
7
7
|
|
8
8
|
from edsl.results.file_exports import CSVExport, ExcelExport, JSONLExport, SQLiteExport
|
9
9
|
|
10
|
-
|
11
10
|
class DatasetExportMixin:
|
12
11
|
"""Mixin class for exporting Dataset objects."""
|
13
12
|
|
@@ -220,23 +219,45 @@ class DatasetExportMixin:
|
|
220
219
|
)
|
221
220
|
return exporter.export()
|
222
221
|
|
223
|
-
def _db(self, remove_prefix: bool = True):
|
222
|
+
def _db(self, remove_prefix: bool = True, shape: str = "wide") -> "sqlalchemy.engine.Engine":
|
224
223
|
"""Create a SQLite database in memory and return the connection.
|
225
224
|
|
226
225
|
Args:
|
227
|
-
shape: The shape of the data in the database (wide or long)
|
228
226
|
remove_prefix: Whether to remove the prefix from the column names
|
227
|
+
shape: The shape of the data in the database ("wide" or "long")
|
229
228
|
|
230
229
|
Returns:
|
231
230
|
A database connection
|
231
|
+
>>> from sqlalchemy import text
|
232
|
+
>>> from edsl import Results
|
233
|
+
>>> engine = Results.example()._db()
|
234
|
+
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
235
|
+
4
|
236
|
+
>>> engine = Results.example()._db(shape = "long")
|
237
|
+
>>> len(engine.execute(text("SELECT * FROM self")).fetchall())
|
238
|
+
172
|
232
239
|
"""
|
233
|
-
from sqlalchemy import create_engine
|
240
|
+
from sqlalchemy import create_engine, text
|
234
241
|
|
235
242
|
engine = create_engine("sqlite:///:memory:")
|
236
|
-
if remove_prefix:
|
243
|
+
if remove_prefix and shape == "wide":
|
237
244
|
df = self.remove_prefix().to_pandas(lists_as_strings=True)
|
238
245
|
else:
|
239
246
|
df = self.to_pandas(lists_as_strings=True)
|
247
|
+
|
248
|
+
if shape == "long":
|
249
|
+
# Melt the dataframe to convert it to long format
|
250
|
+
df = df.melt(
|
251
|
+
var_name='key',
|
252
|
+
value_name='value'
|
253
|
+
)
|
254
|
+
# Add a row number column for reference
|
255
|
+
df.insert(0, 'row_number', range(1, len(df) + 1))
|
256
|
+
|
257
|
+
# Split the key into data_type and key
|
258
|
+
df['data_type'] = df['key'].apply(lambda x: x.split('.')[0] if '.' in x else None)
|
259
|
+
df['key'] = df['key'].apply(lambda x: '.'.join(x.split('.')[1:]) if '.' in x else x)
|
260
|
+
|
240
261
|
df.to_sql(
|
241
262
|
"self",
|
242
263
|
engine,
|
@@ -251,6 +272,7 @@ class DatasetExportMixin:
|
|
251
272
|
transpose: bool = None,
|
252
273
|
transpose_by: str = None,
|
253
274
|
remove_prefix: bool = True,
|
275
|
+
shape: str = "wide",
|
254
276
|
) -> Union["pd.DataFrame", str]:
|
255
277
|
"""Execute a SQL query and return the results as a DataFrame.
|
256
278
|
|
@@ -268,10 +290,17 @@ class DatasetExportMixin:
|
|
268
290
|
Returns:
|
269
291
|
DataFrame, CSV string, list, or LaTeX string depending on parameters
|
270
292
|
|
293
|
+
Examples:
|
294
|
+
>>> from edsl import Results
|
295
|
+
>>> r = Results.example();
|
296
|
+
>>> len(r.sql("SELECT * FROM self", shape = "wide"))
|
297
|
+
4
|
298
|
+
>>> len(r.sql("SELECT * FROM self", shape = "long"))
|
299
|
+
172
|
271
300
|
"""
|
272
301
|
import pandas as pd
|
273
302
|
|
274
|
-
conn = self._db(remove_prefix=remove_prefix)
|
303
|
+
conn = self._db(remove_prefix=remove_prefix, shape=shape)
|
275
304
|
df = pd.read_sql_query(query, conn)
|
276
305
|
|
277
306
|
# Transpose the DataFrame if transpose is True
|