vedana-core 0.1.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vedana_core/graph.py ADDED
@@ -0,0 +1,300 @@
1
+ import abc
2
+ import json
3
+ import logging
4
+ import re
5
+ from typing import Any, Dict, Iterable, Set, cast
6
+
7
+ import aioitertools as aioit
8
+ import neo4j
9
+ import numpy as np
10
+ import typing_extensions as te
11
+ from neo4j import AsyncGraphDatabase, EagerResult, RoutingControl
12
+ from opentelemetry import trace
13
+
14
+ logger = logging.getLogger(__name__)
15
+ tracer = trace.get_tracer(__name__)
16
+
17
+ Record = neo4j.Record
18
+
19
+
20
+ class Graph(abc.ABC):
21
+ async def add_node(
22
+ self,
23
+ node_id: str,
24
+ labels: Set[str],
25
+ properties: dict[str, Any] | None = None,
26
+ embeddings: dict[str, np.ndarray] | None = None,
27
+ ) -> None:
28
+ raise NotImplementedError
29
+
30
+ async def add_edge(self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None) -> None:
31
+ raise NotImplementedError
32
+
33
+ async def number_of_nodes(self) -> int:
34
+ raise NotImplementedError
35
+
36
+ async def number_of_edges(self) -> int:
37
+ raise NotImplementedError
38
+
39
+ async def run_cypher(
40
+ self,
41
+ query: str,
42
+ parameters: dict[str, Any] | None = None,
43
+ limit: int | None = None,
44
+ ) -> Iterable[Record]:
45
+ raise NotImplementedError
46
+
47
+ async def get_existing_node_types(self) -> Iterable[list[str]]:
48
+ raise NotImplementedError
49
+
50
+ async def llm_schema(self) -> str:
51
+ raise NotImplementedError
52
+
53
+ async def text_search(self, label: str, query: str, limit: int = 10) -> Iterable[Record]:
54
+ raise NotImplementedError
55
+
56
+ async def setup(self, *_, create_basic_indices: bool = True, **kwargs) -> None:
57
+ # Set false to speedup import
58
+ if create_basic_indices:
59
+ await self.create_basic_indices()
60
+
61
+ async def create_basic_indices(self) -> None:
62
+ ...
63
+
64
+ async def execute_ro_cypher_query(
65
+ self,
66
+ query: str,
67
+ parameters: dict[str, Any] | None = None,
68
+ limit: int | None = None,
69
+ ) -> Iterable[Record]:
70
+ return await self.run_cypher(query, parameters, limit=limit)
71
+
72
+ async def clear(self) -> None:
73
+ ...
74
+
75
+ def close(self) -> None:
76
+ ...
77
+
78
+ def __enter__(self):
79
+ return self
80
+
81
+ def __exit__(self, exc_type, exc_val, exc_tb):
82
+ self.close()
83
+
84
+
85
+ class CypherGraph(Graph):
86
+ async def add_node(
87
+ self,
88
+ node_id: str,
89
+ labels: Set[str],
90
+ properties: dict[str, Any] | None = None,
91
+ embeddings: dict[str, np.ndarray] | None = None,
92
+ ) -> None:
93
+ query, params = self._add_node_cypher(node_id, labels, properties or {})
94
+ await self.run_cypher(query, params)
95
+
96
+ def _add_node_cypher(
97
+ self,
98
+ node_id: str,
99
+ labels: Set[str],
100
+ properties: dict[str, Any],
101
+ ) -> tuple[str, dict[str, Any]]:
102
+ labels_expr = escape_labels(labels)
103
+ props = {
104
+ **properties,
105
+ "id": node_id,
106
+ }
107
+ # TODO escape
108
+ pros_expr = ", ".join(f"{k}: ${k}" for k in props.keys())
109
+ return (
110
+ f"MERGE (n:{labels_expr} {{id: $id}}) SET n = {{{pros_expr}}} RETURN n",
111
+ props,
112
+ )
113
+
114
+ async def add_edge(self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None) -> None:
115
+ query, params = self._add_edge_cypher(from_id, to_id, type_, attrs)
116
+ await self.run_cypher(query, params)
117
+
118
+ def _add_edge_cypher(
119
+ self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None
120
+ ) -> tuple[str, dict[str, Any]]:
121
+ attrs = attrs or {}
122
+ labels_expr = escape_labels({type_})
123
+ # attrs = {escape_cypher(k): v for k, v in attrs.items()}
124
+ attrs_expr = ", ".join(f"{k}: ${k}" for k in attrs.keys() if k)
125
+ params = {
126
+ **attrs,
127
+ "from_id": from_id,
128
+ "to_id": to_id,
129
+ }
130
+ return (
131
+ "MATCH (nf {id: $from_id}), (nt {id: $to_id}) "
132
+ f"CREATE (nf)-[r:{labels_expr} {{{attrs_expr}}}]->(nt) RETURN r",
133
+ params,
134
+ )
135
+
136
+ async def add_edges(self, edges: Iterable[tuple[str, str, dict]], **common_attrs) -> None:
137
+ for edge_tuple in edges:
138
+ from_id, to_id, attrs = edge_tuple
139
+ attrs = {**common_attrs, **attrs}
140
+ labels: Iterable[str] = attrs.pop("__labels__", [])
141
+ type_ = next(iter(labels), "no_type")
142
+ await self.add_edge(from_id, to_id, type_, attrs)
143
+
144
+ async def number_of_nodes(self) -> int:
145
+ res = await self.execute_ro_cypher_query("MATCH (n) RETURN count(*) as cnt")
146
+ return next(iter(res))["cnt"]
147
+
148
+ async def number_of_edges(self) -> int:
149
+ res = await self.execute_ro_cypher_query("MATCH (f)-[]->(t) RETURN count(*) as cnt")
150
+ return next(iter(res))["cnt"]
151
+
152
+ async def get_existing_node_types(self) -> Iterable[list[str]]:
153
+ res = await self.execute_ro_cypher_query("MATCH (n) RETURN DISTINCT labels(n) as l;")
154
+ return [r["l"] for r in res]
155
+
156
+
157
+ # class NXGraph(Graph):
158
+ # def __init__(self, graph: nx.Graph) -> None:
159
+ # self.graph: nx.Graph = graph
160
+ # self.gcypher = GrandCypher(self.graph)
161
+
162
+ # def execute_ro_cypher_query(self, query: str) -> Iterable[Any]:
163
+ # return self.gcypher.run(query)
164
+
165
+ # def add_node(self, node_id: str, labels: Set[str], **attributes) -> None:
166
+ # self.graph.add_node(node_id, __labels__=labels, **attributes)
167
+
168
+ # def number_of_edges(self) -> int:
169
+ # return self.graph.number_of_edges()
170
+
171
+ # def clear(self) -> None:
172
+ # self.graph.clear()
173
+
174
+
175
+ class MemgraphGraph(CypherGraph):
176
+ def __init__(self, uri: str, user: str, pwd: str, db_name: str = "") -> None:
177
+ self.driver = AsyncGraphDatabase.driver(uri, auth=(user, pwd), database=db_name)
178
+ # await self.driver.verify_connectivity()
179
+ self.driver_uri = uri
180
+ self.auth = (user, pwd)
181
+
182
+ async def execute_ro_cypher_query(
183
+ self, query: str, parameters: dict[str, Any] | None = None, limit: int | None = None
184
+ ) -> Iterable[Record]:
185
+ with tracer.start_as_current_span("memgraph.execute_ro_cypher_query") as span:
186
+ span.set_attribute("memgraph.query", query)
187
+ if parameters:
188
+ span.set_attribute("memgraph.parameters", json.dumps(parameters))
189
+ result: EagerResult = await self.driver.execute_query(query, parameters, routing_=RoutingControl.READ)
190
+
191
+ return result.records
192
+
193
+ async def run_cypher(
194
+ self,
195
+ query: str,
196
+ parameters: dict[str, Any] | None = None,
197
+ limit: int | None = None,
198
+ ) -> Iterable[Record]:
199
+ with tracer.start_as_current_span("memgraph.run_cypher") as span:
200
+ span.set_attribute("memgraph.query", query)
201
+ if parameters:
202
+ span.set_attribute("memgraph.parameters", json.dumps(parameters))
203
+ if limit is not None:
204
+ span.set_attribute("memgraph.limit", limit)
205
+
206
+ async with self.driver.session() as session:
207
+ result = await aioit.list(aioit.islice(await session.run(query, parameters), limit))
208
+
209
+ return result
210
+
211
+ async def add_node(
212
+ self,
213
+ node_id: str,
214
+ labels: Set[str],
215
+ properties: Dict[str, Any] | None = None,
216
+ embeddings: Dict[str, np.ndarray] | None = None,
217
+ ) -> None:
218
+ if properties and embeddings:
219
+ embed_props = {f"{prop_name}_embedding": v for prop_name, v in embeddings.items()}
220
+ properties = {
221
+ **properties,
222
+ **embed_props,
223
+ }
224
+ await super().add_node(node_id, labels, properties, embeddings)
225
+
226
+ async def llm_schema(self) -> str:
227
+ """can be used as fallback data model structure"""
228
+ res = await self.driver.execute_query("CALL llm_util.schema() YIELD schema RETURN schema")
229
+ return res.records[0]["schema"]
230
+
231
+ async def create_basic_indices(self, node_types=None) -> None:
232
+ if not node_types:
233
+ node_types = await self.get_existing_node_types()
234
+ for label in node_types:
235
+ await self.create_node_prop_index(set(label), "id", unique=True)
236
+
237
+ async def clear(self) -> None:
238
+ async with self.driver.session() as session:
239
+ res = await session.run("CALL vector_search.show_index_info() YIELD index_name RETURN *")
240
+
241
+ async for (idx_name,) in res:
242
+ await session.run(f"DROP VECTOR INDEX {escape_cypher(idx_name)}")
243
+ idx_name_re = re.compile(r"\(name:\s(.+?)\)")
244
+ async for row in await session.run(cast(te.LiteralString, "SHOW INDEX INFO")):
245
+ index_type = row["index type"]
246
+ idx_name = next(iter(idx_name_re.findall(index_type)), None)
247
+ if not idx_name:
248
+ continue
249
+ await session.run(f"DROP TEXT INDEX {escape_cypher(idx_name)}")
250
+ await session.run("CALL schema.assert({}, {}, {}, true) YIELD action, key, keys, label, unique")
251
+ await session.run("MATCH (n) DETACH DELETE n")
252
+ # TODO more efficient:
253
+ # USING PERIODIC COMMIT num_rows
254
+ # MATCH (n)-[r]->(m)
255
+ # DELETE r;
256
+ # USING PERIODIC COMMIT num_rows
257
+ # MATCH (n)
258
+ # DETACH DELETE n;
259
+
260
+ async def text_search(self, label: str, query: str, limit: int = 10) -> Iterable[Record]:
261
+ with tracer.start_as_current_span("memgraph.text_search") as span:
262
+ span.set_attribute("memgraph.label", label)
263
+ span.set_attribute("memgraph.fts_query", query)
264
+ span.set_attribute("memgraph.limit", limit)
265
+
266
+ query = "CALL text_search.search_all($idx_name, $query) YIELD node RETURN node LIMIT $limit"
267
+ span.set_attribute("memgraph.query", query)
268
+
269
+ res = await self.driver.execute_query(
270
+ query,
271
+ idx_name=self._fts_idx_name(label),
272
+ query=query,
273
+ limit=limit,
274
+ routing_=RoutingControl.READ,
275
+ )
276
+ return res.records
277
+
278
+ async def create_node_prop_index(self, labels: set[str], property: str, unique: bool = False) -> None:
279
+ escaped_label = escape_labels(labels)
280
+ escaped_prop = escape_cypher(property)
281
+ await self.run_cypher(f"CREATE INDEX ON :{escaped_label}({escaped_prop})")
282
+ if not unique:
283
+ return
284
+ await self.run_cypher(f"CREATE CONSTRAINT ON (n:{escaped_label})\nASSERT n.{escaped_prop} IS UNIQUE")
285
+
286
+ @staticmethod
287
+ def _fts_idx_name(label: str) -> str:
288
+ return f"{label.lower()}_fts_idx"
289
+
290
+ def close(self):
291
+ self.driver.close()
292
+
293
+
294
+ def escape_cypher(identifier: str) -> str:
295
+ identifier = identifier.replace("\u0060", "`").replace("`", "``")
296
+ return f"`{identifier}`"
297
+
298
+
299
+ def escape_labels(labels: set[str]) -> str:
300
+ return ":".join(escape_cypher(label) for label in labels)
vedana_core/llm.py ADDED
@@ -0,0 +1,192 @@
1
+ import asyncio
2
+ import logging
3
+ from typing import Awaitable, Callable, Iterable
4
+
5
+ import openai
6
+ from jims_core.llms.llm_provider import LLMProvider
7
+ from jims_core.thread.schema import CommunicationEvent
8
+ from openai.types.chat import (
9
+ ChatCompletionMessageParam,
10
+ ChatCompletionToolMessageParam,
11
+ )
12
+ from pydantic import BaseModel
13
+
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ class Tool[T: BaseModel]:
18
+ def __init__(
19
+ self, name: str, description: str, args_cls: type[T], fn: Callable[[T], Awaitable[str]] | Callable[[T], str]
20
+ ) -> None:
21
+ self.name = name
22
+ self.description = description
23
+ self.args_cls = args_cls
24
+ self.fn = fn
25
+ self.openai_def = openai.pydantic_function_tool(args_cls, name=name, description=description)
26
+
27
+ async def call(self, args_json: str) -> str:
28
+ try:
29
+ fn_args = self.args_cls.model_validate_json(args_json)
30
+ except ValueError:
31
+ return f"Invalid tool args: {args_json}"
32
+
33
+ if asyncio.iscoroutinefunction(self.fn):
34
+ result = await self.fn(fn_args)
35
+ else:
36
+ result: str = await asyncio.to_thread(self.fn, fn_args) # type: ignore
37
+
38
+ return result
39
+
40
+
41
+ class LLM:
42
+ def __init__(
43
+ self,
44
+ llm_provider: LLMProvider,
45
+ prompt_templates: dict[str, str],
46
+ logger: logging.Logger | None = None,
47
+ ) -> None:
48
+ self.logger = logger or logging.getLogger(__name__)
49
+ self.llm = llm_provider
50
+ self.prompt_templates = prompt_templates
51
+
52
+ # Current
53
+ async def generate_cypher_query_with_tools(
54
+ self,
55
+ data_descr: str,
56
+ messages: Iterable,
57
+ tools: list[Tool],
58
+ ) -> tuple[list[ChatCompletionMessageParam], str]:
59
+ tool_names = [t.name for t in tools]
60
+ msgs = make_cypher_query_with_tools_dialog(data_descr, self.prompt_templates, messages, tool_names=tool_names)
61
+ return await self.create_completion_with_tools(msgs, tools=tools)
62
+
63
+ async def create_completion_with_tools(
64
+ self,
65
+ messages: list[ChatCompletionMessageParam],
66
+ tools: Iterable[Tool],
67
+ ) -> tuple[list[ChatCompletionMessageParam], str]:
68
+ messages = messages.copy()
69
+ tool_defs = [tool.openai_def for tool in tools]
70
+ tools_map = {tool.name: tool for tool in tools}
71
+
72
+ async def _execute_tool_call(tool_call):
73
+ tool_name = tool_call.function.name
74
+ tool = tools_map.get(tool_name)
75
+ if not tool:
76
+ self.logger.error(f"Tool {tool_name} not found!")
77
+ return tool_call.id, f"Tool {tool_name} not found!"
78
+
79
+ self.logger.debug(f"Calling tool {tool_name}")
80
+ try:
81
+ tool_res = await tool.call(tool_call.function.arguments)
82
+ except Exception as e:
83
+ self.logger.exception("Error executing tool %s: %s", tool_name, e)
84
+ tool_res = f"Error executing tool {tool_name}: {e}"
85
+
86
+ self.logger.debug("Tool %s (%s) result: %s", tool_name, tool.description, tool_res)
87
+ return tool_call.id, tool_res
88
+
89
+ max_iters = 5
90
+ for i in range(max_iters):
91
+ msg, tool_calls = await self.llm.chat_completion_with_tools(
92
+ messages=messages,
93
+ tools=tool_defs,
94
+ )
95
+
96
+ messages.append(msg.to_dict()) # type: ignore
97
+
98
+ if not tool_calls:
99
+ self.logger.debug("No tool calls found. Exiting tool call loop")
100
+ break
101
+
102
+ self.logger.debug(f"Tool call iter {i + 1}/{max_iters}")
103
+
104
+ # Execute tool calls in parallel
105
+ results = await asyncio.gather(*[_execute_tool_call(t) for t in tool_calls])
106
+
107
+ for tool_call_id, tool_res in results:
108
+ messages.append(
109
+ ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call_id, content=tool_res)
110
+ )
111
+
112
+ if i == max_iters - 1:
113
+ self.logger.warning(f"Reached tool call iteration limit ({max_iters}). Exiting tool call loop")
114
+ finalize_prompt = self.prompt_templates.get("finalize_answer_tmplt", finalize_answer_tmplt)
115
+ finalize_msg = {"role": "system", "content": finalize_prompt}
116
+ final_msg = await self.llm.chat_completion_plain(messages + [finalize_msg])
117
+ messages.append(final_msg.to_dict()) # type: ignore
118
+ break
119
+
120
+ for last_msg in reversed(messages): # sometimes message with final answer is not the last one
121
+ if last_msg.get("role", "") == "assistant" and last_msg.get("content"):
122
+ return messages, str(last_msg.get("content"))
123
+ return messages, ""
124
+
125
+ async def generate_no_answer(
126
+ self,
127
+ dialog: list[CommunicationEvent] | None = None,
128
+ ) -> str:
129
+ prompt = self.prompt_templates.get("generate_no_answer_tmplt", generate_no_answer_tmplt)
130
+ messages = [
131
+ {"role": "system", "content": prompt},
132
+ *(dialog or []),
133
+ ]
134
+ response = await self.llm.chat_completion_plain(messages)
135
+ human_answer = "" if response.content is None else response.content.strip()
136
+ self.logger.debug(f"Generated 'no answer' response: {human_answer}")
137
+ return human_answer
138
+
139
+
140
+ finalize_answer_tmplt = """\
141
+ Сформулируй ответ на запрос пользователя основе информации, полученной в результате вызова результатов инструментов.
142
+ Если информации недостаточно для точного ответа, ясно опиши ограничения и предложи 1–2 уточняющих вопроса.
143
+ Важно! Не упоминай инструменты в явном виде, ссылайся только на данные.
144
+ """
145
+
146
+ generate_no_answer_tmplt = """\
147
+ Ты - помощник, который преобразует технические ответы в понятный человеку текст.
148
+ Мы не смогли найти ответ на вопрос пользователя в базе знаний.
149
+ Сформулируй ответ, сообщающий кратко и информативно, что ответа не найдено.
150
+ Предложи пару вариантов уточняющих вопросов на основе информации в контексте. Предложи в casual стиле.
151
+ """
152
+
153
+ generate_answer_with_tools_tmplt = """\
154
+ Ты — помощник по работе с графовыми базами данных, в которых используется язык запросов Cypher
155
+
156
+ Цель: постараться найти ответ на вопрос пользователя используя инструменты для работы с БД на основе текстового описания графовой базы данных.
157
+
158
+ На вход ты получаешь graph_composition: – описание графа и примеры запросов по нему, и user_query – пользовательский запрос.
159
+
160
+ **Что нужно сделать:**
161
+ 1. Сгенерировать `Cypher`-запросы, используя узлы, атрибуты и связи перечисленные в **graph_composition**.
162
+ 2. Руководствуйся данными в **graph_composition** примерами запросов, чтобы составить итоговый запрос.
163
+ 3. Используй инструменты {tools} для выполнения запросов и поиска
164
+
165
+ Если нужно, используй несколько `MATCH`-блоков, например:
166
+ MATCH (o:offer)-[:OFFER_belongs_to_CATEGORY]->(c:category)
167
+ MATCH (o)-[:OFFER_made_of_MATERIAL]->(m:material)
168
+ WHERE c.category_name = "Встраиваемый светильник" AND m.material_name IN ["Стекло", "Металл и Стекло", "Алюминий и стекло"]
169
+ RETURN o
170
+
171
+ Теперь проанализируй следующую структуру графа, и постарайся найти ответ на вопрос используя инструменты {tools}. (Лучше использовать несколько инструментов)
172
+
173
+ **graph_composition**
174
+ {graph_description}
175
+ """
176
+
177
+
178
+ def make_cypher_query_with_tools_dialog(
179
+ graph_description: str,
180
+ prompt_templates: dict[str, str],
181
+ messages: Iterable,
182
+ tool_names: list[str],
183
+ ) -> list[ChatCompletionMessageParam]:
184
+ prompt_template = prompt_templates.get("generate_answer_with_tools_tmplt", generate_answer_with_tools_tmplt)
185
+ prompt = prompt_template.format(graph_description=graph_description, tools=", ".join(tool_names))
186
+ return [
187
+ {
188
+ "role": "system",
189
+ "content": prompt,
190
+ },
191
+ *messages,
192
+ ]
vedana_core/py.typed ADDED
File without changes