edsl 0.1.38.dev4__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (212) hide show
  1. edsl/Base.py +197 -116
  2. edsl/__init__.py +15 -7
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +351 -147
  5. edsl/agents/AgentList.py +211 -73
  6. edsl/agents/Invigilator.py +101 -50
  7. edsl/agents/InvigilatorBase.py +62 -70
  8. edsl/agents/PromptConstructor.py +143 -225
  9. edsl/agents/QuestionInstructionPromptBuilder.py +128 -0
  10. edsl/agents/QuestionTemplateReplacementsBuilder.py +137 -0
  11. edsl/agents/__init__.py +0 -1
  12. edsl/agents/prompt_helpers.py +3 -3
  13. edsl/agents/question_option_processor.py +172 -0
  14. edsl/auto/AutoStudy.py +18 -5
  15. edsl/auto/StageBase.py +53 -40
  16. edsl/auto/StageQuestions.py +2 -1
  17. edsl/auto/utilities.py +0 -6
  18. edsl/config.py +22 -2
  19. edsl/conversation/car_buying.py +2 -1
  20. edsl/coop/CoopFunctionsMixin.py +15 -0
  21. edsl/coop/ExpectedParrotKeyHandler.py +125 -0
  22. edsl/coop/PriceFetcher.py +1 -1
  23. edsl/coop/coop.py +125 -47
  24. edsl/coop/utils.py +14 -14
  25. edsl/data/Cache.py +45 -27
  26. edsl/data/CacheEntry.py +12 -15
  27. edsl/data/CacheHandler.py +31 -12
  28. edsl/data/RemoteCacheSync.py +154 -46
  29. edsl/data/__init__.py +4 -3
  30. edsl/data_transfer_models.py +2 -1
  31. edsl/enums.py +27 -0
  32. edsl/exceptions/__init__.py +50 -50
  33. edsl/exceptions/agents.py +12 -0
  34. edsl/exceptions/inference_services.py +5 -0
  35. edsl/exceptions/questions.py +24 -6
  36. edsl/exceptions/scenarios.py +7 -0
  37. edsl/inference_services/AnthropicService.py +38 -19
  38. edsl/inference_services/AvailableModelCacheHandler.py +184 -0
  39. edsl/inference_services/AvailableModelFetcher.py +215 -0
  40. edsl/inference_services/AwsBedrock.py +0 -2
  41. edsl/inference_services/AzureAI.py +0 -2
  42. edsl/inference_services/GoogleService.py +7 -12
  43. edsl/inference_services/InferenceServiceABC.py +18 -85
  44. edsl/inference_services/InferenceServicesCollection.py +120 -79
  45. edsl/inference_services/MistralAIService.py +0 -3
  46. edsl/inference_services/OpenAIService.py +47 -35
  47. edsl/inference_services/PerplexityService.py +0 -3
  48. edsl/inference_services/ServiceAvailability.py +135 -0
  49. edsl/inference_services/TestService.py +11 -10
  50. edsl/inference_services/TogetherAIService.py +5 -3
  51. edsl/inference_services/data_structures.py +134 -0
  52. edsl/jobs/AnswerQuestionFunctionConstructor.py +223 -0
  53. edsl/jobs/Answers.py +1 -14
  54. edsl/jobs/FetchInvigilator.py +47 -0
  55. edsl/jobs/InterviewTaskManager.py +98 -0
  56. edsl/jobs/InterviewsConstructor.py +50 -0
  57. edsl/jobs/Jobs.py +356 -431
  58. edsl/jobs/JobsChecks.py +35 -10
  59. edsl/jobs/JobsComponentConstructor.py +189 -0
  60. edsl/jobs/JobsPrompts.py +6 -4
  61. edsl/jobs/JobsRemoteInferenceHandler.py +205 -133
  62. edsl/jobs/JobsRemoteInferenceLogger.py +239 -0
  63. edsl/jobs/RequestTokenEstimator.py +30 -0
  64. edsl/jobs/async_interview_runner.py +138 -0
  65. edsl/jobs/buckets/BucketCollection.py +44 -3
  66. edsl/jobs/buckets/TokenBucket.py +53 -21
  67. edsl/jobs/buckets/TokenBucketAPI.py +211 -0
  68. edsl/jobs/buckets/TokenBucketClient.py +191 -0
  69. edsl/jobs/check_survey_scenario_compatibility.py +85 -0
  70. edsl/jobs/data_structures.py +120 -0
  71. edsl/jobs/decorators.py +35 -0
  72. edsl/jobs/interviews/Interview.py +143 -408
  73. edsl/jobs/jobs_status_enums.py +9 -0
  74. edsl/jobs/loggers/HTMLTableJobLogger.py +304 -0
  75. edsl/jobs/results_exceptions_handler.py +98 -0
  76. edsl/jobs/runners/JobsRunnerAsyncio.py +88 -403
  77. edsl/jobs/runners/JobsRunnerStatus.py +133 -165
  78. edsl/jobs/tasks/QuestionTaskCreator.py +21 -19
  79. edsl/jobs/tasks/TaskHistory.py +38 -18
  80. edsl/jobs/tasks/task_status_enum.py +0 -2
  81. edsl/language_models/ComputeCost.py +63 -0
  82. edsl/language_models/LanguageModel.py +194 -236
  83. edsl/language_models/ModelList.py +28 -19
  84. edsl/language_models/PriceManager.py +127 -0
  85. edsl/language_models/RawResponseHandler.py +106 -0
  86. edsl/language_models/ServiceDataSources.py +0 -0
  87. edsl/language_models/__init__.py +1 -2
  88. edsl/language_models/key_management/KeyLookup.py +63 -0
  89. edsl/language_models/key_management/KeyLookupBuilder.py +273 -0
  90. edsl/language_models/key_management/KeyLookupCollection.py +38 -0
  91. edsl/language_models/key_management/__init__.py +0 -0
  92. edsl/language_models/key_management/models.py +131 -0
  93. edsl/language_models/model.py +256 -0
  94. edsl/language_models/repair.py +2 -2
  95. edsl/language_models/utilities.py +5 -4
  96. edsl/notebooks/Notebook.py +19 -14
  97. edsl/notebooks/NotebookToLaTeX.py +142 -0
  98. edsl/prompts/Prompt.py +29 -39
  99. edsl/questions/ExceptionExplainer.py +77 -0
  100. edsl/questions/HTMLQuestion.py +103 -0
  101. edsl/questions/QuestionBase.py +68 -214
  102. edsl/questions/QuestionBasePromptsMixin.py +7 -3
  103. edsl/questions/QuestionBudget.py +1 -1
  104. edsl/questions/QuestionCheckBox.py +3 -3
  105. edsl/questions/QuestionExtract.py +5 -7
  106. edsl/questions/QuestionFreeText.py +2 -3
  107. edsl/questions/QuestionList.py +10 -18
  108. edsl/questions/QuestionMatrix.py +265 -0
  109. edsl/questions/QuestionMultipleChoice.py +67 -23
  110. edsl/questions/QuestionNumerical.py +2 -4
  111. edsl/questions/QuestionRank.py +7 -17
  112. edsl/questions/SimpleAskMixin.py +4 -3
  113. edsl/questions/__init__.py +2 -1
  114. edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +47 -2
  115. edsl/questions/data_structures.py +20 -0
  116. edsl/questions/derived/QuestionLinearScale.py +6 -3
  117. edsl/questions/derived/QuestionTopK.py +1 -1
  118. edsl/questions/descriptors.py +17 -3
  119. edsl/questions/loop_processor.py +149 -0
  120. edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +57 -50
  121. edsl/questions/question_registry.py +1 -1
  122. edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +40 -26
  123. edsl/questions/response_validator_factory.py +34 -0
  124. edsl/questions/templates/matrix/__init__.py +1 -0
  125. edsl/questions/templates/matrix/answering_instructions.jinja +5 -0
  126. edsl/questions/templates/matrix/question_presentation.jinja +20 -0
  127. edsl/results/CSSParameterizer.py +1 -1
  128. edsl/results/Dataset.py +170 -7
  129. edsl/results/DatasetExportMixin.py +168 -305
  130. edsl/results/DatasetTree.py +28 -8
  131. edsl/results/MarkdownToDocx.py +122 -0
  132. edsl/results/MarkdownToPDF.py +111 -0
  133. edsl/results/Result.py +298 -206
  134. edsl/results/Results.py +149 -131
  135. edsl/results/ResultsExportMixin.py +2 -0
  136. edsl/results/TableDisplay.py +98 -171
  137. edsl/results/TextEditor.py +50 -0
  138. edsl/results/__init__.py +1 -1
  139. edsl/results/file_exports.py +252 -0
  140. edsl/results/{Selector.py → results_selector.py} +23 -13
  141. edsl/results/smart_objects.py +96 -0
  142. edsl/results/table_data_class.py +12 -0
  143. edsl/results/table_renderers.py +118 -0
  144. edsl/scenarios/ConstructDownloadLink.py +109 -0
  145. edsl/scenarios/DocumentChunker.py +102 -0
  146. edsl/scenarios/DocxScenario.py +16 -0
  147. edsl/scenarios/FileStore.py +150 -239
  148. edsl/scenarios/PdfExtractor.py +40 -0
  149. edsl/scenarios/Scenario.py +90 -193
  150. edsl/scenarios/ScenarioHtmlMixin.py +4 -3
  151. edsl/scenarios/ScenarioList.py +415 -244
  152. edsl/scenarios/ScenarioListExportMixin.py +0 -7
  153. edsl/scenarios/ScenarioListPdfMixin.py +15 -37
  154. edsl/scenarios/__init__.py +1 -2
  155. edsl/scenarios/directory_scanner.py +96 -0
  156. edsl/scenarios/file_methods.py +85 -0
  157. edsl/scenarios/handlers/__init__.py +13 -0
  158. edsl/scenarios/handlers/csv.py +49 -0
  159. edsl/scenarios/handlers/docx.py +76 -0
  160. edsl/scenarios/handlers/html.py +37 -0
  161. edsl/scenarios/handlers/json.py +111 -0
  162. edsl/scenarios/handlers/latex.py +5 -0
  163. edsl/scenarios/handlers/md.py +51 -0
  164. edsl/scenarios/handlers/pdf.py +68 -0
  165. edsl/scenarios/handlers/png.py +39 -0
  166. edsl/scenarios/handlers/pptx.py +105 -0
  167. edsl/scenarios/handlers/py.py +294 -0
  168. edsl/scenarios/handlers/sql.py +313 -0
  169. edsl/scenarios/handlers/sqlite.py +149 -0
  170. edsl/scenarios/handlers/txt.py +33 -0
  171. edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +10 -6
  172. edsl/scenarios/scenario_selector.py +156 -0
  173. edsl/study/ObjectEntry.py +1 -1
  174. edsl/study/SnapShot.py +1 -1
  175. edsl/study/Study.py +5 -12
  176. edsl/surveys/ConstructDAG.py +92 -0
  177. edsl/surveys/EditSurvey.py +221 -0
  178. edsl/surveys/InstructionHandler.py +100 -0
  179. edsl/surveys/MemoryManagement.py +72 -0
  180. edsl/surveys/Rule.py +5 -4
  181. edsl/surveys/RuleCollection.py +25 -27
  182. edsl/surveys/RuleManager.py +172 -0
  183. edsl/surveys/Simulator.py +75 -0
  184. edsl/surveys/Survey.py +270 -791
  185. edsl/surveys/SurveyCSS.py +20 -8
  186. edsl/surveys/{SurveyFlowVisualizationMixin.py → SurveyFlowVisualization.py} +11 -9
  187. edsl/surveys/SurveyToApp.py +141 -0
  188. edsl/surveys/__init__.py +4 -2
  189. edsl/surveys/descriptors.py +6 -2
  190. edsl/surveys/instructions/ChangeInstruction.py +1 -2
  191. edsl/surveys/instructions/Instruction.py +4 -13
  192. edsl/surveys/instructions/InstructionCollection.py +11 -6
  193. edsl/templates/error_reporting/interview_details.html +1 -1
  194. edsl/templates/error_reporting/report.html +1 -1
  195. edsl/tools/plotting.py +1 -1
  196. edsl/utilities/PrettyList.py +56 -0
  197. edsl/utilities/is_notebook.py +18 -0
  198. edsl/utilities/is_valid_variable_name.py +11 -0
  199. edsl/utilities/remove_edsl_version.py +24 -0
  200. edsl/utilities/utilities.py +35 -23
  201. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/METADATA +12 -10
  202. edsl-0.1.39.dist-info/RECORD +358 -0
  203. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/WHEEL +1 -1
  204. edsl/language_models/KeyLookup.py +0 -30
  205. edsl/language_models/registry.py +0 -190
  206. edsl/language_models/unused/ReplicateBase.py +0 -83
  207. edsl/results/ResultsDBMixin.py +0 -238
  208. edsl-0.1.38.dev4.dist-info/RECORD +0 -277
  209. /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
  210. /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
  211. /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
  212. {edsl-0.1.38.dev4.dist-info → edsl-0.1.39.dist-info}/LICENSE +0 -0
@@ -3,21 +3,12 @@ from __future__ import annotations
3
3
  import os
4
4
  import time
5
5
  import requests
6
- import warnings
7
6
  from abc import ABC, abstractmethod
8
7
  from dataclasses import dataclass
9
-
10
- from typing import Any, List, DefaultDict, Optional, Dict
11
8
  from collections import defaultdict
9
+ from typing import Any, Dict, Optional
12
10
  from uuid import UUID
13
11
 
14
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
15
-
16
- InterviewTokenUsageMapping = DefaultDict[str, InterviewTokenUsage]
17
-
18
- from edsl.jobs.interviews.InterviewStatistic import InterviewStatistic
19
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
20
-
21
12
 
22
13
  @dataclass
23
14
  class ModelInfo:
@@ -28,11 +19,44 @@ class ModelInfo:
28
19
  token_usage_info: dict
29
20
 
30
21
 
31
- @dataclass
32
- class ModelTokenUsageStats:
33
- token_usage_type: str
34
- details: List[dict]
35
- cost: str
22
+ class StatisticsTracker:
23
+ def __init__(self, total_interviews: int, distinct_models: list[str]):
24
+ self.start_time = time.time()
25
+ self.total_interviews = total_interviews
26
+ self.completed_count = 0
27
+ self.completed_by_model = defaultdict(int)
28
+ self.distinct_models = distinct_models
29
+ self.total_exceptions = 0
30
+ self.unfixed_exceptions = 0
31
+
32
+ def add_completed_interview(
33
+ self, model: str, num_exceptions: int = 0, num_unfixed: int = 0
34
+ ):
35
+ self.completed_count += 1
36
+ self.completed_by_model[model] += 1
37
+ self.total_exceptions += num_exceptions
38
+ self.unfixed_exceptions += num_unfixed
39
+
40
+ def get_elapsed_time(self) -> float:
41
+ return time.time() - self.start_time
42
+
43
+ def get_average_time_per_interview(self) -> float:
44
+ return (
45
+ self.get_elapsed_time() / self.completed_count
46
+ if self.completed_count > 0
47
+ else 0
48
+ )
49
+
50
+ def get_throughput(self) -> float:
51
+ elapsed = self.get_elapsed_time()
52
+ return self.completed_count / elapsed if elapsed > 0 else 0
53
+
54
+ def get_estimated_time_remaining(self) -> float:
55
+ if self.completed_count == 0:
56
+ return 0
57
+ avg_time = self.get_average_time_per_interview()
58
+ remaining = self.total_interviews - self.completed_count
59
+ return avg_time * remaining
36
60
 
37
61
 
38
62
  class JobsRunnerStatusBase(ABC):
@@ -46,48 +70,39 @@ class JobsRunnerStatusBase(ABC):
46
70
  api_key: str = None,
47
71
  ):
48
72
  self.jobs_runner = jobs_runner
49
-
50
- # The uuid of the job on Coop
51
73
  self.job_uuid = job_uuid
52
-
53
74
  self.base_url = f"{endpoint_url}"
54
-
55
- self.start_time = time.time()
56
- self.completed_interviews = []
57
75
  self.refresh_rate = refresh_rate
58
76
  self.statistics = [
59
77
  "elapsed_time",
60
78
  "total_interviews_requested",
61
79
  "completed_interviews",
62
- # "percent_complete",
63
80
  "average_time_per_interview",
64
- # "task_remaining",
65
81
  "estimated_time_remaining",
66
82
  "exceptions",
67
83
  "unfixed_exceptions",
68
84
  "throughput",
69
85
  ]
70
- self.num_total_interviews = n * len(self.jobs_runner.interviews)
86
+ self.num_total_interviews = n * len(self.jobs_runner)
71
87
 
72
88
  self.distinct_models = list(
73
- set(i.model.model for i in self.jobs_runner.interviews)
89
+ set(model.model for model in self.jobs_runner.jobs.models)
74
90
  )
75
91
 
76
- self.completed_interview_by_model = defaultdict(list)
92
+ self.stats_tracker = StatisticsTracker(
93
+ total_interviews=self.num_total_interviews,
94
+ distinct_models=self.distinct_models,
95
+ )
77
96
 
78
97
  self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
79
98
 
80
99
  @abstractmethod
81
100
  def has_ep_api_key(self):
82
- """
83
- Checks if the user has an Expected Parrot API key.
84
- """
101
+ """Checks if the user has an Expected Parrot API key."""
85
102
  pass
86
103
 
87
104
  def get_status_dict(self) -> Dict[str, Any]:
88
- """
89
- Converts current status into a JSON-serializable dictionary.
90
- """
105
+ """Converts current status into a JSON-serializable dictionary."""
91
106
  # Get all statistics
92
107
  stats = {}
93
108
  for stat_name in self.statistics:
@@ -95,42 +110,46 @@ class JobsRunnerStatusBase(ABC):
95
110
  name, value = list(stat.items())[0]
96
111
  stats[name] = value
97
112
 
98
- # Calculate overall progress
99
- total_interviews = len(self.jobs_runner.total_interviews)
100
- completed = len(self.completed_interviews)
101
-
102
113
  # Get model-specific progress
103
114
  model_progress = {}
115
+ target_per_model = int(self.num_total_interviews / len(self.distinct_models))
116
+
104
117
  for model in self.distinct_models:
105
- completed_for_model = len(self.completed_interview_by_model[model])
106
- target_for_model = int(
107
- self.num_total_interviews / len(self.distinct_models)
108
- )
118
+ completed = self.stats_tracker.completed_by_model[model]
109
119
  model_progress[model] = {
110
- "completed": completed_for_model,
111
- "total": target_for_model,
120
+ "completed": completed,
121
+ "total": target_per_model,
112
122
  "percent": (
113
- (completed_for_model / target_for_model * 100)
114
- if target_for_model > 0
115
- else 0
123
+ (completed / target_per_model * 100) if target_per_model > 0 else 0
116
124
  ),
117
125
  }
118
126
 
119
127
  status_dict = {
120
128
  "overall_progress": {
121
- "completed": completed,
122
- "total": total_interviews,
129
+ "completed": self.stats_tracker.completed_count,
130
+ "total": self.num_total_interviews,
123
131
  "percent": (
124
- (completed / total_interviews * 100) if total_interviews > 0 else 0
132
+ (
133
+ self.stats_tracker.completed_count
134
+ / self.num_total_interviews
135
+ * 100
136
+ )
137
+ if self.num_total_interviews > 0
138
+ else 0
125
139
  ),
126
140
  },
127
141
  "language_model_progress": model_progress,
128
142
  "statistics": stats,
129
- "status": "completed" if completed >= total_interviews else "running",
143
+ "status": (
144
+ "completed"
145
+ if self.stats_tracker.completed_count >= self.num_total_interviews
146
+ else "running"
147
+ ),
130
148
  }
131
149
 
132
150
  model_queues = {}
133
- for model, bucket in self.jobs_runner.bucket_collection.items():
151
+ # for model, bucket in self.jobs_runner.bucket_collection.items():
152
+ for model, bucket in self.jobs_runner.environment.bucket_collection.items():
134
153
  model_name = model.model
135
154
  model_queues[model_name] = {
136
155
  "language_model_name": model_name,
@@ -152,99 +171,68 @@ class JobsRunnerStatusBase(ABC):
152
171
  status_dict["language_model_queues"] = model_queues
153
172
  return status_dict
154
173
 
155
- @abstractmethod
156
- def setup(self):
157
- """
158
- Conducts any setup that needs to happen prior to sending status updates.
174
+ def add_completed_interview(self, result):
175
+ """Records a completed interview without storing the full interview data."""
176
+ self.stats_tracker.add_completed_interview(
177
+ model=result.model.model,
178
+ num_exceptions=(
179
+ len(result.exceptions) if hasattr(result, "exceptions") else 0
180
+ ),
181
+ num_unfixed=(
182
+ result.exceptions.num_unfixed() if hasattr(result, "exceptions") else 0
183
+ ),
184
+ )
159
185
 
160
- Ex. For a local job, creates a job in the Coop database.
161
- """
162
- pass
186
+ def _compute_statistic(self, stat_name: str):
187
+ """Computes individual statistics based on the stats tracker."""
188
+ if stat_name == "elapsed_time":
189
+ value = self.stats_tracker.get_elapsed_time()
190
+ return {"elapsed_time": (value, 1, "sec.")}
163
191
 
164
- @abstractmethod
165
- def send_status_update(self):
166
- """
167
- Updates the current status of the job.
168
- """
169
- pass
192
+ elif stat_name == "total_interviews_requested":
193
+ return {"total_interviews_requested": (self.num_total_interviews, None, "")}
170
194
 
171
- def add_completed_interview(self, result):
172
- self.completed_interviews.append(result.interview_hash)
195
+ elif stat_name == "completed_interviews":
196
+ return {
197
+ "completed_interviews": (self.stats_tracker.completed_count, None, "")
198
+ }
173
199
 
174
- relevant_model = result.model.model
175
- self.completed_interview_by_model[relevant_model].append(result.interview_hash)
200
+ elif stat_name == "average_time_per_interview":
201
+ value = self.stats_tracker.get_average_time_per_interview()
202
+ return {"average_time_per_interview": (value, 2, "sec.")}
176
203
 
177
- def _compute_statistic(self, stat_name: str):
178
- completed_tasks = self.completed_interviews
179
- elapsed_time = time.time() - self.start_time
180
- interviews = self.jobs_runner.total_interviews
204
+ elif stat_name == "estimated_time_remaining":
205
+ value = self.stats_tracker.get_estimated_time_remaining()
206
+ return {"estimated_time_remaining": (value, 1, "sec.")}
181
207
 
182
- stat_definitions = {
183
- "elapsed_time": lambda: InterviewStatistic(
184
- "elapsed_time", value=elapsed_time, digits=1, units="sec."
185
- ),
186
- "total_interviews_requested": lambda: InterviewStatistic(
187
- "total_interviews_requested", value=len(interviews), units=""
188
- ),
189
- "completed_interviews": lambda: InterviewStatistic(
190
- "completed_interviews", value=len(completed_tasks), units=""
191
- ),
192
- "percent_complete": lambda: InterviewStatistic(
193
- "percent_complete",
194
- value=(
195
- len(completed_tasks) / len(interviews) * 100
196
- if len(interviews) > 0
197
- else 0
198
- ),
199
- digits=1,
200
- units="%",
201
- ),
202
- "average_time_per_interview": lambda: InterviewStatistic(
203
- "average_time_per_interview",
204
- value=elapsed_time / len(completed_tasks) if completed_tasks else 0,
205
- digits=2,
206
- units="sec.",
207
- ),
208
- "task_remaining": lambda: InterviewStatistic(
209
- "task_remaining", value=len(interviews) - len(completed_tasks), units=""
210
- ),
211
- "estimated_time_remaining": lambda: InterviewStatistic(
212
- "estimated_time_remaining",
213
- value=(
214
- (len(interviews) - len(completed_tasks))
215
- * (elapsed_time / len(completed_tasks))
216
- if len(completed_tasks) > 0
217
- else 0
218
- ),
219
- digits=1,
220
- units="sec.",
221
- ),
222
- "exceptions": lambda: InterviewStatistic(
223
- "exceptions",
224
- value=sum(len(i.exceptions) for i in interviews),
225
- units="",
226
- ),
227
- "unfixed_exceptions": lambda: InterviewStatistic(
228
- "unfixed_exceptions",
229
- value=sum(i.exceptions.num_unfixed() for i in interviews),
230
- units="",
231
- ),
232
- "throughput": lambda: InterviewStatistic(
233
- "throughput",
234
- value=len(completed_tasks) / elapsed_time if elapsed_time > 0 else 0,
235
- digits=2,
236
- units="interviews/sec.",
237
- ),
238
- }
239
- return stat_definitions[stat_name]()
208
+ elif stat_name == "exceptions":
209
+ return {"exceptions": (self.stats_tracker.total_exceptions, None, "")}
210
+
211
+ elif stat_name == "unfixed_exceptions":
212
+ return {
213
+ "unfixed_exceptions": (self.stats_tracker.unfixed_exceptions, None, "")
214
+ }
215
+
216
+ elif stat_name == "throughput":
217
+ value = self.stats_tracker.get_throughput()
218
+ return {"throughput": (value, 2, "interviews/sec.")}
240
219
 
241
220
  def update_progress(self, stop_event):
242
221
  while not stop_event.is_set():
243
222
  self.send_status_update()
244
223
  time.sleep(self.refresh_rate)
245
-
246
224
  self.send_status_update()
247
225
 
226
+ @abstractmethod
227
+ def setup(self):
228
+ """Conducts any setup needed prior to sending status updates."""
229
+ pass
230
+
231
+ @abstractmethod
232
+ def send_status_update(self):
233
+ """Updates the current status of the job."""
234
+ pass
235
+
248
236
 
249
237
  class JobsRunnerStatus(JobsRunnerStatusBase):
250
238
  @property
@@ -260,49 +248,35 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
260
248
  return f"{self.base_url}/api/v0/local-job/{str(self.job_uuid)}"
261
249
 
262
250
  def setup(self) -> None:
263
- """
264
- Creates a local job on Coop if one does not already exist.
265
- """
266
-
267
- headers = {"Content-Type": "application/json"}
268
-
269
- if self.api_key:
270
- headers["Authorization"] = f"Bearer {self.api_key}"
271
- else:
272
- headers["Authorization"] = f"Bearer None"
251
+ """Creates a local job on Coop if one does not already exist."""
252
+ headers = {
253
+ "Content-Type": "application/json",
254
+ "Authorization": f"Bearer {self.api_key or 'None'}",
255
+ }
273
256
 
274
257
  if self.job_uuid is None:
275
- # Create a new local job
276
258
  response = requests.post(
277
259
  self.create_url,
278
260
  headers=headers,
279
261
  timeout=1,
280
262
  )
281
- response.raise_for_status()
282
- data = response.json()
283
- self.job_uuid = data.get("job_uuid")
263
+ response.raise_for_status()
264
+ data = response.json()
265
+ self.job_uuid = data.get("job_uuid")
284
266
 
285
267
  print(f"Running with progress bar. View progress at {self.viewing_url}")
286
268
 
287
269
  def send_status_update(self) -> None:
288
- """
289
- Sends current status to the web endpoint using the instance's job_uuid.
290
- """
270
+ """Sends current status to the web endpoint using the instance's job_uuid."""
291
271
  try:
292
- # Get the status dictionary and add the job_id
293
272
  status_dict = self.get_status_dict()
294
-
295
- # Make the UUID JSON serializable
296
273
  status_dict["job_id"] = str(self.job_uuid)
297
274
 
298
- headers = {"Content-Type": "application/json"}
299
-
300
- if self.api_key:
301
- headers["Authorization"] = f"Bearer {self.api_key}"
302
- else:
303
- headers["Authorization"] = f"Bearer None"
275
+ headers = {
276
+ "Content-Type": "application/json",
277
+ "Authorization": f"Bearer {self.api_key or 'None'}",
278
+ }
304
279
 
305
- # Send the update
306
280
  response = requests.patch(
307
281
  self.update_url,
308
282
  json=status_dict,
@@ -314,14 +288,8 @@ class JobsRunnerStatus(JobsRunnerStatusBase):
314
288
  print(f"Failed to send status update for job {self.job_uuid}: {e}")
315
289
 
316
290
  def has_ep_api_key(self) -> bool:
317
- """
318
- Returns True if the user has an Expected Parrot API key. Otherwise, returns False.
319
- """
320
-
321
- if self.api_key is not None:
322
- return True
323
- else:
324
- return False
291
+ """Returns True if the user has an Expected Parrot API key."""
292
+ return self.api_key is not None
325
293
 
326
294
 
327
295
  if __name__ == "__main__":
@@ -1,17 +1,17 @@
1
1
  import asyncio
2
- from typing import Callable, Union, List
2
+ from typing import Callable, Union, List, TYPE_CHECKING
3
3
  from collections import UserList, UserDict
4
4
 
5
- from edsl.jobs.buckets import ModelBuckets
6
- from edsl.exceptions import InterviewErrorPriorTaskCanceled
5
+ from edsl.exceptions.jobs import InterviewErrorPriorTaskCanceled
7
6
 
8
- from edsl.jobs.interviews.InterviewStatusDictionary import InterviewStatusDictionary
9
7
  from edsl.jobs.tasks.task_status_enum import TaskStatus, TaskStatusDescriptor
10
8
  from edsl.jobs.tasks.TaskStatusLog import TaskStatusLog
11
- from edsl.jobs.tokens.InterviewTokenUsage import InterviewTokenUsage
12
9
  from edsl.jobs.tokens.TokenUsage import TokenUsage
13
10
  from edsl.jobs.Answers import Answers
14
- from edsl.questions.QuestionBase import QuestionBase
11
+
12
+ if TYPE_CHECKING:
13
+ from edsl.questions.QuestionBase import QuestionBase
14
+ from edsl.jobs.buckets import ModelBuckets
15
15
 
16
16
 
17
17
  class TokensUsed(UserDict):
@@ -24,7 +24,6 @@ class TokensUsed(UserDict):
24
24
 
25
25
  class QuestionTaskCreator(UserList):
26
26
  """Class to create and manage a single question and its dependencies.
27
- The class is an instance of a UserList of tasks that must be completed before the focal task can be run.
28
27
 
29
28
  It is a UserList with all the tasks that must be completed before the focal task can be run.
30
29
  The focal task is the question that we are interested in answering.
@@ -35,9 +34,9 @@ class QuestionTaskCreator(UserList):
35
34
  def __init__(
36
35
  self,
37
36
  *,
38
- question: QuestionBase,
37
+ question: "QuestionBase",
39
38
  answer_question_func: Callable,
40
- model_buckets: ModelBuckets,
39
+ model_buckets: "ModelBuckets",
41
40
  token_estimator: Union[Callable, None] = None,
42
41
  iteration: int = 0,
43
42
  ):
@@ -51,14 +50,15 @@ class QuestionTaskCreator(UserList):
51
50
 
52
51
  """
53
52
  super().__init__([])
54
- # answer_question_func is the 'interview.answer_question_and_record_task" method
55
53
  self.answer_question_func = answer_question_func
56
54
  self.question = question
57
55
  self.iteration = iteration
58
56
 
59
57
  self.model_buckets = model_buckets
58
+
60
59
  self.requests_bucket = self.model_buckets.requests_bucket
61
60
  self.tokens_bucket = self.model_buckets.tokens_bucket
61
+
62
62
  self.status_log = TaskStatusLog()
63
63
 
64
64
  def fake_token_estimator(question):
@@ -125,11 +125,13 @@ class QuestionTaskCreator(UserList):
125
125
 
126
126
  await self.tokens_bucket.get_tokens(requested_tokens)
127
127
 
128
- if (estimated_wait_time := self.requests_bucket.wait_time(1)) > 0:
128
+ if (estimated_wait_time := self.model_buckets.requests_bucket.wait_time(1)) > 0:
129
129
  self.waiting = True # do we need this?
130
130
  self.task_status = TaskStatus.WAITING_FOR_REQUEST_CAPACITY
131
131
 
132
- await self.requests_bucket.get_tokens(1, cheat_bucket_capacity=True)
132
+ await self.model_buckets.requests_bucket.get_tokens(
133
+ 1, cheat_bucket_capacity=True
134
+ )
133
135
 
134
136
  self.task_status = TaskStatus.API_CALL_IN_PROGRESS
135
137
  try:
@@ -142,22 +144,22 @@ class QuestionTaskCreator(UserList):
142
144
  raise e
143
145
 
144
146
  if results.cache_used:
145
- self.tokens_bucket.add_tokens(requested_tokens)
146
- self.requests_bucket.add_tokens(1)
147
+ self.model_buckets.tokens_bucket.add_tokens(requested_tokens)
148
+ self.model_buckets.requests_bucket.add_tokens(1)
147
149
  self.from_cache = True
148
150
  # Turbo mode means that we don't wait for tokens or requests.
149
- self.tokens_bucket.turbo_mode_on()
150
- self.requests_bucket.turbo_mode_on()
151
+ self.model_buckets.tokens_bucket.turbo_mode_on()
152
+ self.model_buckets.requests_bucket.turbo_mode_on()
151
153
  else:
152
- self.tokens_bucket.turbo_mode_off()
153
- self.requests_bucket.turbo_mode_off()
154
+ self.model_buckets.tokens_bucket.turbo_mode_off()
155
+ self.model_buckets.requests_bucket.turbo_mode_off()
154
156
 
155
157
  return results
156
158
 
157
159
  @classmethod
158
160
  def example(cls):
159
161
  """Return an example instance of the class."""
160
- from edsl import QuestionFreeText
162
+ from edsl.questions.QuestionFreeText import QuestionFreeText
161
163
  from edsl.jobs.buckets.ModelBuckets import ModelBuckets
162
164
 
163
165
  m = ModelBuckets.infinity_bucket()
@@ -1,18 +1,17 @@
1
1
  from typing import List, Optional
2
2
  from io import BytesIO
3
- import webbrowser
4
- import os
5
3
  import base64
6
- from importlib import resources
7
4
  from edsl.jobs.tasks.task_status_enum import TaskStatus
5
+ from edsl.Base import RepresentationMixin
8
6
 
9
7
 
10
- class TaskHistory:
8
+ class TaskHistory(RepresentationMixin):
11
9
  def __init__(
12
10
  self,
13
- interviews: List["Interview"],
11
+ interviews: List["Interview"] = None,
14
12
  include_traceback: bool = False,
15
13
  max_interviews: int = 10,
14
+ interviews_with_exceptions_only: bool = False,
16
15
  ):
17
16
  """
18
17
  The structure of a TaskHistory exception
@@ -22,13 +21,33 @@ class TaskHistory:
22
21
  >>> _ = TaskHistory.example()
23
22
  ...
24
23
  """
24
+ self.interviews_with_exceptions_only = interviews_with_exceptions_only
25
+ self._interviews = {}
26
+ self.total_interviews = []
27
+ if interviews is not None:
28
+ for interview in interviews:
29
+ self.add_interview(interview)
25
30
 
26
- self.total_interviews = interviews
31
+ self.include_traceback = include_traceback
32
+ self._interviews = {
33
+ index: interview for index, interview in enumerate(self.total_interviews)
34
+ }
35
+ self.max_interviews = max_interviews
36
+
37
+ # self.total_interviews = interviews
27
38
  self.include_traceback = include_traceback
28
39
 
29
- self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
40
+ # self._interviews = {index: i for index, i in enumerate(self.total_interviews)}
30
41
  self.max_interviews = max_interviews
31
42
 
43
+ def add_interview(self, interview: "Interview"):
44
+ """Add a single interview to the history"""
45
+ if self.interviews_with_exceptions_only and interview.exceptions == {}:
46
+ return
47
+
48
+ self.total_interviews.append(interview)
49
+ self._interviews[len(self._interviews)] = interview
50
+
32
51
  @classmethod
33
52
  def example(cls):
34
53
  """ """
@@ -121,14 +140,6 @@ class TaskHistory:
121
140
  """Return True if there are any exceptions."""
122
141
  return len(self.unfixed_exceptions) > 0
123
142
 
124
- def _repr_html_(self):
125
- """Return an HTML representation of the TaskHistory."""
126
- d = self.to_dict(add_edsl_version=False)
127
- data = [[k, v] for k, v in d.items()]
128
- from tabulate import tabulate
129
-
130
- return tabulate(data, headers=["keys", "values"], tablefmt="html")
131
-
132
143
  def show_exceptions(self, tracebacks=False):
133
144
  """Print the exceptions."""
134
145
  for index in self.indices:
@@ -240,11 +251,15 @@ class TaskHistory:
240
251
  plt.show()
241
252
 
242
253
  def css(self):
254
+ from importlib import resources
255
+
243
256
  env = resources.files("edsl").joinpath("templates/error_reporting")
244
257
  css = env.joinpath("report.css").read_text()
245
258
  return css
246
259
 
247
260
  def javascript(self):
261
+ from importlib import resources
262
+
248
263
  env = resources.files("edsl").joinpath("templates/error_reporting")
249
264
  js = env.joinpath("report.js").read_text()
250
265
  return js
@@ -281,7 +296,7 @@ class TaskHistory:
281
296
  exceptions_by_question_name = {}
282
297
  for interview in self.total_interviews:
283
298
  for question_name, exceptions in interview.exceptions.items():
284
- question_type = interview.survey.get_question(
299
+ question_type = interview.survey._get_question_by_name(
285
300
  question_name
286
301
  ).question_type
287
302
  if (question_name, question_type) not in exceptions_by_question_name:
@@ -330,8 +345,11 @@ class TaskHistory:
330
345
  }
331
346
  return sorted_exceptions_by_model
332
347
 
333
- def generate_html_report(self, css: Optional[str]):
334
- performance_plot_html = self.plot(num_periods=100, get_embedded_html=True)
348
+ def generate_html_report(self, css: Optional[str], include_plot=False):
349
+ if include_plot:
350
+ performance_plot_html = self.plot(num_periods=100, get_embedded_html=True)
351
+ else:
352
+ performance_plot_html = ""
335
353
 
336
354
  if css is None:
337
355
  css = self.css()
@@ -409,6 +427,8 @@ class TaskHistory:
409
427
  print(f"Exception report saved to {filename}")
410
428
 
411
429
  if open_in_browser:
430
+ import webbrowser
431
+
412
432
  webbrowser.open(f"file://{os.path.abspath(filename)}")
413
433
 
414
434
  if return_link:
@@ -3,8 +3,6 @@ from collections import UserDict
3
3
  import enum
4
4
  import time
5
5
 
6
- # from edsl.jobs.tasks.TaskStatusLogEntry import TaskStatusLogEntry
7
-
8
6
 
9
7
  class TaskStatus(enum.Enum):
10
8
  "These are the possible states a task can be in."