edsl 0.1.39__py3-none-any.whl → 0.1.39.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +0 -28
- edsl/__init__.py +1 -1
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +17 -9
- edsl/agents/Invigilator.py +14 -13
- edsl/agents/InvigilatorBase.py +1 -4
- edsl/agents/PromptConstructor.py +22 -42
- edsl/agents/QuestionInstructionPromptBuilder.py +1 -1
- edsl/auto/AutoStudy.py +5 -18
- edsl/auto/StageBase.py +40 -53
- edsl/auto/StageQuestions.py +1 -2
- edsl/auto/utilities.py +6 -0
- edsl/coop/coop.py +5 -21
- edsl/data/Cache.py +18 -29
- edsl/data/CacheHandler.py +2 -0
- edsl/data/RemoteCacheSync.py +46 -154
- edsl/enums.py +0 -7
- edsl/inference_services/AnthropicService.py +16 -38
- edsl/inference_services/AvailableModelFetcher.py +1 -7
- edsl/inference_services/GoogleService.py +1 -5
- edsl/inference_services/InferenceServicesCollection.py +2 -18
- edsl/inference_services/OpenAIService.py +31 -46
- edsl/inference_services/TestService.py +3 -1
- edsl/inference_services/TogetherAIService.py +3 -5
- edsl/inference_services/data_structures.py +2 -74
- edsl/jobs/AnswerQuestionFunctionConstructor.py +113 -148
- edsl/jobs/FetchInvigilator.py +3 -10
- edsl/jobs/InterviewsConstructor.py +4 -6
- edsl/jobs/Jobs.py +233 -299
- edsl/jobs/JobsChecks.py +2 -2
- edsl/jobs/JobsPrompts.py +1 -1
- edsl/jobs/JobsRemoteInferenceHandler.py +136 -160
- edsl/jobs/interviews/Interview.py +42 -80
- edsl/jobs/runners/JobsRunnerAsyncio.py +358 -88
- edsl/jobs/runners/JobsRunnerStatus.py +165 -133
- edsl/jobs/tasks/TaskHistory.py +3 -24
- edsl/language_models/LanguageModel.py +4 -59
- edsl/language_models/ModelList.py +8 -19
- edsl/language_models/__init__.py +1 -1
- edsl/language_models/registry.py +180 -0
- edsl/language_models/repair.py +1 -1
- edsl/questions/QuestionBase.py +26 -35
- edsl/questions/{question_base_gen_mixin.py → QuestionBaseGenMixin.py} +49 -52
- edsl/questions/QuestionBasePromptsMixin.py +1 -1
- edsl/questions/QuestionBudget.py +1 -1
- edsl/questions/QuestionCheckBox.py +2 -2
- edsl/questions/QuestionExtract.py +7 -5
- edsl/questions/QuestionFreeText.py +1 -1
- edsl/questions/QuestionList.py +15 -9
- edsl/questions/QuestionMatrix.py +1 -1
- edsl/questions/QuestionMultipleChoice.py +1 -1
- edsl/questions/QuestionNumerical.py +1 -1
- edsl/questions/QuestionRank.py +1 -1
- edsl/questions/{response_validator_abc.py → ResponseValidatorABC.py} +18 -6
- edsl/questions/{response_validator_factory.py → ResponseValidatorFactory.py} +1 -7
- edsl/questions/SimpleAskMixin.py +1 -1
- edsl/questions/__init__.py +1 -1
- edsl/results/DatasetExportMixin.py +119 -60
- edsl/results/Result.py +3 -109
- edsl/results/Results.py +39 -50
- edsl/scenarios/FileStore.py +0 -32
- edsl/scenarios/ScenarioList.py +7 -35
- edsl/scenarios/handlers/csv.py +0 -11
- edsl/surveys/Survey.py +20 -71
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/METADATA +1 -1
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/RECORD +78 -84
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/WHEEL +1 -1
- edsl/jobs/async_interview_runner.py +0 -138
- edsl/jobs/check_survey_scenario_compatibility.py +0 -85
- edsl/jobs/data_structures.py +0 -120
- edsl/jobs/results_exceptions_handler.py +0 -98
- edsl/language_models/model.py +0 -256
- edsl/questions/data_structures.py +0 -20
- edsl/results/file_exports.py +0 -252
- /edsl/agents/{question_option_processor.py → QuestionOptionProcessor.py} +0 -0
- /edsl/questions/{answer_validator_mixin.py → AnswerValidatorMixin.py} +0 -0
- /edsl/questions/{loop_processor.py → LoopProcessor.py} +0 -0
- /edsl/questions/{register_questions_meta.py → RegisterQuestionsMeta.py} +0 -0
- /edsl/results/{results_fetch_mixin.py → ResultsFetchMixin.py} +0 -0
- /edsl/results/{results_tools_mixin.py → ResultsToolsMixin.py} +0 -0
- /edsl/results/{results_selector.py → Selector.py} +0 -0
- /edsl/scenarios/{directory_scanner.py → DirectoryScanner.py} +0 -0
- /edsl/scenarios/{scenario_join.py → ScenarioJoin.py} +0 -0
- /edsl/scenarios/{scenario_selector.py → ScenarioSelector.py} +0 -0
- {edsl-0.1.39.dist-info → edsl-0.1.39.dev2.dist-info}/LICENSE +0 -0
edsl/Base.py
CHANGED
@@ -10,10 +10,6 @@ from uuid import UUID
|
|
10
10
|
class PersistenceMixin:
|
11
11
|
"""Mixin for saving and loading objects to and from files."""
|
12
12
|
|
13
|
-
def duplicate(self, add_edsl_version=False):
|
14
|
-
"""Return a duplicate of the object."""
|
15
|
-
return self.from_dict(self.to_dict(add_edsl_version=False))
|
16
|
-
|
17
13
|
def push(
|
18
14
|
self,
|
19
15
|
description: Optional[str] = None,
|
@@ -27,30 +23,6 @@ class PersistenceMixin:
|
|
27
23
|
c = Coop(url=expected_parrot_url)
|
28
24
|
return c.create(self, description, alias, visibility)
|
29
25
|
|
30
|
-
def to_yaml(self, add_edsl_version=False, filename: str = None) -> Union[str, None]:
|
31
|
-
import yaml
|
32
|
-
|
33
|
-
output = yaml.dump(self.to_dict(add_edsl_version=add_edsl_version))
|
34
|
-
if not filename:
|
35
|
-
return output
|
36
|
-
|
37
|
-
with open(filename, "w") as f:
|
38
|
-
f.write(output)
|
39
|
-
|
40
|
-
@classmethod
|
41
|
-
def from_yaml(cls, yaml_str: Optional[str] = None, filename: Optional[str] = None):
|
42
|
-
if yaml_str is None and filename is not None:
|
43
|
-
with open(filename, "r") as f:
|
44
|
-
yaml_str = f.read()
|
45
|
-
return cls.from_yaml(yaml_str=yaml_str)
|
46
|
-
elif yaml_str and filename is None:
|
47
|
-
import yaml
|
48
|
-
|
49
|
-
d = yaml.load(yaml_str, Loader=yaml.FullLoader)
|
50
|
-
return cls.from_dict(d)
|
51
|
-
else:
|
52
|
-
raise ValueError("Either yaml_str or filename must be provided.")
|
53
|
-
|
54
26
|
def create_download_link(self):
|
55
27
|
from tempfile import NamedTemporaryFile
|
56
28
|
from edsl.scenarios.FileStore import FileStore
|
edsl/__init__.py
CHANGED
@@ -34,7 +34,7 @@ from edsl.scenarios.FileStore import FileStore
|
|
34
34
|
|
35
35
|
# from edsl.utilities.interface import print_dict_with_rich
|
36
36
|
from edsl.surveys.Survey import Survey
|
37
|
-
from edsl.language_models.
|
37
|
+
from edsl.language_models.registry import Model
|
38
38
|
from edsl.language_models.ModelList import ModelList
|
39
39
|
|
40
40
|
from edsl.results.Results import Results
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.39"
|
1
|
+
__version__ = "0.1.39.dev2"
|
edsl/agents/Agent.py
CHANGED
@@ -549,11 +549,12 @@ class Agent(Base):
|
|
549
549
|
survey: Optional["Survey"] = None,
|
550
550
|
scenario: Optional["Scenario"] = None,
|
551
551
|
model: Optional["LanguageModel"] = None,
|
552
|
+
# debug: bool = False,
|
552
553
|
memory_plan: Optional["MemoryPlan"] = None,
|
553
554
|
current_answers: Optional[dict] = None,
|
554
555
|
iteration: int = 1,
|
556
|
+
# sidecar_model=None,
|
555
557
|
raise_validation_errors: bool = True,
|
556
|
-
key_lookup: Optional["KeyLookup"] = None,
|
557
558
|
) -> "InvigilatorBase":
|
558
559
|
"""Create an Invigilator.
|
559
560
|
|
@@ -568,7 +569,7 @@ class Agent(Base):
|
|
568
569
|
An invigator is an object that is responsible for administering a question to an agent and
|
569
570
|
recording the responses.
|
570
571
|
"""
|
571
|
-
from edsl.language_models.
|
572
|
+
from edsl.language_models.registry import Model
|
572
573
|
|
573
574
|
from edsl.scenarios.Scenario import Scenario
|
574
575
|
|
@@ -581,12 +582,13 @@ class Agent(Base):
|
|
581
582
|
scenario=scenario,
|
582
583
|
survey=survey,
|
583
584
|
model=model,
|
585
|
+
# debug=debug,
|
584
586
|
memory_plan=memory_plan,
|
585
587
|
current_answers=current_answers,
|
586
588
|
iteration=iteration,
|
587
589
|
cache=cache,
|
590
|
+
# sidecar_model=sidecar_model,
|
588
591
|
raise_validation_errors=raise_validation_errors,
|
589
|
-
key_lookup=key_lookup,
|
590
592
|
)
|
591
593
|
if hasattr(self, "validate_response"):
|
592
594
|
invigilator.validate_response = self.validate_response
|
@@ -606,7 +608,6 @@ class Agent(Base):
|
|
606
608
|
memory_plan: Optional[MemoryPlan] = None,
|
607
609
|
current_answers: Optional[dict] = None,
|
608
610
|
iteration: int = 0,
|
609
|
-
key_lookup: Optional["KeyLookup"] = None,
|
610
611
|
) -> AgentResponseDict:
|
611
612
|
"""
|
612
613
|
Answer a posed question.
|
@@ -636,10 +637,10 @@ class Agent(Base):
|
|
636
637
|
scenario=scenario,
|
637
638
|
survey=survey,
|
638
639
|
model=model,
|
640
|
+
# debug=debug,
|
639
641
|
memory_plan=memory_plan,
|
640
642
|
current_answers=current_answers,
|
641
643
|
iteration=iteration,
|
642
|
-
key_lookup=key_lookup,
|
643
644
|
)
|
644
645
|
response: AgentResponseDict = await invigilator.async_answer_question()
|
645
646
|
return response
|
@@ -672,14 +673,15 @@ class Agent(Base):
|
|
672
673
|
scenario: Optional[Scenario] = None,
|
673
674
|
model: Optional[LanguageModel] = None,
|
674
675
|
survey: Optional[Survey] = None,
|
676
|
+
# debug: bool = False,
|
675
677
|
memory_plan: Optional[MemoryPlan] = None,
|
676
678
|
current_answers: Optional[dict] = None,
|
677
679
|
iteration: int = 0,
|
680
|
+
# sidecar_model=None,
|
678
681
|
raise_validation_errors: bool = True,
|
679
|
-
key_lookup: Optional["KeyLookup"] = None,
|
680
682
|
) -> "InvigilatorBase":
|
681
683
|
"""Create an Invigilator."""
|
682
|
-
from edsl.language_models.
|
684
|
+
from edsl.language_models.registry import Model
|
683
685
|
from edsl.scenarios.Scenario import Scenario
|
684
686
|
|
685
687
|
model = model or Model()
|
@@ -692,6 +694,12 @@ class Agent(Base):
|
|
692
694
|
|
693
695
|
invigilator_class = self._get_invigilator_class(question)
|
694
696
|
|
697
|
+
# if sidecar_model is not None:
|
698
|
+
# # this is the case when a 'simple' model is being used
|
699
|
+
# # from edsl.agents.Invigilator import InvigilatorSidecar
|
700
|
+
# # invigilator_class = InvigilatorSidecar
|
701
|
+
# raise DeprecationWarning("Sidecar models are deprecated.")
|
702
|
+
|
695
703
|
invigilator = invigilator_class(
|
696
704
|
self,
|
697
705
|
question=question,
|
@@ -702,8 +710,8 @@ class Agent(Base):
|
|
702
710
|
current_answers=current_answers,
|
703
711
|
iteration=iteration,
|
704
712
|
cache=cache,
|
713
|
+
# sidecar_model=sidecar_model,
|
705
714
|
raise_validation_errors=raise_validation_errors,
|
706
|
-
key_lookup=key_lookup,
|
707
715
|
)
|
708
716
|
return invigilator
|
709
717
|
|
@@ -906,7 +914,7 @@ class Agent(Base):
|
|
906
914
|
{'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'instruction': 'Have fun.', 'edsl_version': '...', 'edsl_class_name': 'Agent'}
|
907
915
|
"""
|
908
916
|
d = {}
|
909
|
-
d["traits"] = copy.deepcopy(
|
917
|
+
d["traits"] = copy.deepcopy(self.traits)
|
910
918
|
if self.name:
|
911
919
|
d["name"] = self.name
|
912
920
|
if self.set_instructions:
|
edsl/agents/Invigilator.py
CHANGED
@@ -13,7 +13,11 @@ if TYPE_CHECKING:
|
|
13
13
|
from edsl.surveys.Survey import Survey
|
14
14
|
|
15
15
|
|
16
|
-
|
16
|
+
class NotApplicable(str):
|
17
|
+
def __new__(cls):
|
18
|
+
instance = super().__new__(cls, "Not Applicable")
|
19
|
+
instance.literal = "Not Applicable"
|
20
|
+
return instance
|
17
21
|
|
18
22
|
|
19
23
|
class InvigilatorAI(InvigilatorBase):
|
@@ -39,9 +43,6 @@ class InvigilatorAI(InvigilatorBase):
|
|
39
43
|
params.update({"iteration": self.iteration, "cache": self.cache})
|
40
44
|
params.update({"invigilator": self})
|
41
45
|
|
42
|
-
if self.key_lookup:
|
43
|
-
self.model.set_key_lookup(self.key_lookup)
|
44
|
-
|
45
46
|
return await self.model.async_get_response(**params)
|
46
47
|
|
47
48
|
def store_response(self, agent_response_dict: AgentResponseDict) -> None:
|
@@ -231,13 +232,13 @@ class InvigilatorHuman(InvigilatorBase):
|
|
231
232
|
exception_occurred = e
|
232
233
|
finally:
|
233
234
|
data = {
|
234
|
-
"generated_tokens":
|
235
|
+
"generated_tokens": NotApplicable(),
|
235
236
|
"question_name": self.question.question_name,
|
236
237
|
"prompts": self.get_prompts(),
|
237
|
-
"cached_response":
|
238
|
-
"raw_model_response":
|
239
|
-
"cache_used":
|
240
|
-
"cache_key":
|
238
|
+
"cached_response": NotApplicable(),
|
239
|
+
"raw_model_response": NotApplicable(),
|
240
|
+
"cache_used": NotApplicable(),
|
241
|
+
"cache_key": NotApplicable(),
|
241
242
|
"answer": answer,
|
242
243
|
"comment": comment,
|
243
244
|
"validated": validated,
|
@@ -258,10 +259,10 @@ class InvigilatorFunctional(InvigilatorBase):
|
|
258
259
|
generated_tokens=str(answer),
|
259
260
|
question_name=self.question.question_name,
|
260
261
|
prompts=self.get_prompts(),
|
261
|
-
cached_response=
|
262
|
-
raw_model_response=
|
263
|
-
cache_used=
|
264
|
-
cache_key=
|
262
|
+
cached_response=NotApplicable(),
|
263
|
+
raw_model_response=NotApplicable(),
|
264
|
+
cache_used=NotApplicable(),
|
265
|
+
cache_key=NotApplicable(),
|
265
266
|
answer=answer["answer"],
|
266
267
|
comment="This is the result of a functional question.",
|
267
268
|
validated=True,
|
edsl/agents/InvigilatorBase.py
CHANGED
@@ -14,7 +14,6 @@ if TYPE_CHECKING:
|
|
14
14
|
from edsl.language_models.LanguageModel import LanguageModel
|
15
15
|
from edsl.surveys.Survey import Survey
|
16
16
|
from edsl.agents.Agent import Agent
|
17
|
-
from edsl.language_models.key_management.KeyLookup import KeyLookup
|
18
17
|
|
19
18
|
from edsl.data_transfer_models import EDSLResultObjectInput
|
20
19
|
from edsl.agents.PromptConstructor import PromptConstructor
|
@@ -47,7 +46,6 @@ class InvigilatorBase(ABC):
|
|
47
46
|
additional_prompt_data: Optional[dict] = None,
|
48
47
|
raise_validation_errors: Optional[bool] = True,
|
49
48
|
prompt_plan: Optional["PromptPlan"] = None,
|
50
|
-
key_lookup: Optional["KeyLookup"] = None,
|
51
49
|
):
|
52
50
|
"""Initialize a new Invigilator."""
|
53
51
|
self.agent = agent
|
@@ -61,7 +59,6 @@ class InvigilatorBase(ABC):
|
|
61
59
|
self.cache = cache
|
62
60
|
self.survey = survey
|
63
61
|
self.raise_validation_errors = raise_validation_errors
|
64
|
-
self.key_lookup = key_lookup
|
65
62
|
|
66
63
|
if prompt_plan is None:
|
67
64
|
self.prompt_plan = PromptPlan()
|
@@ -211,7 +208,7 @@ class InvigilatorBase(ABC):
|
|
211
208
|
from edsl.agents.Agent import Agent
|
212
209
|
from edsl.scenarios.Scenario import Scenario
|
213
210
|
from edsl.surveys.MemoryPlan import MemoryPlan
|
214
|
-
from edsl.language_models.
|
211
|
+
from edsl.language_models.registry import Model
|
215
212
|
from edsl.surveys.Survey import Survey
|
216
213
|
|
217
214
|
model = Model("test", canned_response="SPAM!")
|
edsl/agents/PromptConstructor.py
CHANGED
@@ -4,56 +4,40 @@ from functools import cached_property
|
|
4
4
|
|
5
5
|
from edsl.prompts.Prompt import Prompt
|
6
6
|
|
7
|
-
from dataclasses import dataclass
|
8
|
-
|
9
7
|
from .prompt_helpers import PromptPlan
|
10
8
|
from .QuestionTemplateReplacementsBuilder import (
|
11
9
|
QuestionTemplateReplacementsBuilder,
|
12
10
|
)
|
13
|
-
from .
|
11
|
+
from .QuestionOptionProcessor import QuestionOptionProcessor
|
14
12
|
|
15
13
|
if TYPE_CHECKING:
|
16
14
|
from edsl.agents.InvigilatorBase import InvigilatorBase
|
17
15
|
from edsl.questions.QuestionBase import QuestionBase
|
18
|
-
from edsl.agents.Agent import Agent
|
19
|
-
from edsl.surveys.Survey import Survey
|
20
|
-
from edsl.language_models.LanguageModel import LanguageModel
|
21
|
-
from edsl.surveys.MemoryPlan import MemoryPlan
|
22
|
-
from edsl.questions.QuestionBase import QuestionBase
|
23
|
-
from edsl.scenarios.Scenario import Scenario
|
24
16
|
|
25
17
|
|
26
|
-
class
|
27
|
-
"""
|
18
|
+
class PlaceholderAnswer:
|
19
|
+
"""A placeholder answer for when a question is not yet answered."""
|
28
20
|
|
29
|
-
def __init__(self
|
30
|
-
self.
|
21
|
+
def __init__(self):
|
22
|
+
self.answer = "N/A"
|
31
23
|
self.comment = "Will be populated by prior answer"
|
32
|
-
self._type = placeholder_type
|
33
24
|
|
34
25
|
def __getitem__(self, index):
|
35
26
|
return ""
|
36
27
|
|
37
28
|
def __str__(self):
|
38
|
-
return f"<<{self.__class__.__name__}
|
29
|
+
return f"<<{self.__class__.__name__}>>"
|
39
30
|
|
40
31
|
def __repr__(self):
|
41
|
-
return self.
|
42
|
-
|
43
|
-
|
44
|
-
class PlaceholderAnswer(BasePlaceholder):
|
45
|
-
def __init__(self):
|
46
|
-
super().__init__("answer")
|
32
|
+
return f"<<{self.__class__.__name__}>>"
|
47
33
|
|
48
34
|
|
49
|
-
class PlaceholderComment(
|
50
|
-
|
51
|
-
super().__init__("comment")
|
35
|
+
class PlaceholderComment(PlaceholderAnswer):
|
36
|
+
pass
|
52
37
|
|
53
38
|
|
54
|
-
class PlaceholderGeneratedTokens(
|
55
|
-
|
56
|
-
super().__init__("generated_tokens")
|
39
|
+
class PlaceholderGeneratedTokens(PlaceholderAnswer):
|
40
|
+
pass
|
57
41
|
|
58
42
|
|
59
43
|
class PromptConstructor:
|
@@ -71,8 +55,6 @@ class PromptConstructor:
|
|
71
55
|
self, invigilator: "InvigilatorBase", prompt_plan: Optional["PromptPlan"] = None
|
72
56
|
):
|
73
57
|
self.invigilator = invigilator
|
74
|
-
self.prompt_plan = prompt_plan or PromptPlan()
|
75
|
-
|
76
58
|
self.agent = invigilator.agent
|
77
59
|
self.question = invigilator.question
|
78
60
|
self.scenario = invigilator.scenario
|
@@ -80,6 +62,7 @@ class PromptConstructor:
|
|
80
62
|
self.model = invigilator.model
|
81
63
|
self.current_answers = invigilator.current_answers
|
82
64
|
self.memory_plan = invigilator.memory_plan
|
65
|
+
self.prompt_plan = prompt_plan or PromptPlan()
|
83
66
|
|
84
67
|
def get_question_options(self, question_data):
|
85
68
|
"""Get the question options."""
|
@@ -132,8 +115,6 @@ class PromptConstructor:
|
|
132
115
|
('q0', 'comment')
|
133
116
|
>>> PromptConstructor._extract_quetion_and_entry_type("q0_alternate_generated_tokens")
|
134
117
|
('q0_alternate', 'generated_tokens')
|
135
|
-
>>> PromptConstructor._extract_quetion_and_entry_type("q0_alt_comment")
|
136
|
-
('q0_alt', 'comment')
|
137
118
|
"""
|
138
119
|
split_list = key_entry.rsplit("_", maxsplit=1)
|
139
120
|
if len(split_list) == 1:
|
@@ -152,25 +133,24 @@ class PromptConstructor:
|
|
152
133
|
return question_name, entry_type
|
153
134
|
|
154
135
|
@staticmethod
|
155
|
-
def _augmented_answers_dict(current_answers: dict)
|
136
|
+
def _augmented_answers_dict(current_answers: dict):
|
156
137
|
"""
|
157
138
|
>>> PromptConstructor._augmented_answers_dict({"q0": "LOVE IT!", "q0_comment": "I love school!"})
|
158
139
|
{'q0': {'answer': 'LOVE IT!', 'comment': 'I love school!'}}
|
159
140
|
"""
|
160
|
-
|
161
|
-
|
162
|
-
d = defaultdict(dict)
|
141
|
+
d = {}
|
163
142
|
for key, value in current_answers.items():
|
164
|
-
|
165
|
-
|
166
|
-
|
143
|
+
(
|
144
|
+
question_name,
|
145
|
+
entry_type,
|
146
|
+
) = PromptConstructor._extract_quetion_and_entry_type(key)
|
147
|
+
if question_name not in d:
|
148
|
+
d[question_name] = {}
|
167
149
|
d[question_name][entry_type] = value
|
168
|
-
return
|
150
|
+
return d
|
169
151
|
|
170
152
|
@staticmethod
|
171
|
-
def _add_answers(
|
172
|
-
answer_dict: dict, current_answers: dict
|
173
|
-
) -> dict[str, "QuestionBase"]:
|
153
|
+
def _add_answers(answer_dict: dict, current_answers) -> dict[str, "QuestionBase"]:
|
174
154
|
"""
|
175
155
|
>>> from edsl import QuestionFreeText
|
176
156
|
>>> d = {"q0": QuestionFreeText(question_text="Do you like school?", question_name = "q0")}
|
@@ -51,7 +51,7 @@ class QuestionInstructionPromptBuilder:
|
|
51
51
|
Dict: Enriched prompt data
|
52
52
|
"""
|
53
53
|
if "question_options" in prompt_data["data"]:
|
54
|
-
from edsl.agents.
|
54
|
+
from edsl.agents.QuestionOptionProcessor import QuestionOptionProcessor
|
55
55
|
|
56
56
|
question_options = QuestionOptionProcessor(
|
57
57
|
self.prompt_constructor
|
edsl/auto/AutoStudy.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Optional
|
1
|
+
from typing import Optional
|
2
2
|
|
3
3
|
from edsl import Model
|
4
4
|
from edsl.auto.StageQuestions import StageQuestions
|
@@ -11,11 +11,9 @@ from edsl.auto.StagePersonaDimensionValueRanges import (
|
|
11
11
|
from edsl.auto.StageLabelQuestions import StageLabelQuestions
|
12
12
|
from edsl.auto.StageGenerateSurvey import StageGenerateSurvey
|
13
13
|
|
14
|
-
from edsl.auto.
|
14
|
+
# from edsl.auto.StageBase import gen_pipeline
|
15
15
|
|
16
|
-
|
17
|
-
from edsl.surveys.Survey import Survey
|
18
|
-
from edsl.agents.AgentList import AgentList
|
16
|
+
from edsl.auto.utilities import agent_generator, create_agents, gen_pipeline
|
19
17
|
|
20
18
|
|
21
19
|
class AutoStudy:
|
@@ -26,10 +24,8 @@ class AutoStudy:
|
|
26
24
|
model: Optional["Model"] = None,
|
27
25
|
survey: Optional["Survey"] = None,
|
28
26
|
agent_list: Optional["AgentList"] = None,
|
29
|
-
default_num_agents
|
27
|
+
default_num_agents=11,
|
30
28
|
):
|
31
|
-
"""AutoStudy class for generating surveys and agents."""
|
32
|
-
|
33
29
|
self.overall_question = overall_question
|
34
30
|
self.population = population
|
35
31
|
self._survey = survey
|
@@ -40,15 +36,6 @@ class AutoStudy:
|
|
40
36
|
self.default_num_agents = default_num_agents
|
41
37
|
self.model = model or Model()
|
42
38
|
|
43
|
-
def to_dict(self):
|
44
|
-
return {
|
45
|
-
"overall_question": self.overall_question,
|
46
|
-
"population": self.population,
|
47
|
-
"survey": self.survey.to_dict(),
|
48
|
-
"persona_mapping": self.persona_mapping.to_dict(),
|
49
|
-
"results": self.results.to_dict(),
|
50
|
-
}
|
51
|
-
|
52
39
|
@property
|
53
40
|
def survey(self):
|
54
41
|
if self._survey is None:
|
@@ -124,7 +111,7 @@ class AutoStudy:
|
|
124
111
|
|
125
112
|
|
126
113
|
if __name__ == "__main__":
|
127
|
-
overall_question = "
|
114
|
+
overall_question = "Should online platforms be regulated with respect to selling electric scooters?"
|
128
115
|
auto_study = AutoStudy(overall_question, population="US Adults")
|
129
116
|
|
130
117
|
results = auto_study.results
|
edsl/auto/StageBase.py
CHANGED
@@ -1,5 +1,4 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
import json
|
3
2
|
from typing import Dict, List, Any, TypeVar, Generator, Dict, Callable
|
4
3
|
from dataclasses import dataclass, field, KW_ONLY, fields, asdict
|
5
4
|
import textwrap
|
@@ -36,13 +35,6 @@ class FlowDataBase:
|
|
36
35
|
sent_to_stage_name: str = field(default_factory=str)
|
37
36
|
came_from_stage_name: str = field(default_factory=str)
|
38
37
|
|
39
|
-
def to_dict(self):
|
40
|
-
return asdict(self)
|
41
|
-
|
42
|
-
@classmethod
|
43
|
-
def from_dict(cls, data: dict):
|
44
|
-
return cls(**data)
|
45
|
-
|
46
38
|
def __getitem__(self, key):
|
47
39
|
"""Allows dictionary-style getting."""
|
48
40
|
return getattr(self, key)
|
@@ -134,10 +126,6 @@ class StageBase(ABC):
|
|
134
126
|
else:
|
135
127
|
self.next_stage = None
|
136
128
|
|
137
|
-
@classmethod
|
138
|
-
def function_parameters(self):
|
139
|
-
return fields(self.input)
|
140
|
-
|
141
129
|
@classmethod
|
142
130
|
def func(cls, **kwargs):
|
143
131
|
"This provides a shortcut for running a stage by passing keyword arguments to the input function."
|
@@ -185,59 +173,58 @@ class StageBase(ABC):
|
|
185
173
|
|
186
174
|
|
187
175
|
if __name__ == "__main__":
|
188
|
-
|
189
|
-
# try:
|
176
|
+
try:
|
190
177
|
|
191
|
-
|
192
|
-
|
193
|
-
|
178
|
+
class StageMissing(StageBase):
|
179
|
+
def handle_data(self, data):
|
180
|
+
return data
|
194
181
|
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
182
|
+
except NotImplementedError as e:
|
183
|
+
print(e)
|
184
|
+
else:
|
185
|
+
raise Exception("Should have raised NotImplementedError")
|
199
186
|
|
200
|
-
|
187
|
+
try:
|
201
188
|
|
202
|
-
|
203
|
-
|
189
|
+
class StageMissingInput(StageBase):
|
190
|
+
output = FlowDataBase
|
204
191
|
|
205
|
-
|
206
|
-
|
192
|
+
except NotImplementedError as e:
|
193
|
+
print(e)
|
207
194
|
|
208
|
-
|
209
|
-
|
195
|
+
else:
|
196
|
+
raise Exception("Should have raised NotImplementedError")
|
210
197
|
|
211
|
-
|
212
|
-
|
213
|
-
|
198
|
+
@dataclass
|
199
|
+
class MockInputOutput(FlowDataBase):
|
200
|
+
text: str
|
214
201
|
|
215
|
-
|
216
|
-
|
217
|
-
|
202
|
+
class StageTest(StageBase):
|
203
|
+
input = MockInputOutput
|
204
|
+
output = MockInputOutput
|
218
205
|
|
219
|
-
|
220
|
-
|
206
|
+
def handle_data(self, data):
|
207
|
+
return self.output(text=data["text"] + "processed")
|
221
208
|
|
222
|
-
|
223
|
-
|
209
|
+
result = StageTest().process(MockInputOutput(text="Hello world!"))
|
210
|
+
print(result.text)
|
224
211
|
|
225
|
-
|
226
|
-
|
227
|
-
|
212
|
+
pipeline = StageTest(next_stage=StageTest(next_stage=StageTest()))
|
213
|
+
result = pipeline.process(MockInputOutput(text="Hello world!"))
|
214
|
+
print(result.text)
|
228
215
|
|
229
|
-
|
230
|
-
|
231
|
-
|
216
|
+
class BadMockInput(FlowDataBase):
|
217
|
+
text: str
|
218
|
+
other: str
|
232
219
|
|
233
|
-
|
234
|
-
|
235
|
-
|
220
|
+
class StageBad(StageBase):
|
221
|
+
input = BadMockInput
|
222
|
+
output = BadMockInput
|
236
223
|
|
237
|
-
|
238
|
-
|
224
|
+
def handle_data(self, data):
|
225
|
+
return self.output(text=data["text"] + "processed")
|
239
226
|
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
227
|
+
try:
|
228
|
+
pipeline = StageTest(next_stage=StageBad(next_stage=StageTest()))
|
229
|
+
except ExceptionPipesDoNotFit as e:
|
230
|
+
print(e)
|
edsl/auto/StageQuestions.py
CHANGED
edsl/auto/utilities.py
CHANGED
@@ -88,6 +88,12 @@ def agent_eligibility(
|
|
88
88
|
q_eligibility(model=model, questions=questions, persona=persona, cache=cache)
|
89
89
|
== "Yes"
|
90
90
|
)
|
91
|
+
# results = (
|
92
|
+
# q.by(model)
|
93
|
+
# .by(Scenario({"questions": questions, "persona": persona}))
|
94
|
+
# .run(cache=cache)
|
95
|
+
# )
|
96
|
+
# return results.select("eligibility").first() == "Yes"
|
91
97
|
|
92
98
|
|
93
99
|
def gen_agent_traits(dimension_dict: dict, seed_value: Optional[str] = None):
|
edsl/coop/coop.py
CHANGED
@@ -179,7 +179,6 @@ class Coop(CoopFunctionsMixin):
|
|
179
179
|
Check the response from the server and raise errors as appropriate.
|
180
180
|
"""
|
181
181
|
# Get EDSL version from header
|
182
|
-
# breakpoint()
|
183
182
|
server_edsl_version = response.headers.get("X-EDSL-Version")
|
184
183
|
|
185
184
|
if server_edsl_version:
|
@@ -188,18 +187,11 @@ class Coop(CoopFunctionsMixin):
|
|
188
187
|
server_version_str=server_edsl_version,
|
189
188
|
):
|
190
189
|
print(
|
191
|
-
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip
|
190
|
+
"Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
|
192
191
|
)
|
193
192
|
|
194
193
|
if response.status_code >= 400:
|
195
|
-
|
196
|
-
message = response.json().get("detail")
|
197
|
-
except json.JSONDecodeError:
|
198
|
-
raise CoopServerResponseError(
|
199
|
-
f"Server returned status code {response.status_code}."
|
200
|
-
"JSON response could not be decoded.",
|
201
|
-
"The server response was: " + response.text,
|
202
|
-
)
|
194
|
+
message = response.json().get("detail")
|
203
195
|
# print(response.text)
|
204
196
|
if "The API key you provided is invalid" in message and check_api_key:
|
205
197
|
import secrets
|
@@ -659,6 +651,9 @@ class Coop(CoopFunctionsMixin):
|
|
659
651
|
self._resolve_server_response(response)
|
660
652
|
return response.json()
|
661
653
|
|
654
|
+
################
|
655
|
+
# Remote Inference
|
656
|
+
################
|
662
657
|
def remote_inference_create(
|
663
658
|
self,
|
664
659
|
job: Jobs,
|
@@ -769,17 +764,6 @@ class Coop(CoopFunctionsMixin):
|
|
769
764
|
}
|
770
765
|
)
|
771
766
|
|
772
|
-
def get_running_jobs(self) -> list[str]:
|
773
|
-
"""
|
774
|
-
Get a list of currently running job IDs.
|
775
|
-
|
776
|
-
Returns:
|
777
|
-
list[str]: List of running job UUIDs
|
778
|
-
"""
|
779
|
-
response = self._send_server_request(uri="jobs/status", method="GET")
|
780
|
-
self._resolve_server_response(response)
|
781
|
-
return response.json().get("running_jobs", [])
|
782
|
-
|
783
767
|
def remote_inference_cost(
|
784
768
|
self, input: Union[Jobs, Survey], iterations: int = 1
|
785
769
|
) -> int:
|