langchain-sqlserver 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2023 LangChain, Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,31 @@
1
+ Metadata-Version: 2.1
2
+ Name: langchain-sqlserver
3
+ Version: 0.1.0
4
+ Summary: An integration package to support SQL Server in LangChain.
5
+ License: MIT
6
+ Requires-Python: >=3.9,<4.0
7
+ Classifier: License :: OSI Approved :: MIT License
8
+ Classifier: Programming Language :: Python :: 3
9
+ Classifier: Programming Language :: Python :: 3.9
10
+ Classifier: Programming Language :: Python :: 3.10
11
+ Classifier: Programming Language :: Python :: 3.11
12
+ Classifier: Programming Language :: Python :: 3.12
13
+ Requires-Dist: SQLAlchemy (>=2.0.0,<3)
14
+ Requires-Dist: azure-identity (>=1.16.0,<2.0.0)
15
+ Requires-Dist: langchain-core (>=0.3.0,<0.4.0)
16
+ Requires-Dist: numpy (>=1,<2)
17
+ Requires-Dist: pyodbc (>=5.0.0,<6.0.0)
18
+ Project-URL: Release Notes, https://github.com/langchain-ai/langchain-azure/releases
19
+ Project-URL: Source Code, https://github.com/langchain-ai/langchain-azure/tree/main/libs/sqlserver
20
+ Description-Content-Type: text/markdown
21
+
22
+ # langchain-sqlserver
23
+
24
+ This package contains the LangChain integration for SQL Server. You can use this package to manage vectorstores in SQL Server.
25
+
26
+ ## Installation
27
+
28
+ ```bash
29
+ pip install -U langchain-sqlserver
30
+ ```
31
+
@@ -0,0 +1,9 @@
1
+ # langchain-sqlserver
2
+
3
+ This package contains the LangChain integration for SQL Server. You can use this package to manage vectorstores in SQL Server.
4
+
5
+ ## Installation
6
+
7
+ ```bash
8
+ pip install -U langchain-sqlserver
9
+ ```
@@ -0,0 +1,7 @@
1
+ """LangChain integration for SQL Server."""
2
+
3
+ from langchain_sqlserver.vectorstores import SQLServer_VectorStore
4
+
5
+ __all__ = [
6
+ "SQLServer_VectorStore",
7
+ ]
File without changes
@@ -0,0 +1,1202 @@
1
+ """This is the SQL Server module.
2
+
3
+ This module provides the SQLServer_VectorStore class for managing
4
+ vectorstores in SQL Server.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import json
10
+ import logging
11
+ import struct
12
+ import uuid
13
+ from enum import Enum
14
+ from typing import (
15
+ Any,
16
+ Callable,
17
+ Dict,
18
+ Iterable,
19
+ List,
20
+ MutableMapping,
21
+ Optional,
22
+ Sequence,
23
+ Tuple,
24
+ Type,
25
+ Union,
26
+ )
27
+ from urllib.parse import urlparse
28
+
29
+ import numpy as np
30
+ import sqlalchemy
31
+ from azure.identity import DefaultAzureCredential
32
+ from langchain_core.documents import Document
33
+ from langchain_core.embeddings import Embeddings
34
+ from langchain_core.vectorstores import VectorStore
35
+ from langchain_core.vectorstores.utils import maximal_marginal_relevance
36
+ from sqlalchemy import (
37
+ Column,
38
+ ColumnElement,
39
+ Dialect,
40
+ Index,
41
+ Numeric,
42
+ PrimaryKeyConstraint,
43
+ SQLColumnExpression,
44
+ Uuid,
45
+ asc,
46
+ bindparam,
47
+ cast,
48
+ create_engine,
49
+ event,
50
+ func,
51
+ insert,
52
+ label,
53
+ select,
54
+ text,
55
+ )
56
+ from sqlalchemy.dialects.mssql import JSON, NVARCHAR, VARCHAR
57
+ from sqlalchemy.dialects.mssql.base import MSTypeCompiler
58
+ from sqlalchemy.engine import Connection, Engine
59
+ from sqlalchemy.exc import DBAPIError, ProgrammingError
60
+ from sqlalchemy.ext.compiler import compiles
61
+ from sqlalchemy.orm import Session, declarative_base
62
+ from sqlalchemy.pool import ConnectionPoolEntry
63
+ from sqlalchemy.sql import operators
64
+ from sqlalchemy.types import UserDefinedType
65
+
66
+ COMPARISONS_TO_NATIVE: Dict[str, Callable[[ColumnElement, object], ColumnElement]] = {
67
+ "$eq": operators.eq,
68
+ "$ne": operators.ne,
69
+ }
70
+
71
+ NUMERIC_OPERATORS: Dict[str, Callable[[ColumnElement, object], ColumnElement]] = {
72
+ "$lt": operators.lt,
73
+ "$lte": operators.le,
74
+ "$gt": operators.gt,
75
+ "$gte": operators.ge,
76
+ }
77
+
78
+ SPECIAL_CASED_OPERATORS = {
79
+ "$in",
80
+ "$nin",
81
+ "$like",
82
+ }
83
+
84
+ BETWEEN_OPERATOR = {"$between"}
85
+
86
+ LOGICAL_OPERATORS = {"$and", "$or"}
87
+
88
+ SUPPORTED_OPERATORS = (
89
+ set(COMPARISONS_TO_NATIVE)
90
+ .union(NUMERIC_OPERATORS)
91
+ .union(SPECIAL_CASED_OPERATORS)
92
+ .union(BETWEEN_OPERATOR)
93
+ .union(LOGICAL_OPERATORS)
94
+ )
95
+
96
+
97
+ class DistanceStrategy(str, Enum):
98
+ """Distance Strategy class for SQLServer_VectorStore.
99
+
100
+ Enumerator of the distance strategies for calculating distances
101
+ between vectors.
102
+ """
103
+
104
+ EUCLIDEAN = "euclidean"
105
+ COSINE = "cosine"
106
+ DOT = "dot"
107
+
108
+
109
+ class VectorType(UserDefinedType):
110
+ """VectorType - A custom type definition."""
111
+
112
+ cache_ok = True
113
+
114
+ def __init__(self, length: int) -> None:
115
+ """__init__ for VectorType class."""
116
+ self.length = length
117
+
118
+ def get_col_spec(self, **kw: Any) -> str:
119
+ """get_col_spec function for VectorType class."""
120
+ return "vector(%s)" % self.length
121
+
122
+ def bind_processor(self, dialect: Any) -> Any:
123
+ """bind_processor function for VectorType class."""
124
+
125
+ def process(value: Any) -> Any:
126
+ return value
127
+
128
+ return process
129
+
130
+ def result_processor(self, dialect: Any, coltype: Any) -> Any:
131
+ """result_processor function for VectorType class."""
132
+
133
+ def process(value: Any) -> Any:
134
+ return value
135
+
136
+ return process
137
+
138
+
139
+ # String Constants
140
+ #
141
+ AZURE_TOKEN_URL = "https://database.windows.net/.default" # Token URL for Azure DBs.
142
+ DISTANCE = "distance"
143
+ DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
144
+ DEFAULT_TABLE_NAME = "sqlserver_vectorstore"
145
+ DISTANCE_STRATEGY = "distancestrategy"
146
+ EMBEDDING = "embedding"
147
+ EMBEDDING_LENGTH = "embedding_length"
148
+ EMBEDDING_VALUES = "embeddingvalues"
149
+ EMPTY_IDS_ERROR_MESSAGE = "Empty list of ids provided"
150
+ EXTRA_PARAMS = ";Trusted_Connection=Yes"
151
+ INVALID_IDS_ERROR_MESSAGE = "Invalid list of ids provided"
152
+ INVALID_INPUT_ERROR_MESSAGE = "Input is not valid."
153
+ INVALID_FILTER_INPUT_EXPECTED_DICT = """Invalid filter condition. Expected a dictionary
154
+ but got an empty dictionary"""
155
+ INVALID_FILTER_INPUT_EXPECTED_AND_OR = """Invalid filter condition.
156
+ Expected $and or $or but got: {}"""
157
+
158
+ SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option defined by microsoft in msodbcsql.h
159
+
160
+ # Query Constants
161
+ #
162
+ JSON_TO_VECTOR_QUERY = f"cast (:{EMBEDDING_VALUES} as vector(:{EMBEDDING_LENGTH}))"
163
+ SERVER_JSON_CHECK_QUERY = "select name from sys.types where system_type_id = 244"
164
+ VECTOR_DISTANCE_QUERY = f"""
165
+ VECTOR_DISTANCE(:{DISTANCE_STRATEGY},
166
+ cast (:{EMBEDDING} as vector(:{EMBEDDING_LENGTH})), embeddings)"""
167
+
168
+
169
+ class SQLServer_VectorStore(VectorStore):
170
+ """SQL Server Vector Store.
171
+
172
+ This class provides a vector store interface for adding texts and performing
173
+ similarity searches on the texts in SQL Server.
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ *,
179
+ connection: Optional[Connection] = None,
180
+ connection_string: str,
181
+ db_schema: Optional[str] = None,
182
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
183
+ embedding_function: Embeddings,
184
+ embedding_length: int,
185
+ relevance_score_fn: Optional[Callable[[float], float]] = None,
186
+ table_name: str = DEFAULT_TABLE_NAME,
187
+ ) -> None:
188
+ """Initialize the SQL Server vector store.
189
+
190
+ Args:
191
+ connection: Optional SQLServer connection.
192
+ connection_string: SQLServer connection string.
193
+ If the connection string does not contain a username & password
194
+ or `Trusted_Connection=yes`, Entra ID authentication is used.
195
+ Sample connection string format:
196
+ "mssql+pyodbc://username:password@servername/dbname?other_params"
197
+ db_schema: The schema in which the vector store will be created.
198
+ This schema must exist and the user must have permissions to the schema.
199
+ distance_strategy: The distance strategy to use for comparing embeddings.
200
+ Default value is COSINE. Available options are:
201
+ - COSINE
202
+ - DOT
203
+ - EUCLIDEAN
204
+ embedding_function: Any embedding function implementing
205
+ `langchain.embeddings.base.Embeddings` interface.
206
+ embedding_length: The length (dimension) of the vectors to be stored in the
207
+ table.
208
+ Note that only vectors of same size can be added to the vector store.
209
+ relevance_score_fn: Relevance score funtion to be used.
210
+ Optional param, defaults to None.
211
+ table_name: The name of the table to use for storing embeddings.
212
+ Default value is `sqlserver_vectorstore`.
213
+ """
214
+ self.connection_string = connection_string
215
+ self._distance_strategy = distance_strategy
216
+ self.embedding_function = embedding_function
217
+ self._embedding_length = embedding_length
218
+ self.schema = db_schema
219
+ self.override_relevance_score_fn = relevance_score_fn
220
+ self.table_name = table_name
221
+ self._bind: Union[Connection, Engine] = (
222
+ connection if connection else self._create_engine()
223
+ )
224
+ self._prepare_json_data_type()
225
+ self._embedding_store = self._get_embedding_store(self.table_name, self.schema)
226
+ self._create_table_if_not_exists()
227
+
228
+ def _can_connect_with_entra_id(self) -> bool:
229
+ """Determine if Entra ID authentication can be used.
230
+
231
+ Check the components of the connection string to determine
232
+ if connection via Entra ID authentication is possible or not.
233
+
234
+ The connection string is of expected to be of the form:
235
+ "mssql+pyodbc://username:password@servername/dbname?other_params"
236
+ which gets parsed into -> <scheme>://<netloc>/<path>?<query>
237
+ """
238
+ parsed_url = urlparse(self.connection_string)
239
+
240
+ if parsed_url is None:
241
+ logging.error("Unable to parse connection string.")
242
+ return False
243
+
244
+ if (parsed_url.username and parsed_url.password) or (
245
+ "trusted_connection=yes" in parsed_url.query.lower()
246
+ ):
247
+ return False
248
+
249
+ return True
250
+
251
+ def _create_engine(self) -> Engine:
252
+ if self._can_connect_with_entra_id():
253
+ # Use Entra ID auth. Listen for a connection event
254
+ # when `_create_engine` function from this class is called.
255
+ #
256
+ event.listen(Engine, "do_connect", self._provide_token, once=True)
257
+ logging.info("Using Entra ID Authentication.")
258
+
259
+ return create_engine(url=self.connection_string)
260
+
261
+ def _create_table_if_not_exists(self) -> None:
262
+ logging.info(f"Creating table {self.table_name}.")
263
+ try:
264
+ with Session(self._bind) as session:
265
+ self._embedding_store.__table__.create(
266
+ session.get_bind(), checkfirst=True
267
+ )
268
+ session.commit()
269
+ except ProgrammingError as e:
270
+ logging.error(f"Create table {self.table_name} failed.")
271
+ raise Exception(e.__cause__) from None
272
+
273
+ def _get_embedding_store(self, name: str, schema: Optional[str]) -> Any:
274
+ DynamicBase = declarative_base(class_registry=dict()) # type: Any
275
+ if self._embedding_length is None or self._embedding_length < 1:
276
+ raise ValueError("`embedding_length` value is not valid.")
277
+
278
+ class EmbeddingStore(DynamicBase):
279
+ """This is the base model for SQL vector store."""
280
+
281
+ __tablename__ = name
282
+ __table_args__ = (
283
+ PrimaryKeyConstraint("id", mssql_clustered=False),
284
+ Index("idx_custom_id", "custom_id", mssql_clustered=False, unique=True),
285
+ {"schema": schema},
286
+ )
287
+ id = Column(Uuid, primary_key=True, default=uuid.uuid4)
288
+ custom_id = Column(
289
+ VARCHAR(1000), nullable=True
290
+ ) # column for user defined ids.
291
+ content_metadata = Column(JSON, nullable=True)
292
+ content = Column(NVARCHAR, nullable=False) # defaults to NVARCHAR(MAX)
293
+ embeddings = Column(VectorType(self._embedding_length), nullable=False)
294
+
295
+ return EmbeddingStore
296
+
297
+ def _prepare_json_data_type(self) -> None:
298
+ """Prepare for JSON data type usage.
299
+
300
+ Check if the server has the JSON data type available. If it does,
301
+ we compile JSON data type as JSON instead of NVARCHAR(max) used by
302
+ sqlalchemy. If it doesn't, this defaults to NVARCHAR(max) as specified
303
+ by sqlalchemy.
304
+ """
305
+ try:
306
+ with Session(self._bind) as session:
307
+ result = session.scalar(text(SERVER_JSON_CHECK_QUERY))
308
+ session.close()
309
+
310
+ if result is not None:
311
+
312
+ @compiles(JSON, "mssql")
313
+ def compile_json(
314
+ element: JSON, compiler: MSTypeCompiler, **kw: Any
315
+ ) -> str:
316
+ # return JSON when JSON data type is specified in this class.
317
+ return result # json data type name in sql server
318
+
319
+ except ProgrammingError as e:
320
+ logging.error(f"Unable to get data types.\n {e.__cause__}\n")
321
+
322
+ @property
323
+ def embeddings(self) -> Embeddings:
324
+ """`embeddings` property for SQLServer_VectorStore class."""
325
+ return self.embedding_function
326
+
327
+ @property
328
+ def distance_strategy(self) -> str:
329
+ """distance_strategy property for SQLServer_VectorStore class."""
330
+ # Value of distance strategy passed in should be one of the supported values.
331
+ if isinstance(self._distance_strategy, DistanceStrategy):
332
+ return self._distance_strategy.value
333
+
334
+ # Match string value with appropriate enum value, if supported.
335
+ distance_strategy_lower = str.lower(self._distance_strategy)
336
+
337
+ if distance_strategy_lower == DistanceStrategy.EUCLIDEAN.value:
338
+ return DistanceStrategy.EUCLIDEAN.value
339
+ elif distance_strategy_lower == DistanceStrategy.COSINE.value:
340
+ return DistanceStrategy.COSINE.value
341
+ elif distance_strategy_lower == DistanceStrategy.DOT.value:
342
+ return DistanceStrategy.DOT.value
343
+ else:
344
+ raise ValueError(f"{self._distance_strategy} is not supported.")
345
+
346
+ @distance_strategy.setter
347
+ def distance_strategy(self, value: DistanceStrategy) -> None:
348
+ self._distance_strategy = value
349
+
350
+ @classmethod
351
+ def from_texts(
352
+ cls: Type[SQLServer_VectorStore],
353
+ texts: List[str],
354
+ embedding: Embeddings,
355
+ metadatas: Optional[List[dict]] = None,
356
+ connection_string: str = str(),
357
+ embedding_length: int = 0,
358
+ table_name: str = DEFAULT_TABLE_NAME,
359
+ db_schema: Optional[str] = None,
360
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
361
+ ids: Optional[List[str]] = None,
362
+ **kwargs: Any,
363
+ ) -> SQLServer_VectorStore:
364
+ """Create a SQL Server vectorStore initialized from texts and embeddings.
365
+
366
+ Args:
367
+ texts: Iterable of strings to add into the vectorstore.
368
+ embedding: Any embedding function implementing
369
+ `langchain.embeddings.base.Embeddings` interface.
370
+ metadatas: Optional list of metadatas (python dicts) associated
371
+ with the input texts.
372
+ connection_string: SQLServer connection string.
373
+ If the connection string does not contain a username & password
374
+ or `Trusted_Connection=yes`, Entra ID authentication is used.
375
+ Sample connection string format:
376
+ "mssql+pyodbc://username:password@servername/dbname?other_params"
377
+ embedding_length: The length (dimension) of the vectors to be stored in the
378
+ table.
379
+ Note that only vectors of same size can be added to the vector store.
380
+ table_name: The name of the table to use for storing embeddings.
381
+ db_schema: The schema in which the vector store will be created.
382
+ This schema must exist and the user must have permissions to the schema.
383
+ distance_strategy: The distance strategy to use for comparing embeddings.
384
+ Default value is COSINE. Available options are:
385
+ - COSINE
386
+ - DOT
387
+ - EUCLIDEAN
388
+ ids: Optional list of IDs for the input texts.
389
+ **kwargs: vectorstore specific parameters.
390
+
391
+ Returns:
392
+ SQLServer_VectorStore: A SQL Server vectorstore.
393
+ """
394
+ store = cls(
395
+ connection_string=connection_string,
396
+ db_schema=db_schema,
397
+ distance_strategy=distance_strategy,
398
+ embedding_function=embedding,
399
+ embedding_length=embedding_length,
400
+ table_name=table_name,
401
+ **kwargs,
402
+ )
403
+
404
+ store.add_texts(texts, metadatas, ids, **kwargs)
405
+ return store
406
+
407
+ @classmethod
408
+ def from_documents(
409
+ cls: Type[SQLServer_VectorStore],
410
+ documents: List[Document],
411
+ embedding: Embeddings,
412
+ connection_string: str = str(),
413
+ embedding_length: int = 0,
414
+ table_name: str = DEFAULT_TABLE_NAME,
415
+ db_schema: Optional[str] = None,
416
+ distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
417
+ ids: Optional[List[str]] = None,
418
+ **kwargs: Any,
419
+ ) -> SQLServer_VectorStore:
420
+ """Create a SQL Server vectorStore initialized from texts and embeddings.
421
+
422
+ Args:
423
+ documents: Documents to add to the vectorstore.
424
+ embedding: Any embedding function implementing
425
+ `langchain.embeddings.base.Embeddings` interface.
426
+ connection_string: SQLServer connection string.
427
+ If the connection string does not contain a username & password
428
+ or `Trusted_Connection=yes`, Entra ID authentication is used.
429
+ Sample connection string format:
430
+ "mssql+pyodbc://username:password@servername/dbname?other_params"
431
+ embedding_length: The length (dimension) of the vectors to be stored in the
432
+ table.
433
+ Note that only vectors of same size can be added to the vector store.
434
+ table_name: The name of the table to use for storing embeddings.
435
+ Default value is `sqlserver_vectorstore`.
436
+ db_schema: The schema in which the vector store will be created.
437
+ This schema must exist and the user must have permissions to the schema.
438
+ distance_strategy: The distance strategy to use for comparing embeddings.
439
+ Default value is COSINE. Available options are:
440
+ - COSINE
441
+ - DOT
442
+ - EUCLIDEAN
443
+ ids: Optional list of IDs for the input texts.
444
+ **kwargs: vectorstore specific parameters.
445
+
446
+ Returns:
447
+ SQLServer_VectorStore: A SQL Server vectorstore.
448
+ """
449
+ texts, metadatas = [], []
450
+
451
+ for doc in documents:
452
+ if not isinstance(doc, Document):
453
+ raise ValueError(
454
+ f"Expected an entry of type Document, but got {type(doc)}"
455
+ )
456
+
457
+ texts.append(doc.page_content)
458
+ metadatas.append(doc.metadata)
459
+
460
+ store = cls(
461
+ connection_string=connection_string,
462
+ db_schema=db_schema,
463
+ distance_strategy=distance_strategy,
464
+ embedding_function=embedding,
465
+ embedding_length=embedding_length,
466
+ table_name=table_name,
467
+ **kwargs,
468
+ )
469
+
470
+ store.add_texts(texts, metadatas, ids, **kwargs)
471
+ return store
472
+
473
+ def get_by_ids(self, ids: Sequence[str], /) -> List[Document]:
474
+ """Get documents by their IDs from the vectorstore.
475
+
476
+ Args:
477
+ ids: List of IDs to retrieve.
478
+
479
+ Returns:
480
+ List of Documents
481
+ """
482
+ documents = []
483
+
484
+ if ids is None or len(ids) == 0:
485
+ logging.info(EMPTY_IDS_ERROR_MESSAGE)
486
+ else:
487
+ result = self._get_documents_by_ids(ids)
488
+ for item in result:
489
+ if item is not None:
490
+ documents.append(
491
+ Document(
492
+ id=item.custom_id,
493
+ page_content=item.content,
494
+ metadata=item.content_metadata,
495
+ )
496
+ )
497
+
498
+ return documents
499
+
500
+ def _get_documents_by_ids(self, ids: Sequence[str], /) -> Sequence[Any]:
501
+ result: Sequence[Any] = []
502
+ try:
503
+ with Session(bind=self._bind) as session:
504
+ statement = select(
505
+ self._embedding_store.custom_id,
506
+ self._embedding_store.content,
507
+ self._embedding_store.content_metadata,
508
+ ).where(self._embedding_store.custom_id.in_(ids))
509
+ result = session.execute(statement).fetchall()
510
+ except DBAPIError as e:
511
+ logging.error(e.__cause__)
512
+ return result
513
+
514
+ def _select_relevance_score_fn(self) -> Callable[[float], float]:
515
+ """Determine relevance score function.
516
+
517
+ The 'correct' relevance function
518
+ may differ depending on a few things, including:
519
+ - the distance / similarity metric used by the VectorStore
520
+ - the scale of your embeddings (OpenAI's are unit normed. Many others are not!)
521
+ - embedding dimensionality
522
+ - etc.
523
+ If no relevance function is provided in the class constructor,
524
+ selection is based on the distance strategy provided.
525
+ """
526
+ if self.override_relevance_score_fn is not None:
527
+ return self.override_relevance_score_fn
528
+
529
+ # If the relevance score function is not provided, we default to using
530
+ # the distance strategy specified by the user.
531
+ if self._distance_strategy == DistanceStrategy.COSINE:
532
+ return self._cosine_relevance_score_fn
533
+ elif self._distance_strategy == DistanceStrategy.DOT:
534
+ return self._max_inner_product_relevance_score_fn
535
+ elif self._distance_strategy == DistanceStrategy.EUCLIDEAN:
536
+ return self._euclidean_relevance_score_fn
537
+ else:
538
+ raise ValueError(
539
+ "There is no supported normalization function for"
540
+ f" {self._distance_strategy} distance strategy."
541
+ "Consider providing relevance_score_fn to "
542
+ "SQLServer_VectorStore construction."
543
+ )
544
+
545
+ def max_marginal_relevance_search(
546
+ self,
547
+ query: str,
548
+ k: int = 4,
549
+ fetch_k: int = 20,
550
+ lambda_mult: float = 0.5,
551
+ **kwargs: Any,
552
+ ) -> List[Document]:
553
+ """Return docs selected using the maximal marginal relevance.
554
+
555
+ Maximal marginal relevance optimizes for similarity to query AND diversity
556
+ among selected documents.
557
+
558
+ Args:
559
+ query: Text to look up documents similar to.
560
+ k: Number of Documents to return. Defaults to 4.
561
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
562
+ Default is 20.
563
+ lambda_mult: Number between 0 and 1 that determines the degree
564
+ of diversity among the results with 0 corresponding
565
+ to maximum diversity and 1 to minimum diversity.
566
+ Defaults to 0.5.
567
+ **kwargs: Arguments to pass to the search method.
568
+
569
+ Returns:
570
+ List of Documents selected by maximal marginal relevance.
571
+ """
572
+ embedded_query = self.embedding_function.embed_query(query)
573
+ return self.max_marginal_relevance_search_by_vector(
574
+ embedded_query, k=k, fetch_k=fetch_k, lambda_mult=lambda_mult, **kwargs
575
+ )
576
+
577
+ def max_marginal_relevance_search_by_vector(
578
+ self,
579
+ embedding: list[float],
580
+ k: int = 4,
581
+ fetch_k: int = 20,
582
+ lambda_mult: float = 0.5,
583
+ **kwargs: Any,
584
+ ) -> List[Document]:
585
+ """Return docs selected using the maximal marginal relevance.
586
+
587
+ Maximal marginal relevance optimizes for similarity to query AND diversity
588
+ among selected documents.
589
+
590
+ Args:
591
+ embedding: Embedding to look up documents similar to.
592
+ k: Number of Documents to return. Defaults to 4.
593
+ fetch_k: Number of Documents to fetch to pass to MMR algorithm.
594
+ Default is 20.
595
+ lambda_mult: Number between 0 and 1 that determines the degree
596
+ of diversity among the results with 0 corresponding
597
+ to maximum diversity and 1 to minimum diversity.
598
+ Defaults to 0.5.
599
+ **kwargs: Arguments to pass to the search method.
600
+
601
+ Returns:
602
+ List of Documents selected by maximal marginal relevance.
603
+ """
604
+ results = self._search_store(
605
+ embedding, k=fetch_k, marginal_relevance=True, **kwargs
606
+ )
607
+ embedding_list = [json.loads(result[0]) for result in results]
608
+
609
+ mmr_selects = maximal_marginal_relevance(
610
+ np.array(embedding, dtype=np.float32),
611
+ embedding_list,
612
+ lambda_mult=lambda_mult,
613
+ k=k,
614
+ )
615
+
616
+ results_as_docs = self._docs_from_result(
617
+ self._docs_and_scores_from_result(results)
618
+ )
619
+
620
+ # Return list of Documents from results_as_docs whose position
621
+ # corresponds to the indices in mmr_selects.
622
+ return [
623
+ value for idx, value in enumerate(results_as_docs) if idx in mmr_selects
624
+ ]
625
+
626
+ def similarity_search(
627
+ self, query: str, k: int = 4, **kwargs: Any
628
+ ) -> List[Document]:
629
+ """Return docs most similar to given query.
630
+
631
+ Args:
632
+ query: Text to look up the most similar embedding to.
633
+ k: Number of Documents to return. Defaults to 4.
634
+ **kwargs: Values for filtering on metadata during similarity search.
635
+
636
+ Returns:
637
+ List of Documents most similar to the query provided.
638
+ """
639
+ embedded_query = self.embedding_function.embed_query(query)
640
+ return self.similarity_search_by_vector(embedded_query, k, **kwargs)
641
+
642
+ def similarity_search_by_vector(
643
+ self, embedding: List[float], k: int = 4, **kwargs: Any
644
+ ) -> List[Document]:
645
+ """Return docs most similar to the embedding vector.
646
+
647
+ Args:
648
+ embedding: Embedding to look up documents similar to.
649
+ k: Number of Documents to return. Defaults to 4.
650
+ **kwargs: Values for filtering on metadata during similarity search.
651
+
652
+ Returns:
653
+ List of Documents most similar to the embedding provided.
654
+ """
655
+ similar_docs_with_scores = self.similarity_search_by_vector_with_score(
656
+ embedding, k, **kwargs
657
+ )
658
+ return self._docs_from_result(similar_docs_with_scores)
659
+
660
+ def similarity_search_with_score(
661
+ self, query: str, k: int = 4, **kwargs: Any
662
+ ) -> List[Tuple[Document, float]]:
663
+ """Similarity search with score.
664
+
665
+ Run similarity search with distance and
666
+ return docs most similar to the embedding vector.
667
+
668
+ Args:
669
+ query: Text to look up the most similar embedding to.
670
+ k: Number of Documents to return. Defaults to 4.
671
+ **kwargs: Values for filtering on metadata during similarity search.
672
+
673
+ Returns:
674
+ List of tuple of Document and an accompanying score in order of
675
+ similarity to the query provided.
676
+ Note that, a smaller score implies greater similarity.
677
+ """
678
+ embedded_query = self.embedding_function.embed_query(query)
679
+ return self.similarity_search_by_vector_with_score(embedded_query, k, **kwargs)
680
+
681
+ def similarity_search_by_vector_with_score(
682
+ self, embedding: List[float], k: int = 4, **kwargs: Any
683
+ ) -> List[Tuple[Document, float]]:
684
+ """Similarity search by vector with score.
685
+
686
+ Run similarity search with distance, given an embedding
687
+ and return docs most similar to the embedding vector.
688
+
689
+ Args:
690
+ embedding: Embedding to look up documents similar to.
691
+ k: Number of Documents to return. Defaults to 4.
692
+ **kwargs: Values for filtering on metadata during similarity search.
693
+
694
+ Returns:
695
+ List of tuple of Document and an accompanying score in order of
696
+ similarity to the embedding provided.
697
+ Note that, a smaller score implies greater similarity.
698
+ """
699
+ similar_docs = self._search_store(embedding, k, **kwargs)
700
+ docs_and_scores = self._docs_and_scores_from_result(similar_docs)
701
+ return docs_and_scores
702
+
703
+ def add_texts(
704
+ self,
705
+ texts: Iterable[str],
706
+ metadatas: Optional[List[dict]] = None,
707
+ ids: Optional[List[str]] = None,
708
+ **kwargs: Any,
709
+ ) -> List[str]:
710
+ """`add_texts` function for SQLServer_VectorStore class.
711
+
712
+ Compute the embeddings for the input texts and store embeddings
713
+ in the vectorstore.
714
+
715
+ Args:
716
+ texts: Iterable of strings to add into the vectorstore.
717
+ metadatas: List of metadatas (python dicts) associated with the input texts.
718
+ ids: List of IDs for the input texts.
719
+ **kwargs: vectorstore specific parameters.
720
+
721
+ Returns:
722
+ List of IDs generated from adding the texts into the vectorstore.
723
+ """
724
+ # Embed the texts passed in.
725
+ embedded_texts = self.embedding_function.embed_documents(list(texts))
726
+
727
+ # Insert the embedded texts in the vector store table.
728
+ return self._insert_embeddings(texts, embedded_texts, metadatas, ids)
729
+
730
+ def drop(self) -> None:
731
+ """Drops every table created during initialization of vector store."""
732
+ logging.info(f"Dropping vector store: {self.table_name}")
733
+ try:
734
+ with Session(bind=self._bind) as session:
735
+ # Drop the table associated with the session bind.
736
+ self._embedding_store.__table__.drop(session.get_bind())
737
+ session.commit()
738
+
739
+ logging.info(f"Vector store `{self.table_name}` dropped successfully.")
740
+
741
+ except ProgrammingError as e:
742
+ logging.error(f"Unable to drop vector store.\n {e.__cause__}.")
743
+
744
+ def _search_store(
745
+ self,
746
+ embedding: List[float],
747
+ k: int,
748
+ filter: Optional[dict] = None,
749
+ marginal_relevance: Optional[bool] = False,
750
+ ) -> List[Any]:
751
+ try:
752
+ with Session(self._bind) as session:
753
+ filter_by = []
754
+ filter_clauses = self._create_filter_clause(filter)
755
+ if filter_clauses is not None:
756
+ filter_by.append(filter_clauses)
757
+
758
+ subquery = label(
759
+ DISTANCE,
760
+ text(VECTOR_DISTANCE_QUERY).bindparams(
761
+ bindparam(
762
+ DISTANCE_STRATEGY,
763
+ self.distance_strategy,
764
+ literal_execute=True,
765
+ ),
766
+ bindparam(
767
+ EMBEDDING,
768
+ json.dumps(embedding),
769
+ literal_execute=True,
770
+ ),
771
+ bindparam(
772
+ EMBEDDING_LENGTH,
773
+ self._embedding_length,
774
+ literal_execute=True,
775
+ ),
776
+ ),
777
+ )
778
+
779
+ # Results for marginal relevance includes additional
780
+ # column for embeddings.
781
+ if marginal_relevance:
782
+ query = (
783
+ select(
784
+ text("cast (embeddings as NVARCHAR(MAX))"),
785
+ subquery,
786
+ self._embedding_store,
787
+ )
788
+ .filter(*filter_by)
789
+ .order_by(asc(text(DISTANCE)))
790
+ .limit(k)
791
+ )
792
+ results = list(session.execute(query).fetchall())
793
+ else:
794
+ results = (
795
+ session.query(
796
+ self._embedding_store,
797
+ subquery,
798
+ )
799
+ .filter(*filter_by)
800
+ .order_by(asc(text(DISTANCE)))
801
+ .limit(k)
802
+ .all()
803
+ )
804
+ except ProgrammingError as e:
805
+ logging.error(f"An error has occurred during the search.\n {e.__cause__}")
806
+ raise Exception(e.__cause__) from None
807
+
808
+ return results
809
+
810
+ def _create_filter_clause(self, filters: Any) -> Any:
811
+ """Create a filter clause.
812
+
813
+ Convert LangChain Information Retrieval filter representation to matching
814
+ SQLAlchemy clauses.
815
+
816
+ At the top level, we still don't know if we're working with a field
817
+ or an operator for the keys. After we've determined that we can
818
+ call the appropriate logic to handle filter creation.
819
+
820
+ Args:
821
+ filters: Dictionary of filters to apply to the query.
822
+
823
+ Returns:
824
+ SQLAlchemy clause to apply to the query.
825
+
826
+ Ex: For a filter, {"$or": [{"id": 1}, {"name": "bob"}]}, the result is
827
+ JSON_VALUE(langchain_vector_store_tests.content_metadata, :JSON_VALUE_1) =
828
+ :JSON_VALUE_2 OR JSON_VALUE(langchain_vector_store_tests.content_metadata,
829
+ :JSON_VALUE_3) = :JSON_VALUE_4
830
+ """
831
+ if filters is not None:
832
+ if not isinstance(filters, dict):
833
+ raise ValueError(
834
+ f"Expected a dict, but got {type(filters)} for value: {filter}"
835
+ )
836
+ if len(filters) == 1:
837
+ # The only operators allowed at the top level are $AND and $OR
838
+ # First check if an operator or a field
839
+ key, value = list(filters.items())[0]
840
+ if key.startswith("$"):
841
+ # Then it's an operator
842
+ if key.lower() not in LOGICAL_OPERATORS:
843
+ raise ValueError(
844
+ INVALID_FILTER_INPUT_EXPECTED_AND_OR.format(key)
845
+ )
846
+ else:
847
+ # Then it's a field
848
+ return self._handle_field_filter(key, filters[key])
849
+
850
+ # Here we handle the $and and $or operators
851
+ if not isinstance(value, list):
852
+ raise ValueError(
853
+ f"Expected a list, but got {type(value)} for value: {value}"
854
+ )
855
+ if key.lower() == "$and":
856
+ and_ = [self._create_filter_clause(el) for el in value]
857
+ if len(and_) > 1:
858
+ return sqlalchemy.and_(*and_)
859
+ elif len(and_) == 1:
860
+ return and_[0]
861
+ else:
862
+ raise ValueError(INVALID_FILTER_INPUT_EXPECTED_DICT)
863
+ elif key.lower() == "$or":
864
+ or_ = [self._create_filter_clause(el) for el in value]
865
+ if len(or_) > 1:
866
+ return sqlalchemy.or_(*or_)
867
+ elif len(or_) == 1:
868
+ return or_[0]
869
+ else:
870
+ raise ValueError(INVALID_FILTER_INPUT_EXPECTED_DICT)
871
+
872
+ elif len(filters) > 1:
873
+ # Then all keys have to be fields (they cannot be operators)
874
+ for key in filters.keys():
875
+ if key.startswith("$"):
876
+ raise ValueError(
877
+ f"Invalid filter condition. Expected a field but got: {key}"
878
+ )
879
+ # These should all be fields and combined using an $and operator
880
+ and_ = [self._handle_field_filter(k, v) for k, v in filters.items()]
881
+ if len(and_) > 1:
882
+ return sqlalchemy.and_(*and_)
883
+ elif len(and_) == 1:
884
+ return and_[0]
885
+ else:
886
+ raise ValueError(INVALID_FILTER_INPUT_EXPECTED_DICT)
887
+ else:
888
+ raise ValueError("Got an empty dictionary for filters.")
889
+ else:
890
+ logging.info("No filters are passed, returning")
891
+ return None
892
+
893
+ def _handle_field_filter(
894
+ self,
895
+ field: str,
896
+ value: Any,
897
+ ) -> SQLColumnExpression:
898
+ """Create a filter for a specific field.
899
+
900
+ Args:
901
+ field: name of field
902
+ value: value to filter
903
+ If provided as is then this will be an equality filter
904
+ If provided as a dictionary then this will be a filter, the key
905
+ will be the operator and the value will be the value to filter by
906
+
907
+ Returns:
908
+ sqlalchemy expression
909
+
910
+ Ex: For a filter, {"id": 1}, the result is
911
+
912
+ JSON_VALUE(langchain_vector_store_tests.content_metadata, :JSON_VALUE_1) =
913
+ :JSON_VALUE_2
914
+ """
915
+ if field.startswith("$"):
916
+ raise ValueError(
917
+ f"Invalid filter condition. Expected a field but got an operator: "
918
+ f"{field}"
919
+ )
920
+
921
+ # Allow [a-zA-Z0-9_], disallow $ for now until we support escape characters
922
+ if not field.isidentifier():
923
+ raise ValueError(
924
+ f"Invalid field name: {field}. Expected a valid identifier."
925
+ )
926
+
927
+ if isinstance(value, dict):
928
+ # This is a filter specification that only 1 filter will be for a given
929
+ # field, if multiple filters they are mentioned separately and used with
930
+ # an AND on the top if nothing is specified
931
+ if len(value) != 1:
932
+ raise ValueError(
933
+ "Invalid filter condition. Expected a value which "
934
+ "is a dictionary with a single key that corresponds to an operator "
935
+ f"but got a dictionary with {len(value)} keys. The first few "
936
+ f"keys are: {list(value.keys())[:3]}"
937
+ )
938
+ operator, filter_value = list(value.items())[0]
939
+ # Verify that operator is an operator
940
+ if operator not in SUPPORTED_OPERATORS:
941
+ raise ValueError(
942
+ f"Invalid operator: {operator}. "
943
+ f"Expected one of {SUPPORTED_OPERATORS}"
944
+ )
945
+ else: # Then we assume an equality operator
946
+ operator = "$eq"
947
+ filter_value = value
948
+
949
+ if operator in COMPARISONS_TO_NATIVE:
950
+ operation = COMPARISONS_TO_NATIVE[operator]
951
+ native_result = func.JSON_VALUE(
952
+ self._embedding_store.content_metadata, f"$.{field}"
953
+ )
954
+ native_operation_result = operation(native_result, str(filter_value))
955
+ return native_operation_result
956
+
957
+ elif operator in NUMERIC_OPERATORS:
958
+ operation = NUMERIC_OPERATORS[str(operator)]
959
+ numeric_result = func.JSON_VALUE(
960
+ self._embedding_store.content_metadata, f"$.{field}"
961
+ )
962
+ numeric_operation_result = operation(numeric_result, filter_value)
963
+
964
+ if not isinstance(filter_value, str):
965
+ numeric_operation_result = operation(
966
+ cast(numeric_result, Numeric(10, 2)), filter_value
967
+ )
968
+
969
+ return numeric_operation_result
970
+
971
+ elif operator in BETWEEN_OPERATOR:
972
+ # Use AND with two comparisons
973
+ low, high = filter_value
974
+
975
+ # Assuming lower_bound_value is a ColumnElement
976
+ column_value = func.JSON_VALUE(
977
+ self._embedding_store.content_metadata, f"$.{field}"
978
+ )
979
+
980
+ greater_operation = NUMERIC_OPERATORS["$gte"]
981
+ lesser_operation = NUMERIC_OPERATORS["$lte"]
982
+
983
+ lower_bound = greater_operation(column_value, low)
984
+ upper_bound = lesser_operation(column_value, high)
985
+
986
+ # Conditionally cast if filter_value is not a string
987
+ if not isinstance(filter_value, str):
988
+ lower_bound = greater_operation(cast(column_value, Numeric(10, 2)), low)
989
+ upper_bound = lesser_operation(cast(column_value, Numeric(10, 2)), high)
990
+
991
+ return sqlalchemy.and_(lower_bound, upper_bound)
992
+
993
+ elif operator in SPECIAL_CASED_OPERATORS:
994
+ # We'll do force coercion to text
995
+ if operator in {"$in", "$nin"}:
996
+ for val in filter_value:
997
+ if not isinstance(val, (str, int, float)):
998
+ raise NotImplementedError(
999
+ f"Unsupported type: {type(val)} for value: {val}"
1000
+ )
1001
+
1002
+ queried_field = func.JSON_VALUE(
1003
+ self._embedding_store.content_metadata, f"$.{field}"
1004
+ )
1005
+
1006
+ if operator in {"$in"}:
1007
+ return queried_field.in_([str(val) for val in filter_value])
1008
+ elif operator in {"$nin"}:
1009
+ return queried_field.nin_([str(val) for val in filter_value])
1010
+ elif operator in {"$like"}:
1011
+ return queried_field.like(str(filter_value))
1012
+ else:
1013
+ raise NotImplementedError(f"Operator is not implemented: {operator}. ")
1014
+ else:
1015
+ raise NotImplementedError()
1016
+
1017
+ def _docs_from_result(self, results: Any) -> List[Document]:
1018
+ """Formats the input into a result of type List[Document]."""
1019
+ docs = [doc for doc, _ in results if doc is not None]
1020
+ return docs
1021
+
1022
+ def _docs_and_scores_from_result(
1023
+ self, results: List[Any]
1024
+ ) -> List[Tuple[Document, float]]:
1025
+ """Formats the input into a result of type Tuple[Document, float].
1026
+
1027
+ If an invalid input is given, it does not attempt to format the value
1028
+ and instead logs an error.
1029
+ """
1030
+ docs_and_scores = []
1031
+
1032
+ for result in results:
1033
+ if (
1034
+ result is not None
1035
+ and result.EmbeddingStore is not None
1036
+ and result.distance is not None
1037
+ ):
1038
+ docs_and_scores.append(
1039
+ (
1040
+ Document(
1041
+ page_content=result.EmbeddingStore.content,
1042
+ metadata=result.EmbeddingStore.content_metadata,
1043
+ ),
1044
+ result.distance,
1045
+ )
1046
+ )
1047
+ else:
1048
+ logging.error(INVALID_INPUT_ERROR_MESSAGE)
1049
+
1050
+ return docs_and_scores
1051
+
1052
+ def _insert_embeddings(
1053
+ self,
1054
+ texts: Iterable[str],
1055
+ embeddings: List[List[float]],
1056
+ metadatas: Optional[List[dict]] = None,
1057
+ ids: Optional[List[str]] = None,
1058
+ **kwargs: Any,
1059
+ ) -> List[str]:
1060
+ """Insert the embeddings and the texts in the vectorstore.
1061
+
1062
+ Args:
1063
+ texts: Iterable of strings to add into the vectorstore.
1064
+ embeddings: List of list of embeddings.
1065
+ metadatas: List of metadatas (python dicts) associated with the input texts.
1066
+ ids: List of IDs for the input texts.
1067
+ **kwargs: vectorstore specific parameters.
1068
+
1069
+ Returns:
1070
+ List of IDs generated from adding the texts into the vectorstore.
1071
+ """
1072
+ if metadatas is None:
1073
+ metadatas = [{} for _ in texts]
1074
+
1075
+ try:
1076
+ if ids is None:
1077
+ # Get IDs from metadata if available.
1078
+ ids = [metadata.get("id", uuid.uuid4()) for metadata in metadatas]
1079
+
1080
+ with Session(self._bind) as session:
1081
+ documents = []
1082
+ for idx, query in enumerate(texts):
1083
+ # For a query, if there is no corresponding ID,
1084
+ # we generate a uuid and add it to the list of IDs to be returned.
1085
+ if idx < len(ids):
1086
+ custom_id = ids[idx]
1087
+ else:
1088
+ ids.append(str(uuid.uuid4()))
1089
+ custom_id = ids[-1]
1090
+ embedding = embeddings[idx]
1091
+ metadata = metadatas[idx] if idx < len(metadatas) else {}
1092
+
1093
+ # Construct text, embedding, metadata as EmbeddingStore model
1094
+ # to be inserted into the table.
1095
+ sqlquery = select(
1096
+ text(JSON_TO_VECTOR_QUERY).bindparams(
1097
+ bindparam(
1098
+ EMBEDDING_VALUES,
1099
+ json.dumps(embedding),
1100
+ literal_execute=True,
1101
+ # when unique is set to true, the name of the key
1102
+ # for each bindparameter is made unique, to avoid
1103
+ # using the wrong bound parameter during compile.
1104
+ # This is especially needed since we're creating
1105
+ # and storing multiple queries to be bulk inserted
1106
+ # later on.
1107
+ unique=True,
1108
+ ),
1109
+ bindparam(
1110
+ EMBEDDING_LENGTH,
1111
+ self._embedding_length,
1112
+ literal_execute=True,
1113
+ ),
1114
+ )
1115
+ )
1116
+ # `embedding_store` is created in a dictionary format instead
1117
+ # of using the embedding_store object from this class.
1118
+ # This enables the use of `insert().values()` which can only
1119
+ # take a dict and not a custom object.
1120
+ embedding_store = {
1121
+ "custom_id": custom_id,
1122
+ "content_metadata": metadata,
1123
+ "content": query,
1124
+ "embeddings": sqlquery,
1125
+ }
1126
+ documents.append(embedding_store)
1127
+ session.execute(insert(self._embedding_store).values(documents))
1128
+ session.commit()
1129
+ except DBAPIError as e:
1130
+ logging.error(f"Add text failed:\n {e.__cause__}\n")
1131
+ raise Exception(e.__cause__) from None
1132
+ except AttributeError:
1133
+ logging.error("Metadata must be a list of dictionaries.")
1134
+ raise
1135
+ return ids
1136
+
1137
+ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> Optional[bool]:
1138
+ """Delete embeddings in the vectorstore by the ids.
1139
+
1140
+ Args:
1141
+ ids: List of IDs to delete. If None, delete all. Default is None.
1142
+ No data is deleted if empty list is provided.
1143
+ kwargs: vectorstore specific parameters.
1144
+
1145
+ Returns:
1146
+ Optional[bool]
1147
+ """
1148
+ if ids is not None and len(ids) == 0:
1149
+ logging.info(EMPTY_IDS_ERROR_MESSAGE)
1150
+ return False
1151
+
1152
+ result = self._delete_texts_by_ids(ids)
1153
+ if result == 0:
1154
+ logging.info(INVALID_IDS_ERROR_MESSAGE)
1155
+ return False
1156
+
1157
+ logging.info(f"{result} rows affected.")
1158
+ return True
1159
+
1160
+ def _delete_texts_by_ids(self, ids: Optional[List[str]] = None) -> int:
1161
+ try:
1162
+ with Session(bind=self._bind) as session:
1163
+ if ids is None:
1164
+ logging.info("Deleting all data in the vectorstore.")
1165
+ result = session.query(self._embedding_store).delete()
1166
+ else:
1167
+ result = (
1168
+ session.query(self._embedding_store)
1169
+ .filter(self._embedding_store.custom_id.in_(ids))
1170
+ .delete()
1171
+ )
1172
+ session.commit()
1173
+ except DBAPIError as e:
1174
+ logging.error(e.__cause__)
1175
+ return result
1176
+
1177
+ def _provide_token(
1178
+ self,
1179
+ dialect: Dialect,
1180
+ conn_rec: Optional[ConnectionPoolEntry],
1181
+ cargs: List[str],
1182
+ cparams: MutableMapping[str, Any],
1183
+ ) -> None:
1184
+ """Function to retreive access token for connection.
1185
+
1186
+ Get token for SQLServer connection from token URL,
1187
+ and use the token to connect to the database.
1188
+ """
1189
+ credential = DefaultAzureCredential()
1190
+
1191
+ # Remove Trusted_Connection param that SQLAlchemy adds to
1192
+ # the connection string by default.
1193
+ cargs[0] = cargs[0].replace(EXTRA_PARAMS, str())
1194
+
1195
+ # Create credential token
1196
+ token_bytes = credential.get_token(AZURE_TOKEN_URL).token.encode("utf-16-le")
1197
+ token_struct = struct.pack(
1198
+ f"<I{len(token_bytes)}s", len(token_bytes), token_bytes
1199
+ )
1200
+
1201
+ # Apply credential token to keyword argument
1202
+ cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
@@ -0,0 +1,89 @@
1
+ [tool.poetry]
2
+ name = "langchain-sqlserver"
3
+ version = "0.1.0"
4
+ description = "An integration package to support SQL Server in LangChain."
5
+ authors = []
6
+ license = "MIT"
7
+ readme = "README.md"
8
+
9
+ [tool.poetry.dependencies]
10
+ python = ">=3.9,<4.0"
11
+ SQLAlchemy = ">=2.0.0,<3"
12
+ azure-identity = "^1.16.0"
13
+ langchain-core = "^0.3.0"
14
+ pyodbc = ">=5.0.0,<6.0.0"
15
+ numpy = "^1"
16
+
17
+ [tool.poetry.group.codespell.dependencies]
18
+ codespell = "^2.2.0"
19
+
20
+ [tool.poetry.group.dev.dependencies]
21
+ langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"}
22
+
23
+ [tool.poetry.group.lint.dependencies]
24
+ ruff = "^0.5"
25
+ python-dotenv = "^1.0.1"
26
+ pytest = "^7.4.3"
27
+
28
+ [tool.poetry.group.test.dependencies]
29
+ pydantic = "^2.9.2"
30
+ pytest = "^7.4.3"
31
+ pytest-mock = "^3.10.0"
32
+ pytest-watcher = "^0.3.4"
33
+ pytest-asyncio = "^0.21.1"
34
+ python-dotenv = "^1.0.1"
35
+ syrupy = "^4.7.2"
36
+ langchain-core = {git = "https://github.com/langchain-ai/langchain.git", subdirectory = "libs/core"}
37
+
38
+ [tool.poetry.group.test_integration.dependencies]
39
+ pytest = "^7.3.0"
40
+ python-dotenv = "^1.0.1"
41
+
42
+ [tool.poetry.urls]
43
+ "Source Code" = "https://github.com/langchain-ai/langchain-azure/tree/main/libs/sqlserver"
44
+ "Release Notes" = "https://github.com/langchain-ai/langchain-azure/releases"
45
+
46
+ [tool.mypy]
47
+ disallow_untyped_defs = "True"
48
+
49
+ [tool.poetry.group.typing.dependencies]
50
+ mypy = "^1.10"
51
+
52
+ [tool.ruff.lint]
53
+ select = ["E", "F", "I", "D"]
54
+
55
+ [tool.coverage.run]
56
+ omit = ["tests/*"]
57
+
58
+ [tool.pytest.ini_options]
59
+ addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
60
+ markers = [
61
+ "requires: mark tests as requiring a specific library",
62
+ "compile: mark placeholder test used to compile integration tests without running them",
63
+ ]
64
+ asyncio_mode = "auto"
65
+
66
+ [tool.poetry.group.test]
67
+ optional = true
68
+
69
+ [tool.poetry.group.test_integration]
70
+ optional = true
71
+
72
+ [tool.poetry.group.codespell]
73
+ optional = true
74
+
75
+ [tool.poetry.group.lint]
76
+ optional = true
77
+
78
+ [tool.poetry.group.dev]
79
+ optional = true
80
+
81
+ [tool.ruff.lint.pydocstyle]
82
+ convention = "google"
83
+
84
+ [tool.ruff.lint.per-file-ignores]
85
+ "tests/**" = ["D"]
86
+
87
+ [build-system]
88
+ requires = ["poetry-core"]
89
+ build-backend = "poetry.core.masonry.api"