edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +9 -3
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +6 -6
- edsl/agents/Invigilator.py +6 -3
- edsl/agents/InvigilatorBase.py +8 -27
- edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
- edsl/config.py +26 -34
- edsl/coop/coop.py +11 -2
- edsl/data_transfer_models.py +27 -73
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +1 -1
- edsl/inference_services/InferenceServiceABC.py +44 -13
- edsl/inference_services/OpenAIService.py +7 -4
- edsl/inference_services/TestService.py +24 -15
- edsl/inference_services/TogetherAIService.py +170 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +18 -8
- edsl/jobs/buckets/BucketCollection.py +24 -15
- edsl/jobs/buckets/TokenBucket.py +64 -10
- edsl/jobs/interviews/Interview.py +115 -47
- edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
- edsl/jobs/runners/JobsRunnerStatus.py +331 -0
- edsl/jobs/tasks/TaskHistory.py +17 -0
- edsl/language_models/LanguageModel.py +26 -31
- edsl/language_models/registry.py +13 -9
- edsl/questions/QuestionBase.py +64 -16
- edsl/questions/QuestionBudget.py +93 -41
- edsl/questions/QuestionFreeText.py +6 -0
- edsl/questions/QuestionMultipleChoice.py +11 -26
- edsl/questions/QuestionNumerical.py +5 -4
- edsl/questions/Quick.py +41 -0
- edsl/questions/ResponseValidatorABC.py +6 -5
- edsl/questions/derived/QuestionLinearScale.py +4 -1
- edsl/questions/derived/QuestionTopK.py +4 -1
- edsl/questions/derived/QuestionYesNo.py +8 -2
- edsl/questions/templates/budget/__init__.py +0 -0
- edsl/questions/templates/budget/answering_instructions.jinja +7 -0
- edsl/questions/templates/budget/question_presentation.jinja +7 -0
- edsl/questions/templates/extract/__init__.py +0 -0
- edsl/questions/templates/rank/__init__.py +0 -0
- edsl/results/DatasetExportMixin.py +5 -1
- edsl/results/Result.py +1 -1
- edsl/results/Results.py +4 -1
- edsl/scenarios/FileStore.py +71 -10
- edsl/scenarios/Scenario.py +86 -21
- edsl/scenarios/ScenarioImageMixin.py +2 -2
- edsl/scenarios/ScenarioList.py +13 -0
- edsl/scenarios/ScenarioListPdfMixin.py +150 -4
- edsl/study/Study.py +32 -0
- edsl/surveys/Rule.py +10 -1
- edsl/surveys/RuleCollection.py +19 -3
- edsl/surveys/Survey.py +7 -0
- edsl/templates/error_reporting/interview_details.html +6 -1
- edsl/utilities/utilities.py +9 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
- edsl/jobs/interviews/retry_management.py +0 -39
- edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
- {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -47,21 +47,27 @@ class PersistenceMixin:
|
|
47
47
|
self,
|
48
48
|
description: Optional[str] = None,
|
49
49
|
visibility: Optional[str] = "unlisted",
|
50
|
+
expected_parrot_url: Optional[str] = None,
|
50
51
|
):
|
51
52
|
"""Post the object to coop."""
|
52
53
|
from edsl.coop import Coop
|
53
54
|
|
54
|
-
c = Coop()
|
55
|
+
c = Coop(url=expected_parrot_url)
|
55
56
|
return c.create(self, description, visibility)
|
56
57
|
|
57
58
|
@classmethod
|
58
|
-
def pull(
|
59
|
+
def pull(
|
60
|
+
cls,
|
61
|
+
uuid: Optional[Union[str, UUID]] = None,
|
62
|
+
url: Optional[str] = None,
|
63
|
+
expected_parrot_url: Optional[str] = None,
|
64
|
+
):
|
59
65
|
"""Pull the object from coop."""
|
60
66
|
from edsl.coop import Coop
|
61
67
|
from edsl.coop.utils import ObjectRegistry
|
62
68
|
|
63
69
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
64
|
-
coop = Coop()
|
70
|
+
coop = Coop(url=expected_parrot_url)
|
65
71
|
return coop.get(uuid, url, object_type)
|
66
72
|
|
67
73
|
@classmethod
|
edsl/__init__.py
CHANGED
@@ -23,6 +23,7 @@ from edsl.questions import QuestionNumerical
|
|
23
23
|
from edsl.questions import QuestionYesNo
|
24
24
|
from edsl.questions import QuestionBudget
|
25
25
|
from edsl.questions import QuestionRank
|
26
|
+
from edsl.questions import QuestionTopK
|
26
27
|
|
27
28
|
from edsl.scenarios import Scenario
|
28
29
|
from edsl.scenarios import ScenarioList
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.33.
|
1
|
+
__version__ = "0.1.33.dev3"
|
edsl/agents/Agent.py
CHANGED
@@ -586,9 +586,9 @@ class Agent(Base):
|
|
586
586
|
if dynamic_traits_func:
|
587
587
|
func = inspect.getsource(dynamic_traits_func)
|
588
588
|
raw_data["dynamic_traits_function_source_code"] = func
|
589
|
-
raw_data[
|
590
|
-
|
591
|
-
|
589
|
+
raw_data[
|
590
|
+
"dynamic_traits_function_name"
|
591
|
+
] = self.dynamic_traits_function_name
|
592
592
|
if hasattr(self, "answer_question_directly"):
|
593
593
|
raw_data.pop(
|
594
594
|
"answer_question_directly", None
|
@@ -604,9 +604,9 @@ class Agent(Base):
|
|
604
604
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
605
605
|
answer_question_directly_func
|
606
606
|
)
|
607
|
-
raw_data[
|
608
|
-
|
609
|
-
|
607
|
+
raw_data[
|
608
|
+
"answer_question_directly_function_name"
|
609
|
+
] = self.answer_question_directly_function_name
|
610
610
|
|
611
611
|
return raw_data
|
612
612
|
|
edsl/agents/Invigilator.py
CHANGED
@@ -2,14 +2,13 @@
|
|
2
2
|
|
3
3
|
from typing import Dict, Any, Optional
|
4
4
|
|
5
|
-
from edsl.exceptions import AgentRespondedWithBadJSONError
|
6
5
|
from edsl.prompts.Prompt import Prompt
|
7
6
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
8
7
|
from edsl.prompts.registry import get_classes as prompt_lookup
|
9
8
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
10
|
-
from edsl.agents.PromptConstructionMixin import PromptConstructorMixin
|
11
9
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
12
10
|
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
11
|
+
from edsl.agents.PromptConstructor import PromptConstructor
|
13
12
|
|
14
13
|
|
15
14
|
class NotApplicable(str):
|
@@ -19,9 +18,13 @@ class NotApplicable(str):
|
|
19
18
|
return instance
|
20
19
|
|
21
20
|
|
22
|
-
class InvigilatorAI(
|
21
|
+
class InvigilatorAI(InvigilatorBase):
|
23
22
|
"""An invigilator that uses an AI model to answer questions."""
|
24
23
|
|
24
|
+
def get_prompts(self) -> Dict[str, Prompt]:
|
25
|
+
"""Return the prompts used."""
|
26
|
+
return self.prompt_constructor.get_prompts()
|
27
|
+
|
25
28
|
async def async_answer_question(self) -> AgentResponseDict:
|
26
29
|
"""Answer a question using the AI model.
|
27
30
|
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -14,6 +14,7 @@ from edsl.surveys.MemoryPlan import MemoryPlan
|
|
14
14
|
from edsl.language_models.LanguageModel import LanguageModel
|
15
15
|
|
16
16
|
from edsl.data_transfer_models import EDSLResultObjectInput
|
17
|
+
from edsl.agents.PromptConstructor import PromptConstructor
|
17
18
|
|
18
19
|
|
19
20
|
class InvigilatorBase(ABC):
|
@@ -27,16 +28,7 @@ class InvigilatorBase(ABC):
|
|
27
28
|
|
28
29
|
This returns an empty prompt because there is no memory the agent needs to have at q0.
|
29
30
|
|
30
|
-
>>> InvigilatorBase.example().create_memory_prompt("q0")
|
31
|
-
Prompt(text=\"""\""")
|
32
31
|
|
33
|
-
>>> i = InvigilatorBase.example()
|
34
|
-
>>> i.current_answers = {"q0": "Prior answer"}
|
35
|
-
>>> i.memory_plan.add_single_memory("q1", "q0")
|
36
|
-
>>> i.create_memory_prompt("q1")
|
37
|
-
Prompt(text=\"""
|
38
|
-
Before the question you are now answering, you already answered the following question(s):
|
39
|
-
...
|
40
32
|
"""
|
41
33
|
|
42
34
|
def __init__(
|
@@ -72,6 +64,11 @@ class InvigilatorBase(ABC):
|
|
72
64
|
None # placeholder for the raw response from the model
|
73
65
|
)
|
74
66
|
|
67
|
+
@property
|
68
|
+
def prompt_constructor(self) -> PromptConstructor:
|
69
|
+
"""Return the prompt constructor."""
|
70
|
+
return PromptConstructor(self)
|
71
|
+
|
75
72
|
def to_dict(self):
|
76
73
|
attributes = [
|
77
74
|
"agent",
|
@@ -207,22 +204,6 @@ class InvigilatorBase(ABC):
|
|
207
204
|
|
208
205
|
return main()
|
209
206
|
|
210
|
-
def create_memory_prompt(self, question_name: str) -> Prompt:
|
211
|
-
"""Create a memory for the agent.
|
212
|
-
|
213
|
-
The returns a memory prompt for the agent.
|
214
|
-
|
215
|
-
>>> i = InvigilatorBase.example()
|
216
|
-
>>> i.current_answers = {"q0": "Prior answer"}
|
217
|
-
>>> i.memory_plan.add_single_memory("q1", "q0")
|
218
|
-
>>> p = i.create_memory_prompt("q1")
|
219
|
-
>>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
|
220
|
-
'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
|
221
|
-
"""
|
222
|
-
return self.memory_plan.get_memory_prompt_fragment(
|
223
|
-
question_name, self.current_answers
|
224
|
-
)
|
225
|
-
|
226
207
|
@classmethod
|
227
208
|
def example(
|
228
209
|
cls, throw_an_exception=False, question=None, scenario=None, survey=None
|
@@ -285,9 +266,9 @@ class InvigilatorBase(ABC):
|
|
285
266
|
|
286
267
|
memory_plan = MemoryPlan(survey=survey)
|
287
268
|
current_answers = None
|
288
|
-
from edsl.agents.
|
269
|
+
from edsl.agents.PromptConstructor import PromptConstructor
|
289
270
|
|
290
|
-
class InvigilatorExample(
|
271
|
+
class InvigilatorExample(InvigilatorBase):
|
291
272
|
"""An example invigilator."""
|
292
273
|
|
293
274
|
async def async_answer_question(self):
|
@@ -1,16 +1,15 @@
|
|
1
1
|
from __future__ import annotations
|
2
|
-
from typing import Dict, Any, Optional
|
2
|
+
from typing import Dict, Any, Optional, Set
|
3
3
|
from collections import UserList
|
4
|
+
import enum
|
4
5
|
|
5
|
-
|
6
|
-
from edsl.prompts.Prompt import Prompt
|
6
|
+
from jinja2 import Environment, meta
|
7
7
|
|
8
|
-
|
8
|
+
from edsl.prompts.Prompt import Prompt
|
9
|
+
from edsl.data_transfer_models import ImageInfo
|
9
10
|
from edsl.prompts.registry import get_classes as prompt_lookup
|
10
11
|
from edsl.exceptions import QuestionScenarioRenderError
|
11
12
|
|
12
|
-
import enum
|
13
|
-
|
14
13
|
|
15
14
|
class PromptComponent(enum.Enum):
|
16
15
|
AGENT_INSTRUCTIONS = "agent_instructions"
|
@@ -19,6 +18,21 @@ class PromptComponent(enum.Enum):
|
|
19
18
|
PRIOR_QUESTION_MEMORY = "prior_question_memory"
|
20
19
|
|
21
20
|
|
21
|
+
def get_jinja2_variables(template_str: str) -> Set[str]:
|
22
|
+
"""
|
23
|
+
Extracts all variable names from a Jinja2 template using Jinja2's built-in parsing.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
template_str (str): The Jinja2 template string
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
Set[str]: A set of variable names found in the template
|
30
|
+
"""
|
31
|
+
env = Environment()
|
32
|
+
ast = env.parse(template_str)
|
33
|
+
return meta.find_undeclared_variables(ast)
|
34
|
+
|
35
|
+
|
22
36
|
class PromptList(UserList):
|
23
37
|
separator = Prompt(" ")
|
24
38
|
|
@@ -137,7 +151,7 @@ class PromptPlan:
|
|
137
151
|
}
|
138
152
|
|
139
153
|
|
140
|
-
class
|
154
|
+
class PromptConstructor:
|
141
155
|
"""Mixin for constructing prompts for the LLM call.
|
142
156
|
|
143
157
|
The pieces of a prompt are:
|
@@ -149,16 +163,40 @@ class PromptConstructorMixin:
|
|
149
163
|
This is mixed into the Invigilator class.
|
150
164
|
"""
|
151
165
|
|
152
|
-
|
166
|
+
def __init__(self, invigilator):
|
167
|
+
self.invigilator = invigilator
|
168
|
+
self.agent = invigilator.agent
|
169
|
+
self.question = invigilator.question
|
170
|
+
self.scenario = invigilator.scenario
|
171
|
+
self.survey = invigilator.survey
|
172
|
+
self.model = invigilator.model
|
173
|
+
self.current_answers = invigilator.current_answers
|
174
|
+
self.memory_plan = invigilator.memory_plan
|
175
|
+
self.prompt_plan = PromptPlan() # Assuming PromptPlan is defined elsewhere
|
176
|
+
|
177
|
+
# prompt_plan = PromptPlan()
|
178
|
+
|
179
|
+
@property
|
180
|
+
def scenario_image_keys(self):
|
181
|
+
image_entries = []
|
182
|
+
|
183
|
+
for key, value in self.scenario.items():
|
184
|
+
if isinstance(value, ImageInfo):
|
185
|
+
image_entries.append(key)
|
186
|
+
return image_entries
|
153
187
|
|
154
188
|
@property
|
155
189
|
def agent_instructions_prompt(self) -> Prompt:
|
156
190
|
"""
|
157
191
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
158
192
|
>>> i = InvigilatorBase.example()
|
159
|
-
>>> i.agent_instructions_prompt
|
193
|
+
>>> i.prompt_constructor.agent_instructions_prompt
|
160
194
|
Prompt(text=\"""You are answering questions as if you were a human. Do not break character.\""")
|
161
195
|
"""
|
196
|
+
from edsl import Agent
|
197
|
+
|
198
|
+
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
199
|
+
return Prompt(text="")
|
162
200
|
if not hasattr(self, "_agent_instructions_prompt"):
|
163
201
|
applicable_prompts = prompt_lookup(
|
164
202
|
component_type="agent_instructions",
|
@@ -176,12 +214,17 @@ class PromptConstructorMixin:
|
|
176
214
|
"""
|
177
215
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
178
216
|
>>> i = InvigilatorBase.example()
|
179
|
-
>>> i.agent_persona_prompt
|
217
|
+
>>> i.prompt_constructor.agent_persona_prompt
|
180
218
|
Prompt(text=\"""You are an agent with the following persona:
|
181
219
|
{'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
182
220
|
|
183
221
|
"""
|
184
222
|
if not hasattr(self, "_agent_persona_prompt"):
|
223
|
+
from edsl import Agent
|
224
|
+
|
225
|
+
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
226
|
+
return Prompt(text="")
|
227
|
+
|
185
228
|
if not hasattr(self.agent, "agent_persona"):
|
186
229
|
applicable_prompts = prompt_lookup(
|
187
230
|
component_type="agent_persona",
|
@@ -226,27 +269,29 @@ class PromptConstructorMixin:
|
|
226
269
|
d[new_question].comment = answer
|
227
270
|
return d
|
228
271
|
|
272
|
+
@property
|
273
|
+
def question_image_keys(self):
|
274
|
+
raw_question_text = self.question.question_text
|
275
|
+
variables = get_jinja2_variables(raw_question_text)
|
276
|
+
question_image_keys = []
|
277
|
+
for var in variables:
|
278
|
+
if var in self.scenario_image_keys:
|
279
|
+
question_image_keys.append(var)
|
280
|
+
return question_image_keys
|
281
|
+
|
229
282
|
@property
|
230
283
|
def question_instructions_prompt(self) -> Prompt:
|
231
284
|
"""
|
232
285
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
233
286
|
>>> i = InvigilatorBase.example()
|
234
|
-
>>> i.question_instructions_prompt
|
287
|
+
>>> i.prompt_constructor.question_instructions_prompt
|
235
288
|
Prompt(text=\"""...
|
236
289
|
...
|
237
290
|
"""
|
238
291
|
if not hasattr(self, "_question_instructions_prompt"):
|
239
292
|
question_prompt = self.question.get_instructions(model=self.model.model)
|
240
293
|
|
241
|
-
#
|
242
|
-
# d = self.survey.question_names_to_questions()
|
243
|
-
# for question, answer in self.current_answers.items():
|
244
|
-
# if question in d:
|
245
|
-
# d[question].answer = answer
|
246
|
-
# else:
|
247
|
-
# # adds a comment to the question
|
248
|
-
# if (new_question := question.split("_comment")[0]) in d:
|
249
|
-
# d[new_question].comment = answer
|
294
|
+
# Are any of the scenario values ImageInfo
|
250
295
|
|
251
296
|
question_data = self.question.data.copy()
|
252
297
|
|
@@ -254,8 +299,6 @@ class PromptConstructorMixin:
|
|
254
299
|
# This is used when the user is using the question_options as a variable from a sceario
|
255
300
|
# if "question_options" in question_data:
|
256
301
|
if isinstance(self.question.data.get("question_options", None), str):
|
257
|
-
from jinja2 import Environment, meta
|
258
|
-
|
259
302
|
env = Environment()
|
260
303
|
parsed_content = env.parse(self.question.data["question_options"])
|
261
304
|
question_option_key = list(
|
@@ -269,8 +312,13 @@ class PromptConstructorMixin:
|
|
269
312
|
self.question.question_options = question_options
|
270
313
|
|
271
314
|
replacement_dict = (
|
272
|
-
|
273
|
-
|
|
315
|
+
{key: "<see image>" for key in self.scenario_image_keys}
|
316
|
+
| question_data
|
317
|
+
| {
|
318
|
+
k: v
|
319
|
+
for k, v in self.scenario.items()
|
320
|
+
if k not in self.scenario_image_keys
|
321
|
+
} # don't include images in the replacement dict
|
274
322
|
| self.prior_answers_dict()
|
275
323
|
| {"agent": self.agent}
|
276
324
|
| {
|
@@ -280,9 +328,10 @@ class PromptConstructorMixin:
|
|
280
328
|
),
|
281
329
|
}
|
282
330
|
)
|
283
|
-
|
331
|
+
|
284
332
|
rendered_instructions = question_prompt.render(replacement_dict)
|
285
|
-
|
333
|
+
|
334
|
+
# is there anything left to render?
|
286
335
|
undefined_template_variables = (
|
287
336
|
rendered_instructions.undefined_template_variables({})
|
288
337
|
)
|
@@ -300,7 +349,9 @@ class PromptConstructorMixin:
|
|
300
349
|
f"Question instructions still has variables: {undefined_template_variables}."
|
301
350
|
)
|
302
351
|
|
303
|
-
|
352
|
+
####################################
|
353
|
+
# Check if question has instructions - these are instructions in a Survey that can apply to multiple follow-on questions
|
354
|
+
####################################
|
304
355
|
relevant_instructions = self.survey.relevant_instructions(
|
305
356
|
self.question.question_name
|
306
357
|
)
|
@@ -329,6 +380,23 @@ class PromptConstructorMixin:
|
|
329
380
|
self._prior_question_memory_prompt = memory_prompt
|
330
381
|
return self._prior_question_memory_prompt
|
331
382
|
|
383
|
+
def create_memory_prompt(self, question_name: str) -> Prompt:
|
384
|
+
"""Create a memory for the agent.
|
385
|
+
|
386
|
+
The returns a memory prompt for the agent.
|
387
|
+
|
388
|
+
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
389
|
+
>>> i = InvigilatorBase.example()
|
390
|
+
>>> i.current_answers = {"q0": "Prior answer"}
|
391
|
+
>>> i.memory_plan.add_single_memory("q1", "q0")
|
392
|
+
>>> p = i.prompt_constructor.create_memory_prompt("q1")
|
393
|
+
>>> p.text.strip().replace("\\n", " ").replace("\\t", " ")
|
394
|
+
'Before the question you are now answering, you already answered the following question(s): Question: Do you like school? Answer: Prior answer'
|
395
|
+
"""
|
396
|
+
return self.memory_plan.get_memory_prompt_fragment(
|
397
|
+
question_name, self.current_answers
|
398
|
+
)
|
399
|
+
|
332
400
|
def construct_system_prompt(self) -> Prompt:
|
333
401
|
"""Construct the system prompt for the LLM call."""
|
334
402
|
import warnings
|
@@ -363,9 +431,13 @@ class PromptConstructorMixin:
|
|
363
431
|
question_instructions=self.question_instructions_prompt,
|
364
432
|
prior_question_memory=self.prior_question_memory_prompt,
|
365
433
|
)
|
434
|
+
if len(self.question_image_keys) > 1:
|
435
|
+
raise ValueError("We can only handle one image per question.")
|
436
|
+
elif len(self.question_image_keys) == 1:
|
437
|
+
prompts["encoded_image"] = self.scenario[
|
438
|
+
self.question_image_keys[0]
|
439
|
+
].encoded_image
|
366
440
|
|
367
|
-
if hasattr(self.scenario, "has_image") and self.scenario.has_image:
|
368
|
-
prompts["encoded_image"] = self.scenario["encoded_image"]
|
369
441
|
return prompts
|
370
442
|
|
371
443
|
def _get_scenario_with_image(self) -> Scenario:
|
edsl/config.py
CHANGED
@@ -1,73 +1,65 @@
|
|
1
1
|
"""This module provides a Config class that loads environment variables from a .env file and sets them as class attributes."""
|
2
2
|
|
3
3
|
import os
|
4
|
+
from dotenv import load_dotenv, find_dotenv
|
4
5
|
from edsl.exceptions import (
|
5
6
|
InvalidEnvironmentVariableError,
|
6
7
|
MissingEnvironmentVariableError,
|
7
8
|
)
|
8
|
-
from dotenv import load_dotenv, find_dotenv
|
9
9
|
|
10
10
|
# valid values for EDSL_RUN_MODE
|
11
|
-
EDSL_RUN_MODES = [
|
11
|
+
EDSL_RUN_MODES = [
|
12
|
+
"development",
|
13
|
+
"development-testrun",
|
14
|
+
"production",
|
15
|
+
]
|
12
16
|
|
13
17
|
# `default` is used to impute values only in "production" mode
|
14
18
|
# `info` gives a brief description of the env var
|
15
19
|
CONFIG_MAP = {
|
16
20
|
"EDSL_RUN_MODE": {
|
17
21
|
"default": "production",
|
18
|
-
"info": "This
|
19
|
-
},
|
20
|
-
"EDSL_DATABASE_PATH": {
|
21
|
-
"default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
|
22
|
-
"info": "This env var determines the path to the cache file.",
|
23
|
-
},
|
24
|
-
"EDSL_LOGGING_PATH": {
|
25
|
-
"default": f"{os.path.join(os.getcwd(), 'interview.log')}",
|
26
|
-
"info": "This env var determines the path to the log file.",
|
22
|
+
"info": "This config var determines the run mode of the application.",
|
27
23
|
},
|
28
24
|
"EDSL_API_TIMEOUT": {
|
29
25
|
"default": "60",
|
30
|
-
"info": "This
|
26
|
+
"info": "This config var determines the maximum number of seconds to wait for an API call to return.",
|
31
27
|
},
|
32
28
|
"EDSL_BACKOFF_START_SEC": {
|
33
29
|
"default": "1",
|
34
|
-
"info": "This
|
30
|
+
"info": "This config var determines the number of seconds to wait before retrying a failed API call.",
|
35
31
|
},
|
36
|
-
"
|
32
|
+
"EDSL_BACKOFF_MAX_SEC": {
|
37
33
|
"default": "60",
|
38
|
-
"info": "This
|
34
|
+
"info": "This config var determines the maximum number of seconds to wait before retrying a failed API call.",
|
39
35
|
},
|
40
|
-
"
|
41
|
-
"default": "
|
42
|
-
"info": "This
|
36
|
+
"EDSL_DATABASE_PATH": {
|
37
|
+
"default": f"sqlite:///{os.path.join(os.getcwd(), '.edsl_cache/data.db')}",
|
38
|
+
"info": "This config var determines the path to the cache file.",
|
43
39
|
},
|
44
40
|
"EDSL_DEFAULT_MODEL": {
|
45
41
|
"default": "gpt-4o",
|
46
|
-
"info": "This
|
42
|
+
"info": "This config var holds the default model that will be used if a model is not explicitly passed.",
|
47
43
|
},
|
48
|
-
"
|
49
|
-
"default": "
|
50
|
-
"info": "This
|
44
|
+
"EDSL_FETCH_TOKEN_PRICES": {
|
45
|
+
"default": "True",
|
46
|
+
"info": "This config var determines whether to fetch prices for tokens used in remote inference",
|
47
|
+
},
|
48
|
+
"EDSL_MAX_ATTEMPTS": {
|
49
|
+
"default": "5",
|
50
|
+
"info": "This config var determines the maximum number of times to retry a failed API call.",
|
51
51
|
},
|
52
52
|
"EDSL_SERVICE_RPM_BASELINE": {
|
53
53
|
"default": "100",
|
54
|
-
"info": "This
|
54
|
+
"info": "This config var holds the maximum number of requests per minute. Model-specific values provided in env vars such as EDSL_SERVICE_RPM_OPENAI will override this. value for the corresponding model",
|
55
55
|
},
|
56
|
-
"
|
56
|
+
"EDSL_SERVICE_TPM_BASELINE": {
|
57
57
|
"default": "2000000",
|
58
|
-
"info": "This
|
59
|
-
},
|
60
|
-
"EDSL_SERVICE_RPM_OPENAI": {
|
61
|
-
"default": "100",
|
62
|
-
"info": "This env var holds the maximum number of requests per minute for OpenAI.",
|
63
|
-
},
|
64
|
-
"EDSL_FETCH_TOKEN_PRICES": {
|
65
|
-
"default": "True",
|
66
|
-
"info": "Whether to fetch the prices for tokens",
|
58
|
+
"info": "This config var holds the maximum number of tokens per minute for all models. Model-specific values provided in env vars such as EDSL_SERVICE_TPM_OPENAI will override this value for the corresponding model.",
|
67
59
|
},
|
68
60
|
"EXPECTED_PARROT_URL": {
|
69
61
|
"default": "https://www.expectedparrot.com",
|
70
|
-
"info": "This
|
62
|
+
"info": "This config var holds the URL of the Expected Parrot API.",
|
71
63
|
},
|
72
64
|
}
|
73
65
|
|
edsl/coop/coop.py
CHANGED
@@ -59,8 +59,16 @@ class Coop:
|
|
59
59
|
Send a request to the server and return the response.
|
60
60
|
"""
|
61
61
|
url = f"{self.url}/{uri}"
|
62
|
+
method = method.upper()
|
63
|
+
if payload is None:
|
64
|
+
timeout = 20
|
65
|
+
elif (
|
66
|
+
method.upper() == "POST"
|
67
|
+
and "json_string" in payload
|
68
|
+
and payload.get("json_string") is not None
|
69
|
+
):
|
70
|
+
timeout = max(20, (len(payload.get("json_string", "")) // (1024 * 1024)))
|
62
71
|
try:
|
63
|
-
method = method.upper()
|
64
72
|
if method in ["GET", "DELETE"]:
|
65
73
|
response = requests.request(
|
66
74
|
method, url, params=params, headers=self.headers, timeout=timeout
|
@@ -77,7 +85,7 @@ class Coop:
|
|
77
85
|
else:
|
78
86
|
raise Exception(f"Invalid {method=}.")
|
79
87
|
except requests.ConnectionError:
|
80
|
-
raise requests.ConnectionError("Could not connect to the server.")
|
88
|
+
raise requests.ConnectionError(f"Could not connect to the server at {url}.")
|
81
89
|
|
82
90
|
return response
|
83
91
|
|
@@ -87,6 +95,7 @@ class Coop:
|
|
87
95
|
"""
|
88
96
|
if response.status_code >= 400:
|
89
97
|
message = response.json().get("detail")
|
98
|
+
# print(response.text)
|
90
99
|
if "Authorization" in message:
|
91
100
|
print(message)
|
92
101
|
message = "Please provide an Expected Parrot API key."
|
edsl/data_transfer_models.py
CHANGED
@@ -1,4 +1,7 @@
|
|
1
1
|
from typing import NamedTuple, Dict, List, Optional, Any
|
2
|
+
from dataclasses import dataclass, fields
|
3
|
+
import reprlib
|
4
|
+
|
2
5
|
|
3
6
|
|
4
7
|
class ModelInputs(NamedTuple):
|
@@ -45,76 +48,27 @@ class EDSLResultObjectInput(NamedTuple):
|
|
45
48
|
cost: Optional[float] = None
|
46
49
|
|
47
50
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
#
|
59
|
-
|
60
|
-
#
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
#
|
68
|
-
|
69
|
-
|
70
|
-
#
|
71
|
-
|
72
|
-
# raise ValueError("generated_tokens must be provided")
|
73
|
-
# self.data = {
|
74
|
-
# "answer": answer,
|
75
|
-
# "comment": comment,
|
76
|
-
# "question_name": question_name,
|
77
|
-
# "prompts": prompts,
|
78
|
-
# "usage": usage,
|
79
|
-
# "cached_response": cached_response,
|
80
|
-
# "raw_model_response": raw_model_response,
|
81
|
-
# "simple_model_raw_response": simple_model_raw_response,
|
82
|
-
# "cache_used": cache_used,
|
83
|
-
# "cache_key": cache_key,
|
84
|
-
# "generated_tokens": generated_tokens,
|
85
|
-
# }
|
86
|
-
|
87
|
-
# @property
|
88
|
-
# def data(self):
|
89
|
-
# return self._data
|
90
|
-
|
91
|
-
# @data.setter
|
92
|
-
# def data(self, value):
|
93
|
-
# self._data = value
|
94
|
-
|
95
|
-
# def __getitem__(self, key):
|
96
|
-
# return self.data.get(key, None)
|
97
|
-
|
98
|
-
# def __setitem__(self, key, value):
|
99
|
-
# self.data[key] = value
|
100
|
-
|
101
|
-
# def __delitem__(self, key):
|
102
|
-
# del self.data[key]
|
103
|
-
|
104
|
-
# def __iter__(self):
|
105
|
-
# return iter(self.data)
|
106
|
-
|
107
|
-
# def __len__(self):
|
108
|
-
# return len(self.data)
|
109
|
-
|
110
|
-
# def keys(self):
|
111
|
-
# return self.data.keys()
|
112
|
-
|
113
|
-
# def values(self):
|
114
|
-
# return self.data.values()
|
115
|
-
|
116
|
-
# def items(self):
|
117
|
-
# return self.data.items()
|
118
|
-
|
119
|
-
# def is_this_same_model(self):
|
120
|
-
# return True
|
51
|
+
@dataclass
|
52
|
+
class ImageInfo:
|
53
|
+
file_path: str
|
54
|
+
file_name: str
|
55
|
+
image_format: str
|
56
|
+
file_size: int
|
57
|
+
encoded_image: str
|
58
|
+
|
59
|
+
def __repr__(self):
|
60
|
+
reprlib_instance = reprlib.Repr()
|
61
|
+
reprlib_instance.maxstring = 30 # Limit the string length for the encoded image
|
62
|
+
|
63
|
+
# Get all fields except encoded_image
|
64
|
+
field_reprs = [
|
65
|
+
f"{f.name}={getattr(self, f.name)!r}"
|
66
|
+
for f in fields(self)
|
67
|
+
if f.name != "encoded_image"
|
68
|
+
]
|
69
|
+
|
70
|
+
# Add the reprlib-restricted encoded_image field
|
71
|
+
field_reprs.append(f"encoded_image={reprlib_instance.repr(self.encoded_image)}")
|
72
|
+
|
73
|
+
# Join everything to create the repr
|
74
|
+
return f"{self.__class__.__name__}({', '.join(field_reprs)})"
|
edsl/enums.py
CHANGED
@@ -63,6 +63,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
63
63
|
AZURE = "azure"
|
64
64
|
OLLAMA = "ollama"
|
65
65
|
MISTRAL = "mistral"
|
66
|
+
TOGETHER = "together"
|
66
67
|
|
67
68
|
|
68
69
|
service_to_api_keyname = {
|
@@ -76,6 +77,7 @@ service_to_api_keyname = {
|
|
76
77
|
InferenceServiceType.GROQ.value: "GROQ_API_KEY",
|
77
78
|
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
78
79
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
|
+
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
79
81
|
}
|
80
82
|
|
81
83
|
|