neighbor 0.4.3 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: e8d611fd277cd48d309b2a087fdeb22f39f43d8ef81fcab57763bd5b4b2e48b3
4
- data.tar.gz: fe0a5f7e4aa1ebd81f8c5849be67dc0f3c948d53d313b80259def04b9b7e9e84
3
+ metadata.gz: 5d7036a69b1c57161eaeb11e38feee92c1d5082ddbe8907a83ac3126adf9ae56
4
+ data.tar.gz: c88e4400b75d2a87f766f7e0b7ff6c5311d9c7866d32e7c2b9aa7601f60c474f
5
5
  SHA512:
6
- metadata.gz: 0d9d9d0be9f2929f1eab7e5df52a0606ca8c220ea74e6ef349c2ab77e5da44bd4b98ed4d2271345b2d679882f1790ac54b8373c13a69f75607935322bdb68754
7
- data.tar.gz: 834c7d6e26be9b6d8fc262280048e5104cd022587fed30d030f6c0edd849966a49a08bc793f1b49764b1fc5afc014c1463dd8acacfb5739bc507f4a77d3281a1
6
+ metadata.gz: d3c4c25404fb64f324fbba70edcf06d827d3708905ed4a84404a6c9ce39f27b6890d449b285ce302e24495a666c34f0bf3050270b54ef8d283f53ebeb19e4e91
7
+ data.tar.gz: 63927a8801a88edd48f74ce85d056d7112fa526b37d473b232256b5f2d47e5254b34b25e0312afc01bfedcd6a9d7826496f208e0ab0d3f57c3307fa298b8984e
data/CHANGELOG.md CHANGED
@@ -1,3 +1,12 @@
1
+ ## 0.5.0 (2024-10-07)
2
+
3
+ - Added experimental support for SQLite (sqlite-vec)
4
+ - Added experimental support for MariaDB 11.6 Vector
5
+ - Added experimental support for MySQL 9
6
+ - Changed `normalize` option to use Active Record normalization
7
+ - Fixed connection leasing for Active Record 7.2
8
+ - Dropped support for Active Record < 7
9
+
1
10
  ## 0.4.3 (2024-09-02)
2
11
 
3
12
  - Added `rrf` method
data/README.md CHANGED
@@ -1,6 +1,13 @@
1
1
  # Neighbor
2
2
 
3
- Nearest neighbor search for Rails and Postgres
3
+ Nearest neighbor search for Rails
4
+
5
+ Supports:
6
+
7
+ - Postgres (cube and pgvector)
8
+ - SQLite (sqlite-vec) - experimental
9
+ - MariaDB 11.6 Vector - experimental
10
+ - MySQL 9 (searching requires HeatWave) - experimental
4
11
 
5
12
  [![Build Status](https://github.com/ankane/neighbor/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/neighbor/actions)
6
13
 
@@ -12,7 +19,7 @@ Add this line to your application’s Gemfile:
12
19
  gem "neighbor"
13
20
  ```
14
21
 
15
- ## Choose An Extension
22
+ ### For Postgres
16
23
 
17
24
  Neighbor supports two extensions: [cube](https://www.postgresql.org/docs/current/cube.html) and [pgvector](https://github.com/pgvector/pgvector). cube ships with Postgres, while pgvector supports more dimensions and approximate nearest neighbor search.
18
25
 
@@ -30,6 +37,20 @@ rails generate neighbor:vector
30
37
  rails db:migrate
31
38
  ```
32
39
 
40
+ ### For SQLite
41
+
42
+ Add this line to your application’s Gemfile:
43
+
44
+ ```ruby
45
+ gem "sqlite-vec"
46
+ ```
47
+
48
+ And run:
49
+
50
+ ```sh
51
+ rails generate neighbor:sqlite
52
+ ```
53
+
33
54
  ## Getting Started
34
55
 
35
56
  Create a migration
@@ -37,9 +58,14 @@ Create a migration
37
58
  ```ruby
38
59
  class AddEmbeddingToItems < ActiveRecord::Migration[7.2]
39
60
  def change
61
+ # cube
40
62
  add_column :items, :embedding, :cube
41
- # or
63
+
64
+ # pgvector and MySQL
42
65
  add_column :items, :embedding, :vector, limit: 3 # dimensions
66
+
67
+ # sqlite-vec and MariaDB
68
+ add_column :items, :embedding, :binary
43
69
  end
44
70
  end
45
71
  ```
@@ -81,6 +107,9 @@ See the additional docs for:
81
107
 
82
108
  - [cube](#cube)
83
109
  - [pgvector](#pgvector)
110
+ - [sqlite-vec](#sqlite-vec)
111
+ - [MariaDB](#mariadb)
112
+ - [MySQL](#mysql)
84
113
 
85
114
  Or check out some [examples](#examples)
86
115
 
@@ -134,6 +163,12 @@ Supported values are:
134
163
 
135
164
  The `vector` type can have up to 16,000 dimensions, and vectors with up to 2,000 dimensions can be indexed.
136
165
 
166
+ The `halfvec` type can have up to 16,000 dimensions, and half vectors with up to 4,000 dimensions can be indexed.
167
+
168
+ The `bit` type can have up to 83 million dimensions, and bit vectors with up to 64,000 dimensions can be indexed.
169
+
170
+ The `sparsevec` type can have up to 16,000 non-zero elements, and sparse vectors with up to 1,000 non-zero elements can be indexed.
171
+
137
172
  ### Indexing
138
173
 
139
174
  Add an approximate index to speed up queries. Create a migration with:
@@ -241,6 +276,190 @@ embedding = Neighbor::SparseVector.new({0 => 0.9, 1 => 1.3, 2 => 1.1}, 3)
241
276
  Item.nearest_neighbors(:embedding, embedding, distance: "euclidean").first(5)
242
277
  ```
243
278
 
279
+ ## sqlite-vec
280
+
281
+ ### Distance
282
+
283
+ Supported values are:
284
+
285
+ - `euclidean`
286
+ - `cosine`
287
+ - `taxicab`
288
+ - `hamming`
289
+
290
+ ### Dimensions
291
+
292
+ For sqlite-vec, it’s a good idea to specify the number of dimensions to ensure all records have the same number.
293
+
294
+ ```ruby
295
+ class Item < ApplicationRecord
296
+ has_neighbors :embedding, dimensions: 3
297
+ end
298
+ ```
299
+
300
+ ### Virtual Tables
301
+
302
+ You can also use [virtual tables](https://alexgarcia.xyz/sqlite-vec/features/knn.html)
303
+
304
+ ```ruby
305
+ class AddEmbeddingToItems < ActiveRecord::Migration[7.2]
306
+ def change
307
+ # Rails < 8
308
+ execute <<~SQL
309
+ CREATE VIRTUAL TABLE items USING vec0(
310
+ embedding float[3] distance_metric=L2
311
+ )
312
+ SQL
313
+
314
+ # Rails 8+
315
+ create_virtual_table :items, :vec0, [
316
+ "embedding float[3] distance_metric=L2"
317
+ ]
318
+ end
319
+ end
320
+ ```
321
+
322
+ Use `distance_metric=cosine` for cosine distance
323
+
324
+ You can optionally ignore any shadow tables that are created
325
+
326
+ ```ruby
327
+ ActiveRecord::SchemaDumper.ignore_tables += [
328
+ "items_chunks", "items_rowids", "items_vector_chunks00"
329
+ ]
330
+ ```
331
+
332
+ Create a model with `rowid` as the primary key
333
+
334
+ ```ruby
335
+ class Item < ApplicationRecord
336
+ self.primary_key = "rowid"
337
+
338
+ has_neighbors :embedding, dimensions: 3
339
+ end
340
+ ```
341
+
342
+ Get the `k` nearest neighbors
343
+
344
+ ```ruby
345
+ Item.where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
346
+ ```
347
+
348
+ Filter by primary key
349
+
350
+ ```ruby
351
+ Item.where(rowid: [2, 3]).where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
352
+ ```
353
+
354
+ ### Int8 Vectors
355
+
356
+ Use the `type` option for int8 vectors
357
+
358
+ ```ruby
359
+ class Item < ApplicationRecord
360
+ has_neighbors :embedding, dimensions: 3, type: :int8
361
+ end
362
+ ```
363
+
364
+ ### Binary Vectors
365
+
366
+ Use the `type` option for binary vectors
367
+
368
+ ```ruby
369
+ class Item < ApplicationRecord
370
+ has_neighbors :embedding, dimensions: 8, type: :bit
371
+ end
372
+ ```
373
+
374
+ Get the nearest neighbors by Hamming distance
375
+
376
+ ```ruby
377
+ Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
378
+ ```
379
+
380
+ ## MariaDB
381
+
382
+ ### Distance
383
+
384
+ Supported values are:
385
+
386
+ - `euclidean`
387
+ - `cosine`
388
+ - `hamming`
389
+
390
+ For cosine distance with MariaDB, vectors must be normalized before being stored.
391
+
392
+ ```ruby
393
+ class Item < ApplicationRecord
394
+ has_neighbors :embedding, normalize: true
395
+ end
396
+ ```
397
+
398
+ ### Indexing
399
+
400
+ Vector columns must use `null: false` to add a vector index
401
+
402
+ ```ruby
403
+ class CreateItems < ActiveRecord::Migration[7.2]
404
+ def change
405
+ create_table :items do |t|
406
+ t.binary :embedding, null: false
407
+ t.index :embedding, type: :vector
408
+ end
409
+ end
410
+ end
411
+ ```
412
+
413
+ ### Binary Vectors
414
+
415
+ Use the `bigint` type to store binary vectors
416
+
417
+ ```ruby
418
+ class AddEmbeddingToItems < ActiveRecord::Migration[7.2]
419
+ def change
420
+ add_column :items, :embedding, :bigint
421
+ end
422
+ end
423
+ ```
424
+
425
+ Note: Binary vectors can have up to 64 dimensions
426
+
427
+ Get the nearest neighbors by Hamming distance
428
+
429
+ ```ruby
430
+ Item.nearest_neighbors(:embedding, 5, distance: "hamming").first(5)
431
+ ```
432
+
433
+ ## MySQL
434
+
435
+ ### Distance
436
+
437
+ Supported values are:
438
+
439
+ - `euclidean`
440
+ - `cosine`
441
+ - `hamming`
442
+
443
+ Note: The `DISTANCE()` function is [only available on HeatWave](https://dev.mysql.com/doc/refman/9.0/en/vector-functions.html)
444
+
445
+ ### Binary Vectors
446
+
447
+ Use the `binary` type to store binary vectors
448
+
449
+ ```ruby
450
+ class AddEmbeddingToItems < ActiveRecord::Migration[7.2]
451
+ def change
452
+ add_column :items, :embedding, :binary
453
+ end
454
+ end
455
+ ```
456
+
457
+ Get the nearest neighbors by Hamming distance
458
+
459
+ ```ruby
460
+ Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
461
+ ```
462
+
244
463
  ## Examples
245
464
 
246
465
  - [Embeddings](#openai-embeddings) with OpenAI
@@ -472,12 +691,9 @@ end
472
691
  Create some documents
473
692
 
474
693
  ```ruby
475
- texts = [
476
- "The dog is barking",
477
- "The cat is purring",
478
- "The bear is growling"
479
- ]
480
- documents = Document.create!(texts.map { |v| {content: v} })
694
+ Document.create!(content: "The dog is barking")
695
+ Document.create!(content: "The cat is purring")
696
+ Document.create!(content: "The bear is growling")
481
697
  ```
482
698
 
483
699
  Generate an embedding for each document
@@ -485,9 +701,9 @@ Generate an embedding for each document
485
701
  ```ruby
486
702
  embed = Informers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
487
703
  embed_options = {model_output: "sentence_embedding", pooling: "none"} # specific to embedding model
488
- embeddings = embed.(documents.map(&:content), **embed_options)
489
704
 
490
- documents.zip(embeddings) do |document, embedding|
705
+ Document.find_each do |document|
706
+ embedding = embed.(document.content, **embed_options)
491
707
  document.update!(embedding: embedding)
492
708
  end
493
709
  ```
@@ -511,7 +727,7 @@ semantic_results =
511
727
  To combine the results, use Reciprocal Rank Fusion (RRF)
512
728
 
513
729
  ```ruby
514
- Neighbor::Reranking.rrf(keyword_results, semantic_results)
730
+ Neighbor::Reranking.rrf(keyword_results, semantic_results).first(5)
515
731
  ```
516
732
 
517
733
  Or a reranking model
@@ -519,7 +735,7 @@ Or a reranking model
519
735
  ```ruby
520
736
  rerank = Informers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-xsmall-v1")
521
737
  results = (keyword_results + semantic_results).uniq
522
- rerank.(query, results.map(&:content), top_k: 5).map { |v| results[v[:doc_id]] }
738
+ rerank.(query, results.map(&:content)).first(5).map { |v| results[v[:doc_id]] }
523
739
  ```
524
740
 
525
741
  See the [complete code](examples/hybrid/example.rb)
@@ -667,6 +883,19 @@ To get started with development:
667
883
  git clone https://github.com/ankane/neighbor.git
668
884
  cd neighbor
669
885
  bundle install
886
+
887
+ # Postgres
670
888
  createdb neighbor_test
671
- bundle exec rake test
889
+ bundle exec rake test:postgresql
890
+
891
+ # SQLite
892
+ bundle exec rake test:sqlite
893
+
894
+ # MariaDB
895
+ docker run -e MARIADB_ALLOW_EMPTY_ROOT_PASSWORD=1 -e MARIADB_DATABASE=neighbor_test -p 3307:3306 quay.io/mariadb-foundation/mariadb-devel:11.6-vector-preview
896
+ bundle exec rake test:mariadb
897
+
898
+ # MySQL
899
+ docker run -e MYSQL_ALLOW_EMPTY_PASSWORD=1 -e MYSQL_DATABASE=neighbor_test -p 3306:3306 mysql:9
900
+ bundle exec rake test:mysql
672
901
  ```
@@ -0,0 +1,13 @@
1
+ require "rails/generators"
2
+
3
+ module Neighbor
4
+ module Generators
5
+ class SqliteGenerator < Rails::Generators::Base
6
+ source_root File.join(__dir__, "templates")
7
+
8
+ def copy_templates
9
+ template "sqlite.rb", "config/initializers/neighbor.rb"
10
+ end
11
+ end
12
+ end
13
+ end
@@ -0,0 +1,2 @@
1
+ # Load the sqlite-vec extension
2
+ Neighbor::SQLite.initialize!
@@ -0,0 +1,48 @@
1
+ module Neighbor
2
+ class Attribute < ActiveRecord::Type::Value
3
+ delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type
4
+
5
+ def initialize(cast_type:, model:, type:, attribute_name:)
6
+ @cast_type = cast_type
7
+ @model = model
8
+ @type = type
9
+ @attribute_name = attribute_name
10
+ end
11
+
12
+ private
13
+
14
+ def cast_value(...)
15
+ new_cast_type.send(:cast_value, ...)
16
+ end
17
+
18
+ def new_cast_type
19
+ @new_cast_type ||= begin
20
+ if @cast_type.is_a?(ActiveModel::Type::Value)
21
+ case Utils.adapter(@model)
22
+ when :sqlite
23
+ case @type&.to_sym
24
+ when :int8
25
+ Type::SqliteInt8Vector.new
26
+ when :bit
27
+ @cast_type
28
+ when :float32, nil
29
+ Type::SqliteVector.new
30
+ else
31
+ raise ArgumentError, "Unsupported type"
32
+ end
33
+ when :mariadb
34
+ if @model.columns_hash[@attribute_name.to_s]&.type == :integer
35
+ @cast_type
36
+ else
37
+ Type::MysqlVector.new
38
+ end
39
+ else
40
+ @cast_type
41
+ end
42
+ else
43
+ @cast_type
44
+ end
45
+ end
46
+ end
47
+ end
48
+ end
@@ -1,6 +1,6 @@
1
1
  module Neighbor
2
2
  module Model
3
- def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
3
+ def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
4
4
  if attribute_names.empty?
5
5
  raise ArgumentError, "has_neighbors requires an attribute name"
6
6
  end
@@ -24,125 +24,116 @@ module Neighbor
24
24
 
25
25
  attribute_names.each do |attribute_name|
26
26
  raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
27
- @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
27
+ @neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
28
+ end
29
+
30
+ if ActiveRecord::VERSION::STRING.to_f >= 7.2
31
+ decorate_attributes(attribute_names) do |name, cast_type|
32
+ Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
33
+ end
34
+ else
35
+ attribute_names.each do |attribute_name|
36
+ attribute attribute_name do |cast_type|
37
+ Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
38
+ end
39
+ end
40
+ end
41
+
42
+ if normalize
43
+ if ActiveRecord::VERSION::STRING.to_f >= 7.1
44
+ attribute_names.each do |attribute_name|
45
+ normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
46
+ end
47
+ else
48
+ attribute_names.each do |attribute_name|
49
+ attribute attribute_name do |cast_type|
50
+ Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
51
+ end
52
+ end
53
+ end
28
54
  end
29
55
 
30
56
  return if @neighbor_attributes.size != attribute_names.size
31
57
 
32
58
  validate do
59
+ adapter = Utils.adapter(self.class)
60
+
33
61
  self.class.neighbor_attributes.each do |k, v|
34
62
  value = read_attribute(k)
35
63
  next if value.nil?
36
64
 
37
65
  column_info = self.class.columns_hash[k.to_s]
38
- dimensions = v[:dimensions] || column_info&.limit
66
+ dimensions = v[:dimensions]
67
+ dimensions ||= column_info&.limit unless column_info&.type == :binary
68
+ type = v[:type] || Utils.type(adapter, column_info&.type)
39
69
 
40
- if !Neighbor::Utils.validate_dimensions(value, column_info&.type, dimensions).nil?
70
+ if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
41
71
  errors.add(k, "must have #{dimensions} dimensions")
42
72
  end
43
- if !Neighbor::Utils.validate_finite(value, column_info&.type)
73
+ if !Neighbor::Utils.validate_finite(value, type)
44
74
  errors.add(k, "must have finite values")
45
75
  end
46
76
  end
47
77
  end
48
78
 
49
- # TODO move to normalizes when Active Record < 7.1 no longer supported
50
- before_save do
51
- self.class.neighbor_attributes.each do |k, v|
52
- next unless v[:normalize] && attribute_changed?(k)
53
- value = read_attribute(k)
54
- next if value.nil?
55
- self[k] = Neighbor::Utils.normalize(value, column_info: self.class.columns_hash[k.to_s])
56
- end
57
- end
58
-
59
- # cannot use keyword arguments with scope with Ruby 3.2 and Active Record 6.1
60
- # https://github.com/rails/rails/issues/46934
61
- scope :nearest_neighbors, ->(attribute_name, vector, options = nil) {
62
- raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
63
- distance = options.delete(:distance)
64
- precision = options.delete(:precision)
65
- raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?
66
-
79
+ scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
67
80
  attribute_name = attribute_name.to_sym
68
81
  options = neighbor_attributes[attribute_name]
69
82
  raise ArgumentError, "Invalid attribute" unless options
70
83
  normalize = options[:normalize]
71
84
  dimensions = options[:dimensions]
85
+ type = options[:type]
72
86
 
73
87
  return none if vector.nil?
74
88
 
75
89
  distance = distance.to_s
76
90
 
77
- quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
78
-
79
91
  column_info = columns_hash[attribute_name.to_s]
80
92
  column_type = column_info&.type
81
93
 
82
- operator =
83
- case column_type
84
- when :bit
85
- case distance
86
- when "hamming"
87
- "<~>"
88
- when "jaccard"
89
- "<%>"
90
- when "hamming2"
91
- "#"
92
- end
93
- when :vector, :halfvec, :sparsevec
94
- case distance
95
- when "inner_product"
96
- "<#>"
97
- when "cosine"
98
- "<=>"
99
- when "euclidean"
100
- "<->"
101
- when "taxicab"
102
- "<+>"
103
- end
104
- when :cube
105
- case distance
106
- when "taxicab"
107
- "<#>"
108
- when "chebyshev"
109
- "<=>"
110
- when "euclidean", "cosine"
111
- "<->"
112
- end
113
- else
114
- raise ArgumentError, "Unsupported type: #{column_type}"
115
- end
94
+ adapter = Neighbor::Utils.adapter(klass)
95
+ if type && adapter != :sqlite
96
+ raise ArgumentError, "type only works with SQLite"
97
+ end
116
98
 
99
+ operator = Neighbor::Utils.operator(adapter, column_type, distance)
117
100
  raise ArgumentError, "Invalid distance: #{distance}" unless operator
118
101
 
119
102
  # ensure normalize set (can be true or false)
120
- if distance == "cosine" && column_type == :cube && normalize.nil?
103
+ normalize_required = Utils.normalize_required?(adapter, column_type)
104
+ if distance == "cosine" && normalize_required && normalize.nil?
121
105
  raise Neighbor::Error, "Set normalize for cosine distance with cube"
122
106
  end
123
107
 
124
108
  column_attribute = klass.type_for_attribute(attribute_name)
125
109
  vector = column_attribute.cast(vector)
126
- Neighbor::Utils.validate(vector, dimensions: dimensions, column_info: column_info)
110
+ dimensions ||= column_info&.limit unless column_info&.type == :binary
111
+ Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
127
112
  vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
128
113
 
129
- query = connection.quote(column_attribute.serialize(vector))
114
+ quoted_attribute = nil
115
+ query = nil
116
+ connection_pool.with_connection do |c|
117
+ quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
118
+ query = c.quote(column_attribute.serialize(vector))
119
+ end
130
120
 
131
121
  if !precision.nil?
122
+ if adapter != :postgresql || column_type != :vector
123
+ raise ArgumentError, "Precision not supported for this type"
124
+ end
125
+
132
126
  case precision.to_s
133
127
  when "half"
134
128
  cast_dimensions = dimensions || column_info&.limit
135
129
  raise ArgumentError, "Unknown dimensions" unless cast_dimensions
136
- quoted_attribute += "::halfvec(#{connection.quote(cast_dimensions.to_i)})"
130
+ quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
137
131
  else
138
132
  raise ArgumentError, "Invalid precision"
139
133
  end
140
134
  end
141
135
 
142
- order = "#{quoted_attribute} #{operator} #{query}"
143
- if operator == "#"
144
- order = "bit_count(#{order})"
145
- end
136
+ order = Utils.order(adapter, type, operator, quoted_attribute, query)
146
137
 
147
138
  # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
148
139
  # with normalized vectors:
@@ -150,7 +141,7 @@ module Neighbor
150
141
  # cosine distance = 1 - cosine similarity
151
142
  # this transformation doesn't change the order, so only needed for select
152
143
  neighbor_distance =
153
- if column_type == :cube && distance == "cosine"
144
+ if distance == "cosine" && normalize_required
154
145
  "POWER(#{order}, 2) / 2.0"
155
146
  elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
156
147
  "(#{order}) * -1"
@@ -0,0 +1,37 @@
1
+ module Neighbor
2
+ module MySQL
3
+ def self.initialize!
4
+ require_relative "type/mysql_vector"
5
+
6
+ require "active_record/connection_adapters/abstract_mysql_adapter"
7
+
8
+ # ensure schema can be dumped
9
+ ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
10
+
11
+ # ensure schema can be loaded
12
+ unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector)
13
+ ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector)
14
+ end
15
+
16
+ # prevent unknown OID warning
17
+ ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes)
18
+ if ActiveRecord::VERSION::STRING.to_f < 7.1
19
+ ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP)
20
+ end
21
+ end
22
+
23
+ module RegisterTypes
24
+ def initialize_type_map(m)
25
+ super
26
+ register_vector_type(m)
27
+ end
28
+
29
+ def register_vector_type(m)
30
+ m.register_type %r(^vector)i do |sql_type|
31
+ limit = extract_limit(sql_type)
32
+ Type::MysqlVector.new(limit: limit)
33
+ end
34
+ end
35
+ end
36
+ end
37
+ end
@@ -0,0 +1,21 @@
1
+ module Neighbor
2
+ class NormalizedAttribute < ActiveRecord::Type::Value
3
+ delegate :type, :serialize, :deserialize, to: :@cast_type
4
+
5
+ def initialize(cast_type:, model:, attribute_name:)
6
+ @cast_type = cast_type
7
+ @model = model
8
+ @attribute_name = attribute_name.to_s
9
+ end
10
+
11
+ def cast(...)
12
+ Neighbor::Utils.normalize(@cast_type.cast(...), column_info: @model.columns_hash[@attribute_name])
13
+ end
14
+
15
+ private
16
+
17
+ def cast_value(...)
18
+ @cast_type.send(:cast_value, ...)
19
+ end
20
+ end
21
+ end
@@ -0,0 +1,43 @@
1
+ module Neighbor
2
+ module PostgreSQL
3
+ def self.initialize!
4
+ require_relative "type/cube"
5
+ require_relative "type/halfvec"
6
+ require_relative "type/sparsevec"
7
+ require_relative "type/vector"
8
+
9
+ require "active_record/connection_adapters/postgresql_adapter"
10
+
11
+ # ensure schema can be dumped
12
+ ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
13
+ ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
14
+ ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
15
+ ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
16
+
17
+ # ensure schema can be loaded
18
+ ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
19
+
20
+ # prevent unknown OID warning
21
+ ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)
22
+ end
23
+
24
+ module RegisterTypes
25
+ def initialize_type_map(m = type_map)
26
+ super
27
+ m.register_type "cube", Type::Cube.new
28
+ m.register_type "halfvec" do |_, _, sql_type|
29
+ limit = extract_limit(sql_type)
30
+ Type::Halfvec.new(limit: limit)
31
+ end
32
+ m.register_type "sparsevec" do |_, _, sql_type|
33
+ limit = extract_limit(sql_type)
34
+ Type::Sparsevec.new(limit: limit)
35
+ end
36
+ m.register_type "vector" do |_, _, sql_type|
37
+ limit = extract_limit(sql_type)
38
+ Type::Vector.new(limit: limit)
39
+ end
40
+ end
41
+ end
42
+ end
43
+ end
@@ -0,0 +1,28 @@
1
+ module Neighbor
2
+ module SQLite
3
+ # note: this is a public API (unlike PostgreSQL and MySQL)
4
+ def self.initialize!
5
+ return if defined?(@initialized)
6
+
7
+ require_relative "type/sqlite_vector"
8
+ require_relative "type/sqlite_int8_vector"
9
+
10
+ require "sqlite_vec"
11
+ require "active_record/connection_adapters/sqlite3_adapter"
12
+
13
+ ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods)
14
+
15
+ @initialized = true
16
+ end
17
+
18
+ module InstanceMethods
19
+ def configure_connection
20
+ super
21
+ db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection
22
+ db.enable_load_extension(1)
23
+ SqliteVec.load(db)
24
+ db.enable_load_extension(0)
25
+ end
26
+ end
27
+ end
28
+ end
@@ -0,0 +1,33 @@
1
+ module Neighbor
2
+ module Type
3
+ class MysqlVector < ActiveRecord::Type::Binary
4
+ def type
5
+ :vector
6
+ end
7
+
8
+ def serialize(value)
9
+ if Utils.array?(value)
10
+ value = value.to_a.pack("e*")
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ def deserialize(value)
16
+ value = super
17
+ cast_value(value) unless value.nil?
18
+ end
19
+
20
+ private
21
+
22
+ def cast_value(value)
23
+ if value.is_a?(String)
24
+ value.unpack("e*")
25
+ elsif Utils.array?(value)
26
+ value.to_a
27
+ else
28
+ raise "can't cast #{value.class.name} to vector"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteInt8Vector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("c*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("c*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteVector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("f*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("f*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -1,7 +1,9 @@
1
1
  module Neighbor
2
2
  module Utils
3
- def self.validate_dimensions(value, type, expected)
3
+ def self.validate_dimensions(value, type, expected, adapter)
4
4
  dimensions = type == :sparsevec ? value.dimensions : value.size
5
+ dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
6
+
5
7
  if expected && dimensions != expected
6
8
  "Expected #{expected} dimensions, not #{dimensions}"
7
9
  end
@@ -9,7 +11,7 @@ module Neighbor
9
11
 
10
12
  def self.validate_finite(value, type)
11
13
  case type
12
- when :bit
14
+ when :bit, :integer
13
15
  true
14
16
  when :sparsevec
15
17
  value.values.all?(&:finite?)
@@ -18,18 +20,20 @@ module Neighbor
18
20
  end
19
21
  end
20
22
 
21
- def self.validate(value, dimensions:, column_info:)
22
- if (message = validate_dimensions(value, column_info&.type, dimensions || column_info&.limit))
23
+ def self.validate(value, dimensions:, type:, adapter:)
24
+ if (message = validate_dimensions(value, type, dimensions, adapter))
23
25
  raise Error, message
24
26
  end
25
27
 
26
- if !validate_finite(value, column_info&.type)
28
+ if !validate_finite(value, type)
27
29
  raise Error, "Values must be finite"
28
30
  end
29
31
  end
30
32
 
31
33
  def self.normalize(value, column_info:)
32
- raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec].include?(column_info&.type)
34
+ return nil if value.nil?
35
+
36
+ raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec, :binary].include?(column_info&.type)
33
37
 
34
38
  norm = Math.sqrt(value.sum { |v| v * v })
35
39
 
@@ -42,5 +46,156 @@ module Neighbor
42
46
  def self.array?(value)
43
47
  !value.nil? && value.respond_to?(:to_a)
44
48
  end
49
+
50
+ def self.adapter(model)
51
+ case model.connection_db_config.adapter
52
+ when /sqlite/i
53
+ :sqlite
54
+ when /mysql|trilogy/i
55
+ model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
56
+ else
57
+ :postgresql
58
+ end
59
+ end
60
+
61
+ def self.type(adapter, column_type)
62
+ case adapter
63
+ when :mysql
64
+ if column_type == :binary
65
+ :bit
66
+ else
67
+ column_type
68
+ end
69
+ else
70
+ column_type
71
+ end
72
+ end
73
+
74
+ def self.operator(adapter, column_type, distance)
75
+ case adapter
76
+ when :sqlite
77
+ case distance
78
+ when "euclidean"
79
+ "vec_distance_L2"
80
+ when "cosine"
81
+ "vec_distance_cosine"
82
+ when "taxicab"
83
+ "vec_distance_L1"
84
+ when "hamming"
85
+ "vec_distance_hamming"
86
+ end
87
+ when :mariadb
88
+ case column_type
89
+ when :binary
90
+ case distance
91
+ when "euclidean", "cosine"
92
+ "VEC_DISTANCE"
93
+ end
94
+ when :integer
95
+ case distance
96
+ when "hamming"
97
+ "BIT_COUNT"
98
+ end
99
+ else
100
+ raise ArgumentError, "Unsupported type: #{column_type}"
101
+ end
102
+ when :mysql
103
+ case column_type
104
+ when :vector
105
+ case distance
106
+ when "cosine"
107
+ "COSINE"
108
+ when "euclidean"
109
+ "EUCLIDEAN"
110
+ end
111
+ when :binary
112
+ case distance
113
+ when "hamming"
114
+ "BIT_COUNT"
115
+ end
116
+ else
117
+ raise ArgumentError, "Unsupported type: #{column_type}"
118
+ end
119
+ else
120
+ case column_type
121
+ when :bit
122
+ case distance
123
+ when "hamming"
124
+ "<~>"
125
+ when "jaccard"
126
+ "<%>"
127
+ when "hamming2"
128
+ "#"
129
+ end
130
+ when :vector, :halfvec, :sparsevec
131
+ case distance
132
+ when "inner_product"
133
+ "<#>"
134
+ when "cosine"
135
+ "<=>"
136
+ when "euclidean"
137
+ "<->"
138
+ when "taxicab"
139
+ "<+>"
140
+ end
141
+ when :cube
142
+ case distance
143
+ when "taxicab"
144
+ "<#>"
145
+ when "chebyshev"
146
+ "<=>"
147
+ when "euclidean", "cosine"
148
+ "<->"
149
+ end
150
+ else
151
+ raise ArgumentError, "Unsupported type: #{column_type}"
152
+ end
153
+ end
154
+ end
155
+
156
+ def self.order(adapter, type, operator, quoted_attribute, query)
157
+ case adapter
158
+ when :sqlite
159
+ case type
160
+ when :int8
161
+ "#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
162
+ when :bit
163
+ "#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
164
+ else
165
+ "#{operator}(#{quoted_attribute}, #{query})"
166
+ end
167
+ when :mariadb
168
+ if operator == "BIT_COUNT"
169
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
170
+ else
171
+ "VEC_DISTANCE(#{quoted_attribute}, #{query})"
172
+ end
173
+ when :mysql
174
+ if operator == "BIT_COUNT"
175
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
176
+ elsif operator == "COSINE"
177
+ "DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
178
+ else
179
+ "DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
180
+ end
181
+ else
182
+ if operator == "#"
183
+ "bit_count(#{quoted_attribute} # #{query})"
184
+ else
185
+ "#{quoted_attribute} #{operator} #{query}"
186
+ end
187
+ end
188
+ end
189
+
190
+ def self.normalize_required?(adapter, column_type)
191
+ case adapter
192
+ when :postgresql
193
+ column_type == :cube
194
+ when :mariadb
195
+ true
196
+ else
197
+ false
198
+ end
199
+ end
45
200
  end
46
201
  end
@@ -1,3 +1,3 @@
1
1
  module Neighbor
2
- VERSION = "0.4.3"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/neighbor.rb CHANGED
@@ -1,6 +1,11 @@
1
1
  # dependencies
2
2
  require "active_support"
3
3
 
4
+ # adapter hooks
5
+ require_relative "neighbor/mysql"
6
+ require_relative "neighbor/postgresql"
7
+ require_relative "neighbor/sqlite"
8
+
4
9
  # modules
5
10
  require_relative "neighbor/reranking"
6
11
  require_relative "neighbor/sparse_vector"
@@ -9,53 +14,22 @@ require_relative "neighbor/version"
9
14
 
10
15
  module Neighbor
11
16
  class Error < StandardError; end
12
-
13
- module RegisterTypes
14
- def initialize_type_map(m = type_map)
15
- super
16
- m.register_type "cube", Type::Cube.new
17
- m.register_type "halfvec" do |_, _, sql_type|
18
- limit = extract_limit(sql_type)
19
- Type::Halfvec.new(limit: limit)
20
- end
21
- m.register_type "sparsevec" do |_, _, sql_type|
22
- limit = extract_limit(sql_type)
23
- Type::Sparsevec.new(limit: limit)
24
- end
25
- m.register_type "vector" do |_, _, sql_type|
26
- limit = extract_limit(sql_type)
27
- Type::Vector.new(limit: limit)
28
- end
29
- end
30
- end
31
17
  end
32
18
 
33
19
  ActiveSupport.on_load(:active_record) do
20
+ require_relative "neighbor/attribute"
34
21
  require_relative "neighbor/model"
35
- require_relative "neighbor/type/cube"
36
- require_relative "neighbor/type/halfvec"
37
- require_relative "neighbor/type/sparsevec"
38
- require_relative "neighbor/type/vector"
22
+ require_relative "neighbor/normalized_attribute"
39
23
 
40
24
  extend Neighbor::Model
41
25
 
42
- require "active_record/connection_adapters/postgresql_adapter"
43
-
44
- # ensure schema can be dumped
45
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
46
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
47
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
48
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
49
-
50
- # ensure schema can be loaded
51
- ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
52
-
53
- # prevent unknown OID warning
54
- if ActiveRecord::VERSION::MAJOR >= 7
55
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(Neighbor::RegisterTypes)
56
- else
57
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.prepend(Neighbor::RegisterTypes)
26
+ begin
27
+ Neighbor::PostgreSQL.initialize!
28
+ rescue Gem::LoadError
29
+ # tries to load pg gem, which may not be available
58
30
  end
31
+
32
+ Neighbor::MySQL.initialize!
59
33
  end
60
34
 
61
35
  require_relative "neighbor/railtie" if defined?(Rails::Railtie)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: neighbor
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.4.3
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2024-09-02 00:00:00.000000000 Z
11
+ date: 2024-10-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: activerecord
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '6.1'
19
+ version: '7'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '6.1'
26
+ version: '7'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -34,17 +34,27 @@ files:
34
34
  - LICENSE.txt
35
35
  - README.md
36
36
  - lib/generators/neighbor/cube_generator.rb
37
+ - lib/generators/neighbor/sqlite_generator.rb
37
38
  - lib/generators/neighbor/templates/cube.rb.tt
39
+ - lib/generators/neighbor/templates/sqlite.rb.tt
38
40
  - lib/generators/neighbor/templates/vector.rb.tt
39
41
  - lib/generators/neighbor/vector_generator.rb
40
42
  - lib/neighbor.rb
43
+ - lib/neighbor/attribute.rb
41
44
  - lib/neighbor/model.rb
45
+ - lib/neighbor/mysql.rb
46
+ - lib/neighbor/normalized_attribute.rb
47
+ - lib/neighbor/postgresql.rb
42
48
  - lib/neighbor/railtie.rb
43
49
  - lib/neighbor/reranking.rb
44
50
  - lib/neighbor/sparse_vector.rb
51
+ - lib/neighbor/sqlite.rb
45
52
  - lib/neighbor/type/cube.rb
46
53
  - lib/neighbor/type/halfvec.rb
54
+ - lib/neighbor/type/mysql_vector.rb
47
55
  - lib/neighbor/type/sparsevec.rb
56
+ - lib/neighbor/type/sqlite_int8_vector.rb
57
+ - lib/neighbor/type/sqlite_vector.rb
48
58
  - lib/neighbor/type/vector.rb
49
59
  - lib/neighbor/utils.rb
50
60
  - lib/neighbor/version.rb
@@ -67,8 +77,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
67
77
  - !ruby/object:Gem::Version
68
78
  version: '0'
69
79
  requirements: []
70
- rubygems_version: 3.5.11
80
+ rubygems_version: 3.5.16
71
81
  signing_key:
72
82
  specification_version: 4
73
- summary: Nearest neighbor search for Rails and Postgres
83
+ summary: Nearest neighbor search for Rails
74
84
  test_files: []