most-client 1.0.28__py3-none-any.whl → 1.0.30__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- most/__init__.py +2 -0
- most/async_api.py +2 -2
- most/async_trainer_api.py +26 -0
- most/trainer_api.py +25 -0
- most/types.py +23 -1
- {most_client-1.0.28.dist-info → most_client-1.0.30.dist-info}/METADATA +1 -1
- most_client-1.0.30.dist-info/RECORD +13 -0
- most_client-1.0.28.dist-info/RECORD +0 -11
- {most_client-1.0.28.dist-info → most_client-1.0.30.dist-info}/WHEEL +0 -0
- {most_client-1.0.28.dist-info → most_client-1.0.30.dist-info}/top_level.txt +0 -0
- {most_client-1.0.28.dist-info → most_client-1.0.30.dist-info}/zip-safe +0 -0
most/__init__.py
CHANGED
most/async_api.py
CHANGED
@@ -140,8 +140,8 @@ class AsyncMostClient(object):
|
|
140
140
|
headers=headers,
|
141
141
|
**kwargs)
|
142
142
|
|
143
|
-
if resp.
|
144
|
-
raise RuntimeError(resp.json()['message'])
|
143
|
+
if resp.status_code >= 400:
|
144
|
+
raise RuntimeError(resp.json()['message'] if resp.headers.get("Content-Type") == "application/json" else "Something went wrong.")
|
145
145
|
resp.raise_for_status()
|
146
146
|
return resp
|
147
147
|
|
@@ -0,0 +1,26 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
from . import AsyncMostClient
|
4
|
+
from .types import HumanFeedback
|
5
|
+
|
6
|
+
|
7
|
+
class AsyncTrainer(object):
|
8
|
+
def __init__(self, client: AsyncMostClient):
|
9
|
+
super(AsyncTrainer, self).__init__()
|
10
|
+
self.client = client
|
11
|
+
if self.client.model_id is None:
|
12
|
+
raise RuntimeError("Train must be implemented for stable model_id")
|
13
|
+
|
14
|
+
async def fit(self, data: List[HumanFeedback]):
|
15
|
+
resp = await self.client.put(f"/{self.client.client_id}/model/{self.client.model_id}/data",
|
16
|
+
json={"data": [hf.to_dict() for hf in data]})
|
17
|
+
return self
|
18
|
+
|
19
|
+
async def evaluate(self, data: List[HumanFeedback]):
|
20
|
+
gt_data = await self.get_data_points()
|
21
|
+
return HumanFeedback.calculate_accuracy(data, gt_data)
|
22
|
+
|
23
|
+
async def get_data_points(self) -> List[HumanFeedback]:
|
24
|
+
resp = await self.client.get(f"/{self.client.client_id}/model/{self.client.model_id}/data")
|
25
|
+
audio_list = resp.json()
|
26
|
+
return self.client.retort.load(audio_list, List[HumanFeedback])
|
most/trainer_api.py
ADDED
@@ -0,0 +1,25 @@
|
|
1
|
+
from typing import List
|
2
|
+
from .api import MostClient
|
3
|
+
from .types import HumanFeedback
|
4
|
+
|
5
|
+
|
6
|
+
class Trainer(object):
|
7
|
+
def __init__(self, client: MostClient):
|
8
|
+
super(Trainer, self).__init__()
|
9
|
+
self.client = client
|
10
|
+
if self.client.model_id is None:
|
11
|
+
raise RuntimeError("Train must be implemented for stable model_id")
|
12
|
+
|
13
|
+
def fit(self, data: List[HumanFeedback]):
|
14
|
+
resp = self.client.put(f"/{self.client.client_id}/model/{self.client.model_id}/data",
|
15
|
+
json={"data": [hf.to_dict() for hf in data]})
|
16
|
+
return self
|
17
|
+
|
18
|
+
def evaluate(self, data: List[HumanFeedback]):
|
19
|
+
gt_data = self.get_data_points()
|
20
|
+
return HumanFeedback.calculate_accuracy(data, gt_data)
|
21
|
+
|
22
|
+
def get_data_points(self) -> List[HumanFeedback]:
|
23
|
+
resp = self.client.get(f"/{self.client.client_id}/model/{self.client.model_id}/data")
|
24
|
+
audio_list = resp.json()
|
25
|
+
return self.client.retort.load(audio_list, List[HumanFeedback])
|
most/types.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1
1
|
import re
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from datetime import datetime
|
4
|
-
from typing import Dict, List, Literal, Optional
|
4
|
+
from typing import Dict, List, Literal, Optional, Union
|
5
5
|
|
6
6
|
from dataclasses_json import DataClassJsonMixin, dataclass_json
|
7
7
|
|
@@ -151,6 +151,28 @@ class SearchParams(DataClassJsonMixin):
|
|
151
151
|
must_not: List[StoredInfoCondition | ResultsCondition]
|
152
152
|
|
153
153
|
|
154
|
+
@dataclass_json
|
155
|
+
@dataclass
|
156
|
+
class HumanFeedback(DataClassJsonMixin):
|
157
|
+
data_point_id: str
|
158
|
+
data_point_type: Literal["audio", "text"]
|
159
|
+
column_name: str
|
160
|
+
subcolumn_name: str
|
161
|
+
score: int
|
162
|
+
description: str = ""
|
163
|
+
|
164
|
+
@classmethod
|
165
|
+
def calculate_accuracy(cls,
|
166
|
+
preds: List["HumanFeedback"],
|
167
|
+
gt: List["HumanFeedback"]) -> float:
|
168
|
+
preds = {(y_pred.data_point_id, y_pred.column_name, y_pred.subcolumn_name): y_pred.score
|
169
|
+
for y_pred in preds}
|
170
|
+
gt = {(y_true.data_point_id, y_true.column_name, y_true.subcolumn_name): y_true.score
|
171
|
+
for y_true in gt}
|
172
|
+
common_keys = set(preds.keys()) & set(gt.keys())
|
173
|
+
return sum((preds[key] == gt[key]) for key in common_keys) / len(common_keys)
|
174
|
+
|
175
|
+
|
154
176
|
def is_valid_objectid(oid: str) -> bool:
|
155
177
|
"""
|
156
178
|
Check if a given string is a valid MongoDB ObjectId (24-character hex).
|
@@ -0,0 +1,13 @@
|
|
1
|
+
most/__init__.py,sha256=b0EXXaPA4kmt-FtGXKRWZr7SCwjipMLcpC6uT5WRIdY,144
|
2
|
+
most/_constrants.py,sha256=SlHKcBoXwe_sPzk8tdbb7lqhQz-Bfo__FhSoeFWodZE,217
|
3
|
+
most/api.py,sha256=_xqIj24dm1bINPy54zZ4Xqp5V8DQa05TN3zocuZW7io,17895
|
4
|
+
most/async_api.py,sha256=7uymQ643SVzFU-j_iR-8MqT5wkNad9v54nR9vT87QoY,18968
|
5
|
+
most/async_trainer_api.py,sha256=99rED8RjnOn8VezeEgrTgoVfQrO7DdmOE2Jajumno2g,1052
|
6
|
+
most/score_calculation.py,sha256=1XU1LfIH5LSCwAbAaKkr-EjH5qOTXrJKOUvhCCawka4,1054
|
7
|
+
most/trainer_api.py,sha256=ZwOv4mhROfY97n6i7IY_ZpafsuNRazOqMBAf2dh708k,992
|
8
|
+
most/types.py,sha256=-voecFH0E4ScMHu0DR_2S6XNdlkGuVJgy0Z1Oui2iM8,4578
|
9
|
+
most_client-1.0.30.dist-info/METADATA,sha256=f-x0jH2hHk7uYN2DMTohvxYEJFRczhoH0eQVHBSn8dg,1027
|
10
|
+
most_client-1.0.30.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
11
|
+
most_client-1.0.30.dist-info/top_level.txt,sha256=2g5fk02LKkM1hV3pVVti_LQ60TToLBcR2zQ3JEKGVk8,5
|
12
|
+
most_client-1.0.30.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
13
|
+
most_client-1.0.30.dist-info/RECORD,,
|
@@ -1,11 +0,0 @@
|
|
1
|
-
most/__init__.py,sha256=62uFFeM_1VVR83K3bTYWK3PEoqnmFCy9aWYerQ6U4Ds,67
|
2
|
-
most/_constrants.py,sha256=SlHKcBoXwe_sPzk8tdbb7lqhQz-Bfo__FhSoeFWodZE,217
|
3
|
-
most/api.py,sha256=_xqIj24dm1bINPy54zZ4Xqp5V8DQa05TN3zocuZW7io,17895
|
4
|
-
most/async_api.py,sha256=bt3natlTmqbafo1hJ5TUkxMssz4l8o6_sLNRbTAoU4A,18912
|
5
|
-
most/score_calculation.py,sha256=1XU1LfIH5LSCwAbAaKkr-EjH5qOTXrJKOUvhCCawka4,1054
|
6
|
-
most/types.py,sha256=ukhG70TPfSR9Xp11mQpXsaSBvKcbFW64PbXiUWtgG5E,3771
|
7
|
-
most_client-1.0.28.dist-info/METADATA,sha256=fGmSQqhd578FQN8SgZV0dY6A4Wo7GXV6_e5HB0b15i8,1027
|
8
|
-
most_client-1.0.28.dist-info/WHEEL,sha256=CmyFI0kx5cdEMTLiONQRbGQwjIoR1aIYB7eCAQ4KPJ0,91
|
9
|
-
most_client-1.0.28.dist-info/top_level.txt,sha256=2g5fk02LKkM1hV3pVVti_LQ60TToLBcR2zQ3JEKGVk8,5
|
10
|
-
most_client-1.0.28.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
11
|
-
most_client-1.0.28.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|