tinygent-graph 0.1.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tinygent_graph-0.1.1/.gitignore +210 -0
- tinygent_graph-0.1.1/PKG-INFO +5 -0
- tinygent_graph-0.1.1/pyproject.toml +17 -0
- tinygent_graph-0.1.1/src/tiny_graph/__init__.py +5 -0
- tinygent_graph-0.1.1/src/tiny_graph/driver/__init__.py +5 -0
- tinygent_graph-0.1.1/src/tiny_graph/driver/base.py +21 -0
- tinygent_graph-0.1.1/src/tiny_graph/driver/neo4j.py +60 -0
- tinygent_graph-0.1.1/src/tiny_graph/edge.py +40 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/__init__.py +5 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/base.py +37 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/__init__.py +5 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/core/edge.py +14 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/core/node.py +27 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/datamodels/clients.py +18 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/datamodels/extract_nodes.py +23 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/edges.py +184 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/multi_layer_graph.py +527 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/nodes.py +211 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/ops/cluster_operations.py +165 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/ops/edge_operations.py +995 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/ops/graph_operations.py +31 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/ops/node_operations.py +666 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/prompts/clusters.py +92 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/prompts/default_prompts.py +120 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/prompts/edges.py +229 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/prompts/nodes.py +108 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/queries/cluster_queries.py +21 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/queries/edge_queries.py +176 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/queries/graph_queries.py +127 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/queries/node_queries.py +179 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/__init__.py +13 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/search.py +276 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/search_cfg.py +119 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/search_presets.py +32 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/search_ranker.py +29 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/search/search_utils.py +420 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/types.py +18 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/utils/custom_types.py +20 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/utils/model_repr.py +13 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/utils/node_formatter.py +20 -0
- tinygent_graph-0.1.1/src/tiny_graph/graph/multi_layer_graph/utils/text_similarity.py +161 -0
- tinygent_graph-0.1.1/src/tiny_graph/helper.py +46 -0
- tinygent_graph-0.1.1/src/tiny_graph/node.py +42 -0
- tinygent_graph-0.1.1/src/tiny_graph/types/provider.py +5 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Byte-compiled / optimized / DLL files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[codz]
|
|
4
|
+
*$py.class
|
|
5
|
+
|
|
6
|
+
# C extensions
|
|
7
|
+
*.so
|
|
8
|
+
|
|
9
|
+
# Distribution / packaging
|
|
10
|
+
.Python
|
|
11
|
+
build/
|
|
12
|
+
develop-eggs/
|
|
13
|
+
dist/
|
|
14
|
+
downloads/
|
|
15
|
+
eggs/
|
|
16
|
+
.eggs/
|
|
17
|
+
lib/
|
|
18
|
+
lib64/
|
|
19
|
+
parts/
|
|
20
|
+
sdist/
|
|
21
|
+
var/
|
|
22
|
+
wheels/
|
|
23
|
+
share/python-wheels/
|
|
24
|
+
*.egg-info/
|
|
25
|
+
.installed.cfg
|
|
26
|
+
*.egg
|
|
27
|
+
MANIFEST
|
|
28
|
+
|
|
29
|
+
# PyInstaller
|
|
30
|
+
# Usually these files are written by a python script from a template
|
|
31
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
|
32
|
+
*.manifest
|
|
33
|
+
*.spec
|
|
34
|
+
|
|
35
|
+
# Installer logs
|
|
36
|
+
pip-log.txt
|
|
37
|
+
pip-delete-this-directory.txt
|
|
38
|
+
|
|
39
|
+
# Unit test / coverage reports
|
|
40
|
+
htmlcov/
|
|
41
|
+
.tox/
|
|
42
|
+
.nox/
|
|
43
|
+
.coverage
|
|
44
|
+
.coverage.*
|
|
45
|
+
.cache
|
|
46
|
+
nosetests.xml
|
|
47
|
+
coverage.xml
|
|
48
|
+
*.cover
|
|
49
|
+
*.py.cover
|
|
50
|
+
.hypothesis/
|
|
51
|
+
.pytest_cache/
|
|
52
|
+
cover/
|
|
53
|
+
|
|
54
|
+
# Translations
|
|
55
|
+
*.mo
|
|
56
|
+
*.pot
|
|
57
|
+
|
|
58
|
+
# Django stuff:
|
|
59
|
+
*.log
|
|
60
|
+
local_settings.py
|
|
61
|
+
db.sqlite3
|
|
62
|
+
db.sqlite3-journal
|
|
63
|
+
|
|
64
|
+
# Flask stuff:
|
|
65
|
+
instance/
|
|
66
|
+
.webassets-cache
|
|
67
|
+
|
|
68
|
+
# Scrapy stuff:
|
|
69
|
+
.scrapy
|
|
70
|
+
|
|
71
|
+
# Sphinx documentation
|
|
72
|
+
docs/_build/
|
|
73
|
+
|
|
74
|
+
# PyBuilder
|
|
75
|
+
.pybuilder/
|
|
76
|
+
target/
|
|
77
|
+
|
|
78
|
+
# Jupyter Notebook
|
|
79
|
+
.ipynb_checkpoints
|
|
80
|
+
|
|
81
|
+
# IPython
|
|
82
|
+
profile_default/
|
|
83
|
+
ipython_config.py
|
|
84
|
+
|
|
85
|
+
# pyenv
|
|
86
|
+
# For a library or package, you might want to ignore these files since the code is
|
|
87
|
+
# intended to run in multiple environments; otherwise, check them in:
|
|
88
|
+
# .python-version
|
|
89
|
+
|
|
90
|
+
# pipenv
|
|
91
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
|
92
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
|
93
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
|
94
|
+
# install all needed dependencies.
|
|
95
|
+
#Pipfile.lock
|
|
96
|
+
|
|
97
|
+
# UV
|
|
98
|
+
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
|
99
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
100
|
+
# commonly ignored for libraries.
|
|
101
|
+
#uv.lock
|
|
102
|
+
|
|
103
|
+
# poetry
|
|
104
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
|
105
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
|
106
|
+
# commonly ignored for libraries.
|
|
107
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
|
108
|
+
#poetry.lock
|
|
109
|
+
#poetry.toml
|
|
110
|
+
|
|
111
|
+
# pdm
|
|
112
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
|
113
|
+
# pdm recommends including project-wide configuration in pdm.toml, but excluding .pdm-python.
|
|
114
|
+
# https://pdm-project.org/en/latest/usage/project/#working-with-version-control
|
|
115
|
+
#pdm.lock
|
|
116
|
+
#pdm.toml
|
|
117
|
+
.pdm-python
|
|
118
|
+
.pdm-build/
|
|
119
|
+
|
|
120
|
+
# pixi
|
|
121
|
+
# Similar to Pipfile.lock, it is generally recommended to include pixi.lock in version control.
|
|
122
|
+
#pixi.lock
|
|
123
|
+
# Pixi creates a virtual environment in the .pixi directory, just like venv module creates one
|
|
124
|
+
# in the .venv directory. It is recommended not to include this directory in version control.
|
|
125
|
+
.pixi
|
|
126
|
+
|
|
127
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
|
128
|
+
__pypackages__/
|
|
129
|
+
|
|
130
|
+
# Celery stuff
|
|
131
|
+
celerybeat-schedule
|
|
132
|
+
celerybeat.pid
|
|
133
|
+
|
|
134
|
+
# SageMath parsed files
|
|
135
|
+
*.sage.py
|
|
136
|
+
|
|
137
|
+
# Environments
|
|
138
|
+
.env
|
|
139
|
+
.envrc
|
|
140
|
+
.venv
|
|
141
|
+
env/
|
|
142
|
+
venv/
|
|
143
|
+
ENV/
|
|
144
|
+
env.bak/
|
|
145
|
+
venv.bak/
|
|
146
|
+
|
|
147
|
+
# Spyder project settings
|
|
148
|
+
.spyderproject
|
|
149
|
+
.spyproject
|
|
150
|
+
|
|
151
|
+
# Rope project settings
|
|
152
|
+
.ropeproject
|
|
153
|
+
|
|
154
|
+
# mkdocs documentation
|
|
155
|
+
/site
|
|
156
|
+
|
|
157
|
+
# mypy
|
|
158
|
+
.mypy_cache/
|
|
159
|
+
.dmypy.json
|
|
160
|
+
dmypy.json
|
|
161
|
+
|
|
162
|
+
# Pyre type checker
|
|
163
|
+
.pyre/
|
|
164
|
+
|
|
165
|
+
# pytype static type analyzer
|
|
166
|
+
.pytype/
|
|
167
|
+
|
|
168
|
+
# Cython debug symbols
|
|
169
|
+
cython_debug/
|
|
170
|
+
|
|
171
|
+
# PyCharm
|
|
172
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
|
173
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
|
174
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
|
175
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
|
176
|
+
#.idea/
|
|
177
|
+
|
|
178
|
+
# Abstra
|
|
179
|
+
# Abstra is an AI-powered process automation framework.
|
|
180
|
+
# Ignore directories containing user credentials, local state, and settings.
|
|
181
|
+
# Learn more at https://abstra.io/docs
|
|
182
|
+
.abstra/
|
|
183
|
+
|
|
184
|
+
# Visual Studio Code
|
|
185
|
+
# Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
|
|
186
|
+
# that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
|
|
187
|
+
# and can be added to the global gitignore or merged into this file. However, if you prefer,
|
|
188
|
+
# you could uncomment the following to ignore the entire vscode folder
|
|
189
|
+
# .vscode/
|
|
190
|
+
|
|
191
|
+
# Ruff stuff:
|
|
192
|
+
.ruff_cache/
|
|
193
|
+
|
|
194
|
+
# PyPI configuration file
|
|
195
|
+
.pypirc
|
|
196
|
+
|
|
197
|
+
# Cursor
|
|
198
|
+
# Cursor is an AI-powered code editor. `.cursorignore` specifies files/directories to
|
|
199
|
+
# exclude from AI features like autocomplete and code analysis. Recommended for sensitive data
|
|
200
|
+
# refer to https://docs.cursor.com/context/ignore-files
|
|
201
|
+
.cursorignore
|
|
202
|
+
.cursorindexingignore
|
|
203
|
+
|
|
204
|
+
# Marimo
|
|
205
|
+
marimo/_static/
|
|
206
|
+
marimo/_lsp/
|
|
207
|
+
__marimo__/
|
|
208
|
+
|
|
209
|
+
# examples/knowledge-graph
|
|
210
|
+
.neo4j/
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "tinygent-graph"
|
|
7
|
+
version = "0.1.1"
|
|
8
|
+
dependencies = [
|
|
9
|
+
"neo4j>=6.0.3",
|
|
10
|
+
"tinygent",
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
[tool.hatch.build.targets.wheel]
|
|
14
|
+
packages = ["src/tiny_graph"]
|
|
15
|
+
|
|
16
|
+
[tool.uv.sources]
|
|
17
|
+
tinygent = { workspace = true }
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from typing_extensions import LiteralString
|
|
6
|
+
|
|
7
|
+
from tiny_graph.types.provider import GraphProvider
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class BaseDriver(ABC):
|
|
11
|
+
"""Abstract base class for graph db drivers."""
|
|
12
|
+
|
|
13
|
+
provider: GraphProvider
|
|
14
|
+
|
|
15
|
+
@abstractmethod
|
|
16
|
+
async def execute_query(self, query: str | LiteralString, **kwargs: Any) -> Any:
|
|
17
|
+
raise NotImplementedError('Subclasses must implement this method.')
|
|
18
|
+
|
|
19
|
+
@abstractmethod
|
|
20
|
+
async def close(self) -> None:
|
|
21
|
+
raise NotImplementedError('Subclasses must implement this method.')
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any
|
|
3
|
+
from typing import cast
|
|
4
|
+
|
|
5
|
+
from neo4j import AsyncGraphDatabase
|
|
6
|
+
from neo4j import EagerResult
|
|
7
|
+
from typing_extensions import LiteralString
|
|
8
|
+
|
|
9
|
+
from tiny_graph.driver.base import BaseDriver
|
|
10
|
+
from tiny_graph.types.provider import GraphProvider
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Neo4jDriver(BaseDriver):
|
|
16
|
+
provider = GraphProvider.NEO4J
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
uri: str,
|
|
21
|
+
user: str,
|
|
22
|
+
password: str,
|
|
23
|
+
) -> None:
|
|
24
|
+
self.uri = uri
|
|
25
|
+
self.user = user
|
|
26
|
+
self.password = password
|
|
27
|
+
|
|
28
|
+
self.__client = AsyncGraphDatabase.driver(
|
|
29
|
+
uri,
|
|
30
|
+
auth=(user, password),
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
async def health_check(self) -> None:
|
|
34
|
+
try:
|
|
35
|
+
await self.__client.verify_connectivity()
|
|
36
|
+
logger.debug('Neo4j connection is healthy')
|
|
37
|
+
except Exception as e:
|
|
38
|
+
logger.error('Neo4j health check failed: %s', e)
|
|
39
|
+
raise e
|
|
40
|
+
|
|
41
|
+
async def execute_query(
|
|
42
|
+
self, query: str | LiteralString, **kwargs: Any
|
|
43
|
+
) -> EagerResult:
|
|
44
|
+
params = kwargs.pop('params', {})
|
|
45
|
+
|
|
46
|
+
if isinstance(query, str):
|
|
47
|
+
query = cast(LiteralString, query)
|
|
48
|
+
|
|
49
|
+
try:
|
|
50
|
+
result = await self.__client.execute_query(
|
|
51
|
+
query, parameters_=params, **kwargs
|
|
52
|
+
)
|
|
53
|
+
except Exception as e:
|
|
54
|
+
logger.error('Neo4j failed to execute query: %s with error: %s', query, e)
|
|
55
|
+
raise e
|
|
56
|
+
return result
|
|
57
|
+
|
|
58
|
+
async def close(self) -> None:
|
|
59
|
+
await self.__client.close()
|
|
60
|
+
logger.debug('Neo4j connection closed')
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from pydantic import Field
|
|
7
|
+
|
|
8
|
+
from tiny_graph.driver.base import BaseDriver
|
|
9
|
+
from tiny_graph.helper import generate_uuid
|
|
10
|
+
from tiny_graph.helper import get_current_timestamp
|
|
11
|
+
from tinygent.core.types.base import TinyModel
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TinyEdge(TinyModel, ABC):
|
|
15
|
+
uuid: str = Field(
|
|
16
|
+
description='unique edge identifier', default_factory=generate_uuid
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
name: str = Field(description='name of the edge, relation name')
|
|
20
|
+
|
|
21
|
+
subgraph_id: str = Field(..., description='subgraph identifier')
|
|
22
|
+
|
|
23
|
+
source_node_uuid: str
|
|
24
|
+
|
|
25
|
+
target_node_uuid: str
|
|
26
|
+
|
|
27
|
+
created_at: datetime = Field(default_factory=get_current_timestamp)
|
|
28
|
+
|
|
29
|
+
attributes: dict[str, Any] = Field(
|
|
30
|
+
default_factory=dict, description='Additional attributes of the node.'
|
|
31
|
+
)
|
|
32
|
+
|
|
33
|
+
@classmethod
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def from_record(cls, record: dict) -> Any:
|
|
36
|
+
raise NotImplementedError('Subclasses must implement this method.')
|
|
37
|
+
|
|
38
|
+
@abstractmethod
|
|
39
|
+
async def save(self, driver: BaseDriver) -> Any:
|
|
40
|
+
raise NotImplementedError('Subclasses must implement this method.')
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
from abc import ABC
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from tiny_graph.driver.base import BaseDriver
|
|
6
|
+
from tinygent.core.datamodels.embedder import AbstractEmbedder
|
|
7
|
+
from tinygent.core.datamodels.llm import AbstractLLM
|
|
8
|
+
from tinygent.core.datamodels.messages import BaseMessage
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class BaseGraph(ABC):
|
|
12
|
+
def __init__(
|
|
13
|
+
self,
|
|
14
|
+
llm: AbstractLLM,
|
|
15
|
+
embedder: AbstractEmbedder,
|
|
16
|
+
driver: BaseDriver,
|
|
17
|
+
) -> None:
|
|
18
|
+
self.llm = llm
|
|
19
|
+
self.embedder = embedder
|
|
20
|
+
self.driver = driver
|
|
21
|
+
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def add_record(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
data: str | dict | BaseMessage,
|
|
27
|
+
description: str,
|
|
28
|
+
*,
|
|
29
|
+
uuid: str | None = None,
|
|
30
|
+
subgraph_id: str | None = None,
|
|
31
|
+
**kwargs,
|
|
32
|
+
) -> Any:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
@abstractmethod
|
|
36
|
+
async def close(self) -> None:
|
|
37
|
+
pass
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from tiny_graph.graph.multi_layer_graph.edges import TinyEntityEdge
|
|
2
|
+
from tinygent.core.datamodels.embedder import AbstractEmbedder
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
async def entity_edge_batch_embeddings(
|
|
6
|
+
embedder: AbstractEmbedder, edges: list[TinyEntityEdge]
|
|
7
|
+
) -> list[TinyEntityEdge]:
|
|
8
|
+
if not edges:
|
|
9
|
+
return []
|
|
10
|
+
|
|
11
|
+
embeddings = await embedder.aembed_batch([e.fact for e in edges])
|
|
12
|
+
for edge, emb in zip(edges, embeddings, strict=True):
|
|
13
|
+
edge.fact_embedding = emb
|
|
14
|
+
return edges
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from tiny_graph.graph.multi_layer_graph.nodes import TinyClusterNode
|
|
2
|
+
from tiny_graph.graph.multi_layer_graph.nodes import TinyEntityNode
|
|
3
|
+
from tinygent.core.datamodels.embedder import AbstractEmbedder
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
async def cluster_node_batch_embeddings(
|
|
7
|
+
embedder: AbstractEmbedder, clusters: list[TinyClusterNode]
|
|
8
|
+
) -> list[TinyClusterNode]:
|
|
9
|
+
if not clusters:
|
|
10
|
+
return []
|
|
11
|
+
|
|
12
|
+
embeddings = await embedder.aembed_batch([c.name for c in clusters])
|
|
13
|
+
for cluster, emb in zip(clusters, embeddings, strict=True):
|
|
14
|
+
cluster.name_embedding = emb
|
|
15
|
+
return clusters
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def entity_node_batch_embeddings(
|
|
19
|
+
embedder: AbstractEmbedder, entities: list[TinyEntityNode]
|
|
20
|
+
) -> list[TinyEntityNode]:
|
|
21
|
+
if not entities:
|
|
22
|
+
return []
|
|
23
|
+
|
|
24
|
+
embeddings = await embedder.aembed_batch([e.name for e in entities])
|
|
25
|
+
for entity, emb in zip(entities, embeddings, strict=True):
|
|
26
|
+
entity.name_embedding = emb
|
|
27
|
+
return entities
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
|
|
3
|
+
from tiny_graph.driver.base import BaseDriver
|
|
4
|
+
from tinygent.core.datamodels.cross_encoder import AbstractCrossEncoder
|
|
5
|
+
from tinygent.core.datamodels.embedder import AbstractEmbedder
|
|
6
|
+
from tinygent.core.datamodels.llm import AbstractLLM
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@dataclass
|
|
10
|
+
class TinyGraphClients:
|
|
11
|
+
driver: BaseDriver
|
|
12
|
+
llm: AbstractLLM
|
|
13
|
+
embedder: AbstractEmbedder
|
|
14
|
+
cross_encoder: AbstractCrossEncoder
|
|
15
|
+
|
|
16
|
+
@property
|
|
17
|
+
def safe_embed_model(self) -> str:
|
|
18
|
+
return self.embedder.model.replace('-', '_')
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
from pydantic import Field
|
|
2
|
+
|
|
3
|
+
from tinygent.core.types.base import TinyModel
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class ExtractedEntity(TinyModel):
|
|
7
|
+
name: str = Field(..., description='Name of the extracted entity')
|
|
8
|
+
entity_type_id: int = Field(
|
|
9
|
+
description='ID of the classified entity type. '
|
|
10
|
+
'Must be one of the provided entity_type_id integers.',
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ExtractedEntities(TinyModel):
|
|
15
|
+
extracted_entities: list[ExtractedEntity] = Field(
|
|
16
|
+
..., description='List of extracted entities'
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class MissedEntities(TinyModel):
|
|
21
|
+
missed_entities: list[str] = Field(
|
|
22
|
+
..., description="Names of entities that weren't extracted"
|
|
23
|
+
)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from datetime import datetime
|
|
4
|
+
|
|
5
|
+
from pydantic import Field
|
|
6
|
+
|
|
7
|
+
from tiny_graph.driver.base import BaseDriver
|
|
8
|
+
from tiny_graph.edge import TinyEdge
|
|
9
|
+
from tiny_graph.graph.multi_layer_graph.types import EdgeType
|
|
10
|
+
from tiny_graph.graph.multi_layer_graph.utils.model_repr import compact_model_repr
|
|
11
|
+
from tiny_graph.helper import parse_db_date
|
|
12
|
+
from tinygent.core.datamodels.embedder import AbstractEmbedder
|
|
13
|
+
from tinygent.utils.yaml import json
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TinyEntityEdge(TinyEdge):
|
|
17
|
+
fact: str = Field(
|
|
18
|
+
description='fact representing the edge and nodes that it connects'
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
fact_embedding: list[float] | None = Field(
|
|
22
|
+
default=None, description='embedding of the fact'
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
events: list[str] = Field(
|
|
26
|
+
default=[],
|
|
27
|
+
description='list of episode ids that reference these entity edges',
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
expired_at: datetime | None = Field(
|
|
31
|
+
default=None, description='datetime of when the node was invalidated'
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
valid_at: datetime | None = Field(
|
|
35
|
+
default=None, description='datetime of when the fact became true'
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
invalid_at: datetime | None = Field(
|
|
39
|
+
default=None, description='datetime of when the fact stopped being true'
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
async def save(self, driver: BaseDriver) -> str:
|
|
43
|
+
from tiny_graph.graph.multi_layer_graph.queries.edge_queries import (
|
|
44
|
+
create_entity_edge,
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
args = {
|
|
48
|
+
'edge_uuid': self.uuid,
|
|
49
|
+
'subgraph_id': self.subgraph_id,
|
|
50
|
+
'source_node_uuid': self.source_node_uuid,
|
|
51
|
+
'target_node_uuid': self.target_node_uuid,
|
|
52
|
+
'created_at': self.created_at,
|
|
53
|
+
'name': self.name,
|
|
54
|
+
'fact': self.fact,
|
|
55
|
+
'fact_embedding': self.fact_embedding,
|
|
56
|
+
'events': self.events,
|
|
57
|
+
'expired_at': self.expired_at,
|
|
58
|
+
'valid_at': self.valid_at,
|
|
59
|
+
'invalid_at': self.invalid_at,
|
|
60
|
+
'attributes': json.dumps(self.attributes),
|
|
61
|
+
}
|
|
62
|
+
|
|
63
|
+
return await driver.execute_query(
|
|
64
|
+
query=create_entity_edge(driver.provider),
|
|
65
|
+
**args,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
async def embed(self, embedder: AbstractEmbedder) -> list[float]:
|
|
69
|
+
self.fact_embedding = await embedder.aembed(self.fact)
|
|
70
|
+
return self.fact_embedding
|
|
71
|
+
|
|
72
|
+
@classmethod
|
|
73
|
+
def from_record(cls, record: dict) -> TinyEntityEdge:
|
|
74
|
+
return TinyEntityEdge(
|
|
75
|
+
uuid=record['uuid'],
|
|
76
|
+
subgraph_id=record['subgraph_id'],
|
|
77
|
+
source_node_uuid=record['source_node_uuid'],
|
|
78
|
+
target_node_uuid=record['target_node_uuid'],
|
|
79
|
+
created_at=parse_db_date(record['created_at']),
|
|
80
|
+
name=record['name'],
|
|
81
|
+
fact=record['fact'],
|
|
82
|
+
fact_embedding=record.get('fact_embedding'),
|
|
83
|
+
events=record.get('events', []),
|
|
84
|
+
expired_at=(
|
|
85
|
+
parse_db_date(record['expired_at']) if record.get('expired_at') else None
|
|
86
|
+
),
|
|
87
|
+
valid_at=(
|
|
88
|
+
parse_db_date(record['valid_at']) if record.get('valid_at') else None
|
|
89
|
+
),
|
|
90
|
+
invalid_at=(
|
|
91
|
+
parse_db_date(record['invalid_at']) if record.get('invalid_at') else None
|
|
92
|
+
),
|
|
93
|
+
attributes=(
|
|
94
|
+
json.loads(record['attributes'])
|
|
95
|
+
if isinstance(record.get('attributes'), str)
|
|
96
|
+
else record.get('attributes', {})
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
@classmethod
|
|
101
|
+
async def find_by_targets(
|
|
102
|
+
cls, driver: BaseDriver, source_uuid: str, target_uuid: str
|
|
103
|
+
) -> list[TinyEntityEdge]:
|
|
104
|
+
from tiny_graph.graph.multi_layer_graph.queries.edge_queries import (
|
|
105
|
+
find_entity_edge_by_targets,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
query = find_entity_edge_by_targets(driver.provider)
|
|
109
|
+
|
|
110
|
+
results, _, _ = await driver.execute_query(
|
|
111
|
+
query,
|
|
112
|
+
source_uuid=source_uuid,
|
|
113
|
+
target_uuid=target_uuid,
|
|
114
|
+
)
|
|
115
|
+
return [TinyEntityEdge.from_record(r) for r in results]
|
|
116
|
+
|
|
117
|
+
def __repr__(self) -> str:
|
|
118
|
+
return compact_model_repr(self)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
class TinyClusterEdge(TinyEdge):
|
|
122
|
+
name: str = Field(default=EdgeType.HAS_MEMBER.value, frozen=True)
|
|
123
|
+
|
|
124
|
+
async def save(self, driver: BaseDriver) -> str:
|
|
125
|
+
from tiny_graph.graph.multi_layer_graph.queries.edge_queries import (
|
|
126
|
+
create_cluster_edge,
|
|
127
|
+
)
|
|
128
|
+
|
|
129
|
+
args = {
|
|
130
|
+
'uuid': self.uuid,
|
|
131
|
+
'subgraph_id': self.subgraph_id,
|
|
132
|
+
'created_at': self.created_at,
|
|
133
|
+
'cluster_node_uuid': self.source_node_uuid,
|
|
134
|
+
'entity_node_uuid': self.target_node_uuid,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
return await driver.execute_query(
|
|
138
|
+
query=create_cluster_edge(driver.provider),
|
|
139
|
+
**args,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
@classmethod
|
|
143
|
+
def from_record(cls, record: dict) -> TinyClusterEdge:
|
|
144
|
+
return TinyClusterEdge(
|
|
145
|
+
uuid=record['uuid'],
|
|
146
|
+
subgraph_id=record['subgraph_id'],
|
|
147
|
+
source_node_uuid=record['source_node_uuid'],
|
|
148
|
+
target_node_uuid=record['target_node_uuid'],
|
|
149
|
+
created_at=parse_db_date(record['created_at']),
|
|
150
|
+
name=record['name'],
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
class TinyEventEdge(TinyEdge):
|
|
155
|
+
name: str = Field(default=EdgeType.MENTIONS.value, frozen=True)
|
|
156
|
+
|
|
157
|
+
async def save(self, driver: BaseDriver) -> str:
|
|
158
|
+
from tiny_graph.graph.multi_layer_graph.queries.edge_queries import (
|
|
159
|
+
create_event_edge,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
args = {
|
|
163
|
+
'uuid': self.uuid,
|
|
164
|
+
'subgraph_id': self.subgraph_id,
|
|
165
|
+
'created_at': self.created_at,
|
|
166
|
+
'event_node_uuid': self.source_node_uuid,
|
|
167
|
+
'entity_node_uuid': self.target_node_uuid,
|
|
168
|
+
}
|
|
169
|
+
|
|
170
|
+
return await driver.execute_query(
|
|
171
|
+
query=create_event_edge(driver.provider),
|
|
172
|
+
**args,
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@classmethod
|
|
176
|
+
def from_record(cls, record: dict) -> TinyClusterEdge:
|
|
177
|
+
return TinyClusterEdge(
|
|
178
|
+
uuid=record['uuid'],
|
|
179
|
+
subgraph_id=record['subgraph_id'],
|
|
180
|
+
source_node_uuid=record['source_node_uuid'],
|
|
181
|
+
target_node_uuid=record['target_node_uuid'],
|
|
182
|
+
created_at=parse_db_date(record['created_at']),
|
|
183
|
+
name=record['name'],
|
|
184
|
+
)
|