edsl 0.1.39.dev2__py3-none-any.whl → 0.1.39.dev4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +28 -0
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +8 -16
- edsl/agents/Invigilator.py +13 -14
- edsl/agents/InvigilatorBase.py +4 -1
- edsl/agents/PromptConstructor.py +42 -22
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +18 -5
- edsl/auto/StageBase.py +53 -40
- edsl/auto/StageQuestions.py +2 -1
- edsl/auto/utilities.py +0 -6
- edsl/coop/coop.py +21 -5
- edsl/data/Cache.py +29 -18
- edsl/data/CacheHandler.py +0 -2
- edsl/data/RemoteCacheSync.py +154 -46
- edsl/data/hack.py +10 -0
- edsl/enums.py +7 -0
- edsl/inference_services/AnthropicService.py +38 -16
- edsl/inference_services/AvailableModelFetcher.py +7 -1
- edsl/inference_services/GoogleService.py +5 -1
- edsl/inference_services/InferenceServicesCollection.py +18 -2
- edsl/inference_services/OpenAIService.py +46 -31
- edsl/inference_services/TestService.py +1 -3
- edsl/inference_services/TogetherAIService.py +5 -3
- edsl/inference_services/data_structures.py +74 -2
- edsl/jobs/AnswerQuestionFunctionConstructor.py +148 -113
- edsl/jobs/FetchInvigilator.py +10 -3
- edsl/jobs/InterviewsConstructor.py +6 -4
- edsl/jobs/Jobs.py +299 -233
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +160 -136
- edsl/jobs/async_interview_runner.py +138 -0
- edsl/jobs/check_survey_scenario_compatibility.py +85 -0
- edsl/jobs/data_structures.py +120 -0
- edsl/jobs/interviews/Interview.py +80 -42
- edsl/jobs/results_exceptions_handler.py +98 -0
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -357
- edsl/jobs/runners/JobsRunnerStatus.py +131 -164
- edsl/jobs/tasks/TaskHistory.py +24 -3
- edsl/language_models/LanguageModel.py +59 -4
- edsl/language_models/ModelList.py +19 -8
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/model.py +256 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +35 -26
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +5 -7
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +9 -15
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/questions/data_structures.py +20 -0
- edsl/questions/{QuestionBaseGenMixin.py → question_base_gen_mixin.py} +52 -49
- edsl/questions/{ResponseValidatorABC.py → response_validator_abc.py} +6 -18
- edsl/questions/{ResponseValidatorFactory.py → response_validator_factory.py} +7 -1
- edsl/results/DatasetExportMixin.py +60 -119
- edsl/results/Result.py +109 -3
- edsl/results/Results.py +50 -39
- edsl/results/file_exports.py +252 -0
- edsl/scenarios/ScenarioList.py +35 -7
- edsl/surveys/Survey.py +71 -20
- edsl/test_h +1 -0
- edsl/utilities/gcp_bucket/example.py +50 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/METADATA +2 -2
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/RECORD +85 -76
- edsl/language_models/registry.py +0 -180
- /edsl/agents/{QuestionOptionProcessor.py → question_option_processor.py} +0 -0
- /edsl/questions/{AnswerValidatorMixin.py → answer_validator_mixin.py} +0 -0
- /edsl/questions/{LoopProcessor.py → loop_processor.py} +0 -0
- /edsl/questions/{RegisterQuestionsMeta.py → register_questions_meta.py} +0 -0
- /edsl/results/{ResultsFetchMixin.py → results_fetch_mixin.py} +0 -0
- /edsl/results/{Selector.py → results_selector.py} +0 -0
- /edsl/results/{ResultsToolsMixin.py → results_tools_mixin.py} +0 -0
- /edsl/scenarios/{DirectoryScanner.py → directory_scanner.py} +0 -0
- /edsl/scenarios/{ScenarioJoin.py → scenario_join.py} +0 -0
- /edsl/scenarios/{ScenarioSelector.py → scenario_selector.py} +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/LICENSE +0 -0
- {edsl-0.1.39.dev2.dist-info → edsl-0.1.39.dev4.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -10,6 +10,10 @@ from uuid import UUID
|
|
10
10
|
class PersistenceMixin:
|
11
11
|
"""Mixin for saving and loading objects to and from files."""
|
12
12
|
|
13
|
+
def duplicate(self, add_edsl_version=False):
|
14
|
+
"""Return a duplicate of the object."""
|
15
|
+
return self.from_dict(self.to_dict(add_edsl_version=False))
|
16
|
+
|
13
17
|
def push(
|
14
18
|
self,
|
15
19
|
description: Optional[str] = None,
|
@@ -23,6 +27,30 @@ class PersistenceMixin:
|
|
23
27
|
c = Coop(url=expected_parrot_url)
|
24
28
|
return c.create(self, description, alias, visibility)
|
25
29
|
|
30
|
+
def to_yaml(self, add_edsl_version=False, filename: str = None) -> Union[str, None]:
|
31
|
+
import yaml
|
32
|
+
|
33
|
+
output = yaml.dump(self.to_dict(add_edsl_version=add_edsl_version))
|
34
|
+
if not filename:
|
35
|
+
return output
|
36
|
+
|
37
|
+
with open(filename, "w") as f:
|
38
|
+
f.write(output)
|
39
|
+
|
40
|
+
@classmethod
|
41
|
+
def from_yaml(cls, yaml_str: Optional[str] = None, filename: Optional[str] = None):
|
42
|
+
if yaml_str is None and filename is not None:
|
43
|
+
with open(filename, "r") as f:
|
44
|
+
yaml_str = f.read()
|
45
|
+
return cls.from_yaml(yaml_str=yaml_str)
|
46
|
+
elif yaml_str and filename is None:
|
47
|
+
import yaml
|
48
|
+
|
49
|
+
d = yaml.load(yaml_str, Loader=yaml.FullLoader)
|
50
|
+
return cls.from_dict(d)
|
51
|
+
else:
|
52
|
+
raise ValueError("Either yaml_str or filename must be provided.")
|
53
|
+
|
26
54
|
def create_download_link(self):
|
27
55
|
from tempfile import NamedTemporaryFile
|
28
56
|
from edsl.scenarios.FileStore import FileStore
|
edsl/__init__.py
CHANGED
@@ -34,7 +34,7 @@ from edsl.scenarios.FileStore import FileStore
|
|
34
34
|
|
35
35
|
# from edsl.utilities.interface import print_dict_with_rich
|
36
36
|
from edsl.surveys.Survey import Survey
|
37
|
-
from edsl.language_models.
|
37
|
+
from edsl.language_models.model import Model
|
38
38
|
from edsl.language_models.ModelList import ModelList
|
39
39
|
|
40
40
|
from edsl.results.Results import Results
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.39.
|
1
|
+
__version__ = "0.1.39.dev4"
|
edsl/agents/Agent.py
CHANGED
@@ -549,12 +549,11 @@ class Agent(Base):
|
|
549
549
|
survey: Optional["Survey"] = None,
|
550
550
|
scenario: Optional["Scenario"] = None,
|
551
551
|
model: Optional["LanguageModel"] = None,
|
552
|
-
# debug: bool = False,
|
553
552
|
memory_plan: Optional["MemoryPlan"] = None,
|
554
553
|
current_answers: Optional[dict] = None,
|
555
554
|
iteration: int = 1,
|
556
|
-
# sidecar_model=None,
|
557
555
|
raise_validation_errors: bool = True,
|
556
|
+
key_lookup: Optional["KeyLookup"] = None,
|
558
557
|
) -> "InvigilatorBase":
|
559
558
|
"""Create an Invigilator.
|
560
559
|
|
@@ -569,7 +568,7 @@ class Agent(Base):
|
|
569
568
|
An invigator is an object that is responsible for administering a question to an agent and
|
570
569
|
recording the responses.
|
571
570
|
"""
|
572
|
-
from edsl.language_models.
|
571
|
+
from edsl.language_models.model import Model
|
573
572
|
|
574
573
|
from edsl.scenarios.Scenario import Scenario
|
575
574
|
|
@@ -582,13 +581,12 @@ class Agent(Base):
|
|
582
581
|
scenario=scenario,
|
583
582
|
survey=survey,
|
584
583
|
model=model,
|
585
|
-
# debug=debug,
|
586
584
|
memory_plan=memory_plan,
|
587
585
|
current_answers=current_answers,
|
588
586
|
iteration=iteration,
|
589
587
|
cache=cache,
|
590
|
-
# sidecar_model=sidecar_model,
|
591
588
|
raise_validation_errors=raise_validation_errors,
|
589
|
+
key_lookup=key_lookup,
|
592
590
|
)
|
593
591
|
if hasattr(self, "validate_response"):
|
594
592
|
invigilator.validate_response = self.validate_response
|
@@ -608,6 +606,7 @@ class Agent(Base):
|
|
608
606
|
memory_plan: Optional[MemoryPlan] = None,
|
609
607
|
current_answers: Optional[dict] = None,
|
610
608
|
iteration: int = 0,
|
609
|
+
key_lookup: Optional["KeyLookup"] = None,
|
611
610
|
) -> AgentResponseDict:
|
612
611
|
"""
|
613
612
|
Answer a posed question.
|
@@ -637,10 +636,10 @@ class Agent(Base):
|
|
637
636
|
scenario=scenario,
|
638
637
|
survey=survey,
|
639
638
|
model=model,
|
640
|
-
# debug=debug,
|
641
639
|
memory_plan=memory_plan,
|
642
640
|
current_answers=current_answers,
|
643
641
|
iteration=iteration,
|
642
|
+
key_lookup=key_lookup,
|
644
643
|
)
|
645
644
|
response: AgentResponseDict = await invigilator.async_answer_question()
|
646
645
|
return response
|
@@ -673,15 +672,14 @@ class Agent(Base):
|
|
673
672
|
scenario: Optional[Scenario] = None,
|
674
673
|
model: Optional[LanguageModel] = None,
|
675
674
|
survey: Optional[Survey] = None,
|
676
|
-
# debug: bool = False,
|
677
675
|
memory_plan: Optional[MemoryPlan] = None,
|
678
676
|
current_answers: Optional[dict] = None,
|
679
677
|
iteration: int = 0,
|
680
|
-
# sidecar_model=None,
|
681
678
|
raise_validation_errors: bool = True,
|
679
|
+
key_lookup: Optional["KeyLookup"] = None,
|
682
680
|
) -> "InvigilatorBase":
|
683
681
|
"""Create an Invigilator."""
|
684
|
-
from edsl.language_models.
|
682
|
+
from edsl.language_models.model import Model
|
685
683
|
from edsl.scenarios.Scenario import Scenario
|
686
684
|
|
687
685
|
model = model or Model()
|
@@ -694,12 +692,6 @@ class Agent(Base):
|
|
694
692
|
|
695
693
|
invigilator_class = self._get_invigilator_class(question)
|
696
694
|
|
697
|
-
# if sidecar_model is not None:
|
698
|
-
# # this is the case when a 'simple' model is being used
|
699
|
-
# # from edsl.agents.Invigilator import InvigilatorSidecar
|
700
|
-
# # invigilator_class = InvigilatorSidecar
|
701
|
-
# raise DeprecationWarning("Sidecar models are deprecated.")
|
702
|
-
|
703
695
|
invigilator = invigilator_class(
|
704
696
|
self,
|
705
697
|
question=question,
|
@@ -710,8 +702,8 @@ class Agent(Base):
|
|
710
702
|
current_answers=current_answers,
|
711
703
|
iteration=iteration,
|
712
704
|
cache=cache,
|
713
|
-
# sidecar_model=sidecar_model,
|
714
705
|
raise_validation_errors=raise_validation_errors,
|
706
|
+
key_lookup=key_lookup,
|
715
707
|
)
|
716
708
|
return invigilator
|
717
709
|
|
edsl/agents/Invigilator.py
CHANGED
@@ -13,11 +13,7 @@ if TYPE_CHECKING:
|
|
13
13
|
from edsl.surveys.Survey import Survey
|
14
14
|
|
15
15
|
|
16
|
-
|
17
|
-
def __new__(cls):
|
18
|
-
instance = super().__new__(cls, "Not Applicable")
|
19
|
-
instance.literal = "Not Applicable"
|
20
|
-
return instance
|
16
|
+
NA = "Not Applicable"
|
21
17
|
|
22
18
|
|
23
19
|
class InvigilatorAI(InvigilatorBase):
|
@@ -43,6 +39,9 @@ class InvigilatorAI(InvigilatorBase):
|
|
43
39
|
params.update({"iteration": self.iteration, "cache": self.cache})
|
44
40
|
params.update({"invigilator": self})
|
45
41
|
|
42
|
+
if self.key_lookup:
|
43
|
+
self.model.set_key_lookup(self.key_lookup)
|
44
|
+
|
46
45
|
return await self.model.async_get_response(**params)
|
47
46
|
|
48
47
|
def store_response(self, agent_response_dict: AgentResponseDict) -> None:
|
@@ -232,13 +231,13 @@ class InvigilatorHuman(InvigilatorBase):
|
|
232
231
|
exception_occurred = e
|
233
232
|
finally:
|
234
233
|
data = {
|
235
|
-
"generated_tokens": NotApplicable(),
|
234
|
+
"generated_tokens": NA, # NotApplicable(),
|
236
235
|
"question_name": self.question.question_name,
|
237
236
|
"prompts": self.get_prompts(),
|
238
|
-
"cached_response":
|
239
|
-
"raw_model_response":
|
240
|
-
"cache_used":
|
241
|
-
"cache_key":
|
237
|
+
"cached_response": NA,
|
238
|
+
"raw_model_response": NA,
|
239
|
+
"cache_used": NA,
|
240
|
+
"cache_key": NA,
|
242
241
|
"answer": answer,
|
243
242
|
"comment": comment,
|
244
243
|
"validated": validated,
|
@@ -259,10 +258,10 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
259
258
|
generated_tokens=str(answer),
|
260
259
|
question_name=self.question.question_name,
|
261
260
|
prompts=self.get_prompts(),
|
262
|
-
cached_response=
|
263
|
-
raw_model_response=
|
264
|
-
cache_used=
|
265
|
-
cache_key=
|
261
|
+
cached_response=NA,
|
262
|
+
raw_model_response=NA,
|
263
|
+
cache_used=NA,
|
264
|
+
cache_key=NA,
|
266
265
|
answer=answer["answer"],
|
267
266
|
comment="This is the result of a functional question.",
|
268
267
|
validated=True,
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
|
14
14
|
from edsl.language_models.LanguageModel import LanguageModel
|
15
15
|
from edsl.surveys.Survey import Survey
|
16
16
|
from edsl.agents.Agent import Agent
|
17
|
+
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
17
18
|
|
18
19
|
from edsl.data_transfer_models import EDSLResultObjectInput
|
19
20
|
from edsl.agents.PromptConstructor import PromptConstructor
|
@@ -46,6 +47,7 @@ class InvigilatorBase(ABC):
|
|
46
47
|
additional_prompt_data: Optional[dict] = None,
|
47
48
|
raise_validation_errors: Optional[bool] = True,
|
48
49
|
prompt_plan: Optional["PromptPlan"] = None,
|
50
|
+
key_lookup: Optional["KeyLookup"] = None,
|
49
51
|
):
|
50
52
|
"""Initialize a new Invigilator."""
|
51
53
|
self.agent = agent
|
@@ -59,6 +61,7 @@ class InvigilatorBase(ABC):
|
|
59
61
|
self.cache = cache
|
60
62
|
self.survey = survey
|
61
63
|
self.raise_validation_errors = raise_validation_errors
|
64
|
+
self.key_lookup = key_lookup
|
62
65
|
|
63
66
|
if prompt_plan is None:
|
64
67
|
self.prompt_plan = PromptPlan()
|
@@ -208,7 +211,7 @@ class InvigilatorBase(ABC):
|
|
208
211
|
from edsl.agents.Agent import Agent
|
209
212
|
from edsl.scenarios.Scenario import Scenario
|
210
213
|
from edsl.surveys.MemoryPlan import MemoryPlan
|
211
|
-
from edsl.language_models.
|
214
|
+
from edsl.language_models.model import Model
|
212
215
|
from edsl.surveys.Survey import Survey
|
213
216
|
|
214
217
|
model = Model("test", canned_response="SPAM!")
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -4,40 +4,56 @@ from functools import cached_property
|
|
4
4
|
|
5
5
|
from edsl.prompts.Prompt import Prompt
|
6
6
|
|
7
|
+
from dataclasses import dataclass
|
8
|
+
|
7
9
|
from .prompt_helpers import PromptPlan
|
8
10
|
from .QuestionTemplateReplacementsBuilder import (
|
9
11
|
QuestionTemplateReplacementsBuilder,
|
10
12
|
)
|
11
|
-
from .
|
13
|
+
from .question_option_processor import QuestionOptionProcessor
|
12
14
|
|
13
15
|
if TYPE_CHECKING:
|
14
16
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
15
17
|
from edsl.questions.QuestionBase import QuestionBase
|
18
|
+
from edsl.agents.Agent import Agent
|
19
|
+
from edsl.surveys.Survey import Survey
|
20
|
+
from edsl.language_models.LanguageModel import LanguageModel
|
21
|
+
from edsl.surveys.MemoryPlan import MemoryPlan
|
22
|
+
from edsl.questions.QuestionBase import QuestionBase
|
23
|
+
from edsl.scenarios.Scenario import Scenario
|
16
24
|
|
17
25
|
|
18
|
-
class
|
19
|
-
"""
|
26
|
+
class BasePlaceholder:
|
27
|
+
"""Base class for placeholder values when a question is not yet answered."""
|
20
28
|
|
21
|
-
def __init__(self):
|
22
|
-
self.
|
29
|
+
def __init__(self, placeholder_type: str = "answer"):
|
30
|
+
self.value = "N/A"
|
23
31
|
self.comment = "Will be populated by prior answer"
|
32
|
+
self._type = placeholder_type
|
24
33
|
|
25
34
|
def __getitem__(self, index):
|
26
35
|
return ""
|
27
36
|
|
28
37
|
def __str__(self):
|
29
|
-
return f"<<{self.__class__.__name__}>>"
|
38
|
+
return f"<<{self.__class__.__name__}:{self._type}>>"
|
30
39
|
|
31
40
|
def __repr__(self):
|
32
|
-
return
|
41
|
+
return self.__str__()
|
42
|
+
|
43
|
+
|
44
|
+
class PlaceholderAnswer(BasePlaceholder):
|
45
|
+
def __init__(self):
|
46
|
+
super().__init__("answer")
|
33
47
|
|
34
48
|
|
35
|
-
class PlaceholderComment(
|
36
|
-
|
49
|
+
class PlaceholderComment(BasePlaceholder):
|
50
|
+
def __init__(self):
|
51
|
+
super().__init__("comment")
|
37
52
|
|
38
53
|
|
39
|
-
class PlaceholderGeneratedTokens(
|
40
|
-
|
54
|
+
class PlaceholderGeneratedTokens(BasePlaceholder):
|
55
|
+
def __init__(self):
|
56
|
+
super().__init__("generated_tokens")
|
41
57
|
|
42
58
|
|
43
59
|
class PromptConstructor:
|
@@ -55,6 +71,8 @@ class PromptConstructor:
|
|
55
71
|
self, invigilator: "InvigilatorBase", prompt_plan: Optional["PromptPlan"] = None
|
56
72
|
):
|
57
73
|
self.invigilator = invigilator
|
74
|
+
self.prompt_plan = prompt_plan or PromptPlan()
|
75
|
+
|
58
76
|
self.agent = invigilator.agent
|
59
77
|
self.question = invigilator.question
|
60
78
|
self.scenario = invigilator.scenario
|
@@ -62,7 +80,6 @@ class PromptConstructor:
|
|
62
80
|
self.model = invigilator.model
|
63
81
|
self.current_answers = invigilator.current_answers
|
64
82
|
self.memory_plan = invigilator.memory_plan
|
65
|
-
self.prompt_plan = prompt_plan or PromptPlan()
|
66
83
|
|
67
84
|
def get_question_options(self, question_data):
|
68
85
|
"""Get the question options."""
|
@@ -115,6 +132,8 @@ class PromptConstructor:
|
|
115
132
|
('q0', 'comment')
|
116
133
|
>>> PromptConstructor._extract_quetion_and_entry_type("q0_alternate_generated_tokens")
|
117
134
|
('q0_alternate', 'generated_tokens')
|
135
|
+
>>> PromptConstructor._extract_quetion_and_entry_type("q0_alt_comment")
|
136
|
+
('q0_alt', 'comment')
|
118
137
|
"""
|
119
138
|
split_list = key_entry.rsplit("_", maxsplit=1)
|
120
139
|
if len(split_list) == 1:
|
@@ -133,24 +152,25 @@ class PromptConstructor:
|
|
133
152
|
return question_name, entry_type
|
134
153
|
|
135
154
|
@staticmethod
|
136
|
-
def _augmented_answers_dict(current_answers: dict):
|
155
|
+
def _augmented_answers_dict(current_answers: dict) -> dict:
|
137
156
|
"""
|
138
157
|
>>> PromptConstructor._augmented_answers_dict({"q0": "LOVE IT!", "q0_comment": "I love school!"})
|
139
158
|
{'q0': {'answer': 'LOVE IT!', 'comment': 'I love school!'}}
|
140
159
|
"""
|
141
|
-
|
160
|
+
from collections import defaultdict
|
161
|
+
|
162
|
+
d = defaultdict(dict)
|
142
163
|
for key, value in current_answers.items():
|
143
|
-
(
|
144
|
-
|
145
|
-
|
146
|
-
) = PromptConstructor._extract_quetion_and_entry_type(key)
|
147
|
-
if question_name not in d:
|
148
|
-
d[question_name] = {}
|
164
|
+
question_name, entry_type = (
|
165
|
+
PromptConstructor._extract_quetion_and_entry_type(key)
|
166
|
+
)
|
149
167
|
d[question_name][entry_type] = value
|
150
|
-
return d
|
168
|
+
return dict(d)
|
151
169
|
|
152
170
|
@staticmethod
|
153
|
-
def _add_answers(
|
171
|
+
def _add_answers(
|
172
|
+
answer_dict: dict, current_answers: dict
|
173
|
+
) -> dict[str, "QuestionBase"]:
|
154
174
|
"""
|
155
175
|
>>> from edsl import QuestionFreeText
|
156
176
|
>>> d = {"q0": QuestionFreeText(question_text="Do you like school?", question_name = "q0")}
|
@@ -51,7 +51,7 @@ class QuestionInstructionPromptBuilder:
|
|
51
51
|
Dict: Enriched prompt data
|
52
52
|
"""
|
53
53
|
if "question_options" in prompt_data["data"]:
|
54
|
-
from edsl.agents.
|
54
|
+
from edsl.agents.question_option_processor import QuestionOptionProcessor
|
55
55
|
|
56
56
|
question_options = QuestionOptionProcessor(
|
57
57
|
self.prompt_constructor
|
edsl/auto/AutoStudy.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional
|
1
|
+
from typing import Optional, TYPE_CHECKING
|
2
2
|
|
3
3
|
from edsl import Model
|
4
4
|
from edsl.auto.StageQuestions import StageQuestions
|
@@ -11,10 +11,12 @@ from edsl.auto.StagePersonaDimensionValueRanges import (
|
|
11
11
|
from edsl.auto.StageLabelQuestions import StageLabelQuestions
|
12
12
|
from edsl.auto.StageGenerateSurvey import StageGenerateSurvey
|
13
13
|
|
14
|
-
# from edsl.auto.StageBase import gen_pipeline
|
15
|
-
|
16
14
|
from edsl.auto.utilities import agent_generator, create_agents, gen_pipeline
|
17
15
|
|
16
|
+
if TYPE_CHECKING:
|
17
|
+
from edsl.surveys.Survey import Survey
|
18
|
+
from edsl.agents.AgentList import AgentList
|
19
|
+
|
18
20
|
|
19
21
|
class AutoStudy:
|
20
22
|
def __init__(
|
@@ -24,8 +26,10 @@ class AutoStudy:
|
|
24
26
|
model: Optional["Model"] = None,
|
25
27
|
survey: Optional["Survey"] = None,
|
26
28
|
agent_list: Optional["AgentList"] = None,
|
27
|
-
default_num_agents=11,
|
29
|
+
default_num_agents: int = 11,
|
28
30
|
):
|
31
|
+
"""AutoStudy class for generating surveys and agents."""
|
32
|
+
|
29
33
|
self.overall_question = overall_question
|
30
34
|
self.population = population
|
31
35
|
self._survey = survey
|
@@ -36,6 +40,15 @@ class AutoStudy:
|
|
36
40
|
self.default_num_agents = default_num_agents
|
37
41
|
self.model = model or Model()
|
38
42
|
|
43
|
+
def to_dict(self):
|
44
|
+
return {
|
45
|
+
"overall_question": self.overall_question,
|
46
|
+
"population": self.population,
|
47
|
+
"survey": self.survey.to_dict(),
|
48
|
+
"persona_mapping": self.persona_mapping.to_dict(),
|
49
|
+
"results": self.results.to_dict(),
|
50
|
+
}
|
51
|
+
|
39
52
|
@property
|
40
53
|
def survey(self):
|
41
54
|
if self._survey is None:
|
@@ -111,7 +124,7 @@ class AutoStudy:
|
|
111
124
|
|
112
125
|
|
113
126
|
if __name__ == "__main__":
|
114
|
-
overall_question = "
|
127
|
+
overall_question = "I have an open source Python library for working with LLMs. What are some ways we can market this to others?"
|
115
128
|
auto_study = AutoStudy(overall_question, population="US Adults")
|
116
129
|
|
117
130
|
results = auto_study.results
|
edsl/auto/StageBase.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
+
import json
|
2
3
|
from typing import Dict, List, Any, TypeVar, Generator, Dict, Callable
|
3
4
|
from dataclasses import dataclass, field, KW_ONLY, fields, asdict
|
4
5
|
import textwrap
|
@@ -35,6 +36,13 @@ class FlowDataBase:
|
|
35
36
|
sent_to_stage_name: str = field(default_factory=str)
|
36
37
|
came_from_stage_name: str = field(default_factory=str)
|
37
38
|
|
39
|
+
def to_dict(self):
|
40
|
+
return asdict(self)
|
41
|
+
|
42
|
+
@classmethod
|
43
|
+
def from_dict(cls, data: dict):
|
44
|
+
return cls(**data)
|
45
|
+
|
38
46
|
def __getitem__(self, key):
|
39
47
|
"""Allows dictionary-style getting."""
|
40
48
|
return getattr(self, key)
|
@@ -126,6 +134,10 @@ class StageBase(ABC):
|
|
126
134
|
else:
|
127
135
|
self.next_stage = None
|
128
136
|
|
137
|
+
@classmethod
|
138
|
+
def function_parameters(self):
|
139
|
+
return fields(self.input)
|
140
|
+
|
129
141
|
@classmethod
|
130
142
|
def func(cls, **kwargs):
|
131
143
|
"This provides a shortcut for running a stage by passing keyword arguments to the input function."
|
@@ -173,58 +185,59 @@ class StageBase(ABC):
|
|
173
185
|
|
174
186
|
|
175
187
|
if __name__ == "__main__":
|
176
|
-
|
188
|
+
pass
|
189
|
+
# try:
|
177
190
|
|
178
|
-
|
179
|
-
|
180
|
-
|
191
|
+
# class StageMissing(StageBase):
|
192
|
+
# def handle_data(self, data):
|
193
|
+
# return data
|
181
194
|
|
182
|
-
except NotImplementedError as e:
|
183
|
-
|
184
|
-
else:
|
185
|
-
|
195
|
+
# except NotImplementedError as e:
|
196
|
+
# print(e)
|
197
|
+
# else:
|
198
|
+
# raise Exception("Should have raised NotImplementedError")
|
186
199
|
|
187
|
-
try:
|
200
|
+
# try:
|
188
201
|
|
189
|
-
|
190
|
-
|
202
|
+
# class StageMissingInput(StageBase):
|
203
|
+
# output = FlowDataBase
|
191
204
|
|
192
|
-
except NotImplementedError as e:
|
193
|
-
|
205
|
+
# except NotImplementedError as e:
|
206
|
+
# print(e)
|
194
207
|
|
195
|
-
else:
|
196
|
-
|
208
|
+
# else:
|
209
|
+
# raise Exception("Should have raised NotImplementedError")
|
197
210
|
|
198
|
-
@dataclass
|
199
|
-
class MockInputOutput(FlowDataBase):
|
200
|
-
|
211
|
+
# @dataclass
|
212
|
+
# class MockInputOutput(FlowDataBase):
|
213
|
+
# text: str
|
201
214
|
|
202
|
-
class StageTest(StageBase):
|
203
|
-
|
204
|
-
|
215
|
+
# class StageTest(StageBase):
|
216
|
+
# input = MockInputOutput
|
217
|
+
# output = MockInputOutput
|
205
218
|
|
206
|
-
|
207
|
-
|
219
|
+
# def handle_data(self, data):
|
220
|
+
# return self.output(text=data["text"] + "processed")
|
208
221
|
|
209
|
-
result = StageTest().process(MockInputOutput(text="Hello world!"))
|
210
|
-
print(result.text)
|
222
|
+
# result = StageTest().process(MockInputOutput(text="Hello world!"))
|
223
|
+
# print(result.text)
|
211
224
|
|
212
|
-
pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
|
213
|
-
result = pipeline.process(MockInputOutput(text="Hello world!"))
|
214
|
-
print(result.text)
|
225
|
+
# pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
|
226
|
+
# result = pipeline.process(MockInputOutput(text="Hello world!"))
|
227
|
+
# print(result.text)
|
215
228
|
|
216
|
-
class BadMockInput(FlowDataBase):
|
217
|
-
|
218
|
-
|
229
|
+
# class BadMockInput(FlowDataBase):
|
230
|
+
# text: str
|
231
|
+
# other: str
|
219
232
|
|
220
|
-
class StageBad(StageBase):
|
221
|
-
|
222
|
-
|
233
|
+
# class StageBad(StageBase):
|
234
|
+
# input = BadMockInput
|
235
|
+
# output = BadMockInput
|
223
236
|
|
224
|
-
|
225
|
-
|
237
|
+
# def handle_data(self, data):
|
238
|
+
# return self.output(text=data["text"] + "processed")
|
226
239
|
|
227
|
-
try:
|
228
|
-
|
229
|
-
except ExceptionPipesDoNotFit as e:
|
230
|
-
|
240
|
+
# try:
|
241
|
+
# pipeline = StageTest(next_stage=StageBad(next_stage=StageTest()))
|
242
|
+
# except ExceptionPipesDoNotFit as e:
|
243
|
+
# print(e)
|
edsl/auto/StageQuestions.py
CHANGED
edsl/auto/utilities.py
CHANGED
@@ -88,12 +88,6 @@ def agent_eligibility(
|
|
88
88
|
q_eligibility(model=model, questions=questions, persona=persona, cache=cache)
|
89
89
|
== "Yes"
|
90
90
|
)
|
91
|
-
# results = (
|
92
|
-
# q.by(model)
|
93
|
-
# .by(Scenario({"questions": questions, "persona": persona}))
|
94
|
-
# .run(cache=cache)
|
95
|
-
# )
|
96
|
-
# return results.select("eligibility").first() == "Yes"
|
97
91
|
|
98
92
|
|
99
93
|
def gen_agent_traits(dimension_dict: dict, seed_value: Optional[str] = None):
|
edsl/coop/coop.py
CHANGED
@@ -179,6 +179,7 @@ class Coop(CoopFunctionsMixin):
|
|
179
179
|
Check the response from the server and raise errors as appropriate.
|
180
180
|
"""
|
181
181
|
# Get EDSL version from header
|
182
|
+
# breakpoint()
|
182
183
|
server_edsl_version = response.headers.get("X-EDSL-Version")
|
183
184
|
|
184
185
|
if server_edsl_version:
|
@@ -187,11 +188,18 @@ class Coop(CoopFunctionsMixin):
|
|
187
188
|
server_version_str=server_edsl_version,
|
188
189
|
):
|
189
190
|
print(
|
190
|
-
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
191
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip install --upgrade edsl`"
|
191
192
|
)
|
192
193
|
|
193
194
|
if response.status_code >= 400:
|
194
|
-
|
195
|
+
try:
|
196
|
+
message = response.json().get("detail")
|
197
|
+
except json.JSONDecodeError:
|
198
|
+
raise CoopServerResponseError(
|
199
|
+
f"Server returned status code {response.status_code}."
|
200
|
+
"JSON response could not be decoded.",
|
201
|
+
"The server response was: " + response.text,
|
202
|
+
)
|
195
203
|
# print(response.text)
|
196
204
|
if "The API key you provided is invalid" in message and check_api_key:
|
197
205
|
import secrets
|
@@ -651,9 +659,6 @@ class Coop(CoopFunctionsMixin):
|
|
651
659
|
self._resolve_server_response(response)
|
652
660
|
return response.json()
|
653
661
|
|
654
|
-
################
|
655
|
-
# Remote Inference
|
656
|
-
################
|
657
662
|
def remote_inference_create(
|
658
663
|
self,
|
659
664
|
job: Jobs,
|
@@ -764,6 +769,17 @@ class Coop(CoopFunctionsMixin):
|
|
764
769
|
}
|
765
770
|
)
|
766
771
|
|
772
|
+
def get_running_jobs(self) -> list[str]:
|
773
|
+
"""
|
774
|
+
Get a list of currently running job IDs.
|
775
|
+
|
776
|
+
Returns:
|
777
|
+
list[str]: List of running job UUIDs
|
778
|
+
"""
|
779
|
+
response = self._send_server_request(uri="jobs/status", method="GET")
|
780
|
+
self._resolve_server_response(response)
|
781
|
+
return response.json().get("running_jobs", [])
|
782
|
+
|
767
783
|
def remote_inference_cost(
|
768
784
|
self, input: Union[Jobs, Survey], iterations: int = 1
|
769
785
|
) -> int:
|