lionagi 0.3.6__py3-none-any.whl → 0.3.8__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
lionagi/libs/sys_util.py CHANGED
@@ -7,11 +7,17 @@ import re
7
7
  import subprocess
8
8
  import sys
9
9
  import time
10
+ from collections.abc import Sequence
10
11
  from datetime import datetime, timezone
11
12
  from hashlib import sha256
12
13
  from pathlib import Path
13
14
  from typing import Any
14
15
 
16
+ from lion_core.setting import DEFAULT_LION_ID_CONFIG, LionIDConfig
17
+ from lion_core.sys_utils import SysUtil as _u
18
+ from lionabc import Observable
19
+ from typing_extensions import deprecated
20
+
15
21
  _timestamp_syms = ["-", ":", "."]
16
22
 
17
23
  PATH_TYPE = str | Path
@@ -20,24 +26,74 @@ PATH_TYPE = str | Path
20
26
  class SysUtil:
21
27
 
22
28
  @staticmethod
29
+ def id(
30
+ config: LionIDConfig = DEFAULT_LION_ID_CONFIG,
31
+ n: int = None,
32
+ prefix: str = None,
33
+ postfix: str = None,
34
+ random_hyphen: bool = None,
35
+ num_hyphens: int = None,
36
+ hyphen_start_index: int = None,
37
+ hyphen_end_index: int = None,
38
+ ) -> str:
39
+ return _u.id(
40
+ config=config,
41
+ n=n,
42
+ prefix=prefix,
43
+ postfix=postfix,
44
+ random_hyphen=random_hyphen,
45
+ num_hyphens=num_hyphens,
46
+ hyphen_start_index=hyphen_start_index,
47
+ hyphen_end_index=hyphen_end_index,
48
+ )
49
+
50
+ @staticmethod
51
+ def get_id(
52
+ item: Sequence[Observable] | Observable | str,
53
+ config: LionIDConfig = DEFAULT_LION_ID_CONFIG,
54
+ /,
55
+ ) -> str:
56
+ return _u.get_id(item, config)
57
+
58
+ @staticmethod
59
+ def is_id(
60
+ item: Sequence[Observable] | Observable | str,
61
+ config: LionIDConfig = DEFAULT_LION_ID_CONFIG,
62
+ /,
63
+ ) -> bool:
64
+ return _u.is_id(item, config)
65
+
66
+ # legacy methods, kept for backward compatibility
67
+
68
+ @staticmethod
69
+ @deprecated(
70
+ "Deprecated since v0.3, will be removed in v1.0. Use time.sleep instead.",
71
+ category=DeprecationWarning,
72
+ stacklevel=2,
73
+ )
23
74
  def sleep(delay: float) -> None:
24
75
  """
25
76
  Pauses execution for a specified duration.
26
77
 
27
78
  Args:
28
- delay (float): The amount of time, in seconds, to pause execution.
79
+ delay (float): The amount of time, in seconds, to pause execution.
29
80
  """
30
81
  time.sleep(delay)
31
82
 
32
83
  @staticmethod
84
+ @deprecated(
85
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.time instead",
86
+ category=DeprecationWarning,
87
+ stacklevel=2,
88
+ )
33
89
  def get_now(datetime_: bool = False, tz=None) -> float | datetime:
34
90
  """Returns the current time either as a Unix timestamp or a datetime object.
35
91
 
36
92
  Args:
37
- datetime_ (bool): If True, returns a datetime object; otherwise, returns a Unix timestamp.
93
+ datetime_ (bool): If True, returns a datetime object; otherwise, returns a Unix timestamp.
38
94
 
39
95
  Returns:
40
- Union[float, datetime.datetime]: The current time as a Unix timestamp or a datetime object.
96
+ Union[float, datetime.datetime]: The current time as a Unix timestamp or a datetime object.
41
97
  """
42
98
 
43
99
  if not datetime_:
@@ -48,6 +104,11 @@ class SysUtil:
48
104
  return datetime.now(**config_)
49
105
 
50
106
  @staticmethod
107
+ @deprecated(
108
+ "Deprecated since v0.3, will be removed in v1.0. Use d_[k2] = d_.pop(k1) instead",
109
+ category=DeprecationWarning,
110
+ stacklevel=2,
111
+ )
51
112
  def change_dict_key(
52
113
  dict_: dict[Any, Any], old_key: str, new_key: str
53
114
  ) -> None:
@@ -65,15 +126,20 @@ class SysUtil:
65
126
  dict_[new_key] = dict_.pop(old_key)
66
127
 
67
128
  @staticmethod
129
+ @deprecated(
130
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.time instead",
131
+ category=DeprecationWarning,
132
+ stacklevel=2,
133
+ )
68
134
  def get_timestamp(tz: timezone = timezone.utc, sep: str = "_") -> str:
69
135
  """Returns a timestamp string with optional custom separators and timezone.
70
136
 
71
137
  Args:
72
- tz (timezone): The timezone for the timestamp.
73
- sep (str): The separator to use in the timestamp string, replacing '-', ':', and '.'.
138
+ tz (timezone): The timezone for the timestamp.
139
+ sep (str): The separator to use in the timestamp string, replacing '-', ':', and '.'.
74
140
 
75
141
  Returns:
76
- str: A string representation of the current timestamp.
142
+ str: A string representation of the current timestamp.
77
143
  """
78
144
  str_ = datetime.now(tz=tz).isoformat()
79
145
  if sep is not None:
@@ -82,6 +148,11 @@ class SysUtil:
82
148
  return str_
83
149
 
84
150
  @staticmethod
151
+ @deprecated(
152
+ "Deprecated since v0.3, will be removed in v1.0. Deprecated without replacement",
153
+ category=DeprecationWarning,
154
+ stacklevel=2,
155
+ )
85
156
  def is_schema(dict_: dict[Any, Any], schema: dict[Any, type]) -> bool:
86
157
  """Validates if the given dictionary matches the expected schema types."""
87
158
  return all(
@@ -90,6 +161,11 @@ class SysUtil:
90
161
  )
91
162
 
92
163
  @staticmethod
164
+ @deprecated(
165
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.copy instead",
166
+ category=DeprecationWarning,
167
+ stacklevel=2,
168
+ )
93
169
  def create_copy(input_: Any, num: int = 1) -> Any | list[Any]:
94
170
  """Creates deep copies of the input, either as a single copy or a list of copies.
95
171
 
@@ -109,6 +185,11 @@ class SysUtil:
109
185
  )
110
186
 
111
187
  @staticmethod
188
+ @deprecated(
189
+ "Deprecated since v0.3, will be removed in v1.0. Use SysUtil.id instead",
190
+ category=DeprecationWarning,
191
+ stacklevel=2,
192
+ )
112
193
  def create_id(n: int = 32) -> str:
113
194
  """
114
195
  Generates a unique identifier based on the current time and random bytes.
@@ -124,17 +205,22 @@ class SysUtil:
124
205
  return sha256(current_time + random_bytes).hexdigest()[:n]
125
206
 
126
207
  @staticmethod
208
+ @deprecated(
209
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.get_bins instead",
210
+ category=DeprecationWarning,
211
+ stacklevel=2,
212
+ )
127
213
  def get_bins(
128
214
  input_: list[str], upper: int | None = 2000
129
215
  ) -> list[list[int]]:
130
216
  """Organizes indices of strings into bins based on a cumulative upper limit.
131
217
 
132
218
  Args:
133
- input_ (List[str]): The list of strings to be binned.
134
- upper (int): The cumulative length upper limit for each bin.
219
+ input_ (List[str]): The list of strings to be binned.
220
+ upper (int): The cumulative length upper limit for each bin.
135
221
 
136
222
  Returns:
137
- List[List[int]]: A list of bins, each bin is a list of indices from the input list.
223
+ List[List[int]]: A list of bins, each bin is a list of indices from the input list.
138
224
  """
139
225
  current = 0
140
226
  bins = []
@@ -152,13 +238,18 @@ class SysUtil:
152
238
  return bins
153
239
 
154
240
  @staticmethod
241
+ @deprecated(
242
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.get_cpu_architecture instead",
243
+ category=DeprecationWarning,
244
+ stacklevel=2,
245
+ )
155
246
  def get_cpu_architecture() -> str:
156
247
  """Returns a string identifying the CPU architecture.
157
248
 
158
249
  This method categorizes some architectures as 'apple_silicon'.
159
250
 
160
251
  Returns:
161
- str: A string identifying the CPU architecture ('apple_silicon' or 'other_cpu').
252
+ str: A string identifying the CPU architecture ('apple_silicon' or 'other_cpu').
162
253
  """
163
254
  arch: str = platform.machine().lower()
164
255
  return (
@@ -168,6 +259,11 @@ class SysUtil:
168
259
  )
169
260
 
170
261
  @staticmethod
262
+ @deprecated(
263
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.install_import instead",
264
+ category=DeprecationWarning,
265
+ stacklevel=2,
266
+ )
171
267
  def install_import(
172
268
  package_name: str,
173
269
  module_name: str = None,
@@ -180,10 +276,10 @@ class SysUtil:
180
276
  to install the package using pip and then retries the import.
181
277
 
182
278
  Args:
183
- package_name: The base name of the package to import.
184
- module_name: The submodule name to import from the package, if applicable. Defaults to None.
185
- import_name: The specific name to import from the module or package. Defaults to None.
186
- pip_name: The pip package name if different from `package_name`. Defaults to None.
279
+ package_name: The base name of the package to import.
280
+ module_name: The submodule name to import from the package, if applicable. Defaults to None.
281
+ import_name: The specific name to import from the module or package. Defaults to None.
282
+ pip_name: The pip package name if different from `package_name`. Defaults to None.
187
283
 
188
284
  Prints a message indicating success or attempts installation if the import fails.
189
285
  """
@@ -215,23 +311,38 @@ class SysUtil:
215
311
  __import__(full_import_path)
216
312
 
217
313
  @staticmethod
314
+ @deprecated(
315
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.import_module instead",
316
+ category=DeprecationWarning,
317
+ stacklevel=2,
318
+ )
218
319
  def import_module(module_path: str):
219
320
  return importlib.import_module(module_path)
220
321
 
221
322
  @staticmethod
323
+ @deprecated(
324
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.is_package_installed instead",
325
+ category=DeprecationWarning,
326
+ stacklevel=2,
327
+ )
222
328
  def is_package_installed(package_name: str) -> bool:
223
329
  """Checks if a package is currently installed.
224
330
 
225
331
  Args:
226
- package_name: The name of the package to check.
332
+ package_name: The name of the package to check.
227
333
 
228
334
  Returns:
229
- A boolean indicating whether the package is installed.
335
+ A boolean indicating whether the package is installed.
230
336
  """
231
337
  package_spec = importlib.util.find_spec(package_name)
232
338
  return package_spec is not None
233
339
 
234
340
  @staticmethod
341
+ @deprecated(
342
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.check_import instead",
343
+ category=DeprecationWarning,
344
+ stacklevel=2,
345
+ )
235
346
  def check_import(
236
347
  package_name: str,
237
348
  module_name: str | None = None,
@@ -246,12 +357,12 @@ class SysUtil:
246
357
  it attempts to install the package using `install_import` and then retries the import.
247
358
 
248
359
  Args:
249
- package_name: The name of the package to check and potentially install.
250
- module_name: The submodule name to import from the package, if applicable. Defaults to None.
251
- import_name: The specific name to import from the module or package. Defaults to None.
252
- pip_name: The pip package name if different from `package_name`. Defaults to None.
253
- attempt_install: If attempt to install the package if uninstalled. Defaults to True.
254
- error_message: Error message when the package is not installed and not attempt to install.
360
+ package_name: The name of the package to check and potentially install.
361
+ module_name: The submodule name to import from the package, if applicable. Defaults to None.
362
+ import_name: The specific name to import from the module or package. Defaults to None.
363
+ pip_name: The pip package name if different from `package_name`. Defaults to None.
364
+ attempt_install: If attempt to install the package if uninstalled. Defaults to True.
365
+ error_message: Error message when the package is not installed and not attempt to install.
255
366
  """
256
367
  try:
257
368
  if not SysUtil.is_package_installed(package_name):
@@ -277,6 +388,11 @@ class SysUtil:
277
388
  ) from e
278
389
 
279
390
  @staticmethod
391
+ @deprecated(
392
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.list_installed_packages instead",
393
+ category=DeprecationWarning,
394
+ stacklevel=2,
395
+ )
280
396
  def list_installed_packages() -> list:
281
397
  """list all installed packages using importlib.metadata."""
282
398
  return [
@@ -285,6 +401,11 @@ class SysUtil:
285
401
  ]
286
402
 
287
403
  @staticmethod
404
+ @deprecated(
405
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.uninstall_package instead",
406
+ category=DeprecationWarning,
407
+ stacklevel=2,
408
+ )
288
409
  def uninstall_package(package_name: str) -> None:
289
410
  """Uninstall a specified package."""
290
411
  try:
@@ -296,6 +417,11 @@ class SysUtil:
296
417
  print(f"Failed to uninstall {package_name}. Error: {e}")
297
418
 
298
419
  @staticmethod
420
+ @deprecated(
421
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.update_package instead",
422
+ category=DeprecationWarning,
423
+ stacklevel=2,
424
+ )
299
425
  def update_package(package_name: str) -> None:
300
426
  """Update a specified package."""
301
427
  try:
@@ -314,6 +440,11 @@ class SysUtil:
314
440
  print(f"Failed to update {package_name}. Error: {e}")
315
441
 
316
442
  @staticmethod
443
+ @deprecated(
444
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.clear_path instead",
445
+ category=DeprecationWarning,
446
+ stacklevel=2,
447
+ )
317
448
  def clear_dir(
318
449
  dir_path: Path | str,
319
450
  recursive: bool = False,
@@ -324,12 +455,12 @@ class SysUtil:
324
455
  excluding files that match any pattern in the exclude list.
325
456
 
326
457
  Args:
327
- dir_path (Union[Path, str]): The path to the directory to clear.
328
- recursive (bool): If True, clears directories recursively. Defaults to False.
329
- exclude (List[str]): A list of string patterns to exclude from deletion. Defaults to None.
458
+ dir_path (Union[Path, str]): The path to the directory to clear.
459
+ recursive (bool): If True, clears directories recursively. Defaults to False.
460
+ exclude (List[str]): A list of string patterns to exclude from deletion. Defaults to None.
330
461
 
331
462
  Raises:
332
- FileNotFoundError: If the specified directory does not exist.
463
+ FileNotFoundError: If the specified directory does not exist.
333
464
  """
334
465
  dir_path = Path(dir_path)
335
466
  if not dir_path.exists():
@@ -356,6 +487,11 @@ class SysUtil:
356
487
  raise
357
488
 
358
489
  @staticmethod
490
+ @deprecated(
491
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.split_path instead",
492
+ category=DeprecationWarning,
493
+ stacklevel=2,
494
+ )
359
495
  def split_path(path: Path | str) -> tuple[Path, str]:
360
496
  """
361
497
  Splits a path into its directory and filename components.
@@ -370,6 +506,11 @@ class SysUtil:
370
506
  return path.parent, path.name
371
507
 
372
508
  @staticmethod
509
+ @deprecated(
510
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.create_path instead",
511
+ category=DeprecationWarning,
512
+ stacklevel=2,
513
+ )
373
514
  def create_path(
374
515
  directory: Path | str,
375
516
  filename: str,
@@ -383,12 +524,12 @@ class SysUtil:
383
524
  Creates a path with an optional timestamp in the specified directory.
384
525
 
385
526
  Args:
386
- directory (Union[Path, str]): The directory where the file will be located.
387
- filename (str): The filename. Must include a valid extension.
388
- timestamp (bool): If True, adds a timestamp to the filename. Defaults to True.
389
- dir_exist_ok (bool): If True, does not raise an error if the directory exists. Defaults to True.
390
- time_prefix (bool): If True, adds the timestamp as a prefix; otherwise, as a suffix. Defaults to False.
391
- custom_timestamp_format (str): A custom format for the timestamp. Defaults to "%Y%m%d%H%M%S".
527
+ directory (Union[Path, str]): The directory where the file will be located.
528
+ filename (str): The filename. Must include a valid extension.
529
+ timestamp (bool): If True, adds a timestamp to the filename. Defaults to True.
530
+ dir_exist_ok (bool): If True, does not raise an error if the directory exists. Defaults to True.
531
+ time_prefix (bool): If True, adds the timestamp as a prefix; otherwise, as a suffix. Defaults to False.
532
+ custom_timestamp_format (str): A custom format for the timestamp. Defaults to "%Y%m%d%H%M%S".
392
533
 
393
534
  Returns:
394
535
  Path: The full path to the file.
@@ -432,6 +573,11 @@ class SysUtil:
432
573
  return full_path
433
574
 
434
575
  @staticmethod
576
+ @deprecated(
577
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.list_files instead",
578
+ category=DeprecationWarning,
579
+ stacklevel=2,
580
+ )
435
581
  def list_files(dir_path: Path | str, extension: str = None) -> list[Path]:
436
582
  """
437
583
  Lists all files in a specified directory with an optional filter for file extensions.
@@ -455,6 +601,11 @@ class SysUtil:
455
601
  return list(dir_path.glob("*"))
456
602
 
457
603
  @staticmethod
604
+ @deprecated(
605
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.copy_file instead",
606
+ category=DeprecationWarning,
607
+ stacklevel=2,
608
+ )
458
609
  def copy_file(src: Path | str, dest: Path | str) -> None:
459
610
  """
460
611
  Copies a file from a source path to a destination path.
@@ -475,6 +626,11 @@ class SysUtil:
475
626
  copy2(src, dest)
476
627
 
477
628
  @staticmethod
629
+ @deprecated(
630
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.get_file_size instead",
631
+ category=DeprecationWarning,
632
+ stacklevel=2,
633
+ )
478
634
  def get_size(path: Path | str) -> int:
479
635
  """
480
636
  Gets the size of a file or total size of files in a directory.
@@ -499,6 +655,11 @@ class SysUtil:
499
655
  raise FileNotFoundError(f"{path} does not exist.")
500
656
 
501
657
  @staticmethod
658
+ @deprecated(
659
+ "Deprecated since v0.3, will be removed in v1.0. Use lionfuncs.save_to_file instead",
660
+ category=DeprecationWarning,
661
+ stacklevel=2,
662
+ )
502
663
  def save_to_file(
503
664
  text,
504
665
  directory: Path | str,
@@ -544,3 +705,6 @@ class SysUtil:
544
705
  print(f"Text saved to: {file_path}")
545
706
 
546
707
  return True
708
+
709
+
710
+ __all__ = ["SysUtil"]
File without changes
@@ -0,0 +1,6 @@
1
+ from .brainstorm import brainstorm
2
+ from .rank import rank
3
+ from .score import score
4
+ from .select import select
5
+
6
+ __all__ = ["brainstorm", "rank", "score", "select"]
@@ -0,0 +1,87 @@
1
+ from typing import Any
2
+
3
+ from lion_core.operative.step_model import StepModel
4
+ from lion_core.session.branch import Branch
5
+ from lion_service import iModel
6
+ from pydantic import BaseModel, Field
7
+
8
+ from .config import DEFAULT_CHAT_CONFIG
9
+
10
+
11
+ class BrainstormModel(BaseModel):
12
+
13
+ topic: str = Field(
14
+ default_factory=str,
15
+ description="**Specify the topic or theme for the brainstorming session.**",
16
+ )
17
+ ideas: list[StepModel] = Field(
18
+ default_factory=list,
19
+ description="**Provide a list of ideas needed to accomplish the objective. Each step should be as described in a `PlanStepModel`.**",
20
+ )
21
+
22
+
23
+ PROMPT = "Please follow prompt and provide {num_steps} different ideas for the next step"
24
+
25
+
26
+ async def brainstorm(
27
+ num_steps: int = 3,
28
+ instruction=None,
29
+ guidance=None,
30
+ context=None,
31
+ system=None,
32
+ reason: bool = False,
33
+ actions: bool = False,
34
+ tools: Any = None,
35
+ imodel: iModel = None,
36
+ branch: Branch = None,
37
+ sender=None,
38
+ recipient=None,
39
+ clear_messages: bool = False,
40
+ system_sender=None,
41
+ system_datetime=None,
42
+ return_branch=False,
43
+ num_parse_retries: int = 3,
44
+ retry_imodel: iModel = None,
45
+ branch_user=None,
46
+ **kwargs, # additional operate arguments
47
+ ):
48
+ if branch and branch.imodel:
49
+ imodel = imodel or branch.imodel
50
+ else:
51
+ imodel = imodel or iModel(**DEFAULT_CHAT_CONFIG)
52
+
53
+ prompt = PROMPT.format(num_steps=num_steps)
54
+
55
+ branch = branch or Branch(imodel=imodel)
56
+ if branch_user:
57
+ branch.user = branch_user
58
+
59
+ if system:
60
+ branch.add_message(
61
+ system=system,
62
+ system_datetime=system_datetime,
63
+ sender=system_sender,
64
+ )
65
+ _context = [{"operation": prompt}]
66
+ if context:
67
+ _context.append(context)
68
+
69
+ response = await branch.operate(
70
+ instruction=instruction,
71
+ guidance=guidance,
72
+ context=_context,
73
+ sender=sender,
74
+ recipient=recipient,
75
+ reason=reason,
76
+ actions=actions,
77
+ tools=tools,
78
+ clear_messages=clear_messages,
79
+ operative_model=BrainstormModel,
80
+ retry_imodel=retry_imodel,
81
+ num_parse_retries=num_parse_retries,
82
+ imodel=imodel,
83
+ **kwargs,
84
+ )
85
+ if return_branch:
86
+ return response, branch
87
+ return response
@@ -0,0 +1,6 @@
1
+ DEFAULT_CHAT_CONFIG = {
2
+ "provider": "openai",
3
+ "task": "chat",
4
+ "model": "gpt-4o-mini",
5
+ "api_key": "OPENAI_API_KEY",
6
+ }
@@ -0,0 +1,102 @@
1
+ import asyncio
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ from lion_core.session.branch import Branch
6
+ from lion_core.session.session import Session
7
+ from lion_service import iModel
8
+ from lionfuncs import alcall, to_list
9
+
10
+ from .config import DEFAULT_CHAT_CONFIG
11
+ from .score import score
12
+
13
+ PROMPT = (
14
+ "Given all items: \n {choices} \n\n Please follow prompt and give score "
15
+ "to the item of interest: \n {item} \n\n"
16
+ )
17
+
18
+
19
+ async def rank(
20
+ choices: list[Any],
21
+ num_scorers: int = 5,
22
+ instruction=None,
23
+ guidance=None,
24
+ context=None,
25
+ system=None,
26
+ reason: bool = False,
27
+ actions: bool = False,
28
+ tools: Any = None,
29
+ imodel: iModel = None,
30
+ branch: Branch = None, # branch won't be used for the vote, it is for configuration
31
+ clear_messages: bool = False,
32
+ system_sender=None,
33
+ system_datetime=None,
34
+ num_parse_retries: int = 0,
35
+ retry_imodel: iModel = None,
36
+ return_session: bool = False,
37
+ **kwargs, # additional kwargs for score function
38
+ ) -> dict:
39
+
40
+ if branch and branch.imodel:
41
+ imodel = imodel or branch.imodel
42
+ else:
43
+ imodel = imodel or iModel(**DEFAULT_CHAT_CONFIG)
44
+
45
+ branch = branch or Branch(imodel=imodel)
46
+ session = Session(default_branch=branch)
47
+
48
+ async def _score(item):
49
+ async with session.branches.async_lock:
50
+ b_ = session.new_branch(messages=session.default_branch.messages)
51
+
52
+ prompt = PROMPT.format(choices=choices, item=item)
53
+ if instruction:
54
+ prompt = f"{instruction}\n\n{prompt} \n\n "
55
+
56
+ kwargs["branch"] = b_
57
+ kwargs["score_range"] = kwargs.get("score_range", (1, 10))
58
+ kwargs["num_scores"] = kwargs.get("num_scores", 1)
59
+ kwargs["precision"] = kwargs.get("precision", 1)
60
+
61
+ response = await score(
62
+ instruction=prompt,
63
+ guidance=guidance,
64
+ context=context,
65
+ system=system,
66
+ system_datetime=system_datetime,
67
+ system_sender=system_sender,
68
+ sender=session.ln_id,
69
+ recipient=b_.ln_id,
70
+ default_score=-1,
71
+ reason=reason,
72
+ actions=actions,
73
+ tools=tools,
74
+ clear_messages=clear_messages,
75
+ num_parse_retries=num_parse_retries,
76
+ retry_imodel=retry_imodel,
77
+ **kwargs,
78
+ )
79
+
80
+ if response.score == -1:
81
+ return None
82
+
83
+ return response
84
+
85
+ async def _group_score(item):
86
+ tasks = [asyncio.create_task(_score(item)) for _ in range(num_scorers)]
87
+ responses = await asyncio.gather(*tasks)
88
+ responses = [i for i in responses if i is not None]
89
+ scores = to_list(
90
+ [i.score for i in responses], dropna=True, flatten=True
91
+ )
92
+ return {
93
+ "item": item,
94
+ "scores": scores,
95
+ "average": np.mean(scores) if scores else -1,
96
+ }
97
+
98
+ results = await alcall(choices, _group_score)
99
+ results = sorted(results, key=lambda x: x["average"], reverse=True)
100
+ if return_session:
101
+ return results, session
102
+ return results