edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +9 -3
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +6 -3
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
  8. edsl/config.py +26 -34
  9. edsl/coop/coop.py +11 -2
  10. edsl/data_transfer_models.py +27 -73
  11. edsl/enums.py +2 -0
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +44 -13
  14. edsl/inference_services/OpenAIService.py +7 -4
  15. edsl/inference_services/TestService.py +24 -15
  16. edsl/inference_services/TogetherAIService.py +170 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +18 -8
  19. edsl/jobs/buckets/BucketCollection.py +24 -15
  20. edsl/jobs/buckets/TokenBucket.py +64 -10
  21. edsl/jobs/interviews/Interview.py +115 -47
  22. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  23. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
  25. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  26. edsl/jobs/tasks/TaskHistory.py +17 -0
  27. edsl/language_models/LanguageModel.py +26 -31
  28. edsl/language_models/registry.py +13 -9
  29. edsl/questions/QuestionBase.py +64 -16
  30. edsl/questions/QuestionBudget.py +93 -41
  31. edsl/questions/QuestionFreeText.py +6 -0
  32. edsl/questions/QuestionMultipleChoice.py +11 -26
  33. edsl/questions/QuestionNumerical.py +5 -4
  34. edsl/questions/Quick.py +41 -0
  35. edsl/questions/ResponseValidatorABC.py +6 -5
  36. edsl/questions/derived/QuestionLinearScale.py +4 -1
  37. edsl/questions/derived/QuestionTopK.py +4 -1
  38. edsl/questions/derived/QuestionYesNo.py +8 -2
  39. edsl/questions/templates/budget/__init__.py +0 -0
  40. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  41. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  42. edsl/questions/templates/extract/__init__.py +0 -0
  43. edsl/questions/templates/rank/__init__.py +0 -0
  44. edsl/results/DatasetExportMixin.py +5 -1
  45. edsl/results/Result.py +1 -1
  46. edsl/results/Results.py +4 -1
  47. edsl/scenarios/FileStore.py +71 -10
  48. edsl/scenarios/Scenario.py +86 -21
  49. edsl/scenarios/ScenarioImageMixin.py +2 -2
  50. edsl/scenarios/ScenarioList.py +13 -0
  51. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  52. edsl/study/Study.py +32 -0
  53. edsl/surveys/Rule.py +10 -1
  54. edsl/surveys/RuleCollection.py +19 -3
  55. edsl/surveys/Survey.py +7 -0
  56. edsl/templates/error_reporting/interview_details.html +6 -1
  57. edsl/utilities/utilities.py +9 -1
  58. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
  59. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
  60. edsl/jobs/interviews/retry_management.py +0 -39
  61. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  62. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -4,10 +4,18 @@ from __future__ import annotations
4
4
  import asyncio
5
5
  from typing import Any, Type, List, Generator, Optional, Union
6
6
 
7
+ from tenacity import (
8
+ retry,
9
+ stop_after_attempt,
10
+ wait_exponential,
11
+ retry_if_exception_type,
12
+ RetryError,
13
+ )
14
+
7
15
  from edsl import CONFIG
8
16
  from edsl.surveys.base import EndOfSurvey
9
17
  from edsl.exceptions import QuestionAnswerValidationError
10
- from edsl.exceptions import InterviewTimeoutError
18
+ from edsl.exceptions import QuestionAnswerValidationError
11
19
  from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
12
20
 
13
21
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
@@ -15,21 +23,18 @@ from edsl.jobs.Answers import Answers
15
23
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
16
24
  from edsl.jobs.tasks.TaskCreators import TaskCreators
17
25
  from edsl.jobs.interviews.InterviewStatusLog import InterviewStatusLog
18
- from edsl.jobs.interviews.interview_exception_tracking import (
26
+ from edsl.jobs.interviews.InterviewExceptionCollection import (
19
27
  InterviewExceptionCollection,
20
28
  )
21
- from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
22
- from edsl.jobs.interviews.retry_management import retry_strategy
29
+
23
30
  from edsl.jobs.interviews.InterviewStatusMixin import InterviewStatusMixin
24
31
 
25
32
  from edsl.surveys.base import EndOfSurvey
26
33
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
27
34
  from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
28
- from edsl.jobs.interviews.retry_management import retry_strategy
29
35
  from edsl.jobs.tasks.task_status_enum import TaskStatus
30
36
  from edsl.jobs.tasks.QuestionTaskCreator import QuestionTaskCreator
31
37
 
32
- from edsl.exceptions import QuestionAnswerValidationError
33
38
 
34
39
  from edsl import Agent, Survey, Scenario, Cache
35
40
  from edsl.language_models import LanguageModel
@@ -39,8 +44,11 @@ from edsl.agents.InvigilatorBase import InvigilatorBase
39
44
  from edsl.exceptions.language_models import LanguageModelNoResponseError
40
45
 
41
46
 
42
- class RetryableLanguageModelNoResponseError(LanguageModelNoResponseError):
43
- pass
47
+ from edsl import CONFIG
48
+
49
+ EDSL_BACKOFF_START_SEC = float(CONFIG.get("EDSL_BACKOFF_START_SEC"))
50
+ EDSL_BACKOFF_MAX_SEC = float(CONFIG.get("EDSL_BACKOFF_MAX_SEC"))
51
+ EDSL_MAX_ATTEMPTS = int(CONFIG.get("EDSL_MAX_ATTEMPTS"))
44
52
 
45
53
 
46
54
  class Interview(InterviewStatusMixin):
@@ -97,14 +105,17 @@ class Interview(InterviewStatusMixin):
97
105
  self.debug = debug
98
106
  self.iteration = iteration
99
107
  self.cache = cache
100
- self.answers: dict[str, str] = (
101
- Answers()
102
- ) # will get filled in as interview progresses
108
+ self.answers: dict[
109
+ str, str
110
+ ] = Answers() # will get filled in as interview progresses
103
111
  self.sidecar_model = sidecar_model
104
112
 
113
+ # self.stop_on_exception = False
114
+
105
115
  # Trackers
106
116
  self.task_creators = TaskCreators() # tracks the task creators
107
117
  self.exceptions = InterviewExceptionCollection()
118
+
108
119
  self._task_status_log_dict = InterviewStatusLog()
109
120
  self.skip_retry = skip_retry
110
121
  self.raise_validation_errors = raise_validation_errors
@@ -257,44 +268,80 @@ class Interview(InterviewStatusMixin):
257
268
  ) -> "AgentResponseDict":
258
269
  """Answer a question and records the task."""
259
270
 
260
- invigilator = self._get_invigilator(question)
271
+ had_language_model_no_response_error = False
261
272
 
262
- if self._skip_this_question(question):
263
- response = invigilator.get_failed_task_result(
264
- failure_reason="Question skipped."
265
- )
273
+ @retry(
274
+ stop=stop_after_attempt(EDSL_MAX_ATTEMPTS),
275
+ wait=wait_exponential(
276
+ multiplier=EDSL_BACKOFF_START_SEC, max=EDSL_BACKOFF_MAX_SEC
277
+ ),
278
+ retry=retry_if_exception_type(LanguageModelNoResponseError),
279
+ reraise=True,
280
+ )
281
+ async def attempt_answer():
282
+ nonlocal had_language_model_no_response_error
266
283
 
267
- try:
268
- response: EDSLResultObjectInput = await invigilator.async_answer_question()
269
- if response.validated:
270
- self.answers.add_answer(response=response, question=question)
271
- self._cancel_skipped_questions(question)
272
- else:
273
- if (
274
- hasattr(response, "exception_occurred")
275
- and response.exception_occurred
276
- ):
277
- raise response.exception_occurred
284
+ invigilator = self._get_invigilator(question)
278
285
 
279
- except QuestionAnswerValidationError as e:
280
- # there's a response, but it couldn't be validated
281
- self._handle_exception(e, invigilator, task)
286
+ if self._skip_this_question(question):
287
+ return invigilator.get_failed_task_result(
288
+ failure_reason="Question skipped."
289
+ )
282
290
 
283
- except asyncio.TimeoutError as e:
284
- # the API timed-out - this is recorded but as a response isn't generated, the LanguageModelNoResponseError will also be raised
285
- self._handle_exception(e, invigilator, task)
291
+ try:
292
+ response: EDSLResultObjectInput = (
293
+ await invigilator.async_answer_question()
294
+ )
295
+ if response.validated:
296
+ self.answers.add_answer(response=response, question=question)
297
+ self._cancel_skipped_questions(question)
298
+ else:
299
+ if (
300
+ hasattr(response, "exception_occurred")
301
+ and response.exception_occurred
302
+ ):
303
+ raise response.exception_occurred
304
+
305
+ except QuestionAnswerValidationError as e:
306
+ self._handle_exception(e, invigilator, task)
307
+ return invigilator.get_failed_task_result(
308
+ failure_reason="Question answer validation failed."
309
+ )
286
310
 
287
- except Exception as e:
288
- # there was some other exception
289
- self._handle_exception(e, invigilator, task)
311
+ except asyncio.TimeoutError as e:
312
+ self._handle_exception(e, invigilator, task)
313
+ had_language_model_no_response_error = True
314
+ raise LanguageModelNoResponseError(
315
+ f"Language model timed out for question '{question.question_name}.'"
316
+ )
290
317
 
291
- if "response" not in locals():
318
+ except Exception as e:
319
+ self._handle_exception(e, invigilator, task)
292
320
 
293
- raise LanguageModelNoResponseError(
294
- f"Language model did not return a response for question '{question.question_name}.'"
295
- )
321
+ if "response" not in locals():
322
+ had_language_model_no_response_error = True
323
+ raise LanguageModelNoResponseError(
324
+ f"Language model did not return a response for question '{question.question_name}.'"
325
+ )
326
+
327
+ # if it gets here, it means the no response error was fixed
328
+ if (
329
+ question.question_name in self.exceptions
330
+ and had_language_model_no_response_error
331
+ ):
332
+ self.exceptions.record_fixed_question(question.question_name)
296
333
 
297
- return response
334
+ return response
335
+
336
+ try:
337
+ return await attempt_answer()
338
+ except RetryError as retry_error:
339
+ # All retries have failed for LanguageModelNoResponseError
340
+ original_error = retry_error.last_attempt.exception()
341
+ self._handle_exception(
342
+ original_error, self._get_invigilator(question), task
343
+ )
344
+ raise original_error # Re-raise the original error after handling
298
345
 
299
346
  def _get_invigilator(self, question: QuestionBase) -> InvigilatorBase:
300
347
  """Return an invigilator for the given question.
@@ -334,14 +381,32 @@ class Interview(InterviewStatusMixin):
334
381
  def _handle_exception(
335
382
  self, e: Exception, invigilator: "InvigilatorBase", task=None
336
383
  ):
384
+ import copy
385
+
386
+ # breakpoint()
387
+
388
+ answers = copy.copy(self.answers)
337
389
  exception_entry = InterviewExceptionEntry(
338
390
  exception=e,
339
391
  invigilator=invigilator,
392
+ answers=answers,
340
393
  )
341
394
  if task:
342
395
  task.task_status = TaskStatus.FAILED
343
396
  self.exceptions.add(invigilator.question.question_name, exception_entry)
344
397
 
398
+ if self.raise_validation_errors:
399
+ if isinstance(e, QuestionAnswerValidationError):
400
+ raise e
401
+
402
+ if hasattr(self, "stop_on_exception"):
403
+ stop_on_exception = self.stop_on_exception
404
+ else:
405
+ stop_on_exception = False
406
+
407
+ if stop_on_exception:
408
+ raise e
409
+
345
410
  def _cancel_skipped_questions(self, current_question: QuestionBase) -> None:
346
411
  """Cancel the tasks for questions that are skipped.
347
412
 
@@ -353,11 +418,11 @@ class Interview(InterviewStatusMixin):
353
418
  """
354
419
  current_question_index: int = self.to_index[current_question.question_name]
355
420
 
356
- next_question: Union[int, EndOfSurvey] = (
357
- self.survey.rule_collection.next_question(
358
- q_now=current_question_index,
359
- answers=self.answers | self.scenario | self.agent["traits"],
360
- )
421
+ next_question: Union[
422
+ int, EndOfSurvey
423
+ ] = self.survey.rule_collection.next_question(
424
+ q_now=current_question_index,
425
+ answers=self.answers | self.scenario | self.agent["traits"],
361
426
  )
362
427
 
363
428
  next_question_index = next_question.next_q
@@ -411,6 +476,7 @@ class Interview(InterviewStatusMixin):
411
476
  asyncio.exceptions.CancelledError
412
477
  """
413
478
  self.sidecar_model = sidecar_model
479
+ self.stop_on_exception = stop_on_exception
414
480
 
415
481
  # if no model bucket is passed, create an 'infinity' bucket with no rate limits
416
482
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
@@ -424,7 +490,9 @@ class Interview(InterviewStatusMixin):
424
490
  self.invigilators = [
425
491
  self._get_invigilator(question) for question in self.survey.questions
426
492
  ]
427
- await asyncio.gather(*self.tasks, return_exceptions=not stop_on_exception)
493
+ await asyncio.gather(
494
+ *self.tasks, return_exceptions=not stop_on_exception
495
+ ) # not stop_on_exception)
428
496
  self.answers.replace_missing_answers_with_none(self.survey)
429
497
  valid_results = list(self._extract_valid_results())
430
498
  return self.answers, valid_results
@@ -6,6 +6,22 @@ from edsl.jobs.interviews.InterviewExceptionEntry import InterviewExceptionEntry
6
6
  class InterviewExceptionCollection(UserDict):
7
7
  """A collection of exceptions that occurred during the interview."""
8
8
 
9
+ def __init__(self):
10
+ super().__init__()
11
+ self.fixed = set()
12
+
13
+ def unfixed_exceptions(self) -> list:
14
+ """Return a list of unfixed exceptions."""
15
+ return {k: v for k, v in self.data.items() if k not in self.fixed}
16
+
17
+ def num_unfixed(self) -> list:
18
+ """Return a list of unfixed questions."""
19
+ return len([k for k in self.data.keys() if k not in self.fixed])
20
+
21
+ def record_fixed_question(self, question_name: str) -> None:
22
+ """Record that a question has been fixed."""
23
+ self.fixed.add(question_name)
24
+
9
25
  def add(self, question_name: str, entry: InterviewExceptionEntry) -> None:
10
26
  """Add an exception entry to the collection."""
11
27
  question_name = question_name
@@ -15,12 +15,14 @@ class InterviewExceptionEntry:
15
15
  # failed_question: FailedQuestion,
16
16
  invigilator: "Invigilator",
17
17
  traceback_format="text",
18
+ answers=None,
18
19
  ):
19
20
  self.time = datetime.datetime.now().isoformat()
20
21
  self.exception = exception
21
22
  # self.failed_question = failed_question
22
23
  self.invigilator = invigilator
23
24
  self.traceback_format = traceback_format
25
+ self.answers = answers
24
26
 
25
27
  @property
26
28
  def question_type(self):
@@ -3,40 +3,25 @@ import time
3
3
  import math
4
4
  import asyncio
5
5
  import functools
6
+ import threading
6
7
  from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
7
8
  from contextlib import contextmanager
8
9
  from collections import UserList
9
10
 
11
+ from edsl.results.Results import Results
12
+ from rich.live import Live
13
+ from rich.console import Console
14
+
10
15
  from edsl import shared_globals
11
16
  from edsl.jobs.interviews.Interview import Interview
12
- from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
17
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
18
+
13
19
  from edsl.jobs.tasks.TaskHistory import TaskHistory
14
20
  from edsl.jobs.buckets.BucketCollection import BucketCollection
15
21
  from edsl.utilities.decorators import jupyter_nb_handler
16
22
  from edsl.data.Cache import Cache
17
23
  from edsl.results.Result import Result
18
24
  from edsl.results.Results import Results
19
- from edsl.jobs.FailedQuestion import FailedQuestion
20
-
21
-
22
- def cache_with_timeout(timeout):
23
- """ "Used to keep the generate table from being run too frequetly."""
24
-
25
- def decorator(func):
26
- cached_result = {}
27
- last_computation_time = [0] # Using list to store mutable value
28
-
29
- @functools.wraps(func)
30
- def wrapper(*args, **kwargs):
31
- current_time = time.time()
32
- if (current_time - last_computation_time[0]) >= timeout:
33
- cached_result["value"] = func(*args, **kwargs)
34
- last_computation_time[0] = current_time
35
- return cached_result["value"]
36
-
37
- return wrapper
38
-
39
- return decorator
40
25
 
41
26
 
42
27
  class StatusTracker(UserList):
@@ -48,7 +33,7 @@ class StatusTracker(UserList):
48
33
  return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
34
 
50
35
 
51
- class JobsRunnerAsyncio(JobsRunnerStatusMixin):
36
+ class JobsRunnerAsyncio:
52
37
  """A class for running a collection of interviews asynchronously.
53
38
 
54
39
  It gets instaniated from a Jobs object.
@@ -57,11 +42,12 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
57
42
 
58
43
  def __init__(self, jobs: "Jobs"):
59
44
  self.jobs = jobs
60
- # this creates the interviews, which can take a while
61
45
  self.interviews: List["Interview"] = jobs.interviews()
62
46
  self.bucket_collection: "BucketCollection" = jobs.bucket_collection
63
47
  self.total_interviews: List["Interview"] = []
64
48
 
49
+ # self.jobs_runner_status = JobsRunnerStatus(self, n=1)
50
+
65
51
  async def run_async_generator(
66
52
  self,
67
53
  cache: "Cache",
@@ -79,6 +65,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
79
65
  :param stop_on_exception: Whether to stop the interview if an exception is raised
80
66
  :param sidecar_model: a language model to use in addition to the interview's model
81
67
  :param total_interviews: A list of interviews to run can be provided instead.
68
+ :param raise_validation_errors: Whether to raise validation errors
82
69
  """
83
70
  tasks = []
84
71
  if total_interviews: # was already passed in total interviews
@@ -88,8 +75,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
88
75
  self._populate_total_interviews(n=n)
89
76
  ) # Populate self.total_interviews before creating tasks
90
77
 
91
- # print("Interviews created")
92
-
93
78
  for interview in self.total_interviews:
94
79
  interviewing_task = self._build_interview_task(
95
80
  interview=interview,
@@ -99,11 +84,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
99
84
  )
100
85
  tasks.append(asyncio.create_task(interviewing_task))
101
86
 
102
- # print("Tasks created")
103
-
104
87
  for task in asyncio.as_completed(tasks):
105
- # print(f"Task {task} completed")
106
88
  result = await task
89
+ self.jobs_runner_status.add_completed_interview(result)
107
90
  yield result
108
91
 
109
92
  def _populate_total_interviews(
@@ -122,6 +105,8 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
122
105
  yield interview
123
106
 
124
107
  async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
108
+ """Used for some other modules that have a non-standard way of running interviews."""
109
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
125
110
  self.cache = Cache() if cache is None else cache
126
111
  data = []
127
112
  async for result in self.run_async_generator(cache=self.cache, n=n):
@@ -157,12 +142,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
157
142
  raise_validation_errors=raise_validation_errors,
158
143
  )
159
144
 
160
- # answer_key_names = {
161
- # k
162
- # for k in set(answer.keys())
163
- # if not k.endswith("_comment") and not k.endswith("_generated_tokens")
164
- # }
165
-
166
145
  question_results = {}
167
146
  for result in valid_results:
168
147
  question_results[result.question_name] = result
@@ -174,24 +153,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
174
153
  for k in answer_key_names
175
154
  }
176
155
  comments_dict = {
177
- "k" + "_comment": question_results[k].comment for k in answer_key_names
156
+ k + "_comment": question_results[k].comment for k in answer_key_names
178
157
  }
179
158
 
180
159
  # we should have a valid result for each question
181
160
  answer_dict = {k: answer[k] for k in answer_key_names}
182
161
  assert len(valid_results) == len(answer_key_names)
183
162
 
184
- # breakpoint()
185
- # generated_tokens_dict = {
186
- # k + "_generated_tokens": v.generated_tokens
187
- # for k, v in zip(answer_key_names, valid_results)
188
- # }
189
-
190
- # comments_dict = {
191
- # k + "_comment": v.comment for k, v in zip(answer_key_names, valid_results)
192
- # }
193
- # breakpoint()
194
-
195
163
  # TODO: move this down into Interview
196
164
  question_name_to_prompts = dict({})
197
165
  for result in valid_results:
@@ -226,7 +194,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
226
194
  )
227
195
  raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
228
196
 
229
- # breakpoint()
230
197
  result = Result(
231
198
  agent=interview.agent,
232
199
  scenario=interview.scenario,
@@ -247,6 +214,62 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
247
214
  def elapsed_time(self):
248
215
  return time.monotonic() - self.start_time
249
216
 
217
+ def process_results(
218
+ self, raw_results: Results, cache: Cache, print_exceptions: bool
219
+ ):
220
+ interview_lookup = {
221
+ hash(interview): index
222
+ for index, interview in enumerate(self.total_interviews)
223
+ }
224
+ interview_hashes = list(interview_lookup.keys())
225
+
226
+ results = Results(
227
+ survey=self.jobs.survey,
228
+ data=sorted(
229
+ raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
230
+ ),
231
+ )
232
+ results.cache = cache
233
+ results.task_history = TaskHistory(
234
+ self.total_interviews, include_traceback=False
235
+ )
236
+ results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
237
+ results.bucket_collection = self.bucket_collection
238
+
239
+ if results.has_unfixed_exceptions and print_exceptions:
240
+ from edsl.scenarios.FileStore import HTMLFileStore
241
+ from edsl.config import CONFIG
242
+ from edsl.coop.coop import Coop
243
+
244
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
245
+
246
+ if len(results.task_history.indices) > 5:
247
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
248
+
249
+ print(msg)
250
+ # this is where exceptions are opening up
251
+ filepath = results.task_history.html(
252
+ cta="Open report to see details.",
253
+ open_in_browser=True,
254
+ return_link=True,
255
+ )
256
+
257
+ try:
258
+ coop = Coop()
259
+ user_edsl_settings = coop.edsl_settings
260
+ remote_logging = user_edsl_settings["remote_logging"]
261
+ except Exception as e:
262
+ print(e)
263
+ remote_logging = False
264
+ if remote_logging:
265
+ filestore = HTMLFileStore(filepath)
266
+ coop_details = filestore.push(description="Error report")
267
+ print(coop_details)
268
+
269
+ print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
270
+
271
+ return results
272
+
250
273
  @jupyter_nb_handler
251
274
  async def run(
252
275
  self,
@@ -259,24 +282,16 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
259
282
  raise_validation_errors: bool = False,
260
283
  ) -> "Coroutine":
261
284
  """Runs a collection of interviews, handling both async and sync contexts."""
262
- from rich.console import Console
263
285
 
264
- console = Console()
265
286
  self.results = []
266
287
  self.start_time = time.monotonic()
267
288
  self.completed = False
268
289
  self.cache = cache
269
290
  self.sidecar_model = sidecar_model
270
291
 
271
- from edsl.results.Results import Results
272
- from rich.live import Live
273
- from rich.console import Console
274
-
275
- @cache_with_timeout(1)
276
- def generate_table():
277
- return self.status_table(self.results, self.elapsed_time)
292
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
278
293
 
279
- async def process_results(cache, progress_bar_context=None):
294
+ async def process_results(cache):
280
295
  """Processes results from interviews."""
281
296
  async for result in self.run_async_generator(
282
297
  n=n,
@@ -286,112 +301,22 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
286
301
  raise_validation_errors=raise_validation_errors,
287
302
  ):
288
303
  self.results.append(result)
289
- if progress_bar_context:
290
- progress_bar_context.update(generate_table())
291
304
  self.completed = True
292
305
 
293
- async def update_progress_bar(progress_bar_context):
294
- """Updates the progress bar at fixed intervals."""
295
- if progress_bar_context is None:
296
- return
297
-
298
- while True:
299
- progress_bar_context.update(generate_table())
300
- await asyncio.sleep(0.1) # Update interval
301
- if self.completed:
302
- break
303
-
304
- @contextmanager
305
- def conditional_context(condition, context_manager):
306
- if condition:
307
- with context_manager as cm:
308
- yield cm
309
- else:
310
- yield
311
-
312
- with conditional_context(
313
- progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
314
- ) as progress_bar_context:
315
- with cache as c:
316
- progress_task = asyncio.create_task(
317
- update_progress_bar(progress_bar_context)
318
- )
319
-
320
- try:
321
- await asyncio.gather(
322
- progress_task,
323
- process_results(
324
- cache=c, progress_bar_context=progress_bar_context
325
- ),
326
- )
327
- except asyncio.CancelledError:
328
- pass
329
- finally:
330
- progress_task.cancel() # Cancel the progress_task when process_results is done
331
- await progress_task
332
-
333
- await asyncio.sleep(1) # short delay to show the final status
334
-
335
- if progress_bar_context:
336
- progress_bar_context.update(generate_table())
337
-
338
- # puts results in the same order as the total interviews
339
- interview_lookup = {
340
- hash(interview): index
341
- for index, interview in enumerate(self.total_interviews)
342
- }
343
- interview_hashes = list(interview_lookup.keys())
344
- self.results = sorted(
345
- self.results, key=lambda x: interview_hashes.index(x.interview_hash)
346
- )
347
-
348
- results = Results(survey=self.jobs.survey, data=self.results)
349
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
350
- results.task_history = task_history
351
-
352
- results.failed_questions = {}
353
- results.has_exceptions = task_history.has_exceptions
354
-
355
- # breakpoint()
356
- results.bucket_collection = self.bucket_collection
357
-
358
- if results.has_exceptions:
359
- # put the failed interviews in the results object as a list
360
- failed_interviews = [
361
- interview.duplicate(
362
- iteration=interview.iteration, cache=interview.cache
363
- )
364
- for interview in self.total_interviews
365
- if interview.has_exceptions
366
- ]
367
-
368
- failed_questions = {}
369
- for interview in self.total_interviews:
370
- if interview.has_exceptions:
371
- index = interview_lookup[hash(interview)]
372
- failed_questions[index] = interview.failed_questions
306
+ def run_progress_bar():
307
+ """Runs the progress bar in a separate thread."""
308
+ self.jobs_runner_status.update_progress()
373
309
 
374
- results.failed_questions = failed_questions
310
+ if progress_bar:
311
+ progress_thread = threading.Thread(target=run_progress_bar)
312
+ progress_thread.start()
375
313
 
376
- from edsl.jobs.Jobs import Jobs
314
+ with cache as c:
315
+ await process_results(cache=c)
377
316
 
378
- results.failed_jobs = Jobs.from_interviews(
379
- [interview for interview in failed_interviews]
380
- )
381
- if print_exceptions:
382
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
383
-
384
- if len(results.task_history.indices) > 5:
385
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
386
-
387
- shared_globals["edsl_runner_exceptions"] = task_history
388
- print(msg)
389
- # this is where exceptions are opening up
390
- task_history.html(
391
- cta="Open report to see details.", open_in_browser=True
392
- )
393
- print(
394
- "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
395
- )
317
+ if progress_bar:
318
+ progress_thread.join()
396
319
 
397
- return results
320
+ return self.process_results(
321
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
322
+ )