edsl 0.1.36.dev6__py3-none-any.whl → 0.1.37.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +303 -303
- edsl/BaseDiff.py +260 -260
- edsl/TemplateLoader.py +24 -24
- edsl/__init__.py +48 -47
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +804 -804
- edsl/agents/AgentList.py +345 -337
- edsl/agents/Invigilator.py +222 -222
- edsl/agents/InvigilatorBase.py +305 -294
- edsl/agents/PromptConstructor.py +312 -312
- edsl/agents/__init__.py +3 -3
- edsl/agents/descriptors.py +86 -86
- edsl/agents/prompt_helpers.py +129 -129
- edsl/auto/AutoStudy.py +117 -117
- edsl/auto/StageBase.py +230 -230
- edsl/auto/StageGenerateSurvey.py +178 -178
- edsl/auto/StageLabelQuestions.py +125 -125
- edsl/auto/StagePersona.py +61 -61
- edsl/auto/StagePersonaDimensionValueRanges.py +88 -88
- edsl/auto/StagePersonaDimensionValues.py +74 -74
- edsl/auto/StagePersonaDimensions.py +69 -69
- edsl/auto/StageQuestions.py +73 -73
- edsl/auto/SurveyCreatorPipeline.py +21 -21
- edsl/auto/utilities.py +224 -224
- edsl/base/Base.py +289 -289
- edsl/config.py +149 -149
- edsl/conjure/AgentConstructionMixin.py +152 -152
- edsl/conjure/Conjure.py +62 -62
- edsl/conjure/InputData.py +659 -659
- edsl/conjure/InputDataCSV.py +48 -48
- edsl/conjure/InputDataMixinQuestionStats.py +182 -182
- edsl/conjure/InputDataPyRead.py +91 -91
- edsl/conjure/InputDataSPSS.py +8 -8
- edsl/conjure/InputDataStata.py +8 -8
- edsl/conjure/QuestionOptionMixin.py +76 -76
- edsl/conjure/QuestionTypeMixin.py +23 -23
- edsl/conjure/RawQuestion.py +65 -65
- edsl/conjure/SurveyResponses.py +7 -7
- edsl/conjure/__init__.py +9 -9
- edsl/conjure/naming_utilities.py +263 -263
- edsl/conjure/utilities.py +201 -201
- edsl/conversation/Conversation.py +238 -238
- edsl/conversation/car_buying.py +58 -58
- edsl/conversation/mug_negotiation.py +81 -81
- edsl/conversation/next_speaker_utilities.py +93 -93
- edsl/coop/PriceFetcher.py +54 -54
- edsl/coop/__init__.py +2 -2
- edsl/coop/coop.py +824 -849
- edsl/coop/utils.py +131 -131
- edsl/data/Cache.py +527 -527
- edsl/data/CacheEntry.py +228 -228
- edsl/data/CacheHandler.py +149 -149
- edsl/data/RemoteCacheSync.py +97 -84
- edsl/data/SQLiteDict.py +292 -292
- edsl/data/__init__.py +4 -4
- edsl/data/orm.py +10 -10
- edsl/data_transfer_models.py +73 -73
- edsl/enums.py +173 -173
- edsl/exceptions/__init__.py +50 -50
- edsl/exceptions/agents.py +40 -40
- edsl/exceptions/configuration.py +16 -16
- edsl/exceptions/coop.py +10 -10
- edsl/exceptions/data.py +14 -14
- edsl/exceptions/general.py +34 -34
- edsl/exceptions/jobs.py +33 -33
- edsl/exceptions/language_models.py +63 -63
- edsl/exceptions/prompts.py +15 -15
- edsl/exceptions/questions.py +91 -91
- edsl/exceptions/results.py +26 -26
- edsl/exceptions/surveys.py +34 -34
- edsl/inference_services/AnthropicService.py +87 -87
- edsl/inference_services/AwsBedrock.py +115 -115
- edsl/inference_services/AzureAI.py +217 -217
- edsl/inference_services/DeepInfraService.py +18 -18
- edsl/inference_services/GoogleService.py +156 -156
- edsl/inference_services/GroqService.py +20 -20
- edsl/inference_services/InferenceServiceABC.py +147 -147
- edsl/inference_services/InferenceServicesCollection.py +74 -72
- edsl/inference_services/MistralAIService.py +123 -123
- edsl/inference_services/OllamaService.py +18 -18
- edsl/inference_services/OpenAIService.py +224 -224
- edsl/inference_services/TestService.py +89 -89
- edsl/inference_services/TogetherAIService.py +170 -170
- edsl/inference_services/models_available_cache.py +118 -118
- edsl/inference_services/rate_limits_cache.py +25 -25
- edsl/inference_services/registry.py +39 -39
- edsl/inference_services/write_available.py +10 -10
- edsl/jobs/Answers.py +56 -56
- edsl/jobs/Jobs.py +1112 -1112
- edsl/jobs/__init__.py +1 -1
- edsl/jobs/buckets/BucketCollection.py +63 -63
- edsl/jobs/buckets/ModelBuckets.py +65 -65
- edsl/jobs/buckets/TokenBucket.py +248 -248
- edsl/jobs/interviews/Interview.py +661 -651
- edsl/jobs/interviews/InterviewExceptionCollection.py +99 -99
- edsl/jobs/interviews/InterviewExceptionEntry.py +182 -182
- edsl/jobs/interviews/InterviewStatistic.py +63 -63
- edsl/jobs/interviews/InterviewStatisticsCollection.py +25 -25
- edsl/jobs/interviews/InterviewStatusDictionary.py +78 -78
- edsl/jobs/interviews/InterviewStatusLog.py +92 -92
- edsl/jobs/interviews/ReportErrors.py +66 -66
- edsl/jobs/interviews/interview_status_enum.py +9 -9
- edsl/jobs/runners/JobsRunnerAsyncio.py +338 -337
- edsl/jobs/runners/JobsRunnerStatus.py +332 -332
- edsl/jobs/tasks/QuestionTaskCreator.py +242 -242
- edsl/jobs/tasks/TaskCreators.py +64 -64
- edsl/jobs/tasks/TaskHistory.py +441 -441
- edsl/jobs/tasks/TaskStatusLog.py +23 -23
- edsl/jobs/tasks/task_status_enum.py +163 -163
- edsl/jobs/tokens/InterviewTokenUsage.py +27 -27
- edsl/jobs/tokens/TokenUsage.py +34 -34
- edsl/language_models/LanguageModel.py +718 -718
- edsl/language_models/ModelList.py +102 -102
- edsl/language_models/RegisterLanguageModelsMeta.py +184 -184
- edsl/language_models/__init__.py +2 -2
- edsl/language_models/fake_openai_call.py +15 -15
- edsl/language_models/fake_openai_service.py +61 -61
- edsl/language_models/registry.py +137 -137
- edsl/language_models/repair.py +156 -156
- edsl/language_models/unused/ReplicateBase.py +83 -83
- edsl/language_models/utilities.py +64 -64
- edsl/notebooks/Notebook.py +259 -259
- edsl/notebooks/__init__.py +1 -1
- edsl/prompts/Prompt.py +353 -358
- edsl/prompts/__init__.py +2 -2
- edsl/questions/AnswerValidatorMixin.py +289 -289
- edsl/questions/QuestionBase.py +616 -616
- edsl/questions/QuestionBaseGenMixin.py +161 -161
- edsl/questions/QuestionBasePromptsMixin.py +266 -266
- edsl/questions/QuestionBudget.py +227 -227
- edsl/questions/QuestionCheckBox.py +359 -359
- edsl/questions/QuestionExtract.py +183 -183
- edsl/questions/QuestionFreeText.py +114 -113
- edsl/questions/QuestionFunctional.py +159 -159
- edsl/questions/QuestionList.py +231 -231
- edsl/questions/QuestionMultipleChoice.py +286 -286
- edsl/questions/QuestionNumerical.py +153 -153
- edsl/questions/QuestionRank.py +324 -324
- edsl/questions/Quick.py +41 -41
- edsl/questions/RegisterQuestionsMeta.py +71 -71
- edsl/questions/ResponseValidatorABC.py +174 -174
- edsl/questions/SimpleAskMixin.py +73 -73
- edsl/questions/__init__.py +26 -26
- edsl/questions/compose_questions.py +98 -98
- edsl/questions/decorators.py +21 -21
- edsl/questions/derived/QuestionLikertFive.py +76 -76
- edsl/questions/derived/QuestionLinearScale.py +87 -87
- edsl/questions/derived/QuestionTopK.py +91 -91
- edsl/questions/derived/QuestionYesNo.py +82 -82
- edsl/questions/descriptors.py +418 -418
- edsl/questions/prompt_templates/question_budget.jinja +13 -13
- edsl/questions/prompt_templates/question_checkbox.jinja +32 -32
- edsl/questions/prompt_templates/question_extract.jinja +11 -11
- edsl/questions/prompt_templates/question_free_text.jinja +3 -3
- edsl/questions/prompt_templates/question_linear_scale.jinja +11 -11
- edsl/questions/prompt_templates/question_list.jinja +17 -17
- edsl/questions/prompt_templates/question_multiple_choice.jinja +33 -33
- edsl/questions/prompt_templates/question_numerical.jinja +36 -36
- edsl/questions/question_registry.py +147 -147
- edsl/questions/settings.py +12 -12
- edsl/questions/templates/budget/answering_instructions.jinja +7 -7
- edsl/questions/templates/budget/question_presentation.jinja +7 -7
- edsl/questions/templates/checkbox/answering_instructions.jinja +10 -10
- edsl/questions/templates/checkbox/question_presentation.jinja +22 -22
- edsl/questions/templates/extract/answering_instructions.jinja +7 -7
- edsl/questions/templates/likert_five/answering_instructions.jinja +10 -10
- edsl/questions/templates/likert_five/question_presentation.jinja +11 -11
- edsl/questions/templates/linear_scale/answering_instructions.jinja +5 -5
- edsl/questions/templates/linear_scale/question_presentation.jinja +5 -5
- edsl/questions/templates/list/answering_instructions.jinja +3 -3
- edsl/questions/templates/list/question_presentation.jinja +5 -5
- edsl/questions/templates/multiple_choice/answering_instructions.jinja +9 -9
- edsl/questions/templates/multiple_choice/question_presentation.jinja +11 -11
- edsl/questions/templates/numerical/answering_instructions.jinja +6 -6
- edsl/questions/templates/numerical/question_presentation.jinja +6 -6
- edsl/questions/templates/rank/answering_instructions.jinja +11 -11
- edsl/questions/templates/rank/question_presentation.jinja +15 -15
- edsl/questions/templates/top_k/answering_instructions.jinja +8 -8
- edsl/questions/templates/top_k/question_presentation.jinja +22 -22
- edsl/questions/templates/yes_no/answering_instructions.jinja +6 -6
- edsl/questions/templates/yes_no/question_presentation.jinja +11 -11
- edsl/results/Dataset.py +293 -293
- edsl/results/DatasetExportMixin.py +693 -693
- edsl/results/DatasetTree.py +145 -145
- edsl/results/Result.py +435 -433
- edsl/results/Results.py +1160 -1158
- edsl/results/ResultsDBMixin.py +238 -238
- edsl/results/ResultsExportMixin.py +43 -43
- edsl/results/ResultsFetchMixin.py +33 -33
- edsl/results/ResultsGGMixin.py +121 -121
- edsl/results/ResultsToolsMixin.py +98 -98
- edsl/results/Selector.py +118 -118
- edsl/results/__init__.py +2 -2
- edsl/results/tree_explore.py +115 -115
- edsl/scenarios/FileStore.py +458 -443
- edsl/scenarios/Scenario.py +510 -507
- edsl/scenarios/ScenarioHtmlMixin.py +59 -59
- edsl/scenarios/ScenarioList.py +1101 -1101
- edsl/scenarios/ScenarioListExportMixin.py +52 -52
- edsl/scenarios/ScenarioListPdfMixin.py +261 -261
- edsl/scenarios/__init__.py +4 -2
- edsl/shared.py +1 -1
- edsl/study/ObjectEntry.py +173 -173
- edsl/study/ProofOfWork.py +113 -113
- edsl/study/SnapShot.py +80 -80
- edsl/study/Study.py +528 -528
- edsl/study/__init__.py +4 -4
- edsl/surveys/DAG.py +148 -148
- edsl/surveys/Memory.py +31 -31
- edsl/surveys/MemoryPlan.py +244 -244
- edsl/surveys/Rule.py +324 -324
- edsl/surveys/RuleCollection.py +387 -387
- edsl/surveys/Survey.py +1772 -1772
- edsl/surveys/SurveyCSS.py +261 -261
- edsl/surveys/SurveyExportMixin.py +259 -259
- edsl/surveys/SurveyFlowVisualizationMixin.py +121 -121
- edsl/surveys/SurveyQualtricsImport.py +284 -284
- edsl/surveys/__init__.py +3 -3
- edsl/surveys/base.py +53 -53
- edsl/surveys/descriptors.py +56 -56
- edsl/surveys/instructions/ChangeInstruction.py +47 -47
- edsl/surveys/instructions/Instruction.py +51 -51
- edsl/surveys/instructions/InstructionCollection.py +77 -77
- edsl/templates/error_reporting/base.html +23 -23
- edsl/templates/error_reporting/exceptions_by_model.html +34 -34
- edsl/templates/error_reporting/exceptions_by_question_name.html +16 -16
- edsl/templates/error_reporting/exceptions_by_type.html +16 -16
- edsl/templates/error_reporting/interview_details.html +115 -115
- edsl/templates/error_reporting/interviews.html +9 -9
- edsl/templates/error_reporting/overview.html +4 -4
- edsl/templates/error_reporting/performance_plot.html +1 -1
- edsl/templates/error_reporting/report.css +73 -73
- edsl/templates/error_reporting/report.html +117 -117
- edsl/templates/error_reporting/report.js +25 -25
- edsl/tools/__init__.py +1 -1
- edsl/tools/clusters.py +192 -192
- edsl/tools/embeddings.py +27 -27
- edsl/tools/embeddings_plotting.py +118 -118
- edsl/tools/plotting.py +112 -112
- edsl/tools/summarize.py +18 -18
- edsl/utilities/SystemInfo.py +28 -28
- edsl/utilities/__init__.py +22 -22
- edsl/utilities/ast_utilities.py +25 -25
- edsl/utilities/data/Registry.py +6 -6
- edsl/utilities/data/__init__.py +1 -1
- edsl/utilities/data/scooter_results.json +1 -1
- edsl/utilities/decorators.py +77 -77
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -96
- edsl/utilities/interface.py +627 -627
- edsl/utilities/repair_functions.py +28 -28
- edsl/utilities/restricted_python.py +70 -70
- edsl/utilities/utilities.py +391 -391
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dev2.dist-info}/LICENSE +21 -21
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dev2.dist-info}/METADATA +1 -1
- edsl-0.1.37.dev2.dist-info/RECORD +279 -0
- edsl-0.1.36.dev6.dist-info/RECORD +0 -279
- {edsl-0.1.36.dev6.dist-info → edsl-0.1.37.dev2.dist-info}/WHEEL +0 -0
edsl/data/Cache.py
CHANGED
@@ -1,527 +1,527 @@
|
|
1
|
-
"""
|
2
|
-
The `Cache` class is used to store responses from a language model.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from __future__ import annotations
|
6
|
-
import json
|
7
|
-
import os
|
8
|
-
import warnings
|
9
|
-
import copy
|
10
|
-
from typing import Optional, Union
|
11
|
-
from edsl.Base import Base
|
12
|
-
from edsl.data.CacheEntry import CacheEntry
|
13
|
-
from edsl.utilities.utilities import dict_hash
|
14
|
-
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
15
|
-
|
16
|
-
|
17
|
-
class Cache(Base):
|
18
|
-
"""
|
19
|
-
A class that represents a cache of responses from a language model.
|
20
|
-
|
21
|
-
:param data: The data to initialize the cache with.
|
22
|
-
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
23
|
-
|
24
|
-
Deprecated:
|
25
|
-
|
26
|
-
:param method: The method of storage to use for the cache.
|
27
|
-
"""
|
28
|
-
|
29
|
-
data = {}
|
30
|
-
|
31
|
-
def __init__(
|
32
|
-
self,
|
33
|
-
*,
|
34
|
-
filename: Optional[str] = None,
|
35
|
-
data: Optional[Union["SQLiteDict", dict]] = None,
|
36
|
-
immediate_write: bool = True,
|
37
|
-
method=None,
|
38
|
-
verbose=False,
|
39
|
-
):
|
40
|
-
"""
|
41
|
-
Create two dictionaries to store the cache data.
|
42
|
-
|
43
|
-
:param filename: The name of the file to read/write the cache from/to.
|
44
|
-
:param data: The data to initialize the cache with.
|
45
|
-
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
46
|
-
:param method: The method of storage to use for the cache.
|
47
|
-
|
48
|
-
"""
|
49
|
-
|
50
|
-
# self.data_at_init = data or {}
|
51
|
-
self.fetched_data = {}
|
52
|
-
self.immediate_write = immediate_write
|
53
|
-
self.method = method
|
54
|
-
self.new_entries = {}
|
55
|
-
self.new_entries_to_write_later = {}
|
56
|
-
self.coop = None
|
57
|
-
self.verbose = verbose
|
58
|
-
|
59
|
-
self.filename = filename
|
60
|
-
if filename and data:
|
61
|
-
raise ValueError("Cannot provide both filename and data")
|
62
|
-
if filename is None and data is None:
|
63
|
-
data = {}
|
64
|
-
if data is not None:
|
65
|
-
self.data = data
|
66
|
-
if filename is not None:
|
67
|
-
self.data = {}
|
68
|
-
if filename.endswith(".jsonl"):
|
69
|
-
if os.path.exists(filename):
|
70
|
-
self.add_from_jsonl(filename)
|
71
|
-
else:
|
72
|
-
print(
|
73
|
-
f"File {filename} not found, but will write to this location."
|
74
|
-
)
|
75
|
-
elif filename.endswith(".db"):
|
76
|
-
if os.path.exists(filename):
|
77
|
-
self.add_from_sqlite(filename)
|
78
|
-
else:
|
79
|
-
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
80
|
-
|
81
|
-
self._perform_checks()
|
82
|
-
|
83
|
-
def rich_print(sefl):
|
84
|
-
pass
|
85
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
86
|
-
|
87
|
-
def code(sefl):
|
88
|
-
pass
|
89
|
-
# raise NotImplementedError("This method is not implemented yet.")
|
90
|
-
|
91
|
-
def keys(self):
|
92
|
-
"""
|
93
|
-
>>> from edsl import Cache
|
94
|
-
>>> Cache.example().keys()
|
95
|
-
['5513286eb6967abc0511211f0402587d']
|
96
|
-
"""
|
97
|
-
return list(self.data.keys())
|
98
|
-
|
99
|
-
def values(self):
|
100
|
-
"""
|
101
|
-
>>> from edsl import Cache
|
102
|
-
>>> Cache.example().values()
|
103
|
-
[CacheEntry(...)]
|
104
|
-
"""
|
105
|
-
return list(self.data.values())
|
106
|
-
|
107
|
-
def items(self):
|
108
|
-
return zip(self.keys(), self.values())
|
109
|
-
|
110
|
-
def new_entries_cache(self) -> Cache:
|
111
|
-
"""Return a new Cache object with the new entries."""
|
112
|
-
return Cache(data={**self.new_entries, **self.fetched_data})
|
113
|
-
|
114
|
-
def _perform_checks(self):
|
115
|
-
"""Perform checks on the cache."""
|
116
|
-
from edsl.data.CacheEntry import CacheEntry
|
117
|
-
|
118
|
-
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
119
|
-
raise Exception("Not all values are CacheEntry instances")
|
120
|
-
if self.method is not None:
|
121
|
-
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
122
|
-
|
123
|
-
####################
|
124
|
-
# READ/WRITE
|
125
|
-
####################
|
126
|
-
def fetch(
|
127
|
-
self,
|
128
|
-
*,
|
129
|
-
model: str,
|
130
|
-
parameters: dict,
|
131
|
-
system_prompt: str,
|
132
|
-
user_prompt: str,
|
133
|
-
iteration: int,
|
134
|
-
) -> tuple(Union[None, str], str):
|
135
|
-
"""
|
136
|
-
Fetch a value (LLM output) from the cache.
|
137
|
-
|
138
|
-
:param model: The name of the language model.
|
139
|
-
:param parameters: The model parameters.
|
140
|
-
:param system_prompt: The system prompt.
|
141
|
-
:param user_prompt: The user prompt.
|
142
|
-
:param iteration: The iteration number.
|
143
|
-
|
144
|
-
Return None if the response is not found.
|
145
|
-
|
146
|
-
>>> c = Cache()
|
147
|
-
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
|
148
|
-
True
|
149
|
-
|
150
|
-
|
151
|
-
"""
|
152
|
-
from edsl.data.CacheEntry import CacheEntry
|
153
|
-
|
154
|
-
key = CacheEntry.gen_key(
|
155
|
-
model=model,
|
156
|
-
parameters=parameters,
|
157
|
-
system_prompt=system_prompt,
|
158
|
-
user_prompt=user_prompt,
|
159
|
-
iteration=iteration,
|
160
|
-
)
|
161
|
-
entry = self.data.get(key, None)
|
162
|
-
if entry is not None:
|
163
|
-
if self.verbose:
|
164
|
-
print(f"Cache hit for key: {key}")
|
165
|
-
self.fetched_data[key] = entry
|
166
|
-
else:
|
167
|
-
if self.verbose:
|
168
|
-
print(f"Cache miss for key: {key}")
|
169
|
-
return None if entry is None else entry.output, key
|
170
|
-
|
171
|
-
def store(
|
172
|
-
self,
|
173
|
-
model: str,
|
174
|
-
parameters: str,
|
175
|
-
system_prompt: str,
|
176
|
-
user_prompt: str,
|
177
|
-
response: dict,
|
178
|
-
iteration: int,
|
179
|
-
) -> str:
|
180
|
-
"""
|
181
|
-
Add a new key-value pair to the cache.
|
182
|
-
|
183
|
-
* Key is a hash of the input parameters.
|
184
|
-
* Output is the response from the language model.
|
185
|
-
|
186
|
-
How it works:
|
187
|
-
|
188
|
-
* The key-value pair is added to `self.new_entries`
|
189
|
-
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
190
|
-
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
191
|
-
|
192
|
-
>>> from edsl import Cache, Model, Question
|
193
|
-
>>> m = Model("test")
|
194
|
-
>>> c = Cache()
|
195
|
-
>>> len(c)
|
196
|
-
0
|
197
|
-
>>> results = Question.example("free_text").by(m).run(cache = c)
|
198
|
-
>>> len(c)
|
199
|
-
1
|
200
|
-
"""
|
201
|
-
|
202
|
-
entry = CacheEntry(
|
203
|
-
model=model,
|
204
|
-
parameters=parameters,
|
205
|
-
system_prompt=system_prompt,
|
206
|
-
user_prompt=user_prompt,
|
207
|
-
output=json.dumps(response),
|
208
|
-
iteration=iteration,
|
209
|
-
)
|
210
|
-
key = entry.key
|
211
|
-
self.new_entries[key] = entry
|
212
|
-
if self.immediate_write:
|
213
|
-
self.data[key] = entry
|
214
|
-
else:
|
215
|
-
self.new_entries_to_write_later[key] = entry
|
216
|
-
return key
|
217
|
-
|
218
|
-
def add_from_dict(
|
219
|
-
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
220
|
-
) -> None:
|
221
|
-
"""
|
222
|
-
Add entries to the cache from a dictionary.
|
223
|
-
|
224
|
-
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
225
|
-
"""
|
226
|
-
|
227
|
-
for key, value in new_data.items():
|
228
|
-
if key in self.data:
|
229
|
-
if value != self.data[key]:
|
230
|
-
raise Exception("Mismatch in values")
|
231
|
-
if not isinstance(value, CacheEntry):
|
232
|
-
raise Exception(f"Wrong type - the observed type is {type(value)}")
|
233
|
-
|
234
|
-
self.new_entries.update(new_data)
|
235
|
-
if write_now:
|
236
|
-
self.data.update(new_data)
|
237
|
-
else:
|
238
|
-
self.new_entries_to_write_later.update(new_data)
|
239
|
-
|
240
|
-
def add_from_jsonl(self, filename: str, write_now: Optional[bool] = True) -> None:
|
241
|
-
"""
|
242
|
-
Add entries to the cache from a JSONL.
|
243
|
-
|
244
|
-
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
245
|
-
"""
|
246
|
-
with open(filename, "a+") as f:
|
247
|
-
f.seek(0)
|
248
|
-
lines = f.readlines()
|
249
|
-
new_data = {}
|
250
|
-
for line in lines:
|
251
|
-
d = json.loads(line)
|
252
|
-
key = list(d.keys())[0]
|
253
|
-
value = list(d.values())[0]
|
254
|
-
new_data[key] = CacheEntry(**value)
|
255
|
-
self.add_from_dict(new_data=new_data, write_now=write_now)
|
256
|
-
|
257
|
-
def add_from_sqlite(self, db_path: str, write_now: Optional[bool] = True):
|
258
|
-
"""
|
259
|
-
Add entries to the cache from an SQLite database.
|
260
|
-
|
261
|
-
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
262
|
-
"""
|
263
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
264
|
-
|
265
|
-
db = SQLiteDict(db_path)
|
266
|
-
new_data = {}
|
267
|
-
for key, value in db.items():
|
268
|
-
new_data[key] = CacheEntry(**value)
|
269
|
-
self.add_from_dict(new_data=new_data, write_now=write_now)
|
270
|
-
|
271
|
-
@classmethod
|
272
|
-
def from_sqlite_db(cls, db_path: str) -> Cache:
|
273
|
-
"""
|
274
|
-
Construct a Cache from a SQLite database.
|
275
|
-
"""
|
276
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
277
|
-
|
278
|
-
return cls(data=SQLiteDict(db_path))
|
279
|
-
|
280
|
-
@classmethod
|
281
|
-
def from_local_cache(cls) -> Cache:
|
282
|
-
"""
|
283
|
-
Construct a Cache from a local cache file.
|
284
|
-
"""
|
285
|
-
from edsl.config import CONFIG
|
286
|
-
|
287
|
-
CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
|
288
|
-
path = CACHE_PATH.replace("sqlite:///", "")
|
289
|
-
db_path = os.path.join(os.path.dirname(path), "data.db")
|
290
|
-
return cls.from_sqlite_db(db_path=db_path)
|
291
|
-
|
292
|
-
@classmethod
|
293
|
-
def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
|
294
|
-
"""
|
295
|
-
Construct a Cache from a JSONL file.
|
296
|
-
|
297
|
-
:param jsonlfile: The path to the JSONL file of cache entries.
|
298
|
-
:param db_path: The path to the SQLite database used to store the cache.
|
299
|
-
|
300
|
-
* If `db_path` is None, the cache will be stored in memory, as a dictionary.
|
301
|
-
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
302
|
-
"""
|
303
|
-
# if a file doesn't exist at jsonfile, throw an error
|
304
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
305
|
-
|
306
|
-
if not os.path.exists(jsonlfile):
|
307
|
-
raise FileNotFoundError(f"File {jsonlfile} not found")
|
308
|
-
|
309
|
-
if db_path is None:
|
310
|
-
data = {}
|
311
|
-
else:
|
312
|
-
data = SQLiteDict(db_path)
|
313
|
-
|
314
|
-
cache = Cache(data=data)
|
315
|
-
cache.add_from_jsonl(jsonlfile)
|
316
|
-
return cache
|
317
|
-
|
318
|
-
def write_sqlite_db(self, db_path: str) -> None:
|
319
|
-
"""
|
320
|
-
Write the cache to an SQLite database.
|
321
|
-
"""
|
322
|
-
## TODO: Check to make sure not over-writing (?)
|
323
|
-
## Should be added to SQLiteDict constructor (?)
|
324
|
-
from edsl.data.SQLiteDict import SQLiteDict
|
325
|
-
|
326
|
-
new_data = SQLiteDict(db_path)
|
327
|
-
for key, value in self.data.items():
|
328
|
-
new_data[key] = value
|
329
|
-
|
330
|
-
def write(self, filename: Optional[str] = None) -> None:
|
331
|
-
"""
|
332
|
-
Write the cache to a file at the specified location.
|
333
|
-
"""
|
334
|
-
if filename is None:
|
335
|
-
filename = self.filename
|
336
|
-
if filename.endswith(".jsonl"):
|
337
|
-
self.write_jsonl(filename)
|
338
|
-
elif filename.endswith(".db"):
|
339
|
-
self.write_sqlite_db(filename)
|
340
|
-
else:
|
341
|
-
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
342
|
-
|
343
|
-
def write_jsonl(self, filename: str) -> None:
|
344
|
-
"""
|
345
|
-
Write the cache to a JSONL file.
|
346
|
-
"""
|
347
|
-
path = os.path.join(os.getcwd(), filename)
|
348
|
-
with open(path, "w") as f:
|
349
|
-
for key, value in self.data.items():
|
350
|
-
f.write(json.dumps({key: value.to_dict()}) + "\n")
|
351
|
-
|
352
|
-
def to_scenario_list(self):
|
353
|
-
from edsl import ScenarioList, Scenario
|
354
|
-
|
355
|
-
scenarios = []
|
356
|
-
for key, value in self.data.items():
|
357
|
-
new_d = value.to_dict()
|
358
|
-
new_d["cache_key"] = key
|
359
|
-
s = Scenario(new_d)
|
360
|
-
scenarios.append(s)
|
361
|
-
return ScenarioList(scenarios)
|
362
|
-
|
363
|
-
####################
|
364
|
-
# REMOTE
|
365
|
-
####################
|
366
|
-
# TODO: Make this work
|
367
|
-
# - Need to decide whether the cache belongs to a user and what can be shared
|
368
|
-
# - I.e., some cache entries? all or nothing?
|
369
|
-
@classmethod
|
370
|
-
def from_url(cls, db_path=None) -> Cache:
|
371
|
-
"""
|
372
|
-
Construct a Cache object from a remote.
|
373
|
-
"""
|
374
|
-
# ...do something here
|
375
|
-
# return Cache(data=db)
|
376
|
-
pass
|
377
|
-
|
378
|
-
def __enter__(self):
|
379
|
-
"""
|
380
|
-
Run when a context is entered.
|
381
|
-
"""
|
382
|
-
return self
|
383
|
-
|
384
|
-
def __exit__(self, exc_type, exc_value, traceback):
|
385
|
-
"""
|
386
|
-
Run when a context is exited.
|
387
|
-
"""
|
388
|
-
for key, entry in self.new_entries_to_write_later.items():
|
389
|
-
self.data[key] = entry
|
390
|
-
|
391
|
-
if self.filename:
|
392
|
-
self.write(self.filename)
|
393
|
-
|
394
|
-
####################
|
395
|
-
# DUNDER / USEFUL
|
396
|
-
####################
|
397
|
-
def __hash__(self):
|
398
|
-
"""Return the hash of the Cache."""
|
399
|
-
return dict_hash(self._to_dict())
|
400
|
-
|
401
|
-
def _to_dict(self) -> dict:
|
402
|
-
return {k: v.to_dict() for k, v in self.data.items()}
|
403
|
-
|
404
|
-
@add_edsl_version
|
405
|
-
def to_dict(self) -> dict:
|
406
|
-
"""Return the Cache as a dictionary."""
|
407
|
-
return self._to_dict()
|
408
|
-
|
409
|
-
def _repr_html_(self):
|
410
|
-
from edsl.utilities.utilities import data_to_html
|
411
|
-
|
412
|
-
return data_to_html(self.to_dict())
|
413
|
-
|
414
|
-
@classmethod
|
415
|
-
@remove_edsl_version
|
416
|
-
def from_dict(cls, data) -> Cache:
|
417
|
-
"""Construct a Cache from a dictionary."""
|
418
|
-
newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
|
419
|
-
return cls(data=newdata)
|
420
|
-
|
421
|
-
def __len__(self):
|
422
|
-
"""Return the number of CacheEntry objects in the Cache."""
|
423
|
-
return len(self.data)
|
424
|
-
|
425
|
-
# TODO: Same inputs could give different results and this could be useful
|
426
|
-
# can't distinguish unless we do the ε trick or vary iterations
|
427
|
-
def __eq__(self, other_cache: "Cache") -> bool:
|
428
|
-
"""
|
429
|
-
Check if two Cache objects are equal.
|
430
|
-
Does not verify their values are equal, only that they have the same keys.
|
431
|
-
"""
|
432
|
-
if not isinstance(other_cache, Cache):
|
433
|
-
return False
|
434
|
-
return set(self.data.keys()) == set(other_cache.data.keys())
|
435
|
-
|
436
|
-
def __add__(self, other: "Cache"):
|
437
|
-
"""
|
438
|
-
Combine two caches.
|
439
|
-
"""
|
440
|
-
if not isinstance(other, Cache):
|
441
|
-
raise ValueError("Can only add two caches together")
|
442
|
-
self.data.update(other.data)
|
443
|
-
return self
|
444
|
-
|
445
|
-
def __repr__(self):
|
446
|
-
"""
|
447
|
-
Return a string representation of the Cache object.
|
448
|
-
"""
|
449
|
-
return (
|
450
|
-
f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
|
451
|
-
)
|
452
|
-
|
453
|
-
####################
|
454
|
-
# EXAMPLES
|
455
|
-
####################
|
456
|
-
def fetch_input_example(self) -> dict:
|
457
|
-
"""
|
458
|
-
Create an example input for a 'fetch' operation.
|
459
|
-
"""
|
460
|
-
return CacheEntry.fetch_input_example()
|
461
|
-
|
462
|
-
def to_html(self):
|
463
|
-
# json_str = json.dumps(self.data, indent=4)
|
464
|
-
d = {k: v.to_dict() for k, v in self.data.items()}
|
465
|
-
for key, value in d.items():
|
466
|
-
for k, v in value.items():
|
467
|
-
if isinstance(v, dict):
|
468
|
-
d[key][k] = {kk: str(vv) for kk, vv in v.items()}
|
469
|
-
else:
|
470
|
-
d[key][k] = str(v)
|
471
|
-
|
472
|
-
json_str = json.dumps(d, indent=4)
|
473
|
-
|
474
|
-
# HTML template with the JSON string embedded
|
475
|
-
html = f"""
|
476
|
-
<!DOCTYPE html>
|
477
|
-
<html>
|
478
|
-
<head>
|
479
|
-
<title>Display JSON</title>
|
480
|
-
</head>
|
481
|
-
<body>
|
482
|
-
<pre id="jsonData"></pre>
|
483
|
-
<script>
|
484
|
-
var json = {json_str};
|
485
|
-
|
486
|
-
// JSON.stringify with spacing to format
|
487
|
-
document.getElementById('jsonData').textContent = JSON.stringify(json, null, 4);
|
488
|
-
</script>
|
489
|
-
</body>
|
490
|
-
</html>
|
491
|
-
"""
|
492
|
-
return html
|
493
|
-
|
494
|
-
def view(self) -> None:
|
495
|
-
"""View the Cache in a new browser tab."""
|
496
|
-
import tempfile
|
497
|
-
import webbrowser
|
498
|
-
|
499
|
-
html_content = self.to_html()
|
500
|
-
# Create a temporary file to hold the HTML
|
501
|
-
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as tmpfile:
|
502
|
-
tmpfile.write(html_content)
|
503
|
-
# Get the path to the temporary file
|
504
|
-
filepath = tmpfile.name
|
505
|
-
|
506
|
-
# Open the HTML file in a new browser tab
|
507
|
-
webbrowser.open("file://" + filepath)
|
508
|
-
|
509
|
-
@classmethod
|
510
|
-
def example(cls, randomize: bool = False) -> Cache:
|
511
|
-
"""
|
512
|
-
Returns an example Cache instance.
|
513
|
-
|
514
|
-
:param randomize: If True, uses CacheEntry's randomize method.
|
515
|
-
"""
|
516
|
-
return cls(
|
517
|
-
data={
|
518
|
-
CacheEntry.example(randomize).key: CacheEntry.example(),
|
519
|
-
CacheEntry.example(randomize).key: CacheEntry.example(),
|
520
|
-
}
|
521
|
-
)
|
522
|
-
|
523
|
-
|
524
|
-
if __name__ == "__main__":
|
525
|
-
import doctest
|
526
|
-
|
527
|
-
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
1
|
+
"""
|
2
|
+
The `Cache` class is used to store responses from a language model.
|
3
|
+
"""
|
4
|
+
|
5
|
+
from __future__ import annotations
|
6
|
+
import json
|
7
|
+
import os
|
8
|
+
import warnings
|
9
|
+
import copy
|
10
|
+
from typing import Optional, Union
|
11
|
+
from edsl.Base import Base
|
12
|
+
from edsl.data.CacheEntry import CacheEntry
|
13
|
+
from edsl.utilities.utilities import dict_hash
|
14
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
15
|
+
|
16
|
+
|
17
|
+
class Cache(Base):
|
18
|
+
"""
|
19
|
+
A class that represents a cache of responses from a language model.
|
20
|
+
|
21
|
+
:param data: The data to initialize the cache with.
|
22
|
+
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
23
|
+
|
24
|
+
Deprecated:
|
25
|
+
|
26
|
+
:param method: The method of storage to use for the cache.
|
27
|
+
"""
|
28
|
+
|
29
|
+
data = {}
|
30
|
+
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
*,
|
34
|
+
filename: Optional[str] = None,
|
35
|
+
data: Optional[Union["SQLiteDict", dict]] = None,
|
36
|
+
immediate_write: bool = True,
|
37
|
+
method=None,
|
38
|
+
verbose=False,
|
39
|
+
):
|
40
|
+
"""
|
41
|
+
Create two dictionaries to store the cache data.
|
42
|
+
|
43
|
+
:param filename: The name of the file to read/write the cache from/to.
|
44
|
+
:param data: The data to initialize the cache with.
|
45
|
+
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
46
|
+
:param method: The method of storage to use for the cache.
|
47
|
+
|
48
|
+
"""
|
49
|
+
|
50
|
+
# self.data_at_init = data or {}
|
51
|
+
self.fetched_data = {}
|
52
|
+
self.immediate_write = immediate_write
|
53
|
+
self.method = method
|
54
|
+
self.new_entries = {}
|
55
|
+
self.new_entries_to_write_later = {}
|
56
|
+
self.coop = None
|
57
|
+
self.verbose = verbose
|
58
|
+
|
59
|
+
self.filename = filename
|
60
|
+
if filename and data:
|
61
|
+
raise ValueError("Cannot provide both filename and data")
|
62
|
+
if filename is None and data is None:
|
63
|
+
data = {}
|
64
|
+
if data is not None:
|
65
|
+
self.data = data
|
66
|
+
if filename is not None:
|
67
|
+
self.data = {}
|
68
|
+
if filename.endswith(".jsonl"):
|
69
|
+
if os.path.exists(filename):
|
70
|
+
self.add_from_jsonl(filename)
|
71
|
+
else:
|
72
|
+
print(
|
73
|
+
f"File {filename} not found, but will write to this location."
|
74
|
+
)
|
75
|
+
elif filename.endswith(".db"):
|
76
|
+
if os.path.exists(filename):
|
77
|
+
self.add_from_sqlite(filename)
|
78
|
+
else:
|
79
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
80
|
+
|
81
|
+
self._perform_checks()
|
82
|
+
|
83
|
+
def rich_print(sefl):
|
84
|
+
pass
|
85
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
86
|
+
|
87
|
+
def code(sefl):
|
88
|
+
pass
|
89
|
+
# raise NotImplementedError("This method is not implemented yet.")
|
90
|
+
|
91
|
+
def keys(self):
|
92
|
+
"""
|
93
|
+
>>> from edsl import Cache
|
94
|
+
>>> Cache.example().keys()
|
95
|
+
['5513286eb6967abc0511211f0402587d']
|
96
|
+
"""
|
97
|
+
return list(self.data.keys())
|
98
|
+
|
99
|
+
def values(self):
|
100
|
+
"""
|
101
|
+
>>> from edsl import Cache
|
102
|
+
>>> Cache.example().values()
|
103
|
+
[CacheEntry(...)]
|
104
|
+
"""
|
105
|
+
return list(self.data.values())
|
106
|
+
|
107
|
+
def items(self):
|
108
|
+
return zip(self.keys(), self.values())
|
109
|
+
|
110
|
+
def new_entries_cache(self) -> Cache:
|
111
|
+
"""Return a new Cache object with the new entries."""
|
112
|
+
return Cache(data={**self.new_entries, **self.fetched_data})
|
113
|
+
|
114
|
+
def _perform_checks(self):
|
115
|
+
"""Perform checks on the cache."""
|
116
|
+
from edsl.data.CacheEntry import CacheEntry
|
117
|
+
|
118
|
+
if any(not isinstance(value, CacheEntry) for value in self.data.values()):
|
119
|
+
raise Exception("Not all values are CacheEntry instances")
|
120
|
+
if self.method is not None:
|
121
|
+
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
122
|
+
|
123
|
+
####################
|
124
|
+
# READ/WRITE
|
125
|
+
####################
|
126
|
+
def fetch(
|
127
|
+
self,
|
128
|
+
*,
|
129
|
+
model: str,
|
130
|
+
parameters: dict,
|
131
|
+
system_prompt: str,
|
132
|
+
user_prompt: str,
|
133
|
+
iteration: int,
|
134
|
+
) -> tuple(Union[None, str], str):
|
135
|
+
"""
|
136
|
+
Fetch a value (LLM output) from the cache.
|
137
|
+
|
138
|
+
:param model: The name of the language model.
|
139
|
+
:param parameters: The model parameters.
|
140
|
+
:param system_prompt: The system prompt.
|
141
|
+
:param user_prompt: The user prompt.
|
142
|
+
:param iteration: The iteration number.
|
143
|
+
|
144
|
+
Return None if the response is not found.
|
145
|
+
|
146
|
+
>>> c = Cache()
|
147
|
+
>>> c.fetch(model="gpt-3", parameters="default", system_prompt="Hello", user_prompt="Hi", iteration=1)[0] is None
|
148
|
+
True
|
149
|
+
|
150
|
+
|
151
|
+
"""
|
152
|
+
from edsl.data.CacheEntry import CacheEntry
|
153
|
+
|
154
|
+
key = CacheEntry.gen_key(
|
155
|
+
model=model,
|
156
|
+
parameters=parameters,
|
157
|
+
system_prompt=system_prompt,
|
158
|
+
user_prompt=user_prompt,
|
159
|
+
iteration=iteration,
|
160
|
+
)
|
161
|
+
entry = self.data.get(key, None)
|
162
|
+
if entry is not None:
|
163
|
+
if self.verbose:
|
164
|
+
print(f"Cache hit for key: {key}")
|
165
|
+
self.fetched_data[key] = entry
|
166
|
+
else:
|
167
|
+
if self.verbose:
|
168
|
+
print(f"Cache miss for key: {key}")
|
169
|
+
return None if entry is None else entry.output, key
|
170
|
+
|
171
|
+
def store(
|
172
|
+
self,
|
173
|
+
model: str,
|
174
|
+
parameters: str,
|
175
|
+
system_prompt: str,
|
176
|
+
user_prompt: str,
|
177
|
+
response: dict,
|
178
|
+
iteration: int,
|
179
|
+
) -> str:
|
180
|
+
"""
|
181
|
+
Add a new key-value pair to the cache.
|
182
|
+
|
183
|
+
* Key is a hash of the input parameters.
|
184
|
+
* Output is the response from the language model.
|
185
|
+
|
186
|
+
How it works:
|
187
|
+
|
188
|
+
* The key-value pair is added to `self.new_entries`
|
189
|
+
* If `immediate_write` is True , the key-value pair is added to `self.data`
|
190
|
+
* If `immediate_write` is False, the key-value pair is added to `self.new_entries_to_write_later`
|
191
|
+
|
192
|
+
>>> from edsl import Cache, Model, Question
|
193
|
+
>>> m = Model("test")
|
194
|
+
>>> c = Cache()
|
195
|
+
>>> len(c)
|
196
|
+
0
|
197
|
+
>>> results = Question.example("free_text").by(m).run(cache = c)
|
198
|
+
>>> len(c)
|
199
|
+
1
|
200
|
+
"""
|
201
|
+
|
202
|
+
entry = CacheEntry(
|
203
|
+
model=model,
|
204
|
+
parameters=parameters,
|
205
|
+
system_prompt=system_prompt,
|
206
|
+
user_prompt=user_prompt,
|
207
|
+
output=json.dumps(response),
|
208
|
+
iteration=iteration,
|
209
|
+
)
|
210
|
+
key = entry.key
|
211
|
+
self.new_entries[key] = entry
|
212
|
+
if self.immediate_write:
|
213
|
+
self.data[key] = entry
|
214
|
+
else:
|
215
|
+
self.new_entries_to_write_later[key] = entry
|
216
|
+
return key
|
217
|
+
|
218
|
+
def add_from_dict(
|
219
|
+
self, new_data: dict[str, "CacheEntry"], write_now: Optional[bool] = True
|
220
|
+
) -> None:
|
221
|
+
"""
|
222
|
+
Add entries to the cache from a dictionary.
|
223
|
+
|
224
|
+
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
225
|
+
"""
|
226
|
+
|
227
|
+
for key, value in new_data.items():
|
228
|
+
if key in self.data:
|
229
|
+
if value != self.data[key]:
|
230
|
+
raise Exception("Mismatch in values")
|
231
|
+
if not isinstance(value, CacheEntry):
|
232
|
+
raise Exception(f"Wrong type - the observed type is {type(value)}")
|
233
|
+
|
234
|
+
self.new_entries.update(new_data)
|
235
|
+
if write_now:
|
236
|
+
self.data.update(new_data)
|
237
|
+
else:
|
238
|
+
self.new_entries_to_write_later.update(new_data)
|
239
|
+
|
240
|
+
def add_from_jsonl(self, filename: str, write_now: Optional[bool] = True) -> None:
|
241
|
+
"""
|
242
|
+
Add entries to the cache from a JSONL.
|
243
|
+
|
244
|
+
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
245
|
+
"""
|
246
|
+
with open(filename, "a+") as f:
|
247
|
+
f.seek(0)
|
248
|
+
lines = f.readlines()
|
249
|
+
new_data = {}
|
250
|
+
for line in lines:
|
251
|
+
d = json.loads(line)
|
252
|
+
key = list(d.keys())[0]
|
253
|
+
value = list(d.values())[0]
|
254
|
+
new_data[key] = CacheEntry(**value)
|
255
|
+
self.add_from_dict(new_data=new_data, write_now=write_now)
|
256
|
+
|
257
|
+
def add_from_sqlite(self, db_path: str, write_now: Optional[bool] = True):
|
258
|
+
"""
|
259
|
+
Add entries to the cache from an SQLite database.
|
260
|
+
|
261
|
+
:param write_now: Whether to write to the cache immediately (similar to `immediate_write`).
|
262
|
+
"""
|
263
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
264
|
+
|
265
|
+
db = SQLiteDict(db_path)
|
266
|
+
new_data = {}
|
267
|
+
for key, value in db.items():
|
268
|
+
new_data[key] = CacheEntry(**value)
|
269
|
+
self.add_from_dict(new_data=new_data, write_now=write_now)
|
270
|
+
|
271
|
+
@classmethod
|
272
|
+
def from_sqlite_db(cls, db_path: str) -> Cache:
|
273
|
+
"""
|
274
|
+
Construct a Cache from a SQLite database.
|
275
|
+
"""
|
276
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
277
|
+
|
278
|
+
return cls(data=SQLiteDict(db_path))
|
279
|
+
|
280
|
+
@classmethod
|
281
|
+
def from_local_cache(cls) -> Cache:
|
282
|
+
"""
|
283
|
+
Construct a Cache from a local cache file.
|
284
|
+
"""
|
285
|
+
from edsl.config import CONFIG
|
286
|
+
|
287
|
+
CACHE_PATH = CONFIG.get("EDSL_DATABASE_PATH")
|
288
|
+
path = CACHE_PATH.replace("sqlite:///", "")
|
289
|
+
db_path = os.path.join(os.path.dirname(path), "data.db")
|
290
|
+
return cls.from_sqlite_db(db_path=db_path)
|
291
|
+
|
292
|
+
@classmethod
|
293
|
+
def from_jsonl(cls, jsonlfile: str, db_path: Optional[str] = None) -> Cache:
|
294
|
+
"""
|
295
|
+
Construct a Cache from a JSONL file.
|
296
|
+
|
297
|
+
:param jsonlfile: The path to the JSONL file of cache entries.
|
298
|
+
:param db_path: The path to the SQLite database used to store the cache.
|
299
|
+
|
300
|
+
* If `db_path` is None, the cache will be stored in memory, as a dictionary.
|
301
|
+
* If `db_path` is provided, the cache will be stored in an SQLite database.
|
302
|
+
"""
|
303
|
+
# if a file doesn't exist at jsonfile, throw an error
|
304
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
305
|
+
|
306
|
+
if not os.path.exists(jsonlfile):
|
307
|
+
raise FileNotFoundError(f"File {jsonlfile} not found")
|
308
|
+
|
309
|
+
if db_path is None:
|
310
|
+
data = {}
|
311
|
+
else:
|
312
|
+
data = SQLiteDict(db_path)
|
313
|
+
|
314
|
+
cache = Cache(data=data)
|
315
|
+
cache.add_from_jsonl(jsonlfile)
|
316
|
+
return cache
|
317
|
+
|
318
|
+
def write_sqlite_db(self, db_path: str) -> None:
|
319
|
+
"""
|
320
|
+
Write the cache to an SQLite database.
|
321
|
+
"""
|
322
|
+
## TODO: Check to make sure not over-writing (?)
|
323
|
+
## Should be added to SQLiteDict constructor (?)
|
324
|
+
from edsl.data.SQLiteDict import SQLiteDict
|
325
|
+
|
326
|
+
new_data = SQLiteDict(db_path)
|
327
|
+
for key, value in self.data.items():
|
328
|
+
new_data[key] = value
|
329
|
+
|
330
|
+
def write(self, filename: Optional[str] = None) -> None:
|
331
|
+
"""
|
332
|
+
Write the cache to a file at the specified location.
|
333
|
+
"""
|
334
|
+
if filename is None:
|
335
|
+
filename = self.filename
|
336
|
+
if filename.endswith(".jsonl"):
|
337
|
+
self.write_jsonl(filename)
|
338
|
+
elif filename.endswith(".db"):
|
339
|
+
self.write_sqlite_db(filename)
|
340
|
+
else:
|
341
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
342
|
+
|
343
|
+
def write_jsonl(self, filename: str) -> None:
|
344
|
+
"""
|
345
|
+
Write the cache to a JSONL file.
|
346
|
+
"""
|
347
|
+
path = os.path.join(os.getcwd(), filename)
|
348
|
+
with open(path, "w") as f:
|
349
|
+
for key, value in self.data.items():
|
350
|
+
f.write(json.dumps({key: value.to_dict()}) + "\n")
|
351
|
+
|
352
|
+
def to_scenario_list(self):
|
353
|
+
from edsl import ScenarioList, Scenario
|
354
|
+
|
355
|
+
scenarios = []
|
356
|
+
for key, value in self.data.items():
|
357
|
+
new_d = value.to_dict()
|
358
|
+
new_d["cache_key"] = key
|
359
|
+
s = Scenario(new_d)
|
360
|
+
scenarios.append(s)
|
361
|
+
return ScenarioList(scenarios)
|
362
|
+
|
363
|
+
####################
|
364
|
+
# REMOTE
|
365
|
+
####################
|
366
|
+
# TODO: Make this work
|
367
|
+
# - Need to decide whether the cache belongs to a user and what can be shared
|
368
|
+
# - I.e., some cache entries? all or nothing?
|
369
|
+
@classmethod
|
370
|
+
def from_url(cls, db_path=None) -> Cache:
|
371
|
+
"""
|
372
|
+
Construct a Cache object from a remote.
|
373
|
+
"""
|
374
|
+
# ...do something here
|
375
|
+
# return Cache(data=db)
|
376
|
+
pass
|
377
|
+
|
378
|
+
def __enter__(self):
|
379
|
+
"""
|
380
|
+
Run when a context is entered.
|
381
|
+
"""
|
382
|
+
return self
|
383
|
+
|
384
|
+
def __exit__(self, exc_type, exc_value, traceback):
|
385
|
+
"""
|
386
|
+
Run when a context is exited.
|
387
|
+
"""
|
388
|
+
for key, entry in self.new_entries_to_write_later.items():
|
389
|
+
self.data[key] = entry
|
390
|
+
|
391
|
+
if self.filename:
|
392
|
+
self.write(self.filename)
|
393
|
+
|
394
|
+
####################
|
395
|
+
# DUNDER / USEFUL
|
396
|
+
####################
|
397
|
+
def __hash__(self):
|
398
|
+
"""Return the hash of the Cache."""
|
399
|
+
return dict_hash(self._to_dict())
|
400
|
+
|
401
|
+
def _to_dict(self) -> dict:
|
402
|
+
return {k: v.to_dict() for k, v in self.data.items()}
|
403
|
+
|
404
|
+
@add_edsl_version
|
405
|
+
def to_dict(self) -> dict:
|
406
|
+
"""Return the Cache as a dictionary."""
|
407
|
+
return self._to_dict()
|
408
|
+
|
409
|
+
def _repr_html_(self):
|
410
|
+
from edsl.utilities.utilities import data_to_html
|
411
|
+
|
412
|
+
return data_to_html(self.to_dict())
|
413
|
+
|
414
|
+
@classmethod
|
415
|
+
@remove_edsl_version
|
416
|
+
def from_dict(cls, data) -> Cache:
|
417
|
+
"""Construct a Cache from a dictionary."""
|
418
|
+
newdata = {k: CacheEntry.from_dict(v) for k, v in data.items()}
|
419
|
+
return cls(data=newdata)
|
420
|
+
|
421
|
+
def __len__(self):
|
422
|
+
"""Return the number of CacheEntry objects in the Cache."""
|
423
|
+
return len(self.data)
|
424
|
+
|
425
|
+
# TODO: Same inputs could give different results and this could be useful
|
426
|
+
# can't distinguish unless we do the ε trick or vary iterations
|
427
|
+
def __eq__(self, other_cache: "Cache") -> bool:
|
428
|
+
"""
|
429
|
+
Check if two Cache objects are equal.
|
430
|
+
Does not verify their values are equal, only that they have the same keys.
|
431
|
+
"""
|
432
|
+
if not isinstance(other_cache, Cache):
|
433
|
+
return False
|
434
|
+
return set(self.data.keys()) == set(other_cache.data.keys())
|
435
|
+
|
436
|
+
def __add__(self, other: "Cache"):
|
437
|
+
"""
|
438
|
+
Combine two caches.
|
439
|
+
"""
|
440
|
+
if not isinstance(other, Cache):
|
441
|
+
raise ValueError("Can only add two caches together")
|
442
|
+
self.data.update(other.data)
|
443
|
+
return self
|
444
|
+
|
445
|
+
def __repr__(self):
|
446
|
+
"""
|
447
|
+
Return a string representation of the Cache object.
|
448
|
+
"""
|
449
|
+
return (
|
450
|
+
f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
|
451
|
+
)
|
452
|
+
|
453
|
+
####################
|
454
|
+
# EXAMPLES
|
455
|
+
####################
|
456
|
+
def fetch_input_example(self) -> dict:
|
457
|
+
"""
|
458
|
+
Create an example input for a 'fetch' operation.
|
459
|
+
"""
|
460
|
+
return CacheEntry.fetch_input_example()
|
461
|
+
|
462
|
+
def to_html(self):
|
463
|
+
# json_str = json.dumps(self.data, indent=4)
|
464
|
+
d = {k: v.to_dict() for k, v in self.data.items()}
|
465
|
+
for key, value in d.items():
|
466
|
+
for k, v in value.items():
|
467
|
+
if isinstance(v, dict):
|
468
|
+
d[key][k] = {kk: str(vv) for kk, vv in v.items()}
|
469
|
+
else:
|
470
|
+
d[key][k] = str(v)
|
471
|
+
|
472
|
+
json_str = json.dumps(d, indent=4)
|
473
|
+
|
474
|
+
# HTML template with the JSON string embedded
|
475
|
+
html = f"""
|
476
|
+
<!DOCTYPE html>
|
477
|
+
<html>
|
478
|
+
<head>
|
479
|
+
<title>Display JSON</title>
|
480
|
+
</head>
|
481
|
+
<body>
|
482
|
+
<pre id="jsonData"></pre>
|
483
|
+
<script>
|
484
|
+
var json = {json_str};
|
485
|
+
|
486
|
+
// JSON.stringify with spacing to format
|
487
|
+
document.getElementById('jsonData').textContent = JSON.stringify(json, null, 4);
|
488
|
+
</script>
|
489
|
+
</body>
|
490
|
+
</html>
|
491
|
+
"""
|
492
|
+
return html
|
493
|
+
|
494
|
+
def view(self) -> None:
|
495
|
+
"""View the Cache in a new browser tab."""
|
496
|
+
import tempfile
|
497
|
+
import webbrowser
|
498
|
+
|
499
|
+
html_content = self.to_html()
|
500
|
+
# Create a temporary file to hold the HTML
|
501
|
+
with tempfile.NamedTemporaryFile("w", delete=False, suffix=".html") as tmpfile:
|
502
|
+
tmpfile.write(html_content)
|
503
|
+
# Get the path to the temporary file
|
504
|
+
filepath = tmpfile.name
|
505
|
+
|
506
|
+
# Open the HTML file in a new browser tab
|
507
|
+
webbrowser.open("file://" + filepath)
|
508
|
+
|
509
|
+
@classmethod
|
510
|
+
def example(cls, randomize: bool = False) -> Cache:
|
511
|
+
"""
|
512
|
+
Returns an example Cache instance.
|
513
|
+
|
514
|
+
:param randomize: If True, uses CacheEntry's randomize method.
|
515
|
+
"""
|
516
|
+
return cls(
|
517
|
+
data={
|
518
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
519
|
+
CacheEntry.example(randomize).key: CacheEntry.example(),
|
520
|
+
}
|
521
|
+
)
|
522
|
+
|
523
|
+
|
524
|
+
if __name__ == "__main__":
|
525
|
+
import doctest
|
526
|
+
|
527
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|