edsl 0.1.35__py3-none-any.whl → 0.1.36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +5 -0
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +37 -9
- edsl/agents/Invigilator.py +2 -1
- edsl/agents/InvigilatorBase.py +5 -1
- edsl/agents/PromptConstructor.py +31 -67
- edsl/conversation/Conversation.py +1 -1
- edsl/coop/PriceFetcher.py +14 -18
- edsl/coop/coop.py +42 -8
- edsl/data/RemoteCacheSync.py +97 -0
- edsl/exceptions/coop.py +8 -0
- edsl/inference_services/InferenceServiceABC.py +28 -0
- edsl/inference_services/InferenceServicesCollection.py +10 -4
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/inference_services/registry.py +24 -16
- edsl/jobs/Jobs.py +327 -206
- edsl/jobs/interviews/Interview.py +65 -10
- edsl/jobs/interviews/InterviewExceptionCollection.py +9 -0
- edsl/jobs/interviews/InterviewExceptionEntry.py +31 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +8 -13
- edsl/jobs/tasks/QuestionTaskCreator.py +1 -5
- edsl/jobs/tasks/TaskHistory.py +23 -7
- edsl/language_models/LanguageModel.py +3 -0
- edsl/prompts/Prompt.py +24 -38
- edsl/prompts/__init__.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +18 -18
- edsl/questions/QuestionFunctional.py +7 -3
- edsl/questions/descriptors.py +24 -24
- edsl/results/Dataset.py +12 -0
- edsl/results/Result.py +2 -0
- edsl/results/Results.py +13 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +15 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Survey.py +3 -0
- edsl/surveys/instructions/Instruction.py +20 -3
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/METADATA +1 -1
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/RECORD +41 -57
- edsl/jobs/FailedQuestion.py +0 -78
- edsl/jobs/interviews/InterviewStatusMixin.py +0 -33
- edsl/jobs/tasks/task_management.py +0 -13
- edsl/prompts/QuestionInstructionsBase.py +0 -10
- edsl/prompts/library/agent_instructions.py +0 -38
- edsl/prompts/library/agent_persona.py +0 -21
- edsl/prompts/library/question_budget.py +0 -30
- edsl/prompts/library/question_checkbox.py +0 -38
- edsl/prompts/library/question_extract.py +0 -23
- edsl/prompts/library/question_freetext.py +0 -18
- edsl/prompts/library/question_linear_scale.py +0 -24
- edsl/prompts/library/question_list.py +0 -26
- edsl/prompts/library/question_multiple_choice.py +0 -54
- edsl/prompts/library/question_numerical.py +0 -35
- edsl/prompts/library/question_rank.py +0 -25
- edsl/prompts/prompt_config.py +0 -37
- edsl/prompts/registry.py +0 -202
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/LICENSE +0 -0
- {edsl-0.1.35.dist-info → edsl-0.1.36.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -7,6 +7,8 @@ import json
|
|
7
7
|
from typing import Any, Optional, Union
|
8
8
|
from uuid import UUID
|
9
9
|
|
10
|
+
# from edsl.utilities.MethodSuggesterMixin import MethodSuggesterMixin
|
11
|
+
|
10
12
|
|
11
13
|
class RichPrintingMixin:
|
12
14
|
"""Mixin for rich printing and persistence of objects."""
|
@@ -274,6 +276,9 @@ class Base(
|
|
274
276
|
"""This method should be implemented by subclasses."""
|
275
277
|
raise NotImplementedError("This method is not implemented yet.")
|
276
278
|
|
279
|
+
def to_json(self):
|
280
|
+
return json.dumps(self.to_dict())
|
281
|
+
|
277
282
|
@abstractmethod
|
278
283
|
def from_dict():
|
279
284
|
"""This method should be implemented by subclasses."""
|
edsl/__init__.py
CHANGED
@@ -27,6 +27,7 @@ from edsl.questions import QuestionTopK
|
|
27
27
|
|
28
28
|
from edsl.scenarios import Scenario
|
29
29
|
from edsl.scenarios import ScenarioList
|
30
|
+
from edsl.scenarios.FileStore import FileStore
|
30
31
|
|
31
32
|
# from edsl.utilities.interface import print_dict_with_rich
|
32
33
|
from edsl.surveys.Survey import Survey
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.36"
|
edsl/agents/Agent.py
CHANGED
@@ -8,6 +8,9 @@ from typing import Callable, Optional, Union, Any
|
|
8
8
|
from uuid import uuid4
|
9
9
|
from edsl.Base import Base
|
10
10
|
|
11
|
+
from edsl.prompts import Prompt
|
12
|
+
from edsl.exceptions import QuestionScenarioRenderError
|
13
|
+
|
11
14
|
from edsl.exceptions.agents import (
|
12
15
|
AgentCombinationError,
|
13
16
|
AgentDirectAnswerFunctionError,
|
@@ -44,7 +47,6 @@ class Agent(Base):
|
|
44
47
|
|
45
48
|
def __init__(
|
46
49
|
self,
|
47
|
-
# *,
|
48
50
|
traits: Optional[dict] = None,
|
49
51
|
name: Optional[str] = None,
|
50
52
|
codebook: Optional[dict] = None,
|
@@ -109,9 +111,14 @@ class Agent(Base):
|
|
109
111
|
self.name = name
|
110
112
|
self._traits = traits or dict()
|
111
113
|
self.codebook = codebook or dict()
|
112
|
-
|
114
|
+
if instruction is None:
|
115
|
+
self.instruction = self.default_instruction
|
116
|
+
else:
|
117
|
+
self.instruction = instruction
|
118
|
+
# self.instruction = instruction or self.default_instruction
|
113
119
|
self.dynamic_traits_function = dynamic_traits_function
|
114
120
|
|
121
|
+
# Deal with dynamic traits function
|
115
122
|
if self.dynamic_traits_function:
|
116
123
|
self.dynamic_traits_function_name = self.dynamic_traits_function.__name__
|
117
124
|
self.has_dynamic_traits_function = True
|
@@ -124,6 +131,7 @@ class Agent(Base):
|
|
124
131
|
dynamic_traits_function_name, dynamic_traits_function
|
125
132
|
)
|
126
133
|
|
134
|
+
# Deal with direct answer function
|
127
135
|
if answer_question_directly_source_code:
|
128
136
|
self.answer_question_directly_function_name = (
|
129
137
|
answer_question_directly_function_name
|
@@ -140,10 +148,34 @@ class Agent(Base):
|
|
140
148
|
self.current_question = None
|
141
149
|
|
142
150
|
if traits_presentation_template is not None:
|
143
|
-
from edsl.prompts.library.agent_persona import AgentPersona
|
144
|
-
|
145
151
|
self.traits_presentation_template = traits_presentation_template
|
146
|
-
|
152
|
+
else:
|
153
|
+
self.traits_presentation_template = """Your traits: {{ traits }}"""
|
154
|
+
|
155
|
+
@property
|
156
|
+
def agent_persona(self) -> Prompt:
|
157
|
+
return Prompt(text=self.traits_presentation_template)
|
158
|
+
|
159
|
+
def prompt(self) -> str:
|
160
|
+
"""Return the prompt for the agent.
|
161
|
+
|
162
|
+
Example usage:
|
163
|
+
|
164
|
+
>>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
|
165
|
+
>>> a.prompt()
|
166
|
+
Prompt(text=\"""Your traits: {'age': 10, 'hair': 'brown', 'height': 5.5}\""")
|
167
|
+
"""
|
168
|
+
replacement_dict = (
|
169
|
+
self.traits | {"traits": self.traits} | {"codebook": self.codebook}
|
170
|
+
)
|
171
|
+
if undefined := self.agent_persona.undefined_template_variables(
|
172
|
+
replacement_dict
|
173
|
+
):
|
174
|
+
raise QuestionScenarioRenderError(
|
175
|
+
f"Agent persona still has variables that were not rendered: {undefined}"
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
return self.agent_persona.render(replacement_dict)
|
147
179
|
|
148
180
|
def _check_dynamic_traits_function(self) -> None:
|
149
181
|
"""Check whether dynamic trait function is valid.
|
@@ -252,7 +284,6 @@ class Agent(Base):
|
|
252
284
|
warnings.warn(
|
253
285
|
"Warning: overwriting existing answer_question_directly method"
|
254
286
|
)
|
255
|
-
# print("Warning: overwriting existing answer_question_directly method")
|
256
287
|
|
257
288
|
self.validate_response = validate_response
|
258
289
|
self.translate_response = translate_response
|
@@ -575,9 +606,6 @@ class Agent(Base):
|
|
575
606
|
if self.name == None:
|
576
607
|
raw_data.pop("name")
|
577
608
|
|
578
|
-
import inspect
|
579
|
-
|
580
|
-
# print(raw_data)
|
581
609
|
if hasattr(self, "dynamic_traits_function"):
|
582
610
|
raw_data.pop(
|
583
611
|
"dynamic_traits_function", None
|
edsl/agents/Invigilator.py
CHANGED
@@ -4,7 +4,8 @@ from typing import Dict, Any, Optional
|
|
4
4
|
|
5
5
|
from edsl.prompts.Prompt import Prompt
|
6
6
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
7
|
-
|
7
|
+
|
8
|
+
# from edsl.prompts.registry import get_classes as prompt_lookup
|
8
9
|
from edsl.exceptions.questions import QuestionAnswerValidationError
|
9
10
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
10
11
|
from edsl.data_transfer_models import AgentResponseDict, EDSLResultObjectInput
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -115,7 +115,11 @@ class InvigilatorBase(ABC):
|
|
115
115
|
iteration = data["iteration"]
|
116
116
|
additional_prompt_data = data["additional_prompt_data"]
|
117
117
|
cache = Cache.from_dict(data["cache"])
|
118
|
-
|
118
|
+
|
119
|
+
if data["sidecar_model"] is None:
|
120
|
+
sidecar_model = None
|
121
|
+
else:
|
122
|
+
sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
|
119
123
|
|
120
124
|
return cls(
|
121
125
|
agent=agent,
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -1,16 +1,11 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Dict, Any, Optional, Set
|
3
|
-
from collections import UserList
|
4
|
-
import pdb
|
5
3
|
|
6
4
|
from jinja2 import Environment, meta
|
7
5
|
|
8
6
|
from edsl.prompts.Prompt import Prompt
|
9
|
-
from edsl.data_transfer_models import ImageInfo
|
10
|
-
from edsl.prompts.registry import get_classes as prompt_lookup
|
11
|
-
from edsl.exceptions import QuestionScenarioRenderError
|
12
7
|
|
13
|
-
from edsl.agents.prompt_helpers import
|
8
|
+
from edsl.agents.prompt_helpers import PromptPlan
|
14
9
|
|
15
10
|
|
16
11
|
def get_jinja2_variables(template_str: str) -> Set[str]:
|
@@ -75,17 +70,8 @@ class PromptConstructor:
|
|
75
70
|
|
76
71
|
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
77
72
|
return Prompt(text="")
|
78
|
-
|
79
|
-
|
80
|
-
component_type="agent_instructions",
|
81
|
-
model=self.model.model,
|
82
|
-
)
|
83
|
-
if len(applicable_prompts) == 0:
|
84
|
-
raise Exception("No applicable prompts found")
|
85
|
-
self._agent_instructions_prompt = applicable_prompts[0](
|
86
|
-
text=self.agent.instruction
|
87
|
-
)
|
88
|
-
return self._agent_instructions_prompt
|
73
|
+
|
74
|
+
return Prompt(text=self.agent.instruction)
|
89
75
|
|
90
76
|
@property
|
91
77
|
def agent_persona_prompt(self) -> Prompt:
|
@@ -93,51 +79,14 @@ class PromptConstructor:
|
|
93
79
|
>>> from edsl.agents.InvigilatorBase import InvigilatorBase
|
94
80
|
>>> i = InvigilatorBase.example()
|
95
81
|
>>> i.prompt_constructor.agent_persona_prompt
|
96
|
-
Prompt(text=\"""
|
97
|
-
{'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
98
|
-
|
82
|
+
Prompt(text=\"""Your traits: {'age': 22, 'hair': 'brown', 'height': 5.5}\""")
|
99
83
|
"""
|
100
84
|
from edsl import Agent
|
101
85
|
|
102
|
-
if hasattr(self, "_agent_persona_prompt"):
|
103
|
-
return self._agent_persona_prompt
|
104
|
-
|
105
86
|
if self.agent == Agent(): # if agent is empty, then return an empty prompt
|
106
87
|
return Prompt(text="")
|
107
88
|
|
108
|
-
|
109
|
-
applicable_prompts = prompt_lookup(
|
110
|
-
component_type="agent_persona",
|
111
|
-
model=self.model.model,
|
112
|
-
)
|
113
|
-
persona_prompt_template = applicable_prompts[0]()
|
114
|
-
else:
|
115
|
-
persona_prompt_template = self.agent.agent_persona
|
116
|
-
|
117
|
-
# TODO: This multiple passing of agent traits - not sure if it is necessary. Not harmful.
|
118
|
-
template_parameter_dictionary = (
|
119
|
-
self.agent.traits
|
120
|
-
| {"traits": self.agent.traits}
|
121
|
-
| {"codebook": self.agent.codebook}
|
122
|
-
| {"traits": self.agent.traits}
|
123
|
-
)
|
124
|
-
|
125
|
-
if undefined := persona_prompt_template.undefined_template_variables(
|
126
|
-
template_parameter_dictionary
|
127
|
-
):
|
128
|
-
raise QuestionScenarioRenderError(
|
129
|
-
f"Agent persona still has variables that were not rendered: {undefined}"
|
130
|
-
)
|
131
|
-
|
132
|
-
persona_prompt = persona_prompt_template.render(template_parameter_dictionary)
|
133
|
-
if persona_prompt.has_variables:
|
134
|
-
raise QuestionScenarioRenderError(
|
135
|
-
"Agent persona still has variables that were not rendered."
|
136
|
-
)
|
137
|
-
|
138
|
-
self._agent_persona_prompt = persona_prompt
|
139
|
-
|
140
|
-
return self._agent_persona_prompt
|
89
|
+
return self.agent.prompt()
|
141
90
|
|
142
91
|
def prior_answers_dict(self) -> dict:
|
143
92
|
d = self.survey.question_names_to_questions()
|
@@ -198,14 +147,28 @@ class PromptConstructor:
|
|
198
147
|
|
199
148
|
# might be getting it from the prior answers
|
200
149
|
if self.prior_answers_dict().get(question_option_key) is not None:
|
201
|
-
|
202
|
-
|
203
|
-
.
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
150
|
+
prior_question = self.prior_answers_dict().get(question_option_key)
|
151
|
+
if hasattr(prior_question, "answer"):
|
152
|
+
if isinstance(prior_question.answer, list):
|
153
|
+
question_data["question_options"] = prior_question.answer
|
154
|
+
self.question.question_options = prior_question.answer
|
155
|
+
else:
|
156
|
+
placeholder_options = [
|
157
|
+
"N/A",
|
158
|
+
"Will be populated by prior answer",
|
159
|
+
"These are placeholder options",
|
160
|
+
]
|
161
|
+
question_data["question_options"] = placeholder_options
|
162
|
+
self.question.question_options = placeholder_options
|
163
|
+
|
164
|
+
# if isinstance(
|
165
|
+
# question_options := self.prior_answers_dict()
|
166
|
+
# .get(question_option_key)
|
167
|
+
# .answer,
|
168
|
+
# list,
|
169
|
+
# ):
|
170
|
+
# question_data["question_options"] = question_options
|
171
|
+
# self.question.question_options = question_options
|
209
172
|
|
210
173
|
replacement_dict = (
|
211
174
|
{key: f"<see file {key}>" for key in self.scenario_file_keys}
|
@@ -257,9 +220,10 @@ class PromptConstructor:
|
|
257
220
|
)
|
258
221
|
|
259
222
|
if relevant_instructions != []:
|
260
|
-
preamble_text = Prompt(
|
261
|
-
|
262
|
-
)
|
223
|
+
# preamble_text = Prompt(
|
224
|
+
# text="You were given the following instructions: "
|
225
|
+
# )
|
226
|
+
preamble_text = Prompt(text="")
|
263
227
|
for instruction in relevant_instructions:
|
264
228
|
preamble_text += instruction.text
|
265
229
|
rendered_instructions = preamble_text + rendered_instructions
|
@@ -169,6 +169,7 @@ class Conversation:
|
|
169
169
|
agent=speaker,
|
170
170
|
just_answer=False,
|
171
171
|
cache=self.cache,
|
172
|
+
model=speaker.model,
|
172
173
|
)
|
173
174
|
return results[0]
|
174
175
|
|
@@ -179,7 +180,6 @@ class Conversation:
|
|
179
180
|
i = 0
|
180
181
|
while await self.continue_conversation():
|
181
182
|
speaker = self.next_speaker()
|
182
|
-
# breakpoint()
|
183
183
|
|
184
184
|
next_statement = AgentStatement(
|
185
185
|
statement=await self.get_next_statement(
|
edsl/coop/PriceFetcher.py
CHANGED
@@ -16,30 +16,26 @@ class PriceFetcher:
|
|
16
16
|
if self._cached_prices is not None:
|
17
17
|
return self._cached_prices
|
18
18
|
|
19
|
+
import os
|
19
20
|
import requests
|
20
|
-
import
|
21
|
-
from io import StringIO
|
22
|
-
|
23
|
-
sheet_id = "1SAO3Bhntefl0XQHJv27rMxpvu6uzKDWNXFHRa7jrUDs"
|
24
|
-
|
25
|
-
# Construct the URL to fetch the CSV
|
26
|
-
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
|
21
|
+
from edsl import CONFIG
|
27
22
|
|
28
23
|
try:
|
29
|
-
# Fetch the
|
30
|
-
|
24
|
+
# Fetch the pricing data
|
25
|
+
url = f"{CONFIG.EXPECTED_PARROT_URL}/api/v0/prices"
|
26
|
+
api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
27
|
+
headers = {}
|
28
|
+
if api_key:
|
29
|
+
headers["Authorization"] = f"Bearer {api_key}"
|
30
|
+
else:
|
31
|
+
headers["Authorization"] = f"Bearer None"
|
32
|
+
|
33
|
+
response = requests.get(url, headers=headers, timeout=20)
|
31
34
|
response.raise_for_status() # Raise an exception for bad responses
|
32
35
|
|
33
|
-
# Parse the
|
34
|
-
|
35
|
-
reader = csv.reader(csv_data)
|
36
|
-
|
37
|
-
# Convert to list of dictionaries
|
38
|
-
headers = next(reader)
|
39
|
-
data = [dict(zip(headers, row)) for row in reader]
|
36
|
+
# Parse the data
|
37
|
+
data = response.json()
|
40
38
|
|
41
|
-
# self._cached_prices = data
|
42
|
-
# return data
|
43
39
|
price_lookup = {}
|
44
40
|
for entry in data:
|
45
41
|
service = entry.get("service", None)
|
edsl/coop/coop.py
CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, Literal
|
|
6
6
|
from uuid import UUID
|
7
7
|
import edsl
|
8
8
|
from edsl import CONFIG, CacheEntry, Jobs, Survey
|
9
|
+
from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
|
9
10
|
from edsl.coop.utils import (
|
10
11
|
EDSLObject,
|
11
12
|
ObjectRegistry,
|
@@ -99,7 +100,7 @@ class Coop:
|
|
99
100
|
if "Authorization" in message:
|
100
101
|
print(message)
|
101
102
|
message = "Please provide an Expected Parrot API key."
|
102
|
-
raise
|
103
|
+
raise CoopServerResponseError(message)
|
103
104
|
|
104
105
|
def _json_handle_none(self, value: Any) -> Any:
|
105
106
|
"""
|
@@ -116,7 +117,7 @@ class Coop:
|
|
116
117
|
Resolve the uuid from a uuid or a url.
|
117
118
|
"""
|
118
119
|
if not url and not uuid:
|
119
|
-
raise
|
120
|
+
raise CoopNoUUIDError("No uuid or url provided for the object.")
|
120
121
|
if not uuid and url:
|
121
122
|
uuid = url.split("/")[-1]
|
122
123
|
return uuid
|
@@ -521,7 +522,7 @@ class Coop:
|
|
521
522
|
self._resolve_server_response(response)
|
522
523
|
response_json = response.json()
|
523
524
|
return {
|
524
|
-
"uuid": response_json.get("
|
525
|
+
"uuid": response_json.get("job_uuid"),
|
525
526
|
"description": response_json.get("description"),
|
526
527
|
"status": response_json.get("status"),
|
527
528
|
"iterations": response_json.get("iterations"),
|
@@ -529,29 +530,41 @@ class Coop:
|
|
529
530
|
"version": self._edsl_version,
|
530
531
|
}
|
531
532
|
|
532
|
-
def remote_inference_get(
|
533
|
+
def remote_inference_get(
|
534
|
+
self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
|
535
|
+
) -> dict:
|
533
536
|
"""
|
534
537
|
Get the details of a remote inference job.
|
538
|
+
You can pass either the job uuid or the results uuid as a parameter.
|
539
|
+
If you pass both, the job uuid will be prioritized.
|
535
540
|
|
536
541
|
:param job_uuid: The UUID of the EDSL job.
|
542
|
+
:param results_uuid: The UUID of the results associated with the EDSL job.
|
537
543
|
|
538
544
|
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
539
545
|
{'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
|
540
546
|
"""
|
547
|
+
if job_uuid is None and results_uuid is None:
|
548
|
+
raise ValueError("Either job_uuid or results_uuid must be provided.")
|
549
|
+
elif job_uuid is not None:
|
550
|
+
params = {"job_uuid": job_uuid}
|
551
|
+
else:
|
552
|
+
params = {"results_uuid": results_uuid}
|
553
|
+
|
541
554
|
response = self._send_server_request(
|
542
555
|
uri="api/v0/remote-inference",
|
543
556
|
method="GET",
|
544
|
-
params=
|
557
|
+
params=params,
|
545
558
|
)
|
546
559
|
self._resolve_server_response(response)
|
547
560
|
data = response.json()
|
548
561
|
return {
|
549
|
-
"
|
562
|
+
"job_uuid": data.get("job_uuid"),
|
550
563
|
"results_uuid": data.get("results_uuid"),
|
551
564
|
"results_url": f"{self.url}/content/{data.get('results_uuid')}",
|
552
565
|
"status": data.get("status"),
|
553
566
|
"reason": data.get("reason"),
|
554
|
-
"
|
567
|
+
"credits_consumed": data.get("price"),
|
555
568
|
"version": data.get("version"),
|
556
569
|
}
|
557
570
|
|
@@ -584,7 +597,10 @@ class Coop:
|
|
584
597
|
)
|
585
598
|
self._resolve_server_response(response)
|
586
599
|
response_json = response.json()
|
587
|
-
return
|
600
|
+
return {
|
601
|
+
"credits": response_json.get("cost_in_credits"),
|
602
|
+
"usd": response_json.get("cost_in_usd"),
|
603
|
+
}
|
588
604
|
|
589
605
|
################
|
590
606
|
# Remote Errors
|
@@ -649,6 +665,10 @@ class Coop:
|
|
649
665
|
return response_json
|
650
666
|
|
651
667
|
def fetch_prices(self) -> dict:
|
668
|
+
"""
|
669
|
+
Fetch model prices from Coop. If the request fails, return an empty dict.
|
670
|
+
"""
|
671
|
+
|
652
672
|
from edsl.coop.PriceFetcher import PriceFetcher
|
653
673
|
|
654
674
|
from edsl.config import CONFIG
|
@@ -659,6 +679,20 @@ class Coop:
|
|
659
679
|
else:
|
660
680
|
return {}
|
661
681
|
|
682
|
+
def fetch_rate_limit_config_vars(self) -> dict:
|
683
|
+
"""
|
684
|
+
Fetch a dict of rate limit config vars from Coop.
|
685
|
+
|
686
|
+
The dict keys are RPM and TPM variables like EDSL_SERVICE_RPM_OPENAI.
|
687
|
+
"""
|
688
|
+
response = self._send_server_request(
|
689
|
+
uri="api/v0/config-vars",
|
690
|
+
method="GET",
|
691
|
+
)
|
692
|
+
self._resolve_server_response(response)
|
693
|
+
data = response.json()
|
694
|
+
return data
|
695
|
+
|
662
696
|
|
663
697
|
if __name__ == "__main__":
|
664
698
|
sheet_data = fetch_sheet_data()
|
@@ -0,0 +1,97 @@
|
|
1
|
+
class RemoteCacheSync:
|
2
|
+
def __init__(
|
3
|
+
self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
|
4
|
+
):
|
5
|
+
self.coop = coop
|
6
|
+
self.cache = cache
|
7
|
+
self._output = output_func
|
8
|
+
self.remote_cache = remote_cache
|
9
|
+
self.old_entry_keys = []
|
10
|
+
self.new_cache_entries = []
|
11
|
+
self.remote_cache_description = remote_cache_description
|
12
|
+
|
13
|
+
def __enter__(self):
|
14
|
+
if self.remote_cache:
|
15
|
+
self._sync_from_remote()
|
16
|
+
self.old_entry_keys = list(self.cache.keys())
|
17
|
+
return self
|
18
|
+
|
19
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
20
|
+
if self.remote_cache:
|
21
|
+
self._sync_to_remote()
|
22
|
+
return False # Propagate exceptions
|
23
|
+
|
24
|
+
def _sync_from_remote(self):
|
25
|
+
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
26
|
+
client_missing_cacheentries = cache_difference.get(
|
27
|
+
"client_missing_cacheentries", []
|
28
|
+
)
|
29
|
+
missing_entry_count = len(client_missing_cacheentries)
|
30
|
+
|
31
|
+
if missing_entry_count > 0:
|
32
|
+
self._output(
|
33
|
+
f"Updating local cache with {missing_entry_count:,} new "
|
34
|
+
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
35
|
+
)
|
36
|
+
self.cache.add_from_dict(
|
37
|
+
{entry.key: entry for entry in client_missing_cacheentries}
|
38
|
+
)
|
39
|
+
self._output("Local cache updated!")
|
40
|
+
else:
|
41
|
+
self._output("No new entries to add to local cache.")
|
42
|
+
|
43
|
+
def _sync_to_remote(self):
|
44
|
+
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
45
|
+
server_missing_cacheentry_keys = cache_difference.get(
|
46
|
+
"server_missing_cacheentry_keys", []
|
47
|
+
)
|
48
|
+
server_missing_cacheentries = [
|
49
|
+
entry
|
50
|
+
for key in server_missing_cacheentry_keys
|
51
|
+
if (entry := self.cache.data.get(key)) is not None
|
52
|
+
]
|
53
|
+
|
54
|
+
new_cache_entries = [
|
55
|
+
entry
|
56
|
+
for entry in self.cache.values()
|
57
|
+
if entry.key not in self.old_entry_keys
|
58
|
+
]
|
59
|
+
server_missing_cacheentries.extend(new_cache_entries)
|
60
|
+
new_entry_count = len(server_missing_cacheentries)
|
61
|
+
|
62
|
+
if new_entry_count > 0:
|
63
|
+
self._output(
|
64
|
+
f"Updating remote cache with {new_entry_count:,} new "
|
65
|
+
f"{'entry' if new_entry_count == 1 else 'entries'}..."
|
66
|
+
)
|
67
|
+
self.coop.remote_cache_create_many(
|
68
|
+
server_missing_cacheentries,
|
69
|
+
visibility="private",
|
70
|
+
description=self.remote_cache_description,
|
71
|
+
)
|
72
|
+
self._output("Remote cache updated!")
|
73
|
+
else:
|
74
|
+
self._output("No new entries to add to remote cache.")
|
75
|
+
|
76
|
+
self._output(
|
77
|
+
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
|
+
)
|
79
|
+
|
80
|
+
|
81
|
+
# # Usage example
|
82
|
+
# def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
|
83
|
+
# with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
|
84
|
+
# self._output("Running job...")
|
85
|
+
# results = self._run_local(
|
86
|
+
# n=n,
|
87
|
+
# progress_bar=progress_bar,
|
88
|
+
# cache=cache,
|
89
|
+
# stop_on_exception=stop_on_exception,
|
90
|
+
# sidecar_model=sidecar_model,
|
91
|
+
# print_exceptions=print_exceptions,
|
92
|
+
# raise_validation_errors=raise_validation_errors,
|
93
|
+
# )
|
94
|
+
# self._output("Job completed!")
|
95
|
+
|
96
|
+
# results.cache = cache.new_entries_cache()
|
97
|
+
# return results
|
edsl/exceptions/coop.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
2
|
import os
|
3
3
|
import re
|
4
|
+
from datetime import datetime, timedelta
|
4
5
|
from edsl.config import CONFIG
|
5
6
|
|
6
7
|
|
@@ -10,6 +11,8 @@ class InferenceServiceABC(ABC):
|
|
10
11
|
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
11
12
|
"""
|
12
13
|
|
14
|
+
_coop_config_vars = None
|
15
|
+
|
13
16
|
default_levels = {
|
14
17
|
"google": {"tpm": 2_000_000, "rpm": 15},
|
15
18
|
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
@@ -31,12 +34,37 @@ class InferenceServiceABC(ABC):
|
|
31
34
|
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
32
35
|
)
|
33
36
|
|
37
|
+
@classmethod
|
38
|
+
def _should_refresh_coop_config_vars(cls):
|
39
|
+
"""
|
40
|
+
Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
|
41
|
+
"""
|
42
|
+
|
43
|
+
if cls._last_config_fetch is None:
|
44
|
+
return True
|
45
|
+
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
|
+
|
34
47
|
@classmethod
|
35
48
|
def _get_limt(cls, limit_type: str) -> int:
|
36
49
|
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
37
50
|
if key in os.environ:
|
38
51
|
return int(os.getenv(key))
|
39
52
|
|
53
|
+
if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
|
54
|
+
try:
|
55
|
+
from edsl import Coop
|
56
|
+
|
57
|
+
c = Coop()
|
58
|
+
cls._coop_config_vars = c.fetch_rate_limit_config_vars()
|
59
|
+
cls._last_config_fetch = datetime.now()
|
60
|
+
if key in cls._coop_config_vars:
|
61
|
+
return cls._coop_config_vars[key]
|
62
|
+
except Exception:
|
63
|
+
cls._coop_config_vars = None
|
64
|
+
else:
|
65
|
+
if key in cls._coop_config_vars:
|
66
|
+
return cls._coop_config_vars[key]
|
67
|
+
|
40
68
|
if cls._inference_service_ in cls.default_levels:
|
41
69
|
return int(cls.default_levels[cls._inference_service_][limit_type])
|
42
70
|
|
@@ -56,13 +56,19 @@ class InferenceServicesCollection:
|
|
56
56
|
self.services.append(service)
|
57
57
|
|
58
58
|
def create_model_factory(self, model_name: str, service_name=None, index=None):
|
59
|
+
from edsl.inference_services.TestService import TestService
|
60
|
+
|
61
|
+
if model_name == "test":
|
62
|
+
return TestService.create_model(model_name)
|
63
|
+
|
64
|
+
if service_name:
|
65
|
+
for service in self.services:
|
66
|
+
if service_name == service._inference_service_:
|
67
|
+
return service.create_model(model_name)
|
68
|
+
|
59
69
|
for service in self.services:
|
60
70
|
if model_name in self._get_service_available(service):
|
61
71
|
if service_name is None or service_name == service._inference_service_:
|
62
72
|
return service.create_model(model_name)
|
63
73
|
|
64
|
-
# if model_name == "test":
|
65
|
-
# from edsl.language_models import LanguageModel
|
66
|
-
# return LanguageModel(test = True)
|
67
|
-
|
68
74
|
raise Exception(f"Model {model_name} not found in any of the services")
|