letta-nightly 0.7.15.dev20250515104317__py3-none-any.whl → 0.7.17.dev20250516090339__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- letta/__init__.py +1 -1
- letta/agent.py +12 -0
- letta/agents/helpers.py +48 -5
- letta/agents/letta_agent.py +64 -28
- letta/agents/letta_agent_batch.py +44 -26
- letta/agents/voice_sleeptime_agent.py +6 -4
- letta/client/client.py +16 -1
- letta/constants.py +3 -0
- letta/functions/async_composio_toolset.py +1 -1
- letta/interfaces/anthropic_streaming_interface.py +40 -6
- letta/interfaces/openai_streaming_interface.py +303 -0
- letta/jobs/llm_batch_job_polling.py +6 -2
- letta/orm/agent.py +102 -1
- letta/orm/block.py +3 -0
- letta/orm/sqlalchemy_base.py +459 -158
- letta/schemas/agent.py +10 -2
- letta/schemas/block.py +3 -0
- letta/schemas/memory.py +7 -2
- letta/server/rest_api/routers/v1/agents.py +29 -27
- letta/server/rest_api/routers/v1/blocks.py +1 -1
- letta/server/rest_api/routers/v1/groups.py +2 -2
- letta/server/rest_api/routers/v1/messages.py +11 -11
- letta/server/rest_api/routers/v1/runs.py +2 -2
- letta/server/rest_api/routers/v1/tools.py +4 -4
- letta/server/rest_api/routers/v1/users.py +9 -9
- letta/server/rest_api/routers/v1/voice.py +1 -1
- letta/server/server.py +74 -0
- letta/services/agent_manager.py +417 -7
- letta/services/block_manager.py +12 -8
- letta/services/helpers/agent_manager_helper.py +19 -0
- letta/services/job_manager.py +99 -0
- letta/services/llm_batch_manager.py +28 -27
- letta/services/message_manager.py +66 -19
- letta/services/passage_manager.py +14 -0
- letta/services/tool_executor/tool_executor.py +19 -1
- letta/services/tool_manager.py +13 -3
- letta/services/user_manager.py +70 -0
- letta/types/__init__.py +0 -0
- {letta_nightly-0.7.15.dev20250515104317.dist-info → letta_nightly-0.7.17.dev20250516090339.dist-info}/METADATA +3 -3
- {letta_nightly-0.7.15.dev20250515104317.dist-info → letta_nightly-0.7.17.dev20250516090339.dist-info}/RECORD +43 -41
- {letta_nightly-0.7.15.dev20250515104317.dist-info → letta_nightly-0.7.17.dev20250516090339.dist-info}/LICENSE +0 -0
- {letta_nightly-0.7.15.dev20250515104317.dist-info → letta_nightly-0.7.17.dev20250516090339.dist-info}/WHEEL +0 -0
- {letta_nightly-0.7.15.dev20250515104317.dist-info → letta_nightly-0.7.17.dev20250516090339.dist-info}/entry_points.txt +0 -0
letta/orm/sqlalchemy_base.py
CHANGED
@@ -114,154 +114,323 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
114
114
|
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
|
115
115
|
raise ValueError("'before' reference must be later than 'after' reference")
|
116
116
|
|
117
|
+
query = cls._list_preprocess(
|
118
|
+
before_obj=before_obj,
|
119
|
+
after_obj=after_obj,
|
120
|
+
start_date=start_date,
|
121
|
+
end_date=end_date,
|
122
|
+
limit=limit,
|
123
|
+
query_text=query_text,
|
124
|
+
query_embedding=query_embedding,
|
125
|
+
ascending=ascending,
|
126
|
+
tags=tags,
|
127
|
+
match_all_tags=match_all_tags,
|
128
|
+
actor=actor,
|
129
|
+
access=access,
|
130
|
+
access_type=access_type,
|
131
|
+
join_model=join_model,
|
132
|
+
join_conditions=join_conditions,
|
133
|
+
identifier_keys=identifier_keys,
|
134
|
+
identity_id=identity_id,
|
135
|
+
**kwargs,
|
136
|
+
)
|
137
|
+
|
138
|
+
# Execute the query
|
139
|
+
results = session.execute(query)
|
140
|
+
|
141
|
+
results = list(results.scalars())
|
142
|
+
results = cls._list_postprocess(
|
143
|
+
before=before,
|
144
|
+
after=after,
|
145
|
+
limit=limit,
|
146
|
+
results=results,
|
147
|
+
)
|
148
|
+
|
149
|
+
return results
|
150
|
+
|
151
|
+
@classmethod
|
152
|
+
@handle_db_timeout
|
153
|
+
async def list_async(
|
154
|
+
cls,
|
155
|
+
*,
|
156
|
+
db_session: "AsyncSession",
|
157
|
+
before: Optional[str] = None,
|
158
|
+
after: Optional[str] = None,
|
159
|
+
start_date: Optional[datetime] = None,
|
160
|
+
end_date: Optional[datetime] = None,
|
161
|
+
limit: Optional[int] = 50,
|
162
|
+
query_text: Optional[str] = None,
|
163
|
+
query_embedding: Optional[List[float]] = None,
|
164
|
+
ascending: bool = True,
|
165
|
+
tags: Optional[List[str]] = None,
|
166
|
+
match_all_tags: bool = False,
|
167
|
+
actor: Optional["User"] = None,
|
168
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
169
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
170
|
+
join_model: Optional[Base] = None,
|
171
|
+
join_conditions: Optional[Union[Tuple, List]] = None,
|
172
|
+
identifier_keys: Optional[List[str]] = None,
|
173
|
+
identity_id: Optional[str] = None,
|
174
|
+
**kwargs,
|
175
|
+
) -> List["SqlalchemyBase"]:
|
176
|
+
"""
|
177
|
+
Async version of list method above.
|
178
|
+
NOTE: Keep in sync.
|
179
|
+
List records with before/after pagination, ordering by created_at.
|
180
|
+
Can use both before and after to fetch a window of records.
|
181
|
+
|
182
|
+
Args:
|
183
|
+
db_session: SQLAlchemy session
|
184
|
+
before: ID of item to paginate before (upper bound)
|
185
|
+
after: ID of item to paginate after (lower bound)
|
186
|
+
start_date: Filter items after this date
|
187
|
+
end_date: Filter items before this date
|
188
|
+
limit: Maximum number of items to return
|
189
|
+
query_text: Text to search for
|
190
|
+
query_embedding: Vector to search for similar embeddings
|
191
|
+
ascending: Sort direction
|
192
|
+
tags: List of tags to filter by
|
193
|
+
match_all_tags: If True, return items matching all tags. If False, match any tag.
|
194
|
+
**kwargs: Additional filters to apply
|
195
|
+
"""
|
196
|
+
if start_date and end_date and start_date > end_date:
|
197
|
+
raise ValueError("start_date must be earlier than or equal to end_date")
|
198
|
+
|
199
|
+
logger.debug(f"Listing {cls.__name__} with kwarg filters {kwargs}")
|
200
|
+
|
201
|
+
async with db_session as session:
|
202
|
+
# Get the reference objects for pagination
|
203
|
+
before_obj = None
|
204
|
+
after_obj = None
|
205
|
+
|
206
|
+
if before:
|
207
|
+
before_obj = await session.get(cls, before)
|
208
|
+
if not before_obj:
|
209
|
+
raise NoResultFound(f"No {cls.__name__} found with id {before}")
|
210
|
+
|
211
|
+
if after:
|
212
|
+
after_obj = await session.get(cls, after)
|
213
|
+
if not after_obj:
|
214
|
+
raise NoResultFound(f"No {cls.__name__} found with id {after}")
|
215
|
+
|
216
|
+
# Validate that before comes after the after object if both are provided
|
217
|
+
if before_obj and after_obj and before_obj.created_at < after_obj.created_at:
|
218
|
+
raise ValueError("'before' reference must be later than 'after' reference")
|
219
|
+
|
220
|
+
query = cls._list_preprocess(
|
221
|
+
before_obj=before_obj,
|
222
|
+
after_obj=after_obj,
|
223
|
+
start_date=start_date,
|
224
|
+
end_date=end_date,
|
225
|
+
limit=limit,
|
226
|
+
query_text=query_text,
|
227
|
+
query_embedding=query_embedding,
|
228
|
+
ascending=ascending,
|
229
|
+
tags=tags,
|
230
|
+
match_all_tags=match_all_tags,
|
231
|
+
actor=actor,
|
232
|
+
access=access,
|
233
|
+
access_type=access_type,
|
234
|
+
join_model=join_model,
|
235
|
+
join_conditions=join_conditions,
|
236
|
+
identifier_keys=identifier_keys,
|
237
|
+
identity_id=identity_id,
|
238
|
+
**kwargs,
|
239
|
+
)
|
240
|
+
|
241
|
+
# Execute the query
|
242
|
+
results = await session.execute(query)
|
243
|
+
|
244
|
+
results = list(results.scalars())
|
245
|
+
results = cls._list_postprocess(
|
246
|
+
before=before,
|
247
|
+
after=after,
|
248
|
+
limit=limit,
|
249
|
+
results=results,
|
250
|
+
)
|
251
|
+
|
252
|
+
return results
|
253
|
+
|
254
|
+
@classmethod
|
255
|
+
def _list_preprocess(
|
256
|
+
cls,
|
257
|
+
*,
|
258
|
+
before_obj,
|
259
|
+
after_obj,
|
260
|
+
start_date: Optional[datetime] = None,
|
261
|
+
end_date: Optional[datetime] = None,
|
262
|
+
limit: Optional[int] = 50,
|
263
|
+
query_text: Optional[str] = None,
|
264
|
+
query_embedding: Optional[List[float]] = None,
|
265
|
+
ascending: bool = True,
|
266
|
+
tags: Optional[List[str]] = None,
|
267
|
+
match_all_tags: bool = False,
|
268
|
+
actor: Optional["User"] = None,
|
269
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
270
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
271
|
+
join_model: Optional[Base] = None,
|
272
|
+
join_conditions: Optional[Union[Tuple, List]] = None,
|
273
|
+
identifier_keys: Optional[List[str]] = None,
|
274
|
+
identity_id: Optional[str] = None,
|
275
|
+
**kwargs,
|
276
|
+
):
|
277
|
+
"""
|
278
|
+
Constructs the query for listing records.
|
279
|
+
"""
|
280
|
+
query = select(cls)
|
281
|
+
|
282
|
+
if join_model and join_conditions:
|
283
|
+
query = query.join(join_model, and_(*join_conditions))
|
284
|
+
|
285
|
+
# Apply access predicate if actor is provided
|
286
|
+
if actor:
|
287
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
288
|
+
|
289
|
+
# Handle tag filtering if the model has tags
|
290
|
+
if tags and hasattr(cls, "tags"):
|
117
291
|
query = select(cls)
|
118
292
|
|
119
|
-
if
|
120
|
-
|
293
|
+
if match_all_tags:
|
294
|
+
# Match ALL tags - use subqueries
|
295
|
+
subquery = (
|
296
|
+
select(cls.tags.property.mapper.class_.agent_id)
|
297
|
+
.where(cls.tags.property.mapper.class_.tag.in_(tags))
|
298
|
+
.group_by(cls.tags.property.mapper.class_.agent_id)
|
299
|
+
.having(func.count() == len(tags))
|
300
|
+
)
|
301
|
+
query = query.filter(cls.id.in_(subquery))
|
302
|
+
else:
|
303
|
+
# Match ANY tag - use join and filter
|
304
|
+
query = (
|
305
|
+
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
306
|
+
) # Deduplicate results
|
307
|
+
|
308
|
+
# select distinct primary key
|
309
|
+
query = query.distinct(cls.id).order_by(cls.id)
|
310
|
+
|
311
|
+
if identifier_keys and hasattr(cls, "identities"):
|
312
|
+
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
313
|
+
|
314
|
+
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
315
|
+
if identity_id and hasattr(cls, "identities"):
|
316
|
+
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
317
|
+
|
318
|
+
# Apply filtering logic from kwargs
|
319
|
+
for key, value in kwargs.items():
|
320
|
+
if "." in key:
|
321
|
+
# Handle joined table columns
|
322
|
+
table_name, column_name = key.split(".")
|
323
|
+
joined_table = locals().get(table_name) or globals().get(table_name)
|
324
|
+
column = getattr(joined_table, column_name)
|
325
|
+
else:
|
326
|
+
# Handle columns from main table
|
327
|
+
column = getattr(cls, key)
|
121
328
|
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
subquery = (
|
133
|
-
select(cls.tags.property.mapper.class_.agent_id)
|
134
|
-
.where(cls.tags.property.mapper.class_.tag.in_(tags))
|
135
|
-
.group_by(cls.tags.property.mapper.class_.agent_id)
|
136
|
-
.having(func.count() == len(tags))
|
137
|
-
)
|
138
|
-
query = query.filter(cls.id.in_(subquery))
|
139
|
-
else:
|
140
|
-
# Match ANY tag - use join and filter
|
141
|
-
query = (
|
142
|
-
query.join(cls.tags).filter(cls.tags.property.mapper.class_.tag.in_(tags)).distinct(cls.id).order_by(cls.id)
|
143
|
-
) # Deduplicate results
|
144
|
-
|
145
|
-
# select distinct primary key
|
146
|
-
query = query.distinct(cls.id).order_by(cls.id)
|
147
|
-
|
148
|
-
if identifier_keys and hasattr(cls, "identities"):
|
149
|
-
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.identifier_key.in_(identifier_keys))
|
150
|
-
|
151
|
-
# given the identity_id, we can find within the agents table any agents that have the identity_id in their identity_ids
|
152
|
-
if identity_id and hasattr(cls, "identities"):
|
153
|
-
query = query.join(cls.identities).filter(cls.identities.property.mapper.class_.id == identity_id)
|
154
|
-
|
155
|
-
# Apply filtering logic from kwargs
|
156
|
-
for key, value in kwargs.items():
|
157
|
-
if "." in key:
|
158
|
-
# Handle joined table columns
|
159
|
-
table_name, column_name = key.split(".")
|
160
|
-
joined_table = locals().get(table_name) or globals().get(table_name)
|
161
|
-
column = getattr(joined_table, column_name)
|
162
|
-
else:
|
163
|
-
# Handle columns from main table
|
164
|
-
column = getattr(cls, key)
|
165
|
-
|
166
|
-
if isinstance(value, (list, tuple, set)):
|
167
|
-
query = query.where(column.in_(value))
|
168
|
-
else:
|
169
|
-
query = query.where(column == value)
|
329
|
+
if isinstance(value, (list, tuple, set)):
|
330
|
+
query = query.where(column.in_(value))
|
331
|
+
else:
|
332
|
+
query = query.where(column == value)
|
333
|
+
|
334
|
+
# Date range filtering
|
335
|
+
if start_date:
|
336
|
+
query = query.filter(cls.created_at > start_date)
|
337
|
+
if end_date:
|
338
|
+
query = query.filter(cls.created_at < end_date)
|
170
339
|
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
or_(
|
186
|
-
|
187
|
-
|
188
|
-
# Pure pagination query
|
189
|
-
if before:
|
190
|
-
conditions.append(
|
191
|
-
or_(
|
192
|
-
cls.created_at < before_obj.created_at,
|
193
|
-
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
194
|
-
)
|
340
|
+
# Handle pagination based on before/after
|
341
|
+
if before_obj or after_obj:
|
342
|
+
conditions = []
|
343
|
+
|
344
|
+
if before_obj and after_obj:
|
345
|
+
# Window-based query - get records between before and after
|
346
|
+
conditions = [
|
347
|
+
or_(cls.created_at < before_obj.created_at, and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id)),
|
348
|
+
or_(cls.created_at > after_obj.created_at, and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id)),
|
349
|
+
]
|
350
|
+
else:
|
351
|
+
# Pure pagination query
|
352
|
+
if before_obj:
|
353
|
+
conditions.append(
|
354
|
+
or_(
|
355
|
+
cls.created_at < before_obj.created_at,
|
356
|
+
and_(cls.created_at == before_obj.created_at, cls.id < before_obj.id),
|
195
357
|
)
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
)
|
358
|
+
)
|
359
|
+
if after_obj:
|
360
|
+
conditions.append(
|
361
|
+
or_(
|
362
|
+
cls.created_at > after_obj.created_at,
|
363
|
+
and_(cls.created_at == after_obj.created_at, cls.id > after_obj.id),
|
202
364
|
)
|
203
|
-
|
204
|
-
if conditions:
|
205
|
-
query = query.where(and_(*conditions))
|
206
|
-
|
207
|
-
# Text search
|
208
|
-
if query_text:
|
209
|
-
if hasattr(cls, "text"):
|
210
|
-
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
211
|
-
elif hasattr(cls, "name"):
|
212
|
-
# Special case for Agent model - search across name
|
213
|
-
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
214
|
-
|
215
|
-
# Embedding search (for Passages)
|
216
|
-
is_ordered = False
|
217
|
-
if query_embedding:
|
218
|
-
if not hasattr(cls, "embedding"):
|
219
|
-
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
220
|
-
|
221
|
-
from letta.settings import settings
|
222
|
-
|
223
|
-
if settings.letta_pg_uri_no_default:
|
224
|
-
# PostgreSQL with pgvector
|
225
|
-
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
226
|
-
else:
|
227
|
-
# SQLite with custom vector type
|
228
|
-
query_embedding_binary = adapt_array(query_embedding)
|
229
|
-
query = query.order_by(
|
230
|
-
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
231
|
-
cls.created_at.asc() if ascending else cls.created_at.desc(),
|
232
|
-
cls.id.asc(),
|
233
365
|
)
|
234
|
-
is_ordered = True
|
235
366
|
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
if
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
367
|
+
if conditions:
|
368
|
+
query = query.where(and_(*conditions))
|
369
|
+
|
370
|
+
# Text search
|
371
|
+
if query_text:
|
372
|
+
if hasattr(cls, "text"):
|
373
|
+
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
374
|
+
elif hasattr(cls, "name"):
|
375
|
+
# Special case for Agent model - search across name
|
376
|
+
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
377
|
+
|
378
|
+
# Embedding search (for Passages)
|
379
|
+
is_ordered = False
|
380
|
+
if query_embedding:
|
381
|
+
if not hasattr(cls, "embedding"):
|
382
|
+
raise ValueError(f"Class {cls.__name__} does not have an embedding column")
|
383
|
+
|
384
|
+
from letta.settings import settings
|
385
|
+
|
386
|
+
if settings.letta_pg_uri_no_default:
|
387
|
+
# PostgreSQL with pgvector
|
388
|
+
query = query.order_by(cls.embedding.cosine_distance(query_embedding).asc())
|
252
389
|
else:
|
253
|
-
|
390
|
+
# SQLite with custom vector type
|
391
|
+
query_embedding_binary = adapt_array(query_embedding)
|
392
|
+
query = query.order_by(
|
393
|
+
func.cosine_distance(cls.embedding, query_embedding_binary).asc(),
|
394
|
+
cls.created_at.asc() if ascending else cls.created_at.desc(),
|
395
|
+
cls.id.asc(),
|
396
|
+
)
|
397
|
+
is_ordered = True
|
398
|
+
|
399
|
+
# Handle soft deletes
|
400
|
+
if hasattr(cls, "is_deleted"):
|
401
|
+
query = query.where(cls.is_deleted == False)
|
254
402
|
|
255
|
-
|
403
|
+
# Apply ordering
|
404
|
+
if not is_ordered:
|
405
|
+
if ascending:
|
406
|
+
query = query.order_by(cls.created_at.asc(), cls.id.asc())
|
407
|
+
else:
|
408
|
+
query = query.order_by(cls.created_at.desc(), cls.id.desc())
|
256
409
|
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
410
|
+
# Apply limit, adjusting for both bounds if necessary
|
411
|
+
if before_obj and after_obj:
|
412
|
+
# When both bounds are provided, we need to fetch enough records to satisfy
|
413
|
+
# the limit while respecting both bounds. We'll fetch more and then trim.
|
414
|
+
query = query.limit(limit * 2)
|
415
|
+
else:
|
416
|
+
query = query.limit(limit)
|
417
|
+
return query
|
263
418
|
|
264
|
-
|
419
|
+
@classmethod
|
420
|
+
def _list_postprocess(
|
421
|
+
cls,
|
422
|
+
before: str | None,
|
423
|
+
after: str | None,
|
424
|
+
limit: int | None,
|
425
|
+
results: list,
|
426
|
+
):
|
427
|
+
# If we have both bounds, take the middle portion
|
428
|
+
if before and after and len(results) > limit:
|
429
|
+
middle = len(results) // 2
|
430
|
+
start = max(0, middle - limit // 2)
|
431
|
+
end = min(len(results), start + limit)
|
432
|
+
results = results[start:end]
|
433
|
+
return results
|
265
434
|
|
266
435
|
@classmethod
|
267
436
|
@handle_db_timeout
|
@@ -305,7 +474,7 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
305
474
|
@handle_db_timeout
|
306
475
|
async def read_async(
|
307
476
|
cls,
|
308
|
-
db_session: "
|
477
|
+
db_session: "AsyncSession",
|
309
478
|
identifier: Optional[str] = None,
|
310
479
|
actor: Optional["User"] = None,
|
311
480
|
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
@@ -462,6 +631,24 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
462
631
|
except (DBAPIError, IntegrityError) as e:
|
463
632
|
self._handle_dbapi_error(e)
|
464
633
|
|
634
|
+
@handle_db_timeout
|
635
|
+
async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
636
|
+
"""Async version of create function"""
|
637
|
+
logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
638
|
+
|
639
|
+
if actor:
|
640
|
+
self._set_created_and_updated_by_fields(actor.id)
|
641
|
+
try:
|
642
|
+
db_session.add(self)
|
643
|
+
if no_commit:
|
644
|
+
await db_session.flush() # no commit, just flush to get PK
|
645
|
+
else:
|
646
|
+
await db_session.commit()
|
647
|
+
await db_session.refresh(self)
|
648
|
+
return self
|
649
|
+
except (DBAPIError, IntegrityError) as e:
|
650
|
+
self._handle_dbapi_error(e)
|
651
|
+
|
465
652
|
@classmethod
|
466
653
|
@handle_db_timeout
|
467
654
|
def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]:
|
@@ -503,6 +690,51 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
503
690
|
except (DBAPIError, IntegrityError) as e:
|
504
691
|
cls._handle_dbapi_error(e)
|
505
692
|
|
693
|
+
@classmethod
|
694
|
+
@handle_db_timeout
|
695
|
+
async def batch_create_async(
|
696
|
+
cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None
|
697
|
+
) -> List["SqlalchemyBase"]:
|
698
|
+
"""
|
699
|
+
Async version of batch_create method.
|
700
|
+
Create multiple records in a single transaction for better performance.
|
701
|
+
Args:
|
702
|
+
items: List of model instances to create
|
703
|
+
db_session: AsyncSession session
|
704
|
+
actor: Optional user performing the action
|
705
|
+
Returns:
|
706
|
+
List of created model instances
|
707
|
+
"""
|
708
|
+
logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}")
|
709
|
+
if not items:
|
710
|
+
return []
|
711
|
+
|
712
|
+
# Set created/updated by fields if actor is provided
|
713
|
+
if actor:
|
714
|
+
for item in items:
|
715
|
+
item._set_created_and_updated_by_fields(actor.id)
|
716
|
+
|
717
|
+
try:
|
718
|
+
async with db_session as session:
|
719
|
+
session.add_all(items)
|
720
|
+
await session.flush() # Flush to generate IDs but don't commit yet
|
721
|
+
|
722
|
+
# Collect IDs to fetch the complete objects after commit
|
723
|
+
item_ids = [item.id for item in items]
|
724
|
+
|
725
|
+
await session.commit()
|
726
|
+
|
727
|
+
# Re-query the objects to get them with relationships loaded
|
728
|
+
query = select(cls).where(cls.id.in_(item_ids))
|
729
|
+
if hasattr(cls, "created_at"):
|
730
|
+
query = query.order_by(cls.created_at)
|
731
|
+
|
732
|
+
result = await session.execute(query)
|
733
|
+
return list(result.scalars())
|
734
|
+
|
735
|
+
except (DBAPIError, IntegrityError) as e:
|
736
|
+
cls._handle_dbapi_error(e)
|
737
|
+
|
506
738
|
@handle_db_timeout
|
507
739
|
def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
508
740
|
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}")
|
@@ -513,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
513
745
|
self.is_deleted = True
|
514
746
|
return self.update(db_session)
|
515
747
|
|
748
|
+
@handle_db_timeout
|
749
|
+
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
750
|
+
"""Soft delete a record asynchronously (mark as deleted)."""
|
751
|
+
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
752
|
+
|
753
|
+
if actor:
|
754
|
+
self._set_created_and_updated_by_fields(actor.id)
|
755
|
+
|
756
|
+
self.is_deleted = True
|
757
|
+
return await self.update_async(db_session)
|
758
|
+
|
516
759
|
@handle_db_timeout
|
517
760
|
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
518
761
|
"""Permanently removes the record from the database."""
|
@@ -529,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
529
772
|
else:
|
530
773
|
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
531
774
|
|
775
|
+
@handle_db_timeout
|
776
|
+
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
|
777
|
+
"""Permanently removes the record from the database asynchronously."""
|
778
|
+
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
779
|
+
|
780
|
+
async with db_session as session:
|
781
|
+
try:
|
782
|
+
await session.delete(self)
|
783
|
+
await session.commit()
|
784
|
+
except Exception as e:
|
785
|
+
await session.rollback()
|
786
|
+
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
|
787
|
+
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
|
788
|
+
|
532
789
|
@handle_db_timeout
|
533
790
|
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
534
791
|
logger.debug(...)
|
@@ -561,6 +818,39 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
561
818
|
await db_session.refresh(self)
|
562
819
|
return self
|
563
820
|
|
821
|
+
@classmethod
|
822
|
+
def _size_preprocess(
|
823
|
+
cls,
|
824
|
+
*,
|
825
|
+
db_session: "Session",
|
826
|
+
actor: Optional["User"] = None,
|
827
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
828
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
829
|
+
**kwargs,
|
830
|
+
):
|
831
|
+
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
832
|
+
query = select(func.count()).select_from(cls)
|
833
|
+
|
834
|
+
if actor:
|
835
|
+
query = cls.apply_access_predicate(query, actor, access, access_type)
|
836
|
+
|
837
|
+
# Apply filtering logic based on kwargs
|
838
|
+
for key, value in kwargs.items():
|
839
|
+
if value:
|
840
|
+
column = getattr(cls, key, None)
|
841
|
+
if not column:
|
842
|
+
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
843
|
+
if isinstance(value, (list, tuple, set)): # Check for iterables
|
844
|
+
query = query.where(column.in_(value))
|
845
|
+
else: # Single value for equality filtering
|
846
|
+
query = query.where(column == value)
|
847
|
+
|
848
|
+
# Handle soft deletes if the class has the 'is_deleted' attribute
|
849
|
+
if hasattr(cls, "is_deleted"):
|
850
|
+
query = query.where(cls.is_deleted == False)
|
851
|
+
|
852
|
+
return query
|
853
|
+
|
564
854
|
@classmethod
|
565
855
|
@handle_db_timeout
|
566
856
|
def size(
|
@@ -585,28 +875,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
585
875
|
Raises:
|
586
876
|
DBAPIError: If a database error occurs
|
587
877
|
"""
|
588
|
-
logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}")
|
589
|
-
|
590
878
|
with db_session as session:
|
591
|
-
query =
|
592
|
-
|
593
|
-
if actor:
|
594
|
-
query = cls.apply_access_predicate(query, actor, access, access_type)
|
595
|
-
|
596
|
-
# Apply filtering logic based on kwargs
|
597
|
-
for key, value in kwargs.items():
|
598
|
-
if value:
|
599
|
-
column = getattr(cls, key, None)
|
600
|
-
if not column:
|
601
|
-
raise AttributeError(f"{cls.__name__} has no attribute '{key}'")
|
602
|
-
if isinstance(value, (list, tuple, set)): # Check for iterables
|
603
|
-
query = query.where(column.in_(value))
|
604
|
-
else: # Single value for equality filtering
|
605
|
-
query = query.where(column == value)
|
606
|
-
|
607
|
-
# Handle soft deletes if the class has the 'is_deleted' attribute
|
608
|
-
if hasattr(cls, "is_deleted"):
|
609
|
-
query = query.where(cls.is_deleted == False)
|
879
|
+
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
|
610
880
|
|
611
881
|
try:
|
612
882
|
count = session.execute(query).scalar()
|
@@ -615,6 +885,37 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
|
615
885
|
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
616
886
|
raise e
|
617
887
|
|
888
|
+
@classmethod
|
889
|
+
@handle_db_timeout
|
890
|
+
async def size_async(
|
891
|
+
cls,
|
892
|
+
*,
|
893
|
+
db_session: "AsyncSession",
|
894
|
+
actor: Optional["User"] = None,
|
895
|
+
access: Optional[List[Literal["read", "write", "admin"]]] = ["read"],
|
896
|
+
access_type: AccessType = AccessType.ORGANIZATION,
|
897
|
+
**kwargs,
|
898
|
+
) -> int:
|
899
|
+
"""
|
900
|
+
Get the count of rows that match the provided filters.
|
901
|
+
Args:
|
902
|
+
db_session: SQLAlchemy session
|
903
|
+
**kwargs: Filters to apply to the query (e.g., column_name=value)
|
904
|
+
Returns:
|
905
|
+
int: The count of rows that match the filters
|
906
|
+
Raises:
|
907
|
+
DBAPIError: If a database error occurs
|
908
|
+
"""
|
909
|
+
async with db_session as session:
|
910
|
+
query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs)
|
911
|
+
|
912
|
+
try:
|
913
|
+
count = await session.execute(query).scalar()
|
914
|
+
return count if count else 0
|
915
|
+
except DBAPIError as e:
|
916
|
+
logger.exception(f"Failed to calculate size for {cls.__name__}")
|
917
|
+
raise e
|
918
|
+
|
618
919
|
@classmethod
|
619
920
|
def apply_access_predicate(
|
620
921
|
cls,
|