edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. edsl/Base.py +15 -11
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +22 -3
  4. edsl/agents/PromptConstructor.py +80 -184
  5. edsl/agents/prompt_helpers.py +129 -0
  6. edsl/coop/coop.py +3 -2
  7. edsl/data_transfer_models.py +0 -1
  8. edsl/inference_services/AnthropicService.py +5 -2
  9. edsl/inference_services/AwsBedrock.py +5 -2
  10. edsl/inference_services/AzureAI.py +5 -2
  11. edsl/inference_services/GoogleService.py +108 -33
  12. edsl/inference_services/MistralAIService.py +5 -2
  13. edsl/inference_services/OpenAIService.py +3 -2
  14. edsl/inference_services/TestService.py +11 -2
  15. edsl/inference_services/TogetherAIService.py +1 -1
  16. edsl/jobs/Jobs.py +91 -10
  17. edsl/jobs/interviews/Interview.py +15 -2
  18. edsl/jobs/runners/JobsRunnerAsyncio.py +46 -25
  19. edsl/jobs/runners/JobsRunnerStatus.py +4 -3
  20. edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
  21. edsl/language_models/LanguageModel.py +12 -9
  22. edsl/language_models/utilities.py +5 -2
  23. edsl/questions/QuestionBase.py +13 -3
  24. edsl/questions/QuestionBaseGenMixin.py +28 -0
  25. edsl/questions/QuestionCheckBox.py +1 -1
  26. edsl/questions/QuestionMultipleChoice.py +8 -4
  27. edsl/questions/ResponseValidatorABC.py +5 -1
  28. edsl/questions/descriptors.py +12 -11
  29. edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
  30. edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
  31. edsl/scenarios/FileStore.py +159 -76
  32. edsl/scenarios/Scenario.py +23 -49
  33. edsl/scenarios/ScenarioList.py +6 -2
  34. edsl/surveys/DAG.py +62 -0
  35. edsl/surveys/MemoryPlan.py +26 -0
  36. edsl/surveys/Rule.py +24 -0
  37. edsl/surveys/RuleCollection.py +36 -2
  38. edsl/surveys/Survey.py +182 -10
  39. edsl/surveys/base.py +4 -0
  40. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/METADATA +2 -1
  41. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/RECORD +43 -43
  42. edsl/scenarios/ScenarioImageMixin.py +0 -100
  43. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
  44. {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any
2
+ from typing import Any, Optional, List
3
3
  import re
4
4
  from openai import AsyncAzureOpenAI
5
5
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -122,7 +122,10 @@ class AzureAIService(InferenceServiceABC):
122
122
  _tpm = cls.get_tpm(cls)
123
123
 
124
124
  async def async_execute_model_call(
125
- self, user_prompt: str, system_prompt: str = ""
125
+ self,
126
+ user_prompt: str,
127
+ system_prompt: str = "",
128
+ files_list: Optional[List["FileStore"]] = None,
126
129
  ) -> dict[str, Any]:
127
130
  """Calls the Azure OpenAI API and returns the API response."""
128
131
 
@@ -1,25 +1,54 @@
1
1
  import os
2
- import aiohttp
3
- import json
4
- from typing import Any
2
+ from typing import Any, Dict, List, Optional
3
+ import google
4
+ import google.generativeai as genai
5
+ from google.generativeai.types import GenerationConfig
6
+ from google.api_core.exceptions import InvalidArgument
7
+
5
8
  from edsl.exceptions import MissingAPIKeyError
6
9
  from edsl.language_models.LanguageModel import LanguageModel
7
-
8
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
9
11
 
12
+ safety_settings = [
13
+ {
14
+ "category": "HARM_CATEGORY_HARASSMENT",
15
+ "threshold": "BLOCK_NONE",
16
+ },
17
+ {
18
+ "category": "HARM_CATEGORY_HATE_SPEECH",
19
+ "threshold": "BLOCK_NONE",
20
+ },
21
+ {
22
+ "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
23
+ "threshold": "BLOCK_NONE",
24
+ },
25
+ {
26
+ "category": "HARM_CATEGORY_DANGEROUS_CONTENT",
27
+ "threshold": "BLOCK_NONE",
28
+ },
29
+ ]
30
+
10
31
 
11
32
  class GoogleService(InferenceServiceABC):
12
33
  _inference_service_ = "google"
13
34
  key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
14
- usage_sequence = ["usageMetadata"]
15
- input_token_name = "promptTokenCount"
16
- output_token_name = "candidatesTokenCount"
35
+ usage_sequence = ["usage_metadata"]
36
+ input_token_name = "prompt_token_count"
37
+ output_token_name = "candidates_token_count"
17
38
 
18
39
  model_exclude_list = []
19
40
 
41
+ # @classmethod
42
+ # def available(cls) -> List[str]:
43
+ # return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
44
+
20
45
  @classmethod
21
- def available(cls):
22
- return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
46
+ def available(cls) -> List[str]:
47
+ model_list = []
48
+ for m in genai.list_models():
49
+ if "generateContent" in m.supported_generation_methods:
50
+ model_list.append(m.name.split("/")[-1])
51
+ return model_list
23
52
 
24
53
  @classmethod
25
54
  def create_model(
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
47
76
  "stopSequences": [],
48
77
  }
49
78
 
79
+ api_token = None
80
+ model = None
81
+
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
95
+ def __init__(self, *args, **kwargs):
96
+ super().__init__(*args, **kwargs)
97
+ self.initialize()
98
+
99
+ def get_generation_config(self) -> GenerationConfig:
100
+ return GenerationConfig(
101
+ temperature=self.temperature,
102
+ top_p=self.topP,
103
+ top_k=self.topK,
104
+ max_output_tokens=self.maxOutputTokens,
105
+ stop_sequences=self.stopSequences,
106
+ )
107
+
50
108
  async def async_execute_model_call(
51
- self, user_prompt: str, system_prompt: str = ""
52
- ) -> dict[str, Any]:
53
- # self.api_token = os.getenv("GOOGLE_API_KEY")
54
- combined_prompt = user_prompt + system_prompt
55
- url = f"https://generativelanguage.googleapis.com/v1beta/models/{model_name}:generateContent?key={self.api_token}"
56
- headers = {"Content-Type": "application/json"}
57
- data = {
58
- "contents": [{"parts": [{"text": combined_prompt}]}],
59
- "generationConfig": {
60
- "temperature": self.temperature,
61
- "topK": self.topK,
62
- "topP": self.topP,
63
- "maxOutputTokens": self.maxOutputTokens,
64
- "stopSequences": self.stopSequences,
65
- },
66
- }
67
- # print(combined_prompt)
68
- async with aiohttp.ClientSession() as session:
69
- async with session.post(
70
- url, headers=headers, data=json.dumps(data)
71
- ) as response:
72
- raw_response_text = await response.text()
73
- return json.loads(raw_response_text)
109
+ self,
110
+ user_prompt: str,
111
+ system_prompt: str = "",
112
+ files_list: Optional["Files"] = None,
113
+ ) -> Dict[str, Any]:
114
+ generation_config = self.get_generation_config()
74
115
 
75
- LLM.__name__ = model_name
116
+ if files_list is None:
117
+ files_list = []
118
+
119
+ if (
120
+ system_prompt is not None
121
+ and system_prompt != ""
122
+ and self._model_ != "gemini-pro"
123
+ ):
124
+ try:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ system_instruction=system_prompt,
129
+ )
130
+ except InvalidArgument as e:
131
+ print(
132
+ f"This model, {self._model_}, does not support system_instruction"
133
+ )
134
+ print("Will add system_prompt to user_prompt")
135
+ user_prompt = f"{system_prompt}\n{user_prompt}"
76
136
 
137
+ combined_prompt = [user_prompt]
138
+ for file in files_list:
139
+ if "google" not in file.external_locations:
140
+ _ = file.upload_google()
141
+ gen_ai_file = google.generativeai.types.file_types.File(
142
+ file.external_locations["google"]
143
+ )
144
+ combined_prompt.append(gen_ai_file)
145
+
146
+ response = await self.generative_model.generate_content_async(
147
+ combined_prompt, generation_config=generation_config
148
+ )
149
+ return response.to_dict()
150
+
151
+ LLM.__name__ = model_name
77
152
  return LLM
78
153
 
79
154
 
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Any, List
2
+ from typing import Any, List, Optional
3
3
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
4
4
  from edsl.language_models.LanguageModel import LanguageModel
5
5
  import asyncio
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
95
95
  return cls.async_client()
96
96
 
97
97
  async def async_execute_model_call(
98
- self, user_prompt: str, system_prompt: str = ""
98
+ self,
99
+ user_prompt: str,
100
+ system_prompt: str = "",
101
+ files_list: Optional[List["FileStore"]] = None,
99
102
  ) -> dict[str, Any]:
100
103
  """Calls the Mistral API and returns the API response."""
101
104
  s = self.async_client()
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
168
168
  self,
169
169
  user_prompt: str,
170
170
  system_prompt: str = "",
171
- encoded_image=None,
171
+ files_list: Optional[List["Files"]] = None,
172
172
  invigilator: Optional[
173
173
  "InvigilatorAI"
174
174
  ] = None, # TBD - can eventually be used for function-calling
175
175
  ) -> dict[str, Any]:
176
176
  """Calls the OpenAI API and returns the API response."""
177
- if encoded_image:
177
+ if files_list:
178
+ encoded_image = files_list[0].base64_string
178
179
  content = [{"type": "text", "text": user_prompt}]
179
180
  content.append(
180
181
  {
@@ -1,4 +1,4 @@
1
- from typing import Any, List
1
+ from typing import Any, List, Optional
2
2
  import os
3
3
  import asyncio
4
4
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
@@ -59,11 +59,20 @@ class TestService(InferenceServiceABC):
59
59
  self,
60
60
  user_prompt: str,
61
61
  system_prompt: str,
62
- encoded_image=None,
62
+ # func: Optional[callable] = None,
63
+ files_list: Optional[List["File"]] = None,
63
64
  ) -> dict[str, Any]:
64
65
  await asyncio.sleep(0.1)
65
66
  # return {"message": """{"answer": "Hello, world"}"""}
66
67
 
68
+ if hasattr(self, "func"):
69
+ return {
70
+ "message": [
71
+ {"text": self.func(user_prompt, system_prompt, files_list)}
72
+ ],
73
+ "usage": {"prompt_tokens": 1, "completion_tokens": 1},
74
+ }
75
+
67
76
  if hasattr(self, "throw_exception") and self.throw_exception:
68
77
  if hasattr(self, "exception_probability"):
69
78
  p = self.exception_probability
@@ -1,7 +1,7 @@
1
1
  import aiohttp
2
2
  import json
3
3
  import requests
4
- from typing import Any, List
4
+ from typing import Any, List, Optional
5
5
 
6
6
  # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
7
7
  from edsl.language_models import LanguageModel
edsl/jobs/Jobs.py CHANGED
@@ -145,14 +145,21 @@ class Jobs(Base):
145
145
  >>> Jobs.example().prompts()
146
146
  Dataset(...)
147
147
  """
148
+ from edsl import Coop
149
+
150
+ c = Coop()
151
+ price_lookup = c.fetch_prices()
148
152
 
149
153
  interviews = self.interviews()
150
154
  # data = []
151
155
  interview_indices = []
152
- question_indices = []
156
+ question_names = []
153
157
  user_prompts = []
154
158
  system_prompts = []
155
159
  scenario_indices = []
160
+ agent_indices = []
161
+ models = []
162
+ costs = []
156
163
  from edsl.results.Dataset import Dataset
157
164
 
158
165
  for interview_index, interview in enumerate(interviews):
@@ -160,23 +167,97 @@ class Jobs(Base):
160
167
  interview._get_invigilator(question)
161
168
  for question in self.survey.questions
162
169
  ]
163
- # list(interview._build_invigilators(debug=False))
164
170
  for _, invigilator in enumerate(invigilators):
165
171
  prompts = invigilator.get_prompts()
166
- user_prompts.append(prompts["user_prompt"])
167
- system_prompts.append(prompts["system_prompt"])
172
+ user_prompt = prompts["user_prompt"]
173
+ system_prompt = prompts["system_prompt"]
174
+ user_prompts.append(user_prompt)
175
+ system_prompts.append(system_prompt)
176
+ agent_index = self.agents.index(invigilator.agent)
177
+ agent_indices.append(agent_index)
168
178
  interview_indices.append(interview_index)
169
- scenario_indices.append(invigilator.scenario)
170
- question_indices.append(invigilator.question.question_name)
171
- return Dataset(
179
+ scenario_index = self.scenarios.index(invigilator.scenario)
180
+ scenario_indices.append(scenario_index)
181
+ models.append(invigilator.model.model)
182
+ question_names.append(invigilator.question.question_name)
183
+ # cost calculation
184
+ key = (invigilator.model._inference_service_, invigilator.model.model)
185
+ relevant_prices = price_lookup[key]
186
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
187
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
188
+ input_tokens = len(str(user_prompt) + str(system_prompt)) // 4
189
+ output_tokens = len(str(user_prompt) + str(system_prompt)) // 4
190
+ cost = input_tokens / float(
191
+ inverse_input_price
192
+ ) + output_tokens / float(inverse_output_price)
193
+ costs.append(cost)
194
+
195
+ d = Dataset(
172
196
  [
173
- {"interview_index": interview_indices},
174
- {"question_index": question_indices},
175
197
  {"user_prompt": user_prompts},
176
- {"scenario_index": scenario_indices},
177
198
  {"system_prompt": system_prompts},
199
+ {"interview_index": interview_indices},
200
+ {"question_name": question_names},
201
+ {"scenario_index": scenario_indices},
202
+ {"agent_index": agent_indices},
203
+ {"model": models},
204
+ {"estimated_cost": costs},
178
205
  ]
179
206
  )
207
+ return d
208
+ # if table:
209
+ # d.to_scenario_list().print(format="rich")
210
+ # else:
211
+ # return d
212
+
213
+ def show_prompts(self) -> None:
214
+ """Print the prompts."""
215
+ self.prompts().to_scenario_list().print(format="rich")
216
+
217
+ def estimate_job_cost(self):
218
+ from edsl import Coop
219
+
220
+ c = Coop()
221
+ price_lookup = c.fetch_prices()
222
+
223
+ prompts = self.prompts()
224
+
225
+ text_len = 0
226
+ for prompt in prompts:
227
+ text_len += len(str(prompt))
228
+
229
+ input_token_aproximations = text_len // 4
230
+
231
+ aproximation_cost = {}
232
+ total_cost = 0
233
+ for model in self.models:
234
+ key = (model._inference_service_, model.model)
235
+ relevant_prices = price_lookup[key]
236
+ inverse_output_price = relevant_prices["output"]["one_usd_buys"]
237
+ inverse_input_price = relevant_prices["input"]["one_usd_buys"]
238
+
239
+ aproximation_cost[key] = {
240
+ "input": input_token_aproximations / float(inverse_input_price),
241
+ "output": input_token_aproximations / float(inverse_output_price),
242
+ }
243
+ ##TODO curenlty we approximate the number of output tokens with the number
244
+ # of input tokens. A better solution will be to compute the quesiton answer options length and sum them
245
+ # to compute the output tokens
246
+
247
+ total_cost += input_token_aproximations / float(inverse_input_price)
248
+ total_cost += input_token_aproximations / float(inverse_output_price)
249
+
250
+ # multiply_factor = len(self.agents or [1]) * len(self.scenarios or [1])
251
+ multiply_factor = 1
252
+ out = {
253
+ "input_token_aproximations": input_token_aproximations,
254
+ "models_costs": aproximation_cost,
255
+ "estimated_total_cost": total_cost * multiply_factor,
256
+ "multiply_factor": multiply_factor,
257
+ "single_config_cost": total_cost,
258
+ }
259
+
260
+ return out
180
261
 
181
262
  @staticmethod
182
263
  def _get_container_class(object):
@@ -3,6 +3,7 @@
3
3
  from __future__ import annotations
4
4
  import asyncio
5
5
  from typing import Any, Type, List, Generator, Optional, Union
6
+ import copy
6
7
 
7
8
  from tenacity import (
8
9
  retry,
@@ -99,7 +100,9 @@ class Interview(InterviewStatusMixin):
99
100
 
100
101
  """
101
102
  self.agent = agent
102
- self.survey = survey
103
+ # what I would like to do
104
+ self.survey = copy.deepcopy(survey) # survey copy.deepcopy(survey)
105
+ # self.survey = survey
103
106
  self.scenario = scenario
104
107
  self.model = model
105
108
  self.debug = debug
@@ -248,17 +251,24 @@ class Interview(InterviewStatusMixin):
248
251
 
249
252
  def _get_estimated_request_tokens(self, question) -> float:
250
253
  """Estimate the number of tokens that will be required to run the focal task."""
254
+ from edsl.scenarios.FileStore import FileStore
255
+
251
256
  invigilator = self._get_invigilator(question=question)
252
257
  # TODO: There should be a way to get a more accurate estimate.
253
258
  combined_text = ""
259
+ file_tokens = 0
254
260
  for prompt in invigilator.get_prompts().values():
255
261
  if hasattr(prompt, "text"):
256
262
  combined_text += prompt.text
257
263
  elif isinstance(prompt, str):
258
264
  combined_text += prompt
265
+ elif isinstance(prompt, list):
266
+ for file in prompt:
267
+ if isinstance(file, FileStore):
268
+ file_tokens += file.size * 0.25
259
269
  else:
260
270
  raise ValueError(f"Prompt is of type {type(prompt)}")
261
- return len(combined_text) / 4.0
271
+ return len(combined_text) / 4.0 + file_tokens
262
272
 
263
273
  async def _answer_question_and_record_task(
264
274
  self,
@@ -296,6 +306,9 @@ class Interview(InterviewStatusMixin):
296
306
  self.answers.add_answer(response=response, question=question)
297
307
  self._cancel_skipped_questions(question)
298
308
  else:
309
+ # When a question is not validated, it is not added to the answers.
310
+ # this should also cancel and dependent children questions.
311
+ # Is that happening now?
299
312
  if (
300
313
  hasattr(response, "exception_occurred")
301
314
  and response.exception_occurred
@@ -8,10 +8,10 @@ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
8
8
  from contextlib import contextmanager
9
9
  from collections import UserList
10
10
 
11
- from edsl.results.Results import Results
12
11
  from rich.live import Live
13
12
  from rich.console import Console
14
13
 
14
+ from edsl.results.Results import Results
15
15
  from edsl import shared_globals
16
16
  from edsl.jobs.interviews.Interview import Interview
17
17
  from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
@@ -22,6 +22,8 @@ from edsl.utilities.decorators import jupyter_nb_handler
22
22
  from edsl.data.Cache import Cache
23
23
  from edsl.results.Result import Result
24
24
  from edsl.results.Results import Results
25
+ from edsl.language_models.LanguageModel import LanguageModel
26
+ from edsl.data.Cache import Cache
25
27
 
26
28
 
27
29
  class StatusTracker(UserList):
@@ -50,10 +52,10 @@ class JobsRunnerAsyncio:
50
52
 
51
53
  async def run_async_generator(
52
54
  self,
53
- cache: "Cache",
55
+ cache: Cache,
54
56
  n: int = 1,
55
57
  stop_on_exception: bool = False,
56
- sidecar_model: Optional["LanguageModel"] = None,
58
+ sidecar_model: Optional[LanguageModel] = None,
57
59
  total_interviews: Optional[List["Interview"]] = None,
58
60
  raise_validation_errors: bool = False,
59
61
  ) -> AsyncGenerator["Result", None]:
@@ -104,7 +106,7 @@ class JobsRunnerAsyncio:
104
106
  interview.cache = self.cache
105
107
  yield interview
106
108
 
107
- async def run_async(self, cache: Optional["Cache"] = None, n: int = 1) -> Results:
109
+ async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
108
110
  """Used for some other modules that have a non-standard way of running interviews."""
109
111
  self.jobs_runner_status = JobsRunnerStatus(self, n=n)
110
112
  self.cache = Cache() if cache is None else cache
@@ -171,19 +173,19 @@ class JobsRunnerAsyncio:
171
173
 
172
174
  prompt_dictionary = {}
173
175
  for answer_key_name in answer_key_names:
174
- prompt_dictionary[answer_key_name + "_user_prompt"] = (
175
- question_name_to_prompts[answer_key_name]["user_prompt"]
176
- )
177
- prompt_dictionary[answer_key_name + "_system_prompt"] = (
178
- question_name_to_prompts[answer_key_name]["system_prompt"]
179
- )
176
+ prompt_dictionary[
177
+ answer_key_name + "_user_prompt"
178
+ ] = question_name_to_prompts[answer_key_name]["user_prompt"]
179
+ prompt_dictionary[
180
+ answer_key_name + "_system_prompt"
181
+ ] = question_name_to_prompts[answer_key_name]["system_prompt"]
180
182
 
181
183
  raw_model_results_dictionary = {}
182
184
  for result in valid_results:
183
185
  question_name = result.question_name
184
- raw_model_results_dictionary[question_name + "_raw_model_response"] = (
185
- result.raw_model_response
186
- )
186
+ raw_model_results_dictionary[
187
+ question_name + "_raw_model_response"
188
+ ] = result.raw_model_response
187
189
  raw_model_results_dictionary[question_name + "_cost"] = result.cost
188
190
  one_use_buys = (
189
191
  "NA"
@@ -291,6 +293,8 @@ class JobsRunnerAsyncio:
291
293
 
292
294
  self.jobs_runner_status = JobsRunnerStatus(self, n=n)
293
295
 
296
+ stop_event = threading.Event()
297
+
294
298
  async def process_results(cache):
295
299
  """Processes results from interviews."""
296
300
  async for result in self.run_async_generator(
@@ -303,20 +307,37 @@ class JobsRunnerAsyncio:
303
307
  self.results.append(result)
304
308
  self.completed = True
305
309
 
306
- def run_progress_bar():
310
+ def run_progress_bar(stop_event):
307
311
  """Runs the progress bar in a separate thread."""
308
- self.jobs_runner_status.update_progress()
312
+ self.jobs_runner_status.update_progress(stop_event)
309
313
 
310
314
  if progress_bar:
311
- progress_thread = threading.Thread(target=run_progress_bar)
315
+ progress_thread = threading.Thread(
316
+ target=run_progress_bar, args=(stop_event,)
317
+ )
312
318
  progress_thread.start()
313
319
 
314
- with cache as c:
315
- await process_results(cache=c)
316
-
317
- if progress_bar:
318
- progress_thread.join()
319
-
320
- return self.process_results(
321
- raw_results=self.results, cache=cache, print_exceptions=print_exceptions
322
- )
320
+ exception_to_raise = None
321
+ try:
322
+ with cache as c:
323
+ await process_results(cache=c)
324
+ except KeyboardInterrupt:
325
+ print("Keyboard interrupt received. Stopping gracefully...")
326
+ stop_event.set()
327
+ except Exception as e:
328
+ if stop_on_exception:
329
+ exception_to_raise = e
330
+ stop_event.set()
331
+ finally:
332
+ stop_event.set()
333
+ if progress_bar:
334
+ # self.jobs_runner_status.stop_event.set()
335
+ if progress_thread:
336
+ progress_thread.join()
337
+
338
+ if exception_to_raise:
339
+ raise exception_to_raise
340
+
341
+ return self.process_results(
342
+ raw_results=self.results, cache=cache, print_exceptions=print_exceptions
343
+ )
@@ -265,14 +265,15 @@ class JobsRunnerStatus:
265
265
  table.add_row(pretty_name, value)
266
266
  return table
267
267
 
268
- def update_progress(self):
268
+ def update_progress(self, stop_event):
269
269
  layout, progress, task_ids = self.generate_layout()
270
270
 
271
271
  with Live(
272
272
  layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
273
273
  ) as live:
274
- while len(self.completed_interviews) < len(
275
- self.jobs_runner.total_interviews
274
+ while (
275
+ len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
276
+ and not stop_event.is_set()
276
277
  ):
277
278
  completed_tasks = len(self.completed_interviews)
278
279
  total_tasks = len(self.jobs_runner.total_interviews)
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
156
156
  self.tokens_bucket.turbo_mode_off()
157
157
  self.requests_bucket.turbo_mode_off()
158
158
 
159
- # breakpoint()
160
- # _ = results.pop("cached_response", None)
161
-
162
- # tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
163
-
164
- # TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
165
- # usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
166
- # prompt_tokens = usage.get("prompt_tokens", 0)
167
- # completion_tokens = usage.get("completion_tokens", 0)
168
- # tracker.add_tokens(
169
- # prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
170
- # )
171
-
172
159
  return results
173
160
 
174
161
  @classmethod
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
249
236
  f"Required tasks failed for {self.question.question_name}"
250
237
  ) from e
251
238
 
239
+ # this only runs if all the dependencies are successful
252
240
  return await self._run_focal_task()
253
241
 
254
242