morphlabs 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- morphlabs-0.1.0/LICENSE +21 -0
- morphlabs-0.1.0/PKG-INFO +51 -0
- morphlabs-0.1.0/README.md +29 -0
- morphlabs-0.1.0/morphlabs/__init__.py +4 -0
- morphlabs-0.1.0/morphlabs/io/__init__.py +3 -0
- morphlabs-0.1.0/morphlabs/io/loading.py +112 -0
- morphlabs-0.1.0/morphlabs/models/__init__.py +3 -0
- morphlabs-0.1.0/morphlabs/models/scientia.py +120 -0
- morphlabs-0.1.0/morphlabs.egg-info/PKG-INFO +51 -0
- morphlabs-0.1.0/morphlabs.egg-info/SOURCES.txt +15 -0
- morphlabs-0.1.0/morphlabs.egg-info/dependency_links.txt +1 -0
- morphlabs-0.1.0/morphlabs.egg-info/requires.txt +10 -0
- morphlabs-0.1.0/morphlabs.egg-info/top_level.txt +3 -0
- morphlabs-0.1.0/pyproject.toml +38 -0
- morphlabs-0.1.0/setup.cfg +4 -0
- morphlabs-0.1.0/tests/test_loading.py +163 -0
- morphlabs-0.1.0/tests/test_scientia.py +332 -0
morphlabs-0.1.0/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Morphlabs
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
morphlabs-0.1.0/PKG-INFO
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: morphlabs
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Python SDK for Morphlabs biosignal processing API
|
|
5
|
+
Author-email: Morphlabs <support@morphlabs.tech>
|
|
6
|
+
Maintainer-email: Morphlabs <support@morphlabs.tech>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Project-URL: Homepage, https://morphlabs.tech
|
|
9
|
+
Requires-Python: >=3.8
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: numpy>=1.20.0
|
|
13
|
+
Requires-Dist: requests>=2.28.0
|
|
14
|
+
Requires-Dist: mne>=1.0.0
|
|
15
|
+
Requires-Dist: pandas>=1.3.0
|
|
16
|
+
Requires-Dist: scipy>=1.7.0
|
|
17
|
+
Requires-Dist: tenacity>=9.1.2
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
# morphlabs
|
|
24
|
+
|
|
25
|
+
Python SDK for Morphlabs biosignal processing API.
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install morphlabs
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Quick Start
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
from morphlabs.models import Scientia
|
|
37
|
+
|
|
38
|
+
# Set SCIENTIA_API_KEY environment variable or pass directly
|
|
39
|
+
scientia = Scientia(api_key="your-api-key")
|
|
40
|
+
|
|
41
|
+
# Clean EEG data
|
|
42
|
+
cleaned_data = scientia.clean_data("path/to/eeg_file.csv")
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Documentation
|
|
46
|
+
|
|
47
|
+
For full documentation, see [docs.morphlabs.tech](https://docs.morphlabs.tech)
|
|
48
|
+
|
|
49
|
+
## License
|
|
50
|
+
|
|
51
|
+
MIT
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
# morphlabs
|
|
2
|
+
|
|
3
|
+
Python SDK for Morphlabs biosignal processing API.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install morphlabs
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Quick Start
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
from morphlabs.models import Scientia
|
|
15
|
+
|
|
16
|
+
# Set SCIENTIA_API_KEY environment variable or pass directly
|
|
17
|
+
scientia = Scientia(api_key="your-api-key")
|
|
18
|
+
|
|
19
|
+
# Clean EEG data
|
|
20
|
+
cleaned_data = scientia.clean_data("path/to/eeg_file.csv")
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## Documentation
|
|
24
|
+
|
|
25
|
+
For full documentation, see [docs.morphlabs.tech](https://docs.morphlabs.tech)
|
|
26
|
+
|
|
27
|
+
## License
|
|
28
|
+
|
|
29
|
+
MIT
|
|
@@ -0,0 +1,112 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import mne
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
from typing import Optional
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EEGData:
|
|
9
|
+
def __init__(self, file_path: str):
|
|
10
|
+
self.file_path = file_path
|
|
11
|
+
self._data = None
|
|
12
|
+
self._channels = None
|
|
13
|
+
self._pad_amount = None
|
|
14
|
+
|
|
15
|
+
if file_path is not None:
|
|
16
|
+
if not self.load_data():
|
|
17
|
+
raise ValueError(f"Failed to load data from {file_path}")
|
|
18
|
+
|
|
19
|
+
def get_data(self) -> Optional[list[np.ndarray]]:
|
|
20
|
+
return self._data
|
|
21
|
+
|
|
22
|
+
def get_channels(self) -> Optional[int]:
|
|
23
|
+
return self._channels
|
|
24
|
+
|
|
25
|
+
def get_pad_amount(self) -> Optional[int]:
|
|
26
|
+
return self._pad_amount
|
|
27
|
+
|
|
28
|
+
def load_data(self) -> bool:
|
|
29
|
+
self._validate_file_path(self.file_path)
|
|
30
|
+
match Path(self.file_path).suffix:
|
|
31
|
+
case '.csv':
|
|
32
|
+
self.load_data_from_csv(self.file_path)
|
|
33
|
+
return True
|
|
34
|
+
case '.edf':
|
|
35
|
+
self.load_data_from_edf(self.file_path)
|
|
36
|
+
return True
|
|
37
|
+
case '.bdf':
|
|
38
|
+
self.load_data_from_bdf(self.file_path)
|
|
39
|
+
return True
|
|
40
|
+
case _:
|
|
41
|
+
raise ValueError(f"Unsupported file type: {str(self.file_path).split('.')[-1]}, please use .csv, .edf, or .bdf files.")
|
|
42
|
+
|
|
43
|
+
def load_data_from_csv(self, file_path: str):
|
|
44
|
+
try:
|
|
45
|
+
data = pd.read_csv(file_path)
|
|
46
|
+
if data.empty:
|
|
47
|
+
raise ValueError(f"File is empty: '{file_path}'. Please provide a file with EEG data.")
|
|
48
|
+
self._channels = len(data.columns)
|
|
49
|
+
self._data, self._pad_amount = self.segment_data(data.values.T.astype(np.float32))
|
|
50
|
+
self.verify_montage()
|
|
51
|
+
except UnicodeDecodeError as e:
|
|
52
|
+
raise ValueError(f"Failed to load CSV file '{file_path}': File contains invalid characters. Details: {e}")
|
|
53
|
+
except pd.errors.ParserError as e:
|
|
54
|
+
raise ValueError(f"Failed to parse CSV file '{file_path}': File may be corrupted or incorrectly formatted. Details: {e}")
|
|
55
|
+
except pd.errors.EmptyDataError:
|
|
56
|
+
raise ValueError(f"File is empty: '{file_path}'. Please provide a file with EEG data.")
|
|
57
|
+
except ValueError:
|
|
58
|
+
raise
|
|
59
|
+
except Exception as e:
|
|
60
|
+
raise ValueError(f"Failed to load CSV file '{file_path}': {type(e).__name__}: {e}")
|
|
61
|
+
|
|
62
|
+
def load_data_from_edf(self, file_path: str):
|
|
63
|
+
try:
|
|
64
|
+
raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
|
|
65
|
+
self._channels = len(raw.ch_names)
|
|
66
|
+
self._data, self._pad_amount = self.segment_data(raw.get_data().astype(np.float32))
|
|
67
|
+
self.verify_montage()
|
|
68
|
+
except ValueError:
|
|
69
|
+
raise
|
|
70
|
+
except Exception as e:
|
|
71
|
+
raise ValueError(f"Failed to load EDF file '{file_path}': File may be corrupted or not a valid EDF format. Details: {type(e).__name__}: {e}")
|
|
72
|
+
|
|
73
|
+
def load_data_from_bdf(self, file_path: str):
|
|
74
|
+
try:
|
|
75
|
+
raw = mne.io.read_raw_bdf(file_path, preload=True, verbose=False)
|
|
76
|
+
self._channels = len(raw.ch_names)
|
|
77
|
+
self._data, self._pad_amount = self.segment_data(raw.get_data().astype(np.float32))
|
|
78
|
+
self.verify_montage()
|
|
79
|
+
except ValueError:
|
|
80
|
+
raise
|
|
81
|
+
except Exception as e:
|
|
82
|
+
raise ValueError(f"Failed to load BDF file '{file_path}': File may be corrupted or not a valid BDF format. Details: {type(e).__name__}: {e}")
|
|
83
|
+
|
|
84
|
+
def _validate_file_path(self, file_path: str) -> None:
|
|
85
|
+
path = Path(file_path)
|
|
86
|
+
|
|
87
|
+
if not path.exists():
|
|
88
|
+
raise ValueError(f"File not found: '{file_path}'. Please check the file path exists.")
|
|
89
|
+
if not path.is_file():
|
|
90
|
+
raise ValueError(f"Path is not a file: '{file_path}'. Please provide a path to a file, not a directory.")
|
|
91
|
+
if path.stat().st_size == 0:
|
|
92
|
+
raise ValueError(f"File is empty: '{file_path}'. Please provide a file with EEG data.")
|
|
93
|
+
|
|
94
|
+
def segment_data(self, data: np.ndarray) -> tuple[list[np.ndarray], int]:
|
|
95
|
+
n_samples = data.shape[1]
|
|
96
|
+
window_size = 1000
|
|
97
|
+
segments = []
|
|
98
|
+
pad_amount = 0
|
|
99
|
+
|
|
100
|
+
for i in range(0, n_samples, window_size):
|
|
101
|
+
segment = data[:, i:i+window_size]
|
|
102
|
+
if segment.shape[1] != window_size:
|
|
103
|
+
pad_amount = window_size - segment.shape[1]
|
|
104
|
+
pad_width = ((0, 0), (0, pad_amount))
|
|
105
|
+
segment = np.pad(segment, pad_width, 'constant')
|
|
106
|
+
segments.append(segment)
|
|
107
|
+
return segments, pad_amount
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def verify_montage(self) -> None:
|
|
111
|
+
if self._channels != 19:
|
|
112
|
+
raise ValueError(f"Unsupported number of channels: {self._channels}, Scientia currently only supports the 19 channels in the 10-20 system.")
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
import numpy as np
|
|
3
|
+
import os
|
|
4
|
+
import logging
|
|
5
|
+
from typing import Optional
|
|
6
|
+
from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_exception_type
|
|
7
|
+
|
|
8
|
+
from morphlabs.io.loading import EEGData
|
|
9
|
+
|
|
10
|
+
logging.basicConfig(level=logging.INFO)
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
API_ERROR_MESSAGES = {
|
|
14
|
+
400: "Bad request: The data sent to the API was invalid. Please check your input data format.",
|
|
15
|
+
401: "Authentication failed: Invalid API key. Please check your SCIENTIA_API_KEY.",
|
|
16
|
+
403: "Access denied: Your API key does not have permission for this operation. Please contact support.",
|
|
17
|
+
404: "API endpoint not found: Please check the base_url configuration.",
|
|
18
|
+
429: "Rate limit exceeded: Too many requests. Please wait a moment and try again.",
|
|
19
|
+
500: "Server error: The Scientia API encountered an internal error. Please try again later.",
|
|
20
|
+
502: "Bad gateway: The Scientia API is temporarily unavailable. Please try again later.",
|
|
21
|
+
503: "Service unavailable: The Scientia API is temporarily down for maintenance. Please try again later.",
|
|
22
|
+
504: "Gateway timeout: The request took too long. Please try again with smaller data segments.",
|
|
23
|
+
}
|
|
24
|
+
|
|
25
|
+
RETRYABLE_STATUS_CODES = [429, 500, 502, 503, 504]
|
|
26
|
+
|
|
27
|
+
RUNPOD_API_KEY = "rpa_FFBJYB2SODF8Z3MCXWY5W78L8KI0A4BGQ999JLUZcqkk9u"
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class RetryableAPIError(Exception):
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
class Scientia:
|
|
34
|
+
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
|
35
|
+
api_key = api_key if api_key is not None else os.getenv("SCIENTIA_API_KEY")
|
|
36
|
+
|
|
37
|
+
if api_key is not None and not api_key.strip():
|
|
38
|
+
raise ValueError("API key cannot be empty or whitespace. Please provide a valid API key.")
|
|
39
|
+
self.api_key = api_key
|
|
40
|
+
|
|
41
|
+
if base_url is not None:
|
|
42
|
+
if not base_url.startswith(("http://", "https://")):
|
|
43
|
+
raise ValueError(f"Invalid base_url: '{base_url}'. URL must start with http:// or https://")
|
|
44
|
+
self.base_url = base_url.rstrip("/")
|
|
45
|
+
else:
|
|
46
|
+
self.base_url = "https://api.runpod.ai/v2/9ni9hifywn9z73/runsync"
|
|
47
|
+
|
|
48
|
+
def clean_data(self, data_path: str) -> np.ndarray:
|
|
49
|
+
|
|
50
|
+
if self.api_key is None:
|
|
51
|
+
raise ValueError("API key not found. Please pass the API key as an argument or set the SCIENTIA_API_KEY environment variable.")
|
|
52
|
+
|
|
53
|
+
data_obj = EEGData(data_path)
|
|
54
|
+
|
|
55
|
+
data_samples = data_obj.get_data()
|
|
56
|
+
pad_amount = data_obj.get_pad_amount()
|
|
57
|
+
|
|
58
|
+
cleaned_samples = []
|
|
59
|
+
|
|
60
|
+
logger.info(f"Cleaning {len(data_samples) * 1000 - pad_amount} samples")
|
|
61
|
+
for i, sample in enumerate(data_samples):
|
|
62
|
+
|
|
63
|
+
response = self._make_api_request(
|
|
64
|
+
url=self.base_url,
|
|
65
|
+
json={"input": {"api_key": self.api_key, "data": sample.tolist()}},
|
|
66
|
+
timeout=30
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
logger.info(f"Cleaned {(i + 1) * 1000} of {len(data_samples)*1000 - pad_amount} samples")
|
|
71
|
+
|
|
72
|
+
# Validate response structure
|
|
73
|
+
try:
|
|
74
|
+
response_data = response.json()
|
|
75
|
+
except requests.exceptions.JSONDecodeError:
|
|
76
|
+
raise ValueError("Invalid response from API: Expected JSON but received invalid data.")
|
|
77
|
+
|
|
78
|
+
# Handle RunPod response format
|
|
79
|
+
status = response_data.get("status")
|
|
80
|
+
if status != "COMPLETED":
|
|
81
|
+
if status in ("IN_QUEUE", "IN_PROGRESS"):
|
|
82
|
+
raise ValueError(f"Scientia API request is still processing (status: {status}). Please try again.")
|
|
83
|
+
elif status == "FAILED":
|
|
84
|
+
error_msg = response_data.get("error", "Unknown error")
|
|
85
|
+
raise ValueError(f"Scientia API request failed: {error_msg}")
|
|
86
|
+
elif status == "CANCELLED":
|
|
87
|
+
raise ValueError("Scientia API request was cancelled.")
|
|
88
|
+
else:
|
|
89
|
+
raise ValueError(f"Unexpected response status from Scientia API: {status}")
|
|
90
|
+
|
|
91
|
+
output = response_data.get("output")
|
|
92
|
+
if output is None or "reconstructed" not in output:
|
|
93
|
+
raise ValueError("Invalid response from API: Missing 'reconstructed' field in response.")
|
|
94
|
+
|
|
95
|
+
cleaned_sample = np.array(output['reconstructed'])
|
|
96
|
+
if i == len(data_samples) - 1 and pad_amount > 0:
|
|
97
|
+
cleaned_sample = cleaned_sample[:,:-pad_amount]
|
|
98
|
+
cleaned_samples.append(cleaned_sample)
|
|
99
|
+
|
|
100
|
+
return self._reconstruct_data(cleaned_samples)
|
|
101
|
+
|
|
102
|
+
def _reconstruct_data(self, cleaned_samples: list[np.ndarray]) -> np.ndarray:
|
|
103
|
+
return np.concatenate(cleaned_samples, axis=1)
|
|
104
|
+
|
|
105
|
+
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1, min=4, max=15), retry=retry_if_exception_type(RetryableAPIError))
|
|
106
|
+
def _make_api_request(self, url: str, json: dict, timeout: int) -> requests.Response:
|
|
107
|
+
headers = {"Authorization": f"Bearer {RUNPOD_API_KEY}"}
|
|
108
|
+
try:
|
|
109
|
+
response = requests.post(url, headers=headers, json=json, timeout=timeout)
|
|
110
|
+
except (requests.exceptions.Timeout, requests.exceptions.ConnectionError) as e:
|
|
111
|
+
raise RetryableAPIError(f"Network error: {e}")
|
|
112
|
+
|
|
113
|
+
if response.status_code in RETRYABLE_STATUS_CODES:
|
|
114
|
+
raise RetryableAPIError(f"API request failed with status {response.status_code}: {response.text}")
|
|
115
|
+
|
|
116
|
+
if not response.ok:
|
|
117
|
+
error_msg = API_ERROR_MESSAGES.get(response.status_code, f"API request failed with status {response.status_code}: {response.text}")
|
|
118
|
+
raise ValueError(error_msg)
|
|
119
|
+
|
|
120
|
+
return response
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: morphlabs
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Python SDK for Morphlabs biosignal processing API
|
|
5
|
+
Author-email: Morphlabs <support@morphlabs.tech>
|
|
6
|
+
Maintainer-email: Morphlabs <support@morphlabs.tech>
|
|
7
|
+
License-Expression: MIT
|
|
8
|
+
Project-URL: Homepage, https://morphlabs.tech
|
|
9
|
+
Requires-Python: >=3.8
|
|
10
|
+
Description-Content-Type: text/markdown
|
|
11
|
+
License-File: LICENSE
|
|
12
|
+
Requires-Dist: numpy>=1.20.0
|
|
13
|
+
Requires-Dist: requests>=2.28.0
|
|
14
|
+
Requires-Dist: mne>=1.0.0
|
|
15
|
+
Requires-Dist: pandas>=1.3.0
|
|
16
|
+
Requires-Dist: scipy>=1.7.0
|
|
17
|
+
Requires-Dist: tenacity>=9.1.2
|
|
18
|
+
Provides-Extra: dev
|
|
19
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
20
|
+
Requires-Dist: pytest-cov; extra == "dev"
|
|
21
|
+
Dynamic: license-file
|
|
22
|
+
|
|
23
|
+
# morphlabs
|
|
24
|
+
|
|
25
|
+
Python SDK for Morphlabs biosignal processing API.
|
|
26
|
+
|
|
27
|
+
## Installation
|
|
28
|
+
|
|
29
|
+
```bash
|
|
30
|
+
pip install morphlabs
|
|
31
|
+
```
|
|
32
|
+
|
|
33
|
+
## Quick Start
|
|
34
|
+
|
|
35
|
+
```python
|
|
36
|
+
from morphlabs.models import Scientia
|
|
37
|
+
|
|
38
|
+
# Set SCIENTIA_API_KEY environment variable or pass directly
|
|
39
|
+
scientia = Scientia(api_key="your-api-key")
|
|
40
|
+
|
|
41
|
+
# Clean EEG data
|
|
42
|
+
cleaned_data = scientia.clean_data("path/to/eeg_file.csv")
|
|
43
|
+
```
|
|
44
|
+
|
|
45
|
+
## Documentation
|
|
46
|
+
|
|
47
|
+
For full documentation, see [docs.morphlabs.tech](https://docs.morphlabs.tech)
|
|
48
|
+
|
|
49
|
+
## License
|
|
50
|
+
|
|
51
|
+
MIT
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
LICENSE
|
|
2
|
+
README.md
|
|
3
|
+
pyproject.toml
|
|
4
|
+
morphlabs/__init__.py
|
|
5
|
+
morphlabs.egg-info/PKG-INFO
|
|
6
|
+
morphlabs.egg-info/SOURCES.txt
|
|
7
|
+
morphlabs.egg-info/dependency_links.txt
|
|
8
|
+
morphlabs.egg-info/requires.txt
|
|
9
|
+
morphlabs.egg-info/top_level.txt
|
|
10
|
+
morphlabs/io/__init__.py
|
|
11
|
+
morphlabs/io/loading.py
|
|
12
|
+
morphlabs/models/__init__.py
|
|
13
|
+
morphlabs/models/scientia.py
|
|
14
|
+
tests/test_loading.py
|
|
15
|
+
tests/test_scientia.py
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools>=61.0", "wheel"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "morphlabs"
|
|
7
|
+
version = "0.1.0"
|
|
8
|
+
description = "Python SDK for Morphlabs biosignal processing API"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
license = "MIT"
|
|
11
|
+
|
|
12
|
+
authors = [
|
|
13
|
+
{name = "Morphlabs", email = "support@morphlabs.tech"}
|
|
14
|
+
]
|
|
15
|
+
maintainers = [
|
|
16
|
+
{name = "Morphlabs", email = "support@morphlabs.tech"}
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
requires-python = ">=3.8"
|
|
20
|
+
dependencies = [
|
|
21
|
+
"numpy>=1.20.0",
|
|
22
|
+
"requests>=2.28.0",
|
|
23
|
+
"mne>=1.0.0",
|
|
24
|
+
"pandas>=1.3.0",
|
|
25
|
+
"scipy>=1.7.0",
|
|
26
|
+
"tenacity>=9.1.2"
|
|
27
|
+
]
|
|
28
|
+
|
|
29
|
+
[project.urls]
|
|
30
|
+
Homepage = "https://morphlabs.tech"
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
[project.optional-dependencies]
|
|
34
|
+
dev = ["pytest>=7.0", "pytest-cov"]
|
|
35
|
+
|
|
36
|
+
[tool.setuptools.packages.find]
|
|
37
|
+
where = ["."]
|
|
38
|
+
exclude = ["tests", "tests.*"]
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import numpy as np
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
|
|
5
|
+
from morphlabs.io import EEGData
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@pytest.fixture
|
|
9
|
+
def test_data_path():
|
|
10
|
+
return Path(__file__).parent / "test_data"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@pytest.fixture
|
|
14
|
+
def valid_edf_2500_samples(test_data_path):
|
|
15
|
+
return EEGData(test_data_path / "valid_19ch_2500samples.edf")
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@pytest.fixture
|
|
19
|
+
def valid_csv_1000_samples(test_data_path):
|
|
20
|
+
return EEGData(test_data_path / "valid_19ch_1000samples.csv")
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
@pytest.fixture
|
|
24
|
+
def valid_csv_2500_samples(test_data_path):
|
|
25
|
+
return EEGData(test_data_path / "valid_19ch_2500samples.csv")
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
@pytest.fixture
|
|
29
|
+
def valid_bdf_2500_samples(test_data_path):
|
|
30
|
+
return EEGData(test_data_path / "valid_19ch_2500samples.bdf")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
@pytest.fixture
|
|
34
|
+
def valid_csv_500_samples(test_data_path):
|
|
35
|
+
return EEGData(test_data_path / "valid_19ch_500samples.csv")
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
@pytest.mark.parametrize("fixture_name", [
|
|
40
|
+
"valid_edf_2500_samples",
|
|
41
|
+
"valid_csv_1000_samples",
|
|
42
|
+
"valid_bdf_2500_samples",
|
|
43
|
+
])
|
|
44
|
+
def test_load_file_success(fixture_name, request):
|
|
45
|
+
data = request.getfixturevalue(fixture_name)
|
|
46
|
+
assert data.get_data() is not None
|
|
47
|
+
assert data.get_channels() == 19
|
|
48
|
+
assert isinstance(data.get_data()[0], np.ndarray)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_load_none_success():
|
|
52
|
+
data = EEGData(None)
|
|
53
|
+
assert data.get_data() is None
|
|
54
|
+
assert data.get_channels() is None
|
|
55
|
+
assert data.get_pad_amount() is None
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def test_getter_dtype(valid_csv_1000_samples, valid_edf_2500_samples):
|
|
60
|
+
assert valid_edf_2500_samples.get_data()[0][0].dtype == np.float32
|
|
61
|
+
assert valid_csv_1000_samples.get_data()[0][0].dtype == np.float32
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def test_getter_pad_amount_no_padding(valid_csv_1000_samples):
|
|
65
|
+
assert valid_csv_1000_samples.get_pad_amount() == 0
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def test_getter_pad_amount_with_padding(valid_csv_500_samples):
|
|
69
|
+
assert valid_csv_500_samples.get_pad_amount() == 500
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def test_getter_pad_amount_range(valid_edf_2500_samples):
|
|
73
|
+
assert 0 <= valid_edf_2500_samples.get_pad_amount() <= 1000
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
@pytest.mark.parametrize("fixture_name", [
|
|
77
|
+
"valid_edf_2500_samples",
|
|
78
|
+
"valid_csv_2500_samples",
|
|
79
|
+
"valid_bdf_2500_samples",
|
|
80
|
+
])
|
|
81
|
+
def test_segment_data_success(fixture_name, request):
|
|
82
|
+
eeg_data = request.getfixturevalue(fixture_name)
|
|
83
|
+
data = eeg_data.get_data()
|
|
84
|
+
|
|
85
|
+
assert len(data) == 3
|
|
86
|
+
for segment in data:
|
|
87
|
+
assert segment.shape == (19, 1000)
|
|
88
|
+
assert segment[0][0].dtype == np.float32
|
|
89
|
+
|
|
90
|
+
assert data[-1].shape[1] == 1000
|
|
91
|
+
assert 0 <= eeg_data.get_pad_amount() <= 1000
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def test_file_not_found(test_data_path):
|
|
95
|
+
with pytest.raises(ValueError) as e:
|
|
96
|
+
EEGData(test_data_path / "nonexistent_file.csv")
|
|
97
|
+
assert "File not found" in str(e.value)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_path_is_directory(test_data_path):
|
|
101
|
+
with pytest.raises(ValueError) as e:
|
|
102
|
+
EEGData(test_data_path)
|
|
103
|
+
assert "Path is not a file" in str(e.value)
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
def test_empty_file(test_data_path):
|
|
107
|
+
with pytest.raises(ValueError) as e:
|
|
108
|
+
EEGData(test_data_path / "empty_file.csv")
|
|
109
|
+
assert "File is empty" in str(e.value)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def test_no_content_file(test_data_path):
|
|
113
|
+
with pytest.raises(ValueError) as e:
|
|
114
|
+
EEGData(test_data_path / "no_content.csv")
|
|
115
|
+
assert "File is empty" in str(e.value)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def test_unsupported_file_type(test_data_path):
|
|
120
|
+
with pytest.raises(ValueError) as e:
|
|
121
|
+
EEGData(test_data_path / "invalid_file_format.txt")
|
|
122
|
+
assert "Unsupported file type" in str(e.value)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@pytest.mark.parametrize("filename,channel_count", [
|
|
126
|
+
("invalid_10ch.csv", 10),
|
|
127
|
+
("invalid_32ch.csv", 32),
|
|
128
|
+
])
|
|
129
|
+
def test_unsupported_channel_count(test_data_path, filename, channel_count):
|
|
130
|
+
with pytest.raises(ValueError) as e:
|
|
131
|
+
EEGData(test_data_path / filename)
|
|
132
|
+
assert "Unsupported number of channels" in str(e.value)
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def test_corrupted_csv(test_data_path):
|
|
137
|
+
with pytest.raises(ValueError) as e:
|
|
138
|
+
EEGData(test_data_path / "corrupted.csv")
|
|
139
|
+
assert "Failed to" in str(e.value)
|
|
140
|
+
|
|
141
|
+
|
|
142
|
+
def test_malformed_csv(test_data_path):
|
|
143
|
+
with pytest.raises(ValueError) as e:
|
|
144
|
+
EEGData(test_data_path / "malformed.csv")
|
|
145
|
+
assert "Failed to parse CSV file" in str(e.value)
|
|
146
|
+
|
|
147
|
+
|
|
148
|
+
def test_corrupted_edf(test_data_path):
|
|
149
|
+
with pytest.raises(ValueError) as e:
|
|
150
|
+
EEGData(test_data_path / "corrupted.edf")
|
|
151
|
+
assert "Bad EDF file provided" in str(e.value)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def test_corrupted_bdf(test_data_path):
|
|
155
|
+
with pytest.raises(ValueError) as e:
|
|
156
|
+
EEGData(test_data_path / "corrupted.bdf")
|
|
157
|
+
assert "Bad BDF file provided" in str(e.value)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def test_invalid_encoding_csv(test_data_path):
|
|
161
|
+
with pytest.raises(ValueError) as e:
|
|
162
|
+
EEGData(test_data_path / "invalid_encoding.csv")
|
|
163
|
+
assert "File contains invalid characters" in str(e.value)
|
|
@@ -0,0 +1,332 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from unittest.mock import patch, MagicMock
|
|
3
|
+
from pathlib import Path
|
|
4
|
+
import requests
|
|
5
|
+
|
|
6
|
+
from morphlabs.models import Scientia
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
@pytest.fixture
|
|
10
|
+
def test_data_path():
|
|
11
|
+
return Path(__file__).parent / "test_data"
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@pytest.fixture
|
|
15
|
+
def api_key(monkeypatch):
|
|
16
|
+
monkeypatch.setenv("SCIENTIA_API_KEY", "test_api_key")
|
|
17
|
+
return "test_api_key"
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
@pytest.fixture
|
|
21
|
+
def no_api_key(monkeypatch):
|
|
22
|
+
monkeypatch.delenv("SCIENTIA_API_KEY", raising=False)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def mock_success_response(url, json, **kwargs):
|
|
27
|
+
mock_response = MagicMock()
|
|
28
|
+
mock_response.status_code = 200
|
|
29
|
+
mock_response.ok = True
|
|
30
|
+
mock_response.json.return_value = {
|
|
31
|
+
"status": "COMPLETED",
|
|
32
|
+
"output": {"reconstructed": json["input"]["data"]}
|
|
33
|
+
}
|
|
34
|
+
return mock_response
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def mock_missing_data_response(url, json, **kwargs):
|
|
38
|
+
mock_response = MagicMock()
|
|
39
|
+
mock_response.status_code = 200
|
|
40
|
+
mock_response.ok = True
|
|
41
|
+
mock_response.json.return_value = {"status": "COMPLETED", "output": {}}
|
|
42
|
+
return mock_response
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def test_init_defaults(no_api_key):
|
|
46
|
+
scientia = Scientia()
|
|
47
|
+
assert scientia.api_key is None
|
|
48
|
+
assert scientia.base_url == "https://api.runpod.ai/v2/9ni9hifywn9z73/runsync"
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def test_init_with_api_key(no_api_key):
|
|
52
|
+
scientia = Scientia(api_key="test_api_key1")
|
|
53
|
+
assert scientia.api_key == "test_api_key1"
|
|
54
|
+
assert scientia.base_url == "https://api.runpod.ai/v2/9ni9hifywn9z73/runsync"
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def test_init_from_env(api_key):
|
|
58
|
+
scientia = Scientia(base_url="https://test.scientia.ai")
|
|
59
|
+
assert scientia.api_key == "test_api_key"
|
|
60
|
+
assert scientia.base_url == "https://test.scientia.ai"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@pytest.mark.parametrize("base_url,expected", [
|
|
64
|
+
("https://api.scientia.ai/", "https://api.scientia.ai"),
|
|
65
|
+
("https://api.scientia.ai/v1", "https://api.scientia.ai/v1"),
|
|
66
|
+
("https://api.scientia.ai/v1/", "https://api.scientia.ai/v1"),
|
|
67
|
+
])
|
|
68
|
+
def test_init_strips_trailing_slash(base_url, expected):
|
|
69
|
+
scientia = Scientia(api_key="test_api_key", base_url=base_url)
|
|
70
|
+
assert scientia.base_url == expected
|
|
71
|
+
|
|
72
|
+
|
|
73
|
+
# =============================================================================
|
|
74
|
+
# Validation Tests
|
|
75
|
+
# =============================================================================
|
|
76
|
+
|
|
77
|
+
def test_missing_api_key_error(no_api_key, test_data_path):
|
|
78
|
+
scientia = Scientia()
|
|
79
|
+
with pytest.raises(ValueError) as e:
|
|
80
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
81
|
+
assert "API key not found" in str(e.value)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def test_empty_api_key_error(monkeypatch, test_data_path):
|
|
85
|
+
monkeypatch.setenv("SCIENTIA_API_KEY", "")
|
|
86
|
+
with pytest.raises(ValueError) as e:
|
|
87
|
+
scientia = Scientia()
|
|
88
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
89
|
+
assert "API key cannot be empty or whitespace" in str(e.value)
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def test_whitespace_api_key_error(monkeypatch, test_data_path):
|
|
93
|
+
monkeypatch.setenv("SCIENTIA_API_KEY", " ")
|
|
94
|
+
with pytest.raises(ValueError) as e:
|
|
95
|
+
scientia = Scientia()
|
|
96
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
97
|
+
assert "API key cannot be empty or whitespace" in str(e.value)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def test_invalid_base_url_error(test_data_path):
|
|
101
|
+
with pytest.raises(ValueError) as e:
|
|
102
|
+
scientia = Scientia(api_key="test_api_key", base_url="invalid_url")
|
|
103
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
104
|
+
assert "Invalid base_url" in str(e.value)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def test_invalid_base_url_protocol_error(test_data_path):
|
|
108
|
+
with pytest.raises(ValueError) as e:
|
|
109
|
+
scientia = Scientia(api_key="test_api_key", base_url="ftp://invalid_url")
|
|
110
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
111
|
+
assert "URL must start with http:// or https://" in str(e.value)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def test_clean_data_success(api_key, test_data_path):
|
|
115
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_success_response):
|
|
116
|
+
scientia = Scientia()
|
|
117
|
+
data = scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
118
|
+
assert data is not None
|
|
119
|
+
assert data.shape == (19, 2500)
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def test_clean_data_removes_padding(api_key, test_data_path):
|
|
123
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_success_response):
|
|
124
|
+
scientia = Scientia()
|
|
125
|
+
|
|
126
|
+
data_padded = scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
127
|
+
assert data_padded.shape == (19, 2500)
|
|
128
|
+
|
|
129
|
+
data_no_pad = scientia.clean_data(test_data_path / "valid_19ch_2000samples.csv")
|
|
130
|
+
assert data_no_pad.shape == (19, 2000)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def test_json_missing_reconstructed_field(api_key, test_data_path):
|
|
134
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_missing_data_response):
|
|
135
|
+
scientia = Scientia()
|
|
136
|
+
with pytest.raises(ValueError) as e:
|
|
137
|
+
scientia.clean_data(test_data_path / "valid_19ch_2500samples.csv")
|
|
138
|
+
assert "Invalid response from API: Missing 'reconstructed' field in response." in str(e.value)
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
def test_json_decode_error(api_key, test_data_path):
|
|
142
|
+
def mock_invalid_json(url, json, **kwargs):
|
|
143
|
+
mock_response = MagicMock()
|
|
144
|
+
mock_response.status_code = 200
|
|
145
|
+
mock_response.ok = True
|
|
146
|
+
mock_response.json.side_effect = requests.exceptions.JSONDecodeError("", "", 0)
|
|
147
|
+
return mock_response
|
|
148
|
+
|
|
149
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_invalid_json):
|
|
150
|
+
scientia = Scientia()
|
|
151
|
+
with pytest.raises(ValueError) as e:
|
|
152
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
153
|
+
assert "Expected JSON but received invalid data" in str(e.value)
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
@pytest.mark.parametrize("status_code", [429, 500, 502, 503, 504])
|
|
157
|
+
def test_retryable_status_codes(api_key, test_data_path, status_code):
|
|
158
|
+
def mock_retryable_error(url, json, **kwargs):
|
|
159
|
+
mock_response = MagicMock()
|
|
160
|
+
mock_response.status_code = status_code
|
|
161
|
+
mock_response.ok = False
|
|
162
|
+
mock_response.text = f"Error {status_code}"
|
|
163
|
+
return mock_response
|
|
164
|
+
|
|
165
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_retryable_error):
|
|
166
|
+
scientia = Scientia()
|
|
167
|
+
with pytest.raises(Exception) as e:
|
|
168
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
169
|
+
last_exception = e.value.last_attempt.exception()
|
|
170
|
+
assert f"API request failed with status {status_code}" in str(last_exception)
|
|
171
|
+
|
|
172
|
+
|
|
173
|
+
@pytest.mark.parametrize("status_code,expected_msg", [
|
|
174
|
+
(400, "Bad request"),
|
|
175
|
+
(401, "Authentication failed"),
|
|
176
|
+
(403, "Access denied"),
|
|
177
|
+
(404, "API endpoint not found"),
|
|
178
|
+
])
|
|
179
|
+
def test_non_retryable_status_codes(api_key, test_data_path, status_code, expected_msg):
|
|
180
|
+
def mock_non_retryable_error(url, json, **kwargs):
|
|
181
|
+
mock_response = MagicMock()
|
|
182
|
+
mock_response.status_code = status_code
|
|
183
|
+
mock_response.ok = False
|
|
184
|
+
mock_response.text = f"Error {status_code}"
|
|
185
|
+
return mock_response
|
|
186
|
+
|
|
187
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_non_retryable_error):
|
|
188
|
+
scientia = Scientia()
|
|
189
|
+
with pytest.raises(ValueError) as e:
|
|
190
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
191
|
+
assert expected_msg in str(e.value)
|
|
192
|
+
|
|
193
|
+
|
|
194
|
+
def test_retry_attempts_count(api_key, test_data_path):
|
|
195
|
+
call_count = 0
|
|
196
|
+
|
|
197
|
+
def mock_always_fails(url, json, **kwargs):
|
|
198
|
+
nonlocal call_count
|
|
199
|
+
call_count += 1
|
|
200
|
+
mock_response = MagicMock()
|
|
201
|
+
mock_response.status_code = 500
|
|
202
|
+
mock_response.ok = False
|
|
203
|
+
mock_response.text = "Server error"
|
|
204
|
+
return mock_response
|
|
205
|
+
|
|
206
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_always_fails):
|
|
207
|
+
scientia = Scientia()
|
|
208
|
+
with pytest.raises(Exception):
|
|
209
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
210
|
+
|
|
211
|
+
assert call_count == 3
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def test_network_timeout(api_key, test_data_path):
|
|
215
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=requests.exceptions.Timeout("Connection timed out")):
|
|
216
|
+
scientia = Scientia()
|
|
217
|
+
with pytest.raises(Exception) as e:
|
|
218
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
219
|
+
last_exception = e.value.last_attempt.exception()
|
|
220
|
+
assert "Network error" in str(last_exception)
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def test_network_connection_error(api_key, test_data_path):
|
|
224
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=requests.exceptions.ConnectionError("Connection refused")):
|
|
225
|
+
scientia = Scientia()
|
|
226
|
+
with pytest.raises(Exception) as e:
|
|
227
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
228
|
+
last_exception = e.value.last_attempt.exception()
|
|
229
|
+
assert "Network error" in str(last_exception)
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
# =============================================================================
|
|
233
|
+
# RunPod-specific Status Tests
|
|
234
|
+
# =============================================================================
|
|
235
|
+
|
|
236
|
+
@pytest.mark.parametrize("status", ["IN_QUEUE", "IN_PROGRESS"])
|
|
237
|
+
def test_runpod_processing_status(api_key, test_data_path, status):
|
|
238
|
+
def mock_processing_response(url, json, **kwargs):
|
|
239
|
+
mock_response = MagicMock()
|
|
240
|
+
mock_response.status_code = 200
|
|
241
|
+
mock_response.ok = True
|
|
242
|
+
mock_response.json.return_value = {"status": status}
|
|
243
|
+
return mock_response
|
|
244
|
+
|
|
245
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_processing_response):
|
|
246
|
+
scientia = Scientia()
|
|
247
|
+
with pytest.raises(ValueError) as e:
|
|
248
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
249
|
+
assert "still processing" in str(e.value)
|
|
250
|
+
assert status in str(e.value)
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def test_runpod_failed_status(api_key, test_data_path):
|
|
254
|
+
def mock_failed_response(url, json, **kwargs):
|
|
255
|
+
mock_response = MagicMock()
|
|
256
|
+
mock_response.status_code = 200
|
|
257
|
+
mock_response.ok = True
|
|
258
|
+
mock_response.json.return_value = {"status": "FAILED", "error": "GPU out of memory"}
|
|
259
|
+
return mock_response
|
|
260
|
+
|
|
261
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_failed_response):
|
|
262
|
+
scientia = Scientia()
|
|
263
|
+
with pytest.raises(ValueError) as e:
|
|
264
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
265
|
+
assert "Scientia API request failed" in str(e.value)
|
|
266
|
+
assert "GPU out of memory" in str(e.value)
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def test_runpod_cancelled_status(api_key, test_data_path):
|
|
270
|
+
def mock_cancelled_response(url, json, **kwargs):
|
|
271
|
+
mock_response = MagicMock()
|
|
272
|
+
mock_response.status_code = 200
|
|
273
|
+
mock_response.ok = True
|
|
274
|
+
mock_response.json.return_value = {"status": "CANCELLED"}
|
|
275
|
+
return mock_response
|
|
276
|
+
|
|
277
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_cancelled_response):
|
|
278
|
+
scientia = Scientia()
|
|
279
|
+
with pytest.raises(ValueError) as e:
|
|
280
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
281
|
+
assert "was cancelled" in str(e.value)
|
|
282
|
+
|
|
283
|
+
|
|
284
|
+
def test_api_key_in_payload(api_key, test_data_path):
|
|
285
|
+
captured_payload = None
|
|
286
|
+
|
|
287
|
+
def mock_capture_payload(url, json, **kwargs):
|
|
288
|
+
nonlocal captured_payload
|
|
289
|
+
captured_payload = json
|
|
290
|
+
mock_response = MagicMock()
|
|
291
|
+
mock_response.status_code = 200
|
|
292
|
+
mock_response.ok = True
|
|
293
|
+
mock_response.json.return_value = {
|
|
294
|
+
"status": "COMPLETED",
|
|
295
|
+
"output": {"reconstructed": json["input"]["data"]}
|
|
296
|
+
}
|
|
297
|
+
return mock_response
|
|
298
|
+
|
|
299
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_capture_payload):
|
|
300
|
+
scientia = Scientia()
|
|
301
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
302
|
+
|
|
303
|
+
assert captured_payload is not None
|
|
304
|
+
assert "input" in captured_payload
|
|
305
|
+
assert "api_key" in captured_payload["input"]
|
|
306
|
+
assert captured_payload["input"]["api_key"] == "test_api_key"
|
|
307
|
+
assert "data" in captured_payload["input"]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def test_runpod_api_key_in_header(api_key, test_data_path):
|
|
311
|
+
captured_headers = None
|
|
312
|
+
|
|
313
|
+
def mock_capture_headers(url, headers=None, json=None, **kwargs):
|
|
314
|
+
nonlocal captured_headers
|
|
315
|
+
captured_headers = headers
|
|
316
|
+
mock_response = MagicMock()
|
|
317
|
+
mock_response.status_code = 200
|
|
318
|
+
mock_response.ok = True
|
|
319
|
+
mock_response.json.return_value = {
|
|
320
|
+
"status": "COMPLETED",
|
|
321
|
+
"output": {"reconstructed": json["input"]["data"]}
|
|
322
|
+
}
|
|
323
|
+
return mock_response
|
|
324
|
+
|
|
325
|
+
with patch("morphlabs.models.scientia.requests.post", side_effect=mock_capture_headers):
|
|
326
|
+
scientia = Scientia()
|
|
327
|
+
scientia.clean_data(test_data_path / "valid_19ch_1000samples.csv")
|
|
328
|
+
|
|
329
|
+
assert captured_headers is not None
|
|
330
|
+
assert "Authorization" in captured_headers
|
|
331
|
+
assert captured_headers["Authorization"].startswith("Bearer ")
|
|
332
|
+
assert "rpa_" in captured_headers["Authorization"]
|