lionagi 0.0.312__py3-none-any.whl → 0.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (268) hide show
  1. lionagi/__init__.py +61 -3
  2. lionagi/core/__init__.py +0 -14
  3. lionagi/core/_setting/_setting.py +59 -0
  4. lionagi/core/action/__init__.py +14 -0
  5. lionagi/core/action/function_calling.py +136 -0
  6. lionagi/core/action/manual.py +1 -0
  7. lionagi/core/action/node.py +109 -0
  8. lionagi/core/action/tool.py +114 -0
  9. lionagi/core/action/tool_manager.py +356 -0
  10. lionagi/core/agent/__init__.py +0 -3
  11. lionagi/core/agent/base_agent.py +45 -36
  12. lionagi/core/agent/eval/evaluator.py +1 -0
  13. lionagi/core/agent/eval/vote.py +40 -0
  14. lionagi/core/agent/learn/learner.py +59 -0
  15. lionagi/core/agent/plan/unit_template.py +1 -0
  16. lionagi/core/collections/__init__.py +17 -0
  17. lionagi/core/collections/_logger.py +319 -0
  18. lionagi/core/collections/abc/__init__.py +53 -0
  19. lionagi/core/collections/abc/component.py +615 -0
  20. lionagi/core/collections/abc/concepts.py +297 -0
  21. lionagi/core/collections/abc/exceptions.py +150 -0
  22. lionagi/core/collections/abc/util.py +45 -0
  23. lionagi/core/collections/exchange.py +161 -0
  24. lionagi/core/collections/flow.py +426 -0
  25. lionagi/core/collections/model.py +419 -0
  26. lionagi/core/collections/pile.py +913 -0
  27. lionagi/core/collections/progression.py +236 -0
  28. lionagi/core/collections/util.py +64 -0
  29. lionagi/core/director/direct.py +314 -0
  30. lionagi/core/director/director.py +2 -0
  31. lionagi/core/engine/branch_engine.py +333 -0
  32. lionagi/core/engine/instruction_map_engine.py +204 -0
  33. lionagi/core/engine/sandbox_.py +14 -0
  34. lionagi/core/engine/script_engine.py +99 -0
  35. lionagi/core/executor/base_executor.py +90 -0
  36. lionagi/core/executor/graph_executor.py +330 -0
  37. lionagi/core/executor/neo4j_executor.py +384 -0
  38. lionagi/core/generic/__init__.py +7 -0
  39. lionagi/core/generic/edge.py +112 -0
  40. lionagi/core/generic/edge_condition.py +16 -0
  41. lionagi/core/generic/graph.py +236 -0
  42. lionagi/core/generic/hyperedge.py +1 -0
  43. lionagi/core/generic/node.py +220 -0
  44. lionagi/core/generic/tree.py +48 -0
  45. lionagi/core/generic/tree_node.py +79 -0
  46. lionagi/core/mail/__init__.py +7 -3
  47. lionagi/core/mail/mail.py +25 -0
  48. lionagi/core/mail/mail_manager.py +142 -58
  49. lionagi/core/mail/package.py +45 -0
  50. lionagi/core/mail/start_mail.py +36 -0
  51. lionagi/core/message/__init__.py +19 -0
  52. lionagi/core/message/action_request.py +133 -0
  53. lionagi/core/message/action_response.py +135 -0
  54. lionagi/core/message/assistant_response.py +95 -0
  55. lionagi/core/message/instruction.py +234 -0
  56. lionagi/core/message/message.py +101 -0
  57. lionagi/core/message/system.py +86 -0
  58. lionagi/core/message/util.py +283 -0
  59. lionagi/core/report/__init__.py +4 -0
  60. lionagi/core/report/base.py +217 -0
  61. lionagi/core/report/form.py +231 -0
  62. lionagi/core/report/report.py +166 -0
  63. lionagi/core/report/util.py +28 -0
  64. lionagi/core/rule/__init__.py +0 -0
  65. lionagi/core/rule/_default.py +16 -0
  66. lionagi/core/rule/action.py +99 -0
  67. lionagi/core/rule/base.py +238 -0
  68. lionagi/core/rule/boolean.py +56 -0
  69. lionagi/core/rule/choice.py +47 -0
  70. lionagi/core/rule/mapping.py +96 -0
  71. lionagi/core/rule/number.py +71 -0
  72. lionagi/core/rule/rulebook.py +109 -0
  73. lionagi/core/rule/string.py +52 -0
  74. lionagi/core/rule/util.py +35 -0
  75. lionagi/core/session/__init__.py +0 -3
  76. lionagi/core/session/branch.py +431 -0
  77. lionagi/core/session/directive_mixin.py +287 -0
  78. lionagi/core/session/session.py +230 -902
  79. lionagi/core/structure/__init__.py +1 -0
  80. lionagi/core/structure/chain.py +1 -0
  81. lionagi/core/structure/forest.py +1 -0
  82. lionagi/core/structure/graph.py +1 -0
  83. lionagi/core/structure/tree.py +1 -0
  84. lionagi/core/unit/__init__.py +5 -0
  85. lionagi/core/unit/parallel_unit.py +245 -0
  86. lionagi/core/unit/template/__init__.py +0 -0
  87. lionagi/core/unit/template/action.py +81 -0
  88. lionagi/core/unit/template/base.py +51 -0
  89. lionagi/core/unit/template/plan.py +84 -0
  90. lionagi/core/unit/template/predict.py +109 -0
  91. lionagi/core/unit/template/score.py +124 -0
  92. lionagi/core/unit/template/select.py +104 -0
  93. lionagi/core/unit/unit.py +362 -0
  94. lionagi/core/unit/unit_form.py +305 -0
  95. lionagi/core/unit/unit_mixin.py +1168 -0
  96. lionagi/core/unit/util.py +71 -0
  97. lionagi/core/validator/__init__.py +0 -0
  98. lionagi/core/validator/validator.py +364 -0
  99. lionagi/core/work/__init__.py +0 -0
  100. lionagi/core/work/work.py +76 -0
  101. lionagi/core/work/work_function.py +101 -0
  102. lionagi/core/work/work_queue.py +103 -0
  103. lionagi/core/work/worker.py +258 -0
  104. lionagi/core/work/worklog.py +120 -0
  105. lionagi/experimental/__init__.py +0 -0
  106. lionagi/experimental/compressor/__init__.py +0 -0
  107. lionagi/experimental/compressor/base.py +46 -0
  108. lionagi/experimental/compressor/llm_compressor.py +247 -0
  109. lionagi/experimental/compressor/llm_summarizer.py +61 -0
  110. lionagi/experimental/compressor/util.py +70 -0
  111. lionagi/experimental/directive/__init__.py +19 -0
  112. lionagi/experimental/directive/parser/__init__.py +0 -0
  113. lionagi/experimental/directive/parser/base_parser.py +282 -0
  114. lionagi/experimental/directive/template/__init__.py +0 -0
  115. lionagi/experimental/directive/template/base_template.py +79 -0
  116. lionagi/experimental/directive/template/schema.py +36 -0
  117. lionagi/experimental/directive/tokenizer.py +73 -0
  118. lionagi/experimental/evaluator/__init__.py +0 -0
  119. lionagi/experimental/evaluator/ast_evaluator.py +131 -0
  120. lionagi/experimental/evaluator/base_evaluator.py +218 -0
  121. lionagi/experimental/knowledge/__init__.py +0 -0
  122. lionagi/experimental/knowledge/base.py +10 -0
  123. lionagi/experimental/knowledge/graph.py +0 -0
  124. lionagi/experimental/memory/__init__.py +0 -0
  125. lionagi/experimental/strategies/__init__.py +0 -0
  126. lionagi/experimental/strategies/base.py +1 -0
  127. lionagi/integrations/bridge/autogen_/__init__.py +0 -0
  128. lionagi/integrations/bridge/autogen_/autogen_.py +124 -0
  129. lionagi/integrations/bridge/langchain_/documents.py +4 -0
  130. lionagi/integrations/bridge/llamaindex_/index.py +30 -0
  131. lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
  132. lionagi/integrations/bridge/llamaindex_/llama_pack.py +227 -0
  133. lionagi/integrations/bridge/llamaindex_/node_parser.py +6 -9
  134. lionagi/integrations/bridge/pydantic_/pydantic_bridge.py +1 -0
  135. lionagi/integrations/bridge/transformers_/__init__.py +0 -0
  136. lionagi/integrations/bridge/transformers_/install_.py +36 -0
  137. lionagi/integrations/chunker/__init__.py +0 -0
  138. lionagi/integrations/chunker/chunk.py +312 -0
  139. lionagi/integrations/config/oai_configs.py +38 -7
  140. lionagi/integrations/config/ollama_configs.py +1 -1
  141. lionagi/integrations/config/openrouter_configs.py +14 -2
  142. lionagi/integrations/loader/__init__.py +0 -0
  143. lionagi/integrations/loader/load.py +253 -0
  144. lionagi/integrations/loader/load_util.py +195 -0
  145. lionagi/integrations/provider/_mapping.py +46 -0
  146. lionagi/integrations/provider/litellm.py +2 -1
  147. lionagi/integrations/provider/mlx_service.py +16 -9
  148. lionagi/integrations/provider/oai.py +91 -4
  149. lionagi/integrations/provider/ollama.py +7 -6
  150. lionagi/integrations/provider/openrouter.py +115 -8
  151. lionagi/integrations/provider/services.py +2 -2
  152. lionagi/integrations/provider/transformers.py +18 -22
  153. lionagi/integrations/storage/__init__.py +3 -0
  154. lionagi/integrations/storage/neo4j.py +665 -0
  155. lionagi/integrations/storage/storage_util.py +287 -0
  156. lionagi/integrations/storage/structure_excel.py +285 -0
  157. lionagi/integrations/storage/to_csv.py +63 -0
  158. lionagi/integrations/storage/to_excel.py +83 -0
  159. lionagi/libs/__init__.py +26 -1
  160. lionagi/libs/ln_api.py +78 -23
  161. lionagi/libs/ln_context.py +37 -0
  162. lionagi/libs/ln_convert.py +21 -9
  163. lionagi/libs/ln_func_call.py +69 -28
  164. lionagi/libs/ln_image.py +107 -0
  165. lionagi/libs/ln_knowledge_graph.py +405 -0
  166. lionagi/libs/ln_nested.py +26 -11
  167. lionagi/libs/ln_parse.py +110 -14
  168. lionagi/libs/ln_queue.py +117 -0
  169. lionagi/libs/ln_tokenize.py +164 -0
  170. lionagi/{core/prompt/field_validator.py → libs/ln_validate.py} +79 -14
  171. lionagi/libs/special_tokens.py +172 -0
  172. lionagi/libs/sys_util.py +107 -2
  173. lionagi/lions/__init__.py +0 -0
  174. lionagi/lions/coder/__init__.py +0 -0
  175. lionagi/lions/coder/add_feature.py +20 -0
  176. lionagi/lions/coder/base_prompts.py +22 -0
  177. lionagi/lions/coder/code_form.py +13 -0
  178. lionagi/lions/coder/coder.py +168 -0
  179. lionagi/lions/coder/util.py +96 -0
  180. lionagi/lions/researcher/__init__.py +0 -0
  181. lionagi/lions/researcher/data_source/__init__.py +0 -0
  182. lionagi/lions/researcher/data_source/finhub_.py +191 -0
  183. lionagi/lions/researcher/data_source/google_.py +199 -0
  184. lionagi/lions/researcher/data_source/wiki_.py +96 -0
  185. lionagi/lions/researcher/data_source/yfinance_.py +21 -0
  186. lionagi/tests/integrations/__init__.py +0 -0
  187. lionagi/tests/libs/__init__.py +0 -0
  188. lionagi/tests/libs/test_field_validators.py +353 -0
  189. lionagi/tests/{test_libs → libs}/test_func_call.py +23 -21
  190. lionagi/tests/{test_libs → libs}/test_nested.py +36 -21
  191. lionagi/tests/{test_libs → libs}/test_parse.py +1 -1
  192. lionagi/tests/libs/test_queue.py +67 -0
  193. lionagi/tests/test_core/collections/__init__.py +0 -0
  194. lionagi/tests/test_core/collections/test_component.py +206 -0
  195. lionagi/tests/test_core/collections/test_exchange.py +138 -0
  196. lionagi/tests/test_core/collections/test_flow.py +145 -0
  197. lionagi/tests/test_core/collections/test_pile.py +171 -0
  198. lionagi/tests/test_core/collections/test_progression.py +129 -0
  199. lionagi/tests/test_core/generic/__init__.py +0 -0
  200. lionagi/tests/test_core/generic/test_edge.py +67 -0
  201. lionagi/tests/test_core/generic/test_graph.py +96 -0
  202. lionagi/tests/test_core/generic/test_node.py +106 -0
  203. lionagi/tests/test_core/generic/test_tree_node.py +73 -0
  204. lionagi/tests/test_core/test_branch.py +115 -292
  205. lionagi/tests/test_core/test_form.py +46 -0
  206. lionagi/tests/test_core/test_report.py +105 -0
  207. lionagi/tests/test_core/test_validator.py +111 -0
  208. lionagi/version.py +1 -1
  209. {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/LICENSE +12 -11
  210. {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/METADATA +19 -118
  211. lionagi-0.2.1.dist-info/RECORD +240 -0
  212. lionagi/core/branch/__init__.py +0 -4
  213. lionagi/core/branch/base_branch.py +0 -654
  214. lionagi/core/branch/branch.py +0 -471
  215. lionagi/core/branch/branch_flow_mixin.py +0 -96
  216. lionagi/core/branch/executable_branch.py +0 -347
  217. lionagi/core/branch/util.py +0 -323
  218. lionagi/core/direct/__init__.py +0 -6
  219. lionagi/core/direct/predict.py +0 -161
  220. lionagi/core/direct/score.py +0 -278
  221. lionagi/core/direct/select.py +0 -169
  222. lionagi/core/direct/utils.py +0 -87
  223. lionagi/core/direct/vote.py +0 -64
  224. lionagi/core/flow/base/baseflow.py +0 -23
  225. lionagi/core/flow/monoflow/ReAct.py +0 -238
  226. lionagi/core/flow/monoflow/__init__.py +0 -9
  227. lionagi/core/flow/monoflow/chat.py +0 -95
  228. lionagi/core/flow/monoflow/chat_mixin.py +0 -263
  229. lionagi/core/flow/monoflow/followup.py +0 -214
  230. lionagi/core/flow/polyflow/__init__.py +0 -1
  231. lionagi/core/flow/polyflow/chat.py +0 -248
  232. lionagi/core/mail/schema.py +0 -56
  233. lionagi/core/messages/__init__.py +0 -3
  234. lionagi/core/messages/schema.py +0 -533
  235. lionagi/core/prompt/prompt_template.py +0 -316
  236. lionagi/core/schema/__init__.py +0 -22
  237. lionagi/core/schema/action_node.py +0 -29
  238. lionagi/core/schema/base_mixin.py +0 -296
  239. lionagi/core/schema/base_node.py +0 -199
  240. lionagi/core/schema/condition.py +0 -24
  241. lionagi/core/schema/data_logger.py +0 -354
  242. lionagi/core/schema/data_node.py +0 -93
  243. lionagi/core/schema/prompt_template.py +0 -67
  244. lionagi/core/schema/structure.py +0 -910
  245. lionagi/core/tool/__init__.py +0 -3
  246. lionagi/core/tool/tool_manager.py +0 -280
  247. lionagi/integrations/bridge/pydantic_/base_model.py +0 -7
  248. lionagi/tests/test_core/test_base_branch.py +0 -427
  249. lionagi/tests/test_core/test_chat_flow.py +0 -63
  250. lionagi/tests/test_core/test_mail_manager.py +0 -75
  251. lionagi/tests/test_core/test_prompts.py +0 -51
  252. lionagi/tests/test_core/test_session.py +0 -254
  253. lionagi/tests/test_core/test_session_base_util.py +0 -312
  254. lionagi/tests/test_core/test_tool_manager.py +0 -95
  255. lionagi-0.0.312.dist-info/RECORD +0 -111
  256. /lionagi/core/{branch/base → _setting}/__init__.py +0 -0
  257. /lionagi/core/{flow → agent/eval}/__init__.py +0 -0
  258. /lionagi/core/{flow/base → agent/learn}/__init__.py +0 -0
  259. /lionagi/core/{prompt → agent/plan}/__init__.py +0 -0
  260. /lionagi/core/{tool/manual.py → agent/plan/plan.py} +0 -0
  261. /lionagi/{tests/test_integrations → core/director}/__init__.py +0 -0
  262. /lionagi/{tests/test_libs → core/engine}/__init__.py +0 -0
  263. /lionagi/{tests/test_libs/test_async.py → core/executor/__init__.py} +0 -0
  264. /lionagi/tests/{test_libs → libs}/test_api.py +0 -0
  265. /lionagi/tests/{test_libs → libs}/test_convert.py +0 -0
  266. /lionagi/tests/{test_libs → libs}/test_sys_util.py +0 -0
  267. {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/WHEEL +0 -0
  268. {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,107 @@
1
+ import base64
2
+ import numpy as np
3
+ from typing import Optional
4
+ from .sys_util import SysUtil
5
+
6
+
7
+ class ImageUtil:
8
+
9
+ @staticmethod
10
+ def preprocess_image(
11
+ image: np.ndarray, color_conversion_code: Optional[int] = None
12
+ ) -> np.ndarray:
13
+ SysUtil.check_import("cv2", pip_name="opencv-python")
14
+ import cv2
15
+
16
+ color_conversion_code = color_conversion_code or cv2.COLOR_BGR2RGB
17
+ return cv2.cvtColor(image, color_conversion_code)
18
+
19
+ @staticmethod
20
+ def encode_image_to_base64(image: np.ndarray, file_extension: str = ".jpg") -> str:
21
+ SysUtil.check_import("cv2", pip_name="opencv-python")
22
+ import cv2
23
+
24
+ success, buffer = cv2.imencode(file_extension, image)
25
+ if not success:
26
+ raise ValueError(f"Could not encode image to {file_extension} format.")
27
+ encoded_image = base64.b64encode(buffer).decode("utf-8")
28
+ return encoded_image
29
+
30
+ @staticmethod
31
+ def read_image_to_array(
32
+ image_path: str, color_flag: Optional[int] = None
33
+ ) -> np.ndarray:
34
+ SysUtil.check_import("cv2", pip_name="opencv-python")
35
+ import cv2
36
+
37
+ image = cv2.imread(image_path, color_flag)
38
+ color_flag = color_flag or cv2.IMREAD_COLOR
39
+ if image is None:
40
+ raise ValueError(f"Could not read image from path: {image_path}")
41
+ return image
42
+
43
+ @staticmethod
44
+ def read_image_to_base64(
45
+ image_path: str,
46
+ color_flag: Optional[int] = None,
47
+ ) -> str:
48
+ image_path = str(image_path)
49
+ image = ImageUtil.read_image_to_array(image_path, color_flag)
50
+
51
+ file_extension = "." + image_path.split(".")[-1]
52
+ return ImageUtil.encode_image_to_base64(image, file_extension)
53
+
54
+ # @staticmethod
55
+ # def encode_image(image_path):
56
+ # with open(image_path, "rb") as image_file:
57
+ # return base64.b64encode(image_file.read()).decode("utf-8")
58
+
59
+ @staticmethod
60
+ def calculate_image_token_usage_from_base64(image_base64: str, detail):
61
+ """
62
+ Calculate the token usage for processing OpenAI images from a base64-encoded string.
63
+
64
+ Parameters:
65
+ image_base64 (str): The base64-encoded string of the image.
66
+ detail (str): The detail level of the image, either 'low' or 'high'.
67
+
68
+ Returns:
69
+ int: The total token cost for processing the image.
70
+ """
71
+ import base64
72
+ from io import BytesIO
73
+ from PIL import Image
74
+
75
+ # Decode the base64 string to get image data
76
+ if "data:image/jpeg;base64," in image_base64:
77
+ image_base64 = image_base64.split("data:image/jpeg;base64,")[1]
78
+ image_base64.strip("{}")
79
+
80
+ image_data = base64.b64decode(image_base64)
81
+ image = Image.open(BytesIO(image_data))
82
+
83
+ # Get image dimensions
84
+ width, height = image.size
85
+
86
+ if detail == "low":
87
+ return 85
88
+
89
+ # Scale to fit within a 2048 x 2048 square
90
+ max_dimension = 2048
91
+ if width > max_dimension or height > max_dimension:
92
+ scale_factor = max_dimension / max(width, height)
93
+ width = int(width * scale_factor)
94
+ height = int(height * scale_factor)
95
+
96
+ # Scale such that the shortest side is 768px
97
+ min_side = 768
98
+ if min(width, height) > min_side:
99
+ scale_factor = min_side / min(width, height)
100
+ width = int(width * scale_factor)
101
+ height = int(height * scale_factor)
102
+
103
+ # Calculate the number of 512px squares
104
+ num_squares = (width // 512) * (height // 512)
105
+ token_cost = 170 * num_squares + 85
106
+
107
+ return token_cost
@@ -0,0 +1,405 @@
1
+ import math
2
+ from lionagi.libs import CallDecorator as cd
3
+
4
+
5
+ class KnowledgeBase:
6
+ """
7
+ A class to represent a Knowledge Base (KB) containing entities, relations, and sources.
8
+
9
+ Attributes:
10
+ entities (dict): A dictionary of entities in the KB, where the keys are entity titles, and the values are
11
+ entity information (excluding the title).
12
+ relations (list): A list of relations in the KB, where each relation is a dictionary containing information
13
+ about the relation (head, type, tail) and metadata (article_url and spans).
14
+ sources (dict): A dictionary of information about the sources of relations, where the keys are article URLs,
15
+ and the values are source data (article_title and article_publish_date).
16
+
17
+ Methods:
18
+ merge_with_kb(kb2): Merge another Knowledge Base (kb2) into this KB.
19
+ are_relations_equal(r1, r2): Check if two relations (r1 and r2) are equal.
20
+ exists_relation(r1): Check if a relation (r1) already exists in the KB.
21
+ merge_relations(r2): Merge the information from relation r2 into an existing relation in the KB.
22
+ get_wikipedia_data(candidate_entity): Get data for a candidate entity from Wikipedia.
23
+ add_entity(e): Add an entity to the KB.
24
+ add_relation(r, article_title, article_publish_date): Add a relation to the KB.
25
+ print(): Print the entities, relations, and sources in the KB.
26
+ extract_relations_from_model_output(text): Extract relations from the model output text.
27
+
28
+ """
29
+
30
+ def __init__(self):
31
+ """
32
+ Initialize an empty Knowledge Base (KB) with empty dictionaries for entities, relations, and sources.
33
+ """
34
+ self.entities = {} # { entity_title: {...} }
35
+ self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
36
+ # meta: { article_url: { spans: [...] } } ]
37
+ self.sources = {} # { article_url: {...} }
38
+
39
+ def merge_with_kb(self, kb2):
40
+ """
41
+ Merge another Knowledge Base (KB) into this KB.
42
+
43
+ Args:
44
+ kb2 (KnowledgeBase): The Knowledge Base (KB) to merge into this KB.
45
+ """
46
+ for r in kb2.relations:
47
+ article_url = list(r["meta"].keys())[0]
48
+ source_data = kb2.sources[article_url]
49
+ self.add_relation(
50
+ r, source_data["article_title"], source_data["article_publish_date"]
51
+ )
52
+
53
+ def are_relations_equal(self, r1, r2):
54
+ """
55
+ Check if two relations (r1 and r2) are equal.
56
+
57
+ Args:
58
+ r1 (dict): The first relation to compare.
59
+ r2 (dict): The second relation to compare.
60
+
61
+ Returns:
62
+ bool: True if the relations are equal, False otherwise.
63
+ """
64
+ return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
65
+
66
+ def exists_relation(self, r1):
67
+ """
68
+ Check if a relation (r1) already exists in the KB.
69
+
70
+ Args:
71
+ r1 (dict): The relation to check for existence in the KB.
72
+
73
+ Returns:
74
+ bool: True if the relation exists in the KB, False otherwise.
75
+ """
76
+ return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
77
+
78
+ def merge_relations(self, r2):
79
+ """
80
+ Merge the information from relation r2 into an existing relation in the KB.
81
+
82
+ Args:
83
+ r2 (dict): The relation to merge into an existing relation in the KB.
84
+ """
85
+ r1 = [r for r in self.relations if self.are_relations_equal(r2, r)][0]
86
+
87
+ # if different article
88
+ article_url = list(r2["meta"].keys())[0]
89
+ if article_url not in r1["meta"]:
90
+ r1["meta"][article_url] = r2["meta"][article_url]
91
+
92
+ # if existing article
93
+ else:
94
+ spans_to_add = [
95
+ span
96
+ for span in r2["meta"][article_url]["spans"]
97
+ if span not in r1["meta"][article_url]["spans"]
98
+ ]
99
+ r1["meta"][article_url]["spans"] += spans_to_add
100
+
101
+ @cd.cache(maxsize=10000)
102
+ def get_wikipedia_data(self, candidate_entity):
103
+ """
104
+ Get data for a candidate entity from Wikipedia.
105
+
106
+ Args:
107
+ candidate_entity (str): The candidate entity title.
108
+
109
+ Returns:
110
+ dict: A dictionary containing information about the candidate entity (title, url, summary).
111
+ None if the entity does not exist in Wikipedia.
112
+ """
113
+ try:
114
+ from lionagi.libs import SysUtil
115
+
116
+ SysUtil.check_import("wikipedia")
117
+ import wikipedia # type: ignore
118
+ except Exception as e:
119
+ raise Exception("wikipedia package is not installed {e}")
120
+
121
+ try:
122
+ page = wikipedia.page(candidate_entity, auto_suggest=False)
123
+ entity_data = {
124
+ "title": page.title,
125
+ "url": page.url,
126
+ "summary": page.summary,
127
+ }
128
+ return entity_data
129
+ except:
130
+ return None
131
+
132
+ def add_entity(self, e):
133
+ """
134
+ Add an entity to the KB.
135
+
136
+ Args:
137
+ e (dict): A dictionary containing information about the entity (title and additional attributes).
138
+ """
139
+ self.entities[e["title"]] = {k: v for k, v in e.items() if k != "title"}
140
+
141
+ def add_relation(self, r, article_title, article_publish_date):
142
+ """
143
+ Add a relation to the KB.
144
+
145
+ Args:
146
+ r (dict): A dictionary containing information about the relation (head, type, tail, and metadata).
147
+ article_title (str): The title of the article containing the relation.
148
+ article_publish_date (str): The publish date of the article.
149
+ """
150
+ # check on wikipedia
151
+ candidate_entities = [r["head"], r["tail"]]
152
+ entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
153
+
154
+ # if one entity does not exist, stop
155
+ if any(ent is None for ent in entities):
156
+ return
157
+
158
+ # manage new entities
159
+ for e in entities:
160
+ self.add_entity(e)
161
+
162
+ # rename relation entities with their wikipedia titles
163
+ r["head"] = entities[0]["title"]
164
+ r["tail"] = entities[1]["title"]
165
+
166
+ # add source if not in kb
167
+ article_url = list(r["meta"].keys())[0]
168
+ if article_url not in self.sources:
169
+ self.sources[article_url] = {
170
+ "article_title": article_title,
171
+ "article_publish_date": article_publish_date,
172
+ }
173
+
174
+ # manage new relation
175
+ if not self.exists_relation(r):
176
+ self.relations.append(r)
177
+ else:
178
+ self.merge_relations(r)
179
+
180
+ def print(self):
181
+ """
182
+ Print the entities, relations, and sources in the KB.
183
+
184
+ Returns:
185
+ None
186
+ """
187
+ print("Entities:")
188
+ for e in self.entities.items():
189
+ print(f" {e}")
190
+ print("Relations:")
191
+ for r in self.relations:
192
+ print(f" {r}")
193
+ print("Sources:")
194
+ for s in self.sources.items():
195
+ print(f" {s}")
196
+
197
+ @staticmethod
198
+ def extract_relations_from_model_output(text):
199
+ """
200
+ Extract relations from the model output text.
201
+
202
+ Args:
203
+ text (str): The model output text containing relations.
204
+
205
+ Returns:
206
+ list: A list of dictionaries, where each dictionary represents a relation (head, type, tail).
207
+ """
208
+ relations = []
209
+ relation, subject, relation, object_ = "", "", "", ""
210
+ text = text.strip()
211
+ current = "x"
212
+ text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
213
+ for token in text_replaced.split():
214
+ if token == "<triplet>":
215
+ current = "t"
216
+ if relation != "":
217
+ relations.append(
218
+ {
219
+ "head": subject.strip(),
220
+ "type": relation.strip(),
221
+ "tail": object_.strip(),
222
+ }
223
+ )
224
+ relation = ""
225
+ subject = ""
226
+ elif token == "<subj>":
227
+ current = "s"
228
+ if relation != "":
229
+ relations.append(
230
+ {
231
+ "head": subject.strip(),
232
+ "type": relation.strip(),
233
+ "tail": object_.strip(),
234
+ }
235
+ )
236
+ object_ = ""
237
+ elif token == "<obj>":
238
+ current = "o"
239
+ relation = ""
240
+ else:
241
+ if current == "t":
242
+ subject += " " + token
243
+ elif current == "s":
244
+ object_ += " " + token
245
+ elif current == "o":
246
+ relation += " " + token
247
+ if subject != "" and relation != "" and object_ != "":
248
+ relations.append(
249
+ {
250
+ "head": subject.strip(),
251
+ "type": relation.strip(),
252
+ "tail": object_.strip(),
253
+ }
254
+ )
255
+ return relations
256
+
257
+
258
+ class KGTripletExtractor:
259
+ """
260
+ A class to perform knowledge graph triplet extraction from text using a pre-trained model.
261
+
262
+ Methods:
263
+ text_to_wiki_kb(text, model=None, tokenizer=None, device='cpu', span_length=512,
264
+ article_title=None, article_publish_date=None, verbose=False):
265
+ Extract knowledge graph triplets from text and create a KnowledgeBase (KB) containing entities and relations.
266
+
267
+ """
268
+
269
+ @staticmethod
270
+ def text_to_wiki_kb(
271
+ text,
272
+ model=None,
273
+ tokenizer=None,
274
+ device="cpu",
275
+ span_length=512,
276
+ article_title=None,
277
+ article_publish_date=None,
278
+ verbose=False,
279
+ ):
280
+ from lionagi.integrations.bridge.transformers_.install_ import (
281
+ install_transformers,
282
+ )
283
+
284
+ try:
285
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
286
+ except ImportError:
287
+ install_transformers()
288
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
289
+ import torch # type: ignore
290
+
291
+ """
292
+ Extract knowledge graph triplets from text and create a KnowledgeBase (KB) containing entities and relations.
293
+
294
+ Args:
295
+ text (str): The input text from which triplets will be extracted.
296
+ model (AutoModelForSeq2SeqLM, optional): The pre-trained model for triplet extraction. Defaults to None.
297
+ tokenizer (AutoTokenizer, optional): The tokenizer for the model. Defaults to None.
298
+ device (str, optional): The device to run the model on (e.g., 'cpu', 'cuda'). Defaults to 'cpu'.
299
+ span_length (int, optional): The maximum span length for input text segmentation. Defaults to 512.
300
+ article_title (str, optional): The title of the article containing the input text. Defaults to None.
301
+ article_publish_date (str, optional): The publish date of the article. Defaults to None.
302
+ verbose (bool, optional): Whether to enable verbose mode for debugging. Defaults to False.
303
+
304
+ Returns:
305
+ KnowledgeBase: A KnowledgeBase (KB) containing extracted entities, relations, and sources.
306
+
307
+ """
308
+
309
+ if not any([model, tokenizer]):
310
+ tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
311
+ model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
312
+ model.to(device)
313
+
314
+ inputs = tokenizer([text], return_tensors="pt")
315
+
316
+ num_tokens = len(inputs["input_ids"][0])
317
+ if verbose:
318
+ print(f"Input has {num_tokens} tokens")
319
+ num_spans = math.ceil(num_tokens / span_length)
320
+ if verbose:
321
+ print(f"Input has {num_spans} spans")
322
+ overlap = math.ceil(
323
+ (num_spans * span_length - num_tokens) / max(num_spans - 1, 1)
324
+ )
325
+ spans_boundaries = []
326
+ start = 0
327
+ for i in range(num_spans):
328
+ spans_boundaries.append(
329
+ [start + span_length * i, start + span_length * (i + 1)]
330
+ )
331
+ start -= overlap
332
+ if verbose:
333
+ print(f"Span boundaries are {spans_boundaries}")
334
+
335
+ # transform input with spans
336
+ tensor_ids = [
337
+ inputs["input_ids"][0][boundary[0] : boundary[1]]
338
+ for boundary in spans_boundaries
339
+ ]
340
+ tensor_masks = [
341
+ inputs["attention_mask"][0][boundary[0] : boundary[1]]
342
+ for boundary in spans_boundaries
343
+ ]
344
+
345
+ inputs = {
346
+ "input_ids": torch.stack(tensor_ids).to(device),
347
+ "attention_mask": torch.stack(tensor_masks).to(device),
348
+ }
349
+
350
+ # generate relations
351
+ num_return_sequences = 3
352
+ gen_kwargs = {
353
+ "max_length": 512,
354
+ "length_penalty": 0,
355
+ "num_beams": 3,
356
+ "num_return_sequences": num_return_sequences,
357
+ }
358
+ generated_tokens = model.generate(
359
+ **inputs,
360
+ **gen_kwargs,
361
+ )
362
+
363
+ # decode relations
364
+ decoded_preds = tokenizer.batch_decode(
365
+ generated_tokens, skip_special_tokens=False
366
+ )
367
+
368
+ # create kb
369
+ kb = KnowledgeBase()
370
+ i = 0
371
+ for sentence_pred in decoded_preds:
372
+ current_span_index = i // num_return_sequences
373
+ relations = KnowledgeBase.extract_relations_from_model_output(sentence_pred)
374
+ for relation in relations:
375
+ relation["meta"] = {
376
+ "article_url": {"spans": [spans_boundaries[current_span_index]]}
377
+ }
378
+ kb.add_relation(relation, article_title, article_publish_date)
379
+ i += 1
380
+ return kb
381
+
382
+
383
+ class KGraph:
384
+ """
385
+ A class representing a Knowledge Graph (KGraph) for extracting relations from text.
386
+
387
+ Methods:
388
+ text_to_wiki_kb(text, model=None, tokenizer=None, device='cpu', span_length=512, article_title=None,
389
+ article_publish_date=None, verbose=False):
390
+ Extract relations from input text and create a Knowledge Base (KB) containing entities and relations.
391
+ """
392
+
393
+ @staticmethod
394
+ def text_to_wiki_kb(text, **kwargs):
395
+ """
396
+ Extract relations from input text and create a Knowledge Base (KB) containing entities and relations.
397
+
398
+ Args:
399
+ text (str): The input text from which relations are extracted.
400
+ **kwargs: Additional keyword arguments passed to the underlying extraction method.
401
+
402
+ Returns:
403
+ KnowledgeBase: A Knowledge Base (KB) containing entities and relations extracted from the input text.
404
+ """
405
+ return KGTripletExtractor.text_to_wiki_kb(text, **kwargs)
lionagi/libs/ln_nested.py CHANGED
@@ -1,3 +1,19 @@
1
+ """
2
+ Copyright 2024 HaiyangLi
3
+
4
+ Licensed under the Apache License, Version 2.0 (the "License");
5
+ you may not use this file except in compliance with the License.
6
+ You may obtain a copy of the License at
7
+
8
+ http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.
15
+ """
16
+
1
17
  from collections import defaultdict
2
18
  from itertools import chain
3
19
  from typing import Any, Generator, Callable
@@ -52,7 +68,7 @@ def nset(nested_structure: dict | list, indices: list[int | str], value: Any) ->
52
68
  def nget(
53
69
  nested_structure: dict | list,
54
70
  indices: list[int | str],
55
- default: Any | None = None,
71
+ default=...,
56
72
  ) -> Any:
57
73
  """
58
74
  retrieves a value from a nested list or dictionary structure, with an option to
@@ -98,12 +114,12 @@ def nget(
98
114
  return target_container[last_index]
99
115
  elif isinstance(target_container, dict) and last_index in target_container:
100
116
  return target_container[last_index]
101
- elif default is not None:
117
+ elif default is not ...:
102
118
  return default
103
119
  else:
104
120
  raise LookupError("Target not found and no default value provided.")
105
121
  except (IndexError, KeyError, TypeError):
106
- if default is not None:
122
+ if default is not ...:
107
123
  return default
108
124
  else:
109
125
  raise LookupError("Target not found and no default value provided.")
@@ -116,7 +132,7 @@ def nmerge(
116
132
  *,
117
133
  overwrite: bool = False,
118
134
  dict_sequence: bool = False,
119
- sequence_separator: str = "_",
135
+ sequence_separator: str = "[^_^]",
120
136
  sort_list: bool = False,
121
137
  custom_sort: Callable[[Any], Any] | None = None,
122
138
  ) -> dict | list:
@@ -176,7 +192,7 @@ def flatten(
176
192
  /,
177
193
  *,
178
194
  parent_key: str = "",
179
- sep: str = "_",
195
+ sep: str = "[^_^]",
180
196
  max_depth: int | None = None,
181
197
  inplace: bool = False,
182
198
  dict_only: bool = False,
@@ -238,7 +254,7 @@ def unflatten(
238
254
  flat_dict: dict[str, Any],
239
255
  /,
240
256
  *,
241
- sep: str = "_",
257
+ sep: str = "[^_^]",
242
258
  custom_logic: Callable[[str], Any] | None = None,
243
259
  max_depth: int | None = None,
244
260
  ) -> dict | list:
@@ -330,7 +346,7 @@ def ninsert(
330
346
  indices: list[str | int],
331
347
  value: Any,
332
348
  *,
333
- sep: str = "_",
349
+ sep: str = "[^_^]",
334
350
  max_depth: int | None = None,
335
351
  current_depth: int = 0,
336
352
  ) -> None:
@@ -393,12 +409,11 @@ def ninsert(
393
409
  nested_structure[last_part] = value
394
410
 
395
411
 
396
- # noinspection PyDecorator
397
412
  def get_flattened_keys(
398
413
  nested_structure: Any,
399
414
  /,
400
415
  *,
401
- sep: str = "_",
416
+ sep: str = "[^_^]",
402
417
  max_depth: int | None = None,
403
418
  dict_only: bool = False,
404
419
  inplace: bool = False,
@@ -448,7 +463,7 @@ def _dynamic_flatten_in_place(
448
463
  /,
449
464
  *,
450
465
  parent_key: str = "",
451
- sep: str = "_",
466
+ sep: str = "[^_^]",
452
467
  max_depth: int | None = None,
453
468
  current_depth: int = 0,
454
469
  dict_only: bool = False,
@@ -581,7 +596,7 @@ def _deep_update(original: dict, update: dict) -> dict:
581
596
  def _dynamic_flatten_generator(
582
597
  nested_structure: Any,
583
598
  parent_key: tuple[str, ...],
584
- sep: str = "_",
599
+ sep: str = "[^_^]",
585
600
  max_depth: int | None = None,
586
601
  current_depth: int = 0,
587
602
  dict_only: bool = False,