ltfmselector 0.1.10__py3-none-any.whl → 0.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ltfmselector/env.py CHANGED
@@ -16,7 +16,8 @@ class Environment:
16
16
  def __init__(
17
17
  self, X, y, X_bg, fQueryCost, mQueryCost,
18
18
  fRepeatQueryCost, p_wNoFCost, errorCost, pType,
19
- regression_tol, regression_error_rounding, pModels, device
19
+ regression_tol, regression_error_rounding, pModels, device,
20
+ sample_weight=None, **kwargs
20
21
  ):
21
22
  '''
22
23
  The environment with which the agent interacts, including the actions
@@ -34,7 +35,7 @@ class Environment:
34
35
  X_bg : pd.DataFrame
35
36
  Background dataaset, pandas dataframe with the shape:
36
37
  (n_samples+1, n_features)
37
-
38
+
38
39
  An extra row for 'Total', average feature values for all training
39
40
  samples
40
41
 
@@ -67,7 +68,7 @@ class Environment:
67
68
  error is bigger than regression_tol
68
69
 
69
70
  regression_error_rounding : int
70
- Only applicable for regression models. The error between the
71
+ Only applicable for regression models. The error between the
71
72
  prediction and true value is rounded to the input decimal place.
72
73
 
73
74
  pModels : None or ``list of prediction models``
@@ -75,6 +76,9 @@ class Environment:
75
76
 
76
77
  device : ``CPU`` or ``GPU``
77
78
  Computation device
79
+
80
+ sample_weight : list or array or None
81
+ Per-sample weights
78
82
  '''
79
83
  # Datasets
80
84
  self.X = X
@@ -91,6 +95,7 @@ class Environment:
91
95
  self.regression_error_rounding = regression_error_rounding
92
96
 
93
97
  self.device = device
98
+ self.sample_weight = sample_weight
94
99
 
95
100
  # Available prediction models
96
101
  self.pType = pType
@@ -106,6 +111,14 @@ class Environment:
106
111
  # Counter for prediction model change
107
112
  self.pm_nChange = 0
108
113
 
114
+ ### Special-tailored implementation ###
115
+ if "smsproject" in list(kwargs.keys()):
116
+ self.smsproject = True
117
+ else:
118
+ self.smsproject = False
119
+
120
+ self.y_pred_bg = self.get_bgPrediction()
121
+
109
122
  self.state = None
110
123
 
111
124
  def reset(self, sample=None):
@@ -167,7 +180,7 @@ class Environment:
167
180
 
168
181
  return self.state
169
182
 
170
- def step(self, action, sample_weight=None, **kwargs):
183
+ def step(self, action):
171
184
  '''
172
185
  Agent carries out an action.
173
186
 
@@ -177,9 +190,6 @@ class Environment:
177
190
  = -1 (make a prediction with selected features and prediction model)
178
191
  = int : [0, n_features] (query a feature)
179
192
  = int : [n_features, n_features + n_model] (query a prediction model)
180
-
181
- sample_weight : list or array or None
182
- Per-sample weights
183
193
  '''
184
194
  # === === === ===
185
195
  # Query a feature
@@ -229,13 +239,14 @@ class Environment:
229
239
  # Punish agent if it decides to predict without selecting any
230
240
  # features
231
241
  if len(col_to_retain) == 0:
242
+ self.y_pred = self.y_pred_bg[int(self.state[-1])]
232
243
  return [None, -self.p_wNoFCost, True]
233
244
 
234
245
  # === === === ===
235
246
  # Make a prediction with selected features and prediction model
236
247
 
237
248
  ### Special-tailored implementation ###
238
- if "smsproject" in list(kwargs.keys()):
249
+ if self.smsproject:
239
250
  testpatientID = getPatientID(X_test.index[0])
240
251
  otherSP_of_testPatient = [
241
252
  sp for sp in X_train.index if getPatientID(sp) == testpatientID
@@ -254,13 +265,15 @@ class Environment:
254
265
  selected_predModel = self.pModels[int(self.state[-1])]
255
266
 
256
267
  ### Special-tailored implementation ###
257
- if "smsproject" in list(kwargs.keys()):
268
+ if self.smsproject:
258
269
  X_train_wLabel = X_train.copy()
259
270
  X_train_wLabel["Target"] = self.y.loc[X_train_wLabel.index]
260
271
 
261
- sample_weight = balance_classDistribution_patient(
272
+ _weights = balance_classDistribution_patient(
262
273
  X_train_wLabel, "Target"
263
274
  ).to_numpy(dtype=np.float32)[:,0]
275
+ else:
276
+ _weights = self.sample_weight
264
277
 
265
278
  # Convert X_train and y_train into numpy arrays if they are Pandas
266
279
  # DataFrame or Series
@@ -273,16 +286,16 @@ class Environment:
273
286
  if isinstance(X_test, pd.DataFrame):
274
287
  X_test = X_test.values
275
288
 
276
- if sample_weight is None:
289
+ if _weights is None:
277
290
  selected_predModel.fit(X_train, y_train)
278
291
  else:
279
292
  selected_predModel.fit(
280
- X_train, y_train, sample_weight=sample_weight
293
+ X_train, y_train, sample_weight=_weights
281
294
  )
282
-
295
+
283
296
  self.y_pred = selected_predModel.predict(X_test)[0]
284
297
 
285
- if "smsproject" in list(kwargs.keys()):
298
+ if self.smsproject:
286
299
  # Capping values between 0 and 3
287
300
  self.y_pred = capUpperValues(self.y_pred)
288
301
  self.y_pred = capLowerValues(self.y_pred)
@@ -337,6 +350,54 @@ class Environment:
337
350
  '''
338
351
  return int(self.state[-1])
339
352
 
353
+ def get_bgPrediction(self):
354
+ '''
355
+ Get prediction based on background dataset for each type of
356
+ prediction model, fitted with the training samples, to be used
357
+ for the case that the agent decides to make a prediction without
358
+ any recruited features.
359
+ '''
360
+ # Initialize map between model type with background prediction
361
+ yBg_Model = []
362
+
363
+ ### Special-tailored implementation ###
364
+ if self.smsproject:
365
+ X_train_wLabel = self.X.copy()
366
+ X_train_wLabel["Target"] = self.y.loc[X_train_wLabel.index]
367
+
368
+ _weights = balance_classDistribution_patient(
369
+ X_train_wLabel, "Target"
370
+ ).to_numpy(dtype=np.float32)[:,0]
371
+ else:
372
+ _weights = self.sample_weight
373
+
374
+ # DataFrame or Series -> convert to numpy arrays
375
+ if isinstance(self.X, pd.DataFrame):
376
+ _X = self.X.values
377
+
378
+ if isinstance(self.y, pd.Series):
379
+ _y = self.y.values
380
+
381
+ for m in self.pModels:
382
+ # Fit each prediction model with the entire dataset
383
+ if _weights is None:
384
+ m.fit(_X, _y)
385
+ else:
386
+ m.fit(_X, _y, sample_weight=_weights)
387
+
388
+ # Use fitted model to make a prediction based on background
389
+ # dataset
390
+ yBg_Model.append(
391
+ m.predict(self.X_bg.loc[["Total"]])[0]
392
+ )
393
+
394
+ # Capping values between 0 and 3
395
+ if self.smsproject:
396
+ yBg_Model[-1] = capUpperValues(yBg_Model[-1])
397
+ yBg_Model[-1] = capLowerValues(yBg_Model[-1])
398
+
399
+ return yBg_Model
400
+
340
401
  def __getstate__(self):
341
402
  state = self.__dict__.copy()
342
403
  print(state.keys())
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  import torch.nn as nn
3
3
  import torch.optim as optim
4
+ from torch.utils.tensorboard import SummaryWriter
4
5
 
5
6
  import os
6
7
  import random
@@ -120,7 +121,7 @@ class LTFMSelector:
120
121
 
121
122
  If pType == 'regression', then
122
123
  Agent is punished -errorCost*abs(``prediction`` - ``target``)
123
-
124
+
124
125
  If pType == 'classification', then
125
126
  Agent is punished -errorCost
126
127
 
@@ -132,7 +133,7 @@ class LTFMSelector:
132
133
  error is bigger than regression_tol
133
134
 
134
135
  regression_error_rounding : int (default = 1)
135
- Only applicable for regression models. The error between the
136
+ Only applicable for regression models. The error between the
136
137
  prediction and true value is rounded to the input decimal place.
137
138
 
138
139
  pModels : None or ``list of prediction models``
@@ -142,7 +143,7 @@ class LTFMSelector:
142
143
  1. Support Vector Machine
143
144
  2. Random Forest
144
145
  3. Gaussian Naive Bayes
145
-
146
+
146
147
  For regression:
147
148
  1. Support Vector Machine
148
149
  2. Random Forest
@@ -160,7 +161,7 @@ class LTFMSelector:
160
161
  Maximum number of time-steps per episode. Agent will be forced to
161
162
  make a prediction with the selected features and prediction model,
162
163
  if max_timesteps is reached
163
-
164
+
164
165
  If None, max_timesteps will be set to 3 x number_of_features
165
166
 
166
167
  checkpoint_interval : int or None
@@ -226,7 +227,7 @@ class LTFMSelector:
226
227
  def fit(
227
228
  self, X, y, loss_function='mse', sample_weight=None,
228
229
  agent_neuralnetwork=None, lr=1e-5, returnQ=False,
229
- background_dataset=None, **kwargs
230
+ monitor=False, background_dataset=None, **kwargs
230
231
  ):
231
232
  '''
232
233
  Initializes the environment and agent, then trains the agent to select
@@ -255,7 +256,7 @@ class LTFMSelector:
255
256
  integer elements (n1, n2). n1 and n2 pertains to the number of units
256
257
  in the first and second layer of a multilayer-perceptron,
257
258
  implemented in PyTorch.
258
-
259
+
259
260
  If None, a default multilayer-perceptron of two hidden layers, each
260
261
  with 1024 units is used.
261
262
 
@@ -265,7 +266,14 @@ class LTFMSelector:
265
266
 
266
267
  returnQ : bool
267
268
  Return average computed action-value functions and rewards of
268
- the sampled batches, for debugging purposes.
269
+ the sampled batches, as a (<total_iterations>, 3) matrix. The
270
+ columns correspond to the averaged Q, reward, and target functions.
271
+
272
+ monitor : bool
273
+ Monitor training process using a TensorBoard.
274
+
275
+ Run `tensorboard --logdir=runs` in the terminal to monitor the p
276
+ progression of the action-value function.
269
277
 
270
278
  background_dataset : None or pd.DataFrame
271
279
  If None, numerical features will be assumed when computing the
@@ -284,10 +292,10 @@ class LTFMSelector:
284
292
  List of policy network's action-value function, Q(s,a),
285
293
  averaged over the sampled batch during training, per iteration
286
294
  r_avr_list : list
287
- List of rewards, r, averaged over the sampled batch during
295
+ List of rewards, r, averaged over the sampled batch during
288
296
  training, per iteration
289
297
  V_avr_list : list
290
- List of max action-value function for the next state (s'),
298
+ List of max action-value function for the next state (s'),
291
299
  max{a} Q(s', a), averaged over the sampled batch during
292
300
  training, per iteration
293
301
  '''
@@ -312,10 +320,18 @@ class LTFMSelector:
312
320
  self.sample_weight = sample_weight
313
321
 
314
322
  # If user wants to monitor progression of terms in the loss function
323
+ if monitor:
324
+ writer = SummaryWriter()
325
+ monitor_count = 1
326
+
327
+ # If user wants to save average computed action-value functions and
328
+ # rewards of sampled batches
315
329
  if returnQ:
316
- Q_avr_list = []
317
- r_avr_list = []
318
- V_avr_list = []
330
+ total_iterations = 10000000000
331
+ LearningValuesMatrix = np.zeros(
332
+ (total_iterations, 3), dtype=np.float32
333
+ )
334
+ Q_count = 1
319
335
 
320
336
  # Initializing the environment
321
337
  env = Environment(
@@ -323,7 +339,8 @@ class LTFMSelector:
323
339
  self.fQueryCost, self.mQueryCost,
324
340
  self.fRepeatQueryCost, self.p_wNoFCost, self.errorCost,
325
341
  self.pType, self.regression_tol, self.regression_error_rounding,
326
- self.pModels, self.device
342
+ self.pModels, self.device, sample_weight=self.sample_weight,
343
+ **kwargs
327
344
  )
328
345
  env.reset()
329
346
 
@@ -387,9 +404,7 @@ class LTFMSelector:
387
404
  # Agent carries out action on the environment and returns:
388
405
  # - observation (state in next time-step)
389
406
  # - reward
390
- observation, reward, terminated = env.step(
391
- action.item(), sample_weight=self.sample_weight, **kwargs
392
- )
407
+ observation, reward, terminated = env.step(action.item())
393
408
 
394
409
  if terminated:
395
410
  next_state = None
@@ -407,13 +422,21 @@ class LTFMSelector:
407
422
  state = next_state
408
423
 
409
424
  # Optimize the model
410
- _res = self.optimize_model(optimizer, loss_function, returnQ)
425
+ _res = self.optimize_model(optimizer, loss_function, monitor, returnQ)
426
+
427
+ if monitor:
428
+ if not _res is None:
429
+ writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
430
+ writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
431
+ writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
432
+ monitor_count += 1
411
433
 
412
434
  if returnQ:
413
435
  if not _res is None:
414
- Q_avr_list.append(_res[0])
415
- r_avr_list.append(_res[1])
416
- V_avr_list.append(_res[2])
436
+ LearningValuesMatrix[Q_count, 0] = _res[0]
437
+ LearningValuesMatrix[Q_count, 1] = _res[1]
438
+ LearningValuesMatrix[Q_count, 2] = _res[2]
439
+ Q_count += 1
417
440
 
418
441
  # Apply soft update to target network's weights
419
442
  targetParameters = self.target_net.state_dict()
@@ -457,11 +480,19 @@ class LTFMSelector:
457
480
  self.policy_network_checkpoints[self.episodes] =\
458
481
  self.policy_net.state_dict()
459
482
 
483
+ if monitor:
484
+ writer.add_scalar("Metrics/Average_QValue", _res[0], monitor_count)
485
+ writer.add_scalar("Metrics/Average_Reward", _res[1], monitor_count)
486
+ writer.add_scalar("Metrics/Average_Target", _res[2], monitor_count)
487
+ writer.close()
488
+
460
489
  if returnQ:
461
- Q_avr_list.append(_res[0])
462
- r_avr_list.append(_res[1])
463
- V_avr_list.append(_res[2])
464
- return doc, (Q_avr_list, r_avr_list, V_avr_list)
490
+ LearningValuesMatrix[Q_count, 0] = _res[0]
491
+ LearningValuesMatrix[Q_count, 1] = _res[1]
492
+ LearningValuesMatrix[Q_count, 2] = _res[2]
493
+
494
+ if (monitor or returnQ):
495
+ return doc, LearningValuesMatrix[0:Q_count+1, :]
465
496
  else:
466
497
  return doc
467
498
 
@@ -489,7 +520,7 @@ class LTFMSelector:
489
520
  self.fQueryCost, self.mQueryCost,
490
521
  self.fRepeatQueryCost, self.p_wNoFCost, self.errorCost,
491
522
  self.pType, self.regression_tol, self.regression_error_rounding,
492
- self.pModels, self.device
523
+ self.pModels, self.device, **kwargs
493
524
  )
494
525
 
495
526
  # Create dictionary to save information per episode
@@ -513,9 +544,7 @@ class LTFMSelector:
513
544
  if t > self.max_timesteps:
514
545
  action = torch.tensor([[-1]], device=self.device)
515
546
 
516
- observation, reward, terminated = env.step(
517
- action.item(), sample_weight=self.sample_weight, **kwargs
518
- )
547
+ observation, reward, terminated = env.step(action.item())
519
548
 
520
549
  if terminated:
521
550
  next_state = None
@@ -575,7 +604,7 @@ class LTFMSelector:
575
604
  with torch.no_grad():
576
605
  return (self.policy_net(state).max(1)[1].view(1, 1) - 1)
577
606
 
578
- def optimize_model(self, optimizer, loss_function, returnQ):
607
+ def optimize_model(self, optimizer, loss_function, monitor, returnQ):
579
608
  '''
580
609
  Optimize the policy network.
581
610
 
@@ -603,8 +632,8 @@ class LTFMSelector:
603
632
  # 1. Draw a random batch of experiences
604
633
  experiences = self.ReplayMemory.sample(self.batch_size)
605
634
  # [
606
- # Experience #1: (state, action, next_state, reward),
607
- # Experience #2: (state, action, next_state, reward),
635
+ # Experience #1: (state, action, next_state, reward),
636
+ # Experience #2: (state, action, next_state, reward),
608
637
  # ...
609
638
  # ]
610
639
 
@@ -688,7 +717,7 @@ class LTFMSelector:
688
717
  criterion = nn.SmoothL1Loss()
689
718
  else:
690
719
  criterion = loss_function
691
-
720
+
692
721
  loss = criterion(
693
722
  state_action_values, expected_state_action_values.unsqueeze(1)
694
723
  )
@@ -701,7 +730,7 @@ class LTFMSelector:
701
730
  # Optimize the model (policy network)
702
731
  optimizer.step()
703
732
 
704
- if returnQ:
733
+ if (monitor or returnQ):
705
734
  Q_avr = state_action_values.detach().numpy().mean()
706
735
  r_avr = reward_batch.unsqueeze(1).numpy().mean()
707
736
  V_avr = expected_state_action_values.unsqueeze(1).numpy().mean()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ltfmselector
3
- Version: 0.1.10
3
+ Version: 0.1.12
4
4
  Summary: Locally-Tailored Feature and Model Selector with Deep Q-Learning
5
5
  Project-URL: GitHub, https://github.com/RenZhen95/ltfmselector/
6
6
  Author-email: RenZhen95 <j-liaw@hotmail.com>
@@ -32,8 +32,9 @@ Requires-Dist: matplotlib>=3.10.1
32
32
  Requires-Dist: numpy>=2.2.4
33
33
  Requires-Dist: openpyxl>=3.1.5
34
34
  Requires-Dist: pandas>=2.2.3
35
- Requires-Dist: scikit-learn>=1.6.1
35
+ Requires-Dist: scikit-learn<1.6
36
36
  Requires-Dist: seaborn>=0.13.2
37
+ Requires-Dist: tensorboard>=2.20.0
37
38
  Requires-Dist: torch>=2.6.0
38
39
  Description-Content-Type: text/markdown
39
40
 
@@ -0,0 +1,9 @@
1
+ ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
+ ltfmselector/env.py,sha256=vizWGqDSc_2Zfs9aXjFARanIAz6PTKwUHu2_Lew9s3Y,13878
3
+ ltfmselector/ltfmselector.py,sha256=vs9unOmoDKq1piV6t87GC1wdy7kP8ucKHihw6i0F4KI,29567
4
+ ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
+ ltfmselector-0.1.12.dist-info/METADATA,sha256=QaUPeSx9NlZx0ZUbkEPRyFS-8nfJz9Y8yV5TXXPc7fA,3021
7
+ ltfmselector-0.1.12.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
8
+ ltfmselector-0.1.12.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
+ ltfmselector-0.1.12.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -1,9 +0,0 @@
1
- ltfmselector/__init__.py,sha256=lf3e90CNpEDvEmNZ-0iuoHOPsA7D-WN_opbBsTYLVEA,76
2
- ltfmselector/env.py,sha256=mHa6l7mWE5mZGFTGA7sqr2xbGLAuE1ll0c5Lh8Ju5Gw,11854
3
- ltfmselector/ltfmselector.py,sha256=JX3jtlRE2KRUssH-LGwcrvw0y9HALPNQutete6PI09c,28150
4
- ltfmselector/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- ltfmselector/utils.py,sha256=VXYZSDm7x4s0p9F_58NLW8WQa3dxi0vHZewRy6miC2E,5438
6
- ltfmselector-0.1.10.dist-info/METADATA,sha256=TjeFKEBs09qrB3cbDRMXVCJJ-mcE5-CDJ2nju5qoc6w,2989
7
- ltfmselector-0.1.10.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
8
- ltfmselector-0.1.10.dist-info/licenses/LICENSE,sha256=tmIDlkkp4a0EudXuGmeTdGjHjPhmmXkEMshACXLqX2w,1092
9
- ltfmselector-0.1.10.dist-info/RECORD,,