salesforce-data-customcode 3.0.1__tar.gz → 3.0.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (83) hide show
  1. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/PKG-INFO +1 -1
  2. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/pyproject.toml +1 -1
  3. salesforce_data_customcode-3.0.2/src/datacustomcode/common_config.py +63 -0
  4. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/config.py +13 -64
  5. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/config.yaml +4 -0
  6. salesforce_data_customcode-3.0.2/src/datacustomcode/einstein_predictions/__init__.py +22 -0
  7. salesforce_data_customcode-3.0.2/src/datacustomcode/einstein_predictions/base.py +32 -0
  8. salesforce_data_customcode-3.0.2/src/datacustomcode/einstein_predictions/impl/default.py +35 -0
  9. salesforce_data_customcode-3.0.2/src/datacustomcode/einstein_predictions/types.py +184 -0
  10. salesforce_data_customcode-3.0.2/src/datacustomcode/einstein_predictions_config.py +67 -0
  11. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/function/runtime.py +16 -0
  12. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/base.py +7 -3
  13. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/default.py +2 -0
  14. salesforce_data_customcode-3.0.2/src/datacustomcode/llm_gateway_config.py +65 -0
  15. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/payload/entrypoint.py +29 -0
  16. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/LICENSE.txt +0 -0
  17. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/README.md +0 -0
  18. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/__init__.py +0 -0
  19. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/auth.py +0 -0
  20. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/cli.py +0 -0
  21. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/client.py +0 -0
  22. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/cmd.py +0 -0
  23. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/credentials.py +0 -0
  24. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/deploy.py +0 -0
  25. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/file/__init__.py +0 -0
  26. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/file/base.py +0 -0
  27. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/file/path/__init__.py +0 -0
  28. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/file/path/default.py +0 -0
  29. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/function/__init__.py +0 -0
  30. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/function/base.py +0 -0
  31. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/function/feature_types/__init__.py +0 -0
  32. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/function/feature_types/chunking.py +0 -0
  33. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/__init__.py +0 -0
  34. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/base.py +0 -0
  35. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/reader/__init__.py +0 -0
  36. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/reader/base.py +0 -0
  37. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/reader/query_api.py +0 -0
  38. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/reader/sf_cli.py +0 -0
  39. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/reader/utils.py +0 -0
  40. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/writer/__init__.py +0 -0
  41. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/writer/base.py +0 -0
  42. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/writer/csv.py +0 -0
  43. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/io/writer/print.py +0 -0
  44. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/__init__.py +0 -0
  45. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/types/__init__.py +0 -0
  46. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/types/generate_text_request.py +0 -0
  47. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/types/generate_text_request_builder.py +0 -0
  48. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/types/generate_text_response.py +0 -0
  49. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/llm_gateway/types/generate_text_response_builder.py +0 -0
  50. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/mixin.py +0 -0
  51. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/proxy/__init__.py +0 -0
  52. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/proxy/base.py +0 -0
  53. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/proxy/client/LocalProxyClientProvider.py +0 -0
  54. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/proxy/client/__init__.py +0 -0
  55. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/proxy/client/base.py +0 -0
  56. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/py.typed +0 -0
  57. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/run.py +0 -0
  58. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/scan.py +0 -0
  59. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/spark/__init__.py +0 -0
  60. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/spark/base.py +0 -0
  61. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/spark/default.py +0 -0
  62. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/template.py +0 -0
  63. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/.devcontainer/devcontainer.json +0 -0
  64. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/Dockerfile.dependencies +0 -0
  65. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/README.md +0 -0
  66. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/build_native_dependencies.sh +0 -0
  67. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/payload/config.json +0 -0
  68. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/requirements-dev.txt +0 -0
  69. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/function/requirements.txt +0 -0
  70. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/.devcontainer/devcontainer.json +0 -0
  71. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/Dockerfile +0 -0
  72. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/Dockerfile.dependencies +0 -0
  73. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/README.md +0 -0
  74. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/account.ipynb +0 -0
  75. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/build_native_dependencies.sh +0 -0
  76. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/examples/employee_hierarchy/employee_data.csv +0 -0
  77. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/examples/employee_hierarchy/entrypoint.py +0 -0
  78. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/jupyterlab.sh +0 -0
  79. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/payload/config.json +0 -0
  80. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/payload/entrypoint.py +0 -0
  81. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/requirements-dev.txt +0 -0
  82. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/templates/script/requirements.txt +0 -0
  83. {salesforce_data_customcode-3.0.1 → salesforce_data_customcode-3.0.2}/src/datacustomcode/version.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: salesforce-data-customcode
3
- Version: 3.0.1
3
+ Version: 3.0.2
4
4
  Summary: Data Cloud Custom Code SDK
5
5
  License-Expression: Apache-2.0
6
6
  License-File: LICENSE.txt
@@ -18,7 +18,7 @@ license = "Apache-2.0"
18
18
  name = "salesforce-data-customcode"
19
19
  readme = "README.md"
20
20
  requires-python = ">=3.10,<3.12"
21
- version = "3.0.1"
21
+ version = "3.0.2"
22
22
 
23
23
  [tool.black]
24
24
  exclude = '''
@@ -0,0 +1,63 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from abc import ABC, abstractmethod
16
+ import os
17
+ from typing import Any
18
+
19
+ from pydantic import (
20
+ BaseModel,
21
+ ConfigDict,
22
+ Field,
23
+ )
24
+ import yaml
25
+
26
+ DEFAULT_CONFIG_NAME = "config.yaml"
27
+
28
+
29
+ def default_config_file() -> str:
30
+ return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
31
+
32
+
33
+ class ForceableConfig(BaseModel):
34
+ force: bool = Field(
35
+ default=False,
36
+ description="If True, this takes precedence over parameters passed to the "
37
+ "initializer of the client",
38
+ )
39
+
40
+
41
+ class BaseObjectConfig(ForceableConfig):
42
+ model_config = ConfigDict(validate_default=True, extra="forbid")
43
+ type_config_name: str = Field(
44
+ description="The config name of the object to create",
45
+ )
46
+ options: dict[str, Any] = Field(
47
+ default_factory=dict,
48
+ description="Options passed to the constructor.",
49
+ )
50
+
51
+
52
+ class BaseConfig(ABC, BaseModel):
53
+ @abstractmethod
54
+ def update(self, other: Any) -> "BaseConfig": ...
55
+
56
+ def load(self, config_path: str) -> "BaseConfig":
57
+ """Load configuration from a YAML file and merge with existing config"""
58
+ with open(config_path, "r") as f:
59
+ config_data = yaml.safe_load(f)
60
+
61
+ loaded_config = self.__class__.model_validate(config_data)
62
+ self.update(loaded_config)
63
+ return self
@@ -14,7 +14,6 @@
14
14
  # limitations under the License.
15
15
  from __future__ import annotations
16
16
 
17
- import os
18
17
  from typing import (
19
18
  TYPE_CHECKING,
20
19
  Any,
@@ -26,12 +25,14 @@ from typing import (
26
25
  cast,
27
26
  )
28
27
 
29
- from pydantic import (
30
- BaseModel,
31
- ConfigDict,
32
- Field,
28
+ from pydantic import Field
29
+
30
+ from datacustomcode.common_config import (
31
+ BaseConfig,
32
+ BaseObjectConfig,
33
+ ForceableConfig,
34
+ default_config_file,
33
35
  )
34
- import yaml
35
36
 
36
37
  # This lets all readers and writers to be findable via config
37
38
  from datacustomcode.io import * # noqa: F403
@@ -42,36 +43,15 @@ from datacustomcode.proxy.base import BaseProxyAccessLayer
42
43
  from datacustomcode.proxy.client.base import BaseProxyClient # noqa: TCH002
43
44
  from datacustomcode.spark.base import BaseSparkSessionProvider
44
45
 
45
- DEFAULT_CONFIG_NAME = "config.yaml"
46
-
47
-
48
46
  if TYPE_CHECKING:
49
47
  from pyspark.sql import SparkSession
50
48
 
51
49
 
52
- class ForceableConfig(BaseModel):
53
- force: bool = Field(
54
- default=False,
55
- description="If True, this takes precedence over parameters passed to the "
56
- "initializer of the client.",
57
- )
58
-
59
-
60
50
  _T = TypeVar("_T", bound="BaseDataAccessLayer")
61
51
 
62
52
 
63
- class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):
64
- model_config = ConfigDict(validate_default=True, extra="forbid")
53
+ class AccessLayerObjectConfig(BaseObjectConfig, Generic[_T]):
65
54
  type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
66
- type_config_name: str = Field(
67
- description="The config name of the object to create. "
68
- "For metrics, this would might be 'ipmnormal'. For custom classes, you can "
69
- "assign a name to a class variable `CONFIG_NAME` and reference it here.",
70
- )
71
- options: dict[str, Any] = Field(
72
- default_factory=dict,
73
- description="Options passed to the constructor.",
74
- )
75
55
 
76
56
  def to_object(self, spark: SparkSession) -> _T:
77
57
  type_ = self.type_base.subclass_from_config_name(self.type_config_name)
@@ -97,35 +77,25 @@ _P = TypeVar("_P", bound=BaseSparkSessionProvider)
97
77
  _PX = TypeVar("_PX", bound=BaseProxyAccessLayer)
98
78
 
99
79
 
100
- class ProxyAccessLayerObjectConfig(ForceableConfig, Generic[_PX]):
80
+ class ProxyAccessLayerObjectConfig(BaseObjectConfig, Generic[_PX]):
101
81
  """Config for proxy clients that take no constructor args (e.g. no spark)."""
102
82
 
103
- model_config = ConfigDict(validate_default=True, extra="forbid")
104
83
  type_base: ClassVar[Type[BaseProxyAccessLayer]] = BaseProxyAccessLayer
105
- type_config_name: str = Field(
106
- description="CONFIG_NAME of the proxy client (e.g. 'LocalProxyClient').",
107
- )
108
- options: dict[str, Any] = Field(default_factory=dict)
109
84
 
110
85
  def to_object(self) -> _PX:
111
86
  type_ = self.type_base.subclass_from_config_name(self.type_config_name)
112
87
  return cast(_PX, type_(**self.options))
113
88
 
114
89
 
115
- class SparkProviderConfig(ForceableConfig, Generic[_P]):
116
- model_config = ConfigDict(validate_default=True, extra="forbid")
90
+ class SparkProviderConfig(BaseObjectConfig, Generic[_P]):
117
91
  type_base: ClassVar[Type[BaseSparkSessionProvider]] = BaseSparkSessionProvider
118
- type_config_name: str = Field(
119
- description="CONFIG_NAME of the Spark session provider."
120
- )
121
- options: dict[str, Any] = Field(default_factory=dict)
122
92
 
123
93
  def to_object(self) -> _P:
124
94
  type_ = self.type_base.subclass_from_config_name(self.type_config_name)
125
95
  return cast(_P, type_(**self.options))
126
96
 
127
97
 
128
- class ClientConfig(BaseModel):
98
+ class ClientConfig(BaseConfig):
129
99
  reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
130
100
  writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
131
101
  proxy_config: Union[ProxyAccessLayerObjectConfig[BaseProxyClient], None] = None
@@ -163,31 +133,10 @@ class ClientConfig(BaseModel):
163
133
  )
164
134
  return self
165
135
 
166
- def load(self, config_path: str) -> ClientConfig:
167
- """Load a config from a file and update this config with it.
168
136
 
169
- Args:
170
- config_path: The path to the config file
171
-
172
- Returns:
173
- Self, with updated values from the loaded config.
174
- """
175
- with open(config_path, "r") as f:
176
- config_data = yaml.safe_load(f)
177
- loaded_config = ClientConfig.model_validate(config_data)
178
-
179
- return self.update(loaded_config)
180
-
181
-
182
- config = ClientConfig()
183
137
  """Global config object.
184
138
 
185
139
  This is the object that makes config accessible globally and globally mutable.
186
140
  """
187
-
188
-
189
- def _defaults() -> str:
190
- return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
191
-
192
-
193
- config.load(_defaults())
141
+ config = ClientConfig()
142
+ config.load(default_config_file())
@@ -23,3 +23,7 @@ proxy_config:
23
23
  type_config_name: LocalProxyClientProvider
24
24
  options:
25
25
  credentials_profile: default
26
+
27
+ einstein_predictions_config:
28
+ type_config_name: DefaultEinsteinPredictions
29
+ options: {}
@@ -0,0 +1,22 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datacustomcode.einstein_predictions.base import EinsteinPredictions
17
+ from datacustomcode.einstein_predictions.impl.default import DefaultEinsteinPredictions
18
+
19
+ __all__ = [
20
+ "EinsteinPredictions",
21
+ "DefaultEinsteinPredictions",
22
+ ]
@@ -0,0 +1,32 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from abc import ABC, abstractmethod
17
+
18
+ from datacustomcode.einstein_predictions.types import (
19
+ PredictionRequest,
20
+ PredictionResponse,
21
+ )
22
+ from datacustomcode.mixin import UserExtendableNamedConfigMixin
23
+
24
+
25
+ class EinsteinPredictions(ABC, UserExtendableNamedConfigMixin):
26
+ CONFIG_NAME: str
27
+
28
+ def __init__(self, **kwargs):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def predict(self, request: PredictionRequest) -> PredictionResponse: ...
@@ -0,0 +1,35 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datacustomcode.einstein_predictions.base import EinsteinPredictions
17
+ from datacustomcode.einstein_predictions.types import (
18
+ PredictionRequest,
19
+ PredictionResponse,
20
+ )
21
+
22
+
23
+ class DefaultEinsteinPredictions(EinsteinPredictions):
24
+ CONFIG_NAME = "DefaultEinsteinPredictions"
25
+
26
+ def __init__(self, **kwargs):
27
+ super().__init__(**kwargs)
28
+
29
+ def predict(self, request: PredictionRequest) -> PredictionResponse:
30
+ return PredictionResponse(
31
+ version="v1",
32
+ prediction_type=request.prediction_type,
33
+ status_code=200,
34
+ data={"results": [{"prediction": {"predictedValue": 1.0}}]},
35
+ )
@@ -0,0 +1,184 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from enum import Enum, unique
17
+ from typing import (
18
+ Any,
19
+ Dict,
20
+ Literal,
21
+ Optional,
22
+ )
23
+
24
+ from pydantic import (
25
+ BaseModel,
26
+ Field,
27
+ model_validator,
28
+ )
29
+
30
+
31
+ @unique
32
+ class PredictionType(Enum):
33
+ REGRESSION = 1
34
+ CLUSTERING = 2
35
+ CLASSIFICATION = 3
36
+ MULTI_OUTCOME = 4
37
+ BINARY_CLASSIFICATION = 5
38
+
39
+
40
+ class PredictionColumn(BaseModel):
41
+ column_name: str = Field(min_length=1, description="Column name")
42
+ string_values: Optional[list[str]] = Field(
43
+ default=None, min_length=1, description="Column string values"
44
+ )
45
+ double_values: Optional[list[float]] = Field(
46
+ default=None, min_length=1, description="Column double values"
47
+ )
48
+ boolean_values: Optional[list[bool]] = Field(
49
+ default=None, min_length=1, description="Column boolean values"
50
+ )
51
+ date_values: Optional[list[str]] = Field(
52
+ default=None, min_length=1, description="Column date values"
53
+ )
54
+ datetime_values: Optional[list[str]] = Field(
55
+ default=None, min_length=1, description="Column datetime values"
56
+ )
57
+
58
+ @model_validator(mode="after")
59
+ def validate_exactly_one_value_type(self):
60
+ set_count = sum(
61
+ [
62
+ self.string_values is not None,
63
+ self.double_values is not None,
64
+ self.boolean_values is not None,
65
+ self.date_values is not None,
66
+ self.datetime_values is not None,
67
+ ]
68
+ )
69
+
70
+ if set_count != 1:
71
+ raise ValueError("Exactly one value type must be set")
72
+
73
+ return self
74
+
75
+
76
+ class PredictionColumBuilder:
77
+ def __init__(self) -> None:
78
+ self._column_name: Optional[str] = None
79
+ self._string_values: Optional[list[str]] = None
80
+ self._double_values: Optional[list[float]] = None
81
+ self._boolean_values: Optional[list[bool]] = None
82
+ self._date_values: Optional[list[str]] = None
83
+ self._datetime_values: Optional[list[str]] = None
84
+
85
+ def set_column_name(self, column_name: str) -> "PredictionColumBuilder":
86
+ self._column_name = column_name
87
+ return self
88
+
89
+ def set_string_values(self, string_values: list[str]) -> "PredictionColumBuilder":
90
+ self._string_values = string_values
91
+ return self
92
+
93
+ def set_double_values(self, double_values: list[float]) -> "PredictionColumBuilder":
94
+ self._double_values = double_values
95
+ return self
96
+
97
+ def set_boolean_values(
98
+ self, boolean_values: list[bool]
99
+ ) -> "PredictionColumBuilder":
100
+ self._boolean_values = boolean_values
101
+ return self
102
+
103
+ def set_date_values(self, date_values: list[str]) -> "PredictionColumBuilder":
104
+ self._date_values = date_values
105
+ return self
106
+
107
+ def set_datetime_values(
108
+ self, datetime_values: list[str]
109
+ ) -> "PredictionColumBuilder":
110
+ self._datetime_values = datetime_values
111
+ return self
112
+
113
+ def build(self) -> PredictionColumn:
114
+ return PredictionColumn(
115
+ column_name=self._column_name,
116
+ string_values=self._string_values,
117
+ double_values=self._double_values,
118
+ boolean_values=self._boolean_values,
119
+ date_values=self._date_values,
120
+ datetime_values=self._datetime_values,
121
+ )
122
+
123
+
124
+ class PredictionRequest(BaseModel):
125
+ version: Literal["v1"] = Field(
126
+ default="v1", description="API version, must be 'v1'"
127
+ )
128
+ prediction_type: PredictionType = Field(description="Prediction type")
129
+ model_api_name: str = Field(
130
+ min_length=1, description="API name of the model to use"
131
+ )
132
+ prediction_columns: list[PredictionColumn] = Field(
133
+ min_length=1, description="List of prediction columns"
134
+ )
135
+ settings: Optional[Dict[str, Any]] = Field(
136
+ default=None, description="Settings for the prediction request"
137
+ )
138
+
139
+
140
+ class PredictionRequestBuilder:
141
+ def __init__(self) -> None:
142
+ self._prediction_type: Optional[PredictionType] = None
143
+ self._model_api_name: Optional[str] = None
144
+ self._prediction_columns: list[PredictionColumn] = []
145
+ self._settings: Optional[Dict[str, Any]] = None
146
+
147
+ def set_prediction_type(
148
+ self, prediction_type: PredictionType
149
+ ) -> "PredictionRequestBuilder":
150
+ self._prediction_type = prediction_type
151
+ return self
152
+
153
+ def set_model_api_name(self, model_api_name: str) -> "PredictionRequestBuilder":
154
+ self._model_api_name = model_api_name
155
+ return self
156
+
157
+ def set_prediction_columns(
158
+ self, prediction_columns: list[PredictionColumn]
159
+ ) -> "PredictionRequestBuilder":
160
+ self._prediction_columns = prediction_columns
161
+ return self
162
+
163
+ def set_settings(self, settings: Dict[str, Any]):
164
+ self._settings = settings
165
+ return self
166
+
167
+ def build(self) -> PredictionRequest:
168
+ return PredictionRequest(
169
+ prediction_type=self._prediction_type,
170
+ model_api_name=self._model_api_name,
171
+ prediction_columns=self._prediction_columns,
172
+ settings=self._settings,
173
+ )
174
+
175
+
176
+ class PredictionResponse(BaseModel):
177
+ version: Literal["v1"] = Field(default="v1", description="API version")
178
+ prediction_type: PredictionType = Field(description="Prediction type")
179
+ status_code: int = Field(description="HTTP status code")
180
+ data: Optional[Dict[str, Any]] = Field(default=None, description="Response data")
181
+
182
+ @property
183
+ def is_success(self) -> bool:
184
+ return self.status_code == 200
@@ -0,0 +1,67 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import (
17
+ ClassVar,
18
+ Generic,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+ from datacustomcode.common_config import (
26
+ BaseConfig,
27
+ BaseObjectConfig,
28
+ default_config_file,
29
+ )
30
+ from datacustomcode.einstein_predictions.base import EinsteinPredictions
31
+
32
+ _E = TypeVar("_E", bound=EinsteinPredictions)
33
+
34
+
35
+ class EinsteinPredictionsObjectConfig(BaseObjectConfig, Generic[_E]):
36
+ type_base: ClassVar[Type[EinsteinPredictions]] = EinsteinPredictions # type: ignore[type-abstract]
37
+
38
+ def to_object(self) -> _E:
39
+ type_ = self.type_base.subclass_from_config_name(self.type_config_name)
40
+ return cast(_E, type_(**self.options))
41
+
42
+
43
+ class EinsteinPredictionsConfig(BaseConfig):
44
+ einstein_predictions_config: Union[
45
+ EinsteinPredictionsObjectConfig[EinsteinPredictions], None
46
+ ] = None
47
+
48
+ def update(self, other: "EinsteinPredictionsConfig") -> "EinsteinPredictionsConfig":
49
+ def merge(
50
+ config_a: Union[EinsteinPredictionsObjectConfig, None],
51
+ config_b: Union[EinsteinPredictionsObjectConfig, None],
52
+ ) -> Union[EinsteinPredictionsObjectConfig, None]:
53
+ if config_a is not None and config_a.force:
54
+ return config_a
55
+ if config_b:
56
+ return config_b
57
+ return config_a
58
+
59
+ self.einstein_predictions_config = merge(
60
+ self.einstein_predictions_config, other.einstein_predictions_config
61
+ )
62
+ return self
63
+
64
+
65
+ # Global Einstein Predictions config instance
66
+ einstein_predictions_config = EinsteinPredictionsConfig()
67
+ einstein_predictions_config.load(default_config_file())
@@ -17,6 +17,8 @@
17
17
  import threading
18
18
  from typing import Optional
19
19
 
20
+ from datacustomcode.einstein_predictions.base import EinsteinPredictions
21
+ from datacustomcode.einstein_predictions_config import einstein_predictions_config
20
22
  from datacustomcode.file.path.default import DefaultFindFilePath
21
23
  from datacustomcode.function.base import BaseRuntime
22
24
  from datacustomcode.llm_gateway.default import DefaultLLMGateway
@@ -65,6 +67,7 @@ class Runtime(BaseRuntime):
65
67
  # Initialize resources
66
68
  self._llm_gateway = DefaultLLMGateway()
67
69
  self._file = DefaultFindFilePath()
70
+ self._einstein_predictions: Optional[EinsteinPredictions] = None
68
71
 
69
72
  @property
70
73
  def llm_gateway(self) -> DefaultLLMGateway:
@@ -75,3 +78,16 @@ class Runtime(BaseRuntime):
75
78
  def file(self) -> DefaultFindFilePath:
76
79
  """Access file operations."""
77
80
  return self._file
81
+
82
+ @property
83
+ def einstein_predictions(self) -> EinsteinPredictions:
84
+ if self._einstein_predictions is None:
85
+ if einstein_predictions_config.einstein_predictions_config is None:
86
+ raise RuntimeError(
87
+ "Einstein Predictions is not configured. Add "
88
+ "'einstein_predictions_config' section to config.yaml"
89
+ )
90
+ self._einstein_predictions = (
91
+ einstein_predictions_config.einstein_predictions_config.to_object()
92
+ )
93
+ return self._einstein_predictions
@@ -14,9 +14,11 @@
14
14
  # limitations under the License.
15
15
  from __future__ import annotations
16
16
 
17
- from abc import abstractmethod
17
+ from abc import ABC, abstractmethod
18
18
  from typing import TYPE_CHECKING
19
19
 
20
+ from datacustomcode.mixin import UserExtendableNamedConfigMixin
21
+
20
22
  if TYPE_CHECKING:
21
23
  from datacustomcode.llm_gateway.types.generate_text_request import (
22
24
  GenerateTextRequest,
@@ -26,8 +28,10 @@ if TYPE_CHECKING:
26
28
  )
27
29
 
28
30
 
29
- class LLMGateway:
30
- def __init__(self) -> None:
31
+ class LLMGateway(ABC, UserExtendableNamedConfigMixin):
32
+ CONFIG_NAME: str
33
+
34
+ def __init__(self, **kwargs):
31
35
  pass
32
36
 
33
37
  @abstractmethod
@@ -22,6 +22,8 @@ from datacustomcode.llm_gateway.types.generate_text_response_builder import (
22
22
 
23
23
 
24
24
  class DefaultLLMGateway(LLMGateway):
25
+ CONFIG_NAME = "DefaultLLMGateway"
26
+
25
27
  def generate_text(self, request: GenerateTextRequest) -> GenerateTextResponse:
26
28
 
27
29
  response_data = {
@@ -0,0 +1,65 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import (
17
+ ClassVar,
18
+ Generic,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+ from datacustomcode.common_config import (
26
+ BaseConfig,
27
+ BaseObjectConfig,
28
+ default_config_file,
29
+ )
30
+ from datacustomcode.llm_gateway.base import LLMGateway
31
+
32
+ _E = TypeVar("_E", bound=LLMGateway)
33
+
34
+
35
+ class LLMGatewayObjectConfig(BaseObjectConfig, Generic[_E]):
36
+ type_base: ClassVar[Type[LLMGateway]] = LLMGateway # type: ignore[type-abstract]
37
+
38
+ def to_object(self) -> _E:
39
+ type_ = self.type_base.subclass_from_config_name(self.type_config_name)
40
+ return cast(_E, type_(**self.options))
41
+
42
+
43
+ class LLMGatewayConfig(BaseConfig):
44
+ llm_gateway_config: Union[LLMGatewayObjectConfig[LLMGateway], None] = None
45
+
46
+ def update(self, other: "LLMGatewayConfig") -> "LLMGatewayConfig":
47
+ def merge(
48
+ config_a: Union[LLMGatewayObjectConfig, None],
49
+ config_b: Union[LLMGatewayObjectConfig, None],
50
+ ) -> Union[LLMGatewayObjectConfig, None]:
51
+ if config_a is not None and config_a.force:
52
+ return config_a
53
+ if config_b:
54
+ return config_b
55
+ return config_a
56
+
57
+ self.llm_gateway_config = merge(
58
+ self.llm_gateway_config, other.llm_gateway_config
59
+ )
60
+ return self
61
+
62
+
63
+ # Global LLM Gateway config instance
64
+ llm_gateway_config = LLMGatewayConfig()
65
+ llm_gateway_config.load(default_config_file())
@@ -2,6 +2,11 @@ import logging
2
2
  from typing import List
3
3
  from uuid import uuid4
4
4
 
5
+ from datacustomcode.einstein_predictions.types import (
6
+ PredictionColumBuilder,
7
+ PredictionRequestBuilder,
8
+ PredictionType,
9
+ )
5
10
  from datacustomcode.function import Runtime
6
11
  from datacustomcode.llm_gateway.types.generate_text_request_builder import (
7
12
  GenerateTextRequestBuilder,
@@ -38,6 +43,28 @@ def chunk_text(text: str, chunk_size: int = 1000) -> List[str]:
38
43
  return chunks
39
44
 
40
45
 
46
+ def make_einstein_prediction(runtime: Runtime) -> None:
47
+ column = (
48
+ PredictionColumBuilder()
49
+ .set_column_name("col1")
50
+ .set_string_values(["str1", "str2"])
51
+ .build()
52
+ )
53
+ prediction_request = (
54
+ PredictionRequestBuilder()
55
+ .set_prediction_type(PredictionType.REGRESSION)
56
+ .set_model_api_name("regressionModel")
57
+ .set_prediction_columns([column])
58
+ .build()
59
+ )
60
+
61
+ prediction_response = runtime.einstein_predictions.predict(prediction_request)
62
+ print(
63
+ f"Einstein prediction results - success: {prediction_response.is_success} \
64
+ response data: {prediction_response.data}"
65
+ )
66
+
67
+
41
68
  def function(request: dict, runtime: Runtime) -> dict:
42
69
  logger.info("Inside Function")
43
70
  logger.info(request)
@@ -46,6 +73,8 @@ def function(request: dict, runtime: Runtime) -> dict:
46
73
  output_chunks = []
47
74
  current_seq_no = 1 # Start sequence number from 1
48
75
 
76
+ make_einstein_prediction(runtime)
77
+
49
78
  builder = GenerateTextRequestBuilder()
50
79
  llm_request = builder.set_prompt("Hello").set_model("modelName").build()
51
80
  llm_response = runtime.llm_gateway.generate_text(llm_request)