edsl 0.1.54__py3-none-any.whl → 0.1.56__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/base/data_transfer_models.py +15 -4
  8. edsl/buckets/__init__.py +8 -3
  9. edsl/buckets/bucket_collection.py +9 -3
  10. edsl/buckets/model_buckets.py +4 -2
  11. edsl/buckets/token_bucket.py +2 -2
  12. edsl/buckets/token_bucket_client.py +5 -3
  13. edsl/caching/cache.py +131 -62
  14. edsl/caching/cache_entry.py +70 -58
  15. edsl/caching/sql_dict.py +17 -0
  16. edsl/cli.py +99 -0
  17. edsl/config/config_class.py +16 -0
  18. edsl/conversation/__init__.py +31 -0
  19. edsl/coop/coop.py +276 -242
  20. edsl/coop/coop_jobs_objects.py +59 -0
  21. edsl/coop/coop_objects.py +29 -0
  22. edsl/coop/coop_regular_objects.py +26 -0
  23. edsl/coop/utils.py +24 -19
  24. edsl/dataset/dataset.py +338 -101
  25. edsl/dataset/dataset_operations_mixin.py +216 -180
  26. edsl/db_list/sqlite_list.py +349 -0
  27. edsl/inference_services/__init__.py +40 -5
  28. edsl/inference_services/exceptions.py +11 -0
  29. edsl/inference_services/services/anthropic_service.py +5 -2
  30. edsl/inference_services/services/aws_bedrock.py +6 -2
  31. edsl/inference_services/services/azure_ai.py +6 -2
  32. edsl/inference_services/services/google_service.py +7 -3
  33. edsl/inference_services/services/mistral_ai_service.py +6 -2
  34. edsl/inference_services/services/open_ai_service.py +6 -2
  35. edsl/inference_services/services/perplexity_service.py +6 -2
  36. edsl/inference_services/services/test_service.py +94 -5
  37. edsl/interviews/answering_function.py +167 -59
  38. edsl/interviews/interview.py +124 -72
  39. edsl/interviews/interview_task_manager.py +10 -0
  40. edsl/interviews/request_token_estimator.py +8 -0
  41. edsl/invigilators/invigilators.py +35 -13
  42. edsl/jobs/async_interview_runner.py +146 -104
  43. edsl/jobs/data_structures.py +6 -4
  44. edsl/jobs/decorators.py +61 -0
  45. edsl/jobs/fetch_invigilator.py +61 -18
  46. edsl/jobs/html_table_job_logger.py +14 -2
  47. edsl/jobs/jobs.py +180 -104
  48. edsl/jobs/jobs_component_constructor.py +2 -2
  49. edsl/jobs/jobs_interview_constructor.py +2 -0
  50. edsl/jobs/jobs_pricing_estimation.py +154 -113
  51. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  52. edsl/jobs/jobs_runner_status.py +30 -25
  53. edsl/jobs/progress_bar_manager.py +79 -0
  54. edsl/jobs/remote_inference.py +35 -1
  55. edsl/key_management/key_lookup_builder.py +6 -1
  56. edsl/language_models/language_model.py +110 -12
  57. edsl/language_models/model.py +10 -3
  58. edsl/language_models/price_manager.py +176 -71
  59. edsl/language_models/registry.py +5 -0
  60. edsl/notebooks/notebook.py +77 -10
  61. edsl/questions/VALIDATION_README.md +134 -0
  62. edsl/questions/__init__.py +24 -1
  63. edsl/questions/exceptions.py +21 -0
  64. edsl/questions/question_dict.py +201 -16
  65. edsl/questions/question_multiple_choice_with_other.py +624 -0
  66. edsl/questions/question_registry.py +2 -1
  67. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  68. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  69. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  70. edsl/questions/validation_analysis.py +185 -0
  71. edsl/questions/validation_cli.py +131 -0
  72. edsl/questions/validation_html_report.py +404 -0
  73. edsl/questions/validation_logger.py +136 -0
  74. edsl/results/result.py +115 -46
  75. edsl/results/results.py +702 -171
  76. edsl/scenarios/construct_download_link.py +16 -3
  77. edsl/scenarios/directory_scanner.py +226 -226
  78. edsl/scenarios/file_methods.py +5 -0
  79. edsl/scenarios/file_store.py +150 -9
  80. edsl/scenarios/handlers/__init__.py +5 -1
  81. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  82. edsl/scenarios/handlers/webm_file_store.py +104 -0
  83. edsl/scenarios/scenario.py +120 -101
  84. edsl/scenarios/scenario_list.py +800 -727
  85. edsl/scenarios/scenario_list_gc_test.py +146 -0
  86. edsl/scenarios/scenario_list_memory_test.py +214 -0
  87. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  88. edsl/scenarios/scenario_selector.py +5 -4
  89. edsl/scenarios/scenario_source.py +1990 -0
  90. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  91. edsl/surveys/survey.py +22 -0
  92. edsl/tasks/__init__.py +4 -2
  93. edsl/tasks/task_history.py +198 -36
  94. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  95. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  96. edsl/utilities/__init__.py +2 -1
  97. edsl/utilities/decorators.py +121 -0
  98. edsl/utilities/memory_debugger.py +1010 -0
  99. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/METADATA +51 -76
  100. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/RECORD +103 -79
  101. edsl/jobs/jobs_runner_asyncio.py +0 -281
  102. edsl/language_models/unused/fake_openai_service.py +0 -60
  103. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/LICENSE +0 -0
  104. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/WHEEL +0 -0
  105. {edsl-0.1.54.dist-info → edsl-0.1.56.dist-info}/entry_points.txt +0 -0
@@ -24,10 +24,19 @@ if TYPE_CHECKING:
24
24
  from ..jobs.data_structures import RunConfig
25
25
  from .interview_status_log import InterviewStatusLog
26
26
 
27
- # from jobs module
28
- from ..buckets import ModelBuckets
27
+ # Import data structures
29
28
  from ..jobs.data_structures import Answers
30
29
  from ..jobs.fetch_invigilator import FetchInvigilator
30
+
31
+ # Use import_module to avoid circular import
32
+ from importlib import import_module
33
+
34
+
35
+ def get_model_buckets():
36
+ buckets_module = import_module("edsl.buckets.model_buckets")
37
+ return buckets_module.ModelBuckets
38
+
39
+
31
40
  from ..surveys import Survey
32
41
  from ..utilities.utilities import dict_hash
33
42
 
@@ -51,10 +60,10 @@ if TYPE_CHECKING:
51
60
  @dataclass
52
61
  class InterviewRunningConfig:
53
62
  """Configuration parameters for interview execution.
54
-
63
+
55
64
  This dataclass contains settings that control how an interview is conducted,
56
65
  including error handling, caching behavior, and validation options.
57
-
66
+
58
67
  Attributes:
59
68
  cache: Optional cache for storing and retrieving model responses
60
69
  skip_retry: Whether to skip retrying failed questions (default: False)
@@ -70,24 +79,24 @@ class InterviewRunningConfig:
70
79
 
71
80
  class Interview:
72
81
  """Manages the process of an agent answering a survey asynchronously.
73
-
82
+
74
83
  An Interview represents a single execution unit - one agent answering one survey with one
75
84
  language model and one scenario. It handles the complete workflow of navigating through
76
85
  the survey based on skip logic, creating tasks for each question, tracking execution status,
77
86
  and collecting results.
78
-
87
+
79
88
  The core functionality is implemented in the `async_conduct_interview` method, which
80
89
  orchestrates the asynchronous execution of all question-answering tasks while respecting
81
90
  dependencies and rate limits. The class maintains detailed state about the interview progress,
82
91
  including answers collected so far, task statuses, token usage, and any exceptions encountered.
83
-
92
+
84
93
  Key components:
85
94
  - Task management: Creating and scheduling tasks for each question
86
95
  - Memory management: Controlling what previous answers are visible for each question
87
96
  - Exception handling: Tracking and potentially retrying failed questions
88
97
  - Status tracking: Monitoring the state of each task and the overall interview
89
98
  - Token tracking: Measuring and limiting API token usage
90
-
99
+
91
100
  This class serves as the execution layer that translates a high-level survey definition
92
101
  into concrete API calls to language models, with support for caching and fault tolerance.
93
102
  """
@@ -116,13 +125,13 @@ class Interview:
116
125
  cache: Optional cache for storing and retrieving model responses
117
126
  skip_retry: Whether to skip retrying failed questions
118
127
  raise_validation_errors: Whether to raise exceptions for validation errors
119
-
128
+
120
129
  The initialization process sets up the interview state including:
121
130
  1. Creating the task manager for handling question execution
122
131
  2. Initializing empty containers for answers and exceptions
123
132
  3. Setting up configuration and tracking structures
124
133
  4. Computing question indices for quick lookups
125
-
134
+
126
135
  Examples:
127
136
  >>> i = Interview.example()
128
137
  >>> i.task_manager.task_creators
@@ -173,7 +182,7 @@ class Interview:
173
182
  @property
174
183
  def cache(self) -> "Cache":
175
184
  """Get the cache used for storing and retrieving model responses.
176
-
185
+
177
186
  Returns:
178
187
  Cache: The cache object associated with this interview
179
188
  """
@@ -182,7 +191,7 @@ class Interview:
182
191
  @cache.setter
183
192
  def cache(self, value: "Cache") -> None:
184
193
  """Set the cache used for storing and retrieving model responses.
185
-
194
+
186
195
  Args:
187
196
  value: The cache object to use
188
197
  """
@@ -191,7 +200,7 @@ class Interview:
191
200
  @property
192
201
  def skip_retry(self) -> bool:
193
202
  """Get whether the interview should skip retrying failed questions.
194
-
203
+
195
204
  Returns:
196
205
  bool: True if failed questions should not be retried
197
206
  """
@@ -200,7 +209,7 @@ class Interview:
200
209
  @property
201
210
  def raise_validation_errors(self) -> bool:
202
211
  """Get whether validation errors should raise exceptions.
203
-
212
+
204
213
  Returns:
205
214
  bool: True if validation errors should raise exceptions
206
215
  """
@@ -209,19 +218,19 @@ class Interview:
209
218
  @property
210
219
  def has_exceptions(self) -> bool:
211
220
  """Check if any exceptions have occurred during the interview.
212
-
221
+
213
222
  Returns:
214
223
  bool: True if any exceptions have been recorded
215
224
  """
216
225
  return len(self.exceptions) > 0
217
226
 
218
227
  @property
219
- def task_status_logs(self) -> 'InterviewStatusLog':
228
+ def task_status_logs(self) -> "InterviewStatusLog":
220
229
  """Get the complete status history for all tasks in the interview.
221
-
230
+
222
231
  This property provides access to the status logs for all questions,
223
232
  showing how each task progressed through various states during execution.
224
-
233
+
225
234
  Returns:
226
235
  InterviewStatusLog: Dictionary mapping question names to their status log histories
227
236
  """
@@ -230,10 +239,10 @@ class Interview:
230
239
  @property
231
240
  def token_usage(self) -> "InterviewTokenUsage":
232
241
  """Get the token usage statistics for the entire interview.
233
-
242
+
234
243
  This tracks how many tokens were used for prompts and completions
235
244
  across all questions in the interview.
236
-
245
+
237
246
  Returns:
238
247
  InterviewTokenUsage: Token usage statistics for the interview
239
248
  """
@@ -242,10 +251,10 @@ class Interview:
242
251
  @property
243
252
  def interview_status(self) -> InterviewStatusDictionary:
244
253
  """Get the current status summary for all tasks in the interview.
245
-
254
+
246
255
  This provides a count of tasks in each status category (not started,
247
256
  in progress, completed, failed, etc.).
248
-
257
+
249
258
  Returns:
250
259
  InterviewStatusDictionary: Dictionary mapping status codes to counts
251
260
  """
@@ -253,18 +262,18 @@ class Interview:
253
262
 
254
263
  def to_dict(self, include_exceptions=True, add_edsl_version=True) -> dict[str, Any]:
255
264
  """Serialize the interview to a dictionary representation.
256
-
265
+
257
266
  This method creates a dictionary containing all the essential components
258
267
  of the interview, which can be used for hashing, serialization, and
259
268
  creating duplicate interviews.
260
-
269
+
261
270
  Args:
262
271
  include_exceptions: Whether to include exception information (default: True)
263
272
  add_edsl_version: Whether to include EDSL version in component dicts (default: True)
264
-
273
+
265
274
  Returns:
266
275
  dict: Dictionary representation of the interview
267
-
276
+
268
277
  Examples:
269
278
  >>> i = Interview.example()
270
279
  >>> hash(i)
@@ -293,14 +302,14 @@ class Interview:
293
302
  @classmethod
294
303
  def from_dict(cls, d: dict[str, Any]) -> "Interview":
295
304
  """Create an Interview instance from a dictionary representation.
296
-
305
+
297
306
  This class method deserializes an interview from a dictionary created by
298
307
  the to_dict method, recreating all components including agent, survey,
299
308
  scenario, model, and any exceptions.
300
-
309
+
301
310
  Args:
302
311
  d: Dictionary representation of an interview
303
-
312
+
304
313
  Returns:
305
314
  Interview: A reconstructed Interview instance
306
315
  """
@@ -342,11 +351,11 @@ class Interview:
342
351
 
343
352
  def __hash__(self) -> int:
344
353
  """Generate a hash value for the interview.
345
-
354
+
346
355
  This hash is based on the essential components of the interview
347
356
  (agent, survey, scenario, model, and iteration) but excludes mutable
348
357
  state like exceptions to ensure consistent hashing.
349
-
358
+
350
359
  Returns:
351
360
  int: A hash value that uniquely identifies this interview configuration
352
361
  """
@@ -354,16 +363,16 @@ class Interview:
354
363
 
355
364
  def __eq__(self, other: "Interview") -> bool:
356
365
  """Check if two interviews are equivalent.
357
-
366
+
358
367
  Two interviews are considered equal if they have the same agent, survey,
359
368
  scenario, model, and iteration number.
360
-
369
+
361
370
  Args:
362
371
  other: Another interview to compare with
363
-
372
+
364
373
  Returns:
365
374
  bool: True if the interviews are equivalent, False otherwise
366
-
375
+
367
376
  Examples:
368
377
  >>> from . import Interview
369
378
  >>> i = Interview.example()
@@ -377,46 +386,46 @@ class Interview:
377
386
  async def async_conduct_interview(
378
387
  self,
379
388
  run_config: Optional["RunConfig"] = None,
380
- ) -> tuple["Answers", List[dict[str, Any]]]:
389
+ ) -> None:
381
390
  """Execute the interview process asynchronously.
382
-
391
+
383
392
  This is the core method that conducts the entire interview, creating tasks
384
393
  for each question, managing dependencies between them, handling rate limits,
385
394
  and collecting results. It orchestrates the asynchronous execution of all
386
395
  question-answering tasks in the correct order based on survey rules.
387
-
396
+
388
397
  Args:
389
398
  run_config: Optional configuration for the interview execution,
390
399
  including parameters like stop_on_exception and environment
391
400
  settings like bucket_collection and key_lookup
392
-
401
+
393
402
  Returns:
394
403
  tuple: A tuple containing:
395
404
  - Answers: Dictionary of all question answers
396
405
  - List[dict]: List of valid results with detailed information
397
-
406
+
398
407
  Examples:
399
408
  Basic usage:
400
-
409
+
401
410
  >>> i = Interview.example()
402
- >>> result, _ = asyncio.run(i.async_conduct_interview())
403
- >>> result['q0']
411
+ >>> asyncio.run(i.async_conduct_interview())
412
+ >>> i.answers['q2']
404
413
  'yes'
405
-
414
+
406
415
  Handling exceptions:
407
-
416
+
408
417
  >>> i = Interview.example(throw_exception=True)
409
- >>> result, _ = asyncio.run(i.async_conduct_interview())
418
+ >>> asyncio.run(i.async_conduct_interview())
410
419
  >>> i.exceptions
411
420
  {'q0': ...
412
-
421
+
413
422
  Using custom configuration:
414
-
423
+
415
424
  >>> i = Interview.example()
416
425
  >>> from edsl.jobs import RunConfig, RunParameters, RunEnvironment
417
426
  >>> run_config = RunConfig(parameters=RunParameters(), environment=RunEnvironment())
418
427
  >>> run_config.parameters.stop_on_exception = True
419
- >>> result, _ = asyncio.run(i.async_conduct_interview(run_config))
428
+ >>> asyncio.run(i.async_conduct_interview(run_config))
420
429
  """
421
430
  from ..jobs import RunConfig, RunEnvironment, RunParameters
422
431
 
@@ -436,6 +445,7 @@ class Interview:
436
445
  model_buckets = None
437
446
 
438
447
  if model_buckets is None or hasattr(self.agent, "answer_question_directly"):
448
+ ModelBuckets = get_model_buckets()
439
449
  model_buckets = ModelBuckets.infinity_bucket()
440
450
 
441
451
  self.skip_flags = {q.question_name: False for q in self.survey.questions}
@@ -465,7 +475,10 @@ class Interview:
465
475
  valid_results = list(
466
476
  self._extract_valid_results(self.tasks, self.invigilators, self.exceptions)
467
477
  )
468
- return self.answers, valid_results
478
+ self.valid_results = valid_results
479
+ return None
480
+ #
481
+ # return self.answers, valid_results
469
482
 
470
483
  @staticmethod
471
484
  def _extract_valid_results(
@@ -474,27 +487,27 @@ class Interview:
474
487
  exceptions: InterviewExceptionCollection,
475
488
  ) -> Generator["Answers", None, None]:
476
489
  """Extract valid results from completed tasks and handle exceptions.
477
-
490
+
478
491
  This method processes the completed asyncio tasks, extracting successful
479
492
  results and handling any exceptions that occurred. It maintains the
480
493
  relationship between tasks, invigilators, and the questions they represent.
481
-
494
+
482
495
  Args:
483
496
  tasks: List of asyncio tasks for each question
484
497
  invigilators: List of invigilators corresponding to each task
485
498
  exceptions: Collection for storing any exceptions that occurred
486
-
499
+
487
500
  Yields:
488
501
  Answers: Valid results from each successfully completed task
489
-
502
+
490
503
  Notes:
491
504
  - Tasks and invigilators must have the same length and be in the same order
492
505
  - Cancelled tasks are expected and don't trigger exception recording
493
506
  - Other exceptions are recorded in the exceptions collection
494
-
507
+
495
508
  Examples:
496
509
  >>> i = Interview.example()
497
- >>> result, _ = asyncio.run(i.async_conduct_interview())
510
+ >>> asyncio.run(i.async_conduct_interview())
498
511
  """
499
512
  assert len(tasks) == len(invigilators)
500
513
 
@@ -523,38 +536,77 @@ class Interview:
523
536
  for task, invigilator in zip(tasks, invigilators):
524
537
  if not task.done():
525
538
  from edsl.interviews.exceptions import InterviewTaskError
539
+
526
540
  raise InterviewTaskError(f"Task {task.get_name()} is not done.")
527
541
 
528
542
  yield handle_task(task, invigilator)
529
543
 
530
544
  def __repr__(self) -> str:
531
545
  """Generate a string representation of the interview.
532
-
546
+
533
547
  This representation includes the key components of the interview
534
548
  (agent, survey, scenario, and model) for debugging and display purposes.
535
-
549
+
536
550
  Returns:
537
551
  str: A string representation of the interview instance
538
552
  """
539
553
  return f"Interview(agent = {repr(self.agent)}, survey = {repr(self.survey)}, scenario = {repr(self.scenario)}, model = {repr(self.model)})"
540
554
 
555
+ def clear_references(self) -> None:
556
+ """Clear strong references to help garbage collection.
557
+
558
+ This method clears strong references to various objects that might
559
+ be creating reference cycles and preventing proper garbage collection.
560
+ Call this method when you're done with an interview and want to ensure
561
+ it gets properly garbage collected.
562
+
563
+ This is particularly important for large-scale operations where memory
564
+ usage needs to be minimized.
565
+ """
566
+ # Clear references to tasks
567
+ if hasattr(self, "tasks"):
568
+ self.tasks = None
569
+
570
+ # Clear references to invigilators
571
+ if hasattr(self, "invigilators"):
572
+ self.invigilators = None
573
+
574
+ # Clear validator references in questions
575
+ if hasattr(self, "survey") and self.survey:
576
+ for question in self.survey.questions:
577
+ if hasattr(question, "clear_references"):
578
+ question.clear_references()
579
+
580
+ # Clear valid_results which might contain circular references
581
+ if hasattr(self, "valid_results"):
582
+ self.valid_results = None
583
+
584
+ # Clear task manager references
585
+ if hasattr(self, "task_manager"):
586
+ if hasattr(self.task_manager, "clear_references"):
587
+ self.task_manager.clear_references()
588
+ else:
589
+ # Clear task creators which might hold references to the interview
590
+ if hasattr(self.task_manager, "task_creators"):
591
+ self.task_manager.task_creators = {}
592
+
541
593
  def duplicate(
542
594
  self, iteration: int, cache: "Cache", randomize_survey: Optional[bool] = True
543
595
  ) -> "Interview":
544
596
  """Create a duplicate of this interview with a new iteration number and cache.
545
-
597
+
546
598
  This method creates a new Interview instance with the same components but
547
599
  a different iteration number. It can optionally randomize the survey questions
548
600
  (for surveys that support randomization) and use a different cache.
549
-
601
+
550
602
  Args:
551
603
  iteration: The new iteration number for the duplicated interview
552
604
  cache: The cache to use for the new interview (can be None)
553
605
  randomize_survey: Whether to randomize the survey questions (default: True)
554
-
606
+
555
607
  Returns:
556
608
  Interview: A new interview instance with updated iteration and cache
557
-
609
+
558
610
  Examples:
559
611
  >>> i = Interview.example()
560
612
  >>> i2 = i.duplicate(1, None)
@@ -582,31 +634,31 @@ class Interview:
582
634
  @classmethod
583
635
  def example(self, throw_exception: bool = False) -> "Interview":
584
636
  """Create an example Interview instance for testing and demonstrations.
585
-
637
+
586
638
  This method provides a convenient way to create a fully configured
587
639
  Interview instance with default components. It can be configured to
588
640
  either work normally or deliberately throw exceptions for testing
589
641
  error handling scenarios.
590
-
642
+
591
643
  Args:
592
644
  throw_exception: If True, creates an interview that will throw
593
645
  exceptions when run (useful for testing error handling)
594
-
646
+
595
647
  Returns:
596
648
  Interview: A fully configured example interview instance
597
-
649
+
598
650
  Examples:
599
651
  Creating a normal interview:
600
-
652
+
601
653
  >>> i = Interview.example()
602
- >>> result, _ = asyncio.run(i.async_conduct_interview())
603
- >>> result['q0']
654
+ >>> asyncio.run(i.async_conduct_interview())
655
+ >>> i.answers['q0']
604
656
  'yes'
605
-
657
+
606
658
  Creating an interview that will throw exceptions:
607
-
659
+
608
660
  >>> i = Interview.example(throw_exception=True)
609
- >>> result, _ = asyncio.run(i.async_conduct_interview())
661
+ >>> asyncio.run(i.async_conduct_interview())
610
662
  >>> i.has_exceptions
611
663
  True
612
664
  """
@@ -99,6 +99,16 @@ class InterviewTaskManager:
99
99
  """Return a dictionary mapping task status codes to counts."""
100
100
  return self.task_creators.interview_status
101
101
 
102
+ def clear_references(self) -> None:
103
+ """Clear references to help with garbage collection."""
104
+ # Clear task creators which might hold references to the interview
105
+ if hasattr(self, "task_creators"):
106
+ self.task_creators = {}
107
+
108
+ # Clear the survey reference
109
+ if hasattr(self, "survey"):
110
+ self.survey = None
111
+
102
112
 
103
113
  if __name__ == "__main__":
104
114
  import doctest
@@ -124,6 +124,14 @@ class RequestTokenEstimator:
124
124
  width, height = file.get_image_dimensions()
125
125
  token_usage = estimate_tokens(model_name, width, height)
126
126
  file_tokens += token_usage
127
+ if file.is_video():
128
+ model_name = self.interview.model.model
129
+ duration = file.get_video_metadata()["simplified"][
130
+ "duration_seconds"
131
+ ]
132
+ file_tokens += (
133
+ duration * 295
134
+ ) # (295 tokens per second for video + audio)
127
135
  else:
128
136
  file_tokens += file.size * 0.25
129
137
  else:
@@ -1,4 +1,5 @@
1
1
  """Module for creating Invigilators, which are objects to administer a question to an Agent."""
2
+
2
3
  from abc import ABC, abstractmethod
3
4
  import asyncio
4
5
  from typing import Coroutine, Dict, Any, Optional, TYPE_CHECKING
@@ -395,17 +396,21 @@ class InvigilatorAI(InvigilatorBase):
395
396
 
396
397
  if agent_response_dict.model_outputs.cache_used and False:
397
398
  data = {
398
- "answer": agent_response_dict.edsl_dict.answer
399
- if type(agent_response_dict.edsl_dict.answer) is str
400
- or type(agent_response_dict.edsl_dict.answer) is dict
401
- or type(agent_response_dict.edsl_dict.answer) is list
402
- or type(agent_response_dict.edsl_dict.answer) is int
403
- or type(agent_response_dict.edsl_dict.answer) is float
404
- or type(agent_response_dict.edsl_dict.answer) is bool
405
- else "",
406
- "comment": agent_response_dict.edsl_dict.comment
407
- if agent_response_dict.edsl_dict.comment
408
- else "",
399
+ "answer": (
400
+ agent_response_dict.edsl_dict.answer
401
+ if type(agent_response_dict.edsl_dict.answer) is str
402
+ or type(agent_response_dict.edsl_dict.answer) is dict
403
+ or type(agent_response_dict.edsl_dict.answer) is list
404
+ or type(agent_response_dict.edsl_dict.answer) is int
405
+ or type(agent_response_dict.edsl_dict.answer) is float
406
+ or type(agent_response_dict.edsl_dict.answer) is bool
407
+ else ""
408
+ ),
409
+ "comment": (
410
+ agent_response_dict.edsl_dict.comment
411
+ if agent_response_dict.edsl_dict.comment
412
+ else ""
413
+ ),
409
414
  "generated_tokens": agent_response_dict.edsl_dict.generated_tokens,
410
415
  "question_name": self.question.question_name,
411
416
  "prompts": self.get_prompts(),
@@ -415,7 +420,11 @@ class InvigilatorAI(InvigilatorBase):
415
420
  "cache_key": agent_response_dict.model_outputs.cache_key,
416
421
  "validated": True,
417
422
  "exception_occurred": exception_occurred,
418
- "cost": agent_response_dict.model_outputs.cost,
423
+ "input_tokens": agent_response_dict.model_outputs.input_tokens,
424
+ "output_tokens": agent_response_dict.model_outputs.output_tokens,
425
+ "input_price_per_million_tokens": agent_response_dict.model_outputs.input_price_per_million_tokens,
426
+ "output_price_per_million_tokens": agent_response_dict.model_outputs.output_price_per_million_tokens,
427
+ "total_cost": agent_response_dict.model_outputs.total_cost,
419
428
  }
420
429
 
421
430
  result = EDSLResultObjectInput(**data)
@@ -447,6 +456,15 @@ class InvigilatorAI(InvigilatorBase):
447
456
  answer = self._determine_answer(validated_edsl_dict["answer"])
448
457
  comment = validated_edsl_dict.get("comment", "")
449
458
  validated = True
459
+
460
+ # Update the cache entry to mark it as validated if we have a cache and a key
461
+ if self.cache and agent_response_dict.model_outputs.cache_key:
462
+ cache_key = agent_response_dict.model_outputs.cache_key
463
+ if cache_key in self.cache.data:
464
+ # Get the entry from the cache
465
+ entry = self.cache.data[cache_key]
466
+ # Set the validated flag to True
467
+ entry.validated = True
450
468
  except QuestionAnswerValidationError as e:
451
469
  answer = None
452
470
  comment = "The response was not valid."
@@ -471,7 +489,11 @@ class InvigilatorAI(InvigilatorBase):
471
489
  "cache_key": agent_response_dict.model_outputs.cache_key,
472
490
  "validated": validated,
473
491
  "exception_occurred": exception_occurred,
474
- "cost": agent_response_dict.model_outputs.cost,
492
+ "input_tokens": agent_response_dict.model_outputs.input_tokens,
493
+ "output_tokens": agent_response_dict.model_outputs.output_tokens,
494
+ "input_price_per_million_tokens": agent_response_dict.model_outputs.input_price_per_million_tokens,
495
+ "output_price_per_million_tokens": agent_response_dict.model_outputs.output_price_per_million_tokens,
496
+ "total_cost": agent_response_dict.model_outputs.total_cost,
475
497
  }
476
498
  result = EDSLResultObjectInput(**data)
477
499
  return result