edsl 0.1.27.dev2__py3-none-any.whl → 0.1.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. edsl/Base.py +99 -22
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +4 -0
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +26 -5
  6. edsl/agents/AgentList.py +62 -7
  7. edsl/agents/Invigilator.py +4 -9
  8. edsl/agents/InvigilatorBase.py +5 -5
  9. edsl/agents/descriptors.py +3 -1
  10. edsl/conjure/AgentConstructionMixin.py +152 -0
  11. edsl/conjure/Conjure.py +56 -0
  12. edsl/conjure/InputData.py +628 -0
  13. edsl/conjure/InputDataCSV.py +48 -0
  14. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  15. edsl/conjure/InputDataPyRead.py +91 -0
  16. edsl/conjure/InputDataSPSS.py +8 -0
  17. edsl/conjure/InputDataStata.py +8 -0
  18. edsl/conjure/QuestionOptionMixin.py +76 -0
  19. edsl/conjure/QuestionTypeMixin.py +23 -0
  20. edsl/conjure/RawQuestion.py +65 -0
  21. edsl/conjure/SurveyResponses.py +7 -0
  22. edsl/conjure/__init__.py +9 -4
  23. edsl/conjure/examples/placeholder.txt +0 -0
  24. edsl/conjure/naming_utilities.py +263 -0
  25. edsl/conjure/utilities.py +165 -28
  26. edsl/conversation/Conversation.py +238 -0
  27. edsl/conversation/car_buying.py +58 -0
  28. edsl/conversation/mug_negotiation.py +81 -0
  29. edsl/conversation/next_speaker_utilities.py +93 -0
  30. edsl/coop/coop.py +191 -12
  31. edsl/coop/utils.py +20 -2
  32. edsl/data/Cache.py +55 -17
  33. edsl/data/CacheHandler.py +10 -9
  34. edsl/inference_services/AnthropicService.py +1 -0
  35. edsl/inference_services/DeepInfraService.py +20 -13
  36. edsl/inference_services/GoogleService.py +7 -1
  37. edsl/inference_services/InferenceServicesCollection.py +33 -7
  38. edsl/inference_services/OpenAIService.py +17 -10
  39. edsl/inference_services/models_available_cache.py +69 -0
  40. edsl/inference_services/rate_limits_cache.py +25 -0
  41. edsl/inference_services/write_available.py +10 -0
  42. edsl/jobs/Jobs.py +240 -36
  43. edsl/jobs/buckets/BucketCollection.py +9 -3
  44. edsl/jobs/interviews/Interview.py +4 -1
  45. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +24 -10
  46. edsl/jobs/interviews/retry_management.py +4 -4
  47. edsl/jobs/runners/JobsRunnerAsyncio.py +87 -45
  48. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  49. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  50. edsl/language_models/LanguageModel.py +37 -44
  51. edsl/language_models/ModelList.py +96 -0
  52. edsl/language_models/registry.py +14 -0
  53. edsl/language_models/repair.py +95 -24
  54. edsl/notebooks/Notebook.py +119 -31
  55. edsl/questions/QuestionBase.py +109 -12
  56. edsl/questions/descriptors.py +5 -2
  57. edsl/questions/question_registry.py +7 -0
  58. edsl/results/Result.py +20 -8
  59. edsl/results/Results.py +85 -11
  60. edsl/results/ResultsDBMixin.py +3 -6
  61. edsl/results/ResultsExportMixin.py +47 -16
  62. edsl/results/ResultsToolsMixin.py +5 -5
  63. edsl/scenarios/Scenario.py +59 -5
  64. edsl/scenarios/ScenarioList.py +97 -40
  65. edsl/study/ObjectEntry.py +97 -0
  66. edsl/study/ProofOfWork.py +110 -0
  67. edsl/study/SnapShot.py +77 -0
  68. edsl/study/Study.py +491 -0
  69. edsl/study/__init__.py +2 -0
  70. edsl/surveys/Survey.py +79 -31
  71. edsl/surveys/SurveyExportMixin.py +21 -3
  72. edsl/utilities/__init__.py +1 -0
  73. edsl/utilities/gcp_bucket/__init__.py +0 -0
  74. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  75. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  76. edsl/utilities/interface.py +24 -28
  77. edsl/utilities/repair_functions.py +28 -0
  78. edsl/utilities/utilities.py +57 -2
  79. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/METADATA +43 -17
  80. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/RECORD +83 -55
  81. edsl-0.1.28.dist-info/entry_points.txt +3 -0
  82. edsl/conjure/RawResponseColumn.py +0 -327
  83. edsl/conjure/SurveyBuilder.py +0 -308
  84. edsl/conjure/SurveyBuilderCSV.py +0 -78
  85. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  86. edsl/data/RemoteDict.py +0 -103
  87. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/LICENSE +0 -0
  88. {edsl-0.1.27.dev2.dist-info → edsl-0.1.28.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -106,7 +106,8 @@ class Coop:
106
106
  def create(
107
107
  self,
108
108
  object: EDSLObject,
109
- visibility: VisibilityType = "unlisted",
109
+ description: Optional[str] = None,
110
+ visibility: Optional[VisibilityType] = "unlisted",
110
111
  ) -> dict:
111
112
  """
112
113
  Create an EDSL object in the Coop server.
@@ -117,6 +118,7 @@ class Coop:
117
118
  uri=f"api/v0/object",
118
119
  method="POST",
119
120
  payload={
121
+ "description": description,
120
122
  "object_type": object_type,
121
123
  "json_string": json.dumps(
122
124
  object.to_dict(),
@@ -131,8 +133,9 @@ class Coop:
131
133
  return {
132
134
  "uuid": response_json.get("uuid"),
133
135
  "version": self._edsl_version,
136
+ "description": response_json.get("description"),
134
137
  "visibility": response_json.get("visibility"),
135
- "url": f"{self.url}/explore/{object_page}/{response_json.get('uuid')}",
138
+ "url": f"{self.url}/content/{response_json.get('uuid')}",
136
139
  }
137
140
 
138
141
  def get(
@@ -140,6 +143,7 @@ class Coop:
140
143
  object_type: ObjectType = None,
141
144
  uuid: Union[str, UUID] = None,
142
145
  url: str = None,
146
+ exec_profile=None,
143
147
  ) -> EDSLObject:
144
148
  """
145
149
  Retrieve an EDSL object either by object type & UUID, or by its url.
@@ -156,25 +160,37 @@ class Coop:
156
160
  elif not object_type and not uuid:
157
161
  raise Exception("Provide either object_type & UUID, or a url.")
158
162
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
163
+ import time
164
+
165
+ start = time.time()
159
166
  response = self._send_server_request(
160
167
  uri=f"api/v0/object",
161
168
  method="GET",
162
169
  params={"type": object_type, "uuid": uuid},
163
170
  )
171
+ end = time.time()
172
+ if exec_profile:
173
+ print("Download exec time = ", end - start)
164
174
  self._resolve_server_response(response)
165
175
  json_string = response.json().get("json_string")
166
- return edsl_class.from_dict(json.loads(json_string))
176
+ start = time.time()
177
+ res_object = edsl_class.from_dict(json.loads(json_string))
178
+ end = time.time()
179
+ if exec_profile:
180
+ print("Creating object exec time = ", end - start)
181
+ return res_object
167
182
 
168
183
  def _get_base(
169
184
  self,
170
185
  cls: EDSLObject,
171
186
  uuid: Union[str, UUID],
187
+ exec_profile=None,
172
188
  ) -> EDSLObject:
173
189
  """
174
190
  Used by the Base class to offer a get functionality.
175
191
  """
176
192
  object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
177
- return self.get(object_type, uuid)
193
+ return self.get(object_type, uuid, exec_profile=exec_profile)
178
194
 
179
195
  def get_all(self, object_type: ObjectType) -> list[EDSLObject]:
180
196
  """
@@ -193,6 +209,7 @@ class Coop:
193
209
  "object": edsl_class.from_dict(json.loads(o.get("json_string"))),
194
210
  "uuid": o.get("uuid"),
195
211
  "version": o.get("version"),
212
+ "description": o.get("description"),
196
213
  "visibility": o.get("visibility"),
197
214
  "url": f"{self.url}/explore/{object_page}/{o.get('uuid')}",
198
215
  }
@@ -227,17 +244,36 @@ class Coop:
227
244
  self,
228
245
  object_type: ObjectType,
229
246
  uuid: Union[str, UUID],
230
- visibility: VisibilityType,
247
+ description: Optional[str] = None,
248
+ value: Optional[EDSLObject] = None,
249
+ visibility: Optional[VisibilityType] = None,
231
250
  ) -> dict:
232
251
  """
233
252
  Change the attributes of an uploaded object
234
253
  - Only supports visibility for now
235
254
  """
255
+ if description is None and visibility is None and value is None:
256
+ raise Exception("Nothing to patch.")
257
+ if value is not None:
258
+ value_type = ObjectRegistry.get_object_type_by_edsl_class(value)
259
+ if value_type != object_type:
260
+ raise Exception(f"Object type mismatch: {object_type=} {value_type=}")
236
261
  response = self._send_server_request(
237
262
  uri=f"api/v0/object",
238
263
  method="PATCH",
239
264
  params={"type": object_type, "uuid": uuid},
240
- payload={"visibility": visibility},
265
+ payload={
266
+ "description": description,
267
+ "json_string": (
268
+ json.dumps(
269
+ value.to_dict(),
270
+ default=self._json_handle_none,
271
+ )
272
+ if value
273
+ else None
274
+ ),
275
+ "visibility": visibility,
276
+ },
241
277
  )
242
278
  self._resolve_server_response(response)
243
279
  return response.json()
@@ -246,13 +282,15 @@ class Coop:
246
282
  self,
247
283
  cls: EDSLObject,
248
284
  uuid: Union[str, UUID],
249
- visibility: VisibilityType,
285
+ description: Optional[str] = None,
286
+ value: Optional[EDSLObject] = None,
287
+ visibility: Optional[VisibilityType] = None,
250
288
  ) -> dict:
251
289
  """
252
290
  Used by the Base class to offer a patch functionality.
253
291
  """
254
292
  object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
255
- return self.patch(object_type, uuid, visibility)
293
+ return self.patch(object_type, uuid, description, value, visibility)
256
294
 
257
295
  ################
258
296
  # Remote Cache
@@ -261,9 +299,19 @@ class Coop:
261
299
  self,
262
300
  cache_entry: CacheEntry,
263
301
  visibility: VisibilityType = "private",
302
+ description: Optional[str] = None,
264
303
  ) -> dict:
265
304
  """
266
305
  Create a single remote cache entry.
306
+ If an entry with the same key already exists in the database, update it instead.
307
+
308
+ :param cache_entry: The cache entry to send to the server.
309
+ :param visibility: The visibility of the cache entry.
310
+ :param optional description: A description for this entry in the remote cache.
311
+
312
+ >>> entry = CacheEntry.example()
313
+ >>> coop.remote_cache_create(cache_entry=entry)
314
+ {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
267
315
  """
268
316
  response = self._send_server_request(
269
317
  uri="api/v0/remote-cache",
@@ -272,24 +320,44 @@ class Coop:
272
320
  "json_string": json.dumps(cache_entry.to_dict()),
273
321
  "version": self._edsl_version,
274
322
  "visibility": visibility,
323
+ "description": description,
275
324
  },
276
325
  )
277
326
  self._resolve_server_response(response)
327
+ response_json = response.json()
328
+ created_entry_count = response_json.get("created_entry_count", 0)
329
+ if created_entry_count > 0:
330
+ self.remote_cache_create_log(
331
+ response,
332
+ description="Upload new cache entries to server",
333
+ cache_entry_count=created_entry_count,
334
+ )
278
335
  return response.json()
279
336
 
280
337
  def remote_cache_create_many(
281
338
  self,
282
339
  cache_entries: list[CacheEntry],
283
340
  visibility: VisibilityType = "private",
341
+ description: Optional[str] = None,
284
342
  ) -> dict:
285
343
  """
286
344
  Create many remote cache entries.
345
+ If an entry with the same key already exists in the database, update it instead.
346
+
347
+ :param cache_entries: The list of cache entries to send to the server.
348
+ :param visibility: The visibility of the cache entries.
349
+ :param optional description: A description for these entries in the remote cache.
350
+
351
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
352
+ >>> coop.remote_cache_create_many(cache_entries=entries)
353
+ {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
287
354
  """
288
355
  payload = [
289
356
  {
290
357
  "json_string": json.dumps(c.to_dict()),
291
358
  "version": self._edsl_version,
292
359
  "visibility": visibility,
360
+ "description": description,
293
361
  }
294
362
  for c in cache_entries
295
363
  ]
@@ -299,6 +367,14 @@ class Coop:
299
367
  payload=payload,
300
368
  )
301
369
  self._resolve_server_response(response)
370
+ response_json = response.json()
371
+ created_entry_count = response_json.get("created_entry_count", 0)
372
+ if created_entry_count > 0:
373
+ self.remote_cache_create_log(
374
+ response,
375
+ description="Upload new cache entries to server",
376
+ cache_entry_count=created_entry_count,
377
+ )
302
378
  return response.json()
303
379
 
304
380
  def remote_cache_get(
@@ -307,7 +383,11 @@ class Coop:
307
383
  ) -> list[CacheEntry]:
308
384
  """
309
385
  Get all remote cache entries.
310
- - optional exclude_keys: exclude CacheEntry objects with these keys.
386
+
387
+ :param optional exclude_keys: Exclude CacheEntry objects with these keys.
388
+
389
+ >>> coop.remote_cache_get()
390
+ [CacheEntry(...), CacheEntry(...), ...]
311
391
  """
312
392
  if exclude_keys is None:
313
393
  exclude_keys = []
@@ -322,15 +402,93 @@ class Coop:
322
402
  for v in response.json()
323
403
  ]
324
404
 
405
+ def remote_cache_get_diff(
406
+ self,
407
+ client_cacheentry_keys: list[str],
408
+ ) -> dict:
409
+ """
410
+ Get the difference between local and remote cache entries for a user.
411
+ """
412
+ response = self._send_server_request(
413
+ uri="api/v0/remote-cache/get-diff",
414
+ method="POST",
415
+ payload={"keys": client_cacheentry_keys},
416
+ )
417
+ self._resolve_server_response(response)
418
+ response_json = response.json()
419
+ response_dict = {
420
+ "client_missing_cacheentries": [
421
+ CacheEntry.from_dict(json.loads(c.get("json_string")))
422
+ for c in response_json.get("client_missing_cacheentries", [])
423
+ ],
424
+ "server_missing_cacheentry_keys": response_json.get(
425
+ "server_missing_cacheentry_keys", []
426
+ ),
427
+ }
428
+ downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
429
+ if downloaded_entry_count > 0:
430
+ self.remote_cache_create_log(
431
+ response,
432
+ description="Download missing cache entries to client",
433
+ cache_entry_count=downloaded_entry_count,
434
+ )
435
+ return response_dict
436
+
325
437
  def remote_cache_clear(self) -> dict:
326
438
  """
327
439
  Clear all remote cache entries.
440
+
441
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
442
+ >>> coop.remote_cache_create_many(cache_entries=entries)
443
+ >>> coop.remote_cache_clear()
444
+ {'status': 'success', 'deleted_entry_count': 10}
328
445
  """
329
446
  response = self._send_server_request(
330
447
  uri="api/v0/remote-cache/delete-all",
331
448
  method="DELETE",
332
449
  )
333
450
  self._resolve_server_response(response)
451
+ response_json = response.json()
452
+ deleted_entry_count = response_json.get("deleted_entry_count", 0)
453
+ if deleted_entry_count > 0:
454
+ self.remote_cache_create_log(
455
+ response,
456
+ description="Clear cache entries",
457
+ cache_entry_count=deleted_entry_count,
458
+ )
459
+ return response.json()
460
+
461
+ def remote_cache_create_log(
462
+ self, response: requests.Response, description: str, cache_entry_count: int
463
+ ) -> Union[dict, None]:
464
+ """
465
+ If a remote cache action has been completed successfully,
466
+ log the action.
467
+ """
468
+ if 200 <= response.status_code < 300:
469
+ log_response = self._send_server_request(
470
+ uri="api/v0/remote-cache-log",
471
+ method="POST",
472
+ payload={
473
+ "description": description,
474
+ "cache_entry_count": cache_entry_count,
475
+ },
476
+ )
477
+ self._resolve_server_response(log_response)
478
+ return response.json()
479
+
480
+ def remote_cache_clear_log(self) -> dict:
481
+ """
482
+ Clear all remote cache log entries.
483
+
484
+ >>> coop.remote_cache_clear_log()
485
+ {'status': 'success'}
486
+ """
487
+ response = self._send_server_request(
488
+ uri="api/v0/remote-cache-log/delete-all",
489
+ method="DELETE",
490
+ )
491
+ self._resolve_server_response(response)
334
492
  return response.json()
335
493
 
336
494
  ################
@@ -438,7 +596,7 @@ if __name__ == "__main__":
438
596
  Agent,
439
597
  AgentList,
440
598
  Cache,
441
- Jobs,
599
+ Notebook,
442
600
  QuestionMultipleChoice,
443
601
  Results,
444
602
  Scenario,
@@ -450,7 +608,7 @@ if __name__ == "__main__":
450
608
  ("agent", Agent),
451
609
  ("agent_list", AgentList),
452
610
  ("cache", Cache),
453
- ("job", Jobs),
611
+ ("notebook", Notebook),
454
612
  ("question", QuestionMultipleChoice),
455
613
  ("results", Results),
456
614
  ("scenario", Scenario),
@@ -469,7 +627,9 @@ if __name__ == "__main__":
469
627
  response_1 = coop.create(example)
470
628
  response_2 = coop.create(cls.example(), visibility="private")
471
629
  response_3 = coop.create(cls.example(), visibility="public")
472
- response_4 = coop.create(cls.example(), visibility="unlisted")
630
+ response_4 = coop.create(
631
+ cls.example(), visibility="unlisted", description="hey"
632
+ )
473
633
  # 3. Retrieve all objects
474
634
  objects = coop.get_all(object_type)
475
635
  assert len(objects) == 4
@@ -486,14 +646,33 @@ if __name__ == "__main__":
486
646
  coop.patch(
487
647
  object_type=object_type, uuid=item.get("uuid"), visibility="private"
488
648
  )
649
+ # 6. Change description of all objects
650
+ for item in objects:
651
+ coop.patch(
652
+ object_type=object_type, uuid=item.get("uuid"), description="hey"
653
+ )
489
654
  # 7. Delete all objects
490
655
  for item in objects:
491
656
  coop.delete(object_type=object_type, uuid=item.get("uuid"))
492
657
  assert len(coop.get_all(object_type)) == 0
493
658
 
659
+ # a simple example
660
+ from edsl import Coop, QuestionMultipleChoice, QuestionFreeText
661
+
662
+ coop = Coop(api_key="b")
494
663
  response = QuestionMultipleChoice.example().push()
495
664
  QuestionMultipleChoice.pull(response.get("uuid"))
496
665
  coop.patch(object_type="question", uuid=response.get("uuid"), visibility="public")
666
+ coop.patch(
667
+ object_type="question",
668
+ uuid=response.get("uuid"),
669
+ description="crazy new description",
670
+ )
671
+ coop.patch(
672
+ object_type="question",
673
+ uuid=response.get("uuid"),
674
+ value=QuestionFreeText.example(),
675
+ )
497
676
  coop.delete(object_type="question", uuid=response.get("uuid"))
498
677
 
499
678
  ##############
edsl/coop/utils.py CHANGED
@@ -1,5 +1,15 @@
1
- from edsl import Agent, AgentList, Cache, Jobs, Results, Scenario, ScenarioList, Survey
2
- from edsl.notebooks import Notebook
1
+ from edsl import (
2
+ Agent,
3
+ AgentList,
4
+ Cache,
5
+ Jobs,
6
+ Notebook,
7
+ Results,
8
+ Scenario,
9
+ ScenarioList,
10
+ Survey,
11
+ Study,
12
+ )
3
13
  from edsl.questions import QuestionBase
4
14
  from typing import Literal, Type, Union
5
15
 
@@ -14,6 +24,7 @@ EDSLObject = Union[
14
24
  Scenario,
15
25
  ScenarioList,
16
26
  Survey,
27
+ Study,
17
28
  ]
18
29
 
19
30
  ObjectType = Literal[
@@ -27,6 +38,7 @@ ObjectType = Literal[
27
38
  "scenario",
28
39
  "scenario_list",
29
40
  "survey",
41
+ "study",
30
42
  ]
31
43
 
32
44
  ObjectPage = Literal[
@@ -40,6 +52,7 @@ ObjectPage = Literal[
40
52
  "scenarios",
41
53
  "scenariolists",
42
54
  "surveys",
55
+ "studies",
43
56
  ]
44
57
 
45
58
  VisibilityType = Literal[
@@ -105,6 +118,11 @@ class ObjectRegistry:
105
118
  "edsl_class": Survey,
106
119
  "object_page": "surveys",
107
120
  },
121
+ {
122
+ "object_type": "study",
123
+ "edsl_class": Study,
124
+ "object_page": "studies",
125
+ },
108
126
  ]
109
127
  object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
110
128
  edsl_class_to_object_type = {
edsl/data/Cache.py CHANGED
@@ -12,6 +12,7 @@ from edsl.config import CONFIG
12
12
  from edsl.data.CacheEntry import CacheEntry
13
13
  from edsl.data.SQLiteDict import SQLiteDict
14
14
  from edsl.Base import Base
15
+ from edsl.utilities.utilities import dict_hash
15
16
 
16
17
  from edsl.utilities.decorators import (
17
18
  add_edsl_version,
@@ -24,7 +25,6 @@ class Cache(Base):
24
25
  A class that represents a cache of responses from a language model.
25
26
 
26
27
  :param data: The data to initialize the cache with.
27
- :param remote: Whether to sync the Cache with the server.
28
28
  :param immediate_write: Whether to write to the cache immediately after storing a new entry.
29
29
 
30
30
  Deprecated:
@@ -37,24 +37,51 @@ class Cache(Base):
37
37
  def __init__(
38
38
  self,
39
39
  *,
40
+ filename: Optional[str] = None,
40
41
  data: Optional[Union[SQLiteDict, dict]] = None,
41
- remote: bool = False,
42
42
  immediate_write: bool = True,
43
43
  method=None,
44
44
  ):
45
45
  """
46
46
  Create two dictionaries to store the cache data.
47
47
 
48
+ :param filename: The name of the file to read/write the cache from/to.
49
+ :param data: The data to initialize the cache with.
50
+ :param immediate_write: Whether to write to the cache immediately after storing a new entry.
51
+ :param method: The method of storage to use for the cache.
52
+
48
53
  """
49
- self.data = data or {}
54
+
50
55
  # self.data_at_init = data or {}
51
56
  self.fetched_data = {}
52
- self.remote = remote
53
57
  self.immediate_write = immediate_write
54
58
  self.method = method
55
59
  self.new_entries = {}
56
60
  self.new_entries_to_write_later = {}
57
61
  self.coop = None
62
+
63
+ self.filename = filename
64
+ if filename and data:
65
+ raise ValueError("Cannot provide both filename and data")
66
+ if filename is None and data is None:
67
+ data = {}
68
+ if data is not None:
69
+ self.data = data
70
+ if filename is not None:
71
+ self.data = {}
72
+ if filename.endswith(".jsonl"):
73
+ if os.path.exists(filename):
74
+ self.add_from_jsonl(filename)
75
+ else:
76
+ print(
77
+ f"File {filename} not found, but will write to this location."
78
+ )
79
+ elif filename.endswith(".db"):
80
+ if os.path.exists(filename):
81
+ self.add_from_sqlite(filename)
82
+ else:
83
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
84
+
58
85
  self._perform_checks()
59
86
 
60
87
  def rich_print(sefl):
@@ -81,10 +108,6 @@ class Cache(Base):
81
108
  raise Exception("Not all values are CacheEntry instances")
82
109
  if self.method is not None:
83
110
  warnings.warn("Argument `method` is deprecated", DeprecationWarning)
84
- if self.remote:
85
- from edsl.coop import Coop
86
-
87
- self.coop = Coop()
88
111
 
89
112
  ####################
90
113
  # READ/WRITE
@@ -267,6 +290,19 @@ class Cache(Base):
267
290
  for key, value in self.data.items():
268
291
  new_data[key] = value
269
292
 
293
+ def write(self, filename: Optional[str] = None) -> None:
294
+ """
295
+ Write the cache to a file at the specified location.
296
+ """
297
+ if filename is None:
298
+ filename = self.filename
299
+ if filename.endswith(".jsonl"):
300
+ self.write_jsonl(filename)
301
+ elif filename.endswith(".db"):
302
+ self.write_sqlite_db(filename)
303
+ else:
304
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
305
+
270
306
  def write_jsonl(self, filename: str) -> None:
271
307
  """
272
308
  Write the cache to a JSONL file.
@@ -295,11 +331,6 @@ class Cache(Base):
295
331
  """
296
332
  Run when a context is entered.
297
333
  """
298
- if self.remote:
299
- print("Syncing local and remote caches")
300
- exclude_keys = list(self.data.keys())
301
- cache_entries = self.coop.get_cache_entries(exclude_keys)
302
- self.add_from_dict({c.key: c for c in cache_entries}, write_now=True)
303
334
  return self
304
335
 
305
336
  def __exit__(self, exc_type, exc_value, traceback):
@@ -308,16 +339,21 @@ class Cache(Base):
308
339
  """
309
340
  for key, entry in self.new_entries_to_write_later.items():
310
341
  self.data[key] = entry
311
- if self.remote:
312
- _ = self.coop.create_cache_entries(cache_dict=self.new_entries)
313
342
 
314
343
  ####################
315
344
  # DUNDER / USEFUL
316
345
  ####################
346
+ def __hash__(self):
347
+ """Return the hash of the Cache."""
348
+ return dict_hash(self._to_dict())
349
+
350
+ def _to_dict(self) -> dict:
351
+ return {k: v.to_dict() for k, v in self.data.items()}
352
+
317
353
  @add_edsl_version
318
354
  def to_dict(self) -> dict:
319
355
  """Return the Cache as a dictionary."""
320
- return {k: v.to_dict() for k, v in self.data.items()}
356
+ return self._to_dict()
321
357
 
322
358
  def _repr_html_(self):
323
359
  from edsl.utilities.utilities import data_to_html
@@ -359,7 +395,9 @@ class Cache(Base):
359
395
  """
360
396
  Return a string representation of the Cache object.
361
397
  """
362
- return f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write}, remote={self.remote})"
398
+ return (
399
+ f"Cache(data = {repr(self.data)}, immediate_write={self.immediate_write})"
400
+ )
363
401
 
364
402
  ####################
365
403
  # EXAMPLES
edsl/data/CacheHandler.py CHANGED
@@ -9,22 +9,22 @@ from edsl.data.Cache import Cache
9
9
  from edsl.data.CacheEntry import CacheEntry
10
10
  from edsl.data.SQLiteDict import SQLiteDict
11
11
 
12
+ from edsl.config import CONFIG
13
+
12
14
 
13
15
  def set_session_cache(cache: Cache) -> None:
14
16
  """
15
17
  Set the session cache.
16
18
  """
17
- print("All calls to 'run' will now use this cache by default.")
18
- global _CACHE
19
- _CACHE = cache
19
+ CONFIG.EDSL_SESSION_CACHE = cache
20
20
 
21
21
 
22
22
  def unset_session_cache() -> None:
23
23
  """
24
24
  Unset the session cache.
25
25
  """
26
- global _CACHE
27
- _CACHE = None
26
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
27
+ del CONFIG.EDSL_SESSION_CACHE
28
28
 
29
29
 
30
30
  class CacheHandler:
@@ -49,7 +49,9 @@ class CacheHandler:
49
49
  dir_path = os.path.dirname(path)
50
50
  if dir_path and not os.path.exists(dir_path):
51
51
  os.makedirs(dir_path)
52
- print(f"Created cache directory: {dir_path}")
52
+ import warnings
53
+
54
+ warnings.warn(f"Created cache directory: {dir_path}")
53
55
 
54
56
  def gen_cache(self) -> Cache:
55
57
  """
@@ -58,9 +60,8 @@ class CacheHandler:
58
60
  if self.test:
59
61
  return Cache(data={})
60
62
 
61
- if "_CACHE" in globals() and _CACHE is not None:
62
- # print("Using globally-set cache.")
63
- return _CACHE
63
+ if hasattr(CONFIG, "EDSL_SESSION_CACHE"):
64
+ return CONFIG.EDSL_SESSION_CACHE
64
65
 
65
66
  cache = Cache(data=SQLiteDict(self.CACHE_PATH))
66
67
  return cache
@@ -16,6 +16,7 @@ class AnthropicService(InferenceServiceABC):
16
16
  def available(cls):
17
17
  # TODO - replace with an API call
18
18
  return [
19
+ "claude-3-5-sonnet-20240620",
19
20
  "claude-3-opus-20240229",
20
21
  "claude-3-sonnet-20240229",
21
22
  "claude-3-haiku-20240307",