letta-nightly 0.6.3.dev20241211104238__py3-none-any.whl → 0.6.3.dev20241212104231__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of letta-nightly might be problematic. Click here for more details.

letta/agent_store/db.py DELETED
@@ -1,467 +0,0 @@
1
- import base64
2
- import json
3
- import os
4
- from datetime import datetime
5
- from typing import Dict, List, Optional
6
-
7
- import numpy as np
8
- from sqlalchemy import (
9
- BINARY,
10
- Column,
11
- DateTime,
12
- Index,
13
- String,
14
- TypeDecorator,
15
- and_,
16
- asc,
17
- desc,
18
- or_,
19
- select,
20
- text,
21
- )
22
- from sqlalchemy.orm import mapped_column
23
- from sqlalchemy.orm.session import close_all_sessions
24
- from sqlalchemy.sql import func
25
- from sqlalchemy_json import MutableJson
26
- from tqdm import tqdm
27
-
28
- from letta.agent_store.storage import StorageConnector, TableType
29
- from letta.config import LettaConfig
30
- from letta.constants import MAX_EMBEDDING_DIM
31
- from letta.metadata import EmbeddingConfigColumn
32
- from letta.orm.base import Base
33
- from letta.orm.file import FileMetadata as FileMetadataModel
34
-
35
- # from letta.schemas.message import Message, Passage, Record, RecordType, ToolCall
36
- from letta.orm.passage import Passage as PassageModel
37
- from letta.settings import settings
38
-
39
- config = LettaConfig()
40
-
41
-
42
- class CommonVector(TypeDecorator):
43
- """Common type for representing vectors in SQLite"""
44
-
45
- impl = BINARY
46
- cache_ok = True
47
-
48
- def load_dialect_impl(self, dialect):
49
- return dialect.type_descriptor(BINARY())
50
-
51
- def process_bind_param(self, value, dialect):
52
- if value is None:
53
- return value
54
- # Ensure value is a numpy array
55
- if isinstance(value, list):
56
- value = np.array(value, dtype=np.float32)
57
- # Serialize numpy array to bytes, then encode to base64 for universal compatibility
58
- return base64.b64encode(value.tobytes())
59
-
60
- def process_result_value(self, value, dialect):
61
- if not value:
62
- return value
63
- # Check database type and deserialize accordingly
64
- if dialect.name == "sqlite":
65
- # Decode from base64 and convert back to numpy array
66
- value = base64.b64decode(value)
67
- # For PostgreSQL, value is already in bytes
68
- return np.frombuffer(value, dtype=np.float32)
69
-
70
- class SQLStorageConnector(StorageConnector):
71
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
72
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
73
- self.config = config
74
-
75
- def get_filters(self, filters: Optional[Dict] = {}):
76
- if filters is not None:
77
- filter_conditions = {**self.filters, **filters}
78
- else:
79
- filter_conditions = self.filters
80
- all_filters = [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
81
- return all_filters
82
-
83
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000, offset=0):
84
- filters = self.get_filters(filters)
85
- while True:
86
- # Retrieve a chunk of records with the given page_size
87
- with self.session_maker() as session:
88
- db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
89
-
90
- # If the chunk is empty, we've retrieved all records
91
- if not db_record_chunk:
92
- break
93
-
94
- # Yield a list of Record objects converted from the chunk
95
- yield [record.to_record() for record in db_record_chunk]
96
-
97
- # Increment the offset to get the next chunk in the next iteration
98
- offset += page_size
99
-
100
- def get_all_cursor(
101
- self,
102
- filters: Optional[Dict] = {},
103
- after: str = None,
104
- before: str = None,
105
- limit: Optional[int] = 1000,
106
- order_by: str = "created_at",
107
- reverse: bool = False,
108
- ):
109
- """Get all that returns a cursor (record.id) and records"""
110
- filters = self.get_filters(filters)
111
-
112
- # generate query
113
- with self.session_maker() as session:
114
- query = session.query(self.db_model).filter(*filters)
115
- # query = query.order_by(asc(self.db_model.id))
116
-
117
- # records are sorted by the order_by field first, and then by the ID if two fields are the same
118
- if reverse:
119
- query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
120
- else:
121
- query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
122
-
123
- # cursor logic: filter records based on before/after ID
124
- if after:
125
- after_value = getattr(self.get(id=after), order_by)
126
- sort_exp = getattr(self.db_model, order_by) > after_value
127
- query = query.filter(
128
- or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
129
- )
130
- if before:
131
- before_value = getattr(self.get(id=before), order_by)
132
- sort_exp = getattr(self.db_model, order_by) < before_value
133
- query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
134
-
135
- # get records
136
- db_record_chunk = query.limit(limit).all()
137
- if not db_record_chunk:
138
- return (None, [])
139
- records = [record.to_record() for record in db_record_chunk]
140
- next_cursor = db_record_chunk[-1].id
141
- assert isinstance(next_cursor, str)
142
-
143
- # return (cursor, list[records])
144
- return (next_cursor, records)
145
-
146
- def get_all(self, filters: Optional[Dict] = {}, limit=None):
147
- filters = self.get_filters(filters)
148
- with self.session_maker() as session:
149
- if limit:
150
- db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
151
- else:
152
- db_records = session.query(self.db_model).filter(*filters).all()
153
- return [record.to_record() for record in db_records]
154
-
155
- def get(self, id: str):
156
- with self.session_maker() as session:
157
- db_record = session.get(self.db_model, id)
158
- if db_record is None:
159
- return None
160
- return db_record.to_record()
161
-
162
- def size(self, filters: Optional[Dict] = {}) -> int:
163
- # return size of table
164
- filters = self.get_filters(filters)
165
- with self.session_maker() as session:
166
- return session.query(self.db_model).filter(*filters).count()
167
-
168
- def insert(self, record):
169
- raise NotImplementedError
170
-
171
- def insert_many(self, records, show_progress=False):
172
- raise NotImplementedError
173
-
174
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
175
- raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
176
-
177
- def save(self):
178
- return
179
-
180
- def list_data_sources(self):
181
- assert self.table_type == TableType.ARCHIVAL_MEMORY, f"list_data_sources only implemented for ARCHIVAL_MEMORY"
182
- with self.session_maker() as session:
183
- unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
184
- return unique_data_sources
185
-
186
- def query_date(self, start_date, end_date, limit=None, offset=0):
187
- filters = self.get_filters({})
188
- with self.session_maker() as session:
189
- query = (
190
- session.query(self.db_model)
191
- .filter(*filters)
192
- .filter(self.db_model.created_at >= start_date)
193
- .filter(self.db_model.created_at <= end_date)
194
- .filter(self.db_model.role != "system")
195
- .filter(self.db_model.role != "tool")
196
- .offset(offset)
197
- )
198
- if limit:
199
- query = query.limit(limit)
200
- results = query.all()
201
- return [result.to_record() for result in results]
202
-
203
- def query_text(self, query, limit=None, offset=0):
204
- # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
205
- filters = self.get_filters({})
206
- with self.session_maker() as session:
207
- query = (
208
- session.query(self.db_model)
209
- .filter(*filters)
210
- .filter(func.lower(self.db_model.text).contains(func.lower(query)))
211
- .filter(self.db_model.role != "system")
212
- .filter(self.db_model.role != "tool")
213
- .offset(offset)
214
- )
215
- if limit:
216
- query = query.limit(limit)
217
- results = query.all()
218
- # return [self.type(**vars(result)) for result in results]
219
- return [result.to_record() for result in results]
220
-
221
- # Should be used only in tests!
222
- def delete_table(self):
223
- close_all_sessions()
224
- with self.session_maker() as session:
225
- self.db_model.__table__.drop(session.bind)
226
- session.commit()
227
-
228
- def delete(self, filters: Optional[Dict] = {}):
229
- filters = self.get_filters(filters)
230
- with self.session_maker() as session:
231
- session.query(self.db_model).filter(*filters).delete()
232
- session.commit()
233
-
234
-
235
- class PostgresStorageConnector(SQLStorageConnector):
236
- """Storage via Postgres"""
237
-
238
- # TODO: this should probably eventually be moved into a parent DB class
239
-
240
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
241
- from pgvector.sqlalchemy import Vector
242
-
243
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
244
-
245
- # construct URI from enviornment variables
246
- if settings.pg_uri:
247
- self.uri = settings.pg_uri
248
-
249
- # use config URI
250
- # TODO: remove this eventually (config should NOT contain URI)
251
- if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
252
- self.uri = self.config.archival_storage_uri
253
- self.db_model = PassageModel
254
- if self.config.archival_storage_uri is None:
255
- raise ValueError(f"Must specify archival_storage_uri in config {self.config.config_path}")
256
- elif table_type == TableType.FILES:
257
- self.uri = self.config.metadata_storage_uri
258
- self.db_model = FileMetadataModel
259
- if self.config.metadata_storage_uri is None:
260
- raise ValueError(f"Must specify metadata_storage_uri in config {self.config.config_path}")
261
- else:
262
- raise ValueError(f"Table type {table_type} not implemented")
263
-
264
- if settings.pg_uri:
265
- for c in self.db_model.__table__.columns:
266
- if c.name == "embedding":
267
- assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
268
-
269
- from letta.server.server import db_context
270
-
271
- self.session_maker = db_context
272
-
273
- # TODO: move to DB init
274
- if settings.pg_uri:
275
- with self.session_maker() as session:
276
- session.execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
277
-
278
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}):
279
- filters = self.get_filters(filters)
280
- with self.session_maker() as session:
281
- results = session.scalars(
282
- select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
283
- ).all()
284
-
285
- # Convert the results into Passage objects
286
- records = [result.to_record() for result in results]
287
- return records
288
-
289
- def insert_many(self, records, exists_ok=True, show_progress=False):
290
- # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
291
- if len(records) == 0:
292
- return
293
-
294
- added_ids = [] # avoid adding duplicates
295
- # NOTE: this has not great performance due to the excessive commits
296
- with self.session_maker() as session:
297
- iterable = tqdm(records) if show_progress else records
298
- for record in iterable:
299
- # db_record = self.db_model(**vars(record))
300
-
301
- if record.id in added_ids:
302
- continue
303
-
304
- existing_record = session.query(self.db_model).filter_by(id=record.id).first()
305
- if existing_record:
306
- if exists_ok:
307
- fields = record.model_dump()
308
- fields.pop("id")
309
- session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
310
- print(f"Updated record with id {record.id}")
311
- session.commit()
312
- else:
313
- raise ValueError(f"Record with id {record.id} already exists.")
314
-
315
- else:
316
- db_record = self.db_model(**record.dict())
317
- session.add(db_record)
318
- # print(f"Added record with id {record.id}")
319
- session.commit()
320
-
321
- added_ids.append(record.id)
322
-
323
- def insert(self, record, exists_ok=True):
324
- self.insert_many([record], exists_ok=exists_ok)
325
-
326
- def update(self, record):
327
- """
328
- Updates a record in the database based on the provided Record object.
329
- """
330
- with self.session_maker() as session:
331
- # Find the record by its ID
332
- db_record = session.query(self.db_model).filter_by(id=record.id).first()
333
- if not db_record:
334
- raise ValueError(f"Record with id {record.id} does not exist.")
335
-
336
- # Update the record with new values from the provided Record object
337
- for attr, value in vars(record).items():
338
- setattr(db_record, attr, value)
339
-
340
- # Commit the changes to the database
341
- session.commit()
342
-
343
- def str_to_datetime(self, str_date: str) -> datetime:
344
- val = str_date.split("-")
345
- _datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
346
- return _datetime
347
-
348
- def query_date(self, start_date, end_date, limit=None, offset=0):
349
- filters = self.get_filters({})
350
- _start_date = self.str_to_datetime(start_date) if isinstance(start_date, str) else start_date
351
- _end_date = self.str_to_datetime(end_date) if isinstance(end_date, str) else end_date
352
- with self.session_maker() as session:
353
- query = (
354
- session.query(self.db_model)
355
- .filter(*filters)
356
- .filter(self.db_model.created_at >= _start_date)
357
- .filter(self.db_model.created_at <= _end_date)
358
- .filter(self.db_model.role != "system")
359
- .filter(self.db_model.role != "tool")
360
- .offset(offset)
361
- )
362
- if limit:
363
- query = query.limit(limit)
364
- results = query.all()
365
- return [result.to_record() for result in results]
366
-
367
-
368
- class SQLLiteStorageConnector(SQLStorageConnector):
369
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
370
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
371
-
372
- # get storage URI
373
- if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
374
- self.db_model = PassageModel
375
- if settings.letta_pg_uri_no_default:
376
- self.uri = settings.letta_pg_uri_no_default
377
- else:
378
- # For SQLite, use the archival storage path
379
- self.path = config.archival_storage_path
380
- self.uri = f"sqlite:///{os.path.join(config.archival_storage_path, 'letta.db')}"
381
- elif table_type == TableType.FILES:
382
- self.path = self.config.metadata_storage_path
383
- if self.path is None:
384
- raise ValueError(f"Must specify metadata_storage_path in config.")
385
- self.db_model = FileMetadataModel
386
-
387
- else:
388
- raise ValueError(f"Table type {table_type} not implemented")
389
-
390
- self.path = os.path.join(self.path, f"sqlite.db")
391
-
392
- from letta.server.server import db_context
393
-
394
- self.session_maker = db_context
395
-
396
- # Need this in order to allow UUIDs to be stored successfully in the sqlite database
397
- # import sqlite3
398
- # import uuid
399
- #
400
- # sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
401
- # sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
402
-
403
- def insert_many(self, records, exists_ok=True, show_progress=False):
404
- # TODO: this is terrible, should eventually be done the same way for all types (migrate to SQLModel)
405
- if len(records) == 0:
406
- return
407
-
408
- added_ids = [] # avoid adding duplicates
409
- # NOTE: this has not great performance due to the excessive commits
410
- with self.session_maker() as session:
411
- iterable = tqdm(records) if show_progress else records
412
- for record in iterable:
413
- # db_record = self.db_model(**vars(record))
414
-
415
- if record.id in added_ids:
416
- continue
417
-
418
- existing_record = session.query(self.db_model).filter_by(id=record.id).first()
419
- if existing_record:
420
- if exists_ok:
421
- fields = record.model_dump()
422
- fields.pop("id")
423
- session.query(self.db_model).filter(self.db_model.id == record.id).update(fields)
424
- session.commit()
425
- else:
426
- raise ValueError(f"Record with id {record.id} already exists.")
427
-
428
- else:
429
- db_record = self.db_model(**record.dict())
430
- session.add(db_record)
431
- session.commit()
432
-
433
- added_ids.append(record.id)
434
-
435
- def insert(self, record, exists_ok=True):
436
- self.insert_many([record], exists_ok=exists_ok)
437
-
438
- def update(self, record):
439
- """
440
- Updates an existing record in the database with values from the provided record object.
441
- """
442
- if not record.id:
443
- raise ValueError("Record must have an id.")
444
-
445
- with self.session_maker() as session:
446
- # Fetch the existing record from the database
447
- db_record = session.query(self.db_model).filter_by(id=record.id).first()
448
- if not db_record:
449
- raise ValueError(f"Record with id {record.id} does not exist.")
450
-
451
- # Update the database record with values from the provided record object
452
- for column in self.db_model.__table__.columns:
453
- column_name = column.name
454
- if hasattr(record, column_name):
455
- new_value = getattr(record, column_name)
456
- setattr(db_record, column_name, new_value)
457
-
458
- # Commit the changes to the database
459
- session.commit()
460
-
461
-
462
- def attach_base():
463
- # This should be invoked in server.py to make sure Base gets initialized properly
464
- # DO NOT REMOVE
465
- from letta.utils import printd
466
-
467
- printd("Initializing database...")
@@ -1,198 +0,0 @@
1
- import uuid
2
- from copy import deepcopy
3
- from typing import Dict, Iterator, List, Optional, cast
4
-
5
- from pymilvus import DataType, MilvusClient
6
- from pymilvus.client.constants import ConsistencyLevel
7
-
8
- from letta.agent_store.storage import StorageConnector, TableType
9
- from letta.config import LettaConfig
10
- from letta.constants import MAX_EMBEDDING_DIM
11
- from letta.data_types import Passage, Record, RecordType
12
- from letta.utils import datetime_to_timestamp, printd, timestamp_to_datetime
13
-
14
-
15
- class MilvusStorageConnector(StorageConnector):
16
- """Storage via Milvus"""
17
-
18
- def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None):
19
- super().__init__(table_type=table_type, config=config, user_id=user_id, agent_id=agent_id)
20
-
21
- assert table_type in [TableType.ARCHIVAL_MEMORY, TableType.PASSAGES], "Milvus only supports archival memory"
22
- if config.archival_storage_uri:
23
- self.client = MilvusClient(uri=config.archival_storage_uri)
24
- self._create_collection()
25
- else:
26
- raise ValueError("Please set `archival_storage_uri` in the config file when using Milvus.")
27
-
28
- # need to be converted to strings
29
- self.uuid_fields = ["id", "user_id", "agent_id", "source_id", "file_id"]
30
-
31
- def _create_collection(self):
32
- schema = MilvusClient.create_schema(
33
- auto_id=False,
34
- enable_dynamic_field=True,
35
- )
36
- schema.add_field(field_name="id", datatype=DataType.VARCHAR, is_primary=True, max_length=65_535)
37
- schema.add_field(field_name="text", datatype=DataType.VARCHAR, is_primary=False, max_length=65_535)
38
- schema.add_field(field_name="embedding", datatype=DataType.FLOAT_VECTOR, dim=MAX_EMBEDDING_DIM)
39
- index_params = self.client.prepare_index_params()
40
- index_params.add_index(field_name="id")
41
- index_params.add_index(field_name="embedding", index_type="AUTOINDEX", metric_type="IP")
42
- self.client.create_collection(
43
- collection_name=self.table_name, schema=schema, index_params=index_params, consistency_level=ConsistencyLevel.Strong
44
- )
45
-
46
- def get_milvus_filter(self, filters: Optional[Dict] = {}) -> str:
47
- filter_conditions = {**self.filters, **filters} if filters is not None else self.filters
48
- if not filter_conditions:
49
- return ""
50
- conditions = []
51
- for key, value in filter_conditions.items():
52
- if key in self.uuid_fields or isinstance(key, str):
53
- condition = f'({key} == "{value}")'
54
- else:
55
- condition = f"({key} == {value})"
56
- conditions.append(condition)
57
- filter_expr = " and ".join(conditions)
58
- if len(conditions) == 1:
59
- filter_expr = filter_expr[1:-1]
60
- return filter_expr
61
-
62
- def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: int = 1000) -> Iterator[List[RecordType]]:
63
- if not self.client.has_collection(collection_name=self.table_name):
64
- yield []
65
- filter_expr = self.get_milvus_filter(filters)
66
- offset = 0
67
- while True:
68
- # Retrieve a chunk of records with the given page_size
69
- query_res = self.client.query(
70
- collection_name=self.table_name,
71
- filter=filter_expr,
72
- offset=offset,
73
- limit=page_size,
74
- )
75
- if not query_res:
76
- break
77
- # Yield a list of Record objects converted from the chunk
78
- yield self._list_to_records(query_res)
79
-
80
- # Increment the offset to get the next chunk in the next iteration
81
- offset += page_size
82
-
83
- def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[RecordType]:
84
- if not self.client.has_collection(collection_name=self.table_name):
85
- return []
86
- filter_expr = self.get_milvus_filter(filters)
87
- query_res = self.client.query(
88
- collection_name=self.table_name,
89
- filter=filter_expr,
90
- limit=limit,
91
- )
92
- return self._list_to_records(query_res)
93
-
94
- def get(self, id: str) -> Optional[RecordType]:
95
- res = self.client.get(collection_name=self.table_name, ids=str(id))
96
- return self._list_to_records(res)[0] if res else None
97
-
98
- def size(self, filters: Optional[Dict] = {}) -> int:
99
- if not self.client.has_collection(collection_name=self.table_name):
100
- return 0
101
- filter_expr = self.get_milvus_filter(filters)
102
- count_expr = "count(*)"
103
- query_res = self.client.query(
104
- collection_name=self.table_name,
105
- filter=filter_expr,
106
- output_fields=[count_expr],
107
- )
108
- doc_num = query_res[0][count_expr]
109
- return doc_num
110
-
111
- def insert(self, record: RecordType):
112
- self.insert_many([record])
113
-
114
- def insert_many(self, records: List[RecordType], show_progress=False):
115
- if not records:
116
- return
117
-
118
- # Milvus lite currently does not support upsert, so we delete and insert instead
119
- # self.client.upsert(collection_name=self.table_name, data=self._records_to_list(records))
120
- ids = [str(record.id) for record in records]
121
- self.client.delete(collection_name=self.table_name, ids=ids)
122
- data = self._records_to_list(records)
123
- self.client.insert(collection_name=self.table_name, data=data)
124
-
125
- def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[RecordType]:
126
- if not self.client.has_collection(self.table_name):
127
- return []
128
- search_res = self.client.search(
129
- collection_name=self.table_name, data=[query_vec], filter=self.get_milvus_filter(filters), limit=top_k, output_fields=["*"]
130
- )[0]
131
- entity_res = [res["entity"] for res in search_res]
132
- return self._list_to_records(entity_res)
133
-
134
- def delete_table(self):
135
- self.client.drop_collection(collection_name=self.table_name)
136
-
137
- def delete(self, filters: Optional[Dict] = {}):
138
- if not self.client.has_collection(collection_name=self.table_name):
139
- return
140
- filter_expr = self.get_milvus_filter(filters)
141
- self.client.delete(collection_name=self.table_name, filter=filter_expr)
142
-
143
- def save(self):
144
- # save to persistence file (nothing needs to be done)
145
- printd("Saving milvus")
146
-
147
- def _records_to_list(self, records: List[Record]) -> List[Dict]:
148
- if records == []:
149
- return []
150
- assert all(isinstance(r, Passage) for r in records)
151
- record_list = []
152
- records = list(set(records))
153
- for record in records:
154
- record_vars = deepcopy(vars(record))
155
- _id = record_vars.pop("id")
156
- text = record_vars.pop("text", "")
157
- embedding = record_vars.pop("embedding")
158
- record_metadata = record_vars.pop("metadata_", None) or {}
159
- if "created_at" in record_vars:
160
- record_vars["created_at"] = datetime_to_timestamp(record_vars["created_at"])
161
- record_dict = {key: value for key, value in record_vars.items() if value is not None}
162
- record_dict = {
163
- **record_dict,
164
- **record_metadata,
165
- "id": str(_id),
166
- "text": text,
167
- "embedding": embedding,
168
- }
169
- for key, value in record_dict.items():
170
- if key in self.uuid_fields:
171
- record_dict[key] = str(value)
172
- record_list.append(record_dict)
173
- return record_list
174
-
175
- def _list_to_records(self, query_res: List[Dict]) -> List[RecordType]:
176
- records = []
177
- for res_dict in query_res:
178
- _id = res_dict.pop("id")
179
- embedding = res_dict.pop("embedding")
180
- text = res_dict.pop("text")
181
- metadata = deepcopy(res_dict)
182
- for key, value in metadata.items():
183
- if key in self.uuid_fields:
184
- metadata[key] = uuid.UUID(value)
185
- elif key == "created_at":
186
- metadata[key] = timestamp_to_datetime(value)
187
- records.append(
188
- cast(
189
- RecordType,
190
- self.type(
191
- text=text,
192
- embedding=embedding,
193
- id=uuid.UUID(_id),
194
- **metadata,
195
- ),
196
- )
197
- )
198
- return records