edsl 0.1.38.dev2__py3-none-any.whl → 0.1.38.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +60 -31
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +18 -9
- edsl/agents/AgentList.py +59 -8
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/config.py +8 -0
- edsl/coop/coop.py +74 -7
- edsl/data/Cache.py +27 -2
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +88 -548
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/interviews/Interview.py +11 -11
- edsl/jobs/runners/JobsRunnerAsyncio.py +140 -35
- edsl/jobs/runners/JobsRunnerStatus.py +0 -2
- edsl/jobs/tasks/TaskHistory.py +15 -16
- edsl/language_models/LanguageModel.py +44 -84
- edsl/language_models/ModelList.py +47 -1
- edsl/language_models/registry.py +57 -4
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +20 -16
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +18 -9
- edsl/results/Results.py +145 -51
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +61 -4
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +237 -62
- edsl/surveys/Survey.py +16 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/Instruction.py +12 -0
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/METADATA +2 -1
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/RECORD +53 -45
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.38.dev2.dist-info → edsl-0.1.38.dev4.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -9,37 +9,46 @@ from uuid import UUID
|
|
9
9
|
|
10
10
|
# from edsl.utilities.MethodSuggesterMixin import MethodSuggesterMixin
|
11
11
|
|
12
|
+
from edsl.utilities.utilities import is_notebook
|
13
|
+
|
12
14
|
|
13
15
|
class RichPrintingMixin:
|
14
|
-
|
16
|
+
pass
|
15
17
|
|
16
|
-
def
|
17
|
-
|
18
|
-
from rich.console import Console
|
18
|
+
# def print(self):
|
19
|
+
# print(self)
|
19
20
|
|
20
|
-
with io.StringIO() as buf:
|
21
|
-
console = Console(file=buf, record=True)
|
22
|
-
table = self.rich_print()
|
23
|
-
console.print(table)
|
24
|
-
return console.export_text()
|
25
21
|
|
26
|
-
|
27
|
-
"""Return a string representation of the object for console printing."""
|
28
|
-
return self._for_console()
|
22
|
+
# """Mixin for rich printing and persistence of objects."""
|
29
23
|
|
30
|
-
|
31
|
-
|
32
|
-
|
24
|
+
# def _for_console(self):
|
25
|
+
# """Return a string representation of the object for console printing."""
|
26
|
+
# from rich.console import Console
|
33
27
|
|
34
|
-
|
35
|
-
|
28
|
+
# with io.StringIO() as buf:
|
29
|
+
# console = Console(file=buf, record=True)
|
30
|
+
# table = self.rich_print()
|
31
|
+
# console.print(table)
|
32
|
+
# return console.export_text()
|
36
33
|
|
37
|
-
|
38
|
-
|
39
|
-
|
34
|
+
# def __str__(self):
|
35
|
+
# """Return a string representation of the object for console printing."""
|
36
|
+
# # return self._for_console()
|
37
|
+
# return self.__repr__()
|
38
|
+
|
39
|
+
# def print(self):
|
40
|
+
# """Print the object to the console."""
|
41
|
+
# from edsl.utilities.utilities import is_notebook
|
42
|
+
|
43
|
+
# if is_notebook():
|
44
|
+
# from IPython.display import display
|
40
45
|
|
41
|
-
|
42
|
-
|
46
|
+
# display(self.rich_print())
|
47
|
+
# else:
|
48
|
+
# from rich.console import Console
|
49
|
+
|
50
|
+
# console = Console()
|
51
|
+
# console.print(self.rich_print())
|
43
52
|
|
44
53
|
|
45
54
|
class PersistenceMixin:
|
@@ -201,7 +210,7 @@ class DiffMethodsMixin:
|
|
201
210
|
|
202
211
|
|
203
212
|
class Base(
|
204
|
-
RichPrintingMixin,
|
213
|
+
# RichPrintingMixin,
|
205
214
|
PersistenceMixin,
|
206
215
|
DiffMethodsMixin,
|
207
216
|
ABC,
|
@@ -209,16 +218,36 @@ class Base(
|
|
209
218
|
):
|
210
219
|
"""Base class for all classes in the package."""
|
211
220
|
|
212
|
-
|
213
|
-
|
221
|
+
def json(self):
|
222
|
+
return json.loads(json.dumps(self.to_dict(add_edsl_version=False)))
|
223
|
+
|
224
|
+
def print(self, **kwargs):
|
225
|
+
if "format" in kwargs:
|
226
|
+
if kwargs["format"] not in ["html", "markdown", "rich", "latex"]:
|
227
|
+
raise ValueError(f"Format '{kwargs['format']}' not supported.")
|
214
228
|
|
215
|
-
|
216
|
-
|
217
|
-
|
229
|
+
if hasattr(self, "table"):
|
230
|
+
return self.table()
|
231
|
+
else:
|
232
|
+
return self
|
218
233
|
|
219
|
-
|
220
|
-
|
221
|
-
|
234
|
+
def __str__(self):
|
235
|
+
return self.__repr__()
|
236
|
+
|
237
|
+
def summary(self, format="table"):
|
238
|
+
from edsl import Scenario
|
239
|
+
|
240
|
+
d = self._summary()
|
241
|
+
if format == "table":
|
242
|
+
return Scenario(d).table()
|
243
|
+
if format == "dict":
|
244
|
+
return d
|
245
|
+
if format == "json":
|
246
|
+
return Scenario(d).json()
|
247
|
+
if format == "yaml":
|
248
|
+
return Scenario(d).yaml()
|
249
|
+
if format == "html":
|
250
|
+
return Scenario(d).table(tablefmt="html")
|
222
251
|
|
223
252
|
def keys(self):
|
224
253
|
"""Return the keys of the object."""
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.38.
|
1
|
+
__version__ = "0.1.38.dev4"
|
edsl/agents/Agent.py
CHANGED
@@ -243,6 +243,15 @@ class Agent(Base):
|
|
243
243
|
else:
|
244
244
|
return self._traits
|
245
245
|
|
246
|
+
def _repr_html_(self):
|
247
|
+
# d = self.to_dict(add_edsl_version=False)
|
248
|
+
d = self.traits
|
249
|
+
data = [[k, v] for k, v in d.items()]
|
250
|
+
from tabulate import tabulate
|
251
|
+
|
252
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
253
|
+
return f"<pre>{table}</pre>"
|
254
|
+
|
246
255
|
def rename(
|
247
256
|
self, old_name_or_dict: Union[str, dict], new_name: Optional[str] = None
|
248
257
|
) -> Agent:
|
@@ -631,10 +640,10 @@ class Agent(Base):
|
|
631
640
|
]
|
632
641
|
return f"{class_name}({', '.join(items)})"
|
633
642
|
|
634
|
-
def _repr_html_(self):
|
635
|
-
|
643
|
+
# def _repr_html_(self):
|
644
|
+
# from edsl.utilities.utilities import data_to_html
|
636
645
|
|
637
|
-
|
646
|
+
# return data_to_html(self.to_dict())
|
638
647
|
|
639
648
|
#######################
|
640
649
|
# SERIALIZATION METHODS
|
@@ -669,9 +678,9 @@ class Agent(Base):
|
|
669
678
|
if dynamic_traits_func:
|
670
679
|
func = inspect.getsource(dynamic_traits_func)
|
671
680
|
raw_data["dynamic_traits_function_source_code"] = func
|
672
|
-
raw_data[
|
673
|
-
|
674
|
-
|
681
|
+
raw_data[
|
682
|
+
"dynamic_traits_function_name"
|
683
|
+
] = self.dynamic_traits_function_name
|
675
684
|
if hasattr(self, "answer_question_directly"):
|
676
685
|
raw_data.pop(
|
677
686
|
"answer_question_directly", None
|
@@ -685,9 +694,9 @@ class Agent(Base):
|
|
685
694
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
686
695
|
answer_question_directly_func
|
687
696
|
)
|
688
|
-
raw_data[
|
689
|
-
|
690
|
-
|
697
|
+
raw_data[
|
698
|
+
"answer_question_directly_function_name"
|
699
|
+
] = self.answer_question_directly_function_name
|
691
700
|
|
692
701
|
return raw_data
|
693
702
|
|
edsl/agents/AgentList.py
CHANGED
@@ -36,6 +36,10 @@ def is_iterable(obj):
|
|
36
36
|
class AgentList(UserList, Base):
|
37
37
|
"""A list of Agents."""
|
38
38
|
|
39
|
+
__documentation__ = (
|
40
|
+
"https://docs.expectedparrot.com/en/latest/agents.html#agentlist-class"
|
41
|
+
)
|
42
|
+
|
39
43
|
def __init__(self, data: Optional[list["Agent"]] = None):
|
40
44
|
"""Initialize a new AgentList.
|
41
45
|
|
@@ -57,7 +61,7 @@ class AgentList(UserList, Base):
|
|
57
61
|
random.shuffle(self.data)
|
58
62
|
return self
|
59
63
|
|
60
|
-
def sample(self, n: int, seed=
|
64
|
+
def sample(self, n: int, seed: Optional[str] = None) -> AgentList:
|
61
65
|
"""Return a random sample of agents.
|
62
66
|
|
63
67
|
:param n: The number of agents to sample.
|
@@ -65,9 +69,17 @@ class AgentList(UserList, Base):
|
|
65
69
|
"""
|
66
70
|
import random
|
67
71
|
|
68
|
-
|
72
|
+
if seed:
|
73
|
+
random.seed(seed)
|
69
74
|
return AgentList(random.sample(self.data, n))
|
70
75
|
|
76
|
+
def to_pandas(self):
|
77
|
+
"""Return a pandas DataFrame."""
|
78
|
+
return self.to_scenario_list().to_pandas()
|
79
|
+
|
80
|
+
def tally(self):
|
81
|
+
return self.to_scenario_list().tally()
|
82
|
+
|
71
83
|
def rename(self, old_name, new_name):
|
72
84
|
"""Rename a trait in the AgentList.
|
73
85
|
|
@@ -237,6 +249,7 @@ class AgentList(UserList, Base):
|
|
237
249
|
return dict_hash(self.to_dict(add_edsl_version=False, sorted=True))
|
238
250
|
|
239
251
|
def to_dict(self, sorted=False, add_edsl_version=True):
|
252
|
+
"""Serialize the AgentList to a dictionary."""
|
240
253
|
if sorted:
|
241
254
|
data = self.data[:]
|
242
255
|
data.sort(key=lambda x: hash(x))
|
@@ -264,23 +277,58 @@ class AgentList(UserList, Base):
|
|
264
277
|
def __repr__(self):
|
265
278
|
return f"AgentList({self.data})"
|
266
279
|
|
267
|
-
def
|
268
|
-
|
269
|
-
|
280
|
+
def _summary(self):
|
281
|
+
return {
|
282
|
+
"EDSL Class": "AgentList",
|
283
|
+
"Number of agents": len(self),
|
284
|
+
"Agent trait fields": self.all_traits,
|
285
|
+
}
|
270
286
|
|
271
287
|
def _repr_html_(self):
|
272
288
|
"""Return an HTML representation of the AgentList."""
|
273
|
-
|
289
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
290
|
+
return str(self.summary(format="html")) + footer
|
291
|
+
|
292
|
+
def to_csv(self, file_path: str):
|
293
|
+
"""Save the AgentList to a CSV file.
|
294
|
+
|
295
|
+
:param file_path: The path to the CSV file.
|
296
|
+
"""
|
297
|
+
self.to_scenario_list().to_csv(file_path)
|
274
298
|
|
275
|
-
|
299
|
+
def to_list(self, include_agent_name=False) -> list[tuple]:
|
300
|
+
"""Return a list of tuples."""
|
301
|
+
return self.to_scenario_list(include_agent_name).to_list()
|
276
302
|
|
277
|
-
def to_scenario_list(self) -> ScenarioList:
|
303
|
+
def to_scenario_list(self, include_agent_name=False) -> ScenarioList:
|
278
304
|
"""Return a list of scenarios."""
|
279
305
|
from edsl.scenarios.ScenarioList import ScenarioList
|
280
306
|
from edsl.scenarios.Scenario import Scenario
|
281
307
|
|
308
|
+
if include_agent_name:
|
309
|
+
return ScenarioList(
|
310
|
+
[
|
311
|
+
Scenario(agent.traits | {"agent_name": agent.name})
|
312
|
+
for agent in self.data
|
313
|
+
]
|
314
|
+
)
|
282
315
|
return ScenarioList([Scenario(agent.traits) for agent in self.data])
|
283
316
|
|
317
|
+
def table(
|
318
|
+
self,
|
319
|
+
*fields,
|
320
|
+
tablefmt: Optional[str] = None,
|
321
|
+
pretty_labels: Optional[dict] = None,
|
322
|
+
) -> Table:
|
323
|
+
return (
|
324
|
+
self.to_scenario_list()
|
325
|
+
.to_dataset()
|
326
|
+
.table(*fields, tablefmt=tablefmt, pretty_labels=pretty_labels)
|
327
|
+
)
|
328
|
+
|
329
|
+
def tree(self, node_order: Optional[List[str]] = None):
|
330
|
+
return self.to_scenario_list().tree(node_order)
|
331
|
+
|
284
332
|
@classmethod
|
285
333
|
@remove_edsl_version
|
286
334
|
def from_dict(cls, data: dict) -> "AgentList":
|
@@ -315,6 +363,9 @@ class AgentList(UserList, Base):
|
|
315
363
|
|
316
364
|
:param trait_name: The name of the trait.
|
317
365
|
:param values: A list of values.
|
366
|
+
|
367
|
+
>>> AgentList.from_list('age', [22, 23])
|
368
|
+
AgentList([Agent(traits = {'age': 22}), Agent(traits = {'age': 23})])
|
318
369
|
"""
|
319
370
|
from edsl.agents.Agent import Agent
|
320
371
|
|
edsl/agents/Invigilator.py
CHANGED
@@ -45,6 +45,10 @@ class InvigilatorAI(InvigilatorBase):
|
|
45
45
|
|
46
46
|
params.update({"iteration": self.iteration, "cache": self.cache})
|
47
47
|
|
48
|
+
params.update({"invigilator": self})
|
49
|
+
# if hasattr(self.question, "answer_template"):
|
50
|
+
# breakpoint()
|
51
|
+
|
48
52
|
agent_response_dict: AgentResponseDict = await self.model.async_get_response(
|
49
53
|
**params
|
50
54
|
)
|
@@ -83,19 +87,26 @@ class InvigilatorAI(InvigilatorBase):
|
|
83
87
|
exception_occurred = None
|
84
88
|
validated = False
|
85
89
|
try:
|
86
|
-
# if the question has jinja parameters, it
|
87
|
-
# with those all filled in & then validate that
|
88
|
-
# breakpoint()
|
90
|
+
# if the question has jinja parameters, it is easier to make a new question with the parameters
|
89
91
|
if self.question.parameters:
|
90
92
|
prior_answers_dict = self.prompt_constructor.prior_answers_dict()
|
93
|
+
|
94
|
+
# question options have be treated differently because of dynamic question
|
95
|
+
# this logic is all in the prompt constructor
|
96
|
+
if "question_options" in self.question.data:
|
97
|
+
new_question_options = (
|
98
|
+
self.prompt_constructor._get_question_options(
|
99
|
+
self.question.data
|
100
|
+
)
|
101
|
+
)
|
102
|
+
if new_question_options != self.question.data["question_options"]:
|
103
|
+
# I don't love this direct writing but it seems to work
|
104
|
+
self.question.question_options = new_question_options
|
105
|
+
|
91
106
|
question_with_validators = self.question.render(
|
92
107
|
self.scenario | prior_answers_dict
|
93
108
|
)
|
94
109
|
question_with_validators.use_code = self.question.use_code
|
95
|
-
# if question_with_validators.parameters:
|
96
|
-
# raise ValueError(
|
97
|
-
# f"The question still has parameters after rendering: {question_with_validators}"
|
98
|
-
# )
|
99
110
|
else:
|
100
111
|
question_with_validators = self.question
|
101
112
|
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -172,25 +172,6 @@ class InvigilatorBase(ABC):
|
|
172
172
|
}
|
173
173
|
return EDSLResultObjectInput(**data)
|
174
174
|
|
175
|
-
# breakpoint()
|
176
|
-
# if hasattr(self, "augmented_model_response"):
|
177
|
-
# import json
|
178
|
-
|
179
|
-
# generated_tokens = json.loads(self.augmented_model_response)["answer"][
|
180
|
-
# "generated_tokens"
|
181
|
-
# ]
|
182
|
-
# else:
|
183
|
-
# generated_tokens = "Filled in by InvigilatorBase.get_failed_task_result"
|
184
|
-
# agent_response_dict = AgentResponseDict(
|
185
|
-
# answer=None,
|
186
|
-
# comment="Failed to get usable response",
|
187
|
-
# generated_tokens=generated_tokens,
|
188
|
-
# question_name=self.question.question_name,
|
189
|
-
# prompts=self.get_prompts(),
|
190
|
-
# )
|
191
|
-
# # breakpoint()
|
192
|
-
# return agent_response_dict
|
193
|
-
|
194
175
|
def get_prompts(self) -> Dict[str, Prompt]:
|
195
176
|
"""Return the prompt used."""
|
196
177
|
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -169,6 +169,8 @@ class PromptConstructor:
|
|
169
169
|
|
170
170
|
placeholder = ["<< Option 1 - Placholder >>", "<< Option 2 - Placholder >>"]
|
171
171
|
|
172
|
+
# print("Question options entry: ", question_options_entry)
|
173
|
+
|
172
174
|
if isinstance(question_options_entry, str):
|
173
175
|
env = Environment()
|
174
176
|
parsed_content = env.parse(question_options_entry)
|
@@ -200,13 +202,12 @@ class PromptConstructor:
|
|
200
202
|
# e.g., {'question_text': 'Do you like school?', 'question_name': 'q0', 'question_options': ['yes', 'no']}
|
201
203
|
question_data = self.question.data.copy()
|
202
204
|
|
203
|
-
if
|
205
|
+
if (
|
206
|
+
"question_options" in question_data
|
207
|
+
): # is this a question with question options?
|
204
208
|
question_options = self._get_question_options(question_data)
|
205
209
|
question_data["question_options"] = question_options
|
206
210
|
|
207
|
-
# check to see if the question_options is actually a string
|
208
|
-
# This is used when the user is using the question_options as a variable from a scenario
|
209
|
-
# if "question_options" in question_data:
|
210
211
|
replacement_dict = self.build_replacement_dict(question_data)
|
211
212
|
rendered_instructions = question_prompt.render(replacement_dict)
|
212
213
|
|
edsl/config.py
CHANGED
@@ -61,6 +61,14 @@ CONFIG_MAP = {
|
|
61
61
|
"default": "https://www.expectedparrot.com",
|
62
62
|
"info": "This config var holds the URL of the Expected Parrot API.",
|
63
63
|
},
|
64
|
+
"EDSL_MAX_CONCURRENT_TASKS": {
|
65
|
+
"default": "500",
|
66
|
+
"info": "This config var determines the maximum number of concurrent tasks that can be run by the async job-runner",
|
67
|
+
},
|
68
|
+
"EDSL_OPEN_EXCEPTION_REPORT_URL": {
|
69
|
+
"default": "False",
|
70
|
+
"info": "This config var determines whether to open the exception report URL in the browser",
|
71
|
+
},
|
64
72
|
}
|
65
73
|
|
66
74
|
|
edsl/coop/coop.py
CHANGED
@@ -102,12 +102,57 @@ class Coop:
|
|
102
102
|
|
103
103
|
return response
|
104
104
|
|
105
|
+
def _get_latest_stable_version(self, version: str) -> str:
|
106
|
+
"""
|
107
|
+
Extract the latest stable PyPI version from a version string.
|
108
|
+
|
109
|
+
Examples:
|
110
|
+
- Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
|
111
|
+
- Return a stable version as is: "0.1.37" -> "0.1.37"
|
112
|
+
"""
|
113
|
+
if "dev" not in version:
|
114
|
+
return version
|
115
|
+
else:
|
116
|
+
# For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
|
117
|
+
major, minor, patch = version.split(".")[:3]
|
118
|
+
|
119
|
+
current_patch = int(patch)
|
120
|
+
latest_patch = current_patch - 1
|
121
|
+
return f"{major}.{minor}.{latest_patch}"
|
122
|
+
|
123
|
+
def _user_version_is_outdated(
|
124
|
+
self, user_version_str: str, server_version_str: str
|
125
|
+
) -> bool:
|
126
|
+
"""
|
127
|
+
Check if the user's EDSL version is outdated compared to the server's.
|
128
|
+
"""
|
129
|
+
server_stable_version_str = self._get_latest_stable_version(server_version_str)
|
130
|
+
user_stable_version_str = self._get_latest_stable_version(user_version_str)
|
131
|
+
|
132
|
+
# Turn the version strings into tuples of ints for comparison
|
133
|
+
user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
|
134
|
+
server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
|
135
|
+
|
136
|
+
return user_stable_version < server_stable_version
|
137
|
+
|
105
138
|
def _resolve_server_response(
|
106
139
|
self, response: requests.Response, check_api_key: bool = True
|
107
140
|
) -> None:
|
108
141
|
"""
|
109
142
|
Check the response from the server and raise errors as appropriate.
|
110
143
|
"""
|
144
|
+
# Get EDSL version from header
|
145
|
+
server_edsl_version = response.headers.get("X-EDSL-Version")
|
146
|
+
|
147
|
+
if server_edsl_version:
|
148
|
+
if self._user_version_is_outdated(
|
149
|
+
user_version_str=self._edsl_version,
|
150
|
+
server_version_str=server_edsl_version,
|
151
|
+
):
|
152
|
+
print(
|
153
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
154
|
+
)
|
155
|
+
|
111
156
|
if response.status_code >= 400:
|
112
157
|
message = response.json().get("detail")
|
113
158
|
# print(response.text)
|
@@ -580,7 +625,7 @@ class Coop:
|
|
580
625
|
|
581
626
|
>>> job = Jobs.example()
|
582
627
|
>>> coop.remote_inference_create(job=job, description="My job")
|
583
|
-
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.
|
628
|
+
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'iterations': None, 'visibility': 'unlisted', 'version': '0.1.38.dev1'}
|
584
629
|
"""
|
585
630
|
response = self._send_server_request(
|
586
631
|
uri="api/v0/remote-inference",
|
@@ -621,7 +666,7 @@ class Coop:
|
|
621
666
|
:param results_uuid: The UUID of the results associated with the EDSL job.
|
622
667
|
|
623
668
|
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
624
|
-
{'
|
669
|
+
{'job_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'latest_error_report_uuid': None, 'latest_error_report_url': None, 'status': 'completed', 'reason': None, 'credits_consumed': 0.35, 'version': '0.1.38.dev1'}
|
625
670
|
"""
|
626
671
|
if job_uuid is None and results_uuid is None:
|
627
672
|
raise ValueError("Either job_uuid or results_uuid must be provided.")
|
@@ -637,10 +682,28 @@ class Coop:
|
|
637
682
|
)
|
638
683
|
self._resolve_server_response(response)
|
639
684
|
data = response.json()
|
685
|
+
|
686
|
+
results_uuid = data.get("results_uuid")
|
687
|
+
latest_error_report_uuid = data.get("latest_error_report_uuid")
|
688
|
+
|
689
|
+
if results_uuid is None:
|
690
|
+
results_url = None
|
691
|
+
else:
|
692
|
+
results_url = f"{self.url}/content/{results_uuid}"
|
693
|
+
|
694
|
+
if latest_error_report_uuid is None:
|
695
|
+
latest_error_report_url = None
|
696
|
+
else:
|
697
|
+
latest_error_report_url = (
|
698
|
+
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
699
|
+
)
|
700
|
+
|
640
701
|
return {
|
641
702
|
"job_uuid": data.get("job_uuid"),
|
642
|
-
"results_uuid":
|
643
|
-
"results_url":
|
703
|
+
"results_uuid": results_uuid,
|
704
|
+
"results_url": results_url,
|
705
|
+
"latest_error_report_uuid": latest_error_report_uuid,
|
706
|
+
"latest_error_report_url": latest_error_report_url,
|
644
707
|
"status": data.get("status"),
|
645
708
|
"reason": data.get("reason"),
|
646
709
|
"credits_consumed": data.get("price"),
|
@@ -657,7 +720,7 @@ class Coop:
|
|
657
720
|
|
658
721
|
>>> job = Jobs.example()
|
659
722
|
>>> coop.remote_inference_cost(input=job)
|
660
|
-
|
723
|
+
{'credits': 0.77, 'usd': 0.0076950000000000005}
|
661
724
|
"""
|
662
725
|
if isinstance(input, Jobs):
|
663
726
|
job = input
|
@@ -737,11 +800,15 @@ class Coop:
|
|
737
800
|
|
738
801
|
from edsl.config import CONFIG
|
739
802
|
|
740
|
-
if
|
803
|
+
if CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "True":
|
741
804
|
price_fetcher = PriceFetcher()
|
742
805
|
return price_fetcher.fetch_prices()
|
743
|
-
|
806
|
+
elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
|
744
807
|
return {}
|
808
|
+
else:
|
809
|
+
raise ValueError(
|
810
|
+
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
811
|
+
)
|
745
812
|
|
746
813
|
def fetch_models(self) -> dict:
|
747
814
|
"""
|
edsl/data/Cache.py
CHANGED
@@ -27,6 +27,8 @@ class Cache(Base):
|
|
27
27
|
:param method: The method of storage to use for the cache.
|
28
28
|
"""
|
29
29
|
|
30
|
+
__documentation__ = "https://docs.expectedparrot.com/en/latest/data.html"
|
31
|
+
|
30
32
|
data = {}
|
31
33
|
|
32
34
|
def __init__(
|
@@ -409,10 +411,33 @@ class Cache(Base):
|
|
409
411
|
|
410
412
|
return d
|
411
413
|
|
414
|
+
def _summary(self):
|
415
|
+
return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
|
416
|
+
|
412
417
|
def _repr_html_(self):
|
413
|
-
from edsl.utilities.utilities import data_to_html
|
418
|
+
# from edsl.utilities.utilities import data_to_html
|
419
|
+
# return data_to_html(self.to_dict())
|
420
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
421
|
+
return str(self.summary(format="html")) + footer
|
422
|
+
|
423
|
+
def table(
|
424
|
+
self,
|
425
|
+
*fields,
|
426
|
+
tablefmt: Optional[str] = None,
|
427
|
+
pretty_labels: Optional[dict] = None,
|
428
|
+
) -> str:
|
429
|
+
return self.to_dataset().table(
|
430
|
+
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
431
|
+
)
|
432
|
+
|
433
|
+
def select(self, *fields):
|
434
|
+
return self.to_dataset().select(*fields)
|
435
|
+
|
436
|
+
def tree(self, node_list: Optional[list[str]] = None):
|
437
|
+
return self.to_scenario_list().tree(node_list)
|
414
438
|
|
415
|
-
|
439
|
+
def to_dataset(self):
|
440
|
+
return self.to_scenario_list().to_dataset()
|
416
441
|
|
417
442
|
@classmethod
|
418
443
|
@remove_edsl_version
|
edsl/data/CacheEntry.py
CHANGED
@@ -96,9 +96,14 @@ class CacheEntry:
|
|
96
96
|
"""
|
97
97
|
Returns an HTML representation of a CacheEntry.
|
98
98
|
"""
|
99
|
-
from edsl.utilities.utilities import data_to_html
|
100
|
-
|
101
|
-
|
99
|
+
# from edsl.utilities.utilities import data_to_html
|
100
|
+
# return data_to_html(self.to_dict())
|
101
|
+
d = self.to_dict()
|
102
|
+
data = [[k, v] for k, v in d.items()]
|
103
|
+
from tabulate import tabulate
|
104
|
+
|
105
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
106
|
+
return f"<pre>{table}</pre>"
|
102
107
|
|
103
108
|
def keys(self):
|
104
109
|
return list(self.to_dict().keys())
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -76,22 +76,3 @@ class RemoteCacheSync:
|
|
76
76
|
self._output(
|
77
77
|
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
78
|
)
|
79
|
-
|
80
|
-
|
81
|
-
# # Usage example
|
82
|
-
# def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
|
83
|
-
# with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
|
84
|
-
# self._output("Running job...")
|
85
|
-
# results = self._run_local(
|
86
|
-
# n=n,
|
87
|
-
# progress_bar=progress_bar,
|
88
|
-
# cache=cache,
|
89
|
-
# stop_on_exception=stop_on_exception,
|
90
|
-
# sidecar_model=sidecar_model,
|
91
|
-
# print_exceptions=print_exceptions,
|
92
|
-
# raise_validation_errors=raise_validation_errors,
|
93
|
-
# )
|
94
|
-
# self._output("Job completed!")
|
95
|
-
|
96
|
-
# results.cache = cache.new_entries_cache()
|
97
|
-
# return results
|
edsl/enums.py
CHANGED
@@ -64,6 +64,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
64
64
|
OLLAMA = "ollama"
|
65
65
|
MISTRAL = "mistral"
|
66
66
|
TOGETHER = "together"
|
67
|
+
PERPLEXITY = "perplexity"
|
67
68
|
|
68
69
|
|
69
70
|
service_to_api_keyname = {
|
@@ -78,6 +79,7 @@ service_to_api_keyname = {
|
|
78
79
|
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
79
80
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
81
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
82
|
+
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
81
83
|
}
|
82
84
|
|
83
85
|
|