neighbor 0.3.2 → 0.5.0
Sign up to get free protection for your applications and to get access to all the features.
- checksums.yaml +4 -4
- data/CHANGELOG.md +36 -0
- data/LICENSE.txt +1 -1
- data/README.md +659 -53
- data/lib/generators/neighbor/cube_generator.rb +1 -0
- data/lib/generators/neighbor/sqlite_generator.rb +13 -0
- data/lib/generators/neighbor/templates/sqlite.rb.tt +2 -0
- data/lib/generators/neighbor/vector_generator.rb +1 -0
- data/lib/neighbor/attribute.rb +48 -0
- data/lib/neighbor/model.rb +93 -66
- data/lib/neighbor/mysql.rb +37 -0
- data/lib/neighbor/normalized_attribute.rb +21 -0
- data/lib/neighbor/postgresql.rb +43 -0
- data/lib/neighbor/railtie.rb +4 -4
- data/lib/neighbor/reranking.rb +27 -0
- data/lib/neighbor/sparse_vector.rb +79 -0
- data/lib/neighbor/sqlite.rb +28 -0
- data/lib/neighbor/type/cube.rb +24 -19
- data/lib/neighbor/type/halfvec.rb +28 -0
- data/lib/neighbor/type/mysql_vector.rb +33 -0
- data/lib/neighbor/type/sparsevec.rb +30 -0
- data/lib/neighbor/type/sqlite_int8_vector.rb +29 -0
- data/lib/neighbor/type/sqlite_vector.rb +29 -0
- data/lib/neighbor/type/vector.rb +19 -5
- data/lib/neighbor/utils.rb +201 -0
- data/lib/neighbor/version.rb +1 -1
- data/lib/neighbor.rb +16 -28
- metadata +22 -8
- data/lib/neighbor/vector.rb +0 -65
@@ -0,0 +1,13 @@
|
|
1
|
+
require "rails/generators"
|
2
|
+
|
3
|
+
module Neighbor
|
4
|
+
module Generators
|
5
|
+
class SqliteGenerator < Rails::Generators::Base
|
6
|
+
source_root File.join(__dir__, "templates")
|
7
|
+
|
8
|
+
def copy_templates
|
9
|
+
template "sqlite.rb", "config/initializers/neighbor.rb"
|
10
|
+
end
|
11
|
+
end
|
12
|
+
end
|
13
|
+
end
|
@@ -0,0 +1,48 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class Attribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, :cast, to: :new_cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, type:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@type = type
|
9
|
+
@attribute_name = attribute_name
|
10
|
+
end
|
11
|
+
|
12
|
+
private
|
13
|
+
|
14
|
+
def cast_value(...)
|
15
|
+
new_cast_type.send(:cast_value, ...)
|
16
|
+
end
|
17
|
+
|
18
|
+
def new_cast_type
|
19
|
+
@new_cast_type ||= begin
|
20
|
+
if @cast_type.is_a?(ActiveModel::Type::Value)
|
21
|
+
case Utils.adapter(@model)
|
22
|
+
when :sqlite
|
23
|
+
case @type&.to_sym
|
24
|
+
when :int8
|
25
|
+
Type::SqliteInt8Vector.new
|
26
|
+
when :bit
|
27
|
+
@cast_type
|
28
|
+
when :float32, nil
|
29
|
+
Type::SqliteVector.new
|
30
|
+
else
|
31
|
+
raise ArgumentError, "Unsupported type"
|
32
|
+
end
|
33
|
+
when :mariadb
|
34
|
+
if @model.columns_hash[@attribute_name.to_s]&.type == :integer
|
35
|
+
@cast_type
|
36
|
+
else
|
37
|
+
Type::MysqlVector.new
|
38
|
+
end
|
39
|
+
else
|
40
|
+
@cast_type
|
41
|
+
end
|
42
|
+
else
|
43
|
+
@cast_type
|
44
|
+
end
|
45
|
+
end
|
46
|
+
end
|
47
|
+
end
|
48
|
+
end
|
data/lib/neighbor/model.rb
CHANGED
@@ -1,12 +1,10 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Model
|
3
|
-
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil)
|
3
|
+
def has_neighbors(*attribute_names, dimensions: nil, normalize: nil, type: nil)
|
4
4
|
if attribute_names.empty?
|
5
|
-
|
6
|
-
attribute_names << :neighbor_vector
|
7
|
-
else
|
8
|
-
attribute_names.map!(&:to_sym)
|
5
|
+
raise ArgumentError, "has_neighbors requires an attribute name"
|
9
6
|
end
|
7
|
+
attribute_names.map!(&:to_sym)
|
10
8
|
|
11
9
|
class_eval do
|
12
10
|
@neighbor_attributes ||= {}
|
@@ -26,84 +24,116 @@ module Neighbor
|
|
26
24
|
|
27
25
|
attribute_names.each do |attribute_name|
|
28
26
|
raise Error, "has_neighbors already called for #{attribute_name.inspect}" if neighbor_attributes[attribute_name]
|
29
|
-
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize}
|
27
|
+
@neighbor_attributes[attribute_name] = {dimensions: dimensions, normalize: normalize, type: type&.to_sym}
|
28
|
+
end
|
30
29
|
|
31
|
-
|
30
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.2
|
31
|
+
decorate_attributes(attribute_names) do |name, cast_type|
|
32
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: name)
|
33
|
+
end
|
34
|
+
else
|
35
|
+
attribute_names.each do |attribute_name|
|
36
|
+
attribute attribute_name do |cast_type|
|
37
|
+
Neighbor::Attribute.new(cast_type: cast_type, model: self, type: type, attribute_name: attribute_name)
|
38
|
+
end
|
39
|
+
end
|
40
|
+
end
|
41
|
+
|
42
|
+
if normalize
|
43
|
+
if ActiveRecord::VERSION::STRING.to_f >= 7.1
|
44
|
+
attribute_names.each do |attribute_name|
|
45
|
+
normalizes attribute_name, with: ->(v) { Neighbor::Utils.normalize(v, column_info: columns_hash[attribute_name.to_s]) }
|
46
|
+
end
|
47
|
+
else
|
48
|
+
attribute_names.each do |attribute_name|
|
49
|
+
attribute attribute_name do |cast_type|
|
50
|
+
Neighbor::NormalizedAttribute.new(cast_type: cast_type, model: self, attribute_name: attribute_name)
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
32
54
|
end
|
33
55
|
|
34
56
|
return if @neighbor_attributes.size != attribute_names.size
|
35
57
|
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
58
|
+
validate do
|
59
|
+
adapter = Utils.adapter(self.class)
|
60
|
+
|
61
|
+
self.class.neighbor_attributes.each do |k, v|
|
62
|
+
value = read_attribute(k)
|
63
|
+
next if value.nil?
|
64
|
+
|
65
|
+
column_info = self.class.columns_hash[k.to_s]
|
66
|
+
dimensions = v[:dimensions]
|
67
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
68
|
+
type = v[:type] || Utils.type(adapter, column_info&.type)
|
69
|
+
|
70
|
+
if !Neighbor::Utils.validate_dimensions(value, type, dimensions, adapter).nil?
|
71
|
+
errors.add(k, "must have #{dimensions} dimensions")
|
72
|
+
end
|
73
|
+
if !Neighbor::Utils.validate_finite(value, type)
|
74
|
+
errors.add(k, "must have finite values")
|
75
|
+
end
|
51
76
|
end
|
52
|
-
|
77
|
+
end
|
53
78
|
|
79
|
+
scope :nearest_neighbors, ->(attribute_name, vector, distance:, precision: nil) {
|
80
|
+
attribute_name = attribute_name.to_sym
|
54
81
|
options = neighbor_attributes[attribute_name]
|
55
82
|
raise ArgumentError, "Invalid attribute" unless options
|
56
83
|
normalize = options[:normalize]
|
57
84
|
dimensions = options[:dimensions]
|
85
|
+
type = options[:type]
|
58
86
|
|
59
87
|
return none if vector.nil?
|
60
88
|
|
61
89
|
distance = distance.to_s
|
62
90
|
|
63
|
-
|
91
|
+
column_info = columns_hash[attribute_name.to_s]
|
92
|
+
column_type = column_info&.type
|
64
93
|
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
case distance
|
70
|
-
when "inner_product"
|
71
|
-
"<#>"
|
72
|
-
when "cosine"
|
73
|
-
"<=>"
|
74
|
-
when "euclidean"
|
75
|
-
"<->"
|
76
|
-
end
|
77
|
-
else
|
78
|
-
case distance
|
79
|
-
when "taxicab"
|
80
|
-
"<#>"
|
81
|
-
when "chebyshev"
|
82
|
-
"<=>"
|
83
|
-
when "euclidean", "cosine"
|
84
|
-
"<->"
|
85
|
-
end
|
86
|
-
end
|
94
|
+
adapter = Neighbor::Utils.adapter(klass)
|
95
|
+
if type && adapter != :sqlite
|
96
|
+
raise ArgumentError, "type only works with SQLite"
|
97
|
+
end
|
87
98
|
|
99
|
+
operator = Neighbor::Utils.operator(adapter, column_type, distance)
|
88
100
|
raise ArgumentError, "Invalid distance: #{distance}" unless operator
|
89
101
|
|
90
102
|
# ensure normalize set (can be true or false)
|
91
|
-
|
103
|
+
normalize_required = Utils.normalize_required?(adapter, column_type)
|
104
|
+
if distance == "cosine" && normalize_required && normalize.nil?
|
92
105
|
raise Neighbor::Error, "Set normalize for cosine distance with cube"
|
93
106
|
end
|
94
107
|
|
95
|
-
|
108
|
+
column_attribute = klass.type_for_attribute(attribute_name)
|
109
|
+
vector = column_attribute.cast(vector)
|
110
|
+
dimensions ||= column_info&.limit unless column_info&.type == :binary
|
111
|
+
Neighbor::Utils.validate(vector, dimensions: dimensions, type: type || Utils.type(adapter, column_info&.type), adapter: adapter)
|
112
|
+
vector = Neighbor::Utils.normalize(vector, column_info: column_info) if normalize
|
113
|
+
|
114
|
+
quoted_attribute = nil
|
115
|
+
query = nil
|
116
|
+
connection_pool.with_connection do |c|
|
117
|
+
quoted_attribute = "#{c.quote_table_name(table_name)}.#{c.quote_column_name(attribute_name)}"
|
118
|
+
query = c.quote(column_attribute.serialize(vector))
|
119
|
+
end
|
96
120
|
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
121
|
+
if !precision.nil?
|
122
|
+
if adapter != :postgresql || column_type != :vector
|
123
|
+
raise ArgumentError, "Precision not supported for this type"
|
124
|
+
end
|
125
|
+
|
126
|
+
case precision.to_s
|
127
|
+
when "half"
|
128
|
+
cast_dimensions = dimensions || column_info&.limit
|
129
|
+
raise ArgumentError, "Unknown dimensions" unless cast_dimensions
|
130
|
+
quoted_attribute += "::halfvec(#{connection_pool.with_connection { |c| c.quote(cast_dimensions.to_i) }})"
|
102
131
|
else
|
103
|
-
|
132
|
+
raise ArgumentError, "Invalid precision"
|
104
133
|
end
|
134
|
+
end
|
105
135
|
|
106
|
-
order =
|
136
|
+
order = Utils.order(adapter, type, operator, quoted_attribute, query)
|
107
137
|
|
108
138
|
# https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
|
109
139
|
# with normalized vectors:
|
@@ -111,31 +141,28 @@ module Neighbor
|
|
111
141
|
# cosine distance = 1 - cosine similarity
|
112
142
|
# this transformation doesn't change the order, so only needed for select
|
113
143
|
neighbor_distance =
|
114
|
-
if
|
144
|
+
if distance == "cosine" && normalize_required
|
115
145
|
"POWER(#{order}, 2) / 2.0"
|
116
|
-
elsif
|
146
|
+
elsif [:vector, :halfvec, :sparsevec].include?(column_type) && distance == "inner_product"
|
117
147
|
"(#{order}) * -1"
|
118
148
|
else
|
119
149
|
order
|
120
150
|
end
|
121
151
|
|
122
152
|
# for select, use column_names instead of * to account for ignored columns
|
123
|
-
|
153
|
+
select_columns = select_values.any? ? [] : column_names
|
154
|
+
select(*select_columns, "#{neighbor_distance} AS neighbor_distance")
|
124
155
|
.where.not(attribute_name => nil)
|
125
|
-
.
|
156
|
+
.reorder(Arel.sql(order))
|
126
157
|
}
|
127
158
|
|
128
|
-
def nearest_neighbors(attribute_name
|
129
|
-
if attribute_name.nil?
|
130
|
-
warn "[neighbor] nearest_neighbors without an attribute name is deprecated"
|
131
|
-
attribute_name = :neighbor_vector
|
132
|
-
end
|
159
|
+
def nearest_neighbors(attribute_name, **options)
|
133
160
|
attribute_name = attribute_name.to_sym
|
134
|
-
# important! check if neighbor attribute before
|
161
|
+
# important! check if neighbor attribute before accessing
|
135
162
|
raise ArgumentError, "Invalid attribute" unless self.class.neighbor_attributes[attribute_name]
|
136
163
|
|
137
164
|
self.class
|
138
|
-
.where.not(self.class.primary_key
|
165
|
+
.where.not(Array(self.class.primary_key).to_h { |k| [k, self[k]] })
|
139
166
|
.nearest_neighbors(attribute_name, self[attribute_name], **options)
|
140
167
|
end
|
141
168
|
end
|
@@ -0,0 +1,37 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module MySQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/mysql_vector"
|
5
|
+
|
6
|
+
require "active_record/connection_adapters/abstract_mysql_adapter"
|
7
|
+
|
8
|
+
# ensure schema can be dumped
|
9
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
10
|
+
|
11
|
+
# ensure schema can be loaded
|
12
|
+
unless ActiveRecord::ConnectionAdapters::TableDefinition.method_defined?(:vector)
|
13
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :vector)
|
14
|
+
end
|
15
|
+
|
16
|
+
# prevent unknown OID warning
|
17
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.singleton_class.prepend(RegisterTypes)
|
18
|
+
if ActiveRecord::VERSION::STRING.to_f < 7.1
|
19
|
+
ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter.register_vector_type(ActiveRecord::ConnectionAdapters::AbstractMysqlAdapter::TYPE_MAP)
|
20
|
+
end
|
21
|
+
end
|
22
|
+
|
23
|
+
module RegisterTypes
|
24
|
+
def initialize_type_map(m)
|
25
|
+
super
|
26
|
+
register_vector_type(m)
|
27
|
+
end
|
28
|
+
|
29
|
+
def register_vector_type(m)
|
30
|
+
m.register_type %r(^vector)i do |sql_type|
|
31
|
+
limit = extract_limit(sql_type)
|
32
|
+
Type::MysqlVector.new(limit: limit)
|
33
|
+
end
|
34
|
+
end
|
35
|
+
end
|
36
|
+
end
|
37
|
+
end
|
@@ -0,0 +1,21 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class NormalizedAttribute < ActiveRecord::Type::Value
|
3
|
+
delegate :type, :serialize, :deserialize, to: :@cast_type
|
4
|
+
|
5
|
+
def initialize(cast_type:, model:, attribute_name:)
|
6
|
+
@cast_type = cast_type
|
7
|
+
@model = model
|
8
|
+
@attribute_name = attribute_name.to_s
|
9
|
+
end
|
10
|
+
|
11
|
+
def cast(...)
|
12
|
+
Neighbor::Utils.normalize(@cast_type.cast(...), column_info: @model.columns_hash[@attribute_name])
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(...)
|
18
|
+
@cast_type.send(:cast_value, ...)
|
19
|
+
end
|
20
|
+
end
|
21
|
+
end
|
@@ -0,0 +1,43 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module PostgreSQL
|
3
|
+
def self.initialize!
|
4
|
+
require_relative "type/cube"
|
5
|
+
require_relative "type/halfvec"
|
6
|
+
require_relative "type/sparsevec"
|
7
|
+
require_relative "type/vector"
|
8
|
+
|
9
|
+
require "active_record/connection_adapters/postgresql_adapter"
|
10
|
+
|
11
|
+
# ensure schema can be dumped
|
12
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
|
13
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:halfvec] = {name: "halfvec"}
|
14
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:sparsevec] = {name: "sparsevec"}
|
15
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
|
16
|
+
|
17
|
+
# ensure schema can be loaded
|
18
|
+
ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :halfvec, :sparsevec, :vector)
|
19
|
+
|
20
|
+
# prevent unknown OID warning
|
21
|
+
ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(RegisterTypes)
|
22
|
+
end
|
23
|
+
|
24
|
+
module RegisterTypes
|
25
|
+
def initialize_type_map(m = type_map)
|
26
|
+
super
|
27
|
+
m.register_type "cube", Type::Cube.new
|
28
|
+
m.register_type "halfvec" do |_, _, sql_type|
|
29
|
+
limit = extract_limit(sql_type)
|
30
|
+
Type::Halfvec.new(limit: limit)
|
31
|
+
end
|
32
|
+
m.register_type "sparsevec" do |_, _, sql_type|
|
33
|
+
limit = extract_limit(sql_type)
|
34
|
+
Type::Sparsevec.new(limit: limit)
|
35
|
+
end
|
36
|
+
m.register_type "vector" do |_, _, sql_type|
|
37
|
+
limit = extract_limit(sql_type)
|
38
|
+
Type::Vector.new(limit: limit)
|
39
|
+
end
|
40
|
+
end
|
41
|
+
end
|
42
|
+
end
|
43
|
+
end
|
data/lib/neighbor/railtie.rb
CHANGED
@@ -1,16 +1,16 @@
|
|
1
1
|
module Neighbor
|
2
2
|
class Railtie < Rails::Railtie
|
3
3
|
generators do
|
4
|
+
require "rails/generators/generated_attribute"
|
5
|
+
|
4
6
|
# rails generate model Item embedding:vector{3}
|
5
|
-
|
6
|
-
Rails::Generators::GeneratedAttribute.singleton_class.prepend(Neighbor::GeneratedAttribute)
|
7
|
-
end
|
7
|
+
Rails::Generators::GeneratedAttribute.singleton_class.prepend(Neighbor::GeneratedAttribute)
|
8
8
|
end
|
9
9
|
end
|
10
10
|
|
11
11
|
module GeneratedAttribute
|
12
12
|
def parse_type_and_options(type, *, **)
|
13
|
-
if type =~ /\A(vector)\{(\d+)\}\z/
|
13
|
+
if type =~ /\A(vector|halfvec|bit|sparsevec)\{(\d+)\}\z/
|
14
14
|
return $1, limit: $2.to_i
|
15
15
|
end
|
16
16
|
super
|
@@ -0,0 +1,27 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Reranking
|
3
|
+
def self.rrf(first_ranking, *rankings, k: 60)
|
4
|
+
rankings.unshift(first_ranking)
|
5
|
+
|
6
|
+
ranks = []
|
7
|
+
results = []
|
8
|
+
rankings.each do |ranking|
|
9
|
+
ranks << ranking.map.with_index.to_h { |v, i| [v, i + 1] }
|
10
|
+
results.concat(ranking)
|
11
|
+
end
|
12
|
+
|
13
|
+
results =
|
14
|
+
results.uniq.map do |result|
|
15
|
+
score =
|
16
|
+
ranks.sum do |rank|
|
17
|
+
r = rank[result]
|
18
|
+
r ? 1.0 / (k + r) : 0.0
|
19
|
+
end
|
20
|
+
|
21
|
+
{result: result, score: score}
|
22
|
+
end
|
23
|
+
|
24
|
+
results.sort_by { |v| -v[:score] }
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
@@ -0,0 +1,79 @@
|
|
1
|
+
module Neighbor
|
2
|
+
class SparseVector
|
3
|
+
attr_reader :dimensions, :indices, :values
|
4
|
+
|
5
|
+
NO_DEFAULT = Object.new
|
6
|
+
|
7
|
+
def initialize(value, dimensions = NO_DEFAULT)
|
8
|
+
if value.is_a?(Hash)
|
9
|
+
if dimensions == NO_DEFAULT
|
10
|
+
raise ArgumentError, "missing dimensions"
|
11
|
+
end
|
12
|
+
from_hash(value, dimensions)
|
13
|
+
else
|
14
|
+
unless dimensions == NO_DEFAULT
|
15
|
+
raise ArgumentError, "extra argument"
|
16
|
+
end
|
17
|
+
from_array(value)
|
18
|
+
end
|
19
|
+
end
|
20
|
+
|
21
|
+
def to_s
|
22
|
+
"{#{@indices.zip(@values).map { |i, v| "#{i.to_i + 1}:#{v.to_f}" }.join(",")}}/#{@dimensions.to_i}"
|
23
|
+
end
|
24
|
+
|
25
|
+
def to_a
|
26
|
+
arr = Array.new(dimensions, 0.0)
|
27
|
+
@indices.zip(@values) do |i, v|
|
28
|
+
arr[i] = v
|
29
|
+
end
|
30
|
+
arr
|
31
|
+
end
|
32
|
+
|
33
|
+
private
|
34
|
+
|
35
|
+
def from_hash(data, dimensions)
|
36
|
+
elements = data.select { |_, v| v != 0 }.sort
|
37
|
+
@dimensions = dimensions.to_i
|
38
|
+
@indices = elements.map { |v| v[0].to_i }
|
39
|
+
@values = elements.map { |v| v[1].to_f }
|
40
|
+
end
|
41
|
+
|
42
|
+
def from_array(arr)
|
43
|
+
arr = arr.to_a
|
44
|
+
@dimensions = arr.size
|
45
|
+
@indices = []
|
46
|
+
@values = []
|
47
|
+
arr.each_with_index do |v, i|
|
48
|
+
if v != 0
|
49
|
+
@indices << i
|
50
|
+
@values << v.to_f
|
51
|
+
end
|
52
|
+
end
|
53
|
+
end
|
54
|
+
|
55
|
+
class << self
|
56
|
+
def from_text(string)
|
57
|
+
elements, dimensions = string.split("/", 2)
|
58
|
+
indices = []
|
59
|
+
values = []
|
60
|
+
elements[1..-2].split(",").each do |e|
|
61
|
+
index, value = e.split(":", 2)
|
62
|
+
indices << index.to_i - 1
|
63
|
+
values << value.to_f
|
64
|
+
end
|
65
|
+
from_parts(dimensions.to_i, indices, values)
|
66
|
+
end
|
67
|
+
|
68
|
+
private
|
69
|
+
|
70
|
+
def from_parts(dimensions, indices, values)
|
71
|
+
vec = allocate
|
72
|
+
vec.instance_variable_set(:@dimensions, dimensions)
|
73
|
+
vec.instance_variable_set(:@indices, indices)
|
74
|
+
vec.instance_variable_set(:@values, values)
|
75
|
+
vec
|
76
|
+
end
|
77
|
+
end
|
78
|
+
end
|
79
|
+
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module SQLite
|
3
|
+
# note: this is a public API (unlike PostgreSQL and MySQL)
|
4
|
+
def self.initialize!
|
5
|
+
return if defined?(@initialized)
|
6
|
+
|
7
|
+
require_relative "type/sqlite_vector"
|
8
|
+
require_relative "type/sqlite_int8_vector"
|
9
|
+
|
10
|
+
require "sqlite_vec"
|
11
|
+
require "active_record/connection_adapters/sqlite3_adapter"
|
12
|
+
|
13
|
+
ActiveRecord::ConnectionAdapters::SQLite3Adapter.prepend(InstanceMethods)
|
14
|
+
|
15
|
+
@initialized = true
|
16
|
+
end
|
17
|
+
|
18
|
+
module InstanceMethods
|
19
|
+
def configure_connection
|
20
|
+
super
|
21
|
+
db = ActiveRecord::VERSION::STRING.to_f >= 7.1 ? @raw_connection : @connection
|
22
|
+
db.enable_load_extension(1)
|
23
|
+
SqliteVec.load(db)
|
24
|
+
db.enable_load_extension(0)
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|
data/lib/neighbor/type/cube.rb
CHANGED
@@ -1,36 +1,41 @@
|
|
1
1
|
module Neighbor
|
2
2
|
module Type
|
3
|
-
class Cube < ActiveRecord::Type::
|
3
|
+
class Cube < ActiveRecord::Type::Value
|
4
4
|
def type
|
5
5
|
:cube
|
6
6
|
end
|
7
7
|
|
8
|
-
def
|
9
|
-
if
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = value.to_a
|
10
11
|
if value.first.is_a?(Array)
|
11
|
-
value.map { |v|
|
12
|
+
value = value.map { |v| serialize_point(v) }.join(", ")
|
12
13
|
else
|
13
|
-
|
14
|
+
value = serialize_point(value)
|
14
15
|
end
|
15
|
-
else
|
16
|
-
super
|
17
16
|
end
|
17
|
+
super(value)
|
18
18
|
end
|
19
19
|
|
20
|
-
# TODO uncomment in 0.4.0
|
21
|
-
# def deserialize(value)
|
22
|
-
# if value.nil?
|
23
|
-
# super
|
24
|
-
# elsif value.include?("),(")
|
25
|
-
# value[1..-1].split("),(").map { |v| v.split(",").map(&:to_f) }
|
26
|
-
# else
|
27
|
-
# value[1..-1].split(",").map(&:to_f)
|
28
|
-
# end
|
29
|
-
# end
|
30
|
-
|
31
20
|
private
|
32
21
|
|
33
|
-
def
|
22
|
+
def cast_value(value)
|
23
|
+
if Utils.array?(value)
|
24
|
+
value.to_a
|
25
|
+
elsif value.is_a?(Numeric)
|
26
|
+
[value]
|
27
|
+
elsif value.is_a?(String)
|
28
|
+
if value.include?("),(")
|
29
|
+
value[1..-1].split("),(").map { |v| v.split(",").map(&:to_f) }
|
30
|
+
else
|
31
|
+
value[1..-1].split(",").map(&:to_f)
|
32
|
+
end
|
33
|
+
else
|
34
|
+
raise "can't cast #{value.class.name} to cube"
|
35
|
+
end
|
36
|
+
end
|
37
|
+
|
38
|
+
def serialize_point(value)
|
34
39
|
"(#{value.map(&:to_f).join(", ")})"
|
35
40
|
end
|
36
41
|
end
|
@@ -0,0 +1,28 @@
|
|
1
|
+
module Neighbor
|
2
|
+
module Type
|
3
|
+
class Halfvec < ActiveRecord::Type::Value
|
4
|
+
def type
|
5
|
+
:halfvec
|
6
|
+
end
|
7
|
+
|
8
|
+
def serialize(value)
|
9
|
+
if Utils.array?(value)
|
10
|
+
value = "[#{value.to_a.map(&:to_f).join(",")}]"
|
11
|
+
end
|
12
|
+
super(value)
|
13
|
+
end
|
14
|
+
|
15
|
+
private
|
16
|
+
|
17
|
+
def cast_value(value)
|
18
|
+
if value.is_a?(String)
|
19
|
+
value[1..-1].split(",").map(&:to_f)
|
20
|
+
elsif Utils.array?(value)
|
21
|
+
value.to_a
|
22
|
+
else
|
23
|
+
raise "can't cast #{value.class.name} to halfvec"
|
24
|
+
end
|
25
|
+
end
|
26
|
+
end
|
27
|
+
end
|
28
|
+
end
|