edsl 0.1.33.dev3__py3-none-any.whl → 0.1.34.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +15 -11
- edsl/__version__.py +1 -1
- edsl/agents/Invigilator.py +22 -3
- edsl/agents/PromptConstructor.py +79 -183
- edsl/agents/prompt_helpers.py +129 -0
- edsl/coop/coop.py +3 -2
- edsl/data_transfer_models.py +0 -1
- edsl/inference_services/AnthropicService.py +5 -2
- edsl/inference_services/AwsBedrock.py +5 -2
- edsl/inference_services/AzureAI.py +5 -2
- edsl/inference_services/GoogleService.py +108 -33
- edsl/inference_services/MistralAIService.py +5 -2
- edsl/inference_services/OpenAIService.py +3 -2
- edsl/inference_services/TestService.py +11 -2
- edsl/inference_services/TogetherAIService.py +1 -1
- edsl/jobs/interviews/Interview.py +19 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +37 -16
- edsl/jobs/runners/JobsRunnerStatus.py +4 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -13
- edsl/language_models/LanguageModel.py +12 -9
- edsl/language_models/utilities.py +3 -2
- edsl/questions/QuestionBase.py +11 -2
- edsl/questions/QuestionBaseGenMixin.py +28 -0
- edsl/questions/QuestionCheckBox.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +5 -1
- edsl/questions/ResponseValidatorABC.py +5 -1
- edsl/questions/descriptors.py +12 -11
- edsl/questions/templates/yes_no/answering_instructions.jinja +2 -2
- edsl/scenarios/FileStore.py +159 -71
- edsl/scenarios/Scenario.py +23 -49
- edsl/scenarios/ScenarioList.py +6 -2
- edsl/surveys/DAG.py +62 -0
- edsl/surveys/MemoryPlan.py +26 -0
- edsl/surveys/Rule.py +24 -0
- edsl/surveys/RuleCollection.py +36 -2
- edsl/surveys/Survey.py +182 -10
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/RECORD +40 -40
- edsl/scenarios/ScenarioImageMixin.py +0 -100
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev3.dist-info → edsl-0.1.34.dev1.dist-info}/WHEEL +0 -0
@@ -1,25 +1,54 @@
|
|
1
1
|
import os
|
2
|
-
import
|
3
|
-
import
|
4
|
-
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
import google
|
4
|
+
import google.generativeai as genai
|
5
|
+
from google.generativeai.types import GenerationConfig
|
6
|
+
from google.api_core.exceptions import InvalidArgument
|
7
|
+
|
5
8
|
from edsl.exceptions import MissingAPIKeyError
|
6
9
|
from edsl.language_models.LanguageModel import LanguageModel
|
7
|
-
|
8
10
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
9
11
|
|
12
|
+
safety_settings = [
|
13
|
+
{
|
14
|
+
"category": "HARM_CATEGORY_HARASSMENT",
|
15
|
+
"threshold": "BLOCK_NONE",
|
16
|
+
},
|
17
|
+
{
|
18
|
+
"category": "HARM_CATEGORY_HATE_SPEECH",
|
19
|
+
"threshold": "BLOCK_NONE",
|
20
|
+
},
|
21
|
+
{
|
22
|
+
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
23
|
+
"threshold": "BLOCK_NONE",
|
24
|
+
},
|
25
|
+
{
|
26
|
+
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
27
|
+
"threshold": "BLOCK_NONE",
|
28
|
+
},
|
29
|
+
]
|
30
|
+
|
10
31
|
|
11
32
|
class GoogleService(InferenceServiceABC):
|
12
33
|
_inference_service_ = "google"
|
13
34
|
key_sequence = ["candidates", 0, "content", "parts", 0, "text"]
|
14
|
-
usage_sequence = ["
|
15
|
-
input_token_name = "
|
16
|
-
output_token_name = "
|
35
|
+
usage_sequence = ["usage_metadata"]
|
36
|
+
input_token_name = "prompt_token_count"
|
37
|
+
output_token_name = "candidates_token_count"
|
17
38
|
|
18
39
|
model_exclude_list = []
|
19
40
|
|
41
|
+
# @classmethod
|
42
|
+
# def available(cls) -> List[str]:
|
43
|
+
# return ["gemini-pro", "gemini-1.5-pro", "gemini-1.5-flash", "gemini-1.0-pro"]
|
44
|
+
|
20
45
|
@classmethod
|
21
|
-
def available(cls):
|
22
|
-
|
46
|
+
def available(cls) -> List[str]:
|
47
|
+
model_list = []
|
48
|
+
for m in genai.list_models():
|
49
|
+
if "generateContent" in m.supported_generation_methods:
|
50
|
+
model_list.append(m.name.split("/")[-1])
|
51
|
+
return model_list
|
23
52
|
|
24
53
|
@classmethod
|
25
54
|
def create_model(
|
@@ -47,33 +76,79 @@ class GoogleService(InferenceServiceABC):
|
|
47
76
|
"stopSequences": [],
|
48
77
|
}
|
49
78
|
|
79
|
+
api_token = None
|
80
|
+
model = None
|
81
|
+
|
82
|
+
@classmethod
|
83
|
+
def initialize(cls):
|
84
|
+
if cls.api_token is None:
|
85
|
+
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
+
if not cls.api_token:
|
87
|
+
raise MissingAPIKeyError(
|
88
|
+
"GOOGLE_API_KEY environment variable is not set"
|
89
|
+
)
|
90
|
+
genai.configure(api_key=cls.api_token)
|
91
|
+
cls.generative_model = genai.GenerativeModel(
|
92
|
+
cls._model_, safety_settings=safety_settings
|
93
|
+
)
|
94
|
+
|
95
|
+
def __init__(self, *args, **kwargs):
|
96
|
+
super().__init__(*args, **kwargs)
|
97
|
+
self.initialize()
|
98
|
+
|
99
|
+
def get_generation_config(self) -> GenerationConfig:
|
100
|
+
return GenerationConfig(
|
101
|
+
temperature=self.temperature,
|
102
|
+
top_p=self.topP,
|
103
|
+
top_k=self.topK,
|
104
|
+
max_output_tokens=self.maxOutputTokens,
|
105
|
+
stop_sequences=self.stopSequences,
|
106
|
+
)
|
107
|
+
|
50
108
|
async def async_execute_model_call(
|
51
|
-
self,
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
data = {
|
58
|
-
"contents": [{"parts": [{"text": combined_prompt}]}],
|
59
|
-
"generationConfig": {
|
60
|
-
"temperature": self.temperature,
|
61
|
-
"topK": self.topK,
|
62
|
-
"topP": self.topP,
|
63
|
-
"maxOutputTokens": self.maxOutputTokens,
|
64
|
-
"stopSequences": self.stopSequences,
|
65
|
-
},
|
66
|
-
}
|
67
|
-
# print(combined_prompt)
|
68
|
-
async with aiohttp.ClientSession() as session:
|
69
|
-
async with session.post(
|
70
|
-
url, headers=headers, data=json.dumps(data)
|
71
|
-
) as response:
|
72
|
-
raw_response_text = await response.text()
|
73
|
-
return json.loads(raw_response_text)
|
109
|
+
self,
|
110
|
+
user_prompt: str,
|
111
|
+
system_prompt: str = "",
|
112
|
+
files_list: Optional["Files"] = None,
|
113
|
+
) -> Dict[str, Any]:
|
114
|
+
generation_config = self.get_generation_config()
|
74
115
|
|
75
|
-
|
116
|
+
if files_list is None:
|
117
|
+
files_list = []
|
118
|
+
|
119
|
+
if (
|
120
|
+
system_prompt is not None
|
121
|
+
and system_prompt != ""
|
122
|
+
and self._model_ != "gemini-pro"
|
123
|
+
):
|
124
|
+
try:
|
125
|
+
self.generative_model = genai.GenerativeModel(
|
126
|
+
self._model_,
|
127
|
+
safety_settings=safety_settings,
|
128
|
+
system_instruction=system_prompt,
|
129
|
+
)
|
130
|
+
except InvalidArgument as e:
|
131
|
+
print(
|
132
|
+
f"This model, {self._model_}, does not support system_instruction"
|
133
|
+
)
|
134
|
+
print("Will add system_prompt to user_prompt")
|
135
|
+
user_prompt = f"{system_prompt}\n{user_prompt}"
|
76
136
|
|
137
|
+
combined_prompt = [user_prompt]
|
138
|
+
for file in files_list:
|
139
|
+
if "google" not in file.external_locations:
|
140
|
+
_ = file.upload_google()
|
141
|
+
gen_ai_file = google.generativeai.types.file_types.File(
|
142
|
+
file.external_locations["google"]
|
143
|
+
)
|
144
|
+
combined_prompt.append(gen_ai_file)
|
145
|
+
|
146
|
+
response = await self.generative_model.generate_content_async(
|
147
|
+
combined_prompt, generation_config=generation_config
|
148
|
+
)
|
149
|
+
return response.to_dict()
|
150
|
+
|
151
|
+
LLM.__name__ = model_name
|
77
152
|
return LLM
|
78
153
|
|
79
154
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import os
|
2
|
-
from typing import Any, List
|
2
|
+
from typing import Any, List, Optional
|
3
3
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
4
4
|
from edsl.language_models.LanguageModel import LanguageModel
|
5
5
|
import asyncio
|
@@ -95,7 +95,10 @@ class MistralAIService(InferenceServiceABC):
|
|
95
95
|
return cls.async_client()
|
96
96
|
|
97
97
|
async def async_execute_model_call(
|
98
|
-
self,
|
98
|
+
self,
|
99
|
+
user_prompt: str,
|
100
|
+
system_prompt: str = "",
|
101
|
+
files_list: Optional[List["FileStore"]] = None,
|
99
102
|
) -> dict[str, Any]:
|
100
103
|
"""Calls the Mistral API and returns the API response."""
|
101
104
|
s = self.async_client()
|
@@ -168,13 +168,14 @@ class OpenAIService(InferenceServiceABC):
|
|
168
168
|
self,
|
169
169
|
user_prompt: str,
|
170
170
|
system_prompt: str = "",
|
171
|
-
|
171
|
+
files_list: Optional[List["Files"]] = None,
|
172
172
|
invigilator: Optional[
|
173
173
|
"InvigilatorAI"
|
174
174
|
] = None, # TBD - can eventually be used for function-calling
|
175
175
|
) -> dict[str, Any]:
|
176
176
|
"""Calls the OpenAI API and returns the API response."""
|
177
|
-
if
|
177
|
+
if files_list:
|
178
|
+
encoded_image = files_list[0].base64_string
|
178
179
|
content = [{"type": "text", "text": user_prompt}]
|
179
180
|
content.append(
|
180
181
|
{
|
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Any, List
|
1
|
+
from typing import Any, List, Optional
|
2
2
|
import os
|
3
3
|
import asyncio
|
4
4
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
@@ -59,11 +59,20 @@ class TestService(InferenceServiceABC):
|
|
59
59
|
self,
|
60
60
|
user_prompt: str,
|
61
61
|
system_prompt: str,
|
62
|
-
|
62
|
+
# func: Optional[callable] = None,
|
63
|
+
files_list: Optional[List["File"]] = None,
|
63
64
|
) -> dict[str, Any]:
|
64
65
|
await asyncio.sleep(0.1)
|
65
66
|
# return {"message": """{"answer": "Hello, world"}"""}
|
66
67
|
|
68
|
+
if hasattr(self, "func"):
|
69
|
+
return {
|
70
|
+
"message": [
|
71
|
+
{"text": self.func(user_prompt, system_prompt, files_list)}
|
72
|
+
],
|
73
|
+
"usage": {"prompt_tokens": 1, "completion_tokens": 1},
|
74
|
+
}
|
75
|
+
|
67
76
|
if hasattr(self, "throw_exception") and self.throw_exception:
|
68
77
|
if hasattr(self, "exception_probability"):
|
69
78
|
p = self.exception_probability
|
@@ -105,9 +105,9 @@ class Interview(InterviewStatusMixin):
|
|
105
105
|
self.debug = debug
|
106
106
|
self.iteration = iteration
|
107
107
|
self.cache = cache
|
108
|
-
self.answers: dict[
|
109
|
-
|
110
|
-
|
108
|
+
self.answers: dict[str, str] = (
|
109
|
+
Answers()
|
110
|
+
) # will get filled in as interview progresses
|
111
111
|
self.sidecar_model = sidecar_model
|
112
112
|
|
113
113
|
# self.stop_on_exception = False
|
@@ -248,17 +248,24 @@ class Interview(InterviewStatusMixin):
|
|
248
248
|
|
249
249
|
def _get_estimated_request_tokens(self, question) -> float:
|
250
250
|
"""Estimate the number of tokens that will be required to run the focal task."""
|
251
|
+
from edsl.scenarios.FileStore import FileStore
|
252
|
+
|
251
253
|
invigilator = self._get_invigilator(question=question)
|
252
254
|
# TODO: There should be a way to get a more accurate estimate.
|
253
255
|
combined_text = ""
|
256
|
+
file_tokens = 0
|
254
257
|
for prompt in invigilator.get_prompts().values():
|
255
258
|
if hasattr(prompt, "text"):
|
256
259
|
combined_text += prompt.text
|
257
260
|
elif isinstance(prompt, str):
|
258
261
|
combined_text += prompt
|
262
|
+
elif isinstance(prompt, list):
|
263
|
+
for file in prompt:
|
264
|
+
if isinstance(file, FileStore):
|
265
|
+
file_tokens += file.size * 0.25
|
259
266
|
else:
|
260
267
|
raise ValueError(f"Prompt is of type {type(prompt)}")
|
261
|
-
return len(combined_text) / 4.0
|
268
|
+
return len(combined_text) / 4.0 + file_tokens
|
262
269
|
|
263
270
|
async def _answer_question_and_record_task(
|
264
271
|
self,
|
@@ -296,6 +303,9 @@ class Interview(InterviewStatusMixin):
|
|
296
303
|
self.answers.add_answer(response=response, question=question)
|
297
304
|
self._cancel_skipped_questions(question)
|
298
305
|
else:
|
306
|
+
# When a question is not validated, it is not added to the answers.
|
307
|
+
# this should also cancel and dependent children questions.
|
308
|
+
# Is that happening now?
|
299
309
|
if (
|
300
310
|
hasattr(response, "exception_occurred")
|
301
311
|
and response.exception_occurred
|
@@ -418,11 +428,11 @@ class Interview(InterviewStatusMixin):
|
|
418
428
|
"""
|
419
429
|
current_question_index: int = self.to_index[current_question.question_name]
|
420
430
|
|
421
|
-
next_question: Union[
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
431
|
+
next_question: Union[int, EndOfSurvey] = (
|
432
|
+
self.survey.rule_collection.next_question(
|
433
|
+
q_now=current_question_index,
|
434
|
+
answers=self.answers | self.scenario | self.agent["traits"],
|
435
|
+
)
|
426
436
|
)
|
427
437
|
|
428
438
|
next_question_index = next_question.next_q
|
@@ -8,10 +8,10 @@ from typing import Coroutine, List, AsyncGenerator, Optional, Union, Generator
|
|
8
8
|
from contextlib import contextmanager
|
9
9
|
from collections import UserList
|
10
10
|
|
11
|
-
from edsl.results.Results import Results
|
12
11
|
from rich.live import Live
|
13
12
|
from rich.console import Console
|
14
13
|
|
14
|
+
from edsl.results.Results import Results
|
15
15
|
from edsl import shared_globals
|
16
16
|
from edsl.jobs.interviews.Interview import Interview
|
17
17
|
from edsl.jobs.runners.JobsRunnerStatus import JobsRunnerStatus
|
@@ -22,6 +22,8 @@ from edsl.utilities.decorators import jupyter_nb_handler
|
|
22
22
|
from edsl.data.Cache import Cache
|
23
23
|
from edsl.results.Result import Result
|
24
24
|
from edsl.results.Results import Results
|
25
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
26
|
+
from edsl.data.Cache import Cache
|
25
27
|
|
26
28
|
|
27
29
|
class StatusTracker(UserList):
|
@@ -50,10 +52,10 @@ class JobsRunnerAsyncio:
|
|
50
52
|
|
51
53
|
async def run_async_generator(
|
52
54
|
self,
|
53
|
-
cache:
|
55
|
+
cache: Cache,
|
54
56
|
n: int = 1,
|
55
57
|
stop_on_exception: bool = False,
|
56
|
-
sidecar_model: Optional[
|
58
|
+
sidecar_model: Optional[LanguageModel] = None,
|
57
59
|
total_interviews: Optional[List["Interview"]] = None,
|
58
60
|
raise_validation_errors: bool = False,
|
59
61
|
) -> AsyncGenerator["Result", None]:
|
@@ -104,7 +106,7 @@ class JobsRunnerAsyncio:
|
|
104
106
|
interview.cache = self.cache
|
105
107
|
yield interview
|
106
108
|
|
107
|
-
async def run_async(self, cache: Optional[
|
109
|
+
async def run_async(self, cache: Optional[Cache] = None, n: int = 1) -> Results:
|
108
110
|
"""Used for some other modules that have a non-standard way of running interviews."""
|
109
111
|
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
110
112
|
self.cache = Cache() if cache is None else cache
|
@@ -291,6 +293,8 @@ class JobsRunnerAsyncio:
|
|
291
293
|
|
292
294
|
self.jobs_runner_status = JobsRunnerStatus(self, n=n)
|
293
295
|
|
296
|
+
stop_event = threading.Event()
|
297
|
+
|
294
298
|
async def process_results(cache):
|
295
299
|
"""Processes results from interviews."""
|
296
300
|
async for result in self.run_async_generator(
|
@@ -303,20 +307,37 @@ class JobsRunnerAsyncio:
|
|
303
307
|
self.results.append(result)
|
304
308
|
self.completed = True
|
305
309
|
|
306
|
-
def run_progress_bar():
|
310
|
+
def run_progress_bar(stop_event):
|
307
311
|
"""Runs the progress bar in a separate thread."""
|
308
|
-
self.jobs_runner_status.update_progress()
|
312
|
+
self.jobs_runner_status.update_progress(stop_event)
|
309
313
|
|
310
314
|
if progress_bar:
|
311
|
-
progress_thread = threading.Thread(
|
315
|
+
progress_thread = threading.Thread(
|
316
|
+
target=run_progress_bar, args=(stop_event,)
|
317
|
+
)
|
312
318
|
progress_thread.start()
|
313
319
|
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
320
|
+
exception_to_raise = None
|
321
|
+
try:
|
322
|
+
with cache as c:
|
323
|
+
await process_results(cache=c)
|
324
|
+
except KeyboardInterrupt:
|
325
|
+
print("Keyboard interrupt received. Stopping gracefully...")
|
326
|
+
stop_event.set()
|
327
|
+
except Exception as e:
|
328
|
+
if stop_on_exception:
|
329
|
+
exception_to_raise = e
|
330
|
+
stop_event.set()
|
331
|
+
finally:
|
332
|
+
stop_event.set()
|
333
|
+
if progress_bar:
|
334
|
+
# self.jobs_runner_status.stop_event.set()
|
335
|
+
if progress_thread:
|
336
|
+
progress_thread.join()
|
337
|
+
|
338
|
+
if exception_to_raise:
|
339
|
+
raise exception_to_raise
|
340
|
+
|
341
|
+
return self.process_results(
|
342
|
+
raw_results=self.results, cache=cache, print_exceptions=print_exceptions
|
343
|
+
)
|
@@ -265,14 +265,15 @@ class JobsRunnerStatus:
|
|
265
265
|
table.add_row(pretty_name, value)
|
266
266
|
return table
|
267
267
|
|
268
|
-
def update_progress(self):
|
268
|
+
def update_progress(self, stop_event):
|
269
269
|
layout, progress, task_ids = self.generate_layout()
|
270
270
|
|
271
271
|
with Live(
|
272
272
|
layout, refresh_per_second=int(1 / self.refresh_rate), transient=True
|
273
273
|
) as live:
|
274
|
-
while
|
275
|
-
self.jobs_runner.total_interviews
|
274
|
+
while (
|
275
|
+
len(self.completed_interviews) < len(self.jobs_runner.total_interviews)
|
276
|
+
and not stop_event.is_set()
|
276
277
|
):
|
277
278
|
completed_tasks = len(self.completed_interviews)
|
278
279
|
total_tasks = len(self.jobs_runner.total_interviews)
|
@@ -156,19 +156,6 @@ class QuestionTaskCreator(UserList):
|
|
156
156
|
self.tokens_bucket.turbo_mode_off()
|
157
157
|
self.requests_bucket.turbo_mode_off()
|
158
158
|
|
159
|
-
# breakpoint()
|
160
|
-
# _ = results.pop("cached_response", None)
|
161
|
-
|
162
|
-
# tracker = self.cached_token_usage if self.from_cache else self.new_token_usage
|
163
|
-
|
164
|
-
# TODO: This is hacky. The 'func' call should return an object that definitely has a 'usage' key.
|
165
|
-
# usage = results.get("usage", {"prompt_tokens": 0, "completion_tokens": 0})
|
166
|
-
# prompt_tokens = usage.get("prompt_tokens", 0)
|
167
|
-
# completion_tokens = usage.get("completion_tokens", 0)
|
168
|
-
# tracker.add_tokens(
|
169
|
-
# prompt_tokens=prompt_tokens, completion_tokens=completion_tokens
|
170
|
-
# )
|
171
|
-
|
172
159
|
return results
|
173
160
|
|
174
161
|
@classmethod
|
@@ -249,6 +236,7 @@ class QuestionTaskCreator(UserList):
|
|
249
236
|
f"Required tasks failed for {self.question.question_name}"
|
250
237
|
) from e
|
251
238
|
|
239
|
+
# this only runs if all the dependencies are successful
|
252
240
|
return await self._run_focal_task()
|
253
241
|
|
254
242
|
|
@@ -440,7 +440,7 @@ class LanguageModel(
|
|
440
440
|
system_prompt: str,
|
441
441
|
cache: "Cache",
|
442
442
|
iteration: int = 0,
|
443
|
-
|
443
|
+
files_list=None,
|
444
444
|
) -> ModelResponse:
|
445
445
|
"""Handle caching of responses.
|
446
446
|
|
@@ -462,16 +462,18 @@ class LanguageModel(
|
|
462
462
|
>>> m._get_intended_model_call_outcome(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
463
463
|
ModelResponse(...)"""
|
464
464
|
|
465
|
-
if
|
466
|
-
|
467
|
-
|
468
|
-
user_prompt
|
465
|
+
if files_list:
|
466
|
+
files_hash = "+".join([str(hash(file)) for file in files_list])
|
467
|
+
# print(f"Files hash: {files_hash}")
|
468
|
+
user_prompt_with_hashes = user_prompt + f" {files_hash}"
|
469
|
+
else:
|
470
|
+
user_prompt_with_hashes = user_prompt
|
469
471
|
|
470
472
|
cache_call_params = {
|
471
473
|
"model": str(self.model),
|
472
474
|
"parameters": self.parameters,
|
473
475
|
"system_prompt": system_prompt,
|
474
|
-
"user_prompt":
|
476
|
+
"user_prompt": user_prompt_with_hashes,
|
475
477
|
"iteration": iteration,
|
476
478
|
}
|
477
479
|
cached_response, cache_key = cache.fetch(**cache_call_params)
|
@@ -487,7 +489,8 @@ class LanguageModel(
|
|
487
489
|
params = {
|
488
490
|
"user_prompt": user_prompt,
|
489
491
|
"system_prompt": system_prompt,
|
490
|
-
|
492
|
+
"files_list": files_list
|
493
|
+
#**({"encoded_image": encoded_image} if encoded_image else {}),
|
491
494
|
}
|
492
495
|
# response = await f(**params)
|
493
496
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
@@ -531,7 +534,7 @@ class LanguageModel(
|
|
531
534
|
system_prompt: str,
|
532
535
|
cache: "Cache",
|
533
536
|
iteration: int = 1,
|
534
|
-
|
537
|
+
files_list: Optional[List['File']] = None,
|
535
538
|
) -> dict:
|
536
539
|
"""Get response, parse, and return as string.
|
537
540
|
|
@@ -547,7 +550,7 @@ class LanguageModel(
|
|
547
550
|
"system_prompt": system_prompt,
|
548
551
|
"iteration": iteration,
|
549
552
|
"cache": cache,
|
550
|
-
|
553
|
+
"files_list": files_list,
|
551
554
|
}
|
552
555
|
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
553
556
|
model_outputs = await self._async_get_intended_model_call_outcome(**params)
|
@@ -1,5 +1,5 @@
|
|
1
1
|
import asyncio
|
2
|
-
from typing import Any
|
2
|
+
from typing import Any, Optional, List
|
3
3
|
from edsl import Survey
|
4
4
|
from edsl.config import CONFIG
|
5
5
|
from edsl.enums import InferenceServiceType
|
@@ -40,7 +40,8 @@ def create_language_model(
|
|
40
40
|
_tpm = 1000000000000
|
41
41
|
|
42
42
|
async def async_execute_model_call(
|
43
|
-
self, user_prompt: str, system_prompt: str
|
43
|
+
self, user_prompt: str, system_prompt: str,
|
44
|
+
files_list: Optional[List[Any]] = None
|
44
45
|
) -> dict[str, Any]:
|
45
46
|
question_number = int(
|
46
47
|
user_prompt.split("XX")[1]
|
edsl/questions/QuestionBase.py
CHANGED
@@ -44,6 +44,13 @@ class QuestionBase(
|
|
44
44
|
_answering_instructions = None
|
45
45
|
_question_presentation = None
|
46
46
|
|
47
|
+
@property
|
48
|
+
def response_model(self) -> type["BaseModel"]:
|
49
|
+
if self._response_model is not None:
|
50
|
+
return self._response_model
|
51
|
+
else:
|
52
|
+
return self.create_response_model()
|
53
|
+
|
47
54
|
# region: Validation and simulation methods
|
48
55
|
@property
|
49
56
|
def response_validator(self) -> "ResponseValidatorBase":
|
@@ -98,7 +105,9 @@ class QuestionBase(
|
|
98
105
|
comment: Optional[str]
|
99
106
|
generated_tokens: Optional[str]
|
100
107
|
|
101
|
-
def _validate_answer(
|
108
|
+
def _validate_answer(
|
109
|
+
self, answer: dict, replacement_dict: dict = None
|
110
|
+
) -> ValidatedAnswer:
|
102
111
|
"""Validate the answer.
|
103
112
|
>>> from edsl.exceptions import QuestionAnswerValidationError
|
104
113
|
>>> from edsl import QuestionFreeText as Q
|
@@ -106,7 +115,7 @@ class QuestionBase(
|
|
106
115
|
{'answer': 'Hello', 'generated_tokens': 'Hello'}
|
107
116
|
"""
|
108
117
|
|
109
|
-
return self.response_validator.validate(answer)
|
118
|
+
return self.response_validator.validate(answer, replacement_dict)
|
110
119
|
|
111
120
|
# endregion
|
112
121
|
|
@@ -95,6 +95,34 @@ class QuestionBaseGenMixin:
|
|
95
95
|
questions.append(QuestionBase.from_dict(new_data))
|
96
96
|
return questions
|
97
97
|
|
98
|
+
def render(self, replacement_dict: dict) -> "QuestionBase":
|
99
|
+
"""Render the question components as jinja2 templates with the replacement dictionary."""
|
100
|
+
from jinja2 import Environment
|
101
|
+
from edsl import Scenario
|
102
|
+
|
103
|
+
strings_only_replacement_dict = {
|
104
|
+
k: v for k, v in replacement_dict.items() if not isinstance(v, Scenario)
|
105
|
+
}
|
106
|
+
|
107
|
+
def render_string(value: str) -> str:
|
108
|
+
if value is None or not isinstance(value, str):
|
109
|
+
return value
|
110
|
+
else:
|
111
|
+
try:
|
112
|
+
return (
|
113
|
+
Environment()
|
114
|
+
.from_string(value)
|
115
|
+
.render(strings_only_replacement_dict)
|
116
|
+
)
|
117
|
+
except Exception as e:
|
118
|
+
import warnings
|
119
|
+
|
120
|
+
warnings.warn("Failed to render string: " + value)
|
121
|
+
# breakpoint()
|
122
|
+
return value
|
123
|
+
|
124
|
+
return self.apply_function(render_string)
|
125
|
+
|
98
126
|
def apply_function(self, func: Callable, exclude_components=None) -> QuestionBase:
|
99
127
|
"""Apply a function to the question parts
|
100
128
|
|
@@ -245,7 +245,7 @@ class QuestionCheckBox(QuestionBase):
|
|
245
245
|
|
246
246
|
scenario = scenario or Scenario()
|
247
247
|
translated_options = [
|
248
|
-
Template(option).render(scenario) for option in self.question_options
|
248
|
+
Template(str(option)).render(scenario) for option in self.question_options
|
249
249
|
]
|
250
250
|
translated_codes = []
|
251
251
|
for answer_code in answer_codes:
|
@@ -163,7 +163,11 @@ class QuestionMultipleChoice(QuestionBase):
|
|
163
163
|
# Answer methods
|
164
164
|
################
|
165
165
|
|
166
|
-
def create_response_model(self):
|
166
|
+
def create_response_model(self, replacement_dict: dict = None):
|
167
|
+
if replacement_dict is None:
|
168
|
+
replacement_dict = {}
|
169
|
+
# The replacement dict that could be from scenario, current answers, etc. to populate the response model
|
170
|
+
|
167
171
|
if self.use_code:
|
168
172
|
return create_response_model(
|
169
173
|
list(range(len(self.question_options))), self.permissive
|
@@ -92,7 +92,11 @@ class ResponseValidatorABC(ABC):
|
|
92
92
|
generated_tokens: Optional[str]
|
93
93
|
|
94
94
|
def validate(
|
95
|
-
self,
|
95
|
+
self,
|
96
|
+
raw_edsl_answer_dict: RawEdslAnswerDict,
|
97
|
+
fix=False,
|
98
|
+
verbose=False,
|
99
|
+
replacement_dict: dict = None,
|
96
100
|
) -> EdslAnswerDict:
|
97
101
|
"""This is the main validation function.
|
98
102
|
|