edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. edsl/Base.py +107 -30
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +25 -21
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +103 -46
  6. edsl/agents/AgentList.py +97 -13
  7. edsl/agents/Invigilator.py +23 -10
  8. edsl/agents/InvigilatorBase.py +19 -14
  9. edsl/agents/PromptConstructionMixin.py +342 -100
  10. edsl/agents/descriptors.py +5 -2
  11. edsl/base/Base.py +289 -0
  12. edsl/config.py +2 -1
  13. edsl/conjure/AgentConstructionMixin.py +152 -0
  14. edsl/conjure/Conjure.py +56 -0
  15. edsl/conjure/InputData.py +659 -0
  16. edsl/conjure/InputDataCSV.py +48 -0
  17. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  18. edsl/conjure/InputDataPyRead.py +91 -0
  19. edsl/conjure/InputDataSPSS.py +8 -0
  20. edsl/conjure/InputDataStata.py +8 -0
  21. edsl/conjure/QuestionOptionMixin.py +76 -0
  22. edsl/conjure/QuestionTypeMixin.py +23 -0
  23. edsl/conjure/RawQuestion.py +65 -0
  24. edsl/conjure/SurveyResponses.py +7 -0
  25. edsl/conjure/__init__.py +9 -4
  26. edsl/conjure/examples/placeholder.txt +0 -0
  27. edsl/conjure/naming_utilities.py +263 -0
  28. edsl/conjure/utilities.py +165 -28
  29. edsl/conversation/Conversation.py +238 -0
  30. edsl/conversation/car_buying.py +58 -0
  31. edsl/conversation/mug_negotiation.py +81 -0
  32. edsl/conversation/next_speaker_utilities.py +93 -0
  33. edsl/coop/coop.py +337 -121
  34. edsl/coop/utils.py +56 -70
  35. edsl/data/Cache.py +74 -22
  36. edsl/data/CacheHandler.py +10 -9
  37. edsl/data/SQLiteDict.py +11 -3
  38. edsl/inference_services/AnthropicService.py +1 -0
  39. edsl/inference_services/DeepInfraService.py +20 -13
  40. edsl/inference_services/GoogleService.py +7 -1
  41. edsl/inference_services/InferenceServicesCollection.py +33 -7
  42. edsl/inference_services/OpenAIService.py +17 -10
  43. edsl/inference_services/models_available_cache.py +69 -0
  44. edsl/inference_services/rate_limits_cache.py +25 -0
  45. edsl/inference_services/write_available.py +10 -0
  46. edsl/jobs/Answers.py +15 -1
  47. edsl/jobs/Jobs.py +322 -73
  48. edsl/jobs/buckets/BucketCollection.py +9 -3
  49. edsl/jobs/buckets/ModelBuckets.py +4 -2
  50. edsl/jobs/buckets/TokenBucket.py +1 -2
  51. edsl/jobs/interviews/Interview.py +7 -10
  52. edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
  53. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
  54. edsl/jobs/interviews/retry_management.py +4 -4
  55. edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
  56. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  57. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  58. edsl/jobs/tasks/TaskHistory.py +4 -3
  59. edsl/language_models/LanguageModel.py +42 -55
  60. edsl/language_models/ModelList.py +96 -0
  61. edsl/language_models/registry.py +14 -0
  62. edsl/language_models/repair.py +97 -25
  63. edsl/notebooks/Notebook.py +157 -32
  64. edsl/prompts/Prompt.py +31 -19
  65. edsl/questions/QuestionBase.py +145 -23
  66. edsl/questions/QuestionBudget.py +5 -6
  67. edsl/questions/QuestionCheckBox.py +7 -3
  68. edsl/questions/QuestionExtract.py +5 -3
  69. edsl/questions/QuestionFreeText.py +3 -3
  70. edsl/questions/QuestionFunctional.py +0 -3
  71. edsl/questions/QuestionList.py +3 -4
  72. edsl/questions/QuestionMultipleChoice.py +16 -8
  73. edsl/questions/QuestionNumerical.py +4 -3
  74. edsl/questions/QuestionRank.py +5 -3
  75. edsl/questions/__init__.py +4 -3
  76. edsl/questions/descriptors.py +9 -4
  77. edsl/questions/question_registry.py +27 -31
  78. edsl/questions/settings.py +1 -1
  79. edsl/results/Dataset.py +31 -0
  80. edsl/results/DatasetExportMixin.py +493 -0
  81. edsl/results/Result.py +42 -82
  82. edsl/results/Results.py +178 -66
  83. edsl/results/ResultsDBMixin.py +10 -9
  84. edsl/results/ResultsExportMixin.py +23 -507
  85. edsl/results/ResultsGGMixin.py +3 -3
  86. edsl/results/ResultsToolsMixin.py +9 -9
  87. edsl/scenarios/FileStore.py +140 -0
  88. edsl/scenarios/Scenario.py +59 -6
  89. edsl/scenarios/ScenarioList.py +138 -52
  90. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  91. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  92. edsl/scenarios/__init__.py +1 -0
  93. edsl/study/ObjectEntry.py +173 -0
  94. edsl/study/ProofOfWork.py +113 -0
  95. edsl/study/SnapShot.py +73 -0
  96. edsl/study/Study.py +498 -0
  97. edsl/study/__init__.py +4 -0
  98. edsl/surveys/MemoryPlan.py +11 -4
  99. edsl/surveys/Survey.py +124 -37
  100. edsl/surveys/SurveyExportMixin.py +25 -5
  101. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  102. edsl/tools/plotting.py +4 -2
  103. edsl/utilities/__init__.py +21 -20
  104. edsl/utilities/gcp_bucket/__init__.py +0 -0
  105. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  106. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  107. edsl/utilities/interface.py +90 -73
  108. edsl/utilities/repair_functions.py +28 -0
  109. edsl/utilities/utilities.py +59 -6
  110. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
  111. edsl-0.1.29.dist-info/RECORD +203 -0
  112. edsl/conjure/RawResponseColumn.py +0 -327
  113. edsl/conjure/SurveyBuilder.py +0 -308
  114. edsl/conjure/SurveyBuilderCSV.py +0 -78
  115. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  116. edsl/data/RemoteDict.py +0 -103
  117. edsl-0.1.27.dev2.dist-info/RECORD +0 -172
  118. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
  119. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -5,8 +5,14 @@ import requests
5
5
  from typing import Any, Optional, Union, Literal
6
6
  from uuid import UUID
7
7
  import edsl
8
- from edsl import CONFIG, CacheEntry
9
- from edsl.coop.utils import EDSLObject, ObjectRegistry, ObjectType, VisibilityType
8
+ from edsl import CONFIG, CacheEntry, Jobs, Survey
9
+ from edsl.coop.utils import (
10
+ EDSLObject,
11
+ ObjectRegistry,
12
+ ObjectType,
13
+ RemoteJobStatus,
14
+ VisibilityType,
15
+ )
10
16
 
11
17
 
12
18
  class Coop:
@@ -88,6 +94,18 @@ class Coop:
88
94
  if value is None:
89
95
  return "null"
90
96
 
97
+ def _resolve_uuid(
98
+ self, uuid: Union[str, UUID] = None, url: str = None
99
+ ) -> Union[str, UUID]:
100
+ """
101
+ Resolve the uuid from a uuid or a url.
102
+ """
103
+ if not url and not uuid:
104
+ raise Exception("No uuid or url provided for the object.")
105
+ if not uuid and url:
106
+ uuid = url.split("/")[-1]
107
+ return uuid
108
+
91
109
  @property
92
110
  def edsl_settings(self) -> dict:
93
111
  """
@@ -100,28 +118,26 @@ class Coop:
100
118
  ################
101
119
  # Objects
102
120
  ################
103
-
104
- # TODO: add URL to get and get_all methods
105
-
106
121
  def create(
107
122
  self,
108
123
  object: EDSLObject,
109
- visibility: VisibilityType = "unlisted",
124
+ description: Optional[str] = None,
125
+ visibility: Optional[VisibilityType] = "unlisted",
110
126
  ) -> dict:
111
127
  """
112
128
  Create an EDSL object in the Coop server.
113
129
  """
114
130
  object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
115
- object_page = ObjectRegistry.get_object_page_by_object_type(object_type)
116
131
  response = self._send_server_request(
117
132
  uri=f"api/v0/object",
118
133
  method="POST",
119
134
  payload={
120
- "object_type": object_type,
135
+ "description": description,
121
136
  "json_string": json.dumps(
122
137
  object.to_dict(),
123
138
  default=self._json_handle_none,
124
139
  ),
140
+ "object_type": object_type,
125
141
  "visibility": visibility,
126
142
  "version": self._edsl_version,
127
143
  },
@@ -129,59 +145,51 @@ class Coop:
129
145
  self._resolve_server_response(response)
130
146
  response_json = response.json()
131
147
  return {
148
+ "description": response_json.get("description"),
149
+ "object_type": object_type,
150
+ "url": f"{self.url}/content/{response_json.get('uuid')}",
132
151
  "uuid": response_json.get("uuid"),
133
152
  "version": self._edsl_version,
134
153
  "visibility": response_json.get("visibility"),
135
- "url": f"{self.url}/explore/{object_page}/{response_json.get('uuid')}",
136
154
  }
137
155
 
138
156
  def get(
139
157
  self,
140
- object_type: ObjectType = None,
141
158
  uuid: Union[str, UUID] = None,
142
159
  url: str = None,
160
+ expected_object_type: Optional[ObjectType] = None,
143
161
  ) -> EDSLObject:
144
162
  """
145
- Retrieve an EDSL object either by object type & UUID, or by its url.
146
- - The object has to belong to the user or not be private.
147
- - Returns the initialized object class instance.
163
+ Retrieve an EDSL object by its uuid or its url.
164
+ - If the object's visibility is private, the user must be the owner.
165
+ - Optionally, check if the retrieved object is of a certain type.
148
166
 
149
- :param object_type: the type of object to retrieve.
150
167
  :param uuid: the uuid of the object either in str or UUID format.
151
168
  :param url: the url of the object.
169
+ :param expected_object_type: the expected type of the object.
170
+
171
+ :return: the object instance.
152
172
  """
153
- if url:
154
- object_type = url.split("/")[-2]
155
- uuid = url.split("/")[-1]
156
- elif not object_type and not uuid:
157
- raise Exception("Provide either object_type & UUID, or a url.")
158
- edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
173
+ uuid = self._resolve_uuid(uuid, url)
159
174
  response = self._send_server_request(
160
175
  uri=f"api/v0/object",
161
176
  method="GET",
162
- params={"type": object_type, "uuid": uuid},
177
+ params={"uuid": uuid},
163
178
  )
164
179
  self._resolve_server_response(response)
165
180
  json_string = response.json().get("json_string")
166
- return edsl_class.from_dict(json.loads(json_string))
167
-
168
- def _get_base(
169
- self,
170
- cls: EDSLObject,
171
- uuid: Union[str, UUID],
172
- ) -> EDSLObject:
173
- """
174
- Used by the Base class to offer a get functionality.
175
- """
176
- object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
177
- return self.get(object_type, uuid)
181
+ object_type = response.json().get("object_type")
182
+ if expected_object_type and object_type != expected_object_type:
183
+ raise Exception(f"Expected {expected_object_type=} but got {object_type=}")
184
+ edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
185
+ object = edsl_class.from_dict(json.loads(json_string))
186
+ return object
178
187
 
179
- def get_all(self, object_type: ObjectType) -> list[EDSLObject]:
188
+ def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
180
189
  """
181
190
  Retrieve all objects of a certain type associated with the user.
182
191
  """
183
192
  edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
184
- object_page = ObjectRegistry.get_object_page_by_object_type(object_type)
185
193
  response = self._send_server_request(
186
194
  uri=f"api/v0/objects",
187
195
  method="GET",
@@ -193,67 +201,62 @@ class Coop:
193
201
  "object": edsl_class.from_dict(json.loads(o.get("json_string"))),
194
202
  "uuid": o.get("uuid"),
195
203
  "version": o.get("version"),
204
+ "description": o.get("description"),
196
205
  "visibility": o.get("visibility"),
197
- "url": f"{self.url}/explore/{object_page}/{o.get('uuid')}",
206
+ "url": f"{self.url}/content/{o.get('uuid')}",
198
207
  }
199
208
  for o in response.json()
200
209
  ]
201
210
  return objects
202
211
 
203
- def delete(self, object_type: ObjectType, uuid: Union[str, UUID]) -> dict:
212
+ def delete(self, uuid: Union[str, UUID] = None, url: str = None) -> dict:
204
213
  """
205
214
  Delete an object from the server.
206
215
  """
216
+ uuid = self._resolve_uuid(uuid, url)
207
217
  response = self._send_server_request(
208
218
  uri=f"api/v0/object",
209
219
  method="DELETE",
210
- params={"type": object_type, "uuid": uuid},
220
+ params={"uuid": uuid},
211
221
  )
212
222
  self._resolve_server_response(response)
213
223
  return response.json()
214
224
 
215
- def _delete_base(
216
- self,
217
- cls: EDSLObject,
218
- uuid: Union[str, UUID],
219
- ) -> dict:
220
- """
221
- Used by the Base class to offer a delete functionality.
222
- """
223
- object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
224
- return self.delete(object_type, uuid)
225
-
226
225
  def patch(
227
226
  self,
228
- object_type: ObjectType,
229
- uuid: Union[str, UUID],
230
- visibility: VisibilityType,
227
+ uuid: Union[str, UUID] = None,
228
+ url: str = None,
229
+ description: Optional[str] = None,
230
+ value: Optional[EDSLObject] = None,
231
+ visibility: Optional[VisibilityType] = None,
231
232
  ) -> dict:
232
233
  """
233
234
  Change the attributes of an uploaded object
234
235
  - Only supports visibility for now
235
236
  """
237
+ if description is None and visibility is None and value is None:
238
+ raise Exception("Nothing to patch.")
239
+ uuid = self._resolve_uuid(uuid, url)
236
240
  response = self._send_server_request(
237
241
  uri=f"api/v0/object",
238
242
  method="PATCH",
239
- params={"type": object_type, "uuid": uuid},
240
- payload={"visibility": visibility},
243
+ params={"uuid": uuid},
244
+ payload={
245
+ "description": description,
246
+ "json_string": (
247
+ json.dumps(
248
+ value.to_dict(),
249
+ default=self._json_handle_none,
250
+ )
251
+ if value
252
+ else None
253
+ ),
254
+ "visibility": visibility,
255
+ },
241
256
  )
242
257
  self._resolve_server_response(response)
243
258
  return response.json()
244
259
 
245
- def _patch_base(
246
- self,
247
- cls: EDSLObject,
248
- uuid: Union[str, UUID],
249
- visibility: VisibilityType,
250
- ) -> dict:
251
- """
252
- Used by the Base class to offer a patch functionality.
253
- """
254
- object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
255
- return self.patch(object_type, uuid, visibility)
256
-
257
260
  ################
258
261
  # Remote Cache
259
262
  ################
@@ -261,9 +264,19 @@ class Coop:
261
264
  self,
262
265
  cache_entry: CacheEntry,
263
266
  visibility: VisibilityType = "private",
267
+ description: Optional[str] = None,
264
268
  ) -> dict:
265
269
  """
266
270
  Create a single remote cache entry.
271
+ If an entry with the same key already exists in the database, update it instead.
272
+
273
+ :param cache_entry: The cache entry to send to the server.
274
+ :param visibility: The visibility of the cache entry.
275
+ :param optional description: A description for this entry in the remote cache.
276
+
277
+ >>> entry = CacheEntry.example()
278
+ >>> coop.remote_cache_create(cache_entry=entry)
279
+ {'status': 'success', 'created_entry_count': 1, 'updated_entry_count': 0}
267
280
  """
268
281
  response = self._send_server_request(
269
282
  uri="api/v0/remote-cache",
@@ -272,24 +285,44 @@ class Coop:
272
285
  "json_string": json.dumps(cache_entry.to_dict()),
273
286
  "version": self._edsl_version,
274
287
  "visibility": visibility,
288
+ "description": description,
275
289
  },
276
290
  )
277
291
  self._resolve_server_response(response)
292
+ response_json = response.json()
293
+ created_entry_count = response_json.get("created_entry_count", 0)
294
+ if created_entry_count > 0:
295
+ self.remote_cache_create_log(
296
+ response,
297
+ description="Upload new cache entries to server",
298
+ cache_entry_count=created_entry_count,
299
+ )
278
300
  return response.json()
279
301
 
280
302
  def remote_cache_create_many(
281
303
  self,
282
304
  cache_entries: list[CacheEntry],
283
305
  visibility: VisibilityType = "private",
306
+ description: Optional[str] = None,
284
307
  ) -> dict:
285
308
  """
286
309
  Create many remote cache entries.
310
+ If an entry with the same key already exists in the database, update it instead.
311
+
312
+ :param cache_entries: The list of cache entries to send to the server.
313
+ :param visibility: The visibility of the cache entries.
314
+ :param optional description: A description for these entries in the remote cache.
315
+
316
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
317
+ >>> coop.remote_cache_create_many(cache_entries=entries)
318
+ {'status': 'success', 'created_entry_count': 10, 'updated_entry_count': 0}
287
319
  """
288
320
  payload = [
289
321
  {
290
322
  "json_string": json.dumps(c.to_dict()),
291
323
  "version": self._edsl_version,
292
324
  "visibility": visibility,
325
+ "description": description,
293
326
  }
294
327
  for c in cache_entries
295
328
  ]
@@ -299,6 +332,14 @@ class Coop:
299
332
  payload=payload,
300
333
  )
301
334
  self._resolve_server_response(response)
335
+ response_json = response.json()
336
+ created_entry_count = response_json.get("created_entry_count", 0)
337
+ if created_entry_count > 0:
338
+ self.remote_cache_create_log(
339
+ response,
340
+ description="Upload new cache entries to server",
341
+ cache_entry_count=created_entry_count,
342
+ )
302
343
  return response.json()
303
344
 
304
345
  def remote_cache_get(
@@ -307,7 +348,11 @@ class Coop:
307
348
  ) -> list[CacheEntry]:
308
349
  """
309
350
  Get all remote cache entries.
310
- - optional exclude_keys: exclude CacheEntry objects with these keys.
351
+
352
+ :param optional exclude_keys: Exclude CacheEntry objects with these keys.
353
+
354
+ >>> coop.remote_cache_get()
355
+ [CacheEntry(...), CacheEntry(...), ...]
311
356
  """
312
357
  if exclude_keys is None:
313
358
  exclude_keys = []
@@ -322,23 +367,149 @@ class Coop:
322
367
  for v in response.json()
323
368
  ]
324
369
 
370
+ def remote_cache_get_diff(
371
+ self,
372
+ client_cacheentry_keys: list[str],
373
+ ) -> dict:
374
+ """
375
+ Get the difference between local and remote cache entries for a user.
376
+ """
377
+ response = self._send_server_request(
378
+ uri="api/v0/remote-cache/get-diff",
379
+ method="POST",
380
+ payload={"keys": client_cacheentry_keys},
381
+ )
382
+ self._resolve_server_response(response)
383
+ response_json = response.json()
384
+ response_dict = {
385
+ "client_missing_cacheentries": [
386
+ CacheEntry.from_dict(json.loads(c.get("json_string")))
387
+ for c in response_json.get("client_missing_cacheentries", [])
388
+ ],
389
+ "server_missing_cacheentry_keys": response_json.get(
390
+ "server_missing_cacheentry_keys", []
391
+ ),
392
+ }
393
+ downloaded_entry_count = len(response_dict["client_missing_cacheentries"])
394
+ if downloaded_entry_count > 0:
395
+ self.remote_cache_create_log(
396
+ response,
397
+ description="Download missing cache entries to client",
398
+ cache_entry_count=downloaded_entry_count,
399
+ )
400
+ return response_dict
401
+
325
402
  def remote_cache_clear(self) -> dict:
326
403
  """
327
404
  Clear all remote cache entries.
405
+
406
+ >>> entries = [CacheEntry.example(randomize=True) for _ in range(10)]
407
+ >>> coop.remote_cache_create_many(cache_entries=entries)
408
+ >>> coop.remote_cache_clear()
409
+ {'status': 'success', 'deleted_entry_count': 10}
328
410
  """
329
411
  response = self._send_server_request(
330
412
  uri="api/v0/remote-cache/delete-all",
331
413
  method="DELETE",
332
414
  )
333
415
  self._resolve_server_response(response)
416
+ response_json = response.json()
417
+ deleted_entry_count = response_json.get("deleted_entry_count", 0)
418
+ if deleted_entry_count > 0:
419
+ self.remote_cache_create_log(
420
+ response,
421
+ description="Clear cache entries",
422
+ cache_entry_count=deleted_entry_count,
423
+ )
424
+ return response.json()
425
+
426
+ def remote_cache_create_log(
427
+ self, response: requests.Response, description: str, cache_entry_count: int
428
+ ) -> Union[dict, None]:
429
+ """
430
+ If a remote cache action has been completed successfully,
431
+ log the action.
432
+ """
433
+ if 200 <= response.status_code < 300:
434
+ log_response = self._send_server_request(
435
+ uri="api/v0/remote-cache-log",
436
+ method="POST",
437
+ payload={
438
+ "description": description,
439
+ "cache_entry_count": cache_entry_count,
440
+ },
441
+ )
442
+ self._resolve_server_response(log_response)
443
+ return response.json()
444
+
445
+ def remote_cache_clear_log(self) -> dict:
446
+ """
447
+ Clear all remote cache log entries.
448
+
449
+ >>> coop.remote_cache_clear_log()
450
+ {'status': 'success'}
451
+ """
452
+ response = self._send_server_request(
453
+ uri="api/v0/remote-cache-log/delete-all",
454
+ method="DELETE",
455
+ )
456
+ self._resolve_server_response(response)
334
457
  return response.json()
335
458
 
336
459
  ################
337
460
  # Remote Inference
338
461
  ################
462
+ def remote_inference_create(
463
+ self,
464
+ job: Jobs,
465
+ description: Optional[str] = None,
466
+ status: RemoteJobStatus = "queued",
467
+ visibility: Optional[VisibilityType] = "unlisted",
468
+ ) -> dict:
469
+ """
470
+ Send a remote inference job to the server.
471
+
472
+ :param job: The EDSL job to send to the server.
473
+ :param optional description: A description for this entry in the remote cache.
474
+ :param status: The status of the job. Should be 'queued', unless you are debugging.
475
+ :param visibility: The visibility of the cache entry.
476
+
477
+ >>> job = Jobs.example()
478
+ >>> coop.remote_inference_create(job=job, description="My job")
479
+ {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.29.dev4'}
480
+ """
481
+ response = self._send_server_request(
482
+ uri="api/v0/remote-inference",
483
+ method="POST",
484
+ payload={
485
+ "json_string": json.dumps(
486
+ job.to_dict(),
487
+ default=self._json_handle_none,
488
+ ),
489
+ "description": description,
490
+ "status": status,
491
+ "visibility": visibility,
492
+ "version": self._edsl_version,
493
+ },
494
+ )
495
+ self._resolve_server_response(response)
496
+ response_json = response.json()
497
+ return {
498
+ "uuid": response_json.get("jobs_uuid"),
499
+ "description": response_json.get("description"),
500
+ "status": response_json.get("status"),
501
+ "visibility": response_json.get("visibility"),
502
+ "version": self._edsl_version,
503
+ }
504
+
339
505
  def remote_inference_get(self, job_uuid: str) -> dict:
340
506
  """
341
- Get the results of a remote inference job.
507
+ Get the details of a remote inference job.
508
+
509
+ :param job_uuid: The UUID of the EDSL job.
510
+
511
+ >>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
512
+ {'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
342
513
  """
343
514
  response = self._send_server_request(
344
515
  uri="api/v0/remote-inference",
@@ -350,13 +521,44 @@ class Coop:
350
521
  return {
351
522
  "jobs_uuid": data.get("jobs_uuid"),
352
523
  "results_uuid": data.get("results_uuid"),
353
- "results_url": "TO BE ADDED",
524
+ "results_url": f"{self.url}/content/{data.get('results_uuid')}",
354
525
  "status": data.get("status"),
355
526
  "reason": data.get("reason"),
356
527
  "price": data.get("price"),
357
528
  "version": data.get("version"),
358
529
  }
359
530
 
531
+ def remote_inference_cost(self, input: Union[Jobs, Survey]) -> int:
532
+ """
533
+ Get the cost of a remote inference job.
534
+
535
+ :param input: The EDSL job to send to the server.
536
+
537
+ >>> job = Jobs.example()
538
+ >>> coop.remote_inference_cost(input=job)
539
+ 16
540
+ """
541
+ if isinstance(input, Jobs):
542
+ job = input
543
+ elif isinstance(input, Survey):
544
+ job = Jobs(survey=input)
545
+ else:
546
+ raise TypeError("Input must be either a Job or a Survey.")
547
+
548
+ response = self._send_server_request(
549
+ uri="api/v0/remote-inference/cost",
550
+ method="POST",
551
+ payload={
552
+ "json_string": json.dumps(
553
+ job.to_dict(),
554
+ default=self._json_handle_none,
555
+ ),
556
+ },
557
+ )
558
+ self._resolve_server_response(response)
559
+ response_json = response.json()
560
+ return response_json.get("cost")
561
+
360
562
  ################
361
563
  # Remote Errors
362
564
  ################
@@ -420,87 +622,115 @@ class Coop:
420
622
  return response_json
421
623
 
422
624
 
423
- if __name__ == "__main__":
424
- from edsl.coop import Coop
425
-
426
- # init
427
- API_KEY = "b"
428
- coop = Coop(api_key=API_KEY)
429
- # basics
430
- coop
431
- coop.edsl_settings
432
-
433
- ##############
434
- # A. Objects
435
- ##############
625
+ def main():
626
+ """
627
+ A simple example for the coop client
628
+ """
436
629
  from uuid import uuid4
437
630
  from edsl import (
438
631
  Agent,
439
632
  AgentList,
440
633
  Cache,
441
- Jobs,
634
+ Notebook,
635
+ QuestionFreeText,
442
636
  QuestionMultipleChoice,
443
637
  Results,
444
638
  Scenario,
445
639
  ScenarioList,
446
640
  Survey,
447
641
  )
642
+ from edsl.coop import Coop
643
+ from edsl.data.CacheEntry import CacheEntry
644
+ from edsl.jobs import Jobs
645
+
646
+ # init & basics
647
+ API_KEY = "b"
648
+ coop = Coop(api_key=API_KEY)
649
+ coop
650
+ coop.edsl_settings
651
+
652
+ ##############
653
+ # A. A simple example
654
+ ##############
655
+ # .. create and manipulate an object through the Coop client
656
+ response = coop.create(QuestionMultipleChoice.example())
657
+ coop.get(uuid=response.get("uuid"))
658
+ coop.get(uuid=response.get("uuid"), expected_object_type="question")
659
+ coop.get(url=response.get("url"))
660
+ coop.create(QuestionMultipleChoice.example())
661
+ coop.get_all("question")
662
+ coop.patch(uuid=response.get("uuid"), visibility="private")
663
+ coop.patch(uuid=response.get("uuid"), description="hey")
664
+ coop.patch(uuid=response.get("uuid"), value=QuestionFreeText.example())
665
+ # coop.patch(uuid=response.get("uuid"), value=Survey.example()) - should throw error
666
+ coop.get(uuid=response.get("uuid"))
667
+ coop.delete(uuid=response.get("uuid"))
668
+
669
+ # .. create and manipulate an object through the class
670
+ response = QuestionMultipleChoice.example().push()
671
+ QuestionMultipleChoice.pull(uuid=response.get("uuid"))
672
+ QuestionMultipleChoice.pull(url=response.get("url"))
673
+ QuestionMultipleChoice.patch(uuid=response.get("uuid"), visibility="private")
674
+ QuestionMultipleChoice.patch(uuid=response.get("uuid"), description="hey")
675
+ QuestionMultipleChoice.patch(
676
+ uuid=response.get("uuid"), value=QuestionFreeText.example()
677
+ )
678
+ QuestionMultipleChoice.pull(response.get("uuid"))
679
+ QuestionMultipleChoice.delete(response.get("uuid"))
448
680
 
681
+ ##############
682
+ # B. Examples with all objects
683
+ ##############
449
684
  OBJECTS = [
450
685
  ("agent", Agent),
451
686
  ("agent_list", AgentList),
452
687
  ("cache", Cache),
453
- ("job", Jobs),
688
+ ("notebook", Notebook),
454
689
  ("question", QuestionMultipleChoice),
455
690
  ("results", Results),
456
691
  ("scenario", Scenario),
457
692
  ("scenario_list", ScenarioList),
458
693
  ("survey", Survey),
459
694
  ]
460
-
461
695
  for object_type, cls in OBJECTS:
462
696
  print(f"Testing {object_type} objects")
463
697
  # 1. Delete existing objects
464
698
  existing_objects = coop.get_all(object_type)
465
699
  for item in existing_objects:
466
- coop.delete(object_type=object_type, uuid=item.get("uuid"))
700
+ coop.delete(uuid=item.get("uuid"))
467
701
  # 2. Create new objects
468
702
  example = cls.example()
469
703
  response_1 = coop.create(example)
470
704
  response_2 = coop.create(cls.example(), visibility="private")
471
705
  response_3 = coop.create(cls.example(), visibility="public")
472
- response_4 = coop.create(cls.example(), visibility="unlisted")
706
+ response_4 = coop.create(
707
+ cls.example(), visibility="unlisted", description="hey"
708
+ )
473
709
  # 3. Retrieve all objects
474
710
  objects = coop.get_all(object_type)
475
711
  assert len(objects) == 4
476
712
  # 4. Try to retrieve an item that does not exist
477
713
  try:
478
- coop.get(object_type=object_type, uuid=uuid4())
714
+ coop.get(uuid=uuid4())
479
715
  except Exception as e:
480
716
  print(e)
481
717
  # 5. Try to retrieve all test objects by their uuids
482
718
  for response in [response_1, response_2, response_3, response_4]:
483
- coop.get(object_type=object_type, uuid=response.get("uuid"))
719
+ coop.get(uuid=response.get("uuid"))
484
720
  # 6. Change visibility of all objects
485
721
  for item in objects:
486
- coop.patch(
487
- object_type=object_type, uuid=item.get("uuid"), visibility="private"
488
- )
722
+ coop.patch(uuid=item.get("uuid"), visibility="private")
723
+ # 6. Change description of all objects
724
+ for item in objects:
725
+ coop.patch(uuid=item.get("uuid"), description="hey")
489
726
  # 7. Delete all objects
490
727
  for item in objects:
491
- coop.delete(object_type=object_type, uuid=item.get("uuid"))
728
+ coop.delete(uuid=item.get("uuid"))
492
729
  assert len(coop.get_all(object_type)) == 0
493
730
 
494
- response = QuestionMultipleChoice.example().push()
495
- QuestionMultipleChoice.pull(response.get("uuid"))
496
- coop.patch(object_type="question", uuid=response.get("uuid"), visibility="public")
497
- coop.delete(object_type="question", uuid=response.get("uuid"))
498
-
499
731
  ##############
500
- # B. Remote Cache
732
+ # C. Remote Cache
501
733
  ##############
502
- from edsl.data.CacheEntry import CacheEntry
503
-
504
734
  # clear
505
735
  coop.remote_cache_clear()
506
736
  assert coop.remote_cache_get() == []
@@ -522,30 +752,16 @@ if __name__ == "__main__":
522
752
  coop.remote_cache_get()
523
753
 
524
754
  ##############
525
- # C. Remote Inference
755
+ # D. Remote Inference
526
756
  ##############
527
- from edsl.jobs import Jobs
528
-
529
- # check jobs on server (should be an empty list)
530
- coop.get_all("job")
531
- for job in coop.get_all("job"):
532
- coop.delete(object_type="job", uuid=job.get("uuid"))
533
- # post a job
534
- response = coop.create(Jobs.example())
535
- # get job and results
536
- coop.remote_inference_get(response.get("uuid"))
537
- coop.get(
538
- object_type="results",
539
- uuid=coop.remote_inference_get(response.get("uuid")).get("results_uuid"),
540
- )
757
+ job = Jobs.example()
758
+ coop.remote_inference_cost(job)
759
+ results = coop.remote_inference_create(job)
760
+ coop.remote_inference_get(results.get("uuid"))
541
761
 
542
762
  ##############
543
- # D. Errors
763
+ # E. Errors
544
764
  ##############
545
- from edsl import Coop
546
-
547
- coop = Coop()
548
- coop.api_key = "a"
549
765
  coop.error_create({"something": "This is an error message"})
550
766
  coop.api_key = None
551
767
  coop.error_create({"something": "This is an error message"})