edsl 0.1.31__py3-none-any.whl → 0.1.31.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Invigilator.py +2 -7
  3. edsl/agents/PromptConstructionMixin.py +4 -9
  4. edsl/config.py +0 -4
  5. edsl/conjure/Conjure.py +0 -6
  6. edsl/coop/coop.py +0 -4
  7. edsl/data/CacheHandler.py +4 -3
  8. edsl/enums.py +0 -2
  9. edsl/inference_services/DeepInfraService.py +91 -6
  10. edsl/inference_services/InferenceServicesCollection.py +8 -13
  11. edsl/inference_services/OpenAIService.py +21 -64
  12. edsl/inference_services/registry.py +1 -2
  13. edsl/jobs/Jobs.py +5 -29
  14. edsl/jobs/buckets/TokenBucket.py +4 -12
  15. edsl/jobs/interviews/Interview.py +9 -31
  16. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +33 -49
  17. edsl/jobs/interviews/interview_exception_tracking.py +10 -68
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +81 -112
  19. edsl/jobs/runners/JobsRunnerStatusData.py +237 -0
  20. edsl/jobs/runners/JobsRunnerStatusMixin.py +35 -291
  21. edsl/jobs/tasks/TaskCreators.py +2 -8
  22. edsl/jobs/tasks/TaskHistory.py +1 -145
  23. edsl/language_models/LanguageModel.py +32 -49
  24. edsl/language_models/registry.py +0 -4
  25. edsl/questions/QuestionMultipleChoice.py +1 -1
  26. edsl/questions/QuestionNumerical.py +1 -0
  27. edsl/results/DatasetExportMixin.py +3 -12
  28. edsl/scenarios/Scenario.py +0 -14
  29. edsl/scenarios/ScenarioList.py +2 -15
  30. edsl/scenarios/ScenarioListExportMixin.py +4 -15
  31. edsl/scenarios/ScenarioListPdfMixin.py +0 -3
  32. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/METADATA +1 -2
  33. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/RECORD +35 -37
  34. edsl/inference_services/GroqService.py +0 -18
  35. edsl/jobs/interviews/InterviewExceptionEntry.py +0 -101
  36. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/LICENSE +0 -0
  37. {edsl-0.1.31.dist-info → edsl-0.1.31.dev2.dist-info}/WHEEL +0 -0
@@ -12,34 +12,16 @@ from edsl.exceptions import InterviewTimeoutError
12
12
  # from edsl.questions.QuestionBase import QuestionBase
13
13
  from edsl.surveys.base import EndOfSurvey
14
14
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
15
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
15
+ from edsl.jobs.interviews.interview_exception_tracking import InterviewExceptionEntry
16
16
  from edsl.jobs.interviews.retry_management import retry_strategy
17
17
  from edsl.jobs.tasks.task_status_enum import TaskStatus
18
18
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
19
19
 
20
20
  # from edsl.agents.InvigilatorBase import InvigilatorBase
21
21
 
22
- from rich.console import Console
23
- from rich.traceback import Traceback
24
-
25
22
  TIMEOUT = float(CONFIG.get("EDSL_API_TIMEOUT"))
26
23
 
27
24
 
28
- def frame_summary_to_dict(frame):
29
- """
30
- Convert a FrameSummary object to a dictionary.
31
-
32
- :param frame: A traceback FrameSummary object
33
- :return: A dictionary containing the frame's details
34
- """
35
- return {
36
- "filename": frame.filename,
37
- "lineno": frame.lineno,
38
- "name": frame.name,
39
- "line": frame.line,
40
- }
41
-
42
-
43
25
  class InterviewTaskBuildingMixin:
44
26
  def _build_invigilators(
45
27
  self, debug: bool
@@ -166,6 +148,7 @@ class InterviewTaskBuildingMixin:
166
148
  raise ValueError(f"Prompt is of type {type(prompt)}")
167
149
  return len(combined_text) / 4.0
168
150
 
151
+ @retry_strategy
169
152
  async def _answer_question_and_record_task(
170
153
  self,
171
154
  *,
@@ -180,29 +163,22 @@ class InterviewTaskBuildingMixin:
180
163
  """
181
164
  from edsl.data_transfer_models import AgentResponseDict
182
165
 
183
- async def _inner():
184
- try:
185
- invigilator = self._get_invigilator(question, debug=debug)
186
-
187
- if self._skip_this_question(question):
188
- return invigilator.get_failed_task_result()
189
-
190
- response: AgentResponseDict = await self._attempt_to_answer_question(
191
- invigilator, task
192
- )
166
+ try:
167
+ invigilator = self._get_invigilator(question, debug=debug)
193
168
 
194
- self._add_answer(response=response, question=question)
169
+ if self._skip_this_question(question):
170
+ return invigilator.get_failed_task_result()
195
171
 
196
- self._cancel_skipped_questions(question)
197
- return AgentResponseDict(**response)
198
- except Exception as e:
199
- raise e
172
+ response: AgentResponseDict = await self._attempt_to_answer_question(
173
+ invigilator, task
174
+ )
200
175
 
201
- skip_rety = getattr(self, "skip_retry", False)
202
- if not skip_rety:
203
- _inner = retry_strategy(_inner)
176
+ self._add_answer(response=response, question=question)
204
177
 
205
- return await _inner()
178
+ self._cancel_skipped_questions(question)
179
+ return AgentResponseDict(**response)
180
+ except Exception as e:
181
+ raise e
206
182
 
207
183
  def _add_answer(
208
184
  self, response: "AgentResponseDict", question: "QuestionBase"
@@ -227,30 +203,38 @@ class InterviewTaskBuildingMixin:
227
203
  )
228
204
  return skip
229
205
 
230
- def _handle_exception(self, e, question_name: str, task=None):
231
- exception_entry = InterviewExceptionEntry(e)
232
- if task:
233
- task.task_status = TaskStatus.FAILED
234
- self.exceptions.add(question_name, exception_entry)
235
-
236
206
  async def _attempt_to_answer_question(
237
- self, invigilator: "InvigilatorBase", task: asyncio.Task
238
- ) -> "AgentResponseDict":
207
+ self, invigilator: InvigilatorBase, task: asyncio.Task
208
+ ) -> AgentResponseDict:
239
209
  """Attempt to answer the question, and handle exceptions.
240
210
 
241
211
  :param invigilator: the invigilator that will answer the question.
242
212
  :param task: the task that is being run.
243
-
244
213
  """
245
214
  try:
246
215
  return await asyncio.wait_for(
247
216
  invigilator.async_answer_question(), timeout=TIMEOUT
248
217
  )
249
218
  except asyncio.TimeoutError as e:
250
- self._handle_exception(e, invigilator.question.question_name, task)
219
+ exception_entry = InterviewExceptionEntry(
220
+ exception=repr(e),
221
+ time=time.time(),
222
+ traceback=traceback.format_exc(),
223
+ )
224
+ if task:
225
+ task.task_status = TaskStatus.FAILED
226
+ self.exceptions.add(invigilator.question.question_name, exception_entry)
227
+
251
228
  raise InterviewTimeoutError(f"Task timed out after {TIMEOUT} seconds.")
252
229
  except Exception as e:
253
- self._handle_exception(e, invigilator.question.question_name, task)
230
+ exception_entry = InterviewExceptionEntry(
231
+ exception=repr(e),
232
+ time=time.time(),
233
+ traceback=traceback.format_exc(),
234
+ )
235
+ if task:
236
+ task.task_status = TaskStatus.FAILED
237
+ self.exceptions.add(invigilator.question.question_name, exception_entry)
254
238
  raise e
255
239
 
256
240
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
@@ -1,70 +1,18 @@
1
- import traceback
2
- import datetime
3
- import time
1
+ from rich.console import Console
2
+ from rich.table import Table
4
3
  from collections import UserDict
5
4
 
6
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
7
5
 
8
- # #traceback=traceback.format_exc(),
9
- # #traceback = frame_summary_to_dict(traceback.extract_tb(e.__traceback__))
10
- # #traceback = [frame_summary_to_dict(f) for f in traceback.extract_tb(e.__traceback__)]
6
+ class InterviewExceptionEntry(UserDict):
7
+ """Class to record an exception that occurred during the interview."""
11
8
 
12
- # class InterviewExceptionEntry:
13
- # """Class to record an exception that occurred during the interview.
9
+ def __init__(self, exception, time, traceback):
10
+ data = {"exception": exception, "time": time, "traceback": traceback}
11
+ super().__init__(data)
14
12
 
15
- # >>> entry = InterviewExceptionEntry.example()
16
- # >>> entry.to_dict()['exception']
17
- # "ValueError('An error occurred.')"
18
- # """
19
-
20
- # def __init__(self, exception: Exception):
21
- # self.time = datetime.datetime.now().isoformat()
22
- # self.exception = exception
23
-
24
- # def __getitem__(self, key):
25
- # # Support dict-like access obj['a']
26
- # return str(getattr(self, key))
27
-
28
- # @classmethod
29
- # def example(cls):
30
- # try:
31
- # raise ValueError("An error occurred.")
32
- # except Exception as e:
33
- # entry = InterviewExceptionEntry(e)
34
- # return entry
35
-
36
- # @property
37
- # def traceback(self):
38
- # """Return the exception as HTML."""
39
- # e = self.exception
40
- # tb_str = ''.join(traceback.format_exception(type(e), e, e.__traceback__))
41
- # return tb_str
42
-
43
-
44
- # @property
45
- # def html(self):
46
- # from rich.console import Console
47
- # from rich.table import Table
48
- # from rich.traceback import Traceback
49
-
50
- # from io import StringIO
51
- # html_output = StringIO()
52
-
53
- # console = Console(file=html_output, record=True)
54
- # tb = Traceback(show_locals=True)
55
- # console.print(tb)
56
-
57
- # tb = Traceback.from_exception(type(self.exception), self.exception, self.exception.__traceback__, show_locals=True)
58
- # console.print(tb)
59
- # return html_output.getvalue()
60
-
61
- # def to_dict(self) -> dict:
62
- # """Return the exception as a dictionary."""
63
- # return {
64
- # 'exception': repr(self.exception),
65
- # 'time': self.time,
66
- # 'traceback': self.traceback
67
- # }
13
+ def to_dict(self) -> dict:
14
+ """Return the exception as a dictionary."""
15
+ return self.data
68
16
 
69
17
 
70
18
  class InterviewExceptionCollection(UserDict):
@@ -136,9 +84,3 @@ class InterviewExceptionCollection(UserDict):
136
84
  )
137
85
 
138
86
  console.print(table)
139
-
140
-
141
- if __name__ == "__main__":
142
- import doctest
143
-
144
- doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -13,40 +13,6 @@ from edsl.jobs.tasks.TaskHistory import TaskHistory
13
13
  from edsl.jobs.buckets.BucketCollection import BucketCollection
14
14
  from edsl.utilities.decorators import jupyter_nb_handler
15
15
 
16
- import time
17
- import functools
18
-
19
-
20
- def cache_with_timeout(timeout):
21
- def decorator(func):
22
- cached_result = {}
23
- last_computation_time = [0] # Using list to store mutable value
24
-
25
- @functools.wraps(func)
26
- def wrapper(*args, **kwargs):
27
- current_time = time.time()
28
- if (current_time - last_computation_time[0]) >= timeout:
29
- cached_result["value"] = func(*args, **kwargs)
30
- last_computation_time[0] = current_time
31
- return cached_result["value"]
32
-
33
- return wrapper
34
-
35
- return decorator
36
-
37
-
38
- # from queue import Queue
39
- from collections import UserList
40
-
41
-
42
- class StatusTracker(UserList):
43
- def __init__(self, total_tasks: int):
44
- self.total_tasks = total_tasks
45
- super().__init__()
46
-
47
- def current_status(self):
48
- return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
-
50
16
 
51
17
  class JobsRunnerAsyncio(JobsRunnerStatusMixin):
52
18
  """A class for running a collection of interviews asynchronously.
@@ -77,9 +43,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
77
43
 
78
44
  :param n: how many times to run each interview
79
45
  :param debug:
80
- :param stop_on_exception: Whether to stop the interview if an exception is raised
81
- :param sidecar_model: a language model to use in addition to the interview's model
82
- :param total_interviews: A list of interviews to run can be provided instead.
46
+ :param stop_on_exception:
83
47
  """
84
48
  tasks = []
85
49
  if total_interviews:
@@ -123,18 +87,15 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
123
87
  ) # set the cache for the first interview
124
88
  self.total_interviews.append(interview)
125
89
 
126
- async def run_async(self, cache=None, n=1) -> Results:
90
+ async def run_async(self, cache=None) -> Results:
127
91
  from edsl.results.Results import Results
128
92
 
129
- # breakpoint()
130
- # tracker = StatusTracker(total_tasks=len(self.interviews))
131
-
132
93
  if cache is None:
133
94
  self.cache = Cache()
134
95
  else:
135
96
  self.cache = cache
136
97
  data = []
137
- async for result in self.run_async_generator(cache=self.cache, n=n):
98
+ async for result in self.run_async_generator(cache=self.cache):
138
99
  data.append(result)
139
100
  return Results(survey=self.jobs.survey, data=data)
140
101
 
@@ -212,8 +173,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
212
173
  raw_model_response=raw_model_results_dictionary,
213
174
  survey=interview.survey,
214
175
  )
215
- result.interview_hash = hash(interview)
216
-
217
176
  return result
218
177
 
219
178
  @property
@@ -242,86 +201,97 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
242
201
  self.sidecar_model = sidecar_model
243
202
 
244
203
  from edsl.results.Results import Results
245
- from rich.live import Live
246
- from rich.console import Console
247
-
248
- @cache_with_timeout(1)
249
- def generate_table():
250
- return self.status_table(self.results, self.elapsed_time)
251
204
 
252
- async def process_results(cache, progress_bar_context=None):
253
- """Processes results from interviews."""
254
- async for result in self.run_async_generator(
255
- n=n,
256
- debug=debug,
257
- stop_on_exception=stop_on_exception,
258
- cache=cache,
259
- sidecar_model=sidecar_model,
260
- ):
261
- self.results.append(result)
262
- if progress_bar_context:
263
- progress_bar_context.update(generate_table())
264
- self.completed = True
265
-
266
- async def update_progress_bar(progress_bar_context):
267
- """Updates the progress bar at fixed intervals."""
268
- if progress_bar_context is None:
269
- return
270
-
271
- while True:
272
- progress_bar_context.update(generate_table())
273
- await asyncio.sleep(0.1) # Update interval
274
- if self.completed:
275
- break
276
-
277
- @contextmanager
278
- def conditional_context(condition, context_manager):
279
- if condition:
280
- with context_manager as cm:
281
- yield cm
282
- else:
283
- yield
284
-
285
- with conditional_context(
286
- progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
287
- ) as progress_bar_context:
205
+ if not progress_bar:
206
+ # print("Running without progress bar")
288
207
  with cache as c:
289
- progress_task = asyncio.create_task(
290
- update_progress_bar(progress_bar_context)
291
- )
292
208
 
293
- try:
294
- await asyncio.gather(
295
- progress_task,
296
- process_results(
297
- cache=c, progress_bar_context=progress_bar_context
298
- ),
299
- )
300
- except asyncio.CancelledError:
301
- pass
302
- finally:
303
- progress_task.cancel() # Cancel the progress_task when process_results is done
304
- await progress_task
209
+ async def process_results():
210
+ """Processes results from interviews."""
211
+ async for result in self.run_async_generator(
212
+ n=n,
213
+ debug=debug,
214
+ stop_on_exception=stop_on_exception,
215
+ cache=c,
216
+ sidecar_model=sidecar_model,
217
+ ):
218
+ self.results.append(result)
219
+ self.completed = True
220
+
221
+ await asyncio.gather(process_results())
222
+
223
+ results = Results(survey=self.jobs.survey, data=self.results)
224
+ else:
225
+ # print("Running with progress bar")
226
+ from rich.live import Live
227
+ from rich.console import Console
305
228
 
306
- await asyncio.sleep(1) # short delay to show the final status
229
+ def generate_table():
230
+ return self.status_table(self.results, self.elapsed_time)
307
231
 
308
- if progress_bar_context:
309
- progress_bar_context.update(generate_table())
232
+ @contextmanager
233
+ def no_op_cm():
234
+ """A no-op context manager with a dummy update method."""
235
+ yield DummyLive()
310
236
 
311
- # puts results in the same order as the total interviews
312
- interview_hashes = [hash(interview) for interview in self.total_interviews]
313
- self.results = sorted(
314
- self.results, key=lambda x: interview_hashes.index(x.interview_hash)
315
- )
237
+ class DummyLive:
238
+ def update(self, *args, **kwargs):
239
+ """A dummy update method that does nothing."""
240
+ pass
241
+
242
+ progress_bar_context = (
243
+ Live(generate_table(), console=console, refresh_per_second=5)
244
+ if progress_bar
245
+ else no_op_cm()
246
+ )
247
+
248
+ with cache as c:
249
+ with progress_bar_context as live:
250
+
251
+ async def update_progress_bar():
252
+ """Updates the progress bar at fixed intervals."""
253
+ while True:
254
+ live.update(generate_table())
255
+ await asyncio.sleep(0.00001) # Update interval
256
+ if self.completed:
257
+ break
258
+
259
+ async def process_results():
260
+ """Processes results from interviews."""
261
+ async for result in self.run_async_generator(
262
+ n=n,
263
+ debug=debug,
264
+ stop_on_exception=stop_on_exception,
265
+ cache=c,
266
+ sidecar_model=sidecar_model,
267
+ ):
268
+ self.results.append(result)
269
+ live.update(generate_table())
270
+ self.completed = True
271
+
272
+ progress_task = asyncio.create_task(update_progress_bar())
273
+
274
+ try:
275
+ await asyncio.gather(process_results(), progress_task)
276
+ except asyncio.CancelledError:
277
+ pass
278
+ finally:
279
+ progress_task.cancel() # Cancel the progress_task when process_results is done
280
+ await progress_task
281
+
282
+ await asyncio.sleep(1) # short delay to show the final status
283
+
284
+ # one more update
285
+ live.update(generate_table())
286
+
287
+ results = Results(survey=self.jobs.survey, data=self.results)
316
288
 
317
- results = Results(survey=self.jobs.survey, data=self.results)
318
289
  task_history = TaskHistory(self.total_interviews, include_traceback=False)
319
290
  results.task_history = task_history
320
291
 
321
292
  results.has_exceptions = task_history.has_exceptions
322
293
 
323
294
  if results.has_exceptions:
324
- # put the failed interviews in the results object as a list
325
295
  failed_interviews = [
326
296
  interview.duplicate(
327
297
  iteration=interview.iteration, cache=interview.cache
@@ -342,7 +312,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
342
312
 
343
313
  shared_globals["edsl_runner_exceptions"] = task_history
344
314
  print(msg)
345
- # this is where exceptions are opening up
346
315
  task_history.html(cta="Open report to see details.")
347
316
  print(
348
317
  "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"