edsl 0.1.37.dev4__py3-none-any.whl → 0.1.37.dev6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Agent.py +86 -35
  3. edsl/agents/AgentList.py +5 -0
  4. edsl/agents/InvigilatorBase.py +2 -23
  5. edsl/agents/PromptConstructor.py +147 -106
  6. edsl/agents/descriptors.py +17 -4
  7. edsl/config.py +1 -1
  8. edsl/conjure/AgentConstructionMixin.py +11 -3
  9. edsl/conversation/Conversation.py +66 -14
  10. edsl/conversation/chips.py +95 -0
  11. edsl/coop/coop.py +134 -3
  12. edsl/data/Cache.py +1 -1
  13. edsl/exceptions/BaseException.py +21 -0
  14. edsl/exceptions/__init__.py +7 -3
  15. edsl/exceptions/agents.py +17 -19
  16. edsl/exceptions/results.py +11 -8
  17. edsl/exceptions/scenarios.py +22 -0
  18. edsl/exceptions/surveys.py +13 -10
  19. edsl/inference_services/InferenceServicesCollection.py +32 -9
  20. edsl/jobs/Jobs.py +265 -53
  21. edsl/jobs/interviews/InterviewExceptionEntry.py +5 -1
  22. edsl/jobs/tasks/TaskHistory.py +1 -0
  23. edsl/language_models/KeyLookup.py +30 -0
  24. edsl/language_models/LanguageModel.py +47 -59
  25. edsl/language_models/__init__.py +1 -0
  26. edsl/prompts/Prompt.py +8 -4
  27. edsl/questions/QuestionBase.py +53 -13
  28. edsl/questions/QuestionBasePromptsMixin.py +1 -33
  29. edsl/questions/QuestionFunctional.py +2 -2
  30. edsl/questions/descriptors.py +23 -28
  31. edsl/results/DatasetExportMixin.py +25 -1
  32. edsl/results/Result.py +16 -1
  33. edsl/results/Results.py +31 -120
  34. edsl/results/ResultsDBMixin.py +1 -1
  35. edsl/results/Selector.py +18 -1
  36. edsl/scenarios/Scenario.py +48 -12
  37. edsl/scenarios/ScenarioHtmlMixin.py +7 -2
  38. edsl/scenarios/ScenarioList.py +12 -1
  39. edsl/surveys/Rule.py +10 -4
  40. edsl/surveys/Survey.py +100 -77
  41. edsl/utilities/utilities.py +18 -0
  42. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/METADATA +1 -1
  43. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/RECORD +45 -41
  44. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/LICENSE +0 -0
  45. {edsl-0.1.37.dev4.dist-info → edsl-0.1.37.dev6.dist-info}/WHEEL +0 -0
edsl/surveys/Survey.py CHANGED
@@ -9,6 +9,8 @@ from typing import Any, Generator, Optional, Union, List, Literal, Callable
9
9
  from uuid import uuid4
10
10
  from edsl.Base import Base
11
11
  from edsl.exceptions import SurveyCreationError, SurveyHasNoRulesError
12
+ from edsl.exceptions.surveys import SurveyError
13
+
12
14
  from edsl.questions.QuestionBase import QuestionBase
13
15
  from edsl.surveys.base import RulePriority, EndOfSurvey
14
16
  from edsl.surveys.DAG import DAG
@@ -30,7 +32,7 @@ from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
30
32
  class ValidatedString(str):
31
33
  def __new__(cls, content):
32
34
  if "<>" in content:
33
- raise ValueError(
35
+ raise SurveyCreationError(
34
36
  "The expression contains '<>', which is not allowed. You probably mean '!='."
35
37
  )
36
38
  return super().__new__(cls, content)
@@ -374,14 +376,15 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
374
376
  >>> s._get_question_index("poop")
375
377
  Traceback (most recent call last):
376
378
  ...
377
- ValueError: Question name poop not found in survey. The current question names are {'q0': 0, 'q1': 1, 'q2': 2}.
379
+ edsl.exceptions.surveys.SurveyError: Question name poop not found in survey. The current question names are {'q0': 0, 'q1': 1, 'q2': 2}.
380
+ ...
378
381
  """
379
382
  if q == EndOfSurvey:
380
383
  return EndOfSurvey
381
384
  else:
382
385
  question_name = q if isinstance(q, str) else q.question_name
383
386
  if question_name not in self.question_name_to_index:
384
- raise ValueError(
387
+ raise SurveyError(
385
388
  f"""Question name {question_name} not found in survey. The current question names are {self.question_name_to_index}."""
386
389
  )
387
390
  return self.question_name_to_index[question_name]
@@ -397,7 +400,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
397
400
  Question('multiple_choice', question_name = \"""q0\""", question_text = \"""Do you like school?\""", question_options = ['yes', 'no'])
398
401
  """
399
402
  if question_name not in self.question_name_to_index:
400
- raise KeyError(f"Question name {question_name} not found in survey.")
403
+ raise SurveyError(f"Question name {question_name} not found in survey.")
401
404
  index = self.question_name_to_index[question_name]
402
405
  return self._questions[index]
403
406
 
@@ -421,7 +424,6 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
421
424
  >>> s.question_names
422
425
  ['q0', 'q1', 'q2']
423
426
  """
424
- # return list(self.question_name_to_index.keys())
425
427
  return [q.question_name for q in self.questions]
426
428
 
427
429
  @property
@@ -506,9 +508,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
506
508
 
507
509
  return ChangeInstruction
508
510
  else:
509
- # some data might not have the edsl_class_name
510
511
  return QuestionBase
511
- # raise ValueError(f"Class {pass_dict['edsl_class_name']} not found")
512
512
 
513
513
  questions = [
514
514
  get_class(q_dict).from_dict(q_dict) for q_dict in data["questions"]
@@ -589,7 +589,8 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
589
589
  >>> s3 = s1 + s2
590
590
  Traceback (most recent call last):
591
591
  ...
592
- ValueError: ('Cannot combine two surveys with non-default rules.', "Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.")
592
+ edsl.exceptions.surveys.SurveyCreationError: ...
593
+ ...
593
594
  >>> s3 = s1.clear_non_default_rules() + s2
594
595
  >>> len(s3.questions)
595
596
  4
@@ -599,9 +600,8 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
599
600
  len(self.rule_collection.non_default_rules) > 0
600
601
  or len(other.rule_collection.non_default_rules) > 0
601
602
  ):
602
- raise ValueError(
603
- "Cannot combine two surveys with non-default rules.",
604
- "Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.",
603
+ raise SurveyCreationError(
604
+ "Cannot combine two surveys with non-default rules. Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.",
605
605
  )
606
606
 
607
607
  return Survey(questions=self.questions + other.questions)
@@ -609,16 +609,16 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
609
609
  def move_question(self, identifier: Union[str, int], new_index: int):
610
610
  if isinstance(identifier, str):
611
611
  if identifier not in self.question_names:
612
- raise ValueError(
612
+ raise SurveyError(
613
613
  f"Question name '{identifier}' does not exist in the survey."
614
614
  )
615
615
  index = self.question_name_to_index[identifier]
616
616
  elif isinstance(identifier, int):
617
617
  if identifier < 0 or identifier >= len(self.questions):
618
- raise ValueError(f"Index {identifier} is out of range.")
618
+ raise SurveyError(f"Index {identifier} is out of range.")
619
619
  index = identifier
620
620
  else:
621
- raise TypeError(
621
+ raise SurveyError(
622
622
  "Identifier must be either a string (question name) or an integer (question index)."
623
623
  )
624
624
 
@@ -648,33 +648,28 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
648
648
  """
649
649
  if isinstance(identifier, str):
650
650
  if identifier not in self.question_names:
651
- raise ValueError(
651
+ raise SurveyError(
652
652
  f"Question name '{identifier}' does not exist in the survey."
653
653
  )
654
654
  index = self.question_name_to_index[identifier]
655
655
  elif isinstance(identifier, int):
656
656
  if identifier < 0 or identifier >= len(self.questions):
657
- raise ValueError(f"Index {identifier} is out of range.")
657
+ raise SurveyError(f"Index {identifier} is out of range.")
658
658
  index = identifier
659
659
  else:
660
- raise TypeError(
660
+ raise SurveyError(
661
661
  "Identifier must be either a string (question name) or an integer (question index)."
662
662
  )
663
663
 
664
664
  # Remove the question
665
665
  deleted_question = self._questions.pop(index)
666
666
  del self.pseudo_indices[deleted_question.question_name]
667
- # del self.question_name_to_index[deleted_question.question_name]
668
667
 
669
668
  # Update indices
670
669
  for question_name, old_index in self.pseudo_indices.items():
671
670
  if old_index > index:
672
671
  self.pseudo_indices[question_name] = old_index - 1
673
672
 
674
- # for question_name, old_index in self.question_name_to_index.items():
675
- # if old_index > index:
676
- # self.question_name_to_index[question_name] = old_index - 1
677
-
678
673
  # Update rules
679
674
  new_rule_collection = RuleCollection()
680
675
  for rule in self.rule_collection:
@@ -690,13 +685,6 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
690
685
  rule.next_q = EndOfSurvey
691
686
  else:
692
687
  rule.next_q = index
693
- # rule.next_q = min(index, len(self.questions) - 1)
694
- # continue
695
-
696
- # if rule.next_q == index:
697
- # rule.next_q = min(
698
- # rule.next_q, len(self.questions) - 1
699
- # ) # Adjust to last question if necessary
700
688
 
701
689
  new_rule_collection.add_rule(rule)
702
690
  self.rule_collection = new_rule_collection
@@ -727,6 +715,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
727
715
  Traceback (most recent call last):
728
716
  ...
729
717
  edsl.exceptions.surveys.SurveyCreationError: Question name 'q0' already exists in survey. Existing names are ['q0'].
718
+ ...
730
719
  """
731
720
  if question.question_name in self.question_names:
732
721
  raise SurveyCreationError(
@@ -736,11 +725,11 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
736
725
  index = len(self.questions)
737
726
 
738
727
  if index > len(self.questions):
739
- raise ValueError(
728
+ raise SurveyCreationError(
740
729
  f"Index {index} is greater than the number of questions in the survey."
741
730
  )
742
731
  if index < 0:
743
- raise ValueError(f"Index {index} is less than 0.")
732
+ raise SurveyCreationError(f"Index {index} is less than 0.")
744
733
 
745
734
  interior_insertion = index != len(self.questions)
746
735
 
@@ -932,31 +921,32 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
932
921
  >>> s = Survey.example().add_question_group("q0", "q2", "1group1")
933
922
  Traceback (most recent call last):
934
923
  ...
935
- ValueError: Group name 1group1 is not a valid identifier.
936
-
937
- The name of the group cannot be the same as an existing question name:
938
-
924
+ edsl.exceptions.surveys.SurveyCreationError: Group name 1group1 is not a valid identifier.
925
+ ...
939
926
  >>> s = Survey.example().add_question_group("q0", "q1", "q0")
940
927
  Traceback (most recent call last):
941
928
  ...
942
- ValueError: Group name q0 already exists as a question name in the survey.
943
-
944
- The start index must be less than the end index:
945
-
929
+ edsl.exceptions.surveys.SurveyCreationError: ...
930
+ ...
946
931
  >>> s = Survey.example().add_question_group("q1", "q0", "group1")
947
932
  Traceback (most recent call last):
948
933
  ...
949
- ValueError: Start index 1 is greater than end index 0.
934
+ edsl.exceptions.surveys.SurveyCreationError: ...
935
+ ...
950
936
  """
951
937
 
952
938
  if not group_name.isidentifier():
953
- raise ValueError(f"Group name {group_name} is not a valid identifier.")
939
+ raise SurveyCreationError(
940
+ f"Group name {group_name} is not a valid identifier."
941
+ )
954
942
 
955
943
  if group_name in self.question_groups:
956
- raise ValueError(f"Group name {group_name} already exists in the survey.")
944
+ raise SurveyCreationError(
945
+ f"Group name {group_name} already exists in the survey."
946
+ )
957
947
 
958
948
  if group_name in self.question_name_to_index:
959
- raise ValueError(
949
+ raise SurveyCreationError(
960
950
  f"Group name {group_name} already exists as a question name in the survey."
961
951
  )
962
952
 
@@ -964,7 +954,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
964
954
  end_index = self._get_question_index(end_question)
965
955
 
966
956
  if start_index > end_index:
967
- raise ValueError(
957
+ raise SurveyCreationError(
968
958
  f"Start index {start_index} is greater than end index {end_index}."
969
959
  )
970
960
 
@@ -973,15 +963,21 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
973
963
  existing_end_index,
974
964
  ) in self.question_groups.items():
975
965
  if start_index < existing_start_index and end_index > existing_end_index:
976
- raise ValueError(
966
+ raise SurveyCreationError(
977
967
  f"Group {group_name} contains the questions in the new group."
978
968
  )
979
969
  if start_index > existing_start_index and end_index < existing_end_index:
980
- raise ValueError(f"Group {group_name} is contained in the new group.")
970
+ raise SurveyCreationError(
971
+ f"Group {group_name} is contained in the new group."
972
+ )
981
973
  if start_index < existing_start_index and end_index > existing_start_index:
982
- raise ValueError(f"Group {group_name} overlaps with the new group.")
974
+ raise SurveyCreationError(
975
+ f"Group {group_name} overlaps with the new group."
976
+ )
983
977
  if start_index < existing_end_index and end_index > existing_end_index:
984
- raise ValueError(f"Group {group_name} overlaps with the new group.")
978
+ raise SurveyCreationError(
979
+ f"Group {group_name} overlaps with the new group."
980
+ )
985
981
 
986
982
  self.question_groups[group_name] = (start_index, end_index)
987
983
  return self
@@ -1009,12 +1005,12 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1009
1005
  self, question: Union[QuestionBase, str], expression: str
1010
1006
  ) -> Survey:
1011
1007
  """Add a rule that stops the survey.
1008
+ The rule is evaluated *after* the question is answered. If the rule is true, the survey ends.
1012
1009
 
1013
1010
  :param question: The question to add the stop rule to.
1014
1011
  :param expression: The expression to evaluate.
1015
1012
 
1016
1013
  If this rule is true, the survey ends.
1017
- The rule is evaluated *after* the question is answered. If the rule is true, the survey ends.
1018
1014
 
1019
1015
  Here, answering "yes" to q0 ends the survey:
1020
1016
 
@@ -1030,9 +1026,21 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1030
1026
  >>> s.add_stop_rule("q0", "q1 <> 'yes'")
1031
1027
  Traceback (most recent call last):
1032
1028
  ...
1033
- ValueError: The expression contains '<>', which is not allowed. You probably mean '!='.
1029
+ edsl.exceptions.surveys.SurveyCreationError: The expression contains '<>', which is not allowed. You probably mean '!='.
1030
+ ...
1034
1031
  """
1035
1032
  expression = ValidatedString(expression)
1033
+ prior_question_appears = False
1034
+ for prior_question in self.questions:
1035
+ if prior_question.question_name in expression:
1036
+ prior_question_appears = True
1037
+
1038
+ if not prior_question_appears:
1039
+ import warnings
1040
+
1041
+ warnings.warn(
1042
+ f"The expression {expression} does not contain any prior question names. This is probably a mistake."
1043
+ )
1036
1044
  self.add_rule(question, expression, EndOfSurvey)
1037
1045
  return self
1038
1046
 
@@ -1219,32 +1227,59 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1219
1227
 
1220
1228
  # region: Running the survey
1221
1229
 
1222
- def __call__(self, model=None, agent=None, cache=None, **kwargs):
1230
+ def __call__(
1231
+ self,
1232
+ model=None,
1233
+ agent=None,
1234
+ cache=None,
1235
+ disable_remote_cache: bool = False,
1236
+ disable_remote_inference: bool = False,
1237
+ **kwargs,
1238
+ ):
1223
1239
  """Run the survey with default model, taking the required survey as arguments.
1224
1240
 
1225
1241
  >>> from edsl.questions import QuestionFunctional
1226
1242
  >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1227
1243
  >>> q = QuestionFunctional(question_name = "q0", func = f)
1228
1244
  >>> s = Survey([q])
1229
- >>> s(period = "morning", cache = False).select("answer.q0").first()
1245
+ >>> s(period = "morning", cache = False, disable_remote_cache = True, disable_remote_inference = True).select("answer.q0").first()
1230
1246
  'yes'
1231
- >>> s(period = "evening", cache = False).select("answer.q0").first()
1247
+ >>> s(period = "evening", cache = False, disable_remote_cache = True, disable_remote_inference = True).select("answer.q0").first()
1232
1248
  'no'
1233
1249
  """
1234
1250
  job = self.get_job(model, agent, **kwargs)
1235
- return job.run(cache=cache)
1251
+ return job.run(
1252
+ cache=cache,
1253
+ disable_remote_cache=disable_remote_cache,
1254
+ disable_remote_inference=disable_remote_inference,
1255
+ )
1236
1256
 
1237
- async def run_async(self, model=None, agent=None, cache=None, **kwargs):
1257
+ async def run_async(
1258
+ self,
1259
+ model: Optional["Model"] = None,
1260
+ agent: Optional["Agent"] = None,
1261
+ cache: Optional["Cache"] = None,
1262
+ disable_remote_inference: bool = False,
1263
+ **kwargs,
1264
+ ):
1238
1265
  """Run the survey with default model, taking the required survey as arguments.
1239
1266
 
1267
+ >>> import asyncio
1240
1268
  >>> from edsl.questions import QuestionFunctional
1241
1269
  >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1242
1270
  >>> q = QuestionFunctional(question_name = "q0", func = f)
1243
1271
  >>> s = Survey([q])
1244
- >>> s(period = "morning").select("answer.q0").first()
1245
- 'yes'
1246
- >>> s(period = "evening").select("answer.q0").first()
1247
- 'no'
1272
+ >>> async def test_run_async(): result = await s.run_async(period="morning", disable_remote_inference = True); print(result.select("answer.q0").first())
1273
+ >>> asyncio.run(test_run_async())
1274
+ yes
1275
+ >>> import asyncio
1276
+ >>> from edsl.questions import QuestionFunctional
1277
+ >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1278
+ >>> q = QuestionFunctional(question_name = "q0", func = f)
1279
+ >>> s = Survey([q])
1280
+ >>> async def test_run_async(): result = await s.run_async(period="evening", disable_remote_inference = True); print(result.select("answer.q0").first())
1281
+ >>> asyncio.run(test_run_async())
1282
+ no
1248
1283
  """
1249
1284
  # TODO: temp fix by creating a cache
1250
1285
  if cache is None:
@@ -1253,8 +1288,10 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1253
1288
  c = Cache()
1254
1289
  else:
1255
1290
  c = cache
1256
- jobs: "Jobs" = self.get_job(model, agent, **kwargs)
1257
- return await jobs.run_async(cache=c)
1291
+ jobs: "Jobs" = self.get_job(model=model, agent=agent, **kwargs)
1292
+ return await jobs.run_async(
1293
+ cache=c, disable_remote_inference=disable_remote_inference
1294
+ )
1258
1295
 
1259
1296
  def run(self, *args, **kwargs) -> "Results":
1260
1297
  """Turn the survey into a Job and runs it.
@@ -1263,7 +1300,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1263
1300
  >>> s = Survey([QuestionFreeText.example()])
1264
1301
  >>> from edsl.language_models import LanguageModel
1265
1302
  >>> m = LanguageModel.example(test_model = True, canned_response = "Great!")
1266
- >>> results = s.by(m).run(cache = False)
1303
+ >>> results = s.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True)
1267
1304
  >>> results.select('answer.*')
1268
1305
  Dataset([{'answer.how_are_you': ['Great!']}])
1269
1306
  """
@@ -1394,7 +1431,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1394
1431
  print(
1395
1432
  f"The index is {index} but the length of the questions is {len(self.questions)}"
1396
1433
  )
1397
- raise
1434
+ raise SurveyError
1398
1435
 
1399
1436
  try:
1400
1437
  text_dag = {}
@@ -1641,20 +1678,6 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1641
1678
  else:
1642
1679
  return df
1643
1680
 
1644
- def web(
1645
- self,
1646
- platform: Literal[
1647
- "google_forms", "lime_survey", "survey_monkey"
1648
- ] = "google_forms",
1649
- email=None,
1650
- ):
1651
- from edsl.coop import Coop
1652
-
1653
- c = Coop()
1654
-
1655
- res = c.web(self.to_dict(), platform, email)
1656
- return res
1657
-
1658
1681
  # endregion
1659
1682
 
1660
1683
  @classmethod
@@ -389,3 +389,21 @@ def shorten_string(s, max_length, placeholder="..."):
389
389
  end_remove = end_space
390
390
 
391
391
  return s[:start_remove] + placeholder + s[end_remove:]
392
+
393
+
394
+ def write_api_key_to_env(api_key: str) -> None:
395
+ """
396
+ Write the user's Expected Parrot key to their .env file.
397
+
398
+ If a .env file doesn't exist in the current directory, one will be created.
399
+ """
400
+ from pathlib import Path
401
+ from dotenv import set_key
402
+
403
+ # Create .env file if it doesn't exist
404
+ env_path = ".env"
405
+ env_file = Path(env_path)
406
+ env_file.touch(exist_ok=True)
407
+
408
+ # Write API key to file
409
+ set_key(env_path, "EXPECTED_PARROT_API_KEY", str(api_key))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: edsl
3
- Version: 0.1.37.dev4
3
+ Version: 0.1.37.dev6
4
4
  Summary: Create and analyze LLM-based surveys
5
5
  Home-page: https://www.expectedparrot.com/
6
6
  License: MIT