lionagi 0.2.11__py3-none-any.whl → 0.3.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (152) hide show
  1. lionagi/core/action/function_calling.py +13 -6
  2. lionagi/core/action/tool.py +10 -9
  3. lionagi/core/action/tool_manager.py +18 -9
  4. lionagi/core/agent/README.md +1 -1
  5. lionagi/core/agent/base_agent.py +5 -2
  6. lionagi/core/agent/eval/README.md +1 -1
  7. lionagi/core/collections/README.md +1 -1
  8. lionagi/core/collections/_logger.py +16 -6
  9. lionagi/core/collections/abc/README.md +1 -1
  10. lionagi/core/collections/abc/component.py +35 -11
  11. lionagi/core/collections/abc/concepts.py +5 -3
  12. lionagi/core/collections/abc/exceptions.py +3 -1
  13. lionagi/core/collections/flow.py +16 -5
  14. lionagi/core/collections/model.py +34 -8
  15. lionagi/core/collections/pile.py +65 -28
  16. lionagi/core/collections/progression.py +1 -2
  17. lionagi/core/collections/util.py +11 -2
  18. lionagi/core/director/README.md +1 -1
  19. lionagi/core/engine/branch_engine.py +35 -10
  20. lionagi/core/engine/instruction_map_engine.py +14 -5
  21. lionagi/core/engine/sandbox_.py +3 -1
  22. lionagi/core/engine/script_engine.py +6 -2
  23. lionagi/core/executor/base_executor.py +10 -3
  24. lionagi/core/executor/graph_executor.py +12 -4
  25. lionagi/core/executor/neo4j_executor.py +18 -6
  26. lionagi/core/generic/edge.py +7 -2
  27. lionagi/core/generic/graph.py +23 -7
  28. lionagi/core/generic/node.py +14 -5
  29. lionagi/core/generic/tree_node.py +5 -1
  30. lionagi/core/mail/mail_manager.py +3 -1
  31. lionagi/core/mail/package.py +3 -1
  32. lionagi/core/message/action_request.py +9 -2
  33. lionagi/core/message/action_response.py +9 -3
  34. lionagi/core/message/instruction.py +8 -2
  35. lionagi/core/message/util.py +15 -5
  36. lionagi/core/report/base.py +12 -7
  37. lionagi/core/report/form.py +7 -4
  38. lionagi/core/report/report.py +10 -3
  39. lionagi/core/report/util.py +3 -1
  40. lionagi/core/rule/action.py +4 -1
  41. lionagi/core/rule/base.py +17 -6
  42. lionagi/core/rule/rulebook.py +8 -4
  43. lionagi/core/rule/string.py +3 -1
  44. lionagi/core/session/branch.py +15 -4
  45. lionagi/core/session/session.py +6 -2
  46. lionagi/core/unit/parallel_unit.py +9 -3
  47. lionagi/core/unit/template/action.py +1 -1
  48. lionagi/core/unit/template/predict.py +3 -1
  49. lionagi/core/unit/template/select.py +5 -3
  50. lionagi/core/unit/unit.py +4 -2
  51. lionagi/core/unit/unit_form.py +13 -15
  52. lionagi/core/unit/unit_mixin.py +45 -27
  53. lionagi/core/unit/util.py +7 -3
  54. lionagi/core/validator/validator.py +28 -15
  55. lionagi/core/work/work_edge.py +7 -3
  56. lionagi/core/work/work_task.py +11 -5
  57. lionagi/core/work/worker.py +20 -5
  58. lionagi/core/work/worker_engine.py +6 -2
  59. lionagi/core/work/worklog.py +3 -1
  60. lionagi/experimental/compressor/llm_compressor.py +20 -5
  61. lionagi/experimental/directive/README.md +1 -1
  62. lionagi/experimental/directive/parser/base_parser.py +41 -14
  63. lionagi/experimental/directive/parser/base_syntax.txt +23 -23
  64. lionagi/experimental/directive/template/base_template.py +14 -6
  65. lionagi/experimental/directive/tokenizer.py +3 -1
  66. lionagi/experimental/evaluator/README.md +1 -1
  67. lionagi/experimental/evaluator/ast_evaluator.py +6 -2
  68. lionagi/experimental/evaluator/base_evaluator.py +27 -16
  69. lionagi/integrations/bridge/autogen_/autogen_.py +7 -3
  70. lionagi/integrations/bridge/langchain_/documents.py +13 -10
  71. lionagi/integrations/bridge/llamaindex_/llama_pack.py +36 -12
  72. lionagi/integrations/bridge/llamaindex_/node_parser.py +8 -3
  73. lionagi/integrations/bridge/llamaindex_/reader.py +3 -1
  74. lionagi/integrations/bridge/llamaindex_/textnode.py +9 -3
  75. lionagi/integrations/bridge/pydantic_/pydantic_bridge.py +7 -1
  76. lionagi/integrations/bridge/transformers_/install_.py +3 -1
  77. lionagi/integrations/chunker/chunk.py +5 -2
  78. lionagi/integrations/loader/load.py +7 -3
  79. lionagi/integrations/loader/load_util.py +35 -16
  80. lionagi/integrations/provider/oai.py +13 -4
  81. lionagi/integrations/provider/openrouter.py +13 -4
  82. lionagi/integrations/provider/services.py +3 -1
  83. lionagi/integrations/provider/transformers.py +5 -3
  84. lionagi/integrations/storage/neo4j.py +23 -7
  85. lionagi/integrations/storage/storage_util.py +23 -7
  86. lionagi/integrations/storage/structure_excel.py +7 -2
  87. lionagi/integrations/storage/to_csv.py +8 -2
  88. lionagi/integrations/storage/to_excel.py +11 -3
  89. lionagi/libs/ln_api.py +41 -19
  90. lionagi/libs/ln_context.py +4 -4
  91. lionagi/libs/ln_convert.py +35 -14
  92. lionagi/libs/ln_dataframe.py +9 -3
  93. lionagi/libs/ln_func_call.py +53 -18
  94. lionagi/libs/ln_image.py +9 -5
  95. lionagi/libs/ln_knowledge_graph.py +21 -7
  96. lionagi/libs/ln_nested.py +57 -16
  97. lionagi/libs/ln_parse.py +45 -15
  98. lionagi/libs/ln_queue.py +8 -3
  99. lionagi/libs/ln_tokenize.py +19 -6
  100. lionagi/libs/ln_validate.py +14 -3
  101. lionagi/libs/sys_util.py +44 -12
  102. lionagi/lions/coder/coder.py +24 -8
  103. lionagi/lions/coder/util.py +6 -2
  104. lionagi/lions/researcher/data_source/google_.py +12 -4
  105. lionagi/lions/researcher/data_source/wiki_.py +3 -1
  106. lionagi/version.py +1 -1
  107. {lionagi-0.2.11.dist-info → lionagi-0.3.0.dist-info}/METADATA +6 -7
  108. lionagi-0.3.0.dist-info/RECORD +226 -0
  109. lionagi/tests/__init__.py +0 -0
  110. lionagi/tests/api/__init__.py +0 -0
  111. lionagi/tests/api/aws/__init__.py +0 -0
  112. lionagi/tests/api/aws/conftest.py +0 -25
  113. lionagi/tests/api/aws/test_aws_s3.py +0 -6
  114. lionagi/tests/integrations/__init__.py +0 -0
  115. lionagi/tests/libs/__init__.py +0 -0
  116. lionagi/tests/libs/test_api.py +0 -48
  117. lionagi/tests/libs/test_convert.py +0 -89
  118. lionagi/tests/libs/test_field_validators.py +0 -354
  119. lionagi/tests/libs/test_func_call.py +0 -701
  120. lionagi/tests/libs/test_nested.py +0 -382
  121. lionagi/tests/libs/test_parse.py +0 -171
  122. lionagi/tests/libs/test_queue.py +0 -68
  123. lionagi/tests/libs/test_sys_util.py +0 -222
  124. lionagi/tests/test_core/__init__.py +0 -0
  125. lionagi/tests/test_core/collections/__init__.py +0 -0
  126. lionagi/tests/test_core/collections/test_component.py +0 -208
  127. lionagi/tests/test_core/collections/test_exchange.py +0 -139
  128. lionagi/tests/test_core/collections/test_flow.py +0 -146
  129. lionagi/tests/test_core/collections/test_pile.py +0 -172
  130. lionagi/tests/test_core/collections/test_progression.py +0 -130
  131. lionagi/tests/test_core/generic/__init__.py +0 -0
  132. lionagi/tests/test_core/generic/test_edge.py +0 -69
  133. lionagi/tests/test_core/generic/test_graph.py +0 -97
  134. lionagi/tests/test_core/generic/test_node.py +0 -107
  135. lionagi/tests/test_core/generic/test_structure.py +0 -194
  136. lionagi/tests/test_core/generic/test_tree_node.py +0 -74
  137. lionagi/tests/test_core/graph/__init__.py +0 -0
  138. lionagi/tests/test_core/graph/test_graph.py +0 -71
  139. lionagi/tests/test_core/graph/test_tree.py +0 -76
  140. lionagi/tests/test_core/mail/__init__.py +0 -0
  141. lionagi/tests/test_core/mail/test_mail.py +0 -98
  142. lionagi/tests/test_core/test_branch.py +0 -116
  143. lionagi/tests/test_core/test_form.py +0 -47
  144. lionagi/tests/test_core/test_report.py +0 -106
  145. lionagi/tests/test_core/test_structure/__init__.py +0 -0
  146. lionagi/tests/test_core/test_structure/test_base_structure.py +0 -198
  147. lionagi/tests/test_core/test_structure/test_graph.py +0 -55
  148. lionagi/tests/test_core/test_structure/test_tree.py +0 -49
  149. lionagi/tests/test_core/test_validator.py +0 -112
  150. lionagi-0.2.11.dist-info/RECORD +0 -267
  151. {lionagi-0.2.11.dist-info → lionagi-0.3.0.dist-info}/LICENSE +0 -0
  152. {lionagi-0.2.11.dist-info → lionagi-0.3.0.dist-info}/WHEEL +0 -0
lionagi/libs/ln_api.py CHANGED
@@ -3,9 +3,9 @@ import contextlib
3
3
  import logging
4
4
  import re
5
5
  from abc import ABC
6
- from collections.abc import Mapping, Sequence
6
+ from collections.abc import Callable, Mapping, Sequence
7
7
  from dataclasses import dataclass
8
- from typing import Any, Callable, NoReturn, Type
8
+ from typing import Any, NoReturn, Type
9
9
 
10
10
  import aiohttp
11
11
 
@@ -70,7 +70,9 @@ class APIUtil:
70
70
  False
71
71
  """
72
72
  if "error" in response_json:
73
- logging.warning(f"API call failed with error: {response_json['error']}")
73
+ logging.warning(
74
+ f"API call failed with error: {response_json['error']}"
75
+ )
74
76
  return True
75
77
  return False
76
78
 
@@ -93,7 +95,9 @@ class APIUtil:
93
95
  >>> api_rate_limit_error(response_json_without_rate_limit)
94
96
  False
95
97
  """
96
- return "Rate limit" in response_json.get("error", {}).get("message", "")
98
+ return "Rate limit" in response_json.get("error", {}).get(
99
+ "message", ""
100
+ )
97
101
 
98
102
  @staticmethod
99
103
  @func_call.lru_cache(maxsize=128)
@@ -389,15 +393,15 @@ class APIUtil:
389
393
  num_tokens += ImageUtil.calculate_image_token_usage_from_base64(
390
394
  a, item.get("detail", "low")
391
395
  )
392
- num_tokens += (
393
- 20 # for every image we add 20 tokens buffer
394
- )
396
+ num_tokens += 20 # for every image we add 20 tokens buffer
395
397
  elif isinstance(item, str):
396
398
  num_tokens += len(encoding.encode(item))
397
399
  else:
398
400
  num_tokens += len(encoding.encode(str(item)))
399
401
 
400
- num_tokens += 2 # every reply is primed with <im_start>assistant
402
+ num_tokens += (
403
+ 2 # every reply is primed with <im_start>assistant
404
+ )
401
405
  return num_tokens + completion_tokens
402
406
  else:
403
407
  prompt = payload["format_prompt"]
@@ -405,7 +409,9 @@ class APIUtil:
405
409
  prompt_tokens = len(encoding.encode(prompt))
406
410
  return prompt_tokens + completion_tokens
407
411
  elif isinstance(prompt, list): # multiple prompts
408
- prompt_tokens = sum(len(encoding.encode(p)) for p in prompt)
412
+ prompt_tokens = sum(
413
+ len(encoding.encode(p)) for p in prompt
414
+ )
409
415
  return prompt_tokens + completion_tokens * len(prompt)
410
416
  else:
411
417
  raise TypeError(
@@ -427,7 +433,9 @@ class APIUtil:
427
433
  )
428
434
 
429
435
  @staticmethod
430
- def create_payload(input_, config, required_, optional_, input_key, **kwargs):
436
+ def create_payload(
437
+ input_, config, required_, optional_, input_key, **kwargs
438
+ ):
431
439
  config = {**config, **kwargs}
432
440
  payload = {input_key: input_}
433
441
 
@@ -514,7 +522,9 @@ class BaseRateLimiter(ABC):
514
522
  except asyncio.CancelledError:
515
523
  logging.info("Rate limit replenisher task cancelled.")
516
524
  except Exception as e:
517
- logging.error(f"An error occurred in the rate limit replenisher: {e}")
525
+ logging.error(
526
+ f"An error occurred in the rate limit replenisher: {e}"
527
+ )
518
528
 
519
529
  async def stop_replenishing(self) -> None:
520
530
  """Stops the replenishment task."""
@@ -658,7 +668,9 @@ class SimpleRateLimiter(BaseRateLimiter):
658
668
  token_encoding_name=None,
659
669
  ) -> None:
660
670
  """Initializes the SimpleRateLimiter with the specified parameters."""
661
- super().__init__(max_requests, max_tokens, interval, token_encoding_name)
671
+ super().__init__(
672
+ max_requests, max_tokens, interval, token_encoding_name
673
+ )
662
674
 
663
675
 
664
676
  class EndPoint:
@@ -695,7 +707,7 @@ class EndPoint:
695
707
  max_tokens: int = 100000,
696
708
  interval: int = 60,
697
709
  endpoint_: str | None = None,
698
- rate_limiter_class: Type[BaseRateLimiter] = SimpleRateLimiter,
710
+ rate_limiter_class: type[BaseRateLimiter] = SimpleRateLimiter,
699
711
  token_encoding_name=None,
700
712
  config: Mapping = None,
701
713
  ) -> None:
@@ -712,7 +724,10 @@ class EndPoint:
712
724
  async def init_rate_limiter(self) -> None:
713
725
  """Initializes the rate limiter for the endpoint."""
714
726
  self.rate_limiter = await self.rate_limiter_class.create(
715
- self.max_requests, self.max_tokens, self.interval, self.token_encoding_name
727
+ self.max_requests,
728
+ self.max_tokens,
729
+ self.interval,
730
+ self.token_encoding_name,
716
731
  )
717
732
  self._has_initialized = True
718
733
 
@@ -786,7 +801,9 @@ class BaseService:
786
801
  max_tokens=self.chat_config_rate_limit.get(
787
802
  "max_tokens", 100000
788
803
  ),
789
- interval=self.chat_config_rate_limit.get("interval", 60),
804
+ interval=self.chat_config_rate_limit.get(
805
+ "interval", 60
806
+ ),
790
807
  endpoint_=ep,
791
808
  token_encoding_name=self.token_encoding_name,
792
809
  config=endpoint_config,
@@ -795,17 +812,20 @@ class BaseService:
795
812
  self.endpoints[ep] = EndPoint(
796
813
  max_requests=(
797
814
  endpoint_config.get("max_requests", 1000)
798
- if endpoint_config.get("max_requests", 1000) is not None
815
+ if endpoint_config.get("max_requests", 1000)
816
+ is not None
799
817
  else 1000
800
818
  ),
801
819
  max_tokens=(
802
820
  endpoint_config.get("max_tokens", 100000)
803
- if endpoint_config.get("max_tokens", 100000) is not None
821
+ if endpoint_config.get("max_tokens", 100000)
822
+ is not None
804
823
  else 100000
805
824
  ),
806
825
  interval=(
807
826
  endpoint_config.get("interval", 60)
808
- if endpoint_config.get("interval", 60) is not None
827
+ if endpoint_config.get("interval", 60)
828
+ is not None
809
829
  else 60
810
830
  ),
811
831
  endpoint_=ep,
@@ -832,7 +852,9 @@ class BaseService:
832
852
  if not self.endpoints[ep]._has_initialized:
833
853
  await self.endpoints[ep].init_rate_limiter()
834
854
 
835
- async def call_api(self, payload, endpoint, method, required_tokens=None, **kwargs):
855
+ async def call_api(
856
+ self, payload, endpoint, method, required_tokens=None, **kwargs
857
+ ):
836
858
  """
837
859
  Calls the specified API endpoint with the given payload and method.
838
860
 
@@ -26,12 +26,12 @@ async def async_suppress_print():
26
26
  """
27
27
  An asynchronous context manager that redirects stdout to /dev/null to suppress print output.
28
28
  """
29
- original_stdout = sys.stdout # Save the reference to the original standard output
29
+ original_stdout = (
30
+ sys.stdout
31
+ ) # Save the reference to the original standard output
30
32
  with open(os.devnull, "w") as devnull:
31
33
  sys.stdout = devnull
32
34
  try:
33
35
  yield
34
36
  finally:
35
- sys.stdout = (
36
- original_stdout # Restore standard output to the original value
37
- )
37
+ sys.stdout = original_stdout # Restore standard output to the original value
@@ -1,7 +1,8 @@
1
1
  import json
2
2
  import re
3
+ from collections.abc import Generator, Iterable
3
4
  from functools import singledispatch
4
- from typing import Any, Generator, Iterable, Type
5
+ from typing import Any, Type
5
6
 
6
7
  import pandas as pd
7
8
  from pydantic import BaseModel
@@ -11,7 +12,9 @@ number_regex = re.compile(r"-?\d+\.?\d*")
11
12
 
12
13
  # to_list functions with datatype overloads
13
14
  @singledispatch
14
- def to_list(input_, /, *, flatten: bool = True, dropna: bool = True) -> list[Any]:
15
+ def to_list(
16
+ input_, /, *, flatten: bool = True, dropna: bool = True
17
+ ) -> list[Any]:
15
18
  """
16
19
  Converts the input object to a list. This function is capable of handling various input types,
17
20
  utilizing single dispatch to specialize for different types such as list, tuple, and set.
@@ -43,9 +46,13 @@ def to_list(input_, /, *, flatten: bool = True, dropna: bool = True) -> list[Any
43
46
  ):
44
47
  return [input_]
45
48
  iterable_list = list(input_)
46
- return _flatten_list(iterable_list, dropna) if flatten else iterable_list
49
+ return (
50
+ _flatten_list(iterable_list, dropna) if flatten else iterable_list
51
+ )
47
52
  except Exception as e:
48
- raise ValueError(f"Could not convert {type(input_)} object to list: {e}") from e
53
+ raise ValueError(
54
+ f"Could not convert {type(input_)} object to list: {e}"
55
+ ) from e
49
56
 
50
57
 
51
58
  @to_list.register(list)
@@ -362,7 +369,9 @@ def _(
362
369
  ) -> pd.DataFrame:
363
370
  if not input_:
364
371
  return pd.DataFrame()
365
- if not isinstance(input_[0], (pd.DataFrame, pd.Series, pd.core.generic.NDFrame)):
372
+ if not isinstance(
373
+ input_[0], (pd.DataFrame, pd.Series, pd.core.generic.NDFrame)
374
+ ):
366
375
  if drop_kwargs is None:
367
376
  drop_kwargs = {}
368
377
  try:
@@ -370,7 +379,9 @@ def _(
370
379
  dfs = dfs.dropna(**(drop_kwargs | {"how": how}))
371
380
  return dfs.reset_index(drop=True) if reset_index else dfs
372
381
  except Exception as e:
373
- raise ValueError(f"Error converting input_ to DataFrame: {e}") from e
382
+ raise ValueError(
383
+ f"Error converting input_ to DataFrame: {e}"
384
+ ) from e
374
385
 
375
386
  dfs = ""
376
387
  if drop_kwargs is None:
@@ -401,7 +412,7 @@ def to_num(
401
412
  *,
402
413
  upper_bound: int | float | None = None,
403
414
  lower_bound: int | float | None = None,
404
- num_type: Type[int | float] = float,
415
+ num_type: type[int | float] = float,
405
416
  precision: int | None = None,
406
417
  ) -> int | float:
407
418
  """
@@ -431,13 +442,17 @@ def to_readable_dict(input_: Any) -> str:
431
442
 
432
443
  try:
433
444
  dict_ = to_dict(input_)
434
- return json.dumps(dict_, indent=4) if isinstance(input_, dict) else input_
445
+ return (
446
+ json.dumps(dict_, indent=4) if isinstance(input_, dict) else input_
447
+ )
435
448
  except Exception as e:
436
- raise ValueError(f"Could not convert given input to readable dict: {e}") from e
449
+ raise ValueError(
450
+ f"Could not convert given input to readable dict: {e}"
451
+ ) from e
437
452
 
438
453
 
439
454
  def is_same_dtype(
440
- input_: list | dict, dtype: Type | None = None, return_dtype=False
455
+ input_: list | dict, dtype: type | None = None, return_dtype=False
441
456
  ) -> bool:
442
457
  """
443
458
  Checks if all elements in a list or dictionary values are of the same data type.
@@ -496,7 +511,9 @@ def strip_lower(input_: Any) -> str:
496
511
  try:
497
512
  return str(input_).strip().lower()
498
513
  except Exception as e:
499
- raise ValueError(f"Could not convert input_ to string: {input_}, Error: {e}")
514
+ raise ValueError(
515
+ f"Could not convert input_ to string: {input_}, Error: {e}"
516
+ )
500
517
 
501
518
 
502
519
  def is_structure_homogeneous(
@@ -552,7 +569,9 @@ def is_structure_homogeneous(
552
569
  return (is_, structure_type) if return_structure_type else is_
553
570
 
554
571
 
555
- def is_homogeneous(iterables: list[Any] | dict[Any, Any], type_check: type) -> bool:
572
+ def is_homogeneous(
573
+ iterables: list[Any] | dict[Any, Any], type_check: type
574
+ ) -> bool:
556
575
  if isinstance(iterables, list):
557
576
  return all(isinstance(it, type_check) for it in iterables)
558
577
  return isinstance(iterables, type_check)
@@ -562,7 +581,7 @@ def _str_to_num(
562
581
  input_: str,
563
582
  upper_bound: float | None = None,
564
583
  lower_bound: float | None = None,
565
- num_type: Type[int | float] = int,
584
+ num_type: type[int | float] = int,
566
585
  precision: int | None = None,
567
586
  ) -> int | float:
568
587
  number_str = _extract_first_number(input_)
@@ -590,7 +609,9 @@ def _extract_first_number(input_: str) -> str | None:
590
609
 
591
610
 
592
611
  def _convert_to_num(
593
- number_str: str, num_type: Type[int | float] = int, precision: int | None = None
612
+ number_str: str,
613
+ num_type: type[int | float] = int,
614
+ precision: int | None = None,
594
615
  ) -> int | float:
595
616
  if num_type is int:
596
617
  return int(float(number_str))
@@ -31,7 +31,9 @@ def extend_dataframe(
31
31
  try:
32
32
  if len(df2.dropna(how="all")) > 0 and len(df1.dropna(how="all")) > 0:
33
33
  df = convert.to_df([df1, df2])
34
- df.drop_duplicates(inplace=True, subset=[unique_col], keep=keep, **kwargs)
34
+ df.drop_duplicates(
35
+ inplace=True, subset=[unique_col], keep=keep, **kwargs
36
+ )
35
37
  df_ = convert.to_df(df)
36
38
  if len(df_) > 1:
37
39
  return df_
@@ -112,7 +114,9 @@ def replace_keyword(
112
114
  keyword, replacement, case=False, regex=False
113
115
  )
114
116
  else:
115
- df_.loc[:, column] = df_[column].str.replace(keyword, replacement, regex=False)
117
+ df_.loc[:, column] = df_[column].str.replace(
118
+ keyword, replacement, regex=False
119
+ )
116
120
 
117
121
  return df_ if inplace else True
118
122
 
@@ -160,7 +164,9 @@ def remove_last_n_rows(df: pd.DataFrame, steps: int) -> pd.DataFrame:
160
164
  return convert.to_df(df[:-steps])
161
165
 
162
166
 
163
- def update_row(df: pd.DataFrame, row: str | int, column: str | int, value: Any) -> bool:
167
+ def update_row(
168
+ df: pd.DataFrame, row: str | int, column: str | int, value: Any
169
+ ) -> bool:
164
170
  """
165
171
  Updates a row's value for a specified column in a DataFrame.
166
172
 
@@ -3,8 +3,9 @@ from __future__ import annotations
3
3
  import asyncio
4
4
  import functools
5
5
  import logging
6
+ from collections.abc import Callable, Coroutine
6
7
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import Any, Callable, Coroutine
8
+ from typing import Any
8
9
 
9
10
  from lionagi.libs.ln_async import AsyncUtil
10
11
  from lionagi.libs.ln_convert import to_list
@@ -65,9 +66,13 @@ def lcall(
65
66
  """
66
67
  lst = to_list(input_, dropna=dropna)
67
68
  if len(to_list(func)) != 1:
68
- raise ValueError("There must be one and only one function for list calling.")
69
+ raise ValueError(
70
+ "There must be one and only one function for list calling."
71
+ )
69
72
 
70
- return to_list([func(i, **kwargs) for i in lst], flatten=flatten, dropna=dropna)
73
+ return to_list(
74
+ [func(i, **kwargs) for i in lst], flatten=flatten, dropna=dropna
75
+ )
71
76
 
72
77
 
73
78
  async def alcall(
@@ -122,7 +127,9 @@ async def alcall(
122
127
  outs = await asyncio.gather(*tasks)
123
128
  outs_ = []
124
129
  for i in outs:
125
- outs_.append(await i if isinstance(i, (Coroutine, asyncio.Future)) else i)
130
+ outs_.append(
131
+ await i if isinstance(i, (Coroutine, asyncio.Future)) else i
132
+ )
126
133
 
127
134
  return to_list(outs_, flatten=flatten, dropna=dropna)
128
135
 
@@ -260,7 +267,9 @@ async def tcall(
260
267
  async def async_call() -> tuple[Any, float]:
261
268
  start_time = SysUtil.get_now(datetime_=False)
262
269
  if timeout is not None:
263
- result = await AsyncUtil.execute_timeout(func(*args, **kwargs), timeout)
270
+ result = await AsyncUtil.execute_timeout(
271
+ func(*args, **kwargs), timeout
272
+ )
264
273
  duration = SysUtil.get_now(datetime_=False) - start_time
265
274
  return (result, duration) if timing else result
266
275
  try:
@@ -282,12 +291,18 @@ async def tcall(
282
291
  handle_error(e)
283
292
 
284
293
  def handle_error(e: Exception):
285
- _msg = f"{err_msg} Error: {e}" if err_msg else f"An error occurred: {e}"
294
+ _msg = (
295
+ f"{err_msg} Error: {e}" if err_msg else f"An error occurred: {e}"
296
+ )
286
297
  print(_msg)
287
298
  if not ignore_err:
288
299
  raise
289
300
 
290
- return await async_call() if AsyncUtil.is_coroutine_func(func) else sync_call()
301
+ return (
302
+ await async_call()
303
+ if AsyncUtil.is_coroutine_func(func)
304
+ else sync_call()
305
+ )
291
306
 
292
307
 
293
308
  async def rcall(
@@ -348,7 +363,9 @@ async def rcall(
348
363
  start = SysUtil.get_now(datetime_=False)
349
364
  for attempt in range(retries + 1) if retries == 0 else range(retries):
350
365
  try:
351
- err_msg = f"Attempt {attempt + 1}/{retries}: " if retries > 0 else None
366
+ err_msg = (
367
+ f"Attempt {attempt + 1}/{retries}: " if retries > 0 else None
368
+ )
352
369
  if timing:
353
370
  return (
354
371
  await _tcall(
@@ -362,7 +379,9 @@ async def rcall(
362
379
  last_exception = e
363
380
  if attempt < retries:
364
381
  if verbose:
365
- print(f"Attempt {attempt + 1}/{retries} failed: {e}, retrying...")
382
+ print(
383
+ f"Attempt {attempt + 1}/{retries} failed: {e}, retrying..."
384
+ )
366
385
  await asyncio.sleep(delay)
367
386
  delay *= backoff_factor
368
387
  else:
@@ -485,9 +504,13 @@ async def _tcall(
485
504
  else default
486
505
  )
487
506
  else:
488
- raise asyncio.TimeoutError(err_msg) # Re-raise the timeout exception
507
+ raise asyncio.TimeoutError(
508
+ err_msg
509
+ ) # Re-raise the timeout exception
489
510
  except Exception as e:
490
- err_msg = f"{err_msg} Error: {e}" if err_msg else f"An error occurred: {e}"
511
+ err_msg = (
512
+ f"{err_msg} Error: {e}" if err_msg else f"An error occurred: {e}"
513
+ )
491
514
  if ignore_err:
492
515
  return (
493
516
  (default, SysUtil.get_now(datetime_=False) - start_time)
@@ -645,7 +668,9 @@ class CallDecorator:
645
668
  def decorator(func: Callable[..., Any]) -> Callable:
646
669
  @functools.wraps(func)
647
670
  async def wrapper(*args, **kwargs) -> Any:
648
- return await rcall(func, *args, default=default_value, **kwargs)
671
+ return await rcall(
672
+ func, *args, default=default_value, **kwargs
673
+ )
649
674
 
650
675
  return wrapper
651
676
 
@@ -863,8 +888,12 @@ class CallDecorator:
863
888
  k: preprocess(v, *preprocess_args, **preprocess_kwargs)
864
889
  for k, v in kwargs.items()
865
890
  }
866
- result = await func(*preprocessed_args, **preprocessed_kwargs)
867
- return postprocess(result, *postprocess_args, **postprocess_kwargs)
891
+ result = await func(
892
+ *preprocessed_args, **preprocessed_kwargs
893
+ )
894
+ return postprocess(
895
+ result, *postprocess_args, **postprocess_kwargs
896
+ )
868
897
 
869
898
  return async_wrapper
870
899
  else:
@@ -880,7 +909,9 @@ class CallDecorator:
880
909
  for k, v in kwargs.items()
881
910
  }
882
911
  result = func(*preprocessed_args, **preprocessed_kwargs)
883
- return postprocess(result, *postprocess_args, **postprocess_kwargs)
912
+ return postprocess(
913
+ result, *postprocess_args, **postprocess_kwargs
914
+ )
884
915
 
885
916
  return sync_wrapper
886
917
 
@@ -1165,7 +1196,9 @@ class Throttle:
1165
1196
 
1166
1197
  return wrapper
1167
1198
 
1168
- async def __call_async__(self, func: Callable[..., Any]) -> Callable[..., Any]:
1199
+ async def __call_async__(
1200
+ self, func: Callable[..., Any]
1201
+ ) -> Callable[..., Any]:
1169
1202
  """
1170
1203
  Decorates an asynchronous function with the throttling mechanism.
1171
1204
 
@@ -1187,7 +1220,9 @@ class Throttle:
1187
1220
  return wrapper
1188
1221
 
1189
1222
 
1190
- def _custom_error_handler(error: Exception, error_map: dict[type, Callable]) -> None:
1223
+ def _custom_error_handler(
1224
+ error: Exception, error_map: dict[type, Callable]
1225
+ ) -> None:
1191
1226
  # noinspection PyUnresolvedReferences
1192
1227
  """
1193
1228
  handle errors based on a given error mapping.
@@ -1262,7 +1297,7 @@ async def call_handler(
1262
1297
  raise
1263
1298
 
1264
1299
 
1265
- @functools.lru_cache(maxsize=None)
1300
+ @functools.cache
1266
1301
  def is_coroutine_func(func: Callable) -> bool:
1267
1302
  """
1268
1303
  checks if the specified function is an asyncio coroutine function.
lionagi/libs/ln_image.py CHANGED
@@ -10,7 +10,7 @@ class ImageUtil:
10
10
 
11
11
  @staticmethod
12
12
  def preprocess_image(
13
- image: np.ndarray, color_conversion_code: Optional[int] = None
13
+ image: np.ndarray, color_conversion_code: int | None = None
14
14
  ) -> np.ndarray:
15
15
  SysUtil.check_import("cv2", pip_name="opencv-python")
16
16
  import cv2
@@ -19,19 +19,23 @@ class ImageUtil:
19
19
  return cv2.cvtColor(image, color_conversion_code)
20
20
 
21
21
  @staticmethod
22
- def encode_image_to_base64(image: np.ndarray, file_extension: str = ".jpg") -> str:
22
+ def encode_image_to_base64(
23
+ image: np.ndarray, file_extension: str = ".jpg"
24
+ ) -> str:
23
25
  SysUtil.check_import("cv2", pip_name="opencv-python")
24
26
  import cv2
25
27
 
26
28
  success, buffer = cv2.imencode(file_extension, image)
27
29
  if not success:
28
- raise ValueError(f"Could not encode image to {file_extension} format.")
30
+ raise ValueError(
31
+ f"Could not encode image to {file_extension} format."
32
+ )
29
33
  encoded_image = base64.b64encode(buffer).decode("utf-8")
30
34
  return encoded_image
31
35
 
32
36
  @staticmethod
33
37
  def read_image_to_array(
34
- image_path: str, color_flag: Optional[int] = None
38
+ image_path: str, color_flag: int | None = None
35
39
  ) -> np.ndarray:
36
40
  SysUtil.check_import("cv2", pip_name="opencv-python")
37
41
  import cv2
@@ -45,7 +49,7 @@ class ImageUtil:
45
49
  @staticmethod
46
50
  def read_image_to_base64(
47
51
  image_path: str,
48
- color_flag: Optional[int] = None,
52
+ color_flag: int | None = None,
49
53
  ) -> str:
50
54
  image_path = str(image_path)
51
55
  image = ImageUtil.read_image_to_array(image_path, color_flag)
@@ -33,7 +33,9 @@ class KnowledgeBase:
33
33
  Initialize an empty Knowledge Base (KB) with empty dictionaries for entities, relations, and sources.
34
34
  """
35
35
  self.entities = {} # { entity_title: {...} }
36
- self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
36
+ self.relations = (
37
+ []
38
+ ) # [ head: entity_title, type: ..., tail: entity_title,
37
39
  # meta: { article_url: { spans: [...] } } ]
38
40
  self.sources = {} # { article_url: {...} }
39
41
 
@@ -48,7 +50,9 @@ class KnowledgeBase:
48
50
  article_url = list(r["meta"].keys())[0]
49
51
  source_data = kb2.sources[article_url]
50
52
  self.add_relation(
51
- r, source_data["article_title"], source_data["article_publish_date"]
53
+ r,
54
+ source_data["article_title"],
55
+ source_data["article_publish_date"],
52
56
  )
53
57
 
54
58
  def are_relations_equal(self, r1, r2):
@@ -137,7 +141,9 @@ class KnowledgeBase:
137
141
  Args:
138
142
  e (dict): A dictionary containing information about the entity (title and additional attributes).
139
143
  """
140
- self.entities[e["title"]] = {k: v for k, v in e.items() if k != "title"}
144
+ self.entities[e["title"]] = {
145
+ k: v for k, v in e.items() if k != "title"
146
+ }
141
147
 
142
148
  def add_relation(self, r, article_title, article_publish_date):
143
149
  """
@@ -210,7 +216,9 @@ class KnowledgeBase:
210
216
  relation, subject, relation, object_ = "", "", "", ""
211
217
  text = text.strip()
212
218
  current = "x"
213
- text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
219
+ text_replaced = (
220
+ text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
221
+ )
214
222
  for token in text_replaced.split():
215
223
  if token == "<triplet>":
216
224
  current = "t"
@@ -311,7 +319,9 @@ class KGTripletExtractor:
311
319
 
312
320
  if not any([model, tokenizer]):
313
321
  tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
314
- model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
322
+ model = AutoModelForSeq2SeqLM.from_pretrained(
323
+ "Babelscape/rebel-large"
324
+ )
315
325
  model.to(device)
316
326
 
317
327
  inputs = tokenizer([text], return_tensors="pt")
@@ -373,10 +383,14 @@ class KGTripletExtractor:
373
383
  i = 0
374
384
  for sentence_pred in decoded_preds:
375
385
  current_span_index = i // num_return_sequences
376
- relations = KnowledgeBase.extract_relations_from_model_output(sentence_pred)
386
+ relations = KnowledgeBase.extract_relations_from_model_output(
387
+ sentence_pred
388
+ )
377
389
  for relation in relations:
378
390
  relation["meta"] = {
379
- "article_url": {"spans": [spans_boundaries[current_span_index]]}
391
+ "article_url": {
392
+ "spans": [spans_boundaries[current_span_index]]
393
+ }
380
394
  }
381
395
  kb.add_relation(relation, article_title, article_publish_date)
382
396
  i += 1