edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -11
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +22 -3
- edsl/agents/PromptConstructor.py +80 -184
- edsl/agents/prompt_helpers.py +129 -0
- edsl/coop/coop.py +3 -2
- edsl/data_transfer_models.py +0 -1
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +3 -2
- edsl/inference_services/TestService.py +11 -2
- edsl/inference_services/TogetherAIService.py +1 -1
- edsl/jobs/Jobs.py +91 -10
- edsl/jobs/interviews/Interview.py +15 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +46 -25
- edsl/jobs/runners/JobsRunnerStatus.py +4 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/language_models/LanguageModel.py +12 -9
- edsl/language_models/utilities.py +5 -2
- edsl/questions/QuestionBase.py +13 -3
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +8 -4
- edsl/questions/ResponseValidatorABC.py +5 -1
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +0 -1
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/scenarios/FileStore.py +159 -76
- edsl/scenarios/Scenario.py +23 -49
- edsl/scenarios/ScenarioList.py +6 -2
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +24 -0
- edsl/surveys/RuleCollection.py +36 -2
- edsl/surveys/Survey.py +182 -10
- edsl/surveys/base.py +4 -0
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/RECORD +43 -43
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dist-info}/WHEEL +0 -0
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, List
|
3
3
|
import re
|
4
4
|
from openai import AsyncAzureOpenAI
|
5
5
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -122,7 +122,10 @@ class AzureAIService(InferenceServiceABC):
|
|
122
122
|
_tpm = cls.get_tpm(cls)
|
123
123
|
|
124
124
|
async def async_execute_model_call(
|
125
|
-
self,
|
125
|
+
self,
|
126
|
+
user_prompt: str,
|
127
|
+
system_prompt: str = "",
|
128
|
+
files_list: Optional[List["FileStore"]] = None,
|
126
129
|
) -> dict[str, Any]:
|
127
130
|
"""Calls the Azure OpenAI API and returns the API response."""
|
128
131
|
|
@@ -1,25 +1,54 @@
|
|
1
1
|
import os
|
2
|
-
import
|
3
|
-
import
|
4
|
-
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
import google
|
4
|
+
import google.generativeai as genai
|
5
|
+
from google.generativeai.types import GenerationConfig
|
6
|
+
from google.api_core.exceptions import InvalidArgument
|
7
|
+
|
5
8
|
from edsl.exceptions import MissingAPIKeyError
|
6
9
|
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
-
|
8
10
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
11
|
|
12
|
+
safety_settings = [
|
13
|
+
{
|
14
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
15
|
+
"threshold": "BLOCK_NONE",
|
16
|
+
},
|
17
|
+
{
|
18
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
19
|
+
"threshold": "BLOCK_NONE",
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23
|
+
"threshold": "BLOCK_NONE",
|
24
|
+
},
|
25
|
+
{
|
26
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27
|
+
"threshold": "BLOCK_NONE",
|
28
|
+
},
|
29
|
+
]
|
30
|
+
|
10
31
|
|
11
32
|
class GoogleService(InferenceServiceABC):
|
12
33
|
_inference_service_ = "google"
|
13
34
|
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
-
usage_sequence = ["
|
15
|
-
input_token_name = "
|
16
|
-
output_token_name = "
|
35
|
+
usage_sequence = ["usage_metadata"]
|
36
|
+
input_token_name = "prompt_token_count"
|
37
|
+
output_token_name = "candidates_token_count"
|
17
38
|
|
18
39
|
model_exclude_list = []
|
19
40
|
|
41
|
+
# @classmethod
|
42
|
+
# def available(cls) -> List[str]:
|
43
|
+
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
44
|
+
|
20
45
|
@classmethod
|
21
|
-
def available(cls):
|
22
|
-
|
46
|
+
def available(cls) -> List[str]:
|
47
|
+
model_list = []
|
48
|
+
for m in genai.list_models():
|
49
|
+
if "generateContent" in m.supported_generation_methods:
|
50
|
+
model_list.append(m.name.split("/")[-1])
|
51
|
+
return model_list
|
23
52
|
|
24
53
|
@classmethod
|
25
54
|
def create_model(
|
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
|
|
47
76
|
"stopSequences": [],
|
48
77
|
}
|
49
78
|
|
79
|
+
api_token = None
|
80
|
+
model = None
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def initialize(cls):
|
84
|
+
if cls.api_token is None:
|
85
|
+
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
+
if not cls.api_token:
|
87
|
+
raise MissingAPIKeyError(
|
88
|
+
"GOOGLE_API_KEY environment variable is not set"
|
89
|
+
)
|
90
|
+
genai.configure(api_key=cls.api_token)
|
91
|
+
cls.generative_model = genai.GenerativeModel(
|
92
|
+
cls._model_, safety_settings=safety_settings
|
93
|
+
)
|
94
|
+
|
95
|
+
def __init__(self, *args, **kwargs):
|
96
|
+
super().__init__(*args, **kwargs)
|
97
|
+
self.initialize()
|
98
|
+
|
99
|
+
def get_generation_config(self) -> GenerationConfig:
|
100
|
+
return GenerationConfig(
|
101
|
+
temperature=self.temperature,
|
102
|
+
top_p=self.topP,
|
103
|
+
top_k=self.topK,
|
104
|
+
max_output_tokens=self.maxOutputTokens,
|
105
|
+
stop_sequences=self.stopSequences,
|
106
|
+
)
|
107
|
+
|
50
108
|
async def async_execute_model_call(
|
51
|
-
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
data = {
|
58
|
-
"contents": [{"parts": [{"text": combined_prompt}]}],
|
59
|
-
"generationConfig": {
|
60
|
-
"temperature": self.temperature,
|
61
|
-
"topK": self.topK,
|
62
|
-
"topP": self.topP,
|
63
|
-
"maxOutputTokens": self.maxOutputTokens,
|
64
|
-
"stopSequences": self.stopSequences,
|
65
|
-
},
|
66
|
-
}
|
67
|
-
# print(combined_prompt)
|
68
|
-
async with aiohttp.ClientSession() as session:
|
69
|
-
async with session.post(
|
70
|
-
url, headers=headers, data=json.dumps(data)
|
71
|
-
) as response:
|
72
|
-
raw_response_text = await response.text()
|
73
|
-
return json.loads(raw_response_text)
|
109
|
+
self,
|
110
|
+
user_prompt: str,
|
111
|
+
system_prompt: str = "",
|
112
|
+
files_list: Optional["Files"] = None,
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
generation_config = self.get_generation_config()
|
74
115
|
|
75
|
-
|
116
|
+
if files_list is None:
|
117
|
+
files_list = []
|
118
|
+
|
119
|
+
if (
|
120
|
+
system_prompt is not None
|
121
|
+
and system_prompt != ""
|
122
|
+
and self._model_ != "gemini-pro"
|
123
|
+
):
|
124
|
+
try:
|
125
|
+
self.generative_model = genai.GenerativeModel(
|
126
|
+
self._model_,
|
127
|
+
safety_settings=safety_settings,
|
128
|
+
system_instruction=system_prompt,
|
129
|
+
)
|
130
|
+
except InvalidArgument as e:
|
131
|
+
print(
|
132
|
+
f"This model, {self._model_}, does not support system_instruction"
|
133
|
+
)
|
134
|
+
print("Will add system_prompt to user_prompt")
|
135
|
+
user_prompt = f"{system_prompt}\n{user_prompt}"
|
76
136
|
|
137
|
+
combined_prompt = [user_prompt]
|
138
|
+
for file in files_list:
|
139
|
+
if "google" not in file.external_locations:
|
140
|
+
_ = file.upload_google()
|
141
|
+
gen_ai_file = google.generativeai.types.file_types.File(
|
142
|
+
file.external_locations["google"]
|
143
|
+
)
|
144
|
+
combined_prompt.append(gen_ai_file)
|
145
|
+
|
146
|
+
response = await self.generative_model.generate_content_async(
|
147
|
+
combined_prompt, generation_config=generation_config
|
148
|
+
)
|
149
|
+
return response.to_dict()
|
150
|
+
|
151
|
+
LLM.__name__ = model_name
|
77
152
|
return LLM
|
78
153
|
|
79
154
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any, List
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
4
|
from edsl.language_models.LanguageModel import LanguageModel
|
5
5
|
import asyncio
|
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
|
|
95
95
|
return cls.async_client()
|
96
96
|
|
97
97
|
async def async_execute_model_call(
|
98
|
-
self,
|
98
|
+
self,
|
99
|
+
user_prompt: str,
|
100
|
+
system_prompt: str = "",
|
101
|
+
files_list: Optional[List["FileStore"]] = None,
|
99
102
|
) -> dict[str, Any]:
|
100
103
|
"""Calls the Mistral API and returns the API response."""
|
101
104
|
s = self.async_client()
|
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
|
|
168
168
|
self,
|
169
169
|
user_prompt: str,
|
170
170
|
system_prompt: str = "",
|
171
|
-
|
171
|
+
files_list: Optional[List["Files"]] = None,
|
172
172
|
invigilator: Optional[
|
173
173
|
"InvigilatorAI"
|
174
174
|
] = None, # TBD - can eventually be used for function-calling
|
175
175
|
) -> dict[str, Any]:
|
176
176
|
"""Calls the OpenAI API and returns the API response."""
|
177
|
-
if
|
177
|
+
if files_list:
|
178
|
+
encoded_image = files_list[0].base64_string
|
178
179
|
content = [{"type": "text", "text": user_prompt}]
|
179
180
|
content.append(
|
180
181
|
{
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, List
|
1
|
+
from typing import Any, List, Optional
|
2
2
|
import os
|
3
3
|
import asyncio
|
4
4
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -59,11 +59,20 @@ class TestService(InferenceServiceABC):
|
|
59
59
|
self,
|
60
60
|
user_prompt: str,
|
61
61
|
system_prompt: str,
|
62
|
-
|
62
|
+
# func: Optional[callable] = None,
|
63
|
+
files_list: Optional[List["File"]] = None,
|
63
64
|
) -> dict[str, Any]:
|
64
65
|
await asyncio.sleep(0.1)
|
65
66
|
# return {"message": """{"answer": "Hello, world"}"""}
|
66
67
|
|
68
|
+
if hasattr(self, "func"):
|
69
|
+
return {
|
70
|
+
"message": [
|
71
|
+
{"text": self.func(user_prompt, system_prompt, files_list)}
|
72
|
+
],
|
73
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
74
|
+
}
|
75
|
+
|
67
76
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
68
77
|
if hasattr(self, "exception_probability"):
|
69
78
|
p = self.exception_probability
|
edsl/jobs/Jobs.py
CHANGED
@@ -145,14 +145,21 @@ class Jobs(Base):
|
|
145
145
|
>>> Jobs.example().prompts()
|
146
146
|
Dataset(...)
|
147
147
|
"""
|
148
|
+
from edsl import Coop
|
149
|
+
|
150
|
+
c = Coop()
|
151
|
+
price_lookup = c.fetch_prices()
|
148
152
|
|
149
153
|
interviews = self.interviews()
|
150
154
|
# data = []
|
151
155
|
interview_indices = []
|
152
|
-
|
156
|
+
question_names = []
|
153
157
|
user_prompts = []
|
154
158
|
system_prompts = []
|
155
159
|
scenario_indices = []
|
160
|
+
agent_indices = []
|
161
|
+
models = []
|
162
|
+
costs = []
|
156
163
|
from edsl.results.Dataset import Dataset
|
157
164
|
|
158
165
|
for interview_index, interview in enumerate(interviews):
|
@@ -160,23 +167,97 @@ class Jobs(Base):
|
|
160
167
|
interview._get_invigilator(question)
|
161
168
|
for question in self.survey.questions
|
162
169
|
]
|
163
|
-
# list(interview._build_invigilators(debug=False))
|
164
170
|
for _, invigilator in enumerate(invigilators):
|
165
171
|
prompts = invigilator.get_prompts()
|
166
|
-
|
167
|
-
|
172
|
+
user_prompt = prompts["user_prompt"]
|
173
|
+
system_prompt = prompts["system_prompt"]
|
174
|
+
user_prompts.append(user_prompt)
|
175
|
+
system_prompts.append(system_prompt)
|
176
|
+
agent_index = self.agents.index(invigilator.agent)
|
177
|
+
agent_indices.append(agent_index)
|
168
178
|
interview_indices.append(interview_index)
|
169
|
-
|
170
|
-
|
171
|
-
|
179
|
+
scenario_index = self.scenarios.index(invigilator.scenario)
|
180
|
+
scenario_indices.append(scenario_index)
|
181
|
+
models.append(invigilator.model.model)
|
182
|
+
question_names.append(invigilator.question.question_name)
|
183
|
+
# cost calculation
|
184
|
+
key = (invigilator.model._inference_service_, invigilator.model.model)
|
185
|
+
relevant_prices = price_lookup[key]
|
186
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
187
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
188
|
+
input_tokens = len(str(user_prompt) + str(system_prompt)) // 4
|
189
|
+
output_tokens = len(str(user_prompt) + str(system_prompt)) // 4
|
190
|
+
cost = input_tokens / float(
|
191
|
+
inverse_input_price
|
192
|
+
) + output_tokens / float(inverse_output_price)
|
193
|
+
costs.append(cost)
|
194
|
+
|
195
|
+
d = Dataset(
|
172
196
|
[
|
173
|
-
{"interview_index": interview_indices},
|
174
|
-
{"question_index": question_indices},
|
175
197
|
{"user_prompt": user_prompts},
|
176
|
-
{"scenario_index": scenario_indices},
|
177
198
|
{"system_prompt": system_prompts},
|
199
|
+
{"interview_index": interview_indices},
|
200
|
+
{"question_name": question_names},
|
201
|
+
{"scenario_index": scenario_indices},
|
202
|
+
{"agent_index": agent_indices},
|
203
|
+
{"model": models},
|
204
|
+
{"estimated_cost": costs},
|
178
205
|
]
|
179
206
|
)
|
207
|
+
return d
|
208
|
+
# if table:
|
209
|
+
# d.to_scenario_list().print(format="rich")
|
210
|
+
# else:
|
211
|
+
# return d
|
212
|
+
|
213
|
+
def show_prompts(self) -> None:
|
214
|
+
"""Print the prompts."""
|
215
|
+
self.prompts().to_scenario_list().print(format="rich")
|
216
|
+
|
217
|
+
def estimate_job_cost(self):
|
218
|
+
from edsl import Coop
|
219
|
+
|
220
|
+
c = Coop()
|
221
|
+
price_lookup = c.fetch_prices()
|
222
|
+
|
223
|
+
prompts = self.prompts()
|
224
|
+
|
225
|
+
text_len = 0
|
226
|
+
for prompt in prompts:
|
227
|
+
text_len += len(str(prompt))
|
228
|
+
|
229
|
+
input_token_aproximations = text_len // 4
|
230
|
+
|
231
|
+
aproximation_cost = {}
|
232
|
+
total_cost = 0
|
233
|
+
for model in self.models:
|
234
|
+
key = (model._inference_service_, model.model)
|
235
|
+
relevant_prices = price_lookup[key]
|
236
|
+
inverse_output_price = relevant_prices["output"]["one_usd_buys"]
|
237
|
+
inverse_input_price = relevant_prices["input"]["one_usd_buys"]
|
238
|
+
|
239
|
+
aproximation_cost[key] = {
|
240
|
+
"input": input_token_aproximations / float(inverse_input_price),
|
241
|
+
"output": input_token_aproximations / float(inverse_output_price),
|
242
|
+
}
|
243
|
+
##TODO curenlty we approximate the number of output tokens with the number
|
244
|
+
# of input tokens. A better solution will be to compute the quesiton answer options length and sum them
|
245
|
+
# to compute the output tokens
|
246
|
+
|
247
|
+
total_cost += input_token_aproximations / float(inverse_input_price)
|
248
|
+
total_cost += input_token_aproximations / float(inverse_output_price)
|
249
|
+
|
250
|
+
# multiply_factor = len(self.agents or [1]) * len(self.scenarios or [1])
|
251
|
+
multiply_factor = 1
|
252
|
+
out = {
|
253
|
+
"input_token_aproximations": input_token_aproximations,
|
254
|
+
"models_costs": aproximation_cost,
|
255
|
+
"estimated_total_cost": total_cost * multiply_factor,
|
256
|
+
"multiply_factor": multiply_factor,
|
257
|
+
"single_config_cost": total_cost,
|
258
|
+
}
|
259
|
+
|
260
|
+
return out
|
180
261
|
|
181
262
|
@staticmethod
|
182
263
|
def _get_container_class(object):
|
@@ -3,6 +3,7 @@
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import asyncio
|
5
5
|
from typing import Any, Type, List, Generator, Optional, Union
|
6
|
+
import copy
|
6
7
|
|
7
8
|
from tenacity import (
|
8
9
|
retry,
|
@@ -99,7 +100,9 @@ class Interview(InterviewStatusMixin):
|
|
99
100
|
|
100
101
|
"""
|
101
102
|
self.agent = agent
|
102
|
-
|
103
|
+
# what I would like to do
|
104
|
+
self.survey = copy.deepcopy(survey) # survey copy.deepcopy(survey)
|
105
|
+
# self.survey = survey
|
103
106
|
self.scenario = scenario
|
104
107
|
self.model = model
|
105
108
|
self.debug = debug
|
@@ -248,17 +251,24 @@ class Interview(InterviewStatusMixin):
|
|
248
251
|
|
249
252
|
def _get_estimated_request_tokens(self, question) -> float:
|
250
253
|
"""Estimate the number of tokens that will be required to run the focal task."""
|
254
|
+
from edsl.scenarios.FileStore import FileStore
|
255
|
+
|
251
256
|
invigilator = self._get_invigilator(question=question)
|
252
257
|
# TODO: There should be a way to get a more accurate estimate.
|
253
258
|
combined_text = ""
|
259
|
+
file_tokens = 0
|
254
260
|
for prompt in invigilator.get_prompts().values():
|
255
261
|
if hasattr(prompt, "text"):
|
256
262
|
combined_text += prompt.text
|
257
263
|
elif isinstance(prompt, str):
|
258
264
|
combined_text += prompt
|
265
|
+
elif isinstance(prompt, list):
|
266
|
+
for file in prompt:
|
267
|
+
if isinstance(file, FileStore):
|
268
|
+
file_tokens += file.size * 0.25
|
259
269
|
else:
|
260
270
|
raise ValueError(f"Prompt is of type {type(prompt)}")
|
261
|
-
return len(combined_text) / 4.0
|
271
|
+
return len(combined_text) / 4.0 + file_tokens
|
262
272
|
|
263
273
|
async def _answer_question_and_record_task(
|
264
274
|
self,
|
@@ -296,6 +306,9 @@ class Interview(InterviewStatusMixin):
|
|
296
306
|
self.answers.add_answer(response=response, question=question)
|
297
307
|
self._cancel_skipped_questions(question)
|
298
308
|
else:
|
309
|
+
# When a question is not validated, it is not added to the answers.
|
310
|
+
# this should also cancel and dependent children questions.
|
311
|
+
# Is that happening now?
|
299
312
|
if (
|
300
313
|
hasattr(response, "exception_occurred")
|
301
314
|
and response.exception_occurred
|
@@ -8,10 +8,10 @@ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
|
8
8
|
from contextlib import contextmanager
|
9
9
|
from collections import UserList
|
10
10
|
|
11
|
-
from edsl.results.Results import Results
|
12
11
|
from rich.live import Live
|
13
12
|
from rich.console import Console
|
14
13
|
|
14
|
+
from edsl.results.Results import Results
|
15
15
|
from edsl import shared_globals
|
16
16
|
from edsl.jobs.interviews.Interview import Interview
|
17
17
|
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
@@ -22,6 +22,8 @@ from edsl.utilities.decorators import jupyter_nb_handler
|
|
22
22
|
from edsl.data.Cache import Cache
|
23
23
|
from edsl.results.Result import Result
|
24
24
|
from edsl.results.Results import Results
|
25
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
26
|
+
from edsl.data.Cache import Cache
|
25
27
|
|
26
28
|
|
27
29
|
class StatusTracker(UserList):
|
@@ -50,10 +52,10 @@ class JobsRunnerAsyncio:
|
|
50
52
|
|
51
53
|
async def run_async_generator(
|
52
54
|
self,
|
53
|
-
cache:
|
55
|
+
cache: Cache,
|
54
56
|
n: int = 1,
|
55
57
|
stop_on_exception: bool = False,
|
56
|
-
sidecar_model: Optional[
|
58
|
+
sidecar_model: Optional[LanguageModel] = None,
|
57
59
|
total_interviews: Optional[List["Interview"]] = None,
|
58
60
|
raise_validation_errors: bool = False,
|
59
61
|
) -> AsyncGenerator["Result", None]:
|
@@ -104,7 +106,7 @@ class JobsRunnerAsyncio:
|
|
104
106
|
interview.cache = self.cache
|
105
107
|
yield interview
|
106
108
|
|
107
|
-
async def run_async(self, cache: Optional[
|
109
|
+
async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
|
108
110
|
"""Used for some other modules that have a non-standard way of running interviews."""
|
109
111
|
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
110
112
|
self.cache = Cache() if cache is None else cache
|
@@ -171,19 +173,19 @@ class JobsRunnerAsyncio:
|
|
171
173
|
|
172
174
|
prompt_dictionary = {}
|
173
175
|
for answer_key_name in answer_key_names:
|
174
|
-
prompt_dictionary[
|
175
|
-
|
176
|
-
|
177
|
-
prompt_dictionary[
|
178
|
-
|
179
|
-
|
176
|
+
prompt_dictionary[
|
177
|
+
answer_key_name + "_user_prompt"
|
178
|
+
] = question_name_to_prompts[answer_key_name]["user_prompt"]
|
179
|
+
prompt_dictionary[
|
180
|
+
answer_key_name + "_system_prompt"
|
181
|
+
] = question_name_to_prompts[answer_key_name]["system_prompt"]
|
180
182
|
|
181
183
|
raw_model_results_dictionary = {}
|
182
184
|
for result in valid_results:
|
183
185
|
question_name = result.question_name
|
184
|
-
raw_model_results_dictionary[
|
185
|
-
|
186
|
-
|
186
|
+
raw_model_results_dictionary[
|
187
|
+
question_name + "_raw_model_response"
|
188
|
+
] = result.raw_model_response
|
187
189
|
raw_model_results_dictionary[question_name + "_cost"] = result.cost
|
188
190
|
one_use_buys = (
|
189
191
|
"NA"
|
@@ -291,6 +293,8 @@ class JobsRunnerAsyncio:
|
|
291
293
|
|
292
294
|
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
293
295
|
|
296
|
+
stop_event = threading.Event()
|
297
|
+
|
294
298
|
async def process_results(cache):
|
295
299
|
"""Processes results from interviews."""
|
296
300
|
async for result in self.run_async_generator(
|
@@ -303,20 +307,37 @@ class JobsRunnerAsyncio:
|
|
303
307
|
self.results.append(result)
|
304
308
|
self.completed = True
|
305
309
|
|
306
|
-
def run_progress_bar():
|
310
|
+
def run_progress_bar(stop_event):
|
307
311
|
"""Runs the progress bar in a separate thread."""
|
308
|
-
self.jobs_runner_status.update_progress()
|
312
|
+
self.jobs_runner_status.update_progress(stop_event)
|
309
313
|
|
310
314
|
if progress_bar:
|
311
|
-
progress_thread = threading.Thread(
|
315
|
+
progress_thread = threading.Thread(
|
316
|
+
target=run_progress_bar, args=(stop_event,)
|
317
|
+
)
|
312
318
|
progress_thread.start()
|
313
319
|
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
320
|
+
exception_to_raise = None
|
321
|
+
try:
|
322
|
+
with cache as c:
|
323
|
+
await process_results(cache=c)
|
324
|
+
except KeyboardInterrupt:
|
325
|
+
print("Keyboard interrupt received. Stopping gracefully...")
|
326
|
+
stop_event.set()
|
327
|
+
except Exception as e:
|
328
|
+
if stop_on_exception:
|
329
|
+
exception_to_raise = e
|
330
|
+
stop_event.set()
|
331
|
+
finally:
|
332
|
+
stop_event.set()
|
333
|
+
if progress_bar:
|
334
|
+
# self.jobs_runner_status.stop_event.set()
|
335
|
+
if progress_thread:
|
336
|
+
progress_thread.join()
|
337
|
+
|
338
|
+
if exception_to_raise:
|
339
|
+
raise exception_to_raise
|
340
|
+
|
341
|
+
return self.process_results(
|
342
|
+
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
|
343
|
+
)
|
@@ -265,14 +265,15 @@ class JobsRunnerStatus:
|
|
265
265
|
table.add_row(pretty_name, value)
|
266
266
|
return table
|
267
267
|
|
268
|
-
def update_progress(self):
|
268
|
+
def update_progress(self, stop_event):
|
269
269
|
layout, progress, task_ids = self.generate_layout()
|
270
270
|
|
271
271
|
with Live(
|
272
272
|
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
|
273
273
|
) as live:
|
274
|
-
while
|
275
|
-
self.jobs_runner.total_interviews
|
274
|
+
while (
|
275
|
+
len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
|
276
|
+
and not stop_event.is_set()
|
276
277
|
):
|
277
278
|
completed_tasks = len(self.completed_interviews)
|
278
279
|
total_tasks = len(self.jobs_runner.total_interviews)
|
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
|
|
156
156
|
self.tokens_bucket.turbo_mode_off()
|
157
157
|
self.requests_bucket.turbo_mode_off()
|
158
158
|
|
159
|
-
# breakpoint()
|
160
|
-
# _ = results.pop("cached_response", None)
|
161
|
-
|
162
|
-
# tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
163
|
-
|
164
|
-
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
165
|
-
# usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
166
|
-
# prompt_tokens = usage.get("prompt_tokens", 0)
|
167
|
-
# completion_tokens = usage.get("completion_tokens", 0)
|
168
|
-
# tracker.add_tokens(
|
169
|
-
# prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
170
|
-
# )
|
171
|
-
|
172
159
|
return results
|
173
160
|
|
174
161
|
@classmethod
|
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
|
|
249
236
|
f"Required tasks failed for {self.question.question_name}"
|
250
237
|
) from e
|
251
238
|
|
239
|
+
# this only runs if all the dependencies are successful
|
252
240
|
return await self._run_focal_task()
|
253
241
|
|
254
242
|
|