edsl 0.1.27.dev2__py3-none-any.whl → 0.1.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. edsl/Base.py +107 -30
  2. edsl/BaseDiff.py +260 -0
  3. edsl/__init__.py +25 -21
  4. edsl/__version__.py +1 -1
  5. edsl/agents/Agent.py +103 -46
  6. edsl/agents/AgentList.py +97 -13
  7. edsl/agents/Invigilator.py +23 -10
  8. edsl/agents/InvigilatorBase.py +19 -14
  9. edsl/agents/PromptConstructionMixin.py +342 -100
  10. edsl/agents/descriptors.py +5 -2
  11. edsl/base/Base.py +289 -0
  12. edsl/config.py +2 -1
  13. edsl/conjure/AgentConstructionMixin.py +152 -0
  14. edsl/conjure/Conjure.py +56 -0
  15. edsl/conjure/InputData.py +659 -0
  16. edsl/conjure/InputDataCSV.py +48 -0
  17. edsl/conjure/InputDataMixinQuestionStats.py +182 -0
  18. edsl/conjure/InputDataPyRead.py +91 -0
  19. edsl/conjure/InputDataSPSS.py +8 -0
  20. edsl/conjure/InputDataStata.py +8 -0
  21. edsl/conjure/QuestionOptionMixin.py +76 -0
  22. edsl/conjure/QuestionTypeMixin.py +23 -0
  23. edsl/conjure/RawQuestion.py +65 -0
  24. edsl/conjure/SurveyResponses.py +7 -0
  25. edsl/conjure/__init__.py +9 -4
  26. edsl/conjure/examples/placeholder.txt +0 -0
  27. edsl/conjure/naming_utilities.py +263 -0
  28. edsl/conjure/utilities.py +165 -28
  29. edsl/conversation/Conversation.py +238 -0
  30. edsl/conversation/car_buying.py +58 -0
  31. edsl/conversation/mug_negotiation.py +81 -0
  32. edsl/conversation/next_speaker_utilities.py +93 -0
  33. edsl/coop/coop.py +337 -121
  34. edsl/coop/utils.py +56 -70
  35. edsl/data/Cache.py +74 -22
  36. edsl/data/CacheHandler.py +10 -9
  37. edsl/data/SQLiteDict.py +11 -3
  38. edsl/inference_services/AnthropicService.py +1 -0
  39. edsl/inference_services/DeepInfraService.py +20 -13
  40. edsl/inference_services/GoogleService.py +7 -1
  41. edsl/inference_services/InferenceServicesCollection.py +33 -7
  42. edsl/inference_services/OpenAIService.py +17 -10
  43. edsl/inference_services/models_available_cache.py +69 -0
  44. edsl/inference_services/rate_limits_cache.py +25 -0
  45. edsl/inference_services/write_available.py +10 -0
  46. edsl/jobs/Answers.py +15 -1
  47. edsl/jobs/Jobs.py +322 -73
  48. edsl/jobs/buckets/BucketCollection.py +9 -3
  49. edsl/jobs/buckets/ModelBuckets.py +4 -2
  50. edsl/jobs/buckets/TokenBucket.py +1 -2
  51. edsl/jobs/interviews/Interview.py +7 -10
  52. edsl/jobs/interviews/InterviewStatusMixin.py +3 -3
  53. edsl/jobs/interviews/InterviewTaskBuildingMixin.py +39 -20
  54. edsl/jobs/interviews/retry_management.py +4 -4
  55. edsl/jobs/runners/JobsRunnerAsyncio.py +103 -65
  56. edsl/jobs/runners/JobsRunnerStatusData.py +3 -3
  57. edsl/jobs/tasks/QuestionTaskCreator.py +4 -2
  58. edsl/jobs/tasks/TaskHistory.py +4 -3
  59. edsl/language_models/LanguageModel.py +42 -55
  60. edsl/language_models/ModelList.py +96 -0
  61. edsl/language_models/registry.py +14 -0
  62. edsl/language_models/repair.py +97 -25
  63. edsl/notebooks/Notebook.py +157 -32
  64. edsl/prompts/Prompt.py +31 -19
  65. edsl/questions/QuestionBase.py +145 -23
  66. edsl/questions/QuestionBudget.py +5 -6
  67. edsl/questions/QuestionCheckBox.py +7 -3
  68. edsl/questions/QuestionExtract.py +5 -3
  69. edsl/questions/QuestionFreeText.py +3 -3
  70. edsl/questions/QuestionFunctional.py +0 -3
  71. edsl/questions/QuestionList.py +3 -4
  72. edsl/questions/QuestionMultipleChoice.py +16 -8
  73. edsl/questions/QuestionNumerical.py +4 -3
  74. edsl/questions/QuestionRank.py +5 -3
  75. edsl/questions/__init__.py +4 -3
  76. edsl/questions/descriptors.py +9 -4
  77. edsl/questions/question_registry.py +27 -31
  78. edsl/questions/settings.py +1 -1
  79. edsl/results/Dataset.py +31 -0
  80. edsl/results/DatasetExportMixin.py +493 -0
  81. edsl/results/Result.py +42 -82
  82. edsl/results/Results.py +178 -66
  83. edsl/results/ResultsDBMixin.py +10 -9
  84. edsl/results/ResultsExportMixin.py +23 -507
  85. edsl/results/ResultsGGMixin.py +3 -3
  86. edsl/results/ResultsToolsMixin.py +9 -9
  87. edsl/scenarios/FileStore.py +140 -0
  88. edsl/scenarios/Scenario.py +59 -6
  89. edsl/scenarios/ScenarioList.py +138 -52
  90. edsl/scenarios/ScenarioListExportMixin.py +32 -0
  91. edsl/scenarios/ScenarioListPdfMixin.py +2 -1
  92. edsl/scenarios/__init__.py +1 -0
  93. edsl/study/ObjectEntry.py +173 -0
  94. edsl/study/ProofOfWork.py +113 -0
  95. edsl/study/SnapShot.py +73 -0
  96. edsl/study/Study.py +498 -0
  97. edsl/study/__init__.py +4 -0
  98. edsl/surveys/MemoryPlan.py +11 -4
  99. edsl/surveys/Survey.py +124 -37
  100. edsl/surveys/SurveyExportMixin.py +25 -5
  101. edsl/surveys/SurveyFlowVisualizationMixin.py +6 -4
  102. edsl/tools/plotting.py +4 -2
  103. edsl/utilities/__init__.py +21 -20
  104. edsl/utilities/gcp_bucket/__init__.py +0 -0
  105. edsl/utilities/gcp_bucket/cloud_storage.py +96 -0
  106. edsl/utilities/gcp_bucket/simple_example.py +9 -0
  107. edsl/utilities/interface.py +90 -73
  108. edsl/utilities/repair_functions.py +28 -0
  109. edsl/utilities/utilities.py +59 -6
  110. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/METADATA +42 -15
  111. edsl-0.1.29.dist-info/RECORD +203 -0
  112. edsl/conjure/RawResponseColumn.py +0 -327
  113. edsl/conjure/SurveyBuilder.py +0 -308
  114. edsl/conjure/SurveyBuilderCSV.py +0 -78
  115. edsl/conjure/SurveyBuilderSPSS.py +0 -118
  116. edsl/data/RemoteDict.py +0 -103
  117. edsl-0.1.27.dev2.dist-info/RECORD +0 -172
  118. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/LICENSE +0 -0
  119. {edsl-0.1.27.dev2.dist-info → edsl-0.1.29.dist-info}/WHEEL +0 -0
@@ -7,26 +7,18 @@ import asyncio
7
7
  import json
8
8
  import time
9
9
  import os
10
-
11
10
  from typing import Coroutine, Any, Callable, Type, List, get_type_hints
12
-
13
- from abc import ABC, abstractmethod, ABCMeta
14
-
15
- from rich.table import Table
11
+ from abc import ABC, abstractmethod
16
12
 
17
13
  from edsl.config import CONFIG
18
14
 
19
- from edsl.utilities.utilities import clean_json
20
15
  from edsl.utilities.decorators import sync_wrapper, jupyter_nb_handler
21
16
  from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
17
+
22
18
  from edsl.language_models.repair import repair
23
- from edsl.exceptions.language_models import LanguageModelAttributeTypeError
24
19
  from edsl.enums import InferenceServiceType
25
20
  from edsl.Base import RichPrintingMixin, PersistenceMixin
26
- from edsl.data.Cache import Cache
27
21
  from edsl.enums import service_to_api_keyname
28
-
29
-
30
22
  from edsl.exceptions import MissingAPIKeyError
31
23
  from edsl.language_models.RegisterLanguageModelsMeta import RegisterLanguageModelsMeta
32
24
 
@@ -142,7 +134,7 @@ class LanguageModel(
142
134
  def has_valid_api_key(self) -> bool:
143
135
  """Check if the model has a valid API key.
144
136
 
145
- >>> LanguageModel.example().has_valid_api_key()
137
+ >>> LanguageModel.example().has_valid_api_key() : # doctest: +SKIP
146
138
  True
147
139
 
148
140
  This method is used to check if the model has a valid API key.
@@ -159,7 +151,9 @@ class LanguageModel(
159
151
 
160
152
  def __hash__(self):
161
153
  """Allow the model to be used as a key in a dictionary."""
162
- return hash(self.model + str(self.parameters))
154
+ from edsl.utilities.utilities import dict_hash
155
+
156
+ return dict_hash(self.to_dict())
163
157
 
164
158
  def __eq__(self, other):
165
159
  """Check is two models are the same.
@@ -207,8 +201,8 @@ class LanguageModel(
207
201
  """Model's tokens-per-minute limit.
208
202
 
209
203
  >>> m = LanguageModel.example()
210
- >>> m.TPM
211
- 1600000.0
204
+ >>> m.TPM > 0
205
+ True
212
206
  """
213
207
  self._set_rate_limits()
214
208
  return self._safety_factor * self.__rate_limits["tpm"]
@@ -285,36 +279,14 @@ class LanguageModel(
285
279
  """
286
280
  raise NotImplementedError
287
281
 
288
- def _update_response_with_tracking(
289
- self, response: dict, start_time: int, cached_response=False, cache_key=None
290
- ):
291
- """Update the response with tracking information.
292
-
293
- >>> m = LanguageModel.example()
294
- >>> m._update_response_with_tracking(response={"response": "Hello"}, start_time=0, cached_response=False, cache_key=None)
295
- {'response': 'Hello', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': None}
296
-
297
-
298
- """
299
- end_time = time.time()
300
- response.update(
301
- {
302
- "elapsed_time": end_time - start_time,
303
- "timestamp": end_time,
304
- "cached_response": cached_response,
305
- "cache_key": cache_key,
306
- }
307
- )
308
- return response
309
-
310
282
  async def async_get_raw_response(
311
283
  self,
312
284
  user_prompt: str,
313
285
  system_prompt: str,
314
- cache,
286
+ cache: "Cache",
315
287
  iteration: int = 0,
316
288
  encoded_image=None,
317
- ) -> dict[str, Any]:
289
+ ) -> tuple[dict, bool, str]:
318
290
  """Handle caching of responses.
319
291
 
320
292
  :param user_prompt: The user's prompt.
@@ -322,8 +294,7 @@ class LanguageModel(
322
294
  :param iteration: The iteration number.
323
295
  :param cache: The cache to use.
324
296
 
325
- If the cache isn't being used, it just returns a 'fresh' call to the LLM,
326
- but appends some tracking information to the response (using the _update_response_with_tracking method).
297
+ If the cache isn't being used, it just returns a 'fresh' call to the LLM.
327
298
  But if cache is being used, it first checks the database to see if the response is already there.
328
299
  If it is, it returns the cached response, but again appends some tracking information.
329
300
  If it isn't, it calls the LLM, saves the response to the database, and returns the response with tracking information.
@@ -334,7 +305,7 @@ class LanguageModel(
334
305
  >>> from edsl import Cache
335
306
  >>> m = LanguageModel.example(test_model = True)
336
307
  >>> m.get_raw_response(user_prompt = "Hello", system_prompt = "hello", cache = Cache())
337
- {'message': '{"answer": "Hello world"}', 'elapsed_time': ..., 'timestamp': ..., 'cached_response': False, 'cache_key': '24ff6ac2bc2f1729f817f261e0792577'}
308
+ ({'message': '{"answer": "Hello world"}'}, False, '24ff6ac2bc2f1729f817f261e0792577')
338
309
  """
339
310
  start_time = time.time()
340
311
 
@@ -379,12 +350,7 @@ class LanguageModel(
379
350
  )
380
351
  cache_used = False
381
352
 
382
- return self._update_response_with_tracking(
383
- response=response,
384
- start_time=start_time,
385
- cached_response=cache_used,
386
- cache_key=cache_key,
387
- )
353
+ return response, cache_used, cache_key
388
354
 
389
355
  get_raw_response = sync_wrapper(async_get_raw_response)
390
356
 
@@ -427,14 +393,18 @@ class LanguageModel(
427
393
  if encoded_image:
428
394
  params["encoded_image"] = encoded_image
429
395
 
430
- raw_response = await self.async_get_raw_response(**params)
396
+ raw_response, cache_used, cache_key = await self.async_get_raw_response(
397
+ **params
398
+ )
431
399
  response = self.parse_response(raw_response)
432
400
 
433
401
  try:
434
402
  dict_response = json.loads(response)
435
403
  except json.JSONDecodeError as e:
436
404
  # TODO: Turn into logs to generate issues
437
- dict_response, success = await repair(response, str(e))
405
+ dict_response, success = await repair(
406
+ bad_json=response, error_message=str(e), cache=cache
407
+ )
438
408
  if not success:
439
409
  raise Exception(
440
410
  f"""Even the repair failed. The error was: {e}. The response was: {response}."""
@@ -442,7 +412,8 @@ class LanguageModel(
442
412
 
443
413
  dict_response.update(
444
414
  {
445
- "cached_response": raw_response["cached_response"],
415
+ "cached_used": cache_used,
416
+ "cache_key": cache_key,
446
417
  "usage": raw_response.get("usage", {}),
447
418
  "raw_model_response": raw_response,
448
419
  }
@@ -458,15 +429,18 @@ class LanguageModel(
458
429
  #######################
459
430
  # SERIALIZATION METHODS
460
431
  #######################
432
+ def _to_dict(self) -> dict[str, Any]:
433
+ return {"model": self.model, "parameters": self.parameters}
434
+
461
435
  @add_edsl_version
462
436
  def to_dict(self) -> dict[str, Any]:
463
437
  """Convert instance to a dictionary.
464
438
 
465
439
  >>> m = LanguageModel.example()
466
440
  >>> m.to_dict()
467
- {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}}
441
+ {'model': 'gpt-4-1106-preview', 'parameters': {'temperature': 0.5, 'max_tokens': 1000, 'top_p': 1, 'frequency_penalty': 0, 'presence_penalty': 0, 'logprobs': False, 'top_logprobs': 3}, 'edsl_version': '...', 'edsl_class_name': 'LanguageModel'}
468
442
  """
469
- return {"model": self.model, "parameters": self.parameters}
443
+ return self._to_dict()
470
444
 
471
445
  @classmethod
472
446
  @remove_edsl_version
@@ -508,6 +482,8 @@ class LanguageModel(
508
482
 
509
483
  def rich_print(self):
510
484
  """Display an object as a table."""
485
+ from rich.table import Table
486
+
511
487
  table = Table(title="Language Model")
512
488
  table.add_column("Attribute", style="bold")
513
489
  table.add_column("Value")
@@ -519,8 +495,18 @@ class LanguageModel(
519
495
  return table
520
496
 
521
497
  @classmethod
522
- def example(cls, test_model=False):
523
- """Return a default instance of the class."""
498
+ def example(cls, test_model: bool = False, canned_response: str = "Hello world"):
499
+ """Return a default instance of the class.
500
+
501
+ >>> from edsl.language_models import LanguageModel
502
+ >>> m = LanguageModel.example(test_model = True, canned_response = "WOWZA!")
503
+ >>> isinstance(m, LanguageModel)
504
+ True
505
+ >>> from edsl import QuestionFreeText
506
+ >>> q = QuestionFreeText(question_text = "What is your name?", question_name = 'example')
507
+ >>> q.by(m).run(cache = False).select('example').first()
508
+ 'WOWZA!'
509
+ """
524
510
  from edsl import Model
525
511
 
526
512
  class TestLanguageModelGood(LanguageModel):
@@ -533,7 +519,8 @@ class LanguageModel(
533
519
  self, user_prompt: str, system_prompt: str
534
520
  ) -> dict[str, Any]:
535
521
  await asyncio.sleep(0.1)
536
- return {"message": """{"answer": "Hello world"}"""}
522
+ # return {"message": """{"answer": "Hello, world"}"""}
523
+ return {"message": f'{{"answer": "{canned_response}"}}'}
537
524
 
538
525
  def parse_response(self, raw_response: dict[str, Any]) -> str:
539
526
  return raw_response["message"]
@@ -0,0 +1,96 @@
1
+ from typing import Optional
2
+ from collections import UserList
3
+ from edsl import Model
4
+
5
+ from edsl.language_models import LanguageModel
6
+ from edsl.Base import Base
7
+ from edsl.utilities.decorators import add_edsl_version, remove_edsl_version
8
+ from edsl.utilities.utilities import is_valid_variable_name
9
+ from edsl.utilities.utilities import dict_hash
10
+
11
+
12
+ class ModelList(Base, UserList):
13
+ def __init__(self, data: Optional[list] = None):
14
+ """Initialize the ScenarioList class.
15
+
16
+ >>> from edsl import Model
17
+ >>> m = ModelList(Model.available())
18
+
19
+ """
20
+ if data is not None:
21
+ super().__init__(data)
22
+ else:
23
+ super().__init__([])
24
+
25
+ @property
26
+ def names(self):
27
+ """
28
+
29
+ >>> ModelList.example().names
30
+ {'...'}
31
+ """
32
+ return set([model.model for model in self])
33
+
34
+ def rich_print(self):
35
+ pass
36
+
37
+ def __repr__(self):
38
+ return f"ModelList({super().__repr__()})"
39
+
40
+ def __hash__(self):
41
+ """Return a hash of the ModelList. This is used for comparison of ModelLists.
42
+
43
+ >>> hash(ModelList.example())
44
+ 1423518243781418961
45
+
46
+ """
47
+ from edsl.utilities.utilities import dict_hash
48
+
49
+ return dict_hash(self._to_dict(sort=True))
50
+
51
+ def _to_dict(self, sort=False):
52
+ if sort:
53
+ model_list = sorted([model for model in self], key=lambda x: hash(x))
54
+ return {"models": [model._to_dict() for model in model_list]}
55
+ else:
56
+ return {"models": [model._to_dict() for model in self]}
57
+
58
+ @classmethod
59
+ def from_names(self, *args, **kwargs):
60
+ """A a model list from a list of names"""
61
+ if len(args) == 1 and isinstance(args[0], list):
62
+ args = args[0]
63
+ return ModelList([Model(model_name, **kwargs) for model_name in args])
64
+
65
+ @add_edsl_version
66
+ def to_dict(self):
67
+ """
68
+ Convert the ModelList to a dictionary.
69
+ >>> ModelList.example().to_dict()
70
+ {'models': [...], 'edsl_version': '...', 'edsl_class_name': 'ModelList'}
71
+ """
72
+ return self._to_dict()
73
+
74
+ @classmethod
75
+ @remove_edsl_version
76
+ def from_dict(cls, data):
77
+ """
78
+ Create a ModelList from a dictionary.
79
+
80
+ >>> newm = ModelList.from_dict(ModelList.example().to_dict())
81
+ >>> assert ModelList.example() == newm
82
+ """
83
+ return cls(data=[LanguageModel.from_dict(model) for model in data["models"]])
84
+
85
+ def code(self):
86
+ pass
87
+
88
+ @classmethod
89
+ def example(cl):
90
+ return ModelList([LanguageModel.example() for _ in range(3)])
91
+
92
+
93
+ if __name__ == "__main__":
94
+ import doctest
95
+
96
+ doctest.testmod(optionflags=doctest.ELLIPSIS)
@@ -38,6 +38,20 @@ class Model(metaclass=Meta):
38
38
  factory = registry.create_model_factory(model_name)
39
39
  return factory(*args, **kwargs)
40
40
 
41
+ @classmethod
42
+ def add_model(cls, service_name, model_name):
43
+ from edsl.inference_services.registry import default
44
+
45
+ registry = default
46
+ registry.add_model(service_name, model_name)
47
+
48
+ @classmethod
49
+ def services(cls, registry=None):
50
+ from edsl.inference_services.registry import default
51
+
52
+ registry = registry or default
53
+ return [r._inference_service_ for r in registry.services]
54
+
41
55
  @classmethod
42
56
  def available(cls, search_term=None, name_only=False, registry=None):
43
57
  from edsl.inference_services.registry import default
@@ -1,14 +1,14 @@
1
1
  import json
2
2
  import asyncio
3
+ import warnings
3
4
 
4
- from edsl.utilities.utilities import clean_json
5
5
 
6
+ async def async_repair(
7
+ bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
8
+ ):
9
+ from edsl.utilities.utilities import clean_json
6
10
 
7
- async def async_repair(bad_json, error_message=""):
8
11
  s = clean_json(bad_json)
9
- from edsl import Model
10
-
11
- m = Model()
12
12
 
13
13
  try:
14
14
  # this is the OpenAI version, but that's fine
@@ -17,56 +17,128 @@ async def async_repair(bad_json, error_message=""):
17
17
  except json.JSONDecodeError:
18
18
  valid_dict = {}
19
19
  success = False
20
- # print("Replacing control characters didn't work. Trying with the model.")
20
+ # print("Replacing control characters didn't work. Trying extracting the sub-string.")
21
+ else:
22
+ return valid_dict, success
23
+
24
+ try:
25
+ from edsl.utilities.repair_functions import extract_json_from_string
26
+
27
+ valid_dict = extract_json_from_string(s)
28
+ success = True
29
+ except ValueError:
30
+ valid_dict = {}
31
+ success = False
21
32
  else:
22
33
  return valid_dict, success
23
34
 
24
- prompt = f"""This is the output from a less capable language model.
25
- It was supposed to respond with just a JSON object with an answer to a question and some commentary,
26
- in a field called "comment" next to "answer".
27
- Please repair this bad JSON: {bad_json}."""
35
+ from edsl import Model
28
36
 
29
- if error_message:
30
- prompt += f" Parsing error message: {error_message}"
37
+ m = Model()
31
38
 
32
- try:
33
- results = await m.async_execute_model_call(
34
- prompt,
35
- system_prompt="You are a helpful agent. Only return the repaired JSON, nothing else.",
39
+ from edsl import QuestionExtract
40
+
41
+ with warnings.catch_warnings():
42
+ warnings.simplefilter("ignore", UserWarning)
43
+
44
+ q = QuestionExtract(
45
+ question_text="""
46
+ A language model was supposed to respond to a question.
47
+ The response should have been JSON object with an answer to a question and some commentary.
48
+
49
+ It should have retured a string like this:
50
+
51
+ '{'answer': 'The answer to the question.', 'comment': 'Some commentary.'}'
52
+
53
+ or:
54
+
55
+ '{'answer': 'The answer to the question.'}'
56
+
57
+ The answer field is very like an integer number. The comment field is always string.
58
+
59
+ You job is to return just the repaired JSON object that the model should have returned, properly formatted.
60
+
61
+ - It might have included some preliminary comments.
62
+ - It might have included some control characters.
63
+ - It might have included some extraneous text.
64
+
65
+ DO NOT include any extraneous text in your response. Just return the repaired JSON object.
66
+ Do not preface the JSON object with any text. Just return the JSON object.
67
+
68
+ Bad answer: """
69
+ + str(bad_json)
70
+ + "The model received a user prompt of: '"
71
+ + str(user_prompt)
72
+ + """'
73
+ The model received a system prompt of: ' """
74
+ + str(system_prompt)
75
+ + """
76
+ '
77
+ Please return the repaired JSON object, following the instructions the original model should have followed, though
78
+ using 'new_answer' a nd 'new_comment' as the keys.""",
79
+ answer_template={
80
+ "new_answer": "<number, string, list, etc.>",
81
+ "new_comment": "Model's comments",
82
+ },
83
+ question_name="model_repair",
36
84
  )
37
- except Exception as e:
38
- return {}, False
85
+
86
+ results = await q.run_async(cache=cache)
39
87
 
40
88
  try:
41
89
  # this is the OpenAI version, but that's fine
42
- valid_dict = json.loads(results["choices"][0]["message"]["content"])
90
+ valid_dict = json.loads(json.dumps(results))
43
91
  success = True
92
+ # this is to deal with the fact that the model returns the answer and comment as new_answer and new_comment
93
+ valid_dict["answer"] = valid_dict.pop("new_answer")
94
+ valid_dict["comment"] = valid_dict.pop("new_comment")
44
95
  except json.JSONDecodeError:
45
96
  valid_dict = {}
46
97
  success = False
98
+ from rich import print
99
+ from rich.console import Console
100
+ from rich.syntax import Syntax
101
+
102
+ console = Console()
103
+ error_message = (
104
+ f"All repairs. failed. LLM Model given [red]{str(bad_json)}[/red]"
105
+ )
106
+ console.print(" " + error_message)
107
+ model_returned = results["choices"][0]["message"]["content"]
108
+ console.print(f"LLM Model returned: [blue]{model_returned}[/blue]")
47
109
 
48
110
  return valid_dict, success
49
111
 
50
112
 
51
- def repair_wrapper(bad_json, error_message=""):
113
+ def repair_wrapper(
114
+ bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
115
+ ):
52
116
  try:
53
117
  loop = asyncio.get_event_loop()
54
118
  if loop.is_running():
55
119
  # Add repair as a task to the running loop
56
- task = loop.create_task(async_repair(bad_json, error_message))
120
+ task = loop.create_task(
121
+ async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
122
+ )
57
123
  return task
58
124
  else:
59
125
  # Run a new event loop for repair
60
- return loop.run_until_complete(async_repair(bad_json, error_message))
126
+ return loop.run_until_complete(
127
+ async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
128
+ )
61
129
  except RuntimeError:
62
130
  # Create a new event loop if one is not already available
63
131
  loop = asyncio.new_event_loop()
64
132
  asyncio.set_event_loop(loop)
65
- return loop.run_until_complete(async_repair(bad_json, error_message))
133
+ return loop.run_until_complete(
134
+ async_repair(bad_json, error_message, user_prompt, system_prompt, cache)
135
+ )
66
136
 
67
137
 
68
- def repair(bad_json, error_message=""):
69
- return repair_wrapper(bad_json, error_message)
138
+ def repair(
139
+ bad_json, error_message="", user_prompt=None, system_prompt=None, cache=None
140
+ ):
141
+ return repair_wrapper(bad_json, error_message, user_prompt, system_prompt, cache)
70
142
 
71
143
 
72
144
  # Example usage: