@lancedb/lancedb 0.5.0 → 0.5.1

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (48) hide show
  1. package/biome.json +8 -2
  2. package/dist/arrow.d.ts +34 -9
  3. package/dist/arrow.js +220 -23
  4. package/dist/connection.d.ts +4 -1
  5. package/dist/connection.js +11 -5
  6. package/dist/embedding/embedding_function.d.ts +54 -28
  7. package/dist/embedding/embedding_function.js +71 -10
  8. package/dist/embedding/index.d.ts +28 -2
  9. package/dist/embedding/index.js +111 -4
  10. package/dist/embedding/openai.d.ts +16 -7
  11. package/dist/embedding/openai.js +62 -12
  12. package/dist/embedding/registry.d.ts +54 -0
  13. package/dist/embedding/registry.js +123 -0
  14. package/dist/query.d.ts +1 -1
  15. package/dist/query.js +3 -3
  16. package/dist/sanitize.d.ts +22 -1
  17. package/dist/sanitize.js +123 -110
  18. package/dist/table.d.ts +1 -2
  19. package/dist/table.js +6 -3
  20. package/lancedb/arrow.ts +234 -38
  21. package/lancedb/connection.ts +27 -6
  22. package/lancedb/embedding/embedding_function.ts +126 -42
  23. package/lancedb/embedding/index.ts +113 -2
  24. package/lancedb/embedding/openai.ts +62 -16
  25. package/lancedb/embedding/registry.ts +172 -0
  26. package/lancedb/query.ts +2 -1
  27. package/lancedb/sanitize.ts +22 -22
  28. package/lancedb/table.ts +10 -3
  29. package/nodejs-artifacts/arrow.d.ts +34 -9
  30. package/nodejs-artifacts/arrow.js +220 -23
  31. package/nodejs-artifacts/connection.d.ts +4 -1
  32. package/nodejs-artifacts/connection.js +11 -5
  33. package/nodejs-artifacts/embedding/embedding_function.d.ts +54 -28
  34. package/nodejs-artifacts/embedding/embedding_function.js +71 -10
  35. package/nodejs-artifacts/embedding/index.d.ts +28 -2
  36. package/nodejs-artifacts/embedding/index.js +111 -4
  37. package/nodejs-artifacts/embedding/openai.d.ts +16 -7
  38. package/nodejs-artifacts/embedding/openai.js +62 -12
  39. package/nodejs-artifacts/embedding/registry.d.ts +54 -0
  40. package/nodejs-artifacts/embedding/registry.js +123 -0
  41. package/nodejs-artifacts/query.d.ts +1 -1
  42. package/nodejs-artifacts/query.js +3 -3
  43. package/nodejs-artifacts/sanitize.d.ts +22 -1
  44. package/nodejs-artifacts/sanitize.js +123 -110
  45. package/nodejs-artifacts/table.d.ts +1 -2
  46. package/nodejs-artifacts/table.js +6 -3
  47. package/package.json +14 -9
  48. package/tsconfig.json +3 -1
package/lancedb/arrow.ts CHANGED
@@ -17,10 +17,14 @@ import {
17
17
  Binary,
18
18
  DataType,
19
19
  Field,
20
+ FixedSizeBinary,
20
21
  FixedSizeList,
21
- type Float,
22
+ Float,
22
23
  Float32,
24
+ Int,
25
+ LargeBinary,
23
26
  List,
27
+ Null,
24
28
  RecordBatch,
25
29
  RecordBatchFileWriter,
26
30
  RecordBatchStreamWriter,
@@ -34,7 +38,99 @@ import {
34
38
  vectorFromArray,
35
39
  } from "apache-arrow";
36
40
  import { type EmbeddingFunction } from "./embedding/embedding_function";
37
- import { sanitizeSchema } from "./sanitize";
41
+ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
42
+ import { sanitizeField, sanitizeSchema, sanitizeType } from "./sanitize";
43
+ export * from "apache-arrow";
44
+
45
+ export function isArrowTable(value: object): value is ArrowTable {
46
+ if (value instanceof ArrowTable) return true;
47
+ return "schema" in value && "batches" in value;
48
+ }
49
+
50
+ export function isDataType(value: unknown): value is DataType {
51
+ return (
52
+ value instanceof DataType ||
53
+ DataType.isNull(value) ||
54
+ DataType.isInt(value) ||
55
+ DataType.isFloat(value) ||
56
+ DataType.isBinary(value) ||
57
+ DataType.isLargeBinary(value) ||
58
+ DataType.isUtf8(value) ||
59
+ DataType.isLargeUtf8(value) ||
60
+ DataType.isBool(value) ||
61
+ DataType.isDecimal(value) ||
62
+ DataType.isDate(value) ||
63
+ DataType.isTime(value) ||
64
+ DataType.isTimestamp(value) ||
65
+ DataType.isInterval(value) ||
66
+ DataType.isDuration(value) ||
67
+ DataType.isList(value) ||
68
+ DataType.isStruct(value) ||
69
+ DataType.isUnion(value) ||
70
+ DataType.isFixedSizeBinary(value) ||
71
+ DataType.isFixedSizeList(value) ||
72
+ DataType.isMap(value) ||
73
+ DataType.isDictionary(value)
74
+ );
75
+ }
76
+ export function isNull(value: unknown): value is Null {
77
+ return value instanceof Null || DataType.isNull(value);
78
+ }
79
+ export function isInt(value: unknown): value is Int {
80
+ return value instanceof Int || DataType.isInt(value);
81
+ }
82
+ export function isFloat(value: unknown): value is Float {
83
+ return value instanceof Float || DataType.isFloat(value);
84
+ }
85
+ export function isBinary(value: unknown): value is Binary {
86
+ return value instanceof Binary || DataType.isBinary(value);
87
+ }
88
+ export function isLargeBinary(value: unknown): value is LargeBinary {
89
+ return value instanceof LargeBinary || DataType.isLargeBinary(value);
90
+ }
91
+ export function isUtf8(value: unknown): value is Utf8 {
92
+ return value instanceof Utf8 || DataType.isUtf8(value);
93
+ }
94
+ export function isLargeUtf8(value: unknown): value is Utf8 {
95
+ return value instanceof Utf8 || DataType.isLargeUtf8(value);
96
+ }
97
+ export function isBool(value: unknown): value is Utf8 {
98
+ return value instanceof Utf8 || DataType.isBool(value);
99
+ }
100
+ export function isDecimal(value: unknown): value is Utf8 {
101
+ return value instanceof Utf8 || DataType.isDecimal(value);
102
+ }
103
+ export function isDate(value: unknown): value is Utf8 {
104
+ return value instanceof Utf8 || DataType.isDate(value);
105
+ }
106
+ export function isTime(value: unknown): value is Utf8 {
107
+ return value instanceof Utf8 || DataType.isTime(value);
108
+ }
109
+ export function isTimestamp(value: unknown): value is Utf8 {
110
+ return value instanceof Utf8 || DataType.isTimestamp(value);
111
+ }
112
+ export function isInterval(value: unknown): value is Utf8 {
113
+ return value instanceof Utf8 || DataType.isInterval(value);
114
+ }
115
+ export function isDuration(value: unknown): value is Utf8 {
116
+ return value instanceof Utf8 || DataType.isDuration(value);
117
+ }
118
+ export function isList(value: unknown): value is List {
119
+ return value instanceof List || DataType.isList(value);
120
+ }
121
+ export function isStruct(value: unknown): value is Struct {
122
+ return value instanceof Struct || DataType.isStruct(value);
123
+ }
124
+ export function isUnion(value: unknown): value is Struct {
125
+ return value instanceof Struct || DataType.isUnion(value);
126
+ }
127
+ export function isFixedSizeBinary(value: unknown): value is FixedSizeBinary {
128
+ return value instanceof FixedSizeBinary || DataType.isFixedSizeBinary(value);
129
+ }
130
+
131
+ export function isFixedSizeList(value: unknown): value is FixedSizeList {
132
+ return value instanceof FixedSizeList || DataType.isFixedSizeList(value);
133
+ }
38
134
 
39
135
  /** Data type accepted by NodeJS SDK */
40
136
  export type Data = Record<string, unknown>[] | ArrowTable;
@@ -198,6 +294,7 @@ export class MakeArrowTableOptions {
198
294
  export function makeArrowTable(
199
295
  data: Array<Record<string, unknown>>,
200
296
  options?: Partial<MakeArrowTableOptions>,
297
+ metadata?: Map<string, string>,
201
298
  ): ArrowTable {
202
299
  if (
203
300
  data.length === 0 &&
@@ -290,20 +387,41 @@ export function makeArrowTable(
290
387
  // `new ArrowTable(schema, batches)` which does not do any schema inference
291
388
  const firstTable = new ArrowTable(columns);
292
389
  const batchesFixed = firstTable.batches.map(
293
- // eslint-disable-next-line @typescript-eslint/no-non-null-assertion
294
390
  (batch) => new RecordBatch(opt.schema!, batch.data),
295
391
  );
296
- return new ArrowTable(opt.schema, batchesFixed);
297
- } else {
298
- return new ArrowTable(columns);
392
+ let schema: Schema;
393
+ if (metadata !== undefined) {
394
+ let schemaMetadata = opt.schema.metadata;
395
+ if (schemaMetadata.size === 0) {
396
+ schemaMetadata = metadata;
397
+ } else {
398
+ for (const [key, entry] of schemaMetadata.entries()) {
399
+ schemaMetadata.set(key, entry);
400
+ }
401
+ }
402
+
403
+ schema = new Schema(opt.schema.fields, schemaMetadata);
404
+ } else {
405
+ schema = opt.schema;
406
+ }
407
+ return new ArrowTable(schema, batchesFixed);
299
408
  }
409
+ const tbl = new ArrowTable(columns);
410
+ if (metadata !== undefined) {
411
+ // biome-ignore lint/suspicious/noExplicitAny: <explanation>
412
+ (<any>tbl.schema).metadata = metadata;
413
+ }
414
+ return tbl;
300
415
  }
301
416
 
302
417
  /**
303
418
  * Create an empty Arrow table with the provided schema
304
419
  */
305
- export function makeEmptyTable(schema: Schema): ArrowTable {
306
- return makeArrowTable([], { schema });
420
+ export function makeEmptyTable(
421
+ schema: Schema,
422
+ metadata?: Map<string, string>,
423
+ ): ArrowTable {
424
+ return makeArrowTable([], { schema }, metadata);
307
425
  }
308
426
 
309
427
  /**
@@ -375,13 +493,75 @@ function makeVector(
375
493
  }
376
494
  }
377
495
 
496
+ /** Helper function to apply embeddings from metadata to an input table */
497
+ async function applyEmbeddingsFromMetadata(
498
+ table: ArrowTable,
499
+ schema: Schema,
500
+ ): Promise<ArrowTable> {
501
+ const registry = getRegistry();
502
+ const functions = registry.parseFunctions(schema.metadata);
503
+
504
+ const columns = Object.fromEntries(
505
+ table.schema.fields.map((field) => [
506
+ field.name,
507
+ table.getChild(field.name)!,
508
+ ]),
509
+ );
510
+
511
+ for (const functionEntry of functions.values()) {
512
+ const sourceColumn = columns[functionEntry.sourceColumn];
513
+ const destColumn = functionEntry.vectorColumn ?? "vector";
514
+ if (sourceColumn === undefined) {
515
+ throw new Error(
516
+ `Cannot apply embedding function because the source column '${functionEntry.sourceColumn}' was not present in the data`,
517
+ );
518
+ }
519
+ if (columns[destColumn] !== undefined) {
520
+ throw new Error(
521
+ `Attempt to apply embeddings to table failed because column ${destColumn} already existed`,
522
+ );
523
+ }
524
+ if (table.batches.length > 1) {
525
+ throw new Error(
526
+ "Internal error: `makeArrowTable` unexpectedly created a table with more than one batch",
527
+ );
528
+ }
529
+ const values = sourceColumn.toArray();
530
+
531
+ const vectors =
532
+ await functionEntry.function.computeSourceEmbeddings(values);
533
+ if (vectors.length !== values.length) {
534
+ throw new Error(
535
+ "Embedding function did not return an embedding for each input element",
536
+ );
537
+ }
538
+ let destType: DataType;
539
+ const dtype = schema.fields.find((f) => f.name === destColumn)!.type;
540
+ if (isFixedSizeList(dtype)) {
541
+ destType = sanitizeType(dtype);
542
+ } else {
543
+ throw new Error(
544
+ "Expected FixedSizeList as datatype for vector field, instead got: " +
545
+ dtype,
546
+ );
547
+ }
548
+
549
+ const vector = makeVector(vectors, destType);
550
+ columns[destColumn] = vector;
551
+ }
552
+ const newTable = new ArrowTable(columns);
553
+ return alignTable(newTable, schema);
554
+ }
555
+
378
556
  /** Helper function to apply embeddings to an input table */
379
557
  async function applyEmbeddings<T>(
380
558
  table: ArrowTable,
381
- embeddings?: EmbeddingFunction<T>,
559
+ embeddings?: EmbeddingFunctionConfig,
382
560
  schema?: Schema,
383
561
  ): Promise<ArrowTable> {
384
- if (embeddings == null) {
562
+ if (schema?.metadata.has("embedding_functions")) {
563
+ return applyEmbeddingsFromMetadata(table, schema!);
564
+ } else if (embeddings == null || embeddings === undefined) {
385
565
  return table;
386
566
  }
387
567
 
@@ -399,8 +579,9 @@ async function applyEmbeddings<T>(
399
579
  const newColumns = Object.fromEntries(colEntries);
400
580
 
401
581
  const sourceColumn = newColumns[embeddings.sourceColumn];
402
- const destColumn = embeddings.destColumn ?? "vector";
403
- const innerDestType = embeddings.embeddingDataType ?? new Float32();
582
+ const destColumn = embeddings.vectorColumn ?? "vector";
583
+ const innerDestType =
584
+ embeddings.function.embeddingDataType() ?? new Float32();
404
585
  if (sourceColumn === undefined) {
405
586
  throw new Error(
406
587
  `Cannot apply embedding function because the source column '${embeddings.sourceColumn}' was not present in the data`,
@@ -414,11 +595,9 @@ async function applyEmbeddings<T>(
414
595
  // if we call convertToTable with 0 records and a schema that includes the embedding
415
596
  return table;
416
597
  }
417
- if (embeddings.embeddingDimension !== undefined) {
418
- const destType = newVectorType(
419
- embeddings.embeddingDimension,
420
- innerDestType,
421
- );
598
+ const dimensions = embeddings.function.ndims();
599
+ if (dimensions !== undefined) {
600
+ const destType = newVectorType(dimensions, innerDestType);
422
601
  newColumns[destColumn] = makeVector([], destType);
423
602
  } else if (schema != null) {
424
603
  const destField = schema.fields.find((f) => f.name === destColumn);
@@ -446,7 +625,9 @@ async function applyEmbeddings<T>(
446
625
  );
447
626
  }
448
627
  const values = sourceColumn.toArray();
449
- const vectors = await embeddings.embed(values as T[]);
628
+ const vectors = await embeddings.function.computeSourceEmbeddings(
629
+ values as T[],
630
+ );
450
631
  if (vectors.length !== values.length) {
451
632
  throw new Error(
452
633
  "Embedding function did not return an embedding for each input element",
@@ -486,9 +667,9 @@ async function applyEmbeddings<T>(
486
667
  * embedding columns. If no schema is provded then embedding columns will
487
668
  * be placed at the end of the table, after all of the input columns.
488
669
  */
489
- export async function convertToTable<T>(
670
+ export async function convertToTable(
490
671
  data: Array<Record<string, unknown>>,
491
- embeddings?: EmbeddingFunction<T>,
672
+ embeddings?: EmbeddingFunctionConfig,
492
673
  makeTableOptions?: Partial<MakeArrowTableOptions>,
493
674
  ): Promise<ArrowTable> {
494
675
  const table = makeArrowTable(data, makeTableOptions);
@@ -496,13 +677,13 @@ export async function convertToTable<T>(
496
677
  }
497
678
 
498
679
  /** Creates the Arrow Type for a Vector column with dimension `dim` */
499
- function newVectorType<T extends Float>(
680
+ export function newVectorType<T extends Float>(
500
681
  dim: number,
501
682
  innerType: T,
502
683
  ): FixedSizeList<T> {
503
684
  // in Lance we always default to have the elements nullable, so we need to set it to true
504
685
  // otherwise we often get schema mismatches because the stored data always has schema with nullable elements
505
- const children = new Field<T>("item", innerType, true);
686
+ const children = new Field("item", <T>sanitizeType(innerType), true);
506
687
  return new FixedSizeList(dim, children);
507
688
  }
508
689
 
@@ -513,9 +694,9 @@ function newVectorType<T extends Float>(
513
694
  *
514
695
  * `schema` is required if data is empty
515
696
  */
516
- export async function fromRecordsToBuffer<T>(
697
+ export async function fromRecordsToBuffer(
517
698
  data: Array<Record<string, unknown>>,
518
- embeddings?: EmbeddingFunction<T>,
699
+ embeddings?: EmbeddingFunctionConfig,
519
700
  schema?: Schema,
520
701
  ): Promise<Buffer> {
521
702
  if (schema !== undefined && schema !== null) {
@@ -533,9 +714,9 @@ export async function fromRecordsToBuffer<T>(
533
714
  *
534
715
  * `schema` is required if data is empty
535
716
  */
536
- export async function fromRecordsToStreamBuffer<T>(
717
+ export async function fromRecordsToStreamBuffer(
537
718
  data: Array<Record<string, unknown>>,
538
- embeddings?: EmbeddingFunction<T>,
719
+ embeddings?: EmbeddingFunctionConfig,
539
720
  schema?: Schema,
540
721
  ): Promise<Buffer> {
541
722
  if (schema !== undefined && schema !== null) {
@@ -554,9 +735,9 @@ export async function fromRecordsToStreamBuffer<T>(
554
735
  *
555
736
  * `schema` is required if the table is empty
556
737
  */
557
- export async function fromTableToBuffer<T>(
738
+ export async function fromTableToBuffer(
558
739
  table: ArrowTable,
559
- embeddings?: EmbeddingFunction<T>,
740
+ embeddings?: EmbeddingFunctionConfig,
560
741
  schema?: Schema,
561
742
  ): Promise<Buffer> {
562
743
  if (schema !== undefined && schema !== null) {
@@ -575,19 +756,19 @@ export async function fromTableToBuffer<T>(
575
756
  *
576
757
  * `schema` is required if the table is empty
577
758
  */
578
- export async function fromDataToBuffer<T>(
759
+ export async function fromDataToBuffer(
579
760
  data: Data,
580
- embeddings?: EmbeddingFunction<T>,
761
+ embeddings?: EmbeddingFunctionConfig,
581
762
  schema?: Schema,
582
763
  ): Promise<Buffer> {
583
764
  if (schema !== undefined && schema !== null) {
584
765
  schema = sanitizeSchema(schema);
585
766
  }
586
- if (data instanceof ArrowTable) {
767
+ if (isArrowTable(data)) {
587
768
  return fromTableToBuffer(data, embeddings, schema);
588
769
  } else {
589
- const table = await convertToTable(data);
590
- return fromTableToBuffer(table, embeddings, schema);
770
+ const table = await convertToTable(data, embeddings, { schema });
771
+ return fromTableToBuffer(table);
591
772
  }
592
773
  }
593
774
 
@@ -599,9 +780,9 @@ export async function fromDataToBuffer<T>(
599
780
  *
600
781
  * `schema` is required if the table is empty
601
782
  */
602
- export async function fromTableToStreamBuffer<T>(
783
+ export async function fromTableToStreamBuffer(
603
784
  table: ArrowTable,
604
- embeddings?: EmbeddingFunction<T>,
785
+ embeddings?: EmbeddingFunctionConfig,
605
786
  schema?: Schema,
606
787
  ): Promise<Buffer> {
607
788
  const tableWithEmbeddings = await applyEmbeddings(table, embeddings, schema);
@@ -664,10 +845,25 @@ function validateSchemaEmbeddings(
664
845
  // if it does not, we add it to the list of missing embedding fields
665
846
  // Finally, we check if those missing embedding fields are `this._embeddings`
666
847
  // if they are not, we throw an error
667
- for (const field of schema.fields) {
668
- if (field.type instanceof FixedSizeList) {
848
+ for (let field of schema.fields) {
849
+ if (isFixedSizeList(field.type)) {
850
+ field = sanitizeField(field);
851
+
669
852
  if (data.length !== 0 && data?.[0]?.[field.name] === undefined) {
670
- missingEmbeddingFields.push(field);
853
+ if (schema.metadata.has("embedding_functions")) {
854
+ const embeddings = JSON.parse(
855
+ schema.metadata.get("embedding_functions")!,
856
+ );
857
+ if (
858
+ // biome-ignore lint/suspicious/noExplicitAny: we don't know the type of `f`
859
+ embeddings.find((f: any) => f["vectorColumn"] === field.name) ===
860
+ undefined
861
+ ) {
862
+ missingEmbeddingFields.push(field);
863
+ }
864
+ } else {
865
+ missingEmbeddingFields.push(field);
866
+ }
671
867
  } else {
672
868
  fields.push(field);
673
869
  }
@@ -12,8 +12,14 @@
12
12
  // See the License for the specific language governing permissions and
13
13
  // limitations under the License.
14
14
 
15
- import { Table as ArrowTable, Schema } from "apache-arrow";
16
- import { fromTableToBuffer, makeArrowTable, makeEmptyTable } from "./arrow";
15
+ import { Table as ArrowTable, Schema } from "./arrow";
16
+ import {
17
+ fromTableToBuffer,
18
+ isArrowTable,
19
+ makeArrowTable,
20
+ makeEmptyTable,
21
+ } from "./arrow";
22
+ import { EmbeddingFunctionConfig, getRegistry } from "./embedding/registry";
17
23
  import { ConnectionOptions, Connection as LanceDbConnection } from "./native";
18
24
  import { Table } from "./table";
19
25
 
@@ -65,6 +71,8 @@ export interface CreateTableOptions {
65
71
  * The available options are described at https://lancedb.github.io/lancedb/guides/storage/
66
72
  */
67
73
  storageOptions?: Record<string, string>;
74
+ schema?: Schema;
75
+ embeddingFunction?: EmbeddingFunctionConfig;
68
76
  }
69
77
 
70
78
  export interface OpenTableOptions {
@@ -174,6 +182,7 @@ export class Connection {
174
182
  cleanseStorageOptions(options?.storageOptions),
175
183
  options?.indexCacheSize,
176
184
  );
185
+
177
186
  return new Table(innerTable);
178
187
  }
179
188
 
@@ -196,18 +205,24 @@ export class Connection {
196
205
  }
197
206
 
198
207
  let table: ArrowTable;
199
- if (data instanceof ArrowTable) {
208
+ if (isArrowTable(data)) {
200
209
  table = data;
201
210
  } else {
202
- table = makeArrowTable(data);
211
+ table = makeArrowTable(data, options);
203
212
  }
204
- const buf = await fromTableToBuffer(table);
213
+
214
+ const buf = await fromTableToBuffer(
215
+ table,
216
+ options?.embeddingFunction,
217
+ options?.schema,
218
+ );
205
219
  const innerTable = await this.inner.createTable(
206
220
  name,
207
221
  buf,
208
222
  mode,
209
223
  cleanseStorageOptions(options?.storageOptions),
210
224
  );
225
+
211
226
  return new Table(innerTable);
212
227
  }
213
228
 
@@ -227,8 +242,14 @@ export class Connection {
227
242
  if (mode === "create" && existOk) {
228
243
  mode = "exist_ok";
229
244
  }
245
+ let metadata: Map<string, string> | undefined = undefined;
246
+ if (options?.embeddingFunction !== undefined) {
247
+ const embeddingFunction = options.embeddingFunction;
248
+ const registry = getRegistry();
249
+ metadata = registry.getTableMetadata([embeddingFunction]);
250
+ }
230
251
 
231
- const table = makeEmptyTable(schema);
252
+ const table = makeEmptyTable(schema, metadata);
232
253
  const buf = await fromTableToBuffer(table);
233
254
  const innerTable = await this.inner.createEmptyTable(
234
255
  name,
@@ -1,4 +1,4 @@
1
- // Copyright 2023 Lance Developers.
1
+ // Copyright 2024 Lance Developers.
2
2
  //
3
3
  // Licensed under the Apache License, Version 2.0 (the "License");
4
4
  // you may not use this file except in compliance with the License.
@@ -12,67 +12,151 @@
12
12
  // See the License for the specific language governing permissions and
13
13
  // limitations under the License.
14
14
 
15
- import { type Float } from "apache-arrow";
15
+ import "reflect-metadata";
16
+ import {
17
+ DataType,
18
+ Field,
19
+ FixedSizeList,
20
+ Float,
21
+ Float32,
22
+ isDataType,
23
+ isFixedSizeList,
24
+ isFloat,
25
+ newVectorType,
26
+ } from "../arrow";
27
+ import { sanitizeType } from "../sanitize";
16
28
 
17
29
  /**
18
- * An embedding function that automatically creates vector representation for a given column.
30
+ * Options for a given embedding function
19
31
  */
20
- export interface EmbeddingFunction<T> {
21
- /**
22
- * The name of the column that will be used as input for the Embedding Function.
23
- */
24
- sourceColumn: string;
32
+ export interface FunctionOptions {
33
+ // biome-ignore lint/suspicious/noExplicitAny: options can be anything
34
+ [key: string]: any;
35
+ }
25
36
 
37
+ /**
38
+ * An embedding function that automatically creates vector representation for a given column.
39
+ */
40
+ export abstract class EmbeddingFunction<
41
+ // biome-ignore lint/suspicious/noExplicitAny: we don't know what the implementor will do
42
+ T = any,
43
+ M extends FunctionOptions = FunctionOptions,
44
+ > {
26
45
  /**
27
- * The data type of the embedding
46
+ * Convert the embedding function to a JSON object
47
+ * It is used to serialize the embedding function to the schema
48
+ * It's important that any object returned by this method contains all the necessary
49
+ * information to recreate the embedding function
28
50
  *
29
- * The embedding function should return `number`. This will be converted into
30
- * an Arrow float array. By default this will be Float32 but this property can
31
- * be used to control the conversion.
32
- */
33
- embeddingDataType?: Float;
34
-
35
- /**
36
- * The dimension of the embedding
51
+ * It should return the same object that was passed to the constructor
52
+ * If it does not, the embedding function will not be able to be recreated, or could be recreated incorrectly
37
53
  *
38
- * This is optional, normally this can be determined by looking at the results of
39
- * `embed`. If this is not specified, and there is an attempt to apply the embedding
40
- * to an empty table, then that process will fail.
54
+ * @example
55
+ * ```ts
56
+ * class MyEmbeddingFunction extends EmbeddingFunction {
57
+ * constructor(options: {model: string, timeout: number}) {
58
+ * super();
59
+ * this.model = options.model;
60
+ * this.timeout = options.timeout;
61
+ * }
62
+ * toJSON() {
63
+ * return {
64
+ * model: this.model,
65
+ * timeout: this.timeout,
66
+ * };
67
+ * }
68
+ * ```
41
69
  */
42
- embeddingDimension?: number;
70
+ abstract toJSON(): Partial<M>;
43
71
 
44
72
  /**
45
- * The name of the column that will contain the embedding
73
+ * sourceField is used in combination with `LanceSchema` to provide a declarative data model
46
74
  *
47
- * By default this is "vector"
75
+ * @param optionsOrDatatype - The options for the field or the datatype
76
+ *
77
+ * @see {@link lancedb.LanceSchema}
48
78
  */
49
- destColumn?: string;
79
+ sourceField(
80
+ optionsOrDatatype: Partial<FieldOptions> | DataType,
81
+ ): [DataType, Map<string, EmbeddingFunction>] {
82
+ let datatype = isDataType(optionsOrDatatype)
83
+ ? optionsOrDatatype
84
+ : optionsOrDatatype?.datatype;
85
+ if (!datatype) {
86
+ throw new Error("Datatype is required");
87
+ }
88
+ datatype = sanitizeType(datatype);
89
+ const metadata = new Map<string, EmbeddingFunction>();
90
+ metadata.set("source_column_for", this);
91
+
92
+ return [datatype, metadata];
93
+ }
50
94
 
51
95
  /**
52
- * Should the source column be excluded from the resulting table
96
+ * vectorField is used in combination with `LanceSchema` to provide a declarative data model
53
97
  *
54
- * By default the source column is included. Set this to true and
55
- * only the embedding will be stored.
98
+ * @param options - The options for the field
99
+ *
100
+ * @see {@link lancedb.LanceSchema}
56
101
  */
57
- excludeSource?: boolean;
102
+ vectorField(
103
+ options?: Partial<FieldOptions>,
104
+ ): [DataType, Map<string, EmbeddingFunction>] {
105
+ let dtype: DataType;
106
+ const dims = this.ndims() ?? options?.dims;
107
+ if (!options?.datatype) {
108
+ if (dims === undefined) {
109
+ throw new Error("ndims is required for vector field");
110
+ }
111
+ dtype = new FixedSizeList(dims, new Field("item", new Float32(), true));
112
+ } else {
113
+ if (isFixedSizeList(options.datatype)) {
114
+ dtype = options.datatype;
115
+ } else if (isFloat(options.datatype)) {
116
+ if (dims === undefined) {
117
+ throw new Error("ndims is required for vector field");
118
+ }
119
+ dtype = newVectorType(dims, options.datatype);
120
+ } else {
121
+ throw new Error(
122
+ "Expected FixedSizeList or Float as datatype for vector field",
123
+ );
124
+ }
125
+ }
126
+ const metadata = new Map<string, EmbeddingFunction>();
127
+ metadata.set("vector_column_for", this);
128
+
129
+ return [dtype, metadata];
130
+ }
131
+
132
+ /** The number of dimensions of the embeddings */
133
+ ndims(): number | undefined {
134
+ return undefined;
135
+ }
136
+
137
+ /** The datatype of the embeddings */
138
+ abstract embeddingDataType(): Float;
58
139
 
59
140
  /**
60
141
  * Creates a vector representation for the given values.
61
142
  */
62
- embed: (data: T[]) => Promise<number[][]>;
63
- }
143
+ abstract computeSourceEmbeddings(
144
+ data: T[],
145
+ ): Promise<number[][] | Float32Array[] | Float64Array[]>;
64
146
 
65
- /** Test if the input seems to be an embedding function */
66
- export function isEmbeddingFunction<T>(
67
- value: unknown,
68
- ): value is EmbeddingFunction<T> {
69
- if (typeof value !== "object" || value === null) {
70
- return false;
71
- }
72
- if (!("sourceColumn" in value) || !("embed" in value)) {
73
- return false;
147
+ /**
148
+ Compute the embeddings for a single query
149
+ */
150
+ async computeQueryEmbeddings(
151
+ data: T,
152
+ ): Promise<number[] | Float32Array | Float64Array> {
153
+ return this.computeSourceEmbeddings([data]).then(
154
+ (embeddings) => embeddings[0],
155
+ );
74
156
  }
75
- return (
76
- typeof value.sourceColumn === "string" && typeof value.embed === "function"
77
- );
157
+ }
158
+
159
+ export interface FieldOptions<T extends DataType = DataType> {
160
+ datatype: T;
161
+ dims?: number;
78
162
  }