libmf 0.1.2 → 0.2.3

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
data/vendor/libmf/mf.cpp DELETED
@@ -1,4683 +0,0 @@
1
- #include <algorithm>
2
- #include <cmath>
3
- #include <condition_variable>
4
- #include <cstdlib>
5
- #include <cstring>
6
- #include <fstream>
7
- #include <iostream>
8
- #include <iomanip>
9
- #include <memory>
10
- #include <numeric>
11
- #include <queue>
12
- #include <random>
13
- #include <stdexcept>
14
- #include <string>
15
- #include <thread>
16
- #include <unordered_set>
17
- #include <vector>
18
- #include <limits>
19
-
20
- #include "mf.h"
21
-
22
- #if defined USESSE
23
- #include <pmmintrin.h>
24
- #endif
25
-
26
- #if defined USEAVX
27
- #include <immintrin.h>
28
- #endif
29
-
30
- #if defined USEOMP
31
- #include <omp.h>
32
- #endif
33
-
34
- namespace mf
35
- {
36
-
37
- using namespace std;
38
-
39
- namespace // unnamed namespace
40
- {
41
-
42
- mf_int const kALIGNByte = 32;
43
- mf_int const kALIGN = kALIGNByte/sizeof(mf_float);
44
-
45
- //--------------------------------------
46
- //---------Scheduler of Blocks----------
47
- //--------------------------------------
48
-
49
- class Scheduler
50
- {
51
- public:
52
- Scheduler(mf_int nr_bins, mf_int nr_threads, vector<mf_int> cv_blocks);
53
- mf_int get_job();
54
- mf_int get_bpr_job(mf_int first_block, bool is_column_oriented);
55
- void put_job(mf_int block, mf_double loss, mf_double error);
56
- void put_bpr_job(mf_int first_block, mf_int second_block);
57
- mf_double get_loss();
58
- mf_double get_error();
59
- mf_int get_negative(mf_int first_block, mf_int second_block,
60
- mf_int m, mf_int n, bool is_column_oriented);
61
- void wait_for_jobs_done();
62
- void resume();
63
- void terminate();
64
- bool is_terminated();
65
-
66
- private:
67
- mf_int nr_bins;
68
- mf_int nr_threads;
69
- mf_int nr_done_jobs;
70
- mf_int target;
71
- mf_int nr_paused_threads;
72
- bool terminated;
73
- vector<mf_int> counts;
74
- vector<mf_int> busy_p_blocks;
75
- vector<mf_int> busy_q_blocks;
76
- vector<mf_double> block_losses;
77
- vector<mf_double> block_errors;
78
- vector<minstd_rand0> block_generators;
79
- unordered_set<mf_int> cv_blocks;
80
- mutex mtx;
81
- condition_variable cond_var;
82
- default_random_engine generator;
83
- uniform_real_distribution<mf_float> distribution;
84
- priority_queue<pair<mf_float, mf_int>,
85
- vector<pair<mf_float, mf_int>>,
86
- greater<pair<mf_float, mf_int>>> pq;
87
- };
88
-
89
- Scheduler::Scheduler(mf_int nr_bins, mf_int nr_threads,
90
- vector<mf_int> cv_blocks)
91
- : nr_bins(nr_bins),
92
- nr_threads(nr_threads),
93
- nr_done_jobs(0),
94
- target(nr_bins*nr_bins),
95
- nr_paused_threads(0),
96
- terminated(false),
97
- counts(nr_bins*nr_bins, 0),
98
- busy_p_blocks(nr_bins, 0),
99
- busy_q_blocks(nr_bins, 0),
100
- block_losses(nr_bins*nr_bins, 0),
101
- block_errors(nr_bins*nr_bins, 0),
102
- cv_blocks(cv_blocks.begin(), cv_blocks.end()),
103
- distribution(0.0, 1.0)
104
- {
105
- for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
106
- {
107
- if(this->cv_blocks.find(i) == this->cv_blocks.end())
108
- pq.emplace(distribution(generator), i);
109
- block_generators.push_back(minstd_rand0(rand()));
110
- }
111
- }
112
-
113
- mf_int Scheduler::get_job()
114
- {
115
- bool is_found = false;
116
- pair<mf_float, mf_int> block;
117
-
118
- while(!is_found)
119
- {
120
- lock_guard<mutex> lock(mtx);
121
- vector<pair<mf_float, mf_int>> locked_blocks;
122
- mf_int p_block = 0;
123
- mf_int q_block = 0;
124
-
125
- while(!pq.empty())
126
- {
127
- block = pq.top();
128
- pq.pop();
129
-
130
- p_block = block.second/nr_bins;
131
- q_block = block.second%nr_bins;
132
-
133
- if(busy_p_blocks[p_block] || busy_q_blocks[q_block])
134
- locked_blocks.push_back(block);
135
- else
136
- {
137
- busy_p_blocks[p_block] = 1;
138
- busy_q_blocks[q_block] = 1;
139
- counts[block.second] += 1;
140
- is_found = true;
141
- break;
142
- }
143
- }
144
-
145
- for(auto &block1 : locked_blocks)
146
- pq.push(block1);
147
- }
148
-
149
- return block.second;
150
- }
151
-
152
- mf_int Scheduler::get_bpr_job(mf_int first_block, bool is_column_oriented)
153
- {
154
- lock_guard<mutex> lock(mtx);
155
- mf_int another = first_block;
156
- vector<pair<mf_float, mf_int>> locked_blocks;
157
-
158
- while(!pq.empty())
159
- {
160
- pair<mf_float, mf_int> block = pq.top();
161
- pq.pop();
162
-
163
- mf_int p_block = block.second/nr_bins;
164
- mf_int q_block = block.second%nr_bins;
165
-
166
- auto is_rejected = [&] ()
167
- {
168
- if(is_column_oriented)
169
- return first_block%nr_bins != q_block ||
170
- busy_p_blocks[p_block];
171
- else
172
- return first_block/nr_bins != p_block ||
173
- busy_q_blocks[q_block];
174
- };
175
-
176
- if(is_rejected())
177
- locked_blocks.push_back(block);
178
- else
179
- {
180
- busy_p_blocks[p_block] = 1;
181
- busy_q_blocks[q_block] = 1;
182
- another = block.second;
183
- break;
184
- }
185
- }
186
-
187
- for(auto &block : locked_blocks)
188
- pq.push(block);
189
-
190
- return another;
191
- }
192
-
193
- void Scheduler::put_job(mf_int block_idx, mf_double loss, mf_double error)
194
- {
195
- // Return the held block to the scheduler
196
- {
197
- lock_guard<mutex> lock(mtx);
198
- busy_p_blocks[block_idx/nr_bins] = 0;
199
- busy_q_blocks[block_idx%nr_bins] = 0;
200
- block_losses[block_idx] = loss;
201
- block_errors[block_idx] = error;
202
- ++nr_done_jobs;
203
- mf_float priority =
204
- (mf_float)counts[block_idx]+distribution(generator);
205
- pq.emplace(priority, block_idx);
206
- ++nr_paused_threads;
207
- // Tell others that a block is available again.
208
- cond_var.notify_all();
209
- }
210
-
211
- // Wait if nr_done_jobs (aka the number of processed blocks) is too many
212
- // because we want to print out the training status roughly once all blocks
213
- // are processed once. This is the only place that a solver thread should
214
- // wait for something.
215
- {
216
- unique_lock<mutex> lock(mtx);
217
- cond_var.wait(lock, [&] {
218
- return nr_done_jobs < target;
219
- });
220
- }
221
-
222
- // Nothing is blocking and this thread is going to take another block
223
- {
224
- lock_guard<mutex> lock(mtx);
225
- --nr_paused_threads;
226
- }
227
- }
228
-
229
- void Scheduler::put_bpr_job(mf_int first_block, mf_int second_block)
230
- {
231
- if(first_block == second_block)
232
- return;
233
-
234
- lock_guard<mutex> lock(mtx);
235
- {
236
- busy_p_blocks[second_block/nr_bins] = 0;
237
- busy_q_blocks[second_block%nr_bins] = 0;
238
- mf_float priority =
239
- (mf_float)counts[second_block]+distribution(generator);
240
- pq.emplace(priority, second_block);
241
- }
242
- }
243
-
244
- mf_double Scheduler::get_loss()
245
- {
246
- lock_guard<mutex> lock(mtx);
247
- return accumulate(block_losses.begin(), block_losses.end(), 0.0);
248
- }
249
-
250
- mf_double Scheduler::get_error()
251
- {
252
- lock_guard<mutex> lock(mtx);
253
- return accumulate(block_errors.begin(), block_errors.end(), 0.0);
254
- }
255
-
256
- mf_int Scheduler::get_negative(mf_int first_block, mf_int second_block,
257
- mf_int m, mf_int n, bool is_column_oriented)
258
- {
259
- mf_int rand_val = (mf_int)block_generators[first_block]();
260
-
261
- auto gen_random = [&] (mf_int block_id)
262
- {
263
- mf_int v_min, v_max;
264
-
265
- if(is_column_oriented)
266
- {
267
- mf_int seg_size = (mf_int)ceil((double)m/nr_bins);
268
- v_min = min((block_id/nr_bins)*seg_size, m-1);
269
- v_max = min(v_min+seg_size, m-1);
270
- }
271
- else
272
- {
273
- mf_int seg_size = (mf_int)ceil((double)n/nr_bins);
274
- v_min = min((block_id%nr_bins)*seg_size, n-1);
275
- v_max = min(v_min+seg_size, n-1);
276
- }
277
- if(v_max == v_min)
278
- return v_min;
279
- else
280
- return rand_val%(v_max-v_min)+v_min;
281
- };
282
-
283
- if(rand_val % 2)
284
- return (mf_int)gen_random(first_block);
285
- else
286
- return (mf_int)gen_random(second_block);
287
- }
288
-
289
- void Scheduler::wait_for_jobs_done()
290
- {
291
- unique_lock<mutex> lock(mtx);
292
-
293
- // The first thing the main thread should wait for is that solver threads
294
- // process enough matrix blocks.
295
- // [REVIEW] Is it really needed? Solver threads automatically stop if they
296
- // process too many blocks, so the next wait should be enough for stopping
297
- // the main thread when nr_done_job is not enough.
298
- cond_var.wait(lock, [&] {
299
- return nr_done_jobs >= target;
300
- });
301
-
302
- // Wait for all threads to stop. Once a thread realizes that all threads
303
- // have processed enough blocks it should stop. Then, the main thread can
304
- // print values safely.
305
- cond_var.wait(lock, [&] {
306
- return nr_paused_threads == nr_threads;
307
- });
308
- }
309
-
310
- void Scheduler::resume()
311
- {
312
- lock_guard<mutex> lock(mtx);
313
- target += nr_bins*nr_bins;
314
- cond_var.notify_all();
315
- }
316
-
317
- void Scheduler::terminate()
318
- {
319
- lock_guard<mutex> lock(mtx);
320
- terminated = true;
321
- }
322
-
323
- bool Scheduler::is_terminated()
324
- {
325
- lock_guard<mutex> lock(mtx);
326
- return terminated;
327
- }
328
-
329
- //--------------------------------------
330
- //------------Block of matrix-----------
331
- //--------------------------------------
332
-
333
- class BlockBase
334
- {
335
- public:
336
- virtual bool move_next() { return false; };
337
- virtual mf_node* get_current() { return nullptr; }
338
- virtual void reload() {};
339
- virtual void free() {};
340
- virtual mf_long get_nnz() { return 0; };
341
- virtual ~BlockBase() {};
342
- };
343
-
344
- class Block : public BlockBase
345
- {
346
- public:
347
- Block() : first(nullptr), last(nullptr), current(nullptr) {};
348
- Block(mf_node *first_, mf_node *last_)
349
- : first(first_), last(last_), current(nullptr) {};
350
- bool move_next() { return ++current != last; }
351
- mf_node* get_current() { return current; }
352
- void tie_to(mf_node *first_, mf_node *last_);
353
- void reload() { current = first-1; };
354
- mf_long get_nnz() { return last-first; };
355
-
356
- private:
357
- mf_node* first;
358
- mf_node* last;
359
- mf_node* current;
360
- };
361
-
362
- void Block::tie_to(mf_node *first_, mf_node *last_)
363
- {
364
- first = first_;
365
- last = last_;
366
- };
367
-
368
- class BlockOnDisk : public BlockBase
369
- {
370
- public:
371
- BlockOnDisk() : first(0), last(0), current(0),
372
- source_path(""), buffer(0) {};
373
- bool move_next() { return ++current < last-first; }
374
- mf_node* get_current() { return &buffer[static_cast<size_t>(current)]; }
375
- void tie_to(string source_path_, mf_long first_, mf_long last_);
376
- void reload();
377
- void free() { buffer.resize(0); };
378
- mf_long get_nnz() { return last-first; };
379
-
380
- private:
381
- mf_long first;
382
- mf_long last;
383
- mf_long current;
384
- string source_path;
385
- vector<mf_node> buffer;
386
- };
387
-
388
- void BlockOnDisk::tie_to(string source_path_, mf_long first_, mf_long last_)
389
- {
390
- source_path = source_path_;
391
- first = first_;
392
- last = last_;
393
- }
394
-
395
- void BlockOnDisk::reload()
396
- {
397
- ifstream source(source_path, ifstream::in|ifstream::binary);
398
- if(!source)
399
- throw runtime_error("can not open "+source_path);
400
-
401
- buffer.resize(static_cast<size_t>(last-first));
402
- source.seekg(first*sizeof(mf_node));
403
- source.read((char*)buffer.data(), (last-first)*sizeof(mf_node));
404
- current = -1;
405
- }
406
-
407
- //--------------------------------------
408
- //-------------Miscellaneous------------
409
- //--------------------------------------
410
-
411
- struct sort_node_by_p
412
- {
413
- bool operator() (mf_node const &lhs, mf_node const &rhs)
414
- {
415
- return tie(lhs.u, lhs.v) < tie(rhs.u, rhs.v);
416
- }
417
- };
418
-
419
- struct sort_node_by_q
420
- {
421
- bool operator() (mf_node const &lhs, mf_node const &rhs)
422
- {
423
- return tie(lhs.v, lhs.u) < tie(rhs.v, rhs.u);
424
- }
425
- };
426
-
427
- struct deleter
428
- {
429
- void operator() (mf_problem *prob)
430
- {
431
- delete[] prob->R;
432
- delete prob;
433
- }
434
- };
435
-
436
-
437
- class Utility
438
- {
439
- public:
440
- Utility(mf_int f, mf_int n) : fun(f), nr_threads(n) {};
441
- void collect_info(mf_problem &prob, mf_float &avg, mf_float &std_dev);
442
- void collect_info_on_disk(string data_path, mf_problem &prob,
443
- mf_float &avg, mf_float &std_dev);
444
- void shuffle_problem(mf_problem &prob, vector<mf_int> &p_map,
445
- vector<mf_int> &q_map);
446
- vector<mf_node*> grid_problem(mf_problem &prob, mf_int nr_bins,
447
- vector<mf_int> &omega_p,
448
- vector<mf_int> &omega_q,
449
- vector<Block> &blocks);
450
- void grid_shuffle_scale_problem_on_disk(mf_int m, mf_int n, mf_int nr_bins,
451
- mf_float scale, string data_path,
452
- vector<mf_int> &p_map,
453
- vector<mf_int> &q_map,
454
- vector<mf_int> &omega_p,
455
- vector<mf_int> &omega_q,
456
- vector<BlockOnDisk> &blocks);
457
- void scale_problem(mf_problem &prob, mf_float scale);
458
- mf_double calc_reg1(mf_model &model, mf_float lambda_p, mf_float lambda_q,
459
- vector<mf_int> &omega_p, vector<mf_int> &omega_q);
460
- mf_double calc_reg2(mf_model &model, mf_float lambda_p, mf_float lambda_q,
461
- vector<mf_int> &omega_p, vector<mf_int> &omega_q);
462
- string get_error_legend() const;
463
- mf_double calc_error(vector<BlockBase*> &blocks,
464
- vector<mf_int> &cv_block_ids,
465
- mf_model const &model);
466
- void scale_model(mf_model &model, mf_float scale);
467
-
468
- static mf_problem* copy_problem(mf_problem const *prob, bool copy_data);
469
- static vector<mf_int> gen_random_map(mf_int size);
470
- // A function used to allocate all aligned float array.
471
- // It hides platform-specific function calls. Memory
472
- // allocated by malloc_aligned_float must be freed by using
473
- // free_aligned_float.
474
- static mf_float* malloc_aligned_float(mf_long size);
475
- // A function used to free all aligned float array.
476
- // It hides platform-specific function calls.
477
- static void free_aligned_float(mf_float* ptr);
478
- // Initialization function for stochastic gradient method.
479
- // Factor matrices P and Q are both randomly initialized.
480
- static mf_model* init_model(mf_int loss, mf_int m, mf_int n,
481
- mf_int k, mf_float avg,
482
- vector<mf_int> &omega_p,
483
- vector<mf_int> &omega_q);
484
- // Initialization function for one-class CD.
485
- // It does zero-initialization on factor matrix P and random initialization
486
- // on factor matrix Q.
487
- static mf_model* init_model(mf_int m, mf_int n, mf_int k);
488
- static mf_float inner_product(mf_float *p, mf_float *q, mf_int k);
489
- static vector<mf_int> gen_inv_map(vector<mf_int> &map);
490
- static void shrink_model(mf_model &model, mf_int k_new);
491
- static void shuffle_model(mf_model &model,
492
- vector<mf_int> &p_map,
493
- vector<mf_int> &q_map);
494
- mf_int get_thread_number() const { return nr_threads; };
495
- private:
496
- mf_int fun;
497
- mf_int nr_threads;
498
- };
499
-
500
- void Utility::collect_info(
501
- mf_problem &prob,
502
- mf_float &avg,
503
- mf_float &std_dev)
504
- {
505
- mf_double ex = 0;
506
- mf_double ex2 = 0;
507
-
508
- #if defined USEOMP
509
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:ex,ex2)
510
- #endif
511
- for(mf_long i = 0; i < prob.nnz; ++i)
512
- {
513
- mf_node &N = prob.R[i];
514
- ex += (mf_double)N.r;
515
- ex2 += (mf_double)N.r*N.r;
516
- }
517
-
518
- ex /= (mf_double)prob.nnz;
519
- ex2 /= (mf_double)prob.nnz;
520
- avg = (mf_float)ex;
521
- std_dev = (mf_float)sqrt(ex2-ex*ex);
522
- }
523
-
524
- void Utility::collect_info_on_disk(
525
- string data_path,
526
- mf_problem &prob,
527
- mf_float &avg,
528
- mf_float &std_dev)
529
- {
530
- mf_double ex = 0;
531
- mf_double ex2 = 0;
532
-
533
- ifstream source(data_path);
534
- if(!source.is_open())
535
- throw runtime_error("cannot open " + data_path);
536
-
537
- for(mf_node N; source >> N.u >> N.v >> N.r;)
538
- {
539
- if(N.u+1 > prob.m)
540
- prob.m = N.u+1;
541
- if(N.v+1 > prob.n)
542
- prob.n = N.v+1;
543
- prob.nnz += 1;
544
- ex += (mf_double)N.r;
545
- ex2 += (mf_double)N.r*N.r;
546
- }
547
- source.close();
548
-
549
- ex /= (mf_double)prob.nnz;
550
- ex2 /= (mf_double)prob.nnz;
551
- avg = (mf_float)ex;
552
- std_dev = (mf_float)sqrt(ex2-ex*ex);
553
- }
554
-
555
- void Utility::scale_problem(mf_problem &prob, mf_float scale)
556
- {
557
- if(scale == 1.0)
558
- return;
559
-
560
- #if defined USEOMP
561
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
562
- #endif
563
- for(mf_long i = 0; i < prob.nnz; ++i)
564
- prob.R[i].r *= scale;
565
- }
566
-
567
- void Utility::scale_model(mf_model &model, mf_float scale)
568
- {
569
- if(scale == 1.0)
570
- return;
571
-
572
- mf_int k = model.k;
573
-
574
- model.b *= scale;
575
-
576
- auto scale1 = [&] (mf_float *ptr, mf_int size, mf_float factor_scale)
577
- {
578
- #if defined USEOMP
579
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
580
- #endif
581
- for(mf_int i = 0; i < size; ++i)
582
- {
583
- mf_float *ptr1 = ptr+(mf_long)i*model.k;
584
- for(mf_int d = 0; d < k; ++d)
585
- ptr1[d] *= factor_scale;
586
- }
587
- };
588
-
589
- scale1(model.P, model.m, sqrt(scale));
590
- scale1(model.Q, model.n, sqrt(scale));
591
- }
592
-
593
- mf_float Utility::inner_product(mf_float *p, mf_float *q, mf_int k)
594
- {
595
- #if defined USESSE
596
- __m128 XMM = _mm_setzero_ps();
597
- for(mf_int d = 0; d < k; d += 4)
598
- XMM = _mm_add_ps(XMM, _mm_mul_ps(
599
- _mm_load_ps(p+d), _mm_load_ps(q+d)));
600
- __m128 XMMtmp = _mm_add_ps(XMM, _mm_movehl_ps(XMM, XMM));
601
- XMM = _mm_add_ps(XMM, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
602
- mf_float product;
603
- _mm_store_ss(&product, XMM);
604
- return product;
605
- #elif defined USEAVX
606
- __m256 XMM = _mm256_setzero_ps();
607
- for(mf_int d = 0; d < k; d += 8)
608
- XMM = _mm256_add_ps(XMM, _mm256_mul_ps(
609
- _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
610
- XMM = _mm256_add_ps(XMM, _mm256_permute2f128_ps(XMM, XMM, 1));
611
- XMM = _mm256_hadd_ps(XMM, XMM);
612
- XMM = _mm256_hadd_ps(XMM, XMM);
613
- mf_float product;
614
- _mm_store_ss(&product, _mm256_castps256_ps128(XMM));
615
- return product;
616
- #else
617
- return std::inner_product(p, p+k, q, (mf_float)0.0);
618
- #endif
619
- }
620
-
621
- mf_double Utility::calc_reg1(mf_model &model,
622
- mf_float lambda_p, mf_float lambda_q,
623
- vector<mf_int> &omega_p, vector<mf_int> &omega_q)
624
- {
625
- auto calc_reg1_core = [&] (mf_float *ptr, mf_int size,
626
- vector<mf_int> &omega)
627
- {
628
- mf_double reg = 0;
629
- for(mf_int i = 0; i < size; ++i)
630
- {
631
- if(omega[i] <= 0)
632
- continue;
633
-
634
- mf_float tmp = 0;
635
- for(mf_int j = 0; j < model.k; ++j)
636
- tmp += abs(ptr[(mf_long)i*model.k+j]);
637
- reg += omega[i]*tmp;
638
- }
639
- return reg;
640
- };
641
-
642
- return lambda_p*calc_reg1_core(model.P, model.m, omega_p)+
643
- lambda_q*calc_reg1_core(model.Q, model.n, omega_q);
644
- }
645
-
646
- mf_double Utility::calc_reg2(mf_model &model,
647
- mf_float lambda_p, mf_float lambda_q,
648
- vector<mf_int> &omega_p, vector<mf_int> &omega_q)
649
- {
650
- auto calc_reg2_core = [&] (mf_float *ptr, mf_int size,
651
- vector<mf_int> &omega)
652
- {
653
- mf_double reg = 0;
654
- #if defined USEOMP
655
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:reg)
656
- #endif
657
- for(mf_int i = 0; i < size; ++i)
658
- {
659
- if(omega[i] <= 0)
660
- continue;
661
-
662
- mf_float *ptr1 = ptr+(mf_long)i*model.k;
663
- reg += omega[i]*Utility::inner_product(ptr1, ptr1, model.k);
664
- }
665
-
666
- return reg;
667
- };
668
-
669
- return lambda_p*calc_reg2_core(model.P, model.m, omega_p) +
670
- lambda_q*calc_reg2_core(model.Q, model.n, omega_q);
671
- }
672
-
673
- mf_double Utility::calc_error(
674
- vector<BlockBase*> &blocks,
675
- vector<mf_int> &cv_block_ids,
676
- mf_model const &model)
677
- {
678
- mf_double error = 0;
679
- if(fun == P_L2_MFR || fun == P_L1_MFR || fun == P_KL_MFR ||
680
- fun == P_LR_MFC || fun == P_L2_MFC || fun == P_L1_MFC)
681
- {
682
- #if defined USEOMP
683
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
684
- #endif
685
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
686
- {
687
- BlockBase *block = blocks[cv_block_ids[i]];
688
- block->reload();
689
- while(block->move_next())
690
- {
691
- mf_node const &N = *(block->get_current());
692
- mf_float z = mf_predict(&model, N.u, N.v);
693
- switch(fun)
694
- {
695
- case P_L2_MFR:
696
- error += pow(N.r-z, 2);
697
- break;
698
- case P_L1_MFR:
699
- error += abs(N.r-z);
700
- break;
701
- case P_KL_MFR:
702
- error += N.r*log(N.r/z)-N.r+z;
703
- break;
704
- case P_LR_MFC:
705
- if(N.r > 0)
706
- error += log(1.0+exp(-z));
707
- else
708
- error += log(1.0+exp(z));
709
- break;
710
- case P_L2_MFC:
711
- case P_L1_MFC:
712
- if(N.r > 0)
713
- error += z > 0? 1: 0;
714
- else
715
- error += z < 0? 1: 0;
716
- break;
717
- default:
718
- throw invalid_argument("unknown error function");
719
- break;
720
- }
721
- }
722
- block->free();
723
- }
724
- }
725
- else
726
- {
727
- minstd_rand0 generator(rand());
728
- switch(fun)
729
- {
730
- case P_ROW_BPR_MFOC:
731
- {
732
- uniform_int_distribution<mf_int> distribution(0, model.n-1);
733
- #if defined USEOMP
734
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
735
- #endif
736
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
737
- {
738
- BlockBase *block = blocks[cv_block_ids[i]];
739
- block->reload();
740
- while(block->move_next())
741
- {
742
- mf_node const &N = *(block->get_current());
743
- mf_int w = distribution(generator);
744
- error += log(1+exp(mf_predict(&model, N.u, w)-
745
- mf_predict(&model, N.u, N.v)));
746
- }
747
- block->free();
748
- }
749
- break;
750
- }
751
- case P_COL_BPR_MFOC:
752
- {
753
- uniform_int_distribution<mf_int> distribution(0, model.m-1);
754
- #if defined USEOMP
755
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:error)
756
- #endif
757
- for(mf_int i = 0; i < (mf_long)cv_block_ids.size(); ++i)
758
- {
759
- BlockBase *block = blocks[cv_block_ids[i]];
760
- block->reload();
761
- while(block->move_next())
762
- {
763
- mf_node const &N = *(block->get_current());
764
- mf_int w = distribution(generator);
765
- error += log(1+exp(mf_predict(&model, w, N.v)-
766
- mf_predict(&model, N.u, N.v)));
767
- }
768
- block->free();
769
- }
770
- break;
771
- }
772
- default:
773
- {
774
- throw invalid_argument("unknown error function");
775
- break;
776
- }
777
- }
778
- }
779
-
780
- return error;
781
- }
782
-
783
- string Utility::get_error_legend() const
784
- {
785
- switch(fun)
786
- {
787
- case P_L2_MFR:
788
- return string("rmse");
789
- break;
790
- case P_L1_MFR:
791
- return string("mae");
792
- break;
793
- case P_KL_MFR:
794
- return string("gkl");
795
- break;
796
- case P_LR_MFC:
797
- return string("logloss");
798
- break;
799
- case P_L2_MFC:
800
- case P_L1_MFC:
801
- return string("accuracy");
802
- break;
803
- case P_ROW_BPR_MFOC:
804
- case P_COL_BPR_MFOC:
805
- return string("bprloss");
806
- break;
807
- case P_L2_MFOC:
808
- return string("sqerror");
809
- default:
810
- return string();
811
- break;
812
- }
813
- }
814
-
815
- void Utility::shuffle_problem(
816
- mf_problem &prob,
817
- vector<mf_int> &p_map,
818
- vector<mf_int> &q_map)
819
- {
820
- #if defined USEOMP
821
- #pragma omp parallel for num_threads(nr_threads) schedule(static)
822
- #endif
823
- for(mf_long i = 0; i < prob.nnz; ++i)
824
- {
825
- mf_node &N = prob.R[i];
826
- if(N.u < (mf_long)p_map.size())
827
- N.u = p_map[N.u];
828
- if(N.v < (mf_long)q_map.size())
829
- N.v = q_map[N.v];
830
- }
831
- }
832
-
833
- vector<mf_node*> Utility::grid_problem(
834
- mf_problem &prob,
835
- mf_int nr_bins,
836
- vector<mf_int> &omega_p,
837
- vector<mf_int> &omega_q,
838
- vector<Block> &blocks)
839
- {
840
- vector<mf_long> counts(nr_bins*nr_bins, 0);
841
-
842
- mf_int seg_p = (mf_int)ceil((double)prob.m/nr_bins);
843
- mf_int seg_q = (mf_int)ceil((double)prob.n/nr_bins);
844
-
845
- auto get_block_id = [=] (mf_int u, mf_int v)
846
- {
847
- return (u/seg_p)*nr_bins+v/seg_q;
848
- };
849
-
850
- for(mf_long i = 0; i < prob.nnz; ++i)
851
- {
852
- mf_node &N = prob.R[i];
853
- mf_int block = get_block_id(N.u, N.v);
854
- counts[block] += 1;
855
- omega_p[N.u] += 1;
856
- omega_q[N.v] += 1;
857
- }
858
-
859
- vector<mf_node*> ptrs(nr_bins*nr_bins+1);
860
- mf_node *ptr = prob.R;
861
- ptrs[0] = ptr;
862
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
863
- ptrs[block+1] = ptrs[block] + counts[block];
864
-
865
- vector<mf_node*> pivots(ptrs.begin(), ptrs.end()-1);
866
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
867
- {
868
- for(mf_node* pivot = pivots[block]; pivot != ptrs[block+1];)
869
- {
870
- mf_int curr_block = get_block_id(pivot->u, pivot->v);
871
- if(curr_block == block)
872
- {
873
- ++pivot;
874
- continue;
875
- }
876
-
877
- mf_node *next = pivots[curr_block];
878
- swap(*pivot, *next);
879
- pivots[curr_block] += 1;
880
- }
881
- }
882
-
883
- #if defined USEOMP
884
- #pragma omp parallel for num_threads(nr_threads) schedule(dynamic)
885
- #endif
886
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
887
- {
888
- if(prob.m > prob.n)
889
- sort(ptrs[block], ptrs[block+1], sort_node_by_p());
890
- else
891
- sort(ptrs[block], ptrs[block+1], sort_node_by_q());
892
- }
893
-
894
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
895
- blocks[i].tie_to(ptrs[i], ptrs[i+1]);
896
-
897
- return ptrs;
898
- }
899
-
900
- void Utility::grid_shuffle_scale_problem_on_disk(
901
- mf_int m, mf_int n, mf_int nr_bins,
902
- mf_float scale, string data_path,
903
- vector<mf_int> &p_map, vector<mf_int> &q_map,
904
- vector<mf_int> &omega_p, vector<mf_int> &omega_q,
905
- vector<BlockOnDisk> &blocks)
906
- {
907
- string const buffer_path = data_path+string(".disk");
908
- mf_int seg_p = (mf_int)ceil((double)m/nr_bins);
909
- mf_int seg_q = (mf_int)ceil((double)n/nr_bins);
910
- vector<mf_long> counts(nr_bins*nr_bins+1, 0);
911
- vector<mf_long> pivots(nr_bins*nr_bins, 0);
912
- ifstream source(data_path);
913
- fstream buffer(buffer_path, fstream::in|fstream::out|
914
- fstream::binary|fstream::trunc);
915
- auto get_block_id = [=] (mf_int u, mf_int v)
916
- {
917
- return (u/seg_p)*nr_bins+v/seg_q;
918
- };
919
-
920
- if(!source)
921
- throw ios::failure(string("cannot to open ")+data_path);
922
- if(!buffer)
923
- throw ios::failure(string("cannot to open ")+buffer_path);
924
-
925
- for(mf_node N; source >> N.u >> N.v >> N.r;)
926
- {
927
- N.u = p_map[N.u];
928
- N.v = q_map[N.v];
929
- mf_int bid = get_block_id(N.u, N.v);
930
- omega_p[N.u] += 1;
931
- omega_q[N.v] += 1;
932
- counts[bid+1] += 1;
933
- }
934
-
935
- for(mf_int i = 1; i < nr_bins*nr_bins+1; ++i)
936
- {
937
- counts[i] += counts[i-1];
938
- pivots[i-1] = counts[i-1];
939
- }
940
-
941
- source.clear();
942
- source.seekg(0);
943
- for(mf_node N; source >> N.u >> N.v >> N.r;)
944
- {
945
- N.u = p_map[N.u];
946
- N.v = q_map[N.v];
947
- N.r /= scale;
948
- mf_int bid = get_block_id(N.u, N.v);
949
- buffer.seekp(pivots[bid]*sizeof(mf_node));
950
- buffer.write((char*)&N, sizeof(mf_node));
951
- pivots[bid] += 1;
952
- }
953
-
954
- for(mf_int i = 0; i < nr_bins*nr_bins; ++i)
955
- {
956
- vector<mf_node> nodes(static_cast<size_t>(counts[i+1]-counts[i]));
957
- buffer.clear();
958
- buffer.seekg(counts[i]*sizeof(mf_node));
959
- buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
960
-
961
- if(m > n)
962
- sort(nodes.begin(), nodes.end(), sort_node_by_p());
963
- else
964
- sort(nodes.begin(), nodes.end(), sort_node_by_q());
965
-
966
- buffer.clear();
967
- buffer.seekp(counts[i]*sizeof(mf_node));
968
- buffer.write((char*)nodes.data(), sizeof(mf_node)*nodes.size());
969
- buffer.read((char*)nodes.data(), sizeof(mf_node)*nodes.size());
970
- }
971
-
972
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
973
- blocks[i].tie_to(buffer_path, counts[i], counts[i+1]);
974
- }
975
-
976
- mf_float* Utility::malloc_aligned_float(mf_long size)
977
- {
978
- // Check if conversion from mf_long to size_t causes overflow.
979
- if (size > numeric_limits<std::size_t>::max() / sizeof(mf_float) + 1)
980
- throw bad_alloc();
981
- // [REVIEW] I hope one day we can use C11 aligned_alloc to replace
982
- // platform-depedent functions below. Both of Windows and OSX currently
983
- // don't support that function.
984
- void *ptr = nullptr;
985
- #ifdef _WIN32
986
- ptr = _aligned_malloc(static_cast<size_t>(size*sizeof(mf_float)),
987
- kALIGNByte);
988
- #else
989
- int status = posix_memalign(&ptr, kALIGNByte, size*sizeof(mf_float));
990
- if(status != 0)
991
- throw bad_alloc();
992
- #endif
993
- if(ptr == nullptr)
994
- throw bad_alloc();
995
-
996
- return (mf_float*)ptr;
997
- }
998
-
999
- void Utility::free_aligned_float(mf_float *ptr)
1000
- {
1001
- #ifdef _WIN32
1002
- // Unfortunately, Visual Studio doesn't want to support the
1003
- // cross-platform allocation below.
1004
- _aligned_free(ptr);
1005
- #else
1006
- free(ptr);
1007
- #endif
1008
- }
1009
-
1010
- mf_model* Utility::init_model(mf_int fun,
1011
- mf_int m, mf_int n,
1012
- mf_int k, mf_float avg,
1013
- vector<mf_int> &omega_p,
1014
- vector<mf_int> &omega_q)
1015
- {
1016
- mf_int k_real = k;
1017
- mf_int k_aligned = (mf_int)ceil(mf_double(k)/kALIGN)*kALIGN;
1018
-
1019
- mf_model *model = new mf_model;
1020
-
1021
- model->fun = fun;
1022
- model->m = m;
1023
- model->n = n;
1024
- model->k = k_aligned;
1025
- model->b = avg;
1026
- model->P = nullptr;
1027
- model->Q = nullptr;
1028
-
1029
- mf_float scale = (mf_float)sqrt(1.0/k_real);
1030
- default_random_engine generator;
1031
- uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1032
-
1033
- try
1034
- {
1035
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1036
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1037
- }
1038
- catch(bad_alloc const &e)
1039
- {
1040
- cerr << e.what() << endl;
1041
- mf_destroy_model(&model);
1042
- throw;
1043
- }
1044
-
1045
- auto init1 = [&](mf_float *start_ptr, mf_long size, vector<mf_int> counts)
1046
- {
1047
- memset(start_ptr, 0, static_cast<size_t>(
1048
- sizeof(mf_float) * size*model->k));
1049
- for(mf_long i = 0; i < size; ++i)
1050
- {
1051
- mf_float * ptr = start_ptr + i*model->k;
1052
- if(counts[static_cast<size_t>(i)] > 0)
1053
- for(mf_long d = 0; d < k_real; ++d, ++ptr)
1054
- *ptr = (mf_float)(distribution(generator)*scale);
1055
- else
1056
- if(fun != P_ROW_BPR_MFOC && fun != P_COL_BPR_MFOC) // unseen for bpr is 0
1057
- for(mf_long d = 0; d < k_real; ++d, ++ptr)
1058
- *ptr = numeric_limits<mf_float>::quiet_NaN();
1059
- }
1060
- };
1061
-
1062
- init1(model->P, m, omega_p);
1063
- init1(model->Q, n, omega_q);
1064
-
1065
- return model;
1066
- }
1067
-
1068
- // Initialize P=[\bar{p}_1, ..., \bar{p}_d] and Q=[\bar{q}_1, ..., \bar{q}_d].
1069
- // Note that \bar{q}_{kv} is Q[k*n+v] and \bar{p}_{ku} is P[k*m+u]. One may
1070
- // notice that P and Q here are actually the transposes of P and Q in fpsg(...)
1071
- // because fpsg(...) uses P^TQ (where P and Q are respectively k-by-m and
1072
- // k-by-n) to approximate the given rating matrix R while ccd_one_class(...)
1073
- // uses PQ^T (where P and Q are respectively m-by-k and n-by-k.
1074
- mf_model* Utility::init_model(mf_int m, mf_int n, mf_int k)
1075
- {
1076
- mf_model *model = new mf_model;
1077
-
1078
- model->fun = P_L2_MFOC;
1079
- model->m = m;
1080
- model->n = n;
1081
- model->k = k;
1082
- model->b = 0.0; // One-class matrix factorization doesn't have bias.
1083
- model->P = nullptr;
1084
- model->Q = nullptr;
1085
-
1086
- try
1087
- {
1088
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
1089
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
1090
- }
1091
- catch(bad_alloc const &e)
1092
- {
1093
- cerr << e.what() << endl;
1094
- mf_destroy_model(&model);
1095
- throw;
1096
- }
1097
-
1098
- // Our initialization strategy is that all P's elements are zero and do
1099
- // random initization on Q. Thus, all initial predicted ratings are all zero
1100
- // since the approximated rating matrix is PQ^T.
1101
-
1102
- // Initialize P with zeros
1103
- for(mf_long i = 0; i < k * m; ++i)
1104
- model->P[i] = 0.0;
1105
-
1106
- // Initialize Q with random numbers
1107
- default_random_engine generator;
1108
- uniform_real_distribution<mf_float> distribution(0.0, 1.0);
1109
- for(mf_long i = 0; i < k * n; ++i)
1110
- model->Q[i] = distribution(generator);
1111
-
1112
- return model;
1113
- }
1114
-
1115
- vector<mf_int> Utility::gen_random_map(mf_int size)
1116
- {
1117
- srand(0);
1118
- vector<mf_int> map(size, 0);
1119
- for(mf_int i = 0; i < size; ++i)
1120
- map[i] = i;
1121
- random_shuffle(map.begin(), map.end());
1122
- return map;
1123
- }
1124
-
1125
- vector<mf_int> Utility::gen_inv_map(vector<mf_int> &map)
1126
- {
1127
- vector<mf_int> inv_map(map.size());
1128
- for(mf_int i = 0; i < (mf_long)map.size(); ++i)
1129
- inv_map[map[i]] = i;
1130
- return inv_map;
1131
- }
1132
-
1133
- void Utility::shuffle_model(
1134
- mf_model &model,
1135
- vector<mf_int> &p_map,
1136
- vector<mf_int> &q_map)
1137
- {
1138
- auto inv_shuffle1 = [] (mf_float *vec, vector<mf_int> &map,
1139
- mf_int size, mf_int k)
1140
- {
1141
- for(mf_int pivot = 0; pivot < size;)
1142
- {
1143
- if(pivot == map[pivot])
1144
- {
1145
- ++pivot;
1146
- continue;
1147
- }
1148
-
1149
- mf_int next = map[pivot];
1150
-
1151
- for(mf_int d = 0; d < k; ++d)
1152
- swap(*(vec+(mf_long)pivot*k+d), *(vec+(mf_long)next*k+d));
1153
-
1154
- map[pivot] = map[next];
1155
- map[next] = next;
1156
- }
1157
- };
1158
-
1159
- inv_shuffle1(model.P, p_map, model.m, model.k);
1160
- inv_shuffle1(model.Q, q_map, model.n, model.k);
1161
- }
1162
-
1163
- void Utility::shrink_model(mf_model &model, mf_int k_new)
1164
- {
1165
- mf_int k_old = model.k;
1166
- model.k = k_new;
1167
-
1168
- auto shrink1 = [&] (mf_float *ptr, mf_int size)
1169
- {
1170
- for(mf_int i = 0; i < size; ++i)
1171
- {
1172
- mf_float *src = ptr+(mf_long)i*k_old;
1173
- mf_float *dst = ptr+(mf_long)i*k_new;
1174
- copy(src, src+k_new, dst);
1175
- }
1176
- };
1177
-
1178
- shrink1(model.P, model.m);
1179
- shrink1(model.Q, model.n);
1180
- }
1181
-
1182
- mf_problem* Utility::copy_problem(mf_problem const *prob, bool copy_data)
1183
- {
1184
- mf_problem *new_prob = new mf_problem;
1185
-
1186
- if(prob == nullptr)
1187
- {
1188
- new_prob->m = 0;
1189
- new_prob->n = 0;
1190
- new_prob->nnz = 0;
1191
- new_prob->R = nullptr;
1192
-
1193
- return new_prob;
1194
- }
1195
-
1196
- new_prob->m = prob->m;
1197
- new_prob->n = prob->n;
1198
- new_prob->nnz = prob->nnz;
1199
-
1200
- if(copy_data)
1201
- {
1202
- try
1203
- {
1204
- new_prob->R = new mf_node[static_cast<size_t>(prob->nnz)];
1205
- copy(prob->R, prob->R+prob->nnz, new_prob->R);
1206
- }
1207
- catch(...)
1208
- {
1209
- delete new_prob;
1210
- throw;
1211
- }
1212
- }
1213
- else
1214
- {
1215
- new_prob->R = prob->R;
1216
- }
1217
-
1218
- return new_prob;
1219
- }
1220
-
1221
- //--------------------------------------
1222
- //-----The base class of all solvers----
1223
- //--------------------------------------
1224
-
1225
- class SolverBase
1226
- {
1227
- public:
1228
- SolverBase(Scheduler &scheduler, vector<BlockBase*> &blocks,
1229
- mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
1230
- bool &slow_only)
1231
- : scheduler(scheduler), blocks(blocks), PG(PG), QG(QG),
1232
- model(model), param(param), slow_only(slow_only) {}
1233
- void run();
1234
- SolverBase(const SolverBase&) = delete;
1235
- SolverBase& operator=(const SolverBase&) = delete;
1236
- // Solver is stateless functor, so default destructor should be
1237
- // good enough.
1238
- virtual ~SolverBase() = default;
1239
-
1240
- protected:
1241
- #if defined USESSE
1242
- static void calc_z(__m128 &XMMz, mf_int k, mf_float *p, mf_float *q);
1243
- virtual void load_fixed_variables(
1244
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1245
- __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
1246
- __m128 &XMMeta, __m128 &XMMrk_slow,
1247
- __m128 &XMMrk_fast);
1248
- virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1249
- virtual void prepare_for_sg_update(
1250
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1251
- virtual void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1252
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1253
- __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
1254
- __m128 XMMeta, __m128 XMMrk) = 0;
1255
- virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1256
- #elif defined USEAVX
1257
- static void calc_z(__m256 &XMMz, mf_int k, mf_float *p, mf_float *q);
1258
- virtual void load_fixed_variables(
1259
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1260
- __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
1261
- __m256 &XMMeta, __m256 &XMMrk_slow,
1262
- __m256 &XMMrk_fast);
1263
- virtual void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
1264
- virtual void prepare_for_sg_update(
1265
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror) = 0;
1266
- virtual void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1267
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1268
- __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
1269
- __m256 XMMeta, __m256 XMMrk) = 0;
1270
- virtual void finalize(__m128d XMMloss, __m128d XMMerror);
1271
- #else
1272
- static void calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q);
1273
- virtual void load_fixed_variables();
1274
- virtual void arrange_block();
1275
- virtual void prepare_for_sg_update() = 0;
1276
- virtual void sg_update(mf_int d_begin, mf_int d_end, mf_float rk) = 0;
1277
- virtual void finalize();
1278
- static float qrsqrt(float x);
1279
- #endif
1280
- virtual void update() { ++pG; ++qG; };
1281
-
1282
- Scheduler &scheduler;
1283
- vector<BlockBase*> &blocks;
1284
- BlockBase *block;
1285
- mf_float *PG;
1286
- mf_float *QG;
1287
- mf_model &model;
1288
- mf_parameter param;
1289
- bool &slow_only;
1290
-
1291
- mf_node *N;
1292
- mf_float z;
1293
- mf_double loss;
1294
- mf_double error;
1295
- mf_float *p;
1296
- mf_float *q;
1297
- mf_float *pG;
1298
- mf_float *qG;
1299
- mf_int bid;
1300
-
1301
- mf_float lambda_p1;
1302
- mf_float lambda_q1;
1303
- mf_float lambda_p2;
1304
- mf_float lambda_q2;
1305
- mf_float rk_slow;
1306
- mf_float rk_fast;
1307
- };
1308
-
1309
- #if defined USESSE
1310
- inline void SolverBase::run()
1311
- {
1312
- __m128d XMMloss;
1313
- __m128d XMMerror;
1314
- __m128 XMMz;
1315
- __m128 XMMlambda_p1;
1316
- __m128 XMMlambda_q1;
1317
- __m128 XMMlambda_p2;
1318
- __m128 XMMlambda_q2;
1319
- __m128 XMMeta;
1320
- __m128 XMMrk_slow;
1321
- __m128 XMMrk_fast;
1322
- load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1323
- XMMlambda_p2, XMMlambda_q2,
1324
- XMMeta, XMMrk_slow,
1325
- XMMrk_fast);
1326
- while(!scheduler.is_terminated())
1327
- {
1328
- arrange_block(XMMloss, XMMerror);
1329
- while(block->move_next())
1330
- {
1331
- N = block->get_current();
1332
- p = model.P+(mf_long)N->u*model.k;
1333
- q = model.Q+(mf_long)N->v*model.k;
1334
- pG = PG+N->u*2;
1335
- qG = QG+N->v*2;
1336
- prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1337
- sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1338
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1339
- if(slow_only)
1340
- continue;
1341
- update();
1342
- sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1343
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1344
- }
1345
- finalize(XMMloss, XMMerror);
1346
- }
1347
- }
1348
-
1349
- void SolverBase::load_fixed_variables(
1350
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
1351
- __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
1352
- __m128 &XMMeta, __m128 &XMMrk_slow,
1353
- __m128 &XMMrk_fast)
1354
- {
1355
- XMMlambda_p1 = _mm_set1_ps(param.lambda_p1);
1356
- XMMlambda_q1 = _mm_set1_ps(param.lambda_q1);
1357
- XMMlambda_p2 = _mm_set1_ps(param.lambda_p2);
1358
- XMMlambda_q2 = _mm_set1_ps(param.lambda_q2);
1359
- XMMeta = _mm_set1_ps(param.eta);
1360
- XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
1361
- XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
1362
- }
1363
-
1364
- void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1365
- {
1366
- XMMloss = _mm_setzero_pd();
1367
- XMMerror = _mm_setzero_pd();
1368
- bid = scheduler.get_job();
1369
- block = blocks[bid];
1370
- block->reload();
1371
- }
1372
-
1373
- inline void SolverBase::calc_z(
1374
- __m128 &XMMz, mf_int k, mf_float *p, mf_float *q)
1375
- {
1376
- XMMz = _mm_setzero_ps();
1377
- for(mf_int d = 0; d < k; d += 4)
1378
- XMMz = _mm_add_ps(XMMz, _mm_mul_ps(
1379
- _mm_load_ps(p+d), _mm_load_ps(q+d)));
1380
- // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
1381
- // high-bit to low-bit, where "+" means concatenating two arrays.
1382
- __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
1383
- // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
1384
- // high-bit to low-bit, where "+" means concatenating two arrays.
1385
- XMMz = _mm_add_ps(XMMtmp, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
1386
- }
1387
-
1388
- void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1389
- {
1390
- _mm_store_sd(&loss, XMMloss);
1391
- _mm_store_sd(&error, XMMerror);
1392
- block->free();
1393
- scheduler.put_job(bid, loss, error);
1394
- }
1395
- #elif defined USEAVX
1396
- inline void SolverBase::run()
1397
- {
1398
- __m128d XMMloss;
1399
- __m128d XMMerror;
1400
- __m256 XMMz;
1401
- __m256 XMMlambda_p1;
1402
- __m256 XMMlambda_q1;
1403
- __m256 XMMlambda_p2;
1404
- __m256 XMMlambda_q2;
1405
- __m256 XMMeta;
1406
- __m256 XMMrk_slow;
1407
- __m256 XMMrk_fast;
1408
- load_fixed_variables(XMMlambda_p1, XMMlambda_q1,
1409
- XMMlambda_p2, XMMlambda_q2,
1410
- XMMeta, XMMrk_slow, XMMrk_fast);
1411
- while(!scheduler.is_terminated())
1412
- {
1413
- arrange_block(XMMloss, XMMerror);
1414
- while(block->move_next())
1415
- {
1416
- N = block->get_current();
1417
- p = model.P+(mf_long)N->u*model.k;
1418
- q = model.Q+(mf_long)N->v*model.k;
1419
- pG = PG+N->u*2;
1420
- qG = QG+N->v*2;
1421
- prepare_for_sg_update(XMMz, XMMloss, XMMerror);
1422
- sg_update(0, kALIGN, XMMz, XMMlambda_p1, XMMlambda_q1,
1423
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_slow);
1424
- if(slow_only)
1425
- continue;
1426
- update();
1427
- sg_update(kALIGN, model.k, XMMz, XMMlambda_p1, XMMlambda_q1,
1428
- XMMlambda_p2, XMMlambda_q2, XMMeta, XMMrk_fast);
1429
- }
1430
- finalize(XMMloss, XMMerror);
1431
- }
1432
- }
1433
-
1434
- void SolverBase::load_fixed_variables(
1435
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
1436
- __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
1437
- __m256 &XMMeta, __m256 &XMMrk_slow,
1438
- __m256 &XMMrk_fast)
1439
- {
1440
- XMMlambda_p1 = _mm256_set1_ps(param.lambda_p1);
1441
- XMMlambda_q1 = _mm256_set1_ps(param.lambda_q1);
1442
- XMMlambda_p2 = _mm256_set1_ps(param.lambda_p2);
1443
- XMMlambda_q2 = _mm256_set1_ps(param.lambda_q2);
1444
- XMMeta = _mm256_set1_ps(param.eta);
1445
- XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
1446
- XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
1447
- }
1448
-
1449
- void SolverBase::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
1450
- {
1451
- XMMloss = _mm_setzero_pd();
1452
- XMMerror = _mm_setzero_pd();
1453
- bid = scheduler.get_job();
1454
- block = blocks[bid];
1455
- block->reload();
1456
- }
1457
-
1458
- inline void SolverBase::calc_z(
1459
- __m256 &XMMz, mf_int k, mf_float *p, mf_float *q)
1460
- {
1461
- XMMz = _mm256_setzero_ps();
1462
- for(mf_int d = 0; d < k; d += 8)
1463
- XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
1464
- _mm256_load_ps(p+d), _mm256_load_ps(q+d)));
1465
- XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
1466
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
1467
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
1468
- }
1469
-
1470
- void SolverBase::finalize(__m128d XMMloss, __m128d XMMerror)
1471
- {
1472
- _mm_store_sd(&loss, XMMloss);
1473
- _mm_store_sd(&error, XMMerror);
1474
- block->free();
1475
- scheduler.put_job(bid, loss, error);
1476
- }
1477
- #else
1478
- inline void SolverBase::run()
1479
- {
1480
- load_fixed_variables();
1481
- while(!scheduler.is_terminated())
1482
- {
1483
- arrange_block();
1484
- while(block->move_next())
1485
- {
1486
- N = block->get_current();
1487
- p = model.P+(mf_long)N->u*model.k;
1488
- q = model.Q+(mf_long)N->v*model.k;
1489
- pG = PG+N->u*2;
1490
- qG = QG+N->v*2;
1491
- prepare_for_sg_update();
1492
- sg_update(0, kALIGN, rk_slow);
1493
- if(slow_only)
1494
- continue;
1495
- update();
1496
- sg_update(kALIGN, model.k, rk_fast);
1497
- }
1498
- finalize();
1499
- }
1500
- }
1501
-
1502
- inline float SolverBase::qrsqrt(float x)
1503
- {
1504
- float xhalf = 0.5f*x;
1505
- uint32_t i;
1506
- memcpy(&i, &x, sizeof(i));
1507
- i = 0x5f375a86 - (i>>1);
1508
- memcpy(&x, &i, sizeof(i));
1509
- x = x*(1.5f - xhalf*x*x);
1510
- return x;
1511
- }
1512
-
1513
- void SolverBase::load_fixed_variables()
1514
- {
1515
- lambda_p1 = param.lambda_p1;
1516
- lambda_q1 = param.lambda_q1;
1517
- lambda_p2 = param.lambda_p2;
1518
- lambda_q2 = param.lambda_q2;
1519
- rk_slow = (mf_float)1.0/kALIGN;
1520
- rk_fast = (mf_float)1.0/(model.k-kALIGN);
1521
- }
1522
-
1523
- void SolverBase::arrange_block()
1524
- {
1525
- loss = 0.0;
1526
- error = 0.0;
1527
- bid = scheduler.get_job();
1528
- block = blocks[bid];
1529
- block->reload();
1530
- }
1531
-
1532
- inline void SolverBase::calc_z(mf_float &z, mf_int k, mf_float *p, mf_float *q)
1533
- {
1534
- z = 0;
1535
- for(mf_int d = 0; d < k; ++d)
1536
- z += p[d]*q[d];
1537
- }
1538
-
1539
- void SolverBase::finalize()
1540
- {
1541
- block->free();
1542
- scheduler.put_job(bid, loss, error);
1543
- }
1544
- #endif
1545
-
1546
- //--------------------------------------
1547
- //-----Real-valued MF and binary MF-----
1548
- //--------------------------------------
1549
-
1550
- class MFSolver: public SolverBase
1551
- {
1552
- public:
1553
- MFSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
1554
- mf_float *PG, mf_float *QG, mf_model &model,
1555
- mf_parameter param, bool &slow_only)
1556
- : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only) {}
1557
-
1558
- protected:
1559
- #if defined USESSE
1560
- void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1561
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1562
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1563
- __m128 XMMeta, __m128 XMMrk);
1564
- #elif defined USEAVX
1565
- void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1566
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1567
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1568
- __m256 XMMeta, __m256 XMMrk);
1569
- #else
1570
- void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
1571
- #endif
1572
- };
1573
-
1574
- #if defined USESSE
1575
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
1576
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
1577
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
1578
- __m128 XMMeta, __m128 XMMrk)
1579
- {
1580
- __m128 XMMpG = _mm_load1_ps(pG);
1581
- __m128 XMMqG = _mm_load1_ps(qG);
1582
- __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
1583
- __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
1584
- __m128 XMMpG1 = _mm_setzero_ps();
1585
- __m128 XMMqG1 = _mm_setzero_ps();
1586
-
1587
- for(mf_int d = d_begin; d < d_end; d += 4)
1588
- {
1589
- __m128 XMMp = _mm_load_ps(p+d);
1590
- __m128 XMMq = _mm_load_ps(q+d);
1591
-
1592
- __m128 XMMpg = _mm_sub_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
1593
- _mm_mul_ps(XMMz, XMMq));
1594
- __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
1595
- _mm_mul_ps(XMMz, XMMp));
1596
-
1597
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
1598
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
1599
-
1600
- XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
1601
- XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
1602
-
1603
- _mm_store_ps(p+d, XMMp);
1604
- _mm_store_ps(q+d, XMMq);
1605
- }
1606
-
1607
- mf_float tmp = 0;
1608
- _mm_store_ss(&tmp, XMMlambda_p1);
1609
- if(tmp > 0)
1610
- {
1611
- for(mf_int d = d_begin; d < d_end; d += 4)
1612
- {
1613
- __m128 XMMp = _mm_load_ps(p+d);
1614
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
1615
- _mm_set1_ps(-0.0f));
1616
- XMMp = _mm_xor_ps(XMMflip,
1617
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
1618
- _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
1619
- _mm_store_ps(p+d, XMMp);
1620
- }
1621
- }
1622
-
1623
- _mm_store_ss(&tmp, XMMlambda_q1);
1624
- if(tmp > 0)
1625
- {
1626
- for(mf_int d = d_begin; d < d_end; d += 4)
1627
- {
1628
- __m128 XMMq = _mm_load_ps(q+d);
1629
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
1630
- _mm_set1_ps(-0.0f));
1631
- XMMq = _mm_xor_ps(XMMflip,
1632
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
1633
- _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
1634
- _mm_store_ps(q+d, XMMq);
1635
- }
1636
- }
1637
-
1638
- if(param.do_nmf)
1639
- {
1640
- for(mf_int d = d_begin; d < d_end; d += 4)
1641
- {
1642
- __m128 XMMp = _mm_load_ps(p+d);
1643
- __m128 XMMq = _mm_load_ps(q+d);
1644
- XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
1645
- XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
1646
- _mm_store_ps(p+d, XMMp);
1647
- _mm_store_ps(q+d, XMMq);
1648
- }
1649
- }
1650
-
1651
- __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
1652
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1653
- XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
1654
- _mm_store_ss(pG, XMMpG);
1655
-
1656
- XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
1657
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
1658
- XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
1659
- _mm_store_ss(qG, XMMqG);
1660
- }
1661
- #elif defined USEAVX
1662
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
1663
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
1664
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
1665
- __m256 XMMeta, __m256 XMMrk)
1666
- {
1667
- __m256 XMMpG = _mm256_broadcast_ss(pG);
1668
- __m256 XMMqG = _mm256_broadcast_ss(qG);
1669
- __m256 XMMeta_p = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
1670
- __m256 XMMeta_q = _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
1671
- __m256 XMMpG1 = _mm256_setzero_ps();
1672
- __m256 XMMqG1 = _mm256_setzero_ps();
1673
-
1674
- for(mf_int d = d_begin; d < d_end; d += 8)
1675
- {
1676
- __m256 XMMp = _mm256_load_ps(p+d);
1677
- __m256 XMMq = _mm256_load_ps(q+d);
1678
-
1679
- __m256 XMMpg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
1680
- _mm256_mul_ps(XMMz, XMMq));
1681
- __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
1682
- _mm256_mul_ps(XMMz, XMMp));
1683
-
1684
- XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
1685
- XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
1686
-
1687
- XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
1688
- XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
1689
- _mm256_store_ps(p+d, XMMp);
1690
- _mm256_store_ps(q+d, XMMq);
1691
- }
1692
-
1693
- mf_float tmp = 0;
1694
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
1695
- if(tmp > 0)
1696
- {
1697
- for(mf_int d = d_begin; d < d_end; d += 8)
1698
- {
1699
- __m256 XMMp = _mm256_load_ps(p+d);
1700
- __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMp,
1701
- _mm256_set1_ps(0.0f), _CMP_LE_OS),
1702
- _mm256_set1_ps(-0.0f));
1703
- XMMp = _mm256_xor_ps(XMMflip,
1704
- _mm256_max_ps(_mm256_sub_ps(
1705
- _mm256_xor_ps(XMMp, XMMflip),
1706
- _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
1707
- _mm256_set1_ps(0.0f)));
1708
- _mm256_store_ps(p+d, XMMp);
1709
- }
1710
- }
1711
-
1712
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
1713
- if(tmp > 0)
1714
- {
1715
- for(mf_int d = d_begin; d < d_end; d += 8)
1716
- {
1717
- __m256 XMMq = _mm256_load_ps(q+d);
1718
- __m256 XMMflip = _mm256_and_ps(_mm256_cmp_ps(XMMq,
1719
- _mm256_set1_ps(0.0f), _CMP_LE_OS),
1720
- _mm256_set1_ps(-0.0f));
1721
- XMMq = _mm256_xor_ps(XMMflip,
1722
- _mm256_max_ps(_mm256_sub_ps(
1723
- _mm256_xor_ps(XMMq, XMMflip),
1724
- _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
1725
- _mm256_set1_ps(0.0f)));
1726
- _mm256_store_ps(q+d, XMMq);
1727
- }
1728
- }
1729
-
1730
- if(param.do_nmf)
1731
- {
1732
- for(mf_int d = d_begin; d < d_end; d += 8)
1733
- {
1734
- __m256 XMMp = _mm256_load_ps(p+d);
1735
- __m256 XMMq = _mm256_load_ps(q+d);
1736
- XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0));
1737
- XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0));
1738
- _mm256_store_ps(p+d, XMMp);
1739
- _mm256_store_ps(q+d, XMMq);
1740
- }
1741
- }
1742
-
1743
- XMMpG1 = _mm256_add_ps(XMMpG1,
1744
- _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
1745
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1746
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
1747
-
1748
- XMMqG1 = _mm256_add_ps(XMMqG1,
1749
- _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
1750
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1751
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
1752
-
1753
- XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
1754
- XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
1755
-
1756
- _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
1757
- _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
1758
- }
1759
- #else
1760
- void MFSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
1761
- {
1762
- mf_float eta_p = param.eta*qrsqrt(*pG);
1763
- mf_float eta_q = param.eta*qrsqrt(*qG);
1764
-
1765
- mf_float pG1 = 0;
1766
- mf_float qG1 = 0;
1767
-
1768
- for(mf_int d = d_begin; d < d_end; ++d)
1769
- {
1770
- mf_float gp = -z*q[d]+lambda_p2*p[d];
1771
- mf_float gq = -z*p[d]+lambda_q2*q[d];
1772
-
1773
- pG1 += gp*gp;
1774
- qG1 += gq*gq;
1775
-
1776
- p[d] -= eta_p*gp;
1777
- q[d] -= eta_q*gq;
1778
- }
1779
-
1780
- if(lambda_p1 > 0)
1781
- {
1782
- for(mf_int d = d_begin; d < d_end; ++d)
1783
- {
1784
- mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
1785
- p[d] = p[d] >= 0? p1: -p1;
1786
- }
1787
- }
1788
-
1789
- if(lambda_q1 > 0)
1790
- {
1791
- for(mf_int d = d_begin; d < d_end; ++d)
1792
- {
1793
- mf_float q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
1794
- q[d] = q[d] >= 0? q1: -q1;
1795
- }
1796
- }
1797
-
1798
- if(param.do_nmf)
1799
- {
1800
- for(mf_int d = d_begin; d < d_end; ++d)
1801
- {
1802
- p[d] = max(p[d], (mf_float)0.0f);
1803
- q[d] = max(q[d], (mf_float)0.0f);
1804
- }
1805
- }
1806
-
1807
- *pG += pG1*rk;
1808
- *qG += qG1*rk;
1809
- }
1810
- #endif
1811
-
1812
- class L2_MFR : public MFSolver
1813
- {
1814
- public:
1815
- L2_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1816
- mf_model &model, mf_parameter param, bool &slow_only)
1817
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1818
-
1819
- protected:
1820
- #if defined USESSE
1821
- void prepare_for_sg_update(
1822
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1823
- #elif defined USEAVX
1824
- void prepare_for_sg_update(
1825
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1826
- #else
1827
- void prepare_for_sg_update();
1828
- #endif
1829
- };
1830
-
1831
- #if defined USESSE
1832
- void L2_MFR::prepare_for_sg_update(
1833
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1834
- {
1835
- calc_z(XMMz, model.k, p, q);
1836
- XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1837
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1838
- _mm_mul_ps(XMMz, XMMz)));
1839
- XMMerror = XMMloss;
1840
- }
1841
- #elif defined USEAVX
1842
- void L2_MFR::prepare_for_sg_update(
1843
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1844
- {
1845
- calc_z(XMMz, model.k, p, q);
1846
- XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1847
- XMMloss = _mm_add_pd(XMMloss,
1848
- _mm_cvtps_pd(_mm256_castps256_ps128(
1849
- _mm256_mul_ps(XMMz, XMMz))));
1850
- XMMerror = XMMloss;
1851
- }
1852
- #else
1853
- void L2_MFR::prepare_for_sg_update()
1854
- {
1855
- calc_z(z, model.k, p, q);
1856
- z = N->r-z;
1857
- loss += z*z;
1858
- error = loss;
1859
- }
1860
- #endif
1861
- class L1_MFR : public MFSolver
1862
- {
1863
- public:
1864
- L1_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1865
- mf_model &model, mf_parameter param, bool &slow_only)
1866
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1867
-
1868
- protected:
1869
- #if defined USESSE
1870
- void prepare_for_sg_update(
1871
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1872
- #elif defined USEAVX
1873
- void prepare_for_sg_update(
1874
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1875
- #else
1876
- void prepare_for_sg_update();
1877
- #endif
1878
- };
1879
-
1880
- #if defined USESSE
1881
- void L1_MFR::prepare_for_sg_update(
1882
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1883
- {
1884
- calc_z(XMMz, model.k, p, q);
1885
- XMMz = _mm_sub_ps(_mm_set1_ps(N->r), XMMz);
1886
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1887
- _mm_andnot_ps(_mm_set1_ps(-0.0f), XMMz)));
1888
- XMMerror = XMMloss;
1889
- XMMz = _mm_add_ps(_mm_and_ps(_mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f)),
1890
- _mm_set1_ps(1.0f)),
1891
- _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
1892
- _mm_set1_ps(-1.0f)));
1893
- }
1894
- #elif defined USEAVX
1895
- void L1_MFR::prepare_for_sg_update(
1896
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1897
- {
1898
- calc_z(XMMz, model.k, p, q);
1899
- XMMz = _mm256_sub_ps(_mm256_set1_ps(N->r), XMMz);
1900
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm256_castps256_ps128(
1901
- _mm256_andnot_ps(_mm256_set1_ps(-0.0f), XMMz))));
1902
- XMMerror = XMMloss;
1903
- XMMz = _mm256_add_ps(_mm256_and_ps(_mm256_cmp_ps(XMMz,
1904
- _mm256_set1_ps(0.0f), _CMP_GT_OS), _mm256_set1_ps(1.0f)),
1905
- _mm256_and_ps(_mm256_cmp_ps(XMMz,
1906
- _mm256_set1_ps(0.0f), _CMP_LT_OS), _mm256_set1_ps(-1.0f)));
1907
- }
1908
- #else
1909
- void L1_MFR::prepare_for_sg_update()
1910
- {
1911
- calc_z(z, model.k, p, q);
1912
- z = N->r-z;
1913
- loss += abs(z);
1914
- error = loss;
1915
- if(z > 0)
1916
- z = 1;
1917
- else if(z < 0)
1918
- z = -1;
1919
- }
1920
- #endif
1921
-
1922
- class KL_MFR : public MFSolver
1923
- {
1924
- public:
1925
- KL_MFR(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
1926
- mf_model &model, mf_parameter param, bool &slow_only)
1927
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1928
-
1929
- protected:
1930
- #if defined USESSE
1931
- void prepare_for_sg_update(
1932
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1933
- #elif defined USEAVX
1934
- void prepare_for_sg_update(
1935
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1936
- #else
1937
- void prepare_for_sg_update();
1938
- #endif
1939
- };
1940
-
1941
- #if defined USESSE
1942
- void KL_MFR::prepare_for_sg_update(
1943
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1944
- {
1945
- calc_z(XMMz, model.k, p, q);
1946
- XMMz = _mm_div_ps(_mm_set1_ps(N->r), XMMz);
1947
- _mm_store_ss(&z, XMMz);
1948
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1949
- _mm_set1_ps(N->r*(log(z)-1+1/z))));
1950
- XMMerror = XMMloss;
1951
- XMMz = _mm_sub_ps(XMMz, _mm_set1_ps(1.0f));
1952
- }
1953
- #elif defined USEAVX
1954
- void KL_MFR::prepare_for_sg_update(
1955
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1956
- {
1957
- calc_z(XMMz, model.k, p, q);
1958
- XMMz = _mm256_div_ps(_mm256_set1_ps(N->r), XMMz);
1959
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
1960
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
1961
- _mm_set1_ps(N->r*(log(z)-1+1/z))));
1962
- XMMerror = XMMloss;
1963
- XMMz = _mm256_sub_ps(XMMz, _mm256_set1_ps(1.0f));
1964
- }
1965
- #else
1966
- void KL_MFR::prepare_for_sg_update()
1967
- {
1968
- calc_z(z, model.k, p, q);
1969
- z = N->r/z;
1970
- loss += N->r*(log(z)-1+1/z);
1971
- error = loss;
1972
- z -= 1;
1973
- }
1974
- #endif
1975
-
1976
- class LR_MFC : public MFSolver
1977
- {
1978
- public:
1979
- LR_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
1980
- mf_float *PG, mf_float *QG, mf_model &model,
1981
- mf_parameter param, bool &slow_only)
1982
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
1983
-
1984
- protected:
1985
- #if defined USESSE
1986
- void prepare_for_sg_update(
1987
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1988
- #elif defined USEAVX
1989
- void prepare_for_sg_update(
1990
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
1991
- #else
1992
- void prepare_for_sg_update();
1993
- #endif
1994
- };
1995
-
1996
- #if defined USESSE
1997
- void LR_MFC::prepare_for_sg_update(
1998
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
1999
- {
2000
- calc_z(XMMz, model.k, p, q);
2001
- _mm_store_ss(&z, XMMz);
2002
- if(N->r > 0)
2003
- {
2004
- z = exp(-z);
2005
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2006
- XMMz = _mm_set1_ps(z/(1+z));
2007
- }
2008
- else
2009
- {
2010
- z = exp(z);
2011
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2012
- XMMz = _mm_set1_ps(-z/(1+z));
2013
- }
2014
- XMMerror = XMMloss;
2015
- }
2016
- #elif defined USEAVX
2017
- void LR_MFC::prepare_for_sg_update(
2018
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2019
- {
2020
- calc_z(XMMz, model.k, p, q);
2021
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2022
- if(N->r > 0)
2023
- {
2024
- z = exp(-z);
2025
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2026
- XMMz = _mm256_set1_ps(z/(1+z));
2027
- }
2028
- else
2029
- {
2030
- z = exp(z);
2031
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1.0+z)));
2032
- XMMz = _mm256_set1_ps(-z/(1+z));
2033
- }
2034
- XMMerror = XMMloss;
2035
- }
2036
- #else
2037
- void LR_MFC::prepare_for_sg_update()
2038
- {
2039
- calc_z(z, model.k, p, q);
2040
- if(N->r > 0)
2041
- {
2042
- z = exp(-z);
2043
- loss += log(1+z);
2044
- error = loss;
2045
- z = z/(1+z);
2046
- }
2047
- else
2048
- {
2049
- z = exp(z);
2050
- loss += log(1+z);
2051
- error = loss;
2052
- z = -z/(1+z);
2053
- }
2054
- }
2055
- #endif
2056
-
2057
- class L2_MFC : public MFSolver
2058
- {
2059
- public:
2060
- L2_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2061
- mf_float *PG, mf_float *QG, mf_model &model,
2062
- mf_parameter param, bool &slow_only)
2063
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2064
-
2065
- protected:
2066
- #if defined USESSE
2067
- void prepare_for_sg_update(
2068
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2069
- #elif defined USEAVX
2070
- void prepare_for_sg_update(
2071
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2072
- #else
2073
- void prepare_for_sg_update();
2074
- #endif
2075
- };
2076
-
2077
- #if defined USESSE
2078
- void L2_MFC::prepare_for_sg_update(
2079
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2080
- {
2081
- calc_z(XMMz, model.k, p, q);
2082
- if(N->r > 0)
2083
- {
2084
- __m128 mask = _mm_cmpgt_ps(XMMz, _mm_set1_ps(0.0f));
2085
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2086
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2087
- XMMz = _mm_max_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2088
- _mm_set1_ps(1.0f), XMMz));
2089
- }
2090
- else
2091
- {
2092
- __m128 mask = _mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f));
2093
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2094
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2095
- XMMz = _mm_min_ps(_mm_set1_ps(0.0f), _mm_sub_ps(
2096
- _mm_set1_ps(-1.0f), XMMz));
2097
- }
2098
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2099
- _mm_mul_ps(XMMz, XMMz)));
2100
- }
2101
- #elif defined USEAVX
2102
- void L2_MFC::prepare_for_sg_update(
2103
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2104
- {
2105
- calc_z(XMMz, model.k, p, q);
2106
- if(N->r > 0)
2107
- {
2108
- __m128 mask = _mm_cmpgt_ps(_mm256_castps256_ps128(XMMz),
2109
- _mm_set1_ps(0.0f));
2110
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2111
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2112
- XMMz = _mm256_max_ps(_mm256_set1_ps(0.0f),
2113
- _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz));
2114
- }
2115
- else
2116
- {
2117
- __m128 mask = _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2118
- _mm_set1_ps(0.0f));
2119
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2120
- _mm_and_ps(_mm_set1_ps(1.0f), mask)));
2121
- XMMz = _mm256_min_ps(_mm256_set1_ps(0.0f),
2122
- _mm256_sub_ps(_mm256_set1_ps(-1.0f), XMMz));
2123
- }
2124
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2125
- _mm_mul_ps(_mm256_castps256_ps128(XMMz),
2126
- _mm256_castps256_ps128(XMMz))));
2127
- }
2128
- #else
2129
- void L2_MFC::prepare_for_sg_update()
2130
- {
2131
- calc_z(z, model.k, p, q);
2132
- if(N->r > 0)
2133
- {
2134
- error += z > 0? 1: 0;
2135
- z = max(0.0f, 1-z);
2136
- }
2137
- else
2138
- {
2139
- error += z < 0? 1: 0;
2140
- z = min(0.0f, -1-z);
2141
- }
2142
- loss += z*z;
2143
- }
2144
- #endif
2145
-
2146
- class L1_MFC : public MFSolver
2147
- {
2148
- public:
2149
- L1_MFC(Scheduler &scheduler, vector<BlockBase*> &blocks, mf_float *PG, mf_float *QG,
2150
- mf_model &model, mf_parameter param, bool &slow_only)
2151
- : MFSolver(scheduler, blocks, PG, QG, model, param, slow_only) {}
2152
-
2153
- protected:
2154
- #if defined USESSE
2155
- void prepare_for_sg_update(
2156
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2157
- #elif defined USEAVX
2158
- void prepare_for_sg_update(
2159
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2160
- #else
2161
- void prepare_for_sg_update();
2162
- #endif
2163
- };
2164
-
2165
- #if defined USESSE
2166
- void L1_MFC::prepare_for_sg_update(
2167
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2168
- {
2169
- calc_z(XMMz, model.k, p, q);
2170
- if(N->r > 0)
2171
- {
2172
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2173
- _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2174
- _mm_set1_ps(1.0f))));
2175
- XMMz = _mm_sub_ps(_mm_set1_ps(1.0f), XMMz);
2176
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2177
- _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2178
- XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2179
- _mm_set1_ps(1.0f));
2180
- }
2181
- else
2182
- {
2183
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(
2184
- _mm_and_ps(_mm_cmplt_ps(XMMz, _mm_set1_ps(0.0f)),
2185
- _mm_set1_ps(1.0f))));
2186
- XMMz = _mm_add_ps(_mm_set1_ps(1.0f), XMMz);
2187
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(
2188
- _mm_max_ps(_mm_set1_ps(0.0f), XMMz)));
2189
- XMMz = _mm_and_ps(_mm_cmpge_ps(XMMz, _mm_set1_ps(0.0f)),
2190
- _mm_set1_ps(-1.0f));
2191
- }
2192
- }
2193
- #elif defined USEAVX
2194
- void L1_MFC::prepare_for_sg_update(
2195
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2196
- {
2197
- calc_z(XMMz, model.k, p, q);
2198
- if(N->r > 0)
2199
- {
2200
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2201
- _mm_cmpge_ps(_mm256_castps256_ps128(XMMz),
2202
- _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2203
- XMMz = _mm256_sub_ps(_mm256_set1_ps(1.0f), XMMz);
2204
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2205
- _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2206
- XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2207
- _CMP_GE_OS), _mm256_set1_ps(1.0f));
2208
- }
2209
- else
2210
- {
2211
- XMMerror = _mm_add_pd(XMMerror, _mm_cvtps_pd(_mm_and_ps(
2212
- _mm_cmplt_ps(_mm256_castps256_ps128(XMMz),
2213
- _mm_set1_ps(0.0f)), _mm_set1_ps(1.0f))));
2214
- XMMz = _mm256_add_ps(_mm256_set1_ps(1.0f), XMMz);
2215
- XMMloss = _mm_add_pd(XMMloss, _mm_cvtps_pd(_mm_max_ps(
2216
- _mm_set1_ps(0.0f), _mm256_castps256_ps128(XMMz))));
2217
- XMMz = _mm256_and_ps(_mm256_cmp_ps(XMMz, _mm256_set1_ps(0.0f),
2218
- _CMP_GE_OS), _mm256_set1_ps(-1.0f));
2219
- }
2220
- }
2221
- #else
2222
- void L1_MFC::prepare_for_sg_update()
2223
- {
2224
- calc_z(z, model.k, p, q);
2225
- if(N->r > 0)
2226
- {
2227
- loss += max(0.0f, 1-z);
2228
- error += z > 0? 1.0f: 0.0f;
2229
- z = z > 1? 0.0f: 1.0f;
2230
- }
2231
- else
2232
- {
2233
- loss += max(0.0f, 1+z);
2234
- error += z < 0? 1.0f: 0.0f;
2235
- z = z < -1? 0.0f: -1.0f;
2236
- }
2237
- }
2238
- #endif
2239
- //--------------------------------------
2240
- //------------One-class MF--------------
2241
- //--------------------------------------
2242
-
2243
- class BPRSolver : public SolverBase
2244
- {
2245
- public:
2246
- BPRSolver(Scheduler &scheduler, vector<BlockBase*> &blocks,
2247
- mf_float *PG, mf_float *QG, mf_model &model, mf_parameter param,
2248
- bool &slow_only, bool is_column_oriented)
2249
- : SolverBase(scheduler, blocks, PG, QG, model, param, slow_only),
2250
- is_column_oriented(is_column_oriented) {}
2251
-
2252
- protected:
2253
- #if defined USESSE
2254
- static void calc_z(__m128 &XMMz, mf_int k,
2255
- mf_float *p, mf_float *q, mf_float *w);
2256
- void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2257
- void prepare_for_sg_update(
2258
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2259
- void sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2260
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2261
- __m128 XMMlambda_p2, __m128 XMMlamdba_q2,
2262
- __m128 XMMeta, __m128 XMMrk);
2263
- void finalize(__m128d XMMloss, __m128d XMMerror);
2264
- #elif defined USEAVX
2265
- static void calc_z(__m256 &XMMz, mf_int k,
2266
- mf_float *p, mf_float *q, mf_float *w);
2267
- void arrange_block(__m128d &XMMloss, __m128d &XMMerror);
2268
- void prepare_for_sg_update(
2269
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror);
2270
- void sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2271
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2272
- __m256 XMMlambda_p2, __m256 XMMlamdba_q2,
2273
- __m256 XMMeta, __m256 XMMrk);
2274
- void finalize(__m128d XMMloss, __m128d XMMerror);
2275
- #else
2276
- static void calc_z(mf_float &z, mf_int k,
2277
- mf_float *p, mf_float *q, mf_float *w);
2278
- void arrange_block();
2279
- void prepare_for_sg_update();
2280
- void sg_update(mf_int d_begin, mf_int d_end, mf_float rk);
2281
- void finalize();
2282
- #endif
2283
- void update() { ++pG; ++qG; ++wG; };
2284
- virtual void prepare_negative() = 0;
2285
-
2286
- bool is_column_oriented;
2287
- mf_int bpr_bid;
2288
- mf_float *w;
2289
- mf_float *wG;
2290
- };
2291
-
2292
-
2293
- #if defined USESSE
2294
- inline void BPRSolver::calc_z(
2295
- __m128 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2296
- {
2297
- XMMz = _mm_setzero_ps();
2298
- for(mf_int d = 0; d < k; d += 4)
2299
- XMMz = _mm_add_ps(XMMz, _mm_mul_ps(_mm_load_ps(p+d),
2300
- _mm_sub_ps(_mm_load_ps(q+d), _mm_load_ps(w+d))));
2301
- // Bit-wise representation of 177 is {1,0}+{1,1}+{0,0}+{0,1} from
2302
- // high-bit to low-bit, where "+" means concatenating two arrays.
2303
- __m128 XMMtmp = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMz, XMMz, 177));
2304
- // Bit-wise representation of 78 is {0,1}+{0,0}+{1,1}+{1,0} from
2305
- // high-bit to low-bit, where "+" means concatenating two arrays.
2306
- XMMz = _mm_add_ps(XMMz, _mm_shuffle_ps(XMMtmp, XMMtmp, 78));
2307
- }
2308
-
2309
- void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2310
- {
2311
- XMMloss = _mm_setzero_pd();
2312
- XMMerror = _mm_setzero_pd();
2313
- bid = scheduler.get_job();
2314
- block = blocks[bid];
2315
- block->reload();
2316
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2317
- }
2318
-
2319
- void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2320
- {
2321
- _mm_store_sd(&loss, XMMloss);
2322
- _mm_store_sd(&error, XMMerror);
2323
- scheduler.put_job(bid, loss, error);
2324
- scheduler.put_bpr_job(bid, bpr_bid);
2325
- }
2326
-
2327
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m128 XMMz,
2328
- __m128 XMMlambda_p1, __m128 XMMlambda_q1,
2329
- __m128 XMMlambda_p2, __m128 XMMlambda_q2,
2330
- __m128 XMMeta, __m128 XMMrk)
2331
- {
2332
- __m128 XMMpG = _mm_load1_ps(pG);
2333
- __m128 XMMqG = _mm_load1_ps(qG);
2334
- __m128 XMMwG = _mm_load1_ps(wG);
2335
- __m128 XMMeta_p = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMpG));
2336
- __m128 XMMeta_q = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMqG));
2337
- __m128 XMMeta_w = _mm_mul_ps(XMMeta, _mm_rsqrt_ps(XMMwG));
2338
-
2339
- __m128 XMMpG1 = _mm_setzero_ps();
2340
- __m128 XMMqG1 = _mm_setzero_ps();
2341
- __m128 XMMwG1 = _mm_setzero_ps();
2342
-
2343
- for(mf_int d = d_begin; d < d_end; d += 4)
2344
- {
2345
- __m128 XMMp = _mm_load_ps(p+d);
2346
- __m128 XMMq = _mm_load_ps(q+d);
2347
- __m128 XMMw = _mm_load_ps(w+d);
2348
-
2349
- __m128 XMMpg = _mm_add_ps(_mm_mul_ps(XMMlambda_p2, XMMp),
2350
- _mm_mul_ps(XMMz, _mm_sub_ps(XMMw, XMMq)));
2351
- __m128 XMMqg = _mm_sub_ps(_mm_mul_ps(XMMlambda_q2, XMMq),
2352
- _mm_mul_ps(XMMz, XMMp));
2353
- __m128 XMMwg = _mm_add_ps(_mm_mul_ps(XMMlambda_q2, XMMw),
2354
- _mm_mul_ps(XMMz, XMMp));
2355
-
2356
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_mul_ps(XMMpg, XMMpg));
2357
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_mul_ps(XMMqg, XMMqg));
2358
- XMMwG1 = _mm_add_ps(XMMwG1, _mm_mul_ps(XMMwg, XMMwg));
2359
-
2360
- XMMp = _mm_sub_ps(XMMp, _mm_mul_ps(XMMeta_p, XMMpg));
2361
- XMMq = _mm_sub_ps(XMMq, _mm_mul_ps(XMMeta_q, XMMqg));
2362
- XMMw = _mm_sub_ps(XMMw, _mm_mul_ps(XMMeta_w, XMMwg));
2363
-
2364
- _mm_store_ps(p+d, XMMp);
2365
- _mm_store_ps(q+d, XMMq);
2366
- _mm_store_ps(w+d, XMMw);
2367
- }
2368
-
2369
- mf_float tmp = 0;
2370
- _mm_store_ss(&tmp, XMMlambda_p1);
2371
- if(tmp > 0)
2372
- {
2373
- for(mf_int d = d_begin; d < d_end; d += 4)
2374
- {
2375
- __m128 XMMp = _mm_load_ps(p+d);
2376
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMp, _mm_set1_ps(0.0f)),
2377
- _mm_set1_ps(-0.0f));
2378
- XMMp = _mm_xor_ps(XMMflip,
2379
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMp, XMMflip),
2380
- _mm_mul_ps(XMMeta_p, XMMlambda_p1)), _mm_set1_ps(0.0f)));
2381
- _mm_store_ps(p+d, XMMp);
2382
- }
2383
- }
2384
-
2385
- _mm_store_ss(&tmp, XMMlambda_q1);
2386
- if(tmp > 0)
2387
- {
2388
- for(mf_int d = d_begin; d < d_end; d += 4)
2389
- {
2390
- __m128 XMMq = _mm_load_ps(q+d);
2391
- __m128 XMMw = _mm_load_ps(w+d);
2392
- __m128 XMMflip = _mm_and_ps(_mm_cmple_ps(XMMq, _mm_set1_ps(0.0f)),
2393
- _mm_set1_ps(-0.0f));
2394
- XMMq = _mm_xor_ps(XMMflip,
2395
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMq, XMMflip),
2396
- _mm_mul_ps(XMMeta_q, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2397
- _mm_store_ps(q+d, XMMq);
2398
-
2399
-
2400
- XMMflip = _mm_and_ps(_mm_cmple_ps(XMMw, _mm_set1_ps(0.0f)),
2401
- _mm_set1_ps(-0.0f));
2402
- XMMw = _mm_xor_ps(XMMflip,
2403
- _mm_max_ps(_mm_sub_ps(_mm_xor_ps(XMMw, XMMflip),
2404
- _mm_mul_ps(XMMeta_w, XMMlambda_q1)), _mm_set1_ps(0.0f)));
2405
- _mm_store_ps(w+d, XMMw);
2406
- }
2407
- }
2408
-
2409
- if(param.do_nmf)
2410
- {
2411
- for(mf_int d = d_begin; d < d_end; d += 4)
2412
- {
2413
- __m128 XMMp = _mm_load_ps(p+d);
2414
- __m128 XMMq = _mm_load_ps(q+d);
2415
- __m128 XMMw = _mm_load_ps(w+d);
2416
- XMMp = _mm_max_ps(XMMp, _mm_set1_ps(0.0f));
2417
- XMMq = _mm_max_ps(XMMq, _mm_set1_ps(0.0f));
2418
- XMMw = _mm_max_ps(XMMw, _mm_set1_ps(0.0f));
2419
- _mm_store_ps(p+d, XMMp);
2420
- _mm_store_ps(q+d, XMMq);
2421
- _mm_store_ps(w+d, XMMw);
2422
- }
2423
- }
2424
-
2425
- // Update learning rate of latent vector p. Squared derivatives along all
2426
- // latent dimensions will be computed above. Here their average will be
2427
- // added into the associated squared-gradient sum.
2428
- __m128 XMMtmp = _mm_add_ps(XMMpG1, _mm_movehl_ps(XMMpG1, XMMpG1));
2429
- XMMpG1 = _mm_add_ps(XMMpG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2430
- XMMpG = _mm_add_ps(XMMpG, _mm_mul_ps(XMMpG1, XMMrk));
2431
- _mm_store_ss(pG, XMMpG);
2432
-
2433
- // Similar code is used to update learning rate of latent vector q.
2434
- XMMtmp = _mm_add_ps(XMMqG1, _mm_movehl_ps(XMMqG1, XMMqG1));
2435
- XMMqG1 = _mm_add_ps(XMMqG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2436
- XMMqG = _mm_add_ps(XMMqG, _mm_mul_ps(XMMqG1, XMMrk));
2437
- _mm_store_ss(qG, XMMqG);
2438
-
2439
- // Similar code is used to update learning rate of latent vector w.
2440
- XMMtmp = _mm_add_ps(XMMwG1, _mm_movehl_ps(XMMwG1, XMMwG1));
2441
- XMMwG1 = _mm_add_ps(XMMwG1, _mm_shuffle_ps(XMMtmp, XMMtmp, 1));
2442
- XMMwG = _mm_add_ps(XMMwG, _mm_mul_ps(XMMwG1, XMMrk));
2443
- _mm_store_ss(wG, XMMwG);
2444
- }
2445
-
2446
- void BPRSolver::prepare_for_sg_update(
2447
- __m128 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2448
- {
2449
- prepare_negative();
2450
- calc_z(XMMz, model.k, p, q, w);
2451
- _mm_store_ss(&z, XMMz);
2452
- z = exp(-z);
2453
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2454
- XMMerror = XMMloss;
2455
- XMMz = _mm_set1_ps(z/(1+z));
2456
- }
2457
- #elif defined USEAVX
2458
- inline void BPRSolver::calc_z(
2459
- __m256 &XMMz, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2460
- {
2461
- XMMz = _mm256_setzero_ps();
2462
- for(mf_int d = 0; d < k; d += 8)
2463
- XMMz = _mm256_add_ps(XMMz, _mm256_mul_ps(
2464
- _mm256_load_ps(p+d), _mm256_sub_ps(
2465
- _mm256_load_ps(q+d), _mm256_load_ps(w+d))));
2466
- XMMz = _mm256_add_ps(XMMz, _mm256_permute2f128_ps(XMMz, XMMz, 0x1));
2467
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
2468
- XMMz = _mm256_hadd_ps(XMMz, XMMz);
2469
- }
2470
-
2471
- void BPRSolver::arrange_block(__m128d &XMMloss, __m128d &XMMerror)
2472
- {
2473
- XMMloss = _mm_setzero_pd();
2474
- XMMerror = _mm_setzero_pd();
2475
- bid = scheduler.get_job();
2476
- block = blocks[bid];
2477
- block->reload();
2478
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2479
- }
2480
-
2481
- void BPRSolver::finalize(__m128d XMMloss, __m128d XMMerror)
2482
- {
2483
- _mm_store_sd(&loss, XMMloss);
2484
- _mm_store_sd(&error, XMMerror);
2485
- scheduler.put_job(bid, loss, error);
2486
- scheduler.put_bpr_job(bid, bpr_bid);
2487
- }
2488
-
2489
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, __m256 XMMz,
2490
- __m256 XMMlambda_p1, __m256 XMMlambda_q1,
2491
- __m256 XMMlambda_p2, __m256 XMMlambda_q2,
2492
- __m256 XMMeta, __m256 XMMrk)
2493
- {
2494
- __m256 XMMpG = _mm256_broadcast_ss(pG);
2495
- __m256 XMMqG = _mm256_broadcast_ss(qG);
2496
- __m256 XMMwG = _mm256_broadcast_ss(wG);
2497
- __m256 XMMeta_p =
2498
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMpG));
2499
- __m256 XMMeta_q =
2500
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMqG));
2501
- __m256 XMMeta_w =
2502
- _mm256_mul_ps(XMMeta, _mm256_rsqrt_ps(XMMwG));
2503
-
2504
- __m256 XMMpG1 = _mm256_setzero_ps();
2505
- __m256 XMMqG1 = _mm256_setzero_ps();
2506
- __m256 XMMwG1 = _mm256_setzero_ps();
2507
-
2508
- for(mf_int d = d_begin; d < d_end; d += 8)
2509
- {
2510
- __m256 XMMp = _mm256_load_ps(p+d);
2511
- __m256 XMMq = _mm256_load_ps(q+d);
2512
- __m256 XMMw = _mm256_load_ps(w+d);
2513
- __m256 XMMpg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_p2, XMMp),
2514
- _mm256_mul_ps(XMMz, _mm256_sub_ps(XMMw, XMMq)));
2515
- __m256 XMMqg = _mm256_sub_ps(_mm256_mul_ps(XMMlambda_q2, XMMq),
2516
- _mm256_mul_ps(XMMz, XMMp));
2517
- __m256 XMMwg = _mm256_add_ps(_mm256_mul_ps(XMMlambda_q2, XMMw),
2518
- _mm256_mul_ps(XMMz, XMMp));
2519
-
2520
- XMMpG1 = _mm256_add_ps(XMMpG1, _mm256_mul_ps(XMMpg, XMMpg));
2521
- XMMqG1 = _mm256_add_ps(XMMqG1, _mm256_mul_ps(XMMqg, XMMqg));
2522
- XMMwG1 = _mm256_add_ps(XMMwG1, _mm256_mul_ps(XMMwg, XMMwg));
2523
-
2524
- XMMp = _mm256_sub_ps(XMMp, _mm256_mul_ps(XMMeta_p, XMMpg));
2525
- XMMq = _mm256_sub_ps(XMMq, _mm256_mul_ps(XMMeta_q, XMMqg));
2526
- XMMw = _mm256_sub_ps(XMMw, _mm256_mul_ps(XMMeta_w, XMMwg));
2527
-
2528
- _mm256_store_ps(p+d, XMMp);
2529
- _mm256_store_ps(q+d, XMMq);
2530
- _mm256_store_ps(w+d, XMMw);
2531
- }
2532
-
2533
- mf_float tmp = 0;
2534
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_p1));
2535
- if(tmp > 0)
2536
- {
2537
- for(mf_int d = d_begin; d < d_end; d += 8)
2538
- {
2539
- __m256 XMMp = _mm256_load_ps(p+d);
2540
- __m256 XMMflip =
2541
- _mm256_and_ps(
2542
- _mm256_cmp_ps(XMMp, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2543
- _mm256_set1_ps(-0.0f));
2544
- XMMp = _mm256_xor_ps(XMMflip,
2545
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMp, XMMflip),
2546
- _mm256_mul_ps(XMMeta_p, XMMlambda_p1)),
2547
- _mm256_set1_ps(0.0f)));
2548
- _mm256_store_ps(p+d, XMMp);
2549
- }
2550
- }
2551
-
2552
- _mm_store_ss(&tmp, _mm256_castps256_ps128(XMMlambda_q1));
2553
- if(tmp > 0)
2554
- {
2555
- for(mf_int d = d_begin; d < d_end; d += 8)
2556
- {
2557
- __m256 XMMq = _mm256_load_ps(q+d);
2558
- __m256 XMMw = _mm256_load_ps(w+d);
2559
- __m256 XMMflip;
2560
-
2561
- XMMflip = _mm256_and_ps(
2562
- _mm256_cmp_ps(XMMq, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2563
- _mm256_set1_ps(-0.0f));
2564
- XMMq = _mm256_xor_ps(XMMflip,
2565
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMq, XMMflip),
2566
- _mm256_mul_ps(XMMeta_q, XMMlambda_q1)),
2567
- _mm256_set1_ps(0.0f)));
2568
- _mm256_store_ps(q+d, XMMq);
2569
-
2570
-
2571
- XMMflip = _mm256_and_ps(
2572
- _mm256_cmp_ps(XMMw, _mm256_set1_ps(0.0f), _CMP_LE_OS),
2573
- _mm256_set1_ps(-0.0f));
2574
- XMMw = _mm256_xor_ps(XMMflip,
2575
- _mm256_max_ps(_mm256_sub_ps(_mm256_xor_ps(XMMw, XMMflip),
2576
- _mm256_mul_ps(XMMeta_w, XMMlambda_q1)),
2577
- _mm256_set1_ps(0.0f)));
2578
- _mm256_store_ps(w+d, XMMw);
2579
- }
2580
- }
2581
-
2582
- if(param.do_nmf)
2583
- {
2584
- for(mf_int d = d_begin; d < d_end; d += 8)
2585
- {
2586
- __m256 XMMp = _mm256_load_ps(p+d);
2587
- __m256 XMMq = _mm256_load_ps(q+d);
2588
- __m256 XMMw = _mm256_load_ps(w+d);
2589
- XMMp = _mm256_max_ps(XMMp, _mm256_set1_ps(0.0f));
2590
- XMMq = _mm256_max_ps(XMMq, _mm256_set1_ps(0.0f));
2591
- XMMw = _mm256_max_ps(XMMw, _mm256_set1_ps(0.0f));
2592
- _mm256_store_ps(p+d, XMMp);
2593
- _mm256_store_ps(q+d, XMMq);
2594
- _mm256_store_ps(w+d, XMMw);
2595
- }
2596
- }
2597
-
2598
- XMMpG1 = _mm256_add_ps(XMMpG1,
2599
- _mm256_permute2f128_ps(XMMpG1, XMMpG1, 0x1));
2600
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2601
- XMMpG1 = _mm256_hadd_ps(XMMpG1, XMMpG1);
2602
-
2603
- XMMqG1 = _mm256_add_ps(XMMqG1,
2604
- _mm256_permute2f128_ps(XMMqG1, XMMqG1, 0x1));
2605
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2606
- XMMqG1 = _mm256_hadd_ps(XMMqG1, XMMqG1);
2607
-
2608
- XMMwG1 = _mm256_add_ps(XMMwG1,
2609
- _mm256_permute2f128_ps(XMMwG1, XMMwG1, 0x1));
2610
- XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2611
- XMMwG1 = _mm256_hadd_ps(XMMwG1, XMMwG1);
2612
-
2613
- XMMpG = _mm256_add_ps(XMMpG, _mm256_mul_ps(XMMpG1, XMMrk));
2614
- XMMqG = _mm256_add_ps(XMMqG, _mm256_mul_ps(XMMqG1, XMMrk));
2615
- XMMwG = _mm256_add_ps(XMMwG, _mm256_mul_ps(XMMwG1, XMMrk));
2616
-
2617
- _mm_store_ss(pG, _mm256_castps256_ps128(XMMpG));
2618
- _mm_store_ss(qG, _mm256_castps256_ps128(XMMqG));
2619
- _mm_store_ss(wG, _mm256_castps256_ps128(XMMwG));
2620
- }
2621
-
2622
- void BPRSolver::prepare_for_sg_update(
2623
- __m256 &XMMz, __m128d &XMMloss, __m128d &XMMerror)
2624
- {
2625
- prepare_negative();
2626
- calc_z(XMMz, model.k, p, q, w);
2627
- _mm_store_ss(&z, _mm256_castps256_ps128(XMMz));
2628
- z = exp(-z);
2629
- XMMloss = _mm_add_pd(XMMloss, _mm_set1_pd(log(1+z)));
2630
- XMMerror = XMMloss;
2631
- XMMz = _mm256_set1_ps(z/(1+z));
2632
- }
2633
- #else
2634
- inline void BPRSolver::calc_z(
2635
- mf_float &z, mf_int k, mf_float *p, mf_float *q, mf_float *w)
2636
- {
2637
- z = 0;
2638
- for(mf_int d = 0; d < k; ++d)
2639
- z += p[d]*(q[d]-w[d]);
2640
- }
2641
-
2642
- void BPRSolver::arrange_block()
2643
- {
2644
- loss = 0.0;
2645
- error = 0.0;
2646
- bid = scheduler.get_job();
2647
- block = blocks[bid];
2648
- block->reload();
2649
- bpr_bid = scheduler.get_bpr_job(bid, is_column_oriented);
2650
- }
2651
-
2652
- void BPRSolver::finalize()
2653
- {
2654
- scheduler.put_job(bid, loss, error);
2655
- scheduler.put_bpr_job(bid, bpr_bid);
2656
- }
2657
-
2658
- void BPRSolver::sg_update(mf_int d_begin, mf_int d_end, mf_float rk)
2659
- {
2660
- mf_float eta_p = param.eta*qrsqrt(*pG);
2661
- mf_float eta_q = param.eta*qrsqrt(*qG);
2662
- mf_float eta_w = param.eta*qrsqrt(*wG);
2663
-
2664
- mf_float pG1 = 0;
2665
- mf_float qG1 = 0;
2666
- mf_float wG1 = 0;
2667
-
2668
- for(mf_int d = d_begin; d < d_end; ++d)
2669
- {
2670
- mf_float gp = z*(w[d]-q[d]) + lambda_p2*p[d];
2671
- mf_float gq = -z*p[d] + lambda_q2*q[d];
2672
- mf_float gw = z*p[d] + lambda_q2*w[d];
2673
-
2674
- pG1 += gp*gp;
2675
- qG1 += gq*gq;
2676
- wG1 += gw*gw;
2677
-
2678
- p[d] -= eta_p*gp;
2679
- q[d] -= eta_q*gq;
2680
- w[d] -= eta_w*gw;
2681
- }
2682
-
2683
- if(lambda_p1 > 0)
2684
- {
2685
- for(mf_int d = d_begin; d < d_end; ++d)
2686
- {
2687
- mf_float p1 = max(abs(p[d])-lambda_p1*eta_p, 0.0f);
2688
- p[d] = p[d] >= 0? p1: -p1;
2689
- }
2690
- }
2691
-
2692
- if(lambda_q1 > 0)
2693
- {
2694
- for(mf_int d = d_begin; d < d_end; ++d)
2695
- {
2696
- mf_float q1 = max(abs(w[d])-lambda_q1*eta_w, 0.0f);
2697
- w[d] = w[d] >= 0? q1: -q1;
2698
- q1 = max(abs(q[d])-lambda_q1*eta_q, 0.0f);
2699
- q[d] = q[d] >= 0? q1: -q1;
2700
- }
2701
- }
2702
-
2703
- if(param.do_nmf)
2704
- {
2705
- for(mf_int d = d_begin; d < d_end; ++d)
2706
- {
2707
- p[d] = max(p[d], (mf_float)0.0);
2708
- q[d] = max(q[d], (mf_float)0.0);
2709
- w[d] = max(w[d], (mf_float)0.0);
2710
- }
2711
- }
2712
-
2713
- *pG += pG1*rk;
2714
- *qG += qG1*rk;
2715
- *wG += wG1*rk;
2716
- }
2717
-
2718
- void BPRSolver::prepare_for_sg_update()
2719
- {
2720
- prepare_negative();
2721
- calc_z(z, model.k, p, q, w);
2722
- z = exp(-z);
2723
- loss += log(1+z);
2724
- error = loss;
2725
- z = z/(1+z);
2726
- }
2727
- #endif
2728
-
2729
- class COL_BPR_MFOC : public BPRSolver
2730
- {
2731
- public:
2732
- COL_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2733
- mf_float *PG, mf_float *QG, mf_model &model,
2734
- mf_parameter param, bool &slow_only,
2735
- bool is_column_oriented=true)
2736
- : BPRSolver(scheduler, blocks, PG, QG, model, param,
2737
- slow_only, is_column_oriented) {}
2738
- protected:
2739
- #if defined USESSE
2740
- void load_fixed_variables(
2741
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2742
- __m128 &XMMlambda_p2, __m128 &XMMlabmda_q2,
2743
- __m128 &XMMeta, __m128 &XMMrk_slow,
2744
- __m128 &XMMrk_fast);
2745
- #elif defined USEAVX
2746
- void load_fixed_variables(
2747
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2748
- __m256 &XMMlambda_p2, __m256 &XMMlabmda_q2,
2749
- __m256 &XMMeta, __m256 &XMMrk_slow,
2750
- __m256 &XMMrk_fast);
2751
- #else
2752
- void load_fixed_variables();
2753
- #endif
2754
- void prepare_negative();
2755
- };
2756
-
2757
- void COL_BPR_MFOC::prepare_negative()
2758
- {
2759
- mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2760
- is_column_oriented);
2761
- w = model.P + negative*model.k;
2762
- wG = PG + negative*2;
2763
- swap(p, q);
2764
- swap(pG, qG);
2765
- }
2766
-
2767
- #if defined USESSE
2768
- void COL_BPR_MFOC::load_fixed_variables(
2769
- __m128 &XMMlambda_p1, __m128 &XMMlambda_q1,
2770
- __m128 &XMMlambda_p2, __m128 &XMMlambda_q2,
2771
- __m128 &XMMeta, __m128 &XMMrk_slow,
2772
- __m128 &XMMrk_fast)
2773
- {
2774
- XMMlambda_p1 = _mm_set1_ps(param.lambda_q1);
2775
- XMMlambda_q1 = _mm_set1_ps(param.lambda_p1);
2776
- XMMlambda_p2 = _mm_set1_ps(param.lambda_q2);
2777
- XMMlambda_q2 = _mm_set1_ps(param.lambda_p2);
2778
- XMMeta = _mm_set1_ps(param.eta);
2779
- XMMrk_slow = _mm_set1_ps((mf_float)1.0/kALIGN);
2780
- XMMrk_fast = _mm_set1_ps((mf_float)1.0/(model.k-kALIGN));
2781
- }
2782
- #elif defined USEAVX
2783
- void COL_BPR_MFOC::load_fixed_variables(
2784
- __m256 &XMMlambda_p1, __m256 &XMMlambda_q1,
2785
- __m256 &XMMlambda_p2, __m256 &XMMlambda_q2,
2786
- __m256 &XMMeta, __m256 &XMMrk_slow,
2787
- __m256 &XMMrk_fast)
2788
- {
2789
- XMMlambda_p1 = _mm256_set1_ps(param.lambda_q1);
2790
- XMMlambda_q1 = _mm256_set1_ps(param.lambda_p1);
2791
- XMMlambda_p2 = _mm256_set1_ps(param.lambda_q2);
2792
- XMMlambda_q2 = _mm256_set1_ps(param.lambda_p2);
2793
- XMMeta = _mm256_set1_ps(param.eta);
2794
- XMMrk_slow = _mm256_set1_ps((mf_float)1.0/kALIGN);
2795
- XMMrk_fast = _mm256_set1_ps((mf_float)1.0/(model.k-kALIGN));
2796
- }
2797
- #else
2798
- void COL_BPR_MFOC::load_fixed_variables()
2799
- {
2800
- lambda_p1 = param.lambda_q1;
2801
- lambda_q1 = param.lambda_p1;
2802
- lambda_p2 = param.lambda_q2;
2803
- lambda_q2 = param.lambda_p2;
2804
- rk_slow = (mf_float)1.0/kALIGN;
2805
- rk_fast = (mf_float)1.0/(model.k-kALIGN);
2806
- }
2807
- #endif
2808
-
2809
- class ROW_BPR_MFOC : public BPRSolver
2810
- {
2811
- public:
2812
- ROW_BPR_MFOC(Scheduler &scheduler, vector<BlockBase*> &blocks,
2813
- mf_float *PG, mf_float *QG, mf_model &model,
2814
- mf_parameter param, bool &slow_only,
2815
- bool is_column_oriented = false)
2816
- : BPRSolver(scheduler, blocks, PG, QG, model, param,
2817
- slow_only, is_column_oriented) {}
2818
- protected:
2819
- void prepare_negative();
2820
- };
2821
-
2822
- void ROW_BPR_MFOC::prepare_negative()
2823
- {
2824
- mf_int negative = scheduler.get_negative(bid, bpr_bid, model.m, model.n,
2825
- is_column_oriented);
2826
- w = model.Q + negative*model.k;
2827
- wG = QG + negative*2;
2828
- }
2829
-
2830
-
2831
- class SolverFactory
2832
- {
2833
- public:
2834
- static shared_ptr<SolverBase> get_solver(
2835
- Scheduler &scheduler,
2836
- vector<BlockBase*> &blocks,
2837
- mf_float *PG,
2838
- mf_float *QG,
2839
- mf_model &model,
2840
- mf_parameter param,
2841
- bool &slow_only);
2842
- };
2843
-
2844
- shared_ptr<SolverBase> SolverFactory::get_solver(
2845
- Scheduler &scheduler,
2846
- vector<BlockBase*> &blocks,
2847
- mf_float *PG,
2848
- mf_float *QG,
2849
- mf_model &model,
2850
- mf_parameter param,
2851
- bool &slow_only)
2852
- {
2853
- shared_ptr<SolverBase> solver;
2854
-
2855
- switch(param.fun)
2856
- {
2857
- case P_L2_MFR:
2858
- solver = shared_ptr<SolverBase>(new L2_MFR(scheduler, blocks,
2859
- PG, QG, model, param, slow_only));
2860
- break;
2861
- case P_L1_MFR:
2862
- solver = shared_ptr<SolverBase>(new L1_MFR(scheduler, blocks,
2863
- PG, QG, model, param, slow_only));
2864
- break;
2865
- case P_KL_MFR:
2866
- solver = shared_ptr<SolverBase>(new KL_MFR(scheduler, blocks,
2867
- PG, QG, model, param, slow_only));
2868
- break;
2869
- case P_LR_MFC:
2870
- solver = shared_ptr<SolverBase>(new LR_MFC(scheduler, blocks,
2871
- PG, QG, model, param, slow_only));
2872
- break;
2873
- case P_L2_MFC:
2874
- solver = shared_ptr<SolverBase>(new L2_MFC(scheduler, blocks,
2875
- PG, QG, model, param, slow_only));
2876
- break;
2877
- case P_L1_MFC:
2878
- solver = shared_ptr<SolverBase>(new L1_MFC(scheduler, blocks,
2879
- PG, QG, model, param, slow_only));
2880
- break;
2881
- case P_ROW_BPR_MFOC:
2882
- solver = shared_ptr<SolverBase>(new ROW_BPR_MFOC(scheduler,
2883
- blocks, PG, QG, model, param, slow_only));
2884
- break;
2885
- case P_COL_BPR_MFOC:
2886
- solver = shared_ptr<SolverBase>(new COL_BPR_MFOC(scheduler,
2887
- blocks, PG, QG, model, param, slow_only));
2888
- break;
2889
- default:
2890
- throw invalid_argument("unknown error function");
2891
- }
2892
- return solver;
2893
- }
2894
-
2895
- void fpsg_core(
2896
- Utility &util,
2897
- Scheduler &sched,
2898
- mf_problem *tr,
2899
- mf_problem *va,
2900
- mf_parameter param,
2901
- mf_float scale,
2902
- vector<BlockBase*> &block_ptrs,
2903
- vector<mf_int> &omega_p,
2904
- vector<mf_int> &omega_q,
2905
- shared_ptr<mf_model> &model,
2906
- vector<mf_int> cv_blocks,
2907
- mf_double *cv_error)
2908
- {
2909
- #if defined USESSE || defined USEAVX
2910
- auto flush_zero_mode = _MM_GET_FLUSH_ZERO_MODE();
2911
- _MM_SET_FLUSH_ZERO_MODE(_MM_FLUSH_ZERO_ON);
2912
- #endif
2913
- if(tr->nnz == 0)
2914
- {
2915
- cout << "warning: train on an empty training set" << endl;
2916
- return;
2917
- }
2918
-
2919
- if(param.fun == P_L2_MFR ||
2920
- param.fun == P_L1_MFR ||
2921
- param.fun == P_KL_MFR)
2922
- {
2923
- switch(param.fun)
2924
- {
2925
- case P_L2_MFR:
2926
- param.lambda_p2 /= scale;
2927
- param.lambda_q2 /= scale;
2928
- param.lambda_p1 /= (mf_float)pow(scale, 1.5);
2929
- param.lambda_q1 /= (mf_float)pow(scale, 1.5);
2930
- break;
2931
- case P_L1_MFR:
2932
- case P_KL_MFR:
2933
- param.lambda_p1 /= sqrt(scale);
2934
- param.lambda_q1 /= sqrt(scale);
2935
- break;
2936
- }
2937
- }
2938
-
2939
- if(!param.quiet)
2940
- {
2941
- cout.width(4);
2942
- cout << "iter";
2943
- cout.width(13);
2944
- cout << "tr_"+util.get_error_legend();
2945
- if(va->nnz != 0)
2946
- {
2947
- cout.width(13);
2948
- cout << "va_"+util.get_error_legend();
2949
- }
2950
- cout.width(13);
2951
- cout << "obj";
2952
- cout << "\n";
2953
- }
2954
-
2955
- bool slow_only = param.lambda_p1 == 0 && param.lambda_q1 == 0? true: false;
2956
- vector<mf_float> PG(model->m*2, 1), QG(model->n*2, 1);
2957
-
2958
- vector<shared_ptr<SolverBase>> solvers(param.nr_threads);
2959
- vector<thread> threads;
2960
- threads.reserve(param.nr_threads);
2961
- for(mf_int i = 0; i < param.nr_threads; ++i)
2962
- {
2963
- solvers[i] = SolverFactory::get_solver(sched, block_ptrs,
2964
- PG.data(), QG.data(),
2965
- *model, param, slow_only);
2966
- threads.emplace_back(&SolverBase::run, solvers[i].get());
2967
- }
2968
-
2969
- for(mf_int iter = 0; iter < param.nr_iters; ++iter)
2970
- {
2971
- sched.wait_for_jobs_done();
2972
-
2973
- if(!param.quiet)
2974
- {
2975
- mf_double reg = 0;
2976
- mf_double reg1 = util.calc_reg1(*model, param.lambda_p1,
2977
- param.lambda_q1, omega_p, omega_q);
2978
- mf_double reg2 = util.calc_reg2(*model, param.lambda_p2,
2979
- param.lambda_q2, omega_p, omega_q);
2980
- mf_double tr_loss = sched.get_loss();
2981
- mf_double tr_error = sched.get_error()/tr->nnz;
2982
-
2983
- switch(param.fun)
2984
- {
2985
- case P_L2_MFR:
2986
- reg = (reg1+reg2)*scale*scale;
2987
- tr_loss *= scale*scale;
2988
- tr_error = sqrt(tr_error*scale*scale);
2989
- break;
2990
- case P_L1_MFR:
2991
- case P_KL_MFR:
2992
- reg = (reg1+reg2)*scale;
2993
- tr_loss *= scale;
2994
- tr_error *= scale;
2995
- break;
2996
- default:
2997
- reg = reg1+reg2;
2998
- break;
2999
- }
3000
-
3001
- cout.width(4);
3002
- cout << iter;
3003
- cout.width(13);
3004
- cout << fixed << setprecision(4) << tr_error;
3005
- if(va->nnz != 0)
3006
- {
3007
- Block va_block(va->R, va->R+va->nnz);
3008
- vector<BlockBase*> va_blocks(1, &va_block);
3009
- vector<mf_int> va_block_ids(1, 0);
3010
- mf_double va_error =
3011
- util.calc_error(va_blocks, va_block_ids, *model)/va->nnz;
3012
- switch(param.fun)
3013
- {
3014
- case P_L2_MFR:
3015
- va_error = sqrt(va_error*scale*scale);
3016
- break;
3017
- case P_L1_MFR:
3018
- case P_KL_MFR:
3019
- va_error *= scale;
3020
- break;
3021
- }
3022
-
3023
- cout.width(13);
3024
- cout << fixed << setprecision(4) << va_error;
3025
- }
3026
- cout.width(13);
3027
- cout << fixed << setprecision(4) << scientific << reg+tr_loss;
3028
- cout << "\n" << flush;
3029
- }
3030
-
3031
- if(iter == 0)
3032
- slow_only = false;
3033
- if(iter == param.nr_iters - 1)
3034
- sched.terminate();
3035
- sched.resume();
3036
- }
3037
-
3038
- for(auto &thread : threads)
3039
- thread.join();
3040
-
3041
- if(cv_error != nullptr && cv_blocks.size() > 0)
3042
- {
3043
- mf_long cv_count = 0;
3044
- for(auto block : cv_blocks)
3045
- cv_count += block_ptrs[block]->get_nnz();
3046
-
3047
- *cv_error = util.calc_error(block_ptrs, cv_blocks, *model)/cv_count;
3048
-
3049
- switch(param.fun)
3050
- {
3051
- case P_L2_MFR:
3052
- *cv_error = sqrt(*cv_error*scale*scale);
3053
- break;
3054
- case P_L1_MFR:
3055
- case P_KL_MFR:
3056
- *cv_error *= scale;
3057
- break;
3058
- }
3059
- }
3060
-
3061
- #if defined USESSE || defined USEAVX
3062
- _MM_SET_FLUSH_ZERO_MODE(flush_zero_mode);
3063
- #endif
3064
- }
3065
-
3066
- shared_ptr<mf_model> fpsg(
3067
- mf_problem const *tr_,
3068
- mf_problem const *va_,
3069
- mf_parameter param,
3070
- vector<mf_int> cv_blocks=vector<mf_int>(),
3071
- mf_double *cv_error=nullptr)
3072
- {
3073
- shared_ptr<mf_model> model;
3074
- try
3075
- {
3076
- Utility util(param.fun, param.nr_threads);
3077
- Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3078
- shared_ptr<mf_problem> tr;
3079
- shared_ptr<mf_problem> va;
3080
- vector<Block> blocks(param.nr_bins*param.nr_bins);
3081
- vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3082
- vector<mf_node*> ptrs;
3083
- vector<mf_int> p_map;
3084
- vector<mf_int> q_map;
3085
- vector<mf_int> inv_p_map;
3086
- vector<mf_int> inv_q_map;
3087
- vector<mf_int> omega_p;
3088
- vector<mf_int> omega_q;
3089
- mf_float avg = 0;
3090
- mf_float std_dev = 0;
3091
- mf_float scale = 1;
3092
-
3093
- if(param.copy_data)
3094
- {
3095
- tr = shared_ptr<mf_problem>(
3096
- Utility::copy_problem(tr_, true), deleter());
3097
- va = shared_ptr<mf_problem>(
3098
- Utility::copy_problem(va_, true), deleter());
3099
- }
3100
- else
3101
- {
3102
- tr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3103
- va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3104
- }
3105
-
3106
- util.collect_info(*tr, avg, std_dev);
3107
-
3108
- if(param.fun == P_L2_MFR ||
3109
- param.fun == P_L1_MFR ||
3110
- param.fun == P_KL_MFR)
3111
- scale = max((mf_float)1e-4, std_dev);
3112
-
3113
- p_map = Utility::gen_random_map(tr->m);
3114
- q_map = Utility::gen_random_map(tr->n);
3115
- inv_p_map = Utility::gen_inv_map(p_map);
3116
- inv_q_map = Utility::gen_inv_map(q_map);
3117
- omega_p = vector<mf_int>(tr->m, 0);
3118
- omega_q = vector<mf_int>(tr->n, 0);
3119
-
3120
- util.shuffle_problem(*tr, p_map, q_map);
3121
- util.shuffle_problem(*va, p_map, q_map);
3122
- util.scale_problem(*tr, (mf_float)1.0/scale);
3123
- util.scale_problem(*va, (mf_float)1.0/scale);
3124
- ptrs = util.grid_problem(*tr, param.nr_bins, omega_p, omega_q, blocks);
3125
-
3126
- model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3127
- tr->m, tr->n, param.k, avg/scale, omega_p, omega_q),
3128
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3129
-
3130
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3131
- block_ptrs[i] = &blocks[i];
3132
-
3133
- fpsg_core(util, sched, tr.get(), va.get(), param, scale,
3134
- block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3135
-
3136
- if(!param.copy_data)
3137
- {
3138
- util.scale_problem(*tr, scale);
3139
- util.scale_problem(*va, scale);
3140
- util.shuffle_problem(*tr, inv_p_map, inv_q_map);
3141
- util.shuffle_problem(*va, inv_p_map, inv_q_map);
3142
- }
3143
-
3144
- util.scale_model(*model, scale);
3145
- Utility::shrink_model(*model, param.k);
3146
- Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3147
- }
3148
- catch(exception const &e)
3149
- {
3150
- cerr << e.what() << endl;
3151
- throw;
3152
- }
3153
- return model;
3154
- }
3155
-
3156
- shared_ptr<mf_model> fpsg_on_disk(
3157
- const string tr_path,
3158
- const string va_path,
3159
- mf_parameter param,
3160
- vector<mf_int> cv_blocks=vector<mf_int>(),
3161
- mf_double *cv_error=nullptr)
3162
- {
3163
- shared_ptr<mf_model> model;
3164
- try
3165
- {
3166
- Utility util(param.fun, param.nr_threads);
3167
- Scheduler sched(param.nr_bins, param.nr_threads, cv_blocks);
3168
- mf_problem tr = {};
3169
- mf_problem va = read_problem(va_path.c_str());
3170
- vector<BlockOnDisk> blocks(param.nr_bins*param.nr_bins);
3171
- vector<BlockBase*> block_ptrs(param.nr_bins*param.nr_bins);
3172
- vector<mf_int> p_map;
3173
- vector<mf_int> q_map;
3174
- vector<mf_int> inv_p_map;
3175
- vector<mf_int> inv_q_map;
3176
- vector<mf_int> omega_p;
3177
- vector<mf_int> omega_q;
3178
- mf_float avg = 0;
3179
- mf_float std_dev = 0;
3180
- mf_float scale = 1;
3181
-
3182
- util.collect_info_on_disk(tr_path, tr, avg, std_dev);
3183
-
3184
- if(param.fun == P_L2_MFR ||
3185
- param.fun == P_L1_MFR ||
3186
- param.fun == P_KL_MFR)
3187
- scale = max((mf_float)1e-4, std_dev);
3188
-
3189
- p_map = Utility::gen_random_map(tr.m);
3190
- q_map = Utility::gen_random_map(tr.n);
3191
- inv_p_map = Utility::gen_inv_map(p_map);
3192
- inv_q_map = Utility::gen_inv_map(q_map);
3193
- omega_p = vector<mf_int>(tr.m, 0);
3194
- omega_q = vector<mf_int>(tr.n, 0);
3195
-
3196
- util.shuffle_problem(va, p_map, q_map);
3197
- util.scale_problem(va, (mf_float)1.0/scale);
3198
-
3199
- util.grid_shuffle_scale_problem_on_disk(
3200
- tr.m, tr.n, param.nr_bins, scale, tr_path,
3201
- p_map, q_map, omega_p, omega_q, blocks);
3202
-
3203
- model = shared_ptr<mf_model>(Utility::init_model(param.fun,
3204
- tr.m, tr.n, param.k, avg/scale, omega_p, omega_q),
3205
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3206
-
3207
- for(mf_int i = 0; i < (mf_long)blocks.size(); ++i)
3208
- block_ptrs[i] = &blocks[i];
3209
-
3210
- fpsg_core(util, sched, &tr, &va, param, scale,
3211
- block_ptrs, omega_p, omega_q, model, cv_blocks, cv_error);
3212
-
3213
- delete [] va.R;
3214
-
3215
- util.scale_model(*model, scale);
3216
- Utility::shrink_model(*model, param.k);
3217
- Utility::shuffle_model(*model, inv_p_map, inv_q_map);
3218
- }
3219
- catch(exception const &e)
3220
- {
3221
- cerr << e.what() << endl;
3222
- throw;
3223
- }
3224
- return model;
3225
- }
3226
-
3227
- // The function implements an efficient method to compute objective function
3228
- // minimized by coordinate descent method.
3229
- //
3230
- // \min_{P, Q} 0.5 * \sum_{(u,v)\in\Omega^+} (1-r_{u,v})^2 +
3231
- // 0.5 * \alpha \sum_{(u,v)\not\in\Omega^+} (c-r_{u,v})^2 +
3232
- // 0.5 * \lambda_p2 * ||P||_F^2 + 0.5 * \lambda_q2 * ||Q||_F^2
3233
- // where
3234
- // 1. (u,v) is a tuple of row index and column index,
3235
- // 2. \Omega^+ a collections of (u,v) which specifies the locations of
3236
- // positive entries in the training matrix.
3237
- // 3. r_{u,v} is the predicted rating at (u,v)
3238
- // 4. \alpha is the weight of negative entries' loss.
3239
- // 5. c is the desired value at every negative entries.
3240
- // 6. ||P||_F is matrix P's Frobenius norm.
3241
- // 7. \lambda_p2 is the regularization coefficient of P.
3242
- //
3243
- // Note that coordinate descent method's P and Q are the transpose
3244
- // counterparts of P and Q in stochastic gradient method. Let R denoates
3245
- // the training matrix. For stochastic gradient method, we have R ~ P^TQ.
3246
- // For coordinate descent method, we have R ~ PQ^T.
3247
- void calc_ccd_one_class_obj(const mf_int nr_threads,
3248
- const mf_float alpha, const mf_float c,
3249
- const mf_int m, const mf_int n, const mf_int d,
3250
- const mf_float lambda_p2, const mf_float lambda_q2,
3251
- const mf_float *P, const mf_float *Q,
3252
- shared_ptr<const mf_problem> data,
3253
- /*output*/ mf_double &obj,
3254
- /*output*/ mf_double &positive_loss,
3255
- /*output*/ mf_double &negative_loss,
3256
- /*output*/ mf_double &reg)
3257
- {
3258
- // Declare regularization term of P.
3259
- mf_double p_square_norm = 0.0;
3260
- // Reduce P along column axis, which is the sum of rows in P.
3261
- vector<mf_double> all_p_sum(d, 0.0);
3262
- // Compute square of Frobenius norm on P and sum of all rows in P.
3263
- for(mf_int k = 0; k < d; ++k)
3264
- {
3265
- // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3266
- mf_double all_p_sum_k = 0.0;
3267
- #if defined USEOMP
3268
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_square_norm,all_p_sum_k)
3269
- #endif
3270
- for(mf_int u = 0; u < m; ++u)
3271
- {
3272
- const mf_float &p_ku = P[u + k * m];
3273
- p_square_norm += p_ku * p_ku;
3274
- all_p_sum_k += p_ku;
3275
- }
3276
- all_p_sum[k] = all_p_sum_k;
3277
- }
3278
-
3279
- // Declare regularization term of Q
3280
- mf_double q_square_norm = 0.0;
3281
- // Reduce Q along column axis, whihc is the sum of rows in Q.
3282
- vector<mf_double> all_q_sum(d, 0.0);
3283
- // Compute square of Frobenius norm on Q and sum of all elements in Q
3284
- for(mf_int k = 0; k < d; ++k)
3285
- {
3286
- // Declare a temporal buffer of all_p_sum[k] for using OpenMP.
3287
- mf_double all_q_sum_k = 0.0;
3288
- #if defined USEOMP
3289
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_square_norm,all_q_sum_k)
3290
- #endif
3291
- for(mf_int v = 0; v < n; ++v)
3292
- {
3293
- const mf_float &q_kv = Q[v + k * n];
3294
- q_square_norm += q_kv * q_kv;
3295
- all_q_sum_k += q_kv;
3296
- }
3297
- all_q_sum[k] = all_q_sum_k;
3298
- }
3299
-
3300
- // PTP = P^T * P, where P^T is the transpose of P. Note that P is a m-by-d
3301
- // matrix and PTP is a d-by-d matrix.
3302
- vector<mf_double> PTP(d * d, 0.0);
3303
- // QTQ = Q^T * P, a d-by-d matrix.
3304
- vector<mf_double> QTQ(d * d, 0.0);
3305
- // We calculate PTP and QTQ because they are needed in the computation of
3306
- // negative entries' loss function.
3307
- for(mf_int k1 = 0; k1 < d; ++k1)
3308
- {
3309
- for(mf_int k2 = 0; k2 < d; ++k2)
3310
- {
3311
- // Inner product of the k1 and k2 columns in P, a m-by-d matrix.
3312
- mf_double p_k1_p_k2_inner_product = 0.0;
3313
- #if defined USEOMP
3314
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:p_k1_p_k2_inner_product)
3315
- #endif
3316
- for(mf_int u = 0; u < m; ++u)
3317
- p_k1_p_k2_inner_product += P[u + k1 * m] * P[u + k2 * m];
3318
- PTP[k1 * d + k2] = p_k1_p_k2_inner_product;
3319
-
3320
- // Inner product of the k1 and k2 columns in Q, a n-by-d matrix.
3321
- mf_double q_k1_q_k2_inner_product = 0.0;
3322
- #if defined USEOMP
3323
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:q_k1_q_k2_inner_product)
3324
- #endif
3325
- for(mf_int v = 0; v < n; ++v)
3326
- q_k1_q_k2_inner_product += Q[v + k1 * n] * Q[v + k2 * n];
3327
- QTQ[k1 * d + k2] = q_k1_q_k2_inner_product;
3328
- }
3329
- }
3330
-
3331
- // Initialize loss function value of positive matrix entries.
3332
- // It consists two parts. The first part is the true prediction error
3333
- // while the second part is only used for implementing faster algorithm.
3334
- mf_double positive_loss1 = 0.0;
3335
- mf_double positive_loss2 = 0.0;
3336
- // Scan through positive matrix entries to compute their loss values.
3337
- // Notice that we assume that positive entries' values are all one.
3338
- #if defined USEOMP
3339
- #pragma omp parallel for num_threads(nr_threads) schedule(static) reduction(+:positive_loss1,positive_loss2)
3340
- #endif
3341
- for(mf_long i = 0; i < data->nnz; ++i)
3342
- {
3343
- const mf_double &r = data->R[i].r;
3344
- positive_loss1 += (1.0 - r) * (1.0 - r);
3345
- positive_loss2 -= alpha * (c - r) * (c - r);
3346
- }
3347
- positive_loss1 *= 0.5;
3348
- positive_loss2 *= 0.5;
3349
-
3350
- // Declare loss terms related to negative matrix entries.
3351
- mf_double negative_loss1 = c * c * m * n;
3352
- mf_double negative_loss2 = 0.0;
3353
- mf_double negative_loss3 = 0.0;
3354
- // Compute loss terms.
3355
- for(mf_int k1 = 0; k1 < d; ++k1)
3356
- {
3357
- negative_loss2 += all_p_sum[k1] * all_q_sum[k1];
3358
- for(mf_int k2 = 0; k2 < d; ++k2)
3359
- negative_loss3 += PTP[k1 + k2 * d] * QTQ[k2 + k1 * d];
3360
- }
3361
- // Compute the loss function of negative matrix entries.
3362
- mf_double negative_loss4 = 0.5 * alpha *
3363
- (negative_loss1 - 2 * c * negative_loss2 + negative_loss3);
3364
-
3365
- // Assign results to output variables.
3366
- reg = 0.5 * lambda_p2 * p_square_norm + 0.5 * lambda_q2 * q_square_norm;
3367
-
3368
- // The function minimized by coordinate descent method.
3369
- obj = positive_loss1 + positive_loss2 + negative_loss4 + reg;
3370
-
3371
- // Sume of squared error over positive matrix entries (i.e., those mf_node's
3372
- // in data).
3373
- positive_loss = positive_loss1;
3374
-
3375
- // Sume of squared error over negative matrix entries (i.e., those mf_node's
3376
- // in data). The value negative_loss4 contains the squared errors by
3377
- // considering positive entries as negative entries, so positive_loss2 is
3378
- // added to compensate that.
3379
- negative_loss = negative_loss4 + positive_loss2;
3380
- }
3381
-
3382
- void ccd_one_class_core(
3383
- const Utility &util,
3384
- shared_ptr<const mf_problem> tr_csr,
3385
- shared_ptr<const mf_problem> tr_csc,
3386
- shared_ptr<const mf_problem> va,
3387
- const mf_parameter param,
3388
- const vector<mf_node*> &ptrs_u,
3389
- const vector<mf_node*> &ptrs_v,
3390
- /*output*/ shared_ptr<mf_model> &model)
3391
- {
3392
- // Check problems stored in CSR and CSC formats
3393
- if(tr_csr == nullptr) throw invalid_argument("CSR problem pointer is null.");
3394
- if(tr_csc == nullptr) throw invalid_argument("CSC problem pointer is null.");
3395
-
3396
- if(tr_csr->m != tr_csc->m)
3397
- throw logic_error(
3398
- "Row counts must be identical in CSR and CSC formats: " +
3399
- to_string(tr_csr->m) + " != " + to_string(tr_csc->m));
3400
- const mf_int m = tr_csr->m;
3401
-
3402
- if(tr_csr->n != tr_csc->n)
3403
- throw logic_error(
3404
- "Column counts must be identical in CSR and CSC formats: " +
3405
- to_string(tr_csr->n) + " != " + to_string(tr_csc->n));
3406
- const mf_int n = tr_csr->n;
3407
-
3408
- if(tr_csc->nnz != tr_csc->nnz)
3409
- throw logic_error(
3410
- "Numbers of data points must be identical in CSR and CSC formats: " +
3411
- to_string(tr_csr->nnz) + " != " + to_string(tr_csc->nnz));
3412
- const mf_long nnz = tr_csr->nnz;
3413
-
3414
- // Check formulation parameters
3415
- if(param.k <= 0)
3416
- throw invalid_argument(
3417
- "Latent dimension must be positive but got " +
3418
- to_string(param.k));
3419
- const mf_int d = param.k;
3420
-
3421
- if(param.lambda_p1 != 0)
3422
- throw invalid_argument(
3423
- "P's L1-regularization coefficient must be zero but got " +
3424
- to_string(param.lambda_p1));
3425
- if(param.lambda_q1 != 0)
3426
- throw invalid_argument(
3427
- "Q's L1-regularization coefficient must be zero but got " +
3428
- to_string(param.lambda_q1));
3429
-
3430
- if(param.lambda_p2 <= 0)
3431
- throw invalid_argument(
3432
- "P's L2-regularization coefficient must be positive but got " +
3433
- to_string(param.lambda_p2));
3434
- if(param.lambda_q2 <= 0)
3435
- throw invalid_argument(
3436
- "Q's L2-regularization coefficient must be positive but got " +
3437
- to_string(param.lambda_q2));
3438
-
3439
- // REVIEW: It is not difficult to support non-negative matrix factorization
3440
- // for coordinate descent method; we just need to project the updated value
3441
- // back to the feasible region by using max(0, new_value) right after each
3442
- // Newton step. LIBMF hasn't support it only because we don't see actual
3443
- // users.
3444
- if(param.do_nmf)
3445
- throw invalid_argument(
3446
- "Coordinate descent does not support non-negative constraint");
3447
-
3448
-
3449
- // Check some resources prepared internally
3450
- if(ptrs_u.size() != (size_t)m + 1)
3451
- throw invalid_argument("Number of row pointer must be " +
3452
- to_string(m + 1) + " but got " + to_string(ptrs_u.size()));
3453
- if(ptrs_v.size() != (size_t)n + 1)
3454
- throw invalid_argument("Number of column pointer must be " +
3455
- to_string(n + 1) + " but got " + to_string(ptrs_v.size()));
3456
-
3457
- // Some constants of the formulation.
3458
- // alpha: coefficient of negative part
3459
- // c: the desired prediction values of unobserved ratings
3460
- // lambda_p2: regularization coefficient of P's L2-norm
3461
- // lambda_q2: regularization coefficient of P's Q2-norm
3462
- const mf_float alpha = param.alpha;
3463
- const mf_float c = param.c;
3464
- const mf_float lambda_p2 = param.lambda_p2;
3465
- const mf_float lambda_q2 = param.lambda_q2;
3466
-
3467
- // Initialize P and Q. Note that \bar{q}_{kv} is Q[k*n+v]
3468
- // and \bar{p}_{ku} is P[k*m+u]. One may notice that P and
3469
- // Q here are actually the transposes of P and Q in FPSG.
3470
- mf_float *P = model->P;
3471
- mf_float *Q = model->Q;
3472
-
3473
- // Cache the prediction values on positive matrix entries.
3474
- // Given that P=zero and Q=random initialized in
3475
- // Utility::init_model(mf_int m, mf_int n, mf_int k),
3476
- // all predictions are zeros.
3477
- #if defined USEOMP
3478
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3479
- #endif
3480
- for(mf_long i = 0; i < nnz; ++i)
3481
- {
3482
- tr_csr->R[i].r = 0.0;
3483
- tr_csc->R[i].r = 0.0;
3484
- }
3485
-
3486
- // If the model is not initialized by
3487
- // Utility::init_model(mf_int m, mf_int n, mf_int k),
3488
- // please use the following initialization code to compute
3489
- // and cache all prediction values on positive entries.
3490
- /*
3491
- for(mf_long i = 0; i < nnz; ++i)
3492
- {
3493
- mf_node &node = tr_csr->R[i];
3494
- node.r = 0;
3495
- for(mf_int k = 0; k < d; ++k)
3496
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3497
- }
3498
- for(mf_long i = 0; i < nnz; ++i)
3499
- {
3500
- mf_node &node = tr_csc->R[i];
3501
- node.r = 0;
3502
- for(mf_int k = 0; k < d; ++k)
3503
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3504
- }
3505
- */
3506
-
3507
- if(!param.quiet)
3508
- {
3509
- cout.width(4);
3510
- cout << "iter";
3511
- cout.width(13);
3512
- cout << "tr_"+util.get_error_legend();
3513
- cout.width(14);
3514
- cout << "tr_"+util.get_error_legend() << "+";
3515
- cout.width(14);
3516
- cout << "tr_"+util.get_error_legend() << "-";
3517
- if(va->nnz != 0)
3518
- {
3519
- cout.width(13);
3520
- cout << "va_"+util.get_error_legend();
3521
- cout.width(14);
3522
- cout << "va_"+util.get_error_legend() << "+";
3523
- cout.width(14);
3524
- cout << "va_"+util.get_error_legend() << "-";
3525
- }
3526
- cout.width(13);
3527
- cout << "obj";
3528
- cout << "\n";
3529
- }
3530
-
3531
- /////////////////////////////////////////////////////////////////
3532
- // Minimize the objective function via coordinate descent method
3533
- ////////////////////////////////////////////////////////////////
3534
- // Solve P and Q using coordinate descent.
3535
- // P = [\bar{p}_1, ..., \bar{p}_d] \in R^{m \times k}
3536
- // Q = [\bar{q}_1, ..., \bar{q}_d] \in R^{n \times k}
3537
- // Finally, the rating matrice R would be approximated via
3538
- // R ~ PQ^T \in R^{m \times n}
3539
- for(mf_int outer = 0; outer < param.nr_iters; ++outer)
3540
- {
3541
- // Update \bar{p}_k and \bar{q}_k. The basic idea is
3542
- // to replace \bar{p}_k and \bar{q}_k with a and b,
3543
- // and then minimizes the original objective function.
3544
- for(mf_int k = 0; k < d; ++k)
3545
- {
3546
- // Get the pointer to the first element of \bar{p}_k (and
3547
- // \bar{q}_k).
3548
- mf_float *P_k = P + m * k;
3549
- mf_float *Q_k = Q + n * k;
3550
-
3551
- // Initialize a and b with the value they need to replace
3552
- // so that we can ensure improvement at each iteration.
3553
- vector<mf_float> a(P_k, P_k + m);
3554
- vector<mf_float> b(Q_k, Q_k + n);
3555
-
3556
- for(mf_int inner = 0; inner < 3; ++inner)
3557
- {
3558
- ///////////////////////////////////////////////////////////////
3559
- // Update a:
3560
- // 1. Compute and cache constants
3561
- // 2. For each coordinate of a, calculate optimal update using
3562
- // Newton method
3563
- ///////////////////////////////////////////////////////////////
3564
-
3565
- // Compute and cache constants
3566
- // \hat{b} = \sum_{v=1}^n \bar{b}_v
3567
- // \tilde{b} = \sum_{v=1}^n \bar{b}_v^2
3568
- mf_double b_hat = 0.0;
3569
- mf_double b_tilde = 0.0;
3570
- #if defined USEOMP
3571
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:b_hat,b_tilde)
3572
- #endif
3573
- for(mf_int v = 0; v < n; ++v)
3574
- {
3575
- const mf_double &b_v = b[v];
3576
- b_hat += b_v;
3577
- b_tilde += b_v * b_v;
3578
- }
3579
-
3580
- // Compute and cache a constant vector
3581
- // s_k = \sum_{v=1}^n \bar{q}_{kv}b_v, k = 1, ..., d
3582
- vector<mf_double> s(d, 0.0);
3583
- for(mf_int k1 = 0; k1 < d; ++k1)
3584
- {
3585
- // Buffer variable for using OpenMP
3586
- mf_double s_k1 = 0;
3587
- const mf_float *Q_k1 = Q + k1 * n;
3588
- #if defined USEOMP
3589
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:s_k1)
3590
- #endif
3591
- for(mf_int v = 0; v < n; ++v)
3592
- s_k1 += Q_k1[v] * b[v];
3593
- s[k1] = s_k1;
3594
- }
3595
-
3596
- // Solve a's sub-problem
3597
- #if defined USEOMP
3598
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3599
- #endif
3600
- for(mf_int u = 0; u < m; ++u)
3601
- {
3602
- ////////////////////////////////////////////////////////
3603
- // Update a[u] via Newton method. Let g_u and h_u denote
3604
- // the first-order and second-order derivatives w.r.t.
3605
- // a[u]. The following code implements
3606
- // a[u] <-- a[u] - g_u/h_u
3607
- ////////////////////////////////////////////////////////
3608
-
3609
- // Initialize temporal variables for calculating gradient and hessian.
3610
- mf_double g_u_1 = 0.0;
3611
- mf_double h_u_1 = 0.0;
3612
- mf_double g_u_2 = 0.0;
3613
- // Scan through specified entries at the u-th row
3614
- for(const mf_node *ptr = ptrs_u[u]; ptr != ptrs_u[u+1]; ++ptr)
3615
- {
3616
- const mf_int &v = ptr->v;
3617
- const mf_float &b_v = b[v];
3618
- g_u_1 += b_v;
3619
- h_u_1 += b_v * b_v;
3620
- g_u_2 += (ptr->r - P_k[u] * Q_k[v] + a[u] * b_v) * b_v;
3621
- }
3622
- mf_double g_u_3 = -c * b_hat - P_k[u] * s[k] + a[u] * b_tilde;
3623
- for(mf_int k1 = 0; k1 < d; ++k1)
3624
- g_u_3 += P[m * k1 + u] * s[k1];
3625
- mf_double g_u = -(1.0 - alpha * c) * g_u_1 + (1.0 - alpha) * g_u_2 + alpha * g_u_3 + lambda_p2 * a[u];
3626
- mf_double h_u = (1.0 - alpha) * h_u_1 + alpha * b_tilde + lambda_p2;
3627
- a[u] -= static_cast<mf_float>(g_u / h_u);
3628
- }
3629
-
3630
- ///////////////////////////////////////////////////////////////
3631
- // Update b:
3632
- // 1. Compute and cache constants
3633
- // 2. For each coordinate of b, calculate optimal update using
3634
- // Newton method
3635
- ///////////////////////////////////////////////////////////////
3636
- // Compute and cache a_hat, a_tilde
3637
- // \hat{a} = \sum_{u=1}^m \bar{a}_u
3638
- // \tilde{a} = \sum_{u=1}^m \bar{a}_u^2
3639
- mf_double a_hat = 0.0;
3640
- mf_double a_tilde = 0.0;
3641
- #if defined USEOMP
3642
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:a_hat,a_tilde)
3643
- #endif
3644
- for(mf_int u = 0; u < m; ++u)
3645
- {
3646
- const mf_float &a_u = a[u];
3647
- a_hat += a_u;
3648
- a_tilde += a_u * a_u;
3649
- }
3650
-
3651
- // Compute and cache t
3652
- // t_k = \sum_{u=1}^m \bar{a}_{ku}a_u, k = 1, ..., d
3653
- vector<mf_double> t(d, 0.0);
3654
- for(mf_int k1 = 0; k1 < d; ++k1)
3655
- {
3656
- // Declare buffer variable for using OpenMP
3657
- mf_double t_k1 = 0;
3658
- const mf_float *P_k1 = P + k1 * m;
3659
- #if defined USEOMP
3660
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static) reduction(+:t_k1)
3661
- #endif
3662
- for(mf_int u = 0; u < m; ++u)
3663
- t_k1 += P_k1[u] * a[u];
3664
- t[k1] = t_k1;
3665
- }
3666
-
3667
- #if defined USEOMP
3668
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3669
- #endif
3670
- for(mf_int v = 0; v < n; ++v)
3671
- {
3672
- ////////////////////////////////////////////////////////
3673
- // Update b[v] via Newton method. Let g_v and h_v denote
3674
- // the first-order and second-order derivatives w.r.t.
3675
- // b[v]. The following code implements
3676
- // b[v] <-- b[v] - g_v/h_v
3677
- ////////////////////////////////////////////////////////
3678
-
3679
- // Initialize temporal variables for calculating gradient and hessian.
3680
- mf_double g_v_1 = 0;
3681
- mf_double g_v_2 = 0;
3682
- mf_double h_v_1 = 0;
3683
- // Scan through all positive entries at column v
3684
- for(const mf_node *ptr = ptrs_v[v]; ptr != ptrs_v[v+1]; ++ptr)
3685
- {
3686
- const mf_int &u = ptr->u;
3687
- const mf_float &a_u = a[u];
3688
- g_v_1 += a_u;
3689
- h_v_1 += a_u * a_u;
3690
- g_v_2 += (ptr->r - P_k[u] * Q_k[v] + a_u * b[v]) * a_u;
3691
- }
3692
- mf_double g_v_3 = -c * a_hat - Q_k[v] * t[k] + b[v] * a_tilde;
3693
- for(mf_int k1 = 0; k1 < d; ++k1)
3694
- g_v_3 += Q[n * k1 + v] * t[k1];
3695
- mf_double g_v = -(1.0 - alpha * c) * g_v_1 + (1.0 - alpha) * g_v_2 +
3696
- alpha * g_v_3 + lambda_q2 * b[v];
3697
- mf_double h_v = (1 - alpha) * h_v_1 + alpha * a_tilde + lambda_q2;
3698
- b[v] -= static_cast<mf_float>(g_v / h_v);
3699
- }
3700
-
3701
- ///////////////////////////////////////////////////////////////
3702
- // Update cached variables.
3703
- ///////////////////////////////////////////////////////////////
3704
- // Update prediction error in CSR format
3705
- // \bar{r}_{uv} <- \bar{r}_{uv} - \bar_{p}_{ku}*\bar_{q}_{kv} + a_u*b_v
3706
- #if defined USEOMP
3707
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3708
- #endif
3709
- for(mf_long i = 0; i < tr_csr->nnz; ++i)
3710
- {
3711
- // Update prediction values of positive entries in CSR
3712
- mf_node *csr_ptr = tr_csr->R + i;
3713
- const mf_int &u_csr = csr_ptr->u;
3714
- const mf_int &v_csr = csr_ptr->v;
3715
- csr_ptr->r += a[u_csr] * b[v_csr] - P_k[u_csr] * Q_k[v_csr];
3716
-
3717
- // Update prediction values of positive entries in CSC
3718
- mf_node *csc_ptr = tr_csc->R + i;
3719
- const mf_int &u_csc = csc_ptr->u;
3720
- const mf_int &v_csc = csc_ptr->v;
3721
- csc_ptr->r += a[u_csc] * b[v_csc] - P_k[u_csc] * Q_k[v_csc];
3722
- }
3723
-
3724
- #if defined USEOMP
3725
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3726
- #endif
3727
- // Update P_k and Q_k
3728
- for(mf_int u = 0; u < m; ++u)
3729
- P_k[u] = a[u];
3730
- #if defined USEOMP
3731
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3732
- #endif
3733
- for(mf_int v = 0; v < n; ++v)
3734
- Q_k[v] = b[v];
3735
- }
3736
- }
3737
-
3738
- // Skip the whole evaluation if nothing should be printed out.
3739
- if(param.quiet)
3740
- continue;
3741
-
3742
- // Declare variable for storing objective value being minimized
3743
- // by the training procedure. Note that The objective value consists
3744
- // of two parts, loss function and regularization function.
3745
- mf_double obj = 0;
3746
- // Declare variables for storing loss function's value.
3747
- mf_double positive_loss = 0; // for positive entries in training matrix.
3748
- mf_double negative_loss = 0; // for negative entries in training matrix.
3749
- // Declare variable for storing regularization function's value.
3750
- mf_double reg = 0;
3751
-
3752
- // Compute objective value, loss function value, and regularization
3753
- // function value
3754
- calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3755
- lambda_p2, lambda_q2, P, Q, tr_csr,
3756
- obj, positive_loss, negative_loss, reg);
3757
-
3758
- // Print number of outer iterations.
3759
- cout.width(4);
3760
- cout << outer;
3761
- cout.width(13);
3762
- cout << fixed << setprecision(4) << positive_loss + negative_loss;
3763
- cout.width(15);
3764
- cout << fixed << setprecision(4) << positive_loss;
3765
- cout.width(15);
3766
- cout << fixed << setprecision(4) << negative_loss;
3767
-
3768
- if(va->nnz != 0)
3769
- {
3770
- // The following loop computes prediction scores on validation set.
3771
- // Because training scores is also maintained in coordinate descent
3772
- // framework, we didn't need to actively compute scores on training set.
3773
- #if defined USEOMP
3774
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3775
- #endif
3776
- for(mf_long i = 0; i < va->nnz; ++i)
3777
- {
3778
- mf_node &node = va->R[i];
3779
- node.r = 0;
3780
- for(mf_int k = 0; k < d; ++k)
3781
- node.r += P[node.u + k * m]*Q[node.v + k * n];
3782
- }
3783
-
3784
- mf_double va_obj = 0;
3785
- mf_double va_positive_loss = 0;
3786
- mf_double va_negative_loss = 0;
3787
- mf_double va_reg = 0;
3788
-
3789
- calc_ccd_one_class_obj(util.get_thread_number(), alpha, c, m, n, d,
3790
- lambda_p2, lambda_q2, P, Q, va,
3791
- va_obj, va_positive_loss, va_negative_loss, va_reg);
3792
-
3793
- cout.width(13);
3794
- cout << fixed << setprecision(4) << va_positive_loss + va_negative_loss;
3795
- cout.width(15);
3796
- cout << fixed << setprecision(4) << va_positive_loss;
3797
- cout.width(15);
3798
- cout << fixed << setprecision(4) << va_negative_loss;
3799
- }
3800
-
3801
- cout.width(13);
3802
- cout << fixed << setprecision(4) << scientific << obj;
3803
- cout << "\n" << flush;
3804
- }
3805
-
3806
- // Transpose P and Q. Note that the format of P and Q here are different
3807
- // than that for mf_model.
3808
-
3809
- mf_float *P_transpose = Utility::malloc_aligned_float((mf_long)m * d);
3810
- #if defined USEOMP
3811
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3812
- #endif
3813
- for(mf_int u = 0; u < m; ++u)
3814
- for(mf_int k = 0; k < d; ++k)
3815
- P_transpose[k + u * d] = P[u + k * m];
3816
- Utility::free_aligned_float(P);
3817
- mf_float *Q_transpose = Utility::malloc_aligned_float((mf_long)n * d);
3818
- #if defined USEOMP
3819
- #pragma omp parallel for num_threads(util.get_thread_number()) schedule(static)
3820
- #endif
3821
- for(mf_int v = 0; v < n; ++v)
3822
- for(mf_int k = 0; k < d; ++k)
3823
- Q_transpose[k + v * d] = Q[v + k * n];
3824
- Utility::free_aligned_float(Q);
3825
-
3826
- // Set the passed-in model to the result learned from the given data
3827
- // model is null
3828
- model->m = m;
3829
- model->n = n;
3830
- model->k = d;
3831
- model->b = 0.0;
3832
- model->P = P_transpose;
3833
- model->Q = Q_transpose;
3834
- }
3835
-
3836
- shared_ptr<mf_model> ccd_one_class(
3837
- mf_problem const *tr_,
3838
- mf_problem const *va_,
3839
- mf_parameter param)
3840
- {
3841
- shared_ptr<mf_model> model;
3842
- try
3843
- {
3844
- Utility util(param.fun, param.nr_threads);
3845
- // Training matrix in compressed row format (sort nodes by user id)
3846
- shared_ptr<mf_problem> tr_csr;
3847
- // Training matrix in compressed column format (sort nodes by item id)
3848
- shared_ptr<mf_problem> tr_csc;
3849
- shared_ptr<mf_problem> va;
3850
- // In tr_csr->R, i-th row starting at row_ptrs[i] and eneding right before row_ptrs[i+1]
3851
- vector<mf_node*> ptrs_u(tr_->m + 1, nullptr);
3852
- // In tr_csv->R, i-th column starting at col_ptrs[i] and eneding right before col_ptrs[i+1]
3853
- vector<mf_node*> ptrs_v(tr_->n + 1, nullptr);
3854
-
3855
- if(param.copy_data)
3856
- {
3857
- // Need a row-major and a column-major training formats
3858
- // Thus, two duplicates are made.
3859
- tr_csr = shared_ptr<mf_problem>(
3860
- Utility::copy_problem(tr_, true), deleter());
3861
- tr_csc = shared_ptr<mf_problem>(
3862
- Utility::copy_problem(tr_, true), deleter());
3863
- va = shared_ptr<mf_problem>(
3864
- Utility::copy_problem(va_, true), deleter());
3865
- }
3866
- else
3867
- {
3868
- // Need a row-major and a column-major training formats
3869
- // The original data is reused as row-major one so
3870
- // one duplicate for column-major one would be created.
3871
- tr_csr = shared_ptr<mf_problem>(Utility::copy_problem(tr_, false));
3872
- tr_csc = shared_ptr<mf_problem>(Utility::copy_problem(tr_, true));
3873
- va = shared_ptr<mf_problem>(Utility::copy_problem(va_, false));
3874
- }
3875
-
3876
- // Make the training set CSR/CSC by sorting their nodes. More specifically,
3877
- // a matrix with values sorted by row index is CSR and vice versa. We will
3878
- // compute the starting location for each row (CSR) and each column (CSC)
3879
- // later.
3880
- sort(tr_csr->R, tr_csr->R+tr_csr->nnz, sort_node_by_p());
3881
- sort(tr_csc->R, tr_csc->R+tr_csc->nnz, sort_node_by_q());
3882
-
3883
- // Save starting addresses of rows for CSR and columns for CSC.
3884
- mf_int u_current = -1;
3885
- mf_int v_current = -1;
3886
- for(mf_long i = 0; i < tr_->nnz; ++i)
3887
- {
3888
- mf_node* N = nullptr;
3889
-
3890
- // Deal with CSR format.
3891
- N = tr_csr->R + i;
3892
- // Since tr_csr has been sorted by index u, seeing a larger index
3893
- // implies a new row. Assume a node is encoded a tuple of (u, v, r),
3894
- // where u is row index, v is column index, and r is entry value.
3895
- // The nodes in tr_csr->R could be
3896
- // (0, 1, 0.5), (0, 2, 3.7), (0, 4, -1.2), (2, 0, 1.2), (2, 4, 2.5)
3897
- // Then, we can see the first element of the 3rd row (indexed by 2)
3898
- // is (2, 0, 1.2), which is the 4th element in tr_csr->R. Note that
3899
- // we use the row pointer of the next non-empty row as the pointers
3900
- // of empty rows. That is,
3901
- // ptrs[0] = pointer of (0, 1, 0.5)
3902
- // ptrs[1] = pointer of (0, 2, 1.5)
3903
- // ptrs[2] = pointer of (0, 2, 1.5)
3904
- if(N->u > u_current)
3905
- {
3906
- // We (if u_current != -1) have assigned starting addresses to rows
3907
- // indexed by values smaller than or equal to u_current. Thus, we
3908
- // should handle all rows indexed starting from u_current+1 to the
3909
- // seen row index N->u.
3910
- for(mf_int u_passed = u_current + 1; u_passed <= N->u; ++u_passed)
3911
- {
3912
- // i-th non-zero value's location in tr_csr is the starting
3913
- // address of u_passed-th row.
3914
- ptrs_u[u_passed] = tr_csr->R + i;
3915
- }
3916
- u_current = N->u;
3917
- }
3918
-
3919
- // Deal with CSC format
3920
- N = tr_csc->R + i;
3921
- if(N->v > v_current)
3922
- {
3923
- // We (if v_current != -1) have assigned starting addresses to rows
3924
- // indexed by values smaller than or equal to v_current. Thus, we
3925
- // should handle all columns indexed starting from v_current+1 to
3926
- // the seen row index N->v.
3927
- for(mf_int v_passed = v_current + 1; v_passed <= N->v; ++v_passed)
3928
- {
3929
- // i-th non-zero value's location in tr_csc is the starting
3930
- // address of v_passed-th column.
3931
- ptrs_v[v_passed] = tr_csc->R + i;
3932
- }
3933
- v_current = N->v;
3934
- }
3935
-
3936
- }
3937
- // The bound of the last row. It's the address one-element behind the last
3938
- // matrix entry.
3939
- for(mf_int u_passed = u_current + 1; u_passed <= tr_->m; ++u_passed)
3940
- ptrs_u[u_passed] = tr_csr->R + tr_csr->nnz;
3941
- // The bound of the last column.
3942
- for(mf_int v_passed = v_current + 1; v_passed <= tr_->n; ++v_passed)
3943
- ptrs_v[v_passed] = tr_csc->R + tr_csc->nnz;
3944
-
3945
-
3946
- model = shared_ptr<mf_model>(Utility::init_model(tr_->m, tr_->n, param.k),
3947
- [] (mf_model *ptr) { mf_destroy_model(&ptr); });
3948
-
3949
- ccd_one_class_core(util, tr_csr, tr_csc, va, param, ptrs_u, ptrs_v, model);
3950
- }
3951
- catch(exception const &e)
3952
- {
3953
- cerr << e.what() << endl;
3954
- throw;
3955
- }
3956
- return model;
3957
- }
3958
-
3959
- bool check_parameter(mf_parameter param)
3960
- {
3961
- if(param.fun != P_L2_MFR &&
3962
- param.fun != P_L1_MFR &&
3963
- param.fun != P_KL_MFR &&
3964
- param.fun != P_LR_MFC &&
3965
- param.fun != P_L2_MFC &&
3966
- param.fun != P_L1_MFC &&
3967
- param.fun != P_ROW_BPR_MFOC &&
3968
- param.fun != P_COL_BPR_MFOC &&
3969
- param.fun != P_L2_MFOC)
3970
- {
3971
- cerr << "unknown loss function" << endl;
3972
- return false;
3973
- }
3974
-
3975
- if(param.k < 1)
3976
- {
3977
- cerr << "number of factors must be greater than zero" << endl;
3978
- return false;
3979
- }
3980
-
3981
- if(param.nr_threads < 1)
3982
- {
3983
- cerr << "number of threads must be greater than zero" << endl;
3984
- return false;
3985
- }
3986
-
3987
- if(param.nr_bins < 1 || param.nr_bins < param.nr_threads)
3988
- {
3989
- cerr << "number of bins must be greater than number of threads"
3990
- << endl;
3991
- return false;
3992
- }
3993
-
3994
- if(param.nr_iters < 1)
3995
- {
3996
- cerr << "number of iterations must be greater than zero" << endl;
3997
- return false;
3998
- }
3999
-
4000
- if(param.lambda_p1 < 0 ||
4001
- param.lambda_p2 < 0 ||
4002
- param.lambda_q1 < 0 ||
4003
- param.lambda_q2 < 0)
4004
- {
4005
- cerr << "regularization coefficient must be non-negative" << endl;
4006
- return false;
4007
- }
4008
-
4009
- if(param.eta <= 0)
4010
- {
4011
- cerr << "learning rate must be greater than zero" << endl;
4012
- return false;
4013
- }
4014
-
4015
- if(param.fun == P_KL_MFR && !param.do_nmf)
4016
- {
4017
- cerr << "--nmf must be set when using generalized KL-divergence"
4018
- << endl;
4019
- return false;
4020
- }
4021
-
4022
- if(param.nr_bins <= 2*param.nr_threads)
4023
- {
4024
- cerr << "Warning: insufficient blocks may slow down the training"
4025
- << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4026
- }
4027
-
4028
- if(param.nr_bins <= 2*param.nr_threads)
4029
- {
4030
- cerr << "Warning: insufficient blocks may slow down the training"
4031
- << "process (4*nr_threads^2+1 blocks is suggested)" << endl;
4032
- }
4033
-
4034
- if(param.alpha < 0)
4035
- {
4036
- cerr << "alpha must be a non-negative number" << endl;
4037
- }
4038
-
4039
- return true;
4040
- }
4041
-
4042
- //--------------------------------------
4043
- //-----Classes for cross validation-----
4044
- //--------------------------------------
4045
-
4046
- class CrossValidatorBase
4047
- {
4048
- public:
4049
- CrossValidatorBase(mf_parameter param_, mf_int nr_folds_);
4050
- mf_double do_cross_validation();
4051
- virtual mf_double do_cv1(vector<mf_int> &hidden_blocks) = 0;
4052
- protected:
4053
- mf_parameter param;
4054
- mf_int nr_bins;
4055
- mf_int nr_folds;
4056
- mf_int nr_blocks_per_fold;
4057
- bool quiet;
4058
- Utility util;
4059
- mf_double cv_error;
4060
- };
4061
-
4062
- CrossValidatorBase::CrossValidatorBase(mf_parameter param_, mf_int nr_folds_)
4063
- : param(param_), nr_bins(param_.nr_bins), nr_folds(nr_folds_),
4064
- nr_blocks_per_fold(nr_bins*nr_bins/nr_folds), quiet(param_.quiet),
4065
- util(param.fun, param.nr_threads), cv_error(0)
4066
- {
4067
- param.quiet = true;
4068
- }
4069
-
4070
- mf_double CrossValidatorBase::do_cross_validation()
4071
- {
4072
- vector<mf_int> cv_blocks;
4073
- srand(0);
4074
- for(mf_int block = 0; block < nr_bins*nr_bins; ++block)
4075
- cv_blocks.push_back(block);
4076
- random_shuffle(cv_blocks.begin(), cv_blocks.end());
4077
-
4078
- if(!quiet)
4079
- {
4080
- cout.width(4);
4081
- cout << "fold";
4082
- cout.width(10);
4083
- cout << util.get_error_legend();
4084
- cout << endl;
4085
- }
4086
-
4087
- cv_error = 0;
4088
-
4089
- for(mf_int fold = 0; fold < nr_folds; ++fold)
4090
- {
4091
- mf_int begin = fold*nr_blocks_per_fold;
4092
- mf_int end = min((fold+1)*nr_blocks_per_fold, nr_bins*nr_bins);
4093
- vector<mf_int> hidden_blocks(cv_blocks.begin()+begin,
4094
- cv_blocks.begin()+end);
4095
-
4096
- mf_double err = do_cv1(hidden_blocks);
4097
- cv_error += err;
4098
-
4099
- if(!quiet)
4100
- {
4101
- cout.width(4);
4102
- cout << fold;
4103
- cout.width(10);
4104
- cout << fixed << setprecision(4) << err;
4105
- cout << endl;
4106
- }
4107
- }
4108
-
4109
- if(!quiet)
4110
- {
4111
- cout.width(14);
4112
- cout.fill('=');
4113
- cout << "" << endl;
4114
- cout.fill(' ');
4115
- cout.width(4);
4116
- cout << "avg";
4117
- cout.width(10);
4118
- cout << fixed << setprecision(4) << cv_error/nr_folds;
4119
- cout << endl;
4120
- }
4121
-
4122
- return cv_error/nr_folds;
4123
- }
4124
-
4125
- class CrossValidator : public CrossValidatorBase
4126
- {
4127
- public:
4128
- CrossValidator(
4129
- mf_parameter param_, mf_int nr_folds_, mf_problem const *prob_)
4130
- : CrossValidatorBase(param_, nr_folds_), prob(prob_) {};
4131
- mf_double do_cv1(vector<mf_int> &hidden_blocks);
4132
- private:
4133
- mf_problem const *prob;
4134
- };
4135
-
4136
- mf_double CrossValidator::do_cv1(vector<mf_int> &hidden_blocks)
4137
- {
4138
- mf_double err = 0;
4139
- fpsg(prob, nullptr, param, hidden_blocks, &err);
4140
- return err;
4141
- }
4142
-
4143
- class CrossValidatorOnDisk : public CrossValidatorBase
4144
- {
4145
- public:
4146
- CrossValidatorOnDisk(
4147
- mf_parameter param_, mf_int nr_folds_, string data_path_)
4148
- : CrossValidatorBase(param_, nr_folds_), data_path(data_path_) {};
4149
- mf_double do_cv1(vector<mf_int> &hidden_blocks);
4150
- private:
4151
- string data_path;
4152
- };
4153
-
4154
- mf_double CrossValidatorOnDisk::do_cv1(vector<mf_int> &hidden_blocks)
4155
- {
4156
- mf_double err = 0;
4157
- fpsg_on_disk(data_path, string(), param, hidden_blocks, &err);
4158
- return err;
4159
- }
4160
-
4161
- } // unnamed namespace
4162
-
4163
- mf_model* mf_train_with_validation(
4164
- mf_problem const *tr,
4165
- mf_problem const *va,
4166
- mf_parameter param)
4167
- {
4168
- if(!check_parameter(param))
4169
- return nullptr;
4170
-
4171
- shared_ptr<mf_model> model(nullptr);
4172
-
4173
- if(param.fun != P_L2_MFOC)
4174
- // Use stochastic gradient method
4175
- model = fpsg(tr, va, param);
4176
- else
4177
- // Use coordinate descent method
4178
- model = ccd_one_class(tr, va, param);
4179
-
4180
- mf_model *model_ret = new mf_model;
4181
-
4182
- model_ret->fun = model->fun;
4183
- model_ret->m = model->m;
4184
- model_ret->n = model->n;
4185
- model_ret->k = model->k;
4186
- model_ret->b = model->b;
4187
-
4188
- model_ret->P = model->P;
4189
- model->P = nullptr;
4190
-
4191
- model_ret->Q = model->Q;
4192
- model->Q = nullptr;
4193
-
4194
- return model_ret;
4195
- }
4196
-
4197
- mf_model* mf_train_with_validation_on_disk(
4198
- char const *tr_path,
4199
- char const *va_path,
4200
- mf_parameter param)
4201
- {
4202
- // Two conditions lead to empty model. First, any parameter is not in its
4203
- // supported range. Second, one-class matrix facotorization with L2-loss
4204
- // (-f 12) doesn't support disk-level training.
4205
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4206
- return nullptr;
4207
-
4208
- shared_ptr<mf_model> model = fpsg_on_disk(
4209
- string(tr_path), string(va_path), param);
4210
-
4211
- mf_model *model_ret = new mf_model;
4212
-
4213
- model_ret->fun = model->fun;
4214
- model_ret->m = model->m;
4215
- model_ret->n = model->n;
4216
- model_ret->k = model->k;
4217
- model_ret->b = model->b;
4218
-
4219
- model_ret->P = model->P;
4220
- model->P = nullptr;
4221
-
4222
- model_ret->Q = model->Q;
4223
- model->Q = nullptr;
4224
-
4225
- return model_ret;
4226
- }
4227
-
4228
- mf_model* mf_train(mf_problem const *prob, mf_parameter param)
4229
- {
4230
- return mf_train_with_validation(prob, nullptr, param);
4231
- }
4232
-
4233
- mf_model* mf_train_on_disk(char const *tr_path, mf_parameter param)
4234
- {
4235
- return mf_train_with_validation_on_disk(tr_path, "", param);
4236
- }
4237
-
4238
- mf_double mf_cross_validation(
4239
- mf_problem const *prob,
4240
- mf_int nr_folds,
4241
- mf_parameter param)
4242
- {
4243
- // Two conditions lead to empty model. First, any parameter is not in its
4244
- // supported range. Second, one-class matrix facotorization with L2-loss
4245
- // (-f 12) doesn't support disk-level training.
4246
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4247
- return 0;
4248
-
4249
- CrossValidator validator(param, nr_folds, prob);
4250
-
4251
- return validator.do_cross_validation();
4252
- }
4253
-
4254
- mf_double mf_cross_validation_on_disk(
4255
- char const *prob,
4256
- mf_int nr_folds,
4257
- mf_parameter param)
4258
- {
4259
- // Two conditions lead to empty model. First, any parameter is not in its
4260
- // supported range. Second, one-class matrix facotorization with L2-loss
4261
- // (-f 12) doesn't support disk-level training.
4262
- if(!check_parameter(param) || param.fun == P_L2_MFOC)
4263
- return 0;
4264
-
4265
- CrossValidatorOnDisk validator(param, nr_folds, string(prob));
4266
-
4267
- return validator.do_cross_validation();
4268
- }
4269
-
4270
- mf_problem read_problem(string path)
4271
- {
4272
- mf_problem prob;
4273
- prob.m = 0;
4274
- prob.n = 0;
4275
- prob.nnz = 0;
4276
- prob.R = nullptr;
4277
-
4278
- if(path.empty())
4279
- return prob;
4280
-
4281
- ifstream f(path);
4282
- if(!f.is_open())
4283
- return prob;
4284
-
4285
- string line;
4286
- while(getline(f, line))
4287
- prob.nnz += 1;
4288
-
4289
- mf_node *R = new mf_node[static_cast<size_t>(prob.nnz)];
4290
-
4291
- f.close();
4292
- f.open(path);
4293
-
4294
- mf_long idx = 0;
4295
- for(mf_node N; f >> N.u >> N.v >> N.r;)
4296
- {
4297
- if(N.u+1 > prob.m)
4298
- prob.m = N.u+1;
4299
- if(N.v+1 > prob.n)
4300
- prob.n = N.v+1;
4301
- R[idx] = N;
4302
- ++idx;
4303
- }
4304
- prob.R = R;
4305
-
4306
- f.close();
4307
-
4308
- return prob;
4309
- }
4310
-
4311
- mf_int mf_save_model(mf_model const *model, char const *path)
4312
- {
4313
- ofstream f(path);
4314
- if(!f.is_open())
4315
- return 1;
4316
-
4317
- f << "f " << model->fun << endl;
4318
- f << "m " << model->m << endl;
4319
- f << "n " << model->n << endl;
4320
- f << "k " << model->k << endl;
4321
- f << "b " << model->b << endl;
4322
-
4323
- auto write = [&] (mf_float *ptr, mf_int size, char prefix)
4324
- {
4325
- for(mf_int i = 0; i < size; ++i)
4326
- {
4327
- mf_float *ptr1 = ptr + (mf_long)i*model->k;
4328
- f << prefix << i << " ";
4329
- if(isnan(ptr1[0]))
4330
- {
4331
- f << "F ";
4332
- for(mf_int d = 0; d < model->k; ++d)
4333
- f << 0 << " ";
4334
- }
4335
- else
4336
- {
4337
- f << "T ";
4338
- for(mf_int d = 0; d < model->k; ++d)
4339
- f << ptr1[d] << " ";
4340
- }
4341
- f << endl;
4342
- }
4343
-
4344
- };
4345
-
4346
- write(model->P, model->m, 'p');
4347
- write(model->Q, model->n, 'q');
4348
-
4349
- f.close();
4350
-
4351
- return 0;
4352
- }
4353
-
4354
- mf_model* mf_load_model(char const *path)
4355
- {
4356
- ifstream f(path);
4357
- if(!f.is_open())
4358
- return nullptr;
4359
-
4360
- string dummy;
4361
-
4362
- mf_model *model = new mf_model;
4363
- model->P = nullptr;
4364
- model->Q = nullptr;
4365
-
4366
- f >> dummy >> model->fun >> dummy >> model->m >> dummy >> model->n >>
4367
- dummy >> model->k >> dummy >> model->b;
4368
-
4369
- try
4370
- {
4371
- model->P = Utility::malloc_aligned_float((mf_long)model->m*model->k);
4372
- model->Q = Utility::malloc_aligned_float((mf_long)model->n*model->k);
4373
- }
4374
- catch(bad_alloc const &e)
4375
- {
4376
- cerr << e.what() << endl;
4377
- mf_destroy_model(&model);
4378
- return nullptr;
4379
- }
4380
-
4381
- auto read = [&] (mf_float *ptr, mf_int size)
4382
- {
4383
- for(mf_int i = 0; i < size; ++i)
4384
- {
4385
- mf_float *ptr1 = ptr + (mf_long)i*model->k;
4386
- f >> dummy >> dummy;
4387
- if(dummy.compare("F") == 0) // nan vector starts with "F"
4388
- for(mf_int d = 0; d < model->k; ++d)
4389
- {
4390
- f >> dummy;
4391
- ptr1[d] = numeric_limits<mf_float>::quiet_NaN();
4392
- }
4393
- else
4394
- for(mf_int d = 0; d < model->k; ++d)
4395
- f >> ptr1[d];
4396
- }
4397
- };
4398
-
4399
- read(model->P, model->m);
4400
- read(model->Q, model->n);
4401
-
4402
- f.close();
4403
-
4404
- return model;
4405
- }
4406
-
4407
- void mf_destroy_model(mf_model **model)
4408
- {
4409
- if(model == nullptr || *model == nullptr)
4410
- return;
4411
- Utility::free_aligned_float((*model)->P);
4412
- Utility::free_aligned_float((*model)->Q);
4413
- delete *model;
4414
- *model = nullptr;
4415
- }
4416
-
4417
- mf_float mf_predict(mf_model const *model, mf_int u, mf_int v)
4418
- {
4419
- if(u < 0 || u >= model->m || v < 0 || v >= model->n)
4420
- return model->b;
4421
-
4422
- mf_float *p = model->P+(mf_long)u*model->k;
4423
- mf_float *q = model->Q+(mf_long)v*model->k;
4424
-
4425
- mf_float z = std::inner_product(p, p+model->k, q, (mf_float)0.0f);
4426
-
4427
- if(isnan(z))
4428
- z = model->b;
4429
-
4430
- if(model->fun == P_L2_MFC ||
4431
- model->fun == P_L1_MFC ||
4432
- model->fun == P_LR_MFC)
4433
- z = z > 0.0f? 1.0f: -1.0f;
4434
-
4435
- return z;
4436
- }
4437
-
4438
- mf_double calc_rmse(mf_problem *prob, mf_model *model)
4439
- {
4440
- if(prob->nnz == 0)
4441
- return 0;
4442
- mf_double loss = 0;
4443
- #if defined USEOMP
4444
- #pragma omp parallel for schedule(static) reduction(+:loss)
4445
- #endif
4446
- for(mf_long i = 0; i < prob->nnz; ++i)
4447
- {
4448
- mf_node &N = prob->R[i];
4449
- mf_float e = N.r - mf_predict(model, N.u, N.v);
4450
- loss += e*e;
4451
- }
4452
- return sqrt(loss/prob->nnz);
4453
- }
4454
-
4455
- mf_double calc_mae(mf_problem *prob, mf_model *model)
4456
- {
4457
- if(prob->nnz == 0)
4458
- return 0;
4459
- mf_double loss = 0;
4460
- #if defined USEOMP
4461
- #pragma omp parallel for schedule(static) reduction(+:loss)
4462
- #endif
4463
- for(mf_long i = 0; i < prob->nnz; ++i)
4464
- {
4465
- mf_node &N = prob->R[i];
4466
- loss += abs(N.r - mf_predict(model, N.u, N.v));
4467
- }
4468
- return loss/prob->nnz;
4469
- }
4470
-
4471
- mf_double calc_gkl(mf_problem *prob, mf_model *model)
4472
- {
4473
- if(prob->nnz == 0)
4474
- return 0;
4475
- mf_double loss = 0;
4476
- #if defined USEOMP
4477
- #pragma omp parallel for schedule(static) reduction(+:loss)
4478
- #endif
4479
- for(mf_long i = 0; i < prob->nnz; ++i)
4480
- {
4481
- mf_node &N = prob->R[i];
4482
- mf_float z = mf_predict(model, N.u, N.v);
4483
- loss += N.r*log(N.r/z)-N.r+z;
4484
- }
4485
- return loss/prob->nnz;
4486
- }
4487
-
4488
- mf_double calc_logloss(mf_problem *prob, mf_model *model)
4489
- {
4490
- if(prob->nnz == 0)
4491
- return 0;
4492
- mf_double logloss = 0;
4493
- #if defined USEOMP
4494
- #pragma omp parallel for schedule(static) reduction(+:logloss)
4495
- #endif
4496
- for(mf_long i = 0; i < prob->nnz; ++i)
4497
- {
4498
- mf_node &N = prob->R[i];
4499
- mf_float z = mf_predict(model, N.u, N.v);
4500
- if(N.r > 0)
4501
- logloss += log(1.0+exp(-z));
4502
- else
4503
- logloss += log(1.0+exp(z));
4504
- }
4505
- return logloss/prob->nnz;
4506
- }
4507
-
4508
- mf_double calc_accuracy(mf_problem *prob, mf_model *model)
4509
- {
4510
- if(prob->nnz == 0)
4511
- return 0;
4512
- mf_double acc = 0;
4513
- #if defined USEOMP
4514
- #pragma omp parallel for schedule(static) reduction(+:acc)
4515
- #endif
4516
- for(mf_long i = 0; i < prob->nnz; ++i)
4517
- {
4518
- mf_node &N = prob->R[i];
4519
- mf_float z = mf_predict(model, N.u, N.v);
4520
- if(N.r > 0)
4521
- acc += z > 0? 1: 0;
4522
- else
4523
- acc += z < 0? 1: 0;
4524
- }
4525
- return acc/prob->nnz;
4526
- }
4527
-
4528
- pair<mf_double, mf_double> calc_mpr_auc(mf_problem *prob,
4529
- mf_model *model, bool transpose)
4530
- {
4531
- mf_int mf_node::*row_ptr;
4532
- mf_int mf_node::*col_ptr;
4533
- mf_int m = 0, n = 0;
4534
- if(!transpose)
4535
- {
4536
- row_ptr = &mf_node::u;
4537
- col_ptr = &mf_node::v;
4538
- m = max(prob->m, model->m);
4539
- n = max(prob->n, model->n);
4540
- }
4541
- else
4542
- {
4543
- row_ptr = &mf_node::v;
4544
- col_ptr = &mf_node::u;
4545
- m = max(prob->n, model->n);
4546
- n = max(prob->m, model->m);
4547
- }
4548
-
4549
- auto sort_by_id = [&] (mf_node const &lhs, mf_node const &rhs)
4550
- {
4551
- return tie(lhs.*row_ptr, lhs.*col_ptr) <
4552
- tie(rhs.*row_ptr, rhs.*col_ptr);
4553
- };
4554
-
4555
- sort(prob->R, prob->R+prob->nnz, sort_by_id);
4556
-
4557
- auto sort_by_pred = [&] (pair<mf_node, mf_float> const &lhs,
4558
- pair<mf_node, mf_float> const &rhs) { return lhs.second < rhs.second; };
4559
-
4560
- vector<mf_int> pos_cnts(m+1, 0);
4561
- for(mf_int i = 0; i < prob->nnz; ++i)
4562
- pos_cnts[prob->R[i].*row_ptr+1] += 1;
4563
- for(mf_int i = 1; i < m+1; ++i)
4564
- pos_cnts[i] += pos_cnts[i-1];
4565
-
4566
- mf_int total_m = 0;
4567
- mf_long total_pos = 0;
4568
- mf_double all_u_mpr = 0;
4569
- mf_double all_u_auc = 0;
4570
- #if defined USEOMP
4571
- #pragma omp parallel for schedule(static) reduction(+: total_m, total_pos, all_u_mpr, all_u_auc)
4572
- #endif
4573
- for(mf_int i = 0; i < m; ++i)
4574
- {
4575
- if(pos_cnts[i+1]-pos_cnts[i] < 1)
4576
- continue;
4577
-
4578
- vector<pair<mf_node, mf_float>> row(n);
4579
-
4580
- for(mf_int j = 0; j < n; ++j)
4581
- {
4582
- mf_node N;
4583
- N.*row_ptr = i;
4584
- N.*col_ptr = j;
4585
- N.r = 0;
4586
- row[j] = make_pair(N, mf_predict(model, N.u, N.v));
4587
- }
4588
-
4589
- mf_int pos = 0;
4590
- vector<mf_int> index(pos_cnts[i+1]-pos_cnts[i], 0);
4591
- for(mf_int j = pos_cnts[i]; j < pos_cnts[i+1]; ++j)
4592
- {
4593
- if(prob->R[j].r <= 0)
4594
- continue;
4595
-
4596
- mf_int col = prob->R[j].*col_ptr;
4597
- row[col].first.r = prob->R[j].r;
4598
- index[pos] = col;
4599
- pos += 1;
4600
- }
4601
-
4602
- if(n-pos < 1 || pos < 1)
4603
- continue;
4604
-
4605
- ++total_m;
4606
- total_pos += pos;
4607
-
4608
- mf_int count = 0;
4609
- for(mf_int k = 0; k < pos; ++k)
4610
- {
4611
- swap(row[count], row[index[k]]);
4612
- ++count;
4613
- }
4614
- sort(row.begin(), row.begin()+pos, sort_by_pred);
4615
-
4616
- mf_double u_mpr = 0;
4617
- mf_double u_auc = 0;
4618
- for(auto neg_it = row.begin()+pos; neg_it != row.end(); ++neg_it)
4619
- {
4620
- if(row[pos-1].second <= neg_it->second)
4621
- {
4622
- u_mpr += pos;
4623
- continue;
4624
- }
4625
-
4626
- mf_int left = 0;
4627
- mf_int right = pos-1;
4628
- while(left < right)
4629
- {
4630
- mf_int mid = (left+right)/2;
4631
- if(row[mid].second > neg_it->second)
4632
- right = mid;
4633
- else
4634
- left = mid+1;
4635
- }
4636
- u_mpr += left;
4637
- u_auc += pos-left;
4638
- }
4639
-
4640
- all_u_mpr += u_mpr/(n-pos);
4641
- all_u_auc += u_auc/(n-pos)/pos;
4642
- }
4643
-
4644
- all_u_mpr /= total_pos;
4645
- all_u_auc /= total_m;
4646
-
4647
- return make_pair(all_u_mpr, all_u_auc);
4648
- }
4649
-
4650
- mf_double calc_mpr(mf_problem *prob, mf_model *model, bool transpose)
4651
- {
4652
- return calc_mpr_auc(prob, model, transpose).first;
4653
- }
4654
-
4655
- mf_double calc_auc(mf_problem *prob, mf_model *model, bool transpose)
4656
- {
4657
- return calc_mpr_auc(prob, model, transpose).second;
4658
- }
4659
-
4660
- mf_parameter mf_get_default_param()
4661
- {
4662
- mf_parameter param;
4663
-
4664
- param.fun = P_L2_MFR;
4665
- param.k = 8;
4666
- param.nr_threads = 12;
4667
- param.nr_bins = 20;
4668
- param.nr_iters = 20;
4669
- param.lambda_p1 = 0.0f;
4670
- param.lambda_q1 = 0.0f;
4671
- param.lambda_p2 = 0.1f;
4672
- param.lambda_q2 = 0.1f;
4673
- param.eta = 0.1f;
4674
- param.alpha = 1.0f;
4675
- param.c = 0.0001f;
4676
- param.do_nmf = false;
4677
- param.quiet = false;
4678
- param.copy_data = true;
4679
-
4680
- return param;
4681
- }
4682
-
4683
- }