edsl 0.1.29.dev2__py3-none-any.whl → 0.1.29.dev6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +12 -15
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +37 -2
- edsl/agents/AgentList.py +3 -4
- edsl/agents/InvigilatorBase.py +15 -10
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/conjure/InputData.py +39 -8
- edsl/coop/coop.py +187 -150
- edsl/coop/utils.py +17 -76
- edsl/jobs/Jobs.py +23 -17
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +1 -0
- edsl/notebooks/Notebook.py +31 -0
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +32 -11
- edsl/questions/question_registry.py +20 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/Results.py +6 -8
- edsl/results/ResultsToolsMixin.py +4 -1
- edsl/scenarios/ScenarioList.py +17 -3
- edsl/study/Study.py +3 -9
- edsl/surveys/Survey.py +37 -3
- edsl/tools/plotting.py +4 -2
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/METADATA +11 -10
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/RECORD +27 -28
- edsl-0.1.29.dev2.dist-info/entry_points.txt +0 -3
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/LICENSE +0 -0
- {edsl-0.1.29.dev2.dist-info → edsl-0.1.29.dev6.dist-info}/WHEEL +0 -0
edsl/coop/coop.py
CHANGED
@@ -5,8 +5,14 @@ import requests
|
|
5
5
|
from typing import Any, Optional, Union, Literal
|
6
6
|
from uuid import UUID
|
7
7
|
import edsl
|
8
|
-
from edsl import CONFIG, CacheEntry
|
9
|
-
from edsl.coop.utils import
|
8
|
+
from edsl import CONFIG, CacheEntry, Jobs, Survey
|
9
|
+
from edsl.coop.utils import (
|
10
|
+
EDSLObject,
|
11
|
+
ObjectRegistry,
|
12
|
+
ObjectType,
|
13
|
+
RemoteJobStatus,
|
14
|
+
VisibilityType,
|
15
|
+
)
|
10
16
|
|
11
17
|
|
12
18
|
class Coop:
|
@@ -88,6 +94,18 @@ class Coop:
|
|
88
94
|
if value is None:
|
89
95
|
return "null"
|
90
96
|
|
97
|
+
def _resolve_uuid(
|
98
|
+
self, uuid: Union[str, UUID] = None, url: str = None
|
99
|
+
) -> Union[str, UUID]:
|
100
|
+
"""
|
101
|
+
Resolve the uuid from a uuid or a url.
|
102
|
+
"""
|
103
|
+
if not url and not uuid:
|
104
|
+
raise Exception("No uuid or url provided for the object.")
|
105
|
+
if not uuid and url:
|
106
|
+
uuid = url.split("/")[-1]
|
107
|
+
return uuid
|
108
|
+
|
91
109
|
@property
|
92
110
|
def edsl_settings(self) -> dict:
|
93
111
|
"""
|
@@ -100,9 +118,6 @@ class Coop:
|
|
100
118
|
################
|
101
119
|
# Objects
|
102
120
|
################
|
103
|
-
|
104
|
-
# TODO: add URL to get and get_all methods
|
105
|
-
|
106
121
|
def create(
|
107
122
|
self,
|
108
123
|
object: EDSLObject,
|
@@ -113,17 +128,16 @@ class Coop:
|
|
113
128
|
Create an EDSL object in the Coop server.
|
114
129
|
"""
|
115
130
|
object_type = ObjectRegistry.get_object_type_by_edsl_class(object)
|
116
|
-
object_page = ObjectRegistry.get_object_page_by_object_type(object_type)
|
117
131
|
response = self._send_server_request(
|
118
132
|
uri=f"api/v0/object",
|
119
133
|
method="POST",
|
120
134
|
payload={
|
121
135
|
"description": description,
|
122
|
-
"object_type": object_type,
|
123
136
|
"json_string": json.dumps(
|
124
137
|
object.to_dict(),
|
125
138
|
default=self._json_handle_none,
|
126
139
|
),
|
140
|
+
"object_type": object_type,
|
127
141
|
"visibility": visibility,
|
128
142
|
"version": self._edsl_version,
|
129
143
|
},
|
@@ -131,73 +145,51 @@ class Coop:
|
|
131
145
|
self._resolve_server_response(response)
|
132
146
|
response_json = response.json()
|
133
147
|
return {
|
148
|
+
"description": response_json.get("description"),
|
149
|
+
"object_type": object_type,
|
150
|
+
"url": f"{self.url}/content/{response_json.get('uuid')}",
|
134
151
|
"uuid": response_json.get("uuid"),
|
135
152
|
"version": self._edsl_version,
|
136
|
-
"description": response_json.get("description"),
|
137
153
|
"visibility": response_json.get("visibility"),
|
138
|
-
"url": f"{self.url}/content/{response_json.get('uuid')}",
|
139
154
|
}
|
140
155
|
|
141
156
|
def get(
|
142
157
|
self,
|
143
|
-
object_type: ObjectType = None,
|
144
158
|
uuid: Union[str, UUID] = None,
|
145
159
|
url: str = None,
|
146
|
-
|
160
|
+
expected_object_type: Optional[ObjectType] = None,
|
147
161
|
) -> EDSLObject:
|
148
162
|
"""
|
149
|
-
Retrieve an EDSL object
|
150
|
-
-
|
151
|
-
-
|
163
|
+
Retrieve an EDSL object by its uuid or its url.
|
164
|
+
- If the object's visibility is private, the user must be the owner.
|
165
|
+
- Optionally, check if the retrieved object is of a certain type.
|
152
166
|
|
153
|
-
:param object_type: the type of object to retrieve.
|
154
167
|
:param uuid: the uuid of the object either in str or UUID format.
|
155
168
|
:param url: the url of the object.
|
156
|
-
|
157
|
-
if url:
|
158
|
-
object_type = url.split("/")[-2]
|
159
|
-
uuid = url.split("/")[-1]
|
160
|
-
elif not object_type and not uuid:
|
161
|
-
raise Exception("Provide either object_type & UUID, or a url.")
|
162
|
-
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
163
|
-
import time
|
169
|
+
:param expected_object_type: the expected type of the object.
|
164
170
|
|
165
|
-
|
171
|
+
:return: the object instance.
|
172
|
+
"""
|
173
|
+
uuid = self._resolve_uuid(uuid, url)
|
166
174
|
response = self._send_server_request(
|
167
175
|
uri=f"api/v0/object",
|
168
176
|
method="GET",
|
169
|
-
params={"
|
177
|
+
params={"uuid": uuid},
|
170
178
|
)
|
171
|
-
end = time.time()
|
172
|
-
if exec_profile:
|
173
|
-
print("Download exec time = ", end - start)
|
174
179
|
self._resolve_server_response(response)
|
175
180
|
json_string = response.json().get("json_string")
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
return
|
182
|
-
|
183
|
-
def _get_base(
|
184
|
-
self,
|
185
|
-
cls: EDSLObject,
|
186
|
-
uuid: Union[str, UUID],
|
187
|
-
exec_profile=None,
|
188
|
-
) -> EDSLObject:
|
189
|
-
"""
|
190
|
-
Used by the Base class to offer a get functionality.
|
191
|
-
"""
|
192
|
-
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
193
|
-
return self.get(object_type, uuid, exec_profile=exec_profile)
|
181
|
+
object_type = response.json().get("object_type")
|
182
|
+
if expected_object_type and object_type != expected_object_type:
|
183
|
+
raise Exception(f"Expected {expected_object_type=} but got {object_type=}")
|
184
|
+
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
185
|
+
object = edsl_class.from_dict(json.loads(json_string))
|
186
|
+
return object
|
194
187
|
|
195
|
-
def get_all(self, object_type: ObjectType) -> list[
|
188
|
+
def get_all(self, object_type: ObjectType) -> list[dict[str, Any]]:
|
196
189
|
"""
|
197
190
|
Retrieve all objects of a certain type associated with the user.
|
198
191
|
"""
|
199
192
|
edsl_class = ObjectRegistry.object_type_to_edsl_class.get(object_type)
|
200
|
-
object_page = ObjectRegistry.get_object_page_by_object_type(object_type)
|
201
193
|
response = self._send_server_request(
|
202
194
|
uri=f"api/v0/objects",
|
203
195
|
method="GET",
|
@@ -217,33 +209,23 @@ class Coop:
|
|
217
209
|
]
|
218
210
|
return objects
|
219
211
|
|
220
|
-
def delete(self,
|
212
|
+
def delete(self, uuid: Union[str, UUID] = None, url: str = None) -> dict:
|
221
213
|
"""
|
222
214
|
Delete an object from the server.
|
223
215
|
"""
|
216
|
+
uuid = self._resolve_uuid(uuid, url)
|
224
217
|
response = self._send_server_request(
|
225
218
|
uri=f"api/v0/object",
|
226
219
|
method="DELETE",
|
227
|
-
params={"
|
220
|
+
params={"uuid": uuid},
|
228
221
|
)
|
229
222
|
self._resolve_server_response(response)
|
230
223
|
return response.json()
|
231
224
|
|
232
|
-
def _delete_base(
|
233
|
-
self,
|
234
|
-
cls: EDSLObject,
|
235
|
-
uuid: Union[str, UUID],
|
236
|
-
) -> dict:
|
237
|
-
"""
|
238
|
-
Used by the Base class to offer a delete functionality.
|
239
|
-
"""
|
240
|
-
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
241
|
-
return self.delete(object_type, uuid)
|
242
|
-
|
243
225
|
def patch(
|
244
226
|
self,
|
245
|
-
|
246
|
-
|
227
|
+
uuid: Union[str, UUID] = None,
|
228
|
+
url: str = None,
|
247
229
|
description: Optional[str] = None,
|
248
230
|
value: Optional[EDSLObject] = None,
|
249
231
|
visibility: Optional[VisibilityType] = None,
|
@@ -254,14 +236,11 @@ class Coop:
|
|
254
236
|
"""
|
255
237
|
if description is None and visibility is None and value is None:
|
256
238
|
raise Exception("Nothing to patch.")
|
257
|
-
|
258
|
-
value_type = ObjectRegistry.get_object_type_by_edsl_class(value)
|
259
|
-
if value_type != object_type:
|
260
|
-
raise Exception(f"Object type mismatch: {object_type=} {value_type=}")
|
239
|
+
uuid = self._resolve_uuid(uuid, url)
|
261
240
|
response = self._send_server_request(
|
262
241
|
uri=f"api/v0/object",
|
263
242
|
method="PATCH",
|
264
|
-
params={"
|
243
|
+
params={"uuid": uuid},
|
265
244
|
payload={
|
266
245
|
"description": description,
|
267
246
|
"json_string": (
|
@@ -278,20 +257,6 @@ class Coop:
|
|
278
257
|
self._resolve_server_response(response)
|
279
258
|
return response.json()
|
280
259
|
|
281
|
-
def _patch_base(
|
282
|
-
self,
|
283
|
-
cls: EDSLObject,
|
284
|
-
uuid: Union[str, UUID],
|
285
|
-
description: Optional[str] = None,
|
286
|
-
value: Optional[EDSLObject] = None,
|
287
|
-
visibility: Optional[VisibilityType] = None,
|
288
|
-
) -> dict:
|
289
|
-
"""
|
290
|
-
Used by the Base class to offer a patch functionality.
|
291
|
-
"""
|
292
|
-
object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
|
293
|
-
return self.patch(object_type, uuid, description, value, visibility)
|
294
|
-
|
295
260
|
################
|
296
261
|
# Remote Cache
|
297
262
|
################
|
@@ -494,9 +459,57 @@ class Coop:
|
|
494
459
|
################
|
495
460
|
# Remote Inference
|
496
461
|
################
|
462
|
+
def remote_inference_create(
|
463
|
+
self,
|
464
|
+
job: Jobs,
|
465
|
+
description: Optional[str] = None,
|
466
|
+
status: RemoteJobStatus = "queued",
|
467
|
+
visibility: Optional[VisibilityType] = "unlisted",
|
468
|
+
) -> dict:
|
469
|
+
"""
|
470
|
+
Send a remote inference job to the server.
|
471
|
+
|
472
|
+
:param job: The EDSL job to send to the server.
|
473
|
+
:param optional description: A description for this entry in the remote cache.
|
474
|
+
:param status: The status of the job. Should be 'queued', unless you are debugging.
|
475
|
+
:param visibility: The visibility of the cache entry.
|
476
|
+
|
477
|
+
>>> job = Jobs.example()
|
478
|
+
>>> coop.remote_inference_create(job=job, description="My job")
|
479
|
+
{'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.29.dev4'}
|
480
|
+
"""
|
481
|
+
response = self._send_server_request(
|
482
|
+
uri="api/v0/remote-inference",
|
483
|
+
method="POST",
|
484
|
+
payload={
|
485
|
+
"json_string": json.dumps(
|
486
|
+
job.to_dict(),
|
487
|
+
default=self._json_handle_none,
|
488
|
+
),
|
489
|
+
"description": description,
|
490
|
+
"status": status,
|
491
|
+
"visibility": visibility,
|
492
|
+
"version": self._edsl_version,
|
493
|
+
},
|
494
|
+
)
|
495
|
+
self._resolve_server_response(response)
|
496
|
+
response_json = response.json()
|
497
|
+
return {
|
498
|
+
"uuid": response_json.get("jobs_uuid"),
|
499
|
+
"description": response_json.get("description"),
|
500
|
+
"status": response_json.get("status"),
|
501
|
+
"visibility": response_json.get("visibility"),
|
502
|
+
"version": self._edsl_version,
|
503
|
+
}
|
504
|
+
|
497
505
|
def remote_inference_get(self, job_uuid: str) -> dict:
|
498
506
|
"""
|
499
|
-
Get the
|
507
|
+
Get the details of a remote inference job.
|
508
|
+
|
509
|
+
:param job_uuid: The UUID of the EDSL job.
|
510
|
+
|
511
|
+
>>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
|
512
|
+
{'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
|
500
513
|
"""
|
501
514
|
response = self._send_server_request(
|
502
515
|
uri="api/v0/remote-inference",
|
@@ -508,13 +521,44 @@ class Coop:
|
|
508
521
|
return {
|
509
522
|
"jobs_uuid": data.get("jobs_uuid"),
|
510
523
|
"results_uuid": data.get("results_uuid"),
|
511
|
-
"results_url": "
|
524
|
+
"results_url": f"{self.url}/content/{data.get('results_uuid')}",
|
512
525
|
"status": data.get("status"),
|
513
526
|
"reason": data.get("reason"),
|
514
527
|
"price": data.get("price"),
|
515
528
|
"version": data.get("version"),
|
516
529
|
}
|
517
530
|
|
531
|
+
def remote_inference_cost(self, input: Union[Jobs, Survey]) -> int:
|
532
|
+
"""
|
533
|
+
Get the cost of a remote inference job.
|
534
|
+
|
535
|
+
:param input: The EDSL job to send to the server.
|
536
|
+
|
537
|
+
>>> job = Jobs.example()
|
538
|
+
>>> coop.remote_inference_cost(input=job)
|
539
|
+
16
|
540
|
+
"""
|
541
|
+
if isinstance(input, Jobs):
|
542
|
+
job = input
|
543
|
+
elif isinstance(input, Survey):
|
544
|
+
job = Jobs(survey=input)
|
545
|
+
else:
|
546
|
+
raise TypeError("Input must be either a Job or a Survey.")
|
547
|
+
|
548
|
+
response = self._send_server_request(
|
549
|
+
uri="api/v0/remote-inference/cost",
|
550
|
+
method="POST",
|
551
|
+
payload={
|
552
|
+
"json_string": json.dumps(
|
553
|
+
job.to_dict(),
|
554
|
+
default=self._json_handle_none,
|
555
|
+
),
|
556
|
+
},
|
557
|
+
)
|
558
|
+
self._resolve_server_response(response)
|
559
|
+
response_json = response.json()
|
560
|
+
return response_json.get("cost")
|
561
|
+
|
518
562
|
################
|
519
563
|
# Remote Errors
|
520
564
|
################
|
@@ -578,32 +622,65 @@ class Coop:
|
|
578
622
|
return response_json
|
579
623
|
|
580
624
|
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
API_KEY = "b"
|
586
|
-
coop = Coop(api_key=API_KEY)
|
587
|
-
# basics
|
588
|
-
coop
|
589
|
-
coop.edsl_settings
|
590
|
-
|
591
|
-
##############
|
592
|
-
# A. Objects
|
593
|
-
##############
|
625
|
+
def main():
|
626
|
+
"""
|
627
|
+
A simple example for the coop client
|
628
|
+
"""
|
594
629
|
from uuid import uuid4
|
595
630
|
from edsl import (
|
596
631
|
Agent,
|
597
632
|
AgentList,
|
598
633
|
Cache,
|
599
634
|
Notebook,
|
635
|
+
QuestionFreeText,
|
600
636
|
QuestionMultipleChoice,
|
601
637
|
Results,
|
602
638
|
Scenario,
|
603
639
|
ScenarioList,
|
604
640
|
Survey,
|
605
641
|
)
|
642
|
+
from edsl.coop import Coop
|
643
|
+
from edsl.data.CacheEntry import CacheEntry
|
644
|
+
from edsl.jobs import Jobs
|
645
|
+
|
646
|
+
# init & basics
|
647
|
+
API_KEY = "b"
|
648
|
+
coop = Coop(api_key=API_KEY)
|
649
|
+
coop
|
650
|
+
coop.edsl_settings
|
606
651
|
|
652
|
+
##############
|
653
|
+
# A. A simple example
|
654
|
+
##############
|
655
|
+
# .. create and manipulate an object through the Coop client
|
656
|
+
response = coop.create(QuestionMultipleChoice.example())
|
657
|
+
coop.get(uuid=response.get("uuid"))
|
658
|
+
coop.get(uuid=response.get("uuid"), expected_object_type="question")
|
659
|
+
coop.get(url=response.get("url"))
|
660
|
+
coop.create(QuestionMultipleChoice.example())
|
661
|
+
coop.get_all("question")
|
662
|
+
coop.patch(uuid=response.get("uuid"), visibility="private")
|
663
|
+
coop.patch(uuid=response.get("uuid"), description="hey")
|
664
|
+
coop.patch(uuid=response.get("uuid"), value=QuestionFreeText.example())
|
665
|
+
# coop.patch(uuid=response.get("uuid"), value=Survey.example()) - should throw error
|
666
|
+
coop.get(uuid=response.get("uuid"))
|
667
|
+
coop.delete(uuid=response.get("uuid"))
|
668
|
+
|
669
|
+
# .. create and manipulate an object through the class
|
670
|
+
response = QuestionMultipleChoice.example().push()
|
671
|
+
QuestionMultipleChoice.pull(uuid=response.get("uuid"))
|
672
|
+
QuestionMultipleChoice.pull(url=response.get("url"))
|
673
|
+
QuestionMultipleChoice.patch(uuid=response.get("uuid"), visibility="private")
|
674
|
+
QuestionMultipleChoice.patch(uuid=response.get("uuid"), description="hey")
|
675
|
+
QuestionMultipleChoice.patch(
|
676
|
+
uuid=response.get("uuid"), value=QuestionFreeText.example()
|
677
|
+
)
|
678
|
+
QuestionMultipleChoice.pull(response.get("uuid"))
|
679
|
+
QuestionMultipleChoice.delete(response.get("uuid"))
|
680
|
+
|
681
|
+
##############
|
682
|
+
# B. Examples with all objects
|
683
|
+
##############
|
607
684
|
OBJECTS = [
|
608
685
|
("agent", Agent),
|
609
686
|
("agent_list", AgentList),
|
@@ -615,13 +692,12 @@ if __name__ == "__main__":
|
|
615
692
|
("scenario_list", ScenarioList),
|
616
693
|
("survey", Survey),
|
617
694
|
]
|
618
|
-
|
619
695
|
for object_type, cls in OBJECTS:
|
620
696
|
print(f"Testing {object_type} objects")
|
621
697
|
# 1. Delete existing objects
|
622
698
|
existing_objects = coop.get_all(object_type)
|
623
699
|
for item in existing_objects:
|
624
|
-
coop.delete(
|
700
|
+
coop.delete(uuid=item.get("uuid"))
|
625
701
|
# 2. Create new objects
|
626
702
|
example = cls.example()
|
627
703
|
response_1 = coop.create(example)
|
@@ -635,51 +711,26 @@ if __name__ == "__main__":
|
|
635
711
|
assert len(objects) == 4
|
636
712
|
# 4. Try to retrieve an item that does not exist
|
637
713
|
try:
|
638
|
-
coop.get(
|
714
|
+
coop.get(uuid=uuid4())
|
639
715
|
except Exception as e:
|
640
716
|
print(e)
|
641
717
|
# 5. Try to retrieve all test objects by their uuids
|
642
718
|
for response in [response_1, response_2, response_3, response_4]:
|
643
|
-
coop.get(
|
719
|
+
coop.get(uuid=response.get("uuid"))
|
644
720
|
# 6. Change visibility of all objects
|
645
721
|
for item in objects:
|
646
|
-
coop.patch(
|
647
|
-
object_type=object_type, uuid=item.get("uuid"), visibility="private"
|
648
|
-
)
|
722
|
+
coop.patch(uuid=item.get("uuid"), visibility="private")
|
649
723
|
# 6. Change description of all objects
|
650
724
|
for item in objects:
|
651
|
-
coop.patch(
|
652
|
-
object_type=object_type, uuid=item.get("uuid"), description="hey"
|
653
|
-
)
|
725
|
+
coop.patch(uuid=item.get("uuid"), description="hey")
|
654
726
|
# 7. Delete all objects
|
655
727
|
for item in objects:
|
656
|
-
coop.delete(
|
728
|
+
coop.delete(uuid=item.get("uuid"))
|
657
729
|
assert len(coop.get_all(object_type)) == 0
|
658
730
|
|
659
|
-
# a simple example
|
660
|
-
from edsl import Coop, QuestionMultipleChoice, QuestionFreeText
|
661
|
-
|
662
|
-
coop = Coop(api_key="b")
|
663
|
-
response = QuestionMultipleChoice.example().push()
|
664
|
-
QuestionMultipleChoice.pull(response.get("uuid"))
|
665
|
-
coop.patch(object_type="question", uuid=response.get("uuid"), visibility="public")
|
666
|
-
coop.patch(
|
667
|
-
object_type="question",
|
668
|
-
uuid=response.get("uuid"),
|
669
|
-
description="crazy new description",
|
670
|
-
)
|
671
|
-
coop.patch(
|
672
|
-
object_type="question",
|
673
|
-
uuid=response.get("uuid"),
|
674
|
-
value=QuestionFreeText.example(),
|
675
|
-
)
|
676
|
-
coop.delete(object_type="question", uuid=response.get("uuid"))
|
677
|
-
|
678
731
|
##############
|
679
|
-
#
|
732
|
+
# C. Remote Cache
|
680
733
|
##############
|
681
|
-
from edsl.data.CacheEntry import CacheEntry
|
682
|
-
|
683
734
|
# clear
|
684
735
|
coop.remote_cache_clear()
|
685
736
|
assert coop.remote_cache_get() == []
|
@@ -701,30 +752,16 @@ if __name__ == "__main__":
|
|
701
752
|
coop.remote_cache_get()
|
702
753
|
|
703
754
|
##############
|
704
|
-
#
|
755
|
+
# D. Remote Inference
|
705
756
|
##############
|
706
|
-
|
707
|
-
|
708
|
-
|
709
|
-
coop.
|
710
|
-
for job in coop.get_all("job"):
|
711
|
-
coop.delete(object_type="job", uuid=job.get("uuid"))
|
712
|
-
# post a job
|
713
|
-
response = coop.create(Jobs.example())
|
714
|
-
# get job and results
|
715
|
-
coop.remote_inference_get(response.get("uuid"))
|
716
|
-
coop.get(
|
717
|
-
object_type="results",
|
718
|
-
uuid=coop.remote_inference_get(response.get("uuid")).get("results_uuid"),
|
719
|
-
)
|
757
|
+
job = Jobs.example()
|
758
|
+
coop.remote_inference_cost(job)
|
759
|
+
results = coop.remote_inference_create(job)
|
760
|
+
coop.remote_inference_get(results.get("uuid"))
|
720
761
|
|
721
762
|
##############
|
722
|
-
#
|
763
|
+
# E. Errors
|
723
764
|
##############
|
724
|
-
from edsl import Coop
|
725
|
-
|
726
|
-
coop = Coop()
|
727
|
-
coop.api_key = "a"
|
728
765
|
coop.error_create({"something": "This is an error message"})
|
729
766
|
coop.api_key = None
|
730
767
|
coop.error_create({"something": "This is an error message"})
|
edsl/coop/utils.py
CHANGED
@@ -2,7 +2,6 @@ from edsl import (
|
|
2
2
|
Agent,
|
3
3
|
AgentList,
|
4
4
|
Cache,
|
5
|
-
Jobs,
|
6
5
|
Notebook,
|
7
6
|
Results,
|
8
7
|
Scenario,
|
@@ -17,7 +16,6 @@ EDSLObject = Union[
|
|
17
16
|
Agent,
|
18
17
|
AgentList,
|
19
18
|
Cache,
|
20
|
-
Jobs,
|
21
19
|
Notebook,
|
22
20
|
Type[QuestionBase],
|
23
21
|
Results,
|
@@ -31,9 +29,8 @@ ObjectType = Literal[
|
|
31
29
|
"agent",
|
32
30
|
"agent_list",
|
33
31
|
"cache",
|
34
|
-
"job",
|
35
|
-
"question",
|
36
32
|
"notebook",
|
33
|
+
"question",
|
37
34
|
"results",
|
38
35
|
"scenario",
|
39
36
|
"scenario_list",
|
@@ -41,18 +38,12 @@ ObjectType = Literal[
|
|
41
38
|
"study",
|
42
39
|
]
|
43
40
|
|
44
|
-
|
45
|
-
|
46
|
-
"
|
47
|
-
"
|
48
|
-
"
|
49
|
-
"
|
50
|
-
"questions",
|
51
|
-
"results",
|
52
|
-
"scenarios",
|
53
|
-
"scenariolists",
|
54
|
-
"surveys",
|
55
|
-
"studies",
|
41
|
+
|
42
|
+
RemoteJobStatus = Literal[
|
43
|
+
"queued",
|
44
|
+
"running",
|
45
|
+
"completed",
|
46
|
+
"failed",
|
56
47
|
]
|
57
48
|
|
58
49
|
VisibilityType = Literal[
|
@@ -68,67 +59,21 @@ class ObjectRegistry:
|
|
68
59
|
"""
|
69
60
|
|
70
61
|
objects = [
|
71
|
-
{
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
},
|
76
|
-
{
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
},
|
81
|
-
{
|
82
|
-
"object_type": "cache",
|
83
|
-
"edsl_class": Cache,
|
84
|
-
"object_page": "caches",
|
85
|
-
},
|
86
|
-
{
|
87
|
-
"object_type": "job",
|
88
|
-
"edsl_class": Jobs,
|
89
|
-
"object_page": "jobs",
|
90
|
-
},
|
91
|
-
{
|
92
|
-
"object_type": "question",
|
93
|
-
"edsl_class": QuestionBase,
|
94
|
-
"object_page": "questions",
|
95
|
-
},
|
96
|
-
{
|
97
|
-
"object_type": "notebook",
|
98
|
-
"edsl_class": Notebook,
|
99
|
-
"object_page": "notebooks",
|
100
|
-
},
|
101
|
-
{
|
102
|
-
"object_type": "results",
|
103
|
-
"edsl_class": Results,
|
104
|
-
"object_page": "results",
|
105
|
-
},
|
106
|
-
{
|
107
|
-
"object_type": "scenario",
|
108
|
-
"edsl_class": Scenario,
|
109
|
-
"object_page": "scenarios",
|
110
|
-
},
|
111
|
-
{
|
112
|
-
"object_type": "scenario_list",
|
113
|
-
"edsl_class": ScenarioList,
|
114
|
-
"object_page": "scenariolists",
|
115
|
-
},
|
116
|
-
{
|
117
|
-
"object_type": "survey",
|
118
|
-
"edsl_class": Survey,
|
119
|
-
"object_page": "surveys",
|
120
|
-
},
|
121
|
-
{
|
122
|
-
"object_type": "study",
|
123
|
-
"edsl_class": Study,
|
124
|
-
"object_page": "studies",
|
125
|
-
},
|
62
|
+
{"object_type": "agent", "edsl_class": Agent},
|
63
|
+
{"object_type": "agent_list", "edsl_class": AgentList},
|
64
|
+
{"object_type": "cache", "edsl_class": Cache},
|
65
|
+
{"object_type": "question", "edsl_class": QuestionBase},
|
66
|
+
{"object_type": "notebook", "edsl_class": Notebook},
|
67
|
+
{"object_type": "results", "edsl_class": Results},
|
68
|
+
{"object_type": "scenario", "edsl_class": Scenario},
|
69
|
+
{"object_type": "scenario_list", "edsl_class": ScenarioList},
|
70
|
+
{"object_type": "survey", "edsl_class": Survey},
|
71
|
+
{"object_type": "study", "edsl_class": Study},
|
126
72
|
]
|
127
73
|
object_type_to_edsl_class = {o["object_type"]: o["edsl_class"] for o in objects}
|
128
74
|
edsl_class_to_object_type = {
|
129
75
|
o["edsl_class"].__name__: o["object_type"] for o in objects
|
130
76
|
}
|
131
|
-
object_type_to_object_page = {o["object_type"]: o["object_page"] for o in objects}
|
132
77
|
|
133
78
|
@classmethod
|
134
79
|
def get_object_type_by_edsl_class(cls, edsl_object: EDSLObject) -> ObjectType:
|
@@ -149,7 +94,3 @@ class ObjectRegistry:
|
|
149
94
|
if EDSL_object is None:
|
150
95
|
raise ValueError(f"EDSL class not found for {object_type=}")
|
151
96
|
return EDSL_object
|
152
|
-
|
153
|
-
@classmethod
|
154
|
-
def get_object_page_by_object_type(cls, object_type: ObjectType) -> ObjectPage:
|
155
|
-
return cls.object_type_to_object_page.get(object_type)
|