edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +99 -22
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +4 -0
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +26 -5
- edsl/agents/AgentList.py +62 -7
- edsl/agents/Invigilator.py +4 -9
- edsl/agents/InvigilatorBase.py +5 -5
- edsl/agents/descriptors.py +3 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +628 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +191 -12
- edsl/coop/utils.py +20 -2
- edsl/data/Cache.py +55 -17
- edsl/data/CacheHandler.py +10 -9
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Jobs.py +240 -36
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/interviews/Interview.py +4 -1
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/language_models/LanguageModel.py +37 -44
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +95 -24
- edsl/notebooks/Notebook.py +119 -31
- edsl/questions/QuestionBase.py +109 -12
- edsl/questions/descriptors.py +5 -2
- edsl/questions/question_registry.py +7 -0
- edsl/results/Result.py +20 -8
- edsl/results/Results.py +85 -11
- edsl/results/ResultsDBMixin.py +3 -6
- edsl/results/ResultsExportMixin.py +47 -16
- edsl/results/ResultsToolsMixin.py +5 -5
- edsl/scenarios/Scenario.py +59 -5
- edsl/scenarios/ScenarioList.py +97 -40
- edsl/study/ObjectEntry.py +97 -0
- edsl/study/ProofOfWork.py +110 -0
- edsl/study/SnapShot.py +77 -0
- edsl/study/Study.py +491 -0
- edsl/study/__init__.py +2 -0
- edsl/surveys/Survey.py +79 -31
- edsl/surveys/SurveyExportMixin.py +21 -3
- edsl/utilities/__init__.py +1 -0
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +24 -28
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +57 -2
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
- edsl-0.1.28.dist-info/entry_points.txt +3 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -106,7 +106,8 @@ class Coop:
|
|
106
106
|
def create(
|
107
107
|
self,
|
108
108
|
object: EDSLObject,
|
109
|
-
|
109
|
+
description: Optional[str] = None,
|
110
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
110
111
|
) -> dict:
|
111
112
|
"""
|
112
113
|
Create an EDSL object in the Coop server.
|
@@ -117,6 +118,7 @@ class Coop:
|
|
117
118
|
uri=f"api/v0/object",
|
118
119
|
method="POST",
|
119
120
|
payload={
|
121
|
+
"description": description,
|
120
122
|
"object_type": object_type,
|
121
123
|
"json_string": json.dumps(
|
122
124
|
object.to_dict(),
|
@@ -131,8 +133,9 @@ class Coop:
|
|
131
133
|
return {
|
132
134
|
"uuid": response_json.get("uuid"),
|
133
135
|
"version": self._edsl_version,
|
136
|
+
"description": response_json.get("description"),
|
134
137
|
"visibility": response_json.get("visibility"),
|
135
|
-
"url": f"{self.url}/
|
138
|
+
"url": f"{self.url}/content/{response_json.get('uuid')}",
|
136
139
|
}
|
137
140
|
|
138
141
|
def get(
|
@@ -140,6 +143,7 @@ class Coop:
|
|
140
143
|
object_type: ObjectType = None,
|
141
144
|
uuid: Union[str, UUID] = None,
|
142
145
|
url: str = None,
|
146
|
+
exec_profile=None,
|
143
147
|
) -> EDSLObject:
|
144
148
|
"""
|
145
149
|
Retrieve an EDSL object either by object type & UUID, or by its url.
|
@@ -156,25 +160,37 @@ class Coop:
|
|
156
160
|
elif not object_type and not uuid:
|
157
161
|
raise Exception("Provide either object_type & UUID, or a url.")
|
158
162
|
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
163
|
+
import time
|
164
|
+
|
165
|
+
start = time.time()
|
159
166
|
response = self._send_server_request(
|
160
167
|
uri=f"api/v0/object",
|
161
168
|
method="GET",
|
162
169
|
params={"type": object_type, "uuid": uuid},
|
163
170
|
)
|
171
|
+
end = time.time()
|
172
|
+
if exec_profile:
|
173
|
+
print("Download exec time = ", end - start)
|
164
174
|
self._resolve_server_response(response)
|
165
175
|
json_string = response.json().get("json_string")
|
166
|
-
|
176
|
+
start = time.time()
|
177
|
+
res_object = edsl_class.from_dict(json.loads(json_string))
|
178
|
+
end = time.time()
|
179
|
+
if exec_profile:
|
180
|
+
print("Creating object exec time = ", end - start)
|
181
|
+
return res_object
|
167
182
|
|
168
183
|
def _get_base(
|
169
184
|
self,
|
170
185
|
cls: EDSLObject,
|
171
186
|
uuid: Union[str, UUID],
|
187
|
+
exec_profile=None,
|
172
188
|
) -> EDSLObject:
|
173
189
|
"""
|
174
190
|
Used by the Base class to offer a get functionality.
|
175
191
|
"""
|
176
192
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
177
|
-
return self.get(object_type, uuid)
|
193
|
+
return self.get(object_type, uuid, exec_profile=exec_profile)
|
178
194
|
|
179
195
|
def get_all(self, object_type: ObjectType) -> list[EDSLObject]:
|
180
196
|
"""
|
@@ -193,6 +209,7 @@ class Coop:
|
|
193
209
|
"object": edsl_class.from_dict(json.loads(o.get("json_string"))),
|
194
210
|
"uuid": o.get("uuid"),
|
195
211
|
"version": o.get("version"),
|
212
|
+
"description": o.get("description"),
|
196
213
|
"visibility": o.get("visibility"),
|
197
214
|
"url": f"{self.url}/explore/{object_page}/{o.get('uuid')}",
|
198
215
|
}
|
@@ -227,17 +244,36 @@ class Coop:
|
|
227
244
|
self,
|
228
245
|
object_type: ObjectType,
|
229
246
|
uuid: Union[str, UUID],
|
230
|
-
|
247
|
+
description: Optional[str] = None,
|
248
|
+
value: Optional[EDSLObject] = None,
|
249
|
+
visibility: Optional[VisibilityType] = None,
|
231
250
|
) -> dict:
|
232
251
|
"""
|
233
252
|
Change the attributes of an uploaded object
|
234
253
|
- Only supports visibility for now
|
235
254
|
"""
|
255
|
+
if description is None and visibility is None and value is None:
|
256
|
+
raise Exception("Nothing to patch.")
|
257
|
+
if value is not None:
|
258
|
+
value_type = ObjectRegistry.get_object_type_by_edsl_class(value)
|
259
|
+
if value_type != object_type:
|
260
|
+
raise Exception(f"Object type mismatch: {object_type=} {value_type=}")
|
236
261
|
response = self._send_server_request(
|
237
262
|
uri=f"api/v0/object",
|
238
263
|
method="PATCH",
|
239
264
|
params={"type": object_type, "uuid": uuid},
|
240
|
-
payload={
|
265
|
+
payload={
|
266
|
+
"description": description,
|
267
|
+
"json_string": (
|
268
|
+
json.dumps(
|
269
|
+
value.to_dict(),
|
270
|
+
default=self._json_handle_none,
|
271
|
+
)
|
272
|
+
if value
|
273
|
+
else None
|
274
|
+
),
|
275
|
+
"visibility": visibility,
|
276
|
+
},
|
241
277
|
)
|
242
278
|
self._resolve_server_response(response)
|
243
279
|
return response.json()
|
@@ -246,13 +282,15 @@ class Coop:
|
|
246
282
|
self,
|
247
283
|
cls: EDSLObject,
|
248
284
|
uuid: Union[str, UUID],
|
249
|
-
|
285
|
+
description: Optional[str] = None,
|
286
|
+
value: Optional[EDSLObject] = None,
|
287
|
+
visibility: Optional[VisibilityType] = None,
|
250
288
|
) -> dict:
|
251
289
|
"""
|
252
290
|
Used by the Base class to offer a patch functionality.
|
253
291
|
"""
|
254
292
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
255
|
-
return self.patch(object_type, uuid, visibility)
|
293
|
+
return self.patch(object_type, uuid, description, value, visibility)
|
256
294
|
|
257
295
|
################
|
258
296
|
# Remote Cache
|
@@ -261,9 +299,19 @@ class Coop:
|
|
261
299
|
self,
|
262
300
|
cache_entry: CacheEntry,
|
263
301
|
visibility: VisibilityType = "private",
|
302
|
+
description: Optional[str] = None,
|
264
303
|
) -> dict:
|
265
304
|
"""
|
266
305
|
Create a single remote cache entry.
|
306
|
+
If an entry with the same key already exists in the database, update it instead.
|
307
|
+
|
308
|
+
:param cache_entry: The cache entry to send to the server.
|
309
|
+
:param visibility: The visibility of the cache entry.
|
310
|
+
:param optional description: A description for this entry in the remote cache.
|
311
|
+
|
312
|
+
>>> entry = CacheEntry.example()
|
313
|
+
>>> coop.remote_cache_create(cache_entry=entry)
|
314
|
+
{'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
|
267
315
|
"""
|
268
316
|
response = self._send_server_request(
|
269
317
|
uri="api/v0/remote-cache",
|
@@ -272,24 +320,44 @@ class Coop:
|
|
272
320
|
"json_string": json.dumps(cache_entry.to_dict()),
|
273
321
|
"version": self._edsl_version,
|
274
322
|
"visibility": visibility,
|
323
|
+
"description": description,
|
275
324
|
},
|
276
325
|
)
|
277
326
|
self._resolve_server_response(response)
|
327
|
+
response_json = response.json()
|
328
|
+
created_entry_count = response_json.get("created_entry_count", 0)
|
329
|
+
if created_entry_count > 0:
|
330
|
+
self.remote_cache_create_log(
|
331
|
+
response,
|
332
|
+
description="Upload new cache entries to server",
|
333
|
+
cache_entry_count=created_entry_count,
|
334
|
+
)
|
278
335
|
return response.json()
|
279
336
|
|
280
337
|
def remote_cache_create_many(
|
281
338
|
self,
|
282
339
|
cache_entries: list[CacheEntry],
|
283
340
|
visibility: VisibilityType = "private",
|
341
|
+
description: Optional[str] = None,
|
284
342
|
) -> dict:
|
285
343
|
"""
|
286
344
|
Create many remote cache entries.
|
345
|
+
If an entry with the same key already exists in the database, update it instead.
|
346
|
+
|
347
|
+
:param cache_entries: The list of cache entries to send to the server.
|
348
|
+
:param visibility: The visibility of the cache entries.
|
349
|
+
:param optional description: A description for these entries in the remote cache.
|
350
|
+
|
351
|
+
>>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
352
|
+
>>> coop.remote_cache_create_many(cache_entries=entries)
|
353
|
+
{'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
|
287
354
|
"""
|
288
355
|
payload = [
|
289
356
|
{
|
290
357
|
"json_string": json.dumps(c.to_dict()),
|
291
358
|
"version": self._edsl_version,
|
292
359
|
"visibility": visibility,
|
360
|
+
"description": description,
|
293
361
|
}
|
294
362
|
for c in cache_entries
|
295
363
|
]
|
@@ -299,6 +367,14 @@ class Coop:
|
|
299
367
|
payload=payload,
|
300
368
|
)
|
301
369
|
self._resolve_server_response(response)
|
370
|
+
response_json = response.json()
|
371
|
+
created_entry_count = response_json.get("created_entry_count", 0)
|
372
|
+
if created_entry_count > 0:
|
373
|
+
self.remote_cache_create_log(
|
374
|
+
response,
|
375
|
+
description="Upload new cache entries to server",
|
376
|
+
cache_entry_count=created_entry_count,
|
377
|
+
)
|
302
378
|
return response.json()
|
303
379
|
|
304
380
|
def remote_cache_get(
|
@@ -307,7 +383,11 @@ class Coop:
|
|
307
383
|
) -> list[CacheEntry]:
|
308
384
|
"""
|
309
385
|
Get all remote cache entries.
|
310
|
-
|
386
|
+
|
387
|
+
:param optional exclude_keys: Exclude CacheEntry objects with these keys.
|
388
|
+
|
389
|
+
>>> coop.remote_cache_get()
|
390
|
+
[CacheEntry(...), CacheEntry(...), ...]
|
311
391
|
"""
|
312
392
|
if exclude_keys is None:
|
313
393
|
exclude_keys = []
|
@@ -322,15 +402,93 @@ class Coop:
|
|
322
402
|
for v in response.json()
|
323
403
|
]
|
324
404
|
|
405
|
+
def remote_cache_get_diff(
|
406
|
+
self,
|
407
|
+
client_cacheentry_keys: list[str],
|
408
|
+
) -> dict:
|
409
|
+
"""
|
410
|
+
Get the difference between local and remote cache entries for a user.
|
411
|
+
"""
|
412
|
+
response = self._send_server_request(
|
413
|
+
uri="api/v0/remote-cache/get-diff",
|
414
|
+
method="POST",
|
415
|
+
payload={"keys": client_cacheentry_keys},
|
416
|
+
)
|
417
|
+
self._resolve_server_response(response)
|
418
|
+
response_json = response.json()
|
419
|
+
response_dict = {
|
420
|
+
"client_missing_cacheentries": [
|
421
|
+
CacheEntry.from_dict(json.loads(c.get("json_string")))
|
422
|
+
for c in response_json.get("client_missing_cacheentries", [])
|
423
|
+
],
|
424
|
+
"server_missing_cacheentry_keys": response_json.get(
|
425
|
+
"server_missing_cacheentry_keys", []
|
426
|
+
),
|
427
|
+
}
|
428
|
+
downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
|
429
|
+
if downloaded_entry_count > 0:
|
430
|
+
self.remote_cache_create_log(
|
431
|
+
response,
|
432
|
+
description="Download missing cache entries to client",
|
433
|
+
cache_entry_count=downloaded_entry_count,
|
434
|
+
)
|
435
|
+
return response_dict
|
436
|
+
|
325
437
|
def remote_cache_clear(self) -> dict:
|
326
438
|
"""
|
327
439
|
Clear all remote cache entries.
|
440
|
+
|
441
|
+
>>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
|
442
|
+
>>> coop.remote_cache_create_many(cache_entries=entries)
|
443
|
+
>>> coop.remote_cache_clear()
|
444
|
+
{'status': 'success', 'deleted_entry_count': 10}
|
328
445
|
"""
|
329
446
|
response = self._send_server_request(
|
330
447
|
uri="api/v0/remote-cache/delete-all",
|
331
448
|
method="DELETE",
|
332
449
|
)
|
333
450
|
self._resolve_server_response(response)
|
451
|
+
response_json = response.json()
|
452
|
+
deleted_entry_count = response_json.get("deleted_entry_count", 0)
|
453
|
+
if deleted_entry_count > 0:
|
454
|
+
self.remote_cache_create_log(
|
455
|
+
response,
|
456
|
+
description="Clear cache entries",
|
457
|
+
cache_entry_count=deleted_entry_count,
|
458
|
+
)
|
459
|
+
return response.json()
|
460
|
+
|
461
|
+
def remote_cache_create_log(
|
462
|
+
self, response: requests.Response, description: str, cache_entry_count: int
|
463
|
+
) -> Union[dict, None]:
|
464
|
+
"""
|
465
|
+
If a remote cache action has been completed successfully,
|
466
|
+
log the action.
|
467
|
+
"""
|
468
|
+
if 200 <= response.status_code < 300:
|
469
|
+
log_response = self._send_server_request(
|
470
|
+
uri="api/v0/remote-cache-log",
|
471
|
+
method="POST",
|
472
|
+
payload={
|
473
|
+
"description": description,
|
474
|
+
"cache_entry_count": cache_entry_count,
|
475
|
+
},
|
476
|
+
)
|
477
|
+
self._resolve_server_response(log_response)
|
478
|
+
return response.json()
|
479
|
+
|
480
|
+
def remote_cache_clear_log(self) -> dict:
|
481
|
+
"""
|
482
|
+
Clear all remote cache log entries.
|
483
|
+
|
484
|
+
>>> coop.remote_cache_clear_log()
|
485
|
+
{'status': 'success'}
|
486
|
+
"""
|
487
|
+
response = self._send_server_request(
|
488
|
+
uri="api/v0/remote-cache-log/delete-all",
|
489
|
+
method="DELETE",
|
490
|
+
)
|
491
|
+
self._resolve_server_response(response)
|
334
492
|
return response.json()
|
335
493
|
|
336
494
|
################
|
@@ -438,7 +596,7 @@ if __name__ == "__main__":
|
|
438
596
|
Agent,
|
439
597
|
AgentList,
|
440
598
|
Cache,
|
441
|
-
|
599
|
+
Notebook,
|
442
600
|
QuestionMultipleChoice,
|
443
601
|
Results,
|
444
602
|
Scenario,
|
@@ -450,7 +608,7 @@ if __name__ == "__main__":
|
|
450
608
|
("agent", Agent),
|
451
609
|
("agent_list", AgentList),
|
452
610
|
("cache", Cache),
|
453
|
-
("
|
611
|
+
("notebook", Notebook),
|
454
612
|
("question", QuestionMultipleChoice),
|
455
613
|
("results", Results),
|
456
614
|
("scenario", Scenario),
|
@@ -469,7 +627,9 @@ if __name__ == "__main__":
|
|
469
627
|
response_1 = coop.create(example)
|
470
628
|
response_2 = coop.create(cls.example(), visibility="private")
|
471
629
|
response_3 = coop.create(cls.example(), visibility="public")
|
472
|
-
response_4 = coop.create(
|
630
|
+
response_4 = coop.create(
|
631
|
+
cls.example(), visibility="unlisted", description="hey"
|
632
|
+
)
|
473
633
|
# 3. Retrieve all objects
|
474
634
|
objects = coop.get_all(object_type)
|
475
635
|
assert len(objects) == 4
|
@@ -486,14 +646,33 @@ if __name__ == "__main__":
|
|
486
646
|
coop.patch(
|
487
647
|
object_type=object_type, uuid=item.get("uuid"), visibility="private"
|
488
648
|
)
|
649
|
+
# 6. Change description of all objects
|
650
|
+
for item in objects:
|
651
|
+
coop.patch(
|
652
|
+
object_type=object_type, uuid=item.get("uuid"), description="hey"
|
653
|
+
)
|
489
654
|
# 7. Delete all objects
|
490
655
|
for item in objects:
|
491
656
|
coop.delete(object_type=object_type, uuid=item.get("uuid"))
|
492
657
|
assert len(coop.get_all(object_type)) == 0
|
493
658
|
|
659
|
+
# a simple example
|
660
|
+
from edsl import Coop, QuestionMultipleChoice, QuestionFreeText
|
661
|
+
|
662
|
+
coop = Coop(api_key="b")
|
494
663
|
response = QuestionMultipleChoice.example().push()
|
495
664
|
QuestionMultipleChoice.pull(response.get("uuid"))
|
496
665
|
coop.patch(object_type="question", uuid=response.get("uuid"), visibility="public")
|
666
|
+
coop.patch(
|
667
|
+
object_type="question",
|
668
|
+
uuid=response.get("uuid"),
|
669
|
+
description="crazy new description",
|
670
|
+
)
|
671
|
+
coop.patch(
|
672
|
+
object_type="question",
|
673
|
+
uuid=response.get("uuid"),
|
674
|
+
value=QuestionFreeText.example(),
|
675
|
+
)
|
497
676
|
coop.delete(object_type="question", uuid=response.get("uuid"))
|
498
677
|
|
499
678
|
##############
|
edsl/coop/utils.py
CHANGED
@@ -1,5 +1,15 @@
|
|
1
|
-
from edsl import
|
2
|
-
|
1
|
+
from edsl import (
|
2
|
+
Agent,
|
3
|
+
AgentList,
|
4
|
+
Cache,
|
5
|
+
Jobs,
|
6
|
+
Notebook,
|
7
|
+
Results,
|
8
|
+
Scenario,
|
9
|
+
ScenarioList,
|
10
|
+
Survey,
|
11
|
+
Study,
|
12
|
+
)
|
3
13
|
from edsl.questions import QuestionBase
|
4
14
|
from typing import Literal, Type, Union
|
5
15
|
|
@@ -14,6 +24,7 @@ EDSLObject = Union[
|
|
14
24
|
Scenario,
|
15
25
|
ScenarioList,
|
16
26
|
Survey,
|
27
|
+
Study,
|
17
28
|
]
|
18
29
|
|
19
30
|
ObjectType = Literal[
|
@@ -27,6 +38,7 @@ ObjectType = Literal[
|
|
27
38
|
"scenario",
|
28
39
|
"scenario_list",
|
29
40
|
"survey",
|
41
|
+
"study",
|
30
42
|
]
|
31
43
|
|
32
44
|
ObjectPage = Literal[
|
@@ -40,6 +52,7 @@ ObjectPage = Literal[
|
|
40
52
|
"scenarios",
|
41
53
|
"scenariolists",
|
42
54
|
"surveys",
|
55
|
+
"studies",
|
43
56
|
]
|
44
57
|
|
45
58
|
VisibilityType = Literal[
|
@@ -105,6 +118,11 @@ class ObjectRegistry:
|
|
105
118
|
"edsl_class": Survey,
|
106
119
|
"object_page": "surveys",
|
107
120
|
},
|
121
|
+
{
|
122
|
+
"object_type": "study",
|
123
|
+
"edsl_class": Study,
|
124
|
+
"object_page": "studies",
|
125
|
+
},
|
108
126
|
]
|
109
127
|
object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
|
110
128
|
edsl_class_to_object_type = {
|
edsl/data/Cache.py
CHANGED
@@ -12,6 +12,7 @@ from edsl.config import CONFIG
|
|
12
12
|
from edsl.data.CacheEntry import CacheEntry
|
13
13
|
from edsl.data.SQLiteDict import SQLiteDict
|
14
14
|
from edsl.Base import Base
|
15
|
+
from edsl.utilities.utilities import dict_hash
|
15
16
|
|
16
17
|
from edsl.utilities.decorators import (
|
17
18
|
add_edsl_version,
|
@@ -24,7 +25,6 @@ class Cache(Base):
|
|
24
25
|
A class that represents a cache of responses from a language model.
|
25
26
|
|
26
27
|
:param data: The data to initialize the cache with.
|
27
|
-
:param remote: Whether to sync the Cache with the server.
|
28
28
|
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
29
29
|
|
30
30
|
Deprecated:
|
@@ -37,24 +37,51 @@ class Cache(Base):
|
|
37
37
|
def __init__(
|
38
38
|
self,
|
39
39
|
*,
|
40
|
+
filename: Optional[str] = None,
|
40
41
|
data: Optional[Union[SQLiteDict, dict]] = None,
|
41
|
-
remote: bool = False,
|
42
42
|
immediate_write: bool = True,
|
43
43
|
method=None,
|
44
44
|
):
|
45
45
|
"""
|
46
46
|
Create two dictionaries to store the cache data.
|
47
47
|
|
48
|
+
:param filename: The name of the file to read/write the cache from/to.
|
49
|
+
:param data: The data to initialize the cache with.
|
50
|
+
:param immediate_write: Whether to write to the cache immediately after storing a new entry.
|
51
|
+
:param method: The method of storage to use for the cache.
|
52
|
+
|
48
53
|
"""
|
49
|
-
|
54
|
+
|
50
55
|
# self.data_at_init = data or {}
|
51
56
|
self.fetched_data = {}
|
52
|
-
self.remote = remote
|
53
57
|
self.immediate_write = immediate_write
|
54
58
|
self.method = method
|
55
59
|
self.new_entries = {}
|
56
60
|
self.new_entries_to_write_later = {}
|
57
61
|
self.coop = None
|
62
|
+
|
63
|
+
self.filename = filename
|
64
|
+
if filename and data:
|
65
|
+
raise ValueError("Cannot provide both filename and data")
|
66
|
+
if filename is None and data is None:
|
67
|
+
data = {}
|
68
|
+
if data is not None:
|
69
|
+
self.data = data
|
70
|
+
if filename is not None:
|
71
|
+
self.data = {}
|
72
|
+
if filename.endswith(".jsonl"):
|
73
|
+
if os.path.exists(filename):
|
74
|
+
self.add_from_jsonl(filename)
|
75
|
+
else:
|
76
|
+
print(
|
77
|
+
f"File {filename} not found, but will write to this location."
|
78
|
+
)
|
79
|
+
elif filename.endswith(".db"):
|
80
|
+
if os.path.exists(filename):
|
81
|
+
self.add_from_sqlite(filename)
|
82
|
+
else:
|
83
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
84
|
+
|
58
85
|
self._perform_checks()
|
59
86
|
|
60
87
|
def rich_print(sefl):
|
@@ -81,10 +108,6 @@ class Cache(Base):
|
|
81
108
|
raise Exception("Not all values are CacheEntry instances")
|
82
109
|
if self.method is not None:
|
83
110
|
warnings.warn("Argument `method` is deprecated", DeprecationWarning)
|
84
|
-
if self.remote:
|
85
|
-
from edsl.coop import Coop
|
86
|
-
|
87
|
-
self.coop = Coop()
|
88
111
|
|
89
112
|
####################
|
90
113
|
# READ/WRITE
|
@@ -267,6 +290,19 @@ class Cache(Base):
|
|
267
290
|
for key, value in self.data.items():
|
268
291
|
new_data[key] = value
|
269
292
|
|
293
|
+
def write(self, filename: Optional[str] = None) -> None:
|
294
|
+
"""
|
295
|
+
Write the cache to a file at the specified location.
|
296
|
+
"""
|
297
|
+
if filename is None:
|
298
|
+
filename = self.filename
|
299
|
+
if filename.endswith(".jsonl"):
|
300
|
+
self.write_jsonl(filename)
|
301
|
+
elif filename.endswith(".db"):
|
302
|
+
self.write_sqlite_db(filename)
|
303
|
+
else:
|
304
|
+
raise ValueError("Invalid file extension. Must be .jsonl or .db")
|
305
|
+
|
270
306
|
def write_jsonl(self, filename: str) -> None:
|
271
307
|
"""
|
272
308
|
Write the cache to a JSONL file.
|
@@ -295,11 +331,6 @@ class Cache(Base):
|
|
295
331
|
"""
|
296
332
|
Run when a context is entered.
|
297
333
|
"""
|
298
|
-
if self.remote:
|
299
|
-
print("Syncing local and remote caches")
|
300
|
-
exclude_keys = list(self.data.keys())
|
301
|
-
cache_entries = self.coop.get_cache_entries(exclude_keys)
|
302
|
-
self.add_from_dict({c.key: c for c in cache_entries}, write_now=True)
|
303
334
|
return self
|
304
335
|
|
305
336
|
def __exit__(self, exc_type, exc_value, traceback):
|
@@ -308,16 +339,21 @@ class Cache(Base):
|
|
308
339
|
"""
|
309
340
|
for key, entry in self.new_entries_to_write_later.items():
|
310
341
|
self.data[key] = entry
|
311
|
-
if self.remote:
|
312
|
-
_ = self.coop.create_cache_entries(cache_dict=self.new_entries)
|
313
342
|
|
314
343
|
####################
|
315
344
|
# DUNDER / USEFUL
|
316
345
|
####################
|
346
|
+
def __hash__(self):
|
347
|
+
"""Return the hash of the Cache."""
|
348
|
+
return dict_hash(self._to_dict())
|
349
|
+
|
350
|
+
def _to_dict(self) -> dict:
|
351
|
+
return {k: v.to_dict() for k, v in self.data.items()}
|
352
|
+
|
317
353
|
@add_edsl_version
|
318
354
|
def to_dict(self) -> dict:
|
319
355
|
"""Return the Cache as a dictionary."""
|
320
|
-
return
|
356
|
+
return self._to_dict()
|
321
357
|
|
322
358
|
def _repr_html_(self):
|
323
359
|
from edsl.utilities.utilities import data_to_html
|
@@ -359,7 +395,9 @@ class Cache(Base):
|
|
359
395
|
"""
|
360
396
|
Return a string representation of the Cache object.
|
361
397
|
"""
|
362
|
-
return
|
398
|
+
return (
|
399
|
+
f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
|
400
|
+
)
|
363
401
|
|
364
402
|
####################
|
365
403
|
# EXAMPLES
|
edsl/data/CacheHandler.py
CHANGED
@@ -9,22 +9,22 @@ from edsl.data.Cache import Cache
|
|
9
9
|
from edsl.data.CacheEntry import CacheEntry
|
10
10
|
from edsl.data.SQLiteDict import SQLiteDict
|
11
11
|
|
12
|
+
from edsl.config import CONFIG
|
13
|
+
|
12
14
|
|
13
15
|
def set_session_cache(cache: Cache) -> None:
|
14
16
|
"""
|
15
17
|
Set the session cache.
|
16
18
|
"""
|
17
|
-
|
18
|
-
global _CACHE
|
19
|
-
_CACHE = cache
|
19
|
+
CONFIG.EDSL_SESSION_CACHE = cache
|
20
20
|
|
21
21
|
|
22
22
|
def unset_session_cache() -> None:
|
23
23
|
"""
|
24
24
|
Unset the session cache.
|
25
25
|
"""
|
26
|
-
|
27
|
-
|
26
|
+
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
27
|
+
del CONFIG.EDSL_SESSION_CACHE
|
28
28
|
|
29
29
|
|
30
30
|
class CacheHandler:
|
@@ -49,7 +49,9 @@ class CacheHandler:
|
|
49
49
|
dir_path = os.path.dirname(path)
|
50
50
|
if dir_path and not os.path.exists(dir_path):
|
51
51
|
os.makedirs(dir_path)
|
52
|
-
|
52
|
+
import warnings
|
53
|
+
|
54
|
+
warnings.warn(f"Created cache directory: {dir_path}")
|
53
55
|
|
54
56
|
def gen_cache(self) -> Cache:
|
55
57
|
"""
|
@@ -58,9 +60,8 @@ class CacheHandler:
|
|
58
60
|
if self.test:
|
59
61
|
return Cache(data={})
|
60
62
|
|
61
|
-
if "
|
62
|
-
|
63
|
-
return _CACHE
|
63
|
+
if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
|
64
|
+
return CONFIG.EDSL_SESSION_CACHE
|
64
65
|
|
65
66
|
cache = Cache(data=SQLiteDict(self.CACHE_PATH))
|
66
67
|
return cache
|