vedana-core 0.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vedana_core/__init__.py +0 -0
- vedana_core/app.py +78 -0
- vedana_core/data_model.py +465 -0
- vedana_core/data_provider.py +513 -0
- vedana_core/db.py +41 -0
- vedana_core/graph.py +300 -0
- vedana_core/llm.py +192 -0
- vedana_core/py.typed +0 -0
- vedana_core/rag_agent.py +234 -0
- vedana_core/rag_pipeline.py +326 -0
- vedana_core/settings.py +35 -0
- vedana_core/start_pipeline.py +17 -0
- vedana_core/utils.py +31 -0
- vedana_core/vts.py +167 -0
- vedana_core-0.1.0.dev3.dist-info/METADATA +29 -0
- vedana_core-0.1.0.dev3.dist-info/RECORD +17 -0
- vedana_core-0.1.0.dev3.dist-info/WHEEL +4 -0
vedana_core/rag_agent.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
import enum
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import re
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from itertools import islice
|
|
7
|
+
from typing import Any, Mapping, Type
|
|
8
|
+
|
|
9
|
+
import neo4j
|
|
10
|
+
import neo4j.graph
|
|
11
|
+
from jims_core.thread.thread_context import ThreadContext
|
|
12
|
+
from pydantic import BaseModel, Field, create_model
|
|
13
|
+
|
|
14
|
+
from vedana_core.graph import Graph, Record
|
|
15
|
+
from vedana_core.vts import VectorStore
|
|
16
|
+
from vedana_core.llm import LLM, Tool
|
|
17
|
+
from vedana_core.settings import settings
|
|
18
|
+
|
|
19
|
+
QueryResult = list[Record] | Exception
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
# TODO replace with VTSArgs and CypherArgs
|
|
23
|
+
class CypherQuery(str):
|
|
24
|
+
...
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class VTSQuery:
|
|
29
|
+
label: str
|
|
30
|
+
param: str
|
|
31
|
+
query: str
|
|
32
|
+
|
|
33
|
+
def __str__(self) -> str:
|
|
34
|
+
return f'vector_search("{self.label}","{self.param}","{self.query}")'
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
DBQuery = CypherQuery | VTSQuery | str
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
@dataclass
|
|
41
|
+
class RagResults:
|
|
42
|
+
fts_res: list[Record] | None = None
|
|
43
|
+
vts_res: list[Record] | None = None
|
|
44
|
+
db_query_res: list[tuple[DBQuery, QueryResult]] | None = None
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class VTSArgs(BaseModel):
|
|
48
|
+
label: str = Field(description="node label")
|
|
49
|
+
property: str = Field(description="node property to search in")
|
|
50
|
+
text: str = Field(description="text to search similar")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class CypherArgs(BaseModel):
|
|
54
|
+
query: str = Field(description="Cypher query")
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class GetHistoryArgs(BaseModel):
|
|
58
|
+
max_history: int = Field(20000, description="Maximum text length to retrieve from history. Cuts off on messages")
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
VTS_TOOL_NAME = "vector_text_search"
|
|
62
|
+
CYPHER_TOOL_NAME = "cypher"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class RagAgent:
|
|
66
|
+
_vts_args: type[VTSArgs]
|
|
67
|
+
|
|
68
|
+
def __init__(
|
|
69
|
+
self,
|
|
70
|
+
graph: Graph,
|
|
71
|
+
vts: VectorStore,
|
|
72
|
+
data_model_description: str,
|
|
73
|
+
data_model_vts_indices,
|
|
74
|
+
llm: LLM,
|
|
75
|
+
ctx: ThreadContext,
|
|
76
|
+
logger: logging.Logger | None = None,
|
|
77
|
+
) -> None:
|
|
78
|
+
self.graph = graph
|
|
79
|
+
self.vts = vts
|
|
80
|
+
self.llm = llm
|
|
81
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
82
|
+
self._vts_meta_args: dict[str, dict[str, tuple[str, float]]] = {} # stuff not passed through toolcall
|
|
83
|
+
self._vts_args = self._build_vts_arg_model(data_model_vts_indices)
|
|
84
|
+
self.data_model_description = data_model_description
|
|
85
|
+
self.ctx = ctx
|
|
86
|
+
|
|
87
|
+
def _build_vts_arg_model(self, vts_indices) -> Type[VTSArgs]:
|
|
88
|
+
"""Create a Pydantic model with Enum-constrained fields for the VTS tool."""
|
|
89
|
+
|
|
90
|
+
if not vts_indices:
|
|
91
|
+
return VTSArgs
|
|
92
|
+
|
|
93
|
+
# fill in lookup for resolving idx type and getting custom threshold
|
|
94
|
+
for idx_type, i_name, i_attr, i_th in vts_indices:
|
|
95
|
+
if not self._vts_meta_args.get(i_name):
|
|
96
|
+
self._vts_meta_args[i_name] = {}
|
|
97
|
+
self._vts_meta_args[i_name][i_attr] = (idx_type, i_th)
|
|
98
|
+
|
|
99
|
+
# Label Enum – keys of `vts_indices`
|
|
100
|
+
LabelEnum = enum.Enum("LabelEnum", {name: name for (_type, name, _attr, _th) in vts_indices}) # type: ignore
|
|
101
|
+
|
|
102
|
+
# Property Enum – unique values of `vts_indices`
|
|
103
|
+
unique_props = set(attr for (_type, _name, attr, _th) in vts_indices)
|
|
104
|
+
prop_member_mapping: dict[str, str] = {}
|
|
105
|
+
|
|
106
|
+
used_names: set[str] = set()
|
|
107
|
+
for idx, prop in enumerate(sorted(unique_props)):
|
|
108
|
+
sanitized = re.sub(r"\W|^(?=\d)", "_", prop)
|
|
109
|
+
if sanitized in used_names:
|
|
110
|
+
sanitized = f"{sanitized}_{idx}"
|
|
111
|
+
used_names.add(sanitized)
|
|
112
|
+
prop_member_mapping[sanitized] = prop
|
|
113
|
+
|
|
114
|
+
PropertyEnum = enum.Enum("PropertyEnum", prop_member_mapping) # type: ignore
|
|
115
|
+
|
|
116
|
+
VTSArgsEnum = create_model(
|
|
117
|
+
"VTSArgsEnum",
|
|
118
|
+
label=(LabelEnum, Field(description="node or edge label")),
|
|
119
|
+
property=(PropertyEnum, Field(description="property to search in")),
|
|
120
|
+
text=(str, Field(description="text for semantic search")),
|
|
121
|
+
__base__=VTSArgs,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
return VTSArgsEnum
|
|
125
|
+
|
|
126
|
+
async def search_vector_text(
|
|
127
|
+
self,
|
|
128
|
+
label: str,
|
|
129
|
+
prop_type: str,
|
|
130
|
+
prop_name: str,
|
|
131
|
+
search_value: str,
|
|
132
|
+
threshold: float,
|
|
133
|
+
top_n: int = 5,
|
|
134
|
+
) -> list[Record]:
|
|
135
|
+
embed = await self.llm.llm.create_embedding(search_value)
|
|
136
|
+
return await self.vts.vector_search(label, prop_type, prop_name, embed, threshold=threshold, top_n=top_n)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def result_to_text(query: str, result: list[Record] | Exception) -> str:
|
|
140
|
+
if isinstance(result, Exception):
|
|
141
|
+
return f"Query: {query}\nResult: 'Error executing query'"
|
|
142
|
+
rows_str = "\n".join(row_to_text(row) for row in result)
|
|
143
|
+
return f"Query: {query}\nRows:\n{rows_str}"
|
|
144
|
+
|
|
145
|
+
async def execute_cypher_query(self, query, rows_limit: int = 30) -> QueryResult:
|
|
146
|
+
try:
|
|
147
|
+
return list(islice(await self.graph.execute_ro_cypher_query(query), rows_limit))
|
|
148
|
+
except Exception as e:
|
|
149
|
+
self.logger.exception(e)
|
|
150
|
+
return e
|
|
151
|
+
|
|
152
|
+
def rag_results_to_text(self, results: RagResults) -> str:
|
|
153
|
+
all_results = results.db_query_res or []
|
|
154
|
+
if results.vts_res:
|
|
155
|
+
all_results.append(("Vector text search", results.vts_res))
|
|
156
|
+
if results.fts_res:
|
|
157
|
+
all_results.append(("Full text search", results.fts_res))
|
|
158
|
+
return "\n\n".join(self.result_to_text(str(q), r) for q, r in all_results)
|
|
159
|
+
|
|
160
|
+
async def text_to_answer_with_vts_and_cypher(
|
|
161
|
+
self, text_query: str, threshold: float, top_n: int = 5
|
|
162
|
+
) -> tuple[str, list, list[VTSQuery], list[CypherQuery]]:
|
|
163
|
+
vts_queries: list[VTSQuery] = []
|
|
164
|
+
cypher_queries: list[CypherQuery] = []
|
|
165
|
+
|
|
166
|
+
async def vts_fn(args: VTSArgs) -> str:
|
|
167
|
+
label = args.label.value if isinstance(args.label, enum.Enum) else args.label
|
|
168
|
+
prop = args.property.value if isinstance(args.property, enum.Enum) else args.property
|
|
169
|
+
|
|
170
|
+
# default fallback treats toolcall as node vector search
|
|
171
|
+
prop_type, th = self._vts_meta_args.get(label, {}).get(prop, ("node", threshold))
|
|
172
|
+
self.logger.debug(f"vts_fn(on={prop_type} label={label}, property={prop}, th={th}, n={top_n})")
|
|
173
|
+
|
|
174
|
+
vts_queries.append(VTSQuery(label, prop, args.text))
|
|
175
|
+
vts_res = await self.search_vector_text(label, prop_type, prop, args.text, threshold=th, top_n=top_n)
|
|
176
|
+
return self.result_to_text(VTS_TOOL_NAME, vts_res)
|
|
177
|
+
|
|
178
|
+
async def cypher_fn(args: CypherArgs) -> str:
|
|
179
|
+
self.logger.debug(f"cypher_fn({args})")
|
|
180
|
+
cypher_queries.append(CypherQuery(args.query))
|
|
181
|
+
res = await self.execute_cypher_query(args.query)
|
|
182
|
+
return self.result_to_text(CYPHER_TOOL_NAME, res)
|
|
183
|
+
|
|
184
|
+
vts_tool = Tool(
|
|
185
|
+
VTS_TOOL_NAME,
|
|
186
|
+
"Vector search for similar text in node properties, use for semantic search",
|
|
187
|
+
self._vts_args,
|
|
188
|
+
vts_fn,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
cypher_tool = Tool(
|
|
192
|
+
CYPHER_TOOL_NAME,
|
|
193
|
+
"Execute a Cypher query against the graph database. Use for structured data retrieval and graph traversal.",
|
|
194
|
+
CypherArgs,
|
|
195
|
+
cypher_fn,
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
tools: list[Tool] = [vts_tool, cypher_tool]
|
|
199
|
+
|
|
200
|
+
all_query_events, answer = await self.llm.generate_cypher_query_with_tools(
|
|
201
|
+
data_descr=self.data_model_description,
|
|
202
|
+
messages=self.ctx.history[-settings.pipeline_history_length :],
|
|
203
|
+
tools=tools,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if not answer:
|
|
207
|
+
self.logger.warning(f"No answer found for {text_query}. Generating empty answer...")
|
|
208
|
+
answer = await self.llm.generate_no_answer(self.ctx.history[-settings.pipeline_history_length :])
|
|
209
|
+
|
|
210
|
+
return answer, all_query_events, vts_queries, cypher_queries
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def _remove_embeddings(val: Any):
|
|
214
|
+
if isinstance(val, Mapping):
|
|
215
|
+
return {k: _remove_embeddings(v) for k, v in val.items() if not k.endswith("_embedding")}
|
|
216
|
+
if isinstance(val, list):
|
|
217
|
+
return [_remove_embeddings(v) for v in val]
|
|
218
|
+
return val
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _clear_record_val(val: Any): # todo check for edges
|
|
222
|
+
params = _remove_embeddings(val)
|
|
223
|
+
if isinstance(val, neo4j.graph.Node) and isinstance(params, dict):
|
|
224
|
+
params["labels"] = list(val.labels)
|
|
225
|
+
return params
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
def row_to_text(row: Any) -> str:
|
|
229
|
+
if isinstance(row, neo4j.Record):
|
|
230
|
+
row = {k: _clear_record_val(v) for k, v in row.items()}
|
|
231
|
+
try:
|
|
232
|
+
return json.dumps(row, ensure_ascii=False, indent=2)
|
|
233
|
+
except TypeError:
|
|
234
|
+
return str(row)
|
|
@@ -0,0 +1,326 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import traceback
|
|
3
|
+
from dataclasses import asdict
|
|
4
|
+
from typing import Any
|
|
5
|
+
from pydantic import BaseModel, Field
|
|
6
|
+
|
|
7
|
+
from jims_core.thread.thread_context import ThreadContext
|
|
8
|
+
|
|
9
|
+
from vedana_core.data_model import DataModel
|
|
10
|
+
from vedana_core.graph import Graph
|
|
11
|
+
from vedana_core.vts import VectorStore
|
|
12
|
+
from vedana_core.llm import LLM
|
|
13
|
+
from vedana_core.rag_agent import RagAgent
|
|
14
|
+
from vedana_core.settings import settings
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class DataModelSelection(BaseModel):
|
|
18
|
+
reasoning: str = Field(
|
|
19
|
+
default="",
|
|
20
|
+
description="Brief explanation of why these elements were selected for answering the user's question",
|
|
21
|
+
)
|
|
22
|
+
anchor_nouns: list[str] = Field(
|
|
23
|
+
default_factory=list,
|
|
24
|
+
description="List of anchor nouns (node types) needed to answer the question",
|
|
25
|
+
)
|
|
26
|
+
link_sentences: list[str] = Field(
|
|
27
|
+
default_factory=list,
|
|
28
|
+
description="List of link sentences (relationship types) needed to answer the question",
|
|
29
|
+
)
|
|
30
|
+
anchor_attribute_names: list[str] = Field(
|
|
31
|
+
default_factory=list,
|
|
32
|
+
description="List of anchor attribute names needed to answer the question (attributes belonging to nodes)",
|
|
33
|
+
)
|
|
34
|
+
link_attribute_names: list[str] = Field(
|
|
35
|
+
default_factory=list,
|
|
36
|
+
description="List of link attribute names needed to answer the question (attributes belonging to relationships)",
|
|
37
|
+
)
|
|
38
|
+
query_ids: list[str] = Field(
|
|
39
|
+
default_factory=list,
|
|
40
|
+
description="List of query scenario ID's that match the user's question type",
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
dm_filter_base_system_prompt = """\
|
|
45
|
+
Ты — помощник по анализу структуры графовой базы данных.
|
|
46
|
+
|
|
47
|
+
Твоя задача: проанализировать вопрос пользователя и определить, какие элементы модели данных (узлы, связи, атрибуты, сценарии запросов) необходимы для формирования ответа.
|
|
48
|
+
|
|
49
|
+
**Правила выбора:**
|
|
50
|
+
1. Выбирай только те элементы, которые ДЕЙСТВИТЕЛЬНО нужны для ответа на вопрос
|
|
51
|
+
2. Если вопрос касается связи между сущностями — выбери соответствующие узлы И связь между ними
|
|
52
|
+
3. Выбирай атрибуты, которые могут содержать искомую информацию или использоваться для фильтрации
|
|
53
|
+
- anchor_attribute_names: атрибуты узлов (находятся в разделе "Атрибуты узлов")
|
|
54
|
+
- link_attribute_names: атрибуты связей (находятся в разделе "Атрибуты связей")
|
|
55
|
+
4. Выбирай сценарий запроса, который лучше всего соответствует типу вопроса пользователя
|
|
56
|
+
5. Лучше выбрать чуть больше, чем упустить важное — но не выбирай всё подряд
|
|
57
|
+
|
|
58
|
+
**Формат ответа:**
|
|
59
|
+
Верни JSON с выбранными элементами. Используй ТОЧНЫЕ имена из предоставленной модели данных.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
dm_filter_user_prompt_template = """\
|
|
63
|
+
**Вопрос пользователя:**
|
|
64
|
+
{user_query}
|
|
65
|
+
|
|
66
|
+
**Модель данных:**
|
|
67
|
+
{compact_data_model}
|
|
68
|
+
|
|
69
|
+
Проанализируй вопрос и выбери необходимые элементы модели данных для формирования ответа.
|
|
70
|
+
"""
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
class StartPipeline:
|
|
74
|
+
"""
|
|
75
|
+
Response for /start command
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, data_model: DataModel) -> None:
|
|
79
|
+
self.data_model = data_model
|
|
80
|
+
|
|
81
|
+
async def __call__(self, ctx: ThreadContext) -> None:
|
|
82
|
+
lifecycle_events = await self.data_model.conversation_lifecycle_events()
|
|
83
|
+
start_response = lifecycle_events.get("/start")
|
|
84
|
+
|
|
85
|
+
message = start_response or "Bot online. No response for /start command in LifecycleEvents"
|
|
86
|
+
ctx.send_message(message)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class RagPipeline:
|
|
90
|
+
"""RAG Pipeline with data model filtering for optimized query processing.
|
|
91
|
+
|
|
92
|
+
This pipeline adds an initial step that filters the data model based on
|
|
93
|
+
the user's query, selecting only relevant anchors, links, attributes,
|
|
94
|
+
and query scenarios. This reduces token usage and improves LLM precision
|
|
95
|
+
for large data models.
|
|
96
|
+
"""
|
|
97
|
+
|
|
98
|
+
def __init__(
|
|
99
|
+
self,
|
|
100
|
+
graph: Graph,
|
|
101
|
+
vts: VectorStore,
|
|
102
|
+
data_model: DataModel,
|
|
103
|
+
logger,
|
|
104
|
+
threshold: float = 0.8,
|
|
105
|
+
top_n: int = 5,
|
|
106
|
+
model: str | None = None,
|
|
107
|
+
filter_model: str | None = None,
|
|
108
|
+
enable_filtering: bool | None = None,
|
|
109
|
+
):
|
|
110
|
+
self.graph = graph
|
|
111
|
+
self.vts = vts
|
|
112
|
+
self.data_model = data_model
|
|
113
|
+
self.logger = logger or logging.getLogger(__name__)
|
|
114
|
+
self.threshold = threshold
|
|
115
|
+
self.top_n = top_n
|
|
116
|
+
self.model = model or settings.model
|
|
117
|
+
self.filter_model = filter_model or settings.filter_model # or self.model
|
|
118
|
+
self.enable_filtering = enable_filtering or settings.enable_dm_filtering
|
|
119
|
+
|
|
120
|
+
async def __call__(self, ctx: ThreadContext) -> None:
|
|
121
|
+
"""Main pipeline execution - implements JIMS Pipeline protocol."""
|
|
122
|
+
|
|
123
|
+
# Get the last user message
|
|
124
|
+
user_query = ctx.get_last_user_message()
|
|
125
|
+
if not user_query:
|
|
126
|
+
ctx.send_message("I didn't receive a question. Please ask me something!")
|
|
127
|
+
return
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
# Update status
|
|
131
|
+
await ctx.update_agent_status("Processing your question...")
|
|
132
|
+
|
|
133
|
+
# Process the query using RAG
|
|
134
|
+
answer, agent_query_events, technical_info = await self.process_rag_query(user_query, ctx)
|
|
135
|
+
|
|
136
|
+
# Send the answer
|
|
137
|
+
ctx.send_message(answer)
|
|
138
|
+
|
|
139
|
+
# Store technical information as an event
|
|
140
|
+
ctx.send_event(
|
|
141
|
+
"rag.query_processed",
|
|
142
|
+
{
|
|
143
|
+
"query": user_query,
|
|
144
|
+
"answer": answer,
|
|
145
|
+
"technical_info": technical_info,
|
|
146
|
+
"threshold": self.threshold,
|
|
147
|
+
},
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
except Exception as e:
|
|
151
|
+
self.logger.exception(f"Error in RAG pipeline: {e}")
|
|
152
|
+
error_msg = "An error occurred while processing the request" # не передаем ошибку пользователю в диалог
|
|
153
|
+
ctx.send_message(error_msg)
|
|
154
|
+
|
|
155
|
+
# Store error event
|
|
156
|
+
ctx.send_event(
|
|
157
|
+
"rag.error",
|
|
158
|
+
{
|
|
159
|
+
"query": user_query,
|
|
160
|
+
"error": str(e),
|
|
161
|
+
"traceback": traceback.format_exc(),
|
|
162
|
+
},
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
async def process_rag_query(self, query: str, ctx: ThreadContext) -> tuple[str, list, dict[str, Any]]:
|
|
166
|
+
# 1. Filter data model
|
|
167
|
+
if self.enable_filtering:
|
|
168
|
+
await ctx.update_agent_status("Analyzing query structure...")
|
|
169
|
+
data_model_description, filter_selection = await self.filter_data_model(query, ctx)
|
|
170
|
+
else:
|
|
171
|
+
data_model_description = await self.data_model.to_text_descr()
|
|
172
|
+
filter_selection = DataModelSelection()
|
|
173
|
+
|
|
174
|
+
# Read required DataModel properties
|
|
175
|
+
prompt_templates = await self.data_model.prompt_templates()
|
|
176
|
+
data_model_vector_search_indices = await self.data_model.vector_indices()
|
|
177
|
+
|
|
178
|
+
# 2. Create LLM and agent with filtered data model; Step 1 LLM costs are counted since same ctx.llm is used
|
|
179
|
+
llm = LLM(ctx.llm, prompt_templates=prompt_templates, logger=self.logger)
|
|
180
|
+
await ctx.update_agent_status("Searching knowledge base...")
|
|
181
|
+
|
|
182
|
+
if self.model != llm.llm.model and settings.debug:
|
|
183
|
+
llm.llm.set_model(self.model)
|
|
184
|
+
|
|
185
|
+
agent = RagAgent(
|
|
186
|
+
graph=self.graph,
|
|
187
|
+
vts=self.vts,
|
|
188
|
+
data_model_description=data_model_description,
|
|
189
|
+
data_model_vts_indices=data_model_vector_search_indices,
|
|
190
|
+
llm=llm,
|
|
191
|
+
ctx=ctx,
|
|
192
|
+
logger=self.logger,
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
(
|
|
196
|
+
answer,
|
|
197
|
+
agent_query_events,
|
|
198
|
+
vts_queries,
|
|
199
|
+
cypher_queries,
|
|
200
|
+
) = await agent.text_to_answer_with_vts_and_cypher(
|
|
201
|
+
query,
|
|
202
|
+
threshold=self.threshold,
|
|
203
|
+
top_n=self.top_n,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
technical_info: dict[str, Any] = {
|
|
207
|
+
"vts_queries": [str(q) for q in vts_queries],
|
|
208
|
+
"cypher_queries": [str(q) for q in cypher_queries],
|
|
209
|
+
"num_vts_queries": len(vts_queries),
|
|
210
|
+
"num_cypher_queries": len(cypher_queries),
|
|
211
|
+
"model_used": self.model,
|
|
212
|
+
"model_stats": {m: asdict(u) for m, u in ctx.llm.usage.items()},
|
|
213
|
+
}
|
|
214
|
+
|
|
215
|
+
# Add filtering info if applicable
|
|
216
|
+
if self.enable_filtering:
|
|
217
|
+
dm_anchors = await self.data_model.get_anchors()
|
|
218
|
+
dm_links = await self.data_model.get_links()
|
|
219
|
+
dm_queries = await self.data_model.get_queries()
|
|
220
|
+
|
|
221
|
+
# Count total attributes for original_counts
|
|
222
|
+
total_anchor_attrs = sum(len(anchor.attributes) for anchor in dm_anchors)
|
|
223
|
+
total_link_attrs = sum(len(link.attributes) for link in dm_links)
|
|
224
|
+
|
|
225
|
+
dm_filtering = {
|
|
226
|
+
"filter_model": self.filter_model,
|
|
227
|
+
"reasoning": filter_selection.reasoning,
|
|
228
|
+
"selected_anchors": filter_selection.anchor_nouns,
|
|
229
|
+
"selected_links": filter_selection.link_sentences,
|
|
230
|
+
"selected_anchor_attributes": filter_selection.anchor_attribute_names,
|
|
231
|
+
"selected_link_attributes": filter_selection.link_attribute_names,
|
|
232
|
+
"selected_queries": [dm_queries[int(i)].name for i in filter_selection.query_ids],
|
|
233
|
+
"original_counts": {
|
|
234
|
+
"anchors": len(dm_anchors),
|
|
235
|
+
"links": len(dm_links),
|
|
236
|
+
"anchor_attrs": total_anchor_attrs,
|
|
237
|
+
"link_attrs": total_link_attrs,
|
|
238
|
+
"queries": len(dm_queries),
|
|
239
|
+
},
|
|
240
|
+
"filtered_counts": {
|
|
241
|
+
"anchors": len(filter_selection.anchor_nouns),
|
|
242
|
+
"links": len(filter_selection.link_sentences),
|
|
243
|
+
"anchor_attrs": len(filter_selection.anchor_attribute_names),
|
|
244
|
+
"link_attrs": len(filter_selection.link_attribute_names),
|
|
245
|
+
"queries": len(filter_selection.query_ids),
|
|
246
|
+
},
|
|
247
|
+
}
|
|
248
|
+
|
|
249
|
+
ctx.send_event("rag.data_model_filtered", dm_filtering)
|
|
250
|
+
|
|
251
|
+
return answer, agent_query_events, technical_info
|
|
252
|
+
|
|
253
|
+
async def filter_data_model(
|
|
254
|
+
self,
|
|
255
|
+
query: str,
|
|
256
|
+
ctx: ThreadContext,
|
|
257
|
+
) -> tuple[str, DataModelSelection]:
|
|
258
|
+
# Get description for filtering
|
|
259
|
+
dm_json = await self.data_model.to_compact_json()
|
|
260
|
+
dm_prompt_templates = await self.data_model.prompt_templates()
|
|
261
|
+
|
|
262
|
+
# Build the prompt
|
|
263
|
+
system_prompt = dm_prompt_templates.get("dm_filter_prompt", dm_filter_base_system_prompt)
|
|
264
|
+
user_prompt = dm_prompt_templates.get("dm_filter_user_prompt", dm_filter_user_prompt_template).format(
|
|
265
|
+
user_query=query,
|
|
266
|
+
compact_data_model=dm_json,
|
|
267
|
+
)
|
|
268
|
+
|
|
269
|
+
messages: list[dict[str, Any]] = [
|
|
270
|
+
{"role": "system", "content": system_prompt},
|
|
271
|
+
{"role": "user", "content": user_prompt},
|
|
272
|
+
]
|
|
273
|
+
|
|
274
|
+
self.logger.debug(f"Filtering data model for query: {query}")
|
|
275
|
+
|
|
276
|
+
try:
|
|
277
|
+
filter_llm = ctx.llm
|
|
278
|
+
|
|
279
|
+
base_model = ctx.llm.model
|
|
280
|
+
if self.filter_model: # if different model specified for filtering - use it
|
|
281
|
+
filter_llm.set_model(self.filter_model)
|
|
282
|
+
|
|
283
|
+
# Use structured output to get the selection
|
|
284
|
+
selection = await filter_llm.chat_completion_structured(messages, DataModelSelection)
|
|
285
|
+
|
|
286
|
+
if base_model: # select base model back
|
|
287
|
+
ctx.llm.set_model(base_model)
|
|
288
|
+
|
|
289
|
+
if selection is None:
|
|
290
|
+
raise ValueError("LLM returned empty response")
|
|
291
|
+
|
|
292
|
+
# parse query id's to query names - LLM often misspells arbitrary query_names
|
|
293
|
+
query_names = [dm_json["queries"].get(int(i)) for i in selection.query_ids]
|
|
294
|
+
|
|
295
|
+
self.logger.debug(
|
|
296
|
+
f"Data model filter selection: "
|
|
297
|
+
f"anchors={selection.anchor_nouns}, "
|
|
298
|
+
f"links={selection.link_sentences}, "
|
|
299
|
+
f"anchor_attrs={selection.anchor_attribute_names}, "
|
|
300
|
+
f"link_attrs={selection.link_attribute_names}, "
|
|
301
|
+
f"queries={query_names}",
|
|
302
|
+
)
|
|
303
|
+
self.logger.debug(f"Filter reasoning: {selection.reasoning}")
|
|
304
|
+
|
|
305
|
+
# Create filtered data model
|
|
306
|
+
filtered_dm_descr = await self.data_model.to_text_descr(
|
|
307
|
+
anchor_nouns=selection.anchor_nouns,
|
|
308
|
+
link_sentences=selection.link_sentences,
|
|
309
|
+
anchor_attribute_names=selection.anchor_attribute_names,
|
|
310
|
+
link_attribute_names=selection.link_attribute_names,
|
|
311
|
+
query_names=query_names,
|
|
312
|
+
)
|
|
313
|
+
|
|
314
|
+
return filtered_dm_descr, selection
|
|
315
|
+
|
|
316
|
+
except Exception as e: # return full data_model
|
|
317
|
+
self.logger.exception(f"Data model filtering failed: {e}. Using full data model.")
|
|
318
|
+
descr = await self.data_model.to_text_descr()
|
|
319
|
+
return descr, DataModelSelection(
|
|
320
|
+
reasoning=f"Filtering failed: {e}. Using full data model.",
|
|
321
|
+
anchor_nouns=[],
|
|
322
|
+
link_sentences=[],
|
|
323
|
+
anchor_attribute_names=[],
|
|
324
|
+
link_attribute_names=[],
|
|
325
|
+
query_ids=[],
|
|
326
|
+
)
|
vedana_core/settings.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
2
|
+
from dotenv import load_dotenv
|
|
3
|
+
|
|
4
|
+
load_dotenv()
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class VedanaCoreSettings(BaseSettings):
|
|
8
|
+
model_config = SettingsConfigDict(
|
|
9
|
+
env_prefix="",
|
|
10
|
+
env_file=".env",
|
|
11
|
+
env_file_encoding="utf-8",
|
|
12
|
+
extra="ignore",
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
grist_server_url: str # default 'https://api.getgrist.com'
|
|
16
|
+
grist_api_key: str
|
|
17
|
+
grist_data_model_doc_id: str
|
|
18
|
+
grist_data_doc_id: str
|
|
19
|
+
|
|
20
|
+
debug: bool = False
|
|
21
|
+
model: str = "gpt-4.1"
|
|
22
|
+
enable_dm_filtering: bool = False # pipeline selection (experimental)
|
|
23
|
+
filter_model: str = "gpt-5-mini"
|
|
24
|
+
judge_model: str = "gpt-5-mini"
|
|
25
|
+
embeddings_model: str = "text-embedding-3-large"
|
|
26
|
+
embeddings_dim: int = 1024
|
|
27
|
+
|
|
28
|
+
pipeline_history_length: int = 20
|
|
29
|
+
|
|
30
|
+
memgraph_uri: str
|
|
31
|
+
memgraph_user: str
|
|
32
|
+
memgraph_pwd: str
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
settings = VedanaCoreSettings() # type: ignore
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from jims_core.thread.thread_context import ThreadContext
|
|
2
|
+
from vedana_core.data_model import DataModel
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
class StartPipeline:
|
|
6
|
+
"""
|
|
7
|
+
Response for /start command
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
def __init__(self, data_model: DataModel) -> None:
|
|
11
|
+
self.data_model = data_model
|
|
12
|
+
|
|
13
|
+
async def __call__(self, ctx: ThreadContext) -> None:
|
|
14
|
+
lifecycle_events = await self.data_model.conversation_lifecycle_events()
|
|
15
|
+
start_response = lifecycle_events.get("/start")
|
|
16
|
+
message = start_response or "Bot online. No response for /start command in LifecycleEvents"
|
|
17
|
+
ctx.send_message(message)
|
vedana_core/utils.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
1
|
+
def cast_dtype(value, value_name: str, dtype: str | None):
|
|
2
|
+
"""
|
|
3
|
+
умнейшая конвертация строковых переменных по заданному в дата-модели типу данных
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
try:
|
|
7
|
+
if dtype == "int" or dtype == "float":
|
|
8
|
+
if isinstance(value, str):
|
|
9
|
+
if len(value.strip(",").split(",")) == 2:
|
|
10
|
+
value = value.replace(",", ".")
|
|
11
|
+
elif len(value.strip(",").split(",")) > 2: # assume "," is 000 delimeter
|
|
12
|
+
value = value.replace(",", "")
|
|
13
|
+
elif len(value.strip(".").split(".")) > 2: # assume "." is 000 delimeter
|
|
14
|
+
value = value.replace(".", "")
|
|
15
|
+
|
|
16
|
+
value = float(value)
|
|
17
|
+
if dtype == "int":
|
|
18
|
+
value = int(value)
|
|
19
|
+
|
|
20
|
+
elif dtype == "bool":
|
|
21
|
+
return str(value).strip().lower() in ["1", "true", "да", "есть"]
|
|
22
|
+
elif dtype == "str":
|
|
23
|
+
return str(value)
|
|
24
|
+
|
|
25
|
+
else:
|
|
26
|
+
return value # Fallback
|
|
27
|
+
|
|
28
|
+
except (ValueError, TypeError):
|
|
29
|
+
print(f"cast_dtype error for value {value_name} {value} (expected dtype {dtype})")
|
|
30
|
+
|
|
31
|
+
return value
|