edsl 0.1.37__py3-none-any.whl → 0.1.37.dev1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. edsl/__version__.py +1 -1
  2. edsl/agents/Agent.py +35 -86
  3. edsl/agents/AgentList.py +0 -5
  4. edsl/agents/InvigilatorBase.py +23 -2
  5. edsl/agents/PromptConstructor.py +105 -148
  6. edsl/agents/descriptors.py +4 -17
  7. edsl/conjure/AgentConstructionMixin.py +3 -11
  8. edsl/conversation/Conversation.py +14 -66
  9. edsl/coop/coop.py +14 -148
  10. edsl/data/Cache.py +1 -1
  11. edsl/exceptions/__init__.py +3 -7
  12. edsl/exceptions/agents.py +19 -17
  13. edsl/exceptions/results.py +8 -11
  14. edsl/exceptions/surveys.py +10 -13
  15. edsl/inference_services/AwsBedrock.py +2 -7
  16. edsl/inference_services/InferenceServicesCollection.py +9 -32
  17. edsl/jobs/Jobs.py +71 -306
  18. edsl/jobs/interviews/InterviewExceptionEntry.py +1 -5
  19. edsl/jobs/tasks/TaskHistory.py +0 -1
  20. edsl/language_models/LanguageModel.py +59 -47
  21. edsl/language_models/__init__.py +0 -1
  22. edsl/prompts/Prompt.py +4 -11
  23. edsl/questions/QuestionBase.py +13 -53
  24. edsl/questions/QuestionBasePromptsMixin.py +33 -1
  25. edsl/questions/QuestionFreeText.py +0 -1
  26. edsl/questions/QuestionFunctional.py +2 -2
  27. edsl/questions/descriptors.py +28 -23
  28. edsl/results/DatasetExportMixin.py +1 -25
  29. edsl/results/Result.py +1 -16
  30. edsl/results/Results.py +120 -31
  31. edsl/results/ResultsDBMixin.py +1 -1
  32. edsl/results/Selector.py +1 -18
  33. edsl/scenarios/Scenario.py +12 -48
  34. edsl/scenarios/ScenarioHtmlMixin.py +2 -7
  35. edsl/scenarios/ScenarioList.py +1 -12
  36. edsl/surveys/Rule.py +4 -10
  37. edsl/surveys/Survey.py +77 -100
  38. edsl/utilities/utilities.py +0 -18
  39. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/METADATA +1 -1
  40. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/RECORD +42 -46
  41. edsl/conversation/chips.py +0 -95
  42. edsl/exceptions/BaseException.py +0 -21
  43. edsl/exceptions/scenarios.py +0 -22
  44. edsl/language_models/KeyLookup.py +0 -30
  45. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/LICENSE +0 -0
  46. {edsl-0.1.37.dist-info → edsl-0.1.37.dev1.dist-info}/WHEEL +0 -0
edsl/surveys/Survey.py CHANGED
@@ -9,8 +9,6 @@ from typing import Any, Generator, Optional, Union, List, Literal, Callable
9
9
  from uuid import uuid4
10
10
  from edsl.Base import Base
11
11
  from edsl.exceptions import SurveyCreationError, SurveyHasNoRulesError
12
- from edsl.exceptions.surveys import SurveyError
13
-
14
12
  from edsl.questions.QuestionBase import QuestionBase
15
13
  from edsl.surveys.base import RulePriority, EndOfSurvey
16
14
  from edsl.surveys.DAG import DAG
@@ -32,7 +30,7 @@ from edsl.surveys.instructions.ChangeInstruction import ChangeInstruction
32
30
  class ValidatedString(str):
33
31
  def __new__(cls, content):
34
32
  if "<>" in content:
35
- raise SurveyCreationError(
33
+ raise ValueError(
36
34
  "The expression contains '<>', which is not allowed. You probably mean '!='."
37
35
  )
38
36
  return super().__new__(cls, content)
@@ -376,15 +374,14 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
376
374
  >>> s._get_question_index("poop")
377
375
  Traceback (most recent call last):
378
376
  ...
379
- edsl.exceptions.surveys.SurveyError: Question name poop not found in survey. The current question names are {'q0': 0, 'q1': 1, 'q2': 2}.
380
- ...
377
+ ValueError: Question name poop not found in survey. The current question names are {'q0': 0, 'q1': 1, 'q2': 2}.
381
378
  """
382
379
  if q == EndOfSurvey:
383
380
  return EndOfSurvey
384
381
  else:
385
382
  question_name = q if isinstance(q, str) else q.question_name
386
383
  if question_name not in self.question_name_to_index:
387
- raise SurveyError(
384
+ raise ValueError(
388
385
  f"""Question name {question_name} not found in survey. The current question names are {self.question_name_to_index}."""
389
386
  )
390
387
  return self.question_name_to_index[question_name]
@@ -400,7 +397,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
400
397
  Question('multiple_choice', question_name = \"""q0\""", question_text = \"""Do you like school?\""", question_options = ['yes', 'no'])
401
398
  """
402
399
  if question_name not in self.question_name_to_index:
403
- raise SurveyError(f"Question name {question_name} not found in survey.")
400
+ raise KeyError(f"Question name {question_name} not found in survey.")
404
401
  index = self.question_name_to_index[question_name]
405
402
  return self._questions[index]
406
403
 
@@ -424,6 +421,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
424
421
  >>> s.question_names
425
422
  ['q0', 'q1', 'q2']
426
423
  """
424
+ # return list(self.question_name_to_index.keys())
427
425
  return [q.question_name for q in self.questions]
428
426
 
429
427
  @property
@@ -508,7 +506,9 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
508
506
 
509
507
  return ChangeInstruction
510
508
  else:
509
+ # some data might not have the edsl_class_name
511
510
  return QuestionBase
511
+ # raise ValueError(f"Class {pass_dict['edsl_class_name']} not found")
512
512
 
513
513
  questions = [
514
514
  get_class(q_dict).from_dict(q_dict) for q_dict in data["questions"]
@@ -589,8 +589,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
589
589
  >>> s3 = s1 + s2
590
590
  Traceback (most recent call last):
591
591
  ...
592
- edsl.exceptions.surveys.SurveyCreationError: ...
593
- ...
592
+ ValueError: ('Cannot combine two surveys with non-default rules.', "Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.")
594
593
  >>> s3 = s1.clear_non_default_rules() + s2
595
594
  >>> len(s3.questions)
596
595
  4
@@ -600,8 +599,9 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
600
599
  len(self.rule_collection.non_default_rules) > 0
601
600
  or len(other.rule_collection.non_default_rules) > 0
602
601
  ):
603
- raise SurveyCreationError(
604
- "Cannot combine two surveys with non-default rules. Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.",
602
+ raise ValueError(
603
+ "Cannot combine two surveys with non-default rules.",
604
+ "Please use the 'clear_non_default_rules' method to remove non-default rules from the survey.",
605
605
  )
606
606
 
607
607
  return Survey(questions=self.questions + other.questions)
@@ -609,16 +609,16 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
609
609
  def move_question(self, identifier: Union[str, int], new_index: int):
610
610
  if isinstance(identifier, str):
611
611
  if identifier not in self.question_names:
612
- raise SurveyError(
612
+ raise ValueError(
613
613
  f"Question name '{identifier}' does not exist in the survey."
614
614
  )
615
615
  index = self.question_name_to_index[identifier]
616
616
  elif isinstance(identifier, int):
617
617
  if identifier < 0 or identifier >= len(self.questions):
618
- raise SurveyError(f"Index {identifier} is out of range.")
618
+ raise ValueError(f"Index {identifier} is out of range.")
619
619
  index = identifier
620
620
  else:
621
- raise SurveyError(
621
+ raise TypeError(
622
622
  "Identifier must be either a string (question name) or an integer (question index)."
623
623
  )
624
624
 
@@ -648,28 +648,33 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
648
648
  """
649
649
  if isinstance(identifier, str):
650
650
  if identifier not in self.question_names:
651
- raise SurveyError(
651
+ raise ValueError(
652
652
  f"Question name '{identifier}' does not exist in the survey."
653
653
  )
654
654
  index = self.question_name_to_index[identifier]
655
655
  elif isinstance(identifier, int):
656
656
  if identifier < 0 or identifier >= len(self.questions):
657
- raise SurveyError(f"Index {identifier} is out of range.")
657
+ raise ValueError(f"Index {identifier} is out of range.")
658
658
  index = identifier
659
659
  else:
660
- raise SurveyError(
660
+ raise TypeError(
661
661
  "Identifier must be either a string (question name) or an integer (question index)."
662
662
  )
663
663
 
664
664
  # Remove the question
665
665
  deleted_question = self._questions.pop(index)
666
666
  del self.pseudo_indices[deleted_question.question_name]
667
+ # del self.question_name_to_index[deleted_question.question_name]
667
668
 
668
669
  # Update indices
669
670
  for question_name, old_index in self.pseudo_indices.items():
670
671
  if old_index > index:
671
672
  self.pseudo_indices[question_name] = old_index - 1
672
673
 
674
+ # for question_name, old_index in self.question_name_to_index.items():
675
+ # if old_index > index:
676
+ # self.question_name_to_index[question_name] = old_index - 1
677
+
673
678
  # Update rules
674
679
  new_rule_collection = RuleCollection()
675
680
  for rule in self.rule_collection:
@@ -685,6 +690,13 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
685
690
  rule.next_q = EndOfSurvey
686
691
  else:
687
692
  rule.next_q = index
693
+ # rule.next_q = min(index, len(self.questions) - 1)
694
+ # continue
695
+
696
+ # if rule.next_q == index:
697
+ # rule.next_q = min(
698
+ # rule.next_q, len(self.questions) - 1
699
+ # ) # Adjust to last question if necessary
688
700
 
689
701
  new_rule_collection.add_rule(rule)
690
702
  self.rule_collection = new_rule_collection
@@ -715,7 +727,6 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
715
727
  Traceback (most recent call last):
716
728
  ...
717
729
  edsl.exceptions.surveys.SurveyCreationError: Question name 'q0' already exists in survey. Existing names are ['q0'].
718
- ...
719
730
  """
720
731
  if question.question_name in self.question_names:
721
732
  raise SurveyCreationError(
@@ -725,11 +736,11 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
725
736
  index = len(self.questions)
726
737
 
727
738
  if index > len(self.questions):
728
- raise SurveyCreationError(
739
+ raise ValueError(
729
740
  f"Index {index} is greater than the number of questions in the survey."
730
741
  )
731
742
  if index < 0:
732
- raise SurveyCreationError(f"Index {index} is less than 0.")
743
+ raise ValueError(f"Index {index} is less than 0.")
733
744
 
734
745
  interior_insertion = index != len(self.questions)
735
746
 
@@ -921,32 +932,31 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
921
932
  >>> s = Survey.example().add_question_group("q0", "q2", "1group1")
922
933
  Traceback (most recent call last):
923
934
  ...
924
- edsl.exceptions.surveys.SurveyCreationError: Group name 1group1 is not a valid identifier.
925
- ...
935
+ ValueError: Group name 1group1 is not a valid identifier.
936
+
937
+ The name of the group cannot be the same as an existing question name:
938
+
926
939
  >>> s = Survey.example().add_question_group("q0", "q1", "q0")
927
940
  Traceback (most recent call last):
928
941
  ...
929
- edsl.exceptions.surveys.SurveyCreationError: ...
930
- ...
942
+ ValueError: Group name q0 already exists as a question name in the survey.
943
+
944
+ The start index must be less than the end index:
945
+
931
946
  >>> s = Survey.example().add_question_group("q1", "q0", "group1")
932
947
  Traceback (most recent call last):
933
948
  ...
934
- edsl.exceptions.surveys.SurveyCreationError: ...
935
- ...
949
+ ValueError: Start index 1 is greater than end index 0.
936
950
  """
937
951
 
938
952
  if not group_name.isidentifier():
939
- raise SurveyCreationError(
940
- f"Group name {group_name} is not a valid identifier."
941
- )
953
+ raise ValueError(f"Group name {group_name} is not a valid identifier.")
942
954
 
943
955
  if group_name in self.question_groups:
944
- raise SurveyCreationError(
945
- f"Group name {group_name} already exists in the survey."
946
- )
956
+ raise ValueError(f"Group name {group_name} already exists in the survey.")
947
957
 
948
958
  if group_name in self.question_name_to_index:
949
- raise SurveyCreationError(
959
+ raise ValueError(
950
960
  f"Group name {group_name} already exists as a question name in the survey."
951
961
  )
952
962
 
@@ -954,7 +964,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
954
964
  end_index = self._get_question_index(end_question)
955
965
 
956
966
  if start_index > end_index:
957
- raise SurveyCreationError(
967
+ raise ValueError(
958
968
  f"Start index {start_index} is greater than end index {end_index}."
959
969
  )
960
970
 
@@ -963,21 +973,15 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
963
973
  existing_end_index,
964
974
  ) in self.question_groups.items():
965
975
  if start_index < existing_start_index and end_index > existing_end_index:
966
- raise SurveyCreationError(
976
+ raise ValueError(
967
977
  f"Group {group_name} contains the questions in the new group."
968
978
  )
969
979
  if start_index > existing_start_index and end_index < existing_end_index:
970
- raise SurveyCreationError(
971
- f"Group {group_name} is contained in the new group."
972
- )
980
+ raise ValueError(f"Group {group_name} is contained in the new group.")
973
981
  if start_index < existing_start_index and end_index > existing_start_index:
974
- raise SurveyCreationError(
975
- f"Group {group_name} overlaps with the new group."
976
- )
982
+ raise ValueError(f"Group {group_name} overlaps with the new group.")
977
983
  if start_index < existing_end_index and end_index > existing_end_index:
978
- raise SurveyCreationError(
979
- f"Group {group_name} overlaps with the new group."
980
- )
984
+ raise ValueError(f"Group {group_name} overlaps with the new group.")
981
985
 
982
986
  self.question_groups[group_name] = (start_index, end_index)
983
987
  return self
@@ -1005,12 +1009,12 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1005
1009
  self, question: Union[QuestionBase, str], expression: str
1006
1010
  ) -> Survey:
1007
1011
  """Add a rule that stops the survey.
1008
- The rule is evaluated *after* the question is answered. If the rule is true, the survey ends.
1009
1012
 
1010
1013
  :param question: The question to add the stop rule to.
1011
1014
  :param expression: The expression to evaluate.
1012
1015
 
1013
1016
  If this rule is true, the survey ends.
1017
+ The rule is evaluated *after* the question is answered. If the rule is true, the survey ends.
1014
1018
 
1015
1019
  Here, answering "yes" to q0 ends the survey:
1016
1020
 
@@ -1026,21 +1030,9 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1026
1030
  >>> s.add_stop_rule("q0", "q1 <> 'yes'")
1027
1031
  Traceback (most recent call last):
1028
1032
  ...
1029
- edsl.exceptions.surveys.SurveyCreationError: The expression contains '<>', which is not allowed. You probably mean '!='.
1030
- ...
1033
+ ValueError: The expression contains '<>', which is not allowed. You probably mean '!='.
1031
1034
  """
1032
1035
  expression = ValidatedString(expression)
1033
- prior_question_appears = False
1034
- for prior_question in self.questions:
1035
- if prior_question.question_name in expression:
1036
- prior_question_appears = True
1037
-
1038
- if not prior_question_appears:
1039
- import warnings
1040
-
1041
- warnings.warn(
1042
- f"The expression {expression} does not contain any prior question names. This is probably a mistake."
1043
- )
1044
1036
  self.add_rule(question, expression, EndOfSurvey)
1045
1037
  return self
1046
1038
 
@@ -1227,59 +1219,32 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1227
1219
 
1228
1220
  # region: Running the survey
1229
1221
 
1230
- def __call__(
1231
- self,
1232
- model=None,
1233
- agent=None,
1234
- cache=None,
1235
- disable_remote_cache: bool = False,
1236
- disable_remote_inference: bool = False,
1237
- **kwargs,
1238
- ):
1222
+ def __call__(self, model=None, agent=None, cache=None, **kwargs):
1239
1223
  """Run the survey with default model, taking the required survey as arguments.
1240
1224
 
1241
1225
  >>> from edsl.questions import QuestionFunctional
1242
1226
  >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1243
1227
  >>> q = QuestionFunctional(question_name = "q0", func = f)
1244
1228
  >>> s = Survey([q])
1245
- >>> s(period = "morning", cache = False, disable_remote_cache = True, disable_remote_inference = True).select("answer.q0").first()
1229
+ >>> s(period = "morning", cache = False).select("answer.q0").first()
1246
1230
  'yes'
1247
- >>> s(period = "evening", cache = False, disable_remote_cache = True, disable_remote_inference = True).select("answer.q0").first()
1231
+ >>> s(period = "evening", cache = False).select("answer.q0").first()
1248
1232
  'no'
1249
1233
  """
1250
1234
  job = self.get_job(model, agent, **kwargs)
1251
- return job.run(
1252
- cache=cache,
1253
- disable_remote_cache=disable_remote_cache,
1254
- disable_remote_inference=disable_remote_inference,
1255
- )
1235
+ return job.run(cache=cache)
1256
1236
 
1257
- async def run_async(
1258
- self,
1259
- model: Optional["Model"] = None,
1260
- agent: Optional["Agent"] = None,
1261
- cache: Optional["Cache"] = None,
1262
- disable_remote_inference: bool = False,
1263
- **kwargs,
1264
- ):
1237
+ async def run_async(self, model=None, agent=None, cache=None, **kwargs):
1265
1238
  """Run the survey with default model, taking the required survey as arguments.
1266
1239
 
1267
- >>> import asyncio
1268
1240
  >>> from edsl.questions import QuestionFunctional
1269
1241
  >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1270
1242
  >>> q = QuestionFunctional(question_name = "q0", func = f)
1271
1243
  >>> s = Survey([q])
1272
- >>> async def test_run_async(): result = await s.run_async(period="morning", disable_remote_inference = True); print(result.select("answer.q0").first())
1273
- >>> asyncio.run(test_run_async())
1274
- yes
1275
- >>> import asyncio
1276
- >>> from edsl.questions import QuestionFunctional
1277
- >>> def f(scenario, agent_traits): return "yes" if scenario["period"] == "morning" else "no"
1278
- >>> q = QuestionFunctional(question_name = "q0", func = f)
1279
- >>> s = Survey([q])
1280
- >>> async def test_run_async(): result = await s.run_async(period="evening", disable_remote_inference = True); print(result.select("answer.q0").first())
1281
- >>> asyncio.run(test_run_async())
1282
- no
1244
+ >>> s(period = "morning").select("answer.q0").first()
1245
+ 'yes'
1246
+ >>> s(period = "evening").select("answer.q0").first()
1247
+ 'no'
1283
1248
  """
1284
1249
  # TODO: temp fix by creating a cache
1285
1250
  if cache is None:
@@ -1288,10 +1253,8 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1288
1253
  c = Cache()
1289
1254
  else:
1290
1255
  c = cache
1291
- jobs: "Jobs" = self.get_job(model=model, agent=agent, **kwargs)
1292
- return await jobs.run_async(
1293
- cache=c, disable_remote_inference=disable_remote_inference
1294
- )
1256
+ jobs: "Jobs" = self.get_job(model, agent, **kwargs)
1257
+ return await jobs.run_async(cache=c)
1295
1258
 
1296
1259
  def run(self, *args, **kwargs) -> "Results":
1297
1260
  """Turn the survey into a Job and runs it.
@@ -1300,7 +1263,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1300
1263
  >>> s = Survey([QuestionFreeText.example()])
1301
1264
  >>> from edsl.language_models import LanguageModel
1302
1265
  >>> m = LanguageModel.example(test_model = True, canned_response = "Great!")
1303
- >>> results = s.by(m).run(cache = False, disable_remote_cache = True, disable_remote_inference = True)
1266
+ >>> results = s.by(m).run(cache = False)
1304
1267
  >>> results.select('answer.*')
1305
1268
  Dataset([{'answer.how_are_you': ['Great!']}])
1306
1269
  """
@@ -1431,7 +1394,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1431
1394
  print(
1432
1395
  f"The index is {index} but the length of the questions is {len(self.questions)}"
1433
1396
  )
1434
- raise SurveyError
1397
+ raise
1435
1398
 
1436
1399
  try:
1437
1400
  text_dag = {}
@@ -1678,6 +1641,20 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1678
1641
  else:
1679
1642
  return df
1680
1643
 
1644
+ def web(
1645
+ self,
1646
+ platform: Literal[
1647
+ "google_forms", "lime_survey", "survey_monkey"
1648
+ ] = "google_forms",
1649
+ email=None,
1650
+ ):
1651
+ from edsl.coop import Coop
1652
+
1653
+ c = Coop()
1654
+
1655
+ res = c.web(self.to_dict(), platform, email)
1656
+ return res
1657
+
1681
1658
  # endregion
1682
1659
 
1683
1660
  @classmethod
@@ -389,21 +389,3 @@ def shorten_string(s, max_length, placeholder="..."):
389
389
  end_remove = end_space
390
390
 
391
391
  return s[:start_remove] + placeholder + s[end_remove:]
392
-
393
-
394
- def write_api_key_to_env(api_key: str) -> None:
395
- """
396
- Write the user's Expected Parrot key to their .env file.
397
-
398
- If a .env file doesn't exist in the current directory, one will be created.
399
- """
400
- from pathlib import Path
401
- from dotenv import set_key
402
-
403
- # Create .env file if it doesn't exist
404
- env_path = ".env"
405
- env_file = Path(env_path)
406
- env_file.touch(exist_ok=True)
407
-
408
- # Write API key to file
409
- set_key(env_path, "EXPECTED_PARROT_API_KEY", str(api_key))
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: edsl
3
- Version: 0.1.37
3
+ Version: 0.1.37.dev1
4
4
  Summary: Create and analyze LLM-based surveys
5
5
  Home-page: https://www.expectedparrot.com/
6
6
  License: MIT