edsl 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (101) hide show
  1. edsl/__init__.py +8 -1
  2. edsl/__init__original.py +134 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/agent.py +29 -0
  5. edsl/agents/agent_list.py +36 -1
  6. edsl/base/base_class.py +281 -151
  7. edsl/buckets/__init__.py +8 -3
  8. edsl/buckets/bucket_collection.py +9 -3
  9. edsl/buckets/model_buckets.py +4 -2
  10. edsl/buckets/token_bucket.py +2 -2
  11. edsl/buckets/token_bucket_client.py +5 -3
  12. edsl/caching/cache.py +131 -62
  13. edsl/caching/cache_entry.py +70 -58
  14. edsl/caching/sql_dict.py +17 -0
  15. edsl/cli.py +99 -0
  16. edsl/config/config_class.py +16 -0
  17. edsl/conversation/__init__.py +31 -0
  18. edsl/coop/coop.py +276 -242
  19. edsl/coop/coop_jobs_objects.py +59 -0
  20. edsl/coop/coop_objects.py +29 -0
  21. edsl/coop/coop_regular_objects.py +26 -0
  22. edsl/coop/utils.py +24 -19
  23. edsl/dataset/dataset.py +338 -101
  24. edsl/db_list/sqlite_list.py +349 -0
  25. edsl/inference_services/__init__.py +40 -5
  26. edsl/inference_services/exceptions.py +11 -0
  27. edsl/inference_services/services/anthropic_service.py +5 -2
  28. edsl/inference_services/services/aws_bedrock.py +6 -2
  29. edsl/inference_services/services/azure_ai.py +6 -2
  30. edsl/inference_services/services/google_service.py +3 -2
  31. edsl/inference_services/services/mistral_ai_service.py +6 -2
  32. edsl/inference_services/services/open_ai_service.py +6 -2
  33. edsl/inference_services/services/perplexity_service.py +6 -2
  34. edsl/inference_services/services/test_service.py +94 -5
  35. edsl/interviews/answering_function.py +167 -59
  36. edsl/interviews/interview.py +124 -72
  37. edsl/interviews/interview_task_manager.py +10 -0
  38. edsl/invigilators/invigilators.py +9 -0
  39. edsl/jobs/async_interview_runner.py +146 -104
  40. edsl/jobs/data_structures.py +6 -4
  41. edsl/jobs/decorators.py +61 -0
  42. edsl/jobs/fetch_invigilator.py +61 -18
  43. edsl/jobs/html_table_job_logger.py +14 -2
  44. edsl/jobs/jobs.py +180 -104
  45. edsl/jobs/jobs_component_constructor.py +2 -2
  46. edsl/jobs/jobs_interview_constructor.py +2 -0
  47. edsl/jobs/jobs_remote_inference_logger.py +4 -0
  48. edsl/jobs/jobs_runner_status.py +30 -25
  49. edsl/jobs/progress_bar_manager.py +79 -0
  50. edsl/jobs/remote_inference.py +35 -1
  51. edsl/key_management/key_lookup_builder.py +6 -1
  52. edsl/language_models/language_model.py +86 -6
  53. edsl/language_models/model.py +10 -3
  54. edsl/language_models/price_manager.py +45 -75
  55. edsl/language_models/registry.py +5 -0
  56. edsl/notebooks/notebook.py +77 -10
  57. edsl/questions/VALIDATION_README.md +134 -0
  58. edsl/questions/__init__.py +24 -1
  59. edsl/questions/exceptions.py +21 -0
  60. edsl/questions/question_dict.py +201 -16
  61. edsl/questions/question_multiple_choice_with_other.py +624 -0
  62. edsl/questions/question_registry.py +2 -1
  63. edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
  64. edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
  65. edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
  66. edsl/questions/validation_analysis.py +185 -0
  67. edsl/questions/validation_cli.py +131 -0
  68. edsl/questions/validation_html_report.py +404 -0
  69. edsl/questions/validation_logger.py +136 -0
  70. edsl/results/result.py +63 -16
  71. edsl/results/results.py +702 -171
  72. edsl/scenarios/construct_download_link.py +16 -3
  73. edsl/scenarios/directory_scanner.py +226 -226
  74. edsl/scenarios/file_methods.py +5 -0
  75. edsl/scenarios/file_store.py +117 -6
  76. edsl/scenarios/handlers/__init__.py +5 -1
  77. edsl/scenarios/handlers/mp4_file_store.py +104 -0
  78. edsl/scenarios/handlers/webm_file_store.py +104 -0
  79. edsl/scenarios/scenario.py +120 -101
  80. edsl/scenarios/scenario_list.py +800 -727
  81. edsl/scenarios/scenario_list_gc_test.py +146 -0
  82. edsl/scenarios/scenario_list_memory_test.py +214 -0
  83. edsl/scenarios/scenario_list_source_refactor.md +35 -0
  84. edsl/scenarios/scenario_selector.py +5 -4
  85. edsl/scenarios/scenario_source.py +1990 -0
  86. edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
  87. edsl/surveys/survey.py +22 -0
  88. edsl/tasks/__init__.py +4 -2
  89. edsl/tasks/task_history.py +198 -36
  90. edsl/tests/scenarios/test_ScenarioSource.py +51 -0
  91. edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
  92. edsl/utilities/__init__.py +2 -1
  93. edsl/utilities/decorators.py +121 -0
  94. edsl/utilities/memory_debugger.py +1010 -0
  95. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/METADATA +51 -76
  96. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/RECORD +99 -75
  97. edsl/jobs/jobs_runner_asyncio.py +0 -281
  98. edsl/language_models/unused/fake_openai_service.py +0 -60
  99. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
  100. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
  101. {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
edsl/jobs/jobs.py CHANGED
@@ -18,7 +18,6 @@ who need to run complex simulations with language models.
18
18
  from __future__ import annotations
19
19
  import asyncio
20
20
  from typing import Optional, Union, TypeVar, Callable, cast
21
- from functools import wraps
22
21
 
23
22
  from typing import (
24
23
  Literal,
@@ -31,10 +30,27 @@ from ..base import Base
31
30
  from ..utilities import remove_edsl_version
32
31
  from ..coop import CoopServerResponseError
33
32
 
34
- from ..buckets import BucketCollection
33
+ # Import BucketCollection with an import_module to avoid early binding
34
+ from importlib import import_module
35
+
36
+
37
+ def get_bucket_collection():
38
+ buckets_module = import_module("edsl.buckets")
39
+ return buckets_module.BucketCollection
40
+
41
+
35
42
  from ..scenarios import Scenario, ScenarioList
36
43
  from ..surveys import Survey
37
- from ..interviews import Interview
44
+
45
+ # Use import_module to avoid circular import with interviews
46
+ from importlib import import_module
47
+
48
+
49
+ def get_interview():
50
+ interviews_module = import_module("edsl.interviews.interview")
51
+ return interviews_module.Interview
52
+
53
+
38
54
  from .exceptions import JobsValueError, JobsImplementationError
39
55
 
40
56
  from .jobs_pricing_estimation import JobsPrompts
@@ -42,6 +58,7 @@ from .remote_inference import JobsRemoteInferenceHandler
42
58
  from .jobs_checks import JobsChecks
43
59
  from .data_structures import RunEnvironment, RunParameters, RunConfig
44
60
  from .check_survey_scenario_compatibility import CheckSurveyScenarioCompatibility
61
+ from .decorators import with_config
45
62
 
46
63
 
47
64
  if TYPE_CHECKING:
@@ -59,65 +76,6 @@ if TYPE_CHECKING:
59
76
  VisibilityType = Literal["private", "public", "unlisted"]
60
77
 
61
78
 
62
- try:
63
- from typing import ParamSpec
64
- except ImportError:
65
- from typing_extensions import ParamSpec
66
-
67
-
68
- P = ParamSpec("P")
69
- T = TypeVar("T")
70
-
71
-
72
- def with_config(f: Callable[P, T]) -> Callable[P, T]:
73
- """
74
- Decorator that processes function parameters to match the RunConfig dataclass structure.
75
-
76
- This decorator is used primarily with the run() and run_async() methods to provide
77
- a consistent interface for job configuration while maintaining a clean API.
78
-
79
- The decorator:
80
- 1. Extracts environment-related parameters into a RunEnvironment instance
81
- 2. Extracts execution-related parameters into a RunParameters instance
82
- 3. Combines both into a single RunConfig object
83
- 4. Passes this RunConfig to the decorated function as a keyword argument
84
-
85
- Parameters:
86
- f (Callable): The function to decorate, typically run() or run_async()
87
-
88
- Returns:
89
- Callable: A wrapped function that accepts all RunConfig parameters directly
90
-
91
- Example:
92
- @with_config
93
- def run(self, *, config: RunConfig) -> Results:
94
- # Function can now access config.parameters and config.environment
95
- """
96
- parameter_fields = {
97
- name: field.default
98
- for name, field in RunParameters.__dataclass_fields__.items()
99
- }
100
- environment_fields = {
101
- name: field.default
102
- for name, field in RunEnvironment.__dataclass_fields__.items()
103
- }
104
- # Combined fields dict used for reference during development
105
- # combined = {**parameter_fields, **environment_fields}
106
-
107
- @wraps(f)
108
- def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
109
- environment = RunEnvironment(
110
- **{k: v for k, v in kwargs.items() if k in environment_fields}
111
- )
112
- parameters = RunParameters(
113
- **{k: v for k, v in kwargs.items() if k in parameter_fields}
114
- )
115
- config = RunConfig(environment=environment, parameters=parameters)
116
- return f(*args, config=config)
117
-
118
- return cast(Callable[P, T], wrapper)
119
-
120
-
121
79
  class Jobs(Base):
122
80
  """
123
81
  A collection of agents, scenarios, models, and a survey that orchestrates interviews.
@@ -220,7 +178,7 @@ class Jobs(Base):
220
178
  self.run_config.add_cache(cache)
221
179
  return self
222
180
 
223
- def using_bucket_collection(self, bucket_collection: "BucketCollection") -> Jobs:
181
+ def using_bucket_collection(self, bucket_collection) -> Jobs:
224
182
  """
225
183
  Add a BucketCollection to the job.
226
184
 
@@ -238,7 +196,7 @@ class Jobs(Base):
238
196
  self.run_config.add_key_lookup(key_lookup)
239
197
  return self
240
198
 
241
- def using(self, obj: Union[Cache, BucketCollection, KeyLookup]) -> Jobs:
199
+ def using(self, obj) -> Jobs:
242
200
  """
243
201
  Add a Cache, BucketCollection, or KeyLookup to the job.
244
202
 
@@ -247,6 +205,8 @@ class Jobs(Base):
247
205
  from ..caching import Cache
248
206
  from ..key_management import KeyLookup
249
207
 
208
+ BucketCollection = get_bucket_collection()
209
+
250
210
  if isinstance(obj, Cache):
251
211
  self.using_cache(obj)
252
212
  elif isinstance(obj, BucketCollection):
@@ -428,7 +388,7 @@ class Jobs(Base):
428
388
 
429
389
  :param iterations: the number of iterations to run
430
390
  """
431
- return JobsPrompts(self).estimate_job_cost(iterations)
391
+ return JobsPrompts.from_jobs(self).estimate_job_cost(iterations)
432
392
 
433
393
  def estimate_job_cost_from_external_prices(
434
394
  self, price_lookup: dict, iterations: int = 1
@@ -453,7 +413,7 @@ class Jobs(Base):
453
413
  self.models = self.models or [Model()]
454
414
  self.scenarios = self.scenarios or [Scenario()]
455
415
 
456
- def generate_interviews(self) -> Generator[Interview, None, None]:
416
+ def generate_interviews(self) -> Generator:
457
417
  """
458
418
  Generate interviews.
459
419
 
@@ -485,7 +445,7 @@ class Jobs(Base):
485
445
  filename=filename
486
446
  )
487
447
 
488
- def interviews(self) -> list[Interview]:
448
+ def interviews(self) -> list:
489
449
  """
490
450
  Return a list of :class:`edsl.jobs.interviews.Interview` objects.
491
451
 
@@ -508,6 +468,9 @@ class Jobs(Base):
508
468
  This is useful when you have, say, a list of failed interviews and you want to create
509
469
  a new job with only those interviews.
510
470
  """
471
+ if not interview_list:
472
+ raise JobsValueError("Cannot create Jobs from empty interview list")
473
+
511
474
  survey = interview_list[0].survey
512
475
  # get all the models
513
476
  models = list(set([interview.model for interview in interview_list]))
@@ -516,7 +479,7 @@ class Jobs(Base):
516
479
  jobs._interviews = interview_list
517
480
  return jobs
518
481
 
519
- def create_bucket_collection(self) -> BucketCollection:
482
+ def create_bucket_collection(self):
520
483
  """
521
484
  Create a collection of buckets for each model.
522
485
 
@@ -529,6 +492,7 @@ class Jobs(Base):
529
492
  >>> bc
530
493
  BucketCollection(...)
531
494
  """
495
+ BucketCollection = get_bucket_collection()
532
496
  bc = BucketCollection.from_models(self.models)
533
497
 
534
498
  if self.run_config.environment.key_lookup is not None:
@@ -645,19 +609,109 @@ class Jobs(Base):
645
609
  jc.check_api_keys()
646
610
 
647
611
  async def _execute_with_remote_cache(self, run_job_async: bool) -> Results:
648
- # Remote cache usage determination happens inside this method
649
- # use_remote_cache = self.use_remote_cache()
650
-
651
- from .jobs_runner_asyncio import JobsRunnerAsyncio
612
+ """Core interview execution logic for jobs execution."""
613
+ # Import needed modules inline to avoid early binding
614
+ import os
615
+ import time
616
+ import gc
617
+ import weakref
618
+ import asyncio
652
619
  from ..caching import Cache
620
+ from ..results import Results, Result
621
+ from ..tasks import TaskHistory
622
+ from ..utilities.decorators import jupyter_nb_handler
623
+ from ..utilities.memory_debugger import MemoryDebugger
624
+ from .jobs_runner_status import JobsRunnerStatus
625
+ from .async_interview_runner import AsyncInterviewRunner
626
+ from .progress_bar_manager import ProgressBarManager
627
+ from .results_exceptions_handler import ResultsExceptionsHandler
653
628
 
654
629
  assert isinstance(self.run_config.environment.cache, Cache)
655
630
 
656
- runner = JobsRunnerAsyncio(self, environment=self.run_config.environment)
631
+ # Create the RunConfig for the job
632
+ run_config = RunConfig(
633
+ parameters=self.run_config.parameters,
634
+ environment=self.run_config.environment,
635
+ )
636
+
637
+ # Setup JobsRunnerStatus if needed
638
+ if self.run_config.environment.jobs_runner_status is None:
639
+ self.run_config.environment.jobs_runner_status = JobsRunnerStatus(
640
+ self, n=self.run_config.parameters.n
641
+ )
642
+
643
+ # Create a shared function to process interview results
644
+ async def process_interviews(interview_runner, results_obj):
645
+ prev_interview_ref = None
646
+ async for result, interview, idx in interview_runner.run():
647
+ # Set the order attribute on the result for correct ordering
648
+ result.order = idx
649
+
650
+ # Collect results
651
+ # results_obj.append(result)
652
+ # key = results_obj.shelve_result(result)
653
+ results_obj.add_task_history_entry(interview)
654
+ results_obj.insert_sorted(result)
655
+
656
+ # Memory management: Set up reference for next iteration and clear old references
657
+ prev_interview_ref = weakref.ref(interview)
658
+ if hasattr(interview, "clear_references"):
659
+ interview.clear_references()
660
+
661
+ # Force garbage collection
662
+ del result
663
+ del interview
664
+
665
+ # Finalize results object with cache and bucket collection
666
+ # results_obj.insert_from_shelf()
667
+ results_obj.cache = results_obj.relevant_cache(
668
+ self.run_config.environment.cache
669
+ )
670
+ results_obj.bucket_collection = (
671
+ self.run_config.environment.bucket_collection
672
+ )
673
+ return results_obj
674
+
675
+ # Core execution logic
676
+ interview_runner = AsyncInterviewRunner(self, run_config)
677
+
678
+ # Create an initial Results object with appropriate traceback settings
679
+ results = Results(
680
+ survey=self.survey,
681
+ data=[],
682
+ task_history=TaskHistory(
683
+ include_traceback=not self.run_config.parameters.progress_bar
684
+ ),
685
+ )
686
+
657
687
  if run_job_async:
658
- results = await runner.run_async(self.run_config.parameters)
688
+ # For async execution mode (simplified path without progress bar)
689
+ await process_interviews(interview_runner, results)
659
690
  else:
660
- results = runner.run(self.run_config.parameters)
691
+ # For synchronous execution mode (with progress bar)
692
+ with ProgressBarManager(
693
+ self, run_config, self.run_config.parameters
694
+ ) as stop_event:
695
+ try:
696
+ await process_interviews(interview_runner, results)
697
+ except KeyboardInterrupt:
698
+ print("Keyboard interrupt received. Stopping gracefully...")
699
+ results = Results(
700
+ survey=self.survey, data=[], task_history=TaskHistory()
701
+ )
702
+ except Exception as e:
703
+ if self.run_config.parameters.stop_on_exception:
704
+ raise
705
+ results = Results(
706
+ survey=self.survey, data=[], task_history=TaskHistory()
707
+ )
708
+
709
+ # Process any exceptions in the results
710
+ if results:
711
+ ResultsExceptionsHandler(
712
+ results, self.run_config.parameters
713
+ ).handle_exceptions()
714
+
661
715
  return results
662
716
 
663
717
  @property
@@ -668,55 +722,72 @@ class Jobs(Base):
668
722
  return len(self) * self.run_config.parameters.n
669
723
 
670
724
  def _run(self, config: RunConfig) -> Union[None, "Results"]:
671
- "Shared code for run and run_async"
672
- if config.environment.cache is not None:
673
- self.run_config.environment.cache = config.environment.cache
674
- if config.environment.jobs_runner_status is not None:
675
- self.run_config.environment.jobs_runner_status = (
676
- config.environment.jobs_runner_status
677
- )
725
+ """
726
+ Shared code for run and run_async methods.
678
727
 
679
- if config.environment.bucket_collection is not None:
680
- self.run_config.environment.bucket_collection = (
681
- config.environment.bucket_collection
682
- )
728
+ This method handles all pre-execution setup including:
729
+ 1. Transferring configuration settings from the input config
730
+ 2. Ensuring all required objects (agents, models, scenarios) exist
731
+ 3. Checking API keys and remote execution availability
732
+ 4. Setting up caching and bucket collections
733
+ 5. Attempting remote execution if appropriate
683
734
 
684
- if config.environment.key_lookup is not None:
685
- self.run_config.environment.key_lookup = config.environment.key_lookup
735
+ Returns:
736
+ Tuple containing (Results, reason) if remote execution succeeds,
737
+ or (None, reason) if local execution should proceed
738
+ """
739
+ # Apply configuration from input config to self.run_config
740
+ for attr_name in [
741
+ "cache",
742
+ "jobs_runner_status",
743
+ "bucket_collection",
744
+ "key_lookup",
745
+ ]:
746
+ if getattr(config.environment, attr_name) is not None:
747
+ setattr(
748
+ self.run_config.environment,
749
+ attr_name,
750
+ getattr(config.environment, attr_name),
751
+ )
686
752
 
687
- # replace the parameters with the ones from the config
753
+ # Replace parameters with the ones from the config
688
754
  self.run_config.parameters = config.parameters
689
755
 
756
+ # Make sure all required objects exist
690
757
  self.replace_missing_objects()
691
-
692
758
  self._prepare_to_run()
693
759
  self._check_if_remote_keys_ok()
694
760
 
761
+ # Setup caching
762
+ from ..caching import CacheHandler, Cache
763
+
695
764
  if (
696
765
  self.run_config.environment.cache is None
697
766
  or self.run_config.environment.cache is True
698
767
  ):
699
- from ..caching import CacheHandler
700
-
701
768
  self.run_config.environment.cache = CacheHandler().get_cache()
702
-
703
- if self.run_config.environment.cache is False:
704
- from ..caching import Cache
705
-
769
+ elif self.run_config.environment.cache is False:
706
770
  self.run_config.environment.cache = Cache(immediate_write=False)
707
771
 
708
- # first try to run the job remotely
772
+ # Try to run the job remotely first
709
773
  results, reason = self._remote_results(config)
710
774
  if results is not None:
711
775
  return results, reason
712
776
 
777
+ # If we need to run locally, ensure keys and resources are ready
713
778
  self._check_if_local_keys_ok()
714
779
 
715
- if config.environment.bucket_collection is None:
780
+ # Create bucket collection if it doesn't exist
781
+ if self.run_config.environment.bucket_collection is None:
716
782
  self.run_config.environment.bucket_collection = (
717
783
  self.create_bucket_collection()
718
784
  )
785
+ else:
786
+ # Ensure models are properly added to the bucket collection
787
+ for model in self.models:
788
+ self.run_config.environment.bucket_collection.add_model(model)
719
789
 
790
+ # Update bucket collection from key lookup if both exist
720
791
  if (
721
792
  self.run_config.environment.key_lookup is not None
722
793
  and self.run_config.environment.bucket_collection is not None
@@ -756,6 +827,8 @@ class Jobs(Base):
756
827
  cache (Cache, optional): Cache object to store results
757
828
  bucket_collection (BucketCollection, optional): Object to track API calls
758
829
  key_lookup (KeyLookup, optional): Object to manage API keys
830
+ memory_threshold (int, optional): Memory threshold in bytes for the Results object's SQLList,
831
+ controlling when data is offloaded to SQLite storage
759
832
 
760
833
  Returns:
761
834
  Results: A Results object containing all responses and metadata
@@ -814,6 +887,8 @@ class Jobs(Base):
814
887
  cache (Cache, optional): Cache object to store results
815
888
  bucket_collection (BucketCollection, optional): Object to track API calls
816
889
  key_lookup (KeyLookup, optional): Object to manage API keys
890
+ memory_threshold (int, optional): Memory threshold in bytes for the Results object's SQLList,
891
+ controlling when data is offloaded to SQLite storage
817
892
 
818
893
  Returns:
819
894
  Results: A Results object containing all responses and metadata
@@ -991,17 +1066,18 @@ class Jobs(Base):
991
1066
 
992
1067
  base_survey = Survey(questions=[q1, q2])
993
1068
 
994
- scenario_list = ScenarioList(
995
- [
996
- Scenario({"period": f"morning{addition}"}),
997
- Scenario({"period": "afternoon"}),
998
- ]
999
- )
1069
+ scenarios = [
1070
+ Scenario({"period": f"morning{addition}"}),
1071
+ Scenario({"period": "afternoon"}),
1072
+ ]
1073
+ scenario_list = ScenarioList(data=scenarios)
1000
1074
  if test_model:
1001
1075
  job = base_survey.by(m).by(scenario_list).by(joy_agent, sad_agent)
1002
1076
  else:
1003
1077
  job = base_survey.by(scenario_list).by(joy_agent, sad_agent)
1004
1078
 
1079
+ assert len(scenario_list) == 2
1080
+
1005
1081
  return job
1006
1082
 
1007
1083
  def code(self):
@@ -1,6 +1,6 @@
1
1
  from typing import Union, Sequence, TYPE_CHECKING
2
2
  from .exceptions import JobsValueError
3
-
3
+ from ..scenarios import ScenarioList
4
4
  if TYPE_CHECKING:
5
5
  from ..agents import Agent
6
6
  from ..language_models import LanguageModel
@@ -96,7 +96,7 @@ class JobsComponentConstructor:
96
96
  >>> did_user_pass_a_sequence(1)
97
97
  False
98
98
  """
99
- return len(args) == 1 and isinstance(args[0], Sequence)
99
+ return len(args) == 1 and (isinstance(args[0], Sequence) or isinstance(args[0], ScenarioList))
100
100
 
101
101
  if did_user_pass_a_sequence(args):
102
102
  container_class = JobsComponentConstructor._get_container_class(args[0][0])
@@ -50,6 +50,8 @@ class InterviewsConstructor:
50
50
  },
51
51
  )
52
52
 
53
+
53
54
  if __name__ == "__main__":
55
+ #test_gc()
54
56
  import doctest
55
57
  doctest.testmod()
@@ -30,6 +30,8 @@ class JobsInfo:
30
30
  error_report_url: str = None
31
31
  results_uuid: str = None
32
32
  results_url: str = None
33
+ completed_interviews: int = None
34
+ failed_interviews: int = None
33
35
 
34
36
  pretty_names = {
35
37
  "job_uuid": "Job UUID",
@@ -53,6 +55,8 @@ class JobLogger(ABC):
53
55
  "error_report_url",
54
56
  "results_uuid",
55
57
  "results_url",
58
+ "completed_interviews",
59
+ "failed_interviews",
56
60
  ],
57
61
  value: str,
58
62
  ):
@@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, TYPE_CHECKING
10
10
  from uuid import UUID
11
11
 
12
12
  if TYPE_CHECKING:
13
- from .jobs_runner_asyncio import JobsRunnerAsyncio
13
+ from .jobs import Jobs
14
14
 
15
15
 
16
16
  @dataclass
@@ -65,14 +65,14 @@ class StatisticsTracker:
65
65
  class JobsRunnerStatusBase(ABC):
66
66
  def __init__(
67
67
  self,
68
- jobs_runner: "JobsRunnerAsyncio",
68
+ jobs: "Jobs",
69
69
  n: int,
70
70
  refresh_rate: float = 1,
71
71
  endpoint_url: Optional[str] = "http://localhost:8000",
72
72
  job_uuid: Optional[UUID] = None,
73
73
  api_key: str = None,
74
74
  ):
75
- self.jobs_runner = jobs_runner
75
+ self.jobs = jobs
76
76
  self.job_uuid = job_uuid
77
77
  self.base_url = f"{endpoint_url}"
78
78
  self.refresh_rate = refresh_rate
@@ -86,10 +86,10 @@ class JobsRunnerStatusBase(ABC):
86
86
  "unfixed_exceptions",
87
87
  "throughput",
88
88
  ]
89
- self.num_total_interviews = n * len(self.jobs_runner)
89
+ self.num_total_interviews = n * len(self.jobs)
90
90
 
91
91
  self.distinct_models = list(
92
- set(model.model for model in self.jobs_runner.jobs.models)
92
+ set(model.model for model in self.jobs.models)
93
93
  )
94
94
 
95
95
  self.stats_tracker = StatisticsTracker(
@@ -151,26 +151,31 @@ class JobsRunnerStatusBase(ABC):
151
151
  }
152
152
 
153
153
  model_queues = {}
154
- # for model, bucket in self.jobs_runner.bucket_collection.items():
155
- for model, bucket in self.jobs_runner.environment.bucket_collection.items():
156
- model_name = model.model
157
- model_queues[model_name] = {
158
- "language_model_name": model_name,
159
- "requests_bucket": {
160
- "completed": bucket.requests_bucket.num_released,
161
- "requested": bucket.requests_bucket.num_requests,
162
- "tokens_returned": bucket.requests_bucket.tokens_returned,
163
- "target_rate": round(bucket.requests_bucket.target_rate, 1),
164
- "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
165
- },
166
- "tokens_bucket": {
167
- "completed": bucket.tokens_bucket.num_released,
168
- "requested": bucket.tokens_bucket.num_requests,
169
- "tokens_returned": bucket.tokens_bucket.tokens_returned,
170
- "target_rate": round(bucket.tokens_bucket.target_rate, 1),
171
- "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
172
- },
173
- }
154
+ # Check if bucket collection exists and is not empty
155
+ if (hasattr(self.jobs, 'run_config') and
156
+ hasattr(self.jobs.run_config, 'environment') and
157
+ hasattr(self.jobs.run_config.environment, 'bucket_collection') and
158
+ self.jobs.run_config.environment.bucket_collection):
159
+
160
+ for model, bucket in self.jobs.run_config.environment.bucket_collection.items():
161
+ model_name = model.model
162
+ model_queues[model_name] = {
163
+ "language_model_name": model_name,
164
+ "requests_bucket": {
165
+ "completed": bucket.requests_bucket.num_released,
166
+ "requested": bucket.requests_bucket.num_requests,
167
+ "tokens_returned": bucket.requests_bucket.tokens_returned,
168
+ "target_rate": round(bucket.requests_bucket.target_rate, 1),
169
+ "current_rate": round(bucket.requests_bucket.get_throughput(), 1),
170
+ },
171
+ "tokens_bucket": {
172
+ "completed": bucket.tokens_bucket.num_released,
173
+ "requested": bucket.tokens_bucket.num_requests,
174
+ "tokens_returned": bucket.tokens_bucket.tokens_returned,
175
+ "target_rate": round(bucket.tokens_bucket.target_rate, 1),
176
+ "current_rate": round(bucket.tokens_bucket.get_throughput(), 1),
177
+ },
178
+ }
174
179
  status_dict["language_model_queues"] = model_queues
175
180
  return status_dict
176
181
 
@@ -0,0 +1,79 @@
1
+ """
2
+ Progress bar management for asynchronous job execution.
3
+
4
+ This module provides a context manager for handling progress bar setup and thread
5
+ management during job execution. It coordinates the display and updating of progress
6
+ bars, particularly for remote tracking via the Expected Parrot API.
7
+ """
8
+
9
+ import threading
10
+ import warnings
11
+
12
+ from ..coop import Coop
13
+ from .jobs_runner_status import JobsRunnerStatus
14
+
15
+
16
+ class ProgressBarManager:
17
+ """Context manager for handling progress bar setup and thread management.
18
+
19
+ This class manages the progress bar display and updating during job execution,
20
+ particularly for remote tracking via the Expected Parrot API.
21
+
22
+ It handles:
23
+ 1. Setting up a status tracking object
24
+ 2. Creating and managing a background thread for progress updates
25
+ 3. Properly cleaning up resources when execution completes
26
+ """
27
+
28
+ def __init__(self, jobs, run_config, parameters):
29
+ self.parameters = parameters
30
+ self.jobs = jobs
31
+
32
+ # Set up progress tracking
33
+ coop = Coop()
34
+ endpoint_url = coop.get_progress_bar_url()
35
+
36
+ # Set up jobs status object
37
+ params = {
38
+ "jobs": jobs,
39
+ "n": parameters.n,
40
+ "endpoint_url": endpoint_url,
41
+ "job_uuid": parameters.job_uuid,
42
+ }
43
+
44
+ # If the jobs_runner_status is already set, use it directly
45
+ if run_config.environment.jobs_runner_status is not None:
46
+ self.jobs_runner_status = run_config.environment.jobs_runner_status
47
+ else:
48
+ # Otherwise create a new one
49
+ self.jobs_runner_status = JobsRunnerStatus(**params)
50
+
51
+ # Store on run_config for use by other components
52
+ run_config.environment.jobs_runner_status = self.jobs_runner_status
53
+
54
+ self.progress_thread = None
55
+ self.stop_event = threading.Event()
56
+
57
+ def __enter__(self):
58
+ if self.parameters.progress_bar and self.jobs_runner_status.has_ep_api_key():
59
+ self.jobs_runner_status.setup()
60
+ self.progress_thread = threading.Thread(
61
+ target=self._run_progress_bar,
62
+ args=(self.stop_event, self.jobs_runner_status)
63
+ )
64
+ self.progress_thread.start()
65
+ elif self.parameters.progress_bar:
66
+ warnings.warn(
67
+ "You need an Expected Parrot API key to view job progress bars."
68
+ )
69
+ return self.stop_event
70
+
71
+ def __exit__(self, exc_type, exc_val, exc_tb):
72
+ self.stop_event.set()
73
+ if self.progress_thread is not None:
74
+ self.progress_thread.join()
75
+
76
+ @staticmethod
77
+ def _run_progress_bar(stop_event, jobs_runner_status):
78
+ """Runs the progress bar in a separate thread."""
79
+ jobs_runner_status.update_progress(stop_event)