most-client 1.0.29__tar.gz → 1.0.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: most-client
3
- Version: 1.0.29
3
+ Version: 1.0.30
4
4
  Summary: Most AI API for https://the-most.ai
5
5
  Home-page: https://github.com/the-most-ai/most-client
6
6
  Author: George Kasparyants
@@ -0,0 +1,4 @@
1
+ from .api import MostClient
2
+ from .async_api import AsyncMostClient
3
+ from .trainer_api import Trainer
4
+ from .async_trainer_api import AsyncTrainer
@@ -0,0 +1,26 @@
1
+ from typing import List
2
+
3
+ from . import AsyncMostClient
4
+ from .types import HumanFeedback
5
+
6
+
7
+ class AsyncTrainer(object):
8
+ def __init__(self, client: AsyncMostClient):
9
+ super(AsyncTrainer, self).__init__()
10
+ self.client = client
11
+ if self.client.model_id is None:
12
+ raise RuntimeError("Train must be implemented for stable model_id")
13
+
14
+ async def fit(self, data: List[HumanFeedback]):
15
+ resp = await self.client.put(f"/{self.client.client_id}/model/{self.client.model_id}/data",
16
+ json={"data": [hf.to_dict() for hf in data]})
17
+ return self
18
+
19
+ async def evaluate(self, data: List[HumanFeedback]):
20
+ gt_data = await self.get_data_points()
21
+ return HumanFeedback.calculate_accuracy(data, gt_data)
22
+
23
+ async def get_data_points(self) -> List[HumanFeedback]:
24
+ resp = await self.client.get(f"/{self.client.client_id}/model/{self.client.model_id}/data")
25
+ audio_list = resp.json()
26
+ return self.client.retort.load(audio_list, List[HumanFeedback])
@@ -0,0 +1,25 @@
1
+ from typing import List
2
+ from .api import MostClient
3
+ from .types import HumanFeedback
4
+
5
+
6
+ class Trainer(object):
7
+ def __init__(self, client: MostClient):
8
+ super(Trainer, self).__init__()
9
+ self.client = client
10
+ if self.client.model_id is None:
11
+ raise RuntimeError("Train must be implemented for stable model_id")
12
+
13
+ def fit(self, data: List[HumanFeedback]):
14
+ resp = self.client.put(f"/{self.client.client_id}/model/{self.client.model_id}/data",
15
+ json={"data": [hf.to_dict() for hf in data]})
16
+ return self
17
+
18
+ def evaluate(self, data: List[HumanFeedback]):
19
+ gt_data = self.get_data_points()
20
+ return HumanFeedback.calculate_accuracy(data, gt_data)
21
+
22
+ def get_data_points(self) -> List[HumanFeedback]:
23
+ resp = self.client.get(f"/{self.client.client_id}/model/{self.client.model_id}/data")
24
+ audio_list = resp.json()
25
+ return self.client.retort.load(audio_list, List[HumanFeedback])
@@ -1,7 +1,7 @@
1
1
  import re
2
2
  from dataclasses import dataclass
3
3
  from datetime import datetime
4
- from typing import Dict, List, Literal, Optional
4
+ from typing import Dict, List, Literal, Optional, Union
5
5
 
6
6
  from dataclasses_json import DataClassJsonMixin, dataclass_json
7
7
 
@@ -151,6 +151,28 @@ class SearchParams(DataClassJsonMixin):
151
151
  must_not: List[StoredInfoCondition | ResultsCondition]
152
152
 
153
153
 
154
+ @dataclass_json
155
+ @dataclass
156
+ class HumanFeedback(DataClassJsonMixin):
157
+ data_point_id: str
158
+ data_point_type: Literal["audio", "text"]
159
+ column_name: str
160
+ subcolumn_name: str
161
+ score: int
162
+ description: str = ""
163
+
164
+ @classmethod
165
+ def calculate_accuracy(cls,
166
+ preds: List["HumanFeedback"],
167
+ gt: List["HumanFeedback"]) -> float:
168
+ preds = {(y_pred.data_point_id, y_pred.column_name, y_pred.subcolumn_name): y_pred.score
169
+ for y_pred in preds}
170
+ gt = {(y_true.data_point_id, y_true.column_name, y_true.subcolumn_name): y_true.score
171
+ for y_true in gt}
172
+ common_keys = set(preds.keys()) & set(gt.keys())
173
+ return sum((preds[key] == gt[key]) for key in common_keys) / len(common_keys)
174
+
175
+
154
176
  def is_valid_objectid(oid: str) -> bool:
155
177
  """
156
178
  Check if a given string is a valid MongoDB ObjectId (24-character hex).
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: most-client
3
- Version: 1.0.29
3
+ Version: 1.0.30
4
4
  Summary: Most AI API for https://the-most.ai
5
5
  Home-page: https://github.com/the-most-ai/most-client
6
6
  Author: George Kasparyants
@@ -6,7 +6,9 @@ most/__init__.py
6
6
  most/_constrants.py
7
7
  most/api.py
8
8
  most/async_api.py
9
+ most/async_trainer_api.py
9
10
  most/score_calculation.py
11
+ most/trainer_api.py
10
12
  most/types.py
11
13
  most_client.egg-info/PKG-INFO
12
14
  most_client.egg-info/SOURCES.txt
@@ -8,7 +8,7 @@ with open('requirements.txt', 'r') as f:
8
8
 
9
9
  setup(
10
10
  name='most-client',
11
- version='1.0.29',
11
+ version='1.0.30',
12
12
  python_requires=f'>=3.6',
13
13
  description='Most AI API for https://the-most.ai',
14
14
  url='https://github.com/the-most-ai/most-client',
@@ -1,2 +0,0 @@
1
- from .api import MostClient
2
- from .async_api import AsyncMostClient
File without changes
File without changes
File without changes
File without changes