edsl 0.1.39.dev1__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (194) hide show
  1. edsl/Base.py +169 -116
  2. edsl/__init__.py +14 -6
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +358 -146
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +88 -36
  7. edsl/agents/InvigilatorBase.py +59 -70
  8. edsl/agents/PromptConstructor.py +117 -219
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionOptionProcessor.py +172 -0
  11. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  12. edsl/agents/__init__.py +0 -1
  13. edsl/agents/prompt_helpers.py +3 -3
  14. edsl/config.py +22 -2
  15. edsl/conversation/car_buying.py +2 -1
  16. edsl/coop/CoopFunctionsMixin.py +15 -0
  17. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  18. edsl/coop/PriceFetcher.py +1 -1
  19. edsl/coop/coop.py +104 -42
  20. edsl/coop/utils.py +14 -14
  21. edsl/data/Cache.py +21 -14
  22. edsl/data/CacheEntry.py +12 -15
  23. edsl/data/CacheHandler.py +33 -12
  24. edsl/data/__init__.py +4 -3
  25. edsl/data_transfer_models.py +2 -1
  26. edsl/enums.py +20 -0
  27. edsl/exceptions/__init__.py +50 -50
  28. edsl/exceptions/agents.py +12 -0
  29. edsl/exceptions/inference_services.py +5 -0
  30. edsl/exceptions/questions.py +24 -6
  31. edsl/exceptions/scenarios.py +7 -0
  32. edsl/inference_services/AnthropicService.py +0 -3
  33. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  34. edsl/inference_services/AvailableModelFetcher.py +209 -0
  35. edsl/inference_services/AwsBedrock.py +0 -2
  36. edsl/inference_services/AzureAI.py +0 -2
  37. edsl/inference_services/GoogleService.py +2 -11
  38. edsl/inference_services/InferenceServiceABC.py +18 -85
  39. edsl/inference_services/InferenceServicesCollection.py +105 -80
  40. edsl/inference_services/MistralAIService.py +0 -3
  41. edsl/inference_services/OpenAIService.py +1 -4
  42. edsl/inference_services/PerplexityService.py +0 -3
  43. edsl/inference_services/ServiceAvailability.py +135 -0
  44. edsl/inference_services/TestService.py +11 -8
  45. edsl/inference_services/data_structures.py +62 -0
  46. edsl/jobs/AnswerQuestionFunctionConstructor.py +188 -0
  47. edsl/jobs/Answers.py +1 -14
  48. edsl/jobs/FetchInvigilator.py +40 -0
  49. edsl/jobs/InterviewTaskManager.py +98 -0
  50. edsl/jobs/InterviewsConstructor.py +48 -0
  51. edsl/jobs/Jobs.py +102 -243
  52. edsl/jobs/JobsChecks.py +35 -10
  53. edsl/jobs/JobsComponentConstructor.py +189 -0
  54. edsl/jobs/JobsPrompts.py +5 -3
  55. edsl/jobs/JobsRemoteInferenceHandler.py +128 -80
  56. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  57. edsl/jobs/RequestTokenEstimator.py +30 -0
  58. edsl/jobs/buckets/BucketCollection.py +44 -3
  59. edsl/jobs/buckets/TokenBucket.py +53 -21
  60. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  61. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  62. edsl/jobs/decorators.py +35 -0
  63. edsl/jobs/interviews/Interview.py +77 -380
  64. edsl/jobs/jobs_status_enums.py +9 -0
  65. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  66. edsl/jobs/runners/JobsRunnerAsyncio.py +4 -49
  67. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  68. edsl/jobs/tasks/TaskHistory.py +14 -15
  69. edsl/jobs/tasks/task_status_enum.py +0 -2
  70. edsl/language_models/ComputeCost.py +63 -0
  71. edsl/language_models/LanguageModel.py +137 -234
  72. edsl/language_models/ModelList.py +11 -13
  73. edsl/language_models/PriceManager.py +127 -0
  74. edsl/language_models/RawResponseHandler.py +106 -0
  75. edsl/language_models/ServiceDataSources.py +0 -0
  76. edsl/language_models/__init__.py +0 -1
  77. edsl/language_models/key_management/KeyLookup.py +63 -0
  78. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  79. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  80. edsl/language_models/key_management/__init__.py +0 -0
  81. edsl/language_models/key_management/models.py +131 -0
  82. edsl/language_models/registry.py +49 -59
  83. edsl/language_models/repair.py +2 -2
  84. edsl/language_models/utilities.py +5 -4
  85. edsl/notebooks/Notebook.py +19 -14
  86. edsl/notebooks/NotebookToLaTeX.py +142 -0
  87. edsl/prompts/Prompt.py +29 -39
  88. edsl/questions/AnswerValidatorMixin.py +47 -2
  89. edsl/questions/ExceptionExplainer.py +77 -0
  90. edsl/questions/HTMLQuestion.py +103 -0
  91. edsl/questions/LoopProcessor.py +149 -0
  92. edsl/questions/QuestionBase.py +37 -192
  93. edsl/questions/QuestionBaseGenMixin.py +52 -48
  94. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  95. edsl/questions/QuestionCheckBox.py +1 -1
  96. edsl/questions/QuestionExtract.py +1 -1
  97. edsl/questions/QuestionFreeText.py +1 -2
  98. edsl/questions/QuestionList.py +3 -5
  99. edsl/questions/QuestionMatrix.py +265 -0
  100. edsl/questions/QuestionMultipleChoice.py +66 -22
  101. edsl/questions/QuestionNumerical.py +1 -3
  102. edsl/questions/QuestionRank.py +6 -16
  103. edsl/questions/ResponseValidatorABC.py +37 -11
  104. edsl/questions/ResponseValidatorFactory.py +28 -0
  105. edsl/questions/SimpleAskMixin.py +4 -3
  106. edsl/questions/__init__.py +1 -0
  107. edsl/questions/derived/QuestionLinearScale.py +6 -3
  108. edsl/questions/derived/QuestionTopK.py +1 -1
  109. edsl/questions/descriptors.py +17 -3
  110. edsl/questions/question_registry.py +1 -1
  111. edsl/questions/templates/matrix/__init__.py +1 -0
  112. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  113. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  114. edsl/results/CSSParameterizer.py +1 -1
  115. edsl/results/Dataset.py +170 -7
  116. edsl/results/DatasetExportMixin.py +224 -302
  117. edsl/results/DatasetTree.py +28 -8
  118. edsl/results/MarkdownToDocx.py +122 -0
  119. edsl/results/MarkdownToPDF.py +111 -0
  120. edsl/results/Result.py +192 -206
  121. edsl/results/Results.py +120 -113
  122. edsl/results/ResultsExportMixin.py +2 -0
  123. edsl/results/Selector.py +23 -13
  124. edsl/results/TableDisplay.py +98 -171
  125. edsl/results/TextEditor.py +50 -0
  126. edsl/results/__init__.py +1 -1
  127. edsl/results/smart_objects.py +96 -0
  128. edsl/results/table_data_class.py +12 -0
  129. edsl/results/table_renderers.py +118 -0
  130. edsl/scenarios/ConstructDownloadLink.py +109 -0
  131. edsl/scenarios/DirectoryScanner.py +96 -0
  132. edsl/scenarios/DocumentChunker.py +102 -0
  133. edsl/scenarios/DocxScenario.py +16 -0
  134. edsl/scenarios/FileStore.py +118 -239
  135. edsl/scenarios/PdfExtractor.py +40 -0
  136. edsl/scenarios/Scenario.py +90 -193
  137. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  138. edsl/scenarios/ScenarioJoin.py +10 -6
  139. edsl/scenarios/ScenarioList.py +383 -240
  140. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  141. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  142. edsl/scenarios/ScenarioSelector.py +156 -0
  143. edsl/scenarios/__init__.py +1 -2
  144. edsl/scenarios/file_methods.py +85 -0
  145. edsl/scenarios/handlers/__init__.py +13 -0
  146. edsl/scenarios/handlers/csv.py +38 -0
  147. edsl/scenarios/handlers/docx.py +76 -0
  148. edsl/scenarios/handlers/html.py +37 -0
  149. edsl/scenarios/handlers/json.py +111 -0
  150. edsl/scenarios/handlers/latex.py +5 -0
  151. edsl/scenarios/handlers/md.py +51 -0
  152. edsl/scenarios/handlers/pdf.py +68 -0
  153. edsl/scenarios/handlers/png.py +39 -0
  154. edsl/scenarios/handlers/pptx.py +105 -0
  155. edsl/scenarios/handlers/py.py +294 -0
  156. edsl/scenarios/handlers/sql.py +313 -0
  157. edsl/scenarios/handlers/sqlite.py +149 -0
  158. edsl/scenarios/handlers/txt.py +33 -0
  159. edsl/study/ObjectEntry.py +1 -1
  160. edsl/study/SnapShot.py +1 -1
  161. edsl/study/Study.py +5 -12
  162. edsl/surveys/ConstructDAG.py +92 -0
  163. edsl/surveys/EditSurvey.py +221 -0
  164. edsl/surveys/InstructionHandler.py +100 -0
  165. edsl/surveys/MemoryManagement.py +72 -0
  166. edsl/surveys/Rule.py +5 -4
  167. edsl/surveys/RuleCollection.py +25 -27
  168. edsl/surveys/RuleManager.py +172 -0
  169. edsl/surveys/Simulator.py +75 -0
  170. edsl/surveys/Survey.py +199 -771
  171. edsl/surveys/SurveyCSS.py +20 -8
  172. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  173. edsl/surveys/SurveyToApp.py +141 -0
  174. edsl/surveys/__init__.py +4 -2
  175. edsl/surveys/descriptors.py +6 -2
  176. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  177. edsl/surveys/instructions/Instruction.py +4 -13
  178. edsl/surveys/instructions/InstructionCollection.py +11 -6
  179. edsl/templates/error_reporting/interview_details.html +1 -1
  180. edsl/templates/error_reporting/report.html +1 -1
  181. edsl/tools/plotting.py +1 -1
  182. edsl/utilities/PrettyList.py +56 -0
  183. edsl/utilities/is_notebook.py +18 -0
  184. edsl/utilities/is_valid_variable_name.py +11 -0
  185. edsl/utilities/remove_edsl_version.py +24 -0
  186. edsl/utilities/utilities.py +35 -23
  187. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +12 -10
  188. edsl-0.1.39.dev2.dist-info/RECORD +352 -0
  189. edsl/language_models/KeyLookup.py +0 -30
  190. edsl/language_models/unused/ReplicateBase.py +0 -83
  191. edsl/results/ResultsDBMixin.py +0 -238
  192. edsl-0.1.39.dev1.dist-info/RECORD +0 -277
  193. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
  194. {edsl-0.1.39.dev1.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +0 -0
edsl/results/Results.py CHANGED
@@ -9,13 +9,7 @@ import random
9
9
  from collections import UserList, defaultdict
10
10
  from typing import Optional, Callable, Any, Type, Union, List, TYPE_CHECKING
11
11
 
12
- if TYPE_CHECKING:
13
- from edsl import Survey, Cache, AgentList, ModelList, ScenarioList
14
- from edsl.results.Result import Result
15
- from edsl.jobs.tasks.TaskHistory import TaskHistory
16
-
17
- from simpleeval import EvalWithCompoundTypes
18
-
12
+ from edsl.Base import Base
19
13
  from edsl.exceptions.results import (
20
14
  ResultsError,
21
15
  ResultsBadMutationstringError,
@@ -26,25 +20,27 @@ from edsl.exceptions.results import (
26
20
  ResultsDeserializationError,
27
21
  )
28
22
 
23
+ if TYPE_CHECKING:
24
+ from edsl.surveys.Survey import Survey
25
+ from edsl.data.Cache import Cache
26
+ from edsl.agents.AgentList import AgentList
27
+ from edsl.language_models.registry import Model
28
+ from edsl.scenarios.ScenarioList import ScenarioList
29
+ from edsl.results.Result import Result
30
+ from edsl.jobs.tasks.TaskHistory import TaskHistory
31
+ from edsl.language_models.ModelList import ModelList
32
+ from simpleeval import EvalWithCompoundTypes
33
+
29
34
  from edsl.results.ResultsExportMixin import ResultsExportMixin
30
- from edsl.results.ResultsToolsMixin import ResultsToolsMixin
31
- from edsl.results.ResultsDBMixin import ResultsDBMixin
32
35
  from edsl.results.ResultsGGMixin import ResultsGGMixin
33
36
  from edsl.results.ResultsFetchMixin import ResultsFetchMixin
34
-
35
- from edsl.utilities.decorators import remove_edsl_version
36
- from edsl.utilities.utilities import dict_hash
37
-
38
-
39
- from edsl.Base import Base
37
+ from edsl.utilities.remove_edsl_version import remove_edsl_version
40
38
 
41
39
 
42
40
  class Mixins(
43
41
  ResultsExportMixin,
44
- ResultsDBMixin,
45
42
  ResultsFetchMixin,
46
43
  ResultsGGMixin,
47
- ResultsToolsMixin,
48
44
  ):
49
45
  def long(self):
50
46
  return self.table().long()
@@ -91,6 +87,7 @@ class Results(UserList, Mixins, Base):
91
87
  "question_type",
92
88
  "comment",
93
89
  "generated_tokens",
90
+ "cache_used",
94
91
  ]
95
92
 
96
93
  def __init__(
@@ -129,18 +126,13 @@ class Results(UserList, Mixins, Base):
129
126
  def _summary(self) -> dict:
130
127
  import reprlib
131
128
 
132
- # import yaml
133
-
134
129
  d = {
135
- "EDSL Class": "Results",
136
- # "docs_url": self.__documentation__,
137
- "# of agents": len(set(self.agents)),
138
- "# of distinct models": len(set(self.models)),
139
- "# of observations": len(self),
140
- "# Scenarios": len(set(self.scenarios)),
141
- "Survey Length (# questions)": len(self.survey),
130
+ "observations": len(self),
131
+ "agents": len(set(self.agents)),
132
+ "models": len(set(self.models)),
133
+ "scenarios": len(set(self.scenarios)),
134
+ "questions": len(self.survey),
142
135
  "Survey question names": reprlib.repr(self.survey.question_names),
143
- "Object hash": hash(self),
144
136
  }
145
137
  return d
146
138
 
@@ -258,23 +250,23 @@ class Results(UserList, Mixins, Base):
258
250
 
259
251
  raise TypeError("Invalid argument type")
260
252
 
261
- def _update_results(self) -> None:
262
- from edsl import Agent, Scenario
263
- from edsl.language_models import LanguageModel
264
- from edsl.results import Result
265
-
266
- if self._job_uuid and len(self.data) < self._total_results:
267
- results = [
268
- Result(
269
- agent=Agent.from_dict(json.loads(r.agent)),
270
- scenario=Scenario.from_dict(json.loads(r.scenario)),
271
- model=LanguageModel.from_dict(json.loads(r.model)),
272
- iteration=1,
273
- answer=json.loads(r.answer),
274
- )
275
- for r in CRUD.read_results(self._job_uuid)
276
- ]
277
- self.data = results
253
+ # def _update_results(self) -> None:
254
+ # from edsl import Agent, Scenario
255
+ # from edsl.language_models import LanguageModel
256
+ # from edsl.results import Result
257
+
258
+ # if self._job_uuid and len(self.data) < self._total_results:
259
+ # results = [
260
+ # Result(
261
+ # agent=Agent.from_dict(json.loads(r.agent)),
262
+ # scenario=Scenario.from_dict(json.loads(r.scenario)),
263
+ # model=LanguageModel.from_dict(json.loads(r.model)),
264
+ # iteration=1,
265
+ # answer=json.loads(r.answer),
266
+ # )
267
+ # for r in CRUD.read_results(self._job_uuid)
268
+ # ]
269
+ # self.data = results
278
270
 
279
271
  def __add__(self, other: Results) -> Results:
280
272
  """Add two Results objects together.
@@ -303,9 +295,9 @@ class Results(UserList, Mixins, Base):
303
295
  )
304
296
 
305
297
  def __repr__(self) -> str:
306
- import reprlib
298
+ # import reprlib
307
299
 
308
- return f"Results(data = {reprlib.repr(self.data)}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
300
+ return f"Results(data = {self.data}, survey = {repr(self.survey)}, created_columns = {self.created_columns})"
309
301
 
310
302
  def table(
311
303
  self,
@@ -345,21 +337,6 @@ class Results(UserList, Mixins, Base):
345
337
  print_parameters=print_parameters,
346
338
  )
347
339
  )
348
- # return (
349
- # self.select(f"{selector_string}")
350
- # .to_scenario_list()
351
- # .table(*fields, tablefmt=tablefmt)
352
- # )
353
-
354
- def _repr_html_(self) -> str:
355
- d = self._summary()
356
- from edsl import Scenario
357
-
358
- footer = f"<a href={self.__documentation__}>(docs)</a>"
359
-
360
- s = Scenario(d)
361
- td = s.to_dataset().table(tablefmt="html")
362
- return td._repr_html_() + footer
363
340
 
364
341
  def to_dict(
365
342
  self,
@@ -367,6 +344,7 @@ class Results(UserList, Mixins, Base):
367
344
  add_edsl_version=False,
368
345
  include_cache=False,
369
346
  include_task_history=False,
347
+ include_cache_info=True,
370
348
  ) -> dict[str, Any]:
371
349
  from edsl.data.Cache import Cache
372
350
 
@@ -377,7 +355,11 @@ class Results(UserList, Mixins, Base):
377
355
 
378
356
  d = {
379
357
  "data": [
380
- result.to_dict(add_edsl_version=add_edsl_version) for result in data
358
+ result.to_dict(
359
+ add_edsl_version=add_edsl_version,
360
+ include_cache_info=include_cache_info,
361
+ )
362
+ for result in data
381
363
  ],
382
364
  "survey": self.survey.to_dict(add_edsl_version=add_edsl_version),
383
365
  "created_columns": self.created_columns,
@@ -426,7 +408,11 @@ class Results(UserList, Mixins, Base):
426
408
  return self.task_history.has_unfixed_exceptions
427
409
 
428
410
  def __hash__(self) -> int:
429
- return dict_hash(self.to_dict(sort=True, add_edsl_version=False))
411
+ from edsl.utilities.utilities import dict_hash
412
+
413
+ return dict_hash(
414
+ self.to_dict(sort=True, add_edsl_version=False, include_cache_info=False)
415
+ )
430
416
 
431
417
  @property
432
418
  def hashes(self) -> set:
@@ -472,24 +458,31 @@ class Results(UserList, Mixins, Base):
472
458
  >>> r == r2
473
459
  True
474
460
  """
475
- from edsl import Survey, Cache
461
+ from edsl.surveys.Survey import Survey
462
+ from edsl.data.Cache import Cache
476
463
  from edsl.results.Result import Result
477
464
  from edsl.jobs.tasks.TaskHistory import TaskHistory
465
+ from edsl.agents.Agent import Agent
466
+
467
+ survey = Survey.from_dict(data["survey"])
468
+ results_data = [Result.from_dict(r) for r in data["data"]]
469
+ created_columns = data.get("created_columns", None)
470
+ cache = Cache.from_dict(data.get("cache")) if "cache" in data else Cache()
471
+ task_history = (
472
+ TaskHistory.from_dict(data.get("task_history"))
473
+ if "task_history" in data
474
+ else TaskHistory(interviews=[])
475
+ )
476
+ params = {
477
+ "survey": survey,
478
+ "data": results_data,
479
+ "created_columns": created_columns,
480
+ "cache": cache,
481
+ "task_history": task_history,
482
+ }
478
483
 
479
484
  try:
480
- results = cls(
481
- survey=Survey.from_dict(data["survey"]),
482
- data=[Result.from_dict(r) for r in data["data"]],
483
- created_columns=data.get("created_columns", None),
484
- cache=(
485
- Cache.from_dict(data.get("cache")) if "cache" in data else Cache()
486
- ),
487
- task_history=(
488
- TaskHistory.from_dict(data.get("task_history"))
489
- if "task_history" in data
490
- else TaskHistory(interviews=[])
491
- ),
492
- )
485
+ results = cls(**params)
493
486
  except Exception as e:
494
487
  raise ResultsDeserializationError(f"Error in Results.from_dict: {e}")
495
488
  return results
@@ -544,10 +537,12 @@ class Results(UserList, Mixins, Base):
544
537
 
545
538
  >>> r = Results.example()
546
539
  >>> r.columns
547
- ['agent.agent_instruction', ...]
540
+ ['agent.agent_index', ...]
548
541
  """
549
542
  column_names = [f"{v}.{k}" for k, v in self._key_to_data_type.items()]
550
- return sorted(column_names)
543
+ from edsl.utilities.PrettyList import PrettyList
544
+
545
+ return PrettyList(sorted(column_names))
551
546
 
552
547
  @property
553
548
  def answer_keys(self) -> dict[str, str]:
@@ -567,7 +562,7 @@ class Results(UserList, Mixins, Base):
567
562
  answer_keys = self._data_type_to_keys["answer"]
568
563
  answer_keys = {k for k in answer_keys if "_comment" not in k}
569
564
  questions_text = [
570
- self.survey.get_question(k).question_text for k in answer_keys
565
+ self.survey._get_question_by_name(k).question_text for k in answer_keys
571
566
  ]
572
567
  short_question_text = [shorten_string(q, 80) for q in questions_text]
573
568
  initial_dict = dict(zip(answer_keys, short_question_text))
@@ -584,7 +579,7 @@ class Results(UserList, Mixins, Base):
584
579
  >>> r.agents
585
580
  AgentList([Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Joyful'}), Agent(traits = {'status': 'Sad'}), Agent(traits = {'status': 'Sad'})])
586
581
  """
587
- from edsl import AgentList
582
+ from edsl.agents.AgentList import AgentList
588
583
 
589
584
  return AgentList([r.agent for r in self.data])
590
585
 
@@ -598,10 +593,13 @@ class Results(UserList, Mixins, Base):
598
593
  >>> r.models[0]
599
594
  Model(model_name = ...)
600
595
  """
601
- from edsl import ModelList
596
+ from edsl.language_models.ModelList import ModelList
602
597
 
603
598
  return ModelList([r.model for r in self.data])
604
599
 
600
+ def __eq__(self, other):
601
+ return hash(self) == hash(other)
602
+
605
603
  @property
606
604
  def scenarios(self) -> ScenarioList:
607
605
  """Return a list of all of the scenarios in the Results.
@@ -610,9 +608,9 @@ class Results(UserList, Mixins, Base):
610
608
 
611
609
  >>> r = Results.example()
612
610
  >>> r.scenarios
613
- ScenarioList([Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'}), Scenario({'period': 'morning'}), Scenario({'period': 'afternoon'})])
611
+ ScenarioList([Scenario({'period': 'morning', 'scenario_index': 0}), Scenario({'period': 'afternoon', 'scenario_index': 1}), Scenario({'period': 'morning', 'scenario_index': 0}), Scenario({'period': 'afternoon', 'scenario_index': 1})])
614
612
  """
615
- from edsl import ScenarioList
613
+ from edsl.scenarios.ScenarioList import ScenarioList
616
614
 
617
615
  return ScenarioList([r.scenario for r in self.data])
618
616
 
@@ -624,7 +622,7 @@ class Results(UserList, Mixins, Base):
624
622
 
625
623
  >>> r = Results.example()
626
624
  >>> r.agent_keys
627
- ['agent_instruction', 'agent_name', 'status']
625
+ ['agent_index', 'agent_instruction', 'agent_name', 'status']
628
626
  """
629
627
  return sorted(self._data_type_to_keys["agent"])
630
628
 
@@ -634,7 +632,7 @@ class Results(UserList, Mixins, Base):
634
632
 
635
633
  >>> r = Results.example()
636
634
  >>> r.model_keys
637
- ['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
635
+ ['frequency_penalty', 'logprobs', 'max_tokens', 'model', 'model_index', 'presence_penalty', 'temperature', 'top_logprobs', 'top_p']
638
636
  """
639
637
  return sorted(self._data_type_to_keys["model"])
640
638
 
@@ -644,7 +642,7 @@ class Results(UserList, Mixins, Base):
644
642
 
645
643
  >>> r = Results.example()
646
644
  >>> r.scenario_keys
647
- ['period']
645
+ ['period', 'scenario_index']
648
646
  """
649
647
  return sorted(self._data_type_to_keys["scenario"])
650
648
 
@@ -670,7 +668,7 @@ class Results(UserList, Mixins, Base):
670
668
 
671
669
  >>> r = Results.example()
672
670
  >>> r.all_keys
673
- ['agent_instruction', 'agent_name', 'frequency_penalty', 'how_feeling', 'how_feeling_yesterday', 'logprobs', 'max_tokens', 'model', 'period', 'presence_penalty', 'status', 'temperature', 'top_logprobs', 'top_p']
671
+ ['agent_index', ...]
674
672
  """
675
673
  answer_keys = set(self.answer_keys)
676
674
  all_keys = (
@@ -777,7 +775,7 @@ class Results(UserList, Mixins, Base):
777
775
  @staticmethod
778
776
  def _create_evaluator(
779
777
  result: Result, functions_dict: Optional[dict] = None
780
- ) -> EvalWithCompoundTypes:
778
+ ) -> "EvalWithCompoundTypes":
781
779
  """Create an evaluator for the expression.
782
780
 
783
781
  >>> from unittest.mock import Mock
@@ -800,6 +798,8 @@ class Results(UserList, Mixins, Base):
800
798
  ...
801
799
  simpleeval.NameNotDefined: 'how_feeling' is not defined for expression 'how_feeling== 'OK''
802
800
  """
801
+ from simpleeval import EvalWithCompoundTypes
802
+
803
803
  if functions_dict is None:
804
804
  functions_dict = {}
805
805
  evaluator = EvalWithCompoundTypes(
@@ -858,6 +858,26 @@ class Results(UserList, Mixins, Base):
858
858
  created_columns=self.created_columns + [var_name],
859
859
  )
860
860
 
861
+ def add_column(self, column_name: str, values: list) -> Results:
862
+ """Adds columns to Results
863
+
864
+ >>> r = Results.example()
865
+ >>> r.add_column('a', [1,2,3, 4]).select('a')
866
+ Dataset([{'answer.a': [1, 2, 3, 4]}])
867
+ """
868
+
869
+ assert len(values) == len(
870
+ self.data
871
+ ), "The number of values must match the number of results."
872
+ new_results = self.data.copy()
873
+ for i, result in enumerate(new_results):
874
+ result["answer"][column_name] = values[i]
875
+ return Results(
876
+ survey=self.survey,
877
+ data=new_results,
878
+ created_columns=self.created_columns + [column_name],
879
+ )
880
+
861
881
  def rename(self, old_name: str, new_name: str) -> Results:
862
882
  """Rename an answer column in a Results object.
863
883
 
@@ -987,20 +1007,12 @@ class Results(UserList, Mixins, Base):
987
1007
  Example:
988
1008
 
989
1009
  >>> r = Results.example()
990
- >>> r.sort_by('how_feeling', reverse=False).select('how_feeling').print()
991
- answer.how_feeling
992
- --------------------
993
- Great
994
- OK
995
- OK
996
- Terrible
997
- >>> r.sort_by('how_feeling', reverse=True).select('how_feeling').print()
998
- answer.how_feeling
999
- --------------------
1000
- Terrible
1001
- OK
1002
- OK
1003
- Great
1010
+ >>> r.sort_by('how_feeling', reverse=False).select('how_feeling')
1011
+ Dataset([{'answer.how_feeling': ['Great', 'OK', 'OK', 'Terrible']}])
1012
+
1013
+ >>> r.sort_by('how_feeling', reverse=True).select('how_feeling')
1014
+ Dataset([{'answer.how_feeling': ['Terrible', 'OK', 'OK', 'Great']}])
1015
+
1004
1016
  """
1005
1017
 
1006
1018
  def to_numeric_if_possible(v):
@@ -1032,24 +1044,19 @@ class Results(UserList, Mixins, Base):
1032
1044
  Example usage: Create an example `Results` instance and apply filters to it:
1033
1045
 
1034
1046
  >>> r = Results.example()
1035
- >>> r.filter("how_feeling == 'Great'").select('how_feeling').print()
1036
- answer.how_feeling
1037
- --------------------
1038
- Great
1047
+ >>> r.filter("how_feeling == 'Great'").select('how_feeling')
1048
+ Dataset([{'answer.how_feeling': ['Great']}])
1039
1049
 
1040
1050
  Example usage: Using an OR operator in the filter expression.
1041
1051
 
1042
- >>> r = Results.example().filter("how_feeling = 'Great'").select('how_feeling').print()
1052
+ >>> r = Results.example().filter("how_feeling = 'Great'").select('how_feeling')
1043
1053
  Traceback (most recent call last):
1044
1054
  ...
1045
1055
  edsl.exceptions.results.ResultsFilterError: You must use '==' instead of '=' in the filter expression.
1046
1056
  ...
1047
1057
 
1048
- >>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling').print()
1049
- answer.how_feeling
1050
- --------------------
1051
- Great
1052
- Terrible
1058
+ >>> r.filter("how_feeling == 'Great' or how_feeling == 'Terrible'").select('how_feeling')
1059
+ Dataset([{'answer.how_feeling': ['Great', 'Terrible']}])
1053
1060
  """
1054
1061
 
1055
1062
  def has_single_equals(string):
@@ -14,6 +14,8 @@ def to_dataset(func):
14
14
  """Return the function with the Results object converted to a Dataset object."""
15
15
  if self.__class__.__name__ == "Results":
16
16
  return func(self.select(), *args, **kwargs)
17
+ elif self.__class__.__name__ == "AgentList":
18
+ return func(self.to_dataset(), *args, **kwargs)
17
19
  else:
18
20
  return func(self, *args, **kwargs)
19
21
 
edsl/results/Selector.py CHANGED
@@ -1,7 +1,12 @@
1
- from typing import Union, List, Dict, Any
1
+ from typing import Union, List, Dict, Any, Optional
2
+ import sys
2
3
  from collections import defaultdict
3
4
  from edsl.results.Dataset import Dataset
4
5
 
6
+ from edsl.exceptions.results import ResultsColumnNotFoundError
7
+
8
+ from edsl.utilities.is_notebook import is_notebook
9
+
5
10
 
6
11
  class Selector:
7
12
  def __init__(
@@ -19,11 +24,17 @@ class Selector:
19
24
  self._fetch_list = fetch_list_func
20
25
  self.columns = columns
21
26
 
22
- def select(self, *columns: Union[str, List[str]]) -> "Dataset":
23
- columns = self._normalize_columns(columns)
24
- to_fetch = self._get_columns_to_fetch(columns)
25
- # breakpoint()
26
- new_data = self._fetch_data(to_fetch)
27
+ def select(self, *columns: Union[str, List[str]]) -> Optional[Dataset]:
28
+ try:
29
+ columns = self._normalize_columns(columns)
30
+ to_fetch = self._get_columns_to_fetch(columns)
31
+ new_data = self._fetch_data(to_fetch)
32
+ except ResultsColumnNotFoundError as e:
33
+ if is_notebook():
34
+ print("Error:", e, file=sys.stderr)
35
+ return None
36
+ else:
37
+ raise e
27
38
  return Dataset(new_data)
28
39
 
29
40
  def _normalize_columns(self, columns: Union[str, List[str]]) -> tuple:
@@ -63,17 +74,16 @@ class Selector:
63
74
  search_in_list = self.columns
64
75
  else:
65
76
  search_in_list = [s.split(".")[1] for s in self.columns]
66
- # breakpoint()
67
77
  matches = [s for s in search_in_list if s.startswith(partial_name)]
68
78
  return [partial_name] if partial_name in matches else matches
69
79
 
70
80
  def _validate_matches(self, column: str, matches: List[str]):
71
81
  if len(matches) > 1:
72
- raise ValueError(
82
+ raise ResultsColumnNotFoundError(
73
83
  f"Column '{column}' is ambiguous. Did you mean one of {matches}?"
74
84
  )
75
85
  if len(matches) == 0 and ".*" not in column:
76
- raise ValueError(f"Column '{column}' not found in data.")
86
+ raise ResultsColumnNotFoundError(f"Column '{column}' not found in data.")
77
87
 
78
88
  def _parse_column(self, column: str) -> tuple[str, str]:
79
89
  if "." in column:
@@ -89,11 +99,11 @@ class Selector:
89
99
  close_matches = difflib.get_close_matches(column, self._key_to_data_type.keys())
90
100
  if close_matches:
91
101
  suggestions = ", ".join(close_matches)
92
- raise KeyError(
102
+ raise ResultsColumnNotFoundError(
93
103
  f"Column '{column}' not found in data. Did you mean: {suggestions}?"
94
104
  )
95
105
  else:
96
- raise KeyError(f"Column {column} not found in data")
106
+ raise ResultsColumnNotFoundError(f"Column {column} not found in data")
97
107
 
98
108
  def _process_column(self, data_type: str, key: str, to_fetch: Dict[str, List[str]]):
99
109
  data_types = self._get_data_types_to_return(data_type)
@@ -108,13 +118,13 @@ class Selector:
108
118
  self.items_in_order.append(f"{dt}.{k}")
109
119
 
110
120
  if not found_once:
111
- raise ValueError(f"Key {key} not found in data.")
121
+ raise ResultsColumnNotFoundError(f"Key {key} not found in data.")
112
122
 
113
123
  def _get_data_types_to_return(self, parsed_data_type: str) -> List[str]:
114
124
  if parsed_data_type == "*":
115
125
  return self.known_data_types
116
126
  if parsed_data_type not in self.known_data_types:
117
- raise ValueError(
127
+ raise ResultsColumnNotFoundError(
118
128
  f"Data type {parsed_data_type} not found in data. Did you mean one of {self.known_data_types}"
119
129
  )
120
130
  return [parsed_data_type]