ltfmselector 0.1.13__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ltfmselector/env.py CHANGED
@@ -14,7 +14,10 @@ capLowerValues = lambda x: 0.0 if x < 0.0 else x
14
14
 
15
15
  class Environment:
16
16
  def __init__(
17
- self, X, y, X_bg, fQueryCost, mQueryCost,
17
+ self, X, y, X_bg,
18
+ fQueryCost, fQueryFunction,
19
+ fThreshold, fCap, fRate,
20
+ mQueryCost,
18
21
  fRepeatQueryCost, p_wNoFCost, errorCost, pType,
19
22
  regression_tol, regression_error_rounding, pModels, device,
20
23
  sample_weight=None, **kwargs
@@ -42,6 +45,21 @@ class Environment:
42
45
  fQueryCost : float
43
46
  Cost of querying a feature
44
47
 
48
+ fQueryFunction : None or {'step', 'linear', 'quadratic', 'exponential'}
49
+ Function to progressively increase cost of recruiting a feature
50
+
51
+ fThreshold : None or int
52
+ If `fQueryFunction == {'step', 'linear', 'quadratic', 'exponential'}`
53
+ Threshold of number of features, before cost of recruiting
54
+ increases
55
+
56
+ fCap : None or float
57
+ If `fQueryFunction == {'step'}`, upper limit of penalty
58
+
59
+ fRate : None or float
60
+ If `fQueryFunction == {'linear', 'quadratic', 'exponential'}`, rate
61
+ individual cost functions
62
+
45
63
  mQueryCost : float
46
64
  Cost of querying a prediction model
47
65
 
@@ -87,6 +105,11 @@ class Environment:
87
105
 
88
106
  # Reward functions
89
107
  self.fQueryCost = fQueryCost
108
+ self.fQueryFunction = fQueryFunction
109
+ self.fThreshold = fThreshold
110
+ self.fCap = fCap
111
+ self.fRate = fRate
112
+
90
113
  self.mQueryCost = mQueryCost
91
114
  self.fRepeatQueryCost = fRepeatQueryCost
92
115
  self.p_wNoFCost = p_wNoFCost
@@ -204,7 +227,7 @@ class Environment:
204
227
  self.state[action] = self.X_test.iloc[0, action]
205
228
 
206
229
  # Punish for querying a feature
207
- return [self.state, -self.fQueryCost, False]
230
+ return [self.state, -self.get_fQueryCost(), False]
208
231
 
209
232
  # Punish agent for attempting to query a feature already
210
233
  # previously selected
@@ -331,6 +354,54 @@ class Environment:
331
354
 
332
355
  return [None, 0.0, True]
333
356
 
357
+ def get_fQueryCost(self):
358
+ '''
359
+ Get cost of querying a feature
360
+ '''
361
+ if self.fQueryFunction is None:
362
+ return self.fQueryCost
363
+
364
+ # Get number of total recruited features
365
+ nFSubset = (self.get_feature_mask()).sum()
366
+
367
+ # DEV:: If more than 10 statements, implement dictionary instead
368
+ if self.fQueryFunction == "step":
369
+ return self.get_fQueryCostStep(nFSubset)
370
+ elif self.fQueryFunction == "linear":
371
+ return self.get_fQueryCostLinear(nFSubset)
372
+ elif self.fQueryFunction == "quadratic":
373
+ return self.get_fQueryCostQuadratic(nFSubset)
374
+
375
+ def get_fQueryCostStep(self, _nFSubset):
376
+ '''Step function for querying feature'''
377
+ if _nFSubset > self.fThreshold:
378
+ return self.fCap
379
+ else:
380
+ return self.fQueryCost
381
+
382
+ def get_fQueryCostLinear(self, _nFSubset):
383
+ '''Linear function for querying feature'''
384
+ _qC = max(
385
+ self.fQueryCost,
386
+ self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)
387
+ )
388
+ if not self.fCap is None:
389
+ return min(self.fCap, _qC)
390
+ else:
391
+ return _qC
392
+
393
+ def get_fQueryCostQuadratic(self, _nFSubset):
394
+ '''Quadratic function for querying feature'''
395
+ if _nFSubset > self.fThreshold:
396
+ _qC = self.fQueryCost + self.fRate*(_nFSubset-self.fThreshold)**2
397
+ else:
398
+ _qC = self.fQueryCost
399
+
400
+ if not self.fCap is None:
401
+ return min(self.fCap, _qC)
402
+ else:
403
+ return _qC
404
+
334
405
  def get_feature_mask(self):
335
406
  '''
336
407
  Get the (boolean) feature mask that indicates if a feature has
@@ -72,7 +72,9 @@ class LTFMSelector:
72
72
  def __init__(
73
73
  self, episodes, batch_size=256, tau=0.0005,
74
74
  eps_start=0.9, eps_end=0.05, eps_decay=1000,
75
- fQueryCost=0.01, mQueryCost=0.01,
75
+ fQueryCost=0.01, fQueryFunction=None,
76
+ fThreshold=None, fCap=None, fRate=None,
77
+ mQueryCost=0.01,
76
78
  fRepeatQueryCost=1.0, p_wNoFCost=5.0, errorCost=1.0,
77
79
  pType="regression", regression_tol=0.5,
78
80
  regression_error_rounding=1,
@@ -105,7 +107,33 @@ class LTFMSelector:
105
107
  Rate of exponential decay
106
108
 
107
109
  fQueryCost : float
108
- Cost of querying a feature
110
+ Cost of querying a feature.
111
+
112
+ fQueryFunction : None or {'step', 'linear', 'quadratic'}
113
+ User can also decide to progressively increase the cost of
114
+ querying features in the following manner:
115
+ 'step' :
116
+ Every additional feature adds a fixed constant, determined
117
+ by user.
118
+ 'linear' :
119
+ Cost of every additional feature linearly increases according
120
+ to user-defined gradient
121
+ 'quadratic' :
122
+ Cost of every additional feature increases quadratically,
123
+ according to a user-defined rate
124
+
125
+ fThreshold : None or int
126
+ If `fQueryFunction == {'step', 'linear', 'quadratic', 'exponential'}`
127
+ Threshold of number of features, before cost of recruiting
128
+ increases
129
+
130
+ fCap : None or float
131
+ If `fQueryFunction == {'step', 'linear', 'quadratic'}`, upper
132
+ limit of penalty
133
+
134
+ fRate : None or float
135
+ If `fQueryFunction == {'linear', 'quadratic'}`, rate of
136
+ individual cost functions
109
137
 
110
138
  mQueryCost : float
111
139
  Cost of querying a prediction model
@@ -196,6 +224,35 @@ class LTFMSelector:
196
224
 
197
225
  # Reward function
198
226
  self.fQueryCost = fQueryCost
227
+ self.fQueryFunction = fQueryFunction
228
+ self.fThreshold = fThreshold
229
+ self.fCap = fCap
230
+ self.fRate = fRate
231
+
232
+ # Options for progressive cost functions
233
+ if isinstance(self.fQueryFunction, str):
234
+ fQueryFunctions = ['step', 'linear', 'quadratic']
235
+ if not self.fQueryFunction in fQueryFunctions:
236
+ raise ValueError(
237
+ f"{self.fQueryFunction} is not a valid option. Available " +
238
+ f"options are {fQueryFunctions}"
239
+ )
240
+ else:
241
+ if not isinstance(fThreshold, int):
242
+ raise ValueError("Parameter fThreshold must be an integer!")
243
+
244
+ if self.fQueryFunction == "step":
245
+ if not (isinstance(fCap, float) or isinstance(fCap, int)):
246
+ raise ValueError("Parameter fCap must be an int or float!")
247
+ else:
248
+ self.fCap = float(fCap)
249
+ else:
250
+ if self.fQueryFunction in ["linear", "quadratic"]:
251
+ if not (isinstance(fRate, float) or isinstance(fRate, int)):
252
+ raise ValueError("Parameter fRate must be an int or float!")
253
+ else:
254
+ self.fRate = float(fRate)
255
+
199
256
  self.mQueryCost = mQueryCost
200
257
  self.fRepeatQueryCost = fRepeatQueryCost
201
258
  self.p_wNoFCost = p_wNoFCost
@@ -336,7 +393,9 @@ class LTFMSelector:
336
393
  # Initializing the environment
337
394
  env = Environment(
338
395
  self.X, self.y, self.background_dataset,
339
- self.fQueryCost, self.mQueryCost,
396
+ self.fQueryCost, self.fQueryFunction,
397
+ self.fThreshold, self.fCap, self.fRate,
398
+ self.mQueryCost,
340
399
  self.fRepeatQueryCost, self.p_wNoFCost, self.errorCost,
341
400
  self.pType, self.regression_tol, self.regression_error_rounding,
342
401
  self.pModels, self.device, sample_weight=self.sample_weight,
@@ -484,6 +543,7 @@ class LTFMSelector:
484
543
  writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
485
544
  writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
486
545
  writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
546
+ writer.flush()
487
547
  writer.close()
488
548
 
489
549
  if returnQ:
@@ -517,7 +577,9 @@ class LTFMSelector:
517
577
  # Initializing the environment
518
578
  env = Environment(
519
579
  self.X, self.y, self.background_dataset,
520
- self.fQueryCost, self.mQueryCost,
580
+ self.fQueryCost, self.fQueryFunction,
581
+ self.fThreshold, self.fCap, self.fRate,
582
+ self.mQueryCost,
521
583
  self.fRepeatQueryCost, self.p_wNoFCost, self.errorCost,
522
584
  self.pType, self.regression_tol, self.regression_error_rounding,
523
585
  self.pModels, self.device, **kwargs
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ltfmselector
3
- Version: 0.1.13
3
+ Version: 0.2.1
4
4
  Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
5
5
  Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
6
6
  Author-email: RenZhen95 <j-liaw@hotmail.com>
@@ -0,0 +1,9 @@
1
+ ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
+ ltfmselector/env.py,sha256=F0NycqUkNn-p2zC1EPdds73__G8yyMW5f9F93yPDHTA,16371
3
+ ltfmselector/ltfmselector.py,sha256=GdxazN6JG_ELZ7a7x6bbBCVDsgmdRWOsbi-TnRFNk8Y,32354
4
+ ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
+ ltfmselector-0.2.1.dist-info/METADATA,sha256=mHEsAKWtYYsGOQyRxHX5f8fAaSXePDDklL_CLzMay9A,3020
7
+ ltfmselector-0.2.1.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
+ ltfmselector-0.2.1.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
+ ltfmselector-0.2.1.dist-info/RECORD,,
@@ -1,9 +0,0 @@
1
- ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
- ltfmselector/env.py,sha256=vizWGqDSc_2Zfs9aXjFARanIAz6PTKwUHu2_Lew9s3Y,13878
3
- ltfmselector/ltfmselector.py,sha256=w47J_ktNhfmreX1RqlfFo8Di2aEYdl63oI3EHgguCPM,29571
4
- ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
- ltfmselector-0.1.13.dist-info/METADATA,sha256=dNyHk7_WXjX59D7GmVs_f_Csdngfq2_lesaan24TvEs,3021
7
- ltfmselector-0.1.13.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
- ltfmselector-0.1.13.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
- ltfmselector-0.1.13.dist-info/RECORD,,