edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +6 -3
- edsl/__init__.py +23 -23
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +35 -34
- edsl/agents/AgentList.py +16 -5
- edsl/agents/Invigilator.py +19 -1
- edsl/agents/descriptors.py +2 -1
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/coop/utils.py +28 -1
- edsl/data/Cache.py +19 -5
- edsl/data/SQLiteDict.py +11 -3
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +69 -31
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +0 -6
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +9 -5
- edsl/jobs/runners/JobsRunnerAsyncio.py +12 -16
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +5 -11
- edsl/language_models/ModelList.py +1 -1
- edsl/language_models/repair.py +8 -7
- edsl/notebooks/Notebook.py +9 -3
- edsl/questions/QuestionBase.py +6 -2
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +12 -5
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +4 -2
- edsl/results/DatasetExportMixin.py +491 -0
- edsl/results/Result.py +13 -65
- edsl/results/Results.py +91 -39
- edsl/results/ResultsDBMixin.py +7 -3
- edsl/results/ResultsExportMixin.py +22 -537
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +1 -4
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +5 -6
- edsl/scenarios/ScenarioList.py +17 -8
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +9 -4
- edsl/surveys/SurveyExportMixin.py +4 -2
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/utilities/__init__.py +21 -21
- edsl/utilities/interface.py +66 -45
- edsl/utilities/utilities.py +11 -13
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/METADATA +1 -1
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/RECORD +60 -56
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/WHEEL +0 -0
edsl/Base.py
CHANGED
@@ -6,9 +6,6 @@ import io
|
|
6
6
|
import json
|
7
7
|
from typing import Any, Optional, Union
|
8
8
|
from uuid import UUID
|
9
|
-
from IPython.display import display
|
10
|
-
from rich.console import Console
|
11
|
-
from edsl.utilities import is_notebook
|
12
9
|
|
13
10
|
|
14
11
|
class RichPrintingMixin:
|
@@ -16,6 +13,8 @@ class RichPrintingMixin:
|
|
16
13
|
|
17
14
|
def _for_console(self):
|
18
15
|
"""Return a string representation of the object for console printing."""
|
16
|
+
from rich.console import Console
|
17
|
+
|
19
18
|
with io.StringIO() as buf:
|
20
19
|
console = Console(file=buf, record=True)
|
21
20
|
table = self.rich_print()
|
@@ -28,7 +27,11 @@ class RichPrintingMixin:
|
|
28
27
|
|
29
28
|
def print(self):
|
30
29
|
"""Print the object to the console."""
|
30
|
+
from edsl.utilities.utilities import is_notebook
|
31
|
+
|
31
32
|
if is_notebook():
|
33
|
+
from IPython.display import display
|
34
|
+
|
32
35
|
display(self.rich_print())
|
33
36
|
else:
|
34
37
|
from rich.console import Console
|
edsl/__init__.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1
1
|
import os
|
2
|
+
import time
|
2
3
|
|
3
4
|
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
4
5
|
ROOT_DIR = os.path.dirname(BASE_DIR)
|
@@ -7,36 +8,35 @@ from edsl.__version__ import __version__
|
|
7
8
|
from edsl.config import Config, CONFIG
|
8
9
|
from edsl.agents.Agent import Agent
|
9
10
|
from edsl.agents.AgentList import AgentList
|
10
|
-
from edsl.questions import
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
|
16
|
-
|
17
|
-
|
18
|
-
|
19
|
-
|
20
|
-
|
21
|
-
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
from edsl.scenarios
|
27
|
-
|
28
|
-
from edsl.utilities.interface import print_dict_with_rich
|
11
|
+
from edsl.questions import QuestionBase
|
12
|
+
from edsl.questions import QuestionMultipleChoice
|
13
|
+
from edsl.questions import QuestionBudget
|
14
|
+
from edsl.questions import QuestionCheckBox
|
15
|
+
from edsl.questions import QuestionExtract
|
16
|
+
from edsl.questions import QuestionFreeText
|
17
|
+
from edsl.questions import QuestionFunctional
|
18
|
+
from edsl.questions import QuestionLikertFive
|
19
|
+
from edsl.questions import QuestionList
|
20
|
+
from edsl.questions import QuestionLinearScale
|
21
|
+
from edsl.questions import QuestionNumerical
|
22
|
+
from edsl.questions import QuestionRank
|
23
|
+
from edsl.questions import QuestionTopK
|
24
|
+
from edsl.questions import QuestionYesNo
|
25
|
+
from edsl.questions.question_registry import Question
|
26
|
+
from edsl.scenarios import Scenario
|
27
|
+
from edsl.scenarios import ScenarioList
|
28
|
+
|
29
|
+
# from edsl.utilities.interface import print_dict_with_rich
|
29
30
|
from edsl.surveys.Survey import Survey
|
30
31
|
from edsl.language_models.registry import Model
|
31
|
-
from edsl.
|
32
|
+
from edsl.language_models.ModelList import ModelList
|
32
33
|
from edsl.results.Results import Results
|
33
34
|
from edsl.data.Cache import Cache
|
34
35
|
from edsl.data.CacheEntry import CacheEntry
|
35
36
|
from edsl.data.CacheHandler import set_session_cache, unset_session_cache
|
36
37
|
from edsl.shared import shared_globals
|
37
|
-
from edsl.jobs import Jobs
|
38
|
-
from edsl.notebooks import Notebook
|
38
|
+
from edsl.jobs.Jobs import Jobs
|
39
|
+
from edsl.notebooks.Notebook import Notebook
|
39
40
|
from edsl.study.Study import Study
|
40
41
|
from edsl.conjure.Conjure import Conjure
|
41
|
-
from edsl.language_models.ModelList import ModelList
|
42
42
|
from edsl.coop.coop import Coop
|
edsl/__version__.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.1.
|
1
|
+
__version__ = "0.1.30.dev1"
|
edsl/agents/Agent.py
CHANGED
@@ -5,27 +5,14 @@ import copy
|
|
5
5
|
import inspect
|
6
6
|
import types
|
7
7
|
from typing import Any, Callable, Optional, Union, Dict, Sequence
|
8
|
-
|
9
|
-
from rich.table import Table
|
10
|
-
|
11
8
|
from edsl.Base import Base
|
12
|
-
|
13
|
-
from edsl.language_models import LanguageModel
|
14
|
-
from edsl.surveys.MemoryPlan import MemoryPlan
|
9
|
+
|
15
10
|
from edsl.exceptions.agents import (
|
16
11
|
AgentCombinationError,
|
17
12
|
AgentDirectAnswerFunctionError,
|
18
13
|
AgentDynamicTraitsFunctionError,
|
19
14
|
)
|
20
|
-
|
21
|
-
InvigilatorDebug,
|
22
|
-
InvigilatorHuman,
|
23
|
-
InvigilatorFunctional,
|
24
|
-
InvigilatorAI,
|
25
|
-
InvigilatorBase,
|
26
|
-
)
|
27
|
-
from edsl.language_models.registry import Model
|
28
|
-
from edsl.scenarios import Scenario
|
15
|
+
|
29
16
|
from edsl.agents.descriptors import (
|
30
17
|
TraitsDescriptor,
|
31
18
|
CodebookDescriptor,
|
@@ -38,10 +25,6 @@ from edsl.utilities.decorators import (
|
|
38
25
|
remove_edsl_version,
|
39
26
|
)
|
40
27
|
from edsl.data_transfer_models import AgentResponseDict
|
41
|
-
from edsl.prompts.library.agent_persona import AgentPersona
|
42
|
-
from edsl.data.Cache import Cache
|
43
|
-
|
44
|
-
|
45
28
|
from edsl.utilities.restricted_python import create_restricted_function
|
46
29
|
|
47
30
|
|
@@ -156,6 +139,8 @@ class Agent(Base):
|
|
156
139
|
self.current_question = None
|
157
140
|
|
158
141
|
if traits_presentation_template is not None:
|
142
|
+
from edsl.prompts.library.agent_persona import AgentPersona
|
143
|
+
|
159
144
|
self.traits_presentation_template = traits_presentation_template
|
160
145
|
self.agent_persona = AgentPersona(text=self.traits_presentation_template)
|
161
146
|
|
@@ -276,8 +261,8 @@ class Agent(Base):
|
|
276
261
|
def create_invigilator(
|
277
262
|
self,
|
278
263
|
*,
|
279
|
-
question: QuestionBase,
|
280
|
-
cache,
|
264
|
+
question: "QuestionBase",
|
265
|
+
cache: "Cache",
|
281
266
|
survey: Optional["Survey"] = None,
|
282
267
|
scenario: Optional[Scenario] = None,
|
283
268
|
model: Optional[LanguageModel] = None,
|
@@ -286,7 +271,7 @@ class Agent(Base):
|
|
286
271
|
current_answers: Optional[dict] = None,
|
287
272
|
iteration: int = 1,
|
288
273
|
sidecar_model=None,
|
289
|
-
) -> InvigilatorBase:
|
274
|
+
) -> "InvigilatorBase":
|
290
275
|
"""Create an Invigilator.
|
291
276
|
|
292
277
|
An invigilator is an object that is responsible for administering a question to an agent.
|
@@ -300,6 +285,8 @@ class Agent(Base):
|
|
300
285
|
An invigator is an object that is responsible for administering a question to an agent and
|
301
286
|
recording the responses.
|
302
287
|
"""
|
288
|
+
from edsl import Model, Scenario
|
289
|
+
|
303
290
|
cache = cache
|
304
291
|
self.current_question = question
|
305
292
|
model = model or Model()
|
@@ -321,13 +308,13 @@ class Agent(Base):
|
|
321
308
|
async def async_answer_question(
|
322
309
|
self,
|
323
310
|
*,
|
324
|
-
question: QuestionBase,
|
325
|
-
cache: Cache,
|
326
|
-
scenario: Optional[Scenario] = None,
|
311
|
+
question: "QuestionBase",
|
312
|
+
cache: "Cache",
|
313
|
+
scenario: Optional["Scenario"] = None,
|
327
314
|
survey: Optional["Survey"] = None,
|
328
|
-
model: Optional[LanguageModel] = None,
|
315
|
+
model: Optional["LanguageModel"] = None,
|
329
316
|
debug: bool = False,
|
330
|
-
memory_plan: Optional[MemoryPlan] = None,
|
317
|
+
memory_plan: Optional["MemoryPlan"] = None,
|
331
318
|
current_answers: Optional[dict] = None,
|
332
319
|
iteration: int = 0,
|
333
320
|
) -> AgentResponseDict:
|
@@ -371,22 +358,35 @@ class Agent(Base):
|
|
371
358
|
|
372
359
|
def _create_invigilator(
|
373
360
|
self,
|
374
|
-
question: QuestionBase,
|
375
|
-
cache: Optional[Cache] = None,
|
376
|
-
scenario: Optional[Scenario] = None,
|
377
|
-
model: Optional[LanguageModel] = None,
|
361
|
+
question: "QuestionBase",
|
362
|
+
cache: Optional["Cache"] = None,
|
363
|
+
scenario: Optional["Scenario"] = None,
|
364
|
+
model: Optional["LanguageModel"] = None,
|
378
365
|
survey: Optional["Survey"] = None,
|
379
366
|
debug: bool = False,
|
380
|
-
memory_plan: Optional[MemoryPlan] = None,
|
367
|
+
memory_plan: Optional["MemoryPlan"] = None,
|
381
368
|
current_answers: Optional[dict] = None,
|
382
369
|
iteration: int = 0,
|
383
370
|
sidecar_model=None,
|
384
|
-
) -> InvigilatorBase:
|
371
|
+
) -> "InvigilatorBase":
|
385
372
|
"""Create an Invigilator."""
|
373
|
+
from edsl import Model
|
374
|
+
from edsl import Scenario
|
375
|
+
|
386
376
|
model = model or Model()
|
387
377
|
scenario = scenario or Scenario()
|
388
378
|
|
379
|
+
from edsl.agents.Invigilator import (
|
380
|
+
InvigilatorDebug,
|
381
|
+
InvigilatorHuman,
|
382
|
+
InvigilatorFunctional,
|
383
|
+
InvigilatorAI,
|
384
|
+
InvigilatorBase,
|
385
|
+
)
|
386
|
+
|
389
387
|
if cache is None:
|
388
|
+
from edsl.data.Cache import Cache
|
389
|
+
|
390
390
|
cache = Cache()
|
391
391
|
|
392
392
|
if debug:
|
@@ -502,7 +502,6 @@ class Agent(Base):
|
|
502
502
|
f"'{type(self).__name__}' object has no attribute '{name}'"
|
503
503
|
)
|
504
504
|
|
505
|
-
|
506
505
|
def __getstate__(self):
|
507
506
|
state = self.__dict__.copy()
|
508
507
|
# Include any additional state that needs to be serialized
|
@@ -675,6 +674,8 @@ class Agent(Base):
|
|
675
674
|
>>> a.rich_print()
|
676
675
|
<rich.table.Table object at ...>
|
677
676
|
"""
|
677
|
+
from rich.table import Table
|
678
|
+
|
678
679
|
table_data, column_names = self._table()
|
679
680
|
table = Table(title=f"{self.__class__.__name__} Attributes")
|
680
681
|
for column in column_names:
|
edsl/agents/AgentList.py
CHANGED
@@ -22,7 +22,8 @@ import csv
|
|
22
22
|
from simpleeval import EvalWithCompoundTypes
|
23
23
|
|
24
24
|
from edsl.Base import Base
|
25
|
-
|
25
|
+
|
26
|
+
# from edsl.agents import Agent
|
26
27
|
from edsl.utilities.decorators import (
|
27
28
|
add_edsl_version,
|
28
29
|
remove_edsl_version,
|
@@ -32,7 +33,7 @@ from edsl.utilities.decorators import (
|
|
32
33
|
class AgentList(UserList, Base):
|
33
34
|
"""A list of Agents."""
|
34
35
|
|
35
|
-
def __init__(self, data: Optional[list[Agent]] = None):
|
36
|
+
def __init__(self, data: Optional[list["Agent"]] = None):
|
36
37
|
"""Initialize a new AgentList.
|
37
38
|
|
38
39
|
:param data: A list of Agents.
|
@@ -77,6 +78,7 @@ class AgentList(UserList, Base):
|
|
77
78
|
def select(self, *traits) -> AgentList:
|
78
79
|
"""Selects agents with only the references traits.
|
79
80
|
|
81
|
+
>>> from edsl.agents.Agent import Agent
|
80
82
|
>>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
|
81
83
|
>>> al.select('a')
|
82
84
|
AgentList([Agent(traits = {'a': 1}), Agent(traits = {'a': 1})])
|
@@ -94,12 +96,13 @@ class AgentList(UserList, Base):
|
|
94
96
|
"""
|
95
97
|
Filter a list of agents based on an expression.
|
96
98
|
|
99
|
+
>>> from edsl.agents.Agent import Agent
|
97
100
|
>>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
|
98
101
|
>>> al.filter("b == 2")
|
99
102
|
AgentList([Agent(traits = {'a': 1, 'b': 2})])
|
100
103
|
"""
|
101
104
|
|
102
|
-
def create_evaluator(agent: Agent):
|
105
|
+
def create_evaluator(agent: "Agent"):
|
103
106
|
"""Create an evaluator for the given result.
|
104
107
|
The 'combined_dict' is a mapping of all values for that Result object.
|
105
108
|
"""
|
@@ -133,6 +136,8 @@ class AgentList(UserList, Base):
|
|
133
136
|
|
134
137
|
:param file_path: The path to the CSV file.
|
135
138
|
"""
|
139
|
+
from edsl.agents.Agent import Agent
|
140
|
+
|
136
141
|
agent_list = []
|
137
142
|
with open(file_path, "r") as f:
|
138
143
|
reader = csv.DictReader(f)
|
@@ -153,7 +158,7 @@ class AgentList(UserList, Base):
|
|
153
158
|
"""Remove traits from the AgentList.
|
154
159
|
|
155
160
|
:param traits: The traits to remove.
|
156
|
-
|
161
|
+
>>> from edsl.agents.Agent import Agent
|
157
162
|
>>> al = AgentList([Agent({'age': 22, 'hair': 'brown', 'height': 5.5}), Agent({'age': 22, 'hair': 'brown', 'height': 5.5})])
|
158
163
|
>>> al.remove_trait('age')
|
159
164
|
AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
|
@@ -222,12 +227,14 @@ class AgentList(UserList, Base):
|
|
222
227
|
"""Deserialize the dictionary back to an AgentList object.
|
223
228
|
|
224
229
|
:param: data: A dictionary representing an AgentList.
|
225
|
-
|
230
|
+
>>> from edsl.agents.Agent import Agent
|
226
231
|
>>> al = AgentList([Agent.example(), Agent.example()])
|
227
232
|
>>> al2 = AgentList.from_dict(al.to_dict())
|
228
233
|
>>> al2 == al
|
229
234
|
True
|
230
235
|
"""
|
236
|
+
from edsl.agents.Agent import Agent
|
237
|
+
|
231
238
|
agents = [Agent.from_dict(agent_dict) for agent_dict in data["agent_list"]]
|
232
239
|
return cls(agents)
|
233
240
|
|
@@ -240,6 +247,8 @@ class AgentList(UserList, Base):
|
|
240
247
|
2
|
241
248
|
|
242
249
|
"""
|
250
|
+
from edsl.agents.Agent import Agent
|
251
|
+
|
243
252
|
return cls([Agent.example(), Agent.example()])
|
244
253
|
|
245
254
|
@classmethod
|
@@ -249,6 +258,8 @@ class AgentList(UserList, Base):
|
|
249
258
|
:param trait_name: The name of the trait.
|
250
259
|
:param values: A list of values.
|
251
260
|
"""
|
261
|
+
from edsl.agents.Agent import Agent
|
262
|
+
|
252
263
|
return AgentList([Agent({trait_name: value}) for value in values])
|
253
264
|
|
254
265
|
def __mul__(self, other: AgentList) -> AgentList:
|
edsl/agents/Invigilator.py
CHANGED
@@ -82,7 +82,25 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
|
|
82
82
|
self._remove_from_cache(raw_response)
|
83
83
|
raise e
|
84
84
|
|
85
|
-
|
85
|
+
# breakpoint()
|
86
|
+
question_dict = self.survey.question_names_to_questions()
|
87
|
+
for other_question, answer in self.current_answers.items():
|
88
|
+
if other_question in question_dict:
|
89
|
+
question_dict[other_question].answer = answer
|
90
|
+
else:
|
91
|
+
# adds a comment to the question
|
92
|
+
if (
|
93
|
+
new_question := other_question.split("_comment")[0]
|
94
|
+
) in question_dict:
|
95
|
+
question_dict[new_question].comment = answer
|
96
|
+
|
97
|
+
combined_dict = {**question_dict, **scenario}
|
98
|
+
# print("combined_dict: ", combined_dict)
|
99
|
+
# print("response: ", response)
|
100
|
+
# breakpoint()
|
101
|
+
answer = question._translate_answer_code_to_answer(
|
102
|
+
response["answer"], combined_dict
|
103
|
+
)
|
86
104
|
data = {
|
87
105
|
"answer": answer,
|
88
106
|
"comment": response.get(
|
edsl/agents/descriptors.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1
1
|
"""This module contains the descriptors used to set the attributes of the Agent class."""
|
2
2
|
|
3
3
|
from typing import Dict
|
4
|
-
from edsl.utilities.utilities import is_valid_variable_name
|
5
4
|
from edsl.exceptions.agents import AgentNameError, AgentTraitKeyError
|
6
5
|
|
7
6
|
|
@@ -30,6 +29,8 @@ class TraitsDescriptor:
|
|
30
29
|
|
31
30
|
def __set__(self, instance, traits_dict: Dict[str, str]) -> None:
|
32
31
|
"""Set the value of the attribute."""
|
32
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
33
|
+
|
33
34
|
for key, value in traits_dict.items():
|
34
35
|
if key == "name":
|
35
36
|
raise AgentNameError(
|
edsl/base/Base.py
ADDED
@@ -0,0 +1,289 @@
|
|
1
|
+
"""Base class for all classes in the package. It provides rich printing and persistence of objects."""
|
2
|
+
|
3
|
+
from abc import ABC, abstractmethod, ABCMeta
|
4
|
+
import gzip
|
5
|
+
import io
|
6
|
+
import json
|
7
|
+
from typing import Any, Optional, Union
|
8
|
+
from uuid import UUID
|
9
|
+
from IPython.display import display
|
10
|
+
from rich.console import Console
|
11
|
+
|
12
|
+
|
13
|
+
class RichPrintingMixin:
|
14
|
+
"""Mixin for rich printing and persistence of objects."""
|
15
|
+
|
16
|
+
def _for_console(self):
|
17
|
+
"""Return a string representation of the object for console printing."""
|
18
|
+
with io.StringIO() as buf:
|
19
|
+
console = Console(file=buf, record=True)
|
20
|
+
table = self.rich_print()
|
21
|
+
console.print(table)
|
22
|
+
return console.export_text()
|
23
|
+
|
24
|
+
def __str__(self):
|
25
|
+
"""Return a string representation of the object for console printing."""
|
26
|
+
return self._for_console()
|
27
|
+
|
28
|
+
def print(self):
|
29
|
+
"""Print the object to the console."""
|
30
|
+
from edsl.utilities.utilities import is_notebook
|
31
|
+
|
32
|
+
if is_notebook():
|
33
|
+
display(self.rich_print())
|
34
|
+
else:
|
35
|
+
from rich.console import Console
|
36
|
+
|
37
|
+
console = Console()
|
38
|
+
console.print(self.rich_print())
|
39
|
+
|
40
|
+
|
41
|
+
class PersistenceMixin:
|
42
|
+
"""Mixin for saving and loading objects to and from files."""
|
43
|
+
|
44
|
+
def push(
|
45
|
+
self,
|
46
|
+
description: Optional[str] = None,
|
47
|
+
visibility: Optional[str] = "unlisted",
|
48
|
+
):
|
49
|
+
"""Post the object to coop."""
|
50
|
+
from edsl.coop import Coop
|
51
|
+
|
52
|
+
c = Coop()
|
53
|
+
return c.create(self, description, visibility)
|
54
|
+
|
55
|
+
@classmethod
|
56
|
+
def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
|
57
|
+
"""Pull the object from coop."""
|
58
|
+
from edsl.coop import Coop
|
59
|
+
|
60
|
+
if id_or_url.startswith("http"):
|
61
|
+
uuid_value = id_or_url.split("/")[-1]
|
62
|
+
else:
|
63
|
+
uuid_value = id_or_url
|
64
|
+
|
65
|
+
c = Coop()
|
66
|
+
|
67
|
+
return c._get_base(cls, uuid_value, exec_profile=exec_profile)
|
68
|
+
|
69
|
+
@classmethod
|
70
|
+
def delete(cls, id_or_url: Union[str, UUID]):
|
71
|
+
"""Delete the object from coop."""
|
72
|
+
from edsl.coop import Coop
|
73
|
+
|
74
|
+
c = Coop()
|
75
|
+
return c._delete_base(cls, id_or_url)
|
76
|
+
|
77
|
+
@classmethod
|
78
|
+
def patch(
|
79
|
+
cls,
|
80
|
+
id_or_url: Union[str, UUID],
|
81
|
+
description: Optional[str] = None,
|
82
|
+
value: Optional[Any] = None,
|
83
|
+
visibility: Optional[str] = None,
|
84
|
+
):
|
85
|
+
"""
|
86
|
+
Patch an uploaded objects attributes.
|
87
|
+
- `description` changes the description of the object on Coop
|
88
|
+
- `value` changes the value of the object on Coop. **has to be an EDSL object**
|
89
|
+
- `visibility` changes the visibility of the object on Coop
|
90
|
+
"""
|
91
|
+
from edsl.coop import Coop
|
92
|
+
|
93
|
+
c = Coop()
|
94
|
+
return c._patch_base(cls, id_or_url, description, value, visibility)
|
95
|
+
|
96
|
+
@classmethod
|
97
|
+
def search(cls, query):
|
98
|
+
"""Search for objects on coop."""
|
99
|
+
from edsl.coop import Coop
|
100
|
+
|
101
|
+
c = Coop()
|
102
|
+
return c.search(cls, query)
|
103
|
+
|
104
|
+
def save(self, filename, compress=True):
|
105
|
+
"""Save the object to a file as zippped JSON.
|
106
|
+
|
107
|
+
>>> obj.save("obj.json.gz")
|
108
|
+
|
109
|
+
"""
|
110
|
+
if filename.endswith("json.gz"):
|
111
|
+
import warnings
|
112
|
+
|
113
|
+
warnings.warn(
|
114
|
+
"Do not apply the file extensions. The filename should not end with 'json.gz'."
|
115
|
+
)
|
116
|
+
filename = filename[:-7]
|
117
|
+
if filename.endswith("json"):
|
118
|
+
filename = filename[:-4]
|
119
|
+
warnings.warn(
|
120
|
+
"Do not apply the file extensions. The filename should not end with 'json'."
|
121
|
+
)
|
122
|
+
|
123
|
+
if compress:
|
124
|
+
with gzip.open(filename + ".json.gz", "wb") as f:
|
125
|
+
f.write(json.dumps(self.to_dict()).encode("utf-8"))
|
126
|
+
else:
|
127
|
+
with open(filename + ".json", "w") as f:
|
128
|
+
f.write(json.dumps(self.to_dict()))
|
129
|
+
|
130
|
+
@staticmethod
|
131
|
+
def open_compressed_file(filename):
|
132
|
+
with gzip.open(filename, "rb") as f:
|
133
|
+
file_contents = f.read()
|
134
|
+
file_contents_decoded = file_contents.decode("utf-8")
|
135
|
+
d = json.loads(file_contents_decoded)
|
136
|
+
return d
|
137
|
+
|
138
|
+
@staticmethod
|
139
|
+
def open_regular_file(filename):
|
140
|
+
with open(filename, "r") as f:
|
141
|
+
d = json.loads(f.read())
|
142
|
+
return d
|
143
|
+
|
144
|
+
@classmethod
|
145
|
+
def load(cls, filename):
|
146
|
+
"""Load the object from a file.
|
147
|
+
|
148
|
+
>>> obj = cls.load("obj.json.gz")
|
149
|
+
|
150
|
+
"""
|
151
|
+
|
152
|
+
if filename.endswith("json.gz"):
|
153
|
+
d = cls.open_compressed_file(filename)
|
154
|
+
elif filename.endswith("json"):
|
155
|
+
d = cls.open_regular_file(filename)
|
156
|
+
else:
|
157
|
+
try:
|
158
|
+
d = cls.open_compressed_file(filename)
|
159
|
+
except:
|
160
|
+
d = cls.open_regular_file(filename)
|
161
|
+
finally:
|
162
|
+
raise ValueError("File must be a json or json.gz file")
|
163
|
+
|
164
|
+
return cls.from_dict(d)
|
165
|
+
|
166
|
+
|
167
|
+
class RegisterSubclassesMeta(ABCMeta):
|
168
|
+
"""Metaclass for registering subclasses."""
|
169
|
+
|
170
|
+
_registry = {}
|
171
|
+
|
172
|
+
def __init__(cls, name, bases, nmspc):
|
173
|
+
"""Register the class in the registry upon creation."""
|
174
|
+
super(RegisterSubclassesMeta, cls).__init__(name, bases, nmspc)
|
175
|
+
if cls.__name__ != "Base":
|
176
|
+
RegisterSubclassesMeta._registry[cls.__name__] = cls
|
177
|
+
|
178
|
+
@staticmethod
|
179
|
+
def get_registry():
|
180
|
+
"""Return the registry of subclasses."""
|
181
|
+
return dict(RegisterSubclassesMeta._registry)
|
182
|
+
|
183
|
+
|
184
|
+
class DiffMethodsMixin:
|
185
|
+
def __sub__(self, other):
|
186
|
+
"""Return the difference between two objects."""
|
187
|
+
from edsl.BaseDiff import BaseDiff
|
188
|
+
|
189
|
+
return BaseDiff(self, other)
|
190
|
+
|
191
|
+
|
192
|
+
class Base(
|
193
|
+
RichPrintingMixin,
|
194
|
+
PersistenceMixin,
|
195
|
+
DiffMethodsMixin,
|
196
|
+
ABC,
|
197
|
+
metaclass=RegisterSubclassesMeta,
|
198
|
+
):
|
199
|
+
"""Base class for all classes in the package."""
|
200
|
+
|
201
|
+
# def __getitem__(self, key):
|
202
|
+
# return getattr(self, key)
|
203
|
+
|
204
|
+
# @abstractmethod
|
205
|
+
# def _repr_html_(self) -> str:
|
206
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
207
|
+
|
208
|
+
# @abstractmethod
|
209
|
+
# def _repr_(self) -> str:
|
210
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
211
|
+
|
212
|
+
def keys(self):
|
213
|
+
"""Return the keys of the object."""
|
214
|
+
_keys = list(self.to_dict().keys())
|
215
|
+
if "edsl_version" in _keys:
|
216
|
+
_keys.remove("edsl_version")
|
217
|
+
if "edsl_class_name" in _keys:
|
218
|
+
_keys.remove("edsl_class_name")
|
219
|
+
return _keys
|
220
|
+
|
221
|
+
def values(self):
|
222
|
+
"""Return the values of the object."""
|
223
|
+
data = self.to_dict()
|
224
|
+
keys = self.keys()
|
225
|
+
return {data[key] for key in keys}
|
226
|
+
|
227
|
+
def _repr_html_(self):
|
228
|
+
from edsl.utilities.utilities import data_to_html
|
229
|
+
|
230
|
+
return data_to_html(self.to_dict())
|
231
|
+
|
232
|
+
# def html(self):
|
233
|
+
# html_string = self._repr_html_()
|
234
|
+
# import tempfile
|
235
|
+
# import webbrowser
|
236
|
+
|
237
|
+
# with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f:
|
238
|
+
# # print("Writing HTML to", f.name)
|
239
|
+
# f.write(html_string)
|
240
|
+
# webbrowser.open(f.name)
|
241
|
+
|
242
|
+
def __eq__(self, other):
|
243
|
+
"""Return whether two objects are equal."""
|
244
|
+
import inspect
|
245
|
+
|
246
|
+
if not isinstance(other, self.__class__):
|
247
|
+
return False
|
248
|
+
if "sort" in inspect.signature(self._to_dict).parameters:
|
249
|
+
return self._to_dict(sort=True) == other._to_dict(sort=True)
|
250
|
+
else:
|
251
|
+
return self._to_dict() == other._to_dict()
|
252
|
+
|
253
|
+
@abstractmethod
|
254
|
+
def example():
|
255
|
+
"""This method should be implemented by subclasses."""
|
256
|
+
raise NotImplementedError("This method is not implemented yet.")
|
257
|
+
|
258
|
+
@abstractmethod
|
259
|
+
def rich_print():
|
260
|
+
"""This method should be implemented by subclasses."""
|
261
|
+
raise NotImplementedError("This method is not implemented yet.")
|
262
|
+
|
263
|
+
@abstractmethod
|
264
|
+
def to_dict():
|
265
|
+
"""This method should be implemented by subclasses."""
|
266
|
+
raise NotImplementedError("This method is not implemented yet.")
|
267
|
+
|
268
|
+
@abstractmethod
|
269
|
+
def from_dict():
|
270
|
+
"""This method should be implemented by subclasses."""
|
271
|
+
raise NotImplementedError("This method is not implemented yet.")
|
272
|
+
|
273
|
+
@abstractmethod
|
274
|
+
def code():
|
275
|
+
"""This method should be implemented by subclasses."""
|
276
|
+
raise NotImplementedError("This method is not implemented yet.")
|
277
|
+
|
278
|
+
def show_methods(self, show_docstrings=True):
|
279
|
+
"""Show the methods of the object."""
|
280
|
+
public_methods_with_docstrings = [
|
281
|
+
(method, getattr(self, method).__doc__)
|
282
|
+
for method in dir(self)
|
283
|
+
if callable(getattr(self, method)) and not method.startswith("_")
|
284
|
+
]
|
285
|
+
if show_docstrings:
|
286
|
+
for method, documentation in public_methods_with_docstrings:
|
287
|
+
print(f"{method}: {documentation}")
|
288
|
+
else:
|
289
|
+
return [x[0] for x in public_methods_with_docstrings]
|