vedana-etl 0.1.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
vedana_etl/__init__.py ADDED
File without changes
vedana_etl/app.py ADDED
@@ -0,0 +1,10 @@
1
+ from datapipe.compute import Catalog
2
+ from datapipe_app import DatapipeAPI
3
+
4
+ from vedana_etl.config import ds
5
+ from vedana_etl.pipeline import default_custom_steps, get_pipeline
6
+
7
+ # base app - no extra tables / steps
8
+ pipeline = get_pipeline(custom_steps=default_custom_steps)
9
+
10
+ app = DatapipeAPI(ds, Catalog({}), pipeline)
vedana_etl/catalog.py ADDED
@@ -0,0 +1,266 @@
1
+ from datapipe.compute import Table
2
+ from datapipe.store.database import TableStoreDB
3
+ from datapipe.store.neo4j import Neo4JStore
4
+ from pgvector.sqlalchemy import Vector
5
+ from sqlalchemy import Boolean, Column, Float, String
6
+ from vedana_core.settings import settings as core_settings
7
+
8
+ import vedana_etl.schemas as schemas
9
+ from vedana_etl.config import DBCONN_DATAPIPE, MEMGRAPH_CONN_ARGS
10
+
11
+ dm_links = Table(
12
+ name="dm_links",
13
+ store=TableStoreDB(
14
+ dbconn=DBCONN_DATAPIPE,
15
+ name="dm_links",
16
+ data_sql_schema=[
17
+ Column("anchor1", String, primary_key=True),
18
+ Column("anchor2", String, primary_key=True),
19
+ Column("sentence", String, primary_key=True),
20
+ Column("description", String),
21
+ Column("query", String),
22
+ Column("anchor1_link_column_name", String),
23
+ Column("anchor2_link_column_name", String),
24
+ Column("has_direction", Boolean, default=False),
25
+ ],
26
+ ),
27
+ )
28
+
29
+ dm_anchor_attributes = Table(
30
+ name="dm_anchor_attributes",
31
+ store=TableStoreDB(
32
+ dbconn=DBCONN_DATAPIPE,
33
+ name="dm_anchor_attributes",
34
+ data_sql_schema=[
35
+ Column("anchor", String, primary_key=True),
36
+ Column("attribute_name", String, primary_key=True),
37
+ Column("description", String),
38
+ Column("data_example", String),
39
+ Column("embeddable", Boolean),
40
+ Column("query", String),
41
+ Column("dtype", String),
42
+ Column("embed_threshold", Float),
43
+ ],
44
+ ),
45
+ )
46
+
47
+ dm_link_attributes = Table(
48
+ name="dm_link_attributes",
49
+ store=TableStoreDB(
50
+ dbconn=DBCONN_DATAPIPE,
51
+ name="dm_link_attributes",
52
+ data_sql_schema=[
53
+ Column("link", String, primary_key=True),
54
+ Column("attribute_name", String, primary_key=True),
55
+ Column("description", String),
56
+ Column("data_example", String),
57
+ Column("embeddable", Boolean),
58
+ Column("query", String),
59
+ Column("dtype", String),
60
+ Column("embed_threshold", Float),
61
+ ],
62
+ ),
63
+ )
64
+
65
+ dm_anchors = Table(
66
+ name="dm_anchors",
67
+ store=TableStoreDB(
68
+ dbconn=DBCONN_DATAPIPE,
69
+ name="dm_anchors",
70
+ data_sql_schema=[
71
+ Column("noun", String, primary_key=True),
72
+ Column("description", String),
73
+ Column("id_example", String),
74
+ Column("query", String),
75
+ ],
76
+ ),
77
+ )
78
+
79
+ dm_queries = Table(
80
+ name="dm_queries",
81
+ store=TableStoreDB(
82
+ dbconn=DBCONN_DATAPIPE,
83
+ name="dm_queries",
84
+ data_sql_schema=[
85
+ Column("query_name", String, primary_key=True),
86
+ Column("query_example", String),
87
+ ],
88
+ ),
89
+ )
90
+
91
+ dm_prompts = Table(
92
+ name="dm_prompts",
93
+ store=TableStoreDB(
94
+ dbconn=DBCONN_DATAPIPE,
95
+ name="dm_prompts",
96
+ data_sql_schema=[
97
+ Column("name", String, primary_key=True),
98
+ Column("text", String),
99
+ ],
100
+ ),
101
+ )
102
+
103
+ dm_conversation_lifecycle = Table(
104
+ name="dm_conversation_lifecycle",
105
+ store=TableStoreDB(
106
+ dbconn=DBCONN_DATAPIPE,
107
+ name="dm_conversation_lifecycle",
108
+ data_sql_schema=[
109
+ Column("event", String, primary_key=True),
110
+ Column("text", String),
111
+ ],
112
+ ),
113
+ )
114
+
115
+ grist_nodes = Table(
116
+ name="grist_nodes",
117
+ store=TableStoreDB(
118
+ dbconn=DBCONN_DATAPIPE,
119
+ name="grist_nodes",
120
+ data_sql_schema=schemas.GENERIC_NODE_DATA_SCHEMA,
121
+ ),
122
+ )
123
+
124
+ grist_edges = Table(
125
+ name="grist_edges",
126
+ store=TableStoreDB(
127
+ dbconn=DBCONN_DATAPIPE,
128
+ name="grist_edges",
129
+ data_sql_schema=schemas.GENERIC_EDGE_DATA_SCHEMA,
130
+ ),
131
+ )
132
+
133
+ # --- Tables used as input for memgraph ---
134
+
135
+ nodes = Table(
136
+ name="nodes",
137
+ store=TableStoreDB(
138
+ dbconn=DBCONN_DATAPIPE,
139
+ name="nodes",
140
+ data_sql_schema=schemas.GENERIC_NODE_DATA_SCHEMA,
141
+ ),
142
+ )
143
+
144
+ edges = Table(
145
+ name="edges",
146
+ store=TableStoreDB(
147
+ dbconn=DBCONN_DATAPIPE,
148
+ name="edges",
149
+ data_sql_schema=schemas.GENERIC_EDGE_DATA_SCHEMA,
150
+ ),
151
+ )
152
+
153
+ # --- Memgraph-related tables ---
154
+
155
+ memgraph_anchor_indexes = Table(
156
+ name="memgraph_anchor_indexes",
157
+ store=TableStoreDB(
158
+ dbconn=DBCONN_DATAPIPE,
159
+ name="memgraph_anchor_indexes",
160
+ data_sql_schema=[
161
+ Column("anchor", String, primary_key=True),
162
+ ],
163
+ ),
164
+ )
165
+
166
+ memgraph_link_indexes = Table(
167
+ name="memgraph_link_indexes",
168
+ store=TableStoreDB(
169
+ dbconn=DBCONN_DATAPIPE,
170
+ name="memgraph_link_indexes",
171
+ data_sql_schema=[
172
+ Column("link", String, primary_key=True),
173
+ ],
174
+ ),
175
+ )
176
+
177
+ memgraph_anchor_vector_indexes = Table(
178
+ name="memgraph_anchor_vector_indexes",
179
+ store=TableStoreDB(
180
+ dbconn=DBCONN_DATAPIPE,
181
+ name="memgraph_anchor_vector_indexes",
182
+ data_sql_schema=[
183
+ Column("anchor", String, primary_key=True),
184
+ Column("attribute_name", String, primary_key=True),
185
+ ],
186
+ ),
187
+ )
188
+
189
+ memgraph_link_vector_indexes = Table(
190
+ name="memgraph_link_vector_indexes",
191
+ store=TableStoreDB(
192
+ dbconn=DBCONN_DATAPIPE,
193
+ name="memgraph_link_vector_indexes",
194
+ data_sql_schema=[
195
+ Column("link", String, primary_key=True),
196
+ Column("attribute_name", String, primary_key=True),
197
+ ],
198
+ ),
199
+ )
200
+
201
+ memgraph_nodes = Table(
202
+ name="memgraph_nodes",
203
+ store=Neo4JStore(
204
+ connection_kwargs=MEMGRAPH_CONN_ARGS,
205
+ data_sql_schema=schemas.GENERIC_NODE_DATA_SCHEMA,
206
+ ),
207
+ )
208
+
209
+ memgraph_edges = Table(
210
+ name="memgraph_edges",
211
+ store=Neo4JStore(
212
+ connection_kwargs=MEMGRAPH_CONN_ARGS,
213
+ data_sql_schema=schemas.GENERIC_EDGE_DATA_SCHEMA,
214
+ ),
215
+ )
216
+
217
+ # --- VTS (pgvector) ---
218
+ # embedding size column is fixed for indexing and is defined through settings. Definition is then fixed in migrations
219
+
220
+ rag_anchor_embeddings = Table(
221
+ name="rag_anchor_embeddings",
222
+ store=TableStoreDB(
223
+ dbconn=DBCONN_DATAPIPE,
224
+ name="rag_anchor_embeddings",
225
+ data_sql_schema=[
226
+ Column("node_id", String, primary_key=True),
227
+ Column("attribute_name", String, primary_key=True),
228
+ Column("label", String, nullable=False),
229
+ Column("attribute_value", String),
230
+ Column("embedding", Vector(dim=core_settings.embeddings_dim), nullable=False),
231
+ ],
232
+ ),
233
+ )
234
+
235
+ rag_edge_embeddings = Table(
236
+ name="rag_edge_embeddings",
237
+ store=TableStoreDB(
238
+ dbconn=DBCONN_DATAPIPE,
239
+ name="rag_edge_embeddings",
240
+ data_sql_schema=[
241
+ Column("from_node_id", String, primary_key=True),
242
+ Column("to_node_id", String, primary_key=True),
243
+ Column("edge_label", String, primary_key=True),
244
+ Column("attribute_name", String, primary_key=True),
245
+ Column("attribute_value", String),
246
+ Column("embedding", Vector(dim=core_settings.embeddings_dim), nullable=False),
247
+ ],
248
+ ),
249
+ )
250
+
251
+ # --- Eval pipeline ---
252
+
253
+ eval_gds = Table(
254
+ name="eval_gds",
255
+ store=TableStoreDB(
256
+ dbconn=DBCONN_DATAPIPE,
257
+ name="eval_gds",
258
+ data_sql_schema=[
259
+ Column("gds_question", String, primary_key=True),
260
+ Column("gds_answer", String),
261
+ Column("question_scenario", String),
262
+ Column("question_comment", String),
263
+ Column("question_context", String),
264
+ ],
265
+ ),
266
+ )
vedana_etl/config.py ADDED
@@ -0,0 +1,22 @@
1
+ import json
2
+ from functools import partial
3
+
4
+ from datapipe.compute import Catalog
5
+ from datapipe.datatable import DataStore
6
+ from datapipe.store.database import DBConn
7
+ from vedana_core.settings import settings as core_settings
8
+
9
+ from vedana_etl.settings import settings
10
+
11
+ MEMGRAPH_CONN_ARGS = {
12
+ "uri": core_settings.memgraph_uri,
13
+ "auth": (core_settings.memgraph_user, core_settings.memgraph_pwd),
14
+ }
15
+
16
+ DBCONN_DATAPIPE = DBConn(
17
+ connstr=settings.db_conn_uri, create_engine_kwargs=dict(json_serializer=partial(json.dumps, ensure_ascii=False))
18
+ )
19
+
20
+
21
+ ds = DataStore(DBCONN_DATAPIPE)
22
+ catalog = Catalog({})
vedana_etl/pipeline.py ADDED
@@ -0,0 +1,142 @@
1
+ from datapipe.compute import Pipeline
2
+ from datapipe.step.batch_generate import BatchGenerate
3
+ from datapipe.step.batch_transform import BatchTransform
4
+
5
+ import vedana_etl.steps as steps
6
+ from vedana_etl.catalog import (
7
+ dm_anchor_attributes,
8
+ dm_anchors,
9
+ dm_conversation_lifecycle,
10
+ dm_link_attributes,
11
+ dm_links,
12
+ dm_prompts,
13
+ dm_queries,
14
+ edges,
15
+ grist_edges,
16
+ grist_nodes,
17
+ memgraph_anchor_indexes,
18
+ memgraph_edges,
19
+ memgraph_link_indexes,
20
+ memgraph_nodes,
21
+ nodes,
22
+ rag_anchor_embeddings,
23
+ rag_edge_embeddings,
24
+ eval_gds,
25
+ )
26
+
27
+ data_model_steps = [
28
+ BatchGenerate(
29
+ func=steps.get_data_model, # Generator with main graph data
30
+ outputs=[
31
+ dm_anchors,
32
+ dm_anchor_attributes,
33
+ dm_link_attributes,
34
+ dm_links,
35
+ dm_queries,
36
+ dm_prompts,
37
+ dm_conversation_lifecycle,
38
+ ],
39
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "extract"), ("stage", "data-model")],
40
+ ),
41
+ ]
42
+
43
+ grist_steps = [
44
+ BatchGenerate(
45
+ func=steps.get_grist_data,
46
+ outputs=[grist_nodes, grist_edges],
47
+ labels=[("flow", "on-demand"), ("stage", "extract"), ("stage", "grist")],
48
+ ),
49
+ ]
50
+
51
+ # ---
52
+ # This part is customisable (can be replaced with a connection of other branches
53
+
54
+ default_custom_steps = [
55
+ BatchTransform(
56
+ func=steps.prepare_nodes,
57
+ inputs=[grist_nodes],
58
+ outputs=[nodes],
59
+ labels=[("flow", "on-demand"), ("stage", "transform"), ("stage", "grist")],
60
+ transform_keys=["node_id"],
61
+ ),
62
+ BatchTransform(
63
+ func=steps.prepare_edges,
64
+ inputs=[grist_edges],
65
+ outputs=[edges],
66
+ labels=[("flow", "on-demand"), ("stage", "transform"), ("stage", "grist")],
67
+ transform_keys=["from_node_id", "to_node_id", "edge_label"],
68
+ ),
69
+ ]
70
+
71
+ # --- Loading data to Memgraph and Vector Store ---
72
+
73
+ memgraph_steps = [
74
+ BatchTransform(
75
+ func=steps.ensure_memgraph_node_indexes,
76
+ inputs=[dm_anchor_attributes],
77
+ outputs=[memgraph_anchor_indexes],
78
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
79
+ transform_keys=["attribute_name"],
80
+ ),
81
+ BatchTransform(
82
+ func=steps.ensure_memgraph_edge_indexes,
83
+ inputs=[dm_link_attributes],
84
+ outputs=[memgraph_link_indexes],
85
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
86
+ transform_keys=["attribute_name"],
87
+ ),
88
+ BatchTransform(
89
+ func=steps.pass_df_to_memgraph,
90
+ inputs=[nodes],
91
+ outputs=[memgraph_nodes],
92
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
93
+ transform_keys=["node_id", "node_type"],
94
+ ),
95
+ BatchTransform(
96
+ func=steps.pass_df_to_memgraph,
97
+ inputs=[edges],
98
+ outputs=[memgraph_edges],
99
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
100
+ transform_keys=["from_node_id", "to_node_id", "edge_label"],
101
+ ),
102
+ BatchTransform(
103
+ func=steps.generate_embeddings,
104
+ inputs=[nodes, dm_anchor_attributes],
105
+ outputs=[rag_anchor_embeddings],
106
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
107
+ transform_keys=["node_id", "node_type"],
108
+ ),
109
+ BatchTransform(
110
+ func=steps.generate_embeddings,
111
+ inputs=[edges, dm_link_attributes],
112
+ outputs=[rag_edge_embeddings],
113
+ labels=[("flow", "regular"), ("flow", "on-demand"), ("stage", "load")],
114
+ transform_keys=["from_node_id", "to_node_id", "edge_label"],
115
+ ),
116
+ ]
117
+
118
+ eval_steps = [
119
+ BatchGenerate(
120
+ func=steps.get_eval_gds_from_grist,
121
+ outputs=[eval_gds],
122
+ labels=[("pipeline", "eval"), ("flow", "eval"), ("stage", "extract")],
123
+ ),
124
+ ]
125
+
126
+
127
+ def get_data_model_pipeline() -> Pipeline:
128
+ return Pipeline(data_model_steps)
129
+
130
+
131
+ def get_pipeline(custom_steps: list) -> Pipeline:
132
+ pipeline = Pipeline(
133
+ [
134
+ *data_model_steps,
135
+ *grist_steps,
136
+ *custom_steps,
137
+ *memgraph_steps,
138
+ *eval_steps,
139
+ ]
140
+ )
141
+
142
+ return pipeline
vedana_etl/py.typed ADDED
File without changes
vedana_etl/schemas.py ADDED
@@ -0,0 +1,31 @@
1
+ from sqlalchemy import Column, String
2
+ from sqlalchemy.dialects.postgresql import JSONB
3
+
4
+ GENERIC_NODE_DATA_SCHEMA: list[Column] = [
5
+ Column("node_id", String, primary_key=True),
6
+ Column("node_type", String, primary_key=True),
7
+ Column("attributes", JSONB),
8
+ ]
9
+
10
+ GENERIC_EDGE_DATA_SCHEMA: list[Column] = [
11
+ Column("from_node_id", String, primary_key=True),
12
+ Column("to_node_id", String, primary_key=True),
13
+ Column("from_node_type", String, primary_key=True),
14
+ Column("to_node_type", String, primary_key=True),
15
+ Column("edge_label", String, primary_key=True),
16
+ Column("attributes", JSONB),
17
+ ]
18
+
19
+ # ---
20
+ # Evaluation pipeline schemas
21
+
22
+ DM_VERSIONING_TABLE_SCHEMA: list[Column] = [
23
+ Column("dm_id", String, primary_key=True),
24
+ Column("dm_description", String),
25
+ ]
26
+
27
+ EVAL_GDS_SCHEMA: list[Column] = [
28
+ Column("gds_question", String, primary_key=True),
29
+ Column("gds_answer", String),
30
+ Column("question_context", String),
31
+ ]
vedana_etl/settings.py ADDED
@@ -0,0 +1,23 @@
1
+ from pydantic_settings import BaseSettings, SettingsConfigDict
2
+
3
+
4
+ class Settings(BaseSettings):
5
+ model_config = SettingsConfigDict(
6
+ env_prefix="",
7
+ env_file=".env",
8
+ env_file_encoding="utf-8",
9
+ extra="ignore",
10
+ )
11
+
12
+ # Datapipe connection URI
13
+ # db_conn_uri: str = "sqlite+pysqlite3:///db.sqlite"
14
+ db_conn_uri: str
15
+
16
+ # Tests pipeline (vedana-eval) settings.
17
+ grist_test_set_doc_id: str = ""
18
+ gds_table_name: str = "Gds" # Table names in the test set doc
19
+ tests_table_name: str = "Tests"
20
+ test_environment: str = ""
21
+
22
+
23
+ settings = Settings() # type: ignore