tigrbl_engine_numpy 0.1.1.dev2__tar.gz → 0.1.1.dev4__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -30,9 +30,13 @@ pkgs/standards/peagen/.pymon
30
30
  gateway.db
31
31
  kms.db
32
32
  *.asc
33
+ *.kid
33
34
  *.so
34
35
  *.db
35
36
  target/
36
37
  !.gitkeep # keep the empty dir in repo
37
38
  pkgs/experimental/swarmakit/libs/svelte/.vscode/extensions.json
38
39
  node_modules/
40
+ *.zip
41
+ .pymon
42
+ /.tmp_pydeps
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tigrbl_engine_numpy
3
- Version: 0.1.1.dev2
3
+ Version: 0.1.1.dev4
4
4
  Summary: NumPy engine plugin for tigrbl with array-to-table helpers.
5
5
  Project-URL: Homepage, https://github.com/swarmauri/swarmauri-sdk
6
6
  Author-email: Jacob Stewart <jacob@swarmauri.com>
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "tigrbl_engine_numpy"
7
- version = "0.1.1.dev2"
7
+ version = "0.1.1.dev4"
8
8
  description = "NumPy engine plugin for tigrbl with array-to-table helpers."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -16,7 +16,49 @@ from typing import (
16
16
  )
17
17
 
18
18
  import numpy as np
19
- from tigrbl.session.base import TigrblSessionBase
19
+
20
+ try:
21
+ from tigrbl.session.base import TigrblSessionBase
22
+ except Exception:
23
+ from abc import ABC, abstractmethod
24
+
25
+ class TigrblSessionBase(ABC):
26
+ def __init__(self, spec=None):
27
+ self._spec = spec
28
+
29
+ def apply_spec(self, spec):
30
+ self._spec = spec
31
+
32
+ @abstractmethod
33
+ def _add_impl(self, obj): ...
34
+
35
+ @abstractmethod
36
+ async def _delete_impl(self, obj): ...
37
+
38
+ @abstractmethod
39
+ async def _flush_impl(self): ...
40
+
41
+ @abstractmethod
42
+ async def _refresh_impl(self, obj): ...
43
+
44
+ @abstractmethod
45
+ async def _get_impl(self, model, ident): ...
46
+
47
+ @abstractmethod
48
+ async def _execute_impl(self, stmt): ...
49
+
50
+ @abstractmethod
51
+ async def _tx_begin_impl(self): ...
52
+
53
+ @abstractmethod
54
+ async def _tx_commit_impl(self): ...
55
+
56
+ @abstractmethod
57
+ async def _tx_rollback_impl(self): ...
58
+
59
+ @abstractmethod
60
+ async def _close_impl(self): ...
61
+
20
62
 
21
63
  try:
22
64
  from tigrbl.session.spec import SessionSpec
@@ -68,9 +110,22 @@ class NumpySession(TigrblSessionBase):
68
110
  self._snap_ver: Optional[int] = None
69
111
  self._puts: dict[tuple[type, Any], dict[str, Any]] = {}
70
112
  self._dels: set[tuple[type, Any]] = set()
113
+ self._tracked: dict[tuple[type, Any], Any] = {}
114
+ self._tracked: dict[tuple[type, Any], Any] = {}
71
115
 
72
116
  def to_records(self) -> list[dict[str, Any]]:
73
- return [dict(row) for row in self._engine.catalog.rows]
117
+ pk = self._engine.catalog.pk
118
+ by_id: dict[Any, dict[str, Any]] = {}
119
+ for row in self._engine.catalog.rows:
120
+ ident = row.get(pk)
121
+ if ident is None:
122
+ continue
123
+ by_id[ident] = dict(row)
124
+ for (_, ident), row in self._puts.items():
125
+ by_id[ident] = dict(row)
126
+ for _, ident in self._dels:
127
+ by_id.pop(ident, None)
128
+ return list(by_id.values())
74
129
 
75
130
  def array(self) -> np.ndarray:
76
131
  rows = self.to_records()
@@ -155,6 +210,8 @@ class NumpySession(TigrblSessionBase):
155
210
  self._snap_ver = self._engine.catalog.table_ver
156
211
  self._puts.clear()
157
212
  self._dels.clear()
213
+ self._tracked.clear()
214
+ self._tracked.clear()
158
215
 
159
216
  async def _tx_commit_impl(self) -> None:
160
217
  iso = (self._spec.isolation if self._spec else None) or "read_committed"
@@ -184,17 +241,34 @@ class NumpySession(TigrblSessionBase):
184
241
  )
185
242
  self._puts.clear()
186
243
  self._dels.clear()
244
+ self._tracked.clear()
245
+ self._tracked.clear()
187
246
 
188
247
  async def _tx_rollback_impl(self) -> None:
189
248
  self._puts.clear()
190
249
  self._dels.clear()
250
+ self._tracked.clear()
191
251
 
192
252
  def _add_impl(self, obj: Any) -> Any:
193
253
  model = obj.__class__
194
254
  pk = _single_pk_name(model)
195
255
  ident = getattr(obj, pk)
256
+ if ident is None:
257
+ default = getattr(getattr(model, "__table__", None), "columns", {}).get(pk)
258
+ default = getattr(default, "default", None)
259
+ if default is not None:
260
+ arg = getattr(default, "arg", default)
261
+ if callable(arg):
262
+ try:
263
+ ident = arg()
264
+ except TypeError:
265
+ ident = arg(None)
266
+ else:
267
+ ident = arg
268
+ setattr(obj, pk, ident)
196
269
  if ident is None:
197
270
  raise ValueError(f"primary key {pk!r} must be set")
271
+ self._tracked[(model, ident)] = obj
198
272
  row = {c: getattr(obj, c, None) for c in _model_columns(model)}
199
273
  self._puts[(model, ident)] = row
200
274
  self._dels.discard((model, ident))
@@ -206,30 +280,43 @@ class NumpySession(TigrblSessionBase):
206
280
  ident = getattr(obj, pk)
207
281
  self._puts.pop((model, ident), None)
208
282
  self._dels.add((model, ident))
283
+ self._tracked.pop((model, ident), None)
284
+ self._tracked.pop((model, ident), None)
209
285
 
210
286
  async def _flush_impl(self) -> None:
211
- return
287
+ for (model, ident), obj in list(self._tracked.items()):
288
+ if (model, ident) in self._dels:
289
+ continue
290
+ row = {c: getattr(obj, c, None) for c in _model_columns(model)}
291
+ baseline = self._resolve_row(model, ident)
292
+ if baseline is None:
293
+ continue
294
+ if row == dict(baseline):
295
+ continue
296
+ if self._spec and self._spec.read_only:
297
+ raise RuntimeError("read-only session: writes detected during flush")
298
+ self._puts[(model, ident)] = row
212
299
 
213
300
  async def _refresh_impl(self, obj: Any) -> None:
214
301
  pk = _single_pk_name(obj.__class__)
215
302
  ident = getattr(obj, pk)
216
- fresh = await self._get_impl(obj.__class__, ident)
217
- if fresh is None:
303
+ row = self._resolve_row(obj.__class__, ident)
304
+ if row is None:
218
305
  return
219
306
  for c in _model_columns(obj.__class__):
220
- setattr(obj, c, getattr(fresh, c, None))
307
+ if c in row:
308
+ setattr(obj, c, row[c])
221
309
 
222
310
  async def _get_impl(self, model: type, ident: Any) -> Any | None:
223
- row = self._puts.get((model, ident))
224
- if row is not None:
225
- return self._inflate(model, row)
226
311
  if (model, ident) in self._dels:
227
312
  return None
228
- pk = _single_pk_name(model)
229
- for record in self._engine.catalog.rows:
230
- if record.get(pk) == ident:
231
- return self._inflate(model, record)
232
- return None
313
+ tracked = self._tracked.get((model, ident))
314
+ if tracked is not None:
315
+ return tracked
316
+ row = self._resolve_row(model, ident)
317
+ if row is None:
318
+ return None
319
+ return self._hydrate_tracked(model, ident, row)
233
320
 
234
321
  async def _execute_impl(self, stmt: Any) -> Any:
235
322
  kind = type(stmt).__name__.lower()
@@ -252,6 +339,7 @@ class NumpySession(TigrblSessionBase):
252
339
  raise NotImplementedError(f"Unsupported statement: {type(stmt)}")
253
340
 
254
341
  async def _close_impl(self) -> None:
342
+ self._tracked.clear()
255
343
  return
256
344
 
257
345
  def _inflate(self, model: type, data: Mapping[str, Any]) -> Any:
@@ -262,17 +350,43 @@ class NumpySession(TigrblSessionBase):
262
350
  return obj
263
351
 
264
352
  def _scan_model(self, model: type) -> List[Any]:
265
- out = [self._inflate(model, row) for row in self._engine.catalog.rows]
266
353
  pk = _single_pk_name(model)
267
- by_id = {getattr(obj, pk): obj for obj in out}
354
+ by_id: dict[Any, Any] = {}
355
+ for row in self._engine.catalog.rows:
356
+ ident = row.get(pk)
357
+ if ident is None:
358
+ continue
359
+ by_id[ident] = self._hydrate_tracked(model, ident, row)
268
360
  for (m, ident), row in self._puts.items():
269
361
  if m is model:
270
- by_id[ident] = self._inflate(model, row)
362
+ by_id[ident] = self._hydrate_tracked(model, ident, row)
271
363
  for m, ident in self._dels:
272
364
  if m is model:
273
365
  by_id.pop(ident, None)
366
+ self._tracked.pop((model, ident), None)
274
367
  return list(by_id.values())
275
368
 
369
+ def _resolve_row(self, model: type, ident: Any) -> Mapping[str, Any] | None:
370
+ row = self._puts.get((model, ident))
371
+ if row is not None:
372
+ return row
373
+ pk = _single_pk_name(model)
374
+ for record in self._engine.catalog.rows:
375
+ if record.get(pk) == ident:
376
+ return record
377
+ return None
378
+
379
+ def _hydrate_tracked(self, model: type, ident: Any, data: Mapping[str, Any]) -> Any:
380
+ obj = self._tracked.get((model, ident))
381
+ if obj is None:
382
+ obj = self._inflate(model, data)
383
+ self._tracked[(model, ident)] = obj
384
+ return obj
385
+ for c in _model_columns(model):
386
+ if c in data:
387
+ setattr(obj, c, data[c])
388
+ return obj
389
+
276
390
  def _decompose_select(
277
391
  self, stmt: Any
278
392
  ) -> Tuple[
@@ -302,11 +416,20 @@ class NumpySession(TigrblSessionBase):
302
416
 
303
417
  def _all_subclasses(base: type) -> list[type]:
304
418
  out: list[type] = []
305
- stack = list(base.__subclasses__())
419
+ stack = [base]
420
+ seen: set[type] = set()
306
421
  while stack:
307
422
  cls = stack.pop()
308
- out.append(cls)
309
- stack.extend(cls.__subclasses__())
423
+ try:
424
+ children = cls.__subclasses__()
425
+ except TypeError:
426
+ continue
427
+ for child in children:
428
+ if child in seen:
429
+ continue
430
+ seen.add(child)
431
+ out.append(child)
432
+ stack.append(child)
310
433
  return out
311
434
 
312
435
  def _find_by_table(name: str) -> type | None:
@@ -315,6 +438,13 @@ class NumpySession(TigrblSessionBase):
315
438
  return cls
316
439
  return None
317
440
 
441
+ table = getattr(stmt, "table", None)
442
+ name = getattr(table, "name", None)
443
+ if isinstance(name, str):
444
+ found = _find_by_table(name)
445
+ if found is not None:
446
+ return found
447
+
318
448
  for attr_name in ("_from_objects", "_froms", "froms"):
319
449
  value = getattr(stmt, attr_name, None)
320
450
  if value is not None:
@@ -326,6 +456,27 @@ class NumpySession(TigrblSessionBase):
326
456
  found = _find_by_table(name)
327
457
  if found is not None:
328
458
  return found
459
+
460
+ table = getattr(stmt, "table", None)
461
+ name = getattr(table, "name", None)
462
+ if isinstance(name, str):
463
+ found = _find_by_table(name)
464
+ if found is not None:
465
+ return found
466
+
467
+ raw_columns = getattr(stmt, "_raw_columns", None) or getattr(
468
+ stmt, "columns", None
469
+ )
470
+ if raw_columns is not None:
471
+ if isinstance(raw_columns, (list, tuple)) and not raw_columns:
472
+ raise RuntimeError("Cannot resolve model from statement")
473
+ entity = raw_columns[0]
474
+ table = getattr(entity, "table", None)
475
+ name = getattr(table, "name", None)
476
+ if isinstance(name, str):
477
+ found = _find_by_table(name)
478
+ if found is not None:
479
+ return found
329
480
  raise RuntimeError("Cannot resolve model from statement")
330
481
 
331
482
  def _extract_predicates(self, stmt: Any) -> list[Tuple[str, str, Any]]:
@@ -4,10 +4,34 @@ from pathlib import Path
4
4
 
5
5
  import numpy as np
6
6
  import pytest
7
+ from tigrbl.specs import F, IO, S
8
+ from tigrbl.shortcuts import acol
9
+ from tigrbl.table import Table
10
+ from tigrbl.types import Mapped, String
7
11
 
8
12
  from tigrbl_engine_numpy import numpy_engine
9
13
 
10
14
 
15
+ class _Widget(Table):
16
+ __tablename__ = "session_widgets"
17
+
18
+ id: Mapped[str] = acol(
19
+ storage=S(type_=String(64), primary_key=True, nullable=False),
20
+ field=F(py_type=str),
21
+ io=IO(out_verbs=("read", "list")),
22
+ )
23
+
24
+ name: Mapped[str] = acol(
25
+ storage=S(type_=String(50), nullable=False),
26
+ field=F(py_type=str),
27
+ io=IO(
28
+ in_verbs=("create", "update", "replace"),
29
+ out_verbs=("read", "list"),
30
+ mutable_verbs=("create", "update", "replace"),
31
+ ),
32
+ )
33
+
34
+
11
35
  def test_numpy_session_save_and_load_npy(tmp_path: Path) -> None:
12
36
  target = tmp_path / "records.npy"
13
37
  _, session_factory = numpy_engine(
@@ -134,3 +158,48 @@ def test_numpy_session_save_uses_atomic_replace(
134
158
  assert len(calls) == 1
135
159
  assert calls[0][1] == str(target)
136
160
  assert Path(calls[0][0]).name.startswith(".tmp_")
161
+
162
+
163
+ @pytest.mark.asyncio
164
+ async def test_numpy_session_get_reuses_tracked_instance() -> None:
165
+ ident = "fixed-id"
166
+ _, session_factory = numpy_engine(
167
+ mapping={
168
+ "array": np.array([[ident, "a"]], dtype=object),
169
+ "columns": ["id", "name"],
170
+ "pk": "id",
171
+ }
172
+ )
173
+ session = session_factory()
174
+
175
+ first = await session.get(_Widget, ident)
176
+ assert first is not None
177
+ first.name = "mutated"
178
+
179
+ second = await session.get(_Widget, ident)
180
+ assert second is first
181
+ assert second.name == "mutated"
182
+
183
+
184
+ @pytest.mark.asyncio
185
+ async def test_numpy_session_refresh_updates_tracked_instance() -> None:
186
+ ident = "fixed-id"
187
+ engine, session_factory = numpy_engine(
188
+ mapping={
189
+ "array": np.array([[ident, "a"]], dtype=object),
190
+ "columns": ["id", "name"],
191
+ "pk": "id",
192
+ }
193
+ )
194
+ session = session_factory()
195
+
196
+ item = await session.get(_Widget, ident)
197
+ assert item is not None
198
+ item.name = "mutated"
199
+
200
+ engine.catalog.rows[0]["name"] = "server"
201
+ await session.refresh(item)
202
+
203
+ again = await session.get(_Widget, ident)
204
+ assert again is item
205
+ assert again.name == "server"
@@ -4,10 +4,11 @@ import numpy as np
4
4
  import pytest
5
5
 
6
6
  from tigrbl import TigrblApp
7
- from tigrbl.bindings import rpc_call
7
+ from tigrbl import rpc_call
8
8
  from tigrbl.engine import EngineSpec
9
9
  from tigrbl.orm.mixins import GUIDPk
10
- from tigrbl.specs import F, IO, S, acol
10
+ from tigrbl.specs import F, IO, S
11
+ from tigrbl.shortcuts import acol
11
12
  from tigrbl.table import Table
12
13
  from tigrbl.types import Mapped, String
13
14
 
@@ -44,7 +45,7 @@ def app_and_db() -> tuple[TigrblApp, object]:
44
45
  db = session_factory()
45
46
 
46
47
  api = TigrblApp(engine=spec)
47
- api.include_model(NumpyWidget, mount_router=False)
48
+ api.include_table(NumpyWidget, mount_router=False)
48
49
  api.initialize()
49
50
  return api, db
50
51