edsl 0.1.36__py3-none-any.whl → 0.1.36.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -5
- edsl/__init__.py +0 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +7 -11
- edsl/agents/InvigilatorBase.py +1 -5
- edsl/agents/PromptConstructor.py +18 -27
- edsl/conversation/Conversation.py +1 -1
- edsl/coop/PriceFetcher.py +18 -14
- edsl/coop/coop.py +8 -42
- edsl/exceptions/coop.py +0 -8
- edsl/inference_services/InferenceServiceABC.py +0 -28
- edsl/inference_services/InferenceServicesCollection.py +4 -10
- edsl/inference_services/models_available_cache.py +1 -25
- edsl/jobs/Jobs.py +167 -190
- edsl/jobs/interviews/Interview.py +14 -42
- edsl/jobs/interviews/InterviewExceptionCollection.py +0 -9
- edsl/jobs/interviews/InterviewExceptionEntry.py +6 -31
- edsl/jobs/runners/JobsRunnerAsyncio.py +13 -8
- edsl/jobs/tasks/TaskHistory.py +7 -23
- edsl/questions/QuestionFunctional.py +3 -7
- edsl/results/Dataset.py +0 -12
- edsl/results/Result.py +0 -2
- edsl/results/Results.py +1 -13
- edsl/scenarios/FileStore.py +5 -20
- edsl/scenarios/Scenario.py +1 -15
- edsl/scenarios/__init__.py +0 -2
- edsl/surveys/Survey.py +0 -3
- edsl/surveys/instructions/Instruction.py +3 -20
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/RECORD +32 -33
- edsl/data/RemoteCacheSync.py +0 -97
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/LICENSE +0 -0
- {edsl-0.1.36.dist-info → edsl-0.1.36.dev2.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -7,8 +7,6 @@ import json
|
|
7
7
|
from typing import Any, Optional, Union
|
8
8
|
from uuid import UUID
|
9
9
|
|
10
|
-
# from edsl.utilities.MethodSuggesterMixin import MethodSuggesterMixin
|
11
|
-
|
12
10
|
|
13
11
|
class RichPrintingMixin:
|
14
12
|
"""Mixin for rich printing and persistence of objects."""
|
@@ -276,9 +274,6 @@ class Base(
|
|
276
274
|
"""This method should be implemented by subclasses."""
|
277
275
|
raise NotImplementedError("This method is not implemented yet.")
|
278
276
|
|
279
|
-
def to_json(self):
|
280
|
-
return json.dumps(self.to_dict())
|
281
|
-
|
282
277
|
@abstractmethod
|
283
278
|
def from_dict():
|
284
279
|
"""This method should be implemented by subclasses."""
|
edsl/__init__.py
CHANGED
@@ -27,7 +27,6 @@ from edsl.questions import QuestionTopK
|
|
27
27
|
|
28
28
|
from edsl.scenarios import Scenario
|
29
29
|
from edsl.scenarios import ScenarioList
|
30
|
-
from edsl.scenarios.FileStore import FileStore
|
31
30
|
|
32
31
|
# from edsl.utilities.interface import print_dict_with_rich
|
33
32
|
from edsl.surveys.Survey import Survey
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.36"
|
1
|
+
__version__ = "0.1.36.dev2"
|
edsl/agents/Agent.py
CHANGED
@@ -111,11 +111,7 @@ class Agent(Base):
|
|
111
111
|
self.name = name
|
112
112
|
self._traits = traits or dict()
|
113
113
|
self.codebook = codebook or dict()
|
114
|
-
|
115
|
-
self.instruction = self.default_instruction
|
116
|
-
else:
|
117
|
-
self.instruction = instruction
|
118
|
-
# self.instruction = instruction or self.default_instruction
|
114
|
+
self.instruction = instruction or self.default_instruction
|
119
115
|
self.dynamic_traits_function = dynamic_traits_function
|
120
116
|
|
121
117
|
# Deal with dynamic traits function
|
@@ -614,9 +610,9 @@ class Agent(Base):
|
|
614
610
|
if dynamic_traits_func:
|
615
611
|
func = inspect.getsource(dynamic_traits_func)
|
616
612
|
raw_data["dynamic_traits_function_source_code"] = func
|
617
|
-
raw_data[
|
618
|
-
|
619
|
-
|
613
|
+
raw_data["dynamic_traits_function_name"] = (
|
614
|
+
self.dynamic_traits_function_name
|
615
|
+
)
|
620
616
|
if hasattr(self, "answer_question_directly"):
|
621
617
|
raw_data.pop(
|
622
618
|
"answer_question_directly", None
|
@@ -632,9 +628,9 @@ class Agent(Base):
|
|
632
628
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
633
629
|
answer_question_directly_func
|
634
630
|
)
|
635
|
-
raw_data[
|
636
|
-
|
637
|
-
|
631
|
+
raw_data["answer_question_directly_function_name"] = (
|
632
|
+
self.answer_question_directly_function_name
|
633
|
+
)
|
638
634
|
|
639
635
|
return raw_data
|
640
636
|
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -115,11 +115,7 @@ class InvigilatorBase(ABC):
|
|
115
115
|
iteration = data["iteration"]
|
116
116
|
additional_prompt_data = data["additional_prompt_data"]
|
117
117
|
cache = Cache.from_dict(data["cache"])
|
118
|
-
|
119
|
-
if data["sidecar_model"] is None:
|
120
|
-
sidecar_model = None
|
121
|
-
else:
|
122
|
-
sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
|
118
|
+
sidecar_model = LanguageModel.from_dict(data["sidecar_model"])
|
123
119
|
|
124
120
|
return cls(
|
125
121
|
agent=agent,
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -1,11 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
from typing import Dict, Any, Optional, Set
|
3
|
+
from collections import UserList
|
4
|
+
import pdb
|
3
5
|
|
4
6
|
from jinja2 import Environment, meta
|
5
7
|
|
6
8
|
from edsl.prompts.Prompt import Prompt
|
9
|
+
from edsl.data_transfer_models import ImageInfo
|
7
10
|
|
8
|
-
from edsl.
|
11
|
+
# from edsl.prompts.registry import get_classes as prompt_lookup
|
12
|
+
from edsl.exceptions import QuestionScenarioRenderError
|
13
|
+
|
14
|
+
from edsl.agents.prompt_helpers import PromptComponent, PromptList, PromptPlan
|
9
15
|
|
10
16
|
|
11
17
|
def get_jinja2_variables(template_str: str) -> Set[str]:
|
@@ -147,28 +153,14 @@ class PromptConstructor:
|
|
147
153
|
|
148
154
|
# might be getting it from the prior answers
|
149
155
|
if self.prior_answers_dict().get(question_option_key) is not None:
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
"Will be populated by prior answer",
|
159
|
-
"These are placeholder options",
|
160
|
-
]
|
161
|
-
question_data["question_options"] = placeholder_options
|
162
|
-
self.question.question_options = placeholder_options
|
163
|
-
|
164
|
-
# if isinstance(
|
165
|
-
# question_options := self.prior_answers_dict()
|
166
|
-
# .get(question_option_key)
|
167
|
-
# .answer,
|
168
|
-
# list,
|
169
|
-
# ):
|
170
|
-
# question_data["question_options"] = question_options
|
171
|
-
# self.question.question_options = question_options
|
156
|
+
if isinstance(
|
157
|
+
question_options := self.prior_answers_dict()
|
158
|
+
.get(question_option_key)
|
159
|
+
.answer,
|
160
|
+
list,
|
161
|
+
):
|
162
|
+
question_data["question_options"] = question_options
|
163
|
+
self.question.question_options = question_options
|
172
164
|
|
173
165
|
replacement_dict = (
|
174
166
|
{key: f"<see file {key}>" for key in self.scenario_file_keys}
|
@@ -220,10 +212,9 @@ class PromptConstructor:
|
|
220
212
|
)
|
221
213
|
|
222
214
|
if relevant_instructions != []:
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
preamble_text = Prompt(text="")
|
215
|
+
preamble_text = Prompt(
|
216
|
+
text="Before answer this question, you were given the following instructions: "
|
217
|
+
)
|
227
218
|
for instruction in relevant_instructions:
|
228
219
|
preamble_text += instruction.text
|
229
220
|
rendered_instructions = preamble_text + rendered_instructions
|
@@ -169,7 +169,6 @@ class Conversation:
|
|
169
169
|
agent=speaker,
|
170
170
|
just_answer=False,
|
171
171
|
cache=self.cache,
|
172
|
-
model=speaker.model,
|
173
172
|
)
|
174
173
|
return results[0]
|
175
174
|
|
@@ -180,6 +179,7 @@ class Conversation:
|
|
180
179
|
i = 0
|
181
180
|
while await self.continue_conversation():
|
182
181
|
speaker = self.next_speaker()
|
182
|
+
# breakpoint()
|
183
183
|
|
184
184
|
next_statement = AgentStatement(
|
185
185
|
statement=await self.get_next_statement(
|
edsl/coop/PriceFetcher.py
CHANGED
@@ -16,26 +16,30 @@ class PriceFetcher:
|
|
16
16
|
if self._cached_prices is not None:
|
17
17
|
return self._cached_prices
|
18
18
|
|
19
|
-
import os
|
20
19
|
import requests
|
21
|
-
|
20
|
+
import csv
|
21
|
+
from io import StringIO
|
22
|
+
|
23
|
+
sheet_id = "1SAO3Bhntefl0XQHJv27rMxpvu6uzKDWNXFHRa7jrUDs"
|
24
|
+
|
25
|
+
# Construct the URL to fetch the CSV
|
26
|
+
url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
|
22
27
|
|
23
28
|
try:
|
24
|
-
# Fetch the
|
25
|
-
|
26
|
-
api_key = os.getenv("EXPECTED_PARROT_API_KEY")
|
27
|
-
headers = {}
|
28
|
-
if api_key:
|
29
|
-
headers["Authorization"] = f"Bearer {api_key}"
|
30
|
-
else:
|
31
|
-
headers["Authorization"] = f"Bearer None"
|
32
|
-
|
33
|
-
response = requests.get(url, headers=headers, timeout=20)
|
29
|
+
# Fetch the CSV data
|
30
|
+
response = requests.get(url)
|
34
31
|
response.raise_for_status() # Raise an exception for bad responses
|
35
32
|
|
36
|
-
# Parse the data
|
37
|
-
|
33
|
+
# Parse the CSV data
|
34
|
+
csv_data = StringIO(response.text)
|
35
|
+
reader = csv.reader(csv_data)
|
36
|
+
|
37
|
+
# Convert to list of dictionaries
|
38
|
+
headers = next(reader)
|
39
|
+
data = [dict(zip(headers, row)) for row in reader]
|
38
40
|
|
41
|
+
# self._cached_prices = data
|
42
|
+
# return data
|
39
43
|
price_lookup = {}
|
40
44
|
for entry in data:
|
41
45
|
service = entry.get("service", None)
|
edsl/coop/coop.py
CHANGED
@@ -6,7 +6,6 @@ from typing import Any, Optional, Union, Literal
|
|
6
6
|
from uuid import UUID
|
7
7
|
import edsl
|
8
8
|
from edsl import CONFIG, CacheEntry, Jobs, Survey
|
9
|
-
from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
|
10
9
|
from edsl.coop.utils import (
|
11
10
|
EDSLObject,
|
12
11
|
ObjectRegistry,
|
@@ -100,7 +99,7 @@ class Coop:
|
|
100
99
|
if "Authorization" in message:
|
101
100
|
print(message)
|
102
101
|
message = "Please provide an Expected Parrot API key."
|
103
|
-
raise
|
102
|
+
raise Exception(message)
|
104
103
|
|
105
104
|
def _json_handle_none(self, value: Any) -> Any:
|
106
105
|
"""
|
@@ -117,7 +116,7 @@ class Coop:
|
|
117
116
|
Resolve the uuid from a uuid or a url.
|
118
117
|
"""
|
119
118
|
if not url and not uuid:
|
120
|
-
raise
|
119
|
+
raise Exception("No uuid or url provided for the object.")
|
121
120
|
if not uuid and url:
|
122
121
|
uuid = url.split("/")[-1]
|
123
122
|
return uuid
|
@@ -522,7 +521,7 @@ class Coop:
|
|
522
521
|
self._resolve_server_response(response)
|
523
522
|
response_json = response.json()
|
524
523
|
return {
|
525
|
-
"uuid": response_json.get("
|
524
|
+
"uuid": response_json.get("jobs_uuid"),
|
526
525
|
"description": response_json.get("description"),
|
527
526
|
"status": response_json.get("status"),
|
528
527
|
"iterations": response_json.get("iterations"),
|
@@ -530,41 +529,29 @@ class Coop:
|
|
530
529
|
"version": self._edsl_version,
|
531
530
|
}
|
532
531
|
|
533
|
-
def remote_inference_get(
|
534
|
-
self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
|
535
|
-
) -> dict:
|
532
|
+
def remote_inference_get(self, job_uuid: str) -> dict:
|
536
533
|
"""
|
537
534
|
Get the details of a remote inference job.
|
538
|
-
You can pass either the job uuid or the results uuid as a parameter.
|
539
|
-
If you pass both, the job uuid will be prioritized.
|
540
535
|
|
541
536
|
:param job_uuid: The UUID of the EDSL job.
|
542
|
-
:param results_uuid: The UUID of the results associated with the EDSL job.
|
543
537
|
|
544
538
|
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
545
539
|
{'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
|
546
540
|
"""
|
547
|
-
if job_uuid is None and results_uuid is None:
|
548
|
-
raise ValueError("Either job_uuid or results_uuid must be provided.")
|
549
|
-
elif job_uuid is not None:
|
550
|
-
params = {"job_uuid": job_uuid}
|
551
|
-
else:
|
552
|
-
params = {"results_uuid": results_uuid}
|
553
|
-
|
554
541
|
response = self._send_server_request(
|
555
542
|
uri="api/v0/remote-inference",
|
556
543
|
method="GET",
|
557
|
-
params=
|
544
|
+
params={"uuid": job_uuid},
|
558
545
|
)
|
559
546
|
self._resolve_server_response(response)
|
560
547
|
data = response.json()
|
561
548
|
return {
|
562
|
-
"
|
549
|
+
"jobs_uuid": data.get("jobs_uuid"),
|
563
550
|
"results_uuid": data.get("results_uuid"),
|
564
551
|
"results_url": f"{self.url}/content/{data.get('results_uuid')}",
|
565
552
|
"status": data.get("status"),
|
566
553
|
"reason": data.get("reason"),
|
567
|
-
"
|
554
|
+
"price": data.get("price"),
|
568
555
|
"version": data.get("version"),
|
569
556
|
}
|
570
557
|
|
@@ -597,10 +584,7 @@ class Coop:
|
|
597
584
|
)
|
598
585
|
self._resolve_server_response(response)
|
599
586
|
response_json = response.json()
|
600
|
-
return
|
601
|
-
"credits": response_json.get("cost_in_credits"),
|
602
|
-
"usd": response_json.get("cost_in_usd"),
|
603
|
-
}
|
587
|
+
return response_json.get("cost")
|
604
588
|
|
605
589
|
################
|
606
590
|
# Remote Errors
|
@@ -665,10 +649,6 @@ class Coop:
|
|
665
649
|
return response_json
|
666
650
|
|
667
651
|
def fetch_prices(self) -> dict:
|
668
|
-
"""
|
669
|
-
Fetch model prices from Coop. If the request fails, return an empty dict.
|
670
|
-
"""
|
671
|
-
|
672
652
|
from edsl.coop.PriceFetcher import PriceFetcher
|
673
653
|
|
674
654
|
from edsl.config import CONFIG
|
@@ -679,20 +659,6 @@ class Coop:
|
|
679
659
|
else:
|
680
660
|
return {}
|
681
661
|
|
682
|
-
def fetch_rate_limit_config_vars(self) -> dict:
|
683
|
-
"""
|
684
|
-
Fetch a dict of rate limit config vars from Coop.
|
685
|
-
|
686
|
-
The dict keys are RPM and TPM variables like EDSL_SERVICE_RPM_OPENAI.
|
687
|
-
"""
|
688
|
-
response = self._send_server_request(
|
689
|
-
uri="api/v0/config-vars",
|
690
|
-
method="GET",
|
691
|
-
)
|
692
|
-
self._resolve_server_response(response)
|
693
|
-
data = response.json()
|
694
|
-
return data
|
695
|
-
|
696
662
|
|
697
663
|
if __name__ == "__main__":
|
698
664
|
sheet_data = fetch_sheet_data()
|
edsl/exceptions/coop.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
from abc import abstractmethod, ABC
|
2
2
|
import os
|
3
3
|
import re
|
4
|
-
from datetime import datetime, timedelta
|
5
4
|
from edsl.config import CONFIG
|
6
5
|
|
7
6
|
|
@@ -11,8 +10,6 @@ class InferenceServiceABC(ABC):
|
|
11
10
|
Anthropic: https://docs.anthropic.com/en/api/rate-limits
|
12
11
|
"""
|
13
12
|
|
14
|
-
_coop_config_vars = None
|
15
|
-
|
16
13
|
default_levels = {
|
17
14
|
"google": {"tpm": 2_000_000, "rpm": 15},
|
18
15
|
"openai": {"tpm": 2_000_000, "rpm": 10_000},
|
@@ -34,37 +31,12 @@ class InferenceServiceABC(ABC):
|
|
34
31
|
f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
|
35
32
|
)
|
36
33
|
|
37
|
-
@classmethod
|
38
|
-
def _should_refresh_coop_config_vars(cls):
|
39
|
-
"""
|
40
|
-
Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
|
41
|
-
"""
|
42
|
-
|
43
|
-
if cls._last_config_fetch is None:
|
44
|
-
return True
|
45
|
-
return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
|
46
|
-
|
47
34
|
@classmethod
|
48
35
|
def _get_limt(cls, limit_type: str) -> int:
|
49
36
|
key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
|
50
37
|
if key in os.environ:
|
51
38
|
return int(os.getenv(key))
|
52
39
|
|
53
|
-
if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
|
54
|
-
try:
|
55
|
-
from edsl import Coop
|
56
|
-
|
57
|
-
c = Coop()
|
58
|
-
cls._coop_config_vars = c.fetch_rate_limit_config_vars()
|
59
|
-
cls._last_config_fetch = datetime.now()
|
60
|
-
if key in cls._coop_config_vars:
|
61
|
-
return cls._coop_config_vars[key]
|
62
|
-
except Exception:
|
63
|
-
cls._coop_config_vars = None
|
64
|
-
else:
|
65
|
-
if key in cls._coop_config_vars:
|
66
|
-
return cls._coop_config_vars[key]
|
67
|
-
|
68
40
|
if cls._inference_service_ in cls.default_levels:
|
69
41
|
return int(cls.default_levels[cls._inference_service_][limit_type])
|
70
42
|
|
@@ -56,19 +56,13 @@ class InferenceServicesCollection:
|
|
56
56
|
self.services.append(service)
|
57
57
|
|
58
58
|
def create_model_factory(self, model_name: str, service_name=None, index=None):
|
59
|
-
from edsl.inference_services.TestService import TestService
|
60
|
-
|
61
|
-
if model_name == "test":
|
62
|
-
return TestService.create_model(model_name)
|
63
|
-
|
64
|
-
if service_name:
|
65
|
-
for service in self.services:
|
66
|
-
if service_name == service._inference_service_:
|
67
|
-
return service.create_model(model_name)
|
68
|
-
|
69
59
|
for service in self.services:
|
70
60
|
if model_name in self._get_service_available(service):
|
71
61
|
if service_name is None or service_name == service._inference_service_:
|
72
62
|
return service.create_model(model_name)
|
73
63
|
|
64
|
+
# if model_name == "test":
|
65
|
+
# from edsl.language_models import LanguageModel
|
66
|
+
# return LanguageModel(test = True)
|
67
|
+
|
74
68
|
raise Exception(f"Model {model_name} not found in any of the services")
|
@@ -65,31 +65,7 @@ models_available = {
|
|
65
65
|
"meta-llama/Meta-Llama-3-70B-Instruct",
|
66
66
|
"openchat/openchat_3.5",
|
67
67
|
],
|
68
|
-
"google": [
|
69
|
-
"gemini-1.0-pro",
|
70
|
-
"gemini-1.0-pro-001",
|
71
|
-
"gemini-1.0-pro-latest",
|
72
|
-
"gemini-1.0-pro-vision-latest",
|
73
|
-
"gemini-1.5-flash",
|
74
|
-
"gemini-1.5-flash-001",
|
75
|
-
"gemini-1.5-flash-001-tuning",
|
76
|
-
"gemini-1.5-flash-002",
|
77
|
-
"gemini-1.5-flash-8b",
|
78
|
-
"gemini-1.5-flash-8b-001",
|
79
|
-
"gemini-1.5-flash-8b-exp-0827",
|
80
|
-
"gemini-1.5-flash-8b-exp-0924",
|
81
|
-
"gemini-1.5-flash-8b-latest",
|
82
|
-
"gemini-1.5-flash-exp-0827",
|
83
|
-
"gemini-1.5-flash-latest",
|
84
|
-
"gemini-1.5-pro",
|
85
|
-
"gemini-1.5-pro-001",
|
86
|
-
"gemini-1.5-pro-002",
|
87
|
-
"gemini-1.5-pro-exp-0801",
|
88
|
-
"gemini-1.5-pro-exp-0827",
|
89
|
-
"gemini-1.5-pro-latest",
|
90
|
-
"gemini-pro",
|
91
|
-
"gemini-pro-vision",
|
92
|
-
],
|
68
|
+
"google": ["gemini-pro"],
|
93
69
|
"bedrock": [
|
94
70
|
"amazon.titan-tg1-large",
|
95
71
|
"amazon.titan-text-lite-v1",
|