neighbor 0.3.2 → 0.5.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/LICENSE.txt +1 -1
- data/README.md +659 -53
- data/lib/generators/neighbor/cube_generator.rb +1 -0
- data/lib/generators/neighbor/sqlite_generator.rb +13 -0
- data/lib/generators/neighbor/templates/sqlite.rb.tt +2 -0
- data/lib/generators/neighbor/vector_generator.rb +1 -0
- data/lib/neighbor/attribute.rb +48 -0
- data/lib/neighbor/model.rb +93 -66
- data/lib/neighbor/mysql.rb +37 -0
- data/lib/neighbor/normalized_attribute.rb +21 -0
- data/lib/neighbor/postgresql.rb +43 -0
- data/lib/neighbor/railtie.rb +4 -4
- data/lib/neighbor/reranking.rb +27 -0
- data/lib/neighbor/sparse_vector.rb +79 -0
- data/lib/neighbor/sqlite.rb +28 -0
- data/lib/neighbor/type/cube.rb +24 -19
- data/lib/neighbor/type/halfvec.rb +28 -0
- data/lib/neighbor/type/mysql_vector.rb +33 -0
- data/lib/neighbor/type/sparsevec.rb +30 -0
- data/lib/neighbor/type/sqlite_int8_vector.rb +29 -0
- data/lib/neighbor/type/sqlite_vector.rb +29 -0
- data/lib/neighbor/type/vector.rb +19 -5
- data/lib/neighbor/utils.rb +201 -0
- data/lib/neighbor/version.rb +1 -1
- data/lib/neighbor.rb +16 -28
- metadata +22 -8
- data/lib/neighbor/vector.rb +0 -65
@@ -0,0 +1,13 @@
|
|
1
|
+
require "rails/generators"
|
2
|
+
|
3
|
+
module Neighbor
|
4
|
+
module Generators
|
5
|
+
class SqliteGenerator < Rails::Generators::Base
|
6
|
+
source_root File.join(__dir__, "templates")
|
7
|
+
|
8
|
+
def copy_templates
|
9
|
+
template "sqlite.rb", "config/initializers/neighbor.rb"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class Attribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, type:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@type = type
|
9
|
+
@attribute_name = attribute_name
|
10
|
+
end
|
11
|
+
|
12
|
+
private
|
13
|
+
|
14
|
+
def cast_value(...)
|
15
|
+
new_cast_type.send(:cast_value, ...)
|
16
|
+
end
|
17
|
+
|
18
|
+
def new_cast_type
|
19
|
+
@new_cast_type ||= begin
|
20
|
+
if @cast_type.is_a?(ActiveModel::Type::Value)
|
21
|
+
case Utils.adapter(@model)
|
22
|
+
when :sqlite
|
23
|
+
case @type&.to_sym
|
24
|
+
when :int8
|
25
|
+
Type::SqliteInt8Vector.new
|
26
|
+
when :bit
|
27
|
+
@cast_type
|
28
|
+
when :float32, nil
|
29
|
+
Type::SqliteVector.new
|
30
|
+
else
|
31
|
+
raise ArgumentError, "Unsupported type"
|
32
|
+
end
|
33
|
+
when :mariadb
|
34
|
+
if @model.columns_hash[@attribute_name.to_s]&.type == :integer
|
35
|
+
@cast_type
|
36
|
+
else
|
37
|
+
Type::MysqlVector.new
|
38
|
+
end
|
39
|
+
else
|
40
|
+
@cast_type
|
41
|
+
end
|
42
|
+
else
|
43
|
+
@cast_type
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/neighbor/model.rb
CHANGED
@@ -1,12 +1,10 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Model
|
3
|
-
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
|
3
|
+
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
|
4
4
|
if attribute_names.empty?
|
5
|
-
|
6
|
-
attribute_names << :neighbor_vector
|
7
|
-
else
|
8
|
-
attribute_names.map!(&:to_sym)
|
5
|
+
raise ArgumentError, "has_neighbors requires an attribute name"
|
9
6
|
end
|
7
|
+
attribute_names.map!(&:to_sym)
|
10
8
|
|
11
9
|
class_eval do
|
12
10
|
@neighbor_attributes ||= {}
|
@@ -26,84 +24,116 @@ module Neighbor
|
|
26
24
|
|
27
25
|
attribute_names.each do |attribute_name|
|
28
26
|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
|
29
|
-
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
|
27
|
+
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
|
28
|
+
end
|
30
29
|
|
31
|
-
|
30
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.2
|
31
|
+
decorate_attributes(attribute_names) do |name, cast_type|
|
32
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
|
33
|
+
end
|
34
|
+
else
|
35
|
+
attribute_names.each do |attribute_name|
|
36
|
+
attribute attribute_name do |cast_type|
|
37
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
if normalize
|
43
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.1
|
44
|
+
attribute_names.each do |attribute_name|
|
45
|
+
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
|
46
|
+
end
|
47
|
+
else
|
48
|
+
attribute_names.each do |attribute_name|
|
49
|
+
attribute attribute_name do |cast_type|
|
50
|
+
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
32
54
|
end
|
33
55
|
|
34
56
|
return if @neighbor_attributes.size != attribute_names.size
|
35
57
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
58
|
+
validate do
|
59
|
+
adapter = Utils.adapter(self.class)
|
60
|
+
|
61
|
+
self.class.neighbor_attributes.each do |k, v|
|
62
|
+
value = read_attribute(k)
|
63
|
+
next if value.nil?
|
64
|
+
|
65
|
+
column_info = self.class.columns_hash[k.to_s]
|
66
|
+
dimensions = v[:dimensions]
|
67
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
68
|
+
type = v[:type] || Utils.type(adapter, column_info&.type)
|
69
|
+
|
70
|
+
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
|
71
|
+
errors.add(k, "must have #{dimensions} dimensions")
|
72
|
+
end
|
73
|
+
if !Neighbor::Utils.validate_finite(value, type)
|
74
|
+
errors.add(k, "must have finite values")
|
75
|
+
end
|
51
76
|
end
|
52
|
-
|
77
|
+
end
|
53
78
|
|
79
|
+
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
|
80
|
+
attribute_name = attribute_name.to_sym
|
54
81
|
options = neighbor_attributes[attribute_name]
|
55
82
|
raise ArgumentError, "Invalid attribute" unless options
|
56
83
|
normalize = options[:normalize]
|
57
84
|
dimensions = options[:dimensions]
|
85
|
+
type = options[:type]
|
58
86
|
|
59
87
|
return none if vector.nil?
|
60
88
|
|
61
89
|
distance = distance.to_s
|
62
90
|
|
63
|
-
|
91
|
+
column_info = columns_hash[attribute_name.to_s]
|
92
|
+
column_type = column_info&.type
|
64
93
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
case distance
|
70
|
-
when "inner_product"
|
71
|
-
"<#>"
|
72
|
-
when "cosine"
|
73
|
-
"<=>"
|
74
|
-
when "euclidean"
|
75
|
-
"<->"
|
76
|
-
end
|
77
|
-
else
|
78
|
-
case distance
|
79
|
-
when "taxicab"
|
80
|
-
"<#>"
|
81
|
-
when "chebyshev"
|
82
|
-
"<=>"
|
83
|
-
when "euclidean", "cosine"
|
84
|
-
"<->"
|
85
|
-
end
|
86
|
-
end
|
94
|
+
adapter = Neighbor::Utils.adapter(klass)
|
95
|
+
if type && adapter != :sqlite
|
96
|
+
raise ArgumentError, "type only works with SQLite"
|
97
|
+
end
|
87
98
|
|
99
|
+
operator = Neighbor::Utils.operator(adapter, column_type, distance)
|
88
100
|
raise ArgumentError, "Invalid distance: #{distance}" unless operator
|
89
101
|
|
90
102
|
# ensure normalize set (can be true or false)
|
91
|
-
|
103
|
+
normalize_required = Utils.normalize_required?(adapter, column_type)
|
104
|
+
if distance == "cosine" && normalize_required && normalize.nil?
|
92
105
|
raise Neighbor::Error, "Set normalize for cosine distance with cube"
|
93
106
|
end
|
94
107
|
|
95
|
-
|
108
|
+
column_attribute = klass.type_for_attribute(attribute_name)
|
109
|
+
vector = column_attribute.cast(vector)
|
110
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
111
|
+
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
|
112
|
+
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
|
113
|
+
|
114
|
+
quoted_attribute = nil
|
115
|
+
query = nil
|
116
|
+
connection_pool.with_connection do |c|
|
117
|
+
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
|
118
|
+
query = c.quote(column_attribute.serialize(vector))
|
119
|
+
end
|
96
120
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
121
|
+
if !precision.nil?
|
122
|
+
if adapter != :postgresql || column_type != :vector
|
123
|
+
raise ArgumentError, "Precision not supported for this type"
|
124
|
+
end
|
125
|
+
|
126
|
+
case precision.to_s
|
127
|
+
when "half"
|
128
|
+
cast_dimensions = dimensions || column_info&.limit
|
129
|
+
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
|
130
|
+
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
|
102
131
|
else
|
103
|
-
|
132
|
+
raise ArgumentError, "Invalid precision"
|
104
133
|
end
|
134
|
+
end
|
105
135
|
|
106
|
-
order =
|
136
|
+
order = Utils.order(adapter, type, operator, quoted_attribute, query)
|
107
137
|
|
108
138
|
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
|
109
139
|
# with normalized vectors:
|
@@ -111,31 +141,28 @@ module Neighbor
|
|
111
141
|
# cosine distance = 1 - cosine similarity
|
112
142
|
# this transformation doesn't change the order, so only needed for select
|
113
143
|
neighbor_distance =
|
114
|
-
if
|
144
|
+
if distance == "cosine" && normalize_required
|
115
145
|
"POWER(#{order}, 2) / 2.0"
|
116
|
-
elsif
|
146
|
+
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
|
117
147
|
"(#{order}) * -1"
|
118
148
|
else
|
119
149
|
order
|
120
150
|
end
|
121
151
|
|
122
152
|
# for select, use column_names instead of * to account for ignored columns
|
123
|
-
|
153
|
+
select_columns = select_values.any? ? [] : column_names
|
154
|
+
select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
|
124
155
|
.where.not(attribute_name => nil)
|
125
|
-
.
|
156
|
+
.reorder(Arel.sql(order))
|
126
157
|
}
|
127
158
|
|
128
|
-
def nearest_neighbors(attribute_name
|
129
|
-
if attribute_name.nil?
|
130
|
-
warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
|
131
|
-
attribute_name = :neighbor_vector
|
132
|
-
end
|
159
|
+
def nearest_neighbors(attribute_name, **options)
|
133
160
|
attribute_name = attribute_name.to_sym
|
134
|
-
# important! check if neighbor attribute before
|
161
|
+
# important! check if neighbor attribute before accessing
|
135
162
|
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
|
136
163
|
|
137
164
|
self.class
|
138
|
-
.where.not(self.class.primary_key
|
165
|
+
.where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
|
139
166
|
.nearest_neighbors(attribute_name, self[attribute_name], **options)
|
140
167
|
end
|
141
168
|
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module MySQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/mysql_vector"
|
5
|
+
|
6
|
+
require "active_record/connection_adapters/abstract_mysql_adapter"
|
7
|
+
|
8
|
+
# ensure schema can be dumped
|
9
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
10
|
+
|
11
|
+
# ensure schema can be loaded
|
12
|
+
unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector)
|
13
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector)
|
14
|
+
end
|
15
|
+
|
16
|
+
# prevent unknown OID warning
|
17
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes)
|
18
|
+
if ActiveRecord::VERSION::STRING.to_f < 7.1
|
19
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
module RegisterTypes
|
24
|
+
def initialize_type_map(m)
|
25
|
+
super
|
26
|
+
register_vector_type(m)
|
27
|
+
end
|
28
|
+
|
29
|
+
def register_vector_type(m)
|
30
|
+
m.register_type %r(^vector)i do |sql_type|
|
31
|
+
limit = extract_limit(sql_type)
|
32
|
+
Type::MysqlVector.new(limit: limit)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class NormalizedAttribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, to: :@cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@attribute_name = attribute_name.to_s
|
9
|
+
end
|
10
|
+
|
11
|
+
def cast(...)
|
12
|
+
Neighbor::Utils.normalize(@cast_type.cast(...), column_info: @model.columns_hash[@attribute_name])
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(...)
|
18
|
+
@cast_type.send(:cast_value, ...)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module PostgreSQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/cube"
|
5
|
+
require_relative "type/halfvec"
|
6
|
+
require_relative "type/sparsevec"
|
7
|
+
require_relative "type/vector"
|
8
|
+
|
9
|
+
require "active_record/connection_adapters/postgresql_adapter"
|
10
|
+
|
11
|
+
# ensure schema can be dumped
|
12
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
|
13
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
14
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
15
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
16
|
+
|
17
|
+
# ensure schema can be loaded
|
18
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
19
|
+
|
20
|
+
# prevent unknown OID warning
|
21
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)
|
22
|
+
end
|
23
|
+
|
24
|
+
module RegisterTypes
|
25
|
+
def initialize_type_map(m = type_map)
|
26
|
+
super
|
27
|
+
m.register_type "cube", Type::Cube.new
|
28
|
+
m.register_type "halfvec" do |_, _, sql_type|
|
29
|
+
limit = extract_limit(sql_type)
|
30
|
+
Type::Halfvec.new(limit: limit)
|
31
|
+
end
|
32
|
+
m.register_type "sparsevec" do |_, _, sql_type|
|
33
|
+
limit = extract_limit(sql_type)
|
34
|
+
Type::Sparsevec.new(limit: limit)
|
35
|
+
end
|
36
|
+
m.register_type "vector" do |_, _, sql_type|
|
37
|
+
limit = extract_limit(sql_type)
|
38
|
+
Type::Vector.new(limit: limit)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
data/lib/neighbor/railtie.rb
CHANGED
@@ -1,16 +1,16 @@
|
|
1
1
|
module Neighbor
|
2
2
|
class Railtie < Rails::Railtie
|
3
3
|
generators do
|
4
|
+
require "rails/generators/generated_attribute"
|
5
|
+
|
4
6
|
# rails generate model Item embedding:vector{3}
|
5
|
-
|
6
|
-
Rails::Generators::GeneratedAttribute.singleton_class.prepend(Neighbor::GeneratedAttribute)
|
7
|
-
end
|
7
|
+
Rails::Generators::GeneratedAttribute.singleton_class.prepend(Neighbor::GeneratedAttribute)
|
8
8
|
end
|
9
9
|
end
|
10
10
|
|
11
11
|
module GeneratedAttribute
|
12
12
|
def parse_type_and_options(type, *, **)
|
13
|
-
if type =~ /\A(vector)\{(\d+)\}\z/
|
13
|
+
if type =~ /\A(vector|halfvec|bit|sparsevec)\{(\d+)\}\z/
|
14
14
|
return $1, limit: $2.to_i
|
15
15
|
end
|
16
16
|
super
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Reranking
|
3
|
+
def self.rrf(first_ranking, *rankings, k: 60)
|
4
|
+
rankings.unshift(first_ranking)
|
5
|
+
|
6
|
+
ranks = []
|
7
|
+
results = []
|
8
|
+
rankings.each do |ranking|
|
9
|
+
ranks << ranking.map.with_index.to_h { |v, i| [v, i + 1] }
|
10
|
+
results.concat(ranking)
|
11
|
+
end
|
12
|
+
|
13
|
+
results =
|
14
|
+
results.uniq.map do |result|
|
15
|
+
score =
|
16
|
+
ranks.sum do |rank|
|
17
|
+
r = rank[result]
|
18
|
+
r ? 1.0 / (k + r) : 0.0
|
19
|
+
end
|
20
|
+
|
21
|
+
{result: result, score: score}
|
22
|
+
end
|
23
|
+
|
24
|
+
results.sort_by { |v| -v[:score] }
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,79 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class SparseVector
|
3
|
+
attr_reader :dimensions, :indices, :values
|
4
|
+
|
5
|
+
NO_DEFAULT = Object.new
|
6
|
+
|
7
|
+
def initialize(value, dimensions = NO_DEFAULT)
|
8
|
+
if value.is_a?(Hash)
|
9
|
+
if dimensions == NO_DEFAULT
|
10
|
+
raise ArgumentError, "missing dimensions"
|
11
|
+
end
|
12
|
+
from_hash(value, dimensions)
|
13
|
+
else
|
14
|
+
unless dimensions == NO_DEFAULT
|
15
|
+
raise ArgumentError, "extra argument"
|
16
|
+
end
|
17
|
+
from_array(value)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def to_s
|
22
|
+
"{#{@indices.zip(@values).map { |i, v| "#{i.to_i + 1}:#{v.to_f}" }.join(",")}}/#{@dimensions.to_i}"
|
23
|
+
end
|
24
|
+
|
25
|
+
def to_a
|
26
|
+
arr = Array.new(dimensions, 0.0)
|
27
|
+
@indices.zip(@values) do |i, v|
|
28
|
+
arr[i] = v
|
29
|
+
end
|
30
|
+
arr
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def from_hash(data, dimensions)
|
36
|
+
elements = data.select { |_, v| v != 0 }.sort
|
37
|
+
@dimensions = dimensions.to_i
|
38
|
+
@indices = elements.map { |v| v[0].to_i }
|
39
|
+
@values = elements.map { |v| v[1].to_f }
|
40
|
+
end
|
41
|
+
|
42
|
+
def from_array(arr)
|
43
|
+
arr = arr.to_a
|
44
|
+
@dimensions = arr.size
|
45
|
+
@indices = []
|
46
|
+
@values = []
|
47
|
+
arr.each_with_index do |v, i|
|
48
|
+
if v != 0
|
49
|
+
@indices << i
|
50
|
+
@values << v.to_f
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
class << self
|
56
|
+
def from_text(string)
|
57
|
+
elements, dimensions = string.split("/", 2)
|
58
|
+
indices = []
|
59
|
+
values = []
|
60
|
+
elements[1..-2].split(",").each do |e|
|
61
|
+
index, value = e.split(":", 2)
|
62
|
+
indices << index.to_i - 1
|
63
|
+
values << value.to_f
|
64
|
+
end
|
65
|
+
from_parts(dimensions.to_i, indices, values)
|
66
|
+
end
|
67
|
+
|
68
|
+
private
|
69
|
+
|
70
|
+
def from_parts(dimensions, indices, values)
|
71
|
+
vec = allocate
|
72
|
+
vec.instance_variable_set(:@dimensions, dimensions)
|
73
|
+
vec.instance_variable_set(:@indices, indices)
|
74
|
+
vec.instance_variable_set(:@values, values)
|
75
|
+
vec
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module SQLite
|
3
|
+
# note: this is a public API (unlike PostgreSQL and MySQL)
|
4
|
+
def self.initialize!
|
5
|
+
return if defined?(@initialized)
|
6
|
+
|
7
|
+
require_relative "type/sqlite_vector"
|
8
|
+
require_relative "type/sqlite_int8_vector"
|
9
|
+
|
10
|
+
require "sqlite_vec"
|
11
|
+
require "active_record/connection_adapters/sqlite3_adapter"
|
12
|
+
|
13
|
+
ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods)
|
14
|
+
|
15
|
+
@initialized = true
|
16
|
+
end
|
17
|
+
|
18
|
+
module InstanceMethods
|
19
|
+
def configure_connection
|
20
|
+
super
|
21
|
+
db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection
|
22
|
+
db.enable_load_extension(1)
|
23
|
+
SqliteVec.load(db)
|
24
|
+
db.enable_load_extension(0)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
data/lib/neighbor/type/cube.rb
CHANGED
@@ -1,36 +1,41 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Type
|
3
|
-
class Cube < ActiveRecord::Type::
|
3
|
+
class Cube < ActiveRecord::Type::Value
|
4
4
|
def type
|
5
5
|
:cube
|
6
6
|
end
|
7
7
|
|
8
|
-
def
|
9
|
-
if
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = value.to_a
|
10
11
|
if value.first.is_a?(Array)
|
11
|
-
value.map { |v|
|
12
|
+
value = value.map { |v| serialize_point(v) }.join(", ")
|
12
13
|
else
|
13
|
-
|
14
|
+
value = serialize_point(value)
|
14
15
|
end
|
15
|
-
else
|
16
|
-
super
|
17
16
|
end
|
17
|
+
super(value)
|
18
18
|
end
|
19
19
|
|
20
|
-
# TODO uncomment in 0.4.0
|
21
|
-
# def deserialize(value)
|
22
|
-
# if value.nil?
|
23
|
-
# super
|
24
|
-
# elsif value.include?("),(")
|
25
|
-
# value[1..-1].split("),(").map { |v| v.split(",").map(&:to_f) }
|
26
|
-
# else
|
27
|
-
# value[1..-1].split(",").map(&:to_f)
|
28
|
-
# end
|
29
|
-
# end
|
30
|
-
|
31
20
|
private
|
32
21
|
|
33
|
-
def
|
22
|
+
def cast_value(value)
|
23
|
+
if Utils.array?(value)
|
24
|
+
value.to_a
|
25
|
+
elsif value.is_a?(Numeric)
|
26
|
+
[value]
|
27
|
+
elsif value.is_a?(String)
|
28
|
+
if value.include?("),(")
|
29
|
+
value[1..-1].split("),(").map { |v| v.split(",").map(&:to_f) }
|
30
|
+
else
|
31
|
+
value[1..-1].split(",").map(&:to_f)
|
32
|
+
end
|
33
|
+
else
|
34
|
+
raise "can't cast #{value.class.name} to cube"
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def serialize_point(value)
|
34
39
|
"(#{value.map(&:to_f).join(", ")})"
|
35
40
|
end
|
36
41
|
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class Halfvec < ActiveRecord::Type::Value
|
4
|
+
def type
|
5
|
+
:halfvec
|
6
|
+
end
|
7
|
+
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = "[#{value.to_a.map(&:to_f).join(",")}]"
|
11
|
+
end
|
12
|
+
super(value)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(value)
|
18
|
+
if value.is_a?(String)
|
19
|
+
value[1..-1].split(",").map(&:to_f)
|
20
|
+
elsif Utils.array?(value)
|
21
|
+
value.to_a
|
22
|
+
else
|
23
|
+
raise "can't cast #{value.class.name} to halfvec"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|