edsl 0.1.30.dev4__py3-none-any.whl → 0.1.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +7 -2
  3. edsl/agents/PromptConstructionMixin.py +18 -1
  4. edsl/config.py +4 -0
  5. edsl/conjure/Conjure.py +6 -0
  6. edsl/coop/coop.py +4 -0
  7. edsl/coop/utils.py +9 -1
  8. edsl/data/CacheHandler.py +3 -4
  9. edsl/enums.py +2 -0
  10. edsl/inference_services/DeepInfraService.py +6 -91
  11. edsl/inference_services/GroqService.py +18 -0
  12. edsl/inference_services/InferenceServicesCollection.py +13 -5
  13. edsl/inference_services/OpenAIService.py +64 -21
  14. edsl/inference_services/registry.py +2 -1
  15. edsl/jobs/Jobs.py +80 -33
  16. edsl/jobs/buckets/TokenBucket.py +24 -5
  17. edsl/jobs/interviews/Interview.py +122 -75
  18. edsl/jobs/interviews/InterviewExceptionEntry.py +101 -0
  19. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +58 -52
  20. edsl/jobs/interviews/interview_exception_tracking.py +68 -10
  21. edsl/jobs/runners/JobsRunnerAsyncio.py +112 -81
  22. edsl/jobs/runners/JobsRunnerStatusData.py +0 -237
  23. edsl/jobs/runners/JobsRunnerStatusMixin.py +291 -35
  24. edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
  25. edsl/jobs/tasks/TaskCreators.py +8 -2
  26. edsl/jobs/tasks/TaskHistory.py +145 -1
  27. edsl/language_models/LanguageModel.py +135 -75
  28. edsl/language_models/ModelList.py +8 -2
  29. edsl/language_models/registry.py +16 -0
  30. edsl/questions/QuestionFunctional.py +34 -2
  31. edsl/questions/QuestionMultipleChoice.py +58 -8
  32. edsl/questions/QuestionNumerical.py +0 -1
  33. edsl/questions/descriptors.py +42 -2
  34. edsl/results/DatasetExportMixin.py +258 -75
  35. edsl/results/Result.py +53 -5
  36. edsl/results/Results.py +66 -27
  37. edsl/results/ResultsToolsMixin.py +1 -1
  38. edsl/scenarios/Scenario.py +14 -0
  39. edsl/scenarios/ScenarioList.py +59 -21
  40. edsl/scenarios/ScenarioListExportMixin.py +16 -5
  41. edsl/scenarios/ScenarioListPdfMixin.py +3 -0
  42. edsl/study/Study.py +2 -2
  43. edsl/surveys/Survey.py +35 -1
  44. {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/METADATA +4 -2
  45. {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/RECORD +47 -45
  46. {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/WHEEL +1 -1
  47. {edsl-0.1.30.dev4.dist-info → edsl-0.1.31.dist-info}/LICENSE +0 -0
@@ -12,20 +12,38 @@ from edsl.exceptions import InterviewTimeoutError
12
12
  # from edsl.questions.QuestionBase import QuestionBase
13
13
  from edsl.surveys.base import EndOfSurvey
14
14
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
15
- from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
15
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
16
16
  from edsl.jobs.interviews.retry_management import retry_strategy
17
17
  from edsl.jobs.tasks.task_status_enum import TaskStatus
18
18
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
19
19
 
20
20
  # from edsl.agents.InvigilatorBase import InvigilatorBase
21
21
 
22
+ from rich.console import Console
23
+ from rich.traceback import Traceback
24
+
22
25
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
23
26
 
24
27
 
28
+ def frame_summary_to_dict(frame):
29
+ """
30
+ Convert a FrameSummary object to a dictionary.
31
+
32
+ :param frame: A traceback FrameSummary object
33
+ :return: A dictionary containing the frame's details
34
+ """
35
+ return {
36
+ "filename": frame.filename,
37
+ "lineno": frame.lineno,
38
+ "name": frame.name,
39
+ "line": frame.line,
40
+ }
41
+
42
+
25
43
  class InterviewTaskBuildingMixin:
26
44
  def _build_invigilators(
27
45
  self, debug: bool
28
- ) -> Generator[InvigilatorBase, None, None]:
46
+ ) -> Generator["InvigilatorBase", None, None]:
29
47
  """Create an invigilator for each question.
30
48
 
31
49
  :param debug: whether to use debug mode, in which case `InvigilatorDebug` is used.
@@ -35,7 +53,7 @@ class InterviewTaskBuildingMixin:
35
53
  for question in self.survey.questions:
36
54
  yield self._get_invigilator(question=question, debug=debug)
37
55
 
38
- def _get_invigilator(self, question: QuestionBase, debug: bool) -> "Invigilator":
56
+ def _get_invigilator(self, question: "QuestionBase", debug: bool) -> "Invigilator":
39
57
  """Return an invigilator for the given question.
40
58
 
41
59
  :param question: the question to be answered
@@ -84,7 +102,7 @@ class InterviewTaskBuildingMixin:
84
102
  return tuple(tasks) # , invigilators
85
103
 
86
104
  def _get_tasks_that_must_be_completed_before(
87
- self, *, tasks: list[asyncio.Task], question: QuestionBase
105
+ self, *, tasks: list[asyncio.Task], question: "QuestionBase"
88
106
  ) -> Generator[asyncio.Task, None, None]:
89
107
  """Return the tasks that must be completed before the given question can be answered.
90
108
 
@@ -100,7 +118,7 @@ class InterviewTaskBuildingMixin:
100
118
  def _create_question_task(
101
119
  self,
102
120
  *,
103
- question: QuestionBase,
121
+ question: "QuestionBase",
104
122
  tasks_that_must_be_completed_before: list[asyncio.Task],
105
123
  model_buckets: ModelBuckets,
106
124
  debug: bool,
@@ -148,7 +166,6 @@ class InterviewTaskBuildingMixin:
148
166
  raise ValueError(f"Prompt is of type {type(prompt)}")
149
167
  return len(combined_text) / 4.0
150
168
 
151
- @retry_strategy
152
169
  async def _answer_question_and_record_task(
153
170
  self,
154
171
  *,
@@ -163,36 +180,33 @@ class InterviewTaskBuildingMixin:
163
180
  """
164
181
  from edsl.data_transfer_models import AgentResponseDict
165
182
 
166
- try:
167
- invigilator = self._get_invigilator(question, debug=debug)
183
+ async def _inner():
184
+ try:
185
+ invigilator = self._get_invigilator(question, debug=debug)
168
186
 
169
- if self._skip_this_question(question):
170
- return invigilator.get_failed_task_result()
187
+ if self._skip_this_question(question):
188
+ return invigilator.get_failed_task_result()
171
189
 
172
- response: AgentResponseDict = await self._attempt_to_answer_question(
173
- invigilator, task
174
- )
190
+ response: AgentResponseDict = await self._attempt_to_answer_question(
191
+ invigilator, task
192
+ )
175
193
 
176
- self._add_answer(response=response, question=question)
194
+ self._add_answer(response=response, question=question)
177
195
 
178
- # With the answer to the question, we can now cancel any skipped questions
179
- self._cancel_skipped_questions(question)
180
- return AgentResponseDict(**response)
181
- except Exception as e:
182
- raise e
183
- # import traceback
184
- # print("Exception caught:")
185
- # traceback.print_exc()
186
-
187
- # # Extract and print the traceback info
188
- # tb = e.__traceback__
189
- # while tb is not None:
190
- # print(f"File {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}, in {tb.tb_frame.f_code.co_name}")
191
- # tb = tb.tb_next
192
- # breakpoint()
193
- # raise e
194
-
195
- def _add_answer(self, response: AgentResponseDict, question: QuestionBase) -> None:
196
+ self._cancel_skipped_questions(question)
197
+ return AgentResponseDict(**response)
198
+ except Exception as e:
199
+ raise e
200
+
201
+ skip_rety = getattr(self, "skip_retry", False)
202
+ if not skip_rety:
203
+ _inner = retry_strategy(_inner)
204
+
205
+ return await _inner()
206
+
207
+ def _add_answer(
208
+ self, response: "AgentResponseDict", question: "QuestionBase"
209
+ ) -> None:
196
210
  """Add the answer to the answers dictionary.
197
211
 
198
212
  :param response: the response to the question.
@@ -200,7 +214,7 @@ class InterviewTaskBuildingMixin:
200
214
  """
201
215
  self.answers.add_answer(response=response, question=question)
202
216
 
203
- def _skip_this_question(self, current_question: QuestionBase) -> bool:
217
+ def _skip_this_question(self, current_question: "QuestionBase") -> bool:
204
218
  """Determine if the current question should be skipped.
205
219
 
206
220
  :param current_question: the question to be answered.
@@ -213,38 +227,30 @@ class InterviewTaskBuildingMixin:
213
227
  )
214
228
  return skip
215
229
 
230
+ def _handle_exception(self, e, question_name: str, task=None):
231
+ exception_entry = InterviewExceptionEntry(e)
232
+ if task:
233
+ task.task_status = TaskStatus.FAILED
234
+ self.exceptions.add(question_name, exception_entry)
235
+
216
236
  async def _attempt_to_answer_question(
217
- self, invigilator: InvigilatorBase, task: asyncio.Task
218
- ) -> AgentResponseDict:
237
+ self, invigilator: "InvigilatorBase", task: asyncio.Task
238
+ ) -> "AgentResponseDict":
219
239
  """Attempt to answer the question, and handle exceptions.
220
240
 
221
241
  :param invigilator: the invigilator that will answer the question.
222
242
  :param task: the task that is being run.
243
+
223
244
  """
224
245
  try:
225
246
  return await asyncio.wait_for(
226
247
  invigilator.async_answer_question(), timeout=TIMEOUT
227
248
  )
228
249
  except asyncio.TimeoutError as e:
229
- exception_entry = InterviewExceptionEntry(
230
- exception=repr(e),
231
- time=time.time(),
232
- traceback=traceback.format_exc(),
233
- )
234
- if task:
235
- task.task_status = TaskStatus.FAILED
236
- self.exceptions.add(invigilator.question.question_name, exception_entry)
237
-
250
+ self._handle_exception(e, invigilator.question.question_name, task)
238
251
  raise InterviewTimeoutError(f"Task timed out after {TIMEOUT} seconds.")
239
252
  except Exception as e:
240
- exception_entry = InterviewExceptionEntry(
241
- exception=repr(e),
242
- time=time.time(),
243
- traceback=traceback.format_exc(),
244
- )
245
- if task:
246
- task.task_status = TaskStatus.FAILED
247
- self.exceptions.add(invigilator.question.question_name, exception_entry)
253
+ self._handle_exception(e, invigilator.question.question_name, task)
248
254
  raise e
249
255
 
250
256
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
@@ -1,18 +1,70 @@
1
- from rich.console import Console
2
- from rich.table import Table
1
+ import traceback
2
+ import datetime
3
+ import time
3
4
  from collections import UserDict
4
5
 
6
+ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
5
7
 
6
- class InterviewExceptionEntry(UserDict):
7
- """Class to record an exception that occurred during the interview."""
8
+ # #traceback=traceback.format_exc(),
9
+ # #traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
10
+ # #traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
8
11
 
9
- def __init__(self, exception, time, traceback):
10
- data = {"exception": exception, "time": time, "traceback": traceback}
11
- super().__init__(data)
12
+ # class InterviewExceptionEntry:
13
+ # """Class to record an exception that occurred during the interview.
12
14
 
13
- def to_dict(self) -> dict:
14
- """Return the exception as a dictionary."""
15
- return self.data
15
+ # >>> entry = InterviewExceptionEntry.example()
16
+ # >>> entry.to_dict()['exception']
17
+ # "ValueError('An error occurred.')"
18
+ # """
19
+
20
+ # def __init__(self, exception: Exception):
21
+ # self.time = datetime.datetime.now().isoformat()
22
+ # self.exception = exception
23
+
24
+ # def __getitem__(self, key):
25
+ # # Support dict-like access obj['a']
26
+ # return str(getattr(self, key))
27
+
28
+ # @classmethod
29
+ # def example(cls):
30
+ # try:
31
+ # raise ValueError("An error occurred.")
32
+ # except Exception as e:
33
+ # entry = InterviewExceptionEntry(e)
34
+ # return entry
35
+
36
+ # @property
37
+ # def traceback(self):
38
+ # """Return the exception as HTML."""
39
+ # e = self.exception
40
+ # tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
41
+ # return tb_str
42
+
43
+
44
+ # @property
45
+ # def html(self):
46
+ # from rich.console import Console
47
+ # from rich.table import Table
48
+ # from rich.traceback import Traceback
49
+
50
+ # from io import StringIO
51
+ # html_output = StringIO()
52
+
53
+ # console = Console(file=html_output, record=True)
54
+ # tb = Traceback(show_locals=True)
55
+ # console.print(tb)
56
+
57
+ # tb = Traceback.from_exception(type(self.exception), self.exception, self.exception.__traceback__, show_locals=True)
58
+ # console.print(tb)
59
+ # return html_output.getvalue()
60
+
61
+ # def to_dict(self) -> dict:
62
+ # """Return the exception as a dictionary."""
63
+ # return {
64
+ # 'exception': repr(self.exception),
65
+ # 'time': self.time,
66
+ # 'traceback': self.traceback
67
+ # }
16
68
 
17
69
 
18
70
  class InterviewExceptionCollection(UserDict):
@@ -84,3 +136,9 @@ class InterviewExceptionCollection(UserDict):
84
136
  )
85
137
 
86
138
  console.print(table)
139
+
140
+
141
+ if __name__ == "__main__":
142
+ import doctest
143
+
144
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -13,6 +13,40 @@ from edsl.jobs.tasks.TaskHistory import TaskHistory
13
13
  from edsl.jobs.buckets.BucketCollection import BucketCollection
14
14
  from edsl.utilities.decorators import jupyter_nb_handler
15
15
 
16
+ import time
17
+ import functools
18
+
19
+
20
+ def cache_with_timeout(timeout):
21
+ def decorator(func):
22
+ cached_result = {}
23
+ last_computation_time = [0] # Using list to store mutable value
24
+
25
+ @functools.wraps(func)
26
+ def wrapper(*args, **kwargs):
27
+ current_time = time.time()
28
+ if (current_time - last_computation_time[0]) >= timeout:
29
+ cached_result["value"] = func(*args, **kwargs)
30
+ last_computation_time[0] = current_time
31
+ return cached_result["value"]
32
+
33
+ return wrapper
34
+
35
+ return decorator
36
+
37
+
38
+ # from queue import Queue
39
+ from collections import UserList
40
+
41
+
42
+ class StatusTracker(UserList):
43
+ def __init__(self, total_tasks: int):
44
+ self.total_tasks = total_tasks
45
+ super().__init__()
46
+
47
+ def current_status(self):
48
+ return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
+
16
50
 
17
51
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
18
52
  """A class for running a collection of interviews asynchronously.
@@ -43,7 +77,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
43
77
 
44
78
  :param n: how many times to run each interview
45
79
  :param debug:
46
- :param stop_on_exception:
80
+ :param stop_on_exception: Whether to stop the interview if an exception is raised
81
+ :param sidecar_model: a language model to use in addition to the interview's model
82
+ :param total_interviews: A list of interviews to run can be provided instead.
47
83
  """
48
84
  tasks = []
49
85
  if total_interviews:
@@ -87,15 +123,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
87
123
  ) # set the cache for the first interview
88
124
  self.total_interviews.append(interview)
89
125
 
90
- async def run_async(self, cache=None) -> Results:
126
+ async def run_async(self, cache=None, n=1) -> Results:
91
127
  from edsl.results.Results import Results
92
128
 
129
+ # breakpoint()
130
+ # tracker = StatusTracker(total_tasks=len(self.interviews))
131
+
93
132
  if cache is None:
94
133
  self.cache = Cache()
95
134
  else:
96
135
  self.cache = cache
97
136
  data = []
98
- async for result in self.run_async_generator(cache=self.cache):
137
+ async for result in self.run_async_generator(cache=self.cache, n=n):
99
138
  data.append(result)
100
139
  return Results(survey=self.jobs.survey, data=data)
101
140
 
@@ -173,6 +212,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
173
212
  raw_model_response=raw_model_results_dictionary,
174
213
  survey=interview.survey,
175
214
  )
215
+ result.interview_hash = hash(interview)
216
+
176
217
  return result
177
218
 
178
219
  @property
@@ -201,97 +242,86 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
201
242
  self.sidecar_model = sidecar_model
202
243
 
203
244
  from edsl.results.Results import Results
245
+ from rich.live import Live
246
+ from rich.console import Console
204
247
 
205
- if not progress_bar:
206
- # print("Running without progress bar")
207
- with cache as c:
208
-
209
- async def process_results():
210
- """Processes results from interviews."""
211
- async for result in self.run_async_generator(
212
- n=n,
213
- debug=debug,
214
- stop_on_exception=stop_on_exception,
215
- cache=c,
216
- sidecar_model=sidecar_model,
217
- ):
218
- self.results.append(result)
219
- self.completed = True
220
-
221
- await asyncio.gather(process_results())
222
-
223
- results = Results(survey=self.jobs.survey, data=self.results)
224
- else:
225
- # print("Running with progress bar")
226
- from rich.live import Live
227
- from rich.console import Console
228
-
229
- def generate_table():
230
- return self.status_table(self.results, self.elapsed_time)
248
+ @cache_with_timeout(1)
249
+ def generate_table():
250
+ return self.status_table(self.results, self.elapsed_time)
231
251
 
232
- @contextmanager
233
- def no_op_cm():
234
- """A no-op context manager with a dummy update method."""
235
- yield DummyLive()
252
+ async def process_results(cache, progress_bar_context=None):
253
+ """Processes results from interviews."""
254
+ async for result in self.run_async_generator(
255
+ n=n,
256
+ debug=debug,
257
+ stop_on_exception=stop_on_exception,
258
+ cache=cache,
259
+ sidecar_model=sidecar_model,
260
+ ):
261
+ self.results.append(result)
262
+ if progress_bar_context:
263
+ progress_bar_context.update(generate_table())
264
+ self.completed = True
265
+
266
+ async def update_progress_bar(progress_bar_context):
267
+ """Updates the progress bar at fixed intervals."""
268
+ if progress_bar_context is None:
269
+ return
270
+
271
+ while True:
272
+ progress_bar_context.update(generate_table())
273
+ await asyncio.sleep(0.1) # Update interval
274
+ if self.completed:
275
+ break
276
+
277
+ @contextmanager
278
+ def conditional_context(condition, context_manager):
279
+ if condition:
280
+ with context_manager as cm:
281
+ yield cm
282
+ else:
283
+ yield
284
+
285
+ with conditional_context(
286
+ progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
287
+ ) as progress_bar_context:
288
+ with cache as c:
289
+ progress_task = asyncio.create_task(
290
+ update_progress_bar(progress_bar_context)
291
+ )
236
292
 
237
- class DummyLive:
238
- def update(self, *args, **kwargs):
239
- """A dummy update method that does nothing."""
293
+ try:
294
+ await asyncio.gather(
295
+ progress_task,
296
+ process_results(
297
+ cache=c, progress_bar_context=progress_bar_context
298
+ ),
299
+ )
300
+ except asyncio.CancelledError:
240
301
  pass
302
+ finally:
303
+ progress_task.cancel() # Cancel the progress_task when process_results is done
304
+ await progress_task
241
305
 
242
- progress_bar_context = (
243
- Live(generate_table(), console=console, refresh_per_second=5)
244
- if progress_bar
245
- else no_op_cm()
246
- )
306
+ await asyncio.sleep(1) # short delay to show the final status
247
307
 
248
- with cache as c:
249
- with progress_bar_context as live:
250
-
251
- async def update_progress_bar():
252
- """Updates the progress bar at fixed intervals."""
253
- while True:
254
- live.update(generate_table())
255
- await asyncio.sleep(0.00001) # Update interval
256
- if self.completed:
257
- break
258
-
259
- async def process_results():
260
- """Processes results from interviews."""
261
- async for result in self.run_async_generator(
262
- n=n,
263
- debug=debug,
264
- stop_on_exception=stop_on_exception,
265
- cache=c,
266
- sidecar_model=sidecar_model,
267
- ):
268
- self.results.append(result)
269
- live.update(generate_table())
270
- self.completed = True
271
-
272
- progress_task = asyncio.create_task(update_progress_bar())
273
-
274
- try:
275
- await asyncio.gather(process_results(), progress_task)
276
- except asyncio.CancelledError:
277
- pass
278
- finally:
279
- progress_task.cancel() # Cancel the progress_task when process_results is done
280
- await progress_task
281
-
282
- await asyncio.sleep(1) # short delay to show the final status
283
-
284
- # one more update
285
- live.update(generate_table())
286
-
287
- results = Results(survey=self.jobs.survey, data=self.results)
308
+ if progress_bar_context:
309
+ progress_bar_context.update(generate_table())
310
+
311
+ # puts results in the same order as the total interviews
312
+ interview_hashes = [hash(interview) for interview in self.total_interviews]
313
+ self.results = sorted(
314
+ self.results, key=lambda x: interview_hashes.index(x.interview_hash)
315
+ )
288
316
 
317
+ results = Results(survey=self.jobs.survey, data=self.results)
289
318
  task_history = TaskHistory(self.total_interviews, include_traceback=False)
290
319
  results.task_history = task_history
291
320
 
292
321
  results.has_exceptions = task_history.has_exceptions
293
322
 
294
323
  if results.has_exceptions:
324
+ # put the failed interviews in the results object as a list
295
325
  failed_interviews = [
296
326
  interview.duplicate(
297
327
  iteration=interview.iteration, cache=interview.cache
@@ -312,6 +342,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
312
342
 
313
343
  shared_globals["edsl_runner_exceptions"] = task_history
314
344
  print(msg)
345
+ # this is where exceptions are opening up
315
346
  task_history.html(cta="Open report to see details.")
316
347
  print(
317
348
  "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"