edsl 0.1.36.dev7__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -48
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +345 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +305 -298
- edsl/agents/PromptConstructor.py +310 -320
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +824 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -84
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +74 -74
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1112 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -661
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +182 -189
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +350 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +113 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +435 -433
- edsl/results/Results.py +1160 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -458
- edsl/scenarios/Scenario.py +510 -510
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -4
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.36.dev7.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev7.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
- edsl-0.1.37.dev1.dist-info/RECORD +279 -0
- edsl-0.1.36.dev7.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev7.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
edsl/study/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from edsl.study.ObjectEntry import ObjectEntry
|
2
|
-
from edsl.study.ProofOfWork import ProofOfWork
|
3
|
-
from edsl.study.SnapShot import SnapShot
|
4
|
-
from edsl.study.Study import Study
|
1
|
+
from edsl.study.ObjectEntry import ObjectEntry
|
2
|
+
from edsl.study.ProofOfWork import ProofOfWork
|
3
|
+
from edsl.study.SnapShot import SnapShot
|
4
|
+
from edsl.study.Study import Study
|
edsl/surveys/DAG.py
CHANGED
@@ -1,148 +1,148 @@
|
|
1
|
-
"""Directed Acyclic Graph (DAG) class."""
|
2
|
-
|
3
|
-
from collections import UserDict
|
4
|
-
from graphlib import TopologicalSorter
|
5
|
-
|
6
|
-
|
7
|
-
class DAG(UserDict):
|
8
|
-
"""Class for creating a Directed Acyclic Graph (DAG) from a dictionary."""
|
9
|
-
|
10
|
-
def __init__(self, data: dict):
|
11
|
-
"""Initialize the DAG class."""
|
12
|
-
super().__init__(data)
|
13
|
-
self.reverse_mapping = self._create_reverse_mapping()
|
14
|
-
self.validate_no_cycles()
|
15
|
-
|
16
|
-
def _create_reverse_mapping(self):
|
17
|
-
"""
|
18
|
-
Create a reverse mapping of the DAG, where the keys are the children and the values are the parents.
|
19
|
-
|
20
|
-
Example usage:
|
21
|
-
|
22
|
-
.. code-block:: python
|
23
|
-
|
24
|
-
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
25
|
-
dag = DAG(data)
|
26
|
-
dag._create_reverse_mapping()
|
27
|
-
{'b': {'a'}, 'c': {'a'}, 'd': {'b'}}
|
28
|
-
|
29
|
-
"""
|
30
|
-
rev_map = {}
|
31
|
-
for key, values in self.items():
|
32
|
-
for value in values:
|
33
|
-
rev_map.setdefault(value, set()).add(key)
|
34
|
-
return rev_map
|
35
|
-
|
36
|
-
def get_all_children(self, key):
|
37
|
-
"""Get all children of a node in the DAG."""
|
38
|
-
children = set()
|
39
|
-
|
40
|
-
def dfs(node):
|
41
|
-
for child in self.reverse_mapping.get(node, []):
|
42
|
-
if child not in children:
|
43
|
-
children.add(child)
|
44
|
-
dfs(child)
|
45
|
-
|
46
|
-
dfs(key)
|
47
|
-
return children
|
48
|
-
|
49
|
-
def topologically_sorted_nodes(self):
|
50
|
-
"""
|
51
|
-
Return a sequence of the DAG.
|
52
|
-
|
53
|
-
Example usage:
|
54
|
-
|
55
|
-
.. code-block:: python
|
56
|
-
|
57
|
-
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
58
|
-
dag = DAG(data)
|
59
|
-
dag.topologically_sorted_nodes() == ['c', 'd', 'b', 'a']
|
60
|
-
True
|
61
|
-
|
62
|
-
"""
|
63
|
-
return list(TopologicalSorter(self).static_order())
|
64
|
-
|
65
|
-
def __add__(self, other_dag):
|
66
|
-
"""Combine two DAGs."""
|
67
|
-
d = {}
|
68
|
-
combined_keys = set(self.keys()).union(set(other_dag.keys()))
|
69
|
-
for key in combined_keys:
|
70
|
-
d[key] = self.get(key, set({})).union(other_dag.get(key, set({})))
|
71
|
-
return DAG(d)
|
72
|
-
# if textify:
|
73
|
-
# return DAG(self.textify(d))
|
74
|
-
# else:
|
75
|
-
# return DAG(d)
|
76
|
-
|
77
|
-
def remove_node(self, node: int) -> None:
|
78
|
-
"""Remove a node and all its connections from the DAG."""
|
79
|
-
self.pop(node, None)
|
80
|
-
for connections in self.values():
|
81
|
-
connections.discard(node)
|
82
|
-
# Adjust remaining nodes if necessary
|
83
|
-
self._adjust_nodes_after_removal(node)
|
84
|
-
|
85
|
-
def _adjust_nodes_after_removal(self, removed_node: int) -> None:
|
86
|
-
"""Adjust node indices after a node is removed."""
|
87
|
-
new_dag = {}
|
88
|
-
for node, connections in self.items():
|
89
|
-
new_node = node if node < removed_node else node - 1
|
90
|
-
new_connections = {c if c < removed_node else c - 1 for c in connections}
|
91
|
-
new_dag[new_node] = new_connections
|
92
|
-
self.clear()
|
93
|
-
self.update(new_dag)
|
94
|
-
|
95
|
-
@classmethod
|
96
|
-
def example(cls):
|
97
|
-
"""Return an example of the `DAG`."""
|
98
|
-
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
99
|
-
return cls(data)
|
100
|
-
|
101
|
-
def detect_cycles(self):
|
102
|
-
"""
|
103
|
-
Detect cycles in the DAG using depth-first search.
|
104
|
-
|
105
|
-
:return: A list of cycles if any are found, otherwise an empty list.
|
106
|
-
"""
|
107
|
-
visited = set()
|
108
|
-
path = []
|
109
|
-
cycles = []
|
110
|
-
|
111
|
-
def dfs(node):
|
112
|
-
if node in path:
|
113
|
-
cycle = path[path.index(node) :]
|
114
|
-
cycles.append(cycle + [node])
|
115
|
-
return
|
116
|
-
|
117
|
-
if node in visited:
|
118
|
-
return
|
119
|
-
|
120
|
-
visited.add(node)
|
121
|
-
path.append(node)
|
122
|
-
|
123
|
-
for child in self.get(node, []):
|
124
|
-
dfs(child)
|
125
|
-
|
126
|
-
path.pop()
|
127
|
-
|
128
|
-
for node in self:
|
129
|
-
if node not in visited:
|
130
|
-
dfs(node)
|
131
|
-
|
132
|
-
return cycles
|
133
|
-
|
134
|
-
def validate_no_cycles(self):
|
135
|
-
"""
|
136
|
-
Validate that the DAG does not contain any cycles.
|
137
|
-
|
138
|
-
:raises ValueError: If cycles are detected in the DAG.
|
139
|
-
"""
|
140
|
-
cycles = self.detect_cycles()
|
141
|
-
if cycles:
|
142
|
-
raise ValueError(f"Cycles detected in the DAG: {cycles}")
|
143
|
-
|
144
|
-
|
145
|
-
if __name__ == "__main__":
|
146
|
-
import doctest
|
147
|
-
|
148
|
-
doctest.testmod()
|
1
|
+
"""Directed Acyclic Graph (DAG) class."""
|
2
|
+
|
3
|
+
from collections import UserDict
|
4
|
+
from graphlib import TopologicalSorter
|
5
|
+
|
6
|
+
|
7
|
+
class DAG(UserDict):
|
8
|
+
"""Class for creating a Directed Acyclic Graph (DAG) from a dictionary."""
|
9
|
+
|
10
|
+
def __init__(self, data: dict):
|
11
|
+
"""Initialize the DAG class."""
|
12
|
+
super().__init__(data)
|
13
|
+
self.reverse_mapping = self._create_reverse_mapping()
|
14
|
+
self.validate_no_cycles()
|
15
|
+
|
16
|
+
def _create_reverse_mapping(self):
|
17
|
+
"""
|
18
|
+
Create a reverse mapping of the DAG, where the keys are the children and the values are the parents.
|
19
|
+
|
20
|
+
Example usage:
|
21
|
+
|
22
|
+
.. code-block:: python
|
23
|
+
|
24
|
+
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
25
|
+
dag = DAG(data)
|
26
|
+
dag._create_reverse_mapping()
|
27
|
+
{'b': {'a'}, 'c': {'a'}, 'd': {'b'}}
|
28
|
+
|
29
|
+
"""
|
30
|
+
rev_map = {}
|
31
|
+
for key, values in self.items():
|
32
|
+
for value in values:
|
33
|
+
rev_map.setdefault(value, set()).add(key)
|
34
|
+
return rev_map
|
35
|
+
|
36
|
+
def get_all_children(self, key):
|
37
|
+
"""Get all children of a node in the DAG."""
|
38
|
+
children = set()
|
39
|
+
|
40
|
+
def dfs(node):
|
41
|
+
for child in self.reverse_mapping.get(node, []):
|
42
|
+
if child not in children:
|
43
|
+
children.add(child)
|
44
|
+
dfs(child)
|
45
|
+
|
46
|
+
dfs(key)
|
47
|
+
return children
|
48
|
+
|
49
|
+
def topologically_sorted_nodes(self):
|
50
|
+
"""
|
51
|
+
Return a sequence of the DAG.
|
52
|
+
|
53
|
+
Example usage:
|
54
|
+
|
55
|
+
.. code-block:: python
|
56
|
+
|
57
|
+
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
58
|
+
dag = DAG(data)
|
59
|
+
dag.topologically_sorted_nodes() == ['c', 'd', 'b', 'a']
|
60
|
+
True
|
61
|
+
|
62
|
+
"""
|
63
|
+
return list(TopologicalSorter(self).static_order())
|
64
|
+
|
65
|
+
def __add__(self, other_dag):
|
66
|
+
"""Combine two DAGs."""
|
67
|
+
d = {}
|
68
|
+
combined_keys = set(self.keys()).union(set(other_dag.keys()))
|
69
|
+
for key in combined_keys:
|
70
|
+
d[key] = self.get(key, set({})).union(other_dag.get(key, set({})))
|
71
|
+
return DAG(d)
|
72
|
+
# if textify:
|
73
|
+
# return DAG(self.textify(d))
|
74
|
+
# else:
|
75
|
+
# return DAG(d)
|
76
|
+
|
77
|
+
def remove_node(self, node: int) -> None:
|
78
|
+
"""Remove a node and all its connections from the DAG."""
|
79
|
+
self.pop(node, None)
|
80
|
+
for connections in self.values():
|
81
|
+
connections.discard(node)
|
82
|
+
# Adjust remaining nodes if necessary
|
83
|
+
self._adjust_nodes_after_removal(node)
|
84
|
+
|
85
|
+
def _adjust_nodes_after_removal(self, removed_node: int) -> None:
|
86
|
+
"""Adjust node indices after a node is removed."""
|
87
|
+
new_dag = {}
|
88
|
+
for node, connections in self.items():
|
89
|
+
new_node = node if node < removed_node else node - 1
|
90
|
+
new_connections = {c if c < removed_node else c - 1 for c in connections}
|
91
|
+
new_dag[new_node] = new_connections
|
92
|
+
self.clear()
|
93
|
+
self.update(new_dag)
|
94
|
+
|
95
|
+
@classmethod
|
96
|
+
def example(cls):
|
97
|
+
"""Return an example of the `DAG`."""
|
98
|
+
data = {"a": ["b", "c"], "b": ["d"], "c": [], "d": []}
|
99
|
+
return cls(data)
|
100
|
+
|
101
|
+
def detect_cycles(self):
|
102
|
+
"""
|
103
|
+
Detect cycles in the DAG using depth-first search.
|
104
|
+
|
105
|
+
:return: A list of cycles if any are found, otherwise an empty list.
|
106
|
+
"""
|
107
|
+
visited = set()
|
108
|
+
path = []
|
109
|
+
cycles = []
|
110
|
+
|
111
|
+
def dfs(node):
|
112
|
+
if node in path:
|
113
|
+
cycle = path[path.index(node) :]
|
114
|
+
cycles.append(cycle + [node])
|
115
|
+
return
|
116
|
+
|
117
|
+
if node in visited:
|
118
|
+
return
|
119
|
+
|
120
|
+
visited.add(node)
|
121
|
+
path.append(node)
|
122
|
+
|
123
|
+
for child in self.get(node, []):
|
124
|
+
dfs(child)
|
125
|
+
|
126
|
+
path.pop()
|
127
|
+
|
128
|
+
for node in self:
|
129
|
+
if node not in visited:
|
130
|
+
dfs(node)
|
131
|
+
|
132
|
+
return cycles
|
133
|
+
|
134
|
+
def validate_no_cycles(self):
|
135
|
+
"""
|
136
|
+
Validate that the DAG does not contain any cycles.
|
137
|
+
|
138
|
+
:raises ValueError: If cycles are detected in the DAG.
|
139
|
+
"""
|
140
|
+
cycles = self.detect_cycles()
|
141
|
+
if cycles:
|
142
|
+
raise ValueError(f"Cycles detected in the DAG: {cycles}")
|
143
|
+
|
144
|
+
|
145
|
+
if __name__ == "__main__":
|
146
|
+
import doctest
|
147
|
+
|
148
|
+
doctest.testmod()
|
edsl/surveys/Memory.py
CHANGED
@@ -1,31 +1,31 @@
|
|
1
|
-
"""This module contains the Memory class, which is a list of prior questions."""
|
2
|
-
|
3
|
-
from collections import UserList
|
4
|
-
|
5
|
-
|
6
|
-
class Memory(UserList):
|
7
|
-
"""Class for holding the questions (stored as names) that we want the the agent to have available when answering a question."""
|
8
|
-
|
9
|
-
def __init__(self, prior_questions: list[str] = None):
|
10
|
-
"""Initialize the Memory object."""
|
11
|
-
super().__init__(prior_questions or [])
|
12
|
-
|
13
|
-
def add_prior_question(self, prior_question):
|
14
|
-
"""Add a prior question to the memory."""
|
15
|
-
if prior_question not in self:
|
16
|
-
self.append(prior_question)
|
17
|
-
else:
|
18
|
-
raise ValueError(f"{prior_question} is already in the memory.")
|
19
|
-
|
20
|
-
def __repr__(self):
|
21
|
-
"""Return a string representation of the Memory object."""
|
22
|
-
return f"Memory(prior_questions={self.data})"
|
23
|
-
|
24
|
-
def to_dict(self):
|
25
|
-
"""Create a dictionary representation of the Memory object."""
|
26
|
-
return {"prior_questions": self.data}
|
27
|
-
|
28
|
-
@classmethod
|
29
|
-
def from_dict(cls, data):
|
30
|
-
"""Create a Memory object from a dictionary."""
|
31
|
-
return cls(**data)
|
1
|
+
"""This module contains the Memory class, which is a list of prior questions."""
|
2
|
+
|
3
|
+
from collections import UserList
|
4
|
+
|
5
|
+
|
6
|
+
class Memory(UserList):
|
7
|
+
"""Class for holding the questions (stored as names) that we want the the agent to have available when answering a question."""
|
8
|
+
|
9
|
+
def __init__(self, prior_questions: list[str] = None):
|
10
|
+
"""Initialize the Memory object."""
|
11
|
+
super().__init__(prior_questions or [])
|
12
|
+
|
13
|
+
def add_prior_question(self, prior_question):
|
14
|
+
"""Add a prior question to the memory."""
|
15
|
+
if prior_question not in self:
|
16
|
+
self.append(prior_question)
|
17
|
+
else:
|
18
|
+
raise ValueError(f"{prior_question} is already in the memory.")
|
19
|
+
|
20
|
+
def __repr__(self):
|
21
|
+
"""Return a string representation of the Memory object."""
|
22
|
+
return f"Memory(prior_questions={self.data})"
|
23
|
+
|
24
|
+
def to_dict(self):
|
25
|
+
"""Create a dictionary representation of the Memory object."""
|
26
|
+
return {"prior_questions": self.data}
|
27
|
+
|
28
|
+
@classmethod
|
29
|
+
def from_dict(cls, data):
|
30
|
+
"""Create a Memory object from a dictionary."""
|
31
|
+
return cls(**data)
|