edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +107 -30
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +25 -21
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +103 -46
- edsl/agents/AgentList.py +97 -13
- edsl/agents/Invigilator.py +23 -10
- edsl/agents/InvigilatorBase.py +19 -14
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +5 -2
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +659 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +337 -121
- edsl/coop/utils.py +56 -70
- edsl/data/Cache.py +74 -22
- edsl/data/CacheHandler.py +10 -9
- edsl/data/SQLiteDict.py +11 -3
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +322 -73
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +7 -10
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +42 -55
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +97 -25
- edsl/notebooks/Notebook.py +157 -32
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +145 -23
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +16 -8
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +9 -4
- edsl/questions/question_registry.py +27 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +493 -0
- edsl/results/Result.py +42 -82
- edsl/results/Results.py +178 -66
- edsl/results/ResultsDBMixin.py +10 -9
- edsl/results/ResultsExportMixin.py +23 -507
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +9 -9
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +59 -6
- edsl/scenarios/ScenarioList.py +138 -52
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +73 -0
- edsl/study/Study.py +498 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +124 -37
- edsl/surveys/SurveyExportMixin.py +25 -5
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +90 -73
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +59 -6
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
- edsl-0.1.29.dist-info/RECORD +203 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- edsl-0.1.27.dev2.dist-info/RECORD +0 -172
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/agents/Agent.py
CHANGED
@@ -5,27 +5,14 @@ import copy
|
|
5
5
|
import inspect
|
6
6
|
import types
|
7
7
|
from typing import Any, Callable, Optional, Union, Dict, Sequence
|
8
|
-
|
9
|
-
from rich.table import Table
|
10
|
-
|
11
8
|
from edsl.Base import Base
|
12
|
-
|
13
|
-
from edsl.language_models import LanguageModel
|
14
|
-
from edsl.surveys.MemoryPlan import MemoryPlan
|
9
|
+
|
15
10
|
from edsl.exceptions.agents import (
|
16
11
|
AgentCombinationError,
|
17
12
|
AgentDirectAnswerFunctionError,
|
18
13
|
AgentDynamicTraitsFunctionError,
|
19
14
|
)
|
20
|
-
|
21
|
-
InvigilatorDebug,
|
22
|
-
InvigilatorHuman,
|
23
|
-
InvigilatorFunctional,
|
24
|
-
InvigilatorAI,
|
25
|
-
InvigilatorBase,
|
26
|
-
)
|
27
|
-
from edsl.language_models.registry import Model
|
28
|
-
from edsl.scenarios import Scenario
|
15
|
+
|
29
16
|
from edsl.agents.descriptors import (
|
30
17
|
TraitsDescriptor,
|
31
18
|
CodebookDescriptor,
|
@@ -38,10 +25,6 @@ from edsl.utilities.decorators import (
|
|
38
25
|
remove_edsl_version,
|
39
26
|
)
|
40
27
|
from edsl.data_transfer_models import AgentResponseDict
|
41
|
-
from edsl.prompts.library.agent_persona import AgentPersona
|
42
|
-
from edsl.data.Cache import Cache
|
43
|
-
|
44
|
-
|
45
28
|
from edsl.utilities.restricted_python import create_restricted_function
|
46
29
|
|
47
30
|
|
@@ -56,6 +39,7 @@ class Agent(Base):
|
|
56
39
|
name = NameDescriptor()
|
57
40
|
dynamic_traits_function_name = ""
|
58
41
|
answer_question_directly_function_name = ""
|
42
|
+
has_dynamic_traits_function = False
|
59
43
|
|
60
44
|
def __init__(
|
61
45
|
self,
|
@@ -98,7 +82,7 @@ class Agent(Base):
|
|
98
82
|
|
99
83
|
>>> a = Agent(traits = {"age": 10}, traits_presentation_template = "I am a {{age}} year old.")
|
100
84
|
>>> repr(a.agent_persona)
|
101
|
-
|
85
|
+
'Prompt(text=\"""I am a {{age}} year old.\""")'
|
102
86
|
|
103
87
|
When this is rendered for presentation to the LLM, it will replace the `{{age}}` with the actual age.
|
104
88
|
it is also possible to use the `codebook` to provide a more human-readable description of the trait.
|
@@ -109,7 +93,7 @@ class Agent(Base):
|
|
109
93
|
>>> a = Agent(traits = traits, codebook = codebook, traits_presentation_template = "This agent is Dave. {{codebook['age']}} {{age}}")
|
110
94
|
>>> d = a.traits | {'codebook': a.codebook}
|
111
95
|
>>> a.agent_persona.render(d)
|
112
|
-
Prompt(text
|
96
|
+
Prompt(text=\"""This agent is Dave. Their age is 10\""")
|
113
97
|
|
114
98
|
Instructions
|
115
99
|
------------
|
@@ -129,12 +113,16 @@ class Agent(Base):
|
|
129
113
|
|
130
114
|
if self.dynamic_traits_function:
|
131
115
|
self.dynamic_traits_function_name = self.dynamic_traits_function.__name__
|
116
|
+
self.has_dynamic_traits_function = True
|
117
|
+
else:
|
118
|
+
self.has_dynamic_traits_function = False
|
132
119
|
|
133
120
|
if dynamic_traits_function_source_code:
|
134
121
|
self.dynamic_traits_function_name = dynamic_traits_function_name
|
135
122
|
self.dynamic_traits_function = create_restricted_function(
|
136
123
|
dynamic_traits_function_name, dynamic_traits_function
|
137
124
|
)
|
125
|
+
|
138
126
|
if answer_question_directly_source_code:
|
139
127
|
self.answer_question_directly_function_name = (
|
140
128
|
answer_question_directly_function_name
|
@@ -151,6 +139,8 @@ class Agent(Base):
|
|
151
139
|
self.current_question = None
|
152
140
|
|
153
141
|
if traits_presentation_template is not None:
|
142
|
+
from edsl.prompts.library.agent_persona import AgentPersona
|
143
|
+
|
154
144
|
self.traits_presentation_template = traits_presentation_template
|
155
145
|
self.agent_persona = AgentPersona(text=self.traits_presentation_template)
|
156
146
|
|
@@ -159,7 +149,7 @@ class Agent(Base):
|
|
159
149
|
|
160
150
|
This checks whether the dynamic traits function is valid.
|
161
151
|
"""
|
162
|
-
if self.
|
152
|
+
if self.has_dynamic_traits_function:
|
163
153
|
sig = inspect.signature(self.dynamic_traits_function)
|
164
154
|
if "question" in sig.parameters:
|
165
155
|
if len(sig.parameters) > 1:
|
@@ -189,7 +179,7 @@ class Agent(Base):
|
|
189
179
|
{'age': 10, 'hair': 'brown', 'height': 5.5}
|
190
180
|
|
191
181
|
"""
|
192
|
-
if self.
|
182
|
+
if self.has_dynamic_traits_function:
|
193
183
|
sig = inspect.signature(self.dynamic_traits_function)
|
194
184
|
if "question" in sig.parameters:
|
195
185
|
return self.dynamic_traits_function(question=self.current_question)
|
@@ -198,6 +188,18 @@ class Agent(Base):
|
|
198
188
|
else:
|
199
189
|
return self._traits
|
200
190
|
|
191
|
+
def rename(self, old_name: str, new_name: str) -> Agent:
|
192
|
+
"""Rename a trait.
|
193
|
+
|
194
|
+
Example usage:
|
195
|
+
|
196
|
+
>>> a = Agent(traits = {"age": 10, "hair": "brown", "height": 5.5})
|
197
|
+
>>> a.rename("age", "years") == Agent(traits = {'years': 10, 'hair': 'brown', 'height': 5.5})
|
198
|
+
True
|
199
|
+
"""
|
200
|
+
self.traits[new_name] = self.traits.pop(old_name)
|
201
|
+
return self
|
202
|
+
|
201
203
|
def __getitem__(self, key):
|
202
204
|
"""Allow for accessing traits using the bracket notation.
|
203
205
|
|
@@ -259,8 +261,9 @@ class Agent(Base):
|
|
259
261
|
def create_invigilator(
|
260
262
|
self,
|
261
263
|
*,
|
262
|
-
question: QuestionBase,
|
263
|
-
cache,
|
264
|
+
question: "QuestionBase",
|
265
|
+
cache: "Cache",
|
266
|
+
survey: Optional["Survey"] = None,
|
264
267
|
scenario: Optional[Scenario] = None,
|
265
268
|
model: Optional[LanguageModel] = None,
|
266
269
|
debug: bool = False,
|
@@ -268,7 +271,7 @@ class Agent(Base):
|
|
268
271
|
current_answers: Optional[dict] = None,
|
269
272
|
iteration: int = 1,
|
270
273
|
sidecar_model=None,
|
271
|
-
) -> InvigilatorBase:
|
274
|
+
) -> "InvigilatorBase":
|
272
275
|
"""Create an Invigilator.
|
273
276
|
|
274
277
|
An invigilator is an object that is responsible for administering a question to an agent.
|
@@ -282,6 +285,8 @@ class Agent(Base):
|
|
282
285
|
An invigator is an object that is responsible for administering a question to an agent and
|
283
286
|
recording the responses.
|
284
287
|
"""
|
288
|
+
from edsl import Model, Scenario
|
289
|
+
|
285
290
|
cache = cache
|
286
291
|
self.current_question = question
|
287
292
|
model = model or Model()
|
@@ -289,6 +294,7 @@ class Agent(Base):
|
|
289
294
|
invigilator = self._create_invigilator(
|
290
295
|
question=question,
|
291
296
|
scenario=scenario,
|
297
|
+
survey=survey,
|
292
298
|
model=model,
|
293
299
|
debug=debug,
|
294
300
|
memory_plan=memory_plan,
|
@@ -302,12 +308,13 @@ class Agent(Base):
|
|
302
308
|
async def async_answer_question(
|
303
309
|
self,
|
304
310
|
*,
|
305
|
-
question: QuestionBase,
|
306
|
-
cache: Cache,
|
307
|
-
scenario: Optional[Scenario] = None,
|
308
|
-
|
311
|
+
question: "QuestionBase",
|
312
|
+
cache: "Cache",
|
313
|
+
scenario: Optional["Scenario"] = None,
|
314
|
+
survey: Optional["Survey"] = None,
|
315
|
+
model: Optional["LanguageModel"] = None,
|
309
316
|
debug: bool = False,
|
310
|
-
memory_plan: Optional[MemoryPlan] = None,
|
317
|
+
memory_plan: Optional["MemoryPlan"] = None,
|
311
318
|
current_answers: Optional[dict] = None,
|
312
319
|
iteration: int = 0,
|
313
320
|
) -> AgentResponseDict:
|
@@ -327,7 +334,7 @@ class Agent(Base):
|
|
327
334
|
>>> from edsl import QuestionFreeText
|
328
335
|
>>> q = QuestionFreeText.example()
|
329
336
|
>>> a.answer_question(question = q, cache = False)
|
330
|
-
{'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.',
|
337
|
+
{'answer': 'I am a direct answer.', 'comment': 'This is a real survey response from a human.', ...}
|
331
338
|
|
332
339
|
This is a function where an agent returns an answer to a particular question.
|
333
340
|
However, there are several different ways an agent can answer a question, so the
|
@@ -337,6 +344,7 @@ class Agent(Base):
|
|
337
344
|
question=question,
|
338
345
|
cache=cache,
|
339
346
|
scenario=scenario,
|
347
|
+
survey=survey,
|
340
348
|
model=model,
|
341
349
|
debug=debug,
|
342
350
|
memory_plan=memory_plan,
|
@@ -350,21 +358,35 @@ class Agent(Base):
|
|
350
358
|
|
351
359
|
def _create_invigilator(
|
352
360
|
self,
|
353
|
-
question: QuestionBase,
|
354
|
-
cache: Optional[Cache] = None,
|
355
|
-
scenario: Optional[Scenario] = None,
|
356
|
-
model: Optional[LanguageModel] = None,
|
361
|
+
question: "QuestionBase",
|
362
|
+
cache: Optional["Cache"] = None,
|
363
|
+
scenario: Optional["Scenario"] = None,
|
364
|
+
model: Optional["LanguageModel"] = None,
|
365
|
+
survey: Optional["Survey"] = None,
|
357
366
|
debug: bool = False,
|
358
|
-
memory_plan: Optional[MemoryPlan] = None,
|
367
|
+
memory_plan: Optional["MemoryPlan"] = None,
|
359
368
|
current_answers: Optional[dict] = None,
|
360
369
|
iteration: int = 0,
|
361
370
|
sidecar_model=None,
|
362
|
-
) -> InvigilatorBase:
|
371
|
+
) -> "InvigilatorBase":
|
363
372
|
"""Create an Invigilator."""
|
373
|
+
from edsl import Model
|
374
|
+
from edsl import Scenario
|
375
|
+
|
364
376
|
model = model or Model()
|
365
377
|
scenario = scenario or Scenario()
|
366
378
|
|
379
|
+
from edsl.agents.Invigilator import (
|
380
|
+
InvigilatorDebug,
|
381
|
+
InvigilatorHuman,
|
382
|
+
InvigilatorFunctional,
|
383
|
+
InvigilatorAI,
|
384
|
+
InvigilatorBase,
|
385
|
+
)
|
386
|
+
|
367
387
|
if cache is None:
|
388
|
+
from edsl.data.Cache import Cache
|
389
|
+
|
368
390
|
cache = Cache()
|
369
391
|
|
370
392
|
if debug:
|
@@ -392,6 +414,7 @@ class Agent(Base):
|
|
392
414
|
self,
|
393
415
|
question=question,
|
394
416
|
scenario=scenario,
|
417
|
+
survey=survey,
|
395
418
|
model=model,
|
396
419
|
memory_plan=memory_plan,
|
397
420
|
current_answers=current_answers,
|
@@ -467,6 +490,29 @@ class Agent(Base):
|
|
467
490
|
"""
|
468
491
|
return self.data == other.data
|
469
492
|
|
493
|
+
def __getattr__(self, name):
|
494
|
+
# This will be called only if 'name' is not found in the usual places
|
495
|
+
# breakpoint()
|
496
|
+
if name == "has_dynamic_traits_function":
|
497
|
+
return self.has_dynamic_traits_function
|
498
|
+
|
499
|
+
if name in self.traits:
|
500
|
+
return self.traits[name]
|
501
|
+
raise AttributeError(
|
502
|
+
f"'{type(self).__name__}' object has no attribute '{name}'"
|
503
|
+
)
|
504
|
+
|
505
|
+
def __getstate__(self):
|
506
|
+
state = self.__dict__.copy()
|
507
|
+
# Include any additional state that needs to be serialized
|
508
|
+
return state
|
509
|
+
|
510
|
+
def __setstate__(self, state):
|
511
|
+
self.__dict__.update(state)
|
512
|
+
# Ensure _traits is initialized if it's missing
|
513
|
+
if "_traits" not in self.__dict__:
|
514
|
+
self._traits = {}
|
515
|
+
|
470
516
|
def print(self) -> None:
|
471
517
|
from rich import print_json
|
472
518
|
import json
|
@@ -523,9 +569,9 @@ class Agent(Base):
|
|
523
569
|
if dynamic_traits_func:
|
524
570
|
func = inspect.getsource(dynamic_traits_func)
|
525
571
|
raw_data["dynamic_traits_function_source_code"] = func
|
526
|
-
raw_data[
|
527
|
-
|
528
|
-
|
572
|
+
raw_data["dynamic_traits_function_name"] = (
|
573
|
+
self.dynamic_traits_function_name
|
574
|
+
)
|
529
575
|
if hasattr(self, "answer_question_directly"):
|
530
576
|
raw_data.pop(
|
531
577
|
"answer_question_directly", None
|
@@ -541,12 +587,21 @@ class Agent(Base):
|
|
541
587
|
raw_data["answer_question_directly_source_code"] = inspect.getsource(
|
542
588
|
answer_question_directly_func
|
543
589
|
)
|
544
|
-
raw_data[
|
545
|
-
|
546
|
-
|
590
|
+
raw_data["answer_question_directly_function_name"] = (
|
591
|
+
self.answer_question_directly_function_name
|
592
|
+
)
|
547
593
|
|
548
594
|
return raw_data
|
549
595
|
|
596
|
+
def __hash__(self) -> int:
|
597
|
+
from edsl.utilities.utilities import dict_hash
|
598
|
+
|
599
|
+
return dict_hash(self._to_dict())
|
600
|
+
|
601
|
+
def _to_dict(self) -> dict[str, Union[dict, bool]]:
|
602
|
+
"""Serialize to a dictionary."""
|
603
|
+
return self.data
|
604
|
+
|
550
605
|
@add_edsl_version
|
551
606
|
def to_dict(self) -> dict[str, Union[dict, bool]]:
|
552
607
|
"""Serialize to a dictionary.
|
@@ -557,7 +612,7 @@ class Agent(Base):
|
|
557
612
|
>>> a.to_dict()
|
558
613
|
{'name': 'Steve', 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}
|
559
614
|
"""
|
560
|
-
return self.
|
615
|
+
return self._to_dict()
|
561
616
|
|
562
617
|
@classmethod
|
563
618
|
@remove_edsl_version
|
@@ -567,7 +622,7 @@ class Agent(Base):
|
|
567
622
|
Example usage:
|
568
623
|
|
569
624
|
>>> Agent.from_dict({'name': "Steve", 'traits': {'age': 10, 'hair': 'brown', 'height': 5.5}})
|
570
|
-
Agent(name =
|
625
|
+
Agent(name = \"""Steve\""", traits = {'age': 10, 'hair': 'brown', 'height': 5.5})
|
571
626
|
|
572
627
|
"""
|
573
628
|
return cls(**agent_dict)
|
@@ -619,6 +674,8 @@ class Agent(Base):
|
|
619
674
|
>>> a.rich_print()
|
620
675
|
<rich.table.Table object at ...>
|
621
676
|
"""
|
677
|
+
from rich.table import Table
|
678
|
+
|
622
679
|
table_data, column_names = self._table()
|
623
680
|
table = Table(title=f"{self.__class__.__name__} Attributes")
|
624
681
|
for column in column_names:
|
edsl/agents/AgentList.py
CHANGED
@@ -12,16 +12,18 @@ Example usage:
|
|
12
12
|
|
13
13
|
from __future__ import annotations
|
14
14
|
from collections import UserList
|
15
|
-
from typing import Optional, Union, Sequence
|
15
|
+
from typing import Optional, Union, Sequence, List, Any
|
16
16
|
from rich import print_json
|
17
17
|
from rich.table import Table
|
18
18
|
import json
|
19
19
|
import csv
|
20
20
|
|
21
|
+
|
21
22
|
from simpleeval import EvalWithCompoundTypes
|
22
23
|
|
23
24
|
from edsl.Base import Base
|
24
|
-
|
25
|
+
|
26
|
+
# from edsl.agents import Agent
|
25
27
|
from edsl.utilities.decorators import (
|
26
28
|
add_edsl_version,
|
27
29
|
remove_edsl_version,
|
@@ -31,7 +33,7 @@ from edsl.utilities.decorators import (
|
|
31
33
|
class AgentList(UserList, Base):
|
32
34
|
"""A list of Agents."""
|
33
35
|
|
34
|
-
def __init__(self, data: Optional[list[Agent]] = None):
|
36
|
+
def __init__(self, data: Optional[list["Agent"]] = None):
|
35
37
|
"""Initialize a new AgentList.
|
36
38
|
|
37
39
|
:param data: A list of Agents.
|
@@ -41,9 +43,42 @@ class AgentList(UserList, Base):
|
|
41
43
|
else:
|
42
44
|
super().__init__()
|
43
45
|
|
46
|
+
def shuffle(self, seed: Optional[str] = "edsl") -> AgentList:
|
47
|
+
"""Shuffle the AgentList.
|
48
|
+
|
49
|
+
:param seed: The seed for the random number generator.
|
50
|
+
"""
|
51
|
+
import random
|
52
|
+
|
53
|
+
random.seed(seed)
|
54
|
+
random.shuffle(self.data)
|
55
|
+
return self
|
56
|
+
|
57
|
+
def sample(self, n: int, seed="edsl") -> AgentList:
|
58
|
+
"""Return a random sample of agents.
|
59
|
+
|
60
|
+
:param n: The number of agents to sample.
|
61
|
+
:param seed: The seed for the random number generator.
|
62
|
+
"""
|
63
|
+
import random
|
64
|
+
|
65
|
+
random.seed(seed)
|
66
|
+
return AgentList(random.sample(self.data, n))
|
67
|
+
|
68
|
+
def rename(self, old_name, new_name):
|
69
|
+
"""Rename a trait in the AgentList.
|
70
|
+
|
71
|
+
:param old_name: The old name of the trait.
|
72
|
+
:param new_name: The new name of the trait.
|
73
|
+
"""
|
74
|
+
for agent in self.data:
|
75
|
+
agent.rename(old_name, new_name)
|
76
|
+
return self
|
77
|
+
|
44
78
|
def select(self, *traits) -> AgentList:
|
45
79
|
"""Selects agents with only the references traits.
|
46
80
|
|
81
|
+
>>> from edsl.agents.Agent import Agent
|
47
82
|
>>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
|
48
83
|
>>> al.select('a')
|
49
84
|
AgentList([Agent(traits = {'a': 1}), Agent(traits = {'a': 1})])
|
@@ -61,12 +96,13 @@ class AgentList(UserList, Base):
|
|
61
96
|
"""
|
62
97
|
Filter a list of agents based on an expression.
|
63
98
|
|
99
|
+
>>> from edsl.agents.Agent import Agent
|
64
100
|
>>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
|
65
101
|
>>> al.filter("b == 2")
|
66
102
|
AgentList([Agent(traits = {'a': 1, 'b': 2})])
|
67
103
|
"""
|
68
104
|
|
69
|
-
def create_evaluator(agent: Agent):
|
105
|
+
def create_evaluator(agent: "Agent"):
|
70
106
|
"""Create an evaluator for the given result.
|
71
107
|
The 'combined_dict' is a mapping of all values for that Result object.
|
72
108
|
"""
|
@@ -100,6 +136,8 @@ class AgentList(UserList, Base):
|
|
100
136
|
|
101
137
|
:param file_path: The path to the CSV file.
|
102
138
|
"""
|
139
|
+
from edsl.agents.Agent import Agent
|
140
|
+
|
103
141
|
agent_list = []
|
104
142
|
with open(file_path, "r") as f:
|
105
143
|
reader = csv.DictReader(f)
|
@@ -120,7 +158,7 @@ class AgentList(UserList, Base):
|
|
120
158
|
"""Remove traits from the AgentList.
|
121
159
|
|
122
160
|
:param traits: The traits to remove.
|
123
|
-
|
161
|
+
>>> from edsl.agents.Agent import Agent
|
124
162
|
>>> al = AgentList([Agent({'age': 22, 'hair': 'brown', 'height': 5.5}), Agent({'age': 22, 'hair': 'brown', 'height': 5.5})])
|
125
163
|
>>> al.remove_trait('age')
|
126
164
|
AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
|
@@ -139,21 +177,36 @@ class AgentList(UserList, Base):
|
|
139
177
|
reader = csv.DictReader(f)
|
140
178
|
return {field: None for field in reader.fieldnames}
|
141
179
|
|
180
|
+
def __hash__(self) -> int:
|
181
|
+
from edsl.utilities.utilities import dict_hash
|
182
|
+
|
183
|
+
data = self.to_dict()
|
184
|
+
# data['agent_list'] = sorted(data['agent_list'], key=lambda x: dict_hash(x)
|
185
|
+
return dict_hash(self._to_dict(sorted=True))
|
186
|
+
|
187
|
+
def _to_dict(self, sorted=False):
|
188
|
+
if sorted:
|
189
|
+
data = self.data[:]
|
190
|
+
data.sort(key=lambda x: hash(x))
|
191
|
+
else:
|
192
|
+
data = self.data
|
193
|
+
|
194
|
+
return {"agent_list": [agent.to_dict() for agent in data]}
|
195
|
+
|
196
|
+
def __eq__(self, other: AgentList) -> bool:
|
197
|
+
return self._to_dict(sorted=True) == other._to_dict(sorted=True)
|
198
|
+
|
142
199
|
@add_edsl_version
|
143
200
|
def to_dict(self):
|
144
|
-
"""Return dictionary of AgentList to serialization.
|
145
|
-
|
146
|
-
>>> AgentList.example().to_dict()
|
147
|
-
{'agent_list': [{'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}, {'traits': {'age': 22, 'hair': 'brown', 'height': 5.5}, 'edsl_version': '...', 'edsl_class_name': 'Agent'}], 'edsl_version': '...', 'edsl_class_name': 'AgentList'}
|
148
|
-
"""
|
149
|
-
return {"agent_list": [agent.to_dict() for agent in self.data]}
|
201
|
+
"""Return dictionary of AgentList to serialization."""
|
202
|
+
return self._to_dict()
|
150
203
|
|
151
204
|
def __repr__(self):
|
152
205
|
return f"AgentList({self.data})"
|
153
206
|
|
154
207
|
def print(self, format: Optional[str] = None):
|
155
208
|
"""Print the AgentList."""
|
156
|
-
print_json(json.dumps(self.
|
209
|
+
print_json(json.dumps(self._to_dict()))
|
157
210
|
|
158
211
|
def _repr_html_(self):
|
159
212
|
"""Return an HTML representation of the AgentList."""
|
@@ -161,18 +214,27 @@ class AgentList(UserList, Base):
|
|
161
214
|
|
162
215
|
return data_to_html(self.to_dict()["agent_list"])
|
163
216
|
|
217
|
+
def to_scenario_list(self) -> "ScenarioList":
|
218
|
+
"""Return a list of scenarios."""
|
219
|
+
from edsl.scenarios.ScenarioList import ScenarioList
|
220
|
+
from edsl.scenarios.Scenario import Scenario
|
221
|
+
|
222
|
+
return ScenarioList([Scenario(agent.traits) for agent in self.data])
|
223
|
+
|
164
224
|
@classmethod
|
165
225
|
@remove_edsl_version
|
166
226
|
def from_dict(cls, data: dict) -> "AgentList":
|
167
227
|
"""Deserialize the dictionary back to an AgentList object.
|
168
228
|
|
169
229
|
:param: data: A dictionary representing an AgentList.
|
170
|
-
|
230
|
+
>>> from edsl.agents.Agent import Agent
|
171
231
|
>>> al = AgentList([Agent.example(), Agent.example()])
|
172
232
|
>>> al2 = AgentList.from_dict(al.to_dict())
|
173
233
|
>>> al2 == al
|
174
234
|
True
|
175
235
|
"""
|
236
|
+
from edsl.agents.Agent import Agent
|
237
|
+
|
176
238
|
agents = [Agent.from_dict(agent_dict) for agent_dict in data["agent_list"]]
|
177
239
|
return cls(agents)
|
178
240
|
|
@@ -185,8 +247,30 @@ class AgentList(UserList, Base):
|
|
185
247
|
2
|
186
248
|
|
187
249
|
"""
|
250
|
+
from edsl.agents.Agent import Agent
|
251
|
+
|
188
252
|
return cls([Agent.example(), Agent.example()])
|
189
253
|
|
254
|
+
@classmethod
|
255
|
+
def from_list(self, trait_name: str, values: List[Any]):
|
256
|
+
"""Create an AgentList from a list of values.
|
257
|
+
|
258
|
+
:param trait_name: The name of the trait.
|
259
|
+
:param values: A list of values.
|
260
|
+
"""
|
261
|
+
from edsl.agents.Agent import Agent
|
262
|
+
|
263
|
+
return AgentList([Agent({trait_name: value}) for value in values])
|
264
|
+
|
265
|
+
def __mul__(self, other: AgentList) -> AgentList:
|
266
|
+
"""Takes the cross product of two AgentLists."""
|
267
|
+
from itertools import product
|
268
|
+
|
269
|
+
new_sl = []
|
270
|
+
for s1, s2 in list(product(self, other)):
|
271
|
+
new_sl.append(s1 + s2)
|
272
|
+
return AgentList(new_sl)
|
273
|
+
|
190
274
|
def code(self, string=True) -> Union[str, list[str]]:
|
191
275
|
"""Return code to construct an AgentList.
|
192
276
|
|
edsl/agents/Invigilator.py
CHANGED
@@ -63,14 +63,9 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
63
63
|
|
64
64
|
def _remove_from_cache(self, raw_response) -> None:
|
65
65
|
"""Remove an entry from the cache."""
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
):
|
70
|
-
cache_key = raw_response["raw_model_response"]["cache_key"]
|
71
|
-
else:
|
72
|
-
cache_key = None
|
73
|
-
del self.cache.data[cache_key]
|
66
|
+
cache_key = raw_response.get("cache_key", None)
|
67
|
+
if cache_key:
|
68
|
+
del self.cache.data[cache_key]
|
74
69
|
|
75
70
|
def _format_raw_response(
|
76
71
|
self, *, agent, question, scenario, raw_response, raw_model_response
|
@@ -87,7 +82,25 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
87
82
|
self._remove_from_cache(raw_response)
|
88
83
|
raise e
|
89
84
|
|
90
|
-
|
85
|
+
# breakpoint()
|
86
|
+
question_dict = self.survey.question_names_to_questions()
|
87
|
+
for other_question, answer in self.current_answers.items():
|
88
|
+
if other_question in question_dict:
|
89
|
+
question_dict[other_question].answer = answer
|
90
|
+
else:
|
91
|
+
# adds a comment to the question
|
92
|
+
if (
|
93
|
+
new_question := other_question.split("_comment")[0]
|
94
|
+
) in question_dict:
|
95
|
+
question_dict[new_question].comment = answer
|
96
|
+
|
97
|
+
combined_dict = {**question_dict, **scenario}
|
98
|
+
# print("combined_dict: ", combined_dict)
|
99
|
+
# print("response: ", response)
|
100
|
+
# breakpoint()
|
101
|
+
answer = question._translate_answer_code_to_answer(
|
102
|
+
response["answer"], combined_dict
|
103
|
+
)
|
91
104
|
data = {
|
92
105
|
"answer": answer,
|
93
106
|
"comment": response.get(
|
@@ -95,7 +108,7 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
95
108
|
), # not all question have comment fields,
|
96
109
|
"question_name": question.question_name,
|
97
110
|
"prompts": self.get_prompts(),
|
98
|
-
"cached_response": raw_response
|
111
|
+
"cached_response": raw_response.get("cached_response", None),
|
99
112
|
"usage": raw_response.get("usage", {}),
|
100
113
|
"raw_model_response": raw_model_response,
|
101
114
|
}
|