tigrbl_engine_numpy 0.1.1.dev1__tar.gz → 0.1.1.dev3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: tigrbl_engine_numpy
3
- Version: 0.1.1.dev1
3
+ Version: 0.1.1.dev3
4
4
  Summary: NumPy engine plugin for tigrbl with array-to-table helpers.
5
5
  Project-URL: Homepage, https://github.com/swarmauri/swarmauri-sdk
6
6
  Author-email: Jacob Stewart <jacob@swarmauri.com>
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
4
4
 
5
5
  [project]
6
6
  name = "tigrbl_engine_numpy"
7
- version = "0.1.1.dev1"
7
+ version = "0.1.1.dev3"
8
8
  description = "NumPy engine plugin for tigrbl with array-to-table helpers."
9
9
  readme = "README.md"
10
10
  license = { file = "LICENSE" }
@@ -68,9 +68,16 @@ class NumpySession(TigrblSessionBase):
68
68
  self._snap_ver: Optional[int] = None
69
69
  self._puts: dict[tuple[type, Any], dict[str, Any]] = {}
70
70
  self._dels: set[tuple[type, Any]] = set()
71
+ self._tracked: dict[tuple[type, Any], Any] = {}
71
72
 
72
73
  def to_records(self) -> list[dict[str, Any]]:
73
- return [dict(row) for row in self._engine.catalog.rows]
74
+ pk = self._engine.catalog.pk
75
+ rows = [dict(row) for row in self._engine.catalog.rows]
76
+ deleted = {ident for (_model, ident) in self._dels}
77
+ by_pk = {row.get(pk): row for row in rows if row.get(pk) not in deleted}
78
+ for (_model, ident), row in self._puts.items():
79
+ by_pk[ident] = dict(row)
80
+ return list(by_pk.values())
74
81
 
75
82
  def array(self) -> np.ndarray:
76
83
  rows = self.to_records()
@@ -155,6 +162,7 @@ class NumpySession(TigrblSessionBase):
155
162
  self._snap_ver = self._engine.catalog.table_ver
156
163
  self._puts.clear()
157
164
  self._dels.clear()
165
+ self._tracked.clear()
158
166
 
159
167
  async def _tx_commit_impl(self) -> None:
160
168
  iso = (self._spec.isolation if self._spec else None) or "read_committed"
@@ -184,15 +192,43 @@ class NumpySession(TigrblSessionBase):
184
192
  )
185
193
  self._puts.clear()
186
194
  self._dels.clear()
195
+ self._tracked.clear()
187
196
 
188
197
  async def _tx_rollback_impl(self) -> None:
189
198
  self._puts.clear()
190
199
  self._dels.clear()
200
+ self._tracked.clear()
201
+
202
+ @staticmethod
203
+ def _pk_default(model: type, pk: str) -> Any:
204
+ table = getattr(model, "__table__", None)
205
+ if table is None:
206
+ return None
207
+ try:
208
+ column = table.columns.get(pk)
209
+ except Exception:
210
+ return None
211
+ if column is None:
212
+ return None
213
+ default = getattr(column, "default", None)
214
+ if default is None:
215
+ return None
216
+ arg = getattr(default, "arg", None)
217
+ if callable(arg):
218
+ try:
219
+ return arg()
220
+ except TypeError:
221
+ return arg(None)
222
+ return arg
191
223
 
192
224
  def _add_impl(self, obj: Any) -> Any:
193
225
  model = obj.__class__
194
226
  pk = _single_pk_name(model)
195
227
  ident = getattr(obj, pk)
228
+ if ident is None:
229
+ ident = self._pk_default(model, pk)
230
+ if ident is not None:
231
+ setattr(obj, pk, ident)
196
232
  if ident is None:
197
233
  raise ValueError(f"primary key {pk!r} must be set")
198
234
  row = {c: getattr(obj, c, None) for c in _model_columns(model)}
@@ -206,8 +242,15 @@ class NumpySession(TigrblSessionBase):
206
242
  ident = getattr(obj, pk)
207
243
  self._puts.pop((model, ident), None)
208
244
  self._dels.add((model, ident))
245
+ self._tracked.pop((model, ident), None)
209
246
 
210
247
  async def _flush_impl(self) -> None:
248
+ for (model, ident), obj in self._tracked.items():
249
+ if (model, ident) in self._dels:
250
+ continue
251
+ self._puts[(model, ident)] = {
252
+ column: getattr(obj, column, None) for column in _model_columns(model)
253
+ }
211
254
  return
212
255
 
213
256
  async def _refresh_impl(self, obj: Any) -> None:
@@ -222,13 +265,17 @@ class NumpySession(TigrblSessionBase):
222
265
  async def _get_impl(self, model: type, ident: Any) -> Any | None:
223
266
  row = self._puts.get((model, ident))
224
267
  if row is not None:
225
- return self._inflate(model, row)
268
+ obj = self._inflate(model, row)
269
+ self._tracked[(model, ident)] = obj
270
+ return obj
226
271
  if (model, ident) in self._dels:
227
272
  return None
228
273
  pk = _single_pk_name(model)
229
274
  for record in self._engine.catalog.rows:
230
275
  if record.get(pk) == ident:
231
- return self._inflate(model, record)
276
+ obj = self._inflate(model, record)
277
+ self._tracked[(model, ident)] = obj
278
+ return obj
232
279
  return None
233
280
 
234
281
  async def _execute_impl(self, stmt: Any) -> Any:
@@ -306,7 +353,10 @@ class NumpySession(TigrblSessionBase):
306
353
  while stack:
307
354
  cls = stack.pop()
308
355
  out.append(cls)
309
- stack.extend(cls.__subclasses__())
356
+ try:
357
+ stack.extend(cls.__subclasses__())
358
+ except TypeError:
359
+ stack.extend(type.__subclasses__(cls))
310
360
  return out
311
361
 
312
362
  def _find_by_table(name: str) -> type | None:
@@ -326,6 +376,27 @@ class NumpySession(TigrblSessionBase):
326
376
  found = _find_by_table(name)
327
377
  if found is not None:
328
378
  return found
379
+
380
+ table = getattr(stmt, "table", None)
381
+ name = getattr(table, "name", None)
382
+ if isinstance(name, str):
383
+ found = _find_by_table(name)
384
+ if found is not None:
385
+ return found
386
+
387
+ raw_columns = getattr(stmt, "_raw_columns", None) or getattr(
388
+ stmt, "columns", None
389
+ )
390
+ if raw_columns is not None:
391
+ if isinstance(raw_columns, (list, tuple)) and not raw_columns:
392
+ raise RuntimeError("Cannot resolve model from statement")
393
+ entity = raw_columns[0]
394
+ table = getattr(entity, "table", None)
395
+ name = getattr(table, "name", None)
396
+ if isinstance(name, str):
397
+ found = _find_by_table(name)
398
+ if found is not None:
399
+ return found
329
400
  raise RuntimeError("Cannot resolve model from statement")
330
401
 
331
402
  def _extract_predicates(self, stmt: Any) -> list[Tuple[str, str, Any]]:
@@ -44,7 +44,7 @@ def app_and_db() -> tuple[TigrblApp, object]:
44
44
  db = session_factory()
45
45
 
46
46
  api = TigrblApp(engine=spec)
47
- api.include_model(NumpyWidget, mount_router=False)
47
+ api.include_table(NumpyWidget, mount_router=False)
48
48
  api.initialize()
49
49
  return api, db
50
50