edsl 0.1.38__py3-none-any.whl → 0.1.38.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +31 -60
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +9 -18
- edsl/agents/AgentList.py +8 -59
- edsl/agents/Invigilator.py +7 -18
- edsl/agents/InvigilatorBase.py +19 -0
- edsl/agents/PromptConstructor.py +4 -5
- edsl/config.py +0 -8
- edsl/coop/coop.py +7 -74
- edsl/data/Cache.py +2 -27
- edsl/data/CacheEntry.py +3 -8
- edsl/data/RemoteCacheSync.py +19 -0
- edsl/enums.py +0 -2
- edsl/inference_services/GoogleService.py +15 -7
- edsl/inference_services/registry.py +0 -2
- edsl/jobs/Jobs.py +548 -88
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +35 -140
- edsl/jobs/runners/JobsRunnerStatus.py +2 -0
- edsl/jobs/tasks/TaskHistory.py +16 -15
- edsl/language_models/LanguageModel.py +84 -44
- edsl/language_models/ModelList.py +1 -47
- edsl/language_models/registry.py +4 -57
- edsl/prompts/Prompt.py +3 -8
- edsl/questions/QuestionBase.py +16 -20
- edsl/questions/QuestionExtract.py +4 -3
- edsl/questions/question_registry.py +6 -36
- edsl/results/Dataset.py +15 -146
- edsl/results/DatasetExportMixin.py +217 -231
- edsl/results/DatasetTree.py +4 -134
- edsl/results/Result.py +9 -18
- edsl/results/Results.py +51 -145
- edsl/scenarios/FileStore.py +13 -187
- edsl/scenarios/Scenario.py +4 -61
- edsl/scenarios/ScenarioList.py +62 -237
- edsl/surveys/Survey.py +2 -16
- edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
- edsl/surveys/instructions/Instruction.py +0 -12
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +9 -18
- edsl/utilities/utilities.py +0 -15
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/METADATA +1 -2
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/RECORD +45 -53
- edsl/inference_services/PerplexityService.py +0 -163
- edsl/jobs/JobsChecks.py +0 -147
- edsl/jobs/JobsPrompts.py +0 -268
- edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
- edsl/results/CSSParameterizer.py +0 -108
- edsl/results/TableDisplay.py +0 -198
- edsl/results/table_display.css +0 -78
- edsl/scenarios/ScenarioJoin.py +0 -127
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dist-info → edsl-0.1.38.dev2.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -9,46 +9,37 @@ from uuid import UUID
|
|
9
9
|
|
10
10
|
# from edsl.utilities.MethodSuggesterMixin import MethodSuggesterMixin
|
11
11
|
|
12
|
-
from edsl.utilities.utilities import is_notebook
|
13
|
-
|
14
12
|
|
15
13
|
class RichPrintingMixin:
|
16
|
-
|
17
|
-
|
18
|
-
# def print(self):
|
19
|
-
# print(self)
|
20
|
-
|
14
|
+
"""Mixin for rich printing and persistence of objects."""
|
21
15
|
|
22
|
-
|
16
|
+
def _for_console(self):
|
17
|
+
"""Return a string representation of the object for console printing."""
|
18
|
+
from rich.console import Console
|
23
19
|
|
24
|
-
|
25
|
-
|
26
|
-
|
20
|
+
with io.StringIO() as buf:
|
21
|
+
console = Console(file=buf, record=True)
|
22
|
+
table = self.rich_print()
|
23
|
+
console.print(table)
|
24
|
+
return console.export_text()
|
27
25
|
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
# console.print(table)
|
32
|
-
# return console.export_text()
|
33
|
-
|
34
|
-
# def __str__(self):
|
35
|
-
# """Return a string representation of the object for console printing."""
|
36
|
-
# # return self._for_console()
|
37
|
-
# return self.__repr__()
|
26
|
+
def __str__(self):
|
27
|
+
"""Return a string representation of the object for console printing."""
|
28
|
+
return self._for_console()
|
38
29
|
|
39
|
-
|
40
|
-
|
41
|
-
|
30
|
+
def print(self):
|
31
|
+
"""Print the object to the console."""
|
32
|
+
from edsl.utilities.utilities import is_notebook
|
42
33
|
|
43
|
-
|
44
|
-
|
34
|
+
if is_notebook():
|
35
|
+
from IPython.display import display
|
45
36
|
|
46
|
-
|
47
|
-
|
48
|
-
|
37
|
+
display(self.rich_print())
|
38
|
+
else:
|
39
|
+
from rich.console import Console
|
49
40
|
|
50
|
-
|
51
|
-
|
41
|
+
console = Console()
|
42
|
+
console.print(self.rich_print())
|
52
43
|
|
53
44
|
|
54
45
|
class PersistenceMixin:
|
@@ -210,7 +201,7 @@ class DiffMethodsMixin:
|
|
210
201
|
|
211
202
|
|
212
203
|
class Base(
|
213
|
-
|
204
|
+
RichPrintingMixin,
|
214
205
|
PersistenceMixin,
|
215
206
|
DiffMethodsMixin,
|
216
207
|
ABC,
|
@@ -218,36 +209,16 @@ class Base(
|
|
218
209
|
):
|
219
210
|
"""Base class for all classes in the package."""
|
220
211
|
|
221
|
-
def
|
222
|
-
|
223
|
-
|
224
|
-
def print(self, **kwargs):
|
225
|
-
if "format" in kwargs:
|
226
|
-
if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
|
227
|
-
raise ValueError(f"Format '{kwargs['format']}' not supported.")
|
212
|
+
# def __getitem__(self, key):
|
213
|
+
# return getattr(self, key)
|
228
214
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
return self
|
215
|
+
# @abstractmethod
|
216
|
+
# def _repr_html_(self) -> str:
|
217
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
233
218
|
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
def summary(self, format="table"):
|
238
|
-
from edsl import Scenario
|
239
|
-
|
240
|
-
d = self._summary()
|
241
|
-
if format == "table":
|
242
|
-
return Scenario(d).table()
|
243
|
-
if format == "dict":
|
244
|
-
return d
|
245
|
-
if format == "json":
|
246
|
-
return Scenario(d).json()
|
247
|
-
if format == "yaml":
|
248
|
-
return Scenario(d).yaml()
|
249
|
-
if format == "html":
|
250
|
-
return Scenario(d).table(tablefmt="html")
|
219
|
+
# @abstractmethod
|
220
|
+
# def _repr_(self) -> str:
|
221
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
251
222
|
|
252
223
|
def keys(self):
|
253
224
|
"""Return the keys of the object."""
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.38"
|
1
|
+
__version__ = "0.1.38.dev2"
|
edsl/agents/Agent.py
CHANGED
@@ -243,15 +243,6 @@ class Agent(Base):
|
|
243
243
|
else:
|
244
244
|
return self._traits
|
245
245
|
|
246
|
-
def _repr_html_(self):
|
247
|
-
# d = self.to_dict(add_edsl_version=False)
|
248
|
-
d = self.traits
|
249
|
-
data = [[k, v] for k, v in d.items()]
|
250
|
-
from tabulate import tabulate
|
251
|
-
|
252
|
-
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
253
|
-
return f"<pre>{table}</pre>"
|
254
|
-
|
255
246
|
def rename(
|
256
247
|
self, old_name_or_dict: Union[str, dict], new_name: Optional[str] = None
|
257
248
|
) -> Agent:
|
@@ -640,10 +631,10 @@ class Agent(Base):
|
|
640
631
|
]
|
641
632
|
return f"{class_name}({', '.join(items)})"
|
642
633
|
|
643
|
-
|
644
|
-
|
634
|
+
def _repr_html_(self):
|
635
|
+
from edsl.utilities.utilities import data_to_html
|
645
636
|
|
646
|
-
|
637
|
+
return data_to_html(self.to_dict())
|
647
638
|
|
648
639
|
#######################
|
649
640
|
# SERIALIZATION METHODS
|
@@ -678,9 +669,9 @@ class Agent(Base):
|
|
678
669
|
if dynamic_traits_func:
|
679
670
|
func = inspect.getsource(dynamic_traits_func)
|
680
671
|
raw_data["dynamic_traits_function_source_code"] = func
|
681
|
-
raw_data[
|
682
|
-
|
683
|
-
|
672
|
+
raw_data["dynamic_traits_function_name"] = (
|
673
|
+
self.dynamic_traits_function_name
|
674
|
+
)
|
684
675
|
if hasattr(self, "answer_question_directly"):
|
685
676
|
raw_data.pop(
|
686
677
|
"answer_question_directly", None
|
@@ -694,9 +685,9 @@ class Agent(Base):
|
|
694
685
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
695
686
|
answer_question_directly_func
|
696
687
|
)
|
697
|
-
raw_data[
|
698
|
-
|
699
|
-
|
688
|
+
raw_data["answer_question_directly_function_name"] = (
|
689
|
+
self.answer_question_directly_function_name
|
690
|
+
)
|
700
691
|
|
701
692
|
return raw_data
|
702
693
|
|
edsl/agents/AgentList.py
CHANGED
@@ -36,10 +36,6 @@ def is_iterable(obj):
|
|
36
36
|
class AgentList(UserList, Base):
|
37
37
|
"""A list of Agents."""
|
38
38
|
|
39
|
-
__documentation__ = (
|
40
|
-
"https://docs.expectedparrot.com/en/latest/agents.html#agentlist-class"
|
41
|
-
)
|
42
|
-
|
43
39
|
def __init__(self, data: Optional[list["Agent"]] = None):
|
44
40
|
"""Initialize a new AgentList.
|
45
41
|
|
@@ -61,7 +57,7 @@ class AgentList(UserList, Base):
|
|
61
57
|
random.shuffle(self.data)
|
62
58
|
return self
|
63
59
|
|
64
|
-
def sample(self, n: int, seed
|
60
|
+
def sample(self, n: int, seed="edsl") -> AgentList:
|
65
61
|
"""Return a random sample of agents.
|
66
62
|
|
67
63
|
:param n: The number of agents to sample.
|
@@ -69,17 +65,9 @@ class AgentList(UserList, Base):
|
|
69
65
|
"""
|
70
66
|
import random
|
71
67
|
|
72
|
-
|
73
|
-
random.seed(seed)
|
68
|
+
random.seed(seed)
|
74
69
|
return AgentList(random.sample(self.data, n))
|
75
70
|
|
76
|
-
def to_pandas(self):
|
77
|
-
"""Return a pandas DataFrame."""
|
78
|
-
return self.to_scenario_list().to_pandas()
|
79
|
-
|
80
|
-
def tally(self):
|
81
|
-
return self.to_scenario_list().tally()
|
82
|
-
|
83
71
|
def rename(self, old_name, new_name):
|
84
72
|
"""Rename a trait in the AgentList.
|
85
73
|
|
@@ -249,7 +237,6 @@ class AgentList(UserList, Base):
|
|
249
237
|
return dict_hash(self.to_dict(add_edsl_version=False, sorted=True))
|
250
238
|
|
251
239
|
def to_dict(self, sorted=False, add_edsl_version=True):
|
252
|
-
"""Serialize the AgentList to a dictionary."""
|
253
240
|
if sorted:
|
254
241
|
data = self.data[:]
|
255
242
|
data.sort(key=lambda x: hash(x))
|
@@ -277,58 +264,23 @@ class AgentList(UserList, Base):
|
|
277
264
|
def __repr__(self):
|
278
265
|
return f"AgentList({self.data})"
|
279
266
|
|
280
|
-
def
|
281
|
-
|
282
|
-
|
283
|
-
"Number of agents": len(self),
|
284
|
-
"Agent trait fields": self.all_traits,
|
285
|
-
}
|
267
|
+
def print(self, format: Optional[str] = None):
|
268
|
+
"""Print the AgentList."""
|
269
|
+
print_json(json.dumps(self.to_dict(add_edsl_version=False)))
|
286
270
|
|
287
271
|
def _repr_html_(self):
|
288
272
|
"""Return an HTML representation of the AgentList."""
|
289
|
-
|
290
|
-
return str(self.summary(format="html")) + footer
|
291
|
-
|
292
|
-
def to_csv(self, file_path: str):
|
293
|
-
"""Save the AgentList to a CSV file.
|
294
|
-
|
295
|
-
:param file_path: The path to the CSV file.
|
296
|
-
"""
|
297
|
-
self.to_scenario_list().to_csv(file_path)
|
273
|
+
from edsl.utilities.utilities import data_to_html
|
298
274
|
|
299
|
-
|
300
|
-
"""Return a list of tuples."""
|
301
|
-
return self.to_scenario_list(include_agent_name).to_list()
|
275
|
+
return data_to_html(self.to_dict()["agent_list"])
|
302
276
|
|
303
|
-
def to_scenario_list(self
|
277
|
+
def to_scenario_list(self) -> ScenarioList:
|
304
278
|
"""Return a list of scenarios."""
|
305
279
|
from edsl.scenarios.ScenarioList import ScenarioList
|
306
280
|
from edsl.scenarios.Scenario import Scenario
|
307
281
|
|
308
|
-
if include_agent_name:
|
309
|
-
return ScenarioList(
|
310
|
-
[
|
311
|
-
Scenario(agent.traits | {"agent_name": agent.name})
|
312
|
-
for agent in self.data
|
313
|
-
]
|
314
|
-
)
|
315
282
|
return ScenarioList([Scenario(agent.traits) for agent in self.data])
|
316
283
|
|
317
|
-
def table(
|
318
|
-
self,
|
319
|
-
*fields,
|
320
|
-
tablefmt: Optional[str] = None,
|
321
|
-
pretty_labels: Optional[dict] = None,
|
322
|
-
) -> Table:
|
323
|
-
return (
|
324
|
-
self.to_scenario_list()
|
325
|
-
.to_dataset()
|
326
|
-
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
327
|
-
)
|
328
|
-
|
329
|
-
def tree(self, node_order: Optional[List[str]] = None):
|
330
|
-
return self.to_scenario_list().tree(node_order)
|
331
|
-
|
332
284
|
@classmethod
|
333
285
|
@remove_edsl_version
|
334
286
|
def from_dict(cls, data: dict) -> "AgentList":
|
@@ -363,9 +315,6 @@ class AgentList(UserList, Base):
|
|
363
315
|
|
364
316
|
:param trait_name: The name of the trait.
|
365
317
|
:param values: A list of values.
|
366
|
-
|
367
|
-
>>> AgentList.from_list('age', [22, 23])
|
368
|
-
AgentList([Agent(traits = {'age': 22}), Agent(traits = {'age': 23})])
|
369
318
|
"""
|
370
319
|
from edsl.agents.Agent import Agent
|
371
320
|
|
edsl/agents/Invigilator.py
CHANGED
@@ -45,10 +45,6 @@ class InvigilatorAI(InvigilatorBase):
|
|
45
45
|
|
46
46
|
params.update({"iteration": self.iteration, "cache": self.cache})
|
47
47
|
|
48
|
-
params.update({"invigilator": self})
|
49
|
-
# if hasattr(self.question, "answer_template"):
|
50
|
-
# breakpoint()
|
51
|
-
|
52
48
|
agent_response_dict: AgentResponseDict = await self.model.async_get_response(
|
53
49
|
**params
|
54
50
|
)
|
@@ -87,26 +83,19 @@ class InvigilatorAI(InvigilatorBase):
|
|
87
83
|
exception_occurred = None
|
88
84
|
validated = False
|
89
85
|
try:
|
90
|
-
# if the question has jinja parameters, it
|
86
|
+
# if the question has jinja parameters, it might be easier to make a new question
|
87
|
+
# with those all filled in & then validate that
|
88
|
+
# breakpoint()
|
91
89
|
if self.question.parameters:
|
92
90
|
prior_answers_dict = self.prompt_constructor.prior_answers_dict()
|
93
|
-
|
94
|
-
# question options have be treated differently because of dynamic question
|
95
|
-
# this logic is all in the prompt constructor
|
96
|
-
if "question_options" in self.question.data:
|
97
|
-
new_question_options = (
|
98
|
-
self.prompt_constructor._get_question_options(
|
99
|
-
self.question.data
|
100
|
-
)
|
101
|
-
)
|
102
|
-
if new_question_options != self.question.data["question_options"]:
|
103
|
-
# I don't love this direct writing but it seems to work
|
104
|
-
self.question.question_options = new_question_options
|
105
|
-
|
106
91
|
question_with_validators = self.question.render(
|
107
92
|
self.scenario | prior_answers_dict
|
108
93
|
)
|
109
94
|
question_with_validators.use_code = self.question.use_code
|
95
|
+
# if question_with_validators.parameters:
|
96
|
+
# raise ValueError(
|
97
|
+
# f"The question still has parameters after rendering: {question_with_validators}"
|
98
|
+
# )
|
110
99
|
else:
|
111
100
|
question_with_validators = self.question
|
112
101
|
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -172,6 +172,25 @@ class InvigilatorBase(ABC):
|
|
172
172
|
}
|
173
173
|
return EDSLResultObjectInput(**data)
|
174
174
|
|
175
|
+
# breakpoint()
|
176
|
+
# if hasattr(self, "augmented_model_response"):
|
177
|
+
# import json
|
178
|
+
|
179
|
+
# generated_tokens = json.loads(self.augmented_model_response)["answer"][
|
180
|
+
# "generated_tokens"
|
181
|
+
# ]
|
182
|
+
# else:
|
183
|
+
# generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
|
184
|
+
# agent_response_dict = AgentResponseDict(
|
185
|
+
# answer=None,
|
186
|
+
# comment="Failed to get usable response",
|
187
|
+
# generated_tokens=generated_tokens,
|
188
|
+
# question_name=self.question.question_name,
|
189
|
+
# prompts=self.get_prompts(),
|
190
|
+
# )
|
191
|
+
# # breakpoint()
|
192
|
+
# return agent_response_dict
|
193
|
+
|
175
194
|
def get_prompts(self) -> Dict[str, Prompt]:
|
176
195
|
"""Return the prompt used."""
|
177
196
|
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -169,8 +169,6 @@ class PromptConstructor:
|
|
169
169
|
|
170
170
|
placeholder = ["<< Option 1 - Placholder >>", "<< Option 2 - Placholder >>"]
|
171
171
|
|
172
|
-
# print("Question options entry: ", question_options_entry)
|
173
|
-
|
174
172
|
if isinstance(question_options_entry, str):
|
175
173
|
env = Environment()
|
176
174
|
parsed_content = env.parse(question_options_entry)
|
@@ -202,12 +200,13 @@ class PromptConstructor:
|
|
202
200
|
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
|
203
201
|
question_data = self.question.data.copy()
|
204
202
|
|
205
|
-
if
|
206
|
-
"question_options" in question_data
|
207
|
-
): # is this a question with question options?
|
203
|
+
if "question_options" in question_data:
|
208
204
|
question_options = self._get_question_options(question_data)
|
209
205
|
question_data["question_options"] = question_options
|
210
206
|
|
207
|
+
# check to see if the question_options is actually a string
|
208
|
+
# This is used when the user is using the question_options as a variable from a scenario
|
209
|
+
# if "question_options" in question_data:
|
211
210
|
replacement_dict = self.build_replacement_dict(question_data)
|
212
211
|
rendered_instructions = question_prompt.render(replacement_dict)
|
213
212
|
|
edsl/config.py
CHANGED
@@ -61,14 +61,6 @@ CONFIG_MAP = {
|
|
61
61
|
"default": "https://www.expectedparrot.com",
|
62
62
|
"info": "This config var holds the URL of the Expected Parrot API.",
|
63
63
|
},
|
64
|
-
"EDSL_MAX_CONCURRENT_TASKS": {
|
65
|
-
"default": "500",
|
66
|
-
"info": "This config var determines the maximum number of concurrent tasks that can be run by the async job-runner",
|
67
|
-
},
|
68
|
-
"EDSL_OPEN_EXCEPTION_REPORT_URL": {
|
69
|
-
"default": "False",
|
70
|
-
"info": "This config var determines whether to open the exception report URL in the browser",
|
71
|
-
},
|
72
64
|
}
|
73
65
|
|
74
66
|
|
edsl/coop/coop.py
CHANGED
@@ -102,57 +102,12 @@ class Coop:
|
|
102
102
|
|
103
103
|
return response
|
104
104
|
|
105
|
-
def _get_latest_stable_version(self, version: str) -> str:
|
106
|
-
"""
|
107
|
-
Extract the latest stable PyPI version from a version string.
|
108
|
-
|
109
|
-
Examples:
|
110
|
-
- Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
|
111
|
-
- Return a stable version as is: "0.1.37" -> "0.1.37"
|
112
|
-
"""
|
113
|
-
if "dev" not in version:
|
114
|
-
return version
|
115
|
-
else:
|
116
|
-
# For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
|
117
|
-
major, minor, patch = version.split(".")[:3]
|
118
|
-
|
119
|
-
current_patch = int(patch)
|
120
|
-
latest_patch = current_patch - 1
|
121
|
-
return f"{major}.{minor}.{latest_patch}"
|
122
|
-
|
123
|
-
def _user_version_is_outdated(
|
124
|
-
self, user_version_str: str, server_version_str: str
|
125
|
-
) -> bool:
|
126
|
-
"""
|
127
|
-
Check if the user's EDSL version is outdated compared to the server's.
|
128
|
-
"""
|
129
|
-
server_stable_version_str = self._get_latest_stable_version(server_version_str)
|
130
|
-
user_stable_version_str = self._get_latest_stable_version(user_version_str)
|
131
|
-
|
132
|
-
# Turn the version strings into tuples of ints for comparison
|
133
|
-
user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
|
134
|
-
server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
|
135
|
-
|
136
|
-
return user_stable_version < server_stable_version
|
137
|
-
|
138
105
|
def _resolve_server_response(
|
139
106
|
self, response: requests.Response, check_api_key: bool = True
|
140
107
|
) -> None:
|
141
108
|
"""
|
142
109
|
Check the response from the server and raise errors as appropriate.
|
143
110
|
"""
|
144
|
-
# Get EDSL version from header
|
145
|
-
server_edsl_version = response.headers.get("X-EDSL-Version")
|
146
|
-
|
147
|
-
if server_edsl_version:
|
148
|
-
if self._user_version_is_outdated(
|
149
|
-
user_version_str=self._edsl_version,
|
150
|
-
server_version_str=server_edsl_version,
|
151
|
-
):
|
152
|
-
print(
|
153
|
-
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
154
|
-
)
|
155
|
-
|
156
111
|
if response.status_code >= 400:
|
157
112
|
message = response.json().get("detail")
|
158
113
|
# print(response.text)
|
@@ -625,7 +580,7 @@ class Coop:
|
|
625
580
|
|
626
581
|
>>> job = Jobs.example()
|
627
582
|
>>> coop.remote_inference_create(job=job, description="My job")
|
628
|
-
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', '
|
583
|
+
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.29.dev4'}
|
629
584
|
"""
|
630
585
|
response = self._send_server_request(
|
631
586
|
uri="api/v0/remote-inference",
|
@@ -666,7 +621,7 @@ class Coop:
|
|
666
621
|
:param results_uuid: The UUID of the results associated with the EDSL job.
|
667
622
|
|
668
623
|
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
669
|
-
{'
|
624
|
+
{'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
|
670
625
|
"""
|
671
626
|
if job_uuid is None and results_uuid is None:
|
672
627
|
raise ValueError("Either job_uuid or results_uuid must be provided.")
|
@@ -682,28 +637,10 @@ class Coop:
|
|
682
637
|
)
|
683
638
|
self._resolve_server_response(response)
|
684
639
|
data = response.json()
|
685
|
-
|
686
|
-
results_uuid = data.get("results_uuid")
|
687
|
-
latest_error_report_uuid = data.get("latest_error_report_uuid")
|
688
|
-
|
689
|
-
if results_uuid is None:
|
690
|
-
results_url = None
|
691
|
-
else:
|
692
|
-
results_url = f"{self.url}/content/{results_uuid}"
|
693
|
-
|
694
|
-
if latest_error_report_uuid is None:
|
695
|
-
latest_error_report_url = None
|
696
|
-
else:
|
697
|
-
latest_error_report_url = (
|
698
|
-
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
699
|
-
)
|
700
|
-
|
701
640
|
return {
|
702
641
|
"job_uuid": data.get("job_uuid"),
|
703
|
-
"results_uuid": results_uuid,
|
704
|
-
"results_url":
|
705
|
-
"latest_error_report_uuid": latest_error_report_uuid,
|
706
|
-
"latest_error_report_url": latest_error_report_url,
|
642
|
+
"results_uuid": data.get("results_uuid"),
|
643
|
+
"results_url": f"{self.url}/content/{data.get('results_uuid')}",
|
707
644
|
"status": data.get("status"),
|
708
645
|
"reason": data.get("reason"),
|
709
646
|
"credits_consumed": data.get("price"),
|
@@ -720,7 +657,7 @@ class Coop:
|
|
720
657
|
|
721
658
|
>>> job = Jobs.example()
|
722
659
|
>>> coop.remote_inference_cost(input=job)
|
723
|
-
|
660
|
+
16
|
724
661
|
"""
|
725
662
|
if isinstance(input, Jobs):
|
726
663
|
job = input
|
@@ -800,15 +737,11 @@ class Coop:
|
|
800
737
|
|
801
738
|
from edsl.config import CONFIG
|
802
739
|
|
803
|
-
if CONFIG.get("EDSL_FETCH_TOKEN_PRICES")
|
740
|
+
if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
|
804
741
|
price_fetcher = PriceFetcher()
|
805
742
|
return price_fetcher.fetch_prices()
|
806
|
-
elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
|
807
|
-
return {}
|
808
743
|
else:
|
809
|
-
|
810
|
-
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
811
|
-
)
|
744
|
+
return {}
|
812
745
|
|
813
746
|
def fetch_models(self) -> dict:
|
814
747
|
"""
|
edsl/data/Cache.py
CHANGED
@@ -27,8 +27,6 @@ class Cache(Base):
|
|
27
27
|
:param method: The method of storage to use for the cache.
|
28
28
|
"""
|
29
29
|
|
30
|
-
__documentation__ = "https://docs.expectedparrot.com/en/latest/data.html"
|
31
|
-
|
32
30
|
data = {}
|
33
31
|
|
34
32
|
def __init__(
|
@@ -411,33 +409,10 @@ class Cache(Base):
|
|
411
409
|
|
412
410
|
return d
|
413
411
|
|
414
|
-
def _summary(self):
|
415
|
-
return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
|
416
|
-
|
417
412
|
def _repr_html_(self):
|
418
|
-
|
419
|
-
# return data_to_html(self.to_dict())
|
420
|
-
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
421
|
-
return str(self.summary(format="html")) + footer
|
422
|
-
|
423
|
-
def table(
|
424
|
-
self,
|
425
|
-
*fields,
|
426
|
-
tablefmt: Optional[str] = None,
|
427
|
-
pretty_labels: Optional[dict] = None,
|
428
|
-
) -> str:
|
429
|
-
return self.to_dataset().table(
|
430
|
-
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
431
|
-
)
|
432
|
-
|
433
|
-
def select(self, *fields):
|
434
|
-
return self.to_dataset().select(*fields)
|
435
|
-
|
436
|
-
def tree(self, node_list: Optional[list[str]] = None):
|
437
|
-
return self.to_scenario_list().tree(node_list)
|
413
|
+
from edsl.utilities.utilities import data_to_html
|
438
414
|
|
439
|
-
|
440
|
-
return self.to_scenario_list().to_dataset()
|
415
|
+
return data_to_html(self.to_dict())
|
441
416
|
|
442
417
|
@classmethod
|
443
418
|
@remove_edsl_version
|
edsl/data/CacheEntry.py
CHANGED
@@ -96,14 +96,9 @@ class CacheEntry:
|
|
96
96
|
"""
|
97
97
|
Returns an HTML representation of a CacheEntry.
|
98
98
|
"""
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
data = [[k, v] for k, v in d.items()]
|
103
|
-
from tabulate import tabulate
|
104
|
-
|
105
|
-
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
106
|
-
return f"<pre>{table}</pre>"
|
99
|
+
from edsl.utilities.utilities import data_to_html
|
100
|
+
|
101
|
+
return data_to_html(self.to_dict())
|
107
102
|
|
108
103
|
def keys(self):
|
109
104
|
return list(self.to_dict().keys())
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -76,3 +76,22 @@ class RemoteCacheSync:
|
|
76
76
|
self._output(
|
77
77
|
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
78
|
)
|
79
|
+
|
80
|
+
|
81
|
+
# # Usage example
|
82
|
+
# def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
|
83
|
+
# with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
|
84
|
+
# self._output("Running job...")
|
85
|
+
# results = self._run_local(
|
86
|
+
# n=n,
|
87
|
+
# progress_bar=progress_bar,
|
88
|
+
# cache=cache,
|
89
|
+
# stop_on_exception=stop_on_exception,
|
90
|
+
# sidecar_model=sidecar_model,
|
91
|
+
# print_exceptions=print_exceptions,
|
92
|
+
# raise_validation_errors=raise_validation_errors,
|
93
|
+
# )
|
94
|
+
# self._output("Job completed!")
|
95
|
+
|
96
|
+
# results.cache = cache.new_entries_cache()
|
97
|
+
# return results
|
edsl/enums.py
CHANGED
@@ -64,7 +64,6 @@ class InferenceServiceType(EnumWithChecks):
|
|
64
64
|
OLLAMA = "ollama"
|
65
65
|
MISTRAL = "mistral"
|
66
66
|
TOGETHER = "together"
|
67
|
-
PERPLEXITY = "perplexity"
|
68
67
|
|
69
68
|
|
70
69
|
service_to_api_keyname = {
|
@@ -79,7 +78,6 @@ service_to_api_keyname = {
|
|
79
78
|
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
80
79
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
81
80
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
82
|
-
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
83
81
|
}
|
84
82
|
|
85
83
|
|