cmfrec 0.1.3 → 0.1.4
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +4 -4
 - data/CHANGELOG.md +6 -0
 - data/LICENSE.txt +1 -1
 - data/README.md +67 -0
 - data/lib/cmfrec.rb +5 -3
 - data/lib/cmfrec/recommender.rb +299 -121
 - data/lib/cmfrec/version.rb +1 -1
 - metadata +3 -3
 
    
        checksums.yaml
    CHANGED
    
    | 
         @@ -1,7 +1,7 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            ---
         
     | 
| 
       2 
2 
     | 
    
         
             
            SHA256:
         
     | 
| 
       3 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       4 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 3 
     | 
    
         
            +
              metadata.gz: bb7b07ae46500a545f1a130dfc5648aa3f925f9b5766a6c70a1652c7b5732182
         
     | 
| 
      
 4 
     | 
    
         
            +
              data.tar.gz: e89a6d1900cda651dc6b0aac2899050e28680cddfb6b39b6b5eacfe467b59aad
         
     | 
| 
       5 
5 
     | 
    
         
             
            SHA512:
         
     | 
| 
       6 
     | 
    
         
            -
              metadata.gz:  
     | 
| 
       7 
     | 
    
         
            -
              data.tar.gz:  
     | 
| 
      
 6 
     | 
    
         
            +
              metadata.gz: 117aa6952fe0ab8ddebfaece6655cf479a7adbab7d6f634e7d3428c72824a410812c037ae006366180a9691a6d160d8065b777a9c10a33a5ccfefedb28c99ec6
         
     | 
| 
      
 7 
     | 
    
         
            +
              data.tar.gz: 57985a055705b820226a2aa1451453383ee3509e43225f8fdb09e713c4530754b0b608f7d1b4814973b43e3d625f824f9f87939687d015b352cc8905f7b4f118
         
     | 
    
        data/CHANGELOG.md
    CHANGED
    
    
    
        data/LICENSE.txt
    CHANGED
    
    
    
        data/README.md
    CHANGED
    
    | 
         @@ -107,6 +107,26 @@ Get recommendations with only side information 
     | 
|
| 
       107 
107 
     | 
    
         
             
            recommender.new_user_recs([], user_info: {cats: 0, dogs: 2})
         
     | 
| 
       108 
108 
     | 
    
         
             
            ```
         
     | 
| 
       109 
109 
     | 
    
         | 
| 
      
 110 
     | 
    
         
            +
            ## Similarity
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
            Add this line to your application’s Gemfile:
         
     | 
| 
      
 113 
     | 
    
         
            +
             
     | 
| 
      
 114 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 115 
     | 
    
         
            +
            gem 'ngt'
         
     | 
| 
      
 116 
     | 
    
         
            +
            ```
         
     | 
| 
      
 117 
     | 
    
         
            +
             
     | 
| 
      
 118 
     | 
    
         
            +
            Get similar users
         
     | 
| 
      
 119 
     | 
    
         
            +
             
     | 
| 
      
 120 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 121 
     | 
    
         
            +
            recommender.similar_users(user_id)
         
     | 
| 
      
 122 
     | 
    
         
            +
            ```
         
     | 
| 
      
 123 
     | 
    
         
            +
             
     | 
| 
      
 124 
     | 
    
         
            +
            Get similar items - “users who liked this item also liked”
         
     | 
| 
      
 125 
     | 
    
         
            +
             
     | 
| 
      
 126 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 127 
     | 
    
         
            +
            recommender.similar_items(item_id)
         
     | 
| 
      
 128 
     | 
    
         
            +
            ```
         
     | 
| 
      
 129 
     | 
    
         
            +
             
     | 
| 
       110 
130 
     | 
    
         
             
            ## Examples
         
     | 
| 
       111 
131 
     | 
    
         | 
| 
       112 
132 
     | 
    
         
             
            ### MovieLens
         
     | 
| 
         @@ -125,6 +145,35 @@ recommender.fit(ratings.first(80000), user_info: user_info, item_info: item_info 
     | 
|
| 
       125 
145 
     | 
    
         
             
            recommender.predict(ratings.last(20000))
         
     | 
| 
       126 
146 
     | 
    
         
             
            ```
         
     | 
| 
       127 
147 
     | 
    
         | 
| 
      
 148 
     | 
    
         
            +
            ### Ahoy
         
     | 
| 
      
 149 
     | 
    
         
            +
             
     | 
| 
      
 150 
     | 
    
         
            +
            [Ahoy](https://github.com/ankane/ahoy) is a great source for implicit feedback
         
     | 
| 
      
 151 
     | 
    
         
            +
             
     | 
| 
      
 152 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 153 
     | 
    
         
            +
            views = Ahoy::Event.
         
     | 
| 
      
 154 
     | 
    
         
            +
              where(name: "Viewed post").
         
     | 
| 
      
 155 
     | 
    
         
            +
              group(:user_id).
         
     | 
| 
      
 156 
     | 
    
         
            +
              group("properties->>'post_id'"). # postgres syntax
         
     | 
| 
      
 157 
     | 
    
         
            +
              count
         
     | 
| 
      
 158 
     | 
    
         
            +
             
     | 
| 
      
 159 
     | 
    
         
            +
            data =
         
     | 
| 
      
 160 
     | 
    
         
            +
              views.map do |(user_id, post_id), count|
         
     | 
| 
      
 161 
     | 
    
         
            +
                {
         
     | 
| 
      
 162 
     | 
    
         
            +
                  user_id: user_id,
         
     | 
| 
      
 163 
     | 
    
         
            +
                  item_id: post_id,
         
     | 
| 
      
 164 
     | 
    
         
            +
                  value: count
         
     | 
| 
      
 165 
     | 
    
         
            +
                }
         
     | 
| 
      
 166 
     | 
    
         
            +
              end
         
     | 
| 
      
 167 
     | 
    
         
            +
            ```
         
     | 
| 
      
 168 
     | 
    
         
            +
             
     | 
| 
      
 169 
     | 
    
         
            +
            Create a recommender and get recommended posts for a user
         
     | 
| 
      
 170 
     | 
    
         
            +
             
     | 
| 
      
 171 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 172 
     | 
    
         
            +
            recommender = Cmfrec::Recommender.new
         
     | 
| 
      
 173 
     | 
    
         
            +
            recommender.fit(data)
         
     | 
| 
      
 174 
     | 
    
         
            +
            recommender.user_recs(current_user.id)
         
     | 
| 
      
 175 
     | 
    
         
            +
            ```
         
     | 
| 
      
 176 
     | 
    
         
            +
             
     | 
| 
       128 
177 
     | 
    
         
             
            ## Options
         
     | 
| 
       129 
178 
     | 
    
         | 
| 
       130 
179 
     | 
    
         
             
            Specify the number of factors and epochs
         
     | 
| 
         @@ -163,6 +212,24 @@ Or a Rover data frame 
     | 
|
| 
       163 
212 
     | 
    
         
             
            Rover.read_csv("ratings.csv")
         
     | 
| 
       164 
213 
     | 
    
         
             
            ```
         
     | 
| 
       165 
214 
     | 
    
         | 
| 
      
 215 
     | 
    
         
            +
            ## Storing Recommenders
         
     | 
| 
      
 216 
     | 
    
         
            +
             
     | 
| 
      
 217 
     | 
    
         
            +
            Store the recommender
         
     | 
| 
      
 218 
     | 
    
         
            +
             
     | 
| 
      
 219 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 220 
     | 
    
         
            +
            bin = Marshal.dump(recommender)
         
     | 
| 
      
 221 
     | 
    
         
            +
            File.binwrite("recommender.bin", bin)
         
     | 
| 
      
 222 
     | 
    
         
            +
            ```
         
     | 
| 
      
 223 
     | 
    
         
            +
             
     | 
| 
      
 224 
     | 
    
         
            +
            > You can save it to a file, database, or any other storage system
         
     | 
| 
      
 225 
     | 
    
         
            +
             
     | 
| 
      
 226 
     | 
    
         
            +
            Load a recommender
         
     | 
| 
      
 227 
     | 
    
         
            +
             
     | 
| 
      
 228 
     | 
    
         
            +
            ```ruby
         
     | 
| 
      
 229 
     | 
    
         
            +
            bin = File.binread("recommender.bin")
         
     | 
| 
      
 230 
     | 
    
         
            +
            recommender = Marshal.load(bin)
         
     | 
| 
      
 231 
     | 
    
         
            +
            ```
         
     | 
| 
      
 232 
     | 
    
         
            +
             
     | 
| 
       166 
233 
     | 
    
         
             
            ## Reference
         
     | 
| 
       167 
234 
     | 
    
         | 
| 
       168 
235 
     | 
    
         
             
            Get the global mean
         
     | 
    
        data/lib/cmfrec.rb
    CHANGED
    
    | 
         @@ -18,10 +18,12 @@ module Cmfrec 
     | 
|
| 
       18 
18 
     | 
    
         
             
              lib_name =
         
     | 
| 
       19 
19 
     | 
    
         
             
                if Gem.win_platform?
         
     | 
| 
       20 
20 
     | 
    
         
             
                  "cmfrec.dll"
         
     | 
| 
       21 
     | 
    
         
            -
                elsif RbConfig::CONFIG["arch"] =~ /arm64-darwin/i
         
     | 
| 
       22 
     | 
    
         
            -
                  "libcmfrec.arm64.dylib"
         
     | 
| 
       23 
21 
     | 
    
         
             
                elsif RbConfig::CONFIG["host_os"] =~ /darwin/i
         
     | 
| 
       24 
     | 
    
         
            -
                  " 
     | 
| 
      
 22 
     | 
    
         
            +
                  if RbConfig::CONFIG["host_cpu"] =~ /arm/i
         
     | 
| 
      
 23 
     | 
    
         
            +
                    "libcmfrec.arm64.dylib"
         
     | 
| 
      
 24 
     | 
    
         
            +
                  else
         
     | 
| 
      
 25 
     | 
    
         
            +
                    "libcmfrec.dylib"
         
     | 
| 
      
 26 
     | 
    
         
            +
                  end
         
     | 
| 
       25 
27 
     | 
    
         
             
                else
         
     | 
| 
       26 
28 
     | 
    
         
             
                  "libcmfrec.so"
         
     | 
| 
       27 
29 
     | 
    
         
             
                end
         
     | 
    
        data/lib/cmfrec/recommender.rb
    CHANGED
    
    | 
         @@ -11,19 +11,181 @@ module Cmfrec 
     | 
|
| 
       11 
11 
     | 
    
         
             
                    item_bias: item_bias,
         
     | 
| 
       12 
12 
     | 
    
         
             
                    add_implicit_features: add_implicit_features
         
     | 
| 
       13 
13 
     | 
    
         
             
                  )
         
     | 
| 
      
 14 
     | 
    
         
            +
             
     | 
| 
      
 15 
     | 
    
         
            +
                  @fit = false
         
     | 
| 
      
 16 
     | 
    
         
            +
                  @user_map = {}
         
     | 
| 
      
 17 
     | 
    
         
            +
                  @item_map = {}
         
     | 
| 
      
 18 
     | 
    
         
            +
                  @user_info_map = {}
         
     | 
| 
      
 19 
     | 
    
         
            +
                  @item_info_map = {}
         
     | 
| 
       14 
20 
     | 
    
         
             
                end
         
     | 
| 
       15 
21 
     | 
    
         | 
| 
       16 
22 
     | 
    
         
             
                def fit(train_set, user_info: nil, item_info: nil)
         
     | 
| 
      
 23 
     | 
    
         
            +
                  reset
         
     | 
| 
      
 24 
     | 
    
         
            +
                  partial_fit(train_set, user_info: user_info, item_info: item_info)
         
     | 
| 
      
 25 
     | 
    
         
            +
                end
         
     | 
| 
      
 26 
     | 
    
         
            +
             
     | 
| 
      
 27 
     | 
    
         
            +
                def predict(data)
         
     | 
| 
      
 28 
     | 
    
         
            +
                  check_fit
         
     | 
| 
      
 29 
     | 
    
         
            +
             
     | 
| 
      
 30 
     | 
    
         
            +
                  data = to_dataset(data)
         
     | 
| 
      
 31 
     | 
    
         
            +
             
     | 
| 
      
 32 
     | 
    
         
            +
                  u = data.map { |v| @user_map[v[:user_id]] || @user_map.size }
         
     | 
| 
      
 33 
     | 
    
         
            +
                  i = data.map { |v| @item_map[v[:item_id]] || @item_map.size }
         
     | 
| 
      
 34 
     | 
    
         
            +
             
     | 
| 
      
 35 
     | 
    
         
            +
                  row = int_ptr(u)
         
     | 
| 
      
 36 
     | 
    
         
            +
                  col = int_ptr(i)
         
     | 
| 
      
 37 
     | 
    
         
            +
                  n_predict = data.size
         
     | 
| 
      
 38 
     | 
    
         
            +
                  predicted = Fiddle::Pointer.malloc(n_predict * Fiddle::SIZEOF_DOUBLE)
         
     | 
| 
      
 39 
     | 
    
         
            +
             
     | 
| 
      
 40 
     | 
    
         
            +
                  if @implicit
         
     | 
| 
      
 41 
     | 
    
         
            +
                    check_status FFI.predict_X_old_collective_implicit(
         
     | 
| 
      
 42 
     | 
    
         
            +
                      row, col, predicted, n_predict,
         
     | 
| 
      
 43 
     | 
    
         
            +
                      @a, @b,
         
     | 
| 
      
 44 
     | 
    
         
            +
                      @k, @k_user, @k_item, @k_main,
         
     | 
| 
      
 45 
     | 
    
         
            +
                      @m, @n,
         
     | 
| 
      
 46 
     | 
    
         
            +
                      @nthreads
         
     | 
| 
      
 47 
     | 
    
         
            +
                    )
         
     | 
| 
      
 48 
     | 
    
         
            +
                  else
         
     | 
| 
      
 49 
     | 
    
         
            +
                    check_status FFI.predict_X_old_collective_explicit(
         
     | 
| 
      
 50 
     | 
    
         
            +
                      row, col, predicted, n_predict,
         
     | 
| 
      
 51 
     | 
    
         
            +
                      @a, @bias_a,
         
     | 
| 
      
 52 
     | 
    
         
            +
                      @b, @bias_b,
         
     | 
| 
      
 53 
     | 
    
         
            +
                      @global_mean,
         
     | 
| 
      
 54 
     | 
    
         
            +
                      @k, @k_user, @k_item, @k_main,
         
     | 
| 
      
 55 
     | 
    
         
            +
                      @m, @n,
         
     | 
| 
      
 56 
     | 
    
         
            +
                      @nthreads
         
     | 
| 
      
 57 
     | 
    
         
            +
                    )
         
     | 
| 
      
 58 
     | 
    
         
            +
                  end
         
     | 
| 
      
 59 
     | 
    
         
            +
             
     | 
| 
      
 60 
     | 
    
         
            +
                  predictions = real_array(predicted)
         
     | 
| 
      
 61 
     | 
    
         
            +
                  predictions.map! { |v| v.nan? ? @global_mean : v } if @implicit
         
     | 
| 
      
 62 
     | 
    
         
            +
                  predictions
         
     | 
| 
      
 63 
     | 
    
         
            +
                end
         
     | 
| 
      
 64 
     | 
    
         
            +
             
     | 
| 
      
 65 
     | 
    
         
            +
                def user_recs(user_id, count: 5, item_ids: nil)
         
     | 
| 
      
 66 
     | 
    
         
            +
                  check_fit
         
     | 
| 
      
 67 
     | 
    
         
            +
                  user = @user_map[user_id]
         
     | 
| 
      
 68 
     | 
    
         
            +
             
     | 
| 
      
 69 
     | 
    
         
            +
                  if user
         
     | 
| 
      
 70 
     | 
    
         
            +
                    if item_ids
         
     | 
| 
      
 71 
     | 
    
         
            +
                      # remove missing ids
         
     | 
| 
      
 72 
     | 
    
         
            +
                      item_ids = item_ids.select { |v| @item_map[v] }
         
     | 
| 
      
 73 
     | 
    
         
            +
             
     | 
| 
      
 74 
     | 
    
         
            +
                      data = item_ids.map { |v| {user_id: user_id, item_id: v} }
         
     | 
| 
      
 75 
     | 
    
         
            +
                      scores = predict(data)
         
     | 
| 
      
 76 
     | 
    
         
            +
             
     | 
| 
      
 77 
     | 
    
         
            +
                      item_ids.zip(scores).map do |item_id, score|
         
     | 
| 
      
 78 
     | 
    
         
            +
                        {item_id: item_id, score: score}
         
     | 
| 
      
 79 
     | 
    
         
            +
                      end
         
     | 
| 
      
 80 
     | 
    
         
            +
                    else
         
     | 
| 
      
 81 
     | 
    
         
            +
                      a_vec = @a[user * @k * Fiddle::SIZEOF_DOUBLE, @k * Fiddle::SIZEOF_DOUBLE]
         
     | 
| 
      
 82 
     | 
    
         
            +
                      a_bias = @bias_a ? @bias_a[user * Fiddle::SIZEOF_DOUBLE, Fiddle::SIZEOF_DOUBLE].unpack1("d") : 0
         
     | 
| 
      
 83 
     | 
    
         
            +
                      top_n(a_vec: a_vec, a_bias: a_bias, count: count)
         
     | 
| 
      
 84 
     | 
    
         
            +
                    end
         
     | 
| 
      
 85 
     | 
    
         
            +
                  else
         
     | 
| 
      
 86 
     | 
    
         
            +
                    # no items if user is unknown
         
     | 
| 
      
 87 
     | 
    
         
            +
                    # TODO maybe most popular items
         
     | 
| 
      
 88 
     | 
    
         
            +
                    []
         
     | 
| 
      
 89 
     | 
    
         
            +
                  end
         
     | 
| 
      
 90 
     | 
    
         
            +
                end
         
     | 
| 
      
 91 
     | 
    
         
            +
             
     | 
| 
      
 92 
     | 
    
         
            +
                # TODO add item_ids
         
     | 
| 
      
 93 
     | 
    
         
            +
                def new_user_recs(data, count: 5, user_info: nil)
         
     | 
| 
      
 94 
     | 
    
         
            +
                  check_fit
         
     | 
| 
      
 95 
     | 
    
         
            +
             
     | 
| 
      
 96 
     | 
    
         
            +
                  a_vec, a_bias = factors_warm(data, user_info: user_info)
         
     | 
| 
      
 97 
     | 
    
         
            +
                  top_n(a_vec: a_vec, a_bias: a_bias, count: count)
         
     | 
| 
      
 98 
     | 
    
         
            +
                end
         
     | 
| 
      
 99 
     | 
    
         
            +
             
     | 
| 
      
 100 
     | 
    
         
            +
                def user_factors
         
     | 
| 
      
 101 
     | 
    
         
            +
                  read_factors(@a, [@m, @m_u].max, @k_user + @k + @k_main)
         
     | 
| 
      
 102 
     | 
    
         
            +
                end
         
     | 
| 
      
 103 
     | 
    
         
            +
             
     | 
| 
      
 104 
     | 
    
         
            +
                def item_factors
         
     | 
| 
      
 105 
     | 
    
         
            +
                  read_factors(@b, [@n, @n_i].max, @k_item + @k + @k_main)
         
     | 
| 
      
 106 
     | 
    
         
            +
                end
         
     | 
| 
      
 107 
     | 
    
         
            +
             
     | 
| 
      
 108 
     | 
    
         
            +
                def user_bias
         
     | 
| 
      
 109 
     | 
    
         
            +
                  read_bias(@bias_a) if @bias_a
         
     | 
| 
      
 110 
     | 
    
         
            +
                end
         
     | 
| 
      
 111 
     | 
    
         
            +
             
     | 
| 
      
 112 
     | 
    
         
            +
                def item_bias
         
     | 
| 
      
 113 
     | 
    
         
            +
                  read_bias(@bias_b) if @bias_b
         
     | 
| 
      
 114 
     | 
    
         
            +
                end
         
     | 
| 
      
 115 
     | 
    
         
            +
             
     | 
| 
      
 116 
     | 
    
         
            +
                def similar_items(item_id, count: 5)
         
     | 
| 
      
 117 
     | 
    
         
            +
                  check_fit
         
     | 
| 
      
 118 
     | 
    
         
            +
                  similar(item_id, @item_map, item_factors, count, item_index)
         
     | 
| 
      
 119 
     | 
    
         
            +
                end
         
     | 
| 
      
 120 
     | 
    
         
            +
                alias_method :item_recs, :similar_items
         
     | 
| 
      
 121 
     | 
    
         
            +
             
     | 
| 
      
 122 
     | 
    
         
            +
                def similar_users(user_id, count: 5)
         
     | 
| 
      
 123 
     | 
    
         
            +
                  check_fit
         
     | 
| 
      
 124 
     | 
    
         
            +
                  similar(user_id, @user_map, user_factors, count, user_index)
         
     | 
| 
      
 125 
     | 
    
         
            +
                end
         
     | 
| 
      
 126 
     | 
    
         
            +
             
     | 
| 
      
 127 
     | 
    
         
            +
                private
         
     | 
| 
      
 128 
     | 
    
         
            +
             
     | 
| 
      
 129 
     | 
    
         
            +
                def user_index
         
     | 
| 
      
 130 
     | 
    
         
            +
                  @user_index ||= create_index(user_factors)
         
     | 
| 
      
 131 
     | 
    
         
            +
                end
         
     | 
| 
      
 132 
     | 
    
         
            +
             
     | 
| 
      
 133 
     | 
    
         
            +
                def item_index
         
     | 
| 
      
 134 
     | 
    
         
            +
                  @item_index ||= create_index(item_factors)
         
     | 
| 
      
 135 
     | 
    
         
            +
                end
         
     | 
| 
      
 136 
     | 
    
         
            +
             
     | 
| 
      
 137 
     | 
    
         
            +
                def create_index(factors)
         
     | 
| 
      
 138 
     | 
    
         
            +
                  require "ngt"
         
     | 
| 
      
 139 
     | 
    
         
            +
             
     | 
| 
      
 140 
     | 
    
         
            +
                  index = Ngt::Index.new(@k, distance_type: "Cosine")
         
     | 
| 
      
 141 
     | 
    
         
            +
                  index.batch_insert(factors)
         
     | 
| 
      
 142 
     | 
    
         
            +
                  index
         
     | 
| 
      
 143 
     | 
    
         
            +
                end
         
     | 
| 
      
 144 
     | 
    
         
            +
             
     | 
| 
      
 145 
     | 
    
         
            +
                # TODO include bias
         
     | 
| 
      
 146 
     | 
    
         
            +
                def similar(id, map, factors, count, index)
         
     | 
| 
      
 147 
     | 
    
         
            +
                  i = map[id]
         
     | 
| 
      
 148 
     | 
    
         
            +
                  if i
         
     | 
| 
      
 149 
     | 
    
         
            +
                    keys = map.keys
         
     | 
| 
      
 150 
     | 
    
         
            +
                    result = index.search(factors[i], size: count + 1)[1..-1]
         
     | 
| 
      
 151 
     | 
    
         
            +
                    result.map do |v|
         
     | 
| 
      
 152 
     | 
    
         
            +
                      {
         
     | 
| 
      
 153 
     | 
    
         
            +
                        # ids from batch_insert start at 1 instead of 0
         
     | 
| 
      
 154 
     | 
    
         
            +
                        item_id: keys[v[:id] - 1],
         
     | 
| 
      
 155 
     | 
    
         
            +
                        # convert cosine distance to cosine similarity
         
     | 
| 
      
 156 
     | 
    
         
            +
                        score: 1 - v[:distance]
         
     | 
| 
      
 157 
     | 
    
         
            +
                      }
         
     | 
| 
      
 158 
     | 
    
         
            +
                    end
         
     | 
| 
      
 159 
     | 
    
         
            +
                  else
         
     | 
| 
      
 160 
     | 
    
         
            +
                    []
         
     | 
| 
      
 161 
     | 
    
         
            +
                  end
         
     | 
| 
      
 162 
     | 
    
         
            +
                end
         
     | 
| 
      
 163 
     | 
    
         
            +
             
     | 
| 
      
 164 
     | 
    
         
            +
                def reset
         
     | 
| 
      
 165 
     | 
    
         
            +
                  @fit = false
         
     | 
| 
      
 166 
     | 
    
         
            +
                  @user_map.clear
         
     | 
| 
      
 167 
     | 
    
         
            +
                  @item_map.clear
         
     | 
| 
      
 168 
     | 
    
         
            +
                  @user_info_map.clear
         
     | 
| 
      
 169 
     | 
    
         
            +
                  @item_info_map.clear
         
     | 
| 
      
 170 
     | 
    
         
            +
                  @user_index = nil
         
     | 
| 
      
 171 
     | 
    
         
            +
                  @item_index = nil
         
     | 
| 
      
 172 
     | 
    
         
            +
                end
         
     | 
| 
      
 173 
     | 
    
         
            +
             
     | 
| 
      
 174 
     | 
    
         
            +
                # TODO resize pointers as needed and reset values for new memory
         
     | 
| 
      
 175 
     | 
    
         
            +
                def partial_fit(train_set, user_info: nil, item_info: nil)
         
     | 
| 
       17 
176 
     | 
    
         
             
                  train_set = to_dataset(train_set)
         
     | 
| 
       18 
177 
     | 
    
         | 
| 
       19 
     | 
    
         
            -
                  @ 
     | 
| 
      
 178 
     | 
    
         
            +
                  unless @fit
         
     | 
| 
      
 179 
     | 
    
         
            +
                    @implicit = !train_set.any? { |v| v[:rating] }
         
     | 
| 
      
 180 
     | 
    
         
            +
                  end
         
     | 
| 
      
 181 
     | 
    
         
            +
             
     | 
| 
       20 
182 
     | 
    
         
             
                  unless @implicit
         
     | 
| 
       21 
183 
     | 
    
         
             
                    ratings = train_set.map { |o| o[:rating] }
         
     | 
| 
       22 
184 
     | 
    
         
             
                    check_ratings(ratings)
         
     | 
| 
       23 
185 
     | 
    
         
             
                  end
         
     | 
| 
       24 
186 
     | 
    
         | 
| 
       25 
187 
     | 
    
         
             
                  check_training_set(train_set)
         
     | 
| 
       26 
     | 
    
         
            -
                   
     | 
| 
      
 188 
     | 
    
         
            +
                  update_maps(train_set)
         
     | 
| 
       27 
189 
     | 
    
         | 
| 
       28 
190 
     | 
    
         
             
                  x_row = []
         
     | 
| 
       29 
191 
     | 
    
         
             
                  x_col = []
         
     | 
| 
         @@ -52,16 +214,14 @@ module Cmfrec 
     | 
|
| 
       52 
214 
     | 
    
         
             
                  uu = nil
         
     | 
| 
       53 
215 
     | 
    
         
             
                  ii = nil
         
     | 
| 
       54 
216 
     | 
    
         | 
| 
       55 
     | 
    
         
            -
                   
     | 
| 
      
 217 
     | 
    
         
            +
                  # side info
         
     | 
| 
       56 
218 
     | 
    
         
             
                  u_row, u_col, u_sp, nnz_u, @m_u, p_ = process_info(user_info, @user_map, @user_info_map, :user_id)
         
     | 
| 
       57 
     | 
    
         
            -
             
     | 
| 
       58 
     | 
    
         
            -
                  @item_info_map = {}
         
     | 
| 
       59 
219 
     | 
    
         
             
                  i_row, i_col, i_sp, nnz_i, @n_i, q = process_info(item_info, @item_map, @item_info_map, :item_id)
         
     | 
| 
       60 
220 
     | 
    
         | 
| 
       61 
221 
     | 
    
         
             
                  @precompute_for_predictions = false
         
     | 
| 
       62 
222 
     | 
    
         | 
| 
       63 
223 
     | 
    
         
             
                  # initialize w/ normal distribution
         
     | 
| 
       64 
     | 
    
         
            -
                  reset_values =  
     | 
| 
      
 224 
     | 
    
         
            +
                  reset_values = !@fit
         
     | 
| 
       65 
225 
     | 
    
         | 
| 
       66 
226 
     | 
    
         
             
                  @a = Fiddle::Pointer.malloc([@m, @m_u].max * (@k_user + @k + @k_main) * Fiddle::SIZEOF_DOUBLE)
         
     | 
| 
       67 
227 
     | 
    
         
             
                  @b = Fiddle::Pointer.malloc([@n, @n_i].max * (@k_item + @k + @k_main) * Fiddle::SIZEOF_DOUBLE)
         
     | 
| 
         @@ -75,16 +235,7 @@ module Cmfrec 
     | 
|
| 
       75 
235 
     | 
    
         
             
                  i_colmeans = Fiddle::Pointer.malloc(q * Fiddle::SIZEOF_DOUBLE)
         
     | 
| 
       76 
236 
     | 
    
         | 
| 
       77 
237 
     | 
    
         
             
                  if @implicit
         
     | 
| 
       78 
     | 
    
         
            -
                     
     | 
| 
       79 
     | 
    
         
            -
                    @alpha = 1.0
         
     | 
| 
       80 
     | 
    
         
            -
                    @adjust_weight = false # downweight?
         
     | 
| 
       81 
     | 
    
         
            -
                    @apply_log_transf = false
         
     | 
| 
       82 
     | 
    
         
            -
             
     | 
| 
       83 
     | 
    
         
            -
                    # different defaults
         
     | 
| 
       84 
     | 
    
         
            -
                    @lambda_ = 1e0
         
     | 
| 
       85 
     | 
    
         
            -
                    @w_user = 10
         
     | 
| 
       86 
     | 
    
         
            -
                    @w_item = 10
         
     | 
| 
       87 
     | 
    
         
            -
                    @finalize_chol = false
         
     | 
| 
      
 238 
     | 
    
         
            +
                    set_implicit_vars
         
     | 
| 
       88 
239 
     | 
    
         | 
| 
       89 
240 
     | 
    
         
             
                    args = [
         
     | 
| 
       90 
241 
     | 
    
         
             
                      @a, @b,
         
     | 
| 
         @@ -175,104 +326,13 @@ module Cmfrec 
     | 
|
| 
       175 
326 
     | 
    
         
             
                    @global_mean = real_array(glob_mean).first
         
     | 
| 
       176 
327 
     | 
    
         
             
                  end
         
     | 
| 
       177 
328 
     | 
    
         | 
| 
       178 
     | 
    
         
            -
                  @u_colmeans =  
     | 
| 
       179 
     | 
    
         
            -
                  @i_colmeans = real_array(i_colmeans)
         
     | 
| 
       180 
     | 
    
         
            -
                  @u_colmeans_ptr = u_colmeans
         
     | 
| 
       181 
     | 
    
         
            -
             
     | 
| 
       182 
     | 
    
         
            -
                  self
         
     | 
| 
       183 
     | 
    
         
            -
                end
         
     | 
| 
       184 
     | 
    
         
            -
             
     | 
| 
       185 
     | 
    
         
            -
                def predict(data)
         
     | 
| 
       186 
     | 
    
         
            -
                  check_fit
         
     | 
| 
       187 
     | 
    
         
            -
             
     | 
| 
       188 
     | 
    
         
            -
                  data = to_dataset(data)
         
     | 
| 
       189 
     | 
    
         
            -
             
     | 
| 
       190 
     | 
    
         
            -
                  u = data.map { |v| @user_map[v[:user_id]] || @user_map.size }
         
     | 
| 
       191 
     | 
    
         
            -
                  i = data.map { |v| @item_map[v[:item_id]] || @item_map.size }
         
     | 
| 
       192 
     | 
    
         
            -
             
     | 
| 
       193 
     | 
    
         
            -
                  row = int_ptr(u)
         
     | 
| 
       194 
     | 
    
         
            -
                  col = int_ptr(i)
         
     | 
| 
       195 
     | 
    
         
            -
                  n_predict = data.size
         
     | 
| 
       196 
     | 
    
         
            -
                  predicted = Fiddle::Pointer.malloc(n_predict * Fiddle::SIZEOF_DOUBLE)
         
     | 
| 
       197 
     | 
    
         
            -
             
     | 
| 
       198 
     | 
    
         
            -
                  if @implicit
         
     | 
| 
       199 
     | 
    
         
            -
                    check_status FFI.predict_X_old_collective_implicit(
         
     | 
| 
       200 
     | 
    
         
            -
                      row, col, predicted, n_predict,
         
     | 
| 
       201 
     | 
    
         
            -
                      @a, @b,
         
     | 
| 
       202 
     | 
    
         
            -
                      @k, @k_user, @k_item, @k_main,
         
     | 
| 
       203 
     | 
    
         
            -
                      @m, @n,
         
     | 
| 
       204 
     | 
    
         
            -
                      @nthreads
         
     | 
| 
       205 
     | 
    
         
            -
                    )
         
     | 
| 
       206 
     | 
    
         
            -
                  else
         
     | 
| 
       207 
     | 
    
         
            -
                    check_status FFI.predict_X_old_collective_explicit(
         
     | 
| 
       208 
     | 
    
         
            -
                      row, col, predicted, n_predict,
         
     | 
| 
       209 
     | 
    
         
            -
                      @a, @bias_a,
         
     | 
| 
       210 
     | 
    
         
            -
                      @b, @bias_b,
         
     | 
| 
       211 
     | 
    
         
            -
                      @global_mean,
         
     | 
| 
       212 
     | 
    
         
            -
                      @k, @k_user, @k_item, @k_main,
         
     | 
| 
       213 
     | 
    
         
            -
                      @m, @n,
         
     | 
| 
       214 
     | 
    
         
            -
                      @nthreads
         
     | 
| 
       215 
     | 
    
         
            -
                    )
         
     | 
| 
       216 
     | 
    
         
            -
                  end
         
     | 
| 
       217 
     | 
    
         
            -
             
     | 
| 
       218 
     | 
    
         
            -
                  predictions = real_array(predicted)
         
     | 
| 
       219 
     | 
    
         
            -
                  predictions.map! { |v| v.nan? ? @global_mean : v } if @implicit
         
     | 
| 
       220 
     | 
    
         
            -
                  predictions
         
     | 
| 
       221 
     | 
    
         
            -
                end
         
     | 
| 
       222 
     | 
    
         
            -
             
     | 
| 
       223 
     | 
    
         
            -
                def user_recs(user_id, count: 5, item_ids: nil)
         
     | 
| 
       224 
     | 
    
         
            -
                  check_fit
         
     | 
| 
       225 
     | 
    
         
            -
                  user = @user_map[user_id]
         
     | 
| 
       226 
     | 
    
         
            -
             
     | 
| 
       227 
     | 
    
         
            -
                  if user
         
     | 
| 
       228 
     | 
    
         
            -
                    if item_ids
         
     | 
| 
       229 
     | 
    
         
            -
                      # remove missing ids
         
     | 
| 
       230 
     | 
    
         
            -
                      item_ids = item_ids.select { |v| @item_map[v] }
         
     | 
| 
       231 
     | 
    
         
            -
             
     | 
| 
       232 
     | 
    
         
            -
                      data = item_ids.map { |v| {user_id: user_id, item_id: v} }
         
     | 
| 
       233 
     | 
    
         
            -
                      scores = predict(data)
         
     | 
| 
       234 
     | 
    
         
            -
             
     | 
| 
       235 
     | 
    
         
            -
                      item_ids.zip(scores).map do |item_id, score|
         
     | 
| 
       236 
     | 
    
         
            -
                        {item_id: item_id, score: score}
         
     | 
| 
       237 
     | 
    
         
            -
                      end
         
     | 
| 
       238 
     | 
    
         
            -
                    else
         
     | 
| 
       239 
     | 
    
         
            -
                      a_vec = @a[user * @k * Fiddle::SIZEOF_DOUBLE, @k * Fiddle::SIZEOF_DOUBLE]
         
     | 
| 
       240 
     | 
    
         
            -
                      a_bias = @bias_a ? @bias_a[user * Fiddle::SIZEOF_DOUBLE, Fiddle::SIZEOF_DOUBLE].unpack1("d") : 0
         
     | 
| 
       241 
     | 
    
         
            -
                      top_n(a_vec: a_vec, a_bias: a_bias, count: count)
         
     | 
| 
       242 
     | 
    
         
            -
                    end
         
     | 
| 
       243 
     | 
    
         
            -
                  else
         
     | 
| 
       244 
     | 
    
         
            -
                    # no items if user is unknown
         
     | 
| 
       245 
     | 
    
         
            -
                    # TODO maybe most popular items
         
     | 
| 
       246 
     | 
    
         
            -
                    []
         
     | 
| 
       247 
     | 
    
         
            -
                  end
         
     | 
| 
       248 
     | 
    
         
            -
                end
         
     | 
| 
       249 
     | 
    
         
            -
             
     | 
| 
       250 
     | 
    
         
            -
                # TODO add item_ids
         
     | 
| 
       251 
     | 
    
         
            -
                def new_user_recs(data, count: 5, user_info: nil)
         
     | 
| 
       252 
     | 
    
         
            -
                  check_fit
         
     | 
| 
       253 
     | 
    
         
            -
             
     | 
| 
       254 
     | 
    
         
            -
                  a_vec, a_bias = factors_warm(data, user_info: user_info)
         
     | 
| 
       255 
     | 
    
         
            -
                  top_n(a_vec: a_vec, a_bias: a_bias, count: count)
         
     | 
| 
       256 
     | 
    
         
            -
                end
         
     | 
| 
       257 
     | 
    
         
            -
             
     | 
| 
       258 
     | 
    
         
            -
                def user_factors
         
     | 
| 
       259 
     | 
    
         
            -
                  read_factors(@a, [@m, @m_u].max, @k_user + @k + @k_main)
         
     | 
| 
       260 
     | 
    
         
            -
                end
         
     | 
| 
       261 
     | 
    
         
            -
             
     | 
| 
       262 
     | 
    
         
            -
                def item_factors
         
     | 
| 
       263 
     | 
    
         
            -
                  read_factors(@b, [@n, @n_i].max, @k_item + @k + @k_main)
         
     | 
| 
       264 
     | 
    
         
            -
                end
         
     | 
| 
      
 329 
     | 
    
         
            +
                  @u_colmeans = u_colmeans
         
     | 
| 
       265 
330 
     | 
    
         | 
| 
       266 
     | 
    
         
            -
             
     | 
| 
       267 
     | 
    
         
            -
                  read_bias(@bias_a) if @bias_a
         
     | 
| 
       268 
     | 
    
         
            -
                end
         
     | 
| 
      
 331 
     | 
    
         
            +
                  @fit = true
         
     | 
| 
       269 
332 
     | 
    
         | 
| 
       270 
     | 
    
         
            -
             
     | 
| 
       271 
     | 
    
         
            -
                  read_bias(@bias_b) if @bias_b
         
     | 
| 
      
 333 
     | 
    
         
            +
                  self
         
     | 
| 
       272 
334 
     | 
    
         
             
                end
         
     | 
| 
       273 
335 
     | 
    
         | 
| 
       274 
     | 
    
         
            -
                private
         
     | 
| 
       275 
     | 
    
         
            -
             
     | 
| 
       276 
336 
     | 
    
         
             
                def set_params(
         
     | 
| 
       277 
337 
     | 
    
         
             
                  k: 40, lambda_: 1e+1, method: "als", use_cg: true, user_bias: true,
         
     | 
| 
       278 
338 
     | 
    
         
             
                  item_bias: true, add_implicit_features: false,
         
     | 
| 
         @@ -329,15 +389,14 @@ module Cmfrec 
     | 
|
| 
       329 
389 
     | 
    
         
             
                  @nthreads = nthreads
         
     | 
| 
       330 
390 
     | 
    
         
             
                end
         
     | 
| 
       331 
391 
     | 
    
         | 
| 
       332 
     | 
    
         
            -
                def  
     | 
| 
       333 
     | 
    
         
            -
                   
     | 
| 
       334 
     | 
    
         
            -
                   
     | 
| 
       335 
     | 
    
         
            -
             
     | 
| 
       336 
     | 
    
         
            -
                  raise ArgumentError, "Missing user_id" if user_ids.any?(&:nil?)
         
     | 
| 
       337 
     | 
    
         
            -
                  raise ArgumentError, "Missing item_id" if item_ids.any?(&:nil?)
         
     | 
| 
      
 392 
     | 
    
         
            +
                def update_maps(train_set)
         
     | 
| 
      
 393 
     | 
    
         
            +
                  raise ArgumentError, "Missing user_id" if train_set.any? { |v| v[:user_id].nil? }
         
     | 
| 
      
 394 
     | 
    
         
            +
                  raise ArgumentError, "Missing item_id" if train_set.any? { |v| v[:item_id].nil? }
         
     | 
| 
       338 
395 
     | 
    
         | 
| 
       339 
     | 
    
         
            -
                   
     | 
| 
       340 
     | 
    
         
            -
             
     | 
| 
      
 396 
     | 
    
         
            +
                  train_set.each do |v|
         
     | 
| 
      
 397 
     | 
    
         
            +
                    @user_map[v[:user_id]] ||= @user_map.size
         
     | 
| 
      
 398 
     | 
    
         
            +
                    @item_map[v[:item_id]] ||= @item_map.size
         
     | 
| 
      
 399 
     | 
    
         
            +
                  end
         
     | 
| 
       341 
400 
     | 
    
         
             
                end
         
     | 
| 
       342 
401 
     | 
    
         | 
| 
       343 
402 
     | 
    
         
             
                def check_ratings(ratings)
         
     | 
| 
         @@ -354,7 +413,7 @@ module Cmfrec 
     | 
|
| 
       354 
413 
     | 
    
         
             
                end
         
     | 
| 
       355 
414 
     | 
    
         | 
| 
       356 
415 
     | 
    
         
             
                def check_fit
         
     | 
| 
       357 
     | 
    
         
            -
                  raise "Not fit" unless  
     | 
| 
      
 416 
     | 
    
         
            +
                  raise "Not fit" unless @fit
         
     | 
| 
       358 
417 
     | 
    
         
             
                end
         
     | 
| 
       359 
418 
     | 
    
         | 
| 
       360 
419 
     | 
    
         
             
                def to_dataset(dataset)
         
     | 
| 
         @@ -479,7 +538,7 @@ module Cmfrec 
     | 
|
| 
       479 
538 
     | 
    
         
             
                      u_vec_sp, u_vec_x_col, nnz_u_vec,
         
     | 
| 
       480 
539 
     | 
    
         
             
                      @na_as_zero_user,
         
     | 
| 
       481 
540 
     | 
    
         
             
                      @nonneg,
         
     | 
| 
       482 
     | 
    
         
            -
                      @ 
     | 
| 
      
 541 
     | 
    
         
            +
                      @u_colmeans,
         
     | 
| 
       483 
542 
     | 
    
         
             
                      @b, @n, @c,
         
     | 
| 
       484 
543 
     | 
    
         
             
                      xa, x_col, nnz,
         
     | 
| 
       485 
544 
     | 
    
         
             
                      @k, @k_user, @k_item, @k_main,
         
     | 
| 
         @@ -505,7 +564,7 @@ module Cmfrec 
     | 
|
| 
       505 
564 
     | 
    
         
             
                      @na_as_zero_user, @na_as_zero,
         
     | 
| 
       506 
565 
     | 
    
         
             
                      @nonneg,
         
     | 
| 
       507 
566 
     | 
    
         
             
                      @c, cb,
         
     | 
| 
       508 
     | 
    
         
            -
                      @global_mean, @bias_b, @ 
     | 
| 
      
 567 
     | 
    
         
            +
                      @global_mean, @bias_b, @u_colmeans,
         
     | 
| 
       509 
568 
     | 
    
         
             
                      xa, x_col, nnz, xa_dense,
         
     | 
| 
       510 
569 
     | 
    
         
             
                      @n, weight, @b, @bi,
         
     | 
| 
       511 
570 
     | 
    
         
             
                      @add_implicit_features,
         
     | 
| 
         @@ -585,5 +644,124 @@ module Cmfrec 
     | 
|
| 
       585 
644 
     | 
    
         
             
                def real_array(ptr)
         
     | 
| 
       586 
645 
     | 
    
         
             
                  ptr.to_s(ptr.size).unpack("d*")
         
     | 
| 
       587 
646 
     | 
    
         
             
                end
         
     | 
| 
      
 647 
     | 
    
         
            +
             
     | 
| 
      
 648 
     | 
    
         
            +
                def set_implicit_vars
         
     | 
| 
      
 649 
     | 
    
         
            +
                  @w_main_multiplier = 1.0
         
     | 
| 
      
 650 
     | 
    
         
            +
                  @alpha = 1.0
         
     | 
| 
      
 651 
     | 
    
         
            +
                  @adjust_weight = false # downweight?
         
     | 
| 
      
 652 
     | 
    
         
            +
                  @apply_log_transf = false
         
     | 
| 
      
 653 
     | 
    
         
            +
             
     | 
| 
      
 654 
     | 
    
         
            +
                  # different defaults
         
     | 
| 
      
 655 
     | 
    
         
            +
                  @lambda_ = 1e0
         
     | 
| 
      
 656 
     | 
    
         
            +
                  @w_user = 10
         
     | 
| 
      
 657 
     | 
    
         
            +
                  @w_item = 10
         
     | 
| 
      
 658 
     | 
    
         
            +
                  @finalize_chol = false
         
     | 
| 
      
 659 
     | 
    
         
            +
                end
         
     | 
| 
      
 660 
     | 
    
         
            +
             
     | 
| 
      
 661 
     | 
    
         
            +
                def dump_ptr(ptr)
         
     | 
| 
      
 662 
     | 
    
         
            +
                  ptr.to_s(ptr.size) if ptr
         
     | 
| 
      
 663 
     | 
    
         
            +
                end
         
     | 
| 
      
 664 
     | 
    
         
            +
             
     | 
| 
      
 665 
     | 
    
         
            +
                def load_ptr(str)
         
     | 
| 
      
 666 
     | 
    
         
            +
                  Fiddle::Pointer[str] if str
         
     | 
| 
      
 667 
     | 
    
         
            +
                end
         
     | 
| 
      
 668 
     | 
    
         
            +
             
     | 
| 
      
 669 
     | 
    
         
            +
                def marshal_dump
         
     | 
| 
      
 670 
     | 
    
         
            +
                  obj = {
         
     | 
| 
      
 671 
     | 
    
         
            +
                    implicit: @implicit
         
     | 
| 
      
 672 
     | 
    
         
            +
                  }
         
     | 
| 
      
 673 
     | 
    
         
            +
             
     | 
| 
      
 674 
     | 
    
         
            +
                  # options
         
     | 
| 
      
 675 
     | 
    
         
            +
                  obj[:factors] = @k
         
     | 
| 
      
 676 
     | 
    
         
            +
                  obj[:epochs] = @niter
         
     | 
| 
      
 677 
     | 
    
         
            +
                  obj[:verbose] = @verbose
         
     | 
| 
      
 678 
     | 
    
         
            +
             
     | 
| 
      
 679 
     | 
    
         
            +
                  # factors
         
     | 
| 
      
 680 
     | 
    
         
            +
                  obj[:user_map] = @user_map
         
     | 
| 
      
 681 
     | 
    
         
            +
                  obj[:item_map] = @item_map
         
     | 
| 
      
 682 
     | 
    
         
            +
                  obj[:user_factors] = dump_ptr(@a)
         
     | 
| 
      
 683 
     | 
    
         
            +
                  obj[:item_factors] = dump_ptr(@b)
         
     | 
| 
      
 684 
     | 
    
         
            +
             
     | 
| 
      
 685 
     | 
    
         
            +
                  # bias
         
     | 
| 
      
 686 
     | 
    
         
            +
                  obj[:user_bias] = dump_ptr(@bias_a)
         
     | 
| 
      
 687 
     | 
    
         
            +
                  obj[:item_bias] = dump_ptr(@bias_b)
         
     | 
| 
      
 688 
     | 
    
         
            +
             
     | 
| 
      
 689 
     | 
    
         
            +
                  # mean
         
     | 
| 
      
 690 
     | 
    
         
            +
                  obj[:global_mean] = @global_mean
         
     | 
| 
      
 691 
     | 
    
         
            +
             
     | 
| 
      
 692 
     | 
    
         
            +
                  # side info
         
     | 
| 
      
 693 
     | 
    
         
            +
                  obj[:user_info_map] = @user_info_map
         
     | 
| 
      
 694 
     | 
    
         
            +
                  obj[:item_info_map] = @item_info_map
         
     | 
| 
      
 695 
     | 
    
         
            +
                  obj[:user_info_factors] = dump_ptr(@c)
         
     | 
| 
      
 696 
     | 
    
         
            +
                  obj[:item_info_factors] = dump_ptr(@d)
         
     | 
| 
      
 697 
     | 
    
         
            +
             
     | 
| 
      
 698 
     | 
    
         
            +
                  # implicit features
         
     | 
| 
      
 699 
     | 
    
         
            +
                  obj[:add_implicit_features] = @add_implicit_features
         
     | 
| 
      
 700 
     | 
    
         
            +
                  obj[:user_factors_implicit] = dump_ptr(@ai)
         
     | 
| 
      
 701 
     | 
    
         
            +
                  obj[:item_factors_implicit] = dump_ptr(@bi)
         
     | 
| 
      
 702 
     | 
    
         
            +
             
     | 
| 
      
 703 
     | 
    
         
            +
                  unless @implicit
         
     | 
| 
      
 704 
     | 
    
         
            +
                    obj[:min_rating] = @min_rating
         
     | 
| 
      
 705 
     | 
    
         
            +
                    obj[:max_rating] = @max_rating
         
     | 
| 
      
 706 
     | 
    
         
            +
                  end
         
     | 
| 
      
 707 
     | 
    
         
            +
             
     | 
| 
      
 708 
     | 
    
         
            +
                  obj[:user_means] = dump_ptr(@u_colmeans)
         
     | 
| 
      
 709 
     | 
    
         
            +
             
     | 
| 
      
 710 
     | 
    
         
            +
                  obj
         
     | 
| 
      
 711 
     | 
    
         
            +
                end
         
     | 
| 
      
 712 
     | 
    
         
            +
             
     | 
| 
      
 713 
     | 
    
         
            +
                def marshal_load(obj)
         
     | 
| 
      
 714 
     | 
    
         
            +
                  @implicit = obj[:implicit]
         
     | 
| 
      
 715 
     | 
    
         
            +
             
     | 
| 
      
 716 
     | 
    
         
            +
                  # options
         
     | 
| 
      
 717 
     | 
    
         
            +
                  set_params(
         
     | 
| 
      
 718 
     | 
    
         
            +
                    k: obj[:factors],
         
     | 
| 
      
 719 
     | 
    
         
            +
                    niter: obj[:epochs],
         
     | 
| 
      
 720 
     | 
    
         
            +
                    verbose: obj[:verbose],
         
     | 
| 
      
 721 
     | 
    
         
            +
                    user_bias: !obj[:user_bias].nil?,
         
     | 
| 
      
 722 
     | 
    
         
            +
                    item_bias: !obj[:item_bias].nil?,
         
     | 
| 
      
 723 
     | 
    
         
            +
                    add_implicit_features: obj[:add_implicit_features]
         
     | 
| 
      
 724 
     | 
    
         
            +
                  )
         
     | 
| 
      
 725 
     | 
    
         
            +
             
     | 
| 
      
 726 
     | 
    
         
            +
                  # factors
         
     | 
| 
      
 727 
     | 
    
         
            +
                  @user_map = obj[:user_map]
         
     | 
| 
      
 728 
     | 
    
         
            +
                  @item_map = obj[:item_map]
         
     | 
| 
      
 729 
     | 
    
         
            +
                  @a = load_ptr(obj[:user_factors])
         
     | 
| 
      
 730 
     | 
    
         
            +
                  @b = load_ptr(obj[:item_factors])
         
     | 
| 
      
 731 
     | 
    
         
            +
             
     | 
| 
      
 732 
     | 
    
         
            +
                  # bias
         
     | 
| 
      
 733 
     | 
    
         
            +
                  @bias_a = load_ptr(obj[:user_bias])
         
     | 
| 
      
 734 
     | 
    
         
            +
                  @bias_b = load_ptr(obj[:item_bias])
         
     | 
| 
      
 735 
     | 
    
         
            +
             
     | 
| 
      
 736 
     | 
    
         
            +
                  # mean
         
     | 
| 
      
 737 
     | 
    
         
            +
                  @global_mean = obj[:global_mean]
         
     | 
| 
      
 738 
     | 
    
         
            +
             
     | 
| 
      
 739 
     | 
    
         
            +
                  # side info
         
     | 
| 
      
 740 
     | 
    
         
            +
                  @user_info_map = obj[:user_info_map]
         
     | 
| 
      
 741 
     | 
    
         
            +
                  @item_info_map = obj[:item_info_map]
         
     | 
| 
      
 742 
     | 
    
         
            +
                  @c = load_ptr(obj[:user_info_factors])
         
     | 
| 
      
 743 
     | 
    
         
            +
                  @d = load_ptr(obj[:item_info_factors])
         
     | 
| 
      
 744 
     | 
    
         
            +
             
     | 
| 
      
 745 
     | 
    
         
            +
                  # implicit features
         
     | 
| 
      
 746 
     | 
    
         
            +
                  @add_implicit_features = obj[:add_implicit_features]
         
     | 
| 
      
 747 
     | 
    
         
            +
                  @ai = load_ptr(obj[:user_factors_implicit])
         
     | 
| 
      
 748 
     | 
    
         
            +
                  @bi = load_ptr(obj[:item_factors_implicit])
         
     | 
| 
      
 749 
     | 
    
         
            +
             
     | 
| 
      
 750 
     | 
    
         
            +
                  unless @implicit
         
     | 
| 
      
 751 
     | 
    
         
            +
                    @min_rating = obj[:min_rating]
         
     | 
| 
      
 752 
     | 
    
         
            +
                    @max_rating = obj[:max_rating]
         
     | 
| 
      
 753 
     | 
    
         
            +
                  end
         
     | 
| 
      
 754 
     | 
    
         
            +
             
     | 
| 
      
 755 
     | 
    
         
            +
                  @u_colmeans = load_ptr(obj[:user_means])
         
     | 
| 
      
 756 
     | 
    
         
            +
             
     | 
| 
      
 757 
     | 
    
         
            +
                  @m = @user_map.size
         
     | 
| 
      
 758 
     | 
    
         
            +
                  @n = @item_map.size
         
     | 
| 
      
 759 
     | 
    
         
            +
                  @m_u = @user_info_map.size
         
     | 
| 
      
 760 
     | 
    
         
            +
                  @n_i = @item_info_map.size
         
     | 
| 
      
 761 
     | 
    
         
            +
             
     | 
| 
      
 762 
     | 
    
         
            +
                  set_implicit_vars if @implicit
         
     | 
| 
      
 763 
     | 
    
         
            +
             
     | 
| 
      
 764 
     | 
    
         
            +
                  @fit = @m > 0
         
     | 
| 
      
 765 
     | 
    
         
            +
                end
         
     | 
| 
       588 
766 
     | 
    
         
             
              end
         
     | 
| 
       589 
767 
     | 
    
         
             
            end
         
     | 
    
        data/lib/cmfrec/version.rb
    CHANGED
    
    
    
        metadata
    CHANGED
    
    | 
         @@ -1,17 +1,17 @@ 
     | 
|
| 
       1 
1 
     | 
    
         
             
            --- !ruby/object:Gem::Specification
         
     | 
| 
       2 
2 
     | 
    
         
             
            name: cmfrec
         
     | 
| 
       3 
3 
     | 
    
         
             
            version: !ruby/object:Gem::Version
         
     | 
| 
       4 
     | 
    
         
            -
              version: 0.1. 
     | 
| 
      
 4 
     | 
    
         
            +
              version: 0.1.4
         
     | 
| 
       5 
5 
     | 
    
         
             
            platform: ruby
         
     | 
| 
       6 
6 
     | 
    
         
             
            authors:
         
     | 
| 
       7 
7 
     | 
    
         
             
            - Andrew Kane
         
     | 
| 
       8 
8 
     | 
    
         
             
            autorequire:
         
     | 
| 
       9 
9 
     | 
    
         
             
            bindir: bin
         
     | 
| 
       10 
10 
     | 
    
         
             
            cert_chain: []
         
     | 
| 
       11 
     | 
    
         
            -
            date:  
     | 
| 
      
 11 
     | 
    
         
            +
            date: 2021-02-05 00:00:00.000000000 Z
         
     | 
| 
       12 
12 
     | 
    
         
             
            dependencies: []
         
     | 
| 
       13 
13 
     | 
    
         
             
            description:
         
     | 
| 
       14 
     | 
    
         
            -
            email: andrew@ 
     | 
| 
      
 14 
     | 
    
         
            +
            email: andrew@ankane.org
         
     | 
| 
       15 
15 
     | 
    
         
             
            executables: []
         
     | 
| 
       16 
16 
     | 
    
         
             
            extensions: []
         
     | 
| 
       17 
17 
     | 
    
         
             
            extra_rdoc_files: []
         
     |