edsl 0.1.36.dev1__py3-none-any.whl → 0.1.36.dev5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/Base.py CHANGED
@@ -7,6 +7,8 @@ import json
7
7
  from typing import Any, Optional, Union
8
8
  from uuid import UUID
9
9
 
10
+ # from edsl.utilities.MethodSuggesterMixin import MethodSuggesterMixin
11
+
10
12
 
11
13
  class RichPrintingMixin:
12
14
  """Mixin for rich printing and persistence of objects."""
@@ -274,6 +276,9 @@ class Base(
274
276
  """This method should be implemented by subclasses."""
275
277
  raise NotImplementedError("This method is not implemented yet.")
276
278
 
279
+ def to_json(self):
280
+ return json.dumps(self.to_dict())
281
+
277
282
  @abstractmethod
278
283
  def from_dict():
279
284
  """This method should be implemented by subclasses."""
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.36.dev1"
1
+ __version__ = "0.1.36.dev5"
edsl/agents/Agent.py CHANGED
@@ -111,7 +111,11 @@ class Agent(Base):
111
111
  self.name = name
112
112
  self._traits = traits or dict()
113
113
  self.codebook = codebook or dict()
114
- self.instruction = instruction or self.default_instruction
114
+ if instruction is None:
115
+ self.instruction = self.default_instruction
116
+ else:
117
+ self.instruction = instruction
118
+ # self.instruction = instruction or self.default_instruction
115
119
  self.dynamic_traits_function = dynamic_traits_function
116
120
 
117
121
  # Deal with dynamic traits function
@@ -212,9 +212,10 @@ class PromptConstructor:
212
212
  )
213
213
 
214
214
  if relevant_instructions != []:
215
- preamble_text = Prompt(
216
- text="Before answer this question, you were given the following instructions: "
217
- )
215
+ # preamble_text = Prompt(
216
+ # text="You were given the following instructions: "
217
+ # )
218
+ preamble_text = Prompt(text="")
218
219
  for instruction in relevant_instructions:
219
220
  preamble_text += instruction.text
220
221
  rendered_instructions = preamble_text + rendered_instructions
edsl/coop/PriceFetcher.py CHANGED
@@ -16,30 +16,26 @@ class PriceFetcher:
16
16
  if self._cached_prices is not None:
17
17
  return self._cached_prices
18
18
 
19
+ import os
19
20
  import requests
20
- import csv
21
- from io import StringIO
22
-
23
- sheet_id = "1SAO3Bhntefl0XQHJv27rMxpvu6uzKDWNXFHRa7jrUDs"
24
-
25
- # Construct the URL to fetch the CSV
26
- url = f"https://docs.google.com/spreadsheets/d/{sheet_id}/export?format=csv"
21
+ from edsl import CONFIG
27
22
 
28
23
  try:
29
- # Fetch the CSV data
30
- response = requests.get(url)
24
+ # Fetch the pricing data
25
+ url = f"{CONFIG.EXPECTED_PARROT_URL}/api/v0/prices"
26
+ api_key = os.getenv("EXPECTED_PARROT_API_KEY")
27
+ headers = {}
28
+ if api_key:
29
+ headers["Authorization"] = f"Bearer {api_key}"
30
+ else:
31
+ headers["Authorization"] = f"Bearer None"
32
+
33
+ response = requests.get(url, headers=headers, timeout=20)
31
34
  response.raise_for_status() # Raise an exception for bad responses
32
35
 
33
- # Parse the CSV data
34
- csv_data = StringIO(response.text)
35
- reader = csv.reader(csv_data)
36
-
37
- # Convert to list of dictionaries
38
- headers = next(reader)
39
- data = [dict(zip(headers, row)) for row in reader]
36
+ # Parse the data
37
+ data = response.json()
40
38
 
41
- # self._cached_prices = data
42
- # return data
43
39
  price_lookup = {}
44
40
  for entry in data:
45
41
  service = entry.get("service", None)
edsl/coop/coop.py CHANGED
@@ -6,6 +6,7 @@ from typing import Any, Optional, Union, Literal
6
6
  from uuid import UUID
7
7
  import edsl
8
8
  from edsl import CONFIG, CacheEntry, Jobs, Survey
9
+ from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
9
10
  from edsl.coop.utils import (
10
11
  EDSLObject,
11
12
  ObjectRegistry,
@@ -99,7 +100,7 @@ class Coop:
99
100
  if "Authorization" in message:
100
101
  print(message)
101
102
  message = "Please provide an Expected Parrot API key."
102
- raise Exception(message)
103
+ raise CoopServerResponseError(message)
103
104
 
104
105
  def _json_handle_none(self, value: Any) -> Any:
105
106
  """
@@ -116,7 +117,7 @@ class Coop:
116
117
  Resolve the uuid from a uuid or a url.
117
118
  """
118
119
  if not url and not uuid:
119
- raise Exception("No uuid or url provided for the object.")
120
+ raise CoopNoUUIDError("No uuid or url provided for the object.")
120
121
  if not uuid and url:
121
122
  uuid = url.split("/")[-1]
122
123
  return uuid
@@ -521,7 +522,7 @@ class Coop:
521
522
  self._resolve_server_response(response)
522
523
  response_json = response.json()
523
524
  return {
524
- "uuid": response_json.get("jobs_uuid"),
525
+ "uuid": response_json.get("job_uuid"),
525
526
  "description": response_json.get("description"),
526
527
  "status": response_json.get("status"),
527
528
  "iterations": response_json.get("iterations"),
@@ -529,29 +530,41 @@ class Coop:
529
530
  "version": self._edsl_version,
530
531
  }
531
532
 
532
- def remote_inference_get(self, job_uuid: str) -> dict:
533
+ def remote_inference_get(
534
+ self, job_uuid: Optional[str] = None, results_uuid: Optional[str] = None
535
+ ) -> dict:
533
536
  """
534
537
  Get the details of a remote inference job.
538
+ You can pass either the job uuid or the results uuid as a parameter.
539
+ If you pass both, the job uuid will be prioritized.
535
540
 
536
541
  :param job_uuid: The UUID of the EDSL job.
542
+ :param results_uuid: The UUID of the results associated with the EDSL job.
537
543
 
538
544
  >>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
539
545
  {'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
540
546
  """
547
+ if job_uuid is None and results_uuid is None:
548
+ raise ValueError("Either job_uuid or results_uuid must be provided.")
549
+ elif job_uuid is not None:
550
+ params = {"job_uuid": job_uuid}
551
+ else:
552
+ params = {"results_uuid": results_uuid}
553
+
541
554
  response = self._send_server_request(
542
555
  uri="api/v0/remote-inference",
543
556
  method="GET",
544
- params={"uuid": job_uuid},
557
+ params=params,
545
558
  )
546
559
  self._resolve_server_response(response)
547
560
  data = response.json()
548
561
  return {
549
- "jobs_uuid": data.get("jobs_uuid"),
562
+ "job_uuid": data.get("job_uuid"),
550
563
  "results_uuid": data.get("results_uuid"),
551
564
  "results_url": f"{self.url}/content/{data.get('results_uuid')}",
552
565
  "status": data.get("status"),
553
566
  "reason": data.get("reason"),
554
- "price": data.get("price"),
567
+ "credits_consumed": data.get("price"),
555
568
  "version": data.get("version"),
556
569
  }
557
570
 
@@ -584,7 +597,10 @@ class Coop:
584
597
  )
585
598
  self._resolve_server_response(response)
586
599
  response_json = response.json()
587
- return response_json.get("cost")
600
+ return {
601
+ "credits": response_json.get("cost_in_credits"),
602
+ "usd": response_json.get("cost_in_usd"),
603
+ }
588
604
 
589
605
  ################
590
606
  # Remote Errors
@@ -649,6 +665,10 @@ class Coop:
649
665
  return response_json
650
666
 
651
667
  def fetch_prices(self) -> dict:
668
+ """
669
+ Fetch model prices from Coop. If the request fails, return an empty dict.
670
+ """
671
+
652
672
  from edsl.coop.PriceFetcher import PriceFetcher
653
673
 
654
674
  from edsl.config import CONFIG
@@ -659,6 +679,20 @@ class Coop:
659
679
  else:
660
680
  return {}
661
681
 
682
+ def fetch_rate_limit_config_vars(self) -> dict:
683
+ """
684
+ Fetch a dict of rate limit config vars from Coop.
685
+
686
+ The dict keys are RPM and TPM variables like EDSL_SERVICE_RPM_OPENAI.
687
+ """
688
+ response = self._send_server_request(
689
+ uri="api/v0/config-vars",
690
+ method="GET",
691
+ )
692
+ self._resolve_server_response(response)
693
+ data = response.json()
694
+ return data
695
+
662
696
 
663
697
  if __name__ == "__main__":
664
698
  sheet_data = fetch_sheet_data()
@@ -0,0 +1,84 @@
1
+ class RemoteCacheSync:
2
+ def __init__(self, coop, cache, output_func, remote_cache=True, remote_cache_description=""):
3
+ self.coop = coop
4
+ self.cache = cache
5
+ self._output = output_func
6
+ self.remote_cache = remote_cache
7
+ self.old_entry_keys = []
8
+ self.new_cache_entries = []
9
+ self.remote_cache_description = remote_cache_description
10
+
11
+ def __enter__(self):
12
+ if self.remote_cache:
13
+ self._sync_from_remote()
14
+ self.old_entry_keys = list(self.cache.keys())
15
+ return self
16
+
17
+ def __exit__(self, exc_type, exc_value, traceback):
18
+ if self.remote_cache:
19
+ self._sync_to_remote()
20
+ return False # Propagate exceptions
21
+
22
+ def _sync_from_remote(self):
23
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
24
+ client_missing_cacheentries = cache_difference.get("client_missing_cacheentries", [])
25
+ missing_entry_count = len(client_missing_cacheentries)
26
+
27
+ if missing_entry_count > 0:
28
+ self._output(
29
+ f"Updating local cache with {missing_entry_count:,} new "
30
+ f"{'entry' if missing_entry_count == 1 else 'entries'} from remote..."
31
+ )
32
+ self.cache.add_from_dict({entry.key: entry for entry in client_missing_cacheentries})
33
+ self._output("Local cache updated!")
34
+ else:
35
+ self._output("No new entries to add to local cache.")
36
+
37
+ def _sync_to_remote(self):
38
+ cache_difference = self.coop.remote_cache_get_diff(self.cache.keys())
39
+ server_missing_cacheentry_keys = cache_difference.get("server_missing_cacheentry_keys", [])
40
+ server_missing_cacheentries = [
41
+ entry
42
+ for key in server_missing_cacheentry_keys
43
+ if (entry := self.cache.data.get(key)) is not None
44
+ ]
45
+
46
+ new_cache_entries = [
47
+ entry for entry in self.cache.values() if entry.key not in self.old_entry_keys
48
+ ]
49
+ server_missing_cacheentries.extend(new_cache_entries)
50
+ new_entry_count = len(server_missing_cacheentries)
51
+
52
+ if new_entry_count > 0:
53
+ self._output(
54
+ f"Updating remote cache with {new_entry_count:,} new "
55
+ f"{'entry' if new_entry_count == 1 else 'entries'}..."
56
+ )
57
+ self.coop.remote_cache_create_many(
58
+ server_missing_cacheentries,
59
+ visibility="private",
60
+ description=self.remote_cache_description,
61
+ )
62
+ self._output("Remote cache updated!")
63
+ else:
64
+ self._output("No new entries to add to remote cache.")
65
+
66
+ self._output(f"There are {len(self.cache.keys()):,} entries in the local cache.")
67
+
68
+ # # Usage example
69
+ # def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
70
+ # with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
71
+ # self._output("Running job...")
72
+ # results = self._run_local(
73
+ # n=n,
74
+ # progress_bar=progress_bar,
75
+ # cache=cache,
76
+ # stop_on_exception=stop_on_exception,
77
+ # sidecar_model=sidecar_model,
78
+ # print_exceptions=print_exceptions,
79
+ # raise_validation_errors=raise_validation_errors,
80
+ # )
81
+ # self._output("Job completed!")
82
+
83
+ # results.cache = cache.new_entries_cache()
84
+ # return results
edsl/exceptions/coop.py CHANGED
@@ -1,2 +1,10 @@
1
1
  class CoopErrors(Exception):
2
2
  pass
3
+
4
+
5
+ class CoopNoUUIDError(CoopErrors):
6
+ pass
7
+
8
+
9
+ class CoopServerResponseError(CoopErrors):
10
+ pass
@@ -1,6 +1,7 @@
1
1
  from abc import abstractmethod, ABC
2
2
  import os
3
3
  import re
4
+ from datetime import datetime, timedelta
4
5
  from edsl.config import CONFIG
5
6
 
6
7
 
@@ -10,6 +11,8 @@ class InferenceServiceABC(ABC):
10
11
  Anthropic: https://docs.anthropic.com/en/api/rate-limits
11
12
  """
12
13
 
14
+ _coop_config_vars = None
15
+
13
16
  default_levels = {
14
17
  "google": {"tpm": 2_000_000, "rpm": 15},
15
18
  "openai": {"tpm": 2_000_000, "rpm": 10_000},
@@ -31,12 +34,37 @@ class InferenceServiceABC(ABC):
31
34
  f"Class {cls.__name__} must have a 'model_exclude_list' attribute."
32
35
  )
33
36
 
37
+ @classmethod
38
+ def _should_refresh_coop_config_vars(cls):
39
+ """
40
+ Returns True if config vars have been fetched over 24 hours ago, and False otherwise.
41
+ """
42
+
43
+ if cls._last_config_fetch is None:
44
+ return True
45
+ return (datetime.now() - cls._last_config_fetch) > timedelta(hours=24)
46
+
34
47
  @classmethod
35
48
  def _get_limt(cls, limit_type: str) -> int:
36
49
  key = f"EDSL_SERVICE_{limit_type.upper()}_{cls._inference_service_.upper()}"
37
50
  if key in os.environ:
38
51
  return int(os.getenv(key))
39
52
 
53
+ if cls._coop_config_vars is None or cls._should_refresh_coop_config_vars():
54
+ try:
55
+ from edsl import Coop
56
+
57
+ c = Coop()
58
+ cls._coop_config_vars = c.fetch_rate_limit_config_vars()
59
+ cls._last_config_fetch = datetime.now()
60
+ if key in cls._coop_config_vars:
61
+ return cls._coop_config_vars[key]
62
+ except Exception:
63
+ cls._coop_config_vars = None
64
+ else:
65
+ if key in cls._coop_config_vars:
66
+ return cls._coop_config_vars[key]
67
+
40
68
  if cls._inference_service_ in cls.default_levels:
41
69
  return int(cls.default_levels[cls._inference_service_][limit_type])
42
70
 
@@ -11,21 +11,29 @@ from edsl.inference_services.AwsBedrock import AwsBedrockService
11
11
  from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
- from edsl.inference_services.MistralAIService import MistralAIService
15
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
16
15
 
17
- default = InferenceServicesCollection(
18
- [
19
- OpenAIService,
20
- AnthropicService,
21
- DeepInfraService,
22
- GoogleService,
23
- GroqService,
24
- AwsBedrockService,
25
- AzureAIService,
26
- OllamaService,
27
- TestService,
28
- MistralAIService,
29
- TogetherAIService,
30
- ]
31
- )
16
+ try:
17
+ from edsl.inference_services.MistralAIService import MistralAIService
18
+
19
+ mistral_available = True
20
+ except Exception as e:
21
+ mistral_available = False
22
+
23
+ services = [
24
+ OpenAIService,
25
+ AnthropicService,
26
+ DeepInfraService,
27
+ GoogleService,
28
+ GroqService,
29
+ AwsBedrockService,
30
+ AzureAIService,
31
+ OllamaService,
32
+ TestService,
33
+ TogetherAIService,
34
+ ]
35
+
36
+ if mistral_available:
37
+ services.append(MistralAIService)
38
+
39
+ default = InferenceServicesCollection(services)