neighbor 0.1.1 → 0.1.2

Sign up to get free protection for your applications and to get access to all the features.
checksums.yaml CHANGED
@@ -1,7 +1,7 @@
1
1
  ---
2
2
  SHA256:
3
- metadata.gz: 01df817b036ee9b4d0c54ddadfb3be7fb31b36bb81504321b8dfd3d7a1ba83a4
4
- data.tar.gz: f7a1757914f2e1226bc3b95fe5e0df9837a530315d6e627062df565e331b1708
3
+ metadata.gz: e64a3292445759187b7c08fd9cc54fe0da340ea3e3ab5de909620de6395b4984
4
+ data.tar.gz: 274d73ed464a0503973c5549022dca8820257fe470c9746b5599da9a153d46e3
5
5
  SHA512:
6
- metadata.gz: 8381834f92092cb13d2c8898b588f0758e99e48f84e3b506aab15acf54fb0b3f72449734ca68db10de466eecde4a7b277cceca72192240ce21567125ad0af52f
7
- data.tar.gz: 79ddd953bf0c134d73c92c61a8051524c024e2e1c95710e346a191a8a7c098f55562f81263752ea2e2ccaef9ced8e376dbf6431a4e745931d256ee2ca854d309
6
+ metadata.gz: 17df9f8a5848337fc570c6fa685f54f4aa9d49bea8e52e74fbec098e0bc9e450b655c7d96cb66c6e8b6aaae4824f2c831ea26a14f7f46b50ef64c955cfb9839b
7
+ data.tar.gz: a47165d174ec86ebfdcd128e62c50b85d097aac535b6b6d424ea31988428a57237cdade2829d8e22a15e566ac7597371c95fabafad42cdbc5d56f063abab55e4
data/CHANGELOG.md CHANGED
@@ -1,3 +1,7 @@
1
+ ## 0.1.2 (2021-02-21)
2
+
3
+ - Added `nearest_neighbors` scope
4
+
1
5
  ## 0.1.1 (2021-02-16)
2
6
 
3
7
  - Fixed `Could not dump table` error
data/README.md CHANGED
@@ -50,12 +50,18 @@ item.update(neighbor_vector: [1.0, 1.2, 0.5])
50
50
 
51
51
  > With cosine distance (the default), vectors are normalized before being stored
52
52
 
53
- Get the nearest neighbors
53
+ Get the nearest neighbors to a record
54
54
 
55
55
  ```ruby
56
56
  item.nearest_neighbors.first(5)
57
57
  ```
58
58
 
59
+ Get the nearest neighbors to a vector
60
+
61
+ ```ruby
62
+ Item.nearest_neighbors([1, 2, 3])
63
+ ```
64
+
59
65
  ## Distance
60
66
 
61
67
  Specify the distance metric
@@ -73,6 +79,8 @@ Supported values are:
73
79
  - `taxicab`
74
80
  - `chebyshev`
75
81
 
82
+ For inner product, see [this example](examples/disco_user_recs.rb)
83
+
76
84
  Records returned from `nearest_neighbors` will have a `neighbor_distance` attribute
77
85
 
78
86
  ```ruby
@@ -86,7 +94,7 @@ By default, Postgres limits the `cube` data type to 100 dimensions. See the [Pos
86
94
 
87
95
  ## Example
88
96
 
89
- You can use Neighbor for online item recommendations with [Disco](https://github.com/ankane/disco). We’ll use MovieLens data for this example.
97
+ You can use Neighbor for online item-based recommendations with [Disco](https://github.com/ankane/disco). We’ll use MovieLens data for this example.
90
98
 
91
99
  Generate a model
92
100
 
@@ -126,7 +134,7 @@ movie = Movie.find_by(name: "Star Wars (1977)")
126
134
  movie.nearest_neighbors.first(5).map(&:name)
127
135
  ```
128
136
 
129
- [Complete code](examples/disco.rb)
137
+ [Complete code](examples/disco_item_recs.rb)
130
138
 
131
139
  ## History
132
140
 
@@ -11,8 +11,10 @@ module Neighbor
11
11
  class_eval do
12
12
  attribute attribute_name, Neighbor::Vector.new(dimensions: dimensions, distance: distance)
13
13
 
14
- define_method :nearest_neighbors do
15
- return self.class.none if neighbor_vector.nil?
14
+ scope :nearest_neighbors, ->(vector) {
15
+ return none if vector.nil?
16
+
17
+ quoted_attribute = "#{connection.quote_table_name(table_name)}.#{connection.quote_column_name(attribute_name)}"
16
18
 
17
19
  operator =
18
20
  case distance
@@ -24,11 +26,10 @@ module Neighbor
24
26
  "<->"
25
27
  end
26
28
 
27
- quoted_attribute = "#{self.class.connection.quote_table_name(self.class.table_name)}.#{self.class.connection.quote_column_name(attribute_name)}"
28
-
29
29
  # important! neighbor_vector should already be typecast
30
30
  # but use to_f as extra safeguard against SQL injection
31
- order = "#{quoted_attribute} #{operator} cube(array[#{neighbor_vector.map(&:to_f).join(", ")}])"
31
+ vector = Neighbor::Vector.cast(vector, dimensions: dimensions, distance: distance)
32
+ order = "#{quoted_attribute} #{operator} cube(array[#{vector.map(&:to_f).join(", ")}])"
32
33
 
33
34
  # https://stats.stackexchange.com/questions/146221/is-cosine-similarity-identical-to-l2-normalized-euclidean-distance
34
35
  # with normalized vectors:
@@ -38,11 +39,15 @@ module Neighbor
38
39
  neighbor_distance = distance == "cosine" ? "POWER(#{order}, 2) / 2.0" : order
39
40
 
40
41
  # for select, use column_names instead of * to account for ignored columns
41
- self.class
42
- .select(*self.class.column_names, "#{neighbor_distance} AS neighbor_distance")
43
- .where.not(self.class.primary_key => send(self.class.primary_key))
42
+ select(*column_names, "#{neighbor_distance} AS neighbor_distance")
44
43
  .where.not(attribute_name => nil)
45
44
  .order(Arel.sql(order))
45
+ }
46
+
47
+ define_method :nearest_neighbors do
48
+ self.class
49
+ .where.not(self.class.primary_key => send(self.class.primary_key))
50
+ .nearest_neighbors(send(attribute_name))
46
51
  end
47
52
  end
48
53
  end
@@ -6,13 +6,11 @@ module Neighbor
6
6
  @distance = distance
7
7
  end
8
8
 
9
- def cast(value)
10
- return if value.nil?
11
-
9
+ def self.cast(value, dimensions:, distance:)
12
10
  value = value.to_a.map(&:to_f)
13
- raise Error, "Expected #{@dimensions} dimensions, not #{value.size}" unless value.size == @dimensions
11
+ raise Error, "Expected #{dimensions} dimensions, not #{value.size}" unless value.size == dimensions
14
12
 
15
- if @distance == "cosine"
13
+ if distance == "cosine"
16
14
  norm = Math.sqrt(value.sum { |v| v * v })
17
15
  value.map { |v| v / norm }
18
16
  else
@@ -20,6 +18,10 @@ module Neighbor
20
18
  end
21
19
  end
22
20
 
21
+ def cast(value)
22
+ self.class.cast(value, dimensions: @dimensions, distance: @distance) unless value.nil?
23
+ end
24
+
23
25
  def serialize(value)
24
26
  "(#{cast(value).join(", ")})" unless value.nil?
25
27
  end
@@ -1,3 +1,3 @@
1
1
  module Neighbor
2
- VERSION = "0.1.1"
2
+ VERSION = "0.1.2"
3
3
  end
metadata CHANGED
@@ -1,14 +1,14 @@
1
1
  --- !ruby/object:Gem::Specification
2
2
  name: neighbor
3
3
  version: !ruby/object:Gem::Version
4
- version: 0.1.1
4
+ version: 0.1.2
5
5
  platform: ruby
6
6
  authors:
7
7
  - Andrew Kane
8
8
  autorequire:
9
9
  bindir: bin
10
10
  cert_chain: []
11
- date: 2021-02-17 00:00:00.000000000 Z
11
+ date: 2021-02-22 00:00:00.000000000 Z
12
12
  dependencies:
13
13
  - !ruby/object:Gem::Dependency
14
14
  name: activerecord