edsl 0.1.37.dev5__py3-none-any.whl → 0.1.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +63 -34
- edsl/BaseDiff.py +7 -7
- edsl/__init__.py +2 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +23 -11
- edsl/agents/AgentList.py +86 -23
- edsl/agents/Invigilator.py +18 -7
- edsl/agents/InvigilatorBase.py +0 -19
- edsl/agents/PromptConstructor.py +5 -4
- edsl/auto/SurveyCreatorPipeline.py +1 -1
- edsl/auto/utilities.py +1 -1
- edsl/base/Base.py +3 -13
- edsl/config.py +8 -0
- edsl/coop/coop.py +89 -19
- edsl/data/Cache.py +45 -17
- edsl/data/CacheEntry.py +8 -3
- edsl/data/RemoteCacheSync.py +0 -19
- edsl/enums.py +2 -0
- edsl/exceptions/agents.py +4 -0
- edsl/exceptions/cache.py +5 -0
- edsl/inference_services/GoogleService.py +7 -15
- edsl/inference_services/PerplexityService.py +163 -0
- edsl/inference_services/registry.py +2 -0
- edsl/jobs/Jobs.py +110 -559
- edsl/jobs/JobsChecks.py +147 -0
- edsl/jobs/JobsPrompts.py +268 -0
- edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
- edsl/jobs/buckets/TokenBucket.py +3 -0
- edsl/jobs/interviews/Interview.py +7 -7
- edsl/jobs/runners/JobsRunnerAsyncio.py +156 -28
- edsl/jobs/runners/JobsRunnerStatus.py +194 -196
- edsl/jobs/tasks/TaskHistory.py +27 -19
- edsl/language_models/LanguageModel.py +52 -90
- edsl/language_models/ModelList.py +67 -14
- edsl/language_models/registry.py +57 -4
- edsl/notebooks/Notebook.py +7 -8
- edsl/prompts/Prompt.py +8 -3
- edsl/questions/QuestionBase.py +38 -30
- edsl/questions/QuestionBaseGenMixin.py +1 -1
- edsl/questions/QuestionBasePromptsMixin.py +0 -17
- edsl/questions/QuestionExtract.py +3 -4
- edsl/questions/QuestionFunctional.py +10 -3
- edsl/questions/derived/QuestionTopK.py +2 -0
- edsl/questions/question_registry.py +36 -6
- edsl/results/CSSParameterizer.py +108 -0
- edsl/results/Dataset.py +146 -15
- edsl/results/DatasetExportMixin.py +231 -217
- edsl/results/DatasetTree.py +134 -4
- edsl/results/Result.py +31 -16
- edsl/results/Results.py +159 -65
- edsl/results/TableDisplay.py +198 -0
- edsl/results/table_display.css +78 -0
- edsl/scenarios/FileStore.py +187 -13
- edsl/scenarios/Scenario.py +73 -18
- edsl/scenarios/ScenarioJoin.py +127 -0
- edsl/scenarios/ScenarioList.py +251 -76
- edsl/surveys/MemoryPlan.py +1 -1
- edsl/surveys/Rule.py +1 -5
- edsl/surveys/RuleCollection.py +1 -1
- edsl/surveys/Survey.py +25 -19
- edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
- edsl/surveys/instructions/ChangeInstruction.py +9 -7
- edsl/surveys/instructions/Instruction.py +21 -7
- edsl/templates/error_reporting/interview_details.html +3 -3
- edsl/templates/error_reporting/interviews.html +18 -9
- edsl/{conjure → utilities}/naming_utilities.py +1 -1
- edsl/utilities/utilities.py +15 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/METADATA +2 -1
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/RECORD +71 -77
- edsl/conjure/AgentConstructionMixin.py +0 -160
- edsl/conjure/Conjure.py +0 -62
- edsl/conjure/InputData.py +0 -659
- edsl/conjure/InputDataCSV.py +0 -48
- edsl/conjure/InputDataMixinQuestionStats.py +0 -182
- edsl/conjure/InputDataPyRead.py +0 -91
- edsl/conjure/InputDataSPSS.py +0 -8
- edsl/conjure/InputDataStata.py +0 -8
- edsl/conjure/QuestionOptionMixin.py +0 -76
- edsl/conjure/QuestionTypeMixin.py +0 -23
- edsl/conjure/RawQuestion.py +0 -65
- edsl/conjure/SurveyResponses.py +0 -7
- edsl/conjure/__init__.py +0 -9
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/utilities.py +0 -201
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/LICENSE +0 -0
- {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -28,11 +28,23 @@ class Coop:
|
|
28
28
|
- Provide a URL directly, or use the default one.
|
29
29
|
"""
|
30
30
|
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
31
|
+
|
31
32
|
self.url = url or CONFIG.EXPECTED_PARROT_URL
|
32
33
|
if self.url.endswith("/"):
|
33
34
|
self.url = self.url[:-1]
|
35
|
+
if "chick.expectedparrot" in self.url:
|
36
|
+
self.api_url = "https://chickapi.expectedparrot.com"
|
37
|
+
elif "expectedparrot" in self.url:
|
38
|
+
self.api_url = "https://api.expectedparrot.com"
|
39
|
+
elif "localhost:1234" in self.url:
|
40
|
+
self.api_url = "http://localhost:8000"
|
41
|
+
else:
|
42
|
+
self.api_url = self.url
|
34
43
|
self._edsl_version = edsl.__version__
|
35
44
|
|
45
|
+
def get_progress_bar_url(self):
|
46
|
+
return f"{CONFIG.EXPECTED_PARROT_URL}"
|
47
|
+
|
36
48
|
################
|
37
49
|
# BASIC METHODS
|
38
50
|
################
|
@@ -59,7 +71,7 @@ class Coop:
|
|
59
71
|
"""
|
60
72
|
Send a request to the server and return the response.
|
61
73
|
"""
|
62
|
-
url = f"{self.
|
74
|
+
url = f"{self.api_url}/{uri}"
|
63
75
|
method = method.upper()
|
64
76
|
if payload is None:
|
65
77
|
timeout = 20
|
@@ -90,12 +102,57 @@ class Coop:
|
|
90
102
|
|
91
103
|
return response
|
92
104
|
|
105
|
+
def _get_latest_stable_version(self, version: str) -> str:
|
106
|
+
"""
|
107
|
+
Extract the latest stable PyPI version from a version string.
|
108
|
+
|
109
|
+
Examples:
|
110
|
+
- Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
|
111
|
+
- Return a stable version as is: "0.1.37" -> "0.1.37"
|
112
|
+
"""
|
113
|
+
if "dev" not in version:
|
114
|
+
return version
|
115
|
+
else:
|
116
|
+
# For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
|
117
|
+
major, minor, patch = version.split(".")[:3]
|
118
|
+
|
119
|
+
current_patch = int(patch)
|
120
|
+
latest_patch = current_patch - 1
|
121
|
+
return f"{major}.{minor}.{latest_patch}"
|
122
|
+
|
123
|
+
def _user_version_is_outdated(
|
124
|
+
self, user_version_str: str, server_version_str: str
|
125
|
+
) -> bool:
|
126
|
+
"""
|
127
|
+
Check if the user's EDSL version is outdated compared to the server's.
|
128
|
+
"""
|
129
|
+
server_stable_version_str = self._get_latest_stable_version(server_version_str)
|
130
|
+
user_stable_version_str = self._get_latest_stable_version(user_version_str)
|
131
|
+
|
132
|
+
# Turn the version strings into tuples of ints for comparison
|
133
|
+
user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
|
134
|
+
server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
|
135
|
+
|
136
|
+
return user_stable_version < server_stable_version
|
137
|
+
|
93
138
|
def _resolve_server_response(
|
94
139
|
self, response: requests.Response, check_api_key: bool = True
|
95
140
|
) -> None:
|
96
141
|
"""
|
97
142
|
Check the response from the server and raise errors as appropriate.
|
98
143
|
"""
|
144
|
+
# Get EDSL version from header
|
145
|
+
server_edsl_version = response.headers.get("X-EDSL-Version")
|
146
|
+
|
147
|
+
if server_edsl_version:
|
148
|
+
if self._user_version_is_outdated(
|
149
|
+
user_version_str=self._edsl_version,
|
150
|
+
server_version_str=server_edsl_version,
|
151
|
+
):
|
152
|
+
print(
|
153
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
154
|
+
)
|
155
|
+
|
99
156
|
if response.status_code >= 400:
|
100
157
|
message = response.json().get("detail")
|
101
158
|
# print(response.text)
|
@@ -568,7 +625,7 @@ class Coop:
|
|
568
625
|
|
569
626
|
>>> job = Jobs.example()
|
570
627
|
>>> coop.remote_inference_create(job=job, description="My job")
|
571
|
-
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.
|
628
|
+
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'iterations': None, 'visibility': 'unlisted', 'version': '0.1.38.dev1'}
|
572
629
|
"""
|
573
630
|
response = self._send_server_request(
|
574
631
|
uri="api/v0/remote-inference",
|
@@ -609,7 +666,7 @@ class Coop:
|
|
609
666
|
:param results_uuid: The UUID of the results associated with the EDSL job.
|
610
667
|
|
611
668
|
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
612
|
-
{'
|
669
|
+
{'job_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'latest_error_report_uuid': None, 'latest_error_report_url': None, 'status': 'completed', 'reason': None, 'credits_consumed': 0.35, 'version': '0.1.38.dev1'}
|
613
670
|
"""
|
614
671
|
if job_uuid is None and results_uuid is None:
|
615
672
|
raise ValueError("Either job_uuid or results_uuid must be provided.")
|
@@ -625,10 +682,28 @@ class Coop:
|
|
625
682
|
)
|
626
683
|
self._resolve_server_response(response)
|
627
684
|
data = response.json()
|
685
|
+
|
686
|
+
results_uuid = data.get("results_uuid")
|
687
|
+
latest_error_report_uuid = data.get("latest_error_report_uuid")
|
688
|
+
|
689
|
+
if results_uuid is None:
|
690
|
+
results_url = None
|
691
|
+
else:
|
692
|
+
results_url = f"{self.url}/content/{results_uuid}"
|
693
|
+
|
694
|
+
if latest_error_report_uuid is None:
|
695
|
+
latest_error_report_url = None
|
696
|
+
else:
|
697
|
+
latest_error_report_url = (
|
698
|
+
f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
|
699
|
+
)
|
700
|
+
|
628
701
|
return {
|
629
702
|
"job_uuid": data.get("job_uuid"),
|
630
|
-
"results_uuid":
|
631
|
-
"results_url":
|
703
|
+
"results_uuid": results_uuid,
|
704
|
+
"results_url": results_url,
|
705
|
+
"latest_error_report_uuid": latest_error_report_uuid,
|
706
|
+
"latest_error_report_url": latest_error_report_url,
|
632
707
|
"status": data.get("status"),
|
633
708
|
"reason": data.get("reason"),
|
634
709
|
"credits_consumed": data.get("price"),
|
@@ -645,7 +720,7 @@ class Coop:
|
|
645
720
|
|
646
721
|
>>> job = Jobs.example()
|
647
722
|
>>> coop.remote_inference_cost(input=job)
|
648
|
-
|
723
|
+
{'credits': 0.77, 'usd': 0.0076950000000000005}
|
649
724
|
"""
|
650
725
|
if isinstance(input, Jobs):
|
651
726
|
job = input
|
@@ -685,7 +760,7 @@ class Coop:
|
|
685
760
|
async def remote_async_execute_model_call(
|
686
761
|
self, model_dict: dict, user_prompt: str, system_prompt: str
|
687
762
|
) -> dict:
|
688
|
-
url = self.
|
763
|
+
url = self.api_url + "/inference/"
|
689
764
|
# print("Now using url: ", url)
|
690
765
|
data = {
|
691
766
|
"model_dict": model_dict,
|
@@ -706,7 +781,7 @@ class Coop:
|
|
706
781
|
] = "lime_survey",
|
707
782
|
email=None,
|
708
783
|
):
|
709
|
-
url = f"{self.
|
784
|
+
url = f"{self.api_url}/api/v0/export_to_{platform}"
|
710
785
|
if email:
|
711
786
|
data = {"json_string": json.dumps({"survey": survey, "email": email})}
|
712
787
|
else:
|
@@ -725,11 +800,15 @@ class Coop:
|
|
725
800
|
|
726
801
|
from edsl.config import CONFIG
|
727
802
|
|
728
|
-
if
|
803
|
+
if CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "True":
|
729
804
|
price_fetcher = PriceFetcher()
|
730
805
|
return price_fetcher.fetch_prices()
|
731
|
-
|
806
|
+
elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
|
732
807
|
return {}
|
808
|
+
else:
|
809
|
+
raise ValueError(
|
810
|
+
"Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
|
811
|
+
)
|
733
812
|
|
734
813
|
def fetch_models(self) -> dict:
|
735
814
|
"""
|
@@ -810,15 +889,6 @@ class Coop:
|
|
810
889
|
load_dotenv()
|
811
890
|
|
812
891
|
|
813
|
-
if __name__ == "__main__":
|
814
|
-
sheet_data = fetch_sheet_data()
|
815
|
-
if sheet_data:
|
816
|
-
print(f"Successfully fetched {len(sheet_data)} rows of data.")
|
817
|
-
print("First row:", sheet_data[0])
|
818
|
-
else:
|
819
|
-
print("Failed to fetch sheet data.")
|
820
|
-
|
821
|
-
|
822
892
|
def main():
|
823
893
|
"""
|
824
894
|
A simple example for the coop client
|
edsl/data/Cache.py
CHANGED
@@ -11,7 +11,8 @@ from typing import Optional, Union
|
|
11
11
|
from edsl.Base import Base
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
13
|
from edsl.utilities.utilities import dict_hash
|
14
|
-
from edsl.utilities.decorators import
|
14
|
+
from edsl.utilities.decorators import remove_edsl_version
|
15
|
+
from edsl.exceptions.cache import CacheError
|
15
16
|
|
16
17
|
|
17
18
|
class Cache(Base):
|
@@ -26,6 +27,8 @@ class Cache(Base):
|
|
26
27
|
:param method: The method of storage to use for the cache.
|
27
28
|
"""
|
28
29
|
|
30
|
+
__documentation__ = "https://docs.expectedparrot.com/en/latest/data.html"
|
31
|
+
|
29
32
|
data = {}
|
30
33
|
|
31
34
|
def __init__(
|
@@ -58,7 +61,7 @@ class Cache(Base):
|
|
58
61
|
|
59
62
|
self.filename = filename
|
60
63
|
if filename and data:
|
61
|
-
raise
|
64
|
+
raise CacheError("Cannot provide both filename and data")
|
62
65
|
if filename is None and data is None:
|
63
66
|
data = {}
|
64
67
|
if data is not None:
|
@@ -76,7 +79,7 @@ class Cache(Base):
|
|
76
79
|
if os.path.exists(filename):
|
77
80
|
self.add_from_sqlite(filename)
|
78
81
|
else:
|
79
|
-
raise
|
82
|
+
raise CacheError("Invalid file extension. Must be .jsonl or .db")
|
80
83
|
|
81
84
|
self._perform_checks()
|
82
85
|
|
@@ -116,7 +119,7 @@ class Cache(Base):
|
|
116
119
|
from edsl.data.CacheEntry import CacheEntry
|
117
120
|
|
118
121
|
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
119
|
-
raise
|
122
|
+
raise CacheError("Not all values are CacheEntry instances")
|
120
123
|
if self.method is not None:
|
121
124
|
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
122
125
|
|
@@ -227,9 +230,9 @@ class Cache(Base):
|
|
227
230
|
for key, value in new_data.items():
|
228
231
|
if key in self.data:
|
229
232
|
if value != self.data[key]:
|
230
|
-
raise
|
233
|
+
raise CacheError("Mismatch in values")
|
231
234
|
if not isinstance(value, CacheEntry):
|
232
|
-
raise
|
235
|
+
raise CacheError(f"Wrong type - the observed type is {type(value)}")
|
233
236
|
|
234
237
|
self.new_entries.update(new_data)
|
235
238
|
if write_now:
|
@@ -338,7 +341,7 @@ class Cache(Base):
|
|
338
341
|
elif filename.endswith(".db"):
|
339
342
|
self.write_sqlite_db(filename)
|
340
343
|
else:
|
341
|
-
raise
|
344
|
+
raise CacheError("Invalid file extension. Must be .jsonl or .db")
|
342
345
|
|
343
346
|
def write_jsonl(self, filename: str) -> None:
|
344
347
|
"""
|
@@ -396,20 +399,45 @@ class Cache(Base):
|
|
396
399
|
####################
|
397
400
|
def __hash__(self):
|
398
401
|
"""Return the hash of the Cache."""
|
399
|
-
return dict_hash(self.
|
402
|
+
return dict_hash(self.to_dict(add_edsl_version=False))
|
403
|
+
|
404
|
+
def to_dict(self, add_edsl_version=True) -> dict:
|
405
|
+
d = {k: v.to_dict() for k, v in self.data.items()}
|
406
|
+
if add_edsl_version:
|
407
|
+
from edsl import __version__
|
400
408
|
|
401
|
-
|
402
|
-
|
409
|
+
d["edsl_version"] = __version__
|
410
|
+
d["edsl_class_name"] = "Cache"
|
403
411
|
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
return self.
|
412
|
+
return d
|
413
|
+
|
414
|
+
def _summary(self):
|
415
|
+
return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
|
408
416
|
|
409
417
|
def _repr_html_(self):
|
410
|
-
from edsl.utilities.utilities import data_to_html
|
418
|
+
# from edsl.utilities.utilities import data_to_html
|
419
|
+
# return data_to_html(self.to_dict())
|
420
|
+
footer = f"<a href={self.__documentation__}>(docs)</a>"
|
421
|
+
return str(self.summary(format="html")) + footer
|
422
|
+
|
423
|
+
def table(
|
424
|
+
self,
|
425
|
+
*fields,
|
426
|
+
tablefmt: Optional[str] = None,
|
427
|
+
pretty_labels: Optional[dict] = None,
|
428
|
+
) -> str:
|
429
|
+
return self.to_dataset().table(
|
430
|
+
*fields, tablefmt=tablefmt, pretty_labels=pretty_labels
|
431
|
+
)
|
432
|
+
|
433
|
+
def select(self, *fields):
|
434
|
+
return self.to_dataset().select(*fields)
|
435
|
+
|
436
|
+
def tree(self, node_list: Optional[list[str]] = None):
|
437
|
+
return self.to_scenario_list().tree(node_list)
|
411
438
|
|
412
|
-
|
439
|
+
def to_dataset(self):
|
440
|
+
return self.to_scenario_list().to_dataset()
|
413
441
|
|
414
442
|
@classmethod
|
415
443
|
@remove_edsl_version
|
@@ -438,7 +466,7 @@ class Cache(Base):
|
|
438
466
|
Combine two caches.
|
439
467
|
"""
|
440
468
|
if not isinstance(other, Cache):
|
441
|
-
raise
|
469
|
+
raise CacheError("Can only add two caches together")
|
442
470
|
self.data.update(other.data)
|
443
471
|
return self
|
444
472
|
|
edsl/data/CacheEntry.py
CHANGED
@@ -96,9 +96,14 @@ class CacheEntry:
|
|
96
96
|
"""
|
97
97
|
Returns an HTML representation of a CacheEntry.
|
98
98
|
"""
|
99
|
-
from edsl.utilities.utilities import data_to_html
|
100
|
-
|
101
|
-
|
99
|
+
# from edsl.utilities.utilities import data_to_html
|
100
|
+
# return data_to_html(self.to_dict())
|
101
|
+
d = self.to_dict()
|
102
|
+
data = [[k, v] for k, v in d.items()]
|
103
|
+
from tabulate import tabulate
|
104
|
+
|
105
|
+
table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
|
106
|
+
return f"<pre>{table}</pre>"
|
102
107
|
|
103
108
|
def keys(self):
|
104
109
|
return list(self.to_dict().keys())
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -76,22 +76,3 @@ class RemoteCacheSync:
|
|
76
76
|
self._output(
|
77
77
|
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
78
|
)
|
79
|
-
|
80
|
-
|
81
|
-
# # Usage example
|
82
|
-
# def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
|
83
|
-
# with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
|
84
|
-
# self._output("Running job...")
|
85
|
-
# results = self._run_local(
|
86
|
-
# n=n,
|
87
|
-
# progress_bar=progress_bar,
|
88
|
-
# cache=cache,
|
89
|
-
# stop_on_exception=stop_on_exception,
|
90
|
-
# sidecar_model=sidecar_model,
|
91
|
-
# print_exceptions=print_exceptions,
|
92
|
-
# raise_validation_errors=raise_validation_errors,
|
93
|
-
# )
|
94
|
-
# self._output("Job completed!")
|
95
|
-
|
96
|
-
# results.cache = cache.new_entries_cache()
|
97
|
-
# return results
|
edsl/enums.py
CHANGED
@@ -64,6 +64,7 @@ class InferenceServiceType(EnumWithChecks):
|
|
64
64
|
OLLAMA = "ollama"
|
65
65
|
MISTRAL = "mistral"
|
66
66
|
TOGETHER = "together"
|
67
|
+
PERPLEXITY = "perplexity"
|
67
68
|
|
68
69
|
|
69
70
|
service_to_api_keyname = {
|
@@ -78,6 +79,7 @@ service_to_api_keyname = {
|
|
78
79
|
InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
|
79
80
|
InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
|
80
81
|
InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
|
82
|
+
InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
|
81
83
|
}
|
82
84
|
|
83
85
|
|
edsl/exceptions/agents.py
CHANGED
@@ -1,6 +1,10 @@
|
|
1
1
|
from edsl.exceptions.BaseException import BaseException
|
2
2
|
|
3
3
|
|
4
|
+
class AgentListError(BaseException):
|
5
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
|
6
|
+
|
7
|
+
|
4
8
|
class AgentErrors(BaseException):
|
5
9
|
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html"
|
6
10
|
|
edsl/exceptions/cache.py
ADDED
@@ -8,6 +8,7 @@ from google.api_core.exceptions import InvalidArgument
|
|
8
8
|
from edsl.exceptions import MissingAPIKeyError
|
9
9
|
from edsl.language_models.LanguageModel import LanguageModel
|
10
10
|
from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
11
|
+
from edsl.coop import Coop
|
11
12
|
|
12
13
|
safety_settings = [
|
13
14
|
{
|
@@ -79,22 +80,8 @@ class GoogleService(InferenceServiceABC):
|
|
79
80
|
api_token = None
|
80
81
|
model = None
|
81
82
|
|
82
|
-
@classmethod
|
83
|
-
def initialize(cls):
|
84
|
-
if cls.api_token is None:
|
85
|
-
cls.api_token = os.getenv("GOOGLE_API_KEY")
|
86
|
-
if not cls.api_token:
|
87
|
-
raise MissingAPIKeyError(
|
88
|
-
"GOOGLE_API_KEY environment variable is not set"
|
89
|
-
)
|
90
|
-
genai.configure(api_key=cls.api_token)
|
91
|
-
cls.generative_model = genai.GenerativeModel(
|
92
|
-
cls._model_, safety_settings=safety_settings
|
93
|
-
)
|
94
|
-
|
95
83
|
def __init__(self, *args, **kwargs):
|
96
84
|
super().__init__(*args, **kwargs)
|
97
|
-
self.initialize()
|
98
85
|
|
99
86
|
def get_generation_config(self) -> GenerationConfig:
|
100
87
|
return GenerationConfig(
|
@@ -116,6 +103,7 @@ class GoogleService(InferenceServiceABC):
|
|
116
103
|
if files_list is None:
|
117
104
|
files_list = []
|
118
105
|
|
106
|
+
genai.configure(api_key=self.api_token)
|
119
107
|
if (
|
120
108
|
system_prompt is not None
|
121
109
|
and system_prompt != ""
|
@@ -133,7 +121,11 @@ class GoogleService(InferenceServiceABC):
|
|
133
121
|
)
|
134
122
|
print("Will add system_prompt to user_prompt")
|
135
123
|
user_prompt = f"{system_prompt}\n{user_prompt}"
|
136
|
-
|
124
|
+
else:
|
125
|
+
self.generative_model = genai.GenerativeModel(
|
126
|
+
self._model_,
|
127
|
+
safety_settings=safety_settings,
|
128
|
+
)
|
137
129
|
combined_prompt = [user_prompt]
|
138
130
|
for file in files_list:
|
139
131
|
if "google" not in file.external_locations:
|
@@ -0,0 +1,163 @@
|
|
1
|
+
import aiohttp
|
2
|
+
import json
|
3
|
+
import requests
|
4
|
+
from typing import Any, List, Optional
|
5
|
+
from edsl.inference_services.rate_limits_cache import rate_limits
|
6
|
+
|
7
|
+
# from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
|
8
|
+
from edsl.language_models import LanguageModel
|
9
|
+
|
10
|
+
from edsl.inference_services.OpenAIService import OpenAIService
|
11
|
+
|
12
|
+
|
13
|
+
class PerplexityService(OpenAIService):
|
14
|
+
"""Perplexity service class."""
|
15
|
+
|
16
|
+
_inference_service_ = "perplexity"
|
17
|
+
_env_key_name_ = "PERPLEXITY_API_KEY"
|
18
|
+
_base_url_ = "https://api.perplexity.ai"
|
19
|
+
_models_list_cache: List[str] = []
|
20
|
+
# default perplexity parameters
|
21
|
+
_parameters_ = {
|
22
|
+
"temperature": 0.5,
|
23
|
+
"max_tokens": 1000,
|
24
|
+
"top_p": 1,
|
25
|
+
"logprobs": False,
|
26
|
+
"top_logprobs": 3,
|
27
|
+
}
|
28
|
+
|
29
|
+
@classmethod
|
30
|
+
def available(cls) -> List[str]:
|
31
|
+
return [
|
32
|
+
"llama-3.1-sonar-huge-128k-online",
|
33
|
+
"llama-3.1-sonar-large-128k-online",
|
34
|
+
"llama-3.1-sonar-small-128k-online",
|
35
|
+
]
|
36
|
+
|
37
|
+
@classmethod
|
38
|
+
def create_model(
|
39
|
+
cls, model_name="llama-3.1-sonar-large-128k-online", model_class_name=None
|
40
|
+
) -> LanguageModel:
|
41
|
+
if model_class_name is None:
|
42
|
+
model_class_name = cls.to_class_name(model_name)
|
43
|
+
|
44
|
+
class LLM(LanguageModel):
|
45
|
+
"""
|
46
|
+
Child class of LanguageModel for interacting with Perplexity models
|
47
|
+
"""
|
48
|
+
|
49
|
+
key_sequence = cls.key_sequence
|
50
|
+
usage_sequence = cls.usage_sequence
|
51
|
+
input_token_name = cls.input_token_name
|
52
|
+
output_token_name = cls.output_token_name
|
53
|
+
|
54
|
+
_rpm = cls.get_rpm(cls)
|
55
|
+
_tpm = cls.get_tpm(cls)
|
56
|
+
|
57
|
+
_inference_service_ = cls._inference_service_
|
58
|
+
_model_ = model_name
|
59
|
+
|
60
|
+
_parameters_ = {
|
61
|
+
"temperature": 0.5,
|
62
|
+
"max_tokens": 1000,
|
63
|
+
"top_p": 1,
|
64
|
+
"frequency_penalty": 1,
|
65
|
+
"presence_penalty": 0,
|
66
|
+
# "logprobs": False, # Enable this returns 'Neither or both of logprobs and top_logprobs must be set.
|
67
|
+
# "top_logprobs": 3,
|
68
|
+
}
|
69
|
+
|
70
|
+
def sync_client(self):
|
71
|
+
return cls.sync_client()
|
72
|
+
|
73
|
+
def async_client(self):
|
74
|
+
return cls.async_client()
|
75
|
+
|
76
|
+
@classmethod
|
77
|
+
def available(cls) -> list[str]:
|
78
|
+
return cls.sync_client().models.list()
|
79
|
+
|
80
|
+
def get_headers(self) -> dict[str, Any]:
|
81
|
+
client = self.sync_client()
|
82
|
+
response = client.chat.completions.with_raw_response.create(
|
83
|
+
messages=[
|
84
|
+
{
|
85
|
+
"role": "user",
|
86
|
+
"content": "Say this is a test",
|
87
|
+
}
|
88
|
+
],
|
89
|
+
model=self.model,
|
90
|
+
)
|
91
|
+
return dict(response.headers)
|
92
|
+
|
93
|
+
def get_rate_limits(self) -> dict[str, Any]:
|
94
|
+
try:
|
95
|
+
if "openai" in rate_limits:
|
96
|
+
headers = rate_limits["openai"]
|
97
|
+
|
98
|
+
else:
|
99
|
+
headers = self.get_headers()
|
100
|
+
|
101
|
+
except Exception as e:
|
102
|
+
return {
|
103
|
+
"rpm": 10_000,
|
104
|
+
"tpm": 2_000_000,
|
105
|
+
}
|
106
|
+
else:
|
107
|
+
return {
|
108
|
+
"rpm": int(headers["x-ratelimit-limit-requests"]),
|
109
|
+
"tpm": int(headers["x-ratelimit-limit-tokens"]),
|
110
|
+
}
|
111
|
+
|
112
|
+
async def async_execute_model_call(
|
113
|
+
self,
|
114
|
+
user_prompt: str,
|
115
|
+
system_prompt: str = "",
|
116
|
+
files_list: Optional[List["Files"]] = None,
|
117
|
+
invigilator: Optional[
|
118
|
+
"InvigilatorAI"
|
119
|
+
] = None, # TBD - can eventually be used for function-calling
|
120
|
+
) -> dict[str, Any]:
|
121
|
+
"""Calls the OpenAI API and returns the API response."""
|
122
|
+
if files_list:
|
123
|
+
encoded_image = files_list[0].base64_string
|
124
|
+
content = [{"type": "text", "text": user_prompt}]
|
125
|
+
content.append(
|
126
|
+
{
|
127
|
+
"type": "image_url",
|
128
|
+
"image_url": {
|
129
|
+
"url": f"data:image/jpeg;base64,{encoded_image}"
|
130
|
+
},
|
131
|
+
}
|
132
|
+
)
|
133
|
+
else:
|
134
|
+
content = user_prompt
|
135
|
+
client = self.async_client()
|
136
|
+
|
137
|
+
messages = [
|
138
|
+
{"role": "system", "content": system_prompt},
|
139
|
+
{"role": "user", "content": content},
|
140
|
+
]
|
141
|
+
if system_prompt == "" and self.omit_system_prompt_if_empty:
|
142
|
+
messages = messages[1:]
|
143
|
+
|
144
|
+
params = {
|
145
|
+
"model": self.model,
|
146
|
+
"messages": messages,
|
147
|
+
"temperature": self.temperature,
|
148
|
+
"max_tokens": self.max_tokens,
|
149
|
+
"top_p": self.top_p,
|
150
|
+
"frequency_penalty": self.frequency_penalty,
|
151
|
+
"presence_penalty": self.presence_penalty,
|
152
|
+
# "logprobs": self.logprobs,
|
153
|
+
# "top_logprobs": self.top_logprobs if self.logprobs else None,
|
154
|
+
}
|
155
|
+
try:
|
156
|
+
response = await client.chat.completions.create(**params)
|
157
|
+
except Exception as e:
|
158
|
+
print(e, flush=True)
|
159
|
+
return response.model_dump()
|
160
|
+
|
161
|
+
LLM.__name__ = "LanguageModel"
|
162
|
+
|
163
|
+
return LLM
|
@@ -12,6 +12,7 @@ from edsl.inference_services.AzureAI import AzureAIService
|
|
12
12
|
from edsl.inference_services.OllamaService import OllamaService
|
13
13
|
from edsl.inference_services.TestService import TestService
|
14
14
|
from edsl.inference_services.TogetherAIService import TogetherAIService
|
15
|
+
from edsl.inference_services.PerplexityService import PerplexityService
|
15
16
|
|
16
17
|
try:
|
17
18
|
from edsl.inference_services.MistralAIService import MistralAIService
|
@@ -31,6 +32,7 @@ services = [
|
|
31
32
|
OllamaService,
|
32
33
|
TestService,
|
33
34
|
TogetherAIService,
|
35
|
+
PerplexityService,
|
34
36
|
]
|
35
37
|
|
36
38
|
if mistral_available:
|