edsl 0.1.36.dev5__py3-none-any.whl → 0.1.37__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +1 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +92 -41
- edsl/agents/AgentList.py +15 -2
- edsl/agents/InvigilatorBase.py +15 -25
- edsl/agents/PromptConstructor.py +149 -108
- edsl/agents/descriptors.py +17 -4
- edsl/conjure/AgentConstructionMixin.py +11 -3
- edsl/conversation/Conversation.py +66 -14
- edsl/conversation/chips.py +95 -0
- edsl/coop/coop.py +148 -39
- edsl/data/Cache.py +1 -1
- edsl/data/RemoteCacheSync.py +25 -12
- edsl/exceptions/BaseException.py +21 -0
- edsl/exceptions/__init__.py +7 -3
- edsl/exceptions/agents.py +17 -19
- edsl/exceptions/results.py +11 -8
- edsl/exceptions/scenarios.py +22 -0
- edsl/exceptions/surveys.py +13 -10
- edsl/inference_services/AwsBedrock.py +7 -2
- edsl/inference_services/InferenceServicesCollection.py +42 -13
- edsl/inference_services/models_available_cache.py +25 -1
- edsl/jobs/Jobs.py +306 -71
- edsl/jobs/interviews/Interview.py +24 -14
- edsl/jobs/interviews/InterviewExceptionCollection.py +1 -1
- edsl/jobs/interviews/InterviewExceptionEntry.py +17 -13
- edsl/jobs/interviews/ReportErrors.py +2 -2
- edsl/jobs/runners/JobsRunnerAsyncio.py +10 -9
- edsl/jobs/tasks/TaskHistory.py +1 -0
- edsl/language_models/KeyLookup.py +30 -0
- edsl/language_models/LanguageModel.py +47 -59
- edsl/language_models/__init__.py +1 -0
- edsl/prompts/Prompt.py +11 -12
- edsl/questions/QuestionBase.py +53 -13
- edsl/questions/QuestionBasePromptsMixin.py +1 -33
- edsl/questions/QuestionFreeText.py +1 -0
- edsl/questions/QuestionFunctional.py +2 -2
- edsl/questions/descriptors.py +23 -28
- edsl/results/DatasetExportMixin.py +25 -1
- edsl/results/Result.py +27 -10
- edsl/results/Results.py +34 -121
- edsl/results/ResultsDBMixin.py +1 -1
- edsl/results/Selector.py +18 -1
- edsl/scenarios/FileStore.py +20 -5
- edsl/scenarios/Scenario.py +52 -13
- edsl/scenarios/ScenarioHtmlMixin.py +7 -2
- edsl/scenarios/ScenarioList.py +12 -1
- edsl/scenarios/__init__.py +2 -0
- edsl/surveys/Rule.py +10 -4
- edsl/surveys/Survey.py +100 -77
- edsl/utilities/utilities.py +18 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/RECORD +55 -51
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dev5.dist-info → edsl-0.1.37.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -28,9 +28,18 @@ class Coop:
|
|
28
28
|
- Provide a URL directly, or use the default one.
|
29
29
|
"""
|
30
30
|
self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
|
31
|
+
|
31
32
|
self.url = url or CONFIG.EXPECTED_PARROT_URL
|
32
33
|
if self.url.endswith("/"):
|
33
34
|
self.url = self.url[:-1]
|
35
|
+
if "chick.expectedparrot" in self.url:
|
36
|
+
self.api_url = "https://chickapi.expectedparrot.com"
|
37
|
+
elif "expectedparrot" in self.url:
|
38
|
+
self.api_url = "https://api.expectedparrot.com"
|
39
|
+
elif "localhost:1234" in self.url:
|
40
|
+
self.api_url = "http://localhost:8000"
|
41
|
+
else:
|
42
|
+
self.api_url = self.url
|
34
43
|
self._edsl_version = edsl.__version__
|
35
44
|
|
36
45
|
################
|
@@ -59,7 +68,7 @@ class Coop:
|
|
59
68
|
"""
|
60
69
|
Send a request to the server and return the response.
|
61
70
|
"""
|
62
|
-
url = f"{self.
|
71
|
+
url = f"{self.api_url}/{uri}"
|
63
72
|
method = method.upper()
|
64
73
|
if payload is None:
|
65
74
|
timeout = 20
|
@@ -90,18 +99,83 @@ class Coop:
|
|
90
99
|
|
91
100
|
return response
|
92
101
|
|
93
|
-
def _resolve_server_response(
|
102
|
+
def _resolve_server_response(
|
103
|
+
self, response: requests.Response, check_api_key: bool = True
|
104
|
+
) -> None:
|
94
105
|
"""
|
95
106
|
Check the response from the server and raise errors as appropriate.
|
96
107
|
"""
|
97
108
|
if response.status_code >= 400:
|
98
109
|
message = response.json().get("detail")
|
99
110
|
# print(response.text)
|
100
|
-
if "
|
111
|
+
if "The API key you provided is invalid" in message and check_api_key:
|
112
|
+
import secrets
|
113
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
114
|
+
|
115
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
116
|
+
|
117
|
+
print("Your Expected Parrot API key is invalid.")
|
118
|
+
print(
|
119
|
+
"\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
|
120
|
+
)
|
121
|
+
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
122
|
+
api_key = self._poll_for_api_key(edsl_auth_token)
|
123
|
+
|
124
|
+
if api_key is None:
|
125
|
+
print("\nTimed out waiting for login. Please try again.")
|
126
|
+
return
|
127
|
+
|
128
|
+
write_api_key_to_env(api_key)
|
129
|
+
print("\n✨ API key retrieved and written to .env file.")
|
130
|
+
print("Rerun your code to try again with a valid API key.")
|
131
|
+
return
|
132
|
+
|
133
|
+
elif "Authorization" in message:
|
101
134
|
print(message)
|
102
135
|
message = "Please provide an Expected Parrot API key."
|
136
|
+
|
103
137
|
raise CoopServerResponseError(message)
|
104
138
|
|
139
|
+
def _poll_for_api_key(
|
140
|
+
self, edsl_auth_token: str, timeout: int = 120
|
141
|
+
) -> Union[str, None]:
|
142
|
+
"""
|
143
|
+
Allows the user to retrieve their Expected Parrot API key by logging in with an EDSL auth token.
|
144
|
+
|
145
|
+
:param edsl_auth_token: The EDSL auth token to use for login
|
146
|
+
:param timeout: Maximum time to wait for login, in seconds (default: 120)
|
147
|
+
"""
|
148
|
+
import time
|
149
|
+
from datetime import datetime
|
150
|
+
|
151
|
+
start_poll_time = time.time()
|
152
|
+
waiting_for_login = True
|
153
|
+
while waiting_for_login:
|
154
|
+
elapsed_time = time.time() - start_poll_time
|
155
|
+
if elapsed_time > timeout:
|
156
|
+
# Timed out waiting for the user to log in
|
157
|
+
print("\r" + " " * 80 + "\r", end="")
|
158
|
+
return None
|
159
|
+
|
160
|
+
api_key = self._get_api_key(edsl_auth_token)
|
161
|
+
if api_key is not None:
|
162
|
+
print("\r" + " " * 80 + "\r", end="")
|
163
|
+
return api_key
|
164
|
+
else:
|
165
|
+
duration = 5
|
166
|
+
time_checked = datetime.now().strftime("%Y-%m-%d %I:%M:%S %p")
|
167
|
+
frames = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
|
168
|
+
start_time = time.time()
|
169
|
+
i = 0
|
170
|
+
while time.time() - start_time < duration:
|
171
|
+
print(
|
172
|
+
f"\r{frames[i % len(frames)]} Waiting for login. Last checked: {time_checked}",
|
173
|
+
end="",
|
174
|
+
flush=True,
|
175
|
+
)
|
176
|
+
time.sleep(0.1)
|
177
|
+
i += 1
|
178
|
+
|
105
179
|
def _json_handle_none(self, value: Any) -> Any:
|
106
180
|
"""
|
107
181
|
Handle None values during JSON serialization.
|
@@ -134,7 +208,7 @@ class Coop:
|
|
134
208
|
response = self._send_server_request(
|
135
209
|
uri="api/v0/edsl-settings", method="GET", timeout=5
|
136
210
|
)
|
137
|
-
self._resolve_server_response(response)
|
211
|
+
self._resolve_server_response(response, check_api_key=False)
|
138
212
|
return response.json()
|
139
213
|
except Timeout:
|
140
214
|
return {}
|
@@ -489,6 +563,7 @@ class Coop:
|
|
489
563
|
description: Optional[str] = None,
|
490
564
|
status: RemoteJobStatus = "queued",
|
491
565
|
visibility: Optional[VisibilityType] = "unlisted",
|
566
|
+
initial_results_visibility: Optional[VisibilityType] = "unlisted",
|
492
567
|
iterations: Optional[int] = 1,
|
493
568
|
) -> dict:
|
494
569
|
"""
|
@@ -517,6 +592,7 @@ class Coop:
|
|
517
592
|
"iterations": iterations,
|
518
593
|
"visibility": visibility,
|
519
594
|
"version": self._edsl_version,
|
595
|
+
"initial_results_visibility": initial_results_visibility,
|
520
596
|
},
|
521
597
|
)
|
522
598
|
self._resolve_server_response(response)
|
@@ -568,7 +644,9 @@ class Coop:
|
|
568
644
|
"version": data.get("version"),
|
569
645
|
}
|
570
646
|
|
571
|
-
def remote_inference_cost(
|
647
|
+
def remote_inference_cost(
|
648
|
+
self, input: Union[Jobs, Survey], iterations: int = 1
|
649
|
+
) -> int:
|
572
650
|
"""
|
573
651
|
Get the cost of a remote inference job.
|
574
652
|
|
@@ -593,6 +671,7 @@ class Coop:
|
|
593
671
|
job.to_dict(),
|
594
672
|
default=self._json_handle_none,
|
595
673
|
),
|
674
|
+
"iterations": iterations,
|
596
675
|
},
|
597
676
|
)
|
598
677
|
self._resolve_server_response(response)
|
@@ -602,24 +681,6 @@ class Coop:
|
|
602
681
|
"usd": response_json.get("cost_in_usd"),
|
603
682
|
}
|
604
683
|
|
605
|
-
################
|
606
|
-
# Remote Errors
|
607
|
-
################
|
608
|
-
def error_create(self, error_data: str) -> dict:
|
609
|
-
"""
|
610
|
-
Send an error message to the server.
|
611
|
-
"""
|
612
|
-
response = self._send_server_request(
|
613
|
-
uri="api/v0/errors",
|
614
|
-
method="POST",
|
615
|
-
payload={
|
616
|
-
"json_string": json.dumps(error_data),
|
617
|
-
"version": self._edsl_version,
|
618
|
-
},
|
619
|
-
)
|
620
|
-
self._resolve_server_response(response)
|
621
|
-
return response.json()
|
622
|
-
|
623
684
|
################
|
624
685
|
# DUNDER METHODS
|
625
686
|
################
|
@@ -633,7 +694,7 @@ class Coop:
|
|
633
694
|
async def remote_async_execute_model_call(
|
634
695
|
self, model_dict: dict, user_prompt: str, system_prompt: str
|
635
696
|
) -> dict:
|
636
|
-
url = self.
|
697
|
+
url = self.api_url + "/inference/"
|
637
698
|
# print("Now using url: ", url)
|
638
699
|
data = {
|
639
700
|
"model_dict": model_dict,
|
@@ -654,7 +715,7 @@ class Coop:
|
|
654
715
|
] = "lime_survey",
|
655
716
|
email=None,
|
656
717
|
):
|
657
|
-
url = f"{self.
|
718
|
+
url = f"{self.api_url}/api/v0/export_to_{platform}"
|
658
719
|
if email:
|
659
720
|
data = {"json_string": json.dumps({"survey": survey, "email": email})}
|
660
721
|
else:
|
@@ -679,6 +740,17 @@ class Coop:
|
|
679
740
|
else:
|
680
741
|
return {}
|
681
742
|
|
743
|
+
def fetch_models(self) -> dict:
|
744
|
+
"""
|
745
|
+
Fetch a dict of available models from Coop.
|
746
|
+
|
747
|
+
Each key in the dict is an inference service, and each value is a list of models from that service.
|
748
|
+
"""
|
749
|
+
response = self._send_server_request(uri="api/v0/models", method="GET")
|
750
|
+
self._resolve_server_response(response)
|
751
|
+
data = response.json()
|
752
|
+
return data
|
753
|
+
|
682
754
|
def fetch_rate_limit_config_vars(self) -> dict:
|
683
755
|
"""
|
684
756
|
Fetch a dict of rate limit config vars from Coop.
|
@@ -693,14 +765,58 @@ class Coop:
|
|
693
765
|
data = response.json()
|
694
766
|
return data
|
695
767
|
|
768
|
+
def _display_login_url(self, edsl_auth_token: str):
|
769
|
+
"""
|
770
|
+
Uses rich.print to display a login URL.
|
771
|
+
|
772
|
+
- We need this function because URL detection with print() does not work alongside animations in VSCode.
|
773
|
+
"""
|
774
|
+
from rich import print as rich_print
|
775
|
+
|
776
|
+
url = f"{CONFIG.EXPECTED_PARROT_URL}/login?edsl_auth_token={edsl_auth_token}"
|
777
|
+
|
778
|
+
rich_print(f"[#38bdf8][link={url}]{url}[/link][/#38bdf8]")
|
779
|
+
|
780
|
+
def _get_api_key(self, edsl_auth_token: str):
|
781
|
+
"""
|
782
|
+
Given an EDSL auth token, find the corresponding user's API key.
|
783
|
+
"""
|
784
|
+
|
785
|
+
response = self._send_server_request(
|
786
|
+
uri="api/v0/get-api-key",
|
787
|
+
method="POST",
|
788
|
+
payload={
|
789
|
+
"edsl_auth_token": edsl_auth_token,
|
790
|
+
},
|
791
|
+
)
|
792
|
+
data = response.json()
|
793
|
+
api_key = data.get("api_key")
|
794
|
+
return api_key
|
795
|
+
|
796
|
+
def login(self):
|
797
|
+
"""
|
798
|
+
Starts the EDSL auth token login flow.
|
799
|
+
"""
|
800
|
+
import secrets
|
801
|
+
from dotenv import load_dotenv
|
802
|
+
from edsl.utilities.utilities import write_api_key_to_env
|
803
|
+
|
804
|
+
edsl_auth_token = secrets.token_urlsafe(16)
|
805
|
+
|
806
|
+
print(
|
807
|
+
"\nUse the link below to log in to Expected Parrot so we can automatically update your API key."
|
808
|
+
)
|
809
|
+
self._display_login_url(edsl_auth_token=edsl_auth_token)
|
810
|
+
api_key = self._poll_for_api_key(edsl_auth_token)
|
811
|
+
|
812
|
+
if api_key is None:
|
813
|
+
raise Exception("Timed out waiting for login. Please try again.")
|
696
814
|
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
else:
|
703
|
-
print("Failed to fetch sheet data.")
|
815
|
+
write_api_key_to_env(api_key)
|
816
|
+
print("\n✨ API key retrieved and written to .env file.")
|
817
|
+
|
818
|
+
# Add API key to environment
|
819
|
+
load_dotenv()
|
704
820
|
|
705
821
|
|
706
822
|
def main():
|
@@ -840,10 +956,3 @@ def main():
|
|
840
956
|
job_coop_object = coop.remote_inference_create(job)
|
841
957
|
job_coop_results = coop.remote_inference_get(job_coop_object.get("uuid"))
|
842
958
|
coop.get(uuid=job_coop_results.get("results_uuid"))
|
843
|
-
|
844
|
-
##############
|
845
|
-
# E. Errors
|
846
|
-
##############
|
847
|
-
coop.error_create({"something": "This is an error message"})
|
848
|
-
coop.api_key = None
|
849
|
-
coop.error_create({"something": "This is an error message"})
|
edsl/data/Cache.py
CHANGED
@@ -194,7 +194,7 @@ class Cache(Base):
|
|
194
194
|
>>> c = Cache()
|
195
195
|
>>> len(c)
|
196
196
|
0
|
197
|
-
>>> results = Question.example("free_text").by(m).run(cache = c)
|
197
|
+
>>> results = Question.example("free_text").by(m).run(cache = c, disable_remote_cache = True, disable_remote_inference = True)
|
198
198
|
>>> len(c)
|
199
199
|
1
|
200
200
|
"""
|
edsl/data/RemoteCacheSync.py
CHANGED
@@ -1,5 +1,7 @@
|
|
1
1
|
class RemoteCacheSync:
|
2
|
-
def __init__(
|
2
|
+
def __init__(
|
3
|
+
self, coop, cache, output_func, remote_cache=True, remote_cache_description=""
|
4
|
+
):
|
3
5
|
self.coop = coop
|
4
6
|
self.cache = cache
|
5
7
|
self._output = output_func
|
@@ -21,34 +23,42 @@ class RemoteCacheSync:
|
|
21
23
|
|
22
24
|
def _sync_from_remote(self):
|
23
25
|
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
24
|
-
client_missing_cacheentries = cache_difference.get(
|
26
|
+
client_missing_cacheentries = cache_difference.get(
|
27
|
+
"client_missing_cacheentries", []
|
28
|
+
)
|
25
29
|
missing_entry_count = len(client_missing_cacheentries)
|
26
|
-
|
30
|
+
|
27
31
|
if missing_entry_count > 0:
|
28
32
|
self._output(
|
29
33
|
f"Updating local cache with {missing_entry_count:,} new "
|
30
34
|
f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
|
31
35
|
)
|
32
|
-
self.cache.add_from_dict(
|
36
|
+
self.cache.add_from_dict(
|
37
|
+
{entry.key: entry for entry in client_missing_cacheentries}
|
38
|
+
)
|
33
39
|
self._output("Local cache updated!")
|
34
40
|
else:
|
35
41
|
self._output("No new entries to add to local cache.")
|
36
42
|
|
37
43
|
def _sync_to_remote(self):
|
38
44
|
cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
|
39
|
-
server_missing_cacheentry_keys = cache_difference.get(
|
45
|
+
server_missing_cacheentry_keys = cache_difference.get(
|
46
|
+
"server_missing_cacheentry_keys", []
|
47
|
+
)
|
40
48
|
server_missing_cacheentries = [
|
41
49
|
entry
|
42
50
|
for key in server_missing_cacheentry_keys
|
43
51
|
if (entry := self.cache.data.get(key)) is not None
|
44
52
|
]
|
45
|
-
|
53
|
+
|
46
54
|
new_cache_entries = [
|
47
|
-
entry
|
55
|
+
entry
|
56
|
+
for entry in self.cache.values()
|
57
|
+
if entry.key not in self.old_entry_keys
|
48
58
|
]
|
49
59
|
server_missing_cacheentries.extend(new_cache_entries)
|
50
60
|
new_entry_count = len(server_missing_cacheentries)
|
51
|
-
|
61
|
+
|
52
62
|
if new_entry_count > 0:
|
53
63
|
self._output(
|
54
64
|
f"Updating remote cache with {new_entry_count:,} new "
|
@@ -62,8 +72,11 @@ class RemoteCacheSync:
|
|
62
72
|
self._output("Remote cache updated!")
|
63
73
|
else:
|
64
74
|
self._output("No new entries to add to remote cache.")
|
65
|
-
|
66
|
-
self._output(
|
75
|
+
|
76
|
+
self._output(
|
77
|
+
f"There are {len(self.cache.keys()):,} entries in the local cache."
|
78
|
+
)
|
79
|
+
|
67
80
|
|
68
81
|
# # Usage example
|
69
82
|
# def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
|
@@ -79,6 +92,6 @@ class RemoteCacheSync:
|
|
79
92
|
# raise_validation_errors=raise_validation_errors,
|
80
93
|
# )
|
81
94
|
# self._output("Job completed!")
|
82
|
-
|
95
|
+
|
83
96
|
# results.cache = cache.new_entries_cache()
|
84
|
-
# return results
|
97
|
+
# return results
|
@@ -0,0 +1,21 @@
|
|
1
|
+
class BaseException(Exception):
|
2
|
+
relevant_doc = "https://docs.expectedparrot.com/"
|
3
|
+
|
4
|
+
def __init__(self, message, *, show_docs=True):
|
5
|
+
# Format main error message
|
6
|
+
formatted_message = [message.strip()]
|
7
|
+
|
8
|
+
# Add documentation links if requested
|
9
|
+
if show_docs:
|
10
|
+
if hasattr(self, "relevant_doc"):
|
11
|
+
formatted_message.append(
|
12
|
+
f"\nFor more information, see:\n{self.relevant_doc}"
|
13
|
+
)
|
14
|
+
if hasattr(self, "relevant_notebook"):
|
15
|
+
formatted_message.append(
|
16
|
+
f"\nFor a usage example, see:\n{self.relevant_notebook}"
|
17
|
+
)
|
18
|
+
|
19
|
+
# Join with double newlines for clear separation
|
20
|
+
final_message = "\n\n".join(formatted_message)
|
21
|
+
super().__init__(final_message)
|
edsl/exceptions/__init__.py
CHANGED
@@ -1,8 +1,8 @@
|
|
1
1
|
from .agents import (
|
2
|
-
AgentAttributeLookupCallbackError,
|
2
|
+
# AgentAttributeLookupCallbackError,
|
3
3
|
AgentCombinationError,
|
4
|
-
AgentLacksLLMError,
|
5
|
-
AgentRespondedWithBadJSONError,
|
4
|
+
# AgentLacksLLMError,
|
5
|
+
# AgentRespondedWithBadJSONError,
|
6
6
|
)
|
7
7
|
from .configuration import (
|
8
8
|
InvalidEnvironmentVariableError,
|
@@ -14,6 +14,10 @@ from .data import (
|
|
14
14
|
DatabaseIntegrityError,
|
15
15
|
)
|
16
16
|
|
17
|
+
from .scenarios import (
|
18
|
+
ScenarioError,
|
19
|
+
)
|
20
|
+
|
17
21
|
from .general import MissingAPIKeyError
|
18
22
|
|
19
23
|
from .jobs import JobsRunError, InterviewErrorPriorTaskCanceled, InterviewTimeoutError
|
edsl/exceptions/agents.py
CHANGED
@@ -1,37 +1,35 @@
|
|
1
|
-
|
2
|
-
pass
|
1
|
+
from edsl.exceptions.BaseException import BaseException
|
3
2
|
|
4
3
|
|
5
|
-
class
|
6
|
-
|
4
|
+
class AgentErrors(BaseException):
|
5
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html"
|
7
6
|
|
8
7
|
|
9
|
-
class
|
10
|
-
|
8
|
+
class AgentDynamicTraitsFunctionError(AgentErrors):
|
9
|
+
relevant_doc = (
|
10
|
+
"https://docs.expectedparrot.com/en/latest/agents.html#dynamic-traits-function"
|
11
|
+
)
|
12
|
+
relevant_notebook = "https://docs.expectedparrot.com/en/latest/notebooks/example_agent_dynamic_traits.html"
|
11
13
|
|
12
14
|
|
13
|
-
class
|
14
|
-
|
15
|
+
class AgentDirectAnswerFunctionError(AgentErrors):
|
16
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-direct-answering-methods"
|
15
17
|
|
16
18
|
|
17
19
|
class AgentCombinationError(AgentErrors):
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
class AgentLacksLLMError(AgentErrors):
|
22
|
-
pass
|
23
|
-
|
24
|
-
|
25
|
-
class AgentRespondedWithBadJSONError(AgentErrors):
|
26
|
-
pass
|
20
|
+
relevant_doc = (
|
21
|
+
"https://docs.expectedparrot.com/en/latest/agents.html#combining-agents"
|
22
|
+
)
|
27
23
|
|
28
24
|
|
29
25
|
class AgentNameError(AgentErrors):
|
30
|
-
|
26
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-names"
|
31
27
|
|
32
28
|
|
33
29
|
class AgentTraitKeyError(AgentErrors):
|
34
|
-
|
30
|
+
relevant_doc = (
|
31
|
+
"https://docs.expectedparrot.com/en/latest/agents.html#constructing-an-agent"
|
32
|
+
)
|
35
33
|
|
36
34
|
|
37
35
|
class FailedTaskException(Exception):
|
edsl/exceptions/results.py
CHANGED
@@ -1,26 +1,29 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
from edsl.exceptions.BaseException import BaseException
|
2
|
+
|
3
|
+
|
4
|
+
class ResultsError(BaseException):
|
5
|
+
relevant_docs = "https://docs.expectedparrot.com/en/latest/results.html"
|
3
6
|
|
4
7
|
|
5
|
-
class ResultsDeserializationError(
|
8
|
+
class ResultsDeserializationError(ResultsError):
|
6
9
|
pass
|
7
10
|
|
8
11
|
|
9
|
-
class ResultsBadMutationstringError(
|
12
|
+
class ResultsBadMutationstringError(ResultsError):
|
10
13
|
pass
|
11
14
|
|
12
15
|
|
13
|
-
class ResultsColumnNotFoundError(
|
16
|
+
class ResultsColumnNotFoundError(ResultsError):
|
14
17
|
pass
|
15
18
|
|
16
19
|
|
17
|
-
class ResultsInvalidNameError(
|
20
|
+
class ResultsInvalidNameError(ResultsError):
|
18
21
|
pass
|
19
22
|
|
20
23
|
|
21
|
-
class ResultsMutateError(
|
24
|
+
class ResultsMutateError(ResultsError):
|
22
25
|
pass
|
23
26
|
|
24
27
|
|
25
|
-
class ResultsFilterError(
|
28
|
+
class ResultsFilterError(ResultsError):
|
26
29
|
pass
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import re
|
2
|
+
import textwrap
|
3
|
+
|
4
|
+
|
5
|
+
class ScenarioError(Exception):
|
6
|
+
documentation = "https://docs.expectedparrot.com/en/latest/scenarios.html#module-edsl.scenarios.Scenario"
|
7
|
+
|
8
|
+
def __init__(self, message: str):
|
9
|
+
self.message = message + "\n" + "Documentation: " + self.documentation
|
10
|
+
super().__init__(self.message)
|
11
|
+
|
12
|
+
def __str__(self):
|
13
|
+
return self.make_urls_clickable(self.message)
|
14
|
+
|
15
|
+
@staticmethod
|
16
|
+
def make_urls_clickable(text):
|
17
|
+
url_pattern = r"https?://[^\s]+"
|
18
|
+
urls = re.findall(url_pattern, text)
|
19
|
+
for url in urls:
|
20
|
+
clickable_url = f"\033]8;;{url}\007{url}\033]8;;\007"
|
21
|
+
text = text.replace(url, clickable_url)
|
22
|
+
return text
|
edsl/exceptions/surveys.py
CHANGED
@@ -1,34 +1,37 @@
|
|
1
|
-
|
2
|
-
|
1
|
+
from edsl.exceptions.BaseException import BaseException
|
2
|
+
|
3
|
+
|
4
|
+
class SurveyError(BaseException):
|
5
|
+
relevant_doc = "https://docs.expectedparrot.com/en/latest/surveys.html"
|
3
6
|
|
4
7
|
|
5
|
-
class SurveyCreationError(
|
8
|
+
class SurveyCreationError(SurveyError):
|
6
9
|
pass
|
7
10
|
|
8
11
|
|
9
|
-
class SurveyHasNoRulesError(
|
12
|
+
class SurveyHasNoRulesError(SurveyError):
|
10
13
|
pass
|
11
14
|
|
12
15
|
|
13
|
-
class SurveyRuleSendsYouBackwardsError(
|
16
|
+
class SurveyRuleSendsYouBackwardsError(SurveyError):
|
14
17
|
pass
|
15
18
|
|
16
19
|
|
17
|
-
class SurveyRuleSkipLogicSyntaxError(
|
20
|
+
class SurveyRuleSkipLogicSyntaxError(SurveyError):
|
18
21
|
pass
|
19
22
|
|
20
23
|
|
21
|
-
class SurveyRuleReferenceInRuleToUnknownQuestionError(
|
24
|
+
class SurveyRuleReferenceInRuleToUnknownQuestionError(SurveyError):
|
22
25
|
pass
|
23
26
|
|
24
27
|
|
25
|
-
class SurveyRuleRefersToFutureStateError(
|
28
|
+
class SurveyRuleRefersToFutureStateError(SurveyError):
|
26
29
|
pass
|
27
30
|
|
28
31
|
|
29
|
-
class SurveyRuleCollectionHasNoRulesAtNodeError(
|
32
|
+
class SurveyRuleCollectionHasNoRulesAtNodeError(SurveyError):
|
30
33
|
pass
|
31
34
|
|
32
35
|
|
33
|
-
class SurveyRuleCannotEvaluateError(
|
36
|
+
class SurveyRuleCannotEvaluateError(SurveyError):
|
34
37
|
pass
|
@@ -28,12 +28,16 @@ class AwsBedrockService(InferenceServiceABC):
|
|
28
28
|
"ai21.j2-ultra",
|
29
29
|
"ai21.j2-ultra-v1",
|
30
30
|
]
|
31
|
+
_models_list_cache: List[str] = []
|
31
32
|
|
32
33
|
@classmethod
|
33
34
|
def available(cls):
|
34
35
|
"""Fetch available models from AWS Bedrock."""
|
36
|
+
|
37
|
+
region = os.getenv("AWS_REGION", "us-east-1")
|
38
|
+
|
35
39
|
if not cls._models_list_cache:
|
36
|
-
client = boto3.client("bedrock", region_name=
|
40
|
+
client = boto3.client("bedrock", region_name=region)
|
37
41
|
all_models_ids = [
|
38
42
|
x["modelId"] for x in client.list_foundation_models()["modelSummaries"]
|
39
43
|
]
|
@@ -80,7 +84,8 @@ class AwsBedrockService(InferenceServiceABC):
|
|
80
84
|
self.api_token
|
81
85
|
) # call to check the if env variables are set.
|
82
86
|
|
83
|
-
|
87
|
+
region = os.getenv("AWS_REGION", "us-east-1")
|
88
|
+
client = boto3.client("bedrock-runtime", region_name=region)
|
84
89
|
|
85
90
|
conversation = [
|
86
91
|
{
|