libmf 0.1.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- checksums.yaml +7 -0
- data/CHANGELOG.md +3 -0
- data/LICENSE.txt +22 -0
- data/README.md +125 -0
- data/ext/libmf/extconf.rb +18 -0
- data/lib/libmf.bundle +0 -0
- data/lib/libmf.rb +26 -0
- data/lib/libmf/ffi.rb +62 -0
- data/lib/libmf/model.rb +112 -0
- data/lib/libmf/version.rb +3 -0
- data/vendor/libmf/COPYRIGHT +31 -0
- data/vendor/libmf/Makefile +34 -0
- data/vendor/libmf/Makefile.win +36 -0
- data/vendor/libmf/README +637 -0
- data/vendor/libmf/demo/all_one_matrix.te.txt +1382 -0
- data/vendor/libmf/demo/all_one_matrix.tr.txt +5172 -0
- data/vendor/libmf/demo/binary_matrix.te.txt +1312 -0
- data/vendor/libmf/demo/binary_matrix.tr.txt +4937 -0
- data/vendor/libmf/demo/demo.bat +40 -0
- data/vendor/libmf/demo/demo.sh +58 -0
- data/vendor/libmf/demo/real_matrix.te.txt +794 -0
- data/vendor/libmf/demo/real_matrix.tr.txt +5000 -0
- data/vendor/libmf/mf-predict.cpp +207 -0
- data/vendor/libmf/mf-train.cpp +378 -0
- data/vendor/libmf/mf.cpp +4683 -0
- data/vendor/libmf/mf.def +21 -0
- data/vendor/libmf/mf.h +130 -0
- data/vendor/libmf/windows/mf-predict.exe +0 -0
- data/vendor/libmf/windows/mf-train.exe +0 -0
- data/vendor/libmf/windows/mf.dll +0 -0
- metadata +142 -0
data/vendor/libmf/mf.cpp
ADDED
@@ -0,0 +1,4683 @@
|
|
1
|
+
#include <algorithm>
|
2
|
+
#include <cmath>
|
3
|
+
#include <condition_variable>
|
4
|
+
#include <cstdlib>
|
5
|
+
#include <cstring>
|
6
|
+
#include <fstream>
|
7
|
+
#include <iostream>
|
8
|
+
#include <iomanip>
|
9
|
+
#include <memory>
|
10
|
+
#include <numeric>
|
11
|
+
#include <queue>
|
12
|
+
#include <random>
|
13
|
+
#include <stdexcept>
|
14
|
+
#include <string>
|
15
|
+
#include <thread>
|
16
|
+
#include <unordered_set>
|
17
|
+
#include <vector>
|
18
|
+
#include <limits>
|
19
|
+
|
20
|
+
#include "mf.h"
|
21
|
+
|
22
|
+
#if defined USESSE
|
23
|
+
#include <pmmintrin.h>
|
24
|
+
#endif
|
25
|
+
|
26
|
+
#if defined USEAVX
|
27
|
+
#include <immintrin.h>
|
28
|
+
#endif
|
29
|
+
|
30
|
+
#if defined USEOMP
|
31
|
+
#include <omp.h>
|
32
|
+
#endif
|
33
|
+
|
34
|
+
namespace mf
|
35
|
+
{
|
36
|
+
|
37
|
+
using namespace std;
|
38
|
+
|
39
|
+
namespace // unnamed namespace
|
40
|
+
{
|
41
|
+
|
42
|
+
mf_int const kALIGNByte = 32;
|
43
|
+
mf_int const kALIGN = kALIGNByte/sizeof(mf_float);
|
44
|
+
|
45
|
+
//--------------------------------------
|
46
|
+
//---------Scheduler of Blocks----------
|
47
|
+
//--------------------------------------
|
48
|
+
|
49
|
+
class Scheduler
|
50
|
+
{
|
51
|
+
public:
|
52
|
+
Scheduler(mf_int nr_bins, mf_int nr_threads, vector<mf_int> cv_blocks);
|
53
|
+
mf_int get_job();
|
54
|
+
mf_int get_bpr_job(mf_int first_block, bool is_column_oriented);
|
55
|
+
void put_job(mf_int block, mf_double loss, mf_double error);
|
56
|
+
void put_bpr_job(mf_int first_block, mf_int second_block);
|
57
|
+
mf_double get_loss();
|
58
|
+
mf_double get_error();
|
59
|
+
mf_int get_negative(mf_int first_block, mf_int second_block,
|
60
|
+
mf_int m, mf_int n, bool is_column_oriented);
|
61
|
+
void wait_for_jobs_done();
|
62
|
+
void resume();
|
63
|
+
void terminate();
|
64
|
+
bool is_terminated();
|
65
|
+
|
66
|
+
private:
|
67
|
+
mf_int nr_bins;
|
68
|
+
mf_int nr_threads;
|
69
|
+
mf_int nr_done_jobs;
|
70
|
+
mf_int target;
|
71
|
+
mf_int nr_paused_threads;
|
72
|
+
bool terminated;
|
73
|
+
vector<mf_int> counts;
|
74
|
+
vector<mf_int> busy_p_blocks;
|
75
|
+
vector<mf_int> busy_q_blocks;
|
76
|
+
vector<mf_double> block_losses;
|
77
|
+
vector<mf_double> block_errors;
|
78
|
+
vector<minstd_rand0> block_generators;
|
79
|
+
unordered_set<mf_int> cv_blocks;
|
80
|
+
mutex mtx;
|
81
|
+
condition_variable cond_var;
|
82
|
+
default_random_engine generator;
|
83
|
+
uniform_real_distribution<mf_float> distribution;
|
84
|
+
priority_queue<pair<mf_float, mf_int>,
|
85
|
+
vector<pair<mf_float, mf_int>>,
|
86
|
+
greater<pair<mf_float, mf_int>>> pq;
|
87
|
+
};
|
88
|
+
|
89
|
+
Scheduler::Scheduler(mf_int nr_bins, mf_int nr_threads,
|
90
|
+
vector<mf_int> cv_blocks)
|
91
|
+
: nr_bins(nr_bins),
|
92
|
+
nr_threads(nr_threads),
|
93
|
+
nr_done_jobs(0),
|
94
|
+
target(nr_bins*nr_bins),
|
95
|
+
nr_paused_threads(0),
|
96
|
+
terminated(false),
|
97
|
+
counts(nr_bins*nr_bins, 0),
|
98
|
+
busy_p_blocks(nr_bins, 0),
|
99
|
+
busy_q_blocks(nr_bins, 0),
|
100
|
+
block_losses(nr_bins*nr_bins, 0),
|
101
|
+
block_errors(nr_bins*nr_bins, 0),
|
102
|
+
cv_blocks(cv_blocks.begin(), cv_blocks.end()),
|
103
|
+
distribution(0.0, 1.0)
|
104
|
+
{
|
105
|
+
for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
|
106
|
+
{
|
107
|
+
if(this->cv_blocks.find(i) == this->cv_blocks.end())
|
108
|
+
pq.emplace(distribution(generator), i);
|
109
|
+
block_generators.push_back(minstd_rand0(rand()));
|
110
|
+
}
|
111
|
+
}
|
112
|
+
|
113
|
+
mf_int Scheduler::get_job()
|
114
|
+
{
|
115
|
+
bool is_found = false;
|
116
|
+
pair<mf_float, mf_int> block;
|
117
|
+
|
118
|
+
while(!is_found)
|
119
|
+
{
|
120
|
+
lock_guard<mutex> lock(mtx);
|
121
|
+
vector<pair<mf_float, mf_int>> locked_blocks;
|
122
|
+
mf_int p_block = 0;
|
123
|
+
mf_int q_block = 0;
|
124
|
+
|
125
|
+
while(!pq.empty())
|
126
|
+
{
|
127
|
+
block = pq.top();
|
128
|
+
pq.pop();
|
129
|
+
|
130
|
+
p_block = block.second/nr_bins;
|
131
|
+
q_block = block.second%nr_bins;
|
132
|
+
|
133
|
+
if(busy_p_blocks[p_block] || busy_q_blocks[q_block])
|
134
|
+
locked_blocks.push_back(block);
|
135
|
+
else
|
136
|
+
{
|
137
|
+
busy_p_blocks[p_block] = 1;
|
138
|
+
busy_q_blocks[q_block] = 1;
|
139
|
+
counts[block.second] += 1;
|
140
|
+
is_found = true;
|
141
|
+
break;
|
142
|
+
}
|
143
|
+
}
|
144
|
+
|
145
|
+
for(auto &block1 : locked_blocks)
|
146
|
+
pq.push(block1);
|
147
|
+
}
|
148
|
+
|
149
|
+
return block.second;
|
150
|
+
}
|
151
|
+
|
152
|
+
mf_int Scheduler::get_bpr_job(mf_int first_block, bool is_column_oriented)
|
153
|
+
{
|
154
|
+
lock_guard<mutex> lock(mtx);
|
155
|
+
mf_int another = first_block;
|
156
|
+
vector<pair<mf_float, mf_int>> locked_blocks;
|
157
|
+
|
158
|
+
while(!pq.empty())
|
159
|
+
{
|
160
|
+
pair<mf_float, mf_int> block = pq.top();
|
161
|
+
pq.pop();
|
162
|
+
|
163
|
+
mf_int p_block = block.second/nr_bins;
|
164
|
+
mf_int q_block = block.second%nr_bins;
|
165
|
+
|
166
|
+
auto is_rejected = [&] ()
|
167
|
+
{
|
168
|
+
if(is_column_oriented)
|
169
|
+
return first_block%nr_bins != q_block ||
|
170
|
+
busy_p_blocks[p_block];
|
171
|
+
else
|
172
|
+
return first_block/nr_bins != p_block ||
|
173
|
+
busy_q_blocks[q_block];
|
174
|
+
};
|
175
|
+
|
176
|
+
if(is_rejected())
|
177
|
+
locked_blocks.push_back(block);
|
178
|
+
else
|
179
|
+
{
|
180
|
+
busy_p_blocks[p_block] = 1;
|
181
|
+
busy_q_blocks[q_block] = 1;
|
182
|
+
another = block.second;
|
183
|
+
break;
|
184
|
+
}
|
185
|
+
}
|
186
|
+
|
187
|
+
for(auto &block : locked_blocks)
|
188
|
+
pq.push(block);
|
189
|
+
|
190
|
+
return another;
|
191
|
+
}
|
192
|
+
|
193
|
+
void Scheduler::put_job(mf_int block_idx, mf_double loss, mf_double error)
|
194
|
+
{
|
195
|
+
// Return the held block to the scheduler
|
196
|
+
{
|
197
|
+
lock_guard<mutex> lock(mtx);
|
198
|
+
busy_p_blocks[block_idx/nr_bins] = 0;
|
199
|
+
busy_q_blocks[block_idx%nr_bins] = 0;
|
200
|
+
block_losses[block_idx] = loss;
|
201
|
+
block_errors[block_idx] = error;
|
202
|
+
++nr_done_jobs;
|
203
|
+
mf_float priority =
|
204
|
+
(mf_float)counts[block_idx]+distribution(generator);
|
205
|
+
pq.emplace(priority, block_idx);
|
206
|
+
++nr_paused_threads;
|
207
|
+
// Tell others that a block is available again.
|
208
|
+
cond_var.notify_all();
|
209
|
+
}
|
210
|
+
|
211
|
+
// Wait if nr_done_jobs (aka the number of processed blocks) is too many
|
212
|
+
// because we want to print out the training status roughly once all blocks
|
213
|
+
// are processed once. This is the only place that a solver thread should
|
214
|
+
// wait for something.
|
215
|
+
{
|
216
|
+
unique_lock<mutex> lock(mtx);
|
217
|
+
cond_var.wait(lock, [&] {
|
218
|
+
return nr_done_jobs < target;
|
219
|
+
});
|
220
|
+
}
|
221
|
+
|
222
|
+
// Nothing is blocking and this thread is going to take another block
|
223
|
+
{
|
224
|
+
lock_guard<mutex> lock(mtx);
|
225
|
+
--nr_paused_threads;
|
226
|
+
}
|
227
|
+
}
|
228
|
+
|
229
|
+
void Scheduler::put_bpr_job(mf_int first_block, mf_int second_block)
|
230
|
+
{
|
231
|
+
if(first_block == second_block)
|
232
|
+
return;
|
233
|
+
|
234
|
+
lock_guard<mutex> lock(mtx);
|
235
|
+
{
|
236
|
+
busy_p_blocks[second_block/nr_bins] = 0;
|
237
|
+
busy_q_blocks[second_block%nr_bins] = 0;
|
238
|
+
mf_float priority =
|
239
|
+
(mf_float)counts[second_block]+distribution(generator);
|
240
|
+
pq.emplace(priority, second_block);
|
241
|
+
}
|
242
|
+
}
|
243
|
+
|
244
|
+
mf_double Scheduler::get_loss()
|
245
|
+
{
|
246
|
+
lock_guard<mutex> lock(mtx);
|
247
|
+
return accumulate(block_losses.begin(), block_losses.end(), 0.0);
|
248
|
+
}
|
249
|
+
|
250
|
+
mf_double Scheduler::get_error()
|
251
|
+
{
|
252
|
+
lock_guard<mutex> lock(mtx);
|
253
|
+
return accumulate(block_errors.begin(), block_errors.end(), 0.0);
|
254
|
+
}
|
255
|
+
|
256
|
+
mf_int Scheduler::get_negative(mf_int first_block, mf_int second_block,
|
257
|
+
mf_int m, mf_int n, bool is_column_oriented)
|
258
|
+
{
|
259
|
+
mf_int rand_val = (mf_int)block_generators[first_block]();
|
260
|
+
|
261
|
+
auto gen_random = [&] (mf_int block_id)
|
262
|
+
{
|
263
|
+
mf_int v_min, v_max;
|
264
|
+
|
265
|
+
if(is_column_oriented)
|
266
|
+
{
|
267
|
+
mf_int seg_size = (mf_int)ceil((double)m/nr_bins);
|
268
|
+
v_min = min((block_id/nr_bins)*seg_size, m-1);
|
269
|
+
v_max = min(v_min+seg_size, m-1);
|
270
|
+
}
|
271
|
+
else
|
272
|
+
{
|
273
|
+
mf_int seg_size = (mf_int)ceil((double)n/nr_bins);
|
274
|
+
v_min = min((block_id%nr_bins)*seg_size, n-1);
|
275
|
+
v_max = min(v_min+seg_size, n-1);
|
276
|
+
}
|
277
|
+
if(v_max == v_min)
|
278
|
+
return v_min;
|
279
|
+
else
|
280
|
+
return rand_val%(v_max-v_min)+v_min;
|
281
|
+
};
|
282
|
+
|
283
|
+
if(rand_val % 2)
|
284
|
+
return (mf_int)gen_random(first_block);
|
285
|
+
else
|
286
|
+
return (mf_int)gen_random(second_block);
|
287
|
+
}
|
288
|
+
|
289
|
+
void Scheduler::wait_for_jobs_done()
|
290
|
+
{
|
291
|
+
unique_lock<mutex> lock(mtx);
|
292
|
+
|
293
|
+
// The first thing the main thread should wait for is that solver threads
|
294
|
+
// process enough matrix blocks.
|
295
|
+
// [REVIEW] Is it really needed? Solver threads automatically stop if they
|
296
|
+
// process too many blocks, so the next wait should be enough for stopping
|
297
|
+
// the main thread when nr_done_job is not enough.
|
298
|
+
cond_var.wait(lock, [&] {
|
299
|
+
return nr_done_jobs >= target;
|
300
|
+
});
|
301
|
+
|
302
|
+
// Wait for all threads to stop. Once a thread realizes that all threads
|
303
|
+
// have processed enough blocks it should stop. Then, the main thread can
|
304
|
+
// print values safely.
|
305
|
+
cond_var.wait(lock, [&] {
|
306
|
+
return nr_paused_threads == nr_threads;
|
307
|
+
});
|
308
|
+
}
|
309
|
+
|
310
|
+
void Scheduler::resume()
|
311
|
+
{
|
312
|
+
lock_guard<mutex> lock(mtx);
|
313
|
+
target += nr_bins*nr_bins;
|
314
|
+
cond_var.notify_all();
|
315
|
+
}
|
316
|
+
|
317
|
+
void Scheduler::terminate()
|
318
|
+
{
|
319
|
+
lock_guard<mutex> lock(mtx);
|
320
|
+
terminated = true;
|
321
|
+
}
|
322
|
+
|
323
|
+
bool Scheduler::is_terminated()
|
324
|
+
{
|
325
|
+
lock_guard<mutex> lock(mtx);
|
326
|
+
return terminated;
|
327
|
+
}
|
328
|
+
|
329
|
+
//--------------------------------------
|
330
|
+
//------------Block of matrix-----------
|
331
|
+
//--------------------------------------
|
332
|
+
|
333
|
+
class BlockBase
|
334
|
+
{
|
335
|
+
public:
|
336
|
+
virtual bool move_next() { return false; };
|
337
|
+
virtual mf_node* get_current() { return nullptr; }
|
338
|
+
virtual void reload() {};
|
339
|
+
virtual void free() {};
|
340
|
+
virtual mf_long get_nnz() { return 0; };
|
341
|
+
virtual ~BlockBase() {};
|
342
|
+
};
|
343
|
+
|
344
|
+
class Block : public BlockBase
|
345
|
+
{
|
346
|
+
public:
|
347
|
+
Block() : first(nullptr), last(nullptr), current(nullptr) {};
|
348
|
+
Block(mf_node *first_, mf_node *last_)
|
349
|
+
: first(first_), last(last_), current(nullptr) {};
|
350
|
+
bool move_next() { return ++current != last; }
|
351
|
+
mf_node* get_current() { return current; }
|
352
|
+
void tie_to(mf_node *first_, mf_node *last_);
|
353
|
+
void reload() { current = first-1; };
|
354
|
+
mf_long get_nnz() { return last-first; };
|
355
|
+
|
356
|
+
private:
|
357
|
+
mf_node* first;
|
358
|
+
mf_node* last;
|
359
|
+
mf_node* current;
|
360
|
+
};
|
361
|
+
|
362
|
+
void Block::tie_to(mf_node *first_, mf_node *last_)
|
363
|
+
{
|
364
|
+
first = first_;
|
365
|
+
last = last_;
|
366
|
+
};
|
367
|
+
|
368
|
+
class BlockOnDisk : public BlockBase
|
369
|
+
{
|
370
|
+
public:
|
371
|
+
BlockOnDisk() : first(0), last(0), current(0),
|
372
|
+
source_path(""), buffer(0) {};
|
373
|
+
bool move_next() { return ++current < last-first; }
|
374
|
+
mf_node* get_current() { return &buffer[static_cast<size_t>(current)]; }
|
375
|
+
void tie_to(string source_path_, mf_long first_, mf_long last_);
|
376
|
+
void reload();
|
377
|
+
void free() { buffer.resize(0); };
|
378
|
+
mf_long get_nnz() { return last-first; };
|
379
|
+
|
380
|
+
private:
|
381
|
+
mf_long first;
|
382
|
+
mf_long last;
|
383
|
+
mf_long current;
|
384
|
+
string source_path;
|
385
|
+
vector<mf_node> buffer;
|
386
|
+
};
|
387
|
+
|
388
|
+
void BlockOnDisk::tie_to(string source_path_, mf_long first_, mf_long last_)
|
389
|
+
{
|
390
|
+
source_path = source_path_;
|
391
|
+
first = first_;
|
392
|
+
last = last_;
|
393
|
+
}
|
394
|
+
|
395
|
+
void BlockOnDisk::reload()
|
396
|
+
{
|
397
|
+
ifstream source(source_path, ifstream::in|ifstream::binary);
|
398
|
+
if(!source)
|
399
|
+
throw runtime_error("can not open "+source_path);
|
400
|
+
|
401
|
+
buffer.resize(static_cast<size_t>(last-first));
|
402
|
+
source.seekg(first*sizeof(mf_node));
|
403
|
+
source.read((char*)buffer.data(), (last-first)*sizeof(mf_node));
|
404
|
+
current = -1;
|
405
|
+
}
|
406
|
+
|
407
|
+
//--------------------------------------
|
408
|
+
//-------------Miscellaneous------------
|
409
|
+
//--------------------------------------
|
410
|
+
|
411
|
+
struct sort_node_by_p
|
412
|
+
{
|
413
|
+
bool operator() (mf_node const &lhs, mf_node const &rhs)
|
414
|
+
{
|
415
|
+
return tie(lhs.u, lhs.v) < tie(rhs.u, rhs.v);
|
416
|
+
}
|
417
|
+
};
|
418
|
+
|
419
|
+
struct sort_node_by_q
|
420
|
+
{
|
421
|
+
bool operator() (mf_node const &lhs, mf_node const &rhs)
|
422
|
+
{
|
423
|
+
return tie(lhs.v, lhs.u) < tie(rhs.v, rhs.u);
|
424
|
+
}
|
425
|
+
};
|
426
|
+
|
427
|
+
struct deleter
|
428
|
+
{
|
429
|
+
void operator() (mf_problem *prob)
|
430
|
+
{
|
431
|
+
delete[] prob->R;
|
432
|
+
delete prob;
|
433
|
+
}
|
434
|
+
};
|
435
|
+
|
436
|
+
|
437
|
+
class Utility
|
438
|
+
{
|
439
|
+
public:
|
440
|
+
Utility(mf_int f, mf_int n) : fun(f), nr_threads(n) {};
|
441
|
+
void collect_info(mf_problem &prob, mf_float &avg, mf_float &std_dev);
|
442
|
+
void collect_info_on_disk(string data_path, mf_problem &prob,
|
443
|
+
mf_float &avg, mf_float &std_dev);
|
444
|
+
void shuffle_problem(mf_problem &prob, vector<mf_int> &p_map,
|
445
|
+
vector<mf_int> &q_map);
|
446
|
+
vector<mf_node*> grid_problem(mf_problem &prob, mf_int nr_bins,
|
447
|
+
vector<mf_int> &omega_p,
|
448
|
+
vector<mf_int> &omega_q,
|
449
|
+
vector<Block> &blocks);
|
450
|
+
void grid_shuffle_scale_problem_on_disk(mf_int m, mf_int n, mf_int nr_bins,
|
451
|
+
mf_float scale, string data_path,
|
452
|
+
vector<mf_int> &p_map,
|
453
|
+
vector<mf_int> &q_map,
|
454
|
+
vector<mf_int> &omega_p,
|
455
|
+
vector<mf_int> &omega_q,
|
456
|
+
vector<BlockOnDisk> &blocks);
|
457
|
+
void scale_problem(mf_problem &prob, mf_float scale);
|
458
|
+
mf_double calc_reg1(mf_model &model, mf_float lambda_p, mf_float lambda_q,
|
459
|
+
vector<mf_int> &omega_p, vector<mf_int> &omega_q);
|
460
|
+
mf_double calc_reg2(mf_model &model, mf_float lambda_p, mf_float lambda_q,
|
461
|
+
vector<mf_int> &omega_p, vector<mf_int> &omega_q);
|
462
|
+
string get_error_legend() const;
|
463
|
+
mf_double calc_error(vector<BlockBase*> &blocks,
|
464
|
+
vector<mf_int> &cv_block_ids,
|
465
|
+
mf_model const &model);
|
466
|
+
void scale_model(mf_model &model, mf_float scale);
|
467
|
+
|
468
|
+
static mf_problem* copy_problem(mf_problem const *prob, bool copy_data);
|
469
|
+
static vector<mf_int> gen_random_map(mf_int size);
|
470
|
+
// A function used to allocate all aligned float array.
|
471
|
+
// It hides platform-specific function calls. Memory
|
472
|
+
// allocated by malloc_aligned_float must be freed by using
|
473
|
+
// free_aligned_float.
|
474
|
+
static mf_float* malloc_aligned_float(mf_long size);
|
475
|
+
// A function used to free all aligned float array.
|
476
|
+
// It hides platform-specific function calls.
|
477
|
+
static void free_aligned_float(mf_float* ptr);
|
478
|
+
// Initialization function for stochastic gradient method.
|
479
|
+
// Factor matrices P and Q are both randomly initialized.
|
480
|
+
static mf_model* init_model(mf_int loss, mf_int m, mf_int n,
|
481
|
+
mf_int k, mf_float avg,
|
482
|
+
vector<mf_int> &omega_p,
|
483
|
+
vector<mf_int> &omega_q);
|
484
|
+
// Initialization function for one-class CD.
|
485
|
+
// It does zero-initialization on factor matrix P and random initialization
|
486
|
+
// on factor matrix Q.
|
487
|
+
static mf_model* init_model(mf_int m, mf_int n, mf_int k);
|
488
|
+
static mf_float inner_product(mf_float *p, mf_float *q, mf_int k);
|
489
|
+
static vector<mf_int> gen_inv_map(vector<mf_int> &map);
|
490
|
+
static void shrink_model(mf_model &model, mf_int k_new);
|
491
|
+
static void shuffle_model(mf_model &model,
|
492
|
+
vector<mf_int> &p_map,
|
493
|
+
vector<mf_int> &q_map);
|
494
|
+
mf_int get_thread_number() const { return nr_threads; };
|
495
|
+
private:
|
496
|
+
mf_int fun;
|
497
|
+
mf_int nr_threads;
|
498
|
+
};
|
499
|
+
|
500
|
+
void Utility::collect_info(
|
501
|
+
mf_problem &prob,
|
502
|
+
mf_float &avg,
|
503
|
+
mf_float &std_dev)
|
504
|
+
{
|
505
|
+
mf_double ex = 0;
|
506
|
+
mf_double ex2 = 0;
|
507
|
+
|
508
|
+
#if defined USEOMP
|
509
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:ex,ex2)
|
510
|
+
#endif
|
511
|
+
for(mf_long i = 0; i < prob.nnz; ++i)
|
512
|
+
{
|
513
|
+
mf_node &N = prob.R[i];
|
514
|
+
ex += (mf_double)N.r;
|
515
|
+
ex2 += (mf_double)N.r*N.r;
|
516
|
+
}
|
517
|
+
|
518
|
+
ex /= (mf_double)prob.nnz;
|
519
|
+
ex2 /= (mf_double)prob.nnz;
|
520
|
+
avg = (mf_float)ex;
|
521
|
+
std_dev = (mf_float)sqrt(ex2-ex*ex);
|
522
|
+
}
|
523
|
+
|
524
|
+
void Utility::collect_info_on_disk(
|
525
|
+
string data_path,
|
526
|
+
mf_problem &prob,
|
527
|
+
mf_float &avg,
|
528
|
+
mf_float &std_dev)
|
529
|
+
{
|
530
|
+
mf_double ex = 0;
|
531
|
+
mf_double ex2 = 0;
|
532
|
+
|
533
|
+
ifstream source(data_path);
|
534
|
+
if(!source.is_open())
|
535
|
+
throw runtime_error("cannot open " + data_path);
|
536
|
+
|
537
|
+
for(mf_node N; source >> N.u >> N.v >> N.r;)
|
538
|
+
{
|
539
|
+
if(N.u+1 > prob.m)
|
540
|
+
prob.m = N.u+1;
|
541
|
+
if(N.v+1 > prob.n)
|
542
|
+
prob.n = N.v+1;
|
543
|
+
prob.nnz += 1;
|
544
|
+
ex += (mf_double)N.r;
|
545
|
+
ex2 += (mf_double)N.r*N.r;
|
546
|
+
}
|
547
|
+
source.close();
|
548
|
+
|
549
|
+
ex /= (mf_double)prob.nnz;
|
550
|
+
ex2 /= (mf_double)prob.nnz;
|
551
|
+
avg = (mf_float)ex;
|
552
|
+
std_dev = (mf_float)sqrt(ex2-ex*ex);
|
553
|
+
}
|
554
|
+
|
555
|
+
void Utility::scale_problem(mf_problem &prob, mf_float scale)
|
556
|
+
{
|
557
|
+
if(scale == 1.0)
|
558
|
+
return;
|
559
|
+
|
560
|
+
#if defined USEOMP
|
561
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static)
|
562
|
+
#endif
|
563
|
+
for(mf_long i = 0; i < prob.nnz; ++i)
|
564
|
+
prob.R[i].r *= scale;
|
565
|
+
}
|
566
|
+
|
567
|
+
void Utility::scale_model(mf_model &model, mf_float scale)
|
568
|
+
{
|
569
|
+
if(scale == 1.0)
|
570
|
+
return;
|
571
|
+
|
572
|
+
mf_int k = model.k;
|
573
|
+
|
574
|
+
model.b *= scale;
|
575
|
+
|
576
|
+
auto scale1 = [&] (mf_float *ptr, mf_int size, mf_float factor_scale)
|
577
|
+
{
|
578
|
+
#if defined USEOMP
|
579
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static)
|
580
|
+
#endif
|
581
|
+
for(mf_int i = 0; i < size; ++i)
|
582
|
+
{
|
583
|
+
mf_float *ptr1 = ptr+(mf_long)i*model.k;
|
584
|
+
for(mf_int d = 0; d < k; ++d)
|
585
|
+
ptr1[d] *= factor_scale;
|
586
|
+
}
|
587
|
+
};
|
588
|
+
|
589
|
+
scale1(model.P, model.m, sqrt(scale));
|
590
|
+
scale1(model.Q, model.n, sqrt(scale));
|
591
|
+
}
|
592
|
+
|
593
|
+
mf_float Utility::inner_product(mf_float *p, mf_float *q, mf_int k)
|
594
|
+
{
|
595
|
+
#if defined USESSE
|
596
|
+
__m128 XMM = _mm_setzero_ps();
|
597
|
+
for(mf_int d = 0; d < k; d += 4)
|
598
|
+
XMM = _mm_add_ps(XMM, _mm_mul_ps(
|
599
|
+
_mm_load_ps(p+d), _mm_load_ps(q+d)));
|
600
|
+
__m128 XMMtmp = _mm_add_ps(XMM, _mm_movehl_ps(XMM, XMM));
|
601
|
+
XMM = _mm_add_ps(XMM, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
602
|
+
mf_float product;
|
603
|
+
_mm_store_ss(&product, XMM);
|
604
|
+
return product;
|
605
|
+
#elif defined USEAVX
|
606
|
+
__m256 XMM = _mm256_setzero_ps();
|
607
|
+
for(mf_int d = 0; d < k; d += 8)
|
608
|
+
XMM = _mm256_add_ps(XMM, _mm256_mul_ps(
|
609
|
+
_mm256_load_ps(p+d), _mm256_load_ps(q+d)));
|
610
|
+
XMM = _mm256_add_ps(XMM, _mm256_permute2f128_ps(XMM, XMM, 1));
|
611
|
+
XMM = _mm256_hadd_ps(XMM, XMM);
|
612
|
+
XMM = _mm256_hadd_ps(XMM, XMM);
|
613
|
+
mf_float product;
|
614
|
+
_mm_store_ss(&product, _mm256_castps256_ps128(XMM));
|
615
|
+
return product;
|
616
|
+
#else
|
617
|
+
return std::inner_product(p, p+k, q, (mf_float)0.0);
|
618
|
+
#endif
|
619
|
+
}
|
620
|
+
|
621
|
+
mf_double Utility::calc_reg1(mf_model &model,
|
622
|
+
mf_float lambda_p, mf_float lambda_q,
|
623
|
+
vector<mf_int> &omega_p, vector<mf_int> &omega_q)
|
624
|
+
{
|
625
|
+
auto calc_reg1_core = [&] (mf_float *ptr, mf_int size,
|
626
|
+
vector<mf_int> &omega)
|
627
|
+
{
|
628
|
+
mf_double reg = 0;
|
629
|
+
for(mf_int i = 0; i < size; ++i)
|
630
|
+
{
|
631
|
+
if(omega[i] <= 0)
|
632
|
+
continue;
|
633
|
+
|
634
|
+
mf_float tmp = 0;
|
635
|
+
for(mf_int j = 0; j < model.k; ++j)
|
636
|
+
tmp += abs(ptr[(mf_long)i*model.k+j]);
|
637
|
+
reg += omega[i]*tmp;
|
638
|
+
}
|
639
|
+
return reg;
|
640
|
+
};
|
641
|
+
|
642
|
+
return lambda_p*calc_reg1_core(model.P, model.m, omega_p)+
|
643
|
+
lambda_q*calc_reg1_core(model.Q, model.n, omega_q);
|
644
|
+
}
|
645
|
+
|
646
|
+
mf_double Utility::calc_reg2(mf_model &model,
|
647
|
+
mf_float lambda_p, mf_float lambda_q,
|
648
|
+
vector<mf_int> &omega_p, vector<mf_int> &omega_q)
|
649
|
+
{
|
650
|
+
auto calc_reg2_core = [&] (mf_float *ptr, mf_int size,
|
651
|
+
vector<mf_int> &omega)
|
652
|
+
{
|
653
|
+
mf_double reg = 0;
|
654
|
+
#if defined USEOMP
|
655
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:reg)
|
656
|
+
#endif
|
657
|
+
for(mf_int i = 0; i < size; ++i)
|
658
|
+
{
|
659
|
+
if(omega[i] <= 0)
|
660
|
+
continue;
|
661
|
+
|
662
|
+
mf_float *ptr1 = ptr+(mf_long)i*model.k;
|
663
|
+
reg += omega[i]*Utility::inner_product(ptr1, ptr1, model.k);
|
664
|
+
}
|
665
|
+
|
666
|
+
return reg;
|
667
|
+
};
|
668
|
+
|
669
|
+
return lambda_p*calc_reg2_core(model.P, model.m, omega_p) +
|
670
|
+
lambda_q*calc_reg2_core(model.Q, model.n, omega_q);
|
671
|
+
}
|
672
|
+
|
673
|
+
mf_double Utility::calc_error(
|
674
|
+
vector<BlockBase*> &blocks,
|
675
|
+
vector<mf_int> &cv_block_ids,
|
676
|
+
mf_model const &model)
|
677
|
+
{
|
678
|
+
mf_double error = 0;
|
679
|
+
if(fun == P_L2_MFR || fun == P_L1_MFR || fun == P_KL_MFR ||
|
680
|
+
fun == P_LR_MFC || fun == P_L2_MFC || fun == P_L1_MFC)
|
681
|
+
{
|
682
|
+
#if defined USEOMP
|
683
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
|
684
|
+
#endif
|
685
|
+
for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
|
686
|
+
{
|
687
|
+
BlockBase *block = blocks[cv_block_ids[i]];
|
688
|
+
block->reload();
|
689
|
+
while(block->move_next())
|
690
|
+
{
|
691
|
+
mf_node const &N = *(block->get_current());
|
692
|
+
mf_float z = mf_predict(&model, N.u, N.v);
|
693
|
+
switch(fun)
|
694
|
+
{
|
695
|
+
case P_L2_MFR:
|
696
|
+
error += pow(N.r-z, 2);
|
697
|
+
break;
|
698
|
+
case P_L1_MFR:
|
699
|
+
error += abs(N.r-z);
|
700
|
+
break;
|
701
|
+
case P_KL_MFR:
|
702
|
+
error += N.r*log(N.r/z)-N.r+z;
|
703
|
+
break;
|
704
|
+
case P_LR_MFC:
|
705
|
+
if(N.r > 0)
|
706
|
+
error += log(1.0+exp(-z));
|
707
|
+
else
|
708
|
+
error += log(1.0+exp(z));
|
709
|
+
break;
|
710
|
+
case P_L2_MFC:
|
711
|
+
case P_L1_MFC:
|
712
|
+
if(N.r > 0)
|
713
|
+
error += z > 0? 1: 0;
|
714
|
+
else
|
715
|
+
error += z < 0? 1: 0;
|
716
|
+
break;
|
717
|
+
default:
|
718
|
+
throw invalid_argument("unknown error function");
|
719
|
+
break;
|
720
|
+
}
|
721
|
+
}
|
722
|
+
block->free();
|
723
|
+
}
|
724
|
+
}
|
725
|
+
else
|
726
|
+
{
|
727
|
+
minstd_rand0 generator(rand());
|
728
|
+
switch(fun)
|
729
|
+
{
|
730
|
+
case P_ROW_BPR_MFOC:
|
731
|
+
{
|
732
|
+
uniform_int_distribution<mf_int> distribution(0, model.n-1);
|
733
|
+
#if defined USEOMP
|
734
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
|
735
|
+
#endif
|
736
|
+
for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
|
737
|
+
{
|
738
|
+
BlockBase *block = blocks[cv_block_ids[i]];
|
739
|
+
block->reload();
|
740
|
+
while(block->move_next())
|
741
|
+
{
|
742
|
+
mf_node const &N = *(block->get_current());
|
743
|
+
mf_int w = distribution(generator);
|
744
|
+
error += log(1+exp(mf_predict(&model, N.u, w)-
|
745
|
+
mf_predict(&model, N.u, N.v)));
|
746
|
+
}
|
747
|
+
block->free();
|
748
|
+
}
|
749
|
+
break;
|
750
|
+
}
|
751
|
+
case P_COL_BPR_MFOC:
|
752
|
+
{
|
753
|
+
uniform_int_distribution<mf_int> distribution(0, model.m-1);
|
754
|
+
#if defined USEOMP
|
755
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
|
756
|
+
#endif
|
757
|
+
for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
|
758
|
+
{
|
759
|
+
BlockBase *block = blocks[cv_block_ids[i]];
|
760
|
+
block->reload();
|
761
|
+
while(block->move_next())
|
762
|
+
{
|
763
|
+
mf_node const &N = *(block->get_current());
|
764
|
+
mf_int w = distribution(generator);
|
765
|
+
error += log(1+exp(mf_predict(&model, w, N.v)-
|
766
|
+
mf_predict(&model, N.u, N.v)));
|
767
|
+
}
|
768
|
+
block->free();
|
769
|
+
}
|
770
|
+
break;
|
771
|
+
}
|
772
|
+
default:
|
773
|
+
{
|
774
|
+
throw invalid_argument("unknown error function");
|
775
|
+
break;
|
776
|
+
}
|
777
|
+
}
|
778
|
+
}
|
779
|
+
|
780
|
+
return error;
|
781
|
+
}
|
782
|
+
|
783
|
+
string Utility::get_error_legend() const
|
784
|
+
{
|
785
|
+
switch(fun)
|
786
|
+
{
|
787
|
+
case P_L2_MFR:
|
788
|
+
return string("rmse");
|
789
|
+
break;
|
790
|
+
case P_L1_MFR:
|
791
|
+
return string("mae");
|
792
|
+
break;
|
793
|
+
case P_KL_MFR:
|
794
|
+
return string("gkl");
|
795
|
+
break;
|
796
|
+
case P_LR_MFC:
|
797
|
+
return string("logloss");
|
798
|
+
break;
|
799
|
+
case P_L2_MFC:
|
800
|
+
case P_L1_MFC:
|
801
|
+
return string("accuracy");
|
802
|
+
break;
|
803
|
+
case P_ROW_BPR_MFOC:
|
804
|
+
case P_COL_BPR_MFOC:
|
805
|
+
return string("bprloss");
|
806
|
+
break;
|
807
|
+
case P_L2_MFOC:
|
808
|
+
return string("sqerror");
|
809
|
+
default:
|
810
|
+
return string();
|
811
|
+
break;
|
812
|
+
}
|
813
|
+
}
|
814
|
+
|
815
|
+
void Utility::shuffle_problem(
|
816
|
+
mf_problem &prob,
|
817
|
+
vector<mf_int> &p_map,
|
818
|
+
vector<mf_int> &q_map)
|
819
|
+
{
|
820
|
+
#if defined USEOMP
|
821
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static)
|
822
|
+
#endif
|
823
|
+
for(mf_long i = 0; i < prob.nnz; ++i)
|
824
|
+
{
|
825
|
+
mf_node &N = prob.R[i];
|
826
|
+
if(N.u < (mf_long)p_map.size())
|
827
|
+
N.u = p_map[N.u];
|
828
|
+
if(N.v < (mf_long)q_map.size())
|
829
|
+
N.v = q_map[N.v];
|
830
|
+
}
|
831
|
+
}
|
832
|
+
|
833
|
+
vector<mf_node*> Utility::grid_problem(
|
834
|
+
mf_problem &prob,
|
835
|
+
mf_int nr_bins,
|
836
|
+
vector<mf_int> &omega_p,
|
837
|
+
vector<mf_int> &omega_q,
|
838
|
+
vector<Block> &blocks)
|
839
|
+
{
|
840
|
+
vector<mf_long> counts(nr_bins*nr_bins, 0);
|
841
|
+
|
842
|
+
mf_int seg_p = (mf_int)ceil((double)prob.m/nr_bins);
|
843
|
+
mf_int seg_q = (mf_int)ceil((double)prob.n/nr_bins);
|
844
|
+
|
845
|
+
auto get_block_id = [=] (mf_int u, mf_int v)
|
846
|
+
{
|
847
|
+
return (u/seg_p)*nr_bins+v/seg_q;
|
848
|
+
};
|
849
|
+
|
850
|
+
for(mf_long i = 0; i < prob.nnz; ++i)
|
851
|
+
{
|
852
|
+
mf_node &N = prob.R[i];
|
853
|
+
mf_int block = get_block_id(N.u, N.v);
|
854
|
+
counts[block] += 1;
|
855
|
+
omega_p[N.u] += 1;
|
856
|
+
omega_q[N.v] += 1;
|
857
|
+
}
|
858
|
+
|
859
|
+
vector<mf_node*> ptrs(nr_bins*nr_bins+1);
|
860
|
+
mf_node *ptr = prob.R;
|
861
|
+
ptrs[0] = ptr;
|
862
|
+
for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
|
863
|
+
ptrs[block+1] = ptrs[block] + counts[block];
|
864
|
+
|
865
|
+
vector<mf_node*> pivots(ptrs.begin(), ptrs.end()-1);
|
866
|
+
for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
|
867
|
+
{
|
868
|
+
for(mf_node* pivot = pivots[block]; pivot != ptrs[block+1];)
|
869
|
+
{
|
870
|
+
mf_int curr_block = get_block_id(pivot->u, pivot->v);
|
871
|
+
if(curr_block == block)
|
872
|
+
{
|
873
|
+
++pivot;
|
874
|
+
continue;
|
875
|
+
}
|
876
|
+
|
877
|
+
mf_node *next = pivots[curr_block];
|
878
|
+
swap(*pivot, *next);
|
879
|
+
pivots[curr_block] += 1;
|
880
|
+
}
|
881
|
+
}
|
882
|
+
|
883
|
+
#if defined USEOMP
|
884
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(dynamic)
|
885
|
+
#endif
|
886
|
+
for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
|
887
|
+
{
|
888
|
+
if(prob.m > prob.n)
|
889
|
+
sort(ptrs[block], ptrs[block+1], sort_node_by_p());
|
890
|
+
else
|
891
|
+
sort(ptrs[block], ptrs[block+1], sort_node_by_q());
|
892
|
+
}
|
893
|
+
|
894
|
+
for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
|
895
|
+
blocks[i].tie_to(ptrs[i], ptrs[i+1]);
|
896
|
+
|
897
|
+
return ptrs;
|
898
|
+
}
|
899
|
+
|
900
|
+
void Utility::grid_shuffle_scale_problem_on_disk(
|
901
|
+
mf_int m, mf_int n, mf_int nr_bins,
|
902
|
+
mf_float scale, string data_path,
|
903
|
+
vector<mf_int> &p_map, vector<mf_int> &q_map,
|
904
|
+
vector<mf_int> &omega_p, vector<mf_int> &omega_q,
|
905
|
+
vector<BlockOnDisk> &blocks)
|
906
|
+
{
|
907
|
+
string const buffer_path = data_path+string(".disk");
|
908
|
+
mf_int seg_p = (mf_int)ceil((double)m/nr_bins);
|
909
|
+
mf_int seg_q = (mf_int)ceil((double)n/nr_bins);
|
910
|
+
vector<mf_long> counts(nr_bins*nr_bins+1, 0);
|
911
|
+
vector<mf_long> pivots(nr_bins*nr_bins, 0);
|
912
|
+
ifstream source(data_path);
|
913
|
+
fstream buffer(buffer_path, fstream::in|fstream::out|
|
914
|
+
fstream::binary|fstream::trunc);
|
915
|
+
auto get_block_id = [=] (mf_int u, mf_int v)
|
916
|
+
{
|
917
|
+
return (u/seg_p)*nr_bins+v/seg_q;
|
918
|
+
};
|
919
|
+
|
920
|
+
if(!source)
|
921
|
+
throw ios::failure(string("cannot to open ")+data_path);
|
922
|
+
if(!buffer)
|
923
|
+
throw ios::failure(string("cannot to open ")+buffer_path);
|
924
|
+
|
925
|
+
for(mf_node N; source >> N.u >> N.v >> N.r;)
|
926
|
+
{
|
927
|
+
N.u = p_map[N.u];
|
928
|
+
N.v = q_map[N.v];
|
929
|
+
mf_int bid = get_block_id(N.u, N.v);
|
930
|
+
omega_p[N.u] += 1;
|
931
|
+
omega_q[N.v] += 1;
|
932
|
+
counts[bid+1] += 1;
|
933
|
+
}
|
934
|
+
|
935
|
+
for(mf_int i = 1; i < nr_bins*nr_bins+1; ++i)
|
936
|
+
{
|
937
|
+
counts[i] += counts[i-1];
|
938
|
+
pivots[i-1] = counts[i-1];
|
939
|
+
}
|
940
|
+
|
941
|
+
source.clear();
|
942
|
+
source.seekg(0);
|
943
|
+
for(mf_node N; source >> N.u >> N.v >> N.r;)
|
944
|
+
{
|
945
|
+
N.u = p_map[N.u];
|
946
|
+
N.v = q_map[N.v];
|
947
|
+
N.r /= scale;
|
948
|
+
mf_int bid = get_block_id(N.u, N.v);
|
949
|
+
buffer.seekp(pivots[bid]*sizeof(mf_node));
|
950
|
+
buffer.write((char*)&N, sizeof(mf_node));
|
951
|
+
pivots[bid] += 1;
|
952
|
+
}
|
953
|
+
|
954
|
+
for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
|
955
|
+
{
|
956
|
+
vector<mf_node> nodes(static_cast<size_t>(counts[i+1]-counts[i]));
|
957
|
+
buffer.clear();
|
958
|
+
buffer.seekg(counts[i]*sizeof(mf_node));
|
959
|
+
buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
|
960
|
+
|
961
|
+
if(m > n)
|
962
|
+
sort(nodes.begin(), nodes.end(), sort_node_by_p());
|
963
|
+
else
|
964
|
+
sort(nodes.begin(), nodes.end(), sort_node_by_q());
|
965
|
+
|
966
|
+
buffer.clear();
|
967
|
+
buffer.seekp(counts[i]*sizeof(mf_node));
|
968
|
+
buffer.write((char*)nodes.data(), sizeof(mf_node)*nodes.size());
|
969
|
+
buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
|
970
|
+
}
|
971
|
+
|
972
|
+
for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
|
973
|
+
blocks[i].tie_to(buffer_path, counts[i], counts[i+1]);
|
974
|
+
}
|
975
|
+
|
976
|
+
mf_float* Utility::malloc_aligned_float(mf_long size)
|
977
|
+
{
|
978
|
+
// Check if conversion from mf_long to size_t causes overflow.
|
979
|
+
if (size > numeric_limits<std::size_t>::max() / sizeof(mf_float) + 1)
|
980
|
+
throw bad_alloc();
|
981
|
+
// [REVIEW] I hope one day we can use C11 aligned_alloc to replace
|
982
|
+
// platform-depedent functions below. Both of Windows and OSX currently
|
983
|
+
// don't support that function.
|
984
|
+
void *ptr = nullptr;
|
985
|
+
#ifdef _WIN32
|
986
|
+
ptr = _aligned_malloc(static_cast<size_t>(size*sizeof(mf_float)),
|
987
|
+
kALIGNByte);
|
988
|
+
#else
|
989
|
+
int status = posix_memalign(&ptr, kALIGNByte, size*sizeof(mf_float));
|
990
|
+
if(status != 0)
|
991
|
+
throw bad_alloc();
|
992
|
+
#endif
|
993
|
+
if(ptr == nullptr)
|
994
|
+
throw bad_alloc();
|
995
|
+
|
996
|
+
return (mf_float*)ptr;
|
997
|
+
}
|
998
|
+
|
999
|
+
void Utility::free_aligned_float(mf_float *ptr)
|
1000
|
+
{
|
1001
|
+
#ifdef _WIN32
|
1002
|
+
// Unfortunately, Visual Studio doesn't want to support the
|
1003
|
+
// cross-platform allocation below.
|
1004
|
+
_aligned_free(ptr);
|
1005
|
+
#else
|
1006
|
+
free(ptr);
|
1007
|
+
#endif
|
1008
|
+
}
|
1009
|
+
|
1010
|
+
mf_model* Utility::init_model(mf_int fun,
|
1011
|
+
mf_int m, mf_int n,
|
1012
|
+
mf_int k, mf_float avg,
|
1013
|
+
vector<mf_int> &omega_p,
|
1014
|
+
vector<mf_int> &omega_q)
|
1015
|
+
{
|
1016
|
+
mf_int k_real = k;
|
1017
|
+
mf_int k_aligned = (mf_int)ceil(mf_double(k)/kALIGN)*kALIGN;
|
1018
|
+
|
1019
|
+
mf_model *model = new mf_model;
|
1020
|
+
|
1021
|
+
model->fun = fun;
|
1022
|
+
model->m = m;
|
1023
|
+
model->n = n;
|
1024
|
+
model->k = k_aligned;
|
1025
|
+
model->b = avg;
|
1026
|
+
model->P = nullptr;
|
1027
|
+
model->Q = nullptr;
|
1028
|
+
|
1029
|
+
mf_float scale = (mf_float)sqrt(1.0/k_real);
|
1030
|
+
default_random_engine generator;
|
1031
|
+
uniform_real_distribution<mf_float> distribution(0.0, 1.0);
|
1032
|
+
|
1033
|
+
try
|
1034
|
+
{
|
1035
|
+
model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
|
1036
|
+
model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
|
1037
|
+
}
|
1038
|
+
catch(bad_alloc const &e)
|
1039
|
+
{
|
1040
|
+
cerr << e.what() << endl;
|
1041
|
+
mf_destroy_model(&model);
|
1042
|
+
throw;
|
1043
|
+
}
|
1044
|
+
|
1045
|
+
auto init1 = [&](mf_float *start_ptr, mf_long size, vector<mf_int> counts)
|
1046
|
+
{
|
1047
|
+
memset(start_ptr, 0, static_cast<size_t>(
|
1048
|
+
sizeof(mf_float) * size*model->k));
|
1049
|
+
for(mf_long i = 0; i < size; ++i)
|
1050
|
+
{
|
1051
|
+
mf_float * ptr = start_ptr + i*model->k;
|
1052
|
+
if(counts[static_cast<size_t>(i)] > 0)
|
1053
|
+
for(mf_long d = 0; d < k_real; ++d, ++ptr)
|
1054
|
+
*ptr = (mf_float)(distribution(generator)*scale);
|
1055
|
+
else
|
1056
|
+
if(fun != P_ROW_BPR_MFOC && fun != P_COL_BPR_MFOC) // unseen for bpr is 0
|
1057
|
+
for(mf_long d = 0; d < k_real; ++d, ++ptr)
|
1058
|
+
*ptr = numeric_limits<mf_float>::quiet_NaN();
|
1059
|
+
}
|
1060
|
+
};
|
1061
|
+
|
1062
|
+
init1(model->P, m, omega_p);
|
1063
|
+
init1(model->Q, n, omega_q);
|
1064
|
+
|
1065
|
+
return model;
|
1066
|
+
}
|
1067
|
+
|
1068
|
+
// Initialize P=[\bar{p}_1, ..., \bar{p}_d] and Q=[\bar{q}_1, ..., \bar{q}_d].
|
1069
|
+
// Note that \bar{q}_{kv} is Q[k*n+v] and \bar{p}_{ku} is P[k*m+u]. One may
|
1070
|
+
// notice that P and Q here are actually the transposes of P and Q in fpsg(...)
|
1071
|
+
// because fpsg(...) uses P^TQ (where P and Q are respectively k-by-m and
|
1072
|
+
// k-by-n) to approximate the given rating matrix R while ccd_one_class(...)
|
1073
|
+
// uses PQ^T (where P and Q are respectively m-by-k and n-by-k.
|
1074
|
+
mf_model* Utility::init_model(mf_int m, mf_int n, mf_int k)
|
1075
|
+
{
|
1076
|
+
mf_model *model = new mf_model;
|
1077
|
+
|
1078
|
+
model->fun = P_L2_MFOC;
|
1079
|
+
model->m = m;
|
1080
|
+
model->n = n;
|
1081
|
+
model->k = k;
|
1082
|
+
model->b = 0.0; // One-class matrix factorization doesn't have bias.
|
1083
|
+
model->P = nullptr;
|
1084
|
+
model->Q = nullptr;
|
1085
|
+
|
1086
|
+
try
|
1087
|
+
{
|
1088
|
+
model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
|
1089
|
+
model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
|
1090
|
+
}
|
1091
|
+
catch(bad_alloc const &e)
|
1092
|
+
{
|
1093
|
+
cerr << e.what() << endl;
|
1094
|
+
mf_destroy_model(&model);
|
1095
|
+
throw;
|
1096
|
+
}
|
1097
|
+
|
1098
|
+
// Our initialization strategy is that all P's elements are zero and do
|
1099
|
+
// random initization on Q. Thus, all initial predicted ratings are all zero
|
1100
|
+
// since the approximated rating matrix is PQ^T.
|
1101
|
+
|
1102
|
+
// Initialize P with zeros
|
1103
|
+
for(mf_long i = 0; i < k * m; ++i)
|
1104
|
+
model->P[i] = 0.0;
|
1105
|
+
|
1106
|
+
// Initialize Q with random numbers
|
1107
|
+
default_random_engine generator;
|
1108
|
+
uniform_real_distribution<mf_float> distribution(0.0, 1.0);
|
1109
|
+
for(mf_long i = 0; i < k * n; ++i)
|
1110
|
+
model->Q[i] = distribution(generator);
|
1111
|
+
|
1112
|
+
return model;
|
1113
|
+
}
|
1114
|
+
|
1115
|
+
vector<mf_int> Utility::gen_random_map(mf_int size)
|
1116
|
+
{
|
1117
|
+
srand(0);
|
1118
|
+
vector<mf_int> map(size, 0);
|
1119
|
+
for(mf_int i = 0; i < size; ++i)
|
1120
|
+
map[i] = i;
|
1121
|
+
random_shuffle(map.begin(), map.end());
|
1122
|
+
return map;
|
1123
|
+
}
|
1124
|
+
|
1125
|
+
vector<mf_int> Utility::gen_inv_map(vector<mf_int> &map)
|
1126
|
+
{
|
1127
|
+
vector<mf_int> inv_map(map.size());
|
1128
|
+
for(mf_int i = 0; i < (mf_long)map.size(); ++i)
|
1129
|
+
inv_map[map[i]] = i;
|
1130
|
+
return inv_map;
|
1131
|
+
}
|
1132
|
+
|
1133
|
+
void Utility::shuffle_model(
|
1134
|
+
mf_model &model,
|
1135
|
+
vector<mf_int> &p_map,
|
1136
|
+
vector<mf_int> &q_map)
|
1137
|
+
{
|
1138
|
+
auto inv_shuffle1 = [] (mf_float *vec, vector<mf_int> &map,
|
1139
|
+
mf_int size, mf_int k)
|
1140
|
+
{
|
1141
|
+
for(mf_int pivot = 0; pivot < size;)
|
1142
|
+
{
|
1143
|
+
if(pivot == map[pivot])
|
1144
|
+
{
|
1145
|
+
++pivot;
|
1146
|
+
continue;
|
1147
|
+
}
|
1148
|
+
|
1149
|
+
mf_int next = map[pivot];
|
1150
|
+
|
1151
|
+
for(mf_int d = 0; d < k; ++d)
|
1152
|
+
swap(*(vec+(mf_long)pivot*k+d), *(vec+(mf_long)next*k+d));
|
1153
|
+
|
1154
|
+
map[pivot] = map[next];
|
1155
|
+
map[next] = next;
|
1156
|
+
}
|
1157
|
+
};
|
1158
|
+
|
1159
|
+
inv_shuffle1(model.P, p_map, model.m, model.k);
|
1160
|
+
inv_shuffle1(model.Q, q_map, model.n, model.k);
|
1161
|
+
}
|
1162
|
+
|
1163
|
+
void Utility::shrink_model(mf_model &model, mf_int k_new)
|
1164
|
+
{
|
1165
|
+
mf_int k_old = model.k;
|
1166
|
+
model.k = k_new;
|
1167
|
+
|
1168
|
+
auto shrink1 = [&] (mf_float *ptr, mf_int size)
|
1169
|
+
{
|
1170
|
+
for(mf_int i = 0; i < size; ++i)
|
1171
|
+
{
|
1172
|
+
mf_float *src = ptr+(mf_long)i*k_old;
|
1173
|
+
mf_float *dst = ptr+(mf_long)i*k_new;
|
1174
|
+
copy(src, src+k_new, dst);
|
1175
|
+
}
|
1176
|
+
};
|
1177
|
+
|
1178
|
+
shrink1(model.P, model.m);
|
1179
|
+
shrink1(model.Q, model.n);
|
1180
|
+
}
|
1181
|
+
|
1182
|
+
mf_problem* Utility::copy_problem(mf_problem const *prob, bool copy_data)
|
1183
|
+
{
|
1184
|
+
mf_problem *new_prob = new mf_problem;
|
1185
|
+
|
1186
|
+
if(prob == nullptr)
|
1187
|
+
{
|
1188
|
+
new_prob->m = 0;
|
1189
|
+
new_prob->n = 0;
|
1190
|
+
new_prob->nnz = 0;
|
1191
|
+
new_prob->R = nullptr;
|
1192
|
+
|
1193
|
+
return new_prob;
|
1194
|
+
}
|
1195
|
+
|
1196
|
+
new_prob->m = prob->m;
|
1197
|
+
new_prob->n = prob->n;
|
1198
|
+
new_prob->nnz = prob->nnz;
|
1199
|
+
|
1200
|
+
if(copy_data)
|
1201
|
+
{
|
1202
|
+
try
|
1203
|
+
{
|
1204
|
+
new_prob->R = new mf_node[static_cast<size_t>(prob->nnz)];
|
1205
|
+
copy(prob->R, prob->R+prob->nnz, new_prob->R);
|
1206
|
+
}
|
1207
|
+
catch(...)
|
1208
|
+
{
|
1209
|
+
delete new_prob;
|
1210
|
+
throw;
|
1211
|
+
}
|
1212
|
+
}
|
1213
|
+
else
|
1214
|
+
{
|
1215
|
+
new_prob->R = prob->R;
|
1216
|
+
}
|
1217
|
+
|
1218
|
+
return new_prob;
|
1219
|
+
}
|
1220
|
+
|
1221
|
+
//--------------------------------------
|
1222
|
+
//-----The base class of all solvers----
|
1223
|
+
//--------------------------------------
|
1224
|
+
|
1225
|
+
class SolverBase
|
1226
|
+
{
|
1227
|
+
public:
|
1228
|
+
SolverBase(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
1229
|
+
mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
|
1230
|
+
bool &slow_only)
|
1231
|
+
: scheduler(scheduler), blocks(blocks), PG(PG), QG(QG),
|
1232
|
+
model(model), param(param), slow_only(slow_only) {}
|
1233
|
+
void run();
|
1234
|
+
SolverBase(const SolverBase&) = delete;
|
1235
|
+
SolverBase& operator=(const SolverBase&) = delete;
|
1236
|
+
// Solver is stateless functor, so default destructor should be
|
1237
|
+
// good enough.
|
1238
|
+
virtual ~SolverBase() = default;
|
1239
|
+
|
1240
|
+
protected:
|
1241
|
+
#if defined USESSE
|
1242
|
+
static void calc_z(__m128 &XMMz, mf_int k, mf_float *p, mf_float *q);
|
1243
|
+
virtual void load_fixed_variables(
|
1244
|
+
__m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
|
1245
|
+
__m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
|
1246
|
+
__m128 &XMMeta, __m128 &XMMrk_slow,
|
1247
|
+
__m128 &XMMrk_fast);
|
1248
|
+
virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
|
1249
|
+
virtual void prepare_for_sg_update(
|
1250
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
|
1251
|
+
virtual void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
|
1252
|
+
__m128 XMMlambda_p1, __m128 XMMlambda_q1,
|
1253
|
+
__m128 XMMlambda_p2, __m128 XMMlamdba_q2,
|
1254
|
+
__m128 XMMeta, __m128 XMMrk) = 0;
|
1255
|
+
virtual void finalize(__m128d XMMloss, __m128d XMMerror);
|
1256
|
+
#elif defined USEAVX
|
1257
|
+
static void calc_z(__m256 &XMMz, mf_int k, mf_float *p, mf_float *q);
|
1258
|
+
virtual void load_fixed_variables(
|
1259
|
+
__m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
|
1260
|
+
__m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
|
1261
|
+
__m256 &XMMeta, __m256 &XMMrk_slow,
|
1262
|
+
__m256 &XMMrk_fast);
|
1263
|
+
virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
|
1264
|
+
virtual void prepare_for_sg_update(
|
1265
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
|
1266
|
+
virtual void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
|
1267
|
+
__m256 XMMlambda_p1, __m256 XMMlambda_q1,
|
1268
|
+
__m256 XMMlambda_p2, __m256 XMMlamdba_q2,
|
1269
|
+
__m256 XMMeta, __m256 XMMrk) = 0;
|
1270
|
+
virtual void finalize(__m128d XMMloss, __m128d XMMerror);
|
1271
|
+
#else
|
1272
|
+
static void calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q);
|
1273
|
+
virtual void load_fixed_variables();
|
1274
|
+
virtual void arrange_block();
|
1275
|
+
virtual void prepare_for_sg_update() = 0;
|
1276
|
+
virtual void sg_update(mf_int d_begin, mf_int d_end, mf_float rk) = 0;
|
1277
|
+
virtual void finalize();
|
1278
|
+
static float qrsqrt(float x);
|
1279
|
+
#endif
|
1280
|
+
virtual void update() { ++pG; ++qG; };
|
1281
|
+
|
1282
|
+
Scheduler &scheduler;
|
1283
|
+
vector<BlockBase*> &blocks;
|
1284
|
+
BlockBase *block;
|
1285
|
+
mf_float *PG;
|
1286
|
+
mf_float *QG;
|
1287
|
+
mf_model &model;
|
1288
|
+
mf_parameter param;
|
1289
|
+
bool &slow_only;
|
1290
|
+
|
1291
|
+
mf_node *N;
|
1292
|
+
mf_float z;
|
1293
|
+
mf_double loss;
|
1294
|
+
mf_double error;
|
1295
|
+
mf_float *p;
|
1296
|
+
mf_float *q;
|
1297
|
+
mf_float *pG;
|
1298
|
+
mf_float *qG;
|
1299
|
+
mf_int bid;
|
1300
|
+
|
1301
|
+
mf_float lambda_p1;
|
1302
|
+
mf_float lambda_q1;
|
1303
|
+
mf_float lambda_p2;
|
1304
|
+
mf_float lambda_q2;
|
1305
|
+
mf_float rk_slow;
|
1306
|
+
mf_float rk_fast;
|
1307
|
+
};
|
1308
|
+
|
1309
|
+
#if defined USESSE
|
1310
|
+
inline void SolverBase::run()
|
1311
|
+
{
|
1312
|
+
__m128d XMMloss;
|
1313
|
+
__m128d XMMerror;
|
1314
|
+
__m128 XMMz;
|
1315
|
+
__m128 XMMlambda_p1;
|
1316
|
+
__m128 XMMlambda_q1;
|
1317
|
+
__m128 XMMlambda_p2;
|
1318
|
+
__m128 XMMlambda_q2;
|
1319
|
+
__m128 XMMeta;
|
1320
|
+
__m128 XMMrk_slow;
|
1321
|
+
__m128 XMMrk_fast;
|
1322
|
+
load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
|
1323
|
+
XMMlambda_p2, XMMlambda_q2,
|
1324
|
+
XMMeta, XMMrk_slow,
|
1325
|
+
XMMrk_fast);
|
1326
|
+
while(!scheduler.is_terminated())
|
1327
|
+
{
|
1328
|
+
arrange_block(XMMloss, XMMerror);
|
1329
|
+
while(block->move_next())
|
1330
|
+
{
|
1331
|
+
N = block->get_current();
|
1332
|
+
p = model.P+(mf_long)N->u*model.k;
|
1333
|
+
q = model.Q+(mf_long)N->v*model.k;
|
1334
|
+
pG = PG+N->u*2;
|
1335
|
+
qG = QG+N->v*2;
|
1336
|
+
prepare_for_sg_update(XMMz, XMMloss, XMMerror);
|
1337
|
+
sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
|
1338
|
+
XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
|
1339
|
+
if(slow_only)
|
1340
|
+
continue;
|
1341
|
+
update();
|
1342
|
+
sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
|
1343
|
+
XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
|
1344
|
+
}
|
1345
|
+
finalize(XMMloss, XMMerror);
|
1346
|
+
}
|
1347
|
+
}
|
1348
|
+
|
1349
|
+
void SolverBase::load_fixed_variables(
|
1350
|
+
__m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
|
1351
|
+
__m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
|
1352
|
+
__m128 &XMMeta, __m128 &XMMrk_slow,
|
1353
|
+
__m128 &XMMrk_fast)
|
1354
|
+
{
|
1355
|
+
XMMlambda_p1 = _mm_set1_ps(param.lambda_p1);
|
1356
|
+
XMMlambda_q1 = _mm_set1_ps(param.lambda_q1);
|
1357
|
+
XMMlambda_p2 = _mm_set1_ps(param.lambda_p2);
|
1358
|
+
XMMlambda_q2 = _mm_set1_ps(param.lambda_q2);
|
1359
|
+
XMMeta = _mm_set1_ps(param.eta);
|
1360
|
+
XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
|
1361
|
+
XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
|
1362
|
+
}
|
1363
|
+
|
1364
|
+
void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
|
1365
|
+
{
|
1366
|
+
XMMloss = _mm_setzero_pd();
|
1367
|
+
XMMerror = _mm_setzero_pd();
|
1368
|
+
bid = scheduler.get_job();
|
1369
|
+
block = blocks[bid];
|
1370
|
+
block->reload();
|
1371
|
+
}
|
1372
|
+
|
1373
|
+
inline void SolverBase::calc_z(
|
1374
|
+
__m128 &XMMz, mf_int k, mf_float *p, mf_float *q)
|
1375
|
+
{
|
1376
|
+
XMMz = _mm_setzero_ps();
|
1377
|
+
for(mf_int d = 0; d < k; d += 4)
|
1378
|
+
XMMz = _mm_add_ps(XMMz, _mm_mul_ps(
|
1379
|
+
_mm_load_ps(p+d), _mm_load_ps(q+d)));
|
1380
|
+
// Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
|
1381
|
+
// high-bit to low-bit, where "+" means concatenating two arrays.
|
1382
|
+
__m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
|
1383
|
+
// Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
|
1384
|
+
// high-bit to low-bit, where "+" means concatenating two arrays.
|
1385
|
+
XMMz = _mm_add_ps(XMMtmp, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
|
1386
|
+
}
|
1387
|
+
|
1388
|
+
void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
|
1389
|
+
{
|
1390
|
+
_mm_store_sd(&loss, XMMloss);
|
1391
|
+
_mm_store_sd(&error, XMMerror);
|
1392
|
+
block->free();
|
1393
|
+
scheduler.put_job(bid, loss, error);
|
1394
|
+
}
|
1395
|
+
#elif defined USEAVX
|
1396
|
+
inline void SolverBase::run()
|
1397
|
+
{
|
1398
|
+
__m128d XMMloss;
|
1399
|
+
__m128d XMMerror;
|
1400
|
+
__m256 XMMz;
|
1401
|
+
__m256 XMMlambda_p1;
|
1402
|
+
__m256 XMMlambda_q1;
|
1403
|
+
__m256 XMMlambda_p2;
|
1404
|
+
__m256 XMMlambda_q2;
|
1405
|
+
__m256 XMMeta;
|
1406
|
+
__m256 XMMrk_slow;
|
1407
|
+
__m256 XMMrk_fast;
|
1408
|
+
load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
|
1409
|
+
XMMlambda_p2, XMMlambda_q2,
|
1410
|
+
XMMeta, XMMrk_slow, XMMrk_fast);
|
1411
|
+
while(!scheduler.is_terminated())
|
1412
|
+
{
|
1413
|
+
arrange_block(XMMloss, XMMerror);
|
1414
|
+
while(block->move_next())
|
1415
|
+
{
|
1416
|
+
N = block->get_current();
|
1417
|
+
p = model.P+(mf_long)N->u*model.k;
|
1418
|
+
q = model.Q+(mf_long)N->v*model.k;
|
1419
|
+
pG = PG+N->u*2;
|
1420
|
+
qG = QG+N->v*2;
|
1421
|
+
prepare_for_sg_update(XMMz, XMMloss, XMMerror);
|
1422
|
+
sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
|
1423
|
+
XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
|
1424
|
+
if(slow_only)
|
1425
|
+
continue;
|
1426
|
+
update();
|
1427
|
+
sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
|
1428
|
+
XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_fast);
|
1429
|
+
}
|
1430
|
+
finalize(XMMloss, XMMerror);
|
1431
|
+
}
|
1432
|
+
}
|
1433
|
+
|
1434
|
+
void SolverBase::load_fixed_variables(
|
1435
|
+
__m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
|
1436
|
+
__m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
|
1437
|
+
__m256 &XMMeta, __m256 &XMMrk_slow,
|
1438
|
+
__m256 &XMMrk_fast)
|
1439
|
+
{
|
1440
|
+
XMMlambda_p1 = _mm256_set1_ps(param.lambda_p1);
|
1441
|
+
XMMlambda_q1 = _mm256_set1_ps(param.lambda_q1);
|
1442
|
+
XMMlambda_p2 = _mm256_set1_ps(param.lambda_p2);
|
1443
|
+
XMMlambda_q2 = _mm256_set1_ps(param.lambda_q2);
|
1444
|
+
XMMeta = _mm256_set1_ps(param.eta);
|
1445
|
+
XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
|
1446
|
+
XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
|
1447
|
+
}
|
1448
|
+
|
1449
|
+
void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
|
1450
|
+
{
|
1451
|
+
XMMloss = _mm_setzero_pd();
|
1452
|
+
XMMerror = _mm_setzero_pd();
|
1453
|
+
bid = scheduler.get_job();
|
1454
|
+
block = blocks[bid];
|
1455
|
+
block->reload();
|
1456
|
+
}
|
1457
|
+
|
1458
|
+
inline void SolverBase::calc_z(
|
1459
|
+
__m256 &XMMz, mf_int k, mf_float *p, mf_float *q)
|
1460
|
+
{
|
1461
|
+
XMMz = _mm256_setzero_ps();
|
1462
|
+
for(mf_int d = 0; d < k; d += 8)
|
1463
|
+
XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
|
1464
|
+
_mm256_load_ps(p+d), _mm256_load_ps(q+d)));
|
1465
|
+
XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
|
1466
|
+
XMMz = _mm256_hadd_ps(XMMz, XMMz);
|
1467
|
+
XMMz = _mm256_hadd_ps(XMMz, XMMz);
|
1468
|
+
}
|
1469
|
+
|
1470
|
+
void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
|
1471
|
+
{
|
1472
|
+
_mm_store_sd(&loss, XMMloss);
|
1473
|
+
_mm_store_sd(&error, XMMerror);
|
1474
|
+
block->free();
|
1475
|
+
scheduler.put_job(bid, loss, error);
|
1476
|
+
}
|
1477
|
+
#else
|
1478
|
+
inline void SolverBase::run()
|
1479
|
+
{
|
1480
|
+
load_fixed_variables();
|
1481
|
+
while(!scheduler.is_terminated())
|
1482
|
+
{
|
1483
|
+
arrange_block();
|
1484
|
+
while(block->move_next())
|
1485
|
+
{
|
1486
|
+
N = block->get_current();
|
1487
|
+
p = model.P+(mf_long)N->u*model.k;
|
1488
|
+
q = model.Q+(mf_long)N->v*model.k;
|
1489
|
+
pG = PG+N->u*2;
|
1490
|
+
qG = QG+N->v*2;
|
1491
|
+
prepare_for_sg_update();
|
1492
|
+
sg_update(0, kALIGN, rk_slow);
|
1493
|
+
if(slow_only)
|
1494
|
+
continue;
|
1495
|
+
update();
|
1496
|
+
sg_update(kALIGN, model.k, rk_fast);
|
1497
|
+
}
|
1498
|
+
finalize();
|
1499
|
+
}
|
1500
|
+
}
|
1501
|
+
|
1502
|
+
inline float SolverBase::qrsqrt(float x)
|
1503
|
+
{
|
1504
|
+
float xhalf = 0.5f*x;
|
1505
|
+
uint32_t i;
|
1506
|
+
memcpy(&i, &x, sizeof(i));
|
1507
|
+
i = 0x5f375a86 - (i>>1);
|
1508
|
+
memcpy(&x, &i, sizeof(i));
|
1509
|
+
x = x*(1.5f - xhalf*x*x);
|
1510
|
+
return x;
|
1511
|
+
}
|
1512
|
+
|
1513
|
+
void SolverBase::load_fixed_variables()
|
1514
|
+
{
|
1515
|
+
lambda_p1 = param.lambda_p1;
|
1516
|
+
lambda_q1 = param.lambda_q1;
|
1517
|
+
lambda_p2 = param.lambda_p2;
|
1518
|
+
lambda_q2 = param.lambda_q2;
|
1519
|
+
rk_slow = (mf_float)1.0/kALIGN;
|
1520
|
+
rk_fast = (mf_float)1.0/(model.k-kALIGN);
|
1521
|
+
}
|
1522
|
+
|
1523
|
+
void SolverBase::arrange_block()
|
1524
|
+
{
|
1525
|
+
loss = 0.0;
|
1526
|
+
error = 0.0;
|
1527
|
+
bid = scheduler.get_job();
|
1528
|
+
block = blocks[bid];
|
1529
|
+
block->reload();
|
1530
|
+
}
|
1531
|
+
|
1532
|
+
inline void SolverBase::calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q)
|
1533
|
+
{
|
1534
|
+
z = 0;
|
1535
|
+
for(mf_int d = 0; d < k; ++d)
|
1536
|
+
z += p[d]*q[d];
|
1537
|
+
}
|
1538
|
+
|
1539
|
+
void SolverBase::finalize()
|
1540
|
+
{
|
1541
|
+
block->free();
|
1542
|
+
scheduler.put_job(bid, loss, error);
|
1543
|
+
}
|
1544
|
+
#endif
|
1545
|
+
|
1546
|
+
//--------------------------------------
|
1547
|
+
//-----Real-valued MF and binary MF-----
|
1548
|
+
//--------------------------------------
|
1549
|
+
|
1550
|
+
class MFSolver: public SolverBase
|
1551
|
+
{
|
1552
|
+
public:
|
1553
|
+
MFSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
1554
|
+
mf_float *PG, mf_float *QG, mf_model &model,
|
1555
|
+
mf_parameter param, bool &slow_only)
|
1556
|
+
: SolverBase(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
1557
|
+
|
1558
|
+
protected:
|
1559
|
+
#if defined USESSE
|
1560
|
+
void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
|
1561
|
+
__m128 XMMlambda_p1, __m128 XMMlambda_q1,
|
1562
|
+
__m128 XMMlambda_p2, __m128 XMMlambda_q2,
|
1563
|
+
__m128 XMMeta, __m128 XMMrk);
|
1564
|
+
#elif defined USEAVX
|
1565
|
+
void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
|
1566
|
+
__m256 XMMlambda_p1, __m256 XMMlambda_q1,
|
1567
|
+
__m256 XMMlambda_p2, __m256 XMMlambda_q2,
|
1568
|
+
__m256 XMMeta, __m256 XMMrk);
|
1569
|
+
#else
|
1570
|
+
void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
|
1571
|
+
#endif
|
1572
|
+
};
|
1573
|
+
|
1574
|
+
#if defined USESSE
|
1575
|
+
void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
|
1576
|
+
__m128 XMMlambda_p1, __m128 XMMlambda_q1,
|
1577
|
+
__m128 XMMlambda_p2, __m128 XMMlambda_q2,
|
1578
|
+
__m128 XMMeta, __m128 XMMrk)
|
1579
|
+
{
|
1580
|
+
__m128 XMMpG = _mm_load1_ps(pG);
|
1581
|
+
__m128 XMMqG = _mm_load1_ps(qG);
|
1582
|
+
__m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
|
1583
|
+
__m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
|
1584
|
+
__m128 XMMpG1 = _mm_setzero_ps();
|
1585
|
+
__m128 XMMqG1 = _mm_setzero_ps();
|
1586
|
+
|
1587
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
1588
|
+
{
|
1589
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
1590
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
1591
|
+
|
1592
|
+
__m128 XMMpg = _mm_sub_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
|
1593
|
+
_mm_mul_ps(XMMz, XMMq));
|
1594
|
+
__m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
|
1595
|
+
_mm_mul_ps(XMMz, XMMp));
|
1596
|
+
|
1597
|
+
XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
|
1598
|
+
XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
|
1599
|
+
|
1600
|
+
XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
|
1601
|
+
XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
|
1602
|
+
|
1603
|
+
_mm_store_ps(p+d, XMMp);
|
1604
|
+
_mm_store_ps(q+d, XMMq);
|
1605
|
+
}
|
1606
|
+
|
1607
|
+
mf_float tmp = 0;
|
1608
|
+
_mm_store_ss(&tmp, XMMlambda_p1);
|
1609
|
+
if(tmp > 0)
|
1610
|
+
{
|
1611
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
1612
|
+
{
|
1613
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
1614
|
+
__m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
|
1615
|
+
_mm_set1_ps(-0.0f));
|
1616
|
+
XMMp = _mm_xor_ps(XMMflip,
|
1617
|
+
_mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
|
1618
|
+
_mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
|
1619
|
+
_mm_store_ps(p+d, XMMp);
|
1620
|
+
}
|
1621
|
+
}
|
1622
|
+
|
1623
|
+
_mm_store_ss(&tmp, XMMlambda_q1);
|
1624
|
+
if(tmp > 0)
|
1625
|
+
{
|
1626
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
1627
|
+
{
|
1628
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
1629
|
+
__m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
|
1630
|
+
_mm_set1_ps(-0.0f));
|
1631
|
+
XMMq = _mm_xor_ps(XMMflip,
|
1632
|
+
_mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
|
1633
|
+
_mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
|
1634
|
+
_mm_store_ps(q+d, XMMq);
|
1635
|
+
}
|
1636
|
+
}
|
1637
|
+
|
1638
|
+
if(param.do_nmf)
|
1639
|
+
{
|
1640
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
1641
|
+
{
|
1642
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
1643
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
1644
|
+
XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
|
1645
|
+
XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
|
1646
|
+
_mm_store_ps(p+d, XMMp);
|
1647
|
+
_mm_store_ps(q+d, XMMq);
|
1648
|
+
}
|
1649
|
+
}
|
1650
|
+
|
1651
|
+
__m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
|
1652
|
+
XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
1653
|
+
XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
|
1654
|
+
_mm_store_ss(pG, XMMpG);
|
1655
|
+
|
1656
|
+
XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
|
1657
|
+
XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
1658
|
+
XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
|
1659
|
+
_mm_store_ss(qG, XMMqG);
|
1660
|
+
}
|
1661
|
+
#elif defined USEAVX
|
1662
|
+
void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
|
1663
|
+
__m256 XMMlambda_p1, __m256 XMMlambda_q1,
|
1664
|
+
__m256 XMMlambda_p2, __m256 XMMlambda_q2,
|
1665
|
+
__m256 XMMeta, __m256 XMMrk)
|
1666
|
+
{
|
1667
|
+
__m256 XMMpG = _mm256_broadcast_ss(pG);
|
1668
|
+
__m256 XMMqG = _mm256_broadcast_ss(qG);
|
1669
|
+
__m256 XMMeta_p = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
|
1670
|
+
__m256 XMMeta_q = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
|
1671
|
+
__m256 XMMpG1 = _mm256_setzero_ps();
|
1672
|
+
__m256 XMMqG1 = _mm256_setzero_ps();
|
1673
|
+
|
1674
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
1675
|
+
{
|
1676
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
1677
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
1678
|
+
|
1679
|
+
__m256 XMMpg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
|
1680
|
+
_mm256_mul_ps(XMMz, XMMq));
|
1681
|
+
__m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
|
1682
|
+
_mm256_mul_ps(XMMz, XMMp));
|
1683
|
+
|
1684
|
+
XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
|
1685
|
+
XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
|
1686
|
+
|
1687
|
+
XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
|
1688
|
+
XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
|
1689
|
+
_mm256_store_ps(p+d, XMMp);
|
1690
|
+
_mm256_store_ps(q+d, XMMq);
|
1691
|
+
}
|
1692
|
+
|
1693
|
+
mf_float tmp = 0;
|
1694
|
+
_mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
|
1695
|
+
if(tmp > 0)
|
1696
|
+
{
|
1697
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
1698
|
+
{
|
1699
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
1700
|
+
__m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMp,
|
1701
|
+
_mm256_set1_ps(0.0f), _CMP_LE_OS),
|
1702
|
+
_mm256_set1_ps(-0.0f));
|
1703
|
+
XMMp = _mm256_xor_ps(XMMflip,
|
1704
|
+
_mm256_max_ps(_mm256_sub_ps(
|
1705
|
+
_mm256_xor_ps(XMMp, XMMflip),
|
1706
|
+
_mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
|
1707
|
+
_mm256_set1_ps(0.0f)));
|
1708
|
+
_mm256_store_ps(p+d, XMMp);
|
1709
|
+
}
|
1710
|
+
}
|
1711
|
+
|
1712
|
+
_mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
|
1713
|
+
if(tmp > 0)
|
1714
|
+
{
|
1715
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
1716
|
+
{
|
1717
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
1718
|
+
__m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMq,
|
1719
|
+
_mm256_set1_ps(0.0f), _CMP_LE_OS),
|
1720
|
+
_mm256_set1_ps(-0.0f));
|
1721
|
+
XMMq = _mm256_xor_ps(XMMflip,
|
1722
|
+
_mm256_max_ps(_mm256_sub_ps(
|
1723
|
+
_mm256_xor_ps(XMMq, XMMflip),
|
1724
|
+
_mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
|
1725
|
+
_mm256_set1_ps(0.0f)));
|
1726
|
+
_mm256_store_ps(q+d, XMMq);
|
1727
|
+
}
|
1728
|
+
}
|
1729
|
+
|
1730
|
+
if(param.do_nmf)
|
1731
|
+
{
|
1732
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
1733
|
+
{
|
1734
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
1735
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
1736
|
+
XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0));
|
1737
|
+
XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0));
|
1738
|
+
_mm256_store_ps(p+d, XMMp);
|
1739
|
+
_mm256_store_ps(q+d, XMMq);
|
1740
|
+
}
|
1741
|
+
}
|
1742
|
+
|
1743
|
+
XMMpG1 = _mm256_add_ps(XMMpG1,
|
1744
|
+
_mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
|
1745
|
+
XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
|
1746
|
+
XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
|
1747
|
+
|
1748
|
+
XMMqG1 = _mm256_add_ps(XMMqG1,
|
1749
|
+
_mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
|
1750
|
+
XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
|
1751
|
+
XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
|
1752
|
+
|
1753
|
+
XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
|
1754
|
+
XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
|
1755
|
+
|
1756
|
+
_mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
|
1757
|
+
_mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
|
1758
|
+
}
|
1759
|
+
#else
|
1760
|
+
void MFSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
|
1761
|
+
{
|
1762
|
+
mf_float eta_p = param.eta*qrsqrt(*pG);
|
1763
|
+
mf_float eta_q = param.eta*qrsqrt(*qG);
|
1764
|
+
|
1765
|
+
mf_float pG1 = 0;
|
1766
|
+
mf_float qG1 = 0;
|
1767
|
+
|
1768
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
1769
|
+
{
|
1770
|
+
mf_float gp = -z*q[d]+lambda_p2*p[d];
|
1771
|
+
mf_float gq = -z*p[d]+lambda_q2*q[d];
|
1772
|
+
|
1773
|
+
pG1 += gp*gp;
|
1774
|
+
qG1 += gq*gq;
|
1775
|
+
|
1776
|
+
p[d] -= eta_p*gp;
|
1777
|
+
q[d] -= eta_q*gq;
|
1778
|
+
}
|
1779
|
+
|
1780
|
+
if(lambda_p1 > 0)
|
1781
|
+
{
|
1782
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
1783
|
+
{
|
1784
|
+
mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
|
1785
|
+
p[d] = p[d] >= 0? p1: -p1;
|
1786
|
+
}
|
1787
|
+
}
|
1788
|
+
|
1789
|
+
if(lambda_q1 > 0)
|
1790
|
+
{
|
1791
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
1792
|
+
{
|
1793
|
+
mf_float q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
|
1794
|
+
q[d] = q[d] >= 0? q1: -q1;
|
1795
|
+
}
|
1796
|
+
}
|
1797
|
+
|
1798
|
+
if(param.do_nmf)
|
1799
|
+
{
|
1800
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
1801
|
+
{
|
1802
|
+
p[d] = max(p[d], (mf_float)0.0f);
|
1803
|
+
q[d] = max(q[d], (mf_float)0.0f);
|
1804
|
+
}
|
1805
|
+
}
|
1806
|
+
|
1807
|
+
*pG += pG1*rk;
|
1808
|
+
*qG += qG1*rk;
|
1809
|
+
}
|
1810
|
+
#endif
|
1811
|
+
|
1812
|
+
class L2_MFR : public MFSolver
|
1813
|
+
{
|
1814
|
+
public:
|
1815
|
+
L2_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
|
1816
|
+
mf_model &model, mf_parameter param, bool &slow_only)
|
1817
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
1818
|
+
|
1819
|
+
protected:
|
1820
|
+
#if defined USESSE
|
1821
|
+
void prepare_for_sg_update(
|
1822
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1823
|
+
#elif defined USEAVX
|
1824
|
+
void prepare_for_sg_update(
|
1825
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1826
|
+
#else
|
1827
|
+
void prepare_for_sg_update();
|
1828
|
+
#endif
|
1829
|
+
};
|
1830
|
+
|
1831
|
+
#if defined USESSE
|
1832
|
+
void L2_MFR::prepare_for_sg_update(
|
1833
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1834
|
+
{
|
1835
|
+
calc_z(XMMz, model.k, p, q);
|
1836
|
+
XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
|
1837
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
1838
|
+
_mm_mul_ps(XMMz, XMMz)));
|
1839
|
+
XMMerror = XMMloss;
|
1840
|
+
}
|
1841
|
+
#elif defined USEAVX
|
1842
|
+
void L2_MFR::prepare_for_sg_update(
|
1843
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1844
|
+
{
|
1845
|
+
calc_z(XMMz, model.k, p, q);
|
1846
|
+
XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
|
1847
|
+
XMMloss = _mm_add_pd(XMMloss,
|
1848
|
+
_mm_cvtps_pd(_mm256_castps256_ps128(
|
1849
|
+
_mm256_mul_ps(XMMz, XMMz))));
|
1850
|
+
XMMerror = XMMloss;
|
1851
|
+
}
|
1852
|
+
#else
|
1853
|
+
void L2_MFR::prepare_for_sg_update()
|
1854
|
+
{
|
1855
|
+
calc_z(z, model.k, p, q);
|
1856
|
+
z = N->r-z;
|
1857
|
+
loss += z*z;
|
1858
|
+
error = loss;
|
1859
|
+
}
|
1860
|
+
#endif
|
1861
|
+
class L1_MFR : public MFSolver
|
1862
|
+
{
|
1863
|
+
public:
|
1864
|
+
L1_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
|
1865
|
+
mf_model &model, mf_parameter param, bool &slow_only)
|
1866
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
1867
|
+
|
1868
|
+
protected:
|
1869
|
+
#if defined USESSE
|
1870
|
+
void prepare_for_sg_update(
|
1871
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1872
|
+
#elif defined USEAVX
|
1873
|
+
void prepare_for_sg_update(
|
1874
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1875
|
+
#else
|
1876
|
+
void prepare_for_sg_update();
|
1877
|
+
#endif
|
1878
|
+
};
|
1879
|
+
|
1880
|
+
#if defined USESSE
|
1881
|
+
void L1_MFR::prepare_for_sg_update(
|
1882
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1883
|
+
{
|
1884
|
+
calc_z(XMMz, model.k, p, q);
|
1885
|
+
XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
|
1886
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
1887
|
+
_mm_andnot_ps(_mm_set1_ps(-0.0f), XMMz)));
|
1888
|
+
XMMerror = XMMloss;
|
1889
|
+
XMMz = _mm_add_ps(_mm_and_ps(_mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f)),
|
1890
|
+
_mm_set1_ps(1.0f)),
|
1891
|
+
_mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
|
1892
|
+
_mm_set1_ps(-1.0f)));
|
1893
|
+
}
|
1894
|
+
#elif defined USEAVX
|
1895
|
+
void L1_MFR::prepare_for_sg_update(
|
1896
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1897
|
+
{
|
1898
|
+
calc_z(XMMz, model.k, p, q);
|
1899
|
+
XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
|
1900
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm256_castps256_ps128(
|
1901
|
+
_mm256_andnot_ps(_mm256_set1_ps(-0.0f), XMMz))));
|
1902
|
+
XMMerror = XMMloss;
|
1903
|
+
XMMz = _mm256_add_ps(_mm256_and_ps(_mm256_cmp_ps(XMMz,
|
1904
|
+
_mm256_set1_ps(0.0f), _CMP_GT_OS), _mm256_set1_ps(1.0f)),
|
1905
|
+
_mm256_and_ps(_mm256_cmp_ps(XMMz,
|
1906
|
+
_mm256_set1_ps(0.0f), _CMP_LT_OS), _mm256_set1_ps(-1.0f)));
|
1907
|
+
}
|
1908
|
+
#else
|
1909
|
+
void L1_MFR::prepare_for_sg_update()
|
1910
|
+
{
|
1911
|
+
calc_z(z, model.k, p, q);
|
1912
|
+
z = N->r-z;
|
1913
|
+
loss += abs(z);
|
1914
|
+
error = loss;
|
1915
|
+
if(z > 0)
|
1916
|
+
z = 1;
|
1917
|
+
else if(z < 0)
|
1918
|
+
z = -1;
|
1919
|
+
}
|
1920
|
+
#endif
|
1921
|
+
|
1922
|
+
class KL_MFR : public MFSolver
|
1923
|
+
{
|
1924
|
+
public:
|
1925
|
+
KL_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
|
1926
|
+
mf_model &model, mf_parameter param, bool &slow_only)
|
1927
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
1928
|
+
|
1929
|
+
protected:
|
1930
|
+
#if defined USESSE
|
1931
|
+
void prepare_for_sg_update(
|
1932
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1933
|
+
#elif defined USEAVX
|
1934
|
+
void prepare_for_sg_update(
|
1935
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1936
|
+
#else
|
1937
|
+
void prepare_for_sg_update();
|
1938
|
+
#endif
|
1939
|
+
};
|
1940
|
+
|
1941
|
+
#if defined USESSE
|
1942
|
+
void KL_MFR::prepare_for_sg_update(
|
1943
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1944
|
+
{
|
1945
|
+
calc_z(XMMz, model.k, p, q);
|
1946
|
+
XMMz = _mm_div_ps(_mm_set1_ps(N->r), XMMz);
|
1947
|
+
_mm_store_ss(&z, XMMz);
|
1948
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
1949
|
+
_mm_set1_ps(N->r*(log(z)-1+1/z))));
|
1950
|
+
XMMerror = XMMloss;
|
1951
|
+
XMMz = _mm_sub_ps(XMMz, _mm_set1_ps(1.0f));
|
1952
|
+
}
|
1953
|
+
#elif defined USEAVX
|
1954
|
+
void KL_MFR::prepare_for_sg_update(
|
1955
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1956
|
+
{
|
1957
|
+
calc_z(XMMz, model.k, p, q);
|
1958
|
+
XMMz = _mm256_div_ps(_mm256_set1_ps(N->r), XMMz);
|
1959
|
+
_mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
|
1960
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
1961
|
+
_mm_set1_ps(N->r*(log(z)-1+1/z))));
|
1962
|
+
XMMerror = XMMloss;
|
1963
|
+
XMMz = _mm256_sub_ps(XMMz, _mm256_set1_ps(1.0f));
|
1964
|
+
}
|
1965
|
+
#else
|
1966
|
+
void KL_MFR::prepare_for_sg_update()
|
1967
|
+
{
|
1968
|
+
calc_z(z, model.k, p, q);
|
1969
|
+
z = N->r/z;
|
1970
|
+
loss += N->r*(log(z)-1+1/z);
|
1971
|
+
error = loss;
|
1972
|
+
z -= 1;
|
1973
|
+
}
|
1974
|
+
#endif
|
1975
|
+
|
1976
|
+
class LR_MFC : public MFSolver
|
1977
|
+
{
|
1978
|
+
public:
|
1979
|
+
LR_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
1980
|
+
mf_float *PG, mf_float *QG, mf_model &model,
|
1981
|
+
mf_parameter param, bool &slow_only)
|
1982
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
1983
|
+
|
1984
|
+
protected:
|
1985
|
+
#if defined USESSE
|
1986
|
+
void prepare_for_sg_update(
|
1987
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1988
|
+
#elif defined USEAVX
|
1989
|
+
void prepare_for_sg_update(
|
1990
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
1991
|
+
#else
|
1992
|
+
void prepare_for_sg_update();
|
1993
|
+
#endif
|
1994
|
+
};
|
1995
|
+
|
1996
|
+
#if defined USESSE
|
1997
|
+
void LR_MFC::prepare_for_sg_update(
|
1998
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
1999
|
+
{
|
2000
|
+
calc_z(XMMz, model.k, p, q);
|
2001
|
+
_mm_store_ss(&z, XMMz);
|
2002
|
+
if(N->r > 0)
|
2003
|
+
{
|
2004
|
+
z = exp(-z);
|
2005
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
|
2006
|
+
XMMz = _mm_set1_ps(z/(1+z));
|
2007
|
+
}
|
2008
|
+
else
|
2009
|
+
{
|
2010
|
+
z = exp(z);
|
2011
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
|
2012
|
+
XMMz = _mm_set1_ps(-z/(1+z));
|
2013
|
+
}
|
2014
|
+
XMMerror = XMMloss;
|
2015
|
+
}
|
2016
|
+
#elif defined USEAVX
|
2017
|
+
void LR_MFC::prepare_for_sg_update(
|
2018
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2019
|
+
{
|
2020
|
+
calc_z(XMMz, model.k, p, q);
|
2021
|
+
_mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
|
2022
|
+
if(N->r > 0)
|
2023
|
+
{
|
2024
|
+
z = exp(-z);
|
2025
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
|
2026
|
+
XMMz = _mm256_set1_ps(z/(1+z));
|
2027
|
+
}
|
2028
|
+
else
|
2029
|
+
{
|
2030
|
+
z = exp(z);
|
2031
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
|
2032
|
+
XMMz = _mm256_set1_ps(-z/(1+z));
|
2033
|
+
}
|
2034
|
+
XMMerror = XMMloss;
|
2035
|
+
}
|
2036
|
+
#else
|
2037
|
+
void LR_MFC::prepare_for_sg_update()
|
2038
|
+
{
|
2039
|
+
calc_z(z, model.k, p, q);
|
2040
|
+
if(N->r > 0)
|
2041
|
+
{
|
2042
|
+
z = exp(-z);
|
2043
|
+
loss += log(1+z);
|
2044
|
+
error = loss;
|
2045
|
+
z = z/(1+z);
|
2046
|
+
}
|
2047
|
+
else
|
2048
|
+
{
|
2049
|
+
z = exp(z);
|
2050
|
+
loss += log(1+z);
|
2051
|
+
error = loss;
|
2052
|
+
z = -z/(1+z);
|
2053
|
+
}
|
2054
|
+
}
|
2055
|
+
#endif
|
2056
|
+
|
2057
|
+
class L2_MFC : public MFSolver
|
2058
|
+
{
|
2059
|
+
public:
|
2060
|
+
L2_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
2061
|
+
mf_float *PG, mf_float *QG, mf_model &model,
|
2062
|
+
mf_parameter param, bool &slow_only)
|
2063
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
2064
|
+
|
2065
|
+
protected:
|
2066
|
+
#if defined USESSE
|
2067
|
+
void prepare_for_sg_update(
|
2068
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2069
|
+
#elif defined USEAVX
|
2070
|
+
void prepare_for_sg_update(
|
2071
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2072
|
+
#else
|
2073
|
+
void prepare_for_sg_update();
|
2074
|
+
#endif
|
2075
|
+
};
|
2076
|
+
|
2077
|
+
#if defined USESSE
|
2078
|
+
void L2_MFC::prepare_for_sg_update(
|
2079
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2080
|
+
{
|
2081
|
+
calc_z(XMMz, model.k, p, q);
|
2082
|
+
if(N->r > 0)
|
2083
|
+
{
|
2084
|
+
__m128 mask = _mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f));
|
2085
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2086
|
+
_mm_and_ps(_mm_set1_ps(1.0f), mask)));
|
2087
|
+
XMMz = _mm_max_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
|
2088
|
+
_mm_set1_ps(1.0f), XMMz));
|
2089
|
+
}
|
2090
|
+
else
|
2091
|
+
{
|
2092
|
+
__m128 mask = _mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f));
|
2093
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2094
|
+
_mm_and_ps(_mm_set1_ps(1.0f), mask)));
|
2095
|
+
XMMz = _mm_min_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
|
2096
|
+
_mm_set1_ps(-1.0f), XMMz));
|
2097
|
+
}
|
2098
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
2099
|
+
_mm_mul_ps(XMMz, XMMz)));
|
2100
|
+
}
|
2101
|
+
#elif defined USEAVX
|
2102
|
+
void L2_MFC::prepare_for_sg_update(
|
2103
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2104
|
+
{
|
2105
|
+
calc_z(XMMz, model.k, p, q);
|
2106
|
+
if(N->r > 0)
|
2107
|
+
{
|
2108
|
+
__m128 mask = _mm_cmpgt_ps(_mm256_castps256_ps128(XMMz),
|
2109
|
+
_mm_set1_ps(0.0f));
|
2110
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2111
|
+
_mm_and_ps(_mm_set1_ps(1.0f), mask)));
|
2112
|
+
XMMz = _mm256_max_ps(_mm256_set1_ps(0.0f),
|
2113
|
+
_mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz));
|
2114
|
+
}
|
2115
|
+
else
|
2116
|
+
{
|
2117
|
+
__m128 mask = _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
|
2118
|
+
_mm_set1_ps(0.0f));
|
2119
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2120
|
+
_mm_and_ps(_mm_set1_ps(1.0f), mask)));
|
2121
|
+
XMMz = _mm256_min_ps(_mm256_set1_ps(0.0f),
|
2122
|
+
_mm256_sub_ps(_mm256_set1_ps(-1.0f), XMMz));
|
2123
|
+
}
|
2124
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
2125
|
+
_mm_mul_ps(_mm256_castps256_ps128(XMMz),
|
2126
|
+
_mm256_castps256_ps128(XMMz))));
|
2127
|
+
}
|
2128
|
+
#else
|
2129
|
+
void L2_MFC::prepare_for_sg_update()
|
2130
|
+
{
|
2131
|
+
calc_z(z, model.k, p, q);
|
2132
|
+
if(N->r > 0)
|
2133
|
+
{
|
2134
|
+
error += z > 0? 1: 0;
|
2135
|
+
z = max(0.0f, 1-z);
|
2136
|
+
}
|
2137
|
+
else
|
2138
|
+
{
|
2139
|
+
error += z < 0? 1: 0;
|
2140
|
+
z = min(0.0f, -1-z);
|
2141
|
+
}
|
2142
|
+
loss += z*z;
|
2143
|
+
}
|
2144
|
+
#endif
|
2145
|
+
|
2146
|
+
class L1_MFC : public MFSolver
|
2147
|
+
{
|
2148
|
+
public:
|
2149
|
+
L1_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
|
2150
|
+
mf_model &model, mf_parameter param, bool &slow_only)
|
2151
|
+
: MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
|
2152
|
+
|
2153
|
+
protected:
|
2154
|
+
#if defined USESSE
|
2155
|
+
void prepare_for_sg_update(
|
2156
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2157
|
+
#elif defined USEAVX
|
2158
|
+
void prepare_for_sg_update(
|
2159
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2160
|
+
#else
|
2161
|
+
void prepare_for_sg_update();
|
2162
|
+
#endif
|
2163
|
+
};
|
2164
|
+
|
2165
|
+
#if defined USESSE
|
2166
|
+
void L1_MFC::prepare_for_sg_update(
|
2167
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2168
|
+
{
|
2169
|
+
calc_z(XMMz, model.k, p, q);
|
2170
|
+
if(N->r > 0)
|
2171
|
+
{
|
2172
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2173
|
+
_mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
|
2174
|
+
_mm_set1_ps(1.0f))));
|
2175
|
+
XMMz = _mm_sub_ps(_mm_set1_ps(1.0f), XMMz);
|
2176
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
2177
|
+
_mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
|
2178
|
+
XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
|
2179
|
+
_mm_set1_ps(1.0f));
|
2180
|
+
}
|
2181
|
+
else
|
2182
|
+
{
|
2183
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
|
2184
|
+
_mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
|
2185
|
+
_mm_set1_ps(1.0f))));
|
2186
|
+
XMMz = _mm_add_ps(_mm_set1_ps(1.0f), XMMz);
|
2187
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
|
2188
|
+
_mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
|
2189
|
+
XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
|
2190
|
+
_mm_set1_ps(-1.0f));
|
2191
|
+
}
|
2192
|
+
}
|
2193
|
+
#elif defined USEAVX
|
2194
|
+
void L1_MFC::prepare_for_sg_update(
|
2195
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2196
|
+
{
|
2197
|
+
calc_z(XMMz, model.k, p, q);
|
2198
|
+
if(N->r > 0)
|
2199
|
+
{
|
2200
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
|
2201
|
+
_mm_cmpge_ps(_mm256_castps256_ps128(XMMz),
|
2202
|
+
_mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
|
2203
|
+
XMMz = _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz);
|
2204
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
|
2205
|
+
_mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
|
2206
|
+
XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
|
2207
|
+
_CMP_GE_OS), _mm256_set1_ps(1.0f));
|
2208
|
+
}
|
2209
|
+
else
|
2210
|
+
{
|
2211
|
+
XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
|
2212
|
+
_mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
|
2213
|
+
_mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
|
2214
|
+
XMMz = _mm256_add_ps(_mm256_set1_ps(1.0f), XMMz);
|
2215
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
|
2216
|
+
_mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
|
2217
|
+
XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
|
2218
|
+
_CMP_GE_OS), _mm256_set1_ps(-1.0f));
|
2219
|
+
}
|
2220
|
+
}
|
2221
|
+
#else
|
2222
|
+
void L1_MFC::prepare_for_sg_update()
|
2223
|
+
{
|
2224
|
+
calc_z(z, model.k, p, q);
|
2225
|
+
if(N->r > 0)
|
2226
|
+
{
|
2227
|
+
loss += max(0.0f, 1-z);
|
2228
|
+
error += z > 0? 1.0f: 0.0f;
|
2229
|
+
z = z > 1? 0.0f: 1.0f;
|
2230
|
+
}
|
2231
|
+
else
|
2232
|
+
{
|
2233
|
+
loss += max(0.0f, 1+z);
|
2234
|
+
error += z < 0? 1.0f: 0.0f;
|
2235
|
+
z = z < -1? 0.0f: -1.0f;
|
2236
|
+
}
|
2237
|
+
}
|
2238
|
+
#endif
|
2239
|
+
//--------------------------------------
|
2240
|
+
//------------One-class MF--------------
|
2241
|
+
//--------------------------------------
|
2242
|
+
|
2243
|
+
class BPRSolver : public SolverBase
|
2244
|
+
{
|
2245
|
+
public:
|
2246
|
+
BPRSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
2247
|
+
mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
|
2248
|
+
bool &slow_only, bool is_column_oriented)
|
2249
|
+
: SolverBase(scheduler, blocks, PG, QG, model, param, slow_only),
|
2250
|
+
is_column_oriented(is_column_oriented) {}
|
2251
|
+
|
2252
|
+
protected:
|
2253
|
+
#if defined USESSE
|
2254
|
+
static void calc_z(__m128 &XMMz, mf_int k,
|
2255
|
+
mf_float *p, mf_float *q, mf_float *w);
|
2256
|
+
void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
|
2257
|
+
void prepare_for_sg_update(
|
2258
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2259
|
+
void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
|
2260
|
+
__m128 XMMlambda_p1, __m128 XMMlambda_q1,
|
2261
|
+
__m128 XMMlambda_p2, __m128 XMMlamdba_q2,
|
2262
|
+
__m128 XMMeta, __m128 XMMrk);
|
2263
|
+
void finalize(__m128d XMMloss, __m128d XMMerror);
|
2264
|
+
#elif defined USEAVX
|
2265
|
+
static void calc_z(__m256 &XMMz, mf_int k,
|
2266
|
+
mf_float *p, mf_float *q, mf_float *w);
|
2267
|
+
void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
|
2268
|
+
void prepare_for_sg_update(
|
2269
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
|
2270
|
+
void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
|
2271
|
+
__m256 XMMlambda_p1, __m256 XMMlambda_q1,
|
2272
|
+
__m256 XMMlambda_p2, __m256 XMMlamdba_q2,
|
2273
|
+
__m256 XMMeta, __m256 XMMrk);
|
2274
|
+
void finalize(__m128d XMMloss, __m128d XMMerror);
|
2275
|
+
#else
|
2276
|
+
static void calc_z(mf_float &z, mf_int k,
|
2277
|
+
mf_float *p, mf_float *q, mf_float *w);
|
2278
|
+
void arrange_block();
|
2279
|
+
void prepare_for_sg_update();
|
2280
|
+
void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
|
2281
|
+
void finalize();
|
2282
|
+
#endif
|
2283
|
+
void update() { ++pG; ++qG; ++wG; };
|
2284
|
+
virtual void prepare_negative() = 0;
|
2285
|
+
|
2286
|
+
bool is_column_oriented;
|
2287
|
+
mf_int bpr_bid;
|
2288
|
+
mf_float *w;
|
2289
|
+
mf_float *wG;
|
2290
|
+
};
|
2291
|
+
|
2292
|
+
|
2293
|
+
#if defined USESSE
|
2294
|
+
inline void BPRSolver::calc_z(
|
2295
|
+
__m128 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
|
2296
|
+
{
|
2297
|
+
XMMz = _mm_setzero_ps();
|
2298
|
+
for(mf_int d = 0; d < k; d += 4)
|
2299
|
+
XMMz = _mm_add_ps(XMMz, _mm_mul_ps(_mm_load_ps(p+d),
|
2300
|
+
_mm_sub_ps(_mm_load_ps(q+d), _mm_load_ps(w+d))));
|
2301
|
+
// Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
|
2302
|
+
// high-bit to low-bit, where "+" means concatenating two arrays.
|
2303
|
+
__m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
|
2304
|
+
// Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
|
2305
|
+
// high-bit to low-bit, where "+" means concatenating two arrays.
|
2306
|
+
XMMz = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
|
2307
|
+
}
|
2308
|
+
|
2309
|
+
void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
|
2310
|
+
{
|
2311
|
+
XMMloss = _mm_setzero_pd();
|
2312
|
+
XMMerror = _mm_setzero_pd();
|
2313
|
+
bid = scheduler.get_job();
|
2314
|
+
block = blocks[bid];
|
2315
|
+
block->reload();
|
2316
|
+
bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
|
2317
|
+
}
|
2318
|
+
|
2319
|
+
void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
|
2320
|
+
{
|
2321
|
+
_mm_store_sd(&loss, XMMloss);
|
2322
|
+
_mm_store_sd(&error, XMMerror);
|
2323
|
+
scheduler.put_job(bid, loss, error);
|
2324
|
+
scheduler.put_bpr_job(bid, bpr_bid);
|
2325
|
+
}
|
2326
|
+
|
2327
|
+
void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
|
2328
|
+
__m128 XMMlambda_p1, __m128 XMMlambda_q1,
|
2329
|
+
__m128 XMMlambda_p2, __m128 XMMlambda_q2,
|
2330
|
+
__m128 XMMeta, __m128 XMMrk)
|
2331
|
+
{
|
2332
|
+
__m128 XMMpG = _mm_load1_ps(pG);
|
2333
|
+
__m128 XMMqG = _mm_load1_ps(qG);
|
2334
|
+
__m128 XMMwG = _mm_load1_ps(wG);
|
2335
|
+
__m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
|
2336
|
+
__m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
|
2337
|
+
__m128 XMMeta_w = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMwG));
|
2338
|
+
|
2339
|
+
__m128 XMMpG1 = _mm_setzero_ps();
|
2340
|
+
__m128 XMMqG1 = _mm_setzero_ps();
|
2341
|
+
__m128 XMMwG1 = _mm_setzero_ps();
|
2342
|
+
|
2343
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
2344
|
+
{
|
2345
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
2346
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
2347
|
+
__m128 XMMw = _mm_load_ps(w+d);
|
2348
|
+
|
2349
|
+
__m128 XMMpg = _mm_add_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
|
2350
|
+
_mm_mul_ps(XMMz, _mm_sub_ps(XMMw, XMMq)));
|
2351
|
+
__m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
|
2352
|
+
_mm_mul_ps(XMMz, XMMp));
|
2353
|
+
__m128 XMMwg = _mm_add_ps(_mm_mul_ps(XMMlambda_q2, XMMw),
|
2354
|
+
_mm_mul_ps(XMMz, XMMp));
|
2355
|
+
|
2356
|
+
XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
|
2357
|
+
XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
|
2358
|
+
XMMwG1 = _mm_add_ps(XMMwG1, _mm_mul_ps(XMMwg, XMMwg));
|
2359
|
+
|
2360
|
+
XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
|
2361
|
+
XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
|
2362
|
+
XMMw = _mm_sub_ps(XMMw, _mm_mul_ps(XMMeta_w, XMMwg));
|
2363
|
+
|
2364
|
+
_mm_store_ps(p+d, XMMp);
|
2365
|
+
_mm_store_ps(q+d, XMMq);
|
2366
|
+
_mm_store_ps(w+d, XMMw);
|
2367
|
+
}
|
2368
|
+
|
2369
|
+
mf_float tmp = 0;
|
2370
|
+
_mm_store_ss(&tmp, XMMlambda_p1);
|
2371
|
+
if(tmp > 0)
|
2372
|
+
{
|
2373
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
2374
|
+
{
|
2375
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
2376
|
+
__m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
|
2377
|
+
_mm_set1_ps(-0.0f));
|
2378
|
+
XMMp = _mm_xor_ps(XMMflip,
|
2379
|
+
_mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
|
2380
|
+
_mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
|
2381
|
+
_mm_store_ps(p+d, XMMp);
|
2382
|
+
}
|
2383
|
+
}
|
2384
|
+
|
2385
|
+
_mm_store_ss(&tmp, XMMlambda_q1);
|
2386
|
+
if(tmp > 0)
|
2387
|
+
{
|
2388
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
2389
|
+
{
|
2390
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
2391
|
+
__m128 XMMw = _mm_load_ps(w+d);
|
2392
|
+
__m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
|
2393
|
+
_mm_set1_ps(-0.0f));
|
2394
|
+
XMMq = _mm_xor_ps(XMMflip,
|
2395
|
+
_mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
|
2396
|
+
_mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
|
2397
|
+
_mm_store_ps(q+d, XMMq);
|
2398
|
+
|
2399
|
+
|
2400
|
+
XMMflip = _mm_and_ps(_mm_cmple_ps(XMMw, _mm_set1_ps(0.0f)),
|
2401
|
+
_mm_set1_ps(-0.0f));
|
2402
|
+
XMMw = _mm_xor_ps(XMMflip,
|
2403
|
+
_mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMw, XMMflip),
|
2404
|
+
_mm_mul_ps(XMMeta_w, XMMlambda_q1)), _mm_set1_ps(0.0f)));
|
2405
|
+
_mm_store_ps(w+d, XMMw);
|
2406
|
+
}
|
2407
|
+
}
|
2408
|
+
|
2409
|
+
if(param.do_nmf)
|
2410
|
+
{
|
2411
|
+
for(mf_int d = d_begin; d < d_end; d += 4)
|
2412
|
+
{
|
2413
|
+
__m128 XMMp = _mm_load_ps(p+d);
|
2414
|
+
__m128 XMMq = _mm_load_ps(q+d);
|
2415
|
+
__m128 XMMw = _mm_load_ps(w+d);
|
2416
|
+
XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
|
2417
|
+
XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
|
2418
|
+
XMMw = _mm_max_ps(XMMw, _mm_set1_ps(0.0f));
|
2419
|
+
_mm_store_ps(p+d, XMMp);
|
2420
|
+
_mm_store_ps(q+d, XMMq);
|
2421
|
+
_mm_store_ps(w+d, XMMw);
|
2422
|
+
}
|
2423
|
+
}
|
2424
|
+
|
2425
|
+
// Update learning rate of latent vector p. Squared derivatives along all
|
2426
|
+
// latent dimensions will be computed above. Here their average will be
|
2427
|
+
// added into the associated squared-gradient sum.
|
2428
|
+
__m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
|
2429
|
+
XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
2430
|
+
XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
|
2431
|
+
_mm_store_ss(pG, XMMpG);
|
2432
|
+
|
2433
|
+
// Similar code is used to update learning rate of latent vector q.
|
2434
|
+
XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
|
2435
|
+
XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
2436
|
+
XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
|
2437
|
+
_mm_store_ss(qG, XMMqG);
|
2438
|
+
|
2439
|
+
// Similar code is used to update learning rate of latent vector w.
|
2440
|
+
XMMtmp = _mm_add_ps(XMMwG1, _mm_movehl_ps(XMMwG1, XMMwG1));
|
2441
|
+
XMMwG1 = _mm_add_ps(XMMwG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
|
2442
|
+
XMMwG = _mm_add_ps(XMMwG, _mm_mul_ps(XMMwG1, XMMrk));
|
2443
|
+
_mm_store_ss(wG, XMMwG);
|
2444
|
+
}
|
2445
|
+
|
2446
|
+
void BPRSolver::prepare_for_sg_update(
|
2447
|
+
__m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2448
|
+
{
|
2449
|
+
prepare_negative();
|
2450
|
+
calc_z(XMMz, model.k, p, q, w);
|
2451
|
+
_mm_store_ss(&z, XMMz);
|
2452
|
+
z = exp(-z);
|
2453
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
|
2454
|
+
XMMerror = XMMloss;
|
2455
|
+
XMMz = _mm_set1_ps(z/(1+z));
|
2456
|
+
}
|
2457
|
+
#elif defined USEAVX
|
2458
|
+
inline void BPRSolver::calc_z(
|
2459
|
+
__m256 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
|
2460
|
+
{
|
2461
|
+
XMMz = _mm256_setzero_ps();
|
2462
|
+
for(mf_int d = 0; d < k; d += 8)
|
2463
|
+
XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
|
2464
|
+
_mm256_load_ps(p+d), _mm256_sub_ps(
|
2465
|
+
_mm256_load_ps(q+d), _mm256_load_ps(w+d))));
|
2466
|
+
XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
|
2467
|
+
XMMz = _mm256_hadd_ps(XMMz, XMMz);
|
2468
|
+
XMMz = _mm256_hadd_ps(XMMz, XMMz);
|
2469
|
+
}
|
2470
|
+
|
2471
|
+
void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
|
2472
|
+
{
|
2473
|
+
XMMloss = _mm_setzero_pd();
|
2474
|
+
XMMerror = _mm_setzero_pd();
|
2475
|
+
bid = scheduler.get_job();
|
2476
|
+
block = blocks[bid];
|
2477
|
+
block->reload();
|
2478
|
+
bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
|
2479
|
+
}
|
2480
|
+
|
2481
|
+
void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
|
2482
|
+
{
|
2483
|
+
_mm_store_sd(&loss, XMMloss);
|
2484
|
+
_mm_store_sd(&error, XMMerror);
|
2485
|
+
scheduler.put_job(bid, loss, error);
|
2486
|
+
scheduler.put_bpr_job(bid, bpr_bid);
|
2487
|
+
}
|
2488
|
+
|
2489
|
+
void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
|
2490
|
+
__m256 XMMlambda_p1, __m256 XMMlambda_q1,
|
2491
|
+
__m256 XMMlambda_p2, __m256 XMMlambda_q2,
|
2492
|
+
__m256 XMMeta, __m256 XMMrk)
|
2493
|
+
{
|
2494
|
+
__m256 XMMpG = _mm256_broadcast_ss(pG);
|
2495
|
+
__m256 XMMqG = _mm256_broadcast_ss(qG);
|
2496
|
+
__m256 XMMwG = _mm256_broadcast_ss(wG);
|
2497
|
+
__m256 XMMeta_p =
|
2498
|
+
_mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
|
2499
|
+
__m256 XMMeta_q =
|
2500
|
+
_mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
|
2501
|
+
__m256 XMMeta_w =
|
2502
|
+
_mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMwG));
|
2503
|
+
|
2504
|
+
__m256 XMMpG1 = _mm256_setzero_ps();
|
2505
|
+
__m256 XMMqG1 = _mm256_setzero_ps();
|
2506
|
+
__m256 XMMwG1 = _mm256_setzero_ps();
|
2507
|
+
|
2508
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
2509
|
+
{
|
2510
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
2511
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
2512
|
+
__m256 XMMw = _mm256_load_ps(w+d);
|
2513
|
+
__m256 XMMpg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
|
2514
|
+
_mm256_mul_ps(XMMz, _mm256_sub_ps(XMMw, XMMq)));
|
2515
|
+
__m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
|
2516
|
+
_mm256_mul_ps(XMMz, XMMp));
|
2517
|
+
__m256 XMMwg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_q2, XMMw),
|
2518
|
+
_mm256_mul_ps(XMMz, XMMp));
|
2519
|
+
|
2520
|
+
XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
|
2521
|
+
XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
|
2522
|
+
XMMwG1 = _mm256_add_ps(XMMwG1, _mm256_mul_ps(XMMwg, XMMwg));
|
2523
|
+
|
2524
|
+
XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
|
2525
|
+
XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
|
2526
|
+
XMMw = _mm256_sub_ps(XMMw, _mm256_mul_ps(XMMeta_w, XMMwg));
|
2527
|
+
|
2528
|
+
_mm256_store_ps(p+d, XMMp);
|
2529
|
+
_mm256_store_ps(q+d, XMMq);
|
2530
|
+
_mm256_store_ps(w+d, XMMw);
|
2531
|
+
}
|
2532
|
+
|
2533
|
+
mf_float tmp = 0;
|
2534
|
+
_mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
|
2535
|
+
if(tmp > 0)
|
2536
|
+
{
|
2537
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
2538
|
+
{
|
2539
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
2540
|
+
__m256 XMMflip =
|
2541
|
+
_mm256_and_ps(
|
2542
|
+
_mm256_cmp_ps(XMMp, _mm256_set1_ps(0.0f), _CMP_LE_OS),
|
2543
|
+
_mm256_set1_ps(-0.0f));
|
2544
|
+
XMMp = _mm256_xor_ps(XMMflip,
|
2545
|
+
_mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMp, XMMflip),
|
2546
|
+
_mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
|
2547
|
+
_mm256_set1_ps(0.0f)));
|
2548
|
+
_mm256_store_ps(p+d, XMMp);
|
2549
|
+
}
|
2550
|
+
}
|
2551
|
+
|
2552
|
+
_mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
|
2553
|
+
if(tmp > 0)
|
2554
|
+
{
|
2555
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
2556
|
+
{
|
2557
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
2558
|
+
__m256 XMMw = _mm256_load_ps(w+d);
|
2559
|
+
__m256 XMMflip;
|
2560
|
+
|
2561
|
+
XMMflip = _mm256_and_ps(
|
2562
|
+
_mm256_cmp_ps(XMMq, _mm256_set1_ps(0.0f), _CMP_LE_OS),
|
2563
|
+
_mm256_set1_ps(-0.0f));
|
2564
|
+
XMMq = _mm256_xor_ps(XMMflip,
|
2565
|
+
_mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMq, XMMflip),
|
2566
|
+
_mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
|
2567
|
+
_mm256_set1_ps(0.0f)));
|
2568
|
+
_mm256_store_ps(q+d, XMMq);
|
2569
|
+
|
2570
|
+
|
2571
|
+
XMMflip = _mm256_and_ps(
|
2572
|
+
_mm256_cmp_ps(XMMw, _mm256_set1_ps(0.0f), _CMP_LE_OS),
|
2573
|
+
_mm256_set1_ps(-0.0f));
|
2574
|
+
XMMw = _mm256_xor_ps(XMMflip,
|
2575
|
+
_mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMw, XMMflip),
|
2576
|
+
_mm256_mul_ps(XMMeta_w, XMMlambda_q1)),
|
2577
|
+
_mm256_set1_ps(0.0f)));
|
2578
|
+
_mm256_store_ps(w+d, XMMw);
|
2579
|
+
}
|
2580
|
+
}
|
2581
|
+
|
2582
|
+
if(param.do_nmf)
|
2583
|
+
{
|
2584
|
+
for(mf_int d = d_begin; d < d_end; d += 8)
|
2585
|
+
{
|
2586
|
+
__m256 XMMp = _mm256_load_ps(p+d);
|
2587
|
+
__m256 XMMq = _mm256_load_ps(q+d);
|
2588
|
+
__m256 XMMw = _mm256_load_ps(w+d);
|
2589
|
+
XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0.0f));
|
2590
|
+
XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0.0f));
|
2591
|
+
XMMw = _mm256_max_ps(XMMw, _mm256_set1_ps(0.0f));
|
2592
|
+
_mm256_store_ps(p+d, XMMp);
|
2593
|
+
_mm256_store_ps(q+d, XMMq);
|
2594
|
+
_mm256_store_ps(w+d, XMMw);
|
2595
|
+
}
|
2596
|
+
}
|
2597
|
+
|
2598
|
+
XMMpG1 = _mm256_add_ps(XMMpG1,
|
2599
|
+
_mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
|
2600
|
+
XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
|
2601
|
+
XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
|
2602
|
+
|
2603
|
+
XMMqG1 = _mm256_add_ps(XMMqG1,
|
2604
|
+
_mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
|
2605
|
+
XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
|
2606
|
+
XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
|
2607
|
+
|
2608
|
+
XMMwG1 = _mm256_add_ps(XMMwG1,
|
2609
|
+
_mm256_permute2f128_ps(XMMwG1, XMMwG1, 0x1));
|
2610
|
+
XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
|
2611
|
+
XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
|
2612
|
+
|
2613
|
+
XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
|
2614
|
+
XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
|
2615
|
+
XMMwG = _mm256_add_ps(XMMwG, _mm256_mul_ps(XMMwG1, XMMrk));
|
2616
|
+
|
2617
|
+
_mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
|
2618
|
+
_mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
|
2619
|
+
_mm_store_ss(wG, _mm256_castps256_ps128(XMMwG));
|
2620
|
+
}
|
2621
|
+
|
2622
|
+
void BPRSolver::prepare_for_sg_update(
|
2623
|
+
__m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
|
2624
|
+
{
|
2625
|
+
prepare_negative();
|
2626
|
+
calc_z(XMMz, model.k, p, q, w);
|
2627
|
+
_mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
|
2628
|
+
z = exp(-z);
|
2629
|
+
XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
|
2630
|
+
XMMerror = XMMloss;
|
2631
|
+
XMMz = _mm256_set1_ps(z/(1+z));
|
2632
|
+
}
|
2633
|
+
#else
|
2634
|
+
inline void BPRSolver::calc_z(
|
2635
|
+
mf_float &z, mf_int k, mf_float *p, mf_float *q, mf_float *w)
|
2636
|
+
{
|
2637
|
+
z = 0;
|
2638
|
+
for(mf_int d = 0; d < k; ++d)
|
2639
|
+
z += p[d]*(q[d]-w[d]);
|
2640
|
+
}
|
2641
|
+
|
2642
|
+
void BPRSolver::arrange_block()
|
2643
|
+
{
|
2644
|
+
loss = 0.0;
|
2645
|
+
error = 0.0;
|
2646
|
+
bid = scheduler.get_job();
|
2647
|
+
block = blocks[bid];
|
2648
|
+
block->reload();
|
2649
|
+
bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
|
2650
|
+
}
|
2651
|
+
|
2652
|
+
void BPRSolver::finalize()
|
2653
|
+
{
|
2654
|
+
scheduler.put_job(bid, loss, error);
|
2655
|
+
scheduler.put_bpr_job(bid, bpr_bid);
|
2656
|
+
}
|
2657
|
+
|
2658
|
+
void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
|
2659
|
+
{
|
2660
|
+
mf_float eta_p = param.eta*qrsqrt(*pG);
|
2661
|
+
mf_float eta_q = param.eta*qrsqrt(*qG);
|
2662
|
+
mf_float eta_w = param.eta*qrsqrt(*wG);
|
2663
|
+
|
2664
|
+
mf_float pG1 = 0;
|
2665
|
+
mf_float qG1 = 0;
|
2666
|
+
mf_float wG1 = 0;
|
2667
|
+
|
2668
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
2669
|
+
{
|
2670
|
+
mf_float gp = z*(w[d]-q[d]) + lambda_p2*p[d];
|
2671
|
+
mf_float gq = -z*p[d] + lambda_q2*q[d];
|
2672
|
+
mf_float gw = z*p[d] + lambda_q2*w[d];
|
2673
|
+
|
2674
|
+
pG1 += gp*gp;
|
2675
|
+
qG1 += gq*gq;
|
2676
|
+
wG1 += gw*gw;
|
2677
|
+
|
2678
|
+
p[d] -= eta_p*gp;
|
2679
|
+
q[d] -= eta_q*gq;
|
2680
|
+
w[d] -= eta_w*gw;
|
2681
|
+
}
|
2682
|
+
|
2683
|
+
if(lambda_p1 > 0)
|
2684
|
+
{
|
2685
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
2686
|
+
{
|
2687
|
+
mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
|
2688
|
+
p[d] = p[d] >= 0? p1: -p1;
|
2689
|
+
}
|
2690
|
+
}
|
2691
|
+
|
2692
|
+
if(lambda_q1 > 0)
|
2693
|
+
{
|
2694
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
2695
|
+
{
|
2696
|
+
mf_float q1 = max(abs(w[d])-lambda_q1*eta_w, 0.0f);
|
2697
|
+
w[d] = w[d] >= 0? q1: -q1;
|
2698
|
+
q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
|
2699
|
+
q[d] = q[d] >= 0? q1: -q1;
|
2700
|
+
}
|
2701
|
+
}
|
2702
|
+
|
2703
|
+
if(param.do_nmf)
|
2704
|
+
{
|
2705
|
+
for(mf_int d = d_begin; d < d_end; ++d)
|
2706
|
+
{
|
2707
|
+
p[d] = max(p[d], (mf_float)0.0);
|
2708
|
+
q[d] = max(q[d], (mf_float)0.0);
|
2709
|
+
w[d] = max(w[d], (mf_float)0.0);
|
2710
|
+
}
|
2711
|
+
}
|
2712
|
+
|
2713
|
+
*pG += pG1*rk;
|
2714
|
+
*qG += qG1*rk;
|
2715
|
+
*wG += wG1*rk;
|
2716
|
+
}
|
2717
|
+
|
2718
|
+
void BPRSolver::prepare_for_sg_update()
|
2719
|
+
{
|
2720
|
+
prepare_negative();
|
2721
|
+
calc_z(z, model.k, p, q, w);
|
2722
|
+
z = exp(-z);
|
2723
|
+
loss += log(1+z);
|
2724
|
+
error = loss;
|
2725
|
+
z = z/(1+z);
|
2726
|
+
}
|
2727
|
+
#endif
|
2728
|
+
|
2729
|
+
class COL_BPR_MFOC : public BPRSolver
|
2730
|
+
{
|
2731
|
+
public:
|
2732
|
+
COL_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
2733
|
+
mf_float *PG, mf_float *QG, mf_model &model,
|
2734
|
+
mf_parameter param, bool &slow_only,
|
2735
|
+
bool is_column_oriented=true)
|
2736
|
+
: BPRSolver(scheduler, blocks, PG, QG, model, param,
|
2737
|
+
slow_only, is_column_oriented) {}
|
2738
|
+
protected:
|
2739
|
+
#if defined USESSE
|
2740
|
+
void load_fixed_variables(
|
2741
|
+
__m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
|
2742
|
+
__m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
|
2743
|
+
__m128 &XMMeta, __m128 &XMMrk_slow,
|
2744
|
+
__m128 &XMMrk_fast);
|
2745
|
+
#elif defined USEAVX
|
2746
|
+
void load_fixed_variables(
|
2747
|
+
__m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
|
2748
|
+
__m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
|
2749
|
+
__m256 &XMMeta, __m256 &XMMrk_slow,
|
2750
|
+
__m256 &XMMrk_fast);
|
2751
|
+
#else
|
2752
|
+
void load_fixed_variables();
|
2753
|
+
#endif
|
2754
|
+
void prepare_negative();
|
2755
|
+
};
|
2756
|
+
|
2757
|
+
void COL_BPR_MFOC::prepare_negative()
|
2758
|
+
{
|
2759
|
+
mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
|
2760
|
+
is_column_oriented);
|
2761
|
+
w = model.P + negative*model.k;
|
2762
|
+
wG = PG + negative*2;
|
2763
|
+
swap(p, q);
|
2764
|
+
swap(pG, qG);
|
2765
|
+
}
|
2766
|
+
|
2767
|
+
#if defined USESSE
|
2768
|
+
void COL_BPR_MFOC::load_fixed_variables(
|
2769
|
+
__m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
|
2770
|
+
__m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
|
2771
|
+
__m128 &XMMeta, __m128 &XMMrk_slow,
|
2772
|
+
__m128 &XMMrk_fast)
|
2773
|
+
{
|
2774
|
+
XMMlambda_p1 = _mm_set1_ps(param.lambda_q1);
|
2775
|
+
XMMlambda_q1 = _mm_set1_ps(param.lambda_p1);
|
2776
|
+
XMMlambda_p2 = _mm_set1_ps(param.lambda_q2);
|
2777
|
+
XMMlambda_q2 = _mm_set1_ps(param.lambda_p2);
|
2778
|
+
XMMeta = _mm_set1_ps(param.eta);
|
2779
|
+
XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
|
2780
|
+
XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
|
2781
|
+
}
|
2782
|
+
#elif defined USEAVX
|
2783
|
+
void COL_BPR_MFOC::load_fixed_variables(
|
2784
|
+
__m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
|
2785
|
+
__m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
|
2786
|
+
__m256 &XMMeta, __m256 &XMMrk_slow,
|
2787
|
+
__m256 &XMMrk_fast)
|
2788
|
+
{
|
2789
|
+
XMMlambda_p1 = _mm256_set1_ps(param.lambda_q1);
|
2790
|
+
XMMlambda_q1 = _mm256_set1_ps(param.lambda_p1);
|
2791
|
+
XMMlambda_p2 = _mm256_set1_ps(param.lambda_q2);
|
2792
|
+
XMMlambda_q2 = _mm256_set1_ps(param.lambda_p2);
|
2793
|
+
XMMeta = _mm256_set1_ps(param.eta);
|
2794
|
+
XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
|
2795
|
+
XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
|
2796
|
+
}
|
2797
|
+
#else
|
2798
|
+
void COL_BPR_MFOC::load_fixed_variables()
|
2799
|
+
{
|
2800
|
+
lambda_p1 = param.lambda_q1;
|
2801
|
+
lambda_q1 = param.lambda_p1;
|
2802
|
+
lambda_p2 = param.lambda_q2;
|
2803
|
+
lambda_q2 = param.lambda_p2;
|
2804
|
+
rk_slow = (mf_float)1.0/kALIGN;
|
2805
|
+
rk_fast = (mf_float)1.0/(model.k-kALIGN);
|
2806
|
+
}
|
2807
|
+
#endif
|
2808
|
+
|
2809
|
+
class ROW_BPR_MFOC : public BPRSolver
|
2810
|
+
{
|
2811
|
+
public:
|
2812
|
+
ROW_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
|
2813
|
+
mf_float *PG, mf_float *QG, mf_model &model,
|
2814
|
+
mf_parameter param, bool &slow_only,
|
2815
|
+
bool is_column_oriented = false)
|
2816
|
+
: BPRSolver(scheduler, blocks, PG, QG, model, param,
|
2817
|
+
slow_only, is_column_oriented) {}
|
2818
|
+
protected:
|
2819
|
+
void prepare_negative();
|
2820
|
+
};
|
2821
|
+
|
2822
|
+
void ROW_BPR_MFOC::prepare_negative()
|
2823
|
+
{
|
2824
|
+
mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
|
2825
|
+
is_column_oriented);
|
2826
|
+
w = model.Q + negative*model.k;
|
2827
|
+
wG = QG + negative*2;
|
2828
|
+
}
|
2829
|
+
|
2830
|
+
|
2831
|
+
class SolverFactory
|
2832
|
+
{
|
2833
|
+
public:
|
2834
|
+
static shared_ptr<SolverBase> get_solver(
|
2835
|
+
Scheduler &scheduler,
|
2836
|
+
vector<BlockBase*> &blocks,
|
2837
|
+
mf_float *PG,
|
2838
|
+
mf_float *QG,
|
2839
|
+
mf_model &model,
|
2840
|
+
mf_parameter param,
|
2841
|
+
bool &slow_only);
|
2842
|
+
};
|
2843
|
+
|
2844
|
+
shared_ptr<SolverBase> SolverFactory::get_solver(
|
2845
|
+
Scheduler &scheduler,
|
2846
|
+
vector<BlockBase*> &blocks,
|
2847
|
+
mf_float *PG,
|
2848
|
+
mf_float *QG,
|
2849
|
+
mf_model &model,
|
2850
|
+
mf_parameter param,
|
2851
|
+
bool &slow_only)
|
2852
|
+
{
|
2853
|
+
shared_ptr<SolverBase> solver;
|
2854
|
+
|
2855
|
+
switch(param.fun)
|
2856
|
+
{
|
2857
|
+
case P_L2_MFR:
|
2858
|
+
solver = shared_ptr<SolverBase>(new L2_MFR(scheduler, blocks,
|
2859
|
+
PG, QG, model, param, slow_only));
|
2860
|
+
break;
|
2861
|
+
case P_L1_MFR:
|
2862
|
+
solver = shared_ptr<SolverBase>(new L1_MFR(scheduler, blocks,
|
2863
|
+
PG, QG, model, param, slow_only));
|
2864
|
+
break;
|
2865
|
+
case P_KL_MFR:
|
2866
|
+
solver = shared_ptr<SolverBase>(new KL_MFR(scheduler, blocks,
|
2867
|
+
PG, QG, model, param, slow_only));
|
2868
|
+
break;
|
2869
|
+
case P_LR_MFC:
|
2870
|
+
solver = shared_ptr<SolverBase>(new LR_MFC(scheduler, blocks,
|
2871
|
+
PG, QG, model, param, slow_only));
|
2872
|
+
break;
|
2873
|
+
case P_L2_MFC:
|
2874
|
+
solver = shared_ptr<SolverBase>(new L2_MFC(scheduler, blocks,
|
2875
|
+
PG, QG, model, param, slow_only));
|
2876
|
+
break;
|
2877
|
+
case P_L1_MFC:
|
2878
|
+
solver = shared_ptr<SolverBase>(new L1_MFC(scheduler, blocks,
|
2879
|
+
PG, QG, model, param, slow_only));
|
2880
|
+
break;
|
2881
|
+
case P_ROW_BPR_MFOC:
|
2882
|
+
solver = shared_ptr<SolverBase>(new ROW_BPR_MFOC(scheduler,
|
2883
|
+
blocks, PG, QG, model, param, slow_only));
|
2884
|
+
break;
|
2885
|
+
case P_COL_BPR_MFOC:
|
2886
|
+
solver = shared_ptr<SolverBase>(new COL_BPR_MFOC(scheduler,
|
2887
|
+
blocks, PG, QG, model, param, slow_only));
|
2888
|
+
break;
|
2889
|
+
default:
|
2890
|
+
throw invalid_argument("unknown error function");
|
2891
|
+
}
|
2892
|
+
return solver;
|
2893
|
+
}
|
2894
|
+
|
2895
|
+
void fpsg_core(
|
2896
|
+
Utility &util,
|
2897
|
+
Scheduler &sched,
|
2898
|
+
mf_problem *tr,
|
2899
|
+
mf_problem *va,
|
2900
|
+
mf_parameter param,
|
2901
|
+
mf_float scale,
|
2902
|
+
vector<BlockBase*> &block_ptrs,
|
2903
|
+
vector<mf_int> &omega_p,
|
2904
|
+
vector<mf_int> &omega_q,
|
2905
|
+
shared_ptr<mf_model> &model,
|
2906
|
+
vector<mf_int> cv_blocks,
|
2907
|
+
mf_double *cv_error)
|
2908
|
+
{
|
2909
|
+
#if defined USESSE || defined USEAVX
|
2910
|
+
auto flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE();
|
2911
|
+
_MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
|
2912
|
+
#endif
|
2913
|
+
if(tr->nnz == 0)
|
2914
|
+
{
|
2915
|
+
cout << "warning: train on an empty training set" << endl;
|
2916
|
+
return;
|
2917
|
+
}
|
2918
|
+
|
2919
|
+
if(param.fun == P_L2_MFR ||
|
2920
|
+
param.fun == P_L1_MFR ||
|
2921
|
+
param.fun == P_KL_MFR)
|
2922
|
+
{
|
2923
|
+
switch(param.fun)
|
2924
|
+
{
|
2925
|
+
case P_L2_MFR:
|
2926
|
+
param.lambda_p2 /= scale;
|
2927
|
+
param.lambda_q2 /= scale;
|
2928
|
+
param.lambda_p1 /= (mf_float)pow(scale, 1.5);
|
2929
|
+
param.lambda_q1 /= (mf_float)pow(scale, 1.5);
|
2930
|
+
break;
|
2931
|
+
case P_L1_MFR:
|
2932
|
+
case P_KL_MFR:
|
2933
|
+
param.lambda_p1 /= sqrt(scale);
|
2934
|
+
param.lambda_q1 /= sqrt(scale);
|
2935
|
+
break;
|
2936
|
+
}
|
2937
|
+
}
|
2938
|
+
|
2939
|
+
if(!param.quiet)
|
2940
|
+
{
|
2941
|
+
cout.width(4);
|
2942
|
+
cout << "iter";
|
2943
|
+
cout.width(13);
|
2944
|
+
cout << "tr_"+util.get_error_legend();
|
2945
|
+
if(va->nnz != 0)
|
2946
|
+
{
|
2947
|
+
cout.width(13);
|
2948
|
+
cout << "va_"+util.get_error_legend();
|
2949
|
+
}
|
2950
|
+
cout.width(13);
|
2951
|
+
cout << "obj";
|
2952
|
+
cout << "\n";
|
2953
|
+
}
|
2954
|
+
|
2955
|
+
bool slow_only = param.lambda_p1 == 0 && param.lambda_q1 == 0? true: false;
|
2956
|
+
vector<mf_float> PG(model->m*2, 1), QG(model->n*2, 1);
|
2957
|
+
|
2958
|
+
vector<shared_ptr<SolverBase>> solvers(param.nr_threads);
|
2959
|
+
vector<thread> threads;
|
2960
|
+
threads.reserve(param.nr_threads);
|
2961
|
+
for(mf_int i = 0; i < param.nr_threads; ++i)
|
2962
|
+
{
|
2963
|
+
solvers[i] = SolverFactory::get_solver(sched, block_ptrs,
|
2964
|
+
PG.data(), QG.data(),
|
2965
|
+
*model, param, slow_only);
|
2966
|
+
threads.emplace_back(&SolverBase::run, solvers[i].get());
|
2967
|
+
}
|
2968
|
+
|
2969
|
+
for(mf_int iter = 0; iter < param.nr_iters; ++iter)
|
2970
|
+
{
|
2971
|
+
sched.wait_for_jobs_done();
|
2972
|
+
|
2973
|
+
if(!param.quiet)
|
2974
|
+
{
|
2975
|
+
mf_double reg = 0;
|
2976
|
+
mf_double reg1 = util.calc_reg1(*model, param.lambda_p1,
|
2977
|
+
param.lambda_q1, omega_p, omega_q);
|
2978
|
+
mf_double reg2 = util.calc_reg2(*model, param.lambda_p2,
|
2979
|
+
param.lambda_q2, omega_p, omega_q);
|
2980
|
+
mf_double tr_loss = sched.get_loss();
|
2981
|
+
mf_double tr_error = sched.get_error()/tr->nnz;
|
2982
|
+
|
2983
|
+
switch(param.fun)
|
2984
|
+
{
|
2985
|
+
case P_L2_MFR:
|
2986
|
+
reg = (reg1+reg2)*scale*scale;
|
2987
|
+
tr_loss *= scale*scale;
|
2988
|
+
tr_error = sqrt(tr_error*scale*scale);
|
2989
|
+
break;
|
2990
|
+
case P_L1_MFR:
|
2991
|
+
case P_KL_MFR:
|
2992
|
+
reg = (reg1+reg2)*scale;
|
2993
|
+
tr_loss *= scale;
|
2994
|
+
tr_error *= scale;
|
2995
|
+
break;
|
2996
|
+
default:
|
2997
|
+
reg = reg1+reg2;
|
2998
|
+
break;
|
2999
|
+
}
|
3000
|
+
|
3001
|
+
cout.width(4);
|
3002
|
+
cout << iter;
|
3003
|
+
cout.width(13);
|
3004
|
+
cout << fixed << setprecision(4) << tr_error;
|
3005
|
+
if(va->nnz != 0)
|
3006
|
+
{
|
3007
|
+
Block va_block(va->R, va->R+va->nnz);
|
3008
|
+
vector<BlockBase*> va_blocks(1, &va_block);
|
3009
|
+
vector<mf_int> va_block_ids(1, 0);
|
3010
|
+
mf_double va_error =
|
3011
|
+
util.calc_error(va_blocks, va_block_ids, *model)/va->nnz;
|
3012
|
+
switch(param.fun)
|
3013
|
+
{
|
3014
|
+
case P_L2_MFR:
|
3015
|
+
va_error = sqrt(va_error*scale*scale);
|
3016
|
+
break;
|
3017
|
+
case P_L1_MFR:
|
3018
|
+
case P_KL_MFR:
|
3019
|
+
va_error *= scale;
|
3020
|
+
break;
|
3021
|
+
}
|
3022
|
+
|
3023
|
+
cout.width(13);
|
3024
|
+
cout << fixed << setprecision(4) << va_error;
|
3025
|
+
}
|
3026
|
+
cout.width(13);
|
3027
|
+
cout << fixed << setprecision(4) << scientific << reg+tr_loss;
|
3028
|
+
cout << "\n" << flush;
|
3029
|
+
}
|
3030
|
+
|
3031
|
+
if(iter == 0)
|
3032
|
+
slow_only = false;
|
3033
|
+
if(iter == param.nr_iters - 1)
|
3034
|
+
sched.terminate();
|
3035
|
+
sched.resume();
|
3036
|
+
}
|
3037
|
+
|
3038
|
+
for(auto &thread : threads)
|
3039
|
+
thread.join();
|
3040
|
+
|
3041
|
+
if(cv_error != nullptr && cv_blocks.size() > 0)
|
3042
|
+
{
|
3043
|
+
mf_long cv_count = 0;
|
3044
|
+
for(auto block : cv_blocks)
|
3045
|
+
cv_count += block_ptrs[block]->get_nnz();
|
3046
|
+
|
3047
|
+
*cv_error = util.calc_error(block_ptrs, cv_blocks, *model)/cv_count;
|
3048
|
+
|
3049
|
+
switch(param.fun)
|
3050
|
+
{
|
3051
|
+
case P_L2_MFR:
|
3052
|
+
*cv_error = sqrt(*cv_error*scale*scale);
|
3053
|
+
break;
|
3054
|
+
case P_L1_MFR:
|
3055
|
+
case P_KL_MFR:
|
3056
|
+
*cv_error *= scale;
|
3057
|
+
break;
|
3058
|
+
}
|
3059
|
+
}
|
3060
|
+
|
3061
|
+
#if defined USESSE || defined USEAVX
|
3062
|
+
_MM_SET_FLUSH_ZERO_MODE(flush_zero_mode);
|
3063
|
+
#endif
|
3064
|
+
}
|
3065
|
+
|
3066
|
+
shared_ptr<mf_model> fpsg(
|
3067
|
+
mf_problem const *tr_,
|
3068
|
+
mf_problem const *va_,
|
3069
|
+
mf_parameter param,
|
3070
|
+
vector<mf_int> cv_blocks=vector<mf_int>(),
|
3071
|
+
mf_double *cv_error=nullptr)
|
3072
|
+
{
|
3073
|
+
shared_ptr<mf_model> model;
|
3074
|
+
try
|
3075
|
+
{
|
3076
|
+
Utility util(param.fun, param.nr_threads);
|
3077
|
+
Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
|
3078
|
+
shared_ptr<mf_problem> tr;
|
3079
|
+
shared_ptr<mf_problem> va;
|
3080
|
+
vector<Block> blocks(param.nr_bins*param.nr_bins);
|
3081
|
+
vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
|
3082
|
+
vector<mf_node*> ptrs;
|
3083
|
+
vector<mf_int> p_map;
|
3084
|
+
vector<mf_int> q_map;
|
3085
|
+
vector<mf_int> inv_p_map;
|
3086
|
+
vector<mf_int> inv_q_map;
|
3087
|
+
vector<mf_int> omega_p;
|
3088
|
+
vector<mf_int> omega_q;
|
3089
|
+
mf_float avg = 0;
|
3090
|
+
mf_float std_dev = 0;
|
3091
|
+
mf_float scale = 1;
|
3092
|
+
|
3093
|
+
if(param.copy_data)
|
3094
|
+
{
|
3095
|
+
tr = shared_ptr<mf_problem>(
|
3096
|
+
Utility::copy_problem(tr_, true), deleter());
|
3097
|
+
va = shared_ptr<mf_problem>(
|
3098
|
+
Utility::copy_problem(va_, true), deleter());
|
3099
|
+
}
|
3100
|
+
else
|
3101
|
+
{
|
3102
|
+
tr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
|
3103
|
+
va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
|
3104
|
+
}
|
3105
|
+
|
3106
|
+
util.collect_info(*tr, avg, std_dev);
|
3107
|
+
|
3108
|
+
if(param.fun == P_L2_MFR ||
|
3109
|
+
param.fun == P_L1_MFR ||
|
3110
|
+
param.fun == P_KL_MFR)
|
3111
|
+
scale = max((mf_float)1e-4, std_dev);
|
3112
|
+
|
3113
|
+
p_map = Utility::gen_random_map(tr->m);
|
3114
|
+
q_map = Utility::gen_random_map(tr->n);
|
3115
|
+
inv_p_map = Utility::gen_inv_map(p_map);
|
3116
|
+
inv_q_map = Utility::gen_inv_map(q_map);
|
3117
|
+
omega_p = vector<mf_int>(tr->m, 0);
|
3118
|
+
omega_q = vector<mf_int>(tr->n, 0);
|
3119
|
+
|
3120
|
+
util.shuffle_problem(*tr, p_map, q_map);
|
3121
|
+
util.shuffle_problem(*va, p_map, q_map);
|
3122
|
+
util.scale_problem(*tr, (mf_float)1.0/scale);
|
3123
|
+
util.scale_problem(*va, (mf_float)1.0/scale);
|
3124
|
+
ptrs = util.grid_problem(*tr, param.nr_bins, omega_p, omega_q, blocks);
|
3125
|
+
|
3126
|
+
model = shared_ptr<mf_model>(Utility::init_model(param.fun,
|
3127
|
+
tr->m, tr->n, param.k, avg/scale, omega_p, omega_q),
|
3128
|
+
[] (mf_model *ptr) { mf_destroy_model(&ptr); });
|
3129
|
+
|
3130
|
+
for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
|
3131
|
+
block_ptrs[i] = &blocks[i];
|
3132
|
+
|
3133
|
+
fpsg_core(util, sched, tr.get(), va.get(), param, scale,
|
3134
|
+
block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
|
3135
|
+
|
3136
|
+
if(!param.copy_data)
|
3137
|
+
{
|
3138
|
+
util.scale_problem(*tr, scale);
|
3139
|
+
util.scale_problem(*va, scale);
|
3140
|
+
util.shuffle_problem(*tr, inv_p_map, inv_q_map);
|
3141
|
+
util.shuffle_problem(*va, inv_p_map, inv_q_map);
|
3142
|
+
}
|
3143
|
+
|
3144
|
+
util.scale_model(*model, scale);
|
3145
|
+
Utility::shrink_model(*model, param.k);
|
3146
|
+
Utility::shuffle_model(*model, inv_p_map, inv_q_map);
|
3147
|
+
}
|
3148
|
+
catch(exception const &e)
|
3149
|
+
{
|
3150
|
+
cerr << e.what() << endl;
|
3151
|
+
throw;
|
3152
|
+
}
|
3153
|
+
return model;
|
3154
|
+
}
|
3155
|
+
|
3156
|
+
shared_ptr<mf_model> fpsg_on_disk(
|
3157
|
+
const string tr_path,
|
3158
|
+
const string va_path,
|
3159
|
+
mf_parameter param,
|
3160
|
+
vector<mf_int> cv_blocks=vector<mf_int>(),
|
3161
|
+
mf_double *cv_error=nullptr)
|
3162
|
+
{
|
3163
|
+
shared_ptr<mf_model> model;
|
3164
|
+
try
|
3165
|
+
{
|
3166
|
+
Utility util(param.fun, param.nr_threads);
|
3167
|
+
Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
|
3168
|
+
mf_problem tr = {};
|
3169
|
+
mf_problem va = read_problem(va_path.c_str());
|
3170
|
+
vector<BlockOnDisk> blocks(param.nr_bins*param.nr_bins);
|
3171
|
+
vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
|
3172
|
+
vector<mf_int> p_map;
|
3173
|
+
vector<mf_int> q_map;
|
3174
|
+
vector<mf_int> inv_p_map;
|
3175
|
+
vector<mf_int> inv_q_map;
|
3176
|
+
vector<mf_int> omega_p;
|
3177
|
+
vector<mf_int> omega_q;
|
3178
|
+
mf_float avg = 0;
|
3179
|
+
mf_float std_dev = 0;
|
3180
|
+
mf_float scale = 1;
|
3181
|
+
|
3182
|
+
util.collect_info_on_disk(tr_path, tr, avg, std_dev);
|
3183
|
+
|
3184
|
+
if(param.fun == P_L2_MFR ||
|
3185
|
+
param.fun == P_L1_MFR ||
|
3186
|
+
param.fun == P_KL_MFR)
|
3187
|
+
scale = max((mf_float)1e-4, std_dev);
|
3188
|
+
|
3189
|
+
p_map = Utility::gen_random_map(tr.m);
|
3190
|
+
q_map = Utility::gen_random_map(tr.n);
|
3191
|
+
inv_p_map = Utility::gen_inv_map(p_map);
|
3192
|
+
inv_q_map = Utility::gen_inv_map(q_map);
|
3193
|
+
omega_p = vector<mf_int>(tr.m, 0);
|
3194
|
+
omega_q = vector<mf_int>(tr.n, 0);
|
3195
|
+
|
3196
|
+
util.shuffle_problem(va, p_map, q_map);
|
3197
|
+
util.scale_problem(va, (mf_float)1.0/scale);
|
3198
|
+
|
3199
|
+
util.grid_shuffle_scale_problem_on_disk(
|
3200
|
+
tr.m, tr.n, param.nr_bins, scale, tr_path,
|
3201
|
+
p_map, q_map, omega_p, omega_q, blocks);
|
3202
|
+
|
3203
|
+
model = shared_ptr<mf_model>(Utility::init_model(param.fun,
|
3204
|
+
tr.m, tr.n, param.k, avg/scale, omega_p, omega_q),
|
3205
|
+
[] (mf_model *ptr) { mf_destroy_model(&ptr); });
|
3206
|
+
|
3207
|
+
for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
|
3208
|
+
block_ptrs[i] = &blocks[i];
|
3209
|
+
|
3210
|
+
fpsg_core(util, sched, &tr, &va, param, scale,
|
3211
|
+
block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
|
3212
|
+
|
3213
|
+
delete [] va.R;
|
3214
|
+
|
3215
|
+
util.scale_model(*model, scale);
|
3216
|
+
Utility::shrink_model(*model, param.k);
|
3217
|
+
Utility::shuffle_model(*model, inv_p_map, inv_q_map);
|
3218
|
+
}
|
3219
|
+
catch(exception const &e)
|
3220
|
+
{
|
3221
|
+
cerr << e.what() << endl;
|
3222
|
+
throw;
|
3223
|
+
}
|
3224
|
+
return model;
|
3225
|
+
}
|
3226
|
+
|
3227
|
+
// The function implements an efficient method to compute objective function
|
3228
|
+
// minimized by coordinate descent method.
|
3229
|
+
//
|
3230
|
+
// \min_{P, Q} 0.5 * \sum_{(u,v)\in\Omega^+} (1-r_{u,v})^2 +
|
3231
|
+
// 0.5 * \alpha \sum_{(u,v)\not\in\Omega^+} (c-r_{u,v})^2 +
|
3232
|
+
// 0.5 * \lambda_p2 * ||P||_F^2 + 0.5 * \lambda_q2 * ||Q||_F^2
|
3233
|
+
// where
|
3234
|
+
// 1. (u,v) is a tuple of row index and column index,
|
3235
|
+
// 2. \Omega^+ a collections of (u,v) which specifies the locations of
|
3236
|
+
// positive entries in the training matrix.
|
3237
|
+
// 3. r_{u,v} is the predicted rating at (u,v)
|
3238
|
+
// 4. \alpha is the weight of negative entries' loss.
|
3239
|
+
// 5. c is the desired value at every negative entries.
|
3240
|
+
// 6. ||P||_F is matrix P's Frobenius norm.
|
3241
|
+
// 7. \lambda_p2 is the regularization coefficient of P.
|
3242
|
+
//
|
3243
|
+
// Note that coordinate descent method's P and Q are the transpose
|
3244
|
+
// counterparts of P and Q in stochastic gradient method. Let R denoates
|
3245
|
+
// the training matrix. For stochastic gradient method, we have R ~ P^TQ.
|
3246
|
+
// For coordinate descent method, we have R ~ PQ^T.
|
3247
|
+
void calc_ccd_one_class_obj(const mf_int nr_threads,
|
3248
|
+
const mf_float alpha, const mf_float c,
|
3249
|
+
const mf_int m, const mf_int n, const mf_int d,
|
3250
|
+
const mf_float lambda_p2, const mf_float lambda_q2,
|
3251
|
+
const mf_float *P, const mf_float *Q,
|
3252
|
+
shared_ptr<const mf_problem> data,
|
3253
|
+
/*output*/ mf_double &obj,
|
3254
|
+
/*output*/ mf_double &positive_loss,
|
3255
|
+
/*output*/ mf_double &negative_loss,
|
3256
|
+
/*output*/ mf_double ®)
|
3257
|
+
{
|
3258
|
+
// Declare regularization term of P.
|
3259
|
+
mf_double p_square_norm = 0.0;
|
3260
|
+
// Reduce P along column axis, which is the sum of rows in P.
|
3261
|
+
vector<mf_double> all_p_sum(d, 0.0);
|
3262
|
+
// Compute square of Frobenius norm on P and sum of all rows in P.
|
3263
|
+
for(mf_int k = 0; k < d; ++k)
|
3264
|
+
{
|
3265
|
+
// Declare a temporal buffer of all_p_sum[k] for using OpenMP.
|
3266
|
+
mf_double all_p_sum_k = 0.0;
|
3267
|
+
#if defined USEOMP
|
3268
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_square_norm,all_p_sum_k)
|
3269
|
+
#endif
|
3270
|
+
for(mf_int u = 0; u < m; ++u)
|
3271
|
+
{
|
3272
|
+
const mf_float &p_ku = P[u + k * m];
|
3273
|
+
p_square_norm += p_ku * p_ku;
|
3274
|
+
all_p_sum_k += p_ku;
|
3275
|
+
}
|
3276
|
+
all_p_sum[k] = all_p_sum_k;
|
3277
|
+
}
|
3278
|
+
|
3279
|
+
// Declare regularization term of Q
|
3280
|
+
mf_double q_square_norm = 0.0;
|
3281
|
+
// Reduce Q along column axis, whihc is the sum of rows in Q.
|
3282
|
+
vector<mf_double> all_q_sum(d, 0.0);
|
3283
|
+
// Compute square of Frobenius norm on Q and sum of all elements in Q
|
3284
|
+
for(mf_int k = 0; k < d; ++k)
|
3285
|
+
{
|
3286
|
+
// Declare a temporal buffer of all_p_sum[k] for using OpenMP.
|
3287
|
+
mf_double all_q_sum_k = 0.0;
|
3288
|
+
#if defined USEOMP
|
3289
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_square_norm,all_q_sum_k)
|
3290
|
+
#endif
|
3291
|
+
for(mf_int v = 0; v < n; ++v)
|
3292
|
+
{
|
3293
|
+
const mf_float &q_kv = Q[v + k * n];
|
3294
|
+
q_square_norm += q_kv * q_kv;
|
3295
|
+
all_q_sum_k += q_kv;
|
3296
|
+
}
|
3297
|
+
all_q_sum[k] = all_q_sum_k;
|
3298
|
+
}
|
3299
|
+
|
3300
|
+
// PTP = P^T * P, where P^T is the transpose of P. Note that P is a m-by-d
|
3301
|
+
// matrix and PTP is a d-by-d matrix.
|
3302
|
+
vector<mf_double> PTP(d * d, 0.0);
|
3303
|
+
// QTQ = Q^T * P, a d-by-d matrix.
|
3304
|
+
vector<mf_double> QTQ(d * d, 0.0);
|
3305
|
+
// We calculate PTP and QTQ because they are needed in the computation of
|
3306
|
+
// negative entries' loss function.
|
3307
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3308
|
+
{
|
3309
|
+
for(mf_int k2 = 0; k2 < d; ++k2)
|
3310
|
+
{
|
3311
|
+
// Inner product of the k1 and k2 columns in P, a m-by-d matrix.
|
3312
|
+
mf_double p_k1_p_k2_inner_product = 0.0;
|
3313
|
+
#if defined USEOMP
|
3314
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_k1_p_k2_inner_product)
|
3315
|
+
#endif
|
3316
|
+
for(mf_int u = 0; u < m; ++u)
|
3317
|
+
p_k1_p_k2_inner_product += P[u + k1 * m] * P[u + k2 * m];
|
3318
|
+
PTP[k1 * d + k2] = p_k1_p_k2_inner_product;
|
3319
|
+
|
3320
|
+
// Inner product of the k1 and k2 columns in Q, a n-by-d matrix.
|
3321
|
+
mf_double q_k1_q_k2_inner_product = 0.0;
|
3322
|
+
#if defined USEOMP
|
3323
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_k1_q_k2_inner_product)
|
3324
|
+
#endif
|
3325
|
+
for(mf_int v = 0; v < n; ++v)
|
3326
|
+
q_k1_q_k2_inner_product += Q[v + k1 * n] * Q[v + k2 * n];
|
3327
|
+
QTQ[k1 * d + k2] = q_k1_q_k2_inner_product;
|
3328
|
+
}
|
3329
|
+
}
|
3330
|
+
|
3331
|
+
// Initialize loss function value of positive matrix entries.
|
3332
|
+
// It consists two parts. The first part is the true prediction error
|
3333
|
+
// while the second part is only used for implementing faster algorithm.
|
3334
|
+
mf_double positive_loss1 = 0.0;
|
3335
|
+
mf_double positive_loss2 = 0.0;
|
3336
|
+
// Scan through positive matrix entries to compute their loss values.
|
3337
|
+
// Notice that we assume that positive entries' values are all one.
|
3338
|
+
#if defined USEOMP
|
3339
|
+
#pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:positive_loss1,positive_loss2)
|
3340
|
+
#endif
|
3341
|
+
for(mf_long i = 0; i < data->nnz; ++i)
|
3342
|
+
{
|
3343
|
+
const mf_double &r = data->R[i].r;
|
3344
|
+
positive_loss1 += (1.0 - r) * (1.0 - r);
|
3345
|
+
positive_loss2 -= alpha * (c - r) * (c - r);
|
3346
|
+
}
|
3347
|
+
positive_loss1 *= 0.5;
|
3348
|
+
positive_loss2 *= 0.5;
|
3349
|
+
|
3350
|
+
// Declare loss terms related to negative matrix entries.
|
3351
|
+
mf_double negative_loss1 = c * c * m * n;
|
3352
|
+
mf_double negative_loss2 = 0.0;
|
3353
|
+
mf_double negative_loss3 = 0.0;
|
3354
|
+
// Compute loss terms.
|
3355
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3356
|
+
{
|
3357
|
+
negative_loss2 += all_p_sum[k1] * all_q_sum[k1];
|
3358
|
+
for(mf_int k2 = 0; k2 < d; ++k2)
|
3359
|
+
negative_loss3 += PTP[k1 + k2 * d] * QTQ[k2 + k1 * d];
|
3360
|
+
}
|
3361
|
+
// Compute the loss function of negative matrix entries.
|
3362
|
+
mf_double negative_loss4 = 0.5 * alpha *
|
3363
|
+
(negative_loss1 - 2 * c * negative_loss2 + negative_loss3);
|
3364
|
+
|
3365
|
+
// Assign results to output variables.
|
3366
|
+
reg = 0.5 * lambda_p2 * p_square_norm + 0.5 * lambda_q2 * q_square_norm;
|
3367
|
+
|
3368
|
+
// The function minimized by coordinate descent method.
|
3369
|
+
obj = positive_loss1 + positive_loss2 + negative_loss4 + reg;
|
3370
|
+
|
3371
|
+
// Sume of squared error over positive matrix entries (i.e., those mf_node's
|
3372
|
+
// in data).
|
3373
|
+
positive_loss = positive_loss1;
|
3374
|
+
|
3375
|
+
// Sume of squared error over negative matrix entries (i.e., those mf_node's
|
3376
|
+
// in data). The value negative_loss4 contains the squared errors by
|
3377
|
+
// considering positive entries as negative entries, so positive_loss2 is
|
3378
|
+
// added to compensate that.
|
3379
|
+
negative_loss = negative_loss4 + positive_loss2;
|
3380
|
+
}
|
3381
|
+
|
3382
|
+
void ccd_one_class_core(
|
3383
|
+
const Utility &util,
|
3384
|
+
shared_ptr<const mf_problem> tr_csr,
|
3385
|
+
shared_ptr<const mf_problem> tr_csc,
|
3386
|
+
shared_ptr<const mf_problem> va,
|
3387
|
+
const mf_parameter param,
|
3388
|
+
const vector<mf_node*> &ptrs_u,
|
3389
|
+
const vector<mf_node*> &ptrs_v,
|
3390
|
+
/*output*/ shared_ptr<mf_model> &model)
|
3391
|
+
{
|
3392
|
+
// Check problems stored in CSR and CSC formats
|
3393
|
+
if(tr_csr == nullptr) throw invalid_argument("CSR problem pointer is null.");
|
3394
|
+
if(tr_csc == nullptr) throw invalid_argument("CSC problem pointer is null.");
|
3395
|
+
|
3396
|
+
if(tr_csr->m != tr_csc->m)
|
3397
|
+
throw logic_error(
|
3398
|
+
"Row counts must be identical in CSR and CSC formats: " +
|
3399
|
+
to_string(tr_csr->m) + " != " + to_string(tr_csc->m));
|
3400
|
+
const mf_int m = tr_csr->m;
|
3401
|
+
|
3402
|
+
if(tr_csr->n != tr_csc->n)
|
3403
|
+
throw logic_error(
|
3404
|
+
"Column counts must be identical in CSR and CSC formats: " +
|
3405
|
+
to_string(tr_csr->n) + " != " + to_string(tr_csc->n));
|
3406
|
+
const mf_int n = tr_csr->n;
|
3407
|
+
|
3408
|
+
if(tr_csc->nnz != tr_csc->nnz)
|
3409
|
+
throw logic_error(
|
3410
|
+
"Numbers of data points must be identical in CSR and CSC formats: " +
|
3411
|
+
to_string(tr_csr->nnz) + " != " + to_string(tr_csc->nnz));
|
3412
|
+
const mf_long nnz = tr_csr->nnz;
|
3413
|
+
|
3414
|
+
// Check formulation parameters
|
3415
|
+
if(param.k <= 0)
|
3416
|
+
throw invalid_argument(
|
3417
|
+
"Latent dimension must be positive but got " +
|
3418
|
+
to_string(param.k));
|
3419
|
+
const mf_int d = param.k;
|
3420
|
+
|
3421
|
+
if(param.lambda_p1 != 0)
|
3422
|
+
throw invalid_argument(
|
3423
|
+
"P's L1-regularization coefficient must be zero but got " +
|
3424
|
+
to_string(param.lambda_p1));
|
3425
|
+
if(param.lambda_q1 != 0)
|
3426
|
+
throw invalid_argument(
|
3427
|
+
"Q's L1-regularization coefficient must be zero but got " +
|
3428
|
+
to_string(param.lambda_q1));
|
3429
|
+
|
3430
|
+
if(param.lambda_p2 <= 0)
|
3431
|
+
throw invalid_argument(
|
3432
|
+
"P's L2-regularization coefficient must be positive but got " +
|
3433
|
+
to_string(param.lambda_p2));
|
3434
|
+
if(param.lambda_q2 <= 0)
|
3435
|
+
throw invalid_argument(
|
3436
|
+
"Q's L2-regularization coefficient must be positive but got " +
|
3437
|
+
to_string(param.lambda_q2));
|
3438
|
+
|
3439
|
+
// REVIEW: It is not difficult to support non-negative matrix factorization
|
3440
|
+
// for coordinate descent method; we just need to project the updated value
|
3441
|
+
// back to the feasible region by using max(0, new_value) right after each
|
3442
|
+
// Newton step. LIBMF hasn't support it only because we don't see actual
|
3443
|
+
// users.
|
3444
|
+
if(param.do_nmf)
|
3445
|
+
throw invalid_argument(
|
3446
|
+
"Coordinate descent does not support non-negative constraint");
|
3447
|
+
|
3448
|
+
|
3449
|
+
// Check some resources prepared internally
|
3450
|
+
if(ptrs_u.size() != (size_t)m + 1)
|
3451
|
+
throw invalid_argument("Number of row pointer must be " +
|
3452
|
+
to_string(m + 1) + " but got " + to_string(ptrs_u.size()));
|
3453
|
+
if(ptrs_v.size() != (size_t)n + 1)
|
3454
|
+
throw invalid_argument("Number of column pointer must be " +
|
3455
|
+
to_string(n + 1) + " but got " + to_string(ptrs_v.size()));
|
3456
|
+
|
3457
|
+
// Some constants of the formulation.
|
3458
|
+
// alpha: coefficient of negative part
|
3459
|
+
// c: the desired prediction values of unobserved ratings
|
3460
|
+
// lambda_p2: regularization coefficient of P's L2-norm
|
3461
|
+
// lambda_q2: regularization coefficient of P's Q2-norm
|
3462
|
+
const mf_float alpha = param.alpha;
|
3463
|
+
const mf_float c = param.c;
|
3464
|
+
const mf_float lambda_p2 = param.lambda_p2;
|
3465
|
+
const mf_float lambda_q2 = param.lambda_q2;
|
3466
|
+
|
3467
|
+
// Initialize P and Q. Note that \bar{q}_{kv} is Q[k*n+v]
|
3468
|
+
// and \bar{p}_{ku} is P[k*m+u]. One may notice that P and
|
3469
|
+
// Q here are actually the transposes of P and Q in FPSG.
|
3470
|
+
mf_float *P = model->P;
|
3471
|
+
mf_float *Q = model->Q;
|
3472
|
+
|
3473
|
+
// Cache the prediction values on positive matrix entries.
|
3474
|
+
// Given that P=zero and Q=random initialized in
|
3475
|
+
// Utility::init_model(mf_int m, mf_int n, mf_int k),
|
3476
|
+
// all predictions are zeros.
|
3477
|
+
#if defined USEOMP
|
3478
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3479
|
+
#endif
|
3480
|
+
for(mf_long i = 0; i < nnz; ++i)
|
3481
|
+
{
|
3482
|
+
tr_csr->R[i].r = 0.0;
|
3483
|
+
tr_csc->R[i].r = 0.0;
|
3484
|
+
}
|
3485
|
+
|
3486
|
+
// If the model is not initialized by
|
3487
|
+
// Utility::init_model(mf_int m, mf_int n, mf_int k),
|
3488
|
+
// please use the following initialization code to compute
|
3489
|
+
// and cache all prediction values on positive entries.
|
3490
|
+
/*
|
3491
|
+
for(mf_long i = 0; i < nnz; ++i)
|
3492
|
+
{
|
3493
|
+
mf_node &node = tr_csr->R[i];
|
3494
|
+
node.r = 0;
|
3495
|
+
for(mf_int k = 0; k < d; ++k)
|
3496
|
+
node.r += P[node.u + k * m]*Q[node.v + k * n];
|
3497
|
+
}
|
3498
|
+
for(mf_long i = 0; i < nnz; ++i)
|
3499
|
+
{
|
3500
|
+
mf_node &node = tr_csc->R[i];
|
3501
|
+
node.r = 0;
|
3502
|
+
for(mf_int k = 0; k < d; ++k)
|
3503
|
+
node.r += P[node.u + k * m]*Q[node.v + k * n];
|
3504
|
+
}
|
3505
|
+
*/
|
3506
|
+
|
3507
|
+
if(!param.quiet)
|
3508
|
+
{
|
3509
|
+
cout.width(4);
|
3510
|
+
cout << "iter";
|
3511
|
+
cout.width(13);
|
3512
|
+
cout << "tr_"+util.get_error_legend();
|
3513
|
+
cout.width(14);
|
3514
|
+
cout << "tr_"+util.get_error_legend() << "+";
|
3515
|
+
cout.width(14);
|
3516
|
+
cout << "tr_"+util.get_error_legend() << "-";
|
3517
|
+
if(va->nnz != 0)
|
3518
|
+
{
|
3519
|
+
cout.width(13);
|
3520
|
+
cout << "va_"+util.get_error_legend();
|
3521
|
+
cout.width(14);
|
3522
|
+
cout << "va_"+util.get_error_legend() << "+";
|
3523
|
+
cout.width(14);
|
3524
|
+
cout << "va_"+util.get_error_legend() << "-";
|
3525
|
+
}
|
3526
|
+
cout.width(13);
|
3527
|
+
cout << "obj";
|
3528
|
+
cout << "\n";
|
3529
|
+
}
|
3530
|
+
|
3531
|
+
/////////////////////////////////////////////////////////////////
|
3532
|
+
// Minimize the objective function via coordinate descent method
|
3533
|
+
////////////////////////////////////////////////////////////////
|
3534
|
+
// Solve P and Q using coordinate descent.
|
3535
|
+
// P = [\bar{p}_1, ..., \bar{p}_d] \in R^{m \times k}
|
3536
|
+
// Q = [\bar{q}_1, ..., \bar{q}_d] \in R^{n \times k}
|
3537
|
+
// Finally, the rating matrice R would be approximated via
|
3538
|
+
// R ~ PQ^T \in R^{m \times n}
|
3539
|
+
for(mf_int outer = 0; outer < param.nr_iters; ++outer)
|
3540
|
+
{
|
3541
|
+
// Update \bar{p}_k and \bar{q}_k. The basic idea is
|
3542
|
+
// to replace \bar{p}_k and \bar{q}_k with a and b,
|
3543
|
+
// and then minimizes the original objective function.
|
3544
|
+
for(mf_int k = 0; k < d; ++k)
|
3545
|
+
{
|
3546
|
+
// Get the pointer to the first element of \bar{p}_k (and
|
3547
|
+
// \bar{q}_k).
|
3548
|
+
mf_float *P_k = P + m * k;
|
3549
|
+
mf_float *Q_k = Q + n * k;
|
3550
|
+
|
3551
|
+
// Initialize a and b with the value they need to replace
|
3552
|
+
// so that we can ensure improvement at each iteration.
|
3553
|
+
vector<mf_float> a(P_k, P_k + m);
|
3554
|
+
vector<mf_float> b(Q_k, Q_k + n);
|
3555
|
+
|
3556
|
+
for(mf_int inner = 0; inner < 3; ++inner)
|
3557
|
+
{
|
3558
|
+
///////////////////////////////////////////////////////////////
|
3559
|
+
// Update a:
|
3560
|
+
// 1. Compute and cache constants
|
3561
|
+
// 2. For each coordinate of a, calculate optimal update using
|
3562
|
+
// Newton method
|
3563
|
+
///////////////////////////////////////////////////////////////
|
3564
|
+
|
3565
|
+
// Compute and cache constants
|
3566
|
+
// \hat{b} = \sum_{v=1}^n \bar{b}_v
|
3567
|
+
// \tilde{b} = \sum_{v=1}^n \bar{b}_v^2
|
3568
|
+
mf_double b_hat = 0.0;
|
3569
|
+
mf_double b_tilde = 0.0;
|
3570
|
+
#if defined USEOMP
|
3571
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:b_hat,b_tilde)
|
3572
|
+
#endif
|
3573
|
+
for(mf_int v = 0; v < n; ++v)
|
3574
|
+
{
|
3575
|
+
const mf_double &b_v = b[v];
|
3576
|
+
b_hat += b_v;
|
3577
|
+
b_tilde += b_v * b_v;
|
3578
|
+
}
|
3579
|
+
|
3580
|
+
// Compute and cache a constant vector
|
3581
|
+
// s_k = \sum_{v=1}^n \bar{q}_{kv}b_v, k = 1, ..., d
|
3582
|
+
vector<mf_double> s(d, 0.0);
|
3583
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3584
|
+
{
|
3585
|
+
// Buffer variable for using OpenMP
|
3586
|
+
mf_double s_k1 = 0;
|
3587
|
+
const mf_float *Q_k1 = Q + k1 * n;
|
3588
|
+
#if defined USEOMP
|
3589
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:s_k1)
|
3590
|
+
#endif
|
3591
|
+
for(mf_int v = 0; v < n; ++v)
|
3592
|
+
s_k1 += Q_k1[v] * b[v];
|
3593
|
+
s[k1] = s_k1;
|
3594
|
+
}
|
3595
|
+
|
3596
|
+
// Solve a's sub-problem
|
3597
|
+
#if defined USEOMP
|
3598
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3599
|
+
#endif
|
3600
|
+
for(mf_int u = 0; u < m; ++u)
|
3601
|
+
{
|
3602
|
+
////////////////////////////////////////////////////////
|
3603
|
+
// Update a[u] via Newton method. Let g_u and h_u denote
|
3604
|
+
// the first-order and second-order derivatives w.r.t.
|
3605
|
+
// a[u]. The following code implements
|
3606
|
+
// a[u] <-- a[u] - g_u/h_u
|
3607
|
+
////////////////////////////////////////////////////////
|
3608
|
+
|
3609
|
+
// Initialize temporal variables for calculating gradient and hessian.
|
3610
|
+
mf_double g_u_1 = 0.0;
|
3611
|
+
mf_double h_u_1 = 0.0;
|
3612
|
+
mf_double g_u_2 = 0.0;
|
3613
|
+
// Scan through specified entries at the u-th row
|
3614
|
+
for(const mf_node *ptr = ptrs_u[u]; ptr != ptrs_u[u+1]; ++ptr)
|
3615
|
+
{
|
3616
|
+
const mf_int &v = ptr->v;
|
3617
|
+
const mf_float &b_v = b[v];
|
3618
|
+
g_u_1 += b_v;
|
3619
|
+
h_u_1 += b_v * b_v;
|
3620
|
+
g_u_2 += (ptr->r - P_k[u] * Q_k[v] + a[u] * b_v) * b_v;
|
3621
|
+
}
|
3622
|
+
mf_double g_u_3 = -c * b_hat - P_k[u] * s[k] + a[u] * b_tilde;
|
3623
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3624
|
+
g_u_3 += P[m * k1 + u] * s[k1];
|
3625
|
+
mf_double g_u = -(1.0 - alpha * c) * g_u_1 + (1.0 - alpha) * g_u_2 + alpha * g_u_3 + lambda_p2 * a[u];
|
3626
|
+
mf_double h_u = (1.0 - alpha) * h_u_1 + alpha * b_tilde + lambda_p2;
|
3627
|
+
a[u] -= static_cast<mf_float>(g_u / h_u);
|
3628
|
+
}
|
3629
|
+
|
3630
|
+
///////////////////////////////////////////////////////////////
|
3631
|
+
// Update b:
|
3632
|
+
// 1. Compute and cache constants
|
3633
|
+
// 2. For each coordinate of b, calculate optimal update using
|
3634
|
+
// Newton method
|
3635
|
+
///////////////////////////////////////////////////////////////
|
3636
|
+
// Compute and cache a_hat, a_tilde
|
3637
|
+
// \hat{a} = \sum_{u=1}^m \bar{a}_u
|
3638
|
+
// \tilde{a} = \sum_{u=1}^m \bar{a}_u^2
|
3639
|
+
mf_double a_hat = 0.0;
|
3640
|
+
mf_double a_tilde = 0.0;
|
3641
|
+
#if defined USEOMP
|
3642
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:a_hat,a_tilde)
|
3643
|
+
#endif
|
3644
|
+
for(mf_int u = 0; u < m; ++u)
|
3645
|
+
{
|
3646
|
+
const mf_float &a_u = a[u];
|
3647
|
+
a_hat += a_u;
|
3648
|
+
a_tilde += a_u * a_u;
|
3649
|
+
}
|
3650
|
+
|
3651
|
+
// Compute and cache t
|
3652
|
+
// t_k = \sum_{u=1}^m \bar{a}_{ku}a_u, k = 1, ..., d
|
3653
|
+
vector<mf_double> t(d, 0.0);
|
3654
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3655
|
+
{
|
3656
|
+
// Declare buffer variable for using OpenMP
|
3657
|
+
mf_double t_k1 = 0;
|
3658
|
+
const mf_float *P_k1 = P + k1 * m;
|
3659
|
+
#if defined USEOMP
|
3660
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:t_k1)
|
3661
|
+
#endif
|
3662
|
+
for(mf_int u = 0; u < m; ++u)
|
3663
|
+
t_k1 += P_k1[u] * a[u];
|
3664
|
+
t[k1] = t_k1;
|
3665
|
+
}
|
3666
|
+
|
3667
|
+
#if defined USEOMP
|
3668
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3669
|
+
#endif
|
3670
|
+
for(mf_int v = 0; v < n; ++v)
|
3671
|
+
{
|
3672
|
+
////////////////////////////////////////////////////////
|
3673
|
+
// Update b[v] via Newton method. Let g_v and h_v denote
|
3674
|
+
// the first-order and second-order derivatives w.r.t.
|
3675
|
+
// b[v]. The following code implements
|
3676
|
+
// b[v] <-- b[v] - g_v/h_v
|
3677
|
+
////////////////////////////////////////////////////////
|
3678
|
+
|
3679
|
+
// Initialize temporal variables for calculating gradient and hessian.
|
3680
|
+
mf_double g_v_1 = 0;
|
3681
|
+
mf_double g_v_2 = 0;
|
3682
|
+
mf_double h_v_1 = 0;
|
3683
|
+
// Scan through all positive entries at column v
|
3684
|
+
for(const mf_node *ptr = ptrs_v[v]; ptr != ptrs_v[v+1]; ++ptr)
|
3685
|
+
{
|
3686
|
+
const mf_int &u = ptr->u;
|
3687
|
+
const mf_float &a_u = a[u];
|
3688
|
+
g_v_1 += a_u;
|
3689
|
+
h_v_1 += a_u * a_u;
|
3690
|
+
g_v_2 += (ptr->r - P_k[u] * Q_k[v] + a_u * b[v]) * a_u;
|
3691
|
+
}
|
3692
|
+
mf_double g_v_3 = -c * a_hat - Q_k[v] * t[k] + b[v] * a_tilde;
|
3693
|
+
for(mf_int k1 = 0; k1 < d; ++k1)
|
3694
|
+
g_v_3 += Q[n * k1 + v] * t[k1];
|
3695
|
+
mf_double g_v = -(1.0 - alpha * c) * g_v_1 + (1.0 - alpha) * g_v_2 +
|
3696
|
+
alpha * g_v_3 + lambda_q2 * b[v];
|
3697
|
+
mf_double h_v = (1 - alpha) * h_v_1 + alpha * a_tilde + lambda_q2;
|
3698
|
+
b[v] -= static_cast<mf_float>(g_v / h_v);
|
3699
|
+
}
|
3700
|
+
|
3701
|
+
///////////////////////////////////////////////////////////////
|
3702
|
+
// Update cached variables.
|
3703
|
+
///////////////////////////////////////////////////////////////
|
3704
|
+
// Update prediction error in CSR format
|
3705
|
+
// \bar{r}_{uv} <- \bar{r}_{uv} - \bar_{p}_{ku}*\bar_{q}_{kv} + a_u*b_v
|
3706
|
+
#if defined USEOMP
|
3707
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3708
|
+
#endif
|
3709
|
+
for(mf_long i = 0; i < tr_csr->nnz; ++i)
|
3710
|
+
{
|
3711
|
+
// Update prediction values of positive entries in CSR
|
3712
|
+
mf_node *csr_ptr = tr_csr->R + i;
|
3713
|
+
const mf_int &u_csr = csr_ptr->u;
|
3714
|
+
const mf_int &v_csr = csr_ptr->v;
|
3715
|
+
csr_ptr->r += a[u_csr] * b[v_csr] - P_k[u_csr] * Q_k[v_csr];
|
3716
|
+
|
3717
|
+
// Update prediction values of positive entries in CSC
|
3718
|
+
mf_node *csc_ptr = tr_csc->R + i;
|
3719
|
+
const mf_int &u_csc = csc_ptr->u;
|
3720
|
+
const mf_int &v_csc = csc_ptr->v;
|
3721
|
+
csc_ptr->r += a[u_csc] * b[v_csc] - P_k[u_csc] * Q_k[v_csc];
|
3722
|
+
}
|
3723
|
+
|
3724
|
+
#if defined USEOMP
|
3725
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3726
|
+
#endif
|
3727
|
+
// Update P_k and Q_k
|
3728
|
+
for(mf_int u = 0; u < m; ++u)
|
3729
|
+
P_k[u] = a[u];
|
3730
|
+
#if defined USEOMP
|
3731
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3732
|
+
#endif
|
3733
|
+
for(mf_int v = 0; v < n; ++v)
|
3734
|
+
Q_k[v] = b[v];
|
3735
|
+
}
|
3736
|
+
}
|
3737
|
+
|
3738
|
+
// Skip the whole evaluation if nothing should be printed out.
|
3739
|
+
if(param.quiet)
|
3740
|
+
continue;
|
3741
|
+
|
3742
|
+
// Declare variable for storing objective value being minimized
|
3743
|
+
// by the training procedure. Note that The objective value consists
|
3744
|
+
// of two parts, loss function and regularization function.
|
3745
|
+
mf_double obj = 0;
|
3746
|
+
// Declare variables for storing loss function's value.
|
3747
|
+
mf_double positive_loss = 0; // for positive entries in training matrix.
|
3748
|
+
mf_double negative_loss = 0; // for negative entries in training matrix.
|
3749
|
+
// Declare variable for storing regularization function's value.
|
3750
|
+
mf_double reg = 0;
|
3751
|
+
|
3752
|
+
// Compute objective value, loss function value, and regularization
|
3753
|
+
// function value
|
3754
|
+
calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
|
3755
|
+
lambda_p2, lambda_q2, P, Q, tr_csr,
|
3756
|
+
obj, positive_loss, negative_loss, reg);
|
3757
|
+
|
3758
|
+
// Print number of outer iterations.
|
3759
|
+
cout.width(4);
|
3760
|
+
cout << outer;
|
3761
|
+
cout.width(13);
|
3762
|
+
cout << fixed << setprecision(4) << positive_loss + negative_loss;
|
3763
|
+
cout.width(15);
|
3764
|
+
cout << fixed << setprecision(4) << positive_loss;
|
3765
|
+
cout.width(15);
|
3766
|
+
cout << fixed << setprecision(4) << negative_loss;
|
3767
|
+
|
3768
|
+
if(va->nnz != 0)
|
3769
|
+
{
|
3770
|
+
// The following loop computes prediction scores on validation set.
|
3771
|
+
// Because training scores is also maintained in coordinate descent
|
3772
|
+
// framework, we didn't need to actively compute scores on training set.
|
3773
|
+
#if defined USEOMP
|
3774
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3775
|
+
#endif
|
3776
|
+
for(mf_long i = 0; i < va->nnz; ++i)
|
3777
|
+
{
|
3778
|
+
mf_node &node = va->R[i];
|
3779
|
+
node.r = 0;
|
3780
|
+
for(mf_int k = 0; k < d; ++k)
|
3781
|
+
node.r += P[node.u + k * m]*Q[node.v + k * n];
|
3782
|
+
}
|
3783
|
+
|
3784
|
+
mf_double va_obj = 0;
|
3785
|
+
mf_double va_positive_loss = 0;
|
3786
|
+
mf_double va_negative_loss = 0;
|
3787
|
+
mf_double va_reg = 0;
|
3788
|
+
|
3789
|
+
calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
|
3790
|
+
lambda_p2, lambda_q2, P, Q, va,
|
3791
|
+
va_obj, va_positive_loss, va_negative_loss, va_reg);
|
3792
|
+
|
3793
|
+
cout.width(13);
|
3794
|
+
cout << fixed << setprecision(4) << va_positive_loss + va_negative_loss;
|
3795
|
+
cout.width(15);
|
3796
|
+
cout << fixed << setprecision(4) << va_positive_loss;
|
3797
|
+
cout.width(15);
|
3798
|
+
cout << fixed << setprecision(4) << va_negative_loss;
|
3799
|
+
}
|
3800
|
+
|
3801
|
+
cout.width(13);
|
3802
|
+
cout << fixed << setprecision(4) << scientific << obj;
|
3803
|
+
cout << "\n" << flush;
|
3804
|
+
}
|
3805
|
+
|
3806
|
+
// Transpose P and Q. Note that the format of P and Q here are different
|
3807
|
+
// than that for mf_model.
|
3808
|
+
|
3809
|
+
mf_float *P_transpose = Utility::malloc_aligned_float((mf_long)m * d);
|
3810
|
+
#if defined USEOMP
|
3811
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3812
|
+
#endif
|
3813
|
+
for(mf_int u = 0; u < m; ++u)
|
3814
|
+
for(mf_int k = 0; k < d; ++k)
|
3815
|
+
P_transpose[k + u * d] = P[u + k * m];
|
3816
|
+
Utility::free_aligned_float(P);
|
3817
|
+
mf_float *Q_transpose = Utility::malloc_aligned_float((mf_long)n * d);
|
3818
|
+
#if defined USEOMP
|
3819
|
+
#pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
|
3820
|
+
#endif
|
3821
|
+
for(mf_int v = 0; v < n; ++v)
|
3822
|
+
for(mf_int k = 0; k < d; ++k)
|
3823
|
+
Q_transpose[k + v * d] = Q[v + k * n];
|
3824
|
+
Utility::free_aligned_float(Q);
|
3825
|
+
|
3826
|
+
// Set the passed-in model to the result learned from the given data
|
3827
|
+
// model is null
|
3828
|
+
model->m = m;
|
3829
|
+
model->n = n;
|
3830
|
+
model->k = d;
|
3831
|
+
model->b = 0.0;
|
3832
|
+
model->P = P_transpose;
|
3833
|
+
model->Q = Q_transpose;
|
3834
|
+
}
|
3835
|
+
|
3836
|
+
shared_ptr<mf_model> ccd_one_class(
|
3837
|
+
mf_problem const *tr_,
|
3838
|
+
mf_problem const *va_,
|
3839
|
+
mf_parameter param)
|
3840
|
+
{
|
3841
|
+
shared_ptr<mf_model> model;
|
3842
|
+
try
|
3843
|
+
{
|
3844
|
+
Utility util(param.fun, param.nr_threads);
|
3845
|
+
// Training matrix in compressed row format (sort nodes by user id)
|
3846
|
+
shared_ptr<mf_problem> tr_csr;
|
3847
|
+
// Training matrix in compressed column format (sort nodes by item id)
|
3848
|
+
shared_ptr<mf_problem> tr_csc;
|
3849
|
+
shared_ptr<mf_problem> va;
|
3850
|
+
// In tr_csr->R, i-th row starting at row_ptrs[i] and eneding right before row_ptrs[i+1]
|
3851
|
+
vector<mf_node*> ptrs_u(tr_->m + 1, nullptr);
|
3852
|
+
// In tr_csv->R, i-th column starting at col_ptrs[i] and eneding right before col_ptrs[i+1]
|
3853
|
+
vector<mf_node*> ptrs_v(tr_->n + 1, nullptr);
|
3854
|
+
|
3855
|
+
if(param.copy_data)
|
3856
|
+
{
|
3857
|
+
// Need a row-major and a column-major training formats
|
3858
|
+
// Thus, two duplicates are made.
|
3859
|
+
tr_csr = shared_ptr<mf_problem>(
|
3860
|
+
Utility::copy_problem(tr_, true), deleter());
|
3861
|
+
tr_csc = shared_ptr<mf_problem>(
|
3862
|
+
Utility::copy_problem(tr_, true), deleter());
|
3863
|
+
va = shared_ptr<mf_problem>(
|
3864
|
+
Utility::copy_problem(va_, true), deleter());
|
3865
|
+
}
|
3866
|
+
else
|
3867
|
+
{
|
3868
|
+
// Need a row-major and a column-major training formats
|
3869
|
+
// The original data is reused as row-major one so
|
3870
|
+
// one duplicate for column-major one would be created.
|
3871
|
+
tr_csr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
|
3872
|
+
tr_csc = shared_ptr<mf_problem>(Utility::copy_problem(tr_, true));
|
3873
|
+
va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
|
3874
|
+
}
|
3875
|
+
|
3876
|
+
// Make the training set CSR/CSC by sorting their nodes. More specifically,
|
3877
|
+
// a matrix with values sorted by row index is CSR and vice versa. We will
|
3878
|
+
// compute the starting location for each row (CSR) and each column (CSC)
|
3879
|
+
// later.
|
3880
|
+
sort(tr_csr->R, tr_csr->R+tr_csr->nnz, sort_node_by_p());
|
3881
|
+
sort(tr_csc->R, tr_csc->R+tr_csc->nnz, sort_node_by_q());
|
3882
|
+
|
3883
|
+
// Save starting addresses of rows for CSR and columns for CSC.
|
3884
|
+
mf_int u_current = -1;
|
3885
|
+
mf_int v_current = -1;
|
3886
|
+
for(mf_long i = 0; i < tr_->nnz; ++i)
|
3887
|
+
{
|
3888
|
+
mf_node* N = nullptr;
|
3889
|
+
|
3890
|
+
// Deal with CSR format.
|
3891
|
+
N = tr_csr->R + i;
|
3892
|
+
// Since tr_csr has been sorted by index u, seeing a larger index
|
3893
|
+
// implies a new row. Assume a node is encoded a tuple of (u, v, r),
|
3894
|
+
// where u is row index, v is column index, and r is entry value.
|
3895
|
+
// The nodes in tr_csr->R could be
|
3896
|
+
// (0, 1, 0.5), (0, 2, 3.7), (0, 4, -1.2), (2, 0, 1.2), (2, 4, 2.5)
|
3897
|
+
// Then, we can see the first element of the 3rd row (indexed by 2)
|
3898
|
+
// is (2, 0, 1.2), which is the 4th element in tr_csr->R. Note that
|
3899
|
+
// we use the row pointer of the next non-empty row as the pointers
|
3900
|
+
// of empty rows. That is,
|
3901
|
+
// ptrs[0] = pointer of (0, 1, 0.5)
|
3902
|
+
// ptrs[1] = pointer of (0, 2, 1.5)
|
3903
|
+
// ptrs[2] = pointer of (0, 2, 1.5)
|
3904
|
+
if(N->u > u_current)
|
3905
|
+
{
|
3906
|
+
// We (if u_current != -1) have assigned starting addresses to rows
|
3907
|
+
// indexed by values smaller than or equal to u_current. Thus, we
|
3908
|
+
// should handle all rows indexed starting from u_current+1 to the
|
3909
|
+
// seen row index N->u.
|
3910
|
+
for(mf_int u_passed = u_current + 1; u_passed <= N->u; ++u_passed)
|
3911
|
+
{
|
3912
|
+
// i-th non-zero value's location in tr_csr is the starting
|
3913
|
+
// address of u_passed-th row.
|
3914
|
+
ptrs_u[u_passed] = tr_csr->R + i;
|
3915
|
+
}
|
3916
|
+
u_current = N->u;
|
3917
|
+
}
|
3918
|
+
|
3919
|
+
// Deal with CSC format
|
3920
|
+
N = tr_csc->R + i;
|
3921
|
+
if(N->v > v_current)
|
3922
|
+
{
|
3923
|
+
// We (if v_current != -1) have assigned starting addresses to rows
|
3924
|
+
// indexed by values smaller than or equal to v_current. Thus, we
|
3925
|
+
// should handle all columns indexed starting from v_current+1 to
|
3926
|
+
// the seen row index N->v.
|
3927
|
+
for(mf_int v_passed = v_current + 1; v_passed <= N->v; ++v_passed)
|
3928
|
+
{
|
3929
|
+
// i-th non-zero value's location in tr_csc is the starting
|
3930
|
+
// address of v_passed-th column.
|
3931
|
+
ptrs_v[v_passed] = tr_csc->R + i;
|
3932
|
+
}
|
3933
|
+
v_current = N->v;
|
3934
|
+
}
|
3935
|
+
|
3936
|
+
}
|
3937
|
+
// The bound of the last row. It's the address one-element behind the last
|
3938
|
+
// matrix entry.
|
3939
|
+
for(mf_int u_passed = u_current + 1; u_passed <= tr_->m; ++u_passed)
|
3940
|
+
ptrs_u[u_passed] = tr_csr->R + tr_csr->nnz;
|
3941
|
+
// The bound of the last column.
|
3942
|
+
for(mf_int v_passed = v_current + 1; v_passed <= tr_->n; ++v_passed)
|
3943
|
+
ptrs_v[v_passed] = tr_csc->R + tr_csc->nnz;
|
3944
|
+
|
3945
|
+
|
3946
|
+
model = shared_ptr<mf_model>(Utility::init_model(tr_->m, tr_->n, param.k),
|
3947
|
+
[] (mf_model *ptr) { mf_destroy_model(&ptr); });
|
3948
|
+
|
3949
|
+
ccd_one_class_core(util, tr_csr, tr_csc, va, param, ptrs_u, ptrs_v, model);
|
3950
|
+
}
|
3951
|
+
catch(exception const &e)
|
3952
|
+
{
|
3953
|
+
cerr << e.what() << endl;
|
3954
|
+
throw;
|
3955
|
+
}
|
3956
|
+
return model;
|
3957
|
+
}
|
3958
|
+
|
3959
|
+
bool check_parameter(mf_parameter param)
|
3960
|
+
{
|
3961
|
+
if(param.fun != P_L2_MFR &&
|
3962
|
+
param.fun != P_L1_MFR &&
|
3963
|
+
param.fun != P_KL_MFR &&
|
3964
|
+
param.fun != P_LR_MFC &&
|
3965
|
+
param.fun != P_L2_MFC &&
|
3966
|
+
param.fun != P_L1_MFC &&
|
3967
|
+
param.fun != P_ROW_BPR_MFOC &&
|
3968
|
+
param.fun != P_COL_BPR_MFOC &&
|
3969
|
+
param.fun != P_L2_MFOC)
|
3970
|
+
{
|
3971
|
+
cerr << "unknown loss function" << endl;
|
3972
|
+
return false;
|
3973
|
+
}
|
3974
|
+
|
3975
|
+
if(param.k < 1)
|
3976
|
+
{
|
3977
|
+
cerr << "number of factors must be greater than zero" << endl;
|
3978
|
+
return false;
|
3979
|
+
}
|
3980
|
+
|
3981
|
+
if(param.nr_threads < 1)
|
3982
|
+
{
|
3983
|
+
cerr << "number of threads must be greater than zero" << endl;
|
3984
|
+
return false;
|
3985
|
+
}
|
3986
|
+
|
3987
|
+
if(param.nr_bins < 1 || param.nr_bins < param.nr_threads)
|
3988
|
+
{
|
3989
|
+
cerr << "number of bins must be greater than number of threads"
|
3990
|
+
<< endl;
|
3991
|
+
return false;
|
3992
|
+
}
|
3993
|
+
|
3994
|
+
if(param.nr_iters < 1)
|
3995
|
+
{
|
3996
|
+
cerr << "number of iterations must be greater than zero" << endl;
|
3997
|
+
return false;
|
3998
|
+
}
|
3999
|
+
|
4000
|
+
if(param.lambda_p1 < 0 ||
|
4001
|
+
param.lambda_p2 < 0 ||
|
4002
|
+
param.lambda_q1 < 0 ||
|
4003
|
+
param.lambda_q2 < 0)
|
4004
|
+
{
|
4005
|
+
cerr << "regularization coefficient must be non-negative" << endl;
|
4006
|
+
return false;
|
4007
|
+
}
|
4008
|
+
|
4009
|
+
if(param.eta <= 0)
|
4010
|
+
{
|
4011
|
+
cerr << "learning rate must be greater than zero" << endl;
|
4012
|
+
return false;
|
4013
|
+
}
|
4014
|
+
|
4015
|
+
if(param.fun == P_KL_MFR && !param.do_nmf)
|
4016
|
+
{
|
4017
|
+
cerr << "--nmf must be set when using generalized KL-divergence"
|
4018
|
+
<< endl;
|
4019
|
+
return false;
|
4020
|
+
}
|
4021
|
+
|
4022
|
+
if(param.nr_bins <= 2*param.nr_threads)
|
4023
|
+
{
|
4024
|
+
cerr << "Warning: insufficient blocks may slow down the training"
|
4025
|
+
<< "process (4*nr_threads^2+1 blocks is suggested)" << endl;
|
4026
|
+
}
|
4027
|
+
|
4028
|
+
if(param.nr_bins <= 2*param.nr_threads)
|
4029
|
+
{
|
4030
|
+
cerr << "Warning: insufficient blocks may slow down the training"
|
4031
|
+
<< "process (4*nr_threads^2+1 blocks is suggested)" << endl;
|
4032
|
+
}
|
4033
|
+
|
4034
|
+
if(param.alpha < 0)
|
4035
|
+
{
|
4036
|
+
cerr << "alpha must be a non-negative number" << endl;
|
4037
|
+
}
|
4038
|
+
|
4039
|
+
return true;
|
4040
|
+
}
|
4041
|
+
|
4042
|
+
//--------------------------------------
|
4043
|
+
//-----Classes for cross validation-----
|
4044
|
+
//--------------------------------------
|
4045
|
+
|
4046
|
+
class CrossValidatorBase
|
4047
|
+
{
|
4048
|
+
public:
|
4049
|
+
CrossValidatorBase(mf_parameter param_, mf_int nr_folds_);
|
4050
|
+
mf_double do_cross_validation();
|
4051
|
+
virtual mf_double do_cv1(vector<mf_int> &hidden_blocks) = 0;
|
4052
|
+
protected:
|
4053
|
+
mf_parameter param;
|
4054
|
+
mf_int nr_bins;
|
4055
|
+
mf_int nr_folds;
|
4056
|
+
mf_int nr_blocks_per_fold;
|
4057
|
+
bool quiet;
|
4058
|
+
Utility util;
|
4059
|
+
mf_double cv_error;
|
4060
|
+
};
|
4061
|
+
|
4062
|
+
CrossValidatorBase::CrossValidatorBase(mf_parameter param_, mf_int nr_folds_)
|
4063
|
+
: param(param_), nr_bins(param_.nr_bins), nr_folds(nr_folds_),
|
4064
|
+
nr_blocks_per_fold(nr_bins*nr_bins/nr_folds), quiet(param_.quiet),
|
4065
|
+
util(param.fun, param.nr_threads), cv_error(0)
|
4066
|
+
{
|
4067
|
+
param.quiet = true;
|
4068
|
+
}
|
4069
|
+
|
4070
|
+
mf_double CrossValidatorBase::do_cross_validation()
|
4071
|
+
{
|
4072
|
+
vector<mf_int> cv_blocks;
|
4073
|
+
srand(0);
|
4074
|
+
for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
|
4075
|
+
cv_blocks.push_back(block);
|
4076
|
+
random_shuffle(cv_blocks.begin(), cv_blocks.end());
|
4077
|
+
|
4078
|
+
if(!quiet)
|
4079
|
+
{
|
4080
|
+
cout.width(4);
|
4081
|
+
cout << "fold";
|
4082
|
+
cout.width(10);
|
4083
|
+
cout << util.get_error_legend();
|
4084
|
+
cout << endl;
|
4085
|
+
}
|
4086
|
+
|
4087
|
+
cv_error = 0;
|
4088
|
+
|
4089
|
+
for(mf_int fold = 0; fold < nr_folds; ++fold)
|
4090
|
+
{
|
4091
|
+
mf_int begin = fold*nr_blocks_per_fold;
|
4092
|
+
mf_int end = min((fold+1)*nr_blocks_per_fold, nr_bins*nr_bins);
|
4093
|
+
vector<mf_int> hidden_blocks(cv_blocks.begin()+begin,
|
4094
|
+
cv_blocks.begin()+end);
|
4095
|
+
|
4096
|
+
mf_double err = do_cv1(hidden_blocks);
|
4097
|
+
cv_error += err;
|
4098
|
+
|
4099
|
+
if(!quiet)
|
4100
|
+
{
|
4101
|
+
cout.width(4);
|
4102
|
+
cout << fold;
|
4103
|
+
cout.width(10);
|
4104
|
+
cout << fixed << setprecision(4) << err;
|
4105
|
+
cout << endl;
|
4106
|
+
}
|
4107
|
+
}
|
4108
|
+
|
4109
|
+
if(!quiet)
|
4110
|
+
{
|
4111
|
+
cout.width(14);
|
4112
|
+
cout.fill('=');
|
4113
|
+
cout << "" << endl;
|
4114
|
+
cout.fill(' ');
|
4115
|
+
cout.width(4);
|
4116
|
+
cout << "avg";
|
4117
|
+
cout.width(10);
|
4118
|
+
cout << fixed << setprecision(4) << cv_error/nr_folds;
|
4119
|
+
cout << endl;
|
4120
|
+
}
|
4121
|
+
|
4122
|
+
return cv_error/nr_folds;
|
4123
|
+
}
|
4124
|
+
|
4125
|
+
class CrossValidator : public CrossValidatorBase
|
4126
|
+
{
|
4127
|
+
public:
|
4128
|
+
CrossValidator(
|
4129
|
+
mf_parameter param_, mf_int nr_folds_, mf_problem const *prob_)
|
4130
|
+
: CrossValidatorBase(param_, nr_folds_), prob(prob_) {};
|
4131
|
+
mf_double do_cv1(vector<mf_int> &hidden_blocks);
|
4132
|
+
private:
|
4133
|
+
mf_problem const *prob;
|
4134
|
+
};
|
4135
|
+
|
4136
|
+
mf_double CrossValidator::do_cv1(vector<mf_int> &hidden_blocks)
|
4137
|
+
{
|
4138
|
+
mf_double err = 0;
|
4139
|
+
fpsg(prob, nullptr, param, hidden_blocks, &err);
|
4140
|
+
return err;
|
4141
|
+
}
|
4142
|
+
|
4143
|
+
class CrossValidatorOnDisk : public CrossValidatorBase
|
4144
|
+
{
|
4145
|
+
public:
|
4146
|
+
CrossValidatorOnDisk(
|
4147
|
+
mf_parameter param_, mf_int nr_folds_, string data_path_)
|
4148
|
+
: CrossValidatorBase(param_, nr_folds_), data_path(data_path_) {};
|
4149
|
+
mf_double do_cv1(vector<mf_int> &hidden_blocks);
|
4150
|
+
private:
|
4151
|
+
string data_path;
|
4152
|
+
};
|
4153
|
+
|
4154
|
+
mf_double CrossValidatorOnDisk::do_cv1(vector<mf_int> &hidden_blocks)
|
4155
|
+
{
|
4156
|
+
mf_double err = 0;
|
4157
|
+
fpsg_on_disk(data_path, string(), param, hidden_blocks, &err);
|
4158
|
+
return err;
|
4159
|
+
}
|
4160
|
+
|
4161
|
+
} // unnamed namespace
|
4162
|
+
|
4163
|
+
mf_model* mf_train_with_validation(
|
4164
|
+
mf_problem const *tr,
|
4165
|
+
mf_problem const *va,
|
4166
|
+
mf_parameter param)
|
4167
|
+
{
|
4168
|
+
if(!check_parameter(param))
|
4169
|
+
return nullptr;
|
4170
|
+
|
4171
|
+
shared_ptr<mf_model> model(nullptr);
|
4172
|
+
|
4173
|
+
if(param.fun != P_L2_MFOC)
|
4174
|
+
// Use stochastic gradient method
|
4175
|
+
model = fpsg(tr, va, param);
|
4176
|
+
else
|
4177
|
+
// Use coordinate descent method
|
4178
|
+
model = ccd_one_class(tr, va, param);
|
4179
|
+
|
4180
|
+
mf_model *model_ret = new mf_model;
|
4181
|
+
|
4182
|
+
model_ret->fun = model->fun;
|
4183
|
+
model_ret->m = model->m;
|
4184
|
+
model_ret->n = model->n;
|
4185
|
+
model_ret->k = model->k;
|
4186
|
+
model_ret->b = model->b;
|
4187
|
+
|
4188
|
+
model_ret->P = model->P;
|
4189
|
+
model->P = nullptr;
|
4190
|
+
|
4191
|
+
model_ret->Q = model->Q;
|
4192
|
+
model->Q = nullptr;
|
4193
|
+
|
4194
|
+
return model_ret;
|
4195
|
+
}
|
4196
|
+
|
4197
|
+
mf_model* mf_train_with_validation_on_disk(
|
4198
|
+
char const *tr_path,
|
4199
|
+
char const *va_path,
|
4200
|
+
mf_parameter param)
|
4201
|
+
{
|
4202
|
+
// Two conditions lead to empty model. First, any parameter is not in its
|
4203
|
+
// supported range. Second, one-class matrix facotorization with L2-loss
|
4204
|
+
// (-f 12) doesn't support disk-level training.
|
4205
|
+
if(!check_parameter(param) || param.fun == P_L2_MFOC)
|
4206
|
+
return nullptr;
|
4207
|
+
|
4208
|
+
shared_ptr<mf_model> model = fpsg_on_disk(
|
4209
|
+
string(tr_path), string(va_path), param);
|
4210
|
+
|
4211
|
+
mf_model *model_ret = new mf_model;
|
4212
|
+
|
4213
|
+
model_ret->fun = model->fun;
|
4214
|
+
model_ret->m = model->m;
|
4215
|
+
model_ret->n = model->n;
|
4216
|
+
model_ret->k = model->k;
|
4217
|
+
model_ret->b = model->b;
|
4218
|
+
|
4219
|
+
model_ret->P = model->P;
|
4220
|
+
model->P = nullptr;
|
4221
|
+
|
4222
|
+
model_ret->Q = model->Q;
|
4223
|
+
model->Q = nullptr;
|
4224
|
+
|
4225
|
+
return model_ret;
|
4226
|
+
}
|
4227
|
+
|
4228
|
+
mf_model* mf_train(mf_problem const *prob, mf_parameter param)
|
4229
|
+
{
|
4230
|
+
return mf_train_with_validation(prob, nullptr, param);
|
4231
|
+
}
|
4232
|
+
|
4233
|
+
mf_model* mf_train_on_disk(char const *tr_path, mf_parameter param)
|
4234
|
+
{
|
4235
|
+
return mf_train_with_validation_on_disk(tr_path, "", param);
|
4236
|
+
}
|
4237
|
+
|
4238
|
+
mf_double mf_cross_validation(
|
4239
|
+
mf_problem const *prob,
|
4240
|
+
mf_int nr_folds,
|
4241
|
+
mf_parameter param)
|
4242
|
+
{
|
4243
|
+
// Two conditions lead to empty model. First, any parameter is not in its
|
4244
|
+
// supported range. Second, one-class matrix facotorization with L2-loss
|
4245
|
+
// (-f 12) doesn't support disk-level training.
|
4246
|
+
if(!check_parameter(param) || param.fun == P_L2_MFOC)
|
4247
|
+
return 0;
|
4248
|
+
|
4249
|
+
CrossValidator validator(param, nr_folds, prob);
|
4250
|
+
|
4251
|
+
return validator.do_cross_validation();
|
4252
|
+
}
|
4253
|
+
|
4254
|
+
mf_double mf_cross_validation_on_disk(
|
4255
|
+
char const *prob,
|
4256
|
+
mf_int nr_folds,
|
4257
|
+
mf_parameter param)
|
4258
|
+
{
|
4259
|
+
// Two conditions lead to empty model. First, any parameter is not in its
|
4260
|
+
// supported range. Second, one-class matrix facotorization with L2-loss
|
4261
|
+
// (-f 12) doesn't support disk-level training.
|
4262
|
+
if(!check_parameter(param) || param.fun == P_L2_MFOC)
|
4263
|
+
return 0;
|
4264
|
+
|
4265
|
+
CrossValidatorOnDisk validator(param, nr_folds, string(prob));
|
4266
|
+
|
4267
|
+
return validator.do_cross_validation();
|
4268
|
+
}
|
4269
|
+
|
4270
|
+
mf_problem read_problem(string path)
|
4271
|
+
{
|
4272
|
+
mf_problem prob;
|
4273
|
+
prob.m = 0;
|
4274
|
+
prob.n = 0;
|
4275
|
+
prob.nnz = 0;
|
4276
|
+
prob.R = nullptr;
|
4277
|
+
|
4278
|
+
if(path.empty())
|
4279
|
+
return prob;
|
4280
|
+
|
4281
|
+
ifstream f(path);
|
4282
|
+
if(!f.is_open())
|
4283
|
+
return prob;
|
4284
|
+
|
4285
|
+
string line;
|
4286
|
+
while(getline(f, line))
|
4287
|
+
prob.nnz += 1;
|
4288
|
+
|
4289
|
+
mf_node *R = new mf_node[static_cast<size_t>(prob.nnz)];
|
4290
|
+
|
4291
|
+
f.close();
|
4292
|
+
f.open(path);
|
4293
|
+
|
4294
|
+
mf_long idx = 0;
|
4295
|
+
for(mf_node N; f >> N.u >> N.v >> N.r;)
|
4296
|
+
{
|
4297
|
+
if(N.u+1 > prob.m)
|
4298
|
+
prob.m = N.u+1;
|
4299
|
+
if(N.v+1 > prob.n)
|
4300
|
+
prob.n = N.v+1;
|
4301
|
+
R[idx] = N;
|
4302
|
+
++idx;
|
4303
|
+
}
|
4304
|
+
prob.R = R;
|
4305
|
+
|
4306
|
+
f.close();
|
4307
|
+
|
4308
|
+
return prob;
|
4309
|
+
}
|
4310
|
+
|
4311
|
+
mf_int mf_save_model(mf_model const *model, char const *path)
|
4312
|
+
{
|
4313
|
+
ofstream f(path);
|
4314
|
+
if(!f.is_open())
|
4315
|
+
return 1;
|
4316
|
+
|
4317
|
+
f << "f " << model->fun << endl;
|
4318
|
+
f << "m " << model->m << endl;
|
4319
|
+
f << "n " << model->n << endl;
|
4320
|
+
f << "k " << model->k << endl;
|
4321
|
+
f << "b " << model->b << endl;
|
4322
|
+
|
4323
|
+
auto write = [&] (mf_float *ptr, mf_int size, char prefix)
|
4324
|
+
{
|
4325
|
+
for(mf_int i = 0; i < size; ++i)
|
4326
|
+
{
|
4327
|
+
mf_float *ptr1 = ptr + (mf_long)i*model->k;
|
4328
|
+
f << prefix << i << " ";
|
4329
|
+
if(isnan(ptr1[0]))
|
4330
|
+
{
|
4331
|
+
f << "F ";
|
4332
|
+
for(mf_int d = 0; d < model->k; ++d)
|
4333
|
+
f << 0 << " ";
|
4334
|
+
}
|
4335
|
+
else
|
4336
|
+
{
|
4337
|
+
f << "T ";
|
4338
|
+
for(mf_int d = 0; d < model->k; ++d)
|
4339
|
+
f << ptr1[d] << " ";
|
4340
|
+
}
|
4341
|
+
f << endl;
|
4342
|
+
}
|
4343
|
+
|
4344
|
+
};
|
4345
|
+
|
4346
|
+
write(model->P, model->m, 'p');
|
4347
|
+
write(model->Q, model->n, 'q');
|
4348
|
+
|
4349
|
+
f.close();
|
4350
|
+
|
4351
|
+
return 0;
|
4352
|
+
}
|
4353
|
+
|
4354
|
+
mf_model* mf_load_model(char const *path)
|
4355
|
+
{
|
4356
|
+
ifstream f(path);
|
4357
|
+
if(!f.is_open())
|
4358
|
+
return nullptr;
|
4359
|
+
|
4360
|
+
string dummy;
|
4361
|
+
|
4362
|
+
mf_model *model = new mf_model;
|
4363
|
+
model->P = nullptr;
|
4364
|
+
model->Q = nullptr;
|
4365
|
+
|
4366
|
+
f >> dummy >> model->fun >> dummy >> model->m >> dummy >> model->n >>
|
4367
|
+
dummy >> model->k >> dummy >> model->b;
|
4368
|
+
|
4369
|
+
try
|
4370
|
+
{
|
4371
|
+
model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
|
4372
|
+
model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
|
4373
|
+
}
|
4374
|
+
catch(bad_alloc const &e)
|
4375
|
+
{
|
4376
|
+
cerr << e.what() << endl;
|
4377
|
+
mf_destroy_model(&model);
|
4378
|
+
return nullptr;
|
4379
|
+
}
|
4380
|
+
|
4381
|
+
auto read = [&] (mf_float *ptr, mf_int size)
|
4382
|
+
{
|
4383
|
+
for(mf_int i = 0; i < size; ++i)
|
4384
|
+
{
|
4385
|
+
mf_float *ptr1 = ptr + (mf_long)i*model->k;
|
4386
|
+
f >> dummy >> dummy;
|
4387
|
+
if(dummy.compare("F") == 0) // nan vector starts with "F"
|
4388
|
+
for(mf_int d = 0; d < model->k; ++d)
|
4389
|
+
{
|
4390
|
+
f >> dummy;
|
4391
|
+
ptr1[d] = numeric_limits<mf_float>::quiet_NaN();
|
4392
|
+
}
|
4393
|
+
else
|
4394
|
+
for(mf_int d = 0; d < model->k; ++d)
|
4395
|
+
f >> ptr1[d];
|
4396
|
+
}
|
4397
|
+
};
|
4398
|
+
|
4399
|
+
read(model->P, model->m);
|
4400
|
+
read(model->Q, model->n);
|
4401
|
+
|
4402
|
+
f.close();
|
4403
|
+
|
4404
|
+
return model;
|
4405
|
+
}
|
4406
|
+
|
4407
|
+
void mf_destroy_model(mf_model **model)
|
4408
|
+
{
|
4409
|
+
if(model == nullptr || *model == nullptr)
|
4410
|
+
return;
|
4411
|
+
Utility::free_aligned_float((*model)->P);
|
4412
|
+
Utility::free_aligned_float((*model)->Q);
|
4413
|
+
delete *model;
|
4414
|
+
*model = nullptr;
|
4415
|
+
}
|
4416
|
+
|
4417
|
+
mf_float mf_predict(mf_model const *model, mf_int u, mf_int v)
|
4418
|
+
{
|
4419
|
+
if(u < 0 || u >= model->m || v < 0 || v >= model->n)
|
4420
|
+
return model->b;
|
4421
|
+
|
4422
|
+
mf_float *p = model->P+(mf_long)u*model->k;
|
4423
|
+
mf_float *q = model->Q+(mf_long)v*model->k;
|
4424
|
+
|
4425
|
+
mf_float z = std::inner_product(p, p+model->k, q, (mf_float)0.0f);
|
4426
|
+
|
4427
|
+
if(isnan(z))
|
4428
|
+
z = model->b;
|
4429
|
+
|
4430
|
+
if(model->fun == P_L2_MFC ||
|
4431
|
+
model->fun == P_L1_MFC ||
|
4432
|
+
model->fun == P_LR_MFC)
|
4433
|
+
z = z > 0.0f? 1.0f: -1.0f;
|
4434
|
+
|
4435
|
+
return z;
|
4436
|
+
}
|
4437
|
+
|
4438
|
+
mf_double calc_rmse(mf_problem *prob, mf_model *model)
|
4439
|
+
{
|
4440
|
+
if(prob->nnz == 0)
|
4441
|
+
return 0;
|
4442
|
+
mf_double loss = 0;
|
4443
|
+
#if defined USEOMP
|
4444
|
+
#pragma omp parallel for schedule(static) reduction(+:loss)
|
4445
|
+
#endif
|
4446
|
+
for(mf_long i = 0; i < prob->nnz; ++i)
|
4447
|
+
{
|
4448
|
+
mf_node &N = prob->R[i];
|
4449
|
+
mf_float e = N.r - mf_predict(model, N.u, N.v);
|
4450
|
+
loss += e*e;
|
4451
|
+
}
|
4452
|
+
return sqrt(loss/prob->nnz);
|
4453
|
+
}
|
4454
|
+
|
4455
|
+
mf_double calc_mae(mf_problem *prob, mf_model *model)
|
4456
|
+
{
|
4457
|
+
if(prob->nnz == 0)
|
4458
|
+
return 0;
|
4459
|
+
mf_double loss = 0;
|
4460
|
+
#if defined USEOMP
|
4461
|
+
#pragma omp parallel for schedule(static) reduction(+:loss)
|
4462
|
+
#endif
|
4463
|
+
for(mf_long i = 0; i < prob->nnz; ++i)
|
4464
|
+
{
|
4465
|
+
mf_node &N = prob->R[i];
|
4466
|
+
loss += abs(N.r - mf_predict(model, N.u, N.v));
|
4467
|
+
}
|
4468
|
+
return loss/prob->nnz;
|
4469
|
+
}
|
4470
|
+
|
4471
|
+
mf_double calc_gkl(mf_problem *prob, mf_model *model)
|
4472
|
+
{
|
4473
|
+
if(prob->nnz == 0)
|
4474
|
+
return 0;
|
4475
|
+
mf_double loss = 0;
|
4476
|
+
#if defined USEOMP
|
4477
|
+
#pragma omp parallel for schedule(static) reduction(+:loss)
|
4478
|
+
#endif
|
4479
|
+
for(mf_long i = 0; i < prob->nnz; ++i)
|
4480
|
+
{
|
4481
|
+
mf_node &N = prob->R[i];
|
4482
|
+
mf_float z = mf_predict(model, N.u, N.v);
|
4483
|
+
loss += N.r*log(N.r/z)-N.r+z;
|
4484
|
+
}
|
4485
|
+
return loss/prob->nnz;
|
4486
|
+
}
|
4487
|
+
|
4488
|
+
mf_double calc_logloss(mf_problem *prob, mf_model *model)
|
4489
|
+
{
|
4490
|
+
if(prob->nnz == 0)
|
4491
|
+
return 0;
|
4492
|
+
mf_double logloss = 0;
|
4493
|
+
#if defined USEOMP
|
4494
|
+
#pragma omp parallel for schedule(static) reduction(+:logloss)
|
4495
|
+
#endif
|
4496
|
+
for(mf_long i = 0; i < prob->nnz; ++i)
|
4497
|
+
{
|
4498
|
+
mf_node &N = prob->R[i];
|
4499
|
+
mf_float z = mf_predict(model, N.u, N.v);
|
4500
|
+
if(N.r > 0)
|
4501
|
+
logloss += log(1.0+exp(-z));
|
4502
|
+
else
|
4503
|
+
logloss += log(1.0+exp(z));
|
4504
|
+
}
|
4505
|
+
return logloss/prob->nnz;
|
4506
|
+
}
|
4507
|
+
|
4508
|
+
mf_double calc_accuracy(mf_problem *prob, mf_model *model)
|
4509
|
+
{
|
4510
|
+
if(prob->nnz == 0)
|
4511
|
+
return 0;
|
4512
|
+
mf_double acc = 0;
|
4513
|
+
#if defined USEOMP
|
4514
|
+
#pragma omp parallel for schedule(static) reduction(+:acc)
|
4515
|
+
#endif
|
4516
|
+
for(mf_long i = 0; i < prob->nnz; ++i)
|
4517
|
+
{
|
4518
|
+
mf_node &N = prob->R[i];
|
4519
|
+
mf_float z = mf_predict(model, N.u, N.v);
|
4520
|
+
if(N.r > 0)
|
4521
|
+
acc += z > 0? 1: 0;
|
4522
|
+
else
|
4523
|
+
acc += z < 0? 1: 0;
|
4524
|
+
}
|
4525
|
+
return acc/prob->nnz;
|
4526
|
+
}
|
4527
|
+
|
4528
|
+
pair<mf_double, mf_double> calc_mpr_auc(mf_problem *prob,
|
4529
|
+
mf_model *model, bool transpose)
|
4530
|
+
{
|
4531
|
+
mf_int mf_node::*row_ptr;
|
4532
|
+
mf_int mf_node::*col_ptr;
|
4533
|
+
mf_int m = 0, n = 0;
|
4534
|
+
if(!transpose)
|
4535
|
+
{
|
4536
|
+
row_ptr = &mf_node::u;
|
4537
|
+
col_ptr = &mf_node::v;
|
4538
|
+
m = max(prob->m, model->m);
|
4539
|
+
n = max(prob->n, model->n);
|
4540
|
+
}
|
4541
|
+
else
|
4542
|
+
{
|
4543
|
+
row_ptr = &mf_node::v;
|
4544
|
+
col_ptr = &mf_node::u;
|
4545
|
+
m = max(prob->n, model->n);
|
4546
|
+
n = max(prob->m, model->m);
|
4547
|
+
}
|
4548
|
+
|
4549
|
+
auto sort_by_id = [&] (mf_node const &lhs, mf_node const &rhs)
|
4550
|
+
{
|
4551
|
+
return tie(lhs.*row_ptr, lhs.*col_ptr) <
|
4552
|
+
tie(rhs.*row_ptr, rhs.*col_ptr);
|
4553
|
+
};
|
4554
|
+
|
4555
|
+
sort(prob->R, prob->R+prob->nnz, sort_by_id);
|
4556
|
+
|
4557
|
+
auto sort_by_pred = [&] (pair<mf_node, mf_float> const &lhs,
|
4558
|
+
pair<mf_node, mf_float> const &rhs) { return lhs.second < rhs.second; };
|
4559
|
+
|
4560
|
+
vector<mf_int> pos_cnts(m+1, 0);
|
4561
|
+
for(mf_int i = 0; i < prob->nnz; ++i)
|
4562
|
+
pos_cnts[prob->R[i].*row_ptr+1] += 1;
|
4563
|
+
for(mf_int i = 1; i < m+1; ++i)
|
4564
|
+
pos_cnts[i] += pos_cnts[i-1];
|
4565
|
+
|
4566
|
+
mf_int total_m = 0;
|
4567
|
+
mf_long total_pos = 0;
|
4568
|
+
mf_double all_u_mpr = 0;
|
4569
|
+
mf_double all_u_auc = 0;
|
4570
|
+
#if defined USEOMP
|
4571
|
+
#pragma omp parallel for schedule(static) reduction(+: total_m, total_pos, all_u_mpr, all_u_auc)
|
4572
|
+
#endif
|
4573
|
+
for(mf_int i = 0; i < m; ++i)
|
4574
|
+
{
|
4575
|
+
if(pos_cnts[i+1]-pos_cnts[i] < 1)
|
4576
|
+
continue;
|
4577
|
+
|
4578
|
+
vector<pair<mf_node, mf_float>> row(n);
|
4579
|
+
|
4580
|
+
for(mf_int j = 0; j < n; ++j)
|
4581
|
+
{
|
4582
|
+
mf_node N;
|
4583
|
+
N.*row_ptr = i;
|
4584
|
+
N.*col_ptr = j;
|
4585
|
+
N.r = 0;
|
4586
|
+
row[j] = make_pair(N, mf_predict(model, N.u, N.v));
|
4587
|
+
}
|
4588
|
+
|
4589
|
+
mf_int pos = 0;
|
4590
|
+
vector<mf_int> index(pos_cnts[i+1]-pos_cnts[i], 0);
|
4591
|
+
for(mf_int j = pos_cnts[i]; j < pos_cnts[i+1]; ++j)
|
4592
|
+
{
|
4593
|
+
if(prob->R[j].r <= 0)
|
4594
|
+
continue;
|
4595
|
+
|
4596
|
+
mf_int col = prob->R[j].*col_ptr;
|
4597
|
+
row[col].first.r = prob->R[j].r;
|
4598
|
+
index[pos] = col;
|
4599
|
+
pos += 1;
|
4600
|
+
}
|
4601
|
+
|
4602
|
+
if(n-pos < 1 || pos < 1)
|
4603
|
+
continue;
|
4604
|
+
|
4605
|
+
++total_m;
|
4606
|
+
total_pos += pos;
|
4607
|
+
|
4608
|
+
mf_int count = 0;
|
4609
|
+
for(mf_int k = 0; k < pos; ++k)
|
4610
|
+
{
|
4611
|
+
swap(row[count], row[index[k]]);
|
4612
|
+
++count;
|
4613
|
+
}
|
4614
|
+
sort(row.begin(), row.begin()+pos, sort_by_pred);
|
4615
|
+
|
4616
|
+
mf_double u_mpr = 0;
|
4617
|
+
mf_double u_auc = 0;
|
4618
|
+
for(auto neg_it = row.begin()+pos; neg_it != row.end(); ++neg_it)
|
4619
|
+
{
|
4620
|
+
if(row[pos-1].second <= neg_it->second)
|
4621
|
+
{
|
4622
|
+
u_mpr += pos;
|
4623
|
+
continue;
|
4624
|
+
}
|
4625
|
+
|
4626
|
+
mf_int left = 0;
|
4627
|
+
mf_int right = pos-1;
|
4628
|
+
while(left < right)
|
4629
|
+
{
|
4630
|
+
mf_int mid = (left+right)/2;
|
4631
|
+
if(row[mid].second > neg_it->second)
|
4632
|
+
right = mid;
|
4633
|
+
else
|
4634
|
+
left = mid+1;
|
4635
|
+
}
|
4636
|
+
u_mpr += left;
|
4637
|
+
u_auc += pos-left;
|
4638
|
+
}
|
4639
|
+
|
4640
|
+
all_u_mpr += u_mpr/(n-pos);
|
4641
|
+
all_u_auc += u_auc/(n-pos)/pos;
|
4642
|
+
}
|
4643
|
+
|
4644
|
+
all_u_mpr /= total_pos;
|
4645
|
+
all_u_auc /= total_m;
|
4646
|
+
|
4647
|
+
return make_pair(all_u_mpr, all_u_auc);
|
4648
|
+
}
|
4649
|
+
|
4650
|
+
mf_double calc_mpr(mf_problem *prob, mf_model *model, bool transpose)
|
4651
|
+
{
|
4652
|
+
return calc_mpr_auc(prob, model, transpose).first;
|
4653
|
+
}
|
4654
|
+
|
4655
|
+
mf_double calc_auc(mf_problem *prob, mf_model *model, bool transpose)
|
4656
|
+
{
|
4657
|
+
return calc_mpr_auc(prob, model, transpose).second;
|
4658
|
+
}
|
4659
|
+
|
4660
|
+
mf_parameter mf_get_default_param()
|
4661
|
+
{
|
4662
|
+
mf_parameter param;
|
4663
|
+
|
4664
|
+
param.fun = P_L2_MFR;
|
4665
|
+
param.k = 8;
|
4666
|
+
param.nr_threads = 12;
|
4667
|
+
param.nr_bins = 20;
|
4668
|
+
param.nr_iters = 20;
|
4669
|
+
param.lambda_p1 = 0.0f;
|
4670
|
+
param.lambda_q1 = 0.0f;
|
4671
|
+
param.lambda_p2 = 0.1f;
|
4672
|
+
param.lambda_q2 = 0.1f;
|
4673
|
+
param.eta = 0.1f;
|
4674
|
+
param.alpha = 1.0f;
|
4675
|
+
param.c = 0.0001f;
|
4676
|
+
param.do_nmf = false;
|
4677
|
+
param.quiet = false;
|
4678
|
+
param.copy_data = true;
|
4679
|
+
|
4680
|
+
return param;
|
4681
|
+
}
|
4682
|
+
|
4683
|
+
}
|