edsl 0.1.37.dev5__py3-none-any.whl → 0.1.38__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +63 -34
  2. edsl/BaseDiff.py +7 -7
  3. edsl/__init__.py +2 -1
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +23 -11
  6. edsl/agents/AgentList.py +86 -23
  7. edsl/agents/Invigilator.py +18 -7
  8. edsl/agents/InvigilatorBase.py +0 -19
  9. edsl/agents/PromptConstructor.py +5 -4
  10. edsl/auto/SurveyCreatorPipeline.py +1 -1
  11. edsl/auto/utilities.py +1 -1
  12. edsl/base/Base.py +3 -13
  13. edsl/config.py +8 -0
  14. edsl/coop/coop.py +89 -19
  15. edsl/data/Cache.py +45 -17
  16. edsl/data/CacheEntry.py +8 -3
  17. edsl/data/RemoteCacheSync.py +0 -19
  18. edsl/enums.py +2 -0
  19. edsl/exceptions/agents.py +4 -0
  20. edsl/exceptions/cache.py +5 -0
  21. edsl/inference_services/GoogleService.py +7 -15
  22. edsl/inference_services/PerplexityService.py +163 -0
  23. edsl/inference_services/registry.py +2 -0
  24. edsl/jobs/Jobs.py +110 -559
  25. edsl/jobs/JobsChecks.py +147 -0
  26. edsl/jobs/JobsPrompts.py +268 -0
  27. edsl/jobs/JobsRemoteInferenceHandler.py +239 -0
  28. edsl/jobs/buckets/TokenBucket.py +3 -0
  29. edsl/jobs/interviews/Interview.py +7 -7
  30. edsl/jobs/runners/JobsRunnerAsyncio.py +156 -28
  31. edsl/jobs/runners/JobsRunnerStatus.py +194 -196
  32. edsl/jobs/tasks/TaskHistory.py +27 -19
  33. edsl/language_models/LanguageModel.py +52 -90
  34. edsl/language_models/ModelList.py +67 -14
  35. edsl/language_models/registry.py +57 -4
  36. edsl/notebooks/Notebook.py +7 -8
  37. edsl/prompts/Prompt.py +8 -3
  38. edsl/questions/QuestionBase.py +38 -30
  39. edsl/questions/QuestionBaseGenMixin.py +1 -1
  40. edsl/questions/QuestionBasePromptsMixin.py +0 -17
  41. edsl/questions/QuestionExtract.py +3 -4
  42. edsl/questions/QuestionFunctional.py +10 -3
  43. edsl/questions/derived/QuestionTopK.py +2 -0
  44. edsl/questions/question_registry.py +36 -6
  45. edsl/results/CSSParameterizer.py +108 -0
  46. edsl/results/Dataset.py +146 -15
  47. edsl/results/DatasetExportMixin.py +231 -217
  48. edsl/results/DatasetTree.py +134 -4
  49. edsl/results/Result.py +31 -16
  50. edsl/results/Results.py +159 -65
  51. edsl/results/TableDisplay.py +198 -0
  52. edsl/results/table_display.css +78 -0
  53. edsl/scenarios/FileStore.py +187 -13
  54. edsl/scenarios/Scenario.py +73 -18
  55. edsl/scenarios/ScenarioJoin.py +127 -0
  56. edsl/scenarios/ScenarioList.py +251 -76
  57. edsl/surveys/MemoryPlan.py +1 -1
  58. edsl/surveys/Rule.py +1 -5
  59. edsl/surveys/RuleCollection.py +1 -1
  60. edsl/surveys/Survey.py +25 -19
  61. edsl/surveys/SurveyFlowVisualizationMixin.py +67 -9
  62. edsl/surveys/instructions/ChangeInstruction.py +9 -7
  63. edsl/surveys/instructions/Instruction.py +21 -7
  64. edsl/templates/error_reporting/interview_details.html +3 -3
  65. edsl/templates/error_reporting/interviews.html +18 -9
  66. edsl/{conjure → utilities}/naming_utilities.py +1 -1
  67. edsl/utilities/utilities.py +15 -0
  68. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/METADATA +2 -1
  69. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/RECORD +71 -77
  70. edsl/conjure/AgentConstructionMixin.py +0 -160
  71. edsl/conjure/Conjure.py +0 -62
  72. edsl/conjure/InputData.py +0 -659
  73. edsl/conjure/InputDataCSV.py +0 -48
  74. edsl/conjure/InputDataMixinQuestionStats.py +0 -182
  75. edsl/conjure/InputDataPyRead.py +0 -91
  76. edsl/conjure/InputDataSPSS.py +0 -8
  77. edsl/conjure/InputDataStata.py +0 -8
  78. edsl/conjure/QuestionOptionMixin.py +0 -76
  79. edsl/conjure/QuestionTypeMixin.py +0 -23
  80. edsl/conjure/RawQuestion.py +0 -65
  81. edsl/conjure/SurveyResponses.py +0 -7
  82. edsl/conjure/__init__.py +0 -9
  83. edsl/conjure/examples/placeholder.txt +0 -0
  84. edsl/conjure/utilities.py +0 -201
  85. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.37.dev5.dist-info → edsl-0.1.38.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -28,11 +28,23 @@ class Coop:
28
28
  - Provide a URL directly, or use the default one.
29
29
  """
30
30
  self.api_key = api_key or os.getenv("EXPECTED_PARROT_API_KEY")
31
+
31
32
  self.url = url or CONFIG.EXPECTED_PARROT_URL
32
33
  if self.url.endswith("/"):
33
34
  self.url = self.url[:-1]
35
+ if "chick.expectedparrot" in self.url:
36
+ self.api_url = "https://chickapi.expectedparrot.com"
37
+ elif "expectedparrot" in self.url:
38
+ self.api_url = "https://api.expectedparrot.com"
39
+ elif "localhost:1234" in self.url:
40
+ self.api_url = "http://localhost:8000"
41
+ else:
42
+ self.api_url = self.url
34
43
  self._edsl_version = edsl.__version__
35
44
 
45
+ def get_progress_bar_url(self):
46
+ return f"{CONFIG.EXPECTED_PARROT_URL}"
47
+
36
48
  ################
37
49
  # BASIC METHODS
38
50
  ################
@@ -59,7 +71,7 @@ class Coop:
59
71
  """
60
72
  Send a request to the server and return the response.
61
73
  """
62
- url = f"{self.url}/{uri}"
74
+ url = f"{self.api_url}/{uri}"
63
75
  method = method.upper()
64
76
  if payload is None:
65
77
  timeout = 20
@@ -90,12 +102,57 @@ class Coop:
90
102
 
91
103
  return response
92
104
 
105
+ def _get_latest_stable_version(self, version: str) -> str:
106
+ """
107
+ Extract the latest stable PyPI version from a version string.
108
+
109
+ Examples:
110
+ - Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
111
+ - Return a stable version as is: "0.1.37" -> "0.1.37"
112
+ """
113
+ if "dev" not in version:
114
+ return version
115
+ else:
116
+ # For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
117
+ major, minor, patch = version.split(".")[:3]
118
+
119
+ current_patch = int(patch)
120
+ latest_patch = current_patch - 1
121
+ return f"{major}.{minor}.{latest_patch}"
122
+
123
+ def _user_version_is_outdated(
124
+ self, user_version_str: str, server_version_str: str
125
+ ) -> bool:
126
+ """
127
+ Check if the user's EDSL version is outdated compared to the server's.
128
+ """
129
+ server_stable_version_str = self._get_latest_stable_version(server_version_str)
130
+ user_stable_version_str = self._get_latest_stable_version(user_version_str)
131
+
132
+ # Turn the version strings into tuples of ints for comparison
133
+ user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
134
+ server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
135
+
136
+ return user_stable_version < server_stable_version
137
+
93
138
  def _resolve_server_response(
94
139
  self, response: requests.Response, check_api_key: bool = True
95
140
  ) -> None:
96
141
  """
97
142
  Check the response from the server and raise errors as appropriate.
98
143
  """
144
+ # Get EDSL version from header
145
+ server_edsl_version = response.headers.get("X-EDSL-Version")
146
+
147
+ if server_edsl_version:
148
+ if self._user_version_is_outdated(
149
+ user_version_str=self._edsl_version,
150
+ server_version_str=server_edsl_version,
151
+ ):
152
+ print(
153
+ "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
154
+ )
155
+
99
156
  if response.status_code >= 400:
100
157
  message = response.json().get("detail")
101
158
  # print(response.text)
@@ -568,7 +625,7 @@ class Coop:
568
625
 
569
626
  >>> job = Jobs.example()
570
627
  >>> coop.remote_inference_create(job=job, description="My job")
571
- {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.29.dev4'}
628
+ {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'iterations': None, 'visibility': 'unlisted', 'version': '0.1.38.dev1'}
572
629
  """
573
630
  response = self._send_server_request(
574
631
  uri="api/v0/remote-inference",
@@ -609,7 +666,7 @@ class Coop:
609
666
  :param results_uuid: The UUID of the results associated with the EDSL job.
610
667
 
611
668
  >>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
612
- {'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
669
+ {'job_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'latest_error_report_uuid': None, 'latest_error_report_url': None, 'status': 'completed', 'reason': None, 'credits_consumed': 0.35, 'version': '0.1.38.dev1'}
613
670
  """
614
671
  if job_uuid is None and results_uuid is None:
615
672
  raise ValueError("Either job_uuid or results_uuid must be provided.")
@@ -625,10 +682,28 @@ class Coop:
625
682
  )
626
683
  self._resolve_server_response(response)
627
684
  data = response.json()
685
+
686
+ results_uuid = data.get("results_uuid")
687
+ latest_error_report_uuid = data.get("latest_error_report_uuid")
688
+
689
+ if results_uuid is None:
690
+ results_url = None
691
+ else:
692
+ results_url = f"{self.url}/content/{results_uuid}"
693
+
694
+ if latest_error_report_uuid is None:
695
+ latest_error_report_url = None
696
+ else:
697
+ latest_error_report_url = (
698
+ f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
699
+ )
700
+
628
701
  return {
629
702
  "job_uuid": data.get("job_uuid"),
630
- "results_uuid": data.get("results_uuid"),
631
- "results_url": f"{self.url}/content/{data.get('results_uuid')}",
703
+ "results_uuid": results_uuid,
704
+ "results_url": results_url,
705
+ "latest_error_report_uuid": latest_error_report_uuid,
706
+ "latest_error_report_url": latest_error_report_url,
632
707
  "status": data.get("status"),
633
708
  "reason": data.get("reason"),
634
709
  "credits_consumed": data.get("price"),
@@ -645,7 +720,7 @@ class Coop:
645
720
 
646
721
  >>> job = Jobs.example()
647
722
  >>> coop.remote_inference_cost(input=job)
648
- 16
723
+ {'credits': 0.77, 'usd': 0.0076950000000000005}
649
724
  """
650
725
  if isinstance(input, Jobs):
651
726
  job = input
@@ -685,7 +760,7 @@ class Coop:
685
760
  async def remote_async_execute_model_call(
686
761
  self, model_dict: dict, user_prompt: str, system_prompt: str
687
762
  ) -> dict:
688
- url = self.url + "/inference/"
763
+ url = self.api_url + "/inference/"
689
764
  # print("Now using url: ", url)
690
765
  data = {
691
766
  "model_dict": model_dict,
@@ -706,7 +781,7 @@ class Coop:
706
781
  ] = "lime_survey",
707
782
  email=None,
708
783
  ):
709
- url = f"{self.url}/api/v0/export_to_{platform}"
784
+ url = f"{self.api_url}/api/v0/export_to_{platform}"
710
785
  if email:
711
786
  data = {"json_string": json.dumps({"survey": survey, "email": email})}
712
787
  else:
@@ -725,11 +800,15 @@ class Coop:
725
800
 
726
801
  from edsl.config import CONFIG
727
802
 
728
- if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
803
+ if CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "True":
729
804
  price_fetcher = PriceFetcher()
730
805
  return price_fetcher.fetch_prices()
731
- else:
806
+ elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
732
807
  return {}
808
+ else:
809
+ raise ValueError(
810
+ "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
811
+ )
733
812
 
734
813
  def fetch_models(self) -> dict:
735
814
  """
@@ -810,15 +889,6 @@ class Coop:
810
889
  load_dotenv()
811
890
 
812
891
 
813
- if __name__ == "__main__":
814
- sheet_data = fetch_sheet_data()
815
- if sheet_data:
816
- print(f"Successfully fetched {len(sheet_data)} rows of data.")
817
- print("First row:", sheet_data[0])
818
- else:
819
- print("Failed to fetch sheet data.")
820
-
821
-
822
892
  def main():
823
893
  """
824
894
  A simple example for the coop client
edsl/data/Cache.py CHANGED
@@ -11,7 +11,8 @@ from typing import Optional, Union
11
11
  from edsl.Base import Base
12
12
  from edsl.data.CacheEntry import CacheEntry
13
13
  from edsl.utilities.utilities import dict_hash
14
- from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
14
+ from edsl.utilities.decorators import remove_edsl_version
15
+ from edsl.exceptions.cache import CacheError
15
16
 
16
17
 
17
18
  class Cache(Base):
@@ -26,6 +27,8 @@ class Cache(Base):
26
27
  :param method: The method of storage to use for the cache.
27
28
  """
28
29
 
30
+ __documentation__ = "https://docs.expectedparrot.com/en/latest/data.html"
31
+
29
32
  data = {}
30
33
 
31
34
  def __init__(
@@ -58,7 +61,7 @@ class Cache(Base):
58
61
 
59
62
  self.filename = filename
60
63
  if filename and data:
61
- raise ValueError("Cannot provide both filename and data")
64
+ raise CacheError("Cannot provide both filename and data")
62
65
  if filename is None and data is None:
63
66
  data = {}
64
67
  if data is not None:
@@ -76,7 +79,7 @@ class Cache(Base):
76
79
  if os.path.exists(filename):
77
80
  self.add_from_sqlite(filename)
78
81
  else:
79
- raise ValueError("Invalid file extension. Must be .jsonl or .db")
82
+ raise CacheError("Invalid file extension. Must be .jsonl or .db")
80
83
 
81
84
  self._perform_checks()
82
85
 
@@ -116,7 +119,7 @@ class Cache(Base):
116
119
  from edsl.data.CacheEntry import CacheEntry
117
120
 
118
121
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
119
- raise Exception("Not all values are CacheEntry instances")
122
+ raise CacheError("Not all values are CacheEntry instances")
120
123
  if self.method is not None:
121
124
  warnings.warn("Argument `method` is deprecated", DeprecationWarning)
122
125
 
@@ -227,9 +230,9 @@ class Cache(Base):
227
230
  for key, value in new_data.items():
228
231
  if key in self.data:
229
232
  if value != self.data[key]:
230
- raise Exception("Mismatch in values")
233
+ raise CacheError("Mismatch in values")
231
234
  if not isinstance(value, CacheEntry):
232
- raise Exception(f"Wrong type - the observed type is {type(value)}")
235
+ raise CacheError(f"Wrong type - the observed type is {type(value)}")
233
236
 
234
237
  self.new_entries.update(new_data)
235
238
  if write_now:
@@ -338,7 +341,7 @@ class Cache(Base):
338
341
  elif filename.endswith(".db"):
339
342
  self.write_sqlite_db(filename)
340
343
  else:
341
- raise ValueError("Invalid file extension. Must be .jsonl or .db")
344
+ raise CacheError("Invalid file extension. Must be .jsonl or .db")
342
345
 
343
346
  def write_jsonl(self, filename: str) -> None:
344
347
  """
@@ -396,20 +399,45 @@ class Cache(Base):
396
399
  ####################
397
400
  def __hash__(self):
398
401
  """Return the hash of the Cache."""
399
- return dict_hash(self._to_dict())
402
+ return dict_hash(self.to_dict(add_edsl_version=False))
403
+
404
+ def to_dict(self, add_edsl_version=True) -> dict:
405
+ d = {k: v.to_dict() for k, v in self.data.items()}
406
+ if add_edsl_version:
407
+ from edsl import __version__
400
408
 
401
- def _to_dict(self) -> dict:
402
- return {k: v.to_dict() for k, v in self.data.items()}
409
+ d["edsl_version"] = __version__
410
+ d["edsl_class_name"] = "Cache"
403
411
 
404
- @add_edsl_version
405
- def to_dict(self) -> dict:
406
- """Return the Cache as a dictionary."""
407
- return self._to_dict()
412
+ return d
413
+
414
+ def _summary(self):
415
+ return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
408
416
 
409
417
  def _repr_html_(self):
410
- from edsl.utilities.utilities import data_to_html
418
+ # from edsl.utilities.utilities import data_to_html
419
+ # return data_to_html(self.to_dict())
420
+ footer = f"<a href={self.__documentation__}>(docs)</a>"
421
+ return str(self.summary(format="html")) + footer
422
+
423
+ def table(
424
+ self,
425
+ *fields,
426
+ tablefmt: Optional[str] = None,
427
+ pretty_labels: Optional[dict] = None,
428
+ ) -> str:
429
+ return self.to_dataset().table(
430
+ *fields, tablefmt=tablefmt, pretty_labels=pretty_labels
431
+ )
432
+
433
+ def select(self, *fields):
434
+ return self.to_dataset().select(*fields)
435
+
436
+ def tree(self, node_list: Optional[list[str]] = None):
437
+ return self.to_scenario_list().tree(node_list)
411
438
 
412
- return data_to_html(self.to_dict())
439
+ def to_dataset(self):
440
+ return self.to_scenario_list().to_dataset()
413
441
 
414
442
  @classmethod
415
443
  @remove_edsl_version
@@ -438,7 +466,7 @@ class Cache(Base):
438
466
  Combine two caches.
439
467
  """
440
468
  if not isinstance(other, Cache):
441
- raise ValueError("Can only add two caches together")
469
+ raise CacheError("Can only add two caches together")
442
470
  self.data.update(other.data)
443
471
  return self
444
472
 
edsl/data/CacheEntry.py CHANGED
@@ -96,9 +96,14 @@ class CacheEntry:
96
96
  """
97
97
  Returns an HTML representation of a CacheEntry.
98
98
  """
99
- from edsl.utilities.utilities import data_to_html
100
-
101
- return data_to_html(self.to_dict())
99
+ # from edsl.utilities.utilities import data_to_html
100
+ # return data_to_html(self.to_dict())
101
+ d = self.to_dict()
102
+ data = [[k, v] for k, v in d.items()]
103
+ from tabulate import tabulate
104
+
105
+ table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
106
+ return f"<pre>{table}</pre>"
102
107
 
103
108
  def keys(self):
104
109
  return list(self.to_dict().keys())
@@ -76,22 +76,3 @@ class RemoteCacheSync:
76
76
  self._output(
77
77
  f"There are {len(self.cache.keys()):,} entries in the local cache."
78
78
  )
79
-
80
-
81
- # # Usage example
82
- # def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
83
- # with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
84
- # self._output("Running job...")
85
- # results = self._run_local(
86
- # n=n,
87
- # progress_bar=progress_bar,
88
- # cache=cache,
89
- # stop_on_exception=stop_on_exception,
90
- # sidecar_model=sidecar_model,
91
- # print_exceptions=print_exceptions,
92
- # raise_validation_errors=raise_validation_errors,
93
- # )
94
- # self._output("Job completed!")
95
-
96
- # results.cache = cache.new_entries_cache()
97
- # return results
edsl/enums.py CHANGED
@@ -64,6 +64,7 @@ class InferenceServiceType(EnumWithChecks):
64
64
  OLLAMA = "ollama"
65
65
  MISTRAL = "mistral"
66
66
  TOGETHER = "together"
67
+ PERPLEXITY = "perplexity"
67
68
 
68
69
 
69
70
  service_to_api_keyname = {
@@ -78,6 +79,7 @@ service_to_api_keyname = {
78
79
  InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
79
80
  InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
80
81
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
82
+ InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
81
83
  }
82
84
 
83
85
 
edsl/exceptions/agents.py CHANGED
@@ -1,6 +1,10 @@
1
1
  from edsl.exceptions.BaseException import BaseException
2
2
 
3
3
 
4
+ class AgentListError(BaseException):
5
+ relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
6
+
7
+
4
8
  class AgentErrors(BaseException):
5
9
  relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html"
6
10
 
@@ -0,0 +1,5 @@
1
+ from edsl.exceptions.BaseException import BaseException
2
+
3
+
4
+ class CacheError(BaseException):
5
+ relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
@@ -8,6 +8,7 @@ from google.api_core.exceptions import InvalidArgument
8
8
  from edsl.exceptions import MissingAPIKeyError
9
9
  from edsl.language_models.LanguageModel import LanguageModel
10
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
+ from edsl.coop import Coop
11
12
 
12
13
  safety_settings = [
13
14
  {
@@ -79,22 +80,8 @@ class GoogleService(InferenceServiceABC):
79
80
  api_token = None
80
81
  model = None
81
82
 
82
- @classmethod
83
- def initialize(cls):
84
- if cls.api_token is None:
85
- cls.api_token = os.getenv("GOOGLE_API_KEY")
86
- if not cls.api_token:
87
- raise MissingAPIKeyError(
88
- "GOOGLE_API_KEY environment variable is not set"
89
- )
90
- genai.configure(api_key=cls.api_token)
91
- cls.generative_model = genai.GenerativeModel(
92
- cls._model_, safety_settings=safety_settings
93
- )
94
-
95
83
  def __init__(self, *args, **kwargs):
96
84
  super().__init__(*args, **kwargs)
97
- self.initialize()
98
85
 
99
86
  def get_generation_config(self) -> GenerationConfig:
100
87
  return GenerationConfig(
@@ -116,6 +103,7 @@ class GoogleService(InferenceServiceABC):
116
103
  if files_list is None:
117
104
  files_list = []
118
105
 
106
+ genai.configure(api_key=self.api_token)
119
107
  if (
120
108
  system_prompt is not None
121
109
  and system_prompt != ""
@@ -133,7 +121,11 @@ class GoogleService(InferenceServiceABC):
133
121
  )
134
122
  print("Will add system_prompt to user_prompt")
135
123
  user_prompt = f"{system_prompt}\n{user_prompt}"
136
-
124
+ else:
125
+ self.generative_model = genai.GenerativeModel(
126
+ self._model_,
127
+ safety_settings=safety_settings,
128
+ )
137
129
  combined_prompt = [user_prompt]
138
130
  for file in files_list:
139
131
  if "google" not in file.external_locations:
@@ -0,0 +1,163 @@
1
+ import aiohttp
2
+ import json
3
+ import requests
4
+ from typing import Any, List, Optional
5
+ from edsl.inference_services.rate_limits_cache import rate_limits
6
+
7
+ # from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
8
+ from edsl.language_models import LanguageModel
9
+
10
+ from edsl.inference_services.OpenAIService import OpenAIService
11
+
12
+
13
+ class PerplexityService(OpenAIService):
14
+ """Perplexity service class."""
15
+
16
+ _inference_service_ = "perplexity"
17
+ _env_key_name_ = "PERPLEXITY_API_KEY"
18
+ _base_url_ = "https://api.perplexity.ai"
19
+ _models_list_cache: List[str] = []
20
+ # default perplexity parameters
21
+ _parameters_ = {
22
+ "temperature": 0.5,
23
+ "max_tokens": 1000,
24
+ "top_p": 1,
25
+ "logprobs": False,
26
+ "top_logprobs": 3,
27
+ }
28
+
29
+ @classmethod
30
+ def available(cls) -> List[str]:
31
+ return [
32
+ "llama-3.1-sonar-huge-128k-online",
33
+ "llama-3.1-sonar-large-128k-online",
34
+ "llama-3.1-sonar-small-128k-online",
35
+ ]
36
+
37
+ @classmethod
38
+ def create_model(
39
+ cls, model_name="llama-3.1-sonar-large-128k-online", model_class_name=None
40
+ ) -> LanguageModel:
41
+ if model_class_name is None:
42
+ model_class_name = cls.to_class_name(model_name)
43
+
44
+ class LLM(LanguageModel):
45
+ """
46
+ Child class of LanguageModel for interacting with Perplexity models
47
+ """
48
+
49
+ key_sequence = cls.key_sequence
50
+ usage_sequence = cls.usage_sequence
51
+ input_token_name = cls.input_token_name
52
+ output_token_name = cls.output_token_name
53
+
54
+ _rpm = cls.get_rpm(cls)
55
+ _tpm = cls.get_tpm(cls)
56
+
57
+ _inference_service_ = cls._inference_service_
58
+ _model_ = model_name
59
+
60
+ _parameters_ = {
61
+ "temperature": 0.5,
62
+ "max_tokens": 1000,
63
+ "top_p": 1,
64
+ "frequency_penalty": 1,
65
+ "presence_penalty": 0,
66
+ # "logprobs": False, # Enable this returns 'Neither or both of logprobs and top_logprobs must be set.
67
+ # "top_logprobs": 3,
68
+ }
69
+
70
+ def sync_client(self):
71
+ return cls.sync_client()
72
+
73
+ def async_client(self):
74
+ return cls.async_client()
75
+
76
+ @classmethod
77
+ def available(cls) -> list[str]:
78
+ return cls.sync_client().models.list()
79
+
80
+ def get_headers(self) -> dict[str, Any]:
81
+ client = self.sync_client()
82
+ response = client.chat.completions.with_raw_response.create(
83
+ messages=[
84
+ {
85
+ "role": "user",
86
+ "content": "Say this is a test",
87
+ }
88
+ ],
89
+ model=self.model,
90
+ )
91
+ return dict(response.headers)
92
+
93
+ def get_rate_limits(self) -> dict[str, Any]:
94
+ try:
95
+ if "openai" in rate_limits:
96
+ headers = rate_limits["openai"]
97
+
98
+ else:
99
+ headers = self.get_headers()
100
+
101
+ except Exception as e:
102
+ return {
103
+ "rpm": 10_000,
104
+ "tpm": 2_000_000,
105
+ }
106
+ else:
107
+ return {
108
+ "rpm": int(headers["x-ratelimit-limit-requests"]),
109
+ "tpm": int(headers["x-ratelimit-limit-tokens"]),
110
+ }
111
+
112
+ async def async_execute_model_call(
113
+ self,
114
+ user_prompt: str,
115
+ system_prompt: str = "",
116
+ files_list: Optional[List["Files"]] = None,
117
+ invigilator: Optional[
118
+ "InvigilatorAI"
119
+ ] = None, # TBD - can eventually be used for function-calling
120
+ ) -> dict[str, Any]:
121
+ """Calls the OpenAI API and returns the API response."""
122
+ if files_list:
123
+ encoded_image = files_list[0].base64_string
124
+ content = [{"type": "text", "text": user_prompt}]
125
+ content.append(
126
+ {
127
+ "type": "image_url",
128
+ "image_url": {
129
+ "url": f"data:image/jpeg;base64,{encoded_image}"
130
+ },
131
+ }
132
+ )
133
+ else:
134
+ content = user_prompt
135
+ client = self.async_client()
136
+
137
+ messages = [
138
+ {"role": "system", "content": system_prompt},
139
+ {"role": "user", "content": content},
140
+ ]
141
+ if system_prompt == "" and self.omit_system_prompt_if_empty:
142
+ messages = messages[1:]
143
+
144
+ params = {
145
+ "model": self.model,
146
+ "messages": messages,
147
+ "temperature": self.temperature,
148
+ "max_tokens": self.max_tokens,
149
+ "top_p": self.top_p,
150
+ "frequency_penalty": self.frequency_penalty,
151
+ "presence_penalty": self.presence_penalty,
152
+ # "logprobs": self.logprobs,
153
+ # "top_logprobs": self.top_logprobs if self.logprobs else None,
154
+ }
155
+ try:
156
+ response = await client.chat.completions.create(**params)
157
+ except Exception as e:
158
+ print(e, flush=True)
159
+ return response.model_dump()
160
+
161
+ LLM.__name__ = "LanguageModel"
162
+
163
+ return LLM
@@ -12,6 +12,7 @@ from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
+ from edsl.inference_services.PerplexityService import PerplexityService
15
16
 
16
17
  try:
17
18
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -31,6 +32,7 @@ services = [
31
32
  OllamaService,
32
33
  TestService,
33
34
  TogetherAIService,
35
+ PerplexityService,
34
36
  ]
35
37
 
36
38
  if mistral_available: