most-client 1.0.36__py3-none-any.whl → 1.0.38__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- most/__init__.py +3 -0
- most/api.py +16 -29
- most/async_api.py +21 -29
- most/async_searcher.py +57 -0
- most/score_calculation.py +53 -1
- most/search_types.py +143 -0
- most/searcher.py +57 -0
- most/types.py +1 -28
- {most_client-1.0.36.dist-info → most_client-1.0.38.dist-info}/METADATA +2 -1
- most_client-1.0.38.dist-info/RECORD +16 -0
- {most_client-1.0.36.dist-info → most_client-1.0.38.dist-info}/WHEEL +1 -1
- most_client-1.0.36.dist-info/RECORD +0 -13
- {most_client-1.0.36.dist-info → most_client-1.0.38.dist-info}/top_level.txt +0 -0
- {most_client-1.0.36.dist-info → most_client-1.0.38.dist-info}/zip-safe +0 -0
most/__init__.py
CHANGED
@@ -2,3 +2,6 @@ from .api import MostClient
|
|
2
2
|
from .async_api import AsyncMostClient
|
3
3
|
from .trainer_api import Trainer
|
4
4
|
from .async_trainer_api import AsyncTrainer
|
5
|
+
from .searcher import MostSearcher
|
6
|
+
from .async_searcher import AsyncMostClient
|
7
|
+
from .search_types import SearchParams, IDCondition, ChannelsCondition, DurationCondition, ResultsCondition, StoredInfoCondition
|
most/api.py
CHANGED
@@ -18,14 +18,14 @@ from most.types import (
|
|
18
18
|
Script,
|
19
19
|
StoredAudioData,
|
20
20
|
Text,
|
21
|
-
is_valid_id,
|
21
|
+
is_valid_id, ScriptScoreMapping, Dialog, Usage,
|
22
22
|
)
|
23
23
|
|
24
24
|
|
25
25
|
class MostClient(object):
|
26
26
|
retort = Retort(recipe=[
|
27
|
-
loader(int, lambda x: int(x)),
|
28
|
-
loader(float, lambda x: float(x)),
|
27
|
+
loader(int, lambda x: x if isinstance(x, int) else int(x)),
|
28
|
+
loader(float, lambda x: x if isinstance(x, float) else float(x)),
|
29
29
|
loader(datetime, lambda x: datetime.fromtimestamp(x).astimezone(tz=timezone.utc) if isinstance(x, (int, float)) else datetime.fromisoformat(x)),
|
30
30
|
],)
|
31
31
|
|
@@ -264,7 +264,7 @@ class MostClient(object):
|
|
264
264
|
params={"overwrite": overwrite})
|
265
265
|
result = self.retort.load(resp.json(), Result)
|
266
266
|
if modify_scores:
|
267
|
-
result = self.
|
267
|
+
result = self.get_score_modifier().modify(result)
|
268
268
|
return result
|
269
269
|
|
270
270
|
def apply_on_text(self, text_id,
|
@@ -280,7 +280,7 @@ class MostClient(object):
|
|
280
280
|
params={"overwrite": overwrite})
|
281
281
|
result = self.retort.load(resp.json(), Result)
|
282
282
|
if modify_scores:
|
283
|
-
result = self.
|
283
|
+
result = self.get_score_modifier().modify(result)
|
284
284
|
return result
|
285
285
|
|
286
286
|
def transcribe_later(self, audio_id,
|
@@ -308,7 +308,7 @@ class MostClient(object):
|
|
308
308
|
params={"overwrite": overwrite})
|
309
309
|
result = self.retort.load(resp.json(), Result)
|
310
310
|
if modify_scores:
|
311
|
-
result = self.
|
311
|
+
result = self.get_score_modifier().modify(result)
|
312
312
|
return result
|
313
313
|
|
314
314
|
def apply_on_text_later(self, text_id,
|
@@ -324,7 +324,7 @@ class MostClient(object):
|
|
324
324
|
params={"overwrite": overwrite})
|
325
325
|
result = self.retort.load(resp.json(), Result)
|
326
326
|
if modify_scores:
|
327
|
-
result = self.
|
327
|
+
result = self.get_score_modifier().modify(result)
|
328
328
|
return result
|
329
329
|
|
330
330
|
def get_job_status(self, audio_id) -> JobStatus:
|
@@ -348,7 +348,7 @@ class MostClient(object):
|
|
348
348
|
resp = self.get(f"/{self.client_id}/audio/{audio_id}/model/{self.model_id}/results")
|
349
349
|
result = self.retort.load(resp.json(), Result)
|
350
350
|
if modify_scores:
|
351
|
-
result = self.
|
351
|
+
result = self.get_score_modifier().modify(result)
|
352
352
|
return result
|
353
353
|
|
354
354
|
def fetch_text(self, audio_id: str) -> Result:
|
@@ -384,7 +384,8 @@ class MostClient(object):
|
|
384
384
|
|
385
385
|
def export(self, audio_ids: List[str],
|
386
386
|
aggregated_by: Optional[str] = None,
|
387
|
-
aggregation_title: Optional[str] = None
|
387
|
+
aggregation_title: Optional[str] = None,
|
388
|
+
modify_scores: bool = False) -> str:
|
388
389
|
if aggregation_title is None:
|
389
390
|
aggregation_title = aggregated_by
|
390
391
|
|
@@ -394,12 +395,13 @@ class MostClient(object):
|
|
394
395
|
resp = self.get(f"/{self.client_id}/model/{self.model_id}/export",
|
395
396
|
params={'audio_ids': ','.join(audio_ids),
|
396
397
|
'aggregated_by': aggregated_by,
|
397
|
-
'aggregation_title': aggregation_title
|
398
|
+
'aggregation_title': aggregation_title,
|
399
|
+
'modify_scores': modify_scores})
|
398
400
|
return resp.url
|
399
401
|
|
400
402
|
def store_info(self,
|
401
403
|
audio_id: str,
|
402
|
-
data: Dict[str, str]):
|
404
|
+
data: Dict[str, Union[str, int, float]]) -> StoredAudioData:
|
403
405
|
if not is_valid_id(audio_id):
|
404
406
|
raise RuntimeError("Please use valid audio_id. [try audio.id from list_audios()]")
|
405
407
|
|
@@ -407,14 +409,14 @@ class MostClient(object):
|
|
407
409
|
json={
|
408
410
|
"data": data,
|
409
411
|
})
|
410
|
-
return
|
412
|
+
return StoredAudioData.from_dict(resp.json())
|
411
413
|
|
412
|
-
def fetch_info(self, audio_id: str) ->
|
414
|
+
def fetch_info(self, audio_id: str) -> StoredAudioData:
|
413
415
|
if not is_valid_id(audio_id):
|
414
416
|
raise RuntimeError("Please use valid audio_id. [try audio.id from list_audios()]")
|
415
417
|
|
416
418
|
resp = self.get(f"/{self.client_id}/audio/{audio_id}/info")
|
417
|
-
return
|
419
|
+
return StoredAudioData.from_dict(resp.json())
|
418
420
|
|
419
421
|
def __call__(self, audio_path: Path,
|
420
422
|
modify_scores: bool = False) -> Result:
|
@@ -446,21 +448,6 @@ class MostClient(object):
|
|
446
448
|
raise RuntimeError("Audio can't be indexed")
|
447
449
|
return None
|
448
450
|
|
449
|
-
def search(self,
|
450
|
-
query: str,
|
451
|
-
filter: SearchParams,
|
452
|
-
limit: int = 10) -> List[Audio]:
|
453
|
-
resp = self.post(f"/{self.client_id}/model/{self.model_id}/search",
|
454
|
-
json={
|
455
|
-
"query": query,
|
456
|
-
"filter": filter.to_dict(),
|
457
|
-
"limit": limit,
|
458
|
-
})
|
459
|
-
if resp.status_code >= 400:
|
460
|
-
raise RuntimeError("Audio can't be indexed")
|
461
|
-
audio_list = resp.json()
|
462
|
-
return self.retort.load(audio_list, List[Audio])
|
463
|
-
|
464
451
|
def get_usage(self,
|
465
452
|
start_dt: datetime,
|
466
453
|
end_dt: datetime):
|
most/async_api.py
CHANGED
@@ -18,14 +18,14 @@ from most.types import (
|
|
18
18
|
Script,
|
19
19
|
StoredAudioData,
|
20
20
|
Text,
|
21
|
-
is_valid_id,
|
21
|
+
is_valid_id, ScriptScoreMapping, Dialog, Usage,
|
22
22
|
)
|
23
23
|
|
24
24
|
|
25
25
|
class AsyncMostClient(object):
|
26
26
|
retort = Retort(recipe=[
|
27
|
-
loader(int, lambda x: int(x)),
|
28
|
-
loader(float, lambda x: float(x)),
|
27
|
+
loader(int, lambda x: x if isinstance(x, int) else int(x)),
|
28
|
+
loader(float, lambda x: x if isinstance(x, float) else float(x)),
|
29
29
|
loader(datetime, lambda x: datetime.fromtimestamp(x).astimezone(tz=timezone.utc) if isinstance(x, (int, float)) else datetime.fromisoformat(x)),
|
30
30
|
])
|
31
31
|
|
@@ -274,7 +274,8 @@ class AsyncMostClient(object):
|
|
274
274
|
params={"overwrite": overwrite})
|
275
275
|
result = self.retort.load(resp.json(), Result)
|
276
276
|
if modify_scores:
|
277
|
-
|
277
|
+
score_modifier = await self.get_score_modifier()
|
278
|
+
result = score_modifier.modify(result)
|
278
279
|
return result
|
279
280
|
|
280
281
|
async def apply_on_text(self, text_id,
|
@@ -290,7 +291,8 @@ class AsyncMostClient(object):
|
|
290
291
|
params={"overwrite": overwrite})
|
291
292
|
result = self.retort.load(resp.json(), Result)
|
292
293
|
if modify_scores:
|
293
|
-
|
294
|
+
score_modifier = await self.get_score_modifier()
|
295
|
+
result = score_modifier.modify(result)
|
294
296
|
return result
|
295
297
|
|
296
298
|
async def transcribe_later(self, audio_id,
|
@@ -318,7 +320,8 @@ class AsyncMostClient(object):
|
|
318
320
|
params={"overwrite": overwrite})
|
319
321
|
result = self.retort.load(resp.json(), Result)
|
320
322
|
if modify_scores:
|
321
|
-
|
323
|
+
score_modifier = await self.get_score_modifier()
|
324
|
+
result = score_modifier.modify(result)
|
322
325
|
return result
|
323
326
|
|
324
327
|
async def apply_on_text_later(self, text_id,
|
@@ -334,7 +337,8 @@ class AsyncMostClient(object):
|
|
334
337
|
params={"overwrite": overwrite})
|
335
338
|
result = self.retort.load(resp.json(), Result)
|
336
339
|
if modify_scores:
|
337
|
-
|
340
|
+
score_modifier = await self.get_score_modifier()
|
341
|
+
result = score_modifier.modify(result)
|
338
342
|
return result
|
339
343
|
|
340
344
|
async def get_job_status(self, audio_id) -> JobStatus:
|
@@ -358,7 +362,8 @@ class AsyncMostClient(object):
|
|
358
362
|
resp = await self.get(f"/{self.client_id}/audio/{audio_id}/model/{self.model_id}/results")
|
359
363
|
result = self.retort.load(resp.json(), Result)
|
360
364
|
if modify_scores:
|
361
|
-
|
365
|
+
score_modifier = await self.get_score_modifier()
|
366
|
+
result = score_modifier.modify(result)
|
362
367
|
return result
|
363
368
|
|
364
369
|
async def fetch_text(self, audio_id) -> Result:
|
@@ -394,7 +399,8 @@ class AsyncMostClient(object):
|
|
394
399
|
|
395
400
|
async def export(self, audio_ids: List[str],
|
396
401
|
aggregated_by: Optional[str] = None,
|
397
|
-
aggregation_title: Optional[str] = None
|
402
|
+
aggregation_title: Optional[str] = None,
|
403
|
+
modify_scores: bool = False) -> str:
|
398
404
|
if aggregation_title is None:
|
399
405
|
aggregation_title = aggregated_by
|
400
406
|
|
@@ -404,23 +410,24 @@ class AsyncMostClient(object):
|
|
404
410
|
resp = await self.get(f"/{self.client_id}/model/{self.model_id}/export",
|
405
411
|
params={'audio_ids': ','.join(audio_ids),
|
406
412
|
"aggregated_by": aggregated_by,
|
407
|
-
"aggregation_title": aggregation_title
|
413
|
+
"aggregation_title": aggregation_title,
|
414
|
+
"modify_scores": modify_scores})
|
408
415
|
return resp.next_request.url
|
409
416
|
|
410
417
|
async def store_info(self,
|
411
418
|
audio_id: str,
|
412
|
-
data: Dict[str, str]):
|
419
|
+
data: Dict[str, Union[str, int, float]]) -> StoredAudioData:
|
413
420
|
resp = await self.post(f"/{self.client_id}/audio/{audio_id}/info",
|
414
421
|
json={
|
415
422
|
"data": data,
|
416
423
|
})
|
417
|
-
return
|
424
|
+
return StoredAudioData.from_dict(resp.json())
|
418
425
|
|
419
|
-
async def fetch_info(self, audio_id: str) ->
|
426
|
+
async def fetch_info(self, audio_id: str) -> StoredAudioData:
|
420
427
|
if not is_valid_id(audio_id):
|
421
428
|
raise RuntimeError("Please use valid audio_id. [try audio.id from list_audios()]")
|
422
429
|
resp = await self.get(f"/{self.client_id}/audio/{audio_id}/info")
|
423
|
-
return
|
430
|
+
return StoredAudioData.from_dict(resp.json())
|
424
431
|
|
425
432
|
async def __call__(self, audio_path: Path,
|
426
433
|
modify_scores: bool = False) -> Result:
|
@@ -452,21 +459,6 @@ class AsyncMostClient(object):
|
|
452
459
|
raise RuntimeError("Audio can't be indexed")
|
453
460
|
return None
|
454
461
|
|
455
|
-
async def search(self,
|
456
|
-
query: str,
|
457
|
-
filter: SearchParams,
|
458
|
-
limit: int = 10) -> List[Audio]:
|
459
|
-
resp = await self.post(f"/{self.client_id}/model/{self.model_id}/search",
|
460
|
-
json={
|
461
|
-
"query": query,
|
462
|
-
"filter": filter.to_dict(),
|
463
|
-
"limit": limit,
|
464
|
-
})
|
465
|
-
if resp.status_code >= 400:
|
466
|
-
raise RuntimeError("Audio can't be indexed")
|
467
|
-
audio_list = resp.json()
|
468
|
-
return self.retort.load(audio_list, List[Audio])
|
469
|
-
|
470
462
|
async def get_usage(self,
|
471
463
|
start_dt: datetime,
|
472
464
|
end_dt: datetime):
|
most/async_searcher.py
ADDED
@@ -0,0 +1,57 @@
|
|
1
|
+
from typing import List, Literal, Optional
|
2
|
+
from . import AsyncMostClient
|
3
|
+
from .types import Audio
|
4
|
+
from .search_types import SearchParams
|
5
|
+
|
6
|
+
|
7
|
+
class AsyncMostSearcher(object):
|
8
|
+
def __init__(self, client: AsyncMostClient,
|
9
|
+
data_source: Literal["text", "audio"]):
|
10
|
+
self.client = client
|
11
|
+
self.data_source = data_source
|
12
|
+
|
13
|
+
async def count(self,
|
14
|
+
filter: Optional[SearchParams] = None):
|
15
|
+
if filter is None:
|
16
|
+
filter = SearchParams()
|
17
|
+
|
18
|
+
resp = await self.client.get(f"/{self.client.client_id}/{self.data_source}/count",
|
19
|
+
params={
|
20
|
+
"filter": filter.to_json(),
|
21
|
+
})
|
22
|
+
if resp.status_code >= 400:
|
23
|
+
raise RuntimeError("Audio can't be indexed")
|
24
|
+
return resp.json()
|
25
|
+
|
26
|
+
|
27
|
+
async def distinct(self,
|
28
|
+
key: str,
|
29
|
+
filter: Optional[SearchParams] = None) -> List[str]:
|
30
|
+
"""
|
31
|
+
Distinct values of key.
|
32
|
+
:param key: key should be stored in info (fetch_info, store_info)
|
33
|
+
:return:
|
34
|
+
"""
|
35
|
+
if filter is None:
|
36
|
+
filter = SearchParams()
|
37
|
+
resp = await self.client.get(f"/{self.client.client_id}/{self.data_source}/distinct",
|
38
|
+
params={"filter": filter.to_json(),
|
39
|
+
"key": key})
|
40
|
+
if resp.status_code >= 400:
|
41
|
+
raise RuntimeError("Key is not valid")
|
42
|
+
return resp.json()
|
43
|
+
|
44
|
+
async def search(self,
|
45
|
+
filter: Optional[SearchParams] = None,
|
46
|
+
limit: int = 10) -> List[Audio]:
|
47
|
+
if filter is None:
|
48
|
+
filter = SearchParams()
|
49
|
+
resp = await self.client.get(f"/{self.client.client_id}/{self.data_source}/search",
|
50
|
+
params={
|
51
|
+
"filter": filter.to_json(),
|
52
|
+
"limit": limit,
|
53
|
+
})
|
54
|
+
if resp.status_code >= 400:
|
55
|
+
raise RuntimeError("Audio can't be indexed")
|
56
|
+
audio_list = resp.json()
|
57
|
+
return self.client.retort.load(audio_list, List[Audio])
|
most/score_calculation.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Dict, Tuple, List, Optional
|
1
|
+
from typing import Dict, Tuple, List, Optional, Literal
|
2
2
|
from dataclasses_json import dataclass_json, DataClassJsonMixin
|
3
3
|
from dataclasses import dataclass, replace
|
4
4
|
from .types import Result, ScriptScoreMapping
|
@@ -25,3 +25,55 @@ class ScoreCalculation(DataClassJsonMixin):
|
|
25
25
|
subcolumn_result.score)
|
26
26
|
|
27
27
|
return result
|
28
|
+
|
29
|
+
def unmodify(self, result: Optional[Result]):
|
30
|
+
score_mapping = {
|
31
|
+
(sm.column, sm.subcolumn, sm.to_score): sm.from_score
|
32
|
+
for sm in self.score_mapping
|
33
|
+
}
|
34
|
+
if result is None:
|
35
|
+
return None
|
36
|
+
result = replace(result)
|
37
|
+
for column_result in result.results:
|
38
|
+
for subcolumn_result in column_result.subcolumns:
|
39
|
+
subcolumn_result.score = score_mapping.get((column_result.name,
|
40
|
+
subcolumn_result.name,
|
41
|
+
subcolumn_result.score),
|
42
|
+
subcolumn_result.score)
|
43
|
+
|
44
|
+
return result
|
45
|
+
|
46
|
+
def modify_single(self,
|
47
|
+
column: str, subcolumn: str,
|
48
|
+
from_score: int):
|
49
|
+
for sm in self.score_mapping:
|
50
|
+
if sm.column == column and sm.subcolumn == subcolumn and sm.from_score == from_score:
|
51
|
+
return sm.to_score
|
52
|
+
|
53
|
+
def unmodify_single(self,
|
54
|
+
column: str, subcolumn: str,
|
55
|
+
to_score: int,
|
56
|
+
bound: Literal["strict", "upper", "lower"] = "strict"):
|
57
|
+
upper_from_score = None
|
58
|
+
lower_from_score = None
|
59
|
+
|
60
|
+
for sm in self.score_mapping:
|
61
|
+
if sm.column == column and sm.subcolumn == subcolumn:
|
62
|
+
if sm.to_score == to_score:
|
63
|
+
return sm.from_score
|
64
|
+
|
65
|
+
if sm.to_score > to_score:
|
66
|
+
if upper_from_score is None or sm.to_score < upper_from_score[1]:
|
67
|
+
upper_from_score = (sm.from_score, sm.to_score)
|
68
|
+
elif sm.to_score < to_score:
|
69
|
+
if lower_from_score is None or sm.to_score > lower_from_score[1]:
|
70
|
+
lower_from_score = (sm.from_score, sm.to_score)
|
71
|
+
|
72
|
+
if bound == "strict":
|
73
|
+
return None
|
74
|
+
elif bound == "upper" and upper_from_score is not None:
|
75
|
+
return upper_from_score[0]
|
76
|
+
elif bound == "lower" and lower_from_score is not None:
|
77
|
+
return lower_from_score[0]
|
78
|
+
else:
|
79
|
+
return None
|
most/search_types.py
ADDED
@@ -0,0 +1,143 @@
|
|
1
|
+
from dataclasses import dataclass, field
|
2
|
+
from typing import List, Optional
|
3
|
+
|
4
|
+
from bson import ObjectId
|
5
|
+
from dataclasses_json import DataClassJsonMixin, dataclass_json
|
6
|
+
|
7
|
+
|
8
|
+
@dataclass_json
|
9
|
+
@dataclass
|
10
|
+
class IDCondition(DataClassJsonMixin):
|
11
|
+
equal: Optional[ObjectId] = None
|
12
|
+
in_set: Optional[List[ObjectId]] = None
|
13
|
+
greater_than: Optional[ObjectId] = None
|
14
|
+
less_than: Optional[ObjectId] = None
|
15
|
+
|
16
|
+
|
17
|
+
@dataclass_json
|
18
|
+
@dataclass
|
19
|
+
class ChannelsCondition(DataClassJsonMixin):
|
20
|
+
equal: Optional[int] = None
|
21
|
+
|
22
|
+
|
23
|
+
@dataclass_json
|
24
|
+
@dataclass
|
25
|
+
class DurationCondition(DataClassJsonMixin):
|
26
|
+
greater_than: Optional[int] = None
|
27
|
+
less_than: Optional[int] = None
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass_json
|
31
|
+
@dataclass
|
32
|
+
class StoredInfoCondition(DataClassJsonMixin):
|
33
|
+
key: str
|
34
|
+
match: Optional[int | str | float] = None
|
35
|
+
starts_with: Optional[str] = None
|
36
|
+
ends_with: Optional[str] = None
|
37
|
+
greater_than: Optional[int | str | float] = None
|
38
|
+
less_than: Optional[int | str | float] = None
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass_json
|
42
|
+
@dataclass
|
43
|
+
class ResultsCondition(DataClassJsonMixin):
|
44
|
+
column_idx: int
|
45
|
+
subcolumn_idx: int
|
46
|
+
model_id: str
|
47
|
+
score_equal: Optional[int] = None
|
48
|
+
score_in_set: Optional[List[int]] = None
|
49
|
+
score_greater_than: Optional[int] = None
|
50
|
+
score_less_than: Optional[int] = None
|
51
|
+
|
52
|
+
def create_from(self, client,
|
53
|
+
column: str, subcolumn: str,
|
54
|
+
score_equal: Optional[int] = None,
|
55
|
+
score_in_set: Optional[List[int]] = None,
|
56
|
+
score_greater_than: Optional[int] = None,
|
57
|
+
score_less_than: Optional[int] = None,
|
58
|
+
modified_scores: bool = False) -> 'ResultsCondition':
|
59
|
+
from .api import MostClient
|
60
|
+
client: MostClient
|
61
|
+
script = client.get_model_script()
|
62
|
+
column_idx = [column.name for column in script.columns].index(column)
|
63
|
+
subcolumn_idx = script.columns[column_idx].subcolumns.index(subcolumn)
|
64
|
+
|
65
|
+
if modified_scores:
|
66
|
+
score_modifier = client.get_score_modifier()
|
67
|
+
if score_equal is not None:
|
68
|
+
score_equal = score_modifier.unmodify_single(column, subcolumn,
|
69
|
+
score_equal,
|
70
|
+
bound="strict")
|
71
|
+
if score_in_set is not None:
|
72
|
+
score_in_set = [score_modifier.unmodify_single(column, subcolumn,
|
73
|
+
score,
|
74
|
+
bound="strict")
|
75
|
+
for score in score_in_set]
|
76
|
+
if score_greater_than is not None:
|
77
|
+
score_greater_than = score_modifier.unmodify_single(column, subcolumn,
|
78
|
+
score_greater_than,
|
79
|
+
bound="upper")
|
80
|
+
|
81
|
+
if score_less_than is not None:
|
82
|
+
score_less_than = score_modifier.unmodify_single(column, subcolumn,
|
83
|
+
score_less_than,
|
84
|
+
bound="lower")
|
85
|
+
|
86
|
+
return ResultsCondition(model_id=client.model_id,
|
87
|
+
column_idx=column_idx,
|
88
|
+
subcolumn_idx=subcolumn_idx,
|
89
|
+
score_equal=score_equal,
|
90
|
+
score_in_set=score_in_set,
|
91
|
+
score_greater_than=score_greater_than,
|
92
|
+
score_less_than=score_less_than)
|
93
|
+
|
94
|
+
async def acreate_from(self, client,
|
95
|
+
column: str, subcolumn: str,
|
96
|
+
score_equal: Optional[int] = None,
|
97
|
+
score_in_set: Optional[List[int]] = None,
|
98
|
+
score_greater_than: Optional[int] = None,
|
99
|
+
score_less_than: Optional[int] = None,
|
100
|
+
modified_scores: bool = False) -> 'ResultsCondition':
|
101
|
+
from .async_api import AsyncMostClient
|
102
|
+
client: AsyncMostClient
|
103
|
+
script = await client.get_model_script()
|
104
|
+
column_idx = [column.name for column in script.columns].index(column)
|
105
|
+
subcolumn_idx = script.columns[column_idx].subcolumns.index(subcolumn)
|
106
|
+
|
107
|
+
if modified_scores:
|
108
|
+
score_modifier = await client.get_score_modifier()
|
109
|
+
if score_equal is not None:
|
110
|
+
score_equal = score_modifier.unmodify_single(column, subcolumn,
|
111
|
+
score_equal,
|
112
|
+
bound="strict")
|
113
|
+
if score_in_set is not None:
|
114
|
+
score_in_set = [score_modifier.unmodify_single(column, subcolumn,
|
115
|
+
score,
|
116
|
+
bound="strict")
|
117
|
+
for score in score_in_set]
|
118
|
+
if score_greater_than is not None:
|
119
|
+
score_greater_than = score_modifier.unmodify_single(column, subcolumn,
|
120
|
+
score_greater_than,
|
121
|
+
bound="upper")
|
122
|
+
|
123
|
+
if score_less_than is not None:
|
124
|
+
score_less_than = score_modifier.unmodify_single(column, subcolumn,
|
125
|
+
score_less_than,
|
126
|
+
bound="lower")
|
127
|
+
|
128
|
+
return ResultsCondition(model_id=client.model_id,
|
129
|
+
column_idx=column_idx,
|
130
|
+
subcolumn_idx=subcolumn_idx,
|
131
|
+
score_equal=score_equal,
|
132
|
+
score_in_set=score_in_set,
|
133
|
+
score_greater_than=score_greater_than,
|
134
|
+
score_less_than=score_less_than)
|
135
|
+
|
136
|
+
|
137
|
+
@dataclass_json
|
138
|
+
@dataclass
|
139
|
+
class SearchParams(DataClassJsonMixin):
|
140
|
+
must: List[StoredInfoCondition | ResultsCondition | DurationCondition | ChannelsCondition | IDCondition ] = field(default_factory=list)
|
141
|
+
should: List[StoredInfoCondition | ResultsCondition | DurationCondition | ChannelsCondition | IDCondition ] = field(default_factory=list)
|
142
|
+
must_not: List[StoredInfoCondition | ResultsCondition | DurationCondition | ChannelsCondition | IDCondition ] = field(default_factory=list)
|
143
|
+
should_not: List[StoredInfoCondition | ResultsCondition | DurationCondition | ChannelsCondition | IDCondition ] = field(default_factory=list)
|
most/searcher.py
ADDED
@@ -0,0 +1,57 @@
|
|
1
|
+
from typing import List, Literal, Optional
|
2
|
+
|
3
|
+
from .api import MostClient
|
4
|
+
from .search_types import SearchParams
|
5
|
+
from .types import Audio, Text
|
6
|
+
|
7
|
+
|
8
|
+
class MostSearcher(object):
|
9
|
+
def __init__(self, client: MostClient,
|
10
|
+
data_source: Literal["text", "audio"]):
|
11
|
+
self.client = client
|
12
|
+
self.data_source = data_source
|
13
|
+
|
14
|
+
def count(self,
|
15
|
+
filter: Optional[SearchParams] = None):
|
16
|
+
if filter is None:
|
17
|
+
filter = SearchParams()
|
18
|
+
|
19
|
+
resp = self.client.get(f"/{self.client.client_id}/{self.data_source}/count",
|
20
|
+
params={
|
21
|
+
"filter": filter.to_json(),
|
22
|
+
})
|
23
|
+
if resp.status_code >= 400:
|
24
|
+
raise RuntimeError("Audio can't be indexed")
|
25
|
+
return resp.json()
|
26
|
+
|
27
|
+
def distinct(self,
|
28
|
+
key: str,
|
29
|
+
filter: Optional[SearchParams] = None) -> List[str]:
|
30
|
+
"""
|
31
|
+
Distinct values of key.
|
32
|
+
:param key: key should be stored in info (fetch_info, store_info)
|
33
|
+
:return:
|
34
|
+
"""
|
35
|
+
if filter is None:
|
36
|
+
filter = SearchParams()
|
37
|
+
resp = self.client.get(f"/{self.client.client_id}/{self.data_source}/distinct",
|
38
|
+
params={"filter": filter.to_json(),
|
39
|
+
"key": key})
|
40
|
+
if resp.status_code >= 400:
|
41
|
+
raise RuntimeError("Key is not valid")
|
42
|
+
return resp.json()
|
43
|
+
|
44
|
+
def search(self,
|
45
|
+
filter: Optional[SearchParams] = None,
|
46
|
+
limit: int = 10) -> List[Audio | Text]:
|
47
|
+
if filter is None:
|
48
|
+
filter = SearchParams()
|
49
|
+
resp = self.client.get(f"/{self.client.client_id}/{self.data_source}/search",
|
50
|
+
params={
|
51
|
+
"filter": filter.to_json(),
|
52
|
+
"limit": limit,
|
53
|
+
})
|
54
|
+
if resp.status_code >= 400:
|
55
|
+
raise RuntimeError("Audio can't be indexed")
|
56
|
+
audio_list = resp.json()
|
57
|
+
return self.client.retort.load(audio_list, List[Audio])
|
most/types.py
CHANGED
@@ -2,7 +2,6 @@ import re
|
|
2
2
|
from dataclasses import dataclass
|
3
3
|
from datetime import datetime
|
4
4
|
from typing import Dict, List, Literal, Optional, Union
|
5
|
-
|
6
5
|
from dataclasses_json import DataClassJsonMixin, dataclass_json
|
7
6
|
|
8
7
|
|
@@ -10,7 +9,7 @@ from dataclasses_json import DataClassJsonMixin, dataclass_json
|
|
10
9
|
@dataclass
|
11
10
|
class StoredAudioData(DataClassJsonMixin):
|
12
11
|
id: str
|
13
|
-
data: Dict[str, str]
|
12
|
+
data: Dict[str, Union[str, int, float]]
|
14
13
|
|
15
14
|
|
16
15
|
@dataclass_json
|
@@ -125,32 +124,6 @@ class DialogResult(DataClassJsonMixin):
|
|
125
124
|
results: Optional[List[ColumnResult]] = None
|
126
125
|
|
127
126
|
|
128
|
-
@dataclass_json
|
129
|
-
@dataclass
|
130
|
-
class StoredInfoCondition(DataClassJsonMixin):
|
131
|
-
key: str
|
132
|
-
match: Optional[str] = None
|
133
|
-
starts_with: Optional[str] = None
|
134
|
-
ends_with: Optional[str] = None
|
135
|
-
|
136
|
-
|
137
|
-
@dataclass_json
|
138
|
-
@dataclass
|
139
|
-
class ResultsCondition(DataClassJsonMixin):
|
140
|
-
column: str
|
141
|
-
subcolumn: str
|
142
|
-
score_greater_than: Optional[int] = None
|
143
|
-
score_less_than: Optional[int] = None
|
144
|
-
|
145
|
-
|
146
|
-
@dataclass_json
|
147
|
-
@dataclass
|
148
|
-
class SearchParams(DataClassJsonMixin):
|
149
|
-
must: List[StoredInfoCondition | ResultsCondition]
|
150
|
-
should: List[StoredInfoCondition | ResultsCondition]
|
151
|
-
must_not: List[StoredInfoCondition | ResultsCondition]
|
152
|
-
|
153
|
-
|
154
127
|
@dataclass_json
|
155
128
|
@dataclass
|
156
129
|
class HumanFeedback(DataClassJsonMixin):
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: most-client
|
3
|
-
Version: 1.0.
|
3
|
+
Version: 1.0.38
|
4
4
|
Summary: Most AI API for https://the-most.ai
|
5
5
|
Home-page: https://github.com/the-most-ai/most-client
|
6
6
|
Author: George Kasparyants
|
@@ -27,6 +27,7 @@ Requires-Dist: tox
|
|
27
27
|
Requires-Dist: twine
|
28
28
|
Requires-Dist: httpx
|
29
29
|
Requires-Dist: pydub
|
30
|
+
Requires-Dist: bson
|
30
31
|
Dynamic: author
|
31
32
|
Dynamic: author-email
|
32
33
|
Dynamic: classifier
|
@@ -0,0 +1,16 @@
|
|
1
|
+
most/__init__.py,sha256=yoPKMxjZYAzrVqZ7l9rN0Skh0HPuttPX7ub0wlZMby4,351
|
2
|
+
most/_constrants.py,sha256=SlHKcBoXwe_sPzk8tdbb7lqhQz-Bfo__FhSoeFWodZE,217
|
3
|
+
most/api.py,sha256=HmJSAwLVuy7uaeVAM3y1nd-F-UDZ2kOVww_i2EY0Sjs,19145
|
4
|
+
most/async_api.py,sha256=Xyo1sGN4QstqIZh_4YDT6FStaXR0HRy6fbtbqc7uqK8,20562
|
5
|
+
most/async_searcher.py,sha256=C0zViW20K7OhKO1BzBZktTbMJYBBvor3uK6LAHZTxz0,2238
|
6
|
+
most/async_trainer_api.py,sha256=99rED8RjnOn8VezeEgrTgoVfQrO7DdmOE2Jajumno2g,1052
|
7
|
+
most/score_calculation.py,sha256=vLtGqXrR43xZhGjrH5dpQZfWX1q3s74LvTaHn-SKBAg,3254
|
8
|
+
most/search_types.py,sha256=63kBZvvkqlN2UOHXcxTraI9vEq-ANoqXmBnHi_QABDE,6729
|
9
|
+
most/searcher.py,sha256=9UdiSlScsE6EPc6RpK8xkRLeB5gHNxgPQpXTJ17i3lQ,2135
|
10
|
+
most/trainer_api.py,sha256=ZwOv4mhROfY97n6i7IY_ZpafsuNRazOqMBAf2dh708k,992
|
11
|
+
most/types.py,sha256=AU74VqYilM9DXfBwptliRbhV5urLAf4BygdIY3wlAN8,4309
|
12
|
+
most_client-1.0.38.dist-info/METADATA,sha256=17PN1v9-yRjWGMSVCEIlIKdkK7z6QQlFIaP9ZtvtQFU,1047
|
13
|
+
most_client-1.0.38.dist-info/WHEEL,sha256=DnLRTWE75wApRYVsjgc6wsVswC54sMSJhAEd4xhDpBk,91
|
14
|
+
most_client-1.0.38.dist-info/top_level.txt,sha256=2g5fk02LKkM1hV3pVVti_LQ60TToLBcR2zQ3JEKGVk8,5
|
15
|
+
most_client-1.0.38.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
16
|
+
most_client-1.0.38.dist-info/RECORD,,
|
@@ -1,13 +0,0 @@
|
|
1
|
-
most/__init__.py,sha256=b0EXXaPA4kmt-FtGXKRWZr7SCwjipMLcpC6uT5WRIdY,144
|
2
|
-
most/_constrants.py,sha256=SlHKcBoXwe_sPzk8tdbb7lqhQz-Bfo__FhSoeFWodZE,217
|
3
|
-
most/api.py,sha256=vg7zYsRebQjVOmTLRQE6jqmK-zG1j-k-KMN68gzFpk4,19538
|
4
|
-
most/async_api.py,sha256=iRIjHLVkrjBKimlVmz5Ev_pPNB0dwCa5QgKRy1rCF3c,20753
|
5
|
-
most/async_trainer_api.py,sha256=99rED8RjnOn8VezeEgrTgoVfQrO7DdmOE2Jajumno2g,1052
|
6
|
-
most/score_calculation.py,sha256=1XU1LfIH5LSCwAbAaKkr-EjH5qOTXrJKOUvhCCawka4,1054
|
7
|
-
most/trainer_api.py,sha256=ZwOv4mhROfY97n6i7IY_ZpafsuNRazOqMBAf2dh708k,992
|
8
|
-
most/types.py,sha256=KP34dzS8ayQUhToIxwxTL9O8-7TZz1ySfJzA3ZNGIGw,4921
|
9
|
-
most_client-1.0.36.dist-info/METADATA,sha256=GnPH3grUAX3tLjJKswpB1_vDWtZ8i59zX0miAxZTRx8,1027
|
10
|
-
most_client-1.0.36.dist-info/WHEEL,sha256=SmOxYU7pzNKBqASvQJ7DjX3XGUF92lrGhMb3R6_iiqI,91
|
11
|
-
most_client-1.0.36.dist-info/top_level.txt,sha256=2g5fk02LKkM1hV3pVVti_LQ60TToLBcR2zQ3JEKGVk8,5
|
12
|
-
most_client-1.0.36.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
13
|
-
most_client-1.0.36.dist-info/RECORD,,
|
File without changes
|
File without changes
|