langchain 0.0.163 → 0.0.165

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (43) hide show
  1. package/chat_models/portkey.cjs +1 -0
  2. package/chat_models/portkey.d.ts +1 -0
  3. package/chat_models/portkey.js +1 -0
  4. package/dist/chat_models/bedrock.cjs +3 -0
  5. package/dist/chat_models/bedrock.js +3 -0
  6. package/dist/chat_models/portkey.cjs +159 -0
  7. package/dist/chat_models/portkey.d.ts +17 -0
  8. package/dist/chat_models/portkey.js +155 -0
  9. package/dist/document_loaders/web/notionapi.cjs +28 -5
  10. package/dist/document_loaders/web/notionapi.d.ts +2 -0
  11. package/dist/document_loaders/web/notionapi.js +25 -5
  12. package/dist/embeddings/minimax.cjs +1 -1
  13. package/dist/embeddings/minimax.js +1 -1
  14. package/dist/graphs/neo4j_graph.cjs +86 -10
  15. package/dist/graphs/neo4j_graph.d.ts +2 -1
  16. package/dist/graphs/neo4j_graph.js +86 -10
  17. package/dist/llms/bedrock.cjs +3 -0
  18. package/dist/llms/bedrock.js +3 -0
  19. package/dist/llms/portkey.cjs +147 -0
  20. package/dist/llms/portkey.d.ts +33 -0
  21. package/dist/llms/portkey.js +138 -0
  22. package/dist/llms/sagemaker_endpoint.cjs +76 -14
  23. package/dist/llms/sagemaker_endpoint.d.ts +39 -20
  24. package/dist/llms/sagemaker_endpoint.js +77 -15
  25. package/dist/load/import_constants.cjs +3 -0
  26. package/dist/load/import_constants.js +3 -0
  27. package/dist/output_parsers/list.cjs +1 -1
  28. package/dist/output_parsers/list.js +1 -1
  29. package/dist/util/stream.cjs +4 -4
  30. package/dist/util/stream.js +4 -4
  31. package/dist/vectorstores/cassandra.cjs +212 -0
  32. package/dist/vectorstores/cassandra.d.ts +98 -0
  33. package/dist/vectorstores/cassandra.js +208 -0
  34. package/dist/vectorstores/mongodb_atlas.cjs +29 -39
  35. package/dist/vectorstores/mongodb_atlas.js +29 -39
  36. package/dist/vectorstores/prisma.d.ts +1 -1
  37. package/llms/portkey.cjs +1 -0
  38. package/llms/portkey.d.ts +1 -0
  39. package/llms/portkey.js +1 -0
  40. package/package.json +42 -2
  41. package/vectorstores/cassandra.cjs +1 -0
  42. package/vectorstores/cassandra.d.ts +1 -0
  43. package/vectorstores/cassandra.js +1 -0
@@ -1,5 +1,7 @@
1
1
  import { SageMakerRuntimeClient, SageMakerRuntimeClientConfig } from "@aws-sdk/client-sagemaker-runtime";
2
- import { LLM, BaseLLMParams } from "./base.js";
2
+ import { CallbackManagerForLLMRun } from "../callbacks/manager.js";
3
+ import { GenerationChunk } from "../schema/index.js";
4
+ import { BaseLLMCallOptions, BaseLLMParams, LLM } from "./base.js";
3
5
  /**
4
6
  * A handler class to transform input from LLM to a format that SageMaker
5
7
  * endpoint expects. Similarily, the class also handles transforming output from
@@ -28,22 +30,22 @@ import { LLM, BaseLLMParams } from "./base.js";
28
30
  * ```
29
31
  */
30
32
  export declare abstract class BaseSageMakerContentHandler<InputType, OutputType> {
31
- /** The MIME type of the input data passed to endpoint */
32
33
  contentType: string;
33
- /** The MIME type of the response data returned from endpoint */
34
34
  accepts: string;
35
35
  /**
36
- * Transforms the input to a format that model can accept as the request Body.
37
- * Should return bytes or seekable file like object in the format specified in
38
- * the contentType request header.
36
+ * Transforms the prompt and model arguments into a specific format for sending to SageMaker.
37
+ * @param {InputType} prompt The prompt to be transformed.
38
+ * @param {Record<string, unknown>} modelKwargs Additional arguments.
39
+ * @returns {Promise<Uint8Array>} A promise that resolves to the formatted data for sending.
39
40
  */
40
41
  abstract transformInput(prompt: InputType, modelKwargs: Record<string, unknown>): Promise<Uint8Array>;
41
42
  /**
42
- * Transforms the output from the model to string that the LLM class expects.
43
+ * Transforms SageMaker output into a desired format.
44
+ * @param {Uint8Array} output The raw output from SageMaker.
45
+ * @returns {Promise<OutputType>} A promise that resolves to the transformed data.
43
46
  */
44
47
  abstract transformOutput(output: Uint8Array): Promise<OutputType>;
45
48
  }
46
- /** Content handler for LLM class. */
47
49
  export type SageMakerLLMContentHandler = BaseSageMakerContentHandler<string, string>;
48
50
  /**
49
51
  * The SageMakerEndpointInput interface defines the input parameters for
@@ -61,11 +63,6 @@ export interface SageMakerEndpointInput extends BaseLLMParams {
61
63
  * Options passed to the SageMaker client.
62
64
  */
63
65
  clientOptions: SageMakerRuntimeClientConfig;
64
- /**
65
- * The content handler class that provides an input and output transform
66
- * functions to handle formats between LLM and the endpoint.
67
- */
68
- contentHandler: SageMakerLLMContentHandler;
69
66
  /**
70
67
  * Key word arguments to pass to the model.
71
68
  */
@@ -74,29 +71,51 @@ export interface SageMakerEndpointInput extends BaseLLMParams {
74
71
  * Optional attributes passed to the InvokeEndpointCommand
75
72
  */
76
73
  endpointKwargs?: Record<string, unknown>;
74
+ /**
75
+ * The content handler class that provides an input and output transform
76
+ * functions to handle formats between LLM and the endpoint.
77
+ */
78
+ contentHandler: SageMakerLLMContentHandler;
79
+ streaming?: boolean;
77
80
  }
78
81
  /**
79
82
  * The SageMakerEndpoint class is used to interact with SageMaker
80
- * Inference Endpoint models. It extends the LLM class and overrides the
81
- * _call method to transform the input and output between the LLM and the
82
- * SageMaker endpoint using the provided content handler. The class uses
83
- * AWS client for authentication, which automatically loads credentials.
83
+ * Inference Endpoint models. It uses the AWS client for authentication,
84
+ * which automatically loads credentials.
84
85
  * If a specific credential profile is to be used, the name of the profile
85
86
  * from the ~/.aws/credentials file must be passed. The credentials or
86
87
  * roles used should have the required policies to access the SageMaker
87
88
  * endpoint.
88
89
  */
89
- export declare class SageMakerEndpoint extends LLM {
90
+ export declare class SageMakerEndpoint extends LLM<BaseLLMCallOptions> {
91
+ static lc_name(): string;
90
92
  get lc_secrets(): {
91
93
  [key: string]: string;
92
94
  } | undefined;
93
95
  endpointName: string;
94
- contentHandler: SageMakerLLMContentHandler;
95
96
  modelKwargs?: Record<string, unknown>;
96
97
  endpointKwargs?: Record<string, unknown>;
97
98
  client: SageMakerRuntimeClient;
99
+ contentHandler: SageMakerLLMContentHandler;
100
+ streaming: boolean;
98
101
  constructor(fields: SageMakerEndpointInput);
99
102
  _llmType(): string;
103
+ /**
104
+ * Calls the SageMaker endpoint and retrieves the result.
105
+ * @param {string} prompt The input prompt.
106
+ * @param {this["ParsedCallOptions"]} options Parsed call options.
107
+ * @param {CallbackManagerForLLMRun} _runManager Optional run manager.
108
+ * @returns {Promise<string>} A promise that resolves to the generated string.
109
+ */
100
110
  /** @ignore */
101
- _call(prompt: string, options: this["ParsedCallOptions"]): Promise<string>;
111
+ _call(prompt: string, options: this["ParsedCallOptions"], _runManager?: CallbackManagerForLLMRun): Promise<string>;
112
+ private streamingCall;
113
+ private noStreamingCall;
114
+ /**
115
+ * Streams response chunks from the SageMaker endpoint.
116
+ * @param {string} prompt The input prompt.
117
+ * @param {this["ParsedCallOptions"]} options Parsed call options.
118
+ * @returns {AsyncGenerator<GenerationChunk>} An asynchronous generator yielding generation chunks.
119
+ */
120
+ _streamResponseChunks(prompt: string, options: this["ParsedCallOptions"]): AsyncGenerator<GenerationChunk>;
102
121
  }
@@ -1,4 +1,5 @@
1
- import { SageMakerRuntimeClient, InvokeEndpointCommand, } from "@aws-sdk/client-sagemaker-runtime";
1
+ import { InvokeEndpointCommand, InvokeEndpointWithResponseStreamCommand, SageMakerRuntimeClient, } from "@aws-sdk/client-sagemaker-runtime";
2
+ import { GenerationChunk } from "../schema/index.js";
2
3
  import { LLM } from "./base.js";
3
4
  /**
4
5
  * A handler class to transform input from LLM to a format that SageMaker
@@ -29,14 +30,12 @@ import { LLM } from "./base.js";
29
30
  */
30
31
  export class BaseSageMakerContentHandler {
31
32
  constructor() {
32
- /** The MIME type of the input data passed to endpoint */
33
33
  Object.defineProperty(this, "contentType", {
34
34
  enumerable: true,
35
35
  configurable: true,
36
36
  writable: true,
37
37
  value: "text/plain"
38
38
  });
39
- /** The MIME type of the response data returned from endpoint */
40
39
  Object.defineProperty(this, "accepts", {
41
40
  enumerable: true,
42
41
  configurable: true,
@@ -47,16 +46,17 @@ export class BaseSageMakerContentHandler {
47
46
  }
48
47
  /**
49
48
  * The SageMakerEndpoint class is used to interact with SageMaker
50
- * Inference Endpoint models. It extends the LLM class and overrides the
51
- * _call method to transform the input and output between the LLM and the
52
- * SageMaker endpoint using the provided content handler. The class uses
53
- * AWS client for authentication, which automatically loads credentials.
49
+ * Inference Endpoint models. It uses the AWS client for authentication,
50
+ * which automatically loads credentials.
54
51
  * If a specific credential profile is to be used, the name of the profile
55
52
  * from the ~/.aws/credentials file must be passed. The credentials or
56
53
  * roles used should have the required policies to access the SageMaker
57
54
  * endpoint.
58
55
  */
59
56
  export class SageMakerEndpoint extends LLM {
57
+ static lc_name() {
58
+ return "SageMakerEndpoint";
59
+ }
60
60
  get lc_secrets() {
61
61
  return {
62
62
  "clientOptions.credentials.accessKeyId": "AWS_ACCESS_KEY_ID",
@@ -65,39 +65,44 @@ export class SageMakerEndpoint extends LLM {
65
65
  };
66
66
  }
67
67
  constructor(fields) {
68
- super(fields ?? {});
68
+ super(fields);
69
69
  Object.defineProperty(this, "endpointName", {
70
70
  enumerable: true,
71
71
  configurable: true,
72
72
  writable: true,
73
73
  value: void 0
74
74
  });
75
- Object.defineProperty(this, "contentHandler", {
75
+ Object.defineProperty(this, "modelKwargs", {
76
76
  enumerable: true,
77
77
  configurable: true,
78
78
  writable: true,
79
79
  value: void 0
80
80
  });
81
- Object.defineProperty(this, "modelKwargs", {
81
+ Object.defineProperty(this, "endpointKwargs", {
82
82
  enumerable: true,
83
83
  configurable: true,
84
84
  writable: true,
85
85
  value: void 0
86
86
  });
87
- Object.defineProperty(this, "endpointKwargs", {
87
+ Object.defineProperty(this, "client", {
88
88
  enumerable: true,
89
89
  configurable: true,
90
90
  writable: true,
91
91
  value: void 0
92
92
  });
93
- Object.defineProperty(this, "client", {
93
+ Object.defineProperty(this, "contentHandler", {
94
94
  enumerable: true,
95
95
  configurable: true,
96
96
  writable: true,
97
97
  value: void 0
98
98
  });
99
- const regionName = fields.clientOptions.region;
100
- if (!regionName) {
99
+ Object.defineProperty(this, "streaming", {
100
+ enumerable: true,
101
+ configurable: true,
102
+ writable: true,
103
+ value: void 0
104
+ });
105
+ if (!fields.clientOptions.region) {
101
106
  throw new Error(`Please pass a "clientOptions" object with a "region" field to the constructor`);
102
107
  }
103
108
  const endpointName = fields?.endpointName;
@@ -112,13 +117,33 @@ export class SageMakerEndpoint extends LLM {
112
117
  this.contentHandler = fields.contentHandler;
113
118
  this.endpointKwargs = fields.endpointKwargs;
114
119
  this.modelKwargs = fields.modelKwargs;
120
+ this.streaming = fields.streaming ?? false;
115
121
  this.client = new SageMakerRuntimeClient(fields.clientOptions);
116
122
  }
117
123
  _llmType() {
118
124
  return "sagemaker_endpoint";
119
125
  }
126
+ /**
127
+ * Calls the SageMaker endpoint and retrieves the result.
128
+ * @param {string} prompt The input prompt.
129
+ * @param {this["ParsedCallOptions"]} options Parsed call options.
130
+ * @param {CallbackManagerForLLMRun} _runManager Optional run manager.
131
+ * @returns {Promise<string>} A promise that resolves to the generated string.
132
+ */
120
133
  /** @ignore */
121
- async _call(prompt, options) {
134
+ async _call(prompt, options, _runManager) {
135
+ return this.streaming
136
+ ? await this.streamingCall(prompt, options)
137
+ : await this.noStreamingCall(prompt, options);
138
+ }
139
+ async streamingCall(prompt, options) {
140
+ const chunks = [];
141
+ for await (const chunk of this._streamResponseChunks(prompt, options)) {
142
+ chunks.push(chunk.text);
143
+ }
144
+ return chunks.join("");
145
+ }
146
+ async noStreamingCall(prompt, options) {
122
147
  const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
123
148
  const { contentType, accepts } = this.contentHandler;
124
149
  const response = await this.caller.call(() => this.client.send(new InvokeEndpointCommand({
@@ -133,4 +158,41 @@ export class SageMakerEndpoint extends LLM {
133
158
  }
134
159
  return this.contentHandler.transformOutput(response.Body);
135
160
  }
161
+ /**
162
+ * Streams response chunks from the SageMaker endpoint.
163
+ * @param {string} prompt The input prompt.
164
+ * @param {this["ParsedCallOptions"]} options Parsed call options.
165
+ * @returns {AsyncGenerator<GenerationChunk>} An asynchronous generator yielding generation chunks.
166
+ */
167
+ async *_streamResponseChunks(prompt, options) {
168
+ const body = await this.contentHandler.transformInput(prompt, this.modelKwargs ?? {});
169
+ const { contentType, accepts } = this.contentHandler;
170
+ const stream = await this.caller.call(() => this.client.send(new InvokeEndpointWithResponseStreamCommand({
171
+ EndpointName: this.endpointName,
172
+ Body: body,
173
+ ContentType: contentType,
174
+ Accept: accepts,
175
+ ...this.endpointKwargs,
176
+ }), { abortSignal: options.signal }));
177
+ if (!stream.Body) {
178
+ throw new Error("Inference result missing Body");
179
+ }
180
+ for await (const chunk of stream.Body) {
181
+ if (chunk.PayloadPart && chunk.PayloadPart.Bytes) {
182
+ yield new GenerationChunk({
183
+ text: await this.contentHandler.transformOutput(chunk.PayloadPart.Bytes),
184
+ generationInfo: {
185
+ ...chunk,
186
+ response: undefined,
187
+ },
188
+ });
189
+ }
190
+ else if (chunk.InternalStreamFailure) {
191
+ throw new Error(chunk.InternalStreamFailure.message);
192
+ }
193
+ else if (chunk.ModelStreamError) {
194
+ throw new Error(chunk.ModelStreamError.message);
195
+ }
196
+ }
197
+ }
136
198
  }
@@ -37,8 +37,10 @@ exports.optionalImportEntrypoints = [
37
37
  "langchain/llms/bedrock",
38
38
  "langchain/llms/llama_cpp",
39
39
  "langchain/llms/writer",
40
+ "langchain/llms/portkey",
40
41
  "langchain/prompts/load",
41
42
  "langchain/vectorstores/analyticdb",
43
+ "langchain/vectorstores/cassandra",
42
44
  "langchain/vectorstores/elasticsearch",
43
45
  "langchain/vectorstores/cloudflare_vectorize",
44
46
  "langchain/vectorstores/chroma",
@@ -101,6 +103,7 @@ exports.optionalImportEntrypoints = [
101
103
  "langchain/document_loaders/fs/openai_whisper_audio",
102
104
  "langchain/document_transformers/html_to_text",
103
105
  "langchain/document_transformers/mozilla_readability",
106
+ "langchain/chat_models/portkey",
104
107
  "langchain/chat_models/bedrock",
105
108
  "langchain/chat_models/googlevertexai",
106
109
  "langchain/chat_models/googlevertexai/web",
@@ -34,8 +34,10 @@ export const optionalImportEntrypoints = [
34
34
  "langchain/llms/bedrock",
35
35
  "langchain/llms/llama_cpp",
36
36
  "langchain/llms/writer",
37
+ "langchain/llms/portkey",
37
38
  "langchain/prompts/load",
38
39
  "langchain/vectorstores/analyticdb",
40
+ "langchain/vectorstores/cassandra",
39
41
  "langchain/vectorstores/elasticsearch",
40
42
  "langchain/vectorstores/cloudflare_vectorize",
41
43
  "langchain/vectorstores/chroma",
@@ -98,6 +100,7 @@ export const optionalImportEntrypoints = [
98
100
  "langchain/document_loaders/fs/openai_whisper_audio",
99
101
  "langchain/document_transformers/html_to_text",
100
102
  "langchain/document_transformers/mozilla_readability",
103
+ "langchain/chat_models/portkey",
101
104
  "langchain/chat_models/bedrock",
102
105
  "langchain/chat_models/googlevertexai",
103
106
  "langchain/chat_models/googlevertexai/web",
@@ -119,7 +119,7 @@ class CustomListOutputParser extends ListOutputParser {
119
119
  * @returns A string containing instructions on the expected format of the response.
120
120
  */
121
121
  getFormatInstructions() {
122
- return `Your response should be a list of ${this.length} items separated by "${this.separator}" (eg: \`foo${this.separator} bar${this.separator} baz\`)`;
122
+ return `Your response should be a list of ${this.length === undefined ? "" : `${this.length} `}items separated by "${this.separator}" (eg: \`foo${this.separator} bar${this.separator} baz\`)`;
123
123
  }
124
124
  }
125
125
  exports.CustomListOutputParser = CustomListOutputParser;
@@ -114,6 +114,6 @@ export class CustomListOutputParser extends ListOutputParser {
114
114
  * @returns A string containing instructions on the expected format of the response.
115
115
  */
116
116
  getFormatInstructions() {
117
- return `Your response should be a list of ${this.length} items separated by "${this.separator}" (eg: \`foo${this.separator} bar${this.separator} baz\`)`;
117
+ return `Your response should be a list of ${this.length === undefined ? "" : `${this.length} `}items separated by "${this.separator}" (eg: \`foo${this.separator} bar${this.separator} baz\`)`;
118
118
  }
119
119
  }
@@ -38,7 +38,7 @@ class IterableReadableStream extends ReadableStream {
38
38
  const cancelPromise = this.reader.cancel(); // cancel first, but don't await yet
39
39
  this.reader.releaseLock(); // release lock first
40
40
  await cancelPromise; // now await it
41
- return { done: true, value: undefined }; // This cast fixes TS typing, and convention is to ignore chunk value anyway
41
+ return { done: true, value: undefined }; // This cast fixes TS typing, and convention is to ignore final chunk value anyway
42
42
  }
43
43
  [Symbol.asyncIterator]() {
44
44
  return this;
@@ -68,12 +68,12 @@ class IterableReadableStream extends ReadableStream {
68
68
  return new IterableReadableStream({
69
69
  async pull(controller) {
70
70
  const { value, done } = await generator.next();
71
+ // When no more data needs to be consumed, close the stream
71
72
  if (done) {
72
73
  controller.close();
73
74
  }
74
- else if (value) {
75
- controller.enqueue(value);
76
- }
75
+ // Fix: `else if (value)` will hang the streaming when nullish value (e.g. empty string) is pulled
76
+ controller.enqueue(value);
77
77
  },
78
78
  });
79
79
  }
@@ -35,7 +35,7 @@ export class IterableReadableStream extends ReadableStream {
35
35
  const cancelPromise = this.reader.cancel(); // cancel first, but don't await yet
36
36
  this.reader.releaseLock(); // release lock first
37
37
  await cancelPromise; // now await it
38
- return { done: true, value: undefined }; // This cast fixes TS typing, and convention is to ignore chunk value anyway
38
+ return { done: true, value: undefined }; // This cast fixes TS typing, and convention is to ignore final chunk value anyway
39
39
  }
40
40
  [Symbol.asyncIterator]() {
41
41
  return this;
@@ -65,12 +65,12 @@ export class IterableReadableStream extends ReadableStream {
65
65
  return new IterableReadableStream({
66
66
  async pull(controller) {
67
67
  const { value, done } = await generator.next();
68
+ // When no more data needs to be consumed, close the stream
68
69
  if (done) {
69
70
  controller.close();
70
71
  }
71
- else if (value) {
72
- controller.enqueue(value);
73
- }
72
+ // Fix: `else if (value)` will hang the streaming when nullish value (e.g. empty string) is pulled
73
+ controller.enqueue(value);
74
74
  },
75
75
  });
76
76
  }
@@ -0,0 +1,212 @@
1
+ "use strict";
2
+ Object.defineProperty(exports, "__esModule", { value: true });
3
+ exports.CassandraStore = void 0;
4
+ /* eslint-disable prefer-template */
5
+ const cassandra_driver_1 = require("cassandra-driver");
6
+ const base_js_1 = require("./base.cjs");
7
+ const document_js_1 = require("../document.cjs");
8
+ /**
9
+ * Class for interacting with the Cassandra database. It extends the
10
+ * VectorStore class and provides methods for adding vectors and
11
+ * documents, searching for similar vectors, and creating instances from
12
+ * texts or documents.
13
+ */
14
+ class CassandraStore extends base_js_1.VectorStore {
15
+ _vectorstoreType() {
16
+ return "cassandra";
17
+ }
18
+ constructor(embeddings, args) {
19
+ super(embeddings, args);
20
+ Object.defineProperty(this, "client", {
21
+ enumerable: true,
22
+ configurable: true,
23
+ writable: true,
24
+ value: void 0
25
+ });
26
+ Object.defineProperty(this, "dimensions", {
27
+ enumerable: true,
28
+ configurable: true,
29
+ writable: true,
30
+ value: void 0
31
+ });
32
+ Object.defineProperty(this, "keyspace", {
33
+ enumerable: true,
34
+ configurable: true,
35
+ writable: true,
36
+ value: void 0
37
+ });
38
+ Object.defineProperty(this, "primaryKey", {
39
+ enumerable: true,
40
+ configurable: true,
41
+ writable: true,
42
+ value: void 0
43
+ });
44
+ Object.defineProperty(this, "metadataColumns", {
45
+ enumerable: true,
46
+ configurable: true,
47
+ writable: true,
48
+ value: void 0
49
+ });
50
+ Object.defineProperty(this, "table", {
51
+ enumerable: true,
52
+ configurable: true,
53
+ writable: true,
54
+ value: void 0
55
+ });
56
+ Object.defineProperty(this, "isInitialized", {
57
+ enumerable: true,
58
+ configurable: true,
59
+ writable: true,
60
+ value: false
61
+ });
62
+ this.client = new cassandra_driver_1.Client(args);
63
+ this.dimensions = args.dimensions;
64
+ this.keyspace = args.keyspace;
65
+ this.table = args.table;
66
+ this.primaryKey = args.primaryKey;
67
+ this.metadataColumns = args.metadataColumns;
68
+ }
69
+ /**
70
+ * Method to save vectors to the Cassandra database.
71
+ * @param vectors Vectors to save.
72
+ * @param documents The documents associated with the vectors.
73
+ * @returns Promise that resolves when the vectors have been added.
74
+ */
75
+ async addVectors(vectors, documents) {
76
+ if (vectors.length === 0) {
77
+ return;
78
+ }
79
+ if (!this.isInitialized) {
80
+ await this.initialize();
81
+ }
82
+ const queries = this.buildInsertQuery(vectors, documents);
83
+ await this.client.batch(queries);
84
+ }
85
+ /**
86
+ * Method to add documents to the Cassandra database.
87
+ * @param documents The documents to add.
88
+ * @returns Promise that resolves when the documents have been added.
89
+ */
90
+ async addDocuments(documents) {
91
+ return this.addVectors(await this.embeddings.embedDocuments(documents.map((d) => d.pageContent)), documents);
92
+ }
93
+ /**
94
+ * Method to search for vectors that are similar to a given query vector.
95
+ * @param query The query vector.
96
+ * @param k The number of similar vectors to return.
97
+ * @returns Promise that resolves with an array of tuples, each containing a Document and a score.
98
+ */
99
+ async similaritySearchVectorWithScore(query, k) {
100
+ if (!this.isInitialized) {
101
+ await this.initialize();
102
+ }
103
+ const queryStr = this.buildSearchQuery(query, k);
104
+ const queryResultSet = await this.client.execute(queryStr);
105
+ return queryResultSet?.rows.map((row, index) => {
106
+ const textContent = row.text;
107
+ const sanitizedRow = Object.assign(row, {});
108
+ delete sanitizedRow.vector;
109
+ delete sanitizedRow.text;
110
+ return [
111
+ new document_js_1.Document({ pageContent: textContent, metadata: sanitizedRow }),
112
+ index,
113
+ ];
114
+ });
115
+ }
116
+ /**
117
+ * Static method to create an instance of CassandraStore from texts.
118
+ * @param texts The texts to use.
119
+ * @param metadatas The metadata associated with the texts.
120
+ * @param embeddings The embeddings to use.
121
+ * @param args The arguments for the CassandraStore.
122
+ * @returns Promise that resolves with a new instance of CassandraStore.
123
+ */
124
+ static async fromTexts(texts, metadatas, embeddings, args) {
125
+ const docs = [];
126
+ for (let index = 0; index < texts.length; index += 1) {
127
+ const metadata = Array.isArray(metadatas) ? metadatas[index] : metadatas;
128
+ const doc = new document_js_1.Document({
129
+ pageContent: texts[index],
130
+ metadata,
131
+ });
132
+ docs.push(doc);
133
+ }
134
+ return CassandraStore.fromDocuments(docs, embeddings, args);
135
+ }
136
+ /**
137
+ * Static method to create an instance of CassandraStore from documents.
138
+ * @param docs The documents to use.
139
+ * @param embeddings The embeddings to use.
140
+ * @param args The arguments for the CassandraStore.
141
+ * @returns Promise that resolves with a new instance of CassandraStore.
142
+ */
143
+ static async fromDocuments(docs, embeddings, args) {
144
+ const instance = new this(embeddings, args);
145
+ await instance.addDocuments(docs);
146
+ return instance;
147
+ }
148
+ /**
149
+ * Static method to create an instance of CassandraStore from an existing
150
+ * index.
151
+ * @param embeddings The embeddings to use.
152
+ * @param args The arguments for the CassandraStore.
153
+ * @returns Promise that resolves with a new instance of CassandraStore.
154
+ */
155
+ static async fromExistingIndex(embeddings, args) {
156
+ const instance = new this(embeddings, args);
157
+ await instance.initialize();
158
+ return instance;
159
+ }
160
+ /**
161
+ * Method to initialize the Cassandra database.
162
+ * @returns Promise that resolves when the database has been initialized.
163
+ */
164
+ async initialize() {
165
+ await this.client.execute(`CREATE TABLE IF NOT EXISTS ${this.keyspace}.${this.table} (
166
+ ${this.primaryKey.name} ${this.primaryKey.type} PRIMARY KEY,
167
+ text TEXT,
168
+ ${this.metadataColumns.length > 0
169
+ ? this.metadataColumns.map((col) => `${col.name} ${col.type},`)
170
+ : ""}
171
+ vector VECTOR<FLOAT, ${this.dimensions}>
172
+ );`);
173
+ await this.client.execute(`CREATE CUSTOM INDEX IF NOT EXISTS ann_index
174
+ ON ${this.keyspace}.${this.table}(vector) USING 'StorageAttachedIndex';`);
175
+ this.isInitialized = true;
176
+ }
177
+ /**
178
+ * Method to build an CQL query for inserting vectors and documents into
179
+ * the Cassandra database.
180
+ * @param vectors The vectors to insert.
181
+ * @param documents The documents to insert.
182
+ * @returns The CQL query string.
183
+ */
184
+ buildInsertQuery(vectors, documents) {
185
+ const queries = [];
186
+ for (let index = 0; index < vectors.length; index += 1) {
187
+ const vector = vectors[index];
188
+ const document = documents[index];
189
+ const metadataColNames = Object.keys(document.metadata);
190
+ const metadataVals = Object.values(document.metadata);
191
+ const query = `INSERT INTO ${this.keyspace}.${this.table} (vector, text${metadataColNames.length > 0 ? ", " + metadataColNames.join(", ") : ""}) VALUES ([${vector}], '${document.pageContent}'${metadataVals.length > 0
192
+ ? ", " +
193
+ metadataVals
194
+ .map((val) => (typeof val === "number" ? val : `'${val}'`))
195
+ .join(", ")
196
+ : ""});`;
197
+ queries.push(query);
198
+ }
199
+ return queries;
200
+ }
201
+ /**
202
+ * Method to build an CQL query for searching for similar vectors in the
203
+ * Cassandra database.
204
+ * @param query The query vector.
205
+ * @param k The number of similar vectors to return.
206
+ * @returns The CQL query string.
207
+ */
208
+ buildSearchQuery(query, k) {
209
+ return `SELECT * FROM ${this.keyspace}.${this.table} ORDER BY vector ANN OF [${query}] LIMIT ${k || 1};`;
210
+ }
211
+ }
212
+ exports.CassandraStore = CassandraStore;