craft-ai-sdk 0.63.1rc1__tar.gz → 0.64.0rc1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of craft-ai-sdk might be problematic. Click here for more details.

Files changed (33) hide show
  1. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/PKG-INFO +1 -1
  2. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/__init__.py +1 -1
  3. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/data_store.py +35 -12
  4. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/deployments.py +117 -30
  5. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/endpoints.py +45 -10
  6. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/environment_variables.py +11 -3
  7. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/pipeline_metrics.py +32 -13
  8. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/pipelines.py +72 -17
  9. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/resource_metrics.py +51 -1
  10. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/steps.py +113 -29
  11. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/users.py +8 -1
  12. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/vector_database.py +11 -2
  13. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/io.py +26 -4
  14. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/sdk.py +2 -1
  15. craft_ai_sdk-0.64.0rc1/craft_ai_sdk/shared/types.py +6 -0
  16. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/utils/datetime_utils.py +2 -2
  17. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/utils/dict_utils.py +1 -1
  18. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/utils/file_utils.py +10 -4
  19. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/documentation.pdf +0 -0
  20. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/pyproject.toml +1 -1
  21. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/LICENSE +0 -0
  22. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/README.md +0 -0
  23. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/constants.py +0 -0
  24. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/core/pipeline_executions.py +0 -0
  25. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/exceptions.py +0 -0
  26. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/authentication.py +0 -0
  27. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/environments.py +0 -0
  28. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/execution_context.py +0 -0
  29. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/helpers.py +0 -0
  30. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/logger.py +0 -0
  31. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/request_response_handler.py +0 -0
  32. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/shared/warnings.py +0 -0
  33. {craft_ai_sdk-0.63.1rc1 → craft_ai_sdk-0.64.0rc1}/craft_ai_sdk/utils/__init__.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: craft-ai-sdk
3
- Version: 0.63.1rc1
3
+ Version: 0.64.0rc1
4
4
  Summary: Craft AI MLOps platform SDK
5
5
  License: Apache-2.0
6
6
  Author: Craft AI
@@ -13,4 +13,4 @@ from .io import ( # noqa: F401
13
13
  )
14
14
  from .sdk import CraftAiSdk # noqa: F401
15
15
 
16
- __version__ = "0.63.1rc1"
16
+ __version__ = "0.64.0rc1"
@@ -1,4 +1,5 @@
1
1
  import io
2
+ from typing import TypedDict, Union
2
3
 
3
4
  import requests
4
5
 
@@ -8,7 +9,19 @@ from ..shared.request_response_handler import handle_data_store_response
8
9
  from ..utils import chunk_buffer, convert_size
9
10
 
10
11
 
11
- def get_data_store_object_information(sdk: BaseCraftAiSdk, object_path_in_datastore):
12
+ class DataStoreObjectInformation(TypedDict):
13
+ path: str
14
+ last_modified: str
15
+ size: str
16
+
17
+
18
+ class DataStoreDeletedObject(TypedDict):
19
+ path: str
20
+
21
+
22
+ def get_data_store_object_information(
23
+ sdk: BaseCraftAiSdk, object_path_in_datastore: str
24
+ ) -> DataStoreObjectInformation:
12
25
  """Get information about a single object in the data store.
13
26
 
14
27
  Args:
@@ -33,7 +46,9 @@ def get_data_store_object_information(sdk: BaseCraftAiSdk, object_path_in_datast
33
46
  return result
34
47
 
35
48
 
36
- def list_data_store_objects(sdk: BaseCraftAiSdk, folder_path=None):
49
+ def list_data_store_objects(
50
+ sdk: BaseCraftAiSdk, folder_path: Union[str, None] = None
51
+ ) -> list[DataStoreObjectInformation]:
37
52
  """Get the list of the objects stored in the data store.
38
53
 
39
54
  Args:
@@ -63,7 +78,7 @@ def list_data_store_objects(sdk: BaseCraftAiSdk, folder_path=None):
63
78
  return all_items
64
79
 
65
80
 
66
- def _get_upload_presigned_url(sdk: BaseCraftAiSdk, object_path_in_datastore):
81
+ def _get_upload_presigned_url(sdk: BaseCraftAiSdk, object_path_in_datastore: str):
67
82
  url = f"{sdk.base_environment_api_url}/data-store/upload"
68
83
  params = {"path_to_object": object_path_in_datastore}
69
84
  resp = sdk._get(url, params=params)
@@ -74,7 +89,9 @@ def _get_upload_presigned_url(sdk: BaseCraftAiSdk, object_path_in_datastore):
74
89
 
75
90
  @log_func_result("Object upload")
76
91
  def upload_data_store_object(
77
- sdk: BaseCraftAiSdk, filepath_or_buffer, object_path_in_datastore
92
+ sdk: BaseCraftAiSdk,
93
+ filepath_or_buffer: Union[str, io.IOBase],
94
+ object_path_in_datastore: str,
78
95
  ):
79
96
  """Upload a file as an object into the data store.
80
97
 
@@ -117,7 +134,7 @@ def upload_data_store_object(
117
134
 
118
135
 
119
136
  def _upload_singlepart_data_store_object(
120
- sdk: BaseCraftAiSdk, buffer, object_path_in_datastore
137
+ sdk: BaseCraftAiSdk, buffer: io.IOBase, object_path_in_datastore: str
121
138
  ):
122
139
  files = {"file": buffer}
123
140
 
@@ -128,7 +145,7 @@ def _upload_singlepart_data_store_object(
128
145
 
129
146
 
130
147
  def _upload_multipart_data_store_object(
131
- sdk: BaseCraftAiSdk, buffer, object_path_in_datastore
148
+ sdk: BaseCraftAiSdk, buffer: io.IOBase, object_path_in_datastore: str
132
149
  ):
133
150
  multipart_base_url = f"{sdk.base_environment_api_url}/data-store/upload/multipart"
134
151
  multipart_upload_configuration = sdk._post(
@@ -156,10 +173,10 @@ def _upload_multipart_data_store_object(
156
173
  data=chunk,
157
174
  headers=multipart_part_result["headers"],
158
175
  )
159
- partData = {"number": part_idx}
176
+ part_data: dict = {"number": part_idx}
160
177
  if "ETag" in resp.headers:
161
- partData["metadata"] = resp.headers["ETag"]
162
- parts.append(partData)
178
+ part_data["metadata"] = resp.headers["ETag"]
179
+ parts.append(part_data)
163
180
  if part["lastChunk"]:
164
181
  break
165
182
 
@@ -173,7 +190,9 @@ def _upload_multipart_data_store_object(
173
190
  )
174
191
 
175
192
 
176
- def _get_download_presigned_url(sdk: BaseCraftAiSdk, object_path_in_datastore):
193
+ def _get_download_presigned_url(
194
+ sdk: BaseCraftAiSdk, object_path_in_datastore: str
195
+ ) -> str:
177
196
  url = f"{sdk.base_environment_api_url}/data-store/download"
178
197
  data = {
179
198
  "path_to_object": object_path_in_datastore,
@@ -184,7 +203,9 @@ def _get_download_presigned_url(sdk: BaseCraftAiSdk, object_path_in_datastore):
184
203
 
185
204
  @log_func_result("Object download")
186
205
  def download_data_store_object(
187
- sdk: BaseCraftAiSdk, object_path_in_datastore, filepath_or_buffer
206
+ sdk: BaseCraftAiSdk,
207
+ object_path_in_datastore: str,
208
+ filepath_or_buffer: Union[str, io.IOBase],
188
209
  ):
189
210
  """Download an object in the data store and save it into a file.
190
211
 
@@ -218,7 +239,9 @@ def download_data_store_object(
218
239
 
219
240
 
220
241
  @log_func_result("Object deletion")
221
- def delete_data_store_object(sdk: BaseCraftAiSdk, object_path_in_datastore):
242
+ def delete_data_store_object(
243
+ sdk: BaseCraftAiSdk, object_path_in_datastore: str
244
+ ) -> DataStoreDeletedObject:
222
245
  """Delete an object on the datastore.
223
246
 
224
247
  Args:
@@ -1,6 +1,17 @@
1
+ from datetime import datetime
2
+ from typing import Literal, TypedDict, Union
3
+ from typing_extensions import NotRequired
4
+
5
+ from craft_ai_sdk.shared.types import Log
1
6
  from ..constants import DEPLOYMENT_EXECUTION_RULES, DEPLOYMENT_MODES, DEPLOYMENT_STATUS
2
7
  from ..exceptions import SdkException
3
- from ..io import _validate_inputs_mapping, _validate_outputs_mapping
8
+ from ..io import (
9
+ InputSource,
10
+ InputSourceDict,
11
+ OutputDestination,
12
+ _validate_inputs_mapping,
13
+ _validate_outputs_mapping,
14
+ )
4
15
  from ..sdk import BaseCraftAiSdk
5
16
  from ..shared.logger import log_action, log_func_result
6
17
  from ..utils import (
@@ -10,25 +21,98 @@ from ..utils import (
10
21
  )
11
22
 
12
23
 
24
+ class DeploymentInputMapping(InputSourceDict):
25
+ data_type: str
26
+ description: str
27
+
28
+
29
+ class DeploymentOutputMapping(OutputDestination):
30
+ data_type: str
31
+ description: str
32
+
33
+
34
+ class DeploymentPipeline(TypedDict):
35
+ name: str
36
+
37
+
38
+ class DeploymentPod(TypedDict):
39
+ pod_id: str
40
+ status: str
41
+
42
+
43
+ class DeploymentBase(TypedDict):
44
+ name: str
45
+ mode: str
46
+ pipeline: DeploymentPipeline
47
+ inputs_mapping: list[DeploymentInputMapping]
48
+ outputs_mapping: list[DeploymentOutputMapping]
49
+ created_at: str
50
+ created_by: str
51
+ updated_at: str
52
+ updated_by: str
53
+ last_execution_id: str
54
+ is_enabled: bool
55
+ description: str
56
+ status: str
57
+ enable_parallel_executions: bool
58
+ max_parallel_executions_per_pod: int
59
+ pods: NotRequired[list[DeploymentPod]]
60
+
61
+
62
+ class DeploymentEndpoint(DeploymentBase):
63
+ endpoint_token: str
64
+ endpoint_url_path: str
65
+ execution_rule: Literal["endpoint"]
66
+
67
+
68
+ class DeploymentPeriodic(DeploymentBase):
69
+ schedule: str
70
+ human_readable_schedule: str
71
+ execution_rule: Literal["periodic"]
72
+
73
+
74
+ Deployment = Union[DeploymentEndpoint, DeploymentPeriodic]
75
+
76
+
77
+ class DeploymentLog(Log):
78
+ stream: str
79
+ pod_id: str
80
+ type: str
81
+
82
+
83
+ class DeploymentListItem(TypedDict):
84
+ name: str
85
+ pipeline_name: str
86
+ execution_rule: str
87
+ mode: str
88
+ is_enabled: bool
89
+ created_at: str
90
+
91
+
92
+ class DeploymentDeleted(TypedDict):
93
+ name: str
94
+ execution_rule: str
95
+
96
+
13
97
  @log_func_result("Deployment creation")
14
98
  def create_deployment(
15
99
  sdk: BaseCraftAiSdk,
16
- pipeline_name,
17
- deployment_name,
18
- execution_rule,
100
+ pipeline_name: str,
101
+ deployment_name: str,
102
+ execution_rule: DEPLOYMENT_EXECUTION_RULES,
19
103
  mode=DEPLOYMENT_MODES.ELASTIC,
20
- schedule=None,
21
- endpoint_url_path=None,
22
- inputs_mapping=None,
23
- outputs_mapping=None,
24
- description=None,
25
- enable_parallel_executions=None,
26
- max_parallel_executions_per_pod=None,
27
- ram_request=None,
28
- gpu_request=None,
104
+ schedule: Union[str, None] = None,
105
+ endpoint_url_path: Union[str, None] = None,
106
+ inputs_mapping: Union[list[InputSource], None] = None,
107
+ outputs_mapping: Union[list[OutputDestination], None] = None,
108
+ description: Union[str, None] = None,
109
+ enable_parallel_executions: Union[bool, None] = None,
110
+ max_parallel_executions_per_pod: Union[int, None] = None,
111
+ ram_request: Union[str, None] = None,
112
+ gpu_request: Union[int, None] = None,
29
113
  wait_for_completion=True,
30
- timeout_s=None,
31
- execution_timeout_s=None,
114
+ timeout_s: Union[int, None] = None,
115
+ execution_timeout_s: Union[int, None] = None,
32
116
  ):
33
117
  """Create a deployment associated with a given pipeline.
34
118
 
@@ -397,8 +481,11 @@ FINAL_DEPLOYMENT_STATUSES = [
397
481
 
398
482
 
399
483
  def get_deployment(
400
- sdk: BaseCraftAiSdk, deployment_name, wait_for_completion=False, timeout_s=None
401
- ):
484
+ sdk: BaseCraftAiSdk,
485
+ deployment_name: str,
486
+ wait_for_completion=False,
487
+ timeout_s: Union[int, None] = None,
488
+ ) -> Deployment:
402
489
  """Get information of a deployment.
403
490
 
404
491
  Args:
@@ -563,13 +650,13 @@ wait_for_completion parameter set to false.',
563
650
  @log_func_result("Deployment update")
564
651
  def update_deployment(
565
652
  sdk: BaseCraftAiSdk,
566
- deployment_name,
567
- is_enabled=None,
568
- inputs_mapping=None,
569
- outputs_mapping=None,
570
- schedule=None,
653
+ deployment_name: str,
654
+ is_enabled: Union[bool, None] = None,
655
+ inputs_mapping: Union[list[InputSource], None] = None,
656
+ outputs_mapping: Union[list[OutputDestination], None] = None,
657
+ schedule: Union[str, None] = None,
571
658
  wait_for_completion=True,
572
- timeout_s=None,
659
+ timeout_s: Union[int, None] = None,
573
660
  ):
574
661
  """Update the specified properties of a deployment. The properties that can be
575
662
  updated include enabling/disabling the deployment, updating input/output values,
@@ -624,11 +711,11 @@ def update_deployment(
624
711
 
625
712
  def get_deployment_logs(
626
713
  sdk: BaseCraftAiSdk,
627
- deployment_name,
628
- from_datetime=None,
629
- to_datetime=None,
630
- limit=None,
631
- ):
714
+ deployment_name: str,
715
+ from_datetime: Union[datetime, None] = None,
716
+ to_datetime: Union[datetime, None] = None,
717
+ limit: Union[int, None] = None,
718
+ ) -> list[DeploymentLog]:
632
719
  """Get the logs of a deployment with ``"low_latency"`` mode.
633
720
 
634
721
  Args:
@@ -667,7 +754,7 @@ def get_deployment_logs(
667
754
  return sdk._post(url, json=data)
668
755
 
669
756
 
670
- def list_deployments(sdk: BaseCraftAiSdk):
757
+ def list_deployments(sdk: BaseCraftAiSdk) -> list[DeploymentListItem]:
671
758
  """Get the list of all deployments.
672
759
 
673
760
  Returns:
@@ -689,7 +776,7 @@ def list_deployments(sdk: BaseCraftAiSdk):
689
776
 
690
777
 
691
778
  @log_func_result("Deployment deletion")
692
- def delete_deployment(sdk: BaseCraftAiSdk, deployment_name):
779
+ def delete_deployment(sdk: BaseCraftAiSdk, deployment_name: str) -> DeploymentDeleted:
693
780
  """Delete a deployment identified by its name.
694
781
 
695
782
  Args:
@@ -1,4 +1,5 @@
1
1
  import io
2
+ from typing import Any, Literal, TypedDict, Union, overload
2
3
  from urllib.parse import urlencode
3
4
 
4
5
  import requests
@@ -9,7 +10,7 @@ from ..shared.request_response_handler import handle_http_response
9
10
  from .deployments import get_deployment
10
11
 
11
12
 
12
- def _get_endpoint_url_path(sdk: BaseCraftAiSdk, endpoint_name):
13
+ def _get_endpoint_url_path(sdk: BaseCraftAiSdk, endpoint_name: str):
13
14
  deployment = get_deployment(sdk, endpoint_name)
14
15
 
15
16
  if deployment.get("execution_rule", "") != "endpoint":
@@ -18,14 +19,46 @@ def _get_endpoint_url_path(sdk: BaseCraftAiSdk, endpoint_name):
18
19
  return deployment.get("endpoint_url_path", "")
19
20
 
20
21
 
22
+ class EndpointTriggerBase(TypedDict):
23
+ execution_id: str
24
+
25
+
26
+ class EndpointTriggerWithOutputs(EndpointTriggerBase):
27
+ outputs: dict[str, Any]
28
+
29
+
30
+ class EndpointNewToken(TypedDict):
31
+ endpoint_token: str
32
+
33
+
34
+ @overload
35
+ def trigger_endpoint(
36
+ sdk: BaseCraftAiSdk,
37
+ endpoint_name: str,
38
+ endpoint_token: str,
39
+ inputs: dict[str, Any],
40
+ wait_for_completion: Literal[True],
41
+ ) -> EndpointTriggerWithOutputs: ...
42
+
43
+
44
+ @overload
45
+ def trigger_endpoint(
46
+ sdk: BaseCraftAiSdk,
47
+ endpoint_name: str,
48
+ endpoint_token: str,
49
+ inputs: dict[str, Any],
50
+ wait_for_completion: Literal[False],
51
+ ) -> EndpointTriggerBase: ...
52
+
53
+
21
54
  @log_func_result("Endpoint trigger")
22
55
  def trigger_endpoint(
23
56
  sdk: BaseCraftAiSdk,
24
- endpoint_name,
25
- endpoint_token,
26
- inputs=None,
57
+ endpoint_name: str,
58
+ endpoint_token: str,
59
+ inputs: Union[dict[str, Any], None] = None,
27
60
  wait_for_completion=True,
28
- ):
61
+ ) -> Union[EndpointTriggerWithOutputs, EndpointTriggerBase]:
29
62
  """Trigger an endpoint.
30
63
 
31
64
  Args:
@@ -88,10 +121,10 @@ def trigger_endpoint(
88
121
  @log_func_result("Endpoint result retrieval")
89
122
  def retrieve_endpoint_results(
90
123
  sdk: BaseCraftAiSdk,
91
- endpoint_name,
92
- execution_id,
93
- endpoint_token,
94
- ):
124
+ endpoint_name: str,
125
+ execution_id: str,
126
+ endpoint_token: str,
127
+ ) -> EndpointTriggerWithOutputs:
95
128
  """Get the results of an endpoint execution.
96
129
 
97
130
  Args:
@@ -135,7 +168,9 @@ def retrieve_endpoint_results(
135
168
  return {**handle_http_response(response), "execution_id": execution_id}
136
169
 
137
170
 
138
- def generate_new_endpoint_token(sdk: BaseCraftAiSdk, endpoint_name):
171
+ def generate_new_endpoint_token(
172
+ sdk: BaseCraftAiSdk, endpoint_name: str
173
+ ) -> EndpointNewToken:
139
174
  """Generate a new endpoint token for an endpoint.
140
175
 
141
176
  Args:
@@ -1,10 +1,16 @@
1
+ from typing import TypedDict
1
2
  from ..sdk import BaseCraftAiSdk
2
3
  from ..shared.logger import log_func_result
3
4
 
4
5
 
6
+ class EnvironmentVariable(TypedDict):
7
+ name: str
8
+ value: str
9
+
10
+
5
11
  @log_func_result("Environment variable definition")
6
12
  def create_or_update_environment_variable(
7
- sdk: BaseCraftAiSdk, environment_variable_name, environment_variable_value
13
+ sdk: BaseCraftAiSdk, environment_variable_name: str, environment_variable_value: str
8
14
  ):
9
15
  """Create or update an environment variable available for
10
16
  all pipelines executions.
@@ -28,7 +34,7 @@ def create_or_update_environment_variable(
28
34
  return None
29
35
 
30
36
 
31
- def list_environment_variables(sdk: BaseCraftAiSdk):
37
+ def list_environment_variables(sdk: BaseCraftAiSdk) -> list[EnvironmentVariable]:
32
38
  """Get a list of all environments variables.
33
39
 
34
40
  Returns:
@@ -43,7 +49,9 @@ def list_environment_variables(sdk: BaseCraftAiSdk):
43
49
 
44
50
 
45
51
  @log_func_result("Environment variable deletion")
46
- def delete_environment_variable(sdk: BaseCraftAiSdk, environment_variable_name):
52
+ def delete_environment_variable(
53
+ sdk: BaseCraftAiSdk, environment_variable_name: str
54
+ ) -> EnvironmentVariable:
47
55
  """Delete the specified environment variable
48
56
 
49
57
  Args:
@@ -1,3 +1,4 @@
1
+ from typing import TypedDict, Union
1
2
  import warnings
2
3
 
3
4
  from ..sdk import BaseCraftAiSdk
@@ -6,8 +7,26 @@ from ..shared.logger import log_func_result
6
7
  from ..utils import remove_none_values
7
8
 
8
9
 
10
+ class Metric(TypedDict):
11
+ name: str
12
+ value: float
13
+ created_at: str
14
+ execution_id: str
15
+ deployment_name: str
16
+ pipeline_name: str
17
+
18
+
19
+ class ListMetric(TypedDict):
20
+ name: str
21
+ values: list[float]
22
+ created_at: str
23
+ execution_id: str
24
+ deployment_name: str
25
+ pipeline_name: str
26
+
27
+
9
28
  @log_func_result("Pipeline metrics definition", get_execution_id)
10
- def record_metric_value(sdk: BaseCraftAiSdk, name, value):
29
+ def record_metric_value(sdk: BaseCraftAiSdk, name: str, value: float):
11
30
  """Create or update a pipeline metric. Note that this function can only be used
12
31
  inside a step code.
13
32
 
@@ -34,7 +53,7 @@ been sent",
34
53
 
35
54
 
36
55
  @log_func_result("Pipeline list metric definition", get_execution_id)
37
- def record_list_metric_values(sdk: BaseCraftAiSdk, name, values):
56
+ def record_list_metric_values(sdk: BaseCraftAiSdk, name: str, values: list[float]):
38
57
  """Add values to a pipeline metric list. Note that this function can only be
39
58
  used inside a step code.
40
59
 
@@ -74,11 +93,11 @@ been sent",
74
93
  @log_func_result("Pipeline metrics listing")
75
94
  def get_metrics(
76
95
  sdk: BaseCraftAiSdk,
77
- name=None,
78
- pipeline_name=None,
79
- deployment_name=None,
80
- execution_id=None,
81
- ):
96
+ name: Union[str, None] = None,
97
+ pipeline_name: Union[str, None] = None,
98
+ deployment_name: Union[str, None] = None,
99
+ execution_id: Union[str, None] = None,
100
+ ) -> list[Metric]:
82
101
  """Get a list of pipeline metrics. Note that only one of the
83
102
  parameters (pipeline_name, deployment_name, execution_id) can be set.
84
103
 
@@ -140,11 +159,11 @@ def get_metrics(
140
159
  @log_func_result("Pipeline list metrics listing")
141
160
  def get_list_metrics(
142
161
  sdk: BaseCraftAiSdk,
143
- name=None,
144
- pipeline_name=None,
145
- deployment_name=None,
146
- execution_id=None,
147
- ):
162
+ name: Union[str, None] = None,
163
+ pipeline_name: Union[str, None] = None,
164
+ deployment_name: Union[str, None] = None,
165
+ execution_id: Union[str, None] = None,
166
+ ) -> list[ListMetric]:
148
167
  """Get a list of pipeline metric lists. Note that only one of the
149
168
  parameters (pipeline_name, deployment_name, execution_id) can be set.
150
169
 
@@ -162,7 +181,7 @@ def get_list_metrics(
162
181
  with the following keys:
163
182
 
164
183
  * ``name`` (:obj:`str`): Name of the metric.
165
- * ``value`` (:obj:`float`): Value of the metric.
184
+ * ``values`` (:obj:`list[float]`): Values list of the metric.
166
185
  * ``created_at`` (:obj:`str`): Date of the metric creation.
167
186
  * ``execution_id`` (:obj:`str`): Name of the execution the metric
168
187
  belongs to.
@@ -1,21 +1,70 @@
1
1
  import os
2
+ from typing import TypedDict, Union, cast
3
+ from typing_extensions import NotRequired
2
4
  import warnings
3
5
 
4
6
  import requests
5
7
 
8
+ from craft_ai_sdk.io import Input, Output
9
+ from craft_ai_sdk.shared.types import Log
10
+
6
11
  from ..sdk import BaseCraftAiSdk
7
12
  from ..shared.helpers import wait_create_until_ready
8
13
  from ..shared.logger import log_action, log_func_result
9
14
  from ..shared.request_response_handler import handle_data_store_response
10
15
  from ..utils import datetime_to_timestamp_in_ms, multipartify, remove_keys_from_dict
11
16
  from .steps import (
17
+ ContainerConfig,
18
+ StepCreationInfo,
12
19
  _prepare_create_step_data,
13
20
  _prepare_create_step_files,
14
21
  _validate_create_step_parameters,
15
22
  )
16
23
 
17
24
 
18
- def _create_pipeline_with_step(sdk: BaseCraftAiSdk, pipeline_name: str, step_name: str):
25
+ class PipelineParameter(TypedDict):
26
+ step_name: NotRequired[str]
27
+ pipeline_name: str
28
+ function_path: str
29
+ function_name: str
30
+ description: str
31
+ inputs: list[Input]
32
+ outputs: list[Output]
33
+ container_config: ContainerConfig
34
+
35
+
36
+ class PipelineCreationInfo(StepCreationInfo):
37
+ last_execution_id: str
38
+
39
+
40
+ class Pipeline(TypedDict):
41
+ parameters: PipelineParameter
42
+ creation_info: PipelineCreationInfo
43
+
44
+
45
+ class PipelineListItem(TypedDict):
46
+ pipeline_name: str
47
+ created_at: str
48
+ status: str
49
+
50
+
51
+ class PipelineDeletedPipeline(TypedDict):
52
+ name: str
53
+
54
+
55
+ class PipelineDeletedDeployment(TypedDict):
56
+ name: str
57
+ execution_rule: str
58
+
59
+
60
+ class PipelineDeleted(TypedDict):
61
+ pipeline: PipelineDeletedPipeline
62
+ deployments: list[PipelineDeletedDeployment]
63
+
64
+
65
+ def _create_pipeline_with_step(
66
+ sdk: BaseCraftAiSdk, pipeline_name: str, step_name: str
67
+ ) -> Pipeline:
19
68
  url = f"{sdk.base_environment_api_url}/pipelines"
20
69
  body = {
21
70
  "pipeline_name": pipeline_name,
@@ -30,15 +79,15 @@ def _create_pipeline_with_step(sdk: BaseCraftAiSdk, pipeline_name: str, step_nam
30
79
  @log_func_result("Pipeline creation")
31
80
  def create_pipeline(
32
81
  sdk: BaseCraftAiSdk,
33
- pipeline_name,
34
- function_path=None,
35
- function_name=None,
36
- description=None,
37
- container_config=None,
38
- inputs=None,
39
- outputs=None,
82
+ pipeline_name: str,
83
+ function_path: Union[str, None] = None,
84
+ function_name: Union[str, None] = None,
85
+ description: Union[str, None] = None,
86
+ container_config: Union[ContainerConfig, None] = None,
87
+ inputs: Union[list[Input], None] = None,
88
+ outputs: Union[list[Output], None] = None,
40
89
  wait_for_completion=True,
41
- timeout_s=None,
90
+ timeout_s: Union[int, None] = None,
42
91
  **kwargs,
43
92
  ):
44
93
  """Create a pipeline from a function located on a remote repository or locally.
@@ -211,7 +260,11 @@ def create_pipeline(
211
260
  )
212
261
 
213
262
  # Otherwise, the pipeline is created with a "hidden" step
214
- container_config = {} if container_config is None else container_config.copy()
263
+ container_config = (
264
+ cast(ContainerConfig, {})
265
+ if container_config is None
266
+ else cast(ContainerConfig, container_config.copy())
267
+ )
215
268
  _validate_create_step_parameters(inputs, outputs, timeout_s)
216
269
 
217
270
  url = f"{sdk.base_environment_api_url}/pipelines"
@@ -239,7 +292,7 @@ def create_pipeline(
239
292
 
240
293
  def get_pipeline(
241
294
  sdk: BaseCraftAiSdk, pipeline_name, wait_for_completion=False, timeout_s=None
242
- ):
295
+ ) -> Pipeline:
243
296
  """Get a single pipeline if it exists.
244
297
 
245
298
  Args:
@@ -336,7 +389,7 @@ def get_pipeline(
336
389
  )
337
390
 
338
391
 
339
- def list_pipelines(sdk: BaseCraftAiSdk):
392
+ def list_pipelines(sdk: BaseCraftAiSdk) -> list[PipelineListItem]:
340
393
  """Get the list of all pipelines.
341
394
 
342
395
  Returns:
@@ -354,8 +407,8 @@ def list_pipelines(sdk: BaseCraftAiSdk):
354
407
 
355
408
  @log_func_result("Pipeline deletion")
356
409
  def delete_pipeline(
357
- sdk: BaseCraftAiSdk, pipeline_name, force_deployments_deletion=False
358
- ):
410
+ sdk: BaseCraftAiSdk, pipeline_name: str, force_deployments_deletion=False
411
+ ) -> PipelineDeleted:
359
412
  """Delete a pipeline identified by its name.
360
413
 
361
414
  Args:
@@ -386,14 +439,16 @@ def delete_pipeline(
386
439
  return sdk._delete(url, params=params)
387
440
 
388
441
 
389
- def _get_download_presigned_url(sdk: BaseCraftAiSdk, pipeline_name):
442
+ def _get_download_presigned_url(sdk: BaseCraftAiSdk, pipeline_name: str) -> str:
390
443
  url = f"{sdk.base_environment_api_url}/pipelines/{pipeline_name}/download"
391
444
  presigned_url = sdk._get(url)["signed_url"]
392
445
  return presigned_url
393
446
 
394
447
 
395
448
  @log_func_result("pipeline download")
396
- def download_pipeline_local_folder(sdk: BaseCraftAiSdk, pipeline_name, folder):
449
+ def download_pipeline_local_folder(
450
+ sdk: BaseCraftAiSdk, pipeline_name: str, folder: str
451
+ ):
397
452
  """Download a pipeline's local folder as a `.tgz` archive.
398
453
 
399
454
  Only available if the pipeline's ``origin`` is ``"local_folder"``. This archive
@@ -427,7 +482,7 @@ def get_pipeline_logs(
427
482
  from_datetime=None,
428
483
  to_datetime=None,
429
484
  limit=None,
430
- ):
485
+ ) -> list[Log]:
431
486
  """Get the logs of a pipeline.
432
487
 
433
488
  Args:
@@ -1,8 +1,58 @@
1
+ from datetime import datetime
2
+ from typing import Literal, TypedDict, Union, overload
3
+ from typing_extensions import NotRequired
1
4
  from ..sdk import BaseCraftAiSdk
2
5
  from ..utils import datetime_to_timestamp_in_ms
3
6
 
4
7
 
5
- def get_resource_metrics(sdk: BaseCraftAiSdk, start_date, end_date, csv=False):
8
+ class AdditionalData(TypedDict):
9
+ total_disk: int
10
+ total_ram: int
11
+ total_vram: NotRequired[str]
12
+
13
+
14
+ class MetricWorker(TypedDict):
15
+ worker: str
16
+
17
+
18
+ class Metric(TypedDict):
19
+ metric: MetricWorker
20
+ values: list[list[Union[int, float]]]
21
+
22
+
23
+ class MetricsDict(TypedDict):
24
+ cpu_usage: list[Metric]
25
+ disk_usage: list[Metric]
26
+ ram_usage: list[Metric]
27
+ vram_usage: NotRequired[list[Metric]]
28
+ gpu_usage: NotRequired[list[Metric]]
29
+ network_input_usage: list[Metric]
30
+ network_output_usage: list[Metric]
31
+
32
+
33
+ class ResourceMetrics(TypedDict):
34
+ additional_data: AdditionalData
35
+ metrics: MetricsDict
36
+
37
+
38
+ @overload
39
+ def get_resource_metrics(
40
+ sdk: BaseCraftAiSdk, start_date: datetime, end_date: datetime, csv: Literal[True]
41
+ ) -> bytes: ...
42
+
43
+
44
+ @overload
45
+ def get_resource_metrics(
46
+ sdk: BaseCraftAiSdk,
47
+ start_date: datetime,
48
+ end_date: datetime,
49
+ csv: Literal[False] = False,
50
+ ) -> ResourceMetrics: ...
51
+
52
+
53
+ def get_resource_metrics(
54
+ sdk: BaseCraftAiSdk, start_date: datetime, end_date: datetime, csv=False
55
+ ) -> Union[ResourceMetrics, bytes]:
6
56
  """Get resource metrics of the environment.
7
57
 
8
58
  Args:
@@ -1,9 +1,14 @@
1
+ from datetime import datetime
1
2
  import io
2
3
  import os
3
4
  import tarfile
5
+ from typing import TypeVar, TypedDict, Union, cast
6
+ from typing_extensions import NotRequired
4
7
 
5
8
  import requests
6
9
 
10
+ from craft_ai_sdk.shared.types import Log
11
+
7
12
  from ..constants import CREATION_PARAMETER_VALUE
8
13
  from ..io import Input, Output
9
14
  from ..sdk import BaseCraftAiSdk
@@ -12,8 +17,67 @@ from ..shared.logger import log_action, log_func_result
12
17
  from ..shared.request_response_handler import handle_data_store_response
13
18
  from ..utils import datetime_to_timestamp_in_ms, multipartify, remove_none_values
14
19
 
20
+ T = TypeVar("T")
21
+
22
+ CreationParameter = Union[T, CREATION_PARAMETER_VALUE, None]
23
+
24
+
25
+ class ContainerConfigBase(TypedDict):
26
+ language: NotRequired[CreationParameter[str]]
27
+ requirements_path: NotRequired[CreationParameter[str]]
28
+ included_folders: NotRequired[CreationParameter[list[str]]]
29
+ system_dependencies: NotRequired[CreationParameter[list[str]]]
30
+ dockerfile_path: NotRequired[CreationParameter[str]]
31
+
32
+
33
+ class ContainerConfigWithGit(ContainerConfigBase):
34
+ repository_url: NotRequired[CreationParameter[str]]
35
+ repository_branch: NotRequired[CreationParameter[str]]
36
+ repository_deploy_key: NotRequired[CreationParameter[str]]
37
+
38
+
39
+ class ContainerConfigWithLocalFolder(ContainerConfigBase):
40
+ local_folder: str
41
+
42
+
43
+ ContainerConfig = Union[ContainerConfigWithGit, ContainerConfigWithLocalFolder]
44
+
45
+
46
+ class StepParameter(TypedDict):
47
+ step_name: str
48
+ function_path: str
49
+ function_name: str
50
+ description: str
51
+ inputs: list[Input]
52
+ outputs: list[Output]
53
+ container_config: ContainerConfig
54
+ "This type is actually too large, you'll not get any CREATION_PARAMETER_VALUE here"
55
+
56
+
57
+ class StepCreationInfo(TypedDict):
58
+ created_at: str
59
+ created_by: str
60
+ commit_id: str
61
+ status: str
62
+ origin: str
15
63
 
16
- def _compress_folder_to_memory(local_folder, include):
64
+
65
+ class Step(TypedDict):
66
+ parameters: StepParameter
67
+ creation_info: StepCreationInfo
68
+
69
+
70
+ class StepListItem(TypedDict):
71
+ step_name: str
72
+ status: str
73
+ created_at: str
74
+
75
+
76
+ class StepDeleted(TypedDict):
77
+ step_name: str
78
+
79
+
80
+ def _compress_folder_to_memory(local_folder: str, include: list[str]):
17
81
  tar_data = io.BytesIO()
18
82
  # Remove leading slashes from the paths
19
83
  include = [item.lstrip("/") for item in include]
@@ -24,7 +88,11 @@ def _compress_folder_to_memory(local_folder, include):
24
88
  return tar_data
25
89
 
26
90
 
27
- def _validate_create_step_parameters(inputs, outputs, timeout_s):
91
+ def _validate_create_step_parameters(
92
+ inputs: Union[list[Input], None],
93
+ outputs: Union[list[Output], None],
94
+ timeout_s: Union[int, None],
95
+ ):
28
96
  if timeout_s is not None and timeout_s <= 0:
29
97
  raise ValueError("The timeout must be greater than 0 or None.")
30
98
 
@@ -37,7 +105,7 @@ def _validate_create_step_parameters(inputs, outputs, timeout_s):
37
105
  raise ValueError("'outputs' must be a list of instances of Output.")
38
106
 
39
107
 
40
- def _map_container_config_step_parameter(container_config):
108
+ def _map_container_config_step_parameter(container_config: ContainerConfig):
41
109
  """
42
110
  Maps container config with :obj:`CREATION_PARAMETER_VALUE` enum values to final
43
111
  container config. `None` is considered to be equivalent to
@@ -57,12 +125,12 @@ def _map_container_config_step_parameter(container_config):
57
125
 
58
126
 
59
127
  def _prepare_create_step_data(
60
- function_path,
61
- function_name,
62
- description,
63
- container_config,
64
- inputs,
65
- outputs,
128
+ function_path: Union[str, None],
129
+ function_name: Union[str, None],
130
+ description: Union[str, None],
131
+ container_config: ContainerConfig,
132
+ inputs: Union[list[Input], None],
133
+ outputs: Union[list[Output], None],
66
134
  **step_or_pipeline_name,
67
135
  ):
68
136
  assert step_or_pipeline_name.keys() == {
@@ -87,7 +155,12 @@ def _prepare_create_step_data(
87
155
  return data
88
156
 
89
157
 
90
- def _prepare_create_step_files(sdk, container_config, data, function_path):
158
+ def _prepare_create_step_files(
159
+ sdk: BaseCraftAiSdk,
160
+ container_config: ContainerConfig,
161
+ data: dict,
162
+ function_path: Union[str, None],
163
+ ):
91
164
  if "local_folder" not in container_config:
92
165
  return {}
93
166
 
@@ -136,15 +209,15 @@ def _remove_id_from_step(step):
136
209
  @log_func_result("Steps creation")
137
210
  def create_step(
138
211
  sdk: BaseCraftAiSdk,
139
- step_name,
140
- function_path=None,
141
- function_name=None,
142
- description=None,
143
- container_config=None,
144
- inputs=None,
145
- outputs=None,
212
+ step_name: str,
213
+ function_path: Union[str, None] = None,
214
+ function_name: Union[str, None] = None,
215
+ description: Union[str, None] = None,
216
+ container_config: Union[ContainerConfig, None] = None,
217
+ inputs: Union[list[Input], None] = None,
218
+ outputs: Union[list[Output], None] = None,
146
219
  wait_for_completion=True,
147
- timeout_s=None,
220
+ timeout_s: Union[int, None] = None,
148
221
  ):
149
222
  """Create pipeline step from a function located on a remote repository or locally.
150
223
 
@@ -277,7 +350,11 @@ def create_step(
277
350
  ``"git_repository"`` or ``"local"``.
278
351
  """
279
352
 
280
- container_config = {} if container_config is None else container_config.copy()
353
+ container_config = (
354
+ cast(ContainerConfig, {})
355
+ if container_config is None
356
+ else cast(ContainerConfig, container_config.copy())
357
+ )
281
358
  _validate_create_step_parameters(inputs, outputs, timeout_s=timeout_s)
282
359
 
283
360
  url = f"{sdk.base_environment_api_url}/steps"
@@ -302,7 +379,12 @@ def create_step(
302
379
  return get_step(sdk, step_name, wait_for_completion, timeout_s)
303
380
 
304
381
 
305
- def get_step(sdk: BaseCraftAiSdk, step_name, wait_for_completion=False, timeout_s=None):
382
+ def get_step(
383
+ sdk: BaseCraftAiSdk,
384
+ step_name: str,
385
+ wait_for_completion=False,
386
+ timeout_s: Union[int, None] = None,
387
+ ) -> Step:
306
388
  """Get a single step if it exists.
307
389
 
308
390
  Args:
@@ -391,7 +473,7 @@ def get_step(sdk: BaseCraftAiSdk, step_name, wait_for_completion=False, timeout_
391
473
  return _remove_id_from_step(step)
392
474
 
393
475
 
394
- def list_steps(sdk: BaseCraftAiSdk):
476
+ def list_steps(sdk: BaseCraftAiSdk) -> list[StepListItem]:
395
477
  """Get the list of all steps.
396
478
 
397
479
  Returns:
@@ -408,7 +490,9 @@ def list_steps(sdk: BaseCraftAiSdk):
408
490
 
409
491
 
410
492
  @log_func_result("Step deletion")
411
- def delete_step(sdk: BaseCraftAiSdk, step_name, force_dependents_deletion=False):
493
+ def delete_step(
494
+ sdk: BaseCraftAiSdk, step_name: str, force_dependents_deletion=False
495
+ ) -> StepDeleted:
412
496
  """Delete one step.
413
497
 
414
498
  Args:
@@ -431,14 +515,14 @@ def delete_step(sdk: BaseCraftAiSdk, step_name, force_dependents_deletion=False)
431
515
  return sdk._delete(url, params=params)
432
516
 
433
517
 
434
- def _get_download_presigned_url(sdk: BaseCraftAiSdk, step_name):
518
+ def _get_download_presigned_url(sdk: BaseCraftAiSdk, step_name: str) -> str:
435
519
  url = f"{sdk.base_environment_api_url}/steps/{step_name}/download"
436
520
  presigned_url = sdk._get(url)["signed_url"]
437
521
  return presigned_url
438
522
 
439
523
 
440
524
  @log_func_result("Step download")
441
- def download_step_local_folder(sdk: BaseCraftAiSdk, step_name, folder):
525
+ def download_step_local_folder(sdk: BaseCraftAiSdk, step_name: str, folder: str):
442
526
  """Download a step's local folder as a `.tgz` archive.
443
527
 
444
528
  Only available if the step's ``origin`` is ``"local_folder"``. This archive
@@ -468,11 +552,11 @@ def download_step_local_folder(sdk: BaseCraftAiSdk, step_name, folder):
468
552
  @log_func_result("Step logs")
469
553
  def get_step_logs(
470
554
  sdk: BaseCraftAiSdk,
471
- step_name,
472
- from_datetime=None,
473
- to_datetime=None,
474
- limit=None,
475
- ):
555
+ step_name: str,
556
+ from_datetime: Union[datetime, None] = None,
557
+ to_datetime: Union[datetime, None] = None,
558
+ limit: Union[int, None] = None,
559
+ ) -> list[Log]:
476
560
  """Get the logs of a step.
477
561
 
478
562
  Args:
@@ -1,7 +1,14 @@
1
+ from typing import TypedDict
1
2
  from ..sdk import BaseCraftAiSdk
2
3
 
3
4
 
4
- def get_user(sdk: BaseCraftAiSdk, user_id):
5
+ class User(TypedDict):
6
+ id: str
7
+ name: str
8
+ email: str
9
+
10
+
11
+ def get_user(sdk: BaseCraftAiSdk, user_id: str) -> User:
5
12
  """Get information about a user.
6
13
 
7
14
  Args:
@@ -1,10 +1,16 @@
1
+ from typing import TypedDict
1
2
  from craft_ai_sdk.shared.environments import get_environment_id
2
3
 
3
4
  from ..sdk import BaseCraftAiSdk
4
5
  from ..shared.logger import log_action, log_func_result
5
6
 
6
7
 
7
- def get_vector_database_credentials(sdk: BaseCraftAiSdk):
8
+ class VectorDatabaseCredentials(TypedDict):
9
+ vector_database_url: str
10
+ vector_database_token: str
11
+
12
+
13
+ def get_vector_database_credentials(sdk: BaseCraftAiSdk) -> VectorDatabaseCredentials:
8
14
  """Get the credentials of the vector database.
9
15
 
10
16
  Returns:
@@ -62,6 +68,9 @@ def get_weaviate_client(sdk: BaseCraftAiSdk):
62
68
  auth_credentials=None,
63
69
  )
64
70
 
65
- log_action(sdk, "Connected to Weaviate, using version", weaviate.__version__)
71
+ log_action(
72
+ sdk,
73
+ f"Connected to Weaviate, using version {weaviate.__version__}",
74
+ )
66
75
 
67
76
  return weaviate_client
@@ -1,3 +1,5 @@
1
+ from typing import Any, TypedDict, cast
2
+ from typing_extensions import NotRequired
1
3
  import warnings
2
4
 
3
5
  from strenum import LowercaseStrEnum
@@ -82,6 +84,18 @@ class Output:
82
84
  return remove_none_values(output)
83
85
 
84
86
 
87
+ class InputSourceDict(TypedDict):
88
+ pipeline_input_name: str
89
+ description: str
90
+ constant_value: NotRequired[Any]
91
+ environment_variable_name: NotRequired[str]
92
+ endpoint_input_name: NotRequired[str]
93
+ is_null: NotRequired[bool]
94
+ datastore_path: NotRequired[str]
95
+ is_required: NotRequired[bool]
96
+ default_value: NotRequired[Any]
97
+
98
+
85
99
  class InputSource:
86
100
  """Class to specify to which source a step input should be mapped when creating
87
101
  a deployment (cf. :meth:`.CraftAiSdk.create_deployment`). The different sources can
@@ -167,7 +181,7 @@ class InputSource:
167
181
  self.datastore_path = datastore_path
168
182
  self.step_input_name = step_input_name
169
183
 
170
- def to_dict(self):
184
+ def to_dict(self) -> InputSourceDict:
171
185
  input_mapping_dict = {
172
186
  "pipeline_input_name": self.pipeline_input_name,
173
187
  "endpoint_input_name": self.endpoint_input_name,
@@ -180,7 +194,15 @@ class InputSource:
180
194
  "step_input_name": self.step_input_name,
181
195
  }
182
196
 
183
- return remove_none_values(input_mapping_dict)
197
+ return cast(InputSourceDict, remove_none_values(input_mapping_dict))
198
+
199
+
200
+ class OutputDestinationDict(TypedDict):
201
+ pipeline_output_name: str
202
+ endpoint_output_name: NotRequired[str]
203
+ is_null: NotRequired[bool]
204
+ datastore_path: NotRequired[str]
205
+ step_output_name: NotRequired[str]
184
206
 
185
207
 
186
208
  class OutputDestination:
@@ -252,7 +274,7 @@ class OutputDestination:
252
274
  self.datastore_path = datastore_path
253
275
  self.step_output_name = step_output_name
254
276
 
255
- def to_dict(self):
277
+ def to_dict(self) -> OutputDestinationDict:
256
278
  output_mapping_dict = {
257
279
  "pipeline_output_name": self.pipeline_output_name,
258
280
  "endpoint_output_name": self.endpoint_output_name,
@@ -261,7 +283,7 @@ class OutputDestination:
261
283
  "step_output_name": self.step_output_name,
262
284
  }
263
285
 
264
- return remove_none_values(output_mapping_dict)
286
+ return cast(OutputDestinationDict, remove_none_values(output_mapping_dict))
265
287
 
266
288
 
267
289
  def _format_execution_output(name, output):
@@ -23,6 +23,7 @@ class BaseCraftAiSdk(ABC):
23
23
  _MULTIPART_THRESHOLD: int
24
24
  _MULTIPART_PART_SIZE: int
25
25
  _version: str
26
+ warn_on_metric_outside_of_step: bool
26
27
 
27
28
  @abstractmethod
28
29
  def _get(self, url: str, params=None, **kwargs) -> Any:
@@ -137,7 +138,7 @@ class CraftAiSdk(BaseCraftAiSdk):
137
138
  os.environ.get("CRAFT_AI__MULTIPART_PART_SIZE__B", str(38 * 256 * 1024))
138
139
  )
139
140
  _access_token_margin = timedelta(seconds=30)
140
- _version = "0.63.1rc1" # Would be better to share it somewhere
141
+ _version = "0.64.0rc1" # Would be better to share it somewhere
141
142
 
142
143
  def __init__(
143
144
  self,
@@ -0,0 +1,6 @@
1
+ from typing import TypedDict
2
+
3
+
4
+ class Log(TypedDict):
5
+ timestamp: str
6
+ message: str
@@ -2,13 +2,13 @@ import re
2
2
  from datetime import datetime
3
3
 
4
4
 
5
- def datetime_to_timestamp_in_ms(dt):
5
+ def datetime_to_timestamp_in_ms(dt: datetime) -> int:
6
6
  if not isinstance(dt, datetime):
7
7
  raise ValueError("Parameter must be a datetime.datetime object.")
8
8
  return int(1_000 * dt.timestamp())
9
9
 
10
10
 
11
- def parse_isodate(date_string):
11
+ def parse_isodate(date_string: str):
12
12
  """_summary_
13
13
 
14
14
  Args:
@@ -1,7 +1,7 @@
1
1
  from typing import Union
2
2
 
3
3
 
4
- def remove_none_values(obj):
4
+ def remove_none_values(obj: dict):
5
5
  return {key: value for key, value in obj.items() if value is not None}
6
6
 
7
7
 
@@ -1,5 +1,5 @@
1
1
  from io import BytesIO, IOBase, StringIO
2
- from typing import Iterable, Union
2
+ from typing import Iterable, TypedDict, Union
3
3
 
4
4
 
5
5
  def merge_paths(prefix, path):
@@ -8,7 +8,13 @@ def merge_paths(prefix, path):
8
8
 
9
9
 
10
10
  # From https://stackoverflow.com/a/58767245/4839162
11
- def chunk_buffer(buffer: IOBase, size: int) -> Iterable[Union[BytesIO, StringIO]]:
11
+ class ChunkedIO(TypedDict):
12
+ chunk: Union[BytesIO, StringIO]
13
+ len: int
14
+ lastChunk: bool
15
+
16
+
17
+ def chunk_buffer(buffer: IOBase, size: int) -> Iterable[ChunkedIO]:
12
18
  size_int = int(size)
13
19
  b = buffer.read(size_int)
14
20
  next_data = None
@@ -22,7 +28,7 @@ def chunk_buffer(buffer: IOBase, size: int) -> Iterable[Union[BytesIO, StringIO]
22
28
 
23
29
  next_data = buffer.read(1)
24
30
 
25
- data = {
31
+ data: ChunkedIO = {
26
32
  "chunk": chunk,
27
33
  "len": len(b) + (len(previous_data) if previous_data else 0),
28
34
  "lastChunk": len(next_data) == 0,
@@ -32,7 +38,7 @@ def chunk_buffer(buffer: IOBase, size: int) -> Iterable[Union[BytesIO, StringIO]
32
38
  b = buffer.read(size_int - 1)
33
39
 
34
40
 
35
- def convert_size(size_in_bytes):
41
+ def convert_size(size_in_bytes: Union[int, float]):
36
42
  """
37
43
  Convert a size in bytes to a human readable string.
38
44
  """
@@ -1,6 +1,6 @@
1
1
  [tool.poetry]
2
2
  name = "craft-ai-sdk"
3
- version = "0.63.1rc1"
3
+ version = "0.64.0rc1"
4
4
  description = "Craft AI MLOps platform SDK"
5
5
  license = "Apache-2.0"
6
6
  authors = ["Craft AI <contact@craft.ai>"]