edsl 0.1.33.dev2__py3-none-any.whl → 0.1.33.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. edsl/Base.py +9 -3
  2. edsl/__init__.py +1 -0
  3. edsl/__version__.py +1 -1
  4. edsl/agents/Agent.py +6 -6
  5. edsl/agents/Invigilator.py +6 -3
  6. edsl/agents/InvigilatorBase.py +8 -27
  7. edsl/agents/{PromptConstructionMixin.py → PromptConstructor.py} +101 -29
  8. edsl/config.py +26 -34
  9. edsl/coop/coop.py +11 -2
  10. edsl/data_transfer_models.py +27 -73
  11. edsl/enums.py +2 -0
  12. edsl/inference_services/GoogleService.py +1 -1
  13. edsl/inference_services/InferenceServiceABC.py +44 -13
  14. edsl/inference_services/OpenAIService.py +7 -4
  15. edsl/inference_services/TestService.py +24 -15
  16. edsl/inference_services/TogetherAIService.py +170 -0
  17. edsl/inference_services/registry.py +2 -0
  18. edsl/jobs/Jobs.py +18 -8
  19. edsl/jobs/buckets/BucketCollection.py +24 -15
  20. edsl/jobs/buckets/TokenBucket.py +64 -10
  21. edsl/jobs/interviews/Interview.py +115 -47
  22. edsl/jobs/interviews/{interview_exception_tracking.py → InterviewExceptionCollection.py} +16 -0
  23. edsl/jobs/interviews/InterviewExceptionEntry.py +2 -0
  24. edsl/jobs/runners/JobsRunnerAsyncio.py +86 -161
  25. edsl/jobs/runners/JobsRunnerStatus.py +331 -0
  26. edsl/jobs/tasks/TaskHistory.py +17 -0
  27. edsl/language_models/LanguageModel.py +26 -31
  28. edsl/language_models/registry.py +13 -9
  29. edsl/questions/QuestionBase.py +64 -16
  30. edsl/questions/QuestionBudget.py +93 -41
  31. edsl/questions/QuestionFreeText.py +6 -0
  32. edsl/questions/QuestionMultipleChoice.py +11 -26
  33. edsl/questions/QuestionNumerical.py +5 -4
  34. edsl/questions/Quick.py +41 -0
  35. edsl/questions/ResponseValidatorABC.py +6 -5
  36. edsl/questions/derived/QuestionLinearScale.py +4 -1
  37. edsl/questions/derived/QuestionTopK.py +4 -1
  38. edsl/questions/derived/QuestionYesNo.py +8 -2
  39. edsl/questions/templates/budget/__init__.py +0 -0
  40. edsl/questions/templates/budget/answering_instructions.jinja +7 -0
  41. edsl/questions/templates/budget/question_presentation.jinja +7 -0
  42. edsl/questions/templates/extract/__init__.py +0 -0
  43. edsl/questions/templates/rank/__init__.py +0 -0
  44. edsl/results/DatasetExportMixin.py +5 -1
  45. edsl/results/Result.py +1 -1
  46. edsl/results/Results.py +4 -1
  47. edsl/scenarios/FileStore.py +71 -10
  48. edsl/scenarios/Scenario.py +86 -21
  49. edsl/scenarios/ScenarioImageMixin.py +2 -2
  50. edsl/scenarios/ScenarioList.py +13 -0
  51. edsl/scenarios/ScenarioListPdfMixin.py +150 -4
  52. edsl/study/Study.py +32 -0
  53. edsl/surveys/Rule.py +10 -1
  54. edsl/surveys/RuleCollection.py +19 -3
  55. edsl/surveys/Survey.py +7 -0
  56. edsl/templates/error_reporting/interview_details.html +6 -1
  57. edsl/utilities/utilities.py +9 -1
  58. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/METADATA +2 -1
  59. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/RECORD +61 -55
  60. edsl/jobs/interviews/retry_management.py +0 -39
  61. edsl/jobs/runners/JobsRunnerStatusMixin.py +0 -333
  62. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/LICENSE +0 -0
  63. {edsl-0.1.33.dev2.dist-info → edsl-0.1.33.dev3.dist-info}/WHEEL +0 -0
@@ -120,14 +120,22 @@ class FileStore(Scenario):
120
120
  return info
121
121
 
122
122
  @classmethod
123
- def pull(cls, uuid):
124
- scenario_version = Scenario.pull(uuid)
123
+ def pull(cls, uuid, expected_parrot_url: Optional[str] = None):
124
+ scenario_version = Scenario.pull(uuid, expected_parrot_url=expected_parrot_url)
125
125
  return cls.from_dict(scenario_version.to_dict())
126
126
 
127
127
 
128
128
  class CSVFileStore(FileStore):
129
- def __init__(self, filename):
130
- super().__init__(filename, suffix=".csv")
129
+ def __init__(
130
+ self,
131
+ filename,
132
+ binary: Optional[bool] = None,
133
+ suffix: Optional[str] = None,
134
+ base64_string: Optional[str] = None,
135
+ ):
136
+ super().__init__(
137
+ filename, binary=binary, base64_string=base64_string, suffix=".csv"
138
+ )
131
139
 
132
140
  @classmethod
133
141
  def example(cls):
@@ -147,8 +155,16 @@ class CSVFileStore(FileStore):
147
155
 
148
156
 
149
157
  class PDFFileStore(FileStore):
150
- def __init__(self, filename):
151
- super().__init__(filename, suffix=".pdf")
158
+ def __init__(
159
+ self,
160
+ filename,
161
+ binary: Optional[bool] = None,
162
+ suffix: Optional[str] = None,
163
+ base64_string: Optional[str] = None,
164
+ ):
165
+ super().__init__(
166
+ filename, binary=binary, base64_string=base64_string, suffix=".pdf"
167
+ )
152
168
 
153
169
  def view(self):
154
170
  pdf_path = self.to_tempfile()
@@ -225,8 +241,16 @@ class PDFFileStore(FileStore):
225
241
 
226
242
 
227
243
  class PNGFileStore(FileStore):
228
- def __init__(self, filename):
229
- super().__init__(filename, suffix=".png")
244
+ def __init__(
245
+ self,
246
+ filename,
247
+ binary: Optional[bool] = None,
248
+ suffix: Optional[str] = None,
249
+ base64_string: Optional[str] = None,
250
+ ):
251
+ super().__init__(
252
+ filename, binary=binary, base64_string=base64_string, suffix=".png"
253
+ )
230
254
 
231
255
  @classmethod
232
256
  def example(cls):
@@ -251,8 +275,16 @@ class PNGFileStore(FileStore):
251
275
 
252
276
 
253
277
  class SQLiteFileStore(FileStore):
254
- def __init__(self, filename):
255
- super().__init__(filename, suffix=".sqlite")
278
+ def __init__(
279
+ self,
280
+ filename,
281
+ binary: Optional[bool] = None,
282
+ suffix: Optional[str] = None,
283
+ base64_string: Optional[str] = None,
284
+ ):
285
+ super().__init__(
286
+ filename, binary=binary, base64_string=base64_string, suffix=".sqlite"
287
+ )
256
288
 
257
289
  @classmethod
258
290
  def example(cls):
@@ -265,6 +297,8 @@ class SQLiteFileStore(FileStore):
265
297
  c.execute("""CREATE TABLE stocks (date text)""")
266
298
  conn.commit()
267
299
 
300
+ return cls(f.name)
301
+
268
302
  def view(self):
269
303
  import subprocess
270
304
  import os
@@ -273,6 +307,33 @@ class SQLiteFileStore(FileStore):
273
307
  os.system(f"sqlite3 {sqlite_path}")
274
308
 
275
309
 
310
+ class HTMLFileStore(FileStore):
311
+ def __init__(
312
+ self,
313
+ filename,
314
+ binary: Optional[bool] = None,
315
+ suffix: Optional[str] = None,
316
+ base64_string: Optional[str] = None,
317
+ ):
318
+ super().__init__(
319
+ filename, binary=binary, base64_string=base64_string, suffix=".html"
320
+ )
321
+
322
+ @classmethod
323
+ def example(cls):
324
+ import tempfile
325
+
326
+ with tempfile.NamedTemporaryFile(suffix=".html", delete=False) as f:
327
+ f.write("<html><body><h1>Test</h1></body></html>".encode())
328
+ return cls(f.name)
329
+
330
+ def view(self):
331
+ import webbrowser
332
+
333
+ html_path = self.to_tempfile()
334
+ webbrowser.open("file://" + html_path)
335
+
336
+
276
337
  if __name__ == "__main__":
277
338
  # file_path = "../conjure/examples/Ex11-2.sav"
278
339
  # fs = FileStore(file_path)
@@ -5,6 +5,10 @@ import copy
5
5
  import base64
6
6
  import hashlib
7
7
  import os
8
+ import reprlib
9
+ import imghdr
10
+
11
+
8
12
  from collections import UserDict
9
13
  from typing import Union, List, Optional, Generator
10
14
  from uuid import uuid4
@@ -13,6 +17,8 @@ from edsl.scenarios.ScenarioImageMixin import ScenarioImageMixin
13
17
  from edsl.scenarios.ScenarioHtmlMixin import ScenarioHtmlMixin
14
18
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
15
19
 
20
+ from edsl.data_transfer_models import ImageInfo
21
+
16
22
 
17
23
  class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
18
24
  """A Scenario is a dictionary of keys/values.
@@ -49,6 +55,39 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
49
55
  self._has_image = False
50
56
  return self._has_image
51
57
 
58
+ @property
59
+ def has_jinja_braces(self) -> bool:
60
+ """Return whether the scenario has jinja braces. This matters for rendering.
61
+
62
+ >>> s = Scenario({"food": "I love {{wood chips}}"})
63
+ >>> s.has_jinja_braces
64
+ True
65
+ """
66
+ for key, value in self.items():
67
+ if "{{" in str(value) and "}}" in value:
68
+ return True
69
+ return False
70
+
71
+ def convert_jinja_braces(
72
+ self, replacement_left="<<", replacement_right=">>"
73
+ ) -> Scenario:
74
+ """Convert Jinja braces to some other character.
75
+
76
+ >>> s = Scenario({"food": "I love {{wood chips}}"})
77
+ >>> s.convert_jinja_braces()
78
+ Scenario({'food': 'I love <<wood chips>>'})
79
+
80
+ """
81
+ new_scenario = Scenario()
82
+ for key, value in self.items():
83
+ if isinstance(value, str):
84
+ new_scenario[key] = value.replace("{{", replacement_left).replace(
85
+ "}}", replacement_right
86
+ )
87
+ else:
88
+ new_scenario[key] = value
89
+ return new_scenario
90
+
52
91
  @has_image.setter
53
92
  def has_image(self, value):
54
93
  self._has_image = value
@@ -142,6 +181,7 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
142
181
  print_json(json.dumps(self.to_dict()))
143
182
 
144
183
  def __repr__(self):
184
+ # return "Scenario(" + reprlib.repr(self.data) + ")"
145
185
  return "Scenario(" + repr(self.data) + ")"
146
186
 
147
187
  def _repr_html_(self):
@@ -196,26 +236,48 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
196
236
  return cls({"url": url, field_name: text})
197
237
 
198
238
  @classmethod
199
- def from_image(cls, image_path: str) -> str:
200
- """Creates a scenario with a base64 encoding of an image.
239
+ def from_image(
240
+ cls, image_path: str, image_name: Optional[str] = None
241
+ ) -> "Scenario":
242
+ """
243
+ Creates a scenario with a base64 encoding of an image.
201
244
 
202
- Example:
245
+ Args:
246
+ image_path (str): Path to the image file.
247
+
248
+ Returns:
249
+ Scenario: A new Scenario instance with image information.
203
250
 
251
+ Example:
204
252
  >>> s = Scenario.from_image(Scenario.example_image())
205
253
  >>> s
206
- Scenario({'file_path': '...', 'encoded_image': '...'})
254
+ Scenario({'logo': ...})
207
255
  """
256
+ if not os.path.exists(image_path):
257
+ raise FileNotFoundError(f"Image file not found: {image_path}")
258
+
208
259
  with open(image_path, "rb") as image_file:
209
- s = cls(
210
- {
211
- "file_path": image_path,
212
- "encoded_image": base64.b64encode(image_file.read()).decode(
213
- "utf-8"
214
- ),
215
- }
216
- )
217
- s.has_image = True
218
- return s
260
+ file_content = image_file.read()
261
+
262
+ file_name = os.path.basename(image_path)
263
+ file_size = os.path.getsize(image_path)
264
+ image_format = imghdr.what(image_path) or "unknown"
265
+
266
+ if image_name is None:
267
+ image_name = file_name.split(".")[0]
268
+
269
+ image_info = ImageInfo(
270
+ file_path=image_path,
271
+ file_name=file_name,
272
+ image_format=image_format,
273
+ file_size=file_size,
274
+ encoded_image=base64.b64encode(file_content).decode("utf-8"),
275
+ )
276
+
277
+ scenario_data = {image_name: image_info}
278
+ s = cls(scenario_data)
279
+ s.has_image = True
280
+ return s
219
281
 
220
282
  @classmethod
221
283
  def from_pdf(cls, pdf_path):
@@ -429,18 +491,21 @@ class Scenario(Base, UserDict, ScenarioImageMixin, ScenarioHtmlMixin):
429
491
  return table
430
492
 
431
493
  @classmethod
432
- def example(cls, randomize: bool = False) -> Scenario:
494
+ def example(cls, randomize: bool = False, has_image=False) -> Scenario:
433
495
  """
434
496
  Returns an example Scenario instance.
435
497
 
436
498
  :param randomize: If True, adds a random string to the value of the example key.
437
499
  """
438
- addition = "" if not randomize else str(uuid4())
439
- return cls(
440
- {
441
- "persona": f"A reseacher studying whether LLMs can be used to generate surveys.{addition}",
442
- }
443
- )
500
+ if not has_image:
501
+ addition = "" if not randomize else str(uuid4())
502
+ return cls(
503
+ {
504
+ "persona": f"A reseacher studying whether LLMs can be used to generate surveys.{addition}",
505
+ }
506
+ )
507
+ else:
508
+ return cls.from_image(cls.example_image())
444
509
 
445
510
  def code(self) -> List[str]:
446
511
  """Return the code for the scenario."""
@@ -13,7 +13,7 @@ class ScenarioImageMixin:
13
13
  >>> from edsl.scenarios.Scenario import Scenario
14
14
  >>> s = Scenario({"food": "wood chips"})
15
15
  >>> s.add_image(Scenario.example_image())
16
- Scenario({'food': 'wood chips', 'file_path': '...', 'encoded_image': '...'})
16
+ Scenario({'food': 'wood chips', 'logo': ...})
17
17
  """
18
18
  new_scenario = self.from_image(image_path)
19
19
  return self + new_scenario
@@ -33,7 +33,7 @@ class ScenarioImageMixin:
33
33
  >>> from edsl.scenarios.Scenario import Scenario
34
34
  >>> s = Scenario.from_image(Scenario.example_image())
35
35
  >>> s
36
- Scenario({'file_path': '...', 'encoded_image': '...'})
36
+ Scenario({'logo': ...})
37
37
  """
38
38
 
39
39
  if image_path.startswith("http://") or image_path.startswith("https://"):
@@ -39,6 +39,15 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
39
39
  super().__init__([])
40
40
  self.codebook = codebook or {}
41
41
 
42
+ @property
43
+ def has_jinja_braces(self) -> bool:
44
+ """Check if the ScenarioList has Jinja braces."""
45
+ return any([scenario.has_jinja_braces for scenario in self])
46
+
47
+ def convert_jinja_braces(self) -> ScenarioList:
48
+ """Convert Jinja braces to Python braces."""
49
+ return ScenarioList([scenario.convert_jinja_braces() for scenario in self])
50
+
42
51
  def give_valid_names(self) -> ScenarioList:
43
52
  """Give valid names to the scenario keys.
44
53
 
@@ -273,6 +282,10 @@ class ScenarioList(Base, UserList, ScenarioListMixin):
273
282
  for s in data["scenarios"]:
274
283
  _ = s.pop("edsl_version")
275
284
  _ = s.pop("edsl_class_name")
285
+ for scenario in data["scenarios"]:
286
+ for key, value in scenario.items():
287
+ if hasattr(value, "to_dict"):
288
+ data[key] = value.to_dict()
276
289
  return data_to_html(data)
277
290
 
278
291
  def tally(self, field) -> dict:
@@ -1,15 +1,161 @@
1
1
  import fitz # PyMuPDF
2
2
  import os
3
+ import copy
3
4
  import subprocess
5
+ import requests
6
+ import tempfile
7
+ import os
8
+
9
+ # import urllib.parse as urlparse
10
+ from urllib.parse import urlparse
4
11
 
5
12
  # from edsl import Scenario
6
13
 
14
+ import requests
15
+ import re
16
+ import tempfile
17
+ import os
18
+ import atexit
19
+ from urllib.parse import urlparse, parse_qs
20
+
21
+
22
+ class GoogleDriveDownloader:
23
+ _temp_dir = None
24
+ _temp_file_path = None
25
+
26
+ @classmethod
27
+ def fetch_from_drive(cls, url, filename=None):
28
+ # Extract file ID from the URL
29
+ file_id = cls._extract_file_id(url)
30
+ if not file_id:
31
+ raise ValueError("Invalid Google Drive URL")
32
+
33
+ # Construct the download URL
34
+ download_url = f"https://drive.google.com/uc?export=download&id={file_id}"
35
+
36
+ # Send a GET request to the URL
37
+ session = requests.Session()
38
+ response = session.get(download_url, stream=True)
39
+ response.raise_for_status()
40
+
41
+ # Check for large file download prompt
42
+ for key, value in response.cookies.items():
43
+ if key.startswith("download_warning"):
44
+ params = {"id": file_id, "confirm": value}
45
+ response = session.get(download_url, params=params, stream=True)
46
+ break
47
+
48
+ # Create a temporary file to save the download
49
+ if not filename:
50
+ filename = "downloaded_file"
51
+
52
+ if cls._temp_dir is None:
53
+ cls._temp_dir = tempfile.TemporaryDirectory()
54
+ atexit.register(cls._cleanup)
55
+
56
+ cls._temp_file_path = os.path.join(cls._temp_dir.name, filename)
57
+
58
+ # Write the content to the temporary file
59
+ with open(cls._temp_file_path, "wb") as f:
60
+ for chunk in response.iter_content(32768):
61
+ if chunk:
62
+ f.write(chunk)
63
+
64
+ print(f"File saved to: {cls._temp_file_path}")
65
+
66
+ return cls._temp_file_path
67
+
68
+ @staticmethod
69
+ def _extract_file_id(url):
70
+ # Try to extract file ID from '/file/d/' format
71
+ file_id_match = re.search(r"/d/([a-zA-Z0-9-_]+)", url)
72
+ if file_id_match:
73
+ return file_id_match.group(1)
74
+
75
+ # If not found, try to extract from 'open?id=' format
76
+ parsed_url = urlparse(url)
77
+ query_params = parse_qs(parsed_url.query)
78
+ if "id" in query_params:
79
+ return query_params["id"][0]
80
+
81
+ return None
82
+
83
+ @classmethod
84
+ def _cleanup(cls):
85
+ if cls._temp_dir:
86
+ cls._temp_dir.cleanup()
87
+
88
+ @classmethod
89
+ def get_temp_file_path(cls):
90
+ return cls._temp_file_path
91
+
92
+
93
+ def fetch_and_save_pdf(url, filename):
94
+ # Send a GET request to the URL
95
+ response = requests.get(url)
96
+
97
+ # Check if the request was successful
98
+ response.raise_for_status()
99
+
100
+ # Create a temporary directory
101
+ with tempfile.TemporaryDirectory() as temp_dir:
102
+ # Construct the full path for the file
103
+ temp_file_path = os.path.join(temp_dir, filename)
104
+
105
+ # Write the content to the temporary file
106
+ with open(temp_file_path, "wb") as file:
107
+ file.write(response.content)
108
+
109
+ print(f"PDF saved to: {temp_file_path}")
110
+
111
+ # Here you can perform operations with the file
112
+ # The file will be automatically deleted when you exit this block
113
+
114
+ return temp_file_path
115
+
116
+
117
+ # Example usage:
118
+ # url = "https://example.com/sample.pdf"
119
+ # fetch_and_save_pdf(url, "sample.pdf")
120
+
7
121
 
8
122
  class ScenarioListPdfMixin:
9
123
  @classmethod
10
- def from_pdf(cls, filename):
11
- scenarios = list(cls.extract_text_from_pdf(filename))
12
- return cls(scenarios)
124
+ def from_pdf(cls, filename_or_url, collapse_pages=False):
125
+ # Check if the input is a URL
126
+ if cls.is_url(filename_or_url):
127
+ # Check if it's a Google Drive URL
128
+ if "drive.google.com" in filename_or_url:
129
+ temp_filename = GoogleDriveDownloader.fetch_from_drive(
130
+ filename_or_url, "temp_pdf.pdf"
131
+ )
132
+ else:
133
+ # For other URLs, use the previous fetch_and_save_pdf function
134
+ temp_filename = fetch_and_save_pdf(filename_or_url, "temp_pdf.pdf")
135
+
136
+ scenarios = list(cls.extract_text_from_pdf(temp_filename))
137
+ else:
138
+ # If it's not a URL, assume it's a local file path
139
+ scenarios = list(cls.extract_text_from_pdf(filename_or_url))
140
+ if not collapse_pages:
141
+ return cls(scenarios)
142
+ else:
143
+ txt = ""
144
+ for scenario in scenarios:
145
+ txt += scenario["text"]
146
+ from edsl.scenarios import Scenario
147
+
148
+ base_scenario = copy.copy(scenarios[0])
149
+ base_scenario["text"] = txt
150
+ return base_scenario
151
+
152
+ @staticmethod
153
+ def is_url(string):
154
+ try:
155
+ result = urlparse(string)
156
+ return all([result.scheme, result.netloc])
157
+ except ValueError:
158
+ return False
13
159
 
14
160
  @classmethod
15
161
  def _from_pdf_to_image(cls, pdf_path, image_format="jpeg"):
@@ -38,7 +184,7 @@ class ScenarioListPdfMixin:
38
184
  scenario = Scenario._from_filepath_image(image_path)
39
185
  scenarios.append(scenario)
40
186
 
41
- print(f"Saved {len(images)} pages as images in {output_folder}")
187
+ # print(f"Saved {len(images)} pages as images in {output_folder}")
42
188
  return cls(scenarios)
43
189
 
44
190
  @staticmethod
edsl/study/Study.py CHANGED
@@ -469,6 +469,38 @@ class Study:
469
469
  coop = Coop()
470
470
  return coop.create(self, description=self.description)
471
471
 
472
+ def delete_object(self, identifier: Union[str, UUID]):
473
+ """
474
+ Delete an EDSL object from the study.
475
+
476
+ :param identifier: Either the variable name or the hash of the object to delete
477
+ :raises ValueError: If the object is not found in the study
478
+ """
479
+ if isinstance(identifier, str):
480
+ # If identifier is a variable name or a string representation of UUID
481
+ for hash, obj_entry in list(self.objects.items()):
482
+ if obj_entry.variable_name == identifier or hash == identifier:
483
+ del self.objects[hash]
484
+ self._create_mapping_dicts() # Update internal mappings
485
+ if self.verbose:
486
+ print(f"Deleted object with identifier: {identifier}")
487
+ return
488
+ raise ValueError(f"No object found with identifier: {identifier}")
489
+ elif isinstance(identifier, UUID):
490
+ # If identifier is a UUID object
491
+ hash_str = str(identifier)
492
+ if hash_str in self.objects:
493
+ del self.objects[hash_str]
494
+ self._create_mapping_dicts() # Update internal mappings
495
+ if self.verbose:
496
+ print(f"Deleted object with hash: {hash_str}")
497
+ return
498
+ raise ValueError(f"No object found with hash: {hash_str}")
499
+ else:
500
+ raise TypeError(
501
+ "Identifier must be either a string (variable name or hash) or a UUID object"
502
+ )
503
+
472
504
  @classmethod
473
505
  def pull(cls, uuid: Optional[Union[str, UUID]] = None, url: Optional[str] = None):
474
506
  """Pull the object from coop."""
edsl/surveys/Rule.py CHANGED
@@ -18,6 +18,7 @@ with a low (-1) priority.
18
18
  """
19
19
 
20
20
  import ast
21
+ import random
21
22
  from typing import Any, Union, List
22
23
 
23
24
  from jinja2 import Template
@@ -254,8 +255,16 @@ class Rule:
254
255
  msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
255
256
  raise SurveyRuleCannotEvaluateError(msg)
256
257
 
258
+ random_functions = {
259
+ "randint": random.randint,
260
+ "choice": random.choice,
261
+ "random": random.random,
262
+ "uniform": random.uniform,
263
+ # Add any other random functions you want to allow
264
+ }
265
+
257
266
  try:
258
- return EvalWithCompoundTypes().eval(to_evaluate)
267
+ return EvalWithCompoundTypes(functions=random_functions).eval(to_evaluate)
259
268
  except Exception as e:
260
269
  msg = f"""Exception in evaluation: {e}. The expression is: {self.expression}. The current info env trying to substitute in is: {current_info_env}. After the substition, the expression was: {to_evaluate}."""
261
270
  raise SurveyRuleCannotEvaluateError(msg)
@@ -172,7 +172,8 @@ class RuleCollection(UserList):
172
172
 
173
173
  def next_question(self, q_now: int, answers: dict[str, Any]) -> NextQuestion:
174
174
  """Find the next question by index, given the rule collection.
175
- This rule is applied after the question is asked.
175
+
176
+ This rule is applied after the question is answered.
176
177
 
177
178
  :param q_now: The current question index.
178
179
  :param answers: The answers to the survey questions so far, including the current question.
@@ -182,8 +183,17 @@ class RuleCollection(UserList):
182
183
  NextQuestion(next_q=3, num_rules_found=2, expressions_evaluating_to_true=1, priority=1)
183
184
 
184
185
  """
185
- # What rules apply at the current node?
186
-
186
+ # # is this the first question? If it is, we need to check if it should be skipped.
187
+ # if q_now == 0:
188
+ # if self.skip_question_before_running(q_now, answers):
189
+ # return NextQuestion(
190
+ # next_q=q_now + 1,
191
+ # num_rules_found=0,
192
+ # expressions_evaluating_to_true=0,
193
+ # priority=-1,
194
+ # )
195
+
196
+ # breakpoint()
187
197
  expressions_evaluating_to_true = 0
188
198
  next_q = None
189
199
  highest_priority = -2 # start with -2 to 'pick up' the default rule added
@@ -205,6 +215,12 @@ class RuleCollection(UserList):
205
215
  f"No rules found for question {q_now}"
206
216
  )
207
217
 
218
+ # breakpoint()
219
+ ## Now we need to check if the *next question* has any 'before; rules that we should follow
220
+ for rule in self.applicable_rules(next_q, before_rule=True):
221
+ if rule.evaluate(answers): # rule evaluates to True
222
+ return self.next_question(next_q, answers)
223
+
208
224
  return NextQuestion(
209
225
  next_q, num_rules_found, expressions_evaluating_to_true, highest_priority
210
226
  )
edsl/surveys/Survey.py CHANGED
@@ -866,6 +866,7 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
866
866
 
867
867
  def clear_non_default_rules(self) -> Survey:
868
868
  """Remove all non-default rules from the survey.
869
+
869
870
  >>> Survey.example().show_rules()
870
871
  ┏━━━━━━━━━━━┳━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━┓
871
872
  ┃ current_q ┃ expression ┃ next_q ┃ priority ┃ before_rule ┃
@@ -1173,9 +1174,15 @@ class Survey(SurveyExportMixin, SurveyFlowVisualizationMixin, Base):
1173
1174
  Question('multiple_choice', question_name = \"""q0\""", question_text = \"""Do you like school?\""", question_options = ['yes', 'no'])
1174
1175
  >>> i2.send({"q0": "no"})
1175
1176
  Question('multiple_choice', question_name = \"""q1\""", question_text = \"""Why not?\""", question_options = ['killer bees in cafeteria', 'other'])
1177
+
1178
+
1176
1179
  """
1177
1180
  self.answers = {}
1178
1181
  question = self._questions[0]
1182
+ # should the first question be skipped?
1183
+ if self.rule_collection.skip_question_before_running(0, self.answers):
1184
+ question = self.next_question(question, self.answers)
1185
+
1179
1186
  while not question == EndOfSurvey:
1180
1187
  # breakpoint()
1181
1188
  answer = yield question
@@ -31,7 +31,12 @@
31
31
 
32
32
  <tr>
33
33
  <td>Human-readable question</td>
34
- <td>{{ interview.survey.get_question(question).html(scenario = interview.scenario, agent = interview.agent) }}</td>
34
+ <td>{{ interview.survey.get_question(question).html(
35
+ scenario = interview.scenario,
36
+ agent = interview.agent,
37
+ answers = exception_message.answers)
38
+
39
+ }}</td>
35
40
  </tr>
36
41
  <tr>
37
42
  <td>Scenario</td>
@@ -20,6 +20,14 @@ from html import escape
20
20
  from typing import Callable, Union
21
21
 
22
22
 
23
+ class CustomEncoder(json.JSONEncoder):
24
+ def default(self, obj):
25
+ try:
26
+ return json.JSONEncoder.default(self, obj)
27
+ except TypeError:
28
+ return str(obj)
29
+
30
+
23
31
  def time_it(func):
24
32
  @wraps(func)
25
33
  def wrapper(*args, **kwargs):
@@ -124,7 +132,7 @@ def data_to_html(data, replace_new_lines=False):
124
132
  from pygments.formatters import HtmlFormatter
125
133
  from IPython.display import HTML
126
134
 
127
- json_str = json.dumps(data, indent=4)
135
+ json_str = json.dumps(data, indent=4, cls=CustomEncoder)
128
136
  formatted_json = highlight(
129
137
  json_str,
130
138
  JsonLexer(),