vedana-core 0.1.0.dev3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- vedana_core/__init__.py +0 -0
- vedana_core/app.py +78 -0
- vedana_core/data_model.py +465 -0
- vedana_core/data_provider.py +513 -0
- vedana_core/db.py +41 -0
- vedana_core/graph.py +300 -0
- vedana_core/llm.py +192 -0
- vedana_core/py.typed +0 -0
- vedana_core/rag_agent.py +234 -0
- vedana_core/rag_pipeline.py +326 -0
- vedana_core/settings.py +35 -0
- vedana_core/start_pipeline.py +17 -0
- vedana_core/utils.py +31 -0
- vedana_core/vts.py +167 -0
- vedana_core-0.1.0.dev3.dist-info/METADATA +29 -0
- vedana_core-0.1.0.dev3.dist-info/RECORD +17 -0
- vedana_core-0.1.0.dev3.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
import abc
|
|
2
|
+
import io
|
|
3
|
+
import sqlite3
|
|
4
|
+
import time
|
|
5
|
+
from ast import literal_eval
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, NamedTuple, Sequence, Union
|
|
9
|
+
|
|
10
|
+
import grist_api
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import requests
|
|
13
|
+
|
|
14
|
+
from vedana_core.data_model import Attribute, Link
|
|
15
|
+
from vedana_core.utils import cast_dtype
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclass
|
|
19
|
+
class Table:
|
|
20
|
+
columns: list[str]
|
|
21
|
+
rows: Sequence[Union[tuple, NamedTuple]]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
@dataclass
|
|
25
|
+
class AnchorRecord:
|
|
26
|
+
id: str
|
|
27
|
+
type: str
|
|
28
|
+
data: dict[str, Any]
|
|
29
|
+
dp_id: int | None = None
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class LinkRecord:
|
|
34
|
+
id_from: str
|
|
35
|
+
id_to: str
|
|
36
|
+
anchor_from: str
|
|
37
|
+
anchor_to: str
|
|
38
|
+
type: str
|
|
39
|
+
data: dict[str, Any]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class DataProvider:
|
|
43
|
+
def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
|
|
44
|
+
raise NotImplementedError("get_anchors must be implemented in subclass")
|
|
45
|
+
|
|
46
|
+
def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
|
|
47
|
+
raise NotImplementedError("get_links must be implemented in subclass")
|
|
48
|
+
|
|
49
|
+
def get_anchor_tables(self) -> list[str]:
|
|
50
|
+
raise NotImplementedError("get_anchor_tables must be implemented in subclass")
|
|
51
|
+
|
|
52
|
+
def list_anchor_tables(self) -> list[str]:
|
|
53
|
+
raise NotImplementedError("list_anchor_tables must be implemented in subclass")
|
|
54
|
+
|
|
55
|
+
def get_link_tables(self) -> list[str]:
|
|
56
|
+
raise NotImplementedError("get_link_tables must be implemented in subclass")
|
|
57
|
+
|
|
58
|
+
def close(self) -> None:
|
|
59
|
+
pass
|
|
60
|
+
|
|
61
|
+
def __enter__(self):
|
|
62
|
+
return self
|
|
63
|
+
|
|
64
|
+
def __exit__(self, *args, **kwargs):
|
|
65
|
+
self.close()
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class GristDataProvider(DataProvider):
|
|
69
|
+
anchor_table_prefix = "Anchor_"
|
|
70
|
+
link_table_prefix = "Link_"
|
|
71
|
+
|
|
72
|
+
def get_anchor_tables(self) -> list[str]:
|
|
73
|
+
prefix_len = len(self.anchor_table_prefix)
|
|
74
|
+
return [t[prefix_len:] for t in self.list_anchor_tables()]
|
|
75
|
+
|
|
76
|
+
@abc.abstractmethod
|
|
77
|
+
def list_anchor_tables(self) -> list[str]:
|
|
78
|
+
...
|
|
79
|
+
|
|
80
|
+
@abc.abstractmethod
|
|
81
|
+
def list_link_tables(self) -> list[str]:
|
|
82
|
+
...
|
|
83
|
+
|
|
84
|
+
@abc.abstractmethod
|
|
85
|
+
def get_table(self, table_name: str) -> pd.DataFrame:
|
|
86
|
+
...
|
|
87
|
+
|
|
88
|
+
def get_anchor_types(self) -> list[str]:
|
|
89
|
+
prefix_len = len(self.anchor_table_prefix)
|
|
90
|
+
return [t[prefix_len:] for t in self.list_anchor_tables()]
|
|
91
|
+
|
|
92
|
+
def get_link_tables(self) -> list[str]:
|
|
93
|
+
prefix_len = len(self.link_table_prefix)
|
|
94
|
+
return [t[prefix_len:] for t in self.list_link_tables()]
|
|
95
|
+
|
|
96
|
+
def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
|
|
97
|
+
table_name = f"{self.anchor_table_prefix}{type_}"
|
|
98
|
+
table = self.get_table(table_name)
|
|
99
|
+
if "id" not in table.columns:
|
|
100
|
+
table["id"] = [None] * table.shape[0]
|
|
101
|
+
table_records: list[dict] = table.to_dict(orient="records")
|
|
102
|
+
|
|
103
|
+
id_key = f"{type_}_id"
|
|
104
|
+
anchor_ids = set()
|
|
105
|
+
anchors: list[AnchorRecord] = []
|
|
106
|
+
|
|
107
|
+
def flatten_list_cells(el):
|
|
108
|
+
if isinstance(el, list) and el[0] == "L": # flatten List type fields
|
|
109
|
+
if len(el) == 2:
|
|
110
|
+
return el[1]
|
|
111
|
+
return el[1:]
|
|
112
|
+
elif pd.isna(el):
|
|
113
|
+
return None
|
|
114
|
+
return el
|
|
115
|
+
|
|
116
|
+
for row_dict in table_records:
|
|
117
|
+
db_id = flatten_list_cells(row_dict.pop("id"))
|
|
118
|
+
id_ = row_dict.pop(id_key, None)
|
|
119
|
+
if id_ is None:
|
|
120
|
+
id_ = f"{type_}:{db_id}"
|
|
121
|
+
if pd.isna(id_):
|
|
122
|
+
# print(f"{type_}:{db_id} has {id_key}=nan, skipping")
|
|
123
|
+
continue
|
|
124
|
+
|
|
125
|
+
row_dict = {k: flatten_list_cells(v) for k, v in row_dict.items()}
|
|
126
|
+
row_dict = {k: v for k, v in row_dict.items() if not isinstance(v, (bytes, type(None)))}
|
|
127
|
+
|
|
128
|
+
if id_ not in anchor_ids:
|
|
129
|
+
anchor_ids.add(id_)
|
|
130
|
+
else:
|
|
131
|
+
print(f'Duplicate anchor id "{id_}" in "{table_name}"\nduplicate data: {row_dict}\nrecord skipped')
|
|
132
|
+
continue
|
|
133
|
+
|
|
134
|
+
anchors.append(AnchorRecord(id_, type_, row_dict, dp_id=db_id)) # type: ignore
|
|
135
|
+
|
|
136
|
+
return anchors
|
|
137
|
+
|
|
138
|
+
def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
|
|
139
|
+
table_name = f"{self.link_table_prefix}{table_name}"
|
|
140
|
+
table = self.get_table(table_name)
|
|
141
|
+
if "id" not in table.columns:
|
|
142
|
+
table["id"] = [None] * table.shape[0]
|
|
143
|
+
table_records: list[dict] = table.to_dict(orient="records")
|
|
144
|
+
|
|
145
|
+
def flatten_list_cells(el):
|
|
146
|
+
if isinstance(el, list) and "L" in el and len(el) == 2: # flatten List type fields
|
|
147
|
+
el = el[1]
|
|
148
|
+
return el
|
|
149
|
+
|
|
150
|
+
links: list[LinkRecord] = []
|
|
151
|
+
for row_dict in table_records:
|
|
152
|
+
id_from = row_dict.pop("from_node_id")
|
|
153
|
+
id_to = row_dict.pop("to_node_id")
|
|
154
|
+
# node types - check node type in id's, else take link's definition
|
|
155
|
+
node_from = link.anchor_to.noun if id_from.split(":")[0] == (link.anchor_to.noun) else link.anchor_from.noun
|
|
156
|
+
node_to = link.anchor_to.noun if node_from != link.anchor_to.noun else link.anchor_from.noun
|
|
157
|
+
type_ = row_dict.pop("edge_label")
|
|
158
|
+
row_dict.pop("id")
|
|
159
|
+
|
|
160
|
+
if not id_from or not id_to or not type_ or pd.isna(type_) or pd.isna(id_from) or pd.isna(id_to):
|
|
161
|
+
continue
|
|
162
|
+
row_dict = {str(k): flatten_list_cells(v) for k, v in row_dict.items()}
|
|
163
|
+
row_dict = {
|
|
164
|
+
k: v for k, v in row_dict.items() if not isinstance(v, (bytes, list, type(None))) and not pd.isna(v)
|
|
165
|
+
}
|
|
166
|
+
|
|
167
|
+
links.append(LinkRecord(id_from, id_to, node_from, node_to, type_, row_dict))
|
|
168
|
+
|
|
169
|
+
return links
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
class GristOfflineDataProvider(GristDataProvider):
|
|
173
|
+
"""
|
|
174
|
+
Load from grist backup file (.grist)
|
|
175
|
+
"""
|
|
176
|
+
|
|
177
|
+
def __init__(self, sqlite_path: Path) -> None:
|
|
178
|
+
super().__init__()
|
|
179
|
+
self._conn = sqlite3.connect(sqlite_path)
|
|
180
|
+
|
|
181
|
+
def close(self):
|
|
182
|
+
self._conn.close()
|
|
183
|
+
|
|
184
|
+
def list_anchor_tables(self) -> list[str]:
|
|
185
|
+
rows = self._conn.execute(
|
|
186
|
+
f"SELECT name FROM sqlite_master WHERE type='table' and name like '{self.anchor_table_prefix}%'"
|
|
187
|
+
).fetchall()
|
|
188
|
+
return [row[0] for row in rows]
|
|
189
|
+
|
|
190
|
+
def list_link_tables(self) -> list[str]:
|
|
191
|
+
rows = self._conn.execute(
|
|
192
|
+
f"SELECT name FROM sqlite_master WHERE type='table' and name like '{self.link_table_prefix}%'"
|
|
193
|
+
).fetchall()
|
|
194
|
+
return [row[0] for row in rows]
|
|
195
|
+
|
|
196
|
+
def get_table(self, table_name: str) -> pd.DataFrame:
|
|
197
|
+
cur = self._conn.execute(f"SELECT * FROM {table_name}")
|
|
198
|
+
columns = [t[0] for t in cur.description]
|
|
199
|
+
rows = cur.fetchall()
|
|
200
|
+
return pd.DataFrame.from_records(rows, columns=columns)
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
class GristCsvDataProvider(GristDataProvider):
|
|
204
|
+
def __init__(self, doc_id: str, grist_server: str, api_key: str | None = None) -> None:
|
|
205
|
+
super().__init__()
|
|
206
|
+
self.doc_id = doc_id
|
|
207
|
+
self.grist_server = grist_server
|
|
208
|
+
self.api_key = api_key
|
|
209
|
+
self._table_list: list[str] = self._fetch_table_list()
|
|
210
|
+
|
|
211
|
+
def _fetch_table_list(self) -> list[str]:
|
|
212
|
+
url = f"{self.grist_server}/api/docs/{self.doc_id}/tables"
|
|
213
|
+
resp = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"})
|
|
214
|
+
resp.raise_for_status()
|
|
215
|
+
return [t["id"] for t in resp.json()["tables"]]
|
|
216
|
+
|
|
217
|
+
def get_table(self, table_name: str) -> pd.DataFrame:
|
|
218
|
+
url = f"{self.grist_server}/api/docs/{self.doc_id}/download/csv?tableId={table_name}"
|
|
219
|
+
resp = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=600)
|
|
220
|
+
resp.raise_for_status()
|
|
221
|
+
df = pd.read_csv(io.StringIO(resp.text))
|
|
222
|
+
if "id" not in df.columns:
|
|
223
|
+
df = df.reset_index(drop=False, names=["id"])
|
|
224
|
+
return df
|
|
225
|
+
|
|
226
|
+
def _list_tables_with_prefix(self, prefix: str) -> list[str]:
|
|
227
|
+
return [t for t in self._table_list if t.startswith(prefix)]
|
|
228
|
+
|
|
229
|
+
def list_anchor_tables(self) -> list[str]:
|
|
230
|
+
return self._list_tables_with_prefix(self.anchor_table_prefix)
|
|
231
|
+
|
|
232
|
+
def list_link_tables(self) -> list[str]:
|
|
233
|
+
return self._list_tables_with_prefix(self.link_table_prefix)
|
|
234
|
+
|
|
235
|
+
|
|
236
|
+
class GristAPIDataProvider(GristDataProvider):
|
|
237
|
+
def __init__(self, doc_id: str, grist_server: str, api_key: str | None = None) -> None:
|
|
238
|
+
super().__init__()
|
|
239
|
+
self.grist_server = grist_server
|
|
240
|
+
self.doc_id = doc_id
|
|
241
|
+
self.api_key = api_key
|
|
242
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
243
|
+
self._client = grist_api.GristDocAPI(doc_id, api_key=api_key, server=grist_server)
|
|
244
|
+
|
|
245
|
+
def _list_tables_with_prefix(self, prefix: str) -> list[str]:
|
|
246
|
+
resp = self._client.tables()
|
|
247
|
+
if not resp:
|
|
248
|
+
return []
|
|
249
|
+
table_ids: list[str] = [table["id"] for table in resp.json()["tables"]]
|
|
250
|
+
return [t_id for t_id in table_ids if t_id.startswith(prefix)]
|
|
251
|
+
|
|
252
|
+
def list_anchor_tables(self) -> list[str]:
|
|
253
|
+
return self._list_tables_with_prefix(self.anchor_table_prefix)
|
|
254
|
+
|
|
255
|
+
def list_link_tables(self) -> list[str]:
|
|
256
|
+
return self._list_tables_with_prefix(self.link_table_prefix)
|
|
257
|
+
|
|
258
|
+
def _list_table_columns(self, table_name: str) -> dict[str, str]:
|
|
259
|
+
"""returns mapping {internal_grist_id: id_in_UI}"""
|
|
260
|
+
# There is no label (hidden=True/False) on columns so we have to do 2 api calls
|
|
261
|
+
view_cols = self._client.columns(table_name) # does not show hidden columns i.e. "id"
|
|
262
|
+
all_cols = requests.get( # shows all columns, including internal IDs and helper columns
|
|
263
|
+
f"{self.grist_server}/api/docs/{self.doc_id}/tables/{table_name}/columns?hidden=True", headers=self.headers
|
|
264
|
+
)
|
|
265
|
+
if not view_cols or not all_cols:
|
|
266
|
+
return {}
|
|
267
|
+
|
|
268
|
+
view_cols = view_cols.json()["columns"]
|
|
269
|
+
all_cols = all_cols.json()["columns"]
|
|
270
|
+
|
|
271
|
+
# get display columns
|
|
272
|
+
col_ref_label_map = {c["fields"]["colRef"]: c["fields"]["label"] for c in all_cols}
|
|
273
|
+
|
|
274
|
+
parsed_cols = {
|
|
275
|
+
c["id"]: c["fields"]["label"]
|
|
276
|
+
if not c["fields"]["displayCol"]
|
|
277
|
+
else col_ref_label_map.get(c["fields"]["displayCol"], c["fields"]["label"])
|
|
278
|
+
for c in view_cols
|
|
279
|
+
}
|
|
280
|
+
return parsed_cols
|
|
281
|
+
|
|
282
|
+
def get_table(self, table_name: str) -> pd.DataFrame:
|
|
283
|
+
columns = self._list_table_columns(table_name)
|
|
284
|
+
if "id" not in columns: # add internal id
|
|
285
|
+
columns["id"] = "id"
|
|
286
|
+
rows = self._client.fetch_table(table_name)
|
|
287
|
+
rows = [{c_id: getattr(r, c_label) for c_id, c_label in columns.items()} for r in rows] # filter usage
|
|
288
|
+
return pd.DataFrame(rows, columns=list(columns.keys()))
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
class GristSQLDataProvider(GristDataProvider):
|
|
292
|
+
"""fetches rows via the /sql endpoint in chunks"""
|
|
293
|
+
|
|
294
|
+
def __init__(
|
|
295
|
+
self,
|
|
296
|
+
doc_id: str,
|
|
297
|
+
grist_server: str,
|
|
298
|
+
api_key: str | None = None,
|
|
299
|
+
*,
|
|
300
|
+
batch_size: int = 800,
|
|
301
|
+
) -> None:
|
|
302
|
+
super().__init__()
|
|
303
|
+
self.doc_id = doc_id
|
|
304
|
+
self.grist_server = grist_server
|
|
305
|
+
self.api_key = api_key
|
|
306
|
+
self.batch_size = max(1, batch_size)
|
|
307
|
+
self.headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
308
|
+
self._client = grist_api.GristDocAPI(doc_id, server=grist_server, api_key=api_key)
|
|
309
|
+
|
|
310
|
+
def reset_doc(self, doc_id: str):
|
|
311
|
+
print("reopening doc")
|
|
312
|
+
resp = requests.post(
|
|
313
|
+
f"{self.grist_server}/api/docs/{doc_id}/force-reload",
|
|
314
|
+
headers=self.headers,
|
|
315
|
+
timeout=120,
|
|
316
|
+
)
|
|
317
|
+
resp.raise_for_status()
|
|
318
|
+
|
|
319
|
+
def clean_doc_history(self, doc_id: str):
|
|
320
|
+
print("cleaning memory")
|
|
321
|
+
resp = requests.post(
|
|
322
|
+
f"{self.grist_server}/api/docs/{doc_id}/states/remove",
|
|
323
|
+
headers=self.headers,
|
|
324
|
+
json={"keep": 1},
|
|
325
|
+
timeout=120,
|
|
326
|
+
)
|
|
327
|
+
resp.raise_for_status()
|
|
328
|
+
|
|
329
|
+
def _sql_endpoint(self) -> str:
|
|
330
|
+
"""Fully qualified URL of the /sql POST endpoint for this document."""
|
|
331
|
+
return f"{self.grist_server}/api/docs/{self.doc_id}/sql"
|
|
332
|
+
|
|
333
|
+
def _run_sql(self, sql: str, att=1) -> list[dict]:
|
|
334
|
+
try:
|
|
335
|
+
resp = requests.post(
|
|
336
|
+
self._sql_endpoint(),
|
|
337
|
+
json={"sql": sql},
|
|
338
|
+
headers=self.headers,
|
|
339
|
+
timeout=180,
|
|
340
|
+
)
|
|
341
|
+
resp.raise_for_status()
|
|
342
|
+
data = resp.json() # { "records": [ { "fields": {...} }, ... ] }.
|
|
343
|
+
data = data.get("records", [])
|
|
344
|
+
except Exception as e:
|
|
345
|
+
print(f"Failed to fetch {sql}: {e}\n Retrying in 120 seconds...") # todo logging
|
|
346
|
+
time.sleep(120)
|
|
347
|
+
self.clean_doc_history(self.doc_id)
|
|
348
|
+
self.reset_doc(self.doc_id)
|
|
349
|
+
if att < 3:
|
|
350
|
+
att += 1
|
|
351
|
+
data = self._run_sql(sql, att)
|
|
352
|
+
else:
|
|
353
|
+
raise e
|
|
354
|
+
return data
|
|
355
|
+
|
|
356
|
+
def _list_tables_with_prefix(self, prefix: str) -> list[str]:
|
|
357
|
+
resp = self._client.tables()
|
|
358
|
+
if not resp:
|
|
359
|
+
return []
|
|
360
|
+
table_ids: list[str] = [table["id"] for table in resp.json()["tables"]]
|
|
361
|
+
return [t_id for t_id in table_ids if t_id.startswith(prefix)]
|
|
362
|
+
|
|
363
|
+
def list_anchor_tables(self) -> list[str]:
|
|
364
|
+
return self._list_tables_with_prefix(self.anchor_table_prefix)
|
|
365
|
+
|
|
366
|
+
def list_link_tables(self) -> list[str]:
|
|
367
|
+
return self._list_tables_with_prefix(self.link_table_prefix)
|
|
368
|
+
|
|
369
|
+
def _iter_table_rows(self, table_name: str):
|
|
370
|
+
offset = 0
|
|
371
|
+
while True:
|
|
372
|
+
sql = f'SELECT * FROM "{table_name}" ORDER BY id LIMIT {self.batch_size} OFFSET {offset}'
|
|
373
|
+
batch = self._run_sql(sql)
|
|
374
|
+
if not batch:
|
|
375
|
+
break
|
|
376
|
+
for record in batch:
|
|
377
|
+
yield record.get("fields", {})
|
|
378
|
+
offset += self.batch_size
|
|
379
|
+
|
|
380
|
+
def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
|
|
381
|
+
dtypes = {e.name: e.dtype for e in dm_attrs}
|
|
382
|
+
table_name = f"{self.anchor_table_prefix}{type_}"
|
|
383
|
+
default_id_key = "node_id"
|
|
384
|
+
fk_links_cols = [c.anchor_from_link_attr_name for c in dm_anchor_links]
|
|
385
|
+
|
|
386
|
+
anchors: dict[str, AnchorRecord] = {}
|
|
387
|
+
|
|
388
|
+
def flatten(el):
|
|
389
|
+
if isinstance(el, list):
|
|
390
|
+
if el:
|
|
391
|
+
if el[0] == "L":
|
|
392
|
+
return el[1] if len(el) == 2 else el[1:]
|
|
393
|
+
elif pd.isna(el):
|
|
394
|
+
return None
|
|
395
|
+
return el
|
|
396
|
+
|
|
397
|
+
def safe_fk_list_convert(el):
|
|
398
|
+
try:
|
|
399
|
+
if isinstance(el, str):
|
|
400
|
+
el = literal_eval(el)
|
|
401
|
+
except SyntaxError:
|
|
402
|
+
print(f"str el not list: {el}")
|
|
403
|
+
return el
|
|
404
|
+
|
|
405
|
+
for row in self._iter_table_rows(table_name):
|
|
406
|
+
db_id = flatten(row.get("id"))
|
|
407
|
+
_id = row.get(f"{type_}_id") or row.get(default_id_key) or f"{type_}:{db_id}"
|
|
408
|
+
if pd.isna(_id) or (isinstance(_id, dict) and _id.get("type") == "Buffer"):
|
|
409
|
+
continue
|
|
410
|
+
|
|
411
|
+
row.pop("id", None)
|
|
412
|
+
row.pop("node_type", None)
|
|
413
|
+
row.pop(default_id_key, None)
|
|
414
|
+
# row.pop(f"{type_}_id", None)
|
|
415
|
+
|
|
416
|
+
# grist meta columns
|
|
417
|
+
row.pop("manualSort", None)
|
|
418
|
+
|
|
419
|
+
row = {
|
|
420
|
+
k: v
|
|
421
|
+
for k, v in row.items()
|
|
422
|
+
if not (isinstance(v, dict) and v.get("type") == "Buffer")
|
|
423
|
+
and not k.startswith("gristHelper_")
|
|
424
|
+
and not pd.isna(v) # type: ignore
|
|
425
|
+
}
|
|
426
|
+
|
|
427
|
+
# foreign key link columns - either int id or a list of int ids cast to string
|
|
428
|
+
for col in fk_links_cols:
|
|
429
|
+
if col in row:
|
|
430
|
+
row[col] = safe_fk_list_convert(row[col])
|
|
431
|
+
|
|
432
|
+
clean_data = {
|
|
433
|
+
k: cast_dtype(flatten(v), k, dtypes.get(k))
|
|
434
|
+
for k, v in row.items()
|
|
435
|
+
if not isinstance(v, (bytes, type(None))) and v != ""
|
|
436
|
+
}
|
|
437
|
+
|
|
438
|
+
if _id in anchors: # node already present - populate its attributes
|
|
439
|
+
# anchors[_id].data.update(clean_data)
|
|
440
|
+
print(f"duplicate {type_} id {_id}, skipping...")
|
|
441
|
+
else:
|
|
442
|
+
anchors[_id] = AnchorRecord(str(_id), type_, clean_data, dp_id=db_id) # type: ignore
|
|
443
|
+
|
|
444
|
+
return list(anchors.values())
|
|
445
|
+
|
|
446
|
+
def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
|
|
447
|
+
table_name = f"{self.link_table_prefix}{table_name}"
|
|
448
|
+
columns_resp = self._client.columns(table_name)
|
|
449
|
+
if not columns_resp:
|
|
450
|
+
return []
|
|
451
|
+
|
|
452
|
+
def flatten(el):
|
|
453
|
+
if isinstance(el, list) and "L" in el and len(el) == 2:
|
|
454
|
+
el = el[1]
|
|
455
|
+
return el
|
|
456
|
+
|
|
457
|
+
links: list[LinkRecord] = []
|
|
458
|
+
for row in self._iter_table_rows(table_name):
|
|
459
|
+
id_from = row.get("id_from") or row.get("from_node_id")
|
|
460
|
+
id_to = row.get("id_to") or row.get("to_node_id")
|
|
461
|
+
edge_label = row.get("type") or row.get("edge_label")
|
|
462
|
+
if not id_from or not id_to or not edge_label or pd.isna(edge_label) or pd.isna(id_from) or pd.isna(id_to):
|
|
463
|
+
continue
|
|
464
|
+
|
|
465
|
+
row.pop("manualSort", None)
|
|
466
|
+
row.pop("from_node_id", None)
|
|
467
|
+
row.pop("to_node_id", None)
|
|
468
|
+
row.pop("edge_label", None)
|
|
469
|
+
row.pop("id_from", None)
|
|
470
|
+
row.pop("id_to", None)
|
|
471
|
+
row.pop("type", None)
|
|
472
|
+
row.pop("id", None)
|
|
473
|
+
|
|
474
|
+
clean_data = {
|
|
475
|
+
k: flatten(v)
|
|
476
|
+
for k, v in row.items()
|
|
477
|
+
if not isinstance(v, (bytes, list))
|
|
478
|
+
and not pd.isna(v)
|
|
479
|
+
and not (isinstance(v, str) and v == "")
|
|
480
|
+
and not k.startswith("gristHelper_")
|
|
481
|
+
}
|
|
482
|
+
|
|
483
|
+
links.append(
|
|
484
|
+
LinkRecord(
|
|
485
|
+
str(id_from),
|
|
486
|
+
str(id_to),
|
|
487
|
+
link.anchor_from.noun,
|
|
488
|
+
link.anchor_to.noun,
|
|
489
|
+
str(edge_label),
|
|
490
|
+
clean_data,
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
return links
|
|
495
|
+
|
|
496
|
+
def get_table(self, table_name: str) -> pd.DataFrame:
|
|
497
|
+
"""
|
|
498
|
+
Table with 1st batch_size rows.
|
|
499
|
+
To iterate over the remaining rows use _iter_table_rows
|
|
500
|
+
"""
|
|
501
|
+
cols_resp = self._client.columns(table_name)
|
|
502
|
+
columns: list[str] = []
|
|
503
|
+
if cols_resp:
|
|
504
|
+
columns = [c["id"] for c in cols_resp.json()["columns"]]
|
|
505
|
+
|
|
506
|
+
rows_data = []
|
|
507
|
+
for i, row in enumerate(self._iter_table_rows(table_name)):
|
|
508
|
+
if i >= self.batch_size:
|
|
509
|
+
break
|
|
510
|
+
# row dict to tuple in deterministic column order
|
|
511
|
+
rows_data.append(tuple(row.get(col) for col in columns))
|
|
512
|
+
|
|
513
|
+
return pd.DataFrame(rows_data, columns=columns)
|
vedana_core/db.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
from functools import cache
|
|
3
|
+
|
|
4
|
+
import sqlalchemy as sa
|
|
5
|
+
import sqlalchemy.ext.asyncio as sa_aio
|
|
6
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DbSettings(BaseSettings):
|
|
10
|
+
model_config = SettingsConfigDict(env_prefix="JIMS_", env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
|
11
|
+
|
|
12
|
+
db_conn_uri: str = "postgresql://postgres:postgres@localhost:5432"
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
db_settings = DbSettings() # type: ignore
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
# This is needed because each async loop needs its own engine
|
|
19
|
+
@cache
|
|
20
|
+
def _create_async_engine(loop):
|
|
21
|
+
return sa_aio.create_async_engine(
|
|
22
|
+
db_settings.db_conn_uri.replace("postgresql://", "postgresql+asyncpg://").replace(
|
|
23
|
+
"sqlite://", "sqlite+aiosqlite://"
|
|
24
|
+
)
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def get_async_db_engine() -> sa_aio.AsyncEngine:
|
|
29
|
+
return _create_async_engine(asyncio.get_event_loop())
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def get_db_engine() -> sa.Engine:
|
|
33
|
+
return sa.create_engine(db_settings.db_conn_uri)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def get_sessionmaker() -> sa_aio.async_sessionmaker[sa_aio.AsyncSession]:
|
|
37
|
+
return sa_aio.async_sessionmaker(
|
|
38
|
+
bind=get_async_db_engine(),
|
|
39
|
+
expire_on_commit=False,
|
|
40
|
+
future=True,
|
|
41
|
+
)
|