edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- edsl/Base.py +107 -30
- edsl/BaseDiff.py +260 -0
- edsl/__init__.py +25 -21
- edsl/__version__.py +1 -1
- edsl/agents/Agent.py +103 -46
- edsl/agents/AgentList.py +97 -13
- edsl/agents/Invigilator.py +23 -10
- edsl/agents/InvigilatorBase.py +19 -14
- edsl/agents/PromptConstructionMixin.py +342 -100
- edsl/agents/descriptors.py +5 -2
- edsl/base/Base.py +289 -0
- edsl/config.py +2 -1
- edsl/conjure/AgentConstructionMixin.py +152 -0
- edsl/conjure/Conjure.py +56 -0
- edsl/conjure/InputData.py +659 -0
- edsl/conjure/InputDataCSV.py +48 -0
- edsl/conjure/InputDataMixinQuestionStats.py +182 -0
- edsl/conjure/InputDataPyRead.py +91 -0
- edsl/conjure/InputDataSPSS.py +8 -0
- edsl/conjure/InputDataStata.py +8 -0
- edsl/conjure/QuestionOptionMixin.py +76 -0
- edsl/conjure/QuestionTypeMixin.py +23 -0
- edsl/conjure/RawQuestion.py +65 -0
- edsl/conjure/SurveyResponses.py +7 -0
- edsl/conjure/__init__.py +9 -4
- edsl/conjure/examples/placeholder.txt +0 -0
- edsl/conjure/naming_utilities.py +263 -0
- edsl/conjure/utilities.py +165 -28
- edsl/conversation/Conversation.py +238 -0
- edsl/conversation/car_buying.py +58 -0
- edsl/conversation/mug_negotiation.py +81 -0
- edsl/conversation/next_speaker_utilities.py +93 -0
- edsl/coop/coop.py +337 -121
- edsl/coop/utils.py +56 -70
- edsl/data/Cache.py +74 -22
- edsl/data/CacheHandler.py +10 -9
- edsl/data/SQLiteDict.py +11 -3
- edsl/inference_services/AnthropicService.py +1 -0
- edsl/inference_services/DeepInfraService.py +20 -13
- edsl/inference_services/GoogleService.py +7 -1
- edsl/inference_services/InferenceServicesCollection.py +33 -7
- edsl/inference_services/OpenAIService.py +17 -10
- edsl/inference_services/models_available_cache.py +69 -0
- edsl/inference_services/rate_limits_cache.py +25 -0
- edsl/inference_services/write_available.py +10 -0
- edsl/jobs/Answers.py +15 -1
- edsl/jobs/Jobs.py +322 -73
- edsl/jobs/buckets/BucketCollection.py +9 -3
- edsl/jobs/buckets/ModelBuckets.py +4 -2
- edsl/jobs/buckets/TokenBucket.py +1 -2
- edsl/jobs/interviews/Interview.py +7 -10
- edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
- edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
- edsl/jobs/interviews/retry_management.py +4 -4
- edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
- edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
- edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
- edsl/jobs/tasks/TaskHistory.py +4 -3
- edsl/language_models/LanguageModel.py +42 -55
- edsl/language_models/ModelList.py +96 -0
- edsl/language_models/registry.py +14 -0
- edsl/language_models/repair.py +97 -25
- edsl/notebooks/Notebook.py +157 -32
- edsl/prompts/Prompt.py +31 -19
- edsl/questions/QuestionBase.py +145 -23
- edsl/questions/QuestionBudget.py +5 -6
- edsl/questions/QuestionCheckBox.py +7 -3
- edsl/questions/QuestionExtract.py +5 -3
- edsl/questions/QuestionFreeText.py +3 -3
- edsl/questions/QuestionFunctional.py +0 -3
- edsl/questions/QuestionList.py +3 -4
- edsl/questions/QuestionMultipleChoice.py +16 -8
- edsl/questions/QuestionNumerical.py +4 -3
- edsl/questions/QuestionRank.py +5 -3
- edsl/questions/__init__.py +4 -3
- edsl/questions/descriptors.py +9 -4
- edsl/questions/question_registry.py +27 -31
- edsl/questions/settings.py +1 -1
- edsl/results/Dataset.py +31 -0
- edsl/results/DatasetExportMixin.py +493 -0
- edsl/results/Result.py +42 -82
- edsl/results/Results.py +178 -66
- edsl/results/ResultsDBMixin.py +10 -9
- edsl/results/ResultsExportMixin.py +23 -507
- edsl/results/ResultsGGMixin.py +3 -3
- edsl/results/ResultsToolsMixin.py +9 -9
- edsl/scenarios/FileStore.py +140 -0
- edsl/scenarios/Scenario.py +59 -6
- edsl/scenarios/ScenarioList.py +138 -52
- edsl/scenarios/ScenarioListExportMixin.py +32 -0
- edsl/scenarios/ScenarioListPdfMixin.py +2 -1
- edsl/scenarios/__init__.py +1 -0
- edsl/study/ObjectEntry.py +173 -0
- edsl/study/ProofOfWork.py +113 -0
- edsl/study/SnapShot.py +73 -0
- edsl/study/Study.py +498 -0
- edsl/study/__init__.py +4 -0
- edsl/surveys/MemoryPlan.py +11 -4
- edsl/surveys/Survey.py +124 -37
- edsl/surveys/SurveyExportMixin.py +25 -5
- edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
- edsl/tools/plotting.py +4 -2
- edsl/utilities/__init__.py +21 -20
- edsl/utilities/gcp_bucket/__init__.py +0 -0
- edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
- edsl/utilities/gcp_bucket/simple_example.py +9 -0
- edsl/utilities/interface.py +90 -73
- edsl/utilities/repair_functions.py +28 -0
- edsl/utilities/utilities.py +59 -6
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
- edsl-0.1.29.dist-info/RECORD +203 -0
- edsl/conjure/RawResponseColumn.py +0 -327
- edsl/conjure/SurveyBuilder.py +0 -308
- edsl/conjure/SurveyBuilderCSV.py +0 -78
- edsl/conjure/SurveyBuilderSPSS.py +0 -118
- edsl/data/RemoteDict.py +0 -103
- edsl-0.1.27.dev2.dist-info/RECORD +0 -172
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
- {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
@@ -7,26 +7,18 @@ import asyncio
|
|
7
7
|
import json
|
8
8
|
import time
|
9
9
|
import os
|
10
|
-
|
11
10
|
from typing import Coroutine, Any, Callable, Type, List, get_type_hints
|
12
|
-
|
13
|
-
from abc import ABC, abstractmethod, ABCMeta
|
14
|
-
|
15
|
-
from rich.table import Table
|
11
|
+
from abc import ABC, abstractmethod
|
16
12
|
|
17
13
|
from edsl.config import CONFIG
|
18
14
|
|
19
|
-
from edsl.utilities.utilities import clean_json
|
20
15
|
from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
|
21
16
|
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
17
|
+
|
22
18
|
from edsl.language_models.repair import repair
|
23
|
-
from edsl.exceptions.language_models import LanguageModelAttributeTypeError
|
24
19
|
from edsl.enums import InferenceServiceType
|
25
20
|
from edsl.Base import RichPrintingMixin, PersistenceMixin
|
26
|
-
from edsl.data.Cache import Cache
|
27
21
|
from edsl.enums import service_to_api_keyname
|
28
|
-
|
29
|
-
|
30
22
|
from edsl.exceptions import MissingAPIKeyError
|
31
23
|
from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
|
32
24
|
|
@@ -142,7 +134,7 @@ class LanguageModel(
|
|
142
134
|
def has_valid_api_key(self) -> bool:
|
143
135
|
"""Check if the model has a valid API key.
|
144
136
|
|
145
|
-
>>> LanguageModel.example().has_valid_api_key()
|
137
|
+
>>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
|
146
138
|
True
|
147
139
|
|
148
140
|
This method is used to check if the model has a valid API key.
|
@@ -159,7 +151,9 @@ class LanguageModel(
|
|
159
151
|
|
160
152
|
def __hash__(self):
|
161
153
|
"""Allow the model to be used as a key in a dictionary."""
|
162
|
-
|
154
|
+
from edsl.utilities.utilities import dict_hash
|
155
|
+
|
156
|
+
return dict_hash(self.to_dict())
|
163
157
|
|
164
158
|
def __eq__(self, other):
|
165
159
|
"""Check is two models are the same.
|
@@ -207,8 +201,8 @@ class LanguageModel(
|
|
207
201
|
"""Model's tokens-per-minute limit.
|
208
202
|
|
209
203
|
>>> m = LanguageModel.example()
|
210
|
-
>>> m.TPM
|
211
|
-
|
204
|
+
>>> m.TPM > 0
|
205
|
+
True
|
212
206
|
"""
|
213
207
|
self._set_rate_limits()
|
214
208
|
return self._safety_factor * self.__rate_limits["tpm"]
|
@@ -285,36 +279,14 @@ class LanguageModel(
|
|
285
279
|
"""
|
286
280
|
raise NotImplementedError
|
287
281
|
|
288
|
-
def _update_response_with_tracking(
|
289
|
-
self, response: dict, start_time: int, cached_response=False, cache_key=None
|
290
|
-
):
|
291
|
-
"""Update the response with tracking information.
|
292
|
-
|
293
|
-
>>> m = LanguageModel.example()
|
294
|
-
>>> m._update_response_with_tracking(response={"response": "Hello"}, start_time=0, cached_response=False, cache_key=None)
|
295
|
-
{'response': 'Hello', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': None}
|
296
|
-
|
297
|
-
|
298
|
-
"""
|
299
|
-
end_time = time.time()
|
300
|
-
response.update(
|
301
|
-
{
|
302
|
-
"elapsed_time": end_time - start_time,
|
303
|
-
"timestamp": end_time,
|
304
|
-
"cached_response": cached_response,
|
305
|
-
"cache_key": cache_key,
|
306
|
-
}
|
307
|
-
)
|
308
|
-
return response
|
309
|
-
|
310
282
|
async def async_get_raw_response(
|
311
283
|
self,
|
312
284
|
user_prompt: str,
|
313
285
|
system_prompt: str,
|
314
|
-
cache,
|
286
|
+
cache: "Cache",
|
315
287
|
iteration: int = 0,
|
316
288
|
encoded_image=None,
|
317
|
-
) -> dict
|
289
|
+
) -> tuple[dict, bool, str]:
|
318
290
|
"""Handle caching of responses.
|
319
291
|
|
320
292
|
:param user_prompt: The user's prompt.
|
@@ -322,8 +294,7 @@ class LanguageModel(
|
|
322
294
|
:param iteration: The iteration number.
|
323
295
|
:param cache: The cache to use.
|
324
296
|
|
325
|
-
If the cache isn't being used, it just returns a 'fresh' call to the LLM
|
326
|
-
but appends some tracking information to the response (using the _update_response_with_tracking method).
|
297
|
+
If the cache isn't being used, it just returns a 'fresh' call to the LLM.
|
327
298
|
But if cache is being used, it first checks the database to see if the response is already there.
|
328
299
|
If it is, it returns the cached response, but again appends some tracking information.
|
329
300
|
If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
|
@@ -334,7 +305,7 @@ class LanguageModel(
|
|
334
305
|
>>> from edsl import Cache
|
335
306
|
>>> m = LanguageModel.example(test_model = True)
|
336
307
|
>>> m.get_raw_response(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
|
337
|
-
{'message': '{"answer": "Hello world"}',
|
308
|
+
({'message': '{"answer": "Hello world"}'}, False, '24ff6ac2bc2f1729f817f261e0792577')
|
338
309
|
"""
|
339
310
|
start_time = time.time()
|
340
311
|
|
@@ -379,12 +350,7 @@ class LanguageModel(
|
|
379
350
|
)
|
380
351
|
cache_used = False
|
381
352
|
|
382
|
-
return
|
383
|
-
response=response,
|
384
|
-
start_time=start_time,
|
385
|
-
cached_response=cache_used,
|
386
|
-
cache_key=cache_key,
|
387
|
-
)
|
353
|
+
return response, cache_used, cache_key
|
388
354
|
|
389
355
|
get_raw_response = sync_wrapper(async_get_raw_response)
|
390
356
|
|
@@ -427,14 +393,18 @@ class LanguageModel(
|
|
427
393
|
if encoded_image:
|
428
394
|
params["encoded_image"] = encoded_image
|
429
395
|
|
430
|
-
raw_response = await self.async_get_raw_response(
|
396
|
+
raw_response, cache_used, cache_key = await self.async_get_raw_response(
|
397
|
+
**params
|
398
|
+
)
|
431
399
|
response = self.parse_response(raw_response)
|
432
400
|
|
433
401
|
try:
|
434
402
|
dict_response = json.loads(response)
|
435
403
|
except json.JSONDecodeError as e:
|
436
404
|
# TODO: Turn into logs to generate issues
|
437
|
-
dict_response, success = await repair(
|
405
|
+
dict_response, success = await repair(
|
406
|
+
bad_json=response, error_message=str(e), cache=cache
|
407
|
+
)
|
438
408
|
if not success:
|
439
409
|
raise Exception(
|
440
410
|
f"""Even the repair failed. The error was: {e}. The response was: {response}."""
|
@@ -442,7 +412,8 @@ class LanguageModel(
|
|
442
412
|
|
443
413
|
dict_response.update(
|
444
414
|
{
|
445
|
-
"
|
415
|
+
"cached_used": cache_used,
|
416
|
+
"cache_key": cache_key,
|
446
417
|
"usage": raw_response.get("usage", {}),
|
447
418
|
"raw_model_response": raw_response,
|
448
419
|
}
|
@@ -458,15 +429,18 @@ class LanguageModel(
|
|
458
429
|
#######################
|
459
430
|
# SERIALIZATION METHODS
|
460
431
|
#######################
|
432
|
+
def _to_dict(self) -> dict[str, Any]:
|
433
|
+
return {"model": self.model, "parameters": self.parameters}
|
434
|
+
|
461
435
|
@add_edsl_version
|
462
436
|
def to_dict(self) -> dict[str, Any]:
|
463
437
|
"""Convert instance to a dictionary.
|
464
438
|
|
465
439
|
>>> m = LanguageModel.example()
|
466
440
|
>>> m.to_dict()
|
467
|
-
{'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}}
|
441
|
+
{'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
|
468
442
|
"""
|
469
|
-
return
|
443
|
+
return self._to_dict()
|
470
444
|
|
471
445
|
@classmethod
|
472
446
|
@remove_edsl_version
|
@@ -508,6 +482,8 @@ class LanguageModel(
|
|
508
482
|
|
509
483
|
def rich_print(self):
|
510
484
|
"""Display an object as a table."""
|
485
|
+
from rich.table import Table
|
486
|
+
|
511
487
|
table = Table(title="Language Model")
|
512
488
|
table.add_column("Attribute", style="bold")
|
513
489
|
table.add_column("Value")
|
@@ -519,8 +495,18 @@ class LanguageModel(
|
|
519
495
|
return table
|
520
496
|
|
521
497
|
@classmethod
|
522
|
-
def example(cls, test_model=False):
|
523
|
-
"""Return a default instance of the class.
|
498
|
+
def example(cls, test_model: bool = False, canned_response: str = "Hello world"):
|
499
|
+
"""Return a default instance of the class.
|
500
|
+
|
501
|
+
>>> from edsl.language_models import LanguageModel
|
502
|
+
>>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
|
503
|
+
>>> isinstance(m, LanguageModel)
|
504
|
+
True
|
505
|
+
>>> from edsl import QuestionFreeText
|
506
|
+
>>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
|
507
|
+
>>> q.by(m).run(cache = False).select('example').first()
|
508
|
+
'WOWZA!'
|
509
|
+
"""
|
524
510
|
from edsl import Model
|
525
511
|
|
526
512
|
class TestLanguageModelGood(LanguageModel):
|
@@ -533,7 +519,8 @@ class LanguageModel(
|
|
533
519
|
self, user_prompt: str, system_prompt: str
|
534
520
|
) -> dict[str, Any]:
|
535
521
|
await asyncio.sleep(0.1)
|
536
|
-
return {"message": """{"answer": "Hello world"}"""}
|
522
|
+
# return {"message": """{"answer": "Hello, world"}"""}
|
523
|
+
return {"message": f'{{"answer": "{canned_response}"}}'}
|
537
524
|
|
538
525
|
def parse_response(self, raw_response: dict[str, Any]) -> str:
|
539
526
|
return raw_response["message"]
|
@@ -0,0 +1,96 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
from collections import UserList
|
3
|
+
from edsl import Model
|
4
|
+
|
5
|
+
from edsl.language_models import LanguageModel
|
6
|
+
from edsl.Base import Base
|
7
|
+
from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
|
8
|
+
from edsl.utilities.utilities import is_valid_variable_name
|
9
|
+
from edsl.utilities.utilities import dict_hash
|
10
|
+
|
11
|
+
|
12
|
+
class ModelList(Base, UserList):
|
13
|
+
def __init__(self, data: Optional[list] = None):
|
14
|
+
"""Initialize the ScenarioList class.
|
15
|
+
|
16
|
+
>>> from edsl import Model
|
17
|
+
>>> m = ModelList(Model.available())
|
18
|
+
|
19
|
+
"""
|
20
|
+
if data is not None:
|
21
|
+
super().__init__(data)
|
22
|
+
else:
|
23
|
+
super().__init__([])
|
24
|
+
|
25
|
+
@property
|
26
|
+
def names(self):
|
27
|
+
"""
|
28
|
+
|
29
|
+
>>> ModelList.example().names
|
30
|
+
{'...'}
|
31
|
+
"""
|
32
|
+
return set([model.model for model in self])
|
33
|
+
|
34
|
+
def rich_print(self):
|
35
|
+
pass
|
36
|
+
|
37
|
+
def __repr__(self):
|
38
|
+
return f"ModelList({super().__repr__()})"
|
39
|
+
|
40
|
+
def __hash__(self):
|
41
|
+
"""Return a hash of the ModelList. This is used for comparison of ModelLists.
|
42
|
+
|
43
|
+
>>> hash(ModelList.example())
|
44
|
+
1423518243781418961
|
45
|
+
|
46
|
+
"""
|
47
|
+
from edsl.utilities.utilities import dict_hash
|
48
|
+
|
49
|
+
return dict_hash(self._to_dict(sort=True))
|
50
|
+
|
51
|
+
def _to_dict(self, sort=False):
|
52
|
+
if sort:
|
53
|
+
model_list = sorted([model for model in self], key=lambda x: hash(x))
|
54
|
+
return {"models": [model._to_dict() for model in model_list]}
|
55
|
+
else:
|
56
|
+
return {"models": [model._to_dict() for model in self]}
|
57
|
+
|
58
|
+
@classmethod
|
59
|
+
def from_names(self, *args, **kwargs):
|
60
|
+
"""A a model list from a list of names"""
|
61
|
+
if len(args) == 1 and isinstance(args[0], list):
|
62
|
+
args = args[0]
|
63
|
+
return ModelList([Model(model_name, **kwargs) for model_name in args])
|
64
|
+
|
65
|
+
@add_edsl_version
|
66
|
+
def to_dict(self):
|
67
|
+
"""
|
68
|
+
Convert the ModelList to a dictionary.
|
69
|
+
>>> ModelList.example().to_dict()
|
70
|
+
{'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
|
71
|
+
"""
|
72
|
+
return self._to_dict()
|
73
|
+
|
74
|
+
@classmethod
|
75
|
+
@remove_edsl_version
|
76
|
+
def from_dict(cls, data):
|
77
|
+
"""
|
78
|
+
Create a ModelList from a dictionary.
|
79
|
+
|
80
|
+
>>> newm = ModelList.from_dict(ModelList.example().to_dict())
|
81
|
+
>>> assert ModelList.example() == newm
|
82
|
+
"""
|
83
|
+
return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
|
84
|
+
|
85
|
+
def code(self):
|
86
|
+
pass
|
87
|
+
|
88
|
+
@classmethod
|
89
|
+
def example(cl):
|
90
|
+
return ModelList([LanguageModel.example() for _ in range(3)])
|
91
|
+
|
92
|
+
|
93
|
+
if __name__ == "__main__":
|
94
|
+
import doctest
|
95
|
+
|
96
|
+
doctest.testmod(optionflags=doctest.ELLIPSIS)
|
edsl/language_models/registry.py
CHANGED
@@ -38,6 +38,20 @@ class Model(metaclass=Meta):
|
|
38
38
|
factory = registry.create_model_factory(model_name)
|
39
39
|
return factory(*args, **kwargs)
|
40
40
|
|
41
|
+
@classmethod
|
42
|
+
def add_model(cls, service_name, model_name):
|
43
|
+
from edsl.inference_services.registry import default
|
44
|
+
|
45
|
+
registry = default
|
46
|
+
registry.add_model(service_name, model_name)
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def services(cls, registry=None):
|
50
|
+
from edsl.inference_services.registry import default
|
51
|
+
|
52
|
+
registry = registry or default
|
53
|
+
return [r._inference_service_ for r in registry.services]
|
54
|
+
|
41
55
|
@classmethod
|
42
56
|
def available(cls, search_term=None, name_only=False, registry=None):
|
43
57
|
from edsl.inference_services.registry import default
|
edsl/language_models/repair.py
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import asyncio
|
3
|
+
import warnings
|
3
4
|
|
4
|
-
from edsl.utilities.utilities import clean_json
|
5
5
|
|
6
|
+
async def async_repair(
|
7
|
+
bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
|
8
|
+
):
|
9
|
+
from edsl.utilities.utilities import clean_json
|
6
10
|
|
7
|
-
async def async_repair(bad_json, error_message=""):
|
8
11
|
s = clean_json(bad_json)
|
9
|
-
from edsl import Model
|
10
|
-
|
11
|
-
m = Model()
|
12
12
|
|
13
13
|
try:
|
14
14
|
# this is the OpenAI version, but that's fine
|
@@ -17,56 +17,128 @@ async def async_repair(bad_json, error_message=""):
|
|
17
17
|
except json.JSONDecodeError:
|
18
18
|
valid_dict = {}
|
19
19
|
success = False
|
20
|
-
# print("Replacing control characters didn't work. Trying
|
20
|
+
# print("Replacing control characters didn't work. Trying extracting the sub-string.")
|
21
|
+
else:
|
22
|
+
return valid_dict, success
|
23
|
+
|
24
|
+
try:
|
25
|
+
from edsl.utilities.repair_functions import extract_json_from_string
|
26
|
+
|
27
|
+
valid_dict = extract_json_from_string(s)
|
28
|
+
success = True
|
29
|
+
except ValueError:
|
30
|
+
valid_dict = {}
|
31
|
+
success = False
|
21
32
|
else:
|
22
33
|
return valid_dict, success
|
23
34
|
|
24
|
-
|
25
|
-
It was supposed to respond with just a JSON object with an answer to a question and some commentary,
|
26
|
-
in a field called "comment" next to "answer".
|
27
|
-
Please repair this bad JSON: {bad_json}."""
|
35
|
+
from edsl import Model
|
28
36
|
|
29
|
-
|
30
|
-
prompt += f" Parsing error message: {error_message}"
|
37
|
+
m = Model()
|
31
38
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
39
|
+
from edsl import QuestionExtract
|
40
|
+
|
41
|
+
with warnings.catch_warnings():
|
42
|
+
warnings.simplefilter("ignore", UserWarning)
|
43
|
+
|
44
|
+
q = QuestionExtract(
|
45
|
+
question_text="""
|
46
|
+
A language model was supposed to respond to a question.
|
47
|
+
The response should have been JSON object with an answer to a question and some commentary.
|
48
|
+
|
49
|
+
It should have retured a string like this:
|
50
|
+
|
51
|
+
'{'answer': 'The answer to the question.', 'comment': 'Some commentary.'}'
|
52
|
+
|
53
|
+
or:
|
54
|
+
|
55
|
+
'{'answer': 'The answer to the question.'}'
|
56
|
+
|
57
|
+
The answer field is very like an integer number. The comment field is always string.
|
58
|
+
|
59
|
+
You job is to return just the repaired JSON object that the model should have returned, properly formatted.
|
60
|
+
|
61
|
+
- It might have included some preliminary comments.
|
62
|
+
- It might have included some control characters.
|
63
|
+
- It might have included some extraneous text.
|
64
|
+
|
65
|
+
DO NOT include any extraneous text in your response. Just return the repaired JSON object.
|
66
|
+
Do not preface the JSON object with any text. Just return the JSON object.
|
67
|
+
|
68
|
+
Bad answer: """
|
69
|
+
+ str(bad_json)
|
70
|
+
+ "The model received a user prompt of: '"
|
71
|
+
+ str(user_prompt)
|
72
|
+
+ """'
|
73
|
+
The model received a system prompt of: ' """
|
74
|
+
+ str(system_prompt)
|
75
|
+
+ """
|
76
|
+
'
|
77
|
+
Please return the repaired JSON object, following the instructions the original model should have followed, though
|
78
|
+
using 'new_answer' a nd 'new_comment' as the keys.""",
|
79
|
+
answer_template={
|
80
|
+
"new_answer": "<number, string, list, etc.>",
|
81
|
+
"new_comment": "Model's comments",
|
82
|
+
},
|
83
|
+
question_name="model_repair",
|
36
84
|
)
|
37
|
-
|
38
|
-
|
85
|
+
|
86
|
+
results = await q.run_async(cache=cache)
|
39
87
|
|
40
88
|
try:
|
41
89
|
# this is the OpenAI version, but that's fine
|
42
|
-
valid_dict = json.loads(results
|
90
|
+
valid_dict = json.loads(json.dumps(results))
|
43
91
|
success = True
|
92
|
+
# this is to deal with the fact that the model returns the answer and comment as new_answer and new_comment
|
93
|
+
valid_dict["answer"] = valid_dict.pop("new_answer")
|
94
|
+
valid_dict["comment"] = valid_dict.pop("new_comment")
|
44
95
|
except json.JSONDecodeError:
|
45
96
|
valid_dict = {}
|
46
97
|
success = False
|
98
|
+
from rich import print
|
99
|
+
from rich.console import Console
|
100
|
+
from rich.syntax import Syntax
|
101
|
+
|
102
|
+
console = Console()
|
103
|
+
error_message = (
|
104
|
+
f"All repairs. failed. LLM Model given [red]{str(bad_json)}[/red]"
|
105
|
+
)
|
106
|
+
console.print(" " + error_message)
|
107
|
+
model_returned = results["choices"][0]["message"]["content"]
|
108
|
+
console.print(f"LLM Model returned: [blue]{model_returned}[/blue]")
|
47
109
|
|
48
110
|
return valid_dict, success
|
49
111
|
|
50
112
|
|
51
|
-
def repair_wrapper(
|
113
|
+
def repair_wrapper(
|
114
|
+
bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
|
115
|
+
):
|
52
116
|
try:
|
53
117
|
loop = asyncio.get_event_loop()
|
54
118
|
if loop.is_running():
|
55
119
|
# Add repair as a task to the running loop
|
56
|
-
task = loop.create_task(
|
120
|
+
task = loop.create_task(
|
121
|
+
async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
|
122
|
+
)
|
57
123
|
return task
|
58
124
|
else:
|
59
125
|
# Run a new event loop for repair
|
60
|
-
return loop.run_until_complete(
|
126
|
+
return loop.run_until_complete(
|
127
|
+
async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
|
128
|
+
)
|
61
129
|
except RuntimeError:
|
62
130
|
# Create a new event loop if one is not already available
|
63
131
|
loop = asyncio.new_event_loop()
|
64
132
|
asyncio.set_event_loop(loop)
|
65
|
-
return loop.run_until_complete(
|
133
|
+
return loop.run_until_complete(
|
134
|
+
async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
|
135
|
+
)
|
66
136
|
|
67
137
|
|
68
|
-
def repair(
|
69
|
-
|
138
|
+
def repair(
|
139
|
+
bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
|
140
|
+
):
|
141
|
+
return repair_wrapper(bad_json, error_message, user_prompt, system_prompt, cache)
|
70
142
|
|
71
143
|
|
72
144
|
# Example usage:
|