actinet 0.0.dev5__tar.gz → 0.0.dev6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {actinet-0.0.dev5 → actinet-0.0.dev6}/PKG-INFO +2 -2
- {actinet-0.0.dev5 → actinet-0.0.dev6}/README.md +1 -1
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/_version.py +3 -3
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/actinet.py +36 -19
- actinet-0.0.dev6/src/actinet/evaluate.py +215 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/hmm.py +14 -11
- actinet-0.0.dev6/src/actinet/models.py +341 -0
- actinet-0.0.dev6/src/actinet/prepare.py +329 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/sslmodel.py +136 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/summarisation.py +16 -10
- actinet-0.0.dev6/src/actinet/utils/utils.py +89 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/PKG-INFO +2 -2
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/SOURCES.txt +2 -0
- actinet-0.0.dev5/src/actinet/models.py +0 -193
- actinet-0.0.dev5/src/actinet/utils/utils.py +0 -36
- {actinet-0.0.dev5 → actinet-0.0.dev6}/LICENSE.md +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/pyproject.toml +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/setup.cfg +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/setup.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/__init__.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/accPlot.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/circadian.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/utils/__init__.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/utils/collate_outputs.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet/utils/generate_commands.py +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/dependency_links.txt +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/entry_points.txt +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/requires.txt +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/src/actinet.egg-info/top_level.txt +0 -0
- {actinet-0.0.dev5 → actinet-0.0.dev6}/versioneer.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: actinet
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.dev6
|
|
4
4
|
Summary: Activity detection algorithm compatible with the UK Biobank Accelerometer Dataset
|
|
5
5
|
Home-page: https://github.com/OxWearables/actinet
|
|
6
6
|
Download-URL: https://github.com/OxWearables/actinet
|
|
@@ -77,7 +77,7 @@ $ actinet -f sample.csv
|
|
|
77
77
|
Some systems may face issues with Java when running the script. If this is your case, try fixing OpenJDK to version 8:
|
|
78
78
|
|
|
79
79
|
```console
|
|
80
|
-
conda
|
|
80
|
+
conda create -n actinet openjdk=8
|
|
81
81
|
```
|
|
82
82
|
|
|
83
83
|
### Offline usage
|
|
@@ -55,7 +55,7 @@ $ actinet -f sample.csv
|
|
|
55
55
|
Some systems may face issues with Java when running the script. If this is your case, try fixing OpenJDK to version 8:
|
|
56
56
|
|
|
57
57
|
```console
|
|
58
|
-
conda
|
|
58
|
+
conda create -n actinet openjdk=8
|
|
59
59
|
```
|
|
60
60
|
|
|
61
61
|
### Offline usage
|
|
@@ -8,11 +8,11 @@ import json
|
|
|
8
8
|
|
|
9
9
|
version_json = '''
|
|
10
10
|
{
|
|
11
|
-
"date": "2024-
|
|
11
|
+
"date": "2024-03-09T08:28:09+0000",
|
|
12
12
|
"dirty": false,
|
|
13
13
|
"error": null,
|
|
14
|
-
"full-revisionid": "
|
|
15
|
-
"version": "0.0.
|
|
14
|
+
"full-revisionid": "5b419d5fe6079975dbc6ccbca13b128c52709f3f",
|
|
15
|
+
"version": "0.0.dev6"
|
|
16
16
|
}
|
|
17
17
|
''' # END VERSION_JSON
|
|
18
18
|
|
|
@@ -16,12 +16,12 @@ from actinet import __classifier_version__
|
|
|
16
16
|
from actinet import __classifier_md5__
|
|
17
17
|
from actinet.accPlot import plotTimeSeries
|
|
18
18
|
from actinet.models import ActivityClassifier
|
|
19
|
-
from actinet.sslmodel import SAMPLE_RATE
|
|
20
19
|
from actinet.summarisation import getActivitySummary
|
|
21
20
|
from actinet.utils.utils import infer_freq
|
|
22
21
|
|
|
23
22
|
BASE_URL = "https://zenodo.org/records/10625542/files/"
|
|
24
23
|
|
|
24
|
+
|
|
25
25
|
def main():
|
|
26
26
|
|
|
27
27
|
parser = argparse.ArgumentParser(
|
|
@@ -99,10 +99,14 @@ def main():
|
|
|
99
99
|
|
|
100
100
|
return
|
|
101
101
|
|
|
102
|
+
else:
|
|
103
|
+
if not args.filepath:
|
|
104
|
+
raise ValueError("Please provide a file to process.")
|
|
105
|
+
|
|
102
106
|
# Load file
|
|
103
107
|
data, info = read(
|
|
104
108
|
args.filepath,
|
|
105
|
-
resample_hz=
|
|
109
|
+
resample_hz=None,
|
|
106
110
|
sample_rate=args.sample_rate,
|
|
107
111
|
verbose=verbose,
|
|
108
112
|
)
|
|
@@ -118,7 +122,11 @@ def main():
|
|
|
118
122
|
|
|
119
123
|
check_md5 = args.classifier_path is None
|
|
120
124
|
classifier: ActivityClassifier = load_classifier(
|
|
121
|
-
args.classifier_path or classifier_path,
|
|
125
|
+
args.classifier_path or classifier_path,
|
|
126
|
+
args.model_repo_path,
|
|
127
|
+
check_md5,
|
|
128
|
+
args.force_download,
|
|
129
|
+
verbose,
|
|
122
130
|
)
|
|
123
131
|
|
|
124
132
|
classifier.verbose = verbose
|
|
@@ -126,7 +134,7 @@ def main():
|
|
|
126
134
|
|
|
127
135
|
if verbose:
|
|
128
136
|
print("Running activity classifier...")
|
|
129
|
-
Y = classifier.predict_from_frame(data)
|
|
137
|
+
Y = classifier.predict_from_frame(data, args.sample_rate)
|
|
130
138
|
|
|
131
139
|
# Save predicted activities
|
|
132
140
|
timeSeriesFile = f"{outdir}/{basename}-timeSeries.csv.gz"
|
|
@@ -145,7 +153,7 @@ def main():
|
|
|
145
153
|
print("Output plot written to:", plotFile)
|
|
146
154
|
|
|
147
155
|
# Summary
|
|
148
|
-
summary = getActivitySummary(Y, classifier.labels, True, True, verbose)
|
|
156
|
+
summary = getActivitySummary(Y, list(classifier.labels), True, True, verbose)
|
|
149
157
|
|
|
150
158
|
# Join the actipy processing info, with acitivity summary data
|
|
151
159
|
outputSummary = {**summary, **info}
|
|
@@ -162,24 +170,33 @@ def main():
|
|
|
162
170
|
if verbose:
|
|
163
171
|
print("\nSummary Stats\n---------------------")
|
|
164
172
|
print(
|
|
165
|
-
json.dumps(
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
173
|
+
json.dumps(
|
|
174
|
+
{
|
|
175
|
+
key: outputSummary[key]
|
|
176
|
+
for key in [
|
|
177
|
+
"Filename",
|
|
178
|
+
"Filesize(MB)",
|
|
179
|
+
"WearTime(days)",
|
|
180
|
+
"NonwearTime(days)",
|
|
181
|
+
"ReadOK",
|
|
182
|
+
]
|
|
183
|
+
+ [
|
|
184
|
+
f"{label}-overall-avg"
|
|
185
|
+
for label in ["acc"] + list(classifier.labels)
|
|
186
|
+
]
|
|
187
|
+
},
|
|
188
|
+
indent=4,
|
|
189
|
+
cls=NpEncoder,
|
|
190
|
+
)
|
|
176
191
|
)
|
|
177
192
|
|
|
178
193
|
after = time.time()
|
|
179
194
|
print(f"Done! ({round(after - before,2)}s)")
|
|
180
195
|
|
|
181
196
|
|
|
182
|
-
def read(
|
|
197
|
+
def read(
|
|
198
|
+
filepath, resample_hz="uniform", sample_rate=None, lowpass_hz=None, verbose=True
|
|
199
|
+
):
|
|
183
200
|
|
|
184
201
|
p = pathlib.Path(filepath)
|
|
185
202
|
ftype = p.suffixes[0].lower()
|
|
@@ -210,7 +227,7 @@ def read(filepath, resample_hz="uniform", sample_rate=None, verbose=True):
|
|
|
210
227
|
data, info = actipy.process(
|
|
211
228
|
data,
|
|
212
229
|
sample_rate,
|
|
213
|
-
lowpass_hz=
|
|
230
|
+
lowpass_hz=lowpass_hz,
|
|
214
231
|
calibrate_gravity=True,
|
|
215
232
|
detect_nonwear=True,
|
|
216
233
|
resample_hz=resample_hz,
|
|
@@ -231,7 +248,7 @@ def read(filepath, resample_hz="uniform", sample_rate=None, verbose=True):
|
|
|
231
248
|
|
|
232
249
|
data, info = actipy.read_device(
|
|
233
250
|
filepath,
|
|
234
|
-
lowpass_hz=
|
|
251
|
+
lowpass_hz=lowpass_hz,
|
|
235
252
|
calibrate_gravity=True,
|
|
236
253
|
detect_nonwear=True,
|
|
237
254
|
resample_hz=resample_hz,
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
from sklearn.model_selection import StratifiedGroupKFold
|
|
2
|
+
from sklearn.preprocessing import LabelEncoder
|
|
3
|
+
from sklearn.metrics import (
|
|
4
|
+
classification_report,
|
|
5
|
+
accuracy_score,
|
|
6
|
+
f1_score,
|
|
7
|
+
cohen_kappa_score,
|
|
8
|
+
)
|
|
9
|
+
import numpy as np
|
|
10
|
+
import pandas as pd
|
|
11
|
+
import os
|
|
12
|
+
from imblearn.ensemble import BalancedRandomForestClassifier
|
|
13
|
+
|
|
14
|
+
from actinet.models import ActivityClassifier
|
|
15
|
+
from actinet.hmm import HMM
|
|
16
|
+
from actinet.utils.utils import safe_indexer
|
|
17
|
+
|
|
18
|
+
WINSEC = 30
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def evaluate_preprocessing(
|
|
22
|
+
classifier: ActivityClassifier,
|
|
23
|
+
X,
|
|
24
|
+
Y,
|
|
25
|
+
groups=None,
|
|
26
|
+
T=None,
|
|
27
|
+
weights_path="models/weights.pt",
|
|
28
|
+
verbose=True,
|
|
29
|
+
):
|
|
30
|
+
skf = StratifiedGroupKFold(n_splits=5)
|
|
31
|
+
|
|
32
|
+
le = LabelEncoder().fit(Y)
|
|
33
|
+
Y_encoded = le.transform(Y)
|
|
34
|
+
|
|
35
|
+
Y_preds = np.empty_like(Y_encoded)
|
|
36
|
+
|
|
37
|
+
for fold, (train_index, test_index) in enumerate(skf.split(X, Y_encoded, groups)):
|
|
38
|
+
X_train, X_test = X[train_index], X[test_index]
|
|
39
|
+
y_train, y_test = Y_encoded[train_index], Y_encoded[test_index]
|
|
40
|
+
groups_train = safe_indexer(groups, train_index)
|
|
41
|
+
t_train = safe_indexer(T, train_index)
|
|
42
|
+
|
|
43
|
+
classifier.fit(
|
|
44
|
+
X_train,
|
|
45
|
+
y_train,
|
|
46
|
+
groups_train,
|
|
47
|
+
t_train,
|
|
48
|
+
weights_path.format(fold),
|
|
49
|
+
n_splits=1,
|
|
50
|
+
)
|
|
51
|
+
y_pred = classifier.predict(X_test, False)
|
|
52
|
+
|
|
53
|
+
if verbose:
|
|
54
|
+
print(
|
|
55
|
+
f"Fold {fold+1} Test Scores - Accuracy: {accuracy_score(y_test, y_pred):.3f}, "
|
|
56
|
+
+ f"Macro F1: {f1_score(y_test, y_pred, average='macro'):.3f}"
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
Y_preds[test_index] = y_pred
|
|
60
|
+
|
|
61
|
+
Y_preds = le.inverse_transform(Y_preds)
|
|
62
|
+
|
|
63
|
+
if verbose:
|
|
64
|
+
print(classification_report(Y, Y_preds))
|
|
65
|
+
|
|
66
|
+
return Y_preds
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
def evaluate_models(
|
|
70
|
+
actinet_classifier: ActivityClassifier,
|
|
71
|
+
rf_classifier: BalancedRandomForestClassifier,
|
|
72
|
+
X_actinet,
|
|
73
|
+
X_rf,
|
|
74
|
+
Y_actinet,
|
|
75
|
+
Y_rf,
|
|
76
|
+
groups_actinet,
|
|
77
|
+
groups_rf,
|
|
78
|
+
T_actinet=None,
|
|
79
|
+
T_rf=None,
|
|
80
|
+
weights_path="models/weights.pt",
|
|
81
|
+
out_dir=None,
|
|
82
|
+
verbose=True,
|
|
83
|
+
):
|
|
84
|
+
skf = StratifiedGroupKFold(n_splits=5)
|
|
85
|
+
|
|
86
|
+
le = LabelEncoder().fit(Y_rf)
|
|
87
|
+
Y_encoded_rf = le.transform(Y_rf)
|
|
88
|
+
Y_encoded_actinet = le.transform(Y_actinet)
|
|
89
|
+
|
|
90
|
+
Y_preds_rf = np.empty_like(Y_encoded_rf)
|
|
91
|
+
Y_preds_actinet = np.empty_like(Y_encoded_actinet)
|
|
92
|
+
|
|
93
|
+
results_rf = []
|
|
94
|
+
results_actinet = []
|
|
95
|
+
|
|
96
|
+
for fold, (train_index, test_index) in enumerate(
|
|
97
|
+
skf.split(X_rf, Y_encoded_rf, groups_rf)
|
|
98
|
+
):
|
|
99
|
+
if verbose:
|
|
100
|
+
print(f"======== Evalating Fold {fold+1} ========")
|
|
101
|
+
# Ensure the same train and test split for groups are used in both models in each fold
|
|
102
|
+
train_index_actinet = np.isin(groups_actinet, np.unique(groups_rf[train_index]))
|
|
103
|
+
test_index_actinet = np.isin(groups_actinet, np.unique(groups_rf[test_index]))
|
|
104
|
+
|
|
105
|
+
train_index_rf = np.isin(groups_rf, np.unique(groups_rf[train_index]))
|
|
106
|
+
test_index_rf = np.isin(groups_rf, np.unique(groups_rf[test_index]))
|
|
107
|
+
|
|
108
|
+
# Train test split for actinet model
|
|
109
|
+
X_train_actinet, X_test_actinet = (
|
|
110
|
+
X_actinet[train_index_actinet],
|
|
111
|
+
X_actinet[test_index_actinet],
|
|
112
|
+
)
|
|
113
|
+
y_train_actinet, y_test_actinet = (
|
|
114
|
+
Y_encoded_actinet[train_index_actinet],
|
|
115
|
+
Y_encoded_actinet[test_index_actinet],
|
|
116
|
+
)
|
|
117
|
+
groups_train_actinet = groups_actinet[train_index_actinet]
|
|
118
|
+
groups_test_actinet = groups_actinet[test_index_actinet]
|
|
119
|
+
|
|
120
|
+
t_train_actinet = safe_indexer(T_actinet, train_index_actinet)
|
|
121
|
+
t_test_actinet = safe_indexer(T_actinet, test_index_actinet)
|
|
122
|
+
|
|
123
|
+
# Train test split for accelerometer model
|
|
124
|
+
X_train_rf, X_test_rf = X_rf[train_index_rf], X_rf[test_index_rf]
|
|
125
|
+
y_train_rf, y_test_rf = (
|
|
126
|
+
Y_encoded_rf[train_index_rf],
|
|
127
|
+
Y_encoded_rf[test_index_rf],
|
|
128
|
+
)
|
|
129
|
+
t_train_rf = safe_indexer(T_rf, train_index_rf)
|
|
130
|
+
t_test_rf = safe_indexer(T_rf, test_index_rf)
|
|
131
|
+
|
|
132
|
+
groups_test_rf = groups_rf[test_index_rf]
|
|
133
|
+
|
|
134
|
+
actinet_classifier.fit(
|
|
135
|
+
X_train_actinet,
|
|
136
|
+
y_train_actinet,
|
|
137
|
+
groups_train_actinet,
|
|
138
|
+
t_train_actinet,
|
|
139
|
+
weights_path.format(fold),
|
|
140
|
+
n_splits=5,
|
|
141
|
+
)
|
|
142
|
+
y_pred_actinet = actinet_classifier.predict(
|
|
143
|
+
X_test_actinet, True, t_test_actinet
|
|
144
|
+
).astype(int)
|
|
145
|
+
|
|
146
|
+
# Analysis of accelerometer random forest model
|
|
147
|
+
rf_classifier.fit(
|
|
148
|
+
X_train_rf,
|
|
149
|
+
y_train_rf,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
hmm_rf = HMM()
|
|
153
|
+
hmm_rf.fit(
|
|
154
|
+
rf_classifier.oob_decision_function_,
|
|
155
|
+
y_train_rf,
|
|
156
|
+
t_train_rf,
|
|
157
|
+
WINSEC,
|
|
158
|
+
)
|
|
159
|
+
y_pred_rf = hmm_rf.predict(rf_classifier.predict(X_test_rf), t_test_rf, WINSEC)
|
|
160
|
+
|
|
161
|
+
# Display model performance for each fold
|
|
162
|
+
if verbose:
|
|
163
|
+
print(
|
|
164
|
+
f"Actinet test scores for fold {fold+1}\n"
|
|
165
|
+
+ f"Accuracy: {accuracy_score(y_test_actinet, y_pred_actinet):.3f}, "
|
|
166
|
+
+ f"Macro F1: {f1_score(y_test_actinet, y_pred_actinet, average='macro'):.3f}, "
|
|
167
|
+
+ f"Kappa: {cohen_kappa_score(y_test_actinet, y_pred_actinet):.3f}"
|
|
168
|
+
)
|
|
169
|
+
print(
|
|
170
|
+
f"Accelerometer test scores for fold {fold+1}\n"
|
|
171
|
+
+ f"Accuracy: {accuracy_score(y_test_rf, y_pred_rf):.3f}, "
|
|
172
|
+
+ f"Macro F1: {f1_score(y_test_rf, y_pred_rf, average='macro'):.3f}, "
|
|
173
|
+
+ f"Kappa: {cohen_kappa_score(y_test_rf, y_pred_rf):.3f}"
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
Y_preds_actinet[test_index_actinet] = y_pred_actinet
|
|
177
|
+
Y_preds_rf[test_index_rf] = y_pred_rf
|
|
178
|
+
|
|
179
|
+
results_actinet.append(
|
|
180
|
+
{
|
|
181
|
+
"fold": [fold] * len(y_pred_actinet),
|
|
182
|
+
"group": groups_test_actinet,
|
|
183
|
+
"Y_pred": le.inverse_transform(y_pred_actinet),
|
|
184
|
+
"Y_true": le.inverse_transform(y_test_actinet),
|
|
185
|
+
}
|
|
186
|
+
)
|
|
187
|
+
results_rf.append(
|
|
188
|
+
{
|
|
189
|
+
"fold": [fold] * len(y_pred_rf),
|
|
190
|
+
"group": groups_test_rf,
|
|
191
|
+
"Y_pred": le.inverse_transform(y_pred_rf),
|
|
192
|
+
"Y_true": le.inverse_transform(y_test_rf),
|
|
193
|
+
}
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
Y_preds_actinet = le.inverse_transform(Y_preds_actinet)
|
|
197
|
+
Y_preds_rf = le.inverse_transform(Y_preds_rf)
|
|
198
|
+
|
|
199
|
+
# Report performance across all folds
|
|
200
|
+
if verbose:
|
|
201
|
+
print("Actinet performance:")
|
|
202
|
+
print(classification_report(Y_actinet, Y_preds_actinet))
|
|
203
|
+
print("Accelerometer performance:")
|
|
204
|
+
print(classification_report(Y_rf, Y_preds_rf))
|
|
205
|
+
|
|
206
|
+
# Save results to pickle files
|
|
207
|
+
results_actinet = pd.DataFrame(results_actinet)
|
|
208
|
+
results_rf = pd.DataFrame(results_rf)
|
|
209
|
+
|
|
210
|
+
if out_dir is not None:
|
|
211
|
+
os.makedirs(out_dir, exist_ok=True)
|
|
212
|
+
results_actinet.to_pickle(f"{out_dir}/actinet_results.pkl")
|
|
213
|
+
results_rf.to_pickle(f"{out_dir}/rf_results.pkl")
|
|
214
|
+
|
|
215
|
+
return results_actinet, results_rf
|
|
@@ -27,31 +27,29 @@ class HMM:
|
|
|
27
27
|
"Hidden Markov Model\n"
|
|
28
28
|
"prior: {prior}\n"
|
|
29
29
|
"emission: {emission}\n"
|
|
30
|
-
"transition: {transition}
|
|
31
|
-
"labels: {labels}".format(
|
|
30
|
+
"transition: {transition}".format(
|
|
32
31
|
prior=self.prior,
|
|
33
32
|
emission=self.emission,
|
|
34
33
|
transition=self.transition,
|
|
35
|
-
labels=self.labels,
|
|
36
34
|
)
|
|
37
35
|
)
|
|
38
36
|
|
|
39
|
-
def
|
|
37
|
+
def fit(self, Y_prob, Y_true, T=None, interval=None):
|
|
40
38
|
"""https://en.wikipedia.org/wiki/Hidden_Markov_model
|
|
41
|
-
:param
|
|
42
|
-
:param
|
|
39
|
+
:param Y_prob: Observation probabilities
|
|
40
|
+
:param Y_true: Ground truth labels
|
|
43
41
|
"""
|
|
44
42
|
|
|
45
43
|
if self.labels is None:
|
|
46
|
-
self.labels = np.unique(
|
|
44
|
+
self.labels = np.unique(Y_true)
|
|
47
45
|
|
|
48
|
-
prior = np.mean(
|
|
46
|
+
prior = np.mean(Y_true.reshape(-1, 1) == self.labels, axis=0)
|
|
49
47
|
|
|
50
48
|
emission = np.vstack(
|
|
51
|
-
[np.mean(
|
|
49
|
+
[np.mean(Y_prob[Y_true == label], axis=0) for label in self.labels]
|
|
52
50
|
)
|
|
53
51
|
|
|
54
|
-
transition = calculate_transition_matrix(
|
|
52
|
+
transition = calculate_transition_matrix(Y_true, T, interval)
|
|
55
53
|
|
|
56
54
|
self.prior = prior
|
|
57
55
|
self.emission = emission
|
|
@@ -193,4 +191,9 @@ def calculate_transition_matrix(Y, t=None, interval=None):
|
|
|
193
191
|
trans_mat = df.groupby([0, "shift"]).count().unstack().fillna(0)
|
|
194
192
|
|
|
195
193
|
# normalise by occurences and save values to get the transition matrix
|
|
196
|
-
|
|
194
|
+
trans_mat = trans_mat.div(trans_mat.sum(axis=1), axis=0).values
|
|
195
|
+
|
|
196
|
+
if trans_mat.size == 0:
|
|
197
|
+
raise Exception("No transitions found in data")
|
|
198
|
+
|
|
199
|
+
return trans_mat
|