edsl 0.1.33__py3-none-any.whl → 0.1.33.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +3 -9
  2. edsl/__init__.py +0 -1
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +3 -6
  6. edsl/agents/InvigilatorBase.py +27 -8
  7. edsl/agents/{PromptConstructor.py → PromptConstructionMixin.py} +29 -101
  8. edsl/config.py +34 -26
  9. edsl/coop/coop.py +2 -11
  10. edsl/data_transfer_models.py +73 -26
  11. edsl/enums.py +0 -2
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +13 -44
  14. edsl/inference_services/OpenAIService.py +4 -7
  15. edsl/inference_services/TestService.py +15 -24
  16. edsl/inference_services/registry.py +0 -2
  17. edsl/jobs/Jobs.py +8 -18
  18. edsl/jobs/buckets/BucketCollection.py +15 -24
  19. edsl/jobs/buckets/TokenBucket.py +10 -64
  20. edsl/jobs/interviews/Interview.py +47 -115
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +0 -2
  22. edsl/jobs/interviews/{InterviewExceptionCollection.py → interview_exception_tracking.py} +0 -16
  23. edsl/jobs/interviews/retry_management.py +39 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +170 -95
  25. edsl/jobs/runners/JobsRunnerStatusMixin.py +333 -0
  26. edsl/jobs/tasks/TaskHistory.py +0 -17
  27. edsl/language_models/LanguageModel.py +31 -26
  28. edsl/language_models/registry.py +9 -13
  29. edsl/questions/QuestionBase.py +14 -63
  30. edsl/questions/QuestionBudget.py +41 -93
  31. edsl/questions/QuestionFreeText.py +0 -6
  32. edsl/questions/QuestionMultipleChoice.py +23 -8
  33. edsl/questions/QuestionNumerical.py +4 -5
  34. edsl/questions/ResponseValidatorABC.py +5 -6
  35. edsl/questions/derived/QuestionLinearScale.py +1 -4
  36. edsl/questions/derived/QuestionTopK.py +1 -4
  37. edsl/questions/derived/QuestionYesNo.py +2 -8
  38. edsl/results/DatasetExportMixin.py +1 -5
  39. edsl/results/Result.py +1 -1
  40. edsl/results/Results.py +1 -4
  41. edsl/scenarios/FileStore.py +10 -71
  42. edsl/scenarios/Scenario.py +21 -86
  43. edsl/scenarios/ScenarioImageMixin.py +2 -2
  44. edsl/scenarios/ScenarioList.py +0 -13
  45. edsl/scenarios/ScenarioListPdfMixin.py +4 -150
  46. edsl/study/Study.py +0 -32
  47. edsl/surveys/Rule.py +1 -10
  48. edsl/surveys/RuleCollection.py +3 -19
  49. edsl/surveys/Survey.py +0 -7
  50. edsl/templates/error_reporting/interview_details.html +1 -6
  51. edsl/utilities/utilities.py +1 -9
  52. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/METADATA +1 -2
  53. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/RECORD +55 -61
  54. edsl/inference_services/TogetherAIService.py +0 -170
  55. edsl/jobs/runners/JobsRunnerStatus.py +0 -331
  56. edsl/questions/Quick.py +0 -41
  57. edsl/questions/templates/budget/__init__.py +0 -0
  58. edsl/questions/templates/budget/answering_instructions.jinja +0 -7
  59. edsl/questions/templates/budget/question_presentation.jinja +0 -7
  60. edsl/questions/templates/extract/__init__.py +0 -0
  61. edsl/questions/templates/rank/__init__.py +0 -0
  62. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dist-info → edsl-0.1.33.dev2.dist-info}/WHEEL +0 -0
@@ -3,25 +3,40 @@ import time
3
3
  import math
4
4
  import asyncio
5
5
  import functools
6
- import threading
7
6
  from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
8
7
  from contextlib import contextmanager
9
8
  from collections import UserList
10
9
 
11
- from edsl.results.Results import Results
12
- from rich.live import Live
13
- from rich.console import Console
14
-
15
10
  from edsl import shared_globals
16
11
  from edsl.jobs.interviews.Interview import Interview
17
- from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
18
-
12
+ from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
19
13
  from edsl.jobs.tasks.TaskHistory import TaskHistory
20
14
  from edsl.jobs.buckets.BucketCollection import BucketCollection
21
15
  from edsl.utilities.decorators import jupyter_nb_handler
22
16
  from edsl.data.Cache import Cache
23
17
  from edsl.results.Result import Result
24
18
  from edsl.results.Results import Results
19
+ from edsl.jobs.FailedQuestion import FailedQuestion
20
+
21
+
22
+ def cache_with_timeout(timeout):
23
+ """ "Used to keep the generate table from being run too frequetly."""
24
+
25
+ def decorator(func):
26
+ cached_result = {}
27
+ last_computation_time = [0] # Using list to store mutable value
28
+
29
+ @functools.wraps(func)
30
+ def wrapper(*args, **kwargs):
31
+ current_time = time.time()
32
+ if (current_time - last_computation_time[0]) >= timeout:
33
+ cached_result["value"] = func(*args, **kwargs)
34
+ last_computation_time[0] = current_time
35
+ return cached_result["value"]
36
+
37
+ return wrapper
38
+
39
+ return decorator
25
40
 
26
41
 
27
42
  class StatusTracker(UserList):
@@ -33,7 +48,7 @@ class StatusTracker(UserList):
33
48
  return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
34
49
 
35
50
 
36
- class JobsRunnerAsyncio:
51
+ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
37
52
  """A class for running a collection of interviews asynchronously.
38
53
 
39
54
  It gets instaniated from a Jobs object.
@@ -42,12 +57,11 @@ class JobsRunnerAsyncio:
42
57
 
43
58
  def __init__(self, jobs: "Jobs"):
44
59
  self.jobs = jobs
60
+ # this creates the interviews, which can take a while
45
61
  self.interviews: List["Interview"] = jobs.interviews()
46
62
  self.bucket_collection: "BucketCollection" = jobs.bucket_collection
47
63
  self.total_interviews: List["Interview"] = []
48
64
 
49
- # self.jobs_runner_status = JobsRunnerStatus(self, n=1)
50
-
51
65
  async def run_async_generator(
52
66
  self,
53
67
  cache: "Cache",
@@ -65,7 +79,6 @@ class JobsRunnerAsyncio:
65
79
  :param stop_on_exception: Whether to stop the interview if an exception is raised
66
80
  :param sidecar_model: a language model to use in addition to the interview's model
67
81
  :param total_interviews: A list of interviews to run can be provided instead.
68
- :param raise_validation_errors: Whether to raise validation errors
69
82
  """
70
83
  tasks = []
71
84
  if total_interviews: # was already passed in total interviews
@@ -75,6 +88,8 @@ class JobsRunnerAsyncio:
75
88
  self._populate_total_interviews(n=n)
76
89
  ) # Populate self.total_interviews before creating tasks
77
90
 
91
+ # print("Interviews created")
92
+
78
93
  for interview in self.total_interviews:
79
94
  interviewing_task = self._build_interview_task(
80
95
  interview=interview,
@@ -84,9 +99,11 @@ class JobsRunnerAsyncio:
84
99
  )
85
100
  tasks.append(asyncio.create_task(interviewing_task))
86
101
 
102
+ # print("Tasks created")
103
+
87
104
  for task in asyncio.as_completed(tasks):
105
+ # print(f"Task {task} completed")
88
106
  result = await task
89
- self.jobs_runner_status.add_completed_interview(result)
90
107
  yield result
91
108
 
92
109
  def _populate_total_interviews(
@@ -105,8 +122,6 @@ class JobsRunnerAsyncio:
105
122
  yield interview
106
123
 
107
124
  async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
108
- """Used for some other modules that have a non-standard way of running interviews."""
109
- self.jobs_runner_status = JobsRunnerStatus(self, n=n)
110
125
  self.cache = Cache() if cache is None else cache
111
126
  data = []
112
127
  async for result in self.run_async_generator(cache=self.cache, n=n):
@@ -142,6 +157,12 @@ class JobsRunnerAsyncio:
142
157
  raise_validation_errors=raise_validation_errors,
143
158
  )
144
159
 
160
+ # answer_key_names = {
161
+ # k
162
+ # for k in set(answer.keys())
163
+ # if not k.endswith("_comment") and not k.endswith("_generated_tokens")
164
+ # }
165
+
145
166
  question_results = {}
146
167
  for result in valid_results:
147
168
  question_results[result.question_name] = result
@@ -153,13 +174,24 @@ class JobsRunnerAsyncio:
153
174
  for k in answer_key_names
154
175
  }
155
176
  comments_dict = {
156
- k + "_comment": question_results[k].comment for k in answer_key_names
177
+ "k" + "_comment": question_results[k].comment for k in answer_key_names
157
178
  }
158
179
 
159
180
  # we should have a valid result for each question
160
181
  answer_dict = {k: answer[k] for k in answer_key_names}
161
182
  assert len(valid_results) == len(answer_key_names)
162
183
 
184
+ # breakpoint()
185
+ # generated_tokens_dict = {
186
+ # k + "_generated_tokens": v.generated_tokens
187
+ # for k, v in zip(answer_key_names, valid_results)
188
+ # }
189
+
190
+ # comments_dict = {
191
+ # k + "_comment": v.comment for k, v in zip(answer_key_names, valid_results)
192
+ # }
193
+ # breakpoint()
194
+
163
195
  # TODO: move this down into Interview
164
196
  question_name_to_prompts = dict({})
165
197
  for result in valid_results:
@@ -171,19 +203,19 @@ class JobsRunnerAsyncio:
171
203
 
172
204
  prompt_dictionary = {}
173
205
  for answer_key_name in answer_key_names:
174
- prompt_dictionary[
175
- answer_key_name + "_user_prompt"
176
- ] = question_name_to_prompts[answer_key_name]["user_prompt"]
177
- prompt_dictionary[
178
- answer_key_name + "_system_prompt"
179
- ] = question_name_to_prompts[answer_key_name]["system_prompt"]
206
+ prompt_dictionary[answer_key_name + "_user_prompt"] = (
207
+ question_name_to_prompts[answer_key_name]["user_prompt"]
208
+ )
209
+ prompt_dictionary[answer_key_name + "_system_prompt"] = (
210
+ question_name_to_prompts[answer_key_name]["system_prompt"]
211
+ )
180
212
 
181
213
  raw_model_results_dictionary = {}
182
214
  for result in valid_results:
183
215
  question_name = result.question_name
184
- raw_model_results_dictionary[
185
- question_name + "_raw_model_response"
186
- ] = result.raw_model_response
216
+ raw_model_results_dictionary[question_name + "_raw_model_response"] = (
217
+ result.raw_model_response
218
+ )
187
219
  raw_model_results_dictionary[question_name + "_cost"] = result.cost
188
220
  one_use_buys = (
189
221
  "NA"
@@ -194,6 +226,7 @@ class JobsRunnerAsyncio:
194
226
  )
195
227
  raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
196
228
 
229
+ # breakpoint()
197
230
  result = Result(
198
231
  agent=interview.agent,
199
232
  scenario=interview.scenario,
@@ -214,62 +247,6 @@ class JobsRunnerAsyncio:
214
247
  def elapsed_time(self):
215
248
  return time.monotonic() - self.start_time
216
249
 
217
- def process_results(
218
- self, raw_results: Results, cache: Cache, print_exceptions: bool
219
- ):
220
- interview_lookup = {
221
- hash(interview): index
222
- for index, interview in enumerate(self.total_interviews)
223
- }
224
- interview_hashes = list(interview_lookup.keys())
225
-
226
- results = Results(
227
- survey=self.jobs.survey,
228
- data=sorted(
229
- raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
230
- ),
231
- )
232
- results.cache = cache
233
- results.task_history = TaskHistory(
234
- self.total_interviews, include_traceback=False
235
- )
236
- results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
237
- results.bucket_collection = self.bucket_collection
238
-
239
- if results.has_unfixed_exceptions and print_exceptions:
240
- from edsl.scenarios.FileStore import HTMLFileStore
241
- from edsl.config import CONFIG
242
- from edsl.coop.coop import Coop
243
-
244
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
245
-
246
- if len(results.task_history.indices) > 5:
247
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
248
-
249
- print(msg)
250
- # this is where exceptions are opening up
251
- filepath = results.task_history.html(
252
- cta="Open report to see details.",
253
- open_in_browser=True,
254
- return_link=True,
255
- )
256
-
257
- try:
258
- coop = Coop()
259
- user_edsl_settings = coop.edsl_settings
260
- remote_logging = user_edsl_settings["remote_logging"]
261
- except Exception as e:
262
- print(e)
263
- remote_logging = False
264
- if remote_logging:
265
- filestore = HTMLFileStore(filepath)
266
- coop_details = filestore.push(description="Error report")
267
- print(coop_details)
268
-
269
- print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
270
-
271
- return results
272
-
273
250
  @jupyter_nb_handler
274
251
  async def run(
275
252
  self,
@@ -282,16 +259,24 @@ class JobsRunnerAsyncio:
282
259
  raise_validation_errors: bool = False,
283
260
  ) -> "Coroutine":
284
261
  """Runs a collection of interviews, handling both async and sync contexts."""
262
+ from rich.console import Console
285
263
 
264
+ console = Console()
286
265
  self.results = []
287
266
  self.start_time = time.monotonic()
288
267
  self.completed = False
289
268
  self.cache = cache
290
269
  self.sidecar_model = sidecar_model
291
270
 
292
- self.jobs_runner_status = JobsRunnerStatus(self, n=n)
271
+ from edsl.results.Results import Results
272
+ from rich.live import Live
273
+ from rich.console import Console
293
274
 
294
- async def process_results(cache):
275
+ @cache_with_timeout(1)
276
+ def generate_table():
277
+ return self.status_table(self.results, self.elapsed_time)
278
+
279
+ async def process_results(cache, progress_bar_context=None):
295
280
  """Processes results from interviews."""
296
281
  async for result in self.run_async_generator(
297
282
  n=n,
@@ -301,22 +286,112 @@ class JobsRunnerAsyncio:
301
286
  raise_validation_errors=raise_validation_errors,
302
287
  ):
303
288
  self.results.append(result)
289
+ if progress_bar_context:
290
+ progress_bar_context.update(generate_table())
304
291
  self.completed = True
305
292
 
306
- def run_progress_bar():
307
- """Runs the progress bar in a separate thread."""
308
- self.jobs_runner_status.update_progress()
293
+ async def update_progress_bar(progress_bar_context):
294
+ """Updates the progress bar at fixed intervals."""
295
+ if progress_bar_context is None:
296
+ return
297
+
298
+ while True:
299
+ progress_bar_context.update(generate_table())
300
+ await asyncio.sleep(0.1) # Update interval
301
+ if self.completed:
302
+ break
303
+
304
+ @contextmanager
305
+ def conditional_context(condition, context_manager):
306
+ if condition:
307
+ with context_manager as cm:
308
+ yield cm
309
+ else:
310
+ yield
311
+
312
+ with conditional_context(
313
+ progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
314
+ ) as progress_bar_context:
315
+ with cache as c:
316
+ progress_task = asyncio.create_task(
317
+ update_progress_bar(progress_bar_context)
318
+ )
319
+
320
+ try:
321
+ await asyncio.gather(
322
+ progress_task,
323
+ process_results(
324
+ cache=c, progress_bar_context=progress_bar_context
325
+ ),
326
+ )
327
+ except asyncio.CancelledError:
328
+ pass
329
+ finally:
330
+ progress_task.cancel() # Cancel the progress_task when process_results is done
331
+ await progress_task
332
+
333
+ await asyncio.sleep(1) # short delay to show the final status
334
+
335
+ if progress_bar_context:
336
+ progress_bar_context.update(generate_table())
337
+
338
+ # puts results in the same order as the total interviews
339
+ interview_lookup = {
340
+ hash(interview): index
341
+ for index, interview in enumerate(self.total_interviews)
342
+ }
343
+ interview_hashes = list(interview_lookup.keys())
344
+ self.results = sorted(
345
+ self.results, key=lambda x: interview_hashes.index(x.interview_hash)
346
+ )
347
+
348
+ results = Results(survey=self.jobs.survey, data=self.results)
349
+ task_history = TaskHistory(self.total_interviews, include_traceback=False)
350
+ results.task_history = task_history
309
351
 
310
- if progress_bar:
311
- progress_thread = threading.Thread(target=run_progress_bar)
312
- progress_thread.start()
352
+ results.failed_questions = {}
353
+ results.has_exceptions = task_history.has_exceptions
313
354
 
314
- with cache as c:
315
- await process_results(cache=c)
355
+ # breakpoint()
356
+ results.bucket_collection = self.bucket_collection
316
357
 
317
- if progress_bar:
318
- progress_thread.join()
358
+ if results.has_exceptions:
359
+ # put the failed interviews in the results object as a list
360
+ failed_interviews = [
361
+ interview.duplicate(
362
+ iteration=interview.iteration, cache=interview.cache
363
+ )
364
+ for interview in self.total_interviews
365
+ if interview.has_exceptions
366
+ ]
319
367
 
320
- return self.process_results(
321
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
322
- )
368
+ failed_questions = {}
369
+ for interview in self.total_interviews:
370
+ if interview.has_exceptions:
371
+ index = interview_lookup[hash(interview)]
372
+ failed_questions[index] = interview.failed_questions
373
+
374
+ results.failed_questions = failed_questions
375
+
376
+ from edsl.jobs.Jobs import Jobs
377
+
378
+ results.failed_jobs = Jobs.from_interviews(
379
+ [interview for interview in failed_interviews]
380
+ )
381
+ if print_exceptions:
382
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
383
+
384
+ if len(results.task_history.indices) > 5:
385
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
386
+
387
+ shared_globals["edsl_runner_exceptions"] = task_history
388
+ print(msg)
389
+ # this is where exceptions are opening up
390
+ task_history.html(
391
+ cta="Open report to see details.", open_in_browser=True
392
+ )
393
+ print(
394
+ "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
395
+ )
396
+
397
+ return results