edsl 0.1.29.dev6__py3-none-any.whl → 0.1.30.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (60) hide show
  1. edsl/Base.py +6 -3
  2. edsl/__init__.py +23 -23
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +35 -34
  5. edsl/agents/AgentList.py +16 -5
  6. edsl/agents/Invigilator.py +19 -1
  7. edsl/agents/descriptors.py +2 -1
  8. edsl/base/Base.py +289 -0
  9. edsl/config.py +2 -1
  10. edsl/coop/utils.py +28 -1
  11. edsl/data/Cache.py +19 -5
  12. edsl/data/SQLiteDict.py +11 -3
  13. edsl/jobs/Answers.py +15 -1
  14. edsl/jobs/Jobs.py +69 -31
  15. edsl/jobs/buckets/ModelBuckets.py +4 -2
  16. edsl/jobs/buckets/TokenBucket.py +1 -2
  17. edsl/jobs/interviews/Interview.py +0 -6
  18. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +9 -5
  19. edsl/jobs/runners/JobsRunnerAsyncio.py +12 -16
  20. edsl/jobs/tasks/TaskHistory.py +4 -3
  21. edsl/language_models/LanguageModel.py +5 -11
  22. edsl/language_models/ModelList.py +1 -1
  23. edsl/language_models/repair.py +8 -7
  24. edsl/notebooks/Notebook.py +9 -3
  25. edsl/questions/QuestionBase.py +6 -2
  26. edsl/questions/QuestionBudget.py +5 -6
  27. edsl/questions/QuestionCheckBox.py +7 -3
  28. edsl/questions/QuestionExtract.py +5 -3
  29. edsl/questions/QuestionFreeText.py +3 -3
  30. edsl/questions/QuestionFunctional.py +0 -3
  31. edsl/questions/QuestionList.py +3 -4
  32. edsl/questions/QuestionMultipleChoice.py +12 -5
  33. edsl/questions/QuestionNumerical.py +4 -3
  34. edsl/questions/QuestionRank.py +5 -3
  35. edsl/questions/__init__.py +4 -3
  36. edsl/questions/descriptors.py +4 -2
  37. edsl/results/DatasetExportMixin.py +491 -0
  38. edsl/results/Result.py +13 -65
  39. edsl/results/Results.py +91 -39
  40. edsl/results/ResultsDBMixin.py +7 -3
  41. edsl/results/ResultsExportMixin.py +22 -537
  42. edsl/results/ResultsGGMixin.py +3 -3
  43. edsl/results/ResultsToolsMixin.py +1 -4
  44. edsl/scenarios/FileStore.py +140 -0
  45. edsl/scenarios/Scenario.py +5 -6
  46. edsl/scenarios/ScenarioList.py +17 -8
  47. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  48. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  49. edsl/scenarios/__init__.py +1 -0
  50. edsl/surveys/MemoryPlan.py +11 -4
  51. edsl/surveys/Survey.py +9 -4
  52. edsl/surveys/SurveyExportMixin.py +4 -2
  53. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  54. edsl/utilities/__init__.py +21 -21
  55. edsl/utilities/interface.py +66 -45
  56. edsl/utilities/utilities.py +11 -13
  57. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/METADATA +1 -1
  58. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/RECORD +60 -56
  59. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/LICENSE +0 -0
  60. {edsl-0.1.29.dev6.dist-info → edsl-0.1.30.dev1.dist-info}/WHEEL +0 -0
edsl/Base.py CHANGED
@@ -6,9 +6,6 @@ import io
6
6
  import json
7
7
  from typing import Any, Optional, Union
8
8
  from uuid import UUID
9
- from IPython.display import display
10
- from rich.console import Console
11
- from edsl.utilities import is_notebook
12
9
 
13
10
 
14
11
  class RichPrintingMixin:
@@ -16,6 +13,8 @@ class RichPrintingMixin:
16
13
 
17
14
  def _for_console(self):
18
15
  """Return a string representation of the object for console printing."""
16
+ from rich.console import Console
17
+
19
18
  with io.StringIO() as buf:
20
19
  console = Console(file=buf, record=True)
21
20
  table = self.rich_print()
@@ -28,7 +27,11 @@ class RichPrintingMixin:
28
27
 
29
28
  def print(self):
30
29
  """Print the object to the console."""
30
+ from edsl.utilities.utilities import is_notebook
31
+
31
32
  if is_notebook():
33
+ from IPython.display import display
34
+
32
35
  display(self.rich_print())
33
36
  else:
34
37
  from rich.console import Console
edsl/__init__.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import os
2
+ import time
2
3
 
3
4
  BASE_DIR = os.path.dirname(os.path.abspath(__file__))
4
5
  ROOT_DIR = os.path.dirname(BASE_DIR)
@@ -7,36 +8,35 @@ from edsl.__version__ import __version__
7
8
  from edsl.config import Config, CONFIG
8
9
  from edsl.agents.Agent import Agent
9
10
  from edsl.agents.AgentList import AgentList
10
- from edsl.questions import (
11
- QuestionBase,
12
- QuestionBudget,
13
- QuestionCheckBox,
14
- QuestionExtract,
15
- QuestionFreeText,
16
- QuestionFunctional,
17
- QuestionLikertFive,
18
- QuestionList,
19
- QuestionLinearScale,
20
- QuestionMultipleChoice,
21
- QuestionNumerical,
22
- QuestionRank,
23
- QuestionTopK,
24
- QuestionYesNo,
25
- )
26
- from edsl.scenarios.Scenario import Scenario
27
- from edsl.scenarios.ScenarioList import ScenarioList
28
- from edsl.utilities.interface import print_dict_with_rich
11
+ from edsl.questions import QuestionBase
12
+ from edsl.questions import QuestionMultipleChoice
13
+ from edsl.questions import QuestionBudget
14
+ from edsl.questions import QuestionCheckBox
15
+ from edsl.questions import QuestionExtract
16
+ from edsl.questions import QuestionFreeText
17
+ from edsl.questions import QuestionFunctional
18
+ from edsl.questions import QuestionLikertFive
19
+ from edsl.questions import QuestionList
20
+ from edsl.questions import QuestionLinearScale
21
+ from edsl.questions import QuestionNumerical
22
+ from edsl.questions import QuestionRank
23
+ from edsl.questions import QuestionTopK
24
+ from edsl.questions import QuestionYesNo
25
+ from edsl.questions.question_registry import Question
26
+ from edsl.scenarios import Scenario
27
+ from edsl.scenarios import ScenarioList
28
+
29
+ # from edsl.utilities.interface import print_dict_with_rich
29
30
  from edsl.surveys.Survey import Survey
30
31
  from edsl.language_models.registry import Model
31
- from edsl.questions.question_registry import Question
32
+ from edsl.language_models.ModelList import ModelList
32
33
  from edsl.results.Results import Results
33
34
  from edsl.data.Cache import Cache
34
35
  from edsl.data.CacheEntry import CacheEntry
35
36
  from edsl.data.CacheHandler import set_session_cache, unset_session_cache
36
37
  from edsl.shared import shared_globals
37
- from edsl.jobs import Jobs
38
- from edsl.notebooks import Notebook
38
+ from edsl.jobs.Jobs import Jobs
39
+ from edsl.notebooks.Notebook import Notebook
39
40
  from edsl.study.Study import Study
40
41
  from edsl.conjure.Conjure import Conjure
41
- from edsl.language_models.ModelList import ModelList
42
42
  from edsl.coop.coop import Coop
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.29.dev6"
1
+ __version__ = "0.1.30.dev1"
edsl/agents/Agent.py CHANGED
@@ -5,27 +5,14 @@ import copy
5
5
  import inspect
6
6
  import types
7
7
  from typing import Any, Callable, Optional, Union, Dict, Sequence
8
-
9
- from rich.table import Table
10
-
11
8
  from edsl.Base import Base
12
- from edsl.questions.QuestionBase import QuestionBase
13
- from edsl.language_models import LanguageModel
14
- from edsl.surveys.MemoryPlan import MemoryPlan
9
+
15
10
  from edsl.exceptions.agents import (
16
11
  AgentCombinationError,
17
12
  AgentDirectAnswerFunctionError,
18
13
  AgentDynamicTraitsFunctionError,
19
14
  )
20
- from edsl.agents.Invigilator import (
21
- InvigilatorDebug,
22
- InvigilatorHuman,
23
- InvigilatorFunctional,
24
- InvigilatorAI,
25
- InvigilatorBase,
26
- )
27
- from edsl.language_models.registry import Model
28
- from edsl.scenarios import Scenario
15
+
29
16
  from edsl.agents.descriptors import (
30
17
  TraitsDescriptor,
31
18
  CodebookDescriptor,
@@ -38,10 +25,6 @@ from edsl.utilities.decorators import (
38
25
  remove_edsl_version,
39
26
  )
40
27
  from edsl.data_transfer_models import AgentResponseDict
41
- from edsl.prompts.library.agent_persona import AgentPersona
42
- from edsl.data.Cache import Cache
43
-
44
-
45
28
  from edsl.utilities.restricted_python import create_restricted_function
46
29
 
47
30
 
@@ -156,6 +139,8 @@ class Agent(Base):
156
139
  self.current_question = None
157
140
 
158
141
  if traits_presentation_template is not None:
142
+ from edsl.prompts.library.agent_persona import AgentPersona
143
+
159
144
  self.traits_presentation_template = traits_presentation_template
160
145
  self.agent_persona = AgentPersona(text=self.traits_presentation_template)
161
146
 
@@ -276,8 +261,8 @@ class Agent(Base):
276
261
  def create_invigilator(
277
262
  self,
278
263
  *,
279
- question: QuestionBase,
280
- cache,
264
+ question: "QuestionBase",
265
+ cache: "Cache",
281
266
  survey: Optional["Survey"] = None,
282
267
  scenario: Optional[Scenario] = None,
283
268
  model: Optional[LanguageModel] = None,
@@ -286,7 +271,7 @@ class Agent(Base):
286
271
  current_answers: Optional[dict] = None,
287
272
  iteration: int = 1,
288
273
  sidecar_model=None,
289
- ) -> InvigilatorBase:
274
+ ) -> "InvigilatorBase":
290
275
  """Create an Invigilator.
291
276
 
292
277
  An invigilator is an object that is responsible for administering a question to an agent.
@@ -300,6 +285,8 @@ class Agent(Base):
300
285
  An invigator is an object that is responsible for administering a question to an agent and
301
286
  recording the responses.
302
287
  """
288
+ from edsl import Model, Scenario
289
+
303
290
  cache = cache
304
291
  self.current_question = question
305
292
  model = model or Model()
@@ -321,13 +308,13 @@ class Agent(Base):
321
308
  async def async_answer_question(
322
309
  self,
323
310
  *,
324
- question: QuestionBase,
325
- cache: Cache,
326
- scenario: Optional[Scenario] = None,
311
+ question: "QuestionBase",
312
+ cache: "Cache",
313
+ scenario: Optional["Scenario"] = None,
327
314
  survey: Optional["Survey"] = None,
328
- model: Optional[LanguageModel] = None,
315
+ model: Optional["LanguageModel"] = None,
329
316
  debug: bool = False,
330
- memory_plan: Optional[MemoryPlan] = None,
317
+ memory_plan: Optional["MemoryPlan"] = None,
331
318
  current_answers: Optional[dict] = None,
332
319
  iteration: int = 0,
333
320
  ) -> AgentResponseDict:
@@ -371,22 +358,35 @@ class Agent(Base):
371
358
 
372
359
  def _create_invigilator(
373
360
  self,
374
- question: QuestionBase,
375
- cache: Optional[Cache] = None,
376
- scenario: Optional[Scenario] = None,
377
- model: Optional[LanguageModel] = None,
361
+ question: "QuestionBase",
362
+ cache: Optional["Cache"] = None,
363
+ scenario: Optional["Scenario"] = None,
364
+ model: Optional["LanguageModel"] = None,
378
365
  survey: Optional["Survey"] = None,
379
366
  debug: bool = False,
380
- memory_plan: Optional[MemoryPlan] = None,
367
+ memory_plan: Optional["MemoryPlan"] = None,
381
368
  current_answers: Optional[dict] = None,
382
369
  iteration: int = 0,
383
370
  sidecar_model=None,
384
- ) -> InvigilatorBase:
371
+ ) -> "InvigilatorBase":
385
372
  """Create an Invigilator."""
373
+ from edsl import Model
374
+ from edsl import Scenario
375
+
386
376
  model = model or Model()
387
377
  scenario = scenario or Scenario()
388
378
 
379
+ from edsl.agents.Invigilator import (
380
+ InvigilatorDebug,
381
+ InvigilatorHuman,
382
+ InvigilatorFunctional,
383
+ InvigilatorAI,
384
+ InvigilatorBase,
385
+ )
386
+
389
387
  if cache is None:
388
+ from edsl.data.Cache import Cache
389
+
390
390
  cache = Cache()
391
391
 
392
392
  if debug:
@@ -502,7 +502,6 @@ class Agent(Base):
502
502
  f"'{type(self).__name__}' object has no attribute '{name}'"
503
503
  )
504
504
 
505
-
506
505
  def __getstate__(self):
507
506
  state = self.__dict__.copy()
508
507
  # Include any additional state that needs to be serialized
@@ -675,6 +674,8 @@ class Agent(Base):
675
674
  >>> a.rich_print()
676
675
  <rich.table.Table object at ...>
677
676
  """
677
+ from rich.table import Table
678
+
678
679
  table_data, column_names = self._table()
679
680
  table = Table(title=f"{self.__class__.__name__} Attributes")
680
681
  for column in column_names:
edsl/agents/AgentList.py CHANGED
@@ -22,7 +22,8 @@ import csv
22
22
  from simpleeval import EvalWithCompoundTypes
23
23
 
24
24
  from edsl.Base import Base
25
- from edsl.agents import Agent
25
+
26
+ # from edsl.agents import Agent
26
27
  from edsl.utilities.decorators import (
27
28
  add_edsl_version,
28
29
  remove_edsl_version,
@@ -32,7 +33,7 @@ from edsl.utilities.decorators import (
32
33
  class AgentList(UserList, Base):
33
34
  """A list of Agents."""
34
35
 
35
- def __init__(self, data: Optional[list[Agent]] = None):
36
+ def __init__(self, data: Optional[list["Agent"]] = None):
36
37
  """Initialize a new AgentList.
37
38
 
38
39
  :param data: A list of Agents.
@@ -77,6 +78,7 @@ class AgentList(UserList, Base):
77
78
  def select(self, *traits) -> AgentList:
78
79
  """Selects agents with only the references traits.
79
80
 
81
+ >>> from edsl.agents.Agent import Agent
80
82
  >>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
81
83
  >>> al.select('a')
82
84
  AgentList([Agent(traits = {'a': 1}), Agent(traits = {'a': 1})])
@@ -94,12 +96,13 @@ class AgentList(UserList, Base):
94
96
  """
95
97
  Filter a list of agents based on an expression.
96
98
 
99
+ >>> from edsl.agents.Agent import Agent
97
100
  >>> al = AgentList([Agent(traits = {'a': 1, 'b': 1}), Agent(traits = {'a': 1, 'b': 2})])
98
101
  >>> al.filter("b == 2")
99
102
  AgentList([Agent(traits = {'a': 1, 'b': 2})])
100
103
  """
101
104
 
102
- def create_evaluator(agent: Agent):
105
+ def create_evaluator(agent: "Agent"):
103
106
  """Create an evaluator for the given result.
104
107
  The 'combined_dict' is a mapping of all values for that Result object.
105
108
  """
@@ -133,6 +136,8 @@ class AgentList(UserList, Base):
133
136
 
134
137
  :param file_path: The path to the CSV file.
135
138
  """
139
+ from edsl.agents.Agent import Agent
140
+
136
141
  agent_list = []
137
142
  with open(file_path, "r") as f:
138
143
  reader = csv.DictReader(f)
@@ -153,7 +158,7 @@ class AgentList(UserList, Base):
153
158
  """Remove traits from the AgentList.
154
159
 
155
160
  :param traits: The traits to remove.
156
-
161
+ >>> from edsl.agents.Agent import Agent
157
162
  >>> al = AgentList([Agent({'age': 22, 'hair': 'brown', 'height': 5.5}), Agent({'age': 22, 'hair': 'brown', 'height': 5.5})])
158
163
  >>> al.remove_trait('age')
159
164
  AgentList([Agent(traits = {'hair': 'brown', 'height': 5.5}), Agent(traits = {'hair': 'brown', 'height': 5.5})])
@@ -222,12 +227,14 @@ class AgentList(UserList, Base):
222
227
  """Deserialize the dictionary back to an AgentList object.
223
228
 
224
229
  :param: data: A dictionary representing an AgentList.
225
-
230
+ >>> from edsl.agents.Agent import Agent
226
231
  >>> al = AgentList([Agent.example(), Agent.example()])
227
232
  >>> al2 = AgentList.from_dict(al.to_dict())
228
233
  >>> al2 == al
229
234
  True
230
235
  """
236
+ from edsl.agents.Agent import Agent
237
+
231
238
  agents = [Agent.from_dict(agent_dict) for agent_dict in data["agent_list"]]
232
239
  return cls(agents)
233
240
 
@@ -240,6 +247,8 @@ class AgentList(UserList, Base):
240
247
  2
241
248
 
242
249
  """
250
+ from edsl.agents.Agent import Agent
251
+
243
252
  return cls([Agent.example(), Agent.example()])
244
253
 
245
254
  @classmethod
@@ -249,6 +258,8 @@ class AgentList(UserList, Base):
249
258
  :param trait_name: The name of the trait.
250
259
  :param values: A list of values.
251
260
  """
261
+ from edsl.agents.Agent import Agent
262
+
252
263
  return AgentList([Agent({trait_name: value}) for value in values])
253
264
 
254
265
  def __mul__(self, other: AgentList) -> AgentList:
@@ -82,7 +82,25 @@ class InvigilatorAI(PromptConstructorMixin, InvigilatorBase):
82
82
  self._remove_from_cache(raw_response)
83
83
  raise e
84
84
 
85
- answer = question._translate_answer_code_to_answer(response["answer"], scenario)
85
+ # breakpoint()
86
+ question_dict = self.survey.question_names_to_questions()
87
+ for other_question, answer in self.current_answers.items():
88
+ if other_question in question_dict:
89
+ question_dict[other_question].answer = answer
90
+ else:
91
+ # adds a comment to the question
92
+ if (
93
+ new_question := other_question.split("_comment")[0]
94
+ ) in question_dict:
95
+ question_dict[new_question].comment = answer
96
+
97
+ combined_dict = {**question_dict, **scenario}
98
+ # print("combined_dict: ", combined_dict)
99
+ # print("response: ", response)
100
+ # breakpoint()
101
+ answer = question._translate_answer_code_to_answer(
102
+ response["answer"], combined_dict
103
+ )
86
104
  data = {
87
105
  "answer": answer,
88
106
  "comment": response.get(
@@ -1,7 +1,6 @@
1
1
  """This module contains the descriptors used to set the attributes of the Agent class."""
2
2
 
3
3
  from typing import Dict
4
- from edsl.utilities.utilities import is_valid_variable_name
5
4
  from edsl.exceptions.agents import AgentNameError, AgentTraitKeyError
6
5
 
7
6
 
@@ -30,6 +29,8 @@ class TraitsDescriptor:
30
29
 
31
30
  def __set__(self, instance, traits_dict: Dict[str, str]) -> None:
32
31
  """Set the value of the attribute."""
32
+ from edsl.utilities.utilities import is_valid_variable_name
33
+
33
34
  for key, value in traits_dict.items():
34
35
  if key == "name":
35
36
  raise AgentNameError(
edsl/base/Base.py ADDED
@@ -0,0 +1,289 @@
1
+ """Base class for all classes in the package. It provides rich printing and persistence of objects."""
2
+
3
+ from abc import ABC, abstractmethod, ABCMeta
4
+ import gzip
5
+ import io
6
+ import json
7
+ from typing import Any, Optional, Union
8
+ from uuid import UUID
9
+ from IPython.display import display
10
+ from rich.console import Console
11
+
12
+
13
+ class RichPrintingMixin:
14
+ """Mixin for rich printing and persistence of objects."""
15
+
16
+ def _for_console(self):
17
+ """Return a string representation of the object for console printing."""
18
+ with io.StringIO() as buf:
19
+ console = Console(file=buf, record=True)
20
+ table = self.rich_print()
21
+ console.print(table)
22
+ return console.export_text()
23
+
24
+ def __str__(self):
25
+ """Return a string representation of the object for console printing."""
26
+ return self._for_console()
27
+
28
+ def print(self):
29
+ """Print the object to the console."""
30
+ from edsl.utilities.utilities import is_notebook
31
+
32
+ if is_notebook():
33
+ display(self.rich_print())
34
+ else:
35
+ from rich.console import Console
36
+
37
+ console = Console()
38
+ console.print(self.rich_print())
39
+
40
+
41
+ class PersistenceMixin:
42
+ """Mixin for saving and loading objects to and from files."""
43
+
44
+ def push(
45
+ self,
46
+ description: Optional[str] = None,
47
+ visibility: Optional[str] = "unlisted",
48
+ ):
49
+ """Post the object to coop."""
50
+ from edsl.coop import Coop
51
+
52
+ c = Coop()
53
+ return c.create(self, description, visibility)
54
+
55
+ @classmethod
56
+ def pull(cls, id_or_url: Union[str, UUID], exec_profile=None):
57
+ """Pull the object from coop."""
58
+ from edsl.coop import Coop
59
+
60
+ if id_or_url.startswith("http"):
61
+ uuid_value = id_or_url.split("/")[-1]
62
+ else:
63
+ uuid_value = id_or_url
64
+
65
+ c = Coop()
66
+
67
+ return c._get_base(cls, uuid_value, exec_profile=exec_profile)
68
+
69
+ @classmethod
70
+ def delete(cls, id_or_url: Union[str, UUID]):
71
+ """Delete the object from coop."""
72
+ from edsl.coop import Coop
73
+
74
+ c = Coop()
75
+ return c._delete_base(cls, id_or_url)
76
+
77
+ @classmethod
78
+ def patch(
79
+ cls,
80
+ id_or_url: Union[str, UUID],
81
+ description: Optional[str] = None,
82
+ value: Optional[Any] = None,
83
+ visibility: Optional[str] = None,
84
+ ):
85
+ """
86
+ Patch an uploaded objects attributes.
87
+ - `description` changes the description of the object on Coop
88
+ - `value` changes the value of the object on Coop. **has to be an EDSL object**
89
+ - `visibility` changes the visibility of the object on Coop
90
+ """
91
+ from edsl.coop import Coop
92
+
93
+ c = Coop()
94
+ return c._patch_base(cls, id_or_url, description, value, visibility)
95
+
96
+ @classmethod
97
+ def search(cls, query):
98
+ """Search for objects on coop."""
99
+ from edsl.coop import Coop
100
+
101
+ c = Coop()
102
+ return c.search(cls, query)
103
+
104
+ def save(self, filename, compress=True):
105
+ """Save the object to a file as zippped JSON.
106
+
107
+ >>> obj.save("obj.json.gz")
108
+
109
+ """
110
+ if filename.endswith("json.gz"):
111
+ import warnings
112
+
113
+ warnings.warn(
114
+ "Do not apply the file extensions. The filename should not end with 'json.gz'."
115
+ )
116
+ filename = filename[:-7]
117
+ if filename.endswith("json"):
118
+ filename = filename[:-4]
119
+ warnings.warn(
120
+ "Do not apply the file extensions. The filename should not end with 'json'."
121
+ )
122
+
123
+ if compress:
124
+ with gzip.open(filename + ".json.gz", "wb") as f:
125
+ f.write(json.dumps(self.to_dict()).encode("utf-8"))
126
+ else:
127
+ with open(filename + ".json", "w") as f:
128
+ f.write(json.dumps(self.to_dict()))
129
+
130
+ @staticmethod
131
+ def open_compressed_file(filename):
132
+ with gzip.open(filename, "rb") as f:
133
+ file_contents = f.read()
134
+ file_contents_decoded = file_contents.decode("utf-8")
135
+ d = json.loads(file_contents_decoded)
136
+ return d
137
+
138
+ @staticmethod
139
+ def open_regular_file(filename):
140
+ with open(filename, "r") as f:
141
+ d = json.loads(f.read())
142
+ return d
143
+
144
+ @classmethod
145
+ def load(cls, filename):
146
+ """Load the object from a file.
147
+
148
+ >>> obj = cls.load("obj.json.gz")
149
+
150
+ """
151
+
152
+ if filename.endswith("json.gz"):
153
+ d = cls.open_compressed_file(filename)
154
+ elif filename.endswith("json"):
155
+ d = cls.open_regular_file(filename)
156
+ else:
157
+ try:
158
+ d = cls.open_compressed_file(filename)
159
+ except:
160
+ d = cls.open_regular_file(filename)
161
+ finally:
162
+ raise ValueError("File must be a json or json.gz file")
163
+
164
+ return cls.from_dict(d)
165
+
166
+
167
+ class RegisterSubclassesMeta(ABCMeta):
168
+ """Metaclass for registering subclasses."""
169
+
170
+ _registry = {}
171
+
172
+ def __init__(cls, name, bases, nmspc):
173
+ """Register the class in the registry upon creation."""
174
+ super(RegisterSubclassesMeta, cls).__init__(name, bases, nmspc)
175
+ if cls.__name__ != "Base":
176
+ RegisterSubclassesMeta._registry[cls.__name__] = cls
177
+
178
+ @staticmethod
179
+ def get_registry():
180
+ """Return the registry of subclasses."""
181
+ return dict(RegisterSubclassesMeta._registry)
182
+
183
+
184
+ class DiffMethodsMixin:
185
+ def __sub__(self, other):
186
+ """Return the difference between two objects."""
187
+ from edsl.BaseDiff import BaseDiff
188
+
189
+ return BaseDiff(self, other)
190
+
191
+
192
+ class Base(
193
+ RichPrintingMixin,
194
+ PersistenceMixin,
195
+ DiffMethodsMixin,
196
+ ABC,
197
+ metaclass=RegisterSubclassesMeta,
198
+ ):
199
+ """Base class for all classes in the package."""
200
+
201
+ # def __getitem__(self, key):
202
+ # return getattr(self, key)
203
+
204
+ # @abstractmethod
205
+ # def _repr_html_(self) -> str:
206
+ # raise NotImplementedError("This method is not implemented yet.")
207
+
208
+ # @abstractmethod
209
+ # def _repr_(self) -> str:
210
+ # raise NotImplementedError("This method is not implemented yet.")
211
+
212
+ def keys(self):
213
+ """Return the keys of the object."""
214
+ _keys = list(self.to_dict().keys())
215
+ if "edsl_version" in _keys:
216
+ _keys.remove("edsl_version")
217
+ if "edsl_class_name" in _keys:
218
+ _keys.remove("edsl_class_name")
219
+ return _keys
220
+
221
+ def values(self):
222
+ """Return the values of the object."""
223
+ data = self.to_dict()
224
+ keys = self.keys()
225
+ return {data[key] for key in keys}
226
+
227
+ def _repr_html_(self):
228
+ from edsl.utilities.utilities import data_to_html
229
+
230
+ return data_to_html(self.to_dict())
231
+
232
+ # def html(self):
233
+ # html_string = self._repr_html_()
234
+ # import tempfile
235
+ # import webbrowser
236
+
237
+ # with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as f:
238
+ # # print("Writing HTML to", f.name)
239
+ # f.write(html_string)
240
+ # webbrowser.open(f.name)
241
+
242
+ def __eq__(self, other):
243
+ """Return whether two objects are equal."""
244
+ import inspect
245
+
246
+ if not isinstance(other, self.__class__):
247
+ return False
248
+ if "sort" in inspect.signature(self._to_dict).parameters:
249
+ return self._to_dict(sort=True) == other._to_dict(sort=True)
250
+ else:
251
+ return self._to_dict() == other._to_dict()
252
+
253
+ @abstractmethod
254
+ def example():
255
+ """This method should be implemented by subclasses."""
256
+ raise NotImplementedError("This method is not implemented yet.")
257
+
258
+ @abstractmethod
259
+ def rich_print():
260
+ """This method should be implemented by subclasses."""
261
+ raise NotImplementedError("This method is not implemented yet.")
262
+
263
+ @abstractmethod
264
+ def to_dict():
265
+ """This method should be implemented by subclasses."""
266
+ raise NotImplementedError("This method is not implemented yet.")
267
+
268
+ @abstractmethod
269
+ def from_dict():
270
+ """This method should be implemented by subclasses."""
271
+ raise NotImplementedError("This method is not implemented yet.")
272
+
273
+ @abstractmethod
274
+ def code():
275
+ """This method should be implemented by subclasses."""
276
+ raise NotImplementedError("This method is not implemented yet.")
277
+
278
+ def show_methods(self, show_docstrings=True):
279
+ """Show the methods of the object."""
280
+ public_methods_with_docstrings = [
281
+ (method, getattr(self, method).__doc__)
282
+ for method in dir(self)
283
+ if callable(getattr(self, method)) and not method.startswith("_")
284
+ ]
285
+ if show_docstrings:
286
+ for method, documentation in public_methods_with_docstrings:
287
+ print(f"{method}: {documentation}")
288
+ else:
289
+ return [x[0] for x in public_methods_with_docstrings]