edsl 0.1.53__py3-none-any.whl → 0.1.55__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/__init__.py +8 -1
- edsl/__init__original.py +134 -0
- edsl/__version__.py +1 -1
- edsl/agents/agent.py +29 -0
- edsl/agents/agent_list.py +36 -1
- edsl/base/base_class.py +281 -151
- edsl/buckets/__init__.py +8 -3
- edsl/buckets/bucket_collection.py +9 -3
- edsl/buckets/model_buckets.py +4 -2
- edsl/buckets/token_bucket.py +2 -2
- edsl/buckets/token_bucket_client.py +5 -3
- edsl/caching/cache.py +131 -62
- edsl/caching/cache_entry.py +70 -58
- edsl/caching/sql_dict.py +17 -0
- edsl/cli.py +99 -0
- edsl/config/config_class.py +16 -0
- edsl/conversation/__init__.py +31 -0
- edsl/coop/coop.py +276 -242
- edsl/coop/coop_jobs_objects.py +59 -0
- edsl/coop/coop_objects.py +29 -0
- edsl/coop/coop_regular_objects.py +26 -0
- edsl/coop/utils.py +24 -19
- edsl/dataset/dataset.py +338 -101
- edsl/db_list/sqlite_list.py +349 -0
- edsl/inference_services/__init__.py +40 -5
- edsl/inference_services/exceptions.py +11 -0
- edsl/inference_services/services/anthropic_service.py +5 -2
- edsl/inference_services/services/aws_bedrock.py +6 -2
- edsl/inference_services/services/azure_ai.py +6 -2
- edsl/inference_services/services/google_service.py +3 -2
- edsl/inference_services/services/mistral_ai_service.py +6 -2
- edsl/inference_services/services/open_ai_service.py +6 -2
- edsl/inference_services/services/perplexity_service.py +6 -2
- edsl/inference_services/services/test_service.py +105 -7
- edsl/interviews/answering_function.py +167 -59
- edsl/interviews/interview.py +124 -72
- edsl/interviews/interview_task_manager.py +10 -0
- edsl/invigilators/invigilators.py +10 -1
- edsl/jobs/async_interview_runner.py +146 -104
- edsl/jobs/data_structures.py +6 -4
- edsl/jobs/decorators.py +61 -0
- edsl/jobs/fetch_invigilator.py +61 -18
- edsl/jobs/html_table_job_logger.py +14 -2
- edsl/jobs/jobs.py +180 -104
- edsl/jobs/jobs_component_constructor.py +2 -2
- edsl/jobs/jobs_interview_constructor.py +2 -0
- edsl/jobs/jobs_pricing_estimation.py +127 -46
- edsl/jobs/jobs_remote_inference_logger.py +4 -0
- edsl/jobs/jobs_runner_status.py +30 -25
- edsl/jobs/progress_bar_manager.py +79 -0
- edsl/jobs/remote_inference.py +35 -1
- edsl/key_management/key_lookup_builder.py +6 -1
- edsl/language_models/language_model.py +102 -12
- edsl/language_models/model.py +10 -3
- edsl/language_models/price_manager.py +45 -75
- edsl/language_models/registry.py +5 -0
- edsl/language_models/utilities.py +2 -1
- edsl/notebooks/notebook.py +77 -10
- edsl/questions/VALIDATION_README.md +134 -0
- edsl/questions/__init__.py +24 -1
- edsl/questions/exceptions.py +21 -0
- edsl/questions/question_check_box.py +171 -149
- edsl/questions/question_dict.py +243 -51
- edsl/questions/question_multiple_choice_with_other.py +624 -0
- edsl/questions/question_registry.py +2 -1
- edsl/questions/templates/multiple_choice_with_other/__init__.py +0 -0
- edsl/questions/templates/multiple_choice_with_other/answering_instructions.jinja +15 -0
- edsl/questions/templates/multiple_choice_with_other/question_presentation.jinja +17 -0
- edsl/questions/validation_analysis.py +185 -0
- edsl/questions/validation_cli.py +131 -0
- edsl/questions/validation_html_report.py +404 -0
- edsl/questions/validation_logger.py +136 -0
- edsl/results/result.py +63 -16
- edsl/results/results.py +702 -171
- edsl/scenarios/construct_download_link.py +16 -3
- edsl/scenarios/directory_scanner.py +226 -226
- edsl/scenarios/file_methods.py +5 -0
- edsl/scenarios/file_store.py +117 -6
- edsl/scenarios/handlers/__init__.py +5 -1
- edsl/scenarios/handlers/mp4_file_store.py +104 -0
- edsl/scenarios/handlers/webm_file_store.py +104 -0
- edsl/scenarios/scenario.py +120 -101
- edsl/scenarios/scenario_list.py +800 -727
- edsl/scenarios/scenario_list_gc_test.py +146 -0
- edsl/scenarios/scenario_list_memory_test.py +214 -0
- edsl/scenarios/scenario_list_source_refactor.md +35 -0
- edsl/scenarios/scenario_selector.py +5 -4
- edsl/scenarios/scenario_source.py +1990 -0
- edsl/scenarios/tests/test_scenario_list_sources.py +52 -0
- edsl/surveys/survey.py +22 -0
- edsl/tasks/__init__.py +4 -2
- edsl/tasks/task_history.py +198 -36
- edsl/tests/scenarios/test_ScenarioSource.py +51 -0
- edsl/tests/scenarios/test_scenario_list_sources.py +51 -0
- edsl/utilities/__init__.py +2 -1
- edsl/utilities/decorators.py +121 -0
- edsl/utilities/memory_debugger.py +1010 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/METADATA +52 -76
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/RECORD +102 -78
- edsl/jobs/jobs_runner_asyncio.py +0 -281
- edsl/language_models/unused/fake_openai_service.py +0 -60
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/LICENSE +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/WHEEL +0 -0
- {edsl-0.1.53.dist-info → edsl-0.1.55.dist-info}/entry_points.txt +0 -0
@@ -365,6 +365,59 @@ class LanguageModel(
|
|
365
365
|
self._api_token = info.api_token
|
366
366
|
return self._api_token
|
367
367
|
|
368
|
+
def copy(self) -> "LanguageModel":
|
369
|
+
"""Create a deep copy of this language model instance.
|
370
|
+
|
371
|
+
This method creates a completely independent copy of the language model
|
372
|
+
by creating a new instance with the same parameters and copying relevant attributes.
|
373
|
+
|
374
|
+
Returns:
|
375
|
+
LanguageModel: A new language model instance that is functionally identical to this one
|
376
|
+
|
377
|
+
Examples:
|
378
|
+
>>> m1 = LanguageModel.example()
|
379
|
+
>>> m2 = m1.copy()
|
380
|
+
>>> m1 == m2 # Functionally equivalent
|
381
|
+
True
|
382
|
+
>>> id(m1) == id(m2) # But different objects
|
383
|
+
False
|
384
|
+
"""
|
385
|
+
# Create a new instance of the same class with the same parameters
|
386
|
+
try:
|
387
|
+
# For most models, we can instantiate with the saved parameters
|
388
|
+
new_model = self.__class__(**self.parameters)
|
389
|
+
|
390
|
+
# Copy all important instance attributes
|
391
|
+
for key, value in self.__dict__.items():
|
392
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
393
|
+
setattr(new_model, key, value)
|
394
|
+
|
395
|
+
return new_model
|
396
|
+
except Exception:
|
397
|
+
# Fallback for dynamically created classes like TestServiceLanguageModel
|
398
|
+
from ..inference_services import default
|
399
|
+
|
400
|
+
# If this is a test model, create a new test model instance
|
401
|
+
if getattr(self, "_inference_service_", "") == "test":
|
402
|
+
service = default.get_service("test")
|
403
|
+
model_class = service.create_model("test")
|
404
|
+
new_model = model_class(**self.parameters)
|
405
|
+
|
406
|
+
# Copy attributes
|
407
|
+
for key, value in self.__dict__.items():
|
408
|
+
if key not in ("_api_token",) and not key.startswith("__"):
|
409
|
+
setattr(new_model, key, value)
|
410
|
+
|
411
|
+
return new_model
|
412
|
+
|
413
|
+
# If we can't create the model directly, just return a simple test model
|
414
|
+
# This is a last resort fallback
|
415
|
+
from ..inference_services import get_service
|
416
|
+
|
417
|
+
service = get_service("test")
|
418
|
+
model_class = service.create_model("test")
|
419
|
+
return model_class()
|
420
|
+
|
368
421
|
def __getitem__(self, key):
|
369
422
|
"""Allow dictionary-style access to model attributes.
|
370
423
|
|
@@ -509,7 +562,9 @@ class LanguageModel(
|
|
509
562
|
return self.execute_model_call(user_prompt, system_prompt)
|
510
563
|
|
511
564
|
@abstractmethod
|
512
|
-
async def async_execute_model_call(
|
565
|
+
async def async_execute_model_call(
|
566
|
+
self, user_prompt: str, system_prompt: str, question_name: Optional[str] = None
|
567
|
+
):
|
513
568
|
"""Execute the model call asynchronously.
|
514
569
|
|
515
570
|
This abstract method must be implemented by all model subclasses
|
@@ -518,6 +573,7 @@ class LanguageModel(
|
|
518
573
|
Args:
|
519
574
|
user_prompt: The user message or input prompt
|
520
575
|
system_prompt: The system message or context
|
576
|
+
question_name: Optional name of the question being asked (primarily used for test models)
|
521
577
|
|
522
578
|
Returns:
|
523
579
|
Coroutine that resolves to the model response
|
@@ -529,7 +585,7 @@ class LanguageModel(
|
|
529
585
|
pass
|
530
586
|
|
531
587
|
async def remote_async_execute_model_call(
|
532
|
-
self, user_prompt: str, system_prompt: str
|
588
|
+
self, user_prompt: str, system_prompt: str, question_name: Optional[str] = None
|
533
589
|
):
|
534
590
|
"""Execute the model call remotely through the EDSL Coop service.
|
535
591
|
|
@@ -540,6 +596,7 @@ class LanguageModel(
|
|
540
596
|
Args:
|
541
597
|
user_prompt: The user message or input prompt
|
542
598
|
system_prompt: The system message or context
|
599
|
+
question_name: Optional name of the question being asked (primarily used for test models)
|
543
600
|
|
544
601
|
Returns:
|
545
602
|
Coroutine that resolves to the model response from the remote service
|
@@ -563,6 +620,7 @@ class LanguageModel(
|
|
563
620
|
Args:
|
564
621
|
*args: Positional arguments to pass to async_execute_model_call
|
565
622
|
**kwargs: Keyword arguments to pass to async_execute_model_call
|
623
|
+
Can include question_name for test models
|
566
624
|
|
567
625
|
Returns:
|
568
626
|
The model response
|
@@ -674,9 +732,12 @@ class LanguageModel(
|
|
674
732
|
user_prompt_with_hashes = user_prompt
|
675
733
|
|
676
734
|
# Prepare parameters for cache lookup
|
735
|
+
cache_parameters = self.parameters.copy()
|
736
|
+
if self.model == "test":
|
737
|
+
cache_parameters.pop("canned_response", None)
|
677
738
|
cache_call_params = {
|
678
739
|
"model": str(self.model),
|
679
|
-
"parameters":
|
740
|
+
"parameters": cache_parameters,
|
680
741
|
"system_prompt": system_prompt,
|
681
742
|
"user_prompt": user_prompt_with_hashes,
|
682
743
|
"iteration": iteration,
|
@@ -702,7 +763,9 @@ class LanguageModel(
|
|
702
763
|
"system_prompt": system_prompt,
|
703
764
|
"files_list": files_list,
|
704
765
|
}
|
705
|
-
|
766
|
+
# Add question_name parameter for test models
|
767
|
+
if self.model == "test" and invigilator:
|
768
|
+
params["question_name"] = invigilator.question.question_name
|
706
769
|
# Get timeout from configuration
|
707
770
|
from ..config import CONFIG
|
708
771
|
|
@@ -710,7 +773,6 @@ class LanguageModel(
|
|
710
773
|
|
711
774
|
# Execute the model call with timeout
|
712
775
|
response = await asyncio.wait_for(f(**params), timeout=TIMEOUT)
|
713
|
-
|
714
776
|
# Store the response in the cache
|
715
777
|
new_cache_key = cache.store(
|
716
778
|
**cache_call_params, response=response, service=self._inference_service_
|
@@ -801,7 +863,6 @@ class LanguageModel(
|
|
801
863
|
|
802
864
|
# Create structured input record
|
803
865
|
model_inputs = ModelInputs(user_prompt=user_prompt, system_prompt=system_prompt)
|
804
|
-
|
805
866
|
# Get model response (using cache if available)
|
806
867
|
model_outputs: ModelResponse = (
|
807
868
|
await self._async_get_intended_model_call_outcome(**params)
|
@@ -839,7 +900,7 @@ class LanguageModel(
|
|
839
900
|
# Use the price manager to calculate the actual cost
|
840
901
|
from .price_manager import PriceManager
|
841
902
|
|
842
|
-
price_manager = PriceManager()
|
903
|
+
price_manager = PriceManager.get_instance()
|
843
904
|
|
844
905
|
return price_manager.calculate_cost(
|
845
906
|
inference_service=self._inference_service_,
|
@@ -868,9 +929,15 @@ class LanguageModel(
|
|
868
929
|
{'model': '...', 'parameters': {'temperature': ..., 'max_tokens': ..., 'top_p': ..., 'frequency_penalty': ..., 'presence_penalty': ..., 'logprobs': False, 'top_logprobs': ...}, 'inference_service': 'openai', 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
869
930
|
"""
|
870
931
|
# Build the base dictionary with essential model information
|
932
|
+
parameters = self.parameters.copy()
|
933
|
+
|
934
|
+
# For test models, ensure canned_response is included in serialization
|
935
|
+
if self.model == "test" and hasattr(self, "canned_response"):
|
936
|
+
parameters["canned_response"] = self.canned_response
|
937
|
+
|
871
938
|
d = {
|
872
939
|
"model": self.model,
|
873
|
-
"parameters":
|
940
|
+
"parameters": parameters,
|
874
941
|
"inference_service": self._inference_service_,
|
875
942
|
}
|
876
943
|
|
@@ -908,7 +975,25 @@ class LanguageModel(
|
|
908
975
|
data["model"], service_name=data.get("inference_service", None)
|
909
976
|
)
|
910
977
|
|
911
|
-
#
|
978
|
+
# Handle canned_response in parameters for test models
|
979
|
+
if (
|
980
|
+
data["model"] == "test"
|
981
|
+
and "parameters" in data
|
982
|
+
and "canned_response" in data["parameters"]
|
983
|
+
):
|
984
|
+
# Extract canned_response from parameters to set as a direct attribute
|
985
|
+
canned_response = data["parameters"]["canned_response"]
|
986
|
+
params_copy = data.copy()
|
987
|
+
|
988
|
+
# Direct attribute will be set during initialization
|
989
|
+
# Add it as a top-level parameter for model initialization
|
990
|
+
if isinstance(params_copy, dict) and "parameters" in params_copy:
|
991
|
+
params_copy["canned_response"] = canned_response
|
992
|
+
|
993
|
+
# Create the instance with canned_response as a direct parameter
|
994
|
+
return model_class(**params_copy)
|
995
|
+
|
996
|
+
# For non-test models or test models without canned_response
|
912
997
|
return model_class(**data)
|
913
998
|
|
914
999
|
def __repr__(self) -> str:
|
@@ -994,8 +1079,8 @@ class LanguageModel(
|
|
994
1079
|
|
995
1080
|
Create a test model that throws exceptions:
|
996
1081
|
|
997
|
-
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True)
|
998
|
-
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True)
|
1082
|
+
>>> m = LanguageModel.example(test_model=True, canned_response="WOWZA!", throw_exception=True) # doctest: +SKIP
|
1083
|
+
>>> r = q.by(m).run(cache=False, disable_remote_cache=True, disable_remote_inference=True, print_exceptions=True) # doctest: +SKIP
|
999
1084
|
Exception report saved to ...
|
1000
1085
|
"""
|
1001
1086
|
from ..language_models import Model
|
@@ -1046,7 +1131,12 @@ class LanguageModel(
|
|
1046
1131
|
]
|
1047
1132
|
|
1048
1133
|
# Define a new async_execute_model_call that only reads from cache
|
1049
|
-
async def async_execute_model_call(
|
1134
|
+
async def async_execute_model_call(
|
1135
|
+
self,
|
1136
|
+
user_prompt: str,
|
1137
|
+
system_prompt: str,
|
1138
|
+
question_name: Optional[str] = None,
|
1139
|
+
):
|
1050
1140
|
"""Only use cached responses, never making new API calls."""
|
1051
1141
|
cache_call_params = {
|
1052
1142
|
"model": str(self.model),
|
edsl/language_models/model.py
CHANGED
@@ -6,8 +6,10 @@ from ..utilities import PrettyList
|
|
6
6
|
from ..config import CONFIG
|
7
7
|
from .exceptions import LanguageModelValueError
|
8
8
|
|
9
|
+
# Import only what's needed initially to avoid circular imports
|
9
10
|
from ..inference_services import (InferenceServicesCollection,
|
10
|
-
AvailableModels, InferenceServiceABC, InferenceServiceError
|
11
|
+
AvailableModels, InferenceServiceABC, InferenceServiceError)
|
12
|
+
# The 'default' import will be imported lazily when needed
|
11
13
|
|
12
14
|
from ..enums import InferenceServiceLiteral
|
13
15
|
|
@@ -20,7 +22,10 @@ def get_model_class(
|
|
20
22
|
registry: Optional[InferenceServicesCollection] = None,
|
21
23
|
service_name: Optional[InferenceServiceLiteral] = None,
|
22
24
|
):
|
23
|
-
|
25
|
+
if registry is None:
|
26
|
+
# Import default lazily only when needed
|
27
|
+
from ..inference_services import default as inference_default
|
28
|
+
registry = inference_default
|
24
29
|
try:
|
25
30
|
factory = registry.create_model_factory(model_name, service_name=service_name)
|
26
31
|
return factory
|
@@ -54,7 +59,9 @@ class Model(metaclass=Meta):
|
|
54
59
|
def get_registry(cls) -> InferenceServicesCollection:
|
55
60
|
"""Get the current registry or initialize with default if None"""
|
56
61
|
if cls._registry is None:
|
57
|
-
|
62
|
+
# Import default lazily only when needed
|
63
|
+
from ..inference_services import default as inference_default
|
64
|
+
cls._registry = inference_default
|
58
65
|
return cls._registry
|
59
66
|
|
60
67
|
@classmethod
|
@@ -8,20 +8,42 @@ class PriceManager:
|
|
8
8
|
|
9
9
|
def __new__(cls):
|
10
10
|
if cls._instance is None:
|
11
|
-
|
11
|
+
instance = super(PriceManager, cls).__new__(cls)
|
12
|
+
instance._price_lookup = {} # Instance-specific attribute
|
13
|
+
instance._is_initialized = False
|
14
|
+
cls._instance = instance # Store the instance directly
|
15
|
+
return instance
|
12
16
|
return cls._instance
|
13
17
|
|
14
18
|
def __init__(self):
|
15
|
-
|
19
|
+
"""Initialize the singleton instance only once."""
|
16
20
|
if not self._is_initialized:
|
17
21
|
self._is_initialized = True
|
18
22
|
self.refresh_prices()
|
19
23
|
|
20
|
-
|
21
|
-
|
22
|
-
|
24
|
+
@classmethod
|
25
|
+
def get_instance(cls):
|
26
|
+
"""Get the singleton instance, creating it if necessary."""
|
27
|
+
if cls._instance is None:
|
28
|
+
cls() # Create the instance if it doesn't exist
|
29
|
+
return cls._instance
|
30
|
+
|
31
|
+
@classmethod
|
32
|
+
def reset(cls):
|
33
|
+
"""Reset the singleton instance to clean up resources."""
|
34
|
+
cls._instance = None
|
35
|
+
cls._is_initialized = False
|
36
|
+
cls._price_lookup = {}
|
23
37
|
|
24
|
-
|
38
|
+
def __del__(self):
|
39
|
+
"""Ensure proper cleanup when the instance is garbage collected."""
|
40
|
+
try:
|
41
|
+
self._price_lookup = {} # Clean up resources
|
42
|
+
except:
|
43
|
+
pass # Ignore any cleanup errors
|
44
|
+
|
45
|
+
def refresh_prices(self) -> None:
|
46
|
+
"""Fetch fresh prices and update the internal price lookup."""
|
25
47
|
from edsl.coop import Coop
|
26
48
|
|
27
49
|
c = Coop()
|
@@ -31,43 +53,18 @@ class PriceManager:
|
|
31
53
|
print(f"Error fetching prices: {str(e)}")
|
32
54
|
|
33
55
|
def get_price(self, inference_service: str, model: str) -> Dict:
|
34
|
-
"""
|
35
|
-
Get the price information for a specific service and model combination.
|
36
|
-
If no specific price is found, returns a fallback price.
|
37
|
-
|
38
|
-
Args:
|
39
|
-
inference_service (str): The name of the inference service
|
40
|
-
model (str): The model identifier
|
41
|
-
|
42
|
-
Returns:
|
43
|
-
Dict: Price information (either actual or fallback prices)
|
44
|
-
"""
|
56
|
+
"""Get the price information for a specific service and model."""
|
45
57
|
key = (inference_service, model)
|
46
58
|
return self._price_lookup.get(key) or self._get_fallback_price(
|
47
59
|
inference_service
|
48
60
|
)
|
49
61
|
|
50
62
|
def get_all_prices(self) -> Dict[Tuple[str, str], Dict]:
|
51
|
-
"""
|
52
|
-
Get the complete price lookup dictionary.
|
53
|
-
|
54
|
-
Returns:
|
55
|
-
Dict[Tuple[str, str], Dict]: The complete price lookup dictionary
|
56
|
-
"""
|
63
|
+
"""Get the complete price lookup dictionary."""
|
57
64
|
return self._price_lookup.copy()
|
58
65
|
|
59
66
|
def _get_fallback_price(self, inference_service: str) -> Dict:
|
60
|
-
"""
|
61
|
-
Get fallback prices for a service.
|
62
|
-
- First fallback: The highest input and output prices for that service from the price lookup.
|
63
|
-
- Second fallback: $1.00 per million tokens (for both input and output).
|
64
|
-
|
65
|
-
Args:
|
66
|
-
inference_service (str): The inference service name
|
67
|
-
|
68
|
-
Returns:
|
69
|
-
Dict: Price information
|
70
|
-
"""
|
67
|
+
"""Get fallback prices for a service."""
|
71
68
|
service_prices = [
|
72
69
|
prices
|
73
70
|
for (service, _), prices in self._price_lookup.items()
|
@@ -77,18 +74,12 @@ class PriceManager:
|
|
77
74
|
input_tokens_per_usd = [
|
78
75
|
float(p["input"]["one_usd_buys"]) for p in service_prices if "input" in p
|
79
76
|
]
|
80
|
-
|
81
|
-
min_input_tokens = min(input_tokens_per_usd)
|
82
|
-
else:
|
83
|
-
min_input_tokens = 1_000_000
|
77
|
+
min_input_tokens = min(input_tokens_per_usd, default=1_000_000)
|
84
78
|
|
85
79
|
output_tokens_per_usd = [
|
86
80
|
float(p["output"]["one_usd_buys"]) for p in service_prices if "output" in p
|
87
81
|
]
|
88
|
-
|
89
|
-
min_output_tokens = min(output_tokens_per_usd)
|
90
|
-
else:
|
91
|
-
min_output_tokens = 1_000_000
|
82
|
+
min_output_tokens = min(output_tokens_per_usd, default=1_000_000)
|
92
83
|
|
93
84
|
return {
|
94
85
|
"input": {"one_usd_buys": min_input_tokens},
|
@@ -103,19 +94,7 @@ class PriceManager:
|
|
103
94
|
input_token_name: str,
|
104
95
|
output_token_name: str,
|
105
96
|
) -> Union[float, str]:
|
106
|
-
"""
|
107
|
-
Calculate the total cost for a model usage based on input and output tokens.
|
108
|
-
|
109
|
-
Args:
|
110
|
-
inference_service (str): The inference service identifier
|
111
|
-
model (str): The model identifier
|
112
|
-
usage (Dict[str, Union[str, int]]): Dictionary containing token usage information
|
113
|
-
input_token_name (str): Key name for input tokens in the usage dict
|
114
|
-
output_token_name (str): Key name for output tokens in the usage dict
|
115
|
-
|
116
|
-
Returns:
|
117
|
-
Union[float, str]: Total cost if calculation successful, error message string if not
|
118
|
-
"""
|
97
|
+
"""Calculate the total cost for a model usage."""
|
119
98
|
relevant_prices = self.get_price(inference_service, model)
|
120
99
|
|
121
100
|
# Extract token counts
|
@@ -137,31 +116,22 @@ class PriceManager:
|
|
137
116
|
return f"Could not fetch prices from {relevant_prices} - {e}"
|
138
117
|
|
139
118
|
# Calculate input cost
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
except Exception as e:
|
146
|
-
return f"Could not compute input price - {e}."
|
119
|
+
input_cost = (
|
120
|
+
0
|
121
|
+
if inverse_input_price == "infinity"
|
122
|
+
else input_tokens / float(inverse_input_price)
|
123
|
+
)
|
147
124
|
|
148
125
|
# Calculate output cost
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
except Exception as e:
|
155
|
-
return f"Could not compute output price - {e}"
|
126
|
+
output_cost = (
|
127
|
+
0
|
128
|
+
if inverse_output_price == "infinity"
|
129
|
+
else output_tokens / float(inverse_output_price)
|
130
|
+
)
|
156
131
|
|
157
132
|
return input_cost + output_cost
|
158
133
|
|
159
134
|
@property
|
160
135
|
def is_initialized(self) -> bool:
|
161
|
-
"""
|
162
|
-
Check if the PriceManager has been initialized.
|
163
|
-
|
164
|
-
Returns:
|
165
|
-
bool: True if initialized, False otherwise
|
166
|
-
"""
|
136
|
+
"""Check if the PriceManager has been initialized."""
|
167
137
|
return self._is_initialized
|
edsl/language_models/registry.py
CHANGED
@@ -18,6 +18,11 @@ class RegisterLanguageModelsMeta(ABCMeta):
|
|
18
18
|
|
19
19
|
_registry = {} # Initialize the registry as a dictionary
|
20
20
|
REQUIRED_CLASS_ATTRIBUTES = ["_model_", "_parameters_", "_inference_service_"]
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def clear_registry(cls):
|
24
|
+
"""Clear the registry to prevent memory leaks."""
|
25
|
+
cls._registry = {}
|
21
26
|
|
22
27
|
def __init__(cls, name, bases, dct):
|
23
28
|
"""Register the class in the registry if it has a _model_ attribute."""
|
@@ -5,6 +5,7 @@ from ..surveys import Survey
|
|
5
5
|
|
6
6
|
from .language_model import LanguageModel
|
7
7
|
|
8
|
+
|
8
9
|
def create_survey(num_questions: int, chained: bool = True, take_scenario=False):
|
9
10
|
from ..questions import QuestionFreeText
|
10
11
|
|
@@ -28,7 +29,6 @@ def create_survey(num_questions: int, chained: bool = True, take_scenario=False)
|
|
28
29
|
def create_language_model(
|
29
30
|
exception: Exception, fail_at_number: int, never_ending=False
|
30
31
|
):
|
31
|
-
|
32
32
|
class LanguageModelFromUtilities(LanguageModel):
|
33
33
|
_model_ = "test"
|
34
34
|
_parameters_ = {"temperature": 0.5}
|
@@ -45,6 +45,7 @@ def create_language_model(
|
|
45
45
|
user_prompt: str,
|
46
46
|
system_prompt: str,
|
47
47
|
files_list: Optional[List[Any]] = None,
|
48
|
+
question_name: Optional[str] = None,
|
48
49
|
) -> dict[str, Any]:
|
49
50
|
question_number = int(
|
50
51
|
user_prompt.split("XX")[1]
|
edsl/notebooks/notebook.py
CHANGED
@@ -2,6 +2,10 @@
|
|
2
2
|
|
3
3
|
from __future__ import annotations
|
4
4
|
import json
|
5
|
+
import subprocess
|
6
|
+
import tempfile
|
7
|
+
import os
|
8
|
+
import shutil
|
5
9
|
from typing import Dict, List, Optional, TYPE_CHECKING
|
6
10
|
|
7
11
|
if TYPE_CHECKING:
|
@@ -17,12 +21,56 @@ class Notebook(Base):
|
|
17
21
|
"""
|
18
22
|
|
19
23
|
default_name = "notebook"
|
24
|
+
|
25
|
+
@staticmethod
|
26
|
+
def _lint_code(code: str) -> str:
|
27
|
+
"""
|
28
|
+
Lint Python code using ruff.
|
29
|
+
|
30
|
+
:param code: The Python code to lint
|
31
|
+
:return: The linted code
|
32
|
+
"""
|
33
|
+
try:
|
34
|
+
# Check if ruff is installed
|
35
|
+
if shutil.which("ruff") is None:
|
36
|
+
# If ruff is not installed, return original code
|
37
|
+
return code
|
38
|
+
|
39
|
+
with tempfile.NamedTemporaryFile(mode='w+', suffix='.py', delete=False) as temp_file:
|
40
|
+
temp_file.write(code)
|
41
|
+
temp_file_path = temp_file.name
|
42
|
+
|
43
|
+
# Run ruff to format the code
|
44
|
+
try:
|
45
|
+
result = subprocess.run(
|
46
|
+
["ruff", "format", temp_file_path],
|
47
|
+
check=True,
|
48
|
+
stdout=subprocess.PIPE,
|
49
|
+
stderr=subprocess.PIPE
|
50
|
+
)
|
51
|
+
|
52
|
+
# Read the formatted code
|
53
|
+
with open(temp_file_path, 'r') as f:
|
54
|
+
linted_code = f.read()
|
55
|
+
|
56
|
+
return linted_code
|
57
|
+
except subprocess.CalledProcessError:
|
58
|
+
# If ruff fails, return the original code
|
59
|
+
return code
|
60
|
+
except FileNotFoundError:
|
61
|
+
# If ruff is not installed, return the original code
|
62
|
+
return code
|
63
|
+
finally:
|
64
|
+
# Clean up temporary file
|
65
|
+
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
|
66
|
+
os.unlink(temp_file_path)
|
20
67
|
|
21
68
|
def __init__(
|
22
69
|
self,
|
23
70
|
path: Optional[str] = None,
|
24
71
|
data: Optional[Dict] = None,
|
25
72
|
name: Optional[str] = None,
|
73
|
+
lint: bool = True,
|
26
74
|
):
|
27
75
|
"""
|
28
76
|
Initialize a new Notebook.
|
@@ -32,6 +80,7 @@ class Notebook(Base):
|
|
32
80
|
:param path: A filepath from which to load the notebook.
|
33
81
|
If no path is provided, assume this code is run in a notebook and try to load the current notebook from file.
|
34
82
|
:param name: A name for the Notebook.
|
83
|
+
:param lint: Whether to lint Python code cells using ruff. Defaults to True.
|
35
84
|
"""
|
36
85
|
import nbformat
|
37
86
|
|
@@ -54,6 +103,16 @@ class Notebook(Base):
|
|
54
103
|
raise NotebookEnvironmentError(
|
55
104
|
"Cannot create a notebook from within itself in this development environment"
|
56
105
|
)
|
106
|
+
|
107
|
+
# Store the lint parameter
|
108
|
+
self.lint = lint
|
109
|
+
|
110
|
+
# Apply linting to code cells if enabled
|
111
|
+
if self.lint and self.data and "cells" in self.data:
|
112
|
+
for cell in self.data["cells"]:
|
113
|
+
if cell.get("cell_type") == "code" and "source" in cell:
|
114
|
+
# Only lint Python code cells
|
115
|
+
cell["source"] = self._lint_code(cell["source"])
|
57
116
|
|
58
117
|
# TODO: perhaps add sanity check function
|
59
118
|
# 1. could check if the notebook is a valid notebook
|
@@ -63,7 +122,7 @@ class Notebook(Base):
|
|
63
122
|
self.name = name or self.default_name
|
64
123
|
|
65
124
|
@classmethod
|
66
|
-
def from_script(cls, path: str, name: Optional[str] = None) -> "Notebook":
|
125
|
+
def from_script(cls, path: str, name: Optional[str] = None, lint: bool = True) -> "Notebook":
|
67
126
|
import nbformat
|
68
127
|
|
69
128
|
# Read the script file
|
@@ -78,12 +137,12 @@ class Notebook(Base):
|
|
78
137
|
nb.cells.append(first_cell)
|
79
138
|
|
80
139
|
# Create a Notebook instance with the notebook data
|
81
|
-
notebook_instance = cls(nb)
|
140
|
+
notebook_instance = cls(data=nb, name=name, lint=lint)
|
82
141
|
|
83
142
|
return notebook_instance
|
84
143
|
|
85
144
|
@classmethod
|
86
|
-
def from_current_script(cls) -> "Notebook":
|
145
|
+
def from_current_script(cls, lint: bool = True) -> "Notebook":
|
87
146
|
import inspect
|
88
147
|
import os
|
89
148
|
|
@@ -93,7 +152,7 @@ class Notebook(Base):
|
|
93
152
|
current_file_path = os.path.abspath(caller_frame[1].filename)
|
94
153
|
|
95
154
|
# Use from_script to create the notebook
|
96
|
-
return cls.from_script(current_file_path)
|
155
|
+
return cls.from_script(current_file_path, lint=lint)
|
97
156
|
|
98
157
|
def __eq__(self, other):
|
99
158
|
"""
|
@@ -114,7 +173,7 @@ class Notebook(Base):
|
|
114
173
|
"""
|
115
174
|
Serialize to a dictionary.
|
116
175
|
"""
|
117
|
-
d = {"name": self.name, "data": self.data}
|
176
|
+
d = {"name": self.name, "data": self.data, "lint": self.lint}
|
118
177
|
if add_edsl_version:
|
119
178
|
from .. import __version__
|
120
179
|
|
@@ -124,11 +183,17 @@ class Notebook(Base):
|
|
124
183
|
|
125
184
|
@classmethod
|
126
185
|
@remove_edsl_version
|
127
|
-
def from_dict(cls, d: Dict) -> "Notebook":
|
186
|
+
def from_dict(cls, d: Dict, lint: bool = None) -> "Notebook":
|
128
187
|
"""
|
129
188
|
Convert a dictionary representation of a Notebook to a Notebook object.
|
189
|
+
|
190
|
+
:param d: Dictionary containing notebook data and name
|
191
|
+
:param lint: Whether to lint Python code cells. If None, uses the value from the dictionary or defaults to True.
|
192
|
+
:return: A new Notebook instance
|
130
193
|
"""
|
131
|
-
|
194
|
+
# Use the lint parameter from the dictionary if none is provided, otherwise default to True
|
195
|
+
notebook_lint = lint if lint is not None else d.get("lint", True)
|
196
|
+
return cls(data=d["data"], name=d["name"], lint=notebook_lint)
|
132
197
|
|
133
198
|
def to_file(self, path: str):
|
134
199
|
"""
|
@@ -205,11 +270,13 @@ class Notebook(Base):
|
|
205
270
|
return table
|
206
271
|
|
207
272
|
@classmethod
|
208
|
-
def example(cls, randomize: bool = False) -> Notebook:
|
273
|
+
def example(cls, randomize: bool = False, lint: bool = True) -> Notebook:
|
209
274
|
"""
|
210
275
|
Returns an example Notebook instance.
|
211
276
|
|
212
277
|
:param randomize: If True, adds a random string one of the cells' output.
|
278
|
+
:param lint: Whether to lint Python code cells. Defaults to True.
|
279
|
+
:return: An example Notebook instance
|
213
280
|
"""
|
214
281
|
addition = "" if not randomize else str(uuid4())
|
215
282
|
cells = [
|
@@ -238,7 +305,7 @@ class Notebook(Base):
|
|
238
305
|
"nbformat_minor": 4,
|
239
306
|
"cells": cells,
|
240
307
|
}
|
241
|
-
return cls(data=data)
|
308
|
+
return cls(data=data, lint=lint)
|
242
309
|
|
243
310
|
def code(self) -> List[str]:
|
244
311
|
"""
|
@@ -246,7 +313,7 @@ class Notebook(Base):
|
|
246
313
|
"""
|
247
314
|
lines = []
|
248
315
|
lines.append("from edsl import Notebook") # Keep as absolute for code generation
|
249
|
-
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""")')
|
316
|
+
lines.append(f'nb = Notebook(data={self.data}, name="""{self.name}""", lint={self.lint})')
|
250
317
|
return lines
|
251
318
|
|
252
319
|
def to_latex(self, filename: str):
|