neighbor 0.4.3 → 0.5.1
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +14 -0
- data/README.md +242 -21
- data/lib/generators/neighbor/sqlite_generator.rb +13 -0
- data/lib/generators/neighbor/templates/sqlite.rb.tt +2 -0
- data/lib/neighbor/attribute.rb +48 -0
- data/lib/neighbor/model.rb +59 -68
- data/lib/neighbor/mysql.rb +37 -0
- data/lib/neighbor/normalized_attribute.rb +21 -0
- data/lib/neighbor/postgresql.rb +43 -0
- data/lib/neighbor/sqlite.rb +28 -0
- data/lib/neighbor/type/mysql_vector.rb +33 -0
- data/lib/neighbor/type/sqlite_int8_vector.rb +29 -0
- data/lib/neighbor/type/sqlite_vector.rb +29 -0
- data/lib/neighbor/utils.rb +160 -5
- data/lib/neighbor/version.rb +1 -1
- data/lib/neighbor.rb +13 -39
- metadata +16 -6
checksums.yaml
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
---
|
2
2
|
SHA256:
|
3
|
-
metadata.gz:
|
4
|
-
data.tar.gz:
|
3
|
+
metadata.gz: ee3864aa1511aa273c2d619e408bff8bcb491adc6dd111c5c88f7b79ef0baafc
|
4
|
+
data.tar.gz: 7d79f79814a0041e77d18edf83cb87b5fba5f59d1a25a71370d8b28c6a5cd289
|
5
5
|
SHA512:
|
6
|
-
metadata.gz:
|
7
|
-
data.tar.gz:
|
6
|
+
metadata.gz: c8d63b19a366d670212d6d5296f3b0f0de45e007635b7026bc9474bbd9248c9fc9a7ba4ee8b7c1b796ecdd7eef93460fe4e4e746eb49a8c19d471f4827530ecf
|
7
|
+
data.tar.gz: c371a55e3f1579b6a62fcc5cc55c5ac54cb318faa12278c736089b97de1990ecf56a17863bcc7043b4a2cde9be50cb02b8b7dcecaf5032d371e1675a062723f2
|
data/CHANGELOG.md
CHANGED
@@ -1,3 +1,17 @@
|
|
1
|
+
## 0.5.1 (2024-12-03)
|
2
|
+
|
3
|
+
- Added experimental support for MariaDB 11.7
|
4
|
+
- Dropped experimental support for MariaDB 11.6 Vector
|
5
|
+
|
6
|
+
## 0.5.0 (2024-10-07)
|
7
|
+
|
8
|
+
- Added experimental support for SQLite (sqlite-vec)
|
9
|
+
- Added experimental support for MariaDB 11.6 Vector
|
10
|
+
- Added experimental support for MySQL 9
|
11
|
+
- Changed `normalize` option to use Active Record normalization
|
12
|
+
- Fixed connection leasing for Active Record 7.2
|
13
|
+
- Dropped support for Active Record < 7
|
14
|
+
|
1
15
|
## 0.4.3 (2024-09-02)
|
2
16
|
|
3
17
|
- Added `rrf` method
|
data/README.md
CHANGED
@@ -1,6 +1,13 @@
|
|
1
1
|
# Neighbor
|
2
2
|
|
3
|
-
Nearest neighbor search for Rails
|
3
|
+
Nearest neighbor search for Rails
|
4
|
+
|
5
|
+
Supports:
|
6
|
+
|
7
|
+
- Postgres (cube and pgvector)
|
8
|
+
- SQLite (sqlite-vec) - experimental
|
9
|
+
- MariaDB 11.7 - experimental
|
10
|
+
- MySQL 9 (searching requires HeatWave) - experimental
|
4
11
|
|
5
12
|
[![Build Status](https://github.com/ankane/neighbor/actions/workflows/build.yml/badge.svg)](https://github.com/ankane/neighbor/actions)
|
6
13
|
|
@@ -12,7 +19,7 @@ Add this line to your application’s Gemfile:
|
|
12
19
|
gem "neighbor"
|
13
20
|
```
|
14
21
|
|
15
|
-
|
22
|
+
### For Postgres
|
16
23
|
|
17
24
|
Neighbor supports two extensions: [cube](https://www.postgresql.org/docs/current/cube.html) and [pgvector](https://github.com/pgvector/pgvector). cube ships with Postgres, while pgvector supports more dimensions and approximate nearest neighbor search.
|
18
25
|
|
@@ -30,16 +37,35 @@ rails generate neighbor:vector
|
|
30
37
|
rails db:migrate
|
31
38
|
```
|
32
39
|
|
40
|
+
### For SQLite
|
41
|
+
|
42
|
+
Add this line to your application’s Gemfile:
|
43
|
+
|
44
|
+
```ruby
|
45
|
+
gem "sqlite-vec"
|
46
|
+
```
|
47
|
+
|
48
|
+
And run:
|
49
|
+
|
50
|
+
```sh
|
51
|
+
rails generate neighbor:sqlite
|
52
|
+
```
|
53
|
+
|
33
54
|
## Getting Started
|
34
55
|
|
35
56
|
Create a migration
|
36
57
|
|
37
58
|
```ruby
|
38
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
59
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
39
60
|
def change
|
61
|
+
# cube
|
40
62
|
add_column :items, :embedding, :cube
|
41
|
-
|
63
|
+
|
64
|
+
# pgvector, MariaDB, and MySQL
|
42
65
|
add_column :items, :embedding, :vector, limit: 3 # dimensions
|
66
|
+
|
67
|
+
# sqlite-vec
|
68
|
+
add_column :items, :embedding, :binary
|
43
69
|
end
|
44
70
|
end
|
45
71
|
```
|
@@ -81,6 +107,9 @@ See the additional docs for:
|
|
81
107
|
|
82
108
|
- [cube](#cube)
|
83
109
|
- [pgvector](#pgvector)
|
110
|
+
- [sqlite-vec](#sqlite-vec)
|
111
|
+
- [MariaDB](#mariadb)
|
112
|
+
- [MySQL](#mysql)
|
84
113
|
|
85
114
|
Or check out some [examples](#examples)
|
86
115
|
|
@@ -134,12 +163,18 @@ Supported values are:
|
|
134
163
|
|
135
164
|
The `vector` type can have up to 16,000 dimensions, and vectors with up to 2,000 dimensions can be indexed.
|
136
165
|
|
166
|
+
The `halfvec` type can have up to 16,000 dimensions, and half vectors with up to 4,000 dimensions can be indexed.
|
167
|
+
|
168
|
+
The `bit` type can have up to 83 million dimensions, and bit vectors with up to 64,000 dimensions can be indexed.
|
169
|
+
|
170
|
+
The `sparsevec` type can have up to 16,000 non-zero elements, and sparse vectors with up to 1,000 non-zero elements can be indexed.
|
171
|
+
|
137
172
|
### Indexing
|
138
173
|
|
139
174
|
Add an approximate index to speed up queries. Create a migration with:
|
140
175
|
|
141
176
|
```ruby
|
142
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
177
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
143
178
|
def change
|
144
179
|
add_index :items, :embedding, using: :hnsw, opclass: :vector_l2_ops
|
145
180
|
# or
|
@@ -167,7 +202,7 @@ Item.connection.execute("SET ivfflat.probes = 3")
|
|
167
202
|
Use the `halfvec` type to store half-precision vectors
|
168
203
|
|
169
204
|
```ruby
|
170
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
205
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
171
206
|
def change
|
172
207
|
add_column :items, :embedding, :halfvec, limit: 3 # dimensions
|
173
208
|
end
|
@@ -179,7 +214,7 @@ end
|
|
179
214
|
Index vectors at half precision for smaller indexes
|
180
215
|
|
181
216
|
```ruby
|
182
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
217
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
183
218
|
def change
|
184
219
|
add_index :items, "(embedding::halfvec(3)) vector_l2_ops", using: :hnsw
|
185
220
|
end
|
@@ -197,7 +232,7 @@ Item.nearest_neighbors(:embedding, [0.9, 1.3, 1.1], distance: "euclidean", preci
|
|
197
232
|
Use the `bit` type to store binary vectors
|
198
233
|
|
199
234
|
```ruby
|
200
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
235
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
201
236
|
def change
|
202
237
|
add_column :items, :embedding, :bit, limit: 3 # dimensions
|
203
238
|
end
|
@@ -215,7 +250,7 @@ Item.nearest_neighbors(:embedding, "101", distance: "hamming").first(5)
|
|
215
250
|
Use expression indexing for binary quantization
|
216
251
|
|
217
252
|
```ruby
|
218
|
-
class AddIndexToItemsEmbedding < ActiveRecord::Migration[
|
253
|
+
class AddIndexToItemsEmbedding < ActiveRecord::Migration[8.0]
|
219
254
|
def change
|
220
255
|
add_index :items, "(binary_quantize(embedding)::bit(3)) bit_hamming_ops", using: :hnsw
|
221
256
|
end
|
@@ -227,7 +262,7 @@ end
|
|
227
262
|
Use the `sparsevec` type to store sparse vectors
|
228
263
|
|
229
264
|
```ruby
|
230
|
-
class AddEmbeddingToItems < ActiveRecord::Migration[
|
265
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
231
266
|
def change
|
232
267
|
add_column :items, :embedding, :sparsevec, limit: 3 # dimensions
|
233
268
|
end
|
@@ -241,6 +276,182 @@ embedding = Neighbor::SparseVector.new({0 => 0.9, 1 => 1.3, 2 => 1.1}, 3)
|
|
241
276
|
Item.nearest_neighbors(:embedding, embedding, distance: "euclidean").first(5)
|
242
277
|
```
|
243
278
|
|
279
|
+
## sqlite-vec
|
280
|
+
|
281
|
+
### Distance
|
282
|
+
|
283
|
+
Supported values are:
|
284
|
+
|
285
|
+
- `euclidean`
|
286
|
+
- `cosine`
|
287
|
+
- `taxicab`
|
288
|
+
- `hamming`
|
289
|
+
|
290
|
+
### Dimensions
|
291
|
+
|
292
|
+
For sqlite-vec, it’s a good idea to specify the number of dimensions to ensure all records have the same number.
|
293
|
+
|
294
|
+
```ruby
|
295
|
+
class Item < ApplicationRecord
|
296
|
+
has_neighbors :embedding, dimensions: 3
|
297
|
+
end
|
298
|
+
```
|
299
|
+
|
300
|
+
### Virtual Tables
|
301
|
+
|
302
|
+
You can also use [virtual tables](https://alexgarcia.xyz/sqlite-vec/features/knn.html)
|
303
|
+
|
304
|
+
```ruby
|
305
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
306
|
+
def change
|
307
|
+
# Rails 8+
|
308
|
+
create_virtual_table :items, :vec0, [
|
309
|
+
"embedding float[3] distance_metric=L2"
|
310
|
+
]
|
311
|
+
|
312
|
+
# Rails < 8
|
313
|
+
execute <<~SQL
|
314
|
+
CREATE VIRTUAL TABLE items USING vec0(
|
315
|
+
embedding float[3] distance_metric=L2
|
316
|
+
)
|
317
|
+
SQL
|
318
|
+
end
|
319
|
+
end
|
320
|
+
```
|
321
|
+
|
322
|
+
Use `distance_metric=cosine` for cosine distance
|
323
|
+
|
324
|
+
You can optionally ignore any shadow tables that are created
|
325
|
+
|
326
|
+
```ruby
|
327
|
+
ActiveRecord::SchemaDumper.ignore_tables += [
|
328
|
+
"items_chunks", "items_rowids", "items_vector_chunks00"
|
329
|
+
]
|
330
|
+
```
|
331
|
+
|
332
|
+
Create a model with `rowid` as the primary key
|
333
|
+
|
334
|
+
```ruby
|
335
|
+
class Item < ApplicationRecord
|
336
|
+
self.primary_key = "rowid"
|
337
|
+
|
338
|
+
has_neighbors :embedding, dimensions: 3
|
339
|
+
end
|
340
|
+
```
|
341
|
+
|
342
|
+
Get the `k` nearest neighbors
|
343
|
+
|
344
|
+
```ruby
|
345
|
+
Item.where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
|
346
|
+
```
|
347
|
+
|
348
|
+
Filter by primary key
|
349
|
+
|
350
|
+
```ruby
|
351
|
+
Item.where(rowid: [2, 3]).where("embedding MATCH ?", [1, 2, 3].to_s).where(k: 5).order(:distance)
|
352
|
+
```
|
353
|
+
|
354
|
+
### Int8 Vectors
|
355
|
+
|
356
|
+
Use the `type` option for int8 vectors
|
357
|
+
|
358
|
+
```ruby
|
359
|
+
class Item < ApplicationRecord
|
360
|
+
has_neighbors :embedding, dimensions: 3, type: :int8
|
361
|
+
end
|
362
|
+
```
|
363
|
+
|
364
|
+
### Binary Vectors
|
365
|
+
|
366
|
+
Use the `type` option for binary vectors
|
367
|
+
|
368
|
+
```ruby
|
369
|
+
class Item < ApplicationRecord
|
370
|
+
has_neighbors :embedding, dimensions: 8, type: :bit
|
371
|
+
end
|
372
|
+
```
|
373
|
+
|
374
|
+
Get the nearest neighbors by Hamming distance
|
375
|
+
|
376
|
+
```ruby
|
377
|
+
Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
|
378
|
+
```
|
379
|
+
|
380
|
+
## MariaDB
|
381
|
+
|
382
|
+
### Distance
|
383
|
+
|
384
|
+
Supported values are:
|
385
|
+
|
386
|
+
- `euclidean`
|
387
|
+
- `cosine`
|
388
|
+
- `hamming`
|
389
|
+
|
390
|
+
### Indexing
|
391
|
+
|
392
|
+
Vector columns must use `null: false` to add a vector index
|
393
|
+
|
394
|
+
```ruby
|
395
|
+
class CreateItems < ActiveRecord::Migration[8.0]
|
396
|
+
def change
|
397
|
+
create_table :items do |t|
|
398
|
+
t.vector :embedding, limit: 3, null: false
|
399
|
+
t.index :embedding, type: :vector
|
400
|
+
end
|
401
|
+
end
|
402
|
+
end
|
403
|
+
```
|
404
|
+
|
405
|
+
### Binary Vectors
|
406
|
+
|
407
|
+
Use the `bigint` type to store binary vectors
|
408
|
+
|
409
|
+
```ruby
|
410
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
411
|
+
def change
|
412
|
+
add_column :items, :embedding, :bigint
|
413
|
+
end
|
414
|
+
end
|
415
|
+
```
|
416
|
+
|
417
|
+
Note: Binary vectors can have up to 64 dimensions
|
418
|
+
|
419
|
+
Get the nearest neighbors by Hamming distance
|
420
|
+
|
421
|
+
```ruby
|
422
|
+
Item.nearest_neighbors(:embedding, 5, distance: "hamming").first(5)
|
423
|
+
```
|
424
|
+
|
425
|
+
## MySQL
|
426
|
+
|
427
|
+
### Distance
|
428
|
+
|
429
|
+
Supported values are:
|
430
|
+
|
431
|
+
- `euclidean`
|
432
|
+
- `cosine`
|
433
|
+
- `hamming`
|
434
|
+
|
435
|
+
Note: The `DISTANCE()` function is [only available on HeatWave](https://dev.mysql.com/doc/refman/9.0/en/vector-functions.html)
|
436
|
+
|
437
|
+
### Binary Vectors
|
438
|
+
|
439
|
+
Use the `binary` type to store binary vectors
|
440
|
+
|
441
|
+
```ruby
|
442
|
+
class AddEmbeddingToItems < ActiveRecord::Migration[8.0]
|
443
|
+
def change
|
444
|
+
add_column :items, :embedding, :binary
|
445
|
+
end
|
446
|
+
end
|
447
|
+
```
|
448
|
+
|
449
|
+
Get the nearest neighbors by Hamming distance
|
450
|
+
|
451
|
+
```ruby
|
452
|
+
Item.nearest_neighbors(:embedding, "\x05", distance: "hamming").first(5)
|
453
|
+
```
|
454
|
+
|
244
455
|
## Examples
|
245
456
|
|
246
457
|
- [Embeddings](#openai-embeddings) with OpenAI
|
@@ -472,12 +683,9 @@ end
|
|
472
683
|
Create some documents
|
473
684
|
|
474
685
|
```ruby
|
475
|
-
|
476
|
-
|
477
|
-
|
478
|
-
"The bear is growling"
|
479
|
-
]
|
480
|
-
documents = Document.create!(texts.map { |v| {content: v} })
|
686
|
+
Document.create!(content: "The dog is barking")
|
687
|
+
Document.create!(content: "The cat is purring")
|
688
|
+
Document.create!(content: "The bear is growling")
|
481
689
|
```
|
482
690
|
|
483
691
|
Generate an embedding for each document
|
@@ -485,9 +693,9 @@ Generate an embedding for each document
|
|
485
693
|
```ruby
|
486
694
|
embed = Informers.pipeline("embedding", "Snowflake/snowflake-arctic-embed-m-v1.5")
|
487
695
|
embed_options = {model_output: "sentence_embedding", pooling: "none"} # specific to embedding model
|
488
|
-
embeddings = embed.(documents.map(&:content), **embed_options)
|
489
696
|
|
490
|
-
|
697
|
+
Document.find_each do |document|
|
698
|
+
embedding = embed.(document.content, **embed_options)
|
491
699
|
document.update!(embedding: embedding)
|
492
700
|
end
|
493
701
|
```
|
@@ -511,7 +719,7 @@ semantic_results =
|
|
511
719
|
To combine the results, use Reciprocal Rank Fusion (RRF)
|
512
720
|
|
513
721
|
```ruby
|
514
|
-
Neighbor::Reranking.rrf(keyword_results, semantic_results)
|
722
|
+
Neighbor::Reranking.rrf(keyword_results, semantic_results).first(5)
|
515
723
|
```
|
516
724
|
|
517
725
|
Or a reranking model
|
@@ -519,7 +727,7 @@ Or a reranking model
|
|
519
727
|
```ruby
|
520
728
|
rerank = Informers.pipeline("reranking", "mixedbread-ai/mxbai-rerank-xsmall-v1")
|
521
729
|
results = (keyword_results + semantic_results).uniq
|
522
|
-
rerank.(query, results.map(&:content)
|
730
|
+
rerank.(query, results.map(&:content)).first(5).map { |v| results[v[:doc_id]] }
|
523
731
|
```
|
524
732
|
|
525
733
|
See the [complete code](examples/hybrid/example.rb)
|
@@ -667,6 +875,19 @@ To get started with development:
|
|
667
875
|
git clone https://github.com/ankane/neighbor.git
|
668
876
|
cd neighbor
|
669
877
|
bundle install
|
878
|
+
|
879
|
+
# Postgres
|
670
880
|
createdb neighbor_test
|
671
|
-
bundle exec rake test
|
881
|
+
bundle exec rake test:postgresql
|
882
|
+
|
883
|
+
# SQLite
|
884
|
+
bundle exec rake test:sqlite
|
885
|
+
|
886
|
+
# MariaDB
|
887
|
+
docker run -e MARIADB_ALLOW_EMPTY_ROOT_PASSWORD=1 -e MARIADB_DATABASE=neighbor_test -p 3307:3306 mariadb:11.7-rc
|
888
|
+
bundle exec rake test:mariadb
|
889
|
+
|
890
|
+
# MySQL
|
891
|
+
docker run -e MYSQL_ALLOW_EMPTY_PASSWORD=1 -e MYSQL_DATABASE=neighbor_test -p 3306:3306 mysql:9
|
892
|
+
bundle exec rake test:mysql
|
672
893
|
```
|
@@ -0,0 +1,13 @@
|
|
1
|
+
require "rails/generators"
|
2
|
+
|
3
|
+
module Neighbor
|
4
|
+
module Generators
|
5
|
+
class SqliteGenerator < Rails::Generators::Base
|
6
|
+
source_root File.join(__dir__, "templates")
|
7
|
+
|
8
|
+
def copy_templates
|
9
|
+
template "sqlite.rb", "config/initializers/neighbor.rb"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class Attribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, type:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@type = type
|
9
|
+
@attribute_name = attribute_name
|
10
|
+
end
|
11
|
+
|
12
|
+
private
|
13
|
+
|
14
|
+
def cast_value(...)
|
15
|
+
new_cast_type.send(:cast_value, ...)
|
16
|
+
end
|
17
|
+
|
18
|
+
def new_cast_type
|
19
|
+
@new_cast_type ||= begin
|
20
|
+
if @cast_type.is_a?(ActiveModel::Type::Value)
|
21
|
+
case Utils.adapter(@model)
|
22
|
+
when :sqlite
|
23
|
+
case @type&.to_sym
|
24
|
+
when :int8
|
25
|
+
Type::SqliteInt8Vector.new
|
26
|
+
when :bit
|
27
|
+
@cast_type
|
28
|
+
when :float32, nil
|
29
|
+
Type::SqliteVector.new
|
30
|
+
else
|
31
|
+
raise ArgumentError, "Unsupported type"
|
32
|
+
end
|
33
|
+
when :mariadb
|
34
|
+
if @model.columns_hash[@attribute_name.to_s]&.type == :integer
|
35
|
+
@cast_type
|
36
|
+
else
|
37
|
+
Type::MysqlVector.new
|
38
|
+
end
|
39
|
+
else
|
40
|
+
@cast_type
|
41
|
+
end
|
42
|
+
else
|
43
|
+
@cast_type
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/neighbor/model.rb
CHANGED
@@ -1,6 +1,6 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Model
|
3
|
-
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
|
3
|
+
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
|
4
4
|
if attribute_names.empty?
|
5
5
|
raise ArgumentError, "has_neighbors requires an attribute name"
|
6
6
|
end
|
@@ -24,125 +24,116 @@ module Neighbor
|
|
24
24
|
|
25
25
|
attribute_names.each do |attribute_name|
|
26
26
|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
|
27
|
-
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
|
27
|
+
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
|
28
|
+
end
|
29
|
+
|
30
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.2
|
31
|
+
decorate_attributes(attribute_names) do |name, cast_type|
|
32
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
|
33
|
+
end
|
34
|
+
else
|
35
|
+
attribute_names.each do |attribute_name|
|
36
|
+
attribute attribute_name do |cast_type|
|
37
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
if normalize
|
43
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.1
|
44
|
+
attribute_names.each do |attribute_name|
|
45
|
+
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
|
46
|
+
end
|
47
|
+
else
|
48
|
+
attribute_names.each do |attribute_name|
|
49
|
+
attribute attribute_name do |cast_type|
|
50
|
+
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
28
54
|
end
|
29
55
|
|
30
56
|
return if @neighbor_attributes.size != attribute_names.size
|
31
57
|
|
32
58
|
validate do
|
59
|
+
adapter = Utils.adapter(self.class)
|
60
|
+
|
33
61
|
self.class.neighbor_attributes.each do |k, v|
|
34
62
|
value = read_attribute(k)
|
35
63
|
next if value.nil?
|
36
64
|
|
37
65
|
column_info = self.class.columns_hash[k.to_s]
|
38
|
-
dimensions = v[:dimensions]
|
66
|
+
dimensions = v[:dimensions]
|
67
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
68
|
+
type = v[:type] || Utils.type(adapter, column_info&.type)
|
39
69
|
|
40
|
-
if !Neighbor::Utils.validate_dimensions(value,
|
70
|
+
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
|
41
71
|
errors.add(k, "must have #{dimensions} dimensions")
|
42
72
|
end
|
43
|
-
if !Neighbor::Utils.validate_finite(value,
|
73
|
+
if !Neighbor::Utils.validate_finite(value, type)
|
44
74
|
errors.add(k, "must have finite values")
|
45
75
|
end
|
46
76
|
end
|
47
77
|
end
|
48
78
|
|
49
|
-
|
50
|
-
before_save do
|
51
|
-
self.class.neighbor_attributes.each do |k, v|
|
52
|
-
next unless v[:normalize] && attribute_changed?(k)
|
53
|
-
value = read_attribute(k)
|
54
|
-
next if value.nil?
|
55
|
-
self[k] = Neighbor::Utils.normalize(value, column_info: self.class.columns_hash[k.to_s])
|
56
|
-
end
|
57
|
-
end
|
58
|
-
|
59
|
-
# cannot use keyword arguments with scope with Ruby 3.2 and Active Record 6.1
|
60
|
-
# https://github.com/rails/rails/issues/46934
|
61
|
-
scope :nearest_neighbors, ->(attribute_name, vector, options = nil) {
|
62
|
-
raise ArgumentError, "missing keyword: :distance" unless options.is_a?(Hash) && options.key?(:distance)
|
63
|
-
distance = options.delete(:distance)
|
64
|
-
precision = options.delete(:precision)
|
65
|
-
raise ArgumentError, "unknown keywords: #{options.keys.map(&:inspect).join(", ")}" if options.any?
|
66
|
-
|
79
|
+
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
|
67
80
|
attribute_name = attribute_name.to_sym
|
68
81
|
options = neighbor_attributes[attribute_name]
|
69
82
|
raise ArgumentError, "Invalid attribute" unless options
|
70
83
|
normalize = options[:normalize]
|
71
84
|
dimensions = options[:dimensions]
|
85
|
+
type = options[:type]
|
72
86
|
|
73
87
|
return none if vector.nil?
|
74
88
|
|
75
89
|
distance = distance.to_s
|
76
90
|
|
77
|
-
quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
|
78
|
-
|
79
91
|
column_info = columns_hash[attribute_name.to_s]
|
80
92
|
column_type = column_info&.type
|
81
93
|
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
when "hamming"
|
87
|
-
"<~>"
|
88
|
-
when "jaccard"
|
89
|
-
"<%>"
|
90
|
-
when "hamming2"
|
91
|
-
"#"
|
92
|
-
end
|
93
|
-
when :vector, :halfvec, :sparsevec
|
94
|
-
case distance
|
95
|
-
when "inner_product"
|
96
|
-
"<#>"
|
97
|
-
when "cosine"
|
98
|
-
"<=>"
|
99
|
-
when "euclidean"
|
100
|
-
"<->"
|
101
|
-
when "taxicab"
|
102
|
-
"<+>"
|
103
|
-
end
|
104
|
-
when :cube
|
105
|
-
case distance
|
106
|
-
when "taxicab"
|
107
|
-
"<#>"
|
108
|
-
when "chebyshev"
|
109
|
-
"<=>"
|
110
|
-
when "euclidean", "cosine"
|
111
|
-
"<->"
|
112
|
-
end
|
113
|
-
else
|
114
|
-
raise ArgumentError, "Unsupported type: #{column_type}"
|
115
|
-
end
|
94
|
+
adapter = Neighbor::Utils.adapter(klass)
|
95
|
+
if type && adapter != :sqlite
|
96
|
+
raise ArgumentError, "type only works with SQLite"
|
97
|
+
end
|
116
98
|
|
99
|
+
operator = Neighbor::Utils.operator(adapter, column_type, distance)
|
117
100
|
raise ArgumentError, "Invalid distance: #{distance}" unless operator
|
118
101
|
|
119
102
|
# ensure normalize set (can be true or false)
|
120
|
-
|
103
|
+
normalize_required = Utils.normalize_required?(adapter, column_type)
|
104
|
+
if distance == "cosine" && normalize_required && normalize.nil?
|
121
105
|
raise Neighbor::Error, "Set normalize for cosine distance with cube"
|
122
106
|
end
|
123
107
|
|
124
108
|
column_attribute = klass.type_for_attribute(attribute_name)
|
125
109
|
vector = column_attribute.cast(vector)
|
126
|
-
|
110
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
111
|
+
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
|
127
112
|
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
|
128
113
|
|
129
|
-
|
114
|
+
quoted_attribute = nil
|
115
|
+
query = nil
|
116
|
+
connection_pool.with_connection do |c|
|
117
|
+
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
|
118
|
+
query = c.quote(column_attribute.serialize(vector))
|
119
|
+
end
|
130
120
|
|
131
121
|
if !precision.nil?
|
122
|
+
if adapter != :postgresql || column_type != :vector
|
123
|
+
raise ArgumentError, "Precision not supported for this type"
|
124
|
+
end
|
125
|
+
|
132
126
|
case precision.to_s
|
133
127
|
when "half"
|
134
128
|
cast_dimensions = dimensions || column_info&.limit
|
135
129
|
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
|
136
|
-
quoted_attribute += "::halfvec(#{
|
130
|
+
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
|
137
131
|
else
|
138
132
|
raise ArgumentError, "Invalid precision"
|
139
133
|
end
|
140
134
|
end
|
141
135
|
|
142
|
-
order =
|
143
|
-
if operator == "#"
|
144
|
-
order = "bit_count(#{order})"
|
145
|
-
end
|
136
|
+
order = Utils.order(adapter, type, operator, quoted_attribute, query)
|
146
137
|
|
147
138
|
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
|
148
139
|
# with normalized vectors:
|
@@ -150,7 +141,7 @@ module Neighbor
|
|
150
141
|
# cosine distance = 1 - cosine similarity
|
151
142
|
# this transformation doesn't change the order, so only needed for select
|
152
143
|
neighbor_distance =
|
153
|
-
if
|
144
|
+
if distance == "cosine" && normalize_required
|
154
145
|
"POWER(#{order}, 2) / 2.0"
|
155
146
|
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
|
156
147
|
"(#{order}) * -1"
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module MySQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/mysql_vector"
|
5
|
+
|
6
|
+
require "active_record/connection_adapters/abstract_mysql_adapter"
|
7
|
+
|
8
|
+
# ensure schema can be dumped
|
9
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
10
|
+
|
11
|
+
# ensure schema can be loaded
|
12
|
+
unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector)
|
13
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector)
|
14
|
+
end
|
15
|
+
|
16
|
+
# prevent unknown OID warning
|
17
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes)
|
18
|
+
if ActiveRecord::VERSION::STRING.to_f < 7.1
|
19
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
module RegisterTypes
|
24
|
+
def initialize_type_map(m)
|
25
|
+
super
|
26
|
+
register_vector_type(m)
|
27
|
+
end
|
28
|
+
|
29
|
+
def register_vector_type(m)
|
30
|
+
m.register_type %r(^vector)i do |sql_type|
|
31
|
+
limit = extract_limit(sql_type)
|
32
|
+
Type::MysqlVector.new(limit: limit)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class NormalizedAttribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, to: :@cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@attribute_name = attribute_name.to_s
|
9
|
+
end
|
10
|
+
|
11
|
+
def cast(...)
|
12
|
+
Neighbor::Utils.normalize(@cast_type.cast(...), column_info: @model.columns_hash[@attribute_name])
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(...)
|
18
|
+
@cast_type.send(:cast_value, ...)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module PostgreSQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/cube"
|
5
|
+
require_relative "type/halfvec"
|
6
|
+
require_relative "type/sparsevec"
|
7
|
+
require_relative "type/vector"
|
8
|
+
|
9
|
+
require "active_record/connection_adapters/postgresql_adapter"
|
10
|
+
|
11
|
+
# ensure schema can be dumped
|
12
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
|
13
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
14
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
15
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
16
|
+
|
17
|
+
# ensure schema can be loaded
|
18
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
19
|
+
|
20
|
+
# prevent unknown OID warning
|
21
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)
|
22
|
+
end
|
23
|
+
|
24
|
+
module RegisterTypes
|
25
|
+
def initialize_type_map(m = type_map)
|
26
|
+
super
|
27
|
+
m.register_type "cube", Type::Cube.new
|
28
|
+
m.register_type "halfvec" do |_, _, sql_type|
|
29
|
+
limit = extract_limit(sql_type)
|
30
|
+
Type::Halfvec.new(limit: limit)
|
31
|
+
end
|
32
|
+
m.register_type "sparsevec" do |_, _, sql_type|
|
33
|
+
limit = extract_limit(sql_type)
|
34
|
+
Type::Sparsevec.new(limit: limit)
|
35
|
+
end
|
36
|
+
m.register_type "vector" do |_, _, sql_type|
|
37
|
+
limit = extract_limit(sql_type)
|
38
|
+
Type::Vector.new(limit: limit)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module SQLite
|
3
|
+
# note: this is a public API (unlike PostgreSQL and MySQL)
|
4
|
+
def self.initialize!
|
5
|
+
return if defined?(@initialized)
|
6
|
+
|
7
|
+
require_relative "type/sqlite_vector"
|
8
|
+
require_relative "type/sqlite_int8_vector"
|
9
|
+
|
10
|
+
require "sqlite_vec"
|
11
|
+
require "active_record/connection_adapters/sqlite3_adapter"
|
12
|
+
|
13
|
+
ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods)
|
14
|
+
|
15
|
+
@initialized = true
|
16
|
+
end
|
17
|
+
|
18
|
+
module InstanceMethods
|
19
|
+
def configure_connection
|
20
|
+
super
|
21
|
+
db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection
|
22
|
+
db.enable_load_extension(1)
|
23
|
+
SqliteVec.load(db)
|
24
|
+
db.enable_load_extension(0)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
@@ -0,0 +1,33 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class MysqlVector < ActiveRecord::Type::Binary
|
4
|
+
def type
|
5
|
+
:vector
|
6
|
+
end
|
7
|
+
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = value.to_a.pack("e*")
|
11
|
+
end
|
12
|
+
super(value)
|
13
|
+
end
|
14
|
+
|
15
|
+
def deserialize(value)
|
16
|
+
value = super
|
17
|
+
cast_value(value) unless value.nil?
|
18
|
+
end
|
19
|
+
|
20
|
+
private
|
21
|
+
|
22
|
+
def cast_value(value)
|
23
|
+
if value.is_a?(String)
|
24
|
+
value.unpack("e*")
|
25
|
+
elsif Utils.array?(value)
|
26
|
+
value.to_a
|
27
|
+
else
|
28
|
+
raise "can't cast #{value.class.name} to vector"
|
29
|
+
end
|
30
|
+
end
|
31
|
+
end
|
32
|
+
end
|
33
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class SqliteInt8Vector < ActiveRecord::Type::Binary
|
4
|
+
def serialize(value)
|
5
|
+
if Utils.array?(value)
|
6
|
+
value = value.to_a.pack("c*")
|
7
|
+
end
|
8
|
+
super(value)
|
9
|
+
end
|
10
|
+
|
11
|
+
def deserialize(value)
|
12
|
+
value = super
|
13
|
+
cast_value(value) unless value.nil?
|
14
|
+
end
|
15
|
+
|
16
|
+
private
|
17
|
+
|
18
|
+
def cast_value(value)
|
19
|
+
if value.is_a?(String)
|
20
|
+
value.unpack("c*")
|
21
|
+
elsif Utils.array?(value)
|
22
|
+
value.to_a
|
23
|
+
else
|
24
|
+
raise "can't cast #{value.class.name} to vector"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
@@ -0,0 +1,29 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class SqliteVector < ActiveRecord::Type::Binary
|
4
|
+
def serialize(value)
|
5
|
+
if Utils.array?(value)
|
6
|
+
value = value.to_a.pack("f*")
|
7
|
+
end
|
8
|
+
super(value)
|
9
|
+
end
|
10
|
+
|
11
|
+
def deserialize(value)
|
12
|
+
value = super
|
13
|
+
cast_value(value) unless value.nil?
|
14
|
+
end
|
15
|
+
|
16
|
+
private
|
17
|
+
|
18
|
+
def cast_value(value)
|
19
|
+
if value.is_a?(String)
|
20
|
+
value.unpack("f*")
|
21
|
+
elsif Utils.array?(value)
|
22
|
+
value.to_a
|
23
|
+
else
|
24
|
+
raise "can't cast #{value.class.name} to vector"
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
29
|
+
end
|
data/lib/neighbor/utils.rb
CHANGED
@@ -1,7 +1,9 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Utils
|
3
|
-
def self.validate_dimensions(value, type, expected)
|
3
|
+
def self.validate_dimensions(value, type, expected, adapter)
|
4
4
|
dimensions = type == :sparsevec ? value.dimensions : value.size
|
5
|
+
dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
|
6
|
+
|
5
7
|
if expected && dimensions != expected
|
6
8
|
"Expected #{expected} dimensions, not #{dimensions}"
|
7
9
|
end
|
@@ -9,7 +11,7 @@ module Neighbor
|
|
9
11
|
|
10
12
|
def self.validate_finite(value, type)
|
11
13
|
case type
|
12
|
-
when :bit
|
14
|
+
when :bit, :integer
|
13
15
|
true
|
14
16
|
when :sparsevec
|
15
17
|
value.values.all?(&:finite?)
|
@@ -18,17 +20,19 @@ module Neighbor
|
|
18
20
|
end
|
19
21
|
end
|
20
22
|
|
21
|
-
def self.validate(value, dimensions:,
|
22
|
-
if (message = validate_dimensions(value,
|
23
|
+
def self.validate(value, dimensions:, type:, adapter:)
|
24
|
+
if (message = validate_dimensions(value, type, dimensions, adapter))
|
23
25
|
raise Error, message
|
24
26
|
end
|
25
27
|
|
26
|
-
if !validate_finite(value,
|
28
|
+
if !validate_finite(value, type)
|
27
29
|
raise Error, "Values must be finite"
|
28
30
|
end
|
29
31
|
end
|
30
32
|
|
31
33
|
def self.normalize(value, column_info:)
|
34
|
+
return nil if value.nil?
|
35
|
+
|
32
36
|
raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec].include?(column_info&.type)
|
33
37
|
|
34
38
|
norm = Math.sqrt(value.sum { |v| v * v })
|
@@ -42,5 +46,156 @@ module Neighbor
|
|
42
46
|
def self.array?(value)
|
43
47
|
!value.nil? && value.respond_to?(:to_a)
|
44
48
|
end
|
49
|
+
|
50
|
+
def self.adapter(model)
|
51
|
+
case model.connection_db_config.adapter
|
52
|
+
when /sqlite/i
|
53
|
+
:sqlite
|
54
|
+
when /mysql|trilogy/i
|
55
|
+
model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
|
56
|
+
else
|
57
|
+
:postgresql
|
58
|
+
end
|
59
|
+
end
|
60
|
+
|
61
|
+
def self.type(adapter, column_type)
|
62
|
+
case adapter
|
63
|
+
when :mysql
|
64
|
+
if column_type == :binary
|
65
|
+
:bit
|
66
|
+
else
|
67
|
+
column_type
|
68
|
+
end
|
69
|
+
else
|
70
|
+
column_type
|
71
|
+
end
|
72
|
+
end
|
73
|
+
|
74
|
+
def self.operator(adapter, column_type, distance)
|
75
|
+
case adapter
|
76
|
+
when :sqlite
|
77
|
+
case distance
|
78
|
+
when "euclidean"
|
79
|
+
"vec_distance_L2"
|
80
|
+
when "cosine"
|
81
|
+
"vec_distance_cosine"
|
82
|
+
when "taxicab"
|
83
|
+
"vec_distance_L1"
|
84
|
+
when "hamming"
|
85
|
+
"vec_distance_hamming"
|
86
|
+
end
|
87
|
+
when :mariadb
|
88
|
+
case column_type
|
89
|
+
when :vector
|
90
|
+
case distance
|
91
|
+
when "euclidean"
|
92
|
+
"VEC_DISTANCE_EUCLIDEAN"
|
93
|
+
when "cosine"
|
94
|
+
"VEC_DISTANCE_COSINE"
|
95
|
+
end
|
96
|
+
when :integer
|
97
|
+
case distance
|
98
|
+
when "hamming"
|
99
|
+
"BIT_COUNT"
|
100
|
+
end
|
101
|
+
else
|
102
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
103
|
+
end
|
104
|
+
when :mysql
|
105
|
+
case column_type
|
106
|
+
when :vector
|
107
|
+
case distance
|
108
|
+
when "cosine"
|
109
|
+
"COSINE"
|
110
|
+
when "euclidean"
|
111
|
+
"EUCLIDEAN"
|
112
|
+
end
|
113
|
+
when :binary
|
114
|
+
case distance
|
115
|
+
when "hamming"
|
116
|
+
"BIT_COUNT"
|
117
|
+
end
|
118
|
+
else
|
119
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
120
|
+
end
|
121
|
+
else
|
122
|
+
case column_type
|
123
|
+
when :bit
|
124
|
+
case distance
|
125
|
+
when "hamming"
|
126
|
+
"<~>"
|
127
|
+
when "jaccard"
|
128
|
+
"<%>"
|
129
|
+
when "hamming2"
|
130
|
+
"#"
|
131
|
+
end
|
132
|
+
when :vector, :halfvec, :sparsevec
|
133
|
+
case distance
|
134
|
+
when "inner_product"
|
135
|
+
"<#>"
|
136
|
+
when "cosine"
|
137
|
+
"<=>"
|
138
|
+
when "euclidean"
|
139
|
+
"<->"
|
140
|
+
when "taxicab"
|
141
|
+
"<+>"
|
142
|
+
end
|
143
|
+
when :cube
|
144
|
+
case distance
|
145
|
+
when "taxicab"
|
146
|
+
"<#>"
|
147
|
+
when "chebyshev"
|
148
|
+
"<=>"
|
149
|
+
when "euclidean", "cosine"
|
150
|
+
"<->"
|
151
|
+
end
|
152
|
+
else
|
153
|
+
raise ArgumentError, "Unsupported type: #{column_type}"
|
154
|
+
end
|
155
|
+
end
|
156
|
+
end
|
157
|
+
|
158
|
+
def self.order(adapter, type, operator, quoted_attribute, query)
|
159
|
+
case adapter
|
160
|
+
when :sqlite
|
161
|
+
case type
|
162
|
+
when :int8
|
163
|
+
"#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
|
164
|
+
when :bit
|
165
|
+
"#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
|
166
|
+
else
|
167
|
+
"#{operator}(#{quoted_attribute}, #{query})"
|
168
|
+
end
|
169
|
+
when :mariadb
|
170
|
+
if operator == "BIT_COUNT"
|
171
|
+
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
|
172
|
+
else
|
173
|
+
"#{operator}(#{quoted_attribute}, #{query})"
|
174
|
+
end
|
175
|
+
when :mysql
|
176
|
+
if operator == "BIT_COUNT"
|
177
|
+
"BIT_COUNT(#{quoted_attribute} ^ #{query})"
|
178
|
+
elsif operator == "COSINE"
|
179
|
+
"DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
|
180
|
+
else
|
181
|
+
"DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
|
182
|
+
end
|
183
|
+
else
|
184
|
+
if operator == "#"
|
185
|
+
"bit_count(#{quoted_attribute} # #{query})"
|
186
|
+
else
|
187
|
+
"#{quoted_attribute} #{operator} #{query}"
|
188
|
+
end
|
189
|
+
end
|
190
|
+
end
|
191
|
+
|
192
|
+
def self.normalize_required?(adapter, column_type)
|
193
|
+
case adapter
|
194
|
+
when :postgresql
|
195
|
+
column_type == :cube
|
196
|
+
else
|
197
|
+
false
|
198
|
+
end
|
199
|
+
end
|
45
200
|
end
|
46
201
|
end
|
data/lib/neighbor/version.rb
CHANGED
data/lib/neighbor.rb
CHANGED
@@ -1,6 +1,11 @@
|
|
1
1
|
# dependencies
|
2
2
|
require "active_support"
|
3
3
|
|
4
|
+
# adapter hooks
|
5
|
+
require_relative "neighbor/mysql"
|
6
|
+
require_relative "neighbor/postgresql"
|
7
|
+
require_relative "neighbor/sqlite"
|
8
|
+
|
4
9
|
# modules
|
5
10
|
require_relative "neighbor/reranking"
|
6
11
|
require_relative "neighbor/sparse_vector"
|
@@ -9,53 +14,22 @@ require_relative "neighbor/version"
|
|
9
14
|
|
10
15
|
module Neighbor
|
11
16
|
class Error < StandardError; end
|
12
|
-
|
13
|
-
module RegisterTypes
|
14
|
-
def initialize_type_map(m = type_map)
|
15
|
-
super
|
16
|
-
m.register_type "cube", Type::Cube.new
|
17
|
-
m.register_type "halfvec" do |_, _, sql_type|
|
18
|
-
limit = extract_limit(sql_type)
|
19
|
-
Type::Halfvec.new(limit: limit)
|
20
|
-
end
|
21
|
-
m.register_type "sparsevec" do |_, _, sql_type|
|
22
|
-
limit = extract_limit(sql_type)
|
23
|
-
Type::Sparsevec.new(limit: limit)
|
24
|
-
end
|
25
|
-
m.register_type "vector" do |_, _, sql_type|
|
26
|
-
limit = extract_limit(sql_type)
|
27
|
-
Type::Vector.new(limit: limit)
|
28
|
-
end
|
29
|
-
end
|
30
|
-
end
|
31
17
|
end
|
32
18
|
|
33
19
|
ActiveSupport.on_load(:active_record) do
|
20
|
+
require_relative "neighbor/attribute"
|
34
21
|
require_relative "neighbor/model"
|
35
|
-
require_relative "neighbor/
|
36
|
-
require_relative "neighbor/type/halfvec"
|
37
|
-
require_relative "neighbor/type/sparsevec"
|
38
|
-
require_relative "neighbor/type/vector"
|
22
|
+
require_relative "neighbor/normalized_attribute"
|
39
23
|
|
40
24
|
extend Neighbor::Model
|
41
25
|
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
47
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
48
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
49
|
-
|
50
|
-
# ensure schema can be loaded
|
51
|
-
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
52
|
-
|
53
|
-
# prevent unknown OID warning
|
54
|
-
if ActiveRecord::VERSION::MAJOR >= 7
|
55
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(Neighbor::RegisterTypes)
|
56
|
-
else
|
57
|
-
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.prepend(Neighbor::RegisterTypes)
|
26
|
+
begin
|
27
|
+
Neighbor::PostgreSQL.initialize!
|
28
|
+
rescue Gem::LoadError
|
29
|
+
# tries to load pg gem, which may not be available
|
58
30
|
end
|
31
|
+
|
32
|
+
Neighbor::MySQL.initialize!
|
59
33
|
end
|
60
34
|
|
61
35
|
require_relative "neighbor/railtie" if defined?(Rails::Railtie)
|
metadata
CHANGED
@@ -1,14 +1,14 @@
|
|
1
1
|
--- !ruby/object:Gem::Specification
|
2
2
|
name: neighbor
|
3
3
|
version: !ruby/object:Gem::Version
|
4
|
-
version: 0.
|
4
|
+
version: 0.5.1
|
5
5
|
platform: ruby
|
6
6
|
authors:
|
7
7
|
- Andrew Kane
|
8
8
|
autorequire:
|
9
9
|
bindir: bin
|
10
10
|
cert_chain: []
|
11
|
-
date: 2024-
|
11
|
+
date: 2024-12-03 00:00:00.000000000 Z
|
12
12
|
dependencies:
|
13
13
|
- !ruby/object:Gem::Dependency
|
14
14
|
name: activerecord
|
@@ -16,14 +16,14 @@ dependencies:
|
|
16
16
|
requirements:
|
17
17
|
- - ">="
|
18
18
|
- !ruby/object:Gem::Version
|
19
|
-
version: '
|
19
|
+
version: '7'
|
20
20
|
type: :runtime
|
21
21
|
prerelease: false
|
22
22
|
version_requirements: !ruby/object:Gem::Requirement
|
23
23
|
requirements:
|
24
24
|
- - ">="
|
25
25
|
- !ruby/object:Gem::Version
|
26
|
-
version: '
|
26
|
+
version: '7'
|
27
27
|
description:
|
28
28
|
email: andrew@ankane.org
|
29
29
|
executables: []
|
@@ -34,17 +34,27 @@ files:
|
|
34
34
|
- LICENSE.txt
|
35
35
|
- README.md
|
36
36
|
- lib/generators/neighbor/cube_generator.rb
|
37
|
+
- lib/generators/neighbor/sqlite_generator.rb
|
37
38
|
- lib/generators/neighbor/templates/cube.rb.tt
|
39
|
+
- lib/generators/neighbor/templates/sqlite.rb.tt
|
38
40
|
- lib/generators/neighbor/templates/vector.rb.tt
|
39
41
|
- lib/generators/neighbor/vector_generator.rb
|
40
42
|
- lib/neighbor.rb
|
43
|
+
- lib/neighbor/attribute.rb
|
41
44
|
- lib/neighbor/model.rb
|
45
|
+
- lib/neighbor/mysql.rb
|
46
|
+
- lib/neighbor/normalized_attribute.rb
|
47
|
+
- lib/neighbor/postgresql.rb
|
42
48
|
- lib/neighbor/railtie.rb
|
43
49
|
- lib/neighbor/reranking.rb
|
44
50
|
- lib/neighbor/sparse_vector.rb
|
51
|
+
- lib/neighbor/sqlite.rb
|
45
52
|
- lib/neighbor/type/cube.rb
|
46
53
|
- lib/neighbor/type/halfvec.rb
|
54
|
+
- lib/neighbor/type/mysql_vector.rb
|
47
55
|
- lib/neighbor/type/sparsevec.rb
|
56
|
+
- lib/neighbor/type/sqlite_int8_vector.rb
|
57
|
+
- lib/neighbor/type/sqlite_vector.rb
|
48
58
|
- lib/neighbor/type/vector.rb
|
49
59
|
- lib/neighbor/utils.rb
|
50
60
|
- lib/neighbor/version.rb
|
@@ -67,8 +77,8 @@ required_rubygems_version: !ruby/object:Gem::Requirement
|
|
67
77
|
- !ruby/object:Gem::Version
|
68
78
|
version: '0'
|
69
79
|
requirements: []
|
70
|
-
rubygems_version: 3.5.
|
80
|
+
rubygems_version: 3.5.22
|
71
81
|
signing_key:
|
72
82
|
specification_version: 4
|
73
|
-
summary: Nearest neighbor search for Rails
|
83
|
+
summary: Nearest neighbor search for Rails
|
74
84
|
test_files: []
|