libmf 0.1.3 → 0.2.0

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,4683 +0,0 @@
1
- #include <algorithm>
2
- #include <cmath>
3
- #include <condition_variable>
4
- #include <cstdlib>
5
- #include <cstring>
6
- #include <fstream>
7
- #include <iostream>
8
- #include <iomanip>
9
- #include <memory>
10
- #include <numeric>
11
- #include <queue>
12
- #include <random>
13
- #include <stdexcept>
14
- #include <string>
15
- #include <thread>
16
- #include <unordered_set>
17
- #include <vector>
18
- #include <limits>
19
-
20
- #include "mf.h"
21
-
22
- #if defined USESSE
23
- #include <pmmintrin.h>
24
- #endif
25
-
26
- #if defined USEAVX
27
- #include <immintrin.h>
28
- #endif
29
-
30
- #if defined USEOMP
31
- #include <omp.h>
32
- #endif
33
-
34
- namespace mf
35
- {
36
-
37
- using namespace std;
38
-
39
- namespace // unnamed namespace
40
- {
41
-
42
- mf_int const kALIGNByte = 32;
43
- mf_int const kALIGN = kALIGNByte/sizeof(mf_float);
44
-
45
- //--------------------------------------
46
- //---------Scheduler of Blocks----------
47
- //--------------------------------------
48
-
49
- class Scheduler
50
- {
51
- public:
52
- Scheduler(mf_int nr_bins, mf_int nr_threads, vector<mf_int> cv_blocks);
53
- mf_int get_job();
54
- mf_int get_bpr_job(mf_int first_block, bool is_column_oriented);
55
- void put_job(mf_int block, mf_double loss, mf_double error);
56
- void put_bpr_job(mf_int first_block, mf_int second_block);
57
- mf_double get_loss();
58
- mf_double get_error();
59
- mf_int get_negative(mf_int first_block, mf_int second_block,
60
- mf_int m, mf_int n, bool is_column_oriented);
61
- void wait_for_jobs_done();
62
- void resume();
63
- void terminate();
64
- bool is_terminated();
65
-
66
- private:
67
- mf_int nr_bins;
68
- mf_int nr_threads;
69
- mf_int nr_done_jobs;
70
- mf_int target;
71
- mf_int nr_paused_threads;
72
- bool terminated;
73
- vector<mf_int> counts;
74
- vector<mf_int> busy_p_blocks;
75
- vector<mf_int> busy_q_blocks;
76
- vector<mf_double> block_losses;
77
- vector<mf_double> block_errors;
78
- vector<minstd_rand0> block_generators;
79
- unordered_set<mf_int> cv_blocks;
80
- mutex mtx;
81
- condition_variable cond_var;
82
- default_random_engine generator;
83
- uniform_real_distribution<mf_float> distribution;
84
- priority_queue<pair<mf_float, mf_int>,
85
- vector<pair<mf_float, mf_int>>,
86
- greater<pair<mf_float, mf_int>>> pq;
87
- };
88
-
89
- Scheduler::Scheduler(mf_int nr_bins, mf_int nr_threads,
90
- vector<mf_int> cv_blocks)
91
- : nr_bins(nr_bins),
92
- nr_threads(nr_threads),
93
- nr_done_jobs(0),
94
- target(nr_bins*nr_bins),
95
- nr_paused_threads(0),
96
- terminated(false),
97
- counts(nr_bins*nr_bins, 0),
98
- busy_p_blocks(nr_bins, 0),
99
- busy_q_blocks(nr_bins, 0),
100
- block_losses(nr_bins*nr_bins, 0),
101
- block_errors(nr_bins*nr_bins, 0),
102
- cv_blocks(cv_blocks.begin(), cv_blocks.end()),
103
- distribution(0.0, 1.0)
104
- {
105
- for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
106
- {
107
- if(this->cv_blocks.find(i) == this->cv_blocks.end())
108
- pq.emplace(distribution(generator), i);
109
- block_generators.push_back(minstd_rand0(rand()));
110
- }
111
- }
112
-
113
- mf_int Scheduler::get_job()
114
- {
115
- bool is_found = false;
116
- pair<mf_float, mf_int> block;
117
-
118
- while(!is_found)
119
- {
120
- lock_guard<mutex> lock(mtx);
121
- vector<pair<mf_float, mf_int>> locked_blocks;
122
- mf_int p_block = 0;
123
- mf_int q_block = 0;
124
-
125
- while(!pq.empty())
126
- {
127
- block = pq.top();
128
- pq.pop();
129
-
130
- p_block = block.second/nr_bins;
131
- q_block = block.second%nr_bins;
132
-
133
- if(busy_p_blocks[p_block] || busy_q_blocks[q_block])
134
- locked_blocks.push_back(block);
135
- else
136
- {
137
- busy_p_blocks[p_block] = 1;
138
- busy_q_blocks[q_block] = 1;
139
- counts[block.second] += 1;
140
- is_found = true;
141
- break;
142
- }
143
- }
144
-
145
- for(auto &block1 : locked_blocks)
146
- pq.push(block1);
147
- }
148
-
149
- return block.second;
150
- }
151
-
152
- mf_int Scheduler::get_bpr_job(mf_int first_block, bool is_column_oriented)
153
- {
154
- lock_guard<mutex> lock(mtx);
155
- mf_int another = first_block;
156
- vector<pair<mf_float, mf_int>> locked_blocks;
157
-
158
- while(!pq.empty())
159
- {
160
- pair<mf_float, mf_int> block = pq.top();
161
- pq.pop();
162
-
163
- mf_int p_block = block.second/nr_bins;
164
- mf_int q_block = block.second%nr_bins;
165
-
166
- auto is_rejected = [&] ()
167
- {
168
- if(is_column_oriented)
169
- return first_block%nr_bins != q_block ||
170
- busy_p_blocks[p_block];
171
- else
172
- return first_block/nr_bins != p_block ||
173
- busy_q_blocks[q_block];
174
- };
175
-
176
- if(is_rejected())
177
- locked_blocks.push_back(block);
178
- else
179
- {
180
- busy_p_blocks[p_block] = 1;
181
- busy_q_blocks[q_block] = 1;
182
- another = block.second;
183
- break;
184
- }
185
- }
186
-
187
- for(auto &block : locked_blocks)
188
- pq.push(block);
189
-
190
- return another;
191
- }
192
-
193
- void Scheduler::put_job(mf_int block_idx, mf_double loss, mf_double error)
194
- {
195
- // Return the held block to the scheduler
196
- {
197
- lock_guard<mutex> lock(mtx);
198
- busy_p_blocks[block_idx/nr_bins] = 0;
199
- busy_q_blocks[block_idx%nr_bins] = 0;
200
- block_losses[block_idx] = loss;
201
- block_errors[block_idx] = error;
202
- ++nr_done_jobs;
203
- mf_float priority =
204
- (mf_float)counts[block_idx]+distribution(generator);
205
- pq.emplace(priority, block_idx);
206
- ++nr_paused_threads;
207
- // Tell others that a block is available again.
208
- cond_var.notify_all();
209
- }
210
-
211
- // Wait if nr_done_jobs (aka the number of processed blocks) is too many
212
- // because we want to print out the training status roughly once all blocks
213
- // are processed once. This is the only place that a solver thread should
214
- // wait for something.
215
- {
216
- unique_lock<mutex> lock(mtx);
217
- cond_var.wait(lock, [&] {
218
- return nr_done_jobs < target;
219
- });
220
- }
221
-
222
- // Nothing is blocking and this thread is going to take another block
223
- {
224
- lock_guard<mutex> lock(mtx);
225
- --nr_paused_threads;
226
- }
227
- }
228
-
229
- void Scheduler::put_bpr_job(mf_int first_block, mf_int second_block)
230
- {
231
- if(first_block == second_block)
232
- return;
233
-
234
- lock_guard<mutex> lock(mtx);
235
- {
236
- busy_p_blocks[second_block/nr_bins] = 0;
237
- busy_q_blocks[second_block%nr_bins] = 0;
238
- mf_float priority =
239
- (mf_float)counts[second_block]+distribution(generator);
240
- pq.emplace(priority, second_block);
241
- }
242
- }
243
-
244
- mf_double Scheduler::get_loss()
245
- {
246
- lock_guard<mutex> lock(mtx);
247
- return accumulate(block_losses.begin(), block_losses.end(), 0.0);
248
- }
249
-
250
- mf_double Scheduler::get_error()
251
- {
252
- lock_guard<mutex> lock(mtx);
253
- return accumulate(block_errors.begin(), block_errors.end(), 0.0);
254
- }
255
-
256
- mf_int Scheduler::get_negative(mf_int first_block, mf_int second_block,
257
- mf_int m, mf_int n, bool is_column_oriented)
258
- {
259
- mf_int rand_val = (mf_int)block_generators[first_block]();
260
-
261
- auto gen_random = [&] (mf_int block_id)
262
- {
263
- mf_int v_min, v_max;
264
-
265
- if(is_column_oriented)
266
- {
267
- mf_int seg_size = (mf_int)ceil((double)m/nr_bins);
268
- v_min = min((block_id/nr_bins)*seg_size, m-1);
269
- v_max = min(v_min+seg_size, m-1);
270
- }
271
- else
272
- {
273
- mf_int seg_size = (mf_int)ceil((double)n/nr_bins);
274
- v_min = min((block_id%nr_bins)*seg_size, n-1);
275
- v_max = min(v_min+seg_size, n-1);
276
- }
277
- if(v_max == v_min)
278
- return v_min;
279
- else
280
- return rand_val%(v_max-v_min)+v_min;
281
- };
282
-
283
- if(rand_val % 2)
284
- return (mf_int)gen_random(first_block);
285
- else
286
- return (mf_int)gen_random(second_block);
287
- }
288
-
289
- void Scheduler::wait_for_jobs_done()
290
- {
291
- unique_lock<mutex> lock(mtx);
292
-
293
- // The first thing the main thread should wait for is that solver threads
294
- // process enough matrix blocks.
295
- // [REVIEW] Is it really needed? Solver threads automatically stop if they
296
- // process too many blocks, so the next wait should be enough for stopping
297
- // the main thread when nr_done_job is not enough.
298
- cond_var.wait(lock, [&] {
299
- return nr_done_jobs >= target;
300
- });
301
-
302
- // Wait for all threads to stop. Once a thread realizes that all threads
303
- // have processed enough blocks it should stop. Then, the main thread can
304
- // print values safely.
305
- cond_var.wait(lock, [&] {
306
- return nr_paused_threads == nr_threads;
307
- });
308
- }
309
-
310
- void Scheduler::resume()
311
- {
312
- lock_guard<mutex> lock(mtx);
313
- target += nr_bins*nr_bins;
314
- cond_var.notify_all();
315
- }
316
-
317
- void Scheduler::terminate()
318
- {
319
- lock_guard<mutex> lock(mtx);
320
- terminated = true;
321
- }
322
-
323
- bool Scheduler::is_terminated()
324
- {
325
- lock_guard<mutex> lock(mtx);
326
- return terminated;
327
- }
328
-
329
- //--------------------------------------
330
- //------------Block of matrix-----------
331
- //--------------------------------------
332
-
333
- class BlockBase
334
- {
335
- public:
336
- virtual bool move_next() { return false; };
337
- virtual mf_node* get_current() { return nullptr; }
338
- virtual void reload() {};
339
- virtual void free() {};
340
- virtual mf_long get_nnz() { return 0; };
341
- virtual ~BlockBase() {};
342
- };
343
-
344
- class Block : public BlockBase
345
- {
346
- public:
347
- Block() : first(nullptr), last(nullptr), current(nullptr) {};
348
- Block(mf_node *first_, mf_node *last_)
349
- : first(first_), last(last_), current(nullptr) {};
350
- bool move_next() { return ++current != last; }
351
- mf_node* get_current() { return current; }
352
- void tie_to(mf_node *first_, mf_node *last_);
353
- void reload() { current = first-1; };
354
- mf_long get_nnz() { return last-first; };
355
-
356
- private:
357
- mf_node* first;
358
- mf_node* last;
359
- mf_node* current;
360
- };
361
-
362
- void Block::tie_to(mf_node *first_, mf_node *last_)
363
- {
364
- first = first_;
365
- last = last_;
366
- };
367
-
368
- class BlockOnDisk : public BlockBase
369
- {
370
- public:
371
- BlockOnDisk() : first(0), last(0), current(0),
372
- source_path(""), buffer(0) {};
373
- bool move_next() { return ++current < last-first; }
374
- mf_node* get_current() { return &buffer[static_cast<size_t>(current)]; }
375
- void tie_to(string source_path_, mf_long first_, mf_long last_);
376
- void reload();
377
- void free() { buffer.resize(0); };
378
- mf_long get_nnz() { return last-first; };
379
-
380
- private:
381
- mf_long first;
382
- mf_long last;
383
- mf_long current;
384
- string source_path;
385
- vector<mf_node> buffer;
386
- };
387
-
388
- void BlockOnDisk::tie_to(string source_path_, mf_long first_, mf_long last_)
389
- {
390
- source_path = source_path_;
391
- first = first_;
392
- last = last_;
393
- }
394
-
395
- void BlockOnDisk::reload()
396
- {
397
- ifstream source(source_path, ifstream::in|ifstream::binary);
398
- if(!source)
399
- throw runtime_error("can not open "+source_path);
400
-
401
- buffer.resize(static_cast<size_t>(last-first));
402
- source.seekg(first*sizeof(mf_node));
403
- source.read((char*)buffer.data(), (last-first)*sizeof(mf_node));
404
- current = -1;
405
- }
406
-
407
- //--------------------------------------
408
- //-------------Miscellaneous------------
409
- //--------------------------------------
410
-
411
- struct sort_node_by_p
412
- {
413
- bool operator() (mf_node const &lhs, mf_node const &rhs)
414
- {
415
- return tie(lhs.u, lhs.v) < tie(rhs.u, rhs.v);
416
- }
417
- };
418
-
419
- struct sort_node_by_q
420
- {
421
- bool operator() (mf_node const &lhs, mf_node const &rhs)
422
- {
423
- return tie(lhs.v, lhs.u) < tie(rhs.v, rhs.u);
424
- }
425
- };
426
-
427
- struct deleter
428
- {
429
- void operator() (mf_problem *prob)
430
- {
431
- delete[] prob->R;
432
- delete prob;
433
- }
434
- };
435
-
436
-
437
- class Utility
438
- {
439
- public:
440
- Utility(mf_int f, mf_int n) : fun(f), nr_threads(n) {};
441
- void collect_info(mf_problem &prob, mf_float &avg, mf_float &std_dev);
442
- void collect_info_on_disk(string data_path, mf_problem &prob,
443
- mf_float &avg, mf_float &std_dev);
444
- void shuffle_problem(mf_problem &prob, vector<mf_int> &p_map,
445
- vector<mf_int> &q_map);
446
- vector<mf_node*> grid_problem(mf_problem &prob, mf_int nr_bins,
447
- vector<mf_int> &omega_p,
448
- vector<mf_int> &omega_q,
449
- vector<Block> &blocks);
450
- void grid_shuffle_scale_problem_on_disk(mf_int m, mf_int n, mf_int nr_bins,
451
- mf_float scale, string data_path,
452
- vector<mf_int> &p_map,
453
- vector<mf_int> &q_map,
454
- vector<mf_int> &omega_p,
455
- vector<mf_int> &omega_q,
456
- vector<BlockOnDisk> &blocks);
457
- void scale_problem(mf_problem &prob, mf_float scale);
458
- mf_double calc_reg1(mf_model &model, mf_float lambda_p, mf_float lambda_q,
459
- vector<mf_int> &omega_p, vector<mf_int> &omega_q);
460
- mf_double calc_reg2(mf_model &model, mf_float lambda_p, mf_float lambda_q,
461
- vector<mf_int> &omega_p, vector<mf_int> &omega_q);
462
- string get_error_legend() const;
463
- mf_double calc_error(vector<BlockBase*> &blocks,
464
- vector<mf_int> &cv_block_ids,
465
- mf_model const &model);
466
- void scale_model(mf_model &model, mf_float scale);
467
-
468
- static mf_problem* copy_problem(mf_problem const *prob, bool copy_data);
469
- static vector<mf_int> gen_random_map(mf_int size);
470
- // A function used to allocate all aligned float array.
471
- // It hides platform-specific function calls. Memory
472
- // allocated by malloc_aligned_float must be freed by using
473
- // free_aligned_float.
474
- static mf_float* malloc_aligned_float(mf_long size);
475
- // A function used to free all aligned float array.
476
- // It hides platform-specific function calls.
477
- static void free_aligned_float(mf_float* ptr);
478
- // Initialization function for stochastic gradient method.
479
- // Factor matrices P and Q are both randomly initialized.
480
- static mf_model* init_model(mf_int loss, mf_int m, mf_int n,
481
- mf_int k, mf_float avg,
482
- vector<mf_int> &omega_p,
483
- vector<mf_int> &omega_q);
484
- // Initialization function for one-class CD.
485
- // It does zero-initialization on factor matrix P and random initialization
486
- // on factor matrix Q.
487
- static mf_model* init_model(mf_int m, mf_int n, mf_int k);
488
- static mf_float inner_product(mf_float *p, mf_float *q, mf_int k);
489
- static vector<mf_int> gen_inv_map(vector<mf_int> &map);
490
- static void shrink_model(mf_model &model, mf_int k_new);
491
- static void shuffle_model(mf_model &model,
492
- vector<mf_int> &p_map,
493
- vector<mf_int> &q_map);
494
- mf_int get_thread_number() const { return nr_threads; };
495
- private:
496
- mf_int fun;
497
- mf_int nr_threads;
498
- };
499
-
500
- void Utility::collect_info(
501
- mf_problem &prob,
502
- mf_float &avg,
503
- mf_float &std_dev)
504
- {
505
- mf_double ex = 0;
506
- mf_double ex2 = 0;
507
-
508
- #if defined USEOMP
509
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:ex,ex2)
510
- #endif
511
- for(mf_long i = 0; i < prob.nnz; ++i)
512
- {
513
- mf_node &N = prob.R[i];
514
- ex += (mf_double)N.r;
515
- ex2 += (mf_double)N.r*N.r;
516
- }
517
-
518
- ex /= (mf_double)prob.nnz;
519
- ex2 /= (mf_double)prob.nnz;
520
- avg = (mf_float)ex;
521
- std_dev = (mf_float)sqrt(ex2-ex*ex);
522
- }
523
-
524
- void Utility::collect_info_on_disk(
525
- string data_path,
526
- mf_problem &prob,
527
- mf_float &avg,
528
- mf_float &std_dev)
529
- {
530
- mf_double ex = 0;
531
- mf_double ex2 = 0;
532
-
533
- ifstream source(data_path);
534
- if(!source.is_open())
535
- throw runtime_error("cannot open " + data_path);
536
-
537
- for(mf_node N; source >> N.u >> N.v >> N.r;)
538
- {
539
- if(N.u+1 > prob.m)
540
- prob.m = N.u+1;
541
- if(N.v+1 > prob.n)
542
- prob.n = N.v+1;
543
- prob.nnz += 1;
544
- ex += (mf_double)N.r;
545
- ex2 += (mf_double)N.r*N.r;
546
- }
547
- source.close();
548
-
549
- ex /= (mf_double)prob.nnz;
550
- ex2 /= (mf_double)prob.nnz;
551
- avg = (mf_float)ex;
552
- std_dev = (mf_float)sqrt(ex2-ex*ex);
553
- }
554
-
555
- void Utility::scale_problem(mf_problem &prob, mf_float scale)
556
- {
557
- if(scale == 1.0)
558
- return;
559
-
560
- #if defined USEOMP
561
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
562
- #endif
563
- for(mf_long i = 0; i < prob.nnz; ++i)
564
- prob.R[i].r *= scale;
565
- }
566
-
567
- void Utility::scale_model(mf_model &model, mf_float scale)
568
- {
569
- if(scale == 1.0)
570
- return;
571
-
572
- mf_int k = model.k;
573
-
574
- model.b *= scale;
575
-
576
- auto scale1 = [&] (mf_float *ptr, mf_int size, mf_float factor_scale)
577
- {
578
- #if defined USEOMP
579
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
580
- #endif
581
- for(mf_int i = 0; i < size; ++i)
582
- {
583
- mf_float *ptr1 = ptr+(mf_long)i*model.k;
584
- for(mf_int d = 0; d < k; ++d)
585
- ptr1[d] *= factor_scale;
586
- }
587
- };
588
-
589
- scale1(model.P, model.m, sqrt(scale));
590
- scale1(model.Q, model.n, sqrt(scale));
591
- }
592
-
593
- mf_float Utility::inner_product(mf_float *p, mf_float *q, mf_int k)
594
- {
595
- #if defined USESSE
596
- __m128 XMM = _mm_setzero_ps();
597
- for(mf_int d = 0; d < k; d += 4)
598
- XMM = _mm_add_ps(XMM, _mm_mul_ps(
599
- _mm_load_ps(p+d), _mm_load_ps(q+d)));
600
- __m128 XMMtmp = _mm_add_ps(XMM, _mm_movehl_ps(XMM, XMM));
601
- XMM = _mm_add_ps(XMM, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
602
- mf_float product;
603
- _mm_store_ss(&product, XMM);
604
- return product;
605
- #elif defined USEAVX
606
- __m256 XMM = _mm256_setzero_ps();
607
- for(mf_int d = 0; d < k; d += 8)
608
- XMM = _mm256_add_ps(XMM, _mm256_mul_ps(
609
- _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
610
- XMM = _mm256_add_ps(XMM, _mm256_permute2f128_ps(XMM, XMM, 1));
611
- XMM = _mm256_hadd_ps(XMM, XMM);
612
- XMM = _mm256_hadd_ps(XMM, XMM);
613
- mf_float product;
614
- _mm_store_ss(&product, _mm256_castps256_ps128(XMM));
615
- return product;
616
- #else
617
- return std::inner_product(p, p+k, q, (mf_float)0.0);
618
- #endif
619
- }
620
-
621
- mf_double Utility::calc_reg1(mf_model &model,
622
- mf_float lambda_p, mf_float lambda_q,
623
- vector<mf_int> &omega_p, vector<mf_int> &omega_q)
624
- {
625
- auto calc_reg1_core = [&] (mf_float *ptr, mf_int size,
626
- vector<mf_int> &omega)
627
- {
628
- mf_double reg = 0;
629
- for(mf_int i = 0; i < size; ++i)
630
- {
631
- if(omega[i] <= 0)
632
- continue;
633
-
634
- mf_float tmp = 0;
635
- for(mf_int j = 0; j < model.k; ++j)
636
- tmp += abs(ptr[(mf_long)i*model.k+j]);
637
- reg += omega[i]*tmp;
638
- }
639
- return reg;
640
- };
641
-
642
- return lambda_p*calc_reg1_core(model.P, model.m, omega_p)+
643
- lambda_q*calc_reg1_core(model.Q, model.n, omega_q);
644
- }
645
-
646
- mf_double Utility::calc_reg2(mf_model &model,
647
- mf_float lambda_p, mf_float lambda_q,
648
- vector<mf_int> &omega_p, vector<mf_int> &omega_q)
649
- {
650
- auto calc_reg2_core = [&] (mf_float *ptr, mf_int size,
651
- vector<mf_int> &omega)
652
- {
653
- mf_double reg = 0;
654
- #if defined USEOMP
655
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:reg)
656
- #endif
657
- for(mf_int i = 0; i < size; ++i)
658
- {
659
- if(omega[i] <= 0)
660
- continue;
661
-
662
- mf_float *ptr1 = ptr+(mf_long)i*model.k;
663
- reg += omega[i]*Utility::inner_product(ptr1, ptr1, model.k);
664
- }
665
-
666
- return reg;
667
- };
668
-
669
- return lambda_p*calc_reg2_core(model.P, model.m, omega_p) +
670
- lambda_q*calc_reg2_core(model.Q, model.n, omega_q);
671
- }
672
-
673
- mf_double Utility::calc_error(
674
- vector<BlockBase*> &blocks,
675
- vector<mf_int> &cv_block_ids,
676
- mf_model const &model)
677
- {
678
- mf_double error = 0;
679
- if(fun == P_L2_MFR || fun == P_L1_MFR || fun == P_KL_MFR ||
680
- fun == P_LR_MFC || fun == P_L2_MFC || fun == P_L1_MFC)
681
- {
682
- #if defined USEOMP
683
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
684
- #endif
685
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
686
- {
687
- BlockBase *block = blocks[cv_block_ids[i]];
688
- block->reload();
689
- while(block->move_next())
690
- {
691
- mf_node const &N = *(block->get_current());
692
- mf_float z = mf_predict(&model, N.u, N.v);
693
- switch(fun)
694
- {
695
- case P_L2_MFR:
696
- error += pow(N.r-z, 2);
697
- break;
698
- case P_L1_MFR:
699
- error += abs(N.r-z);
700
- break;
701
- case P_KL_MFR:
702
- error += N.r*log(N.r/z)-N.r+z;
703
- break;
704
- case P_LR_MFC:
705
- if(N.r > 0)
706
- error += log(1.0+exp(-z));
707
- else
708
- error += log(1.0+exp(z));
709
- break;
710
- case P_L2_MFC:
711
- case P_L1_MFC:
712
- if(N.r > 0)
713
- error += z > 0? 1: 0;
714
- else
715
- error += z < 0? 1: 0;
716
- break;
717
- default:
718
- throw invalid_argument("unknown error function");
719
- break;
720
- }
721
- }
722
- block->free();
723
- }
724
- }
725
- else
726
- {
727
- minstd_rand0 generator(rand());
728
- switch(fun)
729
- {
730
- case P_ROW_BPR_MFOC:
731
- {
732
- uniform_int_distribution<mf_int> distribution(0, model.n-1);
733
- #if defined USEOMP
734
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
735
- #endif
736
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
737
- {
738
- BlockBase *block = blocks[cv_block_ids[i]];
739
- block->reload();
740
- while(block->move_next())
741
- {
742
- mf_node const &N = *(block->get_current());
743
- mf_int w = distribution(generator);
744
- error += log(1+exp(mf_predict(&model, N.u, w)-
745
- mf_predict(&model, N.u, N.v)));
746
- }
747
- block->free();
748
- }
749
- break;
750
- }
751
- case P_COL_BPR_MFOC:
752
- {
753
- uniform_int_distribution<mf_int> distribution(0, model.m-1);
754
- #if defined USEOMP
755
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
756
- #endif
757
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
758
- {
759
- BlockBase *block = blocks[cv_block_ids[i]];
760
- block->reload();
761
- while(block->move_next())
762
- {
763
- mf_node const &N = *(block->get_current());
764
- mf_int w = distribution(generator);
765
- error += log(1+exp(mf_predict(&model, w, N.v)-
766
- mf_predict(&model, N.u, N.v)));
767
- }
768
- block->free();
769
- }
770
- break;
771
- }
772
- default:
773
- {
774
- throw invalid_argument("unknown error function");
775
- break;
776
- }
777
- }
778
- }
779
-
780
- return error;
781
- }
782
-
783
- string Utility::get_error_legend() const
784
- {
785
- switch(fun)
786
- {
787
- case P_L2_MFR:
788
- return string("rmse");
789
- break;
790
- case P_L1_MFR:
791
- return string("mae");
792
- break;
793
- case P_KL_MFR:
794
- return string("gkl");
795
- break;
796
- case P_LR_MFC:
797
- return string("logloss");
798
- break;
799
- case P_L2_MFC:
800
- case P_L1_MFC:
801
- return string("accuracy");
802
- break;
803
- case P_ROW_BPR_MFOC:
804
- case P_COL_BPR_MFOC:
805
- return string("bprloss");
806
- break;
807
- case P_L2_MFOC:
808
- return string("sqerror");
809
- default:
810
- return string();
811
- break;
812
- }
813
- }
814
-
815
- void Utility::shuffle_problem(
816
- mf_problem &prob,
817
- vector<mf_int> &p_map,
818
- vector<mf_int> &q_map)
819
- {
820
- #if defined USEOMP
821
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
822
- #endif
823
- for(mf_long i = 0; i < prob.nnz; ++i)
824
- {
825
- mf_node &N = prob.R[i];
826
- if(N.u < (mf_long)p_map.size())
827
- N.u = p_map[N.u];
828
- if(N.v < (mf_long)q_map.size())
829
- N.v = q_map[N.v];
830
- }
831
- }
832
-
833
- vector<mf_node*> Utility::grid_problem(
834
- mf_problem &prob,
835
- mf_int nr_bins,
836
- vector<mf_int> &omega_p,
837
- vector<mf_int> &omega_q,
838
- vector<Block> &blocks)
839
- {
840
- vector<mf_long> counts(nr_bins*nr_bins, 0);
841
-
842
- mf_int seg_p = (mf_int)ceil((double)prob.m/nr_bins);
843
- mf_int seg_q = (mf_int)ceil((double)prob.n/nr_bins);
844
-
845
- auto get_block_id = [=] (mf_int u, mf_int v)
846
- {
847
- return (u/seg_p)*nr_bins+v/seg_q;
848
- };
849
-
850
- for(mf_long i = 0; i < prob.nnz; ++i)
851
- {
852
- mf_node &N = prob.R[i];
853
- mf_int block = get_block_id(N.u, N.v);
854
- counts[block] += 1;
855
- omega_p[N.u] += 1;
856
- omega_q[N.v] += 1;
857
- }
858
-
859
- vector<mf_node*> ptrs(nr_bins*nr_bins+1);
860
- mf_node *ptr = prob.R;
861
- ptrs[0] = ptr;
862
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
863
- ptrs[block+1] = ptrs[block] + counts[block];
864
-
865
- vector<mf_node*> pivots(ptrs.begin(), ptrs.end()-1);
866
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
867
- {
868
- for(mf_node* pivot = pivots[block]; pivot != ptrs[block+1];)
869
- {
870
- mf_int curr_block = get_block_id(pivot->u, pivot->v);
871
- if(curr_block == block)
872
- {
873
- ++pivot;
874
- continue;
875
- }
876
-
877
- mf_node *next = pivots[curr_block];
878
- swap(*pivot, *next);
879
- pivots[curr_block] += 1;
880
- }
881
- }
882
-
883
- #if defined USEOMP
884
- #pragma omp parallel for num_threads(nr_threads) schedule(dynamic)
885
- #endif
886
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
887
- {
888
- if(prob.m > prob.n)
889
- sort(ptrs[block], ptrs[block+1], sort_node_by_p());
890
- else
891
- sort(ptrs[block], ptrs[block+1], sort_node_by_q());
892
- }
893
-
894
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
895
- blocks[i].tie_to(ptrs[i], ptrs[i+1]);
896
-
897
- return ptrs;
898
- }
899
-
900
- void Utility::grid_shuffle_scale_problem_on_disk(
901
- mf_int m, mf_int n, mf_int nr_bins,
902
- mf_float scale, string data_path,
903
- vector<mf_int> &p_map, vector<mf_int> &q_map,
904
- vector<mf_int> &omega_p, vector<mf_int> &omega_q,
905
- vector<BlockOnDisk> &blocks)
906
- {
907
- string const buffer_path = data_path+string(".disk");
908
- mf_int seg_p = (mf_int)ceil((double)m/nr_bins);
909
- mf_int seg_q = (mf_int)ceil((double)n/nr_bins);
910
- vector<mf_long> counts(nr_bins*nr_bins+1, 0);
911
- vector<mf_long> pivots(nr_bins*nr_bins, 0);
912
- ifstream source(data_path);
913
- fstream buffer(buffer_path, fstream::in|fstream::out|
914
- fstream::binary|fstream::trunc);
915
- auto get_block_id = [=] (mf_int u, mf_int v)
916
- {
917
- return (u/seg_p)*nr_bins+v/seg_q;
918
- };
919
-
920
- if(!source)
921
- throw ios::failure(string("cannot to open ")+data_path);
922
- if(!buffer)
923
- throw ios::failure(string("cannot to open ")+buffer_path);
924
-
925
- for(mf_node N; source >> N.u >> N.v >> N.r;)
926
- {
927
- N.u = p_map[N.u];
928
- N.v = q_map[N.v];
929
- mf_int bid = get_block_id(N.u, N.v);
930
- omega_p[N.u] += 1;
931
- omega_q[N.v] += 1;
932
- counts[bid+1] += 1;
933
- }
934
-
935
- for(mf_int i = 1; i < nr_bins*nr_bins+1; ++i)
936
- {
937
- counts[i] += counts[i-1];
938
- pivots[i-1] = counts[i-1];
939
- }
940
-
941
- source.clear();
942
- source.seekg(0);
943
- for(mf_node N; source >> N.u >> N.v >> N.r;)
944
- {
945
- N.u = p_map[N.u];
946
- N.v = q_map[N.v];
947
- N.r /= scale;
948
- mf_int bid = get_block_id(N.u, N.v);
949
- buffer.seekp(pivots[bid]*sizeof(mf_node));
950
- buffer.write((char*)&N, sizeof(mf_node));
951
- pivots[bid] += 1;
952
- }
953
-
954
- for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
955
- {
956
- vector<mf_node> nodes(static_cast<size_t>(counts[i+1]-counts[i]));
957
- buffer.clear();
958
- buffer.seekg(counts[i]*sizeof(mf_node));
959
- buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
960
-
961
- if(m > n)
962
- sort(nodes.begin(), nodes.end(), sort_node_by_p());
963
- else
964
- sort(nodes.begin(), nodes.end(), sort_node_by_q());
965
-
966
- buffer.clear();
967
- buffer.seekp(counts[i]*sizeof(mf_node));
968
- buffer.write((char*)nodes.data(), sizeof(mf_node)*nodes.size());
969
- buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
970
- }
971
-
972
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
973
- blocks[i].tie_to(buffer_path, counts[i], counts[i+1]);
974
- }
975
-
976
- mf_float* Utility::malloc_aligned_float(mf_long size)
977
- {
978
- // Check if conversion from mf_long to size_t causes overflow.
979
- if (size > numeric_limits<std::size_t>::max() / sizeof(mf_float) + 1)
980
- throw bad_alloc();
981
- // [REVIEW] I hope one day we can use C11 aligned_alloc to replace
982
- // platform-depedent functions below. Both of Windows and OSX currently
983
- // don't support that function.
984
- void *ptr = nullptr;
985
- #ifdef _WIN32
986
- ptr = _aligned_malloc(static_cast<size_t>(size*sizeof(mf_float)),
987
- kALIGNByte);
988
- #else
989
- int status = posix_memalign(&ptr, kALIGNByte, size*sizeof(mf_float));
990
- if(status != 0)
991
- throw bad_alloc();
992
- #endif
993
- if(ptr == nullptr)
994
- throw bad_alloc();
995
-
996
- return (mf_float*)ptr;
997
- }
998
-
999
- void Utility::free_aligned_float(mf_float *ptr)
1000
- {
1001
- #ifdef _WIN32
1002
- // Unfortunately, Visual Studio doesn't want to support the
1003
- // cross-platform allocation below.
1004
- _aligned_free(ptr);
1005
- #else
1006
- free(ptr);
1007
- #endif
1008
- }
1009
-
1010
- mf_model* Utility::init_model(mf_int fun,
1011
- mf_int m, mf_int n,
1012
- mf_int k, mf_float avg,
1013
- vector<mf_int> &omega_p,
1014
- vector<mf_int> &omega_q)
1015
- {
1016
- mf_int k_real = k;
1017
- mf_int k_aligned = (mf_int)ceil(mf_double(k)/kALIGN)*kALIGN;
1018
-
1019
- mf_model *model = new mf_model;
1020
-
1021
- model->fun = fun;
1022
- model->m = m;
1023
- model->n = n;
1024
- model->k = k_aligned;
1025
- model->b = avg;
1026
- model->P = nullptr;
1027
- model->Q = nullptr;
1028
-
1029
- mf_float scale = (mf_float)sqrt(1.0/k_real);
1030
- default_random_engine generator;
1031
- uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1032
-
1033
- try
1034
- {
1035
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1036
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1037
- }
1038
- catch(bad_alloc const &e)
1039
- {
1040
- cerr << e.what() << endl;
1041
- mf_destroy_model(&model);
1042
- throw;
1043
- }
1044
-
1045
- auto init1 = [&](mf_float *start_ptr, mf_long size, vector<mf_int> counts)
1046
- {
1047
- memset(start_ptr, 0, static_cast<size_t>(
1048
- sizeof(mf_float) * size*model->k));
1049
- for(mf_long i = 0; i < size; ++i)
1050
- {
1051
- mf_float * ptr = start_ptr + i*model->k;
1052
- if(counts[static_cast<size_t>(i)] > 0)
1053
- for(mf_long d = 0; d < k_real; ++d, ++ptr)
1054
- *ptr = (mf_float)(distribution(generator)*scale);
1055
- else
1056
- if(fun != P_ROW_BPR_MFOC && fun != P_COL_BPR_MFOC) // unseen for bpr is 0
1057
- for(mf_long d = 0; d < k_real; ++d, ++ptr)
1058
- *ptr = numeric_limits<mf_float>::quiet_NaN();
1059
- }
1060
- };
1061
-
1062
- init1(model->P, m, omega_p);
1063
- init1(model->Q, n, omega_q);
1064
-
1065
- return model;
1066
- }
1067
-
1068
- // Initialize P=[\bar{p}_1, ..., \bar{p}_d] and Q=[\bar{q}_1, ..., \bar{q}_d].
1069
- // Note that \bar{q}_{kv} is Q[k*n+v] and \bar{p}_{ku} is P[k*m+u]. One may
1070
- // notice that P and Q here are actually the transposes of P and Q in fpsg(...)
1071
- // because fpsg(...) uses P^TQ (where P and Q are respectively k-by-m and
1072
- // k-by-n) to approximate the given rating matrix R while ccd_one_class(...)
1073
- // uses PQ^T (where P and Q are respectively m-by-k and n-by-k.
1074
- mf_model* Utility::init_model(mf_int m, mf_int n, mf_int k)
1075
- {
1076
- mf_model *model = new mf_model;
1077
-
1078
- model->fun = P_L2_MFOC;
1079
- model->m = m;
1080
- model->n = n;
1081
- model->k = k;
1082
- model->b = 0.0; // One-class matrix factorization doesn't have bias.
1083
- model->P = nullptr;
1084
- model->Q = nullptr;
1085
-
1086
- try
1087
- {
1088
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1089
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1090
- }
1091
- catch(bad_alloc const &e)
1092
- {
1093
- cerr << e.what() << endl;
1094
- mf_destroy_model(&model);
1095
- throw;
1096
- }
1097
-
1098
- // Our initialization strategy is that all P's elements are zero and do
1099
- // random initization on Q. Thus, all initial predicted ratings are all zero
1100
- // since the approximated rating matrix is PQ^T.
1101
-
1102
- // Initialize P with zeros
1103
- for(mf_long i = 0; i < k * m; ++i)
1104
- model->P[i] = 0.0;
1105
-
1106
- // Initialize Q with random numbers
1107
- default_random_engine generator;
1108
- uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1109
- for(mf_long i = 0; i < k * n; ++i)
1110
- model->Q[i] = distribution(generator);
1111
-
1112
- return model;
1113
- }
1114
-
1115
- vector<mf_int> Utility::gen_random_map(mf_int size)
1116
- {
1117
- srand(0);
1118
- vector<mf_int> map(size, 0);
1119
- for(mf_int i = 0; i < size; ++i)
1120
- map[i] = i;
1121
- random_shuffle(map.begin(), map.end());
1122
- return map;
1123
- }
1124
-
1125
- vector<mf_int> Utility::gen_inv_map(vector<mf_int> &map)
1126
- {
1127
- vector<mf_int> inv_map(map.size());
1128
- for(mf_int i = 0; i < (mf_long)map.size(); ++i)
1129
- inv_map[map[i]] = i;
1130
- return inv_map;
1131
- }
1132
-
1133
- void Utility::shuffle_model(
1134
- mf_model &model,
1135
- vector<mf_int> &p_map,
1136
- vector<mf_int> &q_map)
1137
- {
1138
- auto inv_shuffle1 = [] (mf_float *vec, vector<mf_int> &map,
1139
- mf_int size, mf_int k)
1140
- {
1141
- for(mf_int pivot = 0; pivot < size;)
1142
- {
1143
- if(pivot == map[pivot])
1144
- {
1145
- ++pivot;
1146
- continue;
1147
- }
1148
-
1149
- mf_int next = map[pivot];
1150
-
1151
- for(mf_int d = 0; d < k; ++d)
1152
- swap(*(vec+(mf_long)pivot*k+d), *(vec+(mf_long)next*k+d));
1153
-
1154
- map[pivot] = map[next];
1155
- map[next] = next;
1156
- }
1157
- };
1158
-
1159
- inv_shuffle1(model.P, p_map, model.m, model.k);
1160
- inv_shuffle1(model.Q, q_map, model.n, model.k);
1161
- }
1162
-
1163
- void Utility::shrink_model(mf_model &model, mf_int k_new)
1164
- {
1165
- mf_int k_old = model.k;
1166
- model.k = k_new;
1167
-
1168
- auto shrink1 = [&] (mf_float *ptr, mf_int size)
1169
- {
1170
- for(mf_int i = 0; i < size; ++i)
1171
- {
1172
- mf_float *src = ptr+(mf_long)i*k_old;
1173
- mf_float *dst = ptr+(mf_long)i*k_new;
1174
- copy(src, src+k_new, dst);
1175
- }
1176
- };
1177
-
1178
- shrink1(model.P, model.m);
1179
- shrink1(model.Q, model.n);
1180
- }
1181
-
1182
- mf_problem* Utility::copy_problem(mf_problem const *prob, bool copy_data)
1183
- {
1184
- mf_problem *new_prob = new mf_problem;
1185
-
1186
- if(prob == nullptr)
1187
- {
1188
- new_prob->m = 0;
1189
- new_prob->n = 0;
1190
- new_prob->nnz = 0;
1191
- new_prob->R = nullptr;
1192
-
1193
- return new_prob;
1194
- }
1195
-
1196
- new_prob->m = prob->m;
1197
- new_prob->n = prob->n;
1198
- new_prob->nnz = prob->nnz;
1199
-
1200
- if(copy_data)
1201
- {
1202
- try
1203
- {
1204
- new_prob->R = new mf_node[static_cast<size_t>(prob->nnz)];
1205
- copy(prob->R, prob->R+prob->nnz, new_prob->R);
1206
- }
1207
- catch(...)
1208
- {
1209
- delete new_prob;
1210
- throw;
1211
- }
1212
- }
1213
- else
1214
- {
1215
- new_prob->R = prob->R;
1216
- }
1217
-
1218
- return new_prob;
1219
- }
1220
-
1221
- //--------------------------------------
1222
- //-----The base class of all solvers----
1223
- //--------------------------------------
1224
-
1225
- class SolverBase
1226
- {
1227
- public:
1228
- SolverBase(Scheduler &scheduler, vector<BlockBase*> &blocks,
1229
- mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
1230
- bool &slow_only)
1231
- : scheduler(scheduler), blocks(blocks), PG(PG), QG(QG),
1232
- model(model), param(param), slow_only(slow_only) {}
1233
- void run();
1234
- SolverBase(const SolverBase&) = delete;
1235
- SolverBase& operator=(const SolverBase&) = delete;
1236
- // Solver is stateless functor, so default destructor should be
1237
- // good enough.
1238
- virtual ~SolverBase() = default;
1239
-
1240
- protected:
1241
- #if defined USESSE
1242
- static void calc_z(__m128 &XMMz, mf_int k, mf_float *p, mf_float *q);
1243
- virtual void load_fixed_variables(
1244
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1245
- __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
1246
- __m128 &XMMeta, __m128 &XMMrk_slow,
1247
- __m128 &XMMrk_fast);
1248
- virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1249
- virtual void prepare_for_sg_update(
1250
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1251
- virtual void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1252
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1253
- __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
1254
- __m128 XMMeta, __m128 XMMrk) = 0;
1255
- virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1256
- #elif defined USEAVX
1257
- static void calc_z(__m256 &XMMz, mf_int k, mf_float *p, mf_float *q);
1258
- virtual void load_fixed_variables(
1259
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1260
- __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
1261
- __m256 &XMMeta, __m256 &XMMrk_slow,
1262
- __m256 &XMMrk_fast);
1263
- virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1264
- virtual void prepare_for_sg_update(
1265
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1266
- virtual void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1267
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1268
- __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
1269
- __m256 XMMeta, __m256 XMMrk) = 0;
1270
- virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1271
- #else
1272
- static void calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q);
1273
- virtual void load_fixed_variables();
1274
- virtual void arrange_block();
1275
- virtual void prepare_for_sg_update() = 0;
1276
- virtual void sg_update(mf_int d_begin, mf_int d_end, mf_float rk) = 0;
1277
- virtual void finalize();
1278
- static float qrsqrt(float x);
1279
- #endif
1280
- virtual void update() { ++pG; ++qG; };
1281
-
1282
- Scheduler &scheduler;
1283
- vector<BlockBase*> &blocks;
1284
- BlockBase *block;
1285
- mf_float *PG;
1286
- mf_float *QG;
1287
- mf_model &model;
1288
- mf_parameter param;
1289
- bool &slow_only;
1290
-
1291
- mf_node *N;
1292
- mf_float z;
1293
- mf_double loss;
1294
- mf_double error;
1295
- mf_float *p;
1296
- mf_float *q;
1297
- mf_float *pG;
1298
- mf_float *qG;
1299
- mf_int bid;
1300
-
1301
- mf_float lambda_p1;
1302
- mf_float lambda_q1;
1303
- mf_float lambda_p2;
1304
- mf_float lambda_q2;
1305
- mf_float rk_slow;
1306
- mf_float rk_fast;
1307
- };
1308
-
1309
- #if defined USESSE
1310
- inline void SolverBase::run()
1311
- {
1312
- __m128d XMMloss;
1313
- __m128d XMMerror;
1314
- __m128 XMMz;
1315
- __m128 XMMlambda_p1;
1316
- __m128 XMMlambda_q1;
1317
- __m128 XMMlambda_p2;
1318
- __m128 XMMlambda_q2;
1319
- __m128 XMMeta;
1320
- __m128 XMMrk_slow;
1321
- __m128 XMMrk_fast;
1322
- load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1323
- XMMlambda_p2, XMMlambda_q2,
1324
- XMMeta, XMMrk_slow,
1325
- XMMrk_fast);
1326
- while(!scheduler.is_terminated())
1327
- {
1328
- arrange_block(XMMloss, XMMerror);
1329
- while(block->move_next())
1330
- {
1331
- N = block->get_current();
1332
- p = model.P+(mf_long)N->u*model.k;
1333
- q = model.Q+(mf_long)N->v*model.k;
1334
- pG = PG+N->u*2;
1335
- qG = QG+N->v*2;
1336
- prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1337
- sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1338
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1339
- if(slow_only)
1340
- continue;
1341
- update();
1342
- sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1343
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1344
- }
1345
- finalize(XMMloss, XMMerror);
1346
- }
1347
- }
1348
-
1349
- void SolverBase::load_fixed_variables(
1350
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1351
- __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
1352
- __m128 &XMMeta, __m128 &XMMrk_slow,
1353
- __m128 &XMMrk_fast)
1354
- {
1355
- XMMlambda_p1 = _mm_set1_ps(param.lambda_p1);
1356
- XMMlambda_q1 = _mm_set1_ps(param.lambda_q1);
1357
- XMMlambda_p2 = _mm_set1_ps(param.lambda_p2);
1358
- XMMlambda_q2 = _mm_set1_ps(param.lambda_q2);
1359
- XMMeta = _mm_set1_ps(param.eta);
1360
- XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
1361
- XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
1362
- }
1363
-
1364
- void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1365
- {
1366
- XMMloss = _mm_setzero_pd();
1367
- XMMerror = _mm_setzero_pd();
1368
- bid = scheduler.get_job();
1369
- block = blocks[bid];
1370
- block->reload();
1371
- }
1372
-
1373
- inline void SolverBase::calc_z(
1374
- __m128 &XMMz, mf_int k, mf_float *p, mf_float *q)
1375
- {
1376
- XMMz = _mm_setzero_ps();
1377
- for(mf_int d = 0; d < k; d += 4)
1378
- XMMz = _mm_add_ps(XMMz, _mm_mul_ps(
1379
- _mm_load_ps(p+d), _mm_load_ps(q+d)));
1380
- // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
1381
- // high-bit to low-bit, where "+" means concatenating two arrays.
1382
- __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
1383
- // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
1384
- // high-bit to low-bit, where "+" means concatenating two arrays.
1385
- XMMz = _mm_add_ps(XMMtmp, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
1386
- }
1387
-
1388
- void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1389
- {
1390
- _mm_store_sd(&loss, XMMloss);
1391
- _mm_store_sd(&error, XMMerror);
1392
- block->free();
1393
- scheduler.put_job(bid, loss, error);
1394
- }
1395
- #elif defined USEAVX
1396
- inline void SolverBase::run()
1397
- {
1398
- __m128d XMMloss;
1399
- __m128d XMMerror;
1400
- __m256 XMMz;
1401
- __m256 XMMlambda_p1;
1402
- __m256 XMMlambda_q1;
1403
- __m256 XMMlambda_p2;
1404
- __m256 XMMlambda_q2;
1405
- __m256 XMMeta;
1406
- __m256 XMMrk_slow;
1407
- __m256 XMMrk_fast;
1408
- load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1409
- XMMlambda_p2, XMMlambda_q2,
1410
- XMMeta, XMMrk_slow, XMMrk_fast);
1411
- while(!scheduler.is_terminated())
1412
- {
1413
- arrange_block(XMMloss, XMMerror);
1414
- while(block->move_next())
1415
- {
1416
- N = block->get_current();
1417
- p = model.P+(mf_long)N->u*model.k;
1418
- q = model.Q+(mf_long)N->v*model.k;
1419
- pG = PG+N->u*2;
1420
- qG = QG+N->v*2;
1421
- prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1422
- sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1423
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1424
- if(slow_only)
1425
- continue;
1426
- update();
1427
- sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1428
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_fast);
1429
- }
1430
- finalize(XMMloss, XMMerror);
1431
- }
1432
- }
1433
-
1434
- void SolverBase::load_fixed_variables(
1435
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1436
- __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
1437
- __m256 &XMMeta, __m256 &XMMrk_slow,
1438
- __m256 &XMMrk_fast)
1439
- {
1440
- XMMlambda_p1 = _mm256_set1_ps(param.lambda_p1);
1441
- XMMlambda_q1 = _mm256_set1_ps(param.lambda_q1);
1442
- XMMlambda_p2 = _mm256_set1_ps(param.lambda_p2);
1443
- XMMlambda_q2 = _mm256_set1_ps(param.lambda_q2);
1444
- XMMeta = _mm256_set1_ps(param.eta);
1445
- XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
1446
- XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
1447
- }
1448
-
1449
- void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1450
- {
1451
- XMMloss = _mm_setzero_pd();
1452
- XMMerror = _mm_setzero_pd();
1453
- bid = scheduler.get_job();
1454
- block = blocks[bid];
1455
- block->reload();
1456
- }
1457
-
1458
- inline void SolverBase::calc_z(
1459
- __m256 &XMMz, mf_int k, mf_float *p, mf_float *q)
1460
- {
1461
- XMMz = _mm256_setzero_ps();
1462
- for(mf_int d = 0; d < k; d += 8)
1463
- XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
1464
- _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
1465
- XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
1466
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
1467
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
1468
- }
1469
-
1470
- void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1471
- {
1472
- _mm_store_sd(&loss, XMMloss);
1473
- _mm_store_sd(&error, XMMerror);
1474
- block->free();
1475
- scheduler.put_job(bid, loss, error);
1476
- }
1477
- #else
1478
- inline void SolverBase::run()
1479
- {
1480
- load_fixed_variables();
1481
- while(!scheduler.is_terminated())
1482
- {
1483
- arrange_block();
1484
- while(block->move_next())
1485
- {
1486
- N = block->get_current();
1487
- p = model.P+(mf_long)N->u*model.k;
1488
- q = model.Q+(mf_long)N->v*model.k;
1489
- pG = PG+N->u*2;
1490
- qG = QG+N->v*2;
1491
- prepare_for_sg_update();
1492
- sg_update(0, kALIGN, rk_slow);
1493
- if(slow_only)
1494
- continue;
1495
- update();
1496
- sg_update(kALIGN, model.k, rk_fast);
1497
- }
1498
- finalize();
1499
- }
1500
- }
1501
-
1502
- inline float SolverBase::qrsqrt(float x)
1503
- {
1504
- float xhalf = 0.5f*x;
1505
- uint32_t i;
1506
- memcpy(&i, &x, sizeof(i));
1507
- i = 0x5f375a86 - (i>>1);
1508
- memcpy(&x, &i, sizeof(i));
1509
- x = x*(1.5f - xhalf*x*x);
1510
- return x;
1511
- }
1512
-
1513
- void SolverBase::load_fixed_variables()
1514
- {
1515
- lambda_p1 = param.lambda_p1;
1516
- lambda_q1 = param.lambda_q1;
1517
- lambda_p2 = param.lambda_p2;
1518
- lambda_q2 = param.lambda_q2;
1519
- rk_slow = (mf_float)1.0/kALIGN;
1520
- rk_fast = (mf_float)1.0/(model.k-kALIGN);
1521
- }
1522
-
1523
- void SolverBase::arrange_block()
1524
- {
1525
- loss = 0.0;
1526
- error = 0.0;
1527
- bid = scheduler.get_job();
1528
- block = blocks[bid];
1529
- block->reload();
1530
- }
1531
-
1532
- inline void SolverBase::calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q)
1533
- {
1534
- z = 0;
1535
- for(mf_int d = 0; d < k; ++d)
1536
- z += p[d]*q[d];
1537
- }
1538
-
1539
- void SolverBase::finalize()
1540
- {
1541
- block->free();
1542
- scheduler.put_job(bid, loss, error);
1543
- }
1544
- #endif
1545
-
1546
- //--------------------------------------
1547
- //-----Real-valued MF and binary MF-----
1548
- //--------------------------------------
1549
-
1550
- class MFSolver: public SolverBase
1551
- {
1552
- public:
1553
- MFSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
1554
- mf_float *PG, mf_float *QG, mf_model &model,
1555
- mf_parameter param, bool &slow_only)
1556
- : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only) {}
1557
-
1558
- protected:
1559
- #if defined USESSE
1560
- void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1561
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1562
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1563
- __m128 XMMeta, __m128 XMMrk);
1564
- #elif defined USEAVX
1565
- void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1566
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1567
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1568
- __m256 XMMeta, __m256 XMMrk);
1569
- #else
1570
- void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
1571
- #endif
1572
- };
1573
-
1574
- #if defined USESSE
1575
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1576
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1577
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1578
- __m128 XMMeta, __m128 XMMrk)
1579
- {
1580
- __m128 XMMpG = _mm_load1_ps(pG);
1581
- __m128 XMMqG = _mm_load1_ps(qG);
1582
- __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
1583
- __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
1584
- __m128 XMMpG1 = _mm_setzero_ps();
1585
- __m128 XMMqG1 = _mm_setzero_ps();
1586
-
1587
- for(mf_int d = d_begin; d < d_end; d += 4)
1588
- {
1589
- __m128 XMMp = _mm_load_ps(p+d);
1590
- __m128 XMMq = _mm_load_ps(q+d);
1591
-
1592
- __m128 XMMpg = _mm_sub_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
1593
- _mm_mul_ps(XMMz, XMMq));
1594
- __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
1595
- _mm_mul_ps(XMMz, XMMp));
1596
-
1597
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
1598
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
1599
-
1600
- XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
1601
- XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
1602
-
1603
- _mm_store_ps(p+d, XMMp);
1604
- _mm_store_ps(q+d, XMMq);
1605
- }
1606
-
1607
- mf_float tmp = 0;
1608
- _mm_store_ss(&tmp, XMMlambda_p1);
1609
- if(tmp > 0)
1610
- {
1611
- for(mf_int d = d_begin; d < d_end; d += 4)
1612
- {
1613
- __m128 XMMp = _mm_load_ps(p+d);
1614
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
1615
- _mm_set1_ps(-0.0f));
1616
- XMMp = _mm_xor_ps(XMMflip,
1617
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
1618
- _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
1619
- _mm_store_ps(p+d, XMMp);
1620
- }
1621
- }
1622
-
1623
- _mm_store_ss(&tmp, XMMlambda_q1);
1624
- if(tmp > 0)
1625
- {
1626
- for(mf_int d = d_begin; d < d_end; d += 4)
1627
- {
1628
- __m128 XMMq = _mm_load_ps(q+d);
1629
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
1630
- _mm_set1_ps(-0.0f));
1631
- XMMq = _mm_xor_ps(XMMflip,
1632
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
1633
- _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
1634
- _mm_store_ps(q+d, XMMq);
1635
- }
1636
- }
1637
-
1638
- if(param.do_nmf)
1639
- {
1640
- for(mf_int d = d_begin; d < d_end; d += 4)
1641
- {
1642
- __m128 XMMp = _mm_load_ps(p+d);
1643
- __m128 XMMq = _mm_load_ps(q+d);
1644
- XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
1645
- XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
1646
- _mm_store_ps(p+d, XMMp);
1647
- _mm_store_ps(q+d, XMMq);
1648
- }
1649
- }
1650
-
1651
- __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
1652
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1653
- XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
1654
- _mm_store_ss(pG, XMMpG);
1655
-
1656
- XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
1657
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1658
- XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
1659
- _mm_store_ss(qG, XMMqG);
1660
- }
1661
- #elif defined USEAVX
1662
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1663
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1664
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1665
- __m256 XMMeta, __m256 XMMrk)
1666
- {
1667
- __m256 XMMpG = _mm256_broadcast_ss(pG);
1668
- __m256 XMMqG = _mm256_broadcast_ss(qG);
1669
- __m256 XMMeta_p = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
1670
- __m256 XMMeta_q = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
1671
- __m256 XMMpG1 = _mm256_setzero_ps();
1672
- __m256 XMMqG1 = _mm256_setzero_ps();
1673
-
1674
- for(mf_int d = d_begin; d < d_end; d += 8)
1675
- {
1676
- __m256 XMMp = _mm256_load_ps(p+d);
1677
- __m256 XMMq = _mm256_load_ps(q+d);
1678
-
1679
- __m256 XMMpg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
1680
- _mm256_mul_ps(XMMz, XMMq));
1681
- __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
1682
- _mm256_mul_ps(XMMz, XMMp));
1683
-
1684
- XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
1685
- XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
1686
-
1687
- XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
1688
- XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
1689
- _mm256_store_ps(p+d, XMMp);
1690
- _mm256_store_ps(q+d, XMMq);
1691
- }
1692
-
1693
- mf_float tmp = 0;
1694
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
1695
- if(tmp > 0)
1696
- {
1697
- for(mf_int d = d_begin; d < d_end; d += 8)
1698
- {
1699
- __m256 XMMp = _mm256_load_ps(p+d);
1700
- __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMp,
1701
- _mm256_set1_ps(0.0f), _CMP_LE_OS),
1702
- _mm256_set1_ps(-0.0f));
1703
- XMMp = _mm256_xor_ps(XMMflip,
1704
- _mm256_max_ps(_mm256_sub_ps(
1705
- _mm256_xor_ps(XMMp, XMMflip),
1706
- _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
1707
- _mm256_set1_ps(0.0f)));
1708
- _mm256_store_ps(p+d, XMMp);
1709
- }
1710
- }
1711
-
1712
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
1713
- if(tmp > 0)
1714
- {
1715
- for(mf_int d = d_begin; d < d_end; d += 8)
1716
- {
1717
- __m256 XMMq = _mm256_load_ps(q+d);
1718
- __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMq,
1719
- _mm256_set1_ps(0.0f), _CMP_LE_OS),
1720
- _mm256_set1_ps(-0.0f));
1721
- XMMq = _mm256_xor_ps(XMMflip,
1722
- _mm256_max_ps(_mm256_sub_ps(
1723
- _mm256_xor_ps(XMMq, XMMflip),
1724
- _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
1725
- _mm256_set1_ps(0.0f)));
1726
- _mm256_store_ps(q+d, XMMq);
1727
- }
1728
- }
1729
-
1730
- if(param.do_nmf)
1731
- {
1732
- for(mf_int d = d_begin; d < d_end; d += 8)
1733
- {
1734
- __m256 XMMp = _mm256_load_ps(p+d);
1735
- __m256 XMMq = _mm256_load_ps(q+d);
1736
- XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0));
1737
- XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0));
1738
- _mm256_store_ps(p+d, XMMp);
1739
- _mm256_store_ps(q+d, XMMq);
1740
- }
1741
- }
1742
-
1743
- XMMpG1 = _mm256_add_ps(XMMpG1,
1744
- _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
1745
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1746
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1747
-
1748
- XMMqG1 = _mm256_add_ps(XMMqG1,
1749
- _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
1750
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1751
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1752
-
1753
- XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
1754
- XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
1755
-
1756
- _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
1757
- _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
1758
- }
1759
- #else
1760
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
1761
- {
1762
- mf_float eta_p = param.eta*qrsqrt(*pG);
1763
- mf_float eta_q = param.eta*qrsqrt(*qG);
1764
-
1765
- mf_float pG1 = 0;
1766
- mf_float qG1 = 0;
1767
-
1768
- for(mf_int d = d_begin; d < d_end; ++d)
1769
- {
1770
- mf_float gp = -z*q[d]+lambda_p2*p[d];
1771
- mf_float gq = -z*p[d]+lambda_q2*q[d];
1772
-
1773
- pG1 += gp*gp;
1774
- qG1 += gq*gq;
1775
-
1776
- p[d] -= eta_p*gp;
1777
- q[d] -= eta_q*gq;
1778
- }
1779
-
1780
- if(lambda_p1 > 0)
1781
- {
1782
- for(mf_int d = d_begin; d < d_end; ++d)
1783
- {
1784
- mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
1785
- p[d] = p[d] >= 0? p1: -p1;
1786
- }
1787
- }
1788
-
1789
- if(lambda_q1 > 0)
1790
- {
1791
- for(mf_int d = d_begin; d < d_end; ++d)
1792
- {
1793
- mf_float q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
1794
- q[d] = q[d] >= 0? q1: -q1;
1795
- }
1796
- }
1797
-
1798
- if(param.do_nmf)
1799
- {
1800
- for(mf_int d = d_begin; d < d_end; ++d)
1801
- {
1802
- p[d] = max(p[d], (mf_float)0.0f);
1803
- q[d] = max(q[d], (mf_float)0.0f);
1804
- }
1805
- }
1806
-
1807
- *pG += pG1*rk;
1808
- *qG += qG1*rk;
1809
- }
1810
- #endif
1811
-
1812
- class L2_MFR : public MFSolver
1813
- {
1814
- public:
1815
- L2_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1816
- mf_model &model, mf_parameter param, bool &slow_only)
1817
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1818
-
1819
- protected:
1820
- #if defined USESSE
1821
- void prepare_for_sg_update(
1822
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1823
- #elif defined USEAVX
1824
- void prepare_for_sg_update(
1825
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1826
- #else
1827
- void prepare_for_sg_update();
1828
- #endif
1829
- };
1830
-
1831
- #if defined USESSE
1832
- void L2_MFR::prepare_for_sg_update(
1833
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1834
- {
1835
- calc_z(XMMz, model.k, p, q);
1836
- XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1837
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1838
- _mm_mul_ps(XMMz, XMMz)));
1839
- XMMerror = XMMloss;
1840
- }
1841
- #elif defined USEAVX
1842
- void L2_MFR::prepare_for_sg_update(
1843
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1844
- {
1845
- calc_z(XMMz, model.k, p, q);
1846
- XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1847
- XMMloss = _mm_add_pd(XMMloss,
1848
- _mm_cvtps_pd(_mm256_castps256_ps128(
1849
- _mm256_mul_ps(XMMz, XMMz))));
1850
- XMMerror = XMMloss;
1851
- }
1852
- #else
1853
- void L2_MFR::prepare_for_sg_update()
1854
- {
1855
- calc_z(z, model.k, p, q);
1856
- z = N->r-z;
1857
- loss += z*z;
1858
- error = loss;
1859
- }
1860
- #endif
1861
- class L1_MFR : public MFSolver
1862
- {
1863
- public:
1864
- L1_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1865
- mf_model &model, mf_parameter param, bool &slow_only)
1866
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1867
-
1868
- protected:
1869
- #if defined USESSE
1870
- void prepare_for_sg_update(
1871
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1872
- #elif defined USEAVX
1873
- void prepare_for_sg_update(
1874
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1875
- #else
1876
- void prepare_for_sg_update();
1877
- #endif
1878
- };
1879
-
1880
- #if defined USESSE
1881
- void L1_MFR::prepare_for_sg_update(
1882
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1883
- {
1884
- calc_z(XMMz, model.k, p, q);
1885
- XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1886
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1887
- _mm_andnot_ps(_mm_set1_ps(-0.0f), XMMz)));
1888
- XMMerror = XMMloss;
1889
- XMMz = _mm_add_ps(_mm_and_ps(_mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f)),
1890
- _mm_set1_ps(1.0f)),
1891
- _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
1892
- _mm_set1_ps(-1.0f)));
1893
- }
1894
- #elif defined USEAVX
1895
- void L1_MFR::prepare_for_sg_update(
1896
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1897
- {
1898
- calc_z(XMMz, model.k, p, q);
1899
- XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1900
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm256_castps256_ps128(
1901
- _mm256_andnot_ps(_mm256_set1_ps(-0.0f), XMMz))));
1902
- XMMerror = XMMloss;
1903
- XMMz = _mm256_add_ps(_mm256_and_ps(_mm256_cmp_ps(XMMz,
1904
- _mm256_set1_ps(0.0f), _CMP_GT_OS), _mm256_set1_ps(1.0f)),
1905
- _mm256_and_ps(_mm256_cmp_ps(XMMz,
1906
- _mm256_set1_ps(0.0f), _CMP_LT_OS), _mm256_set1_ps(-1.0f)));
1907
- }
1908
- #else
1909
- void L1_MFR::prepare_for_sg_update()
1910
- {
1911
- calc_z(z, model.k, p, q);
1912
- z = N->r-z;
1913
- loss += abs(z);
1914
- error = loss;
1915
- if(z > 0)
1916
- z = 1;
1917
- else if(z < 0)
1918
- z = -1;
1919
- }
1920
- #endif
1921
-
1922
- class KL_MFR : public MFSolver
1923
- {
1924
- public:
1925
- KL_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1926
- mf_model &model, mf_parameter param, bool &slow_only)
1927
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1928
-
1929
- protected:
1930
- #if defined USESSE
1931
- void prepare_for_sg_update(
1932
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1933
- #elif defined USEAVX
1934
- void prepare_for_sg_update(
1935
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1936
- #else
1937
- void prepare_for_sg_update();
1938
- #endif
1939
- };
1940
-
1941
- #if defined USESSE
1942
- void KL_MFR::prepare_for_sg_update(
1943
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1944
- {
1945
- calc_z(XMMz, model.k, p, q);
1946
- XMMz = _mm_div_ps(_mm_set1_ps(N->r), XMMz);
1947
- _mm_store_ss(&z, XMMz);
1948
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1949
- _mm_set1_ps(N->r*(log(z)-1+1/z))));
1950
- XMMerror = XMMloss;
1951
- XMMz = _mm_sub_ps(XMMz, _mm_set1_ps(1.0f));
1952
- }
1953
- #elif defined USEAVX
1954
- void KL_MFR::prepare_for_sg_update(
1955
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1956
- {
1957
- calc_z(XMMz, model.k, p, q);
1958
- XMMz = _mm256_div_ps(_mm256_set1_ps(N->r), XMMz);
1959
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
1960
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1961
- _mm_set1_ps(N->r*(log(z)-1+1/z))));
1962
- XMMerror = XMMloss;
1963
- XMMz = _mm256_sub_ps(XMMz, _mm256_set1_ps(1.0f));
1964
- }
1965
- #else
1966
- void KL_MFR::prepare_for_sg_update()
1967
- {
1968
- calc_z(z, model.k, p, q);
1969
- z = N->r/z;
1970
- loss += N->r*(log(z)-1+1/z);
1971
- error = loss;
1972
- z -= 1;
1973
- }
1974
- #endif
1975
-
1976
- class LR_MFC : public MFSolver
1977
- {
1978
- public:
1979
- LR_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
1980
- mf_float *PG, mf_float *QG, mf_model &model,
1981
- mf_parameter param, bool &slow_only)
1982
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1983
-
1984
- protected:
1985
- #if defined USESSE
1986
- void prepare_for_sg_update(
1987
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1988
- #elif defined USEAVX
1989
- void prepare_for_sg_update(
1990
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1991
- #else
1992
- void prepare_for_sg_update();
1993
- #endif
1994
- };
1995
-
1996
- #if defined USESSE
1997
- void LR_MFC::prepare_for_sg_update(
1998
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1999
- {
2000
- calc_z(XMMz, model.k, p, q);
2001
- _mm_store_ss(&z, XMMz);
2002
- if(N->r > 0)
2003
- {
2004
- z = exp(-z);
2005
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2006
- XMMz = _mm_set1_ps(z/(1+z));
2007
- }
2008
- else
2009
- {
2010
- z = exp(z);
2011
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2012
- XMMz = _mm_set1_ps(-z/(1+z));
2013
- }
2014
- XMMerror = XMMloss;
2015
- }
2016
- #elif defined USEAVX
2017
- void LR_MFC::prepare_for_sg_update(
2018
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2019
- {
2020
- calc_z(XMMz, model.k, p, q);
2021
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2022
- if(N->r > 0)
2023
- {
2024
- z = exp(-z);
2025
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2026
- XMMz = _mm256_set1_ps(z/(1+z));
2027
- }
2028
- else
2029
- {
2030
- z = exp(z);
2031
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2032
- XMMz = _mm256_set1_ps(-z/(1+z));
2033
- }
2034
- XMMerror = XMMloss;
2035
- }
2036
- #else
2037
- void LR_MFC::prepare_for_sg_update()
2038
- {
2039
- calc_z(z, model.k, p, q);
2040
- if(N->r > 0)
2041
- {
2042
- z = exp(-z);
2043
- loss += log(1+z);
2044
- error = loss;
2045
- z = z/(1+z);
2046
- }
2047
- else
2048
- {
2049
- z = exp(z);
2050
- loss += log(1+z);
2051
- error = loss;
2052
- z = -z/(1+z);
2053
- }
2054
- }
2055
- #endif
2056
-
2057
- class L2_MFC : public MFSolver
2058
- {
2059
- public:
2060
- L2_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2061
- mf_float *PG, mf_float *QG, mf_model &model,
2062
- mf_parameter param, bool &slow_only)
2063
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2064
-
2065
- protected:
2066
- #if defined USESSE
2067
- void prepare_for_sg_update(
2068
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2069
- #elif defined USEAVX
2070
- void prepare_for_sg_update(
2071
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2072
- #else
2073
- void prepare_for_sg_update();
2074
- #endif
2075
- };
2076
-
2077
- #if defined USESSE
2078
- void L2_MFC::prepare_for_sg_update(
2079
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2080
- {
2081
- calc_z(XMMz, model.k, p, q);
2082
- if(N->r > 0)
2083
- {
2084
- __m128 mask = _mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f));
2085
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2086
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2087
- XMMz = _mm_max_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2088
- _mm_set1_ps(1.0f), XMMz));
2089
- }
2090
- else
2091
- {
2092
- __m128 mask = _mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f));
2093
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2094
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2095
- XMMz = _mm_min_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2096
- _mm_set1_ps(-1.0f), XMMz));
2097
- }
2098
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2099
- _mm_mul_ps(XMMz, XMMz)));
2100
- }
2101
- #elif defined USEAVX
2102
- void L2_MFC::prepare_for_sg_update(
2103
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2104
- {
2105
- calc_z(XMMz, model.k, p, q);
2106
- if(N->r > 0)
2107
- {
2108
- __m128 mask = _mm_cmpgt_ps(_mm256_castps256_ps128(XMMz),
2109
- _mm_set1_ps(0.0f));
2110
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2111
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2112
- XMMz = _mm256_max_ps(_mm256_set1_ps(0.0f),
2113
- _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz));
2114
- }
2115
- else
2116
- {
2117
- __m128 mask = _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2118
- _mm_set1_ps(0.0f));
2119
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2120
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2121
- XMMz = _mm256_min_ps(_mm256_set1_ps(0.0f),
2122
- _mm256_sub_ps(_mm256_set1_ps(-1.0f), XMMz));
2123
- }
2124
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2125
- _mm_mul_ps(_mm256_castps256_ps128(XMMz),
2126
- _mm256_castps256_ps128(XMMz))));
2127
- }
2128
- #else
2129
- void L2_MFC::prepare_for_sg_update()
2130
- {
2131
- calc_z(z, model.k, p, q);
2132
- if(N->r > 0)
2133
- {
2134
- error += z > 0? 1: 0;
2135
- z = max(0.0f, 1-z);
2136
- }
2137
- else
2138
- {
2139
- error += z < 0? 1: 0;
2140
- z = min(0.0f, -1-z);
2141
- }
2142
- loss += z*z;
2143
- }
2144
- #endif
2145
-
2146
- class L1_MFC : public MFSolver
2147
- {
2148
- public:
2149
- L1_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
2150
- mf_model &model, mf_parameter param, bool &slow_only)
2151
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2152
-
2153
- protected:
2154
- #if defined USESSE
2155
- void prepare_for_sg_update(
2156
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2157
- #elif defined USEAVX
2158
- void prepare_for_sg_update(
2159
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2160
- #else
2161
- void prepare_for_sg_update();
2162
- #endif
2163
- };
2164
-
2165
- #if defined USESSE
2166
- void L1_MFC::prepare_for_sg_update(
2167
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2168
- {
2169
- calc_z(XMMz, model.k, p, q);
2170
- if(N->r > 0)
2171
- {
2172
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2173
- _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2174
- _mm_set1_ps(1.0f))));
2175
- XMMz = _mm_sub_ps(_mm_set1_ps(1.0f), XMMz);
2176
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2177
- _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2178
- XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2179
- _mm_set1_ps(1.0f));
2180
- }
2181
- else
2182
- {
2183
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2184
- _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
2185
- _mm_set1_ps(1.0f))));
2186
- XMMz = _mm_add_ps(_mm_set1_ps(1.0f), XMMz);
2187
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2188
- _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2189
- XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2190
- _mm_set1_ps(-1.0f));
2191
- }
2192
- }
2193
- #elif defined USEAVX
2194
- void L1_MFC::prepare_for_sg_update(
2195
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2196
- {
2197
- calc_z(XMMz, model.k, p, q);
2198
- if(N->r > 0)
2199
- {
2200
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2201
- _mm_cmpge_ps(_mm256_castps256_ps128(XMMz),
2202
- _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2203
- XMMz = _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz);
2204
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2205
- _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2206
- XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2207
- _CMP_GE_OS), _mm256_set1_ps(1.0f));
2208
- }
2209
- else
2210
- {
2211
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2212
- _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2213
- _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2214
- XMMz = _mm256_add_ps(_mm256_set1_ps(1.0f), XMMz);
2215
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2216
- _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2217
- XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2218
- _CMP_GE_OS), _mm256_set1_ps(-1.0f));
2219
- }
2220
- }
2221
- #else
2222
- void L1_MFC::prepare_for_sg_update()
2223
- {
2224
- calc_z(z, model.k, p, q);
2225
- if(N->r > 0)
2226
- {
2227
- loss += max(0.0f, 1-z);
2228
- error += z > 0? 1.0f: 0.0f;
2229
- z = z > 1? 0.0f: 1.0f;
2230
- }
2231
- else
2232
- {
2233
- loss += max(0.0f, 1+z);
2234
- error += z < 0? 1.0f: 0.0f;
2235
- z = z < -1? 0.0f: -1.0f;
2236
- }
2237
- }
2238
- #endif
2239
- //--------------------------------------
2240
- //------------One-class MF--------------
2241
- //--------------------------------------
2242
-
2243
- class BPRSolver : public SolverBase
2244
- {
2245
- public:
2246
- BPRSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
2247
- mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
2248
- bool &slow_only, bool is_column_oriented)
2249
- : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only),
2250
- is_column_oriented(is_column_oriented) {}
2251
-
2252
- protected:
2253
- #if defined USESSE
2254
- static void calc_z(__m128 &XMMz, mf_int k,
2255
- mf_float *p, mf_float *q, mf_float *w);
2256
- void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2257
- void prepare_for_sg_update(
2258
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2259
- void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2260
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2261
- __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
2262
- __m128 XMMeta, __m128 XMMrk);
2263
- void finalize(__m128d XMMloss, __m128d XMMerror);
2264
- #elif defined USEAVX
2265
- static void calc_z(__m256 &XMMz, mf_int k,
2266
- mf_float *p, mf_float *q, mf_float *w);
2267
- void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2268
- void prepare_for_sg_update(
2269
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2270
- void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2271
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2272
- __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
2273
- __m256 XMMeta, __m256 XMMrk);
2274
- void finalize(__m128d XMMloss, __m128d XMMerror);
2275
- #else
2276
- static void calc_z(mf_float &z, mf_int k,
2277
- mf_float *p, mf_float *q, mf_float *w);
2278
- void arrange_block();
2279
- void prepare_for_sg_update();
2280
- void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
2281
- void finalize();
2282
- #endif
2283
- void update() { ++pG; ++qG; ++wG; };
2284
- virtual void prepare_negative() = 0;
2285
-
2286
- bool is_column_oriented;
2287
- mf_int bpr_bid;
2288
- mf_float *w;
2289
- mf_float *wG;
2290
- };
2291
-
2292
-
2293
- #if defined USESSE
2294
- inline void BPRSolver::calc_z(
2295
- __m128 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2296
- {
2297
- XMMz = _mm_setzero_ps();
2298
- for(mf_int d = 0; d < k; d += 4)
2299
- XMMz = _mm_add_ps(XMMz, _mm_mul_ps(_mm_load_ps(p+d),
2300
- _mm_sub_ps(_mm_load_ps(q+d), _mm_load_ps(w+d))));
2301
- // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
2302
- // high-bit to low-bit, where "+" means concatenating two arrays.
2303
- __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
2304
- // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
2305
- // high-bit to low-bit, where "+" means concatenating two arrays.
2306
- XMMz = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
2307
- }
2308
-
2309
- void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2310
- {
2311
- XMMloss = _mm_setzero_pd();
2312
- XMMerror = _mm_setzero_pd();
2313
- bid = scheduler.get_job();
2314
- block = blocks[bid];
2315
- block->reload();
2316
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2317
- }
2318
-
2319
- void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2320
- {
2321
- _mm_store_sd(&loss, XMMloss);
2322
- _mm_store_sd(&error, XMMerror);
2323
- scheduler.put_job(bid, loss, error);
2324
- scheduler.put_bpr_job(bid, bpr_bid);
2325
- }
2326
-
2327
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2328
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2329
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
2330
- __m128 XMMeta, __m128 XMMrk)
2331
- {
2332
- __m128 XMMpG = _mm_load1_ps(pG);
2333
- __m128 XMMqG = _mm_load1_ps(qG);
2334
- __m128 XMMwG = _mm_load1_ps(wG);
2335
- __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
2336
- __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
2337
- __m128 XMMeta_w = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMwG));
2338
-
2339
- __m128 XMMpG1 = _mm_setzero_ps();
2340
- __m128 XMMqG1 = _mm_setzero_ps();
2341
- __m128 XMMwG1 = _mm_setzero_ps();
2342
-
2343
- for(mf_int d = d_begin; d < d_end; d += 4)
2344
- {
2345
- __m128 XMMp = _mm_load_ps(p+d);
2346
- __m128 XMMq = _mm_load_ps(q+d);
2347
- __m128 XMMw = _mm_load_ps(w+d);
2348
-
2349
- __m128 XMMpg = _mm_add_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
2350
- _mm_mul_ps(XMMz, _mm_sub_ps(XMMw, XMMq)));
2351
- __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
2352
- _mm_mul_ps(XMMz, XMMp));
2353
- __m128 XMMwg = _mm_add_ps(_mm_mul_ps(XMMlambda_q2, XMMw),
2354
- _mm_mul_ps(XMMz, XMMp));
2355
-
2356
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
2357
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
2358
- XMMwG1 = _mm_add_ps(XMMwG1, _mm_mul_ps(XMMwg, XMMwg));
2359
-
2360
- XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
2361
- XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
2362
- XMMw = _mm_sub_ps(XMMw, _mm_mul_ps(XMMeta_w, XMMwg));
2363
-
2364
- _mm_store_ps(p+d, XMMp);
2365
- _mm_store_ps(q+d, XMMq);
2366
- _mm_store_ps(w+d, XMMw);
2367
- }
2368
-
2369
- mf_float tmp = 0;
2370
- _mm_store_ss(&tmp, XMMlambda_p1);
2371
- if(tmp > 0)
2372
- {
2373
- for(mf_int d = d_begin; d < d_end; d += 4)
2374
- {
2375
- __m128 XMMp = _mm_load_ps(p+d);
2376
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
2377
- _mm_set1_ps(-0.0f));
2378
- XMMp = _mm_xor_ps(XMMflip,
2379
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
2380
- _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
2381
- _mm_store_ps(p+d, XMMp);
2382
- }
2383
- }
2384
-
2385
- _mm_store_ss(&tmp, XMMlambda_q1);
2386
- if(tmp > 0)
2387
- {
2388
- for(mf_int d = d_begin; d < d_end; d += 4)
2389
- {
2390
- __m128 XMMq = _mm_load_ps(q+d);
2391
- __m128 XMMw = _mm_load_ps(w+d);
2392
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
2393
- _mm_set1_ps(-0.0f));
2394
- XMMq = _mm_xor_ps(XMMflip,
2395
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
2396
- _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2397
- _mm_store_ps(q+d, XMMq);
2398
-
2399
-
2400
- XMMflip = _mm_and_ps(_mm_cmple_ps(XMMw, _mm_set1_ps(0.0f)),
2401
- _mm_set1_ps(-0.0f));
2402
- XMMw = _mm_xor_ps(XMMflip,
2403
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMw, XMMflip),
2404
- _mm_mul_ps(XMMeta_w, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2405
- _mm_store_ps(w+d, XMMw);
2406
- }
2407
- }
2408
-
2409
- if(param.do_nmf)
2410
- {
2411
- for(mf_int d = d_begin; d < d_end; d += 4)
2412
- {
2413
- __m128 XMMp = _mm_load_ps(p+d);
2414
- __m128 XMMq = _mm_load_ps(q+d);
2415
- __m128 XMMw = _mm_load_ps(w+d);
2416
- XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
2417
- XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
2418
- XMMw = _mm_max_ps(XMMw, _mm_set1_ps(0.0f));
2419
- _mm_store_ps(p+d, XMMp);
2420
- _mm_store_ps(q+d, XMMq);
2421
- _mm_store_ps(w+d, XMMw);
2422
- }
2423
- }
2424
-
2425
- // Update learning rate of latent vector p. Squared derivatives along all
2426
- // latent dimensions will be computed above. Here their average will be
2427
- // added into the associated squared-gradient sum.
2428
- __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
2429
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2430
- XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
2431
- _mm_store_ss(pG, XMMpG);
2432
-
2433
- // Similar code is used to update learning rate of latent vector q.
2434
- XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
2435
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2436
- XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
2437
- _mm_store_ss(qG, XMMqG);
2438
-
2439
- // Similar code is used to update learning rate of latent vector w.
2440
- XMMtmp = _mm_add_ps(XMMwG1, _mm_movehl_ps(XMMwG1, XMMwG1));
2441
- XMMwG1 = _mm_add_ps(XMMwG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2442
- XMMwG = _mm_add_ps(XMMwG, _mm_mul_ps(XMMwG1, XMMrk));
2443
- _mm_store_ss(wG, XMMwG);
2444
- }
2445
-
2446
- void BPRSolver::prepare_for_sg_update(
2447
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2448
- {
2449
- prepare_negative();
2450
- calc_z(XMMz, model.k, p, q, w);
2451
- _mm_store_ss(&z, XMMz);
2452
- z = exp(-z);
2453
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2454
- XMMerror = XMMloss;
2455
- XMMz = _mm_set1_ps(z/(1+z));
2456
- }
2457
- #elif defined USEAVX
2458
- inline void BPRSolver::calc_z(
2459
- __m256 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2460
- {
2461
- XMMz = _mm256_setzero_ps();
2462
- for(mf_int d = 0; d < k; d += 8)
2463
- XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
2464
- _mm256_load_ps(p+d), _mm256_sub_ps(
2465
- _mm256_load_ps(q+d), _mm256_load_ps(w+d))));
2466
- XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
2467
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
2468
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
2469
- }
2470
-
2471
- void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2472
- {
2473
- XMMloss = _mm_setzero_pd();
2474
- XMMerror = _mm_setzero_pd();
2475
- bid = scheduler.get_job();
2476
- block = blocks[bid];
2477
- block->reload();
2478
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2479
- }
2480
-
2481
- void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2482
- {
2483
- _mm_store_sd(&loss, XMMloss);
2484
- _mm_store_sd(&error, XMMerror);
2485
- scheduler.put_job(bid, loss, error);
2486
- scheduler.put_bpr_job(bid, bpr_bid);
2487
- }
2488
-
2489
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2490
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2491
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
2492
- __m256 XMMeta, __m256 XMMrk)
2493
- {
2494
- __m256 XMMpG = _mm256_broadcast_ss(pG);
2495
- __m256 XMMqG = _mm256_broadcast_ss(qG);
2496
- __m256 XMMwG = _mm256_broadcast_ss(wG);
2497
- __m256 XMMeta_p =
2498
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
2499
- __m256 XMMeta_q =
2500
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
2501
- __m256 XMMeta_w =
2502
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMwG));
2503
-
2504
- __m256 XMMpG1 = _mm256_setzero_ps();
2505
- __m256 XMMqG1 = _mm256_setzero_ps();
2506
- __m256 XMMwG1 = _mm256_setzero_ps();
2507
-
2508
- for(mf_int d = d_begin; d < d_end; d += 8)
2509
- {
2510
- __m256 XMMp = _mm256_load_ps(p+d);
2511
- __m256 XMMq = _mm256_load_ps(q+d);
2512
- __m256 XMMw = _mm256_load_ps(w+d);
2513
- __m256 XMMpg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
2514
- _mm256_mul_ps(XMMz, _mm256_sub_ps(XMMw, XMMq)));
2515
- __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
2516
- _mm256_mul_ps(XMMz, XMMp));
2517
- __m256 XMMwg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_q2, XMMw),
2518
- _mm256_mul_ps(XMMz, XMMp));
2519
-
2520
- XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
2521
- XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
2522
- XMMwG1 = _mm256_add_ps(XMMwG1, _mm256_mul_ps(XMMwg, XMMwg));
2523
-
2524
- XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
2525
- XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
2526
- XMMw = _mm256_sub_ps(XMMw, _mm256_mul_ps(XMMeta_w, XMMwg));
2527
-
2528
- _mm256_store_ps(p+d, XMMp);
2529
- _mm256_store_ps(q+d, XMMq);
2530
- _mm256_store_ps(w+d, XMMw);
2531
- }
2532
-
2533
- mf_float tmp = 0;
2534
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
2535
- if(tmp > 0)
2536
- {
2537
- for(mf_int d = d_begin; d < d_end; d += 8)
2538
- {
2539
- __m256 XMMp = _mm256_load_ps(p+d);
2540
- __m256 XMMflip =
2541
- _mm256_and_ps(
2542
- _mm256_cmp_ps(XMMp, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2543
- _mm256_set1_ps(-0.0f));
2544
- XMMp = _mm256_xor_ps(XMMflip,
2545
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMp, XMMflip),
2546
- _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
2547
- _mm256_set1_ps(0.0f)));
2548
- _mm256_store_ps(p+d, XMMp);
2549
- }
2550
- }
2551
-
2552
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
2553
- if(tmp > 0)
2554
- {
2555
- for(mf_int d = d_begin; d < d_end; d += 8)
2556
- {
2557
- __m256 XMMq = _mm256_load_ps(q+d);
2558
- __m256 XMMw = _mm256_load_ps(w+d);
2559
- __m256 XMMflip;
2560
-
2561
- XMMflip = _mm256_and_ps(
2562
- _mm256_cmp_ps(XMMq, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2563
- _mm256_set1_ps(-0.0f));
2564
- XMMq = _mm256_xor_ps(XMMflip,
2565
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMq, XMMflip),
2566
- _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
2567
- _mm256_set1_ps(0.0f)));
2568
- _mm256_store_ps(q+d, XMMq);
2569
-
2570
-
2571
- XMMflip = _mm256_and_ps(
2572
- _mm256_cmp_ps(XMMw, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2573
- _mm256_set1_ps(-0.0f));
2574
- XMMw = _mm256_xor_ps(XMMflip,
2575
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMw, XMMflip),
2576
- _mm256_mul_ps(XMMeta_w, XMMlambda_q1)),
2577
- _mm256_set1_ps(0.0f)));
2578
- _mm256_store_ps(w+d, XMMw);
2579
- }
2580
- }
2581
-
2582
- if(param.do_nmf)
2583
- {
2584
- for(mf_int d = d_begin; d < d_end; d += 8)
2585
- {
2586
- __m256 XMMp = _mm256_load_ps(p+d);
2587
- __m256 XMMq = _mm256_load_ps(q+d);
2588
- __m256 XMMw = _mm256_load_ps(w+d);
2589
- XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0.0f));
2590
- XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0.0f));
2591
- XMMw = _mm256_max_ps(XMMw, _mm256_set1_ps(0.0f));
2592
- _mm256_store_ps(p+d, XMMp);
2593
- _mm256_store_ps(q+d, XMMq);
2594
- _mm256_store_ps(w+d, XMMw);
2595
- }
2596
- }
2597
-
2598
- XMMpG1 = _mm256_add_ps(XMMpG1,
2599
- _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
2600
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2601
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2602
-
2603
- XMMqG1 = _mm256_add_ps(XMMqG1,
2604
- _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
2605
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2606
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2607
-
2608
- XMMwG1 = _mm256_add_ps(XMMwG1,
2609
- _mm256_permute2f128_ps(XMMwG1, XMMwG1, 0x1));
2610
- XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2611
- XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2612
-
2613
- XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
2614
- XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
2615
- XMMwG = _mm256_add_ps(XMMwG, _mm256_mul_ps(XMMwG1, XMMrk));
2616
-
2617
- _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
2618
- _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
2619
- _mm_store_ss(wG, _mm256_castps256_ps128(XMMwG));
2620
- }
2621
-
2622
- void BPRSolver::prepare_for_sg_update(
2623
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2624
- {
2625
- prepare_negative();
2626
- calc_z(XMMz, model.k, p, q, w);
2627
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2628
- z = exp(-z);
2629
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2630
- XMMerror = XMMloss;
2631
- XMMz = _mm256_set1_ps(z/(1+z));
2632
- }
2633
- #else
2634
- inline void BPRSolver::calc_z(
2635
- mf_float &z, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2636
- {
2637
- z = 0;
2638
- for(mf_int d = 0; d < k; ++d)
2639
- z += p[d]*(q[d]-w[d]);
2640
- }
2641
-
2642
- void BPRSolver::arrange_block()
2643
- {
2644
- loss = 0.0;
2645
- error = 0.0;
2646
- bid = scheduler.get_job();
2647
- block = blocks[bid];
2648
- block->reload();
2649
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2650
- }
2651
-
2652
- void BPRSolver::finalize()
2653
- {
2654
- scheduler.put_job(bid, loss, error);
2655
- scheduler.put_bpr_job(bid, bpr_bid);
2656
- }
2657
-
2658
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
2659
- {
2660
- mf_float eta_p = param.eta*qrsqrt(*pG);
2661
- mf_float eta_q = param.eta*qrsqrt(*qG);
2662
- mf_float eta_w = param.eta*qrsqrt(*wG);
2663
-
2664
- mf_float pG1 = 0;
2665
- mf_float qG1 = 0;
2666
- mf_float wG1 = 0;
2667
-
2668
- for(mf_int d = d_begin; d < d_end; ++d)
2669
- {
2670
- mf_float gp = z*(w[d]-q[d]) + lambda_p2*p[d];
2671
- mf_float gq = -z*p[d] + lambda_q2*q[d];
2672
- mf_float gw = z*p[d] + lambda_q2*w[d];
2673
-
2674
- pG1 += gp*gp;
2675
- qG1 += gq*gq;
2676
- wG1 += gw*gw;
2677
-
2678
- p[d] -= eta_p*gp;
2679
- q[d] -= eta_q*gq;
2680
- w[d] -= eta_w*gw;
2681
- }
2682
-
2683
- if(lambda_p1 > 0)
2684
- {
2685
- for(mf_int d = d_begin; d < d_end; ++d)
2686
- {
2687
- mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
2688
- p[d] = p[d] >= 0? p1: -p1;
2689
- }
2690
- }
2691
-
2692
- if(lambda_q1 > 0)
2693
- {
2694
- for(mf_int d = d_begin; d < d_end; ++d)
2695
- {
2696
- mf_float q1 = max(abs(w[d])-lambda_q1*eta_w, 0.0f);
2697
- w[d] = w[d] >= 0? q1: -q1;
2698
- q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
2699
- q[d] = q[d] >= 0? q1: -q1;
2700
- }
2701
- }
2702
-
2703
- if(param.do_nmf)
2704
- {
2705
- for(mf_int d = d_begin; d < d_end; ++d)
2706
- {
2707
- p[d] = max(p[d], (mf_float)0.0);
2708
- q[d] = max(q[d], (mf_float)0.0);
2709
- w[d] = max(w[d], (mf_float)0.0);
2710
- }
2711
- }
2712
-
2713
- *pG += pG1*rk;
2714
- *qG += qG1*rk;
2715
- *wG += wG1*rk;
2716
- }
2717
-
2718
- void BPRSolver::prepare_for_sg_update()
2719
- {
2720
- prepare_negative();
2721
- calc_z(z, model.k, p, q, w);
2722
- z = exp(-z);
2723
- loss += log(1+z);
2724
- error = loss;
2725
- z = z/(1+z);
2726
- }
2727
- #endif
2728
-
2729
- class COL_BPR_MFOC : public BPRSolver
2730
- {
2731
- public:
2732
- COL_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2733
- mf_float *PG, mf_float *QG, mf_model &model,
2734
- mf_parameter param, bool &slow_only,
2735
- bool is_column_oriented=true)
2736
- : BPRSolver(scheduler, blocks, PG, QG, model, param,
2737
- slow_only, is_column_oriented) {}
2738
- protected:
2739
- #if defined USESSE
2740
- void load_fixed_variables(
2741
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2742
- __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
2743
- __m128 &XMMeta, __m128 &XMMrk_slow,
2744
- __m128 &XMMrk_fast);
2745
- #elif defined USEAVX
2746
- void load_fixed_variables(
2747
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2748
- __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
2749
- __m256 &XMMeta, __m256 &XMMrk_slow,
2750
- __m256 &XMMrk_fast);
2751
- #else
2752
- void load_fixed_variables();
2753
- #endif
2754
- void prepare_negative();
2755
- };
2756
-
2757
- void COL_BPR_MFOC::prepare_negative()
2758
- {
2759
- mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2760
- is_column_oriented);
2761
- w = model.P + negative*model.k;
2762
- wG = PG + negative*2;
2763
- swap(p, q);
2764
- swap(pG, qG);
2765
- }
2766
-
2767
- #if defined USESSE
2768
- void COL_BPR_MFOC::load_fixed_variables(
2769
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2770
- __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
2771
- __m128 &XMMeta, __m128 &XMMrk_slow,
2772
- __m128 &XMMrk_fast)
2773
- {
2774
- XMMlambda_p1 = _mm_set1_ps(param.lambda_q1);
2775
- XMMlambda_q1 = _mm_set1_ps(param.lambda_p1);
2776
- XMMlambda_p2 = _mm_set1_ps(param.lambda_q2);
2777
- XMMlambda_q2 = _mm_set1_ps(param.lambda_p2);
2778
- XMMeta = _mm_set1_ps(param.eta);
2779
- XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
2780
- XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
2781
- }
2782
- #elif defined USEAVX
2783
- void COL_BPR_MFOC::load_fixed_variables(
2784
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2785
- __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
2786
- __m256 &XMMeta, __m256 &XMMrk_slow,
2787
- __m256 &XMMrk_fast)
2788
- {
2789
- XMMlambda_p1 = _mm256_set1_ps(param.lambda_q1);
2790
- XMMlambda_q1 = _mm256_set1_ps(param.lambda_p1);
2791
- XMMlambda_p2 = _mm256_set1_ps(param.lambda_q2);
2792
- XMMlambda_q2 = _mm256_set1_ps(param.lambda_p2);
2793
- XMMeta = _mm256_set1_ps(param.eta);
2794
- XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
2795
- XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
2796
- }
2797
- #else
2798
- void COL_BPR_MFOC::load_fixed_variables()
2799
- {
2800
- lambda_p1 = param.lambda_q1;
2801
- lambda_q1 = param.lambda_p1;
2802
- lambda_p2 = param.lambda_q2;
2803
- lambda_q2 = param.lambda_p2;
2804
- rk_slow = (mf_float)1.0/kALIGN;
2805
- rk_fast = (mf_float)1.0/(model.k-kALIGN);
2806
- }
2807
- #endif
2808
-
2809
- class ROW_BPR_MFOC : public BPRSolver
2810
- {
2811
- public:
2812
- ROW_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2813
- mf_float *PG, mf_float *QG, mf_model &model,
2814
- mf_parameter param, bool &slow_only,
2815
- bool is_column_oriented = false)
2816
- : BPRSolver(scheduler, blocks, PG, QG, model, param,
2817
- slow_only, is_column_oriented) {}
2818
- protected:
2819
- void prepare_negative();
2820
- };
2821
-
2822
- void ROW_BPR_MFOC::prepare_negative()
2823
- {
2824
- mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2825
- is_column_oriented);
2826
- w = model.Q + negative*model.k;
2827
- wG = QG + negative*2;
2828
- }
2829
-
2830
-
2831
- class SolverFactory
2832
- {
2833
- public:
2834
- static shared_ptr<SolverBase> get_solver(
2835
- Scheduler &scheduler,
2836
- vector<BlockBase*> &blocks,
2837
- mf_float *PG,
2838
- mf_float *QG,
2839
- mf_model &model,
2840
- mf_parameter param,
2841
- bool &slow_only);
2842
- };
2843
-
2844
- shared_ptr<SolverBase> SolverFactory::get_solver(
2845
- Scheduler &scheduler,
2846
- vector<BlockBase*> &blocks,
2847
- mf_float *PG,
2848
- mf_float *QG,
2849
- mf_model &model,
2850
- mf_parameter param,
2851
- bool &slow_only)
2852
- {
2853
- shared_ptr<SolverBase> solver;
2854
-
2855
- switch(param.fun)
2856
- {
2857
- case P_L2_MFR:
2858
- solver = shared_ptr<SolverBase>(new L2_MFR(scheduler, blocks,
2859
- PG, QG, model, param, slow_only));
2860
- break;
2861
- case P_L1_MFR:
2862
- solver = shared_ptr<SolverBase>(new L1_MFR(scheduler, blocks,
2863
- PG, QG, model, param, slow_only));
2864
- break;
2865
- case P_KL_MFR:
2866
- solver = shared_ptr<SolverBase>(new KL_MFR(scheduler, blocks,
2867
- PG, QG, model, param, slow_only));
2868
- break;
2869
- case P_LR_MFC:
2870
- solver = shared_ptr<SolverBase>(new LR_MFC(scheduler, blocks,
2871
- PG, QG, model, param, slow_only));
2872
- break;
2873
- case P_L2_MFC:
2874
- solver = shared_ptr<SolverBase>(new L2_MFC(scheduler, blocks,
2875
- PG, QG, model, param, slow_only));
2876
- break;
2877
- case P_L1_MFC:
2878
- solver = shared_ptr<SolverBase>(new L1_MFC(scheduler, blocks,
2879
- PG, QG, model, param, slow_only));
2880
- break;
2881
- case P_ROW_BPR_MFOC:
2882
- solver = shared_ptr<SolverBase>(new ROW_BPR_MFOC(scheduler,
2883
- blocks, PG, QG, model, param, slow_only));
2884
- break;
2885
- case P_COL_BPR_MFOC:
2886
- solver = shared_ptr<SolverBase>(new COL_BPR_MFOC(scheduler,
2887
- blocks, PG, QG, model, param, slow_only));
2888
- break;
2889
- default:
2890
- throw invalid_argument("unknown error function");
2891
- }
2892
- return solver;
2893
- }
2894
-
2895
- void fpsg_core(
2896
- Utility &util,
2897
- Scheduler &sched,
2898
- mf_problem *tr,
2899
- mf_problem *va,
2900
- mf_parameter param,
2901
- mf_float scale,
2902
- vector<BlockBase*> &block_ptrs,
2903
- vector<mf_int> &omega_p,
2904
- vector<mf_int> &omega_q,
2905
- shared_ptr<mf_model> &model,
2906
- vector<mf_int> cv_blocks,
2907
- mf_double *cv_error)
2908
- {
2909
- #if defined USESSE || defined USEAVX
2910
- auto flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE();
2911
- _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
2912
- #endif
2913
- if(tr->nnz == 0)
2914
- {
2915
- cout << "warning: train on an empty training set" << endl;
2916
- return;
2917
- }
2918
-
2919
- if(param.fun == P_L2_MFR ||
2920
- param.fun == P_L1_MFR ||
2921
- param.fun == P_KL_MFR)
2922
- {
2923
- switch(param.fun)
2924
- {
2925
- case P_L2_MFR:
2926
- param.lambda_p2 /= scale;
2927
- param.lambda_q2 /= scale;
2928
- param.lambda_p1 /= (mf_float)pow(scale, 1.5);
2929
- param.lambda_q1 /= (mf_float)pow(scale, 1.5);
2930
- break;
2931
- case P_L1_MFR:
2932
- case P_KL_MFR:
2933
- param.lambda_p1 /= sqrt(scale);
2934
- param.lambda_q1 /= sqrt(scale);
2935
- break;
2936
- }
2937
- }
2938
-
2939
- if(!param.quiet)
2940
- {
2941
- cout.width(4);
2942
- cout << "iter";
2943
- cout.width(13);
2944
- cout << "tr_"+util.get_error_legend();
2945
- if(va->nnz != 0)
2946
- {
2947
- cout.width(13);
2948
- cout << "va_"+util.get_error_legend();
2949
- }
2950
- cout.width(13);
2951
- cout << "obj";
2952
- cout << "\n";
2953
- }
2954
-
2955
- bool slow_only = param.lambda_p1 == 0 && param.lambda_q1 == 0? true: false;
2956
- vector<mf_float> PG(model->m*2, 1), QG(model->n*2, 1);
2957
-
2958
- vector<shared_ptr<SolverBase>> solvers(param.nr_threads);
2959
- vector<thread> threads;
2960
- threads.reserve(param.nr_threads);
2961
- for(mf_int i = 0; i < param.nr_threads; ++i)
2962
- {
2963
- solvers[i] = SolverFactory::get_solver(sched, block_ptrs,
2964
- PG.data(), QG.data(),
2965
- *model, param, slow_only);
2966
- threads.emplace_back(&SolverBase::run, solvers[i].get());
2967
- }
2968
-
2969
- for(mf_int iter = 0; iter < param.nr_iters; ++iter)
2970
- {
2971
- sched.wait_for_jobs_done();
2972
-
2973
- if(!param.quiet)
2974
- {
2975
- mf_double reg = 0;
2976
- mf_double reg1 = util.calc_reg1(*model, param.lambda_p1,
2977
- param.lambda_q1, omega_p, omega_q);
2978
- mf_double reg2 = util.calc_reg2(*model, param.lambda_p2,
2979
- param.lambda_q2, omega_p, omega_q);
2980
- mf_double tr_loss = sched.get_loss();
2981
- mf_double tr_error = sched.get_error()/tr->nnz;
2982
-
2983
- switch(param.fun)
2984
- {
2985
- case P_L2_MFR:
2986
- reg = (reg1+reg2)*scale*scale;
2987
- tr_loss *= scale*scale;
2988
- tr_error = sqrt(tr_error*scale*scale);
2989
- break;
2990
- case P_L1_MFR:
2991
- case P_KL_MFR:
2992
- reg = (reg1+reg2)*scale;
2993
- tr_loss *= scale;
2994
- tr_error *= scale;
2995
- break;
2996
- default:
2997
- reg = reg1+reg2;
2998
- break;
2999
- }
3000
-
3001
- cout.width(4);
3002
- cout << iter;
3003
- cout.width(13);
3004
- cout << fixed << setprecision(4) << tr_error;
3005
- if(va->nnz != 0)
3006
- {
3007
- Block va_block(va->R, va->R+va->nnz);
3008
- vector<BlockBase*> va_blocks(1, &va_block);
3009
- vector<mf_int> va_block_ids(1, 0);
3010
- mf_double va_error =
3011
- util.calc_error(va_blocks, va_block_ids, *model)/va->nnz;
3012
- switch(param.fun)
3013
- {
3014
- case P_L2_MFR:
3015
- va_error = sqrt(va_error*scale*scale);
3016
- break;
3017
- case P_L1_MFR:
3018
- case P_KL_MFR:
3019
- va_error *= scale;
3020
- break;
3021
- }
3022
-
3023
- cout.width(13);
3024
- cout << fixed << setprecision(4) << va_error;
3025
- }
3026
- cout.width(13);
3027
- cout << fixed << setprecision(4) << scientific << reg+tr_loss;
3028
- cout << "\n" << flush;
3029
- }
3030
-
3031
- if(iter == 0)
3032
- slow_only = false;
3033
- if(iter == param.nr_iters - 1)
3034
- sched.terminate();
3035
- sched.resume();
3036
- }
3037
-
3038
- for(auto &thread : threads)
3039
- thread.join();
3040
-
3041
- if(cv_error != nullptr && cv_blocks.size() > 0)
3042
- {
3043
- mf_long cv_count = 0;
3044
- for(auto block : cv_blocks)
3045
- cv_count += block_ptrs[block]->get_nnz();
3046
-
3047
- *cv_error = util.calc_error(block_ptrs, cv_blocks, *model)/cv_count;
3048
-
3049
- switch(param.fun)
3050
- {
3051
- case P_L2_MFR:
3052
- *cv_error = sqrt(*cv_error*scale*scale);
3053
- break;
3054
- case P_L1_MFR:
3055
- case P_KL_MFR:
3056
- *cv_error *= scale;
3057
- break;
3058
- }
3059
- }
3060
-
3061
- #if defined USESSE || defined USEAVX
3062
- _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode);
3063
- #endif
3064
- }
3065
-
3066
- shared_ptr<mf_model> fpsg(
3067
- mf_problem const *tr_,
3068
- mf_problem const *va_,
3069
- mf_parameter param,
3070
- vector<mf_int> cv_blocks=vector<mf_int>(),
3071
- mf_double *cv_error=nullptr)
3072
- {
3073
- shared_ptr<mf_model> model;
3074
- try
3075
- {
3076
- Utility util(param.fun, param.nr_threads);
3077
- Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3078
- shared_ptr<mf_problem> tr;
3079
- shared_ptr<mf_problem> va;
3080
- vector<Block> blocks(param.nr_bins*param.nr_bins);
3081
- vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3082
- vector<mf_node*> ptrs;
3083
- vector<mf_int> p_map;
3084
- vector<mf_int> q_map;
3085
- vector<mf_int> inv_p_map;
3086
- vector<mf_int> inv_q_map;
3087
- vector<mf_int> omega_p;
3088
- vector<mf_int> omega_q;
3089
- mf_float avg = 0;
3090
- mf_float std_dev = 0;
3091
- mf_float scale = 1;
3092
-
3093
- if(param.copy_data)
3094
- {
3095
- tr = shared_ptr<mf_problem>(
3096
- Utility::copy_problem(tr_, true), deleter());
3097
- va = shared_ptr<mf_problem>(
3098
- Utility::copy_problem(va_, true), deleter());
3099
- }
3100
- else
3101
- {
3102
- tr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3103
- va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3104
- }
3105
-
3106
- util.collect_info(*tr, avg, std_dev);
3107
-
3108
- if(param.fun == P_L2_MFR ||
3109
- param.fun == P_L1_MFR ||
3110
- param.fun == P_KL_MFR)
3111
- scale = max((mf_float)1e-4, std_dev);
3112
-
3113
- p_map = Utility::gen_random_map(tr->m);
3114
- q_map = Utility::gen_random_map(tr->n);
3115
- inv_p_map = Utility::gen_inv_map(p_map);
3116
- inv_q_map = Utility::gen_inv_map(q_map);
3117
- omega_p = vector<mf_int>(tr->m, 0);
3118
- omega_q = vector<mf_int>(tr->n, 0);
3119
-
3120
- util.shuffle_problem(*tr, p_map, q_map);
3121
- util.shuffle_problem(*va, p_map, q_map);
3122
- util.scale_problem(*tr, (mf_float)1.0/scale);
3123
- util.scale_problem(*va, (mf_float)1.0/scale);
3124
- ptrs = util.grid_problem(*tr, param.nr_bins, omega_p, omega_q, blocks);
3125
-
3126
- model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3127
- tr->m, tr->n, param.k, avg/scale, omega_p, omega_q),
3128
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3129
-
3130
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3131
- block_ptrs[i] = &blocks[i];
3132
-
3133
- fpsg_core(util, sched, tr.get(), va.get(), param, scale,
3134
- block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3135
-
3136
- if(!param.copy_data)
3137
- {
3138
- util.scale_problem(*tr, scale);
3139
- util.scale_problem(*va, scale);
3140
- util.shuffle_problem(*tr, inv_p_map, inv_q_map);
3141
- util.shuffle_problem(*va, inv_p_map, inv_q_map);
3142
- }
3143
-
3144
- util.scale_model(*model, scale);
3145
- Utility::shrink_model(*model, param.k);
3146
- Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3147
- }
3148
- catch(exception const &e)
3149
- {
3150
- cerr << e.what() << endl;
3151
- throw;
3152
- }
3153
- return model;
3154
- }
3155
-
3156
- shared_ptr<mf_model> fpsg_on_disk(
3157
- const string tr_path,
3158
- const string va_path,
3159
- mf_parameter param,
3160
- vector<mf_int> cv_blocks=vector<mf_int>(),
3161
- mf_double *cv_error=nullptr)
3162
- {
3163
- shared_ptr<mf_model> model;
3164
- try
3165
- {
3166
- Utility util(param.fun, param.nr_threads);
3167
- Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3168
- mf_problem tr = {};
3169
- mf_problem va = read_problem(va_path.c_str());
3170
- vector<BlockOnDisk> blocks(param.nr_bins*param.nr_bins);
3171
- vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3172
- vector<mf_int> p_map;
3173
- vector<mf_int> q_map;
3174
- vector<mf_int> inv_p_map;
3175
- vector<mf_int> inv_q_map;
3176
- vector<mf_int> omega_p;
3177
- vector<mf_int> omega_q;
3178
- mf_float avg = 0;
3179
- mf_float std_dev = 0;
3180
- mf_float scale = 1;
3181
-
3182
- util.collect_info_on_disk(tr_path, tr, avg, std_dev);
3183
-
3184
- if(param.fun == P_L2_MFR ||
3185
- param.fun == P_L1_MFR ||
3186
- param.fun == P_KL_MFR)
3187
- scale = max((mf_float)1e-4, std_dev);
3188
-
3189
- p_map = Utility::gen_random_map(tr.m);
3190
- q_map = Utility::gen_random_map(tr.n);
3191
- inv_p_map = Utility::gen_inv_map(p_map);
3192
- inv_q_map = Utility::gen_inv_map(q_map);
3193
- omega_p = vector<mf_int>(tr.m, 0);
3194
- omega_q = vector<mf_int>(tr.n, 0);
3195
-
3196
- util.shuffle_problem(va, p_map, q_map);
3197
- util.scale_problem(va, (mf_float)1.0/scale);
3198
-
3199
- util.grid_shuffle_scale_problem_on_disk(
3200
- tr.m, tr.n, param.nr_bins, scale, tr_path,
3201
- p_map, q_map, omega_p, omega_q, blocks);
3202
-
3203
- model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3204
- tr.m, tr.n, param.k, avg/scale, omega_p, omega_q),
3205
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3206
-
3207
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3208
- block_ptrs[i] = &blocks[i];
3209
-
3210
- fpsg_core(util, sched, &tr, &va, param, scale,
3211
- block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3212
-
3213
- delete [] va.R;
3214
-
3215
- util.scale_model(*model, scale);
3216
- Utility::shrink_model(*model, param.k);
3217
- Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3218
- }
3219
- catch(exception const &e)
3220
- {
3221
- cerr << e.what() << endl;
3222
- throw;
3223
- }
3224
- return model;
3225
- }
3226
-
3227
- // The function implements an efficient method to compute objective function
3228
- // minimized by coordinate descent method.
3229
- //
3230
- // \min_{P, Q} 0.5 * \sum_{(u,v)\in\Omega^+} (1-r_{u,v})^2 +
3231
- // 0.5 * \alpha \sum_{(u,v)\not\in\Omega^+} (c-r_{u,v})^2 +
3232
- // 0.5 * \lambda_p2 * ||P||_F^2 + 0.5 * \lambda_q2 * ||Q||_F^2
3233
- // where
3234
- // 1. (u,v) is a tuple of row index and column index,
3235
- // 2. \Omega^+ a collections of (u,v) which specifies the locations of
3236
- // positive entries in the training matrix.
3237
- // 3. r_{u,v} is the predicted rating at (u,v)
3238
- // 4. \alpha is the weight of negative entries' loss.
3239
- // 5. c is the desired value at every negative entries.
3240
- // 6. ||P||_F is matrix P's Frobenius norm.
3241
- // 7. \lambda_p2 is the regularization coefficient of P.
3242
- //
3243
- // Note that coordinate descent method's P and Q are the transpose
3244
- // counterparts of P and Q in stochastic gradient method. Let R denoates
3245
- // the training matrix. For stochastic gradient method, we have R ~ P^TQ.
3246
- // For coordinate descent method, we have R ~ PQ^T.
3247
- void calc_ccd_one_class_obj(const mf_int nr_threads,
3248
- const mf_float alpha, const mf_float c,
3249
- const mf_int m, const mf_int n, const mf_int d,
3250
- const mf_float lambda_p2, const mf_float lambda_q2,
3251
- const mf_float *P, const mf_float *Q,
3252
- shared_ptr<const mf_problem> data,
3253
- /*output*/ mf_double &obj,
3254
- /*output*/ mf_double &positive_loss,
3255
- /*output*/ mf_double &negative_loss,
3256
- /*output*/ mf_double &reg)
3257
- {
3258
- // Declare regularization term of P.
3259
- mf_double p_square_norm = 0.0;
3260
- // Reduce P along column axis, which is the sum of rows in P.
3261
- vector<mf_double> all_p_sum(d, 0.0);
3262
- // Compute square of Frobenius norm on P and sum of all rows in P.
3263
- for(mf_int k = 0; k < d; ++k)
3264
- {
3265
- // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3266
- mf_double all_p_sum_k = 0.0;
3267
- #if defined USEOMP
3268
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_square_norm,all_p_sum_k)
3269
- #endif
3270
- for(mf_int u = 0; u < m; ++u)
3271
- {
3272
- const mf_float &p_ku = P[u + k * m];
3273
- p_square_norm += p_ku * p_ku;
3274
- all_p_sum_k += p_ku;
3275
- }
3276
- all_p_sum[k] = all_p_sum_k;
3277
- }
3278
-
3279
- // Declare regularization term of Q
3280
- mf_double q_square_norm = 0.0;
3281
- // Reduce Q along column axis, whihc is the sum of rows in Q.
3282
- vector<mf_double> all_q_sum(d, 0.0);
3283
- // Compute square of Frobenius norm on Q and sum of all elements in Q
3284
- for(mf_int k = 0; k < d; ++k)
3285
- {
3286
- // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3287
- mf_double all_q_sum_k = 0.0;
3288
- #if defined USEOMP
3289
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_square_norm,all_q_sum_k)
3290
- #endif
3291
- for(mf_int v = 0; v < n; ++v)
3292
- {
3293
- const mf_float &q_kv = Q[v + k * n];
3294
- q_square_norm += q_kv * q_kv;
3295
- all_q_sum_k += q_kv;
3296
- }
3297
- all_q_sum[k] = all_q_sum_k;
3298
- }
3299
-
3300
- // PTP = P^T * P, where P^T is the transpose of P. Note that P is a m-by-d
3301
- // matrix and PTP is a d-by-d matrix.
3302
- vector<mf_double> PTP(d * d, 0.0);
3303
- // QTQ = Q^T * P, a d-by-d matrix.
3304
- vector<mf_double> QTQ(d * d, 0.0);
3305
- // We calculate PTP and QTQ because they are needed in the computation of
3306
- // negative entries' loss function.
3307
- for(mf_int k1 = 0; k1 < d; ++k1)
3308
- {
3309
- for(mf_int k2 = 0; k2 < d; ++k2)
3310
- {
3311
- // Inner product of the k1 and k2 columns in P, a m-by-d matrix.
3312
- mf_double p_k1_p_k2_inner_product = 0.0;
3313
- #if defined USEOMP
3314
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_k1_p_k2_inner_product)
3315
- #endif
3316
- for(mf_int u = 0; u < m; ++u)
3317
- p_k1_p_k2_inner_product += P[u + k1 * m] * P[u + k2 * m];
3318
- PTP[k1 * d + k2] = p_k1_p_k2_inner_product;
3319
-
3320
- // Inner product of the k1 and k2 columns in Q, a n-by-d matrix.
3321
- mf_double q_k1_q_k2_inner_product = 0.0;
3322
- #if defined USEOMP
3323
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_k1_q_k2_inner_product)
3324
- #endif
3325
- for(mf_int v = 0; v < n; ++v)
3326
- q_k1_q_k2_inner_product += Q[v + k1 * n] * Q[v + k2 * n];
3327
- QTQ[k1 * d + k2] = q_k1_q_k2_inner_product;
3328
- }
3329
- }
3330
-
3331
- // Initialize loss function value of positive matrix entries.
3332
- // It consists two parts. The first part is the true prediction error
3333
- // while the second part is only used for implementing faster algorithm.
3334
- mf_double positive_loss1 = 0.0;
3335
- mf_double positive_loss2 = 0.0;
3336
- // Scan through positive matrix entries to compute their loss values.
3337
- // Notice that we assume that positive entries' values are all one.
3338
- #if defined USEOMP
3339
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:positive_loss1,positive_loss2)
3340
- #endif
3341
- for(mf_long i = 0; i < data->nnz; ++i)
3342
- {
3343
- const mf_double &r = data->R[i].r;
3344
- positive_loss1 += (1.0 - r) * (1.0 - r);
3345
- positive_loss2 -= alpha * (c - r) * (c - r);
3346
- }
3347
- positive_loss1 *= 0.5;
3348
- positive_loss2 *= 0.5;
3349
-
3350
- // Declare loss terms related to negative matrix entries.
3351
- mf_double negative_loss1 = c * c * m * n;
3352
- mf_double negative_loss2 = 0.0;
3353
- mf_double negative_loss3 = 0.0;
3354
- // Compute loss terms.
3355
- for(mf_int k1 = 0; k1 < d; ++k1)
3356
- {
3357
- negative_loss2 += all_p_sum[k1] * all_q_sum[k1];
3358
- for(mf_int k2 = 0; k2 < d; ++k2)
3359
- negative_loss3 += PTP[k1 + k2 * d] * QTQ[k2 + k1 * d];
3360
- }
3361
- // Compute the loss function of negative matrix entries.
3362
- mf_double negative_loss4 = 0.5 * alpha *
3363
- (negative_loss1 - 2 * c * negative_loss2 + negative_loss3);
3364
-
3365
- // Assign results to output variables.
3366
- reg = 0.5 * lambda_p2 * p_square_norm + 0.5 * lambda_q2 * q_square_norm;
3367
-
3368
- // The function minimized by coordinate descent method.
3369
- obj = positive_loss1 + positive_loss2 + negative_loss4 + reg;
3370
-
3371
- // Sume of squared error over positive matrix entries (i.e., those mf_node's
3372
- // in data).
3373
- positive_loss = positive_loss1;
3374
-
3375
- // Sume of squared error over negative matrix entries (i.e., those mf_node's
3376
- // in data). The value negative_loss4 contains the squared errors by
3377
- // considering positive entries as negative entries, so positive_loss2 is
3378
- // added to compensate that.
3379
- negative_loss = negative_loss4 + positive_loss2;
3380
- }
3381
-
3382
- void ccd_one_class_core(
3383
- const Utility &util,
3384
- shared_ptr<const mf_problem> tr_csr,
3385
- shared_ptr<const mf_problem> tr_csc,
3386
- shared_ptr<const mf_problem> va,
3387
- const mf_parameter param,
3388
- const vector<mf_node*> &ptrs_u,
3389
- const vector<mf_node*> &ptrs_v,
3390
- /*output*/ shared_ptr<mf_model> &model)
3391
- {
3392
- // Check problems stored in CSR and CSC formats
3393
- if(tr_csr == nullptr) throw invalid_argument("CSR problem pointer is null.");
3394
- if(tr_csc == nullptr) throw invalid_argument("CSC problem pointer is null.");
3395
-
3396
- if(tr_csr->m != tr_csc->m)
3397
- throw logic_error(
3398
- "Row counts must be identical in CSR and CSC formats: " +
3399
- to_string(tr_csr->m) + " != " + to_string(tr_csc->m));
3400
- const mf_int m = tr_csr->m;
3401
-
3402
- if(tr_csr->n != tr_csc->n)
3403
- throw logic_error(
3404
- "Column counts must be identical in CSR and CSC formats: " +
3405
- to_string(tr_csr->n) + " != " + to_string(tr_csc->n));
3406
- const mf_int n = tr_csr->n;
3407
-
3408
- if(tr_csc->nnz != tr_csc->nnz)
3409
- throw logic_error(
3410
- "Numbers of data points must be identical in CSR and CSC formats: " +
3411
- to_string(tr_csr->nnz) + " != " + to_string(tr_csc->nnz));
3412
- const mf_long nnz = tr_csr->nnz;
3413
-
3414
- // Check formulation parameters
3415
- if(param.k <= 0)
3416
- throw invalid_argument(
3417
- "Latent dimension must be positive but got " +
3418
- to_string(param.k));
3419
- const mf_int d = param.k;
3420
-
3421
- if(param.lambda_p1 != 0)
3422
- throw invalid_argument(
3423
- "P's L1-regularization coefficient must be zero but got " +
3424
- to_string(param.lambda_p1));
3425
- if(param.lambda_q1 != 0)
3426
- throw invalid_argument(
3427
- "Q's L1-regularization coefficient must be zero but got " +
3428
- to_string(param.lambda_q1));
3429
-
3430
- if(param.lambda_p2 <= 0)
3431
- throw invalid_argument(
3432
- "P's L2-regularization coefficient must be positive but got " +
3433
- to_string(param.lambda_p2));
3434
- if(param.lambda_q2 <= 0)
3435
- throw invalid_argument(
3436
- "Q's L2-regularization coefficient must be positive but got " +
3437
- to_string(param.lambda_q2));
3438
-
3439
- // REVIEW: It is not difficult to support non-negative matrix factorization
3440
- // for coordinate descent method; we just need to project the updated value
3441
- // back to the feasible region by using max(0, new_value) right after each
3442
- // Newton step. LIBMF hasn't support it only because we don't see actual
3443
- // users.
3444
- if(param.do_nmf)
3445
- throw invalid_argument(
3446
- "Coordinate descent does not support non-negative constraint");
3447
-
3448
-
3449
- // Check some resources prepared internally
3450
- if(ptrs_u.size() != (size_t)m + 1)
3451
- throw invalid_argument("Number of row pointer must be " +
3452
- to_string(m + 1) + " but got " + to_string(ptrs_u.size()));
3453
- if(ptrs_v.size() != (size_t)n + 1)
3454
- throw invalid_argument("Number of column pointer must be " +
3455
- to_string(n + 1) + " but got " + to_string(ptrs_v.size()));
3456
-
3457
- // Some constants of the formulation.
3458
- // alpha: coefficient of negative part
3459
- // c: the desired prediction values of unobserved ratings
3460
- // lambda_p2: regularization coefficient of P's L2-norm
3461
- // lambda_q2: regularization coefficient of P's Q2-norm
3462
- const mf_float alpha = param.alpha;
3463
- const mf_float c = param.c;
3464
- const mf_float lambda_p2 = param.lambda_p2;
3465
- const mf_float lambda_q2 = param.lambda_q2;
3466
-
3467
- // Initialize P and Q. Note that \bar{q}_{kv} is Q[k*n+v]
3468
- // and \bar{p}_{ku} is P[k*m+u]. One may notice that P and
3469
- // Q here are actually the transposes of P and Q in FPSG.
3470
- mf_float *P = model->P;
3471
- mf_float *Q = model->Q;
3472
-
3473
- // Cache the prediction values on positive matrix entries.
3474
- // Given that P=zero and Q=random initialized in
3475
- // Utility::init_model(mf_int m, mf_int n, mf_int k),
3476
- // all predictions are zeros.
3477
- #if defined USEOMP
3478
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3479
- #endif
3480
- for(mf_long i = 0; i < nnz; ++i)
3481
- {
3482
- tr_csr->R[i].r = 0.0;
3483
- tr_csc->R[i].r = 0.0;
3484
- }
3485
-
3486
- // If the model is not initialized by
3487
- // Utility::init_model(mf_int m, mf_int n, mf_int k),
3488
- // please use the following initialization code to compute
3489
- // and cache all prediction values on positive entries.
3490
- /*
3491
- for(mf_long i = 0; i < nnz; ++i)
3492
- {
3493
- mf_node &node = tr_csr->R[i];
3494
- node.r = 0;
3495
- for(mf_int k = 0; k < d; ++k)
3496
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3497
- }
3498
- for(mf_long i = 0; i < nnz; ++i)
3499
- {
3500
- mf_node &node = tr_csc->R[i];
3501
- node.r = 0;
3502
- for(mf_int k = 0; k < d; ++k)
3503
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3504
- }
3505
- */
3506
-
3507
- if(!param.quiet)
3508
- {
3509
- cout.width(4);
3510
- cout << "iter";
3511
- cout.width(13);
3512
- cout << "tr_"+util.get_error_legend();
3513
- cout.width(14);
3514
- cout << "tr_"+util.get_error_legend() << "+";
3515
- cout.width(14);
3516
- cout << "tr_"+util.get_error_legend() << "-";
3517
- if(va->nnz != 0)
3518
- {
3519
- cout.width(13);
3520
- cout << "va_"+util.get_error_legend();
3521
- cout.width(14);
3522
- cout << "va_"+util.get_error_legend() << "+";
3523
- cout.width(14);
3524
- cout << "va_"+util.get_error_legend() << "-";
3525
- }
3526
- cout.width(13);
3527
- cout << "obj";
3528
- cout << "\n";
3529
- }
3530
-
3531
- /////////////////////////////////////////////////////////////////
3532
- // Minimize the objective function via coordinate descent method
3533
- ////////////////////////////////////////////////////////////////
3534
- // Solve P and Q using coordinate descent.
3535
- // P = [\bar{p}_1, ..., \bar{p}_d] \in R^{m \times k}
3536
- // Q = [\bar{q}_1, ..., \bar{q}_d] \in R^{n \times k}
3537
- // Finally, the rating matrice R would be approximated via
3538
- // R ~ PQ^T \in R^{m \times n}
3539
- for(mf_int outer = 0; outer < param.nr_iters; ++outer)
3540
- {
3541
- // Update \bar{p}_k and \bar{q}_k. The basic idea is
3542
- // to replace \bar{p}_k and \bar{q}_k with a and b,
3543
- // and then minimizes the original objective function.
3544
- for(mf_int k = 0; k < d; ++k)
3545
- {
3546
- // Get the pointer to the first element of \bar{p}_k (and
3547
- // \bar{q}_k).
3548
- mf_float *P_k = P + m * k;
3549
- mf_float *Q_k = Q + n * k;
3550
-
3551
- // Initialize a and b with the value they need to replace
3552
- // so that we can ensure improvement at each iteration.
3553
- vector<mf_float> a(P_k, P_k + m);
3554
- vector<mf_float> b(Q_k, Q_k + n);
3555
-
3556
- for(mf_int inner = 0; inner < 3; ++inner)
3557
- {
3558
- ///////////////////////////////////////////////////////////////
3559
- // Update a:
3560
- // 1. Compute and cache constants
3561
- // 2. For each coordinate of a, calculate optimal update using
3562
- // Newton method
3563
- ///////////////////////////////////////////////////////////////
3564
-
3565
- // Compute and cache constants
3566
- // \hat{b} = \sum_{v=1}^n \bar{b}_v
3567
- // \tilde{b} = \sum_{v=1}^n \bar{b}_v^2
3568
- mf_double b_hat = 0.0;
3569
- mf_double b_tilde = 0.0;
3570
- #if defined USEOMP
3571
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:b_hat,b_tilde)
3572
- #endif
3573
- for(mf_int v = 0; v < n; ++v)
3574
- {
3575
- const mf_double &b_v = b[v];
3576
- b_hat += b_v;
3577
- b_tilde += b_v * b_v;
3578
- }
3579
-
3580
- // Compute and cache a constant vector
3581
- // s_k = \sum_{v=1}^n \bar{q}_{kv}b_v, k = 1, ..., d
3582
- vector<mf_double> s(d, 0.0);
3583
- for(mf_int k1 = 0; k1 < d; ++k1)
3584
- {
3585
- // Buffer variable for using OpenMP
3586
- mf_double s_k1 = 0;
3587
- const mf_float *Q_k1 = Q + k1 * n;
3588
- #if defined USEOMP
3589
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:s_k1)
3590
- #endif
3591
- for(mf_int v = 0; v < n; ++v)
3592
- s_k1 += Q_k1[v] * b[v];
3593
- s[k1] = s_k1;
3594
- }
3595
-
3596
- // Solve a's sub-problem
3597
- #if defined USEOMP
3598
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3599
- #endif
3600
- for(mf_int u = 0; u < m; ++u)
3601
- {
3602
- ////////////////////////////////////////////////////////
3603
- // Update a[u] via Newton method. Let g_u and h_u denote
3604
- // the first-order and second-order derivatives w.r.t.
3605
- // a[u]. The following code implements
3606
- // a[u] <-- a[u] - g_u/h_u
3607
- ////////////////////////////////////////////////////////
3608
-
3609
- // Initialize temporal variables for calculating gradient and hessian.
3610
- mf_double g_u_1 = 0.0;
3611
- mf_double h_u_1 = 0.0;
3612
- mf_double g_u_2 = 0.0;
3613
- // Scan through specified entries at the u-th row
3614
- for(const mf_node *ptr = ptrs_u[u]; ptr != ptrs_u[u+1]; ++ptr)
3615
- {
3616
- const mf_int &v = ptr->v;
3617
- const mf_float &b_v = b[v];
3618
- g_u_1 += b_v;
3619
- h_u_1 += b_v * b_v;
3620
- g_u_2 += (ptr->r - P_k[u] * Q_k[v] + a[u] * b_v) * b_v;
3621
- }
3622
- mf_double g_u_3 = -c * b_hat - P_k[u] * s[k] + a[u] * b_tilde;
3623
- for(mf_int k1 = 0; k1 < d; ++k1)
3624
- g_u_3 += P[m * k1 + u] * s[k1];
3625
- mf_double g_u = -(1.0 - alpha * c) * g_u_1 + (1.0 - alpha) * g_u_2 + alpha * g_u_3 + lambda_p2 * a[u];
3626
- mf_double h_u = (1.0 - alpha) * h_u_1 + alpha * b_tilde + lambda_p2;
3627
- a[u] -= static_cast<mf_float>(g_u / h_u);
3628
- }
3629
-
3630
- ///////////////////////////////////////////////////////////////
3631
- // Update b:
3632
- // 1. Compute and cache constants
3633
- // 2. For each coordinate of b, calculate optimal update using
3634
- // Newton method
3635
- ///////////////////////////////////////////////////////////////
3636
- // Compute and cache a_hat, a_tilde
3637
- // \hat{a} = \sum_{u=1}^m \bar{a}_u
3638
- // \tilde{a} = \sum_{u=1}^m \bar{a}_u^2
3639
- mf_double a_hat = 0.0;
3640
- mf_double a_tilde = 0.0;
3641
- #if defined USEOMP
3642
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:a_hat,a_tilde)
3643
- #endif
3644
- for(mf_int u = 0; u < m; ++u)
3645
- {
3646
- const mf_float &a_u = a[u];
3647
- a_hat += a_u;
3648
- a_tilde += a_u * a_u;
3649
- }
3650
-
3651
- // Compute and cache t
3652
- // t_k = \sum_{u=1}^m \bar{a}_{ku}a_u, k = 1, ..., d
3653
- vector<mf_double> t(d, 0.0);
3654
- for(mf_int k1 = 0; k1 < d; ++k1)
3655
- {
3656
- // Declare buffer variable for using OpenMP
3657
- mf_double t_k1 = 0;
3658
- const mf_float *P_k1 = P + k1 * m;
3659
- #if defined USEOMP
3660
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:t_k1)
3661
- #endif
3662
- for(mf_int u = 0; u < m; ++u)
3663
- t_k1 += P_k1[u] * a[u];
3664
- t[k1] = t_k1;
3665
- }
3666
-
3667
- #if defined USEOMP
3668
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3669
- #endif
3670
- for(mf_int v = 0; v < n; ++v)
3671
- {
3672
- ////////////////////////////////////////////////////////
3673
- // Update b[v] via Newton method. Let g_v and h_v denote
3674
- // the first-order and second-order derivatives w.r.t.
3675
- // b[v]. The following code implements
3676
- // b[v] <-- b[v] - g_v/h_v
3677
- ////////////////////////////////////////////////////////
3678
-
3679
- // Initialize temporal variables for calculating gradient and hessian.
3680
- mf_double g_v_1 = 0;
3681
- mf_double g_v_2 = 0;
3682
- mf_double h_v_1 = 0;
3683
- // Scan through all positive entries at column v
3684
- for(const mf_node *ptr = ptrs_v[v]; ptr != ptrs_v[v+1]; ++ptr)
3685
- {
3686
- const mf_int &u = ptr->u;
3687
- const mf_float &a_u = a[u];
3688
- g_v_1 += a_u;
3689
- h_v_1 += a_u * a_u;
3690
- g_v_2 += (ptr->r - P_k[u] * Q_k[v] + a_u * b[v]) * a_u;
3691
- }
3692
- mf_double g_v_3 = -c * a_hat - Q_k[v] * t[k] + b[v] * a_tilde;
3693
- for(mf_int k1 = 0; k1 < d; ++k1)
3694
- g_v_3 += Q[n * k1 + v] * t[k1];
3695
- mf_double g_v = -(1.0 - alpha * c) * g_v_1 + (1.0 - alpha) * g_v_2 +
3696
- alpha * g_v_3 + lambda_q2 * b[v];
3697
- mf_double h_v = (1 - alpha) * h_v_1 + alpha * a_tilde + lambda_q2;
3698
- b[v] -= static_cast<mf_float>(g_v / h_v);
3699
- }
3700
-
3701
- ///////////////////////////////////////////////////////////////
3702
- // Update cached variables.
3703
- ///////////////////////////////////////////////////////////////
3704
- // Update prediction error in CSR format
3705
- // \bar{r}_{uv} <- \bar{r}_{uv} - \bar_{p}_{ku}*\bar_{q}_{kv} + a_u*b_v
3706
- #if defined USEOMP
3707
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3708
- #endif
3709
- for(mf_long i = 0; i < tr_csr->nnz; ++i)
3710
- {
3711
- // Update prediction values of positive entries in CSR
3712
- mf_node *csr_ptr = tr_csr->R + i;
3713
- const mf_int &u_csr = csr_ptr->u;
3714
- const mf_int &v_csr = csr_ptr->v;
3715
- csr_ptr->r += a[u_csr] * b[v_csr] - P_k[u_csr] * Q_k[v_csr];
3716
-
3717
- // Update prediction values of positive entries in CSC
3718
- mf_node *csc_ptr = tr_csc->R + i;
3719
- const mf_int &u_csc = csc_ptr->u;
3720
- const mf_int &v_csc = csc_ptr->v;
3721
- csc_ptr->r += a[u_csc] * b[v_csc] - P_k[u_csc] * Q_k[v_csc];
3722
- }
3723
-
3724
- #if defined USEOMP
3725
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3726
- #endif
3727
- // Update P_k and Q_k
3728
- for(mf_int u = 0; u < m; ++u)
3729
- P_k[u] = a[u];
3730
- #if defined USEOMP
3731
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3732
- #endif
3733
- for(mf_int v = 0; v < n; ++v)
3734
- Q_k[v] = b[v];
3735
- }
3736
- }
3737
-
3738
- // Skip the whole evaluation if nothing should be printed out.
3739
- if(param.quiet)
3740
- continue;
3741
-
3742
- // Declare variable for storing objective value being minimized
3743
- // by the training procedure. Note that The objective value consists
3744
- // of two parts, loss function and regularization function.
3745
- mf_double obj = 0;
3746
- // Declare variables for storing loss function's value.
3747
- mf_double positive_loss = 0; // for positive entries in training matrix.
3748
- mf_double negative_loss = 0; // for negative entries in training matrix.
3749
- // Declare variable for storing regularization function's value.
3750
- mf_double reg = 0;
3751
-
3752
- // Compute objective value, loss function value, and regularization
3753
- // function value
3754
- calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3755
- lambda_p2, lambda_q2, P, Q, tr_csr,
3756
- obj, positive_loss, negative_loss, reg);
3757
-
3758
- // Print number of outer iterations.
3759
- cout.width(4);
3760
- cout << outer;
3761
- cout.width(13);
3762
- cout << fixed << setprecision(4) << positive_loss + negative_loss;
3763
- cout.width(15);
3764
- cout << fixed << setprecision(4) << positive_loss;
3765
- cout.width(15);
3766
- cout << fixed << setprecision(4) << negative_loss;
3767
-
3768
- if(va->nnz != 0)
3769
- {
3770
- // The following loop computes prediction scores on validation set.
3771
- // Because training scores is also maintained in coordinate descent
3772
- // framework, we didn't need to actively compute scores on training set.
3773
- #if defined USEOMP
3774
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3775
- #endif
3776
- for(mf_long i = 0; i < va->nnz; ++i)
3777
- {
3778
- mf_node &node = va->R[i];
3779
- node.r = 0;
3780
- for(mf_int k = 0; k < d; ++k)
3781
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3782
- }
3783
-
3784
- mf_double va_obj = 0;
3785
- mf_double va_positive_loss = 0;
3786
- mf_double va_negative_loss = 0;
3787
- mf_double va_reg = 0;
3788
-
3789
- calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3790
- lambda_p2, lambda_q2, P, Q, va,
3791
- va_obj, va_positive_loss, va_negative_loss, va_reg);
3792
-
3793
- cout.width(13);
3794
- cout << fixed << setprecision(4) << va_positive_loss + va_negative_loss;
3795
- cout.width(15);
3796
- cout << fixed << setprecision(4) << va_positive_loss;
3797
- cout.width(15);
3798
- cout << fixed << setprecision(4) << va_negative_loss;
3799
- }
3800
-
3801
- cout.width(13);
3802
- cout << fixed << setprecision(4) << scientific << obj;
3803
- cout << "\n" << flush;
3804
- }
3805
-
3806
- // Transpose P and Q. Note that the format of P and Q here are different
3807
- // than that for mf_model.
3808
-
3809
- mf_float *P_transpose = Utility::malloc_aligned_float((mf_long)m * d);
3810
- #if defined USEOMP
3811
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3812
- #endif
3813
- for(mf_int u = 0; u < m; ++u)
3814
- for(mf_int k = 0; k < d; ++k)
3815
- P_transpose[k + u * d] = P[u + k * m];
3816
- Utility::free_aligned_float(P);
3817
- mf_float *Q_transpose = Utility::malloc_aligned_float((mf_long)n * d);
3818
- #if defined USEOMP
3819
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3820
- #endif
3821
- for(mf_int v = 0; v < n; ++v)
3822
- for(mf_int k = 0; k < d; ++k)
3823
- Q_transpose[k + v * d] = Q[v + k * n];
3824
- Utility::free_aligned_float(Q);
3825
-
3826
- // Set the passed-in model to the result learned from the given data
3827
- // model is null
3828
- model->m = m;
3829
- model->n = n;
3830
- model->k = d;
3831
- model->b = 0.0;
3832
- model->P = P_transpose;
3833
- model->Q = Q_transpose;
3834
- }
3835
-
3836
- shared_ptr<mf_model> ccd_one_class(
3837
- mf_problem const *tr_,
3838
- mf_problem const *va_,
3839
- mf_parameter param)
3840
- {
3841
- shared_ptr<mf_model> model;
3842
- try
3843
- {
3844
- Utility util(param.fun, param.nr_threads);
3845
- // Training matrix in compressed row format (sort nodes by user id)
3846
- shared_ptr<mf_problem> tr_csr;
3847
- // Training matrix in compressed column format (sort nodes by item id)
3848
- shared_ptr<mf_problem> tr_csc;
3849
- shared_ptr<mf_problem> va;
3850
- // In tr_csr->R, i-th row starting at row_ptrs[i] and eneding right before row_ptrs[i+1]
3851
- vector<mf_node*> ptrs_u(tr_->m + 1, nullptr);
3852
- // In tr_csv->R, i-th column starting at col_ptrs[i] and eneding right before col_ptrs[i+1]
3853
- vector<mf_node*> ptrs_v(tr_->n + 1, nullptr);
3854
-
3855
- if(param.copy_data)
3856
- {
3857
- // Need a row-major and a column-major training formats
3858
- // Thus, two duplicates are made.
3859
- tr_csr = shared_ptr<mf_problem>(
3860
- Utility::copy_problem(tr_, true), deleter());
3861
- tr_csc = shared_ptr<mf_problem>(
3862
- Utility::copy_problem(tr_, true), deleter());
3863
- va = shared_ptr<mf_problem>(
3864
- Utility::copy_problem(va_, true), deleter());
3865
- }
3866
- else
3867
- {
3868
- // Need a row-major and a column-major training formats
3869
- // The original data is reused as row-major one so
3870
- // one duplicate for column-major one would be created.
3871
- tr_csr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3872
- tr_csc = shared_ptr<mf_problem>(Utility::copy_problem(tr_, true));
3873
- va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3874
- }
3875
-
3876
- // Make the training set CSR/CSC by sorting their nodes. More specifically,
3877
- // a matrix with values sorted by row index is CSR and vice versa. We will
3878
- // compute the starting location for each row (CSR) and each column (CSC)
3879
- // later.
3880
- sort(tr_csr->R, tr_csr->R+tr_csr->nnz, sort_node_by_p());
3881
- sort(tr_csc->R, tr_csc->R+tr_csc->nnz, sort_node_by_q());
3882
-
3883
- // Save starting addresses of rows for CSR and columns for CSC.
3884
- mf_int u_current = -1;
3885
- mf_int v_current = -1;
3886
- for(mf_long i = 0; i < tr_->nnz; ++i)
3887
- {
3888
- mf_node* N = nullptr;
3889
-
3890
- // Deal with CSR format.
3891
- N = tr_csr->R + i;
3892
- // Since tr_csr has been sorted by index u, seeing a larger index
3893
- // implies a new row. Assume a node is encoded a tuple of (u, v, r),
3894
- // where u is row index, v is column index, and r is entry value.
3895
- // The nodes in tr_csr->R could be
3896
- // (0, 1, 0.5), (0, 2, 3.7), (0, 4, -1.2), (2, 0, 1.2), (2, 4, 2.5)
3897
- // Then, we can see the first element of the 3rd row (indexed by 2)
3898
- // is (2, 0, 1.2), which is the 4th element in tr_csr->R. Note that
3899
- // we use the row pointer of the next non-empty row as the pointers
3900
- // of empty rows. That is,
3901
- // ptrs[0] = pointer of (0, 1, 0.5)
3902
- // ptrs[1] = pointer of (0, 2, 1.5)
3903
- // ptrs[2] = pointer of (0, 2, 1.5)
3904
- if(N->u > u_current)
3905
- {
3906
- // We (if u_current != -1) have assigned starting addresses to rows
3907
- // indexed by values smaller than or equal to u_current. Thus, we
3908
- // should handle all rows indexed starting from u_current+1 to the
3909
- // seen row index N->u.
3910
- for(mf_int u_passed = u_current + 1; u_passed <= N->u; ++u_passed)
3911
- {
3912
- // i-th non-zero value's location in tr_csr is the starting
3913
- // address of u_passed-th row.
3914
- ptrs_u[u_passed] = tr_csr->R + i;
3915
- }
3916
- u_current = N->u;
3917
- }
3918
-
3919
- // Deal with CSC format
3920
- N = tr_csc->R + i;
3921
- if(N->v > v_current)
3922
- {
3923
- // We (if v_current != -1) have assigned starting addresses to rows
3924
- // indexed by values smaller than or equal to v_current. Thus, we
3925
- // should handle all columns indexed starting from v_current+1 to
3926
- // the seen row index N->v.
3927
- for(mf_int v_passed = v_current + 1; v_passed <= N->v; ++v_passed)
3928
- {
3929
- // i-th non-zero value's location in tr_csc is the starting
3930
- // address of v_passed-th column.
3931
- ptrs_v[v_passed] = tr_csc->R + i;
3932
- }
3933
- v_current = N->v;
3934
- }
3935
-
3936
- }
3937
- // The bound of the last row. It's the address one-element behind the last
3938
- // matrix entry.
3939
- for(mf_int u_passed = u_current + 1; u_passed <= tr_->m; ++u_passed)
3940
- ptrs_u[u_passed] = tr_csr->R + tr_csr->nnz;
3941
- // The bound of the last column.
3942
- for(mf_int v_passed = v_current + 1; v_passed <= tr_->n; ++v_passed)
3943
- ptrs_v[v_passed] = tr_csc->R + tr_csc->nnz;
3944
-
3945
-
3946
- model = shared_ptr<mf_model>(Utility::init_model(tr_->m, tr_->n, param.k),
3947
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3948
-
3949
- ccd_one_class_core(util, tr_csr, tr_csc, va, param, ptrs_u, ptrs_v, model);
3950
- }
3951
- catch(exception const &e)
3952
- {
3953
- cerr << e.what() << endl;
3954
- throw;
3955
- }
3956
- return model;
3957
- }
3958
-
3959
- bool check_parameter(mf_parameter param)
3960
- {
3961
- if(param.fun != P_L2_MFR &&
3962
- param.fun != P_L1_MFR &&
3963
- param.fun != P_KL_MFR &&
3964
- param.fun != P_LR_MFC &&
3965
- param.fun != P_L2_MFC &&
3966
- param.fun != P_L1_MFC &&
3967
- param.fun != P_ROW_BPR_MFOC &&
3968
- param.fun != P_COL_BPR_MFOC &&
3969
- param.fun != P_L2_MFOC)
3970
- {
3971
- cerr << "unknown loss function" << endl;
3972
- return false;
3973
- }
3974
-
3975
- if(param.k < 1)
3976
- {
3977
- cerr << "number of factors must be greater than zero" << endl;
3978
- return false;
3979
- }
3980
-
3981
- if(param.nr_threads < 1)
3982
- {
3983
- cerr << "number of threads must be greater than zero" << endl;
3984
- return false;
3985
- }
3986
-
3987
- if(param.nr_bins < 1 || param.nr_bins < param.nr_threads)
3988
- {
3989
- cerr << "number of bins must be greater than number of threads"
3990
- << endl;
3991
- return false;
3992
- }
3993
-
3994
- if(param.nr_iters < 1)
3995
- {
3996
- cerr << "number of iterations must be greater than zero" << endl;
3997
- return false;
3998
- }
3999
-
4000
- if(param.lambda_p1 < 0 ||
4001
- param.lambda_p2 < 0 ||
4002
- param.lambda_q1 < 0 ||
4003
- param.lambda_q2 < 0)
4004
- {
4005
- cerr << "regularization coefficient must be non-negative" << endl;
4006
- return false;
4007
- }
4008
-
4009
- if(param.eta <= 0)
4010
- {
4011
- cerr << "learning rate must be greater than zero" << endl;
4012
- return false;
4013
- }
4014
-
4015
- if(param.fun == P_KL_MFR && !param.do_nmf)
4016
- {
4017
- cerr << "--nmf must be set when using generalized KL-divergence"
4018
- << endl;
4019
- return false;
4020
- }
4021
-
4022
- if(param.nr_bins <= 2*param.nr_threads)
4023
- {
4024
- cerr << "Warning: insufficient blocks may slow down the training"
4025
- << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4026
- }
4027
-
4028
- if(param.nr_bins <= 2*param.nr_threads)
4029
- {
4030
- cerr << "Warning: insufficient blocks may slow down the training"
4031
- << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4032
- }
4033
-
4034
- if(param.alpha < 0)
4035
- {
4036
- cerr << "alpha must be a non-negative number" << endl;
4037
- }
4038
-
4039
- return true;
4040
- }
4041
-
4042
- //--------------------------------------
4043
- //-----Classes for cross validation-----
4044
- //--------------------------------------
4045
-
4046
- class CrossValidatorBase
4047
- {
4048
- public:
4049
- CrossValidatorBase(mf_parameter param_, mf_int nr_folds_);
4050
- mf_double do_cross_validation();
4051
- virtual mf_double do_cv1(vector<mf_int> &hidden_blocks) = 0;
4052
- protected:
4053
- mf_parameter param;
4054
- mf_int nr_bins;
4055
- mf_int nr_folds;
4056
- mf_int nr_blocks_per_fold;
4057
- bool quiet;
4058
- Utility util;
4059
- mf_double cv_error;
4060
- };
4061
-
4062
- CrossValidatorBase::CrossValidatorBase(mf_parameter param_, mf_int nr_folds_)
4063
- : param(param_), nr_bins(param_.nr_bins), nr_folds(nr_folds_),
4064
- nr_blocks_per_fold(nr_bins*nr_bins/nr_folds), quiet(param_.quiet),
4065
- util(param.fun, param.nr_threads), cv_error(0)
4066
- {
4067
- param.quiet = true;
4068
- }
4069
-
4070
- mf_double CrossValidatorBase::do_cross_validation()
4071
- {
4072
- vector<mf_int> cv_blocks;
4073
- srand(0);
4074
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
4075
- cv_blocks.push_back(block);
4076
- random_shuffle(cv_blocks.begin(), cv_blocks.end());
4077
-
4078
- if(!quiet)
4079
- {
4080
- cout.width(4);
4081
- cout << "fold";
4082
- cout.width(10);
4083
- cout << util.get_error_legend();
4084
- cout << endl;
4085
- }
4086
-
4087
- cv_error = 0;
4088
-
4089
- for(mf_int fold = 0; fold < nr_folds; ++fold)
4090
- {
4091
- mf_int begin = fold*nr_blocks_per_fold;
4092
- mf_int end = min((fold+1)*nr_blocks_per_fold, nr_bins*nr_bins);
4093
- vector<mf_int> hidden_blocks(cv_blocks.begin()+begin,
4094
- cv_blocks.begin()+end);
4095
-
4096
- mf_double err = do_cv1(hidden_blocks);
4097
- cv_error += err;
4098
-
4099
- if(!quiet)
4100
- {
4101
- cout.width(4);
4102
- cout << fold;
4103
- cout.width(10);
4104
- cout << fixed << setprecision(4) << err;
4105
- cout << endl;
4106
- }
4107
- }
4108
-
4109
- if(!quiet)
4110
- {
4111
- cout.width(14);
4112
- cout.fill('=');
4113
- cout << "" << endl;
4114
- cout.fill(' ');
4115
- cout.width(4);
4116
- cout << "avg";
4117
- cout.width(10);
4118
- cout << fixed << setprecision(4) << cv_error/nr_folds;
4119
- cout << endl;
4120
- }
4121
-
4122
- return cv_error/nr_folds;
4123
- }
4124
-
4125
- class CrossValidator : public CrossValidatorBase
4126
- {
4127
- public:
4128
- CrossValidator(
4129
- mf_parameter param_, mf_int nr_folds_, mf_problem const *prob_)
4130
- : CrossValidatorBase(param_, nr_folds_), prob(prob_) {};
4131
- mf_double do_cv1(vector<mf_int> &hidden_blocks);
4132
- private:
4133
- mf_problem const *prob;
4134
- };
4135
-
4136
- mf_double CrossValidator::do_cv1(vector<mf_int> &hidden_blocks)
4137
- {
4138
- mf_double err = 0;
4139
- fpsg(prob, nullptr, param, hidden_blocks, &err);
4140
- return err;
4141
- }
4142
-
4143
- class CrossValidatorOnDisk : public CrossValidatorBase
4144
- {
4145
- public:
4146
- CrossValidatorOnDisk(
4147
- mf_parameter param_, mf_int nr_folds_, string data_path_)
4148
- : CrossValidatorBase(param_, nr_folds_), data_path(data_path_) {};
4149
- mf_double do_cv1(vector<mf_int> &hidden_blocks);
4150
- private:
4151
- string data_path;
4152
- };
4153
-
4154
- mf_double CrossValidatorOnDisk::do_cv1(vector<mf_int> &hidden_blocks)
4155
- {
4156
- mf_double err = 0;
4157
- fpsg_on_disk(data_path, string(), param, hidden_blocks, &err);
4158
- return err;
4159
- }
4160
-
4161
- } // unnamed namespace
4162
-
4163
- mf_model* mf_train_with_validation(
4164
- mf_problem const *tr,
4165
- mf_problem const *va,
4166
- mf_parameter param)
4167
- {
4168
- if(!check_parameter(param))
4169
- return nullptr;
4170
-
4171
- shared_ptr<mf_model> model(nullptr);
4172
-
4173
- if(param.fun != P_L2_MFOC)
4174
- // Use stochastic gradient method
4175
- model = fpsg(tr, va, param);
4176
- else
4177
- // Use coordinate descent method
4178
- model = ccd_one_class(tr, va, param);
4179
-
4180
- mf_model *model_ret = new mf_model;
4181
-
4182
- model_ret->fun = model->fun;
4183
- model_ret->m = model->m;
4184
- model_ret->n = model->n;
4185
- model_ret->k = model->k;
4186
- model_ret->b = model->b;
4187
-
4188
- model_ret->P = model->P;
4189
- model->P = nullptr;
4190
-
4191
- model_ret->Q = model->Q;
4192
- model->Q = nullptr;
4193
-
4194
- return model_ret;
4195
- }
4196
-
4197
- mf_model* mf_train_with_validation_on_disk(
4198
- char const *tr_path,
4199
- char const *va_path,
4200
- mf_parameter param)
4201
- {
4202
- // Two conditions lead to empty model. First, any parameter is not in its
4203
- // supported range. Second, one-class matrix facotorization with L2-loss
4204
- // (-f 12) doesn't support disk-level training.
4205
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4206
- return nullptr;
4207
-
4208
- shared_ptr<mf_model> model = fpsg_on_disk(
4209
- string(tr_path), string(va_path), param);
4210
-
4211
- mf_model *model_ret = new mf_model;
4212
-
4213
- model_ret->fun = model->fun;
4214
- model_ret->m = model->m;
4215
- model_ret->n = model->n;
4216
- model_ret->k = model->k;
4217
- model_ret->b = model->b;
4218
-
4219
- model_ret->P = model->P;
4220
- model->P = nullptr;
4221
-
4222
- model_ret->Q = model->Q;
4223
- model->Q = nullptr;
4224
-
4225
- return model_ret;
4226
- }
4227
-
4228
- mf_model* mf_train(mf_problem const *prob, mf_parameter param)
4229
- {
4230
- return mf_train_with_validation(prob, nullptr, param);
4231
- }
4232
-
4233
- mf_model* mf_train_on_disk(char const *tr_path, mf_parameter param)
4234
- {
4235
- return mf_train_with_validation_on_disk(tr_path, "", param);
4236
- }
4237
-
4238
- mf_double mf_cross_validation(
4239
- mf_problem const *prob,
4240
- mf_int nr_folds,
4241
- mf_parameter param)
4242
- {
4243
- // Two conditions lead to empty model. First, any parameter is not in its
4244
- // supported range. Second, one-class matrix facotorization with L2-loss
4245
- // (-f 12) doesn't support disk-level training.
4246
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4247
- return 0;
4248
-
4249
- CrossValidator validator(param, nr_folds, prob);
4250
-
4251
- return validator.do_cross_validation();
4252
- }
4253
-
4254
- mf_double mf_cross_validation_on_disk(
4255
- char const *prob,
4256
- mf_int nr_folds,
4257
- mf_parameter param)
4258
- {
4259
- // Two conditions lead to empty model. First, any parameter is not in its
4260
- // supported range. Second, one-class matrix facotorization with L2-loss
4261
- // (-f 12) doesn't support disk-level training.
4262
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4263
- return 0;
4264
-
4265
- CrossValidatorOnDisk validator(param, nr_folds, string(prob));
4266
-
4267
- return validator.do_cross_validation();
4268
- }
4269
-
4270
- mf_problem read_problem(string path)
4271
- {
4272
- mf_problem prob;
4273
- prob.m = 0;
4274
- prob.n = 0;
4275
- prob.nnz = 0;
4276
- prob.R = nullptr;
4277
-
4278
- if(path.empty())
4279
- return prob;
4280
-
4281
- ifstream f(path);
4282
- if(!f.is_open())
4283
- return prob;
4284
-
4285
- string line;
4286
- while(getline(f, line))
4287
- prob.nnz += 1;
4288
-
4289
- mf_node *R = new mf_node[static_cast<size_t>(prob.nnz)];
4290
-
4291
- f.close();
4292
- f.open(path);
4293
-
4294
- mf_long idx = 0;
4295
- for(mf_node N; f >> N.u >> N.v >> N.r;)
4296
- {
4297
- if(N.u+1 > prob.m)
4298
- prob.m = N.u+1;
4299
- if(N.v+1 > prob.n)
4300
- prob.n = N.v+1;
4301
- R[idx] = N;
4302
- ++idx;
4303
- }
4304
- prob.R = R;
4305
-
4306
- f.close();
4307
-
4308
- return prob;
4309
- }
4310
-
4311
- mf_int mf_save_model(mf_model const *model, char const *path)
4312
- {
4313
- ofstream f(path);
4314
- if(!f.is_open())
4315
- return 1;
4316
-
4317
- f << "f " << model->fun << endl;
4318
- f << "m " << model->m << endl;
4319
- f << "n " << model->n << endl;
4320
- f << "k " << model->k << endl;
4321
- f << "b " << model->b << endl;
4322
-
4323
- auto write = [&] (mf_float *ptr, mf_int size, char prefix)
4324
- {
4325
- for(mf_int i = 0; i < size; ++i)
4326
- {
4327
- mf_float *ptr1 = ptr + (mf_long)i*model->k;
4328
- f << prefix << i << " ";
4329
- if(isnan(ptr1[0]))
4330
- {
4331
- f << "F ";
4332
- for(mf_int d = 0; d < model->k; ++d)
4333
- f << 0 << " ";
4334
- }
4335
- else
4336
- {
4337
- f << "T ";
4338
- for(mf_int d = 0; d < model->k; ++d)
4339
- f << ptr1[d] << " ";
4340
- }
4341
- f << endl;
4342
- }
4343
-
4344
- };
4345
-
4346
- write(model->P, model->m, 'p');
4347
- write(model->Q, model->n, 'q');
4348
-
4349
- f.close();
4350
-
4351
- return 0;
4352
- }
4353
-
4354
- mf_model* mf_load_model(char const *path)
4355
- {
4356
- ifstream f(path);
4357
- if(!f.is_open())
4358
- return nullptr;
4359
-
4360
- string dummy;
4361
-
4362
- mf_model *model = new mf_model;
4363
- model->P = nullptr;
4364
- model->Q = nullptr;
4365
-
4366
- f >> dummy >> model->fun >> dummy >> model->m >> dummy >> model->n >>
4367
- dummy >> model->k >> dummy >> model->b;
4368
-
4369
- try
4370
- {
4371
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
4372
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
4373
- }
4374
- catch(bad_alloc const &e)
4375
- {
4376
- cerr << e.what() << endl;
4377
- mf_destroy_model(&model);
4378
- return nullptr;
4379
- }
4380
-
4381
- auto read = [&] (mf_float *ptr, mf_int size)
4382
- {
4383
- for(mf_int i = 0; i < size; ++i)
4384
- {
4385
- mf_float *ptr1 = ptr + (mf_long)i*model->k;
4386
- f >> dummy >> dummy;
4387
- if(dummy.compare("F") == 0) // nan vector starts with "F"
4388
- for(mf_int d = 0; d < model->k; ++d)
4389
- {
4390
- f >> dummy;
4391
- ptr1[d] = numeric_limits<mf_float>::quiet_NaN();
4392
- }
4393
- else
4394
- for(mf_int d = 0; d < model->k; ++d)
4395
- f >> ptr1[d];
4396
- }
4397
- };
4398
-
4399
- read(model->P, model->m);
4400
- read(model->Q, model->n);
4401
-
4402
- f.close();
4403
-
4404
- return model;
4405
- }
4406
-
4407
- void mf_destroy_model(mf_model **model)
4408
- {
4409
- if(model == nullptr || *model == nullptr)
4410
- return;
4411
- Utility::free_aligned_float((*model)->P);
4412
- Utility::free_aligned_float((*model)->Q);
4413
- delete *model;
4414
- *model = nullptr;
4415
- }
4416
-
4417
- mf_float mf_predict(mf_model const *model, mf_int u, mf_int v)
4418
- {
4419
- if(u < 0 || u >= model->m || v < 0 || v >= model->n)
4420
- return model->b;
4421
-
4422
- mf_float *p = model->P+(mf_long)u*model->k;
4423
- mf_float *q = model->Q+(mf_long)v*model->k;
4424
-
4425
- mf_float z = std::inner_product(p, p+model->k, q, (mf_float)0.0f);
4426
-
4427
- if(isnan(z))
4428
- z = model->b;
4429
-
4430
- if(model->fun == P_L2_MFC ||
4431
- model->fun == P_L1_MFC ||
4432
- model->fun == P_LR_MFC)
4433
- z = z > 0.0f? 1.0f: -1.0f;
4434
-
4435
- return z;
4436
- }
4437
-
4438
- mf_double calc_rmse(mf_problem *prob, mf_model *model)
4439
- {
4440
- if(prob->nnz == 0)
4441
- return 0;
4442
- mf_double loss = 0;
4443
- #if defined USEOMP
4444
- #pragma omp parallel for schedule(static) reduction(+:loss)
4445
- #endif
4446
- for(mf_long i = 0; i < prob->nnz; ++i)
4447
- {
4448
- mf_node &N = prob->R[i];
4449
- mf_float e = N.r - mf_predict(model, N.u, N.v);
4450
- loss += e*e;
4451
- }
4452
- return sqrt(loss/prob->nnz);
4453
- }
4454
-
4455
- mf_double calc_mae(mf_problem *prob, mf_model *model)
4456
- {
4457
- if(prob->nnz == 0)
4458
- return 0;
4459
- mf_double loss = 0;
4460
- #if defined USEOMP
4461
- #pragma omp parallel for schedule(static) reduction(+:loss)
4462
- #endif
4463
- for(mf_long i = 0; i < prob->nnz; ++i)
4464
- {
4465
- mf_node &N = prob->R[i];
4466
- loss += abs(N.r - mf_predict(model, N.u, N.v));
4467
- }
4468
- return loss/prob->nnz;
4469
- }
4470
-
4471
- mf_double calc_gkl(mf_problem *prob, mf_model *model)
4472
- {
4473
- if(prob->nnz == 0)
4474
- return 0;
4475
- mf_double loss = 0;
4476
- #if defined USEOMP
4477
- #pragma omp parallel for schedule(static) reduction(+:loss)
4478
- #endif
4479
- for(mf_long i = 0; i < prob->nnz; ++i)
4480
- {
4481
- mf_node &N = prob->R[i];
4482
- mf_float z = mf_predict(model, N.u, N.v);
4483
- loss += N.r*log(N.r/z)-N.r+z;
4484
- }
4485
- return loss/prob->nnz;
4486
- }
4487
-
4488
- mf_double calc_logloss(mf_problem *prob, mf_model *model)
4489
- {
4490
- if(prob->nnz == 0)
4491
- return 0;
4492
- mf_double logloss = 0;
4493
- #if defined USEOMP
4494
- #pragma omp parallel for schedule(static) reduction(+:logloss)
4495
- #endif
4496
- for(mf_long i = 0; i < prob->nnz; ++i)
4497
- {
4498
- mf_node &N = prob->R[i];
4499
- mf_float z = mf_predict(model, N.u, N.v);
4500
- if(N.r > 0)
4501
- logloss += log(1.0+exp(-z));
4502
- else
4503
- logloss += log(1.0+exp(z));
4504
- }
4505
- return logloss/prob->nnz;
4506
- }
4507
-
4508
- mf_double calc_accuracy(mf_problem *prob, mf_model *model)
4509
- {
4510
- if(prob->nnz == 0)
4511
- return 0;
4512
- mf_double acc = 0;
4513
- #if defined USEOMP
4514
- #pragma omp parallel for schedule(static) reduction(+:acc)
4515
- #endif
4516
- for(mf_long i = 0; i < prob->nnz; ++i)
4517
- {
4518
- mf_node &N = prob->R[i];
4519
- mf_float z = mf_predict(model, N.u, N.v);
4520
- if(N.r > 0)
4521
- acc += z > 0? 1: 0;
4522
- else
4523
- acc += z < 0? 1: 0;
4524
- }
4525
- return acc/prob->nnz;
4526
- }
4527
-
4528
- pair<mf_double, mf_double> calc_mpr_auc(mf_problem *prob,
4529
- mf_model *model, bool transpose)
4530
- {
4531
- mf_int mf_node::*row_ptr;
4532
- mf_int mf_node::*col_ptr;
4533
- mf_int m = 0, n = 0;
4534
- if(!transpose)
4535
- {
4536
- row_ptr = &mf_node::u;
4537
- col_ptr = &mf_node::v;
4538
- m = max(prob->m, model->m);
4539
- n = max(prob->n, model->n);
4540
- }
4541
- else
4542
- {
4543
- row_ptr = &mf_node::v;
4544
- col_ptr = &mf_node::u;
4545
- m = max(prob->n, model->n);
4546
- n = max(prob->m, model->m);
4547
- }
4548
-
4549
- auto sort_by_id = [&] (mf_node const &lhs, mf_node const &rhs)
4550
- {
4551
- return tie(lhs.*row_ptr, lhs.*col_ptr) <
4552
- tie(rhs.*row_ptr, rhs.*col_ptr);
4553
- };
4554
-
4555
- sort(prob->R, prob->R+prob->nnz, sort_by_id);
4556
-
4557
- auto sort_by_pred = [&] (pair<mf_node, mf_float> const &lhs,
4558
- pair<mf_node, mf_float> const &rhs) { return lhs.second < rhs.second; };
4559
-
4560
- vector<mf_int> pos_cnts(m+1, 0);
4561
- for(mf_int i = 0; i < prob->nnz; ++i)
4562
- pos_cnts[prob->R[i].*row_ptr+1] += 1;
4563
- for(mf_int i = 1; i < m+1; ++i)
4564
- pos_cnts[i] += pos_cnts[i-1];
4565
-
4566
- mf_int total_m = 0;
4567
- mf_long total_pos = 0;
4568
- mf_double all_u_mpr = 0;
4569
- mf_double all_u_auc = 0;
4570
- #if defined USEOMP
4571
- #pragma omp parallel for schedule(static) reduction(+: total_m, total_pos, all_u_mpr, all_u_auc)
4572
- #endif
4573
- for(mf_int i = 0; i < m; ++i)
4574
- {
4575
- if(pos_cnts[i+1]-pos_cnts[i] < 1)
4576
- continue;
4577
-
4578
- vector<pair<mf_node, mf_float>> row(n);
4579
-
4580
- for(mf_int j = 0; j < n; ++j)
4581
- {
4582
- mf_node N;
4583
- N.*row_ptr = i;
4584
- N.*col_ptr = j;
4585
- N.r = 0;
4586
- row[j] = make_pair(N, mf_predict(model, N.u, N.v));
4587
- }
4588
-
4589
- mf_int pos = 0;
4590
- vector<mf_int> index(pos_cnts[i+1]-pos_cnts[i], 0);
4591
- for(mf_int j = pos_cnts[i]; j < pos_cnts[i+1]; ++j)
4592
- {
4593
- if(prob->R[j].r <= 0)
4594
- continue;
4595
-
4596
- mf_int col = prob->R[j].*col_ptr;
4597
- row[col].first.r = prob->R[j].r;
4598
- index[pos] = col;
4599
- pos += 1;
4600
- }
4601
-
4602
- if(n-pos < 1 || pos < 1)
4603
- continue;
4604
-
4605
- ++total_m;
4606
- total_pos += pos;
4607
-
4608
- mf_int count = 0;
4609
- for(mf_int k = 0; k < pos; ++k)
4610
- {
4611
- swap(row[count], row[index[k]]);
4612
- ++count;
4613
- }
4614
- sort(row.begin(), row.begin()+pos, sort_by_pred);
4615
-
4616
- mf_double u_mpr = 0;
4617
- mf_double u_auc = 0;
4618
- for(auto neg_it = row.begin()+pos; neg_it != row.end(); ++neg_it)
4619
- {
4620
- if(row[pos-1].second <= neg_it->second)
4621
- {
4622
- u_mpr += pos;
4623
- continue;
4624
- }
4625
-
4626
- mf_int left = 0;
4627
- mf_int right = pos-1;
4628
- while(left < right)
4629
- {
4630
- mf_int mid = (left+right)/2;
4631
- if(row[mid].second > neg_it->second)
4632
- right = mid;
4633
- else
4634
- left = mid+1;
4635
- }
4636
- u_mpr += left;
4637
- u_auc += pos-left;
4638
- }
4639
-
4640
- all_u_mpr += u_mpr/(n-pos);
4641
- all_u_auc += u_auc/(n-pos)/pos;
4642
- }
4643
-
4644
- all_u_mpr /= total_pos;
4645
- all_u_auc /= total_m;
4646
-
4647
- return make_pair(all_u_mpr, all_u_auc);
4648
- }
4649
-
4650
- mf_double calc_mpr(mf_problem *prob, mf_model *model, bool transpose)
4651
- {
4652
- return calc_mpr_auc(prob, model, transpose).first;
4653
- }
4654
-
4655
- mf_double calc_auc(mf_problem *prob, mf_model *model, bool transpose)
4656
- {
4657
- return calc_mpr_auc(prob, model, transpose).second;
4658
- }
4659
-
4660
- mf_parameter mf_get_default_param()
4661
- {
4662
- mf_parameter param;
4663
-
4664
- param.fun = P_L2_MFR;
4665
- param.k = 8;
4666
- param.nr_threads = 12;
4667
- param.nr_bins = 20;
4668
- param.nr_iters = 20;
4669
- param.lambda_p1 = 0.0f;
4670
- param.lambda_q1 = 0.0f;
4671
- param.lambda_p2 = 0.1f;
4672
- param.lambda_q2 = 0.1f;
4673
- param.eta = 0.1f;
4674
- param.alpha = 1.0f;
4675
- param.c = 0.0001f;
4676
- param.do_nmf = false;
4677
- param.quiet = false;
4678
- param.copy_data = true;
4679
-
4680
- return param;
4681
- }
4682
-
4683
- }