edsl 0.1.33.dev2__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. edsl/Base.py +24 -14
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +28 -6
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +150 -182
  8. edsl/agents/prompt_helpers.py +129 -0
  9. edsl/config.py +26 -34
  10. edsl/coop/coop.py +14 -4
  11. edsl/data_transfer_models.py +26 -73
  12. edsl/enums.py +2 -0
  13. edsl/inference_services/AnthropicService.py +5 -2
  14. edsl/inference_services/AwsBedrock.py +5 -2
  15. edsl/inference_services/AzureAI.py +5 -2
  16. edsl/inference_services/GoogleService.py +108 -33
  17. edsl/inference_services/InferenceServiceABC.py +44 -13
  18. edsl/inference_services/MistralAIService.py +5 -2
  19. edsl/inference_services/OpenAIService.py +10 -6
  20. edsl/inference_services/TestService.py +34 -16
  21. edsl/inference_services/TogetherAIService.py +170 -0
  22. edsl/inference_services/registry.py +2 -0
  23. edsl/jobs/Jobs.py +109 -18
  24. edsl/jobs/buckets/BucketCollection.py +24 -15
  25. edsl/jobs/buckets/TokenBucket.py +64 -10
  26. edsl/jobs/interviews/Interview.py +130 -49
  27. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  28. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  29. edsl/jobs/runners/JobsRunnerAsyncio.py +119 -173
  30. edsl/jobs/runners/JobsRunnerStatus.py +332 -0
  31. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  32. edsl/jobs/tasks/TaskHistory.py +17 -0
  33. edsl/language_models/LanguageModel.py +36 -38
  34. edsl/language_models/registry.py +13 -9
  35. edsl/language_models/utilities.py +5 -2
  36. edsl/questions/QuestionBase.py +74 -16
  37. edsl/questions/QuestionBaseGenMixin.py +28 -0
  38. edsl/questions/QuestionBudget.py +93 -41
  39. edsl/questions/QuestionCheckBox.py +1 -1
  40. edsl/questions/QuestionFreeText.py +6 -0
  41. edsl/questions/QuestionMultipleChoice.py +13 -24
  42. edsl/questions/QuestionNumerical.py +5 -4
  43. edsl/questions/Quick.py +41 -0
  44. edsl/questions/ResponseValidatorABC.py +11 -6
  45. edsl/questions/derived/QuestionLinearScale.py +4 -1
  46. edsl/questions/derived/QuestionTopK.py +4 -1
  47. edsl/questions/derived/QuestionYesNo.py +8 -2
  48. edsl/questions/descriptors.py +12 -11
  49. edsl/questions/templates/budget/__init__.py +0 -0
  50. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  51. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  52. edsl/questions/templates/extract/__init__.py +0 -0
  53. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  54. edsl/questions/templates/rank/__init__.py +0 -0
  55. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  56. edsl/results/DatasetExportMixin.py +5 -1
  57. edsl/results/Result.py +1 -1
  58. edsl/results/Results.py +4 -1
  59. edsl/scenarios/FileStore.py +178 -34
  60. edsl/scenarios/Scenario.py +76 -37
  61. edsl/scenarios/ScenarioList.py +19 -2
  62. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  63. edsl/study/Study.py +32 -0
  64. edsl/surveys/DAG.py +62 -0
  65. edsl/surveys/MemoryPlan.py +26 -0
  66. edsl/surveys/Rule.py +34 -1
  67. edsl/surveys/RuleCollection.py +55 -5
  68. edsl/surveys/Survey.py +189 -10
  69. edsl/surveys/base.py +4 -0
  70. edsl/templates/error_reporting/interview_details.html +6 -1
  71. edsl/utilities/utilities.py +9 -1
  72. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/METADATA +3 -1
  73. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/RECORD +75 -69
  74. edsl/jobs/interviews/retry_management.py +0 -39
  75. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  76. edsl/scenarios/ScenarioImageMixin.py +0 -100
  77. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  78. {edsl-0.1.33.dev2.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -3,40 +3,27 @@ import time
3
3
  import math
4
4
  import asyncio
5
5
  import functools
6
+ import threading
6
7
  from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
7
8
  from contextlib import contextmanager
8
9
  from collections import UserList
9
10
 
11
+ from rich.live import Live
12
+ from rich.console import Console
13
+
14
+ from edsl.results.Results import Results
10
15
  from edsl import shared_globals
11
16
  from edsl.jobs.interviews.Interview import Interview
12
- from edsl.jobs.runners.JobsRunnerStatusMixin import JobsRunnerStatusMixin
17
+ from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
18
+
13
19
  from edsl.jobs.tasks.TaskHistory import TaskHistory
14
20
  from edsl.jobs.buckets.BucketCollection import BucketCollection
15
21
  from edsl.utilities.decorators import jupyter_nb_handler
16
22
  from edsl.data.Cache import Cache
17
23
  from edsl.results.Result import Result
18
24
  from edsl.results.Results import Results
19
- from edsl.jobs.FailedQuestion import FailedQuestion
20
-
21
-
22
- def cache_with_timeout(timeout):
23
- """ "Used to keep the generate table from being run too frequetly."""
24
-
25
- def decorator(func):
26
- cached_result = {}
27
- last_computation_time = [0] # Using list to store mutable value
28
-
29
- @functools.wraps(func)
30
- def wrapper(*args, **kwargs):
31
- current_time = time.time()
32
- if (current_time - last_computation_time[0]) >= timeout:
33
- cached_result["value"] = func(*args, **kwargs)
34
- last_computation_time[0] = current_time
35
- return cached_result["value"]
36
-
37
- return wrapper
38
-
39
- return decorator
25
+ from edsl.language_models.LanguageModel import LanguageModel
26
+ from edsl.data.Cache import Cache
40
27
 
41
28
 
42
29
  class StatusTracker(UserList):
@@ -48,7 +35,7 @@ class StatusTracker(UserList):
48
35
  return print(f"Completed: {len(self.data)} of {self.total_tasks}", end="\r")
49
36
 
50
37
 
51
- class JobsRunnerAsyncio(JobsRunnerStatusMixin):
38
+ class JobsRunnerAsyncio:
52
39
  """A class for running a collection of interviews asynchronously.
53
40
 
54
41
  It gets instaniated from a Jobs object.
@@ -57,17 +44,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
57
44
 
58
45
  def __init__(self, jobs: "Jobs"):
59
46
  self.jobs = jobs
60
- # this creates the interviews, which can take a while
61
47
  self.interviews: List["Interview"] = jobs.interviews()
62
48
  self.bucket_collection: "BucketCollection" = jobs.bucket_collection
63
49
  self.total_interviews: List["Interview"] = []
64
50
 
51
+ # self.jobs_runner_status = JobsRunnerStatus(self, n=1)
52
+
65
53
  async def run_async_generator(
66
54
  self,
67
- cache: "Cache",
55
+ cache: Cache,
68
56
  n: int = 1,
69
57
  stop_on_exception: bool = False,
70
- sidecar_model: Optional["LanguageModel"] = None,
58
+ sidecar_model: Optional[LanguageModel] = None,
71
59
  total_interviews: Optional[List["Interview"]] = None,
72
60
  raise_validation_errors: bool = False,
73
61
  ) -> AsyncGenerator["Result", None]:
@@ -79,6 +67,7 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
79
67
  :param stop_on_exception: Whether to stop the interview if an exception is raised
80
68
  :param sidecar_model: a language model to use in addition to the interview's model
81
69
  :param total_interviews: A list of interviews to run can be provided instead.
70
+ :param raise_validation_errors: Whether to raise validation errors
82
71
  """
83
72
  tasks = []
84
73
  if total_interviews: # was already passed in total interviews
@@ -88,8 +77,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
88
77
  self._populate_total_interviews(n=n)
89
78
  ) # Populate self.total_interviews before creating tasks
90
79
 
91
- # print("Interviews created")
92
-
93
80
  for interview in self.total_interviews:
94
81
  interviewing_task = self._build_interview_task(
95
82
  interview=interview,
@@ -99,11 +86,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
99
86
  )
100
87
  tasks.append(asyncio.create_task(interviewing_task))
101
88
 
102
- # print("Tasks created")
103
-
104
89
  for task in asyncio.as_completed(tasks):
105
- # print(f"Task {task} completed")
106
90
  result = await task
91
+ self.jobs_runner_status.add_completed_interview(result)
107
92
  yield result
108
93
 
109
94
  def _populate_total_interviews(
@@ -121,7 +106,9 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
121
106
  interview.cache = self.cache
122
107
  yield interview
123
108
 
124
- async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
109
+ async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
110
+ """Used for some other modules that have a non-standard way of running interviews."""
111
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
125
112
  self.cache = Cache() if cache is None else cache
126
113
  data = []
127
114
  async for result in self.run_async_generator(cache=self.cache, n=n):
@@ -157,12 +144,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
157
144
  raise_validation_errors=raise_validation_errors,
158
145
  )
159
146
 
160
- # answer_key_names = {
161
- # k
162
- # for k in set(answer.keys())
163
- # if not k.endswith("_comment") and not k.endswith("_generated_tokens")
164
- # }
165
-
166
147
  question_results = {}
167
148
  for result in valid_results:
168
149
  question_results[result.question_name] = result
@@ -174,24 +155,13 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
174
155
  for k in answer_key_names
175
156
  }
176
157
  comments_dict = {
177
- "k" + "_comment": question_results[k].comment for k in answer_key_names
158
+ k + "_comment": question_results[k].comment for k in answer_key_names
178
159
  }
179
160
 
180
161
  # we should have a valid result for each question
181
162
  answer_dict = {k: answer[k] for k in answer_key_names}
182
163
  assert len(valid_results) == len(answer_key_names)
183
164
 
184
- # breakpoint()
185
- # generated_tokens_dict = {
186
- # k + "_generated_tokens": v.generated_tokens
187
- # for k, v in zip(answer_key_names, valid_results)
188
- # }
189
-
190
- # comments_dict = {
191
- # k + "_comment": v.comment for k, v in zip(answer_key_names, valid_results)
192
- # }
193
- # breakpoint()
194
-
195
165
  # TODO: move this down into Interview
196
166
  question_name_to_prompts = dict({})
197
167
  for result in valid_results:
@@ -203,19 +173,19 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
203
173
 
204
174
  prompt_dictionary = {}
205
175
  for answer_key_name in answer_key_names:
206
- prompt_dictionary[answer_key_name + "_user_prompt"] = (
207
- question_name_to_prompts[answer_key_name]["user_prompt"]
208
- )
209
- prompt_dictionary[answer_key_name + "_system_prompt"] = (
210
- question_name_to_prompts[answer_key_name]["system_prompt"]
211
- )
176
+ prompt_dictionary[
177
+ answer_key_name + "_user_prompt"
178
+ ] = question_name_to_prompts[answer_key_name]["user_prompt"]
179
+ prompt_dictionary[
180
+ answer_key_name + "_system_prompt"
181
+ ] = question_name_to_prompts[answer_key_name]["system_prompt"]
212
182
 
213
183
  raw_model_results_dictionary = {}
214
184
  for result in valid_results:
215
185
  question_name = result.question_name
216
- raw_model_results_dictionary[question_name + "_raw_model_response"] = (
217
- result.raw_model_response
218
- )
186
+ raw_model_results_dictionary[
187
+ question_name + "_raw_model_response"
188
+ ] = result.raw_model_response
219
189
  raw_model_results_dictionary[question_name + "_cost"] = result.cost
220
190
  one_use_buys = (
221
191
  "NA"
@@ -226,7 +196,6 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
226
196
  )
227
197
  raw_model_results_dictionary[question_name + "_one_usd_buys"] = one_use_buys
228
198
 
229
- # breakpoint()
230
199
  result = Result(
231
200
  agent=interview.agent,
232
201
  scenario=interview.scenario,
@@ -247,6 +216,62 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
247
216
  def elapsed_time(self):
248
217
  return time.monotonic() - self.start_time
249
218
 
219
+ def process_results(
220
+ self, raw_results: Results, cache: Cache, print_exceptions: bool
221
+ ):
222
+ interview_lookup = {
223
+ hash(interview): index
224
+ for index, interview in enumerate(self.total_interviews)
225
+ }
226
+ interview_hashes = list(interview_lookup.keys())
227
+
228
+ results = Results(
229
+ survey=self.jobs.survey,
230
+ data=sorted(
231
+ raw_results, key=lambda x: interview_hashes.index(x.interview_hash)
232
+ ),
233
+ )
234
+ results.cache = cache
235
+ results.task_history = TaskHistory(
236
+ self.total_interviews, include_traceback=False
237
+ )
238
+ results.has_unfixed_exceptions = results.task_history.has_unfixed_exceptions
239
+ results.bucket_collection = self.bucket_collection
240
+
241
+ if results.has_unfixed_exceptions and print_exceptions:
242
+ from edsl.scenarios.FileStore import HTMLFileStore
243
+ from edsl.config import CONFIG
244
+ from edsl.coop.coop import Coop
245
+
246
+ msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
247
+
248
+ if len(results.task_history.indices) > 5:
249
+ msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
250
+
251
+ print(msg)
252
+ # this is where exceptions are opening up
253
+ filepath = results.task_history.html(
254
+ cta="Open report to see details.",
255
+ open_in_browser=True,
256
+ return_link=True,
257
+ )
258
+
259
+ try:
260
+ coop = Coop()
261
+ user_edsl_settings = coop.edsl_settings
262
+ remote_logging = user_edsl_settings["remote_logging"]
263
+ except Exception as e:
264
+ print(e)
265
+ remote_logging = False
266
+ if remote_logging:
267
+ filestore = HTMLFileStore(filepath)
268
+ coop_details = filestore.push(description="Error report")
269
+ print(coop_details)
270
+
271
+ print("Also see: https://docs.expectedparrot.com/en/latest/exceptions.html")
272
+
273
+ return results
274
+
250
275
  @jupyter_nb_handler
251
276
  async def run(
252
277
  self,
@@ -259,24 +284,18 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
259
284
  raise_validation_errors: bool = False,
260
285
  ) -> "Coroutine":
261
286
  """Runs a collection of interviews, handling both async and sync contexts."""
262
- from rich.console import Console
263
287
 
264
- console = Console()
265
288
  self.results = []
266
289
  self.start_time = time.monotonic()
267
290
  self.completed = False
268
291
  self.cache = cache
269
292
  self.sidecar_model = sidecar_model
270
293
 
271
- from edsl.results.Results import Results
272
- from rich.live import Live
273
- from rich.console import Console
294
+ self.jobs_runner_status = JobsRunnerStatus(self, n=n)
274
295
 
275
- @cache_with_timeout(1)
276
- def generate_table():
277
- return self.status_table(self.results, self.elapsed_time)
296
+ stop_event = threading.Event()
278
297
 
279
- async def process_results(cache, progress_bar_context=None):
298
+ async def process_results(cache):
280
299
  """Processes results from interviews."""
281
300
  async for result in self.run_async_generator(
282
301
  n=n,
@@ -286,112 +305,39 @@ class JobsRunnerAsyncio(JobsRunnerStatusMixin):
286
305
  raise_validation_errors=raise_validation_errors,
287
306
  ):
288
307
  self.results.append(result)
289
- if progress_bar_context:
290
- progress_bar_context.update(generate_table())
291
308
  self.completed = True
292
309
 
293
- async def update_progress_bar(progress_bar_context):
294
- """Updates the progress bar at fixed intervals."""
295
- if progress_bar_context is None:
296
- return
297
-
298
- while True:
299
- progress_bar_context.update(generate_table())
300
- await asyncio.sleep(0.1) # Update interval
301
- if self.completed:
302
- break
303
-
304
- @contextmanager
305
- def conditional_context(condition, context_manager):
306
- if condition:
307
- with context_manager as cm:
308
- yield cm
309
- else:
310
- yield
311
-
312
- with conditional_context(
313
- progress_bar, Live(generate_table(), console=console, refresh_per_second=1)
314
- ) as progress_bar_context:
315
- with cache as c:
316
- progress_task = asyncio.create_task(
317
- update_progress_bar(progress_bar_context)
318
- )
319
-
320
- try:
321
- await asyncio.gather(
322
- progress_task,
323
- process_results(
324
- cache=c, progress_bar_context=progress_bar_context
325
- ),
326
- )
327
- except asyncio.CancelledError:
328
- pass
329
- finally:
330
- progress_task.cancel() # Cancel the progress_task when process_results is done
331
- await progress_task
332
-
333
- await asyncio.sleep(1) # short delay to show the final status
334
-
335
- if progress_bar_context:
336
- progress_bar_context.update(generate_table())
337
-
338
- # puts results in the same order as the total interviews
339
- interview_lookup = {
340
- hash(interview): index
341
- for index, interview in enumerate(self.total_interviews)
342
- }
343
- interview_hashes = list(interview_lookup.keys())
344
- self.results = sorted(
345
- self.results, key=lambda x: interview_hashes.index(x.interview_hash)
346
- )
347
-
348
- results = Results(survey=self.jobs.survey, data=self.results)
349
- task_history = TaskHistory(self.total_interviews, include_traceback=False)
350
- results.task_history = task_history
351
-
352
- results.failed_questions = {}
353
- results.has_exceptions = task_history.has_exceptions
354
-
355
- # breakpoint()
356
- results.bucket_collection = self.bucket_collection
310
+ def run_progress_bar(stop_event):
311
+ """Runs the progress bar in a separate thread."""
312
+ self.jobs_runner_status.update_progress(stop_event)
357
313
 
358
- if results.has_exceptions:
359
- # put the failed interviews in the results object as a list
360
- failed_interviews = [
361
- interview.duplicate(
362
- iteration=interview.iteration, cache=interview.cache
363
- )
364
- for interview in self.total_interviews
365
- if interview.has_exceptions
366
- ]
367
-
368
- failed_questions = {}
369
- for interview in self.total_interviews:
370
- if interview.has_exceptions:
371
- index = interview_lookup[hash(interview)]
372
- failed_questions[index] = interview.failed_questions
373
-
374
- results.failed_questions = failed_questions
375
-
376
- from edsl.jobs.Jobs import Jobs
377
-
378
- results.failed_jobs = Jobs.from_interviews(
379
- [interview for interview in failed_interviews]
314
+ if progress_bar:
315
+ progress_thread = threading.Thread(
316
+ target=run_progress_bar, args=(stop_event,)
380
317
  )
381
- if print_exceptions:
382
- msg = f"Exceptions were raised in {len(results.task_history.indices)} out of {len(self.total_interviews)} interviews.\n"
383
-
384
- if len(results.task_history.indices) > 5:
385
- msg += f"Exceptions were raised in the following interviews: {results.task_history.indices}.\n"
386
-
387
- shared_globals["edsl_runner_exceptions"] = task_history
388
- print(msg)
389
- # this is where exceptions are opening up
390
- task_history.html(
391
- cta="Open report to see details.", open_in_browser=True
392
- )
393
- print(
394
- "Also see: https://docs.expectedparrot.com/en/latest/exceptions.html"
395
- )
318
+ progress_thread.start()
396
319
 
397
- return results
320
+ exception_to_raise = None
321
+ try:
322
+ with cache as c:
323
+ await process_results(cache=c)
324
+ except KeyboardInterrupt:
325
+ print("Keyboard interrupt received. Stopping gracefully...")
326
+ stop_event.set()
327
+ except Exception as e:
328
+ if stop_on_exception:
329
+ exception_to_raise = e
330
+ stop_event.set()
331
+ finally:
332
+ stop_event.set()
333
+ if progress_bar:
334
+ # self.jobs_runner_status.stop_event.set()
335
+ if progress_thread:
336
+ progress_thread.join()
337
+
338
+ if exception_to_raise:
339
+ raise exception_to_raise
340
+
341
+ return self.process_results(
342
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
343
+ )