lionagi 0.0.312__py3-none-any.whl → 0.2.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lionagi/__init__.py +61 -3
- lionagi/core/__init__.py +0 -14
- lionagi/core/_setting/_setting.py +59 -0
- lionagi/core/action/__init__.py +14 -0
- lionagi/core/action/function_calling.py +136 -0
- lionagi/core/action/manual.py +1 -0
- lionagi/core/action/node.py +109 -0
- lionagi/core/action/tool.py +114 -0
- lionagi/core/action/tool_manager.py +356 -0
- lionagi/core/agent/__init__.py +0 -3
- lionagi/core/agent/base_agent.py +45 -36
- lionagi/core/agent/eval/evaluator.py +1 -0
- lionagi/core/agent/eval/vote.py +40 -0
- lionagi/core/agent/learn/learner.py +59 -0
- lionagi/core/agent/plan/unit_template.py +1 -0
- lionagi/core/collections/__init__.py +17 -0
- lionagi/core/collections/_logger.py +319 -0
- lionagi/core/collections/abc/__init__.py +53 -0
- lionagi/core/collections/abc/component.py +615 -0
- lionagi/core/collections/abc/concepts.py +297 -0
- lionagi/core/collections/abc/exceptions.py +150 -0
- lionagi/core/collections/abc/util.py +45 -0
- lionagi/core/collections/exchange.py +161 -0
- lionagi/core/collections/flow.py +426 -0
- lionagi/core/collections/model.py +419 -0
- lionagi/core/collections/pile.py +913 -0
- lionagi/core/collections/progression.py +236 -0
- lionagi/core/collections/util.py +64 -0
- lionagi/core/director/direct.py +314 -0
- lionagi/core/director/director.py +2 -0
- lionagi/core/engine/branch_engine.py +333 -0
- lionagi/core/engine/instruction_map_engine.py +204 -0
- lionagi/core/engine/sandbox_.py +14 -0
- lionagi/core/engine/script_engine.py +99 -0
- lionagi/core/executor/base_executor.py +90 -0
- lionagi/core/executor/graph_executor.py +330 -0
- lionagi/core/executor/neo4j_executor.py +384 -0
- lionagi/core/generic/__init__.py +7 -0
- lionagi/core/generic/edge.py +112 -0
- lionagi/core/generic/edge_condition.py +16 -0
- lionagi/core/generic/graph.py +236 -0
- lionagi/core/generic/hyperedge.py +1 -0
- lionagi/core/generic/node.py +220 -0
- lionagi/core/generic/tree.py +48 -0
- lionagi/core/generic/tree_node.py +79 -0
- lionagi/core/mail/__init__.py +7 -3
- lionagi/core/mail/mail.py +25 -0
- lionagi/core/mail/mail_manager.py +142 -58
- lionagi/core/mail/package.py +45 -0
- lionagi/core/mail/start_mail.py +36 -0
- lionagi/core/message/__init__.py +19 -0
- lionagi/core/message/action_request.py +133 -0
- lionagi/core/message/action_response.py +135 -0
- lionagi/core/message/assistant_response.py +95 -0
- lionagi/core/message/instruction.py +234 -0
- lionagi/core/message/message.py +101 -0
- lionagi/core/message/system.py +86 -0
- lionagi/core/message/util.py +283 -0
- lionagi/core/report/__init__.py +4 -0
- lionagi/core/report/base.py +217 -0
- lionagi/core/report/form.py +231 -0
- lionagi/core/report/report.py +166 -0
- lionagi/core/report/util.py +28 -0
- lionagi/core/rule/__init__.py +0 -0
- lionagi/core/rule/_default.py +16 -0
- lionagi/core/rule/action.py +99 -0
- lionagi/core/rule/base.py +238 -0
- lionagi/core/rule/boolean.py +56 -0
- lionagi/core/rule/choice.py +47 -0
- lionagi/core/rule/mapping.py +96 -0
- lionagi/core/rule/number.py +71 -0
- lionagi/core/rule/rulebook.py +109 -0
- lionagi/core/rule/string.py +52 -0
- lionagi/core/rule/util.py +35 -0
- lionagi/core/session/__init__.py +0 -3
- lionagi/core/session/branch.py +431 -0
- lionagi/core/session/directive_mixin.py +287 -0
- lionagi/core/session/session.py +230 -902
- lionagi/core/structure/__init__.py +1 -0
- lionagi/core/structure/chain.py +1 -0
- lionagi/core/structure/forest.py +1 -0
- lionagi/core/structure/graph.py +1 -0
- lionagi/core/structure/tree.py +1 -0
- lionagi/core/unit/__init__.py +5 -0
- lionagi/core/unit/parallel_unit.py +245 -0
- lionagi/core/unit/template/__init__.py +0 -0
- lionagi/core/unit/template/action.py +81 -0
- lionagi/core/unit/template/base.py +51 -0
- lionagi/core/unit/template/plan.py +84 -0
- lionagi/core/unit/template/predict.py +109 -0
- lionagi/core/unit/template/score.py +124 -0
- lionagi/core/unit/template/select.py +104 -0
- lionagi/core/unit/unit.py +362 -0
- lionagi/core/unit/unit_form.py +305 -0
- lionagi/core/unit/unit_mixin.py +1168 -0
- lionagi/core/unit/util.py +71 -0
- lionagi/core/validator/__init__.py +0 -0
- lionagi/core/validator/validator.py +364 -0
- lionagi/core/work/__init__.py +0 -0
- lionagi/core/work/work.py +76 -0
- lionagi/core/work/work_function.py +101 -0
- lionagi/core/work/work_queue.py +103 -0
- lionagi/core/work/worker.py +258 -0
- lionagi/core/work/worklog.py +120 -0
- lionagi/experimental/__init__.py +0 -0
- lionagi/experimental/compressor/__init__.py +0 -0
- lionagi/experimental/compressor/base.py +46 -0
- lionagi/experimental/compressor/llm_compressor.py +247 -0
- lionagi/experimental/compressor/llm_summarizer.py +61 -0
- lionagi/experimental/compressor/util.py +70 -0
- lionagi/experimental/directive/__init__.py +19 -0
- lionagi/experimental/directive/parser/__init__.py +0 -0
- lionagi/experimental/directive/parser/base_parser.py +282 -0
- lionagi/experimental/directive/template/__init__.py +0 -0
- lionagi/experimental/directive/template/base_template.py +79 -0
- lionagi/experimental/directive/template/schema.py +36 -0
- lionagi/experimental/directive/tokenizer.py +73 -0
- lionagi/experimental/evaluator/__init__.py +0 -0
- lionagi/experimental/evaluator/ast_evaluator.py +131 -0
- lionagi/experimental/evaluator/base_evaluator.py +218 -0
- lionagi/experimental/knowledge/__init__.py +0 -0
- lionagi/experimental/knowledge/base.py +10 -0
- lionagi/experimental/knowledge/graph.py +0 -0
- lionagi/experimental/memory/__init__.py +0 -0
- lionagi/experimental/strategies/__init__.py +0 -0
- lionagi/experimental/strategies/base.py +1 -0
- lionagi/integrations/bridge/autogen_/__init__.py +0 -0
- lionagi/integrations/bridge/autogen_/autogen_.py +124 -0
- lionagi/integrations/bridge/langchain_/documents.py +4 -0
- lionagi/integrations/bridge/llamaindex_/index.py +30 -0
- lionagi/integrations/bridge/llamaindex_/llama_index_bridge.py +6 -0
- lionagi/integrations/bridge/llamaindex_/llama_pack.py +227 -0
- lionagi/integrations/bridge/llamaindex_/node_parser.py +6 -9
- lionagi/integrations/bridge/pydantic_/pydantic_bridge.py +1 -0
- lionagi/integrations/bridge/transformers_/__init__.py +0 -0
- lionagi/integrations/bridge/transformers_/install_.py +36 -0
- lionagi/integrations/chunker/__init__.py +0 -0
- lionagi/integrations/chunker/chunk.py +312 -0
- lionagi/integrations/config/oai_configs.py +38 -7
- lionagi/integrations/config/ollama_configs.py +1 -1
- lionagi/integrations/config/openrouter_configs.py +14 -2
- lionagi/integrations/loader/__init__.py +0 -0
- lionagi/integrations/loader/load.py +253 -0
- lionagi/integrations/loader/load_util.py +195 -0
- lionagi/integrations/provider/_mapping.py +46 -0
- lionagi/integrations/provider/litellm.py +2 -1
- lionagi/integrations/provider/mlx_service.py +16 -9
- lionagi/integrations/provider/oai.py +91 -4
- lionagi/integrations/provider/ollama.py +7 -6
- lionagi/integrations/provider/openrouter.py +115 -8
- lionagi/integrations/provider/services.py +2 -2
- lionagi/integrations/provider/transformers.py +18 -22
- lionagi/integrations/storage/__init__.py +3 -0
- lionagi/integrations/storage/neo4j.py +665 -0
- lionagi/integrations/storage/storage_util.py +287 -0
- lionagi/integrations/storage/structure_excel.py +285 -0
- lionagi/integrations/storage/to_csv.py +63 -0
- lionagi/integrations/storage/to_excel.py +83 -0
- lionagi/libs/__init__.py +26 -1
- lionagi/libs/ln_api.py +78 -23
- lionagi/libs/ln_context.py +37 -0
- lionagi/libs/ln_convert.py +21 -9
- lionagi/libs/ln_func_call.py +69 -28
- lionagi/libs/ln_image.py +107 -0
- lionagi/libs/ln_knowledge_graph.py +405 -0
- lionagi/libs/ln_nested.py +26 -11
- lionagi/libs/ln_parse.py +110 -14
- lionagi/libs/ln_queue.py +117 -0
- lionagi/libs/ln_tokenize.py +164 -0
- lionagi/{core/prompt/field_validator.py → libs/ln_validate.py} +79 -14
- lionagi/libs/special_tokens.py +172 -0
- lionagi/libs/sys_util.py +107 -2
- lionagi/lions/__init__.py +0 -0
- lionagi/lions/coder/__init__.py +0 -0
- lionagi/lions/coder/add_feature.py +20 -0
- lionagi/lions/coder/base_prompts.py +22 -0
- lionagi/lions/coder/code_form.py +13 -0
- lionagi/lions/coder/coder.py +168 -0
- lionagi/lions/coder/util.py +96 -0
- lionagi/lions/researcher/__init__.py +0 -0
- lionagi/lions/researcher/data_source/__init__.py +0 -0
- lionagi/lions/researcher/data_source/finhub_.py +191 -0
- lionagi/lions/researcher/data_source/google_.py +199 -0
- lionagi/lions/researcher/data_source/wiki_.py +96 -0
- lionagi/lions/researcher/data_source/yfinance_.py +21 -0
- lionagi/tests/integrations/__init__.py +0 -0
- lionagi/tests/libs/__init__.py +0 -0
- lionagi/tests/libs/test_field_validators.py +353 -0
- lionagi/tests/{test_libs → libs}/test_func_call.py +23 -21
- lionagi/tests/{test_libs → libs}/test_nested.py +36 -21
- lionagi/tests/{test_libs → libs}/test_parse.py +1 -1
- lionagi/tests/libs/test_queue.py +67 -0
- lionagi/tests/test_core/collections/__init__.py +0 -0
- lionagi/tests/test_core/collections/test_component.py +206 -0
- lionagi/tests/test_core/collections/test_exchange.py +138 -0
- lionagi/tests/test_core/collections/test_flow.py +145 -0
- lionagi/tests/test_core/collections/test_pile.py +171 -0
- lionagi/tests/test_core/collections/test_progression.py +129 -0
- lionagi/tests/test_core/generic/__init__.py +0 -0
- lionagi/tests/test_core/generic/test_edge.py +67 -0
- lionagi/tests/test_core/generic/test_graph.py +96 -0
- lionagi/tests/test_core/generic/test_node.py +106 -0
- lionagi/tests/test_core/generic/test_tree_node.py +73 -0
- lionagi/tests/test_core/test_branch.py +115 -292
- lionagi/tests/test_core/test_form.py +46 -0
- lionagi/tests/test_core/test_report.py +105 -0
- lionagi/tests/test_core/test_validator.py +111 -0
- lionagi/version.py +1 -1
- {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/LICENSE +12 -11
- {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/METADATA +19 -118
- lionagi-0.2.1.dist-info/RECORD +240 -0
- lionagi/core/branch/__init__.py +0 -4
- lionagi/core/branch/base_branch.py +0 -654
- lionagi/core/branch/branch.py +0 -471
- lionagi/core/branch/branch_flow_mixin.py +0 -96
- lionagi/core/branch/executable_branch.py +0 -347
- lionagi/core/branch/util.py +0 -323
- lionagi/core/direct/__init__.py +0 -6
- lionagi/core/direct/predict.py +0 -161
- lionagi/core/direct/score.py +0 -278
- lionagi/core/direct/select.py +0 -169
- lionagi/core/direct/utils.py +0 -87
- lionagi/core/direct/vote.py +0 -64
- lionagi/core/flow/base/baseflow.py +0 -23
- lionagi/core/flow/monoflow/ReAct.py +0 -238
- lionagi/core/flow/monoflow/__init__.py +0 -9
- lionagi/core/flow/monoflow/chat.py +0 -95
- lionagi/core/flow/monoflow/chat_mixin.py +0 -263
- lionagi/core/flow/monoflow/followup.py +0 -214
- lionagi/core/flow/polyflow/__init__.py +0 -1
- lionagi/core/flow/polyflow/chat.py +0 -248
- lionagi/core/mail/schema.py +0 -56
- lionagi/core/messages/__init__.py +0 -3
- lionagi/core/messages/schema.py +0 -533
- lionagi/core/prompt/prompt_template.py +0 -316
- lionagi/core/schema/__init__.py +0 -22
- lionagi/core/schema/action_node.py +0 -29
- lionagi/core/schema/base_mixin.py +0 -296
- lionagi/core/schema/base_node.py +0 -199
- lionagi/core/schema/condition.py +0 -24
- lionagi/core/schema/data_logger.py +0 -354
- lionagi/core/schema/data_node.py +0 -93
- lionagi/core/schema/prompt_template.py +0 -67
- lionagi/core/schema/structure.py +0 -910
- lionagi/core/tool/__init__.py +0 -3
- lionagi/core/tool/tool_manager.py +0 -280
- lionagi/integrations/bridge/pydantic_/base_model.py +0 -7
- lionagi/tests/test_core/test_base_branch.py +0 -427
- lionagi/tests/test_core/test_chat_flow.py +0 -63
- lionagi/tests/test_core/test_mail_manager.py +0 -75
- lionagi/tests/test_core/test_prompts.py +0 -51
- lionagi/tests/test_core/test_session.py +0 -254
- lionagi/tests/test_core/test_session_base_util.py +0 -312
- lionagi/tests/test_core/test_tool_manager.py +0 -95
- lionagi-0.0.312.dist-info/RECORD +0 -111
- /lionagi/core/{branch/base → _setting}/__init__.py +0 -0
- /lionagi/core/{flow → agent/eval}/__init__.py +0 -0
- /lionagi/core/{flow/base → agent/learn}/__init__.py +0 -0
- /lionagi/core/{prompt → agent/plan}/__init__.py +0 -0
- /lionagi/core/{tool/manual.py → agent/plan/plan.py} +0 -0
- /lionagi/{tests/test_integrations → core/director}/__init__.py +0 -0
- /lionagi/{tests/test_libs → core/engine}/__init__.py +0 -0
- /lionagi/{tests/test_libs/test_async.py → core/executor/__init__.py} +0 -0
- /lionagi/tests/{test_libs → libs}/test_api.py +0 -0
- /lionagi/tests/{test_libs → libs}/test_convert.py +0 -0
- /lionagi/tests/{test_libs → libs}/test_sys_util.py +0 -0
- {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/WHEEL +0 -0
- {lionagi-0.0.312.dist-info → lionagi-0.2.1.dist-info}/top_level.txt +0 -0
lionagi/libs/ln_image.py
ADDED
@@ -0,0 +1,107 @@
|
|
1
|
+
import base64
|
2
|
+
import numpy as np
|
3
|
+
from typing import Optional
|
4
|
+
from .sys_util import SysUtil
|
5
|
+
|
6
|
+
|
7
|
+
class ImageUtil:
|
8
|
+
|
9
|
+
@staticmethod
|
10
|
+
def preprocess_image(
|
11
|
+
image: np.ndarray, color_conversion_code: Optional[int] = None
|
12
|
+
) -> np.ndarray:
|
13
|
+
SysUtil.check_import("cv2", pip_name="opencv-python")
|
14
|
+
import cv2
|
15
|
+
|
16
|
+
color_conversion_code = color_conversion_code or cv2.COLOR_BGR2RGB
|
17
|
+
return cv2.cvtColor(image, color_conversion_code)
|
18
|
+
|
19
|
+
@staticmethod
|
20
|
+
def encode_image_to_base64(image: np.ndarray, file_extension: str = ".jpg") -> str:
|
21
|
+
SysUtil.check_import("cv2", pip_name="opencv-python")
|
22
|
+
import cv2
|
23
|
+
|
24
|
+
success, buffer = cv2.imencode(file_extension, image)
|
25
|
+
if not success:
|
26
|
+
raise ValueError(f"Could not encode image to {file_extension} format.")
|
27
|
+
encoded_image = base64.b64encode(buffer).decode("utf-8")
|
28
|
+
return encoded_image
|
29
|
+
|
30
|
+
@staticmethod
|
31
|
+
def read_image_to_array(
|
32
|
+
image_path: str, color_flag: Optional[int] = None
|
33
|
+
) -> np.ndarray:
|
34
|
+
SysUtil.check_import("cv2", pip_name="opencv-python")
|
35
|
+
import cv2
|
36
|
+
|
37
|
+
image = cv2.imread(image_path, color_flag)
|
38
|
+
color_flag = color_flag or cv2.IMREAD_COLOR
|
39
|
+
if image is None:
|
40
|
+
raise ValueError(f"Could not read image from path: {image_path}")
|
41
|
+
return image
|
42
|
+
|
43
|
+
@staticmethod
|
44
|
+
def read_image_to_base64(
|
45
|
+
image_path: str,
|
46
|
+
color_flag: Optional[int] = None,
|
47
|
+
) -> str:
|
48
|
+
image_path = str(image_path)
|
49
|
+
image = ImageUtil.read_image_to_array(image_path, color_flag)
|
50
|
+
|
51
|
+
file_extension = "." + image_path.split(".")[-1]
|
52
|
+
return ImageUtil.encode_image_to_base64(image, file_extension)
|
53
|
+
|
54
|
+
# @staticmethod
|
55
|
+
# def encode_image(image_path):
|
56
|
+
# with open(image_path, "rb") as image_file:
|
57
|
+
# return base64.b64encode(image_file.read()).decode("utf-8")
|
58
|
+
|
59
|
+
@staticmethod
|
60
|
+
def calculate_image_token_usage_from_base64(image_base64: str, detail):
|
61
|
+
"""
|
62
|
+
Calculate the token usage for processing OpenAI images from a base64-encoded string.
|
63
|
+
|
64
|
+
Parameters:
|
65
|
+
image_base64 (str): The base64-encoded string of the image.
|
66
|
+
detail (str): The detail level of the image, either 'low' or 'high'.
|
67
|
+
|
68
|
+
Returns:
|
69
|
+
int: The total token cost for processing the image.
|
70
|
+
"""
|
71
|
+
import base64
|
72
|
+
from io import BytesIO
|
73
|
+
from PIL import Image
|
74
|
+
|
75
|
+
# Decode the base64 string to get image data
|
76
|
+
if "data:image/jpeg;base64," in image_base64:
|
77
|
+
image_base64 = image_base64.split("data:image/jpeg;base64,")[1]
|
78
|
+
image_base64.strip("{}")
|
79
|
+
|
80
|
+
image_data = base64.b64decode(image_base64)
|
81
|
+
image = Image.open(BytesIO(image_data))
|
82
|
+
|
83
|
+
# Get image dimensions
|
84
|
+
width, height = image.size
|
85
|
+
|
86
|
+
if detail == "low":
|
87
|
+
return 85
|
88
|
+
|
89
|
+
# Scale to fit within a 2048 x 2048 square
|
90
|
+
max_dimension = 2048
|
91
|
+
if width > max_dimension or height > max_dimension:
|
92
|
+
scale_factor = max_dimension / max(width, height)
|
93
|
+
width = int(width * scale_factor)
|
94
|
+
height = int(height * scale_factor)
|
95
|
+
|
96
|
+
# Scale such that the shortest side is 768px
|
97
|
+
min_side = 768
|
98
|
+
if min(width, height) > min_side:
|
99
|
+
scale_factor = min_side / min(width, height)
|
100
|
+
width = int(width * scale_factor)
|
101
|
+
height = int(height * scale_factor)
|
102
|
+
|
103
|
+
# Calculate the number of 512px squares
|
104
|
+
num_squares = (width // 512) * (height // 512)
|
105
|
+
token_cost = 170 * num_squares + 85
|
106
|
+
|
107
|
+
return token_cost
|
@@ -0,0 +1,405 @@
|
|
1
|
+
import math
|
2
|
+
from lionagi.libs import CallDecorator as cd
|
3
|
+
|
4
|
+
|
5
|
+
class KnowledgeBase:
|
6
|
+
"""
|
7
|
+
A class to represent a Knowledge Base (KB) containing entities, relations, and sources.
|
8
|
+
|
9
|
+
Attributes:
|
10
|
+
entities (dict): A dictionary of entities in the KB, where the keys are entity titles, and the values are
|
11
|
+
entity information (excluding the title).
|
12
|
+
relations (list): A list of relations in the KB, where each relation is a dictionary containing information
|
13
|
+
about the relation (head, type, tail) and metadata (article_url and spans).
|
14
|
+
sources (dict): A dictionary of information about the sources of relations, where the keys are article URLs,
|
15
|
+
and the values are source data (article_title and article_publish_date).
|
16
|
+
|
17
|
+
Methods:
|
18
|
+
merge_with_kb(kb2): Merge another Knowledge Base (kb2) into this KB.
|
19
|
+
are_relations_equal(r1, r2): Check if two relations (r1 and r2) are equal.
|
20
|
+
exists_relation(r1): Check if a relation (r1) already exists in the KB.
|
21
|
+
merge_relations(r2): Merge the information from relation r2 into an existing relation in the KB.
|
22
|
+
get_wikipedia_data(candidate_entity): Get data for a candidate entity from Wikipedia.
|
23
|
+
add_entity(e): Add an entity to the KB.
|
24
|
+
add_relation(r, article_title, article_publish_date): Add a relation to the KB.
|
25
|
+
print(): Print the entities, relations, and sources in the KB.
|
26
|
+
extract_relations_from_model_output(text): Extract relations from the model output text.
|
27
|
+
|
28
|
+
"""
|
29
|
+
|
30
|
+
def __init__(self):
|
31
|
+
"""
|
32
|
+
Initialize an empty Knowledge Base (KB) with empty dictionaries for entities, relations, and sources.
|
33
|
+
"""
|
34
|
+
self.entities = {} # { entity_title: {...} }
|
35
|
+
self.relations = [] # [ head: entity_title, type: ..., tail: entity_title,
|
36
|
+
# meta: { article_url: { spans: [...] } } ]
|
37
|
+
self.sources = {} # { article_url: {...} }
|
38
|
+
|
39
|
+
def merge_with_kb(self, kb2):
|
40
|
+
"""
|
41
|
+
Merge another Knowledge Base (KB) into this KB.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
kb2 (KnowledgeBase): The Knowledge Base (KB) to merge into this KB.
|
45
|
+
"""
|
46
|
+
for r in kb2.relations:
|
47
|
+
article_url = list(r["meta"].keys())[0]
|
48
|
+
source_data = kb2.sources[article_url]
|
49
|
+
self.add_relation(
|
50
|
+
r, source_data["article_title"], source_data["article_publish_date"]
|
51
|
+
)
|
52
|
+
|
53
|
+
def are_relations_equal(self, r1, r2):
|
54
|
+
"""
|
55
|
+
Check if two relations (r1 and r2) are equal.
|
56
|
+
|
57
|
+
Args:
|
58
|
+
r1 (dict): The first relation to compare.
|
59
|
+
r2 (dict): The second relation to compare.
|
60
|
+
|
61
|
+
Returns:
|
62
|
+
bool: True if the relations are equal, False otherwise.
|
63
|
+
"""
|
64
|
+
return all(r1[attr] == r2[attr] for attr in ["head", "type", "tail"])
|
65
|
+
|
66
|
+
def exists_relation(self, r1):
|
67
|
+
"""
|
68
|
+
Check if a relation (r1) already exists in the KB.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
r1 (dict): The relation to check for existence in the KB.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
bool: True if the relation exists in the KB, False otherwise.
|
75
|
+
"""
|
76
|
+
return any(self.are_relations_equal(r1, r2) for r2 in self.relations)
|
77
|
+
|
78
|
+
def merge_relations(self, r2):
|
79
|
+
"""
|
80
|
+
Merge the information from relation r2 into an existing relation in the KB.
|
81
|
+
|
82
|
+
Args:
|
83
|
+
r2 (dict): The relation to merge into an existing relation in the KB.
|
84
|
+
"""
|
85
|
+
r1 = [r for r in self.relations if self.are_relations_equal(r2, r)][0]
|
86
|
+
|
87
|
+
# if different article
|
88
|
+
article_url = list(r2["meta"].keys())[0]
|
89
|
+
if article_url not in r1["meta"]:
|
90
|
+
r1["meta"][article_url] = r2["meta"][article_url]
|
91
|
+
|
92
|
+
# if existing article
|
93
|
+
else:
|
94
|
+
spans_to_add = [
|
95
|
+
span
|
96
|
+
for span in r2["meta"][article_url]["spans"]
|
97
|
+
if span not in r1["meta"][article_url]["spans"]
|
98
|
+
]
|
99
|
+
r1["meta"][article_url]["spans"] += spans_to_add
|
100
|
+
|
101
|
+
@cd.cache(maxsize=10000)
|
102
|
+
def get_wikipedia_data(self, candidate_entity):
|
103
|
+
"""
|
104
|
+
Get data for a candidate entity from Wikipedia.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
candidate_entity (str): The candidate entity title.
|
108
|
+
|
109
|
+
Returns:
|
110
|
+
dict: A dictionary containing information about the candidate entity (title, url, summary).
|
111
|
+
None if the entity does not exist in Wikipedia.
|
112
|
+
"""
|
113
|
+
try:
|
114
|
+
from lionagi.libs import SysUtil
|
115
|
+
|
116
|
+
SysUtil.check_import("wikipedia")
|
117
|
+
import wikipedia # type: ignore
|
118
|
+
except Exception as e:
|
119
|
+
raise Exception("wikipedia package is not installed {e}")
|
120
|
+
|
121
|
+
try:
|
122
|
+
page = wikipedia.page(candidate_entity, auto_suggest=False)
|
123
|
+
entity_data = {
|
124
|
+
"title": page.title,
|
125
|
+
"url": page.url,
|
126
|
+
"summary": page.summary,
|
127
|
+
}
|
128
|
+
return entity_data
|
129
|
+
except:
|
130
|
+
return None
|
131
|
+
|
132
|
+
def add_entity(self, e):
|
133
|
+
"""
|
134
|
+
Add an entity to the KB.
|
135
|
+
|
136
|
+
Args:
|
137
|
+
e (dict): A dictionary containing information about the entity (title and additional attributes).
|
138
|
+
"""
|
139
|
+
self.entities[e["title"]] = {k: v for k, v in e.items() if k != "title"}
|
140
|
+
|
141
|
+
def add_relation(self, r, article_title, article_publish_date):
|
142
|
+
"""
|
143
|
+
Add a relation to the KB.
|
144
|
+
|
145
|
+
Args:
|
146
|
+
r (dict): A dictionary containing information about the relation (head, type, tail, and metadata).
|
147
|
+
article_title (str): The title of the article containing the relation.
|
148
|
+
article_publish_date (str): The publish date of the article.
|
149
|
+
"""
|
150
|
+
# check on wikipedia
|
151
|
+
candidate_entities = [r["head"], r["tail"]]
|
152
|
+
entities = [self.get_wikipedia_data(ent) for ent in candidate_entities]
|
153
|
+
|
154
|
+
# if one entity does not exist, stop
|
155
|
+
if any(ent is None for ent in entities):
|
156
|
+
return
|
157
|
+
|
158
|
+
# manage new entities
|
159
|
+
for e in entities:
|
160
|
+
self.add_entity(e)
|
161
|
+
|
162
|
+
# rename relation entities with their wikipedia titles
|
163
|
+
r["head"] = entities[0]["title"]
|
164
|
+
r["tail"] = entities[1]["title"]
|
165
|
+
|
166
|
+
# add source if not in kb
|
167
|
+
article_url = list(r["meta"].keys())[0]
|
168
|
+
if article_url not in self.sources:
|
169
|
+
self.sources[article_url] = {
|
170
|
+
"article_title": article_title,
|
171
|
+
"article_publish_date": article_publish_date,
|
172
|
+
}
|
173
|
+
|
174
|
+
# manage new relation
|
175
|
+
if not self.exists_relation(r):
|
176
|
+
self.relations.append(r)
|
177
|
+
else:
|
178
|
+
self.merge_relations(r)
|
179
|
+
|
180
|
+
def print(self):
|
181
|
+
"""
|
182
|
+
Print the entities, relations, and sources in the KB.
|
183
|
+
|
184
|
+
Returns:
|
185
|
+
None
|
186
|
+
"""
|
187
|
+
print("Entities:")
|
188
|
+
for e in self.entities.items():
|
189
|
+
print(f" {e}")
|
190
|
+
print("Relations:")
|
191
|
+
for r in self.relations:
|
192
|
+
print(f" {r}")
|
193
|
+
print("Sources:")
|
194
|
+
for s in self.sources.items():
|
195
|
+
print(f" {s}")
|
196
|
+
|
197
|
+
@staticmethod
|
198
|
+
def extract_relations_from_model_output(text):
|
199
|
+
"""
|
200
|
+
Extract relations from the model output text.
|
201
|
+
|
202
|
+
Args:
|
203
|
+
text (str): The model output text containing relations.
|
204
|
+
|
205
|
+
Returns:
|
206
|
+
list: A list of dictionaries, where each dictionary represents a relation (head, type, tail).
|
207
|
+
"""
|
208
|
+
relations = []
|
209
|
+
relation, subject, relation, object_ = "", "", "", ""
|
210
|
+
text = text.strip()
|
211
|
+
current = "x"
|
212
|
+
text_replaced = text.replace("<s>", "").replace("<pad>", "").replace("</s>", "")
|
213
|
+
for token in text_replaced.split():
|
214
|
+
if token == "<triplet>":
|
215
|
+
current = "t"
|
216
|
+
if relation != "":
|
217
|
+
relations.append(
|
218
|
+
{
|
219
|
+
"head": subject.strip(),
|
220
|
+
"type": relation.strip(),
|
221
|
+
"tail": object_.strip(),
|
222
|
+
}
|
223
|
+
)
|
224
|
+
relation = ""
|
225
|
+
subject = ""
|
226
|
+
elif token == "<subj>":
|
227
|
+
current = "s"
|
228
|
+
if relation != "":
|
229
|
+
relations.append(
|
230
|
+
{
|
231
|
+
"head": subject.strip(),
|
232
|
+
"type": relation.strip(),
|
233
|
+
"tail": object_.strip(),
|
234
|
+
}
|
235
|
+
)
|
236
|
+
object_ = ""
|
237
|
+
elif token == "<obj>":
|
238
|
+
current = "o"
|
239
|
+
relation = ""
|
240
|
+
else:
|
241
|
+
if current == "t":
|
242
|
+
subject += " " + token
|
243
|
+
elif current == "s":
|
244
|
+
object_ += " " + token
|
245
|
+
elif current == "o":
|
246
|
+
relation += " " + token
|
247
|
+
if subject != "" and relation != "" and object_ != "":
|
248
|
+
relations.append(
|
249
|
+
{
|
250
|
+
"head": subject.strip(),
|
251
|
+
"type": relation.strip(),
|
252
|
+
"tail": object_.strip(),
|
253
|
+
}
|
254
|
+
)
|
255
|
+
return relations
|
256
|
+
|
257
|
+
|
258
|
+
class KGTripletExtractor:
|
259
|
+
"""
|
260
|
+
A class to perform knowledge graph triplet extraction from text using a pre-trained model.
|
261
|
+
|
262
|
+
Methods:
|
263
|
+
text_to_wiki_kb(text, model=None, tokenizer=None, device='cpu', span_length=512,
|
264
|
+
article_title=None, article_publish_date=None, verbose=False):
|
265
|
+
Extract knowledge graph triplets from text and create a KnowledgeBase (KB) containing entities and relations.
|
266
|
+
|
267
|
+
"""
|
268
|
+
|
269
|
+
@staticmethod
|
270
|
+
def text_to_wiki_kb(
|
271
|
+
text,
|
272
|
+
model=None,
|
273
|
+
tokenizer=None,
|
274
|
+
device="cpu",
|
275
|
+
span_length=512,
|
276
|
+
article_title=None,
|
277
|
+
article_publish_date=None,
|
278
|
+
verbose=False,
|
279
|
+
):
|
280
|
+
from lionagi.integrations.bridge.transformers_.install_ import (
|
281
|
+
install_transformers,
|
282
|
+
)
|
283
|
+
|
284
|
+
try:
|
285
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
|
286
|
+
except ImportError:
|
287
|
+
install_transformers()
|
288
|
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # type: ignore
|
289
|
+
import torch # type: ignore
|
290
|
+
|
291
|
+
"""
|
292
|
+
Extract knowledge graph triplets from text and create a KnowledgeBase (KB) containing entities and relations.
|
293
|
+
|
294
|
+
Args:
|
295
|
+
text (str): The input text from which triplets will be extracted.
|
296
|
+
model (AutoModelForSeq2SeqLM, optional): The pre-trained model for triplet extraction. Defaults to None.
|
297
|
+
tokenizer (AutoTokenizer, optional): The tokenizer for the model. Defaults to None.
|
298
|
+
device (str, optional): The device to run the model on (e.g., 'cpu', 'cuda'). Defaults to 'cpu'.
|
299
|
+
span_length (int, optional): The maximum span length for input text segmentation. Defaults to 512.
|
300
|
+
article_title (str, optional): The title of the article containing the input text. Defaults to None.
|
301
|
+
article_publish_date (str, optional): The publish date of the article. Defaults to None.
|
302
|
+
verbose (bool, optional): Whether to enable verbose mode for debugging. Defaults to False.
|
303
|
+
|
304
|
+
Returns:
|
305
|
+
KnowledgeBase: A KnowledgeBase (KB) containing extracted entities, relations, and sources.
|
306
|
+
|
307
|
+
"""
|
308
|
+
|
309
|
+
if not any([model, tokenizer]):
|
310
|
+
tokenizer = AutoTokenizer.from_pretrained("Babelscape/rebel-large")
|
311
|
+
model = AutoModelForSeq2SeqLM.from_pretrained("Babelscape/rebel-large")
|
312
|
+
model.to(device)
|
313
|
+
|
314
|
+
inputs = tokenizer([text], return_tensors="pt")
|
315
|
+
|
316
|
+
num_tokens = len(inputs["input_ids"][0])
|
317
|
+
if verbose:
|
318
|
+
print(f"Input has {num_tokens} tokens")
|
319
|
+
num_spans = math.ceil(num_tokens / span_length)
|
320
|
+
if verbose:
|
321
|
+
print(f"Input has {num_spans} spans")
|
322
|
+
overlap = math.ceil(
|
323
|
+
(num_spans * span_length - num_tokens) / max(num_spans - 1, 1)
|
324
|
+
)
|
325
|
+
spans_boundaries = []
|
326
|
+
start = 0
|
327
|
+
for i in range(num_spans):
|
328
|
+
spans_boundaries.append(
|
329
|
+
[start + span_length * i, start + span_length * (i + 1)]
|
330
|
+
)
|
331
|
+
start -= overlap
|
332
|
+
if verbose:
|
333
|
+
print(f"Span boundaries are {spans_boundaries}")
|
334
|
+
|
335
|
+
# transform input with spans
|
336
|
+
tensor_ids = [
|
337
|
+
inputs["input_ids"][0][boundary[0] : boundary[1]]
|
338
|
+
for boundary in spans_boundaries
|
339
|
+
]
|
340
|
+
tensor_masks = [
|
341
|
+
inputs["attention_mask"][0][boundary[0] : boundary[1]]
|
342
|
+
for boundary in spans_boundaries
|
343
|
+
]
|
344
|
+
|
345
|
+
inputs = {
|
346
|
+
"input_ids": torch.stack(tensor_ids).to(device),
|
347
|
+
"attention_mask": torch.stack(tensor_masks).to(device),
|
348
|
+
}
|
349
|
+
|
350
|
+
# generate relations
|
351
|
+
num_return_sequences = 3
|
352
|
+
gen_kwargs = {
|
353
|
+
"max_length": 512,
|
354
|
+
"length_penalty": 0,
|
355
|
+
"num_beams": 3,
|
356
|
+
"num_return_sequences": num_return_sequences,
|
357
|
+
}
|
358
|
+
generated_tokens = model.generate(
|
359
|
+
**inputs,
|
360
|
+
**gen_kwargs,
|
361
|
+
)
|
362
|
+
|
363
|
+
# decode relations
|
364
|
+
decoded_preds = tokenizer.batch_decode(
|
365
|
+
generated_tokens, skip_special_tokens=False
|
366
|
+
)
|
367
|
+
|
368
|
+
# create kb
|
369
|
+
kb = KnowledgeBase()
|
370
|
+
i = 0
|
371
|
+
for sentence_pred in decoded_preds:
|
372
|
+
current_span_index = i // num_return_sequences
|
373
|
+
relations = KnowledgeBase.extract_relations_from_model_output(sentence_pred)
|
374
|
+
for relation in relations:
|
375
|
+
relation["meta"] = {
|
376
|
+
"article_url": {"spans": [spans_boundaries[current_span_index]]}
|
377
|
+
}
|
378
|
+
kb.add_relation(relation, article_title, article_publish_date)
|
379
|
+
i += 1
|
380
|
+
return kb
|
381
|
+
|
382
|
+
|
383
|
+
class KGraph:
|
384
|
+
"""
|
385
|
+
A class representing a Knowledge Graph (KGraph) for extracting relations from text.
|
386
|
+
|
387
|
+
Methods:
|
388
|
+
text_to_wiki_kb(text, model=None, tokenizer=None, device='cpu', span_length=512, article_title=None,
|
389
|
+
article_publish_date=None, verbose=False):
|
390
|
+
Extract relations from input text and create a Knowledge Base (KB) containing entities and relations.
|
391
|
+
"""
|
392
|
+
|
393
|
+
@staticmethod
|
394
|
+
def text_to_wiki_kb(text, **kwargs):
|
395
|
+
"""
|
396
|
+
Extract relations from input text and create a Knowledge Base (KB) containing entities and relations.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
text (str): The input text from which relations are extracted.
|
400
|
+
**kwargs: Additional keyword arguments passed to the underlying extraction method.
|
401
|
+
|
402
|
+
Returns:
|
403
|
+
KnowledgeBase: A Knowledge Base (KB) containing entities and relations extracted from the input text.
|
404
|
+
"""
|
405
|
+
return KGTripletExtractor.text_to_wiki_kb(text, **kwargs)
|
lionagi/libs/ln_nested.py
CHANGED
@@ -1,3 +1,19 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2024 HaiyangLi
|
3
|
+
|
4
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
5
|
+
you may not use this file except in compliance with the License.
|
6
|
+
You may obtain a copy of the License at
|
7
|
+
|
8
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
9
|
+
|
10
|
+
Unless required by applicable law or agreed to in writing, software
|
11
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
12
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
|
+
See the License for the specific language governing permissions and
|
14
|
+
limitations under the License.
|
15
|
+
"""
|
16
|
+
|
1
17
|
from collections import defaultdict
|
2
18
|
from itertools import chain
|
3
19
|
from typing import Any, Generator, Callable
|
@@ -52,7 +68,7 @@ def nset(nested_structure: dict | list, indices: list[int | str], value: Any) ->
|
|
52
68
|
def nget(
|
53
69
|
nested_structure: dict | list,
|
54
70
|
indices: list[int | str],
|
55
|
-
default
|
71
|
+
default=...,
|
56
72
|
) -> Any:
|
57
73
|
"""
|
58
74
|
retrieves a value from a nested list or dictionary structure, with an option to
|
@@ -98,12 +114,12 @@ def nget(
|
|
98
114
|
return target_container[last_index]
|
99
115
|
elif isinstance(target_container, dict) and last_index in target_container:
|
100
116
|
return target_container[last_index]
|
101
|
-
elif default is not
|
117
|
+
elif default is not ...:
|
102
118
|
return default
|
103
119
|
else:
|
104
120
|
raise LookupError("Target not found and no default value provided.")
|
105
121
|
except (IndexError, KeyError, TypeError):
|
106
|
-
if default is not
|
122
|
+
if default is not ...:
|
107
123
|
return default
|
108
124
|
else:
|
109
125
|
raise LookupError("Target not found and no default value provided.")
|
@@ -116,7 +132,7 @@ def nmerge(
|
|
116
132
|
*,
|
117
133
|
overwrite: bool = False,
|
118
134
|
dict_sequence: bool = False,
|
119
|
-
sequence_separator: str = "_",
|
135
|
+
sequence_separator: str = "[^_^]",
|
120
136
|
sort_list: bool = False,
|
121
137
|
custom_sort: Callable[[Any], Any] | None = None,
|
122
138
|
) -> dict | list:
|
@@ -176,7 +192,7 @@ def flatten(
|
|
176
192
|
/,
|
177
193
|
*,
|
178
194
|
parent_key: str = "",
|
179
|
-
sep: str = "_",
|
195
|
+
sep: str = "[^_^]",
|
180
196
|
max_depth: int | None = None,
|
181
197
|
inplace: bool = False,
|
182
198
|
dict_only: bool = False,
|
@@ -238,7 +254,7 @@ def unflatten(
|
|
238
254
|
flat_dict: dict[str, Any],
|
239
255
|
/,
|
240
256
|
*,
|
241
|
-
sep: str = "_",
|
257
|
+
sep: str = "[^_^]",
|
242
258
|
custom_logic: Callable[[str], Any] | None = None,
|
243
259
|
max_depth: int | None = None,
|
244
260
|
) -> dict | list:
|
@@ -330,7 +346,7 @@ def ninsert(
|
|
330
346
|
indices: list[str | int],
|
331
347
|
value: Any,
|
332
348
|
*,
|
333
|
-
sep: str = "_",
|
349
|
+
sep: str = "[^_^]",
|
334
350
|
max_depth: int | None = None,
|
335
351
|
current_depth: int = 0,
|
336
352
|
) -> None:
|
@@ -393,12 +409,11 @@ def ninsert(
|
|
393
409
|
nested_structure[last_part] = value
|
394
410
|
|
395
411
|
|
396
|
-
# noinspection PyDecorator
|
397
412
|
def get_flattened_keys(
|
398
413
|
nested_structure: Any,
|
399
414
|
/,
|
400
415
|
*,
|
401
|
-
sep: str = "_",
|
416
|
+
sep: str = "[^_^]",
|
402
417
|
max_depth: int | None = None,
|
403
418
|
dict_only: bool = False,
|
404
419
|
inplace: bool = False,
|
@@ -448,7 +463,7 @@ def _dynamic_flatten_in_place(
|
|
448
463
|
/,
|
449
464
|
*,
|
450
465
|
parent_key: str = "",
|
451
|
-
sep: str = "_",
|
466
|
+
sep: str = "[^_^]",
|
452
467
|
max_depth: int | None = None,
|
453
468
|
current_depth: int = 0,
|
454
469
|
dict_only: bool = False,
|
@@ -581,7 +596,7 @@ def _deep_update(original: dict, update: dict) -> dict:
|
|
581
596
|
def _dynamic_flatten_generator(
|
582
597
|
nested_structure: Any,
|
583
598
|
parent_key: tuple[str, ...],
|
584
|
-
sep: str = "_",
|
599
|
+
sep: str = "[^_^]",
|
585
600
|
max_depth: int | None = None,
|
586
601
|
current_depth: int = 0,
|
587
602
|
dict_only: bool = False,
|