vedana-etl 0.1.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vedana_etl/steps.py ADDED
@@ -0,0 +1,685 @@
1
+ import logging
2
+ import re
3
+ from typing import Any, Iterator, cast
4
+ from unicodedata import normalize
5
+ from uuid import UUID
6
+
7
+ import pandas as pd
8
+ from jims_core.llms.llm_provider import LLMProvider
9
+ from neo4j import GraphDatabase
10
+ from vedana_core.data_model import Anchor, Attribute, Link
11
+ from vedana_core.data_provider import GristAPIDataProvider, GristCsvDataProvider
12
+ from vedana_core.settings import settings as core_settings
13
+
14
+ from vedana_etl.settings import settings as etl_settings
15
+
16
+ # pd.replace() throws warnings due to type downcasting. Behavior will change only in pandas 3.0
17
+ # https://github.com/pandas-dev/pandas/issues/57734
18
+ pd.set_option("future.no_silent_downcasting", True)
19
+
20
+ logging.basicConfig(level=logging.DEBUG)
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ def is_uuid(val: str):
25
+ try:
26
+ UUID(str(val))
27
+ return True
28
+ except (ValueError, TypeError):
29
+ return False
30
+
31
+
32
+ def clean_str(text: str) -> str:
33
+ if not isinstance(text, str):
34
+ return text
35
+ text = normalize("NFC", text)
36
+
37
+ # Replace non-breaking spaces and other space-like Unicode chars with regular space
38
+ text = re.sub(r"[\u00A0\u2000-\u200B\u202F\u205F\u3000]", " ", text)
39
+
40
+ # Remove zero-width spaces and BOMs
41
+ text = re.sub(r"[\u200B\u200C\u200D\uFEFF]", "", text)
42
+
43
+ # Collapse multiple spaces
44
+ text = re.sub(r"\s+", " ", text)
45
+ return text.strip()
46
+
47
+
48
+ def get_data_model() -> Iterator[
49
+ tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame, pd.DataFrame]
50
+ ]:
51
+ loader = GristCsvDataProvider(
52
+ doc_id=core_settings.grist_data_model_doc_id,
53
+ grist_server=core_settings.grist_server_url,
54
+ api_key=core_settings.grist_api_key,
55
+ )
56
+
57
+ _links_df = loader.get_table("Links")
58
+ links_df = cast(
59
+ pd.DataFrame,
60
+ _links_df[
61
+ [
62
+ "anchor1",
63
+ "anchor2",
64
+ "sentence",
65
+ "description",
66
+ "query",
67
+ "anchor1_link_column_name",
68
+ "anchor2_link_column_name",
69
+ "has_direction",
70
+ ]
71
+ ],
72
+ )
73
+ assert links_df is not None
74
+
75
+ links_df["has_direction"] = _links_df["has_direction"].astype(bool)
76
+ links_df = links_df.dropna(subset=["anchor1", "anchor2", "sentence"], inplace=False)
77
+
78
+ anchor_attrs_df = loader.get_table("Anchor_attributes")
79
+ anchor_attrs_df = cast(
80
+ pd.DataFrame,
81
+ anchor_attrs_df[
82
+ [
83
+ "anchor",
84
+ "attribute_name",
85
+ "description",
86
+ "data_example",
87
+ "embeddable",
88
+ "query",
89
+ "dtype",
90
+ "embed_threshold",
91
+ ]
92
+ ],
93
+ )
94
+ anchor_attrs_df["embeddable"] = anchor_attrs_df["embeddable"].astype(bool)
95
+ anchor_attrs_df["embed_threshold"] = anchor_attrs_df["embed_threshold"].astype(float)
96
+ anchor_attrs_df = anchor_attrs_df.dropna(subset=["anchor", "attribute_name"], how="any")
97
+
98
+ link_attrs_df = loader.get_table("Link_attributes")
99
+ link_attrs_df = cast(
100
+ pd.DataFrame,
101
+ link_attrs_df[
102
+ [
103
+ "link",
104
+ "attribute_name",
105
+ "description",
106
+ "data_example",
107
+ "embeddable",
108
+ "query",
109
+ "dtype",
110
+ "embed_threshold",
111
+ ]
112
+ ],
113
+ )
114
+ link_attrs_df["embeddable"] = link_attrs_df["embeddable"].astype(bool)
115
+ link_attrs_df["embed_threshold"] = link_attrs_df["embed_threshold"].astype(float)
116
+ link_attrs_df = link_attrs_df.dropna(subset=["link", "attribute_name"], how="any")
117
+
118
+ anchors_df = loader.get_table("Anchors")
119
+ anchors_df = cast(
120
+ pd.DataFrame,
121
+ anchors_df[
122
+ [
123
+ "noun",
124
+ "description",
125
+ "id_example",
126
+ "query",
127
+ ]
128
+ ],
129
+ )
130
+ anchors_df = anchors_df.dropna(subset=["noun"], inplace=False)
131
+ anchors_df = anchors_df.astype(str)
132
+
133
+ queries_df = loader.get_table("Queries")
134
+ queries_df = cast(pd.DataFrame, queries_df[["query_name", "query_example"]])
135
+ queries_df = queries_df.dropna()
136
+ queries_df = queries_df.astype(str)
137
+
138
+ prompts_df = loader.get_table("Prompts")
139
+ prompts_df = cast(pd.DataFrame, prompts_df[["name", "text"]])
140
+ prompts_df = prompts_df.dropna()
141
+ prompts_df = prompts_df.astype(str)
142
+
143
+ conversation_lifecycle_df = loader.get_table("ConversationLifecycle")
144
+ conversation_lifecycle_df = cast(pd.DataFrame, conversation_lifecycle_df[["event", "text"]])
145
+ conversation_lifecycle_df = conversation_lifecycle_df.dropna()
146
+ conversation_lifecycle_df = conversation_lifecycle_df.astype(str)
147
+
148
+ yield anchors_df, anchor_attrs_df, link_attrs_df, links_df, queries_df, prompts_df, conversation_lifecycle_df
149
+
150
+
151
+ def get_grist_data() -> Iterator[tuple[pd.DataFrame, pd.DataFrame]]:
152
+ """
153
+ Fetch all anchors and links from Grist into node/edge tables
154
+ """
155
+
156
+ # Build necessary DataModel elements from input tables
157
+ dm_anchors_df, dm_anchor_attrs_df, dm_link_attrs_df, dm_links_df, _q, _p, _cl = next(get_data_model())
158
+
159
+ # Anchors
160
+ dm_anchors: dict[str, Anchor] = {}
161
+ for _, a_row in dm_anchors_df.iterrows():
162
+ noun = str(a_row.get("noun")).strip()
163
+ if not noun:
164
+ continue
165
+ dm_anchors[noun] = Anchor(
166
+ noun=noun,
167
+ description=a_row.get("description", ""),
168
+ id_example=a_row.get("id_example", ""),
169
+ query=a_row.get("query", ""),
170
+ attributes=[],
171
+ )
172
+
173
+ # Anchor attributes
174
+ for _, attr_row in dm_anchor_attrs_df.iterrows():
175
+ noun = str(attr_row.get("anchor")).strip()
176
+ if not noun or noun not in dm_anchors:
177
+ continue
178
+ dm_anchors[noun].attributes.append(
179
+ Attribute(
180
+ name=attr_row.get("attribute_name", ""),
181
+ description=attr_row.get("description", ""),
182
+ example=attr_row.get("data_example", ""),
183
+ dtype=attr_row.get("dtype", ""),
184
+ query=attr_row.get("query", ""),
185
+ embeddable=bool(attr_row.get("embeddable", False)),
186
+ embed_threshold=float(attr_row.get("embed_threshold", 1.0)),
187
+ )
188
+ )
189
+
190
+ # Links
191
+ dm_links: dict[str, Link] = {}
192
+ for _, l_row in dm_links_df.iterrows():
193
+ a1 = str(l_row.get("anchor1")).strip()
194
+ a2 = str(l_row.get("anchor2")).strip()
195
+ if not a1 or not a2 or a1 not in dm_anchors or a2 not in dm_anchors:
196
+ logger.error(f'Link type has invalid anchors "{a1} - {a2}", skipping')
197
+ continue
198
+ dm_links[l_row.get("sentence")] = Link(
199
+ anchor_from=dm_anchors[a1],
200
+ anchor_to=dm_anchors[a2],
201
+ sentence=l_row.get("sentence"),
202
+ description=l_row.get("description", ""),
203
+ query=l_row.get("query", ""),
204
+ attributes=[],
205
+ has_direction=bool(l_row.get("has_direction", False)),
206
+ anchor_from_link_attr_name=l_row.get("anchor1_link_column_name", ""),
207
+ anchor_to_link_attr_name=l_row.get("anchor2_link_column_name", ""),
208
+ )
209
+
210
+ # Link attributes
211
+ for _, lattr_row in dm_link_attrs_df.iterrows():
212
+ sent = str(lattr_row.get("link")).strip()
213
+ if sent not in dm_links:
214
+ continue
215
+ dm_links[sent].attributes.append(
216
+ Attribute(
217
+ name=str(lattr_row.get("attribute_name")),
218
+ description=str(lattr_row.get("description", "")),
219
+ example=str(lattr_row.get("data_example", "")),
220
+ dtype=str(lattr_row.get("dtype", "")),
221
+ query=str(lattr_row.get("query", "")),
222
+ embeddable=bool(lattr_row.get("embeddable", False)),
223
+ embed_threshold=float(lattr_row.get("embed_threshold", 1.0)),
224
+ )
225
+ )
226
+
227
+ # Get data from Grist
228
+
229
+ dp = GristAPIDataProvider(
230
+ doc_id=core_settings.grist_data_doc_id,
231
+ grist_server=core_settings.grist_server_url,
232
+ api_key=core_settings.grist_api_key,
233
+ )
234
+
235
+ # Foreign key type links
236
+ fk_link_records_from = []
237
+ fk_link_records_to = []
238
+
239
+ # Nodes
240
+ node_records: dict[str, Any] = {}
241
+ anchor_types = dp.get_anchor_tables() # does not check data model! only lists tables that are named anchor_...
242
+ logger.debug(f"Fetching {len(anchor_types)} anchor tables from Grist: {anchor_types}")
243
+
244
+ for anchor_type in anchor_types:
245
+ # check anchor's existence in data model
246
+ dm_anchor = dm_anchors.get(anchor_type)
247
+ if not dm_anchor:
248
+ logger.error(f'Anchor "{anchor_type}" not described in data model, skipping')
249
+ continue
250
+ dm_anchor_attrs = [attr.name for attr in dm_anchor.attributes]
251
+
252
+ # get anchor's links
253
+ # todo check link column directions
254
+ anchor_from_link_cols = [
255
+ link
256
+ for link in dm_links.values()
257
+ if link.anchor_from.noun == anchor_type and link.anchor_from_link_attr_name
258
+ ]
259
+ anchor_to_link_cols = [
260
+ link for link in dm_links.values() if link.anchor_to.noun == anchor_type and link.anchor_to_link_attr_name
261
+ ]
262
+
263
+ try:
264
+ anchors = dp.get_anchors(anchor_type, dm_attrs=dm_anchor.attributes, dm_anchor_links=anchor_from_link_cols)
265
+ except Exception as exc:
266
+ logger.exception(f"Failed to fetch anchors for type {anchor_type}: {exc}")
267
+ continue
268
+
269
+ node_records[anchor_type] = {}
270
+
271
+ for a in anchors:
272
+ for link in anchor_from_link_cols:
273
+ # get link other end id(s)
274
+ link_ids = a.data.get(link.anchor_from_link_attr_name, [])
275
+ if isinstance(link_ids, (int, str)):
276
+ link_ids = [link_ids]
277
+ elif link_ids is None:
278
+ continue
279
+
280
+ for to_dp_id in link_ids:
281
+ fk_link_records_from.append(
282
+ {
283
+ "from_node_id": a.id,
284
+ "from_node_dp_id": a.dp_id,
285
+ # "to_node_id": link.id_to, <-- not provided here
286
+ "to_node_dp_id": to_dp_id,
287
+ "from_node_type": anchor_type,
288
+ "to_node_type": link.anchor_to.noun,
289
+ "edge_label": link.sentence,
290
+ # "attributes": {}, # not present in data (format not specified) yet
291
+ }
292
+ )
293
+
294
+ for link in anchor_to_link_cols:
295
+ # get link other end id(s)
296
+ link_ids = a.data.get(link.anchor_to_link_attr_name, [])
297
+ if isinstance(link_ids, (int, str)):
298
+ link_ids = [link_ids]
299
+ elif link_ids is None:
300
+ continue
301
+
302
+ for from_dp_id in link_ids:
303
+ fk_link_records_to.append(
304
+ {
305
+ # "from_node_id": link.id_from, <-- not provided here
306
+ "from_node_dp_id": from_dp_id,
307
+ "to_node_id": a.id,
308
+ "to_node_dp_id": a.dp_id,
309
+ "from_node_type": link.anchor_from.noun,
310
+ "to_node_type": anchor_type,
311
+ "edge_label": link.sentence,
312
+ # "attributes": {}, # not present in data (format not specified) yet
313
+ }
314
+ )
315
+
316
+ node_records[anchor_type][a.dp_id] = {
317
+ "node_id": a.id,
318
+ "node_type": a.type,
319
+ "attributes": {k: v for k, v in a.data.items() if k in dm_anchor_attrs} or {},
320
+ }
321
+
322
+ nodes_df = pd.DataFrame(
323
+ [
324
+ {"node_id": rec.get("node_id"), "node_type": rec.get("node_type"), "attributes": rec.get("attributes", {})}
325
+ for a in node_records.values()
326
+ for rec in a.values()
327
+ ],
328
+ columns=["node_id", "node_type", "attributes"],
329
+ )
330
+
331
+ # Resolve links (database id <-> our id), if necessary
332
+ for lk in fk_link_records_to:
333
+ if isinstance(lk["from_node_dp_id"], int):
334
+ lk["from_node_id"] = node_records[lk["from_node_type"]].get(lk["from_node_dp_id"], {}).get("node_id")
335
+ else:
336
+ lk["from_node_id"] = lk["from_node_dp_id"] # <-- str dp_id is an already correct id
337
+ for lk in fk_link_records_from:
338
+ if isinstance(lk["to_node_dp_id"], int):
339
+ lk["to_node_id"] = node_records[lk["to_node_type"]].get(lk["to_node_dp_id"], {}).get("node_id")
340
+ else:
341
+ lk["to_node_id"] = lk["to_node_dp_id"]
342
+
343
+ if fk_link_records_to:
344
+ fk_links_to_df = pd.DataFrame(fk_link_records_to).dropna(subset=["from_node_id", "to_node_id"])
345
+ else:
346
+ fk_links_to_df = pd.DataFrame(
347
+ columns=["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label"]
348
+ )
349
+
350
+ if fk_link_records_from:
351
+ fk_links_from_df = pd.DataFrame(fk_link_records_from).dropna(subset=["from_node_id", "to_node_id"])
352
+ else:
353
+ fk_links_from_df = pd.DataFrame(
354
+ columns=["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label"]
355
+ )
356
+
357
+ fk_df = pd.concat([fk_links_from_df, fk_links_to_df], axis=0, ignore_index=True)
358
+ fk_df["attributes"] = [dict()] * fk_df.shape[0]
359
+ fk_df = fk_df[["from_node_id", "to_node_id", "from_node_type", "to_node_type", "edge_label", "attributes"]]
360
+
361
+ # keep only links with both nodes present (+done in the end on edges_df); todo add test for this case
362
+ fk_df = fk_df.loc[(fk_df["from_node_id"].isin(nodes_df["node_id"]) & fk_df["to_node_id"].isin(nodes_df["node_id"]))]
363
+
364
+ # Edges
365
+ edge_records = []
366
+ link_types = dp.get_link_tables()
367
+ logger.debug(f"Fetching {len(link_types)} link types from Grist: {link_types}")
368
+
369
+ for link_type in link_types:
370
+ # check link's existence in data model (dm_link is used from anchor_from / to references only)
371
+ dm_link_list = [
372
+ link
373
+ for link in dm_links.values()
374
+ if link.sentence.lower() == link_type.lower()
375
+ or link_type.lower() == f"{link.anchor_from.noun}_{link.anchor_to.noun}".lower()
376
+ ]
377
+ if not dm_link_list:
378
+ logger.error(f'Link type "{link_type}" not described in data model, skipping')
379
+ continue
380
+ dm_link = dm_link_list[0]
381
+ dm_link_attrs = [a.name for a in dm_link.attributes]
382
+
383
+ try:
384
+ links = dp.get_links(link_type, dm_link)
385
+ except Exception as exc:
386
+ logger.error(f"Failed to fetch links for type {link_type}: {exc}")
387
+ continue
388
+
389
+ for link_record in links:
390
+ id_from = link_record.id_from
391
+ id_to = link_record.id_to
392
+
393
+ # resolve foreign key link_record id's
394
+ if isinstance(id_from, int):
395
+ id_from = node_records[link_record.anchor_from].get(id_from, {}).get("node_id")
396
+ if isinstance(id_to, int):
397
+ id_to = node_records[link_record.anchor_to].get(id_to, {}).get("node_id")
398
+
399
+ edge_records.append(
400
+ {
401
+ "from_node_id": id_from,
402
+ "to_node_id": id_to,
403
+ "from_node_type": link_record.anchor_from,
404
+ "to_node_type": link_record.anchor_to,
405
+ "edge_label": link_record.type,
406
+ "attributes": {k: v for k, v in link_record.data.items() if k in dm_link_attrs} or {},
407
+ }
408
+ )
409
+
410
+ edges_df = pd.DataFrame(edge_records)
411
+ edges_df = edges_df.loc[
412
+ (edges_df["from_node_id"].isin(nodes_df["node_id"]) & edges_df["to_node_id"].isin(nodes_df["node_id"]))
413
+ ]
414
+
415
+ edges_df = pd.concat([edges_df, fk_df], ignore_index=True)
416
+
417
+ # add reverse links (if already provided in data, duplicates will be removed later)
418
+ for link in dm_links.values():
419
+ if not link.has_direction:
420
+ rev_edges = cast(
421
+ pd.DataFrame,
422
+ edges_df.loc[
423
+ (
424
+ (
425
+ (edges_df["from_node_type"] == link.anchor_from.noun)
426
+ & (edges_df["to_node_type"] == link.anchor_to.noun)
427
+ )
428
+ | ( # edges with anchors written in reverse are also valid
429
+ (edges_df["from_node_type"] == link.anchor_to.noun)
430
+ & (edges_df["to_node_type"] == link.anchor_from.noun)
431
+ )
432
+ )
433
+ & (edges_df["edge_label"] == link.sentence)
434
+ ].copy(),
435
+ )
436
+ if not rev_edges.empty:
437
+ rev_edges = rev_edges.rename(
438
+ columns={
439
+ "from_node_id": "to_node_id",
440
+ "to_node_id": "from_node_id",
441
+ "from_node_type": "to_node_type",
442
+ "to_node_type": "from_node_type",
443
+ }
444
+ )
445
+ edges_df = pd.concat([edges_df, rev_edges], ignore_index=True)
446
+
447
+ # preventive drop_duplicates / na records
448
+ if not nodes_df.empty:
449
+ nodes_df = nodes_df.dropna(subset=["node_id", "node_type"]).drop_duplicates(subset=["node_id"])
450
+ if not edges_df.empty:
451
+ edges_df = edges_df.dropna(subset=["from_node_id", "to_node_id", "edge_label"]).drop_duplicates(
452
+ subset=["from_node_id", "to_node_id", "edge_label"]
453
+ )
454
+ yield nodes_df, edges_df
455
+
456
+
457
+ def ensure_memgraph_node_indexes(dm_anchor_attrs: pd.DataFrame) -> pd.DataFrame:
458
+ """
459
+ Create label / vector indices
460
+ https://memgraph.com/docs/querying/vector-search
461
+ """
462
+
463
+ anchor_types: set[str] = set(dm_anchor_attrs["anchor"].dropna().unique())
464
+
465
+ # embeddable attrs for vector indices
466
+ # vec_a_attr_rows = dm_anchor_attrs[dm_anchor_attrs["embeddable"]] # & (dm_anchor_attrs["dtype"].str.lower() == "str")]
467
+
468
+ driver = GraphDatabase.driver(
469
+ uri=core_settings.memgraph_uri,
470
+ auth=(
471
+ core_settings.memgraph_user,
472
+ core_settings.memgraph_pwd,
473
+ ),
474
+ )
475
+
476
+ with driver.session() as session:
477
+ # Indices on Anchors
478
+ for label in anchor_types:
479
+ try:
480
+ session.run(f"CREATE INDEX ON :`{label}`(id)") # type: ignore
481
+ except Exception as exc:
482
+ logger.debug(f"CREATE INDEX failed for label {label}: {exc}") # probably index exists
483
+
484
+ try:
485
+ session.run(f"CREATE CONSTRAINT ON (n:`{label}`) ASSERT n.id IS UNIQUE") # type: ignore
486
+ except Exception as exc:
487
+ logger.debug(f"CREATE CONSTRAINT failed for label {label}: {exc}") # probably index exists
488
+
489
+ # Vector indices - Anchors. Deprecated due to move to pgvectorstore
490
+ # for _, row in vec_a_attr_rows.iterrows():
491
+ # attr: str = row["attribute_name"]
492
+ # embeddings_dim = core_settings.embeddings_dim
493
+ # label = row["anchor"]
494
+ #
495
+ # idx_name = f"{label}_{attr}_embed_idx".replace(" ", "_")
496
+ # prop_name = f"{attr}_embedding"
497
+ #
498
+ # cypher = (
499
+ # f"CREATE VECTOR INDEX `{idx_name}` ON :`{label}`(`{prop_name}`) "
500
+ # f'WITH CONFIG {{"dimension": {embeddings_dim}, "capacity": 1024, "metric": "cos"}}'
501
+ # )
502
+ # try:
503
+ # session.run(cypher) # type: ignore
504
+ # except Exception as exc:
505
+ # logger.debug(f"CREATE VECTOR INDEX failed for {idx_name}: {exc}") # probably index exists
506
+ # continue
507
+
508
+ driver.close()
509
+
510
+ # nominal outputs
511
+ mg_anchor_indexes = pd.DataFrame({"anchor": list(anchor_types)})
512
+ # mg_anchor_vector_indexes = vec_a_attr_rows[["anchor", "attribute_name"]].copy()
513
+ return mg_anchor_indexes
514
+
515
+
516
+ def ensure_memgraph_edge_indexes(dm_link_attrs: pd.DataFrame) -> pd.DataFrame:
517
+ """
518
+ Create label / vector indices
519
+ https://memgraph.com/docs/querying/vector-search
520
+ """
521
+
522
+ link_types: set[str] = set(dm_link_attrs["link"].dropna().unique())
523
+
524
+ # embeddable attrs for vector indices
525
+ # vec_l_attr_rows = dm_link_attrs[dm_link_attrs["embeddable"]] # & (dm_link_attrs["dtype"].str.lower() == "str")]
526
+
527
+ driver = GraphDatabase.driver(
528
+ uri=core_settings.memgraph_uri,
529
+ auth=(
530
+ core_settings.memgraph_user,
531
+ core_settings.memgraph_pwd,
532
+ ),
533
+ )
534
+
535
+ with driver.session() as session:
536
+ # Indices on Edges (optimizes queries such as MATCH ()-[r:EDGE_TYPE]->() RETURN r;)
537
+ # If queried by edge property, will need to add property index (similar to above for Anchor)
538
+ for label in link_types:
539
+ try:
540
+ session.run(f"CREATE EDGE INDEX ON :`{label}`") # type: ignore
541
+ except Exception as exc:
542
+ logger.debug(f"CREATE EDGE INDEX failed for label {label}: {exc}") # probably index exists
543
+
544
+ # todo edge constraints?
545
+ # try:
546
+ # session.run(f"CREATE CONSTRAINT ON (n:`{label}`) ASSERT n.id IS UNIQUE") # type: ignore
547
+ # except Exception as exc:
548
+ # logger.debug(f"CREATE CONSTRAINT failed for label {label}: {exc}") # probably index exists
549
+
550
+ # Vector indices - Edges. Deprecated due to move to pgvectorstore
551
+ # for _, row in vec_l_attr_rows.iterrows():
552
+ # attr: str = row["attribute_name"]
553
+ # embeddings_dim = core_settings.embeddings_dim
554
+ # label = row["link"]
555
+ #
556
+ # idx_name = f"{label}_{attr}_embed_idx".replace(" ", "_")
557
+ # prop_name = f"{attr}_embedding"
558
+ #
559
+ # cypher = (
560
+ # f"CREATE VECTOR EDGE INDEX `{idx_name}` ON :`{label}`(`{prop_name}`) "
561
+ # f'WITH CONFIG {{"dimension": {embeddings_dim}, "capacity": 1024, "metric": "cos"}}'
562
+ # )
563
+ # try:
564
+ # session.run(cypher) # type: ignore
565
+ # except Exception as exc:
566
+ # logger.debug(f"CREATE VECTOR EDGE INDEX failed for {idx_name}: {exc}") # probably index exists
567
+ # continue
568
+
569
+ driver.close()
570
+
571
+ # nominal outputs
572
+ mg_link_indexes = pd.DataFrame({"link": list(link_types)})
573
+ # mg_link_vector_indexes = vec_l_attr_rows[["link", "attribute_name"]].copy()
574
+ return mg_link_indexes
575
+
576
+
577
+ def generate_embeddings(
578
+ df: pd.DataFrame,
579
+ dm_attributes: pd.DataFrame,
580
+ ) -> pd.DataFrame:
581
+ """Generate embeddings for embeddable text attributes"""
582
+ type_col = "node_type" if "node_id" in df.columns else "edge_label"
583
+ pkeys = ["node_id", "node_type"] if type_col == "node_type" else ["from_node_id", "to_node_id", "edge_label"]
584
+ dm_attributes = dm_attributes[dm_attributes["embeddable"]] # & (dm_attributes["dtype"].str.lower() == "str")]
585
+
586
+ # Build mapping type -> list[attribute_name] that need embedding
587
+ mapping: dict[str, list[str]] = {}
588
+ for _, row in dm_attributes.iterrows():
589
+ record_type = row["anchor"] if type_col == "node_type" else row["link"]
590
+ if pd.isna(record_type):
591
+ continue
592
+ mapping.setdefault(record_type, []).append(row["attribute_name"])
593
+
594
+ tasks: list[tuple[int, str, str]] = [] # (row_pos, attr_name, text)
595
+ for pos, (_, row) in enumerate(df.iterrows()):
596
+ typ_val = row[type_col]
597
+ attrs_needed = mapping.get(typ_val)
598
+ if not attrs_needed:
599
+ continue
600
+ row_attrs = row.get("attributes", {})
601
+ for attr_name in attrs_needed:
602
+ text_val = row_attrs.get(attr_name)
603
+ if text_val and isinstance(text_val, str) and not is_uuid(text_val):
604
+ tasks.append((pos, attr_name, text_val))
605
+
606
+ if not tasks:
607
+ return df
608
+
609
+ provider = LLMProvider()
610
+
611
+ unique_texts = list(dict.fromkeys([t[2] for t in tasks]))
612
+ vectors = provider.create_embeddings_sync(unique_texts)
613
+ vector_by_text = dict(zip(unique_texts, vectors))
614
+
615
+ # Apply embeddings
616
+ rows: list[dict[str, object]] = []
617
+ for row_pos, attr_name, text in tasks:
618
+ vec = vector_by_text[text]
619
+ rows.append(
620
+ {
621
+ **df.loc[row_pos, pkeys],
622
+ "attribute_name": attr_name,
623
+ "attribute_value": text,
624
+ "embedding": vec,
625
+ }
626
+ )
627
+
628
+ return pd.DataFrame.from_records(rows)
629
+
630
+
631
+ def pass_df_to_memgraph(
632
+ df: pd.DataFrame,
633
+ ) -> pd.DataFrame:
634
+ """dummy method for BatchTransform to MemgraphStore"""
635
+ return df.copy()
636
+
637
+
638
+ # ---
639
+ # Custom parts
640
+
641
+
642
+ def prepare_nodes(
643
+ grist_nodes_df: pd.DataFrame,
644
+ ) -> pd.DataFrame:
645
+ return grist_nodes_df.copy()
646
+
647
+
648
+ def prepare_edges(
649
+ grist_edges_df: pd.DataFrame,
650
+ ) -> pd.DataFrame:
651
+ return grist_edges_df.copy()
652
+
653
+
654
+ def get_eval_gds_from_grist() -> Iterator[pd.DataFrame]:
655
+ """
656
+ Loads evaluation dataset and config rows
657
+
658
+ Output:
659
+ - eval_gds(gds_question, gds_answer, question_context)
660
+ """
661
+ dp = GristAPIDataProvider(
662
+ doc_id=etl_settings.grist_test_set_doc_id,
663
+ grist_server=core_settings.grist_server_url,
664
+ api_key=core_settings.grist_api_key,
665
+ )
666
+
667
+ try:
668
+ gds_df = dp.get_table(etl_settings.gds_table_name)
669
+ except Exception as e:
670
+ logger.exception(f"Failed to get golden dataset {etl_settings.gds_table_name}: {e}")
671
+ raise e
672
+
673
+ gds_df = gds_df.dropna(subset=["gds_question", "gds_answer"]).copy()
674
+ gds_df = gds_df.loc[(gds_df["gds_question"] != "") & (gds_df["gds_answer"] != "")]
675
+ gds_df = gds_df[["gds_question", "gds_answer", "question_scenario", "question_comment", "question_context"]].astype(
676
+ {"gds_question": str, "gds_answer": str, "question_scenario": str, "question_comment": str}
677
+ )
678
+ yield gds_df
679
+
680
+
681
+ if __name__ == "__main__":
682
+ # import dotenv
683
+ # dotenv.load_dotenv()
684
+ # n, _l = next(get_grist_data())
685
+ ...