edsl 0.1.38__py3-none-any.whl → 0.1.38.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. edsl/Base.py +34 -63
  2. edsl/BaseDiff.py +7 -7
  3. edsl/__init__.py +1 -2
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +11 -23
  6. edsl/agents/AgentList.py +23 -86
  7. edsl/agents/Invigilator.py +7 -18
  8. edsl/agents/InvigilatorBase.py +19 -0
  9. edsl/agents/PromptConstructor.py +4 -5
  10. edsl/auto/SurveyCreatorPipeline.py +1 -1
  11. edsl/auto/utilities.py +1 -1
  12. edsl/base/Base.py +13 -3
  13. edsl/config.py +0 -8
  14. edsl/conjure/AgentConstructionMixin.py +160 -0
  15. edsl/conjure/Conjure.py +62 -0
  16. edsl/conjure/InputData.py +659 -0
  17. edsl/conjure/InputDataCSV.py +48 -0
  18. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  19. edsl/conjure/InputDataPyRead.py +91 -0
  20. edsl/conjure/InputDataSPSS.py +8 -0
  21. edsl/conjure/InputDataStata.py +8 -0
  22. edsl/conjure/QuestionOptionMixin.py +76 -0
  23. edsl/conjure/QuestionTypeMixin.py +23 -0
  24. edsl/conjure/RawQuestion.py +65 -0
  25. edsl/conjure/SurveyResponses.py +7 -0
  26. edsl/conjure/__init__.py +9 -0
  27. edsl/conjure/examples/placeholder.txt +0 -0
  28. edsl/{utilities → conjure}/naming_utilities.py +1 -1
  29. edsl/conjure/utilities.py +201 -0
  30. edsl/coop/coop.py +7 -77
  31. edsl/data/Cache.py +17 -45
  32. edsl/data/CacheEntry.py +3 -8
  33. edsl/data/RemoteCacheSync.py +19 -0
  34. edsl/enums.py +0 -2
  35. edsl/exceptions/agents.py +0 -4
  36. edsl/inference_services/GoogleService.py +15 -7
  37. edsl/inference_services/registry.py +0 -2
  38. edsl/jobs/Jobs.py +559 -110
  39. edsl/jobs/buckets/TokenBucket.py +0 -3
  40. edsl/jobs/interviews/Interview.py +7 -7
  41. edsl/jobs/runners/JobsRunnerAsyncio.py +28 -156
  42. edsl/jobs/runners/JobsRunnerStatus.py +196 -194
  43. edsl/jobs/tasks/TaskHistory.py +19 -27
  44. edsl/language_models/LanguageModel.py +90 -52
  45. edsl/language_models/ModelList.py +14 -67
  46. edsl/language_models/registry.py +4 -57
  47. edsl/notebooks/Notebook.py +8 -7
  48. edsl/prompts/Prompt.py +3 -8
  49. edsl/questions/QuestionBase.py +30 -38
  50. edsl/questions/QuestionBaseGenMixin.py +1 -1
  51. edsl/questions/QuestionBasePromptsMixin.py +17 -0
  52. edsl/questions/QuestionExtract.py +4 -3
  53. edsl/questions/QuestionFunctional.py +3 -10
  54. edsl/questions/derived/QuestionTopK.py +0 -2
  55. edsl/questions/question_registry.py +6 -36
  56. edsl/results/Dataset.py +15 -146
  57. edsl/results/DatasetExportMixin.py +217 -231
  58. edsl/results/DatasetTree.py +4 -134
  59. edsl/results/Result.py +16 -31
  60. edsl/results/Results.py +65 -159
  61. edsl/scenarios/FileStore.py +13 -187
  62. edsl/scenarios/Scenario.py +18 -73
  63. edsl/scenarios/ScenarioList.py +76 -251
  64. edsl/surveys/MemoryPlan.py +1 -1
  65. edsl/surveys/Rule.py +5 -1
  66. edsl/surveys/RuleCollection.py +1 -1
  67. edsl/surveys/Survey.py +19 -25
  68. edsl/surveys/SurveyFlowVisualizationMixin.py +9 -67
  69. edsl/surveys/instructions/ChangeInstruction.py +7 -9
  70. edsl/surveys/instructions/Instruction.py +7 -21
  71. edsl/templates/error_reporting/interview_details.html +3 -3
  72. edsl/templates/error_reporting/interviews.html +9 -18
  73. edsl/utilities/utilities.py +0 -15
  74. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/METADATA +1 -2
  75. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/RECORD +77 -71
  76. edsl/exceptions/cache.py +0 -5
  77. edsl/inference_services/PerplexityService.py +0 -163
  78. edsl/jobs/JobsChecks.py +0 -147
  79. edsl/jobs/JobsPrompts.py +0 -268
  80. edsl/jobs/JobsRemoteInferenceHandler.py +0 -239
  81. edsl/results/CSSParameterizer.py +0 -108
  82. edsl/results/TableDisplay.py +0 -198
  83. edsl/results/table_display.css +0 -78
  84. edsl/scenarios/ScenarioJoin.py +0 -127
  85. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/LICENSE +0 -0
  86. {edsl-0.1.38.dist-info → edsl-0.1.38.dev1.dist-info}/WHEEL +0 -0
edsl/coop/coop.py CHANGED
@@ -42,9 +42,6 @@ class Coop:
42
42
  self.api_url = self.url
43
43
  self._edsl_version = edsl.__version__
44
44
 
45
- def get_progress_bar_url(self):
46
- return f"{CONFIG.EXPECTED_PARROT_URL}"
47
-
48
45
  ################
49
46
  # BASIC METHODS
50
47
  ################
@@ -102,57 +99,12 @@ class Coop:
102
99
 
103
100
  return response
104
101
 
105
- def _get_latest_stable_version(self, version: str) -> str:
106
- """
107
- Extract the latest stable PyPI version from a version string.
108
-
109
- Examples:
110
- - Decrement the patch number of a dev version: "0.1.38.dev1" -> "0.1.37"
111
- - Return a stable version as is: "0.1.37" -> "0.1.37"
112
- """
113
- if "dev" not in version:
114
- return version
115
- else:
116
- # For 0.1.38.dev1, split into ["0", "1", "38", "dev1"]
117
- major, minor, patch = version.split(".")[:3]
118
-
119
- current_patch = int(patch)
120
- latest_patch = current_patch - 1
121
- return f"{major}.{minor}.{latest_patch}"
122
-
123
- def _user_version_is_outdated(
124
- self, user_version_str: str, server_version_str: str
125
- ) -> bool:
126
- """
127
- Check if the user's EDSL version is outdated compared to the server's.
128
- """
129
- server_stable_version_str = self._get_latest_stable_version(server_version_str)
130
- user_stable_version_str = self._get_latest_stable_version(user_version_str)
131
-
132
- # Turn the version strings into tuples of ints for comparison
133
- user_stable_version = tuple(map(int, user_stable_version_str.split(".")))
134
- server_stable_version = tuple(map(int, server_stable_version_str.split(".")))
135
-
136
- return user_stable_version < server_stable_version
137
-
138
102
  def _resolve_server_response(
139
103
  self, response: requests.Response, check_api_key: bool = True
140
104
  ) -> None:
141
105
  """
142
106
  Check the response from the server and raise errors as appropriate.
143
107
  """
144
- # Get EDSL version from header
145
- server_edsl_version = response.headers.get("X-EDSL-Version")
146
-
147
- if server_edsl_version:
148
- if self._user_version_is_outdated(
149
- user_version_str=self._edsl_version,
150
- server_version_str=server_edsl_version,
151
- ):
152
- print(
153
- "Please upgrade your EDSL version to access our latest features. To upgrade, open your terminal and run `pip upgrade edsl`"
154
- )
155
-
156
108
  if response.status_code >= 400:
157
109
  message = response.json().get("detail")
158
110
  # print(response.text)
@@ -625,7 +577,7 @@ class Coop:
625
577
 
626
578
  >>> job = Jobs.example()
627
579
  >>> coop.remote_inference_create(job=job, description="My job")
628
- {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'iterations': None, 'visibility': 'unlisted', 'version': '0.1.38.dev1'}
580
+ {'uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'description': 'My job', 'status': 'queued', 'visibility': 'unlisted', 'version': '0.1.29.dev4'}
629
581
  """
630
582
  response = self._send_server_request(
631
583
  uri="api/v0/remote-inference",
@@ -666,7 +618,7 @@ class Coop:
666
618
  :param results_uuid: The UUID of the results associated with the EDSL job.
667
619
 
668
620
  >>> coop.remote_inference_get("9f8484ee-b407-40e4-9652-4133a7236c9c")
669
- {'job_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'latest_error_report_uuid': None, 'latest_error_report_url': None, 'status': 'completed', 'reason': None, 'credits_consumed': 0.35, 'version': '0.1.38.dev1'}
621
+ {'jobs_uuid': '9f8484ee-b407-40e4-9652-4133a7236c9c', 'results_uuid': 'dd708234-31bf-4fe1-8747-6e232625e026', 'results_url': 'https://www.expectedparrot.com/content/dd708234-31bf-4fe1-8747-6e232625e026', 'status': 'completed', 'reason': None, 'price': 16, 'version': '0.1.29.dev4'}
670
622
  """
671
623
  if job_uuid is None and results_uuid is None:
672
624
  raise ValueError("Either job_uuid or results_uuid must be provided.")
@@ -682,28 +634,10 @@ class Coop:
682
634
  )
683
635
  self._resolve_server_response(response)
684
636
  data = response.json()
685
-
686
- results_uuid = data.get("results_uuid")
687
- latest_error_report_uuid = data.get("latest_error_report_uuid")
688
-
689
- if results_uuid is None:
690
- results_url = None
691
- else:
692
- results_url = f"{self.url}/content/{results_uuid}"
693
-
694
- if latest_error_report_uuid is None:
695
- latest_error_report_url = None
696
- else:
697
- latest_error_report_url = (
698
- f"{self.url}/home/remote-inference/error/{latest_error_report_uuid}"
699
- )
700
-
701
637
  return {
702
638
  "job_uuid": data.get("job_uuid"),
703
- "results_uuid": results_uuid,
704
- "results_url": results_url,
705
- "latest_error_report_uuid": latest_error_report_uuid,
706
- "latest_error_report_url": latest_error_report_url,
639
+ "results_uuid": data.get("results_uuid"),
640
+ "results_url": f"{self.url}/content/{data.get('results_uuid')}",
707
641
  "status": data.get("status"),
708
642
  "reason": data.get("reason"),
709
643
  "credits_consumed": data.get("price"),
@@ -720,7 +654,7 @@ class Coop:
720
654
 
721
655
  >>> job = Jobs.example()
722
656
  >>> coop.remote_inference_cost(input=job)
723
- {'credits': 0.77, 'usd': 0.0076950000000000005}
657
+ 16
724
658
  """
725
659
  if isinstance(input, Jobs):
726
660
  job = input
@@ -800,15 +734,11 @@ class Coop:
800
734
 
801
735
  from edsl.config import CONFIG
802
736
 
803
- if CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "True":
737
+ if bool(CONFIG.get("EDSL_FETCH_TOKEN_PRICES")):
804
738
  price_fetcher = PriceFetcher()
805
739
  return price_fetcher.fetch_prices()
806
- elif CONFIG.get("EDSL_FETCH_TOKEN_PRICES") == "False":
807
- return {}
808
740
  else:
809
- raise ValueError(
810
- "Invalid EDSL_FETCH_TOKEN_PRICES value---should be 'True' or 'False'."
811
- )
741
+ return {}
812
742
 
813
743
  def fetch_models(self) -> dict:
814
744
  """
edsl/data/Cache.py CHANGED
@@ -11,8 +11,7 @@ from typing import Optional, Union
11
11
  from edsl.Base import Base
12
12
  from edsl.data.CacheEntry import CacheEntry
13
13
  from edsl.utilities.utilities import dict_hash
14
- from edsl.utilities.decorators import remove_edsl_version
15
- from edsl.exceptions.cache import CacheError
14
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
16
15
 
17
16
 
18
17
  class Cache(Base):
@@ -27,8 +26,6 @@ class Cache(Base):
27
26
  :param method: The method of storage to use for the cache.
28
27
  """
29
28
 
30
- __documentation__ = "https://docs.expectedparrot.com/en/latest/data.html"
31
-
32
29
  data = {}
33
30
 
34
31
  def __init__(
@@ -61,7 +58,7 @@ class Cache(Base):
61
58
 
62
59
  self.filename = filename
63
60
  if filename and data:
64
- raise CacheError("Cannot provide both filename and data")
61
+ raise ValueError("Cannot provide both filename and data")
65
62
  if filename is None and data is None:
66
63
  data = {}
67
64
  if data is not None:
@@ -79,7 +76,7 @@ class Cache(Base):
79
76
  if os.path.exists(filename):
80
77
  self.add_from_sqlite(filename)
81
78
  else:
82
- raise CacheError("Invalid file extension. Must be .jsonl or .db")
79
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
83
80
 
84
81
  self._perform_checks()
85
82
 
@@ -119,7 +116,7 @@ class Cache(Base):
119
116
  from edsl.data.CacheEntry import CacheEntry
120
117
 
121
118
  if any(not isinstance(value, CacheEntry) for value in self.data.values()):
122
- raise CacheError("Not all values are CacheEntry instances")
119
+ raise Exception("Not all values are CacheEntry instances")
123
120
  if self.method is not None:
124
121
  warnings.warn("Argument `method` is deprecated", DeprecationWarning)
125
122
 
@@ -230,9 +227,9 @@ class Cache(Base):
230
227
  for key, value in new_data.items():
231
228
  if key in self.data:
232
229
  if value != self.data[key]:
233
- raise CacheError("Mismatch in values")
230
+ raise Exception("Mismatch in values")
234
231
  if not isinstance(value, CacheEntry):
235
- raise CacheError(f"Wrong type - the observed type is {type(value)}")
232
+ raise Exception(f"Wrong type - the observed type is {type(value)}")
236
233
 
237
234
  self.new_entries.update(new_data)
238
235
  if write_now:
@@ -341,7 +338,7 @@ class Cache(Base):
341
338
  elif filename.endswith(".db"):
342
339
  self.write_sqlite_db(filename)
343
340
  else:
344
- raise CacheError("Invalid file extension. Must be .jsonl or .db")
341
+ raise ValueError("Invalid file extension. Must be .jsonl or .db")
345
342
 
346
343
  def write_jsonl(self, filename: str) -> None:
347
344
  """
@@ -399,45 +396,20 @@ class Cache(Base):
399
396
  ####################
400
397
  def __hash__(self):
401
398
  """Return the hash of the Cache."""
402
- return dict_hash(self.to_dict(add_edsl_version=False))
403
-
404
- def to_dict(self, add_edsl_version=True) -> dict:
405
- d = {k: v.to_dict() for k, v in self.data.items()}
406
- if add_edsl_version:
407
- from edsl import __version__
399
+ return dict_hash(self._to_dict())
408
400
 
409
- d["edsl_version"] = __version__
410
- d["edsl_class_name"] = "Cache"
401
+ def _to_dict(self) -> dict:
402
+ return {k: v.to_dict() for k, v in self.data.items()}
411
403
 
412
- return d
413
-
414
- def _summary(self):
415
- return {"EDSL Class": "Cache", "Number of entries": len(self.data)}
404
+ @add_edsl_version
405
+ def to_dict(self) -> dict:
406
+ """Return the Cache as a dictionary."""
407
+ return self._to_dict()
416
408
 
417
409
  def _repr_html_(self):
418
- # from edsl.utilities.utilities import data_to_html
419
- # return data_to_html(self.to_dict())
420
- footer = f"<a href={self.__documentation__}>(docs)</a>"
421
- return str(self.summary(format="html")) + footer
422
-
423
- def table(
424
- self,
425
- *fields,
426
- tablefmt: Optional[str] = None,
427
- pretty_labels: Optional[dict] = None,
428
- ) -> str:
429
- return self.to_dataset().table(
430
- *fields, tablefmt=tablefmt, pretty_labels=pretty_labels
431
- )
432
-
433
- def select(self, *fields):
434
- return self.to_dataset().select(*fields)
435
-
436
- def tree(self, node_list: Optional[list[str]] = None):
437
- return self.to_scenario_list().tree(node_list)
410
+ from edsl.utilities.utilities import data_to_html
438
411
 
439
- def to_dataset(self):
440
- return self.to_scenario_list().to_dataset()
412
+ return data_to_html(self.to_dict())
441
413
 
442
414
  @classmethod
443
415
  @remove_edsl_version
@@ -466,7 +438,7 @@ class Cache(Base):
466
438
  Combine two caches.
467
439
  """
468
440
  if not isinstance(other, Cache):
469
- raise CacheError("Can only add two caches together")
441
+ raise ValueError("Can only add two caches together")
470
442
  self.data.update(other.data)
471
443
  return self
472
444
 
edsl/data/CacheEntry.py CHANGED
@@ -96,14 +96,9 @@ class CacheEntry:
96
96
  """
97
97
  Returns an HTML representation of a CacheEntry.
98
98
  """
99
- # from edsl.utilities.utilities import data_to_html
100
- # return data_to_html(self.to_dict())
101
- d = self.to_dict()
102
- data = [[k, v] for k, v in d.items()]
103
- from tabulate import tabulate
104
-
105
- table = str(tabulate(data, headers=["keys", "values"], tablefmt="html"))
106
- return f"<pre>{table}</pre>"
99
+ from edsl.utilities.utilities import data_to_html
100
+
101
+ return data_to_html(self.to_dict())
107
102
 
108
103
  def keys(self):
109
104
  return list(self.to_dict().keys())
@@ -76,3 +76,22 @@ class RemoteCacheSync:
76
76
  self._output(
77
77
  f"There are {len(self.cache.keys()):,} entries in the local cache."
78
78
  )
79
+
80
+
81
+ # # Usage example
82
+ # def run_job(self, n, progress_bar, cache, stop_on_exception, sidecar_model, print_exceptions, raise_validation_errors, use_remote_cache=True):
83
+ # with RemoteCacheSync(self.coop, cache, self._output, remote_cache=use_remote_cache):
84
+ # self._output("Running job...")
85
+ # results = self._run_local(
86
+ # n=n,
87
+ # progress_bar=progress_bar,
88
+ # cache=cache,
89
+ # stop_on_exception=stop_on_exception,
90
+ # sidecar_model=sidecar_model,
91
+ # print_exceptions=print_exceptions,
92
+ # raise_validation_errors=raise_validation_errors,
93
+ # )
94
+ # self._output("Job completed!")
95
+
96
+ # results.cache = cache.new_entries_cache()
97
+ # return results
edsl/enums.py CHANGED
@@ -64,7 +64,6 @@ class InferenceServiceType(EnumWithChecks):
64
64
  OLLAMA = "ollama"
65
65
  MISTRAL = "mistral"
66
66
  TOGETHER = "together"
67
- PERPLEXITY = "perplexity"
68
67
 
69
68
 
70
69
  service_to_api_keyname = {
@@ -79,7 +78,6 @@ service_to_api_keyname = {
79
78
  InferenceServiceType.BEDROCK.value: ["AWS_ACCESS_KEY_ID", "AWS_SECRET_ACCESS_KEY"],
80
79
  InferenceServiceType.MISTRAL.value: "MISTRAL_API_KEY",
81
80
  InferenceServiceType.TOGETHER.value: "TOGETHER_API_KEY",
82
- InferenceServiceType.PERPLEXITY.value: "PERPLEXITY_API_KEY",
83
81
  }
84
82
 
85
83
 
edsl/exceptions/agents.py CHANGED
@@ -1,10 +1,6 @@
1
1
  from edsl.exceptions.BaseException import BaseException
2
2
 
3
3
 
4
- class AgentListError(BaseException):
5
- relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html#agent-lists"
6
-
7
-
8
4
  class AgentErrors(BaseException):
9
5
  relevant_doc = "https://docs.expectedparrot.com/en/latest/agents.html"
10
6
 
@@ -8,7 +8,6 @@ from google.api_core.exceptions import InvalidArgument
8
8
  from edsl.exceptions import MissingAPIKeyError
9
9
  from edsl.language_models.LanguageModel import LanguageModel
10
10
  from edsl.inference_services.InferenceServiceABC import InferenceServiceABC
11
- from edsl.coop import Coop
12
11
 
13
12
  safety_settings = [
14
13
  {
@@ -80,8 +79,22 @@ class GoogleService(InferenceServiceABC):
80
79
  api_token = None
81
80
  model = None
82
81
 
82
+ @classmethod
83
+ def initialize(cls):
84
+ if cls.api_token is None:
85
+ cls.api_token = os.getenv("GOOGLE_API_KEY")
86
+ if not cls.api_token:
87
+ raise MissingAPIKeyError(
88
+ "GOOGLE_API_KEY environment variable is not set"
89
+ )
90
+ genai.configure(api_key=cls.api_token)
91
+ cls.generative_model = genai.GenerativeModel(
92
+ cls._model_, safety_settings=safety_settings
93
+ )
94
+
83
95
  def __init__(self, *args, **kwargs):
84
96
  super().__init__(*args, **kwargs)
97
+ self.initialize()
85
98
 
86
99
  def get_generation_config(self) -> GenerationConfig:
87
100
  return GenerationConfig(
@@ -103,7 +116,6 @@ class GoogleService(InferenceServiceABC):
103
116
  if files_list is None:
104
117
  files_list = []
105
118
 
106
- genai.configure(api_key=self.api_token)
107
119
  if (
108
120
  system_prompt is not None
109
121
  and system_prompt != ""
@@ -121,11 +133,7 @@ class GoogleService(InferenceServiceABC):
121
133
  )
122
134
  print("Will add system_prompt to user_prompt")
123
135
  user_prompt = f"{system_prompt}\n{user_prompt}"
124
- else:
125
- self.generative_model = genai.GenerativeModel(
126
- self._model_,
127
- safety_settings=safety_settings,
128
- )
136
+
129
137
  combined_prompt = [user_prompt]
130
138
  for file in files_list:
131
139
  if "google" not in file.external_locations:
@@ -12,7 +12,6 @@ from edsl.inference_services.AzureAI import AzureAIService
12
12
  from edsl.inference_services.OllamaService import OllamaService
13
13
  from edsl.inference_services.TestService import TestService
14
14
  from edsl.inference_services.TogetherAIService import TogetherAIService
15
- from edsl.inference_services.PerplexityService import PerplexityService
16
15
 
17
16
  try:
18
17
  from edsl.inference_services.MistralAIService import MistralAIService
@@ -32,7 +31,6 @@ services = [
32
31
  OllamaService,
33
32
  TestService,
34
33
  TogetherAIService,
35
- PerplexityService,
36
34
  ]
37
35
 
38
36
  if mistral_available: