libmf 0.1.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,4683 @@
1
+ #include <algorithm>
2
+ #include <cmath>
3
+ #include <condition_variable>
4
+ #include <cstdlib>
5
+ #include <cstring>
6
+ #include <fstream>
7
+ #include <iostream>
8
+ #include <iomanip>
9
+ #include <memory>
10
+ #include <numeric>
11
+ #include <queue>
12
+ #include <random>
13
+ #include <stdexcept>
14
+ #include <string>
15
+ #include <thread>
16
+ #include <unordered_set>
17
+ #include <vector>
18
+ #include <limits>
19
+
20
+ #include "mf.h"
21
+
22
+ #if defined USESSE
23
+ #include <pmmintrin.h>
24
+ #endif
25
+
26
+ #if defined USEAVX
27
+ #include <immintrin.h>
28
+ #endif
29
+
30
+ #if defined USEOMP
31
+ #include <omp.h>
32
+ #endif
33
+
34
+ namespace mf
35
+ {
36
+
37
+ using namespace std;
38
+
39
+ namespace // unnamed namespace
40
+ {
41
+
42
+ mf_int const kALIGNByte = 32;
43
+ mf_int const kALIGN = kALIGNByte/sizeof(mf_float);
44
+
45
+ //--------------------------------------
46
+ //---------Scheduler of Blocks----------
47
+ //--------------------------------------
48
+
49
+ class Scheduler
50
+ {
51
+ public:
52
+ Scheduler(mf_int nr_bins, mf_int nr_threads, vector<mf_int> cv_blocks);
53
+ mf_int get_job();
54
+ mf_int get_bpr_job(mf_int first_block, bool is_column_oriented);
55
+ void put_job(mf_int block, mf_double loss, mf_double error);
56
+ void put_bpr_job(mf_int first_block, mf_int second_block);
57
+ mf_double get_loss();
58
+ mf_double get_error();
59
+ mf_int get_negative(mf_int first_block, mf_int second_block,
60
+ mf_int m, mf_int n, bool is_column_oriented);
61
+ void wait_for_jobs_done();
62
+ void resume();
63
+ void terminate();
64
+ bool is_terminated();
65
+
66
+ private:
67
+ mf_int nr_bins;
68
+ mf_int nr_threads;
69
+ mf_int nr_done_jobs;
70
+ mf_int target;
71
+ mf_int nr_paused_threads;
72
+ bool terminated;
73
+ vector<mf_int> counts;
74
+ vector<mf_int> busy_p_blocks;
75
+ vector<mf_int> busy_q_blocks;
76
+ vector<mf_double> block_losses;
77
+ vector<mf_double> block_errors;
78
+ vector<minstd_rand0> block_generators;
79
+ unordered_set<mf_int> cv_blocks;
80
+ mutex mtx;
81
+ condition_variable cond_var;
82
+ default_random_engine generator;
83
+ uniform_real_distribution<mf_float> distribution;
84
+ priority_queue<pair<mf_float, mf_int>,
85
+ vector<pair<mf_float, mf_int>>,
86
+ greater<pair<mf_float, mf_int>>> pq;
87
+ };
88
+
89
+ Scheduler::Scheduler(mf_int nr_bins, mf_int nr_threads,
90
+ vector<mf_int> cv_blocks)
91
+ : nr_bins(nr_bins),
92
+ nr_threads(nr_threads),
93
+ nr_done_jobs(0),
94
+ target(nr_bins*nr_bins),
95
+ nr_paused_threads(0),
96
+ terminated(false),
97
+ counts(nr_bins*nr_bins, 0),
98
+ busy_p_blocks(nr_bins, 0),
99
+ busy_q_blocks(nr_bins, 0),
100
+ block_losses(nr_bins*nr_bins, 0),
101
+ block_errors(nr_bins*nr_bins, 0),
102
+ cv_blocks(cv_blocks.begin(), cv_blocks.end()),
103
+ distribution(0.0, 1.0)
104
+ {
105
+ for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
106
+ {
107
+ if(this->cv_blocks.find(i) == this->cv_blocks.end())
108
+ pq.emplace(distribution(generator), i);
109
+ block_generators.push_back(minstd_rand0(rand()));
110
+ }
111
+ }
112
+
113
+ mf_int Scheduler::get_job()
114
+ {
115
+ bool is_found = false;
116
+ pair<mf_float, mf_int> block;
117
+
118
+ while(!is_found)
119
+ {
120
+ lock_guard<mutex> lock(mtx);
121
+ vector<pair<mf_float, mf_int>> locked_blocks;
122
+ mf_int p_block = 0;
123
+ mf_int q_block = 0;
124
+
125
+ while(!pq.empty())
126
+ {
127
+ block = pq.top();
128
+ pq.pop();
129
+
130
+ p_block = block.second/nr_bins;
131
+ q_block = block.second%nr_bins;
132
+
133
+ if(busy_p_blocks[p_block] || busy_q_blocks[q_block])
134
+ locked_blocks.push_back(block);
135
+ else
136
+ {
137
+ busy_p_blocks[p_block] = 1;
138
+ busy_q_blocks[q_block] = 1;
139
+ counts[block.second] += 1;
140
+ is_found = true;
141
+ break;
142
+ }
143
+ }
144
+
145
+ for(auto &block1 : locked_blocks)
146
+ pq.push(block1);
147
+ }
148
+
149
+ return block.second;
150
+ }
151
+
152
+ mf_int Scheduler::get_bpr_job(mf_int first_block, bool is_column_oriented)
153
+ {
154
+ lock_guard<mutex> lock(mtx);
155
+ mf_int another = first_block;
156
+ vector<pair<mf_float, mf_int>> locked_blocks;
157
+
158
+ while(!pq.empty())
159
+ {
160
+ pair<mf_float, mf_int> block = pq.top();
161
+ pq.pop();
162
+
163
+ mf_int p_block = block.second/nr_bins;
164
+ mf_int q_block = block.second%nr_bins;
165
+
166
+ auto is_rejected = [&] ()
167
+ {
168
+ if(is_column_oriented)
169
+ return first_block%nr_bins != q_block ||
170
+ busy_p_blocks[p_block];
171
+ else
172
+ return first_block/nr_bins != p_block ||
173
+ busy_q_blocks[q_block];
174
+ };
175
+
176
+ if(is_rejected())
177
+ locked_blocks.push_back(block);
178
+ else
179
+ {
180
+ busy_p_blocks[p_block] = 1;
181
+ busy_q_blocks[q_block] = 1;
182
+ another = block.second;
183
+ break;
184
+ }
185
+ }
186
+
187
+ for(auto &block : locked_blocks)
188
+ pq.push(block);
189
+
190
+ return another;
191
+ }
192
+
193
+ void Scheduler::put_job(mf_int block_idx, mf_double loss, mf_double error)
194
+ {
195
+ // Return the held block to the scheduler
196
+ {
197
+ lock_guard<mutex> lock(mtx);
198
+ busy_p_blocks[block_idx/nr_bins] = 0;
199
+ busy_q_blocks[block_idx%nr_bins] = 0;
200
+ block_losses[block_idx] = loss;
201
+ block_errors[block_idx] = error;
202
+ ++nr_done_jobs;
203
+ mf_float priority =
204
+ (mf_float)counts[block_idx]+distribution(generator);
205
+ pq.emplace(priority, block_idx);
206
+ ++nr_paused_threads;
207
+ // Tell others that a block is available again.
208
+ cond_var.notify_all();
209
+ }
210
+
211
+ // Wait if nr_done_jobs (aka the number of processed blocks) is too many
212
+ // because we want to print out the training status roughly once all blocks
213
+ // are processed once. This is the only place that a solver thread should
214
+ // wait for something.
215
+ {
216
+ unique_lock<mutex> lock(mtx);
217
+ cond_var.wait(lock, [&] {
218
+ return nr_done_jobs < target;
219
+ });
220
+ }
221
+
222
+ // Nothing is blocking and this thread is going to take another block
223
+ {
224
+ lock_guard<mutex> lock(mtx);
225
+ --nr_paused_threads;
226
+ }
227
+ }
228
+
229
+ void Scheduler::put_bpr_job(mf_int first_block, mf_int second_block)
230
+ {
231
+ if(first_block == second_block)
232
+ return;
233
+
234
+ lock_guard<mutex> lock(mtx);
235
+ {
236
+ busy_p_blocks[second_block/nr_bins] = 0;
237
+ busy_q_blocks[second_block%nr_bins] = 0;
238
+ mf_float priority =
239
+ (mf_float)counts[second_block]+distribution(generator);
240
+ pq.emplace(priority, second_block);
241
+ }
242
+ }
243
+
244
+ mf_double Scheduler::get_loss()
245
+ {
246
+ lock_guard<mutex> lock(mtx);
247
+ return accumulate(block_losses.begin(), block_losses.end(), 0.0);
248
+ }
249
+
250
+ mf_double Scheduler::get_error()
251
+ {
252
+ lock_guard<mutex> lock(mtx);
253
+ return accumulate(block_errors.begin(), block_errors.end(), 0.0);
254
+ }
255
+
256
+ mf_int Scheduler::get_negative(mf_int first_block, mf_int second_block,
257
+ mf_int m, mf_int n, bool is_column_oriented)
258
+ {
259
+ mf_int rand_val = (mf_int)block_generators[first_block]();
260
+
261
+ auto gen_random = [&] (mf_int block_id)
262
+ {
263
+ mf_int v_min, v_max;
264
+
265
+ if(is_column_oriented)
266
+ {
267
+ mf_int seg_size = (mf_int)ceil((double)m/nr_bins);
268
+ v_min = min((block_id/nr_bins)*seg_size, m-1);
269
+ v_max = min(v_min+seg_size, m-1);
270
+ }
271
+ else
272
+ {
273
+ mf_int seg_size = (mf_int)ceil((double)n/nr_bins);
274
+ v_min = min((block_id%nr_bins)*seg_size, n-1);
275
+ v_max = min(v_min+seg_size, n-1);
276
+ }
277
+ if(v_max == v_min)
278
+ return v_min;
279
+ else
280
+ return rand_val%(v_max-v_min)+v_min;
281
+ };
282
+
283
+ if(rand_val % 2)
284
+ return (mf_int)gen_random(first_block);
285
+ else
286
+ return (mf_int)gen_random(second_block);
287
+ }
288
+
289
+ void Scheduler::wait_for_jobs_done()
290
+ {
291
+ unique_lock<mutex> lock(mtx);
292
+
293
+ // The first thing the main thread should wait for is that solver threads
294
+ // process enough matrix blocks.
295
+ // [REVIEW] Is it really needed? Solver threads automatically stop if they
296
+ // process too many blocks, so the next wait should be enough for stopping
297
+ // the main thread when nr_done_job is not enough.
298
+ cond_var.wait(lock, [&] {
299
+ return nr_done_jobs >= target;
300
+ });
301
+
302
+ // Wait for all threads to stop. Once a thread realizes that all threads
303
+ // have processed enough blocks it should stop. Then, the main thread can
304
+ // print values safely.
305
+ cond_var.wait(lock, [&] {
306
+ return nr_paused_threads == nr_threads;
307
+ });
308
+ }
309
+
310
+ void Scheduler::resume()
311
+ {
312
+ lock_guard<mutex> lock(mtx);
313
+ target += nr_bins*nr_bins;
314
+ cond_var.notify_all();
315
+ }
316
+
317
+ void Scheduler::terminate()
318
+ {
319
+ lock_guard<mutex> lock(mtx);
320
+ terminated = true;
321
+ }
322
+
323
+ bool Scheduler::is_terminated()
324
+ {
325
+ lock_guard<mutex> lock(mtx);
326
+ return terminated;
327
+ }
328
+
329
+ //--------------------------------------
330
+ //------------Block of matrix-----------
331
+ //--------------------------------------
332
+
333
+ class BlockBase
334
+ {
335
+ public:
336
+ virtual bool move_next() { return false; };
337
+ virtual mf_node* get_current() { return nullptr; }
338
+ virtual void reload() {};
339
+ virtual void free() {};
340
+ virtual mf_long get_nnz() { return 0; };
341
+ virtual ~BlockBase() {};
342
+ };
343
+
344
+ class Block : public BlockBase
345
+ {
346
+ public:
347
+ Block() : first(nullptr), last(nullptr), current(nullptr) {};
348
+ Block(mf_node *first_, mf_node *last_)
349
+ : first(first_), last(last_), current(nullptr) {};
350
+ bool move_next() { return ++current != last; }
351
+ mf_node* get_current() { return current; }
352
+ void tie_to(mf_node *first_, mf_node *last_);
353
+ void reload() { current = first-1; };
354
+ mf_long get_nnz() { return last-first; };
355
+
356
+ private:
357
+ mf_node* first;
358
+ mf_node* last;
359
+ mf_node* current;
360
+ };
361
+
362
+ void Block::tie_to(mf_node *first_, mf_node *last_)
363
+ {
364
+ first = first_;
365
+ last = last_;
366
+ };
367
+
368
+ class BlockOnDisk : public BlockBase
369
+ {
370
+ public:
371
+ BlockOnDisk() : first(0), last(0), current(0),
372
+ source_path(""), buffer(0) {};
373
+ bool move_next() { return ++current < last-first; }
374
+ mf_node* get_current() { return &buffer[static_cast<size_t>(current)]; }
375
+ void tie_to(string source_path_, mf_long first_, mf_long last_);
376
+ void reload();
377
+ void free() { buffer.resize(0); };
378
+ mf_long get_nnz() { return last-first; };
379
+
380
+ private:
381
+ mf_long first;
382
+ mf_long last;
383
+ mf_long current;
384
+ string source_path;
385
+ vector<mf_node> buffer;
386
+ };
387
+
388
+ void BlockOnDisk::tie_to(string source_path_, mf_long first_, mf_long last_)
389
+ {
390
+ source_path = source_path_;
391
+ first = first_;
392
+ last = last_;
393
+ }
394
+
395
+ void BlockOnDisk::reload()
396
+ {
397
+ ifstream source(source_path, ifstream::in|ifstream::binary);
398
+ if(!source)
399
+ throw runtime_error("can not open "+source_path);
400
+
401
+ buffer.resize(static_cast<size_t>(last-first));
402
+ source.seekg(first*sizeof(mf_node));
403
+ source.read((char*)buffer.data(), (last-first)*sizeof(mf_node));
404
+ current = -1;
405
+ }
406
+
407
+ //--------------------------------------
408
+ //-------------Miscellaneous------------
409
+ //--------------------------------------
410
+
411
+ struct sort_node_by_p
412
+ {
413
+ bool operator() (mf_node const &lhs, mf_node const &rhs)
414
+ {
415
+ return tie(lhs.u, lhs.v) < tie(rhs.u, rhs.v);
416
+ }
417
+ };
418
+
419
+ struct sort_node_by_q
420
+ {
421
+ bool operator() (mf_node const &lhs, mf_node const &rhs)
422
+ {
423
+ return tie(lhs.v, lhs.u) < tie(rhs.v, rhs.u);
424
+ }
425
+ };
426
+
427
+ struct deleter
428
+ {
429
+ void operator() (mf_problem *prob)
430
+ {
431
+ delete[] prob->R;
432
+ delete prob;
433
+ }
434
+ };
435
+
436
+
437
+ class Utility
438
+ {
439
+ public:
440
+ Utility(mf_int f, mf_int n) : fun(f), nr_threads(n) {};
441
+ void collect_info(mf_problem &prob, mf_float &avg, mf_float &std_dev);
442
+ void collect_info_on_disk(string data_path, mf_problem &prob,
443
+ mf_float &avg, mf_float &std_dev);
444
+ void shuffle_problem(mf_problem &prob, vector<mf_int> &p_map,
445
+ vector<mf_int> &q_map);
446
+ vector<mf_node*> grid_problem(mf_problem &prob, mf_int nr_bins,
447
+ vector<mf_int> &omega_p,
448
+ vector<mf_int> &omega_q,
449
+ vector<Block> &blocks);
450
+ void grid_shuffle_scale_problem_on_disk(mf_int m, mf_int n, mf_int nr_bins,
451
+ mf_float scale, string data_path,
452
+ vector<mf_int> &p_map,
453
+ vector<mf_int> &q_map,
454
+ vector<mf_int> &omega_p,
455
+ vector<mf_int> &omega_q,
456
+ vector<BlockOnDisk> &blocks);
457
+ void scale_problem(mf_problem &prob, mf_float scale);
458
+ mf_double calc_reg1(mf_model &model, mf_float lambda_p, mf_float lambda_q,
459
+ vector<mf_int> &omega_p, vector<mf_int> &omega_q);
460
+ mf_double calc_reg2(mf_model &model, mf_float lambda_p, mf_float lambda_q,
461
+ vector<mf_int> &omega_p, vector<mf_int> &omega_q);
462
+ string get_error_legend() const;
463
+ mf_double calc_error(vector<BlockBase*> &blocks,
464
+ vector<mf_int> &cv_block_ids,
465
+ mf_model const &model);
466
+ void scale_model(mf_model &model, mf_float scale);
467
+
468
+ static mf_problem* copy_problem(mf_problem const *prob, bool copy_data);
469
+ static vector<mf_int> gen_random_map(mf_int size);
470
+ // A function used to allocate all aligned float array.
471
+ // It hides platform-specific function calls. Memory
472
+ // allocated by malloc_aligned_float must be freed by using
473
+ // free_aligned_float.
474
+ static mf_float* malloc_aligned_float(mf_long size);
475
+ // A function used to free all aligned float array.
476
+ // It hides platform-specific function calls.
477
+ static void free_aligned_float(mf_float* ptr);
478
+ // Initialization function for stochastic gradient method.
479
+ // Factor matrices P and Q are both randomly initialized.
480
+ static mf_model* init_model(mf_int loss, mf_int m, mf_int n,
481
+ mf_int k, mf_float avg,
482
+ vector<mf_int> &omega_p,
483
+ vector<mf_int> &omega_q);
484
+ // Initialization function for one-class CD.
485
+ // It does zero-initialization on factor matrix P and random initialization
486
+ // on factor matrix Q.
487
+ static mf_model* init_model(mf_int m, mf_int n, mf_int k);
488
+ static mf_float inner_product(mf_float *p, mf_float *q, mf_int k);
489
+ static vector<mf_int> gen_inv_map(vector<mf_int> &map);
490
+ static void shrink_model(mf_model &model, mf_int k_new);
491
+ static void shuffle_model(mf_model &model,
492
+ vector<mf_int> &p_map,
493
+ vector<mf_int> &q_map);
494
+ mf_int get_thread_number() const { return nr_threads; };
495
+ private:
496
+ mf_int fun;
497
+ mf_int nr_threads;
498
+ };
499
+
500
+ void Utility::collect_info(
501
+ mf_problem &prob,
502
+ mf_float &avg,
503
+ mf_float &std_dev)
504
+ {
505
+ mf_double ex = 0;
506
+ mf_double ex2 = 0;
507
+
508
+ #if defined USEOMP
509
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:ex,ex2)
510
+ #endif
511
+ for(mf_long i = 0; i < prob.nnz; ++i)
512
+ {
513
+ mf_node &N = prob.R[i];
514
+ ex += (mf_double)N.r;
515
+ ex2 += (mf_double)N.r*N.r;
516
+ }
517
+
518
+ ex /= (mf_double)prob.nnz;
519
+ ex2 /= (mf_double)prob.nnz;
520
+ avg = (mf_float)ex;
521
+ std_dev = (mf_float)sqrt(ex2-ex*ex);
522
+ }
523
+
524
+ void Utility::collect_info_on_disk(
525
+ string data_path,
526
+ mf_problem &prob,
527
+ mf_float &avg,
528
+ mf_float &std_dev)
529
+ {
530
+ mf_double ex = 0;
531
+ mf_double ex2 = 0;
532
+
533
+ ifstream source(data_path);
534
+ if(!source.is_open())
535
+ throw runtime_error("cannot open " + data_path);
536
+
537
+ for(mf_node N; source >> N.u >> N.v >> N.r;)
538
+ {
539
+ if(N.u+1 > prob.m)
540
+ prob.m = N.u+1;
541
+ if(N.v+1 > prob.n)
542
+ prob.n = N.v+1;
543
+ prob.nnz += 1;
544
+ ex += (mf_double)N.r;
545
+ ex2 += (mf_double)N.r*N.r;
546
+ }
547
+ source.close();
548
+
549
+ ex /= (mf_double)prob.nnz;
550
+ ex2 /= (mf_double)prob.nnz;
551
+ avg = (mf_float)ex;
552
+ std_dev = (mf_float)sqrt(ex2-ex*ex);
553
+ }
554
+
555
+ void Utility::scale_problem(mf_problem &prob, mf_float scale)
556
+ {
557
+ if(scale == 1.0)
558
+ return;
559
+
560
+ #if defined USEOMP
561
+ #pragma omp parallel for num_threads(nr_threads) schedule(static)
562
+ #endif
563
+ for(mf_long i = 0; i < prob.nnz; ++i)
564
+ prob.R[i].r *= scale;
565
+ }
566
+
567
+ void Utility::scale_model(mf_model &model, mf_float scale)
568
+ {
569
+ if(scale == 1.0)
570
+ return;
571
+
572
+ mf_int k = model.k;
573
+
574
+ model.b *= scale;
575
+
576
+ auto scale1 = [&] (mf_float *ptr, mf_int size, mf_float factor_scale)
577
+ {
578
+ #if defined USEOMP
579
+ #pragma omp parallel for num_threads(nr_threads) schedule(static)
580
+ #endif
581
+ for(mf_int i = 0; i < size; ++i)
582
+ {
583
+ mf_float *ptr1 = ptr+(mf_long)i*model.k;
584
+ for(mf_int d = 0; d < k; ++d)
585
+ ptr1[d] *= factor_scale;
586
+ }
587
+ };
588
+
589
+ scale1(model.P, model.m, sqrt(scale));
590
+ scale1(model.Q, model.n, sqrt(scale));
591
+ }
592
+
593
+ mf_float Utility::inner_product(mf_float *p, mf_float *q, mf_int k)
594
+ {
595
+ #if defined USESSE
596
+ __m128 XMM = _mm_setzero_ps();
597
+ for(mf_int d = 0; d < k; d += 4)
598
+ XMM = _mm_add_ps(XMM, _mm_mul_ps(
599
+ _mm_load_ps(p+d), _mm_load_ps(q+d)));
600
+ __m128 XMMtmp = _mm_add_ps(XMM, _mm_movehl_ps(XMM, XMM));
601
+ XMM = _mm_add_ps(XMM, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
602
+ mf_float product;
603
+ _mm_store_ss(&product, XMM);
604
+ return product;
605
+ #elif defined USEAVX
606
+ __m256 XMM = _mm256_setzero_ps();
607
+ for(mf_int d = 0; d < k; d += 8)
608
+ XMM = _mm256_add_ps(XMM, _mm256_mul_ps(
609
+ _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
610
+ XMM = _mm256_add_ps(XMM, _mm256_permute2f128_ps(XMM, XMM, 1));
611
+ XMM = _mm256_hadd_ps(XMM, XMM);
612
+ XMM = _mm256_hadd_ps(XMM, XMM);
613
+ mf_float product;
614
+ _mm_store_ss(&product, _mm256_castps256_ps128(XMM));
615
+ return product;
616
+ #else
617
+ return std::inner_product(p, p+k, q, (mf_float)0.0);
618
+ #endif
619
+ }
620
+
621
+ mf_double Utility::calc_reg1(mf_model &model,
622
+ mf_float lambda_p, mf_float lambda_q,
623
+ vector<mf_int> &omega_p, vector<mf_int> &omega_q)
624
+ {
625
+ auto calc_reg1_core = [&] (mf_float *ptr, mf_int size,
626
+ vector<mf_int> &omega)
627
+ {
628
+ mf_double reg = 0;
629
+ for(mf_int i = 0; i < size; ++i)
630
+ {
631
+ if(omega[i] <= 0)
632
+ continue;
633
+
634
+ mf_float tmp = 0;
635
+ for(mf_int j = 0; j < model.k; ++j)
636
+ tmp += abs(ptr[(mf_long)i*model.k+j]);
637
+ reg += omega[i]*tmp;
638
+ }
639
+ return reg;
640
+ };
641
+
642
+ return lambda_p*calc_reg1_core(model.P, model.m, omega_p)+
643
+ lambda_q*calc_reg1_core(model.Q, model.n, omega_q);
644
+ }
645
+
646
+ mf_double Utility::calc_reg2(mf_model &model,
647
+ mf_float lambda_p, mf_float lambda_q,
648
+ vector<mf_int> &omega_p, vector<mf_int> &omega_q)
649
+ {
650
+ auto calc_reg2_core = [&] (mf_float *ptr, mf_int size,
651
+ vector<mf_int> &omega)
652
+ {
653
+ mf_double reg = 0;
654
+ #if defined USEOMP
655
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:reg)
656
+ #endif
657
+ for(mf_int i = 0; i < size; ++i)
658
+ {
659
+ if(omega[i] <= 0)
660
+ continue;
661
+
662
+ mf_float *ptr1 = ptr+(mf_long)i*model.k;
663
+ reg += omega[i]*Utility::inner_product(ptr1, ptr1, model.k);
664
+ }
665
+
666
+ return reg;
667
+ };
668
+
669
+ return lambda_p*calc_reg2_core(model.P, model.m, omega_p) +
670
+ lambda_q*calc_reg2_core(model.Q, model.n, omega_q);
671
+ }
672
+
673
+ mf_double Utility::calc_error(
674
+ vector<BlockBase*> &blocks,
675
+ vector<mf_int> &cv_block_ids,
676
+ mf_model const &model)
677
+ {
678
+ mf_double error = 0;
679
+ if(fun == P_L2_MFR || fun == P_L1_MFR || fun == P_KL_MFR ||
680
+ fun == P_LR_MFC || fun == P_L2_MFC || fun == P_L1_MFC)
681
+ {
682
+ #if defined USEOMP
683
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
684
+ #endif
685
+ for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
686
+ {
687
+ BlockBase *block = blocks[cv_block_ids[i]];
688
+ block->reload();
689
+ while(block->move_next())
690
+ {
691
+ mf_node const &N = *(block->get_current());
692
+ mf_float z = mf_predict(&model, N.u, N.v);
693
+ switch(fun)
694
+ {
695
+ case P_L2_MFR:
696
+ error += pow(N.r-z, 2);
697
+ break;
698
+ case P_L1_MFR:
699
+ error += abs(N.r-z);
700
+ break;
701
+ case P_KL_MFR:
702
+ error += N.r*log(N.r/z)-N.r+z;
703
+ break;
704
+ case P_LR_MFC:
705
+ if(N.r > 0)
706
+ error += log(1.0+exp(-z));
707
+ else
708
+ error += log(1.0+exp(z));
709
+ break;
710
+ case P_L2_MFC:
711
+ case P_L1_MFC:
712
+ if(N.r > 0)
713
+ error += z > 0? 1: 0;
714
+ else
715
+ error += z < 0? 1: 0;
716
+ break;
717
+ default:
718
+ throw invalid_argument("unknown error function");
719
+ break;
720
+ }
721
+ }
722
+ block->free();
723
+ }
724
+ }
725
+ else
726
+ {
727
+ minstd_rand0 generator(rand());
728
+ switch(fun)
729
+ {
730
+ case P_ROW_BPR_MFOC:
731
+ {
732
+ uniform_int_distribution<mf_int> distribution(0, model.n-1);
733
+ #if defined USEOMP
734
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
735
+ #endif
736
+ for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
737
+ {
738
+ BlockBase *block = blocks[cv_block_ids[i]];
739
+ block->reload();
740
+ while(block->move_next())
741
+ {
742
+ mf_node const &N = *(block->get_current());
743
+ mf_int w = distribution(generator);
744
+ error += log(1+exp(mf_predict(&model, N.u, w)-
745
+ mf_predict(&model, N.u, N.v)));
746
+ }
747
+ block->free();
748
+ }
749
+ break;
750
+ }
751
+ case P_COL_BPR_MFOC:
752
+ {
753
+ uniform_int_distribution<mf_int> distribution(0, model.m-1);
754
+ #if defined USEOMP
755
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
756
+ #endif
757
+ for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
758
+ {
759
+ BlockBase *block = blocks[cv_block_ids[i]];
760
+ block->reload();
761
+ while(block->move_next())
762
+ {
763
+ mf_node const &N = *(block->get_current());
764
+ mf_int w = distribution(generator);
765
+ error += log(1+exp(mf_predict(&model, w, N.v)-
766
+ mf_predict(&model, N.u, N.v)));
767
+ }
768
+ block->free();
769
+ }
770
+ break;
771
+ }
772
+ default:
773
+ {
774
+ throw invalid_argument("unknown error function");
775
+ break;
776
+ }
777
+ }
778
+ }
779
+
780
+ return error;
781
+ }
782
+
783
+ string Utility::get_error_legend() const
784
+ {
785
+ switch(fun)
786
+ {
787
+ case P_L2_MFR:
788
+ return string("rmse");
789
+ break;
790
+ case P_L1_MFR:
791
+ return string("mae");
792
+ break;
793
+ case P_KL_MFR:
794
+ return string("gkl");
795
+ break;
796
+ case P_LR_MFC:
797
+ return string("logloss");
798
+ break;
799
+ case P_L2_MFC:
800
+ case P_L1_MFC:
801
+ return string("accuracy");
802
+ break;
803
+ case P_ROW_BPR_MFOC:
804
+ case P_COL_BPR_MFOC:
805
+ return string("bprloss");
806
+ break;
807
+ case P_L2_MFOC:
808
+ return string("sqerror");
809
+ default:
810
+ return string();
811
+ break;
812
+ }
813
+ }
814
+
815
+ void Utility::shuffle_problem(
816
+ mf_problem &prob,
817
+ vector<mf_int> &p_map,
818
+ vector<mf_int> &q_map)
819
+ {
820
+ #if defined USEOMP
821
+ #pragma omp parallel for num_threads(nr_threads) schedule(static)
822
+ #endif
823
+ for(mf_long i = 0; i < prob.nnz; ++i)
824
+ {
825
+ mf_node &N = prob.R[i];
826
+ if(N.u < (mf_long)p_map.size())
827
+ N.u = p_map[N.u];
828
+ if(N.v < (mf_long)q_map.size())
829
+ N.v = q_map[N.v];
830
+ }
831
+ }
832
+
833
+ vector<mf_node*> Utility::grid_problem(
834
+ mf_problem &prob,
835
+ mf_int nr_bins,
836
+ vector<mf_int> &omega_p,
837
+ vector<mf_int> &omega_q,
838
+ vector<Block> &blocks)
839
+ {
840
+ vector<mf_long> counts(nr_bins*nr_bins, 0);
841
+
842
+ mf_int seg_p = (mf_int)ceil((double)prob.m/nr_bins);
843
+ mf_int seg_q = (mf_int)ceil((double)prob.n/nr_bins);
844
+
845
+ auto get_block_id = [=] (mf_int u, mf_int v)
846
+ {
847
+ return (u/seg_p)*nr_bins+v/seg_q;
848
+ };
849
+
850
+ for(mf_long i = 0; i < prob.nnz; ++i)
851
+ {
852
+ mf_node &N = prob.R[i];
853
+ mf_int block = get_block_id(N.u, N.v);
854
+ counts[block] += 1;
855
+ omega_p[N.u] += 1;
856
+ omega_q[N.v] += 1;
857
+ }
858
+
859
+ vector<mf_node*> ptrs(nr_bins*nr_bins+1);
860
+ mf_node *ptr = prob.R;
861
+ ptrs[0] = ptr;
862
+ for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
863
+ ptrs[block+1] = ptrs[block] + counts[block];
864
+
865
+ vector<mf_node*> pivots(ptrs.begin(), ptrs.end()-1);
866
+ for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
867
+ {
868
+ for(mf_node* pivot = pivots[block]; pivot != ptrs[block+1];)
869
+ {
870
+ mf_int curr_block = get_block_id(pivot->u, pivot->v);
871
+ if(curr_block == block)
872
+ {
873
+ ++pivot;
874
+ continue;
875
+ }
876
+
877
+ mf_node *next = pivots[curr_block];
878
+ swap(*pivot, *next);
879
+ pivots[curr_block] += 1;
880
+ }
881
+ }
882
+
883
+ #if defined USEOMP
884
+ #pragma omp parallel for num_threads(nr_threads) schedule(dynamic)
885
+ #endif
886
+ for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
887
+ {
888
+ if(prob.m > prob.n)
889
+ sort(ptrs[block], ptrs[block+1], sort_node_by_p());
890
+ else
891
+ sort(ptrs[block], ptrs[block+1], sort_node_by_q());
892
+ }
893
+
894
+ for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
895
+ blocks[i].tie_to(ptrs[i], ptrs[i+1]);
896
+
897
+ return ptrs;
898
+ }
899
+
900
+ void Utility::grid_shuffle_scale_problem_on_disk(
901
+ mf_int m, mf_int n, mf_int nr_bins,
902
+ mf_float scale, string data_path,
903
+ vector<mf_int> &p_map, vector<mf_int> &q_map,
904
+ vector<mf_int> &omega_p, vector<mf_int> &omega_q,
905
+ vector<BlockOnDisk> &blocks)
906
+ {
907
+ string const buffer_path = data_path+string(".disk");
908
+ mf_int seg_p = (mf_int)ceil((double)m/nr_bins);
909
+ mf_int seg_q = (mf_int)ceil((double)n/nr_bins);
910
+ vector<mf_long> counts(nr_bins*nr_bins+1, 0);
911
+ vector<mf_long> pivots(nr_bins*nr_bins, 0);
912
+ ifstream source(data_path);
913
+ fstream buffer(buffer_path, fstream::in|fstream::out|
914
+ fstream::binary|fstream::trunc);
915
+ auto get_block_id = [=] (mf_int u, mf_int v)
916
+ {
917
+ return (u/seg_p)*nr_bins+v/seg_q;
918
+ };
919
+
920
+ if(!source)
921
+ throw ios::failure(string("cannot to open ")+data_path);
922
+ if(!buffer)
923
+ throw ios::failure(string("cannot to open ")+buffer_path);
924
+
925
+ for(mf_node N; source >> N.u >> N.v >> N.r;)
926
+ {
927
+ N.u = p_map[N.u];
928
+ N.v = q_map[N.v];
929
+ mf_int bid = get_block_id(N.u, N.v);
930
+ omega_p[N.u] += 1;
931
+ omega_q[N.v] += 1;
932
+ counts[bid+1] += 1;
933
+ }
934
+
935
+ for(mf_int i = 1; i < nr_bins*nr_bins+1; ++i)
936
+ {
937
+ counts[i] += counts[i-1];
938
+ pivots[i-1] = counts[i-1];
939
+ }
940
+
941
+ source.clear();
942
+ source.seekg(0);
943
+ for(mf_node N; source >> N.u >> N.v >> N.r;)
944
+ {
945
+ N.u = p_map[N.u];
946
+ N.v = q_map[N.v];
947
+ N.r /= scale;
948
+ mf_int bid = get_block_id(N.u, N.v);
949
+ buffer.seekp(pivots[bid]*sizeof(mf_node));
950
+ buffer.write((char*)&N, sizeof(mf_node));
951
+ pivots[bid] += 1;
952
+ }
953
+
954
+ for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
955
+ {
956
+ vector<mf_node> nodes(static_cast<size_t>(counts[i+1]-counts[i]));
957
+ buffer.clear();
958
+ buffer.seekg(counts[i]*sizeof(mf_node));
959
+ buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
960
+
961
+ if(m > n)
962
+ sort(nodes.begin(), nodes.end(), sort_node_by_p());
963
+ else
964
+ sort(nodes.begin(), nodes.end(), sort_node_by_q());
965
+
966
+ buffer.clear();
967
+ buffer.seekp(counts[i]*sizeof(mf_node));
968
+ buffer.write((char*)nodes.data(), sizeof(mf_node)*nodes.size());
969
+ buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
970
+ }
971
+
972
+ for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
973
+ blocks[i].tie_to(buffer_path, counts[i], counts[i+1]);
974
+ }
975
+
976
+ mf_float* Utility::malloc_aligned_float(mf_long size)
977
+ {
978
+ // Check if conversion from mf_long to size_t causes overflow.
979
+ if (size > numeric_limits<std::size_t>::max() / sizeof(mf_float) + 1)
980
+ throw bad_alloc();
981
+ // [REVIEW] I hope one day we can use C11 aligned_alloc to replace
982
+ // platform-depedent functions below. Both of Windows and OSX currently
983
+ // don't support that function.
984
+ void *ptr = nullptr;
985
+ #ifdef _WIN32
986
+ ptr = _aligned_malloc(static_cast<size_t>(size*sizeof(mf_float)),
987
+ kALIGNByte);
988
+ #else
989
+ int status = posix_memalign(&ptr, kALIGNByte, size*sizeof(mf_float));
990
+ if(status != 0)
991
+ throw bad_alloc();
992
+ #endif
993
+ if(ptr == nullptr)
994
+ throw bad_alloc();
995
+
996
+ return (mf_float*)ptr;
997
+ }
998
+
999
+ void Utility::free_aligned_float(mf_float *ptr)
1000
+ {
1001
+ #ifdef _WIN32
1002
+ // Unfortunately, Visual Studio doesn't want to support the
1003
+ // cross-platform allocation below.
1004
+ _aligned_free(ptr);
1005
+ #else
1006
+ free(ptr);
1007
+ #endif
1008
+ }
1009
+
1010
+ mf_model* Utility::init_model(mf_int fun,
1011
+ mf_int m, mf_int n,
1012
+ mf_int k, mf_float avg,
1013
+ vector<mf_int> &omega_p,
1014
+ vector<mf_int> &omega_q)
1015
+ {
1016
+ mf_int k_real = k;
1017
+ mf_int k_aligned = (mf_int)ceil(mf_double(k)/kALIGN)*kALIGN;
1018
+
1019
+ mf_model *model = new mf_model;
1020
+
1021
+ model->fun = fun;
1022
+ model->m = m;
1023
+ model->n = n;
1024
+ model->k = k_aligned;
1025
+ model->b = avg;
1026
+ model->P = nullptr;
1027
+ model->Q = nullptr;
1028
+
1029
+ mf_float scale = (mf_float)sqrt(1.0/k_real);
1030
+ default_random_engine generator;
1031
+ uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1032
+
1033
+ try
1034
+ {
1035
+ model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1036
+ model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1037
+ }
1038
+ catch(bad_alloc const &e)
1039
+ {
1040
+ cerr << e.what() << endl;
1041
+ mf_destroy_model(&model);
1042
+ throw;
1043
+ }
1044
+
1045
+ auto init1 = [&](mf_float *start_ptr, mf_long size, vector<mf_int> counts)
1046
+ {
1047
+ memset(start_ptr, 0, static_cast<size_t>(
1048
+ sizeof(mf_float) * size*model->k));
1049
+ for(mf_long i = 0; i < size; ++i)
1050
+ {
1051
+ mf_float * ptr = start_ptr + i*model->k;
1052
+ if(counts[static_cast<size_t>(i)] > 0)
1053
+ for(mf_long d = 0; d < k_real; ++d, ++ptr)
1054
+ *ptr = (mf_float)(distribution(generator)*scale);
1055
+ else
1056
+ if(fun != P_ROW_BPR_MFOC && fun != P_COL_BPR_MFOC) // unseen for bpr is 0
1057
+ for(mf_long d = 0; d < k_real; ++d, ++ptr)
1058
+ *ptr = numeric_limits<mf_float>::quiet_NaN();
1059
+ }
1060
+ };
1061
+
1062
+ init1(model->P, m, omega_p);
1063
+ init1(model->Q, n, omega_q);
1064
+
1065
+ return model;
1066
+ }
1067
+
1068
+ // Initialize P=[\bar{p}_1, ..., \bar{p}_d] and Q=[\bar{q}_1, ..., \bar{q}_d].
1069
+ // Note that \bar{q}_{kv} is Q[k*n+v] and \bar{p}_{ku} is P[k*m+u]. One may
1070
+ // notice that P and Q here are actually the transposes of P and Q in fpsg(...)
1071
+ // because fpsg(...) uses P^TQ (where P and Q are respectively k-by-m and
1072
+ // k-by-n) to approximate the given rating matrix R while ccd_one_class(...)
1073
+ // uses PQ^T (where P and Q are respectively m-by-k and n-by-k.
1074
+ mf_model* Utility::init_model(mf_int m, mf_int n, mf_int k)
1075
+ {
1076
+ mf_model *model = new mf_model;
1077
+
1078
+ model->fun = P_L2_MFOC;
1079
+ model->m = m;
1080
+ model->n = n;
1081
+ model->k = k;
1082
+ model->b = 0.0; // One-class matrix factorization doesn't have bias.
1083
+ model->P = nullptr;
1084
+ model->Q = nullptr;
1085
+
1086
+ try
1087
+ {
1088
+ model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1089
+ model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1090
+ }
1091
+ catch(bad_alloc const &e)
1092
+ {
1093
+ cerr << e.what() << endl;
1094
+ mf_destroy_model(&model);
1095
+ throw;
1096
+ }
1097
+
1098
+ // Our initialization strategy is that all P's elements are zero and do
1099
+ // random initization on Q. Thus, all initial predicted ratings are all zero
1100
+ // since the approximated rating matrix is PQ^T.
1101
+
1102
+ // Initialize P with zeros
1103
+ for(mf_long i = 0; i < k * m; ++i)
1104
+ model->P[i] = 0.0;
1105
+
1106
+ // Initialize Q with random numbers
1107
+ default_random_engine generator;
1108
+ uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1109
+ for(mf_long i = 0; i < k * n; ++i)
1110
+ model->Q[i] = distribution(generator);
1111
+
1112
+ return model;
1113
+ }
1114
+
1115
+ vector<mf_int> Utility::gen_random_map(mf_int size)
1116
+ {
1117
+ srand(0);
1118
+ vector<mf_int> map(size, 0);
1119
+ for(mf_int i = 0; i < size; ++i)
1120
+ map[i] = i;
1121
+ random_shuffle(map.begin(), map.end());
1122
+ return map;
1123
+ }
1124
+
1125
+ vector<mf_int> Utility::gen_inv_map(vector<mf_int> &map)
1126
+ {
1127
+ vector<mf_int> inv_map(map.size());
1128
+ for(mf_int i = 0; i < (mf_long)map.size(); ++i)
1129
+ inv_map[map[i]] = i;
1130
+ return inv_map;
1131
+ }
1132
+
1133
+ void Utility::shuffle_model(
1134
+ mf_model &model,
1135
+ vector<mf_int> &p_map,
1136
+ vector<mf_int> &q_map)
1137
+ {
1138
+ auto inv_shuffle1 = [] (mf_float *vec, vector<mf_int> &map,
1139
+ mf_int size, mf_int k)
1140
+ {
1141
+ for(mf_int pivot = 0; pivot < size;)
1142
+ {
1143
+ if(pivot == map[pivot])
1144
+ {
1145
+ ++pivot;
1146
+ continue;
1147
+ }
1148
+
1149
+ mf_int next = map[pivot];
1150
+
1151
+ for(mf_int d = 0; d < k; ++d)
1152
+ swap(*(vec+(mf_long)pivot*k+d), *(vec+(mf_long)next*k+d));
1153
+
1154
+ map[pivot] = map[next];
1155
+ map[next] = next;
1156
+ }
1157
+ };
1158
+
1159
+ inv_shuffle1(model.P, p_map, model.m, model.k);
1160
+ inv_shuffle1(model.Q, q_map, model.n, model.k);
1161
+ }
1162
+
1163
+ void Utility::shrink_model(mf_model &model, mf_int k_new)
1164
+ {
1165
+ mf_int k_old = model.k;
1166
+ model.k = k_new;
1167
+
1168
+ auto shrink1 = [&] (mf_float *ptr, mf_int size)
1169
+ {
1170
+ for(mf_int i = 0; i < size; ++i)
1171
+ {
1172
+ mf_float *src = ptr+(mf_long)i*k_old;
1173
+ mf_float *dst = ptr+(mf_long)i*k_new;
1174
+ copy(src, src+k_new, dst);
1175
+ }
1176
+ };
1177
+
1178
+ shrink1(model.P, model.m);
1179
+ shrink1(model.Q, model.n);
1180
+ }
1181
+
1182
+ mf_problem* Utility::copy_problem(mf_problem const *prob, bool copy_data)
1183
+ {
1184
+ mf_problem *new_prob = new mf_problem;
1185
+
1186
+ if(prob == nullptr)
1187
+ {
1188
+ new_prob->m = 0;
1189
+ new_prob->n = 0;
1190
+ new_prob->nnz = 0;
1191
+ new_prob->R = nullptr;
1192
+
1193
+ return new_prob;
1194
+ }
1195
+
1196
+ new_prob->m = prob->m;
1197
+ new_prob->n = prob->n;
1198
+ new_prob->nnz = prob->nnz;
1199
+
1200
+ if(copy_data)
1201
+ {
1202
+ try
1203
+ {
1204
+ new_prob->R = new mf_node[static_cast<size_t>(prob->nnz)];
1205
+ copy(prob->R, prob->R+prob->nnz, new_prob->R);
1206
+ }
1207
+ catch(...)
1208
+ {
1209
+ delete new_prob;
1210
+ throw;
1211
+ }
1212
+ }
1213
+ else
1214
+ {
1215
+ new_prob->R = prob->R;
1216
+ }
1217
+
1218
+ return new_prob;
1219
+ }
1220
+
1221
+ //--------------------------------------
1222
+ //-----The base class of all solvers----
1223
+ //--------------------------------------
1224
+
1225
+ class SolverBase
1226
+ {
1227
+ public:
1228
+ SolverBase(Scheduler &scheduler, vector<BlockBase*> &blocks,
1229
+ mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
1230
+ bool &slow_only)
1231
+ : scheduler(scheduler), blocks(blocks), PG(PG), QG(QG),
1232
+ model(model), param(param), slow_only(slow_only) {}
1233
+ void run();
1234
+ SolverBase(const SolverBase&) = delete;
1235
+ SolverBase& operator=(const SolverBase&) = delete;
1236
+ // Solver is stateless functor, so default destructor should be
1237
+ // good enough.
1238
+ virtual ~SolverBase() = default;
1239
+
1240
+ protected:
1241
+ #if defined USESSE
1242
+ static void calc_z(__m128 &XMMz, mf_int k, mf_float *p, mf_float *q);
1243
+ virtual void load_fixed_variables(
1244
+ __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1245
+ __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
1246
+ __m128 &XMMeta, __m128 &XMMrk_slow,
1247
+ __m128 &XMMrk_fast);
1248
+ virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1249
+ virtual void prepare_for_sg_update(
1250
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1251
+ virtual void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1252
+ __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1253
+ __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
1254
+ __m128 XMMeta, __m128 XMMrk) = 0;
1255
+ virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1256
+ #elif defined USEAVX
1257
+ static void calc_z(__m256 &XMMz, mf_int k, mf_float *p, mf_float *q);
1258
+ virtual void load_fixed_variables(
1259
+ __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1260
+ __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
1261
+ __m256 &XMMeta, __m256 &XMMrk_slow,
1262
+ __m256 &XMMrk_fast);
1263
+ virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1264
+ virtual void prepare_for_sg_update(
1265
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1266
+ virtual void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1267
+ __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1268
+ __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
1269
+ __m256 XMMeta, __m256 XMMrk) = 0;
1270
+ virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1271
+ #else
1272
+ static void calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q);
1273
+ virtual void load_fixed_variables();
1274
+ virtual void arrange_block();
1275
+ virtual void prepare_for_sg_update() = 0;
1276
+ virtual void sg_update(mf_int d_begin, mf_int d_end, mf_float rk) = 0;
1277
+ virtual void finalize();
1278
+ static float qrsqrt(float x);
1279
+ #endif
1280
+ virtual void update() { ++pG; ++qG; };
1281
+
1282
+ Scheduler &scheduler;
1283
+ vector<BlockBase*> &blocks;
1284
+ BlockBase *block;
1285
+ mf_float *PG;
1286
+ mf_float *QG;
1287
+ mf_model &model;
1288
+ mf_parameter param;
1289
+ bool &slow_only;
1290
+
1291
+ mf_node *N;
1292
+ mf_float z;
1293
+ mf_double loss;
1294
+ mf_double error;
1295
+ mf_float *p;
1296
+ mf_float *q;
1297
+ mf_float *pG;
1298
+ mf_float *qG;
1299
+ mf_int bid;
1300
+
1301
+ mf_float lambda_p1;
1302
+ mf_float lambda_q1;
1303
+ mf_float lambda_p2;
1304
+ mf_float lambda_q2;
1305
+ mf_float rk_slow;
1306
+ mf_float rk_fast;
1307
+ };
1308
+
1309
+ #if defined USESSE
1310
+ inline void SolverBase::run()
1311
+ {
1312
+ __m128d XMMloss;
1313
+ __m128d XMMerror;
1314
+ __m128 XMMz;
1315
+ __m128 XMMlambda_p1;
1316
+ __m128 XMMlambda_q1;
1317
+ __m128 XMMlambda_p2;
1318
+ __m128 XMMlambda_q2;
1319
+ __m128 XMMeta;
1320
+ __m128 XMMrk_slow;
1321
+ __m128 XMMrk_fast;
1322
+ load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1323
+ XMMlambda_p2, XMMlambda_q2,
1324
+ XMMeta, XMMrk_slow,
1325
+ XMMrk_fast);
1326
+ while(!scheduler.is_terminated())
1327
+ {
1328
+ arrange_block(XMMloss, XMMerror);
1329
+ while(block->move_next())
1330
+ {
1331
+ N = block->get_current();
1332
+ p = model.P+(mf_long)N->u*model.k;
1333
+ q = model.Q+(mf_long)N->v*model.k;
1334
+ pG = PG+N->u*2;
1335
+ qG = QG+N->v*2;
1336
+ prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1337
+ sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1338
+ XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1339
+ if(slow_only)
1340
+ continue;
1341
+ update();
1342
+ sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1343
+ XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1344
+ }
1345
+ finalize(XMMloss, XMMerror);
1346
+ }
1347
+ }
1348
+
1349
+ void SolverBase::load_fixed_variables(
1350
+ __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1351
+ __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
1352
+ __m128 &XMMeta, __m128 &XMMrk_slow,
1353
+ __m128 &XMMrk_fast)
1354
+ {
1355
+ XMMlambda_p1 = _mm_set1_ps(param.lambda_p1);
1356
+ XMMlambda_q1 = _mm_set1_ps(param.lambda_q1);
1357
+ XMMlambda_p2 = _mm_set1_ps(param.lambda_p2);
1358
+ XMMlambda_q2 = _mm_set1_ps(param.lambda_q2);
1359
+ XMMeta = _mm_set1_ps(param.eta);
1360
+ XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
1361
+ XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
1362
+ }
1363
+
1364
+ void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1365
+ {
1366
+ XMMloss = _mm_setzero_pd();
1367
+ XMMerror = _mm_setzero_pd();
1368
+ bid = scheduler.get_job();
1369
+ block = blocks[bid];
1370
+ block->reload();
1371
+ }
1372
+
1373
+ inline void SolverBase::calc_z(
1374
+ __m128 &XMMz, mf_int k, mf_float *p, mf_float *q)
1375
+ {
1376
+ XMMz = _mm_setzero_ps();
1377
+ for(mf_int d = 0; d < k; d += 4)
1378
+ XMMz = _mm_add_ps(XMMz, _mm_mul_ps(
1379
+ _mm_load_ps(p+d), _mm_load_ps(q+d)));
1380
+ // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
1381
+ // high-bit to low-bit, where "+" means concatenating two arrays.
1382
+ __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
1383
+ // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
1384
+ // high-bit to low-bit, where "+" means concatenating two arrays.
1385
+ XMMz = _mm_add_ps(XMMtmp, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
1386
+ }
1387
+
1388
+ void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1389
+ {
1390
+ _mm_store_sd(&loss, XMMloss);
1391
+ _mm_store_sd(&error, XMMerror);
1392
+ block->free();
1393
+ scheduler.put_job(bid, loss, error);
1394
+ }
1395
+ #elif defined USEAVX
1396
+ inline void SolverBase::run()
1397
+ {
1398
+ __m128d XMMloss;
1399
+ __m128d XMMerror;
1400
+ __m256 XMMz;
1401
+ __m256 XMMlambda_p1;
1402
+ __m256 XMMlambda_q1;
1403
+ __m256 XMMlambda_p2;
1404
+ __m256 XMMlambda_q2;
1405
+ __m256 XMMeta;
1406
+ __m256 XMMrk_slow;
1407
+ __m256 XMMrk_fast;
1408
+ load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1409
+ XMMlambda_p2, XMMlambda_q2,
1410
+ XMMeta, XMMrk_slow, XMMrk_fast);
1411
+ while(!scheduler.is_terminated())
1412
+ {
1413
+ arrange_block(XMMloss, XMMerror);
1414
+ while(block->move_next())
1415
+ {
1416
+ N = block->get_current();
1417
+ p = model.P+(mf_long)N->u*model.k;
1418
+ q = model.Q+(mf_long)N->v*model.k;
1419
+ pG = PG+N->u*2;
1420
+ qG = QG+N->v*2;
1421
+ prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1422
+ sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1423
+ XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1424
+ if(slow_only)
1425
+ continue;
1426
+ update();
1427
+ sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1428
+ XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_fast);
1429
+ }
1430
+ finalize(XMMloss, XMMerror);
1431
+ }
1432
+ }
1433
+
1434
+ void SolverBase::load_fixed_variables(
1435
+ __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1436
+ __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
1437
+ __m256 &XMMeta, __m256 &XMMrk_slow,
1438
+ __m256 &XMMrk_fast)
1439
+ {
1440
+ XMMlambda_p1 = _mm256_set1_ps(param.lambda_p1);
1441
+ XMMlambda_q1 = _mm256_set1_ps(param.lambda_q1);
1442
+ XMMlambda_p2 = _mm256_set1_ps(param.lambda_p2);
1443
+ XMMlambda_q2 = _mm256_set1_ps(param.lambda_q2);
1444
+ XMMeta = _mm256_set1_ps(param.eta);
1445
+ XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
1446
+ XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
1447
+ }
1448
+
1449
+ void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1450
+ {
1451
+ XMMloss = _mm_setzero_pd();
1452
+ XMMerror = _mm_setzero_pd();
1453
+ bid = scheduler.get_job();
1454
+ block = blocks[bid];
1455
+ block->reload();
1456
+ }
1457
+
1458
+ inline void SolverBase::calc_z(
1459
+ __m256 &XMMz, mf_int k, mf_float *p, mf_float *q)
1460
+ {
1461
+ XMMz = _mm256_setzero_ps();
1462
+ for(mf_int d = 0; d < k; d += 8)
1463
+ XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
1464
+ _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
1465
+ XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
1466
+ XMMz = _mm256_hadd_ps(XMMz, XMMz);
1467
+ XMMz = _mm256_hadd_ps(XMMz, XMMz);
1468
+ }
1469
+
1470
+ void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1471
+ {
1472
+ _mm_store_sd(&loss, XMMloss);
1473
+ _mm_store_sd(&error, XMMerror);
1474
+ block->free();
1475
+ scheduler.put_job(bid, loss, error);
1476
+ }
1477
+ #else
1478
+ inline void SolverBase::run()
1479
+ {
1480
+ load_fixed_variables();
1481
+ while(!scheduler.is_terminated())
1482
+ {
1483
+ arrange_block();
1484
+ while(block->move_next())
1485
+ {
1486
+ N = block->get_current();
1487
+ p = model.P+(mf_long)N->u*model.k;
1488
+ q = model.Q+(mf_long)N->v*model.k;
1489
+ pG = PG+N->u*2;
1490
+ qG = QG+N->v*2;
1491
+ prepare_for_sg_update();
1492
+ sg_update(0, kALIGN, rk_slow);
1493
+ if(slow_only)
1494
+ continue;
1495
+ update();
1496
+ sg_update(kALIGN, model.k, rk_fast);
1497
+ }
1498
+ finalize();
1499
+ }
1500
+ }
1501
+
1502
+ inline float SolverBase::qrsqrt(float x)
1503
+ {
1504
+ float xhalf = 0.5f*x;
1505
+ uint32_t i;
1506
+ memcpy(&i, &x, sizeof(i));
1507
+ i = 0x5f375a86 - (i>>1);
1508
+ memcpy(&x, &i, sizeof(i));
1509
+ x = x*(1.5f - xhalf*x*x);
1510
+ return x;
1511
+ }
1512
+
1513
+ void SolverBase::load_fixed_variables()
1514
+ {
1515
+ lambda_p1 = param.lambda_p1;
1516
+ lambda_q1 = param.lambda_q1;
1517
+ lambda_p2 = param.lambda_p2;
1518
+ lambda_q2 = param.lambda_q2;
1519
+ rk_slow = (mf_float)1.0/kALIGN;
1520
+ rk_fast = (mf_float)1.0/(model.k-kALIGN);
1521
+ }
1522
+
1523
+ void SolverBase::arrange_block()
1524
+ {
1525
+ loss = 0.0;
1526
+ error = 0.0;
1527
+ bid = scheduler.get_job();
1528
+ block = blocks[bid];
1529
+ block->reload();
1530
+ }
1531
+
1532
+ inline void SolverBase::calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q)
1533
+ {
1534
+ z = 0;
1535
+ for(mf_int d = 0; d < k; ++d)
1536
+ z += p[d]*q[d];
1537
+ }
1538
+
1539
+ void SolverBase::finalize()
1540
+ {
1541
+ block->free();
1542
+ scheduler.put_job(bid, loss, error);
1543
+ }
1544
+ #endif
1545
+
1546
+ //--------------------------------------
1547
+ //-----Real-valued MF and binary MF-----
1548
+ //--------------------------------------
1549
+
1550
+ class MFSolver: public SolverBase
1551
+ {
1552
+ public:
1553
+ MFSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
1554
+ mf_float *PG, mf_float *QG, mf_model &model,
1555
+ mf_parameter param, bool &slow_only)
1556
+ : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only) {}
1557
+
1558
+ protected:
1559
+ #if defined USESSE
1560
+ void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1561
+ __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1562
+ __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1563
+ __m128 XMMeta, __m128 XMMrk);
1564
+ #elif defined USEAVX
1565
+ void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1566
+ __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1567
+ __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1568
+ __m256 XMMeta, __m256 XMMrk);
1569
+ #else
1570
+ void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
1571
+ #endif
1572
+ };
1573
+
1574
+ #if defined USESSE
1575
+ void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1576
+ __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1577
+ __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1578
+ __m128 XMMeta, __m128 XMMrk)
1579
+ {
1580
+ __m128 XMMpG = _mm_load1_ps(pG);
1581
+ __m128 XMMqG = _mm_load1_ps(qG);
1582
+ __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
1583
+ __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
1584
+ __m128 XMMpG1 = _mm_setzero_ps();
1585
+ __m128 XMMqG1 = _mm_setzero_ps();
1586
+
1587
+ for(mf_int d = d_begin; d < d_end; d += 4)
1588
+ {
1589
+ __m128 XMMp = _mm_load_ps(p+d);
1590
+ __m128 XMMq = _mm_load_ps(q+d);
1591
+
1592
+ __m128 XMMpg = _mm_sub_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
1593
+ _mm_mul_ps(XMMz, XMMq));
1594
+ __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
1595
+ _mm_mul_ps(XMMz, XMMp));
1596
+
1597
+ XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
1598
+ XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
1599
+
1600
+ XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
1601
+ XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
1602
+
1603
+ _mm_store_ps(p+d, XMMp);
1604
+ _mm_store_ps(q+d, XMMq);
1605
+ }
1606
+
1607
+ mf_float tmp = 0;
1608
+ _mm_store_ss(&tmp, XMMlambda_p1);
1609
+ if(tmp > 0)
1610
+ {
1611
+ for(mf_int d = d_begin; d < d_end; d += 4)
1612
+ {
1613
+ __m128 XMMp = _mm_load_ps(p+d);
1614
+ __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
1615
+ _mm_set1_ps(-0.0f));
1616
+ XMMp = _mm_xor_ps(XMMflip,
1617
+ _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
1618
+ _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
1619
+ _mm_store_ps(p+d, XMMp);
1620
+ }
1621
+ }
1622
+
1623
+ _mm_store_ss(&tmp, XMMlambda_q1);
1624
+ if(tmp > 0)
1625
+ {
1626
+ for(mf_int d = d_begin; d < d_end; d += 4)
1627
+ {
1628
+ __m128 XMMq = _mm_load_ps(q+d);
1629
+ __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
1630
+ _mm_set1_ps(-0.0f));
1631
+ XMMq = _mm_xor_ps(XMMflip,
1632
+ _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
1633
+ _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
1634
+ _mm_store_ps(q+d, XMMq);
1635
+ }
1636
+ }
1637
+
1638
+ if(param.do_nmf)
1639
+ {
1640
+ for(mf_int d = d_begin; d < d_end; d += 4)
1641
+ {
1642
+ __m128 XMMp = _mm_load_ps(p+d);
1643
+ __m128 XMMq = _mm_load_ps(q+d);
1644
+ XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
1645
+ XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
1646
+ _mm_store_ps(p+d, XMMp);
1647
+ _mm_store_ps(q+d, XMMq);
1648
+ }
1649
+ }
1650
+
1651
+ __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
1652
+ XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1653
+ XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
1654
+ _mm_store_ss(pG, XMMpG);
1655
+
1656
+ XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
1657
+ XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1658
+ XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
1659
+ _mm_store_ss(qG, XMMqG);
1660
+ }
1661
+ #elif defined USEAVX
1662
+ void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1663
+ __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1664
+ __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1665
+ __m256 XMMeta, __m256 XMMrk)
1666
+ {
1667
+ __m256 XMMpG = _mm256_broadcast_ss(pG);
1668
+ __m256 XMMqG = _mm256_broadcast_ss(qG);
1669
+ __m256 XMMeta_p = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
1670
+ __m256 XMMeta_q = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
1671
+ __m256 XMMpG1 = _mm256_setzero_ps();
1672
+ __m256 XMMqG1 = _mm256_setzero_ps();
1673
+
1674
+ for(mf_int d = d_begin; d < d_end; d += 8)
1675
+ {
1676
+ __m256 XMMp = _mm256_load_ps(p+d);
1677
+ __m256 XMMq = _mm256_load_ps(q+d);
1678
+
1679
+ __m256 XMMpg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
1680
+ _mm256_mul_ps(XMMz, XMMq));
1681
+ __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
1682
+ _mm256_mul_ps(XMMz, XMMp));
1683
+
1684
+ XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
1685
+ XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
1686
+
1687
+ XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
1688
+ XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
1689
+ _mm256_store_ps(p+d, XMMp);
1690
+ _mm256_store_ps(q+d, XMMq);
1691
+ }
1692
+
1693
+ mf_float tmp = 0;
1694
+ _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
1695
+ if(tmp > 0)
1696
+ {
1697
+ for(mf_int d = d_begin; d < d_end; d += 8)
1698
+ {
1699
+ __m256 XMMp = _mm256_load_ps(p+d);
1700
+ __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMp,
1701
+ _mm256_set1_ps(0.0f), _CMP_LE_OS),
1702
+ _mm256_set1_ps(-0.0f));
1703
+ XMMp = _mm256_xor_ps(XMMflip,
1704
+ _mm256_max_ps(_mm256_sub_ps(
1705
+ _mm256_xor_ps(XMMp, XMMflip),
1706
+ _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
1707
+ _mm256_set1_ps(0.0f)));
1708
+ _mm256_store_ps(p+d, XMMp);
1709
+ }
1710
+ }
1711
+
1712
+ _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
1713
+ if(tmp > 0)
1714
+ {
1715
+ for(mf_int d = d_begin; d < d_end; d += 8)
1716
+ {
1717
+ __m256 XMMq = _mm256_load_ps(q+d);
1718
+ __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMq,
1719
+ _mm256_set1_ps(0.0f), _CMP_LE_OS),
1720
+ _mm256_set1_ps(-0.0f));
1721
+ XMMq = _mm256_xor_ps(XMMflip,
1722
+ _mm256_max_ps(_mm256_sub_ps(
1723
+ _mm256_xor_ps(XMMq, XMMflip),
1724
+ _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
1725
+ _mm256_set1_ps(0.0f)));
1726
+ _mm256_store_ps(q+d, XMMq);
1727
+ }
1728
+ }
1729
+
1730
+ if(param.do_nmf)
1731
+ {
1732
+ for(mf_int d = d_begin; d < d_end; d += 8)
1733
+ {
1734
+ __m256 XMMp = _mm256_load_ps(p+d);
1735
+ __m256 XMMq = _mm256_load_ps(q+d);
1736
+ XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0));
1737
+ XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0));
1738
+ _mm256_store_ps(p+d, XMMp);
1739
+ _mm256_store_ps(q+d, XMMq);
1740
+ }
1741
+ }
1742
+
1743
+ XMMpG1 = _mm256_add_ps(XMMpG1,
1744
+ _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
1745
+ XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1746
+ XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1747
+
1748
+ XMMqG1 = _mm256_add_ps(XMMqG1,
1749
+ _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
1750
+ XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1751
+ XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1752
+
1753
+ XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
1754
+ XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
1755
+
1756
+ _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
1757
+ _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
1758
+ }
1759
+ #else
1760
+ void MFSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
1761
+ {
1762
+ mf_float eta_p = param.eta*qrsqrt(*pG);
1763
+ mf_float eta_q = param.eta*qrsqrt(*qG);
1764
+
1765
+ mf_float pG1 = 0;
1766
+ mf_float qG1 = 0;
1767
+
1768
+ for(mf_int d = d_begin; d < d_end; ++d)
1769
+ {
1770
+ mf_float gp = -z*q[d]+lambda_p2*p[d];
1771
+ mf_float gq = -z*p[d]+lambda_q2*q[d];
1772
+
1773
+ pG1 += gp*gp;
1774
+ qG1 += gq*gq;
1775
+
1776
+ p[d] -= eta_p*gp;
1777
+ q[d] -= eta_q*gq;
1778
+ }
1779
+
1780
+ if(lambda_p1 > 0)
1781
+ {
1782
+ for(mf_int d = d_begin; d < d_end; ++d)
1783
+ {
1784
+ mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
1785
+ p[d] = p[d] >= 0? p1: -p1;
1786
+ }
1787
+ }
1788
+
1789
+ if(lambda_q1 > 0)
1790
+ {
1791
+ for(mf_int d = d_begin; d < d_end; ++d)
1792
+ {
1793
+ mf_float q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
1794
+ q[d] = q[d] >= 0? q1: -q1;
1795
+ }
1796
+ }
1797
+
1798
+ if(param.do_nmf)
1799
+ {
1800
+ for(mf_int d = d_begin; d < d_end; ++d)
1801
+ {
1802
+ p[d] = max(p[d], (mf_float)0.0f);
1803
+ q[d] = max(q[d], (mf_float)0.0f);
1804
+ }
1805
+ }
1806
+
1807
+ *pG += pG1*rk;
1808
+ *qG += qG1*rk;
1809
+ }
1810
+ #endif
1811
+
1812
+ class L2_MFR : public MFSolver
1813
+ {
1814
+ public:
1815
+ L2_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1816
+ mf_model &model, mf_parameter param, bool &slow_only)
1817
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1818
+
1819
+ protected:
1820
+ #if defined USESSE
1821
+ void prepare_for_sg_update(
1822
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1823
+ #elif defined USEAVX
1824
+ void prepare_for_sg_update(
1825
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1826
+ #else
1827
+ void prepare_for_sg_update();
1828
+ #endif
1829
+ };
1830
+
1831
+ #if defined USESSE
1832
+ void L2_MFR::prepare_for_sg_update(
1833
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1834
+ {
1835
+ calc_z(XMMz, model.k, p, q);
1836
+ XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1837
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1838
+ _mm_mul_ps(XMMz, XMMz)));
1839
+ XMMerror = XMMloss;
1840
+ }
1841
+ #elif defined USEAVX
1842
+ void L2_MFR::prepare_for_sg_update(
1843
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1844
+ {
1845
+ calc_z(XMMz, model.k, p, q);
1846
+ XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1847
+ XMMloss = _mm_add_pd(XMMloss,
1848
+ _mm_cvtps_pd(_mm256_castps256_ps128(
1849
+ _mm256_mul_ps(XMMz, XMMz))));
1850
+ XMMerror = XMMloss;
1851
+ }
1852
+ #else
1853
+ void L2_MFR::prepare_for_sg_update()
1854
+ {
1855
+ calc_z(z, model.k, p, q);
1856
+ z = N->r-z;
1857
+ loss += z*z;
1858
+ error = loss;
1859
+ }
1860
+ #endif
1861
+ class L1_MFR : public MFSolver
1862
+ {
1863
+ public:
1864
+ L1_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1865
+ mf_model &model, mf_parameter param, bool &slow_only)
1866
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1867
+
1868
+ protected:
1869
+ #if defined USESSE
1870
+ void prepare_for_sg_update(
1871
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1872
+ #elif defined USEAVX
1873
+ void prepare_for_sg_update(
1874
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1875
+ #else
1876
+ void prepare_for_sg_update();
1877
+ #endif
1878
+ };
1879
+
1880
+ #if defined USESSE
1881
+ void L1_MFR::prepare_for_sg_update(
1882
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1883
+ {
1884
+ calc_z(XMMz, model.k, p, q);
1885
+ XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1886
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1887
+ _mm_andnot_ps(_mm_set1_ps(-0.0f), XMMz)));
1888
+ XMMerror = XMMloss;
1889
+ XMMz = _mm_add_ps(_mm_and_ps(_mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f)),
1890
+ _mm_set1_ps(1.0f)),
1891
+ _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
1892
+ _mm_set1_ps(-1.0f)));
1893
+ }
1894
+ #elif defined USEAVX
1895
+ void L1_MFR::prepare_for_sg_update(
1896
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1897
+ {
1898
+ calc_z(XMMz, model.k, p, q);
1899
+ XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1900
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm256_castps256_ps128(
1901
+ _mm256_andnot_ps(_mm256_set1_ps(-0.0f), XMMz))));
1902
+ XMMerror = XMMloss;
1903
+ XMMz = _mm256_add_ps(_mm256_and_ps(_mm256_cmp_ps(XMMz,
1904
+ _mm256_set1_ps(0.0f), _CMP_GT_OS), _mm256_set1_ps(1.0f)),
1905
+ _mm256_and_ps(_mm256_cmp_ps(XMMz,
1906
+ _mm256_set1_ps(0.0f), _CMP_LT_OS), _mm256_set1_ps(-1.0f)));
1907
+ }
1908
+ #else
1909
+ void L1_MFR::prepare_for_sg_update()
1910
+ {
1911
+ calc_z(z, model.k, p, q);
1912
+ z = N->r-z;
1913
+ loss += abs(z);
1914
+ error = loss;
1915
+ if(z > 0)
1916
+ z = 1;
1917
+ else if(z < 0)
1918
+ z = -1;
1919
+ }
1920
+ #endif
1921
+
1922
+ class KL_MFR : public MFSolver
1923
+ {
1924
+ public:
1925
+ KL_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1926
+ mf_model &model, mf_parameter param, bool &slow_only)
1927
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1928
+
1929
+ protected:
1930
+ #if defined USESSE
1931
+ void prepare_for_sg_update(
1932
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1933
+ #elif defined USEAVX
1934
+ void prepare_for_sg_update(
1935
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1936
+ #else
1937
+ void prepare_for_sg_update();
1938
+ #endif
1939
+ };
1940
+
1941
+ #if defined USESSE
1942
+ void KL_MFR::prepare_for_sg_update(
1943
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1944
+ {
1945
+ calc_z(XMMz, model.k, p, q);
1946
+ XMMz = _mm_div_ps(_mm_set1_ps(N->r), XMMz);
1947
+ _mm_store_ss(&z, XMMz);
1948
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1949
+ _mm_set1_ps(N->r*(log(z)-1+1/z))));
1950
+ XMMerror = XMMloss;
1951
+ XMMz = _mm_sub_ps(XMMz, _mm_set1_ps(1.0f));
1952
+ }
1953
+ #elif defined USEAVX
1954
+ void KL_MFR::prepare_for_sg_update(
1955
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1956
+ {
1957
+ calc_z(XMMz, model.k, p, q);
1958
+ XMMz = _mm256_div_ps(_mm256_set1_ps(N->r), XMMz);
1959
+ _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
1960
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1961
+ _mm_set1_ps(N->r*(log(z)-1+1/z))));
1962
+ XMMerror = XMMloss;
1963
+ XMMz = _mm256_sub_ps(XMMz, _mm256_set1_ps(1.0f));
1964
+ }
1965
+ #else
1966
+ void KL_MFR::prepare_for_sg_update()
1967
+ {
1968
+ calc_z(z, model.k, p, q);
1969
+ z = N->r/z;
1970
+ loss += N->r*(log(z)-1+1/z);
1971
+ error = loss;
1972
+ z -= 1;
1973
+ }
1974
+ #endif
1975
+
1976
+ class LR_MFC : public MFSolver
1977
+ {
1978
+ public:
1979
+ LR_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
1980
+ mf_float *PG, mf_float *QG, mf_model &model,
1981
+ mf_parameter param, bool &slow_only)
1982
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1983
+
1984
+ protected:
1985
+ #if defined USESSE
1986
+ void prepare_for_sg_update(
1987
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1988
+ #elif defined USEAVX
1989
+ void prepare_for_sg_update(
1990
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1991
+ #else
1992
+ void prepare_for_sg_update();
1993
+ #endif
1994
+ };
1995
+
1996
+ #if defined USESSE
1997
+ void LR_MFC::prepare_for_sg_update(
1998
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1999
+ {
2000
+ calc_z(XMMz, model.k, p, q);
2001
+ _mm_store_ss(&z, XMMz);
2002
+ if(N->r > 0)
2003
+ {
2004
+ z = exp(-z);
2005
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2006
+ XMMz = _mm_set1_ps(z/(1+z));
2007
+ }
2008
+ else
2009
+ {
2010
+ z = exp(z);
2011
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2012
+ XMMz = _mm_set1_ps(-z/(1+z));
2013
+ }
2014
+ XMMerror = XMMloss;
2015
+ }
2016
+ #elif defined USEAVX
2017
+ void LR_MFC::prepare_for_sg_update(
2018
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2019
+ {
2020
+ calc_z(XMMz, model.k, p, q);
2021
+ _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2022
+ if(N->r > 0)
2023
+ {
2024
+ z = exp(-z);
2025
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2026
+ XMMz = _mm256_set1_ps(z/(1+z));
2027
+ }
2028
+ else
2029
+ {
2030
+ z = exp(z);
2031
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2032
+ XMMz = _mm256_set1_ps(-z/(1+z));
2033
+ }
2034
+ XMMerror = XMMloss;
2035
+ }
2036
+ #else
2037
+ void LR_MFC::prepare_for_sg_update()
2038
+ {
2039
+ calc_z(z, model.k, p, q);
2040
+ if(N->r > 0)
2041
+ {
2042
+ z = exp(-z);
2043
+ loss += log(1+z);
2044
+ error = loss;
2045
+ z = z/(1+z);
2046
+ }
2047
+ else
2048
+ {
2049
+ z = exp(z);
2050
+ loss += log(1+z);
2051
+ error = loss;
2052
+ z = -z/(1+z);
2053
+ }
2054
+ }
2055
+ #endif
2056
+
2057
+ class L2_MFC : public MFSolver
2058
+ {
2059
+ public:
2060
+ L2_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2061
+ mf_float *PG, mf_float *QG, mf_model &model,
2062
+ mf_parameter param, bool &slow_only)
2063
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2064
+
2065
+ protected:
2066
+ #if defined USESSE
2067
+ void prepare_for_sg_update(
2068
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2069
+ #elif defined USEAVX
2070
+ void prepare_for_sg_update(
2071
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2072
+ #else
2073
+ void prepare_for_sg_update();
2074
+ #endif
2075
+ };
2076
+
2077
+ #if defined USESSE
2078
+ void L2_MFC::prepare_for_sg_update(
2079
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2080
+ {
2081
+ calc_z(XMMz, model.k, p, q);
2082
+ if(N->r > 0)
2083
+ {
2084
+ __m128 mask = _mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f));
2085
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2086
+ _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2087
+ XMMz = _mm_max_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2088
+ _mm_set1_ps(1.0f), XMMz));
2089
+ }
2090
+ else
2091
+ {
2092
+ __m128 mask = _mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f));
2093
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2094
+ _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2095
+ XMMz = _mm_min_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2096
+ _mm_set1_ps(-1.0f), XMMz));
2097
+ }
2098
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2099
+ _mm_mul_ps(XMMz, XMMz)));
2100
+ }
2101
+ #elif defined USEAVX
2102
+ void L2_MFC::prepare_for_sg_update(
2103
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2104
+ {
2105
+ calc_z(XMMz, model.k, p, q);
2106
+ if(N->r > 0)
2107
+ {
2108
+ __m128 mask = _mm_cmpgt_ps(_mm256_castps256_ps128(XMMz),
2109
+ _mm_set1_ps(0.0f));
2110
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2111
+ _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2112
+ XMMz = _mm256_max_ps(_mm256_set1_ps(0.0f),
2113
+ _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz));
2114
+ }
2115
+ else
2116
+ {
2117
+ __m128 mask = _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2118
+ _mm_set1_ps(0.0f));
2119
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2120
+ _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2121
+ XMMz = _mm256_min_ps(_mm256_set1_ps(0.0f),
2122
+ _mm256_sub_ps(_mm256_set1_ps(-1.0f), XMMz));
2123
+ }
2124
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2125
+ _mm_mul_ps(_mm256_castps256_ps128(XMMz),
2126
+ _mm256_castps256_ps128(XMMz))));
2127
+ }
2128
+ #else
2129
+ void L2_MFC::prepare_for_sg_update()
2130
+ {
2131
+ calc_z(z, model.k, p, q);
2132
+ if(N->r > 0)
2133
+ {
2134
+ error += z > 0? 1: 0;
2135
+ z = max(0.0f, 1-z);
2136
+ }
2137
+ else
2138
+ {
2139
+ error += z < 0? 1: 0;
2140
+ z = min(0.0f, -1-z);
2141
+ }
2142
+ loss += z*z;
2143
+ }
2144
+ #endif
2145
+
2146
+ class L1_MFC : public MFSolver
2147
+ {
2148
+ public:
2149
+ L1_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
2150
+ mf_model &model, mf_parameter param, bool &slow_only)
2151
+ : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2152
+
2153
+ protected:
2154
+ #if defined USESSE
2155
+ void prepare_for_sg_update(
2156
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2157
+ #elif defined USEAVX
2158
+ void prepare_for_sg_update(
2159
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2160
+ #else
2161
+ void prepare_for_sg_update();
2162
+ #endif
2163
+ };
2164
+
2165
+ #if defined USESSE
2166
+ void L1_MFC::prepare_for_sg_update(
2167
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2168
+ {
2169
+ calc_z(XMMz, model.k, p, q);
2170
+ if(N->r > 0)
2171
+ {
2172
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2173
+ _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2174
+ _mm_set1_ps(1.0f))));
2175
+ XMMz = _mm_sub_ps(_mm_set1_ps(1.0f), XMMz);
2176
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2177
+ _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2178
+ XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2179
+ _mm_set1_ps(1.0f));
2180
+ }
2181
+ else
2182
+ {
2183
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2184
+ _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
2185
+ _mm_set1_ps(1.0f))));
2186
+ XMMz = _mm_add_ps(_mm_set1_ps(1.0f), XMMz);
2187
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2188
+ _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2189
+ XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2190
+ _mm_set1_ps(-1.0f));
2191
+ }
2192
+ }
2193
+ #elif defined USEAVX
2194
+ void L1_MFC::prepare_for_sg_update(
2195
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2196
+ {
2197
+ calc_z(XMMz, model.k, p, q);
2198
+ if(N->r > 0)
2199
+ {
2200
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2201
+ _mm_cmpge_ps(_mm256_castps256_ps128(XMMz),
2202
+ _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2203
+ XMMz = _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz);
2204
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2205
+ _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2206
+ XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2207
+ _CMP_GE_OS), _mm256_set1_ps(1.0f));
2208
+ }
2209
+ else
2210
+ {
2211
+ XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2212
+ _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2213
+ _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2214
+ XMMz = _mm256_add_ps(_mm256_set1_ps(1.0f), XMMz);
2215
+ XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2216
+ _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2217
+ XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2218
+ _CMP_GE_OS), _mm256_set1_ps(-1.0f));
2219
+ }
2220
+ }
2221
+ #else
2222
+ void L1_MFC::prepare_for_sg_update()
2223
+ {
2224
+ calc_z(z, model.k, p, q);
2225
+ if(N->r > 0)
2226
+ {
2227
+ loss += max(0.0f, 1-z);
2228
+ error += z > 0? 1.0f: 0.0f;
2229
+ z = z > 1? 0.0f: 1.0f;
2230
+ }
2231
+ else
2232
+ {
2233
+ loss += max(0.0f, 1+z);
2234
+ error += z < 0? 1.0f: 0.0f;
2235
+ z = z < -1? 0.0f: -1.0f;
2236
+ }
2237
+ }
2238
+ #endif
2239
+ //--------------------------------------
2240
+ //------------One-class MF--------------
2241
+ //--------------------------------------
2242
+
2243
+ class BPRSolver : public SolverBase
2244
+ {
2245
+ public:
2246
+ BPRSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
2247
+ mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
2248
+ bool &slow_only, bool is_column_oriented)
2249
+ : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only),
2250
+ is_column_oriented(is_column_oriented) {}
2251
+
2252
+ protected:
2253
+ #if defined USESSE
2254
+ static void calc_z(__m128 &XMMz, mf_int k,
2255
+ mf_float *p, mf_float *q, mf_float *w);
2256
+ void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2257
+ void prepare_for_sg_update(
2258
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2259
+ void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2260
+ __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2261
+ __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
2262
+ __m128 XMMeta, __m128 XMMrk);
2263
+ void finalize(__m128d XMMloss, __m128d XMMerror);
2264
+ #elif defined USEAVX
2265
+ static void calc_z(__m256 &XMMz, mf_int k,
2266
+ mf_float *p, mf_float *q, mf_float *w);
2267
+ void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2268
+ void prepare_for_sg_update(
2269
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2270
+ void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2271
+ __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2272
+ __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
2273
+ __m256 XMMeta, __m256 XMMrk);
2274
+ void finalize(__m128d XMMloss, __m128d XMMerror);
2275
+ #else
2276
+ static void calc_z(mf_float &z, mf_int k,
2277
+ mf_float *p, mf_float *q, mf_float *w);
2278
+ void arrange_block();
2279
+ void prepare_for_sg_update();
2280
+ void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
2281
+ void finalize();
2282
+ #endif
2283
+ void update() { ++pG; ++qG; ++wG; };
2284
+ virtual void prepare_negative() = 0;
2285
+
2286
+ bool is_column_oriented;
2287
+ mf_int bpr_bid;
2288
+ mf_float *w;
2289
+ mf_float *wG;
2290
+ };
2291
+
2292
+
2293
+ #if defined USESSE
2294
+ inline void BPRSolver::calc_z(
2295
+ __m128 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2296
+ {
2297
+ XMMz = _mm_setzero_ps();
2298
+ for(mf_int d = 0; d < k; d += 4)
2299
+ XMMz = _mm_add_ps(XMMz, _mm_mul_ps(_mm_load_ps(p+d),
2300
+ _mm_sub_ps(_mm_load_ps(q+d), _mm_load_ps(w+d))));
2301
+ // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
2302
+ // high-bit to low-bit, where "+" means concatenating two arrays.
2303
+ __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
2304
+ // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
2305
+ // high-bit to low-bit, where "+" means concatenating two arrays.
2306
+ XMMz = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
2307
+ }
2308
+
2309
+ void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2310
+ {
2311
+ XMMloss = _mm_setzero_pd();
2312
+ XMMerror = _mm_setzero_pd();
2313
+ bid = scheduler.get_job();
2314
+ block = blocks[bid];
2315
+ block->reload();
2316
+ bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2317
+ }
2318
+
2319
+ void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2320
+ {
2321
+ _mm_store_sd(&loss, XMMloss);
2322
+ _mm_store_sd(&error, XMMerror);
2323
+ scheduler.put_job(bid, loss, error);
2324
+ scheduler.put_bpr_job(bid, bpr_bid);
2325
+ }
2326
+
2327
+ void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2328
+ __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2329
+ __m128 XMMlambda_p2, __m128 XMMlambda_q2,
2330
+ __m128 XMMeta, __m128 XMMrk)
2331
+ {
2332
+ __m128 XMMpG = _mm_load1_ps(pG);
2333
+ __m128 XMMqG = _mm_load1_ps(qG);
2334
+ __m128 XMMwG = _mm_load1_ps(wG);
2335
+ __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
2336
+ __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
2337
+ __m128 XMMeta_w = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMwG));
2338
+
2339
+ __m128 XMMpG1 = _mm_setzero_ps();
2340
+ __m128 XMMqG1 = _mm_setzero_ps();
2341
+ __m128 XMMwG1 = _mm_setzero_ps();
2342
+
2343
+ for(mf_int d = d_begin; d < d_end; d += 4)
2344
+ {
2345
+ __m128 XMMp = _mm_load_ps(p+d);
2346
+ __m128 XMMq = _mm_load_ps(q+d);
2347
+ __m128 XMMw = _mm_load_ps(w+d);
2348
+
2349
+ __m128 XMMpg = _mm_add_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
2350
+ _mm_mul_ps(XMMz, _mm_sub_ps(XMMw, XMMq)));
2351
+ __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
2352
+ _mm_mul_ps(XMMz, XMMp));
2353
+ __m128 XMMwg = _mm_add_ps(_mm_mul_ps(XMMlambda_q2, XMMw),
2354
+ _mm_mul_ps(XMMz, XMMp));
2355
+
2356
+ XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
2357
+ XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
2358
+ XMMwG1 = _mm_add_ps(XMMwG1, _mm_mul_ps(XMMwg, XMMwg));
2359
+
2360
+ XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
2361
+ XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
2362
+ XMMw = _mm_sub_ps(XMMw, _mm_mul_ps(XMMeta_w, XMMwg));
2363
+
2364
+ _mm_store_ps(p+d, XMMp);
2365
+ _mm_store_ps(q+d, XMMq);
2366
+ _mm_store_ps(w+d, XMMw);
2367
+ }
2368
+
2369
+ mf_float tmp = 0;
2370
+ _mm_store_ss(&tmp, XMMlambda_p1);
2371
+ if(tmp > 0)
2372
+ {
2373
+ for(mf_int d = d_begin; d < d_end; d += 4)
2374
+ {
2375
+ __m128 XMMp = _mm_load_ps(p+d);
2376
+ __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
2377
+ _mm_set1_ps(-0.0f));
2378
+ XMMp = _mm_xor_ps(XMMflip,
2379
+ _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
2380
+ _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
2381
+ _mm_store_ps(p+d, XMMp);
2382
+ }
2383
+ }
2384
+
2385
+ _mm_store_ss(&tmp, XMMlambda_q1);
2386
+ if(tmp > 0)
2387
+ {
2388
+ for(mf_int d = d_begin; d < d_end; d += 4)
2389
+ {
2390
+ __m128 XMMq = _mm_load_ps(q+d);
2391
+ __m128 XMMw = _mm_load_ps(w+d);
2392
+ __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
2393
+ _mm_set1_ps(-0.0f));
2394
+ XMMq = _mm_xor_ps(XMMflip,
2395
+ _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
2396
+ _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2397
+ _mm_store_ps(q+d, XMMq);
2398
+
2399
+
2400
+ XMMflip = _mm_and_ps(_mm_cmple_ps(XMMw, _mm_set1_ps(0.0f)),
2401
+ _mm_set1_ps(-0.0f));
2402
+ XMMw = _mm_xor_ps(XMMflip,
2403
+ _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMw, XMMflip),
2404
+ _mm_mul_ps(XMMeta_w, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2405
+ _mm_store_ps(w+d, XMMw);
2406
+ }
2407
+ }
2408
+
2409
+ if(param.do_nmf)
2410
+ {
2411
+ for(mf_int d = d_begin; d < d_end; d += 4)
2412
+ {
2413
+ __m128 XMMp = _mm_load_ps(p+d);
2414
+ __m128 XMMq = _mm_load_ps(q+d);
2415
+ __m128 XMMw = _mm_load_ps(w+d);
2416
+ XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
2417
+ XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
2418
+ XMMw = _mm_max_ps(XMMw, _mm_set1_ps(0.0f));
2419
+ _mm_store_ps(p+d, XMMp);
2420
+ _mm_store_ps(q+d, XMMq);
2421
+ _mm_store_ps(w+d, XMMw);
2422
+ }
2423
+ }
2424
+
2425
+ // Update learning rate of latent vector p. Squared derivatives along all
2426
+ // latent dimensions will be computed above. Here their average will be
2427
+ // added into the associated squared-gradient sum.
2428
+ __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
2429
+ XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2430
+ XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
2431
+ _mm_store_ss(pG, XMMpG);
2432
+
2433
+ // Similar code is used to update learning rate of latent vector q.
2434
+ XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
2435
+ XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2436
+ XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
2437
+ _mm_store_ss(qG, XMMqG);
2438
+
2439
+ // Similar code is used to update learning rate of latent vector w.
2440
+ XMMtmp = _mm_add_ps(XMMwG1, _mm_movehl_ps(XMMwG1, XMMwG1));
2441
+ XMMwG1 = _mm_add_ps(XMMwG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2442
+ XMMwG = _mm_add_ps(XMMwG, _mm_mul_ps(XMMwG1, XMMrk));
2443
+ _mm_store_ss(wG, XMMwG);
2444
+ }
2445
+
2446
+ void BPRSolver::prepare_for_sg_update(
2447
+ __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2448
+ {
2449
+ prepare_negative();
2450
+ calc_z(XMMz, model.k, p, q, w);
2451
+ _mm_store_ss(&z, XMMz);
2452
+ z = exp(-z);
2453
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2454
+ XMMerror = XMMloss;
2455
+ XMMz = _mm_set1_ps(z/(1+z));
2456
+ }
2457
+ #elif defined USEAVX
2458
+ inline void BPRSolver::calc_z(
2459
+ __m256 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2460
+ {
2461
+ XMMz = _mm256_setzero_ps();
2462
+ for(mf_int d = 0; d < k; d += 8)
2463
+ XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
2464
+ _mm256_load_ps(p+d), _mm256_sub_ps(
2465
+ _mm256_load_ps(q+d), _mm256_load_ps(w+d))));
2466
+ XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
2467
+ XMMz = _mm256_hadd_ps(XMMz, XMMz);
2468
+ XMMz = _mm256_hadd_ps(XMMz, XMMz);
2469
+ }
2470
+
2471
+ void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2472
+ {
2473
+ XMMloss = _mm_setzero_pd();
2474
+ XMMerror = _mm_setzero_pd();
2475
+ bid = scheduler.get_job();
2476
+ block = blocks[bid];
2477
+ block->reload();
2478
+ bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2479
+ }
2480
+
2481
+ void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2482
+ {
2483
+ _mm_store_sd(&loss, XMMloss);
2484
+ _mm_store_sd(&error, XMMerror);
2485
+ scheduler.put_job(bid, loss, error);
2486
+ scheduler.put_bpr_job(bid, bpr_bid);
2487
+ }
2488
+
2489
+ void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2490
+ __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2491
+ __m256 XMMlambda_p2, __m256 XMMlambda_q2,
2492
+ __m256 XMMeta, __m256 XMMrk)
2493
+ {
2494
+ __m256 XMMpG = _mm256_broadcast_ss(pG);
2495
+ __m256 XMMqG = _mm256_broadcast_ss(qG);
2496
+ __m256 XMMwG = _mm256_broadcast_ss(wG);
2497
+ __m256 XMMeta_p =
2498
+ _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
2499
+ __m256 XMMeta_q =
2500
+ _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
2501
+ __m256 XMMeta_w =
2502
+ _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMwG));
2503
+
2504
+ __m256 XMMpG1 = _mm256_setzero_ps();
2505
+ __m256 XMMqG1 = _mm256_setzero_ps();
2506
+ __m256 XMMwG1 = _mm256_setzero_ps();
2507
+
2508
+ for(mf_int d = d_begin; d < d_end; d += 8)
2509
+ {
2510
+ __m256 XMMp = _mm256_load_ps(p+d);
2511
+ __m256 XMMq = _mm256_load_ps(q+d);
2512
+ __m256 XMMw = _mm256_load_ps(w+d);
2513
+ __m256 XMMpg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
2514
+ _mm256_mul_ps(XMMz, _mm256_sub_ps(XMMw, XMMq)));
2515
+ __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
2516
+ _mm256_mul_ps(XMMz, XMMp));
2517
+ __m256 XMMwg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_q2, XMMw),
2518
+ _mm256_mul_ps(XMMz, XMMp));
2519
+
2520
+ XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
2521
+ XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
2522
+ XMMwG1 = _mm256_add_ps(XMMwG1, _mm256_mul_ps(XMMwg, XMMwg));
2523
+
2524
+ XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
2525
+ XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
2526
+ XMMw = _mm256_sub_ps(XMMw, _mm256_mul_ps(XMMeta_w, XMMwg));
2527
+
2528
+ _mm256_store_ps(p+d, XMMp);
2529
+ _mm256_store_ps(q+d, XMMq);
2530
+ _mm256_store_ps(w+d, XMMw);
2531
+ }
2532
+
2533
+ mf_float tmp = 0;
2534
+ _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
2535
+ if(tmp > 0)
2536
+ {
2537
+ for(mf_int d = d_begin; d < d_end; d += 8)
2538
+ {
2539
+ __m256 XMMp = _mm256_load_ps(p+d);
2540
+ __m256 XMMflip =
2541
+ _mm256_and_ps(
2542
+ _mm256_cmp_ps(XMMp, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2543
+ _mm256_set1_ps(-0.0f));
2544
+ XMMp = _mm256_xor_ps(XMMflip,
2545
+ _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMp, XMMflip),
2546
+ _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
2547
+ _mm256_set1_ps(0.0f)));
2548
+ _mm256_store_ps(p+d, XMMp);
2549
+ }
2550
+ }
2551
+
2552
+ _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
2553
+ if(tmp > 0)
2554
+ {
2555
+ for(mf_int d = d_begin; d < d_end; d += 8)
2556
+ {
2557
+ __m256 XMMq = _mm256_load_ps(q+d);
2558
+ __m256 XMMw = _mm256_load_ps(w+d);
2559
+ __m256 XMMflip;
2560
+
2561
+ XMMflip = _mm256_and_ps(
2562
+ _mm256_cmp_ps(XMMq, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2563
+ _mm256_set1_ps(-0.0f));
2564
+ XMMq = _mm256_xor_ps(XMMflip,
2565
+ _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMq, XMMflip),
2566
+ _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
2567
+ _mm256_set1_ps(0.0f)));
2568
+ _mm256_store_ps(q+d, XMMq);
2569
+
2570
+
2571
+ XMMflip = _mm256_and_ps(
2572
+ _mm256_cmp_ps(XMMw, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2573
+ _mm256_set1_ps(-0.0f));
2574
+ XMMw = _mm256_xor_ps(XMMflip,
2575
+ _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMw, XMMflip),
2576
+ _mm256_mul_ps(XMMeta_w, XMMlambda_q1)),
2577
+ _mm256_set1_ps(0.0f)));
2578
+ _mm256_store_ps(w+d, XMMw);
2579
+ }
2580
+ }
2581
+
2582
+ if(param.do_nmf)
2583
+ {
2584
+ for(mf_int d = d_begin; d < d_end; d += 8)
2585
+ {
2586
+ __m256 XMMp = _mm256_load_ps(p+d);
2587
+ __m256 XMMq = _mm256_load_ps(q+d);
2588
+ __m256 XMMw = _mm256_load_ps(w+d);
2589
+ XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0.0f));
2590
+ XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0.0f));
2591
+ XMMw = _mm256_max_ps(XMMw, _mm256_set1_ps(0.0f));
2592
+ _mm256_store_ps(p+d, XMMp);
2593
+ _mm256_store_ps(q+d, XMMq);
2594
+ _mm256_store_ps(w+d, XMMw);
2595
+ }
2596
+ }
2597
+
2598
+ XMMpG1 = _mm256_add_ps(XMMpG1,
2599
+ _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
2600
+ XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2601
+ XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2602
+
2603
+ XMMqG1 = _mm256_add_ps(XMMqG1,
2604
+ _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
2605
+ XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2606
+ XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2607
+
2608
+ XMMwG1 = _mm256_add_ps(XMMwG1,
2609
+ _mm256_permute2f128_ps(XMMwG1, XMMwG1, 0x1));
2610
+ XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2611
+ XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2612
+
2613
+ XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
2614
+ XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
2615
+ XMMwG = _mm256_add_ps(XMMwG, _mm256_mul_ps(XMMwG1, XMMrk));
2616
+
2617
+ _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
2618
+ _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
2619
+ _mm_store_ss(wG, _mm256_castps256_ps128(XMMwG));
2620
+ }
2621
+
2622
+ void BPRSolver::prepare_for_sg_update(
2623
+ __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2624
+ {
2625
+ prepare_negative();
2626
+ calc_z(XMMz, model.k, p, q, w);
2627
+ _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2628
+ z = exp(-z);
2629
+ XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2630
+ XMMerror = XMMloss;
2631
+ XMMz = _mm256_set1_ps(z/(1+z));
2632
+ }
2633
+ #else
2634
+ inline void BPRSolver::calc_z(
2635
+ mf_float &z, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2636
+ {
2637
+ z = 0;
2638
+ for(mf_int d = 0; d < k; ++d)
2639
+ z += p[d]*(q[d]-w[d]);
2640
+ }
2641
+
2642
+ void BPRSolver::arrange_block()
2643
+ {
2644
+ loss = 0.0;
2645
+ error = 0.0;
2646
+ bid = scheduler.get_job();
2647
+ block = blocks[bid];
2648
+ block->reload();
2649
+ bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2650
+ }
2651
+
2652
+ void BPRSolver::finalize()
2653
+ {
2654
+ scheduler.put_job(bid, loss, error);
2655
+ scheduler.put_bpr_job(bid, bpr_bid);
2656
+ }
2657
+
2658
+ void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
2659
+ {
2660
+ mf_float eta_p = param.eta*qrsqrt(*pG);
2661
+ mf_float eta_q = param.eta*qrsqrt(*qG);
2662
+ mf_float eta_w = param.eta*qrsqrt(*wG);
2663
+
2664
+ mf_float pG1 = 0;
2665
+ mf_float qG1 = 0;
2666
+ mf_float wG1 = 0;
2667
+
2668
+ for(mf_int d = d_begin; d < d_end; ++d)
2669
+ {
2670
+ mf_float gp = z*(w[d]-q[d]) + lambda_p2*p[d];
2671
+ mf_float gq = -z*p[d] + lambda_q2*q[d];
2672
+ mf_float gw = z*p[d] + lambda_q2*w[d];
2673
+
2674
+ pG1 += gp*gp;
2675
+ qG1 += gq*gq;
2676
+ wG1 += gw*gw;
2677
+
2678
+ p[d] -= eta_p*gp;
2679
+ q[d] -= eta_q*gq;
2680
+ w[d] -= eta_w*gw;
2681
+ }
2682
+
2683
+ if(lambda_p1 > 0)
2684
+ {
2685
+ for(mf_int d = d_begin; d < d_end; ++d)
2686
+ {
2687
+ mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
2688
+ p[d] = p[d] >= 0? p1: -p1;
2689
+ }
2690
+ }
2691
+
2692
+ if(lambda_q1 > 0)
2693
+ {
2694
+ for(mf_int d = d_begin; d < d_end; ++d)
2695
+ {
2696
+ mf_float q1 = max(abs(w[d])-lambda_q1*eta_w, 0.0f);
2697
+ w[d] = w[d] >= 0? q1: -q1;
2698
+ q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
2699
+ q[d] = q[d] >= 0? q1: -q1;
2700
+ }
2701
+ }
2702
+
2703
+ if(param.do_nmf)
2704
+ {
2705
+ for(mf_int d = d_begin; d < d_end; ++d)
2706
+ {
2707
+ p[d] = max(p[d], (mf_float)0.0);
2708
+ q[d] = max(q[d], (mf_float)0.0);
2709
+ w[d] = max(w[d], (mf_float)0.0);
2710
+ }
2711
+ }
2712
+
2713
+ *pG += pG1*rk;
2714
+ *qG += qG1*rk;
2715
+ *wG += wG1*rk;
2716
+ }
2717
+
2718
+ void BPRSolver::prepare_for_sg_update()
2719
+ {
2720
+ prepare_negative();
2721
+ calc_z(z, model.k, p, q, w);
2722
+ z = exp(-z);
2723
+ loss += log(1+z);
2724
+ error = loss;
2725
+ z = z/(1+z);
2726
+ }
2727
+ #endif
2728
+
2729
+ class COL_BPR_MFOC : public BPRSolver
2730
+ {
2731
+ public:
2732
+ COL_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2733
+ mf_float *PG, mf_float *QG, mf_model &model,
2734
+ mf_parameter param, bool &slow_only,
2735
+ bool is_column_oriented=true)
2736
+ : BPRSolver(scheduler, blocks, PG, QG, model, param,
2737
+ slow_only, is_column_oriented) {}
2738
+ protected:
2739
+ #if defined USESSE
2740
+ void load_fixed_variables(
2741
+ __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2742
+ __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
2743
+ __m128 &XMMeta, __m128 &XMMrk_slow,
2744
+ __m128 &XMMrk_fast);
2745
+ #elif defined USEAVX
2746
+ void load_fixed_variables(
2747
+ __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2748
+ __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
2749
+ __m256 &XMMeta, __m256 &XMMrk_slow,
2750
+ __m256 &XMMrk_fast);
2751
+ #else
2752
+ void load_fixed_variables();
2753
+ #endif
2754
+ void prepare_negative();
2755
+ };
2756
+
2757
+ void COL_BPR_MFOC::prepare_negative()
2758
+ {
2759
+ mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2760
+ is_column_oriented);
2761
+ w = model.P + negative*model.k;
2762
+ wG = PG + negative*2;
2763
+ swap(p, q);
2764
+ swap(pG, qG);
2765
+ }
2766
+
2767
+ #if defined USESSE
2768
+ void COL_BPR_MFOC::load_fixed_variables(
2769
+ __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2770
+ __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
2771
+ __m128 &XMMeta, __m128 &XMMrk_slow,
2772
+ __m128 &XMMrk_fast)
2773
+ {
2774
+ XMMlambda_p1 = _mm_set1_ps(param.lambda_q1);
2775
+ XMMlambda_q1 = _mm_set1_ps(param.lambda_p1);
2776
+ XMMlambda_p2 = _mm_set1_ps(param.lambda_q2);
2777
+ XMMlambda_q2 = _mm_set1_ps(param.lambda_p2);
2778
+ XMMeta = _mm_set1_ps(param.eta);
2779
+ XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
2780
+ XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
2781
+ }
2782
+ #elif defined USEAVX
2783
+ void COL_BPR_MFOC::load_fixed_variables(
2784
+ __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2785
+ __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
2786
+ __m256 &XMMeta, __m256 &XMMrk_slow,
2787
+ __m256 &XMMrk_fast)
2788
+ {
2789
+ XMMlambda_p1 = _mm256_set1_ps(param.lambda_q1);
2790
+ XMMlambda_q1 = _mm256_set1_ps(param.lambda_p1);
2791
+ XMMlambda_p2 = _mm256_set1_ps(param.lambda_q2);
2792
+ XMMlambda_q2 = _mm256_set1_ps(param.lambda_p2);
2793
+ XMMeta = _mm256_set1_ps(param.eta);
2794
+ XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
2795
+ XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
2796
+ }
2797
+ #else
2798
+ void COL_BPR_MFOC::load_fixed_variables()
2799
+ {
2800
+ lambda_p1 = param.lambda_q1;
2801
+ lambda_q1 = param.lambda_p1;
2802
+ lambda_p2 = param.lambda_q2;
2803
+ lambda_q2 = param.lambda_p2;
2804
+ rk_slow = (mf_float)1.0/kALIGN;
2805
+ rk_fast = (mf_float)1.0/(model.k-kALIGN);
2806
+ }
2807
+ #endif
2808
+
2809
+ class ROW_BPR_MFOC : public BPRSolver
2810
+ {
2811
+ public:
2812
+ ROW_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2813
+ mf_float *PG, mf_float *QG, mf_model &model,
2814
+ mf_parameter param, bool &slow_only,
2815
+ bool is_column_oriented = false)
2816
+ : BPRSolver(scheduler, blocks, PG, QG, model, param,
2817
+ slow_only, is_column_oriented) {}
2818
+ protected:
2819
+ void prepare_negative();
2820
+ };
2821
+
2822
+ void ROW_BPR_MFOC::prepare_negative()
2823
+ {
2824
+ mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2825
+ is_column_oriented);
2826
+ w = model.Q + negative*model.k;
2827
+ wG = QG + negative*2;
2828
+ }
2829
+
2830
+
2831
+ class SolverFactory
2832
+ {
2833
+ public:
2834
+ static shared_ptr<SolverBase> get_solver(
2835
+ Scheduler &scheduler,
2836
+ vector<BlockBase*> &blocks,
2837
+ mf_float *PG,
2838
+ mf_float *QG,
2839
+ mf_model &model,
2840
+ mf_parameter param,
2841
+ bool &slow_only);
2842
+ };
2843
+
2844
+ shared_ptr<SolverBase> SolverFactory::get_solver(
2845
+ Scheduler &scheduler,
2846
+ vector<BlockBase*> &blocks,
2847
+ mf_float *PG,
2848
+ mf_float *QG,
2849
+ mf_model &model,
2850
+ mf_parameter param,
2851
+ bool &slow_only)
2852
+ {
2853
+ shared_ptr<SolverBase> solver;
2854
+
2855
+ switch(param.fun)
2856
+ {
2857
+ case P_L2_MFR:
2858
+ solver = shared_ptr<SolverBase>(new L2_MFR(scheduler, blocks,
2859
+ PG, QG, model, param, slow_only));
2860
+ break;
2861
+ case P_L1_MFR:
2862
+ solver = shared_ptr<SolverBase>(new L1_MFR(scheduler, blocks,
2863
+ PG, QG, model, param, slow_only));
2864
+ break;
2865
+ case P_KL_MFR:
2866
+ solver = shared_ptr<SolverBase>(new KL_MFR(scheduler, blocks,
2867
+ PG, QG, model, param, slow_only));
2868
+ break;
2869
+ case P_LR_MFC:
2870
+ solver = shared_ptr<SolverBase>(new LR_MFC(scheduler, blocks,
2871
+ PG, QG, model, param, slow_only));
2872
+ break;
2873
+ case P_L2_MFC:
2874
+ solver = shared_ptr<SolverBase>(new L2_MFC(scheduler, blocks,
2875
+ PG, QG, model, param, slow_only));
2876
+ break;
2877
+ case P_L1_MFC:
2878
+ solver = shared_ptr<SolverBase>(new L1_MFC(scheduler, blocks,
2879
+ PG, QG, model, param, slow_only));
2880
+ break;
2881
+ case P_ROW_BPR_MFOC:
2882
+ solver = shared_ptr<SolverBase>(new ROW_BPR_MFOC(scheduler,
2883
+ blocks, PG, QG, model, param, slow_only));
2884
+ break;
2885
+ case P_COL_BPR_MFOC:
2886
+ solver = shared_ptr<SolverBase>(new COL_BPR_MFOC(scheduler,
2887
+ blocks, PG, QG, model, param, slow_only));
2888
+ break;
2889
+ default:
2890
+ throw invalid_argument("unknown error function");
2891
+ }
2892
+ return solver;
2893
+ }
2894
+
2895
+ void fpsg_core(
2896
+ Utility &util,
2897
+ Scheduler &sched,
2898
+ mf_problem *tr,
2899
+ mf_problem *va,
2900
+ mf_parameter param,
2901
+ mf_float scale,
2902
+ vector<BlockBase*> &block_ptrs,
2903
+ vector<mf_int> &omega_p,
2904
+ vector<mf_int> &omega_q,
2905
+ shared_ptr<mf_model> &model,
2906
+ vector<mf_int> cv_blocks,
2907
+ mf_double *cv_error)
2908
+ {
2909
+ #if defined USESSE || defined USEAVX
2910
+ auto flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE();
2911
+ _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
2912
+ #endif
2913
+ if(tr->nnz == 0)
2914
+ {
2915
+ cout << "warning: train on an empty training set" << endl;
2916
+ return;
2917
+ }
2918
+
2919
+ if(param.fun == P_L2_MFR ||
2920
+ param.fun == P_L1_MFR ||
2921
+ param.fun == P_KL_MFR)
2922
+ {
2923
+ switch(param.fun)
2924
+ {
2925
+ case P_L2_MFR:
2926
+ param.lambda_p2 /= scale;
2927
+ param.lambda_q2 /= scale;
2928
+ param.lambda_p1 /= (mf_float)pow(scale, 1.5);
2929
+ param.lambda_q1 /= (mf_float)pow(scale, 1.5);
2930
+ break;
2931
+ case P_L1_MFR:
2932
+ case P_KL_MFR:
2933
+ param.lambda_p1 /= sqrt(scale);
2934
+ param.lambda_q1 /= sqrt(scale);
2935
+ break;
2936
+ }
2937
+ }
2938
+
2939
+ if(!param.quiet)
2940
+ {
2941
+ cout.width(4);
2942
+ cout << "iter";
2943
+ cout.width(13);
2944
+ cout << "tr_"+util.get_error_legend();
2945
+ if(va->nnz != 0)
2946
+ {
2947
+ cout.width(13);
2948
+ cout << "va_"+util.get_error_legend();
2949
+ }
2950
+ cout.width(13);
2951
+ cout << "obj";
2952
+ cout << "\n";
2953
+ }
2954
+
2955
+ bool slow_only = param.lambda_p1 == 0 && param.lambda_q1 == 0? true: false;
2956
+ vector<mf_float> PG(model->m*2, 1), QG(model->n*2, 1);
2957
+
2958
+ vector<shared_ptr<SolverBase>> solvers(param.nr_threads);
2959
+ vector<thread> threads;
2960
+ threads.reserve(param.nr_threads);
2961
+ for(mf_int i = 0; i < param.nr_threads; ++i)
2962
+ {
2963
+ solvers[i] = SolverFactory::get_solver(sched, block_ptrs,
2964
+ PG.data(), QG.data(),
2965
+ *model, param, slow_only);
2966
+ threads.emplace_back(&SolverBase::run, solvers[i].get());
2967
+ }
2968
+
2969
+ for(mf_int iter = 0; iter < param.nr_iters; ++iter)
2970
+ {
2971
+ sched.wait_for_jobs_done();
2972
+
2973
+ if(!param.quiet)
2974
+ {
2975
+ mf_double reg = 0;
2976
+ mf_double reg1 = util.calc_reg1(*model, param.lambda_p1,
2977
+ param.lambda_q1, omega_p, omega_q);
2978
+ mf_double reg2 = util.calc_reg2(*model, param.lambda_p2,
2979
+ param.lambda_q2, omega_p, omega_q);
2980
+ mf_double tr_loss = sched.get_loss();
2981
+ mf_double tr_error = sched.get_error()/tr->nnz;
2982
+
2983
+ switch(param.fun)
2984
+ {
2985
+ case P_L2_MFR:
2986
+ reg = (reg1+reg2)*scale*scale;
2987
+ tr_loss *= scale*scale;
2988
+ tr_error = sqrt(tr_error*scale*scale);
2989
+ break;
2990
+ case P_L1_MFR:
2991
+ case P_KL_MFR:
2992
+ reg = (reg1+reg2)*scale;
2993
+ tr_loss *= scale;
2994
+ tr_error *= scale;
2995
+ break;
2996
+ default:
2997
+ reg = reg1+reg2;
2998
+ break;
2999
+ }
3000
+
3001
+ cout.width(4);
3002
+ cout << iter;
3003
+ cout.width(13);
3004
+ cout << fixed << setprecision(4) << tr_error;
3005
+ if(va->nnz != 0)
3006
+ {
3007
+ Block va_block(va->R, va->R+va->nnz);
3008
+ vector<BlockBase*> va_blocks(1, &va_block);
3009
+ vector<mf_int> va_block_ids(1, 0);
3010
+ mf_double va_error =
3011
+ util.calc_error(va_blocks, va_block_ids, *model)/va->nnz;
3012
+ switch(param.fun)
3013
+ {
3014
+ case P_L2_MFR:
3015
+ va_error = sqrt(va_error*scale*scale);
3016
+ break;
3017
+ case P_L1_MFR:
3018
+ case P_KL_MFR:
3019
+ va_error *= scale;
3020
+ break;
3021
+ }
3022
+
3023
+ cout.width(13);
3024
+ cout << fixed << setprecision(4) << va_error;
3025
+ }
3026
+ cout.width(13);
3027
+ cout << fixed << setprecision(4) << scientific << reg+tr_loss;
3028
+ cout << "\n" << flush;
3029
+ }
3030
+
3031
+ if(iter == 0)
3032
+ slow_only = false;
3033
+ if(iter == param.nr_iters - 1)
3034
+ sched.terminate();
3035
+ sched.resume();
3036
+ }
3037
+
3038
+ for(auto &thread : threads)
3039
+ thread.join();
3040
+
3041
+ if(cv_error != nullptr && cv_blocks.size() > 0)
3042
+ {
3043
+ mf_long cv_count = 0;
3044
+ for(auto block : cv_blocks)
3045
+ cv_count += block_ptrs[block]->get_nnz();
3046
+
3047
+ *cv_error = util.calc_error(block_ptrs, cv_blocks, *model)/cv_count;
3048
+
3049
+ switch(param.fun)
3050
+ {
3051
+ case P_L2_MFR:
3052
+ *cv_error = sqrt(*cv_error*scale*scale);
3053
+ break;
3054
+ case P_L1_MFR:
3055
+ case P_KL_MFR:
3056
+ *cv_error *= scale;
3057
+ break;
3058
+ }
3059
+ }
3060
+
3061
+ #if defined USESSE || defined USEAVX
3062
+ _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode);
3063
+ #endif
3064
+ }
3065
+
3066
+ shared_ptr<mf_model> fpsg(
3067
+ mf_problem const *tr_,
3068
+ mf_problem const *va_,
3069
+ mf_parameter param,
3070
+ vector<mf_int> cv_blocks=vector<mf_int>(),
3071
+ mf_double *cv_error=nullptr)
3072
+ {
3073
+ shared_ptr<mf_model> model;
3074
+ try
3075
+ {
3076
+ Utility util(param.fun, param.nr_threads);
3077
+ Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3078
+ shared_ptr<mf_problem> tr;
3079
+ shared_ptr<mf_problem> va;
3080
+ vector<Block> blocks(param.nr_bins*param.nr_bins);
3081
+ vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3082
+ vector<mf_node*> ptrs;
3083
+ vector<mf_int> p_map;
3084
+ vector<mf_int> q_map;
3085
+ vector<mf_int> inv_p_map;
3086
+ vector<mf_int> inv_q_map;
3087
+ vector<mf_int> omega_p;
3088
+ vector<mf_int> omega_q;
3089
+ mf_float avg = 0;
3090
+ mf_float std_dev = 0;
3091
+ mf_float scale = 1;
3092
+
3093
+ if(param.copy_data)
3094
+ {
3095
+ tr = shared_ptr<mf_problem>(
3096
+ Utility::copy_problem(tr_, true), deleter());
3097
+ va = shared_ptr<mf_problem>(
3098
+ Utility::copy_problem(va_, true), deleter());
3099
+ }
3100
+ else
3101
+ {
3102
+ tr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3103
+ va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3104
+ }
3105
+
3106
+ util.collect_info(*tr, avg, std_dev);
3107
+
3108
+ if(param.fun == P_L2_MFR ||
3109
+ param.fun == P_L1_MFR ||
3110
+ param.fun == P_KL_MFR)
3111
+ scale = max((mf_float)1e-4, std_dev);
3112
+
3113
+ p_map = Utility::gen_random_map(tr->m);
3114
+ q_map = Utility::gen_random_map(tr->n);
3115
+ inv_p_map = Utility::gen_inv_map(p_map);
3116
+ inv_q_map = Utility::gen_inv_map(q_map);
3117
+ omega_p = vector<mf_int>(tr->m, 0);
3118
+ omega_q = vector<mf_int>(tr->n, 0);
3119
+
3120
+ util.shuffle_problem(*tr, p_map, q_map);
3121
+ util.shuffle_problem(*va, p_map, q_map);
3122
+ util.scale_problem(*tr, (mf_float)1.0/scale);
3123
+ util.scale_problem(*va, (mf_float)1.0/scale);
3124
+ ptrs = util.grid_problem(*tr, param.nr_bins, omega_p, omega_q, blocks);
3125
+
3126
+ model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3127
+ tr->m, tr->n, param.k, avg/scale, omega_p, omega_q),
3128
+ [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3129
+
3130
+ for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3131
+ block_ptrs[i] = &blocks[i];
3132
+
3133
+ fpsg_core(util, sched, tr.get(), va.get(), param, scale,
3134
+ block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3135
+
3136
+ if(!param.copy_data)
3137
+ {
3138
+ util.scale_problem(*tr, scale);
3139
+ util.scale_problem(*va, scale);
3140
+ util.shuffle_problem(*tr, inv_p_map, inv_q_map);
3141
+ util.shuffle_problem(*va, inv_p_map, inv_q_map);
3142
+ }
3143
+
3144
+ util.scale_model(*model, scale);
3145
+ Utility::shrink_model(*model, param.k);
3146
+ Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3147
+ }
3148
+ catch(exception const &e)
3149
+ {
3150
+ cerr << e.what() << endl;
3151
+ throw;
3152
+ }
3153
+ return model;
3154
+ }
3155
+
3156
+ shared_ptr<mf_model> fpsg_on_disk(
3157
+ const string tr_path,
3158
+ const string va_path,
3159
+ mf_parameter param,
3160
+ vector<mf_int> cv_blocks=vector<mf_int>(),
3161
+ mf_double *cv_error=nullptr)
3162
+ {
3163
+ shared_ptr<mf_model> model;
3164
+ try
3165
+ {
3166
+ Utility util(param.fun, param.nr_threads);
3167
+ Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3168
+ mf_problem tr = {};
3169
+ mf_problem va = read_problem(va_path.c_str());
3170
+ vector<BlockOnDisk> blocks(param.nr_bins*param.nr_bins);
3171
+ vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3172
+ vector<mf_int> p_map;
3173
+ vector<mf_int> q_map;
3174
+ vector<mf_int> inv_p_map;
3175
+ vector<mf_int> inv_q_map;
3176
+ vector<mf_int> omega_p;
3177
+ vector<mf_int> omega_q;
3178
+ mf_float avg = 0;
3179
+ mf_float std_dev = 0;
3180
+ mf_float scale = 1;
3181
+
3182
+ util.collect_info_on_disk(tr_path, tr, avg, std_dev);
3183
+
3184
+ if(param.fun == P_L2_MFR ||
3185
+ param.fun == P_L1_MFR ||
3186
+ param.fun == P_KL_MFR)
3187
+ scale = max((mf_float)1e-4, std_dev);
3188
+
3189
+ p_map = Utility::gen_random_map(tr.m);
3190
+ q_map = Utility::gen_random_map(tr.n);
3191
+ inv_p_map = Utility::gen_inv_map(p_map);
3192
+ inv_q_map = Utility::gen_inv_map(q_map);
3193
+ omega_p = vector<mf_int>(tr.m, 0);
3194
+ omega_q = vector<mf_int>(tr.n, 0);
3195
+
3196
+ util.shuffle_problem(va, p_map, q_map);
3197
+ util.scale_problem(va, (mf_float)1.0/scale);
3198
+
3199
+ util.grid_shuffle_scale_problem_on_disk(
3200
+ tr.m, tr.n, param.nr_bins, scale, tr_path,
3201
+ p_map, q_map, omega_p, omega_q, blocks);
3202
+
3203
+ model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3204
+ tr.m, tr.n, param.k, avg/scale, omega_p, omega_q),
3205
+ [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3206
+
3207
+ for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3208
+ block_ptrs[i] = &blocks[i];
3209
+
3210
+ fpsg_core(util, sched, &tr, &va, param, scale,
3211
+ block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3212
+
3213
+ delete [] va.R;
3214
+
3215
+ util.scale_model(*model, scale);
3216
+ Utility::shrink_model(*model, param.k);
3217
+ Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3218
+ }
3219
+ catch(exception const &e)
3220
+ {
3221
+ cerr << e.what() << endl;
3222
+ throw;
3223
+ }
3224
+ return model;
3225
+ }
3226
+
3227
+ // The function implements an efficient method to compute objective function
3228
+ // minimized by coordinate descent method.
3229
+ //
3230
+ // \min_{P, Q} 0.5 * \sum_{(u,v)\in\Omega^+} (1-r_{u,v})^2 +
3231
+ // 0.5 * \alpha \sum_{(u,v)\not\in\Omega^+} (c-r_{u,v})^2 +
3232
+ // 0.5 * \lambda_p2 * ||P||_F^2 + 0.5 * \lambda_q2 * ||Q||_F^2
3233
+ // where
3234
+ // 1. (u,v) is a tuple of row index and column index,
3235
+ // 2. \Omega^+ a collections of (u,v) which specifies the locations of
3236
+ // positive entries in the training matrix.
3237
+ // 3. r_{u,v} is the predicted rating at (u,v)
3238
+ // 4. \alpha is the weight of negative entries' loss.
3239
+ // 5. c is the desired value at every negative entries.
3240
+ // 6. ||P||_F is matrix P's Frobenius norm.
3241
+ // 7. \lambda_p2 is the regularization coefficient of P.
3242
+ //
3243
+ // Note that coordinate descent method's P and Q are the transpose
3244
+ // counterparts of P and Q in stochastic gradient method. Let R denoates
3245
+ // the training matrix. For stochastic gradient method, we have R ~ P^TQ.
3246
+ // For coordinate descent method, we have R ~ PQ^T.
3247
+ void calc_ccd_one_class_obj(const mf_int nr_threads,
3248
+ const mf_float alpha, const mf_float c,
3249
+ const mf_int m, const mf_int n, const mf_int d,
3250
+ const mf_float lambda_p2, const mf_float lambda_q2,
3251
+ const mf_float *P, const mf_float *Q,
3252
+ shared_ptr<const mf_problem> data,
3253
+ /*output*/ mf_double &obj,
3254
+ /*output*/ mf_double &positive_loss,
3255
+ /*output*/ mf_double &negative_loss,
3256
+ /*output*/ mf_double &reg)
3257
+ {
3258
+ // Declare regularization term of P.
3259
+ mf_double p_square_norm = 0.0;
3260
+ // Reduce P along column axis, which is the sum of rows in P.
3261
+ vector<mf_double> all_p_sum(d, 0.0);
3262
+ // Compute square of Frobenius norm on P and sum of all rows in P.
3263
+ for(mf_int k = 0; k < d; ++k)
3264
+ {
3265
+ // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3266
+ mf_double all_p_sum_k = 0.0;
3267
+ #if defined USEOMP
3268
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_square_norm,all_p_sum_k)
3269
+ #endif
3270
+ for(mf_int u = 0; u < m; ++u)
3271
+ {
3272
+ const mf_float &p_ku = P[u + k * m];
3273
+ p_square_norm += p_ku * p_ku;
3274
+ all_p_sum_k += p_ku;
3275
+ }
3276
+ all_p_sum[k] = all_p_sum_k;
3277
+ }
3278
+
3279
+ // Declare regularization term of Q
3280
+ mf_double q_square_norm = 0.0;
3281
+ // Reduce Q along column axis, whihc is the sum of rows in Q.
3282
+ vector<mf_double> all_q_sum(d, 0.0);
3283
+ // Compute square of Frobenius norm on Q and sum of all elements in Q
3284
+ for(mf_int k = 0; k < d; ++k)
3285
+ {
3286
+ // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3287
+ mf_double all_q_sum_k = 0.0;
3288
+ #if defined USEOMP
3289
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_square_norm,all_q_sum_k)
3290
+ #endif
3291
+ for(mf_int v = 0; v < n; ++v)
3292
+ {
3293
+ const mf_float &q_kv = Q[v + k * n];
3294
+ q_square_norm += q_kv * q_kv;
3295
+ all_q_sum_k += q_kv;
3296
+ }
3297
+ all_q_sum[k] = all_q_sum_k;
3298
+ }
3299
+
3300
+ // PTP = P^T * P, where P^T is the transpose of P. Note that P is a m-by-d
3301
+ // matrix and PTP is a d-by-d matrix.
3302
+ vector<mf_double> PTP(d * d, 0.0);
3303
+ // QTQ = Q^T * P, a d-by-d matrix.
3304
+ vector<mf_double> QTQ(d * d, 0.0);
3305
+ // We calculate PTP and QTQ because they are needed in the computation of
3306
+ // negative entries' loss function.
3307
+ for(mf_int k1 = 0; k1 < d; ++k1)
3308
+ {
3309
+ for(mf_int k2 = 0; k2 < d; ++k2)
3310
+ {
3311
+ // Inner product of the k1 and k2 columns in P, a m-by-d matrix.
3312
+ mf_double p_k1_p_k2_inner_product = 0.0;
3313
+ #if defined USEOMP
3314
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_k1_p_k2_inner_product)
3315
+ #endif
3316
+ for(mf_int u = 0; u < m; ++u)
3317
+ p_k1_p_k2_inner_product += P[u + k1 * m] * P[u + k2 * m];
3318
+ PTP[k1 * d + k2] = p_k1_p_k2_inner_product;
3319
+
3320
+ // Inner product of the k1 and k2 columns in Q, a n-by-d matrix.
3321
+ mf_double q_k1_q_k2_inner_product = 0.0;
3322
+ #if defined USEOMP
3323
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_k1_q_k2_inner_product)
3324
+ #endif
3325
+ for(mf_int v = 0; v < n; ++v)
3326
+ q_k1_q_k2_inner_product += Q[v + k1 * n] * Q[v + k2 * n];
3327
+ QTQ[k1 * d + k2] = q_k1_q_k2_inner_product;
3328
+ }
3329
+ }
3330
+
3331
+ // Initialize loss function value of positive matrix entries.
3332
+ // It consists two parts. The first part is the true prediction error
3333
+ // while the second part is only used for implementing faster algorithm.
3334
+ mf_double positive_loss1 = 0.0;
3335
+ mf_double positive_loss2 = 0.0;
3336
+ // Scan through positive matrix entries to compute their loss values.
3337
+ // Notice that we assume that positive entries' values are all one.
3338
+ #if defined USEOMP
3339
+ #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:positive_loss1,positive_loss2)
3340
+ #endif
3341
+ for(mf_long i = 0; i < data->nnz; ++i)
3342
+ {
3343
+ const mf_double &r = data->R[i].r;
3344
+ positive_loss1 += (1.0 - r) * (1.0 - r);
3345
+ positive_loss2 -= alpha * (c - r) * (c - r);
3346
+ }
3347
+ positive_loss1 *= 0.5;
3348
+ positive_loss2 *= 0.5;
3349
+
3350
+ // Declare loss terms related to negative matrix entries.
3351
+ mf_double negative_loss1 = c * c * m * n;
3352
+ mf_double negative_loss2 = 0.0;
3353
+ mf_double negative_loss3 = 0.0;
3354
+ // Compute loss terms.
3355
+ for(mf_int k1 = 0; k1 < d; ++k1)
3356
+ {
3357
+ negative_loss2 += all_p_sum[k1] * all_q_sum[k1];
3358
+ for(mf_int k2 = 0; k2 < d; ++k2)
3359
+ negative_loss3 += PTP[k1 + k2 * d] * QTQ[k2 + k1 * d];
3360
+ }
3361
+ // Compute the loss function of negative matrix entries.
3362
+ mf_double negative_loss4 = 0.5 * alpha *
3363
+ (negative_loss1 - 2 * c * negative_loss2 + negative_loss3);
3364
+
3365
+ // Assign results to output variables.
3366
+ reg = 0.5 * lambda_p2 * p_square_norm + 0.5 * lambda_q2 * q_square_norm;
3367
+
3368
+ // The function minimized by coordinate descent method.
3369
+ obj = positive_loss1 + positive_loss2 + negative_loss4 + reg;
3370
+
3371
+ // Sume of squared error over positive matrix entries (i.e., those mf_node's
3372
+ // in data).
3373
+ positive_loss = positive_loss1;
3374
+
3375
+ // Sume of squared error over negative matrix entries (i.e., those mf_node's
3376
+ // in data). The value negative_loss4 contains the squared errors by
3377
+ // considering positive entries as negative entries, so positive_loss2 is
3378
+ // added to compensate that.
3379
+ negative_loss = negative_loss4 + positive_loss2;
3380
+ }
3381
+
3382
+ void ccd_one_class_core(
3383
+ const Utility &util,
3384
+ shared_ptr<const mf_problem> tr_csr,
3385
+ shared_ptr<const mf_problem> tr_csc,
3386
+ shared_ptr<const mf_problem> va,
3387
+ const mf_parameter param,
3388
+ const vector<mf_node*> &ptrs_u,
3389
+ const vector<mf_node*> &ptrs_v,
3390
+ /*output*/ shared_ptr<mf_model> &model)
3391
+ {
3392
+ // Check problems stored in CSR and CSC formats
3393
+ if(tr_csr == nullptr) throw invalid_argument("CSR problem pointer is null.");
3394
+ if(tr_csc == nullptr) throw invalid_argument("CSC problem pointer is null.");
3395
+
3396
+ if(tr_csr->m != tr_csc->m)
3397
+ throw logic_error(
3398
+ "Row counts must be identical in CSR and CSC formats: " +
3399
+ to_string(tr_csr->m) + " != " + to_string(tr_csc->m));
3400
+ const mf_int m = tr_csr->m;
3401
+
3402
+ if(tr_csr->n != tr_csc->n)
3403
+ throw logic_error(
3404
+ "Column counts must be identical in CSR and CSC formats: " +
3405
+ to_string(tr_csr->n) + " != " + to_string(tr_csc->n));
3406
+ const mf_int n = tr_csr->n;
3407
+
3408
+ if(tr_csc->nnz != tr_csc->nnz)
3409
+ throw logic_error(
3410
+ "Numbers of data points must be identical in CSR and CSC formats: " +
3411
+ to_string(tr_csr->nnz) + " != " + to_string(tr_csc->nnz));
3412
+ const mf_long nnz = tr_csr->nnz;
3413
+
3414
+ // Check formulation parameters
3415
+ if(param.k <= 0)
3416
+ throw invalid_argument(
3417
+ "Latent dimension must be positive but got " +
3418
+ to_string(param.k));
3419
+ const mf_int d = param.k;
3420
+
3421
+ if(param.lambda_p1 != 0)
3422
+ throw invalid_argument(
3423
+ "P's L1-regularization coefficient must be zero but got " +
3424
+ to_string(param.lambda_p1));
3425
+ if(param.lambda_q1 != 0)
3426
+ throw invalid_argument(
3427
+ "Q's L1-regularization coefficient must be zero but got " +
3428
+ to_string(param.lambda_q1));
3429
+
3430
+ if(param.lambda_p2 <= 0)
3431
+ throw invalid_argument(
3432
+ "P's L2-regularization coefficient must be positive but got " +
3433
+ to_string(param.lambda_p2));
3434
+ if(param.lambda_q2 <= 0)
3435
+ throw invalid_argument(
3436
+ "Q's L2-regularization coefficient must be positive but got " +
3437
+ to_string(param.lambda_q2));
3438
+
3439
+ // REVIEW: It is not difficult to support non-negative matrix factorization
3440
+ // for coordinate descent method; we just need to project the updated value
3441
+ // back to the feasible region by using max(0, new_value) right after each
3442
+ // Newton step. LIBMF hasn't support it only because we don't see actual
3443
+ // users.
3444
+ if(param.do_nmf)
3445
+ throw invalid_argument(
3446
+ "Coordinate descent does not support non-negative constraint");
3447
+
3448
+
3449
+ // Check some resources prepared internally
3450
+ if(ptrs_u.size() != (size_t)m + 1)
3451
+ throw invalid_argument("Number of row pointer must be " +
3452
+ to_string(m + 1) + " but got " + to_string(ptrs_u.size()));
3453
+ if(ptrs_v.size() != (size_t)n + 1)
3454
+ throw invalid_argument("Number of column pointer must be " +
3455
+ to_string(n + 1) + " but got " + to_string(ptrs_v.size()));
3456
+
3457
+ // Some constants of the formulation.
3458
+ // alpha: coefficient of negative part
3459
+ // c: the desired prediction values of unobserved ratings
3460
+ // lambda_p2: regularization coefficient of P's L2-norm
3461
+ // lambda_q2: regularization coefficient of P's Q2-norm
3462
+ const mf_float alpha = param.alpha;
3463
+ const mf_float c = param.c;
3464
+ const mf_float lambda_p2 = param.lambda_p2;
3465
+ const mf_float lambda_q2 = param.lambda_q2;
3466
+
3467
+ // Initialize P and Q. Note that \bar{q}_{kv} is Q[k*n+v]
3468
+ // and \bar{p}_{ku} is P[k*m+u]. One may notice that P and
3469
+ // Q here are actually the transposes of P and Q in FPSG.
3470
+ mf_float *P = model->P;
3471
+ mf_float *Q = model->Q;
3472
+
3473
+ // Cache the prediction values on positive matrix entries.
3474
+ // Given that P=zero and Q=random initialized in
3475
+ // Utility::init_model(mf_int m, mf_int n, mf_int k),
3476
+ // all predictions are zeros.
3477
+ #if defined USEOMP
3478
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3479
+ #endif
3480
+ for(mf_long i = 0; i < nnz; ++i)
3481
+ {
3482
+ tr_csr->R[i].r = 0.0;
3483
+ tr_csc->R[i].r = 0.0;
3484
+ }
3485
+
3486
+ // If the model is not initialized by
3487
+ // Utility::init_model(mf_int m, mf_int n, mf_int k),
3488
+ // please use the following initialization code to compute
3489
+ // and cache all prediction values on positive entries.
3490
+ /*
3491
+ for(mf_long i = 0; i < nnz; ++i)
3492
+ {
3493
+ mf_node &node = tr_csr->R[i];
3494
+ node.r = 0;
3495
+ for(mf_int k = 0; k < d; ++k)
3496
+ node.r += P[node.u + k * m]*Q[node.v + k * n];
3497
+ }
3498
+ for(mf_long i = 0; i < nnz; ++i)
3499
+ {
3500
+ mf_node &node = tr_csc->R[i];
3501
+ node.r = 0;
3502
+ for(mf_int k = 0; k < d; ++k)
3503
+ node.r += P[node.u + k * m]*Q[node.v + k * n];
3504
+ }
3505
+ */
3506
+
3507
+ if(!param.quiet)
3508
+ {
3509
+ cout.width(4);
3510
+ cout << "iter";
3511
+ cout.width(13);
3512
+ cout << "tr_"+util.get_error_legend();
3513
+ cout.width(14);
3514
+ cout << "tr_"+util.get_error_legend() << "+";
3515
+ cout.width(14);
3516
+ cout << "tr_"+util.get_error_legend() << "-";
3517
+ if(va->nnz != 0)
3518
+ {
3519
+ cout.width(13);
3520
+ cout << "va_"+util.get_error_legend();
3521
+ cout.width(14);
3522
+ cout << "va_"+util.get_error_legend() << "+";
3523
+ cout.width(14);
3524
+ cout << "va_"+util.get_error_legend() << "-";
3525
+ }
3526
+ cout.width(13);
3527
+ cout << "obj";
3528
+ cout << "\n";
3529
+ }
3530
+
3531
+ /////////////////////////////////////////////////////////////////
3532
+ // Minimize the objective function via coordinate descent method
3533
+ ////////////////////////////////////////////////////////////////
3534
+ // Solve P and Q using coordinate descent.
3535
+ // P = [\bar{p}_1, ..., \bar{p}_d] \in R^{m \times k}
3536
+ // Q = [\bar{q}_1, ..., \bar{q}_d] \in R^{n \times k}
3537
+ // Finally, the rating matrice R would be approximated via
3538
+ // R ~ PQ^T \in R^{m \times n}
3539
+ for(mf_int outer = 0; outer < param.nr_iters; ++outer)
3540
+ {
3541
+ // Update \bar{p}_k and \bar{q}_k. The basic idea is
3542
+ // to replace \bar{p}_k and \bar{q}_k with a and b,
3543
+ // and then minimizes the original objective function.
3544
+ for(mf_int k = 0; k < d; ++k)
3545
+ {
3546
+ // Get the pointer to the first element of \bar{p}_k (and
3547
+ // \bar{q}_k).
3548
+ mf_float *P_k = P + m * k;
3549
+ mf_float *Q_k = Q + n * k;
3550
+
3551
+ // Initialize a and b with the value they need to replace
3552
+ // so that we can ensure improvement at each iteration.
3553
+ vector<mf_float> a(P_k, P_k + m);
3554
+ vector<mf_float> b(Q_k, Q_k + n);
3555
+
3556
+ for(mf_int inner = 0; inner < 3; ++inner)
3557
+ {
3558
+ ///////////////////////////////////////////////////////////////
3559
+ // Update a:
3560
+ // 1. Compute and cache constants
3561
+ // 2. For each coordinate of a, calculate optimal update using
3562
+ // Newton method
3563
+ ///////////////////////////////////////////////////////////////
3564
+
3565
+ // Compute and cache constants
3566
+ // \hat{b} = \sum_{v=1}^n \bar{b}_v
3567
+ // \tilde{b} = \sum_{v=1}^n \bar{b}_v^2
3568
+ mf_double b_hat = 0.0;
3569
+ mf_double b_tilde = 0.0;
3570
+ #if defined USEOMP
3571
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:b_hat,b_tilde)
3572
+ #endif
3573
+ for(mf_int v = 0; v < n; ++v)
3574
+ {
3575
+ const mf_double &b_v = b[v];
3576
+ b_hat += b_v;
3577
+ b_tilde += b_v * b_v;
3578
+ }
3579
+
3580
+ // Compute and cache a constant vector
3581
+ // s_k = \sum_{v=1}^n \bar{q}_{kv}b_v, k = 1, ..., d
3582
+ vector<mf_double> s(d, 0.0);
3583
+ for(mf_int k1 = 0; k1 < d; ++k1)
3584
+ {
3585
+ // Buffer variable for using OpenMP
3586
+ mf_double s_k1 = 0;
3587
+ const mf_float *Q_k1 = Q + k1 * n;
3588
+ #if defined USEOMP
3589
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:s_k1)
3590
+ #endif
3591
+ for(mf_int v = 0; v < n; ++v)
3592
+ s_k1 += Q_k1[v] * b[v];
3593
+ s[k1] = s_k1;
3594
+ }
3595
+
3596
+ // Solve a's sub-problem
3597
+ #if defined USEOMP
3598
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3599
+ #endif
3600
+ for(mf_int u = 0; u < m; ++u)
3601
+ {
3602
+ ////////////////////////////////////////////////////////
3603
+ // Update a[u] via Newton method. Let g_u and h_u denote
3604
+ // the first-order and second-order derivatives w.r.t.
3605
+ // a[u]. The following code implements
3606
+ // a[u] <-- a[u] - g_u/h_u
3607
+ ////////////////////////////////////////////////////////
3608
+
3609
+ // Initialize temporal variables for calculating gradient and hessian.
3610
+ mf_double g_u_1 = 0.0;
3611
+ mf_double h_u_1 = 0.0;
3612
+ mf_double g_u_2 = 0.0;
3613
+ // Scan through specified entries at the u-th row
3614
+ for(const mf_node *ptr = ptrs_u[u]; ptr != ptrs_u[u+1]; ++ptr)
3615
+ {
3616
+ const mf_int &v = ptr->v;
3617
+ const mf_float &b_v = b[v];
3618
+ g_u_1 += b_v;
3619
+ h_u_1 += b_v * b_v;
3620
+ g_u_2 += (ptr->r - P_k[u] * Q_k[v] + a[u] * b_v) * b_v;
3621
+ }
3622
+ mf_double g_u_3 = -c * b_hat - P_k[u] * s[k] + a[u] * b_tilde;
3623
+ for(mf_int k1 = 0; k1 < d; ++k1)
3624
+ g_u_3 += P[m * k1 + u] * s[k1];
3625
+ mf_double g_u = -(1.0 - alpha * c) * g_u_1 + (1.0 - alpha) * g_u_2 + alpha * g_u_3 + lambda_p2 * a[u];
3626
+ mf_double h_u = (1.0 - alpha) * h_u_1 + alpha * b_tilde + lambda_p2;
3627
+ a[u] -= static_cast<mf_float>(g_u / h_u);
3628
+ }
3629
+
3630
+ ///////////////////////////////////////////////////////////////
3631
+ // Update b:
3632
+ // 1. Compute and cache constants
3633
+ // 2. For each coordinate of b, calculate optimal update using
3634
+ // Newton method
3635
+ ///////////////////////////////////////////////////////////////
3636
+ // Compute and cache a_hat, a_tilde
3637
+ // \hat{a} = \sum_{u=1}^m \bar{a}_u
3638
+ // \tilde{a} = \sum_{u=1}^m \bar{a}_u^2
3639
+ mf_double a_hat = 0.0;
3640
+ mf_double a_tilde = 0.0;
3641
+ #if defined USEOMP
3642
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:a_hat,a_tilde)
3643
+ #endif
3644
+ for(mf_int u = 0; u < m; ++u)
3645
+ {
3646
+ const mf_float &a_u = a[u];
3647
+ a_hat += a_u;
3648
+ a_tilde += a_u * a_u;
3649
+ }
3650
+
3651
+ // Compute and cache t
3652
+ // t_k = \sum_{u=1}^m \bar{a}_{ku}a_u, k = 1, ..., d
3653
+ vector<mf_double> t(d, 0.0);
3654
+ for(mf_int k1 = 0; k1 < d; ++k1)
3655
+ {
3656
+ // Declare buffer variable for using OpenMP
3657
+ mf_double t_k1 = 0;
3658
+ const mf_float *P_k1 = P + k1 * m;
3659
+ #if defined USEOMP
3660
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:t_k1)
3661
+ #endif
3662
+ for(mf_int u = 0; u < m; ++u)
3663
+ t_k1 += P_k1[u] * a[u];
3664
+ t[k1] = t_k1;
3665
+ }
3666
+
3667
+ #if defined USEOMP
3668
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3669
+ #endif
3670
+ for(mf_int v = 0; v < n; ++v)
3671
+ {
3672
+ ////////////////////////////////////////////////////////
3673
+ // Update b[v] via Newton method. Let g_v and h_v denote
3674
+ // the first-order and second-order derivatives w.r.t.
3675
+ // b[v]. The following code implements
3676
+ // b[v] <-- b[v] - g_v/h_v
3677
+ ////////////////////////////////////////////////////////
3678
+
3679
+ // Initialize temporal variables for calculating gradient and hessian.
3680
+ mf_double g_v_1 = 0;
3681
+ mf_double g_v_2 = 0;
3682
+ mf_double h_v_1 = 0;
3683
+ // Scan through all positive entries at column v
3684
+ for(const mf_node *ptr = ptrs_v[v]; ptr != ptrs_v[v+1]; ++ptr)
3685
+ {
3686
+ const mf_int &u = ptr->u;
3687
+ const mf_float &a_u = a[u];
3688
+ g_v_1 += a_u;
3689
+ h_v_1 += a_u * a_u;
3690
+ g_v_2 += (ptr->r - P_k[u] * Q_k[v] + a_u * b[v]) * a_u;
3691
+ }
3692
+ mf_double g_v_3 = -c * a_hat - Q_k[v] * t[k] + b[v] * a_tilde;
3693
+ for(mf_int k1 = 0; k1 < d; ++k1)
3694
+ g_v_3 += Q[n * k1 + v] * t[k1];
3695
+ mf_double g_v = -(1.0 - alpha * c) * g_v_1 + (1.0 - alpha) * g_v_2 +
3696
+ alpha * g_v_3 + lambda_q2 * b[v];
3697
+ mf_double h_v = (1 - alpha) * h_v_1 + alpha * a_tilde + lambda_q2;
3698
+ b[v] -= static_cast<mf_float>(g_v / h_v);
3699
+ }
3700
+
3701
+ ///////////////////////////////////////////////////////////////
3702
+ // Update cached variables.
3703
+ ///////////////////////////////////////////////////////////////
3704
+ // Update prediction error in CSR format
3705
+ // \bar{r}_{uv} <- \bar{r}_{uv} - \bar_{p}_{ku}*\bar_{q}_{kv} + a_u*b_v
3706
+ #if defined USEOMP
3707
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3708
+ #endif
3709
+ for(mf_long i = 0; i < tr_csr->nnz; ++i)
3710
+ {
3711
+ // Update prediction values of positive entries in CSR
3712
+ mf_node *csr_ptr = tr_csr->R + i;
3713
+ const mf_int &u_csr = csr_ptr->u;
3714
+ const mf_int &v_csr = csr_ptr->v;
3715
+ csr_ptr->r += a[u_csr] * b[v_csr] - P_k[u_csr] * Q_k[v_csr];
3716
+
3717
+ // Update prediction values of positive entries in CSC
3718
+ mf_node *csc_ptr = tr_csc->R + i;
3719
+ const mf_int &u_csc = csc_ptr->u;
3720
+ const mf_int &v_csc = csc_ptr->v;
3721
+ csc_ptr->r += a[u_csc] * b[v_csc] - P_k[u_csc] * Q_k[v_csc];
3722
+ }
3723
+
3724
+ #if defined USEOMP
3725
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3726
+ #endif
3727
+ // Update P_k and Q_k
3728
+ for(mf_int u = 0; u < m; ++u)
3729
+ P_k[u] = a[u];
3730
+ #if defined USEOMP
3731
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3732
+ #endif
3733
+ for(mf_int v = 0; v < n; ++v)
3734
+ Q_k[v] = b[v];
3735
+ }
3736
+ }
3737
+
3738
+ // Skip the whole evaluation if nothing should be printed out.
3739
+ if(param.quiet)
3740
+ continue;
3741
+
3742
+ // Declare variable for storing objective value being minimized
3743
+ // by the training procedure. Note that The objective value consists
3744
+ // of two parts, loss function and regularization function.
3745
+ mf_double obj = 0;
3746
+ // Declare variables for storing loss function's value.
3747
+ mf_double positive_loss = 0; // for positive entries in training matrix.
3748
+ mf_double negative_loss = 0; // for negative entries in training matrix.
3749
+ // Declare variable for storing regularization function's value.
3750
+ mf_double reg = 0;
3751
+
3752
+ // Compute objective value, loss function value, and regularization
3753
+ // function value
3754
+ calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3755
+ lambda_p2, lambda_q2, P, Q, tr_csr,
3756
+ obj, positive_loss, negative_loss, reg);
3757
+
3758
+ // Print number of outer iterations.
3759
+ cout.width(4);
3760
+ cout << outer;
3761
+ cout.width(13);
3762
+ cout << fixed << setprecision(4) << positive_loss + negative_loss;
3763
+ cout.width(15);
3764
+ cout << fixed << setprecision(4) << positive_loss;
3765
+ cout.width(15);
3766
+ cout << fixed << setprecision(4) << negative_loss;
3767
+
3768
+ if(va->nnz != 0)
3769
+ {
3770
+ // The following loop computes prediction scores on validation set.
3771
+ // Because training scores is also maintained in coordinate descent
3772
+ // framework, we didn't need to actively compute scores on training set.
3773
+ #if defined USEOMP
3774
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3775
+ #endif
3776
+ for(mf_long i = 0; i < va->nnz; ++i)
3777
+ {
3778
+ mf_node &node = va->R[i];
3779
+ node.r = 0;
3780
+ for(mf_int k = 0; k < d; ++k)
3781
+ node.r += P[node.u + k * m]*Q[node.v + k * n];
3782
+ }
3783
+
3784
+ mf_double va_obj = 0;
3785
+ mf_double va_positive_loss = 0;
3786
+ mf_double va_negative_loss = 0;
3787
+ mf_double va_reg = 0;
3788
+
3789
+ calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3790
+ lambda_p2, lambda_q2, P, Q, va,
3791
+ va_obj, va_positive_loss, va_negative_loss, va_reg);
3792
+
3793
+ cout.width(13);
3794
+ cout << fixed << setprecision(4) << va_positive_loss + va_negative_loss;
3795
+ cout.width(15);
3796
+ cout << fixed << setprecision(4) << va_positive_loss;
3797
+ cout.width(15);
3798
+ cout << fixed << setprecision(4) << va_negative_loss;
3799
+ }
3800
+
3801
+ cout.width(13);
3802
+ cout << fixed << setprecision(4) << scientific << obj;
3803
+ cout << "\n" << flush;
3804
+ }
3805
+
3806
+ // Transpose P and Q. Note that the format of P and Q here are different
3807
+ // than that for mf_model.
3808
+
3809
+ mf_float *P_transpose = Utility::malloc_aligned_float((mf_long)m * d);
3810
+ #if defined USEOMP
3811
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3812
+ #endif
3813
+ for(mf_int u = 0; u < m; ++u)
3814
+ for(mf_int k = 0; k < d; ++k)
3815
+ P_transpose[k + u * d] = P[u + k * m];
3816
+ Utility::free_aligned_float(P);
3817
+ mf_float *Q_transpose = Utility::malloc_aligned_float((mf_long)n * d);
3818
+ #if defined USEOMP
3819
+ #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3820
+ #endif
3821
+ for(mf_int v = 0; v < n; ++v)
3822
+ for(mf_int k = 0; k < d; ++k)
3823
+ Q_transpose[k + v * d] = Q[v + k * n];
3824
+ Utility::free_aligned_float(Q);
3825
+
3826
+ // Set the passed-in model to the result learned from the given data
3827
+ // model is null
3828
+ model->m = m;
3829
+ model->n = n;
3830
+ model->k = d;
3831
+ model->b = 0.0;
3832
+ model->P = P_transpose;
3833
+ model->Q = Q_transpose;
3834
+ }
3835
+
3836
+ shared_ptr<mf_model> ccd_one_class(
3837
+ mf_problem const *tr_,
3838
+ mf_problem const *va_,
3839
+ mf_parameter param)
3840
+ {
3841
+ shared_ptr<mf_model> model;
3842
+ try
3843
+ {
3844
+ Utility util(param.fun, param.nr_threads);
3845
+ // Training matrix in compressed row format (sort nodes by user id)
3846
+ shared_ptr<mf_problem> tr_csr;
3847
+ // Training matrix in compressed column format (sort nodes by item id)
3848
+ shared_ptr<mf_problem> tr_csc;
3849
+ shared_ptr<mf_problem> va;
3850
+ // In tr_csr->R, i-th row starting at row_ptrs[i] and eneding right before row_ptrs[i+1]
3851
+ vector<mf_node*> ptrs_u(tr_->m + 1, nullptr);
3852
+ // In tr_csv->R, i-th column starting at col_ptrs[i] and eneding right before col_ptrs[i+1]
3853
+ vector<mf_node*> ptrs_v(tr_->n + 1, nullptr);
3854
+
3855
+ if(param.copy_data)
3856
+ {
3857
+ // Need a row-major and a column-major training formats
3858
+ // Thus, two duplicates are made.
3859
+ tr_csr = shared_ptr<mf_problem>(
3860
+ Utility::copy_problem(tr_, true), deleter());
3861
+ tr_csc = shared_ptr<mf_problem>(
3862
+ Utility::copy_problem(tr_, true), deleter());
3863
+ va = shared_ptr<mf_problem>(
3864
+ Utility::copy_problem(va_, true), deleter());
3865
+ }
3866
+ else
3867
+ {
3868
+ // Need a row-major and a column-major training formats
3869
+ // The original data is reused as row-major one so
3870
+ // one duplicate for column-major one would be created.
3871
+ tr_csr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3872
+ tr_csc = shared_ptr<mf_problem>(Utility::copy_problem(tr_, true));
3873
+ va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3874
+ }
3875
+
3876
+ // Make the training set CSR/CSC by sorting their nodes. More specifically,
3877
+ // a matrix with values sorted by row index is CSR and vice versa. We will
3878
+ // compute the starting location for each row (CSR) and each column (CSC)
3879
+ // later.
3880
+ sort(tr_csr->R, tr_csr->R+tr_csr->nnz, sort_node_by_p());
3881
+ sort(tr_csc->R, tr_csc->R+tr_csc->nnz, sort_node_by_q());
3882
+
3883
+ // Save starting addresses of rows for CSR and columns for CSC.
3884
+ mf_int u_current = -1;
3885
+ mf_int v_current = -1;
3886
+ for(mf_long i = 0; i < tr_->nnz; ++i)
3887
+ {
3888
+ mf_node* N = nullptr;
3889
+
3890
+ // Deal with CSR format.
3891
+ N = tr_csr->R + i;
3892
+ // Since tr_csr has been sorted by index u, seeing a larger index
3893
+ // implies a new row. Assume a node is encoded a tuple of (u, v, r),
3894
+ // where u is row index, v is column index, and r is entry value.
3895
+ // The nodes in tr_csr->R could be
3896
+ // (0, 1, 0.5), (0, 2, 3.7), (0, 4, -1.2), (2, 0, 1.2), (2, 4, 2.5)
3897
+ // Then, we can see the first element of the 3rd row (indexed by 2)
3898
+ // is (2, 0, 1.2), which is the 4th element in tr_csr->R. Note that
3899
+ // we use the row pointer of the next non-empty row as the pointers
3900
+ // of empty rows. That is,
3901
+ // ptrs[0] = pointer of (0, 1, 0.5)
3902
+ // ptrs[1] = pointer of (0, 2, 1.5)
3903
+ // ptrs[2] = pointer of (0, 2, 1.5)
3904
+ if(N->u > u_current)
3905
+ {
3906
+ // We (if u_current != -1) have assigned starting addresses to rows
3907
+ // indexed by values smaller than or equal to u_current. Thus, we
3908
+ // should handle all rows indexed starting from u_current+1 to the
3909
+ // seen row index N->u.
3910
+ for(mf_int u_passed = u_current + 1; u_passed <= N->u; ++u_passed)
3911
+ {
3912
+ // i-th non-zero value's location in tr_csr is the starting
3913
+ // address of u_passed-th row.
3914
+ ptrs_u[u_passed] = tr_csr->R + i;
3915
+ }
3916
+ u_current = N->u;
3917
+ }
3918
+
3919
+ // Deal with CSC format
3920
+ N = tr_csc->R + i;
3921
+ if(N->v > v_current)
3922
+ {
3923
+ // We (if v_current != -1) have assigned starting addresses to rows
3924
+ // indexed by values smaller than or equal to v_current. Thus, we
3925
+ // should handle all columns indexed starting from v_current+1 to
3926
+ // the seen row index N->v.
3927
+ for(mf_int v_passed = v_current + 1; v_passed <= N->v; ++v_passed)
3928
+ {
3929
+ // i-th non-zero value's location in tr_csc is the starting
3930
+ // address of v_passed-th column.
3931
+ ptrs_v[v_passed] = tr_csc->R + i;
3932
+ }
3933
+ v_current = N->v;
3934
+ }
3935
+
3936
+ }
3937
+ // The bound of the last row. It's the address one-element behind the last
3938
+ // matrix entry.
3939
+ for(mf_int u_passed = u_current + 1; u_passed <= tr_->m; ++u_passed)
3940
+ ptrs_u[u_passed] = tr_csr->R + tr_csr->nnz;
3941
+ // The bound of the last column.
3942
+ for(mf_int v_passed = v_current + 1; v_passed <= tr_->n; ++v_passed)
3943
+ ptrs_v[v_passed] = tr_csc->R + tr_csc->nnz;
3944
+
3945
+
3946
+ model = shared_ptr<mf_model>(Utility::init_model(tr_->m, tr_->n, param.k),
3947
+ [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3948
+
3949
+ ccd_one_class_core(util, tr_csr, tr_csc, va, param, ptrs_u, ptrs_v, model);
3950
+ }
3951
+ catch(exception const &e)
3952
+ {
3953
+ cerr << e.what() << endl;
3954
+ throw;
3955
+ }
3956
+ return model;
3957
+ }
3958
+
3959
+ bool check_parameter(mf_parameter param)
3960
+ {
3961
+ if(param.fun != P_L2_MFR &&
3962
+ param.fun != P_L1_MFR &&
3963
+ param.fun != P_KL_MFR &&
3964
+ param.fun != P_LR_MFC &&
3965
+ param.fun != P_L2_MFC &&
3966
+ param.fun != P_L1_MFC &&
3967
+ param.fun != P_ROW_BPR_MFOC &&
3968
+ param.fun != P_COL_BPR_MFOC &&
3969
+ param.fun != P_L2_MFOC)
3970
+ {
3971
+ cerr << "unknown loss function" << endl;
3972
+ return false;
3973
+ }
3974
+
3975
+ if(param.k < 1)
3976
+ {
3977
+ cerr << "number of factors must be greater than zero" << endl;
3978
+ return false;
3979
+ }
3980
+
3981
+ if(param.nr_threads < 1)
3982
+ {
3983
+ cerr << "number of threads must be greater than zero" << endl;
3984
+ return false;
3985
+ }
3986
+
3987
+ if(param.nr_bins < 1 || param.nr_bins < param.nr_threads)
3988
+ {
3989
+ cerr << "number of bins must be greater than number of threads"
3990
+ << endl;
3991
+ return false;
3992
+ }
3993
+
3994
+ if(param.nr_iters < 1)
3995
+ {
3996
+ cerr << "number of iterations must be greater than zero" << endl;
3997
+ return false;
3998
+ }
3999
+
4000
+ if(param.lambda_p1 < 0 ||
4001
+ param.lambda_p2 < 0 ||
4002
+ param.lambda_q1 < 0 ||
4003
+ param.lambda_q2 < 0)
4004
+ {
4005
+ cerr << "regularization coefficient must be non-negative" << endl;
4006
+ return false;
4007
+ }
4008
+
4009
+ if(param.eta <= 0)
4010
+ {
4011
+ cerr << "learning rate must be greater than zero" << endl;
4012
+ return false;
4013
+ }
4014
+
4015
+ if(param.fun == P_KL_MFR && !param.do_nmf)
4016
+ {
4017
+ cerr << "--nmf must be set when using generalized KL-divergence"
4018
+ << endl;
4019
+ return false;
4020
+ }
4021
+
4022
+ if(param.nr_bins <= 2*param.nr_threads)
4023
+ {
4024
+ cerr << "Warning: insufficient blocks may slow down the training"
4025
+ << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4026
+ }
4027
+
4028
+ if(param.nr_bins <= 2*param.nr_threads)
4029
+ {
4030
+ cerr << "Warning: insufficient blocks may slow down the training"
4031
+ << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4032
+ }
4033
+
4034
+ if(param.alpha < 0)
4035
+ {
4036
+ cerr << "alpha must be a non-negative number" << endl;
4037
+ }
4038
+
4039
+ return true;
4040
+ }
4041
+
4042
+ //--------------------------------------
4043
+ //-----Classes for cross validation-----
4044
+ //--------------------------------------
4045
+
4046
+ class CrossValidatorBase
4047
+ {
4048
+ public:
4049
+ CrossValidatorBase(mf_parameter param_, mf_int nr_folds_);
4050
+ mf_double do_cross_validation();
4051
+ virtual mf_double do_cv1(vector<mf_int> &hidden_blocks) = 0;
4052
+ protected:
4053
+ mf_parameter param;
4054
+ mf_int nr_bins;
4055
+ mf_int nr_folds;
4056
+ mf_int nr_blocks_per_fold;
4057
+ bool quiet;
4058
+ Utility util;
4059
+ mf_double cv_error;
4060
+ };
4061
+
4062
+ CrossValidatorBase::CrossValidatorBase(mf_parameter param_, mf_int nr_folds_)
4063
+ : param(param_), nr_bins(param_.nr_bins), nr_folds(nr_folds_),
4064
+ nr_blocks_per_fold(nr_bins*nr_bins/nr_folds), quiet(param_.quiet),
4065
+ util(param.fun, param.nr_threads), cv_error(0)
4066
+ {
4067
+ param.quiet = true;
4068
+ }
4069
+
4070
+ mf_double CrossValidatorBase::do_cross_validation()
4071
+ {
4072
+ vector<mf_int> cv_blocks;
4073
+ srand(0);
4074
+ for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
4075
+ cv_blocks.push_back(block);
4076
+ random_shuffle(cv_blocks.begin(), cv_blocks.end());
4077
+
4078
+ if(!quiet)
4079
+ {
4080
+ cout.width(4);
4081
+ cout << "fold";
4082
+ cout.width(10);
4083
+ cout << util.get_error_legend();
4084
+ cout << endl;
4085
+ }
4086
+
4087
+ cv_error = 0;
4088
+
4089
+ for(mf_int fold = 0; fold < nr_folds; ++fold)
4090
+ {
4091
+ mf_int begin = fold*nr_blocks_per_fold;
4092
+ mf_int end = min((fold+1)*nr_blocks_per_fold, nr_bins*nr_bins);
4093
+ vector<mf_int> hidden_blocks(cv_blocks.begin()+begin,
4094
+ cv_blocks.begin()+end);
4095
+
4096
+ mf_double err = do_cv1(hidden_blocks);
4097
+ cv_error += err;
4098
+
4099
+ if(!quiet)
4100
+ {
4101
+ cout.width(4);
4102
+ cout << fold;
4103
+ cout.width(10);
4104
+ cout << fixed << setprecision(4) << err;
4105
+ cout << endl;
4106
+ }
4107
+ }
4108
+
4109
+ if(!quiet)
4110
+ {
4111
+ cout.width(14);
4112
+ cout.fill('=');
4113
+ cout << "" << endl;
4114
+ cout.fill(' ');
4115
+ cout.width(4);
4116
+ cout << "avg";
4117
+ cout.width(10);
4118
+ cout << fixed << setprecision(4) << cv_error/nr_folds;
4119
+ cout << endl;
4120
+ }
4121
+
4122
+ return cv_error/nr_folds;
4123
+ }
4124
+
4125
+ class CrossValidator : public CrossValidatorBase
4126
+ {
4127
+ public:
4128
+ CrossValidator(
4129
+ mf_parameter param_, mf_int nr_folds_, mf_problem const *prob_)
4130
+ : CrossValidatorBase(param_, nr_folds_), prob(prob_) {};
4131
+ mf_double do_cv1(vector<mf_int> &hidden_blocks);
4132
+ private:
4133
+ mf_problem const *prob;
4134
+ };
4135
+
4136
+ mf_double CrossValidator::do_cv1(vector<mf_int> &hidden_blocks)
4137
+ {
4138
+ mf_double err = 0;
4139
+ fpsg(prob, nullptr, param, hidden_blocks, &err);
4140
+ return err;
4141
+ }
4142
+
4143
+ class CrossValidatorOnDisk : public CrossValidatorBase
4144
+ {
4145
+ public:
4146
+ CrossValidatorOnDisk(
4147
+ mf_parameter param_, mf_int nr_folds_, string data_path_)
4148
+ : CrossValidatorBase(param_, nr_folds_), data_path(data_path_) {};
4149
+ mf_double do_cv1(vector<mf_int> &hidden_blocks);
4150
+ private:
4151
+ string data_path;
4152
+ };
4153
+
4154
+ mf_double CrossValidatorOnDisk::do_cv1(vector<mf_int> &hidden_blocks)
4155
+ {
4156
+ mf_double err = 0;
4157
+ fpsg_on_disk(data_path, string(), param, hidden_blocks, &err);
4158
+ return err;
4159
+ }
4160
+
4161
+ } // unnamed namespace
4162
+
4163
+ mf_model* mf_train_with_validation(
4164
+ mf_problem const *tr,
4165
+ mf_problem const *va,
4166
+ mf_parameter param)
4167
+ {
4168
+ if(!check_parameter(param))
4169
+ return nullptr;
4170
+
4171
+ shared_ptr<mf_model> model(nullptr);
4172
+
4173
+ if(param.fun != P_L2_MFOC)
4174
+ // Use stochastic gradient method
4175
+ model = fpsg(tr, va, param);
4176
+ else
4177
+ // Use coordinate descent method
4178
+ model = ccd_one_class(tr, va, param);
4179
+
4180
+ mf_model *model_ret = new mf_model;
4181
+
4182
+ model_ret->fun = model->fun;
4183
+ model_ret->m = model->m;
4184
+ model_ret->n = model->n;
4185
+ model_ret->k = model->k;
4186
+ model_ret->b = model->b;
4187
+
4188
+ model_ret->P = model->P;
4189
+ model->P = nullptr;
4190
+
4191
+ model_ret->Q = model->Q;
4192
+ model->Q = nullptr;
4193
+
4194
+ return model_ret;
4195
+ }
4196
+
4197
+ mf_model* mf_train_with_validation_on_disk(
4198
+ char const *tr_path,
4199
+ char const *va_path,
4200
+ mf_parameter param)
4201
+ {
4202
+ // Two conditions lead to empty model. First, any parameter is not in its
4203
+ // supported range. Second, one-class matrix facotorization with L2-loss
4204
+ // (-f 12) doesn't support disk-level training.
4205
+ if(!check_parameter(param) || param.fun == P_L2_MFOC)
4206
+ return nullptr;
4207
+
4208
+ shared_ptr<mf_model> model = fpsg_on_disk(
4209
+ string(tr_path), string(va_path), param);
4210
+
4211
+ mf_model *model_ret = new mf_model;
4212
+
4213
+ model_ret->fun = model->fun;
4214
+ model_ret->m = model->m;
4215
+ model_ret->n = model->n;
4216
+ model_ret->k = model->k;
4217
+ model_ret->b = model->b;
4218
+
4219
+ model_ret->P = model->P;
4220
+ model->P = nullptr;
4221
+
4222
+ model_ret->Q = model->Q;
4223
+ model->Q = nullptr;
4224
+
4225
+ return model_ret;
4226
+ }
4227
+
4228
+ mf_model* mf_train(mf_problem const *prob, mf_parameter param)
4229
+ {
4230
+ return mf_train_with_validation(prob, nullptr, param);
4231
+ }
4232
+
4233
+ mf_model* mf_train_on_disk(char const *tr_path, mf_parameter param)
4234
+ {
4235
+ return mf_train_with_validation_on_disk(tr_path, "", param);
4236
+ }
4237
+
4238
+ mf_double mf_cross_validation(
4239
+ mf_problem const *prob,
4240
+ mf_int nr_folds,
4241
+ mf_parameter param)
4242
+ {
4243
+ // Two conditions lead to empty model. First, any parameter is not in its
4244
+ // supported range. Second, one-class matrix facotorization with L2-loss
4245
+ // (-f 12) doesn't support disk-level training.
4246
+ if(!check_parameter(param) || param.fun == P_L2_MFOC)
4247
+ return 0;
4248
+
4249
+ CrossValidator validator(param, nr_folds, prob);
4250
+
4251
+ return validator.do_cross_validation();
4252
+ }
4253
+
4254
+ mf_double mf_cross_validation_on_disk(
4255
+ char const *prob,
4256
+ mf_int nr_folds,
4257
+ mf_parameter param)
4258
+ {
4259
+ // Two conditions lead to empty model. First, any parameter is not in its
4260
+ // supported range. Second, one-class matrix facotorization with L2-loss
4261
+ // (-f 12) doesn't support disk-level training.
4262
+ if(!check_parameter(param) || param.fun == P_L2_MFOC)
4263
+ return 0;
4264
+
4265
+ CrossValidatorOnDisk validator(param, nr_folds, string(prob));
4266
+
4267
+ return validator.do_cross_validation();
4268
+ }
4269
+
4270
+ mf_problem read_problem(string path)
4271
+ {
4272
+ mf_problem prob;
4273
+ prob.m = 0;
4274
+ prob.n = 0;
4275
+ prob.nnz = 0;
4276
+ prob.R = nullptr;
4277
+
4278
+ if(path.empty())
4279
+ return prob;
4280
+
4281
+ ifstream f(path);
4282
+ if(!f.is_open())
4283
+ return prob;
4284
+
4285
+ string line;
4286
+ while(getline(f, line))
4287
+ prob.nnz += 1;
4288
+
4289
+ mf_node *R = new mf_node[static_cast<size_t>(prob.nnz)];
4290
+
4291
+ f.close();
4292
+ f.open(path);
4293
+
4294
+ mf_long idx = 0;
4295
+ for(mf_node N; f >> N.u >> N.v >> N.r;)
4296
+ {
4297
+ if(N.u+1 > prob.m)
4298
+ prob.m = N.u+1;
4299
+ if(N.v+1 > prob.n)
4300
+ prob.n = N.v+1;
4301
+ R[idx] = N;
4302
+ ++idx;
4303
+ }
4304
+ prob.R = R;
4305
+
4306
+ f.close();
4307
+
4308
+ return prob;
4309
+ }
4310
+
4311
+ mf_int mf_save_model(mf_model const *model, char const *path)
4312
+ {
4313
+ ofstream f(path);
4314
+ if(!f.is_open())
4315
+ return 1;
4316
+
4317
+ f << "f " << model->fun << endl;
4318
+ f << "m " << model->m << endl;
4319
+ f << "n " << model->n << endl;
4320
+ f << "k " << model->k << endl;
4321
+ f << "b " << model->b << endl;
4322
+
4323
+ auto write = [&] (mf_float *ptr, mf_int size, char prefix)
4324
+ {
4325
+ for(mf_int i = 0; i < size; ++i)
4326
+ {
4327
+ mf_float *ptr1 = ptr + (mf_long)i*model->k;
4328
+ f << prefix << i << " ";
4329
+ if(isnan(ptr1[0]))
4330
+ {
4331
+ f << "F ";
4332
+ for(mf_int d = 0; d < model->k; ++d)
4333
+ f << 0 << " ";
4334
+ }
4335
+ else
4336
+ {
4337
+ f << "T ";
4338
+ for(mf_int d = 0; d < model->k; ++d)
4339
+ f << ptr1[d] << " ";
4340
+ }
4341
+ f << endl;
4342
+ }
4343
+
4344
+ };
4345
+
4346
+ write(model->P, model->m, 'p');
4347
+ write(model->Q, model->n, 'q');
4348
+
4349
+ f.close();
4350
+
4351
+ return 0;
4352
+ }
4353
+
4354
+ mf_model* mf_load_model(char const *path)
4355
+ {
4356
+ ifstream f(path);
4357
+ if(!f.is_open())
4358
+ return nullptr;
4359
+
4360
+ string dummy;
4361
+
4362
+ mf_model *model = new mf_model;
4363
+ model->P = nullptr;
4364
+ model->Q = nullptr;
4365
+
4366
+ f >> dummy >> model->fun >> dummy >> model->m >> dummy >> model->n >>
4367
+ dummy >> model->k >> dummy >> model->b;
4368
+
4369
+ try
4370
+ {
4371
+ model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
4372
+ model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
4373
+ }
4374
+ catch(bad_alloc const &e)
4375
+ {
4376
+ cerr << e.what() << endl;
4377
+ mf_destroy_model(&model);
4378
+ return nullptr;
4379
+ }
4380
+
4381
+ auto read = [&] (mf_float *ptr, mf_int size)
4382
+ {
4383
+ for(mf_int i = 0; i < size; ++i)
4384
+ {
4385
+ mf_float *ptr1 = ptr + (mf_long)i*model->k;
4386
+ f >> dummy >> dummy;
4387
+ if(dummy.compare("F") == 0) // nan vector starts with "F"
4388
+ for(mf_int d = 0; d < model->k; ++d)
4389
+ {
4390
+ f >> dummy;
4391
+ ptr1[d] = numeric_limits<mf_float>::quiet_NaN();
4392
+ }
4393
+ else
4394
+ for(mf_int d = 0; d < model->k; ++d)
4395
+ f >> ptr1[d];
4396
+ }
4397
+ };
4398
+
4399
+ read(model->P, model->m);
4400
+ read(model->Q, model->n);
4401
+
4402
+ f.close();
4403
+
4404
+ return model;
4405
+ }
4406
+
4407
+ void mf_destroy_model(mf_model **model)
4408
+ {
4409
+ if(model == nullptr || *model == nullptr)
4410
+ return;
4411
+ Utility::free_aligned_float((*model)->P);
4412
+ Utility::free_aligned_float((*model)->Q);
4413
+ delete *model;
4414
+ *model = nullptr;
4415
+ }
4416
+
4417
+ mf_float mf_predict(mf_model const *model, mf_int u, mf_int v)
4418
+ {
4419
+ if(u < 0 || u >= model->m || v < 0 || v >= model->n)
4420
+ return model->b;
4421
+
4422
+ mf_float *p = model->P+(mf_long)u*model->k;
4423
+ mf_float *q = model->Q+(mf_long)v*model->k;
4424
+
4425
+ mf_float z = std::inner_product(p, p+model->k, q, (mf_float)0.0f);
4426
+
4427
+ if(isnan(z))
4428
+ z = model->b;
4429
+
4430
+ if(model->fun == P_L2_MFC ||
4431
+ model->fun == P_L1_MFC ||
4432
+ model->fun == P_LR_MFC)
4433
+ z = z > 0.0f? 1.0f: -1.0f;
4434
+
4435
+ return z;
4436
+ }
4437
+
4438
+ mf_double calc_rmse(mf_problem *prob, mf_model *model)
4439
+ {
4440
+ if(prob->nnz == 0)
4441
+ return 0;
4442
+ mf_double loss = 0;
4443
+ #if defined USEOMP
4444
+ #pragma omp parallel for schedule(static) reduction(+:loss)
4445
+ #endif
4446
+ for(mf_long i = 0; i < prob->nnz; ++i)
4447
+ {
4448
+ mf_node &N = prob->R[i];
4449
+ mf_float e = N.r - mf_predict(model, N.u, N.v);
4450
+ loss += e*e;
4451
+ }
4452
+ return sqrt(loss/prob->nnz);
4453
+ }
4454
+
4455
+ mf_double calc_mae(mf_problem *prob, mf_model *model)
4456
+ {
4457
+ if(prob->nnz == 0)
4458
+ return 0;
4459
+ mf_double loss = 0;
4460
+ #if defined USEOMP
4461
+ #pragma omp parallel for schedule(static) reduction(+:loss)
4462
+ #endif
4463
+ for(mf_long i = 0; i < prob->nnz; ++i)
4464
+ {
4465
+ mf_node &N = prob->R[i];
4466
+ loss += abs(N.r - mf_predict(model, N.u, N.v));
4467
+ }
4468
+ return loss/prob->nnz;
4469
+ }
4470
+
4471
+ mf_double calc_gkl(mf_problem *prob, mf_model *model)
4472
+ {
4473
+ if(prob->nnz == 0)
4474
+ return 0;
4475
+ mf_double loss = 0;
4476
+ #if defined USEOMP
4477
+ #pragma omp parallel for schedule(static) reduction(+:loss)
4478
+ #endif
4479
+ for(mf_long i = 0; i < prob->nnz; ++i)
4480
+ {
4481
+ mf_node &N = prob->R[i];
4482
+ mf_float z = mf_predict(model, N.u, N.v);
4483
+ loss += N.r*log(N.r/z)-N.r+z;
4484
+ }
4485
+ return loss/prob->nnz;
4486
+ }
4487
+
4488
+ mf_double calc_logloss(mf_problem *prob, mf_model *model)
4489
+ {
4490
+ if(prob->nnz == 0)
4491
+ return 0;
4492
+ mf_double logloss = 0;
4493
+ #if defined USEOMP
4494
+ #pragma omp parallel for schedule(static) reduction(+:logloss)
4495
+ #endif
4496
+ for(mf_long i = 0; i < prob->nnz; ++i)
4497
+ {
4498
+ mf_node &N = prob->R[i];
4499
+ mf_float z = mf_predict(model, N.u, N.v);
4500
+ if(N.r > 0)
4501
+ logloss += log(1.0+exp(-z));
4502
+ else
4503
+ logloss += log(1.0+exp(z));
4504
+ }
4505
+ return logloss/prob->nnz;
4506
+ }
4507
+
4508
+ mf_double calc_accuracy(mf_problem *prob, mf_model *model)
4509
+ {
4510
+ if(prob->nnz == 0)
4511
+ return 0;
4512
+ mf_double acc = 0;
4513
+ #if defined USEOMP
4514
+ #pragma omp parallel for schedule(static) reduction(+:acc)
4515
+ #endif
4516
+ for(mf_long i = 0; i < prob->nnz; ++i)
4517
+ {
4518
+ mf_node &N = prob->R[i];
4519
+ mf_float z = mf_predict(model, N.u, N.v);
4520
+ if(N.r > 0)
4521
+ acc += z > 0? 1: 0;
4522
+ else
4523
+ acc += z < 0? 1: 0;
4524
+ }
4525
+ return acc/prob->nnz;
4526
+ }
4527
+
4528
+ pair<mf_double, mf_double> calc_mpr_auc(mf_problem *prob,
4529
+ mf_model *model, bool transpose)
4530
+ {
4531
+ mf_int mf_node::*row_ptr;
4532
+ mf_int mf_node::*col_ptr;
4533
+ mf_int m = 0, n = 0;
4534
+ if(!transpose)
4535
+ {
4536
+ row_ptr = &mf_node::u;
4537
+ col_ptr = &mf_node::v;
4538
+ m = max(prob->m, model->m);
4539
+ n = max(prob->n, model->n);
4540
+ }
4541
+ else
4542
+ {
4543
+ row_ptr = &mf_node::v;
4544
+ col_ptr = &mf_node::u;
4545
+ m = max(prob->n, model->n);
4546
+ n = max(prob->m, model->m);
4547
+ }
4548
+
4549
+ auto sort_by_id = [&] (mf_node const &lhs, mf_node const &rhs)
4550
+ {
4551
+ return tie(lhs.*row_ptr, lhs.*col_ptr) <
4552
+ tie(rhs.*row_ptr, rhs.*col_ptr);
4553
+ };
4554
+
4555
+ sort(prob->R, prob->R+prob->nnz, sort_by_id);
4556
+
4557
+ auto sort_by_pred = [&] (pair<mf_node, mf_float> const &lhs,
4558
+ pair<mf_node, mf_float> const &rhs) { return lhs.second < rhs.second; };
4559
+
4560
+ vector<mf_int> pos_cnts(m+1, 0);
4561
+ for(mf_int i = 0; i < prob->nnz; ++i)
4562
+ pos_cnts[prob->R[i].*row_ptr+1] += 1;
4563
+ for(mf_int i = 1; i < m+1; ++i)
4564
+ pos_cnts[i] += pos_cnts[i-1];
4565
+
4566
+ mf_int total_m = 0;
4567
+ mf_long total_pos = 0;
4568
+ mf_double all_u_mpr = 0;
4569
+ mf_double all_u_auc = 0;
4570
+ #if defined USEOMP
4571
+ #pragma omp parallel for schedule(static) reduction(+: total_m, total_pos, all_u_mpr, all_u_auc)
4572
+ #endif
4573
+ for(mf_int i = 0; i < m; ++i)
4574
+ {
4575
+ if(pos_cnts[i+1]-pos_cnts[i] < 1)
4576
+ continue;
4577
+
4578
+ vector<pair<mf_node, mf_float>> row(n);
4579
+
4580
+ for(mf_int j = 0; j < n; ++j)
4581
+ {
4582
+ mf_node N;
4583
+ N.*row_ptr = i;
4584
+ N.*col_ptr = j;
4585
+ N.r = 0;
4586
+ row[j] = make_pair(N, mf_predict(model, N.u, N.v));
4587
+ }
4588
+
4589
+ mf_int pos = 0;
4590
+ vector<mf_int> index(pos_cnts[i+1]-pos_cnts[i], 0);
4591
+ for(mf_int j = pos_cnts[i]; j < pos_cnts[i+1]; ++j)
4592
+ {
4593
+ if(prob->R[j].r <= 0)
4594
+ continue;
4595
+
4596
+ mf_int col = prob->R[j].*col_ptr;
4597
+ row[col].first.r = prob->R[j].r;
4598
+ index[pos] = col;
4599
+ pos += 1;
4600
+ }
4601
+
4602
+ if(n-pos < 1 || pos < 1)
4603
+ continue;
4604
+
4605
+ ++total_m;
4606
+ total_pos += pos;
4607
+
4608
+ mf_int count = 0;
4609
+ for(mf_int k = 0; k < pos; ++k)
4610
+ {
4611
+ swap(row[count], row[index[k]]);
4612
+ ++count;
4613
+ }
4614
+ sort(row.begin(), row.begin()+pos, sort_by_pred);
4615
+
4616
+ mf_double u_mpr = 0;
4617
+ mf_double u_auc = 0;
4618
+ for(auto neg_it = row.begin()+pos; neg_it != row.end(); ++neg_it)
4619
+ {
4620
+ if(row[pos-1].second <= neg_it->second)
4621
+ {
4622
+ u_mpr += pos;
4623
+ continue;
4624
+ }
4625
+
4626
+ mf_int left = 0;
4627
+ mf_int right = pos-1;
4628
+ while(left < right)
4629
+ {
4630
+ mf_int mid = (left+right)/2;
4631
+ if(row[mid].second > neg_it->second)
4632
+ right = mid;
4633
+ else
4634
+ left = mid+1;
4635
+ }
4636
+ u_mpr += left;
4637
+ u_auc += pos-left;
4638
+ }
4639
+
4640
+ all_u_mpr += u_mpr/(n-pos);
4641
+ all_u_auc += u_auc/(n-pos)/pos;
4642
+ }
4643
+
4644
+ all_u_mpr /= total_pos;
4645
+ all_u_auc /= total_m;
4646
+
4647
+ return make_pair(all_u_mpr, all_u_auc);
4648
+ }
4649
+
4650
+ mf_double calc_mpr(mf_problem *prob, mf_model *model, bool transpose)
4651
+ {
4652
+ return calc_mpr_auc(prob, model, transpose).first;
4653
+ }
4654
+
4655
+ mf_double calc_auc(mf_problem *prob, mf_model *model, bool transpose)
4656
+ {
4657
+ return calc_mpr_auc(prob, model, transpose).second;
4658
+ }
4659
+
4660
+ mf_parameter mf_get_default_param()
4661
+ {
4662
+ mf_parameter param;
4663
+
4664
+ param.fun = P_L2_MFR;
4665
+ param.k = 8;
4666
+ param.nr_threads = 12;
4667
+ param.nr_bins = 20;
4668
+ param.nr_iters = 20;
4669
+ param.lambda_p1 = 0.0f;
4670
+ param.lambda_q1 = 0.0f;
4671
+ param.lambda_p2 = 0.1f;
4672
+ param.lambda_q2 = 0.1f;
4673
+ param.eta = 0.1f;
4674
+ param.alpha = 1.0f;
4675
+ param.c = 0.0001f;
4676
+ param.do_nmf = false;
4677
+ param.quiet = false;
4678
+ param.copy_data = true;
4679
+
4680
+ return param;
4681
+ }
4682
+
4683
+ }