salesforce-data-customcode 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. datacustomcode/__init__.py +20 -0
  2. datacustomcode/cli.py +142 -0
  3. datacustomcode/client.py +227 -0
  4. datacustomcode/cmd.py +105 -0
  5. datacustomcode/config.py +149 -0
  6. datacustomcode/config.yaml +15 -0
  7. datacustomcode/credentials.py +97 -0
  8. datacustomcode/deploy.py +379 -0
  9. datacustomcode/io/__init__.py +14 -0
  10. datacustomcode/io/base.py +28 -0
  11. datacustomcode/io/reader/__init__.py +14 -0
  12. datacustomcode/io/reader/base.py +34 -0
  13. datacustomcode/io/reader/query_api.py +115 -0
  14. datacustomcode/io/writer/__init__.py +14 -0
  15. datacustomcode/io/writer/base.py +49 -0
  16. datacustomcode/io/writer/csv.py +41 -0
  17. datacustomcode/io/writer/print.py +33 -0
  18. datacustomcode/mixin.py +94 -0
  19. datacustomcode/py.typed +0 -0
  20. datacustomcode/run.py +47 -0
  21. datacustomcode/scan.py +153 -0
  22. datacustomcode/template.py +36 -0
  23. datacustomcode/templates/.devcontainer/devcontainer.json +10 -0
  24. datacustomcode/templates/Dockerfile +20 -0
  25. datacustomcode/templates/README.md +0 -0
  26. datacustomcode/templates/jupyterlab.sh +97 -0
  27. datacustomcode/templates/payload/config.json +1 -0
  28. datacustomcode/templates/payload/entrypoint.py +10 -0
  29. datacustomcode/templates/requirements-dev.txt +10 -0
  30. datacustomcode/templates/requirements.txt +1 -0
  31. salesforce_data_customcode-0.1.0.dist-info/LICENSE.txt +206 -0
  32. salesforce_data_customcode-0.1.0.dist-info/METADATA +159 -0
  33. salesforce_data_customcode-0.1.0.dist-info/RECORD +35 -0
  34. salesforce_data_customcode-0.1.0.dist-info/WHEEL +4 -0
  35. salesforce_data_customcode-0.1.0.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,20 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from datacustomcode.client import Client
17
+ from datacustomcode.io.reader.query_api import QueryAPIDataCloudReader
18
+ from datacustomcode.io.writer.print import PrintDataCloudWriter
19
+
20
+ __all__ = ["Client", "QueryAPIDataCloudReader", "PrintDataCloudWriter"]
datacustomcode/cli.py ADDED
@@ -0,0 +1,142 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from importlib import metadata
16
+ import json
17
+ import os
18
+ import sys
19
+ from typing import List, Union
20
+
21
+ import click
22
+ from loguru import logger
23
+
24
+
25
+ @click.group()
26
+ @click.option("--debug", is_flag=True)
27
+ def cli(debug: bool):
28
+ logger.remove()
29
+ if debug:
30
+ logger.configure(handlers=[{"sink": sys.stderr, "level": "DEBUG"}])
31
+ else:
32
+ logger.configure(handlers=[{"sink": sys.stderr, "level": "INFO"}])
33
+
34
+
35
+ @cli.command()
36
+ def version():
37
+ """Display the current version of the package."""
38
+ print(__name__)
39
+ try:
40
+ version = metadata.version("salesforce-data-customcode")
41
+ click.echo(f"salesforce-data-customcode version: {version}")
42
+ except metadata.PackageNotFoundError:
43
+ click.echo("Version information not available")
44
+
45
+
46
+ @cli.command()
47
+ @click.option("--profile", default="default")
48
+ @click.option("--username", prompt=True)
49
+ @click.option("--password", prompt=True, hide_input=True)
50
+ @click.option("--client-id", prompt=True)
51
+ @click.option("--client-secret", prompt=True)
52
+ @click.option("--login-url", prompt=True)
53
+ def configure(
54
+ username: str,
55
+ password: str,
56
+ client_id: str,
57
+ client_secret: str,
58
+ login_url: str,
59
+ profile: str,
60
+ ) -> None:
61
+ from datacustomcode.credentials import Credentials
62
+
63
+ Credentials(
64
+ username=username,
65
+ password=password,
66
+ client_id=client_id,
67
+ client_secret=client_secret,
68
+ login_url=login_url,
69
+ ).update_ini(profile=profile)
70
+
71
+
72
+ @cli.command()
73
+ @click.option("--profile", default="default")
74
+ @click.option("--path", default="payload")
75
+ @click.option("--name", required=True)
76
+ @click.option("--version", default="0.0.1")
77
+ @click.option("--description", default="Custom Data Transform Code")
78
+ def deploy(profile: str, path: str, name: str, version: str, description: str):
79
+ from datacustomcode.credentials import Credentials
80
+ from datacustomcode.deploy import TransformationJobMetadata, deploy_full
81
+
82
+ logger.debug("Deploying project")
83
+
84
+ metadata = TransformationJobMetadata(
85
+ name=name,
86
+ version=version,
87
+ description=description,
88
+ )
89
+ try:
90
+ credentials = Credentials.from_ini(profile=profile)
91
+ except KeyError:
92
+ click.secho(
93
+ f"Error: Profile {profile} not found in credentials.ini. "
94
+ "Run `datacustomcode configure` to create a credentialsprofile.",
95
+ fg="red",
96
+ )
97
+ raise click.Abort() from None
98
+ deploy_full(path, metadata, credentials)
99
+
100
+
101
+ @cli.command()
102
+ @click.argument("directory", default=".")
103
+ def init(directory: str):
104
+ from datacustomcode.template import copy_template
105
+
106
+ click.echo("Copying template to " + click.style(directory, fg="blue", bold=True))
107
+ copy_template(directory)
108
+ click.echo(
109
+ "Start developing by updating the code in "
110
+ + click.style(f"{directory}/payload/entrypoint.py", fg="blue", bold=True)
111
+ )
112
+
113
+
114
+ @cli.command()
115
+ @click.argument("filename")
116
+ @click.option("--config")
117
+ @click.option("--dry-run", is_flag=True)
118
+ def scan(filename: str, config: str, dry_run: bool):
119
+ from datacustomcode.scan import dc_config_json_from_file
120
+
121
+ config_location = config or os.path.join(os.path.dirname(filename), "config.json")
122
+ click.echo(
123
+ "Dumping scan results to config file: "
124
+ + click.style(config_location, fg="blue", bold=True)
125
+ )
126
+ click.echo("Scanning " + click.style(filename, fg="blue", bold=True) + "...")
127
+ config_json = dc_config_json_from_file(filename)
128
+
129
+ click.secho(json.dumps(config_json, indent=2), fg="yellow")
130
+ if not dry_run:
131
+ with open(config_location, "w") as f:
132
+ json.dump(config_json, f, indent=2)
133
+
134
+
135
+ @cli.command()
136
+ @click.argument("entrypoint")
137
+ @click.option("--config-file", default=None)
138
+ @click.option("--dependencies", default=[], multiple=True)
139
+ def run(entrypoint: str, config_file: Union[str, None], dependencies: List[str]):
140
+ from datacustomcode.run import run_entrypoint
141
+
142
+ run_entrypoint(entrypoint, config_file, dependencies)
@@ -0,0 +1,227 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ from enum import Enum
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ ClassVar,
21
+ Optional,
22
+ )
23
+
24
+ from pyspark.sql import SparkSession
25
+
26
+ from datacustomcode.config import SparkConfig, config
27
+ from datacustomcode.io.reader.base import BaseDataCloudReader
28
+
29
+ if TYPE_CHECKING:
30
+ from pyspark.sql import DataFrame as PySparkDataFrame
31
+
32
+ from datacustomcode.io.reader.base import BaseDataCloudReader
33
+ from datacustomcode.io.writer.base import BaseDataCloudWriter, WriteMode
34
+
35
+
36
+ def _setup_spark(spark_config: SparkConfig) -> SparkSession:
37
+ """Setup Spark session from config."""
38
+ builder = SparkSession.builder
39
+ if spark_config.master is not None:
40
+ builder = builder.master(spark_config.master)
41
+
42
+ builder = builder.appName(spark_config.app_name)
43
+ for key, value in spark_config.options.items():
44
+ builder = builder.config(key, value)
45
+ return builder.getOrCreate()
46
+
47
+
48
+ class DataCloudObjectType(Enum):
49
+ DLO = "dlo"
50
+ DMO = "dmo"
51
+
52
+
53
+ class DataCloudAccessLayerException(Exception):
54
+ """Exception raised when mixing DMOs and DLOs is detected."""
55
+
56
+ def __init__(
57
+ self,
58
+ data_layer_history: dict[DataCloudObjectType, set[str]],
59
+ should_not_contain: DataCloudObjectType,
60
+ ) -> None:
61
+ self.data_layer_history = data_layer_history
62
+ self.should_not_contain = should_not_contain
63
+
64
+ def __str__(self) -> str:
65
+ msg = (
66
+ "Mixed use of DMOs and DLOs. "
67
+ "You can only read from DMOs to write to DMOs "
68
+ "and read from DLOs to write to DLOs. "
69
+ )
70
+ if self.should_not_contain is DataCloudObjectType.DLO:
71
+ msg += (
72
+ "You have read from the following DLOs: "
73
+ f"{self.data_layer_history[DataCloudObjectType.DLO]} "
74
+ f"and are attempting to write to DMO. "
75
+ )
76
+ else:
77
+ msg += (
78
+ "You have read from the following DMOs: "
79
+ f"{self.data_layer_history[DataCloudObjectType.DMO]} "
80
+ f"and are attempting to write to to a DLO. "
81
+ )
82
+ msg += "Restart to clear history."
83
+ return msg
84
+
85
+
86
+ class Client:
87
+ """Entrypoint for accessing DataCloud objects.
88
+
89
+ This is the object used to access Data Cloud DLOs and DMOs. Accessing DLOs/DMOs
90
+ are tracked and will throw an exception if they are mixed. In other words, you
91
+ can read from DLOs and write to DLOs, read from DMOs and write to DMOs, but you
92
+ cannot read from DLOs and write to DMOs or read from DMOs and write to DLOs.
93
+ Furthermore you cannot mix during merging tables. This class is a singleton to
94
+ prevent accidental mixing of DLOs and DMOs.
95
+
96
+ You can provide custom readers and writers to the client for advanced use
97
+ cases, but this is not recommended for testing as they may result in unexpected
98
+ behavior once deployed to Data Cloud. By default, the client intercepts all
99
+ read/write operations and mocks access to Data Cloud. For example, during
100
+ writing, we print to the console instead of writing to Data Cloud.
101
+
102
+ Args:
103
+ reader: A custom reader to use for reading Data Cloud objects.
104
+ writer: A custom writer to use for writing Data Cloud objects.
105
+
106
+ Example:
107
+ >>> client = Client()
108
+ >>> dlo = client.read_dlo("my_dlo")
109
+ >>> client.write_to_dmo("my_dmo", dlo)
110
+ """
111
+
112
+ _instance: ClassVar[Optional[Client]] = None
113
+ _reader: BaseDataCloudReader
114
+ _writer: BaseDataCloudWriter
115
+ _data_layer_history: dict[DataCloudObjectType, set[str]]
116
+
117
+ def __new__(
118
+ cls,
119
+ reader: Optional[BaseDataCloudReader] = None,
120
+ writer: Optional[BaseDataCloudWriter] = None,
121
+ ) -> Client:
122
+ if cls._instance is None:
123
+ cls._instance = super().__new__(cls)
124
+
125
+ # Initialize Readers and Writers from config
126
+ # and/or provided reader and writer
127
+ if reader is None or writer is None:
128
+ # We need a spark because we will initialize readers and writers
129
+ if config.spark_config is None:
130
+ raise ValueError(
131
+ "Spark config is required when reader/writer is not provided"
132
+ )
133
+ spark = _setup_spark(config.spark_config)
134
+
135
+ if config.reader_config is None and reader is None:
136
+ raise ValueError(
137
+ "Reader config is required when reader is not provided"
138
+ )
139
+ elif reader is None or (
140
+ config.reader_config is not None and config.reader_config.force
141
+ ):
142
+ reader_init = config.reader_config.to_object(spark) # type: ignore
143
+ else:
144
+ reader_init = reader
145
+ if config.writer_config is None and writer is None:
146
+ raise ValueError(
147
+ "Writer config is required when writer is not provided"
148
+ )
149
+ elif writer is None or (
150
+ config.writer_config is not None and config.writer_config.force
151
+ ):
152
+ writer_init = config.writer_config.to_object(spark) # type: ignore
153
+ else:
154
+ writer_init = writer
155
+ cls._instance._reader = reader_init
156
+ cls._instance._writer = writer_init
157
+ cls._instance._data_layer_history = {
158
+ DataCloudObjectType.DLO: set(),
159
+ DataCloudObjectType.DMO: set(),
160
+ }
161
+ elif (reader is not None or writer is not None) and cls._instance is not None:
162
+ raise ValueError("Cannot set reader or writer after client is initialized")
163
+ return cls._instance
164
+
165
+ def read_dlo(self, name: str) -> PySparkDataFrame:
166
+ """Read a DLO from Data Cloud.
167
+
168
+ Args:
169
+ name: The name of the DLO to read.
170
+
171
+ Returns:
172
+ A PySpark DataFrame containing the DLO data.
173
+ """
174
+ self._record_dlo_access(name)
175
+ return self._reader.read_dlo(name)
176
+
177
+ def read_dmo(self, name: str) -> PySparkDataFrame:
178
+ """Read a DMO from Data Cloud.
179
+
180
+ Args:
181
+ name: The name of the DMO to read.
182
+
183
+ Returns:
184
+ A PySpark DataFrame containing the DMO data.
185
+ """
186
+ self._record_dmo_access(name)
187
+ return self._reader.read_dmo(name)
188
+
189
+ def write_to_dlo(
190
+ self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode, **kwargs
191
+ ) -> None:
192
+ """Write a PySpark DataFrame to a DLO in Data Cloud.
193
+
194
+ Args:
195
+ name: The name of the DLO to write to.
196
+ dataframe: The PySpark DataFrame to write.
197
+ write_mode: The write mode to use for writing to the DLO.
198
+ """
199
+ self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DMO)
200
+ return self._writer.write_to_dlo(name, dataframe, write_mode, **kwargs)
201
+
202
+ def write_to_dmo(
203
+ self, name: str, dataframe: PySparkDataFrame, write_mode: WriteMode, **kwargs
204
+ ) -> None:
205
+ """Write a PySpark DataFrame to a DMO in Data Cloud.
206
+
207
+ Args:
208
+ name: The name of the DMO to write to.
209
+ dataframe: The PySpark DataFrame to write.
210
+ write_mode: The write mode to use for writing to the DMO.
211
+ """
212
+ self._validate_data_layer_history_does_not_contain(DataCloudObjectType.DLO)
213
+ return self._writer.write_to_dmo(name, dataframe, write_mode, **kwargs)
214
+
215
+ def _validate_data_layer_history_does_not_contain(
216
+ self, data_cloud_object_type: DataCloudObjectType
217
+ ) -> None:
218
+ if len(self._data_layer_history[data_cloud_object_type]) > 0:
219
+ raise DataCloudAccessLayerException(
220
+ self._data_layer_history, data_cloud_object_type
221
+ )
222
+
223
+ def _record_dlo_access(self, name: str) -> None:
224
+ self._data_layer_history[DataCloudObjectType.DLO].add(name)
225
+
226
+ def _record_dmo_access(self, name: str) -> None:
227
+ self._data_layer_history[DataCloudObjectType.DMO].add(name)
datacustomcode/cmd.py ADDED
@@ -0,0 +1,105 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ This module is shamelessly copied from conda to nicely wrap subprocess calls.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import contextlib
22
+ import subprocess
23
+ from typing import Any, Union
24
+
25
+
26
+ def _force_bytes(exc: Any) -> bytes:
27
+ with contextlib.suppress(TypeError):
28
+ return bytes(exc)
29
+ with contextlib.suppress(Exception):
30
+ return str(exc).encode()
31
+ return f"<unprintable {type(exc).__name__} object>".encode()
32
+
33
+
34
+ def _setdefault_kwargs(kwargs: dict[str, Any]) -> None:
35
+ for arg in ("stdin", "stdout", "stderr"):
36
+ kwargs.setdefault(arg, subprocess.PIPE)
37
+
38
+
39
+ def _oserror_to_output(e: OSError) -> tuple[int, bytes, None]:
40
+ return 1, _force_bytes(e).rstrip(b"\n") + b"\n", None
41
+
42
+
43
+ class CalledProcessError(RuntimeError):
44
+ """Nicely formatted subprocess call error."""
45
+
46
+ def __init__(
47
+ self,
48
+ returncode: int,
49
+ cmd: tuple[str, ...],
50
+ stdout: bytes,
51
+ stderr: Union[bytes, None],
52
+ ) -> None:
53
+ super().__init__(returncode, cmd, stdout, stderr)
54
+ self.returncode = returncode
55
+ self.cmd = cmd
56
+ self.stdout = stdout
57
+ self.stderr = stderr
58
+
59
+ def __bytes__(self) -> bytes:
60
+ def _indent_or_none(part: Union[bytes, None]) -> bytes:
61
+ if part:
62
+ return b"\n " + part.replace(b"\n", b"\n ").rstrip()
63
+ else:
64
+ return b" (none)"
65
+
66
+ return b"".join(
67
+ (
68
+ f"command: {self.cmd!r}\n".encode(),
69
+ f"return code: {self.returncode}\n".encode(),
70
+ b"stdout:",
71
+ self.stdout,
72
+ b"\n",
73
+ b"stderr:",
74
+ _indent_or_none(self.stderr),
75
+ )
76
+ )
77
+
78
+ def __str__(self) -> str:
79
+ return self.__bytes__().decode()
80
+
81
+
82
+ def _cmd_output(
83
+ *cmd: str,
84
+ check: bool = True,
85
+ **kwargs: Any,
86
+ ) -> tuple[int, bytes, Union[bytes, None]]:
87
+ _setdefault_kwargs(kwargs)
88
+ try:
89
+ kwargs.setdefault("shell", True)
90
+ proc = subprocess.Popen(cmd, **kwargs)
91
+ except OSError as e:
92
+ returncode, stdout_b, stderr_b = _oserror_to_output(e)
93
+ else:
94
+ stdout_b, stderr_b = proc.communicate()
95
+ returncode = proc.returncode
96
+ if check and returncode:
97
+ raise CalledProcessError(returncode, cmd, stdout_b, stderr_b)
98
+
99
+ return returncode, stdout_b, stderr_b
100
+
101
+
102
+ def cmd_output(*cmd: str, **kwargs: Any) -> Union[str, None]:
103
+ returncode, stdout_b, stderr_b = _cmd_output(*cmd, **kwargs)
104
+ stdout = stdout_b.decode() if stdout_b is not None else None
105
+ return stdout
@@ -0,0 +1,149 @@
1
+ # Copyright (c) 2025, Salesforce, Inc.
2
+ # SPDX-License-Identifier: Apache-2
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ from __future__ import annotations
16
+
17
+ import os
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Any,
21
+ ClassVar,
22
+ Generic,
23
+ Type,
24
+ TypeVar,
25
+ Union,
26
+ cast,
27
+ )
28
+
29
+ from pydantic import (
30
+ BaseModel,
31
+ ConfigDict,
32
+ Field,
33
+ )
34
+ import yaml
35
+
36
+ # This lets all readers and writers to be findable via config
37
+ from datacustomcode.io import * # noqa: F403
38
+ from datacustomcode.io.base import BaseDataAccessLayer
39
+ from datacustomcode.io.reader.base import BaseDataCloudReader # noqa: TCH001
40
+ from datacustomcode.io.writer.base import BaseDataCloudWriter # noqa: TCH001
41
+
42
+ DEFAULT_CONFIG_NAME = "config.yaml"
43
+
44
+
45
+ if TYPE_CHECKING:
46
+ from pyspark.sql import SparkSession
47
+
48
+
49
+ class ForceableConfig(BaseModel):
50
+ force: bool = Field(
51
+ default=False,
52
+ description="If True, this takes precedence over parameters passed to the "
53
+ "initializer of the client.",
54
+ )
55
+
56
+
57
+ _T = TypeVar("_T", bound="BaseDataAccessLayer")
58
+
59
+
60
+ class AccessLayerObjectConfig(ForceableConfig, Generic[_T]):
61
+ model_config = ConfigDict(validate_default=True, extra="forbid")
62
+ type_base: ClassVar[Type[BaseDataAccessLayer]] = BaseDataAccessLayer
63
+ type_config_name: str = Field(
64
+ description="The config name of the object to create. "
65
+ "For metrics, this would might be 'ipmnormal'. For custom classes, you can "
66
+ "assign a name to a class variable `CONFIG_NAME` and reference it here.",
67
+ )
68
+ options: dict[str, Any] = Field(
69
+ default_factory=dict,
70
+ description="Options passed to the constructor.",
71
+ )
72
+
73
+ def to_object(self, spark: SparkSession) -> _T:
74
+ type_ = self.type_base.subclass_from_config_name(self.type_config_name)
75
+ return cast(_T, type_(spark=spark, **self.options))
76
+
77
+
78
+ class SparkConfig(ForceableConfig):
79
+ app_name: str = Field(
80
+ description="The name of the Spark application.",
81
+ )
82
+ master: Union[str, None] = Field(
83
+ default=None,
84
+ description="The Spark master URL.",
85
+ )
86
+ options: dict[str, Any] = Field(
87
+ default_factory=dict,
88
+ description="Options passed to the SparkSession constructor.",
89
+ )
90
+
91
+
92
+ class ClientConfig(BaseModel):
93
+ reader_config: Union[AccessLayerObjectConfig[BaseDataCloudReader], None] = None
94
+ writer_config: Union[AccessLayerObjectConfig[BaseDataCloudWriter], None] = None
95
+ spark_config: Union[SparkConfig, None] = None
96
+
97
+ def update(self, other: ClientConfig) -> ClientConfig:
98
+ """Merge this ClientConfig with another, respecting force flags.
99
+
100
+ Args:
101
+ other: Another ClientConfig to merge with this one
102
+
103
+ Returns:
104
+ Self, with updated values from the other config based on force flags.
105
+ """
106
+ TypeVarT = TypeVar("TypeVarT", bound=ForceableConfig)
107
+
108
+ def merge(
109
+ config_a: Union[TypeVarT, None], config_b: Union[TypeVarT, None]
110
+ ) -> Union[TypeVarT, None]:
111
+ if config_a is not None and config_a.force:
112
+ return config_a
113
+ if config_b:
114
+ return config_b
115
+ return config_a
116
+
117
+ self.reader_config = merge(self.reader_config, other.reader_config)
118
+ self.writer_config = merge(self.writer_config, other.writer_config)
119
+ self.spark_config = merge(self.spark_config, other.spark_config)
120
+ return self
121
+
122
+ def load(self, config_path: str) -> ClientConfig:
123
+ """Load a config from a file and update this config with it.
124
+
125
+ Args:
126
+ config_path: The path to the config file
127
+
128
+ Returns:
129
+ Self, with updated values from the loaded config.
130
+ """
131
+ with open(config_path, "r") as f:
132
+ config_data = yaml.safe_load(f)
133
+ loaded_config = ClientConfig.model_validate(config_data)
134
+
135
+ return self.update(loaded_config)
136
+
137
+
138
+ config = ClientConfig()
139
+ """Global config object.
140
+
141
+ This is the object that makes config accessible globally and globally mutable.
142
+ """
143
+
144
+
145
+ def _defaults() -> str:
146
+ return os.path.join(os.path.dirname(__file__), DEFAULT_CONFIG_NAME)
147
+
148
+
149
+ config.load(_defaults())
@@ -0,0 +1,15 @@
1
+ reader_config:
2
+ type_config_name: QueryAPIDataCloudReader
3
+
4
+ writer_config:
5
+ type_config_name: PrintDataCloudWriter
6
+
7
+ spark_config:
8
+ app_name: DC Custom Code Python SDK Testing
9
+ master: local[*]
10
+ options:
11
+ spark.driver.host: localhost
12
+ spark.driver.bindAddress: 127.0.0.1
13
+ spark.submit.deployMode: client
14
+ spark.sql.execution.arrow.pyspark.enabled: 'true'
15
+ spark.driver.extraJavaOptions: -Djava.security.manager=allow