edsl 0.1.54__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +3 -2
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +94 -5
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/invigilators/invigilators.py +9 -0
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +86 -6
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +45 -75
- edsl/language_models/registry.py +5 -0
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_dict.py +201 -16
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +63 -16
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +117 -6
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/METADATA +51 -76
- {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/RECORD +99 -75
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
- {edsl-0.1.54.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
edsl/jobs/remote_inference.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import re
|
1
2
|
from typing import Optional, Union, Literal, TYPE_CHECKING, NewType, Callable, Any
|
2
3
|
from dataclasses import dataclass
|
3
4
|
from ..coop import CoopServerResponseError
|
@@ -112,13 +113,18 @@ class JobsRemoteInferenceHandler:
|
|
112
113
|
)
|
113
114
|
logger.add_info("job_uuid", job_uuid)
|
114
115
|
|
116
|
+
remote_inference_url = self.remote_inference_url
|
117
|
+
if "localhost" in remote_inference_url:
|
118
|
+
remote_inference_url = remote_inference_url.replace("8000", "1234")
|
115
119
|
logger.update(
|
116
|
-
f"Job details are available at your Coop account. [Go to Remote Inference page]({
|
120
|
+
f"Job details are available at your Coop account. [Go to Remote Inference page]({remote_inference_url})",
|
117
121
|
status=JobsStatus.RUNNING,
|
118
122
|
)
|
119
123
|
progress_bar_url = (
|
120
124
|
f"{self.expected_parrot_url}/home/remote-job-progress/{job_uuid}"
|
121
125
|
)
|
126
|
+
if "localhost" in progress_bar_url:
|
127
|
+
progress_bar_url = progress_bar_url.replace("8000", "1234")
|
122
128
|
logger.add_info("progress_bar_url", progress_bar_url)
|
123
129
|
logger.update(
|
124
130
|
f"View job progress [here]({progress_bar_url})", status=JobsStatus.RUNNING
|
@@ -200,10 +206,35 @@ class JobsRemoteInferenceHandler:
|
|
200
206
|
status=JobsStatus.FAILED,
|
201
207
|
)
|
202
208
|
|
209
|
+
def _handle_partially_failed_job_interview_details(
|
210
|
+
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
211
|
+
) -> None:
|
212
|
+
"Extracts the interview details from the remote job data."
|
213
|
+
try:
|
214
|
+
# Job details is a string of the form "64 out of 1,758 interviews failed"
|
215
|
+
job_details = remote_job_data.get("latest_failure_description")
|
216
|
+
|
217
|
+
text_without_commas = job_details.replace(",", "")
|
218
|
+
|
219
|
+
# Find all numbers in the text
|
220
|
+
numbers = [int(num) for num in re.findall(r"\d+", text_without_commas)]
|
221
|
+
|
222
|
+
failed = numbers[0]
|
223
|
+
total = numbers[1]
|
224
|
+
completed = total - failed
|
225
|
+
|
226
|
+
job_info.logger.add_info("completed_interviews", completed)
|
227
|
+
job_info.logger.add_info("failed_interviews", failed)
|
228
|
+
# This is mainly helpful metadata, and any errors here should not stop the code
|
229
|
+
except:
|
230
|
+
pass
|
231
|
+
|
203
232
|
def _handle_partially_failed_job(
|
204
233
|
self, job_info: RemoteJobInfo, remote_job_data: RemoteInferenceResponse
|
205
234
|
) -> None:
|
206
235
|
"Handles a partially failed job by logging the error and updating the job status."
|
236
|
+
self._handle_partially_failed_job_interview_details(job_info, remote_job_data)
|
237
|
+
|
207
238
|
latest_error_report_url = remote_job_data.get("latest_error_report_url")
|
208
239
|
|
209
240
|
if latest_error_report_url:
|
@@ -244,6 +275,8 @@ class JobsRemoteInferenceHandler:
|
|
244
275
|
job_info.logger.add_info("results_uuid", results_uuid)
|
245
276
|
results = object_fetcher(results_uuid, expected_object_type="results")
|
246
277
|
results_url = remote_job_data.get("results_url")
|
278
|
+
if "localhost" in results_url:
|
279
|
+
results_url = results_url.replace("8000", "1234")
|
247
280
|
job_info.logger.add_info("results_url", results_url)
|
248
281
|
|
249
282
|
if job_status == "completed":
|
@@ -256,6 +289,7 @@ class JobsRemoteInferenceHandler:
|
|
256
289
|
f"View partial results [here]({results_url})",
|
257
290
|
status=JobsStatus.PARTIALLY_FAILED,
|
258
291
|
)
|
292
|
+
|
259
293
|
results.job_uuid = job_info.job_uuid
|
260
294
|
results.results_uuid = results_uuid
|
261
295
|
return results
|
@@ -2,6 +2,7 @@ from typing import Optional, TYPE_CHECKING
|
|
2
2
|
import os
|
3
3
|
from functools import lru_cache
|
4
4
|
import textwrap
|
5
|
+
import requests
|
5
6
|
|
6
7
|
if TYPE_CHECKING:
|
7
8
|
from ..coop import Coop
|
@@ -255,7 +256,11 @@ class KeyLookupBuilder:
|
|
255
256
|
return dict(list(os.environ.items()))
|
256
257
|
|
257
258
|
def _coop_key_value_pairs(self):
|
258
|
-
|
259
|
+
try:
|
260
|
+
return dict(list(self.coop.fetch_rate_limit_config_vars().items()))
|
261
|
+
except requests.ConnectionError:
|
262
|
+
# If connection fails, return empty dict instead of raising error
|
263
|
+
return {}
|
259
264
|
|
260
265
|
def _config_key_value_pairs(self):
|
261
266
|
from ..config import CONFIG
|
@@ -365,6 +365,59 @@ class LanguageModel(
|
|
365
365
|
self._api_token = info.api_token
|
366
366
|
return self._api_token
|
367
367
|
|
368
|
+
def copy(self) -> "LanguageModel":
|
369
|
+
"""Create a deep copy of this language model instance.
|
370
|
+
|
371
|
+
This method creates a completely independent copy of the language model
|
372
|
+
by creating a new instance with the same parameters and copying relevant attributes.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
LanguageModel: A new language model instance that is functionally identical to this one
|
376
|
+
|
377
|
+
Examples:
|
378
|
+
>>> m1 = LanguageModel.example()
|
379
|
+
>>> m2 = m1.copy()
|
380
|
+
>>> m1 == m2 # Functionally equivalent
|
381
|
+
True
|
382
|
+
>>> id(m1) == id(m2) # But different objects
|
383
|
+
False
|
384
|
+
"""
|
385
|
+
# Create a new instance of the same class with the same parameters
|
386
|
+
try:
|
387
|
+
# For most models, we can instantiate with the saved parameters
|
388
|
+
new_model = self.__class__(**self.parameters)
|
389
|
+
|
390
|
+
# Copy all important instance attributes
|
391
|
+
for key, value in self.__dict__.items():
|
392
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
393
|
+
setattr(new_model, key, value)
|
394
|
+
|
395
|
+
return new_model
|
396
|
+
except Exception:
|
397
|
+
# Fallback for dynamically created classes like TestServiceLanguageModel
|
398
|
+
from ..inference_services import default
|
399
|
+
|
400
|
+
# If this is a test model, create a new test model instance
|
401
|
+
if getattr(self, "_inference_service_", "") == "test":
|
402
|
+
service = default.get_service("test")
|
403
|
+
model_class = service.create_model("test")
|
404
|
+
new_model = model_class(**self.parameters)
|
405
|
+
|
406
|
+
# Copy attributes
|
407
|
+
for key, value in self.__dict__.items():
|
408
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
409
|
+
setattr(new_model, key, value)
|
410
|
+
|
411
|
+
return new_model
|
412
|
+
|
413
|
+
# If we can't create the model directly, just return a simple test model
|
414
|
+
# This is a last resort fallback
|
415
|
+
from ..inference_services import get_service
|
416
|
+
|
417
|
+
service = get_service("test")
|
418
|
+
model_class = service.create_model("test")
|
419
|
+
return model_class()
|
420
|
+
|
368
421
|
def __getitem__(self, key):
|
369
422
|
"""Allow dictionary-style access to model attributes.
|
370
423
|
|
@@ -679,9 +732,12 @@ class LanguageModel(
|
|
679
732
|
user_prompt_with_hashes = user_prompt
|
680
733
|
|
681
734
|
# Prepare parameters for cache lookup
|
735
|
+
cache_parameters = self.parameters.copy()
|
736
|
+
if self.model == "test":
|
737
|
+
cache_parameters.pop("canned_response", None)
|
682
738
|
cache_call_params = {
|
683
739
|
"model": str(self.model),
|
684
|
-
"parameters":
|
740
|
+
"parameters": cache_parameters,
|
685
741
|
"system_prompt": system_prompt,
|
686
742
|
"user_prompt": user_prompt_with_hashes,
|
687
743
|
"iteration": iteration,
|
@@ -844,7 +900,7 @@ class LanguageModel(
|
|
844
900
|
# Use the price manager to calculate the actual cost
|
845
901
|
from .price_manager import PriceManager
|
846
902
|
|
847
|
-
price_manager = PriceManager()
|
903
|
+
price_manager = PriceManager.get_instance()
|
848
904
|
|
849
905
|
return price_manager.calculate_cost(
|
850
906
|
inference_service=self._inference_service_,
|
@@ -873,9 +929,15 @@ class LanguageModel(
|
|
873
929
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
874
930
|
"""
|
875
931
|
# Build the base dictionary with essential model information
|
932
|
+
parameters = self.parameters.copy()
|
933
|
+
|
934
|
+
# For test models, ensure canned_response is included in serialization
|
935
|
+
if self.model == "test" and hasattr(self, "canned_response"):
|
936
|
+
parameters["canned_response"] = self.canned_response
|
937
|
+
|
876
938
|
d = {
|
877
939
|
"model": self.model,
|
878
|
-
"parameters":
|
940
|
+
"parameters": parameters,
|
879
941
|
"inference_service": self._inference_service_,
|
880
942
|
}
|
881
943
|
|
@@ -913,7 +975,25 @@ class LanguageModel(
|
|
913
975
|
data["model"], service_name=data.get("inference_service", None)
|
914
976
|
)
|
915
977
|
|
916
|
-
#
|
978
|
+
# Handle canned_response in parameters for test models
|
979
|
+
if (
|
980
|
+
data["model"] == "test"
|
981
|
+
and "parameters" in data
|
982
|
+
and "canned_response" in data["parameters"]
|
983
|
+
):
|
984
|
+
# Extract canned_response from parameters to set as a direct attribute
|
985
|
+
canned_response = data["parameters"]["canned_response"]
|
986
|
+
params_copy = data.copy()
|
987
|
+
|
988
|
+
# Direct attribute will be set during initialization
|
989
|
+
# Add it as a top-level parameter for model initialization
|
990
|
+
if isinstance(params_copy, dict) and "parameters" in params_copy:
|
991
|
+
params_copy["canned_response"] = canned_response
|
992
|
+
|
993
|
+
# Create the instance with canned_response as a direct parameter
|
994
|
+
return model_class(**params_copy)
|
995
|
+
|
996
|
+
# For non-test models or test models without canned_response
|
917
997
|
return model_class(**data)
|
918
998
|
|
919
999
|
def __repr__(self) -> str:
|
@@ -999,8 +1079,8 @@ class LanguageModel(
|
|
999
1079
|
|
1000
1080
|
Create a test model that throws exceptions:
|
1001
1081
|
|
1002
|
-
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
|
1003
|
-
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
|
1082
|
+
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
|
1083
|
+
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
|
1004
1084
|
Exception report saved to ...
|
1005
1085
|
"""
|
1006
1086
|
from ..language_models import Model
|
edsl/language_models/model.py
CHANGED
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
|
|
6
6
|
from ..config import CONFIG
|
7
7
|
from .exceptions import LanguageModelValueError
|
8
8
|
|
9
|
+
# Import only what's needed initially to avoid circular imports
|
9
10
|
from ..inference_services import (InferenceServicesCollection,
|
10
|
-
AvailableModels, InferenceServiceABC, InferenceServiceError
|
11
|
+
AvailableModels, InferenceServiceABC, InferenceServiceError)
|
12
|
+
# The 'default' import will be imported lazily when needed
|
11
13
|
|
12
14
|
from ..enums import InferenceServiceLiteral
|
13
15
|
|
@@ -20,7 +22,10 @@ def get_model_class(
|
|
20
22
|
registry: Optional[InferenceServicesCollection] = None,
|
21
23
|
service_name: Optional[InferenceServiceLiteral] = None,
|
22
24
|
):
|
23
|
-
|
25
|
+
if registry is None:
|
26
|
+
# Import default lazily only when needed
|
27
|
+
from ..inference_services import default as inference_default
|
28
|
+
registry = inference_default
|
24
29
|
try:
|
25
30
|
factory = registry.create_model_factory(model_name, service_name=service_name)
|
26
31
|
return factory
|
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
|
|
54
59
|
def get_registry(cls) -> InferenceServicesCollection:
|
55
60
|
"""Get the current registry or initialize with default if None"""
|
56
61
|
if cls._registry is None:
|
57
|
-
|
62
|
+
# Import default lazily only when needed
|
63
|
+
from ..inference_services import default as inference_default
|
64
|
+
cls._registry = inference_default
|
58
65
|
return cls._registry
|
59
66
|
|
60
67
|
@classmethod
|
@@ -8,20 +8,42 @@ class PriceManager:
|
|
8
8
|
|
9
9
|
def __new__(cls):
|
10
10
|
if cls._instance is None:
|
11
|
-
|
11
|
+
instance = super(PriceManager, cls).__new__(cls)
|
12
|
+
instance._price_lookup = {} # Instance-specific attribute
|
13
|
+
instance._is_initialized = False
|
14
|
+
cls._instance = instance # Store the instance directly
|
15
|
+
return instance
|
12
16
|
return cls._instance
|
13
17
|
|
14
18
|
def __init__(self):
|
15
|
-
|
19
|
+
"""Initialize the singleton instance only once."""
|
16
20
|
if not self._is_initialized:
|
17
21
|
self._is_initialized = True
|
18
22
|
self.refresh_prices()
|
19
23
|
|
20
|
-
|
21
|
-
|
22
|
-
|
24
|
+
@classmethod
|
25
|
+
def get_instance(cls):
|
26
|
+
"""Get the singleton instance, creating it if necessary."""
|
27
|
+
if cls._instance is None:
|
28
|
+
cls() # Create the instance if it doesn't exist
|
29
|
+
return cls._instance
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def reset(cls):
|
33
|
+
"""Reset the singleton instance to clean up resources."""
|
34
|
+
cls._instance = None
|
35
|
+
cls._is_initialized = False
|
36
|
+
cls._price_lookup = {}
|
23
37
|
|
24
|
-
|
38
|
+
def __del__(self):
|
39
|
+
"""Ensure proper cleanup when the instance is garbage collected."""
|
40
|
+
try:
|
41
|
+
self._price_lookup = {} # Clean up resources
|
42
|
+
except:
|
43
|
+
pass # Ignore any cleanup errors
|
44
|
+
|
45
|
+
def refresh_prices(self) -> None:
|
46
|
+
"""Fetch fresh prices and update the internal price lookup."""
|
25
47
|
from edsl.coop import Coop
|
26
48
|
|
27
49
|
c = Coop()
|
@@ -31,43 +53,18 @@ class PriceManager:
|
|
31
53
|
print(f"Error fetching prices: {str(e)}")
|
32
54
|
|
33
55
|
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
|
-
"""
|
35
|
-
Get the price information for a specific service and model combination.
|
36
|
-
If no specific price is found, returns a fallback price.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
inference_service (str): The name of the inference service
|
40
|
-
model (str): The model identifier
|
41
|
-
|
42
|
-
Returns:
|
43
|
-
Dict: Price information (either actual or fallback prices)
|
44
|
-
"""
|
56
|
+
"""Get the price information for a specific service and model."""
|
45
57
|
key = (inference_service, model)
|
46
58
|
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
59
|
inference_service
|
48
60
|
)
|
49
61
|
|
50
62
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
51
|
-
"""
|
52
|
-
Get the complete price lookup dictionary.
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
|
56
|
-
"""
|
63
|
+
"""Get the complete price lookup dictionary."""
|
57
64
|
return self._price_lookup.copy()
|
58
65
|
|
59
66
|
def _get_fallback_price(self, inference_service: str) -> Dict:
|
60
|
-
"""
|
61
|
-
Get fallback prices for a service.
|
62
|
-
- First fallback: The highest input and output prices for that service from the price lookup.
|
63
|
-
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
|
-
|
65
|
-
Args:
|
66
|
-
inference_service (str): The inference service name
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
Dict: Price information
|
70
|
-
"""
|
67
|
+
"""Get fallback prices for a service."""
|
71
68
|
service_prices = [
|
72
69
|
prices
|
73
70
|
for (service, _), prices in self._price_lookup.items()
|
@@ -77,18 +74,12 @@ class PriceManager:
|
|
77
74
|
input_tokens_per_usd = [
|
78
75
|
float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
|
79
76
|
]
|
80
|
-
|
81
|
-
min_input_tokens = min(input_tokens_per_usd)
|
82
|
-
else:
|
83
|
-
min_input_tokens = 1_000_000
|
77
|
+
min_input_tokens = min(input_tokens_per_usd, default=1_000_000)
|
84
78
|
|
85
79
|
output_tokens_per_usd = [
|
86
80
|
float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
|
87
81
|
]
|
88
|
-
|
89
|
-
min_output_tokens = min(output_tokens_per_usd)
|
90
|
-
else:
|
91
|
-
min_output_tokens = 1_000_000
|
82
|
+
min_output_tokens = min(output_tokens_per_usd, default=1_000_000)
|
92
83
|
|
93
84
|
return {
|
94
85
|
"input": {"one_usd_buys": min_input_tokens},
|
@@ -103,19 +94,7 @@ class PriceManager:
|
|
103
94
|
input_token_name: str,
|
104
95
|
output_token_name: str,
|
105
96
|
) -> Union[float, str]:
|
106
|
-
"""
|
107
|
-
Calculate the total cost for a model usage based on input and output tokens.
|
108
|
-
|
109
|
-
Args:
|
110
|
-
inference_service (str): The inference service identifier
|
111
|
-
model (str): The model identifier
|
112
|
-
usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
|
113
|
-
input_token_name (str): Key name for input tokens in the usage dict
|
114
|
-
output_token_name (str): Key name for output tokens in the usage dict
|
115
|
-
|
116
|
-
Returns:
|
117
|
-
Union[float, str]: Total cost if calculation successful, error message string if not
|
118
|
-
"""
|
97
|
+
"""Calculate the total cost for a model usage."""
|
119
98
|
relevant_prices = self.get_price(inference_service, model)
|
120
99
|
|
121
100
|
# Extract token counts
|
@@ -137,31 +116,22 @@ class PriceManager:
|
|
137
116
|
return f"Could not fetch prices from {relevant_prices} - {e}"
|
138
117
|
|
139
118
|
# Calculate input cost
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
except Exception as e:
|
146
|
-
return f"Could not compute input price - {e}."
|
119
|
+
input_cost = (
|
120
|
+
0
|
121
|
+
if inverse_input_price == "infinity"
|
122
|
+
else input_tokens / float(inverse_input_price)
|
123
|
+
)
|
147
124
|
|
148
125
|
# Calculate output cost
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
except Exception as e:
|
155
|
-
return f"Could not compute output price - {e}"
|
126
|
+
output_cost = (
|
127
|
+
0
|
128
|
+
if inverse_output_price == "infinity"
|
129
|
+
else output_tokens / float(inverse_output_price)
|
130
|
+
)
|
156
131
|
|
157
132
|
return input_cost + output_cost
|
158
133
|
|
159
134
|
@property
|
160
135
|
def is_initialized(self) -> bool:
|
161
|
-
"""
|
162
|
-
Check if the PriceManager has been initialized.
|
163
|
-
|
164
|
-
Returns:
|
165
|
-
bool: True if initialized, False otherwise
|
166
|
-
"""
|
136
|
+
"""Check if the PriceManager has been initialized."""
|
167
137
|
return self._is_initialized
|
edsl/language_models/registry.py
CHANGED
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
18
18
|
|
19
19
|
_registry = {} # Initialize the registry as a dictionary
|
20
20
|
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def clear_registry(cls):
|
24
|
+
"""Clear the registry to prevent memory leaks."""
|
25
|
+
cls._registry = {}
|
21
26
|
|
22
27
|
def __init__(cls, name, bases, dct):
|
23
28
|
"""Register the class in the registry if it has a _model_ attribute."""
|
edsl/notebooks/notebook.py
CHANGED
@@ -2,6 +2,10 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import json
|
5
|
+
import subprocess
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
import shutil
|
5
9
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
6
10
|
|
7
11
|
if TYPE_CHECKING:
|
@@ -17,12 +21,56 @@ class Notebook(Base):
|
|
17
21
|
"""
|
18
22
|
|
19
23
|
default_name = "notebook"
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
def _lint_code(code: str) -> str:
|
27
|
+
"""
|
28
|
+
Lint Python code using ruff.
|
29
|
+
|
30
|
+
:param code: The Python code to lint
|
31
|
+
:return: The linted code
|
32
|
+
"""
|
33
|
+
try:
|
34
|
+
# Check if ruff is installed
|
35
|
+
if shutil.which("ruff") is None:
|
36
|
+
# If ruff is not installed, return original code
|
37
|
+
return code
|
38
|
+
|
39
|
+
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as temp_file:
|
40
|
+
temp_file.write(code)
|
41
|
+
temp_file_path = temp_file.name
|
42
|
+
|
43
|
+
# Run ruff to format the code
|
44
|
+
try:
|
45
|
+
result = subprocess.run(
|
46
|
+
["ruff", "format", temp_file_path],
|
47
|
+
check=True,
|
48
|
+
stdout=subprocess.PIPE,
|
49
|
+
stderr=subprocess.PIPE
|
50
|
+
)
|
51
|
+
|
52
|
+
# Read the formatted code
|
53
|
+
with open(temp_file_path, 'r') as f:
|
54
|
+
linted_code = f.read()
|
55
|
+
|
56
|
+
return linted_code
|
57
|
+
except subprocess.CalledProcessError:
|
58
|
+
# If ruff fails, return the original code
|
59
|
+
return code
|
60
|
+
except FileNotFoundError:
|
61
|
+
# If ruff is not installed, return the original code
|
62
|
+
return code
|
63
|
+
finally:
|
64
|
+
# Clean up temporary file
|
65
|
+
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
|
66
|
+
os.unlink(temp_file_path)
|
20
67
|
|
21
68
|
def __init__(
|
22
69
|
self,
|
23
70
|
path: Optional[str] = None,
|
24
71
|
data: Optional[Dict] = None,
|
25
72
|
name: Optional[str] = None,
|
73
|
+
lint: bool = True,
|
26
74
|
):
|
27
75
|
"""
|
28
76
|
Initialize a new Notebook.
|
@@ -32,6 +80,7 @@ class Notebook(Base):
|
|
32
80
|
:param path: A filepath from which to load the notebook.
|
33
81
|
If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
|
34
82
|
:param name: A name for the Notebook.
|
83
|
+
:param lint: Whether to lint Python code cells using ruff. Defaults to True.
|
35
84
|
"""
|
36
85
|
import nbformat
|
37
86
|
|
@@ -54,6 +103,16 @@ class Notebook(Base):
|
|
54
103
|
raise NotebookEnvironmentError(
|
55
104
|
"Cannot create a notebook from within itself in this development environment"
|
56
105
|
)
|
106
|
+
|
107
|
+
# Store the lint parameter
|
108
|
+
self.lint = lint
|
109
|
+
|
110
|
+
# Apply linting to code cells if enabled
|
111
|
+
if self.lint and self.data and "cells" in self.data:
|
112
|
+
for cell in self.data["cells"]:
|
113
|
+
if cell.get("cell_type") == "code" and "source" in cell:
|
114
|
+
# Only lint Python code cells
|
115
|
+
cell["source"] = self._lint_code(cell["source"])
|
57
116
|
|
58
117
|
# TODO: perhaps add sanity check function
|
59
118
|
# 1. could check if the notebook is a valid notebook
|
@@ -63,7 +122,7 @@ class Notebook(Base):
|
|
63
122
|
self.name = name or self.default_name
|
64
123
|
|
65
124
|
@classmethod
|
66
|
-
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
125
|
+
def from_script(cls, path: str, name: Optional[str] = None, lint: bool = True) -> "Notebook":
|
67
126
|
import nbformat
|
68
127
|
|
69
128
|
# Read the script file
|
@@ -78,12 +137,12 @@ class Notebook(Base):
|
|
78
137
|
nb.cells.append(first_cell)
|
79
138
|
|
80
139
|
# Create a Notebook instance with the notebook data
|
81
|
-
notebook_instance = cls(nb)
|
140
|
+
notebook_instance = cls(data=nb, name=name, lint=lint)
|
82
141
|
|
83
142
|
return notebook_instance
|
84
143
|
|
85
144
|
@classmethod
|
86
|
-
def from_current_script(cls) -> "Notebook":
|
145
|
+
def from_current_script(cls, lint: bool = True) -> "Notebook":
|
87
146
|
import inspect
|
88
147
|
import os
|
89
148
|
|
@@ -93,7 +152,7 @@ class Notebook(Base):
|
|
93
152
|
current_file_path = os.path.abspath(caller_frame[1].filename)
|
94
153
|
|
95
154
|
# Use from_script to create the notebook
|
96
|
-
return cls.from_script(current_file_path)
|
155
|
+
return cls.from_script(current_file_path, lint=lint)
|
97
156
|
|
98
157
|
def __eq__(self, other):
|
99
158
|
"""
|
@@ -114,7 +173,7 @@ class Notebook(Base):
|
|
114
173
|
"""
|
115
174
|
Serialize to a dictionary.
|
116
175
|
"""
|
117
|
-
d = {"name": self.name, "data": self.data}
|
176
|
+
d = {"name": self.name, "data": self.data, "lint": self.lint}
|
118
177
|
if add_edsl_version:
|
119
178
|
from .. import __version__
|
120
179
|
|
@@ -124,11 +183,17 @@ class Notebook(Base):
|
|
124
183
|
|
125
184
|
@classmethod
|
126
185
|
@remove_edsl_version
|
127
|
-
def from_dict(cls, d: Dict) -> "Notebook":
|
186
|
+
def from_dict(cls, d: Dict, lint: bool = None) -> "Notebook":
|
128
187
|
"""
|
129
188
|
Convert a dictionary representation of a Notebook to a Notebook object.
|
189
|
+
|
190
|
+
:param d: Dictionary containing notebook data and name
|
191
|
+
:param lint: Whether to lint Python code cells. If None, uses the value from the dictionary or defaults to True.
|
192
|
+
:return: A new Notebook instance
|
130
193
|
"""
|
131
|
-
|
194
|
+
# Use the lint parameter from the dictionary if none is provided, otherwise default to True
|
195
|
+
notebook_lint = lint if lint is not None else d.get("lint", True)
|
196
|
+
return cls(data=d["data"], name=d["name"], lint=notebook_lint)
|
132
197
|
|
133
198
|
def to_file(self, path: str):
|
134
199
|
"""
|
@@ -205,11 +270,13 @@ class Notebook(Base):
|
|
205
270
|
return table
|
206
271
|
|
207
272
|
@classmethod
|
208
|
-
def example(cls, randomize: bool = False) -> Notebook:
|
273
|
+
def example(cls, randomize: bool = False, lint: bool = True) -> Notebook:
|
209
274
|
"""
|
210
275
|
Returns an example Notebook instance.
|
211
276
|
|
212
277
|
:param randomize: If True, adds a random string one of the cells' output.
|
278
|
+
:param lint: Whether to lint Python code cells. Defaults to True.
|
279
|
+
:return: An example Notebook instance
|
213
280
|
"""
|
214
281
|
addition = "" if not randomize else str(uuid4())
|
215
282
|
cells = [
|
@@ -238,7 +305,7 @@ class Notebook(Base):
|
|
238
305
|
"nbformat_minor": 4,
|
239
306
|
"cells": cells,
|
240
307
|
}
|
241
|
-
return cls(data=data)
|
308
|
+
return cls(data=data, lint=lint)
|
242
309
|
|
243
310
|
def code(self) -> List[str]:
|
244
311
|
"""
|
@@ -246,7 +313,7 @@ class Notebook(Base):
|
|
246
313
|
"""
|
247
314
|
lines = []
|
248
315
|
lines.append("from edsl import Notebook") # Keep as absolute for code generation
|
249
|
-
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
|
316
|
+
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""", lint={self.lint})')
|
250
317
|
return lines
|
251
318
|
|
252
319
|
def to_latex(self, filename: str):
|