mistralai 1.6.0__py3-none-any.whl → 1.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- mistralai/_version.py +3 -3
- mistralai/classifiers.py +431 -19
- mistralai/embeddings.py +6 -2
- mistralai/extra/utils/_pydantic_helper.py +2 -1
- mistralai/jobs.py +84 -38
- mistralai/mistral_jobs.py +2 -2
- mistralai/models/__init__.py +197 -46
- mistralai/models/archiveftmodelout.py +3 -11
- mistralai/models/batchjobout.py +3 -9
- mistralai/models/batchjobsout.py +3 -9
- mistralai/models/chatclassificationrequest.py +20 -0
- mistralai/models/chatmoderationrequest.py +4 -7
- mistralai/models/classificationresponse.py +12 -9
- mistralai/models/classificationtargetresult.py +14 -0
- mistralai/models/classifierdetailedjobout.py +156 -0
- mistralai/models/classifierftmodelout.py +101 -0
- mistralai/models/classifierjobout.py +165 -0
- mistralai/models/classifiertargetin.py +55 -0
- mistralai/models/classifiertargetout.py +24 -0
- mistralai/models/classifiertrainingparameters.py +73 -0
- mistralai/models/classifiertrainingparametersin.py +85 -0
- mistralai/models/{detailedjobout.py → completiondetailedjobout.py} +34 -34
- mistralai/models/{ftmodelout.py → completionftmodelout.py} +12 -12
- mistralai/models/{jobout.py → completionjobout.py} +25 -24
- mistralai/models/{trainingparameters.py → completiontrainingparameters.py} +7 -7
- mistralai/models/{trainingparametersin.py → completiontrainingparametersin.py} +7 -7
- mistralai/models/embeddingrequest.py +6 -4
- mistralai/models/finetuneablemodeltype.py +7 -0
- mistralai/models/ftclassifierlossfunction.py +7 -0
- mistralai/models/ftmodelcapabilitiesout.py +3 -0
- mistralai/models/githubrepositoryin.py +3 -11
- mistralai/models/githubrepositoryout.py +3 -11
- mistralai/models/inputs.py +54 -0
- mistralai/models/instructrequest.py +42 -0
- mistralai/models/jobin.py +52 -12
- mistralai/models/jobs_api_routes_batch_get_batch_jobsop.py +3 -3
- mistralai/models/jobs_api_routes_fine_tuning_cancel_fine_tuning_jobop.py +29 -2
- mistralai/models/jobs_api_routes_fine_tuning_create_fine_tuning_jobop.py +21 -4
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobop.py +29 -2
- mistralai/models/jobs_api_routes_fine_tuning_get_fine_tuning_jobsop.py +8 -0
- mistralai/models/jobs_api_routes_fine_tuning_start_fine_tuning_jobop.py +29 -2
- mistralai/models/jobs_api_routes_fine_tuning_update_fine_tuned_modelop.py +28 -2
- mistralai/models/jobsout.py +24 -13
- mistralai/models/legacyjobmetadataout.py +3 -12
- mistralai/models/{classificationobject.py → moderationobject.py} +6 -6
- mistralai/models/moderationresponse.py +21 -0
- mistralai/models/ocrimageobject.py +7 -1
- mistralai/models/ocrrequest.py +15 -0
- mistralai/models/ocrresponse.py +38 -2
- mistralai/models/unarchiveftmodelout.py +3 -11
- mistralai/models/wandbintegration.py +3 -11
- mistralai/models/wandbintegrationout.py +8 -13
- mistralai/models_.py +10 -4
- mistralai/ocr.py +28 -0
- {mistralai-1.6.0.dist-info → mistralai-1.7.1.dist-info}/METADATA +3 -1
- {mistralai-1.6.0.dist-info → mistralai-1.7.1.dist-info}/RECORD +58 -44
- {mistralai-1.6.0.dist-info → mistralai-1.7.1.dist-info}/WHEEL +1 -1
- {mistralai-1.6.0.dist-info → mistralai-1.7.1.dist-info}/LICENSE +0 -0
mistralai/_version.py
CHANGED
|
@@ -3,10 +3,10 @@
|
|
|
3
3
|
import importlib.metadata
|
|
4
4
|
|
|
5
5
|
__title__: str = "mistralai"
|
|
6
|
-
__version__: str = "1.
|
|
7
|
-
__openapi_doc_version__: str = "0.0
|
|
6
|
+
__version__: str = "1.7.1"
|
|
7
|
+
__openapi_doc_version__: str = "1.0.0"
|
|
8
8
|
__gen_version__: str = "2.548.6"
|
|
9
|
-
__user_agent__: str = "speakeasy-sdk/python 1.
|
|
9
|
+
__user_agent__: str = "speakeasy-sdk/python 1.7.1 2.548.6 1.0.0 mistralai"
|
|
10
10
|
|
|
11
11
|
try:
|
|
12
12
|
if __package__ is not None:
|
mistralai/classifiers.py
CHANGED
|
@@ -23,7 +23,7 @@ class Classifiers(BaseSDK):
|
|
|
23
23
|
server_url: Optional[str] = None,
|
|
24
24
|
timeout_ms: Optional[int] = None,
|
|
25
25
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
26
|
-
) -> models.
|
|
26
|
+
) -> models.ModerationResponse:
|
|
27
27
|
r"""Moderations
|
|
28
28
|
|
|
29
29
|
:param model: ID of the model to use.
|
|
@@ -91,7 +91,7 @@ class Classifiers(BaseSDK):
|
|
|
91
91
|
|
|
92
92
|
response_data: Any = None
|
|
93
93
|
if utils.match_response(http_res, "200", "application/json"):
|
|
94
|
-
return utils.unmarshal_json(http_res.text, models.
|
|
94
|
+
return utils.unmarshal_json(http_res.text, models.ModerationResponse)
|
|
95
95
|
if utils.match_response(http_res, "422", "application/json"):
|
|
96
96
|
response_data = utils.unmarshal_json(
|
|
97
97
|
http_res.text, models.HTTPValidationErrorData
|
|
@@ -129,7 +129,7 @@ class Classifiers(BaseSDK):
|
|
|
129
129
|
server_url: Optional[str] = None,
|
|
130
130
|
timeout_ms: Optional[int] = None,
|
|
131
131
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
132
|
-
) -> models.
|
|
132
|
+
) -> models.ModerationResponse:
|
|
133
133
|
r"""Moderations
|
|
134
134
|
|
|
135
135
|
:param model: ID of the model to use.
|
|
@@ -197,7 +197,7 @@ class Classifiers(BaseSDK):
|
|
|
197
197
|
|
|
198
198
|
response_data: Any = None
|
|
199
199
|
if utils.match_response(http_res, "200", "application/json"):
|
|
200
|
-
return utils.unmarshal_json(http_res.text, models.
|
|
200
|
+
return utils.unmarshal_json(http_res.text, models.ModerationResponse)
|
|
201
201
|
if utils.match_response(http_res, "422", "application/json"):
|
|
202
202
|
response_data = utils.unmarshal_json(
|
|
203
203
|
http_res.text, models.HTTPValidationErrorData
|
|
@@ -226,22 +226,20 @@ class Classifiers(BaseSDK):
|
|
|
226
226
|
def moderate_chat(
|
|
227
227
|
self,
|
|
228
228
|
*,
|
|
229
|
-
model: str,
|
|
230
229
|
inputs: Union[
|
|
231
230
|
models.ChatModerationRequestInputs,
|
|
232
231
|
models.ChatModerationRequestInputsTypedDict,
|
|
233
232
|
],
|
|
234
|
-
|
|
233
|
+
model: str,
|
|
235
234
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
236
235
|
server_url: Optional[str] = None,
|
|
237
236
|
timeout_ms: Optional[int] = None,
|
|
238
237
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
239
|
-
) -> models.
|
|
238
|
+
) -> models.ModerationResponse:
|
|
240
239
|
r"""Chat Moderations
|
|
241
240
|
|
|
242
|
-
:param model:
|
|
243
241
|
:param inputs: Chat to classify
|
|
244
|
-
:param
|
|
242
|
+
:param model:
|
|
245
243
|
:param retries: Override the default retry configuration for this method
|
|
246
244
|
:param server_url: Override the default server URL for this method
|
|
247
245
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -258,9 +256,8 @@ class Classifiers(BaseSDK):
|
|
|
258
256
|
base_url = self._get_url(base_url, url_variables)
|
|
259
257
|
|
|
260
258
|
request = models.ChatModerationRequest(
|
|
261
|
-
model=model,
|
|
262
259
|
inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
|
|
263
|
-
|
|
260
|
+
model=model,
|
|
264
261
|
)
|
|
265
262
|
|
|
266
263
|
req = self._build_request(
|
|
@@ -306,7 +303,7 @@ class Classifiers(BaseSDK):
|
|
|
306
303
|
|
|
307
304
|
response_data: Any = None
|
|
308
305
|
if utils.match_response(http_res, "200", "application/json"):
|
|
309
|
-
return utils.unmarshal_json(http_res.text, models.
|
|
306
|
+
return utils.unmarshal_json(http_res.text, models.ModerationResponse)
|
|
310
307
|
if utils.match_response(http_res, "422", "application/json"):
|
|
311
308
|
response_data = utils.unmarshal_json(
|
|
312
309
|
http_res.text, models.HTTPValidationErrorData
|
|
@@ -335,22 +332,20 @@ class Classifiers(BaseSDK):
|
|
|
335
332
|
async def moderate_chat_async(
|
|
336
333
|
self,
|
|
337
334
|
*,
|
|
338
|
-
model: str,
|
|
339
335
|
inputs: Union[
|
|
340
336
|
models.ChatModerationRequestInputs,
|
|
341
337
|
models.ChatModerationRequestInputsTypedDict,
|
|
342
338
|
],
|
|
343
|
-
|
|
339
|
+
model: str,
|
|
344
340
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
345
341
|
server_url: Optional[str] = None,
|
|
346
342
|
timeout_ms: Optional[int] = None,
|
|
347
343
|
http_headers: Optional[Mapping[str, str]] = None,
|
|
348
|
-
) -> models.
|
|
344
|
+
) -> models.ModerationResponse:
|
|
349
345
|
r"""Chat Moderations
|
|
350
346
|
|
|
351
|
-
:param model:
|
|
352
347
|
:param inputs: Chat to classify
|
|
353
|
-
:param
|
|
348
|
+
:param model:
|
|
354
349
|
:param retries: Override the default retry configuration for this method
|
|
355
350
|
:param server_url: Override the default server URL for this method
|
|
356
351
|
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
@@ -367,9 +362,8 @@ class Classifiers(BaseSDK):
|
|
|
367
362
|
base_url = self._get_url(base_url, url_variables)
|
|
368
363
|
|
|
369
364
|
request = models.ChatModerationRequest(
|
|
370
|
-
model=model,
|
|
371
365
|
inputs=utils.get_pydantic_model(inputs, models.ChatModerationRequestInputs),
|
|
372
|
-
|
|
366
|
+
model=model,
|
|
373
367
|
)
|
|
374
368
|
|
|
375
369
|
req = self._build_request_async(
|
|
@@ -413,6 +407,424 @@ class Classifiers(BaseSDK):
|
|
|
413
407
|
retry_config=retry_config,
|
|
414
408
|
)
|
|
415
409
|
|
|
410
|
+
response_data: Any = None
|
|
411
|
+
if utils.match_response(http_res, "200", "application/json"):
|
|
412
|
+
return utils.unmarshal_json(http_res.text, models.ModerationResponse)
|
|
413
|
+
if utils.match_response(http_res, "422", "application/json"):
|
|
414
|
+
response_data = utils.unmarshal_json(
|
|
415
|
+
http_res.text, models.HTTPValidationErrorData
|
|
416
|
+
)
|
|
417
|
+
raise models.HTTPValidationError(data=response_data)
|
|
418
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
419
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
420
|
+
raise models.SDKError(
|
|
421
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
422
|
+
)
|
|
423
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
424
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
425
|
+
raise models.SDKError(
|
|
426
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
427
|
+
)
|
|
428
|
+
|
|
429
|
+
content_type = http_res.headers.get("Content-Type")
|
|
430
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
431
|
+
raise models.SDKError(
|
|
432
|
+
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
433
|
+
http_res.status_code,
|
|
434
|
+
http_res_text,
|
|
435
|
+
http_res,
|
|
436
|
+
)
|
|
437
|
+
|
|
438
|
+
def classify(
|
|
439
|
+
self,
|
|
440
|
+
*,
|
|
441
|
+
model: str,
|
|
442
|
+
inputs: Union[
|
|
443
|
+
models.ClassificationRequestInputs,
|
|
444
|
+
models.ClassificationRequestInputsTypedDict,
|
|
445
|
+
],
|
|
446
|
+
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
447
|
+
server_url: Optional[str] = None,
|
|
448
|
+
timeout_ms: Optional[int] = None,
|
|
449
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
450
|
+
) -> models.ClassificationResponse:
|
|
451
|
+
r"""Classifications
|
|
452
|
+
|
|
453
|
+
:param model: ID of the model to use.
|
|
454
|
+
:param inputs: Text to classify.
|
|
455
|
+
:param retries: Override the default retry configuration for this method
|
|
456
|
+
:param server_url: Override the default server URL for this method
|
|
457
|
+
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
458
|
+
:param http_headers: Additional headers to set or replace on requests.
|
|
459
|
+
"""
|
|
460
|
+
base_url = None
|
|
461
|
+
url_variables = None
|
|
462
|
+
if timeout_ms is None:
|
|
463
|
+
timeout_ms = self.sdk_configuration.timeout_ms
|
|
464
|
+
|
|
465
|
+
if server_url is not None:
|
|
466
|
+
base_url = server_url
|
|
467
|
+
else:
|
|
468
|
+
base_url = self._get_url(base_url, url_variables)
|
|
469
|
+
|
|
470
|
+
request = models.ClassificationRequest(
|
|
471
|
+
model=model,
|
|
472
|
+
inputs=inputs,
|
|
473
|
+
)
|
|
474
|
+
|
|
475
|
+
req = self._build_request(
|
|
476
|
+
method="POST",
|
|
477
|
+
path="/v1/classifications",
|
|
478
|
+
base_url=base_url,
|
|
479
|
+
url_variables=url_variables,
|
|
480
|
+
request=request,
|
|
481
|
+
request_body_required=True,
|
|
482
|
+
request_has_path_params=False,
|
|
483
|
+
request_has_query_params=True,
|
|
484
|
+
user_agent_header="user-agent",
|
|
485
|
+
accept_header_value="application/json",
|
|
486
|
+
http_headers=http_headers,
|
|
487
|
+
security=self.sdk_configuration.security,
|
|
488
|
+
get_serialized_body=lambda: utils.serialize_request_body(
|
|
489
|
+
request, False, False, "json", models.ClassificationRequest
|
|
490
|
+
),
|
|
491
|
+
timeout_ms=timeout_ms,
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
if retries == UNSET:
|
|
495
|
+
if self.sdk_configuration.retry_config is not UNSET:
|
|
496
|
+
retries = self.sdk_configuration.retry_config
|
|
497
|
+
|
|
498
|
+
retry_config = None
|
|
499
|
+
if isinstance(retries, utils.RetryConfig):
|
|
500
|
+
retry_config = (retries, ["429", "500", "502", "503", "504"])
|
|
501
|
+
|
|
502
|
+
http_res = self.do_request(
|
|
503
|
+
hook_ctx=HookContext(
|
|
504
|
+
base_url=base_url or "",
|
|
505
|
+
operation_id="classifications_v1_classifications_post",
|
|
506
|
+
oauth2_scopes=[],
|
|
507
|
+
security_source=get_security_from_env(
|
|
508
|
+
self.sdk_configuration.security, models.Security
|
|
509
|
+
),
|
|
510
|
+
),
|
|
511
|
+
request=req,
|
|
512
|
+
error_status_codes=["422", "4XX", "5XX"],
|
|
513
|
+
retry_config=retry_config,
|
|
514
|
+
)
|
|
515
|
+
|
|
516
|
+
response_data: Any = None
|
|
517
|
+
if utils.match_response(http_res, "200", "application/json"):
|
|
518
|
+
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
519
|
+
if utils.match_response(http_res, "422", "application/json"):
|
|
520
|
+
response_data = utils.unmarshal_json(
|
|
521
|
+
http_res.text, models.HTTPValidationErrorData
|
|
522
|
+
)
|
|
523
|
+
raise models.HTTPValidationError(data=response_data)
|
|
524
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
525
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
526
|
+
raise models.SDKError(
|
|
527
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
528
|
+
)
|
|
529
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
530
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
531
|
+
raise models.SDKError(
|
|
532
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
533
|
+
)
|
|
534
|
+
|
|
535
|
+
content_type = http_res.headers.get("Content-Type")
|
|
536
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
537
|
+
raise models.SDKError(
|
|
538
|
+
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
539
|
+
http_res.status_code,
|
|
540
|
+
http_res_text,
|
|
541
|
+
http_res,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
async def classify_async(
|
|
545
|
+
self,
|
|
546
|
+
*,
|
|
547
|
+
model: str,
|
|
548
|
+
inputs: Union[
|
|
549
|
+
models.ClassificationRequestInputs,
|
|
550
|
+
models.ClassificationRequestInputsTypedDict,
|
|
551
|
+
],
|
|
552
|
+
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
553
|
+
server_url: Optional[str] = None,
|
|
554
|
+
timeout_ms: Optional[int] = None,
|
|
555
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
556
|
+
) -> models.ClassificationResponse:
|
|
557
|
+
r"""Classifications
|
|
558
|
+
|
|
559
|
+
:param model: ID of the model to use.
|
|
560
|
+
:param inputs: Text to classify.
|
|
561
|
+
:param retries: Override the default retry configuration for this method
|
|
562
|
+
:param server_url: Override the default server URL for this method
|
|
563
|
+
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
564
|
+
:param http_headers: Additional headers to set or replace on requests.
|
|
565
|
+
"""
|
|
566
|
+
base_url = None
|
|
567
|
+
url_variables = None
|
|
568
|
+
if timeout_ms is None:
|
|
569
|
+
timeout_ms = self.sdk_configuration.timeout_ms
|
|
570
|
+
|
|
571
|
+
if server_url is not None:
|
|
572
|
+
base_url = server_url
|
|
573
|
+
else:
|
|
574
|
+
base_url = self._get_url(base_url, url_variables)
|
|
575
|
+
|
|
576
|
+
request = models.ClassificationRequest(
|
|
577
|
+
model=model,
|
|
578
|
+
inputs=inputs,
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
req = self._build_request_async(
|
|
582
|
+
method="POST",
|
|
583
|
+
path="/v1/classifications",
|
|
584
|
+
base_url=base_url,
|
|
585
|
+
url_variables=url_variables,
|
|
586
|
+
request=request,
|
|
587
|
+
request_body_required=True,
|
|
588
|
+
request_has_path_params=False,
|
|
589
|
+
request_has_query_params=True,
|
|
590
|
+
user_agent_header="user-agent",
|
|
591
|
+
accept_header_value="application/json",
|
|
592
|
+
http_headers=http_headers,
|
|
593
|
+
security=self.sdk_configuration.security,
|
|
594
|
+
get_serialized_body=lambda: utils.serialize_request_body(
|
|
595
|
+
request, False, False, "json", models.ClassificationRequest
|
|
596
|
+
),
|
|
597
|
+
timeout_ms=timeout_ms,
|
|
598
|
+
)
|
|
599
|
+
|
|
600
|
+
if retries == UNSET:
|
|
601
|
+
if self.sdk_configuration.retry_config is not UNSET:
|
|
602
|
+
retries = self.sdk_configuration.retry_config
|
|
603
|
+
|
|
604
|
+
retry_config = None
|
|
605
|
+
if isinstance(retries, utils.RetryConfig):
|
|
606
|
+
retry_config = (retries, ["429", "500", "502", "503", "504"])
|
|
607
|
+
|
|
608
|
+
http_res = await self.do_request_async(
|
|
609
|
+
hook_ctx=HookContext(
|
|
610
|
+
base_url=base_url or "",
|
|
611
|
+
operation_id="classifications_v1_classifications_post",
|
|
612
|
+
oauth2_scopes=[],
|
|
613
|
+
security_source=get_security_from_env(
|
|
614
|
+
self.sdk_configuration.security, models.Security
|
|
615
|
+
),
|
|
616
|
+
),
|
|
617
|
+
request=req,
|
|
618
|
+
error_status_codes=["422", "4XX", "5XX"],
|
|
619
|
+
retry_config=retry_config,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
response_data: Any = None
|
|
623
|
+
if utils.match_response(http_res, "200", "application/json"):
|
|
624
|
+
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
625
|
+
if utils.match_response(http_res, "422", "application/json"):
|
|
626
|
+
response_data = utils.unmarshal_json(
|
|
627
|
+
http_res.text, models.HTTPValidationErrorData
|
|
628
|
+
)
|
|
629
|
+
raise models.HTTPValidationError(data=response_data)
|
|
630
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
631
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
632
|
+
raise models.SDKError(
|
|
633
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
634
|
+
)
|
|
635
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
636
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
637
|
+
raise models.SDKError(
|
|
638
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
content_type = http_res.headers.get("Content-Type")
|
|
642
|
+
http_res_text = await utils.stream_to_text_async(http_res)
|
|
643
|
+
raise models.SDKError(
|
|
644
|
+
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
645
|
+
http_res.status_code,
|
|
646
|
+
http_res_text,
|
|
647
|
+
http_res,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
def classify_chat(
|
|
651
|
+
self,
|
|
652
|
+
*,
|
|
653
|
+
model: str,
|
|
654
|
+
inputs: Union[models.Inputs, models.InputsTypedDict],
|
|
655
|
+
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
656
|
+
server_url: Optional[str] = None,
|
|
657
|
+
timeout_ms: Optional[int] = None,
|
|
658
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
659
|
+
) -> models.ClassificationResponse:
|
|
660
|
+
r"""Chat Classifications
|
|
661
|
+
|
|
662
|
+
:param model:
|
|
663
|
+
:param inputs: Chat to classify
|
|
664
|
+
:param retries: Override the default retry configuration for this method
|
|
665
|
+
:param server_url: Override the default server URL for this method
|
|
666
|
+
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
667
|
+
:param http_headers: Additional headers to set or replace on requests.
|
|
668
|
+
"""
|
|
669
|
+
base_url = None
|
|
670
|
+
url_variables = None
|
|
671
|
+
if timeout_ms is None:
|
|
672
|
+
timeout_ms = self.sdk_configuration.timeout_ms
|
|
673
|
+
|
|
674
|
+
if server_url is not None:
|
|
675
|
+
base_url = server_url
|
|
676
|
+
else:
|
|
677
|
+
base_url = self._get_url(base_url, url_variables)
|
|
678
|
+
|
|
679
|
+
request = models.ChatClassificationRequest(
|
|
680
|
+
model=model,
|
|
681
|
+
inputs=utils.get_pydantic_model(inputs, models.Inputs),
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
req = self._build_request(
|
|
685
|
+
method="POST",
|
|
686
|
+
path="/v1/chat/classifications",
|
|
687
|
+
base_url=base_url,
|
|
688
|
+
url_variables=url_variables,
|
|
689
|
+
request=request,
|
|
690
|
+
request_body_required=True,
|
|
691
|
+
request_has_path_params=False,
|
|
692
|
+
request_has_query_params=True,
|
|
693
|
+
user_agent_header="user-agent",
|
|
694
|
+
accept_header_value="application/json",
|
|
695
|
+
http_headers=http_headers,
|
|
696
|
+
security=self.sdk_configuration.security,
|
|
697
|
+
get_serialized_body=lambda: utils.serialize_request_body(
|
|
698
|
+
request, False, False, "json", models.ChatClassificationRequest
|
|
699
|
+
),
|
|
700
|
+
timeout_ms=timeout_ms,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
if retries == UNSET:
|
|
704
|
+
if self.sdk_configuration.retry_config is not UNSET:
|
|
705
|
+
retries = self.sdk_configuration.retry_config
|
|
706
|
+
|
|
707
|
+
retry_config = None
|
|
708
|
+
if isinstance(retries, utils.RetryConfig):
|
|
709
|
+
retry_config = (retries, ["429", "500", "502", "503", "504"])
|
|
710
|
+
|
|
711
|
+
http_res = self.do_request(
|
|
712
|
+
hook_ctx=HookContext(
|
|
713
|
+
base_url=base_url or "",
|
|
714
|
+
operation_id="chat_classifications_v1_chat_classifications_post",
|
|
715
|
+
oauth2_scopes=[],
|
|
716
|
+
security_source=get_security_from_env(
|
|
717
|
+
self.sdk_configuration.security, models.Security
|
|
718
|
+
),
|
|
719
|
+
),
|
|
720
|
+
request=req,
|
|
721
|
+
error_status_codes=["422", "4XX", "5XX"],
|
|
722
|
+
retry_config=retry_config,
|
|
723
|
+
)
|
|
724
|
+
|
|
725
|
+
response_data: Any = None
|
|
726
|
+
if utils.match_response(http_res, "200", "application/json"):
|
|
727
|
+
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
|
728
|
+
if utils.match_response(http_res, "422", "application/json"):
|
|
729
|
+
response_data = utils.unmarshal_json(
|
|
730
|
+
http_res.text, models.HTTPValidationErrorData
|
|
731
|
+
)
|
|
732
|
+
raise models.HTTPValidationError(data=response_data)
|
|
733
|
+
if utils.match_response(http_res, "4XX", "*"):
|
|
734
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
735
|
+
raise models.SDKError(
|
|
736
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
737
|
+
)
|
|
738
|
+
if utils.match_response(http_res, "5XX", "*"):
|
|
739
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
740
|
+
raise models.SDKError(
|
|
741
|
+
"API error occurred", http_res.status_code, http_res_text, http_res
|
|
742
|
+
)
|
|
743
|
+
|
|
744
|
+
content_type = http_res.headers.get("Content-Type")
|
|
745
|
+
http_res_text = utils.stream_to_text(http_res)
|
|
746
|
+
raise models.SDKError(
|
|
747
|
+
f"Unexpected response received (code: {http_res.status_code}, type: {content_type})",
|
|
748
|
+
http_res.status_code,
|
|
749
|
+
http_res_text,
|
|
750
|
+
http_res,
|
|
751
|
+
)
|
|
752
|
+
|
|
753
|
+
async def classify_chat_async(
|
|
754
|
+
self,
|
|
755
|
+
*,
|
|
756
|
+
model: str,
|
|
757
|
+
inputs: Union[models.Inputs, models.InputsTypedDict],
|
|
758
|
+
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
759
|
+
server_url: Optional[str] = None,
|
|
760
|
+
timeout_ms: Optional[int] = None,
|
|
761
|
+
http_headers: Optional[Mapping[str, str]] = None,
|
|
762
|
+
) -> models.ClassificationResponse:
|
|
763
|
+
r"""Chat Classifications
|
|
764
|
+
|
|
765
|
+
:param model:
|
|
766
|
+
:param inputs: Chat to classify
|
|
767
|
+
:param retries: Override the default retry configuration for this method
|
|
768
|
+
:param server_url: Override the default server URL for this method
|
|
769
|
+
:param timeout_ms: Override the default request timeout configuration for this method in milliseconds
|
|
770
|
+
:param http_headers: Additional headers to set or replace on requests.
|
|
771
|
+
"""
|
|
772
|
+
base_url = None
|
|
773
|
+
url_variables = None
|
|
774
|
+
if timeout_ms is None:
|
|
775
|
+
timeout_ms = self.sdk_configuration.timeout_ms
|
|
776
|
+
|
|
777
|
+
if server_url is not None:
|
|
778
|
+
base_url = server_url
|
|
779
|
+
else:
|
|
780
|
+
base_url = self._get_url(base_url, url_variables)
|
|
781
|
+
|
|
782
|
+
request = models.ChatClassificationRequest(
|
|
783
|
+
model=model,
|
|
784
|
+
inputs=utils.get_pydantic_model(inputs, models.Inputs),
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
req = self._build_request_async(
|
|
788
|
+
method="POST",
|
|
789
|
+
path="/v1/chat/classifications",
|
|
790
|
+
base_url=base_url,
|
|
791
|
+
url_variables=url_variables,
|
|
792
|
+
request=request,
|
|
793
|
+
request_body_required=True,
|
|
794
|
+
request_has_path_params=False,
|
|
795
|
+
request_has_query_params=True,
|
|
796
|
+
user_agent_header="user-agent",
|
|
797
|
+
accept_header_value="application/json",
|
|
798
|
+
http_headers=http_headers,
|
|
799
|
+
security=self.sdk_configuration.security,
|
|
800
|
+
get_serialized_body=lambda: utils.serialize_request_body(
|
|
801
|
+
request, False, False, "json", models.ChatClassificationRequest
|
|
802
|
+
),
|
|
803
|
+
timeout_ms=timeout_ms,
|
|
804
|
+
)
|
|
805
|
+
|
|
806
|
+
if retries == UNSET:
|
|
807
|
+
if self.sdk_configuration.retry_config is not UNSET:
|
|
808
|
+
retries = self.sdk_configuration.retry_config
|
|
809
|
+
|
|
810
|
+
retry_config = None
|
|
811
|
+
if isinstance(retries, utils.RetryConfig):
|
|
812
|
+
retry_config = (retries, ["429", "500", "502", "503", "504"])
|
|
813
|
+
|
|
814
|
+
http_res = await self.do_request_async(
|
|
815
|
+
hook_ctx=HookContext(
|
|
816
|
+
base_url=base_url or "",
|
|
817
|
+
operation_id="chat_classifications_v1_chat_classifications_post",
|
|
818
|
+
oauth2_scopes=[],
|
|
819
|
+
security_source=get_security_from_env(
|
|
820
|
+
self.sdk_configuration.security, models.Security
|
|
821
|
+
),
|
|
822
|
+
),
|
|
823
|
+
request=req,
|
|
824
|
+
error_status_codes=["422", "4XX", "5XX"],
|
|
825
|
+
retry_config=retry_config,
|
|
826
|
+
)
|
|
827
|
+
|
|
416
828
|
response_data: Any = None
|
|
417
829
|
if utils.match_response(http_res, "200", "application/json"):
|
|
418
830
|
return utils.unmarshal_json(http_res.text, models.ClassificationResponse)
|
mistralai/embeddings.py
CHANGED
|
@@ -15,7 +15,9 @@ class Embeddings(BaseSDK):
|
|
|
15
15
|
self,
|
|
16
16
|
*,
|
|
17
17
|
model: str,
|
|
18
|
-
inputs: Union[
|
|
18
|
+
inputs: Union[
|
|
19
|
+
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
20
|
+
],
|
|
19
21
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
20
22
|
server_url: Optional[str] = None,
|
|
21
23
|
timeout_ms: Optional[int] = None,
|
|
@@ -120,7 +122,9 @@ class Embeddings(BaseSDK):
|
|
|
120
122
|
self,
|
|
121
123
|
*,
|
|
122
124
|
model: str,
|
|
123
|
-
inputs: Union[
|
|
125
|
+
inputs: Union[
|
|
126
|
+
models.EmbeddingRequestInputs, models.EmbeddingRequestInputsTypedDict
|
|
127
|
+
],
|
|
124
128
|
retries: OptionalNullable[utils.RetryConfig] = UNSET,
|
|
125
129
|
server_url: Optional[str] = None,
|
|
126
130
|
timeout_ms: Optional[int] = None,
|
|
@@ -1,11 +1,12 @@
|
|
|
1
1
|
from typing import Any
|
|
2
2
|
|
|
3
|
+
|
|
3
4
|
def rec_strict_json_schema(schema_node: Any) -> Any:
|
|
4
5
|
"""
|
|
5
6
|
Recursively set the additionalProperties property to False for all objects in the JSON Schema.
|
|
6
7
|
This makes the JSON Schema strict (i.e. no additional properties are allowed).
|
|
7
8
|
"""
|
|
8
|
-
if isinstance(schema_node, (str, bool)):
|
|
9
|
+
if isinstance(schema_node, (str, bool)) or schema_node is None:
|
|
9
10
|
return schema_node
|
|
10
11
|
if isinstance(schema_node, dict):
|
|
11
12
|
if "type" in schema_node and schema_node["type"] == "object":
|