edsl 0.1.57__py3-none-any.whl → 0.1.59__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.57"
1
+ __version__ = "0.1.59"
edsl/agents/agent.py CHANGED
@@ -426,6 +426,25 @@ class Agent(Base):
426
426
  self.traits_presentation_template = "Your traits: {{traits}}"
427
427
  self.set_traits_presentation_template = False
428
428
 
429
+
430
+ def drop(self, field_name: str) -> Agent:
431
+ """Drop a field from the agent.
432
+
433
+ Args:
434
+ field_name: The name of the field to drop.
435
+ """
436
+ d = self.to_dict()
437
+ if field_name in d['traits']:
438
+ d['traits'].pop(field_name)
439
+ elif field_name in d:
440
+ d.pop(field_name)
441
+ else:
442
+ raise AgentErrors((f"Field '{field_name}' not found in agent"
443
+ f"Available fields: {d.keys()}"
444
+ f"Available traits: {d['traits'].keys()}"
445
+ ))
446
+ return Agent.from_dict(d)
447
+
429
448
  def duplicate(self) -> Agent:
430
449
  """Create a deep copy of this agent with all its traits and capabilities.
431
450
 
@@ -1213,7 +1232,7 @@ class Agent(Base):
1213
1232
  """
1214
1233
  return dict_hash(self.to_dict(add_edsl_version=False))
1215
1234
 
1216
- def to_dict(self, add_edsl_version=True) -> dict[str, Union[dict, bool]]:
1235
+ def to_dict(self, add_edsl_version=True, full_dict=False) -> dict[str, Union[dict, bool]]:
1217
1236
  """Serialize to a dictionary with EDSL info.
1218
1237
 
1219
1238
  Example usage:
@@ -1230,11 +1249,11 @@ class Agent(Base):
1230
1249
  d["traits"] = copy.deepcopy(dict(self._traits))
1231
1250
  if self.name:
1232
1251
  d["name"] = self.name
1233
- if self.set_instructions:
1252
+ if self.set_instructions or full_dict:
1234
1253
  d["instruction"] = self.instruction
1235
- if self.set_traits_presentation_template:
1254
+ if self.set_traits_presentation_template or full_dict:
1236
1255
  d["traits_presentation_template"] = self.traits_presentation_template
1237
- if self.codebook:
1256
+ if self.codebook or full_dict:
1238
1257
  d["codebook"] = self.codebook
1239
1258
  if add_edsl_version:
1240
1259
  from edsl import __version__
edsl/agents/agent_list.py CHANGED
@@ -47,13 +47,13 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
47
47
  with methods for filtering, transforming, and analyzing collections of agents.
48
48
 
49
49
 
50
- >>> AgentList.example().to_scenario_list()
51
- ScenarioList([Scenario({'age': 22, 'hair': 'brown', 'height': 5.5}), Scenario({'age': 22, 'hair': 'brown', 'height': 5.5})])
52
-
50
+ >>> AgentList.example().to_scenario_list().drop('age')
51
+ ScenarioList([Scenario({'hair': 'brown', 'height': 5.5}), Scenario({'hair': 'brown', 'height': 5.5})])
52
+
53
53
  >>> AgentList.example().to_dataset()
54
54
  Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
55
55
 
56
- >>> AgentList.example().to_pandas()
56
+ >>> AgentList.example().select('age', 'hair', 'height').to_pandas()
57
57
  age hair height
58
58
  0 22 brown 5.5
59
59
  1 22 brown 5.5
@@ -91,6 +91,28 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
91
91
  if codebook is not None:
92
92
  self.set_codebook(codebook)
93
93
 
94
+ def set_instruction(self, instruction: str) -> None:
95
+ """Set the instruction for all agents in the list.
96
+
97
+ Args:
98
+ instruction: The instruction to set.
99
+ """
100
+ for agent in self.data:
101
+ agent.instruction = instruction
102
+
103
+ return None
104
+
105
+ def set_traits_presentation_template(self, traits_presentation_template: str) -> None:
106
+ """Set the traits presentation template for all agents in the list.
107
+
108
+ Args:
109
+ traits_presentation_template: The traits presentation template to set.
110
+ """
111
+ for agent in self.data:
112
+ agent.traits_presentation_template = traits_presentation_template
113
+
114
+ return None
115
+
94
116
  def shuffle(self, seed: Optional[str] = None) -> AgentList:
95
117
  """Randomly shuffle the agents in place.
96
118
 
@@ -119,6 +141,14 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
119
141
  if seed:
120
142
  random.seed(seed)
121
143
  return AgentList(random.sample(self.data, n))
144
+
145
+ def drop(self, field_name: str) -> AgentList:
146
+ """Drop a field from the AgentList.
147
+
148
+ Args:
149
+ field_name: The name of the field to drop.
150
+ """
151
+ return AgentList([a.drop(field_name) for a in self.data])
122
152
 
123
153
  def duplicate(self) -> AgentList:
124
154
  """Create a deep copy of the AgentList.
@@ -478,7 +508,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
478
508
  >>> al.to_dataset()
479
509
  Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}])
480
510
  >>> al.to_dataset(traits_only=False) # doctest: +NORMALIZE_WHITESPACE
481
- Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None}]}])
511
+ Dataset([{'age': [22, 22]}, {'hair': ['brown', 'brown']}, {'height': [5.5, 5.5]}, {'agent_parameters': [{'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}, {'instruction': 'You are answering questions as if you were a human. Do not break character.', 'agent_name': None, 'traits_presentation_template': 'Your traits: {{traits}}'}]}])
482
512
  """
483
513
  from ..dataset import Dataset
484
514
 
@@ -495,7 +525,7 @@ class AgentList(UserList, Base, AgentListOperationsMixin):
495
525
  data[trait_key].append(agent.traits.get(trait_key, None))
496
526
  if not traits_only:
497
527
  data["agent_parameters"].append(
498
- {"instruction": agent.instruction, "agent_name": agent.name}
528
+ {"instruction": agent.instruction, "agent_name": agent.name, "traits_presentation_template": agent.traits_presentation_template}
499
529
  )
500
530
  return Dataset([{key: entry} for key, entry in data.items()])
501
531
 
edsl/coop/coop.py CHANGED
@@ -2,6 +2,7 @@ import aiohttp
2
2
  import base64
3
3
  import json
4
4
  import requests
5
+ import time
5
6
 
6
7
  from typing import Any, Optional, Union, Literal, List, TypedDict, TYPE_CHECKING
7
8
  from uuid import UUID
@@ -13,7 +14,9 @@ from ..caching import CacheEntry
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from ..jobs import Jobs
17
+ from ..scenarios import ScenarioList
16
18
  from ..surveys import Survey
19
+ from ..results import Results
17
20
 
18
21
  from .exceptions import (
19
22
  CoopInvalidURLError,
@@ -37,17 +40,59 @@ from .ep_key_handling import ExpectedParrotKeyHandler
37
40
  from ..inference_services.data_structures import ServiceToModelsMapping
38
41
 
39
42
 
43
+ class JobRunExpense(TypedDict):
44
+ service: str
45
+ model: str
46
+ token_type: Literal["input", "output"]
47
+ price_per_million_tokens: float
48
+ tokens_count: int
49
+ cost_credits: float
50
+ cost_usd: float
51
+
52
+
53
+ class JobRunExceptionCounter(TypedDict):
54
+ exception_type: str
55
+ inference_service: str
56
+ model: str
57
+ question_name: str
58
+ exception_count: int
59
+
60
+
61
+ class JobRunInterviewDetails(TypedDict):
62
+ total_interviews: int
63
+ completed_interviews: int
64
+ interviews_with_exceptions: int
65
+ exception_summary: List[JobRunExceptionCounter]
66
+
67
+
68
+ class LatestJobRunDetails(TypedDict):
69
+
70
+ # For running, completed, and partially failed jobs
71
+ interview_details: Optional[JobRunInterviewDetails] = None
72
+
73
+ # For failed jobs only
74
+ failure_reason: Optional[Literal["error", "insufficient funds"]] = None
75
+ failure_description: Optional[str] = None
76
+
77
+ # For partially failed jobs only
78
+ error_report_uuid: Optional[UUID] = None
79
+
80
+ # For completed and partially failed jobs
81
+ cost_credits: Optional[float] = None
82
+ cost_usd: Optional[float] = None
83
+ expenses: Optional[list[JobRunExpense]] = None
84
+
85
+
40
86
  class RemoteInferenceResponse(TypedDict):
41
87
  job_uuid: str
42
88
  results_uuid: str
43
- results_url: str
44
- latest_error_report_uuid: str
45
- latest_error_report_url: str
46
- status: str
47
- reason: str
48
- credits_consumed: float
49
- version: str
50
89
  job_json_string: Optional[str]
90
+ status: RemoteJobStatus
91
+ latest_job_run_details: LatestJobRunDetails
92
+ description: Optional[str]
93
+ version: str
94
+ visibility: VisibilityType
95
+ results_url: str
51
96
 
52
97
 
53
98
  class RemoteInferenceCreationInfo(TypedDict):
@@ -168,7 +213,9 @@ class Coop(CoopFunctionsMixin):
168
213
  and "json_string" in payload
169
214
  and payload.get("json_string") is not None
170
215
  ):
171
- timeout = max(40, (len(payload.get("json_string", "")) // (1024 * 1024)))
216
+ timeout = max(
217
+ 60, 2 * (len(payload.get("json_string", "")) // (1024 * 1024))
218
+ )
172
219
  try:
173
220
  if method in ["GET", "DELETE"]:
174
221
  response = requests.request(
@@ -244,7 +291,6 @@ class Coop(CoopFunctionsMixin):
244
291
  # print(
245
292
  # "Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
246
293
  # )
247
-
248
294
  if response.status_code >= 400:
249
295
  try:
250
296
  message = str(response.json().get("detail"))
@@ -524,6 +570,7 @@ class Coop(CoopFunctionsMixin):
524
570
  json.dumps(
525
571
  object_dict,
526
572
  default=self._json_handle_none,
573
+ allow_nan=False,
527
574
  )
528
575
  if object_type != "scenario"
529
576
  else ""
@@ -542,6 +589,7 @@ class Coop(CoopFunctionsMixin):
542
589
  json_data = json.dumps(
543
590
  object_dict,
544
591
  default=self._json_handle_none,
592
+ allow_nan=False,
545
593
  )
546
594
  headers = {"Content-Type": "application/json"}
547
595
  if response_json.get("upload_signed_url"):
@@ -885,6 +933,7 @@ class Coop(CoopFunctionsMixin):
885
933
  json.dumps(
886
934
  value.to_dict(),
887
935
  default=self._json_handle_none,
936
+ allow_nan=False,
888
937
  )
889
938
  if value
890
939
  else None
@@ -1063,16 +1112,36 @@ class Coop(CoopFunctionsMixin):
1063
1112
 
1064
1113
  Returns:
1065
1114
  RemoteInferenceResponse: Information about the job including:
1066
- - job_uuid: The unique identifier for the job
1067
- - results_uuid: The UUID of the results (if job is completed)
1068
- - results_url: URL to access the results (if available)
1069
- - latest_error_report_uuid: UUID of error report (if job failed)
1070
- - latest_error_report_url: URL to access error details (if available)
1071
- - status: Current status ("queued", "running", "completed", "failed")
1072
- - reason: Reason for failure (if applicable)
1073
- - credits_consumed: Credits used for the job execution
1074
- - version: EDSL version used for the job
1075
- - job_json_string: The json string for the job (if include_json_string is True)
1115
+ job_uuid: The unique identifier for the job
1116
+ results_uuid: The UUID of the results
1117
+ results_url: URL to access the results
1118
+ status: Current status ("queued", "running", "completed", "failed")
1119
+ version: EDSL version used for the job
1120
+ job_json_string: The json string for the job (if include_json_string is True)
1121
+ latest_job_run_details: Metadata about the job status
1122
+ interview_details: Metadata about the job interview status (for jobs that have reached running status)
1123
+ total_interviews: The total number of interviews in the job
1124
+ completed_interviews: The number of completed interviews
1125
+ interviews_with_exceptions: The number of completed interviews that have exceptions
1126
+ exception_counters: A list of exception counts for the job
1127
+ exception_type: The type of exception
1128
+ inference_service: The inference service
1129
+ model: The model
1130
+ question_name: The name of the question
1131
+ exception_count: The number of exceptions
1132
+ failure_reason: The reason the job failed (failed jobs only)
1133
+ failure_description: The description of the failure (failed jobs only)
1134
+ error_report_uuid: The UUID of the error report (partially failed jobs only)
1135
+ cost_credits: The cost of the job run in credits
1136
+ cost_usd: The cost of the job run in USD
1137
+ expenses: The expenses incurred by the job run
1138
+ service: The service
1139
+ model: The model
1140
+ token_type: The type of token (input or output)
1141
+ price_per_million_tokens: The price per million tokens
1142
+ tokens_count: The number of tokens consumed
1143
+ cost_credits: The cost of the service/model/token type combination in credits
1144
+ cost_usd: The cost of the service/model/token type combination in USD
1076
1145
 
1077
1146
  Raises:
1078
1147
  ValueError: If neither job_uuid nor results_uuid is provided
@@ -1098,6 +1167,8 @@ class Coop(CoopFunctionsMixin):
1098
1167
  params = {"job_uuid": job_uuid}
1099
1168
  else:
1100
1169
  params = {"results_uuid": results_uuid}
1170
+ if include_json_string:
1171
+ params["include_json_string"] = include_json_string
1101
1172
 
1102
1173
  response = self._send_server_request(
1103
1174
  uri="api/v0/remote-inference",
@@ -1108,35 +1179,32 @@ class Coop(CoopFunctionsMixin):
1108
1179
  data = response.json()
1109
1180
 
1110
1181
  results_uuid = data.get("results_uuid")
1111
- latest_error_report_uuid = data.get("latest_error_report_uuid")
1112
1182
 
1113
1183
  if results_uuid is None:
1114
1184
  results_url = None
1115
1185
  else:
1116
1186
  results_url = f"{self.url}/content/{results_uuid}"
1117
1187
 
1118
- if latest_error_report_uuid is None:
1119
- latest_error_report_url = None
1120
- else:
1121
- latest_error_report_url = (
1122
- f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
1123
- )
1188
+ latest_job_run_details = data.get("latest_job_run_details", {})
1189
+ if data.get("status") == "partial_failed":
1190
+ latest_error_report_uuid = latest_job_run_details.get("error_report_uuid")
1191
+ if latest_error_report_uuid is None:
1192
+ latest_job_run_details["error_report_url"] = None
1193
+ else:
1194
+ latest_error_report_url = (
1195
+ f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
1196
+ )
1197
+ latest_job_run_details["error_report_url"] = latest_error_report_url
1124
1198
 
1125
1199
  return RemoteInferenceResponse(
1126
1200
  **{
1127
1201
  "job_uuid": data.get("job_uuid"),
1128
1202
  "results_uuid": results_uuid,
1129
1203
  "results_url": results_url,
1130
- "latest_error_report_uuid": latest_error_report_uuid,
1131
- "latest_error_report_url": latest_error_report_url,
1132
- "latest_failure_description": data.get("latest_failure_details"),
1133
1204
  "status": data.get("status"),
1134
- "reason": data.get("latest_failure_reason"),
1135
- "credits_consumed": data.get("price"),
1136
1205
  "version": data.get("version"),
1137
- "job_json_string": (
1138
- data.get("job_json_string") if include_json_string else None
1139
- ),
1206
+ "job_json_string": data.get("job_json_string"),
1207
+ "latest_job_run_details": latest_job_run_details,
1140
1208
  }
1141
1209
  )
1142
1210
 
@@ -1323,12 +1391,108 @@ class Coop(CoopFunctionsMixin):
1323
1391
  self._resolve_server_response(response)
1324
1392
  response_json = response.json()
1325
1393
  return {
1326
- "name": response_json.get("project_name"),
1394
+ "project_name": response_json.get("project_name"),
1327
1395
  "uuid": response_json.get("uuid"),
1328
1396
  "admin_url": f"{self.url}/home/projects/{response_json.get('uuid')}",
1329
1397
  "respondent_url": f"{self.url}/respond/{response_json.get('uuid')}",
1330
1398
  }
1331
1399
 
1400
+ def get_project(
1401
+ self,
1402
+ project_uuid: str,
1403
+ ) -> dict:
1404
+ """
1405
+ Get a project from Coop.
1406
+ """
1407
+ response = self._send_server_request(
1408
+ uri=f"api/v0/projects/{project_uuid}",
1409
+ method="GET",
1410
+ )
1411
+ self._resolve_server_response(response)
1412
+ response_json = response.json()
1413
+ return {
1414
+ "project_name": response_json.get("project_name"),
1415
+ "project_job_uuids": response_json.get("job_uuids"),
1416
+ }
1417
+
1418
+ def get_project_human_responses(
1419
+ self,
1420
+ project_uuid: str,
1421
+ ) -> Union["Results", "ScenarioList"]:
1422
+ """
1423
+ Return a Results object with the human responses for a project.
1424
+
1425
+ If generating the Results object fails, a ScenarioList will be returned instead.
1426
+ """
1427
+ from ..agents import Agent, AgentList
1428
+ from ..caching import Cache
1429
+ from ..language_models import Model
1430
+ from ..scenarios import Scenario, ScenarioList
1431
+ from ..surveys import Survey
1432
+
1433
+ response = self._send_server_request(
1434
+ uri=f"api/v0/projects/{project_uuid}/human-responses",
1435
+ method="GET",
1436
+ )
1437
+ self._resolve_server_response(response)
1438
+ response_json = response.json()
1439
+ human_responses = response_json.get("human_responses", [])
1440
+
1441
+ try:
1442
+ agent_list = AgentList()
1443
+
1444
+ for response in human_responses:
1445
+ response_uuid = response.get("response_uuid")
1446
+ if response_uuid is None:
1447
+ raise RuntimeError(
1448
+ "One of your responses is missing a unique identifier."
1449
+ )
1450
+
1451
+ response_dict = json.loads(response.get("response_json_string"))
1452
+
1453
+ a = Agent(name=response_uuid, instruction="")
1454
+
1455
+ def create_answer_function(response_data):
1456
+ def f(self, question, scenario):
1457
+ return response_data.get(question.question_name, None)
1458
+
1459
+ return f
1460
+
1461
+ a.add_direct_question_answering_method(
1462
+ create_answer_function(response_dict)
1463
+ )
1464
+ agent_list.append(a)
1465
+
1466
+ survey_json_string = response_json.get("survey_json_string")
1467
+ survey = Survey.from_dict(json.loads(survey_json_string))
1468
+
1469
+ model = Model("test")
1470
+ results = (
1471
+ survey.by(agent_list)
1472
+ .by(model)
1473
+ .run(
1474
+ cache=Cache(),
1475
+ disable_remote_cache=True,
1476
+ disable_remote_inference=True,
1477
+ print_exceptions=False,
1478
+ )
1479
+ )
1480
+ return results
1481
+ except Exception:
1482
+ human_response_scenarios = []
1483
+ for response in human_responses:
1484
+ response_uuid = response.get("response_uuid")
1485
+ if response_uuid is None:
1486
+ raise RuntimeError(
1487
+ "One of your responses is missing a unique identifier."
1488
+ )
1489
+
1490
+ response_dict = json.loads(response.get("response_json_string"))
1491
+ response_dict["agent_name"] = response_uuid
1492
+ scenario = Scenario(response_dict)
1493
+ human_response_scenarios.append(scenario)
1494
+ return ScenarioList(human_response_scenarios)
1495
+
1332
1496
  def __repr__(self):
1333
1497
  """Return a string representation of the client."""
1334
1498
  return f"Client(api_key='{self.api_key}', url='{self.url}')"
@@ -1556,6 +1720,11 @@ class Coop(CoopFunctionsMixin):
1556
1720
  f"[#38bdf8][link={url}][underline]Log in and automatically store key[/underline][/link][/#38bdf8]"
1557
1721
  )
1558
1722
 
1723
+ print("Logging in will activate the following features:")
1724
+ print(" - Remote inference: Runs jobs remotely on the Expected Parrot server.")
1725
+ print(" - Remote logging: Sends error messages to the Expected Parrot server.")
1726
+ print("\n")
1727
+
1559
1728
  def _get_api_key(self, edsl_auth_token: str):
1560
1729
  """
1561
1730
  Given an EDSL auth token, find the corresponding user's API key.
@@ -1600,6 +1769,76 @@ class Coop(CoopFunctionsMixin):
1600
1769
  # Add API key to environment
1601
1770
  load_dotenv()
1602
1771
 
1772
+ def transfer_credits(
1773
+ self,
1774
+ credits_transferred: int,
1775
+ recipient_username: str,
1776
+ transfer_note: str = None,
1777
+ ) -> dict:
1778
+ """
1779
+ Transfer credits to another user.
1780
+
1781
+ This method transfers a specified number of credits from the authenticated user's
1782
+ account to another user's account on the Expected Parrot platform.
1783
+
1784
+ Parameters:
1785
+ credits_transferred (int): The number of credits to transfer to the recipient
1786
+ recipient_username (str): The username of the recipient
1787
+ transfer_note (str, optional): A personal note to include with the transfer
1788
+
1789
+ Returns:
1790
+ dict: Information about the transfer transaction, including:
1791
+ - success: Whether the transaction was successful
1792
+ - transaction_id: A unique identifier for the transaction
1793
+ - remaining_credits: The number of credits remaining in the sender's account
1794
+
1795
+ Raises:
1796
+ CoopServerResponseError: If there's an error communicating with the server
1797
+ or if the transfer criteria aren't met (e.g., insufficient credits)
1798
+
1799
+ Example:
1800
+ >>> result = coop.transfer_credits(
1801
+ ... credits_transferred=100,
1802
+ ... recipient_username="friend_username",
1803
+ ... transfer_note="Thanks for your help!"
1804
+ ... )
1805
+ >>> print(f"Transfer successful! You have {result['remaining_credits']} credits left.")
1806
+ """
1807
+ response = self._send_server_request(
1808
+ uri="api/users/gift",
1809
+ method="POST",
1810
+ payload={
1811
+ "credits_gifted": credits_transferred,
1812
+ "recipient_username": recipient_username,
1813
+ "gift_note": transfer_note,
1814
+ },
1815
+ )
1816
+ self._resolve_server_response(response)
1817
+ return response.json()
1818
+
1819
+ def get_balance(self) -> dict:
1820
+ """
1821
+ Get the current credit balance for the authenticated user.
1822
+
1823
+ This method retrieves the user's current credit balance information from
1824
+ the Expected Parrot platform.
1825
+
1826
+ Returns:
1827
+ dict: Information about the user's credit balance, including:
1828
+ - credits: The current number of credits in the user's account
1829
+ - usage_history: Recent credit usage if available
1830
+
1831
+ Raises:
1832
+ CoopServerResponseError: If there's an error communicating with the server
1833
+
1834
+ Example:
1835
+ >>> balance = coop.get_balance()
1836
+ >>> print(f"You have {balance['credits']} credits available.")
1837
+ """
1838
+ response = self._send_server_request(uri="api/users/get_balance", method="GET")
1839
+ self._resolve_server_response(response)
1840
+ return response.json()
1841
+
1603
1842
 
1604
1843
  def main():
1605
1844
  """
edsl/coop/utils.py CHANGED
@@ -1,3 +1,4 @@
1
+ import math
1
2
  from typing import Literal, Optional, Type, Union
2
3
 
3
4
  from ..agents import Agent, AgentList
@@ -197,3 +198,65 @@ class ObjectRegistry:
197
198
  if (class_name := o["edsl_class"].__name__) not in subclass_registry
198
199
  and class_name not in exclude_classes
199
200
  }
201
+
202
+
203
+ class CostConverter:
204
+ CENTS_PER_CREDIT = 1
205
+
206
+ @staticmethod
207
+ def _credits_to_minicredits(credits: float) -> float:
208
+ """
209
+ Converts credits to minicredits.
210
+
211
+ Current conversion: minicredits = credits * 100
212
+ """
213
+
214
+ return credits * 100
215
+
216
+ @staticmethod
217
+ def _minicredits_to_credits(minicredits: float) -> float:
218
+ """
219
+ Converts minicredits to credits.
220
+
221
+ Current conversion: credits = minicredits / 100
222
+ """
223
+
224
+ return minicredits / 100
225
+
226
+ def _usd_to_minicredits(self, usd: float) -> float:
227
+ """Converts USD to minicredits."""
228
+
229
+ cents = usd * 100
230
+ credits_per_cent = 1 / int(self.CENTS_PER_CREDIT)
231
+
232
+ credits = cents * credits_per_cent
233
+
234
+ minicredits = self._credits_to_minicredits(credits)
235
+
236
+ return minicredits
237
+
238
+ def _minicredits_to_usd(self, minicredits: float) -> float:
239
+ """Converts minicredits to USD."""
240
+
241
+ credits = self._minicredits_to_credits(minicredits)
242
+
243
+ cents_per_credit = int(self.CENTS_PER_CREDIT)
244
+
245
+ cents = credits * cents_per_credit
246
+ usd = cents / 100
247
+
248
+ return usd
249
+
250
+ def usd_to_credits(self, usd: float) -> float:
251
+ """Converts USD to credits."""
252
+
253
+ minicredits = math.ceil(self._usd_to_minicredits(usd))
254
+ credits = self._minicredits_to_credits(minicredits)
255
+ return round(credits, 2)
256
+
257
+ def credits_to_usd(self, credits: float) -> float:
258
+ """Converts credits to USD."""
259
+
260
+ minicredits = math.ceil(self._credits_to_minicredits(credits))
261
+ usd = self._minicredits_to_usd(minicredits)
262
+ return usd