neighbor 0.3.2 → 0.5.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,33 @@
1
+ module Neighbor
2
+ module Type
3
+ class MysqlVector < ActiveRecord::Type::Binary
4
+ def type
5
+ :vector
6
+ end
7
+
8
+ def serialize(value)
9
+ if Utils.array?(value)
10
+ value = value.to_a.pack("e*")
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ def deserialize(value)
16
+ value = super
17
+ cast_value(value) unless value.nil?
18
+ end
19
+
20
+ private
21
+
22
+ def cast_value(value)
23
+ if value.is_a?(String)
24
+ value.unpack("e*")
25
+ elsif Utils.array?(value)
26
+ value.to_a
27
+ else
28
+ raise "can't cast #{value.class.name} to vector"
29
+ end
30
+ end
31
+ end
32
+ end
33
+ end
@@ -0,0 +1,30 @@
1
+ module Neighbor
2
+ module Type
3
+ class Sparsevec < ActiveRecord::Type::Value
4
+ def type
5
+ :sparsevec
6
+ end
7
+
8
+ def serialize(value)
9
+ if value.is_a?(SparseVector)
10
+ value = "{#{value.indices.zip(value.values).map { |i, v| "#{i.to_i + 1}:#{v.to_f}" }.join(",")}}/#{value.dimensions.to_i}"
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ private
16
+
17
+ def cast_value(value)
18
+ if value.is_a?(SparseVector)
19
+ value
20
+ elsif value.is_a?(String)
21
+ SparseVector.from_text(value)
22
+ elsif Utils.array?(value)
23
+ value = SparseVector.new(value.to_a)
24
+ else
25
+ raise "can't cast #{value.class.name} to sparsevec"
26
+ end
27
+ end
28
+ end
29
+ end
30
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteInt8Vector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("c*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("c*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -0,0 +1,29 @@
1
+ module Neighbor
2
+ module Type
3
+ class SqliteVector < ActiveRecord::Type::Binary
4
+ def serialize(value)
5
+ if Utils.array?(value)
6
+ value = value.to_a.pack("f*")
7
+ end
8
+ super(value)
9
+ end
10
+
11
+ def deserialize(value)
12
+ value = super
13
+ cast_value(value) unless value.nil?
14
+ end
15
+
16
+ private
17
+
18
+ def cast_value(value)
19
+ if value.is_a?(String)
20
+ value.unpack("f*")
21
+ elsif Utils.array?(value)
22
+ value.to_a
23
+ else
24
+ raise "can't cast #{value.class.name} to vector"
25
+ end
26
+ end
27
+ end
28
+ end
29
+ end
@@ -1,14 +1,28 @@
1
1
  module Neighbor
2
2
  module Type
3
- class Vector < ActiveRecord::Type::String
3
+ class Vector < ActiveRecord::Type::Value
4
4
  def type
5
5
  :vector
6
6
  end
7
7
 
8
- # TODO uncomment in 0.4.0
9
- # def deserialize(value)
10
- # value[1..-1].split(",").map(&:to_f) unless value.nil?
11
- # end
8
+ def serialize(value)
9
+ if Utils.array?(value)
10
+ value = "[#{value.to_a.map(&:to_f).join(",")}]"
11
+ end
12
+ super(value)
13
+ end
14
+
15
+ private
16
+
17
+ def cast_value(value)
18
+ if value.is_a?(String)
19
+ value[1..-1].split(",").map(&:to_f)
20
+ elsif Utils.array?(value)
21
+ value.to_a
22
+ else
23
+ raise "can't cast #{value.class.name} to vector"
24
+ end
25
+ end
12
26
  end
13
27
  end
14
28
  end
@@ -0,0 +1,201 @@
1
+ module Neighbor
2
+ module Utils
3
+ def self.validate_dimensions(value, type, expected, adapter)
4
+ dimensions = type == :sparsevec ? value.dimensions : value.size
5
+ dimensions *= 8 if type == :bit && [:sqlite, :mysql].include?(adapter)
6
+
7
+ if expected && dimensions != expected
8
+ "Expected #{expected} dimensions, not #{dimensions}"
9
+ end
10
+ end
11
+
12
+ def self.validate_finite(value, type)
13
+ case type
14
+ when :bit, :integer
15
+ true
16
+ when :sparsevec
17
+ value.values.all?(&:finite?)
18
+ else
19
+ value.all?(&:finite?)
20
+ end
21
+ end
22
+
23
+ def self.validate(value, dimensions:, type:, adapter:)
24
+ if (message = validate_dimensions(value, type, dimensions, adapter))
25
+ raise Error, message
26
+ end
27
+
28
+ if !validate_finite(value, type)
29
+ raise Error, "Values must be finite"
30
+ end
31
+ end
32
+
33
+ def self.normalize(value, column_info:)
34
+ return nil if value.nil?
35
+
36
+ raise Error, "Normalize not supported for type" unless [:cube, :vector, :halfvec, :binary].include?(column_info&.type)
37
+
38
+ norm = Math.sqrt(value.sum { |v| v * v })
39
+
40
+ # store zero vector as all zeros
41
+ # since NaN makes the distance always 0
42
+ # could also throw error
43
+ norm > 0 ? value.map { |v| v / norm } : value
44
+ end
45
+
46
+ def self.array?(value)
47
+ !value.nil? && value.respond_to?(:to_a)
48
+ end
49
+
50
+ def self.adapter(model)
51
+ case model.connection_db_config.adapter
52
+ when /sqlite/i
53
+ :sqlite
54
+ when /mysql|trilogy/i
55
+ model.connection_pool.with_connection { |c| c.try(:mariadb?) } ? :mariadb : :mysql
56
+ else
57
+ :postgresql
58
+ end
59
+ end
60
+
61
+ def self.type(adapter, column_type)
62
+ case adapter
63
+ when :mysql
64
+ if column_type == :binary
65
+ :bit
66
+ else
67
+ column_type
68
+ end
69
+ else
70
+ column_type
71
+ end
72
+ end
73
+
74
+ def self.operator(adapter, column_type, distance)
75
+ case adapter
76
+ when :sqlite
77
+ case distance
78
+ when "euclidean"
79
+ "vec_distance_L2"
80
+ when "cosine"
81
+ "vec_distance_cosine"
82
+ when "taxicab"
83
+ "vec_distance_L1"
84
+ when "hamming"
85
+ "vec_distance_hamming"
86
+ end
87
+ when :mariadb
88
+ case column_type
89
+ when :binary
90
+ case distance
91
+ when "euclidean", "cosine"
92
+ "VEC_DISTANCE"
93
+ end
94
+ when :integer
95
+ case distance
96
+ when "hamming"
97
+ "BIT_COUNT"
98
+ end
99
+ else
100
+ raise ArgumentError, "Unsupported type: #{column_type}"
101
+ end
102
+ when :mysql
103
+ case column_type
104
+ when :vector
105
+ case distance
106
+ when "cosine"
107
+ "COSINE"
108
+ when "euclidean"
109
+ "EUCLIDEAN"
110
+ end
111
+ when :binary
112
+ case distance
113
+ when "hamming"
114
+ "BIT_COUNT"
115
+ end
116
+ else
117
+ raise ArgumentError, "Unsupported type: #{column_type}"
118
+ end
119
+ else
120
+ case column_type
121
+ when :bit
122
+ case distance
123
+ when "hamming"
124
+ "<~>"
125
+ when "jaccard"
126
+ "<%>"
127
+ when "hamming2"
128
+ "#"
129
+ end
130
+ when :vector, :halfvec, :sparsevec
131
+ case distance
132
+ when "inner_product"
133
+ "<#>"
134
+ when "cosine"
135
+ "<=>"
136
+ when "euclidean"
137
+ "<->"
138
+ when "taxicab"
139
+ "<+>"
140
+ end
141
+ when :cube
142
+ case distance
143
+ when "taxicab"
144
+ "<#>"
145
+ when "chebyshev"
146
+ "<=>"
147
+ when "euclidean", "cosine"
148
+ "<->"
149
+ end
150
+ else
151
+ raise ArgumentError, "Unsupported type: #{column_type}"
152
+ end
153
+ end
154
+ end
155
+
156
+ def self.order(adapter, type, operator, quoted_attribute, query)
157
+ case adapter
158
+ when :sqlite
159
+ case type
160
+ when :int8
161
+ "#{operator}(vec_int8(#{quoted_attribute}), vec_int8(#{query}))"
162
+ when :bit
163
+ "#{operator}(vec_bit(#{quoted_attribute}), vec_bit(#{query}))"
164
+ else
165
+ "#{operator}(#{quoted_attribute}, #{query})"
166
+ end
167
+ when :mariadb
168
+ if operator == "BIT_COUNT"
169
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
170
+ else
171
+ "VEC_DISTANCE(#{quoted_attribute}, #{query})"
172
+ end
173
+ when :mysql
174
+ if operator == "BIT_COUNT"
175
+ "BIT_COUNT(#{quoted_attribute} ^ #{query})"
176
+ elsif operator == "COSINE"
177
+ "DISTANCE(#{quoted_attribute}, #{query}, 'COSINE')"
178
+ else
179
+ "DISTANCE(#{quoted_attribute}, #{query}, 'EUCLIDEAN')"
180
+ end
181
+ else
182
+ if operator == "#"
183
+ "bit_count(#{quoted_attribute} # #{query})"
184
+ else
185
+ "#{quoted_attribute} #{operator} #{query}"
186
+ end
187
+ end
188
+ end
189
+
190
+ def self.normalize_required?(adapter, column_type)
191
+ case adapter
192
+ when :postgresql
193
+ column_type == :cube
194
+ when :mariadb
195
+ true
196
+ else
197
+ false
198
+ end
199
+ end
200
+ end
201
+ end
@@ -1,3 +1,3 @@
1
1
  module Neighbor
2
- VERSION = "0.3.2"
2
+ VERSION = "0.5.0"
3
3
  end
data/lib/neighbor.rb CHANGED
@@ -1,47 +1,35 @@
1
1
  # dependencies
2
2
  require "active_support"
3
3
 
4
+ # adapter hooks
5
+ require_relative "neighbor/mysql"
6
+ require_relative "neighbor/postgresql"
7
+ require_relative "neighbor/sqlite"
8
+
4
9
  # modules
10
+ require_relative "neighbor/reranking"
11
+ require_relative "neighbor/sparse_vector"
12
+ require_relative "neighbor/utils"
5
13
  require_relative "neighbor/version"
6
14
 
7
15
  module Neighbor
8
16
  class Error < StandardError; end
9
-
10
- module RegisterTypes
11
- def initialize_type_map(m = type_map)
12
- super
13
- m.register_type "cube", Type::Cube.new
14
- m.register_type "vector" do |_, _, sql_type|
15
- limit = extract_limit(sql_type)
16
- Type::Vector.new(limit: limit)
17
- end
18
- end
19
- end
20
17
  end
21
18
 
22
19
  ActiveSupport.on_load(:active_record) do
20
+ require_relative "neighbor/attribute"
23
21
  require_relative "neighbor/model"
24
- require_relative "neighbor/vector"
25
- require_relative "neighbor/type/cube"
26
- require_relative "neighbor/type/vector"
22
+ require_relative "neighbor/normalized_attribute"
27
23
 
28
24
  extend Neighbor::Model
29
25
 
30
- require "active_record/connection_adapters/postgresql_adapter"
31
-
32
- # ensure schema can be dumped
33
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:cube] = {name: "cube"}
34
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter::NATIVE_DATABASE_TYPES[:vector] = {name: "vector"}
35
-
36
- # ensure schema can be loaded
37
- ActiveRecord::ConnectionAdapters::TableDefinition.send(:define_column_methods, :cube, :vector)
38
-
39
- # prevent unknown OID warning
40
- if ActiveRecord::VERSION::MAJOR >= 7
41
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.singleton_class.prepend(Neighbor::RegisterTypes)
42
- else
43
- ActiveRecord::ConnectionAdapters::PostgreSQLAdapter.prepend(Neighbor::RegisterTypes)
26
+ begin
27
+ Neighbor::PostgreSQL.initialize!
28
+ rescue Gem::LoadError
29
+ # tries to load pg gem, which may not be available
44
30
  end
31
+
32
+ Neighbor::MySQL.initialize!
45
33
  end
46
34
 
47
35
  require_relative "neighbor/railtie" if defined?(Rails::Railtie)
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: neighbor
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.3.2
4
+ version: 0.5.0
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2023-12-12 00:00:00.000000000 Z
11
+ date: 2024-10-08 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: activerecord
@@ -16,14 +16,14 @@ dependencies:
16
16
  requirements:
17
17
  - - ">="
18
18
  - !ruby/object:Gem::Version
19
- version: '6.1'
19
+ version: '7'
20
20
  type: :runtime
21
21
  prerelease: false
22
22
  version_requirements: !ruby/object:Gem::Requirement
23
23
  requirements:
24
24
  - - ">="
25
25
  - !ruby/object:Gem::Version
26
- version: '6.1'
26
+ version: '7'
27
27
  description:
28
28
  email: andrew@ankane.org
29
29
  executables: []
@@ -34,15 +34,29 @@ files:
34
34
  - LICENSE.txt
35
35
  - README.md
36
36
  - lib/generators/neighbor/cube_generator.rb
37
+ - lib/generators/neighbor/sqlite_generator.rb
37
38
  - lib/generators/neighbor/templates/cube.rb.tt
39
+ - lib/generators/neighbor/templates/sqlite.rb.tt
38
40
  - lib/generators/neighbor/templates/vector.rb.tt
39
41
  - lib/generators/neighbor/vector_generator.rb
40
42
  - lib/neighbor.rb
43
+ - lib/neighbor/attribute.rb
41
44
  - lib/neighbor/model.rb
45
+ - lib/neighbor/mysql.rb
46
+ - lib/neighbor/normalized_attribute.rb
47
+ - lib/neighbor/postgresql.rb
42
48
  - lib/neighbor/railtie.rb
49
+ - lib/neighbor/reranking.rb
50
+ - lib/neighbor/sparse_vector.rb
51
+ - lib/neighbor/sqlite.rb
43
52
  - lib/neighbor/type/cube.rb
53
+ - lib/neighbor/type/halfvec.rb
54
+ - lib/neighbor/type/mysql_vector.rb
55
+ - lib/neighbor/type/sparsevec.rb
56
+ - lib/neighbor/type/sqlite_int8_vector.rb
57
+ - lib/neighbor/type/sqlite_vector.rb
44
58
  - lib/neighbor/type/vector.rb
45
- - lib/neighbor/vector.rb
59
+ - lib/neighbor/utils.rb
46
60
  - lib/neighbor/version.rb
47
61
  homepage: https://github.com/ankane/neighbor
48
62
  licenses:
@@ -56,15 +70,15 @@ required_ruby_version: !ruby/object:Gem::Requirement
56
70
  requirements:
57
71
  - - ">="
58
72
  - !ruby/object:Gem::Version
59
- version: '3'
73
+ version: '3.1'
60
74
  required_rubygems_version: !ruby/object:Gem::Requirement
61
75
  requirements:
62
76
  - - ">="
63
77
  - !ruby/object:Gem::Version
64
78
  version: '0'
65
79
  requirements: []
66
- rubygems_version: 3.4.10
80
+ rubygems_version: 3.5.16
67
81
  signing_key:
68
82
  specification_version: 4
69
- summary: Nearest neighbor search for Rails and Postgres
83
+ summary: Nearest neighbor search for Rails
70
84
  test_files: []
@@ -1,65 +0,0 @@
1
- module Neighbor
2
- class Vector < ActiveRecord::Type::Value
3
- def initialize(dimensions:, normalize:, model:, attribute_name:)
4
- super()
5
- @dimensions = dimensions
6
- @normalize = normalize
7
- @model = model
8
- @attribute_name = attribute_name
9
- end
10
-
11
- def self.cast(value, dimensions:, normalize:, column_info:)
12
- value = value.to_a.map(&:to_f)
13
-
14
- dimensions ||= column_info[:dimensions]
15
- raise Error, "Expected #{dimensions} dimensions, not #{value.size}" if dimensions && value.size != dimensions
16
-
17
- raise Error, "Values must be finite" unless value.all?(&:finite?)
18
-
19
- if normalize
20
- norm = Math.sqrt(value.sum { |v| v * v })
21
-
22
- # store zero vector as all zeros
23
- # since NaN makes the distance always 0
24
- # could also throw error
25
-
26
- # safe to update in-place since earlier map dups
27
- value.map! { |v| v / norm } if norm > 0
28
- end
29
-
30
- value
31
- end
32
-
33
- def self.column_info(model, attribute_name)
34
- attribute_name = attribute_name.to_s
35
- column = model.columns.detect { |c| c.name == attribute_name }
36
- {
37
- type: column.try(:type),
38
- dimensions: column.try(:limit)
39
- }
40
- end
41
-
42
- # need to be careful to avoid loading column info before needed
43
- def column_info
44
- @column_info ||= self.class.column_info(@model, @attribute_name)
45
- end
46
-
47
- def cast(value)
48
- self.class.cast(value, dimensions: @dimensions, normalize: @normalize, column_info: column_info) unless value.nil?
49
- end
50
-
51
- def serialize(value)
52
- unless value.nil?
53
- if column_info[:type] == :vector
54
- "[#{cast(value).join(", ")}]"
55
- else
56
- "(#{cast(value).join(", ")})"
57
- end
58
- end
59
- end
60
-
61
- def deserialize(value)
62
- value[1..-1].split(",").map(&:to_f) unless value.nil?
63
- end
64
- end
65
- end