edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +169 -116
- edsl/__init__.py +14 -6
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +358 -146
- edsl/agents/AgentList.py +211 -73
- edsl/agents/Invigilator.py +88 -36
- edsl/agents/InvigilatorBase.py +59 -70
- edsl/agents/PromptConstructor.py +117 -219
- edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
- edsl/agents/QuestionOptionProcessor.py +172 -0
- edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
- edsl/agents/__init__.py +0 -1
- edsl/agents/prompt_helpers.py +3 -3
- edsl/config.py +22 -2
- edsl/conversation/car_buying.py +2 -1
- edsl/coop/CoopFunctionsMixin.py +15 -0
- edsl/coop/ExpectedParrotKeyHandler.py +125 -0
- edsl/coop/PriceFetcher.py +1 -1
- edsl/coop/coop.py +104 -42
- edsl/coop/utils.py +14 -14
- edsl/data/Cache.py +21 -14
- edsl/data/CacheEntry.py +12 -15
- edsl/data/CacheHandler.py +33 -12
- edsl/data/__init__.py +4 -3
- edsl/data_transfer_models.py +2 -1
- edsl/enums.py +20 -0
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +12 -0
- edsl/exceptions/inference_services.py +5 -0
- edsl/exceptions/questions.py +24 -6
- edsl/exceptions/scenarios.py +7 -0
- edsl/inference_services/AnthropicService.py +0 -3
- edsl/inference_services/AvailableModelCacheHandler.py +184 -0
- edsl/inference_services/AvailableModelFetcher.py +209 -0
- edsl/inference_services/AwsBedrock.py +0 -2
- edsl/inference_services/AzureAI.py +0 -2
- edsl/inference_services/GoogleService.py +2 -11
- edsl/inference_services/InferenceServiceABC.py +18 -85
- edsl/inference_services/InferenceServicesCollection.py +105 -80
- edsl/inference_services/MistralAIService.py +0 -3
- edsl/inference_services/OpenAIService.py +1 -4
- edsl/inference_services/PerplexityService.py +0 -3
- edsl/inference_services/ServiceAvailability.py +135 -0
- edsl/inference_services/TestService.py +11 -8
- edsl/inference_services/data_structures.py +62 -0
- edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
- edsl/jobs/Answers.py +1 -14
- edsl/jobs/FetchInvigilator.py +40 -0
- edsl/jobs/InterviewTaskManager.py +98 -0
- edsl/jobs/InterviewsConstructor.py +48 -0
- edsl/jobs/Jobs.py +102 -243
- edsl/jobs/JobsChecks.py +35 -10
- edsl/jobs/JobsComponentConstructor.py +189 -0
- edsl/jobs/JobsPrompts.py +5 -3
- edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
- edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
- edsl/jobs/RequestTokenEstimator.py +30 -0
- edsl/jobs/buckets/BucketCollection.py +44 -3
- edsl/jobs/buckets/TokenBucket.py +53 -21
- edsl/jobs/buckets/TokenBucketAPI.py +211 -0
- edsl/jobs/buckets/TokenBucketClient.py +191 -0
- edsl/jobs/decorators.py +35 -0
- edsl/jobs/interviews/Interview.py +77 -380
- edsl/jobs/jobs_status_enums.py +9 -0
- edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
- edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
- edsl/jobs/tasks/TaskHistory.py +14 -15
- edsl/jobs/tasks/task_status_enum.py +0 -2
- edsl/language_models/ComputeCost.py +63 -0
- edsl/language_models/LanguageModel.py +137 -234
- edsl/language_models/ModelList.py +11 -13
- edsl/language_models/PriceManager.py +127 -0
- edsl/language_models/RawResponseHandler.py +106 -0
- edsl/language_models/ServiceDataSources.py +0 -0
- edsl/language_models/__init__.py +0 -1
- edsl/language_models/key_management/KeyLookup.py +63 -0
- edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
- edsl/language_models/key_management/KeyLookupCollection.py +38 -0
- edsl/language_models/key_management/__init__.py +0 -0
- edsl/language_models/key_management/models.py +131 -0
- edsl/language_models/registry.py +49 -59
- edsl/language_models/repair.py +2 -2
- edsl/language_models/utilities.py +5 -4
- edsl/notebooks/Notebook.py +19 -14
- edsl/notebooks/NotebookToLaTeX.py +142 -0
- edsl/prompts/Prompt.py +29 -39
- edsl/questions/AnswerValidatorMixin.py +47 -2
- edsl/questions/ExceptionExplainer.py +77 -0
- edsl/questions/HTMLQuestion.py +103 -0
- edsl/questions/LoopProcessor.py +149 -0
- edsl/questions/QuestionBase.py +37 -192
- edsl/questions/QuestionBaseGenMixin.py +52 -48
- edsl/questions/QuestionBasePromptsMixin.py +7 -3
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionExtract.py +1 -1
- edsl/questions/QuestionFreeText.py +1 -2
- edsl/questions/QuestionList.py +3 -5
- edsl/questions/QuestionMatrix.py +265 -0
- edsl/questions/QuestionMultipleChoice.py +66 -22
- edsl/questions/QuestionNumerical.py +1 -3
- edsl/questions/QuestionRank.py +6 -16
- edsl/questions/ResponseValidatorABC.py +37 -11
- edsl/questions/ResponseValidatorFactory.py +28 -0
- edsl/questions/SimpleAskMixin.py +4 -3
- edsl/questions/__init__.py +1 -0
- edsl/questions/derived/QuestionLinearScale.py +6 -3
- edsl/questions/derived/QuestionTopK.py +1 -1
- edsl/questions/descriptors.py +17 -3
- edsl/questions/question_registry.py +1 -1
- edsl/questions/templates/matrix/__init__.py +1 -0
- edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
- edsl/questions/templates/matrix/question_presentation.jinja +20 -0
- edsl/results/CSSParameterizer.py +1 -1
- edsl/results/Dataset.py +170 -7
- edsl/results/DatasetExportMixin.py +224 -302
- edsl/results/DatasetTree.py +28 -8
- edsl/results/MarkdownToDocx.py +122 -0
- edsl/results/MarkdownToPDF.py +111 -0
- edsl/results/Result.py +192 -206
- edsl/results/Results.py +120 -113
- edsl/results/ResultsExportMixin.py +2 -0
- edsl/results/Selector.py +23 -13
- edsl/results/TableDisplay.py +98 -171
- edsl/results/TextEditor.py +50 -0
- edsl/results/__init__.py +1 -1
- edsl/results/smart_objects.py +96 -0
- edsl/results/table_data_class.py +12 -0
- edsl/results/table_renderers.py +118 -0
- edsl/scenarios/ConstructDownloadLink.py +109 -0
- edsl/scenarios/DirectoryScanner.py +96 -0
- edsl/scenarios/DocumentChunker.py +102 -0
- edsl/scenarios/DocxScenario.py +16 -0
- edsl/scenarios/FileStore.py +118 -239
- edsl/scenarios/PdfExtractor.py +40 -0
- edsl/scenarios/Scenario.py +90 -193
- edsl/scenarios/ScenarioHtmlMixin.py +4 -3
- edsl/scenarios/ScenarioJoin.py +10 -6
- edsl/scenarios/ScenarioList.py +383 -240
- edsl/scenarios/ScenarioListExportMixin.py +0 -7
- edsl/scenarios/ScenarioListPdfMixin.py +15 -37
- edsl/scenarios/ScenarioSelector.py +156 -0
- edsl/scenarios/__init__.py +1 -2
- edsl/scenarios/file_methods.py +85 -0
- edsl/scenarios/handlers/__init__.py +13 -0
- edsl/scenarios/handlers/csv.py +38 -0
- edsl/scenarios/handlers/docx.py +76 -0
- edsl/scenarios/handlers/html.py +37 -0
- edsl/scenarios/handlers/json.py +111 -0
- edsl/scenarios/handlers/latex.py +5 -0
- edsl/scenarios/handlers/md.py +51 -0
- edsl/scenarios/handlers/pdf.py +68 -0
- edsl/scenarios/handlers/png.py +39 -0
- edsl/scenarios/handlers/pptx.py +105 -0
- edsl/scenarios/handlers/py.py +294 -0
- edsl/scenarios/handlers/sql.py +313 -0
- edsl/scenarios/handlers/sqlite.py +149 -0
- edsl/scenarios/handlers/txt.py +33 -0
- edsl/study/ObjectEntry.py +1 -1
- edsl/study/SnapShot.py +1 -1
- edsl/study/Study.py +5 -12
- edsl/surveys/ConstructDAG.py +92 -0
- edsl/surveys/EditSurvey.py +221 -0
- edsl/surveys/InstructionHandler.py +100 -0
- edsl/surveys/MemoryManagement.py +72 -0
- edsl/surveys/Rule.py +5 -4
- edsl/surveys/RuleCollection.py +25 -27
- edsl/surveys/RuleManager.py +172 -0
- edsl/surveys/Simulator.py +75 -0
- edsl/surveys/Survey.py +199 -771
- edsl/surveys/SurveyCSS.py +20 -8
- edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
- edsl/surveys/SurveyToApp.py +141 -0
- edsl/surveys/__init__.py +4 -2
- edsl/surveys/descriptors.py +6 -2
- edsl/surveys/instructions/ChangeInstruction.py +1 -2
- edsl/surveys/instructions/Instruction.py +4 -13
- edsl/surveys/instructions/InstructionCollection.py +11 -6
- edsl/templates/error_reporting/interview_details.html +1 -1
- edsl/templates/error_reporting/report.html +1 -1
- edsl/tools/plotting.py +1 -1
- edsl/utilities/PrettyList.py +56 -0
- edsl/utilities/is_notebook.py +18 -0
- edsl/utilities/is_valid_variable_name.py +11 -0
- edsl/utilities/remove_edsl_version.py +24 -0
- edsl/utilities/utilities.py +35 -23
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
- edsl-0.1.39.dev2.dist-info/RECORD +352 -0
- edsl/language_models/KeyLookup.py +0 -30
- edsl/language_models/unused/ReplicateBase.py +0 -83
- edsl/results/ResultsDBMixin.py +0 -238
- edsl-0.1.39.dev1.dist-info/RECORD +0 -277
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/results/Results.py
CHANGED
@@ -9,13 +9,7 @@ import random
|
|
9
9
|
from collections import UserList, defaultdict
|
10
10
|
from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
|
11
11
|
|
12
|
-
|
13
|
-
from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
|
14
|
-
from edsl.results.Result import Result
|
15
|
-
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
16
|
-
|
17
|
-
from simpleeval import EvalWithCompoundTypes
|
18
|
-
|
12
|
+
from edsl.Base import Base
|
19
13
|
from edsl.exceptions.results import (
|
20
14
|
ResultsError,
|
21
15
|
ResultsBadMutationstringError,
|
@@ -26,25 +20,27 @@ from edsl.exceptions.results import (
|
|
26
20
|
ResultsDeserializationError,
|
27
21
|
)
|
28
22
|
|
23
|
+
if TYPE_CHECKING:
|
24
|
+
from edsl.surveys.Survey import Survey
|
25
|
+
from edsl.data.Cache import Cache
|
26
|
+
from edsl.agents.AgentList import AgentList
|
27
|
+
from edsl.language_models.registry import Model
|
28
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
29
|
+
from edsl.results.Result import Result
|
30
|
+
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
31
|
+
from edsl.language_models.ModelList import ModelList
|
32
|
+
from simpleeval import EvalWithCompoundTypes
|
33
|
+
|
29
34
|
from edsl.results.ResultsExportMixin import ResultsExportMixin
|
30
|
-
from edsl.results.ResultsToolsMixin import ResultsToolsMixin
|
31
|
-
from edsl.results.ResultsDBMixin import ResultsDBMixin
|
32
35
|
from edsl.results.ResultsGGMixin import ResultsGGMixin
|
33
36
|
from edsl.results.ResultsFetchMixin import ResultsFetchMixin
|
34
|
-
|
35
|
-
from edsl.utilities.decorators import remove_edsl_version
|
36
|
-
from edsl.utilities.utilities import dict_hash
|
37
|
-
|
38
|
-
|
39
|
-
from edsl.Base import Base
|
37
|
+
from edsl.utilities.remove_edsl_version import remove_edsl_version
|
40
38
|
|
41
39
|
|
42
40
|
class Mixins(
|
43
41
|
ResultsExportMixin,
|
44
|
-
ResultsDBMixin,
|
45
42
|
ResultsFetchMixin,
|
46
43
|
ResultsGGMixin,
|
47
|
-
ResultsToolsMixin,
|
48
44
|
):
|
49
45
|
def long(self):
|
50
46
|
return self.table().long()
|
@@ -91,6 +87,7 @@ class Results(UserList, Mixins, Base):
|
|
91
87
|
"question_type",
|
92
88
|
"comment",
|
93
89
|
"generated_tokens",
|
90
|
+
"cache_used",
|
94
91
|
]
|
95
92
|
|
96
93
|
def __init__(
|
@@ -129,18 +126,13 @@ class Results(UserList, Mixins, Base):
|
|
129
126
|
def _summary(self) -> dict:
|
130
127
|
import reprlib
|
131
128
|
|
132
|
-
# import yaml
|
133
|
-
|
134
129
|
d = {
|
135
|
-
"
|
136
|
-
|
137
|
-
"
|
138
|
-
"
|
139
|
-
"
|
140
|
-
"# Scenarios": len(set(self.scenarios)),
|
141
|
-
"Survey Length (# questions)": len(self.survey),
|
130
|
+
"observations": len(self),
|
131
|
+
"agents": len(set(self.agents)),
|
132
|
+
"models": len(set(self.models)),
|
133
|
+
"scenarios": len(set(self.scenarios)),
|
134
|
+
"questions": len(self.survey),
|
142
135
|
"Survey question names": reprlib.repr(self.survey.question_names),
|
143
|
-
"Object hash": hash(self),
|
144
136
|
}
|
145
137
|
return d
|
146
138
|
|
@@ -258,23 +250,23 @@ class Results(UserList, Mixins, Base):
|
|
258
250
|
|
259
251
|
raise TypeError("Invalid argument type")
|
260
252
|
|
261
|
-
def _update_results(self) -> None:
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
253
|
+
# def _update_results(self) -> None:
|
254
|
+
# from edsl import Agent, Scenario
|
255
|
+
# from edsl.language_models import LanguageModel
|
256
|
+
# from edsl.results import Result
|
257
|
+
|
258
|
+
# if self._job_uuid and len(self.data) < self._total_results:
|
259
|
+
# results = [
|
260
|
+
# Result(
|
261
|
+
# agent=Agent.from_dict(json.loads(r.agent)),
|
262
|
+
# scenario=Scenario.from_dict(json.loads(r.scenario)),
|
263
|
+
# model=LanguageModel.from_dict(json.loads(r.model)),
|
264
|
+
# iteration=1,
|
265
|
+
# answer=json.loads(r.answer),
|
266
|
+
# )
|
267
|
+
# for r in CRUD.read_results(self._job_uuid)
|
268
|
+
# ]
|
269
|
+
# self.data = results
|
278
270
|
|
279
271
|
def __add__(self, other: Results) -> Results:
|
280
272
|
"""Add two Results objects together.
|
@@ -303,9 +295,9 @@ class Results(UserList, Mixins, Base):
|
|
303
295
|
)
|
304
296
|
|
305
297
|
def __repr__(self) -> str:
|
306
|
-
import reprlib
|
298
|
+
# import reprlib
|
307
299
|
|
308
|
-
return f"Results(data = {
|
300
|
+
return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
|
309
301
|
|
310
302
|
def table(
|
311
303
|
self,
|
@@ -345,21 +337,6 @@ class Results(UserList, Mixins, Base):
|
|
345
337
|
print_parameters=print_parameters,
|
346
338
|
)
|
347
339
|
)
|
348
|
-
# return (
|
349
|
-
# self.select(f"{selector_string}")
|
350
|
-
# .to_scenario_list()
|
351
|
-
# .table(*fields, tablefmt=tablefmt)
|
352
|
-
# )
|
353
|
-
|
354
|
-
def _repr_html_(self) -> str:
|
355
|
-
d = self._summary()
|
356
|
-
from edsl import Scenario
|
357
|
-
|
358
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
359
|
-
|
360
|
-
s = Scenario(d)
|
361
|
-
td = s.to_dataset().table(tablefmt="html")
|
362
|
-
return td._repr_html_() + footer
|
363
340
|
|
364
341
|
def to_dict(
|
365
342
|
self,
|
@@ -367,6 +344,7 @@ class Results(UserList, Mixins, Base):
|
|
367
344
|
add_edsl_version=False,
|
368
345
|
include_cache=False,
|
369
346
|
include_task_history=False,
|
347
|
+
include_cache_info=True,
|
370
348
|
) -> dict[str, Any]:
|
371
349
|
from edsl.data.Cache import Cache
|
372
350
|
|
@@ -377,7 +355,11 @@ class Results(UserList, Mixins, Base):
|
|
377
355
|
|
378
356
|
d = {
|
379
357
|
"data": [
|
380
|
-
result.to_dict(
|
358
|
+
result.to_dict(
|
359
|
+
add_edsl_version=add_edsl_version,
|
360
|
+
include_cache_info=include_cache_info,
|
361
|
+
)
|
362
|
+
for result in data
|
381
363
|
],
|
382
364
|
"survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
|
383
365
|
"created_columns": self.created_columns,
|
@@ -426,7 +408,11 @@ class Results(UserList, Mixins, Base):
|
|
426
408
|
return self.task_history.has_unfixed_exceptions
|
427
409
|
|
428
410
|
def __hash__(self) -> int:
|
429
|
-
|
411
|
+
from edsl.utilities.utilities import dict_hash
|
412
|
+
|
413
|
+
return dict_hash(
|
414
|
+
self.to_dict(sort=True, add_edsl_version=False, include_cache_info=False)
|
415
|
+
)
|
430
416
|
|
431
417
|
@property
|
432
418
|
def hashes(self) -> set:
|
@@ -472,24 +458,31 @@ class Results(UserList, Mixins, Base):
|
|
472
458
|
>>> r == r2
|
473
459
|
True
|
474
460
|
"""
|
475
|
-
from edsl import Survey
|
461
|
+
from edsl.surveys.Survey import Survey
|
462
|
+
from edsl.data.Cache import Cache
|
476
463
|
from edsl.results.Result import Result
|
477
464
|
from edsl.jobs.tasks.TaskHistory import TaskHistory
|
465
|
+
from edsl.agents.Agent import Agent
|
466
|
+
|
467
|
+
survey = Survey.from_dict(data["survey"])
|
468
|
+
results_data = [Result.from_dict(r) for r in data["data"]]
|
469
|
+
created_columns = data.get("created_columns", None)
|
470
|
+
cache = Cache.from_dict(data.get("cache")) if "cache" in data else Cache()
|
471
|
+
task_history = (
|
472
|
+
TaskHistory.from_dict(data.get("task_history"))
|
473
|
+
if "task_history" in data
|
474
|
+
else TaskHistory(interviews=[])
|
475
|
+
)
|
476
|
+
params = {
|
477
|
+
"survey": survey,
|
478
|
+
"data": results_data,
|
479
|
+
"created_columns": created_columns,
|
480
|
+
"cache": cache,
|
481
|
+
"task_history": task_history,
|
482
|
+
}
|
478
483
|
|
479
484
|
try:
|
480
|
-
results = cls(
|
481
|
-
survey=Survey.from_dict(data["survey"]),
|
482
|
-
data=[Result.from_dict(r) for r in data["data"]],
|
483
|
-
created_columns=data.get("created_columns", None),
|
484
|
-
cache=(
|
485
|
-
Cache.from_dict(data.get("cache")) if "cache" in data else Cache()
|
486
|
-
),
|
487
|
-
task_history=(
|
488
|
-
TaskHistory.from_dict(data.get("task_history"))
|
489
|
-
if "task_history" in data
|
490
|
-
else TaskHistory(interviews=[])
|
491
|
-
),
|
492
|
-
)
|
485
|
+
results = cls(**params)
|
493
486
|
except Exception as e:
|
494
487
|
raise ResultsDeserializationError(f"Error in Results.from_dict: {e}")
|
495
488
|
return results
|
@@ -544,10 +537,12 @@ class Results(UserList, Mixins, Base):
|
|
544
537
|
|
545
538
|
>>> r = Results.example()
|
546
539
|
>>> r.columns
|
547
|
-
['agent.
|
540
|
+
['agent.agent_index', ...]
|
548
541
|
"""
|
549
542
|
column_names = [f"{v}.{k}" for k, v in self._key_to_data_type.items()]
|
550
|
-
|
543
|
+
from edsl.utilities.PrettyList import PrettyList
|
544
|
+
|
545
|
+
return PrettyList(sorted(column_names))
|
551
546
|
|
552
547
|
@property
|
553
548
|
def answer_keys(self) -> dict[str, str]:
|
@@ -567,7 +562,7 @@ class Results(UserList, Mixins, Base):
|
|
567
562
|
answer_keys = self._data_type_to_keys["answer"]
|
568
563
|
answer_keys = {k for k in answer_keys if "_comment" not in k}
|
569
564
|
questions_text = [
|
570
|
-
self.survey.
|
565
|
+
self.survey._get_question_by_name(k).question_text for k in answer_keys
|
571
566
|
]
|
572
567
|
short_question_text = [shorten_string(q, 80) for q in questions_text]
|
573
568
|
initial_dict = dict(zip(answer_keys, short_question_text))
|
@@ -584,7 +579,7 @@ class Results(UserList, Mixins, Base):
|
|
584
579
|
>>> r.agents
|
585
580
|
AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'}), Agent(traits = {'status': 'Sad'})])
|
586
581
|
"""
|
587
|
-
from edsl import AgentList
|
582
|
+
from edsl.agents.AgentList import AgentList
|
588
583
|
|
589
584
|
return AgentList([r.agent for r in self.data])
|
590
585
|
|
@@ -598,10 +593,13 @@ class Results(UserList, Mixins, Base):
|
|
598
593
|
>>> r.models[0]
|
599
594
|
Model(model_name = ...)
|
600
595
|
"""
|
601
|
-
from edsl import ModelList
|
596
|
+
from edsl.language_models.ModelList import ModelList
|
602
597
|
|
603
598
|
return ModelList([r.model for r in self.data])
|
604
599
|
|
600
|
+
def __eq__(self, other):
|
601
|
+
return hash(self) == hash(other)
|
602
|
+
|
605
603
|
@property
|
606
604
|
def scenarios(self) -> ScenarioList:
|
607
605
|
"""Return a list of all of the scenarios in the Results.
|
@@ -610,9 +608,9 @@ class Results(UserList, Mixins, Base):
|
|
610
608
|
|
611
609
|
>>> r = Results.example()
|
612
610
|
>>> r.scenarios
|
613
|
-
ScenarioList([Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'}), Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'})])
|
611
|
+
ScenarioList([Scenario({'period': 'morning', 'scenario_index': 0}), Scenario({'period': 'afternoon', 'scenario_index': 1}), Scenario({'period': 'morning', 'scenario_index': 0}), Scenario({'period': 'afternoon', 'scenario_index': 1})])
|
614
612
|
"""
|
615
|
-
from edsl import ScenarioList
|
613
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
616
614
|
|
617
615
|
return ScenarioList([r.scenario for r in self.data])
|
618
616
|
|
@@ -624,7 +622,7 @@ class Results(UserList, Mixins, Base):
|
|
624
622
|
|
625
623
|
>>> r = Results.example()
|
626
624
|
>>> r.agent_keys
|
627
|
-
['agent_instruction', 'agent_name', 'status']
|
625
|
+
['agent_index', 'agent_instruction', 'agent_name', 'status']
|
628
626
|
"""
|
629
627
|
return sorted(self._data_type_to_keys["agent"])
|
630
628
|
|
@@ -634,7 +632,7 @@ class Results(UserList, Mixins, Base):
|
|
634
632
|
|
635
633
|
>>> r = Results.example()
|
636
634
|
>>> r.model_keys
|
637
|
-
['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
|
635
|
+
['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
|
638
636
|
"""
|
639
637
|
return sorted(self._data_type_to_keys["model"])
|
640
638
|
|
@@ -644,7 +642,7 @@ class Results(UserList, Mixins, Base):
|
|
644
642
|
|
645
643
|
>>> r = Results.example()
|
646
644
|
>>> r.scenario_keys
|
647
|
-
['period']
|
645
|
+
['period', 'scenario_index']
|
648
646
|
"""
|
649
647
|
return sorted(self._data_type_to_keys["scenario"])
|
650
648
|
|
@@ -670,7 +668,7 @@ class Results(UserList, Mixins, Base):
|
|
670
668
|
|
671
669
|
>>> r = Results.example()
|
672
670
|
>>> r.all_keys
|
673
|
-
['
|
671
|
+
['agent_index', ...]
|
674
672
|
"""
|
675
673
|
answer_keys = set(self.answer_keys)
|
676
674
|
all_keys = (
|
@@ -777,7 +775,7 @@ class Results(UserList, Mixins, Base):
|
|
777
775
|
@staticmethod
|
778
776
|
def _create_evaluator(
|
779
777
|
result: Result, functions_dict: Optional[dict] = None
|
780
|
-
) -> EvalWithCompoundTypes:
|
778
|
+
) -> "EvalWithCompoundTypes":
|
781
779
|
"""Create an evaluator for the expression.
|
782
780
|
|
783
781
|
>>> from unittest.mock import Mock
|
@@ -800,6 +798,8 @@ class Results(UserList, Mixins, Base):
|
|
800
798
|
...
|
801
799
|
simpleeval.NameNotDefined: 'how_feeling' is not defined for expression 'how_feeling== 'OK''
|
802
800
|
"""
|
801
|
+
from simpleeval import EvalWithCompoundTypes
|
802
|
+
|
803
803
|
if functions_dict is None:
|
804
804
|
functions_dict = {}
|
805
805
|
evaluator = EvalWithCompoundTypes(
|
@@ -858,6 +858,26 @@ class Results(UserList, Mixins, Base):
|
|
858
858
|
created_columns=self.created_columns + [var_name],
|
859
859
|
)
|
860
860
|
|
861
|
+
def add_column(self, column_name: str, values: list) -> Results:
|
862
|
+
"""Adds columns to Results
|
863
|
+
|
864
|
+
>>> r = Results.example()
|
865
|
+
>>> r.add_column('a', [1,2,3, 4]).select('a')
|
866
|
+
Dataset([{'answer.a': [1, 2, 3, 4]}])
|
867
|
+
"""
|
868
|
+
|
869
|
+
assert len(values) == len(
|
870
|
+
self.data
|
871
|
+
), "The number of values must match the number of results."
|
872
|
+
new_results = self.data.copy()
|
873
|
+
for i, result in enumerate(new_results):
|
874
|
+
result["answer"][column_name] = values[i]
|
875
|
+
return Results(
|
876
|
+
survey=self.survey,
|
877
|
+
data=new_results,
|
878
|
+
created_columns=self.created_columns + [column_name],
|
879
|
+
)
|
880
|
+
|
861
881
|
def rename(self, old_name: str, new_name: str) -> Results:
|
862
882
|
"""Rename an answer column in a Results object.
|
863
883
|
|
@@ -987,20 +1007,12 @@ class Results(UserList, Mixins, Base):
|
|
987
1007
|
Example:
|
988
1008
|
|
989
1009
|
>>> r = Results.example()
|
990
|
-
>>> r.sort_by('how_feeling', reverse=False).select('how_feeling')
|
991
|
-
answer.how_feeling
|
992
|
-
|
993
|
-
|
994
|
-
OK
|
995
|
-
|
996
|
-
Terrible
|
997
|
-
>>> r.sort_by('how_feeling', reverse=True).select('how_feeling').print()
|
998
|
-
answer.how_feeling
|
999
|
-
--------------------
|
1000
|
-
Terrible
|
1001
|
-
OK
|
1002
|
-
OK
|
1003
|
-
Great
|
1010
|
+
>>> r.sort_by('how_feeling', reverse=False).select('how_feeling')
|
1011
|
+
Dataset([{'answer.how_feeling': ['Great', 'OK', 'OK', 'Terrible']}])
|
1012
|
+
|
1013
|
+
>>> r.sort_by('how_feeling', reverse=True).select('how_feeling')
|
1014
|
+
Dataset([{'answer.how_feeling': ['Terrible', 'OK', 'OK', 'Great']}])
|
1015
|
+
|
1004
1016
|
"""
|
1005
1017
|
|
1006
1018
|
def to_numeric_if_possible(v):
|
@@ -1032,24 +1044,19 @@ class Results(UserList, Mixins, Base):
|
|
1032
1044
|
Example usage: Create an example `Results` instance and apply filters to it:
|
1033
1045
|
|
1034
1046
|
>>> r = Results.example()
|
1035
|
-
>>> r.filter("how_feeling == 'Great'").select('how_feeling')
|
1036
|
-
answer.how_feeling
|
1037
|
-
--------------------
|
1038
|
-
Great
|
1047
|
+
>>> r.filter("how_feeling == 'Great'").select('how_feeling')
|
1048
|
+
Dataset([{'answer.how_feeling': ['Great']}])
|
1039
1049
|
|
1040
1050
|
Example usage: Using an OR operator in the filter expression.
|
1041
1051
|
|
1042
|
-
>>> r = Results.example().filter("how_feeling = 'Great'").select('how_feeling')
|
1052
|
+
>>> r = Results.example().filter("how_feeling = 'Great'").select('how_feeling')
|
1043
1053
|
Traceback (most recent call last):
|
1044
1054
|
...
|
1045
1055
|
edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
|
1046
1056
|
...
|
1047
1057
|
|
1048
|
-
>>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling')
|
1049
|
-
answer.how_feeling
|
1050
|
-
--------------------
|
1051
|
-
Great
|
1052
|
-
Terrible
|
1058
|
+
>>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling')
|
1059
|
+
Dataset([{'answer.how_feeling': ['Great', 'Terrible']}])
|
1053
1060
|
"""
|
1054
1061
|
|
1055
1062
|
def has_single_equals(string):
|
@@ -14,6 +14,8 @@ def to_dataset(func):
|
|
14
14
|
"""Return the function with the Results object converted to a Dataset object."""
|
15
15
|
if self.__class__.__name__ == "Results":
|
16
16
|
return func(self.select(), *args, **kwargs)
|
17
|
+
elif self.__class__.__name__ == "AgentList":
|
18
|
+
return func(self.to_dataset(), *args, **kwargs)
|
17
19
|
else:
|
18
20
|
return func(self, *args, **kwargs)
|
19
21
|
|
edsl/results/Selector.py
CHANGED
@@ -1,7 +1,12 @@
|
|
1
|
-
from typing import Union, List, Dict, Any
|
1
|
+
from typing import Union, List, Dict, Any, Optional
|
2
|
+
import sys
|
2
3
|
from collections import defaultdict
|
3
4
|
from edsl.results.Dataset import Dataset
|
4
5
|
|
6
|
+
from edsl.exceptions.results import ResultsColumnNotFoundError
|
7
|
+
|
8
|
+
from edsl.utilities.is_notebook import is_notebook
|
9
|
+
|
5
10
|
|
6
11
|
class Selector:
|
7
12
|
def __init__(
|
@@ -19,11 +24,17 @@ class Selector:
|
|
19
24
|
self._fetch_list = fetch_list_func
|
20
25
|
self.columns = columns
|
21
26
|
|
22
|
-
def select(self, *columns: Union[str, List[str]]) ->
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
+
def select(self, *columns: Union[str, List[str]]) -> Optional[Dataset]:
|
28
|
+
try:
|
29
|
+
columns = self._normalize_columns(columns)
|
30
|
+
to_fetch = self._get_columns_to_fetch(columns)
|
31
|
+
new_data = self._fetch_data(to_fetch)
|
32
|
+
except ResultsColumnNotFoundError as e:
|
33
|
+
if is_notebook():
|
34
|
+
print("Error:", e, file=sys.stderr)
|
35
|
+
return None
|
36
|
+
else:
|
37
|
+
raise e
|
27
38
|
return Dataset(new_data)
|
28
39
|
|
29
40
|
def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
|
@@ -63,17 +74,16 @@ class Selector:
|
|
63
74
|
search_in_list = self.columns
|
64
75
|
else:
|
65
76
|
search_in_list = [s.split(".")[1] for s in self.columns]
|
66
|
-
# breakpoint()
|
67
77
|
matches = [s for s in search_in_list if s.startswith(partial_name)]
|
68
78
|
return [partial_name] if partial_name in matches else matches
|
69
79
|
|
70
80
|
def _validate_matches(self, column: str, matches: List[str]):
|
71
81
|
if len(matches) > 1:
|
72
|
-
raise
|
82
|
+
raise ResultsColumnNotFoundError(
|
73
83
|
f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
|
74
84
|
)
|
75
85
|
if len(matches) == 0 and ".*" not in column:
|
76
|
-
raise
|
86
|
+
raise ResultsColumnNotFoundError(f"Column '{column}' not found in data.")
|
77
87
|
|
78
88
|
def _parse_column(self, column: str) -> tuple[str, str]:
|
79
89
|
if "." in column:
|
@@ -89,11 +99,11 @@ class Selector:
|
|
89
99
|
close_matches = difflib.get_close_matches(column, self._key_to_data_type.keys())
|
90
100
|
if close_matches:
|
91
101
|
suggestions = ", ".join(close_matches)
|
92
|
-
raise
|
102
|
+
raise ResultsColumnNotFoundError(
|
93
103
|
f"Column '{column}' not found in data. Did you mean: {suggestions}?"
|
94
104
|
)
|
95
105
|
else:
|
96
|
-
raise
|
106
|
+
raise ResultsColumnNotFoundError(f"Column {column} not found in data")
|
97
107
|
|
98
108
|
def _process_column(self, data_type: str, key: str, to_fetch: Dict[str, List[str]]):
|
99
109
|
data_types = self._get_data_types_to_return(data_type)
|
@@ -108,13 +118,13 @@ class Selector:
|
|
108
118
|
self.items_in_order.append(f"{dt}.{k}")
|
109
119
|
|
110
120
|
if not found_once:
|
111
|
-
raise
|
121
|
+
raise ResultsColumnNotFoundError(f"Key {key} not found in data.")
|
112
122
|
|
113
123
|
def _get_data_types_to_return(self, parsed_data_type: str) -> List[str]:
|
114
124
|
if parsed_data_type == "*":
|
115
125
|
return self.known_data_types
|
116
126
|
if parsed_data_type not in self.known_data_types:
|
117
|
-
raise
|
127
|
+
raise ResultsColumnNotFoundError(
|
118
128
|
f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
|
119
129
|
)
|
120
130
|
return [parsed_data_type]
|