neighbor 0.3.2 → 0.5.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,33 @@
1
+ module Neighbor
2
+ module Type
3
+ class MysqlVector < ActiveRecord::Type::Binary
4
+ def type
5
+ :vector
6
+ end
7
+
8
+ def serialize(value)
9
+ if Utils.array?(value)
10
+ value = value.to_a.pack("e*")
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ def deserialize(value)
16
+ value = super
17
+ cast_value(value) unless value.nil?
18
+ end
19
+
20
+ private
21
+
22
+ def cast_value(value)
23
+ if value.is_a?(String)
24
+ value.unpack("e*")
25
+ elsif Utils.array?(value)
26
+ value.to_a
27
+ else
28
+ raise "can't cast #{value.class.name} to vector"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,30 @@
1
+ module Neighbor
2
+ module Type
3
+ class Sparsevec < ActiveRecord::Type::Value
4
+ def type
5
+ :sparsevec
6
+ end
7
+
8
+ def serialize(value)
9
+ if value.is_a?(SparseVector)
10
+ value = "{#{value.indices.zip(value.values).map { |i, v| "#{i.to_i + 1}:#{v.to_f}" }.join(",")}}/#{value.dimensions.to_i}"
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ private
16
+
17
+ def cast_value(value)
18
+ if value.is_a?(SparseVector)
19
+ value
20
+ elsif value.is_a?(String)
21
+ SparseVector.from_text(value)
22
+ elsif Utils.array?(value)
23
+ value = SparseVector.new(value.to_a)
24
+ else
25
+ raise "can't cast #{value.class.name} to sparsevec"
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteInt8Vector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("c*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("c*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteVector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("f*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("f*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -1,14 +1,28 @@
1
1
  module Neighbor
2
2
  module Type
3
- class Vector < ActiveRecord::Type::String
3
+ class Vector < ActiveRecord::Type::Value
4
4
  def type
5
5
  :vector
6
6
  end
7
7
 
8
- # TODO uncomment in 0.4.0
9
- # def deserialize(value)
10
- # value[1..-1].split(",").map(&:to_f) unless value.nil?
11
- # end
8
+ def serialize(value)
9
+ if Utils.array?(value)
10
+ value = "[#{value.to_a.map(&:to_f).join(",")}]"
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ private
16
+
17
+ def cast_value(value)
18
+ if value.is_a?(String)
19
+ value[1..-1].split(",").map(&:to_f)
20
+ elsif Utils.array?(value)
21
+ value.to_a
22
+ else
23
+ raise "can't cast #{value.class.name} to vector"
24
+ end
25
+ end
12
26
  end
13
27
  end
14
28
  end
@@ -0,0 +1,201 @@
1
+ module Neighbor
2
+ module Utils
3
+ def self.validate_dimensions(value, type, expected, adapter)
4
+ dimensions = type == :sparsevec ? value.dimensions : value.size
5
+ dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
6
+
7
+ if expected && dimensions != expected
8
+ "Expected #{expected} dimensions, not #{dimensions}"
9
+ end
10
+ end
11
+
12
+ def self.validate_finite(value, type)
13
+ case type
14
+ when :bit, :integer
15
+ true
16
+ when :sparsevec
17
+ value.values.all?(&:finite?)
18
+ else
19
+ value.all?(&:finite?)
20
+ end
21
+ end
22
+
23
+ def self.validate(value, dimensions:, type:, adapter:)
24
+ if (message = validate_dimensions(value, type, dimensions, adapter))
25
+ raise Error, message
26
+ end
27
+
28
+ if !validate_finite(value, type)
29
+ raise Error, "Values must be finite"
30
+ end
31
+ end
32
+
33
+ def self.normalize(value, column_info:)
34
+ return nil if value.nil?
35
+
36
+ raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec, :binary].include?(column_info&.type)
37
+
38
+ norm = Math.sqrt(value.sum { |v| v * v })
39
+
40
+ # store zero vector as all zeros
41
+ # since NaN makes the distance always 0
42
+ # could also throw error
43
+ norm > 0 ? value.map { |v| v / norm } : value
44
+ end
45
+
46
+ def self.array?(value)
47
+ !value.nil? && value.respond_to?(:to_a)
48
+ end
49
+
50
+ def self.adapter(model)
51
+ case model.connection_db_config.adapter
52
+ when /sqlite/i
53
+ :sqlite
54
+ when /mysql|trilogy/i
55
+ model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
56
+ else
57
+ :postgresql
58
+ end
59
+ end
60
+
61
+ def self.type(adapter, column_type)
62
+ case adapter
63
+ when :mysql
64
+ if column_type == :binary
65
+ :bit
66
+ else
67
+ column_type
68
+ end
69
+ else
70
+ column_type
71
+ end
72
+ end
73
+
74
+ def self.operator(adapter, column_type, distance)
75
+ case adapter
76
+ when :sqlite
77
+ case distance
78
+ when "euclidean"
79
+ "vec_distance_L2"
80
+ when "cosine"
81
+ "vec_distance_cosine"
82
+ when "taxicab"
83
+ "vec_distance_L1"
84
+ when "hamming"
85
+ "vec_distance_hamming"
86
+ end
87
+ when :mariadb
88
+ case column_type
89
+ when :binary
90
+ case distance
91
+ when "euclidean", "cosine"
92
+ "VEC_DISTANCE"
93
+ end
94
+ when :integer
95
+ case distance
96
+ when "hamming"
97
+ "BIT_COUNT"
98
+ end
99
+ else
100
+ raise ArgumentError, "Unsupported type: #{column_type}"
101
+ end
102
+ when :mysql
103
+ case column_type
104
+ when :vector
105
+ case distance
106
+ when "cosine"
107
+ "COSINE"
108
+ when "euclidean"
109
+ "EUCLIDEAN"
110
+ end
111
+ when :binary
112
+ case distance
113
+ when "hamming"
114
+ "BIT_COUNT"
115
+ end
116
+ else
117
+ raise ArgumentError, "Unsupported type: #{column_type}"
118
+ end
119
+ else
120
+ case column_type
121
+ when :bit
122
+ case distance
123
+ when "hamming"
124
+ "<~>"
125
+ when "jaccard"
126
+ "<%>"
127
+ when "hamming2"
128
+ "#"
129
+ end
130
+ when :vector, :halfvec, :sparsevec
131
+ case distance
132
+ when "inner_product"
133
+ "<#>"
134
+ when "cosine"
135
+ "<=>"
136
+ when "euclidean"
137
+ "<->"
138
+ when "taxicab"
139
+ "<+>"
140
+ end
141
+ when :cube
142
+ case distance
143
+ when "taxicab"
144
+ "<#>"
145
+ when "chebyshev"
146
+ "<=>"
147
+ when "euclidean", "cosine"
148
+ "<->"
149
+ end
150
+ else
151
+ raise ArgumentError, "Unsupported type: #{column_type}"
152
+ end
153
+ end
154
+ end
155
+
156
+ def self.order(adapter, type, operator, quoted_attribute, query)
157
+ case adapter
158
+ when :sqlite
159
+ case type
160
+ when :int8
161
+ "#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
162
+ when :bit
163
+ "#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
164
+ else
165
+ "#{operator}(#{quoted_attribute}, #{query})"
166
+ end
167
+ when :mariadb
168
+ if operator == "BIT_COUNT"
169
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
170
+ else
171
+ "VEC_DISTANCE(#{quoted_attribute}, #{query})"
172
+ end
173
+ when :mysql
174
+ if operator == "BIT_COUNT"
175
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
176
+ elsif operator == "COSINE"
177
+ "DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
178
+ else
179
+ "DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
180
+ end
181
+ else
182
+ if operator == "#"
183
+ "bit_count(#{quoted_attribute} # #{query})"
184
+ else
185
+ "#{quoted_attribute} #{operator} #{query}"
186
+ end
187
+ end
188
+ end
189
+
190
+ def self.normalize_required?(adapter, column_type)
191
+ case adapter
192
+ when :postgresql
193
+ column_type == :cube
194
+ when :mariadb
195
+ true
196
+ else
197
+ false
198
+ end
199
+ end
200
+ end
201
+ end
@@ -1,3 +1,3 @@
1
1
  module Neighbor
2
- VERSION = "0.3.2"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/neighbor.rb CHANGED
@@ -1,47 +1,35 @@
1
1
  # dependencies
2
2
  require "active_support"
3
3
 
4
+ # adapter hooks
5
+ require_relative "neighbor/mysql"
6
+ require_relative "neighbor/postgresql"
7
+ require_relative "neighbor/sqlite"
8
+
4
9
  # modules
10
+ require_relative "neighbor/reranking"
11
+ require_relative "neighbor/sparse_vector"
12
+ require_relative "neighbor/utils"
5
13
  require_relative "neighbor/version"
6
14
 
7
15
  module Neighbor
8
16
  class Error < StandardError; end
9
-
10
- module RegisterTypes
11
- def initialize_type_map(m = type_map)
12
- super
13
- m.register_type "cube", Type::Cube.new
14
- m.register_type "vector" do |_, _, sql_type|
15
- limit = extract_limit(sql_type)
16
- Type::Vector.new(limit: limit)
17
- end
18
- end
19
- end
20
17
  end
21
18
 
22
19
  ActiveSupport.on_load(:active_record) do
20
+ require_relative "neighbor/attribute"
23
21
  require_relative "neighbor/model"
24
- require_relative "neighbor/vector"
25
- require_relative "neighbor/type/cube"
26
- require_relative "neighbor/type/vector"
22
+ require_relative "neighbor/normalized_attribute"
27
23
 
28
24
  extend Neighbor::Model
29
25
 
30
- require "active_record/connection_adapters/postgresql_adapter"
31
-
32
- # ensure schema can be dumped
33
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
34
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
35
-
36
- # ensure schema can be loaded
37
- ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :vector)
38
-
39
- # prevent unknown OID warning
40
- if ActiveRecord::VERSION::MAJOR >= 7
41
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(Neighbor::RegisterTypes)
42
- else
43
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.prepend(Neighbor::RegisterTypes)
26
+ begin
27
+ Neighbor::PostgreSQL.initialize!
28
+ rescue Gem::LoadError
29
+ # tries to load pg gem, which may not be available
44
30
  end
31
+
32
+ Neighbor::MySQL.initialize!
45
33
  end
46
34
 
47
35
  require_relative "neighbor/railtie" if defined?(Rails::Railtie)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: neighbor
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-12-12 00:00:00.000000000 Z
11
+ date: 2024-10-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: activerecord
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '6.1'
19
+ version: '7'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '6.1'
26
+ version: '7'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -34,15 +34,29 @@ files:
34
34
  - LICENSE.txt
35
35
  - README.md
36
36
  - lib/generators/neighbor/cube_generator.rb
37
+ - lib/generators/neighbor/sqlite_generator.rb
37
38
  - lib/generators/neighbor/templates/cube.rb.tt
39
+ - lib/generators/neighbor/templates/sqlite.rb.tt
38
40
  - lib/generators/neighbor/templates/vector.rb.tt
39
41
  - lib/generators/neighbor/vector_generator.rb
40
42
  - lib/neighbor.rb
43
+ - lib/neighbor/attribute.rb
41
44
  - lib/neighbor/model.rb
45
+ - lib/neighbor/mysql.rb
46
+ - lib/neighbor/normalized_attribute.rb
47
+ - lib/neighbor/postgresql.rb
42
48
  - lib/neighbor/railtie.rb
49
+ - lib/neighbor/reranking.rb
50
+ - lib/neighbor/sparse_vector.rb
51
+ - lib/neighbor/sqlite.rb
43
52
  - lib/neighbor/type/cube.rb
53
+ - lib/neighbor/type/halfvec.rb
54
+ - lib/neighbor/type/mysql_vector.rb
55
+ - lib/neighbor/type/sparsevec.rb
56
+ - lib/neighbor/type/sqlite_int8_vector.rb
57
+ - lib/neighbor/type/sqlite_vector.rb
44
58
  - lib/neighbor/type/vector.rb
45
- - lib/neighbor/vector.rb
59
+ - lib/neighbor/utils.rb
46
60
  - lib/neighbor/version.rb
47
61
  homepage: https://github.com/ankane/neighbor
48
62
  licenses:
@@ -56,15 +70,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
56
70
  requirements:
57
71
  - - ">="
58
72
  - !ruby/object:Gem::Version
59
- version: '3'
73
+ version: '3.1'
60
74
  required_rubygems_version: !ruby/object:Gem::Requirement
61
75
  requirements:
62
76
  - - ">="
63
77
  - !ruby/object:Gem::Version
64
78
  version: '0'
65
79
  requirements: []
66
- rubygems_version: 3.4.10
80
+ rubygems_version: 3.5.16
67
81
  signing_key:
68
82
  specification_version: 4
69
- summary: Nearest neighbor search for Rails and Postgres
83
+ summary: Nearest neighbor search for Rails
70
84
  test_files: []
@@ -1,65 +0,0 @@
1
- module Neighbor
2
- class Vector < ActiveRecord::Type::Value
3
- def initialize(dimensions:, normalize:, model:, attribute_name:)
4
- super()
5
- @dimensions = dimensions
6
- @normalize = normalize
7
- @model = model
8
- @attribute_name = attribute_name
9
- end
10
-
11
- def self.cast(value, dimensions:, normalize:, column_info:)
12
- value = value.to_a.map(&:to_f)
13
-
14
- dimensions ||= column_info[:dimensions]
15
- raise Error, "Expected #{dimensions} dimensions, not #{value.size}" if dimensions && value.size != dimensions
16
-
17
- raise Error, "Values must be finite" unless value.all?(&:finite?)
18
-
19
- if normalize
20
- norm = Math.sqrt(value.sum { |v| v * v })
21
-
22
- # store zero vector as all zeros
23
- # since NaN makes the distance always 0
24
- # could also throw error
25
-
26
- # safe to update in-place since earlier map dups
27
- value.map! { |v| v / norm } if norm > 0
28
- end
29
-
30
- value
31
- end
32
-
33
- def self.column_info(model, attribute_name)
34
- attribute_name = attribute_name.to_s
35
- column = model.columns.detect { |c| c.name == attribute_name }
36
- {
37
- type: column.try(:type),
38
- dimensions: column.try(:limit)
39
- }
40
- end
41
-
42
- # need to be careful to avoid loading column info before needed
43
- def column_info
44
- @column_info ||= self.class.column_info(@model, @attribute_name)
45
- end
46
-
47
- def cast(value)
48
- self.class.cast(value, dimensions: @dimensions, normalize: @normalize, column_info: column_info) unless value.nil?
49
- end
50
-
51
- def serialize(value)
52
- unless value.nil?
53
- if column_info[:type] == :vector
54
- "[#{cast(value).join(", ")}]"
55
- else
56
- "(#{cast(value).join(", ")})"
57
- end
58
- end
59
- end
60
-
61
- def deserialize(value)
62
- value[1..-1].split(",").map(&:to_f) unless value.nil?
63
- end
64
- end
65
- end