edsl 0.1.42__py3-none-any.whl → 0.1.44__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. edsl/Base.py +15 -6
  2. edsl/__version__.py +1 -1
  3. edsl/agents/Invigilator.py +1 -1
  4. edsl/agents/PromptConstructor.py +92 -21
  5. edsl/agents/QuestionInstructionPromptBuilder.py +68 -9
  6. edsl/agents/prompt_helpers.py +2 -2
  7. edsl/coop/coop.py +100 -22
  8. edsl/enums.py +3 -1
  9. edsl/exceptions/coop.py +4 -0
  10. edsl/inference_services/AnthropicService.py +2 -0
  11. edsl/inference_services/AvailableModelFetcher.py +4 -1
  12. edsl/inference_services/GoogleService.py +2 -0
  13. edsl/inference_services/GrokService.py +11 -0
  14. edsl/inference_services/InferenceServiceABC.py +1 -0
  15. edsl/inference_services/OpenAIService.py +1 -0
  16. edsl/inference_services/TestService.py +1 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +54 -35
  19. edsl/jobs/JobsChecks.py +7 -7
  20. edsl/jobs/JobsPrompts.py +57 -6
  21. edsl/jobs/JobsRemoteInferenceHandler.py +41 -25
  22. edsl/jobs/buckets/BucketCollection.py +30 -0
  23. edsl/jobs/data_structures.py +1 -0
  24. edsl/language_models/LanguageModel.py +5 -2
  25. edsl/language_models/key_management/KeyLookupBuilder.py +47 -20
  26. edsl/language_models/key_management/models.py +10 -4
  27. edsl/language_models/model.py +43 -11
  28. edsl/prompts/Prompt.py +124 -61
  29. edsl/questions/descriptors.py +32 -18
  30. edsl/questions/question_base_gen_mixin.py +1 -0
  31. edsl/results/DatasetExportMixin.py +35 -6
  32. edsl/results/Results.py +180 -1
  33. edsl/results/ResultsGGMixin.py +117 -60
  34. edsl/scenarios/FileStore.py +19 -8
  35. edsl/scenarios/Scenario.py +33 -0
  36. edsl/scenarios/ScenarioList.py +22 -3
  37. edsl/scenarios/ScenarioListPdfMixin.py +9 -3
  38. edsl/surveys/Survey.py +27 -6
  39. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/METADATA +3 -4
  40. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/RECORD +42 -41
  41. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/LICENSE +0 -0
  42. {edsl-0.1.42.dist-info → edsl-0.1.44.dist-info}/WHEEL +0 -0
edsl/Base.py CHANGED
@@ -64,17 +64,26 @@ class PersistenceMixin:
64
64
  @classmethod
65
65
  def pull(
66
66
  cls,
67
- uuid: Optional[Union[str, UUID]] = None,
68
- url: Optional[str] = None,
69
- expected_parrot_url: Optional[str] = None,
67
+ url_or_uuid: Optional[Union[str, UUID]] = None,
68
+ #expected_parrot_url: Optional[str] = None,
70
69
  ):
71
- """Pull the object from coop."""
70
+ """Pull the object from coop.
71
+
72
+ Args:
73
+ url_or_uuid: Either a UUID string or a URL pointing to the object
74
+ expected_parrot_url: Optional URL for the Parrot server
75
+ """
72
76
  from edsl.coop import Coop
73
77
  from edsl.coop.utils import ObjectRegistry
74
78
 
75
79
  object_type = ObjectRegistry.get_object_type_by_edsl_class(cls)
76
- coop = Coop(url=expected_parrot_url)
77
- return coop.get(uuid, url, object_type)
80
+ coop = Coop()
81
+
82
+ # Determine if input is URL or UUID
83
+ if url_or_uuid and ("http://" in str(url_or_uuid) or "https://" in str(url_or_uuid)):
84
+ return coop.get(url=url_or_uuid, expected_object_type=object_type)
85
+ else:
86
+ return coop.get(uuid=url_or_uuid, expected_object_type=object_type)
78
87
 
79
88
  @classmethod
80
89
  def delete(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.42"
1
+ __version__ = "0.1.44"
@@ -156,7 +156,7 @@ class InvigilatorAI(InvigilatorBase):
156
156
  self.question.question_options = new_question_options
157
157
 
158
158
  question_with_validators = self.question.render(
159
- self.scenario | prior_answers_dict
159
+ self.scenario | prior_answers_dict | {'agent':self.agent.traits}
160
160
  )
161
161
  question_with_validators.use_code = self.question.use_code
162
162
  else:
@@ -1,6 +1,10 @@
1
1
  from __future__ import annotations
2
- from typing import Dict, Any, Optional, Set, Union, TYPE_CHECKING
2
+ from typing import Dict, Any, Optional, Set, Union, TYPE_CHECKING, Literal
3
3
  from functools import cached_property
4
+ from multiprocessing import Pool, freeze_support, get_context
5
+ from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
6
+ import time
7
+ import logging
4
8
 
5
9
  from edsl.prompts.Prompt import Prompt
6
10
 
@@ -22,6 +26,7 @@ if TYPE_CHECKING:
22
26
  from edsl.questions.QuestionBase import QuestionBase
23
27
  from edsl.scenarios.Scenario import Scenario
24
28
 
29
+ logger = logging.getLogger(__name__)
25
30
 
26
31
  class BasePlaceholder:
27
32
  """Base class for placeholder values when a question is not yet answered."""
@@ -242,31 +247,97 @@ class PromptConstructor:
242
247
  question_name, self.current_answers
243
248
  )
244
249
 
245
- def get_prompts(self) -> Dict[str, Prompt]:
246
- """Get both prompts for the LLM call.
247
-
248
- >>> from edsl import QuestionFreeText
249
- >>> from edsl.agents.InvigilatorBase import InvigilatorBase
250
- >>> q = QuestionFreeText(question_text="How are you today?", question_name="q_new")
251
- >>> i = InvigilatorBase.example(question = q)
252
- >>> i.get_prompts()
253
- {'user_prompt': ..., 'system_prompt': ...}
254
- """
255
- prompts = self.prompt_plan.get_prompts(
256
- agent_instructions=self.agent_instructions_prompt,
257
- agent_persona=self.agent_persona_prompt,
258
- question_instructions=Prompt(self.question_instructions_prompt),
259
- prior_question_memory=self.prior_question_memory_prompt,
260
- )
261
- if self.question_file_keys:
250
+ def get_prompts(self, parallel: Literal["thread", "process", None] = None) -> Dict[str, Any]:
251
+ """Get the prompts for the question."""
252
+ start = time.time()
253
+
254
+ # Build all the components
255
+ instr_start = time.time()
256
+ agent_instructions = self.agent_instructions_prompt
257
+ instr_end = time.time()
258
+ logger.debug(f"Time taken for agent instructions: {instr_end - instr_start:.4f}s")
259
+
260
+ persona_start = time.time()
261
+ agent_persona = self.agent_persona_prompt
262
+ persona_end = time.time()
263
+ logger.debug(f"Time taken for agent persona: {persona_end - persona_start:.4f}s")
264
+
265
+ q_instr_start = time.time()
266
+ question_instructions = self.question_instructions_prompt
267
+ q_instr_end = time.time()
268
+ logger.debug(f"Time taken for question instructions: {q_instr_end - q_instr_start:.4f}s")
269
+
270
+ memory_start = time.time()
271
+ prior_question_memory = self.prior_question_memory_prompt
272
+ memory_end = time.time()
273
+ logger.debug(f"Time taken for prior question memory: {memory_end - memory_start:.4f}s")
274
+
275
+ # Get components dict
276
+ components = {
277
+ "agent_instructions": agent_instructions.text,
278
+ "agent_persona": agent_persona.text,
279
+ "question_instructions": question_instructions.text,
280
+ "prior_question_memory": prior_question_memory.text,
281
+ }
282
+
283
+ # Use PromptPlan's get_prompts method
284
+ plan_start = time.time()
285
+
286
+ # Get arranged components first
287
+ arranged = self.prompt_plan.arrange_components(**components)
288
+
289
+ if parallel == "process":
290
+ ctx = get_context('fork')
291
+ with ctx.Pool() as pool:
292
+ results = pool.map(_process_prompt, [
293
+ (arranged["user_prompt"], {}),
294
+ (arranged["system_prompt"], {})
295
+ ])
296
+ prompts = {
297
+ "user_prompt": results[0],
298
+ "system_prompt": results[1]
299
+ }
300
+
301
+ elif parallel == "thread":
302
+ with ThreadPoolExecutor() as executor:
303
+ user_prompt_list = arranged["user_prompt"]
304
+ system_prompt_list = arranged["system_prompt"]
305
+
306
+ # Process both prompt lists in parallel
307
+ rendered_user = executor.submit(_process_prompt, (user_prompt_list, {}))
308
+ rendered_system = executor.submit(_process_prompt, (system_prompt_list, {}))
309
+
310
+ prompts = {
311
+ "user_prompt": rendered_user.result(),
312
+ "system_prompt": rendered_system.result()
313
+ }
314
+
315
+ else: # sequential processing
316
+ prompts = self.prompt_plan.get_prompts(**components)
317
+
318
+ plan_end = time.time()
319
+ logger.debug(f"Time taken for prompt processing: {plan_end - plan_start:.4f}s")
320
+
321
+ # Handle file keys if present
322
+ if hasattr(self, 'question_file_keys') and self.question_file_keys:
323
+ files_start = time.time()
262
324
  files_list = []
263
325
  for key in self.question_file_keys:
264
326
  files_list.append(self.scenario[key])
265
327
  prompts["files_list"] = files_list
328
+ files_end = time.time()
329
+ logger.debug(f"Time taken for file key processing: {files_end - files_start:.4f}s")
330
+
331
+ end = time.time()
332
+ logger.debug(f"Total time in get_prompts: {end - start:.4f}s")
266
333
  return prompts
267
334
 
268
335
 
269
- if __name__ == "__main__":
270
- import doctest
336
+ def _process_prompt(args):
337
+ """Helper function to process a single prompt list with its replacements."""
338
+ prompt_list, replacements = args
339
+ return prompt_list.reduce()
340
+
271
341
 
272
- doctest.testmod(optionflags=doctest.ELLIPSIS)
342
+ if __name__ == '__main__':
343
+ freeze_support()
@@ -1,5 +1,6 @@
1
1
  from typing import Dict, List, Set
2
2
  from warnings import warn
3
+ import logging
3
4
  from edsl.prompts.Prompt import Prompt
4
5
 
5
6
  from edsl.agents.QuestionTemplateReplacementsBuilder import (
@@ -23,12 +24,44 @@ class QuestionInstructionPromptBuilder:
23
24
  Returns:
24
25
  Prompt: The fully rendered question instructions
25
26
  """
27
+ import time
28
+
29
+ start = time.time()
30
+
31
+ # Create base prompt
32
+ base_start = time.time()
26
33
  base_prompt = self._create_base_prompt()
34
+ base_end = time.time()
35
+ logging.debug(f"Time for base prompt: {base_end - base_start}")
36
+
37
+ # Enrich with options
38
+ enrich_start = time.time()
27
39
  enriched_prompt = self._enrich_with_question_options(base_prompt)
40
+ enrich_end = time.time()
41
+ logging.debug(f"Time for enriching with options: {enrich_end - enrich_start}")
42
+
43
+ # Render prompt
44
+ render_start = time.time()
28
45
  rendered_prompt = self._render_prompt(enriched_prompt)
46
+ render_end = time.time()
47
+ logging.debug(f"Time for rendering prompt: {render_end - render_start}")
48
+
49
+ # Validate template variables
50
+ validate_start = time.time()
29
51
  self._validate_template_variables(rendered_prompt)
30
-
31
- return self._append_survey_instructions(rendered_prompt)
52
+ validate_end = time.time()
53
+ logging.debug(f"Time for template validation: {validate_end - validate_start}")
54
+
55
+ # Append survey instructions
56
+ append_start = time.time()
57
+ final_prompt = self._append_survey_instructions(rendered_prompt)
58
+ append_end = time.time()
59
+ logging.debug(f"Time for appending survey instructions: {append_end - append_start}")
60
+
61
+ end = time.time()
62
+ logging.debug(f"Total time in build_question_instructions: {end - start}")
63
+
64
+ return final_prompt
32
65
 
33
66
  def _create_base_prompt(self) -> Dict:
34
67
  """Creates the initial prompt with basic question data.
@@ -50,14 +83,25 @@ class QuestionInstructionPromptBuilder:
50
83
  Returns:
51
84
  Dict: Enriched prompt data
52
85
  """
86
+ import time
87
+
88
+ start = time.time()
89
+
53
90
  if "question_options" in prompt_data["data"]:
54
91
  from edsl.agents.question_option_processor import QuestionOptionProcessor
55
-
92
+
93
+ processor_start = time.time()
56
94
  question_options = QuestionOptionProcessor(
57
95
  self.prompt_constructor
58
96
  ).get_question_options(question_data=prompt_data["data"])
59
-
97
+ processor_end = time.time()
98
+ logging.debug(f"Time to process question options: {processor_end - processor_start}")
99
+
60
100
  prompt_data["data"]["question_options"] = question_options
101
+
102
+ end = time.time()
103
+ logging.debug(f"Total time in _enrich_with_question_options: {end - start}")
104
+
61
105
  return prompt_data
62
106
 
63
107
  def _render_prompt(self, prompt_data: Dict) -> Prompt:
@@ -69,11 +113,28 @@ class QuestionInstructionPromptBuilder:
69
113
  Returns:
70
114
  Prompt: Rendered instructions
71
115
  """
72
-
116
+ import time
117
+
118
+ start = time.time()
119
+
120
+ # Build replacement dict
121
+ dict_start = time.time()
73
122
  replacement_dict = QTRB(self.prompt_constructor).build_replacement_dict(
74
123
  prompt_data["data"]
75
124
  )
76
- return prompt_data["prompt"].render(replacement_dict)
125
+ dict_end = time.time()
126
+ logging.debug(f"Time to build replacement dict: {dict_end - dict_start}")
127
+
128
+ # Render with dict
129
+ render_start = time.time()
130
+ result = prompt_data["prompt"].render(replacement_dict)
131
+ render_end = time.time()
132
+ logging.debug(f"Time to render with dict: {render_end - render_start}")
133
+
134
+ end = time.time()
135
+ logging.debug(f"Total time in _render_prompt: {end - start}")
136
+
137
+ return result
77
138
 
78
139
  def _validate_template_variables(self, rendered_prompt: Prompt) -> None:
79
140
  """Validates that all template variables have been properly replaced.
@@ -101,9 +162,7 @@ class QuestionInstructionPromptBuilder:
101
162
  """
102
163
  for question_name in self.survey.question_names:
103
164
  if question_name in undefined_vars:
104
- print(
105
- f"Question name found in undefined_template_variables: {question_name}"
106
- )
165
+ logging.warning(f"Question name found in undefined_template_variables: {question_name}")
107
166
 
108
167
  def _append_survey_instructions(self, rendered_prompt: Prompt) -> Prompt:
109
168
  """Appends any relevant survey instructions to the rendered prompt.
@@ -124,6 +124,6 @@ class PromptPlan:
124
124
  """Get both prompts for the LLM call."""
125
125
  prompts = self.arrange_components(**kwargs)
126
126
  return {
127
- "user_prompt": prompts["user_prompt"].reduce(),
128
- "system_prompt": prompts["system_prompt"].reduce(),
127
+ "user_prompt": Prompt("".join(str(p) for p in prompts["user_prompt"])),
128
+ "system_prompt": Prompt("".join(str(p) for p in prompts["system_prompt"])),
129
129
  }
edsl/coop/coop.py CHANGED
@@ -4,17 +4,19 @@ import requests
4
4
 
5
5
  from typing import Any, Optional, Union, Literal, TypedDict
6
6
  from uuid import UUID
7
- from collections import UserDict, defaultdict
8
7
 
9
8
  import edsl
10
- from pathlib import Path
11
9
 
12
10
  from edsl.config import CONFIG
13
11
  from edsl.data.CacheEntry import CacheEntry
14
12
  from edsl.jobs.Jobs import Jobs
15
13
  from edsl.surveys.Survey import Survey
16
14
 
17
- from edsl.exceptions.coop import CoopNoUUIDError, CoopServerResponseError
15
+ from edsl.exceptions.coop import (
16
+ CoopInvalidURLError,
17
+ CoopNoUUIDError,
18
+ CoopServerResponseError,
19
+ )
18
20
  from edsl.coop.utils import (
19
21
  EDSLObject,
20
22
  ObjectRegistry,
@@ -285,17 +287,46 @@ class Coop(CoopFunctionsMixin):
285
287
  if value is None:
286
288
  return "null"
287
289
 
288
- def _resolve_uuid(
290
+ def _resolve_uuid_or_alias(
289
291
  self, uuid: Union[str, UUID] = None, url: str = None
290
- ) -> Union[str, UUID]:
292
+ ) -> tuple[Optional[str], Optional[str], Optional[str]]:
291
293
  """
292
- Resolve the uuid from a uuid or a url.
294
+ Resolve the uuid or alias information from a uuid or a url.
295
+ Returns a tuple of (uuid, owner_username, alias)
296
+ - For content/<uuid> URLs: returns (uuid, None, None)
297
+ - For content/<username>/<alias> URLs: returns (None, username, alias)
293
298
  """
294
299
  if not url and not uuid:
295
300
  raise CoopNoUUIDError("No uuid or url provided for the object.")
301
+
296
302
  if not uuid and url:
297
- uuid = url.split("/")[-1]
298
- return uuid
303
+ parts = (
304
+ url.replace("http://", "")
305
+ .replace("https://", "")
306
+ .rstrip("/")
307
+ .split("/")
308
+ )
309
+
310
+ # Remove domain
311
+ parts = parts[1:]
312
+
313
+ if len(parts) < 2 or parts[0] != "content":
314
+ raise CoopInvalidURLError(
315
+ f"Invalid URL format. The URL must end with /content/<uuid> or /content/<username>/<alias>: {url}"
316
+ )
317
+
318
+ if len(parts) == 2:
319
+ obj_uuid = parts[1]
320
+ return obj_uuid, None, None
321
+ elif len(parts) == 3:
322
+ username, alias = parts[1], parts[2]
323
+ return None, username, alias
324
+ else:
325
+ raise CoopInvalidURLError(
326
+ f"Invalid URL format. The URL must end with /content/<uuid> or /content/<username>/<alias>: {url}"
327
+ )
328
+
329
+ return str(uuid), None, None
299
330
 
300
331
  @property
301
332
  def edsl_settings(self) -> dict:
@@ -307,7 +338,7 @@ class Coop(CoopFunctionsMixin):
307
338
 
308
339
  try:
309
340
  response = self._send_server_request(
310
- uri="api/v0/edsl-settings", method="GET", timeout=5
341
+ uri="api/v0/edsl-settings", method="GET", timeout=20
311
342
  )
312
343
  self._resolve_server_response(response, check_api_key=False)
313
344
  return response.json()
@@ -361,22 +392,31 @@ class Coop(CoopFunctionsMixin):
361
392
  expected_object_type: Optional[ObjectType] = None,
362
393
  ) -> EDSLObject:
363
394
  """
364
- Retrieve an EDSL object by its uuid or its url.
395
+ Retrieve an EDSL object by its uuid/url or by owner username and alias.
365
396
  - If the object's visibility is private, the user must be the owner.
366
397
  - Optionally, check if the retrieved object is of a certain type.
367
398
 
368
399
  :param uuid: the uuid of the object either in str or UUID format.
369
- :param url: the url of the object.
400
+ :param url: the url of the object (can be content/uuid or content/username/alias format).
370
401
  :param expected_object_type: the expected type of the object.
371
402
 
372
403
  :return: the object instance.
373
404
  """
374
- uuid = self._resolve_uuid(uuid, url)
375
- response = self._send_server_request(
376
- uri=f"api/v0/object",
377
- method="GET",
378
- params={"uuid": uuid},
379
- )
405
+ obj_uuid, owner_username, alias = self._resolve_uuid_or_alias(uuid, url)
406
+
407
+ if obj_uuid:
408
+ response = self._send_server_request(
409
+ uri=f"api/v0/object",
410
+ method="GET",
411
+ params={"uuid": obj_uuid},
412
+ )
413
+ else:
414
+ response = self._send_server_request(
415
+ uri=f"api/v0/object/alias",
416
+ method="GET",
417
+ params={"owner_username": owner_username, "alias": alias},
418
+ )
419
+
380
420
  self._resolve_server_response(response)
381
421
  json_string = response.json().get("json_string")
382
422
  object_type = response.json().get("object_type")
@@ -414,12 +454,13 @@ class Coop(CoopFunctionsMixin):
414
454
  """
415
455
  Delete an object from the server.
416
456
  """
417
- uuid = self._resolve_uuid(uuid, url)
457
+ obj_uuid, _, _ = self._resolve_uuid_or_alias(uuid, url)
418
458
  response = self._send_server_request(
419
459
  uri=f"api/v0/object",
420
460
  method="DELETE",
421
- params={"uuid": uuid},
461
+ params={"uuid": obj_uuid},
422
462
  )
463
+
423
464
  self._resolve_server_response(response)
424
465
  return response.json()
425
466
 
@@ -438,11 +479,11 @@ class Coop(CoopFunctionsMixin):
438
479
  """
439
480
  if description is None and visibility is None and value is None:
440
481
  raise Exception("Nothing to patch.")
441
- uuid = self._resolve_uuid(uuid, url)
482
+ obj_uuid, _, _ = self._resolve_uuid_or_alias(uuid, url)
442
483
  response = self._send_server_request(
443
484
  uri=f"api/v0/object",
444
485
  method="PATCH",
445
- params={"uuid": uuid},
486
+ params={"uuid": obj_uuid},
446
487
  payload={
447
488
  "description": description,
448
489
  "alias": alias,
@@ -549,6 +590,7 @@ class Coop(CoopFunctionsMixin):
549
590
  def remote_cache_get(
550
591
  self,
551
592
  exclude_keys: Optional[list[str]] = None,
593
+ select_keys: Optional[list[str]] = None,
552
594
  ) -> list[CacheEntry]:
553
595
  """
554
596
  Get all remote cache entries.
@@ -560,10 +602,12 @@ class Coop(CoopFunctionsMixin):
560
602
  """
561
603
  if exclude_keys is None:
562
604
  exclude_keys = []
605
+ if select_keys is None:
606
+ select_keys = []
563
607
  response = self._send_server_request(
564
608
  uri="api/v0/remote-cache/get-many",
565
609
  method="POST",
566
- payload={"keys": exclude_keys},
610
+ payload={"keys": exclude_keys, "selected_keys": select_keys},
567
611
  timeout=40,
568
612
  )
569
613
  self._resolve_server_response(response)
@@ -820,6 +864,40 @@ class Coop(CoopFunctionsMixin):
820
864
  "usd": response_json.get("cost_in_usd"),
821
865
  }
822
866
 
867
+ ################
868
+ # PROJECTS
869
+ ################
870
+ def create_project(
871
+ self,
872
+ survey: Survey,
873
+ project_name: str,
874
+ survey_description: Optional[str] = None,
875
+ survey_alias: Optional[str] = None,
876
+ survey_visibility: Optional[VisibilityType] = "unlisted",
877
+ ):
878
+ """
879
+ Create a survey object on Coop, then create a project from the survey.
880
+ """
881
+ survey_details = self.create(
882
+ object=survey,
883
+ description=survey_description,
884
+ alias=survey_alias,
885
+ visibility=survey_visibility,
886
+ )
887
+ survey_uuid = survey_details.get("uuid")
888
+ response = self._send_server_request(
889
+ uri=f"api/v0/projects/create-from-survey",
890
+ method="POST",
891
+ payload={"project_name": project_name, "survey_uuid": str(survey_uuid)},
892
+ )
893
+ self._resolve_server_response(response)
894
+ response_json = response.json()
895
+ return {
896
+ "name": response_json.get("project_name"),
897
+ "uuid": response_json.get("uuid"),
898
+ "url": f"{self.url}/home/projects/{response_json.get('uuid')}",
899
+ }
900
+
823
901
  ################
824
902
  # DUNDER METHODS
825
903
  ################
edsl/enums.py CHANGED
@@ -67,6 +67,7 @@ class InferenceServiceType(EnumWithChecks):
67
67
  TOGETHER = "together"
68
68
  PERPLEXITY = "perplexity"
69
69
  DEEPSEEK = "deepseek"
70
+ GROK = "grok"
70
71
 
71
72
 
72
73
  # unavoidable violation of the DRY principle but it is necessary
@@ -86,6 +87,7 @@ InferenceServiceLiteral = Literal[
86
87
  "together",
87
88
  "perplexity",
88
89
  "deepseek",
90
+ "grok",
89
91
  ]
90
92
 
91
93
  available_models_urls = {
@@ -97,7 +99,6 @@ available_models_urls = {
97
99
 
98
100
 
99
101
  service_to_api_keyname = {
100
- InferenceServiceType.BEDROCK.value: "TBD",
101
102
  InferenceServiceType.DEEP_INFRA.value: "DEEP_INFRA_API_KEY",
102
103
  InferenceServiceType.REPLICATE.value: "TBD",
103
104
  InferenceServiceType.OPENAI.value: "OPENAI_API_KEY",
@@ -110,6 +111,7 @@ service_to_api_keyname = {
110
111
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
111
112
  InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
112
113
  InferenceServiceType.DEEPSEEK.value: "DEEPSEEK_API_KEY",
114
+ InferenceServiceType.GROK.value: "XAI_API_KEY",
113
115
  }
114
116
 
115
117
 
edsl/exceptions/coop.py CHANGED
@@ -2,6 +2,10 @@ class CoopErrors(Exception):
2
2
  pass
3
3
 
4
4
 
5
+ class CoopInvalidURLError(CoopErrors):
6
+ pass
7
+
8
+
5
9
  class CoopNoUUIDError(CoopErrors):
6
10
  pass
7
11
 
@@ -17,6 +17,8 @@ class AnthropicService(InferenceServiceABC):
17
17
  output_token_name = "output_tokens"
18
18
  model_exclude_list = []
19
19
 
20
+ available_models_url = 'https://docs.anthropic.com/en/docs/about-claude/models'
21
+
20
22
  @classmethod
21
23
  def get_model_list(cls, api_key: str = None):
22
24
 
@@ -136,7 +136,10 @@ class AvailableModelFetcher:
136
136
  if not service_models:
137
137
  import warnings
138
138
 
139
- warnings.warn(f"No models found for service {service_name}")
139
+ with warnings.catch_warnings():
140
+ warnings.simplefilter("ignore") # Ignores the warning
141
+ warnings.warn(f"No models found for service {service_name}")
142
+
140
143
  return [], service_name
141
144
 
142
145
  models_list = AvailableModels(
@@ -39,6 +39,8 @@ class GoogleService(InferenceServiceABC):
39
39
 
40
40
  model_exclude_list = []
41
41
 
42
+ available_models_url = 'https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models'
43
+
42
44
  @classmethod
43
45
  def get_model_list(cls):
44
46
  model_list = []
@@ -0,0 +1,11 @@
1
+ from typing import Any, List
2
+ from edsl.inference_services.OpenAIService import OpenAIService
3
+
4
+
5
+ class GrokService(OpenAIService):
6
+ """Openai service class."""
7
+
8
+ _inference_service_ = "grok"
9
+ _env_key_name_ = "XAI_API_KEY"
10
+ _base_url_ = "https://api.x.ai/v1"
11
+ _models_list_cache: List[str] = []
@@ -23,6 +23,7 @@ class InferenceServiceABC(ABC):
23
23
  "usage_sequence",
24
24
  "input_token_name",
25
25
  "output_token_name",
26
+ #"available_models_url",
26
27
  ]
27
28
  for attr in must_have_attributes:
28
29
  if not hasattr(cls, attr):
@@ -84,6 +84,7 @@ class OpenAIService(InferenceServiceABC):
84
84
 
85
85
  @classmethod
86
86
  def get_model_list(cls, api_key=None):
87
+ # breakpoint()
87
88
  if api_key is None:
88
89
  api_key = os.getenv(cls._env_key_name_)
89
90
  raw_list = cls.sync_client(api_key).models.list()
@@ -28,6 +28,7 @@ class TestService(InferenceServiceABC):
28
28
  model_exclude_list = []
29
29
  input_token_name = "prompt_tokens"
30
30
  output_token_name = "completion_tokens"
31
+ available_models_url = None
31
32
 
32
33
  @classmethod
33
34
  def available(cls) -> list[str]:
@@ -14,6 +14,7 @@ from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
15
  from edsl.inference_services.PerplexityService import PerplexityService
16
16
  from edsl.inference_services.DeepSeekService import DeepSeekService
17
+ from edsl.inference_services.GrokService import GrokService
17
18
 
18
19
  try:
19
20
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -35,6 +36,7 @@ services = [
35
36
  TogetherAIService,
36
37
  PerplexityService,
37
38
  DeepSeekService,
39
+ GrokService,
38
40
  ]
39
41
 
40
42
  if mistral_available: