vedana-etl 0.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vedana_etl/__init__.py +0 -0
- vedana_etl/app.py +10 -0
- vedana_etl/catalog.py +266 -0
- vedana_etl/config.py +22 -0
- vedana_etl/pipeline.py +142 -0
- vedana_etl/py.typed +0 -0
- vedana_etl/schemas.py +31 -0
- vedana_etl/settings.py +23 -0
- vedana_etl/steps.py +685 -0
- vedana_etl/store.py +208 -0
- vedana_etl-0.1.0.dev3.dist-info/METADATA +51 -0
- vedana_etl-0.1.0.dev3.dist-info/RECORD +13 -0
- vedana_etl-0.1.0.dev3.dist-info/WHEEL +4 -0
vedana_etl/steps.py
ADDED
|
@@ -0,0 +1,685 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
import re
|
|
3
|
+
from typing import Any, Iterator, cast
|
|
4
|
+
from unicodedata import normalize
|
|
5
|
+
from uuid import UUID
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from jims_core.llms.llm_provider import LLMProvider
|
|
9
|
+
from neo4j import GraphDatabase
|
|
10
|
+
from vedana_core.data_model import Anchor, Attribute, Link
|
|
11
|
+
from vedana_core.data_provider import GristAPIDataProvider, GristCsvDataProvider
|
|
12
|
+
from vedana_core.settings import settings as core_settings
|
|
13
|
+
|
|
14
|
+
from vedana_etl.settings import settings as etl_settings
|
|
15
|
+
|
|
16
|
+
# pd.replace() throws warnings due to type downcasting. Behavior will change only in pandas 3.0
|
|
17
|
+
# https://github.com/pandas-dev/pandas/issues/57734
|
|
18
|
+
pd.set_option("future.no_silent_downcasting", True)
|
|
19
|
+
|
|
20
|
+
logging.basicConfig(level=logging.DEBUG)
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def is_uuid(val: str):
|
|
25
|
+
try:
|
|
26
|
+
UUID(str(val))
|
|
27
|
+
return True
|
|
28
|
+
except (ValueError, TypeError):
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def clean_str(text: str) -> str:
|
|
33
|
+
if not isinstance(text, str):
|
|
34
|
+
return text
|
|
35
|
+
text = normalize("NFC", text)
|
|
36
|
+
|
|
37
|
+
# Replace non-breaking spaces and other space-like Unicode chars with regular space
|
|
38
|
+
text = re.sub(r"[\u00A0\u2000-\u200B\u202F\u205F\u3000]", " ", text)
|
|
39
|
+
|
|
40
|
+
# Remove zero-width spaces and BOMs
|
|
41
|
+
text = re.sub(r"[\u200B\u200C\u200D\uFEFF]", "", text)
|
|
42
|
+
|
|
43
|
+
# Collapse multiple spaces
|
|
44
|
+
text = re.sub(r"\s+", " ", text)
|
|
45
|
+
return text.strip()
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def get_data_model() -> Iterator[
|
|
49
|
+
tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]
|
|
50
|
+
]:
|
|
51
|
+
loader = GristCsvDataProvider(
|
|
52
|
+
doc_id=core_settings.grist_data_model_doc_id,
|
|
53
|
+
grist_server=core_settings.grist_server_url,
|
|
54
|
+
api_key=core_settings.grist_api_key,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
_links_df = loader.get_table("Links")
|
|
58
|
+
links_df = cast(
|
|
59
|
+
pd.DataFrame,
|
|
60
|
+
_links_df[
|
|
61
|
+
[
|
|
62
|
+
"anchor1",
|
|
63
|
+
"anchor2",
|
|
64
|
+
"sentence",
|
|
65
|
+
"description",
|
|
66
|
+
"query",
|
|
67
|
+
"anchor1_link_column_name",
|
|
68
|
+
"anchor2_link_column_name",
|
|
69
|
+
"has_direction",
|
|
70
|
+
]
|
|
71
|
+
],
|
|
72
|
+
)
|
|
73
|
+
assert links_df is not None
|
|
74
|
+
|
|
75
|
+
links_df["has_direction"] = _links_df["has_direction"].astype(bool)
|
|
76
|
+
links_df = links_df.dropna(subset=["anchor1", "anchor2", "sentence"], inplace=False)
|
|
77
|
+
|
|
78
|
+
anchor_attrs_df = loader.get_table("Anchor_attributes")
|
|
79
|
+
anchor_attrs_df = cast(
|
|
80
|
+
pd.DataFrame,
|
|
81
|
+
anchor_attrs_df[
|
|
82
|
+
[
|
|
83
|
+
"anchor",
|
|
84
|
+
"attribute_name",
|
|
85
|
+
"description",
|
|
86
|
+
"data_example",
|
|
87
|
+
"embeddable",
|
|
88
|
+
"query",
|
|
89
|
+
"dtype",
|
|
90
|
+
"embed_threshold",
|
|
91
|
+
]
|
|
92
|
+
],
|
|
93
|
+
)
|
|
94
|
+
anchor_attrs_df["embeddable"] = anchor_attrs_df["embeddable"].astype(bool)
|
|
95
|
+
anchor_attrs_df["embed_threshold"] = anchor_attrs_df["embed_threshold"].astype(float)
|
|
96
|
+
anchor_attrs_df = anchor_attrs_df.dropna(subset=["anchor", "attribute_name"], how="any")
|
|
97
|
+
|
|
98
|
+
link_attrs_df = loader.get_table("Link_attributes")
|
|
99
|
+
link_attrs_df = cast(
|
|
100
|
+
pd.DataFrame,
|
|
101
|
+
link_attrs_df[
|
|
102
|
+
[
|
|
103
|
+
"link",
|
|
104
|
+
"attribute_name",
|
|
105
|
+
"description",
|
|
106
|
+
"data_example",
|
|
107
|
+
"embeddable",
|
|
108
|
+
"query",
|
|
109
|
+
"dtype",
|
|
110
|
+
"embed_threshold",
|
|
111
|
+
]
|
|
112
|
+
],
|
|
113
|
+
)
|
|
114
|
+
link_attrs_df["embeddable"] = link_attrs_df["embeddable"].astype(bool)
|
|
115
|
+
link_attrs_df["embed_threshold"] = link_attrs_df["embed_threshold"].astype(float)
|
|
116
|
+
link_attrs_df = link_attrs_df.dropna(subset=["link", "attribute_name"], how="any")
|
|
117
|
+
|
|
118
|
+
anchors_df = loader.get_table("Anchors")
|
|
119
|
+
anchors_df = cast(
|
|
120
|
+
pd.DataFrame,
|
|
121
|
+
anchors_df[
|
|
122
|
+
[
|
|
123
|
+
"noun",
|
|
124
|
+
"description",
|
|
125
|
+
"id_example",
|
|
126
|
+
"query",
|
|
127
|
+
]
|
|
128
|
+
],
|
|
129
|
+
)
|
|
130
|
+
anchors_df = anchors_df.dropna(subset=["noun"], inplace=False)
|
|
131
|
+
anchors_df = anchors_df.astype(str)
|
|
132
|
+
|
|
133
|
+
queries_df = loader.get_table("Queries")
|
|
134
|
+
queries_df = cast(pd.DataFrame, queries_df[["query_name", "query_example"]])
|
|
135
|
+
queries_df = queries_df.dropna()
|
|
136
|
+
queries_df = queries_df.astype(str)
|
|
137
|
+
|
|
138
|
+
prompts_df = loader.get_table("Prompts")
|
|
139
|
+
prompts_df = cast(pd.DataFrame, prompts_df[["name", "text"]])
|
|
140
|
+
prompts_df = prompts_df.dropna()
|
|
141
|
+
prompts_df = prompts_df.astype(str)
|
|
142
|
+
|
|
143
|
+
conversation_lifecycle_df = loader.get_table("ConversationLifecycle")
|
|
144
|
+
conversation_lifecycle_df = cast(pd.DataFrame, conversation_lifecycle_df[["event", "text"]])
|
|
145
|
+
conversation_lifecycle_df = conversation_lifecycle_df.dropna()
|
|
146
|
+
conversation_lifecycle_df = conversation_lifecycle_df.astype(str)
|
|
147
|
+
|
|
148
|
+
yield anchors_df, anchor_attrs_df, link_attrs_df, links_df, queries_df, prompts_df, conversation_lifecycle_df
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
def get_grist_data() -> Iterator[tuple[pd.DataFrame, pd.DataFrame]]:
|
|
152
|
+
"""
|
|
153
|
+
Fetch all anchors and links from Grist into node/edge tables
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Build necessary DataModel elements from input tables
|
|
157
|
+
dm_anchors_df, dm_anchor_attrs_df, dm_link_attrs_df, dm_links_df, _q, _p, _cl = next(get_data_model())
|
|
158
|
+
|
|
159
|
+
# Anchors
|
|
160
|
+
dm_anchors: dict[str, Anchor] = {}
|
|
161
|
+
for _, a_row in dm_anchors_df.iterrows():
|
|
162
|
+
noun = str(a_row.get("noun")).strip()
|
|
163
|
+
if not noun:
|
|
164
|
+
continue
|
|
165
|
+
dm_anchors[noun] = Anchor(
|
|
166
|
+
noun=noun,
|
|
167
|
+
description=a_row.get("description", ""),
|
|
168
|
+
id_example=a_row.get("id_example", ""),
|
|
169
|
+
query=a_row.get("query", ""),
|
|
170
|
+
attributes=[],
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
# Anchor attributes
|
|
174
|
+
for _, attr_row in dm_anchor_attrs_df.iterrows():
|
|
175
|
+
noun = str(attr_row.get("anchor")).strip()
|
|
176
|
+
if not noun or noun not in dm_anchors:
|
|
177
|
+
continue
|
|
178
|
+
dm_anchors[noun].attributes.append(
|
|
179
|
+
Attribute(
|
|
180
|
+
name=attr_row.get("attribute_name", ""),
|
|
181
|
+
description=attr_row.get("description", ""),
|
|
182
|
+
example=attr_row.get("data_example", ""),
|
|
183
|
+
dtype=attr_row.get("dtype", ""),
|
|
184
|
+
query=attr_row.get("query", ""),
|
|
185
|
+
embeddable=bool(attr_row.get("embeddable", False)),
|
|
186
|
+
embed_threshold=float(attr_row.get("embed_threshold", 1.0)),
|
|
187
|
+
)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
# Links
|
|
191
|
+
dm_links: dict[str, Link] = {}
|
|
192
|
+
for _, l_row in dm_links_df.iterrows():
|
|
193
|
+
a1 = str(l_row.get("anchor1")).strip()
|
|
194
|
+
a2 = str(l_row.get("anchor2")).strip()
|
|
195
|
+
if not a1 or not a2 or a1 not in dm_anchors or a2 not in dm_anchors:
|
|
196
|
+
logger.error(f'Link type has invalid anchors "{a1} - {a2}", skipping')
|
|
197
|
+
continue
|
|
198
|
+
dm_links[l_row.get("sentence")] = Link(
|
|
199
|
+
anchor_from=dm_anchors[a1],
|
|
200
|
+
anchor_to=dm_anchors[a2],
|
|
201
|
+
sentence=l_row.get("sentence"),
|
|
202
|
+
description=l_row.get("description", ""),
|
|
203
|
+
query=l_row.get("query", ""),
|
|
204
|
+
attributes=[],
|
|
205
|
+
has_direction=bool(l_row.get("has_direction", False)),
|
|
206
|
+
anchor_from_link_attr_name=l_row.get("anchor1_link_column_name", ""),
|
|
207
|
+
anchor_to_link_attr_name=l_row.get("anchor2_link_column_name", ""),
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# Link attributes
|
|
211
|
+
for _, lattr_row in dm_link_attrs_df.iterrows():
|
|
212
|
+
sent = str(lattr_row.get("link")).strip()
|
|
213
|
+
if sent not in dm_links:
|
|
214
|
+
continue
|
|
215
|
+
dm_links[sent].attributes.append(
|
|
216
|
+
Attribute(
|
|
217
|
+
name=str(lattr_row.get("attribute_name")),
|
|
218
|
+
description=str(lattr_row.get("description", "")),
|
|
219
|
+
example=str(lattr_row.get("data_example", "")),
|
|
220
|
+
dtype=str(lattr_row.get("dtype", "")),
|
|
221
|
+
query=str(lattr_row.get("query", "")),
|
|
222
|
+
embeddable=bool(lattr_row.get("embeddable", False)),
|
|
223
|
+
embed_threshold=float(lattr_row.get("embed_threshold", 1.0)),
|
|
224
|
+
)
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
# Get data from Grist
|
|
228
|
+
|
|
229
|
+
dp = GristAPIDataProvider(
|
|
230
|
+
doc_id=core_settings.grist_data_doc_id,
|
|
231
|
+
grist_server=core_settings.grist_server_url,
|
|
232
|
+
api_key=core_settings.grist_api_key,
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Foreign key type links
|
|
236
|
+
fk_link_records_from = []
|
|
237
|
+
fk_link_records_to = []
|
|
238
|
+
|
|
239
|
+
# Nodes
|
|
240
|
+
node_records: dict[str, Any] = {}
|
|
241
|
+
anchor_types = dp.get_anchor_tables() # does not check data model! only lists tables that are named anchor_...
|
|
242
|
+
logger.debug(f"Fetching {len(anchor_types)} anchor tables from Grist: {anchor_types}")
|
|
243
|
+
|
|
244
|
+
for anchor_type in anchor_types:
|
|
245
|
+
# check anchor's existence in data model
|
|
246
|
+
dm_anchor = dm_anchors.get(anchor_type)
|
|
247
|
+
if not dm_anchor:
|
|
248
|
+
logger.error(f'Anchor "{anchor_type}" not described in data model, skipping')
|
|
249
|
+
continue
|
|
250
|
+
dm_anchor_attrs = [attr.name for attr in dm_anchor.attributes]
|
|
251
|
+
|
|
252
|
+
# get anchor's links
|
|
253
|
+
# todo check link column directions
|
|
254
|
+
anchor_from_link_cols = [
|
|
255
|
+
link
|
|
256
|
+
for link in dm_links.values()
|
|
257
|
+
if link.anchor_from.noun == anchor_type and link.anchor_from_link_attr_name
|
|
258
|
+
]
|
|
259
|
+
anchor_to_link_cols = [
|
|
260
|
+
link for link in dm_links.values() if link.anchor_to.noun == anchor_type and link.anchor_to_link_attr_name
|
|
261
|
+
]
|
|
262
|
+
|
|
263
|
+
try:
|
|
264
|
+
anchors = dp.get_anchors(anchor_type, dm_attrs=dm_anchor.attributes, dm_anchor_links=anchor_from_link_cols)
|
|
265
|
+
except Exception as exc:
|
|
266
|
+
logger.exception(f"Failed to fetch anchors for type {anchor_type}: {exc}")
|
|
267
|
+
continue
|
|
268
|
+
|
|
269
|
+
node_records[anchor_type] = {}
|
|
270
|
+
|
|
271
|
+
for a in anchors:
|
|
272
|
+
for link in anchor_from_link_cols:
|
|
273
|
+
# get link other end id(s)
|
|
274
|
+
link_ids = a.data.get(link.anchor_from_link_attr_name, [])
|
|
275
|
+
if isinstance(link_ids, (int, str)):
|
|
276
|
+
link_ids = [link_ids]
|
|
277
|
+
elif link_ids is None:
|
|
278
|
+
continue
|
|
279
|
+
|
|
280
|
+
for to_dp_id in link_ids:
|
|
281
|
+
fk_link_records_from.append(
|
|
282
|
+
{
|
|
283
|
+
"from_node_id": a.id,
|
|
284
|
+
"from_node_dp_id": a.dp_id,
|
|
285
|
+
# "to_node_id": link.id_to, <-- not provided here
|
|
286
|
+
"to_node_dp_id": to_dp_id,
|
|
287
|
+
"from_node_type": anchor_type,
|
|
288
|
+
"to_node_type": link.anchor_to.noun,
|
|
289
|
+
"edge_label": link.sentence,
|
|
290
|
+
# "attributes": {}, # not present in data (format not specified) yet
|
|
291
|
+
}
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
for link in anchor_to_link_cols:
|
|
295
|
+
# get link other end id(s)
|
|
296
|
+
link_ids = a.data.get(link.anchor_to_link_attr_name, [])
|
|
297
|
+
if isinstance(link_ids, (int, str)):
|
|
298
|
+
link_ids = [link_ids]
|
|
299
|
+
elif link_ids is None:
|
|
300
|
+
continue
|
|
301
|
+
|
|
302
|
+
for from_dp_id in link_ids:
|
|
303
|
+
fk_link_records_to.append(
|
|
304
|
+
{
|
|
305
|
+
# "from_node_id": link.id_from, <-- not provided here
|
|
306
|
+
"from_node_dp_id": from_dp_id,
|
|
307
|
+
"to_node_id": a.id,
|
|
308
|
+
"to_node_dp_id": a.dp_id,
|
|
309
|
+
"from_node_type": link.anchor_from.noun,
|
|
310
|
+
"to_node_type": anchor_type,
|
|
311
|
+
"edge_label": link.sentence,
|
|
312
|
+
# "attributes": {}, # not present in data (format not specified) yet
|
|
313
|
+
}
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
node_records[anchor_type][a.dp_id] = {
|
|
317
|
+
"node_id": a.id,
|
|
318
|
+
"node_type": a.type,
|
|
319
|
+
"attributes": {k: v for k, v in a.data.items() if k in dm_anchor_attrs} or {},
|
|
320
|
+
}
|
|
321
|
+
|
|
322
|
+
nodes_df = pd.DataFrame(
|
|
323
|
+
[
|
|
324
|
+
{"node_id": rec.get("node_id"), "node_type": rec.get("node_type"), "attributes": rec.get("attributes", {})}
|
|
325
|
+
for a in node_records.values()
|
|
326
|
+
for rec in a.values()
|
|
327
|
+
],
|
|
328
|
+
columns=["node_id", "node_type", "attributes"],
|
|
329
|
+
)
|
|
330
|
+
|
|
331
|
+
# Resolve links (database id <-> our id), if necessary
|
|
332
|
+
for lk in fk_link_records_to:
|
|
333
|
+
if isinstance(lk["from_node_dp_id"], int):
|
|
334
|
+
lk["from_node_id"] = node_records[lk["from_node_type"]].get(lk["from_node_dp_id"], {}).get("node_id")
|
|
335
|
+
else:
|
|
336
|
+
lk["from_node_id"] = lk["from_node_dp_id"] # <-- str dp_id is an already correct id
|
|
337
|
+
for lk in fk_link_records_from:
|
|
338
|
+
if isinstance(lk["to_node_dp_id"], int):
|
|
339
|
+
lk["to_node_id"] = node_records[lk["to_node_type"]].get(lk["to_node_dp_id"], {}).get("node_id")
|
|
340
|
+
else:
|
|
341
|
+
lk["to_node_id"] = lk["to_node_dp_id"]
|
|
342
|
+
|
|
343
|
+
if fk_link_records_to:
|
|
344
|
+
fk_links_to_df = pd.DataFrame(fk_link_records_to).dropna(subset=["from_node_id", "to_node_id"])
|
|
345
|
+
else:
|
|
346
|
+
fk_links_to_df = pd.DataFrame(
|
|
347
|
+
columns=["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label"]
|
|
348
|
+
)
|
|
349
|
+
|
|
350
|
+
if fk_link_records_from:
|
|
351
|
+
fk_links_from_df = pd.DataFrame(fk_link_records_from).dropna(subset=["from_node_id", "to_node_id"])
|
|
352
|
+
else:
|
|
353
|
+
fk_links_from_df = pd.DataFrame(
|
|
354
|
+
columns=["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label"]
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
fk_df = pd.concat([fk_links_from_df, fk_links_to_df], axis=0, ignore_index=True)
|
|
358
|
+
fk_df["attributes"] = [dict()] * fk_df.shape[0]
|
|
359
|
+
fk_df = fk_df[["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label", "attributes"]]
|
|
360
|
+
|
|
361
|
+
# keep only links with both nodes present (+done in the end on edges_df); todo add test for this case
|
|
362
|
+
fk_df = fk_df.loc[(fk_df["from_node_id"].isin(nodes_df["node_id"]) & fk_df["to_node_id"].isin(nodes_df["node_id"]))]
|
|
363
|
+
|
|
364
|
+
# Edges
|
|
365
|
+
edge_records = []
|
|
366
|
+
link_types = dp.get_link_tables()
|
|
367
|
+
logger.debug(f"Fetching {len(link_types)} link types from Grist: {link_types}")
|
|
368
|
+
|
|
369
|
+
for link_type in link_types:
|
|
370
|
+
# check link's existence in data model (dm_link is used from anchor_from / to references only)
|
|
371
|
+
dm_link_list = [
|
|
372
|
+
link
|
|
373
|
+
for link in dm_links.values()
|
|
374
|
+
if link.sentence.lower() == link_type.lower()
|
|
375
|
+
or link_type.lower() == f"{link.anchor_from.noun}_{link.anchor_to.noun}".lower()
|
|
376
|
+
]
|
|
377
|
+
if not dm_link_list:
|
|
378
|
+
logger.error(f'Link type "{link_type}" not described in data model, skipping')
|
|
379
|
+
continue
|
|
380
|
+
dm_link = dm_link_list[0]
|
|
381
|
+
dm_link_attrs = [a.name for a in dm_link.attributes]
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
links = dp.get_links(link_type, dm_link)
|
|
385
|
+
except Exception as exc:
|
|
386
|
+
logger.error(f"Failed to fetch links for type {link_type}: {exc}")
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
for link_record in links:
|
|
390
|
+
id_from = link_record.id_from
|
|
391
|
+
id_to = link_record.id_to
|
|
392
|
+
|
|
393
|
+
# resolve foreign key link_record id's
|
|
394
|
+
if isinstance(id_from, int):
|
|
395
|
+
id_from = node_records[link_record.anchor_from].get(id_from, {}).get("node_id")
|
|
396
|
+
if isinstance(id_to, int):
|
|
397
|
+
id_to = node_records[link_record.anchor_to].get(id_to, {}).get("node_id")
|
|
398
|
+
|
|
399
|
+
edge_records.append(
|
|
400
|
+
{
|
|
401
|
+
"from_node_id": id_from,
|
|
402
|
+
"to_node_id": id_to,
|
|
403
|
+
"from_node_type": link_record.anchor_from,
|
|
404
|
+
"to_node_type": link_record.anchor_to,
|
|
405
|
+
"edge_label": link_record.type,
|
|
406
|
+
"attributes": {k: v for k, v in link_record.data.items() if k in dm_link_attrs} or {},
|
|
407
|
+
}
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
edges_df = pd.DataFrame(edge_records)
|
|
411
|
+
edges_df = edges_df.loc[
|
|
412
|
+
(edges_df["from_node_id"].isin(nodes_df["node_id"]) & edges_df["to_node_id"].isin(nodes_df["node_id"]))
|
|
413
|
+
]
|
|
414
|
+
|
|
415
|
+
edges_df = pd.concat([edges_df, fk_df], ignore_index=True)
|
|
416
|
+
|
|
417
|
+
# add reverse links (if already provided in data, duplicates will be removed later)
|
|
418
|
+
for link in dm_links.values():
|
|
419
|
+
if not link.has_direction:
|
|
420
|
+
rev_edges = cast(
|
|
421
|
+
pd.DataFrame,
|
|
422
|
+
edges_df.loc[
|
|
423
|
+
(
|
|
424
|
+
(
|
|
425
|
+
(edges_df["from_node_type"] == link.anchor_from.noun)
|
|
426
|
+
& (edges_df["to_node_type"] == link.anchor_to.noun)
|
|
427
|
+
)
|
|
428
|
+
| ( # edges with anchors written in reverse are also valid
|
|
429
|
+
(edges_df["from_node_type"] == link.anchor_to.noun)
|
|
430
|
+
& (edges_df["to_node_type"] == link.anchor_from.noun)
|
|
431
|
+
)
|
|
432
|
+
)
|
|
433
|
+
& (edges_df["edge_label"] == link.sentence)
|
|
434
|
+
].copy(),
|
|
435
|
+
)
|
|
436
|
+
if not rev_edges.empty:
|
|
437
|
+
rev_edges = rev_edges.rename(
|
|
438
|
+
columns={
|
|
439
|
+
"from_node_id": "to_node_id",
|
|
440
|
+
"to_node_id": "from_node_id",
|
|
441
|
+
"from_node_type": "to_node_type",
|
|
442
|
+
"to_node_type": "from_node_type",
|
|
443
|
+
}
|
|
444
|
+
)
|
|
445
|
+
edges_df = pd.concat([edges_df, rev_edges], ignore_index=True)
|
|
446
|
+
|
|
447
|
+
# preventive drop_duplicates / na records
|
|
448
|
+
if not nodes_df.empty:
|
|
449
|
+
nodes_df = nodes_df.dropna(subset=["node_id", "node_type"]).drop_duplicates(subset=["node_id"])
|
|
450
|
+
if not edges_df.empty:
|
|
451
|
+
edges_df = edges_df.dropna(subset=["from_node_id", "to_node_id", "edge_label"]).drop_duplicates(
|
|
452
|
+
subset=["from_node_id", "to_node_id", "edge_label"]
|
|
453
|
+
)
|
|
454
|
+
yield nodes_df, edges_df
|
|
455
|
+
|
|
456
|
+
|
|
457
|
+
def ensure_memgraph_node_indexes(dm_anchor_attrs: pd.DataFrame) -> pd.DataFrame:
|
|
458
|
+
"""
|
|
459
|
+
Create label / vector indices
|
|
460
|
+
https://memgraph.com/docs/querying/vector-search
|
|
461
|
+
"""
|
|
462
|
+
|
|
463
|
+
anchor_types: set[str] = set(dm_anchor_attrs["anchor"].dropna().unique())
|
|
464
|
+
|
|
465
|
+
# embeddable attrs for vector indices
|
|
466
|
+
# vec_a_attr_rows = dm_anchor_attrs[dm_anchor_attrs["embeddable"]] # & (dm_anchor_attrs["dtype"].str.lower() == "str")]
|
|
467
|
+
|
|
468
|
+
driver = GraphDatabase.driver(
|
|
469
|
+
uri=core_settings.memgraph_uri,
|
|
470
|
+
auth=(
|
|
471
|
+
core_settings.memgraph_user,
|
|
472
|
+
core_settings.memgraph_pwd,
|
|
473
|
+
),
|
|
474
|
+
)
|
|
475
|
+
|
|
476
|
+
with driver.session() as session:
|
|
477
|
+
# Indices on Anchors
|
|
478
|
+
for label in anchor_types:
|
|
479
|
+
try:
|
|
480
|
+
session.run(f"CREATE INDEX ON :`{label}`(id)") # type: ignore
|
|
481
|
+
except Exception as exc:
|
|
482
|
+
logger.debug(f"CREATE INDEX failed for label {label}: {exc}") # probably index exists
|
|
483
|
+
|
|
484
|
+
try:
|
|
485
|
+
session.run(f"CREATE CONSTRAINT ON (n:`{label}`) ASSERT n.id IS UNIQUE") # type: ignore
|
|
486
|
+
except Exception as exc:
|
|
487
|
+
logger.debug(f"CREATE CONSTRAINT failed for label {label}: {exc}") # probably index exists
|
|
488
|
+
|
|
489
|
+
# Vector indices - Anchors. Deprecated due to move to pgvectorstore
|
|
490
|
+
# for _, row in vec_a_attr_rows.iterrows():
|
|
491
|
+
# attr: str = row["attribute_name"]
|
|
492
|
+
# embeddings_dim = core_settings.embeddings_dim
|
|
493
|
+
# label = row["anchor"]
|
|
494
|
+
#
|
|
495
|
+
# idx_name = f"{label}_{attr}_embed_idx".replace(" ", "_")
|
|
496
|
+
# prop_name = f"{attr}_embedding"
|
|
497
|
+
#
|
|
498
|
+
# cypher = (
|
|
499
|
+
# f"CREATE VECTOR INDEX `{idx_name}` ON :`{label}`(`{prop_name}`) "
|
|
500
|
+
# f'WITH CONFIG {{"dimension": {embeddings_dim}, "capacity": 1024, "metric": "cos"}}'
|
|
501
|
+
# )
|
|
502
|
+
# try:
|
|
503
|
+
# session.run(cypher) # type: ignore
|
|
504
|
+
# except Exception as exc:
|
|
505
|
+
# logger.debug(f"CREATE VECTOR INDEX failed for {idx_name}: {exc}") # probably index exists
|
|
506
|
+
# continue
|
|
507
|
+
|
|
508
|
+
driver.close()
|
|
509
|
+
|
|
510
|
+
# nominal outputs
|
|
511
|
+
mg_anchor_indexes = pd.DataFrame({"anchor": list(anchor_types)})
|
|
512
|
+
# mg_anchor_vector_indexes = vec_a_attr_rows[["anchor", "attribute_name"]].copy()
|
|
513
|
+
return mg_anchor_indexes
|
|
514
|
+
|
|
515
|
+
|
|
516
|
+
def ensure_memgraph_edge_indexes(dm_link_attrs: pd.DataFrame) -> pd.DataFrame:
|
|
517
|
+
"""
|
|
518
|
+
Create label / vector indices
|
|
519
|
+
https://memgraph.com/docs/querying/vector-search
|
|
520
|
+
"""
|
|
521
|
+
|
|
522
|
+
link_types: set[str] = set(dm_link_attrs["link"].dropna().unique())
|
|
523
|
+
|
|
524
|
+
# embeddable attrs for vector indices
|
|
525
|
+
# vec_l_attr_rows = dm_link_attrs[dm_link_attrs["embeddable"]] # & (dm_link_attrs["dtype"].str.lower() == "str")]
|
|
526
|
+
|
|
527
|
+
driver = GraphDatabase.driver(
|
|
528
|
+
uri=core_settings.memgraph_uri,
|
|
529
|
+
auth=(
|
|
530
|
+
core_settings.memgraph_user,
|
|
531
|
+
core_settings.memgraph_pwd,
|
|
532
|
+
),
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
with driver.session() as session:
|
|
536
|
+
# Indices on Edges (optimizes queries such as MATCH ()-[r:EDGE_TYPE]->() RETURN r;)
|
|
537
|
+
# If queried by edge property, will need to add property index (similar to above for Anchor)
|
|
538
|
+
for label in link_types:
|
|
539
|
+
try:
|
|
540
|
+
session.run(f"CREATE EDGE INDEX ON :`{label}`") # type: ignore
|
|
541
|
+
except Exception as exc:
|
|
542
|
+
logger.debug(f"CREATE EDGE INDEX failed for label {label}: {exc}") # probably index exists
|
|
543
|
+
|
|
544
|
+
# todo edge constraints?
|
|
545
|
+
# try:
|
|
546
|
+
# session.run(f"CREATE CONSTRAINT ON (n:`{label}`) ASSERT n.id IS UNIQUE") # type: ignore
|
|
547
|
+
# except Exception as exc:
|
|
548
|
+
# logger.debug(f"CREATE CONSTRAINT failed for label {label}: {exc}") # probably index exists
|
|
549
|
+
|
|
550
|
+
# Vector indices - Edges. Deprecated due to move to pgvectorstore
|
|
551
|
+
# for _, row in vec_l_attr_rows.iterrows():
|
|
552
|
+
# attr: str = row["attribute_name"]
|
|
553
|
+
# embeddings_dim = core_settings.embeddings_dim
|
|
554
|
+
# label = row["link"]
|
|
555
|
+
#
|
|
556
|
+
# idx_name = f"{label}_{attr}_embed_idx".replace(" ", "_")
|
|
557
|
+
# prop_name = f"{attr}_embedding"
|
|
558
|
+
#
|
|
559
|
+
# cypher = (
|
|
560
|
+
# f"CREATE VECTOR EDGE INDEX `{idx_name}` ON :`{label}`(`{prop_name}`) "
|
|
561
|
+
# f'WITH CONFIG {{"dimension": {embeddings_dim}, "capacity": 1024, "metric": "cos"}}'
|
|
562
|
+
# )
|
|
563
|
+
# try:
|
|
564
|
+
# session.run(cypher) # type: ignore
|
|
565
|
+
# except Exception as exc:
|
|
566
|
+
# logger.debug(f"CREATE VECTOR EDGE INDEX failed for {idx_name}: {exc}") # probably index exists
|
|
567
|
+
# continue
|
|
568
|
+
|
|
569
|
+
driver.close()
|
|
570
|
+
|
|
571
|
+
# nominal outputs
|
|
572
|
+
mg_link_indexes = pd.DataFrame({"link": list(link_types)})
|
|
573
|
+
# mg_link_vector_indexes = vec_l_attr_rows[["link", "attribute_name"]].copy()
|
|
574
|
+
return mg_link_indexes
|
|
575
|
+
|
|
576
|
+
|
|
577
|
+
def generate_embeddings(
|
|
578
|
+
df: pd.DataFrame,
|
|
579
|
+
dm_attributes: pd.DataFrame,
|
|
580
|
+
) -> pd.DataFrame:
|
|
581
|
+
"""Generate embeddings for embeddable text attributes"""
|
|
582
|
+
type_col = "node_type" if "node_id" in df.columns else "edge_label"
|
|
583
|
+
pkeys = ["node_id", "node_type"] if type_col == "node_type" else ["from_node_id", "to_node_id", "edge_label"]
|
|
584
|
+
dm_attributes = dm_attributes[dm_attributes["embeddable"]] # & (dm_attributes["dtype"].str.lower() == "str")]
|
|
585
|
+
|
|
586
|
+
# Build mapping type -> list[attribute_name] that need embedding
|
|
587
|
+
mapping: dict[str, list[str]] = {}
|
|
588
|
+
for _, row in dm_attributes.iterrows():
|
|
589
|
+
record_type = row["anchor"] if type_col == "node_type" else row["link"]
|
|
590
|
+
if pd.isna(record_type):
|
|
591
|
+
continue
|
|
592
|
+
mapping.setdefault(record_type, []).append(row["attribute_name"])
|
|
593
|
+
|
|
594
|
+
tasks: list[tuple[int, str, str]] = [] # (row_pos, attr_name, text)
|
|
595
|
+
for pos, (_, row) in enumerate(df.iterrows()):
|
|
596
|
+
typ_val = row[type_col]
|
|
597
|
+
attrs_needed = mapping.get(typ_val)
|
|
598
|
+
if not attrs_needed:
|
|
599
|
+
continue
|
|
600
|
+
row_attrs = row.get("attributes", {})
|
|
601
|
+
for attr_name in attrs_needed:
|
|
602
|
+
text_val = row_attrs.get(attr_name)
|
|
603
|
+
if text_val and isinstance(text_val, str) and not is_uuid(text_val):
|
|
604
|
+
tasks.append((pos, attr_name, text_val))
|
|
605
|
+
|
|
606
|
+
if not tasks:
|
|
607
|
+
return df
|
|
608
|
+
|
|
609
|
+
provider = LLMProvider()
|
|
610
|
+
|
|
611
|
+
unique_texts = list(dict.fromkeys([t[2] for t in tasks]))
|
|
612
|
+
vectors = provider.create_embeddings_sync(unique_texts)
|
|
613
|
+
vector_by_text = dict(zip(unique_texts, vectors))
|
|
614
|
+
|
|
615
|
+
# Apply embeddings
|
|
616
|
+
rows: list[dict[str, object]] = []
|
|
617
|
+
for row_pos, attr_name, text in tasks:
|
|
618
|
+
vec = vector_by_text[text]
|
|
619
|
+
rows.append(
|
|
620
|
+
{
|
|
621
|
+
**df.loc[row_pos, pkeys],
|
|
622
|
+
"attribute_name": attr_name,
|
|
623
|
+
"attribute_value": text,
|
|
624
|
+
"embedding": vec,
|
|
625
|
+
}
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
return pd.DataFrame.from_records(rows)
|
|
629
|
+
|
|
630
|
+
|
|
631
|
+
def pass_df_to_memgraph(
|
|
632
|
+
df: pd.DataFrame,
|
|
633
|
+
) -> pd.DataFrame:
|
|
634
|
+
"""dummy method for BatchTransform to MemgraphStore"""
|
|
635
|
+
return df.copy()
|
|
636
|
+
|
|
637
|
+
|
|
638
|
+
# ---
|
|
639
|
+
# Custom parts
|
|
640
|
+
|
|
641
|
+
|
|
642
|
+
def prepare_nodes(
|
|
643
|
+
grist_nodes_df: pd.DataFrame,
|
|
644
|
+
) -> pd.DataFrame:
|
|
645
|
+
return grist_nodes_df.copy()
|
|
646
|
+
|
|
647
|
+
|
|
648
|
+
def prepare_edges(
|
|
649
|
+
grist_edges_df: pd.DataFrame,
|
|
650
|
+
) -> pd.DataFrame:
|
|
651
|
+
return grist_edges_df.copy()
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
def get_eval_gds_from_grist() -> Iterator[pd.DataFrame]:
|
|
655
|
+
"""
|
|
656
|
+
Loads evaluation dataset and config rows
|
|
657
|
+
|
|
658
|
+
Output:
|
|
659
|
+
- eval_gds(gds_question, gds_answer, question_context)
|
|
660
|
+
"""
|
|
661
|
+
dp = GristAPIDataProvider(
|
|
662
|
+
doc_id=etl_settings.grist_test_set_doc_id,
|
|
663
|
+
grist_server=core_settings.grist_server_url,
|
|
664
|
+
api_key=core_settings.grist_api_key,
|
|
665
|
+
)
|
|
666
|
+
|
|
667
|
+
try:
|
|
668
|
+
gds_df = dp.get_table(etl_settings.gds_table_name)
|
|
669
|
+
except Exception as e:
|
|
670
|
+
logger.exception(f"Failed to get golden dataset {etl_settings.gds_table_name}: {e}")
|
|
671
|
+
raise e
|
|
672
|
+
|
|
673
|
+
gds_df = gds_df.dropna(subset=["gds_question", "gds_answer"]).copy()
|
|
674
|
+
gds_df = gds_df.loc[(gds_df["gds_question"] != "") & (gds_df["gds_answer"] != "")]
|
|
675
|
+
gds_df = gds_df[["gds_question", "gds_answer", "question_scenario", "question_comment", "question_context"]].astype(
|
|
676
|
+
{"gds_question": str, "gds_answer": str, "question_scenario": str, "question_comment": str}
|
|
677
|
+
)
|
|
678
|
+
yield gds_df
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
if __name__ == "__main__":
|
|
682
|
+
# import dotenv
|
|
683
|
+
# dotenv.load_dotenv()
|
|
684
|
+
# n, _l = next(get_grist_data())
|
|
685
|
+
...
|