edsl 0.1.62__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
edsl/__init__.py CHANGED
@@ -150,3 +150,69 @@ BaseException.install_exception_hook()
150
150
 
151
151
  # Log the total number of items in __all__ for debugging
152
152
  logger.debug(f"EDSL initialization complete with {len(__all__)} items in __all__")
153
+
154
+
155
+ def check_for_updates(silent: bool = False) -> dict:
156
+ """
157
+ Check if there's a newer version of EDSL available.
158
+
159
+ Args:
160
+ silent: If True, don't print any messages to console
161
+
162
+ Returns:
163
+ dict with version info if update is available, None otherwise
164
+ """
165
+ from edsl.coop import Coop
166
+
167
+ coop = Coop()
168
+ return coop.check_for_updates(silent=silent)
169
+
170
+
171
+ # Add check_for_updates to exports
172
+ __all__.append("check_for_updates")
173
+
174
+
175
+ # Perform version check on import (non-blocking)
176
+ def _check_version_on_import():
177
+ """Check for updates on package import in a non-blocking way."""
178
+ import threading
179
+ import os
180
+
181
+ # Check if version check is disabled
182
+ if os.getenv("EDSL_DISABLE_VERSION_CHECK", "").lower() in ["1", "true", "yes"]:
183
+ return
184
+
185
+ # Check if we've already checked recently (within 24 hours)
186
+ cache_file = os.path.join(os.path.expanduser("~"), ".edsl_version_check")
187
+ try:
188
+ if os.path.exists(cache_file):
189
+ with open(cache_file, "r") as f:
190
+ last_check = float(f.read().strip())
191
+ if time.time() - last_check < 86400: # 24 hours
192
+ return
193
+ except Exception:
194
+ pass
195
+
196
+ def check_in_background():
197
+ try:
198
+ # Update cache file
199
+ with open(cache_file, "w") as f:
200
+ f.write(str(time.time()))
201
+
202
+ # Perform the check
203
+ from edsl.coop import Coop
204
+
205
+ coop = Coop()
206
+ coop.check_for_updates(silent=False)
207
+ except Exception:
208
+ # Silently fail
209
+ pass
210
+
211
+ check_in_background()
212
+ # # Run in a separate thread to avoid blocking imports
213
+ # thread = threading.Thread(target=check_in_background, daemon=True)
214
+ # thread.start()
215
+
216
+
217
+ # Run version check on import
218
+ _check_version_on_import()
edsl/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.62"
1
+ __version__ = "1.0.0"
edsl/base/base_class.py CHANGED
@@ -516,6 +516,51 @@ class PersistenceMixin:
516
516
  c = Coop()
517
517
  return c.search(cls, query)
518
518
 
519
+ def clipboard(self):
520
+ """Copy this object's representation to the system clipboard.
521
+
522
+ This method first checks if the object has a custom clipboard_data() method.
523
+ If it does, it uses that method's output. Otherwise, it serializes the object
524
+ to a dictionary (without version info) and copies it to the system clipboard as JSON text.
525
+
526
+ Returns:
527
+ None, but prints a confirmation message
528
+ """
529
+ import subprocess
530
+ import json
531
+ import platform
532
+
533
+ # Check if the object has a custom clipboard_data method
534
+ if hasattr(self, 'clipboard_data') and callable(getattr(self, 'clipboard_data')):
535
+ clipboard_text = self.clipboard_data()
536
+ else:
537
+ # Default behavior: use to_dict and convert to JSON
538
+ obj_dict = self.to_dict(add_edsl_version=False)
539
+ clipboard_text = json.dumps(obj_dict, indent=2)
540
+
541
+ # Determine the clipboard command based on the operating system
542
+ system = platform.system()
543
+
544
+ try:
545
+ if system == "Darwin": # macOS
546
+ process = subprocess.Popen(['pbcopy'], stdin=subprocess.PIPE)
547
+ process.communicate(clipboard_text.encode('utf-8'))
548
+ elif system == "Linux":
549
+ process = subprocess.Popen(['xclip', '-selection', 'clipboard'], stdin=subprocess.PIPE)
550
+ process.communicate(clipboard_text.encode('utf-8'))
551
+ elif system == "Windows":
552
+ process = subprocess.Popen(['clip'], stdin=subprocess.PIPE, shell=True)
553
+ process.communicate(clipboard_text.encode('utf-8'))
554
+ else:
555
+ print(f"Clipboard not supported on {system}")
556
+ return
557
+
558
+ print("Object data copied to clipboard")
559
+ except FileNotFoundError:
560
+ print("Clipboard command not found. Please install pbcopy (macOS), xclip (Linux), or use Windows.")
561
+ except Exception as e:
562
+ print(f"Failed to copy to clipboard: {e}")
563
+
519
564
  def store(self, d: dict, key_name: Optional[str] = None):
520
565
  if key_name is None:
521
566
  index = len(d)
edsl/cli.py CHANGED
@@ -26,91 +26,113 @@ app.add_typer(plugins_app, name="plugins")
26
26
  validation_app = typer.Typer(help="Manage EDSL validation failures")
27
27
  app.add_typer(validation_app, name="validation")
28
28
 
29
+
29
30
  @validation_app.command("logs")
30
31
  def list_validation_logs(
31
32
  count: int = typer.Option(10, "--count", "-n", help="Number of logs to show"),
32
- question_type: Optional[str] = typer.Option(None, "--type", "-t", help="Filter by question type"),
33
- output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output file path"),
33
+ question_type: Optional[str] = typer.Option(
34
+ None, "--type", "-t", help="Filter by question type"
35
+ ),
36
+ output: Optional[Path] = typer.Option(
37
+ None, "--output", "-o", help="Output file path"
38
+ ),
34
39
  ):
35
40
  """List validation failure logs."""
36
41
  from .questions.validation_logger import get_validation_failure_logs
37
-
42
+
38
43
  logs = get_validation_failure_logs(n=count)
39
-
44
+
40
45
  # Filter by question type if provided
41
46
  if question_type:
42
47
  logs = [log for log in logs if log.get("question_type") == question_type]
43
-
48
+
44
49
  if output:
45
50
  with open(output, "w") as f:
46
51
  json.dump(logs, f, indent=2)
47
52
  console.print(f"[green]Logs written to {output}[/green]")
48
53
  else:
49
54
  console.print_json(json.dumps(logs, indent=2))
50
-
55
+
56
+
51
57
  @validation_app.command("clear")
52
58
  def clear_validation_logs():
53
59
  """Clear validation failure logs."""
54
60
  from .questions.validation_logger import clear_validation_logs
55
-
61
+
56
62
  clear_validation_logs()
57
63
  console.print("[green]Validation logs cleared.[/green]")
58
-
64
+
65
+
59
66
  @validation_app.command("stats")
60
67
  def validation_stats(
61
- output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output file path"),
68
+ output: Optional[Path] = typer.Option(
69
+ None, "--output", "-o", help="Output file path"
70
+ ),
62
71
  ):
63
72
  """Show validation failure statistics."""
64
73
  from .questions.validation_analysis import get_validation_failure_stats
65
-
74
+
66
75
  stats = get_validation_failure_stats()
67
-
76
+
68
77
  if output:
69
78
  with open(output, "w") as f:
70
79
  json.dump(stats, f, indent=2)
71
80
  console.print(f"[green]Stats written to {output}[/green]")
72
81
  else:
73
82
  console.print_json(json.dumps(stats, indent=2))
74
-
83
+
84
+
75
85
  @validation_app.command("suggest")
76
86
  def suggest_improvements(
77
- question_type: Optional[str] = typer.Option(None, "--type", "-t", help="Filter by question type"),
78
- output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output file path"),
87
+ question_type: Optional[str] = typer.Option(
88
+ None, "--type", "-t", help="Filter by question type"
89
+ ),
90
+ output: Optional[Path] = typer.Option(
91
+ None, "--output", "-o", help="Output file path"
92
+ ),
79
93
  ):
80
94
  """Suggest improvements for fix methods."""
81
95
  from .questions.validation_analysis import suggest_fix_improvements
82
-
96
+
83
97
  suggestions = suggest_fix_improvements(question_type=question_type)
84
-
98
+
85
99
  if output:
86
100
  with open(output, "w") as f:
87
101
  json.dump(suggestions, f, indent=2)
88
102
  console.print(f"[green]Suggestions written to {output}[/green]")
89
103
  else:
90
104
  console.print_json(json.dumps(suggestions, indent=2))
91
-
105
+
106
+
92
107
  @validation_app.command("report")
93
108
  def generate_report(
94
- output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output file path"),
109
+ output: Optional[Path] = typer.Option(
110
+ None, "--output", "-o", help="Output file path"
111
+ ),
95
112
  ):
96
113
  """Generate a comprehensive validation report."""
97
114
  from .questions.validation_analysis import export_improvements_report
98
-
115
+
99
116
  report_path = export_improvements_report(output_path=output)
100
117
  console.print(f"[green]Report generated at: {report_path}[/green]")
101
-
118
+
119
+
102
120
  @validation_app.command("html-report")
103
121
  def generate_html_report(
104
- output: Optional[Path] = typer.Option(None, "--output", "-o", help="Output file path"),
105
- open_browser: bool = typer.Option(True, "--open/--no-open", help="Open the report in a browser"),
122
+ output: Optional[Path] = typer.Option(
123
+ None, "--output", "-o", help="Output file path"
124
+ ),
125
+ open_browser: bool = typer.Option(
126
+ True, "--open/--no-open", help="Open the report in a browser"
127
+ ),
106
128
  ):
107
129
  """Generate an HTML validation report and optionally open it in a browser."""
108
130
  from .questions.validation_html_report import generate_html_report
109
131
  import webbrowser
110
-
132
+
111
133
  report_path = generate_html_report(output_path=output)
112
134
  console.print(f"[green]HTML report generated at: {report_path}[/green]")
113
-
135
+
114
136
  if open_browser:
115
137
  try:
116
138
  webbrowser.open(f"file://{report_path}")
@@ -119,15 +141,17 @@ def generate_html_report(
119
141
  console.print(f"[yellow]Could not open browser: {e}[/yellow]")
120
142
  console.print(f"[yellow]Report is available at: {report_path}[/yellow]")
121
143
 
144
+
122
145
  @app.callback()
123
146
  def callback():
124
147
  """
125
148
  Expected Parrot EDSL Command Line Interface.
126
-
149
+
127
150
  A toolkit for creating, managing, and running surveys with language models.
128
151
  """
129
152
  pass
130
153
 
154
+
131
155
  @app.command()
132
156
  def version():
133
157
  """Show the EDSL version."""
@@ -135,8 +159,50 @@ def version():
135
159
  version = metadata.version("edsl")
136
160
  console.print(f"[bold cyan]EDSL version:[/bold cyan] {version}")
137
161
  except metadata.PackageNotFoundError:
138
- console.print("[yellow]EDSL package not installed or version not available.[/yellow]")
162
+ console.print(
163
+ "[yellow]EDSL package not installed or version not available.[/yellow]"
164
+ )
165
+
166
+
167
+ @app.command()
168
+ def check_updates():
169
+ """Check for available EDSL updates."""
170
+ try:
171
+ from edsl import check_for_updates
172
+
173
+ console.print("[cyan]Checking for updates...[/cyan]")
174
+ update_info = check_for_updates(silent=True)
175
+
176
+ if update_info:
177
+ console.print("\n[bold yellow]📦 Update Available![/bold yellow]")
178
+ console.print(
179
+ f"[cyan]Current version:[/cyan] {update_info['current_version']}"
180
+ )
181
+ console.print(
182
+ f"[green]Latest version:[/green] {update_info['latest_version']}"
183
+ )
184
+ if update_info.get("update_info"):
185
+ console.print(f"[cyan]Update info:[/cyan] {update_info['update_info']}")
186
+ console.print(f"\n[bold]To update:[/bold] {update_info['update_command']}")
187
+ else:
188
+ console.print(
189
+ "[green]✓ You are running the latest version of EDSL![/green]"
190
+ )
191
+ except Exception as e:
192
+ console.print(f"[red]Error checking for updates: {str(e)}[/red]")
193
+
139
194
 
140
195
  def main():
141
196
  """Main entry point for the EDSL CLI."""
142
- app()
197
+ # Check for updates on startup if environment variable is set
198
+ import os
199
+
200
+ if os.getenv("EDSL_CHECK_UPDATES_ON_STARTUP", "").lower() in ["1", "true", "yes"]:
201
+ try:
202
+ from edsl import check_for_updates
203
+
204
+ check_for_updates(silent=False)
205
+ except Exception:
206
+ pass # Silently fail if update check fails
207
+
208
+ app()
@@ -111,6 +111,10 @@ CONFIG_MAP = {
111
111
  "default": "10", # Change to a very low threshold (10 bytes) to test SQLite offloading
112
112
  "info": "This config var determines the memory threshold in bytes before Results' SQLList offloads data to SQLite.",
113
113
  },
114
+ "EDSL_USE_SQLITE_FOR_SCENARIO_LIST": {
115
+ "default": "False",
116
+ "info": "This config var determines whether to use SQLite for ScenarioList instances.",
117
+ },
114
118
  }
115
119
 
116
120
 
edsl/coop/coop.py CHANGED
@@ -273,6 +273,118 @@ class Coop(CoopFunctionsMixin):
273
273
 
274
274
  return user_stable_version < server_stable_version
275
275
 
276
+ def check_for_updates(self, silent: bool = False) -> Optional[dict]:
277
+ """
278
+ Check if there's a newer version of EDSL available.
279
+
280
+ Args:
281
+ silent: If True, don't print any messages to console
282
+
283
+ Returns:
284
+ dict with version info if update is available, None otherwise
285
+ """
286
+ try:
287
+ # Use the new /version/updates endpoint
288
+ response = self._send_server_request(
289
+ uri="version/updates", method="GET", timeout=5
290
+ )
291
+
292
+ data = response.json()
293
+
294
+ # Extract version information from the response
295
+ current_version = data.get("current") # Latest version in use
296
+ guid_message = data.get("guid_message", "") # Message about updates
297
+ force_update = (
298
+ "force update" in guid_message.lower() if guid_message else False
299
+ )
300
+ # Check if update is needed
301
+ if current_version and self._user_version_is_outdated(
302
+ user_version_str=self._edsl_version,
303
+ server_version_str=current_version,
304
+ ):
305
+ update_data = {
306
+ "current_version": self._edsl_version,
307
+ "latest_version": current_version,
308
+ "guid_message": guid_message,
309
+ "force_update": force_update,
310
+ "update_command": "pip install --upgrade edsl",
311
+ }
312
+
313
+ if not silent:
314
+ print("\n" + "=" * 60)
315
+ print("📦 EDSL Update Available!")
316
+ print(f"Your version: {self._edsl_version}")
317
+ print(f"Latest version: {current_version}")
318
+
319
+ # Display the guid message if present
320
+ if guid_message:
321
+ print(f"\n{guid_message}")
322
+
323
+ # Prompt user for update
324
+ prompt_message = "\nDo you want to update now? [Y/n] "
325
+ if force_update:
326
+ prompt_message = "\n⚠️ FORCE UPDATE REQUIRED - Do you want to update now? [Y/n] "
327
+
328
+ print(prompt_message, end="")
329
+
330
+ try:
331
+ user_input = input().strip().lower()
332
+ if user_input in ["", "y", "yes"]:
333
+ # Actually run the update
334
+ print("\nUpdating EDSL...")
335
+ import subprocess
336
+ import sys
337
+
338
+ try:
339
+ # Run pip install --upgrade edsl
340
+ result = subprocess.run(
341
+ [
342
+ sys.executable,
343
+ "-m",
344
+ "pip",
345
+ "install",
346
+ "--upgrade",
347
+ "edsl",
348
+ ],
349
+ capture_output=True,
350
+ text=True,
351
+ )
352
+
353
+ if result.returncode == 0:
354
+ print(
355
+ "✅ Update successful! Please restart your application."
356
+ )
357
+ else:
358
+ print(f"❌ Update failed: {result.stderr}")
359
+ print(
360
+ "You can try updating manually with: pip install --upgrade edsl"
361
+ )
362
+ except Exception as e:
363
+ print(f"❌ Update failed: {str(e)}")
364
+ print(
365
+ "You can try updating manually with: pip install --upgrade edsl"
366
+ )
367
+ else:
368
+ print(
369
+ "\nUpdate skipped. You can update later with: pip install --upgrade edsl"
370
+ )
371
+
372
+ print("=" * 60 + "\n")
373
+
374
+ except (EOFError, KeyboardInterrupt):
375
+ print(
376
+ "\nUpdate skipped. You can update later with: pip install --upgrade edsl"
377
+ )
378
+ print("=" * 60 + "\n")
379
+
380
+ return update_data
381
+
382
+ except Exception:
383
+ # Silently fail if we can't check for updates
384
+ pass
385
+
386
+ return None
387
+
276
388
  def _resolve_server_response(
277
389
  self, response: requests.Response, check_api_key: bool = True
278
390
  ) -> None:
@@ -280,18 +392,35 @@ class Coop(CoopFunctionsMixin):
280
392
  Check the response from the server and raise errors as appropriate.
281
393
  """
282
394
  # Get EDSL version from header
283
- # breakpoint()
284
- # Commented out as currently unused
285
- # server_edsl_version = response.headers.get("X-EDSL-Version")
286
-
287
- # if server_edsl_version:
288
- # if self._user_version_is_outdated(
289
- # user_version_str=self._edsl_version,
290
- # server_version_str=server_edsl_version,
291
- # ):
292
- # print(
293
- # "Please upgrade your EDSL version to access our latest features. Open your terminal and run `pip install --upgrade edsl`"
294
- # )
395
+ server_edsl_version = response.headers.get("X-EDSL-Version")
396
+
397
+ if server_edsl_version:
398
+ if self._user_version_is_outdated(
399
+ user_version_str=self._edsl_version,
400
+ server_version_str=server_edsl_version,
401
+ ):
402
+ # Get additional info from server if available
403
+ update_info = response.headers.get("X-EDSL-Update-Info", "")
404
+
405
+ print("\n" + "=" * 60)
406
+ print("📦 EDSL Update Available!")
407
+ print(f"Your version: {self._edsl_version}")
408
+ print(f"Latest version: {server_edsl_version}")
409
+ if update_info:
410
+ print(f"Update info: {update_info}")
411
+ print(
412
+ "\nYour version is out of date - can we update to latest version? [Y/n]"
413
+ )
414
+
415
+ try:
416
+ user_input = input().strip().lower()
417
+ if user_input in ["", "y", "yes"]:
418
+ print("To update, run: pip install --upgrade edsl")
419
+ print("=" * 60 + "\n")
420
+ except (EOFError, KeyboardInterrupt):
421
+ # Handle non-interactive environments
422
+ print("To update, run: pip install --upgrade edsl")
423
+ print("=" * 60 + "\n")
295
424
  if response.status_code >= 400:
296
425
  try:
297
426
  message = str(response.json().get("detail"))
@@ -1562,7 +1691,6 @@ class Coop(CoopFunctionsMixin):
1562
1691
 
1563
1692
  # The job has been offloaded to GCS
1564
1693
  if include_json_string and json_string == "offloaded":
1565
-
1566
1694
  # Attempt to fetch JSON string from GCS
1567
1695
  response = self._send_server_request(
1568
1696
  uri="api/v0/remote-inference/pull",
edsl/dataset/dataset.py CHANGED
@@ -1017,6 +1017,53 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
1017
1017
  # Save the document
1018
1018
  doc.save(output_file)
1019
1019
 
1020
+ def unique(self) -> "Dataset":
1021
+ """
1022
+ Remove duplicate rows from the dataset.
1023
+
1024
+ Returns:
1025
+ A new Dataset with duplicate rows removed.
1026
+
1027
+ Examples:
1028
+ >>> d = Dataset([{'a': [1, 2, 3, 1]}, {'b': [4, 5, 6, 4]}])
1029
+ >>> d.unique().data
1030
+ [{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
1031
+
1032
+ >>> d = Dataset([{'x': ['a', 'b', 'a']}, {'y': [1, 2, 1]}])
1033
+ >>> d.unique().data
1034
+ [{'x': ['a', 'b']}, {'y': [1, 2]}]
1035
+
1036
+ >>> # Dataset with a single column
1037
+ >>> Dataset([{'value': [1, 2, 3, 2, 1, 3]}]).unique().data
1038
+ [{'value': [1, 2, 3]}]
1039
+ """
1040
+ # Convert data to tuples for each row to make them hashable
1041
+ rows = []
1042
+ for i in range(len(self)):
1043
+ row = tuple(entry[list(entry.keys())[0]][i] for entry in self.data)
1044
+ rows.append(row)
1045
+
1046
+ # Keep track of unique rows and their indices
1047
+ unique_rows = []
1048
+ indices = []
1049
+
1050
+ # Use a set to track seen rows
1051
+ seen = set()
1052
+ for i, row in enumerate(rows):
1053
+ if row not in seen:
1054
+ seen.add(row)
1055
+ unique_rows.append(row)
1056
+ indices.append(i)
1057
+
1058
+ # Create a new dataset with only the unique rows
1059
+ new_data = []
1060
+ for entry in self.data:
1061
+ key, values = list(entry.items())[0]
1062
+ new_values = [values[i] for i in indices]
1063
+ new_data.append({key: new_values})
1064
+
1065
+ return Dataset(new_data)
1066
+
1020
1067
  def expand(self, field: str, number_field: bool = False) -> "Dataset":
1021
1068
  """
1022
1069
  Expand a field containing lists into multiple rows.
@@ -1086,47 +1133,6 @@ class Dataset(UserList, DatasetOperationsMixin, PersistenceMixin, HashingMixin):
1086
1133
 
1087
1134
  return Dataset(new_data)
1088
1135
 
1089
- def unique(self) -> "Dataset":
1090
- """Return a new dataset with only unique observations.
1091
-
1092
- Examples:
1093
- >>> d = Dataset([{'a': [1, 2, 2, 3]}, {'b': [4, 5, 5, 6]}])
1094
- >>> d.unique().data
1095
- [{'a': [1, 2, 3]}, {'b': [4, 5, 6]}]
1096
-
1097
- >>> d = Dataset([{'x': ['a', 'a', 'b']}, {'y': [1, 1, 2]}])
1098
- >>> d.unique().data
1099
- [{'x': ['a', 'b']}, {'y': [1, 2]}]
1100
- """
1101
- # Get all column names and values
1102
- headers, data = self._tabular()
1103
-
1104
- # Create a list of unique rows
1105
- unique_rows = []
1106
- seen = set()
1107
-
1108
- for row in data:
1109
- # Convert the row to a hashable representation for comparison
1110
- # We need to handle potential unhashable types
1111
- try:
1112
- row_key = tuple(map(lambda x: str(x) if isinstance(x, (list, dict)) else x, row))
1113
- if row_key not in seen:
1114
- seen.add(row_key)
1115
- unique_rows.append(row)
1116
- except:
1117
- # Fallback for complex objects: compare based on string representation
1118
- row_str = str(row)
1119
- if row_str not in seen:
1120
- seen.add(row_str)
1121
- unique_rows.append(row)
1122
-
1123
- # Create a new dataset with unique combinations
1124
- new_data = []
1125
- for i, header in enumerate(headers):
1126
- values = [row[i] for row in unique_rows]
1127
- new_data.append({header: values})
1128
-
1129
- return Dataset(new_data)
1130
1136
 
1131
1137
 
1132
1138
  if __name__ == "__main__":