vedana-core 0.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vedana_core/__init__.py +0 -0
- vedana_core/app.py +78 -0
- vedana_core/data_model.py +465 -0
- vedana_core/data_provider.py +513 -0
- vedana_core/db.py +41 -0
- vedana_core/graph.py +300 -0
- vedana_core/llm.py +192 -0
- vedana_core/py.typed +0 -0
- vedana_core/rag_agent.py +234 -0
- vedana_core/rag_pipeline.py +326 -0
- vedana_core/settings.py +35 -0
- vedana_core/start_pipeline.py +17 -0
- vedana_core/utils.py +31 -0
- vedana_core/vts.py +167 -0
- vedana_core-0.1.0.dev3.dist-info/METADATA +29 -0
- vedana_core-0.1.0.dev3.dist-info/RECORD +17 -0
- vedana_core-0.1.0.dev3.dist-info/WHEEL +4 -0
vedana_core/graph.py
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, Iterable, Set, cast
|
|
6
|
+
|
|
7
|
+
import aioitertools as aioit
|
|
8
|
+
import neo4j
|
|
9
|
+
import numpy as np
|
|
10
|
+
import typing_extensions as te
|
|
11
|
+
from neo4j import AsyncGraphDatabase, EagerResult, RoutingControl
|
|
12
|
+
from opentelemetry import trace
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
tracer = trace.get_tracer(__name__)
|
|
16
|
+
|
|
17
|
+
Record = neo4j.Record
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class Graph(abc.ABC):
|
|
21
|
+
async def add_node(
|
|
22
|
+
self,
|
|
23
|
+
node_id: str,
|
|
24
|
+
labels: Set[str],
|
|
25
|
+
properties: dict[str, Any] | None = None,
|
|
26
|
+
embeddings: dict[str, np.ndarray] | None = None,
|
|
27
|
+
) -> None:
|
|
28
|
+
raise NotImplementedError
|
|
29
|
+
|
|
30
|
+
async def add_edge(self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None) -> None:
|
|
31
|
+
raise NotImplementedError
|
|
32
|
+
|
|
33
|
+
async def number_of_nodes(self) -> int:
|
|
34
|
+
raise NotImplementedError
|
|
35
|
+
|
|
36
|
+
async def number_of_edges(self) -> int:
|
|
37
|
+
raise NotImplementedError
|
|
38
|
+
|
|
39
|
+
async def run_cypher(
|
|
40
|
+
self,
|
|
41
|
+
query: str,
|
|
42
|
+
parameters: dict[str, Any] | None = None,
|
|
43
|
+
limit: int | None = None,
|
|
44
|
+
) -> Iterable[Record]:
|
|
45
|
+
raise NotImplementedError
|
|
46
|
+
|
|
47
|
+
async def get_existing_node_types(self) -> Iterable[list[str]]:
|
|
48
|
+
raise NotImplementedError
|
|
49
|
+
|
|
50
|
+
async def llm_schema(self) -> str:
|
|
51
|
+
raise NotImplementedError
|
|
52
|
+
|
|
53
|
+
async def text_search(self, label: str, query: str, limit: int = 10) -> Iterable[Record]:
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
async def setup(self, *_, create_basic_indices: bool = True, **kwargs) -> None:
|
|
57
|
+
# Set false to speedup import
|
|
58
|
+
if create_basic_indices:
|
|
59
|
+
await self.create_basic_indices()
|
|
60
|
+
|
|
61
|
+
async def create_basic_indices(self) -> None:
|
|
62
|
+
...
|
|
63
|
+
|
|
64
|
+
async def execute_ro_cypher_query(
|
|
65
|
+
self,
|
|
66
|
+
query: str,
|
|
67
|
+
parameters: dict[str, Any] | None = None,
|
|
68
|
+
limit: int | None = None,
|
|
69
|
+
) -> Iterable[Record]:
|
|
70
|
+
return await self.run_cypher(query, parameters, limit=limit)
|
|
71
|
+
|
|
72
|
+
async def clear(self) -> None:
|
|
73
|
+
...
|
|
74
|
+
|
|
75
|
+
def close(self) -> None:
|
|
76
|
+
...
|
|
77
|
+
|
|
78
|
+
def __enter__(self):
|
|
79
|
+
return self
|
|
80
|
+
|
|
81
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
82
|
+
self.close()
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class CypherGraph(Graph):
|
|
86
|
+
async def add_node(
|
|
87
|
+
self,
|
|
88
|
+
node_id: str,
|
|
89
|
+
labels: Set[str],
|
|
90
|
+
properties: dict[str, Any] | None = None,
|
|
91
|
+
embeddings: dict[str, np.ndarray] | None = None,
|
|
92
|
+
) -> None:
|
|
93
|
+
query, params = self._add_node_cypher(node_id, labels, properties or {})
|
|
94
|
+
await self.run_cypher(query, params)
|
|
95
|
+
|
|
96
|
+
def _add_node_cypher(
|
|
97
|
+
self,
|
|
98
|
+
node_id: str,
|
|
99
|
+
labels: Set[str],
|
|
100
|
+
properties: dict[str, Any],
|
|
101
|
+
) -> tuple[str, dict[str, Any]]:
|
|
102
|
+
labels_expr = escape_labels(labels)
|
|
103
|
+
props = {
|
|
104
|
+
**properties,
|
|
105
|
+
"id": node_id,
|
|
106
|
+
}
|
|
107
|
+
# TODO escape
|
|
108
|
+
pros_expr = ", ".join(f"{k}: ${k}" for k in props.keys())
|
|
109
|
+
return (
|
|
110
|
+
f"MERGE (n:{labels_expr} {{id: $id}}) SET n = {{{pros_expr}}} RETURN n",
|
|
111
|
+
props,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
async def add_edge(self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None) -> None:
|
|
115
|
+
query, params = self._add_edge_cypher(from_id, to_id, type_, attrs)
|
|
116
|
+
await self.run_cypher(query, params)
|
|
117
|
+
|
|
118
|
+
def _add_edge_cypher(
|
|
119
|
+
self, from_id: str, to_id: str, type_: str, attrs: Dict[str, Any] | None
|
|
120
|
+
) -> tuple[str, dict[str, Any]]:
|
|
121
|
+
attrs = attrs or {}
|
|
122
|
+
labels_expr = escape_labels({type_})
|
|
123
|
+
# attrs = {escape_cypher(k): v for k, v in attrs.items()}
|
|
124
|
+
attrs_expr = ", ".join(f"{k}: ${k}" for k in attrs.keys() if k)
|
|
125
|
+
params = {
|
|
126
|
+
**attrs,
|
|
127
|
+
"from_id": from_id,
|
|
128
|
+
"to_id": to_id,
|
|
129
|
+
}
|
|
130
|
+
return (
|
|
131
|
+
"MATCH (nf {id: $from_id}), (nt {id: $to_id}) "
|
|
132
|
+
f"CREATE (nf)-[r:{labels_expr} {{{attrs_expr}}}]->(nt) RETURN r",
|
|
133
|
+
params,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
async def add_edges(self, edges: Iterable[tuple[str, str, dict]], **common_attrs) -> None:
|
|
137
|
+
for edge_tuple in edges:
|
|
138
|
+
from_id, to_id, attrs = edge_tuple
|
|
139
|
+
attrs = {**common_attrs, **attrs}
|
|
140
|
+
labels: Iterable[str] = attrs.pop("__labels__", [])
|
|
141
|
+
type_ = next(iter(labels), "no_type")
|
|
142
|
+
await self.add_edge(from_id, to_id, type_, attrs)
|
|
143
|
+
|
|
144
|
+
async def number_of_nodes(self) -> int:
|
|
145
|
+
res = await self.execute_ro_cypher_query("MATCH (n) RETURN count(*) as cnt")
|
|
146
|
+
return next(iter(res))["cnt"]
|
|
147
|
+
|
|
148
|
+
async def number_of_edges(self) -> int:
|
|
149
|
+
res = await self.execute_ro_cypher_query("MATCH (f)-[]->(t) RETURN count(*) as cnt")
|
|
150
|
+
return next(iter(res))["cnt"]
|
|
151
|
+
|
|
152
|
+
async def get_existing_node_types(self) -> Iterable[list[str]]:
|
|
153
|
+
res = await self.execute_ro_cypher_query("MATCH (n) RETURN DISTINCT labels(n) as l;")
|
|
154
|
+
return [r["l"] for r in res]
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
# class NXGraph(Graph):
|
|
158
|
+
# def __init__(self, graph: nx.Graph) -> None:
|
|
159
|
+
# self.graph: nx.Graph = graph
|
|
160
|
+
# self.gcypher = GrandCypher(self.graph)
|
|
161
|
+
|
|
162
|
+
# def execute_ro_cypher_query(self, query: str) -> Iterable[Any]:
|
|
163
|
+
# return self.gcypher.run(query)
|
|
164
|
+
|
|
165
|
+
# def add_node(self, node_id: str, labels: Set[str], **attributes) -> None:
|
|
166
|
+
# self.graph.add_node(node_id, __labels__=labels, **attributes)
|
|
167
|
+
|
|
168
|
+
# def number_of_edges(self) -> int:
|
|
169
|
+
# return self.graph.number_of_edges()
|
|
170
|
+
|
|
171
|
+
# def clear(self) -> None:
|
|
172
|
+
# self.graph.clear()
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class MemgraphGraph(CypherGraph):
|
|
176
|
+
def __init__(self, uri: str, user: str, pwd: str, db_name: str = "") -> None:
|
|
177
|
+
self.driver = AsyncGraphDatabase.driver(uri, auth=(user, pwd), database=db_name)
|
|
178
|
+
# await self.driver.verify_connectivity()
|
|
179
|
+
self.driver_uri = uri
|
|
180
|
+
self.auth = (user, pwd)
|
|
181
|
+
|
|
182
|
+
async def execute_ro_cypher_query(
|
|
183
|
+
self, query: str, parameters: dict[str, Any] | None = None, limit: int | None = None
|
|
184
|
+
) -> Iterable[Record]:
|
|
185
|
+
with tracer.start_as_current_span("memgraph.execute_ro_cypher_query") as span:
|
|
186
|
+
span.set_attribute("memgraph.query", query)
|
|
187
|
+
if parameters:
|
|
188
|
+
span.set_attribute("memgraph.parameters", json.dumps(parameters))
|
|
189
|
+
result: EagerResult = await self.driver.execute_query(query, parameters, routing_=RoutingControl.READ)
|
|
190
|
+
|
|
191
|
+
return result.records
|
|
192
|
+
|
|
193
|
+
async def run_cypher(
|
|
194
|
+
self,
|
|
195
|
+
query: str,
|
|
196
|
+
parameters: dict[str, Any] | None = None,
|
|
197
|
+
limit: int | None = None,
|
|
198
|
+
) -> Iterable[Record]:
|
|
199
|
+
with tracer.start_as_current_span("memgraph.run_cypher") as span:
|
|
200
|
+
span.set_attribute("memgraph.query", query)
|
|
201
|
+
if parameters:
|
|
202
|
+
span.set_attribute("memgraph.parameters", json.dumps(parameters))
|
|
203
|
+
if limit is not None:
|
|
204
|
+
span.set_attribute("memgraph.limit", limit)
|
|
205
|
+
|
|
206
|
+
async with self.driver.session() as session:
|
|
207
|
+
result = await aioit.list(aioit.islice(await session.run(query, parameters), limit))
|
|
208
|
+
|
|
209
|
+
return result
|
|
210
|
+
|
|
211
|
+
async def add_node(
|
|
212
|
+
self,
|
|
213
|
+
node_id: str,
|
|
214
|
+
labels: Set[str],
|
|
215
|
+
properties: Dict[str, Any] | None = None,
|
|
216
|
+
embeddings: Dict[str, np.ndarray] | None = None,
|
|
217
|
+
) -> None:
|
|
218
|
+
if properties and embeddings:
|
|
219
|
+
embed_props = {f"{prop_name}_embedding": v for prop_name, v in embeddings.items()}
|
|
220
|
+
properties = {
|
|
221
|
+
**properties,
|
|
222
|
+
**embed_props,
|
|
223
|
+
}
|
|
224
|
+
await super().add_node(node_id, labels, properties, embeddings)
|
|
225
|
+
|
|
226
|
+
async def llm_schema(self) -> str:
|
|
227
|
+
"""can be used as fallback data model structure"""
|
|
228
|
+
res = await self.driver.execute_query("CALL llm_util.schema() YIELD schema RETURN schema")
|
|
229
|
+
return res.records[0]["schema"]
|
|
230
|
+
|
|
231
|
+
async def create_basic_indices(self, node_types=None) -> None:
|
|
232
|
+
if not node_types:
|
|
233
|
+
node_types = await self.get_existing_node_types()
|
|
234
|
+
for label in node_types:
|
|
235
|
+
await self.create_node_prop_index(set(label), "id", unique=True)
|
|
236
|
+
|
|
237
|
+
async def clear(self) -> None:
|
|
238
|
+
async with self.driver.session() as session:
|
|
239
|
+
res = await session.run("CALL vector_search.show_index_info() YIELD index_name RETURN *")
|
|
240
|
+
|
|
241
|
+
async for (idx_name,) in res:
|
|
242
|
+
await session.run(f"DROP VECTOR INDEX {escape_cypher(idx_name)}")
|
|
243
|
+
idx_name_re = re.compile(r"\(name:\s(.+?)\)")
|
|
244
|
+
async for row in await session.run(cast(te.LiteralString, "SHOW INDEX INFO")):
|
|
245
|
+
index_type = row["index type"]
|
|
246
|
+
idx_name = next(iter(idx_name_re.findall(index_type)), None)
|
|
247
|
+
if not idx_name:
|
|
248
|
+
continue
|
|
249
|
+
await session.run(f"DROP TEXT INDEX {escape_cypher(idx_name)}")
|
|
250
|
+
await session.run("CALL schema.assert({}, {}, {}, true) YIELD action, key, keys, label, unique")
|
|
251
|
+
await session.run("MATCH (n) DETACH DELETE n")
|
|
252
|
+
# TODO more efficient:
|
|
253
|
+
# USING PERIODIC COMMIT num_rows
|
|
254
|
+
# MATCH (n)-[r]->(m)
|
|
255
|
+
# DELETE r;
|
|
256
|
+
# USING PERIODIC COMMIT num_rows
|
|
257
|
+
# MATCH (n)
|
|
258
|
+
# DETACH DELETE n;
|
|
259
|
+
|
|
260
|
+
async def text_search(self, label: str, query: str, limit: int = 10) -> Iterable[Record]:
|
|
261
|
+
with tracer.start_as_current_span("memgraph.text_search") as span:
|
|
262
|
+
span.set_attribute("memgraph.label", label)
|
|
263
|
+
span.set_attribute("memgraph.fts_query", query)
|
|
264
|
+
span.set_attribute("memgraph.limit", limit)
|
|
265
|
+
|
|
266
|
+
query = "CALL text_search.search_all($idx_name, $query) YIELD node RETURN node LIMIT $limit"
|
|
267
|
+
span.set_attribute("memgraph.query", query)
|
|
268
|
+
|
|
269
|
+
res = await self.driver.execute_query(
|
|
270
|
+
query,
|
|
271
|
+
idx_name=self._fts_idx_name(label),
|
|
272
|
+
query=query,
|
|
273
|
+
limit=limit,
|
|
274
|
+
routing_=RoutingControl.READ,
|
|
275
|
+
)
|
|
276
|
+
return res.records
|
|
277
|
+
|
|
278
|
+
async def create_node_prop_index(self, labels: set[str], property: str, unique: bool = False) -> None:
|
|
279
|
+
escaped_label = escape_labels(labels)
|
|
280
|
+
escaped_prop = escape_cypher(property)
|
|
281
|
+
await self.run_cypher(f"CREATE INDEX ON :{escaped_label}({escaped_prop})")
|
|
282
|
+
if not unique:
|
|
283
|
+
return
|
|
284
|
+
await self.run_cypher(f"CREATE CONSTRAINT ON (n:{escaped_label})\nASSERT n.{escaped_prop} IS UNIQUE")
|
|
285
|
+
|
|
286
|
+
@staticmethod
|
|
287
|
+
def _fts_idx_name(label: str) -> str:
|
|
288
|
+
return f"{label.lower()}_fts_idx"
|
|
289
|
+
|
|
290
|
+
def close(self):
|
|
291
|
+
self.driver.close()
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def escape_cypher(identifier: str) -> str:
|
|
295
|
+
identifier = identifier.replace("\u0060", "`").replace("`", "``")
|
|
296
|
+
return f"`{identifier}`"
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def escape_labels(labels: set[str]) -> str:
|
|
300
|
+
return ":".join(escape_cypher(label) for label in labels)
|
vedana_core/llm.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import logging
|
|
3
|
+
from typing import Awaitable, Callable, Iterable
|
|
4
|
+
|
|
5
|
+
import openai
|
|
6
|
+
from jims_core.llms.llm_provider import LLMProvider
|
|
7
|
+
from jims_core.thread.schema import CommunicationEvent
|
|
8
|
+
from openai.types.chat import (
|
|
9
|
+
ChatCompletionMessageParam,
|
|
10
|
+
ChatCompletionToolMessageParam,
|
|
11
|
+
)
|
|
12
|
+
from pydantic import BaseModel
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Tool[T: BaseModel]:
|
|
18
|
+
def __init__(
|
|
19
|
+
self, name: str, description: str, args_cls: type[T], fn: Callable[[T], Awaitable[str]] | Callable[[T], str]
|
|
20
|
+
) -> None:
|
|
21
|
+
self.name = name
|
|
22
|
+
self.description = description
|
|
23
|
+
self.args_cls = args_cls
|
|
24
|
+
self.fn = fn
|
|
25
|
+
self.openai_def = openai.pydantic_function_tool(args_cls, name=name, description=description)
|
|
26
|
+
|
|
27
|
+
async def call(self, args_json: str) -> str:
|
|
28
|
+
try:
|
|
29
|
+
fn_args = self.args_cls.model_validate_json(args_json)
|
|
30
|
+
except ValueError:
|
|
31
|
+
return f"Invalid tool args: {args_json}"
|
|
32
|
+
|
|
33
|
+
if asyncio.iscoroutinefunction(self.fn):
|
|
34
|
+
result = await self.fn(fn_args)
|
|
35
|
+
else:
|
|
36
|
+
result: str = await asyncio.to_thread(self.fn, fn_args) # type: ignore
|
|
37
|
+
|
|
38
|
+
return result
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
class LLM:
|
|
42
|
+
def __init__(
|
|
43
|
+
self,
|
|
44
|
+
llm_provider: LLMProvider,
|
|
45
|
+
prompt_templates: dict[str, str],
|
|
46
|
+
logger: logging.Logger | None = None,
|
|
47
|
+
) -> None:
|
|
48
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
49
|
+
self.llm = llm_provider
|
|
50
|
+
self.prompt_templates = prompt_templates
|
|
51
|
+
|
|
52
|
+
# Current
|
|
53
|
+
async def generate_cypher_query_with_tools(
|
|
54
|
+
self,
|
|
55
|
+
data_descr: str,
|
|
56
|
+
messages: Iterable,
|
|
57
|
+
tools: list[Tool],
|
|
58
|
+
) -> tuple[list[ChatCompletionMessageParam], str]:
|
|
59
|
+
tool_names = [t.name for t in tools]
|
|
60
|
+
msgs = make_cypher_query_with_tools_dialog(data_descr, self.prompt_templates, messages, tool_names=tool_names)
|
|
61
|
+
return await self.create_completion_with_tools(msgs, tools=tools)
|
|
62
|
+
|
|
63
|
+
async def create_completion_with_tools(
|
|
64
|
+
self,
|
|
65
|
+
messages: list[ChatCompletionMessageParam],
|
|
66
|
+
tools: Iterable[Tool],
|
|
67
|
+
) -> tuple[list[ChatCompletionMessageParam], str]:
|
|
68
|
+
messages = messages.copy()
|
|
69
|
+
tool_defs = [tool.openai_def for tool in tools]
|
|
70
|
+
tools_map = {tool.name: tool for tool in tools}
|
|
71
|
+
|
|
72
|
+
async def _execute_tool_call(tool_call):
|
|
73
|
+
tool_name = tool_call.function.name
|
|
74
|
+
tool = tools_map.get(tool_name)
|
|
75
|
+
if not tool:
|
|
76
|
+
self.logger.error(f"Tool {tool_name} not found!")
|
|
77
|
+
return tool_call.id, f"Tool {tool_name} not found!"
|
|
78
|
+
|
|
79
|
+
self.logger.debug(f"Calling tool {tool_name}")
|
|
80
|
+
try:
|
|
81
|
+
tool_res = await tool.call(tool_call.function.arguments)
|
|
82
|
+
except Exception as e:
|
|
83
|
+
self.logger.exception("Error executing tool %s: %s", tool_name, e)
|
|
84
|
+
tool_res = f"Error executing tool {tool_name}: {e}"
|
|
85
|
+
|
|
86
|
+
self.logger.debug("Tool %s (%s) result: %s", tool_name, tool.description, tool_res)
|
|
87
|
+
return tool_call.id, tool_res
|
|
88
|
+
|
|
89
|
+
max_iters = 5
|
|
90
|
+
for i in range(max_iters):
|
|
91
|
+
msg, tool_calls = await self.llm.chat_completion_with_tools(
|
|
92
|
+
messages=messages,
|
|
93
|
+
tools=tool_defs,
|
|
94
|
+
)
|
|
95
|
+
|
|
96
|
+
messages.append(msg.to_dict()) # type: ignore
|
|
97
|
+
|
|
98
|
+
if not tool_calls:
|
|
99
|
+
self.logger.debug("No tool calls found. Exiting tool call loop")
|
|
100
|
+
break
|
|
101
|
+
|
|
102
|
+
self.logger.debug(f"Tool call iter {i + 1}/{max_iters}")
|
|
103
|
+
|
|
104
|
+
# Execute tool calls in parallel
|
|
105
|
+
results = await asyncio.gather(*[_execute_tool_call(t) for t in tool_calls])
|
|
106
|
+
|
|
107
|
+
for tool_call_id, tool_res in results:
|
|
108
|
+
messages.append(
|
|
109
|
+
ChatCompletionToolMessageParam(role="tool", tool_call_id=tool_call_id, content=tool_res)
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
if i == max_iters - 1:
|
|
113
|
+
self.logger.warning(f"Reached tool call iteration limit ({max_iters}). Exiting tool call loop")
|
|
114
|
+
finalize_prompt = self.prompt_templates.get("finalize_answer_tmplt", finalize_answer_tmplt)
|
|
115
|
+
finalize_msg = {"role": "system", "content": finalize_prompt}
|
|
116
|
+
final_msg = await self.llm.chat_completion_plain(messages + [finalize_msg])
|
|
117
|
+
messages.append(final_msg.to_dict()) # type: ignore
|
|
118
|
+
break
|
|
119
|
+
|
|
120
|
+
for last_msg in reversed(messages): # sometimes message with final answer is not the last one
|
|
121
|
+
if last_msg.get("role", "") == "assistant" and last_msg.get("content"):
|
|
122
|
+
return messages, str(last_msg.get("content"))
|
|
123
|
+
return messages, ""
|
|
124
|
+
|
|
125
|
+
async def generate_no_answer(
|
|
126
|
+
self,
|
|
127
|
+
dialog: list[CommunicationEvent] | None = None,
|
|
128
|
+
) -> str:
|
|
129
|
+
prompt = self.prompt_templates.get("generate_no_answer_tmplt", generate_no_answer_tmplt)
|
|
130
|
+
messages = [
|
|
131
|
+
{"role": "system", "content": prompt},
|
|
132
|
+
*(dialog or []),
|
|
133
|
+
]
|
|
134
|
+
response = await self.llm.chat_completion_plain(messages)
|
|
135
|
+
human_answer = "" if response.content is None else response.content.strip()
|
|
136
|
+
self.logger.debug(f"Generated 'no answer' response: {human_answer}")
|
|
137
|
+
return human_answer
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
finalize_answer_tmplt = """\
|
|
141
|
+
Сформулируй ответ на запрос пользователя основе информации, полученной в результате вызова результатов инструментов.
|
|
142
|
+
Если информации недостаточно для точного ответа, ясно опиши ограничения и предложи 1–2 уточняющих вопроса.
|
|
143
|
+
Важно! Не упоминай инструменты в явном виде, ссылайся только на данные.
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
generate_no_answer_tmplt = """\
|
|
147
|
+
Ты - помощник, который преобразует технические ответы в понятный человеку текст.
|
|
148
|
+
Мы не смогли найти ответ на вопрос пользователя в базе знаний.
|
|
149
|
+
Сформулируй ответ, сообщающий кратко и информативно, что ответа не найдено.
|
|
150
|
+
Предложи пару вариантов уточняющих вопросов на основе информации в контексте. Предложи в casual стиле.
|
|
151
|
+
"""
|
|
152
|
+
|
|
153
|
+
generate_answer_with_tools_tmplt = """\
|
|
154
|
+
Ты — помощник по работе с графовыми базами данных, в которых используется язык запросов Cypher
|
|
155
|
+
|
|
156
|
+
Цель: постараться найти ответ на вопрос пользователя используя инструменты для работы с БД на основе текстового описания графовой базы данных.
|
|
157
|
+
|
|
158
|
+
На вход ты получаешь graph_composition: – описание графа и примеры запросов по нему, и user_query – пользовательский запрос.
|
|
159
|
+
|
|
160
|
+
**Что нужно сделать:**
|
|
161
|
+
1. Сгенерировать `Cypher`-запросы, используя узлы, атрибуты и связи перечисленные в **graph_composition**.
|
|
162
|
+
2. Руководствуйся данными в **graph_composition** примерами запросов, чтобы составить итоговый запрос.
|
|
163
|
+
3. Используй инструменты {tools} для выполнения запросов и поиска
|
|
164
|
+
|
|
165
|
+
Если нужно, используй несколько `MATCH`-блоков, например:
|
|
166
|
+
MATCH (o:offer)-[:OFFER_belongs_to_CATEGORY]->(c:category)
|
|
167
|
+
MATCH (o)-[:OFFER_made_of_MATERIAL]->(m:material)
|
|
168
|
+
WHERE c.category_name = "Встраиваемый светильник" AND m.material_name IN ["Стекло", "Металл и Стекло", "Алюминий и стекло"]
|
|
169
|
+
RETURN o
|
|
170
|
+
|
|
171
|
+
Теперь проанализируй следующую структуру графа, и постарайся найти ответ на вопрос используя инструменты {tools}. (Лучше использовать несколько инструментов)
|
|
172
|
+
|
|
173
|
+
**graph_composition**
|
|
174
|
+
{graph_description}
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def make_cypher_query_with_tools_dialog(
|
|
179
|
+
graph_description: str,
|
|
180
|
+
prompt_templates: dict[str, str],
|
|
181
|
+
messages: Iterable,
|
|
182
|
+
tool_names: list[str],
|
|
183
|
+
) -> list[ChatCompletionMessageParam]:
|
|
184
|
+
prompt_template = prompt_templates.get("generate_answer_with_tools_tmplt", generate_answer_with_tools_tmplt)
|
|
185
|
+
prompt = prompt_template.format(graph_description=graph_description, tools=", ".join(tool_names))
|
|
186
|
+
return [
|
|
187
|
+
{
|
|
188
|
+
"role": "system",
|
|
189
|
+
"content": prompt,
|
|
190
|
+
},
|
|
191
|
+
*messages,
|
|
192
|
+
]
|
vedana_core/py.typed
ADDED
|
File without changes
|