neighbor 0.4.3 → 0.5.1
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +242 -21
- data/lib/generators/neighbor/sqlite_generator.rb +13 -0
- data/lib/generators/neighbor/templates/sqlite.rb.tt +2 -0
- data/lib/neighbor/attribute.rb +48 -0
- data/lib/neighbor/model.rb +59 -68
- data/lib/neighbor/mysql.rb +37 -0
- data/lib/neighbor/normalized_attribute.rb +21 -0
- data/lib/neighbor/postgresql.rb +43 -0
- data/lib/neighbor/sqlite.rb +28 -0
- data/lib/neighbor/type/mysql_vector.rb +33 -0
- data/lib/neighbor/type/sqlite_int8_vector.rb +29 -0
- data/lib/neighbor/type/sqlite_vector.rb +29 -0
- data/lib/neighbor/utils.rb +160 -5
- data/lib/neighbor/version.rb +1 -1
- data/lib/neighbor.rb +13 -39
- metadata +16 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ee3864aa1511aa273c2d619e408bff8bcb491adc6dd111c5c88f7b79ef0baafc
|
4
|
+
data.tar.gz: 7d79f79814a0041e77d18edf83cb87b5fba5f59d1a25a71370d8b28c6a5cd289
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c8d63b19a366d670212d6d5296f3b0f0de45e007635b7026bc9474bbd9248c9fc9a7ba4ee8b7c1b796ecdd7eef93460fe4e4e746eb49a8c19d471f4827530ecf
|
7
|
+
data.tar.gz: c371a55e3f1579b6a62fcc5cc55c5ac54cb318faa12278c736089b97de1990ecf56a17863bcc7043b4a2cde9be50cb02b8b7dcecaf5032d371e1675a062723f2
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,17 @@
|
|
1
|
+
## 0.5.1 (2024-12-03)
|
2
|
+
|
3
|
+
- Added experimental support for MariaDB 11.7
|
4
|
+
- Dropped experimental support for MariaDB 11.6 Vector
|
5
|
+
|
6
|
+
## 0.5.0 (2024-10-07)
|
7
|
+
|
8
|
+
- Added experimental support for SQLite (sqlite-vec)
|
9
|
+
- Added experimental support for MariaDB 11.6 Vector
|
10
|
+
- Added experimental support for MySQL 9
|
11
|
+
- Changed `normalize` option to use Active Record normalization
|
12
|
+
- Fixed connection leasing for Active Record 7.2
|
13
|
+
- Dropped support for Active Record < 7
|
14
|
+
|
1
15
|
## 0.4.3 (2024-09-02)
|
2
16
|
|
3
17
|
- Added `rrf` method
|
data/README.md
CHANGED
@@ -1,6 +1,13 @@
|
|
1
1
|
# Neighbor
|
2
2
|
|
3
|
-
Nearest neighbor search for Rails
|
3
|
+
Nearest neighbor search for Rails
|
4
|
+
|
5
|
+
Supports:
|
6
|
+
|
7
|
+
- Postgres (cube and pgvector)
|
8
|
+
- SQLite (sqlite-vec) - experimental
|
9
|
+
- MariaDB 11.7 - experimental
|
10
|
+
- MySQL 9 (searching requires HeatWave) - experimental
|
4
11
|
|
5
12
|
[](https://github.com/ankane/neighbor/actions)
|
6
13
|
|
@@ -12,7 +19,7 @@ Add this line to your application’s Gemfile:
|
|
12
19
|
gem "neighbor"
|
13
20
|
```
|
14
21
|
|
15
|
-
|
22
|
+
### For Postgres
|
16
23
|
|
17
24
|
Neighbor supports two extensions: [cube](https://www.postgresql.org/docs/current/cube.html) and [pgvector](https://github.com/pgvector/pgvector). cube ships with Postgres, while pgvector supports more dimensions and approximate nearest neighbor search.
|
18
25
|
|
@@ -30,16 +37,35 @@ rails generate neighbor:vector
|
|
30
37
|
rails db:migrate
|
31
38
|
```
|
32
39
|
|
40
|
+
### For SQLite
|
41
|
+
|
42
|
+
Add this line to your application’s Gemfile:
|
43
|
+
|
44
|
+
```ruby
|
45
|
+
gem "sqlite-vec"
|
46
|
+
```
|
47
|
+
|
48
|
+
And run:
|
49
|
+
|
50
|
+
```sh
|
51
|
+
rails generate neighbor:sqlite
|
52
|
+
```
|
53
|
+
|
33
54
|
## Getting Started
|
34
55
|
|
35
56
|
Create a migration
|
36
57
|
|
37
58
|
```ruby
|
38
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
59
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
39
60
|
def change
|
61
|
+
# cube
|
40
62
|
add_column :items, :embedding, :cube
|
41
|
-
|
63
|
+
|
64
|
+
# pgvector, MariaDB, and MySQL
|
42
65
|
add_column :items, :embedding, :vector, limit: 3 # dimensions
|
66
|
+
|
67
|
+
# sqlite-vec
|
68
|
+
add_column :items, :embedding, :binary
|
43
69
|
end
|
44
70
|
end
|
45
71
|
```
|
@@ -81,6 +107,9 @@ See the additional docs for:
|
|
81
107
|
|
82
108
|
- [cube](#cube)
|
83
109
|
- [pgvector](#pgvector)
|
110
|
+
- [sqlite-vec](#sqlite-vec)
|
111
|
+
- [MariaDB](#mariadb)
|
112
|
+
- [MySQL](#mysql)
|
84
113
|
|
85
114
|
Or check out some [examples](#examples)
|
86
115
|
|
@@ -134,12 +163,18 @@ Supported values are:
|
|
134
163
|
|
135
164
|
The `vector` type can have up to 16,000 dimensions, and vectors with up to 2,000 dimensions can be indexed.
|
136
165
|
|
166
|
+
The `halfvec` type can have up to 16,000 dimensions, and half vectors with up to 4,000 dimensions can be indexed.
|
167
|
+
|
168
|
+
The `bit` type can have up to 83 million dimensions, and bit vectors with up to 64,000 dimensions can be indexed.
|
169
|
+
|
170
|
+
The `sparsevec` type can have up to 16,000 non-zero elements, and sparse vectors with up to 1,000 non-zero elements can be indexed.
|
171
|
+
|
137
172
|
### Indexing
|
138
173
|
|
139
174
|
Add an approximate index to speed up queries. Create a migration with:
|
140
175
|
|
141
176
|
```ruby
|
142
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
177
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
143
178
|
def change
|
144
179
|
add_index :items, :embedding, using: :hnsw, opclass: :vector_l2_ops
|
145
180
|
# or
|
@@ -167,7 +202,7 @@ Item.connection.execute("SET ivfflat.probes = 3")
|
|
167
202
|
Use the `halfvec` type to store half-precision vectors
|
168
203
|
|
169
204
|
```ruby
|
170
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
205
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
171
206
|
def change
|
172
207
|
add_column :items, :embedding, :halfvec, limit: 3 # dimensions
|
173
208
|
end
|
@@ -179,7 +214,7 @@ end
|
|
179
214
|
Index vectors at half precision for smaller indexes
|
180
215
|
|
181
216
|
```ruby
|
182
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
217
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
183
218
|
def change
|
184
219
|
add_index :items, "(embedding::halfvec(3)) vector_l2_ops", using: :hnsw
|
185
220
|
end
|
@@ -197,7 +232,7 @@ Item.nearest_neighbors(:embedding, [0.9, 1.3, 1.1], distance: "euclidean", preci
|
|
197
232
|
Use the `bit` type to store binary vectors
|
198
233
|
|
199
234
|
```ruby
|
200
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
235
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
201
236
|
def change
|
202
237
|
add_column :items, :embedding, :bit, limit: 3 # dimensions
|
203
238
|
end
|
@@ -215,7 +250,7 @@ Item.nearest_neighbors(:embedding, "101", distance: "hamming").first(5)
|
|
215
250
|
Use expression indexing for binary quantization
|
216
251
|
|
217
252
|
```ruby
|
218
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
253
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
219
254
|
def change
|
220
255
|
add_index :items, "(binary_quantize(embedding)::bit(3)) bit_hamming_ops", using: :hnsw
|
221
256
|
end
|
@@ -227,7 +262,7 @@ end
|
|
227
262
|
Use the `sparsevec` type to store sparse vectors
|
228
263
|
|
229
264
|
```ruby
|
230
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
265
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
231
266
|
def change
|
232
267
|
add_column :items, :embedding, :sparsevec, limit: 3 # dimensions
|
233
268
|
end
|
@@ -241,6 +276,182 @@ embedding = Neighbor::SparseVector.new({0 => 0.9, 1 => 1.3, 2 => 1.1}, 3)
|
|
241
276
|
Item.nearest_neighbors(:embedding, embedding, distance: "euclidean").first(5)
|
242
277
|
```
|
243
278
|
|
279
|
+
## sqlite-vec
|
280
|
+
|
281
|
+
### Distance
|
282
|
+
|
283
|
+
Supported values are:
|
284
|
+
|
285
|
+
- `euclidean`
|
286
|
+
- `cosine`
|
287
|
+
- `taxicab`
|
288
|
+
- `hamming`
|
289
|
+
|
290
|
+
### Dimensions
|
291
|
+
|
292
|
+
For sqlite-vec, it’s a good idea to specify the number of dimensions to ensure all records have the same number.
|
293
|
+
|
294
|
+
```ruby
|
295
|
+
class Item < ApplicationRecord
|
296
|
+
has_neighbors :embedding, dimensions: 3
|
297
|
+
end
|
298
|
+
```
|
299
|
+
|
300
|
+
### Virtual Tables
|
301
|
+
|
302
|
+
You can also use [virtual tables](https://alexgarcia.xyz/sqlite-vec/features/knn.html)
|
303
|
+
|
304
|
+
```ruby
|
305
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
306
|
+
def change
|
307
|
+
# Rails 8+
|
308
|
+
create_virtual_table :items, :vec0, [
|
309
|
+
"embedding float[3] distance_metric=L2"
|
310
|
+
]
|
311
|
+
|
312
|
+
# Rails < 8
|
313
|
+
execute <<~SQL
|
314
|
+
CREATE VIRTUAL TABLE items USING vec0(
|
315
|
+
embedding float[3] distance_metric=L2
|
316
|
+
)
|
317
|
+
SQL
|
318
|
+
end
|
319
|
+
end
|
320
|
+
```
|
321
|
+
|
322
|
+
Use `distance_metric=cosine` for cosine distance
|
323
|
+
|
324
|
+
You can optionally ignore any shadow tables that are created
|
325
|
+
|
326
|
+
```ruby
|
327
|
+
ActiveRecord::SchemaDumper.ignore_tables += [
|
328
|
+
"items_chunks", "items_rowids", "items_vector_chunks00"
|
329
|
+
]
|
330
|
+
```
|
331
|
+
|
332
|
+
Create a model with `rowid` as the primary key
|
333
|
+
|
334
|
+
```ruby
|
335
|
+
class Item < ApplicationRecord
|
336
|
+
self.primary_key = "rowid"
|
337
|
+
|
338
|
+
has_neighbors :embedding, dimensions: 3
|
339
|
+
end
|
340
|
+
```
|
341
|
+
|
342
|
+
Get the `k` nearest neighbors
|
343
|
+
|
344
|
+
```ruby
|
345
|
+
Item.where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
|
346
|
+
```
|
347
|
+
|
348
|
+
Filter by primary key
|
349
|
+
|
350
|
+
```ruby
|
351
|
+
Item.where(rowid: [2, 3]).where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
|
352
|
+
```
|
353
|
+
|
354
|
+
### Int8 Vectors
|
355
|
+
|
356
|
+
Use the `type` option for int8 vectors
|
357
|
+
|
358
|
+
```ruby
|
359
|
+
class Item < ApplicationRecord
|
360
|
+
has_neighbors :embedding, dimensions: 3, type: :int8
|
361
|
+
end
|
362
|
+
```
|
363
|
+
|
364
|
+
### Binary Vectors
|
365
|
+
|
366
|
+
Use the `type` option for binary vectors
|
367
|
+
|
368
|
+
```ruby
|
369
|
+
class Item < ApplicationRecord
|
370
|
+
has_neighbors :embedding, dimensions: 8, type: :bit
|
371
|
+
end
|
372
|
+
```
|
373
|
+
|
374
|
+
Get the nearest neighbors by Hamming distance
|
375
|
+
|
376
|
+
```ruby
|
377
|
+
Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
|
378
|
+
```
|
379
|
+
|
380
|
+
## MariaDB
|
381
|
+
|
382
|
+
### Distance
|
383
|
+
|
384
|
+
Supported values are:
|
385
|
+
|
386
|
+
- `euclidean`
|
387
|
+
- `cosine`
|
388
|
+
- `hamming`
|
389
|
+
|
390
|
+
### Indexing
|
391
|
+
|
392
|
+
Vector columns must use `null: false` to add a vector index
|
393
|
+
|
394
|
+
```ruby
|
395
|
+
class CreateItems < ActiveRecord::Migration[8.0]
|
396
|
+
def change
|
397
|
+
create_table :items do |t|
|
398
|
+
t.vector :embedding, limit: 3, null: false
|
399
|
+
t.index :embedding, type: :vector
|
400
|
+
end
|
401
|
+
end
|
402
|
+
end
|
403
|
+
```
|
404
|
+
|
405
|
+
### Binary Vectors
|
406
|
+
|
407
|
+
Use the `bigint` type to store binary vectors
|
408
|
+
|
409
|
+
```ruby
|
410
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
411
|
+
def change
|
412
|
+
add_column :items, :embedding, :bigint
|
413
|
+
end
|
414
|
+
end
|
415
|
+
```
|
416
|
+
|
417
|
+
Note: Binary vectors can have up to 64 dimensions
|
418
|
+
|
419
|
+
Get the nearest neighbors by Hamming distance
|
420
|
+
|
421
|
+
```ruby
|
422
|
+
Item.nearest_neighbors(:embedding, 5, distance: "hamming").first(5)
|
423
|
+
```
|
424
|
+
|
425
|
+
## MySQL
|
426
|
+
|
427
|
+
### Distance
|
428
|
+
|
429
|
+
Supported values are:
|
430
|
+
|
431
|
+
- `euclidean`
|
432
|
+
- `cosine`
|
433
|
+
- `hamming`
|
434
|
+
|
435
|
+
Note: The `DISTANCE()` function is [only available on HeatWave](https://dev.mysql.com/doc/refman/9.0/en/vector-functions.html)
|
436
|
+
|
437
|
+
### Binary Vectors
|
438
|
+
|
439
|
+
Use the `binary` type to store binary vectors
|
440
|
+
|
441
|
+
```ruby
|
442
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
443
|
+
def change
|
444
|
+
add_column :items, :embedding, :binary
|
445
|
+
end
|
446
|
+
end
|
447
|
+
```
|
448
|
+
|
449
|
+
Get the nearest neighbors by Hamming distance
|
450
|
+
|
451
|
+
```ruby
|
452
|
+
Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
|
453
|
+
```
|
454
|
+
|
244
455
|
## Examples
|
245
456
|
|
246
457
|
- [Embeddings](#openai-embeddings) with OpenAI
|
@@ -472,12 +683,9 @@ end
|
|
472
683
|
Create some documents
|
473
684
|
|
474
685
|
```ruby
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
"The bear is growling"
|
479
|
-
]
|
480
|
-
documents = Document.create!(texts.map { |v| {content: v} })
|
686
|
+
Document.create!(content: "The dog is barking")
|
687
|
+
Document.create!(content: "The cat is purring")
|
688
|
+
Document.create!(content: "The bear is growling")
|
481
689
|
```
|
482
690
|
|
483
691
|
Generate an embedding for each document
|
@@ -485,9 +693,9 @@ Generate an embedding for each document
|
|
485
693
|
```ruby
|
486
694
|
embed = Informers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
|
487
695
|
embed_options = {model_output: "sentence_embedding", pooling: "none"} # specific to embedding model
|
488
|
-
embeddings = embed.(documents.map(&:content), **embed_options)
|
489
696
|
|
490
|
-
|
697
|
+
Document.find_each do |document|
|
698
|
+
embedding = embed.(document.content, **embed_options)
|
491
699
|
document.update!(embedding: embedding)
|
492
700
|
end
|
493
701
|
```
|
@@ -511,7 +719,7 @@ semantic_results =
|
|
511
719
|
To combine the results, use Reciprocal Rank Fusion (RRF)
|
512
720
|
|
513
721
|
```ruby
|
514
|
-
Neighbor::Reranking.rrf(keyword_results, semantic_results)
|
722
|
+
Neighbor::Reranking.rrf(keyword_results, semantic_results).first(5)
|
515
723
|
```
|
516
724
|
|
517
725
|
Or a reranking model
|
@@ -519,7 +727,7 @@ Or a reranking model
|
|
519
727
|
```ruby
|
520
728
|
rerank = Informers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-xsmall-v1")
|
521
729
|
results = (keyword_results + semantic_results).uniq
|
522
|
-
rerank.(query, results.map(&:content)
|
730
|
+
rerank.(query, results.map(&:content)).first(5).map { |v| results[v[:doc_id]] }
|
523
731
|
```
|
524
732
|
|
525
733
|
See the [complete code](examples/hybrid/example.rb)
|
@@ -667,6 +875,19 @@ To get started with development:
|
|
667
875
|
git clone https://github.com/ankane/neighbor.git
|
668
876
|
cd neighbor
|
669
877
|
bundle install
|
878
|
+
|
879
|
+
# Postgres
|
670
880
|
createdb neighbor_test
|
671
|
-
bundle exec rake test
|
881
|
+
bundle exec rake test:postgresql
|
882
|
+
|
883
|
+
# SQLite
|
884
|
+
bundle exec rake test:sqlite
|
885
|
+
|
886
|
+
# MariaDB
|
887
|
+
docker run -e MARIADB_ALLOW_EMPTY_ROOT_PASSWORD=1 -e MARIADB_DATABASE=neighbor_test -p 3307:3306 mariadb:11.7-rc
|
888
|
+
bundle exec rake test:mariadb
|
889
|
+
|
890
|
+
# MySQL
|
891
|
+
docker run -e MYSQL_ALLOW_EMPTY_PASSWORD=1 -e MYSQL_DATABASE=neighbor_test -p 3306:3306 mysql:9
|
892
|
+
bundle exec rake test:mysql
|
672
893
|
```
|
@@ -0,0 +1,13 @@
|
|
1
|
+
require "rails/generators"
|
2
|
+
|
3
|
+
module Neighbor
|
4
|
+
module Generators
|
5
|
+
class SqliteGenerator < Rails::Generators::Base
|
6
|
+
source_root File.join(__dir__, "templates")
|
7
|
+
|
8
|
+
def copy_templates
|
9
|
+
template "sqlite.rb", "config/initializers/neighbor.rb"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class Attribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, type:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@type = type
|
9
|
+
@attribute_name = attribute_name
|
10
|
+
end
|
11
|
+
|
12
|
+
private
|
13
|
+
|
14
|
+
def cast_value(...)
|
15
|
+
new_cast_type.send(:cast_value, ...)
|
16
|
+
end
|
17
|
+
|
18
|
+
def new_cast_type
|
19
|
+
@new_cast_type ||= begin
|
20
|
+
if @cast_type.is_a?(ActiveModel::Type::Value)
|
21
|
+
case Utils.adapter(@model)
|
22
|
+
when :sqlite
|
23
|
+
case @type&.to_sym
|
24
|
+
when :int8
|
25
|
+
Type::SqliteInt8Vector.new
|
26
|
+
when :bit
|
27
|
+
@cast_type
|
28
|
+
when :float32, nil
|
29
|
+
Type::SqliteVector.new
|
30
|
+
else
|
31
|
+
raise ArgumentError, "Unsupported type"
|
32
|
+
end
|
33
|
+
when :mariadb
|
34
|
+
if @model.columns_hash[@attribute_name.to_s]&.type == :integer
|
35
|
+
@cast_type
|
36
|
+
else
|
37
|
+
Type::MysqlVector.new
|
38
|
+
end
|
39
|
+
else
|
40
|
+
@cast_type
|
41
|
+
end
|
42
|
+
else
|
43
|
+
@cast_type
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/neighbor/model.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Model
|
3
|
-
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
|
3
|
+
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
|
4
4
|
if attribute_names.empty?
|
5
5
|
raise ArgumentError, "has_neighbors requires an attribute name"
|
6
6
|
end
|
@@ -24,125 +24,116 @@ module Neighbor
|
|
24
24
|
|
25
25
|
attribute_names.each do |attribute_name|
|
26
26
|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
|
27
|
-
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
|
27
|
+
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
|
28
|
+
end
|
29
|
+
|
30
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.2
|
31
|
+
decorate_attributes(attribute_names) do |name, cast_type|
|
32
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
|
33
|
+
end
|
34
|
+
else
|
35
|
+
attribute_names.each do |attribute_name|
|
36
|
+
attribute attribute_name do |cast_type|
|
37
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
if normalize
|
43
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.1
|
44
|
+
attribute_names.each do |attribute_name|
|
45
|
+
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
|
46
|
+
end
|
47
|
+
else
|
48
|
+
attribute_names.each do |attribute_name|
|
49
|
+
attribute attribute_name do |cast_type|
|
50
|
+
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
28
54
|
end
|
29
55
|
|
30
56
|
return if @neighbor_attributes.size != attribute_names.size
|
31
57
|
|
32
58
|
validate do
|
59
|
+
adapter = Utils.adapter(self.class)
|
60
|
+
|
33
61
|
self.class.neighbor_attributes.each do |k, v|
|
34
62
|
value = read_attribute(k)
|
35
63
|
next if value.nil?
|
36
64
|
|
37
65
|
column_info = self.class.columns_hash[k.to_s]
|
38
|
-
dimensions = v[:dimensions]
|
66
|
+
dimensions = v[:dimensions]
|
67
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
68
|
+
type = v[:type] || Utils.type(adapter, column_info&.type)
|
39
69
|
|
40
|
-
if !Neighbor::Utils.validate_dimensions(value,
|
70
|
+
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
|
41
71
|
errors.add(k, "must have #{dimensions} dimensions")
|
42
72
|
end
|
43
|
-
if !Neighbor::Utils.validate_finite(value,
|
73
|
+
if !Neighbor::Utils.validate_finite(value, type)
|
44
74
|
errors.add(k, "must have finite values")
|
45
75
|
end
|
46
76
|
end
|
47
77
|
end
|
48
78
|
|
49
|
-
|
50
|
-
before_save do
|
51
|
-
self.class.neighbor_attributes.each do |k, v|
|
52
|
-
next unless v[:normalize] && attribute_changed?(k)
|
53
|
-
value = read_attribute(k)
|
54
|
-
next if value.nil?
|
55
|
-
self[k] = Neighbor::Utils.normalize(value, column_info: self.class.columns_hash[k.to_s])
|
56
|
-
end
|
57
|
-
end
|
58
|
-
|
59
|
-
# cannot use keyword arguments with scope with Ruby 3.2 and Active Record 6.1
|
60
|
-
# https://github.com/rails/rails/issues/46934
|
61
|
-
scope :nearest_neighbors, ->(attribute_name, vector, options = nil) {
|
62
|
-
raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
|
63
|
-
distance = options.delete(:distance)
|
64
|
-
precision = options.delete(:precision)
|
65
|
-
raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?
|
66
|
-
|
79
|
+
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
|
67
80
|
attribute_name = attribute_name.to_sym
|
68
81
|
options = neighbor_attributes[attribute_name]
|
69
82
|
raise ArgumentError, "Invalid attribute" unless options
|
70
83
|
normalize = options[:normalize]
|
71
84
|
dimensions = options[:dimensions]
|
85
|
+
type = options[:type]
|
72
86
|
|
73
87
|
return none if vector.nil?
|
74
88
|
|
75
89
|
distance = distance.to_s
|
76
90
|
|
77
|
-
quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
|
78
|
-
|
79
91
|
column_info = columns_hash[attribute_name.to_s]
|
80
92
|
column_type = column_info&.type
|
81
93
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
when "hamming"
|
87
|
-
"<~>"
|
88
|
-
when "jaccard"
|
89
|
-
"<%>"
|
90
|
-
when "hamming2"
|
91
|
-
"#"
|
92
|
-
end
|
93
|
-
when :vector, :halfvec, :sparsevec
|
94
|
-
case distance
|
95
|
-
when "inner_product"
|
96
|
-
"<#>"
|
97
|
-
when "cosine"
|
98
|
-
"<=>"
|
99
|
-
when "euclidean"
|
100
|
-
"<->"
|
101
|
-
when "taxicab"
|
102
|
-
"<+>"
|
103
|
-
end
|
104
|
-
when :cube
|
105
|
-
case distance
|
106
|
-
when "taxicab"
|
107
|
-
"<#>"
|
108
|
-
when "chebyshev"
|
109
|
-
"<=>"
|
110
|
-
when "euclidean", "cosine"
|
111
|
-
"<->"
|
112
|
-
end
|
113
|
-
else
|
114
|
-
raise ArgumentError, "Unsupported type: #{column_type}"
|
115
|
-
end
|
94
|
+
adapter = Neighbor::Utils.adapter(klass)
|
95
|
+
if type && adapter != :sqlite
|
96
|
+
raise ArgumentError, "type only works with SQLite"
|
97
|
+
end
|
116
98
|
|
99
|
+
operator = Neighbor::Utils.operator(adapter, column_type, distance)
|
117
100
|
raise ArgumentError, "Invalid distance: #{distance}" unless operator
|
118
101
|
|
119
102
|
# ensure normalize set (can be true or false)
|
120
|
-
|
103
|
+
normalize_required = Utils.normalize_required?(adapter, column_type)
|
104
|
+
if distance == "cosine" && normalize_required && normalize.nil?
|
121
105
|
raise Neighbor::Error, "Set normalize for cosine distance with cube"
|
122
106
|
end
|
123
107
|
|
124
108
|
column_attribute = klass.type_for_attribute(attribute_name)
|
125
109
|
vector = column_attribute.cast(vector)
|
126
|
-
|
110
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
111
|
+
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
|
127
112
|
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
|
128
113
|
|
129
|
-
|
114
|
+
quoted_attribute = nil
|
115
|
+
query = nil
|
116
|
+
connection_pool.with_connection do |c|
|
117
|
+
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
|
118
|
+
query = c.quote(column_attribute.serialize(vector))
|
119
|
+
end
|
130
120
|
|
131
121
|
if !precision.nil?
|
122
|
+
if adapter != :postgresql || column_type != :vector
|
123
|
+
raise ArgumentError, "Precision not supported for this type"
|
124
|
+
end
|
125
|
+
|
132
126
|
case precision.to_s
|
133
127
|
when "half"
|
134
128
|
cast_dimensions = dimensions || column_info&.limit
|
135
129
|
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
|
136
|
-
quoted_attribute += "::halfvec(#{
|
130
|
+
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
|
137
131
|
else
|
138
132
|
raise ArgumentError, "Invalid precision"
|
139
133
|
end
|
140
134
|
end
|
141
135
|
|
142
|
-
order =
|
143
|
-
if operator == "#"
|
144
|
-
order = "bit_count(#{order})"
|
145
|
-
end
|
136
|
+
order = Utils.order(adapter, type, operator, quoted_attribute, query)
|
146
137
|
|
147
138
|
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
|
148
139
|
# with normalized vectors:
|
@@ -150,7 +141,7 @@ module Neighbor
|
|
150
141
|
# cosine distance = 1 - cosine similarity
|
151
142
|
# this transformation doesn't change the order, so only needed for select
|
152
143
|
neighbor_distance =
|
153
|
-
if
|
144
|
+
if distance == "cosine" && normalize_required
|
154
145
|
"POWER(#{order}, 2) / 2.0"
|
155
146
|
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
|
156
147
|
"(#{order}) * -1"
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module MySQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/mysql_vector"
|
5
|
+
|
6
|
+
require "active_record/connection_adapters/abstract_mysql_adapter"
|
7
|
+
|
8
|
+
# ensure schema can be dumped
|
9
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
10
|
+
|
11
|
+
# ensure schema can be loaded
|
12
|
+
unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector)
|
13
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector)
|
14
|
+
end
|
15
|
+
|
16
|
+
# prevent unknown OID warning
|
17
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes)
|
18
|
+
if ActiveRecord::VERSION::STRING.to_f < 7.1
|
19
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
module RegisterTypes
|
24
|
+
def initialize_type_map(m)
|
25
|
+
super
|
26
|
+
register_vector_type(m)
|
27
|
+
end
|
28
|
+
|
29
|
+
def register_vector_type(m)
|
30
|
+
m.register_type %r(^vector)i do |sql_type|
|
31
|
+
limit = extract_limit(sql_type)
|
32
|
+
Type::MysqlVector.new(limit: limit)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class NormalizedAttribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, to: :@cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@attribute_name = attribute_name.to_s
|
9
|
+
end
|
10
|
+
|
11
|
+
def cast(...)
|
12
|
+
Neighbor::Utils.normalize(@cast_type.cast(...), column_info: @model.columns_hash[@attribute_name])
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(...)
|
18
|
+
@cast_type.send(:cast_value, ...)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module PostgreSQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/cube"
|
5
|
+
require_relative "type/halfvec"
|
6
|
+
require_relative "type/sparsevec"
|
7
|
+
require_relative "type/vector"
|
8
|
+
|
9
|
+
require "active_record/connection_adapters/postgresql_adapter"
|
10
|
+
|
11
|
+
# ensure schema can be dumped
|
12
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
|
13
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
14
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
15
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
16
|
+
|
17
|
+
# ensure schema can be loaded
|
18
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
19
|
+
|
20
|
+
# prevent unknown OID warning
|
21
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)
|
22
|
+
end
|
23
|
+
|
24
|
+
module RegisterTypes
|
25
|
+
def initialize_type_map(m = type_map)
|
26
|
+
super
|
27
|
+
m.register_type "cube", Type::Cube.new
|
28
|
+
m.register_type "halfvec" do |_, _, sql_type|
|
29
|
+
limit = extract_limit(sql_type)
|
30
|
+
Type::Halfvec.new(limit: limit)
|
31
|
+
end
|
32
|
+
m.register_type "sparsevec" do |_, _, sql_type|
|
33
|
+
limit = extract_limit(sql_type)
|
34
|
+
Type::Sparsevec.new(limit: limit)
|
35
|
+
end
|
36
|
+
m.register_type "vector" do |_, _, sql_type|
|
37
|
+
limit = extract_limit(sql_type)
|
38
|
+
Type::Vector.new(limit: limit)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module SQLite
|
3
|
+
# note: this is a public API (unlike PostgreSQL and MySQL)
|
4
|
+
def self.initialize!
|
5
|
+
return if defined?(@initialized)
|
6
|
+
|
7
|
+
require_relative "type/sqlite_vector"
|
8
|
+
require_relative "type/sqlite_int8_vector"
|
9
|
+
|
10
|
+
require "sqlite_vec"
|
11
|
+
require "active_record/connection_adapters/sqlite3_adapter"
|
12
|
+
|
13
|
+
ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods)
|
14
|
+
|
15
|
+
@initialized = true
|
16
|
+
end
|
17
|
+
|
18
|
+
module InstanceMethods
|
19
|
+
def configure_connection
|
20
|
+
super
|
21
|
+
db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection
|
22
|
+
db.enable_load_extension(1)
|
23
|
+
SqliteVec.load(db)
|
24
|
+
db.enable_load_extension(0)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class MysqlVector < ActiveRecord::Type::Binary
|
4
|
+
def type
|
5
|
+
:vector
|
6
|
+
end
|
7
|
+
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = value.to_a.pack("e*")
|
11
|
+
end
|
12
|
+
super(value)
|
13
|
+
end
|
14
|
+
|
15
|
+
def deserialize(value)
|
16
|
+
value = super
|
17
|
+
cast_value(value) unless value.nil?
|
18
|
+
end
|
19
|
+
|
20
|
+
private
|
21
|
+
|
22
|
+
def cast_value(value)
|
23
|
+
if value.is_a?(String)
|
24
|
+
value.unpack("e*")
|
25
|
+
elsif Utils.array?(value)
|
26
|
+
value.to_a
|
27
|
+
else
|
28
|
+
raise "can't cast #{value.class.name} to vector"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class SqliteInt8Vector < ActiveRecord::Type::Binary
|
4
|
+
def serialize(value)
|
5
|
+
if Utils.array?(value)
|
6
|
+
value = value.to_a.pack("c*")
|
7
|
+
end
|
8
|
+
super(value)
|
9
|
+
end
|
10
|
+
|
11
|
+
def deserialize(value)
|
12
|
+
value = super
|
13
|
+
cast_value(value) unless value.nil?
|
14
|
+
end
|
15
|
+
|
16
|
+
private
|
17
|
+
|
18
|
+
def cast_value(value)
|
19
|
+
if value.is_a?(String)
|
20
|
+
value.unpack("c*")
|
21
|
+
elsif Utils.array?(value)
|
22
|
+
value.to_a
|
23
|
+
else
|
24
|
+
raise "can't cast #{value.class.name} to vector"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class SqliteVector < ActiveRecord::Type::Binary
|
4
|
+
def serialize(value)
|
5
|
+
if Utils.array?(value)
|
6
|
+
value = value.to_a.pack("f*")
|
7
|
+
end
|
8
|
+
super(value)
|
9
|
+
end
|
10
|
+
|
11
|
+
def deserialize(value)
|
12
|
+
value = super
|
13
|
+
cast_value(value) unless value.nil?
|
14
|
+
end
|
15
|
+
|
16
|
+
private
|
17
|
+
|
18
|
+
def cast_value(value)
|
19
|
+
if value.is_a?(String)
|
20
|
+
value.unpack("f*")
|
21
|
+
elsif Utils.array?(value)
|
22
|
+
value.to_a
|
23
|
+
else
|
24
|
+
raise "can't cast #{value.class.name} to vector"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
data/lib/neighbor/utils.rb
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Utils
|
3
|
-
def self.validate_dimensions(value, type, expected)
|
3
|
+
def self.validate_dimensions(value, type, expected, adapter)
|
4
4
|
dimensions = type == :sparsevec ? value.dimensions : value.size
|
5
|
+
dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
|
6
|
+
|
5
7
|
if expected && dimensions != expected
|
6
8
|
"Expected #{expected} dimensions, not #{dimensions}"
|
7
9
|
end
|
@@ -9,7 +11,7 @@ module Neighbor
|
|
9
11
|
|
10
12
|
def self.validate_finite(value, type)
|
11
13
|
case type
|
12
|
-
when :bit
|
14
|
+
when :bit, :integer
|
13
15
|
true
|
14
16
|
when :sparsevec
|
15
17
|
value.values.all?(&:finite?)
|
@@ -18,17 +20,19 @@ module Neighbor
|
|
18
20
|
end
|
19
21
|
end
|
20
22
|
|
21
|
-
def self.validate(value, dimensions:,
|
22
|
-
if (message = validate_dimensions(value,
|
23
|
+
def self.validate(value, dimensions:, type:, adapter:)
|
24
|
+
if (message = validate_dimensions(value, type, dimensions, adapter))
|
23
25
|
raise Error, message
|
24
26
|
end
|
25
27
|
|
26
|
-
if !validate_finite(value,
|
28
|
+
if !validate_finite(value, type)
|
27
29
|
raise Error, "Values must be finite"
|
28
30
|
end
|
29
31
|
end
|
30
32
|
|
31
33
|
def self.normalize(value, column_info:)
|
34
|
+
return nil if value.nil?
|
35
|
+
|
32
36
|
raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec].include?(column_info&.type)
|
33
37
|
|
34
38
|
norm = Math.sqrt(value.sum { |v| v * v })
|
@@ -42,5 +46,156 @@ module Neighbor
|
|
42
46
|
def self.array?(value)
|
43
47
|
!value.nil? && value.respond_to?(:to_a)
|
44
48
|
end
|
49
|
+
|
50
|
+
def self.adapter(model)
|
51
|
+
case model.connection_db_config.adapter
|
52
|
+
when /sqlite/i
|
53
|
+
:sqlite
|
54
|
+
when /mysql|trilogy/i
|
55
|
+
model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
|
56
|
+
else
|
57
|
+
:postgresql
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
def self.type(adapter, column_type)
|
62
|
+
case adapter
|
63
|
+
when :mysql
|
64
|
+
if column_type == :binary
|
65
|
+
:bit
|
66
|
+
else
|
67
|
+
column_type
|
68
|
+
end
|
69
|
+
else
|
70
|
+
column_type
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def self.operator(adapter, column_type, distance)
|
75
|
+
case adapter
|
76
|
+
when :sqlite
|
77
|
+
case distance
|
78
|
+
when "euclidean"
|
79
|
+
"vec_distance_L2"
|
80
|
+
when "cosine"
|
81
|
+
"vec_distance_cosine"
|
82
|
+
when "taxicab"
|
83
|
+
"vec_distance_L1"
|
84
|
+
when "hamming"
|
85
|
+
"vec_distance_hamming"
|
86
|
+
end
|
87
|
+
when :mariadb
|
88
|
+
case column_type
|
89
|
+
when :vector
|
90
|
+
case distance
|
91
|
+
when "euclidean"
|
92
|
+
"VEC_DISTANCE_EUCLIDEAN"
|
93
|
+
when "cosine"
|
94
|
+
"VEC_DISTANCE_COSINE"
|
95
|
+
end
|
96
|
+
when :integer
|
97
|
+
case distance
|
98
|
+
when "hamming"
|
99
|
+
"BIT_COUNT"
|
100
|
+
end
|
101
|
+
else
|
102
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
103
|
+
end
|
104
|
+
when :mysql
|
105
|
+
case column_type
|
106
|
+
when :vector
|
107
|
+
case distance
|
108
|
+
when "cosine"
|
109
|
+
"COSINE"
|
110
|
+
when "euclidean"
|
111
|
+
"EUCLIDEAN"
|
112
|
+
end
|
113
|
+
when :binary
|
114
|
+
case distance
|
115
|
+
when "hamming"
|
116
|
+
"BIT_COUNT"
|
117
|
+
end
|
118
|
+
else
|
119
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
120
|
+
end
|
121
|
+
else
|
122
|
+
case column_type
|
123
|
+
when :bit
|
124
|
+
case distance
|
125
|
+
when "hamming"
|
126
|
+
"<~>"
|
127
|
+
when "jaccard"
|
128
|
+
"<%>"
|
129
|
+
when "hamming2"
|
130
|
+
"#"
|
131
|
+
end
|
132
|
+
when :vector, :halfvec, :sparsevec
|
133
|
+
case distance
|
134
|
+
when "inner_product"
|
135
|
+
"<#>"
|
136
|
+
when "cosine"
|
137
|
+
"<=>"
|
138
|
+
when "euclidean"
|
139
|
+
"<->"
|
140
|
+
when "taxicab"
|
141
|
+
"<+>"
|
142
|
+
end
|
143
|
+
when :cube
|
144
|
+
case distance
|
145
|
+
when "taxicab"
|
146
|
+
"<#>"
|
147
|
+
when "chebyshev"
|
148
|
+
"<=>"
|
149
|
+
when "euclidean", "cosine"
|
150
|
+
"<->"
|
151
|
+
end
|
152
|
+
else
|
153
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
def self.order(adapter, type, operator, quoted_attribute, query)
|
159
|
+
case adapter
|
160
|
+
when :sqlite
|
161
|
+
case type
|
162
|
+
when :int8
|
163
|
+
"#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
|
164
|
+
when :bit
|
165
|
+
"#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
|
166
|
+
else
|
167
|
+
"#{operator}(#{quoted_attribute}, #{query})"
|
168
|
+
end
|
169
|
+
when :mariadb
|
170
|
+
if operator == "BIT_COUNT"
|
171
|
+
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
|
172
|
+
else
|
173
|
+
"#{operator}(#{quoted_attribute}, #{query})"
|
174
|
+
end
|
175
|
+
when :mysql
|
176
|
+
if operator == "BIT_COUNT"
|
177
|
+
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
|
178
|
+
elsif operator == "COSINE"
|
179
|
+
"DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
|
180
|
+
else
|
181
|
+
"DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
|
182
|
+
end
|
183
|
+
else
|
184
|
+
if operator == "#"
|
185
|
+
"bit_count(#{quoted_attribute} # #{query})"
|
186
|
+
else
|
187
|
+
"#{quoted_attribute} #{operator} #{query}"
|
188
|
+
end
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def self.normalize_required?(adapter, column_type)
|
193
|
+
case adapter
|
194
|
+
when :postgresql
|
195
|
+
column_type == :cube
|
196
|
+
else
|
197
|
+
false
|
198
|
+
end
|
199
|
+
end
|
45
200
|
end
|
46
201
|
end
|
data/lib/neighbor/version.rb
CHANGED
data/lib/neighbor.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "active_support"
|
3
3
|
|
4
|
+
# adapter hooks
|
5
|
+
require_relative "neighbor/mysql"
|
6
|
+
require_relative "neighbor/postgresql"
|
7
|
+
require_relative "neighbor/sqlite"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require_relative "neighbor/reranking"
|
6
11
|
require_relative "neighbor/sparse_vector"
|
@@ -9,53 +14,22 @@ require_relative "neighbor/version"
|
|
9
14
|
|
10
15
|
module Neighbor
|
11
16
|
class Error < StandardError; end
|
12
|
-
|
13
|
-
module RegisterTypes
|
14
|
-
def initialize_type_map(m = type_map)
|
15
|
-
super
|
16
|
-
m.register_type "cube", Type::Cube.new
|
17
|
-
m.register_type "halfvec" do |_, _, sql_type|
|
18
|
-
limit = extract_limit(sql_type)
|
19
|
-
Type::Halfvec.new(limit: limit)
|
20
|
-
end
|
21
|
-
m.register_type "sparsevec" do |_, _, sql_type|
|
22
|
-
limit = extract_limit(sql_type)
|
23
|
-
Type::Sparsevec.new(limit: limit)
|
24
|
-
end
|
25
|
-
m.register_type "vector" do |_, _, sql_type|
|
26
|
-
limit = extract_limit(sql_type)
|
27
|
-
Type::Vector.new(limit: limit)
|
28
|
-
end
|
29
|
-
end
|
30
|
-
end
|
31
17
|
end
|
32
18
|
|
33
19
|
ActiveSupport.on_load(:active_record) do
|
20
|
+
require_relative "neighbor/attribute"
|
34
21
|
require_relative "neighbor/model"
|
35
|
-
require_relative "neighbor/
|
36
|
-
require_relative "neighbor/type/halfvec"
|
37
|
-
require_relative "neighbor/type/sparsevec"
|
38
|
-
require_relative "neighbor/type/vector"
|
22
|
+
require_relative "neighbor/normalized_attribute"
|
39
23
|
|
40
24
|
extend Neighbor::Model
|
41
25
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
47
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
48
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
49
|
-
|
50
|
-
# ensure schema can be loaded
|
51
|
-
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
52
|
-
|
53
|
-
# prevent unknown OID warning
|
54
|
-
if ActiveRecord::VERSION::MAJOR >= 7
|
55
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(Neighbor::RegisterTypes)
|
56
|
-
else
|
57
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.prepend(Neighbor::RegisterTypes)
|
26
|
+
begin
|
27
|
+
Neighbor::PostgreSQL.initialize!
|
28
|
+
rescue Gem::LoadError
|
29
|
+
# tries to load pg gem, which may not be available
|
58
30
|
end
|
31
|
+
|
32
|
+
Neighbor::MySQL.initialize!
|
59
33
|
end
|
60
34
|
|
61
35
|
require_relative "neighbor/railtie" if defined?(Rails::Railtie)
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: neighbor
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-12-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: activerecord
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '
|
19
|
+
version: '7'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: '
|
26
|
+
version: '7'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -34,17 +34,27 @@ files:
|
|
34
34
|
- LICENSE.txt
|
35
35
|
- README.md
|
36
36
|
- lib/generators/neighbor/cube_generator.rb
|
37
|
+
- lib/generators/neighbor/sqlite_generator.rb
|
37
38
|
- lib/generators/neighbor/templates/cube.rb.tt
|
39
|
+
- lib/generators/neighbor/templates/sqlite.rb.tt
|
38
40
|
- lib/generators/neighbor/templates/vector.rb.tt
|
39
41
|
- lib/generators/neighbor/vector_generator.rb
|
40
42
|
- lib/neighbor.rb
|
43
|
+
- lib/neighbor/attribute.rb
|
41
44
|
- lib/neighbor/model.rb
|
45
|
+
- lib/neighbor/mysql.rb
|
46
|
+
- lib/neighbor/normalized_attribute.rb
|
47
|
+
- lib/neighbor/postgresql.rb
|
42
48
|
- lib/neighbor/railtie.rb
|
43
49
|
- lib/neighbor/reranking.rb
|
44
50
|
- lib/neighbor/sparse_vector.rb
|
51
|
+
- lib/neighbor/sqlite.rb
|
45
52
|
- lib/neighbor/type/cube.rb
|
46
53
|
- lib/neighbor/type/halfvec.rb
|
54
|
+
- lib/neighbor/type/mysql_vector.rb
|
47
55
|
- lib/neighbor/type/sparsevec.rb
|
56
|
+
- lib/neighbor/type/sqlite_int8_vector.rb
|
57
|
+
- lib/neighbor/type/sqlite_vector.rb
|
48
58
|
- lib/neighbor/type/vector.rb
|
49
59
|
- lib/neighbor/utils.rb
|
50
60
|
- lib/neighbor/version.rb
|
@@ -67,8 +77,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
67
77
|
- !ruby/object:Gem::Version
|
68
78
|
version: '0'
|
69
79
|
requirements: []
|
70
|
-
rubygems_version: 3.5.
|
80
|
+
rubygems_version: 3.5.22
|
71
81
|
signing_key:
|
72
82
|
specification_version: 4
|
73
|
-
summary: Nearest neighbor search for Rails
|
83
|
+
summary: Nearest neighbor search for Rails
|
74
84
|
test_files: []
|