edsl 0.1.32__py3-none-any.whl → 0.1.33__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (181) hide show
  1. edsl/Base.py +9 -3
  2. edsl/TemplateLoader.py +24 -0
  3. edsl/__init__.py +8 -3
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +40 -8
  6. edsl/agents/AgentList.py +43 -0
  7. edsl/agents/Invigilator.py +135 -219
  8. edsl/agents/InvigilatorBase.py +148 -59
  9. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +138 -89
  10. edsl/agents/__init__.py +1 -0
  11. edsl/auto/AutoStudy.py +117 -0
  12. edsl/auto/StageBase.py +230 -0
  13. edsl/auto/StageGenerateSurvey.py +178 -0
  14. edsl/auto/StageLabelQuestions.py +125 -0
  15. edsl/auto/StagePersona.py +61 -0
  16. edsl/auto/StagePersonaDimensionValueRanges.py +88 -0
  17. edsl/auto/StagePersonaDimensionValues.py +74 -0
  18. edsl/auto/StagePersonaDimensions.py +69 -0
  19. edsl/auto/StageQuestions.py +73 -0
  20. edsl/auto/SurveyCreatorPipeline.py +21 -0
  21. edsl/auto/utilities.py +224 -0
  22. edsl/config.py +47 -56
  23. edsl/coop/PriceFetcher.py +58 -0
  24. edsl/coop/coop.py +50 -7
  25. edsl/data/Cache.py +35 -1
  26. edsl/data_transfer_models.py +73 -38
  27. edsl/enums.py +4 -0
  28. edsl/exceptions/language_models.py +25 -1
  29. edsl/exceptions/questions.py +62 -5
  30. edsl/exceptions/results.py +4 -0
  31. edsl/inference_services/AnthropicService.py +13 -11
  32. edsl/inference_services/AwsBedrock.py +19 -17
  33. edsl/inference_services/AzureAI.py +37 -20
  34. edsl/inference_services/GoogleService.py +16 -12
  35. edsl/inference_services/GroqService.py +2 -0
  36. edsl/inference_services/InferenceServiceABC.py +58 -3
  37. edsl/inference_services/MistralAIService.py +120 -0
  38. edsl/inference_services/OpenAIService.py +48 -54
  39. edsl/inference_services/TestService.py +80 -0
  40. edsl/inference_services/TogetherAIService.py +170 -0
  41. edsl/inference_services/models_available_cache.py +0 -6
  42. edsl/inference_services/registry.py +6 -0
  43. edsl/jobs/Answers.py +10 -12
  44. edsl/jobs/FailedQuestion.py +78 -0
  45. edsl/jobs/Jobs.py +37 -22
  46. edsl/jobs/buckets/BucketCollection.py +24 -15
  47. edsl/jobs/buckets/TokenBucket.py +93 -14
  48. edsl/jobs/interviews/Interview.py +366 -78
  49. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +14 -68
  50. edsl/jobs/interviews/InterviewExceptionEntry.py +85 -19
  51. edsl/jobs/runners/JobsRunnerAsyncio.py +146 -175
  52. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  53. edsl/jobs/tasks/QuestionTaskCreator.py +30 -23
  54. edsl/jobs/tasks/TaskHistory.py +148 -213
  55. edsl/language_models/LanguageModel.py +261 -156
  56. edsl/language_models/ModelList.py +2 -2
  57. edsl/language_models/RegisterLanguageModelsMeta.py +14 -29
  58. edsl/language_models/fake_openai_call.py +15 -0
  59. edsl/language_models/fake_openai_service.py +61 -0
  60. edsl/language_models/registry.py +23 -6
  61. edsl/language_models/repair.py +0 -19
  62. edsl/language_models/utilities.py +61 -0
  63. edsl/notebooks/Notebook.py +20 -2
  64. edsl/prompts/Prompt.py +52 -2
  65. edsl/questions/AnswerValidatorMixin.py +23 -26
  66. edsl/questions/QuestionBase.py +330 -249
  67. edsl/questions/QuestionBaseGenMixin.py +133 -0
  68. edsl/questions/QuestionBasePromptsMixin.py +266 -0
  69. edsl/questions/QuestionBudget.py +99 -41
  70. edsl/questions/QuestionCheckBox.py +227 -35
  71. edsl/questions/QuestionExtract.py +98 -27
  72. edsl/questions/QuestionFreeText.py +52 -29
  73. edsl/questions/QuestionFunctional.py +7 -0
  74. edsl/questions/QuestionList.py +141 -22
  75. edsl/questions/QuestionMultipleChoice.py +159 -65
  76. edsl/questions/QuestionNumerical.py +88 -46
  77. edsl/questions/QuestionRank.py +182 -24
  78. edsl/questions/Quick.py +41 -0
  79. edsl/questions/RegisterQuestionsMeta.py +31 -12
  80. edsl/questions/ResponseValidatorABC.py +170 -0
  81. edsl/questions/__init__.py +3 -4
  82. edsl/questions/decorators.py +21 -0
  83. edsl/questions/derived/QuestionLikertFive.py +10 -5
  84. edsl/questions/derived/QuestionLinearScale.py +15 -2
  85. edsl/questions/derived/QuestionTopK.py +10 -1
  86. edsl/questions/derived/QuestionYesNo.py +24 -3
  87. edsl/questions/descriptors.py +43 -7
  88. edsl/questions/prompt_templates/question_budget.jinja +13 -0
  89. edsl/questions/prompt_templates/question_checkbox.jinja +32 -0
  90. edsl/questions/prompt_templates/question_extract.jinja +11 -0
  91. edsl/questions/prompt_templates/question_free_text.jinja +3 -0
  92. edsl/questions/prompt_templates/question_linear_scale.jinja +11 -0
  93. edsl/questions/prompt_templates/question_list.jinja +17 -0
  94. edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -0
  95. edsl/questions/prompt_templates/question_numerical.jinja +37 -0
  96. edsl/questions/question_registry.py +6 -2
  97. edsl/questions/templates/__init__.py +0 -0
  98. edsl/questions/templates/budget/__init__.py +0 -0
  99. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  100. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  101. edsl/questions/templates/checkbox/__init__.py +0 -0
  102. edsl/questions/templates/checkbox/answering_instructions.jinja +10 -0
  103. edsl/questions/templates/checkbox/question_presentation.jinja +22 -0
  104. edsl/questions/templates/extract/__init__.py +0 -0
  105. edsl/questions/templates/extract/answering_instructions.jinja +7 -0
  106. edsl/questions/templates/extract/question_presentation.jinja +1 -0
  107. edsl/questions/templates/free_text/__init__.py +0 -0
  108. edsl/questions/templates/free_text/answering_instructions.jinja +0 -0
  109. edsl/questions/templates/free_text/question_presentation.jinja +1 -0
  110. edsl/questions/templates/likert_five/__init__.py +0 -0
  111. edsl/questions/templates/likert_five/answering_instructions.jinja +10 -0
  112. edsl/questions/templates/likert_five/question_presentation.jinja +12 -0
  113. edsl/questions/templates/linear_scale/__init__.py +0 -0
  114. edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -0
  115. edsl/questions/templates/linear_scale/question_presentation.jinja +5 -0
  116. edsl/questions/templates/list/__init__.py +0 -0
  117. edsl/questions/templates/list/answering_instructions.jinja +4 -0
  118. edsl/questions/templates/list/question_presentation.jinja +5 -0
  119. edsl/questions/templates/multiple_choice/__init__.py +0 -0
  120. edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -0
  121. edsl/questions/templates/multiple_choice/html.jinja +0 -0
  122. edsl/questions/templates/multiple_choice/question_presentation.jinja +12 -0
  123. edsl/questions/templates/numerical/__init__.py +0 -0
  124. edsl/questions/templates/numerical/answering_instructions.jinja +8 -0
  125. edsl/questions/templates/numerical/question_presentation.jinja +7 -0
  126. edsl/questions/templates/rank/__init__.py +0 -0
  127. edsl/questions/templates/rank/answering_instructions.jinja +11 -0
  128. edsl/questions/templates/rank/question_presentation.jinja +15 -0
  129. edsl/questions/templates/top_k/__init__.py +0 -0
  130. edsl/questions/templates/top_k/answering_instructions.jinja +8 -0
  131. edsl/questions/templates/top_k/question_presentation.jinja +22 -0
  132. edsl/questions/templates/yes_no/__init__.py +0 -0
  133. edsl/questions/templates/yes_no/answering_instructions.jinja +6 -0
  134. edsl/questions/templates/yes_no/question_presentation.jinja +12 -0
  135. edsl/results/Dataset.py +20 -0
  136. edsl/results/DatasetExportMixin.py +46 -48
  137. edsl/results/DatasetTree.py +145 -0
  138. edsl/results/Result.py +32 -5
  139. edsl/results/Results.py +135 -46
  140. edsl/results/ResultsDBMixin.py +3 -3
  141. edsl/results/Selector.py +118 -0
  142. edsl/results/tree_explore.py +115 -0
  143. edsl/scenarios/FileStore.py +71 -10
  144. edsl/scenarios/Scenario.py +96 -25
  145. edsl/scenarios/ScenarioImageMixin.py +2 -2
  146. edsl/scenarios/ScenarioList.py +361 -39
  147. edsl/scenarios/ScenarioListExportMixin.py +9 -0
  148. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  149. edsl/study/SnapShot.py +8 -1
  150. edsl/study/Study.py +32 -0
  151. edsl/surveys/Rule.py +10 -1
  152. edsl/surveys/RuleCollection.py +21 -5
  153. edsl/surveys/Survey.py +637 -311
  154. edsl/surveys/SurveyExportMixin.py +71 -9
  155. edsl/surveys/SurveyFlowVisualizationMixin.py +2 -1
  156. edsl/surveys/SurveyQualtricsImport.py +75 -4
  157. edsl/surveys/instructions/ChangeInstruction.py +47 -0
  158. edsl/surveys/instructions/Instruction.py +34 -0
  159. edsl/surveys/instructions/InstructionCollection.py +77 -0
  160. edsl/surveys/instructions/__init__.py +0 -0
  161. edsl/templates/error_reporting/base.html +24 -0
  162. edsl/templates/error_reporting/exceptions_by_model.html +35 -0
  163. edsl/templates/error_reporting/exceptions_by_question_name.html +17 -0
  164. edsl/templates/error_reporting/exceptions_by_type.html +17 -0
  165. edsl/templates/error_reporting/interview_details.html +116 -0
  166. edsl/templates/error_reporting/interviews.html +10 -0
  167. edsl/templates/error_reporting/overview.html +5 -0
  168. edsl/templates/error_reporting/performance_plot.html +2 -0
  169. edsl/templates/error_reporting/report.css +74 -0
  170. edsl/templates/error_reporting/report.html +118 -0
  171. edsl/templates/error_reporting/report.js +25 -0
  172. edsl/utilities/utilities.py +9 -1
  173. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/METADATA +5 -2
  174. edsl-0.1.33.dist-info/RECORD +295 -0
  175. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +0 -286
  176. edsl/jobs/interviews/retry_management.py +0 -37
  177. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  178. edsl/utilities/gcp_bucket/simple_example.py +0 -9
  179. edsl-0.1.32.dist-info/RECORD +0 -209
  180. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/LICENSE +0 -0
  181. {edsl-0.1.32.dist-info → edsl-0.1.33.dist-info}/WHEEL +0 -0
@@ -2,24 +2,62 @@ import traceback
2
2
  import datetime
3
3
  import time
4
4
  from collections import UserDict
5
-
6
- # traceback=traceback.format_exc(),
7
- # traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
8
- # traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
5
+ from edsl.jobs.FailedQuestion import FailedQuestion
9
6
 
10
7
 
11
8
  class InterviewExceptionEntry:
12
- """Class to record an exception that occurred during the interview.
13
-
14
- >>> entry = InterviewExceptionEntry.example()
15
- >>> entry.to_dict()['exception']
16
- "ValueError('An error occurred.')"
17
- """
18
-
19
- def __init__(self, exception: Exception, traceback_format="html"):
9
+ """Class to record an exception that occurred during the interview."""
10
+
11
+ def __init__(
12
+ self,
13
+ *,
14
+ exception: Exception,
15
+ # failed_question: FailedQuestion,
16
+ invigilator: "Invigilator",
17
+ traceback_format="text",
18
+ answers=None,
19
+ ):
20
20
  self.time = datetime.datetime.now().isoformat()
21
21
  self.exception = exception
22
+ # self.failed_question = failed_question
23
+ self.invigilator = invigilator
22
24
  self.traceback_format = traceback_format
25
+ self.answers = answers
26
+
27
+ @property
28
+ def question_type(self):
29
+ # return self.failed_question.question.question_type
30
+ return self.invigilator.question.question_type
31
+
32
+ @property
33
+ def name(self):
34
+ return repr(self.exception)
35
+
36
+ @property
37
+ def rendered_prompts(self):
38
+ return self.invigilator.get_prompts()
39
+
40
+ @property
41
+ def key_sequence(self):
42
+ return self.invigilator.model.key_sequence
43
+
44
+ @property
45
+ def generated_token_string(self):
46
+ # return "POO"
47
+ if self.invigilator.raw_model_response is None:
48
+ return "No raw model response available."
49
+ else:
50
+ return self.invigilator.model.get_generated_token_string(
51
+ self.invigilator.raw_model_response
52
+ )
53
+
54
+ @property
55
+ def raw_model_response(self):
56
+ import json
57
+
58
+ if self.invigilator.raw_model_response is None:
59
+ return "No raw model response available."
60
+ return json.dumps(self.invigilator.raw_model_response, indent=2)
23
61
 
24
62
  def __getitem__(self, key):
25
63
  # Support dict-like access obj['a']
@@ -27,11 +65,37 @@ class InterviewExceptionEntry:
27
65
 
28
66
  @classmethod
29
67
  def example(cls):
30
- try:
31
- raise ValueError("An error occurred.")
32
- except Exception as e:
33
- entry = InterviewExceptionEntry(e)
34
- return entry
68
+ from edsl import QuestionFreeText
69
+ from edsl.language_models import LanguageModel
70
+
71
+ m = LanguageModel.example(test_model=True)
72
+ q = QuestionFreeText.example(exception_to_throw=ValueError)
73
+ results = q.by(m).run(
74
+ skip_retry=True, print_exceptions=False, raise_validation_errors=True
75
+ )
76
+ return results.task_history.exceptions[0]["how_are_you"][0]
77
+
78
+ @property
79
+ def code_to_reproduce(self):
80
+ return self.code(run=False)
81
+
82
+ def code(self, run=True):
83
+ lines = []
84
+ lines.append("from edsl import Question, Model, Scenario, Agent")
85
+
86
+ lines.append(f"q = {repr(self.invigilator.question)}")
87
+ lines.append(f"scenario = {repr(self.invigilator.scenario)}")
88
+ lines.append(f"agent = {repr(self.invigilator.agent)}")
89
+ lines.append(f"m = Model('{self.invigilator.model.model}')")
90
+ lines.append("results = q.by(m).by(agent).by(scenario).run()")
91
+ code_str = "\n".join(lines)
92
+
93
+ if run:
94
+ # Create a new namespace to avoid polluting the global namespace
95
+ namespace = {}
96
+ exec(code_str, namespace)
97
+ return namespace["results"]
98
+ return code_str
35
99
 
36
100
  @property
37
101
  def traceback(self):
@@ -78,13 +142,15 @@ class InterviewExceptionEntry:
78
142
 
79
143
  >>> entry = InterviewExceptionEntry.example()
80
144
  >>> entry.to_dict()['exception']
81
- "ValueError('An error occurred.')"
145
+ ValueError()
82
146
 
83
147
  """
84
148
  return {
85
- "exception": repr(self.exception),
149
+ "exception": self.exception,
86
150
  "time": self.time,
87
151
  "traceback": self.traceback,
152
+ # "failed_question": self.failed_question.to_dict(),
153
+ "invigilator": self.invigilator.to_dict(),
88
154
  }
89
155
 
90
156
  def push(self):
@@ -1,42 +1,27 @@
1
1
  from __future__ import annotations
2
2
  import time
3
+ import math
3
4
  import asyncio
4
- import time
5
+ import functools
6
+ import threading
7
+ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
5
8
  from contextlib import contextmanager
9
+ from collections import UserList
6
10
 
7
- from typing import Coroutine, List, AsyncGenerator, Optional, Union
11
+ from edsl.results.Results import Results
12
+ from rich.live import Live
13
+ from rich.console import Console
8
14
 
9
15
  from edsl import shared_globals
10
16
  from edsl.jobs.interviews.Interview import Interview
11
- from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
17
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
18
+
12
19
  from edsl.jobs.tasks.TaskHistory import TaskHistory
13
20
  from edsl.jobs.buckets.BucketCollection import BucketCollection
14
21
  from edsl.utilities.decorators import jupyter_nb_handler
15
-
16
- import time
17
- import functools
18
-
19
-
20
- def cache_with_timeout(timeout):
21
- def decorator(func):
22
- cached_result = {}
23
- last_computation_time = [0] # Using list to store mutable value
24
-
25
- @functools.wraps(func)
26
- def wrapper(*args, **kwargs):
27
- current_time = time.time()
28
- if (current_time - last_computation_time[0]) >= timeout:
29
- cached_result["value"] = func(*args, **kwargs)
30
- last_computation_time[0] = current_time
31
- return cached_result["value"]
32
-
33
- return wrapper
34
-
35
- return decorator
36
-
37
-
38
- # from queue import Queue
39
- from collections import UserList
22
+ from edsl.data.Cache import Cache
23
+ from edsl.results.Result import Result
24
+ from edsl.results.Results import Results
40
25
 
41
26
 
42
27
  class StatusTracker(UserList):
@@ -48,99 +33,87 @@ class StatusTracker(UserList):
48
33
  return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
34
 
50
35
 
51
- class JobsRunnerAsyncio(JobsRunnerStatusMixin):
36
+ class JobsRunnerAsyncio:
52
37
  """A class for running a collection of interviews asynchronously.
53
38
 
54
39
  It gets instaniated from a Jobs object.
55
40
  The Jobs object is a collection of interviews that are to be run.
56
41
  """
57
42
 
58
- def __init__(self, jobs: Jobs):
43
+ def __init__(self, jobs: "Jobs"):
59
44
  self.jobs = jobs
60
- # this creates the interviews, which can take a while
61
45
  self.interviews: List["Interview"] = jobs.interviews()
62
46
  self.bucket_collection: "BucketCollection" = jobs.bucket_collection
63
47
  self.total_interviews: List["Interview"] = []
64
48
 
49
+ # self.jobs_runner_status = JobsRunnerStatus(self, n=1)
50
+
65
51
  async def run_async_generator(
66
52
  self,
67
53
  cache: "Cache",
68
54
  n: int = 1,
69
- debug: bool = False,
70
55
  stop_on_exception: bool = False,
71
- sidecar_model: "LanguageModel" = None,
56
+ sidecar_model: Optional["LanguageModel"] = None,
72
57
  total_interviews: Optional[List["Interview"]] = None,
58
+ raise_validation_errors: bool = False,
73
59
  ) -> AsyncGenerator["Result", None]:
74
60
  """Creates the tasks, runs them asynchronously, and returns the results as a Results object.
75
61
 
76
62
  Completed tasks are yielded as they are completed.
77
63
 
78
64
  :param n: how many times to run each interview
79
- :param debug:
80
65
  :param stop_on_exception: Whether to stop the interview if an exception is raised
81
66
  :param sidecar_model: a language model to use in addition to the interview's model
82
67
  :param total_interviews: A list of interviews to run can be provided instead.
68
+ :param raise_validation_errors: Whether to raise validation errors
83
69
  """
84
70
  tasks = []
85
- if total_interviews:
71
+ if total_interviews: # was already passed in total interviews
86
72
  self.total_interviews = total_interviews
87
73
  else:
88
- self._populate_total_interviews(
89
- n=n
74
+ self.total_interviews = list(
75
+ self._populate_total_interviews(n=n)
90
76
  ) # Populate self.total_interviews before creating tasks
91
77
 
92
78
  for interview in self.total_interviews:
93
79
  interviewing_task = self._build_interview_task(
94
80
  interview=interview,
95
- debug=debug,
96
81
  stop_on_exception=stop_on_exception,
97
82
  sidecar_model=sidecar_model,
83
+ raise_validation_errors=raise_validation_errors,
98
84
  )
99
85
  tasks.append(asyncio.create_task(interviewing_task))
100
86
 
101
87
  for task in asyncio.as_completed(tasks):
102
88
  result = await task
89
+ self.jobs_runner_status.add_completed_interview(result)
103
90
  yield result
104
91
 
105
- def _populate_total_interviews(self, n: int = 1) -> None:
92
+ def _populate_total_interviews(
93
+ self, n: int = 1
94
+ ) -> Generator["Interview", None, None]:
106
95
  """Populates self.total_interviews with n copies of each interview.
107
96
 
108
97
  :param n: how many times to run each interview.
109
98
  """
110
- # TODO: Why not return a list of interviews instead of modifying the object?
111
-
112
- self.total_interviews = []
113
99
  for interview in self.interviews:
114
100
  for iteration in range(n):
115
101
  if iteration > 0:
116
- new_interview = interview.duplicate(
117
- iteration=iteration, cache=self.cache
118
- )
119
- self.total_interviews.append(new_interview)
102
+ yield interview.duplicate(iteration=iteration, cache=self.cache)
120
103
  else:
121
- interview.cache = (
122
- self.cache
123
- ) # set the cache for the first interview
124
- self.total_interviews.append(interview)
125
-
126
- async def run_async(self, cache=None, n=1) -> Results:
127
- from edsl.results.Results import Results
104
+ interview.cache = self.cache
105
+ yield interview
128
106
 
129
- # breakpoint()
130
- # tracker = StatusTracker(total_tasks=len(self.interviews))
131
-
132
- if cache is None:
133
- self.cache = Cache()
134
- else:
135
- self.cache = cache
107
+ async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
108
+ """Used for some other modules that have a non-standard way of running interviews."""
109
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
110
+ self.cache = Cache() if cache is None else cache
136
111
  data = []
137
112
  async for result in self.run_async_generator(cache=self.cache, n=n):
138
113
  data.append(result)
139
114
  return Results(survey=self.jobs.survey, data=data)
140
115
 
141
116
  def simple_run(self):
142
- from edsl.results.Results import Results
143
-
144
117
  data = asyncio.run(self.run_async())
145
118
  return Results(survey=self.jobs.survey, data=data)
146
119
 
@@ -148,14 +121,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
148
121
  self,
149
122
  *,
150
123
  interview: Interview,
151
- debug: bool,
152
124
  stop_on_exception: bool = False,
153
- sidecar_model: Optional[LanguageModel] = None,
154
- ) -> Result:
125
+ sidecar_model: Optional["LanguageModel"] = None,
126
+ raise_validation_errors: bool = False,
127
+ ) -> "Result":
155
128
  """Conducts an interview and returns the result.
156
129
 
157
130
  :param interview: the interview to conduct
158
- :param debug: prints debug messages
159
131
  :param stop_on_exception: stops the interview if an exception is raised
160
132
  :param sidecar_model: a language model to use in addition to the interview's model
161
133
  """
@@ -164,24 +136,37 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
164
136
 
165
137
  # get the results of the interview
166
138
  answer, valid_results = await interview.async_conduct_interview(
167
- debug=debug,
168
139
  model_buckets=model_buckets,
169
140
  stop_on_exception=stop_on_exception,
170
141
  sidecar_model=sidecar_model,
142
+ raise_validation_errors=raise_validation_errors,
171
143
  )
172
144
 
173
- # we should have a valid result for each question
174
- answer_key_names = {k for k in set(answer.keys()) if not k.endswith("_comment")}
145
+ question_results = {}
146
+ for result in valid_results:
147
+ question_results[result.question_name] = result
175
148
 
149
+ answer_key_names = list(question_results.keys())
150
+
151
+ generated_tokens_dict = {
152
+ k + "_generated_tokens": question_results[k].generated_tokens
153
+ for k in answer_key_names
154
+ }
155
+ comments_dict = {
156
+ k + "_comment": question_results[k].comment for k in answer_key_names
157
+ }
158
+
159
+ # we should have a valid result for each question
160
+ answer_dict = {k: answer[k] for k in answer_key_names}
176
161
  assert len(valid_results) == len(answer_key_names)
177
162
 
178
163
  # TODO: move this down into Interview
179
164
  question_name_to_prompts = dict({})
180
165
  for result in valid_results:
181
- question_name = result["question_name"]
166
+ question_name = result.question_name
182
167
  question_name_to_prompts[question_name] = {
183
- "user_prompt": result["prompts"]["user_prompt"],
184
- "system_prompt": result["prompts"]["system_prompt"],
168
+ "user_prompt": result.prompts["user_prompt"],
169
+ "system_prompt": result.prompts["system_prompt"],
185
170
  }
186
171
 
187
172
  prompt_dictionary = {}
@@ -195,22 +180,31 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
195
180
 
196
181
  raw_model_results_dictionary = {}
197
182
  for result in valid_results:
198
- question_name = result["question_name"]
183
+ question_name = result.question_name
199
184
  raw_model_results_dictionary[
200
185
  question_name + "_raw_model_response"
201
- ] = result["raw_model_response"]
202
-
203
- from edsl.results.Result import Result
186
+ ] = result.raw_model_response
187
+ raw_model_results_dictionary[question_name + "_cost"] = result.cost
188
+ one_use_buys = (
189
+ "NA"
190
+ if isinstance(result.cost, str)
191
+ or result.cost == 0
192
+ or result.cost is None
193
+ else 1.0 / result.cost
194
+ )
195
+ raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
204
196
 
205
197
  result = Result(
206
198
  agent=interview.agent,
207
199
  scenario=interview.scenario,
208
200
  model=interview.model,
209
201
  iteration=interview.iteration,
210
- answer=answer,
202
+ answer=answer_dict,
211
203
  prompt=prompt_dictionary,
212
204
  raw_model_response=raw_model_results_dictionary,
213
205
  survey=interview.survey,
206
+ generated_tokens=generated_tokens_dict,
207
+ comments_dict=comments_dict,
214
208
  )
215
209
  result.interview_hash = hash(interview)
216
210
 
@@ -220,132 +214,109 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
220
214
  def elapsed_time(self):
221
215
  return time.monotonic() - self.start_time
222
216
 
217
+ def process_results(
218
+ self, raw_results: Results, cache: Cache, print_exceptions: bool
219
+ ):
220
+ interview_lookup = {
221
+ hash(interview): index
222
+ for index, interview in enumerate(self.total_interviews)
223
+ }
224
+ interview_hashes = list(interview_lookup.keys())
225
+
226
+ results = Results(
227
+ survey=self.jobs.survey,
228
+ data=sorted(
229
+ raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
230
+ ),
231
+ )
232
+ results.cache = cache
233
+ results.task_history = TaskHistory(
234
+ self.total_interviews, include_traceback=False
235
+ )
236
+ results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
237
+ results.bucket_collection = self.bucket_collection
238
+
239
+ if results.has_unfixed_exceptions and print_exceptions:
240
+ from edsl.scenarios.FileStore import HTMLFileStore
241
+ from edsl.config import CONFIG
242
+ from edsl.coop.coop import Coop
243
+
244
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
245
+
246
+ if len(results.task_history.indices) > 5:
247
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
248
+
249
+ print(msg)
250
+ # this is where exceptions are opening up
251
+ filepath = results.task_history.html(
252
+ cta="Open report to see details.",
253
+ open_in_browser=True,
254
+ return_link=True,
255
+ )
256
+
257
+ try:
258
+ coop = Coop()
259
+ user_edsl_settings = coop.edsl_settings
260
+ remote_logging = user_edsl_settings["remote_logging"]
261
+ except Exception as e:
262
+ print(e)
263
+ remote_logging = False
264
+ if remote_logging:
265
+ filestore = HTMLFileStore(filepath)
266
+ coop_details = filestore.push(description="Error report")
267
+ print(coop_details)
268
+
269
+ print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
270
+
271
+ return results
272
+
223
273
  @jupyter_nb_handler
224
274
  async def run(
225
275
  self,
226
276
  cache: Union[Cache, False, None],
227
277
  n: int = 1,
228
- debug: bool = False,
229
278
  stop_on_exception: bool = False,
230
279
  progress_bar: bool = False,
231
280
  sidecar_model: Optional[LanguageModel] = None,
232
281
  print_exceptions: bool = True,
282
+ raise_validation_errors: bool = False,
233
283
  ) -> "Coroutine":
234
284
  """Runs a collection of interviews, handling both async and sync contexts."""
235
- from rich.console import Console
236
285
 
237
- console = Console()
238
286
  self.results = []
239
287
  self.start_time = time.monotonic()
240
288
  self.completed = False
241
289
  self.cache = cache
242
290
  self.sidecar_model = sidecar_model
243
291
 
244
- from edsl.results.Results import Results
245
- from rich.live import Live
246
- from rich.console import Console
247
-
248
- @cache_with_timeout(1)
249
- def generate_table():
250
- return self.status_table(self.results, self.elapsed_time)
292
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
251
293
 
252
- async def process_results(cache, progress_bar_context=None):
294
+ async def process_results(cache):
253
295
  """Processes results from interviews."""
254
296
  async for result in self.run_async_generator(
255
297
  n=n,
256
- debug=debug,
257
298
  stop_on_exception=stop_on_exception,
258
299
  cache=cache,
259
300
  sidecar_model=sidecar_model,
301
+ raise_validation_errors=raise_validation_errors,
260
302
  ):
261
303
  self.results.append(result)
262
- if progress_bar_context:
263
- progress_bar_context.update(generate_table())
264
- self.completed = True
265
-
266
- async def update_progress_bar(progress_bar_context):
267
- """Updates the progress bar at fixed intervals."""
268
- if progress_bar_context is None:
269
- return
270
-
271
- while True:
272
- progress_bar_context.update(generate_table())
273
- await asyncio.sleep(0.1) # Update interval
274
- if self.completed:
275
- break
276
-
277
- @contextmanager
278
- def conditional_context(condition, context_manager):
279
- if condition:
280
- with context_manager as cm:
281
- yield cm
282
- else:
283
- yield
284
-
285
- with conditional_context(
286
- progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
287
- ) as progress_bar_context:
288
- with cache as c:
289
- progress_task = asyncio.create_task(
290
- update_progress_bar(progress_bar_context)
291
- )
292
-
293
- try:
294
- await asyncio.gather(
295
- progress_task,
296
- process_results(
297
- cache=c, progress_bar_context=progress_bar_context
298
- ),
299
- )
300
- except asyncio.CancelledError:
301
- pass
302
- finally:
303
- progress_task.cancel() # Cancel the progress_task when process_results is done
304
- await progress_task
305
-
306
- await asyncio.sleep(1) # short delay to show the final status
307
-
308
- if progress_bar_context:
309
- progress_bar_context.update(generate_table())
310
-
311
- # puts results in the same order as the total interviews
312
- interview_hashes = [hash(interview) for interview in self.total_interviews]
313
- self.results = sorted(
314
- self.results, key=lambda x: interview_hashes.index(x.interview_hash)
315
- )
304
+ self.completed = True
316
305
 
317
- results = Results(survey=self.jobs.survey, data=self.results)
318
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
319
- results.task_history = task_history
320
-
321
- results.has_exceptions = task_history.has_exceptions
322
-
323
- if results.has_exceptions:
324
- # put the failed interviews in the results object as a list
325
- failed_interviews = [
326
- interview.duplicate(
327
- iteration=interview.iteration, cache=interview.cache
328
- )
329
- for interview in self.total_interviews
330
- if interview.has_exceptions
331
- ]
332
- from edsl.jobs.Jobs import Jobs
333
-
334
- results.failed_jobs = Jobs.from_interviews(
335
- [interview for interview in failed_interviews]
336
- )
337
- if print_exceptions:
338
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
306
+ def run_progress_bar():
307
+ """Runs the progress bar in a separate thread."""
308
+ self.jobs_runner_status.update_progress()
339
309
 
340
- if len(results.task_history.indices) > 5:
341
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
310
+ if progress_bar:
311
+ progress_thread = threading.Thread(target=run_progress_bar)
312
+ progress_thread.start()
342
313
 
343
- shared_globals["edsl_runner_exceptions"] = task_history
344
- print(msg)
345
- # this is where exceptions are opening up
346
- task_history.html(cta="Open report to see details.")
347
- print(
348
- "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
349
- )
314
+ with cache as c:
315
+ await process_results(cache=c)
350
316
 
351
- return results
317
+ if progress_bar:
318
+ progress_thread.join()
319
+
320
+ return self.process_results(
321
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
322
+ )