vedana-core 0.1.0.dev3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,513 @@
1
+ import abc
2
+ import io
3
+ import sqlite3
4
+ import time
5
+ from ast import literal_eval
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any, NamedTuple, Sequence, Union
9
+
10
+ import grist_api
11
+ import pandas as pd
12
+ import requests
13
+
14
+ from vedana_core.data_model import Attribute, Link
15
+ from vedana_core.utils import cast_dtype
16
+
17
+
18
+ @dataclass
19
+ class Table:
20
+ columns: list[str]
21
+ rows: Sequence[Union[tuple, NamedTuple]]
22
+
23
+
24
+ @dataclass
25
+ class AnchorRecord:
26
+ id: str
27
+ type: str
28
+ data: dict[str, Any]
29
+ dp_id: int | None = None
30
+
31
+
32
+ @dataclass
33
+ class LinkRecord:
34
+ id_from: str
35
+ id_to: str
36
+ anchor_from: str
37
+ anchor_to: str
38
+ type: str
39
+ data: dict[str, Any]
40
+
41
+
42
+ class DataProvider:
43
+ def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
44
+ raise NotImplementedError("get_anchors must be implemented in subclass")
45
+
46
+ def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
47
+ raise NotImplementedError("get_links must be implemented in subclass")
48
+
49
+ def get_anchor_tables(self) -> list[str]:
50
+ raise NotImplementedError("get_anchor_tables must be implemented in subclass")
51
+
52
+ def list_anchor_tables(self) -> list[str]:
53
+ raise NotImplementedError("list_anchor_tables must be implemented in subclass")
54
+
55
+ def get_link_tables(self) -> list[str]:
56
+ raise NotImplementedError("get_link_tables must be implemented in subclass")
57
+
58
+ def close(self) -> None:
59
+ pass
60
+
61
+ def __enter__(self):
62
+ return self
63
+
64
+ def __exit__(self, *args, **kwargs):
65
+ self.close()
66
+
67
+
68
+ class GristDataProvider(DataProvider):
69
+ anchor_table_prefix = "Anchor_"
70
+ link_table_prefix = "Link_"
71
+
72
+ def get_anchor_tables(self) -> list[str]:
73
+ prefix_len = len(self.anchor_table_prefix)
74
+ return [t[prefix_len:] for t in self.list_anchor_tables()]
75
+
76
+ @abc.abstractmethod
77
+ def list_anchor_tables(self) -> list[str]:
78
+ ...
79
+
80
+ @abc.abstractmethod
81
+ def list_link_tables(self) -> list[str]:
82
+ ...
83
+
84
+ @abc.abstractmethod
85
+ def get_table(self, table_name: str) -> pd.DataFrame:
86
+ ...
87
+
88
+ def get_anchor_types(self) -> list[str]:
89
+ prefix_len = len(self.anchor_table_prefix)
90
+ return [t[prefix_len:] for t in self.list_anchor_tables()]
91
+
92
+ def get_link_tables(self) -> list[str]:
93
+ prefix_len = len(self.link_table_prefix)
94
+ return [t[prefix_len:] for t in self.list_link_tables()]
95
+
96
+ def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
97
+ table_name = f"{self.anchor_table_prefix}{type_}"
98
+ table = self.get_table(table_name)
99
+ if "id" not in table.columns:
100
+ table["id"] = [None] * table.shape[0]
101
+ table_records: list[dict] = table.to_dict(orient="records")
102
+
103
+ id_key = f"{type_}_id"
104
+ anchor_ids = set()
105
+ anchors: list[AnchorRecord] = []
106
+
107
+ def flatten_list_cells(el):
108
+ if isinstance(el, list) and el[0] == "L": # flatten List type fields
109
+ if len(el) == 2:
110
+ return el[1]
111
+ return el[1:]
112
+ elif pd.isna(el):
113
+ return None
114
+ return el
115
+
116
+ for row_dict in table_records:
117
+ db_id = flatten_list_cells(row_dict.pop("id"))
118
+ id_ = row_dict.pop(id_key, None)
119
+ if id_ is None:
120
+ id_ = f"{type_}:{db_id}"
121
+ if pd.isna(id_):
122
+ # print(f"{type_}:{db_id} has {id_key}=nan, skipping")
123
+ continue
124
+
125
+ row_dict = {k: flatten_list_cells(v) for k, v in row_dict.items()}
126
+ row_dict = {k: v for k, v in row_dict.items() if not isinstance(v, (bytes, type(None)))}
127
+
128
+ if id_ not in anchor_ids:
129
+ anchor_ids.add(id_)
130
+ else:
131
+ print(f'Duplicate anchor id "{id_}" in "{table_name}"\nduplicate data: {row_dict}\nrecord skipped')
132
+ continue
133
+
134
+ anchors.append(AnchorRecord(id_, type_, row_dict, dp_id=db_id)) # type: ignore
135
+
136
+ return anchors
137
+
138
+ def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
139
+ table_name = f"{self.link_table_prefix}{table_name}"
140
+ table = self.get_table(table_name)
141
+ if "id" not in table.columns:
142
+ table["id"] = [None] * table.shape[0]
143
+ table_records: list[dict] = table.to_dict(orient="records")
144
+
145
+ def flatten_list_cells(el):
146
+ if isinstance(el, list) and "L" in el and len(el) == 2: # flatten List type fields
147
+ el = el[1]
148
+ return el
149
+
150
+ links: list[LinkRecord] = []
151
+ for row_dict in table_records:
152
+ id_from = row_dict.pop("from_node_id")
153
+ id_to = row_dict.pop("to_node_id")
154
+ # node types - check node type in id's, else take link's definition
155
+ node_from = link.anchor_to.noun if id_from.split(":")[0] == (link.anchor_to.noun) else link.anchor_from.noun
156
+ node_to = link.anchor_to.noun if node_from != link.anchor_to.noun else link.anchor_from.noun
157
+ type_ = row_dict.pop("edge_label")
158
+ row_dict.pop("id")
159
+
160
+ if not id_from or not id_to or not type_ or pd.isna(type_) or pd.isna(id_from) or pd.isna(id_to):
161
+ continue
162
+ row_dict = {str(k): flatten_list_cells(v) for k, v in row_dict.items()}
163
+ row_dict = {
164
+ k: v for k, v in row_dict.items() if not isinstance(v, (bytes, list, type(None))) and not pd.isna(v)
165
+ }
166
+
167
+ links.append(LinkRecord(id_from, id_to, node_from, node_to, type_, row_dict))
168
+
169
+ return links
170
+
171
+
172
+ class GristOfflineDataProvider(GristDataProvider):
173
+ """
174
+ Load from grist backup file (.grist)
175
+ """
176
+
177
+ def __init__(self, sqlite_path: Path) -> None:
178
+ super().__init__()
179
+ self._conn = sqlite3.connect(sqlite_path)
180
+
181
+ def close(self):
182
+ self._conn.close()
183
+
184
+ def list_anchor_tables(self) -> list[str]:
185
+ rows = self._conn.execute(
186
+ f"SELECT name FROM sqlite_master WHERE type='table' and name like '{self.anchor_table_prefix}%'"
187
+ ).fetchall()
188
+ return [row[0] for row in rows]
189
+
190
+ def list_link_tables(self) -> list[str]:
191
+ rows = self._conn.execute(
192
+ f"SELECT name FROM sqlite_master WHERE type='table' and name like '{self.link_table_prefix}%'"
193
+ ).fetchall()
194
+ return [row[0] for row in rows]
195
+
196
+ def get_table(self, table_name: str) -> pd.DataFrame:
197
+ cur = self._conn.execute(f"SELECT * FROM {table_name}")
198
+ columns = [t[0] for t in cur.description]
199
+ rows = cur.fetchall()
200
+ return pd.DataFrame.from_records(rows, columns=columns)
201
+
202
+
203
+ class GristCsvDataProvider(GristDataProvider):
204
+ def __init__(self, doc_id: str, grist_server: str, api_key: str | None = None) -> None:
205
+ super().__init__()
206
+ self.doc_id = doc_id
207
+ self.grist_server = grist_server
208
+ self.api_key = api_key
209
+ self._table_list: list[str] = self._fetch_table_list()
210
+
211
+ def _fetch_table_list(self) -> list[str]:
212
+ url = f"{self.grist_server}/api/docs/{self.doc_id}/tables"
213
+ resp = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"})
214
+ resp.raise_for_status()
215
+ return [t["id"] for t in resp.json()["tables"]]
216
+
217
+ def get_table(self, table_name: str) -> pd.DataFrame:
218
+ url = f"{self.grist_server}/api/docs/{self.doc_id}/download/csv?tableId={table_name}"
219
+ resp = requests.get(url, headers={"Authorization": f"Bearer {self.api_key}"}, timeout=600)
220
+ resp.raise_for_status()
221
+ df = pd.read_csv(io.StringIO(resp.text))
222
+ if "id" not in df.columns:
223
+ df = df.reset_index(drop=False, names=["id"])
224
+ return df
225
+
226
+ def _list_tables_with_prefix(self, prefix: str) -> list[str]:
227
+ return [t for t in self._table_list if t.startswith(prefix)]
228
+
229
+ def list_anchor_tables(self) -> list[str]:
230
+ return self._list_tables_with_prefix(self.anchor_table_prefix)
231
+
232
+ def list_link_tables(self) -> list[str]:
233
+ return self._list_tables_with_prefix(self.link_table_prefix)
234
+
235
+
236
+ class GristAPIDataProvider(GristDataProvider):
237
+ def __init__(self, doc_id: str, grist_server: str, api_key: str | None = None) -> None:
238
+ super().__init__()
239
+ self.grist_server = grist_server
240
+ self.doc_id = doc_id
241
+ self.api_key = api_key
242
+ self.headers = {"Authorization": f"Bearer {self.api_key}"}
243
+ self._client = grist_api.GristDocAPI(doc_id, api_key=api_key, server=grist_server)
244
+
245
+ def _list_tables_with_prefix(self, prefix: str) -> list[str]:
246
+ resp = self._client.tables()
247
+ if not resp:
248
+ return []
249
+ table_ids: list[str] = [table["id"] for table in resp.json()["tables"]]
250
+ return [t_id for t_id in table_ids if t_id.startswith(prefix)]
251
+
252
+ def list_anchor_tables(self) -> list[str]:
253
+ return self._list_tables_with_prefix(self.anchor_table_prefix)
254
+
255
+ def list_link_tables(self) -> list[str]:
256
+ return self._list_tables_with_prefix(self.link_table_prefix)
257
+
258
+ def _list_table_columns(self, table_name: str) -> dict[str, str]:
259
+ """returns mapping {internal_grist_id: id_in_UI}"""
260
+ # There is no label (hidden=True/False) on columns so we have to do 2 api calls
261
+ view_cols = self._client.columns(table_name) # does not show hidden columns i.e. "id"
262
+ all_cols = requests.get( # shows all columns, including internal IDs and helper columns
263
+ f"{self.grist_server}/api/docs/{self.doc_id}/tables/{table_name}/columns?hidden=True", headers=self.headers
264
+ )
265
+ if not view_cols or not all_cols:
266
+ return {}
267
+
268
+ view_cols = view_cols.json()["columns"]
269
+ all_cols = all_cols.json()["columns"]
270
+
271
+ # get display columns
272
+ col_ref_label_map = {c["fields"]["colRef"]: c["fields"]["label"] for c in all_cols}
273
+
274
+ parsed_cols = {
275
+ c["id"]: c["fields"]["label"]
276
+ if not c["fields"]["displayCol"]
277
+ else col_ref_label_map.get(c["fields"]["displayCol"], c["fields"]["label"])
278
+ for c in view_cols
279
+ }
280
+ return parsed_cols
281
+
282
+ def get_table(self, table_name: str) -> pd.DataFrame:
283
+ columns = self._list_table_columns(table_name)
284
+ if "id" not in columns: # add internal id
285
+ columns["id"] = "id"
286
+ rows = self._client.fetch_table(table_name)
287
+ rows = [{c_id: getattr(r, c_label) for c_id, c_label in columns.items()} for r in rows] # filter usage
288
+ return pd.DataFrame(rows, columns=list(columns.keys()))
289
+
290
+
291
+ class GristSQLDataProvider(GristDataProvider):
292
+ """fetches rows via the /sql endpoint in chunks"""
293
+
294
+ def __init__(
295
+ self,
296
+ doc_id: str,
297
+ grist_server: str,
298
+ api_key: str | None = None,
299
+ *,
300
+ batch_size: int = 800,
301
+ ) -> None:
302
+ super().__init__()
303
+ self.doc_id = doc_id
304
+ self.grist_server = grist_server
305
+ self.api_key = api_key
306
+ self.batch_size = max(1, batch_size)
307
+ self.headers = {"Authorization": f"Bearer {self.api_key}"}
308
+ self._client = grist_api.GristDocAPI(doc_id, server=grist_server, api_key=api_key)
309
+
310
+ def reset_doc(self, doc_id: str):
311
+ print("reopening doc")
312
+ resp = requests.post(
313
+ f"{self.grist_server}/api/docs/{doc_id}/force-reload",
314
+ headers=self.headers,
315
+ timeout=120,
316
+ )
317
+ resp.raise_for_status()
318
+
319
+ def clean_doc_history(self, doc_id: str):
320
+ print("cleaning memory")
321
+ resp = requests.post(
322
+ f"{self.grist_server}/api/docs/{doc_id}/states/remove",
323
+ headers=self.headers,
324
+ json={"keep": 1},
325
+ timeout=120,
326
+ )
327
+ resp.raise_for_status()
328
+
329
+ def _sql_endpoint(self) -> str:
330
+ """Fully qualified URL of the /sql POST endpoint for this document."""
331
+ return f"{self.grist_server}/api/docs/{self.doc_id}/sql"
332
+
333
+ def _run_sql(self, sql: str, att=1) -> list[dict]:
334
+ try:
335
+ resp = requests.post(
336
+ self._sql_endpoint(),
337
+ json={"sql": sql},
338
+ headers=self.headers,
339
+ timeout=180,
340
+ )
341
+ resp.raise_for_status()
342
+ data = resp.json() # { "records": [ { "fields": {...} }, ... ] }.
343
+ data = data.get("records", [])
344
+ except Exception as e:
345
+ print(f"Failed to fetch {sql}: {e}\n Retrying in 120 seconds...") # todo logging
346
+ time.sleep(120)
347
+ self.clean_doc_history(self.doc_id)
348
+ self.reset_doc(self.doc_id)
349
+ if att < 3:
350
+ att += 1
351
+ data = self._run_sql(sql, att)
352
+ else:
353
+ raise e
354
+ return data
355
+
356
+ def _list_tables_with_prefix(self, prefix: str) -> list[str]:
357
+ resp = self._client.tables()
358
+ if not resp:
359
+ return []
360
+ table_ids: list[str] = [table["id"] for table in resp.json()["tables"]]
361
+ return [t_id for t_id in table_ids if t_id.startswith(prefix)]
362
+
363
+ def list_anchor_tables(self) -> list[str]:
364
+ return self._list_tables_with_prefix(self.anchor_table_prefix)
365
+
366
+ def list_link_tables(self) -> list[str]:
367
+ return self._list_tables_with_prefix(self.link_table_prefix)
368
+
369
+ def _iter_table_rows(self, table_name: str):
370
+ offset = 0
371
+ while True:
372
+ sql = f'SELECT * FROM "{table_name}" ORDER BY id LIMIT {self.batch_size} OFFSET {offset}'
373
+ batch = self._run_sql(sql)
374
+ if not batch:
375
+ break
376
+ for record in batch:
377
+ yield record.get("fields", {})
378
+ offset += self.batch_size
379
+
380
+ def get_anchors(self, type_: str, dm_attrs: list[Attribute], dm_anchor_links: list[Link]) -> list[AnchorRecord]:
381
+ dtypes = {e.name: e.dtype for e in dm_attrs}
382
+ table_name = f"{self.anchor_table_prefix}{type_}"
383
+ default_id_key = "node_id"
384
+ fk_links_cols = [c.anchor_from_link_attr_name for c in dm_anchor_links]
385
+
386
+ anchors: dict[str, AnchorRecord] = {}
387
+
388
+ def flatten(el):
389
+ if isinstance(el, list):
390
+ if el:
391
+ if el[0] == "L":
392
+ return el[1] if len(el) == 2 else el[1:]
393
+ elif pd.isna(el):
394
+ return None
395
+ return el
396
+
397
+ def safe_fk_list_convert(el):
398
+ try:
399
+ if isinstance(el, str):
400
+ el = literal_eval(el)
401
+ except SyntaxError:
402
+ print(f"str el not list: {el}")
403
+ return el
404
+
405
+ for row in self._iter_table_rows(table_name):
406
+ db_id = flatten(row.get("id"))
407
+ _id = row.get(f"{type_}_id") or row.get(default_id_key) or f"{type_}:{db_id}"
408
+ if pd.isna(_id) or (isinstance(_id, dict) and _id.get("type") == "Buffer"):
409
+ continue
410
+
411
+ row.pop("id", None)
412
+ row.pop("node_type", None)
413
+ row.pop(default_id_key, None)
414
+ # row.pop(f"{type_}_id", None)
415
+
416
+ # grist meta columns
417
+ row.pop("manualSort", None)
418
+
419
+ row = {
420
+ k: v
421
+ for k, v in row.items()
422
+ if not (isinstance(v, dict) and v.get("type") == "Buffer")
423
+ and not k.startswith("gristHelper_")
424
+ and not pd.isna(v) # type: ignore
425
+ }
426
+
427
+ # foreign key link columns - either int id or a list of int ids cast to string
428
+ for col in fk_links_cols:
429
+ if col in row:
430
+ row[col] = safe_fk_list_convert(row[col])
431
+
432
+ clean_data = {
433
+ k: cast_dtype(flatten(v), k, dtypes.get(k))
434
+ for k, v in row.items()
435
+ if not isinstance(v, (bytes, type(None))) and v != ""
436
+ }
437
+
438
+ if _id in anchors: # node already present - populate its attributes
439
+ # anchors[_id].data.update(clean_data)
440
+ print(f"duplicate {type_} id {_id}, skipping...")
441
+ else:
442
+ anchors[_id] = AnchorRecord(str(_id), type_, clean_data, dp_id=db_id) # type: ignore
443
+
444
+ return list(anchors.values())
445
+
446
+ def get_links(self, table_name: str, link: Link) -> list[LinkRecord]:
447
+ table_name = f"{self.link_table_prefix}{table_name}"
448
+ columns_resp = self._client.columns(table_name)
449
+ if not columns_resp:
450
+ return []
451
+
452
+ def flatten(el):
453
+ if isinstance(el, list) and "L" in el and len(el) == 2:
454
+ el = el[1]
455
+ return el
456
+
457
+ links: list[LinkRecord] = []
458
+ for row in self._iter_table_rows(table_name):
459
+ id_from = row.get("id_from") or row.get("from_node_id")
460
+ id_to = row.get("id_to") or row.get("to_node_id")
461
+ edge_label = row.get("type") or row.get("edge_label")
462
+ if not id_from or not id_to or not edge_label or pd.isna(edge_label) or pd.isna(id_from) or pd.isna(id_to):
463
+ continue
464
+
465
+ row.pop("manualSort", None)
466
+ row.pop("from_node_id", None)
467
+ row.pop("to_node_id", None)
468
+ row.pop("edge_label", None)
469
+ row.pop("id_from", None)
470
+ row.pop("id_to", None)
471
+ row.pop("type", None)
472
+ row.pop("id", None)
473
+
474
+ clean_data = {
475
+ k: flatten(v)
476
+ for k, v in row.items()
477
+ if not isinstance(v, (bytes, list))
478
+ and not pd.isna(v)
479
+ and not (isinstance(v, str) and v == "")
480
+ and not k.startswith("gristHelper_")
481
+ }
482
+
483
+ links.append(
484
+ LinkRecord(
485
+ str(id_from),
486
+ str(id_to),
487
+ link.anchor_from.noun,
488
+ link.anchor_to.noun,
489
+ str(edge_label),
490
+ clean_data,
491
+ )
492
+ )
493
+
494
+ return links
495
+
496
+ def get_table(self, table_name: str) -> pd.DataFrame:
497
+ """
498
+ Table with 1st batch_size rows.
499
+ To iterate over the remaining rows use _iter_table_rows
500
+ """
501
+ cols_resp = self._client.columns(table_name)
502
+ columns: list[str] = []
503
+ if cols_resp:
504
+ columns = [c["id"] for c in cols_resp.json()["columns"]]
505
+
506
+ rows_data = []
507
+ for i, row in enumerate(self._iter_table_rows(table_name)):
508
+ if i >= self.batch_size:
509
+ break
510
+ # row dict to tuple in deterministic column order
511
+ rows_data.append(tuple(row.get(col) for col in columns))
512
+
513
+ return pd.DataFrame(rows_data, columns=columns)
vedana_core/db.py ADDED
@@ -0,0 +1,41 @@
1
+ import asyncio
2
+ from functools import cache
3
+
4
+ import sqlalchemy as sa
5
+ import sqlalchemy.ext.asyncio as sa_aio
6
+ from pydantic_settings import BaseSettings, SettingsConfigDict
7
+
8
+
9
+ class DbSettings(BaseSettings):
10
+ model_config = SettingsConfigDict(env_prefix="JIMS_", env_file=".env", env_file_encoding="utf-8", extra="ignore")
11
+
12
+ db_conn_uri: str = "postgresql://postgres:postgres@localhost:5432"
13
+
14
+
15
+ db_settings = DbSettings() # type: ignore
16
+
17
+
18
+ # This is needed because each async loop needs its own engine
19
+ @cache
20
+ def _create_async_engine(loop):
21
+ return sa_aio.create_async_engine(
22
+ db_settings.db_conn_uri.replace("postgresql://", "postgresql+asyncpg://").replace(
23
+ "sqlite://", "sqlite+aiosqlite://"
24
+ )
25
+ )
26
+
27
+
28
+ def get_async_db_engine() -> sa_aio.AsyncEngine:
29
+ return _create_async_engine(asyncio.get_event_loop())
30
+
31
+
32
+ def get_db_engine() -> sa.Engine:
33
+ return sa.create_engine(db_settings.db_conn_uri)
34
+
35
+
36
+ def get_sessionmaker() -> sa_aio.async_sessionmaker[sa_aio.AsyncSession]:
37
+ return sa_aio.async_sessionmaker(
38
+ bind=get_async_db_engine(),
39
+ expire_on_commit=False,
40
+ future=True,
41
+ )